Skip to content

[Kernel] Porting the TRTLLM minimax_allreduce_rms kernels#37045

Merged
youkaichao merged 50 commits intovllm-project:mainfrom
jeejeelee:minmax-m2-norm-trtllm-kernel
Apr 10, 2026
Merged

[Kernel] Porting the TRTLLM minimax_allreduce_rms kernels#37045
youkaichao merged 50 commits intovllm-project:mainfrom
jeejeelee:minmax-m2-norm-trtllm-kernel

Conversation

@jeejeelee
Copy link
Copy Markdown
Collaborator

@jeejeelee jeejeelee commented Mar 14, 2026

Purpose

See: NVIDIA/TensorRT-LLM#12163

Plan

  • Clean up the code
  • Verify the accuracy on GSM8K Benchmark
  • Performance Benchmark
  • Add unit test

Test Plan

Accuracy Verification(69f231c)

  • vLLM script(TP4)
vllm serve MiniMaxAI/MiniMax-M2.5 \
  --tensor-parallel-size 4 \
  --tool-call-parser minimax_m2 \
  --reasoning-parser minimax_m2_append_think  \
  --served-model-name m25 \
  --trust-remote-code \
  --enable-auto-tool-choice
Requesting API: 100%|██████████████████████████████████████████████████████████████████| 1319/1319 [01:35<00:00, 13.87it/s]
2026-03-15:03:49:34 INFO     [loggers.evaluation_tracker:316] Output path not provided, skipping saving results aggregated
local-completions ({'model': 'MiniMaxAI/MiniMax-M2.5', 'base_url': 'http://0.0.0.0:8000/v1/completions', 'tokenized_requests': False, 'tokenizer_backend': None, 'num_concurrent': 128}), gen_kwargs: ({}), limit: None, num_fewshot: 5, batch_size: 1
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match||0.9295|±  |0.0071|
|     |       |strict-match    |     5|exact_match||0.9265|±  |0.0072|
  • This PR
Requesting API: 100%|██████████████████████████████████████████████████████████████████| 1319/1319 [01:36<00:00, 13.67it/s]
2026-03-15:04:03:23 INFO     [loggers.evaluation_tracker:316] Output path not provided, skipping saving results aggregated
local-completions ({'model': 'MiniMaxAI/MiniMax-M2.5', 'base_url': 'http://0.0.0.0:8000/v1/completions', 'tokenized_requests': False, 'tokenizer_backend': None, 'num_concurrent': 128}), gen_kwargs: ({}), limit: None, num_fewshot: 5, batch_size: 1
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match||0.9287|±  |0.0071|
|     |       |strict-match    |     5|exact_match||0.9257|±  |0.0072|
  • AIME25 (epoch=4):0.8417

Performance

  • server
vllm serve MiniMaxAI/MiniMax-M2.5 \
  -tp 4 \
  --tool-call-parser minimax_m2 \
  --reasoning-parser minimax_m2_append_think  \
  --trust-remote-code \
  --no-enable-prefix-caching \
  --load-format fastsafetensors \
  --enable-auto-tool-choice
  • benchmark
vllm bench serve \
  --backend vllm \
  --model MiniMaxAI/MiniMax-M2.5 \
  --endpoint /v1/completions \
  --dataset-name random \
  --random-input 2048 \
  --random-output 1024 \
  --max-concurrency 10 \
  --num-prompt 100  \
  --trust-remote-code
  • with fused kernel
============ Serving Benchmark Result ============
Successful requests:                     100       
Failed requests:                         0         
Maximum request concurrency:             10        
Benchmark duration (s):                  107.92    
Total input tokens:                      204800    
Total generated tokens:                  102400    
Request throughput (req/s):              0.93      
Output token throughput (tok/s):         948.81    
Peak output token throughput (tok/s):    1010.00   
Peak concurrent requests:                20.00     
Total token throughput (tok/s):          2846.43   
---------------Time to First Token----------------
Mean TTFT (ms):                          293.79    
Median TTFT (ms):                        320.57    
P99 TTFT (ms):                           449.71    
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          10.26     
Median TPOT (ms):                        10.27     
P99 TPOT (ms):                           10.51     
---------------Inter-token Latency----------------
Mean ITL (ms):                           10.26     
Median ITL (ms):                         10.14     
P99 ITL (ms):                            10.57     
==================================================
  • wo fused kernel
============ Serving Benchmark Result ============
Successful requests:                     100       
Failed requests:                         0         
Maximum request concurrency:             10        
Benchmark duration (s):                  112.06    
Total input tokens:                      204800    
Total generated tokens:                  102400    
Request throughput (req/s):              0.89      
Output token throughput (tok/s):         913.78    
Peak output token throughput (tok/s):    970.00    
Peak concurrent requests:                20.00     
Total token throughput (tok/s):          2741.33   
---------------Time to First Token----------------
Mean TTFT (ms):                          320.53    
Median TTFT (ms):                        354.95    
P99 TTFT (ms):                           547.76    
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          10.64     
Median TPOT (ms):                        10.62     
P99 TPOT (ms):                           10.86     
---------------Inter-token Latency----------------
Mean ITL (ms):                           10.64     
Median ITL (ms):                         10.50     
P99 ITL (ms):                            10.81     
==================================================

20260408 Acc Update

name this PR official data
AIME 2025 0.8583 0.853
GPQA Diamond 0.8485 0.839

Test Result

Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
@jeejeelee jeejeelee marked this pull request as draft March 14, 2026 09:32
@mergify mergify Bot added the ci/build label Mar 14, 2026
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a new CUDA kernel for minimax_reduce_rms operations, including a float4 variant, and integrates it into the vLLM framework. Key changes include adding the kernel to the build system, defining its parameters and structures, registering the operations with PyTorch, and implementing a LamportWorkspace for managing CUDA IPC memory. The MambaMixer in minimax_m2.py is updated to utilize a new MiniMaxText01RMSNormAR class for fused Q+K RMS normalization. Review comments highlight the need to address a FIXME comment and a TODO regarding potentially incorrect indexing logic in minimax_reduce_rms_kernel.cu, as well as a TODO for a performance optimization in the local reduction step. Additionally, the max_tokens parameter in linear_attn.py is hardcoded and should be made configurable to prevent memory issues, and a large block of commented-out code needs to be removed for clarity.

Comment thread csrc/minimax_reduce_rms_kernel.cu Outdated
Comment thread csrc/minimax_reduce_rms_kernel.cu Outdated
// step 3: calculate the rms norm (input * rsqrt(variance + eps))

// load norm weight
// TODO: correct the access_id_in_token
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This TODO comment suggests that the logic for indexing into the rms_gamma weight tensor using access_id_in_token might be incorrect or needs verification. An error here could lead to incorrect model outputs. Please verify this logic and either fix it or remove the TODO comment if it's stale.

Comment thread csrc/minimax_reduce_rms_kernel.cu Outdated
Comment on lines +377 to +378
// TODO: we can do local reduce only within q threads and k threads
// respectively
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The TODO on line 377 suggests a potential performance optimization for the local reduction step. The current implementation uses all threads for reducing both Q and K variances, while it might be more efficient to perform these reductions within their respective thread groups (Q-threads and K-threads). Implementing this could improve the kernel's performance.

Comment thread vllm/model_executor/layers/mamba/linear_attn.py Outdated
Comment thread vllm/model_executor/layers/mamba/linear_attn.py Outdated
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
@jeejeelee jeejeelee marked this pull request as ready for review March 19, 2026 17:32
@mergify
Copy link
Copy Markdown
Contributor

mergify Bot commented Mar 19, 2026

Hi @jeejeelee, the pre-commit checks have failed. Please run:

uv pip install pre-commit>=4.5.1
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy failing?
mypy is run differently in CI. If the failure is related to this check, please use the following command to run it locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10

Copy link
Copy Markdown
Member

@yewentao256 yewentao256 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the work!

Comment thread CMakeLists.txt Outdated
Comment thread vllm/model_executor/layers/mamba/lamport_workspace.py Outdated
@yewentao256 yewentao256 added the ready ONLY add when PR is ready to merge/full CI is needed label Mar 19, 2026
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
@jeejeelee jeejeelee force-pushed the minmax-m2-norm-trtllm-kernel branch from 0b6b533 to 2e63a80 Compare April 7, 2026 05:19
@khluu khluu added this to the v0.19.1 cherry picks milestone Apr 7, 2026
Comment thread vllm/model_executor/models/minimax_m2.py Outdated
@@ -0,0 +1,152 @@
# SPDX-License-Identifier: Apache-2.0
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

need to plug the test into CI.

@wzhao18
Copy link
Copy Markdown
Contributor

wzhao18 commented Apr 7, 2026

I ran some benchmarking on 1K/1K and 8K/1K profiles on minimax fp8 TP=4. max_num_batched_tokens is set to 2K and 16K respectively. This optimization gives 1-2% speedup across most points in the pareto.

isl1024_osl1024 isl8192_osl1024

Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
@mergify
Copy link
Copy Markdown
Contributor

mergify Bot commented Apr 8, 2026

Hi @jeejeelee, the pre-commit checks have failed. Please run:

uv pip install pre-commit>=4.5.1
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy failing?
mypy is run differently in CI. If the failure is related to this check, please use the following command to run it locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10

Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
@mergify
Copy link
Copy Markdown
Contributor

mergify Bot commented Apr 8, 2026

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @jeejeelee.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify Bot added the needs-rebase label Apr 8, 2026
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
@mergify mergify Bot removed the needs-rebase label Apr 9, 2026
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
@youkaichao youkaichao merged commit ecd1ea1 into vllm-project:main Apr 10, 2026
140 of 141 checks passed
@github-project-automation github-project-automation Bot moved this from Todo to Done in AMD Apr 10, 2026
@jeejeelee jeejeelee deleted the minmax-m2-norm-trtllm-kernel branch April 11, 2026 04:44
wojciech-wais pushed a commit to wojciech-wais/vllm that referenced this pull request Apr 13, 2026
khluu pushed a commit that referenced this pull request Apr 16, 2026
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
(cherry picked from commit ecd1ea1)
khluu pushed a commit that referenced this pull request Apr 16, 2026
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
(cherry picked from commit ecd1ea1)

Signed-off-by: khluu <khluu000@gmail.com>
anish-shanbhag added a commit to anish-shanbhag/vllm that referenced this pull request Apr 22, 2026
Cherry-picks the MiniMax QK norm allreduce+RMSNorm Lamport fusion pass
from vllm-project#37045 into our 0.17.1-based branch.

This replaces the per-layer QK norm allreduce + variance computation +
RMSNorm with a single fused Lamport-based CUDA kernel, eliminating
multiple kernel launches per layer (62 layers x ~3 kernels saved).

Changes:
- New CUDA kernel: csrc/minimax_reduce_rms_kernel.{cu,h}
- New Lamport IPC workspace: vllm/model_executor/layers/mamba/lamport_workspace.py
- New compilation fusion pass: vllm/compilation/passes/fusion/minimax_qk_norm_fusion.py
- Config: add fuse_minimax_qk_norm to PassConfig with None default
- Pass manager: register MiniMaxQKNormPass after AllReduceFusionPass
- Compile ranges: add split point so fusion only applies to decode-size batches
- Bindings: register minimax_allreduce_rms and minimax_allreduce_rms_qk ops
- Test: tests/kernels/core/test_minimax_reduce_rms.py

Adapted compile_ranges_endpoints (upstream API) to compile_ranges_split_points
(our 0.17.1 API). Model change (.contiguous() removal) was already done in
commit 2cdf163.

Enable with: --compilation-config '{"pass_config":{"fuse_minimax_qk_norm":true}}'

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
whk-lab pushed a commit to whk-lab/vllm that referenced this pull request Apr 23, 2026
avinashsingh77 pushed a commit to avinashsingh77/vllm that referenced this pull request Apr 27, 2026
…ct#37045)

Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
Signed-off-by: Avinash Singh <avinashsingh.rcoem@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ci/build ready ONLY add when PR is ready to merge/full CI is needed rocm Related to AMD ROCm

Projects

Status: Done

Development

Successfully merging this pull request may close these issues.

7 participants