PyTorch torch.nn 到底是什么?

2025-06-24 14:06 更新

PyTorchtorch.nn 模块是构建和训练神经网络的核心。它提供了丰富的类和函数,帮助开发者轻松构建、训练和部署神经网络模型。本教程将详细讲解 torch.nn 的基本概念和使用方法。

一、torch.nn 模块概述

torch.nn 是 PyTorch 中用于构建神经网络的模块,它提供了以下关键组件:

(一)nn.Module

nn.Module 是所有神经网络模块的基类。您可以将其视为一个容器,用于存储模型的参数、缓冲区和其他子模块。每个 nn.Module 都可以定义自己的前向传播方法,描述如何将输入数据转换为输出数据。

示例:定义一个简单的线性回归模型

import torch
import torch.nn as nn


class LinearRegression(nn.Module):
    def __init__(self, input_size, output_size):
        super(LinearRegression, self).__init__()
        self.linear = nn.Linear(input_size, output_size)


    def forward(self, x):
        return self.linear(x)


model = LinearRegression(10, 1)
print(model)

(二)nn.Parameter

nn.Parameter 是用于表示模型可学习参数的类。它本质上是一个张量,但具有 requires_grad 属性,默认为 True,表示在反向传播过程中需要计算梯度。

(三)torch.nn.functional

torch.nn.functional 模块包含了一系列常用的函数,如激活函数、损失函数等。这些函数可以用于构建神经网络的不同组件。

示例:使用 ReLU 激活函数

import torch.nn.functional as F


x = torch.randn(2, 3)
y = F.relu(x)
print(y)

(四)优化器

torch.optim 模块提供了多种优化算法,用于更新模型的参数,以最小化损失函数。

示例:使用 SGD 优化器

import torch.optim as optim


model = LinearRegression(10, 1)
criterion = nn.MSELoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)


## 假设输入和目标
inputs = torch.randn(5, 10)
targets = torch.randn(5, 1)


## 前向传播
outputs = model(inputs)
loss = criterion(outputs, targets)


## 反向传播和优化
optimizer.zero_grad()
loss.backward()
optimizer.step()

二、构建和训练神经网络

(一)数据准备

在训练神经网络之前,需要准备训练数据。通常,数据是以张量的形式存储的。可以使用 torch.utils.data.Datasettorch.utils.data.DataLoader 来加载和迭代数据。

示例:加载 MNIST 数据集

from torchvision.datasets import MNIST
from torch.utils.data import DataLoader
import torchvision.transforms as transforms


## 下载并加载 MNIST 数据集
train_dataset = MNIST(root='./data', train=True, transform=transforms.ToTensor(), download=True)
test_dataset = MNIST(root='./data', train=False, transform=transforms.ToTensor())


## 创建数据加载器
train_loader = DataLoader(dataset=train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=64, shuffle=False)

(二)模型定义

定义一个神经网络模型,包括输入层、隐藏层和输出层。

示例:定义一个简单的全连接网络

class SimpleNN(nn.Module):
    def __init__(self):
        super(SimpleNN, self).__init__()
        self.fc1 = nn.Linear(28 * 28, 128)  # 输入层到隐藏层
        self.fc2 = nn.Linear(128, 10)       # 隐藏层到输出层


    def forward(self, x):
        x = x.view(-1, 28 * 28)  # 展平图像
        x = F.relu(self.fc1(x))   # 使用 ReLU 激活函数
        x = self.fc2(x)
        return x


model = SimpleNN()

(三)训练循环

使用定义好的模型、损失函数和优化器,进行模型的训练。

示例:训练循环

## 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)


## 训练模型
num_epochs = 5
for epoch in range(num_epochs):
    for images, labels in train_loader:
        # 前向传播
        outputs = model(images)
        loss = criterion(outputs, labels)


        # 反向传播和优化
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()


    print(f'Epoch [{epoch + 1}/{num_epochs}], Loss: {loss.item():.4f}')

(四)模型评估

在测试集上评估训练好的模型的性能。

示例:模型评估

## 测试模型
model.eval()  # 设置为评估模式
with torch.no_grad():
    correct = 0
    total = 0
    for images, labels in test_loader:
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()


    print(f'Accuracy of the model on the test images: {100 * correct / total:.2f}%')

通过本教程,您可以在编程狮(W3Cschool)平台上轻松掌握 PyTorch 的 torch.nn 模块的使用方法。torch.nn 是构建和训练神经网络的核心工具,希望大家能够学以致用,在实际项目中灵活应用这些技术。在编程狮(W3Cschool)学习更多相关内容,提升你的深度学习开发技能。

以上内容是否对您有帮助:
在线笔记
App下载
App下载

扫描二维码

下载编程狮App

公众号
微信公众号

编程狮公众号