GithubHelp home page GithubHelp logo

Comments (10)

sayakpaul avatar sayakpaul commented on July 20, 2024

Can you check if the outputs of the text encoders vary when loaded using the method you described?

That will be an easier way to reproduce the problem.

Cc: @lawrence-cj

from diffusers.

asomoza avatar asomoza commented on July 20, 2024

Interesting, I can reproduce this error, these are the outputs:

# text_encoder = T5EncoderModel.from_pretrained(...).to(dtype=torch.float16)
tensor([[[ 0.0872, -0.0144, -0.0733,  ...,  0.0432,  0.0251,  0.1550],
         [ 0.0277, -0.1429, -0.1173,  ...,  0.0565, -0.1959,  0.0936],
         [-0.0569,  0.1390, -0.1050,  ...,  0.0665,  0.0408,  0.1098]]],
       device='cuda:0', dtype=torch.float16, grad_fn=<MulBackward0>)
# text_encoder = T5EncoderModel.from_pretrained(..., torch_dtype=torch.float16)
tensor([[[-1.2744e-01, -1.4755e-02, -6.3416e-02,  ...,  1.0626e-01,
          -3.7567e-02, -1.1975e-01],
         [-1.1462e-01,  6.1569e-03,  1.1475e-01,  ..., -3.8208e-02,
          -1.1078e-01, -1.0980e-01],
         [-5.2605e-03, -7.7438e-03,  3.5763e-06,  ..., -3.6888e-03,
           7.2136e-03,  2.2907e-03]]], device='cuda:0', dtype=torch.float16,
       grad_fn=<MulBackward0>)

I found that with the the second one, some layers are still torch.float32.

from diffusers.

sayakpaul avatar sayakpaul commented on July 20, 2024

Ccing a Transformers maintainer here: @ArthurZucker

from diffusers.

yiyixuxu avatar yiyixuxu commented on July 20, 2024

I think for t5, certain layers are upcasted to float32 when load the checkpoint with from_pretrained in fp16 https://github.com/huggingface/transformers/blob/9ba9369a2557e53a01378199a9839ec6e82d8bc7/src/transformers/models/t5/modeling_t5.py#L797

from diffusers.

Luciennnnnnn avatar Luciennnnnnn commented on July 20, 2024

If certain layers need to be upcasted to float32, is the training code of SD3 correct? In the training code of SD3, the T5 text encoder is initially loaded in float32 and then converted to float16 using the to() method when employing mixed precision training with fp16. We do not appear to encounter similar issues when loading the T5 text encoder in SD3. Could this be due to differences between the T5 encoder utilized in PixArt-Sigma and the one in SD3?

from diffusers.

sayakpaul avatar sayakpaul commented on July 20, 2024

Training does not seem to be affected by this :/

from diffusers.

Luciennnnnnn avatar Luciennnnnnn commented on July 20, 2024

Training does not seem to be affected by this :/

Why? If some parameters of T5 have to be in float32, it will cause flow transformer get inferior text features

from diffusers.

sayakpaul avatar sayakpaul commented on July 20, 2024

Could very well be but the qualitative samples haven’t told me that yet.

This needs a deeper investigation. But the problem could stem from the fact that the original checkpoints are in float16 and I am not exactly sure about the consequences of any kind of casting here yet.

from diffusers.

yiyixuxu avatar yiyixuxu commented on July 20, 2024

@Luciennnnnnn, can you run the same experiment for sd3 to see if you also see if it also produces a worse image in fp16? #8604 (comment)

t5 embeddings is used differently in sd3 and pixart so it is possible it has less or no effect in sd3. But we were not aware that these layers in t5 need to be in fp32 before, so it's not impossible the training could work better for sd3 if we do that.

from diffusers.

yiyixuxu avatar yiyixuxu commented on July 20, 2024

a quick test for sd3 here - fp16 (bottom row) seems ok?

import torch
from diffusers import StableDiffusion3Pipeline
from transformers import T5EncoderModel

repo = "stabilityai/stable-diffusion-3-medium-diffusers"
dtype = torch.float16

pipe = StableDiffusion3Pipeline.from_pretrained(repo, torch_dtype=dtype)
pipe.enable_model_cpu_offload()
print(pipe.text_encoder_3.encoder.block[11].layer[1].DenseReluDense.wo.weight.dtype)
out = []
generator = torch.Generator(device="cpu").manual_seed(0)
for i in range(2):
    image = pipe(
        "A cat holding a sign that says hello world",
        negative_prompt="",
        num_inference_steps=28,
        guidance_scale=7.0,
        generator=generator,
    ).images[0]
    out.append(image)
pipe.text_encoder_3 = pipe.text_encoder_3.to(dtype)
print(pipe.text_encoder_3.encoder.block[11].layer[1].DenseReluDense.wo.weight.dtype)
generator = torch.Generator(device="cpu").manual_seed(0)
for i in range(2):
    image = pipe(
        "A cat holding a sign that says hello world",
        negative_prompt="",
        num_inference_steps=28,
        guidance_scale=7.0,
        generator=generator,
    ).images[0]
    out.append(image)



from diffusers.utils import make_image_grid
make_image_grid(out, rows=2, cols=2).save("yiyi_test_1_out.png")

yiyi_test_1_out

from diffusers.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.