We introduced our Mamba-3 model in Part I in which we mentioned that the three core methodological changes were inspired by the SSM perspective. Here, we’ll actually do a deep dive into what each of these three improvements entail, their motivations, and their derivations.
But first, let’s refresh our memory on the underlying state space model and its background.
The state space model, at its most primitive, is a simple, continuous ordinary differential equation (ODE). The input $x(t) \in \mathbb{R}$ is mapped to output $y(t) \in \mathbb{R}$ through a hidden state $h(t) \in \mathbb{R}^N$ of size $N$, also referred to as the state size. In the past, in both deep learning and classical control theory, these systems were linear time-invariant (LTI), where the “state decay” transition $A \in \mathbb{R}^{(N\times N)}$, $B \in \mathbb{R}^{N}$, and $C \in \mathbb{R}^{N}$ terms were constant.
\[\begin{aligned} h'(t) &= A h(t) + B x(t) \\ y(t) &= C^\top h(t) \end{aligned}\]We will occasionally refer to $A$ as the state-transition, and $Bx(t)$ as the state-input.
Upon discretization with one’s favorite method, as demonstrated with the zero-order hold (ZOH) used in both Mamba-1 and Mamba-2, a familiar recurrence materializes,
\[\begin{aligned} h_{t} &= e^{\Delta_t A_t} h_{t-1}+ A_t^{-1}(e^{\Delta_t A_t} - I)\,B_t\,x_t \\ y_t &= C_t^\top h_t \end{aligned}\]where the discretized $\bar{A}$ and $\bar{B}$ are now $e^{\Delta_t A_t}$ and $A_t^{-1}(e^{\Delta_t A_t} - I)\,B_t$ respectively.
Eagle-eyed readers may ask “how does one go from a LTI system to a linear-time varying (LTV) system?” — if you did, the answer is revealed below!
We’ll let you in on a little secret: prior Mamba discretizations used the canonical ZOH discretization scheme and just converted the fixed time-invariant variables A, B, and C to time-varying!
No worries if this feels uneasy. We felt that too, which is why we formalized the discretization later (sorry for the clickbait, the answer is not here; you’ll have to keep on reading).
While there are no theoretical restrictions on the class of matrices $\bar{A}$ can be, computational constraints keep transition matrices structured, e.g., diagonal, scalar times identity, Householder (identity plus low-rank), etc.
Great, now we’ve set up the underlying mechanism of Mamba used in both Mamba-1 and Mamba-2!
As a quick recap, Mamba-3 builds on Mamba-2 to improve the efficiency-performance trade-off of current SSMs with inference at the forefront. The three core improvements we’ve been discussing in this post are rooted in classical state space theory:
Let’s jump straight into it.
Our end goal is to obtain a more general recurrence than that of current models from first principles. Luckily for us, the discretization of the continuous ODE provides the perfect opportunity to do so.
But first, let’s lay down the foundation for our framework used in discretizing time-varying systems. Remember how we mentioned that prior Mamba discretizations adapted the canonical ZOH discretization by adding a time subscript to convert the method from LTI to LTV? Well, to be honest, we left a bit more out earlier. The actual implementation of Mamba frankensteined the canonical ZOH and Euler methods to create discretized parameters $\bar{A}_t = \exp(\Delta_t A_t), \bar{B}_t=\Delta_t B_t$.
Holy heuristic! But it works empirically ¯\(ツ)/¯.
One potential explanation for why this mixture does so well despite not being theoretically grounded is that Euler is an approximation of ZOH. Taking the ZOH formula for discretized $\bar{B}_t=A_t^{-1}\left(\exp(\Delta_t A_t) - I \right) B_t$, if we use the approximation $\exp(x) \approx 1+x$, the resulting $\bar{B}_t \approx \Delta_t B_t$.
This heuristic was always bugging us in the back of our minds, so in our work, we finally formalized the discretization.
We develop a method that produces a class of formal discretizations for time-varying systems, including one called exponential-Euler that exactly corresponds to the formula used in Mamba-1/2.
So let’s actually figure out how to discretize our LTV system in a principled manner.
\[h'(t) = A(t)h(t) + B(t)x(t)\]The general intuition behind our framework is that a bare-bones ODE $f'(t) = Af(t)$ has a closed-form solution $f(t) = e^{tA} f(0)$. It follows that the one-step discrete update is then $x_{t+1} = e^{\Delta A}x_t$. Here, since the derivative includes $Ah(t)$, the state directly impacts the rate of change. Thus, the parameterization of A can rapidly oscillate the dynamics of the system, which forces explicit methods, like Euler, to take small $\Delta$ steps which limits the expressivity of the system.
To mitigate this, we adjust the dynamics with an integrating factor of $e^{-At}$ to counteract the dominating exponential and directly analyze $e^{-At}h(t)$ instead. Let’s see how it applies to our system.
Taking our $h'(t)=A(t)h(t)+B(t)x(t)$ system, we apply an integrating factor of $e^{\int_0^t -A(s)ds}$ as $A$ is now time-varying.
\[\begin{aligned} e^{\int_0^t -A(s)ds}h'(t) &= e^{\int_0^t -A(s)ds}A(t)h(t) + e^{\int_0^t -A(s)ds}B(t)x(t) \\ (e^{\int_0^t -A(s)ds}h(t))' &= e^{\int_0^t -A(s)ds}B(t)x(t) \end{aligned}\]since $(e^{\int_0^t -A(s)ds})' = -A(t)e^{\int_0^t -A(s)ds}$.
Thus, when we want to discretize between timesteps $[\tau_{t-1}, \tau_t]$, we can just integrate both sides over that interval. For ease of notation, we denote $z(t) := e^{\int_0^t -A(s)ds}$.
\[\begin{aligned} \tfrac{d}{dt}(z(t)h(t)) &= z(t)B(t)x(t) \\ \int_{\tau_{t-1}}^{\tau_t}\tfrac{d}{d\tau}(z(\tau)h(\tau))d\tau &= \int_{\tau_{t-1}}^{\tau_t}z(\tau)B(\tau)x(\tau)d\tau \\ z(\tau_{t})h(\tau_{t}) - z(\tau_{t-1})h(\tau_{t-1}) &= \int_{\tau_{t-1}}^{\tau_t}z(\tau)B(\tau)x(\tau)d\tau \\ z(\tau_{t})h(\tau_{t}) &= z(\tau_{t-1})h(\tau_{t-1}) + \int_{\tau_{t-1}}^{\tau_t}z(\tau)B(\tau)x(\tau)d\tau \\ \end{aligned}\]Rearranging into a more familiar form and substituting back the $z(\tau)$ value:
\[\begin{aligned} h(\tau_{t}) &= z(\tau_{t})^{-1}z(\tau_{t-1})h(\tau_{t-1}) + z(\tau_{t})^{-1}\int_{\tau_{t-1}}^{\tau_t}z(\tau)B(\tau)x(\tau)d\tau \\ h(\tau_t) &= \exp\left(\int_{\tau_{t-1}}^{\tau_t}A(s)ds\right)h(\tau_{t-1}) + \int_{\tau_{t-1}}^{\tau_t} \exp\left(\int_{\tau}^{\tau_t}A(s)ds\right)B(\tau)x(\tau) d\tau \end{aligned}\]Now, we’ve isolated the state-transition and the state-input through our integration factor. This means the most “difficult” part of the adjusted system can be calculated independently of the state-input integral, which is left to be approximated with many possible methods.
Under the LTV case, because $A(s)$ is continuous, we “sample” it with a right-hold assumption where $\forall s \in [\tau_{t-1},\tau_t], A(s) = A(\tau_t) = A_t$, resulting in
\[h_t \approx \exp(\Delta_t A_t)h_{t-1} + \int_{\tau_{t-1}}^{\tau_t} \exp\left((\tau_t - \tau)A_t\right)B(\tau)x(\tau)d\tau\]If $A$ is LTI, i.e., constant, then the state-transition integral is exactly $\exp(A \Delta_t)$, which is the discretized $\bar{A}$ term for canonical ZOH.
Under the assumption $x(\tau) = x_t$, if $B$ is also LTI, then $\bar{B}$ also recovers the canonical ZOH term, $A^{-1}\left(\exp(\Delta A) - I\right)B$.
This final equation lays the foundation for recovering the prior Mamba discretization methods and is also the inspiration of our “exponential-“ style name, as the application of the integration factor can be seen as a style of exponential tilting or adjustment.
As previously mentioned, prior Mamba discretizations differ on paper and in practice. Now, using our new discretization derivation and certain sampling assumptions, we can recover the reported LTV ZOH discretization and the implemented exponential-adjusted Euler discretization scheme, or exponential-Euler for short.
We have already recovered the $\bar{A}$ term for both, so we will focus only on the remaining state-input integral which evaluates to $\bar{B}$.
ZOH: Assuming similar assumptions where $B(\tau), x(\tau)$ are constant and sampled at the right endpoint,
\[\begin{aligned} (\cdot) &= B(\tau_t)x(\tau_t) \int_{\tau_{t-1}}^{\tau_t} \exp\left((\tau_t - \tau)A_t\right)d\tau\\ &= B_t\,x_t\exp(A_t\tau_t) \int_{\tau_{t-1}}^{\tau_t} \exp\left(-A_t\tau\right)d\tau \\ &= B_t\,x_t\exp(A_t\tau_t)\dfrac{1}{A_t}\left(\exp(-A_t\tau_{t-1}) - \exp(-A_t\tau_t)\right) \\ &= A_t^{-1}\left(\exp(A_t(\tau_t - \tau_{t-1})) - I\right)B_t\,x_t \\ &= A_t^{-1}\left(\exp({\Delta_tA_t})-I\right)B_t\,x_t \end{aligned}\]Exponential-Euler: Once again, we approximate the integral with Euler’s rule and hold the $B, x$ terms to the right endpoint.
\[\begin{aligned} (\cdot) &= B(\tau_t)x(\tau_t) \int_{\tau_{t-1}}^{\tau_t} \exp\left((\tau_t - \tau)A_t\right)d\tau \\ &= B_t\,x_t\exp(A_t\tau_t) \int_{\tau_{t-1}}^{\tau_t} \exp\left(-A_t\tau\right)d\tau \\ &\approx B_t\,x_t\exp(A_t\tau_t) (\tau_t - \tau_{t-1})\exp\left(-A_t\tau_t\right) \\ &= \Delta_t\,B_t\,x_t \\ \end{aligned}\]Thus, the linchpin of converting continuous-time SSMs into various tangible recurrences is approximating $\int_{\tau_{t-1}}^{\tau_t} \exp\left((\tau_t - \tau)A_t\right)B(\tau)x(\tau)d\tau$ in different ways. We’ve shown above that one can either analytically solve it (ZOH) or approximate it (Euler), but what if we don’t want to use inverses but also want something more precise than Euler’s?
We can instead use the trapezoid method to approximate the integral using both endpoints instead of just one for Euler. Unlike the standard method which averages both endpoints, we use a convex combination which we find empirically performs better. The integral then evaluates to
\[\begin{aligned} & \Delta_t \left(\lambda_t \exp((\tau_t - \tau_t) A_t) B_t\,x_t + (1 - \lambda_t) \exp((\tau_t - \tau_{t-1}) A_t)B_{t-1}\,x_{t-1} \right) \\ =& (1-\lambda_t)\Delta_t\,\exp({\Delta_tA_t})\,B_{t-1}\,x_{t-1} + \lambda_t\,\Delta_t\,B_t\,x_t \end{aligned}\]Interestingly, we can see here that for our new exponential-trapezoidal recurrence, there is some structured time-mixing across the state-input terms. Thus, it acts as an implicit data-dependent convolution of size two on the SSM’s state-input.
Now how can we format our recurrence in a parallel representation to enable faster training? To do so, we’ll be viewing the recurrence in its parallel form. This hearkens back to the state space duality (SSD) framework introduced in Mamba-2.
Let’s rewrite our recurrence so this doesn’t get too messy:
\[h_t = \alpha_t h_{t-1} + \beta_t B_{t-1}x_{t-1} + \gamma_t B_tx_t\]where $\alpha_t = e^{\Delta_t A_t}, \beta_t=(1-\lambda_t)\Delta_t e^{\Delta_t A_t},\gamma_t=\lambda_t\Delta_t$.
SSD demonstrated that a large class of recurrent SSMs could be represented in a parallel form that uses an element-wise multiplicative mask to model the state-transition decay. The form that such parallel representations take is $Y = (L \circ C^\top B) X$ where $L\in\mathbb{R}^{T,T}, C, B\in\mathbb{R}^{T,N}, X,Y\in\mathbb{R}^{T,D}$ where $T$ is the total sequence length.
This format makes it pretty clear the connection between SSMs and attention, especially when changing the SSM-centric notation to one that is more common in attention literature: $C \to Q, B \to K, X \to V$. When $L$ is a lower triangular matrix of all ones, we get the vanilla linear attention
This parallel formulation is what enables the matmul-focused forward pass.
If we expand the recurrence where $h_{-1}=0$,
\[\begin{aligned} h_0 & = \gamma_0 B_0x_0 \\ h_1 & = (\alpha_1 \gamma_0 + \beta_1)B_0x_0 + \gamma_1 B_1x_1 \\ h_2 & = \alpha_2(\alpha_1 \gamma_0 + \beta_1)B_0x_0 + (\alpha_2 \gamma_1 + \beta_2)B_1x_1 + \gamma_2 B_2x_2 \\ ... \\ h_T & = \alpha_{T\dots2}(\alpha_1 \gamma_0 + \beta_1)B_0x_0 + \ldots + \gamma_T B_Tx_T \end{aligned}\], we can express the output as a matrix operation
\[\small \begin{bmatrix} y_0 \\ y_1 \\ y_2 \\ \vdots \end{bmatrix} = \left( \begin{bmatrix} \gamma_0 & & \\ (\gamma_0\alpha_1 + \beta_1) & \gamma_1 & \\ \alpha_2(\gamma_0\alpha_1 + \beta_1) & (\gamma_1\alpha_2+\beta_2) & \gamma_2 \\ \vdots & & & \ddots \\ \end{bmatrix} \odot \begin{bmatrix} C_0^\top B_0 & & & \\ C_1^\top B_0 & C_1^\top B_1 & & \\ C_2^\top B_0 & C_2^\top B_1 & C_2^\top B_2 & \\ \vdots & & & \ddots \\ \end{bmatrix} \right) \begin{bmatrix} x_0 \\ x_1 \\ x_2 \\ \vdots \end{bmatrix}\]which can further be decomposed into a 1-semiseparable matrix (Mamba-2’s decay mask) and a 2-band matrix.
The equivalence established between the recurrent and parallel forms of Mamba-3 is another instance of what Mamba-2’s SSD established: that certain classes of SSMs have a matrix form that vectorizes the recurrence. This forms the foundation of the hardware-efficient algorithm used for training.
As we’ve mentioned, the simplification of the SSM over the past couple years to improve efficiency has reduced the abilities of newer models
While more complex state transitions, such as diagonal plus low-rank (DPLR) in older LTI SSM models, can improve the method’s expressivity, the simplification of transitions across iterations of LTV SSMs has resulted in even the simplest state-tracking tasks falling out-of-reach for Mamba-style models
One of the simplest tasks, parity, determining whether the sum of a sequence of 0’s and 1’s is even, is unsolvable by Mamba models in a constant number of layers. The ideal solution requires the hidden state to track whether the running sum is even or odd and to alternate depending on the next input, modeling a simple two-state automaton
But, parity and other modulo tasks can be solved with rotations! The way one can visualize how rotations solve modulo $m$ problems is that one has some 2D vector that can be rotated around the origin. The entire possible angle distribution $[0, 2\pi]$ is partitioned into $m$ sections, and the vector is rotated by $\tfrac{2\pi}{m}$ to align with the current running modulo remainder.
Working with complex values in computer systems is quite a pain due to their multiplicative interactions. Luckily for us, diagonal complex-valued continuous SSMs can be represented as discretized real-valued SSMs (without any additional approximation loss compared to standard discretization).
While the full proof can be found within our paper, the general intuition is to expand the original $N$-sized state complex-valued SSM to a $2N$-sized real-valued where each complex-valued dimension is split into its real and imaginary counterpart, and the complex transition matrix is partitioned into its scaling and oscillatory portions. With the commutative property of the scaling and oscillatory components in effect (due to the diagonal structure of the underlying matrix), we can map the continuous diagonal complex transition into a block-diagonal, scaled rotation transformation.
It is also possible to convert a general, unstructured complex SSM transition matrix to a real-valued SSM, though the scaled rotation intuition breaks down. The conversion still doubles the state size, but while the expansion of the $B, C, x$ will remain similar, the transition matrix will no longer be as simple. With an unstructured transition matrix $\mathbf{A} + i\Theta$, the exponential (resulting from our integrating factor technique) cannot be factored as
\[\exp(\Delta(\mathbf{A} + i\Theta)) \neq \exp(\Delta\mathbf{A}) \exp(i\Delta\Theta)\]since $\mathbf{A},\Theta$ generally do not commute, unlike the diagonal case. Consequently, while a real-valued equivalent exists, computing it would require the expensive full matrix exponential.
Eventually, under the prior exponential-Euler discretization, we obtain the following recurrence
\[\begin{aligned} h_t &= e^{\Delta_t A_t} \underbrace{ \begin{bmatrix} \cos(\Delta_t \theta_t) & -\sin(\Delta_t \theta_t) \\ \sin(\Delta_t \theta_t) & \cos(\Delta_t \theta_t) \end{bmatrix}}_{\vphantom{\Big|}R_t} h_{t-1} + \Delta_t B_t x_t \end{aligned}\]for an $N=2$ state. For larger states, the rotation matrix $R_t$ is block-diagonal and the $\theta$’s can differ.
Great, now we’ve shown that we can implement a complex SSM without having to explicitly model the imaginary components! But another issue remains: the rotation of the hidden state requires us to reimplement the kernels to incorporate this new type of transition — more moving parts, rotating the entire hidden state, etc., — seems like quite the hassle ugh… or is it?
Luckily for us, given the structure of $A$, we can sidestep all of this and directly adjust the $B, C$ to achieve the same goal. This is because the output for timestep $t$ can be modeled as
\[y_t = C^\top_t\bar{B}_t + \cdots + C^\top_t(\bar{A}R)^\times_{t\cdots 1}\bar{B}_0\]Since $\bar{A}$ is a scaled identity matrix, we can ignore the $\bar{A}$ terms for now by absorbing them into $C$. This results in the term $C_i^\top R_i \cdots R_{j+1} \bar{B}_j$ which can be represented by $\left(R_i \cdots R_0 C_i\right)^\top \left(R_j \cdots R_0 \bar{B}_j\right)$. The $\bar{A}$ terms can be reintroduced at this point. Thus, it’s apparent that the rotations can be embedded into the $B, C$ terms prior to performing the SSM recurrence instead of directly adjusting the transition matrix.
The application of our data-dependent rotations onto $B, C$ can be done efficiently. Instead of performing numerous matrix multiplications, we can run a cumulative sum over the $\theta$’s and perform the efficient realization of rotary matrix multiplication from the RoFormer paper
, which itself used data-independent rotations. This inspired us to call the use of a vanilla SSM to compute a complex SSM as the “RoPE trick.”
The RoPE trick extends to our exponential-trapezoidal recurrence, and we empirically validate that our complex-valued SSM is able to solve state-tracking tasks previously too hard for prior Mambas.
The compute paradigm in scaling LLMs has shifted from training to inference in the past two years or so. Nowadays, more and more compute is dedicated to the actual deployment and usage of these models, and to some degree, the writing was on the wall. Emergent properties such as chain-of-thought and in-context learning techniques dramatically improved performance of earlier models by making them think longer and process more tokens, and all of the best models to date are reasoners which have (almost certainly) been post-trained with reinforcement learning using large rollout budgets. With the advent of agentic workflows, we have agents spawning subagents and so forth.
What does such a paradigm change mean for hardware efficiency?
Compared to the compute-bound nature of training, the deployment of the same models, especially decoding, is memory-bound. Throughout training, the hardware is constantly performing operations, but during decoding, the compute units of the hardware sit idle for large swathes of time as it waits for data to be moved across different levels of memory hierarchy.
As an example of why this happens, think about a simple MLP. During training, the entire sequence is processed, but during decode, only the current token is processed as the past tokens are cached. The latency spent moving the MLP weights is around the same for both training and decode, but can be amortized a lot better over more computation under a training regime.
So with current linear models where the state update and output calculation can be performed in constant time, compute units sit idle for most of the time, and we are bottlenecked by simply moving data back and forth! One way to estimate how “hard” the hardware is working is through arithmetic intensity, a ratio of compute performed to memory moved.
Let’s analyze how SSMs are deployed in practice and their arithmetic intensity. A typical SSM, say Mamba-2 for instance, is organized into heads with head dimension $P$, where a single head is composed of $P$ SISO SSMs that share the same $a_t, B_t, C_t, B_t$
\[\begin{align*} \mathbf{h}_t & = a_t \mathbf{h}_{t-1} + B_t \mathbf{x}_t \\ \mathbf{y}_t &= C_t^\top \mathbf{h}_t \end{align*}\]where $a_t$ is a scalar decay and $\mathbf{x}_t, \mathbf{y}_t \in \mathbb{R}^{P}, \mathbf{h}_t \in \mathbb{R}^{N\times P}$. If we use 2-byte data, for a single decode step, the total memory traffic is $2(1 + 2N + P + NP)$ when accounting for all SSM parameters. The movement of the hidden state is clearly the main contributing factor at reasonable values of $P$ and $N$.
When calculating the number of FLOPs used for the same operation, we get around $5NP - P$
Since we pay for the entire rack and the expensive tensor cores, how can we use as many of them as possible?
A: Either increase the numerator or decrease the denominator.
We’ve seen empirically that the state size is quite important for performance but expanding it also increases memory… so, let’s keep that the same. Now how do we increase the compute required for calculating the hidden state recurrence while maintaining the same hidden state? Referring back to our state space/control theory toolbox, multi-input, multi-output (MIMO) SSMs can be used instead of the single-input, single-output (SISO) SSMs we’ve been using.
Through the expansion of the dimension of $\mathcal{C}_t, \mathcal{B}_t$ to $N \times R$ and $\mathbf{x}_t, \mathbf{y}_t$ to $P \times R$ where $R$ is the rank of the system, we can maintain similar memory traffic (for small enough $R$) while increasing the FLOPs utilized when operating with matrix multiplications with $B_t x_t^\top$ instead of outer-products.
\[\begin{aligned} \mathbf{h}_{t} &= a_t h_{t-1} + \mathcal{B}_t \mathbf{x}_t^\top \\ \mathbf{y}_t &= \mathcal{C}_t^\top \mathbf{h}_t \end{aligned}\]The total FLOP count thus increases to $4NPR + NP - PR$ which results in an arithmetic intensity that scales with $O(R)$ when $R \ll P, N$ (generally the case as $P=64, N=128$ and $R=4$).
The downstream gains and comparable decoding latency associated with switching from SISO to MIMO require compute costs that scale linearly with $R$ during compute-bound training.
Expressing the output of a MIMO SSM would require $R^2$ SISO SSMs due to its rank $R$ state-input and $R$ unique outputs. Its hidden state can be partitioned into the sum of $R$ SISO hidden states, and subsequently, the hidden state needs to be instantiated $R$ times for each of the outputs. But if the expressivity is a $R^2$ increase, how does training compute required scale by only $R$?
The chunked training algorithm is the reason for this disparity. Most linear models, including both Mamba-2 and Mamba-3, are computed in chunked fashion, where the sequence is partitioned into chunk sizes of $C$. The hidden state is aggregated across chunks in a sequential manner, while the output of the SSM is calculated with a quadratic, parallel algorithm.
For MIMO, the computation of outputs between chunks increases by a factor of $R$, whereas the computation of outputs within each chunk increases by $R^2$. So by decreasing the chunk size to $\tfrac{C}{R}$, the total FLOP count only increases by a factor of $R$. Our paper covers the actual FLOP calculations, but one way to think about it is that we want to reduce the amount of compute required for each quadratic algorithmic pass by increasing the number of times we call it.
Given the interpretation of MIMO SSMs as multiple SISO ones, improvements introduced for vanilla SSMs, like our exponential-trapezoidal discretization and complex-valued transition, can be directly applied to our MIMO variant. However, the conversion must be done carefully without drastically increasing the total parameter count. The naive solution of expanding the projection size would lead to a $R\times$ increase as the SSM inputs $x, B, C$ would all need to be adjusted. The subsequent rank $R$ output $Y$ would also force the output gate $Z$ and output projection to expand as well. This approach is clearly untenable.
Instead, we can use Mamba’s multi-value attention structure to our advantage. Since the $B, C$ are tied across all heads, we can increase the projection size without much issue, resulting in a fairly negligible $DN \to DNR$ increase for the entire layer. However, the input $x$, output $y$, and gate $Z$ are unique per head and are the main source of parameters, thus cannot be increased in such a way. Instead, we keep the original projections then element-wise scale each dimension of the projected value to size $R$ using a learnable, data-independent vector. For each head, we are able to reduce the parameter count from $DPR$ to $DP + PR$, which is quite the reduction given the number of heads each Mamba layer has!
We show that our instantiation balances the expressivity of multi-input, multi-output systems and parameter efficiency. In parameter-matched settings, our Mamba-3 MIMO variant further improves the already strong performance of regular Mamba-3 at all scales we tested on. When analyzing state size (proxy for decoding speed) to performance in controlled experiments, Mamba-3 sets the Pareto front compared to prior Mamba-2, able to achieve comparable performance with half the state size.
Mamba-3 offers a faster model with the same quality or a better model for the same speed.
We had to cut a bunch of proofs and results to keep the content in here digestible, but if that interests you, please do read our paper!
Within our work, we’ve aimed at boosting the performance and capabilities of the Mamba series from a few SSM-centric improvements. We’re curious to see how and where the community explores in architecture research. In particular, we are quite excited (and think addressing them would be really impactful) in the following directions:
Building better hybrids: It’s been amazing to see the general research community and industry labs appreciate the benefits linear models can provide, especially with hybrid models
Improving Layer Primitives: Our methodological improvements, while most natural to SSMs, can be applied to other architectures. It would be interesting to see how they scale under different transition mechanisms. In addition, there seems to be a whole trove of untapped improvements waiting to be uncovered or inspired in the “classics,” if you will. Just as Mamba and other SSMs are grounded within signal processing and traditional state space literature, such parallels can be found in other types of linear models — fast-weight programmers for linear attention