-
Dependencies:
- PyTorch
einopspytest
-
Setup:
pip install torch torchvision torchaudio einops pip install pytest # For running tests pip install -U "huggingface_hub[cli]"
Authenticate with Hugging Face:
hf auth login
Go to repository root and run the following:
$ ./download.sh
If the clone fails, here is the explicit link to access the dataset.
-
Running the Demo Script: To train the model:
$ cd cpu/ $ python train.pyModify
model_use_lstm: bool = Falseto either train LSTM or GRUTo run the inference demo:
$ python inference.py
Change
model = 'GRU' # other option: 'LSTM'to either 'GRU' or 'LSTM' to test the respective model. -
Running Tests:
$ pytest ./
Instructions for reproducing the benchmark results:
Note: In the zip, the build directory is already present.
# source build
$ mkdir -p gpu/build && cd gpu/build
$ cmake ..
$ make -j$(nproc)
# token-by-token processing and generation (sequential - main.cu)
$ ./mini_rnn_infer # GRU by default
$ ./mini_rnn_infer --lstm # LSTM variant
# parallel-scan timing demo
$ ./parallel_forward_test 4096 # any T up to GPU RAM
$ ./parallel_forward_test --lstm 4096 # the LSTM scan pathFor CPU sequential mode vs GPU parallel scan mode, run the following:
# CPU-vs-GPU comparison
$ cd cpu
$ python tests/compare.pyFor NSight Profiler, run the following:
$ cd gpu/build
$ ncu --set full --export parallel_4k --target-processes all ./parallel_forward_test 128Note: This documentation includes both CPU and GPU implementations. The CPU implementation is in the cpu directory, and the GPU implementation is in the gpu directory. GPU explanation starts at section 8.
This README contains the CPU implementation and outlines a GPU parallelization strategy for Gated Recurrent Units (GRUs) and Long Short-Term Memory (LSTM) networks based on the paper "Were RNNs All We Needed?" by Feng et al. The primary motivation of this work is to explore the potential of achieving computational speedups in RNNs by simplifying their gating mechanisms and using the parallel scan algorithm. Traditional RNNs, while powerful for sequence modeling, suffer from inherent sequential dependencies, limiting their training and inference speed, especially on parallel hardware like GPUs. This project aims to implement "minimal" versions, minGRU and minLSTM, which are designed to be fully parallelizable during training, and to lay the groundwork for an efficient CUDA-based GPU implementation.
The core contributions of the original paper, and thus the focus of this project, involve:
- Simplifying Gates: Modifying the standard LSTM and GRU gate computations to remove dependencies on the previous hidden state (
$h_{t-1}$ ). - Parallel Training via Parallel Scan: Utilizing the parallel scan (prefix sum) algorithm to compute the recurrent states across an entire sequence in parallel, once the gate simplifications are in place.
This README will cover:
- The original formulations of LSTM and GRU gates.
- The specific simplifications made to create minGRU and minLSTM.
- The architecture and pseudo-code for the CPU implementation of minGRU and minLSTM.
- An explanation of how the parallel scan algorithm enables parallelization.
- A plan for the GPU implementation, including kernel descriptions and optimization considerations.
- Instructions for running the CPU demonstration and its test cases.
Before detailing the simplifications, I'll quickly go over the standard formulations of LSTM and GRU, as these are the starting points for the minGRU and minLSTM architectures.
LSTMs were introduced to address the vanishing gradient problem in traditional RNNs, enabling them to capture long-range dependencies. An LSTM cell maintains a cell state (
The standard PyTorch LSTM equations are as follows:
- Input Gate (
$i_t$ ): Controls which new information is stored in the cell state.$$i_t = \sigma(W_{ii}x_t + b_{ii} + W_{hi}h_{t-1} + b_{hi})$$ - Forget Gate (
$f_t$ ): Determines what information is discarded from the cell state.$$f_t = \sigma(W_{if}x_t + b_{if} + W_{hf}h_{t-1} + b_{hf})$$ - Cell Gate (
$g_t$ or Candidate Cell State $\tilde{c}t$): Computes new candidate values to be added to the cell state. $$g_t = \tanh(W{ig}x_t + b_{ig} + W_{hg}h_{t-1} + b_{hg})$$ - Output Gate (
$o_t$ ): Decides what information from the cell state is output.$$o_t = \sigma(W_{io}x_t + b_{io} + W_{ho}h_{t-1} + b_{ho})$$ - Cell State Update (
$c_t$ ):$$c_t = f_t \odot c_{t-1} + i_t \odot g_t$$ - Hidden State Update (
$h_t$ ):$$h_t = o_t \odot \tanh(c_t)$$
Where:
-
$x_t$ : Input vector at time step$t$ . -
$h_{t-1}$ : Hidden state from the previous time step. -
$c_{t-1}$ : Cell state from the previous time step. -
$W, b$ : Weight matrices and bias vectors. -
$\sigma$ : Sigmoid activation function. -
$\tanh$ : Hyperbolic tangent activation function. -
$\odot$ : Element-wise multiplication.
GRUs, simplify the LSTM architecture by combining the forget and input gates into a single "update gate" and merging the cell state and hidden state. This results in fewer parameters and often comparable performance.
The standard PyTorch GRU equations are as follows:
- Reset Gate (
$r_t$ ): Determines how much of the previous hidden state to forget.$$r_t = \sigma(W_{ir}x_t + b_{ir} + W_{hr}h_{t-1} + b_{hr})$$ - Update Gate (
$z_t$ ): Controls how much of the previous hidden state is kept versus how much the new candidate state is used.$$z_t = \sigma(W_{iz}x_t + b_{iz} + W_{hz}h_{t-1} + b_{hz})$$ - New Gate (
$n_t$ or Candidate Hidden State $\tilde{h}t$): Computes the candidate hidden state. $$n_t = \tanh(W{in}x_t + b_{in} + r_t \odot (W_{hn}h_{t-1} + b_{hn}))$$ - Hidden State Update (
$h_t$ ):$$h_t = (1 - z_t) \odot n_t + z_t \odot h_{t-1}$$
Note: The formulation for
The main innovation in "Were RNNs All We Needed?" is the simplification of LSTM and GRU gates to remove their direct dependency on the previous hidden state
The transformation from GRU to minGRU involves two main steps:
-
Drop Previous State Dependencies from Gates:
- The update gate
$z_t$ and candidate hidden state$\tilde{h}_t$ (originally$n_t$ ) are redefined to depend only on the current input$x_t$ :- Original $z_t = \sigma(\text{Linear}{z_i}(x_t, h{t-1}))$
$\rightarrow$ Simplified$z_t = \sigma(\text{Linear}_z(x_t))$ - Original $\tilde{h}t = \tanh(\text{Linear}{h_i}(x_t, r_t \odot h_{t-1}))$
$\rightarrow$ Simplified$\tilde{h}_t = \tanh(\text{Linear}_h(x_t))$
- Original $z_t = \sigma(\text{Linear}{z_i}(x_t, h{t-1}))$
- The reset gate (
$r_t$ ) is removed entirely, as its primary role was to modulate the influence of$h_{t-1}$ on$\tilde{h}_t$ , which is no longer a direct dependency.
- The update gate
-
Drop Range Restriction of Candidate States:
- The hyperbolic tangent (
$\tanh$ ) activation function is removed from the computation of the candidate hidden state$\tilde{h}_t$ . This simplifies the computation to a linear transformation:$\tilde{h}_t = \text{Linear}_h(x_t)$
- The hyperbolic tangent (
The resulting recurrence for minGRU is:
This fits the parallel scan form
Since
The transformation from LSTM to minLSTM involves three steps:
-
Drop Previous State Dependencies from Gates:
- The forget gate
$f_t$ , input gate$i_t$ , and candidate cell state$\tilde{c}_t$ are redefined to depend only on$x_t$ :- Simplified
$f_t = \sigma(\text{Linear}_f(x_t))$ - Simplified
$i_t = \sigma(\text{Linear}_i(x_t))$ - Simplified
$\tilde{c}_t = \tanh(\text{Linear}_c(x_t))$ (initially, then tanh is removed in step 2)
- Simplified
- The forget gate
-
Drop Range Restriction of Candidate States and Hidden State Activation:
- The tanh activation is removed from
$\tilde{c}_t$ , making it$\tilde{c}_t = \text{Linear}_c(x_t)$ . - The tanh activation on the cell state
$c_t$ in the hidden state computation ($h_t = o_t \odot \tanh(c_t)$) is also removed, so$h_t = o_t \odot c_t$ .
- The tanh activation is removed from
-
Simplifying Scaling of Output (Merging Cell and Hidden State):
- The output gate
$o_t$ is removed. - The cell state
$c_t$ is effectively merged with the hidden state$h_t$ . The recurrence is now directly on$h_t$ . The candidate state is denoted$\tilde{h}_t = \text{Linear}_h(x_t)$ . - The recurrence becomes:
$h_t = f_t \odot h_{t-1} + i_t \odot \tilde{h}_t$ .
- The output gate
-
Normalization for Time-Independent Outputs (Length Independence Scaling):
- To prevent the magnitude of
$h_t$ from growing with sequence length, the forget and input gates are normalized:$$f'_t = \frac{f_t}{f_t + i_t + \epsilon}$$ $$i'_t = \frac{i_t}{f_t + i_t + \epsilon}$$ (where$\epsilon$ is a small constant for numerical stability,$10^{-8}$ ).
- To prevent the magnitude of
The final minLSTM recurrence is: $h_t = f't \odot h{t-1} + i'_t \odot \tilde{h}_t$.
This fits the parallel scan form
Again,
These simplifications assume that the reduced architectural complexity and modified gate interactions are still sufficient to capture the necessary temporal dependencies for the tasks at hand. The significant reduction in parameters (minGRU uses
The CPU implementation of minGRU is based on the definitions from the "Were RNNs All We Needed?" paper's appendix.
The core equations are:
- Update Gate (
$z_t$ ):$z_t = \sigma(W_z x_t + b_z)$ - Candidate Hidden State (
$\tilde{h}_t$ ):$\tilde{h}_t = W_h x_t + b_h$ (Note: tanh is removed) - Hidden State Recurrence (
$h_t$ ):$h_t = (1 - z_t) \odot h_{t-1} + z_t \odot \tilde{h}_t$
Where
Pseudo-code (Vanilla Sequential Mode - from Appendix B.2 of Feng et al.):
def minGRU_forward_sequential(x_t, h_prev, linear_z, linear_h):
# x_t: current input (Batch, Dim_input)
# h_prev: previous hidden state (Batch, Dim_hidden)
# linear_z: torch.nn.Linear for update gate
# linear_h: torch.nn.Linear for candidate hidden state
z_t = torch.sigmoid(linear_z(x_t))
h_tilde_t = linear_h(x_t) # tanh removed as per final minGRU definition
h_t = (1 - z_t) * h_prev + z_t * h_tilde_t
return h_tPseudo-code (Vanilla Parallel Mode - from Appendix B.2 of Feng et al.):
This mode computes all
def minGRU_forward_parallel(x_seq, h_0, linear_z, linear_h):
# x_seq: full input sequence (SeqLen, Batch, Dim_input)
# h_0: initial hidden state (Batch, Dim_hidden)
z_seq = torch.sigmoid(linear_z(x_seq)) # (SeqLen, Batch, Dim_hidden)
h_tilde_seq = linear_h(x_seq) # (SeqLen, Batch, Dim_hidden)
# Parameters for the parallel scan: h_t = a_t * h_{t-1} + b_t
a_seq = (1 - z_seq) # (SeqLen, Batch, Dim_hidden)
b_seq = z_seq * h_tilde_seq # (SeqLen, Batch, Dim_hidden)
h_sequence = associative_parallel_scan(a_seq, b_seq, h_0)
return h_sequenceThe paper uses a numerically stable log-space version for the parallel scan.
Table: minGRU Equations and Parallel Scan Components
| Component | Equation | Notes |
|---|---|---|
| Update Gate ( |
Depends only on |
|
| Candidate State ( |
Depends only on |
|
| Recurrence ( |
Standard GRU-like update form | |
| Scan Parameter ( |
Multiplicative factor for |
|
| Scan Parameter ( |
Additive factor |
The CPU implementation of minLSTM is derived from the simplifications detailed in the "Were RNNs All We Needed?" paper's appendix.
The core equations are:
- Forget Gate (
$f_t$ ):$f_t = \sigma(W_f x_t + b_f)$ - Input Gate (
$i_t$ ):$i_t = \sigma(W_i x_t + b_i)$ - Candidate Hidden State (
$\tilde{h}_t$ ):$\tilde{h}_t = W_h x_t + b_h$ (Note: tanh removed,$o_t$ and$c_t$ merged) - Normalized Forget Gate (
$f'_t$ ):$f'_t = f_t / (f_t + i_t + \epsilon)$ - Normalized Input Gate (
$i'_t$ ):$i'_t = i_t / (f_t + i_t + \epsilon)$ - Hidden State Recurrence (
$h_t$ ): $h_t = f't \odot h{t-1} + i'_t \odot \tilde{h}_t$
Pseudo-code (Vanilla Sequential Mode - from Appendix B.2 of Feng et al.):
def minLSTM_forward_sequential(x_t, h_prev, linear_f, linear_i, linear_h, epsilon=1e-8):
# x_t: current input (Batch, Dim_input)
# h_prev: previous hidden state (Batch, Dim_hidden)
# linear_f, linear_i, linear_h: torch.nn.Linear layers
f_t = torch.sigmoid(linear_f(x_t))
i_t = torch.sigmoid(linear_i(x_t))
h_tilde_t = linear_h(x_t) # tanh removed
# Normalization for time-independent outputs
f_prime_t = f_t / (f_t + i_t + epsilon)
i_prime_t = i_t / (f_t + i_t + epsilon)
h_t = f_prime_t * h_prev + i_prime_t * h_tilde_t
return h_tPseudo-code (Vanilla Parallel Mode - from Appendix B.2 of Feng et al., conceptually):
def minLSTM_forward_parallel(x_seq, h_0, linear_f, linear_i, linear_h, epsilon=1e-8):
# x_seq: full input sequence (SeqLen, Batch, Dim_input)
# h_0: initial hidden state (Batch, Dim_hidden)
f_seq = torch.sigmoid(linear_f(x_seq))
i_seq = torch.sigmoid(linear_i(x_seq))
h_tilde_seq = linear_h(x_seq)
# Normalization
f_prime_seq = f_seq / (f_seq + i_seq + epsilon)
i_prime_seq = i_seq / (f_seq + i_seq + epsilon)
# Parameters for the parallel scan: h_t = a_t * h_{t-1} + b_t
a_seq = f_prime_seq
b_seq = i_prime_seq * h_tilde_seq
# Example from paper's pseudocode for vanilla parallel mode:
# h = parallel_scan(f_prime, torch.cat([h_0.unsqueeze(0), i_prime[:-1]*h_tilde_seq[:-1]], dim=0))
h_sequence = associative_parallel_scan(a_seq, b_seq, h_0) # Returns all h_t
return h_sequenceTable: minLSTM Equations and Parallel Scan Components
| Component | Equation | Notes |
|---|---|---|
| Forget Gate ( |
Depends only on |
|
| Input Gate ( |
Depends only on |
|
| Candidate State ( |
Depends only on |
|
| Normalized Forget Gate ( |
For time-independent scaling | |
| Normalized Input Gate ( |
For time-independent scaling | |
| Recurrence ( |
$h_t = f't \odot h{t-1} + i'_t \odot \tilde{h}_t$ | Simplified LSTM-like update |
| Scan Parameter ( |
Multiplicative factor for |
|
| Scan Parameter ( |
Additive factor |
Given an input sequence
source for Blelloch's algorithm: https://developer.nvidia.com/gpugems/gpugems3/part-vi-gpu-computing/chapter-39-parallel-prefix-sum-scan-cuda
Blelloch's algorithm is the parallel scan algorithm I'll be using, performing
- Reduce Phase (Up-Sweep): Computes partial sums (or results of
$\oplus$ ) in a tree-like manner from leaves to root. - Down-Sweep Phase: Uses the intermediate values from the up-sweep phase to compute the final prefix sums for all elements, traversing from root to leaves.
For the minGRU/minLSTM recurrence
The input to the parallel scan algorithm is a sequence of such pairs
For numerical stability, especially when
It is developed in Python using PyTorch using logic and pseudocode from the paper for minGRU and the derived logic for minLSTM.
-
minGRUCellandminLSTMCell: Modules for single time-step computation, useful for sequential unrolling and understanding. -
minGRUandminLSTMLayers: Full layers capable of processing entire sequences. These layers support:- Sequential Mode: Iterative computation,
$h_t = \text{cell}(x_t, h_{t-1})$ as a comparison baseline. - Parallel Mode: Computation of all
$a_t, b_t$ terms for the sequence at once using vectorized PyTorch operations, followed by a Python-based loop that implements the associative scan logic.
- Sequential Mode: Iterative computation,
- Inference: Forward pass execution for generating outputs.
- Layer Tests for
minGRUandminLSTM:- For a given input sequence and initial hidden state, the output from the sequential unrolling must be numerically very close to the output from the parallel scan-like execution.
torch.allclose()with a suitable tolerance is used. - Batch Processing: Ensure the layers correctly handle batched inputs.
- Output Shape Verification: Confirm that output tensors have the expected dimensions.
- Pre-trained Model Inference: This loads a pre-trained model and checks for the output values and validates that the outputs are not NaN / Inf.
- For a given input sequence and initial hidden state, the output from the sequential unrolling must be numerically very close to the output from the parallel scan-like execution.
This section presents benchmarking results that demonstrate the computational advantages of the minimal GRU/LSTM architectures when implemented with parallel scan algorithms on GPU hardware.
| Symbol | Meaning | Value in the demo |
|---|---|---|
| model (embedding) width | 512 | |
| inner-state width | 768 | |
| sequence / context length | 256 - 65,536 |
Our benchmarking compares three different implementation approaches to highlight the benefits of the minimal GRU parallelization strategy:
| Path | Code that runs | Recurrence depth in |
Explanation |
|---|---|---|---|
| CPU-seq |
for t in range(T): cell(x_t, h_{t-1}) in PyTorch |
Baseline: classic GRU/LSTM generation loop. | |
| CPU-scan | Full sequence through minLM; gates are one sequential call, scan is a Python loop |
Shows what you get if you vectorise the gates but still resolve the recurrence on CPU. | |
| GPU-scan | (1) launch_gru_extract_scan_params (gate mat-vecs)(2) launch_parallel_scan (work-efficient, depth (3) projection |
Implementation of the paper's claim: minimal GRU can be fully parallelised on GPU. |
We can do a theoretical complexity analysis to reveal why each implementation path scales differently:
| Component | FLOPs / memory | CPU-seq | CPU-scan | GPU-scan |
|---|---|---|---|---|
| Gate mat-vecs |
serial | single BLAS, parallel | single CUDA GEMM, massively parallel | |
| Recurrence |
serial |
serial |
Blelloch scan |
While CPU implementations must resolve the recurrence sequentially even with vectorized gate computations, the GPU implementation uses the associative parallel scan to reduce the recurrence depth from
Testing was performed on Intel i9-12900K with 16 threads and RTX 4090 GPU (personal machine). Did not do for longer lengths in sequential mode because it was too slow.
| CPU-seq | CPU-scan | GPU-scan | CPU-GPU diff | |
|---|---|---|---|---|
| 256 | 634 ms | 32.8 ms | 25.8 ms | 7.0 ms |
| 512 | 1 247 ms | 53.0 ms | 46.0 ms | 7.0 ms |
| 1 024 | 2 395 ms | 97.6 ms | 92.1 ms | 5.5 ms |
| 2 048 | 2 919 ms | 165 ms | 193 ms | -28 ms |
| 4 096 | 5 493 ms | 300 ms | 340 ms | -40 ms |
| 8 192 | – | 661 ms | 657 ms | 4 ms |
| 16 384 | – | 2 683 ms | 1 333 ms | 1 350 ms |
| 32 768 | – | 5 524 ms | 2 680 ms | 2 844 ms |
| 65 536 | – | 10 989 ms | 5 330 ms | 5 659 ms |
| 131 072 | – | 22 712.9 ms | 14 001.5 ms | 8 711.4 ms |
| 262 144 | – | 44 855.6 ms | 27 395.3 ms | 17 460.3 ms |
| 524 288 | – | 90 704.6 ms | 56 149.6 ms | 34 555.0 ms |
| CPU-seq | CPU-scan | GPU-scan | CPU-GPU diff | |
|---|---|---|---|---|
| 256 | 701.0 ms | 37.0 ms | 22.0 ms | 15.0 ms |
| 512 | 1440.9 ms | 58.9 ms | 72.6 ms | -13.7 ms |
| 1 024 | 2966.7 ms | 96.8 ms | 107.8 ms | -11.0 ms |
| 2 048 | 5956.8 ms | 170.9 ms | 226.0 ms | -55.1 ms |
| 4 096 | 12035.1 ms | 416.5 ms | 431.5 ms | -15.0 ms |
| 8 192 | – | 664.3 ms | 809.5 ms | -145.2 ms |
| 16 384 | – | 2993.6 ms | 1693.9 ms | 1299.7 ms |
| 32 768 | – | 6329.9 ms | 3457.4 ms | 2872.5 ms |
| 65 536 | – | 13005.4 ms | 6709.8 ms | 6295.6 ms |
| 131 072 | – | 24744.7 ms | 13590.1 ms | 11154.6 ms |
| 262 144 | – | 49180.6 ms | 28217.9 ms | 20962.7 ms |
| 524 288 | – | 97331.7 ms | 55078.1 ms | 42253.6 ms |
Key Observations:
- For short contexts (
$T < 1024$ ), GPU performance gains come primarily from higher FLOP/s throughput; the logarithmic scan depth advantage is negligible. - For long contexts, the
$O(\log T)$ scan complexity mainly helps: at$T = 65536$ , the GPU implementation is 2$\times$ faster than the CPU scan and about 20$\times$ faster than the token-sequential loop. - The CPU-scan vs CPU-seq comparison demonstrates that gate vectorization alone provides 10
$\times$ speedup, yet the implementation still scales linearly with sequence length.
The GPU implementation translates the parallelizable logic from the CPU implementation into optimized CUDA kernels.
-
launch_gru_extract_scan_params- Computes gate parameters (
$z_t$ ,$\tilde{h}_t$ ) and scan coefficients ($a_t$ ,$b_t$ ) for entire sequence - Fuses linear transformations with activation functions
- Computes gate parameters (
-
launch_parallel_scan- Implements work-efficient Blelloch scan algorithm
- Uses log-space computation for numerical stability
- Achieves
$O(\log T)$ depth complexity with$O(T)$ work
-
Projection layers
- Final linear transformations for output generation
- Optimized for coalesced memory access patterns
| CPU Operation (PyTorch) | GPU Kernel | GPU Task Description |
|---|---|---|
linear_z(x), linear_h(x)
|
launch_gru_extract_scan_params (matrix multiply + bias + sigmoid) |
Compute intermediate gate values and |
torch.sigmoid |
fused within parameter extraction kernel | Apply sigmoid activation efficiently within gate computation. |
associative_scan loop |
launch_parallel_scan (up-sweep + down-sweep phases) |
Perform numerically stable parallel scan using associative operator to compute all hidden states |
- Minimal GRU/LSTM removes
$h_{t-1}$ dependencies from gate inputs, enabling the entire sequence to be evaluated through a single associative scan operation. - On GPU, the scan achieves
$O(\log T)$ depth complexity; CPU implementations remain$O(T)$ unless parallel prefix algorithms are implemented. - While autoregressive generation must still proceed sequentially due to token feedback requirements, each step benefits from GPU acceleration. The scan accelerates batch inference and training backpropagation.
- The results validate the theoretical predictions -- logarithmic scaling becomes dominant for long sequences, while short sequences benefit primarily from higher computational throughput.
To understand why the GPU implementation achieves the speedups shown in the benchmarks, I conducted detailed profiling using NVIDIA Nsight Compute. This analysis reveals the kernel-level performance characteristics and identifies remaining optimization opportunities.
| Rank | Kernel (demangled) | Time/launch | Launches | % wall-time | Comment |
|---|---|---|---|---|---|
| 1 |
min_gru_extract_scan_params_kernel(fused, BLK_H = 16) |
180 |
1 | 8 % | gate mat-vec + sigmoid, now shared-mem tiled |
| 2–9 |
compose_offset_kernel × 12 |
3 |
12 | < 1 % | up-sweep (Blelloch reduce) |
| 10 – |
apply_scan_op_kernel × 4096 |
2 |
4096 | 10 % | down-sweep; 4096 launches hurt latency |
| 11– |
matvec_kernel × 4096 |
93 |
4096 | 72 % | per-token projection (H |
Total wall-clock: 270 ms (previously 340 ms). The original implementation launched separate kernels for each timestep, resulting in 4096 individual kernel launches for gate extraction. By fusing these into a single kernel with shared memory tiling, I reduced the gate extraction stage from 93 ms to just 8% of wall-time, achieving a 1.3× speed-up.
| Kernel | Memory view | Interpretation / next action |
|---|---|---|
min_gru_extract_scan_params_kernel |
Mem Throughput = 1.9 TB/s, L1 hit = 66 %, L2 hit = 99 % | 16 threads/row keep all weights in shared mem, loads x_t through texture path – we are bandwidth–limited now, not latency–blocked. Raising BLK_H above 16 would shave another 10% but exceeds 48 KiB default s-mem; you would need cudaFuncSetAttribute to opt-in to 96 KiB. |
compose_offset_kernel |
2–3 |
Already negligible. Could fuse a couple of levels together, but won't make much difference. |
apply_scan_op_kernel |
96 global inst / launch, grid = (6, 1, 1) --> only 2 blocks | Each launch moves only one time-step so the GPU is mostly idle. An idea is to run H warps |
matvec_kernel |
23.9 GB/s, L1 hit 87 %, L2 hit 5 % (streaming) | 4096 launches dominate runtime. |
The algorithmic speed-up has been achieved, reaching
I optimized a computation step and hit peak memory speed at 1.9 TB/s. Now the main problem is processing each token one by one - this eats 72% of my runtime because I'm launching 4,096 tiny operations that waste memory bandwidth (only hitting 23 GB/s).
Profiling:
After fusing the gate-extraction path into a single shared-memory kernel, this stage dropped from 93 ms to 8% of wall-time. Nsight Compute shows it saturating 1.9 TB/s of L2 bandwidth with 66% L1 hit-rate – meaning the kernel is memory-bandwidth-bound, not latency-limited. The remaining bottleneck is the per-token projection: 4096 launches of a matvec spend 72% of runtime doing 23 GB/s strided reads. A single cuBLAS GEMM (








