Skip to content

Comments

[#10245][feat] AutoDeploy: Add Minimax M2 support#10525

Merged
bmarimuthu-nv merged 7 commits intoNVIDIA:mainfrom
nv-auto-deploy:bala/minimax-m2
Jan 28, 2026
Merged

[#10245][feat] AutoDeploy: Add Minimax M2 support#10525
bmarimuthu-nv merged 7 commits intoNVIDIA:mainfrom
nv-auto-deploy:bala/minimax-m2

Conversation

@bmarimuthu-nv
Copy link
Collaborator

@bmarimuthu-nv bmarimuthu-nv commented Jan 8, 2026

Summary by CodeRabbit

  • New Features

    • Added support for exporting MiniMax-M2 models with mixture-of-experts via PyTorch.
    • Added distributed RMSNorm operator for efficient sharded inference across multiple GPUs.
    • Enabled RMSNorm weight sharding in attention layers for improved distributed performance.
  • Tests

    • Added validation tests for MiniMax-M2 mixture-of-experts patching.
    • Added multi-GPU tests for distributed RMSNorm operator.
    • Added transformation tests for RMSNorm sharding.

✏️ Tip: You can customize this high-level summary in your review settings.

Description

Fixes #10245

  • Adds MoE patch for Minimax M2 MoE layer
  • Fixes qk_norm weight sharding
    • q/k/v -> qk_norm -> attn -> o_proj. The qk_norm weights also need to be sharded on head dim same way as qkv.
  • Adds some debug utils to dump graph IR after every transform based on the env var: AD_DUMP_GRAPHS_DIR=<dir to dump graphs>

Test Coverage

PR Checklist

Please review the following before submitting your PR:

  • PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.

  • PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.

  • Test cases are provided for new code paths (see test instructions)

  • Any new dependencies have been scanned for license and vulnerabilities

  • CODEOWNERS updated if ownership changes

  • Documentation updated as needed

  • Update tava architecture diagram if there is a significant design change in PR.

  • The reviewers assigned automatically/manually are appropriate for the PR.

  • Please check this after reviewing the above items as appropriate for this PR.

GitHub Bot Help

/bot [-h] ['run', 'kill', 'skip', 'reuse-pipeline'] ...

Provide a user friendly way for developers to interact with a Jenkins server.

Run /bot [-h|--help] to print this help message.

See details below for each supported subcommand.

Details

run [--reuse-test (optional)pipeline-id --disable-fail-fast --skip-test --stage-list "A10-PyTorch-1, xxx" --gpu-type "A30, H100_PCIe" --test-backend "pytorch, cpp" --add-multi-gpu-test --only-multi-gpu-test --disable-multi-gpu-test --post-merge --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx" --detailed-log --debug(experimental)]

Launch build/test pipelines. All previously running jobs will be killed.

--reuse-test (optional)pipeline-id (OPTIONAL) : Allow the new pipeline to reuse build artifacts and skip successful test stages from a specified pipeline or the last pipeline if no pipeline-id is indicated. If the Git commit ID has changed, this option will be always ignored. The DEFAULT behavior of the bot is to reuse build artifacts and successful test results from the last pipeline.

--disable-reuse-test (OPTIONAL) : Explicitly prevent the pipeline from reusing build artifacts and skipping successful test stages from a previous pipeline. Ensure that all builds and tests are run regardless of previous successes.

--disable-fail-fast (OPTIONAL) : Disable fail fast on build/tests/infra failures.

--skip-test (OPTIONAL) : Skip all test stages, but still run build stages, package stages and sanity check stages. Note: Does NOT update GitHub check status.

--stage-list "A10-PyTorch-1, xxx" (OPTIONAL) : Only run the specified test stages. Examples: "A10-PyTorch-1, xxx". Note: Does NOT update GitHub check status.

--gpu-type "A30, H100_PCIe" (OPTIONAL) : Only run the test stages on the specified GPU types. Examples: "A30, H100_PCIe". Note: Does NOT update GitHub check status.

--test-backend "pytorch, cpp" (OPTIONAL) : Skip test stages which don't match the specified backends. Only support [pytorch, cpp, tensorrt, triton]. Examples: "pytorch, cpp" (does not run test stages with tensorrt or triton backend). Note: Does NOT update GitHub pipeline status.

--only-multi-gpu-test (OPTIONAL) : Only run the multi-GPU tests. Note: Does NOT update GitHub check status.

--disable-multi-gpu-test (OPTIONAL) : Disable the multi-GPU tests. Note: Does NOT update GitHub check status.

--add-multi-gpu-test (OPTIONAL) : Force run the multi-GPU tests in addition to running L0 pre-merge pipeline.

--post-merge (OPTIONAL) : Run the L0 post-merge pipeline instead of the ordinary L0 pre-merge pipeline.

--extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx" (OPTIONAL) : Run the ordinary L0 pre-merge pipeline and specified test stages. Examples: --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx".

--detailed-log (OPTIONAL) : Enable flushing out all logs to the Jenkins console. This will significantly increase the log volume and may slow down the job.

--debug (OPTIONAL) : Experimental feature. Enable access to the CI container for debugging purpose. Note: Specify exactly one stage in the stage-list parameter to access the appropriate container environment. Note: Does NOT update GitHub check status.

For guidance on mapping tests to stage names, see docs/source/reference/ci-overview.md
and the scripts/test_to_stage_mapping.py helper.

kill

kill

Kill all running builds associated with pull request.

skip

skip --comment COMMENT

Skip testing for latest commit on pull request. --comment "Reason for skipping build/test" is required. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.

reuse-pipeline

reuse-pipeline

Reuse a previous pipeline to validate current commit. This action will also kill all currently running builds associated with the pull request. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.

@bmarimuthu-nv bmarimuthu-nv changed the title [AutoDeploy] Minimax M2 support ##10245 [feat] AutoDeploy: Add Minimax M2 support Jan 8, 2026
@bmarimuthu-nv bmarimuthu-nv changed the title ##10245 [feat] AutoDeploy: Add Minimax M2 support #10245 [feat] AutoDeploy: Add Minimax M2 support Jan 8, 2026
@bmarimuthu-nv bmarimuthu-nv changed the title #10245 [feat] AutoDeploy: Add Minimax M2 support [#10245][feat] AutoDeploy: Add Minimax M2 support Jan 8, 2026
@bmarimuthu-nv
Copy link
Collaborator Author

@coderabbitai summary

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Jan 8, 2026

✅ Actions performed

Summary regeneration triggered.

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Jan 8, 2026

📝 Walkthrough

Walkthrough

The changes introduce support for the MiniMax-M2 model in AutoDeploy through torch-export-compatible MoE implementation, distributed RMSNorm sharding for improved performance across multiple devices, and comprehensive unit and functional tests. Includes minor improvements to flashinfer attention and diagnostic error messages.

Changes

Cohort / File(s) Summary
FlashInfer Attention Optimization
tensorrt_llm/_torch/auto_deploy/custom_ops/flashinfer_attention.py
Added .contiguous() calls after reshaping q, k, v tensors to ensure memory contiguity without altering semantics.
MiniMax-M2 MoE Export Support
tensorrt_llm/_torch/auto_deploy/models/patches/minimax_m2.py, tests/unittest/_torch/auto_deploy/unit/singlegpu/models/test_minimax_m2_patches.py
Implements torch-export-friendly MoE forward via minimax_m2_moe method with sigmoid routing, top-k expert selection, and weight normalization. Patches AutoModelForCausalLM.from_config with runtime module replacement mechanism. Tests validate numerical equivalence against original HuggingFace implementation.
Distributed RMSNorm Sharding
tensorrt_llm/_torch/auto_deploy/custom_ops/rms_norm.py, tests/unittest/_torch/auto_deploy/unit/multigpu/custom_ops/test_sharded_rmsnorm.py
Introduces sharded_rmsnorm operator leveraging distributed all_reduce for global normalization across devices. Functional tests verify correctness across multiple GPU ranks with different hidden sizes and dtypes.
Graph Sharding Transformation
tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py, tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_rmsnorm_sharding.py
Adds RMSNormShardingInfo class to detect and replace torch_rmsnorm with sharded variants in attention subgraphs. Extends ShardingTransformContainer with rmsnorm_transforms pipeline. Enhanced _determine_fused_weight_dims and _update_node_args for broader fused-weight pattern matching. Unit tests validate detection and transformation of RMSNorm operations in full and per-head norm configurations.
Diagnostic Improvements
tensorrt_llm/_torch/auto_deploy/utils/node_utils.py
Expanded assertion error message to include diagnostic context (terminating node name and opening node list) when node matching fails.

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

🚥 Pre-merge checks | ✅ 4 | ❌ 1
❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 75.76% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (4 passed)
Check name Status Explanation
Title check ✅ Passed The PR title clearly and specifically describes the main feature being added: support for Minimax M2 in AutoDeploy, directly matching the primary objective from the linked issue #10245.
Description check ✅ Passed The description explains the primary changes (MoE patch, qk_norm sharding fix, debug utilities) and links to the related issue, providing sufficient context for the pull request objectives.
Linked Issues check ✅ Passed The PR implements MoE patch for Minimax M2 and qk_norm weight sharding [#10245], directly addressing the feature request for Minimax model support in AutoDeploy.
Out of Scope Changes check ✅ Passed All code changes are directly related to supporting Minimax M2 (MoE patch, qk_norm sharding, RMSNorm sharding, and debug utilities) with corresponding tests, maintaining focus on the stated objective.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing touches
  • 📝 Generate docstrings

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@bmarimuthu-nv
Copy link
Collaborator Author

/bot run

@bmarimuthu-nv
Copy link
Collaborator Author

/bot run

@bmarimuthu-nv bmarimuthu-nv force-pushed the bala/minimax-m2 branch 2 times, most recently from 8e80cd3 to d20057f Compare January 17, 2026 01:16
@bmarimuthu-nv
Copy link
Collaborator Author

/bot run

2 similar comments
@bmarimuthu-nv
Copy link
Collaborator Author

/bot run

@bmarimuthu-nv
Copy link
Collaborator Author

/bot run

@bmarimuthu-nv bmarimuthu-nv marked this pull request as ready for review January 21, 2026 23:19
@bmarimuthu-nv bmarimuthu-nv requested a review from a team as a code owner January 21, 2026 23:19
@bmarimuthu-nv bmarimuthu-nv enabled auto-merge (squash) January 21, 2026 23:27
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Fix all issues with AI agents
In `@tensorrt_llm/_torch/auto_deploy/models/patches/minimax_m2.py`:
- Around line 1-7: Add the required NVIDIA copyright header (with the year of
latest meaningful modification) at the very top of this file, before the
existing module docstring that explains the MiniMax-M2 MoE patch; ensure the
header follows the project's standard header format and remains a top-of-file
comment so the docstring and the patched AutoModelForCausalLM logic (referenced
in the module text) remain unchanged.
🧹 Nitpick comments (6)
tensorrt_llm/_torch/auto_deploy/utils/node_utils.py (1)

954-958: Improved diagnostic message for assertion failures.

The expanded assertion message now includes the terminating linear node name and the list of opening linear node names, which will help with debugging when this assertion fails.

Consider also including the name of the linear node being checked (linear_nodes[start_lin_index].name) in the message for completeness, since that's the actual node the assertion is validating.

💡 Optional: Include checked node's name
     assert linear_nodes[start_lin_index] in opening_linear_nodes, (
-        f"Linear node not found in opening linear nodes - "
+        f"Linear node {linear_nodes[start_lin_index].name} not found in opening linear nodes - "
         f"terminating_linear_node:{terminating_linear_node.name}, "
         f"opening_linear_nodes: {[n.name for n in opening_linear_nodes]}"
     )
tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_rmsnorm_sharding.py (1)

163-219: LGTM!

The test correctly validates that per-head norm RMSNorm ops are not transformed to sharded variants, which is the expected behavior for GLM-style per-head normalization.

Consider removing or converting the print statement on line 219 to use logging for cleaner test output.

tests/unittest/_torch/auto_deploy/unit/multigpu/custom_ops/test_sharded_rmsnorm.py (1)

93-99: Consider using the project's all_gather wrapper for consistency.

Line 99 uses dist.all_gather directly, while the codebase provides a wrapper at tensorrt_llm/_torch/auto_deploy/distributed/common.py:all_gather that handles the default process group. This is likely fine since initialize() is called, but using the wrapper would be more consistent with other code in the project.

tests/unittest/_torch/auto_deploy/unit/singlegpu/models/test_minimax_m2_patches.py (1)

63-65: Consider using more specific exception handling.

The broad Exception catch could mask unexpected errors. Consider catching more specific exceptions like ValueError, RuntimeError, or OSError that are typical for model loading failures.

Suggested improvement
-    except Exception as e:
+    except (ValueError, RuntimeError, OSError, KeyError) as e:
         print(f"Error extracting layer: {e}")
         return None
tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py (2)

1963-1978: Edge case: slice dimension validation could be more robust.

The logic at lines 1971-1978 handles the case where slice dimensions don't sum to the weight dimension. However, if fused_weight_dims[-1] > weight_dim, the adjustment on line 1973 could result in a negative value when sum(fused_weight_dims[:-1]) > weight_dim.

Consider adding validation for this edge case:

Suggested validation
             if sum(fused_weight_dims) != weight_dim:
                 if fused_weight_dims[-1] > weight_dim:
-                    fused_weight_dims[-1] = weight_dim - sum(fused_weight_dims[:-1])
+                    adjusted = weight_dim - sum(fused_weight_dims[:-1])
+                    if adjusted <= 0:
+                        ad_logger.warning(
+                            f"Invalid slice dimensions: adjusted last dim would be {adjusted}. Skipping."
+                        )
+                        return
+                    fused_weight_dims[-1] = adjusted

2055-2059: Minor: Use next(iter(...)) for single element access.

Static analysis suggests using next(iter(weight_node.users)) instead of list(weight_node.users)[0] for cleaner single element access.

Suggested change
-        user_node = list(weight_node.users)[0]
+        user_node = next(iter(weight_node.users))

@bmarimuthu-nv
Copy link
Collaborator Author

/bot run

2 similar comments
@bmarimuthu-nv
Copy link
Collaborator Author

/bot run

@bmarimuthu-nv
Copy link
Collaborator Author

/bot run

@tensorrt-cicd
Copy link
Collaborator

PR_Github #33319 [ run ] triggered by Bot. Commit: 8483d8a

@tensorrt-cicd
Copy link
Collaborator

PR_Github #33319 [ run ] completed with state SUCCESS. Commit: 8483d8a
/LLM/main/L0_MergeRequest_PR pipeline #25724 completed with status: 'FAILURE'

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

@bmarimuthu-nv
Copy link
Collaborator Author

/bot run

@tensorrt-cicd
Copy link
Collaborator

PR_Github #33623 [ run ] triggered by Bot. Commit: 0656677

@tensorrt-cicd
Copy link
Collaborator

PR_Github #33623 [ run ] completed with state SUCCESS. Commit: 0656677
/LLM/main/L0_MergeRequest_PR pipeline #25938 completed with status: 'FAILURE'

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

@bmarimuthu-nv
Copy link
Collaborator Author

/bot run

@tensorrt-cicd
Copy link
Collaborator

PR_Github #33793 [ run ] triggered by Bot. Commit: 07a8174

@tensorrt-cicd
Copy link
Collaborator

PR_Github #33793 [ run ] completed with state SUCCESS. Commit: 07a8174
/LLM/main/L0_MergeRequest_PR pipeline #26063 completed with status: 'FAILURE'

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

Signed-off-by: Balamurugan Marimuthu <246387390+bmarimuthu-nv@users.noreply.github.com>
Signed-off-by: Balamurugan Marimuthu <246387390+bmarimuthu-nv@users.noreply.github.com>
Signed-off-by: Balamurugan Marimuthu <246387390+bmarimuthu-nv@users.noreply.github.com>
Signed-off-by: Balamurugan Marimuthu <246387390+bmarimuthu-nv@users.noreply.github.com>
Signed-off-by: Balamurugan Marimuthu <246387390+bmarimuthu-nv@users.noreply.github.com>
Signed-off-by: Balamurugan Marimuthu <246387390+bmarimuthu-nv@users.noreply.github.com>
Signed-off-by: Balamurugan Marimuthu <246387390+bmarimuthu-nv@users.noreply.github.com>
@bmarimuthu-nv
Copy link
Collaborator Author

/bot run

@tensorrt-cicd
Copy link
Collaborator

PR_Github #33893 [ run ] triggered by Bot. Commit: fa6cf7e

@bmarimuthu-nv
Copy link
Collaborator Author

/bot run --extra-stage "DGX_B200-4_GPUs-AutoDeploy-1, DGX_H100-4_GPUs-AutoDeploy-1"

@tensorrt-cicd
Copy link
Collaborator

PR_Github #33911 [ run ] triggered by Bot. Commit: fa6cf7e

@tensorrt-cicd
Copy link
Collaborator

PR_Github #33911 [ run ] completed with state SUCCESS. Commit: fa6cf7e
/LLM/main/L0_MergeRequest_PR pipeline #26152 completed with status: 'SUCCESS'
Pipeline passed with automatic retried tests. Check the rerun report for details.
Pipeline has performance regression cases. Check the performance regression report for details.

@bmarimuthu-nv bmarimuthu-nv merged commit 393c3d2 into NVIDIA:main Jan 28, 2026
5 checks passed
@bmarimuthu-nv bmarimuthu-nv linked an issue Jan 29, 2026 that may be closed by this pull request
1 task
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Feature]: Support Minimax model in AutoDeploy [Feature]: AutoDeploy: support rms_norm for attention sharding heuristic

3 participants