Penalty Builder Module¶
Overview¶
The PenaltyBuilder is the central component that grants Variational Linear Attention its adaptive memory capabilities. Located in src/models/attention/penalty_builder.py, this module constructs the data-dependent penalty matrix \(M_t(\theta)\) at each timestep \(t\). By parameterizing \(M_t\), the model learns exactly how to weigh and forget historical information.
The module supports multiple parameterizations for the penalty matrix, balancing expressivity and computational cost.
Parameterizations¶
1. Diagonal + Rank-1 (\(M_t = \lambda_t I + u_t u_t^\top\))¶
The simplest and most robust parameterization. The matrix is decomposed into a uniform decay scalar (\(\lambda_t\)) and a rank-1 outer product (\(u_t u_t^\top\)), which allows the model to selectively penalize specific directions in the key-space.
Formulas: $\( \lambda_t = \text{softplus}(\text{MLP}_\lambda(k_t)) \)\( \)\( u_t = W_u k_t \)$
Implementation Constraints:
\(\lambda_t\) is strictly clamped to \(\ge \lambda_{\text{min}}\) (e.g., \(10^{-3}\)) to guarantee positive definiteness.
\(W_u \in \mathbb{R}^{d \times d}\) is a learnable projection matrix mapping the key dimension to the update vector.
Output: Returns the tuple
(\lambda_t, u_t).
2. Diagonal + Rank-r (\(M_t = \lambda_t I + \sum_{m=1}^r u_{t,m} u_{t,m}^\top\))¶
A higher-capacity parameterization that allows the model to penalize an \(r\)-dimensional subspace at each step.
Formulas: $\( \lambda_t = \text{softplus}(\text{MLP}_\lambda(k_t)) \)\( \)\( u_{t,m} = W_{u,m} k_t \quad \text{for } m = 1, \dots, r \)$
Implementation Constraints:
\(U_t = [u_{t,1}, \dots, u_{t,r}]\) is computed efficiently via a single batched linear projection \(W_u \in \mathbb{R}^{(r \cdot d) \times d}\) and subsequently reshaped into \((B, r, d)\).
Output: Returns
(\lambda_t, U_t).
3. Kernelized Penalty (Low-rank Approximation)¶
This parameterization uses a generic feature map \(\phi(x)\) to approximate full kernelized penalty landscapes.
Formulas: $\( M_t \approx \lambda I + \phi_t \phi_t^\top \)\( \)\( \phi_t = W_\phi k_t \)$
Note: Currently implemented via KernelPenaltyBuilder. This shares the logic of Rank-1 but is designed for future expansions into full RBF kernel approximations.
Recurrence & Inverse Tracking¶
The outputs from the PenaltyBuilder are not used as \(M_t\) directly; doing so would require \(\mathcal{O}(d^3)\) inversion. Instead, \(\lambda_t\) and \(u_t\) are passed to the InversePenaltyTracker module.
Crucial Detail: In
VLALayer, the time-varying scalar \(\lambda_t\) fromPenaltyBuilderis unused in the \(A_t\) update step. The recurrence strictly relies on \(u_t\) and an initial \(\lambda_0\). Consequently, the parameters within the \(\lambda\)-network (lambda_net) do not receive gradients during the standard VLA forward pass.
Inputs and Outputs¶
Inputs:
\(k_t\): The input key vector. Shape:
(B, d)for streaming auto-regressive decoding, or(B, T, d)for batch-training.
Outputs:
\(\lambda_t\): Scalar base penalty. Shape:
(B, 1)or(B, T, 1).\(u_t\) (or \(U_t\)): Update vector(s). Shape:
(B, d)/(B, T, d)for rank-1, or(B, r, d)/(B, T, r, d)for rank-\(r\).stats: A dictionary containing internal tracking statistics (e.g., mean norms, eigenvalues) for logging via WandB or TensorBoard. Modules must return this dictionary rather than invoking logging frameworks directly.
Computation Graph¶
graph TD
K[Key k_t] --> MLP[MLP_lambda]
MLP --> Softplus
Softplus --> Clamp[Clamp >= lambda_min]
Clamp --> Lambda[lambda_t]
K --> Wu[Linear W_u]
Wu --> U[u_t / U_t]
Lambda --> Tracker[InverseTracker]
U --> Tracker