{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Graph Neural Networks\n",
"\n",
"Historically, the biggest difficulty for machine learning with molecules was the choice and computation of \"descriptors\". Graph neural networks (GNNs) are a category of deep neural networks whose inputs are graphs and provide a way around the choice of descriptors. A GNN can take a molecule directly as input.\n",
"\n",
"\n",
"```{admonition} Audience & Objectives\n",
"This chapter builds on {doc}`layers` and {doc}`../ml/regression`. Although it is defined here, it would be good to be familiarize yourself with graphs/networks. After completing this chapter, you should be able to \n",
"\n",
" * Represent a molecule in a graph\n",
" * Discuss and categorize common graph neural network architectures\n",
" * Build a GNN and choose a read-out function for the type of labels\n",
" * Distinguish between graph, edge, and node features\n",
" * Formulate a GNN into edge-updates, node-updates, and aggregation steps\n",
"```\n",
"\n",
"GNNs are specific layers that input a graph and output a graph. You can find reviews of GNNs in Dwivedi *et al.*{cite}`dwivedi2020benchmarking`, Bronstein *et al.*{cite}`bronstein2017geometric`, and Wu *et al.*{cite}`wu2020comprehensive`. GNNs can be used for everything from coarse-grained molecular dynamics {cite}`li2020graph` to predicting NMR chemical shifts {cite}`yang2020predicting` to modeling dynamics of solids {cite}`xie2019graph`. Before we dive too deep into them, we must first understand how a graph is represented in a computer and how molecules are converted into graphs. \n",
"\n",
"You can find an interactive introductory article on graphs and graph neural networks at [distill.pub](https://distill.pub/2021/gnn-intro/) {cite}`sanchez-lengeling2021a`. Most current research in GNNs is done with specialized deep learning libraries for graphs. As of 2022, the most common are [PyTorch Geometric](https://pytorch-geometric.readthedocs.io/en/latest/), [Deep Graph library](https://www.dgl.ai/), [DIG](https://github.com/divelab/DIG), [Spektral](https://graphneural.network/), and [TensorFlow GNNS](https://github.com/tensorflow/gnn)."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Representing a Graph\n",
"\n",
"A graph $\\mathbf{G}$ is a set of nodes $\\mathbf{V}$ and edges $\\mathbf{E}$. In our setting, node $i$ is defined by a vector $\\vec{v}_i$, so that the set of nodes can be written as a rank 2 tensor. The edges can be represented as an adjacency matrix $\\mathbf{E}$, where if $e_{ij} = 1$ then nodes $i$ and $j$ are connected by an edge. In many fields, graphs are often immediately simplified to be directed and acyclic, which simplifies things. Molecules are instead undirected and have cycles (rings). Thus, our adjacency matrices are always symmetric $e_{ij} = e_{ji}$ because there is no concept of direction in chemical bonds. Often our edges themselves have features, so that $e_{ij}$ is itself a vector. Then the adjacency matrix becomes a rank 3 tensor. Examples of edge features might be covalent bond order or distance between two nodes.\n",
"\n",
"\n",
"```{figure} ./methanol.jpg\n",
"----\n",
"name: methanol\n",
"width: 400px\n",
"----\n",
"Methanol with atoms numbered so that we can convert it to a graph. \n",
"```\n",
"\n",
"```{margin} one-hot\n",
"Recall that a one-hot is a vector of\n",
"all 0s and a single 1 - `[0, 1, 0, 0]`. The \n",
"index of the non-zero element indictates the class.\n",
"In this case, class is element. \n",
"```\n",
"\n",
"Let's see how a graph can be constructed from a molecule. Consider methanol, shown in {numref}`methanol`. I've numbered the atoms so that we have an order for defining the nodes/edges. First, the node features. You can use anything for node features, but often we'll begin with one-hot encoded feature vectors:\n",
"\n",
"| Node | C | H | O |\n",
"|:-----|----|----|---:|\n",
"| 1 | 0 | 1 | 0 |\n",
"| 2 | 0 | 1 | 0 |\n",
"| 3 | 0 | 1 | 0 |\n",
"| 4 | 1 | 0 | 0 |\n",
"| 5 | 0 | 0 | 1 |\n",
"| 6 | 0 | 1 | 0 |\n",
"\n",
"$\\mathbf{V}$ will be the combined feature vectors of these nodes. The adjacency matrix $\\mathbf{E}$ will look like:\n",
"\n",
"\n",
"| | 1 | 2 | 3 | 4 | 5 | 6 | \n",
"|:---|----|----|----|----|----|---:|\n",
"| 1 | 0 | 0 | 0 | 1 | 0 | 0 |\n",
"| 2 | 0 | 0 | 0 | 1 | 0 | 0 |\n",
"| 3 | 0 | 0 | 0 | 1 | 0 | 0 |\n",
"| 4 | 1 | 1 | 1 | 0 | 1 | 0 |\n",
"| 5 | 0 | 0 | 0 | 1 | 0 | 1 |\n",
"| 6 | 0 | 0 | 0 | 0 | 1 | 0 |\n",
"\n",
"\n",
"Take a moment to understand these two. For example, notice that rows 1, 2, and 3 only have the 4th column as non-zero. That's because atoms 1-3 are bonded only to carbon (atom 4). Also, the diagonal is always 0 because atoms cannot be bonded with themselves. \n",
"\n",
"You can find a similar process for converting crystals into graphs in Xie et al. {cite}`Xie2018Crystal`. We'll now begin with a function which can convert a smiles string into this representation."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Running This Notebook\n",
"\n",
"\n",
"Click the above to launch this page as an interactive Google Colab. See details below on installing packages.\n",
"\n",
"````{tip} My title\n",
":class: dropdown\n",
"To install packages, execute this code in a new cell. \n",
"\n",
"```\n",
"!pip install dmol-book\n",
"```\n",
"\n",
"If you find install problems, you can get the latest working versions of packages used in [this book here](https://github.com/whitead/dmol-book/blob/master/package/setup.py)\n",
"\n",
"````"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import matplotlib.pyplot as plt\n",
"import matplotlib as mpl\n",
"import numpy as np\n",
"import tensorflow as tf\n",
"import pandas as pd\n",
"import rdkit, rdkit.Chem, rdkit.Chem.rdDepictor, rdkit.Chem.Draw\n",
"import networkx as nx\n",
"import dmol"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"soldata = pd.read_csv(\n",
" \"https://github.com/whitead/dmol-book/raw/master/data/curated-solubility-dataset.csv\"\n",
")\n",
"np.random.seed(0)\n",
"my_elements = {6: \"C\", 8: \"O\", 1: \"H\"}"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The hidden cell below defines our function `smiles2graph`. This creates one-hot node feature vectors for the element C, H, and O. It also creates an adjacency tensor with one-hot bond order being the feature vector."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"tags": [
"hide-cell"
]
},
"outputs": [],
"source": [
"def smiles2graph(sml):\n",
" \"\"\"Argument for the RD2NX function should be a valid SMILES sequence\n",
" returns: the graph\n",
" \"\"\"\n",
" m = rdkit.Chem.MolFromSmiles(sml)\n",
" m = rdkit.Chem.AddHs(m)\n",
" order_string = {\n",
" rdkit.Chem.rdchem.BondType.SINGLE: 1,\n",
" rdkit.Chem.rdchem.BondType.DOUBLE: 2,\n",
" rdkit.Chem.rdchem.BondType.TRIPLE: 3,\n",
" rdkit.Chem.rdchem.BondType.AROMATIC: 4,\n",
" }\n",
" N = len(list(m.GetAtoms()))\n",
" nodes = np.zeros((N, len(my_elements)))\n",
" lookup = list(my_elements.keys())\n",
" for i in m.GetAtoms():\n",
" nodes[i.GetIdx(), lookup.index(i.GetAtomicNum())] = 1\n",
"\n",
" adj = np.zeros((N, N, 5))\n",
" for j in m.GetBonds():\n",
" u = min(j.GetBeginAtomIdx(), j.GetEndAtomIdx())\n",
" v = max(j.GetBeginAtomIdx(), j.GetEndAtomIdx())\n",
" order = j.GetBondType()\n",
" if order in order_string:\n",
" order = order_string[order]\n",
" else:\n",
" raise Warning(\"Ignoring bond order\" + order)\n",
" adj[u, v, order] = 1\n",
" adj[v, u, order] = 1\n",
" return nodes, adj"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"nodes, adj = smiles2graph(\"CO\")\n",
"nodes"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## A Graph Neural Network\n",
"\n",
"A graph neural network (GNN) is a neural network with two defining attributes:\n",
"\n",
"1. Its input is a graph\n",
"2. Its output is permutation equivariant\n",
"\n",
"We can understand clearly the first point. Here, a graph permutation means re-ordering our nodes. In our methanol example above, we could have easily made the carbon be atom 1 instead of atom 4. Our new adjacency matrix would then be:\n",
"\n",
"| | 1 | 2 | 3 | 4 | 5 | 6 | \n",
"|:---|----|----|----|----|----|---:|\n",
"| 1 | 0 | 1 | 1 | 1 | 1 | 0 |\n",
"| 2 | 1 | 0 | 0 | 0 | 0 | 0 |\n",
"| 3 | 1 | 0 | 0 | 0 | 0 | 0 |\n",
"| 4 | 1 | 0 | 0 | 0 | 1 | 0 |\n",
"| 5 | 1 | 0 | 0 | 0 | 0 | 1 |\n",
"| 6 | 0 | 0 | 0 | 0 | 1 | 0 |\n",
"\n",
"\n",
"A GNN is permutation equivariant if the output change the same way as these exchanges. If you are trying to model a per-atom quantity like partial charge or chemical shift, this is obviously essential. If you change the order of atoms input, you would expect the order of their partial charges to similarly change. \n",
"\n",
"Often we want to model a whole-molecule property, like solubility or energy. This should be **invariant** to changing the order of the atoms. To make an equivariant model invariant, we use read-outs (defined below). See {doc}`data` for a more detailed discussion of equivariance."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### A simple GNN\n",
"\n",
"We will often mention a GNN when we really mean a layer from a GNN. Most GNNs implement a specific layer that can deal with graphs, and so usually we are only concerned with this layer. Let's see an example of a simple layer for a GNN:\n",
"\n",
"\\begin{equation}\n",
"f_k = \\sigma\\left( \\sum_i \\sum_j v_{ij}w_{jk} \\right)\n",
"\\end{equation}\n",
"\n",
"This equation shows that we first multiply every node ($v_{ij}$) feature by trainable weights $w_{jk}$, sum over all node features, and then apply an activation. This will yield a single feature vector for the graph. Is this equation permutation equivariant? Yes, because the node index in our expression is index $i$ which can be re-ordered without affecting the output.\n",
"\n",
"Let's see an example that is similar, but not permutation equivariant:\n",
"\n",
"\\begin{equation}\n",
"f_k = \\sigma\\left( \\sum_i v_{ij}w_{ik} \\right)\n",
"\\end{equation}\n",
"\n",
"This is a small change. We have one weight vector per node now. This makes the trainable weights depend on the ordering of the nodes. Then if we swap the node ordering, our weights will no longer align. So if we were to input two methanol molecules, which should have the same output, but we switched two atom numbers, we would get different answers. These simple examples differ from real GNNs in two important ways: (i) they give a single feature vector output, which throws away per-node information, and (ii) they do not use the adjacency matrix. Let's see a real GNN that has these properties while maintaining permutation equivariant."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Kipf & Welling GCN\n",
"\n",
"One of the first popular GNNs was the Kipf & Welling graph convolutional network (GCN) {cite}`kipf2016semi`. Although some people consider GCNs to be a broad class of GNNs, we'll use GCNs to refer specifically the Kipf & Welling GCN. \n",
"Thomas Kipf has written an [excellent article introducing the GCN](https://tkipf.github.io/graph-convolutional-networks/).\n",
"\n",
"The input to a GCN layer is $\\mathbf{V}$, $\\mathbf{E}$ and it outputs an updated $\\mathbf{V}'$. Each node feature vector is updated. The way it updates a node feature vector is by averaging the feature vectors of its neighbors, as determined by $\\mathbf{E}$. The choice of averaging over neighbors is what makes a GCN layer permutation equivariant. Averaging over neighbors is not trainable, so we must add trainable parameters. We multiply the neighbor features by a trainable matrix before the averaging, which gives the GCN the ability to learn. In Einstein notation, this process is:\n",
"\n",
"$$\n",
"v_{il} = \\sigma\\left(\\frac{1}{d_i}e_{ij}v_{jk}w_{lk}\\right)\n",
"$$ (gcn)\n",
"\n",
"where $i$ is the node we're considering, $j$ is the neighbor index, $k$ is the node input feature, $l$ is the output node feature, $d_i$ is the degree of node i (which makes it an average instead of sum), $e_{ij}$ isolates neighbors so that all non-neighbor $v_{jk}$s are zero, $\\sigma$ is our activation, and $w_{lk}$ is the trainable weights. This equation is a mouthful, but it truly just is the average over neighbors with a trainable matrix thrown in. One common modification is to make all nodes neighbors of themselves. This is so that the output node features $v_{il}$ depends on the input features $v_{ik}$. We do not need to change our equation, just make the adjacency matrix have $1$s on the diagonal instead of $0$ by adding the identity matrix during pre-processing.\n",
"\n",
"Building understanding about the GCN is important for understanding other GNNs. You can view the GCN layer as a way to \"communicate\" between a node and its neighbors. The output for node $i$ will depend only on its immediate neighbors. For chemistry, this is not satisfactory. You can stack multiple layers though. If you have two layers, the output for node $i$ will include information about node $i$'s neighbors' neighbors. Another important detail to understand in GCNs is that the averaging procedure accomplishes two goals: (i) it gives permutation equivariance by removing the effect of neighbor order and (ii) it prevents a change in magnitude in node features. A sum would accomplish (i) but would cause the magnitude of the node features to grow after each layer. Of course, you could ad-hoc put a batch normalization layer after each GCN layer to keep output magnitudes stable but averaging is easy."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"tags": [
"remove-cell"
]
},
"outputs": [],
"source": [
"# THIS CELL IS USED TO GENERATE A FIGURE\n",
"# AND NOT RELATED TO CHAPTER\n",
"# YOU CAN SKIP IT\n",
"from myst_nb import glue\n",
"from moviepy.editor import VideoClip\n",
"from moviepy.video.io.bindings import mplfig_to_npimage\n",
"\n",
"\n",
"def draw_vector(x, y, s, v, ax, cmap, **kwargs):\n",
" x += s / 2\n",
" y += s / 2\n",
" for vi in v:\n",
" if cmap is not None:\n",
" ax.add_patch(\n",
" mpl.patches.Rectangle((x, y), s * 1.5, s, facecolor=cmap(vi), **kwargs)\n",
" )\n",
" else:\n",
" ax.add_patch(\n",
" mpl.patches.Rectangle(\n",
" (x, y), s * 1.5, s, facecolor=\"#FFF\", edgecolor=\"#333\", **kwargs\n",
" )\n",
" )\n",
" ax.text(\n",
" x + s * 1.5 / 2,\n",
" y + s / 2,\n",
" \"{:.2f}\".format(vi),\n",
" verticalalignment=\"center\",\n",
" horizontalalignment=\"center\",\n",
" )\n",
" y += s\n",
"\n",
"\n",
"def draw_key(x, y, s, v, ax, cmap, **kwargs):\n",
" x += s / 2\n",
" y += s / 2\n",
" for vi in v:\n",
" ax.add_patch(\n",
" mpl.patches.Rectangle((x, y), s * 1.5, s, facecolor=cmap(1.0), **kwargs)\n",
" )\n",
" ax.text(\n",
" x + s * 1.5 / 2,\n",
" y + s / 2,\n",
" vi,\n",
" verticalalignment=\"center\",\n",
" horizontalalignment=\"center\",\n",
" )\n",
" y += s\n",
" ax.text(\n",
" x, y + s / 2, \"Key:\", verticalalignment=\"center\", horizontalalignment=\"left\"\n",
" )\n",
"\n",
"\n",
"def draw(\n",
" nodes, adj, ax, highlight=None, key=False, labels=None, mask=None, draw_nodes=None\n",
"):\n",
" G = nx.Graph()\n",
" for i in range(adj.shape[0]):\n",
" for j in range(adj.shape[0]):\n",
" if np.any(adj[i, j]):\n",
" G.add_edge(i, j)\n",
" if mask is None:\n",
" mask = [True] * len(G)\n",
" if draw_nodes is None:\n",
" draw_nodes = nodes\n",
" # go from atomic number to element\n",
" elements = np.argmax(draw_nodes, axis=-1)\n",
" el_labels = {i: list(my_elements.values())[e] for i, e in enumerate(elements)}\n",
" try:\n",
" pos = nx.nx_agraph.graphviz_layout(G, prog=\"sfdp\")\n",
" except ImportError:\n",
" pos = nx.spring_layout(G, iterations=100, seed=4, k=1)\n",
" pos = nx.rescale_layout_dict(pos)\n",
" c = [\"white\"] * len(G)\n",
" all_h = []\n",
" if highlight is not None:\n",
" for i, h in enumerate(highlight):\n",
" for hj in h:\n",
" c[hj] = \"C{}\".format(i + 1)\n",
" all_h.append(hj)\n",
" nx.draw(G, ax=ax, pos=pos, labels=el_labels, node_size=700, node_color=c)\n",
" cmap = plt.get_cmap(\"Wistia\")\n",
" for i in range(len(G)):\n",
" if not mask[i]:\n",
" continue\n",
" if i in all_h:\n",
" draw_vector(*pos[i], 0.15, nodes[i], ax, cmap)\n",
" else:\n",
" draw_vector(*pos[i], 0.15, nodes[i], ax, None)\n",
" if key:\n",
" draw_key(-1, -1, 0.15, my_elements.values(), ax, cmap)\n",
" if labels is not None:\n",
" legend_elements = []\n",
" for i, l in enumerate(labels):\n",
" p = mpl.lines.Line2D(\n",
" [0], [0], marker=\"o\", color=\"C{}\".format(i + 1), label=l, markersize=15\n",
" )\n",
" legend_elements.append(p)\n",
" ax.legend(handles=legend_elements)\n",
" ax.set_xlim(-1.2, 1.2)\n",
" ax.set_ylim(-1.2, 1.2)\n",
" ax.set_facecolor(\"#f5f4e9\")\n",
"\n",
"\n",
"fig = plt.figure(figsize=(8, 5))\n",
"draw(nodes, adj, plt.gca(), highlight=[[1], [5, 0]], labels=[\"center\", \"neighbors\"])\n",
"fig.set_facecolor(\"#f5f4e9\")\n",
"glue(\"dframe\", plt.gcf(), display=False)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"tags": [
"remove-cell"
]
},
"outputs": [],
"source": [
"# THIS CELL IS USED TO GENERATE A FIGURE\n",
"# AND NOT RELATED TO CHAPTER\n",
"# YOU CAN SKIP IT\n",
"fig, axs = plt.subplots(1, 2, squeeze=True, figsize=(14, 6), dpi=100)\n",
"order = [5, 1, 0, 2, 3, 4]\n",
"time_per_node = 2\n",
"last_layer = [0]\n",
"layers = 2\n",
"input_nodes = np.copy(nodes)\n",
"fig.set_facecolor(\"#f5f4e9\")\n",
"\n",
"\n",
"def make_frame(t):\n",
" axs[0].clear()\n",
" axs[1].clear()\n",
"\n",
" layer_i = int(t / (time_per_node * len(order)))\n",
" axs[0].set_title(f\"Layer {layer_i + 1} Input\")\n",
" axs[1].set_title(f\"Layer {layer_i + 1} Output\")\n",
"\n",
" flat_adj = np.sum(adj, axis=-1)\n",
" out_nodes = np.einsum(\n",
" \"i,ij,jk->ik\",\n",
" 1 / (np.sum(flat_adj, axis=1) + 1),\n",
" flat_adj + np.eye(*flat_adj.shape),\n",
" nodes,\n",
" )\n",
"\n",
" if last_layer[0] != layer_i:\n",
" print(\"recomputing\")\n",
" nodes[:] = out_nodes\n",
" last_layer[0] = layer_i\n",
"\n",
" t -= layer_i * time_per_node * len(order)\n",
" i = order[int(t / time_per_node)]\n",
" print(last_layer, layer_i, i, t)\n",
" mask = [False] * nodes.shape[0]\n",
" for j in order[: int(t / time_per_node) + 1]:\n",
" mask[j] = True\n",
" print(mask, i)\n",
" neighs = list(np.where(adj[i])[0])\n",
" if (t - int(t / time_per_node) * time_per_node) >= time_per_node / 4:\n",
" draw(\n",
" nodes,\n",
" adj,\n",
" axs[0],\n",
" highlight=[[i], neighs],\n",
" labels=[\"center\", \"neighbors\"],\n",
" draw_nodes=input_nodes,\n",
" )\n",
" else:\n",
" draw(\n",
" nodes,\n",
" adj,\n",
" axs[0],\n",
" highlight=[[i]],\n",
" labels=[\"center\", \"neighbors\"],\n",
" draw_nodes=input_nodes,\n",
" )\n",
" if (t - int(t / time_per_node) * time_per_node) < time_per_node / 2:\n",
" mask[j] = False\n",
" draw(\n",
" out_nodes,\n",
" adj,\n",
" axs[1],\n",
" highlight=[[i]],\n",
" key=True,\n",
" mask=mask,\n",
" draw_nodes=input_nodes,\n",
" )\n",
" fig.set_facecolor(\"#f5f4e9\")\n",
" return mplfig_to_npimage(fig)\n",
"\n",
"\n",
"animation = VideoClip(make_frame, duration=time_per_node * nodes.shape[0] * layers)\n",
"\n",
"animation.write_gif(\"../_static/images/gcn.gif\", fps=2)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"```{glue:figure} dframe\n",
"----\n",
"name: dframe\n",
"----\n",
"Intermediate step of the graph convolution layer. The 3D vectors are the node features and start as one-hot, so a `[1.00, 0.00, 0.00]` means hydrogen. The center node will be updated by averaging its neighbors features.\n",
"```\n",
"\n",
"\n",
"To help understand the GCN layer, look at {numref}`dframe`. It shows an intermediate step of the GCN layer. Each node feature is represented here as a one-hot encoded vector at input. The animation in {numref}`gcnanim` shows the averaging process over neighbor features. To make this animation easy to follow, the trainable weights and activation functions are not considered. Note that the animation repeats for a second layer. Watch how the \"information\" about there being an oxygen atom in the molecule is propagated only after two layers to each atom. All GNNs operate with similar approaches, so try to understand how this animation works. \n",
"\n",
"\n",
"\n",
"```{figure} ../_static/images/gcn.gif\n",
"----\n",
"name: gcnanim\n",
"----\n",
"Animation of the graph convolution layer operation. The left is input, right is output node features. Note that two layers are shown (see title change). As the animation plays out, you can see how the information about the atoms propagates through the molecule via the averaging over neigbhors. So the oxygen goes from being just an oxygen, to an oxygen bonded to C and H, to an oxygen bonded to an H and CH3. The colors just reflect the same information in the numerical values.\n",
"```\n",
"\n",
"\n",
"### GCN Implementation\n",
"\n",
"Let's now create a tensor implementation of the GCN. We'll skip the activation and trainable weights for now.\n",
"We must first compute our rank 2 adjacency matrix. The `smiles2graph` code above computes an adjacency tensor with feature vectors. We can fix that with a simple reduction and add the identity at the same time\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"nodes, adj = smiles2graph(\"CO\")\n",
"adj_mat = np.sum(adj, axis=-1) + np.eye(adj.shape[0])\n",
"adj_mat"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"To compute degree of each node, we can do another reduction:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"degree = np.sum(adj_mat, axis=-1)\n",
"degree"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now we can put all these pieces together into the Einstein equation"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"print(nodes[0])\n",
"# note to divide by degree, make the input 1 / degree\n",
"new_nodes = np.einsum(\"i,ij,jk->ik\", 1 / degree, adj_mat, nodes)\n",
"print(new_nodes[0])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"To now implement this as a layer in Keras, we must put this code above into a new Layer subclass. The code is relatively straightforward, but you can read-up on the function names and Layer class in [this tutorial](https://keras.io/guides/making_new_layers_and_models_via_subclassing/). The three main changes are that we create trainable parameters `self.w` and use them in the {obj}`tf.einsum`, we use an activation `self.activation`, and we output both our new node features and the adjacency matrix. The reason to output the adjacency matrix is so that we can stack multiple GCN layers without having to pass the adjacency matrix each time."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"class GCNLayer(tf.keras.layers.Layer):\n",
" \"\"\"Implementation of GCN as layer\"\"\"\n",
"\n",
" def __init__(self, activation=None, **kwargs):\n",
" # constructor, which just calls super constructor\n",
" # and turns requested activation into a callable function\n",
" super(GCNLayer, self).__init__(**kwargs)\n",
" self.activation = tf.keras.activations.get(activation)\n",
"\n",
" def build(self, input_shape):\n",
" # create trainable weights\n",
" node_shape, adj_shape = input_shape\n",
" self.w = self.add_weight(shape=(node_shape[2], node_shape[2]), name=\"w\")\n",
"\n",
" def call(self, inputs):\n",
" # split input into nodes, adj\n",
" nodes, adj = inputs\n",
" # compute degree\n",
" degree = tf.reduce_sum(adj, axis=-1)\n",
" # GCN equation\n",
" new_nodes = tf.einsum(\"bi,bij,bjk,kl->bil\", 1 / degree, adj, nodes, self.w)\n",
" out = self.activation(new_nodes)\n",
" return out, adj"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"A lot of the code above is Keras/TF specific and getting the variables to the right place. There are really only two key lines here. The first is to compute the degree by summing over the columns of the adjacency matrix:\n",
"\n",
"```python\n",
"degree = tf.reduce_sum(adj, axis=-1)\n",
"```\n",
"\n",
"The second key line is to do the GCN equation {eq}`gcn` (without the activation)\n",
"\n",
"```python\n",
"new_nodes = tf.einsum(\"bi,bij,bjk,kl->bil\", 1 / degree, adj, nodes, self.w)\n",
"```\n",
"\n",
"We can now try our layer:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"gcnlayer = GCNLayer(\"relu\")\n",
"# we insert a batch axis here\n",
"gcnlayer((nodes[np.newaxis, ...], adj_mat[np.newaxis, ...]))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"It outputs (1) the new node features and (2) the adjacency matrix. Let's make sure we can stack these and apply the GCN multiple times"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"x = (nodes[np.newaxis, ...], adj_mat[np.newaxis, ...])\n",
"for i in range(2):\n",
" x = gcnlayer(x)\n",
"print(x)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"It works! Why do we see zeros though? Probably because we had negative numbers that were removed by our ReLU activation. This will be solved by training and increasing our dimension number. "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Solubility Example\n",
"\n",
"We'll now revisit predicting solubility with GCNs. Remember before that we used the features included with the dataset. Now we can use the molecular structures directly. Our GCN layer outputs node-level features. To predict solubility, we need to get a graph-level feature. We'll see later how to be more sophisticated in this process, but for now let's just take the average over all node features after our GCN layers. This is simple, permutation invariant, and gets us from node-level to graph level. Here's an implementation of this"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"class GRLayer(tf.keras.layers.Layer):\n",
" \"\"\"A GNN layer that computes average over all node features\"\"\"\n",
"\n",
" def __init__(self, name=\"GRLayer\", **kwargs):\n",
" super(GRLayer, self).__init__(name=name, **kwargs)\n",
"\n",
" def call(self, inputs):\n",
" nodes, adj = inputs\n",
" reduction = tf.reduce_mean(nodes, axis=1)\n",
" return reduction"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The key line in that code is to just to compute the mean over the nodes (`axis=1`):\n",
"\n",
"```python\n",
"reduction = tf.reduce_mean(nodes, axis=1)\n",
"```\n",
"\n",
"\n",
"To complete our deep solubility predictor, we can add some dense layers and make sure we have a single-output without activation since we're doing regression. Note this model is defined using the [Keras functional API](https://keras.io/guides/functional_api/) which is necessary when you have multiple inputs. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"ninput = tf.keras.Input(\n",
" (\n",
" None,\n",
" 100,\n",
" )\n",
")\n",
"ainput = tf.keras.Input(\n",
" (\n",
" None,\n",
" None,\n",
" )\n",
")\n",
"# GCN block\n",
"x = GCNLayer(\"relu\")([ninput, ainput])\n",
"x = GCNLayer(\"relu\")(x)\n",
"x = GCNLayer(\"relu\")(x)\n",
"x = GCNLayer(\"relu\")(x)\n",
"# reduce to graph features\n",
"x = GRLayer()(x)\n",
"# standard layers (the readout)\n",
"x = tf.keras.layers.Dense(16, \"tanh\")(x)\n",
"x = tf.keras.layers.Dense(1)(x)\n",
"model = tf.keras.Model(inputs=(ninput, ainput), outputs=x)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"where does the 100 come from? Well, this dataset has lots of elements so we cannot use our size 3 one-hot encodings because we'll have more than 3 unique elements. We previously only had C, H and O. This is a good time to update our `smiles2graph` function to deal with this."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"tags": [
"hidden-cell"
]
},
"outputs": [],
"source": [
"def gen_smiles2graph(sml):\n",
" \"\"\"Argument for the RD2NX function should be a valid SMILES sequence\n",
" returns: the graph\n",
" \"\"\"\n",
" m = rdkit.Chem.MolFromSmiles(sml)\n",
" m = rdkit.Chem.AddHs(m)\n",
" order_string = {\n",
" rdkit.Chem.rdchem.BondType.SINGLE: 1,\n",
" rdkit.Chem.rdchem.BondType.DOUBLE: 2,\n",
" rdkit.Chem.rdchem.BondType.TRIPLE: 3,\n",
" rdkit.Chem.rdchem.BondType.AROMATIC: 4,\n",
" }\n",
" N = len(list(m.GetAtoms()))\n",
" nodes = np.zeros((N, 100))\n",
" for i in m.GetAtoms():\n",
" nodes[i.GetIdx(), i.GetAtomicNum()] = 1\n",
"\n",
" adj = np.zeros((N, N))\n",
" for j in m.GetBonds():\n",
" u = min(j.GetBeginAtomIdx(), j.GetEndAtomIdx())\n",
" v = max(j.GetBeginAtomIdx(), j.GetEndAtomIdx())\n",
" order = j.GetBondType()\n",
" if order in order_string:\n",
" order = order_string[order]\n",
" else:\n",
" raise Warning(\"Ignoring bond order\" + order)\n",
" adj[u, v] = 1\n",
" adj[v, u] = 1\n",
" adj += np.eye(N)\n",
" return nodes, adj"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"nodes, adj = gen_smiles2graph(\"CO\")\n",
"model((nodes[np.newaxis], adj_mat[np.newaxis]))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"```{margin}\n",
"We have switched from adjacency tensor to matrix only because a GCN cannot use edge features. Other architectures though can. \n",
"```\n",
"It outputs one number! That's always nice to have. Now we need to do some work to get a trainable dataset. Our dataset is a little bit complex because our features are tuples of tensors($\\mathbf{V}, \\mathbf{E}$) so that our dataset is a tuple of tuples: $\\left((\\mathbf{V}, \\mathbf{E}), y\\right)$. We use a **generator**, which is just a python function that can return multiple times. Our function returns once for every training example. Then we have to pass it to the `from_generator` {obj}`tf.data.Dataset` constructor which requires explicit declaration of the shapes of these examples. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def example():\n",
" for i in range(len(soldata)):\n",
" graph = gen_smiles2graph(soldata.SMILES[i])\n",
" sol = soldata.Solubility[i]\n",
" yield graph, sol\n",
"\n",
"\n",
"data = tf.data.Dataset.from_generator(\n",
" example,\n",
" output_types=((tf.float32, tf.float32), tf.float32),\n",
" output_shapes=(\n",
" (tf.TensorShape([None, 100]), tf.TensorShape([None, None])),\n",
" tf.TensorShape([]),\n",
" ),\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Whew, that's a lot. Now we can do our usual splitting of the dataset."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"test_data = data.take(200)\n",
"val_data = data.skip(200).take(200)\n",
"train_data = data.skip(400)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"And finally, time to train."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"tags": [
"remove-output"
]
},
"outputs": [],
"source": [
"model.compile(\"adam\", loss=\"mean_squared_error\")\n",
"result = model.fit(train_data.batch(1), validation_data=val_data.batch(1), epochs=10)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"plt.plot(result.history[\"loss\"], label=\"training\")\n",
"plt.plot(result.history[\"val_loss\"], label=\"validation\")\n",
"plt.legend()\n",
"plt.xlabel(\"Epoch\")\n",
"plt.ylabel(\"Loss\")\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"This model is definitely underfit. One reason is that our batch size is 1. This is a side-effect of making the number of atoms variable and then Keras/tensorflow has trouble batching together our data if there are two unknown dimensions. A standard trick is to group together multiple molecules into one graph, but making sure they are disconnected (no bonds between the molecules). That allows you to batch molecules without increasing the rank of your model/data.\n",
"\n",
"Let's now check the parity plot."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"yhat = model.predict(test_data.batch(1), verbose=0)[:, 0]\n",
"test_y = [y for x, y in test_data]\n",
"plt.figure()\n",
"plt.plot(test_y, test_y, \"-\")\n",
"plt.plot(test_y, yhat, \".\")\n",
"plt.text(\n",
" min(test_y) + 1,\n",
" max(test_y) - 2,\n",
" f\"correlation = {np.corrcoef(test_y, yhat)[0,1]:.3f}\",\n",
")\n",
"plt.text(\n",
" min(test_y) + 1,\n",
" max(test_y) - 3,\n",
" f\"loss = {np.sqrt(np.mean((test_y - yhat)**2)):.3f}\",\n",
")\n",
"plt.title(\"Testing Data\")\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Message Passing Viewpoint\n",
"\n",
"One way to more broadly view a GCN layer is that it is a kind of \"message-passing\" layer. You first compute a message coming from each neighboring node:\n",
"\n",
"\\begin{equation}\n",
"\\vec{e}_{{s_i}j} = \\vec{v}_{{s_i}j} \\mathbf{W}\n",
"\\end{equation}\n",
"\n",
"where $v_{{s_i}j}$ means the $j$th neighbor of node $i$. The $s_i$ means senders to $i$. This is how a GCN computes the messages, it's just a weight matrix times each neighbor node features. After getting the messages that will go to node $i$, $\\vec{e}_{{s_i}j}$, we aggregate them using a function which is permutation invariant to the order of neighbors:\n",
"\n",
"\\begin{equation}\n",
"\\vec{e}_{i} = \\frac{1}{|\\vec{e}_{{s_i}j}|}\\sum \\vec{e}_{{s_i}j} \n",
"\\end{equation}\n",
"\n",
"In the GCN this aggregation is just a mean, but it can be any permutation invariant (possibly trainable) function. Finally, we update our node using the aggregated message in the GCN:\n",
"\n",
"\\begin{equation}\n",
"\\vec{v}^{'}_{i} = \\sigma(\\vec{e}_i)\n",
"\\end{equation}\n",
"\n",
"where $v^{'}$ indicates the new node features. This is simply the activated aggregated message. Writing it out this way, you can see how it is possible to make small changes. One important paper by Gilmer et al. explored some of these choices and described how this general idea of message passing layers does well in learning to predict molecular energies from quantum mechanics {cite}`gilmer2017neural`. Examples of changes to the above GCN equations are to include edge information when computing the neighbor messages or use a dense neural network layer in place of $\\sigma$. You can think of the GCN as one type of a broader class of message passing graph neural networks, sometimes abbreviated as MPNN.\n",
"\n",
"## Gated Graph Neural Network\n",
"\n",
"\n",
"One common variant of the message passing layer is the **gated graph neural network** (GGN) {cite}`li2015gated`. It replaces the last equation, the node update, with\n",
"\n",
"\\begin{equation}\n",
"\\vec{v}^{'}_{i} = \\textrm{GRU}(\\vec{v}_i, \\vec{e}_i)\n",
"\\end{equation}\n",
"\n",
"where the $\\textrm{GRU}(\\cdot, \\cdot)$ is a gated recurrent unit{cite}`chung2014empirical`. A GRU is a binary (two input arguments) neural network that is typically used in sequence modeling. The interesting property of a GGN relative to a GCN is that it has trainable parameters in the node update (from the GRU), giving the model a bit more flexibility. In a GGN, the GRU parameters are kept the same at each layer, like how a GRU is used to model sequences. What's nice about this is that you can stack infinite GGN layers without increasing the number of trainable parameters (assuming you make $\\mathbf{W}$ the same at each layer). Thus GGNs are suited for large graphs, like a large protein or large unit cell.\n",
"\n",
"```{margin}\n",
"You'll often see the prefix \"gated\" on GNNs and that means that the nodes are updated according to a GRU.\n",
"```\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Pooling\n",
"\n",
"Within the message passing viewpoint, and in general for GNNS, the way that messages from neighbors are combined is a key step. This is sometimes called **pooling**, since it's similar to the pooling layer used in convolutional neural networks. Just like in pooling for convolutional neural networks, there are multiple reduction operations you can use. Typically you see a sum or mean reduction in GNNs, but you can be quite sophisticated like in the Graph Isomorphism Networks {cite}`xu2018powerful`. We'll see an example in our attention chapter of using self-attention, which can also be used for pooling. It can be tempting to focus on this step, but it's been empirically found that the choice of pooling is not so important{cite}`luzhnica2019graph,mesquita2020rethinking`. The key property of the pooling is permutation *invariance* - we want the aggregation operation to not depend on order of nodes (or edges if pooling over them). You can find a recent review of pooling methods in Grattarola et al. {cite}`grattarola2021understanding`.\n",
"\n",
"You can see a more visual comparison and overview of the various pooling strategies in this distill article by Daigavane et al. {cite}`daigavane2021understanding`."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Readout Function\n",
"\n",
"GNNs output a graph by design. It is rare that our labels are graphs -- typically we have node labels or a single graph label. An example of a node label is partial charge of atoms. An example of a graph label would be the energy of the molecule. The process of converting the graph output from the GNN into our predicted node labels or graph label is called the **readout**. If we have node labels, we can simply discard the edges and use our output node feature vectors from the GNN as the prediction, perhaps with a few dense layers before our predicted output label.\n",
"\n",
"If we're trying to predict a graph-level label like energy of the molecule or net charge, we need to be careful when converting from node/edge features to a graph label. If we simply put the node features into a dense layer to get to the desired shape graph label, we will lose permutation equivariance (technically it's permutation invariance now since our output is graph label, not node labels). The readout we did above in the solubility example was a reduction over the node features to get a graph feature. Then we used this graph feature in dense layers. It turns out this is the only way {cite}`zaheer2017deep` to do a graph feature readout: a reduction over nodes to get graph feature and then dense layers to get predicted graph label from those graph features. You can also do some dense layers on the node features individually, but that already happens in GNN so I do not recommend it. This readout is sometimes called DeepSets because it is the same form as the DeepSets architecture, which is a permutation invariant architecture for features that are sets{cite}`zaheer2017deep`.\n",
"\n",
"You may notice that the pooling and readouts both use permutation invariant functions. Thus, DeepSets can be used for pooling and attention could be used for readouts. \n",
"\n",
"### Intensive vs Extensive\n",
"\n",
"One important consideration of a readout in regression is if your labels are **intensive** or **extensive**. An intensive label is one whose value is independent of the number of nodes (or atoms). For example, the index of refraction or solubility are intensive. The readout for an intensive label should (generally) be independent of the number of a nodes/atoms. So the reduction in the readout could be a mean or max, but not a sum. In contrast, an extensive label should (generally) use a sum for the reduction in the readout. An example of an extensive molecular property is enthalpy of formation. "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Battaglia General Equations\n",
"\n",
"As you can see, message passing layers is a general way to view GNN layers. Battaglia *et al.* {cite}`battaglia2018relational` went further and created a general set of equations which captures nearly all GNNs. They broke the GNN layer equations down into 3 update equations, like the node update equation we saw in the message passing layer equations, and 3 aggregation equations (6 total equations). There is a new concept in these equations: graph feature vectors. Instead of having two parts to your network (GNN then readout), a graph level feature is updated at every GNN layer. The graph feature vector is a set of features which represent the whole graph or molecule. For example, when computing solubility it may have been useful to build up a per-molecule feature vector that is eventually used to compute solubility instead of having the readout. Any kind of per-molecule quantity, like energy, should be predicted with the graph-level feature vector. \n",
"\n",
"The first step in these equations is updating the edge feature vectors, written as $\\vec{e}_k$, which we haven't seen yet:\n",
"\n",
"$$\n",
"\\vec{e}^{'}_k = \\phi^e\\left( \\vec{e}_k, \\vec{v}_{rk}, \\vec{v}_{sk}, \\vec{u}\\right)\n",
"$$ (edge-update)\n",
"\n",
"where $\\vec{e}_k$ is the feature vector of edge $k$, $\\vec{v}_{rk}$ is the receiving node feature vector for edge $k$, $\\vec{v}_{sk}$ is the sending node feature vector for edge $k$, $\\vec{u}$ is the graph feature vector, and $\\phi^e$ is one of the three update functions that the define the GNN layer. Note that these are meant to be general expressions and you define $\\phi^e$ for your specific GNN layer. \n",
"\n",
"Our molecular graphs are undirected, so how do we decide which node is receiving $\\vec{v}_{rk}$ and which node is sending $\\vec{v}_{sk}$? The individual $\\vec{e}^{'}_k$ are aggregated in the next step as all the inputs into node $v_{rk}$. In our molecular graph, all bonds are both \"inputs\" and \"outputs\" from an atom (how else could it be?), so it makes sense to just view every bond as two directed edges: a C-H bond has an edge from C to H and an edge from H to C. In fact, our adjacency matrices already reflect that. There are two non-zero elements in them for each bond: one for C to H and one for H to C. Back to the original question, what is $\\vec{v}_{rk}$ and $\\vec{v}_{sk}$? We consider every element in the adjacency matrix (every $k$) and when we're on element $k = \\{ij\\}$, which is $A_{ij}$, then the receiving node is $j$ and the sending node is $i$. When we consider the companion edge $A_{ji}$, the receiving node is $i$ and the sending node is $j$. \n",
"\n",
"$\\vec{e}^{'}_k$ is like the message from the GCN. Except it's more general: it can depend on the receiving node and the graph feature vector $\\vec{u}$. The metaphor of a \"message\" doesn't quite apply, since a message cannot be affected by the receiver. Anyway, the new edge updates are then aggregated with the first aggregation function:\n",
"\n",
"$$\n",
"\\bar{e}^{'}_i = \\rho^{e\\rightarrow v}\\left( E_i^{'}\\right)\n",
"$$ (edge-aggregation)\n",
"\n",
"where $\\rho^{e\\rightarrow v}$ is our defined function and $E_i^{'}$ represents stacking all $\\vec{e}^{'}_k$ from edges **into** node i. Having our aggregated edges, we can compute the node update:\n",
"\n",
"$$\n",
"\\vec{v}^{'}_i = \\phi^v\\left( \\bar{e}^{'}_i, \\vec{v}_i, \\vec{u}\\right)\n",
"$$ (node-update)\n",
"\n",
"This concludes the usual steps of a GNN layer because we have new nodes and new edges. If you are updating the graph features ($\\vec{u}$), the following additional steps may be defined:\n",
"\n",
"$$\n",
"\\bar{e}^{'} = \\rho^{e\\rightarrow u}\\left( E^{'}\\right)\n",
"$$ (edge-all-aggregation)\n",
"\n",
"This equation aggregates all messages/aggregated edges across the whole graph. Then we can aggregate the new nodes across the whole graph:\n",
"\n",
"$$\n",
"\\bar{v}^{'} = \\rho^{v\\rightarrow u}\\left( V^{'}\\right)\n",
"$$ (node-all-aggregation)\n",
"\n",
"Finally, we can compute the update to the graph feature vector as:\n",
"\n",
"$$\n",
"\\vec{u}^{'} = \\phi^u\\left( \\bar{e}^{'},\\bar{v}^{'}, \\vec{u}\\right)\n",
"$$ (global-update)\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Reformulating GCN into Battaglia equations\n",
"\n",
"Let's see how the GCN is presented in this form. We first compute our neighbor messages for all possible neighbors using {eq}`edge-update`. Remember in the GCN, messages only depend on the senders.\n",
"\n",
"```{margin}\n",
"Even though we use the \"edge update\" function, remember in a GCN we ignore \n",
"the edge features. We only care edges for defining the connectivity of the graph.\n",
"```\n",
"\n",
"$$\n",
"\\vec{e}^{'}_k = \\phi^e\\left( \\vec{e}_k, \\vec{v}_{rk}, \\vec{v}_{sk}, \\vec{u}\\right) = \\vec{v}_{sk} \\mathbf{W}\n",
"$$\n",
"\n",
"To aggregate our messages coming into node $i$ in {eq}`edge-aggregation`, we average them.\n",
"\n",
"$$\n",
"\\bar{e}^{'}_i = \\rho^{e\\rightarrow v}\\left( E_i^{'}\\right) = \\frac{1}{|E_i^{'}|}\\sum E_i^{'}\n",
"$$\n",
"\n",
"Our node update is then the activation {eq}`node-update`\n",
"\n",
"$$\n",
"\\vec{v}^{'}_i = \\phi^v\\left( \\bar{e}^{'}_i, \\vec{v}_i, \\vec{u}\\right) = \\sigma(\\bar{e}^{'}_i)\n",
"$$\n",
"\n",
"we could include the self-loop above using $\\sigma(\\bar{e}^{'}_i + \\vec{v}_i)$. The other functions are not used in a GCN, so those three completely define the GCN."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## The SchNet Architecture\n",
"\n",
"One of the earliest and most popular GNNs is the SchNet network {cite}`schutt2018schnet`. It wasn't really recognized at publication time as a GNN, but its now recognized as one and you'll see it often used as a baseline model. A **baseline** model is a well-accepted and accurate model that is compared with.\n",
"\n",
"```{margin} Baseline Models\n",
"A common piece of wisdom is if you want to solve a real problem with deep learning, you should read the most recent popular paper in an area and use the baseline they compare against instead of their proposed model. The reason is that a baseline model usually must be easy, fast, and well-tested, which is generally more important than being the most accurate\n",
"```\n",
"\n",
"SchNet is for atoms represented as xyz coordinates (points) -- not as a molecular graph. All our previous examples used the underlying molecular graph as the input. In SchNet we will convert our xyz coodinates into a graph, so that we can apply a GNNN. SchNet was developed for predicting energies and forces from atom configurations without bond information. Thus, we need to first see how a set of atoms and their positions is converted into a graph. To get the nodes, we do a similar process as above and the atomic number is passed through an embedding layer, which is just means we assign a trainable vector to each atomic number (See {doc}`layers` for a review of embeddings). \n",
"\n",
"Getting the adjacency matrix is simple too: we just make every atom be connected to every atom. It might seem confusing what the point of using a GNN is, if we're just connecting everything. *It is because GNNs are permutation equivariant.* If we tried to do learning on the atoms as xyz coordinates, we would have weights depending on the ordering of atoms and probably fail to handle different numbers of atoms.\n",
"\n",
"There is one more missing detail: where do the xyz coordinates go? We make the model depend on xyz coordinates by constructing the edge features from the xyz coordinates. The edge $\\vec{e}$ between atoms $i$ and $j$ is computed purely from their distance $r$:\n",
"\n",
"$$\n",
"e_k = \\exp\\left(-\\gamma \\left(r - \\mu_k\\right)^2\\right)\n",
"$$ (rbf-edge)\n",
"\n",
"where $\\gamma$ is a hyperparameter (e.g., 10Å) and $\\mu_k$ is an equally-space grid of scalars - like `[0, 5, 10, 15 , 20]`. The purpose of {eq}`rbf-edge` is similar to turning a category feature like atomic number or covalent bond type into a one-hot vector. We cannot do a one-hot vector though, because there is an infinite number of possible distances. Thus, we have a kind of \"smoothing\" that gives us a pseudo one-hot for distance. Let's see an example to get a sense of it:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"gamma = 1\n",
"mu = np.linspace(0, 10, 5)\n",
"\n",
"\n",
"def rbf(r):\n",
" return np.exp(-gamma * (r - mu) ** 2)\n",
"\n",
"\n",
"print(\"input\", 2)\n",
"print(\"output\", np.round(rbf(2), 2))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"tags": [
"remove-cell"
]
},
"outputs": [],
"source": [
"# THIS CELL IS USED TO GENERATE A FIGURE\n",
"# AND NOT RELATED TO CHAPTER\n",
"# YOU CAN SKIP IT\n",
"\n",
"\n",
"def ssp(x):\n",
" return np.log(0.5 * np.exp(x) + 0.5)\n",
"\n",
"\n",
"def relu(x):\n",
" return np.where(x < 0, np.zeros_like(x), x)\n",
"\n",
"\n",
"plt.figure()\n",
"x = np.linspace(-1.5, 1.5)\n",
"plt.axvline(x=0, linestyle=\"--\", color=\"#AAA\")\n",
"plt.axhline(y=0, linestyle=\"--\", color=\"#AAA\")\n",
"plt.plot(x, ssp(x), label=\"Shifted Softplus\")\n",
"plt.plot(x, relu(x), label=\"ReLU\")\n",
"plt.xlabel(\"$x$\")\n",
"plt.ylabel(\"$\\sigma(x)$\")\n",
"plt.legend()\n",
"glue(\"softplus\", plt.gcf(), display=False)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"You can see that a distance of $r=2$ gives a vector with most of the activation for the $k = 1$ position - which corresponds to $\\mu_1 = 2$. \n",
"\n",
"We have our nodes and edges and are close to defining the GNN update equations. We need a bit more notation. I'm going to use $h(\\vec{x})$ to indicate a multilayer perceptron (MLP) -- basically a 1 to 2 dense layers neural network. The exact number of dense layers and when/where activation is used in these MLPs will be defined in the implementation, because it is not so important for understanding. Recall, the definition of a dense layer is\n",
"\n",
"$$\n",
"h(\\vec{x}) = \\sigma\\left(Wx + b\\right)\n",
"$$\n",
"\n",
"We'll also use a different activation function $\\sigma$ called \"shifted softplus\" in SchNet: $\\ln\\left(0.5e^{x} + 0.5\\right)$. You can see $\\sigma(x)$ compared with the usual ReLU activation in {numref}`softplus`. The rationale for using shifted softplus is that it is smooth with-respect to its input, so it could be used to compute forces in a molecular dynamics simulation which requires taking smooth derivatives with respect to pairwise distances.\n",
"\n",
"```{glue:figure} softplus\n",
"----\n",
"name: softplus\n",
"----\n",
"Comparison of the usual ReLU activation function and the shifted softplus used in the SchNet model.\n",
"```\n",
"\n",
"Now, the GNN equations! The edge update equation {eq}`edge-update` is composed of two pieces. First, we run the incoming edge feature through an MLP and the atoms through an MLP. Then the result is run through an MLP:\n",
"\n",
"$$\n",
"\\vec{e}^{'}_k = \\phi^e\\left( \\vec{e}_k, \\vec{v}_{rk}, \\vec{v}_{sk}, \\vec{u}\\right) =h_1\\left(\\vec{v}_{sk}\\right) \\cdot h_2\\left(\\vec{e}_k\\right)\n",
"$$"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The next equation is the edge aggregation equation, {eq}`edge-aggregation`. For SchNet, the edge aggregation is a sum over the neighbor atom features.\n",
"\n",
"$$\n",
"\\bar{e}^{'}_i = \\sum E_i^{'}\n",
"$$\n",
"\n",
"Finally, the node update equation for SchNet is:\n",
"\n",
"$$\n",
"\\vec{v}^{'}_i = \\phi^v\\left( \\bar{e}^{'}_i, \\vec{v}_i, \\vec{u}\\right) = \\vec{v}_i + h_3\\left(\\bar{e}^{'}_i\\right)\n",
"$$\n",
"\n",
"The GNN updates are applied typically 3-6 times. Although we have an edge update equation, like in GCN we do not actually override the edges and keep them the same at each layer. The original SchNet was for predicting energies and forces, so a readout can be done using sum-pooling or any other strategy described above. \n",
"\n",
"These are sometimes changed, but in the original SchNet paper $h_1$ is one dense layer without activation, $h_2$ is two dense layers with activation, and $h_3$ is 2 dense layers with activation on the first and not the second.\n",
"\n",
"\n",
"```{admonition} What is SchNet?\n",
"The key GNN feature of a SchNet-like GNN are (1) use edge & node features in the edge update (message construction):\n",
"\n",
"$$\n",
"\\vec{e}^{'}_k = h_1(\\vec{v}_{sk}) \\cdot h_2(\\vec{e}_k)\n",
"$$\n",
"\n",
"where $h_i()$s are some trainable functions and (2) use a residue in the node update:\n",
"\n",
"$$\n",
"\\vec{v}^{'}_i = \\vec{v}_i + h_3\\left(\\bar{e}^{'}_i\\right)\n",
"$$\n",
"\n",
"```\n",
"\n",
"All the other details about how to featurize the edges, how deep $h_i$ is, what activation to choose, how to readout, and how to convert point clouds to graphs are about the specific SchNet model in {cite}`schutt2018schnet`. "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## SchNet Example: Predicting Space Groups\n",
"\n",
"Our next example will be a SchNet model that predict space groups of points. Identifying the space group of atoms is an important part of crystal structure identification, and when doing simulations of crystallization. Our SchNet model will take as input points and output the predicted space group. This is a classification problem; specifically it is multi-class becase a set of points should only be in one space group. To simplify our plots and analysis, we will work in 2D where there are 17 possible space groups. \n",
"\n",
"Our data for this is a set of points from various point groups. The features are xyz coordinates and the label is the space group. We will not have multiple atom types for this problem. The hidden cell below loads the data and reshapes it for the example. \n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"tags": [
"hide-cell"
]
},
"outputs": [],
"source": [
"import gzip\n",
"import pickle\n",
"import urllib\n",
"\n",
"urllib.request.urlretrieve(\n",
" \"https://github.com/whitead/dmol-book/raw/master/data/sym_trajs.pb.gz\",\n",
" \"sym_trajs.pb.gz\",\n",
")\n",
"with gzip.open(\"sym_trajs.pb.gz\", \"rb\") as f:\n",
" trajs = pickle.load(f)\n",
"\n",
"label_str = list(set([k.split(\"-\")[0] for k in trajs]))\n",
"\n",
"# now build dataset\n",
"def generator():\n",
" for k, v in trajs.items():\n",
" ls = k.split(\"-\")[0]\n",
" label = label_str.index(ls)\n",
" traj = v\n",
" for i in range(traj.shape[0]):\n",
" yield traj[i], label\n",
"\n",
"\n",
"data = tf.data.Dataset.from_generator(\n",
" generator,\n",
" output_signature=(\n",
" tf.TensorSpec(shape=(None, 2), dtype=tf.float32),\n",
" tf.TensorSpec(shape=(), dtype=tf.int32),\n",
" ),\n",
").shuffle(\n",
" 1000,\n",
" reshuffle_each_iteration=False, # do not change order each time (!) otherwise will contaminate\n",
")\n",
"\n",
"# The shuffling above is really important because this dataset is in order of labels!\n",
"\n",
"val_data = data.take(100)\n",
"test_data = data.skip(100).take(100)\n",
"train_data = data.skip(200)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's take a look at a few examples from the dataset\n",
"\n",
"```{admonition} The Data\n",
":class: dropdown\n",
"This data was generated from {cite}`cox2022symmetric` and all points are constrained to match the space group exactly during a molecular dynamics simulation. The trajectories were NPT with a positive pressure and followed the procedure in that paper for Figure 2. The force field is Lennard-Jones with $\\sigma=1$ and $\\epsilon=1$\n",
"```"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"scrolled": false
},
"outputs": [],
"source": [
"fig, axs = plt.subplots(4, 5, figsize=(12, 8))\n",
"axs = axs.flatten()\n",
"\n",
"# get a few example and plot them\n",
"for i, (x, y) in enumerate(data):\n",
" if i == 20:\n",
" break\n",
" axs[i].plot(x[:, 0], x[:, 1], \".\")\n",
" axs[i].set_title(label_str[y.numpy()])\n",
" axs[i].axis(\"off\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"You can see that there is a variable number of points and a few examples for each space group. The goal is to infer those titles on the plot from the points alone.\n",
"\n",
"### Building the graphs\n",
"\n",
"We now need to build the graphs for the points. The nodes are all identical - so they can just be 1s (we'll reserve 0 in case we want to mask or pad at some point in the future). As described in the SchNet section above, the edges should be distance to every other atom. In most implementations of SchNet, we practically add a cut-off on either distance or maximum degree (edges per node). We'll do maximum degree for this work of 16.\n",
"\n",
"I have a function below that is a bit sophisticated. It takes a matrix of point positions in arbitrary dimension and returns the distances and indices to the nearest `k` neighbors - exactly what we need. It uses some tricks from {doc}`../math/tensors-and-shapes`. However, it is not so important for you to understand this function. Just know it takes in points and gives us the edge features and edge nodes."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# this decorator speeds up the function by \"compiling\" it (tracing it)\n",
"# to run efficienty\n",
"@tf.function(\n",
" reduce_retracing=True,\n",
")\n",
"def get_edges(positions, NN, sorted=True):\n",
" M = tf.shape(input=positions)[0]\n",
" # adjust NN\n",
" NN = tf.minimum(NN, M)\n",
" qexpand = tf.expand_dims(positions, 1) # one column\n",
" qTexpand = tf.expand_dims(positions, 0) # one row\n",
" # repeat it to make matrix of all positions\n",
" qtile = tf.tile(qexpand, [1, M, 1])\n",
" qTtile = tf.tile(qTexpand, [M, 1, 1])\n",
" # subtract them to get distance matrix\n",
" dist_mat = qTtile - qtile\n",
" # mask distance matrix to remove zros (self-interactions)\n",
" dist = tf.norm(tensor=dist_mat, axis=2)\n",
" mask = dist >= 5e-4\n",
" mask_cast = tf.cast(mask, dtype=dist.dtype)\n",
" # make masked things be really far\n",
" dist_mat_r = dist * mask_cast + (1 - mask_cast) * 1000\n",
" topk = tf.math.top_k(-dist_mat_r, k=NN, sorted=sorted)\n",
" return -topk.values, topk.indices"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's see how this function works by showing the connections between points in one of our examples. I've hidden the code below. It shows some point's neighbors and connects them so you can get a sense of how a set of points is converted into a graph. The complete graph will have all points' neighborhoods."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"tags": [
"hide-input"
]
},
"outputs": [],
"source": [
"from matplotlib import collections\n",
"\n",
"fig, axs = plt.subplots(2, 3, figsize=(12, 8))\n",
"axs = axs.flatten()\n",
"for i, (x, y) in enumerate(data):\n",
" if i == 6:\n",
" break\n",
" e_f, e_i = get_edges(x, 8)\n",
"\n",
" # make things easier for plotting\n",
" e_i = e_i.numpy()\n",
" x = x.numpy()\n",
" y = y.numpy()\n",
"\n",
" # make lines from origin to its neigbhors\n",
" lines = []\n",
" colors = []\n",
" for j in range(0, x.shape[0], 23):\n",
" # lines are [(xstart, ystart), (xend, yend)]\n",
" lines.extend([[(x[j, 0], x[j, 1]), (x[k, 0], x[k, 1])] for k in e_i[j]])\n",
" colors.extend([f\"C{j}\"] * len(e_i[j]))\n",
" lc = collections.LineCollection(lines, linewidths=2, colors=colors)\n",
" axs[i].add_collection(lc)\n",
" axs[i].plot(x[:, 0], x[:, 1], \".\")\n",
" axs[i].axis(\"off\")\n",
" axs[i].set_title(label_str[y])\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We will now add this function and the edge featurization of SchNet {eq}`rbf-edge` to get the graphs for the GNN steps."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"MAX_DEGREE = 16\n",
"EDGE_FEATURES = 8\n",
"MAX_R = 20\n",
"\n",
"gamma = 1\n",
"mu = np.linspace(0, MAX_R, EDGE_FEATURES)\n",
"\n",
"\n",
"def rbf(r):\n",
" return tf.exp(-gamma * (r[..., tf.newaxis] - mu) ** 2)\n",
"\n",
"\n",
"def make_graph(x, y):\n",
" edge_r, edge_i = get_edges(x, MAX_DEGREE)\n",
" edge_features = rbf(edge_r)\n",
" return (tf.ones(tf.shape(x)[0], dtype=tf.int32), edge_features, edge_i), y[None]\n",
"\n",
"\n",
"graph_train_data = train_data.map(make_graph)\n",
"graph_val_data = val_data.map(make_graph)\n",
"graph_test_data = test_data.map(make_graph)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's examine one graph to see what it looks like. We'll slice out only the first nodes."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"for (n, e, nn), y in graph_train_data:\n",
" print(\"first node:\", n[1].numpy())\n",
" print(\"first node, first edge features:\", e[1, 1].numpy())\n",
" print(\"first node, all neighbors\", nn[1].numpy())\n",
" print(\"label\", y.numpy())\n",
" break"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Implementing the MLPs\n",
"\n",
"Now we can implement the SchNet model! Let's start with the $h_1,h_2,h_3$ MLPs that are used in the GNN update equations. In the SchNet paper these each had different numbers of layers and different decisions about which layers had activation. Let's create them now."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def ssp(x):\n",
" # shifted softplus activation\n",
" return tf.math.log(0.5 * tf.math.exp(x) + 0.5)\n",
"\n",
"\n",
"def make_h1(units):\n",
" return tf.keras.Sequential([tf.keras.layers.Dense(units)])\n",
"\n",
"\n",
"def make_h2(units):\n",
" return tf.keras.Sequential(\n",
" [\n",
" tf.keras.layers.Dense(units, activation=ssp),\n",
" tf.keras.layers.Dense(units, activation=ssp),\n",
" ]\n",
" )\n",
"\n",
"\n",
"def make_h3(units):\n",
" return tf.keras.Sequential(\n",
" [tf.keras.layers.Dense(units, activation=ssp), tf.keras.layers.Dense(units)]\n",
" )"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"One detail that can be missed is that the weights in each MLP should change in each layer of SchNet. Thus, we've written the functions above to always return a new MLP. This means that a new set of trainable weights is generated on each call, meaning there is no way we could erroneously have the same weights in multiple layers.\n",
"\n",
"\n",
"### Implementing the GNN\n",
"\n",
"Now we have all the pieces to make the GNN. This code will be very similar to the GCN example above, except we now have edge features. One more detail is that our readout will be an MLP as well, following the SchNet paper. The only change we'll make is that we want our output property to be (1) multi-class classification and (2) intensive (independent of number of atoms). So we'll end with an average (intensive) and end with an output vector of logits the size of our labels."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"class SchNetModel(tf.keras.Model):\n",
" \"\"\"Implementation of SchNet Model\"\"\"\n",
"\n",
" def __init__(self, gnn_blocks, channels, label_dim, **kwargs):\n",
" super(SchNetModel, self).__init__(**kwargs)\n",
" self.gnn_blocks = gnn_blocks\n",
"\n",
" # build our layers\n",
" self.embedding = tf.keras.layers.Embedding(2, channels)\n",
" self.h1s = [make_h1(channels) for _ in range(self.gnn_blocks)]\n",
" self.h2s = [make_h2(channels) for _ in range(self.gnn_blocks)]\n",
" self.h3s = [make_h3(channels) for _ in range(self.gnn_blocks)]\n",
" self.readout_l1 = tf.keras.layers.Dense(channels // 2, activation=ssp)\n",
" self.readout_l2 = tf.keras.layers.Dense(label_dim)\n",
"\n",
" def call(self, inputs):\n",
" nodes, edge_features, edge_i = inputs\n",
" # turn node types as index to features\n",
" nodes = self.embedding(nodes)\n",
" for i in range(self.gnn_blocks):\n",
" # get the node features per edge\n",
" v_sk = tf.gather(nodes, edge_i)\n",
" e_k = self.h1s[i](v_sk) * self.h2s[i](edge_features)\n",
" e_i = tf.reduce_sum(e_k, axis=1)\n",
" nodes += self.h3s[i](e_i)\n",
" # readout now\n",
" nodes = self.readout_l1(nodes)\n",
" nodes = self.readout_l2(nodes)\n",
" return tf.reduce_mean(nodes, axis=0)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Remember that the key attributes of a SchNet GNN are the way that we use edge and node features. We can see the mixing of these two in the key line for computing the edge update (computing message values):\n",
"\n",
"```python\n",
"e_k = self.h1s[i](v_sk) * self.h2s[i](edge_features)\n",
"```\n",
"\n",
"followed by aggregation of the edges updates (pooling messages):\n",
"\n",
"```python\n",
"e_i = tf.reduce_sum(e_k, axis=1)\n",
"```\n",
"\n",
"and the node update\n",
"\n",
"```python\n",
"nodes += self.h3s[i](e_i)\n",
"```\n",
"\n",
"Also of note is how we go from node features to multi-classs. We use dense layers that get the shape per-node into the number of classes\n",
"\n",
"```python\n",
"self.readout_l1 = tf.keras.layers.Dense(channels // 2, activation=ssp)\n",
"self.readout_l2 = tf.keras.layers.Dense(label_dim)\n",
"```\n",
"\n",
"and then we take the average over all nodes\n",
"\n",
"```python\n",
"return tf.reduce_mean(nodes, axis=0)\n",
"```\n",
"\n",
"---\n",
"\n",
"Let's give now use the model on some data."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"small_schnet = SchNetModel(3, 32, len(label_str))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"for x, y in graph_train_data:\n",
" yhat = small_schnet(x)\n",
" break\n",
"print(yhat.numpy())"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The output is the correct shape and remember it is logits. To get a class prediction that sums to probability 1, we need to use a softmax:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"print(\"predicted class\", tf.nn.softmax(yhat).numpy())"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Training\n",
"\n",
"Great! It is untrained though. Now we can set-up training. Our loss will be cross-entropy from logits, but we need to be careful on the form. Our labels are integers - which is called \"sparse\" labels because they are not full one-hots. Mult-class classification is also known as categorical classification. Thus, the loss we want is sparse categorical cross entropy from logits."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"tags": [
"remove-output"
]
},
"outputs": [],
"source": [
"small_schnet.compile(\n",
" optimizer=tf.keras.optimizers.Adam(1e-4),\n",
" loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),\n",
" metrics=\"sparse_categorical_accuracy\",\n",
")\n",
"result = small_schnet.fit(graph_train_data, validation_data=graph_val_data, epochs=20)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"plt.plot(result.history[\"sparse_categorical_accuracy\"], label=\"training accuracy\")\n",
"plt.plot(result.history[\"val_sparse_categorical_accuracy\"], label=\"validation accuracy\")\n",
"plt.axhline(y=1 / 17, linestyle=\"--\", label=\"random\")\n",
"plt.legend()\n",
"plt.xlabel(\"Epoch\")\n",
"plt.ylabel(\"Accuracy\")\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The accuracy is not great, but it looks like we could keep training. We have a very small SchNet here. Standard SchNet described in {cite}`schutt2018schnet` uses 6 layers and 64 channels and 300 edge features. We have 3 layers and 32 channels. Nevertheless, we're able to get some learning. Let's visually see what's going on with the trained model on some test data"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"scrolled": false
},
"outputs": [],
"source": [
"fig, axs = plt.subplots(4, 5, figsize=(12, 8))\n",
"axs = axs.flatten()\n",
"\n",
"for i, ((x, y), (gx, _)) in enumerate(zip(test_data, graph_test_data)):\n",
" if i == 20:\n",
" break\n",
" axs[i].plot(x[:, 0], x[:, 1], \".\")\n",
" yhat = small_schnet(gx)\n",
" yhat_i = tf.math.argmax(tf.nn.softmax(yhat)).numpy()\n",
" axs[i].set_title(f\"True: {label_str[y.numpy()]}\\nPredicted: {label_str[yhat_i]}\")\n",
" axs[i].axis(\"off\")\n",
"plt.tight_layout()\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We'll revisit this example later! One unique fact about this dataset is that it is *synthetic*, meaning there is no label noise. As discussed in {doc}`../ml/regression`, that removes the possibility of overfitting and leads us to favor high variance models. The goal of teaching a model to predict space groups is to apply it on real simulations or microscopy data, which will certainly have noise. We could have mimicked this by adding noise to the labels in the data and/or by randomly removing atoms to simulate defects. This would better help our model work in a real setting. "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Current Research Directions\n",
"\n",
"### Common Architecture Motifs and Comparisons\n",
"\n",
"We've now seen message passing layer GNNs, GCNs, GGNs, and the generalized Battaglia equations. You'll find common motifs in the architectures, like gating, {doc}`attention`, and pooling strategies. For example, Gated GNNS (GGNs) can be combined with attention pooling to create Gated Attention GNNs (GAANs){cite}`zhang2018gaan`. GraphSAGE is a similar to a GCN but it samples when pooling, making the neighbor-updates of fixed dimension{cite}`hamilton2017inductive`. So you'll see the suffix \"sage\" when you sample over neighbors while pooling. These can all be represented in the Battaglia equations, but you should be aware of these names. \n",
"\n",
"The enormous variety of architectures has led to work on identifying the \"best\" or most general GNN architecture {cite}`dwivedi2020benchmarking,errica2019fair,shchur2018pitfalls`. Unfortunately, the question of which GNN architecture is best is as difficult as \"what benchmark problems are best?\" Thus there are no agreed-upon conclusions on the best architecture. However, those papers are great resources on training, hyperparameters, and reasonable starting guesses and I highly recommend reading them before designing your own GNN. There has been some theoretical work to show that simple architectures, like GCNs, cannot distinguish between certain simple graphs {cite}`xu2018powerful`. How much this practically matters depends on your data. Ultimately, there is so much variety in hyperparameters, data equivariances, and training decisions that you should think carefully about how much the GNN architecture matters before exploring it with too much depth. "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Nodes, Edges, and Features\n",
"\n",
"You'll find that most GNNs use the node-update equation in the Battaglia equations but do not update edges. For example, the GCN will update nodes at each layer but the edges are constant. Some recent work has shown that updating edges can be important for learning when the edges have geometric information, like if the input graph is a molecule and the edges are distance between the atoms {cite}`klicpera2019directional`. As we'll see in the chapter on equivariances ({doc}`../dl/data`), one of the key properties of neural networks with point clouds (i.e., Cartesian xyz coordinates) is to have rotation equivariance. {cite}`klicpera2019directional` showed that you can achieve this if you do edge updates and encode the edge vectors using a rotation equivariant basis set with spherical harmonics and Bessel functions. These kind of edge updating GNNs can be used to predict protein structure {cite}`jing2020learning`.\n",
"\n",
"Another common variation on node features is to pack more into node features than just element identity. In many examples, you will see people inserting valence, elemental mass, electronegativity, a bit indicating if the atom is in a ring, a bit indicating if the atom is aromatic, etc. Typically these are unnecessary, since a model should be able to learn any of these features which are computed from the graph and node elements. However, we and others have empirically found that some can help, specifically indicating if an atom is in a ring {cite}`li2020graph`. Choosing extra features to include though should be at the bottom of your list of things to explore when designing and using GNNs."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Beyond Message Passing\n",
"\n",
"One of the common themes of GNN research is moving \"beyond message passing,\" where message passing is the message construction, aggregation, and node update with messages. Some view this as impossible -- claiming that all GNNs can be recast as message passing {cite}`velivckovic2022message`. Another direction is on disconnecting the underlying graph being input to the GNN and the graph used to compute updates. We sort of saw this above with SchNet, where we restricted the maximum degree for the message passing. More useful are ideas like \"lifting\" the graphs into more structured objects like simplicial complexes {cite}`bodnar2021weisfeiler`. Finally, you can also choose where to send the messages beyond just neighbors {cite}`thiede2021autobahn`. For example, all nodes on a path could communicate messages or all nodes in a clique. \n",
"\n",
"### Do we need graphs?\n",
"\n",
"It is possible to convert a graph into a string if you're working with an adjacency matrix without continuous values. Molecules specifically can be converted into a string. This means you can use layers for sequences/strings (e.g., recurrent neural networks or 1D convolutions) and avoid the complexities of a graph neural network. SMILES is one way to convert molecular graphs into strings. With SMILES, you cannot predict a per-atom quantity and thus a graph neural network is required for atom/bond labels. However, the choice is less clear for per-molecule properties like toxicity or solubility. There is no consensus about if a graph or string/SMILES representation is better. SMILES can exceed certain graph neural networks in accuracy on some tasks. SMILES is typically better on generative tasks. Graphs obviously beat SMILES in label representations, because they have granularity of bonds/edges. We'll see how to model SMILES in {doc}`NLP`, but it is an open question of which is better.\n",
"\n",
"\n",
"### Stereochemistry/Chiral Molecules\n",
"\n",
"Stereochemistry is fundamentally a 3D property of molecules and thus not present in the covalent bonding. It is measured experimentally by seeing if molecules rotate polarized light and a molecule is called chiral or \"optically active\" if it is experimentally known to have this property. Stereochemistry is the categorization of how molecules can preferentially rotate polarized light through asymmetries with respect to their mirror images. In organic chemistry, the majority of stereochemistry is of enantiomers. Enantiomers are \"handedness\" around specific atoms called chiral centers which have 4 or more different bonded atoms. These may be treated in a graph by indicating which nodes are chiral centers (nodes) and what their state or mixture of states (racemic) are. This can be treated as an extra processing step. Amino acids and thus all proteins are entaniomers with only one form present. This chirality of proteins means many drug molecules can be more or less potent depending on their stereochemistry. \n",
"\n",
"\n",
"```{figure} ../_static/images/helicene.mp4\n",
"----\n",
"name: helicene\n",
"width: 500px\n",
"class: autoplay-video\n",
"----\n",
"This is a molecule with axial stereochemistry. Its small helix could be either left or right-handed. \n",
"```\n",
"\n",
"Adding node labels is not enough generally. Molecules can interconvert between stereoisomers at chiral centers through a process called tautomerization. There are also types of stereochemistry that are not at a specific atom, like rotamers that are around a bond. Then there is stereochemistry that involves multiple atoms like axial helecene. As shown in {numref}`helicene`, the molecule has no chiral centers but is \"optically active\" (experimentally measured to be chiral) because of its helix which can be left- or right-handed. "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Relevant Videos\n",
"\n",
"### Intro to GNNs\n",
"\n",
"\n",
"\n",
"### Overview of GNN with Molecule, Compiler Examples\n",
"\n",
"\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Chapter Summary \n",
"\n",
"* Molecules can be represented by graphs by using one-hot encoded feature vectors that show the elemental identity of each node (atom) and an adjacency matrix that show immediate neighbors (bonded atoms).\n",
"* Graph neural networks are a category of deep neural networks that have graphs as inputs.\n",
"* One of the early GNNs is the Kipf & Welling GCN. The input to the GCN is the node feature vector and the adjacency matrix, and returns the updated node feature vector. The main reason a GCN is *permutation equivariant* is because it pools over each nodes' neighbors in a permutation *invariant* way (e.g., averaging). \n",
"* A GCN can be viewed as a message-passing layer, in which we have senders and receivers. Messages are computed from neighboring nodes, which when aggregated update that node. \n",
"* A gated graph neural network is a variant of the message passing layer, for which the nodes are updated according to a gated recurrent unit function. \n",
"* The aggregation of messages is sometimes called pooling, for which there are multiple reduction operations. \n",
"* GNNs output a graph. To get a per-atom or per-molecule property, use a readout function. The readout depends on if your property is intensive vs extensive \n",
"* The Battaglia equations encompasses almost all GNNs into a set of 6 update and aggregation equations.\n",
"* You can convert xyz coordinates into a graph and use a GNN like SchNet "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Cited References\n",
"\n",
"```{bibliography}\n",
":style: unsrtalpha\n",
":filter: docname in docnames\n",
"```"
]
}
],
"metadata": {
"celltoolbar": "Tags",
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.5"
}
},
"nbformat": 4,
"nbformat_minor": 4
}