import numpy as np
import time
inputs = np.arange(0,784).reshape(-1,7,7)
targets = np.arange(0,784).reshape(-1,7,7)
# 仅有数据时
def get_batchs(inputs=None, batch_size=None, shuffle=False):
indices = np.arange(len(inputs))
if shuffle:
np.random.shuffle(indices)
for start_idx in range(0,len(inputs)-batch_size+1, batch_size):
if shuffle:
excerpt = indices[start_idx:start_idx + batch_size]
else:
excerpt = indices[start_idx:start_idx + batch_size]
yield inputs[excerpt]
for batch in get_batchs(inputs,10,True):
print(batch)
# 有数据有label时
def get_batch(inputs=None, targets=None, batch_size=None, shuffle=False):
assert len(inputs) == len(targets)
indices = np.arange(len(inputs))
if shuffle:
np.random.shuffle(indices)
# start_idx为batch_size个数
for start_idx in range(0, len(inputs) -batch_size + 1, batch_size):
if shuffle:
excerpt = indices[start_idx:start_idx + batch_size]
# print(excerpt)
else:
excerpt = indices[start_idx:start_idx + batch_size]
# print(excerpt)
yield inputs[excerpt] , targets[excerpt]
# for a,b in get_batch(inputs, targets , 10, False):
# print(a)
# print(b)
Tensorflow训练中产生batch
关注
打赏