GithubHelp home page GithubHelp logo

lucidrains / gigagan-pytorch Goto Github PK

View Code? Open in Web Editor NEW
1.4K 76.0 65.0 1.14 MB

Implementation of GigaGAN, new SOTA GAN out of Adobe. Culmination of nearly a decade of research into GANs

License: MIT License

Python 100.00%
artificial-intelligence deep-learning generative-adversarial-network

gigagan-pytorch's Introduction

GigaGAN - Pytorch

Implementation of GigaGAN (project page), new SOTA GAN out of Adobe.

I will also add a few findings from lightweight gan, for faster convergence (skip layer excitation) and better stability (reconstruction auxiliary loss in discriminator)

It will also contain the code for the 1k - 4k upsamplers, which I find to be the highlight of this paper.

Please join Join us on Discord if you are interested in helping out with the replication with the LAION community

Appreciation

  • StabilityAI and ๐Ÿค— Huggingface for the generous sponsorship, as well as my other sponsors, for affording me the independence to open source artificial intelligence.

  • ๐Ÿค— Huggingface for their accelerate library

  • All the maintainers at OpenClip, for their SOTA open sourced contrastive learning text-image models

  • Xavier for the very helpful code review, and for discussions on how the scale invariance in the discriminator should be built!

  • @CerebralSeed for pull requesting the initial sampling code for both the generator and upsampler!

  • Keerth for the code review and pointing out some discrepancies with the paper!

Install

$ pip install gigagan-pytorch

Usage

Simple unconditional GAN, for starters

import torch

from gigagan_pytorch import (
    GigaGAN,
    ImageDataset
)

gan = GigaGAN(
    generator = dict(
        dim_capacity = 8,
        style_network = dict(
            dim = 64,
            depth = 4
        ),
        image_size = 256,
        dim_max = 512,
        num_skip_layers_excite = 4,
        unconditional = True
    ),
    discriminator = dict(
        dim_capacity = 16,
        dim_max = 512,
        image_size = 256,
        num_skip_layers_excite = 4,
        unconditional = True
    ),
    amp = True
).cuda()

# dataset

dataset = ImageDataset(
    folder = '/path/to/your/data',
    image_size = 256
)

dataloader = dataset.get_dataloader(batch_size = 1)

# you must then set the dataloader for the GAN before training

gan.set_dataloader(dataloader)

# training the discriminator and generator alternating
# for 100 steps in this example, batch size 1, gradient accumulated 8 times

gan(
    steps = 100,
    grad_accum_every = 8
)

# after much training

images = gan.generate(batch_size = 4) # (4, 3, 256, 256)

For unconditional Unet Upsampler

import torch
from gigagan_pytorch import (
    GigaGAN,
    ImageDataset
)

gan = GigaGAN(
    train_upsampler = True,     # set this to True
    generator = dict(
        style_network = dict(
            dim = 64,
            depth = 4
        ),
        dim = 32,
        image_size = 256,
        input_image_size = 64,
        unconditional = True
    ),
    discriminator = dict(
        dim_capacity = 16,
        dim_max = 512,
        image_size = 256,
        num_skip_layers_excite = 4,
        multiscale_input_resolutions = (128,),
        unconditional = True
    ),
    amp = True
).cuda()

dataset = ImageDataset(
    folder = '/path/to/your/data',
    image_size = 256
)

dataloader = dataset.get_dataloader(batch_size = 1)

gan.set_dataloader(dataloader)

# training the discriminator and generator alternating
# for 100 steps in this example, batch size 1, gradient accumulated 8 times

gan(
    steps = 100,
    grad_accum_every = 8
)

# after much training

lowres = torch.randn(1, 3, 64, 64).cuda()

images = gan.generate(lowres) # (1, 3, 256, 256)

Losses

  • G - Generator
  • MSG - Multiscale Generator
  • D - Discriminator
  • MSD - Multiscale Discriminator
  • GP - Gradient Penalty
  • SSL - Auxiliary Reconstruction in Discriminator (from Lightweight GAN)
  • VD - Vision-aided Discriminator
  • VG - Vision-aided Generator
  • CL - Generator Constrastive Loss
  • MAL - Matching Aware Loss

A healthy run would have G, MSG, D, MSD with values hovering between 0 to 10, and usually staying pretty constant. If at any time after 1k training steps these values persist at triple digits, that would mean something is wrong. It is ok for generator and discriminator values to occasionally dip negative, but it should swing back up to the range above.

GP and SSL should be pushed towards 0. GP can occasionally spike; I like to imagine it as the networks undergoing some epiphany

Multi-GPU Training

The GigaGAN class is now equipped with ๐Ÿค— Accelerator. You can easily do multi-gpu training in two steps using their accelerate CLI

At the project root directory, where the training script is, run

$ accelerate config

Then, in the same directory

$ accelerate launch train.py

Todo

  • make sure it can be trained unconditionally

  • read the relevant papers and knock out all 3 auxiliary losses

    • matching aware loss
    • clip loss
    • vision-aided discriminator loss
    • add reconstruction losses on arbitrary stages in the discriminator (lightweight gan)
    • figure out how the random projections are used from projected-gan
    • vision aided discriminator needs to extract N layers from the vision model in CLIP
    • figure out whether to discard CLS token and reshape into image dimensions for convolution, or stick with attention and condition with adaptive layernorm - also turn off vision aided gan in unconditional case
  • unet upsampler

    • add adaptive conv
    • modify latter stage of unet to also output rgb residuals, and pass the rgb into discriminator. make discriminator agnostic to rgb being passed in
    • do pixel shuffle upsamples for unet
  • get a code review for the multi-scale inputs and outputs, as the paper was a bit vague

  • add upsampling network architecture

  • make unconditional work for both base generator and upsampler

  • make text conditioned training work for both base and upsampler

  • make recon more efficient by random sampling patches

  • make sure generator and discriminator can also accept pre-encoded CLIP text encodings

  • do a review of the auxiliary losses

    • add contrastive loss for generator
    • add vision aided loss
    • add gradient penalty for vision aided discr - make optional
    • add matching awareness loss - figure out if rotating text conditions by one is good enough for mismatching (without drawing an additional batch from dataloader)
    • make sure gradient accumulation works with matching aware loss
    • matching awareness loss runs and is stable
    • vision aided trains
  • add some differentiable augmentations, proven technique from the old GAN days

    • remove any magic being done with automatic rgbs processing, and have it explicitly passed in - offer functions on the discriminator that can process real images into the right multi-scales
    • add horizontal flip for starters
  • move all modulation projections into the adaptive conv2d class

  • add accelerate

    • works single machine
    • works for mixed precision (make sure gradient penalty is scaled correctly), take care of manual scaler saving and reloading, borrow from imagen-pytorch
    • make sure it works multi-GPU for one machine
    • have someone else try multiple machines
  • clip should be optional for all modules, and managed by GigaGAN, with text -> text embeds processed once

  • add ability to select a random subset from multiscale dimension, for efficiency

  • port over CLI from lightweight|stylegan2-pytorch

  • hook up laion dataset for text-image

Citations

@misc{https://doi.org/10.48550/arxiv.2303.05511,
    url     = {https://arxiv.org/abs/2303.05511},
    author  = {Kang, Minguk and Zhu, Jun-Yan and Zhang, Richard and Park, Jaesik and Shechtman, Eli and Paris, Sylvain and Park, Taesung},  
    title   = {Scaling up GANs for Text-to-Image Synthesis},
    publisher = {arXiv},
    year    = {2023},
    copyright = {arXiv.org perpetual, non-exclusive license}
}
@article{Liu2021TowardsFA,
    title   = {Towards Faster and Stabilized GAN Training for High-fidelity Few-shot Image Synthesis},
    author  = {Bingchen Liu and Yizhe Zhu and Kunpeng Song and A. Elgammal},
    journal = {ArXiv},
    year    = {2021},
    volume  = {abs/2101.04775}
}
@inproceedings{dao2022flashattention,
    title   = {Flash{A}ttention: Fast and Memory-Efficient Exact Attention with {IO}-Awareness},
    author  = {Dao, Tri and Fu, Daniel Y. and Ermon, Stefano and Rudra, Atri and R{\'e}, Christopher},
    booktitle = {Advances in Neural Information Processing Systems},
    year    = {2022}
}
@inproceedings{Karras2020ada,
    title     = {Training Generative Adversarial Networks with Limited Data},
    author    = {Tero Karras and Miika Aittala and Janne Hellsten and Samuli Laine and Jaakko Lehtinen and Timo Aila},
    booktitle = {Proc. NeurIPS},
    year      = {2020}
}

gigagan-pytorch's People

Contributors

cerebralseed avatar lucidrains avatar nbardy avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

gigagan-pytorch's Issues

Turn on/off gradients computation between generator/discriminator

Hi! was reviewing some code parts of your implementation (great as usual!) and noticed that you dont alter gradient computation for models when its not required. for example computing losses for discriminator we dont need gradients in generator model and vice versa, i believe it should save some memory?

Some minor questions regarding the design

Thanks for updating the implementation frequently! The codebase looks much nicer than when I first looked at it. I took a closer look on the details and would like to ask some questions on specific design choice. Of course I understand that not every specific design comes with a reason, but I would like to know if you have reference/intuition on these things:

  1. It seems like in the main generator, self attention comes before cross attention, while in the upsampler, cross attention comes before self attention.
  2. This line has a residual connection, but it is already done inside the Transformer class. Same here. Is this something new in transformer literature?
  3. Here in the discriminator, the residual is added after attention block. Does it make more sense to add it right after two conv blocks, since the attention block has its own residual connection?
  4. Very tiny issue. The definition in this is unused.

Weights of Gigagan Upscaler

Hi @lucidrains, are the weights of gigagan released, specially for upscaler?

If weights not available then on which dataset i can train the upscaler as not mentioned in readme file. And the format of images to make it for the dataloader

L2 attention is implemented wrong!

From paper
https://arxiv.org/pdf/2006.04710.pdf

First, token needs to attend itself to ensure Lipschitz!

image

Second, torch.cdist is not the correct way to do it. Follow the original paper.

image

for tied qk

AB = torch.matmul(qk, qk.transpose(-1, -2))
AA = torch.sum(qk ** 2, -1, keepdim=True)
BB = AA.transpose(-1, -2)    # Since query and key are tied.
attn = -(AA - 2 * AB + BB)
attn = attn.mul(self.scale).softmax(-1)

for separate qk

AB = torch.matmul(q, k.transpose(-1, -2))
AA = torch.sum(q ** 2, -1, keepdim=True)
BB = torch.sum(k ** 2, -1, keepdim=True).transpose(-1, -2)
attn = -(AA - 2 * AB + BB)
attn = attn.mul(self.scale).softmax(-1)

This is basically torch.cdist().square(), but more efficient and supports double backward for r1 regularization.

Last, I believe the paper only used L2 self attention in discriminator. The generator should still use dot attention.

Insert the attention mechanism in the codes to the stylegan2-ada but fail

Hello, first of all, thank you for your great work.
I just inserted the self-attention and cross-attention module to stylegan2-ada, following behind the conv1 SynthesisLayer on the 4,8,16,32 resolution. And I use the CLIP to extract txt features and the contrastive loss. However, after a few iterations, the resulting pictures became completely the same color. Could you please give me some advice on how to improve it?

License Violation Issue

Hello Phil,

Thanks for the work you do. Someone else has been copying your code without respecting your MIT license and including your name, or any reference to you, whatsoever. Please see here:

https://github.com/jianzhnie/GigaGAN

His code updates to nearly identical to yours within a day of you making any updates. As you can see, I made an issue on his project page to address the licensing issue, which he closed without updating the license agreement or making any mention of you.

Here is how to request a DMCA page takedown in the instance of license violations:
https://docs.github.com/en/site-policy/content-removal-policies/guide-to-submitting-a-dmca-takedown-notice

Just wanted to bring this to your attention.

Cheers

Confused about this project?

Is ur project a functional Gigagan model that any user can use? if so I'm confused and the steps of how i would even use the model? and what it generates or even upscaling. because if it's an actual working implementation of GigaGan that would be dope

Gradient Penalty is very high in the start

Hi!

i was running few experiments and noticed that GP is extremely hight in first few 100 steps.
GP > 60000, and then gradually going down to around GP = 20

is it normal behaviour? In my previous experience with StyleGan GP was small in the beginning

Pretrained models

Thanks for the great work!
May I know when will the pretrained models be available?
Thanks.

NaN losses after hours of training (UPSAMPLER)

I keep getting NaN losses like this after a few hours of of training:
| 1360/1000000 [7:25:23<5127:52:45, 18.49s/it]G: 0.75 | MSG: -0.17 | VG: 0.00 | D: nan | MSD: 1.97 | VD: 0.00 | GP: nan | SSL: nan | CL: 0.00 | MAL: 0.00
0%|โ–Ž | 1380/1000000 [7:31:46<5217:53:00, 18.81s/it]G: 0.76 | MSG: -0.17 | VG: 0.00 | D: nan | MSD: 1.97 | VD: 0.00 | GP: nan | SSL: nan | CL: 0.00 | MAL: 0.00
I'm training on about 200k images with the settings on the README.
gan = GigaGAN(
train_upsampler = True,
generator = dict(
style_network = dict(
dim = 64,
depth = 4
),
dim = 32,
image_size = 256,
input_image_size = 64,
unconditional = True
),
discriminator = dict(
dim_capacity = 16,
dim_max = 512,
image_size = 256,
num_skip_layers_excite = 4,
multiscale_input_resolutions = (128,),
unconditional = True
),
amp = True
).cuda()

Here's an image before NaN loss:
sample-0
Here's an image of NaN loss:
sample-1

My current batch size is 20 btw.
What should I do? Did you ever manage to train successfully with the provided settings?

Multi GPU training

Hi, I ran into some problems while trying to launch training on multi gpu (one machine):

  • First, there were problems with DistributedDataParallel missing some fields in gigagan_pytorch.py: self.D.multiscale_input_resolutions in line 2159 and self.G.input_image_size in 2095. For a quick fix I added module before referencing a field and moved on.
  • Then I got a following error: RuntimeError: Inplace update to inference tensor outside InferenceMode is not allowed.You can make a clone to get a normal tensor before doing inplace update originating from gigagan_pytorch.py line 2514 (2479). I tried to clone the generator output for a quick fix (as well as sample function output), but it didn't help.

I use accelerate for the first time, so no prior experience - while running on single GPU everything works fine.
Since multi gpu training is marked in your todo list as done, I suppose the issue is on my side - but maybe you have some idea what causes those problems? I use torch 2.0.1+cu117, accelerate 0.22.0.

Anyway, thanks for great work!

Questions regarding the discriminator

Starting a new issue for better visibility to other people. I have some quick questions regarding the implementation of discriminator. The argument to the forward function (images, rgb, real_images) seems to be a bit confusing.

Not sure my understanding is correct. In a training iteration of a general GAN, the discriminator should be called twice, once for a batch real images and once for a batch of generated image. While we can concat real and fake images to call the network once, I think this is not done here as the output logit has only 1 scalar rather than 2.

Then I would like to know what are the expected arguments for the real pass and fake pass? I guess for the fake pass, rgbs is the output collection of different resolution rgbs from the generator, and images should be the final output of the generator (generated images at highest resolution); for the real pass, images is the real image, and 'rgbs' should be None. The real_images argument can always be None.

I may be very wrong, so really appreciate if you can correct and explain.

Two additional questions:

  1. this line looks tricky, isn't self.resize_image_to(rgb, rgb.shape[-1]) just return rgb so it is basically concat two rgb together?
  2. Just want to confirm the current implementation has not supported the multi-scale loss yet, as that requires modules to process (HW3) images at multiple resolutions.

fidelity

in your tests are you tried using this on a video? the area of face upscaling especially in videos is lacking & i am wondering if you have tested if you see much temporal inconsistancys?

Hardware Requirements

I'd like to suggest inform on readme about minimum hardware requirements for training.

I am new on this scenario, I am trying to study about all these technologies recently, and I also would like to know about these requirements, because I recently purchased a very simple computer machine to start my studies.

Is this project able to run on my weak machine?

Core i9 12900HK Interposed @ 5,0GHz
DDR4 64GB RAM
1 x RTX 4090 24GB GDDR6X

I know some generatives models that will run pretty on this hardware, but some text models like Llama requires much more RAM/VRAM memory to run.

Im planning aquiring more GPUs in the future, but I think these modest hardware will work very well to start study journey.

So guys, what you can share about which hardware will be pretty well to run training and inferences?

Training plans?

I've got a bunch of compute the next couple weeks and thinking to train this on LAION.

Wondering if there is any other training going on right now. Would hate to duplicate efforts too much.

Generator/Discriminator dim Argument

It seems that the current implementation of the Generator/Discriminator do not make any use of the primary dim argument. Looks like capacity is the main "go to" for altering the overall model size on the z dimension.

Perhaps dim could be removed, or the term capacity could be replaced with dim.

How can I use this?

I'm sorry for leaving such an annoying issue, but how can I use this as a package for txt2img/upscaling?

Question about discriminator

Thank you for your great project! I have a little problem. Can the discriminator accept multi-scale input? I notice that the forward() of discriminator only accept one input, not multi-scale.

Some doubts about the multi-scale input and output of the super-resolution part

Hello, thanks for your coding. While trying this code, I found out that the resolutions of "multiscale_input_resolutions" and "rbgs" were mismatched. My understanding of the paper is that the size of the multi-scale input is smaller than the original image. in your code that you preserve larger sizes than the original image in order to preserve more features. I'm confused about it and hope for your reply.

What makes the gan upsampler so effective

I find the results of the gan upsampler to be quite impressive when compared to SD upscaler and real-esrgan. However, in my personal experience, training a SR model with paired lr-hr data in a supervised manner can sometimes make it challenging to generate realistic textures on the hr image, especially when compared to pure generative models. This often results in achieving real-esrgan like outcomes. What, in your opinion, is the main concept that makes the gan upsampler so effective?

Question about the timing and complexity of GigaGAN replication

How long would it take to do a full replication?
And how workable will it be in terms of what is presented in the official article?
As I understand it, you are using existing technology and adapting something as needed?

I am overwhelmed by your enthusiasm and desire to make available what is not publicly available to all.

It deserves a lot of respect, you are making a huge contribution to the evolution of new technologies by making them openly available, overcoming many questions and challenges!

Is this project ready to train?

Is this ready to start training?

I am very curious about this project, how can you implement things on it without having weigths to check results? How can you make sure wheter is correct implementation?

I am beginning my studies yet, but I have access to a great hardware for training, I don't know if it will be enough, but if does, may I help with first weights.

[Question] About the upscaler

Hello lucidrains,

Thanks for your code!

I was wondering, is it possible to upscale an image from 1024px to 4K px with GigaGAN? (Even if it takes a lot of memory and GPU/CPU)

In your code it seems to be a 256px upscale in 512px?

If it is possible to upscale in 1024px what values should I change in your code?

gan = GigaGAN(
    train_upsampler = True,     # set this to True
    generator = dict(
        style_network = dict(
            dim = 64,
            depth = 4
        ),
        dim = 32,
        image_size = 256,
        input_image_size = 64,
        unconditional = True
    ),
    discriminator = dict(
        dim_capacity = 16,
        dim_max = 512,
        image_size = 256,
        num_skip_layers_excite = 4,
        multiscale_input_resolutions = (128,),
        unconditional = True
    ),
    amp = True
).cuda()

dataset = ImageDataset(
    folder = '/path/to/your/data',
    image_size = 256
)

Should I change the dim, multiscale_input_resolutions? I'm a bit lost. :-)

Regards,

Sparfell

Possible Discrepancies

Hi Phil,

I noted a few possible discrepancies, and I have made some suggestions as well.

You may want to consider;

  1. Turning off Demodulation for ToRGB
    The demodulation is turned on in the ToRGB. In the current implementation, ToRGB usesdemod=True. This makes it a lot harder for the generator to get the correct feature magnitude at each level of the progressive growing pathway.

  2. Removing sample adaptive kernels in the ToRGB component
    I don't believe that SAKS is used by GigaGAN in the ToRGB section. The caption of Figure 4 states "The style code modulates the main generator using our style-adaptive kernel selection, shown on the right". The main generator is the upsampling pathway consisting of style based feature modulation, not the progressive growing section. The output of the ToRGB section is three channels, so adding more kernel probably won't improve representational capacity but lead to oscillations.

How far is this codebase?

Thanks a lot for starting this effort. I had a quick look at the code, and seems like though the GAN network structure is done, there is no training code or reference to data. Can you please explain? Many thanks.

Multi GPU with gradient accumulation

Hi! While training on multi GPU and using gradient accumulation steps > 1 there's no substantial speedup with relation to a single GPU (there is a speedup if the value is equal to 1). I found following threads on huggingface here and here that seem to provide a solution. I even ran a dummy test by just adding a proper argument to Accelerator, and actually the training was much faster (in your class I set the gradient accumulation steps to 1, but for Accelerator to 8, but I didn't make other changes to take into account this modification, so the results weren't particularly useful ๐Ÿ˜‰). If you have time to check if this is interesting for you, I'd be grateful.

How to use this model for SR ?

Hi,

Would you mind to provide an example of how to use this model for SR. What would be the text input ? I assume, since CLIP is used , rather than text input we use image and its embedding (since it is in a joint space ) ?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.