PyTorch torch.nn
一、PyTorch 简介
PyTorch 是一款开源的机器学习框架,它具有动态计算图、强大的 GPU 支持和易于使用的 API 等优势,广泛应用于深度学习领域。
二、torch.nn 模块基础
torch.nn
是 PyTorch 中用于构建神经网络的核心模块,它提供了各种神经网络层的实现,如卷积层、线性层、池化层等,以及常用的激活函数和损失函数。
1. 常见神经网络层
线性层(Linear Layer)
线性层是最基本的神经网络层之一,它对输入数据进行线性变换。
import torch
import torch.nn as nn
## 创建一个线性层,输入特征数为 5,输出特征数为 3
linear_layer = nn.Linear(5, 3)
## 创建一个输入张量,形状为 (2, 5)
input_tensor = torch.randn(2, 5)
## 通过线性层进行前向传播
output = linear_layer(input_tensor)
print("输入张量形状:", input_tensor.shape)
print("输出张量形状:", output.shape)
运行结果示例:
输入张量形状: torch.Size([2, 5])
输出张量形状: torch.Size([2, 3])
卷积层(Convolutional Layer)
卷积层是卷积神经网络(CNN)中的关键组件,用于提取数据的空间特征。
## 创建一个一维卷积层
conv1d_layer = nn.Conv1d(in_channels=16, out_channels=32, kernel_size=3, stride=1, padding=1)
## 创建一个输入张量,形状为 (10, 16, 50)
input_tensor = torch.randn(10, 16, 50)
## 通过卷积层进行前向传播
output = conv1d_layer(input_tensor)
print("输入张量形状:", input_tensor.shape)
print("输出张量形状:", output.shape)
运行结果示例:
输入张量形状: torch.Size([10, 16, 50])
输出张量形状: torch.Size([10, 32, 50])
池化层(Pooling Layer)
池化层用于对数据进行下采样,减少数据量,同时保留重要特征。
## 创建一个最大池化层
max_pool_layer = nn.MaxPool2d(kernel_size=2, stride=2)
## 创建一个输入张量,形状为 (20, 16, 50, 32)
input_tensor = torch.randn(20, 16, 50, 32)
## 通过池化层进行前向传播
output = max_pool_layer(input_tensor)
print("输入张量形状:", input_tensor.shape)
print("输出张量形状:", output.shape)
运行结果示例:
输入张量形状: torch.Size([20, 16, 50, 32])
输出张量形状: torch.Size([20, 16, 25, 16])
2. 激活函数
激活函数为神经网络引入非线性,使网络能够学习复杂的模式。
ReLU 激活函数
ReLU(Rectified Linear Unit)是常用的激活函数之一,它将所有负值置零。
## 创建一个 ReLU 激活函数
relu = nn.ReLU()
## 创建一个输入张量
input_tensor = torch.randn(2, 3)
## 应用 ReLU 激活函数
output = relu(input_tensor)
print("输入张量:\n", input_tensor)
print("输出张量:\n", output)
运行结果示例:
输入张量:
tensor([[ 0.1234, -0.5678, 0.9012],
[-1.2345, 0.6789, -0.3456]])
输出张量:
tensor([[ 0.1234, 0.0000, 0.9012],
[ 0.0000, 0.6789, 0.0000]])
Sigmoid 激活函数
Sigmoid 函数将输入值映射到 (0, 1) 区间,常用于二分类问题。
## 创建一个 Sigmoid 激活函数
sigmoid = nn.Sigmoid()
## 创建一个输入张量
input_tensor = torch.randn(2, 3)
## 应用 Sigmoid 激活函数
output = sigmoid(input_tensor)
print("输入张量:\n", input_tensor)
print("输出张量:\n", output)
运行结果示例:
输入张量:
tensor([[ 0.1234, -0.5678, 0.9012],
[-1.2345, 0.6789, -0.3456]])
输出张量:
tensor([[0.5312, 0.3622, 0.7109],
[0.2242, 0.6642, 0.4156]])
3. 损失函数
损失函数用于衡量模型预测结果与真实值之间的差异。
均方误差损失(MSELoss)
均方误差损失是回归问题中常用的损失函数,它计算预测值与真实值之间的平方差的平均值。
## 创建一个均方误差损失函数
mse_loss = nn.MSELoss()
## 创建预测值和真实值张量
predictions = torch.randn(3, 5, requires_grad=True)
targets = torch.randn(3, 5)
## 计算损失
loss = mse_loss(predictions, targets)
print("损失值:", loss.item())
运行结果示例:
损失值: 0.8765
交叉熵损失(CrossEntropyLoss)
交叉熵损失常用于分类问题,它结合了 LogSoftmax 和 NLLLoss。
## 创建一个交叉熵损失函数
cross_entropy_loss = nn.CrossEntropyLoss()
## 创建预测值和真实值张量
predictions = torch.randn(3, 5, requires_grad=True)
targets = torch.empty(3, dtype=torch.long).random_(5)
## 计算损失
loss = cross_entropy_loss(predictions, targets)
print("损失值:", loss.item())
运行结果示例:
损失值: 1.2345
三、代码实操环节
实操 1:构建简单的神经网络
## 定义一个简单的神经网络
class SimpleNet(nn.Module):
def __init__(self):
super(SimpleNet, self).__init__()
self.fc1 = nn.Linear(5, 10) # 输入层到隐藏层
self.fc2 = nn.Linear(10, 3) # 隐藏层到输出层
self.relu = nn.ReLU() # ReLU 激活函数
def forward(self, x):
x = self.fc1(x)
x = self.relu(x)
x = self.fc2(x)
return x
## 创建网络实例
net = SimpleNet()
## 创建输入数据
input_data = torch.randn(2, 5)
## 前向传播
output = net(input_data)
print("输入数据形状:", input_data.shape)
print("输出数据形状:", output.shape)
运行结果示例:
输入数据形状: torch.Size([2, 5])
输出数据形状: torch.Size([2, 3])
实操 2:训练神经网络
## 定义损失函数和优化器
criterion = nn.MSELoss()
optimizer = torch.optim.SGD(net.parameters(), lr=0.01)
## 训练数据
train_data = torch.randn(100, 5)
train_labels = torch.randn(100, 3)
## 训练循环
for epoch in range(100):
# 前向传播
outputs = net(train_data)
loss = criterion(outputs, train_labels)
# 反向传播和优化
optimizer.zero_grad()
loss.backward()
optimizer.step()
if (epoch+1) % 10 == 0:
print(f'Epoch [{epoch+1}/100], Loss: {loss.item():.4f}')
运行结果示例:
Epoch [10/100], Loss: 0.8765
Epoch [20/100], Loss: 0.7654
...
Epoch [100/100], Loss: 0.1234
四、总结与展望
本教程从零基础出发,带领大家认识了 PyTorch 环境,了解了 torch.nn
模块中的基本概念,包括神经网络层、激活函数和损失函数等内容,并通过代码实操环节加深了对所学知识的理解。后续可以继续探索更多 PyTorch 相关知识,如自动梯度、神经网络的构建与训练等。
希望这篇教程能够帮助大家更好地入门 PyTorch,如果在学习过程中有任何疑问,欢迎访问编程狮(W3Cschool)官网获取更多资源和支持。
更多建议: