Skip to content

[None][feat] Add NVFP4 dynamic quantization support for visual_gen models#11563

Merged
chang-l merged 5 commits intoNVIDIA:mainfrom
chang-l:nvfp4_dyn_quant
Feb 20, 2026
Merged

[None][feat] Add NVFP4 dynamic quantization support for visual_gen models#11563
chang-l merged 5 commits intoNVIDIA:mainfrom
chang-l:nvfp4_dyn_quant

Conversation

@chang-l
Copy link
Collaborator

@chang-l chang-l commented Feb 18, 2026

  • Add dynamic NVFP4 (W4A4) quantization for diffusion transformers, enabling
    on-the-fly quant from BF16 checkpoints without pre-quantized NVFP4 weights
  • Fix QKV fused projection global scale for NVFP4 dynamic quant
  • Add accuracy tests (static vs dynamic vs BF16) for TI2V-5B and T2V-A14B,
    including mixed quantization with ignore patterns

Summary by CodeRabbit

  • New Features

    • Added NVFP4 dynamic activation quantization support with runtime computation capabilities
    • Introduced enhanced weight scaling parameters for quantization operations
    • Added fused weight quantization processing for multiple weight configurations
  • Tests

    • Expanded test coverage for NVFP4 quantization validation and static/dynamic comparisons
    • Added mixed quantization configuration testing scenarios

Description

Test Coverage

PR Checklist

Please review the following before submitting your PR:

  • PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.

  • PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.

  • Test cases are provided for new code paths (see test instructions)

  • Any new dependencies have been scanned for license and vulnerabilities

  • CODEOWNERS updated if ownership changes

  • Documentation updated as needed

  • Update tava architecture diagram if there is a significant design change in PR.

  • The reviewers assigned automatically/manually are appropriate for the PR.

  • Please check this after reviewing the above items as appropriate for this PR.

GitHub Bot Help

/bot [-h] ['run', 'kill', 'skip', 'reuse-pipeline'] ...

Provide a user friendly way for developers to interact with a Jenkins server.

Run /bot [-h|--help] to print this help message.

See details below for each supported subcommand.

Details

run [--reuse-test (optional)pipeline-id --disable-fail-fast --skip-test --stage-list "A10-PyTorch-1, xxx" --gpu-type "A30, H100_PCIe" --test-backend "pytorch, cpp" --add-multi-gpu-test --only-multi-gpu-test --disable-multi-gpu-test --post-merge --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx" --detailed-log --debug(experimental)]

Launch build/test pipelines. All previously running jobs will be killed.

--reuse-test (optional)pipeline-id (OPTIONAL) : Allow the new pipeline to reuse build artifacts and skip successful test stages from a specified pipeline or the last pipeline if no pipeline-id is indicated. If the Git commit ID has changed, this option will be always ignored. The DEFAULT behavior of the bot is to reuse build artifacts and successful test results from the last pipeline.

--disable-reuse-test (OPTIONAL) : Explicitly prevent the pipeline from reusing build artifacts and skipping successful test stages from a previous pipeline. Ensure that all builds and tests are run regardless of previous successes.

--disable-fail-fast (OPTIONAL) : Disable fail fast on build/tests/infra failures.

--skip-test (OPTIONAL) : Skip all test stages, but still run build stages, package stages and sanity check stages. Note: Does NOT update GitHub check status.

--stage-list "A10-PyTorch-1, xxx" (OPTIONAL) : Only run the specified test stages. Examples: "A10-PyTorch-1, xxx". Note: Does NOT update GitHub check status.

--gpu-type "A30, H100_PCIe" (OPTIONAL) : Only run the test stages on the specified GPU types. Examples: "A30, H100_PCIe". Note: Does NOT update GitHub check status.

--test-backend "pytorch, cpp" (OPTIONAL) : Skip test stages which don't match the specified backends. Only support [pytorch, cpp, tensorrt, triton]. Examples: "pytorch, cpp" (does not run test stages with tensorrt or triton backend). Note: Does NOT update GitHub pipeline status.

--only-multi-gpu-test (OPTIONAL) : Only run the multi-GPU tests. Note: Does NOT update GitHub check status.

--disable-multi-gpu-test (OPTIONAL) : Disable the multi-GPU tests. Note: Does NOT update GitHub check status.

--add-multi-gpu-test (OPTIONAL) : Force run the multi-GPU tests in addition to running L0 pre-merge pipeline.

--post-merge (OPTIONAL) : Run the L0 post-merge pipeline instead of the ordinary L0 pre-merge pipeline.

--extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx" (OPTIONAL) : Run the ordinary L0 pre-merge pipeline and specified test stages. Examples: --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx".

--detailed-log (OPTIONAL) : Enable flushing out all logs to the Jenkins console. This will significantly increase the log volume and may slow down the job.

--debug (OPTIONAL) : Experimental feature. Enable access to the CI container for debugging purpose. Note: Specify exactly one stage in the stage-list parameter to access the appropriate container environment. Note: Does NOT update GitHub check status.

For guidance on mapping tests to stage names, see docs/source/reference/ci-overview.md
and the scripts/test_to_stage_mapping.py helper.

kill

kill

Kill all running builds associated with pull request.

skip

skip --comment COMMENT

Skip testing for latest commit on pull request. --comment "Reason for skipping build/test" is required. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.

reuse-pipeline

reuse-pipeline

Reuse a previous pipeline to validate current commit. This action will also kill all currently running builds associated with the pull request. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.

Signed-off-by: Chang Liu <9713593+chang-l@users.noreply.github.com>
Signed-off-by: Chang Liu <9713593+chang-l@users.noreply.github.com>
Signed-off-by: Chang Liu <9713593+chang-l@users.noreply.github.com>
Signed-off-by: Chang Liu <9713593+chang-l@users.noreply.github.com>
@chang-l chang-l requested review from a team as code owners February 18, 2026 05:38
@chang-l
Copy link
Collaborator Author

chang-l commented Feb 18, 2026

/bot run --disable-fail-fast

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Feb 18, 2026

📝 Walkthrough

Walkthrough

This pull request extends NVFP4 dynamic quantization support across the TensorRT-LLM quantization pipeline. Changes include adding weight_scale_2 parameter to linear quantization methods, modifying _input_prepare to return alpha values, implementing fused NVFP4 quantization, and adding comprehensive test coverage for dynamic NVFP4 quantization scenarios.

Changes

Cohort / File(s) Summary
Core Linear Quantization Extensions
tensorrt_llm/_torch/modules/linear.py
Added weight_scale_2 parameter across quantization methods; updated _input_prepare to return triple (act_fp4, act_sf, alpha) instead of pair; modified load_weight_scales signature to return four values including new weight_scale_2; extended weight-loading paths (vanilla, fused_qkv, fused_gate_up) to load and propagate weight_scale_2 and alpha; added dynamic activation quantization handling with conditional input_scale/alpha computation.
Dynamic Quantization Configuration
tensorrt_llm/_torch/visual_gen/config.py
When NVFP4 quant_algo is specified with dynamic weight quantization and no config_groups, enabled dynamic activation quantization for NVFP4 to maintain consistency between weight and activation quantization.
NVFP4 Quantization Loader
tensorrt_llm/_torch/visual_gen/quantization/loader.py
Introduced NVFP4 dynamic quantization support in _maybe_dynamic_quantize; added _is_fused_nvfp4_dynamic detector for fused weight sets; implemented _quantize_fused_nvfp4 to quantize fused weights (QKV, gate/up) together while preserving consistent global scale and weight_scale_2; updated load_linear_weights to use fused NVFP4 path when applicable.
Quantization Operations
tensorrt_llm/_torch/visual_gen/quantization/ops.py
Added E2M1_MAX constant (6.0) and new public function quantize_nvfp4(weight, block_size=16) that produces quantized weights, per-block scales, and global weight scale (weight_scale_2) using TRT-LLM CUDA kernel.
NVFP4 Dynamic Quantization Tests
tests/unittest/_torch/visual_gen/test_wan.py
Added NVFP4-specific checkpoint path and new test test_load_wan_pipeline_with_nvfp4_quantization validating NVFP4 weights and metadata; introduced wan22_t2v_bf16_checkpoint_exists fixture; renamed and extended NVFP4 tests to compare static vs dynamic quantization against BF16 baselines; added mixed quantization scenarios with dynamic NVFP4 flows.

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

🚥 Pre-merge checks | ✅ 1 | ❌ 2

❌ Failed checks (2 warnings)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 74.07% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
Description check ⚠️ Warning The PR description is largely incomplete. It only provides bullet points of the main changes without the required template structure. Complete the PR description by filling in the missing sections: provide a proper title with [type] tag, expand the Description section with details, and list specific Test Coverage for the changes.
✅ Passed checks (1 passed)
Check name Status Explanation
Title check ✅ Passed The title accurately describes the main feature addition: NVFP4 dynamic quantization support for visual_gen models.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Tip

Issue Planner is now in beta. Read the docs and try it out! Share your feedback on Discord.


Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 5

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (2)
tensorrt_llm/_torch/visual_gen/quantization/loader.py (1)

1-19: ⚠️ Potential issue | 🟡 Minor

Missing NVIDIA copyright header.

Same as ops.py — this file needs an NVIDIA copyright header per coding guidelines. As per coding guidelines: "All source files must contain an NVIDIA copyright header with the year of latest meaningful modification."

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tensorrt_llm/_torch/visual_gen/quantization/loader.py` around lines 1 - 19,
Add the required NVIDIA copyright header to the top of
tensorrt_llm._torch.visual_gen.quantization.loader (loader.py) consistent with
the header used in ops.py and other project files, including the correct year of
latest meaningful modification; place the header before any imports/docstrings
so it appears as the very first lines of the file and matches project
formatting/wording used elsewhere.
tensorrt_llm/_torch/visual_gen/quantization/ops.py (1)

1-14: ⚠️ Potential issue | 🟡 Minor

Missing NVIDIA copyright header.

Per coding guidelines, all .py source files must contain an NVIDIA copyright header with the Apache License 2.0 format. This file starts with a docstring but has no copyright header. As per coding guidelines: "All source files must contain an NVIDIA copyright header with the year of latest meaningful modification."

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tensorrt_llm/_torch/visual_gen/quantization/ops.py` around lines 1 - 14, This
file is missing the required NVIDIA Apache-2.0 copyright header; add the
standard NVIDIA copyright header (with the correct year of latest meaningful
modification) in Apache License 2.0 format immediately above the module
docstring (before the existing triple-quoted comment) so the header appears at
the top of the module that defines FP8_E4M3_MAX and E2M1_MAX; ensure the header
text matches the project template and includes the license URL and copyright
year.
🧹 Nitpick comments (3)
tensorrt_llm/_torch/modules/linear.py (1)

1219-1266: Dynamic quantization path in _input_prepare looks correct, but relies on force_dynamic_quantization being set.

In the dynamic case, module.input_scale is still a Parameter(torch.empty([1])) (created at line 1187) — it is never set to None by process_weights_after_loading_vanilla (unlike FP8QDQLinearMethod which does so at line 654). The dynamic branch (line 1252) will only trigger when module.force_dynamic_quantization is True.

This works correctly for the visual_gen pipeline because config.py (lines 458–461) sets force_dynamic_quantization = True for dynamic NVFP4. However, if anyone uses NVFP4 dynamic quantization via a different entry-point without setting force_dynamic_quantization, the code would silently use the uninitialized module.input_scale (garbage values from torch.empty).

Consider adding a process_weights_after_loading_vanilla override to NVFP4LinearMethod that nullifies input_scale and alpha when no static values were loaded, mirroring the FP8QDQLinearMethod pattern (lines 652–657). This makes the dynamic path self-consistent regardless of how the module is initialized.

Sketch of the suggested override

Add to NVFP4LinearMethod (after load_weights_fused_gate_up_linear):

def process_weights_after_loading_vanilla(self, module: Linear):
    if not hasattr(module, "has_static_input_scale"):
        module.input_scale = None
        module.alpha = None
    else:
        delattr(module, "has_static_input_scale")

And set has_static_input_scale when input_scale is not None in load_weights_vanilla (line 1394).

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tensorrt_llm/_torch/modules/linear.py` around lines 1219 - 1266, The
dynamic-quantization branch in _input_prepare can end up using an uninitialized
module.input_scale/alpha because NVFP4LinearMethod never nullifies these after
loading; add a process_weights_after_loading_vanilla override to
NVFP4LinearMethod that clears module.input_scale and module.alpha when static
values weren't loaded (mirror FP8QDQLinearMethod behavior), and set a marker
attribute (e.g., module.has_static_input_scale) in load_weights_vanilla when you
do load static input_scale so the new override can decide whether to nullify or
remove that marker; ensure references are to NVFP4LinearMethod,
process_weights_after_loading_vanilla, load_weights_vanilla, module.input_scale,
module.alpha, and module.has_static_input_scale.
tests/unittest/_torch/visual_gen/test_wan.py (2)

2302-2310: Reuse the already-computed output_nvfp4_static_float instead of calling .float() twice.

output_nvfp4_static_float is defined at line 2262 in the same branch.

♻️ Proposed fix
-    nvfp4_static_range = (
-        output_nvfp4_static.float().min().item(),
-        output_nvfp4_static.float().max().item(),
-    )
+    nvfp4_static_range = (
+        output_nvfp4_static_float.min().item(),
+        output_nvfp4_static_float.max().item(),
+    )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/unittest/_torch/visual_gen/test_wan.py` around lines 2302 - 2310, The
test prints the min/max of output_nvfp4_static but re-calls .float() twice
instead of reusing the precomputed tensor; replace the double .float() calls
with the already-computed output_nvfp4_static_float (the variable defined
earlier at line ~2262) when computing nvfp4_static_range so the code uses
output_nvfp4_static_float.min().item() and .max().item() rather than
output_nvfp4_static.float().

1886-1886: Rename unused loop variable name_name (Ruff B007).

Applies to three identical patterns at Lines 1886, 2102, and 2143.

♻️ Proposed fix (apply identically at all three sites)
-for name, module in pipeline_nvfp4_dynamic.transformer.named_modules():
+for _name, module in pipeline_nvfp4_dynamic.transformer.named_modules():

Also applies to: 2102-2102, 2143-2143

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/unittest/_torch/visual_gen/test_wan.py` at line 1886, The loop variable
"name" in the iteration "for name, module in
pipeline_nvfp4_dynamic.transformer.named_modules()" is unused and should be
renamed to "_name" to satisfy Ruff B007; update each occurrence (the three
identical loops that iterate over
pipeline_nvfp4_dynamic.transformer.named_modules()—located at the three reported
sites) by replacing "name" with "_name" so only the module variable is used and
the linter warning is resolved.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@tensorrt_llm/_torch/modules/linear.py`:
- Around line 1251-1258: The dynamic branch in _input_prepare reads
module.weight_scale_2 unconditionally, which can be uninitialized; add a
defensive check before using it in the dynamic quantization path (the block that
computes input_scale and alpha) to ensure module.weight_scale_2 is present and
valid — either assert with a clear error message referencing
module.weight_scale_2 or initialize a safe fallback (e.g., copy from
module.weight_scale or compute a default tensor matching expected
dtype/device/shape) so alpha is computed from a valid value; update the
_input_prepare dynamic branch to perform this validation and produce a clear
failure or deterministic fallback rather than reading uninitialized memory.
- Around line 1474-1484: In load_weights_fused_gate_up_linear, the code copies
input_scale but fails to update the corresponding inverse scale; when
input_scale is not None call the same update performed in the vanilla path:
copy_weight(module.input_scale, input_scale) and also set module.inv_input_scale
= 1.0 / input_scale.item() (or mirror how inv_input_scale is computed elsewhere)
so module.inv_input_scale is kept in sync with input_scale; use the existing
copy_weight helper and the same pattern used for module.input_scale and
module.alpha to locate where to add this update.
- Around line 1434-1444: In load_weights_fused_qkv_linear the static-path branch
copies input_scale but fails to update module.inv_input_scale, leaving the
default from create_weights; update the static path so that when input_scale is
not None you also set module.inv_input_scale.data = module.input_scale /
E2M1_MAX (same logic as the vanilla path), ensuring module.inv_input_scale is
consistent for downstream consumers like KV cache scaling; locate this in
load_weights_fused_qkv_linear near the copy_weight(module.input_scale,
input_scale) call and apply the same assignment used in the vanilla load_weights
path.

In `@tensorrt_llm/_torch/visual_gen/quantization/loader.py`:
- Around line 199-264: _is_fused_nvfp4_dynamic currently uses
any(self._should_dynamic_quantize(...)) which can select the fused path when
only some components need dynamic quantization and lead to mixed-dtype torch.cat
in _quantize_fused_nvfp4; change the decision to require uniform
dynamic-quantize need by using all(...) (i.e., replace any with all) so fused
quantization is only attempted when every weight_dict in weight_dicts needs
NVFP4, and additionally add a defensive check inside _quantize_fused_nvfp4 to
verify all source tensors are the same dtype (or convert them to the expected
BF16/CUDA dtype) before concatenation and raise a clear ValueError if not; refer
to _is_fused_nvfp4_dynamic, _quantize_fused_nvfp4, _should_dynamic_quantize, and
quantize_nvfp4 to locate the logic.

In `@tests/unittest/_torch/visual_gen/test_wan.py`:
- Around line 1884-1892: Add an explicit guard assertion that
dynamic_quant_modules > 0 after counting quantized Linear modules for
pipeline_nvfp4_dynamic to mirror the static check; locate the loop over
pipeline_nvfp4_dynamic.transformer.named_modules() that increments
dynamic_quant_modules (checking isinstance(module, Linear) and
module.weight_scale) and, immediately after the print(f"[NVFP4 Dynamic]
Quantized Linear modules: {dynamic_quant_modules}"), add an assertion (or pytest
assert) that fails with a clear message when dynamic_quant_modules == 0 so a
broken dynamic-quantization path is diagnosed early.

---

Outside diff comments:
In `@tensorrt_llm/_torch/visual_gen/quantization/loader.py`:
- Around line 1-19: Add the required NVIDIA copyright header to the top of
tensorrt_llm._torch.visual_gen.quantization.loader (loader.py) consistent with
the header used in ops.py and other project files, including the correct year of
latest meaningful modification; place the header before any imports/docstrings
so it appears as the very first lines of the file and matches project
formatting/wording used elsewhere.

In `@tensorrt_llm/_torch/visual_gen/quantization/ops.py`:
- Around line 1-14: This file is missing the required NVIDIA Apache-2.0
copyright header; add the standard NVIDIA copyright header (with the correct
year of latest meaningful modification) in Apache License 2.0 format immediately
above the module docstring (before the existing triple-quoted comment) so the
header appears at the top of the module that defines FP8_E4M3_MAX and E2M1_MAX;
ensure the header text matches the project template and includes the license URL
and copyright year.

---

Nitpick comments:
In `@tensorrt_llm/_torch/modules/linear.py`:
- Around line 1219-1266: The dynamic-quantization branch in _input_prepare can
end up using an uninitialized module.input_scale/alpha because NVFP4LinearMethod
never nullifies these after loading; add a process_weights_after_loading_vanilla
override to NVFP4LinearMethod that clears module.input_scale and module.alpha
when static values weren't loaded (mirror FP8QDQLinearMethod behavior), and set
a marker attribute (e.g., module.has_static_input_scale) in load_weights_vanilla
when you do load static input_scale so the new override can decide whether to
nullify or remove that marker; ensure references are to NVFP4LinearMethod,
process_weights_after_loading_vanilla, load_weights_vanilla, module.input_scale,
module.alpha, and module.has_static_input_scale.

In `@tests/unittest/_torch/visual_gen/test_wan.py`:
- Around line 2302-2310: The test prints the min/max of output_nvfp4_static but
re-calls .float() twice instead of reusing the precomputed tensor; replace the
double .float() calls with the already-computed output_nvfp4_static_float (the
variable defined earlier at line ~2262) when computing nvfp4_static_range so the
code uses output_nvfp4_static_float.min().item() and .max().item() rather than
output_nvfp4_static.float().
- Line 1886: The loop variable "name" in the iteration "for name, module in
pipeline_nvfp4_dynamic.transformer.named_modules()" is unused and should be
renamed to "_name" to satisfy Ruff B007; update each occurrence (the three
identical loops that iterate over
pipeline_nvfp4_dynamic.transformer.named_modules()—located at the three reported
sites) by replacing "name" with "_name" so only the module variable is used and
the linter warning is resolved.

@tensorrt-cicd
Copy link
Collaborator

PR_Github #36132 [ run ] triggered by Bot. Commit: 130e193

@tensorrt-cicd
Copy link
Collaborator

PR_Github #36132 [ run ] completed with state FAILURE. Commit: 130e193
/LLM/main/L0_MergeRequest_PR pipeline #27921 completed with status: 'FAILURE'

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

@chang-l
Copy link
Collaborator Author

chang-l commented Feb 18, 2026

/bot run

@tensorrt-cicd
Copy link
Collaborator

PR_Github #36166 [ run ] triggered by Bot. Commit: 130e193 Link to invocation

@tensorrt-cicd
Copy link
Collaborator

PR_Github #36166 [ run ] completed with state SUCCESS. Commit: 130e193
/LLM/main/L0_MergeRequest_PR pipeline #27949 completed with status: 'FAILURE'

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

Link to invocation

@chang-l
Copy link
Collaborator Author

chang-l commented Feb 18, 2026

/bot run

@tensorrt-cicd
Copy link
Collaborator

PR_Github #36186 [ run ] triggered by Bot. Commit: 130e193 Link to invocation

@tensorrt-cicd
Copy link
Collaborator

PR_Github #36186 [ run ] completed with state FAILURE. Commit: 130e193
/LLM/main/L0_MergeRequest_PR pipeline #27967 completed with status: 'FAILURE'

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

Link to invocation

@chang-l
Copy link
Collaborator Author

chang-l commented Feb 19, 2026

/bot run

@tensorrt-cicd
Copy link
Collaborator

PR_Github #36268 [ run ] triggered by Bot. Commit: 130e193 Link to invocation

@tensorrt-cicd
Copy link
Collaborator

PR_Github #36268 [ run ] completed with state SUCCESS. Commit: 130e193
/LLM/main/L0_MergeRequest_PR pipeline #28042 completed with status: 'SUCCESS'

Link to invocation

Signed-off-by: Chang Liu <9713593+chang-l@users.noreply.github.com>
@chang-l
Copy link
Collaborator Author

chang-l commented Feb 20, 2026

/bot reuse-pipeline --comment "only change comment from last successful pipeline"

@tensorrt-cicd
Copy link
Collaborator

PR_Github #36344 [ reuse-pipeline ] triggered by Bot. Commit: 0a4f0ee Link to invocation

@tensorrt-cicd
Copy link
Collaborator

PR_Github #36344 [ reuse-pipeline ] completed with state SUCCESS. Commit: 0a4f0ee
Reusing PR_Github #36268 for commit 0a4f0ee

Link to invocation

@chang-l chang-l merged commit 0c239d0 into NVIDIA:main Feb 20, 2026
5 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants