参考链接
- pytorch官方
- 【Pytorch基础】torch.utils.data.DataLoader方法的使用
- Pytorch】torch.utils.data.DataLoader使用方法
- torch.utils.data.DataLoader使用方法
torch.utils.data.DataLoader
1 | CLASS torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=None, sampler=None, batch_sampler=None, num_workers=0, collate_fn=None, pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None, multiprocessing_context=None, generator=None, *, prefetch_factor=None, persistent_workers=False, pin_memory_device='') |
数据加载器。结合了一个数据集和一个采样器,并提供对给定数据集的可迭代访问。
数据加载器支持映射样式和可迭代样式的数据集,可以使用单进程或多进程加载,自定义加载顺序,并可选择自动批处理(整理)和内存固定。
查看torch.utils.data文档页面以获取更多详细信息。
参数
-
dataset
(数据集):要从中加载数据的数据集。 -
batch_size
(int,可选):每批加载多少个样本(默认:1)。 -
shuffle
(bool,可选):设置为True以在每个时期重新洗牌数据(默认:False)。 -
sampler
(采样器或可迭代对象,可选):定义从数据集中抽取样本的策略。可以是任何实现了__len__
的可迭代对象。如果指定了sampler
,则不能指定shuffle
。 -
batch_sampler
(采样器或可迭代对象,可选):类似于sampler
,但一次返回一批索引。与batch_size
、shuffle
、sampler
和drop_last
相互排斥。 -
num_workers
(int,可选):用于数据加载的子进程数。0表示数据将在主进程中加载(默认:0)。 -
collate_fn
(可调用对象,可选):将样本列表合并成一个小批量的张量。在从映射样式数据集进行批量加载时使用。 -
pin_memory
(bool,可选):如果为True,则数据加载器将在返回数据之前将张量复制到设备/CUDA固定内存中。如果你的数据元素是自定义类型,或者你的collate_fn
返回的批次是自定义类型,请参考下面的示例。 -
drop_last
(bool,可选):设置为True以丢弃最后一个不完整的批次,如果数据集大小不能被批次大小整除。如果为False,并且数据集的大小不能被批次大小整除,则最后一个批次将较小(默认:False)。 -
timeout
(数值,可选):如果为正数,则为从工作进程收集批次的超时值。应始终为非负数(默认:0)。 -
worker_init_fn
(可调用对象,可选):如果不为None,则将在每个工作进程上调用,以工作进程ID(一个位于[0,num_workers - 1]的整数)作为输入,在种子生成和数据加载之后(默认:None)。 -
multiprocessing_context
(str或multiprocessing.context.BaseContext,可选):如果为None,则将使用操作系统的默认多进程上下文(默认:None)。 -
generator
(torch.Generator,可选):如果不为None,则RandomSampler将使用此RNG生成随机索引,而多进程将生成工作进程的base_seed。 (默认:None) -
prefetch_factor
(int,可选,仅限关键字参数):每个工作进程提前加载的批次数。2表示所有工作进程将预取2 *num_workers
批次(默认值取决于num_workers
的设置值。如果num_workers=0
,则默认值为None。否则,如果num_workers > 0
,则默认值为2)。 -
persistent_workers
(bool,可选):如果为True,则数据加载器将在数据集被消耗一次后不会关闭工作进程。这允许保持工作进程的数据集实例处于活动状态(默认:False)。 -
pin_memory_device
(str,可选):如果pin_memory
为True,将数据固定到的设备。
示例代码
1 | import torch |