1 Introduction

Optimal Transport (OT) (Villani, 2008) and its associated Wasserstein distance allow the comparison of probability measures by aligning points between the distributions with respect to their masses and transportation costs. Recent advances from a computational perspective, notably with the entropic regularization introduced in (Cuturi, 2013) or the Sliced Wasserstein (Rabin & Peyré, 2011), led to some success stories of OT in the machine learning community, including the Wasserstein Generative Adversarial Networks (Arjovsky et al., 2017), Domain Adaptation (Courty et al., 2014), Color Transfer (Rabin & Peyré, 2011), to cite a few. Even though the square Euclidean distance is used most of the time to compare points of the distributions, various other ground metrics can be naturally used or learned to better capture the idiosyncrasies of the application at hand: the Earth mover’s distance in computer vision tasks, the Mahalanobis distance (Paty and Cuturi 2019), or concave functions in economy such as the square root of the Euclidean distance (Delon et al., 2012), etc.

Whatever the cost function, it is worth noting that the OT problem has been originally formulated so as to deal with distributions that are required to lie in the same space. To relax this constraint, a distance between metric spaces, named Gromov Wasserstein (GW),Footnote 1 has been introduced in (Memoli, 2007). It takes the form of the generalization of the well-known Quadratic Assignment problem (Beckman & Koopmans, 1957) with any distribution (Mémoli, 2011) and any loss function (Peyré et al., 2016). The intuition is still to align points between two distributions but the method only relies on pairwise distances, in each space separately. This allows notably to take into account the structure of each distribution while being invariant to rotation and translation. Therefore, GW is a relevant tool for matching and partitioning tasks involving graphs (Xu et al., 2019a, Xu et al., 2019b; Vayer et al., 2019a), by allowing e.g. to encode some structure like the shortest path between two vertices. GW has been further used in various other domains, such as Heterogeneous Domain Adaptation (Yan et al., 2018), Shape Matching (Mémoli, 2011; Bronstein et al., 2010; Vayer et al., 2019b), Object Modeling with Deep Learning (Ezuz et al., 2017), Generative Adversarial Networks (Bunne et al., 2019). The Wasserstein distance and the GW distance have also been jointly used in (Vayer et al., 2018) leading to the so-called Fused-Gromov Wasserstein distance.

From an algorithmic perspective, most of the previous methods resort to the entropic approximation (EGW) of the original GW formulation introduced in (Peyré et al., 2016) and based on a gradient descent followed by a projection step, both according to the Kullback Leibler (KL) divergence. While a naive implementation of the original GW problem leads to a \(O\left( N^4\right)\) complexity, Peyré et al., (2016) further show that one can compute GW in \(O\left( N^3\right)\) operations for a certain class of losses. Some other attempts have been recently proposed in the literature to speed-up the GW calculation. Sliced Gromov-Wasserstein (SGW) (Vayer et al., 2019b) takes inspiration from the Sliced Wasserstein distance (Rabin & Peyré, 2011) by projecting each distribution in an 1D line and then solving the 1D Gromov-Wasserstein problem efficiently in \(O\left( N\log (N)\right)\). The Anchor Energy (AE) distance from Sato et al., (2020),Footnote 2 is also related to the GW distance but simplifies the problem into \(N^2\) linear sub-problems. The overall time complexity for solving AE is \(O\left( N^2 \log (N)\right)\). Scalable Gromov-Wasserstein Learning (S-GWL) (Xu et al., 2019a) decomposes recursively the two large probability measures into a set of small pairwise aligned distributions using a common Gromov-Wasserstein barycenter (Peyré et al., 2016). The final transport plan is the aggregation of the result of GW on each small aligned distributions.

In this paper, we aim at overcoming the main algorithmic bottleneck of \(\hbox {EGW}_{}\): the multiplication of a 4D tensor with a 2D matrix, which we interpret as an expectation over matrices. We leverage this interpretation, using sampling to approximate the expectation instead of computing it entirely, reducing the complexity to \(O\left( N^2\right)\). Unlike SGW and AE which propose simplified distances, we optimize the original GW distance. Unlike \(\hbox {EGW}_{}\) and S-GWL which have speedups for specific loss functions, we lower the complexity with any loss function. We obtain a generic algorithm, called Sampled Gromov Wasserstein, supported by theoretical convergence guarantees. We further show that when the number of sampled matrices is 1, the particular 1D case of the OT can be used to compute an update in \(O\left( N\log (N)\right)\). This version, called Pointwise Gromov Wasserstein, overcomes most of the limitations of SGW (Vayer et al., 2019b) detailed in Sect. 3, while still being very fast. Our contributions are supported by experiments on synthetic and real datasets. Interestingly, those experiments show evidence that our method outperforms the state of the art when it comes to finding the best compromise between the computation time and the quality of the distance. This behavior takes its origin from (i) the stochastic nature of our method which can reduce the risk to get stuck in local minima and (ii) the fact that the other approaches do not scale well. An experiment on a graph classification task shows that being able to change the loss function for free is of high interest for finding the one that best fits the problem at hand.

This article is organized as follows: Sect. 2 details the notations and the necessary background on GW. Section 3 covers the state of the art approaches for solving the underlying problem. Section 4 presents our Sampled Gromov Wasserstein algorithm, derives convergence guarantees for it, and introduces our very fast specialized variant called Pointwise Gromov Wasserstein. Experiments are detailed in Sect. 5.

2 Background on GW

In this section, we introduce the Optimal Transport (OT) problem with its associated Wasserstein distance, and the Gromov Wasserstein distance that allows the comparison of distributions lying in different spaces. Let \(({\mathcal {X}}, {\mathcal {C}}^{{\mathcal {X}}})\) be a compact metric space where \({\mathcal {X}}\) is a set and \({\mathcal {C}}^{{\mathcal {X}}}\) its associated metric. Let \(\mu\) be a distribution with finite p-moment on \(({\mathcal {X}}, {\mathcal {C}}^{{\mathcal {X}}})\). Similarly, \(({\mathcal {Y}}, {\mathcal {C}}^{{\mathcal {Y}}})\) denotes another compact metric space and \(\nu\) a distribution with finite p-moment on that space. We denote as \(\varPi _{\mu \nu }\) the collection of coupling probability measures on \({\mathcal {X}} \times {\mathcal {Y}}\) constrained by the marginals \(\mu\) and \(\nu\). \(\varPi _{\mu \nu }\) defines the so-called set of admissible transport plans from \(\mu\) to \(\nu\), used to define the OT problem.

Optimal Transport OT consists in finding the best mapping (or coupling or transport plan) between two distributions \(\mu\) and \(\nu\) on the same space, i.e., \({\mathcal {X}} = {\mathcal {Y}}\) and \({\mathcal {C}}^{{\mathcal {X}}}= {\mathcal {C}}^{{\mathcal {Y}}}\). Denoting as \({\mathcal {C}}\) this common distance, one can define the p-Wasserstein distance (Kantorovich 1942) to the power of p, as follows:

$$\begin{aligned} W_p^{p}({\mathcal {C}}) = \min _{\gamma \in \varPi _{\mu \nu }} \int _{{\mathcal {X}} \times {\mathcal {Y}}} {\mathcal {C}}^{p}(x, y) d\gamma (x, y). \end{aligned}$$
(1)

In the discrete version of Problem (1), \(\mu\) and \(\nu\) are empirical measures supported by two finite sets of points. In this context, \(\mu = \sum _{i=1}^{I} a_i \delta _{x_i}\) defined by I points \((x_i)_{i \in \llbracket 1, I\rrbracket }\) in \({\mathcal {X}}\) and the associated probability vector a. In the same way, we define \(\nu = \sum _{k=1}^{K} b_k \delta _{y_k}\) in \({\mathcal {Y}}\) associated with the probability vector b. The set of admissible transport plans becomes \(\varPi _{a b}= \{T \in {\mathbb {R}}_+^{I \times K} | T \mathbf{1 }_{K} = a, T^\mathrm {T} \mathbf{1 }_{I} = b \}\). In this discrete case, each distance function \({\mathcal {C}}{}\) can be considered as a matrix (or tensor) \(C{}\). Therefore, the discrete p-Wasserstein distance to the power p is written as follows:

$$\begin{aligned} W_p^{p}({\mathcal {C}})&= \min _{ T \in \varPi _{a b}}\sum _{i,k=1}^{I,K}{\mathcal {C}}^{p}(x_i, y_k) T_{ik} = \min _{ T \in \varPi _{a b}}\left\langle C^p,T \right\rangle \end{aligned}$$
(2)

where \(\left\langle .,. \right\rangle\) is the Frobenius dot product. To simplify the notations, it is often assumed that \(I = K\) (same number of points in both sets) and \(N\) is used to denote this value. The optimal transport plan \(T^*\) can be found from (2) using a linear solver (Bonneel et al., 2011) with, at least, a complexity of \(O\left( N^3 \log (N)\right)\) (Pele & Werman, 2009). To lower this complexity, an entropic regularization can be added (Cuturi, 2013) leading to a strongly convex problem that yields a smooth and unique solution in \(O\left( PN^2\right)\) with \(P\) the number of Sinkhorn’s iterations. Let \(\epsilon \in {\mathbb {R}}_+\) be a regularization parameter and let \({\mathcal {H}}(T) = \sum _{ik} T_{ik} \log (T_{ik})\) be the negative entropy, the optimal plan \(T^*\) of Eq. (2) can be approximated by

$$\begin{aligned} T^{*} \approx \underset{T \in \varPi _{a b}}{\text {argmin}}\left\langle C^p,T \right\rangle + \epsilon {\mathcal {H}}(T). \end{aligned}$$
(3)

Gromov Wasserstein Distance (GW) While the OT problem requires the two distributions to lie in the same space, the GW distance allows to compare distributions in different metric spaces. Let \({\mathcal {L}}\) be a bounded loss function which allows the comparison of two distances. GW (Mémoli, 2011, 2009; Peyré et al., 2016) is defined as follows:

$$\begin{aligned}&GW{}\left( {\mathcal {C}}^{{\mathcal {X}}}, {\mathcal {C}}^{{\mathcal {Y}}}, \mu , \nu \right) = \min _{\gamma \in \varPi _{\mu \nu }} \int _{\left( {\mathcal {X}} \times {\mathcal {Y}}\right) ^2 } {\mathcal {L}}\left( {\mathcal {C}}^{{\mathcal {X}}}(x, x'), {\mathcal {C}}^{{\mathcal {Y}}}(y, y')\right) d\gamma (x, y) d\gamma (x', y'). \end{aligned}$$
(4)

The discrete case (see Fig. 1) can be formulated as:

$$\begin{aligned}&GW{}\left( {\mathcal {C}}^{{\mathcal {X}}}, {\mathcal {C}}^{{\mathcal {Y}}}, a, b\right) = \min _{ T \in \varPi _{a b}}\sum _{i, j = 1}^{I,I} \sum _{k, l = 1}^{K,K} L_{ijkl}T_{ik}T_{jl}, \end{aligned}$$
(5)

where \(L_{ijkl} = {\mathcal {L}}\left( {\mathcal {C}}^{{\mathcal {X}}}(x_i, x_j), {\mathcal {C}}^{{\mathcal {Y}}}(y_k, y_l)\right)\). This formulation exhibits an important property of GW: only the pairwise distances are needed. This explains why the Gromov Wasserstein distance is often used to compare graphs, for which Chowdhury and Mémoli (2019) proves that GW is a pseudometric.

Fig. 1
figure 1

Illustration of GW, with only one term \(L_{ijkl}\) of the quadruple sum of Eq. (5)

3 Approaches to solve GW

We describe here the most used method for solving GW, namely Entropic Gromov Wasserstein, as well as two other approaches that aim at lowering the time complexity of the former. As all these methods use an iterative optimization, for the sake of simplicity, we omit in this section the number S of iterations (of the outer loop).

Entropic Gromov Wasserstein (EGW) To solve an approximation of Problem (5), the authors of (Peyré et al., 2016) generalize the idea introduced in Solomon et al., (2016) by using a gradient descent step followed by a projection, both according to the Kullback Leibler (KL) divergence. This boil down to a two-step loop. First, from the current estimation of the transport plan T, a new matrix defined as \(\varLambda _{jl} = \sum _{i, k = 1}^{I, K} L_{ijkl} T_{ik}\) is computed, and which can be seen as an updated cost matrix. Second, a new estimate of the transport plan is obtained by solving the following entropic regularization-based OT problem:

$$\begin{aligned} \min _{ T \in \varPi _{a b}}\left\langle \varLambda ,T \right\rangle + \epsilon {\mathcal {H}}(T). \end{aligned}$$
(6)

When the loss \({\mathcal {L}}({\mathcal {C}}^{{\mathcal {X}}},{\mathcal {C}}^{{\mathcal {Y}}})\) can be decomposed as \(f_1({\mathcal {C}}^{{\mathcal {X}}}) + f_2({\mathcal {C}}^{{\mathcal {Y}}}) - h_1({\mathcal {C}}^{{\mathcal {X}}})h_2({\mathcal {C}}^{{\mathcal {Y}}})\) for functions (\(f_1,f_2,h_1,h_2\)), it is shown that the \(\varLambda\) matrix can be computed in \(O\left( N^3\right)\). This notably holds for the square loss and the KL divergence. However, in the general case, the complexity is \(O\left( N^4\right)\), making this method intractable as N grows, as shown in our experiments.

Sliced Gromov-Wassertein (SGW) In Rabin and Peyré (2011), the authors introduce an alternative metric, called Sliced Wasserstein distance, which uses random 1D-projections. The advantage of this method lies in the fact that the OT Problem (2) can be simply solved by sorting both empirical distributions (in \(O\left( N\log (N)\right)\)) and matching the sorted lists. In a similar manner, Sliced Gromov-Wasserstein (SGW) (Vayer et al., 2019b) projects each distribution in a common 1D space, to solve the Gromov-Wasserstein problem (5) efficiently. While being very fast to compute, SGW comes with some limitations: (i) it cannot be used in general on graphs because a feature representation is needed to allow the 1D projection, (ii) it does not output an explicit transport plan which can be a pitfall in some applications like domain adaptation, (iii) it does not approximate the original GW distance and (iv) it is not naturally invariant to rotation (although the authors propose a solution by repeatedly calling SGW). Note that while SGW’s theoretical result and the \(O\left( N\log (N)\right)\) time complexity are relying on the square loss, its algorithmic approach can be adapted to handle arbitrary losses. This adaptation results in a \(O\left( N^2\right)\) time complexity.

Scalable GW Learning (S-GWL) Scalable Gromov-Wasserstein Learning (Xu et al., 2019a) aims at making GW tractable to large scale graph analysis. It recursively decomposes the two original graphs into a set of smaller sub-graph pairs, using Gromov-Wasserstein barycenters (Peyré et al., 2016). Then, these sub-graphs are matched. The transport plan is updated with a proximal gradient method regularized with a KL divergence. The time complexity is \(O\left( N^2 \log (N)\right)\) when the cost matrices \(C^{{\mathcal {X}}}\) and \(C^{{\mathcal {Y}}}\) are not sparse and \({\mathcal {L}}\) is the square loss. However, with an arbitrary \({\mathcal {L}}\), the gain in complexity does not hold anymore because S-GWL cannot leverage the closed-form solution for the barycenter calculation.

4 Scalable GW optimization

We aim to address in this section the algorithmic bottleneck of \(\hbox {EGW}_{}\) (Peyré et al., 2016) which prevents its use on large scale problems. By rewriting Eq. (5) as an alternating optimization problem, we propose to compute the GW distance by solving iteratively an OT problem from a cost matrix seen as the expectation of a random variable. This allows us to propose a sampling strategy to drastically reduce the algorithmic complexity of GW. We introduce our algorithm, called Sampled Gromov Wasserstein (SaGroW), and then derive its convergence guarantees.

We also present some special case and a variant of SaGroW: Pointwise Gromov Wasserstein (PoGroW) which leverages very efficient 1D OT solvers but does not exhibit the drawbacks of SGW, and \(SaGroW{}^{KL}\) a version using a Kullback-Leibler regularization. We finally show that an appropriate sampling strategy can be also be used to accurately and efficiently approximate the GW distance from a known transport plan.

4.1 Sampled Gromov Wasserstein (SaGroW)

It is known that the GW problem as described in Eq. (5) is not convex in general and thus difficult to solve. On the other hand, we can note that the transport plan T appears twice in the formulation. In the following, we suggest to treat these two instances differently and solve the problem with respect to two transport plan variables T and \(T'\), as follows:

$$\begin{aligned} \min _{ T \in \varPi _{a b}}\min _{ T' \in \varPi _{a b}}\sum _{i,k=1}^{I,K}\sum _{j,l=1}^{I,K} L_{ijkl} T'_{ik} T_{jl}. \end{aligned}$$
(7)

Even though our sampling strategy leverages this decomposition into T and \(T'\), as if they were two different transport plans, note that we still solve the original GW problem. Indeed, as we will explain, our Algorithm fuses T and \(T'\) after each update, fulfilling the \(T = T'\) constraint.

In an alternating optimization, with a fixed T, the optimal \(T'\) is thus the solution of the following OT problem:

$$\begin{aligned} \min _{ T' \in \varPi _{a b}}\sum _{i,k,j,l=1}^{I,K,I,K} T_{jl}L_{ijkl} T'_{ik} = \min _{ T' \in \varPi _{a b}}\left\langle \sum _{j,l=1}^{I,K} T_{jl}L_{.j.l},T' \right\rangle \end{aligned}$$
(8)

where \(L_{.j.l}\) is an extracted matrix i.e., \(\left( L_{.j.l}\right) _{ik} = L_{ijkl}\).

As the transport plan T sums to 1, we can interpret it as (the parameters of) a categorical distribution on pairs of points (jl), or equivalently on the associated matrices \(L_{.j.l}\). We thus define a random variable \(\mathbf{C }\) on matrices, definedFootnote 3 by the distribution \({\mathbb {P}}(\mathbf{C }= L_{.j.l}) = T_{jl} \, \, \forall (j,l) \in \llbracket 1, N\rrbracket ^2\). Leveraging this random variable, the cost matrix \(\sum _{j,l} T_{jl}L_{.j.l}\) used in problem (8) can be seen as the expectation of \(\mathbf{C }\). Therefore, the problem can be rewritten as follows:

$$\begin{aligned} \min _{ T' \in \varPi _{a b}}\left\langle {\mathbb {E}}(\mathbf{C }),T' \right\rangle . \end{aligned}$$
(9)

While solving this problem is still in \(O\left( N^4\right)\) in general, it presents the advantage of opening the door to a sampling strategy allowing a reduction of the complexity. Indeed, rather than computing the entire expectation \({\mathbb {E}}(\mathbf{C })\), we suggest here to calculate an approximation by sampling \(M\) matrices \(\left\{ C^m\right\} _{m=1}^M\). To get a matrix \(C^m\) drawn according to the distribution of \(\mathbf{C }\), it suffices to sample two indices \((j_m,l_m)\) following the weights of the matrix T. Consequently, \(C^m\) takes the form of the matrix \(L_{.j_m.l_m}\). Using these sampled matrices, Problem (9) can be approximated as follows:

$$\begin{aligned} \min _{ T' \in \varPi _{a b}}\left\langle \frac{1}{M} \sum _{m=1}^MC^m,T' \right\rangle . \end{aligned}$$
(10)

This approximation comes with two main advantages: (i) it allows a reduction of the computation time of the GW problem and (ii) similarly to a mini batch gradient descent, it might avoid being stuck in local minima and thus might lead to a better transport plan. Even though Problem (10) can be solved efficiently with any OT solver, our approach resorts to the Sinkhorn method (Cuturi, 2013) leading to a time complexity of \(O\left( (M+P) N^2\right)\) due to summing over M matrices and P iterations of the Sinkhorn algorithm.

Algorithm 1 gives the pseudo-code of Sampled Gromov Wasserstein (SaGroW). In the absence of prior, the transport plan \(T_0\) is initialized to the joint distribution \(ab^T\) (line 1). At each iteration, M pairs of indices \((j_m,l_m)\) are sampled from the current transport plan \(T_s\) (line 3). Then \({\widehat{\varLambda }}{}\), the approximation of \({\mathbb {E}}(\mathbf{C })\), is computed (line 4) and used in an entropic regularization-based OT problem (6) solved using the Sinkhorn algorithm, yielding the plan \(T'_{s}\) (line 5). As indicated before, Problem (7) inherently assumes that \(T=T'\). To ensure that \(T'\) stays close to T and to mitigate the nature of the process, we perform a partial update \((1-\alpha )T_s + \alpha T'_s\). Given the symmetric roles of T and \(T'\) (as long as \(C^{{\mathcal {X}}}\) and \(C^{{\mathcal {Y}}}\) are symmetric) this partial update becomes our next plan \(T_{s+1}\) (line 6). This update, inspired by the Frank-Wolfe algorithm, allows us to derive theoretical guarantees (see next section). Notice that Algorithm 1 returns a single transport plan and thus aims at minimizing the original GW problem. In practice, other strategies can be used: as the previous plan \(T_s\) and the optimized \(T'_s\) can be interpreted as distributions, line 6 can be omitted and replaced by a KL regularization (on line 5) between them, as detailed in Sect. 4.4.

We end this section by noting that when the expectation is fully computed in SaGroW (\(i.e.\), \(M=\infty\) and “\(M=N^2\)” in terms of complexity as sampling becomes useless) and \(\alpha\) is set to 1, our method is strictly equivalent to the two steps loop of \(\hbox {EGW}_{}\) described in Sect. 3. This connection will be used advantageously in the next section by deriving new convergence guarantees for \(\hbox {EGW}_{}\) when the GW problem is concave.

figure a

4.2 Convergence analysis

In this section, we aim at studying the convergence of Algorithm 1. Note that convergence guarantees have been already derived for \(\hbox {EGW}_{}\) in (Peyré et al., 2016). However, based on Rangarajan et al., (1999), this convergence has been proven only when \(L\) produces a convex problem. Unlike Peyré et al., (2016), the guarantees presented in this section have two main advantages: (i) they hold whatever the loss function, (ii) a convergence on average is proven to a stationary point. Note that other results related to the GW problem have been recently derived in the literature. The authors of Xu et al., (2019b) prove the convergence of their proximal point method to a stationary point as long as their regularized GW problem can be solved perfectly at each iteration. On the other hand, Redko et al., (2020) provides a guarantee on the convergence of Problem (7) under the condition that \(L\) yields a concave problem.

Assuming that the two cost functions \({\mathcal {C}}^{{\mathcal {X}}}\) and \({\mathcal {C}}^{{\mathcal {Y}}}\) are symmetric, we introduce the following notations: \({\mathcal {E}}(A, A') := {\mathcal {E}}(A', A) := \sum _{i, j = 1}^I \sum _{k, l = 1}^{K} L_{ijkl} A_{ik} A'_{jl}\) and \({\mathcal {E}}(A) := {\mathcal {E}}(A, A)\). Under these notations, our goal is to minimize (5), i.e., to minimize \({\mathcal {E}}(T)\) under constraints on the marginals of T. Let us now define G(T) as follows: \(G(T) := {\mathcal {E}}(T, T) - \min _{T'\in \varPi _{a b}}{\mathcal {E}}(T, T')\). In a non convex setting, T is a stationary point of \({\mathcal {E}}(T)\) if and only if \(G(T) = 0\) (Reddi et al., 2016). The goal of our Theorem 1 is to provide a guarantee on the convergence of \(G({\overline{T}})\) with \({\overline{T}}\) uniformly sampled from \((T_{s})_{s \in \llbracket 0, S-1 \rrbracket }\). The convergence is proven on average over these sampling. A practical implementation will naturally take only the last transport plan, \(T_{S- 1}\), and avoid unnecessary computations.

Theorem 1

(Based on Reddi et al., (2016)) For any \(L_{ijkl} \in [0,B]\), for any distributions \(\mu\) and \(\nu\) with uniform weights a and b respectively, for any optimal solution \(T^*\) of Problem (5), on average for the transport plan \({\overline{T}}\) uniformly sampled from \((T_{s})_{s \in \llbracket 0, S-1 \rrbracket }\), on average over all the samplings, the following bound holds:

$$\begin{aligned} {\mathbb {E}}\left( G \left( {\overline{T}} \right) \right) \le \sqrt{\frac{2B\left( {\mathcal {E}}(T_0) - {\mathcal {E}}(T^*)\right) N}{S}} + B\sqrt{\frac{2N}{M}} + \epsilon \log (N). \end{aligned}$$

Proof

The complete proof is available in the Appendix A.1. It requires a novel lemma that quantifies the difference between the Wasserstein distances obtained with and without the entropic regularization: \(0 \le \left\langle C,T^\epsilon \right\rangle - \left\langle C,T^0 \right\rangle \le \epsilon \log (N)\). We also prove that \({\mathcal {E}}(T)\) is \(2N^2\)-smooth and we bound the difference between two transport plans: \(\left\Vert T - T'\right\Vert _F \le \sqrt{\frac{2}{N}}\). Those two results allow us to adapt the proof of Theorem 2 in (Reddi et al., 2016) where our new Lemma is useful as the entropy regularized solvers do not find the exact OT minimum. \(\square\)

While our bound cannot be explicitly computed as \(T^*\) is unknown, it gives meaningful information about Algorithm 1. First of all, it prompts us to initialize \(T_0\) so as to get \({\mathcal {E}}(T_0)\) as close to \({\mathcal {E}}(T^*)\) as possible. Without any prior information, \(ab^T\) (the uniform plan) appears to be a reasonable choice to avoid degenerated cases. Regarding the regularization parameter, if \(\epsilon\) is not small enough, the convergence to a stationary point is not guaranteed. On the other hand, we can note that the number of sampled matrices M appears in only one term of the bound. Therefore, the costly complete computation of the expectation (\(M= \infty\)) would not guarantee the convergence while leading to a \(O\left( N^4\right)\) complexity. Thus, our bound prompts us to find a compromise between reducing M and increasing the number of iterations S, allowing us to control the complexity while getting a reasonable bound.

As the GW problem has been shown in (Redko et al., 2020) to be often concave, especially with the square loss and the euclidean distance on both spaces, the following Theorem 2 gives a second bound dedicated to address the specific concave case. This result presents the major interest of providing an asymptotic convergence to a stationary point for \(\hbox {EGW}_{}\) in this concave case, as the proofs proposed in (Peyré et al., 2016) only cover the convergence of \(\hbox {EGW}_{}\) and only for high values of \(\epsilon\).

Theorem 2

With the same notations as in Theorem 1with the entropy regularization parameter \(\epsilon _{s}\) that may now change along the iterations s, when L yields a concave GW problem, the following bound holds:

$$\begin{aligned} {\mathbb {E}}\left( G \left( {\overline{T}} \right) \right) \le \frac{{\mathcal {E}}(T_0) - {\mathcal {E}}(T^*)}{2S} + B\sqrt{\frac{2N}{M}} +\frac{1}{S}\sum _{s=0}^{S-1}\epsilon _s log(N) \end{aligned}$$

We can make the following comments from this bound. First, the convergence is better in the concave case as, unlike in Theorem 1, the first term is now linear in \(S\). Second, as it can be seen in the proof (see Appendix A.1), it can be shown that in this concave scenario, the best value for \(\alpha\) is 1. Thus, if we completely compute the matrix \(\varLambda\) (\(M = \infty\)), this bound applies to \(\hbox {EGW}_{}\). For any sequence \((\epsilon _{s})_{s\in {\mathbb {N}}}\) such that \(\sum _{s=0}^{S- 1} \epsilon _s\) is \(o(S)\), the convergence of \(\hbox {EGW}_{}\) to a stationary point is guaranteed.

Relationship between SaGroW and the Frank-Wolfe algorithm At first sight, SaGroW seems akin to a Frank-Wolfe algorithm (Frank & Wolfe, 1956). In fact, when the regularization parameter \(\epsilon = 0\), SaGroW is strictly equivalent to a Stochastic Frank-Wolfe (Reddi et al., 2016). The convergence analysis of this general non-convex setting is thus very similar, except for the term that depends on \(\epsilon\) which quantifies the error due to the entropy regularization. Moreover, note that if \(\epsilon = 0\), EGW becomes equivalent to the Frank-Wolfe algorithm (Frank & Wolfe, 1956) when its step size \(\alpha\) is set to 1. Since the \(\alpha\) parameter in our algorithm plays the same role as that of the step size of the Frank-Wolfe algorithm, we might wonder why SaGroW does not compute the optimal value using a line search. To the best of your knowledge, in this general non convex setting, there is no convergence guarantees towards a stationary point for a stochastic Frank-Wolfe algorithm that would make use of the optimal step. Moreover, it is worth noting that this optimal step is expensive (\(O\left( N^4\right)\) complexity) to calculate without approximation. Considering an approximation would make the derivation of theoretical guarantees even more challenging.

4.3 Particular case: pointwise GW

We focus in this section on the special case of SaGroW where only one matrix C is sampled (i.e., \(M=1\)) at each iteration. This variant, called Pointwise Gromov Wasserstein (PoGroW), makes it possible to leverage a dedicated solver to reduce the algorithmic complexity of GW.

When \(M=1\), if we sample a position jl from T, then we seek to minimize the following problem:

$$\begin{aligned} \min _{ T' \in \varPi _{a b}}\sum _{i,k=1}^{I,K} {\mathcal {L}}({\mathcal {C}}^{{\mathcal {X}}}(x_i, x_j), {\mathcal {C}}^{{\mathcal {Y}}}(y_k, y_l)) T'_{ik}. \end{aligned}$$
(11)

As illustrated in Fig. 2, each point in \({\mathcal {X}}\) (resp. \({\mathcal {Y}}\)) is simply defined by its distance to \(x_j\) (resp. \(y_l\)), as done in papers that define a distribution using a distance to a point (Gelfand et al., 2005; Sato et al., 2020). With a single feature per point, Problem (11) can be solved very efficiently in \(O\left( N\log (N)\right)\) like a 1D OT problem: the two lists of distances can be sorted and matched. With non-convex losses, this sorting approach is only an approximation. PoGroW can be seen as a natural GW extension of Sliced Wasserstein where each point is described by its distance to a chosen “anchor” (instead of a position on a line). Recall that the output of Problem (11) is a transport plan. If needed for the application at hand, the GW value can be computed in \(O\left( N^2\right)\) (see Sect. 4.5).

Fig. 2
figure 2

Intuition behind PoGroW when \(j,l = 0,1\) are sampled from T: only the distances to \(x_0\) in \({\mathcal {X}}\) (on the left) and to \(y_1\) in \({\mathcal {Y}}\) (on the right) characterize a pair, and then \(T'\) can be computed in \(O\left( N\log N\right)\) like in 1D OT

In summary, PoGroW has the same low complexity as Sliced Gromov Wasserstein (Vayer et al., 2019b) but it overcomes its main limitations: PoGroW is naturally invariant to rotation; it returns a transport plan; it approximates the actual GW distance; it works with graphs.

4.4 A KL regularization-based variant

As the transport plan T is a distribution and most GW algorithms progressively update T, an interesting idea is to encourage the next plan \(T'\) to be close (in terms of KL divergence) to the current estimate T. This idea, already used in Xu et al., (2019b) based on Xie et al., (2020), can be applied to our SaGroW algorithm: we name this approach SaGroW\(^{KL}\) and describe it below.

In Algorithm 1, we used partial updates to explore the transport plan space while encouraging the new value of T to be close to the preceding one, as reflected in line 6. We suggest here a slight modification, consisting in using a Kullback Leibler (KL) regularization between T and \(T'\) in line 5 and removing line 6. This allows to account, in a natural way, for the requirement for T and \(T'\) to stay close to each other during the optimization. This leads to the following sampled optimization problem,

$$\begin{aligned} \min _{ T' \in \varPi _{a b}}&\left\langle \frac{1}{M} \sum _{m=1}^M C^m,T' \right\rangle + \epsilon KL(T'||T), \end{aligned}$$
(12)

which can be rearranged into,

$$\begin{aligned} \min _{ T' \in \varPi _{a b}}\left\langle \frac{1}{M} \sum _{m=1}^M C^m- \epsilon log(T),T' \right\rangle + \epsilon {\mathcal {H}}(T'). \end{aligned}$$
(13)

This regularization allows to take advantage of the Sinkhorn-Knopps solver (Cuturi, 2013) as it is similar to Eq. (3) with a cost function modified to take into account the current prior T. Even if \(\epsilon\) is high, the optimization might lead to a solution close to the edge of the polytope with enough iterations which is not the case with a classical entropy regularization without prior. The time complexity does not increase as it is still \(O\left( (P+ M)N^2\right)\). As this regularization is not specific to our method, we will also use it for \(\hbox {EGW}_{}\) during the experiments to allow a fair comparison. On the other hand, note that this regularization cannot be used with PoGroW as it currently does not seem possible to solve 1D entropy-regularized OT in \(O\left( Nlog(N)\right)\) (Cuturi et al., 2019). Note also that the convergence Theorem 1 does not hold anymore with this regularization.

4.5 Efficient computation the GW distance from a transport plan

This section introduces and evaluates a low-complexity high-accuracy method for the estimation of \({\mathcal {E}}{}(T)\). Indeed, while SaGroW and PoGroW provide important complexity improvements, one might argue that they only find a good transport plan T and do not provide a value for \({\mathcal {E}}{}(T)\). An exact computation of \({\mathcal {E}}{}(T)\) has a \(O\left( N^4\right)\) time complexity, and it would dominate the complexity of our algorithms in applications where \({\mathcal {E}}{}(T)\) is required, for example when GW is used as a dissimilarity measure between graphs. Additionally, having an efficient way of estimating \({\mathcal {E}}{}(T)\) opens the door to selecting the best transport plan among a set of plans, e.g., obtained by varying the hyper-parameters or the random seed of an algorithm.

We address this issue in this section. Similar to Eq. (9), we propose to interpret the sums in the definition of \({\mathcal {E}}{}(T)\) as the expectation of a random variable \(\mathbf{R }{}\) (this time real-valued instead of matrix-valued, so with a quadruple sum), with \({\mathbb {P}}(\mathbf{R }{} =L_{ijkl}) = T_{ij}T_{kl}\):

$$\begin{aligned} {\mathbb {E}}(\mathbf{R }{})&= \sum _{i,j,k,l=1}^{I,I,K,K} L_{ijkl} T_{ij}T_{kl}. \end{aligned}$$
(14)

Instead of simply sampling this expectation, we propose to stratify by each index ik to improve the quality of the estimate. Let \(U_i\) be the event “i is chosen for the first dimension of \(L\)” and \(U'_k\) be the event “k is chosen for the third dimension of \(L\)”. Based on the marginal a and using the law of total expectation, \({\mathbb {E}}(\mathbf{R }{})\) can be rewritten as:

$$\begin{aligned} {\mathbb {E}}(\mathbf{R }{})&= \sum _{i,k=1}^{I,K} {\mathbb {P}}(U_i \cap U'_k) {\mathbb {E}}(\mathbf{R }{}|U_i \cap U'_k) = \sum _{i,k=1}^{I,K} a_i a_k {\mathbb {E}}(\mathbf{R }{}|U_i \cap U'_k). \end{aligned}$$
(15)

For each (ik), the conditional expectation is approximated using M samples of a random variable \(X_{ik}\), defined by \({\mathbb {P}}(X_{ik} = L_{ijkl}) = {\mathbb {P}}(\mathbf{R }{} = L_{ijkl} | U_i \cap U'_k ) = T_{ij}T_{kl}\). Finally, \(\hat{\mathbf{R }{}} = \sum _{ik} a_i a_k \frac{1}{M} \sum _{m=1}^{M} X_{ik}^m\) defines an unbiased estimate of the GW distance which can be computed in \(O\left( MN^2\right)\) (details about the variance estimate are provided in the Appendix A.3).

As shown in Fig. 3 (left), the prediction is perfect for a sparse transport plan (\(\epsilon = 0\)), while still being almost perfect and much better than a naive sparse approximation of the OT plan as \(\epsilon\) increases. Fig. 3 (right) confirms that this approximation is clearly faster than the exact computation which becomes quickly intractable as \(N\) grows.

Fig. 3
figure 3

Estimated value of \({\mathcal {E}}{}(T)\) as sparsity decreases due to an increasing \(\epsilon\) regularization in \(\hbox {EGW}_{}\) (left) and evolution of the time required for its estimation as \(N\) grows (right). The absolute loss is used in these experiments and the distributions take the form of two graphs generated using a gaussian random partition graph (Brandes et al., 2003). For a given \(\epsilon\) and \(N\), the same T (obtained using \(\hbox {EGW}_{}\)) is passed to the three considered methods: Real) an exact one which computes completely \({\mathcal {E}}{}(T)\), Sampled) our sampling method described in Sect. 4.5, and, Sparse) a sparse approximation which keeps only the \(2N\) largest values of T and sets the other entries to 0. The mean and 2 standard deviations over 10 runs are displayed on both figures. When the standard deviation is not visible, it corresponds either to a deterministic method or a value very close to 0.

Having at our disposal an efficient method for estimating \({\mathcal {E}}{}(T)\), we can now fully compare, in Table 1, the complexity of the state of the art methods with that of SaGroW and PoGroW, for the general case of an arbitrary loss function. From this table, we have evidence that SaGroW allows a drastic reduction of the algorithmic complexity of \(\hbox {EGW}_{}\). On the other hand, PoGroW fully benefits from the 1D projections. But unlike SGW, it provides a transport plan and does approximate the original GW problem.

Table 1 Complexity of each method with an arbitrary loss function, with \(S\) iterations, \(P\) Sinkhorn iterations, \(N\) points in the dataset and \(M\) matrix samples

5 Experiments

In this section,Footnote 4 we first compare different GW methods on both their speed and their accuracy. We use here the term accuracy to express the capability of the methods to minimize \({\mathcal {E}}{}\)(T). Indeed, as the exact (optimal) GW distance is unknown for a given dataset (solving this problem is known to be NP-hard), the best method will be the one with the smallest value of \({\mathcal {E}}{}(T)\). Then, we analyze the impact of the hyperparameters, illustrating that our approach covers a range of very good trade-offs between speed and accuracy. Using a real graph-classification task, we finally illustrate why being able to solve GW for various loss functions is important.

5.1 General setup and methods

We compare SaGroW\(^{KL}\) and PoGroW with: (I) \(\hbox {EGW}_{}\) (Peyré et al., 2016; II) \(\hbox {EGW}^{KL}_{}\), a KL regularized version of \(\hbox {EGW}_{}\) described in Xu et al., (2019b); (III) EMD-GW, which is similar to \(\hbox {EGW}_{0}\), but uses the OT solver of (Bonneel et al.,, 2011) as the Sinkhorn algorithm (Cuturi, 2013) cannot handle a null value for \(\epsilon\); (IV) S-GWL (Xu et al., 2019a), adapted for arbitrary loss functions using the optimizer of Wright (1996) to update the barycenter; (V) SGW when the points are available, with an adaptation to arbitrary losses; (VI) the uniform transport plan, used as a baseline.

While Sect. 5.3 will detail the impact of the hyperparameters, the next section reports, for each method, the results obtained by the set of parameters with the lowest GW estimation. To take into account the stochasticity of some methods the GW estimation for each hyperparameter set is taken on average over 10 runs. \(\epsilon\) is chosen among {0.001, 0.005, 0.01, 0.005, 0.1} for \(\hbox {EGW}_{}\) and \(\hbox {EGW}^{KL}_{}\), and in {0.001, 0.01, 0.1, 1, 10, 100} for S-GWL. To have comparable sets of hyperparameters, we fix some of our parameters: in PoGroW, a step of \(\alpha = 0.8\), and in SaGroW, the number of samples \(M=10\) and a KL regularization \(\epsilon = 1\). Experiments in the Appendices B.5 and B.6 show that: SaGroW is much less sensitive to \(\epsilon\) than \(\hbox {EGW}_{}\) and \(\alpha = 0.8\) is a reasonable choice. The number of iterations \(S\) is chosen among {10, 100, 500, 1000} to obtain a reasonable accuracy-speed trade-off.

This experiment compares the quality of the transport plan and the computational time of the methods for an increasing number of points \(N\). Each method minimizes Problem (5) and returns a transport plan T (besides SGW, see below). In order to assess the quality of this transport plan, \({\mathcal {E}}{}(T)\) is then computed exactly. Notably, our GW distance approximation (see Sect. 4.5) is not used in this first experiment. The mean and standard deviation of \({\mathcal {E}}{}(T)\) over ten runs are reported.

The loss \({\mathcal {L}}\) chosen here is the absolute loss in order to show the capacity of our methods to deal with any arbitrary loss function. We remind that \(\hbox {EGW}_{}\), S-GWL and SGW are much faster (with speeds that are comparable to our approach) for some specific losses, such as the square loss (see Appendix B.2 and Sect. 5.4).

To include SGW (which needs points to project) in this comparative study, a first dataset uses \(\mu\) and \(\nu\) that are composed of \(N\) points sampled from two different mixtures of gaussians. Details about the generation of the datasets are available in the Appendix B.1.

5.2 Speed and accuracy of the GW estimate

Figure 4 shows, in a log-log representation, that \(\hbox {EGW}_{}\) and EMD-GW become quickly intractable when the number of points increases and that S-GWL is slightly faster. We exclude \(\hbox {EGW}^{KL}_{}\) for the clarity of the figure as it has a computational time similar to \(\hbox {EGW}_{}\). SaGroW, PoGroW and SGW behave better, with a quadratic complexity (linear slope of 2 in log-log) but with different multiplicative factors (offsets in the log-log plot).

Fig. 4
figure 4

Computational time of various methods to compute the distance between samples from two mixtures of gaussians. The mean and the standard deviation over 10 runs are reported

Figure 5 reports the quality of the obtained GW value. Comparing SGW to other methods is complicated as it does not return a transport plan, nor aims at computing an approximation of the GW distance. We thus report the distance it computes and also the same rescaled by a factor 25. With rescaling, we see that SGW seems to behave more like the uniform transport plan than like the GW methods (which produce better-than-uniform plans). While all other methods predict very similar GW distances, \(\hbox {EGW}_{}\)-based methods have often the best accuracy. However, when N reaches 1000 points, we can observe interesting behaviors: \(\hbox {EGW}_{}\) is not able to provide any result, PoGroW is the fastest with a lesser accuracy than S-GWL, and SaGroW provides the best value while being much faster than S-GWL.

Fig. 5
figure 5

GW distance estimation between samples from two mixtures of gaussians. The mean and standard deviation over 10 runs are reported for the stochastic methods

In a second series of experiments, we make use of graphs that are generated using a gaussian random partition graph (Brandes et al., 2003). On this more difficult dataset, we see in Fig. 6 that SaGroW is very competitive with the best method \(\hbox {EGW}^{KL}_{}\) while being able to scale to more than 200 nodes, which is the limit for all \(\hbox {EGW}_{}\)-based methods. With more nodes, SaGroW is as accurate as S-GWL but remains much faster and scalable (computation times are similar to the ones from the first dataset). In this experiment, a key factor of success seems to be the KL regularization, used in \(\hbox {EGW}^{KL}_{}\), S-GWL and SaGroW. This can explain why PoGroW stays close to the uniform baseline.

Fig. 6
figure 6

GW distance estimation between synthetic graphs (Brandes et al., 2003). The mean and standard deviation over 10 runs are reported for the stochastic methods

5.3 Hyperparameters analysis

We now focus on the impact of the numbers of iterations \(S\) and samples \(M\), showing that these allow our approach to cover a variety of trade-offs between speed and accuracy. More experiments (in the Appendix B) consider other parameters such as different loss functions or dataset size. We also study, in this experiment, the impact of the \(\epsilon\) parameter of other methods.

Figure 7 shows that increasing the number of iterations \(S\) yields a strong improvement for SaGroW, independently of the number of samples \(M\). Interestingly, the accuracy of SaGroW is similar regardless the value of \(M\). This remark supports the key assumption of this paper that the entire computation of the expectation is not needed. The standard deviation displayed in Fig. 7 shows that most runs provide similar GW distances, with enough iterations. However, there is a high variance with less iterations which tends to highlight that the different runs of SaGroW take different paths during the optimization. As shown in Fig. 8, the speed of \(\hbox {EGW}_{}\) and S-GWL does not vary much with \(\epsilon\) but this parameter needs to be chosen carefully for those methods to reach a good accuracy.

Fig. 7
figure 7

Impact of the number of sample \(M\) and the number of iterations \(S\) for SaGroW on the GW distance estimation and computational time, for two sets of 500 points sampled from two mixtures of gaussians. The mean and standard deviation over 10 runs are display

Fig. 8
figure 8

Impact of the Kullback-Leiber regularization \(\epsilon\) for \(\hbox {EGW}_{}\) and S-GWL on the GW distance estimation and computational time, for two sets of 500 points sampled from two mixtures of gaussians

On Fig. 9 we can see that PoGroW is even faster than SaGroW: it can provide a reasonable approximation in a second, compared to the three hours required by \(\hbox {EGW}_{}\). Because PoGroW does not resort to a KL regularization, it is more impacted by stochasticity: two runs can yield very different results. This can be used advantageously by keeping the plan that gives the lowest GW among ten runs (crosses on Fig. 9). The combination of SaGroW and PoGroW allows to obtain a good trade-off between speed and accuracy.

Fig. 9
figure 9

Impact of the number of iterations \(S\) for PoGroW on the GW distance estimation and computational time, for two sets of 500 points sampled from two mixtures of gaussians. The mean and standard deviation over 10 runs are display. To take advantage of the large stochasticity, the minimum over 10 runs is also display

Beyond the algorithmic advantages shown above, one last key question remains: is it useful, in an application, to compute the GW distance for other losses than the widely used square loss?

5.4 Graph classification

We illustrate here the usefulness of using different loss functions in a context of graph classification. We take the FIRSTMM-DB graph dataset (Neumann et al., 2013) which is the one with the biggest average nodes number (1377) over the database of (Kersting et al., 2016). Each of the 41 graphs of the dataset describes an object from one of the 11 classes (cup, knife, etc.). The distance matrix of each graph \(C^{{\mathcal {X}}}\) and \(C^{{\mathcal {Y}}}\) is computed using the shortest path length, similarly to Mémoli (2011). For each method, we compute the pairwise GW distance matrix. Finally, a 1-Nearest-Neighbor classifier is used to predict the class of each graph (using a leave-one-graph-out scheme).

Section 5.2 showed that \(\hbox {EGW}_{}\), \(\hbox {EGW}^{KL}_{}\) and S-GWL are very slow with arbitrary loss functions on graphs (with around 1000 nodes). Therefore, we use for them the square loss to allow them to be competitive from a time complexity perspective. We consider ten values for the entropic regularization, \(\epsilon \in [10^{-4}, 10^2]\). SGW is excluded as it is unable to handle graphs. For our methods, we set \(\epsilon = 0.1\) for SaGroW and \(\alpha = 0.8\) for PoGroW and keep \(M= 1\), \(S= 100\) for both methods. However, ten different loss functions \({\mathcal {L}}\) are tested, notably \(|C^{{\mathcal {X}}}_{ij} - C^{{\mathcal {Y}}}_{kl}|^p\) for different values of \(p \in [0.5,3]\).

The results are reported in Table 2. Looking at SaGroW, we see that the classical square loss (\(p=2\)) is outperformed, e.g., by the absolute loss (\(p=1\)) which yields a better classification accuracy. Beyond that, the ability of SaGroW to handle arbitrary losses allows it to get the best overall accuracy, across all the methods. The explanation can be that the L1 loss is more robust to outlier nodes, which might be important on this real dataset. Note that while \(\hbox {EGW}_{}\) and S-GWL are fast as they are computed with the square loss for \({\mathcal {L}}\), SaGroW is still slightly faster. PoGroW has a competitive accuracy and even outperforms \(\hbox {EGW}_{}\) while being very fast. The complete table with every hyperparameter run is available in the Appendix B.7.

While the goal of this experiment is to correctly classify graphs, we can still compare the GW distances obtained from the transport plans returned by all methods. This comparison only makes sense with the same (square) loss for all methods. Averaged over \(41^2\) distances, SaGroW gets the lowest value of 336, followed by EMD-GW with 341. This highlights the fact that, on a real dataset, the stochasticity used by our method can lead to a better GW distance estimation.

Table 2 Classification accuracy and computation time of various methods on a 11-class graph classification task

6 Conclusion

In this paper, we present both algorithmic and theoretical contributions to address the still open problem related to the calculation of the Gromov Wasserstein distance. We propose a method to reduce drastically the time complexity of GW for arbitrary loss functions. To do so, we tackle the bottleneck of the mostly used GW solver, namely \(\hbox {EGW}_{}\), by using a sampling strategy to efficiently approximate the costly sum of \(N^2\) matrices. Our SaGroW algorithm is supported with theoretical convergence guarantees to a stationary point in the general non-convex setting. We also introduce PoGroW, an algorithm which samples only one matrix and allows us to benefit from a very low complexity by using 1D OT. We show that PoGroW overcomes the main issues related to SGW. Experiments on synthetic datasets show that our method are tractable for a large number of points and offer a good trade-off between speed and accuracy. Finally, a real world experiment on graph classification illustrates the interest of choosing different loss functions. In order to deal with potential outliers, we show that the absolute loss associated with SaGroW gives the highest classification accuracy. We claim that this capacity to choose ad-hoc loss functions will push the state of the art in various graph applications by unlocking their use with large graphs.