PyTorch 训练分类器
图像分类作为计算机视觉领域的核心任务,有着广泛的应用前景,如自动驾驶、医疗影像诊断、安防监控等。PyTorch 凭借其强大的功能和灵活的操作,为开发者提供了一个高效构建和训练图像分类器的平台。
一、数据准备:构建模型的基石
在训练分类器之前,我们需要准备合适的训练数据。这里我们将使用经典的 CIFAR10 数据集,它包含 10 个类别的彩色图像,每个类别有 6000 张图像,图像大小为 32x32 像素。
使用 torchvision 加载 CIFAR10 数据集
import torch
import torchvision
import torchvision.transforms as transforms
## 数据预处理:将图像转换为张量,并进行标准化
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
## 下载训练集和测试集
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4, shuffle=True, num_workers=2)
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=4, shuffle=False, num_workers=2)
## 定义类别名称
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
数据可视化
通过可视化部分训练数据,可以帮助我们更好地了解数据集的结构和内容。
import matplotlib.pyplot as plt
import numpy as np
## 定义一个函数用于显示图像
def imshow(img):
img = img / 2 + 0.5 # 反标准化
npimg = img.numpy()
plt.imshow(np.transpose(npimg, (1, 2, 0)))
plt.show()
## 获取一批训练数据
dataiter = iter(trainloader)
images, labels = next(dataiter)
## 显示图像
imshow(torchvision.utils.make_grid(images))
## 打印标签
print(' '.join(f'{classes[labels[j]]:5s}' for j in range(4)))
二、定义卷积神经网络:构建分类器的核心
卷积神经网络(CNN)是处理图像数据的主流网络结构,它通过卷积层自动提取图像特征,能够有效捕捉图像中的空间信息。
定义 CNN 架构
import torch.nn as nn
import torch.nn.functional as F
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
# 定义卷积层和池化层
self.conv1 = nn.Conv2d(3, 6, 5) # 输入通道3,输出通道6,卷积核大小5
self.pool = nn.MaxPool2d(2, 2) # 最大池化层,窗口大小2,步长2
self.conv2 = nn.Conv2d(6, 16, 5) # 输入通道6,输出通道16,卷积核大小5
# 定义全连接层
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
# 前向传播过程
x = self.pool(F.relu(self.conv1(x))) # 卷积 + 激活 + 池化
x = self.pool(F.relu(self.conv2(x))) # 卷积 + 激活 + 池化
x = x.view(-1, 16 * 5 * 5) # 展平操作
x = F.relu(self.fc1(x)) # 全连接 + 激活
x = F.relu(self.fc2(x)) # 全连接 + 激活
x = self.fc3(x) # 输出层
return x
net = Net()
print(net)
三、定义损失函数和优化器:模型训练的指引
损失函数用于衡量模型预测结果与真实标签之间的差距,优化器则负责根据损失函数的梯度信息更新模型参数。
import torch.optim as optim
## 使用交叉熵损失函数和随机梯度下降优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
四、训练网络:提升模型性能的关键
训练过程是模型学习数据特征、优化参数的关键环节。我们需要多次迭代训练数据,逐步调整模型参数,以降低损失函数的值。
for epoch in range(2): # 遍历数据集多次
running_loss = 0.0
for i, data in enumerate(trainloader, 0):
# 获取输入数据和标签
inputs, labels = data
# 清空梯度缓存
optimizer.zero_grad()
# 前向传播 + 反向传播 + 优化
outputs = net(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
# 打印统计信息
running_loss += loss.item()
if i % 2000 == 1999: # 每 2000 个小批量打印一次
print(f'[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 2000:.3f}')
running_loss = 0.0
print('Finished Training')
五、保存和加载模型:模型持久化与复用
训练完成后,我们可以将模型参数保存到文件中,以便后续加载和使用。
PATH = './cifar_net.pth'
torch.save(net.state_dict(), PATH)
加载模型参数:
net = Net()
net.load_state_dict(torch.load(PATH))
六、测试网络:评估模型性能
在测试集上评估模型的性能,计算分类准确率。
correct = 0
total = 0
with torch.no_grad(): # 在测试阶段不需要计算梯度
for data in testloader:
images, labels = data
outputs = net(images)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
print(f'Accuracy of the network on the 10000 test images: {100 * correct / total:.2f}%')
七、在 GPU 上训练:加速模型训练
如果电脑配备 GPU,可以利用 GPU 加速模型训练过程,显著提升训练速度。
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
net.to(device)
## 将输入数据和标签移动到 GPU 上
inputs, labels = inputs.to(device), labels.to(device)
八、总结
通过本教程,你已经掌握了使用 PyTorch 训练图像分类器的核心步骤,包括数据准备、网络定义、模型训练、性能评估以及 GPU 加速等关键技术。在编程狮平台的进一步学习中,你可以尝试以下方向:
- 探索更多数据集 :除了 CIFAR10,还可以尝试 ImageNet、MNIST 等其他知名数据集,挑战不同难度的图像分类任务。
- 优化网络结构 :通过调整卷积层、池化层、全连接层的数量和参数,或者尝试不同的网络架构(如 ResNet、AlexNet 等),提升模型性能。
- 学习高级技巧 :深入了解数据增强、正则化、迁移学习等高级技巧,进一步提高模型的泛化能力和鲁棒性。
更多建议: