{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Predicting DFT Energies with GNNs\n",
"\n",
"QM9 is a dataset of 134,000 molecules consisting of 9 heavy atoms drawn from the elements C, H, O, N, F{cite}`ramakrishnan2014quantum`. The features are the xyz coordinates ($\\mathbf{X}$) and elements ($\\vec{e}$) of the molecule. The coordinates are determined from B3LYP/6-31G(2df,p) level DFT geometry optimization. There are multiple labels (see table below), but we'll be interested specifically in the energy of formation (Enthalpy at 298.15 K). The goal in this chapter is to regress a graph neural network to predict the energy of formation given the coordinates of a molecule. We will build upon ideas in the following chapters:\n",
"\n",
"1. {doc}`../ml/regression`\n",
"2. {doc}`../dl/gnn`\n",
"3. {doc}`../dl/data`\n",
"\n",
"\n",
"QM9 is one of the most popular dataset for machine learning and deep learning since it came out in 2014. The first papers could achieve about 10 kcal/mol on this regression problem and now are down to ~1 kcal/mol and lower. Any model on this dataset must be translation, rotation, and permutation invariant. "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Label Description\n",
"\n",
"|Index | Name | Units | Description|\n",
" |:-----|-------|-------|-----------:|\n",
" |0 |index | - |Consecutive, 1-based integer identifier of molecule|\n",
" |1 |A | GHz |Rotational constant A|\n",
" |2 |B | GHz |Rotational constant B|\n",
" |3 |C | GHz |Rotational constant C|\n",
" |4 |mu | Debye |Dipole moment|\n",
" |5 |alpha | Bohr^3 |Isotropic polarizability|\n",
" |6 |homo | Hartree |Energy of Highest occupied molecular orbital (HOMO)|\n",
" |7 |lumo | Hartree |Energy of Lowest unoccupied molecular orbital (LUMO)|\n",
" |8 | gap | Hartree | Gap, difference between LUMO and HOMO|\n",
" |9 | r2 | Bohr^2 | Electronic spatial extent|\n",
" |10 | zpve | Hartree | Zero point vibrational energy|\n",
" |11 | U0 | Hartree | Internal energy at 0 K|\n",
" |12 | U | Hartree | Internal energy at 298.15 K|\n",
" |13 | H | Hartree | Enthalpy at 298.15 K|\n",
" |14 | G | Hartree | Free energy at 298.15 K|\n",
" |15 | Cv | cal/(mol K) | Heat capacity at 298.15 K|\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Data\n",
"\n",
"I have written some helper code in the `fetch_data.py` file. It downloads the data and converts into a format easily used in Python. The data returned from this function is broken into the features $\\mathbf{X}$ and $\\vec{e}$. $\\mathbf{X}$ is an $N\\times4$ matrix of atom positions + partial charge of the atom. $\\vec{e}$ is vector of atomic numbers for each atom in the molecule. Remember to slice the specific label you want from the label vector."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Running This Notebook\n",
"\n",
"\n",
"Click the above to launch this page as an interactive Google Colab. See details below on installing packages.\n",
"\n",
"````{tip} My title\n",
":class: dropdown\n",
"To install packages, execute this code in a new cell. \n",
"\n",
"```\n",
"!pip install dmol-book\n",
"```\n",
"\n",
"If you find install problems, you can get the latest working versions of packages used in [this book here](https://github.com/whitead/dmol-book/blob/main/package/setup.py)\n",
"\n",
"````"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import tensorflow as tf\n",
"import numpy as np\n",
"from fetch_data import qm9_parse, qm9_fetch\n",
"import dmol"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's load the data. This step will take a few minutes as it is downloaded and processed. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"tags": [
"remove-output"
]
},
"outputs": [],
"source": [
"qm9_records = qm9_fetch()\n",
"data = qm9_parse(qm9_records)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"`data` is an iterable containing the data for the 133k molecules. Let's examine the first one."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"for d in data:\n",
" print(d)\n",
" break"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"These are Tensorflow Tensors. They can be converted to numpy arrays via `x.numpy()`. The first item is the element vector `6,1,1,1,1`. Do you recognize the elements? It's C, H, H, H, H. The positions come next. Note that there is an extra column containing the atom partial charges, which we will not use as a feature. Finally, the last tensor is the label vector. \n",
"\n",
"Now we will do some processing of the data to get into a more usable format. Let's convert to numpy arrays, remove the partial charges, and convert the elements into one-hot vectors."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def convert_record(d):\n",
" # break up record\n",
" (e, x), y = d\n",
" #\n",
" e = e.numpy()\n",
" x = x.numpy()\n",
" r = x[:, :3]\n",
" # make ohc size larger\n",
" # so use same node feature\n",
" # shape later\n",
" ohc = np.zeros((len(e), 16))\n",
" ohc[np.arange(len(e)), e - 1] = 1\n",
" return (ohc, r), y.numpy()[13]\n",
"\n",
"\n",
"for d in data:\n",
" (e, x), y = convert_record(d)\n",
" print(\"Element one hots\\n\", e)\n",
" print(\"Coordinates\\n\", x)\n",
" print(\"Label:\", y)\n",
" break"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Baseline Model\n",
"\n",
"Before we get too far into modeling, let's see what a simple model can do for accuracy. This will help establish a baseline model which any more sophisticated implementation should exceed in accuracy. You can make many choices for this, but I'll just make a linear regression based on number of atom types."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import jax.numpy as jnp\n",
"import jax.example_libraries.optimizers as optimizers\n",
"import jax\n",
"import warnings\n",
"import matplotlib.pyplot as plt\n",
"\n",
"warnings.filterwarnings(\"ignore\")\n",
"\n",
"\n",
"@jax.jit\n",
"def baseline_model(nodes, w, b):\n",
" # get sum of each element type\n",
" atom_count = jnp.sum(nodes, axis=0)\n",
" yhat = atom_count @ w + b\n",
" return yhat\n",
"\n",
"\n",
"def baseline_loss(nodes, y, w, b):\n",
" return (baseline_model(nodes, w, b) - y) ** 2\n",
"\n",
"\n",
"baseline_loss_grad = jax.grad(baseline_loss, (2, 3))\n",
"w = np.ones(16)\n",
"b = 0.0"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We've set-up our simple regression model. One complexity is that we cannot batch the molecules like normal because each molecule contains different shaped tensors."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# we'll just train on 5,000 and use 1,000 for test\n",
"# shuffle, but only once (reshuffle_each_iteration=False) so\n",
"# we lock in which are train/test/val\n",
"shuffled_data = data.shuffle(7000, reshuffle_each_iteration=False)\n",
"test_set = shuffled_data.take(1000)\n",
"valid_set = shuffled_data.skip(1000).take(1000)\n",
"train_set = shuffled_data.skip(2000).take(5000)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The labels in this data are quite large, so we're going to make a transform on them to make our learning rates and training going more smoothly. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"ys = [convert_record(d)[1] for d in train_set]\n",
"train_ym = np.mean(ys)\n",
"train_ys = np.std(ys)\n",
"print(\"Mean = \", train_ym, \"Std =\", train_ys)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now we'll just use this transform when training: $y_s = \\frac{y - \\mu_y}{\\sigma_y}$ and then our predictions will be transformed by $\\hat{y} = \\hat{f}(e,x) \\cdot \\sigma_y + \\mu_y$. This just helps standardize our range of outputs."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def transform_label(y):\n",
" return (y - train_ym) / train_ys\n",
"\n",
"\n",
"def transform_prediction(y):\n",
" return y * train_ys + train_ym"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"scrolled": true
},
"outputs": [],
"source": [
"epochs = 16\n",
"eta = 1e-3\n",
"baseline_val_loss = [0.0 for _ in range(epochs)]\n",
"for epoch in range(epochs):\n",
" for d in train_set:\n",
" (e, x), y_raw = convert_record(d)\n",
" y = transform_label(y_raw)\n",
" grad_est = baseline_loss_grad(e, y, w, b)\n",
" # update regression weights\n",
" w -= eta * grad_est[0]\n",
" b -= eta * grad_est[1]\n",
" # compute validation loss\n",
" for v in valid_set:\n",
" (e, x), y_raw = convert_record(v)\n",
" y = transform_label(y_raw)\n",
" # convert SE to RMSE\n",
" baseline_val_loss[epoch] += baseline_loss(e, y, w, b)\n",
" baseline_val_loss[epoch] = jnp.sqrt(baseline_val_loss[epoch] / 1000)\n",
" eta *= 0.9\n",
"plt.plot(baseline_val_loss)\n",
"plt.xlabel(\"Epoch\")\n",
"plt.ylabel(\"Val Loss\")\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"This is poor performance, but it gives us a baseline value of what we can expect. One unusual detail I did in this training was to slowly reduce the learning rate. This is because our features and labels are all in different magnitudes. Our weights need to move far to get into the right order of magnitude and then need to fine-tune a little. Thus, we start at high learning rate and decrease."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"ys = []\n",
"yhats = []\n",
"for v in valid_set:\n",
" (e, x), y = convert_record(v)\n",
" ys.append(y)\n",
" yhat_raw = baseline_model(e, w, b)\n",
" yhat = transform_prediction(yhat_raw)\n",
" yhats.append(yhat)\n",
"\n",
"\n",
"plt.plot(ys, ys, \"-\")\n",
"plt.plot(ys, yhats, \".\")\n",
"plt.xlabel(\"Energy\")\n",
"plt.ylabel(\"Predicted Energy\")\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"You can see that the broad trends about molecule size capture a lot of variance, but more work needs to be done. "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Example GNN Model\n",
"\n",
"We now can work with this data to build a model. Let's build a simple model that can model energy and obeys the invariances required of the problem. We will use a graph neural network (GNN) because it obeys permutation invariance. We will create a *graph* from the coordinates/element vector by joining all atoms to all other atoms and using their inverse pairwise distance as the edge weight. The choice of pairwise distance gives us translation and rotation invariance. The choice of inverse distance means that atoms which are far away naturally have low edge weights.\n",
"\n",
"I will now define our model using the Battaglia equations {cite}`battaglia2018relational`. As opposed to most examples we've seen in class, I will use the graph level feature vector $\\vec{u}$ which will ultimately be our estimate of energy. The edge update will only consider the sender and the edge weight with trainable parameters:\n",
"\n",
"\\begin{equation}\n",
" \\vec{e}^{'}_k = \\phi^e\\left( \\vec{e}_k, \\vec{v}_{rk}, \\vec{v}_{sk}, \\vec{u}\\right) = \\sigma\\left(\\vec{v}_{sk}\\vec{w}_ee_k + \\vec{b}_e\\right)\n",
"\\end{equation}\n",
"\n",
"where the input edge $e_k$ will be a single number (inverse pairwise distance) and $\\vec{b}_e$ is a trainable bias vector. We will use a sum aggregation for edges (not shown). $\\sigma$ is a leaky ReLU. The leaky just prevents vanishing gradients, which I found empirically to reduce performance here. The node update will be \n",
"\n",
"\\begin{equation}\n",
" \\vec{v}^{'}_i = \\phi^v\\left( \\bar{e}^{'}_i, \\vec{v}_i, \\vec{u}\\right) = \\sigma\\left(\\mathbf{W}_v \\bar{e}^{'}_i\\right) + \\vec{v}_i\n",
"\\end{equation}\n",
"\n",
"The global node aggregation will also be a sum. Finally, we have our graph feature vector update:\n",
"\n",
"\\begin{equation}\n",
" \\vec{u}^{'} = \\phi^u\\left( \\bar{e}^{'},\\bar{v}^{'}, \\vec{u}\\right) = \\sigma\\left(\\mathbf{W}_u\\bar{v}^{'}\\right) + \\vec{u}\n",
"\\end{equation}\n",
"\n",
"\n",
"To compute the final energy, we'll use our regression equation:\n",
"\n",
"\\begin{equation}\n",
" \\hat{E} = \\vec{w}\\cdot \\vec{u} + b\n",
"\\end{equation}\n",
"\n",
"One final detail is that we will pass on $\\vec{u}$ and the nodes, but we will keep the edges the same at each GNN layer. Remember this is an example model: there are many changes that could be made to the above. Also, it is not kernel learning which is the favorite for this domain. Let's implement it though and see if it works. "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### JAX Model Implementation"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def x2e(x):\n",
" \"\"\"convert xyz coordinates to inverse pairwise distance\"\"\"\n",
" r2 = jnp.sum((x - x[:, jnp.newaxis, :]) ** 2, axis=-1)\n",
" e = jnp.where(r2 != 0, 1 / r2, 0.0)\n",
" return e\n",
"\n",
"\n",
"def gnn_layer(nodes, edges, features, we, web, wv, wu):\n",
" \"\"\"Implementation of the GNN\"\"\"\n",
" # make nodes be N x N so we can just multiply directly\n",
" # ek is now shaped N x N x features\n",
" ek = jax.nn.leaky_relu(\n",
" web\n",
" + jnp.repeat(nodes[jnp.newaxis, ...], nodes.shape[0], axis=0)\n",
" @ we\n",
" * edges[..., jnp.newaxis]\n",
" )\n",
" # sum over neighbors to get N x features\n",
" ebar = jnp.sum(ek, axis=1)\n",
" # dense layer for new nodes to get N x features\n",
" new_nodes = jax.nn.leaky_relu(ebar @ wv) + nodes\n",
" # sum over nodes to get shape features\n",
" global_node_features = jnp.sum(new_nodes, axis=0)\n",
" # dense layer for new features\n",
" new_features = jax.nn.leaky_relu(global_node_features @ wu) + features\n",
" # just return features for ease of use\n",
" return new_nodes, edges, new_features"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We have implemented the code to convert coordinates into inverse pairwise distance and the GNN equations above. Let's test them out."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"graph_feature_len = 8\n",
"node_feature_len = 16\n",
"msg_feature_len = 16\n",
"\n",
"\n",
"# make our weights\n",
"def init_weights(g, n, m):\n",
" we = np.random.normal(size=(n, m), scale=1e-1)\n",
" wb = np.random.normal(size=(m), scale=1e-1)\n",
" wv = np.random.normal(size=(m, n), scale=1e-1)\n",
" wu = np.random.normal(size=(n, g), scale=1e-1)\n",
" return [we, wb, wv, wu]\n",
"\n",
"\n",
"# make a graph\n",
"nodes = e\n",
"edges = x2e(x)\n",
"features = jnp.zeros(graph_feature_len)\n",
"\n",
"# eval\n",
"out = gnn_layer(\n",
" nodes,\n",
" edges,\n",
" features,\n",
" *init_weights(graph_feature_len, node_feature_len, msg_feature_len),\n",
")\n",
"print(\"input feautres\", features)\n",
"print(\"output features\", out[2])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Great! Our model can update the graph features. Now we need to convert this into callable and loss. We'll stack two GNN layers."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# get weights for both layers\n",
"w1 = init_weights(graph_feature_len, node_feature_len, msg_feature_len)\n",
"w2 = init_weights(graph_feature_len, node_feature_len, msg_feature_len)\n",
"w3 = np.random.normal(size=(graph_feature_len))\n",
"b = 0.0\n",
"\n",
"\n",
"@jax.jit\n",
"def model(nodes, coords, w1, w2, w3, b):\n",
" f0 = jnp.zeros(graph_feature_len)\n",
" e0 = x2e(coords)\n",
" n0 = nodes\n",
" n1, e1, f1 = gnn_layer(n0, e0, f0, *w1)\n",
" n2, e2, f2 = gnn_layer(n1, e1, f1, *w2)\n",
" yhat = f2 @ w3 + b\n",
" return yhat\n",
"\n",
"\n",
"def loss(nodes, coords, y, w1, w2, w3, b):\n",
" return (model(nodes, coords, w1, w2, w3, b) - y) ** 2\n",
"\n",
"\n",
"loss_grad = jax.grad(loss, (3, 4, 5, 6))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"```{margin}\n",
"You could pad the molecules to all be the same shape. This is a common strategy. We will skip this though for simplicity. \n",
"```\n",
"\n",
"One small change we've made below is that we scale the learning rate for the GNN to be $ 1 / 10$ of the rate for the regression parameters. This is because the GNN parameters need to vary slower based on trial and error. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"eta = 1e-3\n",
"val_loss = [0.0 for _ in range(epochs)]\n",
"for epoch in range(epochs):\n",
" for d in train_set:\n",
" (e, x), y_raw = convert_record(d)\n",
" y = transform_label(y_raw)\n",
" grad = loss_grad(e, x, y, w1, w2, w3, b)\n",
" # update regression weights\n",
" w3 -= eta * grad[2]\n",
" b -= eta * grad[3]\n",
" # update GNN weights\n",
" for i, w in [(0, w1), (1, w2)]:\n",
" for j in range(len(w)):\n",
" w[j] -= eta * grad[i][j] / 10\n",
" # compute validation loss\n",
" for v in valid_set:\n",
" (e, x), y_raw = convert_record(v)\n",
" y = transform_label(y_raw)\n",
" # convert SE to RMSE\n",
" val_loss[epoch] += loss(e, x, y, w1, w2, w3, b)\n",
" val_loss[epoch] = jnp.sqrt(val_loss[epoch] / 1000)\n",
" eta *= 0.9\n",
"plt.plot(baseline_val_loss, label=\"baseline\")\n",
"plt.plot(val_loss, label=\"GNN\")\n",
"plt.legend()\n",
"plt.xlabel(\"Epoch\")\n",
"plt.ylabel(\"Val Loss\")\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"This is a large dataset and we're under training, but hopefully you get the principles of this process! Finally, we'll examine our parity plot. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"ys = []\n",
"yhats = []\n",
"for v in valid_set:\n",
" (e, x), y = convert_record(v)\n",
" ys.append(y)\n",
" yhat_raw = model(e, x, w1, w2, w3, b)\n",
" yhats.append(transform_prediction(yhat_raw))\n",
"\n",
"\n",
"plt.plot(ys, ys, \"-\")\n",
"plt.plot(ys, yhats, \".\")\n",
"plt.xlabel(\"Energy\")\n",
"plt.ylabel(\"Predicted Energy\")\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The clusters are molecule types/sizes. You can see we're starting to get the correct trend within the clusters, but a lot of work needs to be done to move some of them. Additional learning required!"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Relevant Videos About Modeling QM9\n",
"\n",
"\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Cited References\n",
"\n",
"```{bibliography}\n",
":style: unsrtalpha\n",
":filter: docname in docnames\n",
"```"
]
}
],
"metadata": {
"celltoolbar": "Tags",
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.12"
}
},
"nbformat": 4,
"nbformat_minor": 4
}