PyTorch torch.hub

2025-06-25 14:59 更新

一、PyTorch Hub 简介

PyTorch Hub 是一个预训练模型库,旨在促进研究的可重复性。通过 PyTorch Hub,用户可以轻松地加载和使用由社区贡献的预训练模型,极大地方便了模型的共享和复用。这些模型涵盖了多种常见的深度学习任务,如图像分类、目标检测、语义分割等。

二、发布模型

在 PyTorch Hub 中发布模型需要以下步骤:

(一)创建 hubconf.py 文件

在 GitHub 仓库中添加一个 hubconf.py 文件,该文件定义了模型的入口点。每个入口点都是一个 Python 函数,用于加载和返回预训练模型。

dependencies = ['torch']  # 模型加载所需的依赖包


def resnet18(pretrained=False, **kwargs):
    """
    Resnet18 模型
    pretrained (bool): 是否加载预训练权重
    """
    from torchvision.models import resnet18 as _resnet18
    model = _resnet18(pretrained=pretrained, **kwargs)
    return model

(二)注意事项

  • 模型应至少发布在分支或标签中,不能是随机提交。
  • 预训练权重可以存储在 GitHub 仓库中,也可以通过 torch.hub.load_state_dict_from_url() 加载。

三、加载模型

PyTorch Hub 提供了便捷的 API 来加载预训练模型。

(一)列出可用模型

使用 torch.hub.list() 浏览 GitHub 仓库中可用的所有模型。

entrypoints = torch.hub.list('pytorch/vision', force_reload=True)
print(entrypoints)

(二)查看模型帮助文档

使用 torch.hub.help() 查看模型的文档字符串和示例。

help_doc = torch.hub.help('pytorch/vision', 'resnet18', force_reload=True)
print(help_doc)

(三)加载预训练模型

使用 torch.hub.load() 加载预训练模型。

model = torch.hub.load('pytorch/vision', 'resnet50', pretrained=True)
print(model)

四、缓存机制与模型保存路径

(一)缓存机制

PyTorch Hub 默认使用缓存来存储下载的模型文件。如果需要强制重新加载模型,可以设置 force_reload=True

(二)模型保存路径

模型文件默认保存在以下路径之一:

  • 调用 torch.hub.set_dir(<PATH_TO_HUB_DIR>) 指定的目录。
  • $TORCH_HOME/hub,如果设置了环境变量 TORCH_HOME
  • $XDG_CACHE_HOME/torch/hub,如果设置了环境变量 XDG_CACHE_HOME
  • ~/.cache/torch/hub

五、实际案例:加载预训练的 ResNet 模型并进行推理

import torch
import torch.nn as nn
import torchvision.transforms as transforms
from PIL import Image


## 加载预训练的 ResNet50 模型
model = torch.hub.load('pytorch/vision', 'resnet50', pretrained=True)


## 将模型设置为评估模式
model.eval()


## 加载并预处理图像
image_path = 'image.jpg'  # 图像文件路径
input_image = Image.open(image_path)
preprocess = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
input_tensor = preprocess(input_image)
input_batch = input_tensor.unsqueeze(0)  # 创建批次维度


## 进行推理
with torch.no_grad():
    output = model(input_batch)


## 输出结果
print(output)

六、总结

PyTorch Hub 提供了一个便捷的平台,用于共享和使用预训练模型。通过简单的几行代码,用户可以轻松地加载和使用由社区贡献的模型,大大加速了深度学习项目的开发过程。希望本教程能帮助您快速上手 PyTorch Hub,并在您的项目中充分利用预训练模型的优势。

以上内容是否对您有帮助:
在线笔记
App下载
App下载

扫描二维码

下载编程狮App

公众号
微信公众号

编程狮公众号