GithubHelp home page GithubHelp logo

oskopek / mvae Goto Github PK

View Code? Open in Web Editor NEW
58.0 4.0 13.0 249 KB

Mixed-curvature Variational Autoencoders (ICLR 2020)

Home Page: https://openreview.net/forum?id=S1g6xeSKDS

License: Apache License 2.0

Makefile 0.09% Python 97.24% Shell 2.67%
vae variational-autoencoder variational-inference curvature machine-learning deep-learning iclr iclr2020 eth-zurich non-euclidean-geometry

mvae's People

Contributors

macio232 avatar oskopek avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar

mvae's Issues

Does `--fixed_curvature=True` work?

Describe the bug
Despite setting the --fixed_curvature=True flag I can see that curvature is being changed during training.

To Reproduce
Steps to reproduce the behavior:

  1. Set up your environment like described in README
  2. Run experiment with python -m mt.examples.run --dataset="mnist" --model="s10" --fixed_curvature=True
  3. Observe comp_000_s10/curvature is being changed

Expected behavior
Curvature should be constant and set to the default value, which is 1.

Desktop (please complete the following information):

  • OS: Linux

Error in demo: FileNotFoundError

Thanks for the package!

When running (Mac Hierra): python -m mt.examples.run --dataset="mnist" --model="h2,s2,e2" --fixed_curvature=False

I got an error after 100 epochs: FileNotFoundError: [Errno 2] No such file or directory: './chkpt/vae-mnist-e2,h2,s2-2020-06-14T17:45:14.337416/None.chkpt'

Thanks in advance!

Exponential map of Euclidean manifold

Hi, thank you for open-sourcing the repo! Is there a particular reason why the exponential map of the Euclidean manifold is written as x+v/2 instead of x+v?

Pre-trained embeddings for those four datasets

Dear authors,

Thanks for the great code! I am wondering do you also have the pre-trained embeddings of those four datasets included in the repository? Because I want to do some visualization for the embeddings. Seems like they should be in mt/data/, but only find four scripts there.

Also when I want to use the mt/examples/run.py to generate the embeddings myself, I find that with the following parameters:
--dataset="mnist" --model="h2,s2,e2" --fixed_curvature=True --epochs=50 --warmup=10 --lookahead=5 --show_embeddings=10
only the embeddings of the last epoch 49 will be recorded in chkpt/ folder, instead of recording embeddings every 10 epochs. I looked into the code a little bit, finding that show_embeddings is only used in _test_epoch, which is only called once in train_stopping. Maybe I miss something here, but seems like it will only save the last embeddings.

Thanks for your help in advance!

torch.jit.frontend.UnsupportedNodeError: break statements aren't supported

Describe the bug
Running the example from README gives the following error:

Traceback (most recent call last):
  File "/home/maciej/.conda/envs/pt/lib/python3.7/runpy.py", line 193, in _run_module_as_main
    "__main__", mod_spec)
  File "/home/maciej/.conda/envs/pt/lib/python3.7/runpy.py", line 85, in _run_code
    exec(code, run_globals)
  File "/home/maciej/Documents/mvae/mt/examples/run.py", line 22, in <module>
    from ..data import create_dataset
  File "/home/maciej/Documents/mvae/mt/data/__init__.py", line 18, in <module>
    from .image_reconstruction import CifarVaeDataset, MnistVaeDataset, OmniglotVaeDataset
  File "/home/maciej/Documents/mvae/mt/data/image_reconstruction.py", line 24, in <module>
    from ..mvae.distributions import EuclideanUniform
  File "/home/maciej/Documents/mvae/mt/mvae/distributions/__init__.py", line 16, in <module>
    from .hyperspherical_uniform import HypersphericalUniform
  File "/home/maciej/Documents/mvae/mt/mvae/distributions/hyperspherical_uniform.py", line 24, in <module>
    from ..ops.common import ln_2, ln_pi
  File "/home/maciej/Documents/mvae/mt/mvae/ops/__init__.py", line 18, in <module>
    from .poincare import PoincareBall
  File "/home/maciej/Documents/mvae/mt/mvae/ops/poincare.py", line 18, in <module>
    import geoopt.manifolds.poincare.math as pm
  File "/home/maciej/.local/lib/python3.7/site-packages/geoopt/__init__.py", line 1, in <module>
    from . import manifolds
  File "/home/maciej/.local/lib/python3.7/site-packages/geoopt/manifolds/__init__.py", line 6, in <module>
    from .birkhoff_polytope import BirkhoffPolytope
  File "/home/maciej/.local/lib/python3.7/site-packages/geoopt/manifolds/birkhoff_polytope.py", line 203, in <module>
    x, max_iter: int = 300, eps: float = 1e-5, tol: float = 1e-5
  File "/home/maciej/.conda/envs/pt/lib/python3.7/site-packages/torch/jit/__init__.py", line 823, in script
    ast = get_jit_def(obj)
  File "/home/maciej/.conda/envs/pt/lib/python3.7/site-packages/torch/jit/frontend.py", line 158, in get_jit_def
    return build_def(ctx, py_ast.body[0], type_line, self_name)
  File "/home/maciej/.conda/envs/pt/lib/python3.7/site-packages/torch/jit/frontend.py", line 198, in build_def
    build_stmts(ctx, body))
  File "/home/maciej/.conda/envs/pt/lib/python3.7/site-packages/torch/jit/frontend.py", line 122, in build_stmts
    stmts = [build_stmt(ctx, s) for s in stmts]
  File "/home/maciej/.conda/envs/pt/lib/python3.7/site-packages/torch/jit/frontend.py", line 122, in <listcomp>
    stmts = [build_stmt(ctx, s) for s in stmts]
  File "/home/maciej/.conda/envs/pt/lib/python3.7/site-packages/torch/jit/frontend.py", line 174, in __call__
    return method(ctx, node)
  File "/home/maciej/.conda/envs/pt/lib/python3.7/site-packages/torch/jit/frontend.py", line 319, in build_While
    build_stmts(ctx, stmt.body))
  File "/home/maciej/.conda/envs/pt/lib/python3.7/site-packages/torch/jit/frontend.py", line 122, in build_stmts
    stmts = [build_stmt(ctx, s) for s in stmts]
  File "/home/maciej/.conda/envs/pt/lib/python3.7/site-packages/torch/jit/frontend.py", line 122, in <listcomp>
    stmts = [build_stmt(ctx, s) for s in stmts]
  File "/home/maciej/.conda/envs/pt/lib/python3.7/site-packages/torch/jit/frontend.py", line 174, in __call__
    return method(ctx, node)
  File "/home/maciej/.conda/envs/pt/lib/python3.7/site-packages/torch/jit/frontend.py", line 332, in build_If
    build_stmts(ctx, stmt.body),
  File "/home/maciej/.conda/envs/pt/lib/python3.7/site-packages/torch/jit/frontend.py", line 122, in build_stmts
    stmts = [build_stmt(ctx, s) for s in stmts]
  File "/home/maciej/.conda/envs/pt/lib/python3.7/site-packages/torch/jit/frontend.py", line 122, in <listcomp>
    stmts = [build_stmt(ctx, s) for s in stmts]
  File "/home/maciej/.conda/envs/pt/lib/python3.7/site-packages/torch/jit/frontend.py", line 173, in __call__
    raise UnsupportedNodeError(ctx, node)
torch.jit.frontend.UnsupportedNodeError: break statements aren't supported
    x, max_iter: int = 300, eps: float = 1e-5, tol: float = 1e-5
):
    iter = 0
    c = 1.0 / (x.sum(dim=-2, keepdim=True) + eps)
    r = 1.0 / ((x @ c.transpose(-1, -2)) + eps)
    while iter < max_iter:
        iter += 1
        cinv = torch.matmul(r.transpose(-1, -2), x)
        if torch.max(torch.abs(cinv * c - 1)) <= tol:
            break
            ~~~~~ <--- HERE
        c = 1.0 / (cinv + eps)
        r = 1.0 / ((x @ c.transpose(-1, -2)) + eps)
    return x * (r @ c)

To Reproduce
Steps to reproduce the behavior:

  1. Set up your environment as described in README
  2. Run python -m mt.examples.run --dataset="mnist" --model="h2,s2,e2" --fixed_curvature=False

Desktop (please complete the following information):

  • OS: Linux (no CUDA)

Deterministic behavior in the eval mode

Hello!

Great code, thank you!

However, I have a small question: usually, when VAE is in the eval mode we only generate samples using the mean (so we don't sample from the latent distribution and simply forward encoder(x) to decoder). Is it possible to enforce such a behavior in your models?

Thanks in advance!

RuntimeError when using dataset cifar

Describe the bug
The following RuntimeError appears when I am trying to run run.py on cifar dataset.

To Reproduce
Steps to reproduce the behavior:

  1. Set up your environment: conda activate pt
  2. Run experiment with python -m mt.examples.run --dataset="cifar" --model="e2,h2,s2" --fixed_curvature=False --epochs=200
  3. See error: Running on: cuda
    VAE Model: e2,h2,s2; Epochs: 200; Time: 2021-01-26T16:19:41.074145; Fixed curvature: False; Dataset: cifar
    TrainEpoch 0: Traceback (most recent call last):
    File "/home/chaopan2/miniconda3/envs/pt/lib/python3.7/runpy.py", line 193, in _run_module_as_main
    "main", mod_spec)
    File "/home/chaopan2/miniconda3/envs/pt/lib/python3.7/runpy.py", line 85, in _run_code
    exec(code, run_globals)
    File "/home/chaopan2/simulations/Py/p3/mvae/mt/examples/run.py", line 169, in
    main()
    File "/home/chaopan2/simulations/Py/p3/mvae/mt/examples/run.py", line 162, in main
    max_epochs=args.epochs)
    File "/home/chaopan2/simulations/Py/p3/mvae/mt/mvae/models/train.py", line 118, in train_stopping
    train_results[self.epoch] = self._train_epoch(optimizer, train_data, beta=beta)
    File "/home/chaopan2/simulations/Py/p3/mvae/mt/mvae/models/train.py", line 188, in _train_epoch
    stats, (reparametrized, _, ) = self.model.train_step(optimizer, x_mb, beta=beta)
    File "/home/chaopan2/simulations/Py/p3/mvae/mt/mvae/models/vae.py", line 154, in train_step
    reparametrized, concat_z, x_mb
    = self(x_mb)
    File "/home/chaopan2/miniconda3/envs/pt/lib/python3.7/site-packages/torch/nn/modules/module.py", line 493, in call
    result = self.forward(*input, **kwargs)
    File "/home/chaopan2/simulations/Py/p3/mvae/mt/mvae/models/vae.py", line 74, in forward
    q_z, p_z, _ = component(x_encoded)
    File "/home/chaopan2/miniconda3/envs/pt/lib/python3.7/site-packages/torch/nn/modules/module.py", line 493, in call
    result = self.forward(*input, **kwargs)
    File "/home/chaopan2/simulations/Py/p3/mvae/mt/mvae/components/component.py", line 33, in forward
    z_params = self.encode(x)
    File "/home/chaopan2/simulations/Py/p3/mvae/mt/mvae/components/component.py", line 64, in encode
    z_mean = self.fc_mean(x)
    File "/home/chaopan2/miniconda3/envs/pt/lib/python3.7/site-packages/torch/nn/modules/module.py", line 493, in call
    result = self.forward(*input, **kwargs)
    File "/home/chaopan2/miniconda3/envs/pt/lib/python3.7/site-packages/torch/nn/modules/linear.py", line 92, in forward
    return F.linear(input, self.weight, self.bias)
    File "/home/chaopan2/miniconda3/envs/pt/lib/python3.7/site-packages/torch/nn/functional.py", line 1406, in linear
    ret = torch.addmm(bias, input, weight.t())
    RuntimeError: size mismatch, m1: [100 x 8192], m2: [400 x 2] at /opt/conda/conda-bld/pytorch_1556653215914/work/aten/src/THC/generic/THCTensorMathBlas.cu:268

Desktop (please complete the following information):

  • OS: CentOS 7.7.1908
  • PyTorch version 1.1.0
  • Python version 3.7.3
  • other additional Python packages installed by requirements

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.