您当前的位置: 首页 >  pytorch

Xavier Jiezou

暂无认证

  • 2浏览

    0关注

    394博文

    0收益

  • 0浏览

    0点赞

    0打赏

    0留言

私信
关注
热门博文

记录 PyTorch Lightning 的一个坑

Xavier Jiezou 发布时间:2022-05-21 09:55:49 ,浏览量:2

项目场景

PyTorch Lightning 对 PyTorch 做了进一步的封装,并集成了日志记录,分布式训练等工具,让我们能够把研究核心放在模型改进上而不是工程代码的编写。近期使用发现一个小问题,在此记录一下。

问题描述

模型训练的时候很正常,但验证的时候报错:

TypeError: validation_step() takes 3 positional arguments but 4 were given

并且,测试的时候也会遇到类似的问题。

原因分析

原来是我重写 LightningModule 的 validation_steptest_step 方法时没有指定 batch_idx 参数,虽然这个参数在方法中没有被使用,但是却会被隐式地调用。batch_idx 就是批数据的索引,例如打印训练进度条的时候肯定会被调用的。但如果不显式地指定,就是导致位置参数和关键字参数识别冲突,从而引发异常。

解决方案

这是我原来的代码:

def validation_step(self, batch):
	pass

def test_step(self, batch):
    pass

加上 batch_idx 参数就行了:

def validation_step(self, batch, batch_idx):
	pass

def test_step(self, batch, batch_idx):
    pass
引用参考

https://github.com/PyTorchLightning/pytorch-lightning/issues/1034

关注
打赏
1661408149
查看更多评论
立即登录/注册

微信扫码登录

0.0546s