PyTorch torch.nn.init

2025-06-25 15:11 更新

一、torch.nn.init 是什么?

torch.nn.initPyTorch 中用于初始化神经网络权重的模块。它提供了多种权重初始化方法,帮助我们在训练神经网络时获得更好的初始权重,从而加速网络收敛并提升模型性能。

二、为什么需要权重初始化?

在训练神经网络时,合适的权重初始化至关重要。如果初始权重设置不当,可能导致网络训练缓慢甚至无法收敛。例如,权重过大或过小可能引发梯度消失或梯度爆炸问题。而合理的初始化方法能确保网络在训练初期就处于良好的状态,有助于信息在神经网络中的有效传播,促进模型更快地学习和收敛。

三、torch.nn.init 常用函数详解

(一)torch.nn.init.calculate_gain

  1. 函数定义
    • torch.nn.init.calculate_gain(nonlinearity, param=None):返回给定非线性函数的推荐增益值,用于一些初始化方法(如 Xavier 初始化)中,帮助调整初始化权重的尺度。

  1. 参数说明
    • nonlinearity:非线性函数名称(如 'relu''tanh' 等)。
    • param:非线性函数的可选参数(如 Leaky ReLU 的负斜率)。

  1. 示例
    • 计算 Leaky ReLU 的增益值,负斜率为 0.2:

gain = torch.nn.init.calculate_gain('leaky_relu', 0.2)

(二)均匀分布初始化函数

  1. torch.nn.init.uniform_
    • 功能 :用从均匀分布 ([a, b]) 中得出的值填充输入张量。
    • 参数
      • tensor:要填充的张量。
      • a:均匀分布的下限(默认为 0.0)。
      • b:均匀分布的上限(默认为 1.0)。

  • 示例 :将一个 (3 \times 5) 张量用均匀分布值填充:

w = torch.empty(3, 5)
torch.nn.init.uniform_(w)

  1. torch.nn.init.xavier_uniform_
    • 功能 :根据 Xavier 初始化方法(适用于前馈神经网络),用均匀分布的值填充输入张量。计算公式为 (U\left[-\sqrt{\frac{6.0}{fan_in + fan_out}}, \sqrt{\frac{6.0}{fan_in + fan_out}}\right]),其中 (fan_in) 是输入单元数,(fan_out) 是输出单元数。
    • 参数
      • tensor:输入张量。
      • gain:可选的比例因子(默认为 1.0)。

  • 示例 :用 Xavier 均匀初始化方法填充张量,结合 ReLU 激活函数的增益值:

w = torch.empty(3, 5)
torch.nn.init.xavier_uniform_(w, gain=torch.nn.init.calculate_gain('relu'))

  1. torch.nn.init.kaiming_uniform_
    • 功能 :根据 He 初始化方法(适用于使用 ReLU 激活函数的卷积神经网络),用均匀分布的值填充输入张量。计算公式为 (U\left[-\sqrt{\frac{6.0}{fan_in}}, \sqrt{\frac{6.0}{fan_in}}\right]),主要考虑激活函数的非线性特性。
    • 参数
      • tensor:输入张量。
      • a:整流器的负斜率(默认为 0)。
      • mode'fan_in'(默认,保留前向传播中权重的方差)或 'fan_out'(保留反向传播中的方差)。
      • nonlinearity:非线性函数名称(推荐与 'relu''leaky_relu' 一起使用)。

  • 示例 :用 He 均匀初始化方法填充张量,适用于 ReLU 激活函数:

w = torch.empty(3, 5)
torch.nn.init.kaiming_uniform_(w, mode='fan_in', nonlinearity='relu')

(三)正态分布初始化函数

  1. torch.nn.init.normal_
    • 功能 :使用从正态分布 (\mathcal{N}(\text{mean}, \text{std})) 中得出的值填充输入张量。
    • 参数
      • tensor:要填充的张量。
      • mean:正态分布的均值(默认为 0.0)。
      • std:正态分布的标准差(默认为 1.0)。

  • 示例 :将张量用正态分布值填充:

w = torch.empty(3, 5)
torch.nn.init.normal_(w)

  1. torch.nn.init.xavier_normal_
    • 功能 :根据 Xavier 初始化方法,使用正态分布的值填充输入张量。计算公式为 (\text{std} = \text{gain} \times \sqrt{\frac{2.0}{fan_in + fan_out}})。
    • 参数
      • tensor:输入张量。
      • gain:可选的比例因子(默认为 1.0)。

  • 示例 :用 Xavier 正态初始化方法填充张量:

w = torch.empty(3, 5)
torch.nn.init.xavier_normal_(w)

  1. torch.nn.init.kaiming_normal_
    • 功能 :根据 He 初始化方法,使用正态分布的值填充输入张量。计算公式为 (\text{std} = \sqrt{\frac{2.0}{(1 + a^2) \times fan_in}}),其中 (a) 是负斜率。
    • 参数
      • tensor:输入张量。
      • a:整流器的负斜率(默认为 0)。
      • mode'fan_in'(默认)或 'fan_out'
      • nonlinearity:非线性函数名称(推荐与 'relu''leaky_relu' 一起使用)。

  • 示例 :用 He 正态初始化方法填充张量,适用于 ReLU 激活函数:

w = torch.empty(3, 5)
torch.nn.init.kaiming_normal_(w, mode='fan_out', nonlinearity='relu')

(四)其他初始化函数

  1. torch.nn.init.constant_
    • 功能 :用指定的值填充输入张量。
    • 参数
      • tensor:输入张量。
      • val:用于填充的值。

  • 示例 :将张量所有元素初始化为 0.3:

w = torch.empty(3, 5)
torch.nn.init.constant_(w, 0.3)

  1. torch.nn.init.ones_torch.nn.init.zeros_
    • 功能 :分别用 1 和 0 填充输入张量。
    • 参数tensor(要填充的张量)。
    • 示例
      • 初始化为 1:

w = torch.empty(3, 5)
torch.nn.init.ones_(w)

  • 初始化为 0:

w = torch.empty(3, 5)
torch.nn.init.zeros_(w)

  1. torch.nn.init.eye_
    • 功能 :用单位矩阵填充二维输入张量,在线性层中保留输入的身份信息。
    • 参数tensor(二维张量)。
    • 示例 :将 (3 \times 5) 张量初始化为单位矩阵(注意只有方阵才能成为真正的单位矩阵,非方阵会填充部分 1 和 0):

w = torch.empty(3, 5)
torch.nn.init.eye_(w)

  1. torch.nn.init.dirac_
    • 功能 :用 Dirac delta 函数填充 3D、4D 或 5D 张量,保留卷积层中输入的身份信息。
    • 参数tensor(3D、4D 或 5D 张量)。
    • 示例 :初始化一个 (3 \times 16 \times 5 \times 5) 张量:

w = torch.empty(3, 16, 5, 5)
torch.nn.init.dirac_(w)

  1. torch.nn.init.orthogonal_
    • 功能 :用(半)正交矩阵填充输入张量,适用于深度线性网络初始化。
    • 参数
      • tensor:输入张量(至少 2 维)。
      • gain:可选比例因子(默认为 1)。

  • 示例 :将张量初始化为正交矩阵:

w = torch.empty(3, 5)
torch.nn.init.orthogonal_(w)

  1. torch.nn.init.sparse_
    • 功能 :将 2D 张量填充为稀疏矩阵,适用于需要稀疏连接的神经网络层。
    • 参数
      • tensor:输入张量(2D)。
      • sparsity:每列中要设置为零的元素比例。
      • std:生成非零值的正态分布的标准差(默认为 0.01)。

  • 示例 :将张量初始化为稀疏矩阵,稀疏性为 0.1:

w = torch.empty(3, 5)
torch.nn.init.sparse_(w, sparsity=0.1)

四、如何选择合适的初始化方法?

  1. 前馈神经网络 :通常使用 Xavier 初始化(xavier_uniform_xavier_normal_),在结合 Sigmoid 或 Tanh 激活函数时表现良好。
  2. 卷积神经网络 :使用 He 初始化(kaiming_uniform_kaiming_normal_)更适合,尤其当使用 ReLU 激活函数时,能够有效缓解梯度消失问题。
  3. 其他情况 :如果对激活函数不确定,可以先尝试默认的均匀分布或正态分布初始化。对于需要稀疏连接的层,使用稀疏初始化(sparse_)。

五、示例代码:综合应用

以下是一个综合示例,展示如何在 PyTorch 模型中使用多种初始化方法:

import torch
import torch.nn as nn


class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        # 线性层,使用 Xavier 初始化
        self.fc1 = nn.Linear(10, 20)
        torch.nn.init.xavier_uniform_(self.fc1.weight, gain=1.0)
        torch.nn.init.constant_(self.fc1.bias, 0.0)


        # 卷积层,使用 He 初始化
        self.conv1 = nn.Conv2d(3, 16, kernel_size=3)
        torch.nn.init.kaiming_normal_(self.conv1.weight, mode='fan_out', nonlinearity='relu')
        torch.nn.init.constant_(self.conv1.bias, 0.0)


        # 使用稀疏初始化的线性层
        self.fc2 = nn.Linear(20, 10)
        torch.nn.init.sparse_(self.fc2.weight, sparsity=0.5)
        torch.nn.init.constant_(self.fc2.bias, 0.0)


    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x


## 创建模型实例
model = MyModel()


## 打印模型参数(部分)
print("FC1 Weight:", model.fc1.weight[:2, :2])  # 打印前 2x2 权重
print("FC1 Bias:", model.fc1.bias[:2])  # 打印前 2 个偏置
print("Conv1 Weight Shape:", model.conv1.weight.shape)
print("FC2 Weight Shape:", model.fc2.weight.shape)

在这个例子中,我们为不同的层选择了合适的初始化方法,以适应各自的激活函数和网络结构。通过这种方式,我们可以更好地初始化模型,为后续的训练打下良好的基础。

六、总结

通过本教程,我们详细介绍了 PyTorch 中的 torch.nn.init 模块及其常用的权重初始化方法。从简单的均匀分布和正态分布初始化,到针对特定激活函数的 Xavier 和 He 初始化,以及稀疏和正交初始化等高级方法,这些工具为我们提供了多样化的权重初始化选择。合理选择和应用这些初始化方法,能够有效提升神经网络的训练效果和性能。

在实际项目中,我们可以根据网络类型和激活函数等因素,灵活选择合适的初始化方法。同时,也可以通过实验来比较不同初始化方法对模型性能的影响,找到最适合当前任务的方案。

以上内容是否对您有帮助:
在线笔记
App下载
App下载

扫描二维码

下载编程狮App

公众号
微信公众号

编程狮公众号