GithubHelp home page GithubHelp logo

maskgit's Introduction

MaskGIT: Masked Generative Image Transformer

Official Jax Implementation of the CVPR 2022 Paper

PWC PWC

[Paper] [Project Page] [Demo Colab]

teaser

Summary

MaskGIT is a novel image synthesis paradigm using a bidirectional transformer decoder. During training, MaskGIT learns to predict randomly masked tokens by attending to tokens in all directions. At inference time, the model begins with generating all tokens of an image simultaneously, and then refines the image iteratively conditioned on the previous generation.

Running pretrained models

Class conditional Image Genration models:

Dataset Resolution Model Link FID
ImageNet 256 x 256 Tokenizer checkpoint 2.28 (reconstruction)
ImageNet 512 x 512 Tokenizer checkpoint 1.97 (reconstruction)
ImageNet 256 x 256 MaskGIT Transformer checkpoint 6.06 (generation)
ImageNet 512 x 512 MaskGIT Transformer checkpoint 7.32 (generation)

You can run these models for class-conditional image generation and editing in the demo Colab.

teaser

Training

[Coming Soon]

BibTeX

@InProceedings{chang2022maskgit,
  title = {MaskGIT: Masked Generative Image Transformer},
  author={Huiwen Chang and Han Zhang and Lu Jiang and Ce Liu and William T. Freeman},
  booktitle = {The IEEE Conference on Computer Vision and Pattern Recognition (CVPR)},
  month = {June},
  year = {2022}
}

Disclaimer

This is not an officially supported Google product.

maskgit's People

Contributors

huiwenchang avatar thesouthfrog avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

maskgit's Issues

Question about the ResBlock

Hello, thank you for sharing this code, it is very useful!

I am wondering if there is a bug in the ResBlock in vqgan_tokenizer. It currently is like this:

    if input_dim != self.filters:
      if self.use_conv_shortcut:
        residual = self.conv_fn(
            self.filters, kernel_size=(3, 3), use_bias=False)(
                x)
      else:
        residual = self.conv_fn(
            self.filters, kernel_size=(1, 1), use_bias=False)(
                x)
    return x + residual

But this is using the same x coming from the previous convolutions, so should it be like this instead?

    if input_dim != self.filters:
      if self.use_conv_shortcut:
        residual = self.conv_fn(
            self.filters, kernel_size=(3, 3), use_bias=False)(
                residual)
      else:
        residual = self.conv_fn(
            self.filters, kernel_size=(1, 1), use_bias=False)(
                residual)
    return x + residual

TypeError: take_along_axis indices must be of integer type, got float32

Hey,

I was trying to run the untouched notebook in an evironment with jax-0.3.13 and jaxli-0.3.10 on an Ubuntu 18.04 machine with CUDA11.7 and CUDNN 8.2, but I get the error

TypeError: take_along_axis indices must be of integer type, got float32

when running
elif run_mode == 'pmap': sample_rngs = jax.random.split(sample_rng, jax.local_device_count()) results = p_generate_256_samples(pmap_input_tokens, sample_rngs)

Any help?

image size should be fixed?

With MaskGIT, if the input image size varies over iteration, it does not work because bidirectional transformer needs input image size to be fixed?
If in autoregressive way, there is no limit in length or size as I know.
Is it right?

License of pretrained models

Hi, cool work!

I know that the license of the source code is Apache 2.0.

I am aware that Google does not own the rights to the dataset (ImageNet), but it would be cool to have an express license on the pretrained models as well, as far as Google and the repo owners' rights are concearned (although I am not arguing pretrained models are copyrightable).

If the intention is for the Apache 2.0 license to apply to the pretrained models (which are hosted elsewhere) as well, then I suggest adding a "License" section to the README file clarifying this.

Please release training code

Thank you wonderful job, i was very interesting with your job,when are you prepare to release your training code?

Real size output images

Hi, many thanks for your code.
I'm using colab notebook, is it possible to get result images in real size?

Once this line is executed
visualize_images(composite_images, title=f'outputs')

I got images, but at a very small size, how can I get them at 512?

Many thanks

Training code releasing

Hi! Thank you for your great work! May I ask when are you planning to release the full training code for this project?

Reproducing the results of "Stage 1" model

Hello. Thanks for the awesome project !

I want to reproduce the checkpoint of Stage 1 model, which has rFID < 2.5, since the rFID is much lower than original VQ-GAN (rFID=4.7).

Could you let me know the detailed differences of implementation and training from the original VQGAN?
I think it is important for a fair comparison of other works and reproducibility.

Crash with pmap on high RAM V100

Hi there! First off, thanks so much for publishing this code!

This issue may just amount to my GPU not having enough memory, but I thought I'd share it since the Colab mentions that pmap should work with V100s. I am running it on Colab Pro+ with the High-RAM runtime shape, but when I get to this line:

results = p_generate_256_samples(pmap_input_tokens, sample_rngs)

I get the following error (I've tried a few times):

---------------------------------------------------------------------------
UnfilteredStackTrace                      Traceback (most recent call last)
[<ipython-input-9-f2eb36cfcc0b>](https://localhost:8080/#) in <module>()
     17     sample_rngs = jax.random.split(sample_rng, jax.local_device_count())
---> 18     results = p_generate_256_samples(pmap_input_tokens, sample_rngs)
     19 

10 frames
UnfilteredStackTrace: RuntimeError: UNKNOWN: CUDNN_STATUS_NOT_SUPPORTED
in external/org_tensorflow/tensorflow/stream_executor/cuda/cuda_dnn.cc(4839): 'status'

The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.

--------------------

The above exception was the direct cause of the following exception:

RuntimeError                              Traceback (most recent call last)
[<ipython-input-9-f2eb36cfcc0b>](https://localhost:8080/#) in <module>()
     16 elif run_mode == 'pmap':
     17     sample_rngs = jax.random.split(sample_rng, jax.local_device_count())
---> 18     results = p_generate_256_samples(pmap_input_tokens, sample_rngs)
     19 
     20     # flatten the pmap results

RuntimeError: UNKNOWN: CUDNN_STATUS_NOT_SUPPORTED
in external/org_tensorflow/tensorflow/stream_executor/cuda/cuda_dnn.cc(4839): 'status'

It may also be notable that the function runs for a long time (20+ seconds) before throwing that error despite the fact that pmap should make it fast. Since I get an actual out-of-memory crash here on other GPUs with less RAM, I figured there may be a chance that this is a different issue.

Thx again for sharing the code :)

Expected release of the training code

Dear authors,

Congratulation on your acceptance to CVPR; this is fantastic work.
Do you have a rough expectation for the release date of the training code?

Thanks

Training code

Hello!

Thanks for providing part of the code and pre trained models.

I am wondering if you guys are planning on releasing the training code.

Thanks again!

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.