PyTorch 编写自定义数据集,数据加载器和转换

2025-06-18 17:17 更新

PyTorch 提供了灵活的工具,帮助你高效地加载和预处理数据。本教程将指导你如何编写自定义数据集、数据加载器和数据转换,让你的数据准备工作更加高效、便捷。

一、数据准备与加载

1.1 数据集概述

我们以面部姿态数据集为例,该数据集包含标注了 68 个界标点的面部图像。数据集提供一个 CSV 文件,记录了图像文件名和对应的界标点坐标。

1.2 加载 CSV 文件并读取数据

import pandas as pd


## 加载 CSV 文件
landmarks_frame = pd.read_csv('data/faces/face_landmarks.csv')


## 获取图像名称和界标点
n = 65
img_name = landmarks_frame.iloc[n, 0]
landmarks = landmarks_frame.iloc[n, 1:].values.reshape(-1, 2).astype('float32')


print(f'Image name: {img_name}')
print(f'Landmarks shape: {landmarks.shape}')
print(f'First 4 Landmarks: {landmarks[:4]}')

1.3 显示图像及其界标点

import matplotlib.pyplot as plt
from skimage import io


def show_landmarks(image, landmarks):
    """显示图像及其界标点"""
    plt.imshow(image)
    plt.scatter(landmarks[:, 0], landmarks[:, 1], s=10, marker='.', c='r')
    plt.pause(0.001)


image = io.imread(f'data/faces/{img_name}')
show_landmarks(image, landmarks)
plt.show()

二、自定义数据集类

继承 torch.utils.data.Dataset 类,创建自定义数据集类。

2.1 定义数据集类

import torch
import pandas as pd
from skimage import io
from torch.utils.data import Dataset


class FaceLandmarksDataset(Dataset):
    """面部界标数据集"""


    def __init__(self, csv_file, root_dir, transform=None):
        """
        Args:
            csv_file (string): CSV 文件路径。
            root_dir (string): 图像文件夹路径。
            transform (callable, optional): 可选的数据转换函数。
        """
        self.landmarks_frame = pd.read_csv(csv_file)
        self.root_dir = root_dir
        self.transform = transform


    def __len__(self):
        return len(self.landmarks_frame)


    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()


        img_name = self.landmarks_frame.iloc[idx, 0]
        img_path = f"{self.root_dir}/{img_name}"
        image = io.imread(img_path)
        landmarks = self.landmarks_frame.iloc[idx, 1:].values.reshape(-1, 2).astype('float32')


        sample = {'image': image, 'landmarks': landmarks}


        if self.transform:
            sample = self.transform(sample)


        return sample

三、数据转换

3.1 定义数据转换类

import numpy as np
from skimage import transform


class Rescale(object):
    """按比例缩放图像"""


    def __init__(self, output_size):
        assert isinstance(output_size, (int, tuple))
        self.output_size = output_size


    def __call__(self, sample):
        image, landmarks = sample['image'], sample['landmarks']
        h, w = image.shape[:2]


        if isinstance(self.output_size, int):
            if h > w:
                new_h, new_w = self.output_size * h / w, self.output_size
            else:
                new_h, new_w = self.output_size, self.output_size * w / h
        else:
            new_h, new_w = self.output_size


        new_h, new_w = int(new_h), int(new_w)
        img = transform.resize(image, (new_h, new_w))
        landmarks = landmarks * [new_w / w, new_h / h]


        return {'image': img, 'landmarks': landmarks}


class RandomCrop(object):
    """随机裁剪图像"""


    def __init__(self, output_size):
        assert isinstance(output_size, (int, tuple))
        if isinstance(output_size, int):
            self.output_size = (output_size, output_size)
        else:
            assert len(output_size) == 2
            self.output_size = output_size


    def __call__(self, sample):
        image, landmarks = sample['image'], sample['landmarks']
        h, w = image.shape[:2]
        new_h, new_w = self.output_size


        top = np.random.randint(0, h - new_h)
        left = np.random.randint(0, w - new_w)


        image = image[top:top + new_h, left:left + new_w]
        landmarks = landmarks - [left, top]


        return {'image': image, 'landmarks': landmarks}


class ToTensor(object):
    """将样本转换为张量"""


    def __call__(self, sample):
        image, landmarks = sample['image'], sample['landmarks']
        image = image.transpose((2, 0, 1))
        return {'image': torch.from_numpy(image), 'landmarks': torch.from_numpy(landmarks)}

3.2 应用数据转换

from torch.utils.data import DataLoader


## 创建数据集实例
face_dataset = FaceLandmarksDataset(csv_file='data/faces/face_landmarks.csv',
                                    root_dir='data/faces/',
                                    transform=transforms.Compose([
                                        Rescale(256),
                                        RandomCrop(224),
                                        ToTensor()
                                    ]))


## 创建数据加载器
dataloader = DataLoader(face_dataset, batch_size=4, shuffle=True, num_workers=4)


## 遍历数据加载器
for i_batch, sample_batched in enumerate(dataloader):
    print(i_batch, sample_batched['image'].size(), sample_batched['landmarks'].size())
    if i_batch == 3:
        break

四、总结

通过本教程,你已经掌握了 PyTorch 自定义数据集、数据加载器和数据转换的实现方法。这些工具可以帮助你高效地处理各种数据集,为深度学习模型的训练提供有力支持。希望你在编程狮的学习平台上能够灵活运用这些知识,提升你的深度学习项目开发能力!

以上内容是否对您有帮助:
在线笔记
App下载
App下载

扫描二维码

下载编程狮App

公众号
微信公众号

编程狮公众号