Skip to article frontmatterSkip to article content
Site not loading correctly?

This may be due to an incorrect BASE_URL configuration. See the MyST Documentation for reference.

Wasserstein Gradient Flows

Once W2\Wass_2 is a dynamic metric, one can run gradient descent directly on the space of measures. This chapter derives the formal Wasserstein gradient, explains the JKO minimizing-movement scheme, records the role of geodesic convexity in convergence, and then applies the same calculus to mean-field neural-network training.

Minimizing Movements and Wasserstein Gradients

This first section explains how a variational implicit-Euler step on measures gives rise, in the small-step limit, to a continuity equation driven by the Wasserstein gradient of the energy.

We consider a function f(α)f(\alpha) and seek a minimizing evolution (αt)t(\alpha_t)_t. The minimizing-movement strategy over a metric space builds a discrete-time evolution using an implicit Euler scheme:

αt+τ:=arg minα12τW2(αt,α)2+f(α).\alpha_{t+\tau} := \argmin_\alpha \frac{1}{2\tau}\Wass_2(\alpha_t,\alpha)^2+f(\alpha).

Euclidean Gradient Flows

If (1) is restricted to finite dimensions with αt=δx(t)\alpha_t=\delta_{x(t)} and α=δx\alpha=\delta_x, it becomes the implicit Euler scheme

x(t+τ):=arg minx12τxx(t)2+h(x),h(x)=f(δx).x(t+\tau) := \argmin_x \frac{1}{2\tau}\norm{x-x(t)}^2+h(x), \qquad h(x)=f(\delta_x).

Its solution is formally

x(t+τ)=(Id+τh)1(x(t)).x(t+\tau)=(\Id+\tau\nabla h)^{-1}(x(t)).

In contrast, explicit Euler uses

x(t+τ)=(Idτh)(x(t))=x(t)τh(x(t)).x(t+\tau)=(\Id-\tau\nabla h)(x(t))=x(t)-\tau\nabla h(x(t)).

Both schemes converge as τ0\tau\to0 to the classical gradient flow

x˙(t)=h(x(t)).\dot x(t)=-\nabla h(x(t)).

Wasserstein Gradient Formula

The implicit Euler scheme has the advantage that it does not require hh or ff to be smooth. For ff, this is crucial when evolutions over measures may have densities, atoms or other singular parts.

As τ0\tau\to0, under suitable conditions on ff, (1) defines a continuous evolution tαtt\mapsto\alpha_t. As in the dynamic formulation, this evolution can be described by a Lagrangian evolution. We use the following first-variation convention: for any βP(Rd)\beta\in\Pp(\RR^d) and the signed zero-mass perturbation ρ=βα\rho=\beta-\alpha,

f((1τ)α+τβ)=f(α+τρ)=f(α)+τ[δf(α)](x)dρ(x)+o(τ).f((1-\tau)\alpha+\tau\beta) = f(\alpha+\tau\rho) = f(\alpha)+ \tau\int[\delta f(\alpha)](x)\d\rho(x) +o(\tau).

The key infinitesimal object is the vector field that represents this differential in the Wasserstein metric.

The associated formal gradient flow is the continuity equation

αtt+div( ⁣Wf(αt)αt)=0.\frac{\partial\alpha_t}{\partial t} +\operatorname{div}(-\Wgrad f(\alpha_t)\alpha_t)=0.

The following proposition explains why this vector field is the Riemannian gradient for the L2(α)L^2(\alpha) metric on velocities.

Proof

The push-forward expansion gives, in the sense of distributions,

(Id+τv)α=ατdiv(αv)+o(τ).(\Id+\tau v)_\sharp\alpha = \alpha-\tau\operatorname{div}(\alpha v)+o(\tau).

Using the definition of the first variation,

f((Id+τv)α)=f(α)τδf(α)div(αv)dx+o(τ).f((\Id+\tau v)_\sharp\alpha) = f(\alpha) -\tau\int\delta f(\alpha)\operatorname{div}(\alpha v)\d x +o(\tau).

An integration by parts, with compact support or vanishing boundary flux, gives

δf(α)div(αv)dx=δf(α),vdα.-\int\delta f(\alpha)\operatorname{div}(\alpha v)\d x = \int\dotp{\nabla\delta f(\alpha)}{v}\d\alpha.

By definition of the Riesz representative for the L2(α)L^2(\alpha) metric, this representative is δf(α)\nabla\delta f(\alpha).

The Wasserstein gradient-flow viewpoint already appears in John D. Lafferty’s PhD work, published as “The Density Manifold and Configuration Space Quantization”, under the name “density manifold”. It was then systematically developed by Otto, who exposed the formal Riemannian structure of this space Otto, 2001. Rigorous metric-space treatments and numerical JKO schemes can be found in Ambrosio et al., 2006Benamou et al., 2016Peyré, 2015Gallouët & Monsaingeon, 2017.

From the JKO Step to the Velocity Field

A first-order expansion of the JKO step explains why (8) uses the vector field  ⁣Wf(α)\Wgrad f(\alpha). Write (1) as a minimization over displacement fields vv such that α=(Id+τv)αt\alpha=(\Id+\tau v)_\sharp\alpha_t:

minv12ττ2vL2(αt)2+f((Id+τv)αt).\min_v \frac{1}{2\tau}\tau^2\norm{v}_{L^2(\alpha_t)}^2 + f((\Id+\tau v)_\sharp\alpha_t).

The push-forward and energy expansions are

(Id+τv)αt=αtτdiv(vαt)+o(τ),(\Id+\tau v)_\sharp\alpha_t = \alpha_t-\tau\operatorname{div}(v\alpha_t)+o(\tau),
f((Id+τv)αt)=f(αt)τδf(αt)div(vαt)dx+o(τ)f((\Id+\tau v)_\sharp\alpha_t) = f(\alpha_t) -\tau\int\delta f(\alpha_t)\operatorname{div}(v\alpha_t)\d x +o(\tau)

and hence

f((Id+τv)αt)=f(αt)+τxδf(αt)(x),v(x)dαt(x)+o(τ).f((\Id+\tau v)_\sharp\alpha_t) = f(\alpha_t) + \tau\int \dotp{\nabla_x\delta f(\alpha_t)(x)}{v(x)} \d\alpha_t(x) +o(\tau).

Thus the problem minimized in (1) has the first-order expansion

minvf(αt)+τ[12v(x)2+ ⁣Wf(αt)(x),v(x)]dαt(x)+o(τ).\min_v f(\alpha_t) + \tau\int \left[ \frac12\norm{v(x)}^2 + \dotp{\Wgrad f(\alpha_t)(x)}{v(x)} \right] \d\alpha_t(x) +o(\tau).

The pointwise minimizer is v= ⁣Wf(αt)v=-\Wgrad f(\alpha_t), which gives the velocity in the continuity equation. We now detail examples of such Wasserstein gradient flows.

<IPython.core.display.Image object>

JKO minimizing movements for the entropy flow in one dimension. The left panel displays successive implicit-Euler minimizers for the heat equation, colored from red to blue. The right panel tracks inverse CDF values Qt(s)=Ft1(s)Q_t(s)=F_t^{-1}(s) for selected probability levels ss, giving a Lagrangian view of the proximal movement in Wasserstein space.

The interactive demo uses the heat-flow representative of the entropy JKO scheme: changing the step size changes the spacing between implicit Euler iterates, while the quantile panel shows how the same movement is seen in Lagrangian coordinates.

Interactive panel. Use the step size and iteration controls to inspect the JKO scheme as successive implicit steps of the entropy gradient flow.

Discrete Evolutions

If f(α)f(\alpha) can be evaluated on discrete distributions and  ⁣W\Wgrad is continuous in this case, the flow (8) maintains the number of Dirac masses:

αt=1niδxi(t).\alpha_t=\frac1n\sum_i\delta_{x_i(t)}.

The particles X(t)=(xi(t))iX(t)=(x_i(t))_i evolve according to the coupled ODE

x˙i(t)=nxiF(X(t)),\dot x_i(t)=-n\nabla_{x_i}F(X(t)),

where F(X)=f(1niδxi)F(X)=f\left(\frac1n\sum_i\delta_{x_i}\right). The factor nn comes from the empirical Wasserstein metric 1nix˙i2\frac1n\sum_i\norm{\dot x_i}^2.

Linear Functionals

The simplest example is a linear functional

f(α)=h(x)dα(x).f(\alpha)=\int h(x)\d\alpha(x).

Here δf(α)=h\delta f(\alpha)=h is independent of α\alpha. The flow (8) becomes

αtt+div(hαt)=0.\frac{\partial\alpha_t}{\partial t} + \operatorname{div}(-\nabla h\,\alpha_t)=0.

Thus particles move independently according to the usual gradient flow (5).

Shannon Neg-Entropy

A very different behavior is obtained by considering functionals that require αt\alpha_t to have a density. The canonical example is Shannon neg-entropy

f(α)=log(dαdx(x))dα(x).f(\alpha) = \int \log\left(\frac{\d\alpha}{\d x}(x)\right) \d\alpha(x).

Here δf(α)=log(dαdx)\delta f(\alpha)=\log(\frac{\d\alpha}{\d x}) up to an additive constant, so  ⁣Wf(α)=α/α\Wgrad f(\alpha)=\nabla\alpha/\alpha, often called the score. The flow (8) becomes the heat equation

tαt=Δαt.\partial_t\alpha_t=\Delta\alpha_t.

Other entropy functionals lead to nonlinear diffusion equations; finite-volume and particle discretizations are discussed in Carrillo et al., 2015Gianazza et al., 2009Maas, 2011Erbar, 2010.

For example, a generalized entropy

f(α)=g(dαdx)dxf(\alpha)=\int g\left(\frac{\d\alpha}{\d x}\right)\d x

for a scalar convex function gg leads, in the smooth-density regime, to

αtt=Δ(P(αt)),\frac{\partial\alpha_t}{\partial t} = \Delta(P(\alpha_t)),

where the pressure PP satisfies P(s)=sg(s)P'(s)=s g''(s). For g(s)=slogsg(s)=s\log s, one has P(s)=sP(s)=s and recovers (23); for g(s)=sm/(m1)g(s)=s^m/(m-1) with m>1m>1, one obtains P(s)=smP(s)=s^m up to an additive constant and the porous-medium equation.

A celebrated theorem by McCann McCann, 1997 states that an internal energy of the form (25), for g:R+R{+}g:\RR^+\to\RR\cup\{+\infty\} with g(0)=0g(0)=0, is geodesically convex on P(Rd)\Pp(\RR^d) when gg is convex and the map rrdg(rd)r\mapsto r^d g(r^{-d}) is convex and nonincreasing on (0,+)(0,+\infty). Examples include g(s)=sqg(s)=s^q for q>1q>1 and Shannon entropy g(s)=slogsg(s)=s\log s. By contrast, g(s)=logsg(s)=-\log s, associated with the reverse KL divergence, does not satisfy this displacement-convexity criterion.

<IPython.core.display.Image object>

Entropy-driven Wasserstein gradient flows from the same compact initial density. The heat flow is generated by Shannon entropy g(ρ)=ρlogρg(\rho)=\rho\log\rho and instantly develops Gaussian tails. The porous-medium flows use the power entropy g(ρ)=ρm/(m1)g(\rho)=\rho^m/(m-1), hence tρ=Δ(ρm)\partial_t\rho=\Delta(\rho^m): the middle panel has m=2m=2, while the right panel has the stronger nonlinearity m=6m=6, i.e. tρ=Δ(ρ6)\partial_t\rho=\Delta(\rho^6). Larger powers diffuse mainly where the density is high, producing a flatter core and a sharper compact free boundary.

The interactive demo isolates the effect of the entropy exponent. The heat curve keeps Gaussian tails, while increasing mm keeps a compact front and spreads mass mainly from the high-density core.

Interactive panel. Use the diffusion exponent and time controls to compare linear heat flow with nonlinear porous-medium spreading.

Interaction Energies

To obtain nonlinear evolutions without requiring the measure to have a density, one can consider

f(α):=k(x,y)dα(x)dα(y).f(\alpha) := \iint k(x,y)\d\alpha(x)\d\alpha(y).

For a symmetric kernel kk,

δf(α)(x)=2k(x,y)dα(y), ⁣Wf(α)(x)=2xk(x,y)dα(y).\delta f(\alpha)(x) = 2\int k(x,y)\d\alpha(y), \qquad \Wgrad f(\alpha)(x) = 2\int\nabla_x k(x,y)\d\alpha(y).

For α0=1niδxi\alpha_0=\frac1n\sum_i\delta_{x_i}, the flow (8) implies the particle system

x˙i(t)=2njk(xi(t),xj(t)).\dot x_i(t) = -\frac2n\sum_j\nabla k(x_i(t),x_j(t)).

If kk is positive definite, or more generally conditionally positive definite on signed measures of zero total mass as for the energy-distance kernel k(x,y)=xyk(x,y)=-\norm{x-y}, and one minimizes the squared kernel discrepancy to a teacher distribution β\beta, then

αβk2=kdαdα2(k(x,y)dβ(y))dα(x)+constant.\norm{\alpha-\beta}_k^2 = \iint k\d\alpha\d\alpha -2\int\left(\int k(x,y)\d\beta(y)\right)\d\alpha(x) +\mathrm{constant}.

Thus MMD-type training energies are exactly an interaction energy plus a linear potential. The teacher distribution appears through the potential x2k(x,y)dβ(y)x\mapsto-2\int k(x,y)\d\beta(y), and the corresponding empirical Wasserstein gradient flow is

x˙i(t)=2njxk(xi(t),xj(t))+2xk(xi(t),y)dβ(y).\dot x_i(t) = -\frac2n\sum_j\nabla_x k(x_i(t),x_j(t)) + 2\int\nabla_x k(x_i(t),y)\d\beta(y).

The first term is a kernelized self-interaction; the second is the attraction induced by the continuous teacher kernel mean. At the continuum level, characteristic positive-definite kernels, and the Euclidean energy-distance kernel on probability measures, have β\beta as the unique minimizer of αβk2\norm{\alpha-\beta}_k^2. For finitely many particles, however, the flow can only form a kernelized quadrature of β\beta, and small particle systems may cover the target modes poorly. The particle-count figure below illustrates this finite-particle effect.

<IPython.core.display.Image object>

Particle count in the deterministic Wasserstein gradient flow of the squared MMD-type discrepancy to a smooth two-Gaussian teacher distribution, using here the energy-distance kernel k(x,y)=xyk(x,y)=-\norm{x-y}. The teacher itself is shown only through true density contours, while red dots are a compact shifted Gaussian initialization placed away from the target, red-to-blue curves show a thinned subset of particle trajectories, and blue dots show the stabilized long-time particles. With too few particles, the empirical measure forms a sparse kernelized quadrature and may under-cover the target modes; increasing nn makes the particle cloud approximate the continuous target geometry more faithfully.

The interactive demo turns this finite-particle effect into a parameter: increasing the number of particles makes the same deterministic force field approximate the teacher geometry more faithfully.

Interactive panel. Use the particle count and kernel controls to see how MMD geometry drives a particle flow toward the target law.

<IPython.core.display.Image object>

Interaction-energy particle flows for three choices of kk. A positive Gaussian kernel k(x,y)=exp(xy2/(2σ2))k(x,y)=\exp(-\norm{x-y}^2/(2\sigma^2)) produces short-range repulsion under Wasserstein descent; changing its sign produces attraction and collapse; adding a quadratic long-range attraction to the repulsive kernel yields a balanced attraction--repulsion dynamics. The curves use arclength-based red-to-blue coloring along a longer integration of the coupled particle ODE (20).

The interactive demo lets the sign and strength of the interaction change without editing the hidden particle solver. This is the quickest way to see how the same formal ODE can repel, collapse, or self-organize.

Interactive panel. Use the interaction strength and time controls to watch particles move under attraction, repulsion, and confinement.

<IPython.core.display.Image object>

Particle trajectories induced by different discrepancy geometries. The red particles and blue target cloud are the same in all panels. Straight OT displacement produces rays from an optimal matching; an MMD-type witness field gives smoother nonlocal forces; the Sinkhorn-divergence force is an entropic, debiased transport attraction; and the normalized drifting field combines attraction to data with self-repulsion. The figure is qualitative: it compares geometric behavior, not solver performance.

The interactive demo keeps the source and target fixed while switching the discrepancy geometry. The smoothing parameter controls how local or nonlocal the induced force appears.

Interactive panel. Use the smoothing and geometry controls to compare how different discrepancies reshape the same particle objective.

Stochastic Particles and McKean--Vlasov Limits

Deterministic particle flows have stochastic counterparts, where Brownian noise at the particle level becomes an entropy term at the measure level. If the drift does not depend on the empirical measure, each particle evolves independently according to

dXt=b(Xt)dt+2σdBt,\d X_t=b(X_t)\d t+\sqrt2\,\sigma\d B_t,

and the one-particle law αt=ρtdx\alpha_t=\rho_t\d x satisfies the linear Fokker--Planck equation

tρt=div(bρt)+σ2Δρt.\partial_t\rho_t = -\operatorname{div}(b\rho_t)+\sigma^2\Delta\rho_t.

For example, if b=Vb=-\nabla V, this is the W2\Wass_2 gradient flow of the free energy

Vρdx+σ2ρlogρdx.\int V\rho\,\d x+\sigma^2\int\rho\log\rho\,\d x.

The mean-field case is different: the drift is recomputed from the current empirical distribution of all particles,

dXin(t)=b(Xin(t),μtn)dt+2σdBi(t),μtn=1ni=1nδXin(t).\d X_i^n(t) = b(X_i^n(t),\mu_t^n)\d t+\sqrt2\,\sigma\d B_i(t), \qquad \mu_t^n=\frac1n\sum_{i=1}^n\delta_{X_i^n(t)}.

For finite nn, the empirical law μtn\mu_t^n is random. Under suitable Lipschitz, growth and chaotic-initialization assumptions, propagation of chaos states that finitely many particles become asymptotically independent as nn\to\infty, all with the same deterministic law ρtdx\rho_t\d x. Equivalently, μtn\mu_t^n converges in probability to this law. The limiting density solves the nonlinear Fokker--Planck, or McKean--Vlasov, equation

tρt=div(b(x,ρt)ρt)+σ2Δρt.\partial_t\rho_t = -\operatorname{div}\big(b(x,\rho_t)\rho_t\big) + \sigma^2\Delta\rho_t.

When the interaction drift has variational form

b(x,ρ)=δEδρ(x),b(x,\rho) = -\nabla\frac{\delta\mathcal E}{\delta\rho}(x),

this PDE is the Wasserstein gradient flow of the entropy-regularized energy

E(ρ)+σ2ρlogρdx.\mathcal E(\rho)+\sigma^2\int\rho\log\rho\,\d x.
<IPython.core.display.Image object>

Three numerical representations of the same entropy-regularized Wasserstein gradient flow of KL(ρβ)\KL(\rho|\beta), where β\beta is a two-Gaussian target shifted to the right of an initially isotropic Gaussian density. The first row simulates independent Langevin particles and displays a thinned set of trajectories in the left panel. The second row evolves many deterministic particles with velocity τ(logβlogρt)\tau(\nabla\log\beta-\nabla\log\rho_t), estimating logρt\nabla\log\rho_t by a sharper kernel-density score; only representative trajectories and particle subsets are displayed. The third row solves the corresponding Fokker--Planck equation on a grid, starting from the initial density in the left panel. The remaining columns use front-loaded times, so that the onset of the flow and the later deformation toward a bimodal law are both visible.

The interactive demo compares three views of the same entropy-regularized relaxation: stochastic Langevin particles, deterministic score particles, and a smoothed grid density. The noise slider controls the entropy strength.

Interactive panel. Use the drift and noise controls to compare trajectories, particles, and density evolution for the same Fokker-Planck dynamics.

Geodesic Convexity and Convergence

Geodesic convexity is the convexity notion adapted to Wasserstein geometry. It is the condition that turns the formal gradient-flow calculus into a convergence theory.

Geodesics and Convexity

A constant-speed W2\Wass_2 geodesic between α0\alpha_0 and α1\alpha_1 is obtained, as in the McCann interpolation, from any optimal coupling πΠ(α0,α1)\pi^\star\in\Couplings(\alpha_0,\alpha_1) by

αt=((1t)P0+tP1)π,t[0,1],\alpha_t=((1-t)P_0+tP_1)_\sharp\pi^\star, \qquad t\in[0,1],

where P0(x,y)=xP_0(x,y)=x and P1(x,y)=yP_1(x,y)=y. If the optimal plan is induced by a Brenier map TT, this reduces to ((1t)Id+tT)α0((1-t)\Id+tT)_\sharp\alpha_0. The coupling formula matters because geodesics exist even when no Monge map exists, for instance when a Dirac mass must split.

Proof

Along a Monge geodesic Xt=(1t)X0+tX1X_t=(1-t)X_0+tX_1, convexity of hh gives h(Xt)(1t)h(X0)+th(X1)h(X_t)\leq(1-t)h(X_0)+t h(X_1), and strong convexity gives the additional quadratic term; integrating proves the first claim.

The interaction claim follows similarly by applying convexity of WW to pairwise differences XtXt=(1t)(X0X0)+t(X1X1)X_t-X_t'=(1-t)(X_0-X_0')+t(X_1-X_1') and integrating over two independent copies. The entropy claim is McCann’s displacement-convexity theorem; at the density level it follows from the concavity of the Jacobian determinant under the interpolation of optimal maps. Finally,

KL(αγ)=ρlogρdx+Vdα+constant,\KL(\alpha|\gamma) = \int\rho\log\rho\,\d x+\int V\d\alpha+\mathrm{constant},

so it is the sum of displacement-convex entropy and a λ\lambda-geodesically convex linear potential.

Convergence of the Flow

In general, analyzing (8) is delicate. The cleanest case is when ff is geodesically convex. This condition is the Wasserstein analogue of convexity in Euclidean gradient descent.

Proof

The chain rule and the formal Wasserstein-gradient proposition give

ddtf(αt)= ⁣Wf(αt)(x),vt(x)dαt(x)= ⁣Wf(αt)(x)2dαt(x).\frac{\d}{\d t}f(\alpha_t) = \int\dotp{\Wgrad f(\alpha_t)(x)}{v_t(x)}\d\alpha_t(x) = -\int\norm{\Wgrad f(\alpha_t)(x)}^2\d\alpha_t(x).

Geodesic convexity along the geodesic ((1s)Id+sTt)αt((1-s)\Id+sT_t)_\sharp\alpha_t gives

f(α)f(αt) ⁣Wf(αt)(x),Tt(x)xdαt(x).f(\alpha^\star)-f(\alpha_t) \geq \int\dotp{\Wgrad f(\alpha_t)(x)}{T_t(x)-x}\d\alpha_t(x).

Since vt= ⁣Wf(αt)v_t=-\Wgrad f(\alpha_t),

f(αt)f(α)vt(x),Tt(x)xdαt(x).f(\alpha_t)-f(\alpha^\star) \leq \int\dotp{v_t(x)}{T_t(x)-x}\d\alpha_t(x).

The first-variation formula for the squared Wasserstein distance gives

ddt12W22(αt,α)=xTt(x),vt(x)dαt(x),\frac{\d}{\d t}\frac12\Wass_2^2(\alpha_t,\alpha^\star) = \int\dotp{x-T_t(x)}{v_t(x)}\d\alpha_t(x),

which proves the differential inequality. Integrating it from 0 to tt and using monotonicity of sf(αs)s\mapsto f(\alpha_s) gives

t(f(αt)f(α))0t(f(αs)f(α))ds12W22(α0,α).t\bigl(f(\alpha_t)-f(\alpha^\star)\bigr) \leq \int_0^t \bigl(f(\alpha_s)-f(\alpha^\star)\bigr)\d s \leq \frac12\Wass_2^2(\alpha_0,\alpha^\star).

If ff is λ\lambda-geodesically convex, the Wasserstein analogue of strong convexity gives the slope inequality

 ⁣Wf(αt)2dαt2λ(f(αt)f(α)).\int\norm{\Wgrad f(\alpha_t)}^2\d\alpha_t \geq 2\lambda\bigl(f(\alpha_t)-f(\alpha^\star)\bigr).

Combining it with the energy dissipation identity yields

ddt(f(αt)f(α))2λ(f(αt)f(α)),\frac{\d}{\d t}\bigl(f(\alpha_t)-f(\alpha^\star)\bigr) \leq -2\lambda\bigl(f(\alpha_t)-f(\alpha^\star)\bigr),

and Gronwall’s lemma gives the exponential rate.

Proof

Let (αt)t(\alpha_t)_t be the McCann interpolation between α0\alpha_0 and α1\alpha_1, written with an optimal coupling as Xt=(1t)X0+tX1X_t=(1-t)X_0+tX_1. For a linear energy, Jensen’s inequality gives

h(Xt)(1t)h(X0)+th(X1),h(X_t)\leq(1-t)h(X_0)+t h(X_1),

and the strong convexity version gives the additional term λ2t(1t)X0X12-\frac{\lambda}{2}t(1-t)\norm{X_0-X_1}^2. Integrating over the optimal coupling proves geodesic convexity and λ\lambda-geodesic convexity.

For interaction energies, use two independent copies of the optimal coupling. The pairwise displacement evolves as

XtXt=(1t)(X0X0)+t(X1X1).X_t-X_t' = (1-t)(X_0-X_0')+t(X_1-X_1').

Convexity of WW gives the convexity inequality after integration over the product coupling. Evenness of WW ensures that the interaction is symmetric in the two particles and matches the usual factor 1/21/2 in (27).

The entropy claim is McCann’s displacement-convexity theorem. For smooth positive densities and Brenier maps, it follows from the change-of-variables formula and the concavity of the determinant along positive matrices; the general statement is obtained by approximation. Finally,

KL(αγ)=ρlogρdx+Vdα+logZ,\KL(\alpha|\gamma) = \int\rho\log\rho\,\d x+\int V\d\alpha+\log Z,

so it is the sum of the displacement-convex entropy and the λ\lambda-geodesically convex linear potential generated by VV. The energy-decay proposition then applies to all four cases.

Convexity and Curvature

The same language is not restricted to subsets of Rd\RR^d. If (X,d,m)(\X,\dist,\mathfrak m) is a geodesic metric-measure space, W2\Wass_2 geodesics can be defined by transporting each pair of endpoints along metric geodesics, or more intrinsically by dynamical optimal plans on path space. Given a reference measure m\mathfrak m, the entropy relative to m\mathfrak m is

Entm(α):={Xρlogρdm,if α=ρm,+,otherwise.\mathrm{Ent}_{\mathfrak m}(\alpha) \eqdef \begin{cases} \displaystyle\int_\X\rho\log\rho\,\d\mathfrak m, &\text{if }\alpha=\rho\,\mathfrak m,\\ +\infty, &\text{otherwise.} \end{cases}

On a smooth Riemannian manifold (M,g)(M,g), the Ricci curvature tensor Ricg\mathrm{Ric}_g is the trace of the Riemann curvature tensor. The lower bound Ricgλg\mathrm{Ric}_g\geq\lambda g means that Ricg(v,v)λvg2\mathrm{Ric}_g(v,v)\geq\lambda |v|_g^2 for every tangent vector vv. The fundamental link between curvature and optimal transport is that this tensor lower bound is exactly encoded by geodesic convexity of entropy.

This equivalence was developed in the smooth Riemannian setting by Cordero-Erausquin, McCann and Schmuckenschlaeger and by von Renesse and Sturm Cordero-Erausquin et al., 2001Renesse & Sturm, 2005; it is a central theme of the optimal-transport approach to curvature in Villani’s monograph Villani, 2009. Lott--Villani and Sturm then used the same entropy-convexity principle to define synthetic lower Ricci curvature bounds on metric-measure spaces Lott & Villani, 2009Sturm, 2006Sturm, 2006. Outside this convex, curvature-controlled regime, such as in the mean-field neural-network example below, the flow may still be informative but its convergence analysis requires problem-specific arguments.

Training Two-Layer MLPs as Wasserstein Flows

Mean-field limits recast the training of wide neural networks as transport of a distribution of neurons. This section shows how the particle ODE of gradient descent becomes a Wasserstein flow in parameter space.

We use zRdz\in\RR^d for the input data and yRdy\in\RR^{d'} for the label. A neuron is a particle

x=(u,v)Rd×Rd,x=(u,v)\in\RR^d\times\RR^{d'},

where uu is the inner weight and vv is the outer vector weight. For a scalar nonlinearity σ\sigma, define the vector-valued feature

ψ(x,z)=vσ(u,z)Rd.\psi(x,z)=v\,\sigma(\dotp{u}{z})\in\RR^{d'}.

The width-nn network and its mean-field version are

GX(z)=1ni=1nψ(xi,z),Gα(z)=ψ(x,z)dα(x),α=1niδxi.G_X(z)=\frac1n\sum_{i=1}^n\psi(x_i,z), \qquad G_\alpha(z)=\int\psi(x,z)\d\alpha(x), \qquad \alpha=\frac1n\sum_i\delta_{x_i}.

This formulation removes the artificial ordering of neurons and allows α\alpha to be a continuous distribution of infinitely many neurons.

Let ρ\rho be a probability distribution on data-label pairs (z,y)Rd×Rd(z,y)\in\RR^d\times\RR^{d'}. The population risk is

f(α)=(Gα(z),y)dρ(z,y),f(\alpha)=\int\ell(G_\alpha(z),y)\d\rho(z,y),

and the empirical risk is the special case ρ=ρN:=N1k=1Nδ(zk,yk)\rho=\rho_N\eqdef N^{-1}\sum_{k=1}^N\delta_{(z_k,y_k)}. Since αGα\alpha\mapsto G_\alpha is linear, ff is convex as a function of α\alpha whenever (,y)\ell(\cdot,y) is convex. For the empirical neuron law αX=n1iδxi\alpha_X=n^{-1}\sum_i\delta_{x_i}, the Wasserstein metric induces on particles the rescaled metric n1ix˙i2n^{-1}\sum_i\norm{\dot x_i}^2. The corresponding particle flow is

x˙i=nxiF(X),F(X)=f ⁣(1niδxi).\dot x_i=-n\nabla_{x_i}F(X), \qquad F(X)=f\!\left(\frac1n\sum_i\delta_{x_i}\right).

This is the gradient flow of F(X)=f(αX)F(X)=f(\alpha_X) for the Wasserstein particle metric, equivalently Euclidean gradient descent with time scale multiplied by nn. It gives a particle discretization of (8).

Assume that \ell is differentiable in its first variable. The first variation is

δf(α)(x)=1(Gα(z),y),ψ(x,z)dρ(z,y),\delta f(\alpha)(x) = \int \dotp{\nabla_1\ell(G_\alpha(z),y)}{\psi(x,z)} \d\rho(z,y),

and the Wasserstein gradient in parameter space is

 ⁣Wf(α)(x)=xδf(α)(x)=[Dxψ(x,z)]1(Gα(z),y)dρ(z,y).\Wgrad f(\alpha)(x) = \nabla_x\delta f(\alpha)(x) = \int [D_x\psi(x,z)]^\top\nabla_1\ell(G_\alpha(z),y) \d\rho(z,y).

For the squared Euclidean loss (s,y)=12sy2\ell(s,y)=\frac12\norm{s-y}^2, the energy is the sum of a quadratic interaction and a linear potential:

f(α)=12k(x,x)dα(x)dα(x)+g(x)dα(x)+12y2dρ(z,y),f(\alpha) = \frac12\iint k(x,x')\d\alpha(x)\d\alpha(x') + \int g(x)\d\alpha(x) + \frac12\int\norm{y}^2\d\rho(z,y),

with

k(x,x)=ψ(x,z),ψ(x,z)dρ(z,y),g(x)=y,ψ(x,z)dρ(z,y).k(x,x') = \int\dotp{\psi(x,z)}{\psi(x',z)}\d\rho(z,y), \qquad g(x) = -\int\dotp{y}{\psi(x,z)}\d\rho(z,y).

Thus

δf(α)(x)=k(x,x)dα(x)+g(x), ⁣Wf(α)(x)=xk(x,x)dα(x)+xg(x).\delta f(\alpha)(x) = \int k(x,x')\d\alpha(x')+g(x), \qquad \Wgrad f(\alpha)(x) = \int\nabla_x k(x,x')\d\alpha(x')+\nabla_x g(x).

These kernels are generally not convex in the particle variable, so the geodesic-convex convergence theory above does not apply directly.

<IPython.core.display.Image object>

Mean-field training of a homogeneous two-layer model as transport in neuron space. The left panel shows the Wasserstein particle gradient flow in the reduced homogeneous coordinates (uv1,uv2)(|u|v_1,|u|v_2), with black dashed rays marking the teacher directions. The right panel shows the weighted angular density along a front-loaded sequence of times, colored from red to blue, so that the early concentration of neuron directions is visible. The display follows the rendering of the auxiliary MLP experiment but keeps only the W2W_2 flow, not the spectral-flow comparison.

The interactive demo gives a lightweight version of the same phenomenon: particles move in reduced neuron coordinates, while their angles concentrate around the teacher directions.

Interactive panel. Use the width, homogeneity, and time controls to see the mean-field movement of ReLU neurons and the induced angular density.

Classical Convexity and Stationarity

Before using the specific homogeneity mechanism of Chizat and Bach, it is useful to isolate a simpler convex-analytic principle behind many mean-field arguments. Consider an energy

F(α)=12k(x,x)dα(x)dα(x)+V(x)dα(x)+CF(\alpha) = \frac12\iint k(x,x')\d\alpha(x)\d\alpha(x') + \int V(x)\d\alpha(x) +C

on probability measures over a parameter domain. Assume that the quadratic part is convex in the classical affine structure of measures:

Q((1s)α+sβ)(1s)Q(α)+sQ(β),Q(α)=12kdαdα.Q((1-s)\alpha+s\beta) \leq (1-s)Q(\alpha)+sQ(\beta), \qquad Q(\alpha)=\frac12\iint k\d\alpha\d\alpha.

This is ordinary convexity of the functional on the convex set of measures, not displacement convexity along W2\Wass_2 geodesics.

Proof Sketch

The dissipation identity for the gradient flow gives stationarity of the limit: formally, after passing to the limit,

δF(α)2dα=0.\int\norm{\nabla\delta F(\alpha_\infty)}^2\d\alpha_\infty=0.

Without support and positivity assumptions, this identity only controls the first variation on the region explored by the limit. The density hypothesis allows one to test against sufficiently many signed density perturbations of total mass zero. By approximation and the assumed regularity, this yields the displayed first-order variational inequality for arbitrary competitors β\beta. Classical convexity of FF in the affine variable α\alpha then gives the usual subgradient inequality

F(β)F(α)+δF(α)d(βα)F(α).F(\beta) \geq F(\alpha_\infty) + \int\delta F(\alpha_\infty)\d(\beta-\alpha_\infty) \geq F(\alpha_\infty).

Thus no competitor has smaller energy. For square-loss two-layer mean-field models, (73) is exactly of this quadratic-plus-linear form, and positive semidefiniteness of the induced kernel kk is the classical convexity assumption.

The mean-field description of two-layer training was developed in several works, including Chizat & Bach, 2018Mei et al., 2018. The distinctive contribution of Chizat and Bach is a global-convergence analysis for positively homogeneous networks without adding an explicit regularizer or relying on noisy SGD to create a Laplacian term. The following formal statement isolates the core mechanism and ignores the technical issues due to ReLU non-smoothness, support propagation and compactness.

Proof

Write

hα(x)=δf(α)(x)=J(Gα),ψ(x,)ρ.h_\alpha(x) = \delta f(\alpha)(x) = \left\langle\nabla J(G_\alpha),\psi(x,\cdot)\right\rangle_\rho.

By two-homogeneity of ψ\psi, hα(λx)=λ2hα(x)h_\alpha(\lambda x)=\lambda^2h_\alpha(x). Normalize a nonzero direction ω\omega and choose rω>0r_\omega>0 with rωωsupp(α)r_\omega\omega\in\operatorname{supp}(\alpha). Stationarity gives a zero radial derivative at this point:

0=ddrhα(rω)r=rω=2rωhα(ω).0 = \frac{\d}{\d r}h_\alpha(r\omega)\bigg|_{r=r_\omega} = 2r_\omega h_\alpha(\omega).

Hence hα(ω)=0h_\alpha(\omega)=0 for every direction ω\omega, and by homogeneity hα(x)=0h_\alpha(x)=0 for every xx.

For any competitor β\beta, convexity of JJ gives

f(β)f(α)hα(x)d(βα)(x)=0.f(\beta)-f(\alpha) \geq \int h_\alpha(x)\d(\beta-\alpha)(x)=0.

Thus no competitor has smaller risk. The rigorous theorem replaces the full directional support assumption by propagation and overparameterization hypotheses ensuring that a negative descent direction would be present in the support and would contradict stationarity.

References
  1. Otto, F. (2001). The geometry of dissipative evolution equations: the porous medium equation. Communications in Partial Differential Equations, 26(1–2), 101–174.
  2. Ambrosio, L., Gigli, N., & Savaré, G. (2006). Gradient Flows in Metric Spaces and in the Space of Probability Measures. Springer.
  3. Benamou, J.-D., Carlier, G., Mérigot, Q., & Oudet, E. (2016). Discretization of functionals involving the Monge–Ampère operator. Numerische Mathematik, 134(3), 611–636.
  4. Peyré, G. (2015). Entropic approximation of Wasserstein gradient flows. SIAM Journal on Imaging Sciences, 8(4), 2323–2351.
  5. Gallouët, T. O., & Monsaingeon, L. (2017). A JKO splitting scheme for Kantorovich–Fisher–Rao gradient flows. SIAM Journal on Mathematical Analysis, 49(2), 1100–1130.
  6. Carrillo, J. A., Chertock, A., & Huang, Y. (2015). A finite-volume method for nonlinear nonlocal equations with a gradient flow structure. Communications in Computational Physics, 17(01), 233–258.
  7. Gianazza, U., Savaré, G., & Toscani, G. (2009). The Wasserstein gradient flow of the Fisher information and the quantum drift-diffusion equation. Archive for Rational Mechanics and Analysis, 194(1), 133–220.
  8. Maas, J. (2011). Gradient flows of the entropy for finite Markov chains. Journal of Functional Analysis, 261(8), 2250–2292.
  9. Erbar, M. (2010). The heat equation on manifolds as a gradient flow in the Wasserstein space. Annales de l’Institut Henri Poincaré, Probabilités et Statistiques, 46(1), 1–23.
  10. McCann, R. J. (1997). A convexity principle for interacting gases. Advances in Mathematics, 128(1), 153–179.
  11. Cordero-Erausquin, D., McCann, R. J., & Schmuckenschläger, M. (2001). A Riemannian interpolation inequality à la Borell, Brascamp and Lieb. Inventiones Mathematicae, 146(2), 219–257.
  12. von Renesse, M.-K., & Sturm, K.-T. (2005). Transport inequalities, gradient estimates, entropy and Ricci curvature. Communications on Pure and Applied Mathematics, 58(7), 923–940.
  13. Villani, C. (2009). Optimal Transport: Old and New (Vol. 338). Springer.
  14. Lott, J., & Villani, C. (2009). Ricci curvature for metric-measure spaces via optimal transport. Annals of Mathematics, 169(3), 903–991.
  15. Sturm, K.-T. (2006). On the geometry of metric measure spaces. I. Acta Mathematica, 196(1), 65–131.