The Mechanics of Reshaping in Multi-Head Attention

transformers
deep learning
attention
aiayn
A deep dive into PyTorch memory layout and stride mechanics, motivated by the head-splitting reshape in multi-head attention.
Published

April 1, 2026

Written by Michael Gethers

Motivation

The purpose of this document is to develop the intuition for how reshapes (view, reshape, and permute) work in PyTorch. It is primarily motivated by the Multi-Head Attention architecture in the Attention Is All You Need (Vaswani et al., 2017) paper, and my PyTorch implementation of it:

import torch.nn as nn

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model=512, n_heads=8):
        super().__init__()

        ## Static components
        self.d_model = d_model
        self.n_heads = n_heads
        self.d_k = int(d_model/n_heads)

        ## Learnable components
        self.w_q = nn.Linear(self.d_model, self.d_model)
        self.w_k = nn.Linear(self.d_model, self.d_model)
        self.w_v = nn.Linear(self.d_model, self.d_model)
        self.w_o = nn.Linear(self.d_model, self.d_model)
    
    def forward(self, X, mask=None):
        '''
        X can have any dimensions, so long as the final dimension matches the 
        in_features used in nn.Linear (in this case, that is d_model)
        '''
        B = X.shape[0]
        T = X.shape[-2]
        
        # Start with the d_model x d_model transform
        q_full = self.w_q(X)
        k_full = self.w_k(X)
        v_full = self.w_v(X)

        # Now reshape to spread those out over n_heads
        q = q_full.view(B, T, self.n_heads, self.d_k).permute(0,2,1,3) 
        k = k_full.view(B, T, self.n_heads, self.d_k).permute(0,2,1,3)  
        v = v_full.view(B, T, self.n_heads, self.d_k).permute(0,2,1,3)

        # Now do scaled dot product attention
        if isinstance(mask, torch.Tensor):
            mask = mask.unsqueeze(1)
        sdpa = scaled_dot_product_attention(q, k, v, mask=mask)
                
        # Now concatenate
        concat = sdpa.permute(0,2,1,3).reshape(B, T, self.d_model)

        # Apply the final transform with w_o
        output = self.w_o(concat)
        
        return output

Specifically we will be looking at the reshapes of q_full, k_full, and v_full above, each of which is a \(T \times d_{\mathrm{model}}\), being reshaped into \(h\) distinct \(T \times d_k\) matrices, one for each attention head1. This ends up being a very simple two-step transformation in PyTorch, but it is worth developing a robust intuition for exactly how and why it works.

How reshaping our matrix works in PyTorch

An intuitive explanation

In order to explain what we actually want here, let’s use a dummy example. Say we have a \(3 \times 8\) matrix \(X\): that is, \(T=3\) and \(d_{\mathrm{model}}=8\), and we are performing multi-head attention with \(2\) heads. \[ \begin{align} T &= 3 \\ d_{\mathrm{model}}&=8 \\ h&=2 \\ d_k&=4 \\ \end{align} \] Our \(X\) matrix then looks like this: \[ \begin{bmatrix} x_{1,1} & x_{1,2} & x_{1,3} & x_{1,4} & x_{1,5} & x_{1,6} & x_{1,7} & x_{1,8} \\ x_{2,1} & x_{2,2} & x_{2,3} & x_{2,4} & x_{2,5} & x_{2,6} & x_{2,7} & x_{2,8} \\ x_{3,1} & x_{3,2} & x_{3,3} & x_{3,4} & x_{3,5} & x_{3,6} & x_{3,7} & x_{3,8} \\ \end{bmatrix} \] Our linear transform keeps this in the same dimensions, but we want to split it into our 2 heads, column-wise: \[ \underbrace{ \begin{bmatrix} x_{1,1} & x_{1,2} & x_{1,3} & x_{1,4} \\ x_{2,1} & x_{2,2} & x_{2,3} & x_{2,4} \\ x_{3,1} & x_{3,2} & x_{3,3} & x_{3,4} \\ \end{bmatrix} }_{\text{Head 1}} \underbrace{ \begin{bmatrix} x_{1,5} & x_{1,6} & x_{1,7} & x_{1,8} \\ x_{2,5} & x_{2,6} & x_{2,7} & x_{2,8} \\ x_{3,5} & x_{3,6} & x_{3,7} & x_{3,8} \\ \end{bmatrix} }_{\text{Head 2}} \] It could be tempting to simply try something like q_full.view(self.n_heads, T, self.d_k), but that will not give the correct result, and the reason is the way PyTorch stores data, and the way view/reshape works given that underlying data model.

How data is stored in PyTorch Tensors

Despite having multiple dimensions, a tensor’s data is stored as a flat, one-dimensional block of memory. A shape like (2, 10, 8, 64) is just \(2 \times 10 \times 8 \times 64 = 10,240\) numbers sitting in a row.

The shape is essentially layered on top, as a set of rules that PyTorch uses to let you index into that flat block as if it were multidimensional. Those rules are called “strides”. For each dimension, the stride tells you “to move one step in this dimension, skip this many elements in the flat block.” And when you use the view function to reshape in PyTorch, the data itself does not change, only the indexing rules change.

The mechanics of view

Then we need to understand how view works, and how it relates to the order in which the data is actually stored. Think of a shape like (5, 2, 10): this shape represents 5 matrices with 2 rows and 10 columns. In memory, this is stored in blocks of 10, the innermost dimension. The first 10 elements are the first row of the first matrix; the second 10 elements are the second row of the first matrix; the third 10 elements are the first row of the second matrix; etc.

When we use view to reshape into something like (2, 5, 2, 5), what are we actually doing? We are telling PyTorch to reindex the data such that we now have 2 batches, each with 5 matrices, where each matrix now contains 2 rows and 5 columns. To do this, it will simply create this shape in order:

We again start with the innermost dimension, the columns. So our first block of 5 will be the first row of the first matrix of the first batch; the second block of 5 will be the second row of the first matrix of the first batch; the third block of 5 will be the first row of the second matrix of the first batch; etc.

Applying this to our dummy example

So let’s go back to the matrices we have above, and the simple (hopeful) reshape which we already know will not work: q_full.view(self.n_heads, T, self.d_k), or X.view(2, 3, 4). We do indeed ultimately want a shape of (2, 3, 4). But what is actually happening here?

We’re starting with a (3, 8) shape and reshaping it into a (2, 3, 4): 2 matrices, each with 3 rows and 4 columns. So we’re going to take the first 4 elements, and that’s going to be the first row of the first matrix; then we’ll take the next 4 elements, and that’s going to be the second row of the first matrix; then we’ll take the next 4 elements, and that’ll be the third row of the first matrix, completing the first matrix.

Is that what we want?

This is what we’ve done in this example, where blue ends up being our first matrix, and orange our second: \[ \underbrace{ \begin{bmatrix} \textcolor{cornflowerblue}{x_{1,1}} & \textcolor{cornflowerblue}{x_{1,2}} & \textcolor{cornflowerblue}{x_{1,3}} & \textcolor{cornflowerblue}{x_{1,4}} \\ \textcolor{cornflowerblue}{x_{2,1}} & \textcolor{cornflowerblue}{x_{2,2}} & \textcolor{cornflowerblue}{x_{2,3}} & \textcolor{cornflowerblue}{x_{2,4}} \\ \textcolor{orange}{x_{3,1}} & \textcolor{orange}{x_{3,2}} & \textcolor{orange}{x_{3,3}} & \textcolor{orange}{x_{3,4}} \\ \end{bmatrix} }_{\text{Head 1}} \underbrace{ \begin{bmatrix} \textcolor{cornflowerblue}{x_{1,5}} & \textcolor{cornflowerblue}{x_{1,6}} & \textcolor{cornflowerblue}{x_{1,7}} & \textcolor{cornflowerblue}{x_{1,8}} \\ \textcolor{orange}{x_{2,5}} & \textcolor{orange}{x_{2,6}} & \textcolor{orange}{x_{2,7}} & \textcolor{orange}{x_{2,8}} \\ \textcolor{orange}{x_{3,5}} & \textcolor{orange}{x_{3,6}} & \textcolor{orange}{x_{3,7}} & \textcolor{orange}{x_{3,8}} \\ \end{bmatrix} }_{\text{Head 2}} \] Because we’re simply taking the first 3 blocks of 4 in row major order and calling that our first matrix, we can see that we do not get the correct elements in each matrix.

So with this in mind, let’s try a transpose first. I will keep our \(x_{i,j}\) the same through the transposition: \[ \begin{aligned} &\left. \begin{bmatrix} x_{1,1} & x_{2,1} & x_{3,1} \\ x_{1,2} & x_{2,2} & x_{3,2} \\ x_{1,3} & x_{2,3} & x_{3,3} \\ x_{1,4} & x_{2,4} & x_{3,4} \\ \end{bmatrix} \right\} \text{Head 1} \\[1em] &\left. \begin{bmatrix} x_{1,5} & x_{2,5} & x_{3,5} \\ x_{1,6} & x_{2,6} & x_{3,6} \\ x_{1,7} & x_{2,7} & x_{3,7} \\ x_{1,8} & x_{2,8} & x_{3,8} \\ \end{bmatrix} \right\} \text{Head 2} \end{aligned} \] Now if we reshape this transposed matrix in the same way, we will get the correct elements in each matrix.

q_full.transpose(-2, -1).contiguous().view(self.n_heads, T, self.d_k)

## OR

X.transpose(-2, -1).contiguous().view(2,3,4)

Note: we must include .contiguous() because .transpose(-2,-1) breaks the contiguity of our original matrix. Without .contiguous(), we cannot call .view().

This yields the following head splits: \[ \begin{aligned} &\left. \begin{bmatrix} \textcolor{cornflowerblue}{x_{1,1}} & \textcolor{cornflowerblue}{x_{2,1}} & \textcolor{cornflowerblue}{x_{3,1}} \\ \textcolor{cornflowerblue}{x_{1,2}} & \textcolor{cornflowerblue}{x_{2,2}} & \textcolor{cornflowerblue}{x_{3,2}} \\ \textcolor{cornflowerblue}{x_{1,3}} & \textcolor{cornflowerblue}{x_{2,3}} & \textcolor{cornflowerblue}{x_{3,3}} \\ \textcolor{cornflowerblue}{x_{1,4}} & \textcolor{cornflowerblue}{x_{2,4}} & \textcolor{cornflowerblue}{x_{3,4}} \\ \end{bmatrix} \right\} {\text{Head 1}} \\[1em] &\left. \begin{bmatrix} \textcolor{orange}{x_{1,5}} & \textcolor{orange}{x_{2,5}} & \textcolor{orange}{x_{3,5}} \\ \textcolor{orange}{x_{1,6}} & \textcolor{orange}{x_{2,6}} & \textcolor{orange}{x_{3,6}} \\ \textcolor{orange}{x_{1,7}} & \textcolor{orange}{x_{2,7}} & \textcolor{orange}{x_{3,7}} \\ \textcolor{orange}{x_{1,8}} & \textcolor{orange}{x_{2,8}} & \textcolor{orange}{x_{3,8}} \\ \end{bmatrix} \right\} {\text{Head 2}} \end{aligned} \] This indeed produces the correct splits, let’s examine whether this actually produce the right head matrices?

We know from above that we want the first row of our first matrix to be \(\begin{bmatrix} x_{1,1} & x_{1,2} & x_{1,3} & x_{1,4}\end{bmatrix}\). But what matrices will the reshape on the transposed matrix actually produce? If we continue the same protocol, moving in 4 element chunks, we will produce the following head matrices: \[ \underbrace{ \begin{bmatrix} \textcolor{cornflowerblue}{\underline{x_{1,1}}} & \textcolor{cornflowerblue}{x_{2,1}} & \textcolor{cornflowerblue}{x_{3,1}} & \textcolor{cornflowerblue}{\underline{x_{1,2}}} \\ \textcolor{cornflowerblue}{x_{2,2}} & \textcolor{cornflowerblue}{x_{3,2}} & \textcolor{cornflowerblue}{\underline{x_{1,3}}} & \textcolor{cornflowerblue}{x_{2,3}} \\ \textcolor{cornflowerblue}{x_{3,3}} & \textcolor{cornflowerblue}{\underline{x_{1,4}}} & \textcolor{cornflowerblue}{x_{2,4}} & \textcolor{cornflowerblue}{x_{3,4}} \\ \end{bmatrix} }_{\text{Head 1}} \underbrace{ \begin{bmatrix} \textcolor{orange}{\underline{x_{1,5}}} & \textcolor{orange}{x_{2,5}} & \textcolor{orange}{x_{3,5}} & \textcolor{orange}{\underline{x_{1,6}}} \\ \textcolor{orange}{x_{2,6}} & \textcolor{orange}{x_{3,6}} & \textcolor{orange}{\underline{x_{1,7}}} & \textcolor{orange}{x_{2,7}} \\ \textcolor{orange}{x_{3,7}} & \textcolor{orange}{\underline{x_{1,8}}} & \textcolor{orange}{x_{2,8}} & \textcolor{orange}{x_{3,8}} \\ \end{bmatrix} }_{\text{Head 2}} \] In the above matrix, I’ve underlined the elements of what each head’s first row should be. It is clear that while these are the correct values that should be in each head, they’re in the wrong order.

So this reshape does not actually produce what we want, but it is becoming clearer here what we actually do want. The initial transpose allows us to get the values we want, but the reshape is putting things in the wrong order. So let’s modify that reshape.

If we reshape the transposed matrix to (2, 4, 3) (i.e. (self.n_heads, self.d_k, T)), we’ll preserve our rows and columns as we want them, and the only move we’ll have left to make is one final transposition of the last two dimensions.

q_full.transpose(-2, -1).contiguous().view(self.n_heads, self.d_k, T).transpose(-2, -1)

## OR

X.transpose(-2, -1).contiguous().view(2,4,3).transpose(-2, -1)

Which gives us: \[ \underbrace{ \begin{bmatrix} \textcolor{cornflowerblue}{\underline{x_{1,1}}} & \textcolor{cornflowerblue}{\underline{x_{1,2}}} & \textcolor{cornflowerblue}{\underline{x_{1,3}}} & \textcolor{cornflowerblue}{\underline{x_{1,4}}} \\ \textcolor{cornflowerblue}{x_{2,1}} & \textcolor{cornflowerblue}{x_{2,2}} & \textcolor{cornflowerblue}{x_{2,3}} & \textcolor{cornflowerblue}{x_{2,4}} \\ \textcolor{cornflowerblue}{x_{3,1}} & \textcolor{cornflowerblue}{x_{3,2}} & \textcolor{cornflowerblue}{x_{3,3}} & \textcolor{cornflowerblue}{x_{3,4}} \\ \end{bmatrix} }_{\text{Head 1}} \underbrace{ \begin{bmatrix} \textcolor{orange}{\underline{x_{1,5}}} & \textcolor{orange}{\underline{x_{1,6}}} & \textcolor{orange}{\underline{x_{1,7}}} & \textcolor{orange}{\underline{x_{1,8}}} \\ \textcolor{orange}{x_{2,5}} & \textcolor{orange}{x_{2,6}} & \textcolor{orange}{x_{2,7}} & \textcolor{orange}{x_{2,8}} \\ \textcolor{orange}{x_{3,5}} & \textcolor{orange}{x_{3,6}} & \textcolor{orange}{x_{3,7}} & \textcolor{orange}{x_{3,8}} \\ \end{bmatrix} }_{\text{Head 2}} \]

Practical implementation

We at last have our n_heads = 2 head matrices with the correct values, and in the correct order. Now that we have a correct intuitive understanding of what needs to happen, we can actually use a much simpler, two-step technique in PyTorch to construct these matrices in a more elegant way.

Instead of attempting to construct the correct matrices directly via views alone we can take a different approach, where we first split our input matrix \(X\) into \(h\) rows of length \(d_k\), and then assemble those into \(T\) matrices. In effect, what this means then is that our first \(h \times d_k\) matrix contains the first row of each head matrix; the second \(h \times d_k\) matrix contains the second row of each head matrix; etc. That is:

q_full.view(T, self.n_heads, self.d_k)

## OR

X.view(3, 2, 4)

Continuing with our dummy example, this looks like this:

\[ \begin{aligned} &\left. \begin{bmatrix} \textcolor{cornflowerblue}{\underline{x_{1,1}}} & \textcolor{cornflowerblue}{\underline{x_{1,2}}} & \textcolor{cornflowerblue}{\underline{x_{1,3}}} & \textcolor{cornflowerblue}{\underline{x_{1,4}}} \\ \textcolor{orange}{\underline{x_{1,5}}} & \textcolor{orange}{\underline{x_{1,6}}} & \textcolor{orange}{\underline{x_{1,7}}} & \textcolor{orange}{\underline{x_{1,8}}} \\ \end{bmatrix} \right. \quad T_1 \\[1em] &\left. \begin{bmatrix} \textcolor{cornflowerblue}{x_{2,1}} & \textcolor{cornflowerblue}{x_{2,2}} & \textcolor{cornflowerblue}{x_{2,3}} & \textcolor{cornflowerblue}{x_{2,4}} \\ \textcolor{orange}{x_{2,5}} & \textcolor{orange}{x_{2,6}} & \textcolor{orange}{x_{2,7}} & \textcolor{orange}{x_{2,8}} \\ \end{bmatrix} \right. \quad T_2 \\[1em] &\left. \begin{bmatrix} \textcolor{cornflowerblue}{x_{3,1}} & \textcolor{cornflowerblue}{x_{3,2}} & \textcolor{cornflowerblue}{x_{3,3}} & \textcolor{cornflowerblue}{x_{3,4}} \\ \textcolor{orange}{x_{3,5}} & \textcolor{orange}{x_{3,6}} & \textcolor{orange}{x_{3,7}} & \textcolor{orange}{x_{3,8}} \\ \end{bmatrix} \right. \quad T_3 \end{aligned} \]

From here, it’s a simple, one-function transformation using PyTorch’s permute() function. Permute simply reorders dimensions, which is exactly what we need in this case. We know our current dimensions after the initial reshape are \((T, h, d_k)\), what we want is \((h, T, d_k)\). Permute makes this reordering quite easy, where we simply pass the desired ordering of dimensions as indexes:

q_full.view(T, self.n_heads, self.d_k).permute(1,0,2)

## OR

X.view(3, 2, 4).permute(1,0,2)

Which yields the same correct split of our input matrix we previously obtained, but in a much simpler way: \[ \underbrace{ \begin{bmatrix} \textcolor{cornflowerblue}{\underline{x_{1,1}}} & \textcolor{cornflowerblue}{\underline{x_{1,2}}} & \textcolor{cornflowerblue}{\underline{x_{1,3}}} & \textcolor{cornflowerblue}{\underline{x_{1,4}}} \\ \textcolor{cornflowerblue}{x_{2,1}} & \textcolor{cornflowerblue}{x_{2,2}} & \textcolor{cornflowerblue}{x_{2,3}} & \textcolor{cornflowerblue}{x_{2,4}} \\ \textcolor{cornflowerblue}{x_{3,1}} & \textcolor{cornflowerblue}{x_{3,2}} & \textcolor{cornflowerblue}{x_{3,3}} & \textcolor{cornflowerblue}{x_{3,4}} \\ \end{bmatrix} }_{\text{Head 1}} \underbrace{ \begin{bmatrix} \textcolor{orange}{\underline{x_{1,5}}} & \textcolor{orange}{\underline{x_{1,6}}} & \textcolor{orange}{\underline{x_{1,7}}} & \textcolor{orange}{\underline{x_{1,8}}} \\ \textcolor{orange}{x_{2,5}} & \textcolor{orange}{x_{2,6}} & \textcolor{orange}{x_{2,7}} & \textcolor{orange}{x_{2,8}} \\ \textcolor{orange}{x_{3,5}} & \textcolor{orange}{x_{3,6}} & \textcolor{orange}{x_{3,7}} & \textcolor{orange}{x_{3,8}} \\ \end{bmatrix} }_{\text{Head 2}} \]

This is our final transformation code: the view followed by permute is clean, efficient, and produces the correct column-wise matrix split.

This is just a small component of the full multi-head attention implementation. To see how this Tensor reshape fits back into the multi-head attention architecture, please see the main Multi-Head Attention document.

Footnotes

  1. In practice, these tensors include a batch dimension, so the transformation is actually \(B \times T \times d_{\mathrm{model}}\) to \(B \times h \times T \times d_k\). For explanatory purposes, the batch dimension is not essential, and is omitted here.↩︎