GithubHelp home page GithubHelp logo

pytorch-lightning-vae's Introduction

VAE for color images in PyTorch Lightning

This repo is an implementation for the matching medium tutorial

reconstructions on cifar-10

image

To run

pip install -r requirements.txt

python vae.py --gpus 1

pytorch-lightning-vae's People

Contributors

williamfalcon avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar

pytorch-lightning-vae's Issues

self.log_scale in line 34 in vae.py should be a vector of size latent_dim

Hello William,
This is a very good code, the only issue is that initially, we were not able to train on our dataset but after reading the VAE paper, we found that self.log_scale parameter should be a vector of the size of output dimension instead of a real no. . After correcting this our code started to train. Please see if this finding is relevant here.

AttributeError: 'CIFAR10DataModule' object has no attribute 'train_transforms'

When I try to run the VAE code I get this error:

---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
/tmp/ipykernel_2476/1748225191.py in <module>
      4 vae = VAE()
      5 trainer = pl.Trainer(gpus=1, max_epochs=20, callbacks=[sampler])
----> 6 trainer.fit(vae, dataset)

~/.miniconda/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py in fit(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)
    695         self.strategy.model = model
    696         self._call_and_handle_interrupt(
--> 697             self._fit_impl, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path
    698         )
    699 

~/.miniconda/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py in _call_and_handle_interrupt(self, trainer_fn, *args, **kwargs)
    648                 return self.strategy.launcher.launch(trainer_fn, *args, trainer=self, **kwargs)
    649             else:
--> 650                 return trainer_fn(*args, **kwargs)
    651         # TODO(awaelchli): Unify both exceptions below, where `KeyboardError` doesn't re-raise
    652         except KeyboardInterrupt as exception:

~/.miniconda/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py in _fit_impl(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)
    733             ckpt_path, model_provided=True, model_connected=self.lightning_module is not None
    734         )
--> 735         results = self._run(model, ckpt_path=self.ckpt_path)
    736 
    737         assert self.state.stopped

~/.miniconda/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py in _run(self, model, ckpt_path)
   1103         self.__setup_profiler()
   1104 
-> 1105         self._call_setup_hook()  # allow user to setup lightning_module in accelerator environment
   1106 
   1107         # check if we should delay restoring checkpoint till later

~/.miniconda/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py in _call_setup_hook(self)
   1445 
   1446         if self.datamodule is not None:
-> 1447             self._call_lightning_datamodule_hook("setup", stage=fn)
   1448         self._call_callback_hooks("setup", stage=fn)
   1449         self._call_lightning_module_hook("setup", stage=fn)

~/.miniconda/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py in _call_lightning_datamodule_hook(self, hook_name, *args, **kwargs)
   1567         if callable(fn):
   1568             with self.profiler.profile(f"[LightningDataModule]{self.datamodule.__class__.__name__}.{hook_name}"):
-> 1569                 return fn(*args, **kwargs)
   1570 
   1571     def _call_callback_hooks(

~/.miniconda/lib/python3.7/site-packages/pl_bolts/datamodules/vision_datamodule.py in setup(self, stage)
     69         """
     70         if stage == "fit" or stage is None:
---> 71             train_transforms = self.default_transforms() if self.train_transforms is None else self.train_transforms
     72             val_transforms = self.default_transforms() if self.val_transforms is None else self.val_transforms
     73 

AttributeError: 'CIFAR10DataModule' object has no attribute 'train_transforms'

I am using pl '1.7.7' and pl_bolts '0.3.2post1'.

quality of image is not as good as the one in the blog

Hi William,
I am following your script to build cifar10 model, I checked the image after 20 epochs, which is not as good as the one, not sure if you can share any hint on what is wrong in my experiment (I did not change the script actually), thanks a lot!
2021-02-06_02-22

Big reconstruction loss

Hi, thanks for your great job!
I have modeled each parameter out of the decoder as a gaussian in my dataset. although I got high-quality reconstruct mages, however, the reconstruct loss is very big. How to solve this problem?
Thank you again!

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.