GithubHelp home page GithubHelp logo

Comments (7)

tridao avatar tridao commented on July 24, 2024

Seems like a Triton error. You might have better luck searching their repo issues.

from mamba.

JHChen1 avatar JHChen1 commented on July 24, 2024

Hi, I have the same problem, have you solved it?

from mamba.

yingzhige00 avatar yingzhige00 commented on July 24, 2024

Hi, I have the same problem, have you solved it?

No
but, I am try...
triton-lang/triton#2861
in this issues, he change "cuda:1" to "cuda:0" or "cuda" is OK.

from mamba.

yingzhige00 avatar yingzhige00 commented on July 24, 2024

Hi, I have the same problem, have you solved it?

No but, I am try... triton-lang/triton#2861 in this issues, he change "cuda:1" to "cuda:0" or "cuda" is OK.

This code runs, but it runs slower

import torch
import timeit
import os
from mamba_ssm import Mamba, Mamba2


batch, length, dim = 1, 64, 256
os.environ["CUDA_VISIBLE_DEVICES"] = "4"
cuda = "cuda"
x = torch.randn(batch, length, dim).to(cuda)

def try_mamba1(batch, length, dim, x):
    model = Mamba(
        # This module uses roughly 3 * expand * d_model^2 parameters
        d_model=dim, # Model dimension d_model
        d_state=16,  # SSM state expansion factor
        d_conv=4,    # Local convolution width
        expand=2,    # Block expansion factor
    ).to(cuda)
    y = model(x)
    assert y.shape == x.shape

def try_mamba2(batch, length, dim, x):
    model = Mamba2(
        # This module uses roughly 3 * expand * d_model^2 parameters
        d_model=dim, # Model dimension d_model
        d_state=64,  # SSM state expansion factor, typically 64 or 128
        d_conv=4,    # Local convolution width
        expand=2,    # Block expansion factor
    ).to(cuda)
    y = model(x)
    assert y.shape == x.shape

mamba1_time = timeit.timeit('try_mamba1(batch, length, dim, x)', number=10, globals=globals())
print(f"Mamba 1 took {mamba1_time} seconds")

mamba2_time = timeit.timeit('try_mamba2(batch, length, dim, x)', number=10, globals=globals())
print(f"Mamba 2 took {mamba2_time} seconds")

from mamba.

JHChen1 avatar JHChen1 commented on July 24, 2024

Seems like a Triton error. You might have better luck searching their repo issues.

Hello, when using other devices instead of cuda:0, "RuntimeError: Triton Error [CUDA]: context is destroyed" will appear. Is there any solution?

from mamba.

JHChen1 avatar JHChen1 commented on July 24, 2024

Hi, I have the same problem, have you solved it?

No but, I am try... triton-lang/triton#2861 in this issues, he change "cuda:1" to "cuda:0" or "cuda" is OK.

This code runs, but it runs slower

import torch
import timeit
import os
from mamba_ssm import Mamba, Mamba2


batch, length, dim = 1, 64, 256
os.environ["CUDA_VISIBLE_DEVICES"] = "4"
cuda = "cuda"
x = torch.randn(batch, length, dim).to(cuda)

def try_mamba1(batch, length, dim, x):
    model = Mamba(
        # This module uses roughly 3 * expand * d_model^2 parameters
        d_model=dim, # Model dimension d_model
        d_state=16,  # SSM state expansion factor
        d_conv=4,    # Local convolution width
        expand=2,    # Block expansion factor
    ).to(cuda)
    y = model(x)
    assert y.shape == x.shape

def try_mamba2(batch, length, dim, x):
    model = Mamba2(
        # This module uses roughly 3 * expand * d_model^2 parameters
        d_model=dim, # Model dimension d_model
        d_state=64,  # SSM state expansion factor, typically 64 or 128
        d_conv=4,    # Local convolution width
        expand=2,    # Block expansion factor
    ).to(cuda)
    y = model(x)
    assert y.shape == x.shape

mamba1_time = timeit.timeit('try_mamba1(batch, length, dim, x)', number=10, globals=globals())
print(f"Mamba 1 took {mamba1_time} seconds")

mamba2_time = timeit.timeit('try_mamba2(batch, length, dim, x)', number=10, globals=globals())
print(f"Mamba 2 took {mamba2_time} seconds")

Thanks for your reply. When I set it the same as you did, it would use cuda:0 by default instead of cuda:1 as I wanted. The program itself can be executed in cuda:0, but the same error as yours will appear in cuda:1. I have not found the reason yet.
"torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 98.00 MiB (GPU 0; 15.77 GiB total capacity; 3.16 GiB already allocated; 86.19 MiB free; 3.26 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation. See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF"

from mamba.

ZLKong avatar ZLKong commented on July 24, 2024

Seems like a Triton error. You might have better luck searching their repo issues.

Hello, when using other devices instead of cuda:0, "RuntimeError: Triton Error [CUDA]: context is destroyed" will appear. Is there any solution?

Hello, I found a solution here facebookresearch/xformers#681

Therefore, in the error log, for the code in
File "/opt/conda/envs/mamba2/lib/python3.10/site-packages/mamba_ssm/ops/triton/ssd_combined.py", line 323, in _mamba_chunk_scan_combined_fwd
out, out_x = _chunk_scan_fwd(CB, x, dt, dA_cumsum, C, states, D=D, z=z, seq_idx=seq_idx)

I added

with torch.cuda.device(CB.device):
        out, out_x = _chunk_scan_fwd(CB, x, dt, dA_cumsum, C, states, D=D, z=z, seq_idx=seq_idx)

from mamba.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.