← Back to Blog

SDPA is Entropy-Regularized Optimal Transport

|30 min read
attentionoptimal-transportinformation-geometrytransformers

Scaled dot product attention (SDPA) is usually introduced as something like this: "We compute dot products between queries and keys to measure similarity, then apply softmax to get a probability distribution." If you then prod further as to why we use softmax specifically, you would probably get a reply that "it's a differentiable approximation to argmax" or maybe that it "gives us nice gradients."

Both of these things are true, and almost certainly were the reasons that softmax was originally picked. But there's more going on. SDPA is the unique optimal solution to a well-defined mathematical problem, and the story of how we get there connects 18th-century French mathematics, modern reinforcement learning, and the geometry of probability distributions.

This post is based on a very interesting essay by Zartbot on WeChat covering an original paper by Litman, 2025. The majority of what follows is a translation of his blog post, with further expansion around explanations and connections done by me. You can find more of Zartbot's translated essays on his GitHub.


The Optimization Problem Hidden in Attention

What attention actually computes

Let's start with the standard SDPA for a single query. We have:

The attention weights are:

αj=exp(qkj/τ)iexp(qki/τ)\alpha_j = \frac{\exp(q \cdot k_j / \tau)}{\sum_i \exp(q \cdot k_i / \tau)}

This gives you a probability distribution α\alpha over the keys, and we use these weights to take a weighted average of the value vectors. So far the standard story is holding true, but the interesting part is that it turns out these attention weights are the unique solution to the following optimization problem (we'll get into what this problem actually is next):

α=argminαΔ[jcjαjτH(α)]\alpha^* = \arg\min_{\alpha \in \Delta} \left[ \sum_j c_j \alpha_j - \tau H(\alpha) \right]

where:

Put simply, attention finds the distribution that minimizes expected cost while maximizing entropy. This perspective of attention tells us that it's actually balancing two competing objectives:

  1. Minimize cost: We want to concentrate attention on keys with a high similarity to the query
  2. Maximize entropy: Spread the attention out, don't put all your eggs in one basket

The cost-entropy tradeoff in attention. Left: attention weights over six keys. Right: the underlying query-key similarities (scores). As temperature τ increases, attention spreads from the highest-similarity keys toward uniform. The dashed line marks 1/n and keys above this line receive more than their "fair share" of attention. Try adjusting τ to see how the balance shifts between exploiting high-similarity keys and exploring the full context.

Here, the temperature parameter τ\tau controls the tradeoff. A large τ\tau means more entropy (smoother attention), while a small τ\tau means lower cost (peakier attention). You can derive the softmax formula directly from this optimization problem using basic calculus. If we set up the Lagrangian with the constraint jαj=1\sum_j \alpha_j = 1:

L=jcjαj+τjαjlogαj+λ(1jαj)\mathcal{L} = \sum_j c_j \alpha_j + \tau \sum_j \alpha_j \log \alpha_j + \lambda(1 - \sum_j \alpha_j)

Then take the derivative with respect to αj\alpha_j and set it to zero:

cj+τ(1+logαj)λ=0c_j + \tau(1 + \log \alpha_j) - \lambda = 0

Solve for αj\alpha_j:

αj=exp(λτcjτ)exp(cj/τ)=exp(qkj/τ)\alpha_j = \exp\left(\frac{\lambda - \tau - c_j}{\tau}\right) \propto \exp(-c_j/\tau) = \exp(q \cdot k_j / \tau)

After normalizing to satisfy the constraint, we get exactly the softmax formula:

αj=exp(qkj/τ)iexp(qki/τ)\alpha_j = \frac{\exp(q \cdot k_j / \tau)}{\sum_i \exp(q \cdot k_i / \tau)}
Physics Sidenote

If you've seen statistical mechanics, this formula might look familiar. The Boltzmann distribution describes the probability of a system being in state jj with energy EjE_j at temperature TT:

P(j)exp(Ej/kBT)P(j) \propto \exp(-E_j / k_B T)

Our attention formula has the same structure: αjexp(cj/τ)\alpha_j \propto \exp(-c_j / \tau), where the cost cj=qkjc_j = -q \cdot k_j plays the role of energy. The temperature τ\tau controls how sharply the distribution concentrates on low-energy (high-similarity) states. Both formulas arise from maximizing entropy subject to a constraint on expected value. In physics, that's expected energy, while in SDPA, it's expected cost.


Optimal Transport and Moving Dirt Efficiently

The optimization problem we talked about above turns out to have a name: one-sided entropic optimal transport. To understand where this comes from, let's take a brief detour to 18th century France.

The dirt-moving problem

In 1781, French mathematician Gaspard Monge posed an assumedly hypothetical question: Suppose you have a pile of dirt and you want to fill in a hole of the same volume, what's the most efficient way to move the dirt?

More formally: You have a source distribution μ\mu (the dirt pile) and a target distribution ν\nu (the hole). You want to find a "transport plan" that moves mass from the source to the target while minimizing total transportation cost. Monge originally required each grain of dirt to go to exactly one location, a deterministic mapping. In 1942, Soviet mathematician Leonid Kantorovich relaxed this constraint and allowed dirt to be split, with fractions going to different destinations. This makes the problem convex and much easier to solve.

The output of Kantorovich's problem is a coupling PP, or a join distribution whose marginals are μ\mu and ν\nu. The entry Pi,jP_{i,j} tells us how much mass flows from source ii to target jj.

Two-sided vs one-sided constraints

Standard optimal transport requires bilateral constraints, i.e., the transport plan PP must satisfy:

This is like a matching problem: every source must send its mass somewhere, and every target must receive its allocated amount. Finding such a plan requires iterative algorithms like Sinkhorn-Knopp, which alternates between normalizing rows and columns until convergence. Sinkhorn Attention applies this bilateral framework to transformers, requiring the attention matrix to be doubly stochastic (both rows and columns sum to 1). This enforces a form of "fairness" where every key receives equal attention across all queries.

However, standard softmax attention is different, it only constrains rows (each query's attention sums to 1). There's no requirement on columns, so some keys might receive massive attention while others receive almost none. This one-sided constraint is why softmax has a closed-form solution. Without the column constraint, the optimization problem decouples across queries, and each row can be solved independently via a simple normalization. SDPA's efficiency comes from solving a relaxed transport problem. We remove the bilateral "fairness" constraint in exchange for a closed-form solution that doesn't require iteration.

Entropic regularization: making transport smooth

Classical optimal transport has a problem for us though, the solutions are often sparse and non-differentiable. In 2013, Cuturi discovered that adding an entropy term resolves this conflict (Cuturi, 2013):

minP[ijCijPijεH(P)]\min_P \left[ \sum_{ij} C_{ij} P_{ij} - \varepsilon H(P) \right]

subject to the marginal constraints (rows sum to μ\mu, columns sum to ν\nu).

This is called entropic optimal transport (EOT). The entropy term:

The optimal coupling takes the form PijeCij/εP_{ij} \propto e^{-C_{ij}/\varepsilon}, known as the Gibbs kernel. This kernel is strictly positive: every source-target pair gets some mass, though irrelevant pairs get exponentially little. This strict positivity is what makes the solution smooth and differentiable.

Attention as one-sided transport

Now, classical optimal transport also has two marginal constraints (source μ\mu and target ν\nu). But in attention, the source is a single query with mass = 1, and there's no constraint on how much attention each key receives. This is called one-sided optimal transport. The query distributes its attention budget to the keys, but we don't require each key to receive any particular amount. When you have only a source constraint and use entropy regularization, the problem becomes:

minαΔ[jcjαjτH(α)]\min_{\alpha \in \Delta} \left[ \sum_j c_j \alpha_j - \tau H(\alpha) \right]

That looks familiar, it turns out to be exactly our attention formula! The solution is closed-form (softmax) rather than requiring iterative algorithms like Sinkhorn.

mHC Sidenote

While standard SDPA uses a one-sided constraint (only row sums = 1), giving a closed-form softmax solution, the bilateral doubly stochastic constraint appears in other transformer contexts. For example, Sinkhorn attention applies it to attention matrices for "fairness" across keys. More recently, DeepSeek's mHC architecture uses the Sinkhorn-Knopp algorithm to constrain residual stream mixing matrices to be doubly stochastic, ensuring stable signal propagation across layers. In both cases, the Birkhoff polytope (doubly stochastic matrices) provides geometric structure, but for different purposes: attention allocation vs. residual stability.

Therefore, every time you compute attention weights, you're solving an optimal transport problem. The query is "moving mass" to the keys, and the softmax gives the unique transport plan that balance cost-efficiency with smoothness.


Backprop is Secretly Doing Reinforcement Learning

The forward pass solves an optimal transport problem. It turns out that backpropagation through attention is mathematically equivalent to a policy gradient algorithm from reinforcement learning.

Setting up the RL interpretation

If we think of attention as a "policy" that chooses which keys to attend to:

The gradient formula

When we backpropagate through softmax attention, the gradient with respect to the pre-softmax score sj=qkj/τs_j = q \cdot k_j / \tau is:

Lsj=1ταj(uˉuj)\frac{\partial L}{\partial s_j} = \frac{1}{\tau} \alpha_j (\bar{u} - u_j)

where uˉ=kαkuk\bar{u} = \sum_k \alpha_k u_k is the expected utility under the current attention distribution.

Let's look at the structure of this gradient:

Note the sign: the gradient points away from high-utility keys (since we're computing L/sj\partial L / \partial s_j and want to minimize loss). When we subtract this gradient during optimization, the sign flips and we get the familiar RL form. More on that below.

This is exactly the REINFORCE algorithm with a baseline! REINFORCE is the foundational policy gradient method from RL (Williams, 1992). The basic idea: if you want to learn a policy (a probability distribution over actions), you update it by:

θJE[θlogπθ(a)R(a)]\nabla_\theta J \propto \mathbb{E}\left[\nabla_\theta \log \pi_\theta(a) \cdot R(a)\right]

Basically: increase the log-probability of actions proportionally to their reward. The problem is high variance since rewards can be noisy. The fix is to subtract a baseline (typically the expected reward under the current policy):

θJE[θlogπθ(a)(R(a)b)]\nabla_\theta J \propto \mathbb{E}\left[\nabla_\theta \log \pi_\theta(a) \cdot (R(a) - b)\right]

The term (R(a)b)(R(a) - b) is called the advantage, or, how much better this action is than average. This is exactly what we derived for attention:

When we do gradient descent with learning rate η\eta, we update sjsjηLsjs_j \leftarrow s_j - \eta \frac{\partial L}{\partial s_j}. Substituting our gradient formula:

sjsjηταj(uˉuj)=sj+ηταj(ujuˉ)s_j \leftarrow s_j - \frac{\eta}{\tau} \alpha_j (\bar{u} - u_j) = s_j + \frac{\eta}{\tau} \alpha_j (u_j - \bar{u})

So if key jj has above-average utility (uj>uˉu_j > \bar{u}), its score increases. If it has below-average utility, its score decreases.

Thus, every time we train a transformer with standard backpropagation, the attention layers are learning via policy gradients. Keys that contribute above-average utility get "rewarded" (higher attention in the future), while keys that contribute below-average utility get "penalized." To be clear, this is not just a metaphor, it is literally the same mathematics. The gradient descent update to attention scores follows the exact form of a variance-reduced policy gradient.


The Geometry That Connects Everything

We've now been through two elegant results:

  1. The forward pass is equivalent to solving an optimal transport problem
  2. The backward pass is policy gradient reinforcement learning

The underlying connection between these two results is information geometry.

Fisher Information and the curvature of probability space

The space of probability distributions isn't flat, but curved. The "natural" way to measure distances on this space uses the Fisher Information Matrix (FIM). For the attention distribution α\alpha, the FIM is:

F=diag(α)ααT=MF = \text{diag}(\alpha) - \alpha \alpha^T = M

This matrix tells us how "sensitive" the distribution is to changes in the underlying scores, and its entries are the covariances of the score functions. (This is the FIM in probability coordinates. In score coordinates, the softmax Jacobian introduces a factor of 1/τ1/\tau, giving Fs=1τMF_s = \frac{1}{\tau}M, which is why the Hessian 2ψ=1τM\nabla^2 \psi = \frac{1}{\tau}M appears below.)

The connection: Hessian = Fisher Information

This unification emerges from Lagrangian duality. Lagrangian duality is a technique from optimization where instead of solving a constrained problem directly (the "primal"), you create an equivalent "dual" problem by folding the constraints into the objective with multipliers. The dual is often easier to work with, and under certain conditions (which hold here), both problems have the same optimal value, this is called "strong duality." The primal EOT problem minimizes over distributions:

Vp(s)=minαΔ[s,α+τjαjlogαj]V_p(s) = \min_{\alpha \in \Delta} \left[ -\langle s, \alpha \rangle + \tau \sum_j \alpha_j \log \alpha_j \right]

This says: find the distribution α\alpha on the simplex that minimizes the negative expected score (i.e., maximizes expected similarity) plus an entropy penalty. The first term s,α-\langle s, \alpha \rangle rewards high scores; the second term τjαjlogαj\tau \sum_j \alpha_j \log \alpha_j is negative entropy scaled by temperature, which penalizes overly concentrated distributions.

This primal value function has a useful property from the envelope theorem: its gradient equals the negative optimal solution, Vp=α\nabla V_p = -\alpha^*. That negative sign is a bit awkward though, so Lagrangian duality gives us a cleaner object. When we form the dual problem (relaxing the simplex constraint with a Lagrange multiplier), the dual optimal value turns out to be:

ψ(s)=τlogjexp(sj/τ)\psi(s) = \tau \log \sum_j \exp(s_j / \tau)

This is the log-sum-exp (LSE) function, which is known in numerical computing as the "softmax denominator." By strong duality, ψ(s)=Vp(s)\psi(s) = -V_p(s), so that awkward negative sign flips and disappears. So the LSE function is the optimal value of an entropy-regularized transport problem. From that, we can see that the gradients tell the whole story.

First derivative (gradient)

ψsj=τ1kexp(sk/τ)1τexp(sj/τ)=αj\frac{\partial \psi}{\partial s_j} = \tau \cdot \frac{1}{\sum_k \exp(s_k/\tau)} \cdot \frac{1}{\tau}\exp(s_j/\tau) = \alpha_j \quad \checkmark

The gradient of LSE is the attention distribution simply as a consequence of duality. The optimal solution to the primal problem equals the gradient of the dual optimal value.

Second derivative (Hessian)

2ψsjsk=αjsk\frac{\partial^2 \psi}{\partial s_j \partial s_k} = \frac{\partial \alpha_j}{\partial s_k}

Using the softmax Jacobian:

αjsk=1ταj(δjkαk)\frac{\partial \alpha_j}{\partial s_k} = \frac{1}{\tau}\alpha_j(\delta_{jk} - \alpha_k)

So:

2ψ=1τ(diag(α)ααT)=1τM\nabla^2 \psi = \frac{1}{\tau}(\text{diag}(\alpha) - \alpha\alpha^T) = \frac{1}{\tau}M

The Hessian of LSE is (proportional to) the Fisher Information Matrix. This is why everything connects. LSE is the dual potential that generates both the forward solution (via its gradient) and the backward geometry (via its Hessian). The softmax Jacobian you compute in backprop is the Hessian of an optimal transport problem. Furthermore, the curvature of the optimal transport objective is the Fisher Information of the attention distribution. This identity connects everything.

Why this isn't a coincidence

The identity 2ψ=1τF\nabla^2 \psi = \frac{1}{\tau}F is a general property of exponential families, not something specific to attention. The softmax distribution αjexp(sj/τ)\alpha_j \propto \exp(s_j/\tau) is an exponential family with natural parameters ηj=sj/τ\eta_j = s_j/\tau and log-partition function A(η)=logjexp(ηj)=ψ(s)/τA(\eta) = \log \sum_j \exp(\eta_j) = \psi(s)/\tau. For any exponential family, the Hessian of the log-partition function equals the covariance of the sufficient statistics, which for categorical distributions is exactly the FIM. So the Hessian-FIM connection is guaranteed by the exponential family structure of softmax, and the OT framework gives us a new lens on why softmax takes this form.

One matrix, three identities

So our deepest result is that a single matrix appears in all three frameworks:

M=diag(α)ααT\boxed{M = \text{diag}(\alpha) - \alpha \alpha^T}
FrameworkRole of MMInterpretation
Optimal Transport2ψ=1τM\nabla^2 \psi = \frac{1}{\tau} MHessian of dual potential, curvature of the value landscape
Information GeometryF=MF = MFisher Information Matrix, natural metric on probability space
Reinforcement LearningLs=1τMu\frac{\partial L}{\partial s} = \frac{1}{\tau} M \cdot uControl matrix, transforms utility into gradient

As the saying goes: "Once is accident, twice is a coincidence, three times is a pattern". The matrix MM is:

Attention simplex
Softmax attention as a point on the probability simplex. Each corner represents putting all attention on one key; the center is the uniform distribution (maximum entropy). The heatmap shows entropy across the simplex. As temperature τ increases, the optimal solution α* traces a path from the low-cost corner toward the high-entropy center. The star marks the solution at standard scaling (τ = √d).

These are three descriptions of the same geometric object. The optimization landscape is the statistical manifold is the control geometry. And backpropagation through softmax computes MuM \cdot u, which is simultaneously:

  1. A Newton-like step using the optimization Hessian
  2. The natural gradient on probability space
  3. An advantage-weighted policy gradient

Why backprop implements natural gradients

This is a bit of a technical note so feel free to skip if you're not interested, it's a subtle but interesting point. The natural gradient is defined as:

~L=F1L\tilde{\nabla} L = F^{-1} \nabla L

It's the direction of steepest descent in the information geometry of the probability space, not Euclidean space. Natural gradient methods like K-FAC explicitly compute and invert the FIM. For attention, the standard gradient with respect to scores is:

Ls=1τMu=1τFu\frac{\partial L}{\partial s} = \frac{1}{\tau} M \cdot u = \frac{1}{\tau} F \cdot u

Thus, the natural gradient would be:

F1Ls=F11τFu=1τuF^{-1} \frac{\partial L}{\partial s} = F^{-1} \cdot \frac{1}{\tau} F \cdot u = \frac{1}{\tau} u

The natural gradient is proportional to the marginal utility vector. This means:

This is why the advantage-base form emerges automatically. The softmax Jacobian is the FIM (with scaling), so backprop through softmax automatically applies the information-geometric correction that RL algorithms (like natural policy gradient) have to compute explicitly.


A Design Framework for New Attention Mechanisms

The most practical consequence of this theory is that it gives us a principled way to design new attention mechanisms. Instead of guessing and checking, we could:

  1. Choose an objective function with a different regularizer
  2. Derive the optimal solution
  3. Get a new attention mechanism with known properties

The variational framework

The general form is:

α=argminαΔ[s,α+Ω(α)]\alpha^* = \arg\min_{\alpha \in \Delta} \left[ -\langle s, \alpha \rangle + \Omega(\alpha) \right]

where Ω\Omega is a convex regularizer. Different choices of Ω\Omega give different mechanisms:

Regularizer Ω(α)\Omega(\alpha)Attention MechanismProperties
τH(α)-\tau H(\alpha) (negative entropy)SoftmaxDense, smooth, closed-form
12α2\frac{1}{2}\|\alpha\|^2 (L2 norm)SparsemaxSparse (exactly zero weights)
Negative Tsallis entropyα\alpha-entmaxTunable sparsity
τH(α)+mjijαj-\tau H(\alpha) + m\sum_j \|i-j\| \alpha_jALiBi-likeLocality bias
τKL(απ)\tau \cdot KL(\alpha \| \pi)PriorSoftmaxIncorporates prior beliefs

Example: Deriving ALiBi from first principles

ALiBi (Attention with Linear Biases) is a position encoding method that adds a distance-based penalty to attention logits. In the original paper, it was introduced as a heuristic. However, in this variational framework, it emerges naturally. If we add a linear position penalty to the entropy regularizer:

Ω(α)=τH(α)+mjijαj\Omega(\alpha) = -\tau H(\alpha) + m \sum_j |i - j| \alpha_j

The solution is:

αjexp(sjmijτ)\alpha_j \propto \exp\left(\frac{s_j - m|i-j|}{\tau}\right)

which is exactly ALiBi! The position penalty in the regularizer becomes an additive bias in the logits. The framework explains why ALiBi works: it's the optimal transport plan when you care not just about similarity, but also about distance.

PriorSoftmax: attention as Bayesian inference

The most conceptually rich variant uses KL divergence to a prior as the regularizer:

α=argminαΔ[s,α+τKL(απ)]\alpha^* = \arg\min_{\alpha \in \Delta} \left[ -\langle s, \alpha \rangle + \tau \cdot KL(\alpha | \pi) \right]

where π\pi is a prior distribution over keys. The solution is:

αj=πjexp(sj/τ)kπkexp(sk/τ)\alpha_j = \frac{\pi_j \exp(s_j / \tau)}{\sum_k \pi_k \exp(s_k / \tau)}

which has a nice Bayesian interpretation:

Standard softmax is the special case where the prior is uniform (πj=1/n\pi_j = 1/n). This reveals something about ALiBi as well: its additive position bias mij-m|i-j| in the logits is equivalent to a log-prior:

logπj=mij/τ+const\log \pi_j = -m|i-j| / \tau + \text{const}

In other words, ALiBi encodes a prior belief that nearby keys are more relevant. The "heuristic" position penalty is principled Bayesian reasoning.


Implications for Transformer Design

Based on this framework and some other research such as the Physics of Language Models (PoLM) series by Ziming Liu, Zeyuan Allen-Zhu, and collaborators, we can put together some implications for understanding and designing transformers. Note that this section is getting more into hypotheses than concrete math.

Why linear attention loses something fundamental

Linear attention methods (like those in Mamba or linear transformers) approximate the softmax kernel:

exp(qk/τ)ϕ(q)ϕ(k)\exp(q \cdot k / \tau) \approx \phi(q) \cdot \phi(k)

for some feature map ϕ\phi, which enables O(n)O(n) complexity instead of O(n2)O(n^2). This approximation destroys the Gibbs kernel structure that defines the EOT solution. Recall that the Gibbs kernel is strictly positive, guaranteeing the optimal solution lives in the simplex interior where attention weights are always nonzero (just exponentially small for irrelevant keys). Feature map kernels can be zero, so the solution no longer satisfies the transport constraints at all. The entropy-cost tradeoff that makes softmax attention principled simply does not exist in the linear approximation.

The EOT framework predicts that linear attention should fail when you need precise "transport", when it matters exactly which keys attend to which queries. This matches empirical findings from PoLM Part 4.1 (Allen-Zhu, 2025) that linear attention struggles with multi-hop reasoning tasks where information needs to be retrieved precisely across multiple steps.

Why depth matters more than width for reasoning

Each attention layer solves one optimal transport problem, one routing decision about where to send information. Multi-hop reasoning requires composing these decisions: first retrieve fact A, then use A to look up fact B, then use B to answer the question. Width (more heads, larger dimensions) increases the capacity of each decision, but depth increases the number of sequential decisions. You cannot parallelize "use the output of step 1 as input to step 2" by making step 1 wider.

This aligns with empirical findings from PoLM Part 2.1, which shows that deeper models outperform wider models on reasoning tasks even with similar parameter counts. Their probing experiments make this concrete: shallower layers correctly identify dependencies close to the query, while deeper layers are required for dependencies further away. The model performs layer-by-layer reasoning, recursively building up the dependency graph across depth.

The dispersion problem: why attention struggles with length

The EOT framework predicts a fundamental limitation. DeepMind's paper 'Softmax is Not Enough' (Wortsman et al., 2024) proves:

If inputs have bounded range and temperature τ\tau is constant, then as the number of keys nn \to \infty, every attention weight must approach 1/n1/n.

Attention dispersion
Attention dispersion with sequence length. Even when one key has much higher similarity than the distractors (3.0 vs 0.5), its attention weight decays as the number of keys grows. This is the 'softmax is not enough' phenomenon: with bounded scores and fixed temperature, attention necessarily dilutes over long contexts.

In other words, attention necessarily disperses with sequence length. No matter how relevant one key is, with enough keys in the context, its attention weight gets diluted. Why does this happen? The softmax denominator jexp(sj/τ)\sum_j \exp(s_j/\tau) grows with nn. If scores are bounded (say, in [B,B][-B, B]), then each weight is at most exp(B/τ)/n\exp(B/\tau) / n, which vanishes as nn \to \infty.

For transformers, this is concerning. The inputs to attention (after layer norm) typically have a bounded range, so theoretically, attention should degrade on very long sequences. Indeed, empirically we can see that it does too.

Temperature selection

One possible solution to the problem described above is adaptive temperature: compute the entropy of the attention distribution and adjust τ\tau to maintain a target entropy level. In the EOT framework, this means dynamically adjusting the cost-entropy tradeoff based on how many keys you're choosing among. This connects to some recent work on long-context attention (like Ring Attention, landmark tokens, etc.), they're all fighting the same underlying mathematical constraint.

The standard τ=d\tau = \sqrt{d} scaling ensures that qkq \cdot k has variance roughly independent of dimension. In the EOT framework, this means the cost scale stays constant, so the entropy-cost tradeoff remains stable. The theory suggests that temperature should perhaps scale with sequence length too: longer sequences mean more keys, so you might want stronger entropy regularization to prevent attention from becoming too sparse. This is an open research direction as well.


Connections to Mechanistic Interpretability

This is some of my own thinking, and may be entirely wrong, but maybe it provokes thinking on the bridge between this theoretical framework and empirical interpretability research.

What the EOT framework says about attention head circuits

In A Mathematical Framework for Transformer Circuits (Elhage et al., 2021), Anthropic decomposes attention heads into two separable operations:

  1. QK circuit: WETWQKhWEW_E^T W_{QK}^h W_E (where Q is the query projection and K is the key projection), this circuit determines which tokens attend to which, i.e., the attention pattern
  2. OV circuit: WUWOVhWEW_U W_{OV}^h W_E (where O is the output projection and V is the value projection), this circuit determines what happens when a token is attended to (the effect on logits)

The EOT framework gives a new perspective on this decomposition: The QK circuit defines the cost function. The attention score sj=qkjs_j = q \cdot k_j is the negative cost cj-c_j in the transport problem. So the QK circuit is really specifying, "how expensive is it to route information from key jj to this query?" The optimal transport plan (attention weights) then minimizes this cost subject to entropy regularization. The OV circuit defines the "cargo". Once the transport plan is determined, what actually gets moved is governed by the OV circuit. In OT terms, the QK circuit determines the transport plan, the OV circuit determines what gets transported. This separation is exactly the structure of optimal transport: first solve for the optimal plan, then execute the transport.

Induction heads as learned transport policies

The In-context Learning and Induction Heads paper (Olsson et al., 2022) identified induction heads, attention heads that implement the pattern [A][B]...[A][B][A][B]...[A] \to [B] (finding previous occurrences of the current token and predicting what came next). These require:

  1. A "copying" OV circuit (positive eigenvalues, attended tokens boost their own logits)
  2. A "prefix matching" QK circuit via K-composition with a previous-token head

From the EOT perspective, the model has learned a content-dependent cost function. The QK circuit doesn't just measure raw similarity, through K-composition it measures "is this position preceded by a token matching my current token?" This is a structured, algorithmic cost function that emerges from training. The induction head's attention pattern is still the optimal transport solution, but for a learned, compositional cost function that encodes temporal structure.

Why circuits are "separable": an OT explanation

Another key insight from the paper is that you can "freeze" attention patterns and study the OV circuit independently. The logits then become a linear function of tokens. The EOT framework explains why this works, the softmax attention weights are a deterministic function of the scores (the unique EOT solution). The QK circuit produces scores, then softmax produces the plan, then OV acts. Once you fix the scores (freeze the attention pattern), the OV circuit acts linearly.

Attention sinks: a necessary consequence of the simplex constraint

Attention sinks are an emergent phenomenon where models allocate disproportionate attention to the first tokens (especially BOS), even when semantically irrelevant. The EOT framework explains not just why it happens, but why it must happen due to the one-sided constraint.

In EOT, the source must send all its mass somewhere, the rows must equal 1. There's no option for "I don't want to transport anything this round." The query must attend to something, always. This creates a problem when the optimal action is indeed "don't attend to anything meaningful." The softmax/EOT solution space doesn't include "no transport."

Attention sinks as learned low-cost dump destinations. The model learns to make certain tokens (BOS, early tokens) have consistently low "cost", not because they're semantically relevant, but because:

  1. They're always available (visible to all subsequent tokens in causal attention)
  2. Their OV circuit contribution can be made nearly null (attending to them doesn't corrupt the output)
  3. They absorb the "excess" probability mass that must go somewhere

Why evicting breaks streaming. In sliding window attention, removing sink tokens removes the low-cost dump destination. The excess mass that must go somewhere now gets forced onto semantically meaningful recent tokens, corrupting the attention pattern. This explains, for example, why StreamingLLM (Xiao et al., 2023) must permanently retain the first few "sink" tokens even as the window slides.

The variation framework predicts this. When all costs cjc_j are similar with no clear semantic match, entropy maximization would spread attention uniformly. But if one token has learned to be consistently "cheap" (low cjc_j), it absorbs disproportionate mass. The model has learned to exploit the EOT objective by creating designated low-cost destinations for mandatory (but unwanted) attention masses.

Composition and multi-step transport

The three types of composition (Q-, K-, V-composition) correspond to different ways of building complex transport:

Multi-layer transformers are solving sequential transport problems where each layer's cost function can depend on previous layers' transport decisions, which is much richer than single-step OT.

Multi-head attention as multi-marginal transport

The question of whether MHA can be represented as multi-marginal transport was posed in Zartbot's essay too, so I took a stab at a hypothesis for it. Standard MHA runs hh independent one-sided EOT problems in parallel. Recall our core result from earlier: each query's attention weights solve α=argminαΔ[jcjαjτH(α)]\alpha^* = \arg\min_{\alpha \in \Delta} [ \sum_j c_j \alpha_j - \tau H(\alpha) ]. In MHA, each head kk has its own score vector s(k)=q(k)K(k)s^{(k)} = q^{(k)} K^{(k)\top}, so each head independently solves its own version of this problem with no coupling between heads. But in OT theory, multi-marginal transport considers something richer: coupling SS distributions simultaneously.

The formulation looks like this. Multi-marginal OT couples SS histograms (as)s=1S(a_s)_{s=1}^S by solving:

minPU(as)si1,,iSCi1,,iSPi1,,iS\min_{P \in U(a_s)_s} \sum_{i_1, \ldots, i_S} C_{i_1, \ldots, i_S} P_{i_1, \ldots, i_S}

where PP is now a tensor (not a matrix) with SS indices, and the constraint set requires each marginal to match. The key difference from standard OT: multi-marginal OT finds a joint coupling across all SS distributions simultaneously, not SS independent pairwise couplings.

If we treated MHA as multi-marginal EOT, we'd want a joint tensor PRn×n××nP \in \mathbb{R}^{n \times n \times \ldots \times n} (hh times) where Pj1,j2,,jhP_{j_1, j_2, \ldots, j_h} represents the probability that head 1 attends to key j1j_1, head 2 attends to key j2j_2, etc., jointly. The cost tensor could include interaction terms beyond the per-head costs:

Cj1,,jh=kcjk(k)+interaction termsC_{j_1, \ldots, j_h} = \sum_k c^{(k)}_{j_k} + \text{interaction terms}

Those interaction terms are where it gets interesting. You could penalize redundancy (heads attending to the same key) or reward coordination (heads attending to semantically related keys).

The mathematical obstacle

The problem is dimensionality. For hh heads and nn keys, the joint tensor PP has nhn^h entries. Even with entropic regularization and Sinkhorn, this is intractable for typical values (h=8h = 8, n=512n = 512 gives 51281021512^8 \approx 10^{21} entries). However, Peyré & Cuturi do note that special cost structures can reduce this. When the cost decomposes nicely, the computation reduces from O(nS)O(n^S) to O(Sn2)O(Sn^2) via successive matrix-vector multiplications. Whether attention costs have such structure is an open question.

Head specialization as a constraint

We know from mechanistic interpretability that heads specialize: some do induction, some do previous-token attention, some do positional patterns, some become attention sinks, etc. This specialization suggests heads are solving different problems that would each have different cost functions, not coordinating on the same one. Multi-marginal OT would force coordination, which might actually hurt by preventing the diversity that makes MHA effective. Any viable multi-marginal formulation may need to preserve this ability to specialize while adding coordination only where it helps.

A more tractable alternative: Wasserstein barycenters

Instead of full multi-marginal OT, we could consider Wasserstein barycenters. The barycenter of SS distributions is:

α=argminαs=1SλsW2(α,α(s))\alpha^* = \arg\min_{\alpha} \sum_{s=1}^S \lambda_s W_2(\alpha, \alpha^{(s)})

This finds a single "consensus" distribution that's close to all heads, computable via Sinkhorn iterations. For MHA, this would mean: instead of concatenating head outputs, compute their barycenter in attention space and get a single attention pattern representing a "geometric average" of what all heads want to attend to.

A hybrid conjecture

This connects back to our earlier discussion of depth. Perhaps the right level of analysis isn't within a layer but across layers:

  1. Within a layer: Heads remain independent (different cost functions, encouraging specialization)
  2. Across layers: The residual stream aggregates head outputs, and this accumulation, rather than any single layer's attention, is where "coordination" emerges

Standard MHA's concatenation and linear projection is a linear combination of independently transported values. Wasserstein barycenters would be nonlinear, computed via optimization rather than matrix multiplication. The question of whether MHA's simplicity is a feature (efficiency, permitting specialization) or a limitation (no principled coordination) remains open. A concrete test: replace the linear projection with entropic barycenter aggregation and measure whether this helps or hurts on tasks requiring cross-head coordination. My suspicion is that independent heads are the right choice within a layer, and that coordination emerges across layers through the residual stream rather than within layers through explicit coupling.


The Bigger Picture

Let's take a step back and reflect on what we've covered.

Softmax attention is overdetermined

We've seen that softmax attention is the answer to multiple independent questions:

  1. Optimization: What distribution minimizes cost while maximizing entropy?
  2. Statistics: What's the maximum entropy distribution given expected similarity?
    • Even physics: What's the Boltzmann distribution with energy = negative similarity?
  3. Learning: What produces advantage-based policy gradients under backprop?

These questions come from different fields, but all give the same answer, which suggests something more fundamental.

The unreasonable effectiveness of softmax

This helps explain why transformers work so well. Softmax attention is the unique solution to a natural optimization problem that balances exploration (entropy) and exploitation (similarity), and its learning dynamics follow the geometry of probability distributions.

Other attention mechanisms can work too, but they're solving different problems. The framework tells you exactly what tradeoffs you're making, e.g., Sparsemax prioritizes sparsity over smoothness, linear attention prioritizes speed over optimality, and so on.

Quick reference table

Optimal TransportInformation GeometryRL/Control
Cost cjc_jNegative log-likelihoodNegative reward
Entropy regularization τ\tauTemperature / FIM scaleExploration parameter
Dual potential ψ\psiLog-partition functionValue function
Transport plan α\alphaExponential family mean paramsPolicy

What remains open

Again taken from the original essay, several questions remain:

  1. Multi-head geometry: See the discussion above. Full multi-marginal OT coordination between heads appears both intractable (nhn^h tensor entries) and potentially counterproductive (it may prevent beneficial specialization), but barycenter-based aggregation of head outputs remains unexplored.

  2. Position encoding: RoPE rotates queries and keys, implicitly defining a position-dependent similarity metric. What cost function does this correspond to in the EOT framework?

  3. Training dynamics: The FIM governs learning geometry. Can we use this to predict which layers will learn fastest, or design better optimizers for attention?

  4. Scaling laws: Does the optimal transport perspective shed light on scaling laws? What happens to the "cost landscape" as model size grows?


References

Primary Source

Litman, E. (2025). Scaled-Dot-Product Attention as One-Sided Entropic Optimal Transport. arXiv preprint arXiv:2508.08369. https://arxiv.org/abs/2508.08369

This Post Builds On

渣B (Zartbot). (2025, August 21). 大模型时代的数学基础(9): SDPA、最优传输、强化学习与信息几何的联系 [Mathematical Foundations in the Era of Large Models (9): The Connection Between SDPA, Optimal Transport, Reinforcement Learning, and Information Geometry]. WeChat/微信公众号. https://mp.weixin.qq.com/s?__biz=MzUxNzQ5MTExNw==&mid=2247494688&idx=1&sn=3d589f6d4be56ee372d5db4f8631b0cc

Optimal Transport

Cuturi, M. (2013). Sinkhorn Distances: Lightspeed Computation of Optimal Transport. Advances in Neural Information Processing Systems 26 (NeurIPS 2013). https://arxiv.org/abs/1306.0895

Peyré, G. & Cuturi, M. (2019). Computational Optimal Transport. Foundations and Trends in Machine Learning, 11(5-6), 355-607. https://arxiv.org/abs/1803.00567

Information Geometry

Amari, S. (1998). Natural Gradient Works Efficiently in Learning. Neural Computation, 10(2), 251-276.

Reinforcement Learning

Williams, R.J. (1992). Simple Statistical Gradient-Following Algorithms for Connectionist Reinforcement Learning. Machine Learning, 8(3-4), 229-256.

Physics of Language Models

Allen-Zhu, Z. (2025). Physics of Language Models: Part 4.1, Architecture Design and the Magic of Canon Layers. NeurIPS 2025. arXiv:2512.17351. https://arxiv.org/abs/2512.17351

Ye, T., Xu, Z., Li, Y., & Allen-Zhu, Z. (2024). Physics of Language Models: Part 2.1, Grade-School Math and the Hidden Reasoning Process. ICLR 2025. https://arxiv.org/abs/2407.20311

Mechanistic Interpretability

Elhage, N., Nanda, N., Olsson, C., et al. (2021). A Mathematical Framework for Transformer Circuits. Transformer Circuits Thread, Anthropic. https://transformer-circuits.pub/2021/framework/index.html

Olsson, C., Elhage, N., Nanda, N., et al. (2022). In-context Learning and Induction Heads. Transformer Circuits Thread, Anthropic. https://transformer-circuits.pub/2022/in-context-learning-and-induction-heads/index.html

Attention and Long Context

Xiao, G., Tian, Y., Chen, B., Han, S., & Lewis, M. (2023). Efficient Streaming Language Models with Attention Sinks. ICLR 2024. arXiv:2309.17453. https://arxiv.org/abs/2309.17453

Wortsman, M., et al. (2024). Softmax is Not Enough (for Sharp Out-of-Distribution Generalization). arXiv preprint arXiv:2410.01104. https://arxiv.org/abs/2410.01104

Transformer Architecture

Xie, Z., Wei, Y., Cao, H., et al. (2025). mHC: Manifold-Constrained Hyper-Connections. arXiv preprint arXiv:2512.24880. https://arxiv.org/abs/2512.24880