{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Discrete Denoising Diffusion Models\n", "\n", "Denoising diffusion models (DDMs) are currently the state-of-the-art approach for image generation, but can they be used for generating discrete data like language and protein sequences? We derived the basic principles for DDMs with continuous-valued data last week. Here, we will show how these concepts extend to discrete data, and how concepts like continuous-time diffusion and the score function in the reverse diffusion SDE extend to the discrete setting." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Discrete-Time, Discrete-State DDMs\n", "\n", "We will start in discrete time and develop a DDM for discrete-valued data\n", "\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Setup\n", "- Let $x_0 \\in \\cX$ denote a data point.\n", "- Let $|\\cX| = S < \\infty$ is the vocabulary size.\n", "- Let $q_0(x_0)$ denote the data distribution\n", "\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Noising process\n", "\n", "Use a Markov chain that gradually converts $x_0$ to $x_T \\sim q_T(x_T)$, which is pure \"noise,\"\n", "\n", "\\begin{align*}\n", "q(x_{0:T}) &= q_0(x_0) \\prod_{t=1}^T q_{t|t-1}(x_t \\mid x_{t-1})\n", "\\end{align*}\n", "\n", "\n", "For example, $q_T(x_T) = \\mathrm{Unif}_{\\cX}(x_T)$ could be achieved by,\n", "\n", "\\begin{align*}\n", "q_{t|t-1}(x_t \\mid x_{t-1})\n", "&= (1 - \\lambda_t) \\bbI[x_t = x_{t-1}] + \\frac{\\lambda_t}{S-1} \\bbI[x_t \\neq x_{t-1}]\n", "\\end{align*}\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Masking diffusion\n", "\n", "One of the most effective noising processes is the _masking_ diffusion which introduces a $\\mathsf{MASK}$ token that is an absorbing state of the Markov process,\n", "\n", "\\begin{align*}\n", "q_{t|t-1}(x_t \\mid x_{t-1}) \n", "&= (1 - \\lambda_t) \\bbI[x_t = x_{t-1}] + \\lambda_t \\bbI[x_t = \\mathsf{MASK}]\n", "\\end{align*}\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Reverse process\n", "\n", "The reverse of the noising process is also a Markov chain! (We can derive this from the graphical model.) It factors as,\n", "\n", "\\begin{align*}\n", "q(x_{0:T}) &= q_T(x_T) \\prod_{t=T-1}^0 q_{t|t+1}(x_t \\mid x_{t+1})\n", "\\end{align*}\n", "\n", "The reverse transition probabilities can be obtained via Bayes' rule,\n", "\n", "\\begin{align*}\n", "q_{t|t+1}(x_t \\mid x_{t+1}) \n", "&= \\frac{q_t(x_t) \\, q_{t+1|t}(x_{t+1} \\mid x_t)}{q_{t+1}(x_{t+1})}\n", "\\end{align*}\n", "\n", "Alternatively, we can express the reverse transition probabilities in terms of the _denoising distributions_ $q_{0|t+1}(x_0 \\mid x_{t+1})$ as follows,\n", "\n", "\\begin{align*}\n", "q_{t|t+1}(x_t \\mid x_{t+1}) \n", "&= q_{t+1|t}(x_{t+1} \\mid x_t) \\frac{\\sum_{x_0} q_{t|0}(x_t \\mid x_0) \\, q_0(x_0)}{q_{t+1}(x_{t+1})} \\\\\n", "&= q_{t+1|t}(x_{t+1} \\mid x_t) \\sum_{x_0} \\frac{q_{t|0}(x_t \\mid x_0)}{q_{t+1|0}(x_{t+1} \\mid x_0)} q_{0|t+1}(x_0 \\mid x_{t+1})\n", "\\end{align*}\n", "\n", ":::{admonition} Explanation\n", ":class: dropdown\n", "In the last line we used the fact that,\n", "\n", "\\begin{align*}\n", "\\frac{q_0(x_0)}{q_{t+1}(x_{t+1})} \n", "&= \n", "\\frac{q_{0|t+1}(x_0 \\mid x_{t+1})}{q_{t+1|0}(x_{t+1} \\mid x_0)},\n", "\\end{align*}\n", "\n", "which follows from the chain rule.\n", ":::" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Approximating the Reverse Process\n", "\n", "**Problem:** We know everything in the reverse transition probability but the denoising distribution $q_{0|t_1}(x_0 \\mid x_{t+1})$. \n", "\n", "**Solution:** Learn it! Parameterize the reverse transition probability as,\n", "\\begin{align*}\n", "p_{t|t+1}(x_t \\mid x_{t+1}; \\theta) &= \n", "q_{t+1|t}(x_{t+1} \\mid x_t) \\sum_{x_0} \\frac{q_{t|0}(x_t \\mid x_0)}{q_{t+1|0}(x_{t+1} \\mid x_0)} p_{0|t+1}(x_0 \\mid x_{t+1}; \\theta)\n", "\\end{align*}\n", "\n", "where $p_{0|t+1}(x_0 \\mid x_{t+1}; \\theta)$ is a **learned, approximate denoising distribution**\n", "\n", "We can then sample from the approximate reverse process one step at a time from $T$ down to $0$,\n", "\\begin{align*}\n", "p(x_{0:T}; \\theta) &= q_T(x_T) \\prod_{t=T-1}^0 p_{t|t+1}(x_t \\mid x_{t+1}; \\theta).\n", "\\end{align*}" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### The Evidence Lower Bound\n", "\n", "We will estimate the model parameters $\\theta$ by maximizing the ELBO, which is a sum over data points of,\n", "\n", "\\begin{align*}\n", "\\cL(\\theta, x_0) \n", "&= \\E_{q(x_{1:T} \\mid x_0)}\\left[\\log p(x_{0:T}; \\theta) - \\log q(x_{1:T} \\mid x_0) \\right] \\\\\n", "&= \\E_{q(x_{1:T} \\mid x_0)} \\left[ \\sum_{t=0}^{T-1} \\log p_{t|t+1}(x_t \\mid x_{t+1}; \\theta) - \\sum_{t=1}^{T-1} \\log q(x_t \\mid x_{t+1}, x_0) \\right] \\\\\n", "&= \\sum_{t=1}^{T-1} \\E_{q(x_t, x_{t+1} \\mid x_0)} \\left[\\log p_{t|t+1}(x_t \\mid x_{t+1}; \\theta) - \\log q(x_t \\mid x_{t+1}, x_0) \\right] + \\cL_0(\\theta, x_0) \\\\\n", "&= \\sum_{t=1}^{T-1} \\E_{q(x_{t+1} \\mid x_0)} \\left[-\\KL{q(x_t \\mid x_{t+1}, x_0)}{p_{t|t+1}(x_t \\mid x_{t+1}; \\theta)} \\right] + \\cL_0(\\theta, x_0) \n", "\\end{align*}\n", "\n", "where \n", "\n", "\\begin{align*}\n", "\\cL_0(\\theta, x_0) &= \\E_{q(x_1 | x_0)}\\left[\\log p(x_0 \\mid x_1; \\theta) \\right]\n", "+ \\E_{q(x_1 | x_0)}\\left[\\log p(x_0 \\mid x_1; \\theta) \\right]\n", "\\end{align*}\n", "\n", "is the ELBO for the last term in the sum.\n", "\n", "Note that we are again using Rao-Blackwellization to write the ELBO in terms of expectations over fewer random variables for each term in the sum. \n", "\n", ":::{admonition} Important\n", ":class: warning\n", "\n", "We choose $q$ such that the marginal distribution $q_{t|0}(x_t)$ and interpolating distribution $q(x_t \\mid x_{t+1}, x_0)$ are available in closed form!\n", ":::" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Continuous-Time Markov Chains\n", "\n", "We ended our discussion of continuous-state DDMs by noting that in the continuous-time limit the noising and reverse processes are SDEs. For discrete-state DDMs, the continuous-time limit involves **Continuous-Time Markov Chains (CTMCs)**, which are closely related to Poisson processes! " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Properties of CTMCs\n", "\n", "A CTMC is a stochastic process $\\{x_t : t\\in[0,T]\\}$ taking values on a finte state space $\\cX$ such that:\n", "\n", "1. Sample paths $t \\mapsto x_t$ are right-continuous and have finitely many jumps\n", "\n", "2. The **Markov propert** holds:\n", "\n", " \\begin{align*}\n", " \\Pr(x_{t+\\Delta t} = j \\mid \\{x_s: s\\leq t\\}) &= \\Pr(x_{t+\\Delta t}=j \\mid x_t)\n", " \\end{align*}" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Transition Probabilities\n", "\n", "CTMCs are uniquely characterized by the transition distributions $q_{t|s}(x_t=j \\mid x_s=i) q_{t|s}(x_t=j \\mid x_s=i)$ for $t \\geq s$. \n", "\n", "These must satisfy the **Chapman-Kolmogorov Equations**\n", "\n", "\\begin{align*}\n", "q_{t|s}(x_t=j \\mid x_s=i) \n", "&= \\sum_{k \\in \\cX} q_{u|s}(x_u=k \\mid x_s=i) q_{t|u}(x_t=j \\mid x_u=k)\n", "\\end{align*}\n", "\n", "for $s \\leq u \\leq t$." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Rate matrices\n", "\n", "Equivalently, we can identify a CTMC by its **rate matrices**\n", "\\begin{align*}\n", "R_s(i \\to j) &= \\frac{\\partial}{\\partial t} q_{t|s}(x_t=j \\mid x_s=i) \\bigg|_{t=s}\n", "\\end{align*} \n", "\n", "Intuitively, $R_s(i \\to j)$ is the amount of probability flow from $i$ to $j$ at time $s$.\n", "\n", "Properties of rate matrices:\n", "1. $R_s(i \\to j) \\geq 0$ for $i \\neq j$. (_Probability flow must be outward._)\n", "2. $\\sum_j R_s(i \\to j) = 0$ (_Probability is conserved._)\n", "\n", "**Define** $R_s(i) = \\sum_{j\\neq i} R_s(i \\to j)$ to be the total outward probability flow.\n", "\n", "A **homogenous** CTMC has a fixed rate matrix $R_s \\equiv R$ for all times $s \\in [0,T]$." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Gillespie's Algorithm\n", "\n", "Consider a homogenous CTMC with rate matrix $R$. We can simulate a draw from the CTMC using Gillespie's Algorithm\n", "\n", "**Initialize** $x_0 \\sim \\pi_0$ and set $t_0 = 0$, $i=0$.\n", "\n", "While $t_i < T$\n", "- Draw the _waiting time_ $\\Delta_i \\sim \\mathrm{Exp}(R(x_i))$ \n", "- If $t_i + \\Delta_i > T$, return $\\{(t_j, x_j)\\}_{j=0}^i$\n", "- Else, set $t_{i+1} \\leftarrow t_i + \\Delta_i$ and draw the next state\n", " \\begin{align*}\n", " x_{i+1} \\sim \\mathrm{Cat}\\left(\\left[\\frac{R(x_i \\to x_j)}{R(x_i)} \\right]_{j \\neq i} \\right)\n", " \\end{align*}\n", "\n", "Does this look familiar? " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Connection to Poisson processes\n", "\n", "We can cast a CTMC as a _marked_ Poisson process with times $t_i$ and marks $x_i \\in \\cX$. \n", "\n", "The process follows a _conditional_ intensity function that depends on the history $\\cH_t$. In particular, the history contains the current state $x_t$ (since the state path is right continuous),\n", "\\begin{align*}\n", "\\lambda(t, x \\mid \\cH_t) \n", "&= R_t(x_t \\to x) \\cdot \\bbI[x \\neq x_t]\n", "\\end{align*}\n", "\n", "Gillespie's algorithm is using the Poisson superposition and Poisson thinning properties we discussed last time! To sample the waiting time, we are using the fact that $\\lambda(t \\mid \\cH_t) = \\sum_x \\lambda(t, x \\mid \\cH_t)$ is a Poisson process on the time of the next event. Once we sample the time, we use Poisson thinning to sample the next state (i.e. mark).\n", "\n", "Rao and Teh (2010) used this construction to develop a very clever Gibbs sampling algorithm for CTMCs!" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### CTMCs in Continuous-Time Discrete DDMs\n", "\n", "The reversal of a CTMC is another CTMC, and Campbell et al. (2022) showed how to parameterize the reverse process of a discrete-state, continuous-time DDM in terms of the backward rates. \n", "\n", "It turns out the backward rate is,\n", "\\begin{align*}\n", "\\tilde{R}_{T-t}(i \\to j) &= R_t(j \\to i) \\frac{q_t(x_t=j)}{q_t(x_t=i)}\n", "\\end{align*}\n", "where the density ratio $\\frac{q_t(x_t=j)}{q_t(x_t=i)}$ can be seen as the analog of the _score function_ for a discrete distribution.\n", "\n", "Sampling the backward process tricky because the reverse rate is **inhomogeneous**, and Gillespies algorithm for inhomogeneous processes requires integrating rate matrices. Instead, Campbell et al. (2022) propose to use a technique called **tau-leaping** to approximately sample the backward process. Then they use **corrector steps** to try to correct for the errors in the approximation. In recent work, we show how to develop more informative correctors for discrete diffusion with masking processes (Zhao et al., 2024).\n", "\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Conclusion\n", "\n", "Discrete DDMs are a nice way to wrap up this course! They combine old and new: Poisson processes and CTMCs, as well as modern deep generative models.\n", "\n", "These models have recently been used for langugage modeling, and sampling from them can be much more faster than from an autoregressive model like a Transformer since many words can be generated in parallel. See Inception AI!" ] } ], "metadata": { "language_info": { "name": "python" } }, "nbformat": 4, "nbformat_minor": 2 }