PyTorch torch.hub
一、PyTorch Hub 简介
PyTorch Hub 是一个预训练模型库,旨在促进研究的可重复性。通过 PyTorch Hub,用户可以轻松地加载和使用由社区贡献的预训练模型,极大地方便了模型的共享和复用。这些模型涵盖了多种常见的深度学习任务,如图像分类、目标检测、语义分割等。
二、发布模型
在 PyTorch Hub 中发布模型需要以下步骤:
(一)创建 hubconf.py
文件
在 GitHub 仓库中添加一个 hubconf.py
文件,该文件定义了模型的入口点。每个入口点都是一个 Python 函数,用于加载和返回预训练模型。
dependencies = ['torch'] # 模型加载所需的依赖包
def resnet18(pretrained=False, **kwargs):
"""
Resnet18 模型
pretrained (bool): 是否加载预训练权重
"""
from torchvision.models import resnet18 as _resnet18
model = _resnet18(pretrained=pretrained, **kwargs)
return model
(二)注意事项
- 模型应至少发布在分支或标签中,不能是随机提交。
- 预训练权重可以存储在 GitHub 仓库中,也可以通过
torch.hub.load_state_dict_from_url()
加载。
三、加载模型
PyTorch Hub 提供了便捷的 API 来加载预训练模型。
(一)列出可用模型
使用 torch.hub.list()
浏览 GitHub 仓库中可用的所有模型。
entrypoints = torch.hub.list('pytorch/vision', force_reload=True)
print(entrypoints)
(二)查看模型帮助文档
使用 torch.hub.help()
查看模型的文档字符串和示例。
help_doc = torch.hub.help('pytorch/vision', 'resnet18', force_reload=True)
print(help_doc)
(三)加载预训练模型
使用 torch.hub.load()
加载预训练模型。
model = torch.hub.load('pytorch/vision', 'resnet50', pretrained=True)
print(model)
四、缓存机制与模型保存路径
(一)缓存机制
PyTorch Hub 默认使用缓存来存储下载的模型文件。如果需要强制重新加载模型,可以设置 force_reload=True
。
(二)模型保存路径
模型文件默认保存在以下路径之一:
- 调用
torch.hub.set_dir(<PATH_TO_HUB_DIR>)
指定的目录。 $TORCH_HOME/hub
,如果设置了环境变量TORCH_HOME
。$XDG_CACHE_HOME/torch/hub
,如果设置了环境变量XDG_CACHE_HOME
。~/.cache/torch/hub
。
五、实际案例:加载预训练的 ResNet 模型并进行推理
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from PIL import Image
## 加载预训练的 ResNet50 模型
model = torch.hub.load('pytorch/vision', 'resnet50', pretrained=True)
## 将模型设置为评估模式
model.eval()
## 加载并预处理图像
image_path = 'image.jpg' # 图像文件路径
input_image = Image.open(image_path)
preprocess = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
input_tensor = preprocess(input_image)
input_batch = input_tensor.unsqueeze(0) # 创建批次维度
## 进行推理
with torch.no_grad():
output = model(input_batch)
## 输出结果
print(output)
六、总结
PyTorch Hub 提供了一个便捷的平台,用于共享和使用预训练模型。通过简单的几行代码,用户可以轻松地加载和使用由社区贡献的模型,大大加速了深度学习项目的开发过程。希望本教程能帮助您快速上手 PyTorch Hub,并在您的项目中充分利用预训练模型的优势。
更多建议: