Stormin' The Castle

attentionsparsenotes

Sparse Attention

by John Robinson @johnrobinsn

SpargeAttention: Notes on the Sparse Attention Method from Tsinghua

Been looking into SpargeAttention from the thu-ml group (same team behind SageAttention). It's a training-free sparse attention that claims 4-7x speedup over FlashAttention. Wanted to jot down how it actually works.

The Basic Idea

Attention maps are sparse in practice — lots of near-zero values that we compute anyway. SpargeAttn predicts which blocks will be near-zero and skips them entirely.

How the Prediction Works

This is the interesting part. They call it "selective token compression." The attention matrix is divided into tiles (128×64 blocks in SpargeAttn's case) to align with GPU compute patterns.

Step 1: Compress ALL blocks

Every Q block and K block gets compressed to a single representative token by averaging the tokens within. This happens universally — no thresholding here.

Step 2: Compute self-similarity scores

For each block, compute the mean cosine similarity between tokens within that block. This tells you how "uniform" the block is — can a single averaged token reliably represent all the tokens in this block?

Step 3: Run cheap attention on compressed tokens

Compute attention on the compressed q and k. This is fast because instead of seq_len × seq_len, you're doing num_blocks × num_blocks. The output is a rough block-level attention map.

Before this step, they set Ŝ[:, j] = -∞ for any K block with low self-similarity (below threshold θ). This excludes unreliable K blocks from the prediction.

Step 4: TopK mask generation

For each row of the compressed attention map, select blocks whose cumulative probability mass reaches τ (the topk parameter). These get mask value 1, everything else gets 0.

Step 5: Force-include low-similarity blocks

Here's the safety mechanism: blocks that couldn't be reliably compressed get force-included regardless of what the prediction said:

This ensures you never skip a heterogeneous block just because its compressed representation happened to score low.

Why this matters

The key insight: the similarity threshold doesn't control "compress vs don't compress" — ALL blocks get compressed. It controls whether to trust the prediction for that block. High-similarity blocks participate in TopCDF selection normally. Low-similarity blocks bypass the prediction entirely and get computed no matter what.

In practice, low-similarity "fix blocks" are a small fraction (~2% in some experiments), so most blocks participate in the normal sparsity selection.

Second Stage: Softmax-Aware Filtering

Even after the first filter, some blocks will be dominated by neighbors after softmax. They track local vs global max during online softmax — if a block's local max is way below global max, its post-softmax contribution is negligible. These get pruned with zero overhead since it happens during the softmax pass anyway.

Usage

If you're already on SageAttention, migration is one line:

from spas_sage_attn import spas_sage2_attn_meansim_topk_cuda
output = spas_sage2_attn_meansim_topk_cuda(q, k, v, topk=0.5, is_causal=False)

The topk parameter (0-1) controls sparsity vs accuracy tradeoff. At 1.0 you're just running SageAttention. At 0.5 you're skipping roughly half the blocks.

For diffusion models, they've found you can push fairly aggressive (0.3-0.5) without visible quality loss — attention is naturally sparse due to temporal/spatial redundancy.

Real-World Use: TurboDiffusion

The same group integrated this into TurboDiffusion, which combines SpargeAttn + step distillation (rCM) + W8A8 quantization. On an RTX 5090:

That's ~97x speedup on the 1.3B model at 480p.


Share on Twitter |  Discuss on Twitter

John Robinson © 2022-2025