0 Comments

What is torch.distributed.fsdp.fully_shard?

The fully_shard function is PyTorch’s granular, module-level API for applying Fully Sharded Data Parallelism (FSDP) to specific model components. Unlike wrapping entire models with FSDP, fully_shard enables:

  • Selective sharding of individual model components
  • Mixed parallelism strategies within a single model
  • Finer memory control for massive models
  • Progressive adoption of sharding techniques

Key Differences from FSDP Wrapping

FeatureFSDP() Wrapperfully_shard
ScopeEntire modelPer-module
FlexibilityLessMore
AdoptionAll-or-nothingGradual
Use CaseStandard FSDPCustom parallelism

Code Examples

1. Basic fully_shard Application

from torch.distributed.fsdp import fully_shard
from torch.distributed.fsdp.api import ShardingStrategy

class MyModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.layer1 = nn.Linear(1024, 4096)
        self.layer2 = nn.Linear(4096, 4096)
        self.layer3 = nn.Linear(4096, 1024)
        
        # Shard only specific layers
        fully_shard(self.layer1, sharding_strategy=ShardingStrategy.FULL_SHARD)
        fully_shard(self.layer2, sharding_strategy=ShardingStrategy.HYBRID_SHARD)

model = MyModel().cuda()

2. Mixed Precision with fully_shard

from torch.distributed.fsdp import MixedPrecision

# Configure per-module precision
fp16_policy = MixedPrecision(
    param_dtype=torch.float16,
    reduce_dtype=torch.float16,
    buffer_dtype=torch.float16
)

fully_shard(
    model.attention_layer,
    sharding_strategy=ShardingStrategy.FULL_SHARD,
    mixed_precision=fp16_policy
)

3. Combining with DDP

# Shard only the memory-intensive components
fully_shard(model.gpt_layers, sharding_strategy=ShardingStrategy.FULL_SHARD)

# Wrap entire model in DDP for remaining components
model = DDP(model)

Common Methods & Configurations

Method/ParameterDescription
ShardingStrategy.FULL_SHARDShards parameters, gradients and optimizer states
ShardingStrategy.HYBRID_SHARDShards within node but replicates across nodes
MixedPrecision()Configures per-module precision
cpu_offloadOffloads sharded parameters to CPU
ignored_modulesExcludes specific submodules from sharding

Errors & Debugging Tips

Common Pitfalls

  1. Incorrect Shard Ordering: Sharding child modules before parents
  2. Device Mismatches: Mixed CPU/GPU modules
  3. Communication Deadlocks: From improper process synchronization
  4. Checkpoint Conflicts: When saving/loading partially sharded models

Debugging Strategies

✔ Start small: Test with single sharded module first
✔ Use TORCH_DISTRIBUTED_DEBUG=DETAIL
✔ Validate shard placementprint(next(model.layer1.parameters()).device)
✔ Check gradient flowtorch.autograd.gradcheck()
✔ Profile memorytorch.cuda.memory_summary()


✅ People Also Ask (FAQ)

1. When should I use fully_shard vs FSDP wrapper?

Use fully_shard when you need:

  • Custom sharding strategies for different modules
  • To combine FSDP with other parallelism approaches
  • Gradual adoption of sharding in existing code
  • Special handling for specific model components

2. Can I fully_shard only part of my model?

Yes! This is the primary advantage:

# Shard only the attention layers
for attn_layer in model.attention_layers:
    fully_shard(attn_layer)

3. How does fully_shard affect performance?

Communication Overhead increases with:

  • More fine-grained sharding
  • Smaller per-shard parameter sizes
  • Frequent all-gather operations

Optimize by:

  • Sharding at sensible granularity (entire layers vs individual weights)
  • Using HYBRID_SHARD for multi-node setups
  • Enabling limit_all_gathers=True

4. Does fully_shard work with activation checkpointing?

Yes, and they combine powerfully:

from torch.distributed.algorithms.checkpoint import checkpoint_wrapper

layer = checkpoint_wrapper(MyBigLayer())
fully_shard(layer)  # Works perfectly

5. How to handle model saving/loading?

Special considerations:

  1. Save sharded statestorch.save(model.state_dict(), "model.pth")
  2. Load with same sharding: Reapply fully_shard before loading
  3. Use distributed checkpoints: For multi-node scenarios

Advanced Optimization Techniques

  1. Overlap Communication:

fully_shard(
    module,
    process_group=...,
    overlap_comm=True  # Hides communication latency
)
  1. Custom Process Groups:

python

Copy

from torch.distributed import new_group
shard_group = new_group(...)
fully_shard(module, process_group=shard_group)
  1. Memory-Efficient Initialization:
with torch.device("meta"):  # Create model without allocating memory
    huge_model = MyGiantModel()
    
# Materialize parameters only when sharded
fully_shard(huge_model.layer1)
huge_model.layer1.to_empty(device="cuda")

Conclusion

PyTorch’s fully_shard API provides unprecedented control over model sharding strategies. Key takeaways:

  1. Progressive adoption – Start sharding only your heaviest layers
  2. Mixed strategies – Combine different sharding approaches
  3. Precision control – Configure per-module mixed precision
  4. Debug carefully – Sharding introduces new failure modes

Ready to shard? Begin with:

python

Copy

# Just shard one layer to start
fully_shard(model.your_biggest_layer)

Leave a Reply

Your email address will not be published. Required fields are marked *

Related Posts