Previously, we discussed how to save and serialize your models to disk after training is complete. We also learned how to spot underfitting and overfitting as they are happening, enabling you to kill off experiments that are not performing well while keeping the models that show promise while training.
A substantial dataset is useful when working with the ModelCheckpoint Callback in Keras. It allows us to see the Callback’s functionality in saving model weights during training, based on specific performance metrics.
Roboflow has free tools for each stage of the computer vision pipeline that will streamline your workflows and supercharge your productivity.
Sign up or Log in to your Roboflow account to access state of the art dataset libaries and revolutionize your computer vision pipeline.
You can start by choosing your own datasets or using our PyimageSearch’s assorted library of useful datasets.
Bring data in any of 40+ formats to Roboflow, train using any state-of-the-art model architectures, deploy across multiple platforms (API, NVIDIA, browser, iOS, etc), and connect to applications or 3rd party tools.
However, you might be wondering if it’s possible to combine both of these strategies. Can we serialize models whenever our loss/accuracy improves? Or is it possible to serialize only the best model (i.e., the one with the lowest loss or highest accuracy) during the training process? You bet. And luckily, we don’t have to build a custom callback either — this functionality is baked right into Keras.
To learn how to use the ModelCheckpoint callback with Keras and TensorFlow, just keep reading.
Looking for the source code to this post?
Jump Right To The Downloads SectionHow to use the ModelCheckpoint callback with Keras and TensorFlow
A good application of checkpointing is to serialize your network to disk each time there is an improvement during training. We define an “improvement” to be either a decrease in loss or an increase in accuracy — we’ll set this parameter inside the actual Keras callback.
In this example, we’ll be training the MiniVGGNet architecture on the CIFAR-10 dataset and then serializing our network weights to disk each time model performance improves. To get started, open a new file, name it cifar10_checkpoint_improvements.py
, and insert the following code:
# import the necessary packages from sklearn.preprocessing import LabelBinarizer from pyimagesearch.nn.conv import MiniVGGNet from tensorflow.keras.callbacks import ModelCheckpoint from tensorflow.keras.optimizers import SGD from tensorflow.keras.datasets import cifar10 import argparse import os
Lines 2-8 import our required Python packages. Take note of the ModelCheckpoint
class imported on Line 4 — this class will enable us to checkpoint and serialize our networks to disk whenever we find an incremental improvement in model performance.
Next, let’s parse our command line arguments:
# construct the argument parse and parse the arguments ap = argparse.ArgumentParser() ap.add_argument("-w", "--weights", required=True, help="path to weights directory") args = vars(ap.parse_args())
The only command line argument we need is --weights
, the path to the output directory that will store our serialized models during the training process. We then perform our standard routine of loading the CIFAR-10 dataset from disk, scaling the pixel intensities to the range [0, 1]
, and then one-hot encoding the labels:
# load the training and testing data, then scale it into the # range [0, 1] print("[INFO] loading CIFAR-10 data...") ((trainX, trainY), (testX, testY)) = cifar10.load_data() trainX = trainX.astype("float") / 255.0 testX = testX.astype("float") / 255.0 # convert the labels from integers to vectors lb = LabelBinarizer() trainY = lb.fit_transform(trainY) testY = lb.transform(testY)
Given our data, we are now ready to initialize our SGD optimizer along with the MiniVGGNet architecture:
# initialize the optimizer and model print("[INFO] compiling model...") opt = SGD(lr=0.01, decay=0.01 / 40, momentum=0.9, nesterov=True) model = MiniVGGNet.build(width=32, height=32, depth=3, classes=10) model.compile(loss="categorical_crossentropy", optimizer=opt, metrics=["accuracy"])
We’ll use the SGD optimizer with an initial learning rate of α = 0.01 and then slowly decay it over the course of 40 epochs. We’ll also apply a momentum of γ = 0.9 and indicate that the Nesterov acceleration should also be used as well.
The MiniVGGNet architecture is instantiated to accept input images with a width of 32 pixels, a height of 32 pixels, and a depth of 3 (number of channels). We set classes=10
since the CIFAR-10 dataset has ten possible class labels.
The critical step to checkpointing our network can be found in the code block below:
# construct the callback to save only the *best* model to disk # based on the validation loss fname = os.path.sep.join([args["weights"], "weights-{epoch:03d}-{val_loss:.4f}.hdf5"]) checkpoint = ModelCheckpoint(fname, monitor="val_loss", mode="min", save_best_only=True, verbose=1) callbacks = [checkpoint]
On Lines 37 and 38, we construct a special filename (fname
) template string that Keras uses when writing our models to disk. The first variable in the template, {epoch:03d}
, is our epoch number, written out to three digits.
The second variable is the metric we want to monitor for improvement, {val_loss:.4f}
, the loss itself for validation set on the current epoch. Of course, if we wanted to monitor the validation accuracy we can replace val_loss
with val_acc
. If we instead wanted to monitor the training loss and accuracy the variable would become train_loss
and train_acc
, respectively (although I would recommend monitoring your validation metrics as they will give you a better sense on how your model will generalize).
Once the output filename template is defined, we then instantiate the ModelCheckpoint
class on Lines 39 and 40. The first parameter to ModelCheckpoint
is the string representing our filename template. We then pass in what we would like to monitor
. In this case, we would like to monitor the validation loss (val_loss
).
The mode
parameter controls whether the ModelCheckpoint
should be looking for values that minimize our metric or maximize it. Since we are working with loss, lower is better, so we set mode="min"
. If we were instead working with val_acc
, we would set mode="max"
(since higher accuracy is better).
Setting save_best_only=True
ensures that the latest best model (according to the metric monitored) will not be overwritten. Finally, the verbose=1
setting simply logs a notification to our terminal when a model is being serialized to disk during training.
Line 41 then constructs a list of callbacks
— the only callback we need is our checkpoint
.
The last step is to simply train the network and allowing our checkpoint
to take care of the rest:
# train the network print("[INFO] training network...") H = model.fit(trainX, trainY, validation_data=(testX, testY), batch_size=64, epochs=40, callbacks=callbacks, verbose=2)
To execute our script, simply open a terminal and execute the following command:
$ python cifar10_checkpoint_improvements.py --weights weights/improvements [INFO] loading CIFAR-10 data... [INFO] compiling model... [INFO] training network... Train on 50000 samples, validate on 10000 samples Epoch 1/40 171s - loss: 1.6700 - acc: 0.4375 - val_loss: 1.2697 - val_acc: 0.5425 Epoch 2/40 Epoch 00001: val_loss improved from 1.26973 to 0.98481, saving model to test/ weights-001-0.9848.hdf5 ... Epoch 40/40 Epoch 00039: val_loss did not improve 315s - loss: 0.2594 - acc: 0.9075 - val_loss: 0.5707 - val_acc: 0.8190
As we can see from my terminal output and Figure 1, every time the validation loss decreases we save a new serialized model to disk.
At the end of the training process, we have 18 separate files, one for each incremental improvement:
$ find ./ -printf "%f\n" | sort ./ weights-000-1.2697.hdf5 weights-001-0.9848.hdf5 weights-003-0.8176.hdf5 weights-004-0.7987.hdf5 weights-005-0.7722.hdf5 weights-006-0.6925.hdf5 weights-007-0.6846.hdf5 weights-008-0.6771.hdf5 weights-009-0.6212.hdf5 weights-012-0.6121.hdf5 weights-013-0.6101.hdf5 weights-014-0.5899.hdf5 weights-015-0.5811.hdf5 weights-017-0.5774.hdf5 weights-019-0.5740.hdf5 weights-022-0.5724.hdf5 weights-024-0.5628.hdf5 weights-033-0.5546.hdf5
As you can see, each filename has three components. The first is a static string, weights. We then have the epoch number. The final component of the filename is the metric we are measuring for improvement, which in this case is validation loss.
Our best validation loss was obtained on epoch 33 with a value of 0.5546. We could then take this model and load it from disk.
Keep in mind that your results will not match mine as networks are stochastic and initialized with random variables. Depending on the initial values, you might have dramatically different model checkpoints, but at the end of the training process, our networks should obtain similar accuracy (± a few percentage points).
Configuring your development environment
To follow this guide, you need to have the OpenCV library installed on your system.
Luckily, OpenCV is pip-installable:
$ pip install opencv-contrib-python
If you need help configuring your development environment for OpenCV, I highly recommend that you read my pip install OpenCV guide — it will have you up and running in a matter of minutes.
Having problems configuring your development environment?
All that said, are you:
- Short on time?
- Learning on your employer’s administratively locked system?
- Wanting to skip the hassle of fighting with the command line, package managers, and virtual environments?
- Ready to run the code right now on your Windows, macOS, or Linux system?
Then join PyImageSearch University today!
Gain access to Jupyter Notebooks for this tutorial and other PyImageSearch guides that are pre-configured to run on Google Colab’s ecosystem right in your web browser! No installation required.
And best of all, these Jupyter Notebooks will run on Windows, macOS, and Linux!
Checkpointing Best Neural Network Only
Perhaps the biggest downside with checkpointing incremental improvements is that we end up with a bunch of extra files that we are (unlikely) interested in, which is especially true if our validation loss moves up and down over training epochs — each of these incremental improvements will be captured and serialized to disk. In this case, it’s best to save only one model and simply overwrite it every time our metric improves during training.
Luckily, accomplishing this action is as simple as updating the ModelCheckpoint
class to accept a simple string (i.e., a file path without any template variables). Then, whenever our metric improves, that file is simply overwritten. To understand the process, let’s create a second Python file named cifar10_checkpoint_best.py
and review the differences.
First, we need to import our required Python packages:
# import the necessary packages from sklearn.preprocessing import LabelBinarizer from pyimagesearch.nn.conv import MiniVGGNet from tensorflow.keras.callbacks import ModelCheckpoint from tensorflow.keras.optimizers import SGD from tensorflow.keras.datasets import cifar10 import argparse
Then parse our command line arguments:
# construct the argument parse and parse the arguments ap = argparse.ArgumentParser() ap.add_argument("-w", "--weights", required=True, help="path to best model weights file") args = vars(ap.parse_args())
The name of the command line argument itself is the same (--weights
), but the description of the switch is now different: “path to best model weights file.” Thus, this command line argument will be a simple string to an output path — there will be no template applied to this string.
From there we can load our CIFAR-10 dataset and prepare it for training:
# load the training and testing data, then scale it into the # range [0, 1] print("[INFO] loading CIFAR-10 data...") ((trainX, trainY), (testX, testY)) = cifar10.load_data() trainX = trainX.astype("float") / 255.0 testX = testX.astype("float") / 255.0 # convert the labels from integers to vectors lb = LabelBinarizer() trainY = lb.fit_transform(trainY) testY = lb.transform(testY)
As well as initialize our SGD optimizer and MiniVGGNet architecture:
# initialize the optimizer and model print("[INFO] compiling model...") opt = SGD(lr=0.01, decay=0.01 / 40, momentum=0.9, nesterov=True) model = MiniVGGNet.build(width=32, height=32, depth=3, classes=10) model.compile(loss="categorical_crossentropy", optimizer=opt, metrics=["accuracy"])
We are now ready to update the ModelCheckpoint
code:
# construct the callback to save only the *best* model to disk # based on the validation loss checkpoint = ModelCheckpoint(args["weights"], monitor="val_loss", save_best_only=True, verbose=1) callbacks = [checkpoint]
Notice how the fname
template string is gone — all we are doing is supplying the value of --weights
to ModelCheckpoint
. Since there are no template values to fill in, Keras will simply overwrite the existing serialized weights file whenever our monitoring metric improves (in this case, validation loss).
Finally, we train on network in the code block below:
# train the network print("[INFO] training network...") H = model.fit(trainX, trainY, validation_data=(testX, testY), batch_size=64, epochs=40, callbacks=callbacks, verbose=2)
To execute our script, issue the following command:
$ python cifar10_checkpoint_best.py \ --weights weights/best/cifar10_best_weights.hdf5 [INFO] loading CIFAR-10 data... [INFO] compiling model... [INFO] training network... Train on 50000 samples, validate on 10000 samples Epoch 1/40 Epoch 00000: val_loss improved from inf to 1.26677, saving model to test_best/cifar10_best_weights.hdf5 305s - loss: 1.6657 - acc: 0.4441 - val_loss: 1.2668 - val_acc: 0.5584 Epoch 2/40 Epoch 00001: val_loss improved from 1.26677 to 1.21923, saving model to test_best/cifar10_best_weights.hdf5 309s - loss: 1.1996 - acc: 0.5828 - val_loss: 1.2192 - val_acc: 0.5798 ... Epoch 40/40 Epoch 00039: val_loss did not improve 173s - loss: 0.2615 - acc: 0.9079 - val_loss: 0.5511 - val_acc: 0.8250
Here, you can see that we overwrite our cifar10_best_weights.hdf5
file with the updated network only if our validation loss decreases. This has two primary benefits:
- There is only one serialized file at the end of the training process — the model epoch that obtained the lowest loss.
- We are not capturing “incremental improvements” where loss fluctuates up and down. Instead, we only save and overwrite the existing best model if our metric obtains a loss lower than all previous epochs.
To confirm this, take a look at my weights/best
directory where you can see there is only one output file:
$ ls -l weights/best/ total 17024 -rw-rw-r-- 1 adrian adrian 17431968 Apr 28 09:47 cifar10_best_weights.hdf5
You can then take this serialized MiniVGGNet and further evaluate it on the testing data or apply it to your own images.
What's next? We recommend PyImageSearch University.
86 total classes • 115+ hours of on-demand code walkthrough videos • Last updated: October 2024
★★★★★ 4.84 (128 Ratings) • 16,000+ Students Enrolled
I strongly believe that if you had the right teacher you could master computer vision and deep learning.
Do you think learning computer vision and deep learning has to be time-consuming, overwhelming, and complicated? Or has to involve complex mathematics and equations? Or requires a degree in computer science?
That’s not the case.
All you need to master computer vision and deep learning is for someone to explain things to you in simple, intuitive terms. And that’s exactly what I do. My mission is to change education and how complex Artificial Intelligence topics are taught.
If you're serious about learning computer vision, your next stop should be PyImageSearch University, the most comprehensive computer vision, deep learning, and OpenCV course online today. Here you’ll learn how to successfully and confidently apply computer vision to your work, research, and projects. Join me in computer vision mastery.
Inside PyImageSearch University you'll find:
- ✓ 86 courses on essential computer vision, deep learning, and OpenCV topics
- ✓ 86 Certificates of Completion
- ✓ 115+ hours of on-demand video
- ✓ Brand new courses released regularly, ensuring you can keep up with state-of-the-art techniques
- ✓ Pre-configured Jupyter Notebooks in Google Colab
- ✓ Run all code examples in your web browser — works on Windows, macOS, and Linux (no dev environment configuration required!)
- ✓ Access to centralized code repos for all 540+ tutorials on PyImageSearch
- ✓ Easy one-click downloads for code, datasets, pre-trained models, etc.
- ✓ Access on mobile, laptop, desktop, etc.
Summary
In this tutorial, we reviewed how to monitor a given metric (e.g., validation loss, validation accuracy, etc.) during training and then save high-performing networks to disk. There are two methods to accomplish this inside Keras:
- Checkpoint incremental improvements.
- Checkpoint only the best model found during the process.
Personally, I prefer the latter over the former since it results in fewer files and a single output file that represents the best epoch found during the training process.
To download the source code to this post (and be notified when future tutorials are published here on PyImageSearch), simply enter your email address in the form below!
Download the Source Code and FREE 17-page Resource Guide
Enter your email address below to get a .zip of the code and a FREE 17-page Resource Guide on Computer Vision, OpenCV, and Deep Learning. Inside you'll find my hand-picked tutorials, books, courses, and libraries to help you master CV and DL!
Comment section
Hey, Adrian Rosebrock here, author and creator of PyImageSearch. While I love hearing from readers, a couple years ago I made the tough decision to no longer offer 1:1 help over blog post comments.
At the time I was receiving 200+ emails per day and another 100+ blog post comments. I simply did not have the time to moderate and respond to them all, and the sheer volume of requests was taking a toll on me.
Instead, my goal is to do the most good for the computer vision, deep learning, and OpenCV community at large by focusing my time on authoring high-quality blog posts, tutorials, and books/courses.
If you need help learning computer vision and deep learning, I suggest you refer to my full catalog of books and courses — they have helped tens of thousands of developers, students, and researchers just like yourself learn Computer Vision, Deep Learning, and OpenCV.
Click here to browse my full catalog.