您当前的位置: 首页 >  ar

风间琉璃•

暂无认证

  • 4浏览

    0关注

    337博文

    0收益

  • 0浏览

    0点赞

    0打赏

    0留言

私信
关注
热门博文

CIFAR10数据下载失败解决办法

风间琉璃• 发布时间:2022-02-11 20:06:08 ,浏览量:4

将数据集从官网下载下来放在.keras/ datasets目录下是不行的,因为在load_data里面指定了路径来自官网,前不可修改。自定义加载数据集

import numpy as np
import os

def load_batch(file):
    import pickle
    with open(file, 'rb') as fo:
        d = pickle.load(fo, encoding='bytes')
        d_decoded = {}
        for k, v in d.items():
            d_decoded[k.decode('utf8')] = v
        d = d_decoded
        data = d['data']
        labels = d['labels']
        data = data.reshape(data.shape[0], 3, 32, 32)
    return data, labels

def load_data(path = r'C:\Users\xiaochao\.keras\datasets\cifar-10-batches-py'):
    """Loads CIFAR10 dataset.
    # Returns
        Tuple of Numpy arrays: `(x_train, y_train), (x_test, y_test)`.
    """
    from keras import backend as K

    num_train_samples = 50000

    x_train = np.empty((num_train_samples, 3, 32, 32), dtype='uint8')
    y_train = np.empty((num_train_samples,), dtype='uint8')

    for i in range(1, 6):
        fpath = os.path.join(path, 'data_batch_' + str(i))
        (x_train[(i - 1) * 10000: i * 10000, :, :, :],
         y_train[(i - 1) * 10000: i * 10000]) = load_batch(fpath)

    fpath = os.path.join(path, 'test_batch')
    x_test, y_test = load_batch(fpath)

    y_train = np.reshape(y_train, (len(y_train), 1))
    y_test = np.reshape(y_test, (len(y_test), 1))

    if K.image_data_format() == 'channels_last':
        x_train = x_train.transpose(0, 2, 3, 1)
        x_test = x_test.transpose(0, 2, 3, 1)

    return (x_train, y_train), (x_test, y_test)

将它进行模块保存在load文件中

加载数据集:

(x, y), (x_val, y_val) =load.load_data()

关注
打赏
1665385461
查看更多评论
立即登录/注册

微信扫码登录

0.0365s