PyTorch 修剪教程

2025-06-23 16:03 更新

在深度学习模型开发过程中,模型修剪是一种有效的压缩技术,可以减少模型参数数量,降低内存占用和计算成本,同时保持较高的模型性能。本教程将详细讲解如何使用 PyTorch 进行模型修剪。

一、模型修剪概述

模型修剪通过移除神经网络中不重要的连接或神经元,来减小模型规模、提高推理速度和降低存储需求。常见的修剪方法包括结构化修剪和非结构化修剪。

二、建立模型

我们以 LeNet 模型为例,展示如何在 PyTorch 中实现模型修剪。

import torch
import torch.nn as nn
import torch.nn.utils.prune as prune


class LeNet(nn.Module):
    def __init__(self):
        super(LeNet, self).__init__()
        self.conv1 = nn.Conv2d(1, 6, 3)
        self.conv2 = nn.Conv2d(6, 16, 3)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)


    def forward(self, x):
        x = torch.nn.functional.max_pool2d(torch.nn.functional.relu(self.conv1(x)), (2, 2))
        x = torch.nn.functional.max_pool2d(torch.nn.functional.relu(self.conv2(x)), 2)
        x = x.view(-1, int(x.nelement() / x.shape[0]))
        x = torch.nn.functional.relu(self.fc1(x))
        x = torch.nn.functional.relu(self.fc2(x))
        x = self.fc3(x)
        return x


model = LeNet()

三、修剪模块

使用 PyTorch 的 torch.nn.utils.prune 模块对模型进行修剪。

## 修剪 conv1 层的 weight 参数,随机修剪 30% 的连接
prune.random_unstructured(model.conv1, name="weight", amount=0.3)

修剪后,模型的参数和缓冲区会发生变化。

print(list(model.conv1.named_parameters()))
print(list(model.conv1.named_buffers()))
print(model.conv1.weight)

四、迭代修剪

可以对同一参数进行多次修剪,每次修剪的效果会累积。

## 按 L1 范数修剪 bias 参数,移除 3 个最小值
prune.l1_unstructured(model.conv1, name="bias", amount=3)

五、序列化修剪的模型

修剪后的模型可以像普通模型一样进行序列化和保存。

torch.save(model.state_dict(), "trimmed_model.pth")

六、删除修剪重新参数化

修剪完成后,可以删除重新参数化,使修剪永久化。

prune.remove(model.conv1, 'weight')

七、修剪模型中的多个参数

可以同时修剪模型中的多个参数。

new_model = LeNet()
for name, module in new_model.named_modules():
    if isinstance(module, torch.nn.Conv2d):
        prune.l1_unstructured(module, name='weight', amount=0.2)
    elif isinstance(module, torch.nn.Linear):
        prune.l1_unstructured(module, name='weight', amount=0.4)

八、全局修剪

全局修剪会在整个模型范围内进行修剪,而不是针对单个层。

model = LeNet()
parameters_to_prune = (
    (model.conv1, 'weight'),
    (model.conv2, 'weight'),
    (model.fc1, 'weight'),
    (model.fc2, 'weight'),
    (model.fc3, 'weight'),
)
prune.global_unstructured(
    parameters_to_prune,
    pruning_method=prune.L1Unstructured,
    amount=0.2,
)

九、自定义修剪方法

可以通过继承 BasePruningMethod 类来实现自定义修剪方法。

class FooBarPruningMethod(prune.BasePruningMethod):
    PRUNING_TYPE = 'unstructured'
    def compute_mask(self, t, default_mask):
        mask = default_mask.clone()
        mask.view(-1)[::2] = 0
        return mask


def foobar_unstructured(module, name):
    FooBarPruningMethod.apply(module, name)
    return module


model = LeNet()
foobar_unstructured(model.fc3, name='bias')
print(model.fc3.bias_mask)

通过本教程,大家可以在编程狮(W3Cschool)平台上轻松掌握 PyTorch 模型修剪的方法。模型修剪是优化 PyTorch 模型的重要技术,希望大家能够学以致用,在实际项目中灵活应用这些技术。在编程狮(W3Cschool)学习更多相关内容,提升你的深度学习开发技能。

以上内容是否对您有帮助:
在线笔记
App下载
App下载

扫描二维码

下载编程狮App

公众号
微信公众号

编程狮公众号