Welcome back, future AI architect! In our journey so far, we’ve built a strong foundation in deep learning, mastering neural network architectures, understanding training workflows, and optimizing models. We’ve even considered how powerful hardware like GPUs accelerate our tasks. But what happens when your model becomes so massive it won’t fit on a single GPU? Or when your dataset is so enormous that training takes weeks, even on the most powerful single machine?

This chapter is your gateway to solving these grand challenges. We’ll explore the fascinating world of distributed training, where the power of multiple devices—be it GPUs, CPUs, or even entire clusters—is harnessed to train models that were once thought impossible. You’ll learn the core paradigms of how to split work across machines, understand the crucial considerations for effective scaling, and get hands-on with implementing distributed training using PyTorch. By the end, you’ll be equipped to tackle real-world problems that demand immense computational resources, significantly boosting your capabilities as an AI/ML engineer.

Before we dive in, make sure you’re comfortable with:

  • Deep learning fundamentals, including model architecture design and training loops (Chapters 10-12).
  • Optimization techniques and hyperparameter tuning (Chapter 13).
  • Understanding of GPU accelerators and their role in deep learning (Chapter 16).

Let’s unlock the power of distributed AI!

The Need for Speed and Scale: Why Go Distributed?

Imagine you’re building the next-generation Large Language Model (LLM) or a state-of-the-art multimodal AI that can understand both images and text. These models can have hundreds of billions of parameters, and their training datasets can consist of petabytes of information. A single GPU, no matter how powerful, simply cannot handle this scale. This is where distributed training becomes not just an optimization, but a necessity.

Here are the primary reasons we embrace distributed training:

  1. Model Size Exceeds Single-Device Memory: Modern deep learning models, especially LLMs and foundation models, are incredibly large. Their parameters alone can consume hundreds of gigabytes, far exceeding the memory capacity of even high-end GPUs (which typically have 24GB to 80GB of VRAM as of early 2026). Distributed training allows us to split the model across multiple devices.
  2. Massive Datasets: Training on vast datasets (millions or billions of examples) takes an unacceptably long time on a single device. Distributing the data across multiple workers allows for parallel processing, dramatically reducing training time.
  3. Faster Experimentation: Even for models and datasets that could fit on a single device, distributed training can significantly speed up the training process, allowing researchers and engineers to iterate on ideas and experiments much more quickly. This rapid feedback loop is crucial for innovation.
  4. Resource Efficiency: By efficiently utilizing multiple, potentially less expensive, GPUs or even CPUs in parallel, distributed training can sometimes be more cost-effective than relying on a single, extremely high-end machine.

Core Paradigms of Distributed Training

When you decide to distribute your training workload, you generally have two main strategies: Data Parallelism and Model Parallelism. Often, advanced systems combine both.

1. Data Parallelism: Sharing the Workload

Data parallelism is the most common and often the easiest form of distributed training to implement.

What it is: In data parallelism, every participating device (e.g., GPU) gets a complete copy of the neural network model. However, each device processes a different subset (a mini-batch) of the training data.

How it works:

  1. Replication: The model is replicated on each device.
  2. Data Sharding: The total training mini-batch is divided into smaller sub-mini-batches, and each device receives one.
  3. Local Forward Pass & Gradient Computation: Each device independently performs a forward pass and computes gradients for its sub-mini-batch.
  4. Gradient Aggregation: After local gradient computation, the gradients from all devices are aggregated (typically averaged) to ensure all model copies stay synchronized. This aggregation step is crucial for maintaining a consistent global model state.
  5. Parameter Update: The aggregated gradients are then used to update the model parameters on each device, ensuring all model copies are identical before the next training step.

This process essentially allows you to increase your effective batch size without increasing the memory footprint on any single device, leading to faster training.

Let’s visualize this with a simple flowchart:

flowchart TD Start[Training Step Starts] --> Data[Divide Data Batch] Data --> Worker1(Worker 1: Model Copy + Data Slice 1) Data --> Worker2(Worker 2: Model Copy + Data Slice 2) Data --> WorkerN(Worker N: Model Copy + Data Slice N) Worker1 -->|Compute Gradients G1| Grad1[Gradients G1] Worker2 -->|Compute Gradients G2| Grad2[Gradients G2] WorkerN -->|Compute Gradients GN| GradN[Gradients GN] Grad1 & Grad2 & GradN --> Aggregate[Aggregate Gradients] Aggregate --> Update[Update Model Parameters on all Workers] Update --> End[Training Step Ends]

Synchronous vs. Asynchronous Data Parallelism:

  • Synchronous: All devices must complete their gradient computation and aggregation before any model parameters are updated. This ensures consistent updates but can be bottlenecked by the slowest device. This is the most common and generally preferred approach for stable training.
  • Asynchronous: Devices update the global model parameters as soon as they complete their local gradient computation, without waiting for others. This can be faster but might lead to “stale gradient” issues, where parameters are updated based on gradients computed from older model versions, potentially hurting convergence. Synchronous is almost always preferred in modern deep learning frameworks.

2. Model Parallelism: Splitting the Giant

Model parallelism (also known as vertical parallelism or layer-wise parallelism, or more granularly, tensor parallelism and pipeline parallelism) is employed when the model itself is too large to fit into the memory of a single device.

What it is: Instead of replicating the entire model, different layers (or even parts of layers) of the neural network are placed on different devices.

How it works:

  1. Model Partitioning: The model is logically divided. For instance, the first few layers might reside on GPU 1, the next few on GPU 2, and so on.
  2. Sequential Processing: During the forward pass, data flows from one device to the next as it progresses through the model’s layers. GPU 1 processes its layers, then passes the intermediate activations to GPU 2, which processes its layers, and so on.
  3. Backward Pass: The gradients flow backward through the devices in a similar sequential fashion.

Model parallelism is more complex to implement than data parallelism because it requires careful orchestration of data transfer between devices and synchronization of computations. It often introduces communication bottlenecks, as devices must wait for intermediate results from other devices.

Let’s illustrate model parallelism:

flowchart LR Input[Input Data] --> GPU1(GPU 1: Layer 1-3) GPU1 -->|Activations 1| GPU2(GPU 2: Layer 4-6) GPU2 -->|Activations 2| GPU3(GPU 3: Layer 7-9) GPU3 --> Output[Output]

Types of Model Parallelism:

  • Layer-wise Parallelism: Each device gets one or more full layers.
  • Tensor Parallelism: A single layer’s operations (e.g., a large matrix multiplication) are split across multiple devices. This is common for very large dense layers in LLMs.
  • Pipeline Parallelism: Combines ideas from both, creating a pipeline of mini-batches flowing through different stages (groups of layers) on different devices, minimizing idle time.

3. Hybrid Approaches

For truly massive models and datasets, a hybrid approach combining both data and model parallelism is often necessary. For example, you might use model parallelism to split a huge LLM across several GPUs on one node, and then use data parallelism to replicate this “model-parallel” group across multiple nodes, each handling a different subset of the training data.

Framework Support for Distributed Training

Modern deep learning frameworks provide robust tools to simplify distributed training.

  • PyTorch:

    • torch.distributed: The backbone for distributed communication primitives.
    • DistributedDataParallel (DDP): The go-to for data parallelism. It’s highly optimized and widely used.
    • Fully Sharded Data Parallel (FSDP): An advanced technique (introduced in PyTorch 1.11, improved significantly in 2.x) that shards model parameters, gradients, and optimizer states across GPUs, effectively enabling larger models to be trained with data parallelism.
    • torch.distributed.run (formerly torch.distributed.launch): A utility for launching distributed training scripts.
    • torch.nn.parallel.DistributedDataParallel Official Docs
    • torch.distributed.fsdp Official Docs
  • TensorFlow:

    • tf.distribute.Strategy: A high-level API for distributed training.
      • MirroredStrategy: For synchronous data parallelism on a single host with multiple GPUs.
      • MultiWorkerMirroredStrategy: For synchronous data parallelism across multiple hosts with multiple GPUs.
      • ParameterServerStrategy: For parameter server-based training (often used for CPU-heavy workloads or older paradigms).
    • tf.distribute Official Docs
  • JAX:

    • jax.pmap: For data parallelism across multiple devices on a single host.
    • jax.experimental.shard_map: A more flexible and composable API for both data and model parallelism.
    • JAX Distributed Official Docs

As of 2026, PyTorch’s DDP and FSDP are incredibly popular for their flexibility and performance, especially in research and for training large models. We’ll focus on DDP for our hands-on example due to its widespread applicability and ease of understanding for data parallelism.

Step-by-Step Implementation: PyTorch DDP

Let’s put theory into practice! We’ll set up a simple convolutional neural network (CNN) and train it using PyTorch’s DistributedDataParallel (DDP) across multiple GPUs on a single machine. If you only have one GPU, don’t worry, you can still follow along by simulating multiple processes, though the performance benefits won’t be as apparent.

Prerequisites:

  • Python 3.9+
  • PyTorch 2.x (as of 2026-01-17, PyTorch 2.x is the stable release. We’ll use the latest stable features).
  • torchvision
  • CUDA-enabled GPUs (if running on actual hardware).

Installation:

# Ensure you have the correct CUDA version for your system
# Example for CUDA 12.1, PyTorch 2.x
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121

1. The Basic Training Script (Single GPU Baseline)

First, let’s establish a baseline: a standard training script for a simple CNN on the CIFAR-10 dataset. This is what we’ll modify for DDP.

Create a file named cnn_single_gpu.py:

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

# 1. Define the CNN Model
class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)
        self.relu1 = nn.ReLU()
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.relu2 = nn.ReLU()
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.fc1 = nn.Linear(64 * 8 * 8, 512) # CIFAR-10 images are 32x32. After 2 pooling layers, 32/4 = 8.
        self.relu3 = nn.ReLU()
        self.fc2 = nn.Linear(512, 10) # 10 classes for CIFAR-10

    def forward(self, x):
        x = self.pool1(self.relu1(self.conv1(x)))
        x = self.pool2(self.relu2(self.conv2(x)))
        x = x.view(-1, 64 * 8 * 8) # Flatten the tensor
        x = self.relu3(self.fc1(x))
        x = self.fc2(x)
        return x

def train_model(model, device, train_loader, optimizer, criterion, epoch):
    model.train()
    running_loss = 0.0
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
        if batch_idx % 100 == 0:
            print(f'Epoch: {epoch}, Batch: {batch_idx}/{len(train_loader)}, Loss: {loss.item():.4f}')
    print(f'Epoch {epoch} finished, Avg Loss: {running_loss / len(train_loader):.4f}')

def test_model(model, device, test_loader, criterion):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += criterion(output, target).item()
            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)
    print(f'\nTest set: Average loss: {test_loss:.4f}, Accuracy: {correct}/{len(test_loader.dataset)} ({100. * correct / len(test_loader.dataset):.2f}%)\n')

def main_single_gpu():
    # 2. Setup Device, DataLoaders
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])

    train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
    test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)

    train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=2)
    test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False, num_workers=2)

    # 3. Instantiate Model, Optimizer, Loss Function
    model = SimpleCNN().to(device)
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    criterion = nn.CrossEntropyLoss()

    # 4. Training Loop
    num_epochs = 5
    for epoch in range(1, num_epochs + 1):
        train_model(model, device, train_loader, optimizer, criterion, epoch)
        test_model(model, device, test_loader, criterion)

if __name__ == '__main__':
    main_single_gpu()

To run this baseline:

python cnn_single_gpu.py

Observe the training loss and accuracy. This will be our reference point.

2. Converting to DistributedDataParallel (DDP)

Now, let’s adapt cnn_single_gpu.py to cnn_ddp.py for distributed training. We need to introduce several DDP-specific components.

Create a new file named cnn_ddp.py:

import os
import torch
import torch.nn as nn
import torch.optim as optim
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data.distributed import DistributedSampler
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

# 1. Define the CNN Model (same as before)
class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)
        self.relu1 = nn.ReLU()
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.relu2 = nn.ReLU()
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.fc1 = nn.Linear(64 * 8 * 8, 512)
        self.relu3 = nn.ReLU()
        self.fc2 = nn.Linear(512, 10)

    def forward(self, x):
        x = self.pool1(self.relu1(self.conv1(x)))
        x = self.pool2(self.relu2(self.conv2(x)))
        x = x.view(-1, 64 * 8 * 8)
        x = self.relu3(self.fc1(x))
        x = self.fc2(x)
        return x

def train_model(model, device, train_loader, optimizer, criterion, epoch, rank):
    model.train()
    # Ensure distributed sampler resamples data correctly for each epoch
    train_loader.sampler.set_epoch(epoch)
    running_loss = 0.0
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
        # Only print from rank 0 to avoid cluttered output
        if rank == 0 and batch_idx % 100 == 0:
            print(f'Epoch: {epoch}, Batch: {batch_idx}/{len(train_loader)}, Loss: {loss.item():.4f}')
    if rank == 0:
        print(f'Epoch {epoch} finished, Avg Loss: {running_loss / len(train_loader):.4f}')

def test_model(model, device, test_loader, criterion, rank):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += criterion(output, target).item()
            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()

    # Reduce all metrics from all ranks to rank 0 for accurate aggregation
    # For test loss and correct counts across all processes
    # test_loss needs to be summed across all processes then divided by total dataset size
    # correct needs to be summed across all processes
    total_test_loss = torch.tensor(test_loss).to(device)
    total_correct = torch.tensor(correct).to(device)
    total_dataset_size = torch.tensor(len(test_loader.dataset)).to(device)

    dist.all_reduce(total_test_loss, op=dist.ReduceOp.SUM)
    dist.all_reduce(total_correct, op=dist.ReduceOp.SUM)
    dist.all_reduce(total_dataset_size, op=dist.ReduceOp.SUM)

    if rank == 0:
        # Note: test_loader.dataset might not be the total size if DistributedSampler was used on test set
        # For simplicity, we assume the test_loader.dataset length is the true total size for this example.
        # In a real-world scenario, you might pass the total original dataset size or sum it up.
        avg_test_loss = total_test_loss.item() / total_dataset_size.item()
        accuracy = 100. * total_correct.item() / total_dataset_size.item()
        print(f'\nTest set: Average loss: {avg_test_loss:.4f}, Accuracy: {total_correct.item()}/{total_dataset_size.item()} ({accuracy:.2f}%)\n')

def setup(rank, world_size):
    # Initialize the distributed environment.
    # 'env://' means we're using environment variables for configuration.
    os.environ['MASTER_ADDR'] = 'localhost' # Master node IP address (for single machine, localhost is fine)
    os.environ['MASTER_PORT'] = '12355'    # Port for communication
    dist.init_process_group("nccl", rank=rank, world_size=world_size) # "nccl" is recommended for GPU training
    torch.cuda.set_device(rank) # Each process uses a specific GPU

def cleanup():
    dist.destroy_process_group()

def main_ddp(rank, world_size):
    # 2. Setup Distributed Environment
    setup(rank, world_size)
    print(f"Rank {rank}/{world_size} initialized.")

    # 3. Setup Device, DataLoaders
    # Each process uses its dedicated GPU
    device = torch.device(f"cuda:{rank}")
    print(f"Rank {rank} using device: {device}")

    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])

    train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
    test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)

    # Use DistributedSampler to ensure each process gets a unique subset of data
    train_sampler = DistributedSampler(train_dataset, num_replicas=world_size, rank=rank, shuffle=True)
    test_sampler = DistributedSampler(test_dataset, num_replicas=world_size, rank=rank, shuffle=False) # Often not shuffled for testing

    # DataLoader now uses the sampler
    # Note: When using DistributedSampler, the batch_size argument refers to the batch size per GPU.
    train_loader = DataLoader(train_dataset, batch_size=64, sampler=train_sampler, num_workers=2)
    test_loader = DataLoader(test_dataset, batch_size=64, sampler=test_sampler, num_workers=2)

    # 4. Instantiate Model, Optimizer, Loss Function
    model = SimpleCNN().to(device)
    # Wrap the model with DDP
    model = DDP(model, device_ids=[rank]) # device_ids specifies which GPU this process uses

    optimizer = optim.Adam(model.parameters(), lr=0.001)
    criterion = nn.CrossEntropyLoss()

    # 5. Training Loop
    num_epochs = 5
    for epoch in range(1, num_epochs + 1):
        train_model(model, device, train_loader, optimizer, criterion, epoch, rank)
        test_model(model, device, test_loader, criterion, rank) # Pass rank to test_model as well

    # 6. Cleanup Distributed Environment
    cleanup()
    print(f"Rank {rank} finished and cleaned up.")

if __name__ == '__main__':
    # Get world size (number of GPUs/processes) from environment variable set by torchrun
    world_size = int(os.environ["WORLD_SIZE"])
    # Get current rank from environment variable set by torchrun
    rank = int(os.environ["RANK"])
    main_ddp(rank, world_size)

Key Changes Explained:

  1. Import torch.distributed and DDP:

    import torch.distributed as dist
    from torch.nn.parallel import DistributedDataParallel as DDP
    from torch.utils.data.distributed import DistributedSampler
    

    These are the core components for DDP.

  2. setup(rank, world_size) function:

    def setup(rank, world_size):
        os.environ['MASTER_ADDR'] = 'localhost'
        os.environ['MASTER_PORT'] = '12355'
        dist.init_process_group("nccl", rank=rank, world_size=world_size)
        torch.cuda.set_device(rank)
    
    • MASTER_ADDR and MASTER_PORT: These environment variables tell the processes where the “master” process is located to coordinate communication. For a single machine, localhost and an arbitrary free port are fine. For multiple machines, you’d specify the master machine’s IP.
    • dist.init_process_group("nccl", ...): This initializes the distributed backend. “nccl” (NVIDIA Collective Communications Library) is highly optimized for GPU-to-GPU communication and is the recommended backend for multi-GPU training. Other options include “gloo” (CPU/network-based) and “mpi”.
    • torch.cuda.set_device(rank): Crucially, each process needs to be assigned a specific GPU. We use the rank (a unique ID for each process, from 0 to world_size - 1) to assign it to cuda:0, cuda:1, etc.
  3. cleanup() function:

    def cleanup():
        dist.destroy_process_group()
    

    Always call this at the end to properly shut down the distributed environment.

  4. DistributedSampler:

    train_sampler = DistributedSampler(train_dataset, num_replicas=world_size, rank=rank, shuffle=True)
    test_sampler = DistributedSampler(test_dataset, num_replicas=world_size, rank=rank, shuffle=False)
    train_loader = DataLoader(train_dataset, batch_size=64, sampler=train_sampler, num_workers=2)
    test_loader = DataLoader(test_dataset, batch_size=64, sampler=test_sampler, num_workers=2)
    
    • This is essential for data parallelism! The DistributedSampler ensures that each process receives a unique and non-overlapping subset of the data for each epoch. It handles shuffling and ensures that the total dataset is iterated through exactly once per epoch across all processes.
    • Important: When using DistributedSampler, you should not set shuffle=True directly in the DataLoader, as the sampler handles shuffling.
    • train_loader.sampler.set_epoch(epoch): This line in train_model is vital. It tells the sampler to re-shuffle the data for the current epoch, ensuring different data splits are used in each epoch.
  5. Wrapping the Model with DDP:

    model = SimpleCNN().to(device)
    model = DDP(model, device_ids=[rank])
    
    • After moving your model to the correct device (.to(device)), you wrap it with DDP.
    • device_ids=[rank]: This tells DDP which GPU this particular process is managing. DDP then handles the gradient synchronization (the all_reduce operation) automatically during loss.backward().
  6. Conditional Printing:

    if rank == 0:
        print(f'Epoch: {epoch}, Batch: {batch_idx}/{len(train_loader)}, Loss: {loss.item():.4f}')
    

    Since multiple processes are running simultaneously, each will try to print to the console. To avoid overwhelming output, it’s a common practice to only print logs from rank == 0 (the “master” process for logging).

  7. Aggregating Test Metrics:

    total_test_loss = torch.tensor(test_loss).to(device)
    total_correct = torch.tensor(correct).to(device)
    total_dataset_size = torch.tensor(len(test_loader.dataset)).to(device)
    
    dist.all_reduce(total_test_loss, op=dist.ReduceOp.SUM)
    dist.all_reduce(total_correct, op=dist.ReduceOp.SUM)
    dist.all_reduce(total_dataset_size, op=dist.ReduceOp.SUM)
    

    For evaluation metrics like test loss and accuracy, each process calculates its metrics on its subset of the test data. To get the overall metric, you need to all_reduce (sum) these values across all processes. dist.all_reduce sums the tensor across all processes and then broadcasts the sum back to all processes.

  8. Main Execution Block:

    if __name__ == '__main__':
        world_size = int(os.environ["WORLD_SIZE"])
        rank = int(os.environ["RANK"])
        main_ddp(rank, world_size)
    

    Instead of directly calling main_ddp(), we extract WORLD_SIZE and RANK from environment variables. These are automatically set by the torchrun utility, which is how we’ll launch our distributed script.

3. Launching the DDP Script with torchrun

torchrun (formerly torch.distributed.launch in older PyTorch versions) is the recommended way to launch multi-process distributed training jobs. It handles setting up the necessary environment variables (MASTER_ADDR, MASTER_PORT, RANK, WORLD_SIZE, etc.) and launching multiple Python processes, each assigned a unique rank and a specific GPU.

To run cnn_ddp.py on two GPUs (if you have them):

torchrun --nproc_per_node=2 cnn_ddp.py
  • --nproc_per_node=2: This tells torchrun to launch 2 processes on the current node. Each process will be assigned a different GPU (GPU 0 and GPU 1). If you have 4 GPUs, you could use --nproc_per_node=4.
  • cnn_ddp.py: Your training script.

If you only have one GPU, you can still run with nproc_per_node=1 to see the DDP setup working, though you won’t get actual parallelism benefits.

Observe the output. You should see logs from Rank 0 printing the training progress. The overall training speed should be faster than the single-GPU baseline if you’re using multiple GPUs.

Mini-Challenge: Customize Logging

Right now, only Rank 0 prints training progress. This is good for avoiding spam, but sometimes you might want to see specific information from other ranks, or perhaps aggregate metrics in a different way.

Challenge: Modify the train_model function in cnn_ddp.py so that:

  1. All ranks print a message at the start of their train_model function, indicating their rank and the device they are using.
  2. Only Rank 0 prints the detailed loss.item() every 100 batches.
  3. All ranks print their final Avg Loss for the epoch after the train_model function finishes, but ensure it’s still clear which rank is printing.

Hint: Use dist.get_rank() to determine the current process’s rank. Remember that rank is already passed to train_model.

What to observe/learn: This exercise reinforces your understanding of how to control logging and perform rank-specific operations in a distributed environment, which is crucial for debugging and monitoring. You’ll see how each process operates somewhat independently, yet coordinates.

Common Pitfalls & Troubleshooting

Distributed training introduces new complexities. Here are some common issues and how to approach them:

  1. Deadlocks and Hangs:

    • Cause: Most often due to mismatched all_reduce or barrier calls. If one process calls dist.barrier() and another doesn’t, or if they call collective operations in a different order, the processes will wait indefinitely for each other.
    • Troubleshooting:
      • Ensure all processes execute the same collective operations in the same order.
      • Check for conditional logic that might prevent some ranks from reaching a dist.barrier() or DDP synchronization point.
      • Verify that DDP is correctly initialized and wrapping the model.
      • Use torch.distributed.debug.custom_debug_hooks (PyTorch 2.x+) for more detailed debugging of communication.
  2. Incorrect DistributedSampler Usage:

    • Cause: Forgetting to set shuffle=False in DataLoader when using DistributedSampler, or not calling train_loader.sampler.set_epoch(epoch) at the start of each epoch. This can lead to incorrect data distribution, data duplication, or processes seeing the same data every epoch.
    • Troubleshooting: Double-check your DataLoader and DistributedSampler configuration. Ensure set_epoch() is called.
  3. Synchronization Issues (Model Divergence):

    • Cause: If gradients are not properly averaged/reduced, or if model parameters are not synchronized before the next step, different model copies on different GPUs can diverge.
    • Troubleshooting:
      • Ensure your model is wrapped with DDP. DDP handles gradient synchronization automatically.
      • If implementing custom collective operations, ensure dist.all_reduce() or dist.broadcast() are used correctly.
      • Verify that the model state (weights) is identical across all ranks at the start of training (e.g., by loading a checkpoint on rank 0 and broadcasting it, or ensuring all ranks initialize with the same random seed).
  4. Networking Problems:

    • Cause: Firewalls blocking the MASTER_PORT, network latency, or incorrect MASTER_ADDR configuration, especially in multi-node setups.
    • Troubleshooting:
      • Check firewall rules to ensure the MASTER_PORT is open.
      • Verify network connectivity between nodes.
      • Use ping to test basic connectivity.
      • Ensure MASTER_ADDR is reachable from all worker nodes.
  5. Debugging Distributed Systems:

    • Challenge: Standard debuggers often attach to a single process.
    • Troubleshooting:
      • Print statements: Use rank == 0 for general logging, but strategically add print(f"Rank {dist.get_rank()}: Debugging point X reached") to specific sections to track individual process flow.
      • Logging: Use Python’s logging module and configure it to write to separate files for each rank.
      • Remote Debugging: Tools like VS Code’s remote debugging capabilities can attach to multiple processes, though setup can be complex.
      • Profiling Tools: Tools like NVIDIA Nsight Systems can help identify communication bottlenecks.

Summary

Congratulations! You’ve successfully navigated the complexities of distributed training and taken a significant step towards becoming a true AI scaling expert.

Here are the key takeaways from this chapter:

  • Necessity of Scale: Distributed training is crucial for handling models that exceed single-device memory and datasets that require immense training time.
  • Data Parallelism: The most common approach, where each device trains a copy of the model on a different subset of data, with gradients aggregated. PyTorch’s DistributedDataParallel (DDP) is a prime example.
  • Model Parallelism: Used when the model itself is too large for a single device, splitting layers or parts of layers across devices.
  • PyTorch DDP Implementation: You learned how to set up the distributed environment, use DistributedSampler for data distribution, wrap your model with DDP, and launch your script using torchrun.
  • Communication Overhead: The primary challenge in distributed systems, demanding efficient frameworks and hardware (like NCCL).
  • Troubleshooting: You’re now aware of common pitfalls like deadlocks, data sampling issues, and how to debug a multi-process environment.

By mastering these concepts, you’re now capable of building and training AI models that push the boundaries of what’s possible on single machines. This skill is indispensable for working with the cutting-edge large models that define modern AI.

What’s next? While we’ve learned to scale training, effective AI engineering also requires meticulous tracking of experiments, understanding subtle model behaviors, and debugging when things go wrong. In the next chapter, we’ll dive into Experimentation, Tracking, and Debugging Model Behavior, equipping you with the tools and methodologies to manage your AI development lifecycle like a professional.


References


This page is AI-assisted and reviewed. It references official documentation and recognized resources where relevant.