jax

Installation
SKILL.md

JAX - Autograd and XLA (Accelerated Linear Algebra)

JAX is a framework that combines a NumPy-like API with a powerful system of composable function transformations: Grad (differentiation), Jit (compilation), Vmap (vectorization), and Pmap (parallelization).

When to Use

  • High-performance scientific simulations requiring GPU/TPU acceleration.
  • Custom machine learning research where PyTorch/TF abstractions are too restrictive.
  • Calculating higher-order derivatives (Hessians, Jacobians) for optimization.
  • Physics-informed machine learning and differentiable simulations.
  • Automatic vectorization of functions (no more manual batching).
  • Running the same code on CPU, GPU, and TPU without changes.

Reference Documentation

Official docs: https://jax.readthedocs.io/
GitHub: https://github.com/google/jax
Search patterns: jax.numpy, jax.jit, jax.grad, jax.vmap, jax.random

Related skills

More from tondevrel/scientific-agent-skills

Installs
17
GitHub Stars
9
First Seen
Feb 8, 2026