GithubHelp home page GithubHelp logo

ermongroup / ncsnv2 Goto Github PK

View Code? Open in Web Editor NEW
254.0 14.0 57.0 2.25 MB

The official PyTorch implementation for NCSNv2 (NeurIPS 2020)

License: MIT License

Python 100.00%
score-matching generative-models score-based-generative-modeling diffusion-models neurips-2020

ncsnv2's Introduction

Improved Techniques for Training Score-Based Generative Models

This repo contains the official implementation for the paper Improved Techniques for Training Score-Based Generative Models.

by Yang Song and Stefano Ermon, Stanford AI Lab.

Note: The method has been extended by the subsequent work Score-Based Generative Modeling through Stochastic Differential Equations (code) that allows better sample quality and exact log-likelihood computation.


We significantly improve the method proposed in Generative Modeling by Estimating Gradients of the Data Distribution. Score-based generative models are flexible neural networks trained to capture the score function of an underlying data distribution—a vector field pointing to directions where the data density increases most rapidly. We present new techniques to improve the performance of score-based generative models, scaling them to high resolution images that are previously impossible. Without requiring adversarial training, they can produce sharp and diverse image samples that rival GANs.

samples

(From left to right: Our samples on FFHQ 256px, LSUN bedroom 128px, LSUN tower 128px, LSUN church_outdoor 96px, and CelebA 64px.)

Running Experiments

Dependencies

Run the following to install all necessary python packages for our code.

pip install -r requirements.txt

Project structure

main.py is the file that you should run for both training and sampling. Execute python main.py --help to get its usage description:

usage: main.py [-h] --config CONFIG [--seed SEED] [--exp EXP] --doc DOC
               [--comment COMMENT] [--verbose VERBOSE] [--test] [--sample]
               [--fast_fid] [--resume_training] [-i IMAGE_FOLDER] [--ni]

optional arguments:
  -h, --help            show this help message and exit
  --config CONFIG       Path to the config file
  --seed SEED           Random seed
  --exp EXP             Path for saving running related data.
  --doc DOC             A string for documentation purpose. Will be the name
                        of the log folder.
  --comment COMMENT     A string for experiment comment
  --verbose VERBOSE     Verbose level: info | debug | warning | critical
  --test                Whether to test the model
  --sample              Whether to produce samples from the model
  --fast_fid            Whether to do fast fid test
  --resume_training     Whether to resume training
  -i IMAGE_FOLDER, --image_folder IMAGE_FOLDER
                        The folder name of samples
  --ni                  No interaction. Suitable for Slurm Job launcher

Configuration files are in config/. You don't need to include the prefix config/ when specifying --config . All files generated when running the code is under the directory specified by --exp. They are structured as:

<exp> # a folder named by the argument `--exp` given to main.py
├── datasets # all dataset files
├── logs # contains checkpoints and samples produced during training
│   └── <doc> # a folder named by the argument `--doc` specified to main.py
│      ├── checkpoint_x.pth # the checkpoint file saved at the x-th training iteration
│      ├── config.yml # the configuration file for training this model
│      ├── stdout.txt # all outputs to the console during training
│      └── samples # all samples produced during training
├── fid_samples # contains all samples generated for fast fid computation
│   └── <i> # a folder named by the argument `-i` specified to main.py
│      └── ckpt_x # a folder of image samples generated from checkpoint_x.pth
├── image_samples # contains generated samples
│   └── <i>
│       └── image_grid_x.png # samples generated from checkpoint_x.pth       
└── tensorboard # tensorboard files for monitoring training
    └── <doc> # this is the log_dir of tensorboard

Training

For example, we can train an NCSNv2 on LSUN bedroom by running the following

python main.py --config bedroom.yml --doc bedroom

Log files will be saved in <exp>/logs/bedroom.

Sampling

If we want to sample from NCSNv2 on LSUN bedroom, we can edit bedroom.yml to specify the ckpt_id under the group sampling, and then run the following

python main.py --sample --config bedroom.yml -i bedroom

Samples will be saved in <exp>/image_samples/bedroom.

We can interpolate between different samples (see more details in the paper). Just set interpolation to true and an appropriate n_interpolations under the group of sampling in bedroom.yml. We can also perform other tasks such as inpainting. Usages should be quite obvious if you read the code and configuration files carefully.

Computing FID values quickly for a range of checkpoints

We can specify begin_ckpt and end_ckpt under the fast_fid group in the configuration file. For example, by running the following command, we can generate a small number of samples per checkpoint within the range begin_ckpt-end_ckpt for a quick (and rough) FID evaluation.

python main.py --fast_fid --config bedroom.yml -i bedroom

You can find samples in <exp>/fid_samples/bedroom.

Pretrained Checkpoints

Link: https://drive.google.com/drive/folders/1217uhIvLg9ZrYNKOR3XTRFSurt4miQrd?usp=sharing

You can produce samples using it on all datasets we tested in the paper. It assumes the --exp argument is set to exp.

References

If you find the code/idea useful for your research, please consider citing

@inproceedings{song2020improved,
  author    = {Yang Song and Stefano Ermon},
  editor    = {Hugo Larochelle and
               Marc'Aurelio Ranzato and
               Raia Hadsell and
               Maria{-}Florina Balcan and
               Hsuan{-}Tien Lin},
  title     = {Improved Techniques for Training Score-Based Generative Models},
  booktitle = {Advances in Neural Information Processing Systems 33: Annual Conference
               on Neural Information Processing Systems 2020, NeurIPS 2020, December
               6-12, 2020, virtual},
  year      = {2020}
}

and/or our previous work

@inproceedings{song2019generative,
  title={Generative Modeling by Estimating Gradients of the Data Distribution},
  author={Song, Yang and Ermon, Stefano},
  booktitle={Advances in Neural Information Processing Systems},
  pages={11895--11907},
  year={2019}
}

ncsnv2's People

Contributors

dependabot[bot] avatar yang-song avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

ncsnv2's Issues

Low GPU utility

Hi Song Yang! Thank you for your code!
I am running your Network in a local (single GPU: RTX TITAN) machine, but the GPU utility is only around 30% during sampling. Did you notice the same issue?
Best Wishes,
Tianrong

Fid score is different from the paper

Hi Author,
Thanks for your great work.
For the three checkpoints you provided "best_with_denoise", "best_without_denoise" and "checkpoint_300000.pth", I got fid scores around 39 by specifying the argument "--fast_fid". But the paper shows that the Fid score is around 10. Am I missing something? Any help would be appreciated, thanks!

Requirements - Pykerberos, Numba, Scikit-image

I think the requirements.txt file should be checked for version validity problems.

Trying: pip install -r requirements.txt yields errors concerning the numba, scikit-image and pykerberos packages.

pykerberos's issue is Q&Aed here: https://github.com/requests/requests-kerberos/issues/109

Both scikit-image and numba, run on Colab, output:

.
.
.
Building wheels for collected packages: scikit-image
Building wheel for scikit-image (setup.py) ... error
ERROR: Failed building wheel for scikit-image
Running setup.py clean for scikit-image
Failed to build scikit-image
.
.
.
Running setup.py install for scikit-image ... error
Rolling back uninstall of scikit-image
Moving to /usr/local/bin/skivi
from /tmp/pip-uninstall-6hpppytk/skivi
Moving to /usr/local/lib/python3.7/dist-packages/scikit_image-0.14.0.dist-info/
from /usr/local/lib/python3.7/dist-packages/~cikit_image-0.14.0.dist-info
Moving to /usr/local/lib/python3.7/dist-packages/skimage/
from /usr/local/lib/python3.7/dist-packages/~kimage

ERROR: Command errored out with exit status 1: /usr/bin/python3 -u -c 'import io, os, sys, setuptools, tokenize; sys.argv[0] = '"'"'/tmp/pip-install-k0wygbfd/scikit-image_c91cf20f99884249931a8a1d91256722/setup.py'"'"'; __file__='"'"'/tmp/pip-install-k0wygbfd/scikit-image_c91cf20f99884249931a8a1d91256722/setup.py'"'"';f = getattr(tokenize, '"'"'open'"'"', open)(__file__) if os.path.exists(__file__) else io.StringIO('"'"'from setuptools import setup; setup()'"'"');code = f.read().replace('"'"'\r\n'"'"', '"'"'\n'"'"');f.close();exec(compile(code, __file__, '"'"'exec'"'"'))' install --record /tmp/pip-record-7dzzbcjg/install-record.txt --single-version-externally-managed --compile --install-headers /usr/local/include/python3.7/scikit-image Check the logs for full command output.

Not sure, but I think it is due to version upgrades. I fixed it by changing scikit-image's version to 0.14.0 and numba's to 0.39.0

Still testing the repo, though, to check if everything is working fine.

requirement.txt

As I followed the original requirement.txt, I encountered much trouble.

  1. Package conflict. (I have forgotten how to solve it,all in all I cannot use the same original requirement.txt)
  2. require low version of protobuf.
  3. As I use torch==1.5.0+py37,it takes several minutes in the following code,which is abviously unusual.
    sigmas = torch.tensor( np.exp(np.linspace(np.log(90), np.log(0.01), 500)) ).float().to(device)
    I found maybe the reason of this blog.

Finally, I didn't use the version of torch and python in original requirement.txt, it run sucesssfully. my requirements as follows:

  • lmdb==1.4.1
    numpy==1.19.5
    pandas==1.1.5
    Pillow==10.2.0
    PyYAML==6.0.1
    PyYAML==6.0.1
    Requests==2.31.0
    scikit_learn==1.4.0
    scipy==1.12.0
    six==1.16.0
    torch==1.10.1+cu111
    torchvision==0.11.2+cu111
    tqdm==4.66.1

`

rectangular data

Hi there thanks for providing your repo!
I have been trying to adapt the Ncsnv2 model to work with rectangular data, say 5x1024 instead of normal images. I have found that the only noncompatible section is the ConvMeanPool layer uses the summation:

output = sum([output[:, :, ::2, ::2], output[:, :, 1::2, ::2],
                      output[:, :, ::2, 1::2], output[:, :, 1::2, 1::2]]) / 4.

Which is the only section that requires square data with 2 even dimensions. Are there any recommendations for some way around this?

Training on my own semantic dataset

Hi,

I'm new to this area and would like to train/fine tune on my own dataset.

Can you tell me how I can setup my own folder of images for training using main.py?

Thanks

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.