GithubHelp home page GithubHelp logo

alexisthual / fugw Goto Github PK

View Code? Open in Web Editor NEW
36.0 9.0 6.0 91.63 MB

Scalable python GPU solvers for fused unbalanced gromov-wasserstein optimal transport problems, with routines and examples to align brain data (fMRI)

Home Page: https://alexisthual.github.io/fugw/

License: MIT License

Python 100.00%
optimal-transport python3 brain-alignment

fugw's Introduction

Fused Unbalanced Gromov-Wasserstein for Python

build python version license code style

This package implements multiple GPU-compatible PyTorch solvers to the Fused Unbalanced Gromov-Wasserstein optimal transport problem.

This package is under active development. There is no guarantee that the API and solvers won't change in the near future.

Installation

To install this package, make sure you have an up-to-date version of pip.

From PyPI

In a dedicated Python env, run:

pip install fugw

From source

git clone https://github.com/alexisthual/fugw.git
cd fugw

In a dedicated Python env, run:

pip install -e .

Contributors should also install the development dependencies in order to test and automatically format their contributions.

pip install -e ".[dev]"
pre-commit install

Tests run on CPU and GPU, depending on the configuration of your machine. You can run them with:

pytest

Citing this work

If this package was useful to you, please cite it in your work:

@article{Thual-2022-fugw,
  title={Aligning individual brains with Fused Unbalanced Gromov-Wasserstein},
  author={Thual, Alexis and Tran, Huy and Zemskova, Tatiana and Courty, Nicolas and Flamary, Rémi and Dehaene, Stanislas and Thirion, Bertrand},
  publisher={arXiv},
  doi={10.48550/ARXIV.2206.09398},
  url={https://arxiv.org/abs/2206.09398},
  year={2022},
  copyright={Creative Commons Attribution 4.0 International}
}

fugw's People

Contributors

6ulm avatar alexisthual avatar antoinecollas avatar bthirion avatar dimitripapadopoulos avatar pbarbarant avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

fugw's Issues

[BUG] Dense/Sparse barycenters features contain NaNs

After a few iterations of the dense/sparse barycenter solver, NaNs appear in the computation of the barycenter features. This is due to zero div error in the line:

acc = w * pi.T @ f.T / pi.sum(0).reshape(-1, 1)

intsall instructions incomplete ?

pip install -e . fails on my box

bertrandthirion@ptb-11107357:~/mygit/fugw$ pip install -e .
Defaulting to user installation because normal site-packages is not writeable
Obtaining file:///home/bertrandthirion/mygit/fugw
  Installing build dependencies ... done
  Checking if build backend supports build_editable ... done
ERROR: Project file:///home/bertrandthirion/mygit/fugw has a 'pyproject.toml' and its build backend is missing the 'build_editable' hook. Since it does not have a 'setup.py' nor a 'setup.cfg', it cannot be installed in editable mode. Consider using a build backend that supports PEP 660.

Error about transform

Hi, I think something may be wrong when computing the mapped result after getting matrix pi.

please see the transform code from line 338 to 342 in fugw/src/fugw/mappings/dense.py

  pi.T
  @ source_features_tensor.T
  / pi.sum(dim=0).reshape(-1, 1)

You use $pi^{T} \cdot S^{T}$

But the formula should be $pi \cdot Target$, not source data. Please check the transform code from POT

  transp = self.coupling_ / nx.sum(self.coupling_, axis=1)[:, None]

  # set nans to 0
  transp = nx.nan_to_num(transp, nan=0, posinf=0, neginf=0)

  # compute transported samples
  transp_Xs = nx.dot(transp, self.xt_)

I can show you the proofs based on the application and theory.

Proof 1 - application

Here is an example based on your example Transport distributions using dense solvers

After training and getting the pi, you can show the training points to compare with the mapped points,

# modified from transformed_data = mapping.transform(source_features_test)
transformed_data_train = mapping.transform(source_features_train)

fig = plt.figure(figsize=(4, 4))
ax = fig.add_subplot()
ax.set_title("Source and target features")
ax.set_aspect("equal", "datalim")
ax.scatter(source_features_train[0], source_features_train[1], label="Source")
ax.scatter(target_features_train[0], target_features_train[1], label="Target")
ax.scatter(transformed_data_train[0], transformed_data_train[1], label="trans")
ax.legend()
plt.show()

The plot will be like:

f9857eae-f839-45e6-a382-ff16c82a59a8

You can see the mapped data actually close to source data,
and if you use POT way,

mapped_data_train = np.dot(pi, target_features_train.T) / pi.sum(dim=1).reshape(-1, 1)

fig = plt.figure(figsize=(4, 4))
ax = fig.add_subplot()
ax.set_title("Source and target features")
ax.set_aspect("equal", "datalim")
ax.scatter(source_features_train[0], source_features_train[1], label="Source")
ax.scatter(target_features_train[0], target_features_train[1], label="Target")
ax.scatter(mapped_data_train.T[0], mapped_data_train.T[1], label="trans")
ax.legend()
plt.show()

Then the plot will be:

download

You can see the mapped data close to the target data.

Proof 2 - theory

Here I can show you the result does not make sense.

We assume:

$S_{s}$ means the Source data in the source space. The shape is [3000, 50], which means 3000 points, 50 features of each point in a 50-dim space.

$T_{t}$ means the Target data in the target space. The shape is [9000, 100], which means 9000 points, 100 features of each point in a 100-dim space.

OT matrix pi, the shape is [3000, 9000].

POT code

From POT code, if we want to get the mapped source data in target space $S_{t}$, we can use:

$$S_{T} = pi \cdot T_{t}$$

The $S_{t}$ shape will be [3000, 100], the details of shapes according to the formula before:
$$[3000, 100] = [3000, 9000] \cdot [9000, 100]$$

The source data shape from [3000, 50] in the 50-dim space map to [3000, 100] in the 100-dim space, the point number does not change. Each point just moves from the 50-dim space to the 100-dim space.

So the explanation of the OT algorithm is:

OT algorithm can map the data from the source space to the target space, without point number change.

FUGW code

According to FUGW code, if we want to get the mapped source data in target space $S_{t}$, we can use:

$$S_{T} = pi^{T} \cdot S_{s}$$

The $S_{t}$ shape will be [3000, 100], and the details of shapes according to the formula will be:
$$[9000, 50] = [9000, 3000] \cdot [3000, 50]$$

So the source data from [3000, 50] in the 50-dim map to [9000, 50] is still in 50-dim space, the data not in the 100-dim target space! It does not make sense!

Please let me know if I was wrong :)

Btw, thanks a lot for the contribution to FUGW, it helps me a lot.

Make functions in `solvers/utils.py` private.

          I agree, most functions in `solvers/utils.py` are not meant to be public (because they are called through `fugw.mappings` instances, which themselves call `fugw.solvers` instances).

I added a docstring, and suggest we hide all non-public functions in a follow-up PR.

Originally posted by @alexisthual in #29 (comment)

RuntimeError: sparse tensors do not have strides

Hi, I tried using the sparse version with the code in test_sparse_transformer.py. At the fugw.transform() step I get the following error.

transformed_data = fugw.transform(source_features_test)
File "...fugw-main\src\fugw\sparse.py", line 213, in transform
torch.sparse.mm(self.pi.T, source_features_torch.T).to_dense()
RuntimeError: sparse tensors do not have strides

I'm using Windows, conda, python 3.6, don't have a CUDA video card so using torch with cpu only.

Fix the case alpha = 0

When running the dense/sparse solvers with alpha = 0, the following error occurs:

  File "/data/parietal/store3/work/pbarbara/fugw/src/fugw/solvers/dense.py", line 425, in solve
    current_loss = compute_fugw_loss(pi, gamma)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/parietal/store3/work/pbarbara/fugw/src/fugw/solvers/dense.py", line 166, in fugw_loss
    "gromov_wasserstein": loss_gromov_wasserstein.item(),
                          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AttributeError: 'NoneType' object has no attribute 'item'

[BUG] Remove res["pi"].detach().cpu()

In fugw.mapping :

# Store variables of interest in model
self.pi = res["pi"].detach().cpu()

Causes problems when dealing with barycenters as the transport plans get detached before the next barycenter iteration. Discovered with @antoinecollas.

[TEST] Coarse-to-fine test fails on GPU

Two tests fail on coarse-to-fine:

FAILED tests/scripts/test_coarse_to_fine.py::test_coarse_to_fine[device2-False] - TypeError: can't convert cuda:0 device type tensor to numpy. Use Tensor.cpu() to copy the tensor to host memory first.
FAILED tests/scripts/test_coarse_to_fine.py::test_coarse_to_fine[device3-True] - TypeError: can't convert cuda:0 device type tensor to numpy. Use Tensor.cpu() to copy the tensor to host memory first.

[DOC] Fix documentation which rely on dense mapping

Following #50, examples that contain the dense mapping now longer compile as they pipe in the whole cost matrix instead of the factored embedding.

Examples include:

01_brain_alignment/plot_1_aligning_brain_dense.py
00_basics/plot_1_dense.py
02_miscellaneous/plot_3_memory_usage_callback.py
01_brain_alignment/plot_3_aligning_low_res_volumes.py
02_miscellaneous/plot_1_pearson_correlation.py
02_miscellaneous/plot_2_simple_crossval.py

There is NaN in coupling

    from fugw.solvers import FUGWSolver
    
    torch.manual_seed(1)
    torch.backends.cudnn.benchmark = True

    nits_bcd = 100
    eval_bcd = 2        
    Ds_normalized = Ds / Ds.max()
    Dt_normalized = Dt / Dt.max()
    F_normalized = c_z / c_z.max()

    fugw = FUGWSolver(
        nits_bcd=nits_bcd,
        nits_uot=1000,
        tol_bcd=1e-7,
        tol_uot=1e-7,
        tol_loss=1e-5,
        eval_bcd=eval_bcd,
        eval_uot=10,
        ibpp_eps_base=1e5,
    )

    divergence="kl"
    
    reg_mode="independent"
    
    solver="sinkhorn"
    
    res = fugw.solve(
        alpha=alpha,
        rho_s=rho_s,
        rho_t=rho_t,
        eps=eps,
        reg_mode=reg_mode,
        divergence=divergence,
        F=F_normalized,
        Ds=Ds_normalized,
        Dt=Dt_normalized,
        init_plan=None,
        solver=solver,
        callback_bcd=None,
        verbose=False,
    )
    
    Why would I get coupling having nan?  In what cases would this happen easily?

undo parameters concatenation

          I agree, this concatenation of parameters is difficult to understand. It is somewhat consistent across solvers, but I think we should flatten these variables. Once again, I think this would fit more naturally in a follow-up PR.

Originally posted by @alexisthual in #29 (comment)

Test barycenter edge cases with POT

          Maybe we could test edge cases (although this may require using external libraries):
  • one could test that alpha=0 returns wasserstein barycenters similar to that of POT
  • similarly, one could test that alpha=1 returns gromov barycenters similar to that of POT

Originally posted by @alexisthual in #45 (comment)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.