Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 79 additions & 1 deletion src/cfp/utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from typing import Any, Literal

import jax
import jax.numpy as jnp
from ott.geometry import costs, pointcloud
from ott.geometry import costs, geometry, graph, pointcloud
from ott.problems.linear import linear_problem
from ott.solvers.linear import sinkhorn

Expand Down Expand Up @@ -61,3 +62,80 @@ def match_linear(
solver = sinkhorn.Sinkhorn(threshold=threshold, **kwargs)
out = solver(problem)
return out.matrix


def _get_nearest_neighbors(
X: jnp.ndarray, Y: jnp.ndarray | None = None, k: int = 30
) -> tuple[jnp.ndarray, jnp.ndarray]:
concat = X if Y is None else jnp.concatenate((X, Y), axis=0)
pairwise_euclidean_distances = pointcloud.PointCloud(concat, concat).cost_matrix
distances, indices = jax.lax.approx_min_k(
pairwise_euclidean_distances, k=k, recall_target=0.95, aggregate_to_topk=True
)
connectivities = jnp.multiply(jnp.exp(-distances), (distances > 0))
return connectivities / jnp.sum(connectivities), indices


def _create_cost_matrix_lin(
X: jnp.array,
Y: jnp.array,
k_neighbors: int,
) -> jnp.array:
distances, indices = _get_nearest_neighbors(X, Y, k_neighbors)
a = jnp.zeros((len(X) + len(Y), len(X) + len(Y)))
adj_matrix = a.at[
jnp.repeat(jnp.arange(len(X) + len(Y)), repeats=k_neighbors).flatten(),
indices.flatten(),
].set(distances.flatten())
return graph.Graph.from_graph(
adj_matrix,
normalize=True,
).cost_matrix[: len(X), len(X) :]


def match_linear_geodesic(
source_batch: jnp.ndarray,
target_batch: jnp.ndarray,
epsilon: float = 1e-3,
scale_cost: ScaleCost_t = "mean",
tau_a: float = 1.0,
tau_b: float = 1.0,
k_neighbors: int | None = None,
threshold: float | None = None,
**kwargs,
) -> jnp.ndarray:
"""Compute the OT coupling based on a geodesic distance between source and target batch.

Parameters
----------
source_batch
Source point cloud of shape ``[n, d]``.
target_batch
Target point cloud of shape ``[m, d]``.
epsilon
Regularization parameter.
scale_cost
Scaling of the cost matrix.
tau_a
Parameter in :math:`(0, 1]` that defines how unbalanced the problem is
in the source distribution. If :math:`1`, the problem is balanced in the source distribution.
tau_b
Parameter in :math:`(0, 1]` that defines how unbalanced the problem is in the target
distribution. If :math:`1`, the problem is balanced in the target distribution.
threshold
Convergence criterion for the Sinkhorn algorithm.
kwargs
Additional arguments for :class:`ott.solvers.linear.Sinkhorn`.

Returns
-------
Geodesic distance between the two point clouds.
"""
if threshold is None:
threshold = 1e-3 if (tau_a == 1.0 and tau_b == 1.0) else 1e-2
k_neighbors = len(source_batch) + 1 if k_neighbors is None else k_neighbors
cm = _create_cost_matrix_lin(source_batch, target_batch, k_neighbors, **kwargs)
geom = geometry.Geometry(cost_matrix=cm, epsilon=epsilon, scale_cost=scale_cost)
solver = sinkhorn.Sinkhorn(threshold=threshold, **kwargs)
out = solver(linear_problem.LinearProblem(geom, tau_a=tau_a, tau_b=tau_b))
return out.matrix