Skip to content

[ROCm] Enabling forward_includes_kv_cache on ROCm MHA backends#33106

Merged
tjtanaa merged 4 commits into
vllm-project:mainfrom
ROCm:rocm_attn_backends_forward_includes_kv_cache
Jan 28, 2026
Merged

[ROCm] Enabling forward_includes_kv_cache on ROCm MHA backends#33106
tjtanaa merged 4 commits into
vllm-project:mainfrom
ROCm:rocm_attn_backends_forward_includes_kv_cache

Conversation

@gshtras

@gshtras gshtras commented Jan 26, 2026

Copy link
Copy Markdown
Collaborator

Add the following backends support for #32335

  • RocmAiterUnifiedAttention
  • RocmAttention
  • TritonAttention

rocm_attn is also covered in #32543 from what I can tell

Signed-off-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com>
@gshtras gshtras added the rocm Related to AMD ROCm label Jan 26, 2026
@mergify mergify Bot added the v1 label Jan 26, 2026
@gshtras gshtras marked this pull request as draft January 26, 2026 17:34

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request refactors the KV cache update logic across RocmAiterUnifiedAttention, RocmAttention, and TritonAttention backends. The forward_includes_kv_cache: bool = False attribute is added to each backend, and the KV cache update operations are moved into a new do_kv_cache_update method. This improves modularity and clarifies the responsibilities of the forward method, which now focuses solely on attention computation, delegating cache updates to the dedicated method. The changes are consistent across all modified files and align with the stated objective of enabling forward_includes_kv_cache on ROCm MHA backends.

@cursor cursor Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cursor Bugbot has reviewed your changes and found 2 potential issues.

Bugbot Autofix is OFF. To automatically fix reported issues with Cloud Agents, enable Autofix in the Cursor dashboard.

Comment @cursor review or bugbot run to trigger another review on this PR

Comment thread vllm/v1/attention/backends/rocm_aiter_unified_attn.py Outdated
self.kv_cache_dtype,
layer._k_scale,
layer._v_scale,
)

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing encoder attention check in do_kv_cache_update

Medium Severity

TritonAttentionImpl.do_kv_cache_update() is missing the early return for encoder attention that exists in FlashAttentionImpl.do_kv_cache_update(). Since TritonAttentionBackend.supports_attn_type() returns True for ENCODER and ENCODER_ONLY, this method could be called for encoder attention. Without the check if self.attn_type in (AttentionType.ENCODER_ONLY, AttentionType.ENCODER): return, it would attempt to access and write to the KV cache for encoder attention, which doesn't use caching.

Fix in Cursor Fix in Web

Signed-off-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com>
@gshtras gshtras marked this pull request as ready for review January 26, 2026 18:05
@mergify

mergify Bot commented Jan 27, 2026

Copy link
Copy Markdown
Contributor

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @gshtras.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify Bot added the needs-rebase label Jan 27, 2026

@tjtanaa tjtanaa left a comment

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Thanks for adopting the latest abstraction.

@tjtanaa tjtanaa added the ready ONLY add when PR is ready to merge/full CI is needed label Jan 27, 2026
@tjtanaa

tjtanaa commented Jan 27, 2026

Copy link
Copy Markdown
Member

@gshtras can you help to rebase this PR?

@ProExpertProg

Copy link
Copy Markdown
Collaborator

@gshtras can we get some perf & eval numbers?

…forward_includes_kv_cache

Signed-off-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com>
@mergify mergify Bot removed the needs-rebase label Jan 27, 2026
Comment thread vllm/v1/attention/backends/rocm_aiter_unified_attn.py Outdated
Comment thread vllm/v1/attention/backends/rocm_attn.py Outdated
Comment thread vllm/v1/attention/backends/triton_attn.py Outdated
@gshtras

gshtras commented Jan 27, 2026

Copy link
Copy Markdown
Collaborator Author

@gshtras can we get some perf & eval numbers?

The main perf improvement is expected from the ROPE + KV cache fusion PR that will follow. Here are some numbers for this PR

vllm serve meta_llama/Llama-3.1-70B-Instruct -tp 8 --no-enable-prefix-caching
and
vllm serve amd/Llama-3.1-70B-Instruct-FP8-KV --kv-cache-dtype fp8 -tp 8 --no-enable-prefix-caching

vllm bench serve --model $model --percentile-metrics ttft,tpot,itl,e2el --dataset-name sharegpt --dataset-path /models/ShareGPT_Vicuna_unfiltered/ShareGPT_V3_unfiltered_cleaned_split.json --sharegpt-output-len 1024 --max-concurrency 4 --num-prompts 40 --ignore-eos

lm_eval --model local-completions --model_args base_url=http://localhost:8000/v1/completions,pretrained=${model},add_bos_token=true,trust_remote_code=true --tasks gsm8k --num_fewshot 5 --batch_size 64

Llama-70

ROCM_ATTN

main

============ Serving Benchmark Result ============
----------------End-to-end Latency----------------
Mean E2EL (ms):                          11943.43  
Median E2EL (ms):                        11936.98  
P99 E2EL (ms):                           12036.01  
==================================================
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.9287|±  |0.0071|
|     |       |strict-match    |     5|exact_match|↑  |0.8711|±  |0.0092|

PR

============ Serving Benchmark Result ============
----------------End-to-end Latency----------------
Mean E2EL (ms):                          12142.87  
Median E2EL (ms):                        12147.02  
P99 E2EL (ms):                           12246.63  
==================================================
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.9249|±  |0.0073|
|     |       |strict-match    |     5|exact_match|↑  |0.8658|±  |0.0094|

TRITON_ATTN

main

============ Serving Benchmark Result ============
----------------End-to-end Latency----------------
Mean E2EL (ms):                          12865.37  
Median E2EL (ms):                        12874.88  
P99 E2EL (ms):                           13029.51  
==================================================
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.9303|±  |0.0070|
|     |       |strict-match    |     5|exact_match|↑  |0.8749|±  |0.0091|

PR

============ Serving Benchmark Result ============
----------------End-to-end Latency----------------
Mean E2EL (ms):                          12574.40  
Median E2EL (ms):                        12571.19  
P99 E2EL (ms):                           12782.47  
==================================================
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.9234|±  |0.0073|
|     |       |strict-match    |     5|exact_match|↑  |0.8643|±  |0.0094|

AITER_UNIFIED_ATTN

main

============ Serving Benchmark Result ============
----------------End-to-end Latency----------------
Mean E2EL (ms):                          12457.69  
Median E2EL (ms):                        12423.01  
P99 E2EL (ms):                           12725.19  
==================================================
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.9257|±  |0.0072|
|     |       |strict-match    |     5|exact_match|↑  |0.8605|±  |0.0095|

PR

============ Serving Benchmark Result ============
----------------End-to-end Latency----------------
Mean E2EL (ms):                          11926.23  
Median E2EL (ms):                        11823.47  
P99 E2EL (ms):                           12821.11  
==================================================
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.9249|±  |0.0073|
|     |       |strict-match    |     5|exact_match|↑  |0.8590|±  |0.0096|

Llama-70-FP8-KV

ROCM_ATTN

main

============ Serving Benchmark Result ============
----------------End-to-end Latency----------------
Mean E2EL (ms):                          11273.76  
Median E2EL (ms):                        11271.16  
P99 E2EL (ms):                           11335.56  
==================================================
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.9310|±  | 0.007|
|     |       |strict-match    |     5|exact_match|↑  |0.8431|±  | 0.010|

PR

============ Serving Benchmark Result ============
----------------End-to-end Latency----------------
Mean E2EL (ms):                          11484.81  
Median E2EL (ms):                        11297.84  
P99 E2EL (ms):                           13226.65  
==================================================
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.9242|±  |0.0073|
|     |       |strict-match    |     5|exact_match|↑  |0.8476|±  |0.0099|

TRITON_ATTN

main

============ Serving Benchmark Result ============
----------------End-to-end Latency----------------
Mean E2EL (ms):                          11294.83  
Median E2EL (ms):                        11319.47  
P99 E2EL (ms):                           11459.19  
==================================================
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.9052|±  |0.0081|
|     |       |strict-match    |     5|exact_match|↑  |0.8302|±  |0.0103|

PR

============ Serving Benchmark Result ============
----------------End-to-end Latency----------------
Mean E2EL (ms):                          11120.36  
Median E2EL (ms):                        11146.96  
P99 E2EL (ms):                           11269.46  
==================================================
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.9083|±  |0.0080|
|     |       |strict-match    |     5|exact_match|↑  |0.8347|±  |0.0102|

ROCM_AITER_UNIFIED_ATTN

main

============ Serving Benchmark Result ============
----------------End-to-end Latency----------------
Mean E2EL (ms):                          11097.16  
Median E2EL (ms):                        11032.52  
P99 E2EL (ms):                           12012.72  
==================================================
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.9272|±  |0.0072|
|     |       |strict-match    |     5|exact_match|↑  |0.8491|±  |0.0099|

PR

============ Serving Benchmark Result ============
----------------End-to-end Latency----------------
Mean E2EL (ms):                          11719.97  
Median E2EL (ms):                        11632.63  
P99 E2EL (ms):                           12428.84  
==================================================
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.9265|±  |0.0072|
|     |       |strict-match    |     5|exact_match|↑  |0.8552|±  |0.0097|

Signed-off-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com>
@tjtanaa tjtanaa merged commit 22ad649 into vllm-project:main Jan 28, 2026
51 checks passed
@gshtras gshtras deleted the rocm_attn_backends_forward_includes_kv_cache branch January 28, 2026 15:36
num_actual_tokens = attn_metadata.num_actual_tokens

# Handle encoder attention differently - no KV cache needed
if self.attn_type in (AttentionType.ENCODER_ONLY, AttentionType.ENCODER):

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@gshtras this check is missing in do_kv_cache_update, causing a failure for encoder model using TRITON_ATTN

apd10 pushed a commit to apd10/vllm that referenced this pull request Jan 31, 2026
…project#33106)

Signed-off-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com>
PiratePai pushed a commit to PiratePai/epd_shm that referenced this pull request Feb 3, 2026
…project#33106)

Signed-off-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com>
Signed-off-by: Pai <416932041@qq.com>
mystous pushed a commit to mystous/vllm_hybrid that referenced this pull request May 10, 2026
…project#33106)

Signed-off-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com>
my-other-github-account pushed a commit to my-other-github-account/vllm that referenced this pull request May 15, 2026
…project#33106)

Signed-off-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com>
my-other-github-account pushed a commit to my-other-github-account/vllm that referenced this pull request May 15, 2026
…project#33106)

Signed-off-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com>
0826joyce pushed a commit to 0826joyce/vllm-serving-optimization that referenced this pull request May 19, 2026
…project#33106)

Signed-off-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready ONLY add when PR is ready to merge/full CI is needed rocm Related to AMD ROCm v1

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants