We recommend using GPUs from Colab to finish this
project!
Overview
In part B you will train your own diffusion model on MNIST. Starter code can
be found in the provided notebook.
START EARLY!
This project, in many ways, will be the most
difficult project this semester.
Part 1: Training a Single-Step Denoising UNet
1.0 Problem Formulation
Let's warmup by building a simple one-step denoiser. Given a noisy image , we
aim to train a denoiser such that it maps to a clean
image . To do so, we can optimize over an L2 loss:
1.1 Implementing the UNet
In this project, we implement the denoiser as a UNet. It consists of a
few downsampling and upsampling blocks with skip connections.
Figure 1: Unconditional UNet
The diagram above uses a number of standard tensor operations defined as follows:
Figure 2: UNet Operations
Note:
(1) Conv doesn't change the image resolution, only the channel dimension.
(2) DownConv downsamples the tensor by 2.
(3) UpConv upsamples the tensor by 2.
(4) Flatten flattens a 7x7 tensor into a 1x1 tensor. 7 is the resulting height and width after the downsampling operations.
(5) Unflatten unflattens a 1x1 tensor into a 7x7
tensor.
(6) Concat is a simple channel-wise concatenation between tensors with the same 2D shape. This is simply torch.cat.
D is the number of hidden channels and is a hyperparameter that we will set ourselves.
We define composed operations using our simple operations in order to make our network deeper. This doesn't change the tensor's height, width, or number of channels, but simply adds more learnable parameters.
(7) ConvBlock, is similar to Conv but includes an additional Conv. Note that it has the same input and output shape as (1) Conv.
(8) DownBlock, is similar to DownConv but includes an additional ConvBlock. Note that it has the same input and output shape as (2) DownConv.
(9) UpBlock, is similar to UpConv but includes an additional ConvBlock. Note that it has the same input and output shape as (3) UpConv.
Within the simple operations:
Conv2d(kernel_size, stride, padding) is nn.Conv2d(kernel_size, stride, padding)
BN is nn.BatchNorm2d
GELU is nn.GELU()
Conv2d⁻¹(kernel_size, stride, padding) is nn.ConvTranspose2d(kernel_size, stride, padding)
AvgPool(kernel_size) is nn.AvgPool2d(kernel_size)
1.2 Using the UNet to Train a Denoiser
Recall from equation 1 that we aim to solve the following denoising
problem:
Given a noisy image , we
aim to train a denoiser such that it maps to a clean
image . To do so, we can optimize over an L2 loss
To train our denoiser, we need to generate training data pairs of (, ), where each is a clean MNIST digit. For each training batch, we can generate from using the the following noising process:
Visualize the different noising processes over , assuming normalized .
It should be similar to the following plot:
Figure 3. Varying levels of noise on MNIST digits
1.2.1 Training
Now, we will train the model to perform denoising.
Objective: Train a denoiser to denoise noisy image with applied to a clean image .
Dataset and dataloader: Use the MNIST dataset via torchvision.datasets.MNIST with flags to access training and test sets. Shuffle the dataset before creating the dataloader. Recommended batch size: 256. We'll train over our dataset for 5 epochs.
We recommend only noising the image batches when fetched from the dataloader.
Model: Use the UNet architecture defined in section 1.1 with recommended hidden dimension of 128 (this is D in the diagrams above).
Optimizer: Use Adam optimizer with learning rate of 1e-4.
Figure 4. Training Loss Curve
You should visualize denoised results on the test set at the end of
training. Display sample results after the 1st and 5th epoch.
They should look something like these:
Figure 5. Results on digits from the test set after 1 epoch of training
Figure 6. Results on digits from the test set after 5 epochs of training
1.2.2 Out-of-Distribution Testing
Our denoiser was trained on MNIST digits noised with . Let's see how the denoiser performs on different 's that it wasn't trained for.
Visualize the denoiser results on test set digits with varying levels of noise .
Figure 7. Results on digits from the test set with varying noise levels.
1.3 Deliverables
A visualization of different noising process over . (figure 3)
A training loss curve plot every few iterations during the whole
training process (figure 4). That means you are in charge of:
Successfully implementing the UNet.
Dataset and dataloader creation on MNIST with train/test
split.
Optimizer and model creation.
Training loop.
Sample results on the test set after the first and the 5-th epoch
(staff solution takes ~7 minutes for 5 epochs on a Colab T4 GPU).
(figure 5, 6)
Sample results on the test set with out-of-distribution noise levels after the model is trained. Keep the same image and
vary . (figure 7)
Hint
Since training can take a while, we strongly recommend that you
checkpoint your model every epoch onto your personal Google
Drive.
This is because Colab notebooks aren't persistent such that if you are
idle for a while, you will lose connection and your training progress.
This consists of:
Google Drive mounting.
Epoch-wise model & optimizer checkpointing.
Model & optimizer resuming from checkpoints.
Part 2: Training a DDPM Denoising UNet
Now, we are ready for diffusion, where we can iteratively denoise the image. We will implement DDPM in this part.
One small change is that we're going to change our UNet from part 1 to predict the noise instead of the clean image (like in part 4A of the project).
2.0 Problem Formulation
Let's reconsider the problem in part 1, but to its extreme:
Given a pure noise image , we aim to train a denoiser
such that it maps the noise image to a clean image .
To do so, we can still apply a simple L2 loss:
The difference here, compared to part 1, is that is
pure noise.
If we can learn to remove pure noise, this will allow us to generate novel images, not just those in our training set.
However, we saw in part A that one-step denoising does not yield good results. Instead, we can iteratively denoise the image for better results.
For iterative denoising, we condition our model on timestep such that it can learn time-specific denoising. We can equivalently predict the noise added to the image rather than the denoised image itself.
For now, and can be thought of as some random function of .
You can imagine that, with a time-conditioned denoising UNet, we can go from one-step denoising to iterative denoising:
We can therefore perform iterative denoising to get better results than a one-step denoiser as shown in part 1, which is especially useful when our noisy inputs are pure Gaussian noise.
In practice, our model is predicting the entire noise added to to get rather than intermediate amounts of noise, but the coefficients and will appropriately scale such that we recover intermediate noise samples for .
Additionally, because of this scaling we cannot directly subtract the noise as shown above, but will instead do so under a different process shown later.
2.1 Refactoring Your Unconditional UNet for DDPM
In order to do iterative denoising, we first need to add condition
into our model.
We will also add a class-label condition into our model for when we
later do class-conditioned denoising with classifier-free guidance.
Let's first define a new operator called
FCBlock (fully-connected block):
Figure 8. FCBlock for conditioning
Here L(F_in, F_out) is a linear layer with
F_in input features and F_out output
features. You can implement it using nn.Linear.
To condition our network on time and class-label, we can apply conditioning after unflattening and the first upsample:
Figure 9. Conditional UNet
You can embed and by following this pseudo code:
fc1_t = FCBlock(...)
fc1_c = FCBlock(...)
fc2_t = FCBlock(...)
fc2_c = FCBlock(...)
t1 = fc1_t(t)
c1 = fc1_c(c)
t2 = fc2_t(t)
c2 = fc2_c(c)
# Follow diagram to get unflatten.# Replace the original unflatten with modulated unflatten.
unflatten = fc1_c * unflatten + fc1_t
# Follow diagram to get up1.
...
# Replace the original up1 with modulated up1.
up1 = fc2_c * up1 + fc2_t
# Follow diagram to get the output.
...
Note that the class-label should be encoded as one-hot vector.
Another modification you need to make is to make model take in a
batch-wise mask vector which is either 0 or 1.
It indicates whether or not to drop the condition : drop when mask
is 0, not drop when mask is 1.
It can be null, which means just using the condition. This is so we
can perform line 6 of algorithm 1.
2.2 Implementing DDPM Forward and Reverse Process
Now that we have some intuition from part 2.0, it's time to implement
the forward and reverse process of DDPM.
DDPM considers a very specific noising and denoising process:
Figure 10: DDPM markov chain. The forward process is
denoted by and the reverse process is denoted
by .
(Image source: Ho et al. 2020 with a few additional annotations
from Lilian Weng)
Specifically, each forward step adds Gaussian noise in a
variance-preserving way for some variance schedule
:
Using the reparamaterization trick presented in section 2 of the DDPM
paper, we can compute effective one step noising function, since a
Gaussian convolved with a Gaussian is still Gaussian (see here
for more details).
Concretely, let and , then we can sample a noisy for an
arbitrary :
DDPM Scheduler
Let's first implement the DDPM scheduler to fetch all relevant
variables.
Given , follow the doc-string to get all useful
values.
You will use them in a bit!
TODO: Implement ddpm_schedule()
DDPM Forward Process
For brevity, we don't show the mathematical details
here. If you'd like to see the mathematical details, check out here.
TODO: Implement our ddpm_forward() function by
following algorithm 1:
DDPM Reverse Process
Recall that in the reverse process we progressively work backwards to
reconstruct the original image from noise .
We can sample from this process following:
We can think of as a linear
combination of and :
Using the same reparamaterization trick from equation 5, we can solve
for :
If we plug this back into equation 7, we get:
Thus, by using the reparameterization trick and following equation
6, we can cleanly represent as:
To model this, we train network We don't
need to learn the term (since it is defined by the
schedule), therefore we only optimize .
Similarly, we have access to and terms,
so we only need to optimize a network to
predict the noise term.
TODO: Implement the reverse sampling process
ddpm_sample():
Again, if you are interested in the math, check out here.
Some important details:
At each of your denoising iteration, you want to use different
random seed. So a good strategy for reproduction is that you
first sample with the input seed, then shift your
seed by 1 each denoising step.
During sample CFG, you will need to forward the
unet twice, one with mask =
torch.ones(...) and one with mask =
torch.zeros(...). You could do it in one batch, but it
is easy to make mistake. So I would suggest to just forward it
twice.
Note that we will return caches. It is just the
cache for some intermediate sampling results for animation. How
to cache this is your choice, but we suggest the following
way:
if t[0] % 20 == 0 or t[0] == num_ts or t[0] < 8:
caches.append(x)
2.3 Putting It All Together
We have all the pieces, let's now train our diffusion model.
Please consider this pseudo code for your training step.
Use the same network hyperparameters as in part 1.
You might need to increase num_epochs = 20 to get
good results (staff solution takes ~26 minutes on a Colab T4 GPU)
2.4 Deliverables
Your deliverables should include the following for this problem:
A training loss curve plot every few iterations during the
whole training process.
Sample results after the 1-st, 5-th, and your final epoch,
with guidance_scale = 5.
Sample results after the model is trained,
guidance_scale = [0, 5, 10].
Describe what you see and what you think CFG is doing.
For reference, here are the staff solution results (without skip
connections) for epochs 1, 5, 10, 15, and 20 with guidance scale
5.0.
Note: you do not need to generate gifs (this can be done as B&W
below).
Epoch 1
Epoch 5
Epoch 10
Epoch 15
Epoch 20
Bells & Whistles
Sampling Gifs (.1 Cookie Points)
Create your own sampling gifs similar to the ones shown above.
Improve the UNet Architecture (.15 Cookie Points)
For ease of explanation and implementation, our UNet architecture
above is pretty simple.
Modify the UNet (e.g. with skip connections) such that it can fit
better during training.
Implement Rectified Flow(.25 Cookie Points)
Implement rectified flow, which is the state of art diffusion model.
You can reference any code on github, but your implementation needs to follow the same code structure as our DDPM implementation.
In other words, the code change required should be minimal: only changing the forward and sample functions.