您当前的位置: 首页 >  pytorch

wendy_ya

暂无认证

  • 3浏览

    0关注

    342博文

    0收益

  • 0浏览

    0点赞

    0打赏

    0留言

私信
关注
热门博文

PyTorch重难点(二)——利用Dataset和DataLoader构建数据集代码实战

wendy_ya 发布时间:2021-11-13 17:39:10 ,浏览量:3

目录
  • 一、案例描述
  • 二、代码实战
    • 2.1 自定义数据和标签
    • 2.2 创建Dataset的子类
    • 2.3 Dataset子类实例化
    • 2.4 对数据集进行拆分
    • 2.5 利用DataLoader加载数据集
    • 2.6 输出数据集结果

上文PyTorch重难点(一)——利用Dataset和DataLoader构建数据集原理介绍中介绍了Dataset和DataLoader的基本原理,本文对案例进行代码实战介绍。

一、案例描述

原始程序中我首先构建了一个手势识别的数据集,包含手势图像和标签,手势图像是一共有1320个图像,其形状是300*300的单通道图像,其size为[1320, 300, 300],手势标签一共有1320个标签,其size为[1320],这里为了简单方便,自行随机定义一个类似结构的数据集,样本数量为50,大小为300 *300,对其进行创建为Dataset以及进行DataLoader的加载。

二、代码实战 2.1 自定义数据和标签

自定义一个样本数量为50,大小为300 *300的图像数据集:

# 自定义数据和标签
img=torch.rand(size=(50,1,300,300),dtype=torch.float32)#数据
label=torch.rand(size=(50,))#标签
2.2 创建Dataset的子类

创建一个名为MyData的Dataset子类:

#创建Dataset的子类
class MyData(Dataset):
    def __init__(self,img,label):
        self.img = img
        self.label = label
    def __getitem__(self, idx):
        return self.img[idx],self.label[idx]
    def __len__(self):
        return self.img.shape[0]
2.3 Dataset子类实例化

实例化MyData类:

mydata = MyData(img,label)
2.4 对数据集进行拆分

将数据集随机拆分为给定长度的新数据集,这里按照7:3的比例对50个样本的数据集进行拆分:

ds_train,ds_valid = random_split(mydata,[int(50*0.7),50-int(50*0.7)])
2.5 利用DataLoader加载数据集

分别对训练集和验证机进行加载,加载数据集代码如下:

dl_train = DataLoader(ds_train,batch_size = 10,shuffle=True)
dl_valid = DataLoader(ds_valid,batch_size = 10)

iter(train_iter).next()[0]语句可以输出一个batch_size的图像,这里是10,我们可以查看一下它的size:

iter(dl_train).next()[0].size()

运行结果: torch.Size([10, 1, 300, 300])

2.6 输出数据集结果

对每个batch_size进行打印输出,输出示例代码如下:

#对每个batch_size进行打印输出
for step, (data, label) in enumerate(dataloader):
    print('step is :', step)
    print('img is {}, label is {}'.format(img, label))

step输出结果如下: step is : 0 step is : 1 step is : 2 step is : 3

查看训练集的数据和标签的形状:

for features,labels in dl_train:
    print(features.shape)
    print(labels.shape)
    break   #运行一个batch后就break

运行结果: torch.Size([10, 1, 300, 300]) torch.Size([10])

当我们想取出features和对应的labels时,可以用如下代码实现:

# 表示输出数据
features[0]
# 表示输出标签
labels[0]

结果如下: 在这里插入图片描述

也可以查看训练集和验证机的长度:

print(len(ds_train))
print(len(ds_valid))

运行结果: 35 15

ok,以上便是本文的全部内容了,详细代码上文已经介绍过了,如果没有看明白,可以从链接自取:https://download.csdn.net/download/didi_ya/41704695

如果对你有所帮助,记得点个赞哟~

还有什么问题欢迎在评论区补充,感谢您的点赞与留言~

关注
打赏
1659256378
查看更多评论
立即登录/注册

微信扫码登录

0.0413s