PyTorch TorchVision 对象检测微调教程
在深度学习领域,对象检测是一项关键技术,它不仅可以识别图像中的物体类别,还能精确定位它们的位置。PyTorch 作为一款功能强大的开源机器学习框架,在对象检测任务中表现卓越。本教程将教你如何利用 PyTorch TorchVision 微调预训练模型进行对象检测。
一、定义数据集
在 PyTorch 中,定义数据集是进行模型训练的第一步。我们需要创建一个自定义数据集类,继承自 torch.utils.data.Dataset
。这个类要实现 __len__
和 __getitem__
方法。
import os
import numpy as np
import torch
from PIL import Image
class PennFudanDataset(torch.utils.data.Dataset):
def __init__(self, root, transforms):
self.root = root
self.transforms = transforms
self.imgs = list(sorted(os.listdir(os.path.join(root, "PNGImages"))))
self.masks = list(sorted(os.listdir(os.path.join(root, "PedMasks"))))
def __getitem__(self, idx):
img_path = os.path.join(self.root, "PNGImages", self.imgs[idx])
mask_path = os.path.join(self.root, "PedMasks", self.masks[idx])
img = Image.open(img_path).convert("RGB")
mask = Image.open(mask_path)
mask = np.array(mask)
obj_ids = np.unique(mask)
obj_ids = obj_ids[1:]
masks = mask == obj_ids[:, None, None]
num_objs = len(obj_ids)
boxes = []
for i in range(num_objs):
pos = np.where(masks[i])
xmin = np.min(pos[1])
xmax = np.max(pos[1])
ymin = np.min(pos[0])
ymax = np.max(pos[0])
boxes.append([xmin, ymin, xmax, ymax])
boxes = torch.as_tensor(boxes, dtype=torch.float32)
labels = torch.ones((num_objs,), dtype=torch.int64)
masks = torch.as_tensor(masks, dtype=torch.uint8)
image_id = torch.tensor([idx])
area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0])
iscrowd = torch.zeros((num_objs,), dtype=torch.int64)
target = {}
target["boxes"] = boxes
target["labels"] = labels
target["masks"] = masks
target["image_id"] = image_id
target["area"] = area
target["iscrowd"] = iscrowd
if self.transforms is not None:
img, target = self.transforms(img, target)
return img, target
def __len__(self):
return len(self.imgs)
上面的代码定义了一个宾夕法尼亚复旦数据集(PennFudan Dataset)的数据集类。__getitem__
方法根据索引返回图像和目标信息,目标包括边界框、标签、掩码等。
二、定义模型
接下来,我们需要定义用于对象检测的模型。这里我们使用 Mask R-CNN,它是一种在 Faster R-CNN 基础上扩展而来的实例分割模型,能够同时进行对象检测和分割。
import torchvision
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.models.detection.mask_rcnn import MaskRCNNPredictor
def get_model_instance_segmentation(num_classes):
model = torchvision.models.detection.maskrcnn_resnet50_fpn(pretrained=True)
in_features = model.roi_heads.box_predictor.cls_score.in_features
model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)
in_features_mask = model.roi_heads.mask_predictor.conv5_mask.in_channels
hidden_layer = 256
model.roi_heads.mask_predictor = MaskRCNNPredictor(in_features_mask, hidden_layer, num_classes)
return model
上面的代码首先加载了在 COCO 数据集上预训练的 Mask R-CNN 模型,然后替换了模型的分类器和掩码预测器,使其适应我们的自定义数据集。
三、将所有内容放在一起
现在我们已经定义了数据集和模型,接下来需要将它们整合起来进行训练。
from engine import train_one_epoch, evaluate
import utils
import transforms as T
def get_transform(train):
transforms = []
transforms.append(T.ToTensor())
if train:
transforms.append(T.RandomHorizontalFlip(0.5))
return T.Compose(transforms)
def main():
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
num_classes = 2
dataset = PennFudanDataset('PennFudanPed', get_transform(train=True))
dataset_test = PennFudanDataset('PennFudanPed', get_transform(train=False))
indices = torch.randperm(len(dataset)).tolist()
dataset = torch.utils.data.Subset(dataset, indices[:-50])
dataset_test = torch.utils.data.Subset(dataset_test, indices[-50:])
data_loader = torch.utils.data.DataLoader(
dataset, batch_size=2, shuffle=True, num_workers=4,
collate_fn=utils.collate_fn)
data_loader_test = torch.utils.data.DataLoader(
dataset_test, batch_size=1, shuffle=False, num_workers=4,
collate_fn=utils.collate_fn)
model = get_model_instance_segmentation(num_classes)
model.to(device)
params = [p for p in model.parameters() if p.requires_grad]
optimizer = torch.optim.SGD(params, lr=0.005, momentum=0.9, weight_decay=0.0005)
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=3, gamma=0.1)
num_epochs = 10
for epoch in range(num_epochs):
train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq=10)
lr_scheduler.step()
evaluate(model, data_loader_test, device=device)
print("That's it!")
if __name__ == "__main__":
main()
上面的代码首先定义了数据转换函数 get_transform
,用于将图像转换为张量,并在训练时进行随机水平翻转数据增强。main
函数中,我们设置了训练设备(GPU 或 CPU),创建了数据集和数据加载器,定义了模型、优化器和学习率调度器,然后进行了模型的训练和评估。
四、总结
恭喜你!通过以上步骤,你已经成功地在 PyTorch 中利用 TorchVision 微调了一个预训练模型进行对象检测。在编程狮(W3Cschool)上,你可以找到更多关于 PyTorch 的详细教程和实战案例,帮助你进一步提升深度学习技能,成为人工智能领域的编程大神。
更多建议: