{
"cells": [
{
"cell_type": "markdown",
"id": "a4389326",
"metadata": {},
"source": [
"# Modern Molecular NNs\n",
"\n",
"We have seen two chapters about equivariances in {doc}`data` and {doc}`Equivariant`. We have seen one chapter on dealing with molecules as objects with permutation equivariance {doc}`gnn`. We will combine these ideas and create neural networks that can treat arbitrary molecules with point clouds and permutation equivariance. We already saw SchNet is able to do this by working with an invariant point cloud representation (distance to atoms), but modern networks mix in ideas from {doc}`Equivariant` along with graph neural networks (GNN). This is a highly-active research area, especially for predicting energies, forces, and relaxed structures of molecules.\n",
"\n",
"```{admonition} Audience & Objectives\n",
"This chapter assumes you have read {doc}`data`, {doc}`Equivariant`, and {doc}`gnn`. You should be able to\n",
"\n",
" * Categorize a task (features/labels) by equivariance \n",
" * Understand body-ordered expansions\n",
" * Differentiate models based on their message passing, message type, and body-ordering \n",
"```"
]
},
{
"cell_type": "markdown",
"id": "763e14e1",
"metadata": {},
"source": [
"```{warning}\n",
"This chapter is in progress\n",
"```"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "8d468e41",
"metadata": {
"tags": [
"remove-cell"
]
},
"outputs": [],
"source": [
"# This cell is for making plots, not part of examples\n",
"import rdkit, rdkit.Chem, rdkit.Chem.rdDepictor, rdkit.Chem.Draw\n",
"from myst_nb import glue\n",
"import networkx as nx\n",
"import dmol\n",
"\n",
"# I hate to do this manually, but I cannot get the\n",
"# damn molecular fonts to be big enough\n",
"import skunk\n",
"import matplotlib.pyplot as plt\n",
"\n",
"\n",
"def _mol2svg(m, size):\n",
" d = rdkit.Chem.Draw.rdMolDraw2D.MolDraw2DSVG(*size)\n",
" d.DrawMolecule(m)\n",
" d.FinishDrawing()\n",
" return d.GetDrawingText()\n",
"\n",
"\n",
"m1 = rdkit.Chem.MolFromSmiles(\"C1CCC2CCCCC2C1\")\n",
"m2 = rdkit.Chem.MolFromSmiles(\"C1CCC(C1)C2CCCC2\")\n",
"s1 = _mol2svg(m1, (200, 200))\n",
"s2 = _mol2svg(m2, (200, 200))\n",
"_, axs = plt.subplots(1, 2, squeeze=True)\n",
"axs[0].set_title(\"decaline\")\n",
"axs[1].set_title(\"bicylopentyl\")\n",
"axs[0].axis(\"off\")\n",
"axs[1].axis(\"off\")\n",
"skunk.connect(axs[0], \"m1\")\n",
"skunk.connect(axs[1], \"m2\")\n",
"svg = skunk.insert({\"m1\": s1, \"m2\": s2})\n",
"with open(\"lwtest.svg\", \"w\") as f:\n",
" f.write(svg)"
]
},
{
"cell_type": "markdown",
"id": "1bed2a9e",
"metadata": {},
"source": [
"# Expressiveness\n",
"\n",
"The Equivariant SO(3) ideas from {doc}`Equivariant` will not work on variable sized molecules because the layers are not permutation equivariant. We also know that graph neural networks (GNNs) have permutation equivariance and, with the correct choice of edge features, rotation and translation invariance. So why go beyond GNNs?\n",
"\n",
"One reason is that the standard GNNs cannot distinguish certain types of graphs relevant for chemistry is they cannot distinguish molecules like decaline and bicylopentyl, which indeed have different properties. Look at the {numref}`decaline-bicylopentyl` below and think about the degree and neighbors of the atoms near the mixing of the rings -- you'll see if you try to use message passing the two molecules are identical. This is known as the Wesifeiler-Lehman Test {cite}`weisfeiler1968reduction`.\n",
"\n",
"\n",
"```{figure} lwtest.svg\n",
"---\n",
"alt: \"decaline and bicyclopentyl structures drawn side-by-side, which visually are different.\"\n",
"name: \"decaline-bicylopentyl\"\n",
"---\n",
"Comparison of decaline and bicylopentyl, which have identical output in most GNNs despite being different molecules.\n",
"```"
]
},
{
"cell_type": "markdown",
"id": "68b5d551",
"metadata": {},
"source": [
"These can be distinguished if we also have (and use) their Cartesian coordinates. We cannot distinguish enantiomers with GNNs, except maybe with pre-computed node attributes. Even those start to breakdown when we have helical chirality that is not centered at any one molecule.\n",
"\n",
"These are arguments for using Cartesian coordinates in addition to a GNN, but why use equivariant neural networks? Most molnet research is for **neural potentials**. These are neural networks that predict energy and forces given atom positions and elements. We know that the force on each atom is given by\n",
"\n",
"\\begin{equation}\n",
"F\\left(\\vec{r}\\right) = -\\nabla U\\left(\\vec{r}\\right)\n",
"\\end{equation}\n",
"\n",
"where $U\\left(\\vec{x}\\right)$ is the rotation invariant potential given all atom positions $\\vec{r}$. So if we're predicting a translation, rotation, and permutation invariant potential, why use equivariance? Part of it is performance. Models like SchNet or ANI are invariant and are not as accurate as models like NequiP or TorchMD-NET that have equivariances in their internal layers. Another reason is that there are indeed specific 3D configurations that should have different energies (according to quantum chemistry calculations), but are invariant if treatd with pairwise distance alone {cite}`pozdnyakov2022incompleteness`."
]
},
{
"cell_type": "markdown",
"id": "a18782f2",
"metadata": {},
"source": [
"## The Elements of Modern Molecular NNs\n",
"\n",
"There has been a flurry of ideas about molents in the last few years, especially with the advances in equivariant neural network layers. Batatia et al.{cite}`batatia2022design` have proposed a categorization of the main elements of molnets (which they call E(3)-equivariant NNs) that I will adopt here. They categorize the decisions to be made into three parts of the architecture: the atomic cluster expansions (ACE), the body-order of the messages, and the architecture of the message passing neural network (MPNN). This categorization might also be viewed within the GNN theory as node features (ACE), message creation and aggregation (body-order), and node update (MPNN details). See {doc}`gnn` for more details on MPNNs.\n",
"\n",
"This is a relatively new categorization and certainly is not necessary to use. Most papers do not use this categorization and it takes some effort to put models into it. The benefit of thinking about models with this abstractions is it helps us differentiate between the very large number of models now being pursued in the literature. There is also a bit of chaos in teasing out what *differentiates* the best models from others. For example, it took a while to discover that the most important features in NequIP were data normalization and how atom embeddings are treated {cite}`batatia2022design`. This categorization is also improving how these models are designed."
]
},
{
"cell_type": "markdown",
"id": "242add02",
"metadata": {},
"source": [
"### Atom features\n",
"\n",
"Let's start with the general terminology for an atom. Of course, the input to these networks an atom is just a Cartesian coordinate $\\vec{r}_i$ and the element $z_i$. As we pass through GNN layers the features will become larger. The atoms are the nodes. The atom features need to be organized a bit differently than previously because some of the features should be invariant with respect to the group --- SO(3) --- and some need to be equivariant. "
]
},
{
"cell_type": "markdown",
"id": "3e34c1b9",
"metadata": {},
"source": [
"### Atomic Cluster Expansions\n",
"\n",
"An ACE is a per-atom tensor. The main idea of ACE is to encode the local environment of an atom into a feature tensor that describes its neighborhood of nearby atoms. This is like distinguishing between an oxygen in an alcohol group vs an oxygen in an ether. Both are oxygens, but we expect them to behave differently. ACE is the same idea, but for nearby atoms in space instead of just on the molecular graph.\n",
"\n",
"The general equation for ACE (assuming O(3) equivariance) is [cite]:\n",
"\n",
"\\begin{equation}\n",
"A^{(t)}_{i, kl_3m_3} = \\sum_{l_1m_1,l2_m2}C_{l1m_1,l_2m_2}^{l_3,m_3}\\sum_{j \\in \\mathcal{N}(i)} R^{(t)}_{kl_1l_2l_3}\\left(r_{ji}\\right)Y_{l1}^{m_1}\\left(\\hat{\\mathbf{r}}_{ji}\\right)\\mathcal{W}^{(t)}_{kl_2}h_{j,l_2m_2}^{(t)}\n",
"\\end{equation}\n",
"\n",
"Wow! What an expression. Let's go through this carefully, starting with the output. $A^{(t)}_{i, kl_3m_3}$ are the feature tensor values for atom $i$ at layer $t$. There are channels indexed by $k$ and the spherical harmonic indexes $l_3m_3$. The right-hand side is nearly identical to the G-equivariant neural network layer equation from {doc}`Equivariant`. We have the input \n",
"\n",
"How is this different than a MPNN"
]
},
{
"cell_type": "markdown",
"id": "cef0d13f",
"metadata": {},
"source": [
"## Normalization\n",
"\n",
"* Physics-based energy/force normalization\n",
"* Pooling\n",
"* Layers"
]
},
{
"cell_type": "markdown",
"id": "b083f55a",
"metadata": {},
"source": [
"## Running This Notebook\n",
"\n",
"\n",
"Click the above to launch this page as an interactive Google Colab. See details below on installing packages.\n",
"\n",
"````{tip} My title\n",
":class: dropdown\n",
"To install packages, execute this code in a new cell. \n",
"\n",
"```\n",
"!pip install dmol-book\n",
"```\n",
"\n",
"If you find install problems, you can get the latest working versions of packages used in [this book here](https://github.com/whitead/dmol-book/blob/main/package/setup.py)\n",
"\n",
"````"
]
},
{
"cell_type": "markdown",
"id": "6bb5fcb7",
"metadata": {},
"source": [
"## Cited References\n",
"\n",
"```{bibliography}\n",
":style: unsrtalpha\n",
":filter: docname in docnames\n",
"```"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "7bf44568",
"metadata": {},
"outputs": [],
"source": [
"import matplotlib.pyplot as plt\n",
"import matplotlib as mpl\n",
"import numpy as np\n",
"import tensorflow as tf\n",
"import pandas as pd\n",
"import rdkit, rdkit.Chem, rdkit.Chem.rdDepictor, rdkit.Chem.Draw\n",
"import networkx as nx\n",
"import dmol\n",
"\n",
"my_elements = {6: \"C\", 8: \"O\", 1: \"H\"}"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "0f525d6f",
"metadata": {
"tags": [
"hide-cell"
]
},
"outputs": [],
"source": [
"def smiles2graph(sml):\n",
" \"\"\"Argument for the RD2NX function should be a valid SMILES sequence\n",
" returns: the graph\n",
" \"\"\"\n",
" m = rdkit.Chem.MolFromSmiles(sml)\n",
" m = rdkit.Chem.AddHs(m)\n",
" order_string = {\n",
" rdkit.Chem.rdchem.BondType.SINGLE: 1,\n",
" rdkit.Chem.rdchem.BondType.DOUBLE: 2,\n",
" rdkit.Chem.rdchem.BondType.TRIPLE: 3,\n",
" rdkit.Chem.rdchem.BondType.AROMATIC: 4,\n",
" }\n",
" N = len(list(m.GetAtoms()))\n",
" nodes = np.zeros((N, len(my_elements)))\n",
" lookup = list(my_elements.keys())\n",
" for i in m.GetAtoms():\n",
" nodes[i.GetIdx(), lookup.index(i.GetAtomicNum())] = 1\n",
"\n",
" adj = np.zeros((N, N, 5))\n",
" for j in m.GetBonds():\n",
" u = min(j.GetBeginAtomIdx(), j.GetEndAtomIdx())\n",
" v = max(j.GetBeginAtomIdx(), j.GetEndAtomIdx())\n",
" order = j.GetBondType()\n",
" if order in order_string:\n",
" order = order_string[order]\n",
" else:\n",
" raise Warning(\"Ignoring bond order\" + order)\n",
" adj[u, v, order] = 1\n",
" adj[v, u, order] = 1\n",
" return nodes, adj"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "6dc9231b",
"metadata": {},
"outputs": [],
"source": [
"# THIS CELL IS USED TO GENERATE A FIGURE\n",
"# AND NOT RELATED TO CHAPTER\n",
"# YOU CAN SKIP IT\n",
"from myst_nb import glue\n",
"from moviepy.editor import VideoClip\n",
"from moviepy.video.io.bindings import mplfig_to_npimage\n",
"\n",
"\n",
"def draw_vector(x, y, s, v, ax, cmap, **kwargs):\n",
" x += s / 2\n",
" y += s / 2\n",
" # iterate vertically\n",
" L = max([len(vi) for vi in v])\n",
" for dy, vi in enumerate(v):\n",
" for dx, vij in enumerate(vi):\n",
" if cmap is not None:\n",
" ax.add_patch(\n",
" mpl.patches.Rectangle(\n",
" (x + dx * s * 1.5 / L, y + dy * s),\n",
" s * 1.5 / L,\n",
" s,\n",
" facecolor=cmap(vij),\n",
" **kwargs,\n",
" )\n",
" )\n",
" else:\n",
" ax.add_patch(\n",
" mpl.patches.Rectangle(\n",
" (x + dx * s * 1.5 / L, y + dy * s),\n",
" s * 1.5 / L,\n",
" s,\n",
" facecolor=\"#FFF\",\n",
" edgecolor=\"#333\",\n",
" **kwargs,\n",
" )\n",
" )\n",
" ax.text(\n",
" x + dx * s * 1.5 / L + s * 1.5 / L / 2,\n",
" y + s / 2 + dy * s,\n",
" \"{:.2f}\".format(vij),\n",
" verticalalignment=\"center\",\n",
" horizontalalignment=\"center\",\n",
" fontsize=5,\n",
" )\n",
"\n",
"\n",
"def draw_key(x, y, s, v, ax, cmap, **kwargs):\n",
" x += s / 2\n",
" y += s / 2\n",
" for vi in v:\n",
" ax.add_patch(\n",
" mpl.patches.Rectangle((x, y), s * 1.5, s, facecolor=cmap(1.0), **kwargs)\n",
" )\n",
" ax.text(\n",
" x + s * 1.5 / 2,\n",
" y + s / 2,\n",
" vi,\n",
" verticalalignment=\"center\",\n",
" horizontalalignment=\"center\",\n",
" )\n",
" y += s\n",
" ax.text(\n",
" x, y + s / 2, \"Key:\", verticalalignment=\"center\", horizontalalignment=\"left\"\n",
" )\n",
"\n",
"\n",
"def draw(\n",
" nodes,\n",
" node_features,\n",
" adj,\n",
" ax,\n",
" highlight=None,\n",
" key=False,\n",
" labels=None,\n",
" mask=None,\n",
" draw_nodes=None,\n",
"):\n",
" G = nx.Graph()\n",
" for i in range(adj.shape[0]):\n",
" for j in range(adj.shape[0]):\n",
" if np.any(adj[i, j]):\n",
" G.add_edge(i, j)\n",
" if mask is None:\n",
" mask = [True] * len(G)\n",
" if draw_nodes is None:\n",
" draw_nodes = nodes\n",
" # go from atomic number to element\n",
" elements = np.argmax(draw_nodes, axis=-1)\n",
" el_labels = {i: list(my_elements.values())[e] for i, e in enumerate(elements)}\n",
" try:\n",
" pos = nx.nx_agraph.graphviz_layout(G, prog=\"sfdp\")\n",
" except ImportError:\n",
" pos = nx.spring_layout(G, iterations=100, seed=4, k=1)\n",
" pos = nx.rescale_layout_dict(pos)\n",
" c = [\"white\"] * len(G)\n",
" all_h = []\n",
" if highlight is not None:\n",
" for i, h in enumerate(highlight):\n",
" for hj in h:\n",
" c[hj] = \"C{}\".format(i + 1)\n",
" all_h.append(hj)\n",
" # now we add all edges to close (using pos) atoms to emphasize spatial locality\n",
" for i in range(adj.shape[0]):\n",
" for j in range(adj.shape[0]):\n",
" if (\n",
" i != j\n",
" and not np.any(adj[i, j])\n",
" and np.linalg.norm(np.array(pos[i]) - np.array(pos[j])) < 1\n",
" ):\n",
" G.add_edge(i, j, space=True)\n",
" # set-up edge colors based on if they are space or not\n",
" edge_colors = [\"#000\"] * len(G.edges)\n",
" for i, (u, v, d) in enumerate(G.edges(data=True)):\n",
" if d.get(\"space\", False):\n",
" edge_colors[i] = \"#AAA\"\n",
" nx.draw(\n",
" G,\n",
" ax=ax,\n",
" pos=pos,\n",
" labels=el_labels,\n",
" node_size=700,\n",
" node_color=c,\n",
" edge_color=edge_colors,\n",
" )\n",
" cmap = plt.get_cmap(\"Wistia\")\n",
" for i in range(len(G)):\n",
" if not mask[i]:\n",
" continue\n",
" if i in all_h:\n",
" draw_vector(*pos[i], 0.15, node_features[i], ax, cmap)\n",
" else:\n",
" draw_vector(*pos[i], 0.15, node_features[i], ax, None)\n",
" if key:\n",
" draw_key(-1, -1, 0.15, my_elements.values(), ax, cmap)\n",
" if labels is not None:\n",
" legend_elements = []\n",
" for i, l in enumerate(labels):\n",
" p = mpl.lines.Line2D(\n",
" [0], [0], marker=\"o\", color=\"C{}\".format(i + 1), label=l, markersize=15\n",
" )\n",
" legend_elements.append(p)\n",
" ax.legend(handles=legend_elements)\n",
" ax.set_xlim(-1.2, 1.2)\n",
" ax.set_ylim(-1.2, 1.2)\n",
" ax.set_facecolor(\"#f5f4e9\")\n",
"\n",
"\n",
"nodes, adj = smiles2graph(\"CO\")\n",
"print(nodes)\n",
"nodes_vectors = [[n, n[:2]] for i, n in enumerate(nodes)]\n",
"print(nodes_vectors[1])\n",
"fig = plt.figure(figsize=(8, 5))\n",
"draw(\n",
" nodes,\n",
" nodes_vectors,\n",
" adj,\n",
" plt.gca(),\n",
" highlight=[[1], [5, 0]],\n",
" labels=[\"center\", \"neighbors\"],\n",
")\n",
"fig.set_facecolor(\"#f5f4e9\")\n",
"glue(\"dframe\", plt.gcf(), display=False)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "fd8e064f",
"metadata": {},
"outputs": [],
"source": [
"glue(\"dframe\", plt.gcf(), display=False)\n",
"\n",
"# THIS CELL IS USED TO GENERATE A FIGURE\n",
"# AND NOT RELATED TO CHAPTER\n",
"# YOU CAN SKIP IT\n",
"fig, axs = plt.subplots(1, 2, squeeze=True, figsize=(14, 6), dpi=100)\n",
"order = [5, 1, 0, 2, 3, 4]\n",
"time_per_node = 2\n",
"last_layer = [0]\n",
"layers = 2\n",
"input_nodes = np.copy(nodes)\n",
"fig.set_facecolor(\"#f5f4e9\")\n",
"\n",
"\n",
"def make_frame(t):\n",
" axs[0].clear()\n",
" axs[1].clear()\n",
"\n",
" layer_i = int(t / (time_per_node * len(order)))\n",
" axs[0].set_title(f\"Layer {layer_i + 1} Input\")\n",
" axs[1].set_title(f\"Layer {layer_i + 1} Output\")\n",
"\n",
" flat_adj = np.sum(adj, axis=-1)\n",
" # out_nodes = np.einsum(\n",
" # \"i,ij,jk->ik\",\n",
" # 1 / (np.sum(flat_adj, axis=1) + 1),\n",
" # flat_adj + np.eye(*flat_adj.shape),\n",
" # nodes,\n",
" # )\n",
" out_nodes = nodes\n",
" node_vectors = [[n, n[:2]] for i, n in enumerate(out_nodes)]\n",
" if last_layer[0] != layer_i:\n",
" print(\"recomputing\")\n",
" nodes[:] = out_nodes\n",
" last_layer[0] = layer_i\n",
"\n",
" t -= layer_i * time_per_node * len(order)\n",
" i = order[int(t / time_per_node)]\n",
" print(last_layer, layer_i, i, t)\n",
" mask = [False] * nodes.shape[0]\n",
" for j in order[: int(t / time_per_node) + 1]:\n",
" mask[j] = True\n",
" print(mask, i)\n",
" neighs = list(np.where(adj[i])[0])\n",
" if (t - int(t / time_per_node) * time_per_node) >= time_per_node / 4:\n",
" draw(\n",
" nodes,\n",
" node_vectors,\n",
" adj,\n",
" axs[0],\n",
" highlight=[[i], neighs],\n",
" labels=[\"center\", \"neighbors\"],\n",
" draw_nodes=input_nodes,\n",
" )\n",
" else:\n",
" draw(\n",
" nodes,\n",
" node_vectors,\n",
" adj,\n",
" axs[0],\n",
" highlight=[[i]],\n",
" labels=[\"center\", \"neighbors\"],\n",
" draw_nodes=input_nodes,\n",
" )\n",
" if (t - int(t / time_per_node) * time_per_node) < time_per_node / 2:\n",
" mask[j] = False\n",
" draw(\n",
" out_nodes,\n",
" node_vectors,\n",
" adj,\n",
" axs[1],\n",
" highlight=[[i]],\n",
" key=True,\n",
" mask=mask,\n",
" draw_nodes=input_nodes,\n",
" )\n",
" fig.set_facecolor(\"#f5f4e9\")\n",
" return mplfig_to_npimage(fig)\n",
"\n",
"\n",
"animation = VideoClip(make_frame, duration=time_per_node * nodes.shape[0] * layers)\n",
"animation.write_gif(\"../_static/images/molnet.gif\", fps=2)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "2e7cf098",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"celltoolbar": "Tags",
"kernelspec": {
"display_name": "Python 3.7.8 64-bit",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.8"
},
"vscode": {
"interpreter": {
"hash": "3e5a039a7a113538395a7d74f5574b0c5900118222149a18efb009bf03645fce"
}
}
},
"nbformat": 4,
"nbformat_minor": 5
}