This file contains a collection of example codes for various functions offered by the OptimalTransport.jl
package, and can be treated as an informal mini-tutorial for using the package.
Install the package as follows,
# using Pkg; # Pkg.add("OptimalTransport")
and load into our environment:
using OptimalTransport using Distances using LinearAlgebra using Plots, LaTeXStrings using StatsPlots using Random, Distributions
First, let us initialise two random probability measures $\mu$ (source measure) and $\nu$ (target measure) in 1D:
N = 200; M = 200 μ_spt = rand(N) ν_spt = rand(M) μ = fill(1/N, N) ν = fill(1/M, M);
Now we compute the quadratic cost matrix $C_{ij} = \| x_i - x_j \|^2$
C = pairwise(SqEuclidean(), μ_spt', ν_spt');
The earth mover's distance is defined as the optimal value of the Monge-Kantorovich problem:
\[ \min_{\gamma \in \Pi(\mu, \nu)} \langle \gamma, C \rangle = \min_{\gamma \in \Pi(\mu, \nu)} \sum_{i, j} \gamma_{ij} C_{ij} \]
where $\Pi(\mu, \nu)$ denotes the set of joint distributions whose marginals agree with $\mu$ and $\nu$. In the case where $C$ is the quadratic cost, this corresponds to what is known as the 2-Wasserstein distance.
N.B. At the moment, this functionality is available through PyCall and the Python OT library.
Using emd()
returns the optimal transport plan $\gamma$:
γ = OptimalTransport.emd(μ, ν, C);
whilst using emd2()
returns the optimal transport cost at the minimum:
OptimalTransport.emd2(μ, ν, C)
0.00035985765174184283
We may add an entropy term to the Monge-Kantorovich problem to obtain the entropically regularised optimal transport problem:
\[ \min_{\gamma \in \Pi(\mu, \nu)} \langle \gamma, C \rangle - \epsilon H(\gamma) \]
where $H(\gamma) = -\sum_{i, j} \gamma_{ij} \log(\gamma_{ij})$ is the entropy of the coupling $\gamma$ and $\epsilon$ controls the strength of regularisation.
This problem is strictly convex and admits a very efficient iterative scaling scheme for its solution known as the Sinkhorn algorithm [Peyre 2019].
Compute the transport plan using native Julia vs. POT
ϵ = 0.01 γ_ = OptimalTransport.pot_sinkhorn(μ, ν, C, ϵ); γ = OptimalTransport.sinkhorn(μ, ν, C, ϵ)
200×200 Array{Float64,2}: 2.91507e-8 1.08811e-5 0.000137813 … 5.17463e-6 5.99191e-11 4.71167e-7 1.38525e-6 0.00011384 2.33743e-5 7.00902e-13 7.28767e-9 2.351e-5 0.000129499 2.2821e-6 3.74765e-10 2.97569e-19 2.55367e-5 1.32895e-8 5.16509e-14 5.59045e-5 3.77886e-9 3.2011e-5 0.000121464 1.52566e-6 8.2118e-10 6.53149e-12 0.000138215 2.55868e-5 … 2.06694e-8 1.67555e-7 1.12751e-17 5.17631e-5 8.34279e-8 9.05558e-13 2.67655e-5 1.27122e-29 1.00441e-8 1.0707e-14 1.50177e-22 5.94785e-5 2.41221e-34 7.28399e-11 7.8201e-18 1.36278e-26 8.08082e-6 1.5574e-30 4.04126e-9 2.72209e-15 2.51594e-23 4.30015e-5 ⋮ ⋱ 4.98756e-6 1.19019e-7 6.27328e-5 6.96976e-5 5.61575e-15 2.2609e-5 1.18287e-8 2.72473e-5 0.0001156 8.49335e-17 3.4688e-6 1.8495e-7 7.14726e-5 5.98965e-5 1.29693e-14 2.83593e-21 8.39904e-6 1.11385e-9 1.25015e-15 0.000105311 1.2626e-21 6.76949e-6 7.13907e-10 … 6.50686e-16 0.000113719 0.000112871 1.86815e-15 3.38222e-9 6.12877e-6 2.7e-27 5.51645e-15 0.000118528 1.51543e-6 1.07966e-10 4.35706e-6 2.29879e-5 1.1457e-8 2.68903e-5 0.000116045 8.03195e-17 6.35584e-5 8.87114e-10 8.27883e-6 0.000125206 1.06775e-18
Now we can check that both implementations arrive at the same result:
norm(γ - γ_, Inf)
7.050741370959179e-11
Again, we can compute the optimal value (a scalar) of the entropic OT problem using sinkhorn2()
:
OptimalTransport.sinkhorn2(μ, ν, C, ϵ)
0.00483037072961987
Try computing the transport plan for the same problem, this time using a quadratic regularisation [Lorenz 2019] instead of the more common entropic regulariser term. We solve the problem
\[ \min_{\gamma \in \Pi(\mu, \nu)} \langle \gamma, C \rangle + \epsilon \frac{\| \gamma \|_F^2}{2} \]
One property of quadratically regularised OT is that the resulting transport plan $\gamma$ is sparse. We take advantage of this and represent it as a sparse matrix.
γ = OptimalTransport.quadreg(μ, ν, C, ϵ); γ
200×200 SparseArrays.SparseMatrixCSC{Float64,Int64} with 673 stored entries : [3 , 1] = 0.00180664 [13 , 1] = 0.00154934 [134, 1] = 0.00164376 [191, 1] = 2.65912e-7 [21 , 2] = 0.00146994 [142, 2] = 0.00163552 [156, 2] = 0.000852266 [162, 2] = 0.00104228 [67 , 3] = 0.000903313 ⋮ [154, 198] = 0.0022591 [176, 198] = 0.00225247 [19 , 199] = 0.00266787 [34 , 199] = 0.000156919 [41 , 199] = 0.00206001 [122, 199] = 0.000115211 [12 , 200] = 0.000139041 [53 , 200] = 0.00107328 [88 , 200] = 0.00199785 [127, 200] = 0.00178983
This is a log-stabilised version of the Sinkhorn algorithm which is useful when $\epsilon$ is very small [Schmitzer 2019]
ϵ = 0.005 γ = sinkhorn_stabilized(μ, ν, C, ϵ, max_iter = 5000); γ_ = pot_sinkhorn(μ, ν, C, ϵ, method = "sinkhorn_stabilized", max_iter = 5000); norm(γ - γ_, Inf) # Check that we get the same result as POT
4.4244019442606954e-10
In addition to log-stabilisation, we can additionally use $\epsilon$-scaling [Schmitzer 2019]
γ = sinkhorn_stabilized_epsscaling(μ, ν, C, ϵ, max_iter = 5000) γ_ = pot_sinkhorn(μ, ν, C, ϵ, method = "sinkhorn_epsilon_scaling", max_iter = 5000) norm(γ - γ_, Inf) # Check that we get the same result as POT
5.60221516149751e-10
Unbalanced transport was introduced by [Chizat 2019] to interpolate between general positive measures which do not have the same total mass. That is, for $\mu, \nu \in \mathcal{M}_+$ and a cost function (e.g. quadratic) $C$, we solve the following problem:
\[ \min_{\gamma \ge 0} \epsilon \mathrm{KL}(\gamma | \exp(-C/\epsilon)) + \lambda_1 \mathrm{KL}(\gamma_1 | \mu) + \lambda_2 \mathrm{KL}(\gamma_2 | \nu) \]
where $\epsilon$ controls the strength of entropic regularisation, and $\lambda_1, \lambda_2$ control how strongly we enforce the marginal constraints.
We construct two random measures, now with different total masses:
N = 100; M = 200 μ_spt = rand(N) ν_spt = rand(M) μ = fill(1/N, N) ν = fill(1/N, M);
Set up quadratic cost matrix:
C = pairwise(SqEuclidean(), μ_spt', ν_spt')
100×200 Array{Float64,2}: 0.00973849 0.00262453 0.0869126 … 0.117825 0.332162 0.067835 0.0286988 0.0470286 0.316857 0.373745 0.0950142 5.83616e -5 0.195788 0.156046 0.00239968 2.91246e-7 0.84664 0.365115 0.142223 0.180266 0.593851 0.670866 0.0101056 0.046378 4 0.118995 0.0885081 0.00235579 0.00940576 0.676683 0.25677 0.12811 0.164332 0.564629 … 0.639784 0.0143344 0.038477 6 0.143269 0.109598 0.000224522 0.00402348 0.73301 0.291899 0.120213 0.155371 0.547912 0.621981 0.0171436 0.034206 4 0.0871144 0.0613542 0.00967109 0.0215469 0.597224 0.208775 0.0691333 0.046431 0.0170462 0.0320438 0.548463 0.18037 ⋮ ⋱ 0.0708998 0.0478807 0.0161857 0.0308599 0.553419 0.183216 0.0109369 0.0231142 0.248077 0.298684 0.139182 0.003270 46 0.129679 0.166108 0.567918 0.643285 0.0138158 0.039339 7 0.153317 0.192731 0.616305 0.694719 0.00741204 0.052803 4 0.0621895 0.0881092 0.413284 … 0.477921 0.0521084 0.007675 59 0.000243053 0.00101528 0.142811 0.181774 0.243287 0.031455 8 0.00823141 0.0018726 0.0916673 0.12335 0.323054 0.063753 7 0.202143 0.247065 0.710811 0.794849 0.000786697 0.082849 0.237632 0.286148 0.776105 0.863812 9.65084e-5 0.106085
Now we solve the corresponding unbalanced, entropy-regularised transport problem.
ϵ = 0.01 λ = 1.0 γ_ = pot_sinkhorn_unbalanced(μ, ν, C, ϵ, λ) γ = sinkhorn_unbalanced(μ, ν, C, λ, λ, ϵ)
100×200 Array{Float64,2}: 0.000300343 0.000443705 1.16984e-8 … 4.42153e-20 2.53388e-7 9.0623e-5 1.05125e-5 2.42567e-18 1.76936e-9 0.00044704 1.45317e-11 5.60844e-10 0.000318578 1.16624e-41 1.81059e-19 9.86971e-9 1.59444e-10 2.10119e-29 7.99272e-5 4.0368e-5 2.73685e-8 4.18609e-7 0.000278595 2.44218e-34 7.9992e-15 3.32599e-8 6.44706e-10 3.20841e-28 … 4.30289e-5 7.30943e-5 2.42141e-9 5.09237e-8 0.000345585 8.76122e-37 2.39012e-16 6.50491e-8 1.40241e-9 1.51586e-27 2.88474e-5 9.94786e-5 6.25133e-7 5.95996e-6 0.000126308 6.49876e-31 9.15334e-13 3.32004e-6 2.33138e-5 5.31367e-5 7.49421e-29 1.37868e-11 ⋮ ⋱ 2.83126e-6 2.05215e-5 5.89279e-5 4.64594e-29 1.05535e-11 0.000275985 5.9232e-5 1.21388e-15 1.10125e-11 0.000167149 2.90864e-8 5.52234e-10 2.36245e-28 4.63662e-5 6.86059e-5 3.74689e-9 5.27818e-11 2.56165e-30 0.000120471 2.44461e-5 7.40322e-6 4.02037e-7 3.66215e-22 … 3.00506e-7 0.00048551 0.000518728 0.000348276 2.92042e-11 2.13945e-16 6.43657e-6 0.00032958 0.000451483 6.86319e-9 1.03758e-19 3.59689e-7 4.80451e-11 3.90134e-13 3.40887e-34 0.00039541 2.05011e-6 1.9863e-12 1.12593e-14 7.15491e-37 0.000609097 2.8861e-7
Check agreement with POT:
norm(γ - γ_, Inf) # Check that we get the same result as POT
1.4350893520406749e-12
Let's construct source and target measures again
μ_spt = ν_spt = LinRange(-2, 2, 100) C = pairwise(Euclidean(), μ_spt', ν_spt').^2 μ = exp.((-(μ_spt).^2)/0.5^2) μ /= sum(μ) ν = ν_spt.^2 .*exp.((-(ν_spt).^2)/0.5^2) ν /= sum(ν) plot(μ_spt, μ, label = L"\mu") plot!(ν_spt, ν, label = L"\nu") current()
Now compute the entropic transport plan:
γ = OptimalTransport.sinkhorn(μ, ν, C, 0.01) heatmap(γ, title = "Entropically regularised OT", size = (500, 500)) current()
Using a quadratic regularisation, notice how the 'edges' of the transport plan here are sharper compared to the entropic OT transport plan.
γ_quad = Matrix(OptimalTransport.quadreg(μ, ν, C, 5, maxiter = 500)) heatmap(γ_quad, title = "Quadratically regularised OT", size = (500, 500)) current()
For a collection of probability measures $\{ \mu_i \}_{i = 1}^N$ and corresponding cost matrices $C_i$, we define the barycenter to be the measure $\mu$ that solves the following:
\[ \min_{\mu \text{ a distribution}} \sum_{i = 1}^N \lambda_i \mathrm{entropicOT}^{\epsilon}_{C_i}(\mu, \mu_i) \]
where $\mathrm{entropicOT}^\epsilon_{C_i}(\mu, \mu_i)$ denotes the entropic optimal transport cost between measures $(\mu, \mu_i)$ with cost $C_i$ and entropic regularisation strength $\epsilon$.
For instance, below we set up two measures $\mu_0$ and $\mu_1$ and compute the weighted barycenters. The weights are $(w, 1-w)$ where $w = 0.25, 0.5, 0.75$.
spt = LinRange(-1, 1, 250) f(x, σ) = exp.(-x.^2/σ^2) normalize(x) = x./sum(x) mu_all = hcat([normalize(f(spt .- z, 0.1)) for z in [-0.5, 0.5]]...)' C_all = [pairwise(SqEuclidean(), spt', spt') for i = 1:size(mu_all, 1)] plot(size = (400, 200)); plot!(spt, mu_all[1, :], label = L"\mu_0") plot!(spt, mu_all[2, :], label = L"\mu_1") for s = [0.25, 0.5, 0.75] a = sinkhorn_barycenter(mu_all, C_all, 0.01, [s, 1-s]; max_iter = 1000); plot!(spt, a, label= latexstring("\\mu_{$s}")) end current()