Mamba, Linear State-Space Layer
Published on July 22, 2024 by JunYoung
Mamba HiPPO LSSL S4
17 min READ
첫번째 글을 올린 이후 상당히 많은 시간이 지났다. 중간에 ECCV 논문 리부탈과 졸업 준비, 논문 등등 바쁜 일이 참 많았다. 결국 졸업도 해서 학위도 받고 ECCV 논문도 억셉되고 여러모로 좋은 일이 대부분 끝났지만 SSM 공부를 한 지 너무 오랜 시간이 지나버려서 논문을 처음부터 다시 읽어야한다는 큰 문제가 생겨 늦어지게 되었다. 아이고야 …
HiPPO는 SSM 구조에서 Long-term range를 구축하기 위한 matrix $A$ 구조의 중요성을 언급하였고, LSSL에서는 SSM을 연속적으로 존재하는 $A$ 전부에 대해 이를 일반화하였다. 이전 글에서 Mamba modeling의 기초가 되는 LSSL에 대해서 설명했었고, 해당 글에 간단한 수식과 관련된 증명을 첨부했었다. 솔직한 심정을 담아 말해보자면, 아직 본인은 이러한 기본 내용들을 전부 이해하지 못했다고 생각하고 있고 이 글을 통해서 맘바를 이해하고자 하는 것은 너무 돌아가는 과정이라고 생각되기도 한다. 자신감 없이 말한 감이 없지 않아 있지만 결국 맘바 모델링 자체를 이해하는데 있어서 state-space modeling을 완전히 뼛속부터 이해할 필요는 없다고 생각한다 (Bottom-up 보다는 Top-down이 맞는 방향이라는 개인적인 의견).
그리고 여러 블로그에 보면 시각화와 함께 맘바 모델링을 한큐에 이해할 수 있게 쉽게 정리해둔 글도 많이 보인다. 그럼에도 불구하고 굳이 글을 길게 써서 리뷰했던 이유는 앞선 글에서 말했던 것처럼 맘바의 근본적인 내용에 대해 이해해보려 노력하는 과정이 의미가 있다고 생각했기 때문이다. 맘바를 간단하게만 이해한다면 맘바가 도대체 왜 transformer 구조가 가지는 문제들을 해결할 수 있었는지, 그리고 단순히 새로운 아키텍쳐로서 등장했다고 해서 무조건 좋다고 생각해야하는 것은 아니기 때문이다.
무엇이든 쉽게 얻으면 그만큼 쉽게 잃는 법이니까. 분명이 맘바가 가지는 특징을 제대로 이해할 수 있다면 맘바가 가지는 근본적인 장단점을 발견할 수 있을 것이고, 이를 통해 향후 연구 및 모델링 개발의 기반이 될 것이다. 이번 글에서는 LSSL가 가지는 연산량과 연산 불안정성을 해결하고자 한 S4 모델링에 대해서 정리해보도록 하겠다.
시퀀스 모델링에서 가장 중요한 것은 적당한 길이의 시퀀스를 얼마나 효율적으로 처리하는가와 시퀀스의 길이가 길어질수록 참조할 수 있는 문맥의 길이도 그에 따라 길어져야한다는 것이다. 트랜스포머는 연산량을 희생하여 단일 연산으로 전체 시퀀스에 대한 어텐션 정보를 획득할 수 있고, 토큰 임베딩의 길이를 늘임으로써 이를 해결할 수 있었으나 결국 연산량의 한계가 있다는 문제가 있다. RNN 및 CNN 각각이 가지는 특징들도 있지만 모든 모델링은 각자가 가지는 장단점 때문에 만능일 수는 없었고, 이에 대한 대응으로 SSM(state-space modeling)을 제안한 것이 바로 LMU, HiPPO를 비롯한 논문들이었던 것이다.
SSM은 한마디로 Linear Time-Invariant System으로 latent space를 구축하고자 한 것이다. LTI system의 미분 방정식을 구성하는 matrix가 시간 불변성을 가진다는 특징은 결국 연속 신호를 이산화한 관측 단위에서 미분 방정식은 동일한 식으로 구축되며, 따라서 시퀀스 길이에 무관하게 동일한 latent space $x(t)$를 만들어낼 수 있다는 것이다. gate에 의존하는 RNN과는 다르게 특정 구조를 가지는 Matrix $A$가 있다면 실제 딥러닝 모델의 예측에 가장 중요한 특징 벡터인 latent를 long range dependent하게 뽑아낼 수 있게 된다. 트랜스포머로 각 토큰 단위로 어텐션을 구하든, RNN 구조로 연속으로 들어오는 데이터로 implicit latent를 만들든 그런 방법들이 아니라 실제로 시간 불변성이 성립하는 latent space를 예측해낼 수 있다면 아무리 오랜 시간이 흘러도 (관측 범위가 처음과 크게 벗어나도) hidden space는 동일한 함수로 구현될 것이기 때문에 Long Range Dependency를 보장할 수 있다는 것이다. 따라서 어텐션 연산에 대해 고정된 문맥 길이를 가지는 Transformer 구조에 비해 상대적으로 더 긴 길이의 시퀀스 데이터를 처리할 수 있고 이론상으로는 무한한 길이의 인풋을 감당 가능하다.
논문에 나와있는 그림을 통해 이해해보도록 하자. 우선 좌측의 “Continuous State Space” 부분을 확인해보면, SSM 구조에서 인풋에 해당하는 $u(t)$가 행렬 $A, B, C, D$로 구성된 Linear system의 상태(이를 hidden state, 혹은 latent state $x(t)$라 부른다)를 통해 모델링된 출력값 $y(t)$를 내게 된다. 이때 $x(t)$가 지속적으로 변화하는 input to output $\mathbf{u} \rightarrow \mathbf{y}$를 저장하는 메모리 역할을 수행하게 된다면, 이론상 불변하는(고정된 요소를 가지는) 행렬 $A, B, C, D$에 대해 꾸준이 이전 정보를 저장할 수 있게 된다. 바로 이것이 중앙에 보이는 “Long-Range Dependencies”에 해당되고, 실제로 이에 대한 구조화된 행렬의 효과성을 입증한 것이 “HiPPO: Recurrent Memory with Optimal Polynomial Projections” 논문에 해당된다.
이때 기본적으로 SSM은 “Recurrent system”을 가지게 되는데, 이는 시스템의 구조가 입력에 대한 출력 구조로 이어져있으며, 이전 입력 대비 출력 결과에 따른 시스템 변화가 이후 입출력에 영향을 미치게 된다는 것이다. 이를 연산하는 방식으로는 귀납적으로(Recurrent) 연산 후 연산을 하는 방식으로도 구현이 가능하지만 단순화하여 콘볼루션 연산으로 수행하는 것도 가능하다. 그러나 여전히 고차원의 데이터에 대해 필연적으로 증가하는 행렬 연산($e.g.$ $A, B, C, D$ 행렬이나 $\bar{K}$) 때문에 연산량이 높다는 문제가 생긴다.
따라서 지금 글에서 다루고자 하는 논문인 “Efficiently Modeling Long Sequences with Structured State Spaces” (S4)에서의 목적은 다음과 같다. 기존에 LSSL(단순 SSM)의 높은 연산량을 해결하기 위한 방법들이 제안되었지만, 모두 연산상에 numerical stability(연산 안정성 혹은 연산 엄밀성)이 부족했다. 그렇기 때문에 연산 안정성도 높임과 동시에 기존에 존재하던 “잘” 구조화된 행렬에 적용 가능한 알고리즘들을 활용하기 위해 SSM의 토대가 되는 matrix $A$를 다시 구조화하는 전략을 사용하였다. 바로 원래 $A$의 rank보다 훨씬 낮은 rank(서로 독립인 column의 갯수를 의미하며, 낮은 rank를 가지는 matrix는 독립이 아닌 column을 모두 배제할 경우 그만큼 차원 수를 줄여서 표현 가능하다)를 가지는 요소와 normal matrix(특수한 케이스의 정사각 행렬로, commutable한 특징이나 diagonal 요소로 분리가 가능하다는 등등의 특징을 사용하여 non-normal matrix에 비해 빠르게 연산이 가능하다) 요소로 분리가 가능하다는 점을 사용한다.
또한 기존의 SSM이 coefficient space(latent를 표현하는 함수는 사전 정의된 여러 orthogonal한 함수들의 coefficient 가중합으로 표현하고자 했었다.)로 접근하는 방식을 사용했다면, 이번에는 주파수 차원으로 올려서 계산하게 된다. 시간 단위에서의 콘볼루션은 주파수 단위에서의 곱연산으로 표현된다.
이를 통해 Low-Rank 행렬은 Woodbury identity로, Normal 행렬은 Cauchy kernel로 교정 가능하며 이를 토대로 연산량을 $O(N^2L)$에서 $\tilde{O}(N+L)$로, 메모리는 $O(NL)$에서 $O(N+L)$로 줄일 수 있었다.
긴 문맥을 보존할 수 있게끔 모델링된 matrix $A$ (e$.g.$ HiPPO 행렬)을 사용한다. 그렇게 되면 다음과 같은 연립 미분 방정식으로 표현되는 시스템을 구축할 수 있다.
[ \begin{aligned} \mathbf{x}^\prime = A\mathbf{x} + B\mathbf{u} \newline \mathbf{y} = C\mathbf{x}+D\mathbf{u} \end{aligned},\quad A^{\text{HiPPO}}_{n, k}= -\begin{cases} (2n+1)^{1/2}(2k+1)^{1/2},&\text{if }n > k \newline n+1,&\text{if }n = k \newline 0,&\text{if } n < k \end{cases}. ]
이때 일반적인 컴퓨팅 시스템에서는 Continuous system을 Discrete(이산화된 입력)으로 변화하는 과정이 필요하다. 이를 적용한 식이 실제 SSM에서 적용될 Discrete SSM이다.
[ \begin{aligned} x_k = \overline{A}x_{k-1} + \overline{B}u_k\quad &\overline{A} = (I-\Delta/2 \cdot A)^{-1} (I+\Delta/2\cdot A) \newline y_k = \overline{C}x_k,\quad &\overline{B} = (1-\Delta/2\cdot A)^{-1}\Delta B \newline &C = \overline{C} \end{aligned} ]
이에 대한 증명이나 보다 자세한 내용은 이전 게시글인 LSSL을 보고 오면 좋다. 아무튼 Discrete SSM이 의미하는 바는 SSM이 결국 recurrent 연산 구조를 가지기 때문에 RNN의 연산 특징을 가진다는 것. 이산화된 Matrix $\overline{A}$가 hidden state $x$의 transition matrix 역할을 수행한다. 헌데, 위의 식에 대한 hidden state와 output을 $x_{-1} = 0$라 가정한 후에 전개하면, convolution kernel에 대한 연산으로도 표현이 가능하다.
[ \begin{aligned} x_0 =& \overline{B}u_0,\quad x_1 = \overline{AB}u_0 + \overline{B}u_1,\quad x_2 = \overline{A}^2\overline{B}u_0+\overline{A}\overline{B}u_1 +\overline{B}u_2,~\cdots \newline y_0 =& \overline{CB}u_0,\quad x_1 = \overline{CAB}u_0 + \overline{CB}u_1,\quad x_2 = \overline{CA}^2\overline{B}u_0+\overline{CA}\overline{B}u_1 +\overline{CB}u_2,~\cdots \newline y_k =& \overline{K} \ast \mathbf{u},\quad \overline{K}\in\mathbb{R}^L := \mathcal{K}_L(\overline{A}, \overline{B}, \overline{C}) := (\overline{C}\overline{A}^i\overline{B})_{0\le i < L} \end{aligned} ]
그러므로 만약 convolutional filter $\overline{K}$에 대해 알고 있다는 가정을 한다면 FFT(빠른 콘볼루션 연산 알고리즘)을 통해 연산 속도를 개선할 수 있지만, 이 필터를 연산하는 과정 자체도 행렬곱이 필요하며 non-normal matrix $\overline{A}$에 대해 일반화된 연산을 하기에는 어려움이 따르게 된다. 즉, 이 논문에서 하고자 하는 것은 해당 필터를 어떻게 효율적으로 연산하는가에 대한 부분이다.
행렬의 대각화 (Diagonalization)는 행렬의 고유값(eigenvalue)인 $\lambda$와 고유벡터(eigenvalue)를 활용하여 대상이 되는 행렬을 고유값이 대각선 성분인 행렬로 만드는 과정이다. 예컨데 대각화(Diagonalization)이 가능한 행렬 $A \in \mathbb{R^{n \times n}}$가 존재한다면, 이 행렬의 eigenvalue $n$개와 이에 대응되는 eigenvector $n$에 대해서 $\Lambda = V^{-1}AV$로 표현 가능하다. 이때, $\Lambda$와 $V$는 각각 다음과 같다.
[ \text{For eigenvalues }\{ \lambda_i\}_{i=1}^n \text{ and eigenvectors } \{v_i\}_{i=1}^n,\quad\Lambda = \begin{bmatrix} \lambda_1 & 0 & 0 & \cdots & 0 \newline 0 & \lambda_2 & 0 & \cdots & 0 \newline \vdots & \vdots & \ddots & \vdots & \vdots \newline 0 & 0 & 0& \lambda_{n-1} & 0 \newline 0 & 0 & 0 & 0 & \lambda_n \end{bmatrix},\quad V = \begin{bmatrix} \mid & \mid & \cdots & \mid \newline v_1 & v_2 & \cdots & v_n \newline \mid & \mid & \cdots & \mid \newline \end{bmatrix} ]
이런 상황에서, 기존의 식을 조금 바꿔보면 다음과 같이 정리할 수 있다. 우선, 일반적으로 SSM에서 $D = 0$으로 간소화하여 사용하기 때문에 $\mathbf{y} = C\mathbf{x}$로 표현하도록 하겠다.
[ \begin{aligned} \mathbf{x}^\prime =& A\mathbf{x} + B\mathbf{u} \newline \mathbf{y} =& C\mathbf{x} \end{aligned}\quad\rightarrow\quad \begin{aligned} \tilde{\mathbf{x}}^\prime =& V^{-1}AV\tilde{\mathbf{x}} + V^{-1}B\mathbf{u} \newline \mathbf{y} =& CV\tilde{\mathbf{x}} \end{aligned} ]
두 식이 서로 조금 달라보이지만, 우측과 같이 전개된 시스템의 좌측에 전부 $V$를 곱하게 되면, $x = V\tilde{x}$인 SSM과 동치인 것을 알 수 있다. 즉 입력 대비 출력으로 이어지는 관계성 $\mathbf{u} \rightarrow \mathbf{y}$은 동일한 SSM 시스템이고, 이때의 system latent는 행렬 $A$의 eigenvector matrix $V$에 의해 바뀌게 된다. 위의 식처럼 새로운 시스템에서의 $A$행렬에 해당되는 대각 행렬 $V^{-1}AV$을 구축할 수만 있다면, 앞서 보았던 콘볼루션에서 $A$의 곱연산의 연산량을 효과적으로 줄일 수 있다. (이를 Vandermonde product라 한다. 그냥 그렇다고 생각하고 넘어가자)
그러나 유감이지만, HiPPO 행렬에 대한 대각화는 진행할 수 없다.
그 이유는 슬프게도 HiPPO 행렬이 대각화되면서 고유벡터로 이루어진 matrix $V$의 성분이 너무 커지기 때문이다. 쉽게 말하면, 컴퓨터 상에서 연산이 안정적이려면(실제 수학적 계산값과 일치하려면), 행렬 연산 과정에서 행렬 요소가 너무 큰 값을 가지면 안되기 때문에 불가능하다는 것이다. 앞서 보여준 HiPPO 행렬 $A$를 대각화하면 요소가 $V_{ij} = {i+j \choose i-j}$가 되는데 (combination), state size $N$이 커지면 커질수록 최대 $2^{4N/3}$까지 커지는 요소를 감당할 수 없다 ($e.g.$ 출력값인 $CV\tilde{\mathbf{x}}$ 연산에 문제가 생길 수 있다.).
위에서 다룬 것은 대각화를 통해 연산의 용이성을 올려보자!라는 내용이지만, 결국 기본적인 HiPPO 행렬에 대해서 적용하기는 힘들고 추가 작업이 필요하다는 것을 암시한다. 가장 이상적인 조건은 대각화의 대상이 되는 행렬 $A$가 unitary matrix와 같이 특수 행렬로 대각화가 가능한 경우에 해당된다. 선형 대수에서 이러한 조건이 만족하는 행렬 $A$의 모음을 “normal matrices”라 부른다. 눈치챘을 수도 있지만 당연히 HiPPO matrix는 normal matrix가 아니고, 그렇기 때문에 대각화할 때 고유벡터 요소가 커지는 문제가 발생한다.
다행이지만 저자는 HiPPO matrix $A$는 normal matrix가 아니지만, 해당 행렬을 normal matrix와 low rank matrix의 합으로 나타낼 수 있음을 발견하였다. 하지만 문제는 콘볼루션 필터 연산 시 합(Normal + Low Rank)에 대한 제곱 연산이 필요한데, 이 역시 시간이 오래 걸리고 최적화가 필요한 부분이라는 문제가 발생한다.
[ \overline{K}\in\mathbb{R}^L := \mathcal{K}_L(A, B, C) := (CA^iB)_{0\le i < L} ]
이 문제를 해결하기 위해 세 가지 알고리즘을 추가로 적용하여 커널 필터를 계산하게 된다. 각 알고리즘에 대한 내용을 아래 그림과 관련지어 정리하면 다음과 같다. 아직 디테일하게 설명한 부분이 없어 그림의 내용에 대한 이해는 못하지만 순차적으로 보도록 하자.
위의 알고리즘을 따라가기 위해서는 행렬 $A$가 NPLR (Normal + Low-Rank) 혹은 DPLR (Diagonal + Low-Rank)로 표현이 가능함을 이해하고 넘어가야한다. 이 부분은 논문에 Appendix C.1. 파트를 보면 empirical하게 모든 HiPPO 행렬에 대해 입증해 놓았다. 사실 수식으로의 이해는 필요가 없는 부분이고, 그냥 받아들이면 되는 파트다.
그런 뒤, 기존의 콘볼루션 커널을 계산하던 방식에서 차이를 두게 된다. $\overline{K}$를 직접 계산하지 않고 $\overline{K}$의 Discrete Fourier Transform(DFT) 변환 형태인 $\hat{K}$룰 사용한다. DFT는 이산화된 시간 축에서의 신호를 (여기서는 필터의 요소인 $CB, CAB, CA^2B, \cdots$ 를 연속된 시간 축에서의 신호로 생각하면 된다) 이산화된 주파수 축으로의 스펙트럼으로 바꿔주는 변환에 해당되고, DFT 변환 및 이의 역변환 IDFT 알고리즘은 Fast Fourier Transform (FFT)라 하며 연산 속도는 $O(L\log L)$ 선에서 해결 가능하다.
[ \hat{K}_j = \sum_{k=0}^{L-1} \overline{K}_k\exp(-2\pi j\frac{k}{L}) ]
뒤에서 더 자세히 정리하겠지만, Truncate SSM을 구성하여 스펙트럼 단위에서 연산하게 되면 필터 연산 시 $A$를 여러 번 곱하는 방식에서 벗어나 한번의 행렬 연산으로도 연산이 가능하게 된다. 이 과정에서 골칫거리인 term인 $(1-\overline{A}^L)$가 발생하는데 (결국 powering이 필요), 이를 reparameterization을 통해 반복된 연산을 피해 메모리 절약 및 속도 향상을 할 수 있게 된다. 그리고 위에서 가정한 구조화를 통해 스펙트럼 커널 $\hat{K}$를 연산하는 것이 “Cauchy kernel”과 동일함을 알 수 있고, 효율적인 알고리즘을 적용할 수 있다. 짧게 정리했지만 실제 순서대로 간단하게 표현하면 다음과 같다.
앞서 소개한 효율화 과정에 대해서 자세하게 살펴보자. 개인적으로는 디퓨전 논문보다 어려운 것 같다. 그렇지만 힘내보자.
HiPPO 행렬은 모두 “NPLR (Normal Plus Low-Rank)”로 표현 가능하다. HiPPO 행렬은 총 4가지가 존재하지만, 가장 보편적인 케이스인 LegS에 대해서만 살펴보면 다음과 같다.
[ A^{\text{HiPPO}} _{n, k}= -\begin{cases} (2n+1)^{1/2}(2k+1)^{1/2},&\text{if }n > k \newline n+1,&\text{if }n = k \newline 0,&\text{if } n < k \end{cases}. ]
요 식 모든 요소에 $\frac{1}{2}(2n+1)^{1/2}(2k+1)^{1/2}$를 더해주게 되면 다음과 같은 식이 된다.
[ -\begin{cases} \frac{1}{2}(2n+1)^{1/2}(2k+1)^{1/2},&\text{if }n > k \newline \frac{1}{2},&\text{if }n = k \newline -\frac{1}{2}(2n+1)^{1/2}(2k+1)^{1/2},&\text{if } n < k \end{cases}. ]
이 식의 대각 성분을 따로 분리하게 되면 $-\frac{1}{2}I + S$로 표현 가능하고, 이때 $S$는 normal matrix에 속하는 skew-symmetric matrix가 된다. 또한 원래의 행렬의 모든 요소에 같은 값을 더하는 행렬은 rank가 $1$인 행렬이다. 모든 HiPPO 행렬에 대한 증명은 논문을 참고하면 되고, 이를 통해 암시할 수 있는 사실은 모든 HiPPO state 행렬을 Diagonal part + Low-Rank part로 분리 가능하다는 사실이다. 이를 다음과 같이 표현하도록 하겠다. 논문에서는 증명 과정에서 conjugate($\ast$) 표시를 사용했는데, 본인은 이 기호가 조금 익숙치 않아서 transpose 기호($\top$)로 대체하여 사용하도록 하겠다.
[ A = \Lambda - PQ^\top ]
이제 이렇게 대체된 식으로 discrete system matrix를 표현해보면 다음과 같다.
[ \begin{aligned} \overline{A} &= (I-\Delta / 2 \cdot A)^{-1} (I+\Delta / 2 \cdot A) \newline \overline{B} &= (1-\Delta / 2 \cdot A)^{-1} \Delta B \newline \newline I + \Delta/2\cdot A &= I + \Delta / 2 \cdot\left(\Lambda - PQ^\top\right) \newline &= \frac{\Delta}{2} \left(\frac{2}{\Delta}I + \Lambda - PQ^\top\right) \newline &= \frac{\Delta}{2}A_0 \newline \newline (I-\Delta/2 \cdot A)^{-1} &= \frac{2}{\Delta}\left(\frac{2}{\Delta}-\Lambda+PQ^\top\right)^{-1} \newline &= \frac{2}{\Delta}\left(D-DP(I+Q^\top DP)^{-1}Q^\top D \right) \newline &= \frac{2}{\Delta}A_1, \text{ where } D = \left(\frac{2}{\Delta}-\Lambda\right)^{-1} \text{ (Diagonal term)} \end{aligned} ]
$A_0$을 계산하는 것은 크게 문제가 없는데 $A_1$ 연산에는 큰 문제가 있는데, 바로 역행렬 연산이다. 행렬 차원 수 $N$이 증가할수록 연산량이 기하급수적으로 늘어난다 . 따라서 역행렬 연산을 DPLR 행렬에 대해 위와 같이 Woodbury’s Identity를 통해 단순화할 수 있다. 대각화된 행렬에 대한 inverse는 쉽게 구할 수 있으며, 뒤에 붙는 low-rank term에는 무관하게 연산이 가능하므로 전체 계산식에 대한 역행렬 연산보다 단순화할 수 있다는 것이다. Woodbury’s Identity의 경우에는 앞으로 전개될 증명 과정에 계속 활용되기 때문에 계속 인지하고 있는 편이 용이하다 (DPLR 구조의 행렬만 가지면 계속 효율적으로 적용이 가능).
Woodbury’s Identity는 다음과 같이 적용할 수 있다. 교환 법칙이 성립하는 세 행렬 $A\in\mathcal{R}^{N \times N}, U, V \in \mathcal{R}^{N \times p}$에 대해 (여기서 $\mathcal{R}$은 commutative ring으로, 이에 속하는 원소들에 대해서는 곱연산에 대해 교환 법칙이 성립한다고 생각하면 된다.)
[ (A+UV^\top)^{-1} = A^{-1}-A^{-1}U(I_p + V^\top A^{-1}U)^{-1}V^\top A. ]
위의 식을 만족하게 된다.
암튼 구한 식으로 다시 discrete system을 정의해보면, DPLR인 행렬 $A_1,$ $A_0$에 대해 $O(N)$의 연산량으로 해결 가능하다.
[ \begin{aligned} x_k =& A_1A_0x_{k-1} + 2A_{1}Bu_k \newline y_k =& C^\top x_k \end{aligned}. ]
그리고 저자는 이 부분에서 row vector였던 $C$를 column vector로 간주하여 다른 파라미터($B, P, Q$)들과 shape을 맞추었기 때문에, 필자도 이에 따라 요 부분부터는 기존 시스템 식의 $C$를 $C^\top$으로 바꿔 표기하도록 하겠다.
이제 이산화된 시스템 행렬은 얼추 알겠고, 사실 가장 중요한 것은 Recurrent system에서의 콘볼루션 필터 $\overline{K}$를 빠르게 연산하는 것이다. DPLR이 행렬의 이산화 과정에서 Woodbury’s Identity를 활용할 수 있게 되면서 효율성을 올려준다는 사실을 인지하였으나, 실제로 콘볼루션 필터를 연산하는 과정에서의 반복곱 연산에서는 큰 도움이 되지 않는다는 것을 알 수 있다. 반복곱 연산 대신, DPLR를 활용하기 위해서는 역행렬 연산이 필요하므로 스펙트럼 단위로 넘기는 (coefficients) generating function을 생각해볼 수 있다. 예컨데 무한한 길이의 콘볼루션 필터 신호가 있다고 가정해보자.
[ \mathcal{K}(\overline{A}, \overline{B}, \overline{C}) = \{\overline{C}^\top\overline{B},\overline{C}^\top \overline{A}\overline{B},\overline{C}^\top \overline{A}^2\overline{B},\ldots\} ]
우리가 현재 신호에 대해 가질 수 있는 것은 이상적인 콘볼루션 필터 중 $L$의 길이를 가진 한정된 길이의 필터이다.
[ \mathcal{K}_L(\overline{A}, \overline{B}, \overline{C}) = \{\overline{C}^\top\overline{B},\overline{C}^\top \overline{A}\overline{B},\overline{C}^\top \overline{A}^2\overline{B},\ldots,\overline{C}^\top \overline{A}^{L-1}\overline{B}\} ]
길이가 $L$인 이산(discrete) 신호는 주파수 $2\pi\times\frac{0}{L} \sim 2\pi \times \frac{L-1}{L}$의 성분으로 분해가 가능하다. 이 주파수를 표현하는 unit $z$라는 변수로 표현한다면, 이를 $z$함수에 대한 coefficient의 집합으로 대체 가능하며 이를 $z$-transform이라고 부른다. 이때, 일반적으로 $z$는 복소수 단위(Real + Imag)를 의미하며 주파수 단위에서는 이를 오일러 각도 변환 식인 ($e^{-i\Omega}$)에서 $\Omega =\{\frac{2\pi l}{L}\}_{l=0}^{L-1}$의 합으로 표현 가능하다.
[ \hat{K}_L(z; \overline{A}, \overline{B}, \overline{C}) := \sum_{i=0}^{L-1} \overline{C}^\top\overline{A}^i\overline{B}z^i = \overline{C}^\top(I-\overline{A}^L)(I-\overline{A}z)^{-1}\overline{B} ]
이러한 변환을 DFT(Discrete Fourier Transform)이라 부르며, 이산화된 시간 축의 신호를 이산화된 주파수 축으로 변환하는 과정에 주로 활용된다. 맨 앞단의 $\overline{C}^\top(I-\overline{A}^L) = \tilde{C}$로 두게 되면,
[ \hat{K}_L(z; \overline{A}, \overline{B}, \overline{C}) = \tilde{C} (1-\overline{A}z)^{-1}\overline{B} ]
이 식에서 discretized matrix $\overline{A}, \overline{B}$ 의 closed form으로 대체하여 $A, B$에 대해 표현 가능하다.
[ \begin{aligned} \overline{A} =& (I-\Delta/2 \cdot A)^{-1} (I+\Delta/2\cdot A) \newline \overline{B} =& (1-\Delta/2\cdot A)^{-1}\Delta B \end{aligned} ]
이 식을 그대로 위에 대입하게 되면,
[ \begin{aligned} \tilde{C}^(I-\overline{A}z)^{-1}\overline{B} =& \tilde{C}\left(I-(I-\Delta/2 \cdot A)^{-1} (I+\Delta/2\cdot A)z \right)^{-1}\overline{B} \newline =& \tilde{C}\overline{B}\left(I-\frac{\Delta}{2}A\right)\left(\left(I-\frac{\Delta}{2}A\right)-\left(I+\frac{\Delta}{2}A\right)z\right)^{-1} \newline =& \tilde{C} \Delta B\left( I(1-z) - \frac{\Delta}{2}A(1+z)\right) \newline =& \frac{\Delta}{1-z}\tilde{C} \left( I - \frac{\Delta A}{2\frac{1-z}{1+z}} \right)^{-1}B \newline =& \frac{2}{1+z}\tilde{C}\left(\frac{2}{\Delta}\frac{1-z}{1+z}I -A \right)^{-1}B \end{aligned} ]
+앞서 우리는 시스템 행렬이 DPLR(Diagonal Plus Low-Rank)임을 보였기 때문에 다시 표현하게 되면,
[ \frac{2}{1+z}\tilde{C}\left(\frac{2}{\Delta}\frac{1-z}{1+z}I -A \right)^{-1}B = \frac{2}{1+z}\tilde{C}\left(\frac{2}{\Delta}\frac{1-z}{1+z}I -\Lambda+PQ^\top \right)^{-1}B ]
이제 앞서 언급했던 Woodbury’s Identity를 사용해볼 수 있다. 식을 간소화하기 위해 다음과 같이 두개 되면,
[ R(z) = \left( \frac{2}{\Delta}\frac{1-z}{1+z} - \Lambda \right)^{-1} ]
최종적으로는 다음 식으로 전개할 수 있다.
[ \tilde{C}(1-\overline{A}z)^{-1}\overline{B} = \frac{2}{1+z}\left(\tilde{C}R(z)B - \tilde{C}R(z)P(1+Q^\top R(z)P)^{-1} Q^\top R(z)B \right). ]
여기서 생길 수 있는 의문점은, 앞서 식을 전개하는 과정 상에서 앞단의 $\overline{C}^\top(I-\overline{A}^L) = \tilde{C}$로 정의된 부분이다. 원래대로라면 매번 $\overline{A}^L$를 연산해야 하지만, 단순히 학습 파라미터를 초기에 $\tilde{C}$로 초기화한 상태로 학습 가능하게끔 reparameterization을 하게될 경우 이에 따른 연산 코스트를 줄일 수 있다.
이제 마지막 단계까지 왔다. 결국 전개한 식을 요약하자면, $\overline{K}$를 연산하는 부분을 generating function $\hat{K}_L(\Omega; \overline{A}, \overline{B}, \overline{C})$로 계산하고, 이때 $\overline{A}$가 Diagonal 성분을 가짐을 사용하여 풀어낸 식이다. 그런데 이렇게 풀어낸 식이 결국 Cauchy kernel이랑 정확하게 일치하고, Cauchy kernel은 효율적으로 연산할 수 있는 알고리즘이 존재한다. 우선 Cauchy matrix / kernel의 구조는 다음과 같이 정의된다.
[ M \in \mathbb{C}^{M \times N} = M(\Omega, \Lambda) = (M_{ij})_{0<=i < M,~0<=j<N},\quad M_{ij} = \frac{1}{\omega_i - \lambda_j} ]
최종 식에서의 $Q^\top R(\Omega, \Lambda)P$ 부분을 살펴보게 되면(기존의 $z$를 unit $\Omega$의 각 요소에 대해 생각해볼 수 있다), 이를 계산하는 과정이 Cauchy matrix-vector multiplication 연산량으로 계산 가능하다는 것이다. 예컨데 원래대로라면 길이 $L$인 콘볼루션 커널을 총 $N$만큼의 hidden state에 대해 연산하려면 $O(LN)$ 만큼의 연산량이 소모되는데, 약간의 오차를 허용하면 이를 $O((L+N) \log(L+N) \log \frac{1}{\epsilon})$의 연산량으로 처리가 가능하다. 증명하는 과정은 거의 불필요한데 요약하자면, $Q^\top R(\Omega, \Lambda)P$를 계산하는 것은 $\Omega$에 속하는 모든 $\omega$에 대해 우리는 $\sum_{j} \frac{q_j^\top p_j}{\omega-\lambda_j}$를 구하고자 하는 것과 같으며, 결국 이 식은 Cauchy kernel의 형태와 같기 때문이다 ($\omega \neq \lambda_j$ 라는 조건은 항상 성립함).
이렇게 알고리즘이 모두 정리가 되었다. 다시 앞서 간단하게 언급했던 알고리즘을 다시 가져오면 다음과 같다.
먼저, $A$가 DPLR이라는 점에서 시작하여 식을 전개하였고, 이때 $\overline{K}$를 다이렉트로 계산하지 않고 generating function $\hat{K}$를 계산하기 위해 푸리에 변환을 실시하였다. 이때 나온 식에서 $\overline{C}$를 reparameterization해준다. 그런 뒤, discretized된 모든 matrix를 state matrix의 closed form으로 표현한 뒤, Woodbury’s Identity를 통해 식을 다시 전개하면,
[ \begin{bmatrix}\tilde{C}^\top & Q\end{bmatrix}^\top \left(\frac{2}{\Delta} \frac{1-\omega}{1+\omega} -\Lambda \right)^{-1} \begin{bmatrix}B&P\end{bmatrix} ]
가 된다. 다만 본인의 식 전개와 실제 논문 알고리즘에서 살짝 다른 부분이 있다면 필자는 $\tilde{C}^\top = \tilde{C}$로 쭉 전개해왔다는 사실이다. 같은 구조이기 때문에 큰 문제는 없다고 생각된다. 아무튼 이렇게 계산된 각 요소들로 효율적으로 계산된 $\hat{K}$에 다시 푸리에 역변환을 적용하면 $\overline{K}$를 구할 수 있게 된다. 드디어 이 논문의 아이디어가 되는 모든 알고리즘을 이해할 수 있었다.
앞에서 본 내용을 통해 S4에 필요한 파라미터는 다음과 같다는 것을 알 수 있다. $A$는 HiPPO 행렬의 어떠한 형태로 초기화가 될 예정이고, 해당 $A$는 그 형태에 맞게끔 특정 Diagonal 및 vector인 $\Lambda, P, Q, B, C$로 구성된다. Diagonal도 실질적으로 대각 성분 이외에는 다른 파라미터를 저장할 필요가 없기 때문에 S4는 state dimension $N$에 대해 총 $5N$의 학습 가능한 파라미터 수를 가진다. S4 자체는 Linear mapping 이지만 (동일한 길이의 시퀀스를 뽑아내는 구조), 이를 여러층 쌓고 Non-linearity를 더하게 되면 결국 Deep layer로 사용될 수 있다.
실험 결과를 따로 보여주기보다는 이 상태로 마무리하는게 좋을 것 같다. 실험 결과야 당연히 efficiency를 보여주면서 long-range 효과성을 유지하는 그림이겠거니, 예상했고 논문에 예상한 그대로가 display된 것을 볼 수 있다. 가장 중요하다고 생각한 점은 모델링의 기초가 되는 SSM을 풀어냈던 이전 논문으로부터 개선점을 지속적으로 잡아내고 (예컨데, 행렬곱을 단순화하기 위해 구조를 분해하고 이를 수식으로 증명하는 것) 이러한 과정에서 연속으로 contribution을 낼 수 있다는 점을 배우게 된 것 같다.