您当前的位置: 首页 >  pytorch

Xavier Jiezou

暂无认证

  • 0浏览

    0关注

    394博文

    0收益

  • 0浏览

    0点赞

    0打赏

    0留言

私信
关注
热门博文

【pytorch】使用torch.utils.data.random_split()划分数据集

Xavier Jiezou 发布时间:2021-04-05 17:48:47 ,浏览量:0

写在前面

不用自己写划分数据集的函数,pytorch已经给我们封装好了,那就是torch.utils.data.random_split()

用法详解

torch.utils.data.random_split(dataset, lengths, generator=)

描述

随机将一个数据集分割成给定长度的不重叠的新数据集。可选择固定生成器以获得可复现的结果(效果同设置随机种子)。

参数
  • dataset (Dataset) – 要划分的数据集。
  • lengths (sequence) – 要划分的长度。
  • generator (Generator) – 用于随机排列的生成器。
示例

代码:

import torch
from torch.utils.data import random_split
dataset = range(10)
train_dataset, test_dataset = random_split(
    dataset=dataset,
    lengths=[7, 3],
    generator=torch.Generator().manual_seed(0)
)
print(list(train_dataset))
print(list(test_dataset))

输出:

[4, 1, 7, 5, 3, 9, 0]
[8, 6, 2]

torch.Generator().manual_seed(0)torch.manual_seed(0)的效果相同,我们验证一下。

代码:

import torch
from torch.utils.data import random_split
dataset = range(10)
torch.manual_seed(0)
train_dataset, test_dataset = random_split(
    dataset=dataset,
    lengths=[7, 3]
)
print(list(train_dataset))
print(list(test_dataset))

输出:

[4, 1, 7, 5, 3, 9, 0]
[8, 6, 2]
引用参考

https://pytorch.org/docs/stable/data.html#torch.utils.data.random_split

关注
打赏
1661408149
查看更多评论
立即登录/注册

微信扫码登录

0.0448s