Lecture 07
We will not spend much time on this as most data you will encounter is more likely to be in a tabular format (e.g. data frame) and tools like Pandas are more appropriate.
For basic saving and loading of NumPy arrays there are the save() and load() functions which use a custom binary format.
Additional functions for saving (savez(), savez_compressed(), savetxt()) exist for saving multiple arrays or saving a text representation of an array.
While not particularly recommended, if you need to read delimited (csv, tsv, etc.) data into a NumPy array you can use genfromtxt(),
array([[ 6. , 160. , 110. , 3.9 , 2.62 , 16.46 , 0. , 1. , 4. , 4. ],
[ 6. , 160. , 110. , 3.9 , 2.875, 17.02 , 0. , 1. , 4. , 4. ],
[ 4. , 108. , 93. , 3.85 , 2.32 , 18.61 , 1. , 1. , 4. , 1. ],
[ 6. , 258. , 110. , 3.08 , 3.215, 19.44 , 1. , 0. , 3. , 1. ],
[ 8. , 360. , 175. , 3.15 , 3.44 , 17.02 , 0. , 0. , 3. , 2. ],
[ 6. , 225. , 105. , 2.76 , 3.46 , 20.22 , 1. , 0. , 3. , 1. ],
[ 8. , 360. , 245. , 3.21 , 3.57 , 15.84 , 0. , 0. , 3. , 4. ],
[ 4. , 146.7 , 62. , 3.69 , 3.19 , 20. , 1. , 0. , 4. , 2. ],
[ 4. , 140.8 , 95. , 3.92 , 3.15 , 22.9 , 1. , 0. , 4. , 2. ],
[ 6. , 167.6 , 123. , 3.92 , 3.44 , 18.3 , 1. , 0. , 4. , 4. ],
[ 6. , 167.6 , 123. , 3.92 , 3.44 , 18.9 , 1. , 0. , 4. , 4. ],
[ 8. , 275.8 , 180. , 3.07 , 4.07 , 17.4 , 0. , 0. , 3. , 3. ],
[ 8. , 275.8 , 180. , 3.07 , 3.73 , 17.6 , 0. , 0. , 3. , 3. ],
[ 8. , 275.8 , 180. , 3.07 , 3.78 , 18. , 0. , 0. , 3. , 3. ],
[ 8. , 472. , 205. , 2.93 , 5.25 , 17.98 , 0. , 0. , 3. , 4. ],
[ 8. , 460. , 215. , 3. , 5.424, 17.82 , 0. , 0. , 3. , 4. ],
[ 8. , 440. , 230. , 3.23 , 5.345, 17.42 , 0. , 0. , 3. , 4. ],
[ 4. , 78.7 , 66. , 4.08 , 2.2 , 19.47 , 1. , 1. , 4. , 1. ],
[ 4. , 75.7 , 52. , 4.93 , 1.615, 18.52 , 1. , 1. , 4. , 2. ],
[ 4. , 71.1 , 65. , 4.22 , 1.835, 19.9 , 1. , 1. , 4. , 1. ],
[ 4. , 120.1 , 97. , 3.7 , 2.465, 20.01 , 1. , 0. , 3. , 1. ],
[ 8. , 318. , 150. , 2.76 , 3.52 , 16.87 , 0. , 0. , 3. , 2. ],
[ 8. , 304. , 150. , 3.15 , 3.435, 17.3 , 0. , 0. , 3. , 2. ],
[ 8. , 350. , 245. , 3.73 , 3.84 , 15.41 , 0. , 0. , 3. , 4. ],
[ 8. , 400. , 175. , 3.08 , 3.845, 17.05 , 0. , 0. , 3. , 2. ],
[ 4. , 79. , 66. , 4.08 , 1.935, 18.9 , 1. , 1. , 4. , 1. ],
[ 4. , 120.3 , 91. , 4.43 , 2.14 , 16.7 , 0. , 1. , 5. , 2. ],
[ 4. , 95.1 , 113. , 3.77 , 1.513, 16.9 , 1. , 1. , 5. , 2. ],
[ 8. , 351. , 264. , 4.22 , 3.17 , 14.5 , 0. , 1. , 5. , 4. ],
[ 6. , 145. , 175. , 3.62 , 2.77 , 15.5 , 0. , 1. , 5. , 6. ],
[ 8. , 301. , 335. , 3.54 , 3.57 , 14.6 , 0. , 1. , 5. , 8. ],
[ 4. , 121. , 109. , 4.11 , 2.78 , 18.6 , 1. , 1. , 4. , 2. ]])
This is an approach for deciding how to generalize operations between arrays with differing shapes.
Broadcast code is usually shorter / simpler and it can make the calculation more efficient,
When operating on two arrays, NumPy compares their shapes element-wise. It starts with the trailing (i.e. rightmost) dimensions and works its way left. Two dimensions are compatible when
they are equal, or
one of them is 1
If these conditions are not met, a
ValueError: operands could not be broadcast togetherexception is thrown, indicating that the arrays have incompatible shapes. The size of the resulting array is the size that is not 1 along each axis of the inputs.
Why does the code on the left work but not the code on the right?
x (2d array): 4 x 3
y (1d array): 3
----------------------
x+y (2d array): 4 x 3
x (2d array): 3 x 4
y (1d array): 3
----------------------
x+y (2d array): Error
x (2d array): 3 x 4
y (2d array): 3 x 1
----------------------
x+y (2d array): 3 x 4
A common Broadcasting pitfalls is accidental broadcasting where operations succeed but produce unexpected results,
Now we’d like to substract each row’s mean from that row.
Easy to get confused between (n,), (n,1), and (1,n) shapes’ behavior,
All three succeed but v and v_row broadcast across rows while v_col broadcasts across columns.
Broadcasting doesn’t copy data during the operation, but the result can be much larger than expected.
The @ operator (and np.matmul()) has special broadcasting rules:
For each of the following combinations determine what the resulting dimension will be using broadcasting
A [128 x 128 x 3] + B [3]
A [8 x 1 x 6 x 1] + B [7 x 1 x 5]
A [2 x 1] + B [8 x 4 x 3]
A [3 x 1] + B [15 x 3 x 5]
A [3] + B [4]
Below we generate a data set with 3 columns of random normal values. Each column has a different mean and standard deviation which we can check with mean() and std().
Let’s use broadcasting to standardize all three columns to have mean 0 and standard deviation 1.
In addition to arithmetic operators, broadcasting can be used with assignment via array indexing,
JAX is a library for array-oriented numerical computation (à la NumPy), with automatic differentiation and JIT compilation to enable high-performance machine learning research.
JAX provides a unified NumPy-like interface to computations that run on CPU, GPU, or TPU, in local or distributed settings.
JAX features built-in Just-In-Time (JIT) compilation via Open XLA, an open-source machine learning compiler ecosystem.
JAX functions support efficient evaluation of gradients via its automatic differentiation transformations.
JAX functions can be automatically vectorized to efficiently map them over arrays representing batches of inputs.
JAX is just one framework in a long history going back almost 20 years. These frameworks all share a common approach: representing computations as a graph of operations.
This enables automatic differentiation (via backpropagation) and optimization/compilation for efficient execution on GPUs or other specialized hardware.
Theano (2007) - U of Montreal
TensorFlow (2015) - Google Brain
PyTorch (2016) - Facebook
JAX (2018) - Google
Array([[ 0. , 0.9093 , -0.7568 ],
[-0.27942, 0.98936, -0.54402]], dtype=float32)
As we’ve just seen a JAX array is very similar to a NumPy array but there are some important differences.
The default JAX array dtypes are 32 bits not 64 bits (i.e. float32 not float64 and int32 not int64)
UserWarning: Explicitly requested dtype float64 requested in array is not available, and
will be truncated to dtype float32. To enable more dtypes, set the jax_enable_x64 configuration
option or the JAX_ENABLE_X64 shell environment variable. See
https://github.com/jax-ml/jax#current-gotchas for more.
jnp.array([1, 2, 3], dtype=jnp.float64)
Array([1., 2., 3.], dtype=float32)
64-bit dtypes can be enabled by setting jax_enable_x64=True in the JAX configuration.
JAX arrays are allocated to one or more devices
Using JAX interactively allows for the use of standard Python control flow (if, while, for, etc.) but this is not supported for some of JAX’s more advanced operations (e.g. jit and grad)
There are replacements for most of these constructs in JAX, but they are beyond the scope of this lecture.
Pseudo random number generation in JAX is a bit different than with NumPy - the latter depends on a global state that is updated each time a random function is called.
NumPy’s PRNG guarantees something called sequential equivalence which amounts to sampling N numbers sequentially is the same as sampling N numbers at once (e.g. a vector of length N).
Sequential equivalence can be problematic when using parallelization across multiple devices, consider the following code:
How do we guarantee that we get consistent results if we don’t know the order that bar() and baz() will run?
JAX makes use of random keys which are just a fancier version of random seeds - all of JAX’s random functions require a key as their first argument.
Since a key is essentially a seed we do not want to reuse them (unless we want an identical output). Therefore to generate multiple different PRN we can split a key to deterministically generate two (or more) new keys.
JAX’s jit() function compiles a function using XLA (Accelerated Linear Algebra), which can significantly speed up execution by optimizing the computation graph and reducing Python overhead.
When it works the jit tool is fantastic, but it does have a number of limitations,
Must use pure functions (no side effects)
Must primarily use JAX functions
jnp.minimum() not np.minimum() or min()Must generally avoid conditionals / control flow
Issues around concrete values when tracing (static values)
Check performance - there are not always gains + there is the initial cost of compilation
The grad() function takes a numerical function, returning a scalar, and returns a function for calculating the gradient of that function.
vmap()I would like to plot h() and jax.grad(h)() - lets see what happens,
We can only calculate the gradient for scalar valued functions. However, we can transform our scalar function into a vectorized function using vmap().
Array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
1., 1., 1., 1., 1., 1., 1., 1., 1.], dtype=float32)
JAX requires different packages depending on your hardware.
With uv:
The GPU versions require appropriate drivers and libraries (CUDA toolkit for NVIDIA, Metal for Apple Silicon).
One of JAX’s main advantages is seamless GPU support. Once installed correctly, JAX automatically uses available GPUs without requiring code changes.
CudaDevice(id=0)
Sta 663 - Spring 2026