Getting Started & Reproducibility

This guide provides instructions on how to set up the environment, run core experiments, and utilize the Variational Linear Attention (VLA) API in your own projects. All commands should be executed from the root of the repository.

Environment Setup

The repository is built strictly with float32 and float64 operations in PyTorch for maximum numerical stability. We recommend a dedicated Python environment (e.g., conda or venv).

# Clone the repository
git clone https://github.com/deepbrain-labs/variational-linear-attention
cd variational-linear-attention

# Install dependencies
pip install -r requirements.txt

Key Dependencies

  • torch >= 2.0.0

  • numpy, scipy

  • pytest (for unit tests and strict verification)

  • wandb (for tracking experiments)


1. Verifying Core Primitives

Before running large-scale models, we highly recommend verifying that your local environment computes the mathematical primitives with the required precision.

Our math primitive tests use float64 precision (tight tolerance: \(10^{-6}\)) to ensure the Sherman-Morrison inversion acts correctly.

# Run the test suite on CPU
pytest tests/

Note: PyTorch tests must be executed strictly on the CPU to avoid un-reproducible GPU floating-point non-determinism during standard mathematical checks. The codebase itself fully supports .to(device).


2. Running Synthetic Memory Tasks

To verify the core hypothesis that VLA outperforms standard Linear Attention on long-context memorization tasks, you can run the synthetic verification scripts.

The Copy Task

This tests if the model can read a sequence of tokens and identically output them without loss of information.

python -m tests.verify_vla --task copy

Delayed Recall

This tests if the model can remember a key-value pair after observing a massive number of noisy distractors.

# 10k context delay
python -m tests.verify_vla --task delayed_recall --delay 10000

Logs: Training logs and results for synthetic tasks will automatically be saved to results/synthetic_copy/ with a timestamped filename.


3. Running LRA Benchmarks

To reproduce our Long Range Arena (LRA) results, use the dedicated benchmarking scripts. You must first ensure the Hugging Face LRA datasets are downloaded.

# Setup LRA datasets
# This may take a while depending on your internet connection
python scripts/download_lra.py

# Run VLA on the Image task
python src/benchmarks/run_lra.py --model vla --task image

# Run VLA on the Path-X task (Extreme long context: 16k)
python src/benchmarks/run_lra.py --model vla --task pathx

4. Re-generating Paper Plots

The symbolic and diagnostic plots (e.g., Eigenvalue stability, Penalty Matrix Heatmaps) shown in the Experiments section can be natively re-generated.

# Generates figures into results/symbolic_experiments/
python scripts/generate_plots.py --all

Utilizing VLA in Your Project

If you wish to drop VLA into an existing PyTorch codebase, simply import VLALayer and replace your standard attention blocks.

import torch
from src.models.attention.vla_layer import VLALayer

# Input dimensions: (Batch, Sequence_Length, d_model)
batch_size, seq_len, d_model = 32, 1024, 256
x = torch.randn(batch_size, seq_len, d_model).cuda()

# Initialize VLA. d_head must equal d_model.
vla = VLALayer(d_model=d_model, d_head=d_model, rank=1).cuda()

# Forward pass (stats contains diagnostic info)
output, stats = vla(x)

assert output.shape == (batch_size, seq_len, d_model)