Mamba, Selective SSM
Published on September 20, 2024 by JunYoung
Mamba Selective SSM H3 Gated MLP
13 min READ
이전 글들을 올린 후 꽤나 많은 시간이 지났다. 처음으로 올렸던 글인 LSSL(Linear State-Space Layer)에서는 연속 시퀀스 데이터셋에 대해 딥러닝 모델이 효과적으로 latent space를 정의할 수 있는 구조의 발달 양상을 살펴보았다. 그 중 가장 주요한 키포인트가 되는 HiPPO 논문의 경우 임의의 길이를 가지는 시퀀스의 hidden state를 모델링할 수 있는 근거로 자리잡았고, 이후 LSSL은 레이어 개념으로 확장시켜 HiPPO 행렬의 학습을 통해 성능을 높일 수 있다는 가능성을 보여주게 된다. 그리고 학습이 진행됨에 따라 떨어질 수 있는 학습 안정성 및 수치 엄밀성을 확보하기 위해, 이에 추가로 이후 scalable(데이터 및 모델 확장 가능성)을 높이기 위해 제안된 S4모델을 두번째 글로 다루었다. 이제 이러한 기존 SSM based modeling을 근거로 하여, 트랜스포머 모델의 불가피한 연산량 증대를 개선하고자 한 Mamba 모델을 소개한다.
최근 가장 많이 다루게 되는 딥러닝 모델, 혹은 foundation model은 아마 대부분 알고 있겠지만 Transformer에 해당된다. 기계 번역 분야에서 등장한 transformer 구조는 이후 Computer vision, NLP, Audio 등등 모달리티(데이터 형태)의 종류에 무관하게 활발하게 활용되었으며, 가장 중요한 특성인 model scalability(모델 크기와 데이터 크기가 증가함에 따라 성능도 같이 향상됨) 특성이 현존하는 모델링 중 가장 확연하게 드러난 모델이라고 할 수 있다.
그러나 트랜스포머 모델에는 가장 큰 문제점이 있다. 그것은 바로 한정된 길이의 시퀀스를 (보통은 트랜스포머 모델은 토크나이저를 통해 문장을 토큰 단위로 분해하고, 이를 임베딩화하여 사용한다) 받아들이며 이를 병렬 연산(Attention)하기 위해 그만큼의 메모리를 소모하게 되고, 이는 결국 동시에 처리 가능한 데이터의 길이가 어느 정도 한정될 수 밖에 없는 것이다.
만약 짧은 문장 하나를 트랜스포머에 넣게 되면, 적당히 토크나이징하여 연산을 수행하면 되지만 논문과 같이 긴 줄글의 경우 한정된 연산량 때문에 이를 분리해서 넣게 된다. 가장 큰 문제는 트랜스포머 모델은 논문 제목에서도 볼 수 있듯이(Attention is All you Need) RNN모델과 같이 hidden state를 따로 학습하지 않다 보니, 모델에 들어간 시퀀스 내에서 모든 것을 수행하게끔 되어있다. 즉, 논문 전체를 쿼리로 하여 특정 질문에 대한 주요 부분들을 추출하거나 하는 과정에서는 충분한 연산량이 보장되지 않으면 좋은 성능을 보이지 않게 된다. 이러한 한계점은 곧 트랜스포머 모델의 어텐션 연산을 위한 효율화 작업으로 이어지게 된다.
그럼 대체 왜 꼭 “트랜스포머”여야만 하는가??라고 한다면, 그간 연산 효율화를 위해 단순 어텐션을 벗어난 모델링인 linear attention, gated convolution, RNN, structured state space models (SSMs) 모두 연구되었지만 Language와 같은 데이터에 대해 트랜스포머의 어텐션보다 높은 성능을 보이지 못했다는 단순한 사실이다. 트랜스포머가 보여준 가능성에 대해 수많은 연구가 진행되었기 때문에 하드웨어 친화적인 알고리즘이나 다양한 학습법 등등 많은 연구가 진행되어 이미 상당히 높은 발전을 이루어낸 트랜스포머 시장에서 적당한 연산량으로 애매한 성능을 보이는 다른 모델이 주목받기가 힘든 상황이다. 특히나 최근 long range dependency/reasoning 에 집중했던 SSM의 경우에는 텍스트와 같이 오히려 정보 집약적인 task에 대해서 높은 성능을 보이지 못했다.
맘바는 SSM에 집중한다. SSM은 기본적으로 특정 시스템을 모델링하기 위한 구조로, 이 방법에 attention, RNN에서 사용되는 “gate” 개념을 “selection”으로 가져간 것이다. Attention은 사실 대단한 알고리즘은 아니고, 현재 토큰에 대해 참고가 가능한 시퀀스 내에서 집중할 부분과 무시할 부분을 구분하고, 이를 예측에 활용하는 것이다. 결국 SSM에서도 특정 정보만 선택적으로 활용하는 방법을 사용해볼 수 있는 것이다. 그러나 selective SSM은 일반적인 SSM에 비해 효율적 콘볼루션 연산 등 연산 효율화를 위한 장치를 전혀 사용하지 못하게 된다. 고로 가장 기본적인 연산 방식인 recurrent를 기본적으로 사용하게 되는데, 이때 하드웨어 친화적인 알고리즘을 고안하여 연산 비효율성을 보완하게 되는 것이다. 즉 어텐션과 MLP 없이, selective SSM과 이를 효율적으로 연산할 수 있는 방법을 더하여 Mamba를 만들어낸 것이다. 그래서 사실 맘바 논문의 가장 주요한 키포인트는 SSM에 있다기 보다는 FlashAttention과 비슷한 맥락인 GPU 활용력에 있는 것이다.
이전의 SSM에서 다뤄지지 않은 내용은 “input”에 대해 선택적 알고리즘이 없다는 사실이다.
SSM 모델링을 보게 되면 길이에 무관한 모델링을 할 수 있다는 장점이 있지만, 이를 다르게 해석하면 임의의 시점에 들어오는 입력은 모두 동일하게 모델링된 state를 보기 때문에 관찰 시점 이전의 input 혹은 이후의 input 중 일부 무시해야할 내용을 구분할 수 없다. SSM의 가장 큰 장점은 아무리 긴 길이의 입력이 들어오더라도, 모든 정보를 함축적으로 모델링할 수 있다는 사실이다. 그러나 위와 같은 예시를 보면 어텐션과의 차이점이 크게 드러난다. 예컨데 어떤 텍스트의 초반에 “고양이를 5년 전부터 키워왔다.”라는 정보가 있고, 중간에 그와 무관한 이야기인 “시골집에서 살던 내용”이 포함되어있고, 텍스트 후반부에 고양이 이름이 나와있는 경우를 생각해보자.
해당 쿼리에 대해 “OOO는 몇살 정도로 예상되는가?”라는 질문을 받는다고 가정하면 어텐션 모델은 우선적으로 해당 질문과 가장 큰 연관성을 지니는 “고양이 이름은 OOO이다.” 라는 내용과 “대략 5년 전부터 고양이를 키웠다.”라는 내용을 참고하겠지만, SSM은 중간에 있는 시골집에 살던 내용까지 전부 참고하여 정답을 내놓게될 것이다. 물론 SSM이 잘못된 대답을 내놓지 않고 정답을 내놓을 수 있지만, 결국 말하고자 하는 것은 이처럼 정보가 집약적인 데이터(특정 부분에 집중해야 제대로 된 QnA가 가능한 데이터 구조)에 대해서는 어텐션 만큼 효율적이고 정확한 방법이 없다는 것이다. 따라서 맘바에서는 기존 SSM 구조에 추가로 입력 신호에 대해 SSM 파라미터(이전의 정보들)를 다변화하는 구조를 통해 Selection 매커니즘을 추가하게 된다.
이전의 모든 SSM을 위한 효율화 알고리즘은 selective SSM에는 적용되지 않는다. 이는 LTI(Invariant)와 같은 모델링 구조에서 기본적으로 입력 및 시간에 무관하게 시스템은 동일하다는 가정을 가지고 있기 때문이다. 따라서 기존 논문에서 제안된 콘볼루션 기반 방법들을 모두 사용할 수 없게 되었고, 오로지 recurrent 연산을 사용할 수 밖에 없으므로, 이를 하드웨어 친화적으로 연산하는 방법을 고안하게 된다. 실제로 구현된 하드웨어 친화적인 알고리즘은 시퀀스 길이에 대해 Linear한 복잡도를 가지게 되어, 이전 콘볼루션 기반 알고리즘이 pseudo-linearity를 가졌던 것에 비해 recurrent 연산을 더욱 효과적으로 수행할 수 있게 되었다.
모델 구조는 간단하다. Selective SSM 구조를 하나의 모듈처럼 취급하여, 기존 트랜스포머 모델을 구성하는 MLP 파트를(Attention 및 Projection 등등) SSM 모듈로 갈아끼워서 사용한다. Selective SSM이 아닌 일반적인 SSM에 대한 내용은 이전에 다루었던 글들을 통해 간단하게 이해하고 오면 좋다. 간단하게 소개하자면 대개 Structured SSM은 4개의 파라미터$(\Delta, A, B, C)$를 기본으로 적용되며, 이를 discretize한 ($\overline{A}, \overline{B}, C$)을 사용한다. SSM은 LTI 시스템을 기반으로 하여 시간에 따른 시스템 함수 불변성을 가정하였으나, 이러한 불변성이 필연적으로 가지는 한계 때문에 맘바에서는 기존에 적용될 수 있었던 연산 효율성을 포기하고 Selective SSM을 채택하게 된다. 이 부분에서는 어떠한 파트가 구체화되어 Selective SSM이 설계되었는지 단계별로 정리하고자 한다.
Compression(축약)을 위한 Selection
시퀀스를 다루는 모든 모델링은 임의의 길이를 가지는 ‘문맥’을 어떻게 하면 작은 크기의 ‘hidden state’ 혹은 ‘latent’로 함축하는가?를 다루게 된다. 모든 시퀀스 모델링은 해당 관점에서의 trade-off를 고려할 수 밖에 없는데, 대표적인 시퀀스 모델링에 해당되는 ‘트랜스포머(Transformer)’는 문맥을 전혀 압축/함축하지 않는다는 특징을 가지고 있다. 이러한 특징은 autoregressive한 추론 단계에서 Key-Value 문맥 전체를 참조하기 위해 길이에 따라 연산 속도 및 메모리가 증가하게 되며, 이는 트랜스포머의 quadratic-time consuming의 주된 원인으로 작용한다. 반대로 RNN과 같은 recurrent model은 한정된 크기의 state를 가진다는 점에서 학습 효율성을 가지나, 과연 한정된 크기의 state에 얼만큼 context(문맥)이 잘 요약될 수 있는가가 문제점으로 작용된다.
위의 그림에 나타난 생성 task를 보게 되면 이러한 trade-off system을 SSM(LTI system)의 레벨 단에서 이해하기 쉽다.
좌측에 나타난 task는 입력으로 들어온 연속 신호 중 일부분(연속되어 색칠된 부분)을 복사하여 생성하는 과정이다. LTI system이 처리할 수 있는 가장 기본적인 형태의 delay라고 볼 수 있다. 즉 LTI system으로 매핑 가능한 일반적인 모델로 간단하게 수행할 수 있는 과제이다.
우측에 나타난 두 가지의 task 중 위쪽은 입력으로 들어온 연속 신호 중 관련 신호(색칠된 부분)과 무관한 신호(흰색 부분)을 구분하고, 관련 신호를 입력 순서대로 복사하여 생성하는 과제이다. 앞선 복사 task처럼 LTI system으로 수행될 수 없기에 time-varying system 및 non-linear system이 활용되어야하는 것을 볼 수 있다. 아래쪽은 Induction heads라는 과제로, 흔히 요즘 LLM에서의 In-context learning에서 대두되는 task라고 볼 수 있다. 입력으로 넣어준 일련의 신호에 대해 맥락을 파악하고, 이후에 특정 신호(검정색 토큰)를 입력으로 넣어줬을때 문맥에 맞는 정답을 내놓게되는 것이다(파란색 토큰). 이 역시 입력 신호에 대해 어떤 특정 신호가 뒤따를지 모르기 때문에 시스템이 문맥에 대한 추론이 필수적이고, 이를 위한 모델링을 추가로 수행해야한다.
결국 위의 그림으로 이해하고자 한 것은 여러 복잡한 생성 이론을 효과적으로 수행하기 위해 “선택적으로” 문맥을 이해하는 과정을 모델링에 추가해야 한다는 사실은 시퀀스를 처리하는 모든 모델링이 다루는 문제라는 사실이다. 해당 문제를 수행하기 위해 기존 방법론들을 총정리했을때, trade-off로서 context 용량과 효율성 간의 합의점이 필요하고, 현재 다루는 모델에서 이를 어떻게 적용해낼지(Attention으로 일부 특징들을 걸러낼 것인지, Recurrent module로 문맥을 요약한 state를 구축할 것인지) 고민하게 된다. 그렇기 때문에 SSM에서도 비슷한 맥락으로의 구조가 필요하고, 기존 시퀀스 모델에서 개별적으로 적용되던 context compression의 수단으로 selection mechanism을 넣은 것이다.
SSM에 selection을 넣기
결국 Mamba는 SSM에 어떻게 selection mechanism을 심느냐는 것인데, 저자는 RNN의 recurrent dynamics나 CNN의 파라미터와 같이 직접적으로 입력에 영향을 주는 “파라미터”를 입력 신호에 따라 바꾸는 방식을 생각해냈다.
좌측과 우측을 비교하게되면 그 차이가 나타난다. 주된 차이는 $B, C, \Delta$ 파라미터가 더이상 입력(배치) 및 각 입력의 타이밍(길이)에 무관하지 않고, 입력 및 출력 신호와 동일한 크기를 가지는 것을 볼 수 있고, 이는 더이상 Time-invariant system이 아닌 Time-variant system이 되었다는 것을 의미한다.
또한 $B, C. \Delta$가 어떻게 정해지는지 우측을 잘 보게되면 $s_B(x), s_C(x), s_\Delta(x)$와 같은 방식으로 입력 신호에 대한 함수로 표현이 되어있는 것을 볼 수 있다.
[ s_B(x) = \text{Linear}_{N}(x),~s_C(x) = \text{Linear}_{N}(x) ]
이는 가장 간단한 형태로, dimension $D$인 입력을 받아 $N$인 출력을 해주는 Linear module로 parameterize하여 함수를 구성하고,
[ s_\Delta(x) = \text{Broadcast}_D(\text{Linear}_1(x)) ]
이산 신호의 간격은 위와 같이 스칼라 값을 dimension에 브로드캐스팅하는 형태로 함수를 구성하였다. 이렇게 SSM을 시간 변화에 따른 함수로 파라미터화 하였다.
하드웨어 친화적 알고리즘
Convolution 모델이라던지, Attention 모델들은 하드웨어 친화적으로 설계가 되었다. 콘볼루션 커널은 입력 크기와 무관하게 항상 일정한 receptive field 크기를 가져 메모리를 최적화할 수 있으며, 어텐션은 길이에 따라 메모리 및 시간이 증가하기는 하지만 HBM 대신 SRAM에서 잘 동작할 수 있는 알고리즘이 등장했으니까 (실제로 FlashAttention 저자인 Tri Dao가 맘바 저자로 참여했음). Selective SSM도 비록 LTI system을 사용할 수 없게 되어버렸지만 분명 학습 효율화할 수 있는 부분은 있을 것이다. 기존 방법들의 한계점은 다음과 같다.
(1) SSM과 같은 recurrent model은 표현력(state size)과 속도 사이의 합의점이 필요하다. 높은 표현력을 가지면서도 속도 저하가 심하지 않은 방법을 찾는 것이 목적.
(2) Recurrent가 Convolution보다 더 유연하다. 후자가 전자의 확장판이기 때문에 latent state 구축을 위한 연산량이 (B, L, D, N) 만큼 필요한데, 이러한 문제를 해결하려는 방법이 나옴 (S4 모델).
(3) 기존의 LTI state model은 표현력 확보를 위한 state dimension $N$의 넉넉한 확보를 위해 dual recurrent-convolutional form을 고안함.
우리는 이제 selection mechanism을 적용하기 때문에 LTI system을 사용할 수 없다. LTI system이 가지는 한계점을 해결하기 위해 Selective SSM을 고안하였으나 연산 비효율성 문제를 해결해야한다는 점에 직면하게 된다. 저자는 문제를 해결하기 전 두가지 중요한 특징을 활용한다.
결국 주된 아이디어는 엄청 특별한 내용은 아니고, hidden state $h$를 GPU에서 효율적으로 연산할 수 있는 방법들(kernel fusion, parallel scan, recomputation)로 빠르게 구해보자는 것이다.
SSM의 시스템 주축이 되는 $\overline{A}, \overline{B}$를 직접 HBM에서 계산하지 않고, SSM parameter $A, B, C, \Delta$를 SRAM단으로 로드, 이산화 작업을 거져 다시 HBM에 쓰는 방식을 취한다. 또한 순차성 부분은 스캔할 타이밍에서 parallel scan algorithm을 적용하게 된다. 이로써 적은 메모리 bandwidth를 가지는 SRAM과의 데이터 송수신 관련 코스트를 최소화하여 사용한다. 이외의 backpropagation 시의 recomputation 방식은 FlashAttention과 하드웨어적으로 동일하게 적용된다.
Neural Network에 Mamba 섞기
Structured SSM(S4)와 마찬가지로 Selective SSM(Mamba) 또한 시퀀스에 대한 변환 모듈에 해당되기 때문에 neural network에 적용할 수 있다. 맘바의 구조를 종합적으로 이해하기 위해서는 H3와 Gated Unit을 이해하는 과정이 필요하다. 속칭 ‘배고픈 하마 (Hungry Hungry Hippos)’라 불리는 H3의 경우 트랜스포머의 Attention Algorithm의 효과를 따라갈 수 있는 SSM 구조 모델링을 위해 Shifting SSM과 Recalling SSM을 구별하고, 이를 multiplication으로 엮는 시도를 하게 된다. 이렇게 모델링하게 되면 Q, K, V로 추출되는 입력에 대한 정보가 Shifting SSM에서 이전 입력을 참조하기 위해 옮겨주는 역할을 수행하고, 만약 현재 입력 정보가 기억이 된다면(Shifting $\odot$ SSM), 그 이후 입력에 대한 출력값(Value $\odot$ SSM)을 응답으로 내놓는 구조가 된다. Gated MLP의 경우에도 결국 트랜스포머의 Attention 구조를 MLP 구조에 통합하고자 한 구조에 해당된다.
즉, 맘바의 경우에도 기본적으로 SSM 구조를 사용하기 때문에 트랜스포머의 Attention 효과를 활용하고자 했던 H3와 근본적으로 문제시하는 부분이 동일하다. 그렇기 때문에 내부적으로 연산되는 SSM 부분은 H3와 동일하다. 그러나 차이가 있는 점은 H3는 Linear Attention의 Q, K, V 구조를 활용하였지만, Mamba에서는 이러한 어텐션 구조를 전혀 사용하지 않고 Gated MLP를 2개의 SSM 시스템을 Wrapping하는 방식으로 구조화하였다.
이와 같이 모델링했다. 이때의 각 요소별 효과를 간략하게 서술하면 다음과 같다.
[ \begin{aligned} \mathbf{h_t}^\prime = A\mathbf{h_{t-1}} + B\mathbf{x_t} \newline \mathbf{y_t} = C\mathbf{h_t}+D\mathbf{x_t} \end{aligned} ]
Selective copying(좌측) 그리고 Induction head(우측) 각각의 성능이 기존 SSM baseline에 비해 월등히 좋아지는 것을 확인할 수 있다.
또한 확연히 좋아지는 부분은 Perplexity인데, 연산량이 늘어날수록(모델의 파라미터 수가 증가할수록) 문맥에 대한 생성 능력이 확연히 올라간 모습을 보여준다. 이전까지는 H3까지도 어텐션에 필적하지 못했던 부분이었는데, 맘바를 통해 꽤나 많이 따라잡은 것을 확인할 수 있다.
여러 downstream task에 대해 zero-shot 성능을 확인하였다. 파라미터 수가 증가할수록 perplexity는 감소하고 average는 증가하였고, 다소 적은 파라미터 수를 가지고도 좋은 성능을 보인다. 이외에 DNA, Audio modeling등 다른 시계열 모달리티에 대해서도 좋은 성능을 보여준다.
Mamba의 장점 중 가장 주요한 포인트는 context 길이가 길어질 경우에 연산량 및 추론 시간을 줄일 수 있다는 점인데, 실제로 Attention을 효율화한 FlashAttention과 비교했을 때에도 Mamba의 inference time 및 throughput이 좋아지는 것을 볼 수 있다.