Last updated on July 4, 2021.
In this tutorial, you will learn how the Keras .fit
and .fit_generator
functions work, including the differences between them. To help you gain hands-on experience, I’ve included a full example showing you how to implement a Keras data generator from scratch.
A large dataset is crucial when using Keras’ fit and fit_generator functions. It allows us to understand the difference between the two and observe how they handle data of large volumes.
Roboflow has free tools for each stage of the computer vision pipeline that will streamline your workflows and supercharge your productivity.
Sign up or Log in to your Roboflow account to access state of the art dataset libaries and revolutionize your computer vision pipeline.
You can start by choosing your own datasets or using our PyimageSearch’s assorted library of useful datasets.
Bring data in any of 40+ formats to Roboflow, train using any state-of-the-art model architectures, deploy across multiple platforms (API, NVIDIA, browser, iOS, etc), and connect to applications or 3rd party tools.
Today’s blog post is inspired by PyImageSearch reader, Shey.
Shey asks:
Hi Adrian, thanks for your tutorials. I’ve been methodically going through every one. They’ve really helped me learn deep learning.
I have a question about the Keras “.fit_generator” function.
I’ve noticed you use it quite a bit in your blog posts but I’m not really sure how the function is different than Keras’ standard “.fit” function.
How is it different? How do I know when to use each? And how to I create a data generator for the “.fit_generator” function?
Shey asks a great question.
The Keras deep learning library includes three separate functions that can be used to train your own models:
.fit
.fit_generator
.train_on_batch
If you’re new to Keras and deep learning you may feel a bit overwhelmed trying to determine which function you’re supposed to use — this confusion is only compounded if you need to work with your own custom data.
To help lift the cloud of confusion regarding the Keras fit and fit_generator functions, I’m going to spend this tutorial discussing:
- The differences between Keras’
.fit
,.fit_generator
, and.train_on_batch
functions - When to use each when training your own deep learning models
- How to implement your own Keras data generator and utilize it when training a model using
.fit_generator
- How to use the
.predict_generator
function when evaluating your network after training
To learn more about Keras’ .fit
and .fit_generator
functions, including how to train a deep learning model on your own custom dataset, just keep reading!
- Update July 2021: For TensorFlow 2.2+ users, just use the
.fit
method for your projects. The.fit_generator
method will be deprecated in future releases of TensorFlow as the.fit
method can automatically detect if the input data is an array or a generator.
Looking for the source code to this post?
Jump Right To The Downloads SectionHow to use Keras fit and fit_generator (a hands-on tutorial)
2020-05-13 Update: This blog post is now TensorFlow 2+ compatible! TensorFlow is in the process of deprecating the .fit_generator
method which supported data augmentation. If you are using tensorflow==2.2.0
or tensorflow-gpu==2.2.0
(or higher), then you must use the .fit
method (which now supports data augmentation). Please keep this in mind while reading this legacy tutorial. Of course the concept of data augmentation stays the same. Please note that the code in the tutorial is updated for TensorFlow 2.2+ compatibility, however you may still see in-text references for the legacy .fit_generator
method.
In the first part of today’s tutorial we’ll discuss the differences between Keras’ .fit
, .fit_generator
, and .train_on_batch
functions.
From there I’ll show you an example of a “non-standard” image dataset which doesn’t contain any actual PNG, JPEG, etc. images at all! Instead, the entire image dataset is represented by two CSV files, one for training and the second for evaluation.
Our goal will be to implement a Keras generator capable of training a network on this CSV image data (don’t worry, I’ll show you how to implement such a generator function from scratch).
Finally, we’ll train and evaluate our network.
When to use Keras’ fit, fit_generator, and train_on_batch functions?
Keras provides three functions that can be used to train your own deep learning models:
.fit
.fit_generator
.train_on_batch
All three of these functions can essentially accomplish the same task — but how they go about doing it is very different.
Let’s explore each of these functions one-by-one, looking at an example function call, and then discussing how they are different from each other.
The Keras .fit function
Let’s start with a call to .fit
:
model.fit(trainX, trainY, batch_size=32, epochs=50)
Here you can see that we are supplying our training data (trainX
) and training labels (trainY
).
We then instruct Keras to allow our model to train for 50
epochs with a batch size of 32
.
The call to .fit
is making two primary assumptions here:
- Our entire training set can fit into RAM
- There is no data augmentation going on (i.e., there is no need for Keras generators)
Instead, our network will be trained on the raw data.
The raw data itself will fit into memory — we have no need to move old batches of data out of RAM and move new batches of data into RAM.
Furthermore, we will not be manipulating the training data on the fly using data augmentation.
The Keras fit_generator function
For small, simplistic datasets it’s perfectly acceptable to use Keras’ .fit
function.
These datasets are often not very challenging and do not require any data augmentation.
However, real-world datasets are rarely that simple:
- Real-world datasets are often too large to fit into memory.
- They also tend to be challenging, requiring us to perform data augmentation to avoid overfitting and increase the ability of our model to generalize.
In those situations we need to utilize Keras’ .fit_generator
function:
# initialize the number of epochs and batch size EPOCHS = 100 BS = 32 # construct the training image generator for data augmentation aug = ImageDataGenerator(rotation_range=20, zoom_range=0.15, width_shift_range=0.2, height_shift_range=0.2, shear_range=0.15, horizontal_flip=True, fill_mode="nearest") # train the network H = model.fit_generator(aug.flow(trainX, trainY, batch_size=BS), validation_data=(testX, testY), steps_per_epoch=len(trainX) // BS, epochs=EPOCHS)
2020-05-13 Update: With TensorFlow 2.2+ we now use .fit
instead of .fit_generator
which works the exact same way under the hood to accommodate data augmentation if the first argument provided is a Python generator object.
Here we start by first initializing the number of epochs we are going to train our network for along with the batch size.
We then initialize aug
, a Keras ImageDataGenerator
object that is used to apply data augmentation, randomly translating, rotating, resizing, etc. images on the fly.
Performing data augmentation is a form of regularization, enabling our model to generalize better.
However, applying data augmentation implies that our training data is no longer “static” — the data is constantly changing.
Each new batch of data is randomly adjusted according to the parameters supplied to ImageDataGenerator
.
Thus, we now need to utilize Keras’ .fit_generator
function to train our model.
As the name suggests, the .fit_generator
function assumes there is an underlying function that is generating the data for it.
The function itself is a Python generator.
Internally, Keras is using the following process when training a model with .fit_generator
:
- Keras calls the generator function supplied to
.fit_generator
(in this case,aug.flow
). - The generator function yields a batch of size
BS
to the.fit_generator
function. - The
.fit_generator
function accepts the batch of data, performs backpropagation, and updates the weights in our model. - This process is repeated until we have reached the desired number of epochs.
You’ll notice we now need to supply a steps_per_epoch
parameter when calling .fit_generator
(the .fit
method had no such parameter).
Why do we need steps_per_epoch
?
Keep in mind that a Keras data generator is meant to loop infinitely — it should never return or exit.
Since the function is intended to loop infinitely, Keras has no ability to determine when one epoch starts and a new epoch begins.
Therefore, we compute the steps_per_epoch
value as the total number of training data points divided by the batch size. Once Keras hits this step count it knows that it’s a new epoch.
The Keras train_on_batch function
For deep learning practitioners looking for the finest-grained control over training your Keras models, you may wish to use the .train_on_batch
function:
model.train_on_batch(batchX, batchY)
The train_on_batch
function accepts a single batch of data, performs backpropagation, and then updates the model parameters.
The batch of data can be of arbitrary size (i.e., it does not require an explicit batch size to be provided).
The data itself can be generated however you like as well. This data could be raw images on disk or data that has been modified or augmented in some manner.
You’ll typically use the .train_on_batch
function when you have very explicit reasons for wanting to maintain your own training data iterator, such as the data iteration process being extremely complex and requiring custom code.
If you find yourself asking if you need the .train_on_batch
function then in all likelihood you probably don’t.
In 99% of the situations you will not need such fine-grained control over training your deep learning models. Instead, a custom Keras .fit_generator
function is likely all you need it.
That said, it’s good to know that the function exists if you ever need it.
I typically only recommend using the .train_on_batch
function if you are an advanced deep learning practitioner/engineer, and you know exactly what you’re doing and why.
An image dataset…as a CSV file?
The dataset we will be using here today is the Flowers-17 dataset, a collection of 17 different flower species with 80 images per class.
Our goal will be to train a Keras Convolutional Neural Network to correctly classify each species of flowers.
However, there’s a bit of a twist to this project:
- Instead of working with the raw image files residing on disk…
- …I’ve serialized the entire image dataset to two CSV files (one for training, and one for evaluation).
To construct each CSV file I:
- Looped over all images in our input dataset
- Resized them to 64×64 pixels
- Flattened the 64x64x3=12,288 RGB pixel intensities into a single list
- Wrote 12,288 pixel values + class label to the CSV file (one per line)
Our goal is to now write a custom Keras generator to parse the CSV file and yield batches of images and labels to the .fit_generator
function.
Wait, why bother with a CSV file if you already have the images?
Today’s tutorial is meant to be an example of how to implement your own Keras generator for the .fit_generator
function.
In the real-world datasets are not nicely curated for you:
- You may have unstructured directories of images.
- You could be working with both images and text.
- Your images could be serialized in a particular format, whether that’s a CSV file, a Caffe or TensorFlow record file, etc.
In these situations, you will need to know how to write your own Keras generator functions.
Keep in mind that it’s not the particular data format that’s important here — it’s the actual process of writing your own Keras generator that you need to learn (and that’s exactly what’s covered in the rest of the tutorial).
Project structure
Let’s inspect the project tree for today’s example:
$ tree --dirsfirst . ├── pyimagesearch │ ├── __init__.py │ └── minivggnet.py ├── flowers17_testing.csv ├── flowers17_training.csv ├── plot.png └── train.py 1 directory, 6 files
Today we’ll be using the MiniVGGNet CNN. We won’t be covering the implementation here today as I’ll assume you already know how to implement a CNN. If not, no worries — just refer to my Keras tutorial.
Our serialized image dataset is contained within flowers17_training.csv
and flowers17_testing.csv
(included in the “Downloads” associated with today’s post).
We’ll be reviewing train.py
, our training script, in the next two sections.
Implementing a custom Keras fit_generator function
Let’s go ahead and get started.
I’ll be assuming you have the following libraries installed on your system:
- NumPy
- TensorFlow + Keras
- Scikit-learn
- Matplotlib
Each of these packages can be installed via pip in your virtual environment. If you have virtualenvwrapper installed you can create an environment with mkvirtualenv
and activate your environment with the workon
command. From there you can use pip to set up your environment:
$ mkvirtualenv cv -p python3 $ workon cv $ pip install numpy $ pip install tensorflow # or tensorflow-gpu $ pip install keras $ pip install scikit-learn $ pip install matplotlib
Once your virtual environment is set up, you can proceed with writing the training script. Make sure you use the “Downloads” section of today’s post grab the source code and Flowers-17 CSV image dataset.
Open up the train.py
file and insert the following code:
# set the matplotlib backend so figures can be saved in the background import matplotlib matplotlib.use("Agg") # import the necessary packages from tensorflow.keras.preprocessing.image import ImageDataGenerator from tensorflow.keras.optimizers import SGD from sklearn.preprocessing import LabelBinarizer from sklearn.metrics import classification_report from pyimagesearch.minivggnet import MiniVGGNet import matplotlib.pyplot as plt import numpy as np
Lines 2-12 import our required packages and modules. Since we’ll be saving our training plot to disk, Line 3 sets matplotlib
‘s backend appropriately.
Notable imports include ImageDataGenerator
, which contains the data augmentation and image generator functionality, along with MiniVGGNet
, our CNN that we will be training.
Let’s define the csv_image_generator
function:
def csv_image_generator(inputPath, bs, lb, mode="train", aug=None): # open the CSV file for reading f = open(inputPath, "r")
On Line 14 we’ve defined the csv_image_generator
. This function is responsible for reading our CSV data file and loading images into memory. It yields batches of data to our Keras .fit_generator
function.
As such, the function accepts the following parameters:
inputPath
: the path to the CSV dataset file.bs
: The batch size. We’ll be using 32.lb
: A label binarizer object which contains our class labels.mode
: (default is"train"
) If and only if themode=="eval"
, then a special accommodation is made to not apply data augmentation via theaug
object (if one is supplied).aug
: (default isNone
) If an augmentation object is specified, then we’ll apply it before we yield our images and labels.
On Line 16 we’ll go ahead and open the CSV data file for reading.
Let’s begin looping over the lines of data:
# loop indefinitely while True: # initialize our batches of images and labels images = [] labels = []
Each line of data in the CSV file contains an image serialized as a text string. Again, I generated the text strings from the Flowers-17 dataset. Additionally, I know this isn’t the most efficient way to store an image, but it is great for the purposes of this example.
Our Keras generator must loop indefinitely as is defined on Line 19. The .fit_generator
function will be calling our csv_image_generator
function each time it needs a new batch of data.
And furthermore, Keras maintains a cache/queue of data, ensuring the model we are training always has data to train on. Keras constantly keeps this queue full so even if you have reached the total number of epochs to train for, keep in mind that Keras is still feeding the data generator, keeping data in the queue.
Always make sure your function returns data, otherwise, Keras will error out saying it could not obtain more training data from your generator.
At each iteration of the loop, we’ll reinitialize our images
and labels
to empty lists (Lines 21 and 22).
From there, we’ll begin appending images and labels to these lists until we’ve reached our batch size:
# keep looping until we reach our batch size while len(images) < bs: # attempt to read the next line of the CSV file line = f.readline() # check to see if the line is empty, indicating we have # reached the end of the file if line == "": # reset the file pointer to the beginning of the file # and re-read the line f.seek(0) line = f.readline() # if we are evaluating we should now break from our # loop to ensure we don't continue to fill up the # batch from samples at the beginning of the file if mode == "eval": break # extract the label and construct the image line = line.strip().split(",") label = line[0] image = np.array([int(x) for x in line[1:]], dtype="float32") image = image.reshape((64, 64, 3)) # update our corresponding batches lists images.append(image) labels.append(label)
Let’s walk through this loop:
- First, we read a
line
from our text file object,f
(Line 27). - If
line
is empty:- …we reset our file pointer and try to read a
line
(Lines 34 and 35). - And if we’re in evaluation
mode
, we go ahead andbreak
from the loop (Lines 40 and 41).
- …we reset our file pointer and try to read a
- At this point, we’ll parse our
image
andlabel
from the CSV file (Lines 44-46). - We go ahead and call
.reshape
to reshape our 1D array into our image which is 64×64 pixels with 3 color channels (Line 47). - Finally, we append the
image
andlabel
to their respective lists, repeating this process until our batch of images is full (Lines 50 and 51).
Note: The key to making evaluation work here is that we supply the number of steps
to model.predict_generator
, ensuring that each image in the testing set is predicted only once. I’ll be covering how to do this process later in the tutorial.
With our batch of images and corresponding labels ready, we can now take two steps before yielding our batch:
# one-hot encode the labels labels = lb.transform(np.array(labels)) # if the data augmentation object is not None, apply it if aug is not None: (images, labels) = next(aug.flow(np.array(images), labels, batch_size=bs)) # yield the batch to the calling function yield (np.array(images), labels)
Our final steps include:
- One-hot encoding
labels
(Line 54) - Applying data augmentation if necessary (Lines 57-59)
Finally, our generator “yields” our array of images and our list of labels to the calling function on request (Line 62). If you aren’t familiar with the yield
keyword, it is used for Python Generator functions as a convenient shortcut in place of building an iterator class with less memory consumption. You can read more about Python Generators here.
Let’s initialize our training parameters:
# initialize the paths to our training and testing CSV files TRAIN_CSV = "flowers17_training.csv" TEST_CSV = "flowers17_testing.csv" # initialize the number of epochs to train for and batch size NUM_EPOCHS = 75 BS = 32 # initialize the total number of training and testing image NUM_TRAIN_IMAGES = 0 NUM_TEST_IMAGES = 0
A number of initializations are hardcoded in this example training script:
- Our training and testing CSV filepaths (Lines 65 and 66).
- The number of epochs and batch size for training (Lines 69 and 70).
- Two variables which will hold the number of training and testing images (Lines 73 and 74).
Let’s take a look at the next block of code:
# open the training CSV file, then initialize the unique set of class # labels in the dataset along with the testing labels f = open(TRAIN_CSV, "r") labels = set() testLabels = [] # loop over all rows of the CSV file for line in f: # extract the class label, update the labels list, and increment # the total number of training images label = line.strip().split(",")[0] labels.add(label) NUM_TRAIN_IMAGES += 1 # close the training CSV file and open the testing CSV file f.close() f = open(TEST_CSV, "r") # loop over the lines in the testing file for line in f: # extract the class label, update the test labels list, and # increment the total number of testing images label = line.strip().split(",")[0] testLabels.append(label) NUM_TEST_IMAGES += 1 # close the testing CSV file f.close()
This block of code is long, but it has three purposes:
- Extract all labels from our training dataset so that we can subsequently determine unique labels. Notice that
labels
is aset
which only allows unique entries. - Assemble a list of
testLabels
. - Count the
NUM_TRAIN_IMAGES
andNUM_TEST_IMAGES
.
Let’s build our LabelBinarizer
object and construct the data augmentation object:
# create the label binarizer for one-hot encoding labels, then encode # the testing labels lb = LabelBinarizer() lb.fit(list(labels)) testLabels = lb.transform(testLabels) # construct the training image generator for data augmentation aug = ImageDataGenerator(rotation_range=20, zoom_range=0.15, width_shift_range=0.2, height_shift_range=0.2, shear_range=0.15, horizontal_flip=True, fill_mode="nearest")
Using the unique labels, we’ll .fit
our LabelBinarizer
object (Lines 107 and 108).
We’ll also go ahead and transform our testLabels
into binary one-hot encoded testLabels
(Line 109).
From there, we’ll construct aug
, an ImageDataGenerator
(Lines 112-114). Our image data augmentation object will randomly rotate, flip, shear, etc. our training images.
Now let’s initialize our training and testing image generators:
# initialize both the training and testing image generators trainGen = csv_image_generator(TRAIN_CSV, BS, lb, mode="train", aug=aug) testGen = csv_image_generator(TEST_CSV, BS, lb, mode="train", aug=None)
Our trainGen
and testGen
generator objects generate image data from their respective CSV files using the csv_image_generator
(Lines 117-120).
Notice the subtle similarities and differences:
- We’re using
mode="train"
for both generators - Only
trainGen
will perform data augmentation
Let’s initialize + compile our MiniVGGNet model with Keras and begin training:
# initialize our Keras model and compile it model = MiniVGGNet.build(64, 64, 3, len(lb.classes_)) opt = SGD(lr=1e-2, momentum=0.9, decay=1e-2 / NUM_EPOCHS) model.compile(loss="categorical_crossentropy", optimizer=opt, metrics=["accuracy"]) # train the network print("[INFO] training w/ generator...") H = model.fit( x=trainGen, steps_per_epoch=NUM_TRAIN_IMAGES // BS, validation_data=testGen, validation_steps=NUM_TEST_IMAGES // BS, epochs=NUM_EPOCHS)
2020-05-13 Update: This blog post is now TensorFlow 2+ compatible! We no-longer use the .fit_generator
function; instead, we use the .fit
method.
Lines 123-126 compile our model. We’re using a Stochastic Gradient Descent optimizer with a hardcoded initial learning rate of 1e-2
. Learning rate decay is applied at each epoch. Categorical crossentropy is used since we have more than 2 classes (binary crossentropy would be used otherwise). Be sure to refer to my Keras tutorial for additional reading.
On Lines 130-135 we call .fit
to start training.
The trainGen
generator object is responsible for yielding batches of data and labels to the .fit
function.
Notice how we compute the steps per epoch and validation steps based on number of images and batch size. It’s paramount that we supply the steps_per_epoch
value, otherwise Keras will not know when one epoch starts and another one begins.
Now let’s evaluate the results of training:
# re-initialize our testing data generator, this time for evaluating testGen = csv_image_generator(TEST_CSV, BS, lb, mode="eval", aug=None) # make predictions on the testing images, finding the index of the # label with the corresponding largest predicted probability predIdxs = model.predict(x=testGen, steps=(NUM_TEST_IMAGES // BS) + 1) predIdxs = np.argmax(predIdxs, axis=1) # show a nicely formatted classification report print("[INFO] evaluating network...") print(classification_report(testLabels.argmax(axis=1), predIdxs, target_names=lb.classes_))
We go ahead and re-initialize our testGen
, this time changing the mode
to "eval"
for evaluation purposes.
After re-initialization, we make predictions using our .predict
function and our testGen
(Lines 143 and 144). At the end of this process, we’ll proceed to grab the max prediction indices (Line 145).
Using the testLabels
and predIdxs
, we’ll generate a classification_report
via scikit-learn (Lines 149 and 150). The classification report is printed nicely to our terminal for inspection at the end of training and evaluation.
As a final step, we’ll use our training history dictionary, H
, to generate a plot with matplotlib:
# plot the training loss and accuracy N = NUM_EPOCHS plt.style.use("ggplot") plt.figure() plt.plot(np.arange(0, N), H.history["loss"], label="train_loss") plt.plot(np.arange(0, N), H.history["val_loss"], label="val_loss") plt.plot(np.arange(0, N), H.history["accuracy"], label="train_acc") plt.plot(np.arange(0, N), H.history["val_accuracy"], label="val_acc") plt.title("Training Loss and Accuracy on Dataset") plt.xlabel("Epoch #") plt.ylabel("Loss/Accuracy") plt.legend(loc="lower left") plt.savefig("plot.png")
The accuracy/loss plot is generated and saved to disk as plot.png
for inspection upon script exit.
Training a Keras model using fit_generator and evaluating with predict_generator
2020-06-03 Update: Despite the heading to this section, we now use .fit
(sans .fit_generator
) and .predict
(sans .predict_generator
).
To train our Keras model using our custom data generator, make sure you use the “Downloads” section to download the source code and example CSV image dataset.
From there, open a terminal, navigate to where you downloaded the source code + dataset, and execute the following command:
$ python train.py Using TensorFlow backend. [INFO] training w/ generator... Train for 31 steps, validate for 10 steps Epoch 1/75 31/31 [==============================] - 10s 317ms/step - loss: 3.6791 - accuracy: 0.1401 - val_loss: 1828.6441 - val_accuracy: 0.0625 Epoch 2/75 31/31 [==============================] - 9s 287ms/step - loss: 3.0351 - accuracy: 0.2077 - val_loss: 246.5172 - val_accuracy: 0.0938 Epoch 3/75 31/31 [==============================] - 9s 288ms/step - loss: 2.8571 - accuracy: 0.2621 - val_loss: 92.0763 - val_accuracy: 0.0750 ... 31/31 [==============================] - 9s 287ms/step - loss: 0.4484 - accuracy: 0.8548 - val_loss: 1.3388 - val_accuracy: 0.6531 Epoch 73/75 31/31 [==============================] - 9s 287ms/step - loss: 0.4025 - accuracy: 0.8619 - val_loss: 1.1642 - val_accuracy: 0.7125 Epoch 74/75 31/31 [==============================] - 9s 287ms/step - loss: 0.3401 - accuracy: 0.8720 - val_loss: 1.2229 - val_accuracy: 0.7188 Epoch 75/75 31/31 [==============================] - 9s 287ms/step - loss: 0.3605 - accuracy: 0.8780 - val_loss: 1.2207 - val_accuracy: 0.7063 [INFO] evaluating network... precision recall f1-score support bluebell 0.63 0.81 0.71 21 buttercup 0.69 0.73 0.71 15 coltsfoot 0.55 0.76 0.64 21 cowslip 0.73 0.40 0.52 20 crocus 0.53 0.88 0.66 24 daffodil 0.82 0.33 0.47 27 daisy 0.77 0.94 0.85 18 dandelion 0.71 0.83 0.77 18 fritillary 1.00 0.77 0.87 22 iris 0.95 0.75 0.84 24 lilyvalley 0.92 0.55 0.69 22 pansy 0.89 0.89 0.89 18 snowdrop 0.69 0.50 0.58 22 sunflower 0.90 1.00 0.95 18 tigerlily 0.87 0.93 0.90 14 tulip 0.33 0.50 0.40 16 windflower 0.81 0.85 0.83 20 accuracy 0.72 340 macro avg 0.75 0.73 0.72 340 weighted avg 0.76 0.72 0.71 340
Here you can see that our network has obtained 76% accuracy on the evaluation set, which is quite respectable for the relatively shallow CNN used.
Most importantly, you learned how to utilize:
- Data generators
.fit
(formerly.fit_generator
).predict
(formerly.predict_generator
)
…all to train and evaluate your own custom Keras model!
Again, it’s not the actual format of the data itself that’s important here. Instead of CSV files, we could have been working with Caffe or TensorFlow record files, a combination of numerical/categorical data along with images, or any other synthesis of data that you may encounter in the real-world.
Instead, it’s the actual process of implementing your own Keras data generator that matters here.
Follow the steps in this tutorial and you’ll have a blueprint that you can use for implementing your own Keras data generators.
If you’re using TensorFlow 2.2+, just use “.fit”, there’s no reason to use “.fit_generator”
You can check your TensorFlow version by using pip freeze
and then looking for your TensorFlow version:
$ pip freeze | grep 'tensorflow' tensorflow==2.4.1 tensorflow-estimator==2.4.0
Provided you are using TensorFlow 2.2 or greater, you should just be using the .fit
method.
TensorFlow will be deprecating the .fit_generator
method in future releases as the .fit
method can automatically detect whether or not the input data is an array or a generator.
What's next? We recommend PyImageSearch University.
86 total classes • 115+ hours of on-demand code walkthrough videos • Last updated: October 2024
★★★★★ 4.84 (128 Ratings) • 16,000+ Students Enrolled
I strongly believe that if you had the right teacher you could master computer vision and deep learning.
Do you think learning computer vision and deep learning has to be time-consuming, overwhelming, and complicated? Or has to involve complex mathematics and equations? Or requires a degree in computer science?
That’s not the case.
All you need to master computer vision and deep learning is for someone to explain things to you in simple, intuitive terms. And that’s exactly what I do. My mission is to change education and how complex Artificial Intelligence topics are taught.
If you're serious about learning computer vision, your next stop should be PyImageSearch University, the most comprehensive computer vision, deep learning, and OpenCV course online today. Here you’ll learn how to successfully and confidently apply computer vision to your work, research, and projects. Join me in computer vision mastery.
Inside PyImageSearch University you'll find:
- ✓ 86 courses on essential computer vision, deep learning, and OpenCV topics
- ✓ 86 Certificates of Completion
- ✓ 115+ hours of on-demand video
- ✓ Brand new courses released regularly, ensuring you can keep up with state-of-the-art techniques
- ✓ Pre-configured Jupyter Notebooks in Google Colab
- ✓ Run all code examples in your web browser — works on Windows, macOS, and Linux (no dev environment configuration required!)
- ✓ Access to centralized code repos for all 540+ tutorials on PyImageSearch
- ✓ Easy one-click downloads for code, datasets, pre-trained models, etc.
- ✓ Access on mobile, laptop, desktop, etc.
Summary
In this tutorial you learned the differences between Keras’ three primary functions used to train a deep neural network:
.fit
: Used when the entire training dataset can fit into memory and no data augmentation is applied. As of TensorFlow 2, this method now supports data augmentation..fit_generator
: For legacy code using versions of TensorFlow/Keras prior to 2.2. Should be used when either (1) the dataset is too large to fit into memory, (2) data augmentation needs to be applied, or (3) in any situation when it’s more convenient to yield training data in batches (i.e., using theflow_from_directory
function)..train_on_batch
: Can be used to train a Keras model on a single batch of data. Should be utilized only when you need the finest-grained control training your network, such as in situations where your data iterator is highly complex.
From there, we discovered how to:
- Implement our own custom Keras generator function
- Use our custom generator along with Keras’
.fit_generator
to train our deep neural network
You can use today’s example code as a template when implementing your own Keras generators in your own projects.
I hope you enjoyed today’s blog post!
To download the source code to this post, and be notified when future tutorials are published here on PyImageSearch, just enter your email address in the form below!
Download the Source Code and FREE 17-page Resource Guide
Enter your email address below to get a .zip of the code and a FREE 17-page Resource Guide on Computer Vision, OpenCV, and Deep Learning. Inside you'll find my hand-picked tutorials, books, courses, and libraries to help you master CV and DL!
Sivarama Krishnan Rajaraman
Hi Adrian,
Thanks for this wonderful post. We have been working with images all the time. However, there is no clear information online on how to serialize the image dataset along with their labels to the CSV files. I am sure many enthusiastic readers of your blog would love to see this kind of a post. Looking forward.
Best,
Shiva
Adrian Rosebrock
The general algorithm is actually quite simple:
1. Loop over all images in your dataset
2. Load image
3. Resize to fixed dimensions (or embed the dimensions as the first entries for the row)
4. Flatten the image to a list of pixels
5. Write label, flattened list, and any other meta data (such as dimension info) to the CSV file
Xu Zhang
If you would like to show some codes, it will help a lot. Thank you so much for your great post
Sivarama Krishnan Rajaraman
I agree with Zhang’s request. If you can point us to some reliable code for the process, it would be a lot helpful.
Adrian Rosebrock
I’ve uploaded the .zip associated with this post (available via the “Downloads” section) to include my
build_dataset.py
file which can be used to create a CSV file of images. Enjoy!Safi
Hi Sir,
I’ve downloaded the code and try to use #build_datasets to convert some images into csv files, but I’m stuck with parse arguments. in the sources code you provided, do I need to input through ap.argument or how. when I try to run I get this errors ” error: the following arguments are required: -d/–dataset”.
I didn’t get well the concept of parses here.
Thanks Sir.
Adrian Rosebrock
It’s okay if you are new to command line arguments but make sure you read this tutorial on argpase first. From there you will have the knowledge you need to continue.
Safi
Thanks so much Adrian Rosebrock, the tutorial on agparse is so helpful, I’m able to figuring after reading the tutorial.
Thanks so much. keep it up the good work. you’re amazing and talented.
Adrian Rosebrock
Thanks Safi, I’m glad it helped you 🙂
Sanjeevi
Actually data augmention is used to produce more data with rotating images,shift the image.Data augmention used when our dataset is small right
Adrian Rosebrock
Your understanding of data augmentation is slightly incorrect. See my reply to Sagar.
Xu Zhang
In Francois Chollet’s book “Deep Learning with Python” on page 139, he wrote ” Data augmentation takes the approach of generating more training data from existing training samples, ……. The goal is that at training time. your model will never see the exact same picture twice. …..”
Would you like to explain your opinions?
Adrian Rosebrock
You’re not understanding Francois’ explanation. He is saying that data augmentation takes the original training data and then modifies it on the fly via random perturbations. Data augmentation is not an additive operation, meaning that the network is NOT trained on the original data + augmented data. Instead, it’s trained on data that is augmented, on the fly, from the original training data.
I would strongly encourage you, or anyone else who has this same question, to read through Deep Learning for Computer Vision with Python where I discuss data augmentation and how it works in more detail.
Martin
Hello, Adrian,
there are Augmentor tools out there that create a bunch of extended images and still keep the original images. But in this case, you first generate data and save the images and create a matching CSV file. That is, if I read this correctly, the number of images is also correct. But you’re right about the “on-fly” method, which you use here. The dataset doesn’t get bigger. I think the questioner has stumbled over it.
Best regards
Martin
Adrian Rosebrock
You should take a look at my dedicated tutorial on data augmentation coming out in a few weeks 🙂 Keep an eye on the PyImageSearch blog.
Tom
Thanks for the posting. May I know if you can post a sample on classification of moving video object such as a person is walking or the person is falling on the ground based upon the video.
Adrian Rosebrock
I think what you are referring to is called “human activity recognition”. I don’t have any tutorials on human activity recognition but I will consider it for the future.
Ravi
Thanks for the post Adrian. Excuse me for posting a slightly off-topic question.
We can train a model with Keras wrapper over TF and could save the Model to H5 format, when we follow your above instructions. Is there a way to export the model to ckpt files? What changes we need to make in the code while saving ? Is it possible to export TF MetaGraph directly from Keras?
Kim, Eun-ho
Dear Adrian,
Thank you for this very useful article.
For data augmentation, the total number of training data points per epoch is to multiply steps_per_epoch(len(trainX) // BS) by batch_size(BS). therefore, no data augmentation is occuring.
And you have said that the proper number of training data points per class is 1000 ~ 5000. So, the total number of training datapoints per epoch should be to multiply the number of classes by (1000 ~ 5000).
I think that steps_per_epoch should be that
class_number x (1000 ~ 5000) // batch_size
Adrian Rosebrock
No, the number of steps per epoch is the total number of training examples divided by the batch size. Data augmentation is applied internally inside the data generator. I’m not sure where the multiplication comment is coming from so perhaps you can clarify your comment but my general intuition is that I believe you have a misunderstanding on how data augmentation actually works. Make sure you see my reply to “Sagar”.
Kim, Eun-ho
I think that the total number of training examples per epoch for data augmentation is not training data points but the number of classes times (1000 ~ 5000).
Adrian Rosebrock
No, that is incorrect. The number of training examples per epoch with data augmentation is the number of total training data points. Applying data augmentation does not add more data to your training set, it simply augments it by randomly perturbing each and every data point with some transformation.
Again, your understanding of data augmentation inside of Keras is incorrect. Data augmentation is not “additive” — data augmentation replaces the original training set with randomly perturbed examples.
Kim, Eun-ho
Yes, you are right. But I think it is possible to increase the total number of training examples per epoch through change of steps_per_epoch of fit_generator method.
Accordingly, I think that NUM_TRAIN_IMAGES in steps_per-epoch should be not training data points but the number of classes times (1000 ~ 5000).
Adrian Rosebrock
You can arbitrarily increase the number of batches or images per epoch, yes. But there’s no reason to do that.
As for your second remark, no that is 100% false. The number of training steps per epoch is the total number of training images divided by the batch size. The total number of class labels has absolutely nothing do with the batch size.
I would highly encourage you to read through Deep Learning for Computer Vision with Python. The book will help you understand the fundamentals and remove any confusions you have surrounding batch sizes and steps per epoch. Be sure to take a look.
Atul
Hi Adrian,
I closely follow you and your tutorials and thanks for this one.
I have one question, above you provided tutorial to train custom data in keras, but as you know keras has few models like VGG16, Resnet50 etc so Is there any way to fine tune these models ? Because I want to add few more classes in existing keras model, like they have 1000 classes and I want to add 10 more in the same model.
Adrian Rosebrock
Yes, absolutely. I cover transfer learning, feature extraction, and fine-tuning in detail inside my book, Deep Learning for Computer Vision with Python. I would suggest starting there.
Sagar Rathod
Good Explanation, Adrian !!!
I have one doubt here:
These lines of code in csv_image_generator function is going to modify all images in the current batch if augmentation supplied, right? If yes, then it means that model is never going to see original images in the dataset.
# if the data augmentation object is not None, apply it
if aug is not None:
(images, labels) = next(aug.flow(np.array(images),
labels, batch_size=bs))
What I thought is that the data augmentation technique is to augment the training set by adding additional images, in particular, to increase the size of a training set.
Adrian Rosebrock
When applying data augmentation the goal is to purposely apply data augmentation for each and every batch of images, implying that each image is randomly transformed in some way. The goal is not to replace the dataset, it’s to randomly modify each image. You can think of data augmentation is applying a set of transformations with probability “p”. The goal is not to enlarge the dataset, it’s simply to augment it on the fly.
Rajesh Agrawal
In case of data augementation, will the batch size remain same ? or it will be batch size of images from training + augemented images?
Adrian Rosebrock
The batch size will remain the same. Data augmentation does not add new images to the training set, it just augments the existing ones on the fly. I would suggest you read Deep Learning for Computer Vision with Python so you can learn more about data augmentation and how it works.
Xu Zhang
If the purpose of data augmentation is not to enlarge the dataset, how can data augmentation reduce overfitting? What are the mechanisms?
If I use this technic, I can generate more images using the original ones and save them into the dataset, then load the rebuilt dataset and train the model. Are they different between doing data augmentation in the code and training enlarge the dataset using the same augmentation technic? Thanks
Adrian Rosebrock
Kindly refer to my reply to your other comment.
Xu Zhang
Thank you for your tutorial. What are the purposes that you changed the image files into .csv files? Thanks a lot.
Adrian Rosebrock
Take a look at the “Wait, why bother with a CSV file if you already have the images?” section.
Kleyson Rios
Hi Adrian,
Some days before you post this very nice blogpost, I’ve been playing with Keras generators, and after validating of my code I noticed some strange behaviors.
One of them is the steps_per_epoch and validation_steps. Doing as you did, that is correct based on the Keras documentation, might not feed the model with the full dataset as expected. See the discussion on this thread – https://github.com/keras-team/keras/issues/11877
The correct way should be: math.ceil(NUM_TRAIN_IMAGES / BS)
The second one is regarding the .fit_generator itself, please take a look on this thread – https://github.com/keras-team/keras/issues/11878 to understand better the issue.
Still regarding the second issue, I would like to see a new blogpost 🙂 using sequence instead of generator, as suggested by a member in the respective thread.
Best Regards.
Kleyson Rios.
Adrian Rosebrock
Regarding the first issue, that’s normally a implementation-specific choice by the DL engineer whether or not they want to pass the final non-full size batch through the model. I wouldn’t call that an “issue”, just a matter of preference.
As for the sequence vs. generator question, I’ve never ran into that before. I’ll have to take a look.
Ivan Donadello
Hi Adrian,
first of all thank you very much for all your posts. Pyimagesearch is a very precious and useful resources for researchers, workers and Computer Vision lovers.
Regarding this post, do you have any hint or tutorial for writing our own generators with data augmentation?
Thank a lot
Ivan
Adrian Rosebrock
Thanks Ivan.
As for your question this tutorial actually shows how you can apply data augmentation within the generator so perhaps I’m not understanding your question properly?
JP Cassar
Thanks Adrian for the post,
I was wondering if you can add an example of classification (classify.py) using the MiniVGGNet model created by this post and images from Flowers-17.
Adrian Rosebrock
You mean the actual images themselves and not the serialized images? If so, yes, I cover that topic inside Deep Learning for Computer Vision with Python.
Mike
Hi Adrian,
Shouldn’t the mode of testGen also be set to ‘eval’ when training?
Davy Jones
No, eval is to stop generating data when you reach end of file (for predicting after training is complete)
Mike
And another question:
Why do you reset the file pointer to the beginning of the file once the end of the file is reached? I think this will never happen during training since you set the number of steps per epoch to number of examples divided by batch size.
Rajesh Agrawal
I am bit confused with model.fit , when we are mentioning batch_size , how can you say it fits the whole data in ram? or is it weights and baises of the model?
Adrian Rosebrock
The call to “model.fit” assumes your entire dataset is in RAM. The “.fit” method does not use a data generator so the entire dataset must be loaded into RAM before calling it.
Rajesh
Thanks Adrian for the clearing my concept. I got your point fit needs training data to be readily available in the code before calling fit. It perfectly makes sense. Thank you!
rabah
There’s something I really don’t understand.
You train the model with an output of lb.classes size. So that will depend on the batch size right? since you add a label in that loop.
But what if you have X labels and your loop length is only X/2 ?
Adrian Rosebrock
No, the “lb” is “fit” on all the input “labels”. It has nothing to do with the batch size. Go back and review the code again.
Tiago Carvalho
Hi Adrian
first all, thanks a lot about your blog. It is very useful to follow advances on CV, ML and PI fields when working with Python, OpenCV and Deep Learning frameworks.
My question is about performance on fit() and fit_generator() methods.
Currently I’m trying to reproduce my results when I was using fit() but now using fit_generator(). However, even using same parameters and inputs, my results when using fit_generator is much worse than my results using just fit(). I believe this is because the way fit() split input data for training batches but I’m not completely sure.
Do you have some guests? There is some way you know to obtain exact same results?
Since now, thank you so much.
Bests
Adrian Rosebrock
I would go back and double-check your code. Make sure you are using the same hyperparameters between the two examples. It could also be the case that you have a bug in your generator function causing incorrect data + corresponding labels to be generated.
donto
Hi @adrian,
somehow I get confused with ‘steps_per_epoch’ parameter. You wrote:
“Since the function is intended to loop infinitely, Keras has no ability to determine when one epoch starts and a new epoch begins.”
In my case, I use custom generator (https://stanford.edu/~shervine/blog/keras-how-to-generate-data-on-the-fly) to generate my data and I simply set how many epoch I need. So, what is the correlation ‘epoch’ and ‘steps_per_epoch’ ?
Adrian Rosebrock
You can absolutely set the number of epochs you want your network to train for. However, if you are using a data generator you also Need to supply the number of steps per epoch. The steps per epoch is the total number of training images divided by your batch size.
YaoHui
Why doesn’t keras calculate steps_per_epoch internally? Keras himself can get the total length and batchsize of the sample? Why doesn’t he calculate it himself?
Martin
Hej Adrian,
thanks for a nice tutorial! I wonder if you have any suggestions for how one does to shuffle the data when using fit_generator?
Best regards
Martin
Adrian Rosebrock
There are a few things you could do:
1. Pre-shuffle the data
2. At each epoch, pick a random index into your data and then start generating your batches from there
Alan
Hi Adrian,
I was wondering why we train on the testGen sample and also evaluate on the testGen sample?
Thanks!
Adrian Rosebrock
You’re not training on the “testGen”, you’re training with the “trainGen”.
Martin
Usually, I don’t have a question and understand source code fluently, but I am a little bit confused right now how the image generator really works.
Let’s assume my RAM can only handle 1000 pictures at a time, but I want to train my machine learning model with 100.000 pictures. Then, I have to set the BS value to 1000 or even smaller right?
So on the first batch/chunk, it reads in the first 1000 images with labels and it will train on them. On the 2nd chunk it hast to start reading lines 1001 to 2001 of your csv file. But I can’t find this in your code above. In my humble opinion, it always starts at line 0 when I call the method. Is the method treaten like a thread? And I only have to reset the value for the next epoch? Or why don’t you save the last line number you where in and start from that line?
Adrian Rosebrock
I get the example that you’re including, but typical batch values are BS={8,16,32,64,128,256}. Very rarely would a batch size be larger than 256.
As for always starting at line 0 of the file, that’s not the case. The file pointer only restarts if the line read was empty (which would happen at the end of the file).
Ponraj
Hello Adrian,
Thanks for your post.
While performing model.fit_generator, whether it is necessary to have common BS for both trainGen and testGen. In my data set I have small number of testing data, so whether can i arbitrarily provide the Validation_sets quantity ?
Bruno Fontana
Great tutorial, thanks for your help.
I have a question about using custom generator functions for prediction.
I wrote my own custom generator, which provides batches of (X_train, Y_train), where Y_train are the true output labels.
The training/validation works fine.
However, I would like to use model.predict_generator with my testGen object.
The problem is that I only get the predictions, but I don’t have access to the true labels of the batches of Y_train.
My generated examples have random nature, so every call to testGen will have different examples. Is there a way to call model.predict_generator and output the true labels of each batch?
Valentino Pereira
How to debug an existing script which uses Keras.fitGenerator function ?
I am using a script and it keeps on exiting at first epoch without throwing any error. How should I deal with this?
K M Ibrahim Khalilullah
Thanks your wonderful post. I have a question. Would you answer me please.
Is it possible to see/inspect output of any layer of your model during training?
Thanks again,
Khalilullah
Adrian Rosebrock
What do you mean by “inspect” and “see”?
Hadeer El-Saadawy
Hello,
This tutorial is very useful really thank you. But i have a weird thing. I applied this code on my data but i have the same data for the validation and testing purposes. However the accuracy of the validation is very high while the accuracy of the testing is very low.
Any help ?
Adrian Rosebrock
It sounds like your network is overfitting and/or your testing set is not representative of the rest of your training/validation data. I cover how to resolve that issue inside Deep Learning for Computer Vision with Python.
Rahul Tewatia
Hi Adrian
I am using google colab for the training of my model which has 25 gb ram.
I loaded h5 file containing 17000 images data with a batch_size of 128 images and steps per epoch also as per required.
For doing this, my 16 gb ram is used up.Is, it normal since my dataset is not that big.
Anchor Jiang
hi Adrian! Thank you for your code!
I study it, and I think something can be improve
1. if dataset is very big, csv will be so big (because write image file in csv, every batch read file and covert)
when traing , speed is not so good. beter way is only write path,
2. when using sklean LabelBinarizer proess binary classification one-hot , it will be problem
now i using keras.utils.to_categorical to process
Thành Hoàng MTA
With image dataset we have a function like flow and flow_from_directory that automatically generate and yield the batches so I wonder if there is any way out there which is shorter than your function “csv_image_generator”, I mean for handling the csv file. Thank you very much !!!
Rolando Pula
Hello, I just read this dicumentation and tutorial but I can not find the answer on dealing image with (x,y,z) values like .tiff file. Can someone help me how to use the .fit_generator using .tiff in 3D CNN. Thank you so much in advance
Adrian Rosebrock
I don’t have any tutorials on using 3D data but I may cover it in the future.
Mubashir
Hi! Thanks for your tutorial! Could you kindly explain how you included the labels in the two CSVs you created? Original dataset only includes names of files. Did you assume that first 80 images belong to one category, etc?
Adrian Rosebrock
The label is included in the CSV file. It’s the first entry in each row (see Lines 44-47).