Skip to content

Comments

[TRTLLM-10022][feat] Add hopper xqa decode support for skip softmax attention#10264

Merged
pengbowang-nv merged 11 commits intoNVIDIA:mainfrom
pengbowang-nv:dev-add-hopper-xqa-skip-softmax
Jan 12, 2026
Merged

[TRTLLM-10022][feat] Add hopper xqa decode support for skip softmax attention#10264
pengbowang-nv merged 11 commits intoNVIDIA:mainfrom
pengbowang-nv:dev-add-hopper-xqa-skip-softmax

Conversation

@pengbowang-nv
Copy link
Collaborator

@pengbowang-nv pengbowang-nv commented Dec 24, 2025

Summary by CodeRabbit

Release Notes

  • New Features

    • Added softmax attention skipping optimization for multi-head attention kernels with configurable threshold control.
    • Optional block-level statistics collection for skipped attention operations.
    • Environment variable support for tuning subsequence configuration.
  • Bug Fixes

    • Added data alignment validation in kernel operations.
  • Refactor

    • Enhanced data type handling to support multiple KV cache formats in tensor operations.

✏️ Tip: You can customize this high-level summary in your review settings.

Description

Added skip softmax attention in XQA Hopper style kernel, and enable the usage of it in TensorRT-LLM

Test Coverage

PR Checklist

Please review the following before submitting your PR:

  • PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.

  • PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.

  • Test cases are provided for new code paths (see test instructions)

  • Any new dependencies have been scanned for license and vulnerabilities

  • CODEOWNERS updated if ownership changes

  • Documentation updated as needed

  • Update tava architecture diagram if there is a significant design change in PR.

  • The reviewers assigned automatically/manually are appropriate for the PR.

  • Please check this after reviewing the above items as appropriate for this PR.

GitHub Bot Help

/bot [-h] ['run', 'kill', 'skip', 'reuse-pipeline'] ...

Provide a user friendly way for developers to interact with a Jenkins server.

Run /bot [-h|--help] to print this help message.

See details below for each supported subcommand.

Details

run [--reuse-test (optional)pipeline-id --disable-fail-fast --skip-test --stage-list "A10-PyTorch-1, xxx" --gpu-type "A30, H100_PCIe" --test-backend "pytorch, cpp" --add-multi-gpu-test --only-multi-gpu-test --disable-multi-gpu-test --post-merge --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx" --detailed-log --debug(experimental)]

Launch build/test pipelines. All previously running jobs will be killed.

--reuse-test (optional)pipeline-id (OPTIONAL) : Allow the new pipeline to reuse build artifacts and skip successful test stages from a specified pipeline or the last pipeline if no pipeline-id is indicated. If the Git commit ID has changed, this option will be always ignored. The DEFAULT behavior of the bot is to reuse build artifacts and successful test results from the last pipeline.

--disable-reuse-test (OPTIONAL) : Explicitly prevent the pipeline from reusing build artifacts and skipping successful test stages from a previous pipeline. Ensure that all builds and tests are run regardless of previous successes.

--disable-fail-fast (OPTIONAL) : Disable fail fast on build/tests/infra failures.

--skip-test (OPTIONAL) : Skip all test stages, but still run build stages, package stages and sanity check stages. Note: Does NOT update GitHub check status.

--stage-list "A10-PyTorch-1, xxx" (OPTIONAL) : Only run the specified test stages. Examples: "A10-PyTorch-1, xxx". Note: Does NOT update GitHub check status.

--gpu-type "A30, H100_PCIe" (OPTIONAL) : Only run the test stages on the specified GPU types. Examples: "A30, H100_PCIe". Note: Does NOT update GitHub check status.

--test-backend "pytorch, cpp" (OPTIONAL) : Skip test stages which don't match the specified backends. Only support [pytorch, cpp, tensorrt, triton]. Examples: "pytorch, cpp" (does not run test stages with tensorrt or triton backend). Note: Does NOT update GitHub pipeline status.

--only-multi-gpu-test (OPTIONAL) : Only run the multi-GPU tests. Note: Does NOT update GitHub check status.

--disable-multi-gpu-test (OPTIONAL) : Disable the multi-GPU tests. Note: Does NOT update GitHub check status.

--add-multi-gpu-test (OPTIONAL) : Force run the multi-GPU tests in addition to running L0 pre-merge pipeline.

--post-merge (OPTIONAL) : Run the L0 post-merge pipeline instead of the ordinary L0 pre-merge pipeline.

--extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx" (OPTIONAL) : Run the ordinary L0 pre-merge pipeline and specified test stages. Examples: --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx".

--detailed-log (OPTIONAL) : Enable flushing out all logs to the Jenkins console. This will significantly increase the log volume and may slow down the job.

--debug (OPTIONAL) : Experimental feature. Enable access to the CI container for debugging purpose. Note: Specify exactly one stage in the stage-list parameter to access the appropriate container environment. Note: Does NOT update GitHub check status.

For guidance on mapping tests to stage names, see docs/source/reference/ci-overview.md
and the scripts/test_to_stage_mapping.py helper.

kill

kill

Kill all running builds associated with pull request.

skip

skip --comment COMMENT

Skip testing for latest commit on pull request. --comment "Reason for skipping build/test" is required. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.

reuse-pipeline

reuse-pipeline

Reuse a previous pipeline to validate current commit. This action will also kill all currently running builds associated with the pull request. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.

@pengbowang-nv pengbowang-nv marked this pull request as ready for review December 24, 2025 08:01
@coderabbitai
Copy link
Contributor

coderabbitai bot commented Dec 24, 2025

📝 Walkthrough

Walkthrough

This PR introduces a skip-softmax attention optimization mechanism across the XQA kernel infrastructure. New preprocessor configuration flags and helper functions enable conditional softmax skipping based on thresholds, with optional block-level statistics tracking. The feature gates kernel behavior via compile-time and runtime checks integrated into host-side launch functions and the SM90/Hopper kernel implementation.

Changes

Cohort / File(s) Summary
Kernel Configuration Defaults
cpp/kernels/xqa/defines.h
Adds three preprocessor configuration macros with default values to control skip-softmax behavior: SKIP_SOFTMAX_ATTN, SKIP_SOFTMAX_ATTN_BLOCK_STATS, and SKIP_SOFTMAX_ATTN_FIX_THRESHOLD_GREATER_THAN_ONE.
Memory Alignment & Data Structures
cpp/kernels/xqa/gmma.cuh, cpp/kernels/xqa/mha_stdheaders.cuh
Enforces data alignment validation in makeMatDesc; introduces public pair<T1, T2> template for namespace mha with conditional compilation for CUBIN vs. standard paths.
Host-Side API & Helper Functions
cpp/kernels/xqa/mha.h, cpp/kernels/xqa/mha.cu
Adds computeNbSubSeqPerSeqMHA and computeNbSubSeqPerSeqHopperF8MHA helper functions; expands launchMHA and launchHopperF8MHA signatures to accept skipSoftmaxThresholdScaleFactor and optional block statistics pointers, gated by SKIP_SOFTMAX_ATTN flags.
SM90/Hopper Kernel Implementation
cpp/kernels/xqa/mha_sm90.cu
Core kernel modifications: adds skip-softmax-specific shared memory buffers (nbKBuf, nbVBuf, nbXBuf), synchronization barriers (xBar, skipSoftmaxXBar), voting arrays, threshold computation, and optional atomic statistics collection. Extends kernel signature to accept skipSoftmaxThresholdScaleFactor and block count pointers.
Reference Implementation & Testing
cpp/kernels/xqa/test/refAttention.h, cpp/kernels/xqa/test/refAttention.cpp, cpp/kernels/xqa/test/test.cpp
Extends refFlashAttention and runTest signatures with skipSoftmaxThresholdScaleFactor, block counters, and multiBlockNum parameters; implements skip-threshold logic in reference path and adds diagnostic output for skipped block statistics.
XQA Parameter Structure
cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/xqaParams.h
Adds conditional members skip_softmax_total_blocks and skip_softmax_skipped_blocks under SKIP_SOFTMAX_STAT flag; updates toString() to include these fields.
TensorRT-LLM Integration - Parameter Conversion & Compilation
cpp/tensorrt_llm/common/attentionOp.cpp, cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplJIT/compileEngine.cpp, cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplJIT/nvrtcWrapper/include/nvrtcWrapper.h, cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplJIT/nvrtcWrapper/src/nvrtcWrapper.cpp
Adds use_skip_softmax_attn field to JIT context struct; conditionally populates XQA parameters with skip-softmax statistics; generates SKIP_SOFTMAX_ATTN and SKIP_SOFTMAX_ATTN_BLOCK_STATS NVRTC macros based on context flag.
JIT Kernel Wrapper & Launch
cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplJIT/decoderXQAImplJIT.cpp
Derives isSkipSoftmax flag from parameter threshold; enforces GMMA-only constraint; appends skip-softmax parameters and optional statistics to kernel launch call in both spec-dec and standard paths.
Kernel Compatibility & Configuration
cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplJIT/kernelUtils.cpp
Disables HMMA and MLA kernels when skip-softmax is enabled; relaxes KV-cache type restrictions for QGMMA to accept FP16/BF16/E4M3 when skip-softmax is active.
Precompiled Kernel Guard
cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplPrecompiled.cpp
Adds early exit condition in shouldUse() to reject precompiled kernel path when skip_softmax_threshold_scale_factor is non-zero.
Tensor Map Data Type Handling
cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/tensorMapUtils.cpp
Introduces getDataTypeFromXqaParams() helper to map KV-cache data types; replaces hardcoded CU_TENSOR_MAP_DATA_TYPE_UINT8 with dynamic type derivation across K/V/Q tensor map creation functions.

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

Pre-merge checks and finishing touches

❌ Failed checks (1 warning, 1 inconclusive)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 13.33% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
Description check ❓ Inconclusive The PR description is minimal and does not adequately document the scope of changes. The template sections for Test Coverage are empty, and specific test cases are not listed. Complete the Test Coverage section by specifying which tests safeguard the skip-softmax attention feature changes, and provide more details about the implementation approach and impact.
✅ Passed checks (1 passed)
Check name Status Explanation
Title check ✅ Passed The PR title clearly and specifically summarizes the main change: adding Hopper XQA decode support for skip softmax attention, with proper JIRA ticket reference and feature type.
✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (3)
cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplJIT/kernelUtils.cpp (1)

1-15: Update the copyright year to include 2025.

The copyright header shows 2020-2023, but this file has meaningful modifications in 2025. As per coding guidelines, the copyright should reflect the year of latest modification.

-* Copyright (c) 2020-2023, NVIDIA CORPORATION.  All rights reserved.
+* Copyright (c) 2020-2025, NVIDIA CORPORATION.  All rights reserved.
cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplJIT/decoderXQAImplJIT.cpp (1)

235-388: Kernel arg buffer too small when skip-softmax stats are enabled

The new skip-softmax wiring (isSkipSoftmax + extra appendParam calls) can exceed the fixed kMAX_NB_KERNEL_PARAMS = 16:

  • In GMMA paths with:
    • RoPE fused (applyRoPEInXqaKernel == true),
    • isFp8Out && !needOutputCvt,
    • beam_width > 1,
    • skip-softmax enabled (isSkipSoftmax),
    • and SKIP_SOFTMAX_STAT defined,

the non-spec-dec GMMA branch now appends 18 parameters (including stats, semaphores, scratch). The spec-dec GMMA branch with skip-softmax+stats also reaches 18.

Since appendParam asserts idxNextParam < kMAX_NB_KERNEL_PARAMS, those configurations will trip the check at runtime once idxNextParam reaches 16.

Consider bumping kMAX_NB_KERNEL_PARAMS to something safely above the maximum (e.g., 24) so future additions don’t regress, and keep the assertion as a safeguard.

The GMMA-only guard for skip-softmax (TLLM_CHECK_WITH_INFO(isGMMAKernel, ...)) looks correct otherwise.

cpp/kernels/xqa/test/test.cpp (1)

845-915: Guard DRAM “with-skip” estimate against zero total block count

The DRAM estimate with skip:

size_t const totalNbCacheLoadWithSkip = gmemCacheHeadBytes
    * (nbKHeads + nbVHeads * (1 - 1.0f * kernelSkippedBlockCount[0] / kernelTotalBlockCount[0]))
    * nbLoadedCacheTokens;
...
float const dramSolRatioWithSkip = dramSolTimeWithSkip / ms;

assumes kernelTotalBlockCount[0] > 0. In tests where skip-softmax is enabled but no blocks qualify for skipping (or stats are misconfigured), kernelTotalBlockCount[0] can be zero, leading to a 0/0 and NaN/Inf ratios.

You already guard the printed percentage with kernelTotalBlockCount[0] == 0 ? 0.0f : ...; consider adding a similar guard before computing totalNbCacheLoadWithSkip (e.g., early‑out or treat “no stats” as “no skip”) to avoid NaNs in dramSolRatioWithSkip.

🧹 Nitpick comments (8)
cpp/kernels/xqa/test/test.cpp (2)

339-349: Macro gating around skip-softmax threshold and kernel stats is sound

  • Under #if SKIP_SOFTMAX_ATTN_BLOCK_STATS, the kernel-side counters are conditionally allocated and zeroed.
  • When SKIP_SOFTMAX_ATTN is not enabled, you explicitly expect skipSoftmaxThresholdScaleFactor == 0.0f via EXPECT_EQ, which avoids silently ignoring a nonzero threshold.

This keeps the test harness behavior aligned with the compile-time configuration and prevents misuse when the feature is compiled out.


1129-1190: Reference skip-softmax wiring and host/kernel skip counters look consistent

  • The reference branch under if (useQGMMA) now calls refFlashAttention<CacheElem, 64> with:
    • skipSoftmaxThresholdScaleFactor
    • addresses of skippedBlockCount / totalBlockCount
    • multiBlockNum

which matches the updated refFlashAttention signature.

  • Final prints at the end of refCheck:

    • Always emit host skippedBlockCount / totalBlockCount (with a zero-guard).
    • Optionally emit kernel block stats under SKIP_SOFTMAX_ATTN_BLOCK_STATS.

This gives a clear comparison of host vs kernel skip behavior. No issues here; just be aware that host counters aggregate over all heads and requests in the run, which is fine for coarse diagnostics.

cpp/kernels/xqa/mha.cu (2)

2736-2754: computeNbSubSeqPerSeqMHA logic matches intended multi-block behavior

The new helper:

  • Respects allowMultiBlockMode (returns 1 when disabled).
  • Allows an explicit override via XQA_NB_SUB_SEQ if set to a positive value.
  • Otherwise picks nbSubSeqPerSeq proportional to prop.multiProcessorCount / (batchSize * nbKHeads), clamped between 1 and divUp(maxSeqLen, ctaTile.x).

This is a sensible policy and centralizes the logic that was previously inlined, making it easier to share with tests.


2756-2892: launchMHA signature extended for skip-softmax compatibility; ensure includes and callers are aligned

  • launchMHA now accepts additional arguments under #if SKIP_SOFTMAX_ATTN (and SKIP_SOFTMAX_ATTN_BLOCK_STATS) purely for compatibility with launchHopperF8MHA; they are not used in this TU, which is fine.
  • nbSubSeqPerSeq is now obtained via computeNbSubSeqPerSeqMHA, keeping host configuration consistent with the kernel’s grid logic.

Two follow-ups to consider:

  1. Includes: computeNbSubSeqPerSeqMHA uses std::getenv and std::stoi. If those are not already provided transitively by hostUtils.h or other headers in this TU, you’ll want to add the appropriate standard headers (<cstdlib>, <string>) here to avoid relying on non-portable transitive includes.

  2. Callers: Make sure every non-test caller of launchMHA is compiled with consistent SKIP_SOFTMAX_ATTN / SKIP_SOFTMAX_ATTN_BLOCK_STATS values and has been updated to pass the new arguments where required (or is built with the feature disabled so the extra parameters are compiled out).

cpp/kernels/xqa/test/refAttention.cpp (1)

51-75: Skip-softmax reference logic is reasonable but missing a direct include

The extended refFlashAttention:

  • Computes an nbSubSeq based on multiBlockNum and nbTiles, mirroring the multi-block behavior.
  • Maintains per-subsequence skipRowMaxs and uses a simple threshold test on localRowMax - prevSkipRowMax to decide whether to skip a tile, updating skippedBlockCount / totalBlockCount accordingly.
  • Only evaluates log(skipSoftmaxThreshold) when skipSoftmaxThreshold > 0, so it avoids domain errors.

Two small nits:

  • This TU now uses std::vector but does not include <vector> explicitly (neither here nor in refAttention.h). It would be safer to add #include <vector> to avoid relying on transitive includes.
  • The new <cstdio> include appears unused in this function; consider removing it if nothing else in the file needs it.
cpp/kernels/xqa/mha_sm90.cu (3)

1010-1023: Memory fence is essential, not optional.

The comment // maybe not used on line 1020 is misleading. The fence.proxy.async.shared::cta is required here to ensure the write to skipSoftmaxVotesGemm0ToGemm1[idxXBuf] (line 1015) is visible to the Gemm1 warp group before they read it after xBar.produced.arrive(). Without this fence, there's a potential data race.

Consider updating the comment to clarify its purpose:

-                asm volatile("fence.proxy.async.shared::cta;\n"); // maybe not used
+                asm volatile("fence.proxy.async.shared::cta;\n"); // ensures vote visibility before xBar signal

3486-3490: Inconsistent parameter naming convention.

The parameters skipped_block_count and total_block_count use snake_case, but per the coding guidelines, local variables and parameters in C++ should use camelCase. The kernel parameters on lines 708-709 correctly use skippedBlockCount and totalBlockCount.

🔎 Suggested fix for consistency
 #if SKIP_SOFTMAX_ATTN
     float const skipSoftmaxThresholdScaleFactor,
 #if SKIP_SOFTMAX_ATTN_BLOCK_STATS
-    uint32_t* __restrict__ skipped_block_count, uint32_t* __restrict__ total_block_count,
+    uint32_t* __restrict__ skippedBlockCount, uint32_t* __restrict__ totalBlockCount,
 #endif
 #endif

And update the corresponding kernel launch call on lines 3584-3585.

As per coding guidelines, local variables and parameters should use camelCase.


1673-1673: TODO acknowledged: Multiblock mode not yet integrated with skip-softmax.

The comment correctly flags that multiblock mode needs additional work for skip-softmax. Ensure this is tracked for follow-up.

Would you like me to open an issue to track the multiblock mode integration with skip-softmax attention?

📜 Review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between ecea71c and dbbd4c3.

📒 Files selected for processing (18)
  • cpp/kernels/xqa/defines.h
  • cpp/kernels/xqa/gmma.cuh
  • cpp/kernels/xqa/mha.cu
  • cpp/kernels/xqa/mha.h
  • cpp/kernels/xqa/mha_sm90.cu
  • cpp/kernels/xqa/mha_stdheaders.cuh
  • cpp/kernels/xqa/test/refAttention.cpp
  • cpp/kernels/xqa/test/refAttention.h
  • cpp/kernels/xqa/test/test.cpp
  • cpp/tensorrt_llm/common/attentionOp.cpp
  • cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplJIT/compileEngine.cpp
  • cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplJIT/decoderXQAImplJIT.cpp
  • cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplJIT/kernelUtils.cpp
  • cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplJIT/nvrtcWrapper/include/nvrtcWrapper.h
  • cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplJIT/nvrtcWrapper/src/nvrtcWrapper.cpp
  • cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplPrecompiled.cpp
  • cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/tensorMapUtils.cpp
  • cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/xqaParams.h
🧰 Additional context used
📓 Path-based instructions (3)
**/*.{cpp,h,cu,cuh}

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

**/*.{cpp,h,cu,cuh}: Closing braces of namespaces should have a comment saying the namespace it closes: } // namespace foo
Prefer const or constexpr variables over #define whenever possible, as the latter are not visible to the compiler
A variable that is not modified after its initialization should be declared as const
For naming of constants in C++, follow the naming section conventions
Except 0 (only used in comparison for checking signness/existence/emptiness) and nullptr, true, false, all other literals should only be used for variable initialization in C++
Use the Allman indentation style in C++
Put the semicolon for an empty for or while loop in a new line in C++
The statement forming the body of a switch, while, do .. while or for statement shall be a compound statement (use brace-delimited statements) in C++
If and else should always be followed by brace-delimited statements, even if empty or a single statement in C++
C++ filenames should use camel case with first letter lowercase: thisIsASubDir and thisIsAFilename.cpp
All files involved in the compilation of a compilation target (.exe/.so) must have filenames that are case-insensitive unique in C++
All types (including class names) in C++ should use camel case with uppercase first letter: FooBarClass
Local variables, methods and namespaces in C++ should use camel case with first letter lowercase: localFooBar
Non-magic-number global variables that are non-static and not defined in anonymous namespace in C++ should use camel case prefixed by a lower case 'g': gDontUseGlobalFoos
Non-magic-number global variables that are static or defined in an anonymous namespace in C++ should use camel case prefixed by a lower case 's': sMutableStaticGlobal
Locally visible static variables in C++ should use camel case with lowercase prefix 's' as the first letter: static std::once_flag sFlag;
Public, private and protected class member variables in C++ should use camel case prefi...

Files:

  • cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplJIT/nvrtcWrapper/include/nvrtcWrapper.h
  • cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplJIT/decoderXQAImplJIT.cpp
  • cpp/kernels/xqa/defines.h
  • cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplJIT/kernelUtils.cpp
  • cpp/kernels/xqa/mha_stdheaders.cuh
  • cpp/tensorrt_llm/common/attentionOp.cpp
  • cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/tensorMapUtils.cpp
  • cpp/kernels/xqa/mha.h
  • cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplJIT/nvrtcWrapper/src/nvrtcWrapper.cpp
  • cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/xqaParams.h
  • cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplPrecompiled.cpp
  • cpp/kernels/xqa/test/refAttention.cpp
  • cpp/kernels/xqa/mha.cu
  • cpp/kernels/xqa/gmma.cuh
  • cpp/kernels/xqa/test/refAttention.h
  • cpp/kernels/xqa/mha_sm90.cu
  • cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplJIT/compileEngine.cpp
  • cpp/kernels/xqa/test/test.cpp
**/*.h

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

**/*.h: Use a preprocessor guard in C++ header files with the format TRTLLM_<FILENAME>_H derived from the filename in all caps
The preprocessor guard name in C++ must have prefix TRTLLM_ followed by the filename, all in caps. Only use the file name, not directory names
Do not use prefix with underscore in C++ preprocessor guard symbols as such symbols are reserved in C++ standard for compilers or implementation
Do not use trailing underscore in C++ preprocessor guard symbols (unlike Google C++ guideline)

Files:

  • cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplJIT/nvrtcWrapper/include/nvrtcWrapper.h
  • cpp/kernels/xqa/defines.h
  • cpp/kernels/xqa/mha.h
  • cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/xqaParams.h
  • cpp/kernels/xqa/test/refAttention.h
**/*.{cpp,h,cu,cuh,py}

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

All TensorRT-LLM Open Source Software code should contain an NVIDIA copyright header that includes the year of its latest meaningful modification

Files:

  • cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplJIT/nvrtcWrapper/include/nvrtcWrapper.h
  • cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplJIT/decoderXQAImplJIT.cpp
  • cpp/kernels/xqa/defines.h
  • cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplJIT/kernelUtils.cpp
  • cpp/kernels/xqa/mha_stdheaders.cuh
  • cpp/tensorrt_llm/common/attentionOp.cpp
  • cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/tensorMapUtils.cpp
  • cpp/kernels/xqa/mha.h
  • cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplJIT/nvrtcWrapper/src/nvrtcWrapper.cpp
  • cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/xqaParams.h
  • cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplPrecompiled.cpp
  • cpp/kernels/xqa/test/refAttention.cpp
  • cpp/kernels/xqa/mha.cu
  • cpp/kernels/xqa/gmma.cuh
  • cpp/kernels/xqa/test/refAttention.h
  • cpp/kernels/xqa/mha_sm90.cu
  • cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplJIT/compileEngine.cpp
  • cpp/kernels/xqa/test/test.cpp
🧠 Learnings (19)
📓 Common learnings
Learnt from: thorjohnsen
Repo: NVIDIA/TensorRT-LLM PR: 6910
File: cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp:0-0
Timestamp: 2025-08-14T21:04:50.248Z
Learning: In KV cache onboarding logic during prefill in cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp, when calculating which blocks fall within the attention window, use getTokensPerBlock() to advance token indices rather than block->getUniqueTokens().size(), because the calculation needs to consider the post-prefill state where blocks will be filled to capacity, not their current token count.
📚 Learning: 2025-08-14T15:38:01.771Z
Learnt from: MatthiasKohl
Repo: NVIDIA/TensorRT-LLM PR: 6904
File: cpp/tensorrt_llm/pybind/thop/bindings.cpp:55-57
Timestamp: 2025-08-14T15:38:01.771Z
Learning: In TensorRT-LLM Python bindings, tensor parameter collections like mla_tensor_params and spec_decoding_tensor_params are kept as required parameters without defaults to maintain API consistency, even when it might affect backward compatibility.

Applied to files:

  • cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplJIT/decoderXQAImplJIT.cpp
📚 Learning: 2025-08-14T21:04:50.248Z
Learnt from: thorjohnsen
Repo: NVIDIA/TensorRT-LLM PR: 6910
File: cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp:0-0
Timestamp: 2025-08-14T21:04:50.248Z
Learning: In KV cache onboarding logic during prefill in cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp, when calculating which blocks fall within the attention window, use getTokensPerBlock() to advance token indices rather than block->getUniqueTokens().size(), because the calculation needs to consider the post-prefill state where blocks will be filled to capacity, not their current token count.

Applied to files:

  • cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplJIT/kernelUtils.cpp
  • cpp/tensorrt_llm/common/attentionOp.cpp
  • cpp/kernels/xqa/test/refAttention.cpp
  • cpp/kernels/xqa/test/refAttention.h
📚 Learning: 2025-08-15T06:46:54.897Z
Learnt from: eopXD
Repo: NVIDIA/TensorRT-LLM PR: 6767
File: cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp:0-0
Timestamp: 2025-08-15T06:46:54.897Z
Learning: In cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp addToken function, newly allocated blocks are unshared by design. The beam search path in addToken (when sequence.getNumTokens() > windowSize) is currently broken/non-functional with SWA, so the block allocation doesn't follow a shared-then-unshared pattern.

Applied to files:

  • cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplJIT/kernelUtils.cpp
  • cpp/kernels/xqa/mha_sm90.cu
  • cpp/kernels/xqa/test/test.cpp
📚 Learning: 2025-12-19T06:31:54.973Z
Learnt from: nvyocox
Repo: NVIDIA/TensorRT-LLM PR: 10117
File: tensorrt_llm/_torch/auto_deploy/transform/library/fuse_rope_attention.py:336-339
Timestamp: 2025-12-19T06:31:54.973Z
Learning: In tensorrt_llm/_torch/auto_deploy/transform/library/fuse_rope_attention.py, the cast to torch.float16 for qkv_node before creating the AttentionPlugin is intentional and required because DriveOS LLM expects float16 dtype specifically. This should not be changed to preserve original dtype or made configurable for bfloat16 models in the DriveOS LLM ONNX export path.

Applied to files:

  • cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplJIT/kernelUtils.cpp
  • cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/tensorMapUtils.cpp
📚 Learning: 2025-09-23T15:01:00.070Z
Learnt from: nv-lschneider
Repo: NVIDIA/TensorRT-LLM PR: 7910
File: cpp/tensorrt_llm/kernels/nccl_device/config.cu:15-17
Timestamp: 2025-09-23T15:01:00.070Z
Learning: In TensorRT-LLM NCCL device kernels, the <sstream> header is not needed as an explicit include in config.cu because it's provided transitively through other headers. Local compilation testing confirms this works without the explicit include.

Applied to files:

  • cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplJIT/nvrtcWrapper/src/nvrtcWrapper.cpp
📚 Learning: 2025-12-17T22:39:44.244Z
Learnt from: CR
Repo: NVIDIA/TensorRT-LLM PR: 0
File: CODING_GUIDELINES.md:0-0
Timestamp: 2025-12-17T22:39:44.244Z
Learning: Applies to **/*.{cpp,h,cu,cuh} : Use `#if` / `#endif` to disable C++ code, preferably with a mnemonic condition

Applied to files:

  • cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplJIT/nvrtcWrapper/src/nvrtcWrapper.cpp
📚 Learning: 2025-12-17T22:39:44.244Z
Learnt from: CR
Repo: NVIDIA/TensorRT-LLM PR: 0
File: CODING_GUIDELINES.md:0-0
Timestamp: 2025-12-17T22:39:44.244Z
Learning: Applies to **/*.{cpp,h,cu,cuh} : Macros in C++ should follow uppercase snake_case: `FOO_VERSION`

Applied to files:

  • cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplJIT/nvrtcWrapper/src/nvrtcWrapper.cpp
📚 Learning: 2025-12-17T22:39:44.244Z
Learnt from: CR
Repo: NVIDIA/TensorRT-LLM PR: 0
File: CODING_GUIDELINES.md:0-0
Timestamp: 2025-12-17T22:39:44.244Z
Learning: Applies to **/*.{cpp,h,cu,cuh} : `#define` and `#undef` of macros in C++ should be done only at global namespace

Applied to files:

  • cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplJIT/nvrtcWrapper/src/nvrtcWrapper.cpp
📚 Learning: 2025-12-17T22:39:44.244Z
Learnt from: CR
Repo: NVIDIA/TensorRT-LLM PR: 0
File: CODING_GUIDELINES.md:0-0
Timestamp: 2025-12-17T22:39:44.244Z
Learning: Applies to **/*.{cpp,h,cu,cuh} : Prefer `const` or `constexpr` variables over `#define` whenever possible, as the latter are not visible to the compiler

Applied to files:

  • cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplJIT/nvrtcWrapper/src/nvrtcWrapper.cpp
📚 Learning: 2025-12-17T22:39:44.244Z
Learnt from: CR
Repo: NVIDIA/TensorRT-LLM PR: 0
File: CODING_GUIDELINES.md:0-0
Timestamp: 2025-12-17T22:39:44.244Z
Learning: Applies to **/*.{cpp,h,cu,cuh} : Avoid the use of `#ifdef` and `#ifndef` directives in C++ (except in header include guards). Prefer `#if defined(...)` or `#if !defined(...)` instead

Applied to files:

  • cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplJIT/nvrtcWrapper/src/nvrtcWrapper.cpp
📚 Learning: 2025-08-20T06:56:02.889Z
Learnt from: eopXD
Repo: NVIDIA/TensorRT-LLM PR: 6768
File: cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp:577-579
Timestamp: 2025-08-20T06:56:02.889Z
Learning: In cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp, maxSequenceLength is now enforced as a non-optional argument in the BlockManager constructor, so concerns about std::nullopt defaulting to 0 are not applicable. When windowSize > maxSequenceLength, a warning should be added instead of handling optional parameter cases.

Applied to files:

  • cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplPrecompiled.cpp
  • cpp/kernels/xqa/test/test.cpp
📚 Learning: 2025-08-08T22:03:40.707Z
Learnt from: sklevtsov-nvidia
Repo: NVIDIA/TensorRT-LLM PR: 3294
File: cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_kernels.cu:1198-1209
Timestamp: 2025-08-08T22:03:40.707Z
Learning: In the CUTLASS MoE kernels (cpp/tensorrt_llm/cutlass_extensions), when `layout_info.fusion` is set to `TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::FINALIZE`, the `router_scales` parameter must be non-null by design. The fused finalize kernel epilogue does not perform nullptr checks and requires valid router scales to function correctly. This is an implicit contract that callers must satisfy when enabling the FINALIZE fusion mode.

Applied to files:

  • cpp/kernels/xqa/gmma.cuh
📚 Learning: 2025-08-19T03:35:20.866Z
Learnt from: djns99
Repo: NVIDIA/TensorRT-LLM PR: 6915
File: cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_kernels.cu:4616-4626
Timestamp: 2025-08-19T03:35:20.866Z
Learning: In the MOE profiler TMA workspace preparation (cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_kernels.cu), the overlapping of TMA WS regions for NONE and FINALIZE variants is deliberate design to save memory space, as confirmed by djns99. The comment "reuse the same pointers to save space" reflects this intentional behavior.

Applied to files:

  • cpp/kernels/xqa/mha_sm90.cu
📚 Learning: 2025-09-22T19:25:45.607Z
Learnt from: nv-lschneider
Repo: NVIDIA/TensorRT-LLM PR: 7910
File: cpp/tensorrt_llm/kernels/userbuffers/ub_allocator.cpp:170-179
Timestamp: 2025-09-22T19:25:45.607Z
Learning: In NCCLUserBufferAllocator::getNCCLDevComm(), multimem support is hard-coded to true because multimem is required for this function. The caller is responsible for ensuring multimem is available before calling this function - it should not be called if multimem is not supported.

Applied to files:

  • cpp/kernels/xqa/mha_sm90.cu
📚 Learning: 2025-08-08T05:10:38.906Z
Learnt from: sklevtsov-nvidia
Repo: NVIDIA/TensorRT-LLM PR: 3294
File: cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/epilogue/fusion/sm90_visitor_scatter.hpp:0-0
Timestamp: 2025-08-08T05:10:38.906Z
Learning: The ScaledAccPerRowBiasPerColScaleScatter fusion in CUTLASS extensions (cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/epilogue/fusion/sm90_visitor_scatter.hpp) is specifically designed for per-column scaling factors only, so it uses a fixed Stride<_0,_1,int64_t> rather than conditional stride logic.

Applied to files:

  • cpp/kernels/xqa/mha_sm90.cu
📚 Learning: 2025-09-23T14:58:05.372Z
Learnt from: nv-lschneider
Repo: NVIDIA/TensorRT-LLM PR: 7910
File: cpp/tensorrt_llm/kernels/nccl_device/config.cu:42-49
Timestamp: 2025-09-23T14:58:05.372Z
Learning: In TensorRT-LLM NCCL device kernels (cpp/tensorrt_llm/kernels/nccl_device/), the token partitioning intentionally uses ceil-like distribution (same token_per_rank for all ranks) to ensure all ranks launch the same number of blocks. This is required for optimal NCCL device API barrier performance, even though it may launch extra blocks for non-existent tokens on later ranks. Runtime bounds checking in the kernel (blockID validation) handles the overshoot cases.

Applied to files:

  • cpp/kernels/xqa/mha_sm90.cu
📚 Learning: 2025-08-14T15:36:37.610Z
Learnt from: MatthiasKohl
Repo: NVIDIA/TensorRT-LLM PR: 6904
File: cpp/tensorrt_llm/kernels/mlaKernels.cu:436-439
Timestamp: 2025-08-14T15:36:37.610Z
Learning: CUDA kernels prioritize performance and should avoid runtime bounds checking or conditional operations that cause branching/warp divergence. Input validation should be done at the host level before kernel launch, not per-thread in the kernel.

Applied to files:

  • cpp/kernels/xqa/mha_sm90.cu
📚 Learning: 2025-08-21T09:41:49.347Z
Learnt from: eopXD
Repo: NVIDIA/TensorRT-LLM PR: 6768
File: cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp:2010-2045
Timestamp: 2025-08-21T09:41:49.347Z
Learning: In cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp, updateSequenceCacheBlockOffsets is specifically for updating bookkeeping when blocks are added during the context phase, not for refreshing offsets after detach operations. During detach operations, GenerationRequest::removeFrontBlock handles the necessary cache block bookkeeping internally.

Applied to files:

  • cpp/kernels/xqa/test/test.cpp
🧬 Code graph analysis (4)
cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplJIT/decoderXQAImplJIT.cpp (1)
cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplPrecompiled.cpp (4)
  • xqaParams (109-138)
  • xqaParams (109-109)
  • xqaParams (140-157)
  • xqaParams (140-140)
cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplJIT/kernelUtils.cpp (1)
cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplPrecompiled.cpp (4)
  • xqaParams (109-138)
  • xqaParams (109-109)
  • xqaParams (140-157)
  • xqaParams (140-140)
cpp/kernels/xqa/mha.cu (1)
cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQARunner.cpp (2)
  • divUp (63-66)
  • divUp (63-63)
cpp/kernels/xqa/test/test.cpp (1)
cpp/kernels/xqa/mha_sm90.cu (2)
  • computeNbSubSeqPerSeqHopperF8MHA (3430-3446)
  • computeNbSubSeqPerSeqHopperF8MHA (3430-3431)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Pre-commit Check
🔇 Additional comments (32)
cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplJIT/kernelUtils.cpp (3)

99-108: LGTM!

The gating logic correctly restricts the Hopper QGMMA kernel path with FP16/BF16 KV cache to cases where skip-softmax is enabled, while preserving E4M3 support unconditionally. The inline comment clearly documents the design intent.


177-181: LGTM!

Correctly disables the HMMA path when skip-softmax is enabled, ensuring the feature routes exclusively through the Hopper QGMMA kernel.


215-219: LGTM!

Consistent with the HMMA gating—MLA path is disabled when skip-softmax is enabled.

cpp/kernels/xqa/mha_stdheaders.cuh (1)

1276-1286: The minimal pair implementation is problematic—verify it's actually used before committing.

The custom pair struct lacks constructors and comparison operators beyond what std::pair provides. While aggregate initialization like pair<T1, T2>{val1, val2} works, attempting construction, assignment, or comparison could cause compilation failures in GENERATE_CUBIN builds. Ensure this type is instantiated in the codebase and document its intended use, or add standard constructors to match std::pair behavior.

cpp/kernels/xqa/gmma.cuh (1)

109-109: Good addition of alignment assertion for GMMA operations.

The assertion enforces strict alignment requirements (up to 1024 bytes) based on swizzle mode, which is critical for correctness of Hopper GMMA operations. This defensive check will catch alignment violations early during development.

Most current call sites pass nullptr as the data parameter, so the assertion trivially passes; however, this protection is still valuable for detecting misuse if future code paths pass non-null data pointers. The SharedMem structure is properly aligned with alignas(128), and the assertion pattern is consistent with similar checks in the codebase (e.g., line 86).

cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplPrecompiled.cpp (1)

496-499: Guarding precompiled XQA against unsupported skip‑softmax configs looks correct

Rejecting configs with skip_softmax_threshold_scale_factor != 0 in shouldUse() is the right way to ensure skip‑softmax is only handled by the JIT path and not by precompiled kernels that don’t implement it. Logging the explicit reason keeps diagnostics clear.

cpp/kernels/xqa/defines.h (1)

132-143: Skip‑softmax macro defaults are well‑chosen and consistent

The new SKIP_SOFTMAX_ATTN* defines are guarded with #ifndef, use clear uppercase names, and default to a safe configuration (feature off, stats off, fix enabled). This matches the existing style in this header and cleanly gates skip‑softmax–related code.

cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/xqaParams.h (1)

121-127: Skip‑softmax parameterization in XQAParams is coherent

Using skip_softmax_threshold_scale_factor = 0 as the disabled default, plus optional skip_softmax_total_blocks / skip_softmax_skipped_blocks under SKIP_SOFTMAX_STAT, matches how the feature is gated elsewhere. Including these fields in toString() only when stats are compiled in gives useful diagnostics without impacting default builds.

Also applies to: 207-210

cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/tensorMapUtils.cpp (1)

67-80: Dynamic tensor‑map data type selection for KV/Q looks correct

getDataTypeFromXqaParams() cleanly maps kv_cache_data_type to the appropriate CUtensorMapDataType_enum, and the updated Hopper XQA / MLA tensor‑map builders correctly derive elemBytes and pass dataType through. This enables BF16/FP16 KV layouts while still handling INT8/FP8 as 1‑byte elements, and the TLLM_CHECK guard on unsupported enums is a good safety net.

Also applies to: 146-169, 179-187, 200-205

cpp/kernels/xqa/mha.h (1)

93-100: MHA/HopperF8 interfaces cleanly extend to support skip‑softmax

The added computeNbSubSeqPerSeq* helpers and the conditional skipSoftmaxThresholdScaleFactor + block‑stat parameters on launchMHA / launchHopperF8MHA are wired consistently with SKIP_SOFTMAX_ATTN and SKIP_SOFTMAX_ATTN_BLOCK_STATS. This keeps the interface minimal when the feature is off while exposing everything needed for skip‑softmax and per‑block statistics when enabled.

Also applies to: 131-139, 142-149, 180-189

cpp/tensorrt_llm/common/attentionOp.cpp (2)

300-305: XQA skip‑softmax fields are plumbed correctly from AttentionOp

Wiring mSkipSoftmaxThresholdScaleFactorDecode into xqaParams.skip_softmax_threshold_scale_factor and threading the optional mSkipSoftmaxTotalBlocks / mSkipSoftmaxSkippedBlocks through XQAParams (under SKIP_SOFTMAX_STAT) gives the XQA decode path the same control and observability as FMHA, without affecting builds where stats are disabled.


1896-1906: Context FMHA skip‑softmax wiring and stat handling are sound

Setting fmhaParams.skipSoftmaxThresholdScaleFactor from the prefill‑specific knob and optionally providing the shared block‑stat buffers under SKIP_SOFTMAX_STAT aligns FMHA with the new skip‑softmax pipeline. The explicit throw when getEnvPrintSkipSoftmaxStat() is set but stats aren’t compiled in avoids silent misuse of the feature.

cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplJIT/nvrtcWrapper/include/nvrtcWrapper.h (1)

67-70: New use_skip_softmax_attn flag in JIT context is appropriate

Adding use_skip_softmax_attn as a trailing field on tllmXqaJitContext provides a clear, single source of truth for enabling skip‑softmax in NVRTC‑compiled kernels, without perturbing existing fields or call‑sites that zero‑initialize the context.

cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplJIT/nvrtcWrapper/src/nvrtcWrapper.cpp (1)

218-221: Skip‑softmax NVRTC macro wiring matches the runtime context

Deriving SKIP_SOFTMAX_ATTN (and, when available, SKIP_SOFTMAX_ATTN_BLOCK_STATS) directly from context->use_skip_softmax_attn ensures JIT‑compiled XQA kernels specialize correctly for skip‑softmax and optionally emit block‑stat code. Builds without SKIP_SOFTMAX_STAT still get SKIP_SOFTMAX_ATTN set while relying on the default SKIP_SOFTMAX_ATTN_BLOCK_STATS 0, which is the intended behavior.

cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplJIT/compileEngine.cpp (1)

90-110: Propagating skip-softmax flag into JIT context looks consistent

use_skip_softmax_attn derived from skip_softmax_threshold_scale_factor != 0 is consistent with the runtime flag in DecoderXQAImplJIT::runImpl and cleanly extends tllmXqaJitContext. No issues from this side.

cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplJIT/decoderXQAImplJIT.cpp (1)

506-526: Skip-softmax parameter ordering must stay in lockstep with Hopper JIT kernel

The new block:

  • Appends SpecDecParams for speculative decoding.
  • Then, when isSkipSoftmax, appends:
    • skip_softmax_threshold_scale_factor
    • (optionally) skip_softmax_total_blocks, skip_softmax_skipped_blocks under SKIP_SOFTMAX_STAT.

This ordering needs to exactly match the Hopper XQA JIT kernel’s argument list for both spec-dec and non-spec-dec variants. Any divergence between this host-side order and the device-side signature will silently corrupt arguments.

If you haven’t already, please double‑check the SM90/Hopper XQA kernel prototypes to ensure the order and presence of these parameters (including the conditional stats pointers) are identical across:

  • the GMMA spec-dec path
  • the GMMA non-spec-dec path
  • the JIT compilation context (where use_skip_softmax_attn is set).
cpp/kernels/xqa/test/refAttention.h (1)

88-103: refFlashAttention API extension is consistent with implementation

The added parameters for skip-softmax (scale factor, block counters, multiBlockNum) are appended to the end of the template signature and match the implementation in refAttention.cpp and the new call sites in test.cpp. This keeps the reference path aligned with the Hopper skip-softmax behavior.

No issues from the header side; just ensure any other manual callers of refFlashAttention are updated to pass these extra arguments.

cpp/kernels/xqa/test/test.cpp (4)

152-233: runTest API extension and GMMA-only skip-softmax constraint look good

  • Adding float skipSoftmaxThresholdScaleFactor = 0.0f at the end of runTest keeps existing callers source-compatible.
  • Initializing skippedBlockCount / totalBlockCount and asserting skipSoftmaxThresholdScaleFactor > 0 implies useQGMMA enforces that skip-softmax is only requested on the Hopper/QGMMA path.

This is a clean way to plumb the new control while keeping non-skip tests unchanged.


747-751: Using shared helpers to determine multi-block count keeps ref and kernel in sync

Computing multiBlockNum via:

  • computeNbSubSeqPerSeqHopperF8MHA for Hopper/QGMMA
  • computeNbSubSeqPerSeqMHA for the generic MHA path

ensures the reference path sees the same subsequence count that the kernel uses (modulo any environment override via XQA_NB_SUB_SEQ). This is the right place to centralize the logic.


803-808: Skip-softmax arguments correctly threaded into launchMHA / launchHopperF8MHA

The launchFunc invocation now passes:

  • skipSoftmaxThresholdScaleFactor always under SKIP_SOFTMAX_ATTN
  • Optional kernelSkippedBlockCount / kernelTotalBlockCount when SKIP_SOFTMAX_ATTN_BLOCK_STATS is also enabled

which matches the updated prototypes of launchHopperF8MHA and launchMHA. This keeps the test harness agnostic to which implementation is selected by useQGMMA.

Just ensure all non-test call sites of these launchers have also been updated (or compiled with SKIP_SOFTMAX_ATTN == 0) so signatures remain consistent across the build.


1313-1319: New skip-softmax test cases provide good coverage of thresholds

The extra runTest invocations under #if SKIP_SOFTMAX_ATTN:

  • Exercise both “disabled” (0.f) and several nontrivial thresholds (55.f, 80.f, 125.f, 455.f).
  • Use a variety of (nbKHeads, batchSize, seqLen) configurations.

This should give reasonable coverage of the new skip-softmax behavior and its interaction with multi-block scheduling.

cpp/kernels/xqa/test/refAttention.cpp (1)

164-170: Template instantiation macro updated correctly for new parameters

INSTANTIATE_refFlashAttention now matches the extended refFlashAttention signature, including the skip-softmax threshold, counters, and multiBlockNum. This keeps all explicit instantiations consistent with the implementation.

No issues here.

cpp/kernels/xqa/mha_sm90.cu (10)

52-54: LGTM: Compatibility guard for SKIP_SOFTMAX_ATTN.

The static_assert correctly restricts the skip-softmax feature to supported configurations.


160-171: LGTM: Buffer count adjustments for skip-softmax.

The extra buffer for nbVBuf is properly documented, and static_assert(nbXBuf == nbVBuf) ensures consistency at compile time.


236-252: LGTM: Shared memory layout for skip-softmax synchronization.

The dual-vote/dual-barrier design correctly separates synchronization for V-loader and Gemm1 consumers.


814-839: LGTM: Barrier initialization with skip-softmax adjustments.

The barrier thread counts are correctly adjusted: vBar includes the V-loader warp in both counts when SKIP_SOFTMAX_ATTN is enabled to prevent race conditions.


788-790: LGTM: Skip threshold calculation.

Scaling the threshold by 1/cacheSeqLen is reasonable as attention weights are distributed across more tokens in longer sequences.


1137-1160: LGTM: Gemm1 skip-softmax control flow.

The synchronization correctly handles both skip and non-skip paths, with proper barrier coordination between warp groups.


1631-1656: LGTM: V loader skip path with race condition fix.

The additional vBar.produced.arrive() on line 1655 correctly compensates for the barrier count adjustment, preventing the race condition described in the comment (lines 1635-1639).


2133-2255: LGTM: Skip-softmax voting in computeWarpGrpColMax_sync.

The implementation correctly handles both threshold variants with well-documented trade-offs. The atomicAnd for collective voting is efficient.


3430-3446: LGTM: Extracted helper function for subsequence calculation.

Good refactoring to extract this logic into a reusable function. The environment variable override is useful for tuning.


2398-2403: LGTM: Use of mha::pair for device code.

Using mha::pair instead of std::pair is appropriate for CUDA device code compatibility.

@pengbowang-nv pengbowang-nv force-pushed the dev-add-hopper-xqa-skip-softmax branch from e7facc3 to 8bf8e2d Compare December 24, 2025 08:45
@pengbowang-nv
Copy link
Collaborator Author

/bot run --add-multi-gpu-test

@tensorrt-cicd
Copy link
Collaborator

PR_Github #29787 [ run ] triggered by Bot. Commit: 8bf8e2d

@pengbowang-nv
Copy link
Collaborator Author

/bot kill

@tensorrt-cicd
Copy link
Collaborator

PR_Github #29789 [ kill ] triggered by Bot. Commit: 8bf8e2d

@tensorrt-cicd
Copy link
Collaborator

PR_Github #29789 [ kill ] completed with state SUCCESS. Commit: 8bf8e2d
Successfully killed previous jobs for commit 8bf8e2d

@pengbowang-nv pengbowang-nv force-pushed the dev-add-hopper-xqa-skip-softmax branch from 8bf8e2d to 2202e65 Compare December 24, 2025 09:05
@pengbowang-nv
Copy link
Collaborator Author

/bot run --add-multi-gpu-test --disable-fail-fast

@tensorrt-cicd
Copy link
Collaborator

PR_Github #29801 [ run ] triggered by Bot. Commit: 490bb60

@tensorrt-cicd
Copy link
Collaborator

PR_Github #29801 [ run ] completed with state ABORTED. Commit: 490bb60
LLM/main/L0_MergeRequest_PR #22906 (Blue Ocean) completed with status: ABORTED

@pengbowang-nv
Copy link
Collaborator Author

/bot run --add-multi-gpu-test --disable-fail-fast

@tensorrt-cicd
Copy link
Collaborator

PR_Github #30657 [ run ] triggered by Bot. Commit: 490bb60

@pengbowang-nv pengbowang-nv force-pushed the dev-add-hopper-xqa-skip-softmax branch from 490bb60 to d0c7bfc Compare January 6, 2026 05:44
@pengbowang-nv
Copy link
Collaborator Author

/bot run --add-multi-gpu-test --disable-fail-fast

@tensorrt-cicd
Copy link
Collaborator

PR_Github #30685 [ run ] triggered by Bot. Commit: d0c7bfc

@tensorrt-cicd
Copy link
Collaborator

PR_Github #30685 [ run ] completed with state SUCCESS. Commit: d0c7bfc
/LLM/main/L0_MergeRequest_PR pipeline #23675 completed with status: 'FAILURE'

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

@pengbowang-nv
Copy link
Collaborator Author

/bot run --add-multi-gpu-test --disable-fail-fast

@tensorrt-cicd
Copy link
Collaborator

PR_Github #30802 [ run ] triggered by Bot. Commit: d0c7bfc

Copy link
Collaborator

@bobboli bobboli left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am unwaiving the CI test on Hopper #10420.
Please make sure that test could pass.

@tensorrt-cicd
Copy link
Collaborator

PR_Github #30802 [ run ] completed with state SUCCESS. Commit: d0c7bfc
/LLM/main/L0_MergeRequest_PR pipeline #23784 completed with status: 'FAILURE'

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

@pengbowang-nv
Copy link
Collaborator Author

/bot run --add-multi-gpu-test

@tensorrt-cicd
Copy link
Collaborator

PR_Github #30975 [ run ] triggered by Bot. Commit: d0c7bfc

@tensorrt-cicd
Copy link
Collaborator

PR_Github #30975 [ run ] completed with state SUCCESS. Commit: d0c7bfc
/LLM/main/L0_MergeRequest_PR pipeline #23933 completed with status: 'FAILURE'

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

@pengbowang-nv
Copy link
Collaborator Author

/bot run --add-multi-gpu-test --disable-fail-fast

@tensorrt-cicd
Copy link
Collaborator

PR_Github #31158 [ run ] completed with state SUCCESS. Commit: d0c7bfc
/LLM/main/L0_MergeRequest_PR pipeline #24069 completed with status: 'FAILURE'

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

Signed-off-by: Pengbo Wang <221450789+pengbowang-nv@users.noreply.github.com>
Signed-off-by: Pengbo Wang <221450789+pengbowang-nv@users.noreply.github.com>
Signed-off-by: Pengbo Wang <221450789+pengbowang-nv@users.noreply.github.com>
Signed-off-by: Pengbo Wang <221450789+pengbowang-nv@users.noreply.github.com>
Signed-off-by: Pengbo Wang <221450789+pengbowang-nv@users.noreply.github.com>
Signed-off-by: Pengbo Wang <221450789+pengbowang-nv@users.noreply.github.com>
Signed-off-by: Pengbo Wang <221450789+pengbowang-nv@users.noreply.github.com>
Signed-off-by: Pengbo Wang <221450789+pengbowang-nv@users.noreply.github.com>
Signed-off-by: Pengbo Wang <221450789+pengbowang-nv@users.noreply.github.com>
Signed-off-by: Pengbo Wang <221450789+pengbowang-nv@users.noreply.github.com>
Signed-off-by: Pengbo Wang <221450789+pengbowang-nv@users.noreply.github.com>
@pengbowang-nv pengbowang-nv force-pushed the dev-add-hopper-xqa-skip-softmax branch from d0c7bfc to 4bb57fa Compare January 9, 2026 15:05
@pengbowang-nv
Copy link
Collaborator Author

/bot run --add-multi-gpu-test --disable-fail-fast

@pengbowang-nv pengbowang-nv enabled auto-merge (squash) January 9, 2026 15:06
@tensorrt-cicd
Copy link
Collaborator

PR_Github #31269 [ run ] triggered by Bot. Commit: 4bb57fa

@tensorrt-cicd
Copy link
Collaborator

PR_Github #31269 [ run ] completed with state ABORTED. Commit: 4bb57fa
LLM/main/L0_MergeRequest_PR #24163 (Blue Ocean) completed with status: ABORTED

@pengbowang-nv
Copy link
Collaborator Author

/bot run --add-multi-gpu-test --disable-fail-fast

@tensorrt-cicd
Copy link
Collaborator

PR_Github #31360 [ run ] triggered by Bot. Commit: 4bb57fa

@tensorrt-cicd
Copy link
Collaborator

PR_Github #31360 [ run ] completed with state SUCCESS. Commit: 4bb57fa
/LLM/main/L0_MergeRequest_PR pipeline #24251 completed with status: 'FAILURE'

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

@pengbowang-nv
Copy link
Collaborator Author

/bot run --add-multi-gpu-test --disable-fail-fast

@tensorrt-cicd
Copy link
Collaborator

PR_Github #31387 [ run ] triggered by Bot. Commit: 4bb57fa

@tensorrt-cicd
Copy link
Collaborator

PR_Github #31387 [ run ] completed with state DISABLED
CI server is currently disabled for scheduled maintenance. Estimated completion time: 8 AM PST on 1/11.

@pengbowang-nv
Copy link
Collaborator Author

/bot run --add-multi-gpu-test --disable-fail-fast

@tensorrt-cicd
Copy link
Collaborator

PR_Github #31416 [ run ] triggered by Bot. Commit: 4bb57fa

@tensorrt-cicd
Copy link
Collaborator

PR_Github #31416 [ run ] completed with state SUCCESS. Commit: 4bb57fa
/LLM/main/L0_MergeRequest_PR pipeline #24277 completed with status: 'SUCCESS'

@pengbowang-nv pengbowang-nv merged commit c0e25e5 into NVIDIA:main Jan 12, 2026
5 checks passed
videodanchik pushed a commit to videodanchik/TensorRT-LLM that referenced this pull request Jan 14, 2026
…ttention (NVIDIA#10264)

Signed-off-by: Pengbo Wang <221450789+pengbowang-nv@users.noreply.github.com>
Signed-off-by: Daniil Kulko <kulkodaniil@gmail.com>
@pengbowang-nv pengbowang-nv deleted the dev-add-hopper-xqa-skip-softmax branch January 26, 2026 09:49
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants