PyTorch Elastic Distributed Training

What is torch.distributed.fsdp?

torch.distributed.fsdp (Fully Sharded Data Parallel) is PyTorch’s advanced distributed training strategy that optimizes memory usage by sharding model parameters, gradients, and optimizer states across multiple GPUs. Unlike traditional DDP (DistributedDataParallel) that replicates the entire model, FSDP:

  • Splits model parameters across devices
  • Reduces memory footprint per GPU
  • Maintains training efficiency through optimized communication
  • Supports giant models that don’t fit on single GPUs

Key Benefits

  • Trains models 10x larger than GPU memory
  • Automatic mixed precision support
  • Flexible sharding strategies
  • Native integration with DDP

Code Examples

1. Basic FSDP Setup

from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import ShardingStrategy

# Initialize distributed environment
torch.distributed.init_process_group(backend="nccl")

# Wrap model with FSDP
model = FSDP(
    MyLargeModel().cuda(),
    sharding_strategy=ShardingStrategy.FULL_SHARD,  # Shard params, grads, optimizer states
    mixed_precision=torch.float16
)

# Normal training loop
optimizer = torch.optim.Adam(model.parameters())
for inputs, labels in dataloader:
    outputs = model(inputs.cuda())
    loss = criterion(outputs, labels.cuda())
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

2. Custom Sharding Strategies

# Different sharding approaches
strategies = {
    "full_shard": ShardingStrategy.FULL_SHARD,  # Default (most memory efficient)
    "shard_grad": ShardingStrategy.SHARD_GRAD_OP,  # Only shard gradients
    "no_shard": ShardingStrategy.NO_SHARD  # Like DDP but with FSDP infrastructure
}

model = FSDP(
    MyModel(),
    sharding_strategy=strategies["full_shard"],
    cpu_offload=True  # Additional memory savings
)

3. FSDP with Activation Checkpointing

from torch.distributed.algorithms.checkpoint import checkpoint_wrapper

# Apply activation checkpointing to specific layers
checkpointed_model = checkpoint_wrapper(
    MyLargeLayer(),
    offload_to_cpu=True
)

model = FSDP(
    torch.nn.Sequential(
        checkpointed_model,
        MyOtherLayers()
    )
)

Common Methods & Configurations

Method/ConfigPurpose
FULL_SHARDShards everything (most memory efficient)
SHARD_GRAD_OPOnly shards gradients and optimizer states
NO_SHARDNo sharding (DDP-like behavior)
cpu_offloadOffloads parameters to CPU when not in use
mixed_precisionEnables AMP (Automatic Mixed Precision)
limit_all_gathersOptimizes communication overhead

Errors & Debugging Tips

Common Issues

  1. CUDA Out of Memory – Even with FSDP (usually indicates incorrect sharding setup)
  2. Deadlocks – From improper process synchronization
  3. Slow Training – Communication overhead issues
  4. Checkpointing Problems – When saving/loading FSDP models

Leave a Reply