PyTorch torch.utils.data
PyTorch 数据加载与处理详解
一、PyTorch 数据加载器简介
torch.utils.data.DataLoader
是 PyTorch 提供的核心数据加载工具,它可以方便地从数据集中加载数据,并支持多种高级功能,如多进程加载、自动批处理、自定义数据转换等。
二、数据加载器的核心参数与用法
(一)DataLoader
的基本构造
DataLoader(
dataset,
batch_size=1,
shuffle=False,
sampler=None,
batch_sampler=None,
num_workers=0,
collate_fn=None,
pin_memory=False,
drop_last=False,
timeout=0,
worker_init_fn=None
)
dataset
:数据集对象,可以是映射式数据集(实现__getitem__
和__len__
方法)或迭代式数据集(实现__iter__
方法)。batch_size
:每个批次加载的样本数量,默认为 1。shuffle
:是否在每个 epoch 开始时打乱数据,默认为False
。sampler
:自定义采样器,用于指定数据加载顺序,不能与shuffle
同时使用。num_workers
:加载数据时使用的子进程数量,默认为 0(即单进程加载)。collate_fn
:用于将单个样本合并成批次的函数,默认会将样本列表转换为张量。pin_memory
:是否将数据加载到固定内存中,以便更快地传输到 GPU,默认为False
。drop_last
:如果数据集大小不能被批次大小整除,是否丢弃最后一个不完整的批次,默认为False
。
(二)数据集类型
- 映射式数据集
- 实现
__getitem__
和__len__
协议。 - 适合于数据已经存储在磁盘上且可以按索引访问的场景。
- 实现
- 迭代式数据集
- 实现
__iter__
协议。 - 适合于数据流式读取的场景,如实时生成的数据或从数据库中读取的数据。
- 实现
(三)采样器
torch.utils.data.SequentialSampler
:按顺序采样数据。torch.utils.data.RandomSampler
:随机采样数据,可以指定是否替换采样。torch.utils.data.SubsetRandomSampler
:从给定的索引列表中随机采样。torch.utils.data.WeightedRandomSampler
:根据给定的权重进行采样。
(四)多进程数据加载
将 num_workers
参数设置为大于 0 的值可以启用多进程数据加载。每个工作进程会加载一个子集的数据,从而加速数据加载过程。
(五)内存固定
将 pin_memory
参数设置为 True
,可以将数据加载到固定内存中,这样在将数据传输到 GPU 时会更快。
三、数据集的创建与使用
(一)创建自定义数据集
- 映射式数据集示例
import torch
from torch.utils.data import Dataset
class CustomDataset(Dataset):
def __init__(self, data, labels):
self.data = data
self.labels = labels
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
return self.data[idx], self.labels[idx]
## 示例数据
data = torch.randn(100, 3, 224, 224)
labels = torch.randint(0, 10, (100,))
dataset = CustomDataset(data, labels)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=10, shuffle=True)
- 迭代式数据集示例
from torch.utils.data import IterableDataset
class CustomIterableDataset(IterableDataset):
def __init__(self, data, labels):
self.data = data
self.labels = labels
def __iter__(self):
for i in range(len(self.data)):
yield self.data[i], self.labels[i]
dataset = CustomIterableDataset(data, labels)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=10, num_workers=2)
(二)数据集的分割与合并
- 数据集分割
from torch.utils.data import random_split
train_size = int(0.8 * len(dataset))
test_size = len(dataset) - train_size
train_dataset, test_dataset = random_split(dataset, [train_size, test_size])
train_loader = DataLoader(train_dataset, batch_size=10, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=10, shuffle=False)
- 数据集合并
from torch.utils.data import ConcatDataset
dataset1 = CustomDataset(data1, labels1)
dataset2 = CustomDataset(data2, labels2)
combined_dataset = ConcatDataset([dataset1, dataset2])
combined_loader = DataLoader(combined_dataset, batch_size=10, shuffle=True)
四、数据加载器的高级用法
(一)自定义 collate_fn
在某些情况下,可能需要自定义如何将单个样本合并成一个批次。例如,对于变长序列数据,可以自定义 collate_fn
来填充序列使其长度一致。
def custom_collate_fn(batch):
# batch 是一个列表,其中每个元素是一个数据样本
# 这里可以实现自定义的批次合并逻辑
# 例如,填充序列使其长度一致
return batch
dataloader = DataLoader(dataset, batch_size=10, collate_fn=custom_collate_fn)
(二)使用 worker_init_fn
自定义工作进程初始化
worker_init_fn
可以在每个工作进程初始化时执行自定义逻辑,例如设置随机种子。
def worker_init_fn(worker_id):
import numpy as np
np.random.seed(worker_id)
dataloader = DataLoader(dataset, num_workers=4, worker_init_fn=worker_init_fn)
(三)分布式数据加载
使用 torch.utils.data.distributed.DistributedSampler
可以在分布式训练中将数据集分割成多个子集,每个进程加载不同的子集。
from torch.utils.data.distributed import DistributedSampler
sampler = DistributedSampler(dataset, num_replicas=4, rank=0, shuffle=True)
dataloader = DataLoader(dataset, batch_size=10, sampler=sampler)
五、总结
通过本教程,我们详细了解了 PyTorch 中 torch.utils.data
模块的使用方法,包括数据加载器的核心参数与用法、数据集的创建与使用、采样器的使用、多进程数据加载、内存固定以及高级功能如自定义 collate_fn
和分布式数据加载。合理利用这些功能可以显著提升数据预处理和加载的效率,为模型训练提供有力支持。
更多建议: