{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Explaining Predictions\n",
"\n",
"Neural network predictions are not interpretable in general. In this chapter, we explore how to explain predictions. This is part of the broader topic of explainable AI (XAI). These explanations should help us understand why particular predictions are made. This is a critical topic because being able to understand model predictions is justified from a practical, theoretical, and increasingly a regulatory stand-point. It is practical because it has been shown that people are more likely to use predictions of a model if they can understand the rationale {cite}`lee2004trust`. Another practical concern is that correctly implementing methods is much easier when one can understand how a model arrived at a prediction. A theoretical justification for transparency is that it can help identify incompleteness in model domains (i.e., covariate shift){cite}`doshi2017towards`. It is now becoming a compliance problem because both the European Union {cite}`goodman2017european` and the G20 {cite}`Development2019` have recently adopted guidelines that recommend or require explanations for machine predictions. The US and EU are also considering going further with more [strict draft legislation](https://digital-strategy.ec.europa.eu/en/library/proposal-regulation-laying-down-harmonised-rules-artificial-intelligence-artificial-intelligence) and a so-called White House AI Bill of Rights {cite}`blumenthal2022ai`.\n",
"\n",
"\n",
"```{admonition} Audience & Objectives\n",
"This chapter builds on {doc}`layers` and {doc}`NLP`. It also assumes a good knowledge of probability theory, including conditional probabilities. You can read [my notes](https://raw.githubusercontent.com/whitead/numerical_stats/master/unit_2/lectures/lecture_3.pdf) or any introductory probability text to get an overview. After completing this chapter, you should be able to \n",
"\n",
" * Justify why explanations are important\n",
" * Distinguish between justification, interpretation, and explanation\n",
" * Compute feature importance and Shapley values\n",
" * Define a counterfactual and compute them\n",
" * Know which models are interpretable and how to fit interpretable surrogate models \n",
"```\n",
"\n",
"A famous example on the need for explainable AI is found in Caruana et al.{cite}`caruana2015intelligible` who built an ML predictor to assess mortality risk of patients in the ER with pneumonia. The idea is that patients with pneumonia are screened with this tool and it helps doctors know which patients are more at risk of dying. It was found to be quite accurate. When the interpretation of its predictions were examined though, the reasoning was medically insane. The model surprisingly suggested patients with asthma (called asthmatics) have a reduced mortality risk when coming to the ER with pneumonia. Asthma, a condition which makes it difficult to breathe, was found to *make pneumonia patients less likely to die.* This was incidental; asthmatics are actually more at risk of dying from pneumonia but doctors are acutely aware of this and are thus more aggressive and attentive with them. Thanks to the increase care and attention from doctors, there are fewer mortalities. From an empirical standpoint, the model predictions are correct. However if the model were put into practice, it could have cost lives by incorrectly characterizing asthmatics as low mortality risk. Luckily the interpretability of their model helped researchers identify this problem. Thus, we can see that interpretation should always be a step in the construction of predictive models. \n",
"\n",
"## What is an explanation?\n",
"\n",
"We'll use the definition of explanation from Miller {cite}`miller2019explanation`. Miller distinguishes between interpretability, justification, and explanation with the following definitions:\n",
"\n",
"* **interpretability** \"the degree to which an observer can understand the cause of a decision\". Miller considers this synonymous with explainability. *This is generally a property of a model.*\n",
"* **justification** evidence or explanation of why a decision is good, like testing error or accuracy of a model. *This is a property of a model.*\n",
"* **explanation** explanations are a presentation of information intended for humans that give the context and cause for an outcome. These are the major focus of this chapter. *This is generally something extra we generate and not a property of a model.*\n",
"\n",
"We will dig deeper into what constitutes an *explanation*, but note an explanation is different than justifying a prediction. Justification is what we've focused on previously: empirical evidence for why we should believe model predictions are accurate. An explanation provides a *cause* for the prediction. Ultimately, explanations are intended to be understood by humans.\n",
"\n",
"Deep learning alone is a black box modeling technique. It is not interpretable or explainable. Examining the weights or model equation provides little insight into why predictions are made. Thus, interpretability is an extra task and means adding an explanation to predictions from the model. This is a challenge because of both the black box nature of deep learning and because there is no consensus on what exactly constitutes an \"explanation\" for model predictions {cite}`doshi2017towards`. For some, interpretability means having a natural language explanation justifying each prediction. For others, it can be simply showing which features contributed most to the prediction. \n",
"\n",
"There are two broad approaches to interpretation of ML models: post hoc interpretation via explanations and self-explaining models {cite}`Murdoch2019`. Self-explaining models are constructed so that an expert can view output of the model and connect it with the features through reasoning. They are inherently interpretable. Self-explaining models are highly dependent on the task model{cite}`montavon2018methods`. A familiar example would be a physics based simulation like molecular dynamics or a single-point quantum energy calculation. You can examine the molecular dynamics trajectory, look at output numbers, and an expert can explain why, for example, the simulation predicts a drug molecule will bind to a protein. \n",
"\n",
"It may seem like self-explaining models would be useless for deep learning interpretation. However, we will see later that we can create a **surrogate model** (sometimes **proxy model**) that is self-explaining and train it to agree with the deep learning model. Why will this training burden be any less than just using the surrogate model from the beginning? We can generate an infinite amount of training data because our trained neural network can label arbitrary points. You can also construct deep learning models which have self-explaining features in them, like attention {cite}`bahdanau2014neural`. This allows you to connect the input features to the prediction based on attention. There is also work within machine learning called **symbolic regression**, which tries to construct self-explaining models by working with mathematical equations that can be directly interpreted{cite}`ansari2021iterative,billard2000regression,udrescu2020ai`. Symbolic regression is then used to generate the surrogate model{cite}`cranmer2020discovering`.\n",
"\n",
"Post hoc interpretation by creating explanations can be approached in a number of ways, but the most common are training data importance, feature importance, and counterfactual explanations{cite}`wellawatte_seshadri_white_2021,ribeiro2016should,ribeiro2016model,wachter2017counterfactual`. An example of a post hoc interpretation based on data importance is identifying the most influential training data to explain a prediction {cite}`koh2017understanding`. It is perhaps arguable if this gives an *explanation*, but it certainly helps understand which data is relevant for a prediction. Feature importance is probably the most common XAI approach and frequently appears in computer vision research where the pixels most important for the class of an image are highlighted. \n",
"\n",
"Counterfactual explanations are an emerging post hoc interpretation method. Counterfactuals are new data point that serve as an explanation. A counterfactual gives insight into how important and sensitive the features are. An example might be in a model that recommends giving a loan. A model could produce the following counterfactual explanation (from {cite}`wachter2017counterfactual`):\n",
"\n",
"> You were denied a loan based on your annual income, zip code, and assets. If \n",
"> your annual income had been $45,000, you would have been offered a loan.\n",
"\n",
"The second sentence is the conuterfactual and shows how the features could be changed to affect the model outcome. Counterfactuals provide a nice balance of complexity and explanatory power.\n",
"\n",
"This was a brief overview of large field of XAI. You can find a recent review of interpretable deep learning in Samek et al. {cite}`9369420` and Christopher Molnar has a [broad online book](https://christophm.github.io/interpretable-ml-book/) about interpretable machine learning, including deep learning {cite}`molnar2019`. Prediction error and confidence in predictions are not covered here, since they are more about justification, but see the methods from {doc}`../ml/regression` which apply. "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Feature Importance\n",
"\n",
"Feature importance is the most straightforward and common method of interpreting a machine learning model. The output of feature importance is a ranking or numerical values for each feature, typically for a single prediction. If you are trying to understand the feature importance across the whole model, this is called **global** feature importance and **local** for a single prediction. Global feature importance and global interpretability is relatively rare because accurate deep learning models change which features are important in different regions of feature space.\n",
"\n",
"Let's start with a linear model to see feature importance:\n",
"\n",
"\\begin{equation}\n",
"\\hat{y} = \\vec{w}\\vec{x} + b \n",
"\\end{equation}\n",
"\n",
"where $\\vec{x}$ is our feature vector. A simple way to assess feature importance is to simply look at the weight value $w_i$ for a particular feature $x_i$. The weight $w_i$ shows how much $\\hat{y}$ would change if $x_i$ is increased by 1, while all other features are constant. If the magnitude of our features are comparable, then this would be a reasonable way to rank features. However, if our features have units, this approach is sensitive to unit choices and relative magnitude of features. For example if our temperature was changed from Celsius to Fahrenheit, a 1 degree increase will have a smaller effect. \n",
"\n",
"To remove the effect of feature magnitude and units, a slightly better way to assess feature importance is to divide $w_i$ by the **standard error** in the feature values. Recall that standard error is just the ratio of sum of squared error in predicted value divided by the total deviation in the feature. Standard error is a ratio of prediction accuracy to feature variance. $w_i$ divided by standard error is called the $t$-statistic because it can be compared with the $t$-distribution for assessing feature importance.\n",
"\n",
"\\begin{equation}\n",
"t_i = \\frac{w_i}{S_{w_i}},\\; S^2_{w_i} = \\frac{1}{N - D}\\sum_j \\frac{\\left(\\hat{y}_j - y_j\\right)^2}{\\left(x_{ij} - \\bar{x}_i\\right)^2}\n",
"\\end{equation}\n",
"\n",
"where $N$ is the number of examples, $D$ is the number of features, and $\\bar{x}_i$ is the average value of the $i$th feature. The $t_i$ value can be used to rank features and it can be used for a hypothesis test: if $P(t > t_i) < 0.05$ then that feature is significant, where $P(t)$ is Student's $t$-distribution. Note that a feature's significance is sensitive to which features are present in a model; if you add new features some may become redundant.\n",
"\n",
"If we move to a nonlinear learned function $\\hat{f}(\\vec{x})$, we must compute how the prediction changes if a feature value increases by 1 via the derivative approximation:\n",
"\n",
"$$\n",
"\\frac{\\Delta \\hat{f}(\\vec{x})}{\\Delta x_i} \\approx \\frac{\\partial \\hat{f}(\\vec{x})}{\\partial x_i}\n",
"$$\n",
"\n",
"so a change by 1 is\n",
"\n",
"\\begin{equation}\n",
"\\Delta \\hat{f}(\\vec{x}) \\approx \\frac{\\partial \\hat{f}(\\vec{x})}{\\partial x_i}.\n",
"\\end{equation}\n",
"\n",
"\n",
"In practice, we make a slight variation on this equation -- instead of a Taylor series centered at 0 approximating this change, we center at some other root (point where the function is 0). This \"grounds\" the series at the decision boundary (a root) and then you can view the partials as \"pushing\" the predicted class away or towards the decision boundary. Another way to think about this is that we use the first-order terms of the Taylor series to build a linear model. Then we just apply what we did above to that linear model and use the coefficients as the \"importance\" of features. Specifically, we use this surrogate function for $\\hat{f}(\\vec{x})$:\n",
"\n",
"\\begin{equation}\n",
"\\require{cancel}\n",
"\\hat{f}(\\vec{x}) \\approx \\cancelto{0}{f(\\vec{x}')} + \\nabla\\hat{f}(\\vec{x}')\\cdot\\left(\\vec{x} - \\vec{x}'\\right)\n",
"\\end{equation}\n",
"\n",
"where $\\vec{x}'$ is the root of $\\hat{f}(\\vec{x})$. In practice people may choose the trivial root $\\vec{x}' = \\vec{0}$, however a nearby root is ideal. This root is often called the **baseline** input. Note that as opposed to the linear example above, we consider the product of the partial $\\frac{\\partial \\hat{f}(\\vec{x})}{\\partial x_i}$ and the increase above baseline $(x_i - x_i')$."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Neural Network Feature Importance\n",
"\n",
"In neural networks, the partial derivatives are a poor approximation of the real changes to the output. Small changes to the input can have discontinuous changes (because of nonlinearities like ReLU), making the terms above have little explanatory power. This is called the **shattered gradients** problem {cite}`pmlr-v70-balduzzi17b`. Breaking down each feature separately also misses correlations between features -- which don't exist in a linear model. Thus the derivative approximation works satisfactorily in locally linear models, but not deep neural networks.\n",
"\n",
"There are a variety of techniques that get around the issue of shattered gradients in neural networks. Two popular methods are integrated gradients {cite}`sundararajan2017axiomatic` and SmoothGrad{cite}`smilkov2017smoothgrad`. Integrated gradients creates a path from $\\vec{x}'$ to $\\vec{x}$ and integrates Equation 4 along that path:\n",
"\n",
"\\begin{equation}\n",
"\\textrm{IG}_i = \\left(\\vec{x} - \\vec{x}'\\right) \\int_0^1\\left[\\nabla\\hat{f}\\left(\\vec{x}' + t\\left(\\vec{x} - \\vec{x}'\\right)\\right)\\right]_i\\,dt\n",
"\\end{equation}\n",
"\n",
"where $t$ is some increment along the path such that $\\vec{x}' + t\\left(\\vec{x} - \\vec{x}'\\right) = \\vec{x}'$ when $t = 0$ and $\\vec{x}' + t\\left(\\vec{x} - \\vec{x}'\\right) = \\vec{x}$ when $t = 1$. This gives us the integrated gradient for each feature $i$. The integrated gradients are the importance of each feature, but without the complexity of shattered gradients. There are some nice properties too, like $\\sum_i \\textrm{IG}_i = f(\\vec{x}) - f(\\vec{x}')$ so that the integrated gradients provide a complete partition of the change from the baseline to the prediction{cite}`sundararajan2017axiomatic`.\n",
"\n",
"Implementing integrated gradients is actually relatively simple. You approximate the path integral with a Riemann sum by breaking the path into a set of discrete inputs between the input features $\\vec{x}$ and the baseline $\\vec{x}'$. You compute the gradient of these inputs with the neural network. Then you multiply that by the change in features above baseline: $\\left(\\vec{x} - \\vec{x}'\\right)$.\n",
"\n",
"SmoothGrad is a similar idea to the integrated gradients. Rather than summing up the gradients along a path though, we sum gradients from random points nearby our prediction. The equation is:\n",
"\n",
"\\begin{equation}\n",
"\\textrm{SG}_i = \\sum_j^M\\left[\\nabla\\hat{f}\\left(\\vec{x}' + \\vec{\\epsilon}\\right)\\right]_i\n",
"\\end{equation}\n",
"\n",
"where $M$ is a choice of sample number and $\\vec{\\epsilon}$ is sampled from $D$ zero-mean Guassians {cite}`smilkov2017smoothgrad`. The only change in implementation here is to replace the path with a series of random perturbations.\n",
"\n",
"Beyond these gradient based approaches, Layer-wise Relevance Propagation (LRP) is another popular approach for feature importance analysis in neural networks. LRP works by doing a backwards propogation through the neural network that partitions the output value of one layer to the input features. It \"distributes relevance.\" What is unusual about LRP is that each layer type needs its own implementation. It doesn't rely on the analytic derivative, but instead a Taylor series expansion of the layer equation. There are variants for GNNs and sequence models, so that LRP can be used in most settings in materials and chemistry {cite}`Montavon2019`."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Shapley Values\n",
"\n",
"A model agnostic way to treat feature importance is with **Shapley values.** Shapley values come from game theory and are a solution to how to pay a coalition of cooperating players according to their contributions. Imagine each feature is a player and we would like to \"pay\" them according to their contribution to the predicted value. A Shapley value $\\phi_i(x)$ is the pay to feature $i$ at instance $x$. We break-up the predicted function value $\\hat{f}(x)$ into the Shapley values so that the sum of the pay is the function value: $\\sum_i \\phi_i(x) = \\hat{f}(x)$. This means you can interpret the Shapley value of a feature as its numerical contribution to the prediction. Shapley values are powerful because their calculation is agnostic to the model, they partition the predicted value among each feature, and they have other attributes that we would desire in an explanation of a prediction (symmetry, linearity, permutation invariant, etc.). Their disadvantage are that exact computation is combinatorial with respect to feature number and they have no sparsity, making them less helpful as feature number grows. Most methods we discuss here also have no sparsity. You can always force your model to be sparse to achieve sparse explanations, like with L1 regularization (see {doc}`layers`).\n",
"\n",
"Shapley values are computed as\n",
"\n",
"\\begin{equation}\n",
"\\phi_i(x) = \\frac{1}{Z}\\sum_{S \\in N \\backslash x_i}v(S\\cup x_i) - v(S)\n",
"\\end{equation}\n",
"$$\n",
"Z = \\frac{|S|!\\left(N - |S| - 1\\right)!}{N!}\n",
"$$\n",
"\n",
"where $S \\in N \\backslash x_i$ means all sets of features that exclude feature $x_i$, $S\\cup x_i$ means putting back feature $x_i$ into the set, and $v(S)$ is the value of $\\hat{f}(x)$ using only the features included in $S$, and $Z$ is a normalization value. The formula can be interpreted as the mean of all possible differences in $\\hat{f}$ formed by adding/removing feature $i$. \n",
"\n",
"One immediate concern though is how can we \"remove\" feature $i$ from a model equation? We marginalize out feature $i$. Recall a marginal is a way to integrate out a random variable $P(x) = \\int\\, P(x,y)\\,dy$. That integrates over all possible $x$ values. Marginalization can be used on functions of random variables, which obviously are also random variables, by taking an expectation: $E_y[f | X = x] = \\int\\,f(X=x,y)P(X=x,y)\\, dy$. I've emphasized that the random variable $X$ is fixed in the integral and thus $E_y[f]$ is a function of $x$. $y$ is removed by computing the expected value of $f(x,y)$ where $x$ is fixed (the function argument). We're essentially replacing $f(x,y)$ with a new function $E_y[f]$ that is the average of all possible $y$ values. I'm over-explaining this though, it's quite intuitive once you see the code below. The other detail is that *value* is the change relative to the average of $\\hat{f}$. You can typically ignore this extra term - it cancels, but I include it for completeness. Thus the value equation becomes {cite}`vstrumbelj2014explaining`:\n",
"\n",
"\\begin{equation}\n",
"v(x_i) = \\int\\,f(x_0, x_1, \\ldots, x_i,\\ldots, x_N)P(x_0, x_1, \\ldots, x_i,\\ldots, x_N)\\, dx_i - E\\left[\\hat{f}(\\vec{x})\\right]\n",
"\\end{equation}\n",
"\n",
"How do we compute the marginal $\\int\\,f(x_0, x_1, \\ldots, x_i,\\ldots, x_N)P(x_0, x_1, \\ldots, x_i,\\ldots, x_N)\\, dx_i$? We do not have a known probability distribution $P(\\vec{x})$. We can sample from $P(\\vec{x})$ by considering our data as an **empirical distribution**. That is, we can sample from $P(\\vec{x})$ by sampling data points. There is a little bit of complexity here because we need to sample the $\\vec{x}$'s jointly, we cannot just mix together individual features randomly because there are correlations between features that will be removed. \n",
"\n",
"Strumbelj et al. {cite}`vstrumbelj2014explaining` showed that we can directly estimate the $i$th Shapley value with:\n",
"\n",
"\\begin{equation}\n",
"\\phi_i(\\vec{x}) = \\frac{1}{M}\\sum^M \\hat{f}\\left(\\vec{z}_{+i}\\right) - \\hat{f}\\left(\\vec{z}_{-i}\\right)\n",
"\\end{equation}\n",
"\n",
"where $\\vec{z}$ is a \"chimera\" example constructed from the real example $\\vec{x}$ and a randomly drawn example $\\vec{x}'$. We randomly select from $\\vec{x}$ and $\\vec{x}'$ to construct $\\vec{z}$, except $\\vec{z}_{+i}$ specifically has the $i$th feature from the example $\\vec{x}$ and $\\vec{z}_{-i}$ has the $i$th feature from the random example $\\vec{x}'$. $M$ is chosen large enough to get a good sample for this value. {cite}`vstrumbelj2014explaining` gives guidance on choosing $M$, but basically as large $M$ as computationally feasible reasonable. One change in this approximation though is that we end-up with an explicit term for the expectation (sometimes denoted $\\phi_0$) so that our \"completeness\" equation is:\n",
"\n",
"\\begin{equation}\n",
"\\sum_i \\phi_i(\\vec{x}) = \\hat{f}(\\vec{x}) - E[\\hat{f}(\\vec{x})]\n",
"\\end{equation}\n",
"\n",
"Or if you explicitly include expectation as $\\phi_0$, which is independent of $\\vec{x}$\n",
"\n",
"\\begin{equation}\n",
"\\phi_0 + \\sum_{i=1} \\phi_i(\\vec{x}) = \\hat{f}(\\vec{x})\n",
"\\end{equation}\n",
"\n",
"```{margin}\n",
"Marginalizing features *is not* the same as replacing features with their average.\n",
"```\n",
"With this efficient approximation method, the strong theory, and independence of model choice, Shapley values are an excellent choice for describing feature importance for predictions."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Running This Notebook\n",
"\n",
"\n",
"Click the above to launch this page as an interactive Google Colab. See details below on installing packages.\n",
"\n",
"````{tip} My title\n",
":class: dropdown\n",
"To install packages, execute this code in a new cell. \n",
"\n",
"```\n",
"!pip install dmol-book\n",
"```\n",
"\n",
"If you find install problems, you can get the latest working versions of packages used in [this book here](https://github.com/whitead/dmol-book/blob/main/package/setup.py)\n",
"\n",
"````"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import haiku as hk\n",
"import jax\n",
"import tensorflow as tf\n",
"import jax.numpy as jnp\n",
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"import urllib\n",
"from functools import partial\n",
"from jax.example_libraries import optimizers as opt\n",
"import dmol"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"np.random.seed(0)\n",
"tf.random.set_seed(0)\n",
"\n",
"ALPHABET = [\n",
" \"-\",\n",
" \"A\",\n",
" \"R\",\n",
" \"N\",\n",
" \"D\",\n",
" \"C\",\n",
" \"Q\",\n",
" \"E\",\n",
" \"G\",\n",
" \"H\",\n",
" \"I\",\n",
" \"L\",\n",
" \"K\",\n",
" \"M\",\n",
" \"F\",\n",
" \"P\",\n",
" \"S\",\n",
" \"T\",\n",
" \"W\",\n",
" \"Y\",\n",
" \"V\",\n",
"]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We now define a few functions we'll need to convert between amino acid sequence and one-hot vectors. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"scrolled": true
},
"outputs": [],
"source": [
"def seq2array(seq, L=200):\n",
" return np.pad(list(map(ALPHABET.index, seq)), (0, L - len(seq))).reshape(1, -1)\n",
"\n",
"\n",
"def array2oh(a):\n",
" a = np.squeeze(a)\n",
" o = np.zeros((len(a), 21))\n",
" o[np.arange(len(a)), a] = 1\n",
" return o.astype(np.float32).reshape(1, -1, 21)\n",
"\n",
"\n",
"urllib.request.urlretrieve(\n",
" \"https://github.com/whitead/dmol-book/raw/main/data/hemolytic.npz\",\n",
" \"hemolytic.npz\",\n",
")\n",
"with np.load(\"hemolytic.npz\", \"rb\") as r:\n",
" pos_data, neg_data = r[\"positives\"], r[\"negatives\"]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Feature Importance Example\n",
"\n",
"Let's see an example of these feature importance methods on a peptide prediction task to predict if a peptide will kill red blood cells (hemolytic). This is similar to the solubility prediction example from {doc}`layers`. The data is from {cite}`barrett2020investigating`. The model takes in peptides sequences (e.g., `DDFRD`) and predicts the probability that the peptide is hemolytic. The goal of the feature importance method here will be to identify which amino acids matter most for the hemolytic activity. The hidden-cell below loads and processes the data into a dataset."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"tags": [
"hide-cell"
]
},
"outputs": [],
"source": [
"# create labels and stich it all into one\n",
"# tensor\n",
"labels = np.concatenate(\n",
" (\n",
" np.ones((pos_data.shape[0], 1), dtype=pos_data.dtype),\n",
" np.zeros((neg_data.shape[0], 1), dtype=pos_data.dtype),\n",
" ),\n",
" axis=0,\n",
")\n",
"features = np.concatenate((pos_data, neg_data), axis=0)\n",
"# we now need to shuffle before creating TF dataset\n",
"# so that our train/test/val splits are random\n",
"i = np.arange(len(labels))\n",
"np.random.shuffle(i)\n",
"labels = labels[i]\n",
"features = features[i]\n",
"L = pos_data.shape[-2]\n",
"\n",
"# need to add token for empty amino acid\n",
"# dataset just has all zeros currently\n",
"features = np.concatenate((np.zeros((features.shape[0], L, 1)), features), axis=-1)\n",
"features[np.sum(features, -1) == 0, 0] = 1.0\n",
"\n",
"batch_size = 16\n",
"full_data = tf.data.Dataset.from_tensor_slices((features.astype(np.float32), labels))\n",
"\n",
"# now split into val, test, train\n",
"N = pos_data.shape[0] + neg_data.shape[0]\n",
"split = int(0.1 * N)\n",
"test_data = full_data.take(split).batch(batch_size)\n",
"nontest = full_data.skip(split)\n",
"val_data, train_data = nontest.take(split).batch(batch_size), nontest.skip(\n",
" split\n",
").shuffle(1000).batch(batch_size)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We rebuild the convolution model in Jax (using [Haiku](https://github.com/deepmind/dm-haiku)) to make working with gradients a bit easier. We also make a few changes to the model -- we pass in the sequence length and amino acid fractions as extra information in addition to the convolutions."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def binary_cross_entropy(logits, y):\n",
" \"\"\"Binary cross entropy without sigmoid. Works with logits directly\"\"\"\n",
" return (\n",
" jnp.clip(logits, 0, None) - logits * y + jnp.log(1 + jnp.exp(-jnp.abs(logits)))\n",
" )\n",
"\n",
"\n",
"def model_fn(x):\n",
" # get fractions, excluding skip character\n",
" aa_fracs = jnp.mean(x, axis=1)[:, 1:]\n",
" # compute convolutions/poolings\n",
" mask = jnp.sum(x[..., 1:], axis=-1, keepdims=True)\n",
" for kernel, pool in zip([5, 3, 3], [4, 2, 2]):\n",
" x = hk.Conv1D(16, kernel)(x) * mask\n",
" x = jax.nn.tanh(x)\n",
" x = hk.MaxPool(pool, pool, \"VALID\")(x)\n",
" mask = hk.MaxPool(pool, pool, \"VALID\")(mask)\n",
" # combine fractions, length, and convolution ouputs\n",
" x = jnp.concatenate((hk.Flatten()(x), aa_fracs, jnp.sum(mask, axis=1)), axis=1)\n",
" # dense layers. no bias, so zeros give P=0.5\n",
" logits = hk.Sequential(\n",
" [\n",
" hk.Linear(256, with_bias=False),\n",
" jax.nn.tanh,\n",
" hk.Linear(64, with_bias=False),\n",
" jax.nn.tanh,\n",
" hk.Linear(1, with_bias=False),\n",
" ]\n",
" )(x)\n",
" return logits\n",
"\n",
"\n",
"model = hk.without_apply_rng(hk.transform(model_fn))\n",
"\n",
"\n",
"def loss_fn(params, x, y):\n",
" logits = model.apply(params, x)\n",
" return jnp.mean(binary_cross_entropy(logits, y))\n",
"\n",
"\n",
"@jax.jit\n",
"def hemolytic_prob(params, x):\n",
" logits = model.apply(params, x)\n",
" return jax.nn.sigmoid(jnp.squeeze(logits))\n",
"\n",
"\n",
"@jax.jit\n",
"def accuracy_fn(params, x, y):\n",
" logits = model.apply(params, x)\n",
" return jnp.mean((logits >= 0) * y + (logits < 0) * (1 - y))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"rng = jax.random.PRNGKey(0)\n",
"xi, yi = features[:batch_size], labels[:batch_size]\n",
"params = model.init(rng, xi)\n",
"\n",
"opt_init, opt_update, get_params = opt.adam(1e-2)\n",
"opt_state = opt_init(params)\n",
"\n",
"\n",
"@jax.jit\n",
"def update(step, opt_state, x, y):\n",
" value, grads = jax.value_and_grad(loss_fn)(get_params(opt_state), x, y)\n",
" opt_state = opt_update(step, grads, opt_state)\n",
" return value, opt_state"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"epochs = 32\n",
"for e in range(epochs):\n",
" avg_v = 0\n",
" for i, (xi, yi) in enumerate(train_data):\n",
" v, opt_state = update(i, opt_state, xi.numpy(), yi.numpy())\n",
" avg_v += v\n",
"opt_params = get_params(opt_state)\n",
"\n",
"\n",
"def predict(x):\n",
" return jnp.squeeze(model.apply(opt_params, x))\n",
"\n",
"\n",
"def predict_prob(x):\n",
" return hemolytic_prob(opt_params, x)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"If you're having trouble following the code, that's OK! The goal of this chapter is to show how to get explanations of a model, not necessarily how to build the model. So focus on the next few lines where I show how to use the model to get predictions and explain them. The model is called via `predict(x)` for logits or `predict_prob` for probability."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"```{margin} Sequence Models\n",
"Review {doc}`NLP` to refresh ideas about\n",
"one-hots and sequence models.\n",
"```\n",
"\n",
"Let's try an amino acid sequence, a peptide, to get a feel for the model. The model outputs logits (logarithm of odds), which we put through a sigmoid to get probabilities. The peptides must be converted from a sequence to a matrix of one-hot column vectors. We'll try two known sequences: Q is known to be common in hemolytic residues and the second sequence is poly-G, which is the simplest amino acid."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"scrolled": true
},
"outputs": [],
"source": [
"s = \"QQQQQ\"\n",
"sm = array2oh(seq2array(s))\n",
"p = predict_prob(sm)\n",
"print(f\"Probability {s} of being hemolytic {p:.2f}\")\n",
"\n",
"s = \"GGGGG\"\n",
"sm = array2oh(seq2array(s))\n",
"p = predict_prob(sm)\n",
"print(f\"Probability {s} of being hemolytic {p:.2f}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"It looks reasonable -- the model matches our intuition about these two sequences\n",
"\n",
"\n",
"Now we compute the accuracy of our model, which is quite good. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"acc = []\n",
"for xi, yi in test_data:\n",
" acc.append(accuracy_fn(opt_params, xi.numpy(), yi.numpy()))\n",
"print(jnp.mean(np.array(acc)))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Gradients\n",
"\n",
"Now to start examining *why* a particular sequence is hemolytic! We'll begin by computing the gradients with respect to input -- the naieve approach that is susceptible to shattered gradients. Computing this is a component in the process for integrated and smooth gradients, so not wasted effort. We will use a more complex peptide sequence that is known to be hemolytic to get more interesting analysis. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"tags": [
"hide-cell"
]
},
"outputs": [],
"source": [
"def plot_grad(g, s, ax=None):\n",
" # g = np.array(g)\n",
" if ax is None:\n",
" plt.figure()\n",
" ax = plt.gca()\n",
" if len(g.shape) == 3:\n",
" h = g[0, np.arange(len(s)), list(map(ALPHABET.index, s))]\n",
" else:\n",
" h = g\n",
" ax.bar(np.arange(len(s)), height=h)\n",
" ax.set_xticks(range(len(s)))\n",
" ax.set_xticklabels(s)\n",
" ax.set_xlabel(\"Amino Acid $x_i$\")\n",
" ax.set_ylabel(r\"Gradient $\\frac{\\partial \\hat{f}(\\vec{x})}{\\partial x_i}$\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"s = \"RAGLQFPVGRLLRRLLRRLLR\"\n",
"sm = array2oh(seq2array(s))\n",
"p = predict_prob(sm)\n",
"print(f\"Probability {s} of being hemolytic {p:.2f}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The code is quite simple, just a gradient computation."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"gradient = jax.grad(predict, 0)\n",
"g = gradient(sm)\n",
"plot_grad(g, s)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Remember that the model outputs logits. Positive value of the gradient mean this amino acid is responsible for pushing hemolytic probability higher and negative values mean the amino acid is pushing towards non-hemolytic. Interestingly, you can see a strong position dependence on the leucine (L) and arginine (R).\n",
"\n",
"### Integrated Gradients\n",
"\n",
"We'll now implement the integrated gradients method. We go through three basic steps:\n",
"\n",
"1. Create an array of inputs going from baseline to input peptide\n",
"2. Evaluate gradient on each input\n",
"3. Compute the sum of the gradients and multiply it by difference between baseline and peptide\n",
"\n",
"The baseline for us is all zeros -- which gives a probability of 0.5 (logits = 0, a model root). This baseline is exactly on the decision boundary. You could use other baselines like all glycines or all alanines, just they should be at or near probability of 0.5. You can find a detailed and interactive exploration of the baseline choice in {cite}`sturmfels2020visualizing`."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def integrated_gradients(sm, N):\n",
" baseline = jnp.zeros((1, L, 21))\n",
" t = jnp.linspace(0, 1, N).reshape(-1, 1, 1)\n",
" path = baseline * (1 - t) + sm * t\n",
"\n",
" def get_grad(pi):\n",
" # compute gradient\n",
" # add/remove batch axes\n",
" return gradient(pi[jnp.newaxis, ...])[0]\n",
"\n",
" gs = jax.vmap(get_grad)(path)\n",
" # sum pieces (Riemann sum), multiply by (x - x')\n",
" ig = jnp.mean(gs, axis=0, keepdims=True) * (sm - baseline)\n",
" return ig\n",
"\n",
"\n",
"ig = integrated_gradients(sm, 1024)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"plot_grad(ig, s)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We see that the position dependence has become more pronounced, with arginine being very sensitive to position. Relatively little has qualitatively changed between this and the vanilla gradients.\n",
"\n",
"### SmoothGrad\n",
"\n",
"To do SmoothGrad, our steps are almost identicial:\n",
"\n",
"1. Create an array of inputs that are random pertubations of the input peptide\n",
"2. Evaluate gradient on each input\n",
"3. Compute the mean of the gradients\n",
"\n",
"There is one additional hyperparameter, $\\sigma$, which in principle should be as small as possible while still causing the model output to change.\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def smooth_gradients(sm, N, rng, sigma=1e-3):\n",
" baseline = jnp.zeros((1, L, 21))\n",
" t = jax.random.normal(rng, shape=(N, sm.shape[1], sm.shape[2])) * sigma\n",
" path = sm + t\n",
" # remove examples that are negative and force summing to 1\n",
" path = jnp.clip(path, 0, 1)\n",
" path /= jnp.sum(path, axis=2, keepdims=True)\n",
"\n",
" def get_grad(pi):\n",
" # compute gradient\n",
" # add/remove batch axes\n",
" return gradient(pi[jnp.newaxis, ...])[0]\n",
"\n",
" gs = jax.vmap(get_grad)(path)\n",
" # mean\n",
" ig = jnp.mean(gs, axis=0, keepdims=True)\n",
" return ig\n",
"\n",
"\n",
"sg = smooth_gradients(sm, 1024, jax.random.PRNGKey(0))\n",
"plot_grad(sg, s)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"It looks remarkably similar to the vanilla gradient setting -- probably because our 1D input/shallow network is not as sensitive to shattered gradients."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Shapley Value\n",
"\n",
"Now we will approximate the Shapley values for each feature using Equation 10.9. The Shapley value computation is different than previous approaches because it does not require gradients. The basic algorithm is:\n",
"\n",
"1. select random point x'\n",
"2. create point z by combining x and x'\n",
"3. compute change in predicted function\n",
"\n",
"One efficiency change we make is to prevent modifying the sequence in its padding -- basically prevent exploring making the sequence longer. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def shapley(i, sm, sampled_x, rng, model):\n",
" M, F, *_ = sampled_x.shape\n",
" z_choice = jax.random.bernoulli(rng, shape=(M, F))\n",
" # only swap out features within length of sm\n",
" mask = jnp.sum(sm[..., 1:], -1)\n",
" z_choice *= mask\n",
" z_choice = 1 - z_choice\n",
" # construct with and w/o ith feature\n",
" z_choice = z_choice.at[:, i].set(0.0)\n",
" z_choice_i = z_choice.at[:, i].set(1.0)\n",
" # select them via multiplication\n",
" z = sm * z_choice[..., jnp.newaxis] + sampled_x * (1 - z_choice[..., jnp.newaxis])\n",
" z_i = sm * z_choice_i[..., jnp.newaxis] + sampled_x * (\n",
" 1 - z_choice_i[..., jnp.newaxis]\n",
" )\n",
" v = model(z_i) - model(z)\n",
" return jnp.squeeze(jnp.mean(v, axis=0))\n",
"\n",
"\n",
"# assume data is alrady shuffled, so just take M\n",
"M = 4096\n",
"sl = len(s)\n",
"sampled_x = train_data.unbatch().batch(M).as_numpy_iterator().next()[0]\n",
"# make batched shapley so we can compute for all features\n",
"bshapley = jax.vmap(shapley, in_axes=(0, None, None, 0, None))\n",
"sv = bshapley(\n",
" jnp.arange(sl),\n",
" sm,\n",
" sampled_x,\n",
" jax.random.split(jax.random.PRNGKey(0), sl),\n",
" predict,\n",
")\n",
"\n",
"# compute global expectation\n",
"eyhat = 0\n",
"for xi, yi in full_data.batch(M).as_numpy_iterator():\n",
" eyhat += jnp.mean(predict(xi))\n",
"eyhat /= len(full_data)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"tags": [
"remove-cell"
]
},
"outputs": [],
"source": [
"from myst_nb import glue\n",
"\n",
"val = []\n",
"ms = np.linspace(2, 400, 25)\n",
"for m in ms.astype(np.int32):\n",
" sampled_x = train_data.unbatch().batch(m).as_numpy_iterator().next()[0]\n",
" val.append(\n",
" eyhat\n",
" + jnp.sum(\n",
" bshapley(\n",
" jnp.arange(sl),\n",
" sm,\n",
" sampled_x,\n",
" jax.random.split(jax.random.PRNGKey(0), sl),\n",
" predict,\n",
" )\n",
" )\n",
" )\n",
"plt.plot(ms, val, \"-o\", label=\"Sum of Shapley Values\")\n",
"plt.xlabel(\"Sample Number\")\n",
"plt.ylabel(\"Function Value [logits]\")\n",
"plt.axhline(predict(sm), color=\"C1\", label=r\"$\\hat{f}\\left(\\vec{x}\\right)$\")\n",
"plt.legend()\n",
"plt.tight_layout()\n",
"glue(\"shapley_convg\", plt.gcf(), display=False)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"One nice check on Shapley values is that we can check that their sum is equal to the value of model function minus the expect value across all instances. Note we made approximations to use the Equation from {cite}`vstrumbelj2014explaining` so that we cannot expect perfect agreement. That value is computed as:\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"print(np.sum(sv), predict(sm))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"which is *some* disagreement. This is an effect of the approximation method we're using. We can check that by examining how sample number effects the sum of Shapley values.\n",
"\n",
"```{glue:figure} shapley_convg\n",
"---\n",
"name: shapley_convg\n",
"---\n",
"A comparison of sum of Shapley values and function value as a function of samples number in the Shapley value approximation.\n",
"```\n",
"\n",
"It is slowly converging. Finally we can view the individual Shapley values, which is our explanation."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"plot_grad(sv, s)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The four methods are shown side-by-side below. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"heights = []\n",
"plt.figure(figsize=(12, 4))\n",
"x = np.arange(len(s))\n",
"for i, (gi, l) in enumerate(zip([g, ig, sg], [\"Gradient\", \"Integrated\", \"Smooth\"])):\n",
" h = gi[0, np.arange(len(s)), list(map(ALPHABET.index, s))]\n",
" plt.bar(x + i / 5 - 1 / 4, h, width=1 / 5, edgecolor=\"black\", label=l)\n",
"plt.bar(x + 3 / 5 - 1 / 4, sv, width=1 / 5, edgecolor=\"black\", label=\"Shapley\")\n",
"ax = plt.gca()\n",
"ax.set_xticks(range(len(s)))\n",
"ax.set_xticklabels(s)\n",
"ax.set_xlabel(\"Amino Acid $x_i$\")\n",
"ax.set_ylabel(r\"Importance [logits]\")\n",
"plt.legend()\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"As someone who works with peptides, I believe the Shapley is the most accurate here. I wouldn't expect the pattern of L and R to be that significant, which is what the Shapley values show. Another difference is that the Shapley values do not show the phenylalanine (F) as have a significant effect. \n",
"\n",
"What can we conclude from this information? We could perhaps add an explanation like this: \"The sequence is predicted to be hemolytic primarily because of the glutamine, proline, and arrangement of lecucine and arginine.\"\n",
"\n",
"## What is feature importance for?\n",
"\n",
"Feature importance rarely leads to a clear explanation that gives the cause for a prediction or insight that is actionable. The lack of causality can lead us to find meaning in feature explanations that do not exist{cite}`chuang2018comment`. Another caveat is remember that we are explaining *model*, not the actual chemical systems. For example, avoid saying \"Hemolytic activity was caused by the glutamine in position 5.\" Instead: \"Our model predicted hemolytic activity because of glutamine in position 5.\"\n",
"\n",
"An **actionable** explanation is one that shows how to modify the features to affect the outcome --- similar to saying we know the cause for an outcome. Thus, there is ongoing debate about if feature importance is an explanation {cite}`lipton2018mythos`. A popular line of work that tries to connect feature importance to human *concepts* is called Quantitative testing with concept activation vectors (TCAV) {cite}`kim2018interpretability`. I personally have moved away from feature importance for XAI because the explanations are not actionable or causal and often can add additional confusion."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Training Data Importance\n",
"\n",
"Another kind of explanation or interpretation we might desire is *which* training data points contribute most to a prediction. This is a more literal answer to the question: \"Why did my model predict this?\" -- neural networks are a result of training data and thus the answer to why a prediction is made can be traced to training data. Ranking training data for a given prediction helps us understand which training examples are causing the neural network to predict a value. This is like an influence function, $\\mathcal{I}(x_i, x)$, which gives a score of influence for training point $i$ and input $x$. The most straightforward way to compute the influence would be to train the neural network with (i.e., $\\hat{f}(x)$) and without $x_i$ (i.e., $\\hat{f}_{-x_i}(x)$) and define the influence as\n",
"\n",
"\\begin{equation}\n",
"\\mathcal{I}(x_i, x) = \\hat{f}_{-x_i}(x) - \\hat{f}(x)\n",
"\\end{equation}\n",
"\n",
"For example, if a prediction is higher after removing the training point $x_i$ from training, we would say that point has a positive influence. Computing this influence function requires training the model as many times as you have points -- typically computationally unfeasible. {cite}`koh2017understanding` show a way to approximate this by looking at infinitesimal changes to the *weights* of each training point. Computing these influence functions does require computing a Hessian with respect to the loss function and thus are not commonly used. If you're using JAX though, this is simple to do.\n",
"\n",
"```{margin}\n",
"If using a kernel model, the features are the training data. The above methods like integrated gradients, give training data importance.\n",
"```\n",
"\n",
"Training data importance provides an interpretation that is useful for deep learning experts. It tells you which training examples are most influential for a given prediction. This can help troubleshoot issues with data or tracing explanations for spurious predictions. However, a typical user of predictions from a deep learning model will probably be unsatisfied with a ranking of training data as an explanation. "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Surrogate Models\n",
"\n",
"One of the more general ideas in interpretability is to fit an interpretable model to a black box model *in the neighborhood of a specific example*. We assume that an interpretable model cannot be fit globally to a black box model -- otherwise we could just use the interpretable model and throw away the black box model. However, if we fit the interpretable model to just a small region around an example of interest, we can provide an explanation through the locally correct interpretable model. We call the interpretable model a **local surrogate model**. Examples of local surrogate models that are inherently interpretable include decision trees, linear models, sparse linear models (for succinct explanations), a Naive Bayes Classifier, etc.\n",
"\n",
"A popular algorithm for this process of fitting a local surrogate model is called Local Interpretable Model-Agnostic Explanations (LIME) {cite}`ribeiro2016should`. LIME fits the local surrogate model in the neighborhood of the example of interest utilizing the loss function that trained the original black box model. The loss function for the local surrogate model is weighted so that we value points closer to the example of interest as we regress the surrogate model. The LIME paper includes sparsifying the surrogate model in its notation, but we'll omit that from the loss equation since that is more of an attribute of the local surrogate model. Thus, our definition for the local surrogate model loss is\n",
"\n",
"\\begin{equation}\n",
"\\mathcal{l^s}\\left(x'\\right) = w(x', x)\\mathcal{l}\\left(\\hat{f}_s(x'), \\hat{f}(x')\\right)\n",
"\\end{equation}\n",
"\n",
"where $w(x', x)$ is a weight kernel function that weights points near example of interest $x$, $\\mathcal{l}(\\cdot,\\cdot)$ is the original black box model loss, $\\hat{f}(\\cdot)$ is the black box model, and $\\hat{f}_s(\\cdot)$ is the local surrogate model.\n",
"\n",
"\n",
"```{margin}\n",
"LIME as formulated in {cite}`ribeiro2016should` gives feature importance descriptions, but some surrogate models may be interpretable already. Like a decision tree.\n",
"```\n",
"\n",
"The weight function is a bit ad hoc -- it depends on the data type. For regression tasks with scalar labels, we use a kernel function and you have a variety of choices: Gaussian, cosine, Epanechnikov. For text, the LIME implementations use a [Hamming distance](https://en.wikipedia.org/wiki/Hamming_distance) which just counts number of text tokens which do not match between two strings. Images use the same distance but with superpixels being the same as the example or blank.\n",
"\n",
"How are the points $x'$ generated? In the continuous case $x'$ is sampled *uniformly*, which is quite a feat since feature spaces are often unbounded. You could sample $x'$ according to your weight function and then omit the weighting (since it was sampled according to that) to avoid issues like unbounded feature spaces. In general, LIME is a bit subjective in continuous vector feature spaces. For images and text, $x'$ is formed by masking tokens (words) and zeroing (making black) superpixels. This leads to explanations that should feel quite similar to Shapley values -- and indeed you can show LIME is equivalent to Shapley values with some small notation changes."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Counterfactuals \n",
"\n",
"```{margin} Counterfactuals\n",
"Our optimization formulation is what's used in XAI, but\n",
"counterfactuals in other contexts do not have the \"nearness\" \n",
"criterion\n",
"```\n",
"\n",
"A counterfactual is a solution to an optimization problem: find an example $x'$ that has a different label than $x$ and as close as possible to $x${cite}`wachter2017counterfactual`. You can formulate this like:\n",
"\n",
"```{math}\n",
":label: cf\n",
"\\textrm{minimize}\\qquad d(x, x')\\\\\n",
"\\textrm{such that}\\qquad \\hat{f}(x) \\neq \\hat{f}(x')\n",
"```\n",
"\n",
"In regression settings where $\\hat{f}(x)$ outputs a scalar, you need to modify your constraint to be some $\\Delta$ away from $\\hat{f}(x)$. $x'$ that satisfies this optimization problem is the counterfactual: a condition that did not occur and would have led to a different outcome. Typically finding $x'$ is treated as a derivative-free optimization. You can calculate $\\frac{\\partial \\hat{f}}{\\partial x'}$ and do constrained optimization, but in practice it can be faster to just randomly perturb $x$ until $\\hat{f}(x) \\neq \\hat{f}(x')$ like a Monte Carlo optimization. You can also use a generative model that can propose new $x'$ via unsupervised training. See {cite}`wellawatte_seshadri_white_2021` for a universal counterfactual generator for molecules. See {cite}`numeroso2020explaining` for a method specifically for graph neural networks of molecules. \n",
"\n",
"Defining distance is an important subjective concern, that we saw above for LIME. A common example for molecular structures is Tanimoto similarity (also known as Jaccard index) of molecular fingerprints/descriptors like Morgan fingerprints {cite}`rogers2010extended`.\n",
"\n",
"Counterfactuals have one disadvantage compared to Shapley values: they do not provide a *complete* explanation. Shapley values sum to the prediction, meaning we are not missing any part of the explanation. Counterfactuals modify as few features as possible (minimizing distance) and so may omit information about features that still contribute to a prediction. Of course, one advantage of Shapley values is that they are actionable. You can use the counterfactual directly.\n",
"\n",
"### Example\n",
"\n",
"We can quickly implement this idea for our peptide example above. We can define our distance as the Hamming distance. Then the closes $x'$ would be a single amino acid substitution. Let's just try enumerating those and see if we can achieve a label swap. We'll define a function that does a single substitution:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def check_cf(x, i, j):\n",
" # copy\n",
" x = jnp.array(x)\n",
" # substitute\n",
" x = x.at[:, i].set(0)\n",
" x = x.at[:, i, j].set(1)\n",
" return predict(x)\n",
"\n",
"\n",
"check_cf(sm, 0, 0)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Then build all possible substitutions with {obj}`jnp.meshgrid` and apply our function over that with {obj}`vmap`. {obj}`.ravel()` makes our array of indices be a single dimensions, so we do not need to worry about doing a complex vmap. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"ii, jj = jnp.meshgrid(jnp.arange(sl), jnp.arange(21))\n",
"ii, jj = ii.ravel(), jj.ravel()\n",
"x = jax.vmap(check_cf, in_axes=(None, 0, 0))(sm, ii, jj)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now we'll display all the single amino acid substitutions which resulted in a negative prediction - the logits are less than zero."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from IPython.core.display import display, HTML\n",
"\n",
"out = [\"\"]\n",
"for i, j in zip(ii[jnp.squeeze(x) < 0], jj[jnp.squeeze(x) < 0]):\n",
" out.append(f'{s[:i]}{ALPHABET[j]}{s[i+1:]} ')\n",
"out.append(\"\")\n",
"display(HTML(\"\".join(out)))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We have a few to choose from, but the interpretation is essentially exchange the glutamine with a hydrophobic group or replace the proline with V, F, A, or C to make the peptide non-hemolytic. Stated as a counterfactual: \"If the glutamine were exchanged with a hydrophobic amino acid, the peptide would not be hemolytic\"."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Specific Architectures Explanations\n",
"\n",
"The same principles above apply to GNNs, but there are competing ideas about how best to translate these ideas to work on graphs. See {cite}`agarwal2021towards` for a discussion of theory of interpretability specifically for GNNs and {cite}`yuan2020explainability` for a survey of the methods available for constructing explanations in GNNs. \n",
"\n",
"NLP is another area where there are specific approaches to constructing explanations and interpretation. See {cite}`madsen2021post` for a recent survey."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Model Agnostic Molecular Counterfacutal Explanations\n",
"\n",
"The main challenge associated with counterfactuals in chemistry is the difficulty in computing the derivative in {eq}`cf`. Therefore, most methods which focus on this task are specific to model architectures as we saw previously. Wellawatte et. al {cite}`wellawatte_seshadri_white_2021` have introduced a method named Molecular Model Agnostic Counterfactual Explanations (MMACE) to do this for molecules regardless of model architecture. \n",
"\n",
"The MMACE method is implemented in the `exmol` package. Given a molecule and a model, `exmol` is able to generate local counterfactual explanations. There are two main steps involved in the MMACE method. First, a local chemical space is expanded around the given base molecule. Next, each sample point is labeled with the user given model architecture. These labels are then used to identify the counterfactuals in the local chemical space. As the MMACE method is model agnostic, `exmol` package is able to generate counterfactuals for both classification and regression tasks. \n",
"\n",
"Now let's see how to generate molecular counterfactuals using `exmol`. In this example, we will train a random forest model which predicts clinical toxicology of molecules. For this binary classification task, we'll be using the same dataset we used in the {doc}`../ml/classification` chapter presented by the MoleculeNet group {cite}`wu2018moleculenet`. "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Running This Notebook\n",
"\n",
"\n",
"Click the above to launch this page as an interactive Google Colab. See details below on installing packages, either on your own environment or on Google Colab\n",
"\n",
"````{tip} My title\n",
":class: dropdown\n",
"To install packages, execute this code in a new cell\n",
"\n",
"```\n",
"!pip install exmol jupyter-book matplotlib numpy pandas seaborn sklearn mordred[full] rdkit\n",
"```\n",
"\n",
"````"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"tags": [
"hide-cell"
]
},
"outputs": [],
"source": [
"import pandas as pd\n",
"import matplotlib as mpl\n",
"import matplotlib.pyplot as plt\n",
"import rdkit, rdkit.Chem, rdkit.Chem.Draw\n",
"from rdkit.Chem.Draw import IPythonConsole\n",
"import numpy as np\n",
"import mordred, mordred.descriptors\n",
"import warnings\n",
"from sklearn.model_selection import train_test_split\n",
"\n",
"import tensorflow as tf\n",
"from tensorflow import keras\n",
"from tensorflow.keras.models import Sequential\n",
"from tensorflow.keras.layers import Dense\n",
"from tensorflow.keras.wrappers.scikit_learn import KerasClassifier\n",
"from sklearn.model_selection import StratifiedKFold\n",
"from sklearn.model_selection import cross_val_score\n",
"import exmol\n",
"\n",
"IPythonConsole.ipython_useSVG = True\n",
"\n",
"\n",
"toxdata = pd.read_csv(\n",
" \"https://github.com/whitead/dmol-book/raw/main/data/clintox.csv.gz\"\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# make object that can compute descriptors\n",
"calc = mordred.Calculator(mordred.descriptors, ignore_3D=True)\n",
"# make subsample from pandas df\n",
"molecules = [rdkit.Chem.MolFromSmiles(smi) for smi in toxdata.smiles]\n",
"\n",
"# view one molecule to make sure things look good.\n",
"molecules[0]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"After importing the data we generate input descriptors with `Mordred` package."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Get valid molecules from the sample\n",
"valid_mol_idx = [bool(m) for m in molecules]\n",
"valid_mols = [m for m in molecules if m]\n",
"# Compute molecular descriptors using Mordred\n",
"features = calc.pandas(valid_mols, quiet=True)\n",
"labels = toxdata[valid_mol_idx].FDA_APPROVED\n",
"# Standardize the features\n",
"features -= features.mean()\n",
"features /= features.std()\n",
"\n",
"# we have some nans in features, likely because std was 0\n",
"features = features.values.astype(float)\n",
"features_select = np.all(np.isfinite(features), axis=0)\n",
"features = features[:, features_select]\n",
"print(f\"We have {len(features)} features per molecule\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"In this example, we are using a simple dense neural network classifier implemented with `Keras`. First, let's train this simple classifier and use it to generate labels for the counterfactuals in `exmol`. By improving the performance of the trained model, you can expect more accurate results. But the following is example is sufficient to understand the workings of exmol for now. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Train and test spit\n",
"X_train, X_test, y_train, y_test = train_test_split(\n",
" features, labels, test_size=0.2, shuffle=True\n",
")\n",
"ft_shape = X_train.shape[-1]\n",
"\n",
"# reshape data\n",
"X_train = X_train.reshape(-1, ft_shape)\n",
"X_test = X_test.reshape(-1, ft_shape)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now let's build our model and compile! You can find an in depth introduction to dense models in the {doc}`introduction` chapter. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"model = tf.keras.models.Sequential()\n",
"model.add(tf.keras.Input(shape=(ft_shape,)))\n",
"model.add(tf.keras.layers.Dense(32, activation=\"relu\"))\n",
"model.add(tf.keras.layers.Dense(32))\n",
"model.add(Dense(1, activation=\"sigmoid\"))\n",
"model.compile(loss=\"binary_crossentropy\", optimizer=\"adam\", metrics=[\"accuracy\"])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Model training\n",
"model.fit(X_train, y_train, epochs=50, batch_size=32, verbose=0)\n",
"_, accuracy = model.evaluate(X_test, y_test)\n",
"print(f\"Model accuracy: {accuracy:.2%}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Seems like our model has a good accuracy!"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now we'll write a wrapper function that takes in SMILES and/or SELFIES molecule representations and output label predictions from the trained classifier. A detailed description on SELFIES can be found in {doc}`NLP` chapter. This wrapper function is given as an input to {obj}`exmol.sample_space` function in `exmol` to create a local chemical space around a given base molecule. `exmol` uses Superfast Traversal, Optimization, Novelty, Exploration and Discovery (STONED) algorithm {cite}`nigam_stoned` as a generative algorithm to expand the local space. Given a base molecule, the STONED algorithm randomly mutate SELFIES representations of the molecules. These mutations can be string substitutions, additions or deletions. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def model_eval(smiles, selfies):\n",
" molecules = [rdkit.Chem.MolFromSmiles(smi) for smi in smiles]\n",
" features = calc.pandas(molecules)\n",
" features = features.values.astype(float)\n",
" features = features[:, features_select]\n",
" labels = np.round(model.predict(np.nan_to_num(features).reshape(-1, ft_shape)))\n",
" return labels"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now we use STONED to sample local chemical space with {obj}`exmol.sample_space`. In this example, we will modify the size of the sample space with `num_samples` argument. The base molecule selected here is a non-FDA approved molecule. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"tags": [
"remove-output"
]
},
"outputs": [],
"source": [
"space = exmol.sample_space(\"C1CC(=O)NC(=O)C1N2CC3=C(C2=O)C=CC=C3N\", model_eval);"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Once the sample space is created, we can identify counterfactuals in the local chemical space using {obj}`exmol.sample_space` function. Each counterfactual is a python `dataclass` that contains additional information. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"exps = exmol.cf_explain(space, 2)\n",
"exps[1]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"You can easily visualize the generated counterfactuals using the plotting codes in `exmol`: {obj}`exmol.sample_space` and {obj}`exmol.sample_space` Similarity between the base and counterfactuals is the Tanimoto similarity of ECFP4 fingerprints. Top 3 counterfactuals are the shown here."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"exmol.plot_cf(exps, nrows=1)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The base molecule which we selected here is NOT FDA approved. By looking at the generated counterfactuals we can conclude that, the heterocyclic group has an impact on the toxicity of the base. Therefore, by altering the heterocylic group, the base molecule might be made non-toxic according to our model. This also shows why counterfactual explanations give actionable insight into how modifications can be made. "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We can also visualize the generated chemical space!"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"exmol.plot_space(space, exps)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Chapter Summary \n",
"\n",
"* Interpretation of deep learning models is imperative for ensuring model correctness, making predictions useful to humans, and can be required for legal compliance.\n",
"* Interpretability of neural networks is part of a broader topic of explainability in AI (XAI), a topic that is in its infancy\n",
"* An *explanation* is still ill-defined, but most often is expressed in terms of model features.\n",
"* Strategies for explanations include feature importance, training data importance, counterfactuals, and surrogate models that are locally accurate,\n",
"* Most explanations are generated per-example (at inference).\n",
"* The most systematic but expensive to compute explanations are Shapley values.\n",
"* Some argue that counterfactuals provide the most intuitive and satisfying explanations, but they may not be complete explanations. \n",
"* `exmol` is a software that generate model agnostic molecular counterfactual explanations."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Exercises\n",
"\n",
"1. Computing feature importance requires computing $\\nabla \\hat{f}(x)$ - the gradient of the output with respect to the input. Is this the same gradient we compute when training a neural network?\n",
"\n",
"2. Why might $\\nabla \\hat{f}(x)$ be more difficult when the input is a graph (molecule) instead of an image or dense vector?\n",
"\n",
"3. Some of the attributes of an explanation are if it's actionable, if it's faithful (agrees with NN), if it's sparse, and if it's complete. Make a table comparing these attributes of explanations generated by training data importance, feature importance, surrogate models, and counterfactual methods. \n",
"\n",
"4. Can we average feature importances across the whole training dataset to provide a global explanation?"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Cited References\n",
"\n",
"```{bibliography}\n",
":style: unsrtalpha\n",
":filter: docname in docnames\n",
"```"
]
}
],
"metadata": {
"celltoolbar": "Tags",
"kernelspec": {
"display_name": "Python 3.7.8 64-bit",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.8"
},
"vscode": {
"interpreter": {
"hash": "3e5a039a7a113538395a7d74f5574b0c5900118222149a18efb009bf03645fce"
}
}
},
"nbformat": 4,
"nbformat_minor": 4
}