GithubHelp home page GithubHelp logo

Comments (7)

farshadghodsian avatar farshadghodsian commented on June 18, 2024 1

After taking a deep look into the code and testing Flash Attention support on AMD GPUs here is what I found:

AMD Instinct GPUs, gfx90a and gfx942 (MI210, MI250, MI300), support Flash Attention by way of specially written Composable Kernel libraries. Although I haven't tested this myself it is working and there are performance numbers on the 2-3x speedup vLLM gives you using CK Flash Attention.

Radeon RDNA3 GPUs, 7900 XTX and W7900 (gfx1100), lack the nessecary Composable Kernel libraries to use the above mentioned Flash Attention mechanism and thus the engineers at AMD opted for these GPUs to use an implemenation of Flash Attention written in OpenAI's Triton. This Triton Flash Attention is supposed to be working, but all tests I've done (usuing various different branches and docker builds) and using VLLM_USE_TRITON_FLASH_ATTN=1 have the same "stack frame size exceeds limit" issue while trying to compile doing the Triton JIT compile at runtime. I am sure the compile is not failing due to system resources as I have tested this using the Radeon Pro W7900 on two powerful systems, Ryzen 9 7950x w/ 64GBs of RAM and a Threadripper Pro 5975wx w/ 128GBs of RAM, but in both cases the triton compile takes a really long time (upwards of several hours) and still fails with the same stack frame size error (see screenshot).

Screenshot from 2024-05-18 06-36-37

Flash Attention forward pass support for RDNA3 was added thanks to howiejay however this implementation no longer works in my testing as it fails to run the hipify_python patch and build on newer versions of pytorch+rocm (tried on rocm/pytorch:rocm5.7_ubuntu22.04_py3.10_pytorch_2.0.1 and rocm/pytorch:rocm6.0_ubuntu20.04_py3.9_pytorch_2.1.1).

In summary the only way it seems to get vLLM working on Radeon and Radeon Pro graphics cards at the moment is to build without CK Flash Attention support BUILD_FA="0" and disable the Triton Flash Attention implemenation VLLM_USE_TRITON_FLASH_ATTN=0. This results in vLLM running, but you do not get any of the speed ups that vLLM is known for and in my testing inference using vLLM is the same or slower than things like llama.cpp and Ollama.

The vLLM repos I've already tried are:
https://github.com/vllm-project/vllm (main branch)
https://github.com/ROCm/vllm (main, bf16_temp_fix_navi, TunableOp_Integration_ROCm6.0 branches)
https://github.com/hongxiayang/vllm (main branch)

Commands used to run vLLM docker image and server were as follows (tried a few other variations of the below commands like changing smh-size or --max-model-len with no luck):

# Run vllm-rocm Docker image 
docker run -it --network=host --device=/dev/kfd --device=/dev/dri --group-add=video --ipc=host \
--cap-add=SYS_PTRACE --security-opt seccomp=unconfined --shm-size 16G --name vllm-rocm -v /home/${USER}/Downloads/models:/app/model \
vllm-rocm bash

# Run vllm api server
VLLM_USE_TRITON_FLASH_ATTN=1 CUDA_VISIBLE_DEVICES=0 python -m vllm.entrypoints.openai.api_server --max-model-len 3072 --download-dir /app/model --quantization=gptq --tensor-parallel-size=1 --enforce-eager --trust-remote-code --dtype=auto --kv-cache-dtype=auto --quantization-param-path=None --device=cuda --block-size=16 --model TechxGenus/Meta-Llama-3-70B-Instruct-GPTQ

Asking that the engineers at AMD look into this and assist in troubleshooting/getting this working for Radeon GPUs (Navi3).

from vllm.

lhl avatar lhl commented on June 18, 2024 1

That won't work I think. There's a related Flash Attention discussion on gfx1100 here: ROCm/aotriton#16 although according to this, Navi support was upstreamed last month and the appropriate place to file any navi31 Triton issues is the main repo: https://github.com/openai/triton

(The vLLM bug atm is just that it's not checking for gfx1100 correctly, it shouldn't be trying to use the Triton FA at all?)

from vllm.

DhruvDh avatar DhruvDh commented on June 18, 2024

if possible, can you try building triton from source?

from vllm.

Beinsezii avatar Beinsezii commented on June 18, 2024

The howiejay branch should build fine on the latest torch stable running ROCm 6. I have py3.11 and py3.12 wheels built against gfx1100 and ROCm 6.0 here. All I run is

pip wheel git+https://github.com/ROCm/flash-attention@howiejay/navi_support --no-deps

in my virtualenvs to produce the wheels.

So I built vLLM with defaults then set VLLM_USE_TRITON_FLASH_ATTN=0 at runtime. On unquantized Llama3 8B I peaked at something like 1550 T/S with BS=96 and 0.95 memory allocation on a 7900 XTX 24G. 400 token response with a few hundred in context. Seems okay-ish?

Upate: There's an internal gate against using the CK FA for Navi even if its installed because there's no varlen_fwd() support. You can build and install the howiejay flash-attn fine but it seems to only be useful for diffusion models atm.

Additionally I built ROCM/triton from source as of an hour ago and it still just sits peaking one thread for a small eternity before eventually being killed for blowing up the stack. I guess a person could try to increase the stack size but I really feel like something's not working...

from vllm.

Beinsezii avatar Beinsezii commented on June 18, 2024

I think I narrowed it to this autotune:
triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "waves_per_eu": 2, "PRE_LOAD_V": False}, num_stages=1, num_warps=4)

Disabling that and I can run without VLLM_USE_TRITON_FLASH_ATTN=0. I'm using triton nightly as of an hour ago to make sure it has any possible Navi fixes. Though if anything it feels slower? I'll try stable triton in a bit.

Patch on top of v0.4.2 if someone else wants to play with it.

diff --git a/vllm/attention/ops/triton_flash_attention.py b/vllm/attention/ops/triton_flash_attention.py
index 11476641..d5f6bbec 100644
--- a/vllm/attention/ops/triton_flash_attention.py
+++ b/vllm/attention/ops/triton_flash_attention.py
@@ -219,16 +219,16 @@ def _attn_fwd_inner(
             num_stages=1,
             num_warps=8,
         ),
-        triton.Config(
-            {
-                "BLOCK_M": 128,
-                "BLOCK_N": 128,
-                "waves_per_eu": 2,
-                "PRE_LOAD_V": False,
-            },
-            num_stages=1,
-            num_warps=4,
-        ),
+        # triton.Config(
+        #     {
+        #         "BLOCK_M": 128,
+        #         "BLOCK_N": 128,
+        #         "waves_per_eu": 2,
+        #         "PRE_LOAD_V": False,
+        #     },
+        #     num_stages=1,
+        #     num_warps=4,
+        # ),
         triton.Config(
             {
                 "BLOCK_M": 256,

from vllm.

sdli1995 avatar sdli1995 commented on June 18, 2024

The howiejay branch should build fine on the latest torch stable running ROCm 6. I have py3.11 and py3.12 wheels built against gfx1100 and ROCm 6.0 here. All I run is

pip wheel git+https://github.com/ROCm/flash-attention@howiejay/navi_support --no-deps

in my virtualenvs to produce the wheels.

So I built vLLM with defaults then set VLLM_USE_TRITON_FLASH_ATTN=0 at runtime. On unquantized Llama3 8B I peaked at something like 1550 T/S with BS=96 and 0.95 memory allocation on a 7900 XTX 24G. 400 token response with a few hundred in context. Seems okay-ish?

Upate: There's an internal gate against using the CK FA for Navi even if its installed because there's no varlen_fwd() support. You can build and install the howiejay flash-attn fine but it seems to only be useful for diffusion models atm.

Additionally I built ROCM/triton from source as of an hour ago and it still just sits peaking one thread for a small eternity before eventually being killed for blowing up the stack. I guess a person could try to increase the stack size but I really feel like something's not working...

the upstreaming triton support navi3 but attention performance is slow

from vllm.

Beinsezii avatar Beinsezii commented on June 18, 2024

Alright I tried with stable triton and the ROCm triton fork. My patch only helped the official nightly run without hanging.

pip uninstall pytorch-triton-rocm -y; pip install --index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/Triton-Nightly/pypi/simple/ triton-nightly --no-deps

There might be more configs that need to be disabled to run stable triton? A person could maybe just disable every config with a block dim ≥ 128 and it'd probably work everywhere. I think navi favors the small ones anyways?

I also found triton is indeed much faster than naive once you stack the context.

from vllm.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.