GithubHelp home page GithubHelp logo

Mamba excessive memory usage about transformers HOT 7 OPEN

GooseIt avatar GooseIt commented on June 18, 2024 8
Mamba excessive memory usage

from transformers.

Comments (7)

famishedrover avatar famishedrover commented on June 18, 2024

Before I get into the internals of Mamba implementation, this maybe a useful thread.

The following snippet keeps a check on memory :

for ix in range(2) : 
    inp = torch.randint(0, 128, (2, 512)).to('cuda')
    out = model(inp, use_cache=False)
    del out
    del inp
print(torch.cuda.max_memory_allocated())
for ix in range(2) : 
    inp = torch.randint(0, 128, (2, 512)).to('cuda')
    out = model(inp, use_cache=False)
    del out
    del inp
print(torch.cuda.max_memory_allocated())

and I get the following output :

513157120
513157120

After only loading the modules (imports) result of torch.cuda.max_memory_allocated() = 0. (expected)
After loading the model to RAM result of torch.cuda.max_memory_allocated() = 0. (expected)
After model = model.to('cuda') result of torch.cuda.max_memory_allocated() = 5529600 . ~5.5Mb
After a forward pass using a dummy batch as above result is 513157120 ~ 0.5Gb

Using the del commands, the usage remains same.

from transformers.

GooseIt avatar GooseIt commented on June 18, 2024

@famishedrover

The issue seems not to be one of those described in the thread. Rather than gradually increasing each epoch, memory spikes during forward pass, spike amplitudes seems to be consistent between epochs.

I've updated the Kaggle notebook - please check Version 6. It now includes torch cuda memory profile image, as well as memory usage prints in different setups backing my suspicion that the problem is in model itself.

from transformers.

GooseIt avatar GooseIt commented on June 18, 2024

@famishedrover

I'm carefully reminding you about this issue in case you've forgotten about it

from transformers.

famishedrover avatar famishedrover commented on June 18, 2024

I agree, there is something going on with the forward pass itself. The del hack atleast prevents linear growth across multiple forward passes ( but within one pass it still takes unreasonable size of memory ). I will take a look at this later in the week ( limited bandwidth ).

from transformers.

v4ndi avatar v4ndi commented on June 18, 2024

@famishedrover @ArthurZucker @younesbelkada @amyeroberts @koayon
I've encountered the same issue as described above. I've tried to used transformers mamba implmentation instead of state-spaces/mamba. If neccessary I can provide example of code, Please fix this isssue, I will be really appreciative of this

from transformers.

ArthurZucker avatar ArthurZucker commented on June 18, 2024

You would need to provide a reproducer. If you try the original state space model and have the kernels, the hf model should not really change much.
You should make sure you are testing equiavlent things: gradient or not, fast path or not, use cache or not

from transformers.

GooseIt avatar GooseIt commented on June 18, 2024

@v4ndi Please provide the reproducer, it will be very helpful

from transformers.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.