Overview
Direct Answer
Gradient checkpointing is a memory optimisation technique that reduces peak GPU memory consumption during neural network training by selectively discarding intermediate activations during the forward pass and recomputing them on-demand during backpropagation. This approach trades increased computational cost for substantially lower memory requirements, enabling training of larger models or larger batch sizes on fixed hardware.
How It Works
During the forward pass, designated checkpoint layers store only their input activations whilst discarding intermediate values. During backpropagation, the forward computation is re-executed for selected segments to regenerate the discarded activations needed for gradient calculation. This selective recomputation strategy—typically applied to deep residual or transformer architectures—reduces memory scaling from linear to sub-linear with network depth whilst introducing modest computational overhead.
Why It Matters
Training state-of-the-art large language models and vision transformers often exceeds available GPU memory. Checkpointing enables organisations to train parameter-efficient variants of larger models within existing infrastructure budgets, avoiding costly hardware upgrades. This is particularly valuable in resource-constrained environments and reduces time-to-deployment for frontier models.
Common Applications
The technique is widely employed in training transformer models, large vision transformers, and deep convolutional networks where memory is the limiting factor. It is integral to frameworks supporting large-scale model training in natural language processing and computer vision research.
Key Considerations
The computational overhead typically ranges from 20–50% additional forward-pass computation, making the optimisation most effective when memory is the critical bottleneck rather than compute. Checkpoint granularity must be carefully selected to balance memory savings against recomputation cost; suboptimal choices can degrade wall-clock training speed despite reducing memory usage.
More in Deep Learning
Graph Neural Network
ArchitecturesA neural network designed to operate on graph-structured data, learning representations of nodes, edges, and entire graphs.
Attention Head
Training & OptimisationAn individual attention computation within a multi-head attention layer that learns to focus on different aspects of the input, with outputs concatenated for richer representations.
Contrastive Learning
ArchitecturesA self-supervised learning approach that trains models by comparing similar and dissimilar pairs of data representations.
Pooling Layer
ArchitecturesA neural network layer that reduces spatial dimensions by aggregating values, commonly using max or average operations.
Prefix Tuning
Language ModelsA parameter-efficient method that prepends trainable continuous vectors to the input of each transformer layer, guiding model behaviour without altering base parameters.
ReLU
Training & OptimisationRectified Linear Unit — an activation function that outputs the input directly if positive, otherwise outputs zero.
Softmax Function
Training & OptimisationAn activation function that converts a vector of numbers into a probability distribution, commonly used in multi-class classification.
Dropout
Training & OptimisationA regularisation technique that randomly deactivates neurons during training to prevent co-adaptation and reduce overfitting.