A high-performance inference engine for block diffusion language models
SDAR · LLaDa · dLLM-Var
JetEngine is a lightweight, production-ready inference engine for block diffusion language models (SDAR, LLaDa, dLLM-Var). It supports dense and MoE architectures, hybrid Data Parallel + Tensor Parallel distributed inference, CUDA graph acceleration, and advanced remasking strategies for optimal generation quality.
| GPU | Throughput | Key Optimizations |
|---|---|---|
| NVIDIA H200 | 7,500+ tok/s | FA3 + CUDA Graphs + Chain + dynamic_pmax |
| NVIDIA A800 | 3,000+ tok/s | FA2 + Triton Kernels + Paged KV Cache |
- Block Diffusion Decoding — generates tokens in fixed-size blocks via iterative denoising, fundamentally different from autoregressive decoding
- 11 Remasking Strategies — from simple sequential to novel joint-distribution-aware multi-token commit (
dynamic_pmax) - Flash Attention 3 — Hopper SM90 paged attention for denoise, with FA2/flashinfer fallback
- CUDA Graph Capture — graphs for batch sizes 1–128, near-zero kernel launch overhead
- Chain Mechanism — runs up to 5 denoising steps within a single
step()call, eliminating scheduler overhead - Hybrid DP+TP —
tensor_parallel_size× data parallel across any GPU count (e.g., TP=2 × DP=4 on 8 GPUs) - Streaming Generation —
generate_streaming()streams prompts through a fixed active window with automatic prompt interleaving for batch diversity - Model Offloading — completely offload model weights and KV cache to free GPU memory for RL training loops
- Selective Logits — computes LM head only for DENOISING sequences, skipping SAVING sequences
Requirements:
Python >= 3.10
PyTorch >= 2.1
transformers >= 4.52.4
flash-attn >= 2.5
accelerate
Install:
pip install flash-attn --no-build-isolation
git clone https://github.com/Labman42/JetEngine.git
cd JetEngine
pip install .Optional (for Hopper GPUs):
# Flash Attention 3 — enables FA3 paged attention (significant speedup on H100/H200)
pip install flash-attn-3CUDA_VISIBLE_DEVICES='0' accelerate launch --multi_gpu example.py# Uses all visible GPUs for data parallel inference
accelerate launch --multi_gpu example.py# 8 GPUs: TP=2 (model split across 2 GPUs) × DP=4 (4 data shards)
accelerate launch --multi_gpu --num_processes=8 your_script.py \
--tensor_parallel_size 2from jetengine import LLM, SamplingParams
# Initialize engine
llm = LLM(
"path/to/SDAR-4B-Chat",
mask_token_id=151669, # Required: model's mask token
block_length=4, # Required: block diffusion block size
tensor_parallel_size=1, # TP degree (1 = single GPU)
max_model_len=4096,
gpu_memory_utilization=0.9,
)
# Configure sampling
sampling_params = SamplingParams(
temperature=1.0,
max_tokens=4096,
block_length=4,
denoising_steps=4,
remasking_strategy="dynamic_pmax", # Best strategy for accuracy
dynamic_threshold=0.75, # Commit threshold for multi-token
topk=0,
topp=1.0,
)
# Generate
outputs = llm.generate_streaming(
prompts, # List[str] or List[List[int]]
sampling_params,
max_active=128, # Max concurrent sequences
)
for output in outputs:
print(output["text"])outputs = llm.generate(
["Solve x^2 = 4", "What is pi?"],
sampling_params,
)# Free GPU memory after inference for training
llm.offload_parameters() # Free model weights (keep buffers)
llm.free_all_resources() # Free everything (graphs + KV cache)
# Reload from a HuggingFace model for the next eval round
llm.reload_from_hf_model(hf_model)The remasking strategy controls which positions in a block to commit at each denoising step and how many to commit simultaneously. This is the core design knob for block diffusion quality and speed.
| Strategy | Multi-Token | Description | Best For |
|---|---|---|---|
dynamic_pmax |
Yes | Threshold on P(argmax) per position; commits argmax token. EOS-safe. | Best pass@1 & pass@k |
low_confidence_dynamic |
Yes | Threshold on P(sampled); fallback to leftmost | General use |
| Strategy | Description |
|---|---|
sequential |
Commit leftmost masked position (left-to-right) |
anti_sequential |
Commit rightmost masked position |
low_confidence_static |
Commit position with highest P(sampled) |
least_entropy |
Commit position with lowest entropy |
top2_margin |
Commit position with largest top1-top2 probability gap |
| Strategy | Description |
|---|---|
dynamic_pmax |
P(argmax) > threshold; commits argmax tokens with EOS safety |
low_confidence_dynamic |
P(sampled) > threshold; fallback to leftmost |
entropy_bounded |
Commit positions sorted by entropy up to a budget |
causal_waterfall |
Contiguous leftmost prefix where each P(argmax) > floor |
logit_salience |
Z-score outlier detection in logit space |
relative_top |
Commit prefix where P(argmax) ≥ α × max across block |
consensus_commit |
K=4 multinomial majority vote |
| Strategy | pass@1 | pass@4 |
|---|---|---|
| dynamic_pmax (t=0.75) | 0.698 | 0.840 |
| low_confidence_dynamic (t=0.90) | 0.682 | 0.834 |
| sequential | 0.669 | 0.730 |
| low_confidence_static | 0.677 | 0.750 |
At gen16,
dynamic_pmaxachieves pass@16 = 0.96 vs dynamic's 0.93 (+3%).
| Parameter | Default | Description |
|---|---|---|
model |
— | Path to model checkpoint |
mask_token_id |
— | Required. Model's mask token ID |
block_length |
4 | Block size for diffusion decoding |
tensor_parallel_size |
1 | Number of GPUs for tensor parallelism |
max_num_seqs |
512 | Maximum sequences in KV cache |
max_model_len |
4096 | Maximum sequence length |
gpu_memory_utilization |
0.8 | Fraction of GPU memory for KV cache |
enforce_eager |
False | Disable CUDA graphs (for debugging) |
kvcache_block_size |
256 | KV cache page size (must be multiple of 256) |
dtype |
"auto" | Model dtype ("auto", "bfloat16", "float16") |
| Parameter | Default | Description |
|---|---|---|
temperature |
1.0 | Sampling temperature |
max_tokens |
64 | Maximum completion length |
block_length |
4 | Block size (must match LLM) |
denoising_steps |
4 | Denoising iterations per block |
remasking_strategy |
"low_confidence_static" |
Strategy for token commitment |
dynamic_threshold |
0.75 | Confidence threshold for dynamic/pmax strategies |
topk |
0 | Top-k filtering (0 = disabled) |
topp |
1.0 | Top-p (nucleus) filtering |
repetition_penalty |
1.0 | Repetition penalty |
eb_threshold |
0.35 | Entropy budget for entropy_bounded strategy |
pos_temp_slope |
0.0 | Position-temperature: T_i = T × (1 + slope × i/(L-1)) |
stop_words |
None | Token IDs that trigger sequence termination |
| Model | Type | mask_token_id |
block_length |
|---|---|---|---|
| SDAR-1.7B/4B-Chat | Dense | 151669 | 4 |
| SDAR-MoE | MoE | 151669 | 4 |
| LLaDa-8B-Instruct | Dense | 126336 | 1024 |
| dLLM-Var | Dense | 126336 | 64 |
llm = LLM("SDAR-4B-Chat", mask_token_id=151669, block_length=4)
sp = SamplingParams(temperature=1.0, max_tokens=4096, block_length=4,
denoising_steps=4, remasking_strategy="dynamic_pmax",
dynamic_threshold=0.75)llm = LLM("LLaDA-8B-Instruct", mask_token_id=126336, block_length=1024,
gpu_memory_utilization=0.9)
sp = SamplingParams(temperature=1.0, max_tokens=2048, block_length=1024,
denoising_steps=1024, remasking_strategy="low_confidence_dynamic",
dynamic_threshold=0.90)Tip: Set
block_length> prompt length for pure diffusion mode. Withblock_length< prompt length, JetEngine uses prefill + block diffusion (hybrid mode with interesting behaviors).
llm = LLM("dLLM-Var", mask_token_id=126336, block_length=64,
gpu_memory_utilization=0.9)
sp = SamplingParams(temperature=1.0, max_tokens=2048, block_length=64,
denoising_steps=64, remasking_strategy="low_confidence_dynamic",
dynamic_threshold=0.90)LLM (llm.py)
└─ LLMEngine (engine/llm_engine.py)
├─ ModelRunner (engine/model_runner.py)
│ ├─ SDAR / SDAR-MoE / LLaDa model
│ ├─ Flash Attention 3 / flashinfer / FA2 (layers/attention.py)
│ ├─ CUDA Graph capture & replay (bs=1..128)
│ └─ Paged KV Cache (kvcache_block_size=256)
├─ Scheduler (engine/scheduler.py)
│ ├─ Block Manager — allocate/deallocate KV cache pages
│ ├─ postprocess_unify() — batched sampling + strategy dispatch
│ │ ├─ Dense path (all masked, step 0)
│ │ ├─ Sparse path (partial masks, steps 1+)
│ │ └─ Position fast-path (sequential/anti_sequential)
│ └─ Chain mechanism — up to 5 steps per step() call
└─ DistributedManager — hybrid DP + TP via accelerate
Mode 1: Ideal Decode (total_seqs ≤ max_active) — all sequences prefill together, denoise together, drain together. Chain=5 (full block per step). Maximum GPU utilization.
Mode 2: Streaming Decode (total_seqs > max_active) — sequences stream through via generate_streaming(). Prompts are automatically interleaved for batch diversity. Adaptive chain depth (3–5) based on pending queue.
| Opt | Description | pass@1 | tok/s (H200) |
|---|---|---|---|
| baseline | Original | 0.683 | ~1,500 |
| opt17 | CUDA graph 1-128, selective logits | — | 2,677 |
| opt18 | flashinfer paged attention | — | 4,109 |
| opt22 | Lazy entropy in chain intermediate | 0.709 | 4,242 |
| opt23 | Sparse logits (LM head only masked) | 0.711 | 4,268 |
| opt24 | Flash Attention 3 (Hopper SM90) | 0.722 | 5,677 |
| opt25 | Chain=3 for pending prefills + FA3 prefill | 0.705 | 5,677 |
| fix | Gen64 quality fix (multinomial + interleave) | 0.685 | — |
| opt26 | TP support fixes + sequential fast-path | 0.666 | +27% seq |
| opt27 | dynamic_pmax strategy (P(argmax) threshold) |
0.696 | — |
| opt28 | Threshold tuning (0.9 → 0.75) | 0.698 | — |
For pure diffusion models (LLaDa, dLLM-Var), the logits tensor scales with batch × context_length × vocab_size and can exhaust GPU memory. Mitigations:
- Decrease
max_num_seqsinLLM()initialization - Decrease
max_activeingenerate_streaming() - Use
gpu_memory_utilization=0.9or lower
For issues or inquiries:
- Yihan Bian, University of Maryland, College Park — ybian@umd.edu
- GitHub Issues: Labman42/JetEngine