Multi-Head Attention

transformers
deep learning
attention
aiayn
How and why transformers learn multiple distinct attention subspaces simultaneously, with a full PyTorch implementation.
Published

March 31, 2026

Written by Michael Gethers

Introduction

The purpose of this document is to detail the concepts behind and implementation of Multi-Head Attention.

This is the second document of of a series that I am writing on the Attention Is All You Need (Vaswani et al., 2017) paper. The intent of this series is both to edify my own understanding of the paper’s core concepts, but to also help others to build a practical understanding of the hows and whys of transformer architecture.

The first document in this series covered scaled dot-product attention. The reader is invited to refer to this document as needed, as multi-head attention depends upon scaled dot-product attention, which will not be covered at length here.

Note that multi-head attention is applied to both the encoder and decoder of the Transformer. We will primarily be discussing the encoder in this document. The mechanics of multi-head attention are identical in the encoder and decoder, with the only difference being the inputs used.

The intuition

In the document on scaled dot-product attention, we discussed its general purpose, which is to allow the Transformer to learn how tokens in a sequence should influence each other given the context in which they are used.

The purpose of multi-head attention is to take a the matrix of vector representations of a particular token sequence (which we’ll call \(X\) in this writeup), and map those vector representations to \(n_{\mathrm{heads}}\) different learned geometric subspaces that are optimized for different semantic relationships between tokens.

Conceptually, what those \(n_{\mathrm{heads}}\) represent are different learned ways of applying attention: one head may learn to focus on pronouns and their referents, one may learn to focus on noun/verb interactions, another may look for adjectives applied to nouns. The subspaces are optimized for these different comprehensive tasks.

After passing through the attentional sublayers in these \(n_{\mathrm{heads}}\), the outputs are synthesized back into a single matrix, which now contains a richer reprentation of each token, built from the context of each the other tokens in the sequence.

Limitations of a single head

A single attention head allows for the creation of a particular subspace for a particular attentional task.  The query, key, and value matrices in that attentional head are optimized for that task, and the resulting subspaces are specifically tailored to it.

But we have already identified many different tasks that an attention head could be learning (noun/verb interactions, etc.).  The question is whether a single attention head could generalize to these different tasks.

Let’s think about that in the context of our example sentence:

The creek was cold and the spunky dog wanted to jump over it.

Let’s look at the words “creek”, “cold”, and “it”, and examine them in context. There are two clauses in this sentence: “the creek was cold”, and “the spunky dog wanted to jump over it”.

  • “creek”: subject of the first clause
  • “cold”: adjective modifying “creek”
  • “it”: object pronoun, referring to “creek”

The words “cold” and “it” both have qualitatively different relationships to the word “creek”, but creek is fundamental to a contextual understanding of both of them. A single learned \(Q\) matrix cannot simultaneously encode “I am a pronoun looking for my referent” and “I am an adjective looking for my noun” at the same time, because a single \(Q\) matrix maps every token into one shared vector space where similarity means one thing. In order to maximize the dot-product of the query vectors of “cold” and “it” with the key vector of “creek”, it would require those query vectors themselves to be similar to each other, but they are different enough queries that forcing them into the same geometric region would corrupt both.

This is the fundamental motivation behind multi-head attention: to allow the model to learn which types of relationships are most important for sequence comprehension, and create those separate subspaces where those relationships can be properly encoded.

The function

The model can “decide” which subspaces it wants to create organically, and the number of heads, i.e. subspaces, to include is a parameter of the model.

The process by which these subspaces are created is defined by the following multi-head attention function, which is taken directly from the paper:

\[ \begin{align} \mathrm{MultiHead}(Q, K, V) &= \mathrm{Concat}(\mathrm{head}_1, ..., \mathrm{head}_h)W^O \\ \text{where }\mathrm{head}_i &= \mathrm{Attention(QW_{i}^{Q}, KW_{i}^{K}, VW_{i}^{V})} \end{align} \] In the above, the linear projections are parameter matrices: \[ W_{i}^{Q} \in \mathbb{R}^{d_\mathrm{model} \times d_k}\\ W_{i}^{K} \in \mathbb{R}^{d_\mathrm{model} \times d_k}\\ W_{i}^{V} \in \mathbb{R}^{d_\mathrm{model} \times d_v}\\ W^{O} \in \mathbb{R}^{hd_v \times d_\mathrm{model}}\\ \]

The inputs to the function

The multi-head attention function takes 3 separate inputs, \(Q\), \(K\), and \(V\).

Functionally, however, in the encoder, \(Q\), \(K\), and \(V\) are all the same \(T \times d_{\mathrm{model}}\) matrix, where T is the number of tokens in the sequence, and \(d_{\mathrm{model}}\) is a parameter of the model determining the size of the vectors that represent each token. \[ Q = K = V = X \in \mathbb{R}^{T \times d_{\mathrm{model}}} \] This \(X\) matrix has one row for every token in the sequence, and each row represents the learned vector representation of that token.

As a dummy example, the \(X\) matrix for our above sentence might look something like this: \[ \begin{array}{c|ccccccc} \text{token} & d_1 & d_2 & d_3 & d_4 & d_5 & d_6 & \cdots & d_{\mathrm{model}} \\ \hline \text{the} & 0.1 & 0.0 & 0.0 & 0.2 & 0.0 & 0.1 & \cdots & 0.0 \\ \text{creek} & 0.0 & 0.2 & 0.1 & 0.0 & 0.3 & 0.5 & \cdots & 0.2 \\ \text{was} & 0.1 & 0.1 & 0.0 & 0.0 & 0.0 & 0.2 & \cdots & 0.1 \\ \text{cold} & 0.0 & 0.3 & 0.2 & 0.1 & 0.0 & 0.4 & \cdots & 0.2 \\ \text{and} & 0.1 & 0.0 & 0.1 & 0.0 & 0.1 & 0.0 & \cdots & 0.1 \\ \text{the} & 0.1 & 0.0 & 0.0 & 0.2 & 0.0 & 0.1 & \cdots & 0.0 \\ \text{spunky} & 0.2 & 0.1 & 0.3 & 0.0 & 0.2 & 0.1 & \cdots & 0.0 \\ \text{dog} & 0.3 & 0.2 & 0.1 & 0.0 & 0.1 & 0.2 & \cdots & 0.1 \\ \text{wanted} & 0.0 & 0.1 & 0.2 & 0.3 & 0.1 & 0.0 & \cdots & 0.2 \\ \text{to} & 0.1 & 0.0 & 0.0 & 0.1 & 0.0 & 0.1 & \cdots & 0.0 \\ \text{jump} & 0.2 & 0.3 & 0.1 & 0.0 & 0.2 & 0.1 & \cdots & 0.3 \\ \text{over} & 0.0 & 0.1 & 0.2 & 0.1 & 0.3 & 0.2 & \cdots & 0.1 \\ \text{it} & 0.1 & 0.2 & 0.1 & 0.0 & 0.2 & 0.3 & \cdots & 0.2 \\ \end{array} \]

Linear transformations

This \(X\) matrix gets linearly transformed \(h * 3 = 3h\) different ways: once for each of \(Q\), \(K\), and \(V\), for each of the \(h\) heads.

This produces a Query matrix, a Key matrix, and Value matrix for all heads. This is the mechanism that allows for the creation of different geometric subspaces that encode different relationships between tokens in the text.

The paper (and most ML literature) uses the simplified notation of \(XW_{i}^{X}\) to signify these linear transformations. However, a linear transformation generally (though not necessarily) includes a bias term. The input matrix gets multiplied by a learned weight matrix \(W\), and then added to a bias matrix \(b\). In the first head, this would be represented like so: \[ XW_{1}^{X} + b_{1}^{X}\\ \text{where } W_{1}^{X} \in \mathbb{R}^{d_\mathrm{model} \times d_k}\\ \text{and } b_{1}^{X} \in \mathbb{R}^{d_k}\\ \]

Each of the \(3h\) weight matrices has dimensions \(d_{\mathrm{model}} \times d_k\), so the resulting transformation is a projection of the input matrix from \(d_{\mathrm{model}}\) space down to \(d_k\) space.

On dimensionality reduction

There is a legitimate question as to why this dimensionality reduction takes place at this stage. Another approach could reasonably have been to keep projections in \(d_{\mathrm{model}}\) space, concatenate up to \(T \times hd_{\mathrm{model}}\) and later use the \(W^O\) transformation to project back down to \(d_{\mathrm{model}}\) space.

The reason for projecting down to \(d_k\) at this stage is primarily computational efficiency, combined with empirical validation. Scaled dot-product attention costs \(O(T^2 \cdot d_k)\), which gets computed for every head. If each head operated in \(d_{\mathrm{model}}\) dimensions instead of \(d_k\), that cost becomes \(O(T^2 \cdot d_{\mathrm{model}})\), and we’ve multiplied our compute by \(h\) with no architectural benefit.

Additionally, there is a parameter count argument. By projecting from \(d_{\mathrm{model}}\) to \(d_k\), we have: \(3*h*d_{\mathrm{model}}*d_k\) projection parameters before scaled dot-product attention, and \(d_{\mathrm{model}} * d_{\mathrm{model}}\) projection parameters in \(W^O\).

With \(d_k = d_{\mathrm{model}}/h\), that becomes: \[ \begin{align} &= 3*h*d_{\mathrm{model}}*d_k + d_{\mathrm{model}} * d_{\mathrm{model}}\\ &= 3*h*d_{\mathrm{model}}*\frac{d_{\mathrm{model}}}{h} + d_{\mathrm{model}}^{2}\\ &= 3d_{\mathrm{model}}^2 + d_{\mathrm{model}}^2 \\ &= 4d_{\mathrm{model}}^2 \end{align} \]

Notably, our parameter count is not a function of \(h\). If we did not reduce dimensionality (so \(d_k = d_{\mathrm{model}}\), and \(W^O\) is \(hd_{\mathrm{model}} \times d_{\mathrm{model}}\)), that becomes: \[ \begin{align} &= 3*h*d_{\mathrm{model}}*d_k + hd_{\mathrm{model}} * d_{\mathrm{model}} \\ &= 3*h*d_{\mathrm{model}}*d_{\mathrm{model}} + hd_{\mathrm{model}} * d_{\mathrm{model}} \\ &= 3hd_{\mathrm{model}}^2 + hd_{\mathrm{model}}^2\\ &= 4hd_{\mathrm{model}}^2\\ \end{align} \] Without the reduction in dimensionality, the total parameter count for projections would increase linearly with \(h\).

Nevertheless, it is worth noting that the dimensionality reduction is a design choice with a computational motivation, not a mathematical necessity.

Scaled dot-product attention

After these initial \(3h\) linear projections down to \(d_k\) dimensions, we apply scaled dot-product attention to each of our \(h\) heads, which each have their own \(Q\), \(K\), and \(V\) matrices.

\[ \mathrm{Attention}(Q, K, V) = \mathrm{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V \]

The mechanics of scaled dot-product attention have been covered at length already in this series, so I’ll not cover it extensively here.

But this is the mechanism by which the Transformer applies its learned understanding of how each token in the sequence contributes to the understanding of each other token, and the understanding of the sequence as a whole.

Concatenation

The output of scaled dot-product attention at every head is a single \(T \times d_k\) matrix representing a weighted blend of each token’s contextual meaning (represented as a vector of length \(d_k\)).

These \(h\) matrices are concatenated column-wise, which results in a single \(T \times hd_k = T \times d_{\mathrm{model}}\) matrix that contains the output of all heads.

The final linear projection of \(W^O\)

After attention is applied to each head, and the outputs are concatenated into a single \(T \times d_{\mathrm{model}}\) matrix, we have a matrix that contains information from every head, and contains within it meaningful information about what the tokens and the sequence as a whole mean in context.

But after concatenation, that information is currently contained in relative isolation: yes it exists within the same matrix, but the different heads have not had any ability to inform each other about what they know of the tokens’ contextual meaning.

This is where we apply our final linear transformation in multi-head attention: the multiplication of our concatenated \(T \times d_{\mathrm{model}}\) matrix by the learned \(d_{\mathrm{model}} \times d_{\mathrm{model}}\) \(W^O\) matrix.

This final transformation is the “mixing matrix” that allows the model to combine signals across heads into a single unified representation. Each output token’s final vector can now reflect coreference information, syntactic relationships, and whatever the other heads learned, simultaneously, in a single \(d_{\mathrm{model}}\) vector.

Implementation

We can implement multi-head attention using PyTorch as follows:

import torch.nn as nn

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model=512, n_heads=8):
        super().__init__()

        ## Static components
        self.d_model = d_model
        self.n_heads = n_heads
        self.d_k = int(d_model/n_heads)

        ## Learnable components
        self.w_q = nn.Linear(self.d_model, self.d_model)
        self.w_k = nn.Linear(self.d_model, self.d_model)
        self.w_v = nn.Linear(self.d_model, self.d_model)
        self.w_o = nn.Linear(self.d_model, self.d_model)
    
    def forward(self, X, mask=None):
        '''
        X can have any dimensions, so long as the final dimension matches the 
        in_features used in nn.Linear (in this case, that is d_model)
        '''
        B = X.shape[0]
        T = X.shape[-2]
        
        # Start with the d_model x d_model transform
        q_full = self.w_q(X)
        k_full = self.w_k(X)
        v_full = self.w_v(X)

        # Now reshape to spread those out over n_heads
        q = q_full.view(B, T, self.n_heads, self.d_k).permute(0,2,1,3)
        k = k_full.view(B, T, self.n_heads, self.d_k).permute(0,2,1,3)
        v = v_full.view(B, T, self.n_heads, self.d_k).permute(0,2,1,3)

        # Now do scaled dot product attention
        if isinstance(mask, torch.Tensor):
            mask = mask.unsqueeze(1)
        sdpa = scaled_dot_product_attention(q, k, v, mask=mask)
                
        # Now concatenate
        concat = sdpa.permute(0,2,1,3).reshape(B, T, self.d_model)

        # Apply the final transform with w_o
        output = self.w_o(concat)
        
        return output

Where scaled_dot_product_attention() is the function defined in our scaled dot-product implementation writeup.

We’re going to walk through this implementation line by line, as there are some interesting and non-obvious elements that are worth explaining in more detail.

__init__

Before doing anything in our __init__ function, recognize that MultiHeadAttention is a subclass of nn.Module. As such, immediately call super().__init__() to initialize the parent class, ensuring that nn.Module’s internal state is properly set up within our new object.

Static components

Our MultiHeadAttention class will take two arguments upon initialization: d_model (\(d_{\mathrm{model}}\)) and n_heads (\(h\)). We give these defaults of 512 and 8 respectively, to match the Transformer paper.

These are static values, and we simply save them as attributes of our class:

## Static components
self.d_model = d_model
self.n_heads = n_heads
self.d_k = int(d_model/n_heads)

Note that we are also defining self.d_k here, which is simply d_model/n_heads, as previously discussed.

Learnable components

Now, let’s look at our learnable components:

## Learnable components
self.w_q = nn.Linear(self.d_model, self.d_model)
self.w_k = nn.Linear(self.d_model, self.d_model)
self.w_v = nn.Linear(self.d_model, self.d_model)
self.w_o = nn.Linear(self.d_model, self.d_model)

We have defined four weight matrices here, one for \(W^Q\), \(W^K\), and \(W^V\), and then our final \(W^O\) mixing matrix. But looking closely at these definitions, they do not immediately appear to map to the dimensions we outlined above. We expect our \(W^Q\), \(W^K\), and \(W^V\) matrices to project our \(T \times d_{\mathrm{model}}\) matrix \(X\) down from \(d_{\mathrm{model}}\) dimensions down to \(d_k\). So why are these linear transformations all \(d_{\mathrm{model}}\) to \(d_{\mathrm{model}}\)?

Recall that we intend to have \(h\) distinct \(W^Q\), \(W^K\), and \(W^V\) matrices, each \(T \times d_k\). To do this directly, we would need to register \(h * 3\) different transformations just for our initial linear projections, and we would have to do this dynamically to respond to the user’s input for \(h\).

But this is messy and computationally inefficient, and there is a much cleaner way to do it.

Instead of registering \(h\) different transformations for each of \(Q\), \(K\), and \(V\), recognize that by our definition of \(d_k\), it is simply \(d_{\mathrm{model}} / h\), which is to say that a single \(T \times d_{\mathrm{model}}\) matrix contains the exact same number of parameters as a \(h\) distinct \(T \times d_k\) matrices.

Using that insight, for each of \(Q\), \(K\), and \(V\), we can create a single \(T \times d_{\mathrm{model}}\) matrix, and simply reshape it to \(h \times T \times d_k\) in our forward pass, effectively creating the \(h\) different \(T \times d_k\) matrices, just in a cleaner and more efficient manner. This transformation occurs in our MultiHeadAttention.forward function, which we will discuss next.

forward

You’ll notice that in the very first line of our forward function, we define B. B is our batch variable: it defines the number of sequences we train our model on concurrently. It is used strictly for computational efficiency, but it is not at all fundamental to the way multi-head attention works, and can make the core concepts slightly more difficult to understand. For now, we will omit B, and add it back in as a final step at the end.

T is the number of tokens in our input sequence, and is the number of rows in our input matrix. It can be obtained simply by taking the second-to-last dimension of X:

T = X.shape[-2]

Note that without batching, this would be equivalent to T = X.shape[0], however when we include batching later, it would be T = X.shape[1]. For this reason, we simply use -2 as our shape index, which is correct in either case.

Initial \(d_{\mathrm{model}} \times d_{\mathrm{model}}\) transforms

We first perform our transformations from \(d_{\mathrm{model}}\) space to \(d_{\mathrm{model}}\) space. As discussed above, this is done for tidiness and efficiency, and it is an equivalent transformation to the \(h\) distinct transformations of \(X\) as outlined in the paper.

This is transformations are performed via a call to the self.w_q, self.w_k, and self.w_v nn.Linear classes we defined in initialization:

q_full = self.w_q(X)
k_full = self.w_k(X)
v_full = self.w_v(X)

Reshaping to \(h \times T \times d_k\)

Conceptually, reshaping our three \(T \times d_{\mathrm{model}}\) matrices into three \(h \times T \times d_k\) matrices is quite simple. We want to split our matrix up into \(h\) sections column-wise, where each section contains \(d_k\) columns, and use each of those \(T \times d_k\) matrices as the inputs to our \(h\) heads.

This ends up being a very simple reshape in PyTorch as well, but it does require some understanding of how reshaping works in PyTorch.

Ultimately, the reshape is done like this:

# Now reshape to spread those out over n_heads
q = q_full.view(T, self.n_heads, self.d_k).permute(1,0,2)
k = k_full.view(T, self.n_heads, self.d_k).permute(1,0,2)
v = v_full.view(T, self.n_heads, self.d_k).permute(1,0,2)

I’ve used a view + permute approach here. This solution is clean and efficient, and I’ve written a separate document (The Mechanics of Reshaping in Multi-Head Attention) which outlines exactly how this two-step solution works. I would highly encourage you do read through this document if rationale behind this transformation is not self-evident, as it gets into the weeds of how PyTorch stores data and what views/reshapes are actually doing in practice.

Otherwise, we proceed from here to scaled dot-product attention.

Performing scaled dot product attention

Now that we have our \(Q\), \(K\), and \(V\) input matrices, each in the shape of \((h, T, d_k)\), we can pass them into the scaled dot product attention function we created previously, which looks like this:

def scaled_dot_product_attention(Q, K, V, mask=None):
    '''
    Q: torch.tensor() with shape T x d_k
    K: torch.tensor() with shape T x d_k
    V: torch.tensor() with shape T x d_v
    mask: T x T mask or None
    '''
    d_k = Q.shape[-1]
    C = Q @ K.transpose(-1, -2)         
    divisor = math.sqrt(d_k)            
    scaled = C / divisor
    C_masked = scaled
    if isinstance(mask, torch.Tensor):
        C_masked = scaled.masked_fill(mask, float('-inf'))
    softmaxed = torch.softmax(C_masked, dim=-1)

    final = softmaxed @ V
    return final

We pass our q, k, and v variables in directly as arguments:

sdpa = scaled_dot_product_attention(q, k, v)

The result, sdpa, is an \(h \times T \times d_k\) tensor.

Concatenation

We now reassemble the output of our attention heads into a single \(T \times d_{\mathrm{model}}\) matrix:

concat = sdpa.permute(1,0,2).reshape(T, self.d_model)

That this uses the same concept we used to split our \(T \times d_{\mathrm{model}}\) into heads, but in reverse. We recognize that we have an \(h \times T \times d_k\) tensor. If we were to reshape directly into \(T \times d_{\mathrm{model}}\) (i.e. without the permute), the first row of our concatenated matrix would be the first \(d_{\mathrm{model}}\) elements in our sdpa tensor’s memory, which is the first \(d_{\mathrm{model}}\) elements of our first output head. This is not what we need.

Instead, we want to reorder our sdpa output so that it has shape \(T \times h \times d_k\) (exactly the opposite reordering we did previously). When this is made contiguous, our elements will be in the right order for our final reshape: our first row will be the first \(d_{\mathrm{model}}\) elements of this new shape, which will be exactly the first matrix of our permuted sdpa tensor. That is to say, after permute(1,0,2), we have shape \((T, h, d_k)\), so the first row of the reshape corresponds to the first token’s output across all heads, concatenated.

Note that as we’ve discussed previously, reshaping with view requires contiguity. sdpa.permute(1,0,2) breaks contiguity, so we need to call .contiguous() before we can reshape with view. The reshape function is simply shorthand for this exact sequence of actions: it calls .contiguous(), and then calls view with the desired parameters in a single, more convenient function.

Final linear transform with \(W^O\)

The final step is the to make the final “mixing” transformation using \(d_{\mathrm{model}} \times d_{\mathrm{model}}\) weight matrix \(W^O\). This is done via a simple call to our nn.Linear object self.w_o:

output = self.w_o(concat)

Adding batches B back in

The Transformer is generally trained in batches. In order to do this, we must add a batch dimension to our forward function. Instead of \(T \times d_{\mathrm{model}}\), the shape of our output will be \(B \times T \times d_{\mathrm{model}}\). This dimension \(B\) will be a part of our input matrix \(X\), and will get be carried throughout every step of the the logic.

In this sense, it is a very simple change to make: we don’t need to change any core logic, we just need to make sure to 1) extract \(B\) from the input matrix, and 2) include it in all of our reshapes and permutations:

B = X.shape[0]
T = X.shape[-2]

# Start with the d_model x d_model transform
q_full = self.w_q(X)
k_full = self.w_k(X)
v_full = self.w_v(X)

# Now reshape to spread those out over n_heads
q = q_full.view(B, T, self.n_heads, self.d_k).permute(0,2,1,3)
k = k_full.view(B, T, self.n_heads, self.d_k).permute(0,2,1,3)
v = v_full.view(B, T, self.n_heads, self.d_k).permute(0,2,1,3)

# Now do scaled dot product attention
sdpa = scaled_dot_product_attention(q, k, v) # will this work with the added dimension of heads?
        
# Now concatenate
concat = sdpa.permute(0,2,1,3).reshape(B, T, self.d_model)

# Apply the final transform with w_o
output = self.w_o(concat)

Masking

The final step is to include masking. The causal mask depends only on \(T\), so in theory it could be constructed inside forward. But the padding mask depends on the actual length of each sequence, which is only known at the data pipeline level (i.e. when we’re tokenizing and batching). Instead of constructing the mask within MultiHeadAttention, the design is cleaner if we construct it alongside our input matrix \(X\).

So for the purpose of our forward function, all we’re going to do is accept mask as an optional variable, add the h dimension do it, and pass that directly into scaled_dot_product_attention:

if isinstance(mask, torch.Tensor):
    mask = mask.unsqueeze(1)
sdpa = scaled_dot_product_attention(q, k, v, mask=mask)

The purpose of masking is twofold. First and most crucially, it stops the model from learning from future tokens while it is training (this is called the causal mask and is applied to the decoder only; the encoder uses full bidirectional attention, meaning every token can attend to every other token). Secondly, it facilitates batching by enabling sequences of varying lengths to be passed through the model simultaneously (this mask is applied to both encoder and decoder during training). Recall that scaled dot product attention first performs a matrix multiplication of \(Q\) and \(K^T\), and then scales it by \(\sqrt{d_k}\). What we have at this point is a \(T \times T\) matrix representing the dot products of each query and key vector.

While we’ve primarily discussed the encoder to this point, I’ll include detail on the causal mask here, which, as noted, is only used in the decoder. But it’s simple to achieve the first objective: we simply mask out the upper triangle of the \(T \times T\) matrix, by setting each of those values by -inf (as they are then run through a softmax function, which is 0 at -inf). Any elements in the upper triangle represent keys that the given query should not yet have access to, as they appear later in the sequence. For example, for the sequence “this is a sample sequence”:

\[ \begin{array}{r|ccccc} \text{Q} \downarrow \, / \, \text{K} \rightarrow & \text{this} & \text{is} & \text{a} & \text{sample} & \text{sequence} \\ \hline \text{this} & 1 & -\infty & -\infty & -\infty & -\infty \\ \text{is} & 1 & 1 & -\infty & -\infty & -\infty \\ \text{a} & 1 & 1 & 1 & -\infty & -\infty \\ \text{sample} & 1 & 1 & 1 & 1 & -\infty \\ \text{sequence} & 1 & 1 & 1 & 1 & 1 \\ \end{array} \] Note: cells shown as 1 represent arbitrary (non-zero) attention scores for illustration. The actual values are the entries of \(QK^T / \sqrt{d_k}\).

The second objective is slightly more nuanced, because the same mask cannot be applied to every sequence. Every batch will contain sequences of varying length, and \(T\) will be defined as the length of the longest sequence in the batch. Say our sample sequence above, which is length 5, is in a batch where the longest sequence is length 7. \(T = 7\) in this case, and therefore our sequence needs to be represented as length 7.

In order to do this, we use “padding tokens”, which are essentially meaningless tokens that are meant to extend shorter sequences up to the required length \(T\) for their batch. In our case, our sequence would become “this is a sample sequence <pad> <pad>”.

This necessarily alters our mask too: just as we don’t want future tokens to inform the training of the Transformer, we do not want padding tokens to inform it either: they are meaningless tokens, strictly present to make batch computation possible.

So to incorporate both the upper triangle masking and the padding token masking, the mask that ends up getting passed will produce this result: \[ \begin{array}{r|ccccccc} \text{Q} \downarrow \, / \, \text{K} \rightarrow & \text{this} & \text{is} & \text{a} & \text{sample} & \text{sequence} & \text{<pad>} & \text{<pad>}\\ \hline \text{this} & 1 & -\infty & -\infty & -\infty & -\infty & -\infty & -\infty\\ \text{is} & 1 & 1 & -\infty & -\infty & -\infty & -\infty & -\infty\\ \text{a} & 1 & 1 & 1 & -\infty & -\infty & -\infty & -\infty\\ \text{sample} & 1 & 1 & 1 & 1 & -\infty & -\infty & -\infty\\ \text{sequence} & 1 & 1 & 1 & 1 & 1 & -\infty & -\infty\\ \text{<pad>} & -\infty & -\infty & -\infty & -\infty & -\infty & -\infty & -\infty\\ \text{<pad>} & -\infty & -\infty & -\infty & -\infty & -\infty & -\infty & -\infty\\ \end{array} \]

In the encoder, where the causal mask is not used, this would translate to: \[ \begin{array}{r|ccccccc} \text{Q} \downarrow \, / \, \text{K} \rightarrow & \text{this} & \text{is} & \text{a} & \text{sample} & \text{sequence} & \text{<pad>} & \text{<pad>}\\ \hline \text{this} & 1 & 1 & 1 & 1 & 1 & -\infty & -\infty\\ \text{is} & 1 & 1 & 1 & 1 & 1 & -\infty & -\infty\\ \text{a} & 1 & 1 & 1 & 1 & 1 & -\infty & -\infty\\ \text{sample} & 1 & 1 & 1 & 1 & 1 & -\infty & -\infty\\ \text{sequence} & 1 & 1 & 1 & 1 & 1 & -\infty & -\infty\\ \text{<pad>} & -\infty & -\infty & -\infty & -\infty & -\infty & -\infty & -\infty\\ \text{<pad>} & -\infty & -\infty & -\infty & -\infty & -\infty & -\infty & -\infty\\ \end{array} \]

Each sequence in the batch will have its own mask, and each head will need to receive the mask as well, so the final mask shape that we pass will be \(B \times h \times T \times T\). As mentioned above, it will be constructed at the data pipeline level, when our input matrix \(X\) (\(B \times T \times d_{\mathrm{model}}\)) is created. At construction, this mask will be \(B \times T \times T\).

Ultimately, we need to take this input mask from \(B \times T \times T\) to \(B \times h \times T \times T\). But by leveraging the mechanics of PyTorch’s broadcasting rules, we can make this simpler by expanding the mask to \(B \times 1 \times T \times T\), which mask.unsqueeze(1) does.

A note on broadcasting:

The reason this works in practice is because of broadcasting in PyTorch. Broadcasting is what PyTorch does when you try to apply an operation to two tensors that don’t have the same shape. Rather than throwing an error, it tries to “expand” one or both tensors to make the shapes compatible. It follows two rules:

  1. Align shapes from the right. If the tensors have different numbers of dimensions, pad the shorter shape with 1s to the left.
  2. For each dimension, sizes must either match, or one of them must be 1. If one of them is 1, it gets stretched to match the other.

If either rule is violated, you get an error.

When we pass a shape of \(B \times 1 \times T \times T\) to a function that expects a shape of \(B \times h \times T \times T\), the broadcasting rules will start from the right: - \(T\) matches \(T\): pass - \(T\) matches \(T\): pass - \(h\) does not match \(1\), but one dimension is \(1\), so it gets stretched to match \(h\): pass - \(B\) matches \(B\): pass

If we had not used unsqueeze, we’d try to pass a \(B \times T \times T\) shape where a \(B \times h \times T \times T\) is expected: - Tensors have different number of dimensions, 4 and 3. The shorter shape gets padded with 1s to the left, so we now have \(1 \times B \times T \times T\) and \(B \times h \times T \times T\) - \(T\) matches \(T\): pass - \(T\) matches \(T\): pass - \(h\) does not match \(B\), and neither dimension is \(1\): FAIL

unsqueeze is what allows us to make use of PyTorch’s broadcasting rules, which ultimately makes our function simpler, more interpretable, and more efficient.