Skip to content

Commit c21991b

Browse files
committed
Implementation of warmstart for network simplex can make use off precomputed potentials from sinkhorn or even related simplex
1 parent 3ee4386 commit c21991b

6 files changed

Lines changed: 489 additions & 18 deletions

File tree

ot/lp/EMD.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ enum ProblemType {
2929
MAX_ITER_REACHED
3030
};
3131

32-
int EMD_wrap(int n1,int n2, double *X, double *Y,double *D, double *G, double* alpha, double* beta, double *cost, uint64_t maxIter);
32+
int EMD_wrap(int n1,int n2, double *X, double *Y,double *D, double *G, double* alpha, double* beta, double *cost, uint64_t maxIter, double* alpha_init, double* beta_init);
3333
int EMD_wrap_omp(int n1,int n2, double *X, double *Y,double *D, double *G, double* alpha, double* beta, double *cost, uint64_t maxIter, int numThreads);
3434

3535
int EMD_wrap_sparse(

ot/lp/EMD_wrapper.cpp

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,8 @@
2222

2323

2424
int EMD_wrap(int n1, int n2, double *X, double *Y, double *D, double *G,
25-
double* alpha, double* beta, double *cost, uint64_t maxIter) {
25+
double* alpha, double* beta, double *cost, uint64_t maxIter,
26+
double* alpha_init, double* beta_init) {
2627
// beware M and C are stored in row major C style!!!
2728

2829
using namespace lemon;
@@ -93,6 +94,19 @@ int EMD_wrap(int n1, int n2, double *X, double *Y, double *D, double *G,
9394
}
9495
}
9596

97+
// Set warmstart potentials if provided
98+
if (alpha_init != nullptr && beta_init != nullptr) {
99+
// Compress warmstart potentials to only non-zero entries
100+
std::vector<double> alpha_compressed(n);
101+
std::vector<double> beta_compressed(m);
102+
for (uint64_t i = 0; i < n; i++) {
103+
alpha_compressed[i] = alpha_init[indI[i]];
104+
}
105+
for (uint64_t j = 0; j < m; j++) {
106+
beta_compressed[j] = beta_init[indJ[j]];
107+
}
108+
net.setWarmstartPotentials(&alpha_compressed[0], &beta_compressed[0], (int)n, (int)m);
109+
}
96110

97111
// Solve the problem with the network simplex algorithm
98112

ot/lp/_network_simplex.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,7 @@ def emd(
172172
center_dual=True,
173173
numThreads=1,
174174
check_marginals=True,
175+
warmstart_dual=None,
175176
):
176177
r"""Solves the Earth Movers distance problem and returns the OT matrix
177178
@@ -237,6 +238,11 @@ def emd(
237238
check_marginals: bool, optional (default=True)
238239
If True, checks that the marginals mass are equal. If False, skips the
239240
check.
241+
warmstart_dual: tuple of two arrays (alpha, beta), optional (default=None)
242+
Warmstart dual potentials to accelerate convergence. Should be a tuple
243+
(alpha, beta) where alpha is shape (ns,) and beta is shape (nt,).
244+
These potentials are used to guide initial pivots in the network simplex.
245+
Typically obtained from a previous EMD solve or Sinkhorn approximation.
240246
241247
.. note:: The solver automatically detects sparse format using the backend's
242248
:py:meth:`issparse` method. For sparse inputs:
@@ -373,8 +379,18 @@ def emd(
373379
a, b, edge_sources, edge_targets, edge_costs, numItermax
374380
)
375381
else:
382+
# Prepare warmstart if provided
383+
alpha_init = None
384+
beta_init = None
385+
if warmstart_dual is not None:
386+
alpha_init, beta_init = warmstart_dual
387+
alpha_init = np.asarray(alpha_init, dtype=np.float64)
388+
beta_init = np.asarray(beta_init, dtype=np.float64)
389+
376390
# Dense solver
377-
G, cost, u, v, result_code = emd_c(a, b, M, numItermax, numThreads)
391+
G, cost, u, v, result_code = emd_c(
392+
a, b, M, numItermax, numThreads, alpha_init, beta_init
393+
)
378394

379395
# ============================================================================
380396
# POST-PROCESS DUAL VARIABLES AND CREATE TRANSPORT PLAN
@@ -513,6 +529,11 @@ def emd2(
513529
check_marginals: bool, optional (default=True)
514530
If True, checks that the marginals mass are equal. If False, skips the
515531
check.
532+
warmstart_dual: tuple of two arrays (alpha, beta), optional (default=None)
533+
Warmstart dual potentials to accelerate convergence. Should be a tuple
534+
(alpha, beta) where alpha is shape (ns,) and beta is shape (nt,).
535+
These potentials are used to guide initial pivots in the network simplex.
536+
Typically obtained from a previous EMD solve or Sinkhorn approximation.
516537
517538
.. note:: The solver automatically detects sparse format using the backend's
518539
:py:meth:`issparse` method. For sparse inputs:

ot/lp/emd_wrap.pyx

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ import warnings
2020

2121

2222
cdef extern from "EMD.h":
23-
int EMD_wrap(int n1,int n2, double *X, double *Y,double *D, double *G, double* alpha, double* beta, double *cost, uint64_t maxIter) nogil
23+
int EMD_wrap(int n1,int n2, double *X, double *Y,double *D, double *G, double* alpha, double* beta, double *cost, uint64_t maxIter, double* alpha_init, double* beta_init) nogil
2424
int EMD_wrap_omp(int n1,int n2, double *X, double *Y,double *D, double *G, double* alpha, double* beta, double *cost, uint64_t maxIter, int numThreads) nogil
2525
int EMD_wrap_sparse(int n1, int n2, double *X, double *Y, uint64_t n_edges, uint64_t *edge_sources, uint64_t *edge_targets, double *edge_costs, uint64_t *flow_sources_out, uint64_t *flow_targets_out, double *flow_values_out, uint64_t *n_flows_out, double *alpha, double *beta, double *cost, uint64_t maxIter) nogil
2626
int EMD_wrap_lazy(int n1, int n2, double *X, double *Y, double *coords_a, double *coords_b, int dim, int metric, double *G, double* alpha, double* beta, double *cost, uint64_t maxIter) nogil
@@ -42,7 +42,7 @@ def check_result(result_code):
4242

4343
@cython.boundscheck(False)
4444
@cython.wraparound(False)
45-
def emd_c(np.ndarray[double, ndim=1, mode="c"] a, np.ndarray[double, ndim=1, mode="c"] b, np.ndarray[double, ndim=2, mode="c"] M, uint64_t max_iter, int numThreads):
45+
def emd_c(np.ndarray[double, ndim=1, mode="c"] a, np.ndarray[double, ndim=1, mode="c"] b, np.ndarray[double, ndim=2, mode="c"] M, uint64_t max_iter, int numThreads, alpha_init=None, beta_init=None):
4646
"""
4747
Solves the Earth Movers distance problem and returns the optimal transport matrix
4848
@@ -81,6 +81,10 @@ def emd_c(np.ndarray[double, ndim=1, mode="c"] a, np.ndarray[double, ndim=1, mod
8181
max_iter : uint64_t
8282
The maximum number of iterations before stopping the optimization
8383
algorithm if it has not converged.
84+
alpha_init : (ns,) numpy.ndarray, float64, optional
85+
Initial dual potentials for sources (warmstart)
86+
beta_init : (nt,) numpy.ndarray, float64, optional
87+
Initial dual potentials for targets (warmstart)
8488
8589
Returns
8690
-------
@@ -101,6 +105,12 @@ def emd_c(np.ndarray[double, ndim=1, mode="c"] a, np.ndarray[double, ndim=1, mod
101105
cdef np.ndarray[double, ndim=2, mode="c"] G=np.zeros([0, 0])
102106

103107
cdef np.ndarray[double, ndim=1, mode="c"] Gv=np.zeros(0)
108+
109+
# Warmstart potentials
110+
cdef np.ndarray[double, ndim=1, mode="c"] alpha_init_c
111+
cdef np.ndarray[double, ndim=1, mode="c"] beta_init_c
112+
cdef double* alpha_init_ptr = NULL
113+
cdef double* beta_init_ptr = NULL
104114

105115
if not len(a):
106116
a=np.ones((n1,))/n1
@@ -110,11 +120,18 @@ def emd_c(np.ndarray[double, ndim=1, mode="c"] a, np.ndarray[double, ndim=1, mod
110120

111121
# init OT matrix
112122
G=np.zeros([n1, n2])
123+
124+
# Setup warmstart pointers if provided
125+
if alpha_init is not None and beta_init is not None:
126+
alpha_init_c = np.ascontiguousarray(alpha_init, dtype=np.float64)
127+
beta_init_c = np.ascontiguousarray(beta_init, dtype=np.float64)
128+
alpha_init_ptr = <double*> alpha_init_c.data
129+
beta_init_ptr = <double*> beta_init_c.data
113130

114131
# calling the function
115132
with nogil:
116133
if numThreads == 1:
117-
result_code = EMD_wrap(n1, n2, <double*> a.data, <double*> b.data, <double*> M.data, <double*> G.data, <double*> alpha.data, <double*> beta.data, <double*> &cost, max_iter)
134+
result_code = EMD_wrap(n1, n2, <double*> a.data, <double*> b.data, <double*> M.data, <double*> G.data, <double*> alpha.data, <double*> beta.data, <double*> &cost, max_iter, alpha_init_ptr, beta_init_ptr)
118135
else:
119136
result_code = EMD_wrap_omp(n1, n2, <double*> a.data, <double*> b.data, <double*> M.data, <double*> G.data, <double*> alpha.data, <double*> beta.data, <double*> &cost, max_iter, numThreads)
120137
return G, cost, alpha, beta, result_code

0 commit comments

Comments
 (0)