GithubHelp home page GithubHelp logo

theadamcolton / ddpo-pytorch-xl Goto Github PK

View Code? Open in Web Editor NEW

This project forked from kvablack/ddpo-pytorch

0.0 0.0 0.0 3.72 MB

DDPO for finetuning diffusion models, implemented in PyTorch with LoRA support and sdxl support

License: MIT License

Python 100.00%

ddpo-pytorch-xl's Introduction

ddpo-pytorch

This is an implementation of Denoising Diffusion Policy Optimization (DDPO) in PyTorch with support for low-rank adaptation (LoRA). Unlike our original research code (which you can find here), this implementation runs on GPUs, and if LoRA is enabled, requires less than 10GB of GPU memory to finetune Stable Diffusion!

DDPO

Installation

Requires Python 3.10 or newer.

git clone [email protected]:kvablack/ddpo-pytorch.git
cd ddpo-pytorch
pip install -e .

Usage

accelerate launch scripts/train.py

This will immediately start finetuning Stable Diffusion v1.5 for compressibility on all available GPUs using the config from config/base.py. It should work as long as each GPU has at least 10GB of memory. If you don't want to log into wandb, you can run wandb disabled before the above command.

Please note that the default hyperparameters in config/base.py are not meant to achieve good performance, they are just to get the code up and running as fast as possible. I would not expect to get good results without using a much larger number of samples per epoch and gradient accumulation steps.

Important Hyperparameters

A detailed explanation of all the hyperparameters can be found in config/base.py. Here are a few of the important ones.

prompt_fn and reward_fn

At a high level, the problem of finetuning a diffusion model is defined by 2 things: a set of prompts to generate images, and a reward function to evaluate those images. The prompts are defined by a prompt_fn which takes no arguments and generates a random prompt each time it is called. The reward function is defined by a reward_fn which takes in a batch of images and returns a batch of rewards for those images. All of the prompt and reward functions currently implemented can be found in ddpo_pytorch/prompts.py and ddpo_pytorch/rewards.py, respectively.

Batch Sizes and Accumulation Steps

Each DDPO epoch consists of generating a batch of images, computing their rewards, and then doing some training steps on those images. One important hyperparameter is the number of images generated per epoch; you want enough images to get a good estimate of the average reward and the policy gradient. Another important hyperparameter is the number of training steps per epoch.

However, these are not defined explicitly but are instead defined implicitly by several other hyperparameters. First note that all batch sizes are per GPU. Therefore, the total number of images generated per epoch is sample.batch_size * num_gpus * sample.num_batches_per_epoch. The effective total training batch size (if you include multi-GPU training and gradient accumulation) is train.batch_size * num_gpus * train.gradient_accumulation_steps. The number of training steps per epoch is the first number divided by the second number, or (sample.batch_size * sample.num_batches_per_epoch) / (train.batch_size * train.gradient_accumulation_steps).

(This assumes that train.num_inner_epochs == 1. If this is set to a higher number, then training will loop over the same batch of images multiple times before generating a new batch of images, and the number of training steps per epoch will be multiplied accordingly.)

At the beginning of each training run, the script will print out the calculated value for the number of images generated per epoch, the effective total training batch size, and the number of training steps per epoch. Make sure to double-check these numbers!

Reproducing Results

The image at the top of this README was generated using LoRA! However, I did use a fairly powerful DGX machine with 8xA100 GPUs, on which each experiment took about 4 hours for 100 epochs. In order to run the same experiments with a single small GPU, you would set sample.batch_size = train.batch_size = 1 and multiply sample.num_batches_per_epoch and train.gradient_accumulation_steps accordingly.

You can find the exact configs I used for the 4 experiments in config/dgx.py. For example, to run the aesthetic quality experiment:

accelerate launch scripts/train.py --config config/dgx.py:aesthetic

If you want to run the LLaVA prompt-image alignment experiments, you need to dedicate a few GPUs to running LLaVA inference using this repo.

Reward Curves

ddpo-pytorch-xl's People

Contributors

desaixie avatar kvablack avatar theadamcolton avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.