GithubHelp home page GithubHelp logo

kakaobrain / mindall-e Goto Github PK

View Code? Open in Web Editor NEW
630.0 14.0 65.0 45.66 MB

PyTorch implementation of a 1.3B text-to-image generation model trained on 14 million image-text pairs

License: Other

Python 100.00%

mindall-e's People

Contributors

leedoyup avatar saehoonkim avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

mindall-e's Issues

What training setup did you use?

This looks great!

Could you share some information on what setup you used for the training of the transformer model?

  • how many gpu / for how long
  • how many steps
  • what batch size

It would be helpful to have these information to better understand the cost of training dalle models.

How much VRAM is needed for this?

I was trying to run the sampling_ex.py, but no matter how low I set the num_candidates value (even if it's set to one or two), it always tells me that it has run out of memory. I am using an NVIDIA Quadro M5000 with 8 GB of VRAM.

CUDA out-of-memory

Hi,
It is mentioned in the "Transfer Learning Examples" section that you fine-tuned the pre-trained DALL-E on 8 V100 GPUs. I tried running you transfer_learning_ex.py script on V100 GPUs (16GB GPU memory per CPU). It throws CUDA OOM error. Can you please share the exact specs of the hardware you used for this?

Notebook tweaks for Google Colab

For info, on Google Colab, the provided notebook examples/sampling_interactive_demo.ipynb has to be slightly edited.

One has to:

  • toggle ON the GPU usage,
  • run the following cell at the top of the notebook:
%cd /content
!git clone https://github.com/kakaobrain/minDALL-E.git
%cd /content/minDALL-E
%pip install -q pytorch-lightning omegaconf einops tokenizers
%pip install -q git+https://github.com/openai/CLIP.git

I could have run:

%pip install -q -r requirements.txt

However, it takes a long time for no added value, as some packages are already installed on Colab.

Project dependencies may have API risk issues

Hi, In minDALL-E, inappropriate dependency versioning constraints can cause risks.

Below are the dependencies and version constraints that the project is using

torch==1.8.0
torchvision>=0.8.2
tokenizers>=0.10.2
pyflakes>=2.2.0
tqdm>=4.46.0
pytorch-lightning>=1.5
einops
omegaconf
git+https://github.com/openai/CLIP.git
matplotlib

The version constraint == will introduce the risk of dependency conflicts because the scope of dependencies is too strict.
The version constraint No Upper Bound and * will introduce the risk of the missing API Error because the latest version of the dependencies may remove some APIs.

After further analysis, in this project,
The version constraint of dependency tqdm can be changed to >=4.36.0,<=4.64.0.

The above modification suggestions can reduce the dependency conflicts as much as possible,
and introduce the latest version as much as possible without calling Error in the projects.

The invocation of the current project includes all the following methods.

The calling methods from the tqdm
tqdm.tqdm.set_description
tqdm.tqdm
The calling methods from the all methods
self.resid_drop
torch.cuda.manual_seed_all
PIL.Image.fromarray
PIL.Image.fromarray.save
ExpConfig
self.key
hashlib.md5
module.weight.data.normal_
self.head
pytorch_lightning.loggers.TensorBoardLogger
self.lr_schedulers.get_last_lr
text_features.image_features.F.cosine_similarity.squeeze
W.B.device.H.torch.arange.repeat.transpose
numpy.transpose
min
argparse.ArgumentParser.add_argument
self.quantize.get_codebook_entry
self.v
sorted_idx_remove_cond.scatter
self.quant_conv
RuntimeError
self.apply
ImageNetDataModule
self.sos.repeat
pytorch_lightning.Trainer.fit
torchvision.transforms.Compose
self.stage2.sos
AttnBlock
model.stage1.from_ckpt
from_file
reversed
get_positional_encoding
datetime.datetime.now
tokens.to.unsqueeze
torch.nn.functional.cosine_similarity
probs.torch.multinomial.clone
self.encode
pl_module.stage1
self.down.append
Normalize
self.mid.block_1
download
self.conv1
Downsample
z_q.permute.contiguous
self.conv
OptConfig
torch.nn.functional.pad
Stage1Hparams
self.embedding
super
w_.permute.permute
i.images.astype
source.info.get
from_file.enable_truncation
self.norm2
random.seed
numpy.random.seed
os.path.expanduser
x.self.query.view
codes.device.T.torch.arange.repeat
layers.Block
device.args.num_candidates.args.softmax_temperature.args.top_p.args.top_k.args.prompt.model.sampling.cpu
self.conv_in
device.H.torch.arange.repeat
self.mlp.transpose
cutoff_topp_probs.masked_fill
self.norm1
k.reshape.reshape
torch.cuda.amp.autocast
x.contiguous.contiguous
loop.update
argparse.ArgumentParser.parse_args
prompt.clip.tokenize.to
self.tok_emb_txt
device.args.num_candidates.args.softmax_temperature.args.top_p.args.top_k.args.prompt.model.sampling.cpu.numpy
Stage2Hparams
os.path.dirname
torch.tril
self.ln1
pytorch_lightning.callbacks.ModelCheckpoint
cnt.code_.unsqueeze
model_clip.encode_text
y.transpose.contiguous.view
ImageNetDataModule.setup
tuple
enumerate
torch.nn.Linear
self.resid_drop.transpose
tokenizer.build_tokenizer
i_block.i_level.self.down.attn
self.register_buffer
self.dropout
torchvision.utils.make_grid
self.mid.attn_1
x.self.value.view
torch.randn
output.write
self.pos_emb_img
self.n_heads.C.self.n_heads.B.T.x.self.key.view.transpose
self.ln2
self.nin_shortcut
self.stage2.eval
self.lr_schedulers.step
self.blocks
os.path.abspath
model.stage2.from_ckpt
torch.multinomial
self.encoder
quant.permute.permute
min_encoding_indices.self.embedding.view
torch.nn.functional.interpolate
labels.self.sos.unsqueeze
print
torchvision.transforms.Normalize
sys.path.append
self.decoder
torch.einsum
self.norm_out
torch.optim.AdamW
images.self.stage1.get_codes.detach.view
MultiHeadSelfAttention
einops.rearrange
urllib.parse.urlparse
stage2.transformer.Transformer1d
self.stage1.get_codes
DataConfig
self.drop
omegaconf.OmegaConf.structured
dalle.models.Dalle.from_pretrained.sampling
preprocess_clip
images.torch.stack.to
tqdm.tqdm.set_description
utils.config.get_base_config
tqdm.tqdm
x.self.key.view
self.n_heads.C.self.n_heads.B.T.x.self.query.view.transpose
torch.cat.clone
self.decode
self.stage2
self.query
i_level.self.up.upsample
urllib.request.urlopen
torch.nn.ModuleList.append
self.conv2
source.info
self.n_heads.C.self.n_heads.B.T.x.self.value.view.transpose
self.lr_schedulers
layers.Encoder
tarfile.open
images.self.stage1.get_codes.detach
model_clip.encode_image
cutoff_topk_logits
utils.sampling.sampling
torch.nn.Sequential
torch.nn.ModuleList
setup_callbacks
self.value
tokens.to.to
self.log
math.sqrt
isinstance
omegaconf.OmegaConf.merge
open
torch.cat
torch.ones
torch.topk
self.proj_out.reshape
torch.argmin
self.q
self.stage1.parameters
os.path.join
os.path.exists
torch.utils.data.DataLoader
self.embedding.weight.data.uniform_
scores.torch.argsort.cpu
torch.nn.Module
cutoff_topk_logits.to
dalle.utils.utils.clip_score
int
cutoff_topk_logits.clone
N.x.contiguous
f.extract
torch.stack
torch.sort
self.attn_drop.masked_fill
torchvision.datasets.ImageNet
torchvision.transforms.CenterCrop
optimizer.step
download_target.open.read
cnt.pos_enc_code_.unsqueeze
args.config_downstream.os.path.basename.split
self
torch.optim.lr_scheduler.CosineAnnealingLR
stage1.vqgan.VQGAN
ValueError
torch.argsort
Stage1Config
range
torch.nn.functional.avg_pool2d
omegaconf.OmegaConf.load
self.sos
x.transpose.contiguous
torch.manual_seed
os.path.isfile
image.astype
present.torch.stack.clone
pl_module.logger.experiment.add_image
os.path.basename
ImageLogger
self.stage1.eval
pytorch_lightning.seed_everything
torch.cat.size
v.reshape.reshape
sos.self.stage2.sos.unsqueeze
torchvision.transforms.Resize
url.split
clip.tokenize
datetime.datetime.now.strftime
device.W.torch.arange.repeat
torch.nn.Conv2d
torch.nn.LayerNorm
dalle.utils.utils.set_seed
cls_idx.torch.LongTensor.to
torch.nn.functional.softmax
i_block.i_level.self.up.attn
ResnetBlock
torch.nn.functional.cross_entropy
probs.torch.multinomial.clone.detach
float
images.texts.torch.cat.contiguous
f.getmembers
z_q.permute.contiguous.view
dalle.models.Dalle.from_pretrained
source.read
VectorQuantizer
pytorch_lightning.Trainer
torch.sigmoid
self.tok_emb_img
i_block.i_level.self.down.block
torch.clamp
self.tokenizer.encode
h.self.quantize.view
self.conv_out
nonlinearity
model_clip.to
self.ln_f
q.permute.reshape
torch.arange
self.load_state_dict
q.permute.permute
self.k
functools.partial
torch.sum
self.stage2.sos.repeat
self.norm
self.mid.block_2
self.head_txt
cls
utils.realpath_url_or_path
torch.load
torch.no_grad
format
past.append
torchvision.transforms.ToTensor
device.N.torch.arange.repeat
presents.append
self.stage1.decode_code
self.quantize
from_file.token_to_id
os.makedirs
self.pos_emb_txt
torch.nn.Embedding
utils.sampling.sampling_igpt
code.clone.detach
dalle.models.ImageGPT.from_pretrained
z_q.permute.contiguous.permute
torchvision.transforms.RandomCrop
self.attn
Upsample
stage2.transformer.iGPT
self.post_quant_conv
torch.cumsum
super.__init__
download_target.open.read.hashlib.md5.hexdigest
self.proj_out
i_level.self.down.downsample
h.sos.torch.cat.contiguous
ImageNetDataModule.train_dataloader
self.stage2.view
self.head_img
self.proj
ImageNetDataModule.valid_dataloader
self.parameters
len
z.rearrange.contiguous
torch.clip
torch.nn.GroupNorm
torch.nn.Parameter
model.sampling
argparse.ArgumentParser
torch.nn.Dropout
sorted_idx_remove_cond.clone
block.sample
torch.LongTensor
self.log_img
from_file.enable_padding
torch.bmm
self.mlp
self.conv_shortcut
y.transpose.contiguous
recons.cpu.cpu
module.bias.data.zero_
GELU
self.up.insert
dataclasses.field
module.weight.data.fill_
clip.load
torch.nn.functional.gelu
i_block.i_level.self.up.block
present.torch.stack.clone.detach
from_file.add_special_tokens
Stage2Config
torch.repeat_interleave
dalle.models.Dalle.from_pretrained.to
layers.Decoder
scores.torch.argsort.cpu.numpy
cutoff_topp_probs
self.mask.torch.tril.view
sos.self.stage2.sos.unsqueeze.repeat
torch.cat.transpose
images.cpu.cpu
self.attn_drop
quant.rearrange.contiguous
z.rearrange.contiguous.view

@developer
Could please help me check this issue?
May I pull a request to fix it?
Thank you very much.

Training hyperparameters

First off, great work.

Will information about the training be published anywhere? I'm specifically interested in the number of training epochs used and the LR.

sampling in GPU with 12 GB memory

I found that sampling code examples/sampling_ex.py fails to save the image if the num_candiates is smaller than 16.

It is due to the value 16 is hardcoded in line 61,
for i in range(16):

The below modification works for lower num_candidates value.
for i in range(min(16, args.num_candidates)):

Script for VQGAN Finetuning

This is an incredible project! For reproducibility, and for some of my own work, would you mind sharing/pointing me to code for fine-tuning VQGAN models (e.g., vqgan_imagenet_f16_16384) on custom datasets? This would be different than code for training VQGAN from scratch on different datasets.

Additionally, how long does fine-tuning take?

Increasing positional embeddings text

I am finetuning the minDALL-E model on a self-made dataset but my tokenized text prompts are sometimes longer than 64. What would be the best technique to increase the length of the positional encodings to e.g. 128? I was thinking of keeping the original 64 embeddings and appending 64 more, which have to be trained from scratch. However, I think it might mess with the finetuning, since the embeddings are in the very first layer.

Are there better options/techniques to accomplish this?

Does zero-shot work in minDALL-E?

Thanks for your amazing work!

I'm attempting zero-shot image-to-image translation, as described in the original paper, by inserting only half of the image. The outcomes are as follows. Will this problem be solved if I increase the size of the model?

스크린샷 2021-12-15 오후 8 22 23

complete images?

you can add a way to complete images, give you an image without completion.

How to do inference from half image

Hi I want to know if the code can do the inference when we input the text and half of the image like iGPT and Taming Transformer? If possible, would you mind pointing to the relevance code for this.

Comparison against GLIDE

Recently Open AI posted GLIDE, a diffusion model made for generating images from text, much like DALL-E.

Would it be possible to compare minDALL-E to GLIDE and put the results on the github?

Thank you in advance!

Also I have to say this is amazing!

Amazing work; models CDN?

Hi there! Just want to quickly congratulate all the effort done in this project!

Will the models / tokenizers also be stored in Github's releases binary? It could be good as a backup / alternative.

text token index slice to N-1

Hi, thanks for sharing the code.

In the forward function of Transformer1d,
text index is sliced with 0 ~ N-2 and image index is sliced with N-1 ~ N-1 + (T-1).

B, T = images.shape
_, N = texts.shape
...
x = torch.cat([texts, images], axis=1).contiguous()
...
texts = x[:, :N-1].contiguous()
images = x[:, N-1:-1].contiguous()

Could you please clarify why you didn't slice like below? Thanks!

texts = x[:, :N]
images = x[:, N:]

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.