[None][feat] Use new index api, add block scale support, fix max_seq_len esitmation, add flash mla support#11334
Conversation
7ea1b99 to
e43e829
Compare
📝 WalkthroughWalkthroughThe PR refactors KV cache offset computation by removing a conditional branching mechanism and replacing it with parameterized scaling via new Changes
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~20 minutes 🚥 Pre-merge checks | ✅ 1 | ❌ 2❌ Failed checks (2 warnings)
✅ Passed checks (1 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
e43e829 to
278254e
Compare
There was a problem hiding this comment.
Actionable comments posted: 1
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (2)
cpp/tensorrt_llm/batch_manager/kvCacheManagerV2Utils.h (1)
1-3:⚠️ Potential issue | 🟡 MinorUpdate copyright year to reflect 2026 changes.
🛠️ Suggested update
- * SPDX-FileCopyrightText: Copyright (c) 2022-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-FileCopyrightText: Copyright (c) 2022-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.As per coding guidelines: "All TensorRT-LLM source files (.cpp, .h, .cu, .py, and other source files) should contain an NVIDIA copyright header with the year of latest meaningful modification".
cpp/tensorrt_llm/batch_manager/kvCacheManagerV2Utils.cu (1)
253-290:⚠️ Potential issue | 🟠 MajorValidate
indexScales/kvOffsetshapes before kernel launch.The kernel indexes both arrays by
poolIdx; mismatched shapes can cause OOB reads and memory corruption. Add shape/length checks alongside existing tensor validations.🛡️ Suggested shape validation
@@ - auto const& srcShape = input.getShape(); - auto const& dstShape = output.getShape(); - auto const& copyIndexShape = copyIndex.getShape(); + auto const& srcShape = input.getShape(); + auto const& dstShape = output.getShape(); + auto const& copyIndexShape = copyIndex.getShape(); + auto const& indexScalesShape = indexScales.getShape(); + auto const& kvOffsetShape = kvOffset.getShape(); @@ - SizeType32 numBlocksPerSeq = srcShape.d[3]; - SizeType32 numSeqs = copyIndexShape.d[0]; + SizeType32 numBlocksPerSeq = srcShape.d[3]; + SizeType32 numSeqs = copyIndexShape.d[0]; + constexpr int32_t kExpectedVectorDim = 1; + TLLM_CHECK(indexScalesShape.nbDims == kExpectedVectorDim); + TLLM_CHECK(kvOffsetShape.nbDims == kExpectedVectorDim); + TLLM_CHECK_WITH_INFO(indexScalesShape.d[0] >= numPools, + "indexScales must have at least numPools=%d elements.", numPools); + TLLM_CHECK_WITH_INFO(kvOffsetShape.d[0] >= numPools, + "kvOffset must have at least numPools=%d elements.", numPools);
🤖 Fix all issues with AI agents
In `@tensorrt_llm/_torch/pyexecutor/resource_manager.py`:
- Line 1574: Remove the stray debug print in KVCacheManagerV2.__init__: replace
the call to print(config.layers) with either nothing (delete it) or a structured
logger call such as logger.debug("layers=%s", config.layers) so stdout is not
polluted; update the KVCacheManagerV2 constructor to use the logger (or remove)
and ensure any required logger is imported/available in the class.
🧹 Nitpick comments (2)
tensorrt_llm/_torch/pyexecutor/resource_manager.py (2)
1839-1850: Consider vectorizing the per-element Python loop.The list comprehension on lines 1842–1846 iterates per-element in Python, which can be slow for long sequences. A vectorized approach would be more efficient:
Proposed vectorized implementation
def get_block_ids_per_seq(self, request_ids: List[int]) -> torch.Tensor: block_ids_per_seq = self.get_batch_cache_indices(request_ids) block_ids_per_seq_tensors = [ - torch.tensor([ - i // self.num_local_layers if i != BAD_PAGE_INDEX else i - for i in sublist - ], - dtype=torch.int) for sublist in block_ids_per_seq + torch.where( + (t := torch.tensor(sublist, dtype=torch.int)) != BAD_PAGE_INDEX, + t // self.num_local_layers, + t, + ) for sublist in block_ids_per_seq ] padded_tensor = torch.nn.utils.rnn.pad_sequence( block_ids_per_seq_tensors, batch_first=True, padding_value=0) return padded_tensorBased on learnings: "In files under tensorrt_llm/_torch/pyexecutor, avoid accessing torch.Tensor objects inside for-loops when iterating over requests. Convert batched tensors to Python lists beforehand using tensor.tolist(), and then iterate over those lists."
1546-1549: Consider extracting the"nvfp4"check to a local variable.The string comparison
kv_cache_config.dtype == "nvfp4"is repeated four times in__init__. Extracting it to a local boolean would improve readability and reduce the risk of typos.Proposed refactor
Add near the top of
__init__, afterself.dtype = dtype:is_nvfp4 = kv_cache_config.dtype == "nvfp4"Then replace all four occurrences with
is_nvfp4.Also applies to: 1595-1595, 1612-1612, 1633-1633
278254e to
19bcac5
Compare
|
/bot run --disable-fail-fast |
|
PR_Github #35080 [ run ] triggered by Bot. Commit: |
eopXD
left a comment
There was a problem hiding this comment.
Looks good to me for the max_seq_len estimation
|
/bot run --disable-fail-fast |
|
PR_Github #35089 [ run ] triggered by Bot. Commit: |
|
PR_Github #35089 [ run ] completed with state
|
|
/bot run --disable-fail-fast |
|
PR_Github #35174 [ run ] triggered by Bot. Commit: |
|
PR_Github #35174 [ run ] completed with state
|
3ae9a75 to
fbff29f
Compare
|
/bot run --disable-fail-fast |
|
PR_Github #35241 [ run ] triggered by Bot. Commit: |
|
PR_Github #35241 [ run ] completed with state |
fbff29f to
3e8d0d6
Compare
|
/bot run --disable-fail-fast |
|
PR_Github #35254 [ run ] triggered by Bot. Commit: |
|
PR_Github #35254 [ run ] completed with state |
|
/bot run --disable-fail-fast |
|
PR_Github #35274 [ run ] triggered by Bot. Commit: |
59a1185 to
dcda6be
Compare
|
/bot run --disable-fail-fast |
c4048db to
4f90a75
Compare
Signed-off-by: Yi Zhang <187001205+yizhang-nv@users.noreply.github.com>
4f90a75 to
d9bd46f
Compare
|
/bot run --disable-fail-fast |
|
PR_Github #35626 [ run ] triggered by Bot. Commit: |
|
PR_Github #35626 [ run ] completed with state
|
Signed-off-by: yizhang-nv <187001205+yizhang-nv@users.noreply.github.com>
|
/bot run --disable-fail-fast |
|
/bot kill |
|
/bot run --disable-fail-fast |
|
PR_Github #35697 [ kill ] triggered by Bot. Commit: |
|
PR_Github #35697 [ kill ] completed with state |
|
PR_Github #35698 [ run ] triggered by Bot. Commit: |
|
PR_Github #35701 [ run ] triggered by Bot. Commit: |
|
/bot run --disable-fail-fast |
|
PR_Github #35755 [ run ] triggered by Bot. Commit: |
|
PR_Github #35755 [ run ] completed with state
|
|
/bot run --disable-fail-fast |
|
PR_Github #35853 [ run ] triggered by Bot. Commit: |
|
PR_Github #35853 [ run ] completed with state
|
|
/bot run --disable-fail-fast |
|
PR_Github #36000 [ run ] triggered by Bot. Commit: |
|
PR_Github #36000 [ run ] completed with state |
…len esitmation, add flash mla support (NVIDIA#11334) Signed-off-by: Yi Zhang <187001205+yizhang-nv@users.noreply.github.com> Signed-off-by: yizhang-nv <187001205+yizhang-nv@users.noreply.github.com> Signed-off-by: peihu-nv <259410613+peihu-nv@users.noreply.github.com>
Summary by CodeRabbit
Release Notes
New Features
Improvements
Tests
Description
Test Coverage
PR Checklist
Please review the following before submitting your PR:
PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.
PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.
Test cases are provided for new code paths (see test instructions)
Any new dependencies have been scanned for license and vulnerabilities
CODEOWNERS updated if ownership changes
Documentation updated as needed
Update tava architecture diagram if there is a significant design change in PR.
The reviewers assigned automatically/manually are appropriate for the PR.
Please check this after reviewing the above items as appropriate for this PR.
GitHub Bot Help
/bot [-h] ['run', 'kill', 'skip', 'reuse-pipeline'] ...Provide a user friendly way for developers to interact with a Jenkins server.
Run
/bot [-h|--help]to print this help message.See details below for each supported subcommand.
Details
run [--reuse-test (optional)pipeline-id --disable-fail-fast --skip-test --stage-list "A10-PyTorch-1, xxx" --gpu-type "A30, H100_PCIe" --test-backend "pytorch, cpp" --add-multi-gpu-test --only-multi-gpu-test --disable-multi-gpu-test --post-merge --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx" --detailed-log --debug(experimental)]Launch build/test pipelines. All previously running jobs will be killed.
--reuse-test (optional)pipeline-id(OPTIONAL) : Allow the new pipeline to reuse build artifacts and skip successful test stages from a specified pipeline or the last pipeline if no pipeline-id is indicated. If the Git commit ID has changed, this option will be always ignored. The DEFAULT behavior of the bot is to reuse build artifacts and successful test results from the last pipeline.--disable-reuse-test(OPTIONAL) : Explicitly prevent the pipeline from reusing build artifacts and skipping successful test stages from a previous pipeline. Ensure that all builds and tests are run regardless of previous successes.--disable-fail-fast(OPTIONAL) : Disable fail fast on build/tests/infra failures.--skip-test(OPTIONAL) : Skip all test stages, but still run build stages, package stages and sanity check stages. Note: Does NOT update GitHub check status.--stage-list "A10-PyTorch-1, xxx"(OPTIONAL) : Only run the specified test stages. Examples: "A10-PyTorch-1, xxx". Note: Does NOT update GitHub check status.--gpu-type "A30, H100_PCIe"(OPTIONAL) : Only run the test stages on the specified GPU types. Examples: "A30, H100_PCIe". Note: Does NOT update GitHub check status.--test-backend "pytorch, cpp"(OPTIONAL) : Skip test stages which don't match the specified backends. Only support [pytorch, cpp, tensorrt, triton]. Examples: "pytorch, cpp" (does not run test stages with tensorrt or triton backend). Note: Does NOT update GitHub pipeline status.--only-multi-gpu-test(OPTIONAL) : Only run the multi-GPU tests. Note: Does NOT update GitHub check status.--disable-multi-gpu-test(OPTIONAL) : Disable the multi-GPU tests. Note: Does NOT update GitHub check status.--add-multi-gpu-test(OPTIONAL) : Force run the multi-GPU tests in addition to running L0 pre-merge pipeline.--post-merge(OPTIONAL) : Run the L0 post-merge pipeline instead of the ordinary L0 pre-merge pipeline.--extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx"(OPTIONAL) : Run the ordinary L0 pre-merge pipeline and specified test stages. Examples: --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx".--detailed-log(OPTIONAL) : Enable flushing out all logs to the Jenkins console. This will significantly increase the log volume and may slow down the job.--debug(OPTIONAL) : Experimental feature. Enable access to the CI container for debugging purpose. Note: Specify exactly one stage in thestage-listparameter to access the appropriate container environment. Note: Does NOT update GitHub check status.For guidance on mapping tests to stage names, see
docs/source/reference/ci-overview.mdand the
scripts/test_to_stage_mapping.pyhelper.kill
killKill all running builds associated with pull request.
skip
skip --comment COMMENTSkip testing for latest commit on pull request.
--comment "Reason for skipping build/test"is required. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.reuse-pipeline
reuse-pipelineReuse a previous pipeline to validate current commit. This action will also kill all currently running builds associated with the pull request. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.