PyTorch torch.utils.data
PyTorch 数据加载实用程序的核心是 torch.utils.data.DataLoader
类。 它表示可在数据集上迭代的 Python,并支持
这些选项由 DataLoader
的构造函数参数配置,该参数具有签名:
DataLoader(dataset, batch_size=1, shuffle=False, sampler=None,
batch_sampler=None, num_workers=0, collate_fn=None,
pin_memory=False, drop_last=False, timeout=0,
worker_init_fn=None)
以下各节详细介绍了这些选项的效果和用法。
数据集类型
DataLoader
构造函数的最重要参数是dataset
,它指示要从中加载数据的数据集对象。 PyTorch 支持两种不同类型的数据集:
地图样式数据集
映射样式数据集是一种实现__getitem__()
和__len__()
协议的数据集,它表示从(可能是非整数)索引/关键字到数据样本的映射。
例如,当使用dataset[idx]
访问时,此类数据集可以从磁盘上的文件夹中读取第idx
张图像及其对应的标签。
迭代式数据集
可迭代样式的数据集是 IterableDataset
子类的实例,该子类实现了__iter__()
协议,并表示数据样本上的可迭代。 这种类型的数据集特别适用于随机读取价格昂贵甚至不大可能,并且批处理大小取决于所获取数据的情况。
例如,这种数据集称为iter(dataset)
时,可以返回从数据库,远程服务器甚至实时生成的日志中读取的数据流。
注意
当将 IterableDataset
与一起使用时,多进程数据加载。 在每个工作进程上都复制相同的数据集对象,因此必须对副本进行不同的配置,以避免重复的数据。 有关如何实现此功能的信息,请参见 IterableDataset
文档。
数据加载顺序和 Sampler
对于迭代式数据集,数据加载顺序完全由用户定义的迭代器控制。 这样可以更轻松地实现块读取和动态批次大小的实现(例如,通过每次生成一个批次的样本)。
本节的其余部分涉及地图样式数据集的情况。 torch.utils.data.Sampler
类用于指定数据加载中使用的索引/键的顺序。 它们代表数据集索引上的可迭代对象。 例如,在具有随机梯度体面(SGD)的常见情况下, Sampler
可以随机排列一列索引,一次生成每个索引,或者为小批量生成少量索引 新币。
基于 DataLoader
的shuffle
参数,将自动构建顺序采样或混洗的采样器。 或者,用户可以使用sampler
参数指定一个自定义 Sampler
对象,该对象每次都会产生要提取的下一个索引/关键字。
可以一次生成批量索引列表的自定义 Sampler
作为batch_sampler
参数传递。 也可以通过batch_size
和drop_last
参数启用自动批处理。 有关更多详细信息,请参见下一部分的。
Note
sampler
和batch_sampler
都不与可迭代样式的数据集兼容,因为此类数据集没有键或索引的概念。
加载批处理和非批处理数据
DataLoader
支持通过参数batch_size
,drop_last
和batch_sampler
将各个提取的数据样本自动整理为批次。
自动批处理(默认)
这是最常见的情况,对应于获取一小批数据并将其整理为批处理的样本,即包含张量,其中一维为批处理维度(通常是第一维)。
当batch_size
(默认1
)不是None
时,数据加载器将生成批处理的样本,而不是单个样本。 batch_size
和drop_last
参数用于指定数据加载器如何获取数据集密钥的批处理。 对于地图样式的数据集,用户可以选择指定batch_sampler
,它一次生成一个键列表。
Note
batch_size
和drop_last
自变量本质上用于从sampler
构造batch_sampler
。 对于地图样式的数据集,sampler
由用户提供或基于shuffle
参数构造。 对于可迭代样式的数据集,sampler
是一个虚拟的无限数据集。
Note
当从可重复样式数据集进行多重处理提取时,drop_last
参数会删除每个工作人员数据集副本的最后一个非完整批次。
使用来自采样器的索引获取样本列表后,作为collate_fn
参数传递的函数用于将样本列表整理为批次。
在这种情况下,从地图样式数据集加载大致等效于:
for indices in batch_sampler:
yield collate_fn([dataset[i] for i in indices])
从可迭代样式的数据集加载大致等效于:
dataset_iter = iter(dataset)
for indices in batch_sampler:
yield collate_fn([next(dataset_iter) for _ in indices])
自定义collate_fn
可用于自定义排序规则,例如,将顺序数据填充到批处理的最大长度。
禁用自动批处理
在某些情况下,用户可能希望以数据集代码手动处理批处理,或仅加载单个样本。 例如,直接加载批处理的数据(例如,从数据库中批量读取或读取连续的内存块)可能更便宜,或者批处理大小取决于数据,或者该程序设计为可处理单个样本。 在这种情况下,最好不要使用自动批处理(其中collate_fn
用于整理样本),而应让数据加载器直接返回dataset
对象的每个成员。
当batch_size
和batch_sampler
均为None
时(batch_sampler
的默认值已为None
),自动批处理被禁用。 从dataset
获得的每个样本都将作为collate_fn
参数传递的函数进行处理。
禁用自动批处理时,默认值collate_fn
仅将 NumPy 数组转换为 PyTorch 张量,而其他所有内容均保持不变。
In this case, loading from a map-style dataset is roughly equivalent with:
for index in sampler:
yield collate_fn(dataset[index])
and loading from an iterable-style dataset is roughly equivalent with:
for data in iter(dataset):
yield collate_fn(data)
使用collate_fn
启用或禁用自动批处理时,collate_fn
的使用略有不同。
禁用自动批处理时,将对每个单独的数据样本调用collate_fn
,并且从数据加载器迭代器产生输出。 在这种情况下,默认的collate_fn
仅转换 PyTorch 张量中的 NumPy 数组。
启用自动批处理时,会每次调用collate_fn
并带有数据样本列表。 期望将输入样本整理为一批,以便从数据加载器迭代器中获得收益。 本节的其余部分描述了这种情况下默认collate_fn
的行为。
例如,如果每个数据样本都包含一个 3 通道图像和一个整体类标签,即数据集的每个元素返回一个元组(image, class_index)
,则默认值collate_fn
将此类元组的列表整理为一个元组 批处理图像张量和批处理类标签 Tensor。 特别是,默认collate_fn
具有以下属性:
- 它始终将新维度添加为批次维度。
- 它会自动将 NumPy 数组和 Python 数值转换为 PyTorch 张量。
- 它保留了数据结构,例如,如果每个样本都是一个字典,它将输出一个具有相同键集但将批处理 Tensors 作为值的字典(如果无法将这些值转换为 Tensors,则将其列出)。 与
list
,tuple
,namedtuple
等相同。
用户可以使用自定义的collate_fn
来实现自定义批处理,例如,沿除第一个维度之外的其他维度进行校对,各种长度的填充序列或添加对自定义数据类型的支持。
单进程和多进程数据加载
默认情况下, DataLoader
使用单进程数据加载。
在 Python 进程中,全局解释器锁定(GIL)阻止了跨线程真正的完全并行化 Python 代码。 为了避免在加载数据时阻塞计算代码,PyTorch 提供了一个简单的开关,只需将参数num_workers
设置为正整数即可执行多进程数据加载。
单进程数据加载(默认)
在此模式下,以与初始化 DataLoader
相同的过程完成数据提取。 因此,数据加载可能会阻止计算。 然而,当用于在进程之间共享数据的资源(例如,共享存储器,文件描述符)受到限制时,或者当整个数据集很小并且可以完全加载到存储器中时,该模式可能是优选的。 此外,单进程加载通常显示更多可读的错误跟踪,因此对于调试很有用。
多进程数据加载
将参数num_workers
设置为正整数将打开具有指定数量的加载程序工作进程的多进程数据加载。
在此模式下,每次创建 DataLoader
的迭代器时(例如,当您调用enumerate(dataloader)
时),都会创建num_workers
工作进程。 此时,dataset
,collate_fn
和worker_init_fn
被传递给每个工作程序,在这里它们被用来初始化和获取数据。 这意味着数据集访问及其内部 IO 转换(包括collate_fn
)在工作进程中运行。
torch.utils.data.get_worker_info()
在工作进程中返回各种有用的信息(包括工作 ID,数据集副本,初始种子等),并在主进程中返回None
。 用户可以在数据集代码和/或worker_init_fn
中使用此功能来分别配置每个数据集副本,并确定代码是否正在工作进程中运行。 例如,这在分片数据集时特别有用。
对于地图样式的数据集,主过程使用sampler
生成索引并将其发送给工作人员。 因此,任何随机播放都是在主过程中完成的,该过程通过为索引分配索引来引导加载。
对于可迭代样式的数据集,由于每个工作进程都获得dataset
对象的副本,因此幼稚的多进程加载通常会导致数据重复。 用户可以使用 torch.utils.data.get_worker_info()
和/或worker_init_fn
独立配置每个副本。 (有关如何实现此操作的信息,请参见 IterableDataset
文档。)出于类似的原因,在多进程加载中,drop_last
参数删除每个工作程序的可迭代样式数据集副本的最后一个非完整批次。
一旦迭代结束或迭代器被垃圾回收,工作器将关闭。
警告
通常不建议在多进程加载中返回 CUDA 张量,因为在使用 CUDA 和在并行处理中共享 CUDA 张量时存在很多微妙之处(请参见在并行处理中的 CUDA)。 相反,我们建议使用自动内存固定(即,设置pin_memory=True
),该功能可以将数据快速传输到支持 CUDA 的 GPU。
平台特定的行为
由于工作程序依赖于 Python multiprocessing
,因此与 Unix 相比,Windows 上的工作程序启动行为有所不同。
- 在 Unix 上,
fork()
是默认的multiprocessing
启动方法。 使用fork()
,童工通常可以直接通过克隆的地址空间访问dataset
和 Python 参数函数。 - 在 Windows 上,
spawn()
是默认的multiprocessing
启动方法。 使用spawn()
启动另一个解释器,该解释器运行您的主脚本,然后运行内部工作程序函数,该函数通过序列化pickle
接收dataset
,collate_fn
和其他参数。
这种独立的序列化意味着您应该采取两个步骤来确保在使用多进程数据加载时与 Windows 兼容:
- 将您的大部分主脚本代码包装在
if __name__ == '__main__':
块中,以确保在启动每个工作进程时,该脚本不会再次运行(很可能会产生错误)。 您可以在此处放置数据集和DataLoader
实例创建逻辑,因为它不需要在 worker 中重新执行。 - 确保在
__main__
检查之外将任何自定义collate_fn
,worker_init_fn
或dataset
代码声明为顶级定义。 这样可以确保它们在工作进程中可用。 (这是必需的,因为将函数仅作为引用而不是bytecode
进行腌制。)
多进程数据加载中的随机性
默认情况下,每个工作人员的 PyTorch 种子将设置为base_seed + worker_id
,其中base_seed
是主进程使用其 RNG 生成的长整数(因此,强制使用 RNG 状态)。 但是,初始化工作程序(例如 NumPy)时,可能会复制其他库的种子,导致每个工作程序返回相同的随机数。
在worker_init_fn
中,您可以使用 torch.utils.data.get_worker_info().seed
或 torch.initial_seed()
访问每个工作人员的 PyTorch 种子集,并在加载数据之前使用它为其他库提供种子。
内存固定
主机到 GPU 副本源自固定(页面锁定)内存时,速度要快得多。 有关通常何时以及如何使用固定内存的更多详细信息,请参见使用固定内存缓冲区。
对于数据加载,将pin_memory=True
传递到 DataLoader
将自动将获取的数据张量放置在固定内存中,从而更快地将数据传输到启用 CUDA 的 GPU。
默认的内存固定逻辑仅识别张量以及包含张量的映射和可迭代对象。 默认情况下,如果固定逻辑看到一个自定义类型的批处理(如果您有一个collate_fn
返回自定义批处理类型,则会发生),或者如果该批处理的每个元素都是自定义类型,则固定逻辑将 无法识别它们,它将返回该批处理(或那些元素)而不固定内存。 要为自定义批处理或数据类型启用内存固定,请在自定义类型上定义pin_memory()
方法。
请参见下面的示例。
例:
class SimpleCustomBatch:
def __init__(self, data):
transposed_data = list(zip(*data))
self.inp = torch.stack(transposed_data[0], 0)
self.tgt = torch.stack(transposed_data[1], 0)
# custom memory pinning method on custom type
def pin_memory(self):
self.inp = self.inp.pin_memory()
self.tgt = self.tgt.pin_memory()
return self
def collate_wrapper(batch):
return SimpleCustomBatch(batch)
inps = torch.arange(10 * 5, dtype=torch.float32).view(10, 5)
tgts = torch.arange(10 * 5, dtype=torch.float32).view(10, 5)
dataset = TensorDataset(inps, tgts)
loader = DataLoader(dataset, batch_size=2, collate_fn=collate_wrapper,
pin_memory=True)
for batch_ndx, sample in enumerate(loader):
print(sample.inp.is_pinned())
print(sample.tgt.is_pinned())
class torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, num_workers=0, collate_fn=None, pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None, multiprocessing_context=None)¶
数据加载器。 组合数据集和采样器,并在给定的数据集上提供可迭代的。
DataLoader
支持地图样式和可迭代样式的数据集,具有单进程或多进程加载,自定义加载顺序以及可选的自动批处理(归类)和内存固定。
有关更多详细信息,请参见 torch.utils.data
文档页面。
参数
- 数据集 (数据集)–要从中加载数据的数据集。
- batch_size (python:int , 可选)–每批次要加载多少个样本(默认值:
1
)。 - 随机播放 (bool , 可选)–设置为
True
以使数据在每个时间段都重新随机播放(默认值:False
)。 - 采样器 (采样器 , 可选)–定义了从数据集中抽取样本的策略。 如果指定,则
shuffle
必须为False
。 - batch_sampler (采样器 , 可选)–类似
sampler
,但在 时间。 与batch_size
,shuffle
,sampler
和drop_last
互斥。 - num_workers (python:int , 可选)–多少个子进程用于数据加载。
0
表示将在主进程中加载数据。 (默认:0
) - collate_fn (可调用的, 可选)–合并样本列表以形成张量的小批量。 在从地图样式数据集中使用批量加载时使用。
- pin_memory (bool , 可选)–如果
True
,则数据加载器将张量复制到 CUDA 固定的内存中,然后返回。 如果您的数据元素是自定义类型,或者您的collate_fn
返回的是自定义类型的批次,请参见下面的示例。 - drop_last (布尔 , 可选)–设置为
True
以删除最后不完整的批次,如果数据集大小不可分割 按批次大小。 如果False
并且数据集的大小不能被批次大小整除,那么最后一批将较小。 (默认:False
) - 超时(数字 , 可选)–如果为正,则表示从工作人员处收集批次的超时值。 应始终为非负数。 (默认:
0
) - worker_init_fn (可调用 , 可选)–如果不是
None
,则将在每个具有工作人员 ID (在播种之后和数据加载之前,将[0, num_workers - 1]
中的 int 作为输入。 (默认:None
)
Warning
如果使用spawn
启动方法,则worker_init_fn
不能是不可拾取的对象,例如 lambda 函数。 有关 PyTorch 中与并行处理有关的更多详细信息,请参见并行处理最佳实践。
Note
len(dataloader)
启发式方法基于所用采样器的长度。 当dataset
是 IterableDataset
时,无论多进程加载配置如何,都将返回len(dataset)
(如果实现),因为 PyTorch 信任用户dataset
代码可以正确处理多进程加载 避免重复数据。 有关这两种类型的数据集以及 IterableDataset
如何与多进程数据加载交互的更多详细信息,请参见数据集类型。
class torch.utils.data.Dataset¶
表示 Dataset
的抽象类。
代表从键到数据样本的映射的所有数据集都应将其子类化。 所有子类都应该覆盖__getitem__()
,支持为给定键获取数据样本。 子类还可以选择覆盖__len__()
,它有望通过许多 Sampler
实现以及 DataLoader
的默认选项返回数据集的大小。
Note
默认情况下, DataLoader
构造一个索引采样器,该采样器产生整数索引。 要使其与具有非整数索引/键的地图样式数据集一起使用,必须提供自定义采样器。
class torch.utils.data.IterableDataset¶
可迭代的数据集。
代表可迭代数据样本的所有数据集都应将其子类化。 当数据来自流时,这种形式的数据集特别有用。
所有子类都应覆盖__iter__()
,这将返回此数据集中的样本迭代器。
当子类与 DataLoader
一起使用时,数据集中的每个项目都将由 DataLoader
迭代器产生。 当num_workers > 0
时,每个工作进程将具有数据集对象的不同副本,因此通常需要独立配置每个副本,以避免从工作进程返回重复的数据。 get_worker_info()
在工作程序进程中调用时,返回有关工作程序的信息。 可以在数据集的__iter__()
方法或 DataLoader
的worker_init_fn
选项中使用它来修改每个副本的行为。
示例 1:在__iter__()
中将工作负载分配给所有工作人员:
>>> class MyIterableDataset(torch.utils.data.IterableDataset):
... def __init__(self, start, end):
... super(MyIterableDataset).__init__()
... assert end > start, "this example code only works with end >= start"
... self.start = start
... self.end = end
...
... def __iter__(self):
... worker_info = torch.utils.data.get_worker_info()
... if worker_info is None: # single-process data loading, return the full iterator
... iter_start = self.start
... iter_end = self.end
... else: # in a worker process
... # split workload
... per_worker = int(math.ceil((self.end - self.start) / float(worker_info.num_workers)))
... worker_id = worker_info.id
... iter_start = self.start + worker_id * per_worker
... iter_end = min(iter_start + per_worker, self.end)
... return iter(range(iter_start, iter_end))
...
>>> # should give same set of data as range(3, 7), i.e., [3, 4, 5, 6].
>>> ds = MyIterableDataset(start=3, end=7)
>>> # Single-process loading
>>> print(list(torch.utils.data.DataLoader(ds, num_workers=0)))
[3, 4, 5, 6]
>>> # Mult-process loading with two worker processes
>>> # Worker 0 fetched [3, 4]. Worker 1 fetched [5, 6].
>>> print(list(torch.utils.data.DataLoader(ds, num_workers=2)))
[3, 5, 4, 6]
>>> # With even more workers
>>> print(list(torch.utils.data.DataLoader(ds, num_workers=20)))
[3, 4, 5, 6]
示例 2:使用worker_init_fn
在所有工作人员之间分配工作量:
>>> class MyIterableDataset(torch.utils.data.IterableDataset):
... def __init__(self, start, end):
... super(MyIterableDataset).__init__()
... assert end > start, "this example code only works with end >= start"
... self.start = start
... self.end = end
...
... def __iter__(self):
... return iter(range(self.start, self.end))
...
>>> # should give same set of data as range(3, 7), i.e., [3, 4, 5, 6].
>>> ds = MyIterableDataset(start=3, end=7)
>>> # Single-process loading
>>> print(list(torch.utils.data.DataLoader(ds, num_workers=0)))
[3, 4, 5, 6]
>>>
>>> # Directly doing multi-process loading yields duplicate data
>>> print(list(torch.utils.data.DataLoader(ds, num_workers=2)))
[3, 3, 4, 4, 5, 5, 6, 6]
>>> # Define a `worker_init_fn` that configures each dataset differently
>>> def worker_init_fn(worker_id):
... worker_info = torch.utils.data.get_worker_info()
... dataset = worker_info.dataset # the dataset copy in this worker process
... overall_start = dataset.start
... overall_end = dataset.end
... # configure the dataset to only process the split workload
... per_worker = int(math.ceil((overall_end - overall_start) / float(worker_info.num_workers)))
... worker_id = worker_info.id
... dataset.start = overall_start + worker_id * per_worker
... dataset.end = min(dataset.start + per_worker, overall_end)
...
>>> # Mult-process loading with the custom `worker_init_fn`
>>> # Worker 0 fetched [3, 4]. Worker 1 fetched [5, 6].
>>> print(list(torch.utils.data.DataLoader(ds, num_workers=2, worker_init_fn=worker_init_fn)))
[3, 5, 4, 6]
>>> # With even more workers
>>> print(list(torch.utils.data.DataLoader(ds, num_workers=20, worker_init_fn=worker_init_fn)))
[3, 4, 5, 6]
class torch.utils.data.TensorDataset(*tensors)¶
数据集包装张量。
每个样本将通过沿第一维索引张量来检索。
Parameters
张量 (tensor*)–具有与第一维相同大小的张量。
class torch.utils.data.ConcatDataset(datasets)¶
数据集是多个数据集的串联。
此类对于组装不同的现有数据集很有用。
Parameters
数据集(序列)–要连接的数据集列表
class torch.utils.data.ChainDataset(datasets)¶
用于链接多个 IterableDataset
的数据集。
此类对于组装不同的现有数据集流很有用。 链接操作是即时完成的,因此将大型数据集与此类连接起来将非常有效。
Parameters
数据集(IterableDataset 的可迭代)–链接在一起的数据集
class torch.utils.data.Subset(dataset, indices)¶
指定索引处的数据集子集。
Parameters
- 数据集 (数据集)–整个数据集
- 索引(序列)–为子集选择的整个集合中的索引
torch.utils.data.get_worker_info()¶
返回有关当前 DataLoader
迭代器工作进程的信息。
在工作线程中调用时,此方法返回一个保证具有以下属性的对象:
id
:当前工作人员 ID。num_workers
:工人总数。seed
:当前工作程序的随机种子集。 该值由主进程 RNG 和工作程序 ID 确定。 有关更多详细信息,请参见DataLoader
的文档。dataset
:此流程在中的数据集对象的副本。 请注意,在不同的过程中,这将是与主过程中的对象不同的对象。
在主进程中调用时,将返回None
。
Note
在传递给 DataLoader
的worker_init_fn
中使用时,此方法可用于不同地设置每个工作进程,例如,使用worker_id
将dataset
对象配置为仅读取 分片数据集的特定部分,或使用seed
播种数据集代码中使用的其他库(例如 NumPy)。
torch.utils.data.random_split(dataset, lengths)¶
将数据集随机拆分为给定长度的不重叠的新数据集。
Parameters
- 数据集 (数据集)–要拆分的数据集
- 长度(序列)–要产生的分割的长度
class torch.utils.data.Sampler(data_source)¶
所有采样器的基类。
每个 Sampler 子类都必须提供__iter__()
方法(提供一种对数据集元素的索引进行迭代的方法)和__len__()
方法,该方法返回返回的迭代器的长度。
Note
DataLoader
并非严格要求__len__()
方法,但在涉及 DataLoader
长度的任何计算中都应采用。
class torch.utils.data.SequentialSampler(data_source)¶
始终以相同顺序顺序采样元素。
Parameters
data_source (数据集)–要从中采样的数据集
class torch.utils.data.RandomSampler(data_source, replacement=False, num_samples=None)¶
随机采样元素。 如果不进行替换,则从经过改组的数据集中采样。 如果要更换,则用户可以指定num_samples
进行绘制。
Parameters
- data_source (Dataset) – dataset to sample from
- 替换 (bool )–如果
True
为默认值,则替换为True
- num_samples (python:int )–要绘制的样本数,默认为 len(dataset)。 仅当<cite>替换</cite>为
True
时才应指定此参数。
class torch.utils.data.SubsetRandomSampler(indices)¶
从给定的索引列表中随机抽样元素,而无需替换。
Parameters
索引(序列)–索引序列
class torch.utils.data.WeightedRandomSampler(weights, num_samples, replacement=True)¶
以给定的概率(权重)从[0,..,len(weights)-1]
中采样元素。
Parameters
- 权重(序列)–权重序列,不必累加一个
- num_samples (python:int )–要绘制的样本数
- 替代品 (bool )–如果
True
,则抽取替代品抽取样品。 如果没有,则它们将被替换而不会被绘制,这意味着当为一行绘制样本索引时,无法为该行再次绘制它。
例
>>> list(WeightedRandomSampler([0.1, 0.9, 0.4, 0.7, 3.0, 0.6], 5, replacement=True))
[0, 0, 0, 1, 0]
>>> list(WeightedRandomSampler([0.9, 0.4, 0.05, 0.2, 0.3, 0.1], 5, replacement=False))
[0, 1, 4, 3, 2]
class torch.utils.data.BatchSampler(sampler, batch_size, drop_last)¶
包装另一个采样器以产生一个小批量的索引。
Parameters
- 采样器 (采样器)–基本采样器。
- batch_size (python:int )–迷你批量的大小。
- drop_last (bool )–如果为
True
,则采样器将丢弃最后一批,如果其大小小于batch_size
Example
>>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=False))
[[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]]
>>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=True))
[[0, 1, 2], [3, 4, 5], [6, 7, 8]]
class torch.utils.data.distributed.DistributedSampler(dataset, num_replicas=None, rank=None, shuffle=True)¶
将数据加载限制为数据集子集的采样器。
与 torch.nn.parallel.DistributedDataParallel
结合使用时特别有用。 在这种情况下,每个进程都可以将 DistributedSampler 实例作为 DataLoader 采样器传递,并加载原始数据集的专有子集。
Note
假定数据集大小恒定。
Parameters
- 数据集 –用于采样的数据集。
- num_replicas (可选)–参与分布式训练的进程数。
- 等级(可选)–当前进程在 num_replicas 中的等级。
- 随机播放(可选)–如果为 true(默认值),采样器将随机播放索引
更多建议: