Comments (11)
Okay I've successfully ran inference on Windows. I'm in python 3.9 cuda 12.1 I had to do the following things:
(do all of the following in x64 Native Tools Command Prompt for VS 2019)
compile causal-conv1d by adding
"-DWIN32_LEAN_AND_MEAN",
To the nvcc flags in setup.py
(you may also need to run
SET DISTUTILS_USE_SDK=1
)
Next, we need to install triton.
Download triton wheel from here scroll down to the bottom and download triton-dist windows-latest
extract it then run
pip3 install triton-2.1.0-cp39-cp39-win_amd64.whl
If you have a different version of python and cuda 11.8 you can use one from here instead though I haven't tested that
Next, you need to get the compiled libraries triton needs. You can download them from here, add the bin directory to your PATH
If you prefer to compile it yourself you can see the command here but be wary it'll take about 1-2 hours.
Finally, I just modified ops/selective_scan_interface.py
to:
- Remove this line:
import selective_scan_cuda
- Replace
def selective_scan_fn(u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False,
return_last_state=False):
"""if return_last_state is True, returns (out, last_state)
last_state has shape (batch, dim, dstate). Note that the gradient of the last state is
not considered in the backward pass.
"""
return SelectiveScanFn.apply(u, delta, A, B, C, D, z, delta_bias, delta_softplus, return_last_state)
with
def selective_scan_fn(u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False,
return_last_state=False):
"""if return_last_state is True, returns (out, last_state)
last_state has shape (batch, dim, dstate). Note that the gradient of the last state is
not considered in the backward pass.
"""
return selective_scan_ref(u, delta, A, B, C, D, z, delta_bias, delta_softplus, return_last_state)
it would be better to use the kernel, but until we can get it compiling on windows we can use the reference implementation in pure python instead.
With this setup I'm able to run inference using the 2.8b model (at fp16 or fp32) on a 3090.
For example:
Prompt:
User: What is the answer to life the universe and everything? Oracle:
Answer:
I don't know. I'm just a computer.
from mamba.
I think I found a workaround for compiling this package for windows (however, I have not tested the impact on performance). MSVC has a problem with constexpr
and can't handle passing them to templates as arguments (see this and this). The workaround is to replace constexpr
with const static
.
diff --git a/csrc/selective_scan/selective_scan_fwd_kernel.cuh b/csrc/selective_scan/selective_scan_fwd_kernel.cuh
index 440a209..b3ef2a8 100644
--- a/csrc/selective_scan/selective_scan_fwd_kernel.cuh
+++ b/csrc/selective_scan/selective_scan_fwd_kernel.cuh
@@ -306,14 +306,14 @@ template<int kNThreads, int kNItems, typename input_t, typename weight_t>
void selective_scan_fwd_launch(SSMParamsBase ¶ms, cudaStream_t stream) {
// Only kNRows == 1 is tested for now, which ofc doesn't differ from previously when we had each block
// processing 1 row.
- constexpr int kNRows = 1;
+ const static int kNRows = 1;
BOOL_SWITCH(params.seqlen % (kNThreads * kNItems) == 0, kIsEvenLen, [&] {
BOOL_SWITCH(params.is_variable_B, kIsVariableB, [&] {
BOOL_SWITCH(params.is_variable_C, kIsVariableC, [&] {
BOOL_SWITCH(params.z_ptr != nullptr , kHasZ, [&] {
using Ktraits = Selective_Scan_fwd_kernel_traits<kNThreads, kNItems, kNRows, kIsEvenLen, kIsVariableB, kIsVariableC, kHasZ, input_t, weight_t>;
- // constexpr int kSmemSize = Ktraits::kSmemSize;
- constexpr int kSmemSize = Ktraits::kSmemSize + kNRows * MAX_DSTATE * sizeof(typename Ktraits::scan_t);
+ // const static int kSmemSize = Ktraits::kSmemSize;
+ const static int kSmemSize = Ktraits::kSmemSize + kNRows * MAX_DSTATE * sizeof(typename Ktraits::scan_t);
// printf("smem_size = %d\n", kSmemSize);
dim3 grid(params.batch, params.dim / kNRows);
auto kernel = &selective_scan_fwd_kernel<Ktraits>;
diff --git a/csrc/selective_scan/static_switch.h b/csrc/selective_scan/static_switch.h
index 7920ac0..87493ef 100644
--- a/csrc/selective_scan/static_switch.h
+++ b/csrc/selective_scan/static_switch.h
@@ -16,10 +16,10 @@
#define BOOL_SWITCH(COND, CONST_NAME, ...) \
[&] { \
if (COND) { \
- constexpr bool CONST_NAME = true; \
+ const static bool CONST_NAME = true; \
return __VA_ARGS__(); \
} else { \
- constexpr bool CONST_NAME = false; \
+ const static bool CONST_NAME = false; \
return __VA_ARGS__(); \
} \
}()
With those changes I can compile the package. It seems to work in PyTorch, but like I mentioned, I haven't tested performance or correctness. 😅
from mamba.
I have managed to build mamba-ssm but for the life of mine , i cannot compile causal-conv1d, @Phylliida the "-DWIN32_LEAN_AND_MEAN", goes right into here :
extra_compile_args={
"cxx": ["-O3"],
"nvcc": append_nvcc_threads(
[
"-DWIN32_LEAN_AND_MEAN",
"-O3",
"-U__CUDA_NO_HALF_OPERATORS__",
"-U__CUDA_NO_HALF_CONVERSIONS__",
"-U__CUDA_NO_BFLOAT16_OPERATORS__",
"-U__CUDA_NO_BFLOAT16_CONVERSIONS__",
"-U__CUDA_NO_BFLOAT162_OPERATORS__",
"-U__CUDA_NO_BFLOAT162_CONVERSIONS__",
"--expt-relaxed-constexpr",
"--expt-extended-lambda",
"--use_fast_math",
"--ptxas-options=-v",
"-lineinfo",
]
right ?
from mamba.
Unfortunately we've never tested windows paths, and it's not on the roadmap right now.
from mamba.
Sorry if this is something you've already checked/covered @Phylliida but have you checked perhaps that you are building the code as C++20 (just guessing that the way constexpr and lambdas are used that it'll need to be that version of the language)?
EDIT: also that comment you link to, that links to a Stack Overflow post appears to be unrelated to either issue thread; it's talking about something completely different (I'd hazard a guess the commenter remembered a #define being useful for array declaration and was sharing it, even though it did not relate to the specific defines you mentioned there)
EDIT2: per https://learn.microsoft.com/en-us/cpp/c-runtime-library/math-constants it perhaps might be better to define _USE_MATH_DEFINES for like M_LOG2E to be defined
EDIT3: actually it looks like the code was updated a day or so ago to ask that it be compiled with C++17 (not 20 as I had guessed) maybe check if you have this the recent commit also? 023c25d
from mamba.
Nice, adding
"-D_USE_MATH_DEFINES",
to nvcc flags is a better alternative
Compiling with c++17 isn't enough, I get the errors listed above. Rn I'm trying to get c++20 working, no success yet
Edit: Ok looks like triton is a dependency, I'm trying out wheels prebuilt from here (scroll down to the bottom, extract the windows build, then pip install ___.whl for your version of python. I'm using 3.10 and Cuda 12.1)
from mamba.
I think I found a workaround for compiling this package for windows (however, I have not tested the impact on performance). MSVC has a problem with
constexpr
and can't handle passing them to templates as arguments (see this and this). The workaround is to replaceconstexpr
withconst static
.diff --git a/csrc/selective_scan/selective_scan_fwd_kernel.cuh b/csrc/selective_scan/selective_scan_fwd_kernel.cuh index 440a209..b3ef2a8 100644 --- a/csrc/selective_scan/selective_scan_fwd_kernel.cuh +++ b/csrc/selective_scan/selective_scan_fwd_kernel.cuh @@ -306,14 +306,14 @@ template<int kNThreads, int kNItems, typename input_t, typename weight_t> void selective_scan_fwd_launch(SSMParamsBase ¶ms, cudaStream_t stream) { // Only kNRows == 1 is tested for now, which ofc doesn't differ from previously when we had each block // processing 1 row. - constexpr int kNRows = 1; + const static int kNRows = 1; BOOL_SWITCH(params.seqlen % (kNThreads * kNItems) == 0, kIsEvenLen, [&] { BOOL_SWITCH(params.is_variable_B, kIsVariableB, [&] { BOOL_SWITCH(params.is_variable_C, kIsVariableC, [&] { BOOL_SWITCH(params.z_ptr != nullptr , kHasZ, [&] { using Ktraits = Selective_Scan_fwd_kernel_traits<kNThreads, kNItems, kNRows, kIsEvenLen, kIsVariableB, kIsVariableC, kHasZ, input_t, weight_t>; - // constexpr int kSmemSize = Ktraits::kSmemSize; - constexpr int kSmemSize = Ktraits::kSmemSize + kNRows * MAX_DSTATE * sizeof(typename Ktraits::scan_t); + // const static int kSmemSize = Ktraits::kSmemSize; + const static int kSmemSize = Ktraits::kSmemSize + kNRows * MAX_DSTATE * sizeof(typename Ktraits::scan_t); // printf("smem_size = %d\n", kSmemSize); dim3 grid(params.batch, params.dim / kNRows); auto kernel = &selective_scan_fwd_kernel<Ktraits>; diff --git a/csrc/selective_scan/static_switch.h b/csrc/selective_scan/static_switch.h index 7920ac0..87493ef 100644 --- a/csrc/selective_scan/static_switch.h +++ b/csrc/selective_scan/static_switch.h @@ -16,10 +16,10 @@ #define BOOL_SWITCH(COND, CONST_NAME, ...) \ [&] { \ if (COND) { \ - constexpr bool CONST_NAME = true; \ + const static bool CONST_NAME = true; \ return __VA_ARGS__(); \ } else { \ - constexpr bool CONST_NAME = false; \ + const static bool CONST_NAME = false; \ return __VA_ARGS__(); \ } \ }()With those changes I can compile the package. It seems to work in PyTorch, but like I mentioned, I haven't tested performance or correctness. 😅
working solution. (compiled but haven't trained)
python 3.11.7
windows 10
from mamba.
@Phylliida hello, thanks for your method. But I don't understand what to be added after removing "import selective_scan_cuda" .In the class SelectiveScanFn , There are " out, x, *rest = selective_scan_cuda.fwd(u, delta, A, B, C, D, delta_bias, delta_softplus)" and "du, ddelta, dA, dB, dC, dD, ddelta_bias, *rest = selective_scan_cuda.bwd(u, delta, A, B, C, D, delta_bias, dout, x, None, ctx.delta_softplus, ) " in the forward and backward fuctions .
Please help me.
from mamba.
@RiceBunny1990 You can skip any modifications to ops/selective_scan_interface.py
after you successfully compile mamba kernels on windows. Which should be possible after doing the changes I posted previously.
from mamba.
Is there a simple way to get the training and inference (without recompiling the CUDA kernels) working on Windows without using WSL?
from mamba.
@Phylliida @Grzego Thank you for your information, I have complied causal_conv1d 1.1.3.post1 and mamba 1.1.3.post1 successfully in python 3.10 + windows 11 x64 + torch 2.2 + cuda 12.1. However, when I try to import mamba, it will crash on import casual_conv1d_cuda
, gives:
ImportError: DLL load failed while importing causal_conv1d_cuda: The specified module could not be found.
I have checked causal_conv1d_cuda.cp310-win_amd64.pyd
's dependencies (AFAIK pyd is dll in windows), all its dependencies exist.
Any idea what causes it failed?
from mamba.
Related Issues (20)
- Question for 'self.use_mem_eff_path and inference_params' HOT 4
- triton.runtime.autotuner.OutOfResources: out of resource: shared memory, Required: 254208, Hardware limit: 101376. HOT 5
- I want to ask does anyone know how to solve this problem
- /anaconda3/lib/python3.11/site-packages/causal_conv1d_cuda.cpython-311-x86_64-linux-gnu.so: undefined symbol: _ZN3c107WarningC1ENS_7variantIJNS0_11UserWarningENS0_18DeprecationWarningEEEERKNS_14SourceLocationENSt7__cxx1112basic_stringIcSt11char_traitsIcESaIcEEEb HOT 1
- Mamba-2 Error: `'NoneType' object has no attribute 'causal_conv1d_fwd'` HOT 8
- Used selective_scan_cuda and causal_conv1d_cuda, but still very slow to train HOT 1
- mamba / self-attention hybrid generation
- Inference multiple tokens HOT 2
- Error when using FP16 or Mixed precision HOT 3
- How to use Mamba2?
- How to extract whole sentence embeddings HOT 1
- Does mamba support data packing?
- Slow Mamba 2 training speeds with higher d_state values HOT 1
- Where is ‘Block’ class in the new version mamba? HOT 1
- mamba_ssm Install Failure HOT 9
- Sequence parallelism in the mixer (Context Parallelism)
- Support Mamba-codestral
- Why does it take so long to build HOT 1
- Is mamba suitable for time-series classification task? HOT 1
- Question on Comparison between Mamba and S4 HOT 1
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from mamba.