GithubHelp home page GithubHelp logo

rockeycoss / spo Goto Github PK

View Code? Open in Web Editor NEW
133.0 6.0 3.0 31.2 MB

Step-aware Preference Optimization: Aligning Preference with Denoising Performance at Each Step

Home Page: https://arxiv.org/abs/2406.04314

Python 100.00%
diffusion-models dpo sdxl text-to-image text-to-image-generation

spo's Introduction

Step-aware Preference Optimization: Aligning Preference with Denoising Performance at Each Step

Zhanhao Liang, Yuhui Yuan, Shuyang Gu, Bohan Chen, Tiankai Hang, Ji Li, Liang Zheng

This is the official implementation of SPO, introduced in Step-aware Preference Optimization: Aligning Preference with Denoising Performance at Each Step.

News

2024.07.10 Release the training code of SPO.

2024.06.20 Release the SD v1.5 checkpoint and inference code.

2024.06.07 Release the SDXL checkpoint and inference code.

Abstract

Recently, Direct Preference Optimization (DPO) has extended its success from aligning large language models (LLMs) to aligning text-to-image diffusion models with human preferences. Unlike most existing DPO methods that assume all diffusion steps share a consistent preference order with the final generated images, we argue that this assumption neglects step-specific denoising performance and that preference labels should be tailored to each step's contribution.

To address this limitation, we propose Step-aware Preference Optimization (SPO), a novel post-training approach that independently evaluates and adjusts the denoising performance at each step, using a step-aware preference model and a step-wise resampler to ensure accurate step-aware supervision. Specifically, at each denoising step, we sample a pool of images, find a suitable win-lose pair, and, most importantly, randomly select a single image from the pool to initialize the next denoising step. This step-wise resampler process ensures the next win-lose image pair comes from the same image, making the win-lose comparison independent of the previous step. To assess the preferences at each step, we train a separate step-aware preference model that can be applied to both noisy and clean images.

Our experiments with Stable Diffusion v1.5 and SDXL demonstrate that SPO significantly outperforms the latest Diffusion-DPO in aligning generated images with complex, detailed prompts and enhancing aesthetics, while also achieving more than 20Γ— times faster in training efficiency. Code and model: https://rockeycoss.github.io/spo.github.io/

Method Overview

method_overview

TODO

  • Release training code
  • Release checkpoints and inference code
  • Initialization

Gallery

teaser example 0 teaser example 1 teaser example 2 teaser example 3
teaser example 4 teaser example 5 teaser example 6 teaser example 7
teaser example 8 teaser example 9 teaser example 10 teaser example 11
teaser example 12 teaser example 13 teaser example 14 teaser example 15
teaser example 16 teaser example 17 teaser example 18 teaser example 19

πŸ”§ Installation

  1. Pull the Docker Image
sudo docker pull rockeycoss/spo:v1
  1. Login to wandb
wandb login {Your wandb key}
  1. (Optional) To customize the location for saving models downloaded from Hugging Face, you can use the following command:
export HUGGING_FACE_CACHE_DIR=/path/to/your/cache/dir

πŸ”§ Inference Hugging Face Checkpoints

SDXL inference

PYTHONPATH=$(pwd) python inference_scripts/inference_spo_sdxl.py

SD v1.5 inference

PYTHONPATH=$(pwd) python inference_scripts/inference_spo_sd-v1-5.py

πŸ”§ Training

The following scripts assume the use of four 80GB A100 GPUs for fine-tuning, as described in the paper.

Before fine-tuning, please download the checkpoints of step-aware preference models. You can do this by following these steps:

sudo apt update
sudo apt install wget

mkdir model_ckpts
cd model_ckpts

wget https://huggingface.co/SPO-Diffusion-Models/Step-Aware_Preference_Models/resolve/main/sd-v1-5_step-aware_preference_model.bin

wget https://huggingface.co/SPO-Diffusion-Models/Step-Aware_Preference_Models/resolve/main/sdxl_step-aware_preference_model.bin

cd ..

To fine-tune SD v1.5, you can use the following command:

PYTHONPATH=$(pwd) accelerate launch --config_file accelerate_cfg/1m4g_fp16.yaml train_scripts/train_spo.py --config configs/spo_sd-v1-5_4k-prompts_num-sam-4_10ep_bs10.py

To fine-tune SDXL, you can use the following command:

PYTHONPATH=$(pwd) accelerate launch --config_file accelerate_cfg/1m4g_fp16.yaml train_scripts/train_spo_sdxl.py --config configs/spo_sdxl_4k-prompts_num-sam-2_3-is_10ep_bs2_gradacc2.py

πŸ”“ Available Checkpoints

SPO-SDXL_4k-prompts_10-epochs

SPO-SDXL_4k-prompts_10-epochs_LoRA

SPO-SD-v1-5_4k-prompts_10-epochs

SPO-SD-v1-5_4k-prompts_10-epochs_LoRA

Acknowledgement

Our codebase references the code from Diffusers, D3PO and PickScore. We extend our gratitude to their authors for open-sourcing their code.

πŸ“¬ Citation

If you find this code useful in your research, please consider citing:

@article{liang2024step,
  title={Step-aware Preference Optimization: Aligning Preference with Denoising Performance at Each Step},
  author={Liang, Zhanhao and Yuan, Yuhui and Gu, Shuyang and Chen, Bohan and Hang, Tiankai and Li, Ji and Zheng, Liang},
  journal={arXiv preprint arXiv:2406.04314},
  year={2024}
}

spo's People

Contributors

rockeycoss avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar

spo's Issues

About the training speed

Hello. I am trying to reproduce the fine-tuning experiments of SD1.5. But it is found that the training speed is very low. Each step of online sampling takes close to 10 minutes. For this estimate, the entire training process requires ((4000/40)(step) * 10(min) * 10(epoch) / 60 = 166h for online sampling alone. This is far from the 12h training time stated in the paper.

Therefore, I would like to ask what might be causing my training speed to be abnormal. I use 4 A800-80G gpus. The batch size of each gpu is 10. I do not change any other hyperparameters.

Looking forward to your reply.

How to understand the training loss?

Hi, great work!

I am now trying to reproducing your work results. For the training loss I got, it looks a bit strange, which fluctuates around 0.693.

W B Chart 7_17_2024, 10_48_55 AM

W B Chart 7_17_2024, 10_48_26 AM

This issue has also been mentioned in #6 (comment). Could you explain why the loss appears like this?

FYI, my GPU does not own enough memory, so I set training batch size as 4.

Out of memory using default config

Hi, thanks for your great work. But when I try to train SPO_SDXL with the default config file spo_sdxl_4k-prompts_num-sam-2_3-is_10ep_bs2_gradacc2.py using 4 80G A800 GPUs, there is an error showing CUDA out of memory. Can you help me confirm your training configuration?

Lora or full parameter

Hi, great work

I'm curious if you are using only lora for training?

SPO-SDXL_4k-prompts_10-epochs is a merger of SDXL-base and SPO-SDXL_4k-prompts_10-epochs_LoRA, not the one you used for full-parameter training?

Have you tried full parameter training? Will full parameter training introduce some problems, such as training instability?

Thank you very much

Questions on validation

Dear authors,

I am currently working on reproducing the quantitative results from your table for the Stable Diffusion v1-5 models and would appreciate some guidance.

Test Setup:

  • Prompts: 500 unique validation prompts for pickscore.
  • Parameters: cfg=7.5, dpmsolver++, step=25.

Models Tested:

  • Diffusion-DPO
  • SPO-Lora

Results Obtained:

Metric Diffusion-DPO SPO-Lora
AES 5.56 5.68
HPS 27.07 27.01
PICKSCORE 20.95 21.04
IMAGE REWARD 0.2297 0.165
  • The SPO-Lora model's generation results exhibit some unusual artifacts.
    Artifact Example

Could you please provide any recommendations for improving the quantitative testing of the SPO-Lora model? Any advice would be incredibly helpful.

Thank you for your pioneering work and for considering my request.

Best regards,

Training step-aware scorer

Dear authors:
Appreciation of your great work!
I wonder if you have any plan to release the training code of the step-aware scorer.
Hope to get your reply soon!

About training loss

I trained the model following your training code and found that the recorded loss curve was very weird. The train_loss always fluctuates violently around 0.693. Additionally, the performance of the trained model was very poor. Could you share your training records? Is there anything specific I need to pay attention to during training?
train_loss
train_ratio_win
train_ratio_lose

About the training of step-aware preference model

As stated in the paper, when the preference model is fine-tuned, you use the estimated x_0 from the noisy sample according to DDIM. However, estimating x_0 requires predicted noise from the noisy sample by a diffusion model. Could you let me know which diffusion model you use to estimate x_0 when training the preference model? SD-v-1-5 or SDXL?

About the training ratio win

Hi, thanks for sharing your code. From my understanding, the training ratio win is $p_\theta (x_{t-1}^w|c, t, x_t) / p_{ref} (x_{t-1}^l|c, t, x_t)$. The training ratio win should increase during the training. When I tried to train SDXL-SPO on the subsampled Pick-Pic v2, I noticed that the training ratio win is almost 1 during the training process. Is that normal ?

Launching inference_spo_sdxl.py does not finish. "1Torch was not compiled with flash attention."

(venv) C:\AI\SPO>py inference_spo_sdxl.py
C:\AI\SPO\venv\Lib\site-packages\diffusers\utils\outputs.py:63: UserWarning: torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead.
torch.utils._pytree._register_pytree_node(
C:\AI\SPO\venv\Lib\site-packages\huggingface_hub\file_download.py:1132: FutureWarning: resume_download is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use force_download=True.
warnings.warn(
Loading pipeline components...: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 7/7 [00:00<00:00, 11.60it/s]
0%| | 0/50 [00:00<?, ?it/s]C:\AI\SPO\venv\Lib\site-packages\diffusers\models\attention_processor.py:1244: UserWarning: 1Torch was not compiled with flash attention. (Triggered internally at ..\aten\src\ATen\native\transformers\cuda\sdp_utils.cpp:455.)
hidden_states = F.scaled_dot_product_attention(

I have tried to launch it previously and it downloaded a ton of models.
Now it's stuck here seemingly forever with GPU spiking.

I'm on a 3070 with 8gb vram, maybe this is why?

Questions on Baseline

Thank you for your excellent work. I have a couple of questions:

  1. Will you be open-sourcing the code for the fine-tuned D3PO model?
  2. I have been following your experimental setup, but I consistently achieve a PickScore of around 17 with D3PO, which is below the reported score of 20 in the paper. When I continue training, the model generates images with higher PickScores but they tend to be blurry.

Here are the adjustments I made:

  1. Changed the reward to PickScore
  2. Set beta to 10
  3. Modified the learning rate
    Is there any aspect I might be overlooking? Any guidance would be greatly appreciated.

Thank you!

The issue of the model crashing after training for one epoch.

Hello, Author!

Due to limited computational resources, I trained SD1.5 using a single GPU and only changed train.train_batch_size to 1, keeping all other hyperparameters unchanged.

I found that the test results after the first epoch were normal, but the model crashed starting from the second epoch. Below are the test images from the first and second epochs.

Have you encountered this issue before? Is this a normal phenomenon?
lakespo0
lakespo1

About 1.5

Interesting, thanks for the XL model. Will you release a 1.5 model ?
Our experiments with Stable Diffusion v1.5 and SDXL demonstrate that SPO significantly outperforms the latest Diffusion-DPO in aligning generated images with complex, detailed prompts and enhancing aesthetics, while also achieving more than 20Γ— times faster in training efficiency.

Question about the validation result

Dear authors,

I am currently working on reproducing the quantitative results from your table for the Stable Diffusion v1-5 models and would appreciate some guidance.

I reproduce your result on the validation set of pickapic_v1_no_images and surprisingly find out that the ImageReward is 0.3 instead of 0.1712. I am not sure why there is a such huge gap. Is there any explanation?

This is the image reward model I used.

import ImageReward as RM
model = RM.load("ImageReward-v1.0")

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    πŸ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. πŸ“ŠπŸ“ˆπŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❀️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.