RuntimeError: each element in list of batch should be of equal size
示例代码:
import os
import re
from torch.utils.data import Dataset, DataLoader
data_base_path = r'./aclImdb/'
# 1.定义token的方法
def tokenize(test):
filters = ['!','"','#','$','%','&','\(','\)','\*','\+',',','-','\.','/',':',';','','\?','@'
,'\[','\\','\]','^','_','`','\{','\|','\}','~','\t','\n','\x97','\x96','”','“',]
text = re.sub("", " ", test, flags=re.S)
text = re.sub("|".join(filters), " ", test, flags=re.S)
return [i.strip() for i in text.split()]
# 2.准备dataset
class ImdbDataset(Dataset):
def __init__(self, mode):
super().__init__()
if mode == "train":
text_path = [os.path.join(data_base_path, i) for i in ["train/neg", "train/pos"]]
else:
text_path = [os.path.join(data_base_path, i) for i in ["test/neg", "test/pos"]]
self.total_file_path_list = []
for i in text_path:
self.total_file_path_list.extend([os.path.join(i, j) for j in os.listdir(i)])
def __getitem__(self, item):
cur_path = self.total_file_path_list[item]
cur_filename = os.path.basename(cur_path)
label = int(cur_filename.split("_")[-1].split(".")[0]) - 1 # 处理标题,获取标签label,转化为从[0-9]
text = tokenize(open(cur_path).read().strip()) # 直接按照空格进行分词
return label, text
def __len__(self):
return len(self.total_file_path_list)
# 3.实例化,准别dataloader
dataset = ImdbDataset(mode="train")
dataloader = DataLoader(dataset=dataset, batch_size=2, shuffle=True)
# 4.观察数输出结果
for idx, (label, text) in enumerate(dataloader):
print("idx:", idx)
print("label:", label)
print("text:", text)
break
运行结果:
报错原因:
dataloader = DataLoader(dataset=dataset, batch_size=2, shuffle=True),发现是这行代码导致的错误,如果把batch_size=2改为batch_size=1时就不再报错了,运行结果如下:
但是如果想让batch_size=2时,这个错误该如何解决呢?
解决方法如下:
出现问题的原因在于Dataloader
中的参数collate_fn
collate_fn
的默认值为torch自定义的default_collate
,collate_fn
的作用就是对每个batch进行处理,而默认的default_collate
处理出错。
解决问题的思路:
- 手段1:考虑先把数据转化为数字序列,观察其结果是否符合要求,之前使用DataLoader并未出现类似错误
- 手段2:考虑自定义一个
collate_fn
,观察结果
这里使用方式2,自定义一个collate_fn
,然后观察结果:
def collate_fn(batch):
# batch是一个列表,其中是一个一个的元组,每个元组是dataset中_getitem__的结果
batch = list(zip(*batch))
labels = torch.tensor(batch[0], dtype=torch.int32)
texts = batch[1]
del batch
return labels, texts
全部代码:
import os
import re
import torch
from torch.utils.data import Dataset, DataLoader
data_base_path = r'./aclImdb/'
# 1.定义token的方法
def tokenize(test):
filters = ['!','"','#','$','%','&','\(','\)','\*','\+',',','-','\.','/',':',';','','\?','@'
,'\[','\\','\]','^','_','`','\{','\|','\}','~','\t','\n','\x97','\x96','”','“',]
text = re.sub("", " ", test, flags=re.S)
text = re.sub("|".join(filters), " ", test, flags=re.S)
return [i.strip() for i in text.split()]
# 2.准备dataset
class ImdbDataset(Dataset):
def __init__(self, mode):
super().__init__()
if mode == "train":
text_path = [os.path.join(data_base_path, i) for i in ["train/neg", "train/pos"]]
else:
text_path = [os.path.join(data_base_path, i) for i in ["test/neg", "test/pos"]]
self.total_file_path_list = []
for i in text_path:
self.total_file_path_list.extend([os.path.join(i, j) for j in os.listdir(i)])
def __getitem__(self, item):
cur_path = self.total_file_path_list[item]
cur_filename = os.path.basename(cur_path)
label = int(cur_filename.split("_")[-1].split(".")[0]) - 1 # 处理标题,获取标签label,转化为从[0-9]
text = tokenize(open(cur_path).read().strip()) # 直接按照空格进行分词
return label, text
def __len__(self):
return len(self.total_file_path_list)
def collate_fn(batch):
# batch是一个列表,其中是一个一个的元组,每个元组是dataset中_getitem__的结果
batch = list(zip(*batch))
labels = torch.tensor(batch[0], dtype=torch.int32)
texts = batch[1]
del batch
return labels, texts
# 3.实例化,准别dataloader
dataset = ImdbDataset(mode="train")
dataloader = DataLoader(dataset=dataset, batch_size=2, shuffle=True, collate_fn=collate_fn)
# 4.观察数输出结果
for idx, (label, text) in enumerate(dataloader):
print("idx:", idx)
print("label:", label)
print("text:", text)
break
运行效果: