NumPy vs JAX: Optimizing Performance and Functionality
Introduction
Numpy and Jax are both Python libraries that are widely used in numerical computing and scientific computing. While they have similar functionalities, there are some key differences between them that make Jax particularly useful for machine learning applications. In this tutorial, we will explore some of the differences between Numpy and Jax and provide code snippets to illustrate these differences.
Installation
Both Numpy and Jax can be installed using pip. To install Numpy, you can run the following command in your terminal:
pip install numpy
To install Jax, you can run the following command:
pip install jax
Note that Jax also requires the jaxlib
library, which can be installed using the following command:
pip install jaxlib
Functionality
Let us take a look at how to write NumPy
and Jax
when performing simple and complex scientific computing and machine learning.
Array creation
Numpy and Jax both provide functions for creating arrays. The simplest way to create an array in Numpy is to use the array
function:
import numpy as np
# create a 1D array
npArrOne = np.array([1, 2, 3])
print(npArrOne)
# create a 2D array
npArrTwo = np.array([[1, 2], [3, 4]])
print(npArrTwo)
In Jax, the corresponding function is jnp.array
:
import jax.numpy as jnp
# create a 1D array
jxArrOne = jnp.array([1, 2, 3])
print(jxArrOne)
# create a 2D array
jxArrTwo = jnp.array([[1, 2], [3, 4]])
print(jxArrTwo)
Note that the syntax is very similar in both libraries, but in Jax we import the numpy functions from the jax
package rather than the numpy
package.
Array indexing
Both Numpy and Jax provide powerful indexing capabilities for arrays. The basic syntax is the same in both libraries:
# indexing in numpy
npArr = np.array([1, 2, 3])
print(npArr[0])
# output: 1
# indexing in jax
jxArr = jnp.array([1, 2, 3])
print(jxArr[0])
# output: 1
However, Jax provides some additional indexing features that are particularly useful for machine learning. For example, Jax provides a function called jnp.index_update
that allows you to update an element of an array by index:
# update an element in numpy
npArr = np.array([1, 2, 3])
npArr[0] = 4
print(npArr)
# output: [4, 2, 3]
# update an element in jax
jxArr = jnp.array([1, 2, 3])
jxArr = jnp.index_update(b, 0, 4)
print(jxArr)
# output: [4, 2, 3]
Note that in Jax, we cannot modify an array in-place, so we need to reassign the result of jnp.index_update
to the original array.
Array broadcasting
One of the most powerful features of Numpy and Jax is their ability to broadcast arrays. Broadcasting allows you to perform operations on arrays with different shapes and sizes. Here is an example of broadcasting in Numpy:
# broadcasting in numpy
npArrOne = np.array([[1, 2], [3, 4]])
npArrTwo = np.array([10, 20])
print(npArrOne + npArrTwo)
# output: [[11, 22], [13, 24]]
In Jax, broadcasting works in a similar way:
# broadcasting in jax
jxArrOne = jnp.array([[1, 2], [3, 4]])
jxArrTwo = jnp.array([10, 20])
print(jxArrOne + jxArrTwo)
# output: [[11, 22], [13, 24]]
In the above example, Numpy and Jax are able to add the 2D array npArrOne
and the 1D array npArrTwo
by automatically broadcasting the dimensions of npArrTwo
to match the dimensions of npArrOne
.
For Jax, it provides some additional broadcasting features that are particularly useful for machine learning. For instance, Jax provides a function called jnp.vmap
that allows you to apply a function to multiple inputs using broadcasting:
# vmap in Jax
def add(x, y):
return x + y
jxArrOne = jnp.array([[1, 2], [3, 4]])
jxArrTwo = jnp.array([[10, 20], [30, 40]])
# apply add to multiple inputs using broadcasting
jxArrThree = jnp.vmap(add)(jxArrOne, jxArrTwo)
print(jxArrThree)
# output: [[11, 22], [33, 44]]
In the code snippet above, we define a function add
that adds two arrays element-wise. We then use jnp.vmap
to apply this function to two arrays jxArrOne
and jxArrTwo
, which have the same shape. The result is a new array jxArrThree
that has the same shape as jxArrOne
and jxArrTwo
, where each element of jxArrThree
is the sum of the corresponding elements of jxArrOne
and jxArrTwo
.
Performance
One of the key advantages of Jax over Numpy is its ability to compile code using the XLA compiler. This allows Jax to run code on GPUs and TPUs, which can significantly improve performance for large-scale machine learning applications.
NumPy | JAX | |
Hardware | CPU | CPU, GPU, TPU |
Execution | Synchronously | Asynchronously |
Parallel computation | No | Yes |
In the example below, we show how Jax and NumPy can be used to accelerate a simple matrix multiplication:
NumPy:
# matrix multiplication in numpy
import numpy as np
import time
npArrOne = np.random.rand(1000, 1000)
npArrOne = np.random.rand(1000, 1000)
start = time.time()
npArrThree = np.dot(npArrOne, npArrOne)
end = time.time()
print("NumPy time:", end - start)
Output:
NumPy time: 0.5427792072296143
JAX:
# matrix multiplication in jax
import jax.numpy as jnp
from jax import jit
jxArrOne = jnp.random.rand(1000, 1000)
jxArrTwo = jnp.random.rand(1000, 1000)
@jit
def matmul(jxArrOne, jxArrTwo):
return jnp.dot(jxArrOne, jxArrTwo)
start = time.time()
jxArrThree = matmul(jxArrOne, jxArrOne)
end = time.time()
print("JAX time:", end - start)
Output:
JAX time: 0.03486919403076172
In the code snippets above, we generate two random matrices NumPy and Jax of size 1000 x 1000, and we use the np.dot
function to perform matrix multiplication in Numpy, and the jnp.dot
function to perform matrix multiplication in Jax. We also use the @jit
decorator to compile the matmul
function using the XLA compiler.
From the output of time given, you can see that Jax is significantly faster than NumPy in terms of performance.
Conclusion
Summary table of the similarities and differences between NumPy and JAX
Feature | NumPy | JAX |
Functionality | Basic array operations, linear algebra, statistics, image processing, Fourier transform | All of NumPy's functionality, plus automatic differentiation, just-in-time compilation, parallel execution, and stateful computations |
Performance | Good for simple tasks | Excellent for complex tasks |
Scalability | Good for small to medium datasets | Excellent for large datasets |
So, we have explored some of the differences between Numpy and Jax, two powerful Python libraries that are widely used in numerical computing and scientific computing.
While they have similar functionalities, Jax provides some additional features that are particularly useful for machine learning applications, including powerful indexing and broadcasting capabilities, as well as the ability to run code on GPUs and TPUs using the XLA compiler, which makes it particularly useful for machine learning applications.
By understanding the strengths and weaknesses of each library, you can choose the one that best suits your needs, and maximize the performance and functionality of your scientific computing projects.
Good luck!
Subscribe to my newsletter
Read articles from Wesley Kambale directly inside your inbox. Subscribe to the newsletter, and don't miss out.
Written by
Wesley Kambale
Wesley Kambale
Wesley is a machine learning engineer and data scientist, adept at crafting production-ready ML systems that provide impactful solutions in the African market. As a tech conference speaker, he shares his expertise through insightful talks and occasional articles on TensorFlow and Keras, aiming to disseminate his knowledge and experiences. He is a seasoned community organizer with vast experience in launching and building Google Developer communities in western Uganda. He is an active organizer in Google Developer Groups (GDG) program and an alumni of the Google Developer Students Club (GDSC) program. Wesley has an undergraduate degree in computer science from Mbarara University of Science and Technology and holds various certificates and certifications in data science and machine learning.