Loading...

参考链接


torch.utils.data.DataLoader

1
CLASS torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=None, sampler=None, batch_sampler=None, num_workers=0, collate_fn=None, pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None, multiprocessing_context=None, generator=None, *, prefetch_factor=None, persistent_workers=False, pin_memory_device='')

数据加载器。结合了一个数据集和一个采样器,并提供对给定数据集的可迭代访问。

数据加载器支持映射样式和可迭代样式的数据集,可以使用单进程或多进程加载,自定义加载顺序,并可选择自动批处理(整理)和内存固定。

查看torch.utils.data文档页面以获取更多详细信息。


参数

  • dataset(数据集):要从中加载数据的数据集。

  • batch_size(int,可选):每批加载多少个样本(默认:1)。

  • shuffle(bool,可选):设置为True以在每个时期重新洗牌数据(默认:False)。

  • sampler(采样器或可迭代对象,可选):定义从数据集中抽取样本的策略。可以是任何实现了__len__的可迭代对象。如果指定了sampler,则不能指定shuffle

  • batch_sampler(采样器或可迭代对象,可选):类似于sampler,但一次返回一批索引。与batch_sizeshufflesamplerdrop_last相互排斥。

  • num_workers(int,可选):用于数据加载的子进程数。0表示数据将在主进程中加载(默认:0)。

  • collate_fn(可调用对象,可选):将样本列表合并成一个小批量的张量。在从映射样式数据集进行批量加载时使用。

  • pin_memory(bool,可选):如果为True,则数据加载器将在返回数据之前将张量复制到设备/CUDA固定内存中。如果你的数据元素是自定义类型,或者你的collate_fn返回的批次是自定义类型,请参考下面的示例。

  • drop_last(bool,可选):设置为True以丢弃最后一个不完整的批次,如果数据集大小不能被批次大小整除。如果为False,并且数据集的大小不能被批次大小整除,则最后一个批次将较小(默认:False)。

  • timeout(数值,可选):如果为正数,则为从工作进程收集批次的超时值。应始终为非负数(默认:0)。

  • worker_init_fn(可调用对象,可选):如果不为None,则将在每个工作进程上调用,以工作进程ID(一个位于[0,num_workers - 1]的整数)作为输入,在种子生成和数据加载之后(默认:None)。

  • multiprocessing_context(str或multiprocessing.context.BaseContext,可选):如果为None,则将使用操作系统的默认多进程上下文(默认:None)。

  • generator(torch.Generator,可选):如果不为None,则RandomSampler将使用此RNG生成随机索引,而多进程将生成工作进程的base_seed。 (默认:None)

  • prefetch_factor(int,可选,仅限关键字参数):每个工作进程提前加载的批次数。2表示所有工作进程将预取2 * num_workers批次(默认值取决于num_workers的设置值。如果num_workers=0,则默认值为None。否则,如果num_workers > 0,则默认值为2)。

  • persistent_workers(bool,可选):如果为True,则数据加载器将在数据集被消耗一次后不会关闭工作进程。这允许保持工作进程的数据集实例处于活动状态(默认:False)。

  • pin_memory_device(str,可选):如果pin_memory为True,将数据固定到的设备。


示例代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
import torch
import torch.utils.data as Data

def main():
torch.manual_seed(1) # reproducible

BATCH_SIZE = 5 # 批训练的数据个数

x = torch.linspace(1, 10, 10) # x data (torch tensor)
y = torch.linspace(10, 1, 10) # y data (torch tensor)

# 先转换成 torch 能识别的 Dataset
torch_dataset = Data.TensorDataset(x, y)

# 把 dataset 放入 DataLoader
loader = Data.DataLoader(
dataset=torch_dataset,
batch_size=BATCH_SIZE,
shuffle=True,
num_workers=2, # 注意: 在Windows和MacOS上,可能需要将num_workers设置为0来避免问题
)

print(enumerate(loader))

for epoch in range(3):
for step, (batch_x, batch_y) in enumerate(loader):
#step只有0,1是因为,batchsize=5, 10个数据点,两次循环就用完了所有数据
print('Epoch: ', epoch, '| Step: ', step, '| batch x: ', batch_x.numpy(), '| batch y: ', batch_y.numpy())

if __name__ == '__main__':
main()