Linear recurrent models such as Mamba and linear attention possess a remarkable feature: they can process extremely long sequences, which is key for applications that require long context reasoning (like summarizing long texts or agents with long term memory). Indeed, this is their key advantage over their main competitor, the Transformers , which are bottlenecked by their finite context window and quadratic complexity over the sequence length.
Previously, the issue with recurrent models was their performance: on short sequences they were less capable than Transformers. But recent architecture breakthroughs have improved the performance of recurrent models and brought them on par with Transformers, to the point that they are currently used in several industry applications like audio modeling or code completion . However, several recent works have found out that recurrent models still fall short: they might have comparable performance to Transformers, but in many cases they struggle to generalize past the training length.
Indeed, we show the performance of the official Mamba-2 checkpoints as a function of the sequence position $t$ (using perplexity, the lower the better). It can be seen that for positions $t$ beyond the training context $T=2048$, these models become virtually useless: they fail to length generalize.
This is an issue: existing recurrent models have low performance on long sequences, and are not much more efficient than Transformers in shorter sequences; so they seem to be falling short on both sides.
Does this mean that recurrent models are useless? Not at all! In our work, we show that length generalization is easily achievable in recurrent models through simple training interventions: post-training for 500 steps (~0.1% of the pre-training budget) enables length generalization in up to 256k sequences! Therefore, recurrent models possess an unrealised potential rather than a fundamental limitation.
Why Do Recurrent Models Fail to Length Generalize? The Unexplored States Hypothesis
For an input sequence with $t$ elements $(x_1, x_2, ..., x_{t-1}, x_t)$, recurrent models compress the input context $(x_1, x_2, ..., x_{t-1})$ into a fixed-size recurrent state $h_{t-1}$. At time $t=0$, the state is initialized with some value $h_{-1}$, and then it is updated at each $t$ with an update function $f$:
$ h_t = f(h_{t-1}, x_t) $
Similarly, the output at time $t$ only depends on the state $h_t$ and the current input $x_t$, i.e. for some other function $g$ the output $y_t$ can be written as
$y_t = g(h_t, x_t)$
The functions $f$ and $g$ do not depend on the position $t$, so in theory recurrent models can naturally process any sequence length. But then, how can it be that they fail when $t$ is large?
In our work we show that the distribution of the state $h_t$ changes over time. Therefore, even if $g$ and $f$ work correctly up to some $T$, other $h_t$ with $t>T$ might be significantly different, and thus the model fails to produce the correct output. Indeed, in the following figure we show how the norm of the state of Mamba-2 increases significantly over time:
This explains why recurrent models fail to length generalize: when processing sequences longer than those seen during training, they encounter states $h_t$ that have not been explored during training, and thus they have not learnt to process them. Based on this insight, we propose the unexplored states hypothesis to explain the failure to length generalize:
Unexplored States Hypothesis
Recurrent models fail to length generalize when they are trained only on a subset of all attainable state distributions—i.e. on a subset of the states that would be attained if the state recurrence was rolled out indefinitely.
When trained for long enough, the models overfit to this subset and perform poorly on long sequences because they encounter unexplored state distributions.
Interventions to Enable Length Generalization
The unexplored states hypothesis indicates that length generalization can be achieved not by changing the architecture or its mechanisms, but by training the model on a more diverse set of state distributions—in particular, on the distributions that arise when rolling out the state recurrence on long sequences. To do so, we could directly train the model on longer sequences, but this might not always be possible due to GPU memory constraints or due to lack of sufficiently long training sequences.
The recipe to achieve length generalization: interventions on the initial state
Most modern architectures assume a zero initial state ($h_{-1}=0$). In our work, we consider four simple interventions on the initial state $h_{-1}$, which increase the diversity of states that the model explores during training without the need of training on longer sequences.
The four training interventions can be seen as sampling the initial state $h_{-1}$ from four different distributions that progressively get closer to the distribution of attainable states:
1. Random Noise: The state is initialized with an IID Gaussian with zero mean and a constant standard deviation (using the same mean / standard deviation for all layers and heads).
2. Fitted Noise: During training, we record the mean and standard deviation of the final states of the sequences across all layers and heads. Then, we initialize the state with an IID Gaussian distribution with mean and standard deviation fitted to the ones seen during training (using a different mean / standard deviation for each layer and head).
3. State Passing (SP)1: We use the final state of a previous (unrelated) sequence as the initial state. These final states are obtained by applying the state recurrence on a given sequence, attaining $h_T$ and using it as $h_{-1}$ for another sequence. This is similar to what happens at validation: the model doesn't stop at $T$, but rather keeps rolling the state and producing outputs from $h_T$.
4. Truncated Backpropagation Through Time (TBTT): In this case, we split a long sequence into smaller chunks, and use the final state of each chunk as the initial state of the next one. This is equivalent to processing the whole sequence, yet stopping the gradient propagation between chunks.
Difference between SP and TBTT
For simplicity, we implement SP by using the final state of the previous batch of sequences as the initial state of the new one. Thus, in practice the only difference between SP and TBTT is that TBTT requires carefully setting up the dataloader so that the sequences of the previous batch correspond to the prior parts of the sequences in the new batch.
The following figures show the results of post-training the official Mamba-2 models for 500 steps (~0.1% of pre-training budget) with each intervention:
Takeaway #1: SP and TBTT enable length generalization
State Passing and TBTT – which are the interventions that are closer to realistic states – allow length generalization in sequences much longer than those seen during training. Thus:
Takeaway
Length generalization is expected to be readily achievable in recurrent models through simple training interventions.
Note that our results were achieved with only ~0.02% of the original pre-training budget!
Takeaway #2: Properties of the state of recurrent models
Takeaway
We can infer properties of the distribution of the state of recurrent models by looking at the performance of the interventions
.
The Random Noise intervention fails to length generalize in the 370m, whereas Fitted Noise works. This suggests that for the 370m model the distribution of attainable states cannot be approximated with a Gaussian with fixed variance, but it can be approximated with an IID Gaussian with fitted variance in each layer and head of the state. However, the Fitted Noise intervention fails to achieve length generalization in the 1.3b model, indicating that the state of large models probably has complex dependency relationships among its elements and thus cannot be approximated with IID values.
Additionally, the interventions also fix the increasing state norm behavior we showed before, by making the model output states with similar norm at all timesteps:
SP in prior works
1 Prior works have used the State Passing technique , yet it was applied to different recurrent architectures (e.g. time-invariant ones) or to tasks different to text modeling. To the best of our knowledge, we are the first to show that this technique used as a training intervention can greatly improve the length generalization of several recurrent models, and that it is as effective as TBTT in text modeling. ↩
Performance on Long Context Tasks
We have seen that the interventions enable length robustness (i.e. not having decreased peformance after the training context $T$), but it is not clear whether they enable length generalization (i.e. solving tasks that require exploiting relationships between tokens that are separated by more than $T$ positions). One may wonder whether the interventions enable length robustness by simply preventing the model from reasoning beyond the training context length—similar to sliding window attention, which can't reason over tokens separated by more than the sliding window—in which case the models would have constant performance for all evaluation contexts $t > T$, but could not solve tasks that require long context reasoning. In our work we show that the interventions do enable length generalization by showing results on three long context tasks.
BABILong. BABILong is a challenging benchmark which tests both the common sense understanding of a model as well as its ability to capture long range dependencies in text. In the figure below it can be observed that State Passing enhances the length generalization of the model in both the few-shot and finetuned settings (we recall that the model is trained and finetuned on sequences of length 2048). Therefore, State Passing is not only useful in fixing the diverging perplexity of established language models, but also in enhancing their ability to solve long context reasoning tasks.
Passkey retrieval. The passkey retrieval task requires the model to retrieve a 5-digit passkey inserted at a given depth of a long context. In the figure below we show the performance of the Mamba-2 370m and 780m official checkpoints in three settings: zero shot, regular finetuning, and finetuning with fitted noise2. The models finetuned with fitted noise are capable of exploiting relationships between tokens that are much more than 2048 positions apart (the training context length). In particular, the 780m model can solve the passkey perfectly for sequences of length 256k.
Choice of intervention for passkey retrieval
2 Contrary to typical language modeling datasets, the distribution of tokens in the passkey task is not stationary (in other words, there is not a well defined behavior for what the model should do after revealing the passkey). This is why we show results for the fitted noise intervention, as it does not require using the final state of a sequence (i.e., right after revealing the passkey), which might not be appropriate as the initial state.. ↩
Synthetic Copying. The synthetic copying task consists in copying an arbitrary sequence of tokens. In the table below we show that using State Passing during training greatly improves the validation performance in sequences more than three times longer. Thus, state passing helps the model length generalize, solving long context tasks that are harder than those seen during training.
A Deeper Look into How Recurrent Models Process Context
We have shown that the interventions on the initial state enable length robustness and allow solving long context tasks. On top of these findings, we now present a metric that sheds light on how sequence models process their context.
Ideally, in the case of text modeling we would like the model to pay attention to the recent context, and not focus too much on tokens that are too far away. But how can we quantify this behavior? We introduce Effective Remembrance to measure how much an autoregressive model is "effectively" remembering previous tokens. Denote by $q(\cdot \| \text{context})$ the probabilities that an autoregressive sequential model outputs for the next token given a context. Then, we define:
Where \( d(p,\bar{p}) \) is a distance between probability distributions (e.g., Total Variation). \(\text{EffRem}_T(t)\) roughly measures how much the model "effectively remembers" the tokens \( x[0:t-1] \) at time \( T \). If \( \text{EffRem}_T(t) = 0 \), this means that the predictions using \( x[t:T] \) and using \( x[0:T] \) are the same, meaning that the model does not "effectively remember" any of the past tokens \( x[0:t-1] \). Conversely, if \( \text{EffRem}_T(t) \) is high, the model is substantially influenced by the tokens \( x[0:t-1] \), since removing them from the context changes the prediction significantly.
The following figure shows $\text{EffRem}_T(t)$ for two official Mamba-2 checkpoints (which fail to length generalize) for varying $t$ and $T=8192$ (four times the training context):
Intuitively we would expect that while every token contributes to the model’s output, the most recent tokens should have a significantly stronger influence. However, notice how the $\text{EffRem}$ curves immediately jump up and then gradually taper off. This behavior is clearly problematic: the next-token prediction at time $T=8192$ shouldn't change drastically depending on whether the model sees only the recent tokens \( x[4096:8192] \) or the full sequence \( x[0:8192] \). In natural language, the model should primarily rely on recent context, and earlier tokens \( x[0:4096] \) shouldn't completely alter the prediction—especially not to the extent that the total variation between the two output probability distributions approaches 1. This means that the model is disproportionately influenced by tokens at the beginning of the sequence.
Intuition
We hypothesize that when a model is always trained with a zero initial state, it uses the first few tokens it sees to rapidly differentiate the state, which in turn causes overfitting to these tokens.
State Passing fixes Effective Remembrance
After post-training with State Passing, the $\text{EffRem}$ curves show a gradual increase, indicating that the model places minimal weight on distant tokens and places progressively more weight on recent ones. In particular, tokens in the immediate context (e.g. the previous words in a sentence) have a critical impact on the next token predictions, which is the desired behavior in text modeling.
Takeaway
Through Effective Remembrance, we can check that State Passing helps the models prioritize recent context and not be needlessly disrupted by tokens that are far away in the past.
Conclusion
We have shown that length generalization is expected to be achievable in recurrent models through simple training interventions, without the need of changing the architecture nor the internal mechanisms of the model. Moreover, these interventions improve their performance on long context reasoning tasks, suggesting that existing recurrent models are not realising their full potential and can be easily improved.
Secondly, we believe that this work has significant implications for architecture research. For example, it has become very popular for modern recurrent architecture works to compare out-of-length extrapolation abilities . In our work we show that simple training interventions substantially improve length generalization across several recurrent architectures, and thus research can focus mostly on the in-length performance (or if directly studying length generalization, it would be important to account for these interventions).
Lastly, we have proposed Effective Remembrance as a tool to understand how any autoregressive sequence model processes its context, thus making it easy to quantify how much models are "effectively remembering" parts of the context.