Duality of Forget Gates and Position Embeddings in Sequence Modeling
The literature on linear transformers and State Space Models (SSMs) is progressing rapidly
One of the most well-known examples is Rotary Position Embedding (RoPE)
Causal softmax attention takes the form:
\begin{align} \mathbf{O} = \text{Softmax}\left(\frac{\mathbf{Q} \mathbf{K}^\top}{\sqrt{d}} + \mathbf{M}\right) \mathbf{V} \tag{Matrix Form}
\end{align}
where $\mathbf{Q}, \mathbf{K}, \mathbf{V}, \mathbf{O} \in \mathbb{R}^{T \times d}$ denote the queries, keys, values, and the output of the softmax attention, respectively, and $\mathbf{M} \in (-\infty, 1)$ is the lower-triangular causal mask. For simplicity, we assume that the values and keys share the same shape and omit the $\sqrt{d}$ scaling factor from here on
As shown above, permuting the order of the input tokens ($\mathbf{X}$) induces the same permutation in the attention output ($\mathbf{O}$). Consequently, the attention mechanism itself has no inherent notion of token order and does not encode positional information. Consequently, there is no inherent notion of positional information or explicit “positional embedding”. Softmax attention without any positional embedding, as presented in the equation above, is referred to as No-Position Embedding (NoPE)
To inject positional information into tokens, several approaches have been proposed. As an example, vanilla Transformer adds sinusoidal positional embeddings to the input tokens $\mathbf{x}_t$ only after embedding in the first layer. We are mostly interested in positional embeddings that have been shown to be performant, scalable, and sufficiently general to support many variants
with $\textcolor{blue} {\mathbf{D_{tj}}} \in \mathbb{R}^{d \times d}$ and $\textcolor{blue}{d_{t j}} \in \mathbb{R}$ denote postional embeddings responsible for decoding the relative position between tokens $t$ and $j$. These matrices (or scalars) may depend on the positions of the tokens ($t$) or may even depend on the inputs themselves $\mathbf{x}_t$.
The softmax Transformer, while highly performant, has quadratic computational complexity. To address this limitation, linear attention was introduced.
Linear attention removes the softmax operation from the Transformer
\begin{align} \mathbf{O} = \left(\mathbf{Q} \mathbf{K}^\top \odot \mathbf{M}\right) \mathbf{V} \tag{Matrix Form}
\end{align}
with $\mathbf{M}\in(0,1)$ denoting a binary causal mask. Examining the recurrent form of linear attention, it can be expressed as a linear recurrence:
\[\mathbf{S}_t = \mathbf{S}_{t-1} + \mathbf{v}_t \mathbf{k}^\top_t , \quad \mathbf{o}_t = \mathbf{S}_t\mathbf{q}_t\]with $\mathbf{S}_t \in \mathbb{R}^{d \times d}$ being a two-dimensional hidden state.
Vanilla linear Transformers simply accumulate information in their hidden state without forgetting. Soon after the introduction of linear attention, and due to its close connection to RNNs, researchers incorporated forget gates into linear Transformers insipred by RNN litriture
with $\textcolor{red}{\mathbf{A}_t}$ being a structured matrix (dense, low-rank, diagonal, etc.) extracted from the input $\mathbf{X}$ whose eigenvalues have magnitude less than 1 to ensure stability.
$\textcolor{red}{\mathbf{A}_t}$ is commonly referred to as the forget gate, not only decays historical information but also implicitly induces a relative positional embedding in linear Transformers. This effect becomes more apparent by unrolling the recurrence above:
\[\mathbf{o}_t = \sum_{j=1}^t \left( \mathbf{k^\top_j} \left( \textcolor{blue} {\prod_{s=j+1}^t \mathbf{A}_s} \right) \mathbf{q}_t \right) \mathbf{v}_j \tag{Unrolled LinAtt}\]Looking at the product $\textcolor{blue}{\prod_{s=j+1}^t \mathbf{A_s}}$, we see that the forget gate implicitly encodes relative positional information between positions $j$ and $t$.
Considering each $\mathbf{A_s}$ has eigenvalues $\le 1$, the product is contractive and can be written in log-space as $\textcolor{blue}{\prod_{s=j+1}^t \mathbf{A_s} = \exp \left(- \sum_{s=j+1}^t \mathbf{D_s}\right)}$, where each $\mathbf{D_s} \succeq 0$. We show that $\sum_{s=j+1}^t \mathbf{D_s}$ defines a distance $\mathbf{D}$ that satisfies all metric properties except symmetry, which is appropriate for causal (autoregressive) modeling.
One of the key differences between different linear recurrent models lies in the structure of the forget gate $\mathbf{A}_t$. Using this insight, we establish a duality between well-known linear Transformers and different PEs, showing how the structure of popular linear models gives rise to well-known PEs.
We begin with the simplest case $\mathbf{A_t}=1$, which results in $\textcolor{black} {\prod_{s=j+1}^t \mathbf{A}_s}=1$. Therefore, no relative positional embedding is applied, and the recurrence takes the form:
\[\mathbf{S}_t = \mathbf{S}_{t-1} + \mathbf{v}_t \mathbf{k}^\top_t\]The dual position embedding of vanilla linear attention is simply applying the causal mask to softmax attention:
\[\mathbf{O} = \text{Softmax}\left(\mathbf{Q} \mathbf{K}^\top + \mathbf{M}\right) \mathbf{V}, \quad \mathbf{M}_{ij} = \begin{cases} \textcolor{blue}{0} & i \ge j \\ -\infty & i < j \end{cases}\]Therefore, $\textcolor{blue}{\mathbf{M}}$ can be interpreted as an additive positional embedding in which all tokens are at the same relative distance from each other.
\[\begin{aligned} & \textcolor{black}{\mathbf{S}_t} \textcolor{black}{= \mathbf{S}_{t-1}\textcolor{blue}{1} + \mathbf{v}_t \mathbf{k}^\top_t} \\ & \textcolor{black}{\text{Softmax}\left({\mathbf{Q} \mathbf{K}^\top} + \textcolor{blue}{\mathbf{D}}\right) }, \quad \textcolor{blue} {\mathbf{D_{ij}}} \textcolor{black}{=\begin{cases} \textcolor{blue}{0} & i \ge j \\ - \infty & i<j \end{cases} }\end{aligned}\]
The above represents our first connection between linear attentions forgte gate and softmax transformer PE, where the forget gate is equal to $1$, i.e., no forgetting is applied. NoPE
A first improvement over linear attention is the Retentive Network (RetNet)
From the above, we can derive the unrolled cumulative forget gate:
\[\textcolor{black}{\prod_{s=j+1}^t \mathbf{A}_s} = \gamma^{\,t-j}.\]By setting $\gamma = \exp(-\lambda)$, the cumulative forget gate can be rewritten as
\[\prod_{s=j+1}^t \mathbf{A}_s = \exp\!\big(\textcolor{blue}{-\lambda (t-j)}\big)\]The $\textcolor{blue}{\text{blue}}$ term above is encodes the relative position between tokens $t$ and $j$ and by applying above as additive positional embedding to softmax transformer we can reach a well known position embedding of AliBi
\[\begin{aligned} & \textcolor{black}{\mathbf{S}_t} \textcolor{black}{= \mathbf{S}_{t-1}\textcolor{blue}{\exp(-\lambda)} + \mathbf{v}_t \mathbf{k}^\top_t} \\ & \textcolor{black}{\text{Softmax}\left({\mathbf{Q} \mathbf{K}^\top} + \textcolor{blue}{\mathbf{D}}\right) }, \quad \textcolor{blue} {\mathbf{D_{ij}}} \textcolor{black}{=\begin{cases} \textcolor{blue}{\lambda (i-j)} & i \ge j \\ - \infty & i<j \end{cases} }\end{aligned}\]
Thus, ALiBi PE, admits a direct duality with the RetNet linear Transformer.
import torch
def alibi(T, lambd):
"""
Simple AliBi implementation
T: Sequence length
lambd: lambda scaling of alibi
"""
idx = torch.arange(T)
I, J = torch.meshgrid(idx, idx, indexing='ij')
D_Add = - ( (torch.abs((I-J)).float()) * lambd ).tril()
return D_Add
def att_alibi(Q,K,V,lambd):
"""
Simple FoX Attention
Q,K,V: input (d , T)
"""
D_Add = alibi(Q.shape[-1] , lambd)
out = F.softmax(Q.T@K + D_add , dim=-1)@V.T
return outWhile RetNet incorporates a forget gate, it lacks input-dependent selectivity in the forgetting mechanism. Mamba2
As an example, the recurrence of GLA is:
\[\textcolor{black}{\mathbf{S}_t} \textcolor{black}{= \mathbf{S}_{t-1}\textcolor{blue}{\sigma(\mathbf{Wx_t})^{1/\tau}} + \mathbf{v}_t \mathbf{k}^\top_t}\]Above leads to the following cumulative forget-gate product:
\[\prod_{s=j+1}^t \mathbf{A}_s = \exp\!\big(\textcolor{blue}{\sum_{s=j+1}^t \log (\sigma(\mathbf{Wx_t})^{1/\tau}) }\big)\]for the case of GLA. Simplifying GLA’s forget gate from diagonal $\mathbf{A}_t$ into scalar we reach:
\[\prod_{s=j+1}^t \mathbf{A}_s = \exp\!\big(\textcolor{blue}{\sum_{s=j+1}^t \log (\sigma(\mathbf{wx_t})^{1/\tau}) }\big)\]The above forget gates can be interpreted as additive positional embeddings for Transformers with forget gates, also known as Forgetting Transformers (FoX)
\[\begin{aligned} & \textcolor{black}{\mathbf{S}_t} \textcolor{black}{= \mathbf{S}_{t-1}\textcolor{blue}{\sigma(\mathbf{Wx_t})^{1/\tau}} + \mathbf{v}_t \mathbf{k}^\top_t} \\ & \textcolor{black}{\text{Softmax}\left({\mathbf{Q} \mathbf{K}^\top} + \textcolor{blue}{\mathbf{D}}\right) }, \quad \textcolor{blue} {\mathbf{D_{ij}}} \textcolor{black}{=\begin{cases} \textcolor{blue}{\sum_{s=j+1}^t \log (\sigma(\mathbf{wx_t+b})) }& i \ge j \\ - \infty & i<j \end{cases} }\end{aligned}\]
Therefore, FoX, admits a direct duality with the Mamba2/GLA linear Transformers.
import torch
def fox_pe(w, x):
"""
Simple FoX implementation
x: input (d , T)
w: learnable weight (d , d)
"""
a = F.logsigmoid (w.transpose(-1,-2)@x) # (T)
a = torch.cumsum(a , dim=-1) # (T)
D_add = (a.transpose(-1,-2) - a).tril() # (T , T)
return D_Add
def fox(Q,K,V,x):
"""
Simple FoX Attention
Q,K,V: input (d , T)
w: learnable weight (d , d)
"""
D_Add = fox_pe(w, x)
out = F.softmax(Q.T@K + D_add , dim=-1)@V.T
return outLet’s now consider queries and keys are complex-valued, with $\mathbf{\tilde{q},\tilde{k}}\in\mathbb{C}^{d/2}$. We then choose a diagonal forget gate with imaginary roots, given by $\mathbf{A}_t = \exp(i\,\boldsymbol{\Omega}) \in \mathbb{C}^{d/2}$ where $\mathbf{\Omega}$ is a diagonal matrix of frequencies. Formally, the recurrence of this linear transformer is:
\[\mathbf{S}_t = \mathbf{S}_{t-1}\,\textcolor{blue}{e^{i\mathbf{\Omega}}} + \mathbf{v}_t \tilde{\mathbf{k}}_t^{\mathrm{H}}, \quad \mathbf{o}_t = \mathcal{R} \{\mathbf{S}_t \tilde{\mathbf{q}}_t\}, \quad \mathbf{o}_t = \sum_{j=1}^t \mathbf{v}_j \, \textcolor{blue}{\mathcal{R}\!\left\{\tilde{\mathbf{k}}_j^{\mathrm{H}} e^{i\mathbf{\Omega}(t-j)} \tilde{\mathbf{q}}_t \right\}}.\]with $\mathrm{H}$ being the complex conjugate operation.
We now breakdown some details:
Since the queries and keys are complex-valued, $\mathbf{\tilde{q},\tilde{k}}\in\mathbb{C}^{d/2}$, each channel has a real and an imaginary part, giving a total of $d$ floating-point values, as in previous setups.
$\mathbf{\Omega}_n = 10000^{-2n/d}$ being a diagonal matrix of angles followed by RoPE
We can rewrite the complex-valued queries and keys using amplitude–phase notation as: $\mathbf{\tilde{q}_n} = |\mathbf{\tilde{q}_n}|e^{i\phi (\mathbf{\tilde{q}_n})}, \mathbf{\tilde{k}_n} = |\mathbf{\tilde{k}_n}|e^{i\phi (\mathbf{\tilde{k}_n})}$
As $\mathbf{A}_t = \exp(i\boldsymbol{\Omega})$, all its eigenvalues have magnitude of 1; therefore, it leads to stable recurrence. This forget gate does not necessarily “forgets,” as it does not have any eigenvalues below one.
Using the phase–amplitude split of the queries and keys and $\mathbf{\Omega}$ being diagonal, the attention score ($\textcolor{blue}{\text{blue}}$ part above) can be expanded as:
\[\textcolor{black}{\mathcal{R}\!\left\{\tilde{\mathbf{k}}_j^{\mathrm{H}} e^{i\mathbf{\Omega}(t-j)} \tilde{\mathbf{q}}_t \right\}} = \mathcal{R}\!\left\{ \sum_{n=1}^{d/2} |\mathbf{\tilde{q}_{t,n}}|\,|\mathbf{\tilde{k}_{j,n}}|\, e^{i\left(\mathbf{\Omega_n}(t-j)-\phi (\mathbf{\tilde{q}_{t,n}}) -\phi (\mathbf{\tilde{k}_{j,n}}) \right)} \right\} =\] \[\sum_{n=1}^{d/2} \lvert \mathbf{\tilde{q}_{t,n}} \rvert \, \lvert \mathbf{\tilde{k}_{j,n}} \rvert \, \cos\!\left( \mathbf{\Omega_n} (t - j) - \phi (\mathbf{\tilde{q}_{t,n}}) - \phi (\mathbf{\tilde{k}_{j,n}}) \right).\]Using the cosine identity, we can rewrite the cosine term as applying a 2D rotation matrix $R_{\mathbf{\Omega_n}}^{(t-j)}$ with the angle of rotation $\mathbf{\Omega_n}(t-j)$:
\[\cos\left(\mathbf{\Omega_n}(t-\tau)-\phi (\mathbf{\tilde{q}_{t,n}})-\phi (\mathbf{\tilde{k}_{j,n}})\right) =\] \[\begin{bmatrix} \cos(\phi (\mathbf{\tilde{q}_{t,n}})) \\ \sin(\phi (\mathbf{\tilde{q}_{t,n}})) \end{bmatrix}^{\!\top} \! \textcolor{blue} { \underbrace{ \begin{bmatrix} \cos(\mathbf{\Omega_n}(t-j)) & -\sin(\mathbf{\Omega_n}(t-j)) \\ \sin(\mathbf{\Omega_n}(t-j)) & \cos(\mathbf{\Omega_n}(t-j)) \end{bmatrix} }_{\textbf{Rotation Matrix = } R_{\mathbf{\Omega_n}}^{t-j} } } \! \begin{bmatrix} \cos(\phi (\mathbf{\tilde{k}_{j,n}})) \\ \sin(\phi (\mathbf{\tilde{k}_{j,n}})) \end{bmatrix}.\]which we use the fact that $R_{(t-j)\mathbf{\Omega}}=R_{\mathbf{\Omega}}^{(t-j)}$ known as Rotation Composition.
Applying the rotation matrix to the output of the recurrence:
\[\sum_{n=1}^{d/2} \lvert \mathbf{\tilde{q}_{t,n}} \rvert \, \lvert \mathbf{\tilde{k}_{j,n}} \rvert \, \cos\!\left( \mathbf{\Omega_n} (t - j) - \phi (\mathbf{\tilde{q}_{t,n}}) - \phi (\mathbf{\tilde{k}_{j,n}}) \right) =\] \[\sum_{n=1}^{d/2} \lvert \mathbf{\tilde{q}_{t,n}} \rvert \, \lvert \mathbf{\tilde{k}_{j,n}} \rvert \, \begin{bmatrix} \cos(\phi (\mathbf{\tilde{q}_{t,n}})) \\ \sin(\phi (\mathbf{\tilde{q}_{t,n}})) \end{bmatrix}^{\!\top} \! \textcolor{blue}{R_{\mathbf{\Omega_n}}^{t-j}} \! \begin{bmatrix} \cos(\phi (\mathbf{\tilde{k}_{j,n}})) \\ \sin(\phi (\mathbf{\tilde{k}_{j,n}})) \end{bmatrix}=\] \[\sum_{n=1}^{d/2} \lvert \mathbf{\tilde{q}_{t,n}} \rvert \, \begin{bmatrix} \cos(\phi (\mathbf{\tilde{q}_{t,n}})) \\ \sin(\phi (\mathbf{\tilde{q}_{t,n}})) \end{bmatrix}^{\!\top} \! \textcolor{blue}{R_{\mathbf{\Omega_n}}^{t-j}} \lvert \mathbf{\tilde{k}_{j,n}} \rvert \, \begin{bmatrix} \cos(\phi (\mathbf{\tilde{k}_{j,n}})) \\ \sin(\phi (\mathbf{\tilde{k}_{j,n}})) \end{bmatrix}=\] \[\sum_{n=1}^{d/2} \begin{bmatrix} \mathcal{R} (\mathbf{\tilde{q}_{t,n}}) \\ \mathcal{I} (\mathbf{\tilde{q}_{t,n}}) \end{bmatrix}^{\!\top} \! \textcolor{blue}{R_{\mathbf{\Omega_n}}^{t-j}} \! \begin{bmatrix} \mathcal{R} (\mathbf{\tilde{k}_{t,n}}) \\ \mathcal{I} (\mathbf{\tilde{k}_{t,n}}) \end{bmatrix}.\]By stacking the real and imaginary parts of queries and keys into a real-vector of shape $d$ we have:
\[\mathbf{q_t} = \bigoplus_{n=1}^{d/2} \begin{bmatrix} \mathcal{R} (\mathbf{\tilde{q}_{t,n}}) \\ \mathcal{I} (\mathbf{\tilde{q}_{t,n}}) \end{bmatrix}, \quad \mathbf{k_t} = \bigoplus_{n=1}^{d/2} \begin{bmatrix} \mathcal{R} (\mathbf{\tilde{k}_{t,n}}) \\ \mathcal{I} (\mathbf{\tilde{k}_{t,n}}) \end{bmatrix}.\]This way, we do not actually need any computation in the complex domain; we can do everything in real numbers using $\mathbf{q_t, k_t}$. Using this we can write the attention-score as:
\[\mathbf{q^\top_{t}} \underbrace{\left(\bigoplus_{n=1}^{d/2} \textcolor{black}{R_{\mathbf{\Omega_n}}^{t-j}}\right)}_{\mathbf{R}^{(t-j)}_{\Omega}} \mathbf{k_{j}} = \mathbf{q^\top_{t}} \mathbf{R}^{(t-j)}_{\Omega} \mathbf{k_{j}} = \left(\mathbf{q^\top_{t}} \mathbf{R}^{t}_{\Omega} \right) \left( \mathbf{R}^{-j}_{\Omega} \mathbf{k_{j}} \right)\]By using the inverse of the rotation matrix, we can directly apply the rotation of the forget gate $\mathbf{A}_t = \exp(i\mathbf{\Omega})$ to the queries and keys.
\[\left(\mathbf{q^\top_{t}} \mathbf{R}^{t}_{\Omega} \right) \left( \mathbf{R}^{-j}_{\Omega} \mathbf{k_{j}} \right) = \textcolor{blue}{ \left( \mathbf{R}_{-t\Omega} \mathbf{q_{t}} \right)^\top} \textcolor{red}{ \left( \mathbf{R}_{-j\Omega} \mathbf{k_{j}} \right)} = \textcolor{blue}{ \mathbf{\bar q^\top_{t}}} \textcolor{red}{ \mathbf{\bar k_{j}} }\] \[\mathbf{R}^{(t)}_{\Omega} = \bigoplus_{n=1}^{d/2} \textcolor{black}{R_{\mathbf{\Omega_n}}^{t}}\]To wrap-up we get back to our complex-valued recurrence:
\[\mathbf{S}_t = \mathbf{S}_{t-1}\,\textcolor{black}{e^{i\mathbf{\Omega}}} + \mathbf{v}_t \tilde{\mathbf{k}}_t^{\mathrm{H}}, \quad \mathbf{o}_t = \mathcal{R} \{\mathbf{S}_t \tilde{\mathbf{q}}_t\}\]We can apply this forget gate in the real domain by stacking the imaginary and real parts of the queries and keys, and forming the output directly:
\[\mathbf{o}_t = \left(\mathbf{q_{t}} \mathbf{R}_{-t\Omega} \right) \sum_{j=1}^t \mathbf{v}_j \left( \mathbf{R}_{-j\Omega}\mathbf{k_{j}} \right)^\top\]The construction of rotation matrix based on the angles of $\mathbf{\Omega}$ is the well-known multiplicative position embedding of Rotary Position embedding (RoPE)
\[\begin{aligned} & \textcolor{black}{\mathbf{S}_t} \textcolor{black}{= \mathbf{S}_{t-1}\, \textcolor{blue}{e^{i\mathbf{\Omega}}} + \mathbf{v}_t \tilde{\mathbf{k}}_t^{\mathrm{H}}}, \quad \textcolor{black}{\mathbf{o}_t = \mathcal{R}\{\mathbf{S}_t \tilde{\mathbf{q}}_t\}}, \\ & \textcolor{black}{\text{Softmax}\!\left( \textcolor{blue}{\mathbf{\bar Q}}\,\textcolor{blue}{\mathbf{\bar K^\top}} \right)} , \quad \textcolor{blue}{\mathbf{\bar q_{t}}} \textcolor{black}{= \mathbf{R}_{-t\Omega} \mathbf{q_{t}}} , \quad \textcolor{blue}{\mathbf{\bar k_{j}}} \textcolor{black}{= \mathbf{R}_{-j\Omega} \mathbf{k_{t}}} \end{aligned}\]
Therefore, RoPE, admits a direct duality with the complex linear Transformers as also stated in original paper
import torch
def rope(Q, K, V):
"""
Simple RoPE Attention (original paper)
Q, K, V: (d, T), d must be even
"""
d, T = Q.shape
device, dtype = Q.device, Q.dtype
# Original RoPE frequencies
m = torch.arange(0, d, 2, device=device, dtype=dtype) / d
omegas = 1.0 / (10000 ** m) # (d/2,)
t = torch.arange(T, device=device, dtype=dtype) # (T,)
theta = omegas[:, None] * t[None, :] # (d/2, T)
cos, sin = torch.cos(theta), torch.sin(theta)
def rotate(x):
x1, x2 = x[0::2], x[1::2]
y = torch.empty_like(x)
y[0::2] = x1 * cos - x2 * sin
y[1::2] = x1 * sin + x2 * cos
return y
Qr = rotate(Q)
Kr = rotate(K)
out = torch.softmax(Qr.T @ Kr, dim=-1) @ V.T
return outUp to this point, all linear models considered employ scalar or diagonal forget gates. DeltaNet
with $\beta_t=\sigma(\mathbf{w}^\top \mathbf{x_t})$ being an input-dependent scalar. Above recurrence follows a Householder-like construction. Under this design, the recurrence relation becomes:
Resulting in cumulative product of forget gate as:
\[\prod_{s=j+1}^t \mathbf{A}_s = \prod_{s=j+1}^t \textcolor{blue}{\mathbf{(I- \beta_s k_s k^\top_s)}}\]Similar to other positional embeddings, DeltaNet’s cumulative forget gate implicitly defines a relative positional embedding. This structure can be directly incorporated into softmax-based Transformers, resulting in an attention mechanism that is refered to as Position encoding based on Accumulated Products of Householder(-like) Transformations (PaTH) attention
\[\begin{aligned} & \textcolor{black}{\mathbf{S}_t} \textcolor{black}{= \mathbf{S}_{t-1}\, \textcolor{blue}{\mathbf{(I- \beta_t k_t k^\top_t)}} + \mathbf{v}_t {\mathbf{k}}_t^{\top}}, \quad \textcolor{black}{\mathbf{o}_t = \mathbf{S}_t {\mathbf{q}}_t}, \\ & \textcolor{black}{\mathbf{Att}_{ij} = \exp( \mathbf{k^\top_j}}\textcolor{blue}{\prod_{s=j+1}^i \mathbf{(I- \beta_s k_s k^\top_s)}}\textcolor{black}{\mathbf{q_i} )},\quad \textcolor{black}{\text{SoftMax} (\mathbf{\textcolor{blue}{Att}})} \end{aligned}\]
Therefore, PaTH, admits a direct duality with the DeltaNet.
import torch
def PaTH_attn(Q,K):
"""
Simple PaTH implementation
Q,K: input (d , T)
"""
H = beta.reshape(-1,1,1) * ( torch.eye(d).unsqueeze(0) - K.unsqueeze(-1) @ K.unsqueeze(-2) )
A = torch.zeros(T, T)
for i in range(T):
for j in range(i + 1): # usually j <= i
# Compute product of H_s from s=j+1 to i
if j + 1 <= i:
P = torch.eye(d)
for s in range(j + 1, i + 1):
P = H[s] @ P
else:
P = torch.eye(d)
# k_j^T (P q_i)
score = K[j] @ (P @ Q[i])
A[i, j] = torch.exp(score)
return A , raw_scores
⚠️ PaTH supports parallelization across the sequence length ($T$) and leverages an efficient implementation based on the ($UT$) transform. Further implementation details are available in the fla-repository.
As DeltaNet uses Householder products as a forget gate, it lacks sharp decay as in Mamba2 and GLA therefor Gated DeltaNet
with $\mathbf{w}$ is a learnable weight vector, and the decay term follows the scalar decay mechanism of Mamba-2.
In the context of positional embeddings for softmax Transformers, the use of the cumulative forget gate as a positional embedding is not entirely new. It can be interpreted as a combination of PaTH and FoX, both described in the previous sections. As a result, PaTH-FoX emerges as a PE that has both additive and multiplicative embeddings together as:
\[\begin{aligned} & \textcolor{black}{\mathbf{S}_t} \textcolor{black}{= \mathbf{S}_{t-1}\, \textcolor{red}{a_t}\textcolor{blue}{\mathbf{(I- \beta_t k_t k^\top_t)}} + \mathbf{v}_t {\mathbf{k}}_t^{\top}}, \quad \textcolor{black}{\mathbf{o}_t = \mathbf{S}_t {\mathbf{q}}_t}, \\ & \textcolor{black}{\mathbf{Att}_{ij} = \mathbf{k^\top_j}}\textcolor{blue}{\prod_{s=j+1}^i \mathbf{(I- \beta_s k_s k^\top_s)}}\textcolor{black}{\mathbf{q_i} },\quad \textcolor{black}{\text{SoftMax} (\mathbf{\textcolor{blue}{Att}+\textcolor{red}{D}})} , \\ & \textcolor{red} {\mathbf{D_{ij}}} \textcolor{black}{=\begin{cases} \textcolor{red}{\sum_{s=j+1}^t \log (\sigma(\mathbf{wx_t})^{1/\tau}) }& i \ge j \\ - \infty & i<j \end{cases} } \end{aligned}\]
And PaTH-FoX
import torch
def Path-Fox(Q,K):
"""
Simple PaTH-FoX implementation
Q,K: input (d , T)
w : learnable weight of fox
x : input (d , T)
"""
D_add = fox_pe(w, x)
_ , scores = PaTH_attn(Q,K)
Attn = scores + D_add
return AttnFinally, the table below provides a compact overview, summarizing all the key dualities between linear transformers and positional embeddings discussed above, at a single glance 🙂
If you’re interested in learning more about different variants of linear transformers, I highly recommend checking out Table 4 of
You may also enjoy the excellent blog posts by:
which provide intuitive explanations and deeper insights into the topic.
If you use find this blog useful, please consider citing me:
@article{afzal2025legacy,
title={On the Legacy of Linear Transformers in Positional Embedding},
author={Afzal Arshia},
year={2026},
url={https://arshiaafzal.github.io/blog/2026/pe/},
}