GithubHelp home page GithubHelp logo

zyinghua / uncond-image-generation-ldm Goto Github PK

View Code? Open in Web Editor NEW
6.0 2.0 0.0 33 KB

Unconditional Image Generation using a [modifiable] pretrained VQVAE based Latent Diffusion Model, adapted from huggingface diffusers.

Python 100.00%
diffusion-model ldm unconditional-generation

uncond-image-generation-ldm's Introduction

Training an unconditional latent diffusion model

Creating a training image set is described in a different document.

Cloning to local

git clone https://github.com/zyinghua/uncond-image-generation-ldm.git

Then call:

cd uncond-image-generation-ldm

Installing the dependencies

Before running the scripts, make sure to install the library's training dependencies:

pip install -r requirements.txt

And initialize an 🤗Accelerate environment with:

accelerate config

Change Pretrained VAE settings

You can specify which pretrained VAE model to use by changing the VAE_PRETRAINED_PATH and VAE_KWARGS variables in train.py, at the top.

Unconditional Flowers

An examplar command to train a DDPM UNet model on the Oxford Flowers dataset, without using GPUs:

accelerate launch train.py \
  --dataset_name="huggan/flowers-102-categories" \
  --resolution=256 \
  --output_dir="ddpm-ema-flowers-256" \
  --train_batch_size=16 \
  --num_epochs=150 \
  --gradient_accumulation_steps=1 \
  --use_ema \
  --learning_rate=1e-4 \
  --lr_warmup_steps=500 \
  --mixed_precision=no \

Training with multiple GPUs

accelerate allows for seamless multi-GPU training. After setting up with accelerate config, simply add --multi_gpu in the command. For more information, follow the instructions here for running distributed training with accelerate. Here is an example command:

accelerate launch --multi_gpu train.py \
  --dataset_name="huggan/flowers-102-categories" \
  --resolution=256 \
  --output_dir="ddpm-ema-flowers-256" \
  --train_batch_size=16 \
  --num_epochs=150 \
  --gradient_accumulation_steps=1 \
  --use_ema \
  --learning_rate=1e-4 \
  --lr_warmup_steps=500 \
  --mixed_precision=no \

To be able to use Weights and Biases (wandb) as a logger you need to install the library: pip install wandb.

Using your own data

To use your own dataset, there are 3 ways:

  • you can either provide your own folder as --train_data_dir
  • or you can provide your own .zip file containing the data as --train_data_files
  • or you can upload your dataset to the hub (possibly as a private repo, if you prefer so), and simply pass the --dataset_name argument.

Below, we explain both in more detail.

Provide the dataset as a folder/zip file

If you provide your own folders with images, the script expects the following directory structure:

data_dir/xxx.png
data_dir/xxy.png
data_dir/[...]/xxz.png

In other words, the script will take care of gathering all images inside the folder. You can then run the script like this:

accelerate launch train.py \
    --train_data_dir <path-to-train-directory> \
    <other-arguments>

Or (if it is a zip file):

accelerate launch train.py \
    --train_data_files <path-to-train-zip-file> \
    <other-arguments>

Internally, the script will use the ImageFolder feature which will automatically turn the folders into 🤗 Dataset objects.

Official diffusers repo also has a pipeline for uncond ldm that can be found here.

uncond-image-generation-ldm's People

Contributors

zyinghua avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar

uncond-image-generation-ldm's Issues

Captions data

Can you please explain if how are you loading the captions data with images?

May I ask for the result?

Hi, I appreciate that you release the code here.
I wonder if you have test the code in training process?

How to make inference and inference error.

i trained a model by following the instructions provided but when i try to run the inference code i get an error saying that there are keys missing

화면 캡처 2024-06-12 142110

ValueError: Cannot load <class 'diffusers.models.unets.unet_2d.UNet2DModel'> from /home/david/testing/uncond-image-generation-ldm/output/vae because the following keys are missing:
down_blocks.2.resnets.0.norm1.bias, up_blocks.2.resnets.2.time_emb_proj.bias, up_blocks.2.resnets.2.norm1.bias, down_blocks.1.resnets.1.norm2.weight, mid_block.resnets.0.time_emb_proj.weight, down_blocks.2.resnets.0.norm1.weight, up_blocks.1.resnets.2.norm2.weight, up_blocks.1.resnets.1.conv1.bias, down_blocks.2.resnets.1.conv1.weight, down_blocks.1.resnets.1.conv1.weight, mid_block.attentions.0.to_q.weight, down_blocks.0.resnets.0.norm2.bias, up_blocks.2.resnets.0.norm2.bias, up_blocks.0.resnets.1.norm2.bias, up_blocks.0.resnets.2.conv1.weight, up_blocks.0.resnets.0.norm1.bias, conv_out.weight, up_blocks.0.resnets.0.time_emb_proj.bias, down_blocks.0.downsamplers.0.conv.weight, mid_block.attentions.0.to_k.weight, mid_block.resnets.1.norm1.bias, conv_out.bias, up_blocks.1.resnets.2.conv1.bias, up_blocks.0.resnets.0.time_emb_proj.weight, down_blocks.2.resnets.1.norm2.weight, up_blocks.1.upsamplers.0.conv.bias, up_blocks.2.resnets.1.time_emb_proj.weight, mid_block.attentions.0.to_v.weight, up_blocks.2.resnets.1.conv2.weight, down_blocks.1.resnets.1.conv2.weight, up_blocks.1.resnets.2.conv2.bias, mid_block.attentions.0.to_v.bias, down_blocks.1.resnets.0.conv_shortcut.bias, up_blocks.0.resnets.1.norm2.weight, up_blocks.0.resnets.2.norm1.weight, down_blocks.1.downsamplers.0.conv.weight, up_blocks.1.resnets.0.norm2.weight, up_blocks.1.resnets.0.conv2.weight, mid_block.resnets.0.norm1.weight, up_blocks.0.resnets.1.conv2.bias, down_blocks.1.resnets.0.norm1.weight, down_blocks.0.resnets.0.norm2.weight, down_blocks.1.resnets.0.norm2.weight, mid_block.resnets.0.conv2.bias, up_blocks.0.resnets.0.norm2.weight, up_blocks.1.resnets.2.norm2.bias, mid_block.resnets.0.conv2.weight, down_blocks.0.resnets.1.norm2.bias, down_blocks.2.resnets.0.conv1.bias, up_blocks.0.resnets.0.conv2.bias, up_blocks.2.resnets.1.norm2.weight, down_blocks.0.downsamplers.0.conv.bias, up_blocks.2.resnets.1.norm1.bias, time_embedding.linear_2.bias, up_blocks.1.resnets.1.conv2.weight, up_blocks.1.resnets.2.norm1.weight, down_blocks.0.resnets.0.conv2.bias, up_blocks.2.resnets.2.norm1.weight, conv_norm_out.weight, up_blocks.2.resnets.2.norm2.weight, down_blocks.0.resnets.1.conv2.weight, mid_block.attentions.0.group_norm.weight, up_blocks.0.resnets.1.conv1.bias, up_blocks.2.resnets.0.time_emb_proj.bias, up_blocks.1.upsamplers.0.conv.weight, mid_block.resnets.0.norm2.bias, up_blocks.1.resnets.0.conv_shortcut.bias, up_blocks.0.resnets.2.norm2.bias, up_blocks.0.resnets.2.conv1.bias, up_blocks.1.resnets.1.norm2.bias, mid_block.resnets.1.norm2.weight, up_blocks.1.resnets.0.norm1.weight, up_blocks.1.resnets.1.time_emb_proj.weight, down_blocks.1.resnets.0.conv1.bias, down_blocks.2.resnets.0.conv2.weight, up_blocks.0.resnets.2.time_emb_proj.weight, up_blocks.2.resnets.1.norm1.weight, up_blocks.1.resnets.1.conv1.weight, down_blocks.2.resnets.0.conv2.bias, up_blocks.0.resnets.0.conv_shortcut.weight, up_blocks.0.resnets.2.norm1.bias, up_blocks.0.resnets.2.conv2.bias, up_blocks.2.resnets.0.norm1.bias, down_blocks.0.resnets.0.conv1.weight, conv_in.weight, down_blocks.0.resnets.0.norm1.bias, mid_block.resnets.1.norm1.weight, down_blocks.1.resnets.1.conv1.bias, down_blocks.2.resnets.1.norm2.bias, up_blocks.1.resnets.1.norm1.weight, mid_block.resnets.0.norm2.weight, time_embedding.linear_2.weight, up_blocks.2.resnets.0.conv1.bias, down_blocks.2.resnets.1.conv1.bias, mid_block.attentions.0.to_out.0.weight, up_blocks.2.resnets.0.norm2.weight, up_blocks.2.resnets.1.conv1.bias, down_blocks.0.resnets.1.norm1.weight, up_blocks.2.resnets.2.time_emb_proj.weight, time_embedding.linear_1.weight, up_blocks.2.resnets.1.conv1.weight, down_blocks.1.downsamplers.0.conv.bias, up_blocks.0.resnets.1.time_emb_proj.weight, down_blocks.0.resnets.0.norm1.weight, down_blocks.1.resnets.0.norm1.bias, mid_block.resnets.0.conv1.weight, up_blocks.2.resnets.1.conv2.bias, down_blocks.2.resnets.1.conv2.bias, up_blocks.0.resnets.0.conv_shortcut.bias, up_blocks.1.resnets.1.time_emb_proj.bias, up_blocks.1.resnets.2.time_emb_proj.weight, mid_block.resnets.1.conv2.weight, up_blocks.0.upsamplers.0.conv.weight, down_blocks.0.resnets.1.conv1.bias, down_blocks.1.resnets.1.norm1.bias, up_blocks.0.resnets.0.conv1.bias, mid_block.resnets.1.time_emb_proj.weight, down_blocks.1.resnets.0.conv2.bias, up_blocks.0.resnets.1.norm1.weight, up_blocks.2.resnets.0.time_emb_proj.weight, up_blocks.2.resnets.2.conv2.weight, mid_block.resnets.1.conv1.bias, time_embedding.linear_1.bias, up_blocks.1.resnets.0.conv_shortcut.weight, mid_block.attentions.0.to_q.bias, up_blocks.1.resnets.0.conv1.bias, down_blocks.1.resnets.1.conv2.bias, down_blocks.1.resnets.0.conv2.weight, down_blocks.0.resnets.1.norm2.weight, up_blocks.1.resnets.2.time_emb_proj.bias, up_blocks.0.resnets.2.time_emb_proj.bias, up_blocks.1.resnets.1.conv2.bias, up_blocks.0.resnets.0.conv1.weight, mid_block.attentions.0.to_k.bias, down_blocks.0.resnets.0.conv1.bias, up_blocks.2.resnets.0.conv1.weight, up_blocks.0.resnets.0.norm2.bias, mid_block.resnets.1.conv1.weight, mid_block.resnets.1.time_emb_proj.bias, down_blocks.0.resnets.1.norm1.bias, mid_block.resnets.1.conv2.bias, up_blocks.0.upsamplers.0.conv.bias, up_blocks.2.resnets.2.conv2.bias, up_blocks.0.resnets.1.norm1.bias, down_blocks.1.resnets.1.norm1.weight, down_blocks.2.resnets.0.conv_shortcut.bias, down_blocks.1.resnets.0.conv_shortcut.weight, down_blocks.2.resnets.0.norm2.bias, up_blocks.0.resnets.1.conv1.weight, down_blocks.0.resnets.1.conv1.weight, up_blocks.1.resnets.2.conv1.weight, up_blocks.0.resnets.2.norm2.weight, up_blocks.1.resnets.2.norm1.bias, mid_block.attentions.0.group_norm.bias, up_blocks.2.resnets.2.norm2.bias, down_blocks.2.resnets.1.norm1.weight, mid_block.resnets.0.time_emb_proj.bias, mid_block.attentions.0.to_out.0.bias, up_blocks.0.resnets.0.norm1.weight, up_blocks.0.resnets.1.time_emb_proj.bias, up_blocks.0.resnets.0.conv2.weight, down_blocks.1.resnets.0.norm2.bias, conv_norm_out.bias, up_blocks.1.resnets.0.norm2.bias, down_blocks.1.resnets.1.norm2.bias, conv_in.bias, up_blocks.0.resnets.1.conv2.weight, mid_block.resnets.0.conv1.bias, down_blocks.2.resnets.0.conv1.weight, up_blocks.2.resnets.1.norm2.bias, down_blocks.2.resnets.0.norm2.weight, down_blocks.1.resnets.0.conv1.weight, down_blocks.2.resnets.1.conv2.weight, up_blocks.2.resnets.0.conv2.weight, up_blocks.1.resnets.0.conv2.bias, down_blocks.2.resnets.1.norm1.bias, up_blocks.1.resnets.0.conv1.weight, up_blocks.0.resnets.2.conv2.weight, up_blocks.1.resnets.0.time_emb_proj.weight, up_blocks.2.resnets.2.conv1.bias, up_blocks.1.resnets.1.norm1.bias, down_blocks.2.resnets.0.conv_shortcut.weight, up_blocks.2.resnets.0.conv2.bias, up_blocks.2.resnets.0.norm1.weight, up_blocks.1.resnets.0.norm1.bias, up_blocks.1.resnets.1.norm2.weight, up_blocks.2.resnets.1.time_emb_proj.bias, down_blocks.0.resnets.0.conv2.weight, down_blocks.0.resnets.1.conv2.bias, mid_block.resnets.1.norm2.bias, mid_block.resnets.0.norm1.bias, up_blocks.1.resnets.0.time_emb_proj.bias, up_blocks.1.resnets.2.conv2.weight, up_blocks.2.resnets.2.conv1.weight.
Please make sure to pass low_cpu_mem_usage=False and device_map=None if you want to randomly initialize those weights or else make sure your checkpoint file is correct.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.