Math Primitives

This document outlines the core mathematical primitives implemented in src/maths/primitives.py. These primitives are rigorously designed for numerical stability and form the fundamental backbone of the Variational Linear Attention (VLA) system.

1. Inner-Product Score

The attention score \(s_t\) at timestep \(t\) is computed as the scaled dot product of a key vector \(k_t \in \mathbb{R}^d\) and a query vector \(q_t \in \mathbb{R}^d\):

\[ s_t = \frac{k_t^\top q_t}{\sqrt{d}} \]

Implementation Details:

  • Returns: A scalar value indicating query-key relevance.

  • Scaling: Includes the standard \(1/\sqrt{d}\) scaling factor to mathematically constrain gradients from vanishing or exploding, completely analogous to standard scaled dot-product attention.

  • Safety Measures: Validates that all values are strictly finite (void of NaN or Inf), which is crucially required before undertaking subsequent inverse operations.


2. Sherman–Morrison Rank-1 Inverse Update

The core bottleneck of stably solving a linear system at every timestep is the matrix inversion. The Sherman-Morrison formula empowers VLA to surgically update the inverse of the penalty matrix \(M_t\) when it is mathematically perturbed by a rank-1 update \(u_t u_t^\top\).

Theoretical Foundation

Given:

  • \(M_{t-1} \in \mathbb{R}^{d \times d}\): The prior penalty matrix (symmetric positive definite).

  • \(A_{t-1} = M_{t-1}^{-1}\): The exact inverse of the prior penalty matrix.

  • \(u_t \in \mathbb{R}^d\): The new update vector proposed by the PenaltyBuilder.

The updated penalty matrix naturally takes the form \(M_t = M_{t-1} + u_t u_t^\top\). Computing \((M_t)^{-1}\) directly scales at \(\mathcal{O}(d^3)\). However, by leveraging the Sherman-Morrison identify, we reconstruct the inverse update securely in \(\mathcal{O}(d^2)\):

\[ A_t = A_{t-1} - \frac{A_{t-1} u_t u_t^\top A_{t-1}}{1 + u_t^\top A_{t-1} u_t} \]

Algorithmic Execution Steps

  1. Compute Denominator (Scalar): $\( \delta = 1 + u_t^\top \left(A_{t-1} u_t\right) \)$

  2. Numerical Safety Enforcement: If the absolute value \(|\delta| < \epsilon\), a strict fallback is triggered (adding \(\epsilon I\)) stopping division by zero or gradient detonation. (Default \(\epsilon \approx 10^{-6}\)).

  3. Compute Intermediate Vector (\(z \in \mathbb{R}^d\)): $\( z = A_{t-1} u_t \)$

  4. Compute Outer Product (\(O \in \mathbb{R}^{d \times d}\)): $\( O = z z^\top \)$

  5. Final Rank-1 Update: $\( A_t = A_{t-1} - \frac{O}{\delta} \)$

  6. Periodic Stabilization: Every \(K\) steps, add \(\epsilon I\) onto the diagonal tensor of \(A_t\) to correct accumulating floating-point drift. This represents a critical fix for extremely long context sequence regimes.


3. Multiple Rank-1 Updates (Woodbury Generalization)

To securely support higher-rank context parameterizations, we sequence iterating rank-1 updates \(\{u_1, u_2, \dots, u_r\}\). This explicitly mirrors the mathematical equivalence of the Woodbury matrix identity for a rank-\(r\) update—but iterating sequentially minimizes dangerous intermediate memory spikes.

\[\begin{split} \begin{aligned} A^{(0)} &= A_{t-1} \\ A^{(i)} &= \text{ShermanMorrison}\left(A^{(i-1)}, \ u_i\right) \quad \text{for } i \in \{1, \dots, r\} \\ A_t &= A^{(r)} \end{aligned} \end{split}\]

4. Recovering Optimal Coefficients \(\alpha^*\)

In standard Linear Attention, the global memory matrix \(S_t\) is linearly updated using static \(v_t k_t^\top\). In sharp contrast, VLA computes a globally optimal scaling vector \(\alpha_t\) that completely minimizes the associative reconstruction error of the value vector \(v_t\) bounded strictly by the active penalty \(M_t\).

As VLA naturally tracks the inverted formulation \(A_t = M_t^{-1}\), generating the theoretical ground-truth optimum \(\alpha^* = M_t^{-1} s_t\) simplifies exponentially into a single matrix-vector hardware product:

\[ \alpha_t = A_t s_t \]

5. Memory Matrix Update

Having securely arrived at the optimal coefficient map \(\alpha_t\), the global memory matrix \(S_t\) updates to ingest new contextual information:

\[ S_t = S_{t-1} + \alpha_t \otimes \left(v_t k_t^\top\right) \]

Implementation Details:

  • Batched Outer Products: Evaluated safely using v.unsqueeze(2) * alpha.unsqueeze(1) ensuring fast batched hardware execution without explicit loops.

  • In-place Constraints: The update strictly operates out-of-place (\(S_t = S_{t-1} + \Delta\)) guarding the PyTorch computational graph graph. (Direct memory overwrites like += are banned.)

  • Renormalization Guard: In regimes where the Frobenius norm \(\|S_t\|_F\) balloons over threshold, \(S_t\) rescales safely avoiding numeric overflow during ultra-long runtimes.