6. Deep Learning Overview

Deep learning is a category of machine learning. Machine learning is a category of artificial intelligence. Deep learning is the use of neural networks to do machine learning, like classify and regress data. This chapter provides an overview and we will dive further into these topics in later chapters.

Audience & Objectives

This chapter builds on Regression and Introduction to Machine Learning. After completing this chapter, you should be able to

  • Define deep learning

  • Define a neural network

  • Connect the previous regression principles to neural networks

There are many good resources on deep learning to supplement these chapters. The goal of this book is to present a chemistry and materials-first introduction to deep learning. These other resources can help provide better depth in certain topics and cover topics we do not even cover, because I do not find them relevant to deep learning (e.g., image processing). I found the introduction the from Ian Goodfellow’s book to be a good intro. If you’re more visually oriented, Grant Sanderson has made a short video series specifically about neural networks that give an applied introduction to the topic. DeepMind has a high-level video showing what can be accomplished with deep learning & AI. When people write “deep learning is a powerful tool” in their research papers, they typically cite this Nature paper by Yann LeCun, Yoshua Bengio, and Geoffery Hinton. Zhang, Lipton, Li, and Smola have written a practical and example-driven online book that gives each example in Tensorflow, PyTorch, and MXNet. You can find many chemistry-specific examples and information about deep learning in chemistry via the excellent DeepChem project. Finally, some deep learning package provide a short introduction to deep learning via a tutorial of its API: Keras, PyTorch.

The main advice I would give to beginners in deep learning are to focus less on the neurological inspired language (i.e., connections between neurons), and instead view deep learning as a series of linear algebra operations where many of the matrices are filled with adjustable parameters. Of course non-linear functions (activations) are used to join the linear algebra operations, but deep learning is essentially linear algebra operations specified via a “computation network” (aka computation graph) that vaguely looks like neurons connected in a brain.

Non-linearity

A function \(f(\vec{x})\) is linear if two conditions hold:

(6.1)\[\begin{equation} f(\vec{x} + \vec{y}) = f(\vec{x}) + f(\vec{y}) \end{equation}\]

for all \(\vec{x}\) and \(\vec{y}\). And

(6.2)\[\begin{equation} f(s\vec{x}) = sf(x) \end{equation}\]

where \(s\) is a scalar. A function is non-linear if these conditions do not hold for some \(\vec{x}\).

6.1. What is a neural network?

The deep in deep learning means we have many layers in our neural networks. What is a neural network? Without loss of generality, we can view neural networks as 2 components: (1) a non-linear function \(g(\cdot)\) which operates on our input features \(\mathbf{X}\) and outputs a new set of features \(\mathbf{H} = g(\mathbf{X})\) and (2) a linear model like we saw in our Introduction to Machine Learning. Our model equation for deep learning regression is:

(6.3)\[\begin{equation} \hat{y} = \vec{w}g(\vec{x}) + b \end{equation}\]

One of the main discussion points in our ML chapters was how arcane and difficult it is to choose features. Here, we have replaced our features with a set of trainable features \(g(\vec{x})\) and then use the same linear model as before. So how do we design \(g(\vec{x})\)? That is the deep learning part. \(g(\vec{x})\) is a differentiable function composed of layers, which are themselves differentiable functions each with trainable weights (free variables). Deep learning is a mature field and there is a set of standard layers, each with a different purpose. For example, convolution layers look at a fixed neighborhood around each element of an input tensor. Dropout layers randomly inactivate inputs as a form of regularization. The most commonly used and basic layer is the dense or fully-connected layer.

A dense layer is defined by two things: the desired output feature shape and the activation. The equation is:

(6.4)\[\begin{equation} \vec{h} = \sigma(\mathbf{W}\vec{x} + \vec{b}) \end{equation}\]

where \(\mathbf{W}\) is a trainable \(D \times F\) matrix, where \(D\) is the input vector (\(\vec{x}\)) dimension and \(F\) is the output vector (\(\vec{h}\)) dimension, \(\vec{b}\) is a trainable \(F\) dimensional vector, and \(\sigma(\cdot)\) is the activation function. \(F\) is an example of a hyperparameter, it is not trainable but is a problem dependent choice. \(\sigma(\cdot)\) is another hyperparameter. In principle, any differentiable function that has a range of \((-\infty, \infty)\) can be used for activation. However, just a few activations have been empirically designed that balance computational cost and effectiveness. One example we’ve seen before is the sigmoid. Another is a hyperbolic tangent, which behaves similar (domain/range) to the sigmoid. The most commonly used activation is the rectified linear unit (ReLU), which is

(6.5)\[\begin{equation} \sigma(x) = \left\{\begin{array}{lr} x & x > 0\\ 0 & \textrm{otherwise}\\ \end{array}\right. \end{equation}\]

6.1.1. Universal Approximation Theorem

One of the reasons that neural networks are a good choice at approximating unknown functions (\(f(\vec{x})\)) is that a neural network can approximate any function with a large enough network depth (number of layers) or width (size of hidden layers). There are many variations of this theorem – infinitely wide or infinitely deep neural networks. For example, any 1 dimensional function can be approximated by a depth 5 neural network with ReLU activation functions with infinitely wide layers (infinite hidden dimension) [LPW+17]. The universal approximation theorem shows that neural networks are, in the limit of large depth or width, expressive enough to fit any function.

6.1.2. Frameworks

Deep learning has lots of “gotchas” – easy to make mistakes that make it difficult to implement things yourself. This is especially true with numerical stability, which only reveals itself when your model fails to learn. We will move to a bit of a more abstract software framework than JAX for some examples. We’ll use Keras, which is one of many possible choices for deep learning frameworks.

6.1.3. Discussion

When it comes to introducing deep learning, I will be as terse as possible. There are good learning resources out there. You should use some of the reading above and tutorials put out by Keras (or PyTorch) to get familiar with the concepts of neural networks and learning.

6.2. Revisiting Solubility Model

We’ll see our first example of deep learning by revisiting the solubility dataset with a two layer dense neural network.

6.2.1. Running This Notebook

Click the    above to launch this page as an interactive Google Colab. See details below on installing packages, either on your own environment or on Google Colab

The hidden cells below sets-up our imports and/or install necessary packages.

import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import matplotlib as mpl
import tensorflow as tf
import numpy as np

np.random.seed(0)
import warnings

warnings.filterwarnings("ignore")
sns.set_context("notebook")
sns.set_style(
    "dark",
    {
        "xtick.bottom": True,
        "ytick.left": True,
        "xtick.color": "#666666",
        "ytick.color": "#666666",
        "axes.edgecolor": "#666666",
        "axes.linewidth": 0.8,
    },
)
color_cycle = ["#1BBC9B", "#F06060", "#5C4B51", "#F3B562", "#6e5687"]
mpl.rcParams["axes.prop_cycle"] = mpl.cycler(color=color_cycle)

6.2.2. Load Data

We download the data and load it into a Pandas data frame and then standardize our features as before.

# soldata = pd.read_csv('https://dataverse.harvard.edu/api/access/datafile/3407241?format=original&gbrecs=true')
# had to rehost because dataverse isn't reliable
soldata = pd.read_csv(
    "https://github.com/whitead/dmol-book/raw/master/data/curated-solubility-dataset.csv"
)
features_start_at = list(soldata.columns).index("MolWt")
feature_names = soldata.columns[features_start_at:]
# standardize the features
soldata[feature_names] -= soldata[feature_names].mean()
soldata[feature_names] /= soldata[feature_names].std()

6.3. Prepare Data for Keras

The deep learning libraries simplify many common tasks, like splitting data and building layers. This code below builds our dataset from numpy arrays.

full_data = tf.data.Dataset.from_tensor_slices(
    (soldata[feature_names].values, soldata["Solubility"].values)
)
N = len(soldata)
test_N = int(0.1 * N)
test_data = full_data.take(test_N).batch(16)
train_data = full_data.skip(test_N).batch(16)

Notice that we used skip and take (See tf.data.Dataset) to split our dataset into two pieces and create batches of data.

6.4. Neural Network

Now we build our neural network model. In this case, our \(g(\vec{x}) = \sigma\left(\mathbf{W^0}\vec{x} + \vec{b}\right)\). We will call the function \(g(\vec{x})\) a hidden layer. This is because we do not observe its output. Remember, the solubility will be \(y = \vec{w}g(\vec{x}) + b\). We’ll choose our activation, \(\sigma(\cdot)\), to be tanh and the output dimension of the hidden-layer to be 32. The choice of tanh is empirical — there are many choices of non-linearity and they are typically chosen based on efficiency and empirical accuracy. You can read more about this Keras API here, however you should be able to understand the process from the function names and comments.

# our hidden layer
# We only need to define the output dimension - 32.
hidden_layer = tf.keras.layers.Dense(32, activation="tanh")
# Last layer - which we want to output one number
# the predicted solubility.
output_layer = tf.keras.layers.Dense(1)

# Now we put the layers into a sequential model
model = tf.keras.Sequential()
model.add(hidden_layer)
model.add(output_layer)

# our model is complete

# Try out our model on first few datapoints
model(soldata[feature_names].values[:3])
<tf.Tensor: shape=(3, 1), dtype=float32, numpy=
array([[0.6522273 ],
       [0.04297334],
       [0.36673006]], dtype=float32)>

We can see our model predicting the solubility for 3 molecules above. There may be a warning about how our Pandas data is using float64 (double precision floating point numbers) but our model is using float32 (single precision), which doesn’t matter that much. It warns us because we are technically throwing out a little bit of precision, but our solubility has much more variance than the difference between 32 and 64 bit precision floating point numbers. We can remove this warning by modifying the last line to be:

model(soldata[feature_names].values[:3].astype(float))

At this point, we’ve defined how our model structure should work and it can be called on data. Now we need to train it! We prepare the model for training by calling model.compile, which is where we define our optimization (typically a flavor of stochastic gradient descent) and loss

model.compile(optimizer="SGD", loss="mean_squared_error")

Look back to the amount of work it took to previously set-up loss and optimization process! Now we can train our model

model.fit(train_data, epochs=50)
Epoch 1/50
  1/562 [..............................] - ETA: 2:58 - loss: 12.0218

 45/562 [=>............................] - ETA: 0s - loss: 6.3034   

 92/562 [===>..........................] - ETA: 0s - loss: 4.0686

136/562 [======>.......................] - ETA: 0s - loss: 3.3807

185/562 [========>.....................] - ETA: 0s - loss: 3.7349

235/562 [===========>..................] - ETA: 0s - loss: 3.3403

284/562 [==============>...............] - ETA: 0s - loss: 2.9977

332/562 [================>.............] - ETA: 0s - loss: 2.7485

383/562 [===================>..........] - ETA: 0s - loss: 2.5150

432/562 [======================>.......] - ETA: 0s - loss: 2.3377

472/562 [========================>.....] - ETA: 0s - loss: 2.2515

505/562 [=========================>....] - ETA: 0s - loss: 2.1607

536/562 [===========================>..] - ETA: 0s - loss: 2.1188

562/562 [==============================] - 1s 1ms/step - loss: 2.0916
Epoch 2/50
  1/562 [..............................] - ETA: 2s - loss: 1.1392

 43/562 [=>............................] - ETA: 0s - loss: 2.4282

 92/562 [===>..........................] - ETA: 0s - loss: 2.1673

139/562 [======>.......................] - ETA: 0s - loss: 1.9837

174/562 [========>.....................] - ETA: 0s - loss: 2.5709

221/562 [==========>...................] - ETA: 0s - loss: 2.4246

272/562 [=============>................] - ETA: 0s - loss: 2.2005

322/562 [================>.............] - ETA: 0s - loss: 2.0359

364/562 [==================>...........] - ETA: 0s - loss: 1.9036

411/562 [====================>.........] - ETA: 0s - loss: 1.7918

459/562 [=======================>......] - ETA: 0s - loss: 1.6987

509/562 [==========================>...] - ETA: 0s - loss: 1.6403

559/562 [============================>.] - ETA: 0s - loss: 1.6205

562/562 [==============================] - 1s 1ms/step - loss: 1.6184
Epoch 3/50
  1/562 [..............................] - ETA: 2s - loss: 1.0553

 42/562 [=>............................] - ETA: 0s - loss: 2.2453

 88/562 [===>..........................] - ETA: 0s - loss: 2.0549

133/562 [======>.......................] - ETA: 0s - loss: 1.8319

184/562 [========>.....................] - ETA: 0s - loss: 2.3892

234/562 [===========>..................] - ETA: 0s - loss: 2.2278

282/562 [==============>...............] - ETA: 0s - loss: 2.0259

333/562 [================>.............] - ETA: 0s - loss: 1.8757

383/562 [===================>..........] - ETA: 0s - loss: 1.7430

434/562 [======================>.......] - ETA: 0s - loss: 1.6340

484/562 [========================>.....] - ETA: 0s - loss: 1.5765

535/562 [===========================>..] - ETA: 0s - loss: 1.5233

562/562 [==============================] - 1s 1ms/step - loss: 1.5222
Epoch 4/50
  1/562 [..............................] - ETA: 2s - loss: 0.9402

 50/562 [=>............................] - ETA: 0s - loss: 2.5232

 98/562 [====>.........................] - ETA: 0s - loss: 1.8638

146/562 [======>.......................] - ETA: 0s - loss: 1.7923

197/562 [=========>....................] - ETA: 0s - loss: 2.2823

242/562 [===========>..................] - ETA: 0s - loss: 2.1062

288/562 [==============>...............] - ETA: 0s - loss: 1.9399

334/562 [================>.............] - ETA: 0s - loss: 1.8027

378/562 [===================>..........] - ETA: 0s - loss: 1.6820

423/562 [=====================>........] - ETA: 0s - loss: 1.5899

471/562 [========================>.....] - ETA: 0s - loss: 1.5346

521/562 [==========================>...] - ETA: 0s - loss: 1.4713

562/562 [==============================] - 1s 1ms/step - loss: 1.4658
Epoch 5/50
  1/562 [..............................] - ETA: 1s - loss: 0.8744

 52/562 [=>............................] - ETA: 0s - loss: 2.4491

 99/562 [====>.........................] - ETA: 0s - loss: 1.8227

144/562 [======>.......................] - ETA: 0s - loss: 1.7533

189/562 [=========>....................] - ETA: 0s - loss: 2.2386

237/562 [===========>..................] - ETA: 0s - loss: 2.0768

283/562 [==============>...............] - ETA: 0s - loss: 1.9054

329/562 [================>.............] - ETA: 0s - loss: 1.7703

370/562 [==================>...........] - ETA: 0s - loss: 1.6459

414/562 [=====================>........] - ETA: 0s - loss: 1.5643

461/562 [=======================>......] - ETA: 0s - loss: 1.4883

511/562 [==========================>...] - ETA: 0s - loss: 1.4400

558/562 [============================>.] - ETA: 0s - loss: 1.4289

562/562 [==============================] - 1s 1ms/step - loss: 1.4254
Epoch 6/50
  1/562 [..............................] - ETA: 2s - loss: 0.8524

 39/562 [=>............................] - ETA: 0s - loss: 2.1880

 84/562 [===>..........................] - ETA: 0s - loss: 1.9714

131/562 [=====>........................] - ETA: 0s - loss: 1.7036

175/562 [========>.....................] - ETA: 0s - loss: 2.2434

219/562 [==========>...................] - ETA: 0s - loss: 2.1141

266/562 [=============>................] - ETA: 0s - loss: 1.9254

315/562 [===============>..............] - ETA: 0s - loss: 1.7761

358/562 [==================>...........] - ETA: 0s - loss: 1.6559

397/562 [====================>.........] - ETA: 0s - loss: 1.5665

446/562 [======================>.......] - ETA: 0s - loss: 1.4840

490/562 [=========================>....] - ETA: 0s - loss: 1.4353

535/562 [===========================>..] - ETA: 0s - loss: 1.4025

562/562 [==============================] - 1s 1ms/step - loss: 1.4007
Epoch 7/50
  1/562 [..............................] - ETA: 2s - loss: 0.8517

 47/562 [=>............................] - ETA: 0s - loss: 2.4907

 94/562 [====>.........................] - ETA: 0s - loss: 1.8404

143/562 [======>.......................] - ETA: 0s - loss: 1.6992

193/562 [=========>....................] - ETA: 0s - loss: 2.1939

233/562 [===========>..................] - ETA: 0s - loss: 2.0406

270/562 [=============>................] - ETA: 0s - loss: 1.8930

310/562 [===============>..............] - ETA: 0s - loss: 1.7724

357/562 [==================>...........] - ETA: 0s - loss: 1.6413

408/562 [====================>.........] - ETA: 0s - loss: 1.5255

456/562 [=======================>......] - ETA: 0s - loss: 1.4501

504/562 [=========================>....] - ETA: 0s - loss: 1.4014

549/562 [============================>.] - ETA: 0s - loss: 1.3954

562/562 [==============================] - 1s 1ms/step - loss: 1.3856
Epoch 8/50
  1/562 [..............................] - ETA: 2s - loss: 0.8654

 49/562 [=>............................] - ETA: 0s - loss: 2.4958

 97/562 [====>.........................] - ETA: 0s - loss: 1.8008

141/562 [======>.......................] - ETA: 0s - loss: 1.6974

181/562 [========>.....................] - ETA: 0s - loss: 2.2043

220/562 [==========>...................] - ETA: 0s - loss: 2.0774

270/562 [=============>................] - ETA: 0s - loss: 1.8780

318/562 [===============>..............] - ETA: 0s - loss: 1.7346

366/562 [==================>...........] - ETA: 0s - loss: 1.5957

417/562 [=====================>........] - ETA: 0s - loss: 1.5023

467/562 [=======================>......] - ETA: 0s - loss: 1.4342

517/562 [==========================>...] - ETA: 0s - loss: 1.3826

562/562 [==============================] - 1s 1ms/step - loss: 1.3736
Epoch 9/50
  1/562 [..............................] - ETA: 1s - loss: 0.8814

 47/562 [=>............................] - ETA: 0s - loss: 2.4486

 94/562 [====>.........................] - ETA: 0s - loss: 1.8169

135/562 [======>.......................] - ETA: 0s - loss: 1.7077

177/562 [========>.....................] - ETA: 0s - loss: 2.2095

228/562 [===========>..................] - ETA: 0s - loss: 2.0251

277/562 [=============>................] - ETA: 0s - loss: 1.8344

327/562 [================>.............] - ETA: 0s - loss: 1.6971

369/562 [==================>...........] - ETA: 0s - loss: 1.5741

421/562 [=====================>........] - ETA: 0s - loss: 1.4821

471/562 [========================>.....] - ETA: 0s - loss: 1.4235

519/562 [==========================>...] - ETA: 0s - loss: 1.3689

562/562 [==============================] - 1s 1ms/step - loss: 1.3629
Epoch 10/50
  1/562 [..............................] - ETA: 2s - loss: 0.8840

 49/562 [=>............................] - ETA: 0s - loss: 2.4378

 95/562 [====>.........................] - ETA: 0s - loss: 1.7886

142/562 [======>.......................] - ETA: 0s - loss: 1.6709

190/562 [=========>....................] - ETA: 0s - loss: 2.1369

235/562 [===========>..................] - ETA: 0s - loss: 1.9879

285/562 [==============>...............] - ETA: 0s - loss: 1.8058

336/562 [================>.............] - ETA: 0s - loss: 1.6542

386/562 [===================>..........] - ETA: 0s - loss: 1.5315

433/562 [======================>.......] - ETA: 0s - loss: 1.4513

479/562 [========================>.....] - ETA: 0s - loss: 1.4026

530/562 [===========================>..] - ETA: 0s - loss: 1.3557

562/562 [==============================] - 1s 1ms/step - loss: 1.3521
Epoch 11/50
  1/562 [..............................] - ETA: 2s - loss: 0.8831

 50/562 [=>............................] - ETA: 0s - loss: 2.3886

100/562 [====>.........................] - ETA: 0s - loss: 1.7275

148/562 [======>.......................] - ETA: 0s - loss: 1.7586

196/562 [=========>....................] - ETA: 0s - loss: 2.1204

238/562 [===========>..................] - ETA: 0s - loss: 1.9590

286/562 [==============>...............] - ETA: 0s - loss: 1.7894

332/562 [================>.............] - ETA: 0s - loss: 1.6568

379/562 [===================>..........] - ETA: 0s - loss: 1.5402

430/562 [=====================>........] - ETA: 0s - loss: 1.4465

480/562 [========================>.....] - ETA: 0s - loss: 1.3917

523/562 [==========================>...] - ETA: 0s - loss: 1.3440

562/562 [==============================] - 1s 1ms/step - loss: 1.3438
Epoch 12/50
  1/562 [..............................] - ETA: 2s - loss: 0.8819

 51/562 [=>............................] - ETA: 0s - loss: 2.3468

 93/562 [===>..........................] - ETA: 0s - loss: 1.7732

140/562 [======>.......................] - ETA: 0s - loss: 1.6685

182/562 [========>.....................] - ETA: 0s - loss: 2.1458

225/562 [===========>..................] - ETA: 0s - loss: 1.9967

269/562 [=============>................] - ETA: 0s - loss: 1.8280

313/562 [===============>..............] - ETA: 0s - loss: 1.6996

358/562 [==================>...........] - ETA: 0s - loss: 1.5788

408/562 [====================>.........] - ETA: 0s - loss: 1.4698

460/562 [=======================>......] - ETA: 0s - loss: 1.3973

509/562 [==========================>...] - ETA: 0s - loss: 1.3467

558/562 [============================>.] - ETA: 0s - loss: 1.3404

562/562 [==============================] - 1s 1ms/step - loss: 1.3371
Epoch 13/50
  1/562 [..............................] - ETA: 1s - loss: 0.8801

 43/562 [=>............................] - ETA: 0s - loss: 2.1320

 88/562 [===>..........................] - ETA: 0s - loss: 1.8316

127/562 [=====>........................] - ETA: 0s - loss: 1.5644

175/562 [========>.....................] - ETA: 0s - loss: 2.1531

215/562 [==========>...................] - ETA: 0s - loss: 2.0357

254/562 [============>.................] - ETA: 0s - loss: 1.8722

301/562 [===============>..............] - ETA: 0s - loss: 1.7193

349/562 [=================>............] - ETA: 0s - loss: 1.6025

397/562 [====================>.........] - ETA: 0s - loss: 1.4848

443/562 [======================>.......] - ETA: 0s - loss: 1.4130

487/562 [========================>.....] - ETA: 0s - loss: 1.3649

537/562 [===========================>..] - ETA: 0s - loss: 1.3415

562/562 [==============================] - 1s 1ms/step - loss: 1.3313
Epoch 14/50
  1/562 [..............................] - ETA: 2s - loss: 0.8758

 44/562 [=>............................] - ETA: 0s - loss: 2.1732

 93/562 [===>..........................] - ETA: 0s - loss: 1.7638

141/562 [======>.......................] - ETA: 0s - loss: 1.6557

189/562 [=========>....................] - ETA: 0s - loss: 2.1029

235/562 [===========>..................] - ETA: 0s - loss: 1.9468

284/562 [==============>...............] - ETA: 0s - loss: 1.7706

333/562 [================>.............] - ETA: 0s - loss: 1.6302

373/562 [==================>...........] - ETA: 0s - loss: 1.5226

416/562 [=====================>........] - ETA: 0s - loss: 1.4508

465/562 [=======================>......] - ETA: 0s - loss: 1.3842

514/562 [==========================>...] - ETA: 0s - loss: 1.3328

555/562 [============================>.] - ETA: 0s - loss: 1.3295

562/562 [==============================] - 1s 1ms/step - loss: 1.3260
Epoch 15/50
  1/562 [..............................] - ETA: 2s - loss: 0.8698

 47/562 [=>............................] - ETA: 0s - loss: 2.3850

 95/562 [====>.........................] - ETA: 0s - loss: 1.7604

147/562 [======>.......................] - ETA: 0s - loss: 1.7073

193/562 [=========>....................] - ETA: 0s - loss: 2.0961

236/562 [===========>..................] - ETA: 0s - loss: 1.9336

283/562 [==============>...............] - ETA: 0s - loss: 1.7639

334/562 [================>.............] - ETA: 0s - loss: 1.6208

385/562 [===================>..........] - ETA: 0s - loss: 1.4970

437/562 [======================>.......] - ETA: 0s - loss: 1.4115

489/562 [=========================>....] - ETA: 0s - loss: 1.3510

541/562 [===========================>..] - ETA: 0s - loss: 1.3357

562/562 [==============================] - 1s 1ms/step - loss: 1.3213
Epoch 16/50
  1/562 [..............................] - ETA: 2s - loss: 0.8638

 52/562 [=>............................] - ETA: 0s - loss: 2.3163

 99/562 [====>.........................] - ETA: 0s - loss: 1.7201

144/562 [======>.......................] - ETA: 0s - loss: 1.6723

191/562 [=========>....................] - ETA: 0s - loss: 2.0762

239/562 [===========>..................] - ETA: 0s - loss: 1.9112

285/562 [==============>...............] - ETA: 0s - loss: 1.7548

331/562 [================>.............] - ETA: 0s - loss: 1.6239

383/562 [===================>..........] - ETA: 0s - loss: 1.4966

425/562 [=====================>........] - ETA: 0s - loss: 1.4235

473/562 [========================>.....] - ETA: 0s - loss: 1.3758

523/562 [==========================>...] - ETA: 0s - loss: 1.3163

562/562 [==============================] - 1s 1ms/step - loss: 1.3167
Epoch 17/50
  1/562 [..............................] - ETA: 3s - loss: 0.8582

 50/562 [=>............................] - ETA: 0s - loss: 2.3583

 98/562 [====>.........................] - ETA: 0s - loss: 1.7223

135/562 [======>.......................] - ETA: 0s - loss: 1.6659

182/562 [========>.....................] - ETA: 0s - loss: 2.1073

234/562 [===========>..................] - ETA: 0s - loss: 1.9304

279/562 [=============>................] - ETA: 0s - loss: 1.7543

329/562 [================>.............] - ETA: 0s - loss: 1.6232

373/562 [==================>...........] - ETA: 0s - loss: 1.5058

420/562 [=====================>........] - ETA: 0s - loss: 1.4272

470/562 [========================>.....] - ETA: 0s - loss: 1.3703

513/562 [==========================>...] - ETA: 0s - loss: 1.3194

562/562 [==============================] - 1s 1ms/step - loss: 1.3121
Epoch 18/50
  1/562 [..............................] - ETA: 1s - loss: 0.8529

 47/562 [=>............................] - ETA: 0s - loss: 2.3748

 95/562 [====>.........................] - ETA: 0s - loss: 1.7470

146/562 [======>.......................] - ETA: 0s - loss: 1.6758

192/562 [=========>....................] - ETA: 0s - loss: 2.0574

241/562 [===========>..................] - ETA: 0s - loss: 1.8848

290/562 [==============>...............] - ETA: 0s - loss: 1.7189

337/562 [================>.............] - ETA: 0s - loss: 1.5938

371/562 [==================>...........] - ETA: 0s - loss: 1.5013

417/562 [=====================>........] - ETA: 0s - loss: 1.4280

466/562 [=======================>......] - ETA: 0s - loss: 1.3643

515/562 [==========================>...] - ETA: 0s - loss: 1.3143

562/562 [==============================] - 1s 1ms/step - loss: 1.3073
Epoch 19/50
  1/562 [..............................] - ETA: 2s - loss: 0.8475

 42/562 [=>............................] - ETA: 0s - loss: 2.0349

 87/562 [===>..........................] - ETA: 0s - loss: 1.8131

136/562 [======>.......................] - ETA: 0s - loss: 1.6529

187/562 [========>.....................] - ETA: 0s - loss: 2.0708

238/562 [===========>..................] - ETA: 0s - loss: 1.8923

287/562 [==============>...............] - ETA: 0s - loss: 1.7237

336/562 [================>.............] - ETA: 0s - loss: 1.5878

384/562 [===================>..........] - ETA: 0s - loss: 1.4758

426/562 [=====================>........] - ETA: 0s - loss: 1.4055

476/562 [========================>.....] - ETA: 0s - loss: 1.3548

527/562 [===========================>..] - ETA: 0s - loss: 1.3048

562/562 [==============================] - 1s 1ms/step - loss: 1.3022
Epoch 20/50
  1/562 [..............................] - ETA: 2s - loss: 0.8414

 53/562 [=>............................] - ETA: 0s - loss: 2.2780

104/562 [====>.........................] - ETA: 0s - loss: 1.6508

146/562 [======>.......................] - ETA: 0s - loss: 1.6625

195/562 [=========>....................] - ETA: 0s - loss: 2.0398

240/562 [===========>..................] - ETA: 0s - loss: 1.8709

286/562 [==============>...............] - ETA: 0s - loss: 1.7194

333/562 [================>.............] - ETA: 0s - loss: 1.5898

383/562 [===================>..........] - ETA: 0s - loss: 1.4710

434/562 [======================>.......] - ETA: 0s - loss: 1.3882

483/562 [========================>.....] - ETA: 0s - loss: 1.3346

533/562 [===========================>..] - ETA: 0s - loss: 1.2983

562/562 [==============================] - 1s 1ms/step - loss: 1.2966
Epoch 21/50
  1/562 [..............................] - ETA: 2s - loss: 0.8345

 48/562 [=>............................] - ETA: 0s - loss: 2.3355

 99/562 [====>.........................] - ETA: 0s - loss: 1.6840

149/562 [======>.......................] - ETA: 0s - loss: 1.7510

197/562 [=========>....................] - ETA: 0s - loss: 2.0214

246/562 [============>.................] - ETA: 0s - loss: 1.8298

296/562 [==============>...............] - ETA: 0s - loss: 1.6762

344/562 [=================>............] - ETA: 0s - loss: 1.5566

394/562 [====================>.........] - ETA: 0s - loss: 1.4409

444/562 [======================>.......] - ETA: 0s - loss: 1.3682

495/562 [=========================>....] - ETA: 0s - loss: 1.3086

542/562 [===========================>..] - ETA: 0s - loss: 1.3048

562/562 [==============================] - 1s 1ms/step - loss: 1.2910
Epoch 22/50
  1/562 [..............................] - ETA: 2s - loss: 0.8282

 44/562 [=>............................] - ETA: 0s - loss: 2.1121

 90/562 [===>..........................] - ETA: 0s - loss: 1.7525

135/562 [======>.......................] - ETA: 0s - loss: 1.6330

185/562 [========>.....................] - ETA: 0s - loss: 2.0440

232/562 [===========>..................] - ETA: 0s - loss: 1.8797

279/562 [=============>................] - ETA: 0s - loss: 1.7116

330/562 [================>.............] - ETA: 0s - loss: 1.5811

380/562 [===================>..........] - ETA: 0s - loss: 1.4643

430/562 [=====================>........] - ETA: 0s - loss: 1.3812

480/562 [========================>.....] - ETA: 0s - loss: 1.3287

528/562 [===========================>..] - ETA: 0s - loss: 1.2899

562/562 [==============================] - 1s 1ms/step - loss: 1.2857
Epoch 23/50
  1/562 [..............................] - ETA: 2s - loss: 0.8234

 51/562 [=>............................] - ETA: 0s - loss: 2.2588

102/562 [====>.........................] - ETA: 0s - loss: 1.6321

152/562 [=======>......................] - ETA: 0s - loss: 1.8342

202/562 [=========>....................] - ETA: 0s - loss: 1.9913

252/562 [============>.................] - ETA: 0s - loss: 1.7918

295/562 [==============>...............] - ETA: 0s - loss: 1.6625

337/562 [================>.............] - ETA: 0s - loss: 1.5556

387/562 [===================>..........] - ETA: 0s - loss: 1.4399

437/562 [======================>.......] - ETA: 0s - loss: 1.3652

483/562 [========================>.....] - ETA: 0s - loss: 1.3173

533/562 [===========================>..] - ETA: 0s - loss: 1.2823

562/562 [==============================] - 1s 1ms/step - loss: 1.2808
Epoch 24/50
  1/562 [..............................] - ETA: 2s - loss: 0.8181

 50/562 [=>............................] - ETA: 0s - loss: 2.2816

100/562 [====>.........................] - ETA: 0s - loss: 1.6510

149/562 [======>.......................] - ETA: 0s - loss: 1.7295

198/562 [=========>....................] - ETA: 0s - loss: 1.9842

246/562 [============>.................] - ETA: 0s - loss: 1.8026

296/562 [==============>...............] - ETA: 0s - loss: 1.6525

343/562 [=================>............] - ETA: 0s - loss: 1.5369

391/562 [===================>..........] - ETA: 0s - loss: 1.4259

440/562 [======================>.......] - ETA: 0s - loss: 1.3541

490/562 [=========================>....] - ETA: 0s - loss: 1.3004

536/562 [===========================>..] - ETA: 0s - loss: 1.2849

562/562 [==============================] - 1s 1ms/step - loss: 1.2762
Epoch 25/50
  1/562 [..............................] - ETA: 2s - loss: 0.8123

 52/562 [=>............................] - ETA: 0s - loss: 2.2268

103/562 [====>.........................] - ETA: 0s - loss: 1.6160

148/562 [======>.......................] - ETA: 0s - loss: 1.6781

198/562 [=========>....................] - ETA: 0s - loss: 1.9762

246/562 [============>.................] - ETA: 0s - loss: 1.7948

294/562 [==============>...............] - ETA: 0s - loss: 1.6491

346/562 [=================>............] - ETA: 0s - loss: 1.5292

393/562 [===================>..........] - ETA: 0s - loss: 1.4164

442/562 [======================>.......] - ETA: 0s - loss: 1.3462

494/562 [=========================>....] - ETA: 0s - loss: 1.2899

545/562 [============================>.] - ETA: 0s - loss: 1.2847

562/562 [==============================] - 1s 1ms/step - loss: 1.2720
Epoch 26/50
  1/562 [..............................] - ETA: 2s - loss: 0.8068

 51/562 [=>............................] - ETA: 0s - loss: 2.2287

102/562 [====>.........................] - ETA: 0s - loss: 1.6163

154/562 [=======>......................] - ETA: 0s - loss: 1.8418

203/562 [=========>....................] - ETA: 0s - loss: 1.9603

253/562 [============>.................] - ETA: 0s - loss: 1.7655

302/562 [===============>..............] - ETA: 0s - loss: 1.6225

353/562 [=================>............] - ETA: 0s - loss: 1.5022

404/562 [====================>.........] - ETA: 0s - loss: 1.3931

449/562 [======================>.......] - ETA: 0s - loss: 1.3347

499/562 [=========================>....] - ETA: 0s - loss: 1.2809

550/562 [============================>.] - ETA: 0s - loss: 1.2767

562/562 [==============================] - 1s 1ms/step - loss: 1.2681
Epoch 27/50
  1/562 [..............................] - ETA: 1s - loss: 0.8015

 52/562 [=>............................] - ETA: 0s - loss: 2.2094

102/562 [====>.........................] - ETA: 0s - loss: 1.6116

151/562 [=======>......................] - ETA: 0s - loss: 1.8036

202/562 [=========>....................] - ETA: 0s - loss: 1.9596

252/562 [============>.................] - ETA: 0s - loss: 1.7624

303/562 [===============>..............] - ETA: 0s - loss: 1.6163

355/562 [=================>............] - ETA: 0s - loss: 1.4911

404/562 [====================>.........] - ETA: 0s - loss: 1.3885

453/562 [=======================>......] - ETA: 0s - loss: 1.3239

504/562 [=========================>....] - ETA: 0s - loss: 1.2732

544/562 [============================>.] - ETA: 0s - loss: 1.2774

562/562 [==============================] - 1s 1ms/step - loss: 1.2645
Epoch 28/50
  1/562 [..............................] - ETA: 2s - loss: 0.7968

 48/562 [=>............................] - ETA: 0s - loss: 2.2621

 97/562 [====>.........................] - ETA: 0s - loss: 1.6628

147/562 [======>.......................] - ETA: 0s - loss: 1.6357

199/562 [=========>....................] - ETA: 0s - loss: 1.9562

249/562 [============>.................] - ETA: 0s - loss: 1.7619

298/562 [==============>...............] - ETA: 0s - loss: 1.6222

349/562 [=================>............] - ETA: 0s - loss: 1.5068

399/562 [====================>.........] - ETA: 0s - loss: 1.3934

450/562 [=======================>......] - ETA: 0s - loss: 1.3242

501/562 [=========================>....] - ETA: 0s - loss: 1.2728

546/562 [============================>.] - ETA: 0s - loss: 1.2727

562/562 [==============================] - 1s 1ms/step - loss: 1.2610
Epoch 29/50
  1/562 [..............................] - ETA: 3s - loss: 0.7925

 47/562 [=>............................] - ETA: 0s - loss: 2.2573

 94/562 [====>.........................] - ETA: 0s - loss: 1.6827

143/562 [======>.......................] - ETA: 0s - loss: 1.5765

195/562 [=========>....................] - ETA: 0s - loss: 1.9645

243/562 [===========>..................] - ETA: 0s - loss: 1.7833

293/562 [==============>...............] - ETA: 0s - loss: 1.6288

342/562 [=================>............] - ETA: 0s - loss: 1.5135

390/562 [===================>..........] - ETA: 0s - loss: 1.4040

440/562 [======================>.......] - ETA: 0s - loss: 1.3322

492/562 [=========================>....] - ETA: 0s - loss: 1.2769

543/562 [===========================>..] - ETA: 0s - loss: 1.2704

562/562 [==============================] - 1s 1ms/step - loss: 1.2577
Epoch 30/50
  1/562 [..............................] - ETA: 2s - loss: 0.7886

 50/562 [=>............................] - ETA: 0s - loss: 2.2312

 90/562 [===>..........................] - ETA: 0s - loss: 1.7071

131/562 [=====>........................] - ETA: 0s - loss: 1.5714

177/562 [========>.....................] - ETA: 0s - loss: 2.0154

217/562 [==========>...................] - ETA: 0s - loss: 1.8847

262/562 [============>.................] - ETA: 0s - loss: 1.7176

306/562 [===============>..............] - ETA: 0s - loss: 1.5964

357/562 [==================>...........] - ETA: 0s - loss: 1.4722

403/562 [====================>.........] - ETA: 0s - loss: 1.3774

453/562 [=======================>......] - ETA: 0s - loss: 1.3123

502/562 [=========================>....] - ETA: 0s - loss: 1.2652

553/562 [============================>.] - ETA: 0s - loss: 1.2588

562/562 [==============================] - 1s 1ms/step - loss: 1.2545
Epoch 31/50
  1/562 [..............................] - ETA: 2s - loss: 0.7848

 49/562 [=>............................] - ETA: 0s - loss: 2.2552

 97/562 [====>.........................] - ETA: 0s - loss: 1.6490

145/562 [======>.......................] - ETA: 0s - loss: 1.5927

193/562 [=========>....................] - ETA: 0s - loss: 1.9622

241/562 [===========>..................] - ETA: 0s - loss: 1.7806

289/562 [==============>...............] - ETA: 0s - loss: 1.6317

340/562 [=================>............] - ETA: 0s - loss: 1.5062

386/562 [===================>..........] - ETA: 0s - loss: 1.4034

437/562 [======================>.......] - ETA: 0s - loss: 1.3301

480/562 [========================>.....] - ETA: 0s - loss: 1.2908

526/562 [===========================>..] - ETA: 0s - loss: 1.2529

562/562 [==============================] - 1s 1ms/step - loss: 1.2513
Epoch 32/50
  1/562 [..............................] - ETA: 1s - loss: 0.7810

 52/562 [=>............................] - ETA: 0s - loss: 2.1720

104/562 [====>.........................] - ETA: 0s - loss: 1.5890

148/562 [======>.......................] - ETA: 0s - loss: 1.6462

199/562 [=========>....................] - ETA: 0s - loss: 1.9317

250/562 [============>.................] - ETA: 0s - loss: 1.7386

300/562 [===============>..............] - ETA: 0s - loss: 1.5963

350/562 [=================>............] - ETA: 0s - loss: 1.4845

401/562 [====================>.........] - ETA: 0s - loss: 1.3735

452/562 [=======================>......] - ETA: 0s - loss: 1.3059

502/562 [=========================>....] - ETA: 0s - loss: 1.2585

553/562 [============================>.] - ETA: 0s - loss: 1.2525

562/562 [==============================] - 1s 1ms/step - loss: 1.2482
Epoch 33/50
  1/562 [..............................] - ETA: 2s - loss: 0.7769

 50/562 [=>............................] - ETA: 0s - loss: 2.2113

100/562 [====>.........................] - ETA: 0s - loss: 1.6097

150/562 [=======>......................] - ETA: 0s - loss: 1.7347

201/562 [=========>....................] - ETA: 0s - loss: 1.9247

251/562 [============>.................] - ETA: 0s - loss: 1.7295

295/562 [==============>...............] - ETA: 0s - loss: 1.6056

339/562 [=================>............] - ETA: 0s - loss: 1.5008

385/562 [===================>..........] - ETA: 0s - loss: 1.3984

428/562 [=====================>........] - ETA: 0s - loss: 1.3354

467/562 [=======================>......] - ETA: 0s - loss: 1.2954

508/562 [==========================>...] - ETA: 0s - loss: 1.2510

553/562 [============================>.] - ETA: 0s - loss: 1.2494

562/562 [==============================] - 1s 1ms/step - loss: 1.2452
Epoch 34/50
  1/562 [..............................] - ETA: 2s - loss: 0.7724

 43/562 [=>............................] - ETA: 0s - loss: 1.9834

 92/562 [===>..........................] - ETA: 0s - loss: 1.6649

138/562 [======>.......................] - ETA: 0s - loss: 1.5814

188/562 [=========>....................] - ETA: 0s - loss: 1.9516

239/562 [===========>..................] - ETA: 0s - loss: 1.7745

279/562 [=============>................] - ETA: 0s - loss: 1.6393

320/562 [================>.............] - ETA: 0s - loss: 1.5406

370/562 [==================>...........] - ETA: 0s - loss: 1.4179

421/562 [=====================>........] - ETA: 0s - loss: 1.3415

459/562 [=======================>......] - ETA: 0s - loss: 1.2905

507/562 [==========================>...] - ETA: 0s - loss: 1.2482

553/562 [============================>.] - ETA: 0s - loss: 1.2465

562/562 [==============================] - 1s 1ms/step - loss: 1.2423
Epoch 35/50
  1/562 [..............................] - ETA: 1s - loss: 0.7675

 53/562 [=>............................] - ETA: 0s - loss: 2.1508

101/562 [====>.........................] - ETA: 0s - loss: 1.5938

147/562 [======>.......................] - ETA: 0s - loss: 1.6086

197/562 [=========>....................] - ETA: 0s - loss: 1.9238

246/562 [============>.................] - ETA: 0s - loss: 1.7357

293/562 [==============>...............] - ETA: 0s - loss: 1.5997

345/562 [=================>............] - ETA: 0s - loss: 1.4856

394/562 [====================>.........] - ETA: 0s - loss: 1.3742

444/562 [======================>.......] - ETA: 0s - loss: 1.3079

495/562 [=========================>....] - ETA: 0s - loss: 1.2531

540/562 [===========================>..] - ETA: 0s - loss: 1.2514

562/562 [==============================] - 1s 1ms/step - loss: 1.2396
Epoch 36/50
  1/562 [..............................] - ETA: 2s - loss: 0.7628

 45/562 [=>............................] - ETA: 0s - loss: 2.1380

 87/562 [===>..........................] - ETA: 0s - loss: 1.7140

134/562 [======>.......................] - ETA: 0s - loss: 1.5775

183/562 [========>.....................] - ETA: 0s - loss: 1.9560

233/562 [===========>..................] - ETA: 0s - loss: 1.7878

284/562 [==============>...............] - ETA: 0s - loss: 1.6276

335/562 [================>.............] - ETA: 0s - loss: 1.4960

386/562 [===================>..........] - ETA: 0s - loss: 1.3850

435/562 [======================>.......] - ETA: 0s - loss: 1.3165

484/562 [========================>.....] - ETA: 0s - loss: 1.2672

532/562 [===========================>..] - ETA: 0s - loss: 1.2367

562/562 [==============================] - 1s 1ms/step - loss: 1.2369
Epoch 37/50
  1/562 [..............................] - ETA: 1s - loss: 0.7584

 52/562 [=>............................] - ETA: 0s - loss: 2.1503

103/562 [====>.........................] - ETA: 0s - loss: 1.5725

142/562 [======>.......................] - ETA: 0s - loss: 1.5503

188/562 [=========>....................] - ETA: 0s - loss: 1.9374

238/562 [===========>..................] - ETA: 0s - loss: 1.7659

287/562 [==============>...............] - ETA: 0s - loss: 1.6133

337/562 [================>.............] - ETA: 0s - loss: 1.4888

384/562 [===================>..........] - ETA: 0s - loss: 1.3871

433/562 [======================>.......] - ETA: 0s - loss: 1.3159

475/562 [========================>.....] - ETA: 0s - loss: 1.2808

518/562 [==========================>...] - ETA: 0s - loss: 1.2370

562/562 [==============================] - 1s 1ms/step - loss: 1.2343
Epoch 38/50
  1/562 [..............................] - ETA: 3s - loss: 0.7544

 46/562 [=>............................] - ETA: 0s - loss: 2.1938

 92/562 [===>..........................] - ETA: 0s - loss: 1.6552

141/562 [======>.......................] - ETA: 0s - loss: 1.5534

186/562 [========>.....................] - ETA: 0s - loss: 1.9362

226/562 [===========>..................] - ETA: 0s - loss: 1.7980

268/562 [=============>................] - ETA: 0s - loss: 1.6587

296/562 [==============>...............] - ETA: 0s - loss: 1.5822

347/562 [=================>............] - ETA: 0s - loss: 1.4700

395/562 [====================>.........] - ETA: 0s - loss: 1.3637

442/562 [======================>.......] - ETA: 0s - loss: 1.2992

492/562 [=========================>....] - ETA: 0s - loss: 1.2489

543/562 [===========================>..] - ETA: 0s - loss: 1.2439

562/562 [==============================] - 1s 1ms/step - loss: 1.2318
Epoch 39/50
  1/562 [..............................] - ETA: 2s - loss: 0.7511

 46/562 [=>............................] - ETA: 0s - loss: 2.1897

 96/562 [====>.........................] - ETA: 0s - loss: 1.6380

142/562 [======>.......................] - ETA: 0s - loss: 1.5463

187/562 [========>.....................] - ETA: 0s - loss: 1.9328

233/562 [===========>..................] - ETA: 0s - loss: 1.7743

281/562 [==============>...............] - ETA: 0s - loss: 1.6143

326/562 [================>.............] - ETA: 0s - loss: 1.5080

375/562 [===================>..........] - ETA: 0s - loss: 1.3913

422/562 [=====================>........] - ETA: 0s - loss: 1.3244

472/562 [========================>.....] - ETA: 0s - loss: 1.2775

522/562 [==========================>...] - ETA: 0s - loss: 1.2260

562/562 [==============================] - 1s 1ms/step - loss: 1.2293
Epoch 40/50
  1/562 [..............................] - ETA: 1s - loss: 0.7491

 48/562 [=>............................] - ETA: 0s - loss: 2.1971

 98/562 [====>.........................] - ETA: 0s - loss: 1.6123

140/562 [======>.......................] - ETA: 0s - loss: 1.5547

190/562 [=========>....................] - ETA: 0s - loss: 1.9150

238/562 [===========>..................] - ETA: 0s - loss: 1.7527

287/562 [==============>...............] - ETA: 0s - loss: 1.6014

337/562 [================>.............] - ETA: 0s - loss: 1.4782

387/562 [===================>..........] - ETA: 0s - loss: 1.3701

435/562 [======================>.......] - ETA: 0s - loss: 1.3048

476/562 [========================>.....] - ETA: 0s - loss: 1.2710

524/562 [==========================>...] - ETA: 0s - loss: 1.2289

562/562 [==============================] - 1s 1ms/step - loss: 1.2267
Epoch 41/50
  1/562 [..............................] - ETA: 2s - loss: 0.7472

 39/562 [=>............................] - ETA: 0s - loss: 1.9681

 86/562 [===>..........................] - ETA: 0s - loss: 1.7115

132/562 [======>.......................] - ETA: 0s - loss: 1.5413

182/562 [========>.....................] - ETA: 0s - loss: 1.9395

223/562 [==========>...................] - ETA: 0s - loss: 1.7963

271/562 [=============>................] - ETA: 0s - loss: 1.6369

316/562 [===============>..............] - ETA: 0s - loss: 1.5235

358/562 [==================>...........] - ETA: 0s - loss: 1.4286

407/562 [====================>.........] - ETA: 0s - loss: 1.3343

455/562 [=======================>......] - ETA: 0s - loss: 1.2737

505/562 [=========================>....] - ETA: 0s - loss: 1.2299

554/562 [============================>.] - ETA: 0s - loss: 1.2268

562/562 [==============================] - 1s 1ms/step - loss: 1.2243
Epoch 42/50
  1/562 [..............................] - ETA: 3s - loss: 0.7455

 51/562 [=>............................] - ETA: 0s - loss: 2.1430

 95/562 [====>.........................] - ETA: 0s - loss: 1.6366

139/562 [======>.......................] - ETA: 0s - loss: 1.5575

187/562 [========>.....................] - ETA: 0s - loss: 1.9194

237/562 [===========>..................] - ETA: 0s - loss: 1.7486

286/562 [==============>...............] - ETA: 0s - loss: 1.5980

335/562 [================>.............] - ETA: 0s - loss: 1.4746

387/562 [===================>..........] - ETA: 0s - loss: 1.3639

436/562 [======================>.......] - ETA: 0s - loss: 1.2974

478/562 [========================>.....] - ETA: 0s - loss: 1.2619

530/562 [===========================>..] - ETA: 0s - loss: 1.2238

562/562 [==============================] - 1s 1ms/step - loss: 1.2219
Epoch 43/50
  1/562 [..............................] - ETA: 2s - loss: 0.7438

 52/562 [=>............................] - ETA: 0s - loss: 2.1297

 99/562 [====>.........................] - ETA: 0s - loss: 1.6005

147/562 [======>.......................] - ETA: 0s - loss: 1.5892

197/562 [=========>....................] - ETA: 0s - loss: 1.8886

249/562 [============>.................] - ETA: 0s - loss: 1.6900

299/562 [==============>...............] - ETA: 0s - loss: 1.5545

349/562 [=================>............] - ETA: 0s - loss: 1.4492

400/562 [====================>.........] - ETA: 0s - loss: 1.3399

446/562 [======================>.......] - ETA: 0s - loss: 1.2831

498/562 [=========================>....] - ETA: 0s - loss: 1.2290

548/562 [============================>.] - ETA: 0s - loss: 1.2286

562/562 [==============================] - 1s 1ms/step - loss: 1.2196
Epoch 44/50
  1/562 [..............................] - ETA: 2s - loss: 0.7419

 48/562 [=>............................] - ETA: 0s - loss: 2.1816

 91/562 [===>..........................] - ETA: 0s - loss: 1.6538

141/562 [======>.......................] - ETA: 0s - loss: 1.5412

188/562 [=========>....................] - ETA: 0s - loss: 1.9067

236/562 [===========>..................] - ETA: 0s - loss: 1.7443

285/562 [==============>...............] - ETA: 0s - loss: 1.5938

336/562 [================>.............] - ETA: 0s - loss: 1.4652

386/562 [===================>..........] - ETA: 0s - loss: 1.3599

434/562 [======================>.......] - ETA: 0s - loss: 1.2944

483/562 [========================>.....] - ETA: 0s - loss: 1.2479

532/562 [===========================>..] - ETA: 0s - loss: 1.2169

562/562 [==============================] - 1s 1ms/step - loss: 1.2173
Epoch 45/50
  1/562 [..............................] - ETA: 2s - loss: 0.7401

 52/562 [=>............................] - ETA: 0s - loss: 2.1228

 92/562 [===>..........................] - ETA: 0s - loss: 1.6399

132/562 [======>.......................] - ETA: 0s - loss: 1.5333

175/562 [========>.....................] - ETA: 0s - loss: 1.9391

221/562 [==========>...................] - ETA: 0s - loss: 1.7902

256/562 [============>.................] - ETA: 0s - loss: 1.6661

306/562 [===============>..............] - ETA: 0s - loss: 1.5354

353/562 [=================>............] - ETA: 0s - loss: 1.4292

399/562 [====================>.........] - ETA: 0s - loss: 1.3347

446/562 [======================>.......] - ETA: 0s - loss: 1.2778

497/562 [=========================>....] - ETA: 0s - loss: 1.2255

547/562 [============================>.] - ETA: 0s - loss: 1.2244

562/562 [==============================] - 1s 1ms/step - loss: 1.2151
Epoch 46/50
  1/562 [..............................] - ETA: 2s - loss: 0.7390

 52/562 [=>............................] - ETA: 0s - loss: 2.1195

104/562 [====>.........................] - ETA: 0s - loss: 1.5579

154/562 [=======>......................] - ETA: 0s - loss: 1.7664

206/562 [=========>....................] - ETA: 0s - loss: 1.8479

258/562 [============>.................] - ETA: 0s - loss: 1.6602

308/562 [===============>..............] - ETA: 0s - loss: 1.5256

360/562 [==================>...........] - ETA: 0s - loss: 1.4070

410/562 [====================>.........] - ETA: 0s - loss: 1.3198

461/562 [=======================>......] - ETA: 0s - loss: 1.2586

507/562 [==========================>...] - ETA: 0s - loss: 1.2174

554/562 [============================>.] - ETA: 0s - loss: 1.2154

562/562 [==============================] - 1s 1ms/step - loss: 1.2130
Epoch 47/50
  1/562 [..............................] - ETA: 1s - loss: 0.7381

 42/562 [=>............................] - ETA: 0s - loss: 1.8738

 88/562 [===>..........................] - ETA: 0s - loss: 1.6735

139/562 [======>.......................] - ETA: 0s - loss: 1.5463

190/562 [=========>....................] - ETA: 0s - loss: 1.8859

233/562 [===========>..................] - ETA: 0s - loss: 1.7410

283/562 [==============>...............] - ETA: 0s - loss: 1.5868

330/562 [================>.............] - ETA: 0s - loss: 1.4713

375/562 [===================>..........] - ETA: 0s - loss: 1.3671

420/562 [=====================>........] - ETA: 0s - loss: 1.3055

469/562 [========================>.....] - ETA: 0s - loss: 1.2586

519/562 [==========================>...] - ETA: 0s - loss: 1.2120

562/562 [==============================] - 1s 1ms/step - loss: 1.2109
Epoch 48/50
  1/562 [..............................] - ETA: 2s - loss: 0.7373

 53/562 [=>............................] - ETA: 0s - loss: 2.1059

100/562 [====>.........................] - ETA: 0s - loss: 1.5773

146/562 [======>.......................] - ETA: 0s - loss: 1.5642

191/562 [=========>....................] - ETA: 0s - loss: 1.8766

240/562 [===========>..................] - ETA: 0s - loss: 1.7101

292/562 [==============>...............] - ETA: 0s - loss: 1.5522

330/562 [================>.............] - ETA: 0s - loss: 1.4684

380/562 [===================>..........] - ETA: 0s - loss: 1.3638

425/562 [=====================>........] - ETA: 0s - loss: 1.2959

473/562 [========================>.....] - ETA: 0s - loss: 1.2556

518/562 [==========================>...] - ETA: 0s - loss: 1.2111

562/562 [==============================] - 1s 1ms/step - loss: 1.2089
Epoch 49/50
  1/562 [..............................] - ETA: 1s - loss: 0.7364

 44/562 [=>............................] - ETA: 0s - loss: 1.9689

 93/562 [===>..........................] - ETA: 0s - loss: 1.6202

143/562 [======>.......................] - ETA: 0s - loss: 1.5265

192/562 [=========>....................] - ETA: 0s - loss: 1.8700

240/562 [===========>..................] - ETA: 0s - loss: 1.7067

290/562 [==============>...............] - ETA: 0s - loss: 1.5573

336/562 [================>.............] - ETA: 0s - loss: 1.4503

381/562 [===================>..........] - ETA: 0s - loss: 1.3595

432/562 [======================>.......] - ETA: 0s - loss: 1.2857

483/562 [========================>.....] - ETA: 0s - loss: 1.2367

531/562 [===========================>..] - ETA: 0s - loss: 1.2081

562/562 [==============================] - 1s 1ms/step - loss: 1.2070
Epoch 50/50
  1/562 [..............................] - ETA: 2s - loss: 0.7356

 50/562 [=>............................] - ETA: 0s - loss: 2.1505

 99/562 [====>.........................] - ETA: 0s - loss: 1.5858

147/562 [======>.......................] - ETA: 0s - loss: 1.5718

194/562 [=========>....................] - ETA: 0s - loss: 1.8722

245/562 [============>.................] - ETA: 0s - loss: 1.6821

291/562 [==============>...............] - ETA: 0s - loss: 1.5501

331/562 [================>.............] - ETA: 0s - loss: 1.4616

372/562 [==================>...........] - ETA: 0s - loss: 1.3665

421/562 [=====================>........] - ETA: 0s - loss: 1.2973

473/562 [========================>.....] - ETA: 0s - loss: 1.2513

524/562 [==========================>...] - ETA: 0s - loss: 1.2070

562/562 [==============================] - 1s 1ms/step - loss: 1.2051
<keras.callbacks.History at 0x7f455ccc2340>

That was quite simple!

For reference, we got a loss about as low as 3 in our previous work. It was also much faster, thanks to the optimizations. Now let’s see how our model did on the test data

# get model predictions on test data and get labels
# squeeze to remove extra dimensions
yhat = np.squeeze(model.predict(test_data))
test_y = soldata["Solubility"].values[:test_N]
plt.plot(test_y, yhat, ".")
plt.plot(test_y, test_y, "-")
plt.xlabel("Measured Solubility $y$")
plt.ylabel("Predicted Solubility $\hat{y}$")
plt.text(
    min(test_y) + 1,
    max(test_y) - 2,
    f"correlation = {np.corrcoef(test_y, yhat)[0,1]:.3f}",
)
plt.text(
    min(test_y) + 1,
    max(test_y) - 3,
    f"loss = {np.sqrt(np.mean((test_y - yhat)**2)):.3f}",
)
plt.show()
../_images/introduction_16_0.png

This performance is better than our simple linear model.

6.5. Exercises

  1. Make a plot of the ReLU function. Prove it is non-linear.

  2. Try increasing the number of layers in the neural network. Discuss what you see in context of the bias-variance trade off

  3. Show that a neural network would be equivalent to linear regression if \(\sigma(\cdot)\) was the identity function

  4. What are the advantages and disadvantages of using deep learning instead of non-linear regression for fitting data? When might you choose non-linear regression over deep learning?

6.6. Chapter Summary

  • Deep learning is a category of machine learning that utilizes neural networks for classification and regression of data.

  • Neural networks are a series of operations with matrices of adjustable parameters.

  • A neural network transforms input features into a new set of features that can be subsequently used for regression or classification.

  • The most common layer is the dense layer. Each input element affects each output element. It is defined by the desired output feature shape and the activation function.

  • With enough layers or wide enough hidden layers, neural networks can approximate unknown functions.

  • Hidden layers are called such because we do not observe the output from one.

  • Using libraries such as TensorFlow, it becomes easy to split data into training and testing, but also to build layers in the neural network.

  • Building a neural network allows us to predict various properties of molecules, such as solubility.

6.7. Cited References

LPW+17

Zhou Lu, Hongming Pu, Feicheng Wang, Zhiqiang Hu, and Liwei Wang. The expressive power of neural networks: a view from the width. In Proceedings of the 31st International Conference on Neural Information Processing Systems, 6232–6240. 2017.