pytorch
Installation
SKILL.md
PyTorch
Overview
PyTorch is a deep learning framework for building and training neural networks with dynamic computation graphs and automatic differentiation. It provides tensor operations with GPU acceleration, nn.Module for defining architectures, DataLoader for efficient data loading, mixed precision training for performance, and export tools (TorchScript, ONNX) for production deployment.
Instructions
- When defining models, subclass
nn.Modulewith__init__for layers andforwardfor computation, usingnn.Sequentialfor simple stacks and custom forward logic for complex architectures. - When training, implement the standard loop: forward pass, loss computation,
loss.backward(),optimizer.step(),optimizer.zero_grad(), with gradient clipping viaclip_grad_norm_for stability. - When loading data, subclass
Datasetwith__len__and__getitem__, then useDataLoaderwithnum_workers=4andpin_memory=Truefor GPU training throughput. - When optimizing performance, use
torch.compile(model)on PyTorch 2.0+ for 20-50% speedup, mixed precision withtorch.amp.autocast()for halved memory and doubled throughput, andDistributedDataParallelfor multi-GPU training. - When doing transfer learning, load pretrained models from
torchvision.modelsor Hugging Face, freeze the backbone, and replace the classifier head for your task. - When deploying, use
torch.export()ortorch.jit.trace()for production,torch.onnx.export()for cross-framework compatibility, andtorch.quantizationfor INT8 inference speedup.
Examples
Example 1: Fine-tune a vision model for image classification
Related skills