Skip to article frontmatterSkip to article content
Site not loading correctly?

This may be due to an incorrect BASE_URL configuration. See the MyST Documentation for reference.

The SVGD-Induced Density

As a particle-based variational inference method, SVGD Liu & Wang, 2016 evolves a set of MM particles {xi0}i=0M1\{ x_i^0 \}_{i=0}^{M-1} sampled from an initial distribution q0q^0 to collectively match any arbitrarily complex target distribution pp, as long as it is possible compute the score of pp. However, SVGD does not provide an explicit expression for the density it induces: while we can generate samples, we do not know the value of the corresponding probability density at those sampled points. That is, if we evolve {xi0}i=0M1\{ x_i^0 \}_{i=0}^{M-1} for LL steps according to the SVGD update rule, we would obtain {xiL}i=0M1qL\{ x_i^L \}_{i=0}^{M-1} \sim q^L with qLpq^L \approx p, but we wouldn’t know the value of qL(xiL)q^L(x_i^L) for a given xiLx_i^L. This is problematic when we are interested in downstream tasks such as likelihood-based evaluation, uncertainty quantification, or entropy estimation, as these require access to the density itself rather than just samples.

Derivation of the SVGD-induced Density

Density’ evolution starting from q^0 and following the SVGD velocity field.

Figure 1:Density’ evolution starting from q0q^0 and following the SVGD velocity field.

Code for Figure 1.
from svgd.sampler import SVGD
from svgd.distributions import TorchDistribution
from svgd.kernels import RBF
from svgd.kernels.parameters import HeuristicKP
from svgd.lrs import ParameterLR
from svgd.callbacks import Logger

import torch
from torch.distributions import MixtureSameFamily, Categorical, Independent, Normal

import matplotlib
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation

torch.manual_seed(0)

target_distribution = TorchDistribution(
    MixtureSameFamily(
        Categorical(torch.ones(2)),
        Independent(Normal(torch.tensor([[-1.0, 0.0], [1.0, 0.0]]), 1 / 3), 1),
    )
)
initial_distribution = TorchDistribution(Independent(Normal(torch.zeros(2), 1.0), 1))
kernel = RBF(HeuristicKP("median"))
lr = ParameterLR(torch.tensor(0.5))
logger = Logger(log_x=True)
logger.activated = True
svgd = SVGD(
    target_distribution=target_distribution,
    initial_distribution=initial_distribution,
    kernel=kernel,
    lr=lr,
    callbacks=[logger],
)

n_particles = 100
n_steps = 50
x, _, _ = svgd.sample(n_particles=n_particles, n_steps=n_steps)
x = x.detach()

bound = 2.5

grid = torch.arange(-bound, bound, 0.01)
xg, yg = torch.meshgrid(grid, grid, indexing="ij")
grid = torch.cat((xg.reshape(-1)[:, None], yg.reshape(-1)[:, None]), dim=-1)
zg = target_distribution.log_prob(grid).exp().view(xg.shape)

fig, ax = plt.subplots()
ax.pcolormesh(xg, yg, zg, cmap="Oranges")
ax.set_xlim(-bound, bound)
ax.set_ylim(-bound, bound)
ax.axis("off")
scatter = ax.scatter([], [], color="black", alpha=0.6)


def animate(frame):
    for artist in ax.collections:
        if isinstance(artist, matplotlib.contour.QuadContourSet):
            artist.remove()

    scatter.set_offsets(logger.x[frame])
    sns.kdeplot(
        x=logger.x[frame][:, 0],
        y=logger.x[frame][:, 1],
        ax=ax,
        alpha=0.4,
        color="black",
    )

    return (scatter,)


animation = FuncAnimation(
    fig,
    animate,
    len(logger.x),
    interval=1000 / 60,
    blit=False,
)
animation.save("density_evolution.gif", fps=120)

Suppose that, after ll SVGD steps, the particle xlx^l is distributed according to qlq^l. Then, using the change of variable formula for densities (CVF), the distribution of xl+1=xl+ϵϕ(xl)x^{l+1} = x^l + \epsilon \phi(x^l), where ϕ\phi is the SVGD velocity field, is:

ql+1(xl+1)=ql(xl)detxl(xl+ϵϕ(xl))1=ql(xl)det(I+ϵxlϕ(xl))1.\begin{align*} q^{l+1}(x^{l+1}) &= q^{l}(x^{l}) \left|\det \nabla_{x^l} \left( x^l + \epsilon \phi(x^l) \right)\right|^{-1} \\ &= q^{l}(x^{l}) \left|\det \left( I + \epsilon \nabla_{x^l} \phi(x^l) \right)\right|^{-1}. \end{align*}

However, to apply the CVF, the SVGD transformation must be invertible. To ensure this, MET-SVGD adapts a sufficient conditions for the invertibility of transformations of the form f(x)=x+g(x)f(x)=x+g(x) from Behrmann et al. (2019) to SVGD. Namely, f(xl)=xl+ϵϕ(xl)f(x^l) = x^l + \epsilon \phi(x^l) is invertible if ϵsupxlxlϕ(xl)2<1\epsilon \sup_{x^l} ||\nabla_{x^l} \phi(x^l)||_2 < 1, where xlϕ(xl)2||\nabla_{x^l} \phi(x^l)||_2 denotes the spectral norm of xlϕ(xl)\nabla_{x^l} \phi(x^l).

In practice, it is easier to work with the log of the density, also called the log-likelihood:

logql+1(xl+1)=logql(xl)logdet(I+ϵxlϕ(xl)).\log q^{l+1}(x^{l+1}) = \log q^{l}(x^{l}) - \log \left|\det \left( I + \epsilon \nabla_{x^l} \phi(x^l) \right)\right|.

Unfortunately, this expression involves the log of the determinant of a jacobian, which is expensive to compute. To avoid this, under the condition ϵλmax(xlϕ(xl))<1\epsilon |\lambda_\text{max}(\nabla_{x^l} \phi(x^l))| < 1, MET-SVGD accurately estimates logdet(I+ϵxlϕ(xl))\log \left|\det \left( I + \epsilon \nabla_{x^l} \phi(x^l) \right)\right| by ϵTr(xlϕ(xl))\epsilon \mathrm{Tr}\left( \nabla_{x^l} \phi(x^l) \right), giving:

logql+1(xl+1)logql(xl)ϵTr(xlϕ(xl)).\log q^{l+1}(x^{l+1}) \approx \log q^{l}(x^{l}) - \epsilon \mathrm{Tr}\left( \nabla_{x^l} \phi(x^l) \right).

For samples {xil}i=0M1ql\{ x_i^l \}_{i=0}^{M-1} \sim q^l, the trace term evaluated at xilx_i^l is:

Tr(xilϕ(xil))=1Mj=0M1[xilκ(xil,xjl)xjllogpˉ(xjl)+Tr(xilxjlκ(xil,xjl))]+1MTr(xil2logpˉ(xil)).\begin{align*} \mathrm{Tr}\left( \nabla_{x_i^l} \phi(x_i^l) \right) &= \frac{1}{M} \sum_{j=0}^{M-1} \left[ \nabla_{x_i^l} \kappa(x_i^l, x_j^l)^\top \nabla_{x_j^l} \log \bar p(x_j^l) + \mathrm{Tr}\left( \nabla_{x_i^l} \nabla_{x_j^l} \kappa(x_i^l, x_j^l) \right) \right] \\ &+ \frac{1}{M} \mathrm{Tr}\left( \nabla_{x_i^l}^2 \log \bar p(x_i^l) \right). \end{align*}

When κ\kappa is the RBF kernel, the first term in the above expression can be efficiently computed using only vector dot products. However, the second term is computationally expensive because it involves computing the trace of a hessian, which has O(D2)\mathcal{O}(D^2) complexity.

One way to altogether bypass computing this term is to estimate the expectation in ϕ(xil)\phi(x_i^l) using {xil}i=0M1{xil}\{ x_i^l \}_{i=0}^{M-1} - \{ x_i^ l\}. However, this turns out to be suboptimal in the finite particle case. Instead, MET-SVGD efficiently estimates it as:

1MTr(xil2logpˉ(xil))=(i)1MEvpv[vxil2logpˉ(xil)v]=(ii)1MEvpv[xil(vxillogpˉ(xil))v]1MVk=0V1xil(vkxillogpˉ(xil))vk,\begin{align*} \frac1M\mathrm{Tr}\left( \nabla_{x_i^l}^2 \log \bar p(x_i^l) \right) &\stackrel{(i)}{=} \frac1M\mathbb{E}_{v \sim p_v}\left[ v^\top \nabla_{x_i^l}^2 \log \bar p(x_i^l) v \right] \\ &\stackrel{(ii)}{=} \frac1M\mathbb{E}_{v \sim p_v}\left[ \nabla_{x_i^l} \left( v^\top \nabla_{x_i^l} \log \bar p(x_i^l) \right) v \right] \\ &\approx \frac1{MV} \sum_{k=0}^{V-1} \nabla_{x_i^l} \left( v_k^\top \nabla_{x_i^l} \log \bar p(x_i^l) \right) v_k, \end{align*}

where vpvv \sim p_v satisfy Evpv[v]=0\mathbb{E}_{v \sim p_v}[v]=0 and Evpv[vv]=I\mathbb{E}_{v \sim p_v}[vv^\top]=I, and VV is the number of vkv_k samples.

Since the estimator is weighted by 1M\frac1M, its variance is greatly reduced, and, in practice, only one vv is sufficient.

Unifying the Step-Size Conditions

The correctness of the previous derivation depends on two conditions:

While these are separate conditions, they can be unified by considering the order relation between the spectral norm of a real-valued square matrix AA and the magnitude of λmax(A)\lambda_\text{max}(A).

According to Wolkowicz & Styan (1980), for ARd×dA \in \mathbb{R}^{d \times d}:

λi(A)σi(A)Tr(AA)i[1d],|\lambda_i(A)| \leq \sigma_i(A) \leq \sqrt{ \mathrm{Tr}(A^\top A) } \quad \forall i \in [1 \dots d],

where λi(A)\lambda_i(A) and σi(A)\sigma_i(A) are the ii-th eigenvalue and singular value of AA, respectively.

Given that xlϕ(xl)2=σmax(xlϕ(xl))||\nabla_{x^l} \phi(x^l)||_2 = \sigma_\text{max}(\nabla_{x^l} \phi(x^l)), we have:

ϵλmax(xlϕ(xl))ϵsupxlxlϕ(xl)2ϵsupxlTr(xlϕ(xl)xlϕ(xl)).\epsilon |\lambda_\text{max}(\nabla_{x^l} \phi(x^l))| \leq \epsilon \sup_{x^l} ||\nabla_{x^l} \phi(x^l)||_2 \leq \epsilon \sup_{x^l} \sqrt{ \mathrm{Tr}\left( \nabla_{x^l} \phi(x^l)^\top \nabla_{x^l} \phi(x^l) \right) } .

Therefore, in order to satisfy both conditions, it is sufficient to choose the step-size such that:

ϵ<(supxlTr(xlϕ(xl)xlϕ(xl)))1.\epsilon < \left( \sup_{x^l} \sqrt{ \mathrm{Tr}\Big( \nabla_{x^l} \phi(x^l)^\top \nabla_{x^l} \phi(x^l) \Big) } \right)^{-1} .

Note that Tr(xlϕ(xl)xlϕ(xl))\mathrm{Tr}\Big( \nabla_{x^l} \phi(x^l)^\top \nabla_{x^l} \phi(x^l) \Big) can be efficiently computed using only vector dot products and first-order derivatives, similarly to Tr(xlϕ(xl))\mathrm{Tr}\left( \nabla_{x^l} \phi(x^l) \right). And, in practice, MET-SVGD solves supxl\sup_{x^l} by taking the maximum over particles {xil}i=0M1\{ x_i^l \}_{i=0}^{M-1} at iteration ll.

References
  1. Liu, Q., & Wang, D. (2016). Stein variational gradient descent: A general purpose bayesian inference algorithm. NeurIPS.
  2. Behrmann, J., Grathwohl, W., Chen, R. T., Duvenaud, D., & Jacobsen, J.-H. (2019). Invertible residual networks. ICML.
  3. Hutchinson, M. F. (1989). A stochastic estimator of the trace of the influence matrix for Laplacian smoothing splines. Commun. Stat. Simul. Comput.
  4. Song, Y., Garg, S., Shi, J., & Ermon, S. (2020). Sliced score matching: A scalable approach to density and score estimation. UAI.
  5. Wolkowicz, H., & Styan, G. P. (1980). Bounds for eigenvalues using traces. Linear Algebra Appl.