- 一、基础概念
- 二、创建数据集常用的方法
- 2.1 使用 torch.utils.data.TensorDataset创建数据集
- 2.2 使用torchvision.datasets.ImageFolder创建图片数据集
- 2.3 继承torch.utils.data.Dataset创建自定义数据集
- 三、Dataset的介绍和使用
- 3.1 Dataset的介绍
- 3.2 Dataset的核心接口
- 3.3 Dataset的使用
- 3.3.1 导入Dataset类
- 3.3.2 创建Dataset的子类
- 3.3.3 实例化该类
- 四、DataLoader的介绍和使用
- 4.1 DataLoader的介绍
- 4.2 DataLoader的核心接口
- 4.3 DataLoader加载现有数据集
- 4.4 DataLoader的使用
- 4.4.1 导入DataLoader
- 4.4.2 加载数据集
- 五、总结
Pytorch通常使用Dataset和DataLoader这两个工具类来构建数据集。
Dataset定义了数据集的内容,它相当于一个类似列表的数据结构,具有确定的长度,能够用索引获取数据集中的元素。
而DataLoader定义了按batch加载数据集的方法,它是一个实现了__iter__方法的可迭代对象,每次迭代输出一个batch的数据。
DataLoader能够控制batch的大小,batch中元素的采样方法,以及将batch结果整理成模型所需输入形式的方法,并且能够使用多进程读取数据。
在绝大部分情况下,用户只需实现Dataset的__len__方法和 __getitem__方法,就可以轻松构建自己的数据集,并用默认数据集进行加载。
二、创建数据集常用的方法 2.1 使用 torch.utils.data.TensorDataset创建数据集利用TensorDataset类创建数据集可以参考文章:PyTorch实战案例(二)——利用PyTorch实现线性回归算法(进阶),这里不再进行详细介绍,创建完数据集后,可以利用torch.utils.data.random_split 将一个数据集分割成多份,常用于分割训练集,验证集和测试集。
2.2 使用torchvision.datasets.ImageFolder创建图片数据集使用 torchvision.datasets.ImageFolder创建图片数据集使用频率较低,这里不进行介绍,需要的可以自行查阅资料。
2.3 继承torch.utils.data.Dataset创建自定义数据集利用torch.utils.data.Dataset创建自定义数据集在下文第三部分进行介绍。
三、Dataset的介绍和使用 3.1 Dataset的介绍torch.utils.data.Dataset这样的抽象类(Abstract Class)可以用来创建数据集。学过面向对象的应该清楚,抽象类不能实例化,因此我们需要构造这个抽象类的子类来创建数据集,并且我们还可以定义自己的继承和重写方法。这其中最重要的就是len和getitem这两个函数,前者给出数据集的大小,后者是用于查找数据和标签。
3.2 Dataset的核心接口Dataset的核心接口代码如下:
class Dataset(object):
def __init__(self):
pass
def __len__(self):
raise NotImplementedError
def __getitem__(self,index):
raise NotImplementedError
3.3 Dataset的使用
3.3.1 导入Dataset类
首先我们需要引入Dataset这个抽象类:
from torch.utils.data import Dataset
3.3.2 创建Dataset的子类
首先对类进行初始化,定义数据和标签:
def __init__(self,img,label):
self.img = img
self.label = label
定义getitem方法获取数据集的内容和标签:
def __getitem__(self, idx):
return self.img[idx],self.label[idx]
定义len方法获取数据集的长度:
def __len__(self):
return self.img.shape[0]
3.3.3 实例化该类
实例化代码如下:(类名为MyData)
mydata = MyData(img,label)
四、DataLoader的介绍和使用
4.1 DataLoader的介绍
PyTorch用类torch.utils.data.DataLoader加载数据,并对数据进行采样,生成batch迭代器: torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False) 数据加载器 常用参数如下:
- dataset:加载数据的数据集;
- batch_size:每个batch要加载多少样本(默认为1);
- shuffle:是否对数据集进行打乱重新排列(默认为False,即不重新排列);
Dataset的核心接口代码如下:
class DataLoader(object):
def __init__(self,dataset,batch_size,collate_fn,shuffle = True,drop_last = False):
self.dataset = dataset
self.sampler =torch.utils.data.RandomSampler if shuffle else \
torch.utils.data.SequentialSampler
self.batch_sampler = torch.utils.data.BatchSampler
self.sample_iter = self.batch_sampler(
self.sampler(range(len(dataset))),
batch_size = batch_size,drop_last = drop_last)
def __next__(self):
indices = next(self.sample_iter)
batch = self.collate_fn([self.dataset[i] for i in indices])
return batch
4.3 DataLoader加载现有数据集
torchvision.datasets 是用来进行数据加载的,PyTorch团队在这个包中帮我们提前处理好了很多很多图片数据集,如:
- MNIST
- FashionMNIST
- COCO(用于图像标注和目标检测)(Captioning and Detection)
- LSUN Classification
- ImageFolder
- Imagenet-12
- CIFAR10 and CIRAR100
- STL10
代码示例:
#加载MNIST数据集
train_dataset=datasets.MNIST(root='./data',train=True,transform=transforms.ToTensor(),download=True)
test_dataset=datasets.MNIST(root='./data',train=False,transform=transforms.ToTensor())
4.4 DataLoader的使用
4.4.1 导入DataLoader
首先我们需要引入DataLoader这个抽象类:
from torch.utils.data import DataLoader
4.4.2 加载数据集
加载数据集可以利用之前Dataset实例化后的类用作数据集进行加载,也可以利用第二部分介绍的相关方法作为数据集,加载代码如下:
dataloader= DataLoader(mydata,batch_size = 10,shuffle=True)
五、总结
本文介绍了一下利用Dataset和DataLoader构建数据集的基本原理,包括两种类的基础概念、核心接口、使用的基本流程等,下一篇文章将介绍一下利用Dataset和DataLoader构建数据集实例代码,敬请期待。 下一篇文章: PyTorch重难点(二)——利用Dataset和DataLoader构建数据集代码实战
参考:
- https://www.heywhale.com/mw/project/5f33d5c0af3980002cb83cfa
- https://pytorch.org/tutorials/beginner/basics/data_tutorial.html
如果对你有所帮助,记得点个赞哟~