您当前的位置: 首页 >  Python

Xavier Jiezou

暂无认证

  • 2浏览

    0关注

    394博文

    0收益

  • 0浏览

    0点赞

    0打赏

    0留言

私信
关注
热门博文

【Python】提取 MNIST 数据集中的图片到本地

Xavier Jiezou 发布时间:2022-01-18 22:48:53 ,浏览量:2

引言 | Introduction

MNIST 数据集是最经典的一个机器学习的数据集,常被视为图像分类问题的入门级数据。虽然 Python 的很多第三方包都对其进行了封装,但对于模型训练来说,我们常用的还是本地的数据。今天教大家如何提取 MNIST 数据到本地。

安装 | Install
pip install torchvision==0.11.2 tqdm==4.54.1
方法 | Method

我们利用 torchvision 包封装的 MNIST 数据集来提取图片到本地。MNIST 数据集是一个典型的多分类数据集,其中存放的是 7 万张手写数字的灰度图片(6 万训练和1 万测试),每张灰度图片的大小是 28×28。共有 10 类标签,分别对应数字 0-9。

一般来说,我们回将整个数据集按照 8:1:1 的比例(或其他比例)划分为 3 个子集:训练集,验证集和测试集。MNIST 官方只划分了两个子集,笔者自认为不太合理,故提取到本地时没有单独的创建 train 和 test 子文件夹来存放图片,不过在图片名称前加了 train 和 test 字样,以标识该图片是从哪个子数据集中获取的。

下方给出了具体的实现代码。您需要安装两个第三方 Python 包:torchvision 和 tqdm,然后给定数据在本地保存的文件夹路径,运行代码即可。

代码 | Code
import os
import shutil
from tqdm import tqdm
from torchvision import datasets
from concurrent.futures import ThreadPoolExecutor


def mnist_export(root: str = './data/minst'):
    """Export MNIST data to a local folder using multi-threading.

    Args:
        root (str, optional): Path to local folder. Defaults to './data/minst'.
    """
    for i in range(10):
        os.makedirs(os.path.join(root, f'./{i}'), exist_ok=True)
    split_list = ['train', 'test']
    data = {
        split: datasets.MNIST(
            root='./tmp',
            train=split == 'train',
            download=True
        ) for split in split_list
    }
    total = sum([len(data[split]) for split in split_list])
    with tqdm(total=total) as pbar:
        with ThreadPoolExecutor() as tp:
            for split in split_list:
                for index, (image, label) in enumerate(data[split]):
                    tmp = os.path.join(root, f'{label}/{split}_{index}.png')
                    tp.submit(image.save, tmp).add_done_callback(
                        lambda func: pbar.update()
                    )
    shutil.rmtree('./tmp')


if __name__ == '__main__':
    mnist_export('./data/minst')
参考 | References
  • https://docs.python.org/3/library/os.html
  • https://docs.python.org/3/library/shutil.html#shutil.rmtree
  • https://tqdm.github.io/
  • https://pytorch.org/vision/stable/datasets.html#mnist
  • https://docs.python.org/3/library/concurrent.futures.html
下载 | Download

https://cdn.jsdelivr.net/gh/XavierJiezou/pytorch-lstm-mnist@main/data/mnist.7z

关注
打赏
1661408149
查看更多评论
立即登录/注册

微信扫码登录

0.0737s