您当前的位置: 首页 >  pytorch

wendy_ya

暂无认证

  • 2浏览

    0关注

    342博文

    0收益

  • 0浏览

    0点赞

    0打赏

    0留言

私信
关注
热门博文

PyTorch重难点(一)——利用Dataset和DataLoader构建数据集原理介绍

wendy_ya 发布时间:2021-11-13 16:35:43 ,浏览量:2

目录
  • 一、基础概念
  • 二、创建数据集常用的方法
    • 2.1 使用 torch.utils.data.TensorDataset创建数据集
    • 2.2 使用torchvision.datasets.ImageFolder创建图片数据集
    • 2.3 继承torch.utils.data.Dataset创建自定义数据集
  • 三、Dataset的介绍和使用
    • 3.1 Dataset的介绍
    • 3.2 Dataset的核心接口
    • 3.3 Dataset的使用
      • 3.3.1 导入Dataset类
      • 3.3.2 创建Dataset的子类
      • 3.3.3 实例化该类
  • 四、DataLoader的介绍和使用
    • 4.1 DataLoader的介绍
    • 4.2 DataLoader的核心接口
    • 4.3 DataLoader加载现有数据集
    • 4.4 DataLoader的使用
      • 4.4.1 导入DataLoader
      • 4.4.2 加载数据集
  • 五、总结

一、基础概念

Pytorch通常使用Dataset和DataLoader这两个工具类来构建数据集。

Dataset定义了数据集的内容,它相当于一个类似列表的数据结构,具有确定的长度,能够用索引获取数据集中的元素。

而DataLoader定义了按batch加载数据集的方法,它是一个实现了__iter__方法的可迭代对象,每次迭代输出一个batch的数据。

DataLoader能够控制batch的大小,batch中元素的采样方法,以及将batch结果整理成模型所需输入形式的方法,并且能够使用多进程读取数据。

在绝大部分情况下,用户只需实现Dataset的__len__方法和 __getitem__方法,就可以轻松构建自己的数据集,并用默认数据集进行加载。

二、创建数据集常用的方法 2.1 使用 torch.utils.data.TensorDataset创建数据集

利用TensorDataset类创建数据集可以参考文章:PyTorch实战案例(二)——利用PyTorch实现线性回归算法(进阶),这里不再进行详细介绍,创建完数据集后,可以利用torch.utils.data.random_split 将一个数据集分割成多份,常用于分割训练集,验证集和测试集。

2.2 使用torchvision.datasets.ImageFolder创建图片数据集

使用 torchvision.datasets.ImageFolder创建图片数据集使用频率较低,这里不进行介绍,需要的可以自行查阅资料。

2.3 继承torch.utils.data.Dataset创建自定义数据集

利用torch.utils.data.Dataset创建自定义数据集在下文第三部分进行介绍。

三、Dataset的介绍和使用 3.1 Dataset的介绍

torch.utils.data.Dataset这样的抽象类(Abstract Class)可以用来创建数据集。学过面向对象的应该清楚,抽象类不能实例化,因此我们需要构造这个抽象类的子类来创建数据集,并且我们还可以定义自己的继承和重写方法。这其中最重要的就是len和getitem这两个函数,前者给出数据集的大小,后者是用于查找数据和标签。

3.2 Dataset的核心接口

Dataset的核心接口代码如下:

class Dataset(object):
    def __init__(self):
        pass
    
    def __len__(self):
        raise NotImplementedError
        
    def __getitem__(self,index):
        raise NotImplementedError
3.3 Dataset的使用 3.3.1 导入Dataset类

首先我们需要引入Dataset这个抽象类:

from torch.utils.data import Dataset
3.3.2 创建Dataset的子类

首先对类进行初始化,定义数据和标签:

    def __init__(self,img,label):
        self.img = img
        self.label = label

定义getitem方法获取数据集的内容和标签:

    def __getitem__(self, idx):
        return self.img[idx],self.label[idx]

定义len方法获取数据集的长度:

    def __len__(self):
        return self.img.shape[0]
3.3.3 实例化该类

实例化代码如下:(类名为MyData)

mydata = MyData(img,label)
四、DataLoader的介绍和使用 4.1 DataLoader的介绍

PyTorch用类torch.utils.data.DataLoader加载数据,并对数据进行采样,生成batch迭代器: torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False) 数据加载器 常用参数如下:

  • dataset:加载数据的数据集;
  • batch_size:每个batch要加载多少样本(默认为1);
  • shuffle:是否对数据集进行打乱重新排列(默认为False,即不重新排列);
4.2 DataLoader的核心接口

Dataset的核心接口代码如下:

class DataLoader(object):
    def __init__(self,dataset,batch_size,collate_fn,shuffle = True,drop_last = False):
        self.dataset = dataset
        self.sampler =torch.utils.data.RandomSampler if shuffle else \
           torch.utils.data.SequentialSampler
        self.batch_sampler = torch.utils.data.BatchSampler
        self.sample_iter = self.batch_sampler(
            self.sampler(range(len(dataset))),
            batch_size = batch_size,drop_last = drop_last)
        
    def __next__(self):
        indices = next(self.sample_iter)
        batch = self.collate_fn([self.dataset[i] for i in indices])
        return batch
4.3 DataLoader加载现有数据集

torchvision.datasets 是用来进行数据加载的,PyTorch团队在这个包中帮我们提前处理好了很多很多图片数据集,如:

  • MNIST
  • FashionMNIST
  • COCO(用于图像标注和目标检测)(Captioning and Detection)
  • LSUN Classification
  • ImageFolder
  • Imagenet-12
  • CIFAR10 and CIRAR100
  • STL10

代码示例:

#加载MNIST数据集
train_dataset=datasets.MNIST(root='./data',train=True,transform=transforms.ToTensor(),download=True)
test_dataset=datasets.MNIST(root='./data',train=False,transform=transforms.ToTensor())
4.4 DataLoader的使用 4.4.1 导入DataLoader

首先我们需要引入DataLoader这个抽象类:

from torch.utils.data import DataLoader
4.4.2 加载数据集

加载数据集可以利用之前Dataset实例化后的类用作数据集进行加载,也可以利用第二部分介绍的相关方法作为数据集,加载代码如下:

dataloader= DataLoader(mydata,batch_size = 10,shuffle=True)
五、总结

本文介绍了一下利用Dataset和DataLoader构建数据集的基本原理,包括两种类的基础概念、核心接口、使用的基本流程等,下一篇文章将介绍一下利用Dataset和DataLoader构建数据集实例代码,敬请期待。 下一篇文章: PyTorch重难点(二)——利用Dataset和DataLoader构建数据集代码实战

参考:

  1. https://www.heywhale.com/mw/project/5f33d5c0af3980002cb83cfa
  2. https://pytorch.org/tutorials/beginner/basics/data_tutorial.html

如果对你有所帮助,记得点个赞哟~

关注
打赏
1659256378
查看更多评论
立即登录/注册

微信扫码登录

0.3506s