**Table of Contents**

- Learning JAX in 2023: Part 1 — The Ultimate Guide to Accelerating Numerical Computation and Machine Learning
- 🙌🏻 Introduction
- Configuring Your Development Environment
- Having Problems Configuring Your Development Environment?
- 🤔 What Is JAX?
- 👀 What Is JAX (revisited)?
- ⬇️ Import JAX
- 📚 Understanding the Components: API Layering of JAX
- 💯 Numerical Computation in JAX
- Summary

**Learning JAX in 2023: Part 1 — The Ultimate Guide to Accelerating Numerical Computation and Machine Learning**

In this tutorial, you will learn the basics of the JAX library, including how to install and use it to perform numerical computation and machine learning tasks using NumPy-like syntax and GPU acceleration.

This lesson is the 1st of a 3-part series on Learning JAX in 2023:

*Learning JAX in 2023: Part 1 — The Ultimate Guide to Accelerating Numerical Computation and Machine Learning**Learning JAX in 2023: Part 2 — JAX’s Power Tools*`grad`

,`jit`

,`vmap`

, and`pmap`

*Learning JAX in 2023: Part 3 — A Step-by-Step Guide to Training Your First Machine Learning Model with JAX*

**To learn how to get started with JAX, ***just keep reading.*

#### Looking for the source code to this post?

Jump Right To The Downloads Section**Learning JAX in 2023: Part 1 — The Ultimate Guide to Accelerating Numerical Computation and Machine Learning**

**🙌🏻 ****Introduction**

As deep learning practitioners, it can be tough to keep up with all the new developments. New academic papers and models are always coming out; there’s a new framework to learn every few years. Recently, many people have been talking about JAX, a new numerical computing library that can make your code run faster.

Many people have asked us to create a course about JAX, so we decided to take on the challenge. In this series, we’ll not only teach you about JAX, but also how to learn and understand new concepts. We’ll keep the language simple and avoid using jargon, but if you need help understanding anything, please let us know, and we’ll do our best to help.

Once you complete this course, you’ll be able to understand and work with any code written in JAX/FLAX. Major companies like Google Research, Hugging Face, and OpenAI are already using JAX heavily, so this is a valuable skill to have. Let’s get started and learn all about it!

**Configuring Your Development Environment**

To follow this guide, you need to have the JAX library installed on your system. JAX is written in pure Python, but it depends on XLA, which needs to be installed as the jaxlib package (from: jax repository).

Luckily, jaxlib and jax are pip-installable:

$ pip install jaxlib $ pip install numpy $ pip install autograd $ pip install jax

**If you need help configuring your development environment for OpenCV, we highly recommend that you read our **

**— it will have you up and running in a matter of minutes.**

*pip install OpenCV*guide**Having Problems Configuring Your Development Environment?**

All that said, are you:

- Short on time?
- Learning on your employer’s administratively locked system?
- Wanting to skip the hassle of fighting with the command line, package managers, and virtual environments?
**Ready to run the code***right now***on your Windows, macOS, or Linux system?**

Then join PyImageSearch University today!

**Gain access to Jupyter Notebooks for this tutorial and other PyImageSearch guides that are ***pre-configured*** to run on Google Colab’s ecosystem right in your web browser!** No installation required.

And best of all, these Jupyter Notebooks will run on Windows, macOS, and Linux!

**🤔 ****What Is JAX?**

JAX is the combination of `autograd`

and `XLA`

. Before diving into the nitty-gritty of JAX, let us look into `autograd`

and `XLA`

briefly.

** Note:** The section about autograd and XLA is meant to provide a more holistic understanding of the principles with which JAX was built. They are optional to getting started with JAX.

Click here to skip to “What Is JAX (revisited)?”

`autograd`

`autograd`

Gradients run the deep learning world quite literally. For example, we can compute the gradients (derivatives) of an equation in the following ways:

**Manual:**We use our calculus knowledge and derive the derivatives by hand. The problem with this approach is that it is manual. It would take a lot of time for a Deep Learning researcher to derive the model’s derivatives by hand.**Symbolic:**We can obtain the derivatives via symbols and a program that can mimic the manual process. The problem with this approach is termed**expression swell**. Here the derivatives of a particular expression are exponentially longer (think chain rule) than the expression itself. This becomes quite difficult to track.**Numeric:**Here, we use the finite differences method to derive the derivatives.**Automatic**: The star ⭐️ of the show.

Automatic differentiation (autodiff) is the type of differentiation we all love and use when training our deep neural networks. In a previous tutorial, we covered the math and code base to understand how autodiff works. If you are interested in how autodiff works, we urge you to read the linked tutorial first and then return to the current one.

`autograd`

is a python package that performs *automatic differentiation* on native python and NumPy code. The code base is fairly simple.

Two points to note here:

- There is a light wrapper
`autograd.numpy`

around the native NumPy codebase. This allows users to use NumPy-like semantics, harnessing the power of automatic differentiation. `autograd.grad`

and`autograd.elementwise_grad`

help with the actual automatic differentiation.

We will now look at a few code snippets to demonstrate how `autograd`

works.

Let’s start by importing the necessary packages.

*Note:*`autograd`

’s numpy module is imported as `anp`

to distinguish from `np`

(original NumPy) and `jnp`

(jax numpy).

# Import the necessary packages from autograd import numpy as anp from autograd import grad from autograd import elementwise_grad as egrad

We will first define a simple function and then compute the function’s gradient using automatic differentiation. Next, we find the function’s gradient at a particular point (scalar) and for a list of points (vector).

We build a function that takes a value and returns the square of it .

def func_square(x): # Return the square of the input return x**2 # Build a scalar input and pass it to the # square function x = 4.0 squared_x = func_square(x=x) print(f"x => {x}\nx**2 => {squared_x}")

>>> x => 4.0 >>> x**2 => 16.0

The function that we have defined is . We know that the function’s derivative is .

We can achieve this derivative by applying the `autograd.grad`

function. Let us see how that works.

# Compute the derivative of the square function grad_func = grad(func_square) point = 1.0 # Retrieve the gradient of the function at a particular point print(f"Gradient of square func at {point} => {grad_func(1.0)}")

>>> Gradient of square func at 1.0 => 2.0

The above code snippet allows us to calculate the gradient of a scalar function. Next, let us see how to do the same with vectors.

# Let's pass a vector to the square function vector = anp.arange(1, 10, dtype=anp.float32) out_vector = func_square(vector) # Iterate over the vector and its output for v, o in zip(vector, out_vector): print(f"Value at point {v} => {o}")

>>> Value at point 1.0 => 1.0 >>> Value at point 2.0 => 4.0 >>> Value at point 3.0 => 9.0 >>> Value at point 4.0 => 16.0 >>> Value at point 5.0 => 25.0 >>> Value at point 6.0 => 36.0 >>> Value at point 7.0 => 49.0 >>> Value at point 8.0 => 64.0 >>> Value at point 9.0 => 81.0

*What happens if we send the entire vector to our gradient function?*

try: out_vector = grad_func(vector) except Exception as ex: print(f"Type of exception => {type(ex).__name__}") print(f"Exception => {ex}")

>>> Type of exception => TypeError >>> Exception => Grad only applies to real scalar-output functions. Try jacobian, elementwise_grad or holomorphic_grad.

Let us do what the exception tells us to do. Use `elementwise_grad`

for the vectorization process.

# Let us now vectorize the gradient code egrad_func = egrad(func_square) try: out_vector = egrad_func(vector) for v, o in zip(vector, out_vector): print(f"Grad at point {v} => {o}") except Exception as ex: print(f"Type of exception => {type(ex).__name__}") print(f"Exception => {ex}")

With the `elementwise_grad`

function, we could vectorize the gradient function.

>>> Grad at point 1.0 => 2.0 >>> Grad at point 2.0 => 4.0 >>> Grad at point 3.0 => 6.0 >>> Grad at point 4.0 => 8.0 >>> Grad at point 5.0 => 10.0 >>> Grad at point 6.0 => 12.0 >>> Grad at point 7.0 => 14.0 >>> Grad at point 8.0 => 16.0 >>> Grad at point 9.0 => 18.0

Automatic Differentiation is at the very heart of Deep Learning. Any framework that facilitates differential programming allows users to navigate and exploit patterns in data through backpropagation.

Learn about Automatic Differentiation and Differential programming in our blog post series:

- Automatic Differentiation Part 1: Understanding the Math
- Automatic Differentiation Part 2: Implementation Using Micrograd

`XLA`

`XLA`

It is safe to say that the fields of Deep Learning (DL) and Machine Learning (ML) consist of an enormous amount of **Linear Algebra**. All computations from start to finish are mostly Linear Algebra.

What if we told you there is a compiler in town that can make Linear Algebra operations more efficient?

**Enters XLA:** XLA stands for Accelerated Linear Algebra. It is a domain-specific compiler that accelerates linear algebra operations. The compiled operations are device agnostic. It runs on the CPU, GPU, and TPU with no code change.

**👀 ****What Is JAX (revisited)?**

Understanding what `autograd`

and `XLA`

does gives us a basic intuition about JAX.

JAX is a high-performance, numerical computing library incorporating composable function transformations.

—Why You Should (or Shouldn’t) be Using Google’s JAX in 2023

That sounds intimidating, but think about it again. Thanks to `autograd`

, the NumPy-like API and automatic differentiation engine make JAX a very efficient **numerical computing library**.

The inclusion of the `XLA`

compiler makes JAX a **highly performant numerical computing library incorporating composable function transformations**.

We will talk about what **composable function transformation** means in an upcoming blog post.

**⬇️ Import JAX**

Let’s talk about JAX by working on it hands-on. We have already installed JAX on our system. Now let’s import it to get started.

import jax

**📚 Understanding the Components: API Layering of JAX**

Before we start multiplying matrices and backpropagating on them, let us take a moment to understand the various components of JAX. While starting with a library, knowing its basic API design is always a good practice.

The version of JAX used when writing this tutorial is `0.3.25`

. The API design of JAX is done in a way where we have the high-level abstraction of `jax.numpy`

and the low-level abstraction of `jax.lax`

.

Where `jax.numpy`

is similar to the original NumPy package, `jax.lax`

is a wrapper around Google’s XLA compiler.

** Note:** Did you notice that

**lax**is an anagram of

**xla**? 🤯

If you head over to the official documentation of JAX API, you will see several sub-packages and sub-topics with their APIs listed.

The most used APIs are the following:

`jax.numpy`

`jax.lax`

While the topics that are very important in the API design paradigm are:

- Just-in-time compilation (
`jit`

) - Automatic differentiation (
`grad`

) - Vectorization (
`vmap`

) - Parallelization (
`pmap`

)

We will discuss these topics and sub-packages with corresponding code snippets as we go through the tutorial. Let us import them first into our work environment.

import numpy as np import jax from jax import numpy as jnp from jax import make_jaxpr from jax import grad, jit, vmap, pmap

**💯 ****Numerical Computation in JAX**

This section will take us through the most used APIs of JAX: `jax.numpy`

and `jax.lax`

. Before diving in, we must note that JAX is ** not** a Deep Learning (DL) framework. Instead, it is a numerical computation library. It is just that DL falls into the numerical computation paradigm.

For the ease of numerical computation, it has a NumPy API that mirrors the API of yet another very powerful numerical computation library (yes, you guessed it, NumPy 😁).

The thing that makes JAX stand out is its wrapper for the XLA compiler, `jax.lax`

. The `jax.numpy`

wrapper is basic XLA code with the `jax.lax`

API. This makes JAX code not only **device agnostic** but also **jit compilable**.

Being device agnostic means that the same code can be run on different hardware (CPUs, GPUs, and TPUs). With the JIT compilation, the same code can run much faster and more efficiently. This is why JAX is referred to as *NumPy on steroids*.

`jax.numpy`

`jax.numpy`

In this section, we learn how to write NumPy-like code using `jax.numpy`

.

# Build an array of 0 to 9 with the `jax.numpy` API array = jnp.arange(0, 10, dtype=jnp.int8) print(f"array => {array}")

>>> array => [0 1 2 3 4 5 6 7 8 9]

This is a great thing to have. One with a fair amount of knowledge in NumPy does not need to learn something new. We can easily port the programs built on NumPy into JAX by adding the extra `j`

. With Python duck typing, `jax.numpy`

can be a drop-in replacement for any `numpy`

code.

Let’s now look into some **differences**. The first one is the data type of the values in `jax.numpy`

.

print(type(array))

>>> <class 'jaxlib.xla_extension.DeviceArray'>

The `DeviceArray`

is the JAX equivalent to `numpy.ndarray`

. However, the two are not exactly the same. For example, JAX is device agnostic, while NumPy is not.

We have seen that the `jax.numpy`

wrapper mirrors the NumPy python library well, but there remains a few stark differences between the two.

A major one is** DeviceArrays are immutable, unlike numpy.ndarrays. **We illustrate this using the following code snippet

**.**

jax_array = jnp.arange(1, 10, dtype=jax.numpy.int8) numpy_array = np.arange(1, 10).astype(np.int8) try: numpy_array[2] = 2 except Exception as ex: print(f"Type of exception => {type(ex).__name__}") print(f"Exception => {ex}") try: jax_array[2] = 2 except Exception as ex: print(f"Type of exception => {type(ex).__name__}") print(f"Exception => {ex}")

Type of exception => TypeError Exception => '<class 'jaxlib.xla_extension.DeviceArray'>' object does not support item assignment. JAX arrays are immutable. Instead of ``x[idx] = y``, use ``x = x.at[idx].set(y)`` or another .at[] method: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html

To combat this problem, JAX has an `at[].set()`

clause.

** Note:** The clause does not make changes

*in place*. The mutation creates another

`DeviceArray`

with the necessary changes. This is only correct outside JIT; in most cases within JIT, updates will happen in-place (from: GitHub Discussion).This design decision was taken to make functions **pure** (we will discuss pure functions in an upcoming blog post).

try: mutated_jax_array = jax_array.at[2].set(200) except Exception as ex: print(f"Type of exception => {type(ex).__name__}") print(f"Exception => {ex}") print(f"Original Array => {jax_array}") print(f"Mutated Array => {mutated_jax_array}")

>>> Original Array => [1 2 3 4 5 6 7 8 9] >>> Mutated Array => [ 1 2 -56 4 5 6 7 8 9]

Another key point to note is the **indexing** of tensors in JAX.

try: print("Indexing 1000th position of a NumPy array...") print(numpy_array[1000]) except Exception as ex: print(type(ex).__name__) print(ex)

>>> Indexing 1000th position of a NumPy array... >>> IndexError >>> index 1000 is out of bounds for axis 0 with size 9

try: print("Indexing 1000th position of a JAX array...") print(jax_array[1000]) except Exception as ex: print(type(ex).__name__) print(ex)

>>> Indexing 1000th position of a JAX array... >>> 9

👀 What happened here?

In JAX, the indexing is **capped**. This is a little caveat that we need to take care of so that we do not see our code fail *silently*.

`jax.lax`

`jax.lax`

Let’s talk a little bit about `jax.lax`

now. While the NumPy API makes it easier for you to enter the world of JAX, `jax.lax`

is what powers the library with all of its functionalities.

`jax.lax`

is a library of primitive operations that underpins libraries such as`jax.numpy`

.—jax.lax module

While `jax.numpy`

is a high-level abstraction that makes it easier to code,` jax.lax`

is much more powerful with many constraints.

`jax.lax`

does not even support automatic type casting. This is demonstrated using the following code snippets.

# Checking the lenient `jax.numpy` API try: print(jax.lax.add(jnp.float32(1), 2.0)) except Exception as ex: print(f"Type of exception => {type(ex).__name__}") print(f"Exception => {ex}")

>>> 3.0

# Checking the stricter `jax.lax` API 😭 try: jax.lax.add(1, 2.0) except Exception as ex: print(f"Type of exception => {type(ex).__name__}") print(f"Exception => {ex}")

>>> Type of exception => TypeError >>> Exception => lax.add requires arguments to have the same dtypes, got int32, float32. (Tip: jnp.add is a similar function that does automatic type promotion on inputs).

### What's next? I recommend PyImageSearch University.

**Course information:**

69 total classes • 73 hours of on-demand code walkthrough videos • Last updated: March 2023

★★★★★ 4.84 (128 Ratings) • 15,800+ Students Enrolled

**I strongly believe that if you had the right teacher you could master computer vision and deep learning.**

Do you think learning computer vision and deep learning has to be time-consuming, overwhelming, and complicated? Or has to involve complex mathematics and equations? Or requires a degree in computer science?

That’s *not* the case.

All you need to master computer vision and deep learning is for someone to explain things to you in *simple, intuitive* terms. *And that’s exactly what I do*. My mission is to change education and how complex Artificial Intelligence topics are taught.

If you're serious about learning computer vision, your next stop should be PyImageSearch University, the most comprehensive computer vision, deep learning, and OpenCV course online today. Here you’ll learn how to *successfully* and *confidently* apply computer vision to your work, research, and projects. Join me in computer vision mastery.

**Inside PyImageSearch University you'll find:**

- ✓
**74 courses**on essential computer vision, deep learning, and OpenCV topics - ✓
**74 Certificates**of Completion - ✓
**84 hours**of on-demand video - ✓
**Brand new courses released**, ensuring you can keep up with state-of-the-art techniques*regularly* - ✓
**Pre-configured Jupyter Notebooks in Google Colab** - ✓ Run all code examples in your web browser — works on Windows, macOS, and Linux (no dev environment configuration required!)
- ✓ Access to
**centralized code repos for**on PyImageSearch*all*500+ tutorials - ✓
**Easy one-click downloads**for code, datasets, pre-trained models, etc. - ✓
**Access**on mobile, laptop, desktop, etc.

**Summary**

Great job on completing the first part of the tutorial on JAX! In this tutorial, we covered the background and origins of JAX, specifically highlighting its parent libraries `autograd`

and `xla`

. We also explored the API layering of JAX and delved into the details of the two most commonly used APIs: jax.numpy and jax.lax.

Now, it’s time to move on to the next part of the tutorial, which will focus on functional transformations in JAX. These transformations, such as `grad`

, `jit`

, `vmap`

, and `pmap`

, are essential tools in the JAX toolkit and allow you to optimize your code for better performance and efficiency.

Finally, in the third and final part of the tutorial, we will put everything we’ve learned to the test by training a model from scratch using JAX. This will be a great opportunity to apply the concepts and techniques covered in the first two parts of the tutorial and see the power of JAX in action. So buckle up because the next tutorial will be an exciting and resource-filled adventure!

We would like to acknowledge the detailed review and discussion from Jake Vanderplas.

**References**

- What is Automatic Differentiation?
- You don’t know JAX
- Why You Should (or Shouldn’t) be Using Google’s JAX in 2023
- The Sharp Bits 🔪 — JAX documentation
- Training a Simple Neural Network, with tensorflow/datasets Data Loading
- jax/README.md at main · google/jax · GitHub
- GitHub – google/jax: Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
- JAX Crash Course – Accelerating Machine Learning code!

**Citation Information**

**A. R. Gosthipaty and R. Raha.** “Learning JAX in 2023: Part 1 — The Ultimate Guide to Accelerating Numerical Computation and Machine Learning,” *PyImageSearch*, P. Chugh, S. Huot, K. Kidriavsteva, and A. Thanki, eds., 2023, https://pyimg.co/uwe1j

@incollection{ARG-RR_2023_JAX1, author = {Aritra Roy Gosthipaty and Ritwik Raha}, title = {Learning {JAX} in 2023: Part 1 — The Ultimate Guide to Accelerating Numerical Computation and Machine Learning}, booktitle = {PyImageSearch}, editor = {Puneet Chugh and Susan Huot and Kseniia Kidriavsteva and Abhishek Thanki}, year = {2023}, url = {https://pyimg.co/uwe1j}, }

### Want free GPU credits to train models?

- We used Jarvislabs.ai, a GPU cloud, for all the experiments.
- We are proud to offer PyImageSearch University students $20 worth of Jarvislabs.ai GPU cloud credits. Join PyImageSearch University and claim your $20 credit here.

In Deep Learning, we need to train Neural Networks. These Neural Networks can be trained on a CPU but take a lot of time. Moreover, sometimes these networks do not even fit (run) on a CPU.

To overcome this problem, we use **GPUs**. The problem is these GPUs are **expensive** and become outdated quickly.

GPUs are great because they take your Neural Network and train it quickly. The problem is that GPUs are expensive, so you don’t want to buy one and use it only occasionally. Cloud GPUs let you use a GPU and **only pay for the time you are running the GPU**. It’s a brilliant idea that saves you money.

**JarvisLabs** provides the best-in-class GPUs, and **PyImageSearch University students** get between 10-50 hours on a world-class GPU (time depends on the specific GPU you select).

This gives you a chance to **test-drive a monstrously powerful GPU** on any of our tutorials in a jiffy. So join PyImageSearch University today and try it for yourself.

**To download the source code to this post (and be notified when future tutorials are published here on PyImageSearch), simply enter your email address in the form below!**

#### Download the Source Code and FREE 17-page Resource Guide

Enter your email address below to get a .zip of the code and a **FREE 17-page Resource Guide on Computer Vision, OpenCV, and Deep Learning.** Inside you'll find my hand-picked tutorials, books, courses, and libraries to help you master CV and DL!

## Comment section

Hey, Adrian Rosebrock here, author and creator of PyImageSearch. While I love hearing from readers, a couple years ago I made the tough decision to no longer offer 1:1 help over blog post comments.

At the time I was receiving 200+ emails per day and another 100+ blog post comments. I simply did not have the time to moderate and respond to them all, and the sheer volume of requests was taking a toll on me.

Instead, my goal is to

do the most goodfor the computer vision, deep learning, and OpenCV community at large by focusing my time on authoring high-quality blog posts, tutorials, and books/courses.If you need help learning computer vision and deep learning, I suggest you refer to my full catalog of books and courses— they have helped tens of thousands of developers, students, and researchersjust like yourselflearn Computer Vision, Deep Learning, and OpenCV.Click here to browse my full catalog.