PyTorch 编写自定义数据集,数据加载器和转换
2025-06-18 17:17 更新
PyTorch 提供了灵活的工具,帮助你高效地加载和预处理数据。本教程将指导你如何编写自定义数据集、数据加载器和数据转换,让你的数据准备工作更加高效、便捷。
一、数据准备与加载
1.1 数据集概述
我们以面部姿态数据集为例,该数据集包含标注了 68 个界标点的面部图像。数据集提供一个 CSV 文件,记录了图像文件名和对应的界标点坐标。
1.2 加载 CSV 文件并读取数据
import pandas as pd
## 加载 CSV 文件
landmarks_frame = pd.read_csv('data/faces/face_landmarks.csv')
## 获取图像名称和界标点
n = 65
img_name = landmarks_frame.iloc[n, 0]
landmarks = landmarks_frame.iloc[n, 1:].values.reshape(-1, 2).astype('float32')
print(f'Image name: {img_name}')
print(f'Landmarks shape: {landmarks.shape}')
print(f'First 4 Landmarks: {landmarks[:4]}')
1.3 显示图像及其界标点
import matplotlib.pyplot as plt
from skimage import io
def show_landmarks(image, landmarks):
"""显示图像及其界标点"""
plt.imshow(image)
plt.scatter(landmarks[:, 0], landmarks[:, 1], s=10, marker='.', c='r')
plt.pause(0.001)
image = io.imread(f'data/faces/{img_name}')
show_landmarks(image, landmarks)
plt.show()
二、自定义数据集类
继承 torch.utils.data.Dataset
类,创建自定义数据集类。
2.1 定义数据集类
import torch
import pandas as pd
from skimage import io
from torch.utils.data import Dataset
class FaceLandmarksDataset(Dataset):
"""面部界标数据集"""
def __init__(self, csv_file, root_dir, transform=None):
"""
Args:
csv_file (string): CSV 文件路径。
root_dir (string): 图像文件夹路径。
transform (callable, optional): 可选的数据转换函数。
"""
self.landmarks_frame = pd.read_csv(csv_file)
self.root_dir = root_dir
self.transform = transform
def __len__(self):
return len(self.landmarks_frame)
def __getitem__(self, idx):
if torch.is_tensor(idx):
idx = idx.tolist()
img_name = self.landmarks_frame.iloc[idx, 0]
img_path = f"{self.root_dir}/{img_name}"
image = io.imread(img_path)
landmarks = self.landmarks_frame.iloc[idx, 1:].values.reshape(-1, 2).astype('float32')
sample = {'image': image, 'landmarks': landmarks}
if self.transform:
sample = self.transform(sample)
return sample
三、数据转换
3.1 定义数据转换类
import numpy as np
from skimage import transform
class Rescale(object):
"""按比例缩放图像"""
def __init__(self, output_size):
assert isinstance(output_size, (int, tuple))
self.output_size = output_size
def __call__(self, sample):
image, landmarks = sample['image'], sample['landmarks']
h, w = image.shape[:2]
if isinstance(self.output_size, int):
if h > w:
new_h, new_w = self.output_size * h / w, self.output_size
else:
new_h, new_w = self.output_size, self.output_size * w / h
else:
new_h, new_w = self.output_size
new_h, new_w = int(new_h), int(new_w)
img = transform.resize(image, (new_h, new_w))
landmarks = landmarks * [new_w / w, new_h / h]
return {'image': img, 'landmarks': landmarks}
class RandomCrop(object):
"""随机裁剪图像"""
def __init__(self, output_size):
assert isinstance(output_size, (int, tuple))
if isinstance(output_size, int):
self.output_size = (output_size, output_size)
else:
assert len(output_size) == 2
self.output_size = output_size
def __call__(self, sample):
image, landmarks = sample['image'], sample['landmarks']
h, w = image.shape[:2]
new_h, new_w = self.output_size
top = np.random.randint(0, h - new_h)
left = np.random.randint(0, w - new_w)
image = image[top:top + new_h, left:left + new_w]
landmarks = landmarks - [left, top]
return {'image': image, 'landmarks': landmarks}
class ToTensor(object):
"""将样本转换为张量"""
def __call__(self, sample):
image, landmarks = sample['image'], sample['landmarks']
image = image.transpose((2, 0, 1))
return {'image': torch.from_numpy(image), 'landmarks': torch.from_numpy(landmarks)}
3.2 应用数据转换
from torch.utils.data import DataLoader
## 创建数据集实例
face_dataset = FaceLandmarksDataset(csv_file='data/faces/face_landmarks.csv',
root_dir='data/faces/',
transform=transforms.Compose([
Rescale(256),
RandomCrop(224),
ToTensor()
]))
## 创建数据加载器
dataloader = DataLoader(face_dataset, batch_size=4, shuffle=True, num_workers=4)
## 遍历数据加载器
for i_batch, sample_batched in enumerate(dataloader):
print(i_batch, sample_batched['image'].size(), sample_batched['landmarks'].size())
if i_batch == 3:
break
四、总结
通过本教程,你已经掌握了 PyTorch 自定义数据集、数据加载器和数据转换的实现方法。这些工具可以帮助你高效地处理各种数据集,为深度学习模型的训练提供有力支持。希望你在编程狮的学习平台上能够灵活运用这些知识,提升你的深度学习项目开发能力!
以上内容是否对您有帮助:
更多建议: