jax

Installation
SKILL.md

JAX

JAX is "NumPy on steroids". It combines Autograd (automatic differentiation) with XLA (compilation). 2025 sees Flax NNX (PyTorch-style OOP) becoming standard.

When to Use

  • TPU Training: JAX runs natively on Google TPUs.
  • Research: If you need to compute 10th order derivatives or strange math.
  • Massive Scale: DeepMind and OpenAI use JAX for training frontier models.

Core Concepts

Functional Transformations

grad(), jit(), vmap(), pmap().

Flax (NNX)

Neural network library. NNX introduces mutable state (OOP) to make JAX feel like PyTorch.

Related skills
Installs
1
GitHub Stars
7
First Seen
Feb 10, 2026