Comments (10)
Can you check if the outputs of the text encoders vary when loaded using the method you described?
That will be an easier way to reproduce the problem.
Cc: @lawrence-cj
from diffusers.
Interesting, I can reproduce this error, these are the outputs:
# text_encoder = T5EncoderModel.from_pretrained(...).to(dtype=torch.float16)
tensor([[[ 0.0872, -0.0144, -0.0733, ..., 0.0432, 0.0251, 0.1550],
[ 0.0277, -0.1429, -0.1173, ..., 0.0565, -0.1959, 0.0936],
[-0.0569, 0.1390, -0.1050, ..., 0.0665, 0.0408, 0.1098]]],
device='cuda:0', dtype=torch.float16, grad_fn=<MulBackward0>)
# text_encoder = T5EncoderModel.from_pretrained(..., torch_dtype=torch.float16)
tensor([[[-1.2744e-01, -1.4755e-02, -6.3416e-02, ..., 1.0626e-01,
-3.7567e-02, -1.1975e-01],
[-1.1462e-01, 6.1569e-03, 1.1475e-01, ..., -3.8208e-02,
-1.1078e-01, -1.0980e-01],
[-5.2605e-03, -7.7438e-03, 3.5763e-06, ..., -3.6888e-03,
7.2136e-03, 2.2907e-03]]], device='cuda:0', dtype=torch.float16,
grad_fn=<MulBackward0>)
I found that with the the second one, some layers are still torch.float32
.
from diffusers.
Ccing a Transformers maintainer here: @ArthurZucker
from diffusers.
I think for t5, certain layers are upcasted to float32 when load the checkpoint with from_pretrained in fp16 https://github.com/huggingface/transformers/blob/9ba9369a2557e53a01378199a9839ec6e82d8bc7/src/transformers/models/t5/modeling_t5.py#L797
from diffusers.
If certain layers need to be upcasted to float32
, is the training code of SD3 correct? In the training code of SD3, the T5 text encoder is initially loaded in float32
and then converted to float16
using the to()
method when employing mixed precision training with fp16
. We do not appear to encounter similar issues when loading the T5 text encoder in SD3. Could this be due to differences between the T5 encoder utilized in PixArt-Sigma and the one in SD3?
from diffusers.
Training does not seem to be affected by this :/
from diffusers.
Training does not seem to be affected by this :/
Why? If some parameters of T5 have to be in float32
, it will cause flow transformer get inferior text features
from diffusers.
Could very well be but the qualitative samples haven’t told me that yet.
This needs a deeper investigation. But the problem could stem from the fact that the original checkpoints are in float16 and I am not exactly sure about the consequences of any kind of casting here yet.
from diffusers.
@Luciennnnnnn, can you run the same experiment for sd3 to see if you also see if it also produces a worse image in fp16? #8604 (comment)
t5 embeddings is used differently in sd3 and pixart so it is possible it has less or no effect in sd3. But we were not aware that these layers in t5 need to be in fp32 before, so it's not impossible the training could work better for sd3 if we do that.
from diffusers.
a quick test for sd3 here - fp16 (bottom row) seems ok?
import torch
from diffusers import StableDiffusion3Pipeline
from transformers import T5EncoderModel
repo = "stabilityai/stable-diffusion-3-medium-diffusers"
dtype = torch.float16
pipe = StableDiffusion3Pipeline.from_pretrained(repo, torch_dtype=dtype)
pipe.enable_model_cpu_offload()
print(pipe.text_encoder_3.encoder.block[11].layer[1].DenseReluDense.wo.weight.dtype)
out = []
generator = torch.Generator(device="cpu").manual_seed(0)
for i in range(2):
image = pipe(
"A cat holding a sign that says hello world",
negative_prompt="",
num_inference_steps=28,
guidance_scale=7.0,
generator=generator,
).images[0]
out.append(image)
pipe.text_encoder_3 = pipe.text_encoder_3.to(dtype)
print(pipe.text_encoder_3.encoder.block[11].layer[1].DenseReluDense.wo.weight.dtype)
generator = torch.Generator(device="cpu").manual_seed(0)
for i in range(2):
image = pipe(
"A cat holding a sign that says hello world",
negative_prompt="",
num_inference_steps=28,
guidance_scale=7.0,
generator=generator,
).images[0]
out.append(image)
from diffusers.utils import make_image_grid
make_image_grid(out, rows=2, cols=2).save("yiyi_test_1_out.png")
from diffusers.
Related Issues (20)
- Getting OOM error when using "--caption_column="caption"". HOT 9
- controlnet singlefile dont have config.json HOT 15
- Support `fuse_lora` on Stable Diffusion 3
- Classifier free guidance(CFG) on different prediction types and karras style schedulers HOT 2
- High Batch Size with SD3 Dreambooth Destabilizes Training HOT 1
- Running stable diffusion 3 medium : fused_layer_norm_cuda.cpython-310-x86_64-linux-gnu.so: undefined symbol HOT 2
- More thorough guidance for multiple IP adapter images/masks and a single IP Adapter HOT 9
- Stable audio open diffusers version HOT 1
- ImportError: cannot import name 'PixArtSigmaPipeline' from 'diffusers'
- getting bug on diffuser example HOT 7
- Integrate Lumina-T2X HOT 4
- Loading T5 encoder separately with StableDiffusion3Pipeline causes meta tensor error on sequential/model cpu offload HOT 5
- There is no create_diffusers_controlnet_model_from_ldm function in single_file_utils.py HOT 1
- StableDiffusionControlNetImg2ImgPipeline call report “argument of type 'NoneType' is not iterable” HOT 1
- SD3 - num_images_per_prompt no longer honoured (throws error) HOT 3
- configuration_utils.to_json_string() fails on WindowsPath HOT 2
- i2vgen-xl keep produce black gif HOT 4
- train_text_to_image_sdxl.py fail resume from checkpoint and also can not load for infer HOT 1
- AnimateDiff bug not sure if it use adapter or not HOT 2
- AnimateDiffSDXL + Multi Controlnets support HOT 2
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from diffusers.