PyTorch torch.utils.data

2025-07-02 16:04 更新

PyTorch 数据加载与处理详解

一、PyTorch 数据加载器简介

torch.utils.data.DataLoader 是 PyTorch 提供的核心数据加载工具,它可以方便地从数据集中加载数据,并支持多种高级功能,如多进程加载、自动批处理、自定义数据转换等。

二、数据加载器的核心参数与用法

(一)DataLoader 的基本构造

DataLoader(
    dataset,
    batch_size=1,
    shuffle=False,
    sampler=None,
    batch_sampler=None,
    num_workers=0,
    collate_fn=None,
    pin_memory=False,
    drop_last=False,
    timeout=0,
    worker_init_fn=None
)

  1. dataset :数据集对象,可以是映射式数据集(实现 __getitem____len__ 方法)或迭代式数据集(实现 __iter__ 方法)。
  2. batch_size :每个批次加载的样本数量,默认为 1。
  3. shuffle :是否在每个 epoch 开始时打乱数据,默认为 False
  4. sampler :自定义采样器,用于指定数据加载顺序,不能与 shuffle 同时使用。
  5. num_workers :加载数据时使用的子进程数量,默认为 0(即单进程加载)。
  6. collate_fn :用于将单个样本合并成批次的函数,默认会将样本列表转换为张量。
  7. pin_memory :是否将数据加载到固定内存中,以便更快地传输到 GPU,默认为 False
  8. drop_last :如果数据集大小不能被批次大小整除,是否丢弃最后一个不完整的批次,默认为 False

(二)数据集类型

  1. 映射式数据集
    • 实现 __getitem____len__ 协议。
    • 适合于数据已经存储在磁盘上且可以按索引访问的场景。

  1. 迭代式数据集
    • 实现 __iter__ 协议。
    • 适合于数据流式读取的场景,如实时生成的数据或从数据库中读取的数据。

(三)采样器

  1. torch.utils.data.SequentialSampler :按顺序采样数据。
  2. torch.utils.data.RandomSampler :随机采样数据,可以指定是否替换采样。
  3. torch.utils.data.SubsetRandomSampler :从给定的索引列表中随机采样。
  4. torch.utils.data.WeightedRandomSampler :根据给定的权重进行采样。

(四)多进程数据加载

num_workers 参数设置为大于 0 的值可以启用多进程数据加载。每个工作进程会加载一个子集的数据,从而加速数据加载过程。

(五)内存固定

pin_memory 参数设置为 True,可以将数据加载到固定内存中,这样在将数据传输到 GPU 时会更快。

三、数据集的创建与使用

(一)创建自定义数据集

  1. 映射式数据集示例

import torch
from torch.utils.data import Dataset


class CustomDataset(Dataset):
    def __init__(self, data, labels):
        self.data = data
        self.labels = labels


    def __len__(self):
        return len(self.data)


    def __getitem__(self, idx):
        return self.data[idx], self.labels[idx]


## 示例数据
data = torch.randn(100, 3, 224, 224)
labels = torch.randint(0, 10, (100,))


dataset = CustomDataset(data, labels)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=10, shuffle=True)

  1. 迭代式数据集示例

from torch.utils.data import IterableDataset


class CustomIterableDataset(IterableDataset):
    def __init__(self, data, labels):
        self.data = data
        self.labels = labels


    def __iter__(self):
        for i in range(len(self.data)):
            yield self.data[i], self.labels[i]


dataset = CustomIterableDataset(data, labels)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=10, num_workers=2)

(二)数据集的分割与合并

  1. 数据集分割

from torch.utils.data import random_split


train_size = int(0.8 * len(dataset))
test_size = len(dataset) - train_size
train_dataset, test_dataset = random_split(dataset, [train_size, test_size])


train_loader = DataLoader(train_dataset, batch_size=10, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=10, shuffle=False)

  1. 数据集合并

from torch.utils.data import ConcatDataset


dataset1 = CustomDataset(data1, labels1)
dataset2 = CustomDataset(data2, labels2)


combined_dataset = ConcatDataset([dataset1, dataset2])
combined_loader = DataLoader(combined_dataset, batch_size=10, shuffle=True)

四、数据加载器的高级用法

(一)自定义 collate_fn

在某些情况下,可能需要自定义如何将单个样本合并成一个批次。例如,对于变长序列数据,可以自定义 collate_fn 来填充序列使其长度一致。

def custom_collate_fn(batch):
    # batch 是一个列表,其中每个元素是一个数据样本
    # 这里可以实现自定义的批次合并逻辑
    # 例如,填充序列使其长度一致
    return batch


dataloader = DataLoader(dataset, batch_size=10, collate_fn=custom_collate_fn)

(二)使用 worker_init_fn 自定义工作进程初始化

worker_init_fn 可以在每个工作进程初始化时执行自定义逻辑,例如设置随机种子。

def worker_init_fn(worker_id):
    import numpy as np
    np.random.seed(worker_id)


dataloader = DataLoader(dataset, num_workers=4, worker_init_fn=worker_init_fn)

(三)分布式数据加载

使用 torch.utils.data.distributed.DistributedSampler 可以在分布式训练中将数据集分割成多个子集,每个进程加载不同的子集。

from torch.utils.data.distributed import DistributedSampler


sampler = DistributedSampler(dataset, num_replicas=4, rank=0, shuffle=True)
dataloader = DataLoader(dataset, batch_size=10, sampler=sampler)

五、总结

通过本教程,我们详细了解了 PyTorch 中 torch.utils.data 模块的使用方法,包括数据加载器的核心参数与用法、数据集的创建与使用、采样器的使用、多进程数据加载、内存固定以及高级功能如自定义 collate_fn 和分布式数据加载。合理利用这些功能可以显著提升数据预处理和加载的效率,为模型训练提供有力支持。

以上内容是否对您有帮助:
在线笔记
App下载
App下载

扫描二维码

下载编程狮App

公众号
微信公众号

编程狮公众号