loading

Mamba modeling의 기초 (1) - Linear State-Space Layer (LSSL)에 대하여


Mamba, Linear State-Space Layer

Published on February 01, 2024 by JunYoung

Mamba LSSL HiPPO

21 min READ

연속 데이터 구조에 대한 DNN의 발전

Sequential한 데이터를 처리하기 위해 딥러닝 모델은 수많은 변화와 발전을 이루었다. 그 중 요즘 대표적으로 LLMmultimodal 연구에서 활발하게 활용되는 것은 Transformer 구조이지만, 그 이전에는 LSTM이나 GRU같이 Long term(거리가 먼 문맥 간의 관계성 파악) 모듈과 함께 연구된 Recurrent Neural Network (RNN), 그리고 가장 베이직한 DNN 구조인 CNN(Convolutional Neural Network)를 temporal dataset에 적절하게 변형시켜사용하는 방법이 있었다(예컨데, 비디오 데이터셋에는 temporal information 간의 정보도 사용하기 위해 시간 축을 추가한 3D convolution을 사용하였다).

이외의 방법으로는 신경망 자체의 발전으로는 유명하지는 않지만 보다 복잡한 continuous data를 처리하기 위해 neural differential equations (NDEs)를 직접 모델링하는 방법이 주로 사용되었다.

하지만 모든 네트워크는 나름의 장단점이 확실했다. RNN (Recurrent Neural Network)은 모두가 알다시피 Long-term module의 발전이 있었음에도 불구하고 여전히 긴 문맥을 처리하는데 연산량이나 시간이 비례해서 증가한다는 문제점이 있었으며, 또한 Long-term 모듈에 의존하기에 복잡한 데이터에서의 문맥 파악을 학습시키기 어렵다는 근본적인 문제가 있었다.

대체로 gradient vanishing problem이나 gradient exploding problem은 continual learning에서와 더불어 RNN과 같은 연속 데이터를 학습함에 있어 catastrophic forgetting의 주된 이유로 등장하기도 했다.

CNN(Convolutional Neural Network)는 local한 정보에 대해 (서로 차원이 붙어있는 특징) 최적화가 빠르다는 장점이 있으며, 어느 정도 문맥이 명확한 비디오 데이터셋이라던지, 이미지와 같은 object centric/semantic centric 데이터에 대해 inductive bias를 가진다는 장점이 있었다. 하지만 연산 자체가 sequence에 대응할 수 있는 구조가 아니다보니, 길이가 길어질수록 RNN과 같은 문제가 발생하였다. 결국 convolution 연산 또한 정해진 context 내에서의 local information만 뽑아내는 구조다 보니, context length에 따라 연산량이나 시간이 비례한다는 문제는 똑같이 생기게 되었다.

NDE (Neural Differential Equation) 모델링은 특정 modality나 정해진 문제를 수학 모델링을 통해 이론화했지만, 그리 효율적이지 않다는 문제가 있다. 대표적으로는 diffusion modeling을 생각해볼 수 있는데, 생성 모델인 diffusion을 이런 효율의 문제를 score function의 이산화로 해결했다. Implicit model의 부담을 줄여주어서 간단한 U-Net 구조를 사용했고, consistency modeling과 같이 또다른 implicit mapping을 통해 해결할 수 있었지만, 이는 각 구간에서의 미분 방정식 solution을 numerical하게 구할 수 있었기 때문이었고 모든 형태의 미분 방정식에서 일괄적으로 신경망이 효율적으로 학습될 수 있는 구조를 찾는 것은 불가능하다.

결국 가장 이상적인 모델 구조의 발전 방향은

  • Convolution과 같이 병렬화 연산이 가능한 구조여야 효율적일 수 있음.
  • Recurrence 형태의 상태 추론과정을 통한 문맥 처리가 되어야함.
  • Differential equation과 같이 이산화된 신호가 아닌 time-scale에 적용 가능해야함.

로 요약할 수 있다. 이러한 모델링을 찾기 위해 끊임없는 시도가 있었다.


여러 모델링 방법들 소개

CKConv

그 중 하나인 CKConv(Continuous Kernel Convolution)콘볼루션 커널을 일종의 vector continuous function $\psi : \mathbb{R} \rightarrow \mathbb{R}^{N_{out} \times N_{in}}$ 으로 보는 방식이다. 이때 연속 함수 $\psi$는 작은 신경망 MLP로 parameterize하여 학습시키게 되는데, MLP는 value로 time-step을 스칼라 값으로 받아 해당 position에서의 convolution kernel을 벡터로 내보내는 형식이 된다.

UnICORNN

RNN 계열에서 ODE 기반의 모델링 (time-scaling을 통한 long-time dependency 확보)에서는 UnICORNN과 같은 연구가 진행되기도 하였다. 간단하게 방법만 소개하면 해당 RNN은 2차 ODE(일반 미방)을 시간 축으로 이산화 (discretization)할 수 있는 오일러 메소드를 사용한다.

위의 그림에서 나와있는 $y$가 얻고자 하는 함수이고 $z$는 얻고자 하는 함수의 1차 미분 함수에 해당된다. 2차 ODE를 직접 풀어서 원하는 함수를 얻기가 힘들기 때문에 $y$의 1차 미분 함수인 $y^\prime$을 $z$라는 임시 변수로 선언함으로써 2차 미분 ODE를 $z, y$ 간의 1차 미방으로 바꿀 수 있다.

이렇게 변경된 ODE 시스템을 “Hamiltonian system”이라고 부른다. 이 Hamiltonian system을 풀어내는 과정에서 시간별 input에 의존하는 연속 함수가 구현이 되고,

[ H(y, z, t) = \frac{\alpha}{2} \parallel y \parallel^2 + \frac{1}{2}\parallel z \parallel^2 + \sum_{i=1}^m\frac{1}{w_i} \log (\cosh (w_iy_i + (Vu(t))_i + b_i)) ]

각 벡터 $y, z$의 유클리디안 norm 연산인 $\parallel \cdot \parallel$ 을 통해 위와 같이 정리된다. 이 연속 신호 미분 방정식을 오일러 메소드를 통해 이산화하면 얻고자 하는 discrete dynamical system이 추출된다.

LMU

Parallelizing Legendre Memory Unit Training (LMU) 에서는 RNN의 단점 중 하나인 병렬화 불가능 문제를 linear recurrence convolution으로 해결하는 시도를 보였다. 만약 우리가 특정 input의 이전/이후 state를 가져올 수 있는 딜레이 구조의 시스템을 구축할 수 있다면, 해당 시스템의 output으로 input의 recurrence 구조를 확보할 수 있다는 장점이 생긴다. 우리는 Linear system을 찾고자 하기 때문에 (애초에 학습하고자 하는 신경망 연산 자체가 텐서 및 행렬 기반이기 때문이라 생각하면 편하다), 다음과 같이 네 개의 matrices $<A, B, C, D>$ 로 표현되는 LTI system을 찾는 것이 목표가 된다.

[ \begin{aligned} \dot{m} =& Am + Bu \newline y =& Cm + Du \end{aligned} ]

그리고 I/O 의 라플라스 변환 형태인 $u(s), y(s)$로 SISO system의 transfer function $G(s)$를 정의할 수 있게 된다. 하지만 해당 transfer function이 내포하는 어려움은 infinite dimensional하며, continous delay $\theta$를 모두 커버치기 불가능하다는 문제가 생긴다.

[ G(s) = \frac{y(s)}{u(s)} = e^{-\theta s} ]

이제 finite하고 causal한 state space realization 차원으로 가져오기 위해서는 transfer function $G(s)$를 $s$에 대한 polynomial로 구성을 해야한다. 보통 transfer function은 분자와 분모가 각각 특정 차수를 가지는 $s$의 다항식으로 표현되는데, proper 한 dimension을 가지는 시스템은 분모의 차수가 더 높아야한다(그래야 시스템의 convergence를 보장할 수 있기 때문이다). 아무튼 위에 있는 저 식을 approximation 해야한다는 결론에 다다르게 된다. 이를 Linear system에서 구현하기 위해서 앞서 확인했던 것처럼 matrices를 구해야하고, $i,~j \in [0,d-1]$ 에 대해서 다음이 성립하는 행렬 요소를 사용하게 된다.

디테일한 내용이나 증명 과정은 해당 페이퍼의 이전 논문인 LMU를 보거나 아래에 있는 증명 과정을 보면 된다.

[ \begin{aligned} A_{i,j} =& \frac{(2i+1)}{\theta}\begin{cases} -1 & i < j \newline (-1)^{i-j+1} & i \ge j \end{cases}\newline B_i =& \frac{(2i+1)(-1)^i}{\theta} \newline C_i =& (-1)^i \sum_{l=0}^i {i \choose l}{i+l \choose j}(-1)^l \newline D =& 0 \end{aligned} ]

해당 매트릭스들 중 세번째 matrix인 $C$가 가장 주요 아이디어에 해당된다. $C$는 풀게 되면 르장드르 다항식으로 표현되며, $D = 0$이기 때문에 shifted input $u(t-\theta)$ 의 정확도를 현재 state $m_t$를 기준으로 판별할 수 있다. 예컨데, $\theta^\prime$만큼의 phase가 이동된 신호를 예측하고자 한다면 다음과 같이 shifted Legendre polynomial를 통해 근사할 수 있다.

[ \begin{aligned} C_i(\theta^\prime) = (-1)^i \sum_{l=0}^i {i \choose l}{i+1 \choose j}&\left(-\frac{\theta^\prime}{\theta}\right)^l,~0 \le \theta^\prime \le \theta \newline u(t-\theta^\prime) \approx& C(\theta^\prime)^\top m_t \end{aligned} ]

설명이 길었지만 다시 풀어서 설명하자면, 이상적인 딜레이 시스템을 LTI 시스템으로 구축하여 표현한 것이 기존의 Linear state machine 디자인이었고, 이를 다시 non-linear neural network system을 사용하여 학습한 것이 LMU 구조가 되겠다. 딜레이 시스템을 솔루션으로 삼아 네트워크를 학습하려고 한 것이다.

HiPPO

HiPPO는 LMU를 일반화한 구조에 해당된다.

HiPPO의 방법은 다음과 같다:

원래 함숫값과 예측된 함숫값 사이의 차이를 measure할 수 있는 Hilbert space $\mu$상에서 각 구간의 연속 함수인 $f$를 $g$라는 subspace로 보내는 과정을 거친 뒤, 이를 적당한 vector basis의 coefficient의 배열로 표현한다. 그렇게 되면 Continous-time ODE를 LTI system의 미분 방정식으로 표현할 수 있게 되며, 이때 system의 주축이 되는 $A(t)$와 $B(t)$ 함수의 형태를 결정하여 시퀀스 메모리에 대한 중요도를 매핑한다. 이를 통해 기존 LMU를 continuous-time memorization으로 일반화시켰다. 왜냐하면 기존 LMU(르장드르 메모리 유닛)에서는 특정 슬라딩 윈도우 크기($\theta$)를 가지는 이상적인 delay system의 LTI 미분 방정식을 그대로 이산화하여 사용하기 때문이다.


LSSL 모델링

이런 이전 모듈들의 발전은 모두 공통적으로 기존 CNN/RNN의 구조 및 단점을 time-step 차원에서 접근했다는 점이다. 하지만 모든 방법들은 convolutional/recurrent model의 문제점을 근본적으로 해결하지 못했다는 점이 한계점으로 작용했다.

Linear State-Space Layer (LSSL)은 위의 그림에서 나오는 각각의 장점을 통합한 구조를 고안하는 것을 주된 목적으로 삼았다. 결국 formulation은 이전 approach와 큰 차이는 없다. LSSL은 1-dimensional function 혹은 sequence $u(t) \rightarrow y(t)$를 implicit function $x(t)$를 통해 mapping하고자 하는 방법이다.

$A$는 앞서 봤던 LMU에서와 같이 system의 implicit function $x(t)$의 evolution을 조정하는 matrix이며, $B, C, D$는 projection에 사용된다.

[ \begin{aligned} \dot{x}(t) =& Ax(t) + Bu(t) \newline y(t) =& Cx(t) + Du(t) \end{aligned} ]

만약 $\Delta t$를 discrete step-size로 정한다면, LSSL은 정해진 갯수의 메모리와 연산으로 각 시간 축에 따라 state를 변화시키는 recurrent model로 해석할 수 있으며, LTI system인 위의 두 수식은 결국 continous convolution으로 표현될 수 있다. 고로 discrete-time version의 LTI system 또한 convolution으로 병렬화가 가능하다. 학습 속도가 빨라질 수 있다는 것이다. 마지막으로 LSSL은 LTI system의 모델링 자체가 differential equation이기 때문에 continous-time model의 모든 적용 가능한 상황을 그대로 모방할 수 있다.

결국 이 논문에서 밝히고자 한 내용은 위의 LSSL이 고전적인 제어 이론으로부터 익히 알려져있는 사실과 같이 모든 형태의 1-D Convolution을 표현할 수 있을 뿐만 아니라, 적절한 step size인 $\Delta t$ 그리고 적절한 state matrix $A$를 가지고 RNN 및 ODE가 가지는 특성(특히 장점에 집중)을 그대로 가져올 수 있다는 것이다. $A$는 다시 말하지면 시스템의 변화를 주도하는 학습 행렬로 사용되는데, HiPPO와 같은 이전 연구에서 드러났던 것처럼 연속 시간에 대한 memory를 고려하면서 동시에 long dependency를 고려해야한다.


Continuous-time memorization

Continuous time memorization 에 대한 근사화(approximation)는 HiPPO 그리고 LSSL 논문에서 공통적으로 가지는 이론적/기술적 배경에 해당된다.

필연적으로 연속 시간 모델링을 그대로 적용할 수 없기 때문에 이를 이산 시간 모델로 근사화 혹은 다운 샘플링하는 과정을 거치게 된다.

디퓨전 모델링에서도 확인할 수 있었던 것처럼 결국 연속 시간 미분 방정식의 $dt$를 얼마나 조정하냐에 따라 생성 성능이 달라졌기 때문에, 결국 연속 시간 모델링을 이산화할 때는 step size/time scale인 $\Delta t$를 조절하는 것이 중요하다.

해당 섹션에서는 LSSL 모델링으로부터 여러 property에 대한 insight를 얻을 수 있는 근거라고 볼 수 있는 개념들에 대해서 정리하도록 하겠다.

Approximations of differential equations

모든 형태의 differential equation $\dot{x}(t) = f(t, x(t))$는 integral equation $x(t) = x(t_0) + \int_{t_0}^t f(s, x(s))ds$을 동치로 가지게 된다. 해당 integral solution은 함수 $x$의 근사치를 $f(s, x(s))$에 넣고 계속 연산을 하는 방식으로 풀어낼 수 있다. 예컨데 $x_0(t) = x(t_0)$라는 함수 초기 조건을 가지고 있다면,

[ x_{i+1} (t) = x_0 (t) + \int_{t_0}^t f(s, x_{i}(t))ds ]

위와 같이 근사화할 수 있다. 이를 Picard iteration 이라고 부른다.

Discretization

그리고 이산화 과정에서 함수를 직접 적분해낼 수 없기 때문에 discrete times $t_i$에 대해, $x(t_i)$를 쪼개서 얻어내야한다. Integral equation의 form을 closed form으로 정확히 계산할 수 있다면 단순히 downsampling하는 방법으로 각 $x(t_0), x(t_1), \cdots$ 를 얻어내거나, closed form으로 알지 못하더라도 picard iteration을 각 구간별 integral equation인

[ x(t_{k+1}) = x(t_k) + \int_{t_k}^{t_{k+1}} f(s, x(s)) ds ]

에 적용하여 각 $t_k$ 시점의 함숫값들을 샘플링할 수 있다. 다른 방법으로는 generalized bilinear transform (GBT)가 있는데, 이는 현재 우리가 관심있는 Linear ODE에 적용될 수 있는 방법이다. 풀고자하는 Linear ODE의 형태가 다음과 같을때,

[ \begin{aligned} \dot{x}(t) =& Ax(t) + Bu(t) \newline y(t) =& Cx(t) + Du(t) \end{aligned} ]

GBT update는 다음의 수식으로 진행된다. 수식에서의 $\Delta t$는 step size를 의미한다.

[ x(t+\Delta t) = (I-\alpha \Delta t \cdot A)^{-1}(I+(1-\alpha)\Delta t \cdot A)x(t) +\Delta t(I-\alpha \Delta t \cdot A)^{-1}B \cdot u(t) ]

수식이 조금 복잡해서 한번에 잘 이해가 되질 않지만 특별한 케이스를 보면 이해하기 어렵지 않다. $\alpha = 0$을 위 수식에 대입하면,

[ \begin{aligned} x(t+\Delta t) =& x(t) + \Delta t \cdot (Ax(t) + Bu(t)) \newline =& x(t) + \Delta t \cdot \dot{x}(t) \end{aligned} ]

위와 같이 표현되며 이는 가장 대표적인 방법인 Euler method임을 알 수 있다. 결국 $\alpha$는 동일하게 함수를 구하는 방식에서 어느 위치에서의 미분값을 사용하냐에 따라 달려있다. $\alpha=1$이 되면 backward Euler method 가 되는데, 이는 동일하게 함수를 예측할 때 특정 위치에서의 도함수에 기반한 first order approximation이라는 점은 같지만 특정 위치가 $t$ 가 아닌 $t + \Delta t$ 라는 점에서 차이가 있다.

[ x(t+\Delta t) = (I-\Delta t A)^{-1}x(t) + \Delta t (I - \Delta t A)^{-1} B \cdot \dot{x}(t) ]

따라서 $\alpha = \frac{1}{2}$를 사용하게 되면 서로 다른 두 위치의 도함수 평균을 쓰게 되므로, 만약 곡률이 큰 복잡도가 높은 함수가 솔루션을 구성하는 상황에서는 같은 $\Delta t$를 사용하더라도 보다 안정적인 함수 예측이 가능해진다. 이를 bilinear 방법이라고 부른다.

[ x(t+\Delta t) = (I-\Delta t / 2A)^{-1}(I+\Delta t / 2A) x(t) + \Delta t (I - \Delta t / 2A)^{-1} B\cdot\dot{x}(t) ]

이렇게 bilinear 방법에 사용되는 matrix A와 B를 $\bar{A}, \bar{B}$ 라고 했을 때, 이를 통해 위의 시스템을 discretize하게 되면 다음과 같은 discrete-time state-space model을 구할 수 있다.

[ \begin{aligned} x_t =& \bar{A}x_{t-1} + \bar{B}u_t \newline y_t =& Cx_t + Du_t \end{aligned} ]

Timescale factor

시퀀스 길이에 따른 dependency는 길이가 길어질수록 줄어든다. 예컨데 $\Delta t$ 만큼을 시간 간격으로 잡는다면 의존도는 그에 반비례하게 된다. 대부분의 ODE 기반 RNN 구조에서는 $\Delta t$를 고정값으로 사용하였지만, classical RNN의 gating 메커니즘은 이를 학습하는 것과 같은 효과를 지닌다. 그리고 CNN 관점에서의 $\Delta t$는 convolution kernel의 크기를 조절하는 형태로 해석이 가능하다. 즉, CNN이든 RNN이든 ODE 기반으로 해석한다면 모두 시간 간격인 $\Delta t$를 어떻게하면 최적화할 수 있을까에 대한 문제로 해석이 가능하다는 것이다.

Continuous-time memory

입력되는 함수 $u(t)$와 고정된 probability measure(메트릭) $\omega(t)$가 있을 때, 함수의 기본꼴이 되는 $N$개의 basis가 있다고 가정해보자. 각 time step $t$마다 이전까지의 input들인 $u(\tau)\vert_{\tau < t}$ 는 $N$개의 basis의 조합으로 표현이 가능하고, 이는 곧 함수를 projection하여 획득한 coefficient vector $x(t) \in \mathbb{R}^{N}$ 이다. 이때 각 time step마다의 최적의 솔루션은 거리 메트릭 $\omega(t)$에 의존하게 된다. 이렇듯 함수 $u(t)$를 coefficient $x(t)$로 표현하는 과정이 앞서 소개했던 HiPPO (High-Order Polynomial Projection Operator)가 된다.

HiPPO의 경우에는 두 가지 경우(해당 논문에서는 LegT, LagT라는 이름으로 제안된 메트릭)를 제안하였는데, 모든 time step에 같은 중요도를 매핑하는 uniform measure $\omega = \mathbb{I}{[0, 1]}$ 와, 가까운 time step에 보다 높은 중요도를 매핑하는 exponential-decaying measure $\omega(t) = \exp(-t)$ 가 있다. 논외긴 하지만 HiPPO에서는 정해진 sliding window 크기를 가지는 translated Legendre (LegT) 대신 long dependency 및 forgetting 문제를 해결하고자 scaled Legendre (LegS)를 사용하였다. 둘의 공통점은 window 안에서 균일한 measure weight을 가진다는 점이지만, LegS는 시간이 흐를수록 window 크기가 커진다는 차이점이 있다. 아무튼 중요한 점은 measure 종류에 따라 matrix $A$를 closed form으로 풀어낼 수 있으며, 이를 토대로 long dependency에 대한 모델링이 가능하다.

각 메트릭에 따른 matrix $A$를 정하는 과정은 HiPPO 논문의 Appendix를 참고하면 되는데, 이를 조금 간단하게 정리해보고자 한다. 관련 내용을 이해하는데 필요한 사전 지식이 너무 방대하여 완벽한 증명 과정을 담기에는 무리가 있지만 그럼에도 HiPPO 전반적인 내용을 이해해야 LSSL 모델링을 해석할 수 있기 때문이다.

Orthogonal Polynomials

Orthogonal polynomials (서로 수직 관계에 있는 다항식)은 함수를 해석하는데 사용되는 기본적인 툴이다. 모든 measure $\mu$ 상에서 해당 OP에 대응되는 unique한 함수 시퀀스가 나오게 된다. 여기서 measure metric은 적분이 이루어지는 서브 공간으로 이해하면 된다. OP의 특징은, 서로 다른 OP들을 measure 상에서 적분했을때 0이 나와야한다는 것이다. 그리고 $i$번째 Polynomial은 차수가 $i$라는 constraints도 포함된다.

[ \langle P_i, P_j \rangle_\mu = \int P_i(x) P_j(x) d\mu (x) = 0~~(i \neq j),~\deg (P_i) = i ]

이러한 조건에서 $f$라는 이상적인 함수에 근사하는 최적의 솔루션은 다음과 같이 계산된다.

[ \sum_{i=0}^{N-1} c_i P_i(x) / \parallel P_i \parallel_\mu^2,~\text{where }c_i = \langle f,P_i \rangle_\mu = \int f(x)P_i (x) d\mu(x) ]

가장 대표적으로 유명한 OP에는 Fourier series basis를 생각해볼 수 있고, Jacobi, Laguerre 혹은 Hermite Polynomial도 이에 포함된다. 여기에서 소개할 OP는 Jacobi Polynomial에 속하는 르장드르 다항식이다.

Legendre Polynomials

르장드르 다항식은 흔히 구면 좌표계에서 많이 사용한다. 공학 수학을 배울 때의 악몽이 떠오르는 기분이다. 암튼 orthogonal 관계는 익히 알려진대로 구간 $[-1, 1]$ 내에서 $L^2$ 내적을 취했을 때 $\frac{2}{2n+1}$ 만큼 스케일링된 크로네커 델타를 획득할 수 있다. 그리고 유명한 성질 중 하나가 $P_n(1) = 1, P_n(-1) = (-1)^n$ 라는 경계조건을 가진다는 것.

여기서 일종의 선형성을 통해 다양한 time-scale 축에 대한 Polynomial 또한 구할 수 있다. 결국 르장드르 다항식이 성립하는 measure 공간 자체도 균일 확률 분포였기 때문에 가능한 일이다.

원래의 orthogonality는 $[-1, 1]$에서 성립했고, 이를 $[0, t]$ 구간에서 성립하게 하기 위해 함수 구간을 맞춰주게 되면 다음과 같다.

[ \begin{aligned} (2n+1)\int_0^t P_n \left( \frac{2x}{t} - 1 \right) P_m \left( \frac{2x}{t}-1 \right) \frac{1}{t} dx = \frac{2n+1}{2}\int P_n P_m \omega_\text{leg} dx \end{aligned} ]

적분 구간만 맞춰줬는데 다시 크로네커 델타를 획득할 수 있다. 고로 measure가 스케일링된 경우 르장드르 다항식은 원래의 다항식을 스케일링 해주면 쉽게 얻을 수 있다.

[ (2n+1)^{1/2} P_n \left(\frac{2x}{t} - 1\right) ]

Translated Legendre

Translated Legendre는 윈도우 크기가 $\theta$이고, 현재 지점이 $t$인 경우의 Legendre measure를 의미한다.

[ \begin{aligned} \omega(t, x) =& \frac{1}{\theta} \mathbb{I}_{[t-\theta, t]} \newline p_n(t, x) =& (2n+1)^{1/2}P_n\left(\frac{2(x-t)}{\theta} + 1\right) \newline g_n(t, x) =& \lambda_n p_n (t, x) \end{aligned} ]

그리고 원래 여기에서 tilting 개념이 등장하면서 굳이 OP를 쓰지 않을 때 사용하는 함수가 등장한다. 이를 $\chi$라고 하는데, 만약 $p_n(t, x)$ 대신 조합된 함수 형태인 $p_n(x)\chi(x)$를 쓴다고 가정한다면 각 time step에서 이번에는 $\omega/\chi^2$에 orthogonal해지게 된다 (OP 곱하기 OP 곱하기 $\chi^2$이 되므로). 만약 normalized된 measure와 orthonormal basis를 구한다치면,

[ \zeta(t) = \int \frac{\omega}{\chi^2} = \int \frac{\omega^{(t)}(x)}{(\chi^{(t)}(x))^2}dx ]

해당 함수가 곧 normalization constant가 된다. 그렇기에 normalized된 measure인 $\nu^{(t)}$는 $\frac{\omega^{(t)}(x)}{\zeta(t)\cdot(\chi^{(t)}(x))^2}$를 density로 가진다. 이렇게까지 하는 이유는 결국 tilted OP를 orthonormal하게 맞춰주기 위함이다. 위의 수식을 사용하여 orthogonality를 확인하면 르장드르에서의 orthogonality가 원래의 measure $\omega$에 대해 정규화가 됨을 알 수 있다. 하지만 이건 특수한 경우에 formulation을 위해 사용하게 되지만, 르장드르에 의한 projection에는 사용되지 않는다. 따라서 그냥 일반적인 수식을 생각해주면 된다. 앞서 추가로 언급했던 르장드르 다항식의 특성을 활용하면 마찬가지로 shifted and scaled Legendre에 대해,

[ \begin{aligned} g_n(t, t) =& \lambda_n (2n+1)^{1/2} \newline g_n(t,t-\theta) =& \lambda_n (-1)^n (2n+1)^{1/2} \end{aligned} ]

위의 경계조건을 가진다.

Projection and Coefficients

$A$ 하나 유도하는데 너무 돌아가는 듯 하지만 HiPPO를 완전히 정복하기 위해선 필수적인 수식들이다. 앞서 tilting을 고려한 measure를 유도했었는데, 이를 사용하여 coefficient를 계산하기 위해 measure에 projection한 결과는 다음과 같다.

[ c_n(t) = \zeta(t)^{-1/2} \lambda_n \int fp_n^{(t)} \frac{\omega^{(t)}}{\chi^{(t)}} ]

해당 수식을 토대로 end-to-end model을 구성하고, 해당 네트워크가 online prediction에 기반에서 이전의 함숫값 $f$ 그리고 현재의 함수를 제대로 대변하게 하기 위해서는 $c(t)$를 벡터로 표현해야하고, 이는 곧 coefficient의 벡터 형태로 얻고자 하는 목적에 부합한다.

Coefficient는 항상 현재의 예측에 기반하여 업데이트되어야한다. 즉 coefficient는 고정되어있지 않고 지속적으로 변하는 함수로 고려해야하며, 이에 맞는 미분 방정식을 생각해볼 수 있다.

[ \begin{aligned} \frac{d}{dt} c_n(t) &= \zeta(t)^{-1/2} \lambda_n \int f(x) \left(\frac{\partial}{\partial t}p_n (t, x) \right) \frac{\omega}{\chi} (t, x) dx \newline &+\int f(x) \left( \zeta^{-1/2}\lambda_n p_n(t, x) \right)\left(\frac{\partial}{\partial t} \frac{\omega}{\chi} (t, x)\right) dx \end{aligned} ]

Coefficient dynamics with Translated Legendre

르장드르 다항식의 projection을 구할 때 tilting을 무시한다고 했다. 그러면 위의 수식을 풀어낼 때 필요한 것은 OP의 편미분과 measure의 편미분이다. OP의 편미분은 자세한 과정은 생략하고 결과만 언급하자면 $n$번째 르장드르의 미분은 $n-1$번째의 르장드르까지의 linear combination으로 표현할 수 있다. 놀라운 르장드르의 세계.

그래서 정말 다행이지만 $\lambda_n p_n(t, x)$의 미분은 수많은 $g$들로 간단하게 표현 가능하다.

[ \frac{\partial}{\partial t} g_n (t, x) = -\lambda_n (2n+1)^{1/2} \frac{2}{\theta} \left( \lambda_{n-1}^{-1} (2n-1)^{1/2}g_{n-1} + \lambda_{n-3}^{-1} (2n-1)^{1/2} g_{n-3} + \cdots \right) ]

그리고 measure에 대한 편미분은 rectangular function에 대한 미분과 같다.

[ \frac{\partial}{\partial t} \omega (t, x) = \frac{1}{\theta}\delta_t - \frac{1}{\theta} \delta_{t-\theta} ]

준비물이 모두 완료되었기 때문에 이를 통해 앞서 구했던 coefficient dynamics를 표현한 미분 방정식에 대입이 가능하다.

[ \frac{d}{dt}c_n(t) = -\frac{\lambda_n}{\theta} (2n+1)^{1/2} \sum_{k=0}^{N-1} M_{nk} (2k+1)^{1/2} \frac{c_k(t)}{\lambda_k} + (2n+1)^{1/2} \frac{\lambda_n}{\theta} f(t) ]

이며 이 때 $M_{nk}$는 $k$가 $n$보다 작거나 같으면 무조건 $1$이고 $k$가 $n$보다 크면 $(-1)^{n-k}$의 값을 가지는 value이다. 이제 임의로 정해줄 수 있는 $\lambda_n = (2n+1)^{1/2}(-1)^n$를 적용하면

[ \frac{d}{dt} c(t) = -\frac{1}{\theta} Ac(t) + \frac{1}{\theta} B f(t) ]

의 수식에서

[ A_{nk} = (2n+1)\begin{cases} (-1)^{n-k}& \text{if }k < n \newline 1 & \text{if }k \ge n \end{cases},~~B_n = (2n+1)(-1)^n ]

앞서 소개했던 LMU가 그대로 나오는 것을 확인할 수 있다.


LSSL 해석해보기

다시 LSSL로 돌아와서 Fixed state space representation $A, B, C, D$가 주어진 상황을 가정해보자. 간단하게도 LSSL은 input sequence를 output sequence로 매핑하는 과정이 된다. LSSL는 이러한 매핑 과정에서 파라미터 행렬 $A, B, C, D$ 그리고 discretize에 필수적인 $\Delta t$로 정의된다. 이제 이러한 LSSL이 대체 어떻게 RNN, CNN 그리고 Neural ODE의 모든 특징을 가질 수 있는지 해석해보도록 하겠다.

LSSL to RNN

LSSL에서의 recurrent state는 각 time step$t$$x_{t-1}$에 해당한다. 현재 state $x_t$ 그리고 output $y_t$는 이산화된 LSSL formulation에 의해 계산된다.

[ \begin{aligned} x_t =& \bar{A}x_{t-1} + \bar{B}u_t \newline y_t =& Cx_t + Du_t \end{aligned} ]

따라서 RNN 구조와 같이 동작하는 것을 알 수 있다. 심지어 RNN 구조에서의 gated recurrence 도 만족한다. 예컨데 1차원의 gated recurrence 구조 $(1-\sigma (z))x_{t-1} + \sigma(z) u_t$는 backward-Euler method로 $\dot{x}(t) = -x(t) + u(t)$를 이산화한 것과 동일하다. $z$는 임의의 expression이 모두 가능한데, sigmoid function 특성과 앞서 소개한 GBT를 생각하면 $\Delta t = \exp (z)$로 표현했을때 gated recurrence가 $A = -1, B = 1$인 backward-Euler method임을 증명할 수 있다. 그런데 여기서 의문이 생길 수 있는 점은, Linear system에서 구축한 state layer가 과연 일반적인 deep RNN이 가지는 non-linearity 및 복잡도를 표현할 수 있는가에 대한 문제이다.

앞서 단순히 $\dot{x}(t) = -x(t) + u(t)$의 이산화에 대해 언급했었는데, 이를 다르게 해석해서 Picard iteration 을 사용한다고 생각하면, 결국 deep RNN은 학습 과정에서 Picard iteration 을 거치면서 함수를 찾아간다고 생각할 수 있다. 즉, 만약 linear recurrence가 아닌 non-linear recurrence를 사용한다면 LSSL 또한 non-linearity를 학습할 수 있게 된다. 이를 통해 RNN 구조와 LSSL는 필요충분 관계에 놓여있다고 볼 수 있다.

LSSL to CNN

간단한 상황을 가정하기 위해 initial state를 $0$이라 가정해보자. 그렇게 되면 Linear state system을 풀어낸 output을

[ y_k = C(\bar{A})^k\bar{B}u_0 + C(\bar{A})^{k-1}\bar{B}u_1 + \cdots + C\bar{A} \bar{B}u_{k-1} + \bar{B}u_k + Du_k ]

이처럼 정리할 수 있으며, 이는 곧 discrete-time convolution으로 표현 가능하다.

[ \begin{aligned} &y = \mathcal{K}_L (\bar{A}, \bar{B}, C) \ast u + Du \newline &\mathcal{K}_L (\bar{A}, \bar{B}, C) = (CA^iB)_{i \in [L]} \in \mathbb{R}^L \end{aligned} ]

따라서 LSSL은 output이 convolution에 의해 연산되는 모델로 해석 가능하며, 콘볼루션 연산은 FFT로 가속화가 가능하다.

일반적인 continous state-space system의 관점에서 output $y$는 input $u$에 대해 시스템의 impulse response function $h$와의 콘볼루션 연산으로 표현된다.

[ y(t) = \int h(\tau)u(t-\tau) d\tau ]

이와는 조금 다르게, convolutional filter가 만약 rational functional degree ($N$)를 가지는 경우, 크기가 $N$인 state-space model로 필터를 나타낼 수 있다. 기존 연구들에서 밝혔던 점을 토대로 임의의 convolutional filter $h$는 유한한 degree 값을 가지는 rational function으로 표현이 가능하다. 앞서 봤던 HiPPO matrix의 케이스를 예로 들어보도록 하자. 필요한 사전지식을 정리할 때 Translated Legendre의 경우를 보게 되면, $A$는 특정 구간($\theta$) 내에서 동일한 확률 분포를 가지는 measure에서 정의되었다. 일반적인 LSSL에서 $dt$를 고정시켜서 생각했을 때, 첫번째 식인

[ \dot{x}(t) = Ax(t) + Bu(t) ]

은 history element를 기억하는 과정에 해당되고 두번째 식인

[ y(t) = Cx(t) + Du(t) ]

은 해당 윈도우 내에서 유의미한 feature를 뽑는 작업이다. 그렇기 때문에 LSSL은 결국 width가 학습 가능한 convolutional kernel filter를 학습하는 과정과 동치라고 생각할 수 있다.


Deep Linear State-System Layers

일반적인 LSSL은 간단하게 요약하면 입력 시퀀스를 출력 시퀀스로 매핑하는 시스템이었다. 예컨데 길이가 $L$인 신호가 있다면, LSSL은 $\mathbb{R}^L \rightarrow \mathbb{R}^L$을 수행하는 하나의 vec to vec 함수 구조이며 이때 함수 자체는 parameterized 되어있다. 만약 LSSL을 $\psi$라고 한다면,

[ \psi(\cdot \vert A, B, C, D, \Delta t),~A \in \mathbb{R}^{N \times N},~B \in \mathbb{R}^{N \times 1},~C \in \mathbb{R}^{1 \times N},~D \in \mathbb{R}^{1 \times 1} ]

이처럼 표현할 수 있다. 앞서 언급했던 것처럼 단일 LSSL은 Recurrence, Convolution의 특징을 모두 가지고 있기 때문에 RNN과 CNN의 대표적인 레이어인 recurent unit이나 convolution kernel처럼 사용할 수 있다.

또한 입력 시퀀스가 transformer의 input처럼 $H$의 hidden dimension을 가지고 있다고 하면($L \times H$), LSSL은 $H$만큼의 LSSL을 독립적으로 학습하게 되고, Transformer의 multi-head 효과 또한 그대로 적용할 수 있다.

말하고자 했던 것은 LSSL를 stacking하는 과정으로 기존 DNN 방법론과 같이 다양한 함수를 모사할 수 있으며 동시에 normalization, residual connection과 같은 방법론과 함께 모델링될 수 있다는 사실이다.


LSSL과 Continuous-time Memorization

LSSL이 기존 DNN 모델링의 특징을 살리면서 사용될 수 있다고 해서 무작정 사용할 수는 없는 노릇이고, LSSL이 장점을 보일 수 있어야 한다.

Long dependency into LSSLs

Discretized Linear system ODE에서, 시스템은 이산화된 parameter $\bar{A}$가 계속 곱해지며 발전해간다.

[ x_t = \bar{A}x_{t-1} + \bar{B}u_t ]

그 말은 gradient descent로 학습하게 되면vanishing gradient 문제를 피할 수 없다는 것이다. 이처럼 만약 $A$를 랜덤하게 초기화한 후 학습하는 형태를 사용하면, 기대하는 성능이 나오지 않을 것이라는 말이 된다.

하지만 HiPPO와 같은 framework에서는 measure $\omega$에 따라 어떤 방식으로 이전 function을 기억할 지에 대한 문제를 언급했었다 (projection/coefficient화 과정을 통해). 그러나 HiPPO의 문제점이라고 한다면 이렇게 매뉴얼하게 정한 hippo matrix를 학습하지 못하고 그대로 사용해야한다는 점이다. 왜냐하면 HiPPO에서는 르장드르를 포함한 일부 measure에 대해서만 이를 풀어낼 수 있는 structured solution matrix $A$가 존재했고, 모든 일반적인measure에도 다른 형태의 $A$가 존재할 수 있다는 사실을 밝히지 못했기 때문이다.

따라서 LSSL에서는 이를 arbitrary measure $\omega$로 확장시키고, 이때 Low-recurrence width $A$에 대한 미분 방정식을 찾을 수 있다고 증명하였다.

Efficient Algorithms for LSSLs

그러나 A와 $\Delta t$가 상당히 중요한 parameter임이 드러났음에도 불구하고, naive LSSL에서는 학습하기 어렵다는 문제가 발생한다. LSSL은 MVM(Matrix Vector Multiplication) 그리고 Krylov function을 연산할 때 (각각 convolution/recurrence에 해당) 전자의 경우에는 matrix inversion이 필요하다는 어려움과,

[ x(t+\Delta t) = (I-\alpha \Delta t \cdot A)^{-1}(I+(1-\alpha)\Delta t \cdot A)x(t) +\Delta t(I-\alpha \Delta t \cdot A)^{-1}B \cdot u(t) ]

후자는 $\bar{A}$를 feautre의 길이인 $L$만큼 곱해야 한다는 문제가 발생한다.

[ \mathcal{K}_L (\bar{A}, \bar{B}, C) = (CB, CAB, \ldots, CA^{L-1}B) ]

이를 해결하기 위해 $A$를 학습할 때의 효율성을 증대하기 위한 조건이 하나 더 발생한다. 모든 기존의 fixed LSSL의 $A$는 3-quasiseparable함이 증명되었다. 만약 학습되는 $A$ 또한 quasiseparable 특성을 유지할 경우, MVM과 krylov function 연산이 보다 적은 연산량으로 처리될 수 있다.


Evaluations and Demonstrations

실제로 특정 조건을 가지는 $A$를 학습할 수 있으면, 이는 이전 HiPPO system $A$보다 더 좋은 성능을 보임을 확인하였다.

또한 길이가 긴 음성 신호의 classification 성능을 통해 long time dependency 또한 입증하였다.

또한 기존 SoTA에 필적하는 성능을 보이기까지 학습 epoch가 훨씬 적어질 수 있음을 보여주었다.


Conclusion

Mamba modeling의 가장 기초가 되는 LSSL을 살펴보았으며, LSSL의 이해에는 HiPPO의 이해가 필수적이기 때문에 해당 논문도 함께 다루었다. 앞으로 몇개의 포스팅을 통해 Mamba를 리뷰하게 될지는 모르겠지만 State Modeling에 대해서는 아무도 제대로 정리를 안해놓을 것 같아서..

A n o t h e r p o s t i n c a t e g o r y