[TRTLLM-9111][feat] provide the uniform test framework to test all MoE backends#11128
[TRTLLM-9111][feat] provide the uniform test framework to test all MoE backends#11128xxi-nv merged 2 commits intoNVIDIA:mainfrom
Conversation
|
Will update the CI test DB in another PR. |
📝 WalkthroughWalkthroughA capability checking framework is introduced across MoE backend implementations through an abstract Changes
Estimated code review effort🎯 4 (Complex) | ⏱️ ~50 minutes 🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing touches
🧪 Generate unit tests (beta)
Tip 🧪 Unit Test Generation v2 is now available!We have significantly improved our unit test generation capabilities. To enable: Add this to your reviews:
finishing_touches:
unit_tests:
enabled: trueTry it out by using the Have feedback? Share your thoughts on our Discord thread! Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 4
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (3)
tensorrt_llm/_torch/modules/fused_moe/fused_moe_cute_dsl.py (1)
1-1:⚠️ Potential issue | 🟠 MajorAdd NVIDIA copyright header.
Please add the required NVIDIA copyright header (latest modification year) at the top of this source file.
As per coding guidelines,
**/*.{cpp,cc,cxx,h,hpp,hxx,cu,cuh,py}: All TensorRT-LLM source files should contain an NVIDIA copyright header with the year of latest meaningful modification.tensorrt_llm/_torch/modules/fused_moe/fused_moe_cutlass.py (1)
1-1:⚠️ Potential issue | 🟠 MajorAdd NVIDIA copyright header.
Please add the required NVIDIA copyright header (latest modification year) at the top of this source file.
As per coding guidelines,
**/*.{cpp,cc,cxx,h,hpp,hxx,cu,cuh,py}: All TensorRT-LLM source files should contain an NVIDIA copyright header with the year of latest meaningful modification.tests/unittest/_torch/modules/moe/quantize_utils.py (1)
204-244:⚠️ Potential issue | 🟡 MinorEnsure custom SwiGLU path activates when any swiglu param is set.
If only
swiglu_beta/swiglu_limitare provided (withoutswiglu_alpha), the custom activation is currently skipped.🐛 Suggested fix
- self.experts = nn.ModuleList( + use_custom_swiglu = any( + v is not None for v in (swiglu_alpha, swiglu_beta, swiglu_limit) + ) + self.experts = nn.ModuleList( @@ - activation=custom_swiglu if swiglu_alpha is not None else F.silu, + activation=custom_swiglu if use_custom_swiglu else F.silu,
🤖 Fix all issues with AI agents
In `@tensorrt_llm/_torch/modules/fused_moe/configurable_moe.py`:
- Around line 98-124: The can_implement method in class ConfigurableMoE declares
parameters quant_algo, dtype_activation and gptoss_style but doesn't use them,
triggering ARG003; fix it by explicitly marking them unused inside
ConfigurableMoE.can_implement (e.g., assign them to dummy vars or prefix with
underscores) so linters accept it — update the method body of
ConfigurableMoE.can_implement to reference quant_algo, dtype_activation, and
gptoss_style in a no-op way (e.g., `_ = quant_algo; _ = dtype_activation; _ =
gptoss_style`) or rename the params to
_quant_algo/_dtype_activation/_gptoss_style to silence the lint while keeping
the return behavior unchanged.
In `@tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py`:
- Line 1: Add the required NVIDIA copyright header to the top of the module file
fused_moe_deepgemm.py: insert the standard NVIDIA copyright block (matching
other TensorRT-LLM source files) including the latest modification year and
license text as the very first lines before any imports or code so the file
complies with the repository rule for **/*.{py} sources.
In `@tensorrt_llm/_torch/modules/fused_moe/fused_moe_triton.py`:
- Around line 1266-1331: Add the required NVIDIA copyright header at the top of
the file, move the local import "from tensorrt_llm.models.modeling_utils import
QuantAlgo" out of can_implement into module scope using namespace-preserving
import (e.g. "from tensorrt_llm.models import modeling_utils"), then update the
can_implement signature/type hints and all comparisons to reference
modeling_utils.QuantAlgo (change Optional["QuantAlgo"] to
Optional[modeling_utils.QuantAlgo] and replace QuantAlgo.* checks with
modeling_utils.QuantAlgo.*).
In `@tensorrt_llm/_torch/modules/fused_moe/fused_moe_trtllm_gen.py`:
- Line 1: Add the NVIDIA copyright header (with the latest modification year) to
the top of the module file fused_moe_trtllm_gen.py before any imports (e.g.,
before the existing "import inspect" line); ensure the header matches the
project's required header style used for TensorRT-LLM Python files and includes
the correct year, copyright owner (NVIDIA CORPORATION) and any required
license/boilerplate text.
🧹 Nitpick comments (11)
tensorrt_llm/_torch/modules/fused_moe/interface.py (1)
10-12: Prefer module-level imports for the new logger/QuantAlgo additions.
Keeps module namespaces intact and aligns with the repo import guideline.♻️ Suggested adjustment
-from tensorrt_llm.logger import logger -from tensorrt_llm.models.modeling_utils import QuantAlgo +import tensorrt_llm.logger as trtllm_logger +import tensorrt_llm.models.modeling_utils as modeling_utils @@ - logger.warning(reason) + trtllm_logger.logger.warning(reason) @@ - quant_algo: Optional[QuantAlgo], + quant_algo: Optional[modeling_utils.QuantAlgo],As per coding guidelines: Always maintain the namespace when importing Python modules, even if only one class or function from a module is used.
tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py (1)
7-10: Prefer module-qualified imports for new utility references.The new imports should keep their module namespaces (including
_warn_and_return) to align with project style. Consider module-qualified access instead offrom ... import ....♻️ Suggested refactor (namespace imports)
-from tensorrt_llm._utils import get_sm_version, nvtx_range -from tensorrt_llm.models.modeling_utils import QuantAlgo +import tensorrt_llm._utils as trt_utils +import tensorrt_llm.models.modeling_utils as modeling_utils +from . import interface as moe_interface @@ -@nvtx_range("[DG] preprocess_after_permute") +@trt_utils.nvtx_range("[DG] preprocess_after_permute") @@ -@nvtx_range("[DG]") +@trt_utils.nvtx_range("[DG]") @@ -@nvtx_range("[DG] forward") +@trt_utils.nvtx_range("[DG] forward") @@ - quant_algo: Optional[QuantAlgo], + quant_algo: Optional[modeling_utils.QuantAlgo], @@ - from .interface import _warn_and_return - - sm_version = get_sm_version() + sm_version = trt_utils.get_sm_version() @@ - return _warn_and_return( + return moe_interface._warn_and_return( @@ - if quant_algo == QuantAlgo.FP8_BLOCK_SCALES: + if quant_algo == modeling_utils.QuantAlgo.FP8_BLOCK_SCALES: return True, None - return _warn_and_return( + return moe_interface._warn_and_return(As per coding guidelines, Always maintain the namespace when importing Python modules, even if only one class or function from a module is used.
tensorrt_llm/_torch/modules/fused_moe/fused_moe_trtllm_gen.py (1)
4-13: Prefer module-qualified imports for new utility references.New imports should retain module namespaces per project style. Consider module-qualified access for
_utilsandmodeling_utils.♻️ Suggested refactor (namespace imports)
-from tensorrt_llm._utils import get_sm_version -from tensorrt_llm.models.modeling_utils import QuantAlgo +import tensorrt_llm._utils as trt_utils +import tensorrt_llm.models.modeling_utils as modeling_utils @@ - _SUPPORTED_QUANT_ALGOS = { - QuantAlgo.NVFP4, + _SUPPORTED_QUANT_ALGOS = { + modeling_utils.QuantAlgo.NVFP4, @@ - sm_version = get_sm_version() + sm_version = trt_utils.get_sm_version() @@ - quant_algo: Optional[QuantAlgo], + quant_algo: Optional[modeling_utils.QuantAlgo], @@ - sm_version = get_sm_version() + sm_version = trt_utils.get_sm_version()As per coding guidelines, Always maintain the namespace when importing Python modules, even if only one class or function from a module is used.
tensorrt_llm/_torch/modules/fused_moe/fused_moe_cute_dsl.py (1)
7-9: Prefer module-qualified imports for new utility references.To align with namespace import guidance, consider module-qualified access for
_utilsandmodeling_utils.♻️ Suggested refactor (namespace imports)
-from tensorrt_llm._utils import get_sm_version, is_sm_100f -from tensorrt_llm.models.modeling_utils import QuantAlgo +import tensorrt_llm._utils as trt_utils +import tensorrt_llm.models.modeling_utils as modeling_utils @@ - if is_sm_100f(): + if trt_utils.is_sm_100f(): @@ - quant_algo: Optional[QuantAlgo], + quant_algo: Optional[modeling_utils.QuantAlgo], @@ - sm_version = get_sm_version() + sm_version = trt_utils.get_sm_version() @@ - if quant_algo == QuantAlgo.NVFP4: + if quant_algo == modeling_utils.QuantAlgo.NVFP4:As per coding guidelines, Always maintain the namespace when importing Python modules, even if only one class or function from a module is used.
tensorrt_llm/_torch/modules/fused_moe/fused_moe_cutlass.py (1)
10-12: Prefer module-qualified imports for new utility references.To follow the namespace-import guideline, consider module-qualified access for
_utilsandmodeling_utilsand update the table references accordingly.♻️ Suggested refactor (namespace imports)
-from tensorrt_llm._utils import get_sm_version -from tensorrt_llm.models.modeling_utils import QuantAlgo +import tensorrt_llm._utils as trt_utils +import tensorrt_llm.models.modeling_utils as modeling_utils @@ - QuantAlgo.FP8: { + modeling_utils.QuantAlgo.FP8: { @@ - sm_version = get_sm_version() + sm_version = trt_utils.get_sm_version()As per coding guidelines, Always maintain the namespace when importing Python modules, even if only one class or function from a module is used.
tests/unittest/_torch/modules/moe/quantize_utils.py (1)
21-27: Prefer module-qualified helper imports.For consistency with project style, import helper modules and qualify usages (apply the same change across all call sites).
♻️ Suggested refactor (namespace imports)
-from _torch.helpers import ( - calc_woq_tolerence, - per_block_cast_to_fp8, - per_block_cast_to_fp8_e8m0, - per_token_cast_to_fp8_e8m0, -) -from utils.util import check_accuracy +import _torch.helpers as torch_helpers +import utils.util as util @@ - check_accuracy(output, ref_output, rtol=2e-1, atol=2e-1, percent=0.96) + util.check_accuracy(output, ref_output, rtol=2e-1, atol=2e-1, percent=0.96) @@ - quant_fn = per_block_cast_to_fp8_e8m0 if use_e8m0_scale else per_block_cast_to_fp8 + quant_fn = (torch_helpers.per_block_cast_to_fp8_e8m0 + if use_e8m0_scale else torch_helpers.per_block_cast_to_fp8) @@ - act_fp8, act_sf = per_token_cast_to_fp8_e8m0(permuted_data) + act_fp8, act_sf = torch_helpers.per_token_cast_to_fp8_e8m0(permuted_data) @@ - atol = calc_woq_tolerence(ref_output, weight_dtype) + atol = torch_helpers.calc_woq_tolerence(ref_output, weight_dtype)As per coding guidelines, Always maintain the namespace when importing Python modules, even if only one class or function from a module is used.
tests/unittest/_torch/modules/moe/test_moe_backend.py (5)
242-246: Unusedquant_algoparameter.The
quant_algoparameter is declared but never used in this function. If it's included for API consistency with othershould_skip_*functions or reserved for future use, consider prefixing it with an underscore to signal intent:def should_skip_gptoss( backend_type: MoeBackendType, - quant_algo: Optional[QuantAlgo], + _quant_algo: Optional[QuantAlgo], gptoss_style: bool, ) -> Optional[str]:
275-278: Unusedquant_algoparameter.Similar to
should_skip_gptoss, thequant_algoparameter is included but not used. Consider prefixing with underscore or removing if not needed for API consistency:def supports_autotuner_capture( backend_type: MoeBackendType, - quant_algo: Optional[QuantAlgo], + _quant_algo: Optional[QuantAlgo], ) -> bool:
902-902: Consider usingtempfilefor the cache path.The hardcoded
/tmp/moe_autotuner_cache.jsonpath could cause issues in multi-user environments or parallel test runs. Consider usingtempfilefor safer temporary file handling:🛡️ Proposed fix
+import tempfile +import os ... - with torch.inference_mode(), autotune(cache_path="/tmp/moe_autotuner_cache.json"): + # Use a unique temp file to avoid conflicts in parallel test runs + with tempfile.NamedTemporaryFile(suffix=".json", delete=False) as f: + cache_path = f.name + try: + with torch.inference_mode(), autotune(cache_path=cache_path): + _ = run_moe() + finally: + if os.path.exists(cache_path): + os.unlink(cache_path)Alternatively, if the cache file is intentionally shared across test runs for performance, document this in a comment.
894-898: AutoTuner state modifications may leak to other tests.The test modifies
AutoTunersingleton state (warmup,repeat,stream_delay_micro_secs) without restoring original values. If tests run in the same process, this could affect subsequent tests.Consider saving and restoring the original values:
🛠️ Proposed fix
# Configure AutoTuner for faster profiling (reduce warmup/repeat for unit tests) autotuner = AutoTuner.get() + original_warmup = autotuner.warmup + original_repeat = autotuner.repeat + original_stream_delay = autotuner.stream_delay_micro_secs autotuner.warmup = 0 # default: 2 autotuner.repeat = 1 # default: 10 autotuner.stream_delay_micro_secs = 10 # default: 1000 + + try: + # ... rest of the test ... + finally: + # Restore original AutoTuner state + autotuner.warmup = original_warmup + autotuner.repeat = original_repeat + autotuner.stream_delay_micro_secs = original_stream_delayAlternatively, if the test is always skipped (as indicated by the skip marker), this may be a non-issue, but it's good practice for when the skip is removed.
790-792: Direct assignment tomapping.rankafter construction.Assigning
mapping.rank = mpi_rank()directly after creating aMapping()object works but is unusual. Consider passing the rank during construction if theMappingclass supports it:mapping = Mapping(rank=mpi_rank())If the class doesn't support this, the current approach is fine.
|
/bot run --disable-fail-fast |
|
PR_Github #34161 [ run ] triggered by Bot. Commit: |
|
PR_Github #34161 [ run ] completed with state
|
|
/bot run --disable-fail-fast |
|
PR_Github #34374 [ run ] triggered by Bot. Commit: |
…E backends Signed-off-by: xxi <xxi@nvidia.com>
|
/bot run --disable-fail-fast |
|
PR_Github #34377 [ run ] triggered by Bot. Commit: |
|
PR_Github #34377 [ run ] completed with state
|
|
/bot run --disable-fail-fast |
|
PR_Github #34464 [ run ] triggered by Bot. Commit: |
|
PR_Github #34464 [ run ] completed with state
|
|
/bot run --disable-fail-fast |
|
PR_Github #34505 [ run ] triggered by Bot. Commit: |
|
PR_Github #34505 [ run ] completed with state
|
|
/bot run --disable-fail-fast |
|
PR_Github #34575 [ run ] triggered by Bot. Commit: |
|
PR_Github #34575 [ run ] completed with state |
…E backends (NVIDIA#11128) Signed-off-by: xxi <xxi@nvidia.com>
Description
Summary
This PR introduces a unified test framework for MoE (Mixture of Experts) backends that enables systematic testing of all backend implementations through their backend-level interfaces (quantize_input + run_moe), rather than the high-level forward() interface which will be deprecated in the future.
Key Changes
Added a standardized can_implement() classmethod to all MoE backend classes:
Each implementation checks:
Implemented a new test module that provides:
Pre-computed test parameters at module load time for fast test collection
Test coverage for:
Extended quantization utilities to support the test framework with additional helper classes and methods.
Design Goals
Test Coverage
PR Checklist
Please review the following before submitting your PR:
PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.
PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.
Test cases are provided for new code paths (see test instructions)
Any new dependencies have been scanned for license and vulnerabilities
CODEOWNERS updated if ownership changes
Documentation updated as needed
Update tava architecture diagram if there is a significant design change in PR.
The reviewers assigned automatically/manually are appropriate for the PR.
Please check this after reviewing the above items as appropriate for this PR.
GitHub Bot Help
/bot [-h] ['run', 'kill', 'skip', 'reuse-pipeline'] ...Provide a user friendly way for developers to interact with a Jenkins server.
Run
/bot [-h|--help]to print this help message.See details below for each supported subcommand.
Details
run [--reuse-test (optional)pipeline-id --disable-fail-fast --skip-test --stage-list "A10-PyTorch-1, xxx" --gpu-type "A30, H100_PCIe" --test-backend "pytorch, cpp" --add-multi-gpu-test --only-multi-gpu-test --disable-multi-gpu-test --post-merge --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx" --detailed-log --debug(experimental)]Launch build/test pipelines. All previously running jobs will be killed.
--reuse-test (optional)pipeline-id(OPTIONAL) : Allow the new pipeline to reuse build artifacts and skip successful test stages from a specified pipeline or the last pipeline if no pipeline-id is indicated. If the Git commit ID has changed, this option will be always ignored. The DEFAULT behavior of the bot is to reuse build artifacts and successful test results from the last pipeline.--disable-reuse-test(OPTIONAL) : Explicitly prevent the pipeline from reusing build artifacts and skipping successful test stages from a previous pipeline. Ensure that all builds and tests are run regardless of previous successes.--disable-fail-fast(OPTIONAL) : Disable fail fast on build/tests/infra failures.--skip-test(OPTIONAL) : Skip all test stages, but still run build stages, package stages and sanity check stages. Note: Does NOT update GitHub check status.--stage-list "A10-PyTorch-1, xxx"(OPTIONAL) : Only run the specified test stages. Examples: "A10-PyTorch-1, xxx". Note: Does NOT update GitHub check status.--gpu-type "A30, H100_PCIe"(OPTIONAL) : Only run the test stages on the specified GPU types. Examples: "A30, H100_PCIe". Note: Does NOT update GitHub check status.--test-backend "pytorch, cpp"(OPTIONAL) : Skip test stages which don't match the specified backends. Only support [pytorch, cpp, tensorrt, triton]. Examples: "pytorch, cpp" (does not run test stages with tensorrt or triton backend). Note: Does NOT update GitHub pipeline status.--only-multi-gpu-test(OPTIONAL) : Only run the multi-GPU tests. Note: Does NOT update GitHub check status.--disable-multi-gpu-test(OPTIONAL) : Disable the multi-GPU tests. Note: Does NOT update GitHub check status.--add-multi-gpu-test(OPTIONAL) : Force run the multi-GPU tests in addition to running L0 pre-merge pipeline.--post-merge(OPTIONAL) : Run the L0 post-merge pipeline instead of the ordinary L0 pre-merge pipeline.--extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx"(OPTIONAL) : Run the ordinary L0 pre-merge pipeline and specified test stages. Examples: --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx".--detailed-log(OPTIONAL) : Enable flushing out all logs to the Jenkins console. This will significantly increase the log volume and may slow down the job.--debug(OPTIONAL) : Experimental feature. Enable access to the CI container for debugging purpose. Note: Specify exactly one stage in thestage-listparameter to access the appropriate container environment. Note: Does NOT update GitHub check status.For guidance on mapping tests to stage names, see
docs/source/reference/ci-overview.mdand the
scripts/test_to_stage_mapping.pyhelper.kill
killKill all running builds associated with pull request.
skip
skip --comment COMMENTSkip testing for latest commit on pull request.
--comment "Reason for skipping build/test"is required. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.reuse-pipeline
reuse-pipelineReuse a previous pipeline to validate current commit. This action will also kill all currently running builds associated with the pull request. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.
Summary by CodeRabbit
New Features
Tests
✏️ Tip: You can customize this high-level summary in your review settings.