Skip to content

dhruvmsheth/cuRNN

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

3 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Were RNNs All We Needed?

Instructions for Running the CPU Demo

  1. Dependencies:

    • PyTorch
    • einops
    • pytest
  2. Setup:

    pip install torch torchvision torchaudio einops
    pip install pytest # For running tests
    pip install -U "huggingface_hub[cli]"

    Authenticate with Hugging Face:

    hf auth login

    Go to repository root and run the following:

    $ ./download.sh

    If the clone fails, here is the explicit link to access the dataset.

  3. Running the Demo Script: To train the model:

    $ cd cpu/
    $ python train.py

    Modify model_use_lstm: bool = False to either train LSTM or GRU

    To run the inference demo:

    $ python inference.py

    Change model = 'GRU' # other option: 'LSTM' to either 'GRU' or 'LSTM' to test the respective model.

  4. Running Tests:

    $ pytest ./

Instructions for Running the GPU Demo

Reproduction Instructions

Instructions for reproducing the benchmark results:

Note: In the zip, the build directory is already present.

# source build
$ mkdir -p gpu/build && cd gpu/build
$ cmake ..
$ make -j$(nproc)

# token-by-token processing and generation (sequential - main.cu)
$ ./mini_rnn_infer                      # GRU by default
$ ./mini_rnn_infer --lstm               # LSTM variant

# parallel-scan timing demo
$ ./parallel_forward_test 4096          # any T up to GPU RAM
$ ./parallel_forward_test --lstm 4096   # the LSTM scan path

For CPU sequential mode vs GPU parallel scan mode, run the following:

# CPU-vs-GPU comparison
$ cd cpu
$ python tests/compare.py

For NSight Profiler, run the following:

$ cd gpu/build
$ ncu --set full --export parallel_4k --target-processes all ./parallel_forward_test 128

Note: This documentation includes both CPU and GPU implementations. The CPU implementation is in the cpu directory, and the GPU implementation is in the gpu directory. GPU explanation starts at section 8.

1. Introduction

This README contains the CPU implementation and outlines a GPU parallelization strategy for Gated Recurrent Units (GRUs) and Long Short-Term Memory (LSTM) networks based on the paper "Were RNNs All We Needed?" by Feng et al. The primary motivation of this work is to explore the potential of achieving computational speedups in RNNs by simplifying their gating mechanisms and using the parallel scan algorithm. Traditional RNNs, while powerful for sequence modeling, suffer from inherent sequential dependencies, limiting their training and inference speed, especially on parallel hardware like GPUs. This project aims to implement "minimal" versions, minGRU and minLSTM, which are designed to be fully parallelizable during training, and to lay the groundwork for an efficient CUDA-based GPU implementation.

The core contributions of the original paper, and thus the focus of this project, involve:

  • Simplifying Gates: Modifying the standard LSTM and GRU gate computations to remove dependencies on the previous hidden state ($h_{t-1}$).
  • Parallel Training via Parallel Scan: Utilizing the parallel scan (prefix sum) algorithm to compute the recurrent states across an entire sequence in parallel, once the gate simplifications are in place.

This README will cover:

  • The original formulations of LSTM and GRU gates.
  • The specific simplifications made to create minGRU and minLSTM.
  • The architecture and pseudo-code for the CPU implementation of minGRU and minLSTM.
  • An explanation of how the parallel scan algorithm enables parallelization.
  • A plan for the GPU implementation, including kernel descriptions and optimization considerations.
  • Instructions for running the CPU demonstration and its test cases.

2. Original RNN Architectures

Before detailing the simplifications, I'll quickly go over the standard formulations of LSTM and GRU, as these are the starting points for the minGRU and minLSTM architectures.

2.1. Long Short-Term Memory (LSTM)

LSTMs were introduced to address the vanishing gradient problem in traditional RNNs, enabling them to capture long-range dependencies. An LSTM cell maintains a cell state ($c_t$) and a hidden state ($h_t$), controlled by three main gates: the forget gate, the input gate, and the output gate.

The standard PyTorch LSTM equations are as follows:

  • Input Gate ($i_t$): Controls which new information is stored in the cell state. $$i_t = \sigma(W_{ii}x_t + b_{ii} + W_{hi}h_{t-1} + b_{hi})$$
  • Forget Gate ($f_t$): Determines what information is discarded from the cell state. $$f_t = \sigma(W_{if}x_t + b_{if} + W_{hf}h_{t-1} + b_{hf})$$
  • Cell Gate ($g_t$ or Candidate Cell State $\tilde{c}t$): Computes new candidate values to be added to the cell state. $$g_t = \tanh(W{ig}x_t + b_{ig} + W_{hg}h_{t-1} + b_{hg})$$
  • Output Gate ($o_t$): Decides what information from the cell state is output. $$o_t = \sigma(W_{io}x_t + b_{io} + W_{ho}h_{t-1} + b_{ho})$$
  • Cell State Update ($c_t$): $$c_t = f_t \odot c_{t-1} + i_t \odot g_t$$
  • Hidden State Update ($h_t$): $$h_t = o_t \odot \tanh(c_t)$$

Where:

  • $x_t$: Input vector at time step $t$.
  • $h_{t-1}$: Hidden state from the previous time step.
  • $c_{t-1}$: Cell state from the previous time step.
  • $W, b$: Weight matrices and bias vectors.
  • $\sigma$: Sigmoid activation function.
  • $\tanh$: Hyperbolic tangent activation function.
  • $\odot$: Element-wise multiplication.

2.2. Gated Recurrent Unit (GRU)

GRUs, simplify the LSTM architecture by combining the forget and input gates into a single "update gate" and merging the cell state and hidden state. This results in fewer parameters and often comparable performance.

The standard PyTorch GRU equations are as follows:

  • Reset Gate ($r_t$): Determines how much of the previous hidden state to forget. $$r_t = \sigma(W_{ir}x_t + b_{ir} + W_{hr}h_{t-1} + b_{hr})$$
  • Update Gate ($z_t$): Controls how much of the previous hidden state is kept versus how much the new candidate state is used. $$z_t = \sigma(W_{iz}x_t + b_{iz} + W_{hz}h_{t-1} + b_{hz})$$
  • New Gate ($n_t$ or Candidate Hidden State $\tilde{h}t$): Computes the candidate hidden state. $$n_t = \tanh(W{in}x_t + b_{in} + r_t \odot (W_{hn}h_{t-1} + b_{hn}))$$
  • Hidden State Update ($h_t$): $$h_t = (1 - z_t) \odot n_t + z_t \odot h_{t-1}$$

Note: The formulation for $n_t$ can vary slightly across different deep learning frameworks. The PyTorch version applies the reset gate $r_t$ to the $W_{hn}h_{t-1}$ term, whereas other common formulations apply it directly to $h_{t-1}$ before the matrix multiplication.

3. Simplifications for Parallelism: minLSTM and minGRU

The main innovation in "Were RNNs All We Needed?" is the simplification of LSTM and GRU gates to remove their direct dependency on the previous hidden state $h_{t-1}$. This modification allows the recurrence relation to be expressed in a form of the parallel scan algorithm, $h_t = a_t \odot h_{t-1} + b_t$, where $a_t$ and $b_t$ only depend on the current input $x_t$.

3.1. minGRU Simplifications

The transformation from GRU to minGRU involves two main steps:

  • Drop Previous State Dependencies from Gates:

    • The update gate $z_t$ and candidate hidden state $\tilde{h}_t$ (originally $n_t$) are redefined to depend only on the current input $x_t$:
      • Original $z_t = \sigma(\text{Linear}{z_i}(x_t, h{t-1}))$ $\rightarrow$ Simplified $z_t = \sigma(\text{Linear}_z(x_t))$
      • Original $\tilde{h}t = \tanh(\text{Linear}{h_i}(x_t, r_t \odot h_{t-1}))$ $\rightarrow$ Simplified $\tilde{h}_t = \tanh(\text{Linear}_h(x_t))$
    • The reset gate ($r_t$) is removed entirely, as its primary role was to modulate the influence of $h_{t-1}$ on $\tilde{h}_t$, which is no longer a direct dependency.
  • Drop Range Restriction of Candidate States:

    • The hyperbolic tangent ($\tanh$) activation function is removed from the computation of the candidate hidden state $\tilde{h}_t$. This simplifies the computation to a linear transformation:
      • $\tilde{h}_t = \text{Linear}_h(x_t)$

The resulting recurrence for minGRU is: $$h_t = (1 - z_t) \odot h_{t-1} + z_t \odot \tilde{h}_t$$

This fits the parallel scan form $h_t = a_t \odot h_{t-1} + b_t$ with: $a_t = (1 - z_t)$ $b_t = z_t \odot \tilde{h}_t$

Since $z_t$ and $\tilde{h}_t$ (and thus $a_t$ and $b_t$) depend only on $x_t$, they can be computed in parallel for all time steps.

3.2. minLSTM Simplifications

The transformation from LSTM to minLSTM involves three steps:

  • Drop Previous State Dependencies from Gates:

    • The forget gate $f_t$, input gate $i_t$, and candidate cell state $\tilde{c}_t$ are redefined to depend only on $x_t$:
      • Simplified $f_t = \sigma(\text{Linear}_f(x_t))$
      • Simplified $i_t = \sigma(\text{Linear}_i(x_t))$
      • Simplified $\tilde{c}_t = \tanh(\text{Linear}_c(x_t))$ (initially, then tanh is removed in step 2)
  • Drop Range Restriction of Candidate States and Hidden State Activation:

    • The tanh activation is removed from $\tilde{c}_t$, making it $\tilde{c}_t = \text{Linear}_c(x_t)$.
    • The tanh activation on the cell state $c_t$ in the hidden state computation ($h_t = o_t \odot \tanh(c_t)$) is also removed, so $h_t = o_t \odot c_t$.
  • Simplifying Scaling of Output (Merging Cell and Hidden State):

    • The output gate $o_t$ is removed.
    • The cell state $c_t$ is effectively merged with the hidden state $h_t$. The recurrence is now directly on $h_t$. The candidate state is denoted $\tilde{h}_t = \text{Linear}_h(x_t)$.
    • The recurrence becomes: $h_t = f_t \odot h_{t-1} + i_t \odot \tilde{h}_t$.
  • Normalization for Time-Independent Outputs (Length Independence Scaling):

    • To prevent the magnitude of $h_t$ from growing with sequence length, the forget and input gates are normalized: $$f'_t = \frac{f_t}{f_t + i_t + \epsilon}$$ $$i'_t = \frac{i_t}{f_t + i_t + \epsilon}$$ (where $\epsilon$ is a small constant for numerical stability, $10^{-8}$).

The final minLSTM recurrence is: $h_t = f't \odot h{t-1} + i'_t \odot \tilde{h}_t$.

This fits the parallel scan form $h_t = a_t \odot h_{t-1} + b_t$ with: $a_t = f'_t$ $b_t = i'_t \odot \tilde{h}_t$

Again, $a_t$ and $b_t$ depend only on $x_t$.

3.3. Assumptions

These simplifications assume that the reduced architectural complexity and modified gate interactions are still sufficient to capture the necessary temporal dependencies for the tasks at hand. The significant reduction in parameters (minGRU uses $2d_h(d_x + 1)$ and minLSTM uses $3d_h(d_x + 1)$ parameters ) and the ability to parallelize training are major advantages. The empirical results in the paper suggest these simplified models perform surprisingly well, rivaling more complex architectures.

4. The minGRU Architecture

The CPU implementation of minGRU is based on the definitions from the "Were RNNs All We Needed?" paper's appendix.

The core equations are:

  • Update Gate ($z_t$): $z_t = \sigma(W_z x_t + b_z)$
  • Candidate Hidden State ($\tilde{h}_t$): $\tilde{h}_t = W_h x_t + b_h$ (Note: tanh is removed)
  • Hidden State Recurrence ($h_t$): $h_t = (1 - z_t) \odot h_{t-1} + z_t \odot \tilde{h}_t$

Where $W_z, b_z, W_h, b_h$ are learnable parameters of linear layers.

Pseudo-code (Vanilla Sequential Mode - from Appendix B.2 of Feng et al.):

def minGRU_forward_sequential(x_t, h_prev, linear_z, linear_h):
    # x_t: current input (Batch, Dim_input)
    # h_prev: previous hidden state (Batch, Dim_hidden)
    # linear_z: torch.nn.Linear for update gate
    # linear_h: torch.nn.Linear for candidate hidden state

    z_t = torch.sigmoid(linear_z(x_t))
    h_tilde_t = linear_h(x_t)  # tanh removed as per final minGRU definition

    h_t = (1 - z_t) * h_prev + z_t * h_tilde_t
    return h_t

Pseudo-code (Vanilla Parallel Mode - from Appendix B.2 of Feng et al.): This mode computes all $a_t$ and $b_t$ terms first across the sequence and then applies the scan.

def minGRU_forward_parallel(x_seq, h_0, linear_z, linear_h):
    # x_seq: full input sequence (SeqLen, Batch, Dim_input)
    # h_0: initial hidden state (Batch, Dim_hidden)

    z_seq = torch.sigmoid(linear_z(x_seq)) # (SeqLen, Batch, Dim_hidden)
    h_tilde_seq = linear_h(x_seq) # (SeqLen, Batch, Dim_hidden)

    # Parameters for the parallel scan: h_t = a_t * h_{t-1} + b_t
    a_seq = (1 - z_seq) # (SeqLen, Batch, Dim_hidden)
    b_seq = z_seq * h_tilde_seq # (SeqLen, Batch, Dim_hidden)
    h_sequence = associative_parallel_scan(a_seq, b_seq, h_0) 
    return h_sequence

The paper uses a numerically stable log-space version for the parallel scan.

Table: minGRU Equations and Parallel Scan Components

Component Equation Notes
Update Gate ($z_t$) $z_t = \sigma(\text{Linear}_z(x_t))$ Depends only on $x_t$
Candidate State ($\tilde{h}_t$) $\tilde{h}_t = \text{Linear}_h(x_t)$ Depends only on $x_t$; tanh removed
Recurrence ($h_t$) $h_t = (1 - z_t) \odot h_{t-1} + z_t \odot \tilde{h}_t$ Standard GRU-like update form
Scan Parameter ($a_t$) $a_t = (1 - z_t)$ Multiplicative factor for $h_{t-1}$
Scan Parameter ($b_t$) $b_t = z_t \odot \tilde{h}_t$ Additive factor

5. The minLSTM Architecture

The CPU implementation of minLSTM is derived from the simplifications detailed in the "Were RNNs All We Needed?" paper's appendix.

The core equations are:

  • Forget Gate ($f_t$): $f_t = \sigma(W_f x_t + b_f)$
  • Input Gate ($i_t$): $i_t = \sigma(W_i x_t + b_i)$
  • Candidate Hidden State ($\tilde{h}_t$): $\tilde{h}_t = W_h x_t + b_h$ (Note: tanh removed, $o_t$ and $c_t$ merged)
  • Normalized Forget Gate ($f'_t$): $f'_t = f_t / (f_t + i_t + \epsilon)$
  • Normalized Input Gate ($i'_t$): $i'_t = i_t / (f_t + i_t + \epsilon)$
  • Hidden State Recurrence ($h_t$): $h_t = f't \odot h{t-1} + i'_t \odot \tilde{h}_t$

Pseudo-code (Vanilla Sequential Mode - from Appendix B.2 of Feng et al.):

def minLSTM_forward_sequential(x_t, h_prev, linear_f, linear_i, linear_h, epsilon=1e-8):
    # x_t: current input (Batch, Dim_input)
    # h_prev: previous hidden state (Batch, Dim_hidden)
    # linear_f, linear_i, linear_h: torch.nn.Linear layers

    f_t = torch.sigmoid(linear_f(x_t))
    i_t = torch.sigmoid(linear_i(x_t))
    h_tilde_t = linear_h(x_t)  # tanh removed

    # Normalization for time-independent outputs
    f_prime_t = f_t / (f_t + i_t + epsilon)
    i_prime_t = i_t / (f_t + i_t + epsilon)

    h_t = f_prime_t * h_prev + i_prime_t * h_tilde_t
    return h_t

Pseudo-code (Vanilla Parallel Mode - from Appendix B.2 of Feng et al., conceptually):

def minLSTM_forward_parallel(x_seq, h_0, linear_f, linear_i, linear_h, epsilon=1e-8):
    # x_seq: full input sequence (SeqLen, Batch, Dim_input)
    # h_0: initial hidden state (Batch, Dim_hidden)

    f_seq = torch.sigmoid(linear_f(x_seq))
    i_seq = torch.sigmoid(linear_i(x_seq))
    h_tilde_seq = linear_h(x_seq)

    # Normalization
    f_prime_seq = f_seq / (f_seq + i_seq + epsilon)
    i_prime_seq = i_seq / (f_seq + i_seq + epsilon)

    # Parameters for the parallel scan: h_t = a_t * h_{t-1} + b_t
    a_seq = f_prime_seq
    b_seq = i_prime_seq * h_tilde_seq

    # Example from paper's pseudocode for vanilla parallel mode:
    # h = parallel_scan(f_prime, torch.cat([h_0.unsqueeze(0), i_prime[:-1]*h_tilde_seq[:-1]], dim=0))
    h_sequence = associative_parallel_scan(a_seq, b_seq, h_0) # Returns all h_t
    return h_sequence

Table: minLSTM Equations and Parallel Scan Components

Component Equation Notes
Forget Gate ($f_t$) $f_t = \sigma(\text{Linear}_f(x_t))$ Depends only on $x_t$
Input Gate ($i_t$) $i_t = \sigma(\text{Linear}_i(x_t))$ Depends only on $x_t$
Candidate State ($\tilde{h}_t$) $\tilde{h}_t = \text{Linear}_h(x_t)$ Depends only on $x_t$; tanh, $o_t$, $c_t$ removed/merged
Normalized Forget Gate ($f'_t$) $f'_t = f_t / (f_t + i_t + \epsilon)$ For time-independent scaling
Normalized Input Gate ($i'_t$) $i'_t = i_t / (f_t + i_t + \epsilon)$ For time-independent scaling
Recurrence ($h_t$) $h_t = f't \odot h{t-1} + i'_t \odot \tilde{h}_t$ Simplified LSTM-like update
Scan Parameter ($a_t$) $a_t = f'_t$ Multiplicative factor for $h_{t-1}$
Scan Parameter ($b_t$) $b_t = i'_t \odot \tilde{h}_t$ Additive factor

6. Parallelization via Parallel Scan Algorithm

Given an input sequence $x_0, x_1, \dots, x_{n-1}$ and an associative binary operator $\oplus$, an inclusive scan computes the sequence $y_0 = x_0, y_1 = x_0 \oplus x_1, \dots, y_{n-1} = x_0 \oplus \dots \oplus x_{n-1}$. An exclusive scan typically sets $y_0$ to an identity element and $y_k = x_0 \oplus \dots \oplus x_{k-1}$. The associativity of $\oplus$ is important as it allows the coputation to be reordered and parallelized into a treelike structure.

source for Blelloch's algorithm: https://developer.nvidia.com/gpugems/gpugems3/part-vi-gpu-computing/chapter-39-parallel-prefix-sum-scan-cuda

Blelloch's algorithm is the parallel scan algorithm I'll be using, performing $O(N)$ work in $O(\log N)$ steps (depth) on $O(N / \log N)$ processors. It generally consists of two phases:

  1. Reduce Phase (Up-Sweep): Computes partial sums (or results of $\oplus$) in a tree-like manner from leaves to root.
  2. Down-Sweep Phase: Uses the intermediate values from the up-sweep phase to compute the final prefix sums for all elements, traversing from root to leaves.

For the minGRU/minLSTM recurrence $h_t = a_t \odot h_{t-1} + b_t$ (element-wise operations), the "operator" is more complex than simple addition. If we have $h_t = a_t h_{t-1} + b_t$ and $h_{t-1} = a_{t-1} h_{t-2} + b_{t-1}$, then by substitution: $h_t = a_t (a_{t-1} h_{t-2} + b_{t-1}) + b_t = (a_t a_{t-1}) h_{t-2} + (a_t b_{t-1} + b_t)$. This defines an associative operator $\oplus$ on pairs $(A, B)$, where a pair represents the transformation $h_k = A \cdot h_j + B$. The composition rule is: $$(A_2, B_2) \oplus (A_1, B_1) = (A_2 A_1, A_2 B_1 + B_2)$$

The input to the parallel scan algorithm is a sequence of such pairs $(a_t, b_t)$ for each time step $t$. The scan then computes the cumulative transformation to obtain each $h_t$ from an initial $h_0$.

For numerical stability, especially when $a_t$ terms (like $1 - z_t$ or $f't$) can be very small or involve exponentials, the paper "Were RNNs All We Needed" uses a log-space scan. This involves transforming the recurrence and performing the scan operations in the logarithmic domain to avoid underflow/overflow issues and maintain precision. The exact formulation of the associative operator in log-space is specific to this transformation. The paper describes a log-space formulation for $h_t = \alpha_t h{t-1} + \beta_t$ (where $\alpha_t = a_t, \beta_t = b_t$) as $A_t = \prod_{i=1}^t \alpha_i$ and $H_t = \sum_{i=1}^t (\prod_{j=i+1}^t \alpha_j) \beta_i + A_t h_0$. Both $A_t$ (log-sum of $\log \alpha_i$) and the sum component can be computed using parallel scans.

7. CPU Implementation and Demonstration

It is developed in Python using PyTorch using logic and pseudocode from the paper for minGRU and the derived logic for minLSTM.

7.1. Features Implemented

  • minGRUCell and minLSTMCell: Modules for single time-step computation, useful for sequential unrolling and understanding.
  • minGRU and minLSTM Layers: Full layers capable of processing entire sequences. These layers support:
    • Sequential Mode: Iterative computation, $h_t = \text{cell}(x_t, h_{t-1})$ as a comparison baseline.
    • Parallel Mode: Computation of all $a_t, b_t$ terms for the sequence at once using vectorized PyTorch operations, followed by a Python-based loop that implements the associative scan logic.
  • Inference: Forward pass execution for generating outputs.

7.2. Test Cases and Verification

  • Layer Tests for minGRU and minLSTM:
    • For a given input sequence and initial hidden state, the output from the sequential unrolling must be numerically very close to the output from the parallel scan-like execution. torch.allclose() with a suitable tolerance is used.
    • Batch Processing: Ensure the layers correctly handle batched inputs.
    • Output Shape Verification: Confirm that output tensors have the expected dimensions.
    • Pre-trained Model Inference: This loads a pre-trained model and checks for the output values and validates that the outputs are not NaN / Inf.

8. Latency Comparison: Classic GRU vs. Minimal GRU Scan

This section presents benchmarking results that demonstrate the computational advantages of the minimal GRU/LSTM architectures when implemented with parallel scan algorithms on GPU hardware.

Symbol Meaning Value in the demo
$D$ model (embedding) width 512
$H$ inner-state width 768
$T$ sequence / context length 256 - 65,536

8.1. Benchmark Implementation Paths

Our benchmarking compares three different implementation approaches to highlight the benefits of the minimal GRU parallelization strategy:

Path Code that runs Recurrence depth in $T$ Explanation
CPU-seq for t in range(T): cell(x_t, h_{t-1}) in PyTorch $O(T)$ Baseline: classic GRU/LSTM generation loop.
CPU-scan Full sequence through minLM; gates are one sequential call, scan is a Python loop $O(T)$ Shows what you get if you vectorise the gates but still resolve the recurrence on CPU.
GPU-scan (1) launch_gru_extract_scan_params (gate mat-vecs)
(2) launch_parallel_scan (work-efficient, depth $==\log_2T$)
(3) projection
$O(\log T)$ Implementation of the paper's claim: minimal GRU can be fully parallelised on GPU.

8.2. Theoretical Computational Analysis

We can do a theoretical complexity analysis to reveal why each implementation path scales differently:

Component FLOPs / memory CPU-seq CPU-scan GPU-scan
Gate mat-vecs $Wz·x_t$, $Wh·x_t$ $2THD$ serial single BLAS, parallel single CUDA GEMM, massively parallel
Recurrence $h_t = a_t h_{t-1}+b_t$ $2TH$ serial $O(T)$ serial $O(T)$ Blelloch scan $O(\log T)$

The main reason why we see speedup later is:

While CPU implementations must resolve the recurrence sequentially even with vectorized gate computations, the GPU implementation uses the associative parallel scan to reduce the recurrence depth from $O(T)$ to $O(\log T)$.

8.3. Empirical Benchmark Results

Testing was performed on Intel i9-12900K with 16 threads and RTX 4090 GPU (personal machine). Did not do for longer lengths in sequential mode because it was too slow.

GRU Results

$T$ CPU-seq CPU-scan GPU-scan CPU-GPU diff
256 634 ms 32.8 ms 25.8 ms 7.0 ms
512 1 247 ms 53.0 ms 46.0 ms 7.0 ms
1 024 2 395 ms 97.6 ms 92.1 ms 5.5 ms
2 048 2 919 ms 165 ms 193 ms -28 ms
4 096 5 493 ms 300 ms 340 ms -40 ms
8 192 661 ms 657 ms 4 ms
16 384 2 683 ms 1 333 ms 1 350 ms
32 768 5 524 ms 2 680 ms 2 844 ms
65 536 10 989 ms 5 330 ms 5 659 ms
131 072 22 712.9 ms 14 001.5 ms 8 711.4 ms
262 144 44 855.6 ms 27 395.3 ms 17 460.3 ms
524 288 90 704.6 ms 56 149.6 ms 34 555.0 ms

LSTM Results

$T$ CPU-seq CPU-scan GPU-scan CPU-GPU diff
256 701.0 ms 37.0 ms 22.0 ms 15.0 ms
512 1440.9 ms 58.9 ms 72.6 ms -13.7 ms
1 024 2966.7 ms 96.8 ms 107.8 ms -11.0 ms
2 048 5956.8 ms 170.9 ms 226.0 ms -55.1 ms
4 096 12035.1 ms 416.5 ms 431.5 ms -15.0 ms
8 192 664.3 ms 809.5 ms -145.2 ms
16 384 2993.6 ms 1693.9 ms 1299.7 ms
32 768 6329.9 ms 3457.4 ms 2872.5 ms
65 536 13005.4 ms 6709.8 ms 6295.6 ms
131 072 24744.7 ms 13590.1 ms 11154.6 ms
262 144 49180.6 ms 28217.9 ms 20962.7 ms
524 288 97331.7 ms 55078.1 ms 42253.6 ms

Sequential CPU vs GPU parallelized GRU Performance Comparison LSTM Performance Comparison

Key Observations:

  • For short contexts ($T < 1024$), GPU performance gains come primarily from higher FLOP/s throughput; the logarithmic scan depth advantage is negligible.
  • For long contexts, the $O(\log T)$ scan complexity mainly helps: at $T = 65536$, the GPU implementation is 2 $\times$ faster than the CPU scan and about 20 $\times$ faster than the token-sequential loop.
  • The CPU-scan vs CPU-seq comparison demonstrates that gate vectorization alone provides 10 $\times$ speedup, yet the implementation still scales linearly with sequence length.

8.4. GPU Implementation Architecture

The GPU implementation translates the parallelizable logic from the CPU implementation into optimized CUDA kernels.

8.4.1. Core GPU Kernels

  1. launch_gru_extract_scan_params

    • Computes gate parameters ($z_t$, $\tilde{h}_t$) and scan coefficients ($a_t$, $b_t$) for entire sequence
    • Fuses linear transformations with activation functions
  2. launch_parallel_scan

    • Implements work-efficient Blelloch scan algorithm
    • Uses log-space computation for numerical stability
    • Achieves $O(\log T)$ depth complexity with $O(T)$ work
  3. Projection layers

    • Final linear transformations for output generation
    • Optimized for coalesced memory access patterns

8.4.2. Kernel Mapping

CPU Operation (PyTorch) GPU Kernel GPU Task Description
linear_z(x), linear_h(x) launch_gru_extract_scan_params (matrix multiply + bias + sigmoid) Compute intermediate gate values and $a_t, b_t$ scan parameters from input $x_t$ and weights.
torch.sigmoid fused within parameter extraction kernel Apply sigmoid activation efficiently within gate computation.
associative_scan loop launch_parallel_scan (up-sweep + down-sweep phases) Perform numerically stable parallel scan using associative operator to compute all hidden states $h_t$.

8.5. Key Takeaways

  • Minimal GRU/LSTM removes $h_{t-1}$ dependencies from gate inputs, enabling the entire sequence to be evaluated through a single associative scan operation.
  • On GPU, the scan achieves $O(\log T)$ depth complexity; CPU implementations remain $O(T)$ unless parallel prefix algorithms are implemented.
  • While autoregressive generation must still proceed sequentially due to token feedback requirements, each step benefits from GPU acceleration. The scan accelerates batch inference and training backpropagation.
  • The results validate the theoretical predictions -- logarithmic scaling becomes dominant for long sequences, while short sequences benefit primarily from higher computational throughput.

9. Nsight Compute Analysis

To understand why the GPU implementation achieves the speedups shown in the benchmarks, I conducted detailed profiling using NVIDIA Nsight Compute. This analysis reveals the kernel-level performance characteristics and identifies remaining optimization opportunities.

9.1. Kernel at T = 4096

Rank Kernel (demangled) Time/launch Launches % wall-time Comment
1 min_gru_extract_scan_params_kernel
(fused, BLK_H = 16)
180 $\mu$ s 1 8 % gate mat-vec + sigmoid, now shared-mem tiled
2–9 compose_offset_kernel × 12 3 $\mu$ s 12 < 1 % up-sweep (Blelloch reduce)
10 – apply_scan_op_kernel × 4096 2 $\mu$ s 4096 10 % down-sweep; 4096 launches hurt latency
11– matvec_kernel × 4096 93 $\mu$ s 4096 72 % per-token projection (H $\rightarrow$ D) – new bottleneck

Total wall-clock: 270 ms (previously 340 ms). The original implementation launched separate kernels for each timestep, resulting in 4096 individual kernel launches for gate extraction. By fusing these into a single kernel with shared memory tiling, I reduced the gate extraction stage from 93 ms to just 8% of wall-time, achieving a 1.3× speed-up.

9.2. Kernel-by-Kernel Performance Analysis

Kernel Memory view Interpretation / next action
min_gru_extract_scan_params_kernel Mem Throughput = 1.9 TB/s, L1 hit = 66 %, L2 hit = 99 % 16 threads/row keep all weights in shared mem, loads x_t through texture path – we are bandwidth–limited now, not latency–blocked. Raising BLK_H above 16 would shave another 10% but exceeds 48 KiB default s-mem; you would need cudaFuncSetAttribute to opt-in to 96 KiB.
compose_offset_kernel 2–3 $\mu$ s each, "SM Busy < 30 %" Already negligible. Could fuse a couple of levels together, but won't make much difference.
apply_scan_op_kernel 96 global inst / launch, grid = (6, 1, 1) --> only 2 blocks Each launch moves only one time-step so the GPU is mostly idle. An idea is to run H warps $\times$ (T/32) blocks so a single launch walks the entire sequence or run a CTA per SM and loop inside. 2 $\mu$ s × 4096 is 8 ms of pure kernel-launch latency – cheap relative to the projection but "free" speed-up if you touch it.
matvec_kernel 23.9 GB/s, L1 hit 87 %, L2 hit 5 % (streaming) 4096 launches dominate runtime.

9.3. Performance Assessment Summary

The algorithmic speed-up has been achieved, reaching $O(\log T)$ depth with the GPU implementation providing approximately 20× speedup over CPU at 65k tokens. The gate extraction stage has been optimized through shared-memory tiling and constant-memory weights, moving it off the critical path to just 8% of wall-time. The scan kernels are functional but show observable per-step launch overhead that could be optimized if the matvec bottleneck is eliminated first. The projection stage using the matvec kernel is now the primary bottleneck, consuming 70% of wall-time. The overall peak utilization shows 38% SM Busy according to Nsight "Summary % Peak".

I optimized a computation step and hit peak memory speed at 1.9 TB/s. Now the main problem is processing each token one by one - this eats 72% of my runtime because I'm launching 4,096 tiny operations that waste memory bandwidth (only hitting 23 GB/s).

Profiling: After fusing the gate-extraction path into a single shared-memory kernel, this stage dropped from 93 ms to 8% of wall-time. Nsight Compute shows it saturating 1.9 TB/s of L2 bandwidth with 66% L1 hit-rate – meaning the kernel is memory-bandwidth-bound, not latency-limited. The remaining bottleneck is the per-token projection: 4096 launches of a matvec spend 72% of runtime doing 23 GB/s strided reads. A single cuBLAS GEMM ($C = A \cdot W^T$) would eliminate that launch overhead and is expected to bring total latency for $T = 4096$ from 270 ms down to 120 ms, pushing effective speed-up vs. CPU-scan above 20 $\times$ for long sequences.

Profile 1 Profile 2 Profile 3 Profile 4 Profile 5 Profile 6

About

CUDA implementation of "Were RNNs All We Needed?" - minLSTM and minGRU models with fewer parameters and full training parallelization

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors