Skip to content

Comments

[None][feat] Add priority-based KV cache offload filtering support#10751

Merged
pcastonguay merged 17 commits intoNVIDIA:mainfrom
nv-yna:yna/offload-filter-prototype
Feb 5, 2026
Merged

[None][feat] Add priority-based KV cache offload filtering support#10751
pcastonguay merged 17 commits intoNVIDIA:mainfrom
nv-yna:yna/offload-filter-prototype

Conversation

@nv-yna
Copy link
Collaborator

@nv-yna nv-yna commented Jan 16, 2026

Summary

  • Add getPriorityByBlockId method to KVCacheManager to expose block retention priorities
  • Extend RequestData with priorities field for the KV cache connector
  • Populate priorities in update_and_build_data to pass to downstream KVBM

This enables KVBM to filter which blocks get offloaded based on retention priority (e.g., only offload high-priority system prompt blocks for explicit prompt caching).

Companion PR

Test plan

  • Build C++ and Python bindings
  • Verify get_priority_by_block_id returns correct priority values
  • Integration test with KVBM priority filtering enabled

Summary by CodeRabbit

  • New Features

    • Added new method to query KV cache block retention priorities by block ID and window size
    • Request data now carries per-block retention priorities, enabling priority-based KV cache offloading and management
  • Tests

    • Added unit tests validating block priority retrieval for valid and invalid block IDs
    • Added integration tests verifying per-block priorities in request generation with retention configuration

✏️ Tip: You can customize this high-level summary in your review settings.

@nv-yna nv-yna force-pushed the yna/offload-filter-prototype branch 2 times, most recently from d1c2d0d to 2ab02d6 Compare January 16, 2026 08:33
Copy link
Collaborator

@jthomson04 jthomson04 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can save this for the end, but would also be nice to include an integration test in https://github.com/NVIDIA/TensorRT-LLM/blob/main/tests/integration/defs/llmapi/test_llm_api_connector.py

@nv-yna
Copy link
Collaborator Author

nv-yna commented Jan 23, 2026

Can save this for the end, but would also be nice to include an integration test in https://github.com/NVIDIA/TensorRT-LLM/blob/main/tests/integration/defs/llmapi/test_llm_api_connector.py

No worries. Added directly. PTAL

@nv-yna nv-yna requested a review from jthomson04 January 23, 2026 20:45
@nv-yna nv-yna changed the title feat: Add priority-based KV cache offload filtering support [None][feat] Add priority-based KV cache offload filtering support Jan 23, 2026
@nv-yna
Copy link
Collaborator Author

nv-yna commented Jan 23, 2026

/bot run

@tensorrt-cicd
Copy link
Collaborator

PR_Github #33412 [ run ] triggered by Bot. Commit: 9029f30

@tensorrt-cicd
Copy link
Collaborator

PR_Github #33412 [ run ] completed with state SUCCESS. Commit: 9029f30
/LLM/main/L0_MergeRequest_PR pipeline #25790 completed with status: 'FAILURE'

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

@nv-yna nv-yna force-pushed the yna/offload-filter-prototype branch from 9029f30 to 7c74b75 Compare January 26, 2026 06:48
@nv-yna
Copy link
Collaborator Author

nv-yna commented Jan 26, 2026

/bot run

@tensorrt-cicd
Copy link
Collaborator

PR_Github #33546 [ run ] triggered by Bot. Commit: 7c74b75

@nv-yna nv-yna marked this pull request as ready for review January 26, 2026 07:11
@nv-yna nv-yna requested review from a team as code owners January 26, 2026 07:11
@nv-yna nv-yna requested a review from achartier January 26, 2026 07:12
@coderabbitai
Copy link
Contributor

coderabbitai bot commented Jan 26, 2026

📝 Walkthrough

Walkthrough

This change introduces a new API method getPriorityByBlockId across the KV cache manager system to query retention priority for a block by its ID and window size. The method is implemented at the C++ level, exposed through language bindings (nanobind and pybind), integrated into the Python scheduler to propagate per-block priorities, and validated with unit and integration tests.

Changes

Cohort / File(s) Summary
C++ API Definition & Implementation
cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h, cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp
Added pure virtual method getPriorityByBlockId(blockId, windowSize) to BaseKVCacheManager and concrete override in KVCacheManager. Implementation performs block lookup; returns priority if found, logs warning and returns default priority (35) if not.
Language Bindings
cpp/tensorrt_llm/nanobind/batch_manager/kvCacheManager.cpp, cpp/tensorrt_llm/pybind/batch_manager/kvCacheManager.cpp
Exposed new getPriorityByBlockId method through nanobind and pybind11 bindings as get_priority_by_block_id(block_id, window_size) for Python access.
C++ Unit Tests
cpp/tests/unit_tests/batch_manager/kvCacheManagerTest.cpp
Added GetPriorityByBlockId test case validating method returns configured priority (80) for valid block IDs and default priority for invalid block IDs (-1, 9999).
Python Scheduler Integration
tensorrt_llm/_torch/pyexecutor/kv_cache_connector.py
Added optional priorities field to RequestData dataclass; scheduler now computes per-block priorities from KV cache manager and propagates them via request data.
Python Resource Manager API
tensorrt_llm/_torch/pyexecutor/resource_manager.py
Added public get_priority_by_block_id(block_id, window_size) method to KVCacheManager wrapper; updated get_batch_cache_indices signature to accept optional window_size for variable-sized window attention (VSWA) support.
Integration Tests
tests/integration/defs/llmapi/test_llm_api_connector.py
Added two test functions: test_connector_priorities validates that per-block priorities from KvCacheRetentionConfig are correctly propagated to request data; test_connector_priorities_default verifies priorities field is None when no retention config is provided.

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~20 minutes

🚥 Pre-merge checks | ✅ 2 | ❌ 1
❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 26.32% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Title check ✅ Passed The title clearly describes the main feature being added: priority-based KV cache offload filtering support. It is specific, concise, and accurately reflects the core change across all modified files.
Description check ✅ Passed The PR description covers key aspects: what was added (getPriorityByBlockId method, priorities field), why it matters (enabling KVBM to filter blocks by retention priority), and test coverage confirmation.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing touches
  • 📝 Generate docstrings

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

🤖 Fix all issues with AI agents
In `@cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h`:
- Around line 1686-1692: Replace the three-line Doxygen `///` comments above the
virtual method getPriorityByBlockId with the header-style single-line `//!`
comments: change each `/// `@brief``, `/// `@param``, and `/// `@return`` line to `//!
`@brief``, `//! `@param``, and `//! `@return`` respectively for the declaration of
executor::RetentionPriority getPriorityByBlockId(KVCacheBlock::IdType blockId,
SizeType32 windowSize) const = 0; so the header follows the project's `//!`
Doxygen style.

In `@cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp`:
- Around line 2578-2587: KVCacheManager::getPriorityByBlockId calls
mBlockManager.getBlockById which uses .at() and can throw std::out_of_range
before the warning/log and default return; update getPriorityByBlockId to guard
against invalid blockId/windowSize by either (A) performing an existence/bounds
check on mBlockManager (e.g., a hasBlock or validWindow/validId method) before
calling getBlockById and returning
tle::KvCacheRetentionConfig::kDefaultRetentionPriority when missing, or (B)
wrapping the getBlockById call in a try/catch for std::out_of_range and on
exception log the warning and return
tle::KvCacheRetentionConfig::kDefaultRetentionPriority; keep existing return
type tle::RetentionPriority and preserve the TLLM_LOG_WARNING message.

In `@tests/integration/defs/llmapi/test_llm_api_connector.py`:
- Around line 446-528: Remove the unused enforce_single_worker parameter from
the test function signatures (test_connector_priorities and
test_connector_priorities_default) and instead mark each test to use the fixture
by adding `@pytest.mark.usefixtures`("enforce_single_worker") above the function;
ensure pytest is imported at the top of the file if not already present so the
decorator resolves.
🧹 Nitpick comments (3)
tensorrt_llm/_torch/pyexecutor/kv_cache_connector.py (1)

318-325: Consider preserving the “None means default” contract.

Right now priorities is always a list (even when all defaults), so None is never used. If downstream treats None as “feature disabled,” this could change behavior. Consider gating priorities (e.g., only populate when a retention config is present) or update the field comment to match behavior.

tests/integration/defs/llmapi/test_llm_api_connector.py (1)

25-25: Keep module namespace on import.

Codebase rule prefers module-qualified access (avoids name clashes and keeps namespace explicit).
As per coding guidelines, update to a module import and qualify usages.

🔧 Suggested refactor
-from tensorrt_llm.llmapi.llm_utils import KvCacheRetentionConfig
+import tensorrt_llm.llmapi.llm_utils as llm_utils
@@
-    retention_config = KvCacheRetentionConfig(
+    retention_config = llm_utils.KvCacheRetentionConfig(
         token_range_retention_priorities=[
-            KvCacheRetentionConfig.TokenRangeRetentionConfig(
+            llm_utils.KvCacheRetentionConfig.TokenRangeRetentionConfig(
                 token_start=0,
                 token_end=32,
                 priority=80,
             ),
-            KvCacheRetentionConfig.TokenRangeRetentionConfig(
+            llm_utils.KvCacheRetentionConfig.TokenRangeRetentionConfig(
                 token_start=32,
                 token_end=None,  # Extend to end of sequence
                 priority=10,
             ),
         ],
         decode_retention_priority=10,
     )

Also applies to: 467-478

tensorrt_llm/_torch/pyexecutor/resource_manager.py (1)

853-865: Avoid hardcoding the default priority value in the docstring.

The default is defined in C++ and could change; the docstring risks drifting.

✏️ Suggested tweak
-        Returns:
-            The retention priority of the block (0-100), or default priority (35) if not found.
+        Returns:
+            The retention priority of the block (0-100), or the default priority if not found.

@tensorrt-cicd
Copy link
Collaborator

PR_Github #33546 [ run ] completed with state SUCCESS. Commit: 7c74b75
/LLM/main/L0_MergeRequest_PR pipeline #25876 completed with status: 'FAILURE'

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

@nv-yna nv-yna force-pushed the yna/offload-filter-prototype branch from 7c74b75 to 719ccbd Compare January 26, 2026 18:02
@nv-yna
Copy link
Collaborator Author

nv-yna commented Jan 26, 2026

/bot run

@tensorrt-cicd
Copy link
Collaborator

PR_Github #33621 [ run ] triggered by Bot. Commit: 719ccbd

@tensorrt-cicd
Copy link
Collaborator

PR_Github #33621 [ run ] completed with state SUCCESS. Commit: 719ccbd
/LLM/main/L0_MergeRequest_PR pipeline #25936 completed with status: 'FAILURE'

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

@nv-yna
Copy link
Collaborator Author

nv-yna commented Jan 26, 2026

/bot run

@tensorrt-cicd
Copy link
Collaborator

PR_Github #33634 [ run ] triggered by Bot. Commit: c188771

@tensorrt-cicd
Copy link
Collaborator

PR_Github #33634 [ run ] completed with state SUCCESS. Commit: c188771
/LLM/main/L0_MergeRequest_PR pipeline #25948 completed with status: 'FAILURE'

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

@nv-yna nv-yna removed the request for review from achartier January 27, 2026 00:35
@nv-yna
Copy link
Collaborator Author

nv-yna commented Jan 27, 2026

/bot run

@pcastonguay
Copy link
Collaborator

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Collaborator

PR_Github #34679 [ run ] triggered by Bot. Commit: dd0e58f

@pcastonguay pcastonguay enabled auto-merge (squash) February 3, 2026 19:29
@tensorrt-cicd
Copy link
Collaborator

PR_Github #34679 [ run ] completed with state FAILURE. Commit: dd0e58f
/LLM/main/L0_MergeRequest_PR pipeline #26760 completed with status: 'FAILURE'

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

@nv-yna
Copy link
Collaborator Author

nv-yna commented Feb 3, 2026

/bot run --disable-fail-fast

@nv-yna
Copy link
Collaborator Author

nv-yna commented Feb 4, 2026

/bot run

@nv-yna
Copy link
Collaborator Author

nv-yna commented Feb 4, 2026

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Collaborator

PR_Github #34729 [ run ] triggered by Bot. Commit: dd0e58f

@nv-yna
Copy link
Collaborator Author

nv-yna commented Feb 4, 2026

/bot kill

@tensorrt-cicd
Copy link
Collaborator

PR_Github #34797 [ kill ] triggered by Bot. Commit: dd0e58f

@tensorrt-cicd
Copy link
Collaborator

PR_Github #34797 [ kill ] completed with state SUCCESS. Commit: dd0e58f
Successfully killed previous jobs for commit dd0e58f

@nv-yna
Copy link
Collaborator Author

nv-yna commented Feb 4, 2026

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Collaborator

PR_Github #34810 [ run ] triggered by Bot. Commit: dd0e58f

@tensorrt-cicd
Copy link
Collaborator

PR_Github #34810 [ run ] completed with state SUCCESS. Commit: dd0e58f
/LLM/main/L0_MergeRequest_PR pipeline #26870 completed with status: 'FAILURE'

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

@nv-yna
Copy link
Collaborator Author

nv-yna commented Feb 5, 2026

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Collaborator

PR_Github #34872 [ run ] triggered by Bot. Commit: dd0e58f

@nv-yna
Copy link
Collaborator Author

nv-yna commented Feb 5, 2026

/bot kill

@tensorrt-cicd
Copy link
Collaborator

PR_Github #34891 [ kill ] triggered by Bot. Commit: af65dc7

@tensorrt-cicd
Copy link
Collaborator

PR_Github #34872 [ run ] completed with state ABORTED. Commit: dd0e58f

@tensorrt-cicd
Copy link
Collaborator

PR_Github #34891 [ kill ] completed with state SUCCESS. Commit: af65dc7
Successfully killed previous jobs for commit af65dc7

@nv-yna
Copy link
Collaborator Author

nv-yna commented Feb 5, 2026

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Collaborator

PR_Github #34894 [ run ] triggered by Bot. Commit: af65dc7

@tensorrt-cicd
Copy link
Collaborator

PR_Github #34894 [ run ] completed with state SUCCESS. Commit: af65dc7
/LLM/main/L0_MergeRequest_PR pipeline #26925 completed with status: 'SUCCESS'
Pipeline has performance regression cases. Check the performance regression report for details.

@pcastonguay pcastonguay merged commit 0d18b2d into NVIDIA:main Feb 5, 2026
5 checks passed
SchumiDing pushed a commit to SchumiDing/TensorRT-LLM that referenced this pull request Feb 6, 2026
…VIDIA#10751)

Signed-off-by: Yuewei Na <yna@nvidia.com>
Signed-off-by: Yuewei Na <nv-yna@users.noreply.github.com>
Co-authored-by: Yuewei Na <nv-yna@users.noreply.github.com>
inciaf pushed a commit to inciaf/trtllm-energy-monitoring that referenced this pull request Feb 18, 2026
…VIDIA#10751)

Signed-off-by: Yuewei Na <yna@nvidia.com>
Signed-off-by: Yuewei Na <nv-yna@users.noreply.github.com>
Co-authored-by: Yuewei Na <nv-yna@users.noreply.github.com>
Signed-off-by: Ahmet Inci <ainci@nvidia.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants