20 State Space Models: Mamba & Beyond
Who this chapter is for: Mid / FDE What you’ll be able to answer after reading this:
- Why transformers have a fundamental O(n²) cost and what SSMs do differently
- The mathematical structure of SSMs: linear recurrence, discretization, and structured state spaces
- What makes Mamba’s selective scan fundamentally different from prior SSMs
- How parallel scan enables SSMs to train as fast as transformers despite recurrent structure
- The current practical boundaries: where SSMs win and where transformers still dominate
20.1 The Transformer Bottleneck
The O(n²) complexity of self-attention is not an implementation accident — it is a fundamental property of the computation. Attention computes pairwise relationships between every position pair in the sequence: the Q×K^T matrix multiplication produces an n×n matrix where n is sequence length. Computing this matrix is O(n²d) time and storing it requires O(n²) memory. For n=32,768 (32k context), the attention matrix is approximately 4GB at FP16 before considering batch size. At n=131,072 (128k), it is 64GB per sequence — impossible to store on a single GPU. FlashAttention avoids storing the full matrix in HBM but the compute complexity remains O(n²). Doubling context length quadruples attention compute, not doubles it, making very long contexts prohibitively expensive even with memory-efficient implementations.
The KV cache compounds this at inference time. As the model generates tokens autoregressively, every new token must attend to all previous tokens. The KV cache stores the key and value projections from all previous positions: its size scales as 2 × n × num_heads × head_dim × num_layers × 2 bytes (FP16). For a 32-layer model with 32 heads of dimension 128, a single 32k-token sequence requires approximately 2 × 32768 × 32 × 128 × 32 × 2 ≈ 16GB for the KV cache alone. Serving a large batch of long-context requests means the KV cache dominates GPU memory allocation, directly limiting throughput. The memory problem is linear in context length, but the compute problem is quadratic — both grow unsustainably as you scale context.
Recurrent architectures (RNNs, LSTMs) solve the memory problem trivially: they maintain a fixed-size hidden state that summarizes all past context, regardless of sequence length. Inference is O(1) per token in memory and O(d²) in compute, independent of sequence length. The fatal flaw is training: because each step depends on the previous, backpropagation through time (BPTT) must process steps sequentially, making training O(n) sequential operations. This prevents parallelism across the sequence dimension — GPUs and TPUs are optimized for parallel matrix operations, not sequential chains. LSTM training on a 32k-length sequence requires 32k sequential recurrent steps, which is catastrophically slow compared to transformers that process all positions in parallel. This is why RNNs were replaced despite their inference-time advantages: the training bottleneck made them non-viable for large-scale pretraining.
SSMs thread the needle: they are recurrent in structure (constant-memory inference) but can be reformulated as a global convolution that admits fully parallel training. The recurrence h_{t} = A h_{t-1} + B x_t defines the state evolution; but if A is time-invariant (fixed matrices), the entire sequence output can be computed as a single convolution y = x ★ K where K is the SSM’s convolution kernel derived analytically from A, B, C matrices. This global convolution runs as a single parallel operation over the full sequence, enabling the same GPU parallelism as transformer training. At inference, you switch back to the recurrent form and process one token at a time with O(1) memory. This dual representation — parallel for training, recurrent for inference — is the architectural innovation that makes SSMs computationally attractive.
20.2 SSM Fundamentals
The continuous-time SSM is defined by two equations: h’(t) = Ah(t) + Bx(t) for state evolution, and y(t) = Ch(t) for output. Here h(t) ∈ ℝ^N is the hidden state, x(t) ∈ ℝ is a scalar input (or a channel of a multi-channel input), A ∈ ℝ^{N×N} is the state transition matrix, B ∈ ℝ^{N×1} is the input projection, and C ∈ ℝ^{1×N} is the output projection. This is exactly a linear time-invariant (LTI) system from control theory. The parameters A, B, C are fixed matrices shared across all time steps — they do not depend on the input. This time-invariance is both the SSM’s efficiency advantage and its limitation before Mamba.
Discretization converts the continuous-time equations to a discrete recurrence suitable for processing token sequences. The zero-order hold (ZOH) method is the standard: given a step size Δ (the timescale parameter), the discretized matrices are Ā = exp(ΔA) and B̄ = (ΔA)^{-1}(exp(ΔA) − I) ΔB. In practice, with diagonal A, exp(ΔA) is computed element-wise and efficiently. The discrete recurrence is h_t = Āh_{t-1} + B̄x_t and y_t = Ch_t, with Ā and B̄ derived from A, B, Δ. The timescale parameter Δ controls how much “memory” the state retains: large Δ corresponds to longer-range dependencies; small Δ causes the state to decay quickly and focus on local context. In pre-Mamba SSMs, Δ is a learned scalar or vector but still input-independent.
The structured state space design addresses the question: what should A be? An unconstrained N×N matrix is expensive to compute exp(ΔA) for and may have poor gradient flow. The HiPPO (High-order Polynomial Projection Operator) framework provides a theoretically motivated initialization for A that gives the SSM a principled mechanism for memorizing polynomial projections of its input history. Concretely, the HiPPO matrix is structured such that the hidden state h_t optimally approximates the history of x up to time t as a projection onto a set of basis polynomials (Legendre, Fourier, etc.). This is why SSMs initialized with HiPPO can theoretically memorize arbitrary-length history — the state tracks the continuous function of past inputs, not just a lossy summary. In practice, S4 (Structured State Space Sequence Model) uses diagonal-plus-low-rank (DPLR) approximations of HiPPO-initialized A, enabling efficient computation while retaining the theoretical memory properties.
Parallel scan is the algorithmic trick that enables efficient training of SSMs without sequential computation. The recurrence h_t = Ā h_{t-1} + B̄ x_t is a first-order linear recurrence. The parallel prefix scan algorithm computes all h_t values in O(log n) sequential steps using O(n log n) work — a substantial improvement over the naive O(n) sequential chain. The algorithm proceeds by computing partial sums in a tree structure: first pair adjacent elements, then combine pairs, then combine groups of 4, etc., doubling the covered range at each level. On GPU hardware with many parallel processors, all operations at each level execute simultaneously, giving effective O(log n) wall-clock time instead of O(n). For n=32k, this is a 15× reduction in the number of sequential steps, enabling training that is competitive with transformer training speeds despite the recurrent formulation.
20.3 Mamba’s Selective State Space
The central limitation of pre-Mamba SSMs is input-independence. The matrices A, B, C, Δ are learned parameters but apply identically to every token in every sequence — the SSM computes the same transformation regardless of the content of the input. This means an SSM cannot selectively “forget” irrelevant context or “sharpen” its attention on salient inputs. It processes a boring filler word and a critical keyword with the same transition matrix. This is adequate for tasks that depend on the aggregate statistics of a sequence (like many classification tasks) but insufficient for tasks requiring selective recall of specific content — precisely the tasks where transformers excel through attention.
Mamba’s key innovation is making the SSM parameters input-dependent: B, C, and Δ are computed as functions of the current input x_t. Concretely, B_t = LinearB(x_t), C_t = LinearC(x_t), and Δ_t = softplus(LinearΔ(x_t) + parameter_bias). The A matrix remains structured (fixed diagonal structure with learned parameters, not input-dependent) for computational tractability, but B, C, and Δ varying per token gives the model the ability to selectively update its state and selectively read from it. When Δ_t is large for a particular token, the state updates strongly (high memory write); when Δ_t is small, the state largely ignores the input (high forgetting). When C_t is in a direction aligned with the current state, the output strongly reflects that state component; when misaligned, that state component is suppressed. The combination of selective write (via Δ and B) and selective read (via C) gives Mamba associative recall capability that pure LTI SSMs lack.
The hardware-aware parallel scan algorithm in Mamba addresses a practical obstacle: the selective (input-dependent) SSM cannot precompute a single convolution kernel because the kernel changes with every input. Standard parallel scan would require materializing intermediate state tensors of shape (batch, seq_len, d_state), which at large sequence lengths exceeds GPU SRAM capacity and requires slow HBM reads. Mamba’s hardware-aware algorithm reorders computation to keep intermediate states in SRAM by fusing the scan with kernel computation: it tiles the sequence into chunks that fit in SRAM, processes each chunk entirely in registers and SRAM, and writes only the final chunk state to HBM. This is conceptually analogous to FlashAttention’s tiling approach — both avoid materializing large intermediate tensors in slow memory. The result is that Mamba’s selective scan runs at near-memory-bandwidth-bound speed despite the input-dependent parameters, making it practical to train on sequences of millions of tokens.
Inference characteristics are where Mamba’s advantage is most pronounced. At inference time, Mamba switches from parallel scan to pure recurrence: each new token updates the state h_t = A_t h_{t-1} + B_t x_t and produces an output y_t = C_t h_t. This requires O(1) memory per layer — a constant-size state vector of dimension d_state × d_model, independent of sequence length. For a Mamba model with d_model=1024 and d_state=16, the state per layer is 16KB — contrasted with a transformer KV cache that grows linearly. At 100k token context, a transformer KV cache is hundreds of gigabytes; a Mamba model’s state is still just 16KB per layer. This makes Mamba practically advantageous for ultra-long streaming inference where the transformer KV cache would exhaust GPU memory long before the sequence ends.
20.4 Hybrid Models and Practical Boundaries
Hybrid architectures that combine attention and SSM layers address the complementary weaknesses of each. Pure SSMs excel at long-range sequential processing and streaming inference but struggle with precise in-context retrieval — associative recall of specific facts from long context (e.g., “what was the exact value mentioned in paragraph 3?”) is harder for SSMs than for transformers. Pure transformers excel at precise attention-based retrieval but fail at ultra-long context due to O(n²) cost. Hybrid models interleave attention layers (for precise retrieval) with SSM layers (for efficient long-range context compression), getting the best of both.
Jamba (AI21 Labs, 2024) combines Mamba layers with transformer attention layers and MoE FFN layers in a single architecture. The design rationale: attention layers handle precise retrieval tasks where Mamba underperforms, while Mamba layers handle long-range context integration efficiently. MoE FFN layers provide high model capacity without proportional compute cost. Jamba-1.5-Large achieved competitive quality with transformers of similar compute budget while supporting 256k context windows without the memory explosion that would afflict a pure transformer. The interleaving ratio (how many attention vs. Mamba layers) is a hyperparameter: more attention layers improve quality on retrieval-heavy tasks at higher memory cost; more Mamba layers reduce memory and improve throughput at some quality cost.
RWKV (Receptance Weighted Key Value) takes a different hybrid approach, reformulating attention as a linear recurrence without the softmax. The core idea: standard attention output is a weighted sum of values, with weights computed via softmax(QK^T/√d). RWKV approximates this with a linear recurrence where each state update is a running numerator-denominator pair tracking the weighted sum of values. This gives RWKV transformer-like parallelism during training (via parallel prefix scan on the recurrence) and RNN-like O(1) memory during inference. RWKV forgoes the full O(n²) attention capability, which limits precision on long-range associative recall but eliminates the quadratic cost. RWKV models up to 14B parameters have been trained and perform comparably to transformers on most benchmarks but fall behind on tasks requiring precise long-range recall.
The transformer still dominates production deployment for three practical reasons. First, the pretraining data advantage: virtually all large-scale foundation models trained at frontier scale (70B+) are transformers, meaning available pretrained weights, instruction-tuned checkpoints, and fine-tuning ecosystems are transformer-centric. SSM pretraining infrastructure is less mature. Second, quality on reasoning and retrieval benchmarks: at equivalent scale and compute, Mamba-style models show quality gaps on tasks involving in-context learning and associative recall, which are important for enterprise use cases. Third, tooling: vLLM, TGI, TensorRT-LLM, and the entire serving optimization stack is built and optimized for transformer architectures. SSM serving infrastructure is nascent. The gap is closing, and SSMs are compelling for specific use cases (ultra-long streaming context, edge deployment), but the practical choice for most production deployments in 2024-2025 remains transformers with efficient attention implementations.
20.5 Interview Questions
Q1. What is the key computational advantage of SSMs over transformers?
SSMs have O(n) training complexity (with parallel scan) and O(1) inference memory, compared to transformers’ O(n²) attention complexity and O(n) KV cache memory. The transformer’s self-attention must compute relationships between every pair of positions in the sequence, producing an n×n attention matrix. At long sequence lengths, this becomes the dominant cost — both in FLOP count and in memory. SSMs instead maintain a fixed-size hidden state that summarizes all past context and update it with each new token. No pairwise comparison, no growing cache.
The practical implication is dramatic at long context: a transformer processing a 1M-token sequence requires an astronomically large attention matrix (or a very expensive FlashAttention pass), while an SSM processes the same sequence token-by-token with the same fixed state size throughout. This makes SSMs attractive for streaming inference applications (IoT, real-time audio/video processing) and ultra-long document processing where transformer KV caches would exhaust GPU memory.
Q2. Why can’t you just use an RNN instead of a transformer for long sequences?
RNNs have the same inference-time advantage as SSMs — O(1) memory per token generated — but they have a catastrophic training disadvantage. Because each hidden state depends on the previous hidden state, backpropagation through time (BPTT) must unroll the network through every time step sequentially. For a 32k-token sequence, that is 32k sequential operations during the backward pass. GPUs are designed for massively parallel computation; sequential dependencies eliminate almost all parallelism. Training an RNN on a 32k-sequence dataset is orders of magnitude slower than training a transformer on the same data.
SSMs escape this trap by exploiting the mathematical structure of linear recurrences. When the state transition is a linear operation (h_t = A h_{t-1} + B x_t with fixed A), the entire sequence of hidden states can be computed via parallel prefix scan in O(log n) sequential steps using O(n) processors in parallel. This parallel scan formulation gives SSMs transformer-like training speed while retaining the RNN’s inference-time efficiency. RNNs with nonlinear activations in the recurrence (which is why LSTMs use gates — to handle vanishing gradients) cannot be parallelized this way, which is why they could not compete with transformers for large-scale pretraining.
Q3. What makes Mamba “selective” compared to earlier SSMs?
Earlier SSMs (S4, S5, H3) use fixed matrices A, B, C that are learned during training but applied identically to every input token at inference. The SSM computes the same state transition regardless of whether the current token is a critical keyword or an irrelevant filler word. This input-independence means the SSM cannot focus on or ignore specific parts of the input — it processes everything with the same transformation.
Mamba makes the B, C, and Δ parameters input-dependent: for each token x_t, Mamba computes B_t = Linear(x_t), C_t = Linear(x_t), and Δ_t = softplus(Linear(x_t)). The Δ parameter controls the state update rate — when Δ_t is large, the new token’s information strongly updates the state (high write strength); when small, the state mostly ignores the input. C_t controls what aspect of the state is read for the current output. This mechanism allows Mamba to selectively memorize relevant tokens, selectively forget irrelevant tokens, and selectively attend to specific state components — approximating the discriminative behavior of attention through selective state updates and reads, but with O(1) inference memory instead of a growing KV cache.
Q4. Explain the parallel scan trick that makes SSMs trainable efficiently.
A first-order linear recurrence h_t = a_t h_{t-1} + b_t (where a_t and b_t may vary per step) looks sequential — each value depends on the previous. But the parallel prefix scan algorithm computes all n values in O(log n) sequential rounds using O(n) parallel processors.
The key insight: any prefix of the recurrence can be collapsed into an affine function h_t = A_{t:0} h_0 + B_{t:0}, where A and B are accumulated transition coefficients. Two affine functions compose as (A_2, B_2) ∘ (A_1, B_1) = (A_2 A_1, A_2 B_1 + B_2). This composition is associative, enabling tree-structured parallel computation. In round 1, compute pairwise compositions for all adjacent pairs (n/2 operations in parallel). In round 2, compose results with stride-2 steps (n/4 operations in parallel). Continue for log₂(n) rounds. Every h_t is computable independently once you have the prefix accumulated up to position t.
On GPU hardware, all operations within a round execute simultaneously. For n=32,768, naive sequential scan requires 32,768 sequential steps; parallel scan requires only 15 sequential rounds of decreasing-size parallel operations. The total work is O(n log n) vs. O(n) for sequential, but the wall-clock time advantage is enormous on parallel hardware. Mamba’s hardware-aware implementation further fuses the scan with the selective parameter computation, keeping intermediate states in SRAM rather than writing them to HBM, achieving near-peak memory bandwidth utilization.
Q5. Compare Mamba’s inference characteristics (memory, speed) to a transformer at 100k token context.
At 100k token context, the comparison is stark. A transformer must maintain a KV cache containing the K and V projections for all 100,000 previous tokens. For a 32-layer transformer with 32 heads of dimension 128, this is 2 × 100,000 × 32 × 128 × 32 × 2 bytes ≈ 52GB per sequence. Serving a single 100k-context sequence saturates an A100 80GB GPU in KV cache alone before considering weights or activations. Batching multiple such sequences is essentially impossible without significant compression.
A Mamba model of equivalent size maintains a fixed state per layer: for Mamba-3B with hidden dimension 2,560 and state size 16, the state per layer is 2,560 × 16 × 4 bytes ≈ 160KB. Across 32 layers, total state is ~5MB, completely independent of context length. At 100k tokens, Mamba uses 5MB of recurrent state; the transformer needs 52GB of KV cache. Memory is no longer the binding constraint for Mamba’s long-context inference.
Decode speed: Mamba’s per-token decode requires one matrix-vector multiply per layer (state update) — the same operation regardless of context length. Transformer decode requires attending over all previous KV positions, with compute growing linearly with context. At 100k context, transformer decode is noticeably slower per token than at 1k context; Mamba decode is identical. The tradeoff is quality: transformers can precisely retrieve any specific token from the 100k-long KV cache; Mamba’s state is a lossy compression that may lose specific details from much earlier in the sequence.
Q6. What is Jamba and why did Jamba combine attention with Mamba?
Jamba (AI21 Labs, 2024) is a hybrid transformer-Mamba-MoE architecture that interleaves standard multi-head attention layers with Mamba SSM layers, and replaces FFN sublayers with MoE expert banks in most layers. The architecture alternates in a pattern like 3 Mamba layers : 1 attention layer, with MoE FFNs throughout.
The motivation for combining attention with Mamba addresses the complementary failure modes of each. Pure Mamba models underperform transformers on tasks requiring precise in-context retrieval — when the model must retrieve a specific fact mentioned thousands of tokens ago, the Mamba state may have compressed it into a lossy representation. Attention layers provide exact retrieval capability: they can attend directly to any position in the sequence, regardless of how it has been compressed by the SSM state. Pure transformers cannot scale to extremely long contexts due to O(n²) attention cost and O(n) KV cache growth. Mamba layers handle long-range context integration efficiently, compressing distant context into a fixed-size state.
The interleaving serves both goals: Mamba layers efficiently propagate and summarize long-range context; the periodic attention layers provide precise retrieval from the most recently attended context. Jamba-1.5-Large demonstrated that this hybrid achieves competitive benchmark quality with transformer-only models while supporting 256k context windows at a fraction of the memory cost, validating the complementarity hypothesis.
Q7. A customer has a use case processing 500k-token documents in real-time streaming. Would you recommend Mamba or transformer?
For 500k-token streaming document processing, Mamba or a Mamba-hybrid is the technically correct recommendation with important caveats about production readiness.
The transformer is not viable without massive infrastructure: 500k-token context requires a KV cache of approximately 260GB per sequence (for a 32-layer, 32-head transformer at FP16), which cannot fit on any available GPU and would require aggressive quantization and offloading strategies that add hundreds of milliseconds of latency. Even with FlashAttention-3, the O(n²) compute cost at 500k tokens is prohibitive for real-time requirements.
Mamba’s 500k-token inference requires the same fixed state size as 1k-token inference — a fundamental advantage. For streaming (left-to-right processing as tokens arrive), Mamba’s recurrent mode is exactly the right paradigm: process each token in O(1) memory and O(d²) compute, accumulate state, produce output, move to next token. Per-token latency is constant regardless of how many tokens have been seen.
However, important caveats for enterprise deployment in 2025: production-grade serving infrastructure for Mamba is less mature than for transformers. vLLM does not natively support Mamba (though there are community implementations). Fine-tuning tooling (PEFT, Axolotl, etc.) is primarily transformer-focused. The quality of available Mamba checkpoints at the scale needed for complex document understanding lags behind frontier transformer models. Jamba-1.5-Large is the most production-ready option, offering 256k native context with transformer-competitive quality and better tooling than pure Mamba.
My recommendation: prototype with Jamba for up to 256k context or a Mamba-hybrid for 500k; plan for a transition to more mature SSM infrastructure as the ecosystem develops in 2025-2026.
Q8. What are the current gaps in SSM model quality vs. transformers at equivalent scale?
The quality gaps between SSMs and transformers at equivalent scale (matching total parameters and training tokens) cluster around three capability areas as of 2025.
In-context learning (ICL) and few-shot prompting: transformers learn from examples provided in the prompt by attending directly to them during generation. SSMs must compress prompt examples into the recurrent state, and the state may not preserve all the inductive signals needed for reliable few-shot learning. Mamba models show measurable ICL capability but lag transformers on tasks requiring precise pattern extraction from 5-10 in-context examples, particularly when the pattern is subtle.
Precise retrieval from long context: the “needle in a haystack” style tasks — finding a specific fact embedded in a large document — remain weaker for pure Mamba vs. transformers. The recurrent state is a lossy compression and information from very early in a long sequence may be overwritten by later content. Hybrid models (Jamba, Zamba) largely close this gap by including periodic attention layers.
Complex multi-step reasoning: tasks on MATH, AIME, and competitive programming benchmarks show a systematic ~3-5 point gap between the best SSM models and equivalent-parameter transformers. The hypothesis is that multi-step reasoning requires precise “in-context scratch pad” operations that attention’s exact retrieval enables but SSM’s lossy state hinders. This gap is largest on reasoning benchmarks and smallest on language modeling perplexity.
What is not a gap: language modeling perplexity on standard pretraining distributions, instruction following quality (for well-fine-tuned models), factual QA, and summarization are essentially equivalent between SSMs and transformers at matched scale and training tokens. The practical gap is narrower than the architectural difference might suggest.