GithubHelp home page GithubHelp logo

lucidrains / bit-diffusion Goto Github PK

View Code? Open in Web Editor NEW
307.0 5.0 19.0 77 KB

Implementation of Bit Diffusion, Hinton's group's attempt at discrete denoising diffusion, in Pytorch

License: MIT License

Python 100.00%
artificial-intelligence deep-learning denoising-diffusion discrete

bit-diffusion's Introduction

Bit Diffusion - Pytorch

Implementation of Bit Diffusion, Hinton's group's attempt at discrete denoising diffusion, in Pytorch

It seems like they missed the mark for text, but the research direction still seems promising. I think a clean repository will do the research community a lot of benefits for those branching off from here.

Install

$ pip install bit-diffusion

Usage

from bit_diffusion import Unet, Trainer, BitDiffusion

model = Unet(
    dim = 32,
    channels = 3,
    dim_mults = (1, 2, 4, 8),
).cuda()

bit_diffusion = BitDiffusion(
    model,
    image_size = 128,
    timesteps = 100,
    time_difference = 0.1,       # they found in the paper that at lower number of timesteps, a time difference during sampling of greater than 0 helps FID. as timesteps increases, this time difference can be set to 0 as it does not help
    use_ddim = True              # use ddim
).cuda()

trainer = Trainer(
    bit_diffusion,
    '/path/to/your/data',             # path to your folder of images
    results_folder = './results',     # where to save results
    num_samples = 16,                 # number of samples
    train_batch_size = 4,             # training batch size
    gradient_accumulate_every = 4,    # gradient accumulation
    train_lr = 1e-4,                  # learning rate
    save_and_sample_every = 1000,     # how often to save and sample
    train_num_steps = 700000,         # total training steps
    ema_decay = 0.995,                # exponential moving average decay
)

trainer.train()

Results will be saved periodically to the ./results folder

If you would like to experiment with the Unet and BitDiffusion class outside the Trainer

import torch
from bit_diffusion import Unet, BitDiffusion

model = Unet(
    dim = 64,
    dim_mults = (1, 2, 4, 8)
)

bit_diffusion = BitDiffusion(
    model,
    image_size = 128,
    timesteps = 1000
)

training_images = torch.randn(8, 3, 128, 128) # images are normalized from 0 to 1
loss = bit_diffusion(training_images)
loss.backward()
# after a lot of training

sampled_images = bit_diffusion.sample(batch_size = 4)
sampled_images.shape # (4, 3, 128, 128)

Citations

@article{Chen2022AnalogBG,
    title   = {Analog Bits: Generating Discrete Data using Diffusion Models with Self-Conditioning},
    author  = {Ting Chen and Ruixiang Zhang and Geoffrey E. Hinton},
    journal = {ArXiv},
    year    = {2022},
    volume  = {abs/2208.04202}
}

bit-diffusion's People

Contributors

bruce2233 avatar lucidrains avatar rostro36 avatar rvorias avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar

bit-diffusion's Issues

Performance of this code in image generation task

Dear @lucidrains, thank you very much for providing the code. It's really helpful for my own research on diffusion model.

I am wondering if you have reached the reported performance in Analog Bits paper for the image generation task?

I am trying to reproduce the image captioning task in the paper, but unfortunately I couldn't reach the reported performance. So I am checking the code of the image generation task in this repo, to see if I did anything wrong in my own code.

Also, if you've tried reproducing the image captioning task, I would love to hear your suggestions.

Thank you very much!

compile with self-conditioning

Hi, thx for sharing your implementation, it is really easy to follow.

the recent release Pytorch 2.0 support compile which accelerates training a lot. But it seems that the self-conditioning prevents compiling. I am wondering if there is some small tweak that enables the model can be compiled.

Which formulation used in ddpm_sample() implementation?

Hi,

thanks for the implementation!

I am trying to understand how you implemented method ddpm_sample, in particular, how you incorporated "time_next" into into the denoising step. Which paper did you use as guidance? I was comparing your ddpm_sample with Algorithm 2 of paper "Denoising diffusion probabilistic models" but can't match your use of expm1, for instance.

Thanks a lot.

DDIM time delay

It seems the time delay in DDIM is at wrong place, as the delay is applied after releated operations are done.

for times, times_next in tqdm(time_pairs, desc = 'sampling loop time step'):
# get times and noise levels
log_snr = self.log_snr(times)
log_snr_next = self.log_snr(times_next)
padded_log_snr, padded_log_snr_next = map(partial(right_pad_dims_to, img), (log_snr, log_snr_next))
alpha, sigma = log_snr_to_alpha_sigma(padded_log_snr)
alpha_next, sigma_next = log_snr_to_alpha_sigma(padded_log_snr_next)
# add the time delay
times_next = (times_next - time_difference).clamp(min = 0.)
# predict x0
x_start = self.model(img, log_snr, x_start)
# clip x0
x_start.clamp_(-self.bit_scale, self.bit_scale)
# get predicted noise
pred_noise = (img - alpha * x_start) / sigma.clamp(min = 1e-8)
# calculate x next
img = x_start * alpha_next + pred_noise * sigma_next

Non-square images error

Hi Thanks for this implementation
Would be possible to run bit-diffusion with non-squared images ?

ex:

20x200

Lucas

small bug in bits_to_decimal

Hi,
thanks for the great work. I think there is a small bug in function bits_to_decimal. rearrange's parameter d should be parameter 'bits' instead of fixed 8 (e.g. in my test case it is 1).

`def bits_to_decimal(x, bits = BITS):
""" expects bits from -1 to 1, outputs image tensor from 0 to 1 """
device = x.device

x = (x > 0).int()
mask = 2 ** torch.arange(bits - 1, -1, -1, device = device, dtype = torch.int32)

mask = rearrange(mask, 'd -> d 1 1')
x = rearrange(x, 'b (c d) h w -> b c d h w', d = bits)  # TODO FIX: should be d=bits instead of d=8
dec = reduce(x * mask, 'b c d h w -> b c h w', 'sum')
return (dec / 255).clamp(0., 1.)`

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.