Comments (7)
After taking a deep look into the code and testing Flash Attention support on AMD GPUs here is what I found:
AMD Instinct GPUs, gfx90a and gfx942 (MI210, MI250, MI300), support Flash Attention by way of specially written Composable Kernel libraries. Although I haven't tested this myself it is working and there are performance numbers on the 2-3x speedup vLLM gives you using CK Flash Attention.
Radeon RDNA3 GPUs, 7900 XTX and W7900 (gfx1100), lack the nessecary Composable Kernel libraries to use the above mentioned Flash Attention mechanism and thus the engineers at AMD opted for these GPUs to use an implemenation of Flash Attention written in OpenAI's Triton. This Triton Flash Attention is supposed to be working, but all tests I've done (usuing various different branches and docker builds) and using VLLM_USE_TRITON_FLASH_ATTN=1
have the same "stack frame size exceeds limit" issue while trying to compile doing the Triton JIT compile at runtime. I am sure the compile is not failing due to system resources as I have tested this using the Radeon Pro W7900 on two powerful systems, Ryzen 9 7950x w/ 64GBs of RAM and a Threadripper Pro 5975wx w/ 128GBs of RAM, but in both cases the triton compile takes a really long time (upwards of several hours) and still fails with the same stack frame size error (see screenshot).
Flash Attention forward pass support for RDNA3 was added thanks to howiejay however this implementation no longer works in my testing as it fails to run the hipify_python patch and build on newer versions of pytorch+rocm (tried on rocm/pytorch:rocm5.7_ubuntu22.04_py3.10_pytorch_2.0.1
and rocm/pytorch:rocm6.0_ubuntu20.04_py3.9_pytorch_2.1.1
).
In summary the only way it seems to get vLLM working on Radeon and Radeon Pro graphics cards at the moment is to build without CK Flash Attention support BUILD_FA="0"
and disable the Triton Flash Attention implemenation VLLM_USE_TRITON_FLASH_ATTN=0
. This results in vLLM running, but you do not get any of the speed ups that vLLM is known for and in my testing inference using vLLM is the same or slower than things like llama.cpp and Ollama.
The vLLM repos I've already tried are:
https://github.com/vllm-project/vllm (main
branch)
https://github.com/ROCm/vllm (main
, bf16_temp_fix_navi
, TunableOp_Integration_ROCm6.0
branches)
https://github.com/hongxiayang/vllm (main
branch)
Commands used to run vLLM docker image and server were as follows (tried a few other variations of the below commands like changing smh-size or --max-model-len with no luck):
# Run vllm-rocm Docker image
docker run -it --network=host --device=/dev/kfd --device=/dev/dri --group-add=video --ipc=host \
--cap-add=SYS_PTRACE --security-opt seccomp=unconfined --shm-size 16G --name vllm-rocm -v /home/${USER}/Downloads/models:/app/model \
vllm-rocm bash
# Run vllm api server
VLLM_USE_TRITON_FLASH_ATTN=1 CUDA_VISIBLE_DEVICES=0 python -m vllm.entrypoints.openai.api_server --max-model-len 3072 --download-dir /app/model --quantization=gptq --tensor-parallel-size=1 --enforce-eager --trust-remote-code --dtype=auto --kv-cache-dtype=auto --quantization-param-path=None --device=cuda --block-size=16 --model TechxGenus/Meta-Llama-3-70B-Instruct-GPTQ
Asking that the engineers at AMD look into this and assist in troubleshooting/getting this working for Radeon GPUs (Navi3).
from vllm.
That won't work I think. There's a related Flash Attention discussion on gfx1100 here: ROCm/aotriton#16 although according to this, Navi support was upstreamed last month and the appropriate place to file any navi31 Triton issues is the main repo: https://github.com/openai/triton
(The vLLM bug atm is just that it's not checking for gfx1100 correctly, it shouldn't be trying to use the Triton FA at all?)
from vllm.
if possible, can you try building triton from source?
from vllm.
The howiejay branch should build fine on the latest torch stable running ROCm 6. I have py3.11 and py3.12 wheels built against gfx1100 and ROCm 6.0 here. All I run is
pip wheel git+https://github.com/ROCm/flash-attention@howiejay/navi_support --no-deps
in my virtualenvs to produce the wheels.
So I built vLLM with defaults then set VLLM_USE_TRITON_FLASH_ATTN=0 at runtime. On unquantized Llama3 8B I peaked at something like 1550 T/S with BS=96 and 0.95 memory allocation on a 7900 XTX 24G. 400 token response with a few hundred in context. Seems okay-ish?
Upate: There's an internal gate against using the CK FA for Navi even if its installed because there's no varlen_fwd()
support. You can build and install the howiejay flash-attn fine but it seems to only be useful for diffusion models atm.
Additionally I built ROCM/triton from source as of an hour ago and it still just sits peaking one thread for a small eternity before eventually being killed for blowing up the stack. I guess a person could try to increase the stack size but I really feel like something's not working...
from vllm.
I think I narrowed it to this autotune:
triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "waves_per_eu": 2, "PRE_LOAD_V": False}, num_stages=1, num_warps=4)
Disabling that and I can run without VLLM_USE_TRITON_FLASH_ATTN=0
. I'm using triton nightly as of an hour ago to make sure it has any possible Navi fixes. Though if anything it feels slower? I'll try stable triton in a bit.
Patch on top of v0.4.2 if someone else wants to play with it.
diff --git a/vllm/attention/ops/triton_flash_attention.py b/vllm/attention/ops/triton_flash_attention.py
index 11476641..d5f6bbec 100644
--- a/vllm/attention/ops/triton_flash_attention.py
+++ b/vllm/attention/ops/triton_flash_attention.py
@@ -219,16 +219,16 @@ def _attn_fwd_inner(
num_stages=1,
num_warps=8,
),
- triton.Config(
- {
- "BLOCK_M": 128,
- "BLOCK_N": 128,
- "waves_per_eu": 2,
- "PRE_LOAD_V": False,
- },
- num_stages=1,
- num_warps=4,
- ),
+ # triton.Config(
+ # {
+ # "BLOCK_M": 128,
+ # "BLOCK_N": 128,
+ # "waves_per_eu": 2,
+ # "PRE_LOAD_V": False,
+ # },
+ # num_stages=1,
+ # num_warps=4,
+ # ),
triton.Config(
{
"BLOCK_M": 256,
from vllm.
The howiejay branch should build fine on the latest torch stable running ROCm 6. I have py3.11 and py3.12 wheels built against gfx1100 and ROCm 6.0 here. All I run is
pip wheel git+https://github.com/ROCm/flash-attention@howiejay/navi_support --no-depsin my virtualenvs to produce the wheels.
So I built vLLM with defaults then set VLLM_USE_TRITON_FLASH_ATTN=0 at runtime. On unquantized Llama3 8B I peaked at something like 1550 T/S with BS=96 and 0.95 memory allocation on a 7900 XTX 24G. 400 token response with a few hundred in context. Seems okay-ish?Upate: There's an internal gate against using the CK FA for Navi even if its installed because there's no
varlen_fwd()
support. You can build and install the howiejay flash-attn fine but it seems to only be useful for diffusion models atm.Additionally I built ROCM/triton from source as of an hour ago and it still just sits peaking one thread for a small eternity before eventually being killed for blowing up the stack. I guess a person could try to increase the stack size but I really feel like something's not working...
the upstreaming triton support navi3 but attention performance is slow
from vllm.
Alright I tried with stable triton and the ROCm triton fork. My patch only helped the official nightly run without hanging.
pip uninstall pytorch-triton-rocm -y; pip install --index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/Triton-Nightly/pypi/simple/ triton-nightly --no-deps
There might be more configs that need to be disabled to run stable triton? A person could maybe just disable every config with a block dim ≥ 128 and it'd probably work everywhere. I think navi favors the small ones anyways?
I also found triton is indeed much faster than naive once you stack the context.
from vllm.
Related Issues (20)
- [Feature]: asymmetric tensor parallel
- [Bug]: prefix-caching: inconsistent completions HOT 1
- [Bug]: Distribute Tests PR test fails
- [Bug]: llava-v1.6-mistral-7b-hf prompt template handling error HOT 3
- [Bug]: RuntimeError: CUDA error: no kernel image is available for execution on the device HOT 1
- [Bug]: OOM when setting prompt_logprobs=1 HOT 3
- [RFC]: Refactor Worker and ModelRunner to consolidate control plane communication HOT 10
- [Bug]: Using tensor-parallel-size 4 fails for some models with pyo3_runtime.PanicException: The global thread pool has not been initialized.: ThreadPoolBuildError {"Resource temporarily unavailable" })
- [RFC]: Implement KV cache transferring mechanism in vLLM HOT 2
- [Performance] [Speculative decoding] Speed up autoregressive proposal methods by making sampler CPU serialization optional HOT 2
- [Bug]: Speculative decoding server: `ValueError: could not broadcast input array from shape (513,) into shape (512,)` HOT 11
- [Bug]: Regression in LoRA Adapter loading speed between vllm 0.4.3 and 0.5.0 HOT 3
- [RFC]: Branch based version control, and development version
- [Usage]: how to use marlin kernel for GPTQ model HOT 9
- [Performance]:
- [Bug]: BitsandBytes quantization is not working as expected HOT 4
- [Usage]: Using VLLM with Langchain for RAG purposes HOT 2
- [Bug]: Installation Issue with torch Version Conflict on vllm v0.5.0.post1 HOT 3
- [Bug]: chunked prefill scheudler uses up swap on many n>=2 requests HOT 3
- [Bug]: Mixtral8x7B very high spikes for Inter Token Latency (ITL)
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from vllm.