Training On Metal

TL;DR

  • Using the nightly pytorch build, I am no longer getting CPU fallback warnings due to missing kernels when training nanoGPT on Metal.
  • However — on my Macbook, at least — training on Metal is significantly slower with torch.compile than without it (though still faster than training on the CPU).
  • For the time being, if you want first-class Metal support for training models from scratch, MLX (Apple Silicon-only) and tinygrad (cross-platform) seem to be your best options.

Background

  • Apple Silicon’s unified memory architecture makes the Mac a potentially attractive platform for ML researchers, particularly when it comes to models too large to fit into VRAM on a single consumer NVIDIA card.
  • To take advantage of this, though, you need an ML framework with good Metal (MPS) support, ideally one capable of fusing kernels and optimizing memory loads through JIT compilation (or in some other way).
  • A Metal backend was first added to PyTorch back in 2022. While much progress has been made since then, MPS Ops coverage is still incomplete.
    • For example, if you try to run torch.svd on MPS, you get this error: “The operator aten::linalg_svd is not currently supported on the MPS backend and will fall back to run on the CPU. This may have performance implications.”.
    • As of now, over 70 ops are tagged with either “To triage” or “To be implemented”.
    • There is a Github issue you can comment on to drive prioritization.
  • Apple has recently released its own ML framework, called MLX, which supports not only MPS but also Neural Engine (ANE).
    • Unlike Core ML, which is primary geared towards inference and fine-tuning, MLX is intended for training models from scratch in a flexible way, similar to PyTorch or JAX.
    • There are no tools for automatically converting PyTorch (or JAX) models to MLX the way you can covert models to Core ML with coremltools, though.
    • You can certainly port the training loop manually, but it’s non-trivial. When I tried to get Gemini CLI to port nanoGPT to MLX it kept getting stuck, even after much prodding.
  • JAX has historically prioritized TPUs when it comes to non-NVIDIA accelerators, which makes sense given its Google lineage. While support for MPS in JAX is in the works, it’s still experimental at this point.
  • Then there’s tinygrad, which aims to provide first-class support for MPS out of the box. However, as with MLX, there is no way to automatically convert PyTorch models to tinygrad. A number of popular models have been ported, though. You can see what training looks like here.

nanoGPT as a training benchmark

  • nanoGPT is Adrej Karpathy’s from-scratch implementation of GPT-2 in PyTorch.
  • There are two main drivers, train.py for training and sample.py for inference.
  • The default is to use NVIDIA GPUs and JIT-compile the model into optimized kernels via torch.compile, but you can override this using the --device and --compile command-line switches, respectively.
  • When nanoGPT was first released (in 2022), MPS support in PyTorch was still nascent. Training on the Mac was limited to the CPU in practice, so much so that the quick start section of README.md branched on “I have a GPU” vs “I only have a macbook”.
  • This have steadily improved over time, though, as evidenced by this github issue.

Training nanoGPT on the Mac

Setup

This assumes you have uv installed:

  • git clone https://github.com/karpathy/nanoGPT.git
  • cd nanoGPT
  • uv venv
  • uv pip install numpy transformers datasets tiktoken wandb tqdm
  • uv pip install --pre torch --index-url https://download.pytorch.org/whl/nightly/cpu
  • uv run data/shakespeare_char/prepare.py

Metal training without torch.compile

/usr/bin/time uv run train.py config/train_shakespeare_char.py --device=mps --compile=False

...

step 0: train loss 4.2874, val loss 4.2823
iter 0: loss 4.2639, time 25857.56ms, mfu -100.00%
iter 10: loss 3.1459, time 220.50ms, mfu 1.69%
iter 20: loss 2.7319, time 221.44ms, mfu 1.69%
iter 30: loss 2.6226, time 221.30ms, mfu 1.69%
iter 40: loss 2.5756, time 220.74ms, mfu 1.69%
iter 50: loss 2.5239, time 220.24ms, mfu 1.69%

...

iter 4990: loss 0.8203, time 276.23ms, mfu 1.33%
step 5000: train loss 0.6166, val loss 1.7099
iter 5000: loss 0.8117, time 33010.34ms, mfu 1.19%
     3588.66 real        90.18 user        32.03 sys

Metal training with torch.compile

/usr/bin/time uv run train.py config/train_shakespeare_char.py --device=mps --compile=True

...

step 0: train loss 4.2874, val loss 4.2823
iter 0: loss 4.2672, time 85292.53ms, mfu -100.00%
iter 10: loss 3.1471, time 304.79ms, mfu 1.22%
iter 20: loss 2.7730, time 305.33ms, mfu 1.22%
iter 30: loss 2.6538, time 306.39ms, mfu 1.22%
iter 40: loss 2.5960, time 306.01ms, mfu 1.22%
iter 50: loss 2.5479, time 305.41ms, mfu 1.22%

...

iter 4990: loss 0.8286, time 329.89ms, mfu 1.11%
step 5000: train loss 0.6221, val loss 1.7106
iter 5000: loss 0.8242, time 1255915.73ms, mfu 1.00%
    39039.74 real        72.21 user        21.28 sys

CPU training

uv run train.py config/train_shakespeare_char.py --device=cpu --compile=False

...

step 0: train loss 4.2874, val loss 4.2823
iter 0: loss 4.2655, time 119984.47ms, mfu -100.00%
iter 10: loss 3.1338, time 1973.80ms, mfu 0.19%
iter 20: loss 2.7556, time 1984.34ms, mfu 0.19%
iter 30: loss 2.6094, time 1988.45ms, mfu 0.19%
iter 40: loss 2.5535, time 1983.75ms, mfu 0.19%
iter 50: loss 2.5237, time 2003.05ms, mfu 0.19%

Observations

  • I didn’t run CPU training to completion, but training on Metal without torch.compile is roughly 9 times faster than training on the CPU based on the 1st 50 iterations (221ms vs 1986ms).
  • With torch.compile, it’s only 6.5 times faster (306ms). In order words, turning on torch.compile slows training down by a factor of 1.38.
  • You can use asitop to confirm that GPU usage is at 100% when training on MPS, whether or not torch.compile is used.
  • CPU+GPU+ANE is stable at 100% without torch.compile. With torch.compile, it drops slightly below that at times.