GithubHelp home page GithubHelp logo

Windows Support about mamba HOT 11 OPEN

Phylliida avatar Phylliida commented on July 24, 2024
Windows Support

from mamba.

Comments (11)

Phylliida avatar Phylliida commented on July 24, 2024 4

Okay I've successfully ran inference on Windows. I'm in python 3.9 cuda 12.1 I had to do the following things:

(do all of the following in x64 Native Tools Command Prompt for VS 2019)

compile causal-conv1d by adding

                        "-DWIN32_LEAN_AND_MEAN",

To the nvcc flags in setup.py

(you may also need to run

SET DISTUTILS_USE_SDK=1

)

Next, we need to install triton.

Download triton wheel from here scroll down to the bottom and download triton-dist windows-latest

extract it then run

pip3 install triton-2.1.0-cp39-cp39-win_amd64.whl

If you have a different version of python and cuda 11.8 you can use one from here instead though I haven't tested that

Next, you need to get the compiled libraries triton needs. You can download them from here, add the bin directory to your PATH

If you prefer to compile it yourself you can see the command here but be wary it'll take about 1-2 hours.

Finally, I just modified ops/selective_scan_interface.py to:

  1. Remove this line:
import selective_scan_cuda
  1. Replace
def selective_scan_fn(u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False,
                     return_last_state=False):
    """if return_last_state is True, returns (out, last_state)
    last_state has shape (batch, dim, dstate). Note that the gradient of the last state is
    not considered in the backward pass.
    """
    return SelectiveScanFn.apply(u, delta, A, B, C, D, z, delta_bias, delta_softplus, return_last_state)

with

def selective_scan_fn(u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False,
                     return_last_state=False):
    """if return_last_state is True, returns (out, last_state)
    last_state has shape (batch, dim, dstate). Note that the gradient of the last state is
    not considered in the backward pass.
    """
    return selective_scan_ref(u, delta, A, B, C, D, z, delta_bias, delta_softplus, return_last_state)

it would be better to use the kernel, but until we can get it compiling on windows we can use the reference implementation in pure python instead.

With this setup I'm able to run inference using the 2.8b model (at fp16 or fp32) on a 3090.

For example:

Prompt:

User: What is the answer to life the universe and everything? Oracle:

Answer:

I don't know. I'm just a computer.

from mamba.

Grzego avatar Grzego commented on July 24, 2024 4

I think I found a workaround for compiling this package for windows (however, I have not tested the impact on performance). MSVC has a problem with constexpr and can't handle passing them to templates as arguments (see this and this). The workaround is to replace constexpr with const static.

diff --git a/csrc/selective_scan/selective_scan_fwd_kernel.cuh b/csrc/selective_scan/selective_scan_fwd_kernel.cuh
index 440a209..b3ef2a8 100644
--- a/csrc/selective_scan/selective_scan_fwd_kernel.cuh
+++ b/csrc/selective_scan/selective_scan_fwd_kernel.cuh
@@ -306,14 +306,14 @@ template<int kNThreads, int kNItems, typename input_t, typename weight_t>
 void selective_scan_fwd_launch(SSMParamsBase &params, cudaStream_t stream) {
     // Only kNRows == 1 is tested for now, which ofc doesn't differ from previously when we had each block
     // processing 1 row.
-    constexpr int kNRows = 1;
+    const static int kNRows = 1;
     BOOL_SWITCH(params.seqlen % (kNThreads * kNItems) == 0, kIsEvenLen, [&] {
         BOOL_SWITCH(params.is_variable_B, kIsVariableB, [&] {
             BOOL_SWITCH(params.is_variable_C, kIsVariableC, [&] {
                 BOOL_SWITCH(params.z_ptr != nullptr , kHasZ, [&] {
                     using Ktraits = Selective_Scan_fwd_kernel_traits<kNThreads, kNItems, kNRows, kIsEvenLen, kIsVariableB, kIsVariableC, kHasZ, input_t, weight_t>;
-                    // constexpr int kSmemSize = Ktraits::kSmemSize;
-                    constexpr int kSmemSize = Ktraits::kSmemSize + kNRows * MAX_DSTATE * sizeof(typename Ktraits::scan_t);
+                    // const static int kSmemSize = Ktraits::kSmemSize;
+                    const static int kSmemSize = Ktraits::kSmemSize + kNRows * MAX_DSTATE * sizeof(typename Ktraits::scan_t);
                     // printf("smem_size = %d\n", kSmemSize);
                     dim3 grid(params.batch, params.dim / kNRows);
                     auto kernel = &selective_scan_fwd_kernel<Ktraits>;
diff --git a/csrc/selective_scan/static_switch.h b/csrc/selective_scan/static_switch.h
index 7920ac0..87493ef 100644
--- a/csrc/selective_scan/static_switch.h
+++ b/csrc/selective_scan/static_switch.h
@@ -16,10 +16,10 @@
 #define BOOL_SWITCH(COND, CONST_NAME, ...)                                           \
     [&] {                                                                            \
         if (COND) {                                                                  \
-            constexpr bool CONST_NAME = true;                                        \
+            const static bool CONST_NAME = true;                                     \
             return __VA_ARGS__();                                                    \
         } else {                                                                     \
-            constexpr bool CONST_NAME = false;                                       \
+            const static bool CONST_NAME = false;                                    \
             return __VA_ARGS__();                                                    \
         }                                                                            \
     }()

With those changes I can compile the package. It seems to work in PyTorch, but like I mentioned, I haven't tested performance or correctness. 😅

from mamba.

ramzeez88 avatar ramzeez88 commented on July 24, 2024 1

I have managed to build mamba-ssm but for the life of mine , i cannot compile causal-conv1d, @Phylliida the "-DWIN32_LEAN_AND_MEAN", goes right into here :
extra_compile_args={
"cxx": ["-O3"],
"nvcc": append_nvcc_threads(
[
"-DWIN32_LEAN_AND_MEAN",
"-O3",
"-U__CUDA_NO_HALF_OPERATORS__",
"-U__CUDA_NO_HALF_CONVERSIONS__",
"-U__CUDA_NO_BFLOAT16_OPERATORS__",
"-U__CUDA_NO_BFLOAT16_CONVERSIONS__",
"-U__CUDA_NO_BFLOAT162_OPERATORS__",
"-U__CUDA_NO_BFLOAT162_CONVERSIONS__",
"--expt-relaxed-constexpr",
"--expt-extended-lambda",
"--use_fast_math",
"--ptxas-options=-v",
"-lineinfo",
]
right ?

from mamba.

albertfgu avatar albertfgu commented on July 24, 2024

Unfortunately we've never tested windows paths, and it's not on the roadmap right now.

from mamba.

nat42 avatar nat42 commented on July 24, 2024

Sorry if this is something you've already checked/covered @Phylliida but have you checked perhaps that you are building the code as C++20 (just guessing that the way constexpr and lambdas are used that it'll need to be that version of the language)?

EDIT: also that comment you link to, that links to a Stack Overflow post appears to be unrelated to either issue thread; it's talking about something completely different (I'd hazard a guess the commenter remembered a #define being useful for array declaration and was sharing it, even though it did not relate to the specific defines you mentioned there)

EDIT2: per https://learn.microsoft.com/en-us/cpp/c-runtime-library/math-constants it perhaps might be better to define _USE_MATH_DEFINES for like M_LOG2E to be defined

EDIT3: actually it looks like the code was updated a day or so ago to ask that it be compiled with C++17 (not 20 as I had guessed) maybe check if you have this the recent commit also? 023c25d

from mamba.

Phylliida avatar Phylliida commented on July 24, 2024

Nice, adding

                        "-D_USE_MATH_DEFINES",

to nvcc flags is a better alternative

Compiling with c++17 isn't enough, I get the errors listed above. Rn I'm trying to get c++20 working, no success yet

Edit: Ok looks like triton is a dependency, I'm trying out wheels prebuilt from here (scroll down to the bottom, extract the windows build, then pip install ___.whl for your version of python. I'm using 3.10 and Cuda 12.1)

from mamba.

Jacky56 avatar Jacky56 commented on July 24, 2024

I think I found a workaround for compiling this package for windows (however, I have not tested the impact on performance). MSVC has a problem with constexpr and can't handle passing them to templates as arguments (see this and this). The workaround is to replace constexpr with const static.

diff --git a/csrc/selective_scan/selective_scan_fwd_kernel.cuh b/csrc/selective_scan/selective_scan_fwd_kernel.cuh
index 440a209..b3ef2a8 100644
--- a/csrc/selective_scan/selective_scan_fwd_kernel.cuh
+++ b/csrc/selective_scan/selective_scan_fwd_kernel.cuh
@@ -306,14 +306,14 @@ template<int kNThreads, int kNItems, typename input_t, typename weight_t>
 void selective_scan_fwd_launch(SSMParamsBase &params, cudaStream_t stream) {
     // Only kNRows == 1 is tested for now, which ofc doesn't differ from previously when we had each block
     // processing 1 row.
-    constexpr int kNRows = 1;
+    const static int kNRows = 1;
     BOOL_SWITCH(params.seqlen % (kNThreads * kNItems) == 0, kIsEvenLen, [&] {
         BOOL_SWITCH(params.is_variable_B, kIsVariableB, [&] {
             BOOL_SWITCH(params.is_variable_C, kIsVariableC, [&] {
                 BOOL_SWITCH(params.z_ptr != nullptr , kHasZ, [&] {
                     using Ktraits = Selective_Scan_fwd_kernel_traits<kNThreads, kNItems, kNRows, kIsEvenLen, kIsVariableB, kIsVariableC, kHasZ, input_t, weight_t>;
-                    // constexpr int kSmemSize = Ktraits::kSmemSize;
-                    constexpr int kSmemSize = Ktraits::kSmemSize + kNRows * MAX_DSTATE * sizeof(typename Ktraits::scan_t);
+                    // const static int kSmemSize = Ktraits::kSmemSize;
+                    const static int kSmemSize = Ktraits::kSmemSize + kNRows * MAX_DSTATE * sizeof(typename Ktraits::scan_t);
                     // printf("smem_size = %d\n", kSmemSize);
                     dim3 grid(params.batch, params.dim / kNRows);
                     auto kernel = &selective_scan_fwd_kernel<Ktraits>;
diff --git a/csrc/selective_scan/static_switch.h b/csrc/selective_scan/static_switch.h
index 7920ac0..87493ef 100644
--- a/csrc/selective_scan/static_switch.h
+++ b/csrc/selective_scan/static_switch.h
@@ -16,10 +16,10 @@
 #define BOOL_SWITCH(COND, CONST_NAME, ...)                                           \
     [&] {                                                                            \
         if (COND) {                                                                  \
-            constexpr bool CONST_NAME = true;                                        \
+            const static bool CONST_NAME = true;                                     \
             return __VA_ARGS__();                                                    \
         } else {                                                                     \
-            constexpr bool CONST_NAME = false;                                       \
+            const static bool CONST_NAME = false;                                    \
             return __VA_ARGS__();                                                    \
         }                                                                            \
     }()

With those changes I can compile the package. It seems to work in PyTorch, but like I mentioned, I haven't tested performance or correctness. 😅

working solution. (compiled but haven't trained)

python 3.11.7
windows 10

from mamba.

RiceBunny1990 avatar RiceBunny1990 commented on July 24, 2024

@Phylliida hello, thanks for your method. But I don't understand what to be added after removing "import selective_scan_cuda" .In the class SelectiveScanFn , There are " out, x, *rest = selective_scan_cuda.fwd(u, delta, A, B, C, D, delta_bias, delta_softplus)" and "du, ddelta, dA, dB, dC, dD, ddelta_bias, *rest = selective_scan_cuda.bwd(u, delta, A, B, C, D, delta_bias, dout, x, None, ctx.delta_softplus, ) " in the forward and backward fuctions .
Please help me.

from mamba.

Grzego avatar Grzego commented on July 24, 2024

@RiceBunny1990 You can skip any modifications to ops/selective_scan_interface.py after you successfully compile mamba kernels on windows. Which should be possible after doing the changes I posted previously.

from mamba.

F286 avatar F286 commented on July 24, 2024

Is there a simple way to get the training and inference (without recompiling the CUDA kernels) working on Windows without using WSL?

from mamba.

lyhyl avatar lyhyl commented on July 24, 2024

@Phylliida @Grzego Thank you for your information, I have complied causal_conv1d 1.1.3.post1 and mamba 1.1.3.post1 successfully in python 3.10 + windows 11 x64 + torch 2.2 + cuda 12.1. However, when I try to import mamba, it will crash on import casual_conv1d_cuda, gives:

ImportError: DLL load failed while importing causal_conv1d_cuda: The specified module could not be found.

I have checked causal_conv1d_cuda.cp310-win_amd64.pyd's dependencies (AFAIK pyd is dll in windows), all its dependencies exist.
image
Any idea what causes it failed?

from mamba.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.