PyTorch 修剪教程
2025-06-23 16:03 更新
在深度学习模型开发过程中,模型修剪是一种有效的压缩技术,可以减少模型参数数量,降低内存占用和计算成本,同时保持较高的模型性能。本教程将详细讲解如何使用 PyTorch 进行模型修剪。
一、模型修剪概述
模型修剪通过移除神经网络中不重要的连接或神经元,来减小模型规模、提高推理速度和降低存储需求。常见的修剪方法包括结构化修剪和非结构化修剪。
二、建立模型
我们以 LeNet 模型为例,展示如何在 PyTorch 中实现模型修剪。
import torch
import torch.nn as nn
import torch.nn.utils.prune as prune
class LeNet(nn.Module):
def __init__(self):
super(LeNet, self).__init__()
self.conv1 = nn.Conv2d(1, 6, 3)
self.conv2 = nn.Conv2d(6, 16, 3)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
x = torch.nn.functional.max_pool2d(torch.nn.functional.relu(self.conv1(x)), (2, 2))
x = torch.nn.functional.max_pool2d(torch.nn.functional.relu(self.conv2(x)), 2)
x = x.view(-1, int(x.nelement() / x.shape[0]))
x = torch.nn.functional.relu(self.fc1(x))
x = torch.nn.functional.relu(self.fc2(x))
x = self.fc3(x)
return x
model = LeNet()
三、修剪模块
使用 PyTorch 的 torch.nn.utils.prune
模块对模型进行修剪。
## 修剪 conv1 层的 weight 参数,随机修剪 30% 的连接
prune.random_unstructured(model.conv1, name="weight", amount=0.3)
修剪后,模型的参数和缓冲区会发生变化。
print(list(model.conv1.named_parameters()))
print(list(model.conv1.named_buffers()))
print(model.conv1.weight)
四、迭代修剪
可以对同一参数进行多次修剪,每次修剪的效果会累积。
## 按 L1 范数修剪 bias 参数,移除 3 个最小值
prune.l1_unstructured(model.conv1, name="bias", amount=3)
五、序列化修剪的模型
修剪后的模型可以像普通模型一样进行序列化和保存。
torch.save(model.state_dict(), "trimmed_model.pth")
六、删除修剪重新参数化
修剪完成后,可以删除重新参数化,使修剪永久化。
prune.remove(model.conv1, 'weight')
七、修剪模型中的多个参数
可以同时修剪模型中的多个参数。
new_model = LeNet()
for name, module in new_model.named_modules():
if isinstance(module, torch.nn.Conv2d):
prune.l1_unstructured(module, name='weight', amount=0.2)
elif isinstance(module, torch.nn.Linear):
prune.l1_unstructured(module, name='weight', amount=0.4)
八、全局修剪
全局修剪会在整个模型范围内进行修剪,而不是针对单个层。
model = LeNet()
parameters_to_prune = (
(model.conv1, 'weight'),
(model.conv2, 'weight'),
(model.fc1, 'weight'),
(model.fc2, 'weight'),
(model.fc3, 'weight'),
)
prune.global_unstructured(
parameters_to_prune,
pruning_method=prune.L1Unstructured,
amount=0.2,
)
九、自定义修剪方法
可以通过继承 BasePruningMethod
类来实现自定义修剪方法。
class FooBarPruningMethod(prune.BasePruningMethod):
PRUNING_TYPE = 'unstructured'
def compute_mask(self, t, default_mask):
mask = default_mask.clone()
mask.view(-1)[::2] = 0
return mask
def foobar_unstructured(module, name):
FooBarPruningMethod.apply(module, name)
return module
model = LeNet()
foobar_unstructured(model.fc3, name='bias')
print(model.fc3.bias_mask)
通过本教程,大家可以在编程狮(W3Cschool)平台上轻松掌握 PyTorch 模型修剪的方法。模型修剪是优化 PyTorch 模型的重要技术,希望大家能够学以致用,在实际项目中灵活应用这些技术。在编程狮(W3Cschool)学习更多相关内容,提升你的深度学习开发技能。
以上内容是否对您有帮助:
更多建议: