PyTorch 分布式训练师与 AWS 集成实战教程
随着深度学习模型规模的不断扩大和数据量的持续增长,单机训练方式已难以满足高效训练的需求。分布式训练成为一种必然选择,它通过将计算任务分布在多个 GPU 或服务器上,显著提升了训练效率。AWS 作为全球领先的云计算平台,提供了强大的计算资源和灵活的服务架构,为分布式训练提供了理想的运行环境。本文将深入探讨如何在 AWS 上搭建和运行 PyTorch 分布式训练系统,通过实际案例助力您高效开展深度学习项目。
一、AWS 环境搭建
(一)创建实例
在 AWS 上创建两个多 GPU 节点,选择适合深度学习任务的实例类型,如 p2.8xlarge
,其配备 8 个 NVIDIA Tesla K80 GPU,为分布式训练提供强大的计算支持。
(二)配置安全组
确保实例之间的通信畅通无阻,是分布式训练成功的关键。创建一个新的安全组,并配置入站和出站规则,允许节点之间所有类型的数据流量。具体操作步骤如下:
- 登录 AWS 管理控制台,选择 “EC2” 服务。
- 在左侧导航栏中,选择 “安全组”。
- 点击 “创建安全组”,设置安全组名称和描述。
- 在 “入站规则” 栏中,添加规则允许来自新安全组的 “所有流量”。
- 在 “出站规则” 栏中,同样添加规则允许流向新安全组的 “所有流量”。
(三)获取节点 IP 地址
在 EC2 仪表板中找到正在运行的实例,记录每个节点的 IPv4 公网 IP 和私网 IP。公网 IP 用于 SSH 连接,私网 IP 用于节点间通信。这些 IP 地址在后续配置中将被频繁使用。
二、环境配置
(一)创建并激活 conda 环境
在每个节点上创建并激活一个新的 conda 环境,为 PyTorch 提供干净的运行环境:
conda create -n pytorch_env python=3.8
conda activate pytorch_env
(二)安装 PyTorch 和 torchvision
安装支持 CUDA 的 PyTorch 夜度构建版本以及从源代码构建的 torchvision:
pip install torch --index-url https://download.pytorch.org/whl/nightly/cu118
cd ~
git clone https://github.com/pytorch/vision.git
cd vision
python setup.py install
(三)设置 NCCL 网络接口
为了优化 GPU 之间的通信,设置 NCCL 套接字的网络接口名称。通过运行 ifconfig
命令确定网络接口名称,并设置环境变量:
export NCCL_SOCKET_IFNAME=ens3
三、分布式训练代码实现
(一)导入必要的模块
import time
import sys
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.distributed as dist
import torch.optim
import torch.utils.data
import torch.utils.data.distributed
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torchvision.models as models
(二)定义辅助函数和类
class AverageMeter(object):
"""Computes and stores the average and current value"""
def __init__(self):
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
def accuracy(output, target, topk=(1,)):
"""Computes the precision@k for the specified values of k"""
with torch.no_grad():
maxk = max(topk)
batch_size = target.size(0)
_, pred = output.topk(maxk, 1, True, True)
pred = pred.t()
correct = pred.eq(target.view(1, -1).expand_as(pred))
res = []
for k in topk:
correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
res.append(correct_k.mul_(100.0 / batch_size))
return res
(三)定义训练和验证函数
def train(train_loader, model, criterion, optimizer, epoch):
batch_time = AverageMeter()
data_time = AverageMeter()
losses = AverageMeter()
top1 = AverageMeter()
top5 = AverageMeter()
model.train()
end = time.time()
for i, (input, target) in enumerate(train_loader):
data_time.update(time.time() - end)
input = input.cuda(non_blocking=True)
target = target.cuda(non_blocking=True)
output = model(input)
loss = criterion(output, target)
prec1, prec5 = accuracy(output, target, topk=(1, 5))
losses.update(loss.item(), input.size(0))
top1.update(prec1[0], input.size(0))
top5.update(prec5[0], input.size(0))
optimizer.zero_grad()
loss.backward()
optimizer.step()
batch_time.update(time.time() - end)
end = time.time()
if i % 10 == 0:
print('Epoch: [{0}][{1}/{2}]\t'
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
epoch, i, len(train_loader), batch_time=batch_time,
data_time=data_time, loss=losses, top1=top1, top5=top5))
def validate(val_loader, model, criterion):
batch_time = AverageMeter()
losses = AverageMeter()
top1 = AverageMeter()
top5 = AverageMeter()
model.eval()
with torch.no_grad():
end = time.time()
for i, (input, target) in enumerate(val_loader):
input = input.cuda(non_blocking=True)
target = target.cuda(non_blocking=True)
output = model(input)
loss = criterion(output, target)
prec1, prec5 = accuracy(output, target, topk=(1, 5))
losses.update(loss.item(), input.size(0))
top1.update(prec1[0], input.size(0))
top5.update(prec5[0], input.size(0))
batch_time.update(time.time() - end)
end = time.time()
if i % 100 == 0:
print('Test: [{0}/{1}]\t'
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
i, len(val_loader), batch_time=batch_time, loss=losses,
top1=top1, top5=top5))
print(' * Prec@1 {top1.avg:.3f} Prec@5 {top5.avg:.3f}'
.format(top1=top1, top5=top5))
return top1.avg
(四)初始化进程组
def main():
batch_size = 32
workers = 2
num_epochs = 2
starting_lr = 0.1
world_size = 4
dist_backend = 'nccl'
dist_url = "tcp://<node0-privateIP>:23456" # 替换为实际的节点私有 IP
print("Initialize Process Group...")
dist.init_process_group(backend=dist_backend, init_method=dist_url,
rank=int(sys.argv[1]), world_size=world_size)
local_rank = int(sys.argv[2])
dp_device_ids = [local_rank]
torch.cuda.set_device(local_rank)
print("Initialize Model...")
model = models.resnet18(pretrained=False).cuda()
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=dp_device_ids)
criterion = nn.CrossEntropyLoss().cuda()
optimizer = torch.optim.SGD(model.parameters(), starting_lr, momentum=0.9, weight_decay=1e-4)
print("Initialize Dataloaders...")
transform = transforms.Compose([
transforms.Resize(224),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
trainset = datasets.STL10(root='./data', split='train', download=True, transform=transform)
valset = datasets.STL10(root='./data', split='test', download=True, transform=transform)
train_sampler = torch.utils.data.distributed.DistributedSampler(trainset)
train_loader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
shuffle=(train_sampler is None),
num_workers=workers, pin_memory=False,
sampler=train_sampler)
val_loader = torch.utils.data.DataLoader(valset, batch_size=batch_size,
shuffle=False, num_workers=workers,
pin_memory=False)
best_prec1 = 0
for epoch in range(num_epochs):
train_sampler.set_epoch(epoch)
adjust_learning_rate(starting_lr, optimizer, epoch)
train(train_loader, model, criterion, optimizer, epoch)
prec1 = validate(val_loader, model, criterion)
best_prec1 = max(prec1, best_prec1)
print("Epoch Summary: ")
print("\tEpoch Accuracy: {}".format(prec1))
print("\tBest Accuracy: {}".format(best_prec1))
if __name__ == "__main__":
main()
四、运行训练
在每个节点上打开多个 SSH 终端,分别运行以下命令:
- 在 node0 的第一个终端上:
python main.py 0 0
- 在 node0 的第二个终端上:
python main.py 1 1
- 在 node1 的第一个终端上:
python main.py 2 0
- 在 node1 的第二个终端上:
python main.py 3 1
以上内容已同步发布至编程狮网站,欢迎访问编程狮 PyTorch 教程获取更多深度学习和 PyTorch 相关的优质教程。在学习过程中,如果您有任何疑问或需要进一步的技术支持,欢迎加入编程狮社区,与广大编程爱好者和专家进行交流和互动。
更多建议: