Skip to article frontmatterSkip to article content
Site not loading correctly?

This may be due to an incorrect BASE_URL configuration. See the MyST Documentation for reference.

Sampling with Stein Variational Gradient Descent (SVGD)

The SVGD Update Rule

Stein Variational Gradient Descent Liu & Wang, 2016 is a particle-based sampling algorithm that begins with samples drawn from an initial reference distribution q0q^0 and iteratively transports them via the SVGD velocity field ϕ\phi. After enough iterations, the transformed particles form an empirical distribution that closely matches the target.

Formally, at iteration l[0,,L1]l\in[0, \dots, L-1], the SVGD update rule is:

xil+1=xil+ϵExjlql[κ(xil,xjl)xjllogpˉ(xjl)+xjlκ(xil,xjl)]ϕ(xil),x_i^{l+1} = x_i^{l} + \epsilon \underbrace{ \mathbb{E}_{x_j^l \sim q^l} \left[ \kappa(x_i^l, x_j^l) \nabla_{x_j^l} \log \bar p(x_j^l) + \nabla_{x_j^l} \kappa(x_i^l, x_j^l) \right] }_{\phi(x_i^l)},

where xilx_i^l is the ii-th particle at iteration ll, ϵ\epsilon is the step-size, pˉ\bar p is the unnormalized density, and κ\kappa is a kernel function, most commonly the RBF kernel.

Particles’ evolution starting from q^0 and following the SVGD velocity field.

Figure 1:Particles’ evolution starting from q0q^0 and following the SVGD velocity field.

Code for Figure 1.
from svgd.sampler import SVGD
from svgd.distributions import TorchDistribution
from svgd.kernels import RBF
from svgd.kernels.parameters import HeuristicKP
from svgd.lrs import ParameterLR
from svgd.callbacks import Logger

import torch
from torch.distributions import MixtureSameFamily, Categorical, Independent, Normal

import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation

torch.manual_seed(0)

target_distribution = TorchDistribution(
    MixtureSameFamily(
        Categorical(torch.ones(2)),
        Independent(Normal(torch.tensor([[-1.0, 0.0], [1.0, 0.0]]), 1 / 3), 1),
    )
)
initial_distribution = TorchDistribution(Independent(Normal(torch.zeros(2), 1.0), 1))
kernel = RBF(HeuristicKP("median"))
lr = ParameterLR(torch.tensor(0.5))
logger = Logger(log_x=True)
logger.activated = True
svgd = SVGD(
    target_distribution=target_distribution,
    initial_distribution=initial_distribution,
    kernel=kernel,
    lr=lr,
    callbacks=[logger],
)

n_particles = 100
n_steps = 50
x, _, _ = svgd.sample(n_particles=n_particles, n_steps=n_steps)
x = x.detach()

grid = torch.arange(-2, 2, 0.001)
xg, yg = torch.meshgrid(grid, grid, indexing="ij")
grid = torch.cat((xg.reshape(-1)[:, None], yg.reshape(-1)[:, None]), dim=-1)
zg = target_distribution.log_prob(grid).exp().view(xg.shape)

fig, ax = plt.subplots()
ax.pcolormesh(xg, yg, zg, cmap="Oranges")
ax.set_xlim(-2.0, 2.0)
ax.set_ylim(-2.0, 2.0)
ax.axis("off")
scatter = ax.scatter([], [], color="black", alpha=0.6)


def animate(frame):
    scatter.set_offsets(logger.x[frame])
    return (scatter,)


animation = FuncAnimation(
    fig,
    animate,
    len(logger.x),
    interval=1000 / 60,
    blit=False,
)
animation.save("particle_evolution.gif", fps=120)

A key property of SVGD is that the velocity field ϕ\phi is chosen to maximally decrease the KL divergence between the particle distribution and the target. Intuitively, each update moves the particles in the direction that most closely makes their empirical distribution resemble the target distribution.

The RBF Kernel

The RBF kernel is defined as

κ(xil,xjl)=exp(12σ2xilxjl2),\kappa(x_i^l, x_j^l) = \exp\left( -\frac{1}{2\sigma^2} ||x_i^l - x_j^l||^2 \right),

where σ\sigma is the kernel bandwidth.

Using the RBF kernel, the SVGD update rule becomes:

xil+1=xil+ϵE[κ(xil,xjl)xjllogpˉ(xjl)drift term1σ2κ(xil,xjl)(xjlxil)repulsion term].x_i^{l+1} = x_i^{l} + \epsilon \mathbb{E} \Bigg[ \underbrace{ \kappa(x_i^l, x_j^l) \nabla_{x_j^l} \log \bar p(x_j^l) }_{\text{drift term}} - \underbrace{ \frac{1}{\sigma^2} \kappa(x_i^l, x_j^l) (x_j^l - x_i^l) }_{\text{repulsion term}} \Bigg].

In the drift term, the kernel value κ(xil,xjl)\kappa(x_i^l, x_j^l) determines how strongly the particle xilx_i^l is influenced by the score xjllogpˉ(xjl)\nabla_{x_j^l} \log \bar p(x_j^l). This term moves particles toward regions of high probability. In the repulsion term, the same kernel value controls how far xilx_i^l is pushed away from xjlx_j^l along the direction (xjlxil)(x_j^l - x_i^l). This term enforces diversity among particles, preventing them from collapsing onto a single mode.

The kernel takes its maximum value of 1 when xil=xjlx_i^l = x_j^l, and approaches 0 as xilxjl2||x_i^l - x_j^l||^2 \to \infty. Intuitively, κ(xil,xjl)\kappa(x_i^l, x_j^l) measures similarity between particles based on Euclidean distance, with the effective neighborhood size controlled by σ\sigma.

The neighborhood of x_i^l=0 as a function of x_j^l in 1D in terms of similarity \kappa(0, x_j^l) and repulsive force \nabla_{x_j^l} \kappa(0, x_j^l) for different \sigma values.

Figure 2:The neighborhood of xil=0x_i^l=0 as a function of xjlx_j^l in 1D in terms of similarity κ(0,xjl)\kappa(0, x_j^l) and repulsive force xjlκ(0,xjl)\nabla_{x_j^l} \kappa(0, x_j^l) for different σ\sigma values.

Code for Figure 2.
import torch
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

torch.manual_seed(0)

bound = 10.0
x = torch.arange(-bound, bound, 0.1).unsqueeze(-1)

sigma = torch.tensor([0.5, 2.5, 5.0])
gamma = sigma.pow(2).mul(2).pow(-1)

norm = x.pow(2).sum(-1)
k = gamma.unsqueeze(-1).mul(norm).mul(-1).exp()
k_xj = gamma.mul(2).unsqueeze(-1).unsqueeze(-1).mul(k.unsqueeze(-1)).mul(x.mul(-1))

data = pd.DataFrame(
    [
        {
            "sigma": s.item(),
            "x": x[idx, 0].item(),
            "k": k[s_idx, idx].item(),
            "k_xj": k_xj[s_idx, idx, 0].item(),
        }
        for s_idx, s in enumerate(sigma)
        for idx in range(x.shape[0])
    ]
)

r, c = 1, 2
fig, ax = plt.subplots(r, c, figsize=(6.4 * c, 4.8 * r))

if c == 1:
    ax = np.array([ax])

if r == 1:
    ax = np.array([ax])

sns.lineplot(
    data,
    x="x",
    y="k",
    hue="sigma",
    ax=ax[0, 0],
    palette=sns.color_palette("tab10"),
)
ax[0, 0].grid()
ax[0, 0].set_ylabel("$\\kappa(0, x_j^l)$")
ax[0, 0].set_xlabel("$x_j^l$")

sns.lineplot(
    data,
    x="x",
    y="k_xj",
    hue="sigma",
    ax=ax[0, 1],
    palette=sns.color_palette("tab10"),
)
ax[0, 1].grid()
ax[0, 1].set_ylabel("$\\nabla_{x_j^l} \\kappa(0, x_j^l)$")
ax[0, 1].set_xlabel("$x_j^l$")

handles, labels = ax[0, 0].get_legend_handles_labels()
labels = [f"$\\sigma = {label}$" for label in labels]

ax[0, 0].legend_.remove()
ax[0, 1].legend_.remove()

legend = fig.legend(
    handles,
    labels,
    loc="lower center",
    bbox_to_anchor=(0.42, 0.7),
    framealpha=1.0,
)

fig.savefig("bandwidth.svg", bbox_inches="tight")

Choosing an appropriate value for σ\sigma is nontrivial: both extremely small and extremely large bandwidths cause the repulsion term to vanish. A common heuristic is the median trick, which sets σ=median{xilxjl}i,j=0M1/2logM\sigma = \mathrm{median} \{ ||x_i^l - x_j^l|| \}_{i,j=0}^{M-1} / \sqrt{2 \log M}, where MM is the number of particles. This choice roughly ensures that jiκ(xil,xjl)1\sum_{j \neq i} \kappa(x_i^l, x_j^l) \approx 1.

However, it is not obvious that this property is always desirable, and in practice the median trick can be somewhat brittle depending on the geometry of the target distribution.

Derivation of the SVGD Update Rule

Suppose we want to sample from a target distribution pp. Given xqx \sim q, the idea behind SVGD is to find a velocity field ϕ\phi that maximally decreases the KL divergence between the distribution qϵq_\epsilon of fϵ(x)=x+ϵϕ(x)f_\epsilon(x) = x + \epsilon \phi(x) and pp, which we denote as DKL(qϵp)D_{KL} \left(q_\epsilon \,||\, p \right). Note that q0=qq_0 = q, since f0(x)=xf_0(x)=x.

For a small enough ϵ\epsilon, we can approximate DKL(qϵp)D_{KL} \left(q_\epsilon \,||\, p \right) with its taylor expansion with respect to ϵ\epsilon around 0:

DKL(qϵp)DKL(q0p)+ϵϵDKL(qϵp)ϵ=0.D_{KL} \left(q_\epsilon \,||\, p \right) \approx D_{KL} \left(q_0 \,||\, p \right) + \epsilon \nabla_\epsilon D_{KL} \left(q_\epsilon \,||\, p \right) |_{\epsilon=0}.

We can, therefore, estimate the decrease in the KL divergence as:

DKL(q0p)DKL(qϵp)ϵϵDKL(qϵp)ϵ=0.D_{KL} \left(q_0 \,||\, p \right) - D_{KL} \left(q_\epsilon \,||\, p \right) \approx - \epsilon \nabla_\epsilon D_{KL} \left(q_\epsilon \,||\, p \right) |_{\epsilon=0}.

Hence, since ϕ\phi and ϵ\epsilon are independent, finding ϕ\phi that maximizes this decrease is approximately equivalent to finding ϕ\phi that maximizes ϵDKL(qϵp)ϵ=0-\nabla_\epsilon D_{KL} \left(q_\epsilon \,||\, p \right) |_{\epsilon=0}. This becomes exact as ϵ0\epsilon \to 0, in which case we are looking for ϕ\phi that maximizes the instantaneous decrease in DKL(qϵp)D_{KL} \left(q_\epsilon \,||\, p \right).

Formally, we want to solve the following optimization problem:

ϕ=argmaxϕFϵDKL(qϵp)ϵ=0,\phi^* = \underset{\phi \in \mathcal{F}}{\mathrm{argmax}} \,\, - \nabla_{\epsilon} D_{KL} \left(q_\epsilon \,||\, p \right) |_{\epsilon=0},

where F\mathcal{F} is a suitable family of functions.

Let ypy \sim p. One can show that DKL(qϵp)=DKL(qp~)D_{KL}(q_\epsilon \,||\, p) = D_{KL}(q \,||\, \tilde p), where p~\tilde p is the distribution of fϵ1(y)f_\epsilon^{-1}(y), assuming that fϵf_\epsilon is invertible. This allows us to get a closed form expression of the gradient of DKLD_{KL} with respect to ϵ\epsilon, as qq is independent of ϵ\epsilon:

ϵDKL(qϵp)ϵ=0=Exq[xlogp(x)ϕ(x)+Tr(xϕ(x))].-\nabla_{\epsilon} D_{\mathrm{KL}}(q_\epsilon \,\|\, p) |_{\epsilon=0} = \mathbb{E}_{x\sim q} \left[ \nabla_{x} \log p(x)^{\top}\,\phi(x) + \mathrm{Tr}(\nabla_x \phi(x)) \right].

In case F\mathcal{F} is a subset of HD\mathcal{H}^D, where HD\mathcal{H}^D is a Reproducing Kernel Hilbert Space (RKHS) with a corresponding kernel κ(x,y)\kappa(x, y), then, using the reproducing property of the RKHS, ϕ(x)\phi(x) can be written as ϕ(),κ(x,)HD\langle \phi(\cdot), \kappa(x, \cdot) \rangle_{\mathcal{H}^D}. Given this, we can rewrite the gradient of the KL divergence as an inner product:

ϵDKL(qϵp)ϵ=0=ϕ(),Exq[κ(x,)xlogp(x)+xκ(x,)]HD.-\nabla_{\epsilon} D_{\mathrm{KL}}(q_\epsilon \,\|\, p) |_{\epsilon=0} = \left\langle \phi(\cdot) , \mathbb{E}_{x\sim q} \left[ \kappa(x, \cdot) \nabla_{x} \log p(x) + \nabla_{x} \kappa(x, \cdot) \right] \right\rangle_{\mathcal{H}^D}.

If we constrain the norm of ϕ\phi in HD\mathcal{H}^D, so that F={ϕ:ϕHD s.t. ϕHD1}\mathcal{F} = \{ \phi : \phi \in \mathcal{H}^D \text{ s.t. } ||\phi||_{\mathcal{H}^D} \leq 1 \}, the maximum of the inner product is achieved when ϕ\phi is proportional to the second argument, hence the maximizer is:

ϕp,q()Exq[κ(x,)xlogp(x)+xκ(x,)].\phi_{p,q} (\cdot) \propto \mathbb{E}_{x\sim q} \left[ \kappa(x, \cdot) \nabla_{x} \log p(x) + \nabla_{x} \kappa(x, \cdot) \right].

For a distribution known up to a normalization constant p(x)=pˉ(x)/Zp(x) = \bar p(x) / Z, xlogp(x)=xlogpˉ(x)\nabla_x \log p(x) = \nabla_x \log \bar p(x), which does not depend on the normalization constant ZZ.

References
  1. Liu, Q., & Wang, D. (2016). Stein variational gradient descent: A general purpose bayesian inference algorithm. NeurIPS.