您当前的位置: 首页 >  tensorflow

Better Bench

暂无认证

  • 3浏览

    0关注

    695博文

    0收益

  • 0浏览

    0点赞

    0打赏

    0留言

私信
关注
热门博文

【北京大学】9 TensorFlow1.x的实现自定义Mnist数据集

Better Bench 发布时间:2020-12-21 18:15:25 ,浏览量:3

目录
  • 1 实现把任意图片放进训练好的网络进行测试
  • 2 实现制作数据
    • 2.1 简介
    • 2.2 生成tfrecords文件
    • 2.3 解析tfrecords文件
    • 2.4 生成自定义数据的完整代码
      • mnist_generateds.py文件
      • mnist_backward.py文件
      • mnist_test.py文件
  • 相关笔记

1 实现把任意图片放进训练好的网络进行测试

输入的图片是白底黑字的数字图片进行测试,测试前需要做两步 (1)转换图片矩阵大小为28*28符合网络的输入 (2)把图片的转换成白字黑底的黑白图片

mnist_app.py
import tensorflow as tf
import numpy as np
from PIL import Image
import mnist_backward
import mnist_forward
def restore_model(testPicArr):
    # 利用tf.Graph()复现之前定义的计算图
    with tf.Graph().as_default() as tg:
        x = tf.placeholder(tf.float32, [None, mnist_forward.INPUT_NODE])
        # 调用mnist_forward文件中的前向传播过程forword()函数
        y = mnist_forward.forward(x, None)
        # 得到概率最大的预测值
        preValue = tf.argmax(y, 1)
        # 实例化具有滑动平均的saver对象
        variable_averages = tf.train.ExponentialMovingAverage(mnist_backward.MOVING_AVERAGE_DECAY)
        variables_to_restore = variable_averages.variables_to_restore()
        saver = tf.train.Saver(variables_to_restore)
        with tf.Session() as sess:
            # 通过ckpt获取最新保存的模型
            ckpt = tf.train.get_checkpoint_state(mnist_backward.MODEL_SAVE_PATH)
            if ckpt and ckpt.model_checkpoint_path:
                saver.restore(sess, ckpt.model_checkpoint_path)
                preValue = sess.run(preValue, feed_dict={x: testPicArr})
                return preValue
            else:
                print("No checkpoint file found")
                return -1
# 预处理,包括resize,转变灰度图,二值化
def pre_pic(picName):
    img = Image.open(picName)
    reIm = img.resize((28, 28), Image.ANTIALIAS)
    #把图片转换为灰度值图片
    im_arr = np.array(reIm.convert('L'))
    # 对图片做二值化处理(这样以滤掉噪声,另外调试中可适当调节阈值)
    threshold = 50
    # 模型的要求是黑底白字,但输入的图是白底黑字,所以需要对每个像素点的值改为255减去原值以得到互补的反色。
    for i in range(28):
        for j in range(28):
            im_arr[i][j] = 255 - im_arr[i][j]
            if (im_arr[i][j]             
关注
打赏
1665674626
查看更多评论
0.0426s