Following up on our blog post on the simplified diffusion loss (without a detour through ELBO and KL-divergences) we wrote a minimal model highling that the core ideas can be implemented in only a “few” lines of code. The notebook can be found on our public repo at:
Excerpts from the notebook
class DiffusionModel(nn.Module, BatchLoss):
...
def q_t(self, t, x_0 , eps):
"""
Forward model $q(x_t \mid x_0)$
from Eq. 1 in our blog post.
"""
t = t.view(-1, 1, 1, 1) # prepare for broadcasting to x0
return sqrt(self.alpha_bar[t]) * x_0 + sqrt(1 - self.alpha_bar[t]) * eps
def batch_loss(self, b):
"""
Computes the simplified loss for a whole batch.
"""
device = self.device
x0, labels = b # img, labels
B , _, _, _ = x0.size() # batch_size, channels, height, width
t = torch.randint(1, self.T + 1, size=(B,), device=device)
x0 = normalize_pixels(x0)
eps = N0(x0.size()).to(device)
xt = self.q_t(t=t, x_0=x0, eps=eps)
eps_hat = self.model(xt, t, labels)
L = loss_simple(eps_hat, eps, t)
return L