GithubHelp home page GithubHelp logo

Comments (6)

sfc-gh-zhwang avatar sfc-gh-zhwang commented on June 11, 2024 1

@mgoin it's just tp=8 doesn't work.

from vllm.

mgoin avatar mgoin commented on June 11, 2024

Hi @sfc-gh-zhwang, FWIW I was able to run this with TP=2 on 2xA6000 using vllm==0.4.2.

from vllm import LLM

llm = LLM(
    model="mistralai/Mistral-7B-Instruct-v0.2",
    enable_lora=True,
    tensor_parallel_size=2,
)

print(llm.generate("Hello"))

Output:

2024-05-22 06:43:47,998 INFO worker.py:1749 -- Started a local Ray instance.
INFO 05-22 06:43:49 llm_engine.py:100] Initializing an LLM engine (v0.4.2) with config: model='mistralai/Mistral-7B-Instruct-v0.2', speculative_config=None, tokenizer='mistralai/Mistral-7B-Instruct-v0.2', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, tokenizer_revision=None, trust_remote_code=False, dtype=torch.bfloat16, max_seq_len=32768, download_dir=None, load_format=LoadFormat.AUTO, tensor_parallel_size=2, disable_custom_all_reduce=False, quantization=None, enforce_eager=False, kv_cache_dtype=auto, quantization_param_path=None, device_config=cuda, decoding_config=DecodingConfig(guided_decoding_backend='outlines'), seed=0, served_model_name=mistralai/Mistral-7B-Instruct-v0.2)
INFO 05-22 06:43:53 utils.py:660] Found nccl from library /root/.config/vllm/nccl/cu12/libnccl.so.2.18.1
(RayWorkerWrapper pid=523946) INFO 05-22 06:43:53 utils.py:660] Found nccl from library /root/.config/vllm/nccl/cu12/libnccl.so.2.18.1
INFO 05-22 06:43:54 selector.py:81] Cannot use FlashAttention-2 backend because the flash_attn package is not found. Please install it for better performance.
INFO 05-22 06:43:54 selector.py:32] Using XFormers backend.
(RayWorkerWrapper pid=523946) INFO 05-22 06:43:54 selector.py:81] Cannot use FlashAttention-2 backend because the flash_attn package is not found. Please install it for better performance.
(RayWorkerWrapper pid=523946) INFO 05-22 06:43:54 selector.py:32] Using XFormers backend.
INFO 05-22 06:43:56 pynccl_utils.py:43] vLLM is using nccl==2.18.1
(RayWorkerWrapper pid=523946) INFO 05-22 06:43:56 pynccl_utils.py:43] vLLM is using nccl==2.18.1
INFO 05-22 06:43:57 utils.py:132] reading GPU P2P access cache from /root/.config/vllm/gpu_p2p_access_cache_for_0,1.json
(RayWorkerWrapper pid=523946) INFO 05-22 06:43:57 utils.py:132] reading GPU P2P access cache from /root/.config/vllm/gpu_p2p_access_cache_for_0,1.json
INFO 05-22 06:43:58 weight_utils.py:199] Using model weights format ['*.safetensors']
(RayWorkerWrapper pid=523946) INFO 05-22 06:43:58 weight_utils.py:199] Using model weights format ['*.safetensors']
INFO 05-22 06:44:00 model_runner.py:175] Loading model weights took 6.7544 GB
(RayWorkerWrapper pid=523946) INFO 05-22 06:44:01 model_runner.py:175] Loading model weights took 6.7544 GB
INFO 05-22 06:44:07 distributed_gpu_executor.py:45] # GPU blocks: 33401, # CPU blocks: 4096
INFO 05-22 06:44:08 model_runner.py:937] Capturing the model for CUDA graphs. This may lead to unexpected consequences if the model is not static. To run the model in eager mode, set 'enforce_eager=True' or use '--enforce-eager' in the CLI.
INFO 05-22 06:44:08 model_runner.py:941] CUDA graphs can take additional 1~3 GiB memory per GPU. If you are running out of memory, consider decreasing `gpu_memory_utilization` or enforcing eager mode. You can also reduce the `max_num_seqs` as needed to decrease memory usage.
(RayWorkerWrapper pid=523946) INFO 05-22 06:44:08 model_runner.py:937] Capturing the model for CUDA graphs. This may lead to unexpected consequences if the model is not static. To run the model in eager mode, set 'enforce_eager=True' or use '--enforce-eager' in the CLI.
(RayWorkerWrapper pid=523946) INFO 05-22 06:44:08 model_runner.py:941] CUDA graphs can take additional 1~3 GiB memory per GPU. If you are running out of memory, consider decreasing `gpu_memory_utilization` or enforcing eager mode. You can also reduce the `max_num_seqs` as needed to decrease memory usage.
INFO 05-22 06:44:13 custom_all_reduce.py:246] Registering 2275 cuda graph addresses
INFO 05-22 06:44:13 model_runner.py:1017] Graph capturing finished in 5 secs.
(RayWorkerWrapper pid=523946) INFO 05-22 06:44:13 custom_all_reduce.py:246] Registering 2275 cuda graph addresses
(RayWorkerWrapper pid=523946) INFO 05-22 06:44:13 model_runner.py:1017] Graph capturing finished in 5 secs.
Processed prompts: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00,  3.56it/s]
[RequestOutput(request_id=0, prompt='Hello', prompt_token_ids=[1, 22557], prompt_logprobs=None, outputs=[CompletionOutput(index=0, text=", I'm quite new to Linux so I'm really sorry if this", token_ids=[28725, 315, 28742, 28719, 3448, 633, 298, 19486, 579, 315, 28742, 28719, 1528, 7371, 513, 456], cumulative_logprob=-27.695576645433903, logprobs=None, finish_reason=length, stop_reason=None)], finished=True, metrics=RequestMetrics(arrival_time=1716360253.5880315, last_token_time=1716360253.5880315, first_scheduled_time=1716360253.5906928, first_token_time=1716360253.6384084, time_in_queue=0.0026612281799316406, finished_time=1716360253.8712075), lora_request=None)]

from vllm.

sfc-gh-zhwang avatar sfc-gh-zhwang commented on June 11, 2024

@FurtherAI in case you have some idea 😃

from vllm.

sfc-gh-zhwang avatar sfc-gh-zhwang commented on June 11, 2024

Further narrow down to this line

punica_kernels.dispatch_bgmv(buffer, x, wa_t_all, indicies, layer_idx, 1.0)

where, for tp=8 (error out), the tensor sizes are:

buffer: [32768, 16]
x: [32768, 512]
wa_t_all: [1, 1, 16, 512]

while for tp=4 (working), the tensor sizes are

buffer: [32768, 16]
x: [32768, 1024]
wa_t_all: [1, 1, 16, 1024]

Still trying to figure out what is the magic around 1024 -> 512

from vllm.

FurtherAI avatar FurtherAI commented on June 11, 2024

Tracked it a little further. Seems to be due to the sequence length. Not sure why, from a brief glance, the kernel shouldn't care about the sequence length. I found 65536 to work, 32768 and 16384 to not work and 8192, 4096 to work and didn't test more. So for now, @sfc-gh-zhwang, run with a different seq length.

Here's some code to reproduce:

import vllm._punica_C as punica_kernels
seq_length, rank = 32768, 16
buffer = torch.randn((seq_length, rank), device='cuda', dtype=torch.float32)
x = torch.randn((seq_length, 512), device='cuda', dtype=torch.bfloat16)
wa_t_all = torch.randn((1, 1, rank, 512), device='cuda', dtype=torch.bfloat16)
indicies = torch.full((seq_length,), 1, device='cuda', dtype=torch.int64)
punica_kernels.dispatch_bgmv(buffer, x, wa_t_all, indicies, 0, 1.0)

torch.cuda.synchronize()

from vllm.

sfc-gh-zhwang avatar sfc-gh-zhwang commented on June 11, 2024

I think I found the root cause, basically the code here will overflow X for certain tensor shapes. I think the solution should be adding a condition like if (threadIdx.y * tx * vec_size < feat_in).
But I think we should fold this into line 84 and just change for (tile_idx = 1; -> for (tile_idx = 0;? ROCM is doing this anyway:

for (tile_idx = 0; tile_idx < num_tiles; ++tile_idx) {

from vllm.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.