[ROCm] Enabling forward_includes_kv_cache on ROCm MHA backends#33106
Conversation
Signed-off-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com>
There was a problem hiding this comment.
Code Review
This pull request refactors the KV cache update logic across RocmAiterUnifiedAttention, RocmAttention, and TritonAttention backends. The forward_includes_kv_cache: bool = False attribute is added to each backend, and the KV cache update operations are moved into a new do_kv_cache_update method. This improves modularity and clarifies the responsibilities of the forward method, which now focuses solely on attention computation, delegating cache updates to the dedicated method. The changes are consistent across all modified files and align with the stated objective of enabling forward_includes_kv_cache on ROCm MHA backends.
There was a problem hiding this comment.
Cursor Bugbot has reviewed your changes and found 2 potential issues.
Bugbot Autofix is OFF. To automatically fix reported issues with Cloud Agents, enable Autofix in the Cursor dashboard.
Comment @cursor review or bugbot run to trigger another review on this PR
| self.kv_cache_dtype, | ||
| layer._k_scale, | ||
| layer._v_scale, | ||
| ) |
There was a problem hiding this comment.
Missing encoder attention check in do_kv_cache_update
Medium Severity
TritonAttentionImpl.do_kv_cache_update() is missing the early return for encoder attention that exists in FlashAttentionImpl.do_kv_cache_update(). Since TritonAttentionBackend.supports_attn_type() returns True for ENCODER and ENCODER_ONLY, this method could be called for encoder attention. Without the check if self.attn_type in (AttentionType.ENCODER_ONLY, AttentionType.ENCODER): return, it would attempt to access and write to the KV cache for encoder attention, which doesn't use caching.
Signed-off-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com>
|
This pull request has merge conflicts that must be resolved before it can be |
tjtanaa
left a comment
There was a problem hiding this comment.
LGTM. Thanks for adopting the latest abstraction.
|
@gshtras can you help to rebase this PR? |
|
@gshtras can we get some perf & eval numbers? |
…forward_includes_kv_cache Signed-off-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com>
The main perf improvement is expected from the ROPE + KV cache fusion PR that will follow. Here are some numbers for this PR
Llama-70ROCM_ATTNmainPRTRITON_ATTNmainPRAITER_UNIFIED_ATTNmainPRLlama-70-FP8-KVROCM_ATTNmainPRTRITON_ATTNmainPRROCM_AITER_UNIFIED_ATTNmainPR |
Signed-off-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com>
| num_actual_tokens = attn_metadata.num_actual_tokens | ||
|
|
||
| # Handle encoder attention differently - no KV cache needed | ||
| if self.attn_type in (AttentionType.ENCODER_ONLY, AttentionType.ENCODER): |
There was a problem hiding this comment.
@gshtras this check is missing in do_kv_cache_update, causing a failure for encoder model using TRITON_ATTN
…project#33106) Signed-off-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com>
…project#33106) Signed-off-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com> Signed-off-by: Pai <416932041@qq.com>
…project#33106) Signed-off-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com>
…project#33106) Signed-off-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com>
…project#33106) Signed-off-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com>
…project#33106) Signed-off-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com>
Add the following backends support for #32335
rocm_attn is also covered in #32543 from what I can tell