Skip to content

Comments

[TRTLLM-10325][feat] Refactor speculative decoding workers#10768

Merged
mikeiovine merged 2 commits intoNVIDIA:mainfrom
cascade812:guiju/spec1
Jan 21, 2026
Merged

[TRTLLM-10325][feat] Refactor speculative decoding workers#10768
mikeiovine merged 2 commits intoNVIDIA:mainfrom
cascade812:guiju/spec1

Conversation

@cascade812
Copy link
Collaborator

@cascade812 cascade812 commented Jan 16, 2026

Description

Refactor Eagle3OneModelWorker and MTPWorker to move shared code into SpecWorkerBase. This will improve maintainability, reduce code duplication, and ensure critical correctness logic is centralized.

Test Coverage

PR Checklist

Please review the following before submitting your PR:

  • PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.

  • PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.

  • Test cases are provided for new code paths (see test instructions)

  • Any new dependencies have been scanned for license and vulnerabilities

  • CODEOWNERS updated if ownership changes

  • Documentation updated as needed

  • Update tava architecture diagram if there is a significant design change in PR.

  • The reviewers assigned automatically/manually are appropriate for the PR.

  • Please check this after reviewing the above items as appropriate for this PR.

GitHub Bot Help

/bot [-h] ['run', 'kill', 'skip', 'reuse-pipeline'] ...

Provide a user friendly way for developers to interact with a Jenkins server.

Run /bot [-h|--help] to print this help message.

See details below for each supported subcommand.

Details

run [--reuse-test (optional)pipeline-id --disable-fail-fast --skip-test --stage-list "A10-PyTorch-1, xxx" --gpu-type "A30, H100_PCIe" --test-backend "pytorch, cpp" --add-multi-gpu-test --only-multi-gpu-test --disable-multi-gpu-test --post-merge --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx" --detailed-log --debug(experimental)]

Launch build/test pipelines. All previously running jobs will be killed.

--reuse-test (optional)pipeline-id (OPTIONAL) : Allow the new pipeline to reuse build artifacts and skip successful test stages from a specified pipeline or the last pipeline if no pipeline-id is indicated. If the Git commit ID has changed, this option will be always ignored. The DEFAULT behavior of the bot is to reuse build artifacts and successful test results from the last pipeline.

--disable-reuse-test (OPTIONAL) : Explicitly prevent the pipeline from reusing build artifacts and skipping successful test stages from a previous pipeline. Ensure that all builds and tests are run regardless of previous successes.

--disable-fail-fast (OPTIONAL) : Disable fail fast on build/tests/infra failures.

--skip-test (OPTIONAL) : Skip all test stages, but still run build stages, package stages and sanity check stages. Note: Does NOT update GitHub check status.

--stage-list "A10-PyTorch-1, xxx" (OPTIONAL) : Only run the specified test stages. Examples: "A10-PyTorch-1, xxx". Note: Does NOT update GitHub check status.

--gpu-type "A30, H100_PCIe" (OPTIONAL) : Only run the test stages on the specified GPU types. Examples: "A30, H100_PCIe". Note: Does NOT update GitHub check status.

--test-backend "pytorch, cpp" (OPTIONAL) : Skip test stages which don't match the specified backends. Only support [pytorch, cpp, tensorrt, triton]. Examples: "pytorch, cpp" (does not run test stages with tensorrt or triton backend). Note: Does NOT update GitHub pipeline status.

--only-multi-gpu-test (OPTIONAL) : Only run the multi-GPU tests. Note: Does NOT update GitHub check status.

--disable-multi-gpu-test (OPTIONAL) : Disable the multi-GPU tests. Note: Does NOT update GitHub check status.

--add-multi-gpu-test (OPTIONAL) : Force run the multi-GPU tests in addition to running L0 pre-merge pipeline.

--post-merge (OPTIONAL) : Run the L0 post-merge pipeline instead of the ordinary L0 pre-merge pipeline.

--extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx" (OPTIONAL) : Run the ordinary L0 pre-merge pipeline and specified test stages. Examples: --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx".

--detailed-log (OPTIONAL) : Enable flushing out all logs to the Jenkins console. This will significantly increase the log volume and may slow down the job.

--debug (OPTIONAL) : Experimental feature. Enable access to the CI container for debugging purpose. Note: Specify exactly one stage in the stage-list parameter to access the appropriate container environment. Note: Does NOT update GitHub check status.

For guidance on mapping tests to stage names, see docs/source/reference/ci-overview.md
and the scripts/test_to_stage_mapping.py helper.

kill

kill

Kill all running builds associated with pull request.

skip

skip --comment COMMENT

Skip testing for latest commit on pull request. --comment "Reason for skipping build/test" is required. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.

reuse-pipeline

reuse-pipeline

Reuse a previous pipeline to validate current commit. This action will also kill all currently running builds associated with the pull request. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.

Signed-off-by: Guiju Zhang <7135567+cascade812@users.noreply.github.com>
Signed-off-by: Guiju Zhang <7135567+cascade812@users.noreply.github.com>
@cascade812 cascade812 requested a review from a team as a code owner January 16, 2026 21:35
@cascade812 cascade812 requested a review from yweng0828 January 16, 2026 21:35
@coderabbitai
Copy link
Contributor

coderabbitai bot commented Jan 16, 2026

📝 Walkthrough

Walkthrough

This PR refactors speculative decoding logic by extracting repeated inline operations into dedicated helper methods. Eight new methods are added to the SpecWorkerBase class in interface.py for managing attention metadata, sampling and accepting draft tokens, executing guided decoders, and preparing tokens. Eagle3 and MTP implementations are updated to delegate to these helpers rather than implementing logic inline.

Changes

Cohort / File(s) Summary
Base interface additions
tensorrt_llm/_torch/speculative/interface.py
Added 8 new helper methods to SpecWorkerBase: metadata management (_prepare_attn_metadata_for_spec_dec, _restore_attn_metadata_from_spec_dec), draft token sampling (_sample_and_accept_draft_tokens_base, _draft_sampler_greedy), guided decoder execution (_execute_guided_decoder_if_present), token preparation (_prepare_next_new_tokens, _prepare_context_input_ids), sampling interface extension (_sample_tokens_for_batch), and testing support (_apply_force_accepted_tokens).
Eagle3 refactoring
tensorrt_llm/_torch/speculative/eagle3.py
Replaced direct guided decoder calls with _execute_guided_decoder_if_present(), switched attention metadata handling to _prepare_attn_metadata_for_spec_dec() and _restore_attn_metadata_from_spec_dec(), replaced inline token preparation with _prepare_next_new_tokens(), and delegated draft sampling to _draft_sampler_greedy().
MTP refactoring
tensorrt_llm/_torch/speculative/mtp.py
Removed public method restore_attn_metadata() and updated call sites to use _restore_attn_metadata_from_spec_dec(), replaced direct guided decoder execution with _execute_guided_decoder_if_present(), consolidated token preparation logic via _prepare_next_tokens() and _prepare_context_input_ids(), and aligned draft sampling to use _draft_sampler_greedy().

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~25 minutes

🚥 Pre-merge checks | ✅ 2 | ❌ 1
❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 68.18% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Title check ✅ Passed The title clearly describes the main refactoring work: moving shared code from Eagle3OneModelWorker and MTPWorker into SpecWorkerBase.
Description check ✅ Passed PR description provides clear summary of refactoring intent (consolidating shared code into SpecWorkerBase) and includes completed PR checklist items.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing touches
  • 📝 Generate docstrings

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Fix all issues with AI agents
In `@tensorrt_llm/_torch/speculative/interface.py`:
- Around line 433-455: The _apply_force_accepted_tokens function should guard
against negative force_num_accepted_tokens to avoid producing zeros (which later
cause index errors when using num_accepted_tokens - 1); update the logic that
computes force_total_tokens (and/or the conditional that checks
self.force_num_accepted_tokens) to ignore or clamp negative/zero values (e.g.,
treat any self.force_num_accepted_tokens <= 0 as "no-op" or clamp the forced
value to at least 1) before assigning to num_accepted_tokens[num_contexts:],
keeping the existing min(...) cap with self.max_draft_len + 1 and referring to
self.force_num_accepted_tokens, force_total_tokens, num_accepted_tokens, and
num_contexts to locate the change.

@cascade812
Copy link
Collaborator Author

/bot run

@tensorrt-cicd
Copy link
Collaborator

PR_Github #32355 [ run ] triggered by Bot. Commit: eb4dbb3

@tensorrt-cicd
Copy link
Collaborator

PR_Github #32355 [ run ] completed with state SUCCESS. Commit: eb4dbb3
/LLM/main/L0_MergeRequest_PR pipeline #25071 completed with status: 'FAILURE'

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

@cascade812
Copy link
Collaborator Author

/bot run

@cascade812 cascade812 requested a review from mikeiovine January 17, 2026 01:36
@tensorrt-cicd
Copy link
Collaborator

PR_Github #32375 [ run ] triggered by Bot. Commit: eb4dbb3

@tensorrt-cicd
Copy link
Collaborator

PR_Github #32375 [ run ] completed with state SUCCESS. Commit: eb4dbb3
/LLM/main/L0_MergeRequest_PR pipeline #25089 completed with status: 'FAILURE'

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

@cascade812
Copy link
Collaborator Author

/bot run

@tensorrt-cicd
Copy link
Collaborator

PR_Github #32390 [ run ] triggered by Bot. Commit: eb4dbb3

@tensorrt-cicd
Copy link
Collaborator

PR_Github #32390 [ run ] completed with state SUCCESS. Commit: eb4dbb3
/LLM/main/L0_MergeRequest_PR pipeline #25103 completed with status: 'FAILURE'

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

@cascade812
Copy link
Collaborator Author

/bot run

@tensorrt-cicd
Copy link
Collaborator

PR_Github #32792 [ run ] triggered by Bot. Commit: eb4dbb3

@tensorrt-cicd
Copy link
Collaborator

PR_Github #32792 [ run ] completed with state SUCCESS. Commit: eb4dbb3
/LLM/main/L0_MergeRequest_PR pipeline #25375 completed with status: 'SUCCESS'

@mikeiovine mikeiovine merged commit 8cf8fbb into NVIDIA:main Jan 21, 2026
11 of 13 checks passed
greg-kwasniewski1 pushed a commit to nv-auto-deploy/TensorRT-LLM that referenced this pull request Jan 22, 2026
)

Signed-off-by: Guiju Zhang <7135567+cascade812@users.noreply.github.com>
pathorn added a commit to deepinfra/TensorRT-LLM that referenced this pull request Jan 27, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants