{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Kernel Learning\n",
"\n",
"Kernel learning is a way to transform features in either classification or regression problems. Recall in regression, we have the following model equation:\n",
"\n",
"\\begin{equation}\n",
" \\hat{y} = \\vec{w}\\vec{x} + b\n",
"\\end{equation}\n",
"\n",
"where $\\vec{x}$ is our feature vector of dimension $D$. In kernel learning, we transform our feature vector from dimension $D$ features to *distances to training data points* of dimension $N$, where $N$ is the number of training data points:\n",
"\n",
"\\begin{equation}\n",
" \\hat{y} = \\sum_i^N w_i \\left<\\vec{x}, \\vec{x}_i\\right>+ b\n",
"\\end{equation}\n",
"\n",
"where $\\left<\\cdot\\right>$ is the distance between two feature vectors and $\\vec{x}_i$ is the $i$th training data point. $\\vec{x}$ is the function argument whereas $\\vec{x}_i$ are known values. \n",
"\n",
"```{admonition} Audience & Objectives\n",
"This chapter builds on {doc}`regression` and {doc}`classification`. After completing this chapter, you should be able to \n",
"\n",
" * Distinguish between primal and dual form\n",
" * Choose when kernel learning could be beneficial\n",
" * Understand the principles of training curves and the connection between training size and feature number \n",
"```\n",
"\n",
"\n",
"```{margin}\n",
"Although we will use the word distance, the $\\left<\\cdot\\right>$ function is actually an inner product which is more flexible than a distance.\n",
"```\n",
"\n",
"One of the consequences of this transformation is that our training weight vector, $\\vec{w}$, no longer depends on the number of features. **Instead the number of weights depends on the number of training data points.** This is the reason we use kernel learning. You might have few training data points but large feature vectors ($N < D$). By using the kernel formulation, you'll reduce the number of weights. It might also be that your feature space is hard to model with a linear equation, but when you view it as distances from training data it becomes linear (often $N > D$). Finally, it could be that your feature vector is infinite dimensional (i.e., it is a function not a vector) or for some reason you cannot compute it. In kernel learning, you only need to define your *kernel function* $\\left<\\cdot\\right>$ and never explicitly work with the feature vector. \n",
"\n",
"The distance function is called a *kernel function* $\\left<\\cdot\\right>$. A kernel function is a binary function (takes two arguments) that outputs a scalar and has the following properties:\n",
"\n",
"1. Positive: $\\left \\geq 0$\n",
"2. Symmetric: $\\left = \\left$\n",
"3. Point-separating: $\\left = 0$ if and only if $x = x'$\n",
"\n",
"The classic kernel function example is $L_2$ norm (Euclidean distance): $\\left<\\vec{x}, \\vec{x}'\\right>=\\sqrt{\\sum^D_i (x_i - x_i^{'})^2}$. Some of the most interesting applications of kernel learning though are when $x$ is not a vector, but a function or some other structured object. \n",
"\n",
"```{admonition} Primal & Dual Form\n",
":class: tip\n",
"**Dual Form** is what some call our model equation when it uses the kernels: $\\hat{y} = \\sum_i^N w_i \\left<\\vec{x}, \\vec{x}_i\\right>+ b$. To distinguish from the dual form, you can also refer to the usual model equation as the **Primal Form** $\\hat{y} = \\vec{w}\\vec{x} + b$. It also sounds cool. \n",
"```\n",
"\n",
"Kernel learning is a widely-used approach for learning potential energy functions and force fields in molecular modeling {cite}`scherer2020kernel,rupp2012fast`."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Solubility Example\n",
"\n",
"Let's revisit the solubility AqSolDB{cite}`Sorkun2019` dataset from {doc}`regression`. Recall it has about 10,000 unique compounds with measured solubility in water (label) and 17 molecular descriptors (features)."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Running This Notebook\n",
"\n",
"\n",
"Click the above to launch this page as an interactive Google Colab. See details below on installing packages.\n",
"\n",
"````{tip} My title\n",
":class: dropdown\n",
"To install packages, execute this code in a new cell. \n",
"\n",
"```\n",
"!pip install dmol-book\n",
"```\n",
"\n",
"If you find install problems, you can get the latest working versions of packages used in [this book here](https://github.com/whitead/dmol-book/blob/main/package/setup.py)\n",
"\n",
"````"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"As usual, the code below sets-up our imports.\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import pandas as pd\n",
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"import jax.numpy as jnp\n",
"import jax\n",
"import dmol"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# soldata = pd.read_csv('https://dataverse.harvard.edu/api/access/datafile/3407241?format=original&gbrecs=true')\n",
"soldata = pd.read_csv(\n",
" \"https://github.com/whitead/dmol-book/raw/main/data/curated-solubility-dataset.csv\"\n",
")\n",
"features_start_at = list(soldata.columns).index(\"MolWt\")\n",
"feature_names = soldata.columns[features_start_at:]\n",
"np.random.seed(0)\n",
"\n",
"# Split into train(80) and test(20)\n",
"N = len(soldata)\n",
"split = int(N * 0.8)\n",
"shuffled = soldata.sample(N, replace=False)\n",
"train = shuffled[:split]\n",
"test = shuffled[split:]\n",
"\n",
"# standardize the features using only train\n",
"test[feature_names] -= train[feature_names].mean()\n",
"test[feature_names] /= train[feature_names].std()\n",
"train[feature_names] -= train[feature_names].mean()\n",
"train[feature_names] /= train[feature_names].std()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Kernel Definition\n",
"\n",
"We'll start by creating our kernel function. *Our kernel function does not need to be differentiable*. In contrast to the functions we see in deep learning, we can use sophisticated and non-differentiable functions in kernel learning. For example, you could use a two-component molecular dynamics simulation to compute the kernel between two molecules. We'll still implement our kernel functions in JAX for this example because it is efficient and consistent. Remember our kernel should take two feature vectors and return a scalar. In our example, we will simply use the $L_2$ norm. There is one small change: dividing by the dimension. This makes our kernel output magnitude is independent of the number of dimensions of $x$. Other options for kernels for dense vectors are $1 - $ cosine similarity (dot product) or [Mahalanobis distance](https://en.wikipedia.org/wiki/Mahalanobis_distance).\n",
"\n",
"Choosing a kernel function is an open question for molecules. You can work with the molecular structure directly with a variety of ideas. You can use {doc}`../dl/gnn` to convert molecules into vectors and use any of the vector kernels above. You can work with fingerprints (vectors of bits) from cheminformatics libraries like Rdkit and compare such vectors with [Tanimoto similarity](https://en.wikipedia.org/wiki/Jaccard_index) {doc}`rankovic_griffiths_moss_schwaller_2022, yang2022classifying`. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def kernel(x1, x2):\n",
" return jnp.sqrt(jnp.mean((x1 - x2) ** 2))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Model Definition\n",
"\n",
"\n",
"```{margin}\n",
"Since we're doing linear regression, you can compute the fit coefficients by just doing matrix algebra. We'll still approach this problem with gradient descent even though there are more efficient model-specific procedures.\n",
"```\n",
"\n",
"Now we define our regression model equation in the *dual form*. Remember that our function must always take the training data in to compute the distance to a new given point. We will use the batch feature of JAX to compute all the kernels simultaneously for our new point.\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def model(x, train_x, w, b):\n",
" # make vectorized version of kernel\n",
" vkernel = jax.vmap(kernel, in_axes=(None, 0), out_axes=0)\n",
" # compute kernel with all training data\n",
" s = vkernel(x, train_x)\n",
" # dual form\n",
" yhat = jnp.dot(s, w) + b\n",
" return yhat\n",
"\n",
"\n",
"# make batched version that can handle multiple xs\n",
"batch_model = jax.vmap(model, in_axes=(0, None, None, None), out_axes=0)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Training\n",
"\n",
"We now have trainable weights and a model equation. To begin training, we need to define a loss function and compute its gradient. We'll use mean squared error as usual for the loss function. We can use regularization, as we saw previously, but will skip it for now. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"@jax.jit\n",
"def loss(w, b, train_x, x, y):\n",
" return jnp.mean((batch_model(x, train_x, w, b) - y) ** 2)\n",
"\n",
"\n",
"loss_grad = jax.grad(loss, (0, 1))\n",
"\n",
"# convert from pandas dataframe to numpy arrays\n",
"train_x = train[feature_names].values\n",
"train_y = train[\"Solubility\"].values\n",
"test_x = test[feature_names].values\n",
"test_y = test[\"Solubility\"].values"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We've defined our loss and split our data into training/testing. Now we will set-up the training parameters, including breaking up our training data into batches. An **epoch** is one iteration through the whole dataset. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"eta = 1e-5\n",
"batch_size = 32\n",
"epochs = 10\n",
"\n",
"\n",
"# reshape into batches\n",
"batch_num = train_x.shape[0] // batch_size\n",
"# first truncate data so it's whole nubmer of batches\n",
"trunc = batch_num * batch_size\n",
"train_x = train_x[:trunc]\n",
"train_y = train_y[:trunc]\n",
"# split into batches\n",
"x_batches = train_x.reshape(-1, batch_size, train_x.shape[-1])\n",
"y_batches = train_y.reshape(-1, batch_size)\n",
"\n",
"\n",
"# make trainable parameters\n",
"# w = np.random.normal(scale = 1e-30, size=train_x.shape[0])\n",
"w = np.zeros(train_x.shape[0])\n",
"b = np.mean(train_y) # just set to mean, since it's a good first guess"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"You may notice our learning rate, $\\eta$, is unusually low at $10^{-5}$. It's because each training data point, for which we have about 8,000, contributes to the final $\\hat{y}$. Thus if we take a large training step, it is can create very big changes to $\\hat{y}$. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"loss_progress = []\n",
"test_loss_progress = []\n",
"\n",
"for _ in range(epochs):\n",
" # go in random order\n",
" for i in np.random.randint(0, batch_num - 1, size=batch_num):\n",
" # update step\n",
" x = x_batches[i]\n",
" y = y_batches[i]\n",
" loss_progress.append(loss(w, b, train_x, x, y))\n",
" test_loss_progress.append(loss(w, b, train_x, test_x, test_y))\n",
" grad = loss_grad(w, b, train_x, x, y)\n",
" w -= eta * grad[0]\n",
" b -= eta * grad[1]\n",
"plt.plot(loss_progress, label=\"Training Loss\")\n",
"plt.plot(test_loss_progress, label=\"Testing Loss\")\n",
"\n",
"plt.xlabel(\"Step\")\n",
"plt.yscale(\"log\")\n",
"plt.legend()\n",
"plt.ylabel(\"Loss\")\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"One small change from previous training loops is that we randomized our batches in the `for` loop."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"yhat = batch_model(test_x, train_x, w, b)\n",
"plt.plot(test_y, test_y)\n",
"plt.plot(test_y, yhat, \".\")\n",
"plt.text(min(y) + 1, max(y) - 2, f\"correlation = {np.corrcoef(test_y, yhat)[0,1]:.3f}\")\n",
"plt.text(min(y) + 1, max(y) - 3, f\"loss = {np.sqrt(np.mean((test_y - yhat)**2)):.3f}\")\n",
"plt.title(\"Testing Data\")\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We can see our results show underfitting. As usual, I want to make this code execute fast so I have not done many epochs. You can increase the epoch number and watch the loss and correlation improve over time. "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Regularization\n",
"\n",
"\n",
"You'll notice that our trainable parameter number, by design, is equal to the number training data points. If we were to use a direct computation of the fit coefficients with a pseudo-inverse, we could run into problems because of this. Thus, most people add an additional regularization term to both make matrix algebra solutions tractable and because it seems wise with the large number of trainable parameters. Just like we saw in linear regression, L1 regression is known as **Lasso Ridge Regression** and L2 is known as **Kernel Ridge Regression**. Remember that L1 zeros-out specific parameters which was useful for interpreting the importance of features in linear regression. However, in the kernel setting this would only zero-out specific training data points and thus provides no real insight(usually, see {doc}`../dl/xai`). Kernel ridge regression is thus more popular in the kernel setting."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Training Curves\n",
"\n",
"The bias-variance trade-off from {doc}`../ml/regression` showed how increasing model complexity could reduce model bias (more expressive and able to fit data better) at the cost of increased model variance (more sensitive to training data choice and amount). The model complexity was controlled by adjusting feature number. In kernel learning, we cannot control feature number because it is always equal to the number of training data points. Thus, we can only control hyperparameters like the choice of kernel, regularization, learning rate, etc. To assess these effects, we usually do not only compute test loss because that is highly-connected to the amount of training data you have. More training data means more sophisticated models and thus lower loss. So it is common in kernel learning especially to show how the test-loss changes a function of training data amount. These are presented as log-log plots due to the large magnitude changes in these. These are called **training curves** (or sometimes **learning curves**). Training curves can be applied broadly in ML and deep learning, but you'll most often see them in kernel learning.\n",
"\n",
"Let's revisit our solubility model and compare L1 and L2 regularization with a training curve. Note that this code is very slow because we must compute $M$ models, where $M$ is the number of points we want on our training curve. To keep things efficient for this textbook, I'll use few points on the curve.\n",
"\n",
"First, we'll turn our training procedure into a function."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def fit_model(loss, npoints, eta=1e-6, batch_size=16, epochs=25):\n",
" sample_idx = np.random.choice(\n",
" np.arange(train_x.shape[0]), replace=False, size=npoints\n",
" )\n",
" sample_x = train_x[sample_idx, :]\n",
" sample_y = train_y[sample_idx]\n",
"\n",
" # reshape into batches\n",
" batch_num = npoints // batch_size\n",
" # first truncate data so it's whole nubmer of batches\n",
" trunc = batch_num * batch_size\n",
" sample_x = sample_x[:trunc]\n",
" sample_y = sample_y[:trunc]\n",
" # split into batches\n",
" x_batches = sample_x.reshape(-1, batch_size, sample_x.shape[-1])\n",
" y_batches = sample_y.reshape(-1, batch_size)\n",
"\n",
" # get loss grad\n",
" loss_grad = jax.grad(loss, (0, 1))\n",
"\n",
" # make trainable parameters\n",
" # w = np.random.normal(scale = 1e-30, size=train_x.shape[0])\n",
" w = np.zeros(sample_x.shape[0])\n",
" b = np.mean(sample_y) # just set to mean, since it's a good first guess\n",
" for _ in range(epochs):\n",
" # go in random order\n",
" for i in np.random.randint(0, batch_num - 1, size=batch_num):\n",
" # update step\n",
" x = x_batches[i]\n",
" y = y_batches[i]\n",
" grad = loss_grad(w, b, sample_x, x, y)\n",
" w -= eta * grad[0]\n",
" b -= eta * grad[1]\n",
" return loss(w, b, sample_x, test_x, test_y)\n",
"\n",
"\n",
"# test it out\n",
"fit_model(loss, 256)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now we'll create L1 and L2 version of our loss. We must choose the *strength* of the regularization. Since our weights are less than 1, I'll choose much stronger regularization for the L2. These are hyperparameters though and you can adjust them to improve your fit. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"@jax.jit\n",
"def loss_l1(w, b, train_x, x, y):\n",
" return jnp.mean((batch_model(x, train_x, w, b) - y) ** 2) + 1e-2 * jnp.sum(\n",
" jnp.abs(w)\n",
" )\n",
"\n",
"\n",
"@jax.jit\n",
"def loss_l2(w, b, train_x, x, y):\n",
" return jnp.mean((batch_model(x, train_x, w, b) - y) ** 2) + 1e2 * jnp.sum(w**2)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"And now we can generate the points necessary for our curves!"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"nvalues = [32, 256, 512, 1024, 2048, 1024 * 5]\n",
"\n",
"nor_losses = [fit_model(loss, n) for n in nvalues]\n",
"l1_losses = [fit_model(loss_l1, n) for n in nvalues]\n",
"l2_losses = [fit_model(loss_l2, n) for n in nvalues]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"plt.plot(nvalues, nor_losses, label=\"No Regularization\")\n",
"plt.plot(nvalues, l1_losses, label=\"L1 Regularization\")\n",
"plt.plot(nvalues, l2_losses, label=\"L2 Regularization\")\n",
"plt.legend()\n",
"plt.xlabel(\"Training Data Amount\")\n",
"plt.ylabel(\"Test Loss\")\n",
"plt.gca().set_yscale(\"log\")\n",
"plt.gca().set_xscale(\"log\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Finally, we see our training curves showing the different approaches. Regularization has some effect on final loss on the test data. It is hard to say if L1 and L2 are simply worse, or if I need to tune the regularization strength more. Nevertheless, this plot shows you how we typically evaluated kernel learning methods. "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Exercises\n",
"\n",
"1. Compute the analytical gradient for the dual form regression equation and use it to describe why the kernel function does not need to be differentiable.\n",
"2. Is it faster or slower to do training with kernel learning? Explain\n",
"3. Is it faster or slower to do inference with kernel learning? Explain\n",
"4. How can we modify Equation 4.2 to do classification? \n",
"5. Do the weight values give relative importance of training examples regardless of kernel?\n",
"6. Create a training curve from the above example showing 5 different L1 regularization strengths. Why might regularization not matter here?"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Chapter Summary\n",
"\n",
"* In this section we introduced kernel learning, which is a method to transform features into distance between samples.\n",
"* A kernel function takes two arguments and outputs a scalar. A kernel function must have three properties: positive, symmetric, and point-separating.\n",
"* The distance function (inner product) is a kernel function.\n",
"* Kernel functions do not need to be differentiable.\n",
"* Kernel learning is appropriate when you would like to use features that can only be specified as a binary kernel function.\n",
"* The number of trainable parameters in a kernel model is proportional to number of training points, not dimension of features."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Cited References\n",
"\n",
"```{bibliography}\n",
":style: unsrtalpha\n",
":filter: docname in docnames\n",
"```"
]
}
],
"metadata": {
"celltoolbar": "Tags",
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.12"
}
},
"nbformat": 4,
"nbformat_minor": 4
}