In this tutorial, you will learn how to perform image classification with pre-trained networks using PyTorch. Utilizing these networks, you can accurately classify 1,000 common object categories in only a few lines of code.
Today’s tutorial is part four in our five part series on PyTorch fundamentals:
- What is PyTorch?
- Intro to PyTorch: Training your first neural network using PyTorch
- PyTorch: Training your first Convolutional Neural Network
- PyTorch image classification with pre-trained networks (today’s tutorial)
- August 2nd: PyTorch object detection with pre-trained networks (next week’s tutorial)
Throughout the rest of this tutorial, you’ll gain experience using PyTorch to classify input images using seminal, state-of-the-art image classification networks, including VGG, Inception, DenseNet, and ResNet.
To learn how to perform image classification with pre-trained PyTorch networks, just keep reading.
Looking for the source code to this post?
Jump Right To The Downloads SectionPyTorch image classification with pre-trained networks
In the first part of this tutorial, we’ll discuss what pre-trained image classification networks are, including those that are built into the PyTorch library.
From there, we’ll configure our development environment and review our project directory structure.
I’ll then show you how to implement a Python script that can accurately classify input images using pre-trained PyTorch networks.
We’ll wrap up this tutorial with a discussion of our results.
What are pre-trained image classification networks?
When it comes to image classification, there is no dataset/challenge more famous than ImageNet. The goal of ImageNet is to accurately classify input images into a set of 1,000 common object categories that computer vision systems will “see” in everyday life.
Most popular deep learning frameworks, including PyTorch, Keras, TensorFlow, fast.ai, and others, include pre-trained networks. These are highly accurate, state-of-the-art models that computer vision researchers trained on the ImageNet dataset.
After training on ImageNet was complete, researchers saved their models to disk and then published them freely for other researchers, students, and developers to learn from and use in their own projects.
This tutorial will show how to use PyTorch to classify input images using the following state-of-the-art classification networks:
- VGG16
- VGG19
- Inception
- DenseNet
- ResNet
Let’s get started!
Configuring your development environment
To follow this guide, you need to have both PyTorch and OpenCV installed on your system.
Luckily, both PyTorch and OpenCV are extremely easy to install using pip:
$ pip install torch torchvision $ pip install opencv-contrib-python
If you need help configuring your development environment for PyTorch, I highly recommend that you read the PyTorch documentation — PyTorch’s documentation is comprehensive and will have you up and running quickly.
And if you need help installing OpenCV, be sure to refer to my pip install OpenCV tutorial.
Having problems configuring your development environment?
All that said, are you:
- Short on time?
- Learning on your employer’s administratively locked system?
- Wanting to skip the hassle of fighting with the command line, package managers, and virtual environments?
- Ready to run the code right now on your Windows, macOS, or Linux system?
Then join PyImageSearch University today!
Gain access to Jupyter Notebooks for this tutorial and other PyImageSearch guides that are pre-configured to run on Google Colab’s ecosystem right in your web browser! No installation required.
And best of all, these Jupyter Notebooks will run on Windows, macOS, and Linux!
Project structure
Before we implement image classification with PyTorch, let’s first review our project directory structure.
Start by accessing the “Downloads” section of this guide to retrieve the source code and example images. You’ll then be presented with the following directory structure.
$ tree . --dirsfirst . ├── images │ ├── bmw.png │ ├── boat.png │ ├── clint_eastwood.jpg │ ├── jemma.png │ ├── office.png │ ├── scotch.png │ ├── soccer_ball.jpg │ └── tv.png ├── pyimagesearch │ └── config.py ├── classify_image.py └── ilsvrc2012_wordnet_lemmas.txt
Inside the pyimagesearch
module we have a single file, config.py
. This file stores important configurations, such as:
- Our input image dimensions
- Mean and standard deviation for mean subtraction and scaling
- Whether or not we are using a GPU for training
- Path to the human-readable ImageNet class labels (i.e.,
ilsvrc2012_wordnet_lemmas.txt
)
Our classify_image.py
script will load our config
and then classify an input image using either VGG16, VGG19, Inception, DenseNet, or ResNet (depending on which model architecture we supply as our command line argument).
The images
directory contains a number of sample images where we’ll apply these image classification networks.
Creating our configuration file
Before we implement our image classification driver script, let’s first create a configuration file to store important configurations.
Open the config.py
file in the pyimagesearch
module and insert the following code:
# import the necessary packages import torch # specify image dimension IMAGE_SIZE = 224 # specify ImageNet mean and standard deviation MEAN = [0.485, 0.456, 0.406] STD = [0.229, 0.224, 0.225] # determine the device we will be using for inference DEVICE = "cuda" if torch.cuda.is_available() else "cpu" # specify path to the ImageNet labels IN_LABELS = "ilsvrc2012_wordnet_lemmas.txt"
Line 5 defines our input image spatial dimensions, meaning that each image will be resized to 224×224 pixels before being passed through our pre-trained PyTorch network for classification.
Note: Most networks trained on the ImageNet dataset accept images that are 224×224 or 227×227. Some networks, particularly fully convolutional networks, may accept larger image dimensions.
From there, we define the mean and standard deviation of RGB pixel intensities across our training set (Lines 8 and 9). Prior to passing an input image through our network for classification, we first scale the image pixel intensities by subtracting the mean and then dividing by the standard deviation — this preprocessing is typical for CNNs trained on large, diverse image datasets such as ImageNet.
From there, Line 12 specifies whether we are using our CPU or GPU for training, while Line 15 defines the path to our input text file of ImageNet class labels.
If you were to open this file in your favorite text editor of choice, you would see the following contents:
tench, Tinca_tinca goldfish, Carassius_auratus ... bolete ear, spike, capitulum toilet_tissue, toilet_paper, bathroom_tissue
Each row in this text file maps to the name of a class label our pre-trained PyTorch networks were trained to recognize and classify.
Implementing our image classification script
With our configuration file taken care of, let’s move on to implementing our main driver script used to classify input images using our pre-trained PyTorch networks.
Open the classify_image.py
file in your project directory structure, and let’s get to work:
# import the necessary packages from pyimagesearch import config from torchvision import models import numpy as np import argparse import torch import cv2
We start on Lines 2-7 importing our Python packages, including:
config
: The configuration file we implemented from the previous sectionmodels
: Contains PyTorch’s pre-trained neural networksnumpy
: Numerical array processingtorch
: Accesses the PyTorch APIcv2
: Our OpenCV bindings
With our imports taken care of, let’s define a function to accept an input image and preprocess it:
def preprocess_image(image): # swap the color channels from BGR to RGB, resize it, and scale # the pixel values to [0, 1] range image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) image = cv2.resize(image, (config.IMAGE_SIZE, config.IMAGE_SIZE)) image = image.astype("float32") / 255.0 # subtract ImageNet mean, divide by ImageNet standard deviation, # set "channels first" ordering, and add a batch dimension image -= config.MEAN image /= config.STD image = np.transpose(image, (2, 0, 1)) image = np.expand_dims(image, 0) # return the preprocessed image return image
Our preprocess_image
function takes a single argument, image
, which is the image we’ll be preprocessing for classification.
We start the preprocessing operations by:
- Swapping from BGR to RGB channel ordering (the pre-trained networks we’re using here utilized RGB channel ordering whereas OpenCV uses BGR ordering by default)
- Resizing our image to fixed dimensions (i.e., 224×224), ignoring aspect ratio
- Converting our image to a floating point data type and then scaling the pixel intensities to the range [0, 1]
From there, we perform a second set of preprocessing operations:
- Subtracting the mean (Line 18) and dividing by the standard deviation (Line 19)
- Moving the channels dimension to the front of the array (Line 20), which is called channels-first ordering and is the default channel ordering method that PyTorch expects
- Adding a batch dimension to the array (Line 21)
The preprocessed image
is then returned to the calling function.
Next, let’s parse our command line arguments:
# construct the argument parser and parse the arguments ap = argparse.ArgumentParser() ap.add_argument("-i", "--image", required=True, help="path to the input image") ap.add_argument("-m", "--model", type=str, default="vgg16", choices=["vgg16", "vgg19", "inception", "densenet", "resnet"], help="name of pre-trained network to use") args = vars(ap.parse_args())
We have two command line arguments to parse:
--image
: The path to the input image that we wish to classify--model
: The pre-trained CNN model we’ll be using to classify the image
Let’s now define a MODELS
dictionary which maps the name of the --model
command line argument to its corresponding PyTorch function:
# define a dictionary that maps model names to their classes # inside torchvision MODELS = { "vgg16": models.vgg16(pretrained=True), "vgg19": models.vgg19(pretrained=True), "inception": models.inception_v3(pretrained=True), "densenet": models.densenet121(pretrained=True), "resnet": models.resnet50(pretrained=True) } # load our the network weights from disk, flash it to the current # device, and set it to evaluation mode print("[INFO] loading {}...".format(args["model"])) model = MODELS[args["model"]].to(config.DEVICE) model.eval()
Lines 37-43 create our MODELS
dictionary:
- The key to the dictionary is the human-readable name of the model, passed in via the
--model
command line argument. - The value to the dictionary is the corresponding PyTorch function used to load the model with the weights pre-trained on ImageNet
You’ll be able to use the following pre-trained models to classify an input image with PyTorch:
- VGG16
- VGG19
- Inception
- DenseNet
- ResNet
Specifying the pretrained=True
flag instructs PyTorch to not only load the model architecture definition, but also download the pre-trained ImageNet weights for the model.
Line 48 then loads the model and pre-trained weights (if you’ve never downloaded the model weights before they will be automatically downloaded and cached for you) and then sets the model to run either on your CPU or GPU, depending on your DEVICE
from the configuration file.
Line 49 puts our model
into evaluation mode, instructing PyTorch to handle special layers, such as dropout and batch normalization, different from how it would otherwise handle them during training. Putting your model into evaluation mode before making predictions is critical, so don’t forget to do it!
Now that our model is loaded, we need an input image — let’s take care of that now:
# load the image from disk, clone it (so we can draw on it later), # and preprocess it print("[INFO] loading image...") image = cv2.imread(args["image"]) orig = image.copy() image = preprocess_image(image) # convert the preprocessed image to a torch tensor and flash it to # the current device image = torch.from_numpy(image) image = image.to(config.DEVICE) # load the preprocessed the ImageNet labels print("[INFO] loading ImageNet labels...") imagenetLabels = dict(enumerate(open(config.IN_LABELS)))
Line 54 loads our input image
from disk. We make a copy of it on Line 55 so that we can draw on it and visualize the top prediction of our network. We also make use of our preprocess_image
function on Line 56 to perform resizing and scaling.
Line 60 converts our image
from a NumPy array to a PyTorch tensor, while Line 61 moves the image
to our device (either CPU or GPU).
FInally, Line 65 loads our input ImageNet class labels from disk.
We are now ready to make predictions on input image
using our model
:
# classify the image and extract the predictions print("[INFO] classifying image with '{}'...".format(args["model"])) logits = model(image) probabilities = torch.nn.Softmax(dim=-1)(logits) sortedProba = torch.argsort(probabilities, dim=-1, descending=True) # loop over the predictions and display the rank-5 predictions and # corresponding probabilities to our terminal for (i, idx) in enumerate(sortedProba[0, :5]): print("{}. {}: {:.2f}%".format (i, imagenetLabels[idx.item()].strip(), probabilities[0, idx.item()] * 100))
Line 69 performs a forward-pass of our network, resulting in the outputs of the network.
We pass these through the Softmax
function on Line 70 to obtain the predicted probabilities for each of the possible 1,000 class labels the model
was trained on.
Line 71 then sorts the probabilities in descending order with higher probabilities at the front of the list.
We then display the top-5 predicted class labels and corresponding probabilities to our terminal on Lines 75-78 by:
- Looping over the top-5 predictions
- Looking up the name of the class label using our
imagenetLabels
dictionary - Displaying the predicted probability
Our final code block draws the top-1 (i.e., top predicted label) on our output image:
# draw the top prediction on the image and display the image to # our screen (label, prob) = (imagenetLabels[probabilities.argmax().item()], probabilities.max().item()) cv2.putText(orig, "Label: {}, {:.2f}%".format(label.strip(), prob * 100), (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.8, (0, 0, 255), 2) cv2.imshow("Classification", orig) cv2.waitKey(0)
The result is then displayed to our screen.
Image classification with PyTorch results
We are now ready to apply image classification with PyTorch!
Be sure to access the “Downloads” section of this tutorial to retrieve the source code and example images.
From there, try classifying an input image using the following command:
$ python classify_image.py --image images/boat.png [INFO] loading vgg16... [INFO] loading image... [INFO] loading ImageNet labels... [INFO] classifying image with 'vgg16'... 0. wreck: 99.99% 1. seashore, coast, seacoast, sea-coast: 0.01% 2. pirate, pirate_ship: 0.00% 3. breakwater, groin, groyne, mole, bulwark, seawall, jetty: 0.00% 4. sea_lion: 0.00%
It appears that Captain Jack Sparrow is stranded on the beach! And sure enough, the VGG16 network is able to correctly classify the input image as a “wreck” (i.e., shipwreck) with 99.99% probability.
It’s also interesting to see that “seashore” is the second top prediction from the model — this prediction is also accurate, due to the boat being on the beach.
Let’s try a different image, this time using the DenseNet model:
$ python classify_image.py --image images/bmw.png --model densenet [INFO] loading densenet... [INFO] loading image... [INFO] loading ImageNet labels... [INFO] classifying image with 'densenet'... 0. convertible: 96.61% 1. sports_car, sport_car: 2.25% 2. car_wheel: 0.45% 3. beach_wagon, station_wagon, wagon, estate_car, beach_waggon, station_waggon, waggon: 0.22% 4. racer, race_car, racing_car: 0.13%
The top prediction from DenseNet is “convertible” with 96.61% accuracy. The second top prediction, “sports car” is also accurate.
This image contains Jemma, my family’s beagle:
$ python classify_image.py --image images/jemma.png --model resnet [INFO] loading resnet... [INFO] loading image... [INFO] loading ImageNet labels... [INFO] classifying image with 'resnet'... 0. beagle: 95.98% 1. bluetick: 1.46% 2. Walker_hound, Walker_foxhound: 1.11% 3. English_foxhound: 0.45% 4. maraca: 0.25%
Here we are using the ResNet architecture to classify our input image. Jemma is a “beagle” (a type of dog), which ResNet accurately predicts with 95.98% probability.
Interestingly, a “bluetick,” “walker hound,” and “English foxhound” are all types of dogs belonging to the “hound” family — all of these would be reasonable predictions from the model.
Let’s take a look at one final example:
$ python classify_image.py --image images/soccer_ball.jpg --model inception [INFO] loading inception... [INFO] loading image... [INFO] loading ImageNet labels... [INFO] classifying image with 'inception'... 0. soccer_ball: 100.00% 1. volleyball: 0.00% 2. sea_urchin: 0.00% 3. rugby_ball: 0.00% 4. silky_terrier, Sydney_silky: 0.00%
Our Inception model correctly classifies the input image as “soccer ball” with 100% probability.
Image classification allows us to assign one or more labels to an input image; however, it tells us nothing about where in the image the object resides.
To determine where in an input image a given object is, we need to apply object detection:
Just like we have pre-trained networks for image classification, we also have pre-trained networks for object detection as well. Next week you’ll learn how to use PyTorch to detect objects in images using specialized object detection networks.
What's next? We recommend PyImageSearch University.
86 total classes • 115+ hours of on-demand code walkthrough videos • Last updated: October 2024
★★★★★ 4.84 (128 Ratings) • 16,000+ Students Enrolled
I strongly believe that if you had the right teacher you could master computer vision and deep learning.
Do you think learning computer vision and deep learning has to be time-consuming, overwhelming, and complicated? Or has to involve complex mathematics and equations? Or requires a degree in computer science?
That’s not the case.
All you need to master computer vision and deep learning is for someone to explain things to you in simple, intuitive terms. And that’s exactly what I do. My mission is to change education and how complex Artificial Intelligence topics are taught.
If you're serious about learning computer vision, your next stop should be PyImageSearch University, the most comprehensive computer vision, deep learning, and OpenCV course online today. Here you’ll learn how to successfully and confidently apply computer vision to your work, research, and projects. Join me in computer vision mastery.
Inside PyImageSearch University you'll find:
- ✓ 86 courses on essential computer vision, deep learning, and OpenCV topics
- ✓ 86 Certificates of Completion
- ✓ 115+ hours of on-demand video
- ✓ Brand new courses released regularly, ensuring you can keep up with state-of-the-art techniques
- ✓ Pre-configured Jupyter Notebooks in Google Colab
- ✓ Run all code examples in your web browser — works on Windows, macOS, and Linux (no dev environment configuration required!)
- ✓ Access to centralized code repos for all 540+ tutorials on PyImageSearch
- ✓ Easy one-click downloads for code, datasets, pre-trained models, etc.
- ✓ Access on mobile, laptop, desktop, etc.
Summary
In this tutorial, you learned how to perform image classification using PyTorch. Specifically, we utilized popular pre-trained network architectures, including:
- VGG16
- VGG19
- Inception
- DenseNet
- ResNet
These models were trained by the researchers responsible for inventing and proposing the novel architectures listed above. After training was complete, these researchers saved the model weights to disk and then published them for other researchers, students, and developers to learn from and use in their own projects.
While the models are free to use, make sure you check any terms/conditions associated with them, as some models are not free to use in commercial applications (typically entrepreneurs in the AI space get around this restriction by training the models themselves rather than using the pre-trained weights provided by the original authors).
Stay tuned for next week’s blog post, where you’ll learn how to perform object detection using PyTorch.
To download the source code to this post (and be notified when future tutorials are published here on PyImageSearch), simply enter your email address in the form below!
Download the Source Code and FREE 17-page Resource Guide
Enter your email address below to get a .zip of the code and a FREE 17-page Resource Guide on Computer Vision, OpenCV, and Deep Learning. Inside you'll find my hand-picked tutorials, books, courses, and libraries to help you master CV and DL!
Comment section
Hey, Adrian Rosebrock here, author and creator of PyImageSearch. While I love hearing from readers, a couple years ago I made the tough decision to no longer offer 1:1 help over blog post comments.
At the time I was receiving 200+ emails per day and another 100+ blog post comments. I simply did not have the time to moderate and respond to them all, and the sheer volume of requests was taking a toll on me.
Instead, my goal is to do the most good for the computer vision, deep learning, and OpenCV community at large by focusing my time on authoring high-quality blog posts, tutorials, and books/courses.
If you need help learning computer vision and deep learning, I suggest you refer to my full catalog of books and courses — they have helped tens of thousands of developers, students, and researchers just like yourself learn Computer Vision, Deep Learning, and OpenCV.
Click here to browse my full catalog.