oskopek / mvae Goto Github PK
View Code? Open in Web Editor NEWMixed-curvature Variational Autoencoders (ICLR 2020)
Home Page: https://openreview.net/forum?id=S1g6xeSKDS
License: Apache License 2.0
Mixed-curvature Variational Autoencoders (ICLR 2020)
Home Page: https://openreview.net/forum?id=S1g6xeSKDS
License: Apache License 2.0
Describe the bug
Despite setting the --fixed_curvature=True
flag I can see that curvature is being changed during training.
To Reproduce
Steps to reproduce the behavior:
python -m mt.examples.run --dataset="mnist" --model="s10" --fixed_curvature=True
comp_000_s10/curvature
is being changedExpected behavior
Curvature should be constant and set to the default value, which is 1
.
Desktop (please complete the following information):
Thanks for the package!
When running (Mac Hierra): python -m mt.examples.run --dataset="mnist" --model="h2,s2,e2" --fixed_curvature=False
I got an error after 100 epochs: FileNotFoundError: [Errno 2] No such file or directory: './chkpt/vae-mnist-e2,h2,s2-2020-06-14T17:45:14.337416/None.chkpt'
Thanks in advance!
Hi, thank you for open-sourcing the repo! Is there a particular reason why the exponential map of the Euclidean manifold is written as x+v/2 instead of x+v?
Dear authors,
Thanks for the great code! I am wondering do you also have the pre-trained embeddings of those four datasets included in the repository? Because I want to do some visualization for the embeddings. Seems like they should be in mt/data/, but only find four scripts there.
Also when I want to use the mt/examples/run.py to generate the embeddings myself, I find that with the following parameters:
--dataset="mnist" --model="h2,s2,e2" --fixed_curvature=True --epochs=50 --warmup=10 --lookahead=5 --show_embeddings=10
only the embeddings of the last epoch 49 will be recorded in chkpt/ folder, instead of recording embeddings every 10 epochs. I looked into the code a little bit, finding that show_embeddings
is only used in _test_epoch
, which is only called once in train_stopping
. Maybe I miss something here, but seems like it will only save the last embeddings.
Thanks for your help in advance!
Hi, thanks for your paper, and I really enjoy it.
I found the code about calculating the logdet is different from that in the paper.
Specifically, the torch.log(radius) might be redundant in https://github.com/oskopek/mvae/blob/master/mt/mvae/ops/hyperbolics.py#L63,
since log(R) is contained by torch.log(r).
Also, the logdet in sphere might have the same problem.
Could you please help me check the codes? Thanks a lot!
Describe the bug
Running the example from README gives the following error:
Traceback (most recent call last):
File "/home/maciej/.conda/envs/pt/lib/python3.7/runpy.py", line 193, in _run_module_as_main
"__main__", mod_spec)
File "/home/maciej/.conda/envs/pt/lib/python3.7/runpy.py", line 85, in _run_code
exec(code, run_globals)
File "/home/maciej/Documents/mvae/mt/examples/run.py", line 22, in <module>
from ..data import create_dataset
File "/home/maciej/Documents/mvae/mt/data/__init__.py", line 18, in <module>
from .image_reconstruction import CifarVaeDataset, MnistVaeDataset, OmniglotVaeDataset
File "/home/maciej/Documents/mvae/mt/data/image_reconstruction.py", line 24, in <module>
from ..mvae.distributions import EuclideanUniform
File "/home/maciej/Documents/mvae/mt/mvae/distributions/__init__.py", line 16, in <module>
from .hyperspherical_uniform import HypersphericalUniform
File "/home/maciej/Documents/mvae/mt/mvae/distributions/hyperspherical_uniform.py", line 24, in <module>
from ..ops.common import ln_2, ln_pi
File "/home/maciej/Documents/mvae/mt/mvae/ops/__init__.py", line 18, in <module>
from .poincare import PoincareBall
File "/home/maciej/Documents/mvae/mt/mvae/ops/poincare.py", line 18, in <module>
import geoopt.manifolds.poincare.math as pm
File "/home/maciej/.local/lib/python3.7/site-packages/geoopt/__init__.py", line 1, in <module>
from . import manifolds
File "/home/maciej/.local/lib/python3.7/site-packages/geoopt/manifolds/__init__.py", line 6, in <module>
from .birkhoff_polytope import BirkhoffPolytope
File "/home/maciej/.local/lib/python3.7/site-packages/geoopt/manifolds/birkhoff_polytope.py", line 203, in <module>
x, max_iter: int = 300, eps: float = 1e-5, tol: float = 1e-5
File "/home/maciej/.conda/envs/pt/lib/python3.7/site-packages/torch/jit/__init__.py", line 823, in script
ast = get_jit_def(obj)
File "/home/maciej/.conda/envs/pt/lib/python3.7/site-packages/torch/jit/frontend.py", line 158, in get_jit_def
return build_def(ctx, py_ast.body[0], type_line, self_name)
File "/home/maciej/.conda/envs/pt/lib/python3.7/site-packages/torch/jit/frontend.py", line 198, in build_def
build_stmts(ctx, body))
File "/home/maciej/.conda/envs/pt/lib/python3.7/site-packages/torch/jit/frontend.py", line 122, in build_stmts
stmts = [build_stmt(ctx, s) for s in stmts]
File "/home/maciej/.conda/envs/pt/lib/python3.7/site-packages/torch/jit/frontend.py", line 122, in <listcomp>
stmts = [build_stmt(ctx, s) for s in stmts]
File "/home/maciej/.conda/envs/pt/lib/python3.7/site-packages/torch/jit/frontend.py", line 174, in __call__
return method(ctx, node)
File "/home/maciej/.conda/envs/pt/lib/python3.7/site-packages/torch/jit/frontend.py", line 319, in build_While
build_stmts(ctx, stmt.body))
File "/home/maciej/.conda/envs/pt/lib/python3.7/site-packages/torch/jit/frontend.py", line 122, in build_stmts
stmts = [build_stmt(ctx, s) for s in stmts]
File "/home/maciej/.conda/envs/pt/lib/python3.7/site-packages/torch/jit/frontend.py", line 122, in <listcomp>
stmts = [build_stmt(ctx, s) for s in stmts]
File "/home/maciej/.conda/envs/pt/lib/python3.7/site-packages/torch/jit/frontend.py", line 174, in __call__
return method(ctx, node)
File "/home/maciej/.conda/envs/pt/lib/python3.7/site-packages/torch/jit/frontend.py", line 332, in build_If
build_stmts(ctx, stmt.body),
File "/home/maciej/.conda/envs/pt/lib/python3.7/site-packages/torch/jit/frontend.py", line 122, in build_stmts
stmts = [build_stmt(ctx, s) for s in stmts]
File "/home/maciej/.conda/envs/pt/lib/python3.7/site-packages/torch/jit/frontend.py", line 122, in <listcomp>
stmts = [build_stmt(ctx, s) for s in stmts]
File "/home/maciej/.conda/envs/pt/lib/python3.7/site-packages/torch/jit/frontend.py", line 173, in __call__
raise UnsupportedNodeError(ctx, node)
torch.jit.frontend.UnsupportedNodeError: break statements aren't supported
x, max_iter: int = 300, eps: float = 1e-5, tol: float = 1e-5
):
iter = 0
c = 1.0 / (x.sum(dim=-2, keepdim=True) + eps)
r = 1.0 / ((x @ c.transpose(-1, -2)) + eps)
while iter < max_iter:
iter += 1
cinv = torch.matmul(r.transpose(-1, -2), x)
if torch.max(torch.abs(cinv * c - 1)) <= tol:
break
~~~~~ <--- HERE
c = 1.0 / (cinv + eps)
r = 1.0 / ((x @ c.transpose(-1, -2)) + eps)
return x * (r @ c)
To Reproduce
Steps to reproduce the behavior:
python -m mt.examples.run --dataset="mnist" --model="h2,s2,e2" --fixed_curvature=False
Desktop (please complete the following information):
Hello!
Great code, thank you!
However, I have a small question: usually, when VAE is in the eval mode we only generate samples using the mean (so we don't sample from the latent distribution and simply forward encoder(x)
to decoder). Is it possible to enforce such a behavior in your models?
Thanks in advance!
Describe the bug
The following RuntimeError appears when I am trying to run run.py on cifar dataset.
To Reproduce
Steps to reproduce the behavior:
conda activate pt
python -m mt.examples.run --dataset="cifar" --model="e2,h2,s2" --fixed_curvature=False --epochs=200
Desktop (please complete the following information):
A declarative, efficient, and flexible JavaScript library for building user interfaces.
๐ Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
An Open Source Machine Learning Framework for Everyone
The Web framework for perfectionists with deadlines.
A PHP framework for web artisans
Bring data to life with SVG, Canvas and HTML. ๐๐๐
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
Some thing interesting about web. New door for the world.
A server is a program made to process requests and deliver data to clients.
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
Some thing interesting about visualization, use data art
Some thing interesting about game, make everyone happy.
We are working to build community through open source technology. NB: members must have two-factor auth.
Open source projects and samples from Microsoft.
Google โค๏ธ Open Source for everyone.
Alibaba Open Source for everyone
Data-Driven Documents codes.
China tencent open source team.