Skip to content

Comments

[None][fix] nccl symmetric with graceful fallbacks#11042

Merged
Tabrizian merged 14 commits intoNVIDIA:mainfrom
nv-lschneider:debug-nccl-symm
Jan 28, 2026
Merged

[None][fix] nccl symmetric with graceful fallbacks#11042
Tabrizian merged 14 commits intoNVIDIA:mainfrom
nv-lschneider:debug-nccl-symm

Conversation

@nv-lschneider
Copy link
Collaborator

@nv-lschneider nv-lschneider commented Jan 27, 2026

Summary by CodeRabbit

Release Notes

  • Bug Fixes & Improvements
    • Enhanced NCCL window buffer allocation with better handling during GPU graph capture operations
    • Improved error handling with graceful fallbacks when buffer allocation fails
    • Added warnings to guide users when optimal buffer configurations are unavailable, ensuring operations continue with alternate paths

✏️ Tip: You can customize this high-level summary in your review settings.

Description

This makes NCCL_SYMMETRIC more gracful.
If a problem occurs during creation of registered tensors, the unregistered NCCL operation will be performed.

New is also that it detects execution during graph capture, since it is not possible to register buffers during graph capture.
In a follow up PR, I will add that buffers are registered before any graph capture.
This PR only fixes bad behavior that may happen with a graceful behavior that may compromise performance (not worse the NCCL though.)

Test Coverage

Covered by existing tests.

PR Checklist

Please review the following before submitting your PR:

  • PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.

  • PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.

  • Test cases are provided for new code paths (see test instructions)

  • Any new dependencies have been scanned for license and vulnerabilities

  • CODEOWNERS updated if ownership changes

  • Documentation updated as needed

  • Update tava architecture diagram if there is a significant design change in PR.

  • The reviewers assigned automatically/manually are appropriate for the PR.

  • Please check this after reviewing the above items as appropriate for this PR.

GitHub Bot Help

/bot [-h] ['run', 'kill', 'skip', 'reuse-pipeline'] ...

Provide a user friendly way for developers to interact with a Jenkins server.

Run /bot [-h|--help] to print this help message.

See details below for each supported subcommand.

Details

run [--reuse-test (optional)pipeline-id --disable-fail-fast --skip-test --stage-list "A10-PyTorch-1, xxx" --gpu-type "A30, H100_PCIe" --test-backend "pytorch, cpp" --add-multi-gpu-test --only-multi-gpu-test --disable-multi-gpu-test --post-merge --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx" --detailed-log --debug(experimental)]

Launch build/test pipelines. All previously running jobs will be killed.

--reuse-test (optional)pipeline-id (OPTIONAL) : Allow the new pipeline to reuse build artifacts and skip successful test stages from a specified pipeline or the last pipeline if no pipeline-id is indicated. If the Git commit ID has changed, this option will be always ignored. The DEFAULT behavior of the bot is to reuse build artifacts and successful test results from the last pipeline.

--disable-reuse-test (OPTIONAL) : Explicitly prevent the pipeline from reusing build artifacts and skipping successful test stages from a previous pipeline. Ensure that all builds and tests are run regardless of previous successes.

--disable-fail-fast (OPTIONAL) : Disable fail fast on build/tests/infra failures.

--skip-test (OPTIONAL) : Skip all test stages, but still run build stages, package stages and sanity check stages. Note: Does NOT update GitHub check status.

--stage-list "A10-PyTorch-1, xxx" (OPTIONAL) : Only run the specified test stages. Examples: "A10-PyTorch-1, xxx". Note: Does NOT update GitHub check status.

--gpu-type "A30, H100_PCIe" (OPTIONAL) : Only run the test stages on the specified GPU types. Examples: "A30, H100_PCIe". Note: Does NOT update GitHub check status.

--test-backend "pytorch, cpp" (OPTIONAL) : Skip test stages which don't match the specified backends. Only support [pytorch, cpp, tensorrt, triton]. Examples: "pytorch, cpp" (does not run test stages with tensorrt or triton backend). Note: Does NOT update GitHub pipeline status.

--only-multi-gpu-test (OPTIONAL) : Only run the multi-GPU tests. Note: Does NOT update GitHub check status.

--disable-multi-gpu-test (OPTIONAL) : Disable the multi-GPU tests. Note: Does NOT update GitHub check status.

--add-multi-gpu-test (OPTIONAL) : Force run the multi-GPU tests in addition to running L0 pre-merge pipeline.

--post-merge (OPTIONAL) : Run the L0 post-merge pipeline instead of the ordinary L0 pre-merge pipeline.

--extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx" (OPTIONAL) : Run the ordinary L0 pre-merge pipeline and specified test stages. Examples: --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx".

--detailed-log (OPTIONAL) : Enable flushing out all logs to the Jenkins console. This will significantly increase the log volume and may slow down the job.

--debug (OPTIONAL) : Experimental feature. Enable access to the CI container for debugging purpose. Note: Specify exactly one stage in the stage-list parameter to access the appropriate container environment. Note: Does NOT update GitHub check status.

For guidance on mapping tests to stage names, see docs/source/reference/ci-overview.md
and the scripts/test_to_stage_mapping.py helper.

kill

kill

Kill all running builds associated with pull request.

skip

skip --comment COMMENT

Skip testing for latest commit on pull request. --comment "Reason for skipping build/test" is required. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.

reuse-pipeline

reuse-pipeline

Reuse a previous pipeline to validate current commit. This action will also kill all currently running builds associated with the pull request. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.

Signed-off-by: Ludwig Schneider <lschneider@nvidia.com>
Signed-off-by: Ludwig Schneider <lschneider@nvidia.com>
Signed-off-by: Ludwig Schneider <lschneider@nvidia.com>
Signed-off-by: Ludwig Schneider <lschneider@nvidia.com>
Signed-off-by: Ludwig Schneider <lschneider@nvidia.com>
Signed-off-by: Ludwig Schneider <lschneider@nvidia.com>
Signed-off-by: Ludwig Schneider <lschneider@nvidia.com>
Signed-off-by: Ludwig Schneider <lschneider@nvidia.com>
Signed-off-by: Ludwig Schneider <lschneider@nvidia.com>
@coderabbitai
Copy link
Contributor

coderabbitai bot commented Jan 27, 2026

📝 Walkthrough

Walkthrough

The changes introduce CUDA graph capture awareness to the NCCL window buffer allocation system. When allocating buffers, the code now checks if the current CUDA stream is in capture state; if so, it returns an empty buffer to avoid allocation during capture. The changes also add error handling and fallback mechanisms for symmetric window buffer creation in the AllReduce operation, gracefully degrading to input buffers when window buffers are unavailable.

Changes

Cohort / File(s) Summary
NCCL Window Buffer Allocation
cpp/tensorrt_llm/common/ncclUtils.cpp, cpp/tensorrt_llm/common/ncclUtils.h
Added CUDA headers and capture-state checking logic. In ncclUtils.cpp, buffer allocation now detects active CUDA graph capture and returns empty buffers when capturing. In ncclUtils.h, createNCCLWindowTensor wrapped in try/catch to handle exceptions and invalid buffers gracefully, returning empty/invalid buffers instead of throwing.
NCCL AllReduce Operation
cpp/tensorrt_llm/thop/allreduceOp.cpp
Modified AllreduceOp.runNCCLAllReduceSymmetric to validate window buffers before use and conditionally create symmetric window buffers only when existing buffers are unavailable and size exceeds threshold. Added fallback logic to use input/output tensors directly when symmetric buffer creation fails, with warnings logged for each fallback case.

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~25 minutes

🚥 Pre-merge checks | ✅ 2 | ❌ 1
❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 0.00% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Title check ✅ Passed The title '[None][fix] nccl symmetric with graceful fallbacks' directly relates to the main changes: adding graceful fallback behavior and CUDA graph capture detection to NCCL symmetric operations.
Description check ✅ Passed The PR description explains the rationale and changes but is quite brief. However, it covers the key aspects: graceful fallback on buffer registration failure, CUDA graph capture detection, test coverage claim, and checklist confirmation.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing touches
  • 📝 Generate docstrings

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@nv-lschneider
Copy link
Collaborator Author

@coderabitai review

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (4)
cpp/tensorrt_llm/common/ncclUtils.h (1)

1-3: Update the NVIDIA copyright year to reflect 2026.

This file was modified in this PR, but the header still lists 2025. Please bump the year to 2026 (or 2025–2026) to reflect the latest meaningful modification. As per coding guidelines, please keep the header year current.

🔧 Suggested update
- * Copyright (c) 2025, NVIDIA CORPORATION.  All rights reserved.
+ * Copyright (c) 2025-2026, NVIDIA CORPORATION.  All rights reserved.
cpp/tensorrt_llm/common/ncclUtils.cpp (1)

1-3: Update the NVIDIA copyright year to reflect 2026.

This file was modified in this PR, but the header still lists 2025. Please bump the year to 2026 (or 2025–2026) to reflect the latest meaningful modification. As per coding guidelines, please keep the header year current.

🔧 Suggested update
- * Copyright (c) 2025, NVIDIA CORPORATION.  All rights reserved.
+ * Copyright (c) 2025-2026, NVIDIA CORPORATION.  All rights reserved.
cpp/tensorrt_llm/thop/allreduceOp.cpp (2)

1-4: Update the SPDX copyright range to include 2026.

This file was modified in this PR, but the header ends at 2024. Please update the range to include 2026 to reflect the latest meaningful modification. As per coding guidelines, please keep the header year current.

🔧 Suggested update
- * SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION &
+ * SPDX-FileCopyrightText: Copyright (c) 1993-2026 NVIDIA CORPORATION &

512-551: Gate output window allocation on input registration status to prevent undefined NCCL behavior.

The current code can allocate a registered output buffer even when the input buffer is unregistered (due to size threshold or allocation failure). NCCL requires that for any collective call, either all ranks use registered buffers or all use unregistered buffers; mixing them results in undefined behavior.

Condition output window allocation on the same decision that controls input registration (i.e., whether windowBuffer0.isValid() or via a useSymmetric flag) to ensure both input and output follow the same registration path.

♻️ Example refactor
+        bool const useSymmetric = windowBuffer0.isValid();
+
+        torch::Tensor outputTensor;
+        void* outputPtr = nullptr;
+        if (useSymmetric)
+        {
+            auto [normOut, windowBuffer1] = createNCCLWindowTensor(comm, input.sizes(), input.scalar_type());
+            if (windowBuffer1.isValid())
+            {
+                outputTensor = normOut;
+                outputPtr = windowBuffer1.ptr;
+            }
+            else
+            {
+                outputTensor = torch::empty_like(inputTensor);
+                outputPtr = outputTensor.data_ptr();
+                TLLM_LOG_WARNING(
+                    "[runNCCLAllReduceSymmetric] No valid symmetric buffer available; "
+                    "using plain CUDA tensor for output");
+            }
+        }
+        else
+        {
+            outputTensor = torch::empty_like(inputTensor);
+            outputPtr = outputTensor.data_ptr();
+        }
-        // Use window-backed output buffer
-        auto [normOut, windowBuffer1] = createNCCLWindowTensor(comm, input.sizes(), input.scalar_type());
-        torch::Tensor outputTensor = windowBuffer1.isValid() ? normOut : torch::empty_like(inputTensor);
-        void* outputPtr = windowBuffer1.isValid() ? windowBuffer1.ptr : outputTensor.data_ptr();
-        if (!windowBuffer1.isValid())
-        {
-            TLLM_LOG_WARNING(
-                "[runNCCLAllReduceSymmetric] No valid symmetric buffer available; "
-                "using plain CUDA tensor for output");
-        }

Also applies to: 562–569

@nv-lschneider
Copy link
Collaborator Author

/bot run --disable-fail-fast --add-multi-gpu-test

@tensorrt-cicd
Copy link
Collaborator

PR_Github #33778 [ run ] triggered by Bot. Commit: b1e932d

Copy link
Collaborator

@hyukn hyukn left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess there might be no coverage in today's tests for NCCL_SYMMETRIC_PATH. Because #10463 has a very simple change. Do you think it will be better to combine them to get more feedback quickly from CI on this PR?

@hyukn hyukn requested a review from liji-nv January 27, 2026 23:51
@tensorrt-cicd
Copy link
Collaborator

PR_Github #33778 [ run ] completed with state SUCCESS. Commit: b1e932d
/LLM/main/L0_MergeRequest_PR pipeline #26051 completed with status: 'FAILURE'

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

Signed-off-by: Ludwig Schneider <lschneider@nvidia.com>
Signed-off-by: Ludwig Schneider <lschneider@nvidia.com>
Signed-off-by: Ludwig Schneider <lschneider@nvidia.com>
@nv-lschneider nv-lschneider requested a review from a team as a code owner January 28, 2026 02:46
Signed-off-by: Ludwig Schneider <lschneider@nvidia.com>
@nv-lschneider
Copy link
Collaborator Author

I added the activation of NCCL_SYMMETRIC in autotuning for better test coverage.
CI for #10463 was green at some point before rebase, so that should be OK to merge with this one too.

@nv-lschneider
Copy link
Collaborator Author

/bot run --disable-fail-fast --add-multi-gpu-test

@tensorrt-cicd
Copy link
Collaborator

PR_Github #33803 [ run ] triggered by Bot. Commit: 0529e4c

@tensorrt-cicd
Copy link
Collaborator

PR_Github #33803 [ run ] completed with state SUCCESS. Commit: 0529e4c
/LLM/main/L0_MergeRequest_PR pipeline #26071 completed with status: 'FAILURE'

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

@nv-lschneider
Copy link
Collaborator Author

/bot run --disable-fail-fast --add-multi-gpu-test

@tensorrt-cicd
Copy link
Collaborator

PR_Github #33891 [ run ] triggered by Bot. Commit: 0529e4c

@tensorrt-cicd
Copy link
Collaborator

PR_Github #33891 [ run ] completed with state FAILURE. Commit: 0529e4c

@nv-lschneider
Copy link
Collaborator Author

/bot run --disable-fail-fast --add-multi-gpu-test

@tensorrt-cicd
Copy link
Collaborator

PR_Github #33894 [ run ] triggered by Bot. Commit: 0529e4c

@tensorrt-cicd
Copy link
Collaborator

PR_Github #33894 [ run ] completed with state SUCCESS. Commit: 0529e4c
/LLM/main/L0_MergeRequest_PR pipeline #26142 completed with status: 'FAILURE'

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

@nv-lschneider
Copy link
Collaborator Author

/bot run --disable-fail-fast --add-multi-gpu-test

@tensorrt-cicd
Copy link
Collaborator

PR_Github #33920 [ run ] triggered by Bot. Commit: 0529e4c

@tensorrt-cicd
Copy link
Collaborator

PR_Github #33920 [ run ] completed with state SUCCESS. Commit: 0529e4c
/LLM/main/L0_MergeRequest_PR pipeline #26160 completed with status: 'SUCCESS'

@Tabrizian
Copy link
Member

I think this test needs to be unwaived too:

accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_nvfp4_multi_gpus[latency] SKIP (https://nvbugs/5814309)

@nv-lschneider nv-lschneider changed the title [https://nvbugs/5814309][fix] nccl symmetric with graceful fallbacks [None][fix] nccl symmetric with graceful fallbacks Jan 28, 2026
@Tabrizian Tabrizian merged commit 4e10bf8 into NVIDIA:main Jan 28, 2026
7 checks passed
@hyukn
Copy link
Collaborator

hyukn commented Jan 29, 2026

Hi @nv-lschneider. Thanks for the effort. I noticed that the last failed CI before the final successful one hits this test

_torch/auto_deploy/unit/multigpu/transformations/library/test_allreduce_residual_rmsnorm_fusion.py FAILED [ 24%]

I am not sure if it could be a flaky one.

hyukn pushed a commit to hyukn/TensorRT-LLM that referenced this pull request Jan 29, 2026
Signed-off-by: Ludwig Schneider <lschneider@nvidia.com>
@nv-lschneider nv-lschneider deleted the debug-nccl-symm branch January 29, 2026 21:03
@nv-lschneider
Copy link
Collaborator Author

Hi @nv-lschneider. Thanks for the effort. I noticed that the last failed CI before the final successful one hits this test

_torch/auto_deploy/unit/multigpu/transformations/library/test_allreduce_residual_rmsnorm_fusion.py FAILED [ 24%]

I am not sure if it could be a flaky one.

Thanks for pointing that out.
I ran the test locally about 20 times on GB200 and did not notice any failure.
So at least on that hardware it does not seem very flaky to me.

hyukn added a commit that referenced this pull request Jan 30, 2026
… graceful fallbacks (#11042) (#11090)

Signed-off-by: Ludwig Schneider <lschneider@nvidia.com>
Co-authored-by: Ludwig Schneider <lschneider@nvidia.com>
dominicshanshan pushed a commit to dominicshanshan/TensorRT-LLM that referenced this pull request Feb 13, 2026
…c with graceful fallbacks (NVIDIA#11042) (NVIDIA#11090)

Signed-off-by: Ludwig Schneider <lschneider@nvidia.com>
Co-authored-by: Ludwig Schneider <lschneider@nvidia.com>
dominicshanshan pushed a commit to dominicshanshan/TensorRT-LLM that referenced this pull request Feb 13, 2026
…c with graceful fallbacks (NVIDIA#11042) (NVIDIA#11090)

Signed-off-by: Ludwig Schneider <lschneider@nvidia.com>
Co-authored-by: Ludwig Schneider <lschneider@nvidia.com>
Signed-off-by: Wangshanshan <30051912+dominicshanshan@users.noreply.github.com>
dominicshanshan pushed a commit to dominicshanshan/TensorRT-LLM that referenced this pull request Feb 13, 2026
…c with graceful fallbacks (NVIDIA#11042) (NVIDIA#11090)

Signed-off-by: Ludwig Schneider <lschneider@nvidia.com>
Co-authored-by: Ludwig Schneider <lschneider@nvidia.com>
Signed-off-by: Wangshanshan <30051912+dominicshanshan@users.noreply.github.com>
dominicshanshan pushed a commit to dominicshanshan/TensorRT-LLM that referenced this pull request Feb 13, 2026
…c with graceful fallbacks (NVIDIA#11042) (NVIDIA#11090)

Signed-off-by: Ludwig Schneider <lschneider@nvidia.com>
Co-authored-by: Ludwig Schneider <lschneider@nvidia.com>
Signed-off-by: Wangshanshan <30051912+dominicshanshan@users.noreply.github.com>
dominicshanshan pushed a commit to dominicshanshan/TensorRT-LLM that referenced this pull request Feb 13, 2026
…c with graceful fallbacks (NVIDIA#11042) (NVIDIA#11090)

Signed-off-by: Ludwig Schneider <lschneider@nvidia.com>
Co-authored-by: Ludwig Schneider <lschneider@nvidia.com>
Signed-off-by: Wangshanshan <30051912+dominicshanshan@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants