GithubHelp home page GithubHelp logo

mindiffusion's Introduction

minDiffusion

Goal of this educational repository is to provide a self-contained, minimalistic implementation of diffusion models using Pytorch.

Many implementations of diffusion models can be a bit overwhelming. Here, superminddpm : under 200 lines of code, fully self contained implementation of DDPM with Pytorch is a good starting point for anyone who wants to get started with Denoising Diffusion Models, without having to spend time on the details.

Simply:

$ python superminddpm.py

Above script is self-contained. (Of course, you need to have pytorch and torchvision installed. Latest version should suffice. We do not use any cutting edge features.)

If you want to use the bit more refactored code, that runs CIFAR10 dataset:

$ python train_cifar10.py

Above result took about 2 hours of training on single 3090 GPU. Top 8 images are generated, bottom 8 are ground truth.

Here is another example, trained on 100 epochs (about 1.5 hours)

Currently has:

  • Tiny implementation of DDPM
  • MNIST, CIFAR dataset.
  • Simple unet structure. + Simple Time embeddings.
  • CelebA dataset.

TODOS

  • DDIM
  • Classifier Guidance
  • Multimodality

Updates!

  • Using more parameter yields better result for MNIST.
  • More comments in superminddpm.py

mindiffusion's People

Contributors

cloneofsimo avatar silencemonk avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

mindiffusion's Issues

how to load custom dataset?

Hi! Im new to this stuff and don't understand how can i download a custom dataset and how to add it to this model? My questions are: what format should downloaded dataset be and how to train the model on it?

How to overfit a single image

I'm having the following samples after training:

  • 200 epochs for 1 CELEBA image
    ddpm_1img_200epochs

  • 100 epochs for 1000 CELEBA images
    ddpm_1000imgs_100epochs

Shouldn't using only 1 image for training make the model to overfit that image in a few epochs and produce always that image for any given z?

Why does using more training samples makes the model to converge faster?

Thank you!

Why are you normalizing to 1.414 in unet.py?

Why are you normalizing to 1.414 in unet.py?

class Conv3(nn.Module):
...

def forward(self, x: torch.Tensor) -> torch.Tensor:
    x = self.main(x)
    if self.is_res:
        x = x + self.conv(x)
        return x / 1.414 # <= here
    else:
        return self.conv(x)

Quality of the generated images

Dear @cloneofsimo and @SilenceMonk, thanks so much for this code! It is beneficial and precisely the missing piece I need to understand diffusion models better. Also, I appreciated that I could just run the CIFAR10 training without any code modification.
I am playing around with your code to better understand the guided_diffusion repository, which I find too complex and I need to simplify.

I have trained on cifar10 and obtained the following results after 100 epochs.
ddpm_sample_cifar99

As you can see, the prediction quality seems quite far from the ground truth.
I plan to extend your code to images with a larger resolution, however, I am hesitant now, as I do not understand if the network is learning or not. I would like to extend the code while maintaining convergence.

i) Is this behavior normal? Is there some critical hyperameter to tune to obtain clearer images?

UPDATE: I have trained on celebA and obtained the following results after 21 epochs (approx 14 hours on a 3090):
ddpm_sample_celeba021
The celebA results seem already better than the cifar10, but I might need more training epochs because the generated images are still far from the groundtruth.

Still referring to the celebA results, you can see in the following image that the generated images could show a constant color,
background.
(below you can see celebA after 19 epochs)
ddpm_sample_celeba020
This issue is similar to openai/guided-diffusion#81 .

Furthermore, you can see that the training does not progress linearly, if we take epoch 22 of celebA, we can notice that the network outputs smooth predictions with no structure again.

ddpm_sample_celeba022

So overall I am not getting the training stability I was expecting. These results are (unfortunately) consistent with my issues for the guided_diffusion repository openai/guided-diffusion#42 .

iii) do you have any comment which could help overcame this issue?

Thanks again for your help!
Stefano

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.