0 Comments

What is torch.distributed.elastic?

torch.distributed.elastic is PyTorch’s framework for fault-tolerant, elastic distributed training that automatically adapts to cluster changes. Unlike static distributed training, elastic training:

  • Handles node failures gracefully – Automatically recovers from worker crashes
  • Supports dynamic scaling – Adjusts to adding/removing workers mid-training
  • Maintains training continuity – Preserves progress across interruptions
  • Works with cloud environments – Ideal for spot instances and preemptible VMs

Key Components

  • Agents (coordinator processes)
  • Worker groups (elastic groups of processes)
  • Rendezvous (dynamic worker discovery)
  • Failure handlers (automatic recovery)

Code Examples

1. Basic Elastic Training Setup

import torch.distributed.elastic as elastic
from torch.nn.parallel import DistributedDataParallel as DDP

def train_loop():
    # Initialize elastic process group
    elastic.init_process_group(backend="nccl")
    
    model = DDP(MyModel().cuda())
    optimizer = torch.optim.Adam(model.parameters())
    
    # Your training logic here
    for epoch in range(epochs):
        for batch in dataloader:
            # Standard training steps
            ...

# Launch with elastic agent
spec = elastic.agent.server.WorkerSpec(
    role="trainer",
    local_world_size=4,  # GPUs per node
    entrypoint=train_loop
)

agent = elastic.agent.server.LocalElasticAgent(spec)
agent.run()

2. Custom Rendezvous Backend

from torch.distributed.elastic.rendezvous import RendezvousHandler

class CustomRendezvous(RendezvousHandler):
    def next_rendezvous(self):
        # Implement custom worker discovery
        return store, rank, world_size

# Configure elastic to use custom rendezvous
elastic.rendezvous.registry.register("custom", CustomRendezvous)

3. Handling Worker Failures

from torch.distributed.elastic.multiprocessing import Std

# Configure error handling
mp = elastic.multiprocessing.Std(
    entrypoint=train_loop,
    log_dir="./logs",
    monitor_interval=5,  # Check worker health every 5s
    max_restarts=3      # Maximum restart attempts
)

result = mp.run()

Common Methods & Components

Component/MethodPurpose
LocalElasticAgentCoordinates local workers
WorkerSpecDefines worker configuration
RendezvousHandlerManages worker discovery
init_process_group()Elastic-aware initialization
record()Tracks training state for recovery

Errors & Debugging Tips

Common Issues

  1. Rendezvous timeouts – Workers fail to discover each other
  2. Version mismatches – Different PyTorch versions across nodes
  3. Partial failures – Some workers crash while others continue
  4. Checkpoint conflicts – Multiple workers trying to save simultaneously

Leave a Reply

Your email address will not be published. Required fields are marked *

Related Posts