Table of Contents
- Enhanced Super-Resolution Generative Adversarial Networks (ESRGAN)
- Preface
- Enhanced Super-Resolution GANs
- Configuring Your Development Environment
- Having Problems Configuring Your Development Environment?
- Project Structure
- Configuring the Prerequisites
- Implementing Data Processing Utilities
- Implementing the ESRGAN Architecture
- Building the Training Pipeline for the ESRGAN
- Creating Utility Functions to Aid GAN Training
- Training the ESRGAN
- Building an Inference Script for the ESRGAN
- Visualizations of the ESRGAN
- Summary
Enhanced Super-Resolution Generative Adversarial Networks (ESRGAN)
Last week we learned about Super-Resolution GANs. They worked tremendously well on achieving better sharpness in super-resolution images. But was that the end of roads for GANs in the domain of super-resolution?
A common theme in deep learning is that growth never stops. Thus, we move on to Enhanced Super-Resolution GANs. As the name suggests, it brings in many updates over the original SRGAN architecture, which drastically improves performance and visualizations.
In this tutorial, you will learn how to implement ESRGAN using tensorflow.
This lesson is the 2nd in a 4-part series on GANs 201:
- Super-Resolution Generative Adversarial Networks (SRGAN)
- Enhanced Super-Resolution Generative Adversarial Networks (ESRGAN) (this tutorial)
- Pix2Pix GAN for Image-to-Image Translation
- CycleGAN for Image-to-Image Translation
To learn how to implement an ESRGAN, just keep reading.
Looking for the source code to this post?
Jump Right To The Downloads SectionPreface
GANs train two neural networks: the discriminator and the generator, simultaneously. The generator is to create fake images while the discriminator judges them as real or fake.
SRGANs used this idea in the domain of image super-resolution. The generator produces super-resolution images, while the discriminator judges them as real and fake.
Enhanced Super-Resolution GANs
Building on the foundation led by SRGANs, the ESRGAN’s main aim is to introduce model modification such that the training is efficient and less complex.
A brief recap of SRGANs:
- Feed low-resolution images as input to a generator and get super-resolution images as outputs.
- Pass those predictions through a discriminator and get the real or fake branding.
- Use a VGG net to add perceptual loss (pixel-wise) to add more sharpness to our predicted fake image.
But what updates did ESRGANs bring?
For starters, some major steps have been taken for the generator to ensure an increase in performance:
- Removal of Batch-Normalization Layers: A brief recap of the SRGAN architecture will show that batch-normalization layers were extensively used throughout the generator architecture. ESRGANs scrap the use of BN layers entirely, owing to increased performance and decreased computational complexity.
- Residual in Residual Dense Block: An upgrade over the standard residual block, this particular structure allows all the layer outputs in a block to be passed to the subsequent layers, as shown in Figure 1. The intuition here is that the model has access to many features to choose from and determine its relevance. Also, ESRGAN uses Residual Scaling to scale down the residual outputs to prevent instability.
- For the discriminator, the main addition is the relativistic loss. It estimates the probability of a real image being relatively more realistic than a fake predicted one. Naturally, adding this as a loss function makes the model work on overcoming the relativistic loss.
- The perceptual loss is changed a bit, making the loss based on features right before the activation function rather than after the activation function, as shown in last week’s SRGAN paper.
- The total loss is now the combination of the GAN loss, perceptual loss, and the pixel-wise distance between the ground truth high-resolution and predicted images.
These additions help improve the results drastically. In our implementation, we have stayed true to the paper and brought these updates to the traditional SRGAN to improve super-resolution results.
The core ideology behind ESRGAN was not only to enhance the results but also to make the process far more efficient. Therefore, the paper does not globally condemn batch-normalization usage. Still, it states that scraping the use of BN layers would be beneficial for our particular task, where even the similarity in the smallest of pixels is needed.
Configuring Your Development Environment
To follow this guide, you need to have the OpenCV library installed on your system.
Luckily, OpenCV is pip-installable:
$ pip install opencv-contrib-python
If you need help configuring your development environment for OpenCV, we highly recommend that you read our pip install OpenCV guide — it will have you up and running in a matter of minutes.
Having Problems Configuring Your Development Environment?
All that said, are you:
- Short on time?
- Learning on your employer’s administratively locked system?
- Wanting to skip the hassle of fighting with the command line, package managers, and virtual environments?
- Ready to run the code right now on your Windows, macOS, or Linux system?
Then join PyImageSearch University today!
Gain access to Jupyter Notebooks for this tutorial and other PyImageSearch guides that are pre-configured to run on Google Colab’s ecosystem right in your web browser! No installation required.
And best of all, these Jupyter Notebooks will run on Windows, macOS, and Linux!
Project Structure
We first need to review our project directory structure.
Start by accessing the “Downloads” section of this tutorial to retrieve the source code and example images.
From there, take a look at the directory structure:
!tree . . ├── create_tfrecords.py ├── inference.py ├── outputs ├── pyimagesearch │ ├── config.py │ ├── data_preprocess.py │ ├── __init__.py │ ├── losses.py │ ├── esrgan.py │ ├── esrgan_training.py │ ├── utils.py │ └── vgg.py └── train_esrgan.py 2 directories, 11 files
In the pyimagesearch
directory, we have:
config.py
: Contains an end-to-end configuration pipeline for the complete project.data_preprocess.py
: Contains functions to aid in data processing.__init__.py
: Makes the directory act like a python package.losses.py
: Initializes the losses required to train the ESRGAN.esrgan.py
: Contains the ESRGAN architecture.esrgan_training.py
: Contains the training class which runs the ESRGAN training.utils.py
: Contains additional utilityvgg.py
: Initializes a VGG model for perception loss calculation.
In the root directory, we have:
create_tfrecords.py
: CreatesTFRecords
from the dataset we will use.inference.py
: Draws inference using the trained models.train_srgan.py
: Executes the ESRGAN training using theesrgan.py
andesrgan_training.py
scripts.
Configuring the Prerequisites
The config.py
script located in the pyimagesearch
directory houses several parameters and paths required throughout the project. It is a good coding practice to keep your configuration variables separate. For that, let us move to the config.py
script.
# import the necessary packages import os # name of the TFDS dataset we will be using DATASET = "div2k/bicubic_x4" # define the shard size and batch size SHARD_SIZE = 256 TRAIN_BATCH_SIZE = 64 INFER_BATCH_SIZE = 8 # dataset specs HR_SHAPE = [96, 96, 3] LR_SHAPE = [24, 24, 3] SCALING_FACTOR = 4
We start with referencing our project’s dataset on Line 5.
The SHARD_SIZE
for TFRecords
is defined on Line 8. This is followed by the TRAIN_BATCH_SIZE
and INFER_BATCH_SIZE
definitions on Lines 9 and 10.
Our high-resolution output image will have dimensions of 96 x 96 x 3
while our input low-resolution images will have dimensions of 24 x 24 x 3
(Lines 13 and 14). Accordingly, the SCALING_FACTOR
is set to 4
on Line 15.
# GAN model specs FEATURE_MAPS = 64 RESIDUAL_BLOCKS = 16 LEAKY_ALPHA = 0.2 DISC_BLOCKS = 4 RESIDUAL_SCALAR = 0.2 # training specs PRETRAIN_LR = 1e-4 FINETUNE_LR = 3e-5 PRETRAIN_EPOCHS = 1500 FINETUNE_EPOCHS = 1000 STEPS_PER_EPOCH = 10 # define the path to the dataset BASE_DATA_PATH = "dataset" DIV2K_PATH = os.path.join(BASE_DATA_PATH, "div2k")
As we had done for the SRGAN, the architecture consists of residual networks. First, we set the number of filters used in the Conv2D
layer (Line 18). On Line 19, we define the number of residual blocks. The alpha
parameter for our ReLU
function is set on Line 20.
The discriminator architecture will be automated based on the value of DISC_BLOCKS
(Line 21). Now, we define a value for RESIDUAL_SCALAR
, which will help us scale the residual block outputs to levels and keep the training procedure stable (Line 22).
Now, a repeat of our SRGAN parameters (learning rate, epochs, etc.) is done on Lines 25-29. We will pretrain our GAN and then fully train it for comparison. For that reason, we have defined the learning rate and epochs for our pretrained GAN and the fully trained GAN.
The BASE_DATA_PATH
is set to define the to store our dataset. The DIV2K_PATH
references the DIV2K
dataset (Lines 32 and 33). The div2k
dataset is perfect for aiding image super-resolution research, as it contains a variety of high-resolution images.
# define the path to the tfrecords for GPU training GPU_BASE_TFR_PATH = "tfrecord" GPU_DIV2K_TFR_TRAIN_PATH = os.path.join(GPU_BASE_TFR_PATH, "train") GPU_DIV2K_TFR_TEST_PATH = os.path.join(GPU_BASE_TFR_PATH, "test") # define the path to the tfrecords for TPU training TPU_BASE_TFR_PATH = "gs://<PATH_TO_GCS_BUCKET>/tfrecord" TPU_DIV2K_TFR_TRAIN_PATH = os.path.join(TPU_BASE_TFR_PATH, "train") TPU_DIV2K_TFR_TEST_PATH = os.path.join(TPU_BASE_TFR_PATH, "test") # path to our base output directory BASE_OUTPUT_PATH = "outputs" # GPU training ESRGAN model paths GPU_PRETRAINED_GENERATOR_MODEL = os.path.join(BASE_OUTPUT_PATH, "models", "pretrained_generator") GPU_GENERATOR_MODEL = os.path.join(BASE_OUTPUT_PATH, "models", "generator") # TPU training ESRGAN model paths TPU_OUTPUT_PATH = "gs://<PATH_TO_GCS_BUCKET>/outputs" TPU_PRETRAINED_GENERATOR_MODEL = os.path.join(TPU_OUTPUT_PATH, "models", "pretrained_generator") TPU_GENERATOR_MODEL = os.path.join(TPU_OUTPUT_PATH, "models", "generator") # define the path to the inferred images and to the grid image BASE_IMAGE_PATH = os.path.join(BASE_OUTPUT_PATH, "images") GRID_IMAGE_PATH = os.path.join(BASE_OUTPUT_PATH, "grid.png")
So for comparison of training efficiency, we will be training the GAN on both TPU and GPU. For that, we have to create separate paths referencing data and outputs for both the GPU training and TPU training.
On Lines 36-38, we define TFRecords
for GPU training. On Lines 41-43, we define the TFRecords
for TPU training.
Now, we define the base output path on Line 46. This is followed by the referenced paths for the GPU-trained ESRGAN generator models (Lines 49-52). We do the same for the TPU-trained ESRGAN generator models (Lines 55-59).
With all set and done, the only remaining task is to reference the paths to the inferred images (Lines 62 and 63).
Implementing Data Processing Utilities
Training GANs properly requires lots of computation power and data. To ensure that we have sufficient data, we will employ several data augmentation techniques. Let’s look at those in the data_preprocess.py
script located in the pyimagesearch
directory.
# import the necessary packages from tensorflow.io import FixedLenFeature from tensorflow.io import parse_single_example from tensorflow.io import parse_tensor from tensorflow.image import flip_left_right from tensorflow.image import rot90 import tensorflow as tf # define AUTOTUNE object AUTO = tf.data.AUTOTUNE def random_crop(lrImage, hrImage, hrCropSize=96, scale=4): # calculate the low resolution image crop size and image shape lrCropSize = hrCropSize // scale lrImageShape = tf.shape(lrImage)[:2] # calculate the low resolution image width and height offsets lrW = tf.random.uniform(shape=(), maxval=lrImageShape[1] - lrCropSize + 1, dtype=tf.int32) lrH = tf.random.uniform(shape=(), maxval=lrImageShape[0] - lrCropSize + 1, dtype=tf.int32) # calculate the high resolution image width and height hrW = lrW * scale hrH = lrH * scale # crop the low and high resolution images lrImageCropped = tf.slice(lrImage, [lrH, lrW, 0], [(lrCropSize), (lrCropSize), 3]) hrImageCropped = tf.slice(hrImage, [hrH, hrW, 0], [(hrCropSize), (hrCropSize), 3]) # return the cropped low and high resolution images return (lrImageCropped, hrImageCropped)
Considering the amount of TensorFlow wrappers we will use in this project, defining a tf.data.AUTOTUNE
object for space optimization is a good idea.
The first data augmentation function we have defined is random_crop
(Line 12). It takes in the following arguments:
lrImage
: The low-resolution image.hrImage
: The high-resolution image.hrCropSize
: The high-resolution crop size for low-resolution crop calculation.scale
: The factor by which we calculate the low-resolution crop.
The lr width and height offsets are then calculated (Lines 18-21).
To calculate the corresponding high-resolution values, we simply multiply the low-resolution values with the scale factor (Lines 24 and 25).
Using the values, we crop out the low-resolution image and its corresponding high-resolution crop and return them (Lines 28-34)
def get_center_crop(lrImage, hrImage, hrCropSize=96, scale=4): # calculate the low resolution image crop size and image shape lrCropSize = hrCropSize // scale lrImageShape = tf.shape(lrImage)[:2] # calculate the low resolution image width and height lrW = lrImageShape[1] // 2 lrH = lrImageShape[0] // 2 # calculate the high resolution image width and height hrW = lrW * scale hrH = lrH * scale # crop the low and high resolution images lrImageCropped = tf.slice(lrImage, [lrH - (lrCropSize // 2), lrW - (lrCropSize // 2), 0], [lrCropSize, lrCropSize, 3]) hrImageCropped = tf.slice(hrImage, [hrH - (hrCropSize // 2), hrW - (hrCropSize // 2), 0], [hrCropSize, hrCropSize, 3]) # return the cropped low and high resolution images return (lrImageCropped, hrImageCropped)
The next line in the data augmentation utilities is get_center_crop
(Line 36), which takes in the following arguments:
lrImage
: The low-resolution image.hrImage
: The high-resolution image.hrCropSize
: The high-resolution crop size for low-resolution crop calculation.scale
: The factor by which we calculate the low-resolution crop.
Just like we created crop size values for our previous function, we get the lr crop size values and image shape on Lines 38 and 39.
Now, to get the center pixel coordinates, we simply have to divide the low-resolution shape by 2
(Lines 42 and 43).
To get the corresponding high-resolution center points, multiply the lr center points by the scale factor (Lines 46 and 47).
def random_flip(lrImage, hrImage): # calculate a random chance for flip flipProb = tf.random.uniform(shape=(), maxval=1) (lrImage, hrImage) = tf.cond(flipProb < 0.5, lambda: (lrImage, hrImage), lambda: (flip_left_right(lrImage), flip_left_right(hrImage))) # return the randomly flipped low and high resolution images return (lrImage, hrImage)
We have the random_flip
function to flip images on Line 58. It takes in the low-resolution and high-resolution images as its arguments.
Based on a flip probability value using tf.random.uniform
, we flip our images and return them (Lines 60-66).
def random_rotate(lrImage, hrImage): # randomly generate the number of 90 degree rotations n = tf.random.uniform(shape=(), maxval=4, dtype=tf.int32) # rotate the low and high resolution images lrImage = rot90(lrImage, n) hrImage = rot90(hrImage, n) # return the randomly rotated images return (lrImage, hrImage)
On Line 68, we have another data augmentation function called random_rotate
, which takes in the low-resolution image and the high-resolution image as its arguments.
The variable n
generates a value that later helps on the amount of rotation to apply to our image sets (Lines 70-77).
def read_train_example(example): # get the feature template and parse a single image according to # the feature template feature = { "lr": FixedLenFeature([], tf.string), "hr": FixedLenFeature([], tf.string), } example = parse_single_example(example, feature) # parse the low and high resolution images lrImage = parse_tensor(example["lr"], out_type=tf.uint8) hrImage = parse_tensor(example["hr"], out_type=tf.uint8) # perform data augmentation (lrImage, hrImage) = random_crop(lrImage, hrImage) (lrImage, hrImage) = random_flip(lrImage, hrImage) (lrImage, hrImage) = random_rotate(lrImage, hrImage) # reshape the low and high resolution images lrImage = tf.reshape(lrImage, (24, 24, 3)) hrImage = tf.reshape(hrImage, (96, 96, 3)) # return the low and high resolution images return (lrImage, hrImage)
With the data augmentation functions out of the way, we can move to the image reading function read_train_example
on Line 79. This function games in a single image set (low-resolution and corresponding high-resolution image)
On Lines 82-90, we create an lr, hr feature template and parse the example set based on it.
Now, we apply the data augmentation functions on the lr-hr set (Lines 93-95). Then we reshape the lr-hr images back to their required dimensions (Lines 98-102).
def read_test_example(example): # get the feature template and parse a single image according to # the feature template feature = { "lr": FixedLenFeature([], tf.string), "hr": FixedLenFeature([], tf.string), } example = parse_single_example(example, feature) # parse the low and high resolution images lrImage = parse_tensor(example["lr"], out_type=tf.uint8) hrImage = parse_tensor(example["hr"], out_type=tf.uint8) # center crop both low and high resolution image (lrImage, hrImage) = get_center_crop(lrImage, hrImage) # reshape the low and high resolution images lrImage = tf.reshape(lrImage, (24, 24, 3)) hrImage = tf.reshape(hrImage, (96, 96, 3)) # return the low and high resolution images return (lrImage, hrImage)
We create a similar function to the read_train_example
for the inference images, called read_test_example
, which takes an lr-hr image set (Line 104). Everything done in the previous function is repeated, except for the data augmentation process (Lines 107-125).
def load_dataset(filenames, batchSize, train=False): # get the TFRecords from the filenames dataset = tf.data.TFRecordDataset(filenames, num_parallel_reads=AUTO) # check if this is the training dataset if train: # read the training examples dataset = dataset.map(read_train_example, num_parallel_calls=AUTO) # otherwise, we are working with the test dataset else: # read the test examples dataset = dataset.map(read_test_example, num_parallel_calls=AUTO) # batch and prefetch the data dataset = (dataset .shuffle(batchSize) .batch(batchSize) .repeat() .prefetch(AUTO) ) # return the dataset return dataset
Now, to sum it all up, we have the load_dataset
function on Line 127. It takes in the filenames, batch size, and a Boolean variable indicating if the mode is training or inference.
On Lines 129 and 130, we get the TFRecords
from the filenames provided. If the mode is set to train, we map the read_train_example
function to the dataset. This way, all entries inside it get passed through the read_train_example
function (Lines 133-136).
If the mode is inference, we wamp the read_test_example
function to the dataset (Lines 138-141).
With our dataset created, it is now batched, shuffled, and set to automatic prefetch (Lines 144-152).
Implementing the ESRGAN Architecture
Our next destination is the esrgan.py
script located in the pyimagesearch
directory. This script houses the complete ESRGAN architecture. As we have already discussed the changes that ESRGAN brings in, let’s go through them one by one.
# import the necessary packages from tensorflow.keras.layers import BatchNormalization from tensorflow.keras.layers import GlobalAvgPool2D from tensorflow.keras.layers.experimental.preprocessing import Rescaling from tensorflow.keras.layers import LeakyReLU from tensorflow.keras.layers import Lambda from tensorflow.keras.layers import Conv2D from tensorflow.keras.layers import Dense from tensorflow.keras.layers import Add from tensorflow.nn import depth_to_space from tensorflow.keras import Model from tensorflow.keras import Input class ESRGAN(object): @staticmethod def generator(scalingFactor, featureMaps, residualBlocks, leakyAlpha, residualScalar): # initialize the input layer inputs = Input((None, None, 3)) xIn = Rescaling(scale=1.0/255, offset=0.0)(inputs) # pass the input through CONV => LeakyReLU block xIn = Conv2D(featureMaps, 9, padding="same")(xIn) xIn = LeakyReLU(leakyAlpha)(xIn)
For ease of workflow, it is better to define the ESRGAN as a class template (Line 14).
First, we work out the generator. The function generator
on Line 16 works as our generator definition and takes in the following arguments:
scalingFactor
: The determining factor for the output image scaling.featureMaps
: The number of convolution filters.residualBlocks
: The number of residual blocks added to the architecture.leakyAlpha
: The factor determining the threshold value for our leakyReLU
functionresidualScalar
: A value that keeps the outputs of residual blocks scaled so that the training is stable.
The inputs are initialized, and the pixels are scaled to the range of 0
and 1
(Lines 19 and 20).
The processed inputs are then passed through a Conv2D
layer followed by a LeakyReLU
activation function (Lines 23 and 24). The arguments for these layers have been defined previously in config.py
.
# construct the residual in residual block x = Conv2D(featureMaps, 3, padding="same")(xIn) x1 = LeakyReLU(leakyAlpha)(x) x1 = Add()([xIn, x1]) x = Conv2D(featureMaps, 3, padding="same")(x1) x2 = LeakyReLU(leakyAlpha)(x) x2 = Add()([x1, x2]) x = Conv2D(featureMaps, 3, padding="same")(x2) x3 = LeakyReLU(leakyAlpha)(x) x3 = Add()([x2, x3]) x = Conv2D(featureMaps, 3, padding="same")(x3) x4 = LeakyReLU(leakyAlpha)(x) x4 = Add()([x3, x4]) x4 = Conv2D(featureMaps, 3, padding="same")(x4) xSkip = Add()([xIn, x4]) # scale the residual outputs with a scalar between [0,1] xSkip = Lambda(lambda x: x * residualScalar)(xSkip)
As we have mentioned before, ESRGAN uses Residual in Residual Blocks. Hence, the base block is defined next.
We start by adding a Conv2D
and a LeakyReLU
layer. Due to the nature of the block, the output of this combination of layers x1
is then added to the initial input x
. This is repeated thrice, with a final Conv2D
layer added before the skip connection joining the initial input xIn
and the block output x4
(Lines 27-40).
Now, this is a little deviation from the original paper, where all layers are interconnected. The interconnection intention is to ensure that the model has access to the previous features at each step. Our approach is sufficient to give us desirable results based on the task and dataset we have used today.
After the skip connection addition, the outputs are scaled using the residualScalar
variable on Line 43.
# create a number of residual in residual blocks for blockId in range(residualBlocks-1): x = Conv2D(featureMaps, 3, padding="same")(xSkip) x1 = LeakyReLU(leakyAlpha)(x) x1 = Add()([xSkip, x1]) x = Conv2D(featureMaps, 3, padding="same")(x1) x2 = LeakyReLU(leakyAlpha)(x) x2 = Add()([x1, x2]) x = Conv2D(featureMaps, 3, padding="same")(x2) x3 = LeakyReLU(leakyAlpha)(x) x3 = Add()([x2, x3]) x = Conv2D(featureMaps, 3, padding="same")(x3) x4 = LeakyReLU(leakyAlpha)(x) x4 = Add()([x3, x4]) x4 = Conv2D(featureMaps, 3, padding="same")(x4) xSkip = Add()([xSkip, x4]) xSkip = Lambda(lambda x: x * residualScalar)(xSkip)
Now the block repetition is an automation using a for
loop. Based on the number of residual blocks specified, the block layers will be added (Lines 46-61).
# process the residual output with a conv kernel x = Conv2D(featureMaps, 3, padding="same")(xSkip) x = Add()([xIn, x]) # upscale the image with pixel shuffle x = Conv2D(featureMaps * (scalingFactor // 2), 3, padding="same")(x) x = tf.nn.depth_to_space(x, 2) x = LeakyReLU(leakyAlpha)(x) # upscale the image with pixel shuffle x = Conv2D(featureMaps, 3, padding="same")(x) x = tf.nn.depth_to_space(x, 2) x = LeakyReLU(leakyAlpha)(x) # get the output layer x = Conv2D(3, 9, padding="same", activation="tanh")(x) output = Rescaling(scale=127.5, offset=127.5)(x) # create the generator model generator = Model(inputs, output) # return the generator model return generator
The final residual output is added with another Conv2D
layer on Lines 64 and 65.
Now, the upscaling process is started on Line 68, where the scalingFactor
variable comes into play. This is followed by the depth_to_space
utility function, which increases the height and width of a featureMaps
by uniformly decreasing the channel size accordingly (Line 70). A LeakyReLU
activation function is added to finish this specific layer combination (Line 71).
This same set of layers is repeated on Lines 73-76. The output layer is achieved by passing the featureMaps
through another Conv2D
layer. Notice how this convolution layer has a tanh
activation function, which scales your input to the range of -1
and 1
.
For this reason, the pixels are rescaled back to the range of 0 and 255. (Lines 79 and 80).
With the initialization of the generator on Line 83, our ESRGAN generator side requirements are complete.
@staticmethod def discriminator(featureMaps, leakyAlpha, discBlocks): # initialize the input layer and process it with conv kernel inputs = Input((None, None, 3)) x = Rescaling(scale=1.0/127.5, offset=-1)(inputs) x = Conv2D(featureMaps, 3, padding="same")(x) x = LeakyReLU(leakyAlpha)(x) # pass the output from previous layer through a CONV => BN => # LeakyReLU block x = Conv2D(featureMaps, 3, padding="same")(x) x = BatchNormalization()(x) x = LeakyReLU(leakyAlpha)(x)
As we know, the objective of the discriminator is to take in the image as input and output one single value, which says if the image is real or fake.
The discriminator
function on Line 89 is defined with the following arguments:
featureMaps
: The number of filters forConv2D
layers.leakyAlpha
: The parameter required for theLeakyReLU
activation function.discBlocks
: Number of discriminator blocks we require in the architecture.
The inputs for the discriminator are initialized, and the pixels are scaled to the range of -1
and 1
(Lines 91 and 92).
The architecture starts with a Conv2D
layer followed by a LeakyReLU
activation layer (Lines 93 and 94).
Although we have discarded batch normalization layers for the generator, we will use them for the discriminator. The next set of layers is a Conv
→ BN
→ LeakyReLU
combination (Lines 98-100).
# create a downsample conv kernel config downConvConf = { "strides": 2, "padding": "same", } # create a number of discriminator blocks for i in range(1, discBlocks): # first CONV => BN => LeakyReLU block x = Conv2D(featureMaps * (2 ** i), 3, **downConvConf)(x) x = BatchNormalization()(x) x = LeakyReLU(leakyAlpha)(x) # second CONV => BN => LeakyReLU block x = Conv2D(featureMaps * (2 ** i), 3, padding="same")(x) x = BatchNormalization()(x) x = LeakyReLU(leakyAlpha)(x)
On Lines 103-106, we create a downsampling convolution template configuration. This is then used in the automated discriminator blocks on Lines 109-118.
# process the feature maps with global average pooling x = GlobalAvgPool2D()(x) x = LeakyReLU(leakyAlpha)(x) # final FC layer with sigmoid activation function x = Dense(1, activation="sigmoid")(x) # create the discriminator model discriminator = Model(inputs, x) # return the discriminator model return discriminator
The feature maps are then passed through the GlobalAvgPool2D
layer and another LeakyReLU
activation layer, after which a final dense layer gives us our output (Lines 121-125).
The discriminator object is initialized and returned on Lines 128-131, and this concludes the discriminator function.
Building the Training Pipeline for the ESRGAN
With the architecture complete, it’s time to move on to the esrgan_training.py
script located in the pyimagesearch
directory.
# import the necessary packages from tensorflow.keras import Model from tensorflow import concat from tensorflow import zeros from tensorflow import ones from tensorflow import GradientTape from tensorflow.keras.activations import sigmoid from tensorflow.math import reduce_mean import tensorflow as tf class ESRGANTraining(Model): def __init__(self, generator, discriminator, vgg, batchSize): # initialize the generator, discriminator, vgg model, and # the global batch size super().__init__() self.generator = generator self.discriminator = discriminator self.vgg = vgg self.batchSize = batchSize
To make things easier, the complete training module is packaged inside a class defined on Line 11.
Naturally, the first function becomes __init__
, which takes in the generator model, discriminator model, VGG model, and the batch size specification (Line 12).
In this function, we create the corresponding class variables for the arguments (Lines 16-19). These variables will be used for the class functions later.
def compile(self, gOptimizer, dOptimizer, bceLoss, mseLoss): super().compile() # initialize the optimizers for the generator # and discriminator self.gOptimizer = gOptimizer self.dOptimizer = dOptimizer # initialize the loss functions self.bceLoss = bceLoss self.mseLoss = mseLoss
The compile
function on Line 21 takes in the generator and discriminator optimizers, binary cross entropy loss function, and the mean squared loss function.
This function initializes the optimizers and loss functions for the generator and the discriminator (Lines 25-30).
def train_step(self, images): # grab the low and high resolution images (lrImages, hrImages) = images lrImages = tf.cast(lrImages, tf.float32) hrImages = tf.cast(hrImages, tf.float32) # generate super resolution images srImages = self.generator(lrImages) # combine them with real images combinedImages = concat([srImages, hrImages], axis=0) # assemble labels discriminating real from fake images where # label 0 is for predicted images and 1 is for original high # resolution images labels = concat( [zeros((self.batchSize, 1)), ones((self.batchSize, 1))], axis=0)
Now it’s time for us to define the training procedure. This is done in the function train_step
defined on Line 32. This function takes in the images as its arguments.
We unpack the image set into its corresponding low-resolution and high-resolution images and cast them into the float32
data type (Lines 34-36).
On Line 39, we get a fake batch of super-resolution images from the generator. These are concatenated with the real super-resolution images, and the labels are accordingly created (Lines 47-49).
# train the discriminator with relativistic error with GradientTape() as tape: # get the raw predictions and divide them into # raw fake and raw real predictions rawPreds = self.discriminator(combinedImages) rawFake = rawPreds[:self.batchSize] rawReal = rawPreds[self.batchSize:] # process the relative raw error and pass it through the # sigmoid activation function predFake = sigmoid(rawFake - reduce_mean(rawReal)) predReal = sigmoid(rawReal - reduce_mean(rawFake)) # concat the predictions and calculate the discriminator # loss predictions = concat([predFake, predReal], axis=0) dLoss = self.bceLoss(labels, predictions) # compute the gradients grads = tape.gradient(dLoss, self.discriminator.trainable_variables) # optimize the discriminator weights according to the # gradients computed self.dOptimizer.apply_gradients( zip(grads, self.discriminator.trainable_variables) ) # generate misleading labels misleadingLabels = ones((self.batchSize, 1))
First, we will define the discriminator training. Initiating a GradientTape
, we get predictions from our discriminator on the combined image set (Lines 52-55). We separate the fake image predictions and the real image predictions from these predictions and get relativistic errors for both. The values are then passed through a sigmoid function to get our final output values (Lines 56-62).
The predictions are again concatenated, and the discriminator loss is calculated by passing the predictions through the binary cross entropy loss (Lines 66 and 67).
With the loss values, the gradients are calculated, and the discriminator weights are changed accordingly (Lines 70-77).
With the discriminator training over, we now generate misleading labels required for the generator training (Line 80).
# train the generator (note that we should *not* update # the weights of the discriminator) with GradientTape() as tape: # generate fake images fakeImages = self.generator(lrImages) # calculate predictions rawPreds = self.discriminator(fakeImages) realPreds = self.discriminator(hrImages) relativisticPreds = rawPreds - reduce_mean(realPreds) predictions = sigmoid(relativisticPreds) # compute the discriminator predictions on the fake images # todo: try with logits #gLoss = self.bceLoss(misleadingLabels, predictions) gLoss = self.bceLoss(misleadingLabels, predictions) # compute the pixel loss pixelLoss = self.mseLoss(hrImages, fakeImages) # compute the normalized vgg outputs srVGG = tf.keras.applications.vgg19.preprocess_input( fakeImages) srVGG = self.vgg(srVGG) / 12.75 hrVGG = tf.keras.applications.vgg19.preprocess_input( hrImages) hrVGG = self.vgg(hrVGG) / 12.75 # compute the perceptual loss percLoss = self.mseLoss(hrVGG, srVGG) # compute the total GAN loss gTotalLoss = 5e-3 * gLoss + percLoss + 1e-2 * pixelLoss # compute the gradients grads = tape.gradient(gTotalLoss, self.generator.trainable_variables) # optimize the generator weights according to the gradients # calculated self.gOptimizer.apply_gradients(zip(grads, self.generator.trainable_variables) ) # return the generator and discriminator losses return {"dLoss": dLoss, "gTotalLoss": gTotalLoss, "gLoss": gLoss, "percLoss": percLoss, "pixelLoss": pixelLoss}
We again initialize a GradientTape
for the generator and generate fake super-resolution images using the generator (Lines 84-86).
The predictions for both the fake and real super-resolution images are calculated, and the relativistic errors are calculated (Lines 89-92).
The predictions are fed to a binary cross entropy loss function while the pixel loss is calculated using the mean squared error loss function (Lines 97-100).
Next, we compute the VGG outputs and the perceptual loss (Lines 103-111). With all the loss values available, we compute the total GAN loss using the equation on Line 114.
The generator gradients are computed and applied (Lines 117-128).
Creating Utility Functions to Aid GAN Training
We have used a few utility scripts to aid in our training pipeline. The first one is a script that keeps the losses we have used in our training. For that, let’s move into the losses.py
script inside the pyimagesearch
directory.
# import necessary packages from tensorflow.keras.losses import MeanSquaredError from tensorflow.keras.losses import BinaryCrossentropy from tensorflow.keras.losses import Reduction from tensorflow import reduce_mean class Losses: def __init__(self, numReplicas): self.numReplicas = numReplicas def bce_loss(self, real, pred): # compute binary cross entropy loss without reduction BCE = BinaryCrossentropy(reduction=Reduction.NONE) loss = BCE(real, pred) # compute reduced mean over the entire batch loss = reduce_mean(loss) * (1. / self.numReplicas) # return reduced bce loss return
The __init__
function defines the batch size used in the consequent loss functions (Lines 8 and 9).
The losses are packaged into a class on Line 7. The first loss we have defined is the binary cross entropy loss on Line 11. It takes in the real labels and the predicted labels.
The binary cross entropy loss object is defined on Line 13, and the loss is calculated on Line 14. The loss is then adjusted over the entire batch (Line 17).
def mse_loss(self, real, pred): # compute mean squared error loss without reduction MSE = MeanSquaredError(reduction=Reduction.NONE) loss = MSE(real, pred) # compute reduced mean over the entire batch loss = reduce_mean(loss) * (1. / self.numReplicas) # return reduced mse loss return loss
The next loss is the mean squared error function defined on Line 22. A mean squared error loss object is initialized, followed by the loss calculation over the entire batch (Lines 24-28).
This concludes our losses.py
script. We next move into the utils.py
script, which will help us assess images generated by the GAN better. For that, let’s move into the utils.py
script next.
# import the necessary packages from . import config from matplotlib.pyplot import subplots from matplotlib.pyplot import savefig from matplotlib.pyplot import title from matplotlib.pyplot import xticks from matplotlib.pyplot import yticks from matplotlib.pyplot import show from tensorflow.keras.preprocessing.image import array_to_img from mpl_toolkits.axes_grid1.inset_locator import zoomed_inset_axes from mpl_toolkits.axes_grid1.inset_locator import mark_inset import os # the following code snippet has been taken from: # https://keras.io/examples/vision/super_resolution_sub_pixel def zoom_into_images(image, imageTitle): # create a new figure with a default 111 subplot. (fig, ax) = subplots() im = ax.imshow(array_to_img(image[::-1]), origin="lower") title(imageTitle) # zoom-factor: 2.0, location: upper-left axins = zoomed_inset_axes(ax, 2, loc=2) axins.imshow(array_to_img(image[::-1]), origin="lower") # specify the limits. (x1, x2, y1, y2) = 20, 40, 20, 40 # apply the x-limits. axins.set_xlim(x1, x2) # apply the y-limits. axins.set_ylim(y1, y2) # remove the xticks and yticks yticks(visible=False) xticks(visible=False) # make the line. mark_inset(ax, axins, loc1=1, loc2=3, fc="none", ec="blue") # build the image path and save it to disk imagePath = os.path.join(config.BASE_IMAGE_PATH, f"{imageTitle}.png") savefig(imagePath) # show the image show()
This script contains a single function called zoom_into_images
on Line 16, which takes in the image and image title as arguments.
First, the subplots are defined, and the image is plotted (Lines 18 and 19). On Lines 21-24, we zoom into the upper-left area of the image and plot that part again.
The limits of this plot are set on Lines 27-31. Now, we remove the ticks on both the x-axis and y-axis and insert the lines on our original plot (Lines 34-38).
With the images plotted, we save the image and conclude the function (Lines 41-43).
Our final utility script is the vgg.py
script, which initializes a VGG model for our perception loss.
# import the necessary packages from tensorflow.keras.applications import VGG19 from tensorflow.keras import Model class VGG: @staticmethod def build(): # initialize the pre-trained VGG19 model vgg = VGG19(input_shape=(None, None, 3), weights="imagenet", include_top=False) # slicing the VGG19 model till layer #20 model = Model(vgg.input, vgg.layers[20].output) # return the sliced VGG19 model return model
We create a class for the VGG model on Line 5. This function contains a singular function called build
, which simply initializes a pretrained VGG-19 architecture and returns a VGG model sliced to the 20th layer (Lines 7-16). This concludes the vgg.py
script.
Training the ESRGAN
Now we have all our blocks ready. We just need to execute them in the correct order for proper GAN training. To achieve this, we move into the train_esrgan.py
script.
# USAGE # python train_esrgan.py --device gpu # python train_esrgan.py --device tpu # import tensorflow and fix the random seed for better reproducibility import tensorflow as tf tf.random.set_seed(42) # import the necessary packages from pyimagesearch.data_preprocess import load_dataset from pyimagesearch.esrgan_training import ESRGANTraining from pyimagesearch.esrgan import ESRGAN from pyimagesearch.losses import Losses from pyimagesearch.vgg import VGG from pyimagesearch import config from tensorflow import distribute from tensorflow.config import experimental_connect_to_cluster from tensorflow.tpu.experimental import initialize_tpu_system from tensorflow.keras.optimizers import Adam from tensorflow.io.gfile import glob import argparse import sys import os # construct the argument parser and parse the arguments ap = argparse.ArgumentParser() ap.add_argument("--device", required=True, default="gpu", choices=["gpu", "tpu"], type=str, help="device to use for training (gpu or tpu)") args = vars(ap.parse_args())
The first task here is to define an argument parser so that the user can choose whether the GAN training will be done using a TPU or GPU (Lines 26-30). As we have already mentioned, we have trained the GAN using both TPU and GPU to assess efficiency.
# check if we are using TPU, if so, initialize the TPU strategy if args["device"] == "tpu": # initialize the TPUs tpu = distribute.cluster_resolver.TPUClusterResolver() experimental_connect_to_cluster(tpu) initialize_tpu_system(tpu) strategy = distribute.TPUStrategy(tpu) # ensure the user has entered a valid gcs bucket path if config.TPU_BASE_TFR_PATH == "gs://<PATH_TO_GCS_BUCKET>/tfrecord": print("[INFO] not a valid GCS Bucket path...") sys.exit(0) # set the train TFRecords, pretrained generator, and final # generator model paths to be used for TPU training tfrTrainPath = config.TPU_DIV2K_TFR_TRAIN_PATH pretrainedGenPath = config.TPU_PRETRAINED_GENERATOR_MODEL genPath = config.TPU_GENERATOR_MODEL # otherwise, we are using multi/single GPU so initialize the mirrored # strategy elif args["device"] == "gpu": # define the multi-gpu strategy strategy = distribute.MirroredStrategy() # set the train TFRecords, pretrained generator, and final # generator model paths to be used for GPU training tfrTrainPath = config.GPU_DIV2K_TFR_TRAIN_PATH pretrainedGenPath = config.GPU_PRETRAINED_GENERATOR_MODEL genPath = config.GPU_GENERATOR_MODEL # else, invalid argument was provided as input else: # exit the program print("[INFO] please enter a valid device argument...") sys.exit(0) # display the number of accelerators print(f"[INFO] number of accelerators: {strategy.num_replicas_in_sync}...")
According to the device choice, we have to initialize strategies. First, we explore the case of the TPU choice (Line 33).
To utilize the power of TPUs properly, we initialize a TPUClusterResolver
for efficient usage of resources. Next, the TPU strategy is initialized (Lines 35-43).
The TFRecords
path to the TPU trained data, pretrained generator, and fully trained generator are defined (Lines 47-49).
Now the second device choice, which is the GPU, is explored. For the GPU, the GPU-mirroring strategy is used (Line 55), and the GPU-specific TFRecords
path, pretrained generator path, and the fully trained generator path are defined (Lines 59-61).
If any other choice is given, the script exits by itself (Lines 64-67).
# grab train TFRecord filenames print("[INFO] grabbing the train TFRecords...") trainTfr = glob(tfrTrainPath +"/*.tfrec") # build the div2k datasets from the TFRecords print("[INFO] creating train and test dataset...") trainDs = load_dataset(filenames=trainTfr, train=True, batchSize=config.TRAIN_BATCH_SIZE * strategy.num_replicas_in_sync) # call the strategy scope context manager with strategy.scope(): # initialize our losses class object losses = Losses(numReplicas=strategy.num_replicas_in_sync) # initialize the generator, and compile it with Adam optimizer and # MSE loss generator = ESRGAN.generator( scalingFactor=config.SCALING_FACTOR, featureMaps=config.FEATURE_MAPS, residualBlocks=config.RESIDUAL_BLOCKS, leakyAlpha=config.LEAKY_ALPHA, residualScalar=config.RESIDUAL_SCALAR) generator.compile(optimizer=Adam(learning_rate=config.PRETRAIN_LR), loss=losses.mse_loss) # pretraining the generator print("[INFO] pretraining ESRGAN generator ...") generator.fit(trainDs, epochs=config.PRETRAIN_EPOCHS, steps_per_epoch=config.STEPS_PER_EPOCH)
We grab the TFRecords
files and then create a training dataset using the load_dataset
function (Lines 74-79).
First, we will initialize the pretrained Generator. For that, we first call the strategy scope context manager to initialize the losses and the generator on Lines 82-95. This is followed by training the generator on Lines 99 and 100.
# check whether output model directory exists, if it doesn't, then # create it if args["device"] == "gpu" and not os.path.exists(config.BASE_OUTPUT_PATH): os.makedirs(config.BASE_OUTPUT_PATH) # save the pretrained generator print("[INFO] saving the pretrained generator...") generator.save(pretrainedGenPath) # call the strategy scope context manager with strategy.scope(): # initialize our losses class object losses = Losses(numReplicas=strategy.num_replicas_in_sync) # initialize the vgg network (for perceptual loss) and discriminator # network vgg = VGG.build() discriminator = ESRGAN.discriminator( featureMaps=config.FEATURE_MAPS, leakyAlpha=config.LEAKY_ALPHA, discBlocks=config.DISC_BLOCKS) # build the ESRGAN model and compile it esrgan = ESRGANTraining( generator=generator, discriminator=discriminator, vgg=vgg, batchSize=config.TRAIN_BATCH_SIZE) esrgan.compile( dOptimizer=Adam(learning_rate=config.FINETUNE_LR), gOptimizer=Adam(learning_rate=config.FINETUNE_LR), bceLoss=losses.bce_loss, mseLoss=losses.mse_loss, ) # train the ESRGAN model print("[INFO] training ESRGAN...") esrgan.fit(trainDs, epochs=config.FINETUNE_EPOCHS, steps_per_epoch=config.STEPS_PER_EPOCH) # save the ESRGAN generator print("[INFO] saving ESRGAN generator to {}..." .format(genPath)) esrgan.generator.save(genPath)
If the device is set to GPU, the base output model directory is initialized if not done already (Lines 104 and 105). The pretrained generator is then saved to the designated path (Line 109).
Now we move on to the fully trained ESRGAN. We initialize the strategy scope context manager again and initialize a loss object (Lines 112-114).
The VGG model required for the perceptual loss is initialized, followed by the ESRGAN (Lines 118-129). The ESRGAN is then compiled with required optimizers and losses (Lines 130-135).
The final step here is to fit the ESRGAN with the training data and let it train (Lines 139 and 140).
Once completed, the trained weights are saved in the predetermined path on Line 145.
Building an Inference Script for the ESRGAN
With our ESRGAN training complete, we can now assess how good our ESRGAN has fared for the result. For that, let’s look at the inference.py
script located in the core directory.
# USAGE # python inference.py --device gpu # python inference.py --device tpu # import the necessary packages from pyimagesearch.data_preprocess import load_dataset from pyimagesearch.utils import zoom_into_images from pyimagesearch import config from tensorflow import distribute from tensorflow.config import experimental_connect_to_cluster from tensorflow.tpu.experimental import initialize_tpu_system from tensorflow.keras.models import load_model from tensorflow.keras.preprocessing.image import array_to_img from tensorflow.io.gfile import glob from matplotlib.pyplot import subplots import argparse import sys import os # construct the argument parser and parse the arguments ap = argparse.ArgumentParser() ap.add_argument("--device", required=True, default="gpu", choices=["gpu", "tpu"], type=str, help="device to use for training (gpu or tpu)") args = vars(ap.parse_args())
Depending on the device we had utilized for training, we need to provide the same device to initialize models and load the required weights accordingly. For that, we build an argument parser that takes the device choice from the user (Lines 21-25).
# check if we are using TPU, if so, initialize the strategy # accordingly if args["device"] == "tpu": tpu = distribute.cluster_resolver.TPUClusterResolver() experimental_connect_to_cluster(tpu) initialize_tpu_system(tpu) strategy = distribute.TPUStrategy(tpu) # set the train TFRecords, pretrained generator, and final # generator model paths to be used for TPU training tfrTestPath = config.TPU_DIV2K_TFR_TEST_PATH pretrainedGenPath = config.TPU_PRETRAINED_GENERATOR_MODEL genPath = config.TPU_GENERATOR_MODEL # otherwise, we are using multi/single GPU so initialize the mirrored # strategy elif args["device"] == "gpu": # define the multi-gpu strategy strategy = distribute.MirroredStrategy() # set the train TFRecords, pretrained generator, and final # generator model paths to be used for GPU training tfrTestPath = config.GPU_DIV2K_TFR_TEST_PATH pretrainedGenPath = config.GPU_PRETRAINED_GENERATOR_MODEL genPath = config.GPU_GENERATOR_MODEL # else, invalid argument was provided as input else: # exit the program print("[INFO] please enter a valid device argument...") sys.exit(0)
Depending on the choice input by the user, we have to set up the strategies for dealing with data.
The first device choice (TPU) is explored by initializing the TPUClusterResolver
, strategy, and TPU-specific output paths the same way we did for the training script (Lines 29-33).
For the second choice (GPU), we repeat the same procedures as the training script (Lines 43-51).
If any other input is given, the script exits itself (Lines 54-57).
# get the dataset print("[INFO] loading the test dataset...") testTfr = glob(tfrTestPath + "/*.tfrec") testDs = load_dataset(testTfr, config.INFER_BATCH_SIZE, train=False) # get the first batch of testing images (lrImage, hrImage) = next(iter(testDs)) # call the strategy scope context manager with strategy.scope(): # load the ESRGAN trained models print("[INFO] loading the pre-trained and fully trained ESRGAN model...") esrganPreGen = load_model(pretrainedGenPath, compile=False) esrganGen = load_model(genPath, compile=False) # predict using ESRGAN print("[INFO] making predictions with pre-trained and fully trained ESRGAN model...") esrganPreGenPred = esrganPreGen.predict(lrImage) esrganGenPred = esrganGen.predict(lrImage)
For testing purposes, we create a test dataset on Line 62. Using next(iter())
, we can grab a single batch of image sets, which we unpack on Line 65.
Next, the pretrained GAN and the fully trained ESRGAN are initialized and loaded on Lines 71 and 72. The low-resolution images are then passed through these GANs for predictions (Lines 76 and 77).
# plot the respective predictions print("[INFO] plotting the ESRGAN predictions...") (fig, axes) = subplots(nrows=config.INFER_BATCH_SIZE, ncols=4, figsize=(50, 50)) # plot the predicted images from low res to high res for (ax, lowRes, esrPreIm, esrGanIm, highRes) in zip(axes, lrImage, esrganPreGenPred, esrganGenPred, hrImage): # plot the low resolution image ax[0].imshow(array_to_img(lowRes)) ax[0].set_title("Low Resolution Image") # plot the pretrained ESRGAN image ax[1].imshow(array_to_img(esrPreIm)) ax[1].set_title("ESRGAN Pretrained") # plot the ESRGAN image ax[2].imshow(array_to_img(esrGanIm)) ax[2].set_title("ESRGAN") # plot the high resolution image ax[3].imshow(array_to_img(highRes)) ax[3].set_title("High Resolution Image") # check whether output image directory exists, if it doesn't, then # create it if not os.path.exists(config.BASE_IMAGE_PATH): os.makedirs(config.BASE_IMAGE_PATH) # serialize the results to disk print("[INFO] saving the ESRGAN predictions to disk...") fig.savefig(config.GRID_IMAGE_PATH) # plot the zoomed in images zoom_into_images(esrganPreGenPred[0], "ESRGAN Pretrained") zoom_into_images(esrganGenPred[0], "ESRGAN")
To visualize our results, subplots are initialized on Lines 81 and 82. Then we loop over the batch and plot the low-resolution image, pretrained GAN output, ESRGAN output, and actual high-resolution image for comparison (Lines 85-101).
Visualizations of the ESRGAN
Figures 3 and 4 show us the final predicted images from the pretrained ESRGAN and the fully trained ESRGAN, respectively.
The outputs of these two models are visually indistinguishable. However, the results are better than last week’s SRGAN outputs, even though the ESRGAN was trained for fewer epochs.
The zoomed-in patch shows the intricate sharpness of the pixelated information the ESRGAN achieved, proving that the enhanced recipes used for SRGAN enhancement have worked quite well.
What's next? We recommend PyImageSearch University.
84 total classes • 114+ hours of on-demand code walkthrough videos • Last updated: February 2024
★★★★★ 4.84 (128 Ratings) • 16,000+ Students Enrolled
I strongly believe that if you had the right teacher you could master computer vision and deep learning.
Do you think learning computer vision and deep learning has to be time-consuming, overwhelming, and complicated? Or has to involve complex mathematics and equations? Or requires a degree in computer science?
That’s not the case.
All you need to master computer vision and deep learning is for someone to explain things to you in simple, intuitive terms. And that’s exactly what I do. My mission is to change education and how complex Artificial Intelligence topics are taught.
If you're serious about learning computer vision, your next stop should be PyImageSearch University, the most comprehensive computer vision, deep learning, and OpenCV course online today. Here you’ll learn how to successfully and confidently apply computer vision to your work, research, and projects. Join me in computer vision mastery.
Inside PyImageSearch University you'll find:
- ✓ 86 courses on essential computer vision, deep learning, and OpenCV topics
- ✓ 86 Certificates of Completion
- ✓ 115+ hours of on-demand video
- ✓ Brand new courses released regularly, ensuring you can keep up with state-of-the-art techniques
- ✓ Pre-configured Jupyter Notebooks in Google Colab
- ✓ Run all code examples in your web browser — works on Windows, macOS, and Linux (no dev environment configuration required!)
- ✓ Access to centralized code repos for all 540+ tutorials on PyImageSearch
- ✓ Easy one-click downloads for code, datasets, pre-trained models, etc.
- ✓ Access on mobile, laptop, desktop, etc.
Summary
SRGANs have already left a great impression based on their results. It had outperformed several existing super-resolution algorithms. Building on its foundation, coming up with enhancing recipes is something that the whole deep learning community greatly appreciates.
The additions are extremely well thought out, and the results are clear as day. Our ESRGAN achieves great results despite being trained for far fewer epochs. This fits the motive of ESRGAN, which had a prioritizing focus on efficiency, very well.
GANs have continued to impress us, and to this day, new domains are taken on using GANs. But in today’s project, we dealt with approaches that can be used to improve the end result.
Citation Information
Chakraborty, D. “Enhanced Super-Resolution Generative Adversarial Networks (ESRGAN),” PyImageSearch, P. Chugh, A. R. Gosthipaty, S. Huot, K. Kidriavsteva, R. Raha, and A. Thanki, eds., 2022, https://pyimg.co/jt2cb
@incollection{Chakraborty_2022_ESRGAN, author = {Devjyoti Chakraborty}, title = {Enhanced Super-Resolution Generative Adversarial Networks {(ESRGAN)}}, booktitle = {PyImageSearch}, editor = {Puneet Chugh and Aritra Roy Gosthipaty and Susan Huot and Kseniia Kidriavsteva and Ritwik Raha and Abhishek Thanki}, year = {2022}, note = {https://pyimg.co/jt2cb}, }
Unleash the potential of computer vision with Roboflow - Free!
- Step into the realm of the future by signing up or logging into your Roboflow account. Unlock a wealth of innovative dataset libraries and revolutionize your computer vision operations.
- Jumpstart your journey by choosing from our broad array of datasets, or benefit from PyimageSearch’s comprehensive library, crafted to cater to a wide range of requirements.
- Transfer your data to Roboflow in any of the 40+ compatible formats. Leverage cutting-edge model architectures for training, and deploy seamlessly across diverse platforms, including API, NVIDIA, browser, iOS, and beyond. Integrate our platform effortlessly with your applications or your favorite third-party tools.
- Equip yourself with the ability to train a potent computer vision model in a mere afternoon. With a few images, you can import data from any source via API, annotate images using our superior cloud-hosted tool, kickstart model training with a single click, and deploy the model via a hosted API endpoint. Tailor your process by opting for a code-centric approach, leveraging our intuitive, cloud-based UI, or combining both to fit your unique needs.
- Embark on your journey today with absolutely no credit card required. Step into the future with Roboflow.
To download the source code to this post (and be notified when future tutorials are published here on PyImageSearch), simply enter your email address in the form below!
Download the Source Code and FREE 17-page Resource Guide
Enter your email address below to get a .zip of the code and a FREE 17-page Resource Guide on Computer Vision, OpenCV, and Deep Learning. Inside you'll find my hand-picked tutorials, books, courses, and libraries to help you master CV and DL!
Comment section
Hey, Adrian Rosebrock here, author and creator of PyImageSearch. While I love hearing from readers, a couple years ago I made the tough decision to no longer offer 1:1 help over blog post comments.
At the time I was receiving 200+ emails per day and another 100+ blog post comments. I simply did not have the time to moderate and respond to them all, and the sheer volume of requests was taking a toll on me.
Instead, my goal is to do the most good for the computer vision, deep learning, and OpenCV community at large by focusing my time on authoring high-quality blog posts, tutorials, and books/courses.
If you need help learning computer vision and deep learning, I suggest you refer to my full catalog of books and courses — they have helped tens of thousands of developers, students, and researchers just like yourself learn Computer Vision, Deep Learning, and OpenCV.
Click here to browse my full catalog.