Table of Contents
- Learning JAX in 2023: Part 1 — The Ultimate Guide to Accelerating Numerical Computation and Machine Learning
- 🙌🏻 Introduction
- Configuring Your Development Environment
- Having Problems Configuring Your Development Environment?
- 🤔 What Is JAX?
- 👀 What Is JAX (revisited)?
- ⬇️ Import JAX
- 📚 Understanding the Components: API Layering of JAX
- 💯 Numerical Computation in JAX
- Summary
Learning JAX in 2023: Part 1 — The Ultimate Guide to Accelerating Numerical Computation and Machine Learning
In this tutorial, you will learn the basics of the JAX library, including how to install and use it to perform numerical computation and machine learning tasks using NumPy-like syntax and GPU acceleration.
This lesson is the 1st of a 3-part series on Learning JAX in 2023:
- Learning JAX in 2023: Part 1 — The Ultimate Guide to Accelerating Numerical Computation and Machine Learning (today’s tutorial)
- Learning JAX in 2023: Part 2 — JAX’s Power Tools
grad
,jit
,vmap
, andpmap
- Learning JAX in 2023: Part 3 — A Step-by-Step Guide to Training Your First Machine Learning Model with JAX
To learn how to get started with JAX, just keep reading.
Looking for the source code to this post?
Jump Right To The Downloads SectionLearning JAX in 2023: Part 1 — The Ultimate Guide to Accelerating Numerical Computation and Machine Learning
🙌🏻 Introduction
As deep learning practitioners, it can be tough to keep up with all the new developments. New academic papers and models are always coming out; there’s a new framework to learn every few years. Recently, many people have been talking about JAX, a new numerical computing library that can make your code run faster.
Many people have asked us to create a course about JAX, so we decided to take on the challenge. In this series, we’ll not only teach you about JAX, but also how to learn and understand new concepts. We’ll keep the language simple and avoid using jargon, but if you need help understanding anything, please let us know, and we’ll do our best to help.
Once you complete this course, you’ll be able to understand and work with any code written in JAX/FLAX. Major companies like Google Research, Hugging Face, and OpenAI are already using JAX heavily, so this is a valuable skill to have. Let’s get started and learn all about it!
Configuring Your Development Environment
To follow this guide, you need to have the JAX library installed on your system. JAX is written in pure Python, but it depends on XLA, which needs to be installed as the jaxlib package (from: jax repository).
Luckily, jaxlib and jax are pip-installable:
$ pip install jaxlib $ pip install numpy $ pip install autograd $ pip install jax
If you need help configuring your development environment for OpenCV, we highly recommend that you read our pip install OpenCV guide — it will have you up and running in a matter of minutes.
Having Problems Configuring Your Development Environment?
All that said, are you:
- Short on time?
- Learning on your employer’s administratively locked system?
- Wanting to skip the hassle of fighting with the command line, package managers, and virtual environments?
- Ready to run the code right now on your Windows, macOS, or Linux system?
Then join PyImageSearch University today!
Gain access to Jupyter Notebooks for this tutorial and other PyImageSearch guides that are pre-configured to run on Google Colab’s ecosystem right in your web browser! No installation required.
And best of all, these Jupyter Notebooks will run on Windows, macOS, and Linux!
🤔 What Is JAX?
JAX is the combination of autograd
and XLA
. Before diving into the nitty-gritty of JAX, let us look into autograd
and XLA
briefly.
Note: The section about autograd and XLA is meant to provide a more holistic understanding of the principles with which JAX was built. They are optional to getting started with JAX.
Click here to skip to “What Is JAX (revisited)?”
autograd
Gradients run the deep learning world quite literally. For example, we can compute the gradients (derivatives) of an equation in the following ways:
- Manual: We use our calculus knowledge and derive the derivatives by hand. The problem with this approach is that it is manual. It would take a lot of time for a Deep Learning researcher to derive the model’s derivatives by hand.
- Symbolic: We can obtain the derivatives via symbols and a program that can mimic the manual process. The problem with this approach is termed expression swell. Here the derivatives of a particular expression are exponentially longer (think chain rule) than the expression itself. This becomes quite difficult to track.
- Numeric: Here, we use the finite differences method to derive the derivatives.
- Automatic: The star ⭐️ of the show.
Automatic differentiation (autodiff) is the type of differentiation we all love and use when training our deep neural networks. In a previous tutorial, we covered the math and code base to understand how autodiff works. If you are interested in how autodiff works, we urge you to read the linked tutorial first and then return to the current one.
autograd
is a python package that performs automatic differentiation on native python and NumPy code. The code base is fairly simple.
Two points to note here:
- There is a light wrapper
autograd.numpy
around the native NumPy codebase. This allows users to use NumPy-like semantics, harnessing the power of automatic differentiation. autograd.grad
andautograd.elementwise_grad
help with the actual automatic differentiation.
We will now look at a few code snippets to demonstrate how autograd
works.
Let’s start by importing the necessary packages.
Note: autograd
’s numpy module is imported as anp
to distinguish from np
(original NumPy) and jnp
(jax numpy).
# Import the necessary packages from autograd import numpy as anp from autograd import grad from autograd import elementwise_grad as egrad
We will first define a simple function and then compute the function’s gradient using automatic differentiation. Next, we find the function’s gradient at a particular point (scalar) and for a list of points (vector).
We build a function that takes a value and returns the square of it .
def func_square(x): # Return the square of the input return x**2 # Build a scalar input and pass it to the # square function x = 4.0 squared_x = func_square(x=x) print(f"x => {x}\nx**2 => {squared_x}")
>>> x => 4.0 >>> x**2 => 16.0
The function that we have defined is . We know that the function’s derivative is .
We can achieve this derivative by applying the autograd.grad
function. Let us see how that works.
# Compute the derivative of the square function grad_func = grad(func_square) point = 1.0 # Retrieve the gradient of the function at a particular point print(f"Gradient of square func at {point} => {grad_func(1.0)}")
>>> Gradient of square func at 1.0 => 2.0
The above code snippet allows us to calculate the gradient of a scalar function. Next, let us see how to do the same with vectors.
# Let's pass a vector to the square function vector = anp.arange(1, 10, dtype=anp.float32) out_vector = func_square(vector) # Iterate over the vector and its output for v, o in zip(vector, out_vector): print(f"Value at point {v} => {o}")
>>> Value at point 1.0 => 1.0 >>> Value at point 2.0 => 4.0 >>> Value at point 3.0 => 9.0 >>> Value at point 4.0 => 16.0 >>> Value at point 5.0 => 25.0 >>> Value at point 6.0 => 36.0 >>> Value at point 7.0 => 49.0 >>> Value at point 8.0 => 64.0 >>> Value at point 9.0 => 81.0
What happens if we send the entire vector to our gradient function?
try: out_vector = grad_func(vector) except Exception as ex: print(f"Type of exception => {type(ex).__name__}") print(f"Exception => {ex}")
>>> Type of exception => TypeError >>> Exception => Grad only applies to real scalar-output functions. Try jacobian, elementwise_grad or holomorphic_grad.
Let us do what the exception tells us to do. Use elementwise_grad
for the vectorization process.
# Let us now vectorize the gradient code egrad_func = egrad(func_square) try: out_vector = egrad_func(vector) for v, o in zip(vector, out_vector): print(f"Grad at point {v} => {o}") except Exception as ex: print(f"Type of exception => {type(ex).__name__}") print(f"Exception => {ex}")
With the elementwise_grad
function, we could vectorize the gradient function.
>>> Grad at point 1.0 => 2.0 >>> Grad at point 2.0 => 4.0 >>> Grad at point 3.0 => 6.0 >>> Grad at point 4.0 => 8.0 >>> Grad at point 5.0 => 10.0 >>> Grad at point 6.0 => 12.0 >>> Grad at point 7.0 => 14.0 >>> Grad at point 8.0 => 16.0 >>> Grad at point 9.0 => 18.0
Automatic Differentiation is at the very heart of Deep Learning. Any framework that facilitates differential programming allows users to navigate and exploit patterns in data through backpropagation.
Learn about Automatic Differentiation and Differential programming in our blog post series:
- Automatic Differentiation Part 1: Understanding the Math
- Automatic Differentiation Part 2: Implementation Using Micrograd
XLA
It is safe to say that the fields of Deep Learning (DL) and Machine Learning (ML) consist of an enormous amount of Linear Algebra. All computations from start to finish are mostly Linear Algebra.
What if we told you there is a compiler in town that can make Linear Algebra operations more efficient?
Enters XLA: XLA stands for Accelerated Linear Algebra. It is a domain-specific compiler that accelerates linear algebra operations. The compiled operations are device agnostic. It runs on the CPU, GPU, and TPU with no code change.
👀 What Is JAX (revisited)?
Understanding what autograd
and XLA
does gives us a basic intuition about JAX.
JAX is a high-performance, numerical computing library incorporating composable function transformations.
—Why You Should (or Shouldn’t) be Using Google’s JAX in 2023
That sounds intimidating, but think about it again. Thanks to autograd
, the NumPy-like API and automatic differentiation engine make JAX a very efficient numerical computing library.
The inclusion of the XLA
compiler makes JAX a highly performant numerical computing library incorporating composable function transformations.
We will talk about what composable function transformation means in an upcoming blog post.
⬇️ Import JAX
Let’s talk about JAX by working on it hands-on. We have already installed JAX on our system. Now let’s import it to get started.
import jax
📚 Understanding the Components: API Layering of JAX
Before we start multiplying matrices and backpropagating on them, let us take a moment to understand the various components of JAX. While starting with a library, knowing its basic API design is always a good practice.
The version of JAX used when writing this tutorial is 0.3.25
. The API design of JAX is done in a way where we have the high-level abstraction of jax.numpy
and the low-level abstraction of jax.lax
.
Where jax.numpy
is similar to the original NumPy package, jax.lax
is a wrapper around Google’s XLA compiler.
Note: Did you notice that lax is an anagram of xla? 🤯
If you head over to the official documentation of JAX API, you will see several sub-packages and sub-topics with their APIs listed.
The most used APIs are the following:
jax.numpy
jax.lax
While the topics that are very important in the API design paradigm are:
- Just-in-time compilation (
jit
) - Automatic differentiation (
grad
) - Vectorization (
vmap
) - Parallelization (
pmap
)
We will discuss these topics and sub-packages with corresponding code snippets as we go through the tutorial. Let us import them first into our work environment.
import numpy as np import jax from jax import numpy as jnp from jax import make_jaxpr from jax import grad, jit, vmap, pmap
💯 Numerical Computation in JAX
This section will take us through the most used APIs of JAX: jax.numpy
and jax.lax
. Before diving in, we must note that JAX is not a Deep Learning (DL) framework. Instead, it is a numerical computation library. It is just that DL falls into the numerical computation paradigm.
For the ease of numerical computation, it has a NumPy API that mirrors the API of yet another very powerful numerical computation library (yes, you guessed it, NumPy 😁).
The thing that makes JAX stand out is its wrapper for the XLA compiler, jax.lax
. The jax.numpy
wrapper is basic XLA code with the jax.lax
API. This makes JAX code not only device agnostic but also jit compilable.
Being device agnostic means that the same code can be run on different hardware (CPUs, GPUs, and TPUs). With the JIT compilation, the same code can run much faster and more efficiently. This is why JAX is referred to as NumPy on steroids.
jax.numpy
In this section, we learn how to write NumPy-like code using jax.numpy
.
# Build an array of 0 to 9 with the `jax.numpy` API array = jnp.arange(0, 10, dtype=jnp.int8) print(f"array => {array}")
>>> array => [0 1 2 3 4 5 6 7 8 9]
This is a great thing to have. One with a fair amount of knowledge in NumPy does not need to learn something new. We can easily port the programs built on NumPy into JAX by adding the extra j
. With Python duck typing, jax.numpy
can be a drop-in replacement for any numpy
code.
Let’s now look into some differences. The first one is the data type of the values in jax.numpy
.
print(type(array))
>>> <class 'jaxlib.xla_extension.DeviceArray'>
The DeviceArray
is the JAX equivalent to numpy.ndarray
. However, the two are not exactly the same. For example, JAX is device agnostic, while NumPy is not.
We have seen that the jax.numpy
wrapper mirrors the NumPy python library well, but there remains a few stark differences between the two.
A major one is DeviceArrays
are immutable, unlike numpy.ndarrays
. We illustrate this using the following code snippet.
jax_array = jnp.arange(1, 10, dtype=jax.numpy.int8) numpy_array = np.arange(1, 10).astype(np.int8) try: numpy_array[2] = 2 except Exception as ex: print(f"Type of exception => {type(ex).__name__}") print(f"Exception => {ex}") try: jax_array[2] = 2 except Exception as ex: print(f"Type of exception => {type(ex).__name__}") print(f"Exception => {ex}")
Type of exception => TypeError Exception => '<class 'jaxlib.xla_extension.DeviceArray'>' object does not support item assignment. JAX arrays are immutable. Instead of ``x[idx] = y``, use ``x = x.at[idx].set(y)`` or another .at[] method: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html
To combat this problem, JAX has an at[].set()
clause.
Note: The clause does not make changes in place. The mutation creates another DeviceArray
with the necessary changes. This is only correct outside JIT; in most cases within JIT, updates will happen in-place (from: GitHub Discussion).
This design decision was taken to make functions pure (we will discuss pure functions in an upcoming blog post).
try: mutated_jax_array = jax_array.at[2].set(200) except Exception as ex: print(f"Type of exception => {type(ex).__name__}") print(f"Exception => {ex}") print(f"Original Array => {jax_array}") print(f"Mutated Array => {mutated_jax_array}")
>>> Original Array => [1 2 3 4 5 6 7 8 9] >>> Mutated Array => [ 1 2 -56 4 5 6 7 8 9]
Another key point to note is the indexing of tensors in JAX.
try: print("Indexing 1000th position of a NumPy array...") print(numpy_array[1000]) except Exception as ex: print(type(ex).__name__) print(ex)
>>> Indexing 1000th position of a NumPy array... >>> IndexError >>> index 1000 is out of bounds for axis 0 with size 9
try: print("Indexing 1000th position of a JAX array...") print(jax_array[1000]) except Exception as ex: print(type(ex).__name__) print(ex)
>>> Indexing 1000th position of a JAX array... >>> 9
👀 What happened here?
In JAX, the indexing is capped. This is a little caveat that we need to take care of so that we do not see our code fail silently.
jax.lax
Let’s talk a little bit about jax.lax
now. While the NumPy API makes it easier for you to enter the world of JAX, jax.lax
is what powers the library with all of its functionalities.
—jax.lax module
jax.lax
is a library of primitive operations that underpins libraries such asjax.numpy
.
While jax.numpy
is a high-level abstraction that makes it easier to code, jax.lax
is much more powerful with many constraints.
jax.lax
does not even support automatic type casting. This is demonstrated using the following code snippets.
# Checking the lenient `jax.numpy` API try: print(jax.lax.add(jnp.float32(1), 2.0)) except Exception as ex: print(f"Type of exception => {type(ex).__name__}") print(f"Exception => {ex}")
>>> 3.0
# Checking the stricter `jax.lax` API 😭 try: jax.lax.add(1, 2.0) except Exception as ex: print(f"Type of exception => {type(ex).__name__}") print(f"Exception => {ex}")
>>> Type of exception => TypeError >>> Exception => lax.add requires arguments to have the same dtypes, got int32, float32. (Tip: jnp.add is a similar function that does automatic type promotion on inputs).
What's next? We recommend PyImageSearch University.
86+ total classes • 115+ hours hours of on-demand code walkthrough videos • Last updated: January 2025
★★★★★ 4.84 (128 Ratings) • 16,000+ Students Enrolled
I strongly believe that if you had the right teacher you could master computer vision and deep learning.
Do you think learning computer vision and deep learning has to be time-consuming, overwhelming, and complicated? Or has to involve complex mathematics and equations? Or requires a degree in computer science?
That’s not the case.
All you need to master computer vision and deep learning is for someone to explain things to you in simple, intuitive terms. And that’s exactly what I do. My mission is to change education and how complex Artificial Intelligence topics are taught.
If you're serious about learning computer vision, your next stop should be PyImageSearch University, the most comprehensive computer vision, deep learning, and OpenCV course online today. Here you’ll learn how to successfully and confidently apply computer vision to your work, research, and projects. Join me in computer vision mastery.
Inside PyImageSearch University you'll find:
- ✓ 86+ courses on essential computer vision, deep learning, and OpenCV topics
- ✓ 86 Certificates of Completion
- ✓ 115+ hours hours of on-demand video
- ✓ Brand new courses released regularly, ensuring you can keep up with state-of-the-art techniques
- ✓ Pre-configured Jupyter Notebooks in Google Colab
- ✓ Run all code examples in your web browser — works on Windows, macOS, and Linux (no dev environment configuration required!)
- ✓ Access to centralized code repos for all 540+ tutorials on PyImageSearch
- ✓ Easy one-click downloads for code, datasets, pre-trained models, etc.
- ✓ Access on mobile, laptop, desktop, etc.
Summary
Great job on completing the first part of the tutorial on JAX! In this tutorial, we covered the background and origins of JAX, specifically highlighting its parent libraries autograd
and xla
. We also explored the API layering of JAX and delved into the details of the two most commonly used APIs: jax.numpy and jax.lax.
Now, it’s time to move on to the next part of the tutorial, which will focus on functional transformations in JAX. These transformations, such as grad
, jit
, vmap
, and pmap
, are essential tools in the JAX toolkit and allow you to optimize your code for better performance and efficiency.
Finally, in the third and final part of the tutorial, we will put everything we’ve learned to the test by training a model from scratch using JAX. This will be a great opportunity to apply the concepts and techniques covered in the first two parts of the tutorial and see the power of JAX in action. So buckle up because the next tutorial will be an exciting and resource-filled adventure!
We would like to acknowledge the detailed review and discussion from Jake Vanderplas.
References
- What is Automatic Differentiation?
- You don’t know JAX
- Why You Should (or Shouldn’t) be Using Google’s JAX in 2023
- The Sharp Bits 🔪 — JAX documentation
- Training a Simple Neural Network, with tensorflow/datasets Data Loading
- jax/README.md at main · google/jax · GitHub
- GitHub – google/jax: Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
- JAX Crash Course – Accelerating Machine Learning code!
Citation Information
A. R. Gosthipaty and R. Raha. “Learning JAX in 2023: Part 1 — The Ultimate Guide to Accelerating Numerical Computation and Machine Learning,” PyImageSearch, P. Chugh, S. Huot, K. Kidriavsteva, and A. Thanki, eds., 2023, https://pyimg.co/uwe1j
@incollection{ARG-RR_2023_JAX1, author = {Aritra Roy Gosthipaty and Ritwik Raha}, title = {Learning {JAX} in 2023: Part 1 — The Ultimate Guide to Accelerating Numerical Computation and Machine Learning}, booktitle = {PyImageSearch}, editor = {Puneet Chugh and Susan Huot and Kseniia Kidriavsteva and Abhishek Thanki}, year = {2023}, url = {https://pyimg.co/uwe1j}, }
Unleash the potential of computer vision with Roboflow - Free!
- Step into the realm of the future by signing up or logging into your Roboflow account. Unlock a wealth of innovative dataset libraries and revolutionize your computer vision operations.
- Jumpstart your journey by choosing from our broad array of datasets, or benefit from PyimageSearch’s comprehensive library, crafted to cater to a wide range of requirements.
- Transfer your data to Roboflow in any of the 40+ compatible formats. Leverage cutting-edge model architectures for training, and deploy seamlessly across diverse platforms, including API, NVIDIA, browser, iOS, and beyond. Integrate our platform effortlessly with your applications or your favorite third-party tools.
- Equip yourself with the ability to train a potent computer vision model in a mere afternoon. With a few images, you can import data from any source via API, annotate images using our superior cloud-hosted tool, kickstart model training with a single click, and deploy the model via a hosted API endpoint. Tailor your process by opting for a code-centric approach, leveraging our intuitive, cloud-based UI, or combining both to fit your unique needs.
- Embark on your journey today with absolutely no credit card required. Step into the future with Roboflow.
To download the source code to this post (and be notified when future tutorials are published here on PyImageSearch), simply enter your email address in the form below!
Download the Source Code and FREE 17-page Resource Guide
Enter your email address below to get a .zip of the code and a FREE 17-page Resource Guide on Computer Vision, OpenCV, and Deep Learning. Inside you'll find my hand-picked tutorials, books, courses, and libraries to help you master CV and DL!
Comment section
Hey, Adrian Rosebrock here, author and creator of PyImageSearch. While I love hearing from readers, a couple years ago I made the tough decision to no longer offer 1:1 help over blog post comments.
At the time I was receiving 200+ emails per day and another 100+ blog post comments. I simply did not have the time to moderate and respond to them all, and the sheer volume of requests was taking a toll on me.
Instead, my goal is to do the most good for the computer vision, deep learning, and OpenCV community at large by focusing my time on authoring high-quality blog posts, tutorials, and books/courses.
If you need help learning computer vision and deep learning, I suggest you refer to my full catalog of books and courses — they have helped tens of thousands of developers, students, and researchers just like yourself learn Computer Vision, Deep Learning, and OpenCV.
Click here to browse my full catalog.