What is torch.distributed.fsdp.fully_shard?
The fully_shard
function is PyTorch’s granular, module-level API for applying Fully Sharded Data Parallelism (FSDP) to specific model components. Unlike wrapping entire models with FSDP, fully_shard
enables:
- Selective sharding of individual model components
- Mixed parallelism strategies within a single model
- Finer memory control for massive models
- Progressive adoption of sharding techniques
Key Differences from FSDP Wrapping
Feature | FSDP() Wrapper | fully_shard |
---|---|---|
Scope | Entire model | Per-module |
Flexibility | Less | More |
Adoption | All-or-nothing | Gradual |
Use Case | Standard FSDP | Custom parallelism |
Code Examples
1. Basic fully_shard Application
from torch.distributed.fsdp import fully_shard from torch.distributed.fsdp.api import ShardingStrategy class MyModel(nn.Module): def __init__(self): super().__init__() self.layer1 = nn.Linear(1024, 4096) self.layer2 = nn.Linear(4096, 4096) self.layer3 = nn.Linear(4096, 1024) # Shard only specific layers fully_shard(self.layer1, sharding_strategy=ShardingStrategy.FULL_SHARD) fully_shard(self.layer2, sharding_strategy=ShardingStrategy.HYBRID_SHARD) model = MyModel().cuda()
2. Mixed Precision with fully_shard
from torch.distributed.fsdp import MixedPrecision # Configure per-module precision fp16_policy = MixedPrecision( param_dtype=torch.float16, reduce_dtype=torch.float16, buffer_dtype=torch.float16 ) fully_shard( model.attention_layer, sharding_strategy=ShardingStrategy.FULL_SHARD, mixed_precision=fp16_policy )
3. Combining with DDP
# Shard only the memory-intensive components fully_shard(model.gpt_layers, sharding_strategy=ShardingStrategy.FULL_SHARD) # Wrap entire model in DDP for remaining components model = DDP(model)
Common Methods & Configurations
Method/Parameter | Description |
---|---|
ShardingStrategy.FULL_SHARD | Shards parameters, gradients and optimizer states |
ShardingStrategy.HYBRID_SHARD | Shards within node but replicates across nodes |
MixedPrecision() | Configures per-module precision |
cpu_offload | Offloads sharded parameters to CPU |
ignored_modules | Excludes specific submodules from sharding |
Errors & Debugging Tips
Common Pitfalls
- Incorrect Shard Ordering: Sharding child modules before parents
- Device Mismatches: Mixed CPU/GPU modules
- Communication Deadlocks: From improper process synchronization
- Checkpoint Conflicts: When saving/loading partially sharded models
Debugging Strategies
✔ Start small: Test with single sharded module first
✔ Use TORCH_DISTRIBUTED_DEBUG=DETAIL
✔ Validate shard placement: print(next(model.layer1.parameters()).device)
✔ Check gradient flow: torch.autograd.gradcheck()
✔ Profile memory: torch.cuda.memory_summary()
✅ People Also Ask (FAQ)
1. When should I use fully_shard vs FSDP wrapper?
Use fully_shard
when you need:
- Custom sharding strategies for different modules
- To combine FSDP with other parallelism approaches
- Gradual adoption of sharding in existing code
- Special handling for specific model components
2. Can I fully_shard only part of my model?
Yes! This is the primary advantage:
# Shard only the attention layers for attn_layer in model.attention_layers: fully_shard(attn_layer)
3. How does fully_shard affect performance?
Communication Overhead increases with:
- More fine-grained sharding
- Smaller per-shard parameter sizes
- Frequent all-gather operations
Optimize by:
- Sharding at sensible granularity (entire layers vs individual weights)
- Using HYBRID_SHARD for multi-node setups
- Enabling limit_all_gathers=True
4. Does fully_shard work with activation checkpointing?
Yes, and they combine powerfully:
from torch.distributed.algorithms.checkpoint import checkpoint_wrapper layer = checkpoint_wrapper(MyBigLayer()) fully_shard(layer) # Works perfectly
5. How to handle model saving/loading?
Special considerations:
- Save sharded states:
torch.save(model.state_dict(), "model.pth")
- Load with same sharding: Reapply fully_shard before loading
- Use distributed checkpoints: For multi-node scenarios
Advanced Optimization Techniques
- Overlap Communication:
fully_shard( module, process_group=..., overlap_comm=True # Hides communication latency )
- Custom Process Groups:
python
Copy
from torch.distributed import new_group shard_group = new_group(...) fully_shard(module, process_group=shard_group)
- Memory-Efficient Initialization:
with torch.device("meta"): # Create model without allocating memory huge_model = MyGiantModel() # Materialize parameters only when sharded fully_shard(huge_model.layer1) huge_model.layer1.to_empty(device="cuda")
Conclusion
PyTorch’s fully_shard
API provides unprecedented control over model sharding strategies. Key takeaways:
- Progressive adoption – Start sharding only your heaviest layers
- Mixed strategies – Combine different sharding approaches
- Precision control – Configure per-module mixed precision
- Debug carefully – Sharding introduces new failure modes
Ready to shard? Begin with:
python
Copy
# Just shard one layer to start fully_shard(model.your_biggest_layer)