Table of Contents
Introduction to TFRecords
In this tutorial, you will learn about TensorFlow’s TFRecords.
To learn how to work with TensorFlow’s TFRecords format, just keep reading.
Looking for the source code to this post?
Jump Right To The Downloads SectionIntroduction to TFRecords
Introduction
The goal of this tutorial is to serve as a one-stop destination for everything you need to know about TFRecords. We purposefully structure the tutorial in a way so that you build a deeper understanding of the topic. It is designed for beginners, and we expect you to have no prior knowledge of this topic.
So without any further delay, let’s jump straight into our tutorial.
Configuring Your Development Environment
To follow this guide, you need to have the TensorFlow and TensorFlow Datasets library installed on your system.
Luckily, both are pip-installable:
$ pip install tensorflow $ pip install tensorflow-datasets
Having Problems Configuring Your Development Environment?
All that said, are you:
- Short on time?
- Learning on your employer’s administratively locked system?
- Wanting to skip the hassle of fighting with the command line, package managers, and virtual environments?
- Ready to run the code right now on your Windows, macOS, or Linux system?
Then join PyImageSearch University today!
Gain access to Jupyter Notebooks for this tutorial and other PyImageSearch guides that are pre-configured to run on Google Colab’s ecosystem right in your web browser! No installation required.
And best of all, these Jupyter Notebooks will run on Windows, macOS, and Linux!
Project Structure
We first need to review our project directory structure.
Start by accessing the “Downloads” section of this tutorial to retrieve the source code and example images.
From there, take a look at the directory structure:
. ├── create_tfrecords.py ├── example_tf_record.py ├── pyimagesearch │ ├── advance_config.py │ ├── config.py │ ├── __init__.py │ └── utils.py ├── serialization.py └── single_tf_record.py 1 directory, 8 files
In the pyimagesearch
directory, we have:
utils.py
: The utilities for loading and saving images off disk.config.py
: The configuration file for the single data tfrecord example.advance_config.py
: The configuration file for thediv2k
dataset example.
In the core directory, we have four scripts:
single_tf_record.py
: Script that works with a single binary record and shows how to save it to the TFRecord format.serialization.py
: Script that explains the importance of serialization of the data.example_tf_record.py
: Script to save and load a single image as a TFRecord.create_tfrecords.py
: Script to save and load the entirediv2k
dataset.
What Are TFRecords?
TFRecord is a custom TensorFlow format for storing a sequence of binary records. TFRecords are highly optimized for TensorFlow, which lead to them having the following advantages:
- Efficient form of data storage
- Faster read speed compared to other types of formats
One of the most important use cases of TFRecords is when we train a model using TPU. TPUs are super powerful but require the data they interact with to be stored remotely (usually, we use Google Cloud Storage), and that’s where TFRecords come in. We store the datasets remotely in TFRecord format when training a model on TPU since it makes saving the data efficiently and loading the data easier.
In this blog post, you learn everything from how to build basic TFRecords to advanced TFRecords used to train the SRGAN and ESRGAN models covered in the following blog posts:
- https://pyimagesearch.com/2022/06/06/super-resolution-generative-adversarial-networks-srgan/
- https://pyimagesearch.com/2022/06/13/enhanced-super-resolution-generative-adversarial-networks-esrgan/
Before we get started, we would like to mention that this blog post is heavily inspired by Ryan Holbrook’s Kaggle Notebook on TFRecords Basics and TensorFlow’s guide on TFRecords.
Building Your Own TFRecords
Build a TFRecord
Let’s start with something simple. We will create binary records (byte strings) and then use APIs to save them into a TFRecord. This will allow us to understand how to save large datasets into TFRecords.
# USAGE # python single_tf_record.py # import the necessary packages from pyimagesearch import config from tensorflow.io import TFRecordWriter from tensorflow.data import TFRecordDataset # build a byte-string that will be our binary record record = "12345" binaryRecord = record.encode() # print the original data and the encoded data print(f"Original data: {record}") print(f"Encoded data: {binaryRecord}") # use the with context to initialize the record writer with TFRecordWriter(config.TFRECORD_SINGLE_FNAME) as recordWriter: # write the binary record into the TFRecord recordWriter.write(binaryRecord) # open the TFRecord file with open(config.TFRECORD_SINGLE_FNAME, "rb") as filePointer: # print the binary record from the TFRecord print(f"Data from the TFRecord: {filePointer.read()}") # build a dataset from the TFRecord and iterate over it to read # the data in the decoded format dataset = TFRecordDataset(config.TFRECORD_SINGLE_FNAME) for element in dataset: # fetch the string from the binary record and then decode it element = element.numpy().decode() # print the decoded data print(f"Decoded data: {element}")
On Lines 5-7, we import our necessary packages.
Let’s initialize the data we want to store as a TFRecord. On Line 10, we build a variable named record
, which is initialized with a string "12345"
. Next, on Line 11, we encode this string into a byte string.
This is particularly important because TFRecords can store only binary records, and byte-strings are just that.
Lines 14 and 15 print the original and encoded strings to show the difference between the two. We will be able to notice the difference when we look into the output of the script.
Lines 18-20 initialize a TFRecordWriter
. We can use the write()
API as many times as we want to write a binary record into the TFRecord. Notice how the TFRecordWriter
uses the with
context.
On Lines 23-25, we open the TFRecord file and inspect its data.
Line 29 is particularly useful for us as we now can build a tf.data.Dataset
from any TFRecord.
Once we have our dataset, we can iterate through it and use the data to our own will.
You can refer to this blog post on tf.data
to brush up on some basics.
Let’s look at the output of this script.
$ python single_tf_record.py Original data: 12345 Encoded data: b'12345' Data from the TFRecord: b'\x05\x00\x00\x00\x00\x00\x00\x00\xea\xb2\x04>12345z\x1c\xed\xe8' Decoded data: 12345
Notice the following:
- The original data and the decoded data should be the same.
- The encoded data is just the byte string of the original data.
- The data from the TFRecord is a serialized binary record.
Serialization
So, what is a serialized binary record?
We know that TFRecords store a sequence of binary records. So, we first need to learn how to convert data into binary representations. Later we will build our intuitions on top of this.
TensorFlow has two public APIs that take care of encoding and decoding data into and out of binary records. The two public APIs are from tf.io.serialize_tensor
and tf.io.parse_tensor
.
Let’s get our hands dirty with this example.
# USAGE # python serialization.py # import the necessary packages import tensorflow as tf from tensorflow.io import serialize_tensor, parse_tensor # build the original data originalData = tf.constant( value=[1, 2, 3, 4], dtype=tf.dtypes.uint8 ) # serialize the data into binary records serializedData = serialize_tensor(originalData) # read the serialized data into the original format parsedData = parse_tensor( serializedData, out_type=tf.dtypes.uint8 ) # print the original, encoded, and the decoded data print(f"Original Data: {originalData}\n") print(f"Encoded Data: {serializedData}\n") print(f"Decoded Data: {parsedData}\n")
Lines 5 and 6 include the imports necessary for the code to run. We are importing our ally TensorFlow and the two public APIs necessary for serialization and deserialization.
On Lines 9-12, we build a tf.constant
, which will serve as our original data. The data is of data type tf.dtypes.uint8
. The data type of the original data is important as this will be required while we deserialize (decode) the binary records.
On Line 15, we serialize originalData
into byte strings (binary representation).
On Lines 18-21, we deserialize the serializedData
into the original format. Notice how we need to specify the output format using the parameter out_type
. This is where we provide the same data type as the original data (tf.dtypes.uint8
).
Lines 24-26 are print statements to help us visualize the process. Let’s look at the output.
$ python serialization.py Original Data: [1 2 3 4] Encoded Data: b'\x08\x04\x12\x04\x12\x02\x08\x04"\x04\x01\x02\x03\x04' Decoded Data: [1 2 3 4]
As is evident from the output, the original data was serialized into a sequence of byte strings and later deserialized into the original data.
TFRecords from structured tf.data
: Let’s back up a little and recap what we have just covered.
We know how a binary record is stored in a TFRecord; we also covered parsing any data into a binary record. Now we will deep dive into the process of turning a structured dataset into TFRecords.
In this step, we will need all of the prerequisites that we have covered till now.
Before we work with an entire dataset, let us first try to understand how to work with a single instance of the dataset.
A dataset (structural dataset) consists of individual instances. These instances can be thought of as an Example
. Let’s consider an image and class name pair as an Example
. This Example consists of two individual Feature
collectively called Features
. One of the Feature
is an image, and the other is the class name.
From the TensorFlow official guide on TFRecords, as shown in Figure 2, we can see the different data types that tf.train.Feature
can accept.
This means that the feature would need to be serialized into one of the above lists and then wrapped into a Feature.
Let’s see how to do that through the following example.
# import the necessary packages import tensorflow as tf from pyimagesearch import config from tensorflow.io import read_file from tensorflow.image import decode_image, convert_image_dtype, resize import matplotlib.pyplot as plt import os def load_image(pathToImage): # read the image from the path and decode the image image = read_file(pathToImage) image = decode_image(image, channels=3) # convert the image data type and resize it image = convert_image_dtype(image, tf.dtypes.float32) image = resize(image, (16, 16)) # return the processed image return image def save_image(image, saveImagePath, title=None): # show the image plt.imshow(image) # check if title is provided, if so, add the title to the plot if title: plt.title(title) # turn off the axis and save the plot to disk plt.axis("off") plt.savefig(saveImagePath)
In this example, we do many things at once. To begin with, let us look into the utils.py
file and get a hold of the pre-processing steps.
On Lines 2-7, we import the necessary packages. Next, we define our load_image
function on Lines 9-19, which reads an image from disk, converts it to 32-bit floating format, resizes the image to 16×16, and returns.
Following that, we define our save_image
function on Lines 21-31, which takes as input the image and output image path. On Line 23, we show the image followed by setting the plot title on Lines 26 and 27. Last up, we save the image to disk on Lines 30 and 31.
Let’s now see how we would load a raw image from disk and serialize it in the TFRecord format. We will then see how we can load the serialized TFRecord and de-serialize the image.
# USAGE # python example_tf_record.py # import the necessary packages from pyimagesearch import config from pyimagesearch import utils from tensorflow.keras.utils import get_file from tensorflow.io import serialize_tensor from tensorflow.io import parse_example from tensorflow.io import parse_tensor from tensorflow.io import TFRecordWriter from tensorflow.io import FixedLenFeature from tensorflow.train import BytesList from tensorflow.train import Example from tensorflow.train import Features from tensorflow.train import Feature from tensorflow.data import TFRecordDataset import tensorflow as tf import os
From Lines 5-19, we import all the necessary packages.
# a single instance of structured data will consist of an image and its # corresponding class name imagePath = get_file( config.IMAGE_FNAME, config.IMAGE_URL, ) image = utils.load_image(pathToImage=imagePath) class_name = config.IMAGE_CLASS # check to see if the output folder exists, if not, build the output # folder if not os.path.exists(config.OUTPUT_PATH): os.makedirs(config.OUTPUT_PATH) # save the resized image utils.save_image(image=image, saveImagePath=config.RESIZED_IMAGE_PATH) # build the image and the class name feature imageFeature = Feature( bytes_list=BytesList(value=[ # notice how we serialize the image serialize_tensor(image).numpy(), ]) ) classNameFeature = Feature( bytes_list=BytesList(value=[ class_name.encode(), ]) ) # wrap the image and class feature with a features dictionary and then # wrap the features into an example features = Features(feature={ "image": imageFeature, "class_name": classNameFeature, }) example = Example(features=features)
On Lines 23-26, we download an image from a specific url and save the image to disk. Next, on Line 27, we use the load_image
function to load the image from disk as a tf.Tensor
. Finally, Line 28 specifies the class name of the image.
The image and the class name will serve as our single instance data. We would now need to serialize them and save them as individual Feature
. Lines 39-49 take care of the serialization process and wrap the image and class name as Feature
.
Now that we have our individual Feature
, we need to wrap it into a collection named Features
. Lines 53-56 build a Features
, which consists of a dictionary of Feature
. Finally, Line 57 concludes our journey by wrapping Features
into a single Example
.
# serialize the entire example serializedExample = example.SerializeToString() # write the serialized example into a TFRecord with TFRecordWriter(config.TFRECORD_EXAMPLE_FNAME) as recordWriter: recordWriter.write(serializedExample) # build the feature schema and the TFRecord dataset featureSchema = { "image": FixedLenFeature([], dtype=tf.string), "class_name": FixedLenFeature([], dtype=tf.string), } dataset = TFRecordDataset(config.TFRECORD_EXAMPLE_FNAME) # iterate over the dataset for element in dataset: # get the serialized example and parse it with the feature schema element = parse_example(element, featureSchema) # grab the serialized class name and the image className = element["class_name"].numpy().decode() image = parse_tensor( element["image"].numpy(), out_type=tf.dtypes.float32 ) # save the de-serialized image along with the class name utils.save_image( image=image, saveImagePath=config.DESERIALIZED_IMAGE_PATH, title=className )
On Line 60, we can directly serialize the Example
using the SerializeToString
function. Next, we directly build the TFRecord from the serialized example on Lines 63 and 64.
Now we build a schematic of the Feature on Lines 67-70. This schematic will be used to parse each example.
As mentioned earlier, building tf.data.Dataset
using TFRecords is very simple. On Line 71, we build our dataset using the simple API TFRecordDataset
.
On Lines 74-90, we iterate over the dataset. Line 76 is used to parse each element of the dataset. Notice how we use the feature schematic here to parse the examples. On Lines 79-83, we grab the class name and the image in their deserialized state. Finally, we save the image to disk in Lines 86-90.
Advanced TFRecord Generation
Let’s now take a look at generating advanced TFRecords. In this section, we will load the div2k dataset using tfds
(which stands for tensorflow_datasets
, a collection of ready-to-use datasets), pre-process it, and then serialize the pre-processed dataset as TFRecords.
# USAGE # python create_tfrecords.py # import the necessary packages from pyimagesearch import config from tensorflow.io import serialize_tensor from tensorflow.io import TFRecordWriter from tensorflow.train import BytesList from tensorflow.train import Feature from tensorflow.train import Features from tensorflow.train import Example import tensorflow_datasets as tfds import tensorflow as tf import os # define AUTOTUNE object AUTO = tf.data.AUTOTUNE
From Lines 5-14, we import all the necessary packages, including our config file, tensorflow datasets collection, and other TensorFlow sub-modules required to serialize a dataset as TFrecords. On Line 17, we define the tf.data.AUTOTUNE
for optimization purposes.
def pre_process(element): # grab the low and high resolution images lrImage = element["lr"] hrImage = element["hr"] # convert the low and high resolution images from tensor to # serialized TensorProto proto lrByte = serialize_tensor(lrImage) hrByte = serialize_tensor(hrImage) # return the low and high resolution proto objects return (lrByte, hrByte)
On Lines 19-30, we define our pre_process
function, which takes an element consisting of low- and high-resolution images as input. On Lines 21 and 22, we grab the low- and high-resolution images. On Lines 26 and 27, we convert the low- and high-resolution images from tensors to serialized TensorProto type. Finally, on Line 30, we return both the low- and high-resolution image.
def create_dataset(dataDir, split, shardSize): # load the dataset, save it to disk, and preprocess it ds = tfds.load(config.DATASET, split=split, data_dir=dataDir) ds = (ds .map(pre_process, num_parallel_calls=AUTO) .batch(shardSize) ) # return the dataset return ds
On Lines 32-41, we define our create_dataset
function, which takes the path of the directory to store the dataset, dataset split, and shard size as input. On Line 34, we load the div2k dataset and store it on disk. On Lines 35-38, we preprocess the dataset and batch size. Finally, on Line 41, we return the TensorFlow dataset object.
def create_serialized_example(lrByte, hrByte): # create low and high resolution image byte list lrBytesList = BytesList(value=[lrByte]) hrBytesList = BytesList(value=[hrByte]) # build low and high resolution image feature from the byte list lrFeature = Feature(bytes_list=lrBytesList) hrFeature = Feature(bytes_list=hrBytesList) # build a low and high resolution image feature map featureMap = { "lr": lrFeature, "hr": hrFeature, } # build a collection of features, followed by building example # from features, and serializing the example features = Features(feature=featureMap) example = Example(features=features) serializedExample = example.SerializeToString() # return the serialized example return serializedExample
On Lines 43-65, we define our create_serialized_example
function, which takes low- and high-resolution images in the byte form as input. On Lines 45 and 46, we create the low- and high-resolution image byte list object. On Lines 49-56, we build the low- and high-resolution image feature from the byte list and build the subsequent image feature map from the low- and high-resolution image feature. On Lines 60-62, we build a collection of features from the feature map, followed by building an example from the features and serializing the example. Finally, on Line 65, we return the serialized example.
def prepare_tfrecords(dataset, outputDir, name, printEvery=50): # check whether output directory exists if not os.path.exists(outputDir): os.makedirs(outputDir) # loop over the dataset and create TFRecords for (index, images) in enumerate(dataset): # get the shard size and build the filename shardSize = images[0].numpy().shape[0] tfrecName = f"{index:02d}-{shardSize}.tfrec" filename = outputDir + f"/{name}-" + tfrecName # write to the tfrecords with TFRecordWriter(filename) as outFile: # write shard size serialized examples to each TFRecord for i in range(shardSize): serializedExample = create_serialized_example( images[0].numpy()[i], images[1].numpy()[i]) outFile.write(serializedExample) # print the progress to the user if index % printEvery == 0: print("[INFO] wrote file {} containing {} records..." .format(filename, shardSize))
On Lines 67-90, we define the prepare_tfrecords
function, which mainly takes the TensorFlow dataset and output directory path as input. On Lines 69 and 70, we check if the output directory exists, and if it doesn’t, then we create it. On Line 73, we start looping over the dataset, grabbing the index and images. On Lines 75-77, we set the shard size, output TFRecord name, and path of the output TFrecord. On Lines 80-90, we open an empty TFRecord and begin writing serialized examples to it.
# create training and validation dataset of the div2k images print("[INFO] creating div2k training and testing dataset...") trainDs = create_dataset(dataDir=config.DIV2K_PATH, split="train", shardSize=config.SHARD_SIZE) testDs = create_dataset(dataDir=config.DIV2K_PATH, split="validation", shardSize=config.SHARD_SIZE) # create training and testing TFRecords and write them to disk print("[INFO] preparing and writing div2k TFRecords to disk...") prepare_tfrecords(dataset=trainDs, name="train", outputDir=config.GPU_DIV2K_TFR_TRAIN_PATH) prepare_tfrecords(dataset=testDs, name="test", outputDir=config.GPU_DIV2K_TFR_TEST_PATH)
On Lines 93-97, we create our div2k training and testing TensorFlow dataset. From Lines 100-104, we start calling the prepare_tfrecords
function to create the training and testing TFRecords which will be saved on the disk.
What's next? We recommend PyImageSearch University.
84 total classes • 114+ hours of on-demand code walkthrough videos • Last updated: February 2024
★★★★★ 4.84 (128 Ratings) • 16,000+ Students Enrolled
I strongly believe that if you had the right teacher you could master computer vision and deep learning.
Do you think learning computer vision and deep learning has to be time-consuming, overwhelming, and complicated? Or has to involve complex mathematics and equations? Or requires a degree in computer science?
That’s not the case.
All you need to master computer vision and deep learning is for someone to explain things to you in simple, intuitive terms. And that’s exactly what I do. My mission is to change education and how complex Artificial Intelligence topics are taught.
If you're serious about learning computer vision, your next stop should be PyImageSearch University, the most comprehensive computer vision, deep learning, and OpenCV course online today. Here you’ll learn how to successfully and confidently apply computer vision to your work, research, and projects. Join me in computer vision mastery.
Inside PyImageSearch University you'll find:
- ✓ 86 courses on essential computer vision, deep learning, and OpenCV topics
- ✓ 86 Certificates of Completion
- ✓ 115+ hours of on-demand video
- ✓ Brand new courses released regularly, ensuring you can keep up with state-of-the-art techniques
- ✓ Pre-configured Jupyter Notebooks in Google Colab
- ✓ Run all code examples in your web browser — works on Windows, macOS, and Linux (no dev environment configuration required!)
- ✓ Access to centralized code repos for all 540+ tutorials on PyImageSearch
- ✓ Easy one-click downloads for code, datasets, pre-trained models, etc.
- ✓ Access on mobile, laptop, desktop, etc.
Summary
In this tutorial, you learned what TFRecords are and how to generate them to train deep neural networks using TensorFlow.
We first started with the basics of TFRecords and learned how to serialize data using them. Next, we learned how to pre-process and serialize a large dataset like div2k using TFRecords.
The two main advantages of the TFRecord format are that it helps us store datasets efficiently, and we get faster I/O speed than reading raw data from disk.
TFRecords are extremely beneficial when we are training deep neural networks with TPUs. If this interests you, then definitely check out the SRGAN and ESRGAN tutorials, which cover how to train deep neural networks using both Tensor Processing Units (TPUs) and Graphics Processing Units (GPUs).
Reference
The list of all the tutorials that have helped us:
- https://www.kaggle.com/code/ryanholbrook/tfrecords-basics/notebook
- https://www.tensorflow.org/tutorials/load_data/tfrecord
Citation Information
A. R. Gosthipaty and A. Thanki. “Introduction to TFRecords,” PyImageSearch, D. Chakraborty, P. Chugh, S. Huot, K. Kidriavsteva, and R. Raha, eds., 2022, https://pyimg.co/s5p1b
@incollection{ARG-AT_2022_TFRecords, author = {Aritra Roy Gosthipaty and Abhishek Thanki}, title = {Introduction to TFRecords}, booktitle = {PyImageSearch}, editor = {Devjyoti Chakraborty and Puneet Chugh and Susan Huot and Kseniia Kidriavsteva and Ritwik Raha}, year = {2022}, note = {https://pyimg.co/s5p1b}, }
Unleash the potential of computer vision with Roboflow - Free!
- Step into the realm of the future by signing up or logging into your Roboflow account. Unlock a wealth of innovative dataset libraries and revolutionize your computer vision operations.
- Jumpstart your journey by choosing from our broad array of datasets, or benefit from PyimageSearch’s comprehensive library, crafted to cater to a wide range of requirements.
- Transfer your data to Roboflow in any of the 40+ compatible formats. Leverage cutting-edge model architectures for training, and deploy seamlessly across diverse platforms, including API, NVIDIA, browser, iOS, and beyond. Integrate our platform effortlessly with your applications or your favorite third-party tools.
- Equip yourself with the ability to train a potent computer vision model in a mere afternoon. With a few images, you can import data from any source via API, annotate images using our superior cloud-hosted tool, kickstart model training with a single click, and deploy the model via a hosted API endpoint. Tailor your process by opting for a code-centric approach, leveraging our intuitive, cloud-based UI, or combining both to fit your unique needs.
- Embark on your journey today with absolutely no credit card required. Step into the future with Roboflow.
To download the source code to this post (and be notified when future tutorials are published here on PyImageSearch), simply enter your email address in the form below!
Download the Source Code and FREE 17-page Resource Guide
Enter your email address below to get a .zip of the code and a FREE 17-page Resource Guide on Computer Vision, OpenCV, and Deep Learning. Inside you'll find my hand-picked tutorials, books, courses, and libraries to help you master CV and DL!
Comment section
Hey, Adrian Rosebrock here, author and creator of PyImageSearch. While I love hearing from readers, a couple years ago I made the tough decision to no longer offer 1:1 help over blog post comments.
At the time I was receiving 200+ emails per day and another 100+ blog post comments. I simply did not have the time to moderate and respond to them all, and the sheer volume of requests was taking a toll on me.
Instead, my goal is to do the most good for the computer vision, deep learning, and OpenCV community at large by focusing my time on authoring high-quality blog posts, tutorials, and books/courses.
If you need help learning computer vision and deep learning, I suggest you refer to my full catalog of books and courses — they have helped tens of thousands of developers, students, and researchers just like yourself learn Computer Vision, Deep Learning, and OpenCV.
Click here to browse my full catalog.