[#10826][feat] AutoDeploy: Eagle One-Model [2/n]: Prefill-Only Implementation#11073
Conversation
46745b6 to
57cf48a
Compare
|
/bot run --disable-fail-fast |
|
PR_Github #33930 [ run ] triggered by Bot. Commit: |
📝 WalkthroughWalkthroughThis pull request implements a wrapper-based speculative decoding flow for Eagle models by introducing the EagleWrapper class with embedding/LM head management, implementing EagleConfig for configuration handling, enhancing Eagle3DrafterForCausalLM with conditional component loading, and adding comprehensive test infrastructure for prefill-only Eagle workflows. Changes
Sequence DiagramssequenceDiagram
actor User
participant EagleWrapper
participant TargetModel
participant DraftModel
participant Sampler
User->>EagleWrapper: sample_and_verify(inputs_embeds, draft_input_ids, ...)
EagleWrapper->>TargetModel: forward(inputs_embeds)
TargetModel-->>EagleWrapper: logits, hidden_state
EagleWrapper->>DraftModel: apply_eagle3_fc(hidden_state)
DraftModel-->>EagleWrapper: compressed_hidden_state
EagleWrapper->>DraftModel: apply_lm_head(compressed_hidden_state)
DraftModel-->>EagleWrapper: draft_logits
EagleWrapper->>Sampler: sample_greedy(draft_logits)
Sampler-->>EagleWrapper: draft_token_ids
EagleWrapper->>EagleWrapper: compare_with_target(draft_ids, target_logits)
EagleWrapper-->>User: accepted_tokens, num_newly_accepted
sequenceDiagram
actor Caller
participant EagleWrapper as EagleWrapper._forward_prefill_only
participant TargetModel
participant DraftModel
Caller->>EagleWrapper: forward(inputs_embeds, position_ids, num_previously_accepted=None)
activate EagleWrapper
EagleWrapper->>TargetModel: generate logits iteratively
TargetModel-->>EagleWrapper: target_logits
EagleWrapper->>DraftModel: draft_tokens via sample_and_verify
DraftModel-->>EagleWrapper: accepted_tokens, hidden_states
EagleWrapper->>EagleWrapper: consolidate hidden_state_history
EagleWrapper-->>Caller: EagleWrapperOutput(new_tokens, new_tokens_lens)
deactivate EagleWrapper
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes 🚥 Pre-merge checks | ✅ 1 | ❌ 2❌ Failed checks (2 warnings)
✅ Passed checks (1 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing touches
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 4
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (2)
tensorrt_llm/_torch/auto_deploy/models/eagle.py (1)
1-2: Update the NVIDIA copyright year to 2026.This file was meaningfully modified in this PR; the header should reflect the latest year.
As per coding guidelines “All TensorRT-LLM source files (.cpp, .h, .cu, .py, and other source files) should contain an NVIDIA copyright header with the year of latest meaningful modification.”✅ Suggested update
-# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.tests/unittest/_torch/auto_deploy/unit/singlegpu/models/test_eagle.py (1)
56-77: Fixinputs_embedsinitialization wheninput_embedsis provided.As written,
inputs_embedsis undefined wheninput_embedsis non-None, which will raiseUnboundLocalError.✅ Suggested fix
- if input_embeds is None: - inputs_embeds = self.model.embed_tokens(input_ids) + if input_embeds is None: + inputs_embeds = self.model.embed_tokens(input_ids) + else: + inputs_embeds = input_embeds
🤖 Fix all issues with AI agents
In `@tensorrt_llm/_torch/auto_deploy/models/custom/modeling_eagle.py`:
- Around line 452-455: The apply_eagle3_fc method can return None when
self.model.fc is absent, breaking callers expecting a tensor; modify
apply_eagle3_fc (in the class containing self.model) to return the original
hidden_states when self.model.fc is None (i.e., use an explicit
else/early-return returning hidden_states) so callers always receive a
torch.Tensor.
- Around line 588-694: The code is slicing tensors using 0-d CUDA tensors
(num_previously_accepted, num_newly_accepted_tokens), which will fail on CUDA;
convert those index tensors to Python ints before using them for
slicing/indexing (use .item() or int(... .cpu().item()) as appropriate). Update
usages in sample_and_verify where unchecked_input_ids and
unchecked_target_logits are built (the list comprehensions using
num_previously_accepted[i] and seq_len), the loop that builds draft_input_ids
(prev_accepted = input_ids[i, 1 : int(num_previously_accepted[i])],
newly_accepted slice using num_previously_accepted[i] and
num_newly_accepted_tokens[i], next_token indexing into
unchecked_output_ids[i][0][int(num_newly_accepted_tokens[i])]), and when
selecting bonus_logit from unchecked_target_logits (index with
int(num_newly_accepted_tokens[i])). After changes, run a CUDA test to verify
slicing no longer errors.
In `@tests/integration/defs/examples/test_ad_speculative_decoding.py`:
- Around line 603-613: In LlamaModelWithCapture.forward change the kwarg passed
to self.model from input_embeds to inputs_embeds (or rename the parameter to
inputs_embeds) so the call uses inputs_embeds=input_embeds; update the forward
signature or the call site accordingly (function: LlamaModelWithCapture.forward,
symbol: self.model(... input_embeds=... ) and logits = self.lm_head(...)) to
ensure exactly one of input_ids or inputs_embeds is provided.
- Around line 623-624: The method get_output_embeddings in LlamaModelWithCapture
currently returns self.model.lm_head which will raise because
LlamaModelWithCapture does not define model.lm_head; change
get_output_embeddings to return the module's actual output head (e.g.,
self.lm_head) or access the correct attribute on the wrapped model (e.g.,
getattr(self, "lm_head", getattr(self.model, "lm_head", None))) so it returns
the real output embedding layer; update get_output_embeddings to reference the
existing attribute (lm_head) on LlamaModelWithCapture rather than
self.model.lm_head.
🧹 Nitpick comments (5)
tensorrt_llm/_torch/auto_deploy/models/eagle.py (1)
24-33: Use module-qualified imports and annotate_drafter_mappingasClassVar.This aligns with the namespace import guideline and resolves the mutable class-attribute lint.
As per coding guidelines “Always maintain the namespace when importing Python modules, even if only one class or function from a module is used.”♻️ Suggested refactor
-from dataclasses import dataclass -from typing import Dict, Type +import dataclasses +import typing -from transformers import PretrainedConfig, PreTrainedModel +import transformers @@ -@dataclass +@dataclasses.dataclass class EagleConfigInfo: @@ - config_class: Type[PreTrainedModel] + config_class: typing.Type[transformers.PreTrainedModel] @@ - _drafter_mapping: Dict[str, EagleConfigInfo] = { + _drafter_mapping: typing.ClassVar[typing.Dict[str, EagleConfigInfo]] = {Also applies to: 94-104
tensorrt_llm/_torch/auto_deploy/models/custom/modeling_eagle.py (1)
470-497: AlignEagleWrapperOutput.new_tokenstype with actual return value.
_forward_prefill_onlyreturns a list of tensors, butnew_tokensis typed asOptional[torch.Tensor]. Consider updating the annotation (or stacking to a tensor).♻️ Suggested type update
- new_tokens: Optional[torch.Tensor] = None + new_tokens: Optional[list[torch.Tensor]] = Nonetests/unittest/_torch/auto_deploy/unit/singlegpu/models/test_eagle.py (1)
25-30: Use module-qualified imports per namespace guideline.Prefer importing the module and qualifying symbols to keep namespaces intact.
As per coding guidelines “Always maintain the namespace when importing Python modules, even if only one class or function from a module is used.”♻️ Suggested refactor
-from tensorrt_llm._torch.auto_deploy.models.custom.modeling_eagle import ( - Eagle3DrafterForCausalLM, - Eagle3DraftOutput, -) -from tensorrt_llm._torch.auto_deploy.models.eagle import EagleConfigInfo, EagleDrafterFactory +import tensorrt_llm._torch.auto_deploy.models.custom.modeling_eagle as modeling_eagle +import tensorrt_llm._torch.auto_deploy.models.eagle as eagle_models @@ -class MockEagle3ModelForCausalLM(Eagle3DrafterForCausalLM): +class MockEagle3ModelForCausalLM(modeling_eagle.Eagle3DrafterForCausalLM): @@ - return Eagle3DraftOutput(logits=logits, last_hidden_state=draft_output.last_hidden_state) + return modeling_eagle.Eagle3DraftOutput( + logits=logits, last_hidden_state=draft_output.last_hidden_state + ) @@ -class MockEagleDrafterFactory(EagleDrafterFactory): +class MockEagleDrafterFactory(eagle_models.EagleDrafterFactory): @@ - "llama": EagleConfigInfo( + "llama": eagle_models.EagleConfigInfo( config_class=MockEagle3ModelForCausalLM,tests/integration/defs/examples/test_ad_speculative_decoding.py (2)
18-31: Use module-qualified imports per namespace guideline.Several new
from ... import ...statements violate the namespace import rule.As per coding guidelines “Always maintain the namespace when importing Python modules, even if only one class or function from a module is used.”♻️ Suggested refactor (pattern)
-from dataclasses import dataclass +import dataclasses @@ -from typing import Optional, Set +import typing @@ -from transformers import AutoModelForCausalLM, AutoTokenizer -from transformers.masking_utils import create_causal_mask -from transformers.modeling_outputs import BaseModelOutputWithPast -from transformers.models.llama.modeling_llama import LlamaModel -from transformers.utils.generic import ModelOutput +import transformersThen reference as
dataclasses.dataclass,typing.Optional,transformers.AutoModelForCausalLM,transformers.models.llama.modeling_llama.LlamaModel, etc.
663-672: Remove or usemax_seq_leninbuild_eagle_wrapper.It’s currently unused; either drop it or wire it into buffer sizing to avoid confusion.
tensorrt_llm/_torch/auto_deploy/models/custom/modeling_eagle.py
Outdated
Show resolved
Hide resolved
57cf48a to
b3e57fd
Compare
|
PR_Github #33930 [ run ] completed with state
|
|
/bot run --disable-fail-fast |
|
PR_Github #34101 [ run ] triggered by Bot. Commit: |
|
PR_Github #34101 [ run ] completed with state
|
lucaslie
left a comment
There was a problem hiding this comment.
overall looks good and it looks helpful for testing. My only concern is that this is very detached from what eventually will be the e2e workflow with AutoDeploy and looks more like a HF-style implementation to run it e2e. That being said, it's a good intermediate milestone and I just wanted to make you aware that things may have to change once it gets integrated more closely into AD
tensorrt_llm/_torch/auto_deploy/models/custom/modeling_eagle.py
Outdated
Show resolved
Hide resolved
tensorrt_llm/_torch/auto_deploy/models/custom/modeling_eagle.py
Outdated
Show resolved
Hide resolved
tensorrt_llm/_torch/auto_deploy/models/custom/modeling_eagle.py
Outdated
Show resolved
Hide resolved
tensorrt_llm/_torch/auto_deploy/models/custom/modeling_eagle.py
Outdated
Show resolved
Hide resolved
b3e57fd to
8578338
Compare
|
/bot run |
|
PR_Github #34252 [ run ] triggered by Bot. Commit: |
f15b347 to
66162fb
Compare
|
/bot run |
|
PR_Github #34256 [ run ] triggered by Bot. Commit: |
|
PR_Github #34256 [ run ] completed with state
|
66162fb to
6ea54f1
Compare
|
/bot run |
|
PR_Github #34271 [ run ] triggered by Bot. Commit: |
|
PR_Github #34271 [ run ] completed with state
|
|
/bot run |
|
PR_Github #34324 [ run ] triggered by Bot. Commit: |
|
PR_Github #34324 [ run ] completed with state |
|
/bot run |
|
PR_Github #34339 [ run ] triggered by Bot. Commit: |
|
PR_Github #34339 [ run ] completed with state
|
|
/bot run |
|
PR_Github #34344 [ run ] triggered by Bot. Commit: |
794d8c1 to
3a94eac
Compare
|
/bot run |
|
PR_Github #34347 [ run ] triggered by Bot. Commit: |
|
PR_Github #34347 [ run ] completed with state
|
… code with Llama + Eagle3. Acceptance rate seems good Signed-off-by: Govind Ramnarayan <105831528+govind-ramnarayan@users.noreply.github.com>
Signed-off-by: Govind Ramnarayan <105831528+govind-ramnarayan@users.noreply.github.com>
Signed-off-by: Govind Ramnarayan <105831528+govind-ramnarayan@users.noreply.github.com>
Signed-off-by: Govind Ramnarayan <105831528+govind-ramnarayan@users.noreply.github.com>
…ded in full Signed-off-by: Govind Ramnarayan <105831528+govind-ramnarayan@users.noreply.github.com>
Signed-off-by: Govind Ramnarayan <105831528+govind-ramnarayan@users.noreply.github.com>
Signed-off-by: Govind Ramnarayan <105831528+govind-ramnarayan@users.noreply.github.com>
3a94eac to
56ae5af
Compare
|
/bot run |
|
PR_Github #34369 [ run ] triggered by Bot. Commit: |
|
PR_Github #34369 [ run ] completed with state |
fixes: #10826
--Manual Summary--
Implements eagle one-model flow in prefill-only setting. This means no cached attention, so sequences are provided in full to the model. This is to be used for tracing and export in AutoDeploy.
Testing plan: Since this implements end-to-end Eagle support in the prefill-only setting, it can be tested by checking the acceptance rate of the model. This not only tests the correctness of the "glue code" (
EagleWrapper) in this PR, but also verifies the correctness of the Eagle checkpoint architecture and model loading.The test (
test_eagle_wrapper_forward()) sets up a custom "target model with capture" that manually captures hidden states from preconfigured layers, creates a resource manager to store these target hidden states, then instantiates the EagleWrapper with the target model, resource manager, and draft model. Then it runs the EagleWrapper autoregressively and checks the acceptance rate for decode requests. This is done for batch sizes 1 and 2.A note: The batch size 2 tests are quite hacky due to the fact that as we speculate, the sequences grow at different rates, and we are dealing with padded (not packed) representation here. To compensate, we discard all speculated tokens between each run, and calculate an acceptance rate by manually running these speculated tokens through the target model in a separate step. This makes the autoregressive loop identical to the normal decoding loop, but we check the accuracy of speculated tokens at each step.
Summary by CodeRabbit
New Features
Tests
✏️ Tip: You can customize this high-level summary in your review settings.
Description
Test Coverage
PR Checklist
Please review the following before submitting your PR:
PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.
PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.
Test cases are provided for new code paths (see test instructions)
Any new dependencies have been scanned for license and vulnerabilities
CODEOWNERS updated if ownership changes
Documentation updated as needed
Update tava architecture diagram if there is a significant design change in PR.
The reviewers assigned automatically/manually are appropriate for the PR.
Please check this after reviewing the above items as appropriate for this PR.
GitHub Bot Help
/bot [-h] ['run', 'kill', 'skip', 'reuse-pipeline'] ...Provide a user friendly way for developers to interact with a Jenkins server.
Run
/bot [-h|--help]to print this help message.See details below for each supported subcommand.
Details
run [--reuse-test (optional)pipeline-id --disable-fail-fast --skip-test --stage-list "A10-PyTorch-1, xxx" --gpu-type "A30, H100_PCIe" --test-backend "pytorch, cpp" --add-multi-gpu-test --only-multi-gpu-test --disable-multi-gpu-test --post-merge --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx" --detailed-log --debug(experimental)]Launch build/test pipelines. All previously running jobs will be killed.
--reuse-test (optional)pipeline-id(OPTIONAL) : Allow the new pipeline to reuse build artifacts and skip successful test stages from a specified pipeline or the last pipeline if no pipeline-id is indicated. If the Git commit ID has changed, this option will be always ignored. The DEFAULT behavior of the bot is to reuse build artifacts and successful test results from the last pipeline.--disable-reuse-test(OPTIONAL) : Explicitly prevent the pipeline from reusing build artifacts and skipping successful test stages from a previous pipeline. Ensure that all builds and tests are run regardless of previous successes.--disable-fail-fast(OPTIONAL) : Disable fail fast on build/tests/infra failures.--skip-test(OPTIONAL) : Skip all test stages, but still run build stages, package stages and sanity check stages. Note: Does NOT update GitHub check status.--stage-list "A10-PyTorch-1, xxx"(OPTIONAL) : Only run the specified test stages. Examples: "A10-PyTorch-1, xxx". Note: Does NOT update GitHub check status.--gpu-type "A30, H100_PCIe"(OPTIONAL) : Only run the test stages on the specified GPU types. Examples: "A30, H100_PCIe". Note: Does NOT update GitHub check status.--test-backend "pytorch, cpp"(OPTIONAL) : Skip test stages which don't match the specified backends. Only support [pytorch, cpp, tensorrt, triton]. Examples: "pytorch, cpp" (does not run test stages with tensorrt or triton backend). Note: Does NOT update GitHub pipeline status.--only-multi-gpu-test(OPTIONAL) : Only run the multi-GPU tests. Note: Does NOT update GitHub check status.--disable-multi-gpu-test(OPTIONAL) : Disable the multi-GPU tests. Note: Does NOT update GitHub check status.--add-multi-gpu-test(OPTIONAL) : Force run the multi-GPU tests in addition to running L0 pre-merge pipeline.--post-merge(OPTIONAL) : Run the L0 post-merge pipeline instead of the ordinary L0 pre-merge pipeline.--extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx"(OPTIONAL) : Run the ordinary L0 pre-merge pipeline and specified test stages. Examples: --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx".--detailed-log(OPTIONAL) : Enable flushing out all logs to the Jenkins console. This will significantly increase the log volume and may slow down the job.--debug(OPTIONAL) : Experimental feature. Enable access to the CI container for debugging purpose. Note: Specify exactly one stage in thestage-listparameter to access the appropriate container environment. Note: Does NOT update GitHub check status.For guidance on mapping tests to stage names, see
docs/source/reference/ci-overview.mdand the
scripts/test_to_stage_mapping.pyhelper.kill
killKill all running builds associated with pull request.
skip
skip --comment COMMENTSkip testing for latest commit on pull request.
--comment "Reason for skipping build/test"is required. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.reuse-pipeline
reuse-pipelineReuse a previous pipeline to validate current commit. This action will also kill all currently running builds associated with the pull request. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.