원생
[논문 정리] Mamba: Linear-Time Sequence Modeling with Selective State Spaces 본문
딥러닝/논문 정리
[논문 정리] Mamba: Linear-Time Sequence Modeling with Selective State Spaces
wonlife 2024. 2. 24. 16:59
Background
HiPPO → LSSL → S4 → Mamba
Times series $($sequential$)$ input을 처리하기 위한 State Space Model 기반의 efficient architecture 계보
State Space Model
- 상태 $($State$)$
어떤 시점(t=t0)에서의 변수를 알고, 시간이 지난 어느 시점$($t≥t0$)$에서의 입력을 알면, 입력이 주어진 시점$($t≥t0$)$에서의 시스템의 거동을 완전히 결정할 수 있을 때, 이러한 변수$($상태변수$)$들의 최소집합을 말한다$($Minimum set of variables, known as state variables, that fully describe the system and its response to any given set of inputs$)$
- 상태 방정식$($State equation$)$
시스템의 현재 상태를 기술하는 방정식$($$\dot{x}: x$의 derivative$)$ $$ \dot{h}_1 = f_1(h,x,t)\\ \dot{h}_2 = f_2(h,x,t)\\ ...\\ \dot{h}_1 = f_n(h,x,t)\\ $$
LTI$($linear and time-invariant$)$ 시스템에서는 위 식을 1차 선형미분방정식으로 표현 가능하다.
$$ {d\over dt} \begin{bmatrix} h_1\\ h_2\\ \vdots\\ h_n\\ \end{bmatrix}= \begin{bmatrix} a_{11} & a_{12} & ... & a_{1n}\\ a_{21} & a_{22} & ... & a_{2n}\\ \vdots & \vdots & \ddots & \vdots\\ a_{n1} & a_{n2} & ... & a_{nn}\\ \end{bmatrix} \begin{bmatrix} h_1\\ h_2\\ \vdots\\ h_n\\ \end{bmatrix} + \begin{bmatrix} b_{11} & b_{12} & ... & b_{1n}\\ b_{21} & b_{22} & ... & b_{2n}\\ \vdots & \vdots & \ddots & \vdots\\ b_{n1} & b_{n2} & ... & b_{nn}\\ \end{bmatrix} \begin{bmatrix} x_1\\ x_2\\ \vdots\\ x_n\\ \end{bmatrix} $$$$ \dot{h} = Ah + Bx $$
시스템의 output 역시 유사하게 표현 가능
$$ y = Ch + Dx \ or \\ y = Ch $$
$($많은 물리 시스템에서 D는 null matrix라고 한다.$)$
따라서, LTI system을 standard state space form으로 표현하는 모델은 아래와 같다.
$$ \dot{h}(t) = Ah(t) + Bx(t)\\ y(t) = Ch(t) + Dx(t) $$
유튜브링크 1 2
Hippo
- HiPPO (High-order Polynomial Projection Operators)
- HiPPO는 time series input의 길이가 매우 길 때, 해당 input의 cumulative history $($축적된 정보$)$를 압축하여 표현하는 방법을 제안한다.
- 그림에서 $($1$)$임의의 함수 f$($t$)$를 $($2$)$ polynomial 기저공간$($subspace$)$ $\mathcal{G}$에 중요도에 대한 지표 $($meause$)$ $\mu^{(t)}$를 가중하여 투영$($projection$)$한다. $($3$)$ 잘 정의된 기저 공간에 대해, 얻어진 각 basis에 대한 coefficients $c(t) \in \R^{N}$는 함수 f의 history를 잘 압축할 수 있다. $($4$)$ 해당 시스템을 이산화하면 online으로 추가되는 데이터에 대해 효율적인 closed-form recurrence 시스템을 구성할 수 있다.
S4
- S4 $($Structured State Space sequence Model$)$
- SSM은 long range dependency를 잘 표현할 수 있지만, computational burden이 매우 커 practical하지 않다.
- S4는 이를 해결하기 위해 먼저 long-range dependancy를 HiPPO matrix를 통해 표현하고, reparameterization을 통해 단순한 discrete representation을 제공한다.
Introduction
Intro
- Foundation model trends
- Foundation Models$($FMs$)$를 방대한 데이터로 pretrain한 후, 원하는 downstream task에 fintune하는 패러다임이 최근 많이 사용되는데, 이중 가장 popular한 구조는 단연 Transformer 구조이다.
- 하지만, self-attention은 학습된 finite length window에 그 효과가 한정되어 있고, window의 크기의 제곱에 비례하는 높은 computational complexity를 가진다는 단점이 있다.
- SSM
- 최근, structured state space sequence models $($S4$)$가 efficient한 long-range dependency 모델로 떠오르고 있다.
- 해당 모델은 continuous signal을 다루는 분야$($오디오, 비디오$)$에서는 높은 성능을 보이지만, discrete data를 다루는 language, genomic 분야에서는 효과가 입증되지 않았다.
Overview
- Selection Mechanism
- 기존 SSM은 LTI system을 가정하기 때문에 input에 independant한 fixed parameter $($$A, B, \bar{A}, \bar{B}, \dots$$)$를 가진다.
- 이는 변화하는 input에 따라 필요한 데이터를 선택하는 능력을 크게 제한하는 요인으로, 본 논문에서는 input에 variant한 SSM parameter를 설정해 SSM에서 필요한 데이터를 선택할 수 있는 selection mechansim을 구성했다.
- Hardware-aware Algorithm
- 기존 SSM은 time 및 input에 invariant하여 효율적으로 $($S4논문 참고$)$ output 연산을 할 수 있었다.
- 하지만, 본 논문에서 제안한 selection mechanism은 input variant한 parameter들을 가져 reparameterization 등의 기법을 활용할 수 없다.
- 따라서 본 논문에서는 효율성을 지키기 위해, convolutional 구조 대신 recurrent 구조를 활용하며, 서로 다른 GPU memory hierarchy 간 IO access를 최소화하기 위해 expanded state를 materialize하지 않는 방법을 제안한다.
- Architecture
- 이전 SSM 구조와 transformer의 mlp 블록을 결합해 sequence 모델을 일반화 및 단순화한 Mamba architecture를 제안한다.
- Contributions
제안하는 모델은 위에서 언급한 모듈 및 알고리즘을 활용해- High quality: 선택 알고리즘$($selectivity$)$을 통한 dense modality에서 성능 향상,
- Fast training and inference: 연산량 및 메모리가 input size에 linear하게 증가하여 효율적인 training, inference가 가능,
- Long context: 1M가량의 input sequence lenght에도 quality 및 efficiency를 보장
State Space Models
- SSM
- State space model은 sequence model로, implicit latent state $ h(t)\in \mathbb{R}^{N} $을 거쳐 1D input을 $ x(t) \in \mathbb{R} \mapsto y(t)\in \mathbb{R} $와 같이 mapping한다.
$ \dot{h}(t) = Ah(t) + Bx(t) $
$y(t) = Ch(t) + Dx(t) $
- Discretization
- 위 식은 continuous system을 해석하기 위한 식으로, 우리가 다루는 discrete system으로의 변환이 필요하다.
- 해당 변환은 기존 “continuous parameters” $(A, B, \Delta)$($\Delta$: sampling time difference)를 입력받아 “discrete parameters” $(\bar{A}, \bar{B})$를 출력하는 formulas $\bar{A}=f_{A}(\Delta, A), \bar{B}=f_{B}(\Delta, B)$로 이루어져 있으며, $(f_A, f_B)$는 discretization rule이라 불린다.
- Discretization된 SSM은 linear recurrence혹은 global convolution로 연산 가능하다.
$\bar{K}=(C\bar{B}, C\bar{A}\bar{B}, \dots, C\bar{A}^k\bar{B}, \dots),\\ y=x\ast\bar{K},$
$h_t=\bar{A}h_{t-1}+\bar{B}x_t,\\ y_t=Ch_t,$ - Computation
- 일반적으로 input sequence가 한번에 모두 주어지는 training 시에는 paralellize된 global convolution 연산을, input이 순차적으로 주어지는 inference시에는 recurrent mode로 연산을 진행한다.
- Linear Time Invariance
- 기존 SSM들은 LTI, 즉 모든 parameter$($$A, B, C, \Delta$$)$들이 모든 입력에 대해 동일했다.
- 이러한 특성은 위 linear recurrence 공식을 global convolution으로 전환할 수 있게 해주는 요인이었다.
- LTI system을 가정함으로써 효율적인 학습 및 inference가 가능했지만, 아래 언급하듯이 LTI적 특성은 모델의 성능을 제한하는 요인이 되기도 한다.
Selective State Space Models
Motivation: Selection as a Means of Compression
- Sequence modeling in term of compressing
- Sequence model의 성능은 이전 sequence의 정보를 얼마나 잘 유지하고, 얼마나 잘 압축했는지로 갈음할 수 있다.
- Self-attention의 경우, 모든 sequence의 정보를 압축하지 않고(key, value cache) 활용하여 매우 효과적이지만, 동시에 매우 비효율적이다.
- Recurrent 모델의 경우, finite state에 정보를 저장하므로 효율적이지만, 해당 state에 얼마나 효과적으로 context를 유지했느냐에 따라 성능이 나뉜다. - Limitation of LTI model
- 위 그림에서 왼쪽 단순한 copying 문제는 단순한 LTI system $($e.g. global conv$)$로 쉽게 풀 수 있다.
- 하지만, 오른쪽 위 Selective Copying $($sequential input data에서 기억해야 할 token의 위치가 varying하는 경우$)$나 Induction Heads $($검정색 token 뒤에는 파란색 token이 온다는 사실을 기억해 검정색 token이 input될 경우 파란색 token을 attention 하는 task$)$는 constant dynamics로는 해결할 수 없다.
- Summary
- Efficiency: context를 얼마나 잘 압축했냐.
- Effectiveness: context를 얼마나 잘 유지했냐.
- LTI model: 한계가 있다.
→ Selectivity를 활용해 input dependant $($or context-aware$)$하게 state를 저장해 효율성과 효과성을 둘 다 얻도록 하자.
Improving SSMs with Selection
- Algorithms
- Explanation
- $\Delta, B, C$를 input에 대한 function으로 표현해 input time에 varying하는 $($L차원의 존재는 input 순서에 따라 다른 param 의미$)$ system 구성한다.
- 위에서 언급한 convolution으로의 전환은 불가하다.
- 각각 $s_B(x) = Linear_N(x),\ s_C(x) = Linear_N(x),\ s_{\Delta}(x) = Broadcast_D(Linear_1(x)),$ and $\tau_{\Delta} = softplus$ $($$s_{\Delta}$와 $\tau_{\Delta}$를 위와 같이 선택한 것은 RNN과의 연관성을 위한 것$)$
Efficient Implementation of Selective SSMs
- Selective Scan
- Selection mechanism은 LTI 모델의 limitation을 해결했지만, SSM의 computational cost 문제를 다시 야기하게 된다.
- Recurrent한 computation은 convolution mode에 비해 유연하지만, latent state $h$를 $($모든 sequence에서$)$ 연산 및 저장하기 때문에 convolution에 비해 메모리 사용량이 증가하게 된다.
- 이를 해결하기 위해 kernel fusion, parallel scan, recomputaion 세 가지 테크닉을 활용한다.
- Kernel fusion
$($추측, 코드 확인 필요$)$
- $(B, L,D,N)$ 크기의 $\bar{A}, \bar{B}$는 크기가 작은 SRAM에 저장할 수 없다.
- 그렇다면, 원래는 data load 속도가 느린 HBM에서 $A, B, C, \Delta$를 load해 $\bar{A}, \bar{B}$연산 후 다시 HBM에 저장하므로 HBM을 통한 데이터 I/O가 많아져 속도가 느리다.
- 제안한 방법에서는 $A, B, C, \Delta$를 SRAM에 직접 로드하고 상대적으로 작은 $(B, L, D)$크기의discretization 및 recurrence 결과를 HBM에 전달해 data I/O에 의한 bottleneck을 줄인다.
- Recomputation- 기존 pytorch 등 library에서는 backpropagation시 연산량을 줄이기 위해 중간 계산 결과들을 backprogation시까지 저장한다.
- 이렇게 intermediate states($h$)를 모두 저장할 경우 memory 사용량이 증가한다.
- 각 states들을 저장하여 backpropagation에 사용하는 대신, backpropagation을 진행할 때 다시 연산하는$($recomputation$)$ 방식으로 메모리 사용량 문제를 해결한다.
A simplified SSM Architecture
Code
TODO
Discussion
- Continuous-Discrete Spectrum
앞서 설명 했듯, S4구조는 continuous signal에 대해 좋은 성능을 가지고 있었다.
- 하지만, Mamba 구조는 text, DNA와 같은 discrete data에서 성능을 끌어올릴 수 있었지만 반대로 continuous signal에 대한 taks 성능은 하락하였다.
- Downstream Affordances
- Transformer 기반 foundation model들처럼 finetuning, adaptation, prompting 등 하위 태스크에 잘 적용할 수 있는지 확인해야된다.
- Scaling
- 7B 파라미터까지는 Transformer와 비견되는 성능을 보인다 확인했지만, 모델 크기가 커질 때에도 이러한 경향성이 유지되는지 확인해야한다.
- 또한, 모델 크기가 커짐에 따라 본 논문에서 다루지 않은 여러 문제점이 발생할 수 있기 때문에 이에 대한 검증도 필요하다.
Conclusion
우리는 structured state space model에 selection mechanism을 도입하여, 시퀀스 길이에 선형적으로 확장되면서 context-depdendent한 연산을 수행할 수 있게 합니다. 간단한 attention free 구조에 통합될 때, Mamba는 다양한 도메인에 대해 SOTA 결과를 달성하며, 강력한 트랜스포머 모델의 성능을 비견되거나 능가합니다. 우리는 selective state space model이 유전체학, 오디오, 비디오와 같이 긴 문맥을 요구하는 모달리티를 포함하여 다양한 도메인에 대한 기반 모델을 구축하는 데 넓은 응용 가능성에 대해 흥분됩니다. 우리의 결과는 Mamba가 일반 시퀀스 모델 백본이 될 강력한 후보임을 제안합니다.
Related papers
HiPPO, LSSL, S4, VisionMamba
'딥러닝 > 논문 정리' 카테고리의 다른 글
[기타] Visio 수식 깨짐 해결 (0) | 2023.09.03 |
---|