GithubHelp home page GithubHelp logo

Comments (5)

LioYao avatar LioYao commented on August 22, 2024 2

Thanks for your answer @ebsmothers and @kartikayk. Based on your tips, I found out that it was my environment problem. When I used a new environment, I found that the code could be run on the GPU. The possible explanation is that I mixed conda and pip when configuring the environment, causing multiple different versions of libraries to exist at the same time. Thanks again for your help.

from multimodal.

ebsmothers avatar ebsmothers commented on August 22, 2024 1

Hi @LioYao, strangely I cannot reproduce this error on my end, your script runs fine for me. A couple requests that will help with debugging:

  1. Can you share the version of torch and other dependencies you're using (e.g. if in a conda env just paste the output of conda list)?
  2. What hardware are you running on?
  3. Can you paste the full stack trace you're seeing?

In the meantime, a couple other suggestions (not sure that they will work until I have more info). One is to explicitly run in a device context manager, i.e. you can wrap everything in a with torch.cuda.device(0): block. Then you shouldn't need to call .to(device) anymore either. A second suggestion is to toggle different kernels for scaled_dot_product_attention. E.g. you can do

with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_mem_efficient=False, enable_math=False):
  image_embeddings, text_embeddings = model(image, text)

to enable the flash attention kernel and disable the other kernels. This may not solve your problem, but I would be interested to see if the same error repros under all of the different kernels.

from multimodal.

ebsmothers avatar ebsmothers commented on August 22, 2024

Hi @LioYao thanks for creating the issue. I suspect that the issue cause of the device mismatch here is ContrastiveLossWithTemperature. Note that this loss contains a parameter, which should also be moved to the GPU. Can you try calling .to(device) on your loss as well?

from multimodal.

LioYao avatar LioYao commented on August 22, 2024

Thanks for your answer @ebsmothers.
When I moved the loss to GPU, the same error still appeared. I printed the devices about image, text, model's parameters and loss's parameters and found that they are all in cuda:0. Most importantly, the error still occurs in "image_embeddings, text_embeddings = model(image,text)". The final traceback point of the error is in the scaled_dot_product_attention function. I wonder whether the multi-head attention mechanism forcibly copies the GPU data to the CPU when processing data.

from multimodal.

kartikayk avatar kartikayk commented on August 22, 2024

Hi @LioYao, I tried the following very simple setup and this works just fine for me. So I'm not able to reproduce your error. Do you mind trying this simple version and seeing what you get?

model= clip_vit_b16(pretrained=False)
model = model.to('cuda:0')
# define loss and other things needed for training
clip_loss = ContrastiveLossWithTemperature()
optimizer = torch.optim.AdamW(model.parameters(),0.0225)

text = torch.randint(0, 49408, (1, 77), dtype=torch.long).to('cuda:0')
image = torch.randn(1, 3, 224, 224).to('cuda:0')

image_embeddings, text_embeddings = model(image, text)
loss = clip_loss(image_embeddings, text_embeddings)
loss.backward()
optimizer.step()
optimizer.zero_grad()

print(optimizer.param_groups[0]['params'][0].device)
print(image_embeddings.device)
print(text_embeddings.device)

And the output is:

cuda:0
cuda:0
cuda:0

from multimodal.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.