您当前的位置: 首页 >  tensorflow

星夜孤帆

暂无认证

  • 2浏览

    0关注

    626博文

    0收益

  • 0浏览

    0点赞

    0打赏

    0留言

私信
关注
热门博文

Tensorflow训练中产生batch

星夜孤帆 发布时间:2018-10-18 16:39:13 ,浏览量:2

import numpy as np
import time
inputs = np.arange(0,784).reshape(-1,7,7)
targets = np.arange(0,784).reshape(-1,7,7)


# 仅有数据时
def get_batchs(inputs=None, batch_size=None, shuffle=False):
  indices = np.arange(len(inputs))
  if shuffle:
    np.random.shuffle(indices)
  for start_idx in range(0,len(inputs)-batch_size+1, batch_size):
    if shuffle:
      excerpt = indices[start_idx:start_idx + batch_size]
    else:
      excerpt = indices[start_idx:start_idx + batch_size]
      
    yield inputs[excerpt]
    
for batch in get_batchs(inputs,10,True):
  print(batch)
  
  
# 有数据有label时
def get_batch(inputs=None, targets=None, batch_size=None, shuffle=False):
  
  assert len(inputs) == len(targets)
  indices = np.arange(len(inputs))
  if shuffle:
    np.random.shuffle(indices)
  # start_idx为batch_size个数
  for start_idx in range(0, len(inputs) -batch_size + 1, batch_size):
    if shuffle:
      excerpt = indices[start_idx:start_idx + batch_size]
#       print(excerpt)
    else:
      excerpt = indices[start_idx:start_idx + batch_size]
#       print(excerpt)
    yield inputs[excerpt] , targets[excerpt]

# for a,b in get_batch(inputs, targets , 10, False):
#   print(a)
#   print(b)

关注
打赏
1636984416
查看更多评论
立即登录/注册

微信扫码登录

0.0392s