您当前的位置: 首页 > 

IT之一小佬

暂无认证

  • 0浏览

    0关注

    1192博文

    0收益

  • 0浏览

    0点赞

    0打赏

    0留言

私信
关注
热门博文

RuntimeError: each element in list of batch should be of equal size

IT之一小佬 发布时间:2021-03-23 21:57:56 ,浏览量:0

RuntimeError: each element in list of batch should be of equal size

示例代码:

import os
import re
from torch.utils.data import Dataset, DataLoader

data_base_path = r'./aclImdb/'


#  1.定义token的方法
def tokenize(test):
    filters = ['!','"','#','$','%','&','\(','\)','\*','\+',',','-','\.','/',':',';','','\?','@'
        ,'\[','\\','\]','^','_','`','\{','\|','\}','~','\t','\n','\x97','\x96','”','“',]
    text = re.sub("", " ", test, flags=re.S)
    text = re.sub("|".join(filters), " ", test, flags=re.S)
    return [i.strip() for i in text.split()]


#  2.准备dataset
class ImdbDataset(Dataset):
    def __init__(self, mode):
        super().__init__()
        if mode == "train":
            text_path = [os.path.join(data_base_path, i) for i in ["train/neg", "train/pos"]]
        else:
            text_path = [os.path.join(data_base_path, i) for i in ["test/neg", "test/pos"]]
        self.total_file_path_list = []
        for i in text_path:
            self.total_file_path_list.extend([os.path.join(i, j) for j in os.listdir(i)])

    def __getitem__(self, item):
        cur_path = self.total_file_path_list[item]
        cur_filename = os.path.basename(cur_path)
        label = int(cur_filename.split("_")[-1].split(".")[0]) - 1  # 处理标题,获取标签label,转化为从[0-9]
        text = tokenize(open(cur_path).read().strip())  # 直接按照空格进行分词
        return label, text

    def __len__(self):
        return len(self.total_file_path_list)


#  3.实例化,准别dataloader
dataset = ImdbDataset(mode="train")
dataloader = DataLoader(dataset=dataset, batch_size=2, shuffle=True)

#  4.观察数输出结果
for idx, (label, text) in enumerate(dataloader):
    print("idx:", idx)
    print("label:", label)
    print("text:", text)
    break

运行结果:

报错原因:

dataloader = DataLoader(dataset=dataset, batch_size=2, shuffle=True),发现是这行代码导致的错误,如果把batch_size=2改为batch_size=1时就不再报错了,运行结果如下:

但是如果想让batch_size=2时,这个错误该如何解决呢?

解决方法如下:

出现问题的原因在于Dataloader中的参数collate_fn

collate_fn的默认值为torch自定义的default_collate,collate_fn的作用就是对每个batch进行处理,而默认的default_collate处理出错。

解决问题的思路:

  • 手段1:考虑先把数据转化为数字序列,观察其结果是否符合要求,之前使用DataLoader并未出现类似错误
  • 手段2:考虑自定义一个collate_fn,观察结果

这里使用方式2,自定义一个collate_fn,然后观察结果:

def collate_fn(batch):
    #  batch是一个列表,其中是一个一个的元组,每个元组是dataset中_getitem__的结果
    batch = list(zip(*batch))
    labels = torch.tensor(batch[0], dtype=torch.int32)
    texts = batch[1]
    del batch
    return labels, texts

全部代码:

import os
import re
import torch
from torch.utils.data import Dataset, DataLoader

data_base_path = r'./aclImdb/'


#  1.定义token的方法
def tokenize(test):
    filters = ['!','"','#','$','%','&','\(','\)','\*','\+',',','-','\.','/',':',';','','\?','@'
        ,'\[','\\','\]','^','_','`','\{','\|','\}','~','\t','\n','\x97','\x96','”','“',]
    text = re.sub("", " ", test, flags=re.S)
    text = re.sub("|".join(filters), " ", test, flags=re.S)
    return [i.strip() for i in text.split()]


#  2.准备dataset
class ImdbDataset(Dataset):
    def __init__(self, mode):
        super().__init__()
        if mode == "train":
            text_path = [os.path.join(data_base_path, i) for i in ["train/neg", "train/pos"]]
        else:
            text_path = [os.path.join(data_base_path, i) for i in ["test/neg", "test/pos"]]
        self.total_file_path_list = []
        for i in text_path:
            self.total_file_path_list.extend([os.path.join(i, j) for j in os.listdir(i)])

    def __getitem__(self, item):
        cur_path = self.total_file_path_list[item]
        cur_filename = os.path.basename(cur_path)
        label = int(cur_filename.split("_")[-1].split(".")[0]) - 1  # 处理标题,获取标签label,转化为从[0-9]
        text = tokenize(open(cur_path).read().strip())  # 直接按照空格进行分词
        return label, text

    def __len__(self):
        return len(self.total_file_path_list)


def collate_fn(batch):
    #  batch是一个列表,其中是一个一个的元组,每个元组是dataset中_getitem__的结果
    batch = list(zip(*batch))
    labels = torch.tensor(batch[0], dtype=torch.int32)
    texts = batch[1]
    del batch
    return labels, texts


#  3.实例化,准别dataloader
dataset = ImdbDataset(mode="train")
dataloader = DataLoader(dataset=dataset, batch_size=2, shuffle=True, collate_fn=collate_fn)

#  4.观察数输出结果
for idx, (label, text) in enumerate(dataloader):
    print("idx:", idx)
    print("label:", label)
    print("text:", text)
    break

运行效果:

关注
打赏
1665675218
查看更多评论
立即登录/注册

微信扫码登录

0.0397s