{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "1x7mjzUHDtvx"
},
"source": [
"# Classification\n",
"\n",
"Classification is supervised learning with categorical labels. You are given labeled data consisting of features and labels $\\{\\vec{x}_i, \\vec{y}_i\\}$, where $\\vec{y}_i$ is a vector of binary values indicating class membership. An example of $\\vec{y}_i$ that indicates membership of classes \"soluble in THF\", \"insoluble in water\", \"soluble in chloroform\" might be:\n",
"\n",
"| THF | water | chloroform |\n",
"|:------|:-----:| ----------:|\n",
"| 1 | 0 | 1 |\n",
"\n",
"where we've indicated that the molecule in soluble in THF and chloroform but not water. As a vector, it is $\\vec{y} = (1, 0, 1) $. This is the general format of classification and can be called **multi-label** classification because we are attaching three labels: THF soluble, water insoluble, chloroform soluble. This can be restricted so that each data point belongs to only one class -- called **multi-class** classification. This might be like assigning visible color. A molecule can only be red or green or orange, but not multiple colors. Finally, you can only have one class and a label can only belong to the single class or does not belong to the single class. This is called **binary** classification and is the most common classification type. If you're doing multi-label or multi-class classification, the shape of $\\vec{y}$ will be a vector length $K$ where $K$ indicates number of classes. In the case of binary classification, the label is a binary value of 1 or 0 where 1 means it is a member of the class. You can view this as there being two classes: a **positive** class ($y = 1$) and **negative** class ($y = 0$). For example, you could be predicting if a molecule will kill cells. If the molecule is in the positive class, it kills cells. If it is the negative class, it is inert and does not kill cells. Depending on your choice of model type, when you predict the labels ($\\hat{\\vec{y}}$), your model could predict probabilities. \n",
"\n",
"```{admonition} Audience & Objectives\n",
"This chapter builds on {doc}`regression` and a basic knowledge of probability theory -- specifically random variables, normalization, and the metrics section below touches on calibration (empirical agreement of model distribution with true distribution). You can read [my notes](https://raw.githubusercontent.com/whitead/numerical_stats/master/unit_2/lectures/lecture_3.pdf) or any introductory probability text to learn these topics. After completing this chapter, you should be able to: \n",
"\n",
" * Distinguish between types of classification\n",
" * Set-up and train a classifier with a cross-entropy loss function\n",
" * Characterize classifier performance\n",
" * Identify and address class imbalance\n",
"```\n",
"\n",
"\n",
"\n",
"```{margin}\n",
"Note that in multi-class and binary classification $\\sum \\hat{\\vec{y}} = 1$, but \n",
"in multi-label classification this is not the case. Multi-label classification is like doing $K$ instances of binary classification. \n",
"```\n",
"\n",
"The goal of classification is to find a function that describes the relationship between features and class, $\\hat{f}(\\vec{x}) = \\hat{y}$. We'll see that this problem can be converted to regression by using probability or *distance from a decision boundary*. This means much of what we learned previously can be applied to this classification.\n",
"\n",
"The classic application of classification in structure-activity relationship is in drug discovery, where we want to predict if a molecule will be active (positive class) as a function of structure. That dates back to the 1970s. Classification is widely used now in materials and chemistry. Many molecular design problems can be formulated as classification. For example, you can use it to design new organic photovoltaic materials {cite}`sun2019machine` or antimicrobial peptides {cite}`barrett2018classifying`."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "HN3F0TXYDtv2"
},
"source": [
"## Data\n",
"\n",
"The dataset for this lecture was prepared by the MoleculeNet group {cite}`wu2018moleculenet`. It is a collection of molecules that succeeded or failed in clinical trials. The development of a new drug can cost well over a $1 billion, so any way to predict if a molecule will fail during clinical trials is highly valuable. The reason molecules fail in clinical trials is often due to safety, so even though some of these drugs failed because they were not effective there may be something common to each of the failed ones that we can learn.\n",
"\n",
"The labels will be the FDA_Approved column which is a 1 or 0 indicating FDA approval status. This is an example of binary classification."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "y78vGacyDtv2"
},
"source": [
"## Running This Notebook\n",
"\n",
"\n",
"Click the above to launch this page as an interactive Google Colab. See details below on installing packages.\n",
"\n",
"````{tip} My title\n",
":class: dropdown\n",
"To install packages, execute this code in a new cell. \n",
"\n",
"```\n",
"!pip install dmol-book\n",
"```\n",
"\n",
"If you find install problems, you can get the latest working versions of packages used in [this book here](https://github.com/whitead/dmol-book/blob/master/package/requirements.txt)\n",
"\n",
"````"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "4YbkujizDtv3"
},
"outputs": [],
"source": [
"import pandas as pd\n",
"import matplotlib.pyplot as plt\n",
"import rdkit, rdkit.Chem, rdkit.Chem.Draw\n",
"import numpy as np\n",
"import jax.numpy as jnp\n",
"import mordred, mordred.descriptors\n",
"import jax\n",
"import dmol"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Ao1j3KExDtv4"
},
"source": [
"Now we load the data. This is a little fancy because we're extracting the data file from a zip archive on a website. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "xcgFOCizDtv5"
},
"outputs": [],
"source": [
"# from zipfile import ZipFile\n",
"# from io import BytesIO\n",
"# from urllib.request import urlopen\n",
"\n",
"# from web version\n",
"# url = 'https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/clintox.csv.gz'\n",
"# file = urlopen(url).read()\n",
"# file = BytesIO(file)\n",
"# document = ZipFile(file)\n",
"# toxdata = pd.read_csv(document.open('clintox.csv'))\n",
"\n",
"# local version\n",
"toxdata = pd.read_csv(\n",
" \"https://github.com/whitead/dmol-book/raw/master/data/clintox.csv.gz\"\n",
")\n",
"toxdata.head()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "nBLU1hGuDtv5"
},
"source": [
"## Molecular Descriptors\n",
"\n",
"This time, our data does not come with pre-computed descriptors. We only have the SMILES string, which is a way of writing a molecule using letters and numbers (a string). We can use rdkit to convert the SMILES string into a molecule, and then we can use a package called Mordred {cite}`moriwaki2018mordred` to compute a set of descriptors for each molecule. This package will compute around 1500 descriptors for each molecule. \n",
"\n",
"We'll start by converting our molecules into rdkit objects and building a calculator to compute the descriptors."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "txmFSMPPDtv6"
},
"outputs": [],
"source": [
"# make object that can compute descriptors\n",
"calc = mordred.Calculator(mordred.descriptors, ignore_3D=True)\n",
"# make subsample from pandas df\n",
"molecules = [rdkit.Chem.MolFromSmiles(smi) for smi in toxdata.smiles]\n",
"\n",
"# view one molecule to make sure things look good.\n",
"molecules[0]"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "i1LWxpcYDtv6"
},
"source": [
"Some of our molecules failed to be converted. We'll have to remove them. We need to remember which ones were deleted too, since we need to remove the failed molecules from the labels."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "0dOVHw5jDtv6"
},
"outputs": [],
"source": [
"# the invalid molecules were None, so we'll just\n",
"# use the fact the None is False in Python\n",
"valid_mol_idx = [bool(m) for m in molecules]\n",
"valid_mols = [m for m in molecules if m]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "MhpCozWeDtv6",
"tags": [
"remove-output"
]
},
"outputs": [],
"source": [
"features = calc.pandas(valid_mols)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "n1077JZQDtv7"
},
"source": [
"Now we just need to stich everything back together so that our labels are consistent and standardize our features. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "yJQtBlxJDtv7"
},
"outputs": [],
"source": [
"labels = toxdata[valid_mol_idx].FDA_APPROVED\n",
"features -= features.mean()\n",
"features /= features.std()\n",
"\n",
"# we have some nans in features, likely because std was 0\n",
"features.dropna(inplace=True, axis=1)\n",
"\n",
"print(f\"We have {len(features.columns)} features per molecule\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "5QpVHo5TDtv7"
},
"source": [
"## Classification Models\n",
"\n",
"### Linear Perceptron\n",
"\n",
"We are able to predict single values from regression. How can we go from a predicted value to a class? The simplest answer is to use the same linear regression equation from chapter {doc}`../ml/regression` $\\hat{f}(\\vec{x})$ and assign $\\hat{y} = 1$ when $\\hat{f}(\\vec{x}) > 0$, $\\hat{y} = 0$ otherwise. Our model equation is then:\n",
"\n",
"\\begin{equation}\n",
"\\hat{f}(\\vec{x}) = \\begin{cases} \n",
" 1 & \\vec{w}\\cdot \\vec{x} + b > 0 \\\\\n",
" 0 & \\textrm{otherwise}\\\\\n",
" \\end{cases}\n",
"\\end{equation}\n",
"\n",
"The term $\\vec{w}\\cdot \\vec{x} + b$ is called **distance from the decision boundary** where the decision boundary is at $\\vec{w}\\cdot \\vec{x} + b = 0$. If it is large, we are far away from classifying it as $0$. If it is small, we are close to classifying it as $0$. It can be loosely thought of as \"confidence\". "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "M14yp0E_Dtv8"
},
"outputs": [],
"source": [
"def perceptron(x, w, b):\n",
" v = jnp.dot(x, w) + b\n",
" y = jnp.where(v > 0, x=jnp.ones_like(v), y=jnp.zeros_like(v))\n",
" return y"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "AsFRyvBlDtv8"
},
"source": [
"This particular model is called a **perceptron** and is the first neural network for classification. It was invented in 1958 by Frank Rosenblatt, a psychologist at Cornell University. It was not the first neural network, but is often the first one that students learn. The perceptron is an example of a **hard** classifier; it does not predict probability of the class and instead predicts exactly one class. \n",
"\n",
"\n",
"Now that we have a model, we must choose a loss function. We haven't learned about many loss functions yet. We've only seen mean squared error. Let us begin with a related loss called mean absolute error (MAE). MAE measures disagreement between our class and the predicted class. This is like an accuracy -- what percentage of the time we're correct.\n",
"\n",
"\\begin{equation}\n",
" L = \\frac{1}{N} \\sum_i \\left|y_i - \\hat{y}_i\\right|\n",
"\\end{equation}"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "L_9SI1FsDtv8"
},
"outputs": [],
"source": [
"def loss(y, yhat):\n",
" return jnp.mean(jnp.abs(y - yhat))\n",
"\n",
"\n",
"def loss_wrapper(w, b, x, y):\n",
" yhat = perceptron(x, w, b)\n",
" return loss(y, yhat)\n",
"\n",
"\n",
"loss_grad = jax.grad(loss_wrapper, (0, 1))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "YJnPp-WiDtv9"
},
"outputs": [],
"source": [
"batch_size = 32\n",
"train_N = int(len(labels) * 0.8)\n",
"\n",
"\n",
"N = len(labels)\n",
"batch_idx = range(0, train_N, batch_size)\n",
"w = np.random.normal(size=len(features.columns))\n",
"b = 0.0\n",
"\n",
"loss_grad = jax.grad(loss_wrapper, (0, 1))\n",
"\n",
"\n",
"test_x = features[train_N:].values.astype(np.float32)\n",
"test_y = labels[train_N:].values"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "4AHfyXs4Dtv9"
},
"source": [
"Let's now try out our gradient to make sure it works"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "lOx_q5I6Dtv9"
},
"outputs": [],
"source": [
"loss_grad(w, b, test_x, test_y)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "r3WqkjLEDtv-",
"tags": [
"remove-cell"
]
},
"outputs": [],
"source": [
"from myst_nb import glue\n",
"\n",
"x = np.linspace(-3, 3, 500)\n",
"y = 1 / (1 + np.exp(-x))\n",
"plt.plot(x, y)\n",
"plt.xlabel(r\"$x$\")\n",
"plt.ylabel(r\"$\\sigma(x)$\")\n",
"plt.axvline(0, color=\"gray\")\n",
"plt.title(\"Sigmoid\")\n",
"glue(\"sigmoid\", plt.gcf(), display=False)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "SGCxleaVDtv-"
},
"source": [
"It's all zeros! Why is that? It's because our {obj}`jnp.where` statement above is not differentiable, nor are any inequalities where the result is a constant (`1` or `0` in our case). The perceptron actually has a special training procedure that is not related to its derivatives. One of the motivating reasons that deep learning is popular is that we do not need to construct a special training process for each model we construct -- like the training procedure for the perceptron. \n",
"\n",
"Rather than teach and discuss the special perceptron training procedure, we'll move to a more modern related classifier called a softmax binary classifier. This is a tiny change, the softmax binary classifier is:\n",
"\n",
"\\begin{equation}\n",
"\\hat{f}(\\vec{x}) = \\sigma\\left(\\vec{w}\\cdot \\vec{x} + b\\right)\n",
"\\end{equation}\n",
"\n",
"```{glue:figure} sigmoid\n",
"----\n",
"name: simgoid\n",
"----\n",
"The sigmoid function. Input is any real number and the output is a probability. Positive numbers map to probabilities greater than 0.5 and negative numbers to probabilities less than 0.5. \n",
"```\n",
"\n",
"```{margin}\n",
"Softmax is the generalization of sigmoid to multiple classes. Although we call our binary classifier a softmax classifier, it doesn't use the softmax function.\n",
"```\n",
"\n",
"where $\\sigma$ is the **sigmoid** function. The sigmoid has a domain of $(-\\infty, \\infty)$ and outputs a probability $(0, 1)$. The input to the sigmoid can be viewed as log-odds, called **logits** for short. Odds are ratios of probability -- odds of 1 means the probability of the class 1 is 0.5 and class 0 is 0.5. Odds of 2 means the probability of class 1 is 0.67 and class 0 is 0.33. Log-odds is the natural logarithm of that, so that log-odds of 0 means the odds are 1 and the output probability should be 0.5. One definition of the sigmoid is \n",
"\n",
"\\begin{equation}\n",
"\\sigma(x) = \\frac{1}{1 + e^{-x}}\n",
"\\end{equation}\n",
"\n",
"however in practice there are some complexities to implementing sigmoids to make sure they're numerically stable. This type of binary classifier is sometimes called **logistic regression** because we're regressing logits. \n",
"\n",
"In essence, all we've done is replacing the inequality of the perceptron with a smooth differentiable version. Just like previously, a positive number indicated class 1 (FDA approved) but now it's a continuum of numbers from 0.5 to 1.0. This is **soft** classification -- we give probabilities of class membership instead of hard assignment. However, our loss function now needs to be modified as well. "
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "bNhnPLHODtv-"
},
"source": [
"There is a different loss function that works better with classification called **cross-entropy**. You can experiment with mean absolute error or mean squared error with classification, but you'll find they are almost always worse than cross-entropy.\n",
"\n",
"Cross-entropy is a loss function that describes distance between two probability distributions. When minimized, the two probability distributions are identical. Cross-entropy is a simplification of the [Kullbackâ€“Leibler divergence](https://en.wikipedia.org/wiki/Kullback%E2%80%93Leibler_divergence) which is a way to measure distance between two probability distributions. Technically it is not a distance since it's not symmetric with respect to its arguments. But in practice it is close enough to a distance that we treat it as one. \n",
"\n",
"How is comparing predicted values $\\hat{y}$ and $y$ like comparing two probability distributions? Even though these are both 1s and 0s in the case of hard classification, they do sum to 1, and so we consider them probability distributions. Cross-entropy is defined as:\n",
"\n",
"\\begin{equation}\n",
" L = -\\sum_c^K y_c \\log \\hat{y_c}\n",
"\\end{equation}\n",
"\n",
"where $c$ indicates which class of the $K$ we're considering, and it's assumed that $\\sum_c^K y_c = 1$ and $\\sum_c^K \\hat{y}_c = 1$ like probabilities (and they are positive). In the case of binary classification (only two classes), this becomes:\n",
"\n",
"\\begin{equation}\n",
" L = -\\left[ y_0 \\log \\hat{y_0} + y_1 \\log \\hat{y_1} \\right]\n",
"\\end{equation}\n",
"\n",
"where $y_0$ is for the first class and $y_1$ is for the second class. However, we also know that because these are probabilities that $y_1 = 1 - y_0$. We can rewrite to:\n",
"\n",
"\\begin{equation}\n",
" L = -\\left[ y_0 \\log \\hat{y_0} + (1 - y_0) \\log ( 1- \\hat{y_0}) \\right]\n",
"\\end{equation}\n",
"\n",
"Finally, we can drop the indication of the class:\n",
"\n",
"\\begin{equation}\n",
" L = -\\left[ y \\log \\hat{y} + (1 - y) \\log ( 1- \\hat{y}) \\right]\n",
"\\end{equation}\n",
"\n",
"\n",
"```{margin}\n",
"The correct way to avoid numerical instability in cross-entropy sigmoid classification is to have your model output the logits and you use a loss function that works on logits instead of probability. For example, {obj}`tf.nn.sigmoid_cross_entropy_with_logits`.\n",
"```\n",
"and this matches our data, where we have a single value for each label indicating if it is a class member. Now we have features, labels, loss, and a model. Let's create a batched gradient descent algorithm to train our classifier. Note, one change we need to do is use the built-in jax {obj}`jax.nn.sigmoid` function to avoid numerical instabilities and also add a small number to all logs to avoid numerical instabilities."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "W0OwHQBcDtv_"
},
"outputs": [],
"source": [
"def bin_classifier(x, w, b):\n",
" v = jnp.dot(x, w) + b\n",
" y = jax.nn.sigmoid(v)\n",
" return y\n",
"\n",
"\n",
"def cross_ent(y, yhat):\n",
" return jnp.mean(-(y * jnp.log(yhat + 1e-10) + (1 - y) * jnp.log(1 - yhat + 1e-10)))\n",
"\n",
"\n",
"def loss_wrapper(w, b, x, y):\n",
" yhat = bin_classifier(x, w, b)\n",
" return cross_ent(y, yhat)\n",
"\n",
"\n",
"loss_grad = jax.grad(loss_wrapper, (0, 1))\n",
"w = np.random.normal(scale=0.01, size=len(features.columns))\n",
"b = 1.0"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "rWY6ESj6Dtv_"
},
"outputs": [],
"source": [
"loss_progress = []\n",
"test_loss_progress = []\n",
"eta = 0.2\n",
"for epoch in range(5):\n",
" for i in range(len(batch_idx) - 1):\n",
" x = features[batch_idx[i] : batch_idx[i + 1]].values.astype(np.float32)\n",
" y = labels[batch_idx[i] : batch_idx[i + 1]].values\n",
" grad = loss_grad(w, b, x, y)\n",
" w -= eta * grad[0]\n",
" b -= eta * grad[1]\n",
" loss_progress.append(loss_wrapper(w, b, x, y))\n",
" test_loss_progress.append(loss_wrapper(w, b, test_x, test_y))\n",
"plt.plot(loss_progress, label=\"Training Loss\")\n",
"plt.plot(test_loss_progress, label=\"Testing Loss\")\n",
"\n",
"plt.xlabel(\"Step\")\n",
"plt.legend()\n",
"plt.ylabel(\"Loss\")\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "8gCAUAypDtwA"
},
"source": [
"We are making good progress with our classifier, as judged from testing loss. You can run the code longer, but I'll leave it at that. We have a reasonably well-trained model."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "3aDuhvTiDtwA"
},
"source": [
"## Classification Metrics\n",
"\n",
"In regression, we assessed model performance with a parity plot, correlation coefficient, or mean squared error. In classification, we use slightly different metrics. The first metric is **accuracy**. Accuracy is the percentage of time that the predicted label matches the true label. We do not have a hard classifier, so we have to choose how to turn probability into a specific class. For now, we will choose the class with the highest probability. Let's see how this looks"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "wWbuc2qUDtwA"
},
"outputs": [],
"source": [
"def accuracy(y, yhat):\n",
" # convert from prob to hard class\n",
" hard_yhat = np.where(yhat > 0.5, np.ones_like(yhat), np.zeros_like(yhat))\n",
" disagree = np.sum(np.abs(y - yhat))\n",
" return 1 - disagree / len(y)\n",
"\n",
"\n",
"accuracy(test_y, bin_classifier(test_x, w, b))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "EL8jRt0MDtwA",
"tags": [
"remove-cell"
]
},
"outputs": [],
"source": [
"glue(\"acc\", accuracy(test_y, bin_classifier(test_x, w, b)))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "4vuzrIOGDtwA"
},
"source": [
"An accuracy of {glue:text}`acc:.2f` seems quite reasonable! However, consider this model:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "7gf4oMVQDtwB"
},
"outputs": [],
"source": [
"def alt_classifier(x):\n",
" return np.ones((x.shape[0]))\n",
"\n",
"\n",
"accuracy(test_y, alt_classifier(test_x))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "NvXgsT4PDtwB"
},
"source": [
"This model, which always returns 1, has better accuracy than our model. How is this possible? \n",
"\n",
"```{admonition} Answer\n",
":class: tip, dropdown\n",
"If you examine the data, you'll see the majority of the molecules passed FDA clinical trials ($y = 1$), so that just guessing $1$ is a good strategy.\n",
"```"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "x78lCHDQDtwC"
},
"source": [
"### Error Types\n",
"\n",
"Let's recall what we're trying to do. We're trying to predict if a molecule will make it through FDA clinical trials. Our model can be incorrect in two ways: it predicts a molecule will pass through clinical trials, but it actually fails. This is called a false positive. The other error is if we predict our drug will not make it through clinical trials, but it actually does. This is false negative.\n",
"\n",
"```{margin}\n",
"False positive are sometimes known as Type I (pronounced type one) and false negatives as Type II false negatives\n",
"```\n",
"\n",
"Our `alt_classifier` model, which simply reports everything as positive, has no false negative errors. It has many false positive errors. These two types of errors can be quantified. We're going to add one complexity -- **threshold**. Our model provides probabilities which we're converting into hard class memberships -- 1s and 0s. We have been choosing to just take the most probable class. However, we will now instead choose a threshold for when we report a positive (class 1). The rationale is that although we train our model to minimize cross-entropy, we may want to be more conservative or aggressive in our classification with the trained model. If we want to minimize false negatives, we can lower the threshold and report even predictions that have a probability of 30% as positive. Or, if we want to minimize false positives we may set our threshold so that our model must predict above 90% before we predict a positive."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "_wlBwdcEDtwC"
},
"outputs": [],
"source": [
"def error_types(y, yhat, threshold):\n",
" hard_yhat = np.where(yhat >= threshold, np.ones_like(yhat), np.zeros_like(yhat))\n",
" # predicted 1, actually was 0 -> 1 (bool to remove predicted 0, actually was 1)\n",
" fp = np.sum((hard_yhat - y) > 0)\n",
" # predicted 0, actually was 1 -> 1 (bool to remove predicted 1, actually was 0)\n",
" fn = np.sum((y - hard_yhat) > 0)\n",
" return fp, fn"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "nkQJ_PdTDtwD"
},
"outputs": [],
"source": [
"print(\"Alt Classifier\", error_types(test_y, alt_classifier(test_x), 0.5))\n",
"print(\"Trained Classifier\", error_types(test_y, bin_classifier(test_x, w, b), 0.5))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Qq6IL9qnDtwD"
},
"source": [
"Now we have a better sense of how our model does in comparison. The number of errors is indeed larger for our trained model, but it has a bit of balance between the two errors. What is more important? In our case, I would argue doing clinical trials that fail is worse than mistakenly not starting them. That is, false positives are worse than false negatives. Let's see if we can tune our threshold value to minimize false positives. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "3j11ETYeDtwD"
},
"outputs": [],
"source": [
"print(\"Threshold 0.7\", error_types(test_y, bin_classifier(test_x, w, b), 0.7))\n",
"print(\"Threshold 0.9\", error_types(test_y, bin_classifier(test_x, w, b), 0.9))\n",
"print(\"Threshold 0.95\", error_types(test_y, bin_classifier(test_x, w, b), 0.95))\n",
"print(\"Threshold 0.99\", error_types(test_y, bin_classifier(test_x, w, b), 0.99))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "bgcO3BdkDtwE"
},
"source": [
"By adjusting the threshold, we can achieve a balance of error more like what we desire for our model. We're able to have 1 false positives in fact, at the cost of missing 218 of the molecules. Now are we still predicting positives? Are we actually going to get some **true positives?** We can measure that as well"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "HBj3MVZLDtwE"
},
"outputs": [],
"source": [
"total_pos = np.sum(test_y)\n",
"print(\n",
" \"Total positives:\",\n",
" total_pos,\n",
" \"Predicted Positives:\",\n",
" np.sum(bin_classifier(test_x, w, b) > 0.99),\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "MPMLguTlDtwE"
},
"source": [
"Yes, our model is actually capable of predicting if molecules will pass FDA clinical trials with as few false positives as possible (1). A model that is capable of this tuning is an example of a good model. Our other model, that predicts 1s, has good accuracy but we cannot adjust it or try to better balance type I and type II errors. "
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "A5dV0HlzDtwE"
},
"source": [
"### Receiver-Operating Characteristic Curve\n",
"\n",
"We can plot threshold, false positive rate, and true positive rate all together on one plot to capture model accuracy and balance between error type in a Receiver-Operating Characteristic Curve (ROC curve). The x-axis of ROC curve is false positive rate and the y-axis is true positive rate. Each point on the plot is our model with different thresholds. How do we choose which thresholds to use? It is the set of unique class probabilities we saw (namely, {obj}`np.unique`). We do need to add two extremes to this set though: all positive (threshold of 0.0) and all negative (1.0). Recall our alternate/baseline model of always predicting positive: it can only have a few points on the the ROC curve because it's unique set of probabilities is just 1.0. Let's make one and discuss what we're seeing."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "xvrIo3e1DtwF"
},
"outputs": [],
"source": [
"unique_threshes = np.unique(bin_classifier(test_x, w, b))\n",
"fp = []\n",
"tp = []\n",
"total_pos = np.sum(test_y)\n",
"for ut in list(unique_threshes) + [-0.1, 1.01]:\n",
" errors = error_types(test_y, bin_classifier(test_x, w, b), ut)\n",
" fp.append(errors[0])\n",
" tp.append(total_pos - errors[1])\n",
"\n",
"# sort them so can plot as a line\n",
"idx = np.argsort(fp)\n",
"fpr = np.array(fp)[idx] / (len(test_y) - np.sum(test_y))\n",
"tpr = np.array(tp)[idx] / np.sum(test_y)\n",
"\n",
"# now remove duplicate x-values\n",
"fpr_nd = []\n",
"tpr_nd = []\n",
"last = None\n",
"for f, t in zip(fpr, tpr):\n",
" if last is None or f != last:\n",
" last = f\n",
" fpr_nd.append(f)\n",
" tpr_nd.append(t)\n",
"\n",
"plt.plot(fpr_nd, tpr_nd, \"-o\", label=\"Trained Model\")\n",
"plt.plot([0, 1], [0, 1], label=\"Naive Classifier\")\n",
"plt.ylabel(\"True Positive Rate\")\n",
"plt.xlabel(\"False Positive Rate\")\n",
"plt.legend()\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "3hk4JjZRDtwF"
},
"source": [
"This plot nicely shows how our trained model is actually sensitive to threshold, so that we could choose to more carefully screen for false negative or false positives. The best curves fall to the top-left of this plot. Our naive classifier is where we return a fixed percentage of examples randomly as positive or negative. You can plot the area under this curve with an integration and this is a good way to measure classifier performance and correctly capture the effect of both false negatives and false positives. The area under the ROC curve is known as the **ROC AUC score** and is preferred to accuracy because it captures the balance of Type I and II errors."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "GunRy3jlDtwG"
},
"source": [
"### Other metrics\n",
"\n",
"I will just mention that there are other ways to assess how your model balances the two error types. One major type is called **precision** and **recall**. Precision measures correctness of predicted positives and recall measures number of predicted positives. This can be a good viewpoint when doing molecular screening -- you may want to be very precise in that your proposed molecules are accurate while sacrificing recall. Recall here meaning you do not return very many molecules. Models on the left of an ROC curve are precise. Models at the top have good recall. There are also F1 scores, likelihoods, Matthew's correlation coefficients, Jaccard index, Brier score, and balanced accuracy which all try to report one number which balances precision and recall. We will rarely explore these other measures but you should know they exist.\n",
"\n",
"#### Confusion Matrix\n",
"\n",
"A confusion matrix is a table of counts indicating true and predicted classes. They are one of many methods for binary classification, but really stand-out as good visual assessment for multiclass classification. For example, consider we are categorizing molecules into three classes: insoluble, weakly soluble, and soluble. We can represent a classifier's performance in a table:\n",
"\n",
"| truelđź‘‡\\predictedđź‘‰ | insoluble | weakly soluble | soluble |\n",
"|:--------------|:-----------:|:--------------:| -------:|\n",
"| insoluble | 121 | 8 | 1 |\n",
"| weakly soluble| 7 | 45 | 18 |\n",
"| soluble | 11 | 4 | 56 |\n",
"\n",
"\n",
"The diagonal elements show when the predicted and true labels agree. For example, 121 molecules were actually insoluble and predicted to be insolbule. We can also read how the classifier failed. One molecule was predicted to be soluble, but was actually insoluble. 4 molecules were predicted to be weakly soluble, but were actually soluble. This can help us understand *how* the classifier is failing."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "jKfEkNBMDtwG"
},
"source": [
"## Class Imbalance\n",
"\n",
"The reason for this uneven amount of false positives and false negatives is that we have very few negative example -- molecules which failed FDA clinical trials. This also explains why just predicting success has a high accuracy. How can we address this problem?\n",
"\n",
"The first answer is do nothing. Is this imbalance a problem at all? Perhaps a drug in general will succeed at clinical trials and thus the imbalance in training data reflects what we expect to see in testing. This is clearly not the case, judging from the difficult and large expense of creating new drug molecules. However, this should be the first thing you ask yourself. If you're creating a classifier to detect lung cancer from X-ray images, probably you will have imbalanced training data and at test time, when evaluating patients, you'll also not have 50% of patients having lung cancer. This comes back to the discussion in the {doc}`regression` about training data distribution. If your testing data is within your training data distribution, then the class imbalance does not need to be explicitly addressed.\n",
"\n",
"The second solution is to somehow weight your training data to appear more like your testing data when you think you do have **label shift**. There are two ways to accomplish this. You could \"augment\" your training data by repeating the minority class until the ratio of minority to majority examples matches the assumed testing data. There are research papers written on this topic, with intuitive results{cite}`chawla2002smote`. You can over-sample minority class but that can lead to a large dataset, so you can also under-sample the majority class. This is a robust approach that is independent to how you train. It also is typically as good as more sophisticated methods {cite}`youbi2021simple`.\n",
"\n",
"Another method of weighing data is to modify your loss function to increase the gradient updates applied to minority examples. This is equivalent to saying there is a difference in loss between a false positive vs a false negative. In our case, false positive are rarer and also more important in reality. We would rather skip a clinical trial (false negative) rather than start one and have it fail (false positive). We already tried minimizing false positives by changing the threshold on a trained model but let's see how this works during training. We'll create a weight vector that is high for negative labels so that if they are misclassified (false positive), there will be a bigger update.\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "FbF1hZIpDtwG"
},
"outputs": [],
"source": [
"def bin_classifier(x, w, b):\n",
" v = jnp.dot(x, w) + b\n",
" y = jax.nn.sigmoid(v)\n",
" return y\n",
"\n",
"\n",
"def weighted_cross_ent(y, yhat, yw):\n",
" # weights may not be normalized\n",
" N = jnp.sum(yw)\n",
" # use weighted sum instead\n",
" return (\n",
" jnp.sum(\n",
" -(yw * y * jnp.log(yhat + 1e-10) + yw * (1 - y) * jnp.log(1 - yhat + 1e-10))\n",
" )\n",
" / N\n",
" )\n",
"\n",
"\n",
"def loss_wrapper(w, b, x, y, yw):\n",
" yhat = bin_classifier(x, w, b)\n",
" return weighted_cross_ent(y, yhat, yw)\n",
"\n",
"\n",
"loss_grad = jax.grad(loss_wrapper, (0, 1))\n",
"w2 = np.random.normal(scale=0.01, size=len(features.columns))\n",
"b2 = 1.0\n",
"weights = np.ones_like(labels)\n",
"# make the labels = 0 values be much larger\n",
"weights[labels.values == 0] *= 1000\n",
"# now make weights be on average 1\n",
"# to keep our learning rate/avg update consistent\n",
"weights = weights * len(weights) / np.sum(weights)\n",
"\n",
"loss_progress = []\n",
"test_loss_progress = []\n",
"eta = 0.2\n",
"# make epochs larger since this has\n",
"# very large steps that converge poorly\n",
"for epoch in range(10):\n",
" for i in range(len(batch_idx) - 1):\n",
" x = features[batch_idx[i] : batch_idx[i + 1]].values.astype(np.float32)\n",
" y = labels[batch_idx[i] : batch_idx[i + 1]].values\n",
" yw = weights[batch_idx[i] : batch_idx[i + 1]]\n",
" grad = loss_grad(w2, b2, x, y, yw)\n",
" w2 -= eta * grad[0]\n",
" b2 -= eta * grad[1]\n",
" loss_progress.append(loss_wrapper(w2, b2, x, y, yw))\n",
" test_loss_progress.append(\n",
" loss_wrapper(w2, b2, test_x, test_y, np.ones_like(test_y))\n",
" )\n",
"plt.plot(loss_progress, label=\"Training Loss\")\n",
"plt.plot(test_loss_progress, label=\"Testing Loss\")\n",
"\n",
"plt.xlabel(\"Step\")\n",
"plt.legend()\n",
"plt.ylabel(\"Loss\")\n",
"plt.show()\n",
"\n",
"print(\"Normal Classifier\", error_types(test_y, bin_classifier(test_x, w, b), 0.5))\n",
"print(\"Weighted Classifier\", error_types(test_y, bin_classifier(test_x, w2, b2), 0.5))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "hAaO0dfIDtwH"
},
"source": [
"The spikes in loss occur when we see a rare negative example, which are weighted heavily. Compared to the normal classifier trained above, we have fewer false positives at a threshold of 0.5. However, we also have more false negatives. We saw above that we could tweak this by changing our threshold. Let's see how our model looks on an ROC curve to compare our model trained with weighting with the previous model at all thresholds.\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "42tvh69XDtwH"
},
"outputs": [],
"source": [
"unique_threshes = np.unique(bin_classifier(test_x, w2, b2))\n",
"fp = []\n",
"tp = []\n",
"total_pos = np.sum(test_y)\n",
"for ut in list(unique_threshes) + [-0.1, 1.01]:\n",
" errors = error_types(test_y, bin_classifier(test_x, w2, b2), ut)\n",
" fp.append(errors[0])\n",
" tp.append(total_pos - errors[1])\n",
"\n",
"# sort them so can plot as a line\n",
"idx = np.argsort(fp)\n",
"fpr = np.array(fp)[idx] / (len(test_y) - np.sum(test_y))\n",
"tpr = np.array(tp)[idx] / np.sum(test_y)\n",
"\n",
"# now remove duplicate x-values\n",
"fpr_nd2 = []\n",
"tpr_nd2 = []\n",
"last = None\n",
"for f, t in zip(fpr, tpr):\n",
" if last is None or f != last:\n",
" last = f\n",
" fpr_nd2.append(f)\n",
" tpr_nd2.append(t)\n",
"\n",
"plt.plot(fpr_nd, tpr_nd, \"-o\", label=\"Normal Model\")\n",
"plt.plot(fpr_nd2, tpr_nd2, \"-o\", label=\"Weighted Model\")\n",
"plt.plot([0, 1], [0, 1], label=\"Naive Classifier\")\n",
"plt.ylabel(\"True Positive Rate\")\n",
"plt.xlabel(\"False Positive Rate\")\n",
"plt.legend()\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "-1KEJUl5DtwH"
},
"source": [
"It appears our weighted training actually did not improve model performance, except in a small range between 0.25-0.4 false positive rate. It is even worse in the low false positive rate, which is where we would like to operate. In conclusion, we can modify the balance of false positive and false negative through modifications to training. However, we can also modify this after training by affecting the threshold for classification. This post-training procedure gives similar or even slightly better performance in our example. \n",
"\n",
"This may not always be the case. An overview of methods are available in {cite}`he2009learning` and you can find a more recent discussion of the effects of reweighting, including when combined with regularization, in Bryrd and Lipton {cite}`byrd2019effect`. Byrd and Lipton show that reweighting has little effect unless combined with L2 regularization and batch normalization, perhaps accounting for the small effect we observed.\n",
"\n",
"### Screening: no negative examples\n",
"Class imbalance is common in peptide and drug discovery where screening is used to generate data. Screening typically only contains positive examples, meaning you have literally zero negative examples. This is an active topic of research in a field of **positive-unlabeled learning** {cite}`song2021inferring`\n",
"\n",
"\n",
"## Overfitting\n",
"\n",
"The goal of this chapter is to introduce classification. For simplicity, we did not use any of the techniques from the last chapter, except training/testing splitting. You can and should use techniques like Jacknife+ and cross-validation to assess overfitting. Also, our descriptor number was very high, a few hundred, and regularization could be helpful for these models. "
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "sU7Fp2-rDtwH"
},
"source": [
"## Chapter Summary\n",
"\n",
"* We introduced classification, which is supervised learning with categorical labels. The labels can be single binary values - representing 2 classes which is binary classification.\n",
"* We can compute descriptors for molecules using Python packages and do not require them to be part of our dataset\n",
"* Cross-entropy loss should be used for classification tasks\n",
"* Classification models (called classifiers) can output distance from decision boundary or, more commonly, probability of class\n",
"* The Perceptron is an early example of a neural network classifier that has a special training procedure\n",
"* The sigmoid and soft-max functions convert real numbers into probabilities\n",
"* Binary classification error can be false positives or false negatives\n",
"* Accuracy does not distinguish these two errors, so receive-operator characteristic (ROC) curves can be used to assess model performance. Precision and recall are other commonly used measures. \n",
"* An imbalance of classes in training data is not necessarily a problem and can be addressed by weighting training examples"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Q4vNM5IgDtwI"
},
"source": [
"## Exercises\n",
"\n",
"### Classification\n",
"\n",
"1. Design your own examples of labels for binary, multi-class, and multi-label classification. For example, \"A multi-class label is the country a person lives in. A label for this is a 225 element vector with one non-zero element indicating the country the person lives in.\"\n",
"\n",
"2. Write out the equations for cross-entropy in multi-class and multi-label settings. \n",
"\n",
"### Data\n",
"\n",
"1. Use the dimensional reduction methods from our [first chapter](./introduction.ipynb) to plot the molecules here in 2D. Color the points based on their labels. Do you see any patterns?\n",
"\n",
"2. Now, use clustering to color the molecules. Use an elbow plot to choose your cluster number.\n",
"\n",
"### Assessment\n",
"\n",
"1. Repeat the model fitting with L1 and L2 regularization and plot them on a ROC curve. What effect does regularization have on these? Choose a strength of 0.1.\n",
"\n",
"2. Could you use leave-one-class-out cross-validation in binary classification? Why or why not?\n",
"\n",
"3. We said that class imbalance alone has little effect on model training, as long as the testing distribution matches the training distribution. However, can you make an argument using the bias-variance decomposition about why this may not be true with small dataset size?\n",
"\n",
"4. Compute the area under the curve of an ROC curve using numerical trapezoidal integration.\n",
"\n",
"### Complete Model\n",
"\n",
"Do your best to create a binary-classifier for this dataset with regularization and any other methods we learned from this chapter the previous ones. What is the best area under the curve you can achieve?"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "F7JSY86NDtwI"
},
"source": [
"## Cited References\n",
"\n",
"```{bibliography}\n",
":style: unsrtalpha\n",
":filter: docname in docnames\n",
"```"
]
}
],
"metadata": {
"celltoolbar": "Tags",
"colab": {
"provenance": []
},
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.12"
}
},
"nbformat": 4,
"nbformat_minor": 0
}