您当前的位置: 首页 > 
  • 2浏览

    0关注

    417博文

    0收益

  • 0浏览

    0点赞

    0打赏

    0留言

私信
关注
热门博文

风格迁移0-06:stylegan-源码无死角解读(2)-数据预处理process_reals

江南才尽,年少无知! 发布时间:2019-09-22 08:51:26 ,浏览量:2

以下链接是个人关于stylegan所有见解,如有错误欢迎大家指出,我会第一时间纠正,如有兴趣可以加微信:17575010159 相互讨论技术。若是帮助到了你什么,一定要记得点赞奥!因为这是对我最大的鼓励。 风格迁移0-00:stylegan-目录-史上最全:https://blog.csdn.net/weixin_43013761/article/details/100895333

数据处理

在上篇博客中,我给大家注释了一下根目录下的training/training_loop.py文件,了解一下大致的流程,这小节我们就来讲解一下数据的预处理,即process_reals函数,该函数也存在于training/training_loop.py文件中,其实该函数花费了我好些心思才理解过来。:

def training_loop()
	......
    # 加载训练数据,其会把所有分辨率的数据都加载进来
    training_set = dataset.load_dataset(data_dir=config.data_dir, verbose=True, **dataset_args)
    ......
    ......
            # 获得训练数据图片和标签
            reals, labels = training_set.get_minibatch_tf()
            # 对训练数据的真实图片进行处理,主要把图片分成多个区域进行平滑,注意这里的reals包含多张图片,分别对应不同的分辨率,
            # 其实这里说是分辨率不太合适,总的来说,他们分辨率都是1024,但是平滑插值不一样.其不是用来训练的数据,是用来求损失用的,具体细节后面分析,也属于一个比较重要的地方
            reals = process_reals(reals, lod_in, mirror_augment, training_set.dynamic_range, drange_net)  

现在我们来看看process_reals函数:

def process_reals(x, lod, mirror_augment, drange_data, drange_net):
    """

    :param x: 该为输入的图片,(batch_size, 3, 1024, 1024)
    :param lod: 该值从零开始,随着训练图片的张数,更改变为0,1,2,3,4,5,6,7,8
    :param mirror_augment: # 是否进行镜像翻转
    :param drange_data: #数据动态变化的范围[0,255],输入
    :param drange_net: #数据动态变化的网络[-1,1],输出
    :return:
    """
    with tf.name_scope('ProcessReals'):
        with tf.name_scope('DynamicRange'):
            x = tf.cast(x, tf.float32)
            # 把原来的像素先缩小到2/255然后减去1
            x = misc.adjust_dynamic_range(x, drange_data, drange_net)

        if mirror_augment:
            with tf.name_scope('MirrorAugment'):
                s = tf.shape(x)
                # 随机产生(batch_size, 1, 1, 1)维度的数组,其值为0到1之间
                mask = tf.random_uniform([s[0], 1, 1, 1], 0.0, 1.0)
                # 对前面产生的像素,进行复制,复制之后的维度为[batch_size, 3, 1024, 1024]
                mask = tf.tile(mask, [1, s[1], s[2], s[3]])
                #小于0.5的返回原值,否则返回对第三维进行翻转之后的值(经过一位小伙伴的提醒,进行了修改)
                x = tf.where(mask             
关注
打赏
1592542134
查看更多评论
0.0516s