Attention Layers
Contents
12. Attention Layers#
Attention is a concept in machine learning and AI that goes back many years, especially in computer vision[BP97]. Like the word âneural networkâ, attention was inspired by the idea of attention in how human brains deal with the massive amount of visual and audio input[TG80]. Attention layers are deep learning layers that evoke the idea of attention. You can read more about attention in deep learning in Luong et al. [LPM15] and get a practical overview here. Attention layers have been empirically shown to be so effective in modeling sequences, like language, that they have become indispensable[VSP+17]. The most common place youâll see attention layers is in transformer neural networks that model sequences. Weâll also sometimes see attention in graph neural networks.
Audience & Objectives
This chapter builds on Standard Layers and Tensors and Shapes. You should be comfortable with broadcasting, matrices, and tensor shapes. After completing this chapter, you should be able to
Correctly specify shapes and input/output of attention layers
Implement attention layers
Understand how attention can be put into other layer types
Attention layers are fundamentally a weighted mean reduction. It is just computing a mean, where you somehow weight each element contributing to the mean. Since it is a mean, attention decreases the rank of an input tensor. Attention is unusual among layers because it takes three inputs, whereas most layers in deep learning take just one or perhaps two. These inputs are called the query, the values, and the keys. The reduction occurs over the values; so if the values are rank 3, the output will be rank 2. The query should be one less rank than the keys. The keys should be the same rank as the values. The keys and query determine how to weight the values according the attention mechanism â a fancy word for equation.
The table below summarizes these three input arguments. Note that often the query is batched, so that its rank will be 2 if batched. If the input query is batched, then the outputâs rank will be batched as well and be 2 instead of 1.
Rank |
Axes |
Purpose |
Example |
|
---|---|---|---|---|
Query |
1 |
(# of attn features) |
input for checking against keys |
One word represented as feature vector |
Keys |
2 |
(sequence length, # of attn features) |
used to compute attention against query |
All words in a sentence represented as matrix of feature vectors |
Values |
2 |
(sequence length, # of value features) |
used to compute value of output |
A vector of numbers for each word in a sentence |
Output |
1 |
(# of value features) |
attention-weighted mean over values |
single vector |
12.1. Example#
Attention is best conceptualized as operating on a sequence. Letâs use a sentence like âThe sleepy child reads a bookâ. The words in the sentence correspond to the keys. If we represent our words as embeddings, our keys will be rank 2. For example, the word âsleepyâ might be represented by an embedding vector of length 2: \([2,0,1]\), where these embedding values are trained or taken from a standard language embedding. By convention, the zeroth axis of keys will be the position in the sequence and the first axis contains these vectors. The query is often an element from the keys, like the word âbook.â The point of attention is to see what parts of the sentence the query should be influenced by. âBookâ should have strong attention on âchildâ and âreads,â but probably not to âsleepy.â Youâll see soon that we will actually compute this as a vector, called the attention vector \(\vec{b}\). The output from the attention layer will be a reduction over the values where each element of values is weighted by the attention between the query and the key. Thus there should be one key and one value for each element in our sentence. The values could be identical to the keys, which is common.
Letâs see how this looks mathematically. The attention layer consists of two steps: (1) computing the attention vector \(\vec{b}\) using the attention mechanism and (2) the reduction over the values using the attention vector \(\vec{b}\). Attention mechanism is a fancy word for the attention equation. Consider our example above. Weâll use a 3-dimensional embedding for our words
Index |
Embedding |
Word |
---|---|---|
0 |
0,0,0 |
The |
1 |
2,0,1 |
Sleepy |
2 |
1,-1,-2 |
Child |
3 |
2,3,1 |
Reads |
4 |
-2,0,0 |
A |
5 |
0,2,1 |
Book |
The keys will be a rank 2 tensor (matrix) putting all these together. Note that these are only integers to make this example clearer, typically words are represented with floating point numbers when embedded.
They keys are shape \((6, 3)\) because our sentence has 6 words and each word is represented with a 3 dimensional embedding vector. Letâs make our values simple, weâll have one for each word. These values are what determine our output. Perhaps they could be the sentiment of the word: is it a positive word (âhappyâ) or a negative word (âangryâ).
Note that the values \(\mathbf{V}\) should be the same rank as the keys, so its shape is interpreted as \((6, 1)\). Finally, the query which should be one rank less than the keys. Our query is the word âbook:â
12.2. Attention Mechanism Equation#
The attention mechanism equation uses query and keys arguments only. It outputs a tensor one rank less than the keys, giving a scalar for each key corresponding to the attention the query should have for the key. This attention vector should be normalized. The most common attention mechanism is a dot product and softmax:
where index \(i\) is the position in the sequence and \(j\) is the index of the feature. Softmax is defined as
and ensures that \(\vec{b}\) is normalized. Substituting our values from the example above:
Iâve rounded the numbers here, but essentially the attention vector only gives weight to the word itself (book) and the verb âreadâ. I made this up, remember, but it gives you an idea of how attention gives you a way to connect words. It may even remind you of our graph neural networkâs idea of neighbors.
12.3. Attention Reduction#
After computing the attention vector \(\vec{b}\), this is used to compute a weighted mean over the values:
Conceptually, our example computed the attention-weighted sentiment of the query word âbookâ in our sentence. You can see that attention layers do two things: compute an attention vector with the attention mechanism and then use it to take the attention-weighted average over the values.
12.4. Tensor-Dot#
This dot product, softmax, and reduction is called a tensor-dot and is the most common attention layer[LPM15]. One common modification is to divide by the dimension of the keys (last axis dimension). Remember the keys are not normalized. If they are random numbers, the magnitude of the output from the dot product scales with the square root of the dimension of the keys due to the central limit theorem. This can make the soft-max behave poorly, since youâre taking \(e^{\vec{q} \cdot \mathbf{K}}\). Putting this all together, the equation is:
where \(d\) is the dimension of the query vector.
12.5. Soft, Hard, and Temperature Attention#
One possible change to attention is to replace the \(\mathrm{softmax}\) with a one at the position of highest attention and zero at all others. This is called hard attention. The equation for hard attention is to replace softmax with a âhardmaxâ, defined as
which is a mathematical way to formulate putting \(1\) in the position of the largest element of \(\vec{x}\) and a \(0\) at all others. The choice of \(T\) is for the word temperature, because this equation is similar to Boltzmannâs distribution from statistical mechanics. You can see that limit \(T = 0\) is the hard attention, \(T = 1\) is the soft attention, and \(T = \infty\) means uniform attention. you could tune \(T\) to some intermediate values as well.
12.6. Self-Attention#
Remember how everything is batched in deep learning? The batched input to an attention layer is usually the query. So although in the above discussion it was a tensor of one rank less than the keys (typically a query vector), once it has been batched it will be the same rank as the keys. Almost always, the query is in fact equal to the keys. Like in our example, our query was the embedding vector for the word âbookâ, which is one of the keys. If you consider the query to be batched so that you consider every word in the sentence, the query becomes equal to the keys. A further special case is when the query, values and keys are equal. This is called self-attention. This just means our attention mechanism uses the values directly and there is no extra set of âkeysâ input to the layer.
12.7. Trainable Attention#
There are no trainable parameters in our definitions above. How can you do learning with attention? Typically, you donât have trainable parameters in equations directly. Instead, you put the keys, values, and query through a dense layer (see Standard Layers) before the attention. So when viewed as a layer, attention has no trainable parameters. When viewed as a block with a dense layer and attention layer, it is trainable. Weâll see this now explicitly below.
12.8. Multi-head Attention Block#
Inspired by the idea of convolutions with multiple filters, there is a block (group of layers) that splits to multiple parallel attentions. These are called âmulti-head attentionâ. If your values are shape \((L, V)\), you will get back a \((H, V)\) tensor, where \(H\) is the number of parallel attention layers (heads). If there are no trainable parameters in attention layers, whatâs the point of this though? Well, you must introduce weights. These are square weight matrices because we need all shapes to remain constant among all the attention heads.
Consider an attention layer to be defined by \(A(\vec{q}, \mathbf{K}, \mathbf{V})\). The multi-head attention is
where each element of the output vector \([\ldots]\) is itself an output vector from an attention layer, making \(H\) \((L, V)\) shaped tensors. So the whole output is an \((H, L, V)\) tensor. The most famous example of the multi-head attention block is in transformers[VSP+17] where they use self-attention multi-head attention blocks.
Typically we apply multiple sequential blocks of attention, so need the values input to the next block to be of rank 2 again (instead of the rank 3 \((H, L, V)\) tensor). Thus the output from the multi-head attention is often reduced by matrix multiplication with an \((H, V, V)\) weight tensor or a \((H)\) tensor of weights so that you get back to rank 2. If this seems confusing, see the example below.
12.9. Running This Notebook#
Click the   above to launch this page as an interactive Google Colab.
12.10. Code Examples#
Letâs see how attention can be implemented in code. I will use random variables here for the different quantities but I will indicate which variables should be trained with w_
and which should be inputs with i_
.
12.10.1. Tensor-Dot Mechanism#
Weâll begin with implementing the tensor-dot attention mechanism first. As an example, weâll use a sequence length of 11 and a keys feature length of 4 and a values feature dimension of 2. Remember the keys and query must share feature dimension size.
import numpy as np
def softmax(x, axis=None):
return np.exp(x) / np.sum(np.exp(x), axis=axis)
def tensor_dot(q, k):
b = softmax((k @ q) / np.sqrt(q.shape[0]))
return b
i_query = np.random.normal(size=(4,))
i_keys = np.random.normal(size=(11, 4))
b = tensor_dot(i_query, i_keys)
print("b = ", b)
b = [0.10708407 0.2159088 0.18910064 0.02449849 0.09297742 0.04962959
0.02203747 0.06559855 0.11954705 0.04353657 0.07008136]
As expected, we get out a vector \(\vec{b}\) whose sum is 1.
12.10.2. General Attention#
Now letâs put this attention mechanism into an attention layer.
def attention_layer(q, k, v):
b = tensor_dot(q, k)
return b @ v
i_values = np.random.normal(size=(11, 2))
attention_layer(i_query, i_keys, i_values)
array([0.14328042, 0.0463141 ])
We get two values, one for each feature dimension.
12.10.3. Self-attention#
The change in self-attention is that we make queries, keys, and values equal. We need to make a small change in that the queries are batched in this setting, so we should get a rank 2 output.
def batched_tensor_dot(q, k):
# a will be batch x seq x feature dim
# which is N x N x 4
# batched dot product in einstein notation
a = np.einsum("ij,kj->ik", q, k) / np.sqrt(q.shape[0])
# now we softmax over sequence
b = softmax(a, axis=1)
return b
def self_attention(x):
b = batched_tensor_dot(x, x)
return b @ x
i_batched_query = np.random.normal(size=(11, 4))
self_attention(i_batched_query)
array([[ 0.01832814, 0.0749668 , -0.04577256, 0.21776567],
[-0.28260916, 0.1767006 , 0.43049974, 0.35189732],
[-1.12945171, 0.74194802, 0.26632622, 0.1826393 ],
[ 0.08365006, 0.08586253, -0.11776762, 0.23627548],
[ 0.02641636, -0.01291467, 0.3669651 , 0.73047754],
[-0.12526439, 0.40507159, 0.20178456, 0.3394749 ],
[-0.08332396, -0.4524976 , 0.67658427, 1.38369207],
[ 0.37910667, 0.32785657, 0.06896593, 0.7083815 ],
[-0.12169481, -0.16310902, 0.38171132, 0.58598476],
[ 0.03113162, 0.10217473, -0.01623162, 0.32127187],
[-0.08716809, 0.25544837, 0.28385506, 0.42890038]])
We are given as output an \(11\times4\) matrix, which is correct.
12.10.4. Adding Trainable Parameters#
You can add trainable parameters to these steps by adding a weight matrix. Letâs do this for the self-attention. Although keys, values, and query are equal in self-attention, I can multiply them by different weights. Just to demonstrate, Iâll have the values change to feature dimension 2.
# weights should be input feature_dim -> desired output feature_dim
w_q = np.random.normal(size=(4, 4))
w_k = np.random.normal(size=(4, 4))
w_v = np.random.normal(size=(4, 2))
def trainable_self_attention(x, w_q, w_k, w_v):
q = x @ w_q
k = x @ w_k
v = x @ w_v
b = batched_tensor_dot(q, k)
return b @ v
trainable_self_attention(i_batched_query, w_q, w_k, w_v)
array([[-0.12415358, 0.19886661],
[ 0.3107283 , 0.62858142],
[-6.14330661, 10.38434859],
[-0.18217992, 0.2592068 ],
[ 1.25461404, 1.07323339],
[-0.2912892 , 3.2461432 ],
[ 2.52707095, 1.62410927],
[ 0.57845772, 2.03749672],
[ 1.13981545, 0.60648038],
[-0.07851417, 0.55365682],
[ 0.20726197, 1.38793509]])
Since we had our values change to feature dimension 2 with the weights, we get out an \(11\times 2\) output.
12.10.5. Multi-head#
The only change for multi-head attention is that we have one set of weights for each head and we agree on how to combine after running through the heads. Iâll just use a length \(H\) vector of trainable weights. Other strategies are to concatenate them or use a reduction (e.g., mean, max).
w_q_h1 = np.random.normal(size=(4, 4))
w_k_h1 = np.random.normal(size=(4, 4))
w_v_h1 = np.random.normal(size=(4, 2))
w_q_h2 = np.random.normal(size=(4, 4))
w_k_h2 = np.random.normal(size=(4, 4))
w_v_h2 = np.random.normal(size=(4, 2))
w_h = np.random.normal(size=2)
def multihead_attention(x, w_q_h1, w_k_h1, w_v_h1, w_q_h2, w_k_h2, w_v_h2):
h1_out = trainable_self_attention(x, w_q_h1, w_k_h1, w_v_h1)
h2_out = trainable_self_attention(x, w_q_h2, w_k_h2, w_v_h2)
# join along last axis so we can use dot.
all_h = np.stack((h1_out, h2_out), -1)
return all_h @ w_h
multihead_attention(i_batched_query, w_q_h1, w_k_h1, w_v_h1, w_q_h2, w_k_h2, w_v_h2)
array([[ 6.77177965e-01, 5.59331900e-02],
[ 2.27143381e+00, 3.78831009e-01],
[ 1.09948881e+03, 5.83190853e+02],
[ 1.35308286e+00, 3.07478796e-01],
[ 3.68488594e-01, -5.26710264e-01],
[ 1.36394854e+02, -9.30940659e+00],
[-2.35289075e+00, -9.72529355e-01],
[ 1.80170350e+01, 4.90282722e-02],
[-1.69200964e+00, -7.16240273e-01],
[ 1.67545404e+00, 5.52312680e-01],
[ 3.22989922e+01, -1.62346886e+00]])
As expected, we do get an \(11\times2\) rank 2 output.
12.11. Attention in Graph Neural Networks#
Recall that the key attribute of a graph neural network is permutation equivariance. We used reductions like sum or mean over neighbors as the way to make the graph neural network layers be permutation equivariant. Attention layers are also permutation invariant (when not batched) and equivariant (when batched). This has made attention a popular choice for how to aggregate neighbor information. Attention layers are good at finding important neighbors and so are important with high-degree graphs (lots of neighbors). This is rare in molecules, but you can just define all atoms to be connected and then put distances as the edge attributes. Recall that graph convolution layers (GCN layer), and most GNN layers, only allow information to propagate one-bond per layer. Thus joining all atoms and using attention can give you long-range communication without so many layers. The disadvantage is that your network must now learn how to give attention to the correct bonds/atoms.
Letâs see how attention fits into the Battaglia equations[BHB+18]. Recall that the Battaglia equations are general standard equations for defining a GNN. Attention can appear in multiple places, but as discussed above it appears when considering neighbors. Specifically, the query will be the \(i\)th node, and the keys/values will be some combination of neighboring node and edge features. There is no step in the Battaglia equations where this fits neatly, but we can split up the attention layer as follows. Most of the attention layer will fit into the edge update equation:
Recall that this is a general equation and our choice of \(\phi^e()\) defines the GNN. \(\vec{e}_k\) is the feature vector of edge \(k\), \(\vec{v}_{rk}\) is the receiving node feature vector for edge \(k\), \(\vec{v}_{sk}\) is the sending node feature vector for edge \(k\), \(\vec{u}\) is the global graph feature vector. We will use this step for attention mechanism where the query is the receiving node \(\vec{v}_{rk}\) and the keys/values are composed of the senders and edges vectors. To be specific, weâll use the approach from Zhang et al. [ZSX+18] with a tensor-dot mechanism. They only considered node features and set the keys and values to be identical as the node features. However, they put trainable parameters at each layer that translated the node features in to the keys/query.
Putting it compactly into one equation:
Now we have weighted edge feature vectors from the attention. Finally, we sum over these edge features in the edge aggregation step.
In Zhang et al. [ZSX+18], they used multi-headed attention as well. How would multi-headed attention work? Your edge feature matrix \( E_i^{'}\) now becomes an edge feature tensor, where axis 0 is edge (\(k\)), axis 1 is feature, and axis 2 is the head. Recall that the âheadâ just means which set of \(\mathbf{W}^h_q, \mathbf{W}^h_k, \mathbf{W}^h_v\) we used. To reduce the tensor back to the expected matrix, we simply use another weight matrix that maps from the last two axes (feature, head) down to features only. I will write-out the indices explicitly to be more clear:
where \(j\) is edge feature input index, \(l\) is our output edge feature matrix, and \(k,h,i\) are defined as before. Transformer is another name for a network built on multi-headed attention, so youâll also see transformer graph neural networks [MDM+20] building.
12.12. Chapter Summary#
Attention layers are inspired by human ideas of attention, but is fundamentally a weighted mean reduction.
The attention layer takes in three inputs: the query, the values, and the keys. These inputs are often identical, where the query is one key and the keys and the values are equal.
They are good at modeling sequences, such as language.
The attention vector should be normalized, which can be achieved using a softmax activation function, but the attention mechanism equation is a hyperparameter.
Attention layers compute an attention vector with the attention mechanism, and then reduce it by computing the attention-weighted average.
Using hard attention (hardmax function) returns the maximum output from the attention mechanism.
The tensor-dot followed by a softmax is the most common attention mechanism.
Self-attention is achieved when the query, values, and the keys are equal.
Attention layers by themselves are not trainable.
Multi-head attention block is a group of layers that splits to multiple parallel attentions.
12.13. Cited References#
- BHB+18
Peter W Battaglia, Jessica B Hamrick, Victor Bapst, Alvaro Sanchez-Gonzalez, Vinicius Zambaldi, Mateusz Malinowski, Andrea Tacchetti, David Raposo, Adam Santoro, Ryan Faulkner, and others. Relational inductive biases, deep learning, and graph networks. arXiv preprint arXiv:1806.01261, 2018.
- ZSX+18(1,2)
Jiani Zhang, Xingjian Shi, Junyuan Xie, Hao Ma, Irwin King, and Dit-Yan Yeung. Gaan: gated attention networks for learning on large and spatiotemporal graphs. arXiv preprint arXiv:1803.07294, 2018.
- BP97
Shumeet Baluja and Dean A. Pomerleau. Expectation-based selective attention for visual monitoring and control of a robot vehicle. Robotics and Autonomous Systems, 22(3):329â344, 1997. Robot Learning: The New Wave. URL: http://www.sciencedirect.com/science/article/pii/S0921889097000468, doi:https://doi.org/10.1016/S0921-8890(97)00046-8.
- TG80
Anne M Treisman and Garry Gelade. A feature-integration theory of attention. Cognitive psychology, 12(1):97â136, 1980.
- LPM15(1,2)
Minh-Thang Luong, Hieu Pham, and Christopher D Manning. Effective approaches to attention-based neural machine translation. arXiv preprint arXiv:1508.04025, 2015.
- VSP+17(1,2)
Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez, Ĺukasz Kaiser, and Illia Polosukhin. Attention is all you need. In Advances in neural information processing systems, 5998â6008. 2017.
- MDM+20
Ĺukasz Maziarka, Tomasz Danel, SĹawomir Mucha, Krzysztof Rataj, Jacek Tabor, and StanisĹaw JastrzÄbski. Molecule attention transformer. arXiv preprint arXiv:2002.08264, 2020.