天天动画片 > 八卦谈 > 第一篇—数据加载(Dataset,Sampler,DataLoader)

第一篇—数据加载(Dataset,Sampler,DataLoader)

八卦谈 佚名 2022-10-27 13:35:38

  • Pytorch对数据集的处理有三个重要的类:

  • Dataset,Sampler和DataLoader,他们均是torch.utils.data包下的模块(类)

  • Dataset:数据集的类,主要用于定义数据集;

  • Sampler:采样器的类,主要用于定义从数据集中选出数据的规则,比如是随机读取或顺序读取;

  • DataLoader:数据的加载类,主要用于读取数据;

  • 总结:

  • Dataset:定义整个数据集;

  • Sampler:定义读取数据的规则;

  • DataLoader:用于读取数据;


  • 实现:

  • 创建数据集(30张图片,3个类别0,1,2)



  • Dataset定义数据集

  • Dataset位于torch.utils.data下,可通过定义继承自这个类的子类来自定义数据集;

  • 有两个重要的方法需要重写:

  •     __len__(self):返回数据集的大小;

  •     __getitem__(self,index):index为下标,返回数据集中对应下标的数据;


  • DataLoader读取数据

  • DataLoader对Dataset打包,完成最后对数据的读取;

  • 一般不需要自己定义或重写DataLoader类,直接使用即可;

  • 常用参数:

  • dataset:一个Dataset类对象,定义好的数据集;

  • batch_size:整数值,每个batch的样本数量,默认为1;

  • shuffle:布尔值,如果为True,则在每个epoch开始的时候打乱数据集,默认为False;

  • sampler:一个Sampler类对象,定义读数据的规则,Sampler每次返回一个索引,默认为None;

  • batch_sampler:也是一个Sampler类对象,与sampler参数不同的是,它接收的Sampler类对象每次返回一个batch的索引,默认为None;

  • num_workers:整数值,定义有几个进程来处理数据,默认为0,表示所有的数据都会被加载到主进程;

  • pin_memory:布尔值,如果为True,那么将加载的数据拷贝到cuda固定的内存中;

  • drop_last:布尔值,如果为True,则对最后一个batch来说,如果样本量不足batch_size,则舍弃这个batch,如果为False,则最后一个batch不舍弃,默认为False;

  • timeout:如果是正数,表示设置加载一个batch的等待时间,若超出设定的时间还没加载完,则舍弃这个batch,如果是0,表示不设置限制时间,默认为0;

  • collate_fn:一个函数,用于将一个batch的样本打包成一个大的tensor;设置batch_size=8,实际上,从Dataset读取的是单独的数据,如每次采样得到一个tuple:(img,label),collate_fn的作用是用于包装batch,即每从Dataset中抽出8个这样的tuple,就把8个(img,label)包装成一个list—[imgs,labels],其中imgs和labels都是tensor,imgs.shape为(8,c,h,w);

  • 注意:

  • DataLoader参数之间存在互斥的情况,主要针对自己定义的采样器:

  • sampler:如果自行指定了sampler参数,则shuffle必须保持默认值,即False;

  • batch_sampler:如果自行指定了batch_sampler参数,则batch_size,sampler,shuffle,drop_last都必须保持默认值;

  • 如果没有自行指定采样器(sampler和batch_sampler均为None),则:

  • sampler:

  •     shuffle:True,sampler采用RandomSampler,即随机采样;

  •     shuffle:False,sampler采用SequentialSampler,即顺序采样;

  • batch_sampler:采用BatchSampler,即根据batch_size进行batch采样;

  • 其中,RandomSampler,SequentialSampler和BatchSampler都是Sampler的子类。

一个batch的数据


  • Sampler定义读取数据的规则

  • 主要用于定义采样的规则,Sampler是一个可迭代对象,其类方法__iter__()定义了迭代后返回的内容。

  • SequentialSampler类:

  • 是一个按顺序进行采样的采样器,接收一个数据集(任何可迭代对象都可以)做参数,按顺序对其进行采样。

  • RandomSampler类:

  • 是一个随机采样器,返回随机采样的值;

  • 参数:

  • dataset:数据集(任何可迭代对象);

  • replacement:布尔值,True表示放回采样,默认为False;

  • num_samples:只有replacement=True时才可设置,表示采样的个数,默认为数据集的总长度;

  • BatchSampler类:

  • Sampler在每次迭代都只返回一个索引,而BatchSampler的作用是对返回一个索引的采样器进行包装,按照设定的batch_size返回一组索引;

  • 参数:

  • sampler:一个Sampler对象(或一个可迭代对象);

  • batch_size:batch的大小;

  • drop_last:是否丢弃最后一个可能不足batch_size大小的batch;




本文标题:第一篇—数据加载(Dataset,Sampler,DataLoader) - 八卦谈
本文地址:www.ttdhp.com/article/5282.html

天天动画片声明:登载此文出于传递更多信息之目的,并不意味着赞同其观点或证实其描述。
扫码关注我们