Skip to content

Efficiently Modeling Long Sequences with Structured State Spaces

Paper

Introduction

A major problem in sequence modeling is the difficulty of capturing long-range dependencies(LRDs). Results from long-range arena(LRA) benchmark show that sequence models perform poorly on LRD tasks. The specialized variants of RNNs, CNNs and Transformers have been proposed to address this issue. Despite these solutions still perform poorly on challenging benchmarks like LRA.

An alternative approach to LRDs is based on state space models(SSMs). Deep SSMs perform exceptionally when equipped with special matrices A. Their Linear State Space Layer(LSSL) conceptually unifies the strengths of CTM, Rnn and CNN models.Theoretically deep SSMs can address LDRs. But LSSL are computationally expensive, for state dimension \(N\) computing latent space requires \(O(N^2L)\) time and \(O(NL)\) space.

This paper introduces the Structured State Space(S4) sequence model base on SSM that solves the critical bottleneck. S4 reparameterize the structured state matrix A by decomposing them as the sum of low-rank and normal term.

Background

State Space Models

The state space model is defined by the following equations:

\[ \begin{equation} \begin{aligned} \frac{dx}{dt} &= x'(t) = Ax(t) + Bu(t) \\ y(t) &= Cx(t) + Du(t) \end{aligned} \end{equation} \]

Where \(u(t) \in R^d\) is the input signal, \(x(t) \in R^N\) is the hidden state, and \(y(t) \in R^d\) is the output signal. The matrices A, B, C, D are learnable parameters of appropriate dimensions.

The goal is to use SSM as a black-box representation in a deep sequence model.

Addressing Long-Range Dependencies with HiPPO

The basic SSM (1) performs poorly in practice. Intutively this is because the linear first-order ODEs solve to an exponential function and cause vanishing/exploding gradients problem. To address this, LSSL uses HiPPO which specifies a class of certain matrices \(A \in \mathbb{R}^{N \times N}\) that allows \(x(t)\) to memorize the history of \(u(t)\).

\[ \begin{equation} A_{nk} = \begin{cases} (2n + 1)^{1/2} (2k + 1)^{1/2} & \text{if } n > k \\ n + 1 & \text{if } n = k \\ 0 & \text{if } n < k \\ \end{cases} \end{equation}\]

Discrete-time SSM

To use SSMs on a discrete input sequence \((u_0, u_1, \cdots)\), we need to convert the continuous-time SSM (1) to a discrete-time SSM. Conceptually, the inputs \(u_k\) can be viewed as samples of a continuous signal \(u(t)\) at integer times \(t = k\Delta\). A bilinear discretization method is used which converts the state matrix \(A\) into an approximate \(\bar{A}\) as follows:

\[\begin{equation} \begin{aligned} x_{k} &= \bar{A} x_{k-1} + \bar{B} u_k \\ y_k &= \bar{C} x_k \\ \bar{A} &= (I - \frac{\Delta}{2} A)^{-1} (I + \frac{\Delta}{2} A) \\ \bar{B} &= (I - \frac{\Delta}{2} A)^{-1} \Delta B \\ \bar{C} &= C \\ \end{aligned} \end{equation}\]

The state equation is now a reccurence in \(x_k\), allowing discrete SSM to be computed like an RNN.

Training SSMs

The recurrent SSM (3) is slow to trian because of its sequential nature. To address this, the SSM is reformulated as a convolution operation. When (3) is unrolled, we have:

\[\begin{equation} \begin{aligned} y_k &= \bar{C} \bar{A}^k \bar{B} u_0 + \bar{C} \bar{A}^{k-1} \bar{B} u_1 + \cdots + \bar{C} \bar{B} u_{k-1} + \bar{CB} u_k \\ y &= \bar{K} * u \\ \end{aligned} \end{equation}\]
\[\begin{equation} \bar{K} \in \mathbb{R}^L := \mathcal{K}_L(\bar{A}, \bar{B}, \bar{C}) := (\bar{C}\bar{A}^i\bar{B})_{i\in[L]} = (\bar{C} \bar{B}, \bar{C} \bar{A} \bar{B}, \cdots, \bar{C}\bar{A}^{L-1} \bar{B}) \\ \end{equation}\]

Where \(*\) is the convolution operator. The convolution kernel \(\bar{K}\) can be computed in \(O(N^2L)\) time and \(O(NL)\) space. The convolution operation can be computed in \(O(L \log L)\) time using FFT.

However, computing \(\bar{K}\) is still expensive for large \(N\). The next section introduces the S4 model which addresses this bottleneck.

Methodology

Motivation: Diagonalization

The main computational bottleneck in training SSMs is computing the convolution kernel \(\bar{K}\) which requires \(O(N^2L)\) time and \(O(NL)\) space. The key idea to address this is to choose a special structure for the state matrix \(A\) that allows \(\bar{K}\) to be computed more efficiently.

Lemma 3.1 Conjugation is an equivalence relation on SSMs. Specifically, if \(V\) is an invertible matrix, then the SSM \((A, B, C)\) is equivalent to \((V^{-1}AV, V^{-1}B, CV)\).

If \(A\) were diagonal, then the resulting computation of \(\bar{K}\) would be much cheaper. But unfortunately, naive diagonalization of the HiPPO matrix \(A\) is numerically unstable. An entry in \(V\) has magnitude \(2^{4N/3}\).

S4 Parametrization

To address the numerical instability of diagonalizing \(A\), we should only conjugate \(A\) by a matrix \(V\) that is well-conditioned. The ideal choice is a unitary matrix. By Spectral Theorem, any normal matrix \(A\) can be diagonalized by a unitary matrix. But the HiPPO matrix \(A\) is not normal. To address this, S4 decomposes \(A\) as the sum of a normal matrix and a low-rank matrix. But this form is still slow and not easily optimized. So S4 applies three new techniques:

  • Generating functions: Instead of computing \(\bar{K}\) directly, S4 computes its generating function \(\sum_{j=0}^{L-1} \bar{K}_j \zeta^j\). Evaluating this function at the roots of unity (Fourier points) gives the whole kernel. Then apply a inverse FFt to recover \(K\). This allows the use of fast polynomial arithmetic to compute \(\bar{K}\) in \(O(N \log^2 L)\) time.
  • Woodbury Identity: The generating function involves \((I - \bar{A} \zeta )^{-1}\).
  • Cauchy Kernel: After applying the Woodbury identity, the inverse involves computation of terms like \(\frac{1}{w_j - \zeta_k}\) . This is a Cauchy kernel and can be computed in \(O(N \log N)\) time using fast multipole methods.

Theorem 1: All HiPPO matrics have a NPLR respresentation

\[ \begin{equation} A = V \Lambda V^* - PQ^T = V (\Lambda - (V^*P)(V^*Q)^*)V^* \end{equation} \]

for unitary \(V \in \mathbb{C}^{N \times N}\) , diagonal \(\Lambda\) and low-rank factorization \(P, Q \in \mathbb{R}^{N \times r}\). These matrices all satisfy r=1 or r=2.

S4 Algorithms

S4 algo.png

By equation (6), the NPLR matrices can be conjugated into diagonal plus low-rank (DLPR) form. Computing the SSM convolution kernel \(\bar{K}\) requires only \(\tilde{O}(N + L)\) operations and \(O(N + L)\) memory.

Architecture

An S4 layer is parameterized by \((\Lambda, P, Q, B, C)\). First initialize an SSM with \(A\) set to HiPPO matrix, this SSM is unitarily equivalent to some \((\Lambda - PQ^*, B, C)\) for some diagonal \(\Lambda\) and vectors \(P, Q, B, C \in \mathbb{C}^{N \times 1}\).

The overall DNN architecture of S4 is similar to prior works. S4 defines a map from \(\mathbb{R}^{L}\) to \(\mathbb{R}^{L}\) i.e. a 1-D sequence map. Typically, DNNs operate on a feature maps of size H instead of 1. To handle this, S4 uses H independent copies of itself, and then mixing with a position-wise linear layer for a total of \(O(H^2) + O(HN)\) parameters.

Experiments & Results

Efficiency: S4 is ~30 times faster than LSSL and uses ~390 times less memory for \(d=512\). S4 is also as efficient as Performer and Linear Transformer.

LRD: S4 outperforms on all 6 tasks in LRA benchmark, with an average score of 80.48% compared to the previous SoTA of 60%.

Ablation Studies

Tries to evaluate is S4 good because of HiPPO initialization or the NPLR structure.

Unconstrained SSMs: Randomly initialized (Gaussian) \(A\) and HiPPO are compared. Although both are able to reach perfect training accuracy, HiPPO generalizes much better.

NLPR SSMs: Even with NLPR structure, random \(A\) still generalizes poorly. This shows that HiPPO is critical for good performance.

Limitations and Future Directions

  • S4 is not a drop-in replacement for Transformers in NLP tasks. There is still a gap between S4 and Transformers on language modeling tasks.
  • Exploring combinations of S4 with other sequence models to complement each other's strengths.
  • Applying S4 to other domains like vision, reinforcement learning, etc.