Denoising Diffusion Model - A minimal example generating MNIST images

— 14 October 2022

Following up on our blog post on the simplified diffusion loss (without a detour through ELBO and KL-divergences) we wrote a minimal model highling that the core ideas can be implemented in only a “few” lines of code. The notebook can be found on our public repo at:

Excerpts from the notebook

class DiffusionModel(nn.Module, BatchLoss):


	def q_t(self, t, x_0 , eps):
	    Forward model $q(x_t \mid x_0)$ 
	    from Eq. 1 in our blog post.
	    t = t.view(-1, 1, 1, 1) # prepare for broadcasting to x0
	    return sqrt(self.alpha_bar[t]) * x_0  +  sqrt(1 - self.alpha_bar[t]) * eps
	def batch_loss(self, b):
	    Computes the simplified loss for a whole batch.
	    device = self.device

	    x0, labels = b # img, labels
	    B , _, _, _ = x0.size() # batch_size, channels, height, width

	    t   = torch.randint(1, self.T + 1, size=(B,), device=device)
	    x0  = normalize_pixels(x0)
	    eps = N0(x0.size()).to(device)
	    xt  = self.q_t(t=t, x_0=x0, eps=eps)
	    eps_hat = self.model(xt, t, labels)

	    L = loss_simple(eps_hat, eps, t)
	    return L

Exemplary rollouts of our model: