Dao and Gu show that if you simplify Mamba so its state-space layer uses a diagonal matrix A_t that is a scalar times the identity matrix, i.e., A_t = a_t I, the state-space transformation can be expressed as a form of causal linear attention[a] by compounding coefficients a_1 ... a_t at each time step t. The equivalence of the simplified state-space layer and causal linear attention constitute the duality the authors refer to in the title. By taking advantage of this duality, Mamba-2 can be trained more efficiently, i.e., faster than original Mamba on GPUs.
Great work by Tri Dao (of FlashAttention fame) and Albert Gu, as usual.
The key question, for me and many others, is whether Mamba, Mamba-2, RWKV, and other linear RNN / linear attention models will ever match the performance of standard Softmax attention. My understanding and experience is that all the linear attention models out there [b] still underperform Softmax attention on things like recall tasks.[c]
Notably, humans also underperform Transformers on recall tasks. And yet we do ok on many others, even with our imperfect recall. So I hope we can identify a set of high value tasks on which these new architectures outperform Transformers and start benchmarking them on Transformers, too. Recall isn’t really “all you need” in this space, although it certainly impresses and helps to plug the capability gaps.
> Notably, humans also underperform Transformers on recall tasks...
Thank you. Yes, that's obviously true... but keep in mind that human beings can pull up information from notes, books, URLs, etc. and store that information in "short-term memory" at any time, as needed, for cognitive tasks. To match or exceed peak human performance, AI models must be able to do that too -- whether it's via a long context window or some other mechanism (e.g., searching on external storage) is a secondary issue. I'd say recall is a prerequisite to match peak human performance.
Can you memorize “War and Peace” in one sitting and randomly and precisely recall things from it on demand? No? Didn’t think you could. Could you perhaps translate Farsi into Hawaiian and then speak it? Can you give instant, and nearly always correct to almost any textbook question irrespective of the domain? All of that is recall, and all of that GPT4 could do for well over a year.
> Can you memorize “War and Peace in one sitting - neither can an LLM, unless you weaken it to "well I meant input that fits in context", not memorize
> precisely recall things from it on demand - "precisely" is load-bearing here, and the question at hand. (no, LLMs don't have perfect or precise recall, even from in-context material, and it's not particularly close)
> translate Farsi into Hawaiian and then speak it - 0 models, is this supposed to be a gpt-4o reference? What does this have to do with "perfect recall?" Let's pretend it does. Are you claiming that gpt-4o has a 100% success rate translating Farsi to Hawaiian?
> Can you give instant, and nearly always correct to almost any textbook question - sobs in just spent 4 days benchmarking. For the record, best non-RAG I have is gpt-4o at 86% on MedQA and LegalBench. With RAG, 93%. RAG gpt-4o just barely scrapes by my cofounders USMLE scores. It's excellent. But it's not superhuman.
> All of that is recall - no, translation is not recall
But you see, while some of those things aren’t _just_ recall, they are recall, too. Just different types of it, rather than point recall. And a modern SOTA LLM can accurately recall _vastly more_ data than any human or even any group of humans. Transformers are quite literally very fancy associative memories.
Quadratic transformers outperform weaker forms of attention on recall tasks, but recurrent models are strictly more powerful than transformers (when they don't use chain of thought) at state tracking problems. Mamba 2 likely has the same limitations with state tracking as a full transformer due to being parallelizable: https://arxiv.org/abs/2404.08819 .
The last figure in this paper is a huge disappointment when you consider how reality meets such theoretical arguments. A typical mamba model would have 100 layers (maybe 60 for snaller ones). That figure scales up to 4 layers and these are sufficient for the problem they consider so the argument goes that another RNN only needed one layer for it.
Why bother with more layers? The inability of the conventional transformer to not have recurrent layers is mostly a weakness. The final layers of an LLM do very little (except the last) and most of the time just let the token pass through from the layer that actually determined it. A recurrent architecture could dynamically perform as many passes as it needs to produce the next token. This would result in a speedup for easy tokens and a slowdown for hard tokens compared to the fixed "you must go through all layers regardless" architecture of classical LLMs.
Mamba is technically a recurrent neural network (and typically decodes as such), simply of a constrained architecture that among other things keeps the norms of its matrixes finite independently of gating. I think that the above paper confused the word “state” to make some cheap points that may have hit the social networks at the time, yet it didnt demonstrate a practical benefit of earlier RNN over mamba, and the example looks a bit silly. If one had used resnet in vision with less than 4 layers to make a point, then that would be the equivalent of this paper. It might have been stronger, if the authors cared more, but we will not find out.
(disclaimer: I've not looked at mamba specifically yet) but state spaces as used current are different from traditional RNNs in a very simple way: the state is linearly (often associatively) accumulated while in RNN the state is passed/accumulated as an input/output of a non-linear function.
It is fundamentally impossible for linear attention to outperform quadratic attention. What you could do instead is have a few quadratic attention layers in the first layers and one in the last and have everything else use linear attention. This would allow you to have a context length of millions of tokens within 24 GB VRAM for a 7b model because you have eliminated the KV cache for the inner layers, while still retaining the ability to perform any token vs any other token attention. The final layer is also somewhat important since it allows the model to see all its reasoning outputs and attend over its previous reasoning steps.
a) a linear SSM (a form of RNN?) is equivalent to Attention without the scaling and softmax; and
b) Attention is "all you need" and the thing that made Transformers radically outperform all the previous architectures like LSTMs that used to dominate NLP;
does that imply c) the scaling and softmax parts of the attention equation, in particular, is the magic touch that makes Transformers work so well?
The major difference is that transformer state grows as the sequence gets longer, while recurrent models use a fixed size state. So presumably at sequence length (T) > size of state space (N), the transformer will be better on some very specific tasks. Not all, especially those that require the model to select information from the beginning of the sequence conditional on something at the end of the sequence. Transformers can refocus any time, while SSNs need to guess right from the start what to keep and what to drop. SSNs could use the old trick of repeating the input twice to allow the end to condition on the beginning as well.
An important role is held by the softmax function which normalizes the attention scores, allowing the model to weigh different parts of the input sequence dynamically. This means that, unlike RNNs which sequentially process inputs and update states, Transformers can directly access and prioritize information from any part of the sequence, and they are not slower for T < N.
"From one perspective, Mamba-2 isn’t strictly better than Mamba-1: while it’s a dramatic improvement from a training perspective, Mamba-1 might be better from a pure inference perspective. Since inference speed of SSMs is entirely governed by the state dimension, if one wants to maximize performance for a target inference efficiency (i.e. for a particular state size N), then the increased expressivity of Mamba-1 might be better."
TL;DR: The authors show that if you simplify Mamba so its state-space layer uses a diagonal matrix A that is a scalar times the identity matrix, the state-space transformation can be expressed as a form of causal linear attention.[a] That's the duality the authors refer to in the title. The key practical benefit is that it enables more efficient (faster) training on GPUs.
Theoretical stuff aside, Mamba-2's performance seems to scale slightly better than original Mamba: https://tridao.me/assets/img/2024-05-31-mamba-2/pile_8k_mamb...
Here's the code implementing Mamba-2: https://github.com/state-spaces/mamba/blob/main/mamba_ssm/mo...
Great work by Tri Dao (of FlashAttention fame) and Albert Gu, as usual.
The key question, for me and many others, is whether Mamba, Mamba-2, RWKV, and other linear RNN / linear attention models will ever match the performance of standard Softmax attention. My understanding and experience is that all the linear attention models out there [b] still underperform Softmax attention on things like recall tasks.[c]
---
[a] https://arxiv.org/abs/2006.16236
[b] https://github.com/topics/linear-attention / https://github.com/topics/linear-attention-model -- this list is by no means complete!
[c] https://arxiv.org/abs/2402.01032