PyTorch 训练分类器

2025-06-18 17:17 更新

图像分类作为计算机视觉领域的核心任务,有着广泛的应用前景,如自动驾驶、医疗影像诊断、安防监控等。PyTorch 凭借其强大的功能和灵活的操作,为开发者提供了一个高效构建和训练图像分类器的平台。

一、数据准备:构建模型的基石

在训练分类器之前,我们需要准备合适的训练数据。这里我们将使用经典的 CIFAR10 数据集,它包含 10 个类别的彩色图像,每个类别有 6000 张图像,图像大小为 32x32 像素。

使用 torchvision 加载 CIFAR10 数据集

import torch
import torchvision
import torchvision.transforms as transforms


## 数据预处理:将图像转换为张量,并进行标准化
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])


## 下载训练集和测试集
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4, shuffle=True, num_workers=2)


testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=4, shuffle=False, num_workers=2)


## 定义类别名称
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

数据可视化

通过可视化部分训练数据,可以帮助我们更好地了解数据集的结构和内容。

import matplotlib.pyplot as plt
import numpy as np


## 定义一个函数用于显示图像
def imshow(img):
    img = img / 2 + 0.5  # 反标准化
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.show()


## 获取一批训练数据
dataiter = iter(trainloader)
images, labels = next(dataiter)


## 显示图像
imshow(torchvision.utils.make_grid(images))


## 打印标签
print(' '.join(f'{classes[labels[j]]:5s}' for j in range(4)))

二、定义卷积神经网络:构建分类器的核心

卷积神经网络(CNN)是处理图像数据的主流网络结构,它通过卷积层自动提取图像特征,能够有效捕捉图像中的空间信息。

定义 CNN 架构

import torch.nn as nn
import torch.nn.functional as F


class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        # 定义卷积层和池化层
        self.conv1 = nn.Conv2d(3, 6, 5)  # 输入通道3,输出通道6,卷积核大小5
        self.pool = nn.MaxPool2d(2, 2)   # 最大池化层,窗口大小2,步长2
        self.conv2 = nn.Conv2d(6, 16, 5) # 输入通道6,输出通道16,卷积核大小5
        # 定义全连接层
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)


    def forward(self, x):
        # 前向传播过程
        x = self.pool(F.relu(self.conv1(x)))  # 卷积 + 激活 + 池化
        x = self.pool(F.relu(self.conv2(x)))  # 卷积 + 激活 + 池化
        x = x.view(-1, 16 * 5 * 5)            # 展平操作
        x = F.relu(self.fc1(x))               # 全连接 + 激活
        x = F.relu(self.fc2(x))               # 全连接 + 激活
        x = self.fc3(x)                       # 输出层
        return x


net = Net()
print(net)

三、定义损失函数和优化器:模型训练的指引

损失函数用于衡量模型预测结果与真实标签之间的差距,优化器则负责根据损失函数的梯度信息更新模型参数。

import torch.optim as optim


## 使用交叉熵损失函数和随机梯度下降优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)

四、训练网络:提升模型性能的关键

训练过程是模型学习数据特征、优化参数的关键环节。我们需要多次迭代训练数据,逐步调整模型参数,以降低损失函数的值。

for epoch in range(2):  # 遍历数据集多次
    running_loss = 0.0
    for i, data in enumerate(trainloader, 0):
        # 获取输入数据和标签
        inputs, labels = data


        # 清空梯度缓存
        optimizer.zero_grad()


        # 前向传播 + 反向传播 + 优化
        outputs = net(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()


        # 打印统计信息
        running_loss += loss.item()
        if i % 2000 == 1999:    # 每 2000 个小批量打印一次
            print(f'[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 2000:.3f}')
            running_loss = 0.0


print('Finished Training')

五、保存和加载模型:模型持久化与复用

训练完成后,我们可以将模型参数保存到文件中,以便后续加载和使用。

PATH = './cifar_net.pth'
torch.save(net.state_dict(), PATH)

加载模型参数:

net = Net()
net.load_state_dict(torch.load(PATH))

六、测试网络:评估模型性能

在测试集上评估模型的性能,计算分类准确率。

correct = 0
total = 0
with torch.no_grad():  # 在测试阶段不需要计算梯度
    for data in testloader:
        images, labels = data
        outputs = net(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()


print(f'Accuracy of the network on the 10000 test images: {100 * correct / total:.2f}%')

七、在 GPU 上训练:加速模型训练

如果电脑配备 GPU,可以利用 GPU 加速模型训练过程,显著提升训练速度。

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
net.to(device)


## 将输入数据和标签移动到 GPU 上
inputs, labels = inputs.to(device), labels.to(device)

八、总结

通过本教程,你已经掌握了使用 PyTorch 训练图像分类器的核心步骤,包括数据准备、网络定义、模型训练、性能评估以及 GPU 加速等关键技术。在编程狮平台的进一步学习中,你可以尝试以下方向:

  • 探索更多数据集 :除了 CIFAR10,还可以尝试 ImageNet、MNIST 等其他知名数据集,挑战不同难度的图像分类任务。
  • 优化网络结构 :通过调整卷积层、池化层、全连接层的数量和参数,或者尝试不同的网络架构(如 ResNet、AlexNet 等),提升模型性能。
  • 学习高级技巧 :深入了解数据增强、正则化、迁移学习等高级技巧,进一步提高模型的泛化能力和鲁棒性。
以上内容是否对您有帮助:
在线笔记
App下载
App下载

扫描二维码

下载编程狮App

公众号
微信公众号

编程狮公众号