Andreas Bærentzen
This notebook represents my efforts at understanding smoothed Wasserstein distances. It is a topic that has obvious relevance to a number of areas within my research interests, and the basic idea is simple and obvious. However, the derivation of the equations is somewhat non-trivial, hence this note.
# Import the packages we will need and setup for rendering
from math import *
import random
import numpy as np
import PIL
from scipy import misc
from scipy.optimize import linprog
from scipy.ndimage.filters import gaussian_filter
from numpy.linalg import norm
from bokeh.plotting import figure, output_file, show
from bokeh.layouts import row, column
from bokeh.io import output_notebook
output_notebook()
def rand_color():
v = np.array((random.random(), random.random(), random.random()))
v**2
v /= np.max(v)
return v
def color_image(img, col):
img_out = np.empty(img.shape, dtype=np.uint32)
view = img_out.view(dtype=np.uint8).reshape(img.shape + (4,))
for i in range(view.shape[0]):
for j in range(view.shape[1]):
alpha = min(1.0, img[i,j] / 0.002)
view[i, j, 0] = 255.0 * col[0] * alpha
view[i, j, 1] = 255.0 * col[1] * alpha
view[i, j, 2] = 255.0 * col[2] * alpha
view[i, j, 3] = 255* alpha
return img_out
def color_image_combine(a,b):
viewa = a.view(dtype=np.uint8).reshape(a.shape + (4,))
viewb = b.view(dtype=np.uint8).reshape(a.shape + (4,))
for i in range(a.shape[0]):
for j in range(a.shape[1]):
if viewa[i,j,3] < viewb[i,j,3]:
viewa[i,j,:] = viewb[i,j,:]
def boxes(shape):
imgs = []
w2 = int(shape[0]/2)
h2 = int(shape[1]/2)
rngs = [(0,0),(w2,0),(0,h2),(w2,h2)]
for r in rngs:
img = np.zeros(shape)
for i in range(r[0],r[0]+w2):
for j in range(r[1], r[1]+h2):
img[i,j] = 1
imgs.append(img)
return imgs
def load_image(fn):
img = misc.imread(fn, True)
return np.flip(img,0)
sigma = 1
def filter(image):
return gaussian_filter(image,sigma,mode='constant',truncate=32)
def make_distribution(image_in):
dist = filter(image_in)
dist /= np.sum(dist)
return dist
powv = np.vectorize(pow)
maxv = np.vectorize(max)
logv = np.vectorize(lambda x: log(max(1e-300,x)))
In recent years, several authors (e.g. [Cuturi 2013, Solomon et al. 2015]) have proposed efficient methods for computing smoothed Wasserstein or "Earth Mover's" distance in a number of settings. This is something that appears to have many applications in image processing, machine learning and shape analysis to name just some areas. The recent paper by Solomon et al. [2015] is an excellent place to start in order to get an overview of the possibilities. Unfortunately, while computing the smoothed Wasserstein distance is actually both simple and efficient, it is not trivial to understand precisely how or why the algorithm works.
The goal of this brief note is to
In practice, we might want to compute the Wasserstein distance on, say, a tensor product of two manifolds, but, first, we must understand the basics.
Transportation theory operates on probability distributions. In this note, we consider only discrete distributions. In practical terms, a distribution is a set of probabilities that cover all outcomes of an experiment.
We can represent a discrete distribution, say ${p}$, as a vector s.t. $0\le p_i \le 1$ and $\sum_i p_i = 1$.
For instance, we might have $p = \frac 1 6 [1 1 1 1 1 1]^T $ if the experiment is a die cast with a fair die. For many applications of transport theory it might not matter that the distribution is a probability distribution, and, of course, we can simply normalize to ensure that it sums to 1. None the less, it is easier to understand the notions, if we appreciate that we are talking about probability distributions.
Given two distributions, we often want to know the distance between them. One measure is the square Euclidean distance:
$$ D_2(p,q) = \sum_i |p_i - q_i|^2 $$
Now, consider the distributions $p$, $q$, and $r$
# Draw the p, q, and r distributions as vertical bars
pic1 = figure(title="p", plot_width=250, plot_height=250, tools="")
pic2 = figure(title="q", plot_width=250, plot_height=250, tools="")
pic3 = figure(title="r", plot_width=250, plot_height=250, tools="")
p = [0.1,0.5,0.1,0.1,0.1,0.1]
q = [0.1,0.1,0.1,0.5,0.1,0.1]
r = [0.1,0.1,0.1,0.1,0.1,0.5]
pic1.vbar(np.linspace(1,6,6), width=0.9,top=p)
pic2.vbar(np.linspace(1,6,6), width=0.9,top=q)
pic3.vbar(np.linspace(1,6,6), width=0.9,top=r)
show(row(pic1, pic2, pic3))
It is easy to see from the three charts that $D_2(p,q) = D_2(p,r)$. Unfortunately, that is probably not what we want. In distribution $p$ the hump with probability $0.5$ is at outcome 2 and in $r$ it is at outcome six whereas in distribution $q$ it is at outcome four. The $D_2$ distance is completely blind to that, so we are looking for a distance that helps us understand how much the "mass" of the distribution is moved around.
This is where the Wasserstein distance comes into the picture. Let us define the 2-Wasserstein distance as
$$ W_2(p,q) = \inf_{\pi \in \Pi(p,q)} \sum_{i,\,j} \pi_{i,\,j} d^2(i,j) $$
where $\Pi(p,q)$ is the space of couplings, i.e. the space
$$ \Pi(p,q) = \{ \pi \, |\, \pi \,\mathbf 1 = p \wedge \pi^T \mathbf 1 = q \} $$
We can think of a coupling between $p$ and $q$ as a matrix whose rows sum to $p$ and whose columns sum to $q$. Clearly, we could imagine different couplings for any pair of distributions, and to obtain the 2-Wasserstein distance, we need to find the coupling that yields the smallest possible distance. We are also assuming here that there is some distance measure between the outcomes. For the purpose of this paper, we will simply assume that the probability distributions are associated with a grid of points in Euclidean space.
We observe that the constraints on the coupling are linear equality constraints and we are simply minimizing a sum of products. Thus, we have a linear programming problem which we can solve, e.g., using the simplex method. With this in place, we can now compute the 2-Wasserstein distance between pairs of distributions like $p$, $q$, and $r$.
The code below computes precisely $D_2$ and $W_2$ for these three distributions.
# Computing the 2-Wasserstein distances from p to p,q, and r
d2 = np.ndarray((6,6))
A_eq = np.zeros((2*6,6*6))
b_eq_pp = np.array(p+p)
b_eq_pq = np.array(p+q)
b_eq_pr = np.array(p+r)
for j in range(0,6):
for i in range(0,6):
d2[j,i] = (i-j)**2
A_eq[j,j*6+i] = 1
A_eq[6+j,i*6+j] = 1
c = np.reshape(d2,6*6)
output = linprog(c,A_eq=A_eq, b_eq=b_eq_pp)
print("D2(p,p)= %5.2f" % norm(np.array(p)-np.array(p)), " W2(p,p)= %5.2f" % c.dot(output.x))
output = linprog(c,A_eq=A_eq, b_eq=b_eq_pq)
print("D2(p,q)= %5.2f" % norm(np.array(p)-np.array(q)), " W2(p,q)= %5.2f" % c.dot(output.x))
output = linprog(c,A_eq=A_eq, b_eq=b_eq_pr)
print("D2(p,r)= %5.2f" % norm(np.array(p)-np.array(r)), " W2(p,r)= %5.2f" % c.dot(output.x))
Observe that $D_2(p,q) = D_2(p,r)$ but $W_2(p,q)<W_2(p,r)$ just as we would expect. However, the computational cost of computing $W_2$ can be quite significant for a large multidimensional distribution. Also -- while this brief note will not go into issues of asymptotic complexity or practical efficiency -- it appears that computing $W_2$ is very difficult to do in a highly efficient, parallelizable way that scales to large problems.
For this reason, we turn to regularization.
In order to make the problem tractable, we need to introduce the notion of entropy. The entropy of a probability distribution can be defined as
$$ E(p) = - \sum_i p_i \ln(p_i) $$ and similarly for a coupling.
For a relatively smooth distribution the entropy is greater than for a distribution with many outcomes that have 0 probability and a few with high probability. For a very handwavy explanation consider the "priceless Ming dynasty vase". In its unbroken state there is really only a single configuration of the shards, but if it is broken (in the high entropy state) there many equally probable configurations of the shards that are roughly equivalent.
The point now is that we can smooth the Wasserstein distance using entropy in order to make the problem easier to solve [Cuturi 2013]. Let us define the smoothed 2-Wasserstein distance:
$$ W_{2,e}(p,q) = \inf_{\pi \in \Pi(p,q)} \sum_{i,\,j} \pi_{i,\,j} d^2(i,j) - \lambda E(\pi) $$
The benefit of the entropy term is that the resulting function is strongly concave [Peyre 2017] and that we can therefore more easily find a solution.
Let us define the value at $y$ of a heat kernel centered at $x$ (or vice versa) as
$$ H(x,y) = \exp\left(-\frac{d^2(x,y)}{\sigma^2}\right) $$
If we now isolate $d^2(x,y)$, we get [Solomon 2015]:
$$ d^2(x,y) = - \sigma^2 \ln H(x,y) $$
If we plug both this and the entropy into the expression for $W_{2,e}$ then we get
$$ W_{2,e}(p,q) = \inf_{\pi \in \Pi(p,q)} \sum_{i,\,j} -\pi_{i,\,j} \sigma^2 \ln(H_{i,\,j}) + \lambda \sum_{i,\,j} \pi_{i,\,j} \ln(\pi_{i,\,j}) $$
It is possible to significantly simplify this expression by setting $\lambda = \sigma^2$, merging the sums and exploiting properties of $\ln$. We arrive at [Solomon et al. 2015a]:
$$ W_{2,e}(p,q) = \inf_{\pi \in \Pi(p,q)} \sigma^2 \sum_{i,\,j} \pi_{i,\,j} \ln \left(\frac{\pi_{i,\,j}}{H_{i,\,j}}\right) = \inf_{\pi \in \Pi(p,q)}\sigma^2 [1 + \mathrm{KL}(\pi \| H)] $$
This is much simpler and also, from the last equality, we see that the solution is really the minimizer of the Kullback-Leibler divergence from $H$ to $\pi$. So, since $H$ is fixed and we are looking for $\pi$, it is clear that we are trying to find the $\pi$ such that the KL divergence from $H$ to the $\pi$ we seek is smallest. The prepositions matter becase KL divergence is not symmetric.
The the KL divergence is defined as follows [Solomon et al. 2015a]:
$$ \mathrm{KL}(\pi\|H) = \sum_{i,\,j} \pi_{i,\,j} \left( \ln\left(\frac{\pi_{i,\,j}}{H_{i,\,j}}\right) -1 \right) = \sum_{i,\,j} \pi_{i,\,j} \left( \ln\left(\frac{\pi_{i,\,j}}{H_{i,\,j}}\right) \right) -1 $$
since $\pi_{i,\,j}$ sums to unit. Below, we will be using the form with -1 inside the parenthesis since, otherwise, the partial derivatives become less nice.
But, we still have the conditions that $\pi \mathbf{1} = p$ andf $\pi^T \mathbf{1} = q$. We can incorporate these conditions using Lagrange multipliers as explained below in an exposition that owes much to Justin Solomon's PhD thesis [Solomon 2015]. Here, we need to observe that if $p$ and $q$ are both $n$ dimensional, then we have $2n$ Lagrange multipliers. Below, we omit the constant $\sigma^2$ since it does not influence the minimum
The energy we want to minimize is
$$ \mathcal{E}(\pi) = \sum_{i,j} \pi_{i,\,j} \left( \ln \frac{\pi_{i,\,j}}{H_{i,\,j}} -1 \right) $$
Now, let $f$ and $g$ be vectors of the same dimensions as $p$ and $q$. The constrained energy becomes:
$$ \mathcal{E}_c(\pi) = \sum_{i,j} \pi_{i,\,j} \left( \ln \frac{\pi_{i,\,j}}{H_{i,\,j}}-1\right) - f^T (\pi \mathbf{1}-p) - g^T (\pi^T \mathbf 1 - q) $$
We compute the partial derivatives $$ \frac{\partial \mathcal{E}_c(\pi)}{\partial \pi_{i,\,j}} = \ln \pi_{i,\,j}- \ln H_{i,\,j} -f_i - g_j $$ Solving for $\frac{\partial \mathcal{E}_c(\pi)}{\partial \pi_{i,\,j}} = 0$ leads to $$ \ln \pi_{i,\,j} = \ln H_{i,\,j} + f_i + g_j $$ taking the exponential $$ \pi_{i,\,j} = e^{f_i} H_{i,\,j} e^{g_j} $$ and substituting $v_i = e^{f_i}$ as well as $w_j = e^{g_j}$ $$ \pi_{i,\,j} = v_i H_{i,\,j} w_j $$ or, finally, $$ \pi = \mathrm{diag}(v) H \mathrm{diag}(w) $$ where $\mathrm{diag}(v)$ is a diagonal matrix that has $v$ as its diagonal.
So, $\pi = \mathrm{diag}(v) H \mathrm{diag}(w)$ and plugging into the marginal constraints we get
$$\mathrm{diag}(v) H \mathrm{diag}(w) \mathbf{1} = p $$ and $$\mathrm{diag}(w) H \mathrm{diag}(v) \mathbf{1} = q $$
It turns out that Sinkhorn [67] has shown that the solution is found using a very simple iterative algorithm:
$$ w^0 \leftarrow \mathbf 1 $$
$$ v^{(l+1)} \leftarrow \frac{p}{H w^{(l)}} \;\; , \;\; w^{(l+1)} \leftarrow \frac{p}{H^T v^{(l+1)}} $$
where the division is per element.
def Sinkhorn(mu):
iter = 250
v = np.ones(mu[0].shape)
w = np.ones(mu[0].shape)
for i in range(0,iter):
v = mu[0] / maxv(filter(w), 1e-300)
w = mu[1] / maxv(filter(v), 1e-300)
wasserstein_dist = (mu[0]*(logv(v))+mu[1]*(logv(w))).sum()*sigma
return (wasserstein_dist,v,w)
where $\mu_0 = p$ and $\mu_1 = q$ (mu
above is $\mu$). Note that the division above is still elementwise. Also, note that since $H$ is the heat kernel matrix, $H \mathrm{diag}(w) \mathbf 1$ is the same as simply performing a Gaussian filter on $w$ as noted by Solomon et al. [2015]. This is important since we do not really want to construct the heat kernel matrix -- especially not since Gaussian filtering is often implemented in a library and also separable which means that for nD problems we can filter each dimension separately with a Gaussian filter.
Wait, how did we compute the actual Wasserstein distance? This happens in the last but one line in the code, but like many of the steps, it is not trivial. Following Solomon [2015], assuming we have found the coupling $\pi$ that minimizes the $W_{2,e}$, we now need to compute
$$ W_{2,e} = \sigma^2 \sum_{i,\,j} \pi_{i,\,j} \ln \left(\frac{\pi_{i,\,j}}{H_{i,\,j}}\right) = \sigma^2 \sum_{i,\,j} \pi_{i,\,j} \ln \left(v_i w_j \right) = \sigma^2 \sum_{i,\,j} \pi_{i,\,j} (\ln(v_i) + \ln(w_j)) = \sigma^2 (p^T \ln(v) + q^T \ln(w)) $$
This is not easy (at least not for me) but the second equality exploits the factorization of the $\pi$ matrix as from the Sinkhorn algorithm. The fourth equality exploits that row and columns sum to the marginal distributions. The application of the natural logarithm to a vector should be construed as a per element operation.
# Setup code for testing the computation of smooth Wasserstein distances
def move_blob_test():
dist = [make_distribution(load_image("one-blob.png")),
make_distribution(load_image("one-blob-moved.png")),
make_distribution(load_image("one-blob-moved-even-more.png")),
make_distribution(load_image("one-blob-moved-even-more-again.png"))
]
pics = [figure(title = "Original", x_range=(0,1),y_range=(0,1),plot_width=250,plot_height=250)]
pics[0].image(image=[dist[0]],x=0,y=0,dw=1,dh=1)
for i in range(1,4):
wd,v,w = Sinkhorn((dist[0],dist[i]))
ed = norm(dist[0]-dist[i])
p = figure(title = "W2 = "+str(wd)+" D2 = "+str(ed), x_range=(0,1),y_range=(0,1),plot_width=250,plot_height=250)
p.image(image=[dist[i]],x=0,y=0,dw=1,dh=1)
pics += [p,]
show(row(pics))
Below, we test the implementation of Sinkhorn's algorithm applied to the problem of computing entropically smoothed Wasserstein distances. The distributions are in 2D and simply a white blob that moves. The example is similar to the one shown in the introduction for the non-smooth Wasserstein distance, only in 2D. Again we see that as the distribution on the right moves, there is a change in $W_{2,e}$ but not in $D_2$.
# Compute W2 between several similar distributions
move_blob_test()
Often, we want the optimal coupling matrix $\pi$ and not just the Wasserstein distance. Remember that $\pi \mathbf 1 = p$, i.e. one marginal distribution, and, likewise, $q = \pi^T \mathbf 1$. If we multiply a matrix with ones in just some elements, that basically tells us where stuff in that part of the vector goes.
Again, to avoid explicitly constructing the heat kernel $H$, we apply Gaussian filtering to the input and use the factored formulation from the Sinkhorn algorithm as seen in the code below.
# Code for the evaluation of the coupling matrix and quadrant test
def SinkhornEvalR(v,w,a):
return v*filter(w*a)
def SinkhornEvalL(v,w,a):
return w*filter(v*a)
dist1 = make_distribution(load_image("smile_s.png"))
dist2 = make_distribution(load_image("dots_s.png"))
#output_file("sinkhorn.html")
def quadrant_test():
# Compute the coupling between two distributions and use that to find
# out how we go from quadrants in one image to quadrants in the other.
wd,v,w = Sinkhorn((dist1,dist2))
print ("Wasserstein distance = ", wd)
img_out = np.zeros(dist1.shape,dtype=np.uint32)
cmp_out = np.zeros(dist1.shape,dtype=np.uint32)
for b in boxes(dist1.shape):
color = rand_color()
img = SinkhornEvalR(v,w,b)
col_img = color_image(img, color)
cmp_img = color_image(b*dist2, color)
color_image_combine(img_out, col_img)
color_image_combine(cmp_out, cmp_img)
p1 = figure(title = "Original", x_range=(0,1),y_range=(1,0),plot_width=420,plot_height=420)
p1.image_rgba(image=[cmp_out],x=0,y=1,dw=1,dh=1)
p2 = figure(title = "Wasserstein distance = "+str(wd), x_range=(0,1),y_range=(1,0),plot_width=420,plot_height=420)
p2.image_rgba(image=[img_out],x=0,y=1,dw=1,dh=1)
show(row(p1,p2))
With this in place, we can find out where pixels in each of the four quadrants of the picture below go
# Show how four blobs map to a smiley using the coupling from the smoothed W2
quadrant_test()
# Show smoothed Wasserstein interpolation
def WassersteinBarycenter(alpha, mu, dims):
v = [np.ones(dims), np.ones(dims)]
w = [np.ones(dims), np.ones(dims)]
d = [np.ones(dims), np.ones(dims)]
mu_out = np.ones(dims)
iter = 100
for j in range(0,iter):
mu_out = np.ones(dims)
for i in range(0,2):
w[i] = mu[i] / maxv(filter(v[i]),1e-100)
d[i] = v[i] * filter(w[i])
mu_out = mu_out * powv(d[i],alpha[i])
for i in range(0,2):
v[i] = v[i] * mu_out/maxv(d[i],1e-100)
return mu_out
def WassersteinBarycenter_test():
# Compute a sequence of Wasserstein barycenters interpolating between our two
# distributions
dist1 = make_distribution(load_image("smile_s.png"))
dist2 = make_distribution(load_image("dots_s.png"))
rows = []
pics = []
for alpha in np.linspace(0,1,8):
dist = WassersteinBarycenter((alpha,1.0-alpha), (dist1,dist2), dist1.shape)
p = figure(x_range=(0,1),y_range=(1,0),plot_width=200,plot_height=200)
p.image([dist],x=0,y=1,dw=1,dh=1)
pics += [p]
if len(pics) == 4:
rows += [row(pics)]
pics = []
c = column(rows)
show(c)
Often we want something a little different like a more natural interpolation between distributions. This can be done with Wasserstein barycenters. The idea, essentially is to find something like the convex combination of distributions but such that the combination minimizes the Wasserstein distance to each of the distributions being combined.
From Solomon et al. [2015] $$ \min_\mu \sum_i \alpha_i W_{2,e}(\mu,\mu_i) $$ In words, we are looking for the distribution $\mu$ that has the smallest weighted Wasserstein distance to each of several distributions. This can be computed efficiently as well as shown in the code above.
Below, we compute a sequence of Wasserstein barycenters between just two distributions. Note that in this particular case, we could have used the coupling matrix to compute a flow from one distribution to another which is called displacement interpolation in litterature [Solomon, personal communication].
# Interpolate between two distributions
WassersteinBarycenter_test()