Position-wise Feed-Forward Networks
Written by Michael Gethers
Introduction
As we covered in our previous document, multi-head attention is, at its core, a weighted averaging operation. Each token’s output is a weighted sum of value vectors. The weights come from the softmax, the values come from \(V\), but the final operation itself is fundamentally linear. For each token in the sequence, attention has produced a \(d_{\mathrm{model}}\) vector that encodes something like “this token, in the context of its neighbors”.
But is that all we need? Can linear relationships capture the complexity we need our model to understand? The answer is no, and this is the basis for the feed-forward network, which allows us to take the rich contextual information from attention and make non-linear computations from it.
The intuition behind the Feed-Forward Network
Before jumping into the actual calculations and implementation of the FFN, let’s go a bit deeper into the intuitions for why it is necessary. While it may sound self-evident that 1) multi-head attention is a fundamentally linear operation, and 2) this is not sufficient, it is worth solidifying that intuition as completely as possible.
Multi-head attention is a linear operation
We have walked through MHA in depth in previous posts, so if you’ve started from the beginning it is probably clear that MHA is a linear operation. But for the sake of completeness of this as a standalone document, let’s quickly run through a summary of MHA.
We start with an input matrix \(X\), which is a matrix of the vector representations of each token in the given sequence. For this walkthrough we’ll use the encoder rather than the decoder, so that we only have this single matrix getting passed to MHA.
First, \(X\) is linearly transformed \(3h\) different ways, producing a \(Q\), \(K\), and \(V\) matrix for each of our \(h\) heads. Linear operation.
For each head these \(Q\), \(K\), and \(V\) matrices are passed into the scaled dot-product attention function, which looks like this:
\[ \mathrm{Attention}(Q, K, V) = \mathrm{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V \]
While the intuitions behind this operation are worth understanding and are spelled out in detail in the Scaled Dot-Product Attention post, it is clearly a linear function: the \(Q\) and \(K^T\) matrices are multiplied together, divided by a scalar, softmaxed, and then multiplied by the \(V\) matrix to produce the final output. Now, the softmax here is not itself linear: it does contain exponentials (\(\sigma(z_i) = \frac{e^{z_i}}{\sum_{j=1}^{K} e^{z_j}}\)). But while the non-linearity of softmax is used to determine the weights, the final output of scaled dot-product attention is still just a weighted average of value vectors from matrix \(V\). The blending of vectors itself is still linear. Linear operation.
The output of scaled dot-product attention for each of our \(h\) heads is a \(T x d_k\) matrix, and these \(h\) matrices are concatenated together column-wise, and then multiplied by a “mixing matrix” of sorts in \(W^O\), which allows the model to combine signals across heads into a single unified representation. Linear operation.
That’s all there is to it: just a sequence of multiplications of learned matrices. It is a strictly linear operation.
Linear operations are not sufficient
Let’s say that one token’s attention output has two features, and let’s call them \(a\) and \(b\). Maybe \(a\) represents “this token is a subject”, and \(b\) represents “this token is plural”.
A linear function of \(a\) and \(b\) can compute things like \(2a + 3b\) or \(a-b\).
But what if the thing that we need to detect, that is actually important to understanding the sequence, is that something is “subject and plural simultaneously”. This is to say, what if what we actually need to understand is something that’s only meaningful when both features are active together. That’s more like \(a \times b\), or a threshold: “if \(a > 0.5\) and \(b > 0.5\)”.
A linear function cannot compute \(a \times b\) or apply thresholding logic. We need a neural network to learn these kinds of relationships, and that is what the feed-forward network is here to do. MHA can almost be thought of as an elaborate feature engineering exercise, and the FFN is the simple model that extracts the power from those complex features.
Why position-wise?
Crucially, the feed-forward network is applied position-wise, which means that the FFN is only working on each token in isolation. It has no mechanism to mix information across positions: it sees one token vector, transforms it, and moves to the next, and the same weight matrices are applied to each token.
What this means conceptually is that all cross-token information is coming from the attention sub-layer: once we are past MHA, we are operating exclusively at the token level.
This should help drive home the importance of both MHA and the FFN. MHA has constructed row vectors for each token that incorporate the context in which that token appeared, but from that point forward, it acts as a single sample in isolation, producing its own output which no longer interacts with the other tokens in the sequence.
An example
Let’s go back to our sample sentence:
The creek was cold and the spunky dog wanted to jump over it.
Multi-head attention has incorporated things like the fact that “it” is a pronoun referring to “creek”, and that “cold” is an adjective modifying “creek” much in the same way that “spunky” is an adjective modifying “dog”, and it has produced a vector representation for “it” that has been appropriately weighted by the representations of “creek” and “cold”.
These are all features of the tokens in the sequence, and when these features are read together, we as humans who understand the language generate a mental picture that incorporates their contextual meanings.
A “cold creek” produces a much different mental picture than a “warm spring”. In both cases, it’s a natural feature that the dog is jumping over. But the case of the cold creek, we might think of something uncomfortable, unpleasant, something the dog is trying to avoid. If we’re talking about a warm spring, the context changes considerably: a warm spring sounds desirable, maybe playful.
This is a essentially what the FFN is doing. It is taking a matrix of features from MHA, and synthesizing it into a “picture”: a representation of the tokens that has now had the ability to incorporate and model all of the relationships derived from the attention process into a rich, coherent view of the full sequence.
The computation
The computation of the feed-forward network itself is quite straightforward. As shown in the paper, the FFN is defined by:
\[ \mathrm{FFN}(x) = \mathrm{max}(0, xW_1 + b_1)W_2 + b_2 \]
Descriptively, you’ll notice that this is simply two linear transformations applied consecutively, but with a ReLU activation function in between. It is this ReLU that actually produces the non-linearity.
Step-by-step calculation
The FFN receives the output from MHA as a \(T \times d_{\mathrm{model}}\) matrix.
The first step is a linear transformation to \(d_{ff}\)-dimensional space, so \(W_1\) is a \(d_{\mathrm{model}} \times d_{ff}\) learned weight matrix. After \(X\) is multiplied by \(W_1\), we have a \(T \times d_{ff}\) matrix.
Then ReLU is applied to this matrix (i.e. \(\mathrm{max}(0, \cdot)\)). All ReLU does is zero out any values any negative elements of our first linear transformation, but this simple function is all it takes to create the non-linearity we need to model more complex relationships.
Effectively, what we have after ReLU is a matrix of \(d_{ff}\)-dimensional vectors, one for each token, which indicates which detectors fired for that particular token, and their magnitude. But at this stage, those firing neurons are independent: they don’t yet know about each other.
We then perform one final linear transformation, back down to \(d_{\mathrm{model}}\), which synthesizes the signals from the fired detectors, and writes them back into a usable output representation.
How does ReLU create complex non-linearity?
This is rooted in the universal approximation theorem (Cybenko 1989, Hornik 1991), which states that a feed-forward neural network with a single hidden layer can approximate any continuous function within a closed and bounded domain.
So how does ReLU do this, exactly? Let’s think about this concretely, using a simple function, say \(f(x) = x^2\).
ReLU itself is a piecewise linear function (i.e. \(0\) for \(x < 0\) and \(x\) for \(x >= 0\)). So for a single neuron, we have 2 linear pieces, created by the one “kink” that the ReLU affords. But if we have \(N\) neurons in our feed-forward network, this simple non-linearity is applied \(N\) times (simultaneously, to each neuron), meaning we get \(N\) kinks, and therefore \(N + 1\) distinct linear pieces.
On any arbitrary bounded interval \([A, B]\) in the domain of \(f(x)\), we can use those \(N + 1\) linear pieces to approximate \(x^2\) in that region.
And this approximation can be arbitrarily good: if we want a better approximation, we can simply add more neurons, meaning we get more segments with which to approximate.
It is worth noting that this applies on any bounded interval for any continuous function, but breaks the moment we remove the bounded limitation. Any feed-forward network has a finite number of neurons, meaning it has a finite number of lines with which to approximate the underlying function. Beyond the rightmost kink in our piecewise linear approximation, the network’s approximation is exactly one single line, forever, at which point it is impossible to approximate a non-linear function to any level of precision. In practice this is not a problem, because inputs to any given layer are bounded by the preceding normalizations.
On dimensionality
The paper notes that the dimensionality of each transformation is actually different: the first transformation increases dimensionality from \(d_{\mathrm{model}}\) to \(d_{ff}\) (which is a new parameter, not used previously in the transformer), and the second reduces it back down from \(d_{ff}\) to \(d_{\mathrm{model}}\).
The authors used \(d_{ff} = 2048\), which is a 4x increase in dimensionality over \(d_{\mathrm{model}}\). This is, on some level, an arbitrary choice. There is nothing mathematically derived about the selection of 2048 or \(4d_{\mathrm{model}}\), it is simply empirical: the authors found that it worked well and was computationally feasible.
The reason for this expansion is that the increased dimensionality gives the model more room to learn complex relationships. It creates a higher-dimensional space where the ReLU non-linearity can carve out more complex decision boundaries before projecting back down to \(d_{\mathrm{model}}\).
Implementation
We can implement the position-wise feed-forward network using PyTorch as follows:
class feed-forwardNetwork(nn.Module):
def __init__(self, d_model=512, d_ff=2048):
super().__init__()
## Static components
self.d_model = d_model
self.d_ff = d_ff
self.relu = nn.ReLU()
## Learnable components
self.w_1 = nn.Linear(self.d_model, self.d_ff)
self.w_2 = nn.Linear(self.d_ff, self.d_model)
def forward(self, X):
# Start with first transformation, which is inside the ReLU:
x1 = self.w_1(X)
# Now apply ReLU
x_relu = self.relu(x1)
# Now apply second transformation, which is now outside of the ReLU:
x2 = self.w_2(x_relu)
return x2As previously noted, the computation for the FFN is quite straightforward, as is the implementation of it.
__init__
The feed-forwardNetwork function has two primary static components: d_model and d_ff. We also add relu to our class __init__(), simply because it prevents us from having to instantiate a separate nn.ReLU() every time forward() is called.
We then have two learnable components, which are our weight matrices w_1 and w_2.
forward
The forward function is just three lines.
First, we apply our first linear transformation, which occurs inside the ReLU: x1 = self.w_1(X).
Then, we apply ReLU, which because of our instantiation of nn.ReLU() at class initialization is simply: self.relu(x1).
And then our final linear transformation: self.w_2(x_relu).