In this post, we implement two GAN variants: Wasserstein GAN (WGAN) and Wasserstein GAN with Gradient Penalty (WGAN-GP), to address the training instability discussed in my previous post, GAN Training Challenges: DCGAN for Color Images. We will train the WGAN and WGAN-GP models to generate colorful 64×64
anime faces.
WGAN models require diverse and extensive training data to generate high-quality anime faces. A well-curated dataset is crucial in training these models to capture the nuances of anime art styles.
Roboflow has free tools for each stage of the computer vision pipeline that will streamline your workflows and supercharge your productivity.
Sign up or Log in to your Roboflow account to access state of the art dataset libaries and revolutionize your computer vision pipeline.
You can start by choosing your own datasets or using our PyimageSearch’s assorted library of useful datasets.
Bring data in any of 40+ formats to Roboflow, train using any state-of-the-art model architectures, deploy across multiple platforms (API, NVIDIA, browser, iOS, etc), and connect to applications or 3rd party tools.
With a few images, you can train a working computer vision model in an afternoon. For example, bring data into Roboflow from anywhere via API, label images with the cloud-hosted image annotation tool, kickoff a hosted model training with one-click, and deploy the model via a hosted API endpoint. This process can be executed in a code-centric way, in the cloud-based UI, or any mix of the two.
Over 250,000 developers and machine learning engineers from companies such as Cardinal Health, Walmart, USG, Rivian, Intel, and Medtronic build computer vision pipelines with Roboflow. Get started today, no credit card required.
This is the fourth post of our GAN tutorial series:
- Intro to Generative Adversarial Networks (GANs)
- Get Started: DCGAN for Fashion-MNIST
- GAN Training Challenges: DCGAN for Color Images
- Anime Faces with WGAN and WGAN-GP (this tutorial)
We will first walk through a WGAN tutorial step-by-step focusing on the new concepts introduced by the WGAN paper. Then we discuss how to improve WGAN with a few changes to make WGAN-GP.
Wasserstein GAN
The Wasserstein GAN (WGAN) was introduced in the paper Wasserstein GAN. Its main contribution was to use the Wasserstein loss to address the GAN training instability issues, which was a major breakthrough for GAN training.
Recall in DCGAN, when the discriminator is too weak or too strong, it won’t give the generator useful feedback for making improvements. Training longer doesn’t necessarily make the DCGAN model better.
With WGAN, these training issues can be solved with the new Wasserstein loss: we no longer need a careful balance in the training of discriminator and generator or careful design of the network architecture. WGAN has linear gradients that are continuous and differentiable almost everywhere (Figure 1). This solves the vanishing gradient problem with regular GAN training,
Here are a few new concepts or key changes introduced in the WGAN paper:
- Wasserstein distance (or Earth mover’s distance): measures the effort needed to transform one distribution into another.
- Wasserstein loss: a new loss function that measures the Wasserstein distance.
- The discriminator is now called a critic in WGAN. Instead of training a discriminator (a binary classifier) to tell whether an image is real or fake (generated), we train a critic that outputs a number.
- The critic must meet the Lipschitz constraint for the Wasserstein loss to work.
- WGAN uses weight clipping to enforce the 1-Lipschitz constraint.
As we implement each new GAN architecture, I will highlight the changes compared with a previous GAN variant to help you learn the new concepts. Here are the key changes comparing a WGAN with DCGAN:
Table 1 summarizes the changes needed for updating a DCGAN to a WGAN:
Now let’s walk through the code to implement these changes in WGAN with TensorFlow 2 / Keras. While following the tutorial below, please refer to the WGAN Colab notebook here for the complete code.
Setup
First, we make sure to set the runtime of the Colab hardware accelerator as GPU. Then we import all the libraries needed (e.g., TensorFlow 2, Keras, and Matplotlib, etc.).
Prepare the Data
We will train the DCGAN with a dataset called Anime Face Dataset from Kaggle, which is a collection of anime faces scraped from www.getchu.com. There are 63,565 small color images to be resized to 64×64
for training.
To download data from Kaggle, you will need to provide your Kaggle credential. You could either upload the Kaggle .json file to Colab or put your Kaggle user name and key in the notebook. We chose the latter option.
os.environ['KAGGLE_USERNAME']="enter-your-own-user-name" os.environ['KAGGLE_KEY']="enter-your-own-user-name"
Download and unzip the data to a directory called dataset
.
!kaggle datasets download -d splcher/animefacedataset -p dataset !unzip datasets/animefacedataset.zip -d datasets/
After downloading and unzipping the data, we set a directory where the images are.
anime_data_dir = "/content/datasets/images"
Then we use the Keras utils function of image_dataset_from_directory
to create a tf.data.Dataset
from the images in the directory, which will be used for training the model later on. We specify the image size of 64×64
and a batch size of 256
.
train_images = tf.keras.utils.image_dataset_from_directory( anime_data_dir, label_mode=None, image_size=(64, 64), batch_size=256)
Let’s visualize one random training image.
image_batch = next(iter(train_images)) random_index = np.random.choice(image_batch.shape[0]) random_image = image_batch[random_index].numpy().astype("int32") plt.axis("off") plt.imshow(random_image) plt.show()
Here is what this random training image looks like in Figure 2:
Same as before, we normalize the images to the range of [-1, 1]
because the generator’s final layer activation uses tanh
. Finally, we apply the normalization by using the map
function of the tf.dataset
with a lambda
function.
train_images = train_images.map(lambda x: (x - 127.5) / 127.5)
The Generator
There is no change in the WGAN generator architecture, which is the same as in DCGAN. We create the generator architecture with the Keras Sequential
API in the build_generator
function. Refer to the details of how to create the generator architecture in my previous two DCGAN posts: DCGAN for Fashion-MNIST and DCGAN for Color Images.
After defining the generator architecture in the build_generator()
function, we build the generator model with generator = build_generator()
and call generator.summary()
to visualize the model architecture.
The Critic
In WGAN, we have a critic that assigns a score that measures Wasserstein distance instead of a discriminator for binary classification of real and fake images. Note the critic’s output is now a score instead of a probability. The critic is constrained with a 1-Lipschitz continuity condition.
There are quite a few changes here:
- Rename
discriminator
tocritic
- Use weight clipping to enforce 1-Lipschitz continuity on the critic
- Change the critic’s activation function from
sigmoid
tolinear
Rename discriminator
to critic
If you start with the DCGAN code, you will need to rename the discriminator
to critic
. You can use the “Find and replace” feature in Colab to make all the updates.
So now we have a function called build_critic
instead of build_discriminator
.
Weight clipping
WGAN enforces 1-Lipschitz constraint by using weight clipping which we implement by subclassing keras.constraints.Constraint
. Refer to the Keras layer weight constraint for detailed documentation. Here is how we create the WeightClipping
class:
class WeightClipping(tf.keras.constraints.Constraint): def __init__(self, clip_value): self.clip_value = clip_value def __call__(self, weights): return tf.clip_by_value(weights, -self.clip_value, self.clip_value) def get_config(self): return {'clip_value': self.clip_value}
Then in the build_critic
function we create a constraint
of [-0.01, 0.01]
with the WeightClipping
class.
constraint = WeightClipping(0.01)
Now we add kernel_constraint = constraint
to all the CONV2D
layers of the critic. For example:
model.add(layers.Conv2D(64, (4, 4), padding="same", strides=(2, 2), kernel_constraint = constraint, input_shape=input_shape))
Linear activation
In the last layer of the critic, we update the activation from sigmoid
to linear
.
model.add(layers.Dense(1, activation="linear"))
Please note that in Keras, the Dense
layer by default has linear
activation so that we could have omitted the activation="linear"
part and written the code like this:
model.add(layers.Dense(1))
I left activation = "linear"
there to make it clear we are changing from sigmoid
to linear
activation when updating a DCGAN to WGAN.
Now that we have defined the model architecture in the build_critic
function, let’s build the critic model with critic = build_critic(64, 64, 3)
and call critic.summary()
to visualize the critic model architecture.
The WGAN Model
We define the WGAN model architecture by subclass keras.Model
and override train_step
to define the custom training loops.
There are a few changes in this section for WGAN:
- Update the critic more frequently than the generator
- No more image labels for the critic
- Use Wasserstein loss instead of Binary Crossentropy (BCE) loss
Update the critic more often than the generator
Per the paper recommendation, we update the critic 5 times more often than the generator. To achieve this, we pass in an additional argument called critic_extra_steps
to __init__
of the WGAN
class.
def __init__(self, critic, generator, latent_dim, critic_extra_steps): ... self.c_extra_steps = critic_extra_steps ...
Then in train_step()
, we use a for
loop to apply the extra training steps.
for i in range(self.c_extra_steps): # Step 1. Train the critic ... # Step 2. Train the generator
Image labels
Depending on how we write the Wasserstein loss functions, we could either 1) assign ones as real images labels and negative ones as the labels of the fake images, or 2) not assign any labels at all.
Here is a brief explanation of the two options. When using labels, the Wasserstein loss is calculated as tf.reduce mean(y_true * y_pred)
. If we have the critic loss as loss on real images + loss on fake images and the generator loss on fake images only, then it leads to tf.reduce_mean (1 * pred_real - 1 * pred_fake)
for the critic loss and -tf.reduce_mean(pred_fake)
for the generator loss.
Note that the critic’s objective is not trying to assign a label of 1
or -1
; instead, it tries to maximize the difference between its prediction on real images and its predictions on fake images. So in the case of the Wasserstein loss, the labels don’t really matter much.
So we choose the latter option of not assigning labels, and you will see all the code of real or fake labels are removed.
Wasserstein loss
The Wasserstein loss for the critic and generator get passed in through model.compile
:
def compile(self, d_optimizer, g_optimizer, d_loss_fn, g_loss_fn): super(WGAN, self).compile() ... self.d_loss_fn = d_loss_fn self.g_loss_fn = g_loss_fn
Then in train_step
, we use these functions to calculate the critic loss and generator loss, respectively, during training.
def train_step(self, real_images): for i in range(self.c_extra_steps): # Step 1. Train the critic ... d_loss = self.d_loss_fn(pred_real, pred_fake) # critic loss # Step 2. Train the generator ... g_loss = self.g_loss_fn(pred_fake) # generator loss
Keras Callback
for Training Monitoring
Same code as DCGAN with no change — override Keras Callback
to monitor and visualize the generated images during training.
class GANMonitor(keras.callbacks.Callback): def __init__(): ... def on_epoch_end(): ... def on_train_end(): ...
Compile and Train WGAN
Putting together the WGAN model
We put together the wgan
model with the WGAN class defined above. Note we need to set the extra training steps for the critic as 5
per the WGAN paper.
wgan = WGAN(critic=critic, generator=generator, latent_dim=LATENT_DIM, critic_extra_steps=5) # UPDATE for WGAN
Wasserstein loss functions
As mentioned before, the main change in WGAN is the usage of Wasserstein loss. Here is how to calculate Wasserstein loss for the critic and the generator — by defining custom loss functions in Keras.
# Wasserstein loss for the critic def d_wasserstein_loss(pred_real, pred_fake): real_loss = tf.reduce_mean(pred_real) fake_loss = tf.reduce_mean(pred_fake) return fake_loss - real_loss # Wasserstein loss for the generator def g_wasserstein_loss(pred_fake): return -tf.reduce_mean(pred_fake)
Compile WGAN
Now we compile the wgan
model with RMSProp optimizer and a learning rate of 0.00005 as per the WGAN paper.
LR = 0.00005 # UPDATE for WGAN: learning rate per WGAN paper wgan.compile( d_optimizer = keras.optimizers.RMSprop(learning_rate=LR, clipvalue=1.0, decay=1e-8), # UPDATE for WGAN: use RMSProp instead of Adam g_optimizer = keras.optimizers.RMSprop(learning_rate=LR, clipvalue=1.0, decay=1e-8), # UPDATE for WGAN: use RMSProp instead of Adam d_loss_fn = d_wasserstein_loss, g_loss_fn = g_wasserstein_loss )
Note in DCGAN, we use keras.losses.BinaryCrossentropy()
while for WGAN, we are using the custom wasserstein_loss
functions defined above. These two wasserstein_loss
functions get passed in through model.compile()
. They will be used in the custom training loop as discussed in the overriding _step
section above.
Train the WGAN model
Now we simply call model.fit()
to train the wgan
model!
NUM_EPOCHS = 50 # number of epochs wgan.fit(train_images, epochs=NUM_EPOCHS, callbacks=[GANMonitor(num_img=16, latent_dim=LATENT_DIM)])
Wasserstein GAN with Gradient Penalty
While WGAN improves training stability with the Wasserstein loss, even the paper itself admits that “weight clipping is a clearly terrible way to enforce a Lipschitz constraint.” A large clipping parameter can lead to slow training and prevent the critic from reaching optimality. At the same time, a clipping too small can easily lead to vanishing gradients, the exact problem WGAN was proposed to solve.
The Wasserstein with Gradient Penalty (WGAN-GP) was introduced in the paper, Improved Training of Wasserstein GANs. It further improves WGAN by using gradient penalty instead of weight clipping to enforce the 1-Lipschitz constraint for the critic.
We only need to make a few changes to update a WGAN to a WGAN-WP:
- Remove batch norm from the critic’s architecture.
- Use gradient penalty instead of weight clipping to enforce the Lipschitz constraint.
- Use Adam optimizer (α = 0.0002, β1 = 0.5, β2 = 0.9) instead of RMSProp.
Please refer to the WGAN-GP Colab notebook here for the complete code example. Here in this tutorial, we discuss only the incremental changes updating a WGAN to a WGAN-WP.
Add Gradient Penalty
Gradient penalty means penalizing gradients with large norm values, and here is how we calculate it in Keras:
def gradient_penalty(self, batch_size, real_images, fake_images): """ Calculates the gradient penalty. Gradient penalty is calculated on an interpolated image and added to the discriminator loss. """ alpha = tf.random.normal([batch_size, 1, 1, 1], 0.0, 1.0) diff = fake_images - real_images # 1. Create the interpolated image interpolated = real_images + alpha * diff with tf.GradientTape() as gp_tape: gp_tape.watch(interpolated) # 2. Get the Critic's output for the interpolated image pred = self.critic(interpolated, training=True) # 3. Calculate the gradients w.r.t to the interpolated image grads = gp_tape.gradient(pred, [interpolated])[0] # 4. Calculate the norm of the gradients. norm = tf.sqrt(tf.reduce_sum(tf.square(grads), axis=[1, 2, 3])) # 5. Calculate gradient penalty gradient_penalty = tf.reduce_mean((norm - 1.0) ** 2) return gradient_penalty
Then in train_step
, we calculate the gradient penalty and add it to the original critic loss. Note the penalty weight (or coefficient lambda ƛ) controls the magnitude of the penalty, and it’s set as 10 per WGAN paper.
gp = self.gradient_penalty(batch_size, real_images, fake_images) d_loss = self.d_loss_fn(pred_real, pred_fake) + gp * self.gp_weight
Remove batchnorm
While batch normalization helps stabilize training in GAN training, it doesn’t work with gradient penalty because with gradient penalty, we penalize the norm of the critic’s gradient to each input independently and not the entire batch. So we need to remove the batch norm code from the critic’s model architecture.
Adam Optimizer instead of RMSProp
DCGAN uses the Adam optimizer, and for WGAN, we switch to the RMSProp optimizer. Now for WGAN-GP, we switch back to Adam optimizer with a learning rate of 0.0002 per the WGAN-GP paper recommendation.
LR = 0.0002 # WGAN-GP paper recommends lr of 0.0002 d_optimizer = keras.optimizers.Adam(learning_rate=LR, beta_1=0.5, beta_2=0.9) g_optimizer = keras.optimizers.Adam(learning_rate=LR, beta_1=0.5, beta_2=0.9)
We compile and train the WGAN-GP model for 50 epochs, and we observe more stable training and better image quality generated by the model.
Figure 3 compares the real (training) images and images generated by WGAN and WGAN-GP, respectively.
Both WGAN and WGAN-GP have improved training stability. The tradeoff is that their training converges slower than DCGAN, and the image quality may be slightly worse; however, with the improved training stability, we can use much more complex generator network architectures, which result in improved image quality. Many later GAN variants adopted the Wasserstein loss and gradient penalty as default, for example, ProGAN and StyleGAN. Even the TF-GAN library uses the Wasserstein loss by default.
What's next? We recommend PyImageSearch University.
84 total classes • 114+ hours of on-demand code walkthrough videos • Last updated: February 2024
★★★★★ 4.84 (128 Ratings) • 16,000+ Students Enrolled
I strongly believe that if you had the right teacher you could master computer vision and deep learning.
Do you think learning computer vision and deep learning has to be time-consuming, overwhelming, and complicated? Or has to involve complex mathematics and equations? Or requires a degree in computer science?
That’s not the case.
All you need to master computer vision and deep learning is for someone to explain things to you in simple, intuitive terms. And that’s exactly what I do. My mission is to change education and how complex Artificial Intelligence topics are taught.
If you're serious about learning computer vision, your next stop should be PyImageSearch University, the most comprehensive computer vision, deep learning, and OpenCV course online today. Here you’ll learn how to successfully and confidently apply computer vision to your work, research, and projects. Join me in computer vision mastery.
Inside PyImageSearch University you'll find:
- ✓ 86 courses on essential computer vision, deep learning, and OpenCV topics
- ✓ 86 Certificates of Completion
- ✓ 115+ hours of on-demand video
- ✓ Brand new courses released regularly, ensuring you can keep up with state-of-the-art techniques
- ✓ Pre-configured Jupyter Notebooks in Google Colab
- ✓ Run all code examples in your web browser — works on Windows, macOS, and Linux (no dev environment configuration required!)
- ✓ Access to centralized code repos for all 540+ tutorials on PyImageSearch
- ✓ Easy one-click downloads for code, datasets, pre-trained models, etc.
- ✓ Access on mobile, laptop, desktop, etc.
Summary
In this post, you learned how to use WGAN and WGAN-GP to improve GAN training stability. You learned about incremental changes moving from a DCGAN to WGAN, then from a WGAN to WGAN-GP with TensorFlow 2 / Keras. You learned how to generate anime faces with WGAN and WGAN-GP.
Citation Information
Maynard-Reid, M. “Anime Faces with WGAN and WGAN-GP,” PyImageSearch, 2022, https://pyimg.co/9avys
@article{Maynard-Reid_2022_Anime_Faces, author = {Margaret Maynard-Reid}, title = {Anime Faces with {WGAN} and {WGAN-GP}}, journal = {PyImageSearch}, year = {2022}, note = {https://pyimg.co/9avys}, }
Unleash the potential of computer vision with Roboflow - Free!
- Step into the realm of the future by signing up or logging into your Roboflow account. Unlock a wealth of innovative dataset libraries and revolutionize your computer vision operations.
- Jumpstart your journey by choosing from our broad array of datasets, or benefit from PyimageSearch’s comprehensive library, crafted to cater to a wide range of requirements.
- Transfer your data to Roboflow in any of the 40+ compatible formats. Leverage cutting-edge model architectures for training, and deploy seamlessly across diverse platforms, including API, NVIDIA, browser, iOS, and beyond. Integrate our platform effortlessly with your applications or your favorite third-party tools.
- Equip yourself with the ability to train a potent computer vision model in a mere afternoon. With a few images, you can import data from any source via API, annotate images using our superior cloud-hosted tool, kickstart model training with a single click, and deploy the model via a hosted API endpoint. Tailor your process by opting for a code-centric approach, leveraging our intuitive, cloud-based UI, or combining both to fit your unique needs.
- Embark on your journey today with absolutely no credit card required. Step into the future with Roboflow.
To download the source code to this post (and be notified when future tutorials are published here on PyImageSearch), simply enter your email address in the form below!
Download the Source Code and FREE 17-page Resource Guide
Enter your email address below to get a .zip of the code and a FREE 17-page Resource Guide on Computer Vision, OpenCV, and Deep Learning. Inside you'll find my hand-picked tutorials, books, courses, and libraries to help you master CV and DL!
Comment section
Hey, Adrian Rosebrock here, author and creator of PyImageSearch. While I love hearing from readers, a couple years ago I made the tough decision to no longer offer 1:1 help over blog post comments.
At the time I was receiving 200+ emails per day and another 100+ blog post comments. I simply did not have the time to moderate and respond to them all, and the sheer volume of requests was taking a toll on me.
Instead, my goal is to do the most good for the computer vision, deep learning, and OpenCV community at large by focusing my time on authoring high-quality blog posts, tutorials, and books/courses.
If you need help learning computer vision and deep learning, I suggest you refer to my full catalog of books and courses — they have helped tens of thousands of developers, students, and researchers just like yourself learn Computer Vision, Deep Learning, and OpenCV.
Click here to browse my full catalog.