What is torch.distributed.fsdp?
torch.distributed.fsdp
(Fully Sharded Data Parallel) is PyTorch’s advanced distributed training strategy that optimizes memory usage by sharding model parameters, gradients, and optimizer states across multiple GPUs. Unlike traditional DDP (DistributedDataParallel) that replicates the entire model, FSDP:
- Splits model parameters across devices
- Reduces memory footprint per GPU
- Maintains training efficiency through optimized communication
- Supports giant models that don’t fit on single GPUs
Key Benefits
- Trains models 10x larger than GPU memory
- Automatic mixed precision support
- Flexible sharding strategies
- Native integration with DDP
Code Examples
1. Basic FSDP Setup
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.distributed.fsdp import ShardingStrategy # Initialize distributed environment torch.distributed.init_process_group(backend="nccl") # Wrap model with FSDP model = FSDP( MyLargeModel().cuda(), sharding_strategy=ShardingStrategy.FULL_SHARD, # Shard params, grads, optimizer states mixed_precision=torch.float16 ) # Normal training loop optimizer = torch.optim.Adam(model.parameters()) for inputs, labels in dataloader: outputs = model(inputs.cuda()) loss = criterion(outputs, labels.cuda()) loss.backward() optimizer.step() optimizer.zero_grad()
2. Custom Sharding Strategies
# Different sharding approaches strategies = { "full_shard": ShardingStrategy.FULL_SHARD, # Default (most memory efficient) "shard_grad": ShardingStrategy.SHARD_GRAD_OP, # Only shard gradients "no_shard": ShardingStrategy.NO_SHARD # Like DDP but with FSDP infrastructure } model = FSDP( MyModel(), sharding_strategy=strategies["full_shard"], cpu_offload=True # Additional memory savings )
3. FSDP with Activation Checkpointing
from torch.distributed.algorithms.checkpoint import checkpoint_wrapper # Apply activation checkpointing to specific layers checkpointed_model = checkpoint_wrapper( MyLargeLayer(), offload_to_cpu=True ) model = FSDP( torch.nn.Sequential( checkpointed_model, MyOtherLayers() ) )
Common Methods & Configurations
Method/Config | Purpose |
---|---|
FULL_SHARD | Shards everything (most memory efficient) |
SHARD_GRAD_OP | Only shards gradients and optimizer states |
NO_SHARD | No sharding (DDP-like behavior) |
cpu_offload | Offloads parameters to CPU when not in use |
mixed_precision | Enables AMP (Automatic Mixed Precision) |
limit_all_gathers | Optimizes communication overhead |
Errors & Debugging Tips
Common Issues
- CUDA Out of Memory – Even with FSDP (usually indicates incorrect sharding setup)
- Deadlocks – From improper process synchronization
- Slow Training – Communication overhead issues
- Checkpointing Problems – When saving/loading FSDP models