12. Attention Layers#

Attention is a concept in machine learning and AI that goes back many years, especially in computer vision[BP97]. Like the word “neural network”, attention was inspired by the idea of attention in how human brains deal with the massive amount of visual and audio input[TG80]. Attention layers are deep learning layers that evoke the idea of attention. You can read more about attention in deep learning in Luong et al. [LPM15] and get a practical overview here. Attention layers have been empirically shown to be so effective in modeling sequences, like language, that they have become indispensable[VSP+17]. The most common place you’ll see attention layers is in transformer neural networks that model sequences. We’ll also sometimes see attention in graph neural networks.

Audience & Objectives

This chapter builds on Standard Layers and Tensors and Shapes. You should be comfortable with broadcasting, matrices, and tensor shapes. After completing this chapter, you should be able to

  • Correctly specify shapes and input/output of attention layers

  • Implement attention layers

  • Understand how attention can be put into other layer types

Attention layers are fundamentally a weighted mean reduction. It is just computing a mean, where you somehow weight each element contributing to the mean. Since it is a mean, attention decreases the rank of an input tensor. Attention is unusual among layers because it takes three inputs, whereas most layers in deep learning take just one or perhaps two. These inputs are called the query, the values, and the keys. The reduction occurs over the values; so if the values are rank 3, the output will be rank 2. The query should be one less rank than the keys. The keys should be the same rank as the values. The keys and query determine how to weight the values according the attention mechanism – a fancy word for equation.

The table below summarizes these three input arguments. Note that often the query is batched, so that its rank will be 2 if batched. If the input query is batched, then the output’s rank will be batched as well and be 2 instead of 1.

Rank

Axes

Purpose

Example

Query

1

(# of attn features)

input for checking against keys

One word represented as feature vector

Keys

2

(sequence length, # of attn features)

used to compute attention against query

All words in a sentence represented as matrix of feature vectors

Values

2

(sequence length, # of value features)

used to compute value of output

A vector of numbers for each word in a sentence

Output

1

(# of value features)

attention-weighted mean over values

single vector

12.1. Example#

Attention is best conceptualized as operating on a sequence. Let’s use a sentence like “The sleepy child reads a book”. The words in the sentence correspond to the keys. If we represent our words as embeddings, our keys will be rank 2. For example, the word “sleepy” might be represented by an embedding vector of length 2: \([2,0,1]\), where these embedding values are trained or taken from a standard language embedding. By convention, the zeroth axis of keys will be the position in the sequence and the first axis contains these vectors. The query is often an element from the keys, like the word “book.” The point of attention is to see what parts of the sentence the query should be influenced by. “Book” should have strong attention on “child” and “reads,” but probably not to “sleepy.” You’ll see soon that we will actually compute this as a vector, called the attention vector \(\vec{b}\). The output from the attention layer will be a reduction over the values where each element of values is weighted by the attention between the query and the key. Thus there should be one key and one value for each element in our sentence. The values could be identical to the keys, which is common.

Let’s see how this looks mathematically. The attention layer consists of two steps: (1) computing the attention vector \(\vec{b}\) using the attention mechanism and (2) the reduction over the values using the attention vector \(\vec{b}\). Attention mechanism is a fancy word for the attention equation. Consider our example above. We’ll use a 3-dimensional embedding for our words

Index

Embedding

Word

0

0,0,0

The

1

2,0,1

Sleepy

2

1,-1,-2

Child

3

2,3,1

Reads

4

-2,0,0

A

5

0,2,1

Book

The keys will be a rank 2 tensor (matrix) putting all these together. Note that these are only integers to make this example clearer, typically words are represented with floating point numbers when embedded.

(12.1)#\[\begin{equation} \mathbf{K} = \left[ \begin{array}{lccccr} 0 & 2 & 1 & 2 & -2 & 0\\ 0 & 0 & -1 & 3 & 0 & 2\\ 0 & 1 & -2 & 1 & 0 & 1\\ \end{array}\right] \end{equation}\]

They keys are shape \((6, 3)\) because our sentence has 6 words and each word is represented with a 3 dimensional embedding vector. Let’s make our values simple, we’ll have one for each word. These values are what determine our output. Perhaps they could be the sentiment of the word: is it a positive word (“happy”) or a negative word (“angry”).

(12.2)#\[\begin{equation} \mathbf{V} = \left[ 0, -0.2, 0.3, 0.4, 0, 0.1\right] \end{equation}\]

Note that the values \(\mathbf{V}\) should be the same rank as the keys, so its shape is interpreted as \((6, 1)\). Finally, the query which should be one rank less than the keys. Our query is the word “book:”

(12.3)#\[\begin{equation} \vec{q} = \left[0, 2, 1\right] \end{equation}\]

12.2. Attention Mechanism Equation#

The attention mechanism equation uses query and keys arguments only. It outputs a tensor one rank less than the keys, giving a scalar for each key corresponding to the attention the query should have for the key. This attention vector should be normalized. The most common attention mechanism is a dot product and softmax:

(12.4)#\[\begin{equation} \vec{b} = \mathrm{softmax}\left(\vec{q}\cdot \mathbf{K}\right) = \mathrm{softmax}\left(\sum_j q_j k_{ij}\right) \end{equation}\]

where index \(i\) is the position in the sequence and \(j\) is the index of the feature. Softmax is defined as

(12.5)#\[\mathrm{softmax}\left(\vec{x}\right) = \frac{e^\vec{x}}{\sum_i e^ x_i}\]

and ensures that \(\vec{b}\) is normalized. Substituting our values from the example above:

(12.6)#\[\begin{equation} \vec{b} = \mathrm{softmax}\left(\left[0, 2, 1\right] \times \left[ \begin{array}{lccccr} 0 & 2 & 1 & 2 & -2 & 0\\ 0 & 0 & -1 & 3 & 0 & 2\\ 0 & 1 & -2 & 1 & 0 & 1\\ \end{array}\right]\right) = \mathrm{softmax}\left( \left[0, 1, -4, 7, 0, 5\right]\right) \end{equation}\]
(12.7)#\[\begin{equation} \vec{b} = \left[0, 0, 0, 0.88, 0, 0.12\right] \end{equation}\]

I’ve rounded the numbers here, but essentially the attention vector only gives weight to the word itself (book) and the verb “read”. I made this up, remember, but it gives you an idea of how attention gives you a way to connect words. It may even remind you of our graph neural network’s idea of neighbors.

12.3. Attention Reduction#

After computing the attention vector \(\vec{b}\), this is used to compute a weighted mean over the values:

(12.8)#\[\begin{equation} \mathbf{V}\vec{b} = \left[0, 0, 0, 0.88, 0, 0.12\right]^ T \left[ 0, -0.2, 0.3, 0.4, 0, 0.1\right] = 0.36 \end{equation}\]

Conceptually, our example computed the attention-weighted sentiment of the query word “book” in our sentence. You can see that attention layers do two things: compute an attention vector with the attention mechanism and then use it to take the attention-weighted average over the values.

12.4. Tensor-Dot#

This dot product, softmax, and reduction is called a tensor-dot and is the most common attention layer[LPM15]. One common modification is to divide by the dimension of the keys (last axis dimension). Remember the keys are not normalized. If they are random numbers, the magnitude of the output from the dot product scales with the square root of the dimension of the keys due to the central limit theorem. This can make the soft-max behave poorly, since you’re taking \(e^{\vec{q} \cdot \mathbf{K}}\). Putting this all together, the equation is:

(12.9)#\[\begin{equation} \vec{b} = \mathrm{softmax}\left(\frac{1}{\sqrt{d}}\vec{q}\cdot \mathbf{K}\right) \end{equation}\]

where \(d\) is the dimension of the query vector.

12.5. Soft, Hard, and Temperature Attention#

One possible change to attention is to replace the \(\mathrm{softmax}\) with a one at the position of highest attention and zero at all others. This is called hard attention. The equation for hard attention is to replace softmax with a “hardmax”, defined as

(12.10)#\[\begin{equation} \mathrm{hardmax}\left(\vec{x}\right) = \lim_{T\rightarrow0}\frac{e^\vec{x} / T}{\sum_i e^ {x_i / T}} \end{equation}\]

which is a mathematical way to formulate putting \(1\) in the position of the largest element of \(\vec{x}\) and a \(0\) at all others. The choice of \(T\) is for the word temperature, because this equation is similar to Boltzmann’s distribution from statistical mechanics. You can see that limit \(T = 0\) is the hard attention, \(T = 1\) is the soft attention, and \(T = \infty\) means uniform attention. you could tune \(T\) to some intermediate values as well.

12.6. Self-Attention#

Remember how everything is batched in deep learning? The batched input to an attention layer is usually the query. So although in the above discussion it was a tensor of one rank less than the keys (typically a query vector), once it has been batched it will be the same rank as the keys. Almost always, the query is in fact equal to the keys. Like in our example, our query was the embedding vector for the word “book”, which is one of the keys. If you consider the query to be batched so that you consider every word in the sentence, the query becomes equal to the keys. A further special case is when the query, values and keys are equal. This is called self-attention. This just means our attention mechanism uses the values directly and there is no extra set of “keys” input to the layer.

12.7. Trainable Attention#

There are no trainable parameters in our definitions above. How can you do learning with attention? Typically, you don’t have trainable parameters in equations directly. Instead, you put the keys, values, and query through a dense layer (see Standard Layers) before the attention. So when viewed as a layer, attention has no trainable parameters. When viewed as a block with a dense layer and attention layer, it is trainable. We’ll see this now explicitly below.

12.8. Multi-head Attention Block#

Inspired by the idea of convolutions with multiple filters, there is a block (group of layers) that splits to multiple parallel attentions. These are called “multi-head attention”. If your values are shape \((L, V)\), you will get back a \((H, V)\) tensor, where \(H\) is the number of parallel attention layers (heads). If there are no trainable parameters in attention layers, what’s the point of this though? Well, you must introduce weights. These are square weight matrices because we need all shapes to remain constant among all the attention heads.

Consider an attention layer to be defined by \(A(\vec{q}, \mathbf{K}, \mathbf{V})\). The multi-head attention is

(12.11)#\[\begin{equation} \left[A(\mathbf{W}_q^0\vec{q}, \mathbf{W}_k^0\mathbf{K}, \mathbf{W}_v^0\mathbf{V}), A(\mathbf{W}_q^1\vec{q}, \mathbf{W}_k^1\mathbf{K}, \mathbf{W}_v^1\mathbf{V}), \ldots, A(\mathbf{W}_q^H\vec{q}, \mathbf{W}_k^H\mathbf{K}, \mathbf{W}_v^H\mathbf{V})\right] \end{equation}\]

where each element of the output vector \([\ldots]\) is itself an output vector from an attention layer, making \(H\) \((L, V)\) shaped tensors. So the whole output is an \((H, L, V)\) tensor. The most famous example of the multi-head attention block is in transformers[VSP+17] where they use self-attention multi-head attention blocks.

Typically we apply multiple sequential blocks of attention, so need the values input to the next block to be of rank 2 again (instead of the rank 3 \((H, L, V)\) tensor). Thus the output from the multi-head attention is often reduced by matrix multiplication with an \((H, V, V)\) weight tensor or a \((H)\) tensor of weights so that you get back to rank 2. If this seems confusing, see the example below.

12.9. Running This Notebook#

Click the    above to launch this page as an interactive Google Colab.

12.10. Code Examples#

Let’s see how attention can be implemented in code. I will use random variables here for the different quantities but I will indicate which variables should be trained with w_ and which should be inputs with i_.

12.10.1. Tensor-Dot Mechanism#

We’ll begin with implementing the tensor-dot attention mechanism first. As an example, we’ll use a sequence length of 11 and a keys feature length of 4 and a values feature dimension of 2. Remember the keys and query must share feature dimension size.

import numpy as np


def softmax(x, axis=None):
    return np.exp(x) / np.sum(np.exp(x), axis=axis)


def tensor_dot(q, k):
    b = softmax((k @ q) / np.sqrt(q.shape[0]))
    return b


i_query = np.random.normal(size=(4,))
i_keys = np.random.normal(size=(11, 4))

b = tensor_dot(i_query, i_keys)
print("b = ", b)
b =  [0.10708407 0.2159088  0.18910064 0.02449849 0.09297742 0.04962959
 0.02203747 0.06559855 0.11954705 0.04353657 0.07008136]

As expected, we get out a vector \(\vec{b}\) whose sum is 1.

12.10.2. General Attention#

Now let’s put this attention mechanism into an attention layer.

def attention_layer(q, k, v):
    b = tensor_dot(q, k)
    return b @ v


i_values = np.random.normal(size=(11, 2))
attention_layer(i_query, i_keys, i_values)
array([0.14328042, 0.0463141 ])

We get two values, one for each feature dimension.

12.10.3. Self-attention#

The change in self-attention is that we make queries, keys, and values equal. We need to make a small change in that the queries are batched in this setting, so we should get a rank 2 output.

def batched_tensor_dot(q, k):
    # a will be batch x seq x feature dim
    # which is N x N x 4
    # batched dot product in einstein notation
    a = np.einsum("ij,kj->ik", q, k) / np.sqrt(q.shape[0])
    # now we softmax over sequence
    b = softmax(a, axis=1)
    return b


def self_attention(x):
    b = batched_tensor_dot(x, x)
    return b @ x


i_batched_query = np.random.normal(size=(11, 4))
self_attention(i_batched_query)
array([[ 0.01832814,  0.0749668 , -0.04577256,  0.21776567],
       [-0.28260916,  0.1767006 ,  0.43049974,  0.35189732],
       [-1.12945171,  0.74194802,  0.26632622,  0.1826393 ],
       [ 0.08365006,  0.08586253, -0.11776762,  0.23627548],
       [ 0.02641636, -0.01291467,  0.3669651 ,  0.73047754],
       [-0.12526439,  0.40507159,  0.20178456,  0.3394749 ],
       [-0.08332396, -0.4524976 ,  0.67658427,  1.38369207],
       [ 0.37910667,  0.32785657,  0.06896593,  0.7083815 ],
       [-0.12169481, -0.16310902,  0.38171132,  0.58598476],
       [ 0.03113162,  0.10217473, -0.01623162,  0.32127187],
       [-0.08716809,  0.25544837,  0.28385506,  0.42890038]])

We are given as output an \(11\times4\) matrix, which is correct.

12.10.4. Adding Trainable Parameters#

You can add trainable parameters to these steps by adding a weight matrix. Let’s do this for the self-attention. Although keys, values, and query are equal in self-attention, I can multiply them by different weights. Just to demonstrate, I’ll have the values change to feature dimension 2.

# weights should be input feature_dim -> desired output feature_dim
w_q = np.random.normal(size=(4, 4))
w_k = np.random.normal(size=(4, 4))
w_v = np.random.normal(size=(4, 2))


def trainable_self_attention(x, w_q, w_k, w_v):
    q = x @ w_q
    k = x @ w_k
    v = x @ w_v
    b = batched_tensor_dot(q, k)
    return b @ v


trainable_self_attention(i_batched_query, w_q, w_k, w_v)
array([[-0.12415358,  0.19886661],
       [ 0.3107283 ,  0.62858142],
       [-6.14330661, 10.38434859],
       [-0.18217992,  0.2592068 ],
       [ 1.25461404,  1.07323339],
       [-0.2912892 ,  3.2461432 ],
       [ 2.52707095,  1.62410927],
       [ 0.57845772,  2.03749672],
       [ 1.13981545,  0.60648038],
       [-0.07851417,  0.55365682],
       [ 0.20726197,  1.38793509]])

Since we had our values change to feature dimension 2 with the weights, we get out an \(11\times 2\) output.

12.10.5. Multi-head#

The only change for multi-head attention is that we have one set of weights for each head and we agree on how to combine after running through the heads. I’ll just use a length \(H\) vector of trainable weights. Other strategies are to concatenate them or use a reduction (e.g., mean, max).

w_q_h1 = np.random.normal(size=(4, 4))
w_k_h1 = np.random.normal(size=(4, 4))
w_v_h1 = np.random.normal(size=(4, 2))
w_q_h2 = np.random.normal(size=(4, 4))
w_k_h2 = np.random.normal(size=(4, 4))
w_v_h2 = np.random.normal(size=(4, 2))
w_h = np.random.normal(size=2)


def multihead_attention(x, w_q_h1, w_k_h1, w_v_h1, w_q_h2, w_k_h2, w_v_h2):
    h1_out = trainable_self_attention(x, w_q_h1, w_k_h1, w_v_h1)
    h2_out = trainable_self_attention(x, w_q_h2, w_k_h2, w_v_h2)
    # join along last axis so we can use dot.
    all_h = np.stack((h1_out, h2_out), -1)
    return all_h @ w_h


multihead_attention(i_batched_query, w_q_h1, w_k_h1, w_v_h1, w_q_h2, w_k_h2, w_v_h2)
array([[ 6.77177965e-01,  5.59331900e-02],
       [ 2.27143381e+00,  3.78831009e-01],
       [ 1.09948881e+03,  5.83190853e+02],
       [ 1.35308286e+00,  3.07478796e-01],
       [ 3.68488594e-01, -5.26710264e-01],
       [ 1.36394854e+02, -9.30940659e+00],
       [-2.35289075e+00, -9.72529355e-01],
       [ 1.80170350e+01,  4.90282722e-02],
       [-1.69200964e+00, -7.16240273e-01],
       [ 1.67545404e+00,  5.52312680e-01],
       [ 3.22989922e+01, -1.62346886e+00]])

As expected, we do get an \(11\times2\) rank 2 output.

12.11. Attention in Graph Neural Networks#

Recall that the key attribute of a graph neural network is permutation equivariance. We used reductions like sum or mean over neighbors as the way to make the graph neural network layers be permutation equivariant. Attention layers are also permutation invariant (when not batched) and equivariant (when batched). This has made attention a popular choice for how to aggregate neighbor information. Attention layers are good at finding important neighbors and so are important with high-degree graphs (lots of neighbors). This is rare in molecules, but you can just define all atoms to be connected and then put distances as the edge attributes. Recall that graph convolution layers (GCN layer), and most GNN layers, only allow information to propagate one-bond per layer. Thus joining all atoms and using attention can give you long-range communication without so many layers. The disadvantage is that your network must now learn how to give attention to the correct bonds/atoms.

Let’s see how attention fits into the Battaglia equations[BHB+18]. Recall that the Battaglia equations are general standard equations for defining a GNN. Attention can appear in multiple places, but as discussed above it appears when considering neighbors. Specifically, the query will be the \(i\)th node, and the keys/values will be some combination of neighboring node and edge features. There is no step in the Battaglia equations where this fits neatly, but we can split up the attention layer as follows. Most of the attention layer will fit into the edge update equation:

(12.12)#\[\begin{equation} \vec{e}^{'}_k = \phi^e\left( \vec{e}_k, \vec{v}_{rk}, \vec{v}_{sk}, \vec{u}\right) \end{equation}\]

Recall that this is a general equation and our choice of \(\phi^e()\) defines the GNN. \(\vec{e}_k\) is the feature vector of edge \(k\), \(\vec{v}_{rk}\) is the receiving node feature vector for edge \(k\), \(\vec{v}_{sk}\) is the sending node feature vector for edge \(k\), \(\vec{u}\) is the global graph feature vector. We will use this step for attention mechanism where the query is the receiving node \(\vec{v}_{rk}\) and the keys/values are composed of the senders and edges vectors. To be specific, we’ll use the approach from Zhang et al. [ZSX+18] with a tensor-dot mechanism. They only considered node features and set the keys and values to be identical as the node features. However, they put trainable parameters at each layer that translated the node features in to the keys/query.

(12.13)#\[\begin{equation} \vec{q} = \mathbf{W}_q\vec{v}_{rk} \end{equation}\]
(12.14)#\[\begin{equation} \mathbf{K} = \mathbf{W}_k\vec{v}_{sk} \end{equation}\]
(12.15)#\[\begin{equation} \mathbf{V} = \mathbf{W}_v\vec{v}_{sk} \end{equation}\]
(12.16)#\[\begin{equation} \vec{b}_k = \mathrm{softmax}\left(\frac{1}{\sqrt{d}} \vec{q}\cdot \mathbf{K}\right) \end{equation}\]
(12.17)#\[\begin{equation} \vec{e}^{'}_k = \vec{b} V \end{equation}\]

Putting it compactly into one equation:

(12.18)#\[\begin{equation} \vec{e}^{'}_k = \mathrm{softmax}\left(\frac{1}{\sqrt{d}} \mathbf{W}_q\vec{v}_{rk}\cdot \mathbf{W}_k\vec{v}_{sk}\right)\mathbf{W}_v\vec{v}_{sk} \end{equation}\]

Now we have weighted edge feature vectors from the attention. Finally, we sum over these edge features in the edge aggregation step.

(12.19)#\[\begin{equation} \bar{e}^{'}_i = \rho^{e\rightarrow v}\left( E_i^{'}\right) = \sum E_i^{'} \end{equation}\]

In Zhang et al. [ZSX+18], they used multi-headed attention as well. How would multi-headed attention work? Your edge feature matrix \( E_i^{'}\) now becomes an edge feature tensor, where axis 0 is edge (\(k\)), axis 1 is feature, and axis 2 is the head. Recall that the “head” just means which set of \(\mathbf{W}^h_q, \mathbf{W}^h_k, \mathbf{W}^h_v\) we used. To reduce the tensor back to the expected matrix, we simply use another weight matrix that maps from the last two axes (feature, head) down to features only. I will write-out the indices explicitly to be more clear:

(12.20)#\[\begin{equation} \bar{e}^{'}_{il} = \rho^{e\rightarrow v}\left( E_i^{'}\right) = \sum_k e_{ikjh}^{'}w_{jhl} \end{equation}\]

where \(j\) is edge feature input index, \(l\) is our output edge feature matrix, and \(k,h,i\) are defined as before. Transformer is another name for a network built on multi-headed attention, so you’ll also see transformer graph neural networks [MDM+20] building.

12.12. Chapter Summary#

  • Attention layers are inspired by human ideas of attention, but is fundamentally a weighted mean reduction.

  • The attention layer takes in three inputs: the query, the values, and the keys. These inputs are often identical, where the query is one key and the keys and the values are equal.

  • They are good at modeling sequences, such as language.

  • The attention vector should be normalized, which can be achieved using a softmax activation function, but the attention mechanism equation is a hyperparameter.

  • Attention layers compute an attention vector with the attention mechanism, and then reduce it by computing the attention-weighted average.

  • Using hard attention (hardmax function) returns the maximum output from the attention mechanism.

  • The tensor-dot followed by a softmax is the most common attention mechanism.

  • Self-attention is achieved when the query, values, and the keys are equal.

  • Attention layers by themselves are not trainable.

  • Multi-head attention block is a group of layers that splits to multiple parallel attentions.

12.13. Cited References#

BHB+18

Peter W Battaglia, Jessica B Hamrick, Victor Bapst, Alvaro Sanchez-Gonzalez, Vinicius Zambaldi, Mateusz Malinowski, Andrea Tacchetti, David Raposo, Adam Santoro, Ryan Faulkner, and others. Relational inductive biases, deep learning, and graph networks. arXiv preprint arXiv:1806.01261, 2018.

ZSX+18(1,2)

Jiani Zhang, Xingjian Shi, Junyuan Xie, Hao Ma, Irwin King, and Dit-Yan Yeung. Gaan: gated attention networks for learning on large and spatiotemporal graphs. arXiv preprint arXiv:1803.07294, 2018.

BP97

Shumeet Baluja and Dean A. Pomerleau. Expectation-based selective attention for visual monitoring and control of a robot vehicle. Robotics and Autonomous Systems, 22(3):329–344, 1997. Robot Learning: The New Wave. URL: http://www.sciencedirect.com/science/article/pii/S0921889097000468, doi:https://doi.org/10.1016/S0921-8890(97)00046-8.

TG80

Anne M Treisman and Garry Gelade. A feature-integration theory of attention. Cognitive psychology, 12(1):97–136, 1980.

LPM15(1,2)

Minh-Thang Luong, Hieu Pham, and Christopher D Manning. Effective approaches to attention-based neural machine translation. arXiv preprint arXiv:1508.04025, 2015.

VSP+17(1,2)

Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez, Łukasz Kaiser, and Illia Polosukhin. Attention is all you need. In Advances in neural information processing systems, 5998–6008. 2017.

MDM+20

Łukasz Maziarka, Tomasz Danel, Sławomir Mucha, Krzysztof Rataj, Jacek Tabor, and Stanisław Jastrzębski. Molecule attention transformer. arXiv preprint arXiv:2002.08264, 2020.