PyTorch torch.utils.model_zoo

2025-07-02 17:36 更新

PyTorch 模型加载与迁移学习:torch.hub 实践指南

一、torch.hub 的简介与优势

torch.hub 是 PyTorch 提供的一个方便的工具,用于加载和使用预训练模型。它简化了模型的获取过程,使得用户可以轻松地从模型仓库中下载并加载预训练模型。这对于快速上手深度学习项目、进行模型迁移学习以及复现研究结果非常有帮助。

二、模型加载函数详解

(一)torch.hub.load

用于从 GitHub 仓库加载模型或模型组件。您可以直接通过模型仓库的路径和模型名称来获取模型。

(二)torch.hub.load_state_dict_from_url

从指定的 URL 下载并加载模型的状态字典。此函数非常有用,当您需要直接从网络加载模型权重时。

三、实际操作示例

(一)加载预训练模型

我们以加载预训练的 ResNet-18 模型为例,展示如何使用 torch.hub

import torch


## 加载 ResNet-18 模型
model = torch.hub.load('pytorch/vision:v0.10.0', 'resnet18', pretrained=True)


## 将模型设置为评估模式
model.eval()


## 打印模型结构
print(model)

(二)加载模型状态字典

如果您需要从 URL 加载模型的状态字典,可以使用以下方法:

state_dict = torch.hub.load_state_dict_from_url('https://s3.amazonaws.com/pytorch/models/resnet18-5c106cde.pth', map_location=torch.device('cpu'))


## 加载状态字典到模型
model = torch.hub.load('pytorch/vision:v0.10.0', 'resnet18')
model.load_state_dict(state_dict)

四、应用场景与技巧

(一)迁移学习

torch.hub 在迁移学习中非常有用。您可以加载一个预训练模型,然后根据自己的任务需求进行微调。

## 加载预训练模型并进行微调
model = torch.hub.load('pytorch/vision:v0.10.0', 'resnet18', pretrained=True)


## 替换模型的最后一层以适应新的分类任务
num_ftrs = model.fc.in_features
model.fc = torch.nn.Linear(num_ftrs, 2)  # 假设新的分类任务有 2 个类别


## 将模型设置为训练模式
model.train()

(二)复现研究结果

在复现其他研究者的模型时,您可以直接使用 torch.hub 加载他们提供的预训练模型或模型组件。

## 假设有一个研究者提供的模型仓库和模型名称
## model = torch.hub.load('researcher/model_repo', 'model_name', source='github')

(三)使用不同版本的模型

torch.hub 支持加载特定版本的模型,这在模型更新后需要保持一致性时非常有用。

## 加载特定版本的模型
model = torch.hub.load('pytorch/vision:v0.5.0', 'resnet18', pretrained=True)

五、总结

通过本教程,我们详细介绍了如何使用 torch.hub 加载和使用预训练模型,以及在迁移学习和复现研究结果中的应用。torch.hub 提供了便捷的接口,使得模型的获取和使用变得更加简单。掌握这些技巧,可以帮助您更高效地进行深度学习开发和研究。

以上内容是否对您有帮助:
在线笔记
App下载
App下载

扫描二维码

下载编程狮App

公众号
微信公众号

编程狮公众号