debug-distributed
GitHub针对AReaL分布式训练(FSDP2/TP等)的调试指南。涵盖死锁、结果错误、OOM及通信错误的排查。强调最小复现原则,提供环境配置、py-spy堆栈分析及DTensor验证等具体步骤与代码示例。
Trigger Scenarios
Install
npx skills add areal-project/AReaL --skill debug-distributed -g -y
SKILL.md
Frontmatter
{
"name": "debug-distributed",
"description": "Guide for debugging distributed training issues in AReaL. Use when user encounters hangs, wrong results, OOM, or communication errors."
}
Debug Distributed Training
Debugging guide for distributed training issues in AReaL (FSDP2, TP, CP, EP).
When to Use
This skill is triggered when:
- Training hangs or deadlocks
- Results differ across ranks or are numerically wrong
- OOM errors in distributed settings
- NCCL/communication errors or device mesh issues
Debugging Principles
Minimal Reproduction
Always follow the minimal demo principle: Reproduce with the least amount of code to narrow down the issue faster.
# Bad: Debug in full training loop
# Good: Create minimal script
import torch
import torch.distributed as dist
dist.init_process_group("nccl")
rank = dist.get_rank()
# Reproduce the exact operation that fails
tensor = torch.ones(10).cuda()
dist.all_reduce(tensor) # <-- Isolate the failing op
print(f"Rank {rank}: {tensor}")
Reduction strategy:
- Remove unrelated model components
- Use small tensor sizes
- Reduce world_size to minimum (e.g., 2 GPUs)
- Remove torch.compile if possible
- Disable activation checkpointing
Step-by-Step Debugging Guide
1. Hang Debugging (Deadlocks, Synchronization)
Environment Variables for Debugging:
# Full debug logging
export TORCH_DISTRIBUTED_DEBUG=DETAIL
export NCCL_DEBUG=INFO
export NCCL_DEBUG_SUBSYS=ALL
# torch.compile debugging
export TORCH_LOGS="+dynamo,recompiles"
export TORCHDYNAMO_VERBOSE=1
Dump Call Stack with py-spy (for hung processes):
# Find process IDs
ps aux | grep python
# Dump call stack of specific rank
py-spy dump --pid <PID>
# Record flame graph for performance analysis
py-spy record -o profile.svg --pid <PID> --duration 30
Common Causes:
- Mismatched Collectives: One rank calls
all_reduce, another doesn't. - Wrong Process Group: Using wrong group for collective.
- Tensor Shape Mismatch: Different shapes across ranks.
Debug Steps:
# Verify group membership
mesh = parallel_dims.get_mesh("dp_shard_cp")
group = mesh.get_group()
print(f"Rank {dist.get_rank()}: group size = {dist.get_world_size(group)}")
# Print shapes on all ranks
print(f"Rank {dist.get_rank()}: tensor.shape = {tensor.shape}")
dist.barrier()
Timeout Adjustment (for debugging only):
from areal.engine.core.distributed import patch_dist_group_timeout
from datetime import timedelta
patch_dist_group_timeout(timedelta(minutes=30))
2. Wrong Results (Gradient, Reduction Issues)
Check DTensor Placements:
from torch.distributed.tensor import DTensor
if isinstance(param, DTensor):
print(f"Param {name}: placements={param.placements}, mesh={param.device_mesh}")
Verify Gradient Reduction:
for name, param in model.named_parameters():
if param.grad is not None:
print(f"Rank {dist.get_rank()}: {name} grad_sum = {param.grad.sum().item()}")
3. OOM Issues (Memory, Sharding)
Check Memory Usage:
print(f"Rank {dist.get_rank()}: "
f"allocated={torch.cuda.memory_allocated()/1e9:.2f}GB, "
f"reserved={torch.cuda.memory_reserved()/1e9:.2f}GB")
Check FSDP Coverage:
for name, param in model.named_parameters():
is_dtensor = isinstance(param, DTensor)
print(f"{name}: is_dtensor={is_dtensor}, shape={param.shape}")
4. Communication Errors
| Error | Cause | Solution |
|---|---|---|
NCCL WARN Cuda failure |
GPU communication | Check NCCL version, GPU topology |
RuntimeError: Timed out |
Rank synchronization | Increase timeout, check code paths |
Invalid device mesh |
Mesh configuration | Verify world_size = dp * tp * cp |
Debugging Tools
Environment Variables Reference
| Variable | Purpose |
|---|---|
TORCH_DISTRIBUTED_DEBUG=DETAIL |
Detailed distributed logging |
NCCL_DEBUG=INFO |
NCCL communication logging |
NCCL_DEBUG_SUBSYS=ALL |
All NCCL subsystems |
TORCH_LOGS="+dynamo,recompiles" |
torch.compile logging |
TORCHDYNAMO_VERBOSE=1 |
Dynamo verbose output |
CUDA_LAUNCH_BLOCKING=1 |
Synchronous CUDA (slow, for debugging) |
py-spy for Call Stack Analysis
# Install
pip install py-spy
# Dump call stack of hung process
py-spy dump --pid <PID>
# Dump all Python processes
pgrep -f python | xargs -I {} py-spy dump --pid {}
# Record flame graph
py-spy record -o profile.svg --pid <PID> --duration 30
Rank-Conditional Printing
def print_all_ranks(msg):
for r in range(dist.get_world_size()):
if dist.get_rank() == r:
print(f"[Rank {r}] {msg}")
dist.barrier()
Check Device Mesh
def debug_mesh(parallel_dims):
mesh = parallel_dims.world_mesh
for dim_name in mesh.mesh_dim_names:
submesh = parallel_dims.get_mesh(dim_name)
if submesh:
print(f"Rank {dist.get_rank()}: {dim_name} size={submesh.size()}")
Validate Tensor Consistency
def check_tensor_consistency(tensor, name, group=None):
local_sum = tensor.sum().item()
tensor_sums = [None] * dist.get_world_size(group)
dist.all_gather_object(tensor_sums, local_sum, group=group)
if dist.get_rank() == 0 and len(set(tensor_sums)) > 1:
print(f"WARNING: {name} inconsistent: {tensor_sums}")
Key Files Reference
| Component | File |
|---|---|
| Parallel Dims | areal/experimental/models/archon/parallel_dims.py |
| Expert Parallel | areal/experimental/models/archon/expert_parallel.py |
| Ulysses (CP) | areal/experimental/models/archon/ulysses.py |
| FSDP/TP Apply | areal/experimental/models/archon/qwen2/infra/parallelize.py |
Version History
- d99124e Current 2026-07-05 09:13


