您当前的位置: 首页 >  pytorch

Better Bench

暂无认证

  • 2浏览

    0关注

    695博文

    0收益

  • 0浏览

    0点赞

    0打赏

    0留言

私信
关注
热门博文

【Pytorch】Expected hidden[0] size (2, 136, 256), got [2, 256, 256]

Better Bench 发布时间:2021-07-21 17:06:09 ,浏览量:2

问题

我在使用pytorch的 LSTM (RNN) 构建多类文本分类网络时遇到此错误,网络结构没有问题,能够运行起来,但是运行到几个batch后就报错Expected hidden[0] size (2, 136, 256), got [2, 256, 256]

分析

该错误是由于的训练数据不能被批量大小整除造成的。前面的batch都是256个,但是最后一个batch不足256,只有136个。 假设训练数据有 100个,batch大小为 16,划分为6个batch,最后一个batch将只有 4 个(100%16 = 4)个。

解决方案

(1)方法一 修改batchsize,让数据集大小能整除batchsize (2)方法二 如果使用Dataloader,设置一个参数drop_last=True,会自动舍弃最后不足batchsize的batch

from torch.utils.data import DataLoader

train_loader = DataLoader(train_data, shuffle=True, batch_size=batch_size, drop_last=True)

参考:https://stackoverflow.com/questions/54878904/runtimeerror-expected-hidden0-size-2-20-256-got-2-50-256

关注
打赏
1665674626
查看更多评论
立即登录/注册

微信扫码登录

0.2010s