\n",
" \n",
"

\n",
"\n",
"\n",
"\n",
"This polymer has each bead (atom) joined by a harmonic bond, a harmonic angle between each three, and a Lennard-Jones interaction potential. Knowing these items will not be necessary for the example. We'll construct a VAE that can compress the trajectory to some latent space and generate new conformations.\n",
"\n",
"To begin, we'll use the lessons learned from {doc}`data` about how to align points from a trajectory. This will then serve as our training data. The space of our problem will be 12 2D vectors. Our system need not be permutation invariant, so we can flatten these vectors into a 24 dimensional input. The code belows loads and aligns the trajectory"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"tags": [
"hide-cell"
]
},
"outputs": [],
"source": [
"import numpy as np\n",
"\n",
"###---------Transformation Functions----###\n",
"def center_com(paths):\n",
" \"\"\"Align paths to COM at each frame\"\"\"\n",
" coms = np.mean(paths, axis=-2, keepdims=True)\n",
" return paths - coms\n",
"\n",
"\n",
"def make_2drot(angle):\n",
" mats = np.array([[np.cos(angle), -np.sin(angle)], [np.sin(angle), np.cos(angle)]])\n",
" # swap so batch axis is first\n",
" return np.swapaxes(mats, 0, -1)\n",
"\n",
"\n",
"def find_principle_axis(points):\n",
" \"\"\"Compute single principle axis for points\"\"\"\n",
" inertia = points.T @ points\n",
" evals, evecs = np.linalg.eigh(inertia)\n",
" # get biggest eigenvalue\n",
" order = np.argsort(evals)\n",
" return evecs[:, order[-1]]\n",
"\n",
"\n",
"def align_principle(paths, axis_finder=find_principle_axis):\n",
" # This is a degenarate version, I removed mirror disambiguation\n",
" # to make latent space jump less. Data augmentation will\n",
" # have to overcome this issue\n",
" # the code is commented out below\n",
" vecs = [axis_finder(p) for p in paths]\n",
" vecs = np.array(vecs)\n",
" # find angle to rotate so these are pointed towards pos x\n",
" cur_angle = np.arctan2(vecs[:, 1], vecs[:, 0])\n",
" # cross = np.cross(vecs[:,0], vecs[:,1])\n",
" rot_angle = -cur_angle # - (cross < 0) * np.pi\n",
" rot_mat = make_2drot(rot_angle)\n",
" return paths @ rot_mat\n",
"\n",
"\n",
"###-----------------------------------###"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"urllib.request.urlretrieve(\n",
" \"https://github.com/whitead/dmol-book/raw/master/data/long_paths.npz\",\n",
" \"long_paths.npz\",\n",
")\n",
"paths = np.load(\"long_paths.npz\")[\"arr\"]\n",
"# transform to be rot/trans invariant\n",
"data = align_principle(center_com(paths))\n",
"cmap = plt.get_cmap(\"cool\")\n",
"for i in range(0, data.shape[0], 16):\n",
" plt.plot(data[i, :, 0], data[i, :, 1], \"-\", alpha=0.1, color=\"C2\")\n",
"plt.title(\"All Frames\")\n",
"plt.xticks([])\n",
"plt.yticks([])\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Before training, let's examine some of the **marginals** of the data. Marginals mean we've transformed (by integration) our probability distribution to be a function of only 1-2 variables so that we can plot nicely. We'll look at the pairwise distance between points. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"fig, axs = plt.subplots(ncols=4, squeeze=True, figsize=(16, 4))\n",
"for i, j in enumerate(range(1, 9, 2)):\n",
" axs[i].set_title(f\"Dist between 0-{j}\")\n",
" sns.distplot(np.linalg.norm(data[:, 0] - data[:, j], axis=1), ax=axs[i])\n",
"plt.tight_layout()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"These look a little like the chi distribution with two degrees of freedom. Notice that the support (x-axis) changes between them though. We'll keep an eye on these when we evaluate the efficacy of our VAE. \n",
"\n",
"### VAE Model\n",
"\n",
"We'll build the VAE like above. I will make two changes. I will use JAX's random number generator and I will make the number of layers variable. The code is hidden below, but you can expand to see the details. We'll be starting with 4 layers total (3 hidden) with a hidden layer dimension of 256. Another detail is that we flatten the input/output since the order is preserved and thus we do not worry about separating the x,y dimension out."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"tags": [
"hide-cell"
]
},
"outputs": [],
"source": [
"input_dim = 12 * 2\n",
"hidden_units = 256\n",
"num_layers = 4\n",
"latent_dim = 2\n",
"\n",
"\n",
"def init_theta(input_dim, hidden_units, latent_dim, num_layers, key, scale=0.1):\n",
" key, subkey = jax.random.split(key)\n",
" w1 = jax.random.normal(key=subkey, shape=(hidden_units, latent_dim)) * scale\n",
" b1 = jnp.zeros(hidden_units)\n",
" theta = [(w1, b1)]\n",
" for i in range(1, num_layers - 1):\n",
" key, subkey = jax.random.split(key)\n",
" w = jax.random.normal(key=subkey, shape=(hidden_units, hidden_units)) * scale\n",
" b = jnp.zeros(hidden_units)\n",
" theta.append((w, b))\n",
" key, subkey = jax.random.split(key)\n",
" w = jax.random.normal(key=subkey, shape=(input_dim, hidden_units)) * scale\n",
" b = jnp.zeros(input_dim)\n",
" theta.append((w, b))\n",
" return theta, key\n",
"\n",
"\n",
"def decoder(z, theta):\n",
" num_layers = len(theta)\n",
" for i in range(num_layers - 1):\n",
" w, b = theta[i]\n",
" z = jax.nn.relu(w @ z + b)\n",
" w, b = theta[-1]\n",
" x = w @ z + b\n",
" return x\n",
"\n",
"\n",
"def init_phi(input_dim, hidden_units, latent_dim, num_layers, key, scale=0.1):\n",
" key, subkey = jax.random.split(key)\n",
" w1 = jax.random.normal(key=subkey, shape=(hidden_units, input_dim)) * scale\n",
" b1 = jnp.zeros(hidden_units)\n",
" phi = [(w1, b1)]\n",
" for i in range(1, num_layers - 1):\n",
" key, subkey = jax.random.split(key)\n",
" w = jax.random.normal(key=subkey, shape=(hidden_units, hidden_units)) * scale\n",
" b = jnp.zeros(hidden_units)\n",
" phi.append((w, b))\n",
" key, subkey = jax.random.split(key)\n",
" w = jax.random.normal(key=subkey, shape=(latent_dim * 2, hidden_units)) * scale\n",
" b = jnp.zeros(latent_dim * 2)\n",
" phi.append((w, b))\n",
" return phi, key\n",
"\n",
"\n",
"def encoder(x, phi):\n",
" num_layers = len(phi)\n",
" for i in range(num_layers - 1):\n",
" w, b = phi[i]\n",
" x = jax.nn.relu(w @ x + b)\n",
" w, b = phi[-1]\n",
" hz = w @ x + b\n",
" hz = hz.reshape(-1, 2)\n",
" mu = hz[:, 0:1]\n",
" std = jax.nn.softplus(hz[:, 1:2])\n",
" return jnp.concatenate((mu, std), axis=1)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Loss\n",
"\n",
"The loss function is similar to above, but I will not even bother with the Gaussian outputs. You can see the only change is that we drop the output Gaussian standard deviation from the loss, which remember was not trainable anyway. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"@jax.jit\n",
"def loss(x, theta, phi, rng_key):\n",
" \"\"\"VAE Loss\"\"\"\n",
" # reconstruction loss\n",
" sampled_z_params = encoder(x, phi)\n",
" # reparameterization trick\n",
" # we use standard normal sample and multiply by parameters\n",
" # to ensure derivatives correctly propogate to encoder\n",
" sampled_z = (\n",
" jax.random.normal(rng_key, shape=(latent_dim,)) * sampled_z_params[:, 1]\n",
" + sampled_z_params[:, 0]\n",
" )\n",
" # MSE now instead\n",
" xp = decoder(sampled_z, theta)\n",
" rloss = jnp.sum((xp - x) ** 2)\n",
" # LK loss\n",
" klloss = (\n",
" -0.5\n",
" - jnp.log(sampled_z_params[:, 1] + 1e-8)\n",
" + 0.5 * sampled_z_params[:, 0] ** 2\n",
" + 0.5 * sampled_z_params[:, 1] ** 2\n",
" )\n",
" # combined\n",
" return jnp.array([rloss, jnp.mean(klloss)])\n",
"\n",
"\n",
"# update compiled functions\n",
"batched_loss = jax.vmap(loss, in_axes=(0, None, None, None), out_axes=0)\n",
"batched_decoder = jax.vmap(decoder, in_axes=(0, None), out_axes=0)\n",
"batched_encoder = jax.vmap(encoder, in_axes=(0, None), out_axes=0)\n",
"grad = jax.grad(modified_loss, (1, 2))\n",
"fast_grad = jax.jit(grad)\n",
"fast_loss = jax.jit(batched_loss)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Training\n",
"\n",
"Finally comes the training. The only changes to this code are to flatten our input data and shuffle to prevent the each batch from having similar conformations."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"batch_size = 32\n",
"epochs = 250\n",
"key = jax.random.PRNGKey(0)\n",
"\n",
"flat_data = data.reshape(-1, input_dim)\n",
"# scramble it\n",
"flat_data = jax.random.shuffle(key, flat_data)\n",
"\n",
"\n",
"opt_init, opt_update, get_params = optimizers.adam(step_size=1e-2)\n",
"theta0, key = init_theta(input_dim, hidden_units, latent_dim, num_layers, key)\n",
"phi0, key = init_phi(input_dim, hidden_units, latent_dim, num_layers, key)\n",
"opt_state = opt_init((theta0, phi0))\n",
"losses = []\n",
"# KL/Reconstruction balance\n",
"beta = 0.01\n",
"for e in range(epochs):\n",
" for bi, i in enumerate(range(0, len(flat_data), batch_size)):\n",
" # make a batch into shape B x 1\n",
" batch = flat_data[i : (i + batch_size)].reshape(-1, input_dim)\n",
" # udpate random number key\n",
" key, subkey = jax.random.split(key)\n",
" # get current parameter values from optimizer\n",
" theta, phi = get_params(opt_state)\n",
" last_state = opt_state\n",
" # compute gradient and update\n",
" grad = fast_grad(batch, theta, phi, key, beta)\n",
" opt_state = opt_update(bi, grad, opt_state)\n",
" # use large batch for tracking progress\n",
" lvalue = jnp.mean(fast_loss(flat_data[:100], theta, phi, subkey), axis=0)\n",
" losses.append(lvalue)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"plt.plot([l[0] for l in losses], label=\"Reconstruction\")\n",
"plt.plot([l[1] for l in losses], label=\"KL\")\n",
"plt.plot([l[1] + l[0] for l in losses], label=\"ELBO\")\n",
"plt.legend()\n",
"plt.ylim(0, 20)\n",
"plt.xlabel(\"epoch\")\n",
"plt.ylabel(\"loss\")\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"As usual, this model is undertrained. A latent space of 2, which we chose for plotting convenience, is also probably a little too compressed. Let's sample a few conformation and see how they look."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"sampled_data = decoder(jax.random.normal(key, shape=[latent_dim]), theta).reshape(-1, 2)\n",
"plt.plot(sampled_data[:, 0], sampled_data[:, 1], \"-o\", alpha=1)\n",
"plt.xticks([])\n",
"plt.yticks([])\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"These look reasonable compared with the trajectory video showing the training conformations."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Using VAE on a Trajectory\n",
"\n",
"There are three main things to do with a VAE on a trajectory. The first is to go from a trajectory in the feature dimension to the latent dimension. This can simplify analysis of dynamics or act as a reaction coordinate for free energy methods. The second is to generate new conformations. This could be used to fill-in under sampling or perhaps extrapolate to new regions of latent space. You can also use the VAE to examine marginals that are perhaps under-sampled. Finally, you can do optimization on the latent space. For example, you could try to find the most compact structure. We'll examine these examples but there are many other things you could examine. For a more complete model example with attention and 3D coordinates, see Winter et al. {cite}`winter2021auto`. You can find applications of VAEs on trajectories for molecular design {cite}`shmilovich2020discovery`, coarse-graining {cite}`wang2019coarse`, and identifying rare-events {cite}`ribeiro2018reweighted`. "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Latent Trajectory\n",
"\n",
"Let's start by computing a latent trajctory. I'm going to load a shorter trajectory which has the frames closer together in time. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"urllib.request.urlretrieve(\n",
" \"https://github.com/whitead/dmol-book/raw/master/data/paths.npz\", \"paths.npz\"\n",
")\n",
"paths = np.load(\"paths.npz\")[\"arr\"]\n",
"short_data = align_principle(center_com(paths))\n",
"\n",
"# get latent params\n",
"# throw away standard deviation\n",
"latent_traj = batched_encoder(short_data.reshape(-1, input_dim), phi)[:, 0]\n",
"plt.plot(latent_traj[:, 0], latent_traj[:, 1], \"-o\", markersize=5, alpha=0.5)\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"You can see that the trajectory is relatively continuous, except for a few wide jumps. We'll see below that this is because the alignment process can have big jumps as our principle axis rapidly moves when the points rearrange. Let's compare the video and the z-path side-by-side. You can find the code for this movie on the github repo."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"tags": [
"remove-cell"
]
},
"outputs": [],
"source": [
"from moviepy.editor import VideoClip\n",
"from moviepy.video.io.bindings import mplfig_to_npimage\n",
"from matplotlib.collections import LineCollection\n",
"\n",
"\n",
"def make_segments(data, time_index):\n",
" points = np.array([data[time_index, :, 0], data[time_index, :, 1]]).T.reshape(\n",
" -1, 1, 2\n",
" )\n",
" segments = np.concatenate([points[:-1], points[1:]], axis=1)\n",
" return segments\n",
"\n",
"\n",
"dpi = 100\n",
"fig, axs = plt.subplots(ncols=3, figsize=(1920 / dpi, 1080 / dpi / 2), dpi=dpi)\n",
"fps = 60\n",
"fronts = axs[0].plot(\n",
" paths[-1][:, 0], paths[-1][:, 1], \"o\", zorder=0, color=\"C0\", markersize=12\n",
")[0]\n",
"afronts = axs[1].plot(\n",
" short_data[-1][:, 0], short_data[-1][:, 1], \"o\", zorder=0, color=\"C0\", markersize=12\n",
")[0]\n",
"zfront = axs[2].plot(\n",
" latent_traj[0, 0], latent_traj[0, 1], \"o\", zorder=0, color=\"C0\", markersize=12\n",
")[0]\n",
"zl = axs[2].plot(latent_traj[0, 0], latent_traj[1, 0], \"-\", alpha=1, color=\"C1\")[0]\n",
"axs[0].set_xlim(np.nanmin(paths, axis=(0, 1))[0], np.nanmax(paths, axis=(0, 1))[0])\n",
"axs[0].set_ylim(np.nanmin(paths, axis=(0, 1))[1], np.nanmax(paths, axis=(0, 1))[1])\n",
"axs[1].set_xlim(\n",
" np.nanmin(short_data, axis=(0, 1))[0], np.nanmax(short_data, axis=(0, 1))[0]\n",
")\n",
"axs[1].set_ylim(\n",
" np.nanmin(short_data, axis=(0, 1))[1], np.nanmax(short_data, axis=(0, 1))[1]\n",
")\n",
"axs[2].set_xlim(np.min(latent_traj[:, 0]), np.max(latent_traj[:, 0]))\n",
"axs[2].set_ylim(np.min(latent_traj[:, 1]), np.max(latent_traj[:, 1]))\n",
"axs[0].set_title(\"Trajectory\", fontsize=36)\n",
"axs[1].set_title(\"Aligned\", fontsize=36)\n",
"axs[2].set_title(\"Latent\", fontsize=36)\n",
"for i in range(3):\n",
" axs[i].set_xticks([])\n",
" axs[i].set_yticks([])\n",
"plt.tight_layout()\n",
"\n",
"T = paths.shape[0]\n",
"line_collections = []\n",
"line_segments = []\n",
"aline_segments = []\n",
"for i in range(T):\n",
" seg = make_segments(paths, i)\n",
" line_segments.append(seg)\n",
" seg = make_segments(short_data, i)\n",
" aline_segments.append(seg)\n",
"\n",
"\n",
"def make_frame(t):\n",
" frame = int(fps * t)\n",
" if len(line_collections) == 0:\n",
" lc = LineCollection(\n",
" line_segments[frame], color=\"C0\", norm=plt.Normalize(0, 1), alpha=0.2\n",
" )\n",
" axs[0].add_collection(lc)\n",
" line_collections.append(lc)\n",
" lc = LineCollection(\n",
" aline_segments[frame], color=\"C0\", norm=plt.Normalize(0, 1), alpha=0.2\n",
" )\n",
" axs[1].add_collection(lc)\n",
" line_collections.append(lc)\n",
" j = min(frame, T - 1)\n",
" # Set the values used for colormapping\n",
" # lc.set_array(np.linspace(1,0,T)[:frame])\n",
" line_collections[0].set_segments(line_segments[j])\n",
" line_collections[1].set_segments(aline_segments[j])\n",
" fronts.set_data(paths[j][:, 0], paths[j][:, 1])\n",
" afronts.set_data(short_data[j][:, 0], short_data[j][:, 1])\n",
" zfront.set_data(latent_traj[j, 0], latent_traj[j, 1])\n",
" zl.set_data(latent_traj[:j, 0], latent_traj[:j, 1])\n",
" return mplfig_to_npimage(fig)\n",
"\n",
"\n",
"duration = T / fps\n",
"animation = VideoClip(make_frame, duration=duration)\n",
"animation.write_videofile(filename=\"../_static/images/latent_traj.mp4\", fps=fps)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"tags": [
"remove-cell"
]
},
"outputs": [],
"source": [
"fig = plt.figure(figsize=(1024 / dpi, 720 / dpi))\n",
"ax = plt.gca()\n",
"fps = 60\n",
"fronts = ax.plot(\n",
" paths[-1][:, 0], paths[-1][:, 1], \"o\", zorder=0, color=\"C0\", markersize=12\n",
")[0]\n",
"ax.set_xlim(np.nanmin(paths, axis=(0, 1))[0], np.nanmax(paths, axis=(0, 1))[0])\n",
"ax.set_ylim(np.nanmin(paths, axis=(0, 1))[1], np.nanmax(paths, axis=(0, 1))[1])\n",
"ax.set_title(\"Trajectory\", fontsize=36)\n",
"ax.set_xticks([])\n",
"ax.set_yticks([])\n",
"plt.tight_layout()\n",
"\n",
"T = paths.shape[0]\n",
"line_collections = []\n",
"line_segments = []\n",
"aline_segments = []\n",
"for i in range(T):\n",
" seg = make_segments(paths, i)\n",
" line_segments.append(seg)\n",
" seg = make_segments(short_data, i)\n",
" aline_segments.append(seg)\n",
"\n",
"\n",
"def make_frame(t):\n",
" frame = int(fps * t)\n",
" if len(line_collections) == 0:\n",
" lc = LineCollection(\n",
" line_segments[frame], color=\"C0\", norm=plt.Normalize(0, 1), alpha=0.2\n",
" )\n",
" ax.add_collection(lc)\n",
" line_collections.append(lc)\n",
" lc = LineCollection(\n",
" aline_segments[frame], color=\"C0\", norm=plt.Normalize(0, 1), alpha=0.2\n",
" )\n",
" j = min(frame, T - 1)\n",
" # Set the values used for colormapping\n",
" # lc.set_array(np.linspace(1,0,T)[:frame])\n",
" line_collections[0].set_segments(line_segments[j])\n",
" fronts.set_data(paths[j][:, 0], paths[j][:, 1])\n",
" return mplfig_to_npimage(fig)\n",
"\n",
"\n",
"duration = T / fps\n",
"animation = VideoClip(make_frame, duration=duration)\n",
"animation.write_videofile(filename=\"../_static/images/traj.mp4\", fps=fps)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"\n",
"\n",
" \n",
"

\n",
"\n",
"You can see the quick change is due to our alignment quickly changing. This is why aligning on the principle axis isn't always perfect: your axis can flip 90 degrees because the internal points change the moment of inertia enough to change."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Generate New Samples\n",
"\n",
"Let's see how our samples look. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"fig, axs = plt.subplots(ncols=2, figsize=(12, 4))\n",
"sampled_data = batched_decoder(\n",
" np.random.normal(size=(data.shape[0], latent_dim)), theta\n",
").reshape(data.shape[0], -1, 2)\n",
"for i in range(0, data.shape[0]):\n",
" axs[0].plot(data[i, :, 0], data[i, :, 1], \"-\", alpha=0.1, color=\"C1\")\n",
" axs[1].plot(\n",
" sampled_data[i, :, 0], sampled_data[i, :, 1], \"-\", alpha=0.1, color=\"C1\"\n",
" )\n",
"axs[0].set_title(\"Training\")\n",
"axs[1].set_title(\"Generated\")\n",
"for i in range(2):\n",
" axs[i].set_xticks([])\n",
" axs[i].set_yticks([])\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The samples are not perfect, but we're close. Let's examine the marginals. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"fig, axs = plt.subplots(ncols=4, squeeze=True, figsize=(16, 4))\n",
"for i, j in enumerate(range(1, 9, 2)):\n",
" axs[i].set_title(f\"Dist between 0-{j}\")\n",
" sns.distplot(np.linalg.norm(data[:, 0] - data[:, j], axis=1), ax=axs[i])\n",
" sns.distplot(\n",
" np.linalg.norm(sampled_data[:, 0] - sampled_data[:, j], axis=1),\n",
" ax=axs[i],\n",
" hist=False,\n",
" )\n",
"plt.tight_layout()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"You can see that there are some issues here as well. Remember that our latent space is quite small: 2D. So we should not be that surprised that we're losing information from our 24D input space. "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Optimization on Latent Space\n",
"\n",
"Finally, let us examine how we can optimize in the latent space. Let's say I want to find the most compact structure. We'll define our loss function as the radius of gyration and take its derivative with respect to $z$. Recall the definition of radius of gyration is\n",
"\n",
"\\begin{equation}\n",
"R_g = \\frac{1}{N}\\sum_i r_i^2\n",
"\\end{equation}\n",
"\n",
"where $r_i$ is distance to center of mass. Our generated samples are, by definition, centered at the origin though so we do not need to worry about center of mass. We want to take derivatives in $z$, but need samples in $x$ to compute radius of gyration. We use the decoder to get an $x$ and can propagate derivatives through it, because it is a differentiable neural network."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def rg_loss(z):\n",
" x = decoder(z, theta).reshape(-1, 2)\n",
" rg = jnp.sum(x**2)\n",
" return jnp.sqrt(rg)\n",
"\n",
"\n",
"rg_grad = jax.jit(jax.grad(rg_loss))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now we will find the $z$ that minimizes the radius of gyration by using gradient descent with the derivative. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"z = jax.random.normal(key, shape=[latent_dim])\n",
"losses = []\n",
"eta = 1e-2\n",
"for i in range(100):\n",
" losses.append(rg_loss(z))\n",
" g = rg_grad(z)\n",
" z -= eta * g\n",
"plt.plot(losses)\n",
"plt.xlabel(\"Iterations\")\n",
"plt.ylabel(\"$R_g$\")\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We have a $z$ with a very low radius of gyration. How good is it? Well, we can also see what was the lowest radius of gyration *observed* structure in our trajectory. We compare them below."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# get min from training\n",
"train_rgmin = np.argmin(np.sum(data**2, axis=(1, 2)))\n",
"# use new z\n",
"opt_rgmin = decoder(z, theta).reshape(-1, 2)\n",
"plt.plot(\n",
" data[train_rgmin, :, 0], data[train_rgmin, :, 1], \"o-\", label=\"Training\", alpha=0.8\n",
")\n",
"plt.plot(opt_rgmin[:, 0], opt_rgmin[:, 1], \"o-\", label=\"Optimized\", alpha=0.8)\n",
"plt.xticks([])\n",
"plt.yticks([])\n",
"plt.legend()\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"What is remarkable about this is that the optimized one has no overlaps and still reasonable bond-lengths. It is also more compact than the lowest radius of gyration found in the training example. "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Relevant Videos\n",
"\n",
"### Using VAE for Coarse-Grained Molecular Simulation\n",
"\n",
"\n",
"\n",
"### Using VAE for Molecular Graph Generation\n",
"\n",
"\n",
"\n",
"### Review of Molecular Graph Generative Models (including VAE)\n",
"\n",
""
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Chapter Summary \n",
"\n",
"* A variational autoencoder is a generative deep learning model capable of unsupervised learning. It is capable of of generating new data points not seen in training.\n",
"* A VAE is a set of two trained conditional probability distributions that operate on examples from the data $x$ and the latent space $z$. The encoder goes from data to latent and the decoder goes from latent to data.\n",
"* The loss function is the log likelihood that we observed the training point $x_i$.\n",
"* Taking the log allows us to sum/average over data to aggregate multiple points.\n",
"* The VAE can be used for both discrete or continuous features.\n",
"* The goal with VAE is to reproduce the probability distribution of $x$. Comparing the distribution over $z$ and that of $x$ allows us to evaluate how well the VAE operates.\n",
"* A bead-spring polymer VAE example shows how VAEs operate on a trajectory. "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Cited References\n",
"\n",
"```{bibliography}\n",
":style: unsrtalpha\n",
":filter: docname in docnames\n",
"```"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"celltoolbar": "Tags",
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.12"
}
},
"nbformat": 4,
"nbformat_minor": 4
}