PyTorch torch.nn.init
一、torch.nn.init 是什么?
torch.nn.init
是 PyTorch 中用于初始化神经网络权重的模块。它提供了多种权重初始化方法,帮助我们在训练神经网络时获得更好的初始权重,从而加速网络收敛并提升模型性能。
二、为什么需要权重初始化?
在训练神经网络时,合适的权重初始化至关重要。如果初始权重设置不当,可能导致网络训练缓慢甚至无法收敛。例如,权重过大或过小可能引发梯度消失或梯度爆炸问题。而合理的初始化方法能确保网络在训练初期就处于良好的状态,有助于信息在神经网络中的有效传播,促进模型更快地学习和收敛。
三、torch.nn.init 常用函数详解
(一)torch.nn.init.calculate_gain
- 函数定义
torch.nn.init.calculate_gain(nonlinearity, param=None)
:返回给定非线性函数的推荐增益值,用于一些初始化方法(如 Xavier 初始化)中,帮助调整初始化权重的尺度。
- 参数说明
nonlinearity
:非线性函数名称(如'relu'
、'tanh'
等)。param
:非线性函数的可选参数(如 Leaky ReLU 的负斜率)。
- 示例
- 计算 Leaky ReLU 的增益值,负斜率为 0.2:
gain = torch.nn.init.calculate_gain('leaky_relu', 0.2)
(二)均匀分布初始化函数
torch.nn.init.uniform_
- 功能 :用从均匀分布 ([a, b]) 中得出的值填充输入张量。
- 参数 :
tensor
:要填充的张量。a
:均匀分布的下限(默认为 0.0)。b
:均匀分布的上限(默认为 1.0)。
- 示例 :将一个 (3 \times 5) 张量用均匀分布值填充:
w = torch.empty(3, 5)
torch.nn.init.uniform_(w)
torch.nn.init.xavier_uniform_
- 功能 :根据 Xavier 初始化方法(适用于前馈神经网络),用均匀分布的值填充输入张量。计算公式为 (U\left[-\sqrt{\frac{6.0}{fan_in + fan_out}}, \sqrt{\frac{6.0}{fan_in + fan_out}}\right]),其中 (fan_in) 是输入单元数,(fan_out) 是输出单元数。
- 参数 :
tensor
:输入张量。gain
:可选的比例因子(默认为 1.0)。
- 示例 :用 Xavier 均匀初始化方法填充张量,结合 ReLU 激活函数的增益值:
w = torch.empty(3, 5)
torch.nn.init.xavier_uniform_(w, gain=torch.nn.init.calculate_gain('relu'))
torch.nn.init.kaiming_uniform_
- 功能 :根据 He 初始化方法(适用于使用 ReLU 激活函数的卷积神经网络),用均匀分布的值填充输入张量。计算公式为 (U\left[-\sqrt{\frac{6.0}{fan_in}}, \sqrt{\frac{6.0}{fan_in}}\right]),主要考虑激活函数的非线性特性。
- 参数 :
tensor
:输入张量。a
:整流器的负斜率(默认为 0)。mode
:'fan_in'
(默认,保留前向传播中权重的方差)或'fan_out'
(保留反向传播中的方差)。nonlinearity
:非线性函数名称(推荐与'relu'
或'leaky_relu'
一起使用)。
- 示例 :用 He 均匀初始化方法填充张量,适用于 ReLU 激活函数:
w = torch.empty(3, 5)
torch.nn.init.kaiming_uniform_(w, mode='fan_in', nonlinearity='relu')
(三)正态分布初始化函数
torch.nn.init.normal_
- 功能 :使用从正态分布 (\mathcal{N}(\text{mean}, \text{std})) 中得出的值填充输入张量。
- 参数 :
tensor
:要填充的张量。mean
:正态分布的均值(默认为 0.0)。std
:正态分布的标准差(默认为 1.0)。
- 示例 :将张量用正态分布值填充:
w = torch.empty(3, 5)
torch.nn.init.normal_(w)
torch.nn.init.xavier_normal_
- 功能 :根据 Xavier 初始化方法,使用正态分布的值填充输入张量。计算公式为 (\text{std} = \text{gain} \times \sqrt{\frac{2.0}{fan_in + fan_out}})。
- 参数 :
tensor
:输入张量。gain
:可选的比例因子(默认为 1.0)。
- 示例 :用 Xavier 正态初始化方法填充张量:
w = torch.empty(3, 5)
torch.nn.init.xavier_normal_(w)
torch.nn.init.kaiming_normal_
- 功能 :根据 He 初始化方法,使用正态分布的值填充输入张量。计算公式为 (\text{std} = \sqrt{\frac{2.0}{(1 + a^2) \times fan_in}}),其中 (a) 是负斜率。
- 参数 :
tensor
:输入张量。a
:整流器的负斜率(默认为 0)。mode
:'fan_in'
(默认)或'fan_out'
。nonlinearity
:非线性函数名称(推荐与'relu'
或'leaky_relu'
一起使用)。
- 示例 :用 He 正态初始化方法填充张量,适用于 ReLU 激活函数:
w = torch.empty(3, 5)
torch.nn.init.kaiming_normal_(w, mode='fan_out', nonlinearity='relu')
(四)其他初始化函数
torch.nn.init.constant_
- 功能 :用指定的值填充输入张量。
- 参数 :
tensor
:输入张量。val
:用于填充的值。
- 示例 :将张量所有元素初始化为 0.3:
w = torch.empty(3, 5)
torch.nn.init.constant_(w, 0.3)
torch.nn.init.ones_
和torch.nn.init.zeros_
- 功能 :分别用 1 和 0 填充输入张量。
- 参数 :
tensor
(要填充的张量)。 - 示例 :
- 初始化为 1:
w = torch.empty(3, 5)
torch.nn.init.ones_(w)
- 初始化为 0:
w = torch.empty(3, 5)
torch.nn.init.zeros_(w)
torch.nn.init.eye_
- 功能 :用单位矩阵填充二维输入张量,在线性层中保留输入的身份信息。
- 参数 :
tensor
(二维张量)。 - 示例 :将 (3 \times 5) 张量初始化为单位矩阵(注意只有方阵才能成为真正的单位矩阵,非方阵会填充部分 1 和 0):
w = torch.empty(3, 5)
torch.nn.init.eye_(w)
torch.nn.init.dirac_
- 功能 :用 Dirac delta 函数填充 3D、4D 或 5D 张量,保留卷积层中输入的身份信息。
- 参数 :
tensor
(3D、4D 或 5D 张量)。 - 示例 :初始化一个 (3 \times 16 \times 5 \times 5) 张量:
w = torch.empty(3, 16, 5, 5)
torch.nn.init.dirac_(w)
torch.nn.init.orthogonal_
- 功能 :用(半)正交矩阵填充输入张量,适用于深度线性网络初始化。
- 参数 :
tensor
:输入张量(至少 2 维)。gain
:可选比例因子(默认为 1)。
- 示例 :将张量初始化为正交矩阵:
w = torch.empty(3, 5)
torch.nn.init.orthogonal_(w)
torch.nn.init.sparse_
- 功能 :将 2D 张量填充为稀疏矩阵,适用于需要稀疏连接的神经网络层。
- 参数 :
tensor
:输入张量(2D)。sparsity
:每列中要设置为零的元素比例。std
:生成非零值的正态分布的标准差(默认为 0.01)。
- 示例 :将张量初始化为稀疏矩阵,稀疏性为 0.1:
w = torch.empty(3, 5)
torch.nn.init.sparse_(w, sparsity=0.1)
四、如何选择合适的初始化方法?
- 前馈神经网络 :通常使用 Xavier 初始化(
xavier_uniform_
或xavier_normal_
),在结合 Sigmoid 或 Tanh 激活函数时表现良好。 - 卷积神经网络 :使用 He 初始化(
kaiming_uniform_
或kaiming_normal_
)更适合,尤其当使用 ReLU 激活函数时,能够有效缓解梯度消失问题。 - 其他情况 :如果对激活函数不确定,可以先尝试默认的均匀分布或正态分布初始化。对于需要稀疏连接的层,使用稀疏初始化(
sparse_
)。
五、示例代码:综合应用
以下是一个综合示例,展示如何在 PyTorch 模型中使用多种初始化方法:
import torch
import torch.nn as nn
class MyModel(nn.Module):
def __init__(self):
super(MyModel, self).__init__()
# 线性层,使用 Xavier 初始化
self.fc1 = nn.Linear(10, 20)
torch.nn.init.xavier_uniform_(self.fc1.weight, gain=1.0)
torch.nn.init.constant_(self.fc1.bias, 0.0)
# 卷积层,使用 He 初始化
self.conv1 = nn.Conv2d(3, 16, kernel_size=3)
torch.nn.init.kaiming_normal_(self.conv1.weight, mode='fan_out', nonlinearity='relu')
torch.nn.init.constant_(self.conv1.bias, 0.0)
# 使用稀疏初始化的线性层
self.fc2 = nn.Linear(20, 10)
torch.nn.init.sparse_(self.fc2.weight, sparsity=0.5)
torch.nn.init.constant_(self.fc2.bias, 0.0)
def forward(self, x):
x = torch.relu(self.fc1(x))
x = self.fc2(x)
return x
## 创建模型实例
model = MyModel()
## 打印模型参数(部分)
print("FC1 Weight:", model.fc1.weight[:2, :2]) # 打印前 2x2 权重
print("FC1 Bias:", model.fc1.bias[:2]) # 打印前 2 个偏置
print("Conv1 Weight Shape:", model.conv1.weight.shape)
print("FC2 Weight Shape:", model.fc2.weight.shape)
在这个例子中,我们为不同的层选择了合适的初始化方法,以适应各自的激活函数和网络结构。通过这种方式,我们可以更好地初始化模型,为后续的训练打下良好的基础。
六、总结
通过本教程,我们详细介绍了 PyTorch 中的 torch.nn.init
模块及其常用的权重初始化方法。从简单的均匀分布和正态分布初始化,到针对特定激活函数的 Xavier 和 He 初始化,以及稀疏和正交初始化等高级方法,这些工具为我们提供了多样化的权重初始化选择。合理选择和应用这些初始化方法,能够有效提升神经网络的训练效果和性能。
在实际项目中,我们可以根据网络类型和激活函数等因素,灵活选择合适的初始化方法。同时,也可以通过实验来比较不同初始化方法对模型性能的影响,找到最适合当前任务的方案。
更多建议: