项目场景
PyTorch Lightning 对 PyTorch 做了进一步的封装,并集成了日志记录,分布式训练等工具,让我们能够把研究核心放在模型改进上而不是工程代码的编写。近期使用发现一个小问题,在此记录一下。
问题描述模型训练的时候很正常,但验证的时候报错:
TypeError: validation_step() takes 3 positional arguments but 4 were given
并且,测试的时候也会遇到类似的问题。
原因分析原来是我重写 LightningModule 的 validation_step
和 test_step
方法时没有指定 batch_idx
参数,虽然这个参数在方法中没有被使用,但是却会被隐式地调用。batch_idx
就是批数据的索引,例如打印训练进度条的时候肯定会被调用的。但如果不显式地指定,就是导致位置参数和关键字参数识别冲突,从而引发异常。
这是我原来的代码:
def validation_step(self, batch):
pass
def test_step(self, batch):
pass
加上 batch_idx 参数就行了:
def validation_step(self, batch, batch_idx):
pass
def test_step(self, batch, batch_idx):
pass
引用参考
https://github.com/PyTorchLightning/pytorch-lightning/issues/1034