Speeding up JAX on Macs with MLX
Overview
- While PyTorch supports Metal reasonably well, JAX has lagged behind.
- A while ago, there was a post from Apple about a Python package called jax-metal, which added experimental Metal support for JAX.
- The idea was to create a PJRT plugin for OpenXLA which maps StableHLO primitives to their MPSGraph counterparts.
- jax-metal was an interesting proof of concept, but Apple has not invested in keeping it up to date: the latest supported Mac OS version is Sonoma.
- Enter jax-mps, an open-source project inspired by jax-metal. The overall approach is similar, but StableHLO is mapped to MLX primitives instead of MPSGraph ones.
- Here JAX is effectively serving the role of the compiler front end, MLX that of the compiler back end and StableHLO (a dialect of MLIR) that of the IR.
- OpenXLA is not performing any optimizations, though, the way it usually would with JAX. This is because there is no LLVM backend for Metal (i.e. no Apple Silicon equivalent of NVPTX for NVIDIA hardware).
Should you care?
- It’s still early days for the jax-mps project. Many StableHLO opearations are not yet implemented and the reported speed-up over the CPU backend is fairly modest: on the order of ~4x.
- There is a real risk that jax-mps could ultimately go the way of jax-metal, but at least it supports more recent versions of Mac OS and jaxlib.
- For now, it may be best to think of jax-metal as a way to potentially speed up your JAX code on Apple hardware without having to explicitly port it to MLX, assuming the StableHLO primitives you require are supported already.
Install instructions
- Install uv
uv venv jax-metal --python 3.13source jax-metal/bin/activateuv pip install "jaxlib==0.9.0" jax numpy wheeluv pip install jax-mpspython -c 'import jax; print(jax.devices()); print(jax.numpy.arange(10))'
If all goes well, you should see output along the following lines:
WARNING:2026-04-12 16:44:41,122:jax._src.xla_bridge:905: Platform 'mps' is experimental and not all JAX functionality may be correctly supported!
[0 1 2 3 4 5 6 7 8 9]
Note that only jaxlib 0.9.0 is supported out of the box at this time.