您当前的位置: 首页 > 
  • 3浏览

    0关注

    417博文

    0收益

  • 0浏览

    0点赞

    0打赏

    0留言

私信
关注
热门博文

风格迁移0-10:stylegan-源码无死角解读(6)-loss损失函数详解

江南才尽,年少无知! 发布时间:2019-09-26 09:58:30 ,浏览量:3

以下链接是个人关于stylegan所有见解,如有错误欢迎大家指出,我会第一时间纠正,如有兴趣可以加微信:17575010159 相互讨论技术。若是帮助到了你什么,一定要记得点赞奥!因为这是对我最大的鼓励。 风格迁移0-00:stylegan-目录-史上最全:https://blog.csdn.net/weixin_43013761/article/details/100895333

损失函数代码注释

通过debug模式,我们可以知道,作者github的代码使用的损失计算再stylegan-master\dnnlib\tflib\optimizer.py文件中,针对本人运行的代码,分别是下面两个函数: generate网络的损失计算:

def G_logistic_nonsaturating(G, D, opt, training_set, minibatch_size): # pylint: disable=unused-argument
    # 获得latents,即论文中的Z
    latents = tf.random_normal([minibatch_size] + G.input_shapes[0][1:])
    labels = training_set.get_random_labels_tf(minibatch_size)
    # 把latents送入生成网络,得到输出图片
    fake_images_out = G.get_output_for(latents, labels, is_training=True)
    # 把生成网络生成的图片,输入到鉴别网络,得到鉴别结果
    fake_scores_out = fp32(D.get_output_for(fake_images_out, labels, is_training=True))
    # 然后通过逻辑回归
    loss = tf.nn.softplus(-fake_scores_out)  # -log(logistic(fake_scores_out))
    return loss

discriminator 网络的损失计算:

def D_logistic_simplegp(G, D, opt, training_set, minibatch_size, reals, labels, r1_gamma=10.0, r2_gamma=0.0): # pylint: disable=unused-argument
    # 随机获得latents
    latents = tf.random_normal([minibatch_size] + G.input_shapes[0][1:])
    # 把latents送入生成网络生成图像
    fake_images_out = G.get_output_for(latents, labels, is_training=True)
    # 把真实图片送入判别网络,进行预测,得到其为真实图片的概率
    real_scores_out = fp32(D.get_output_for(reals, labels, is_training=True))
    # 把生成图片送入判别网络,得到其为真实图片的概率值
    fake_scores_out = fp32(D.get_output_for(fake_images_out, labels, is_training=True))

    real_scores_out = autosummary('Loss/scores/real', real_scores_out)
    fake_scores_out = autosummary('Loss/scores/fake', fake_scores_out)

    # 然后计算损失
    loss = tf.nn.softplus(fake_scores_out)  # -log(1 - logistic(fake_scores_out))
    loss += tf.nn.softplus(-real_scores_out)  # -log(logistic(real_scores_out)) # temporary pylint workaround # pylint: disable=invalid-unary-operand-type

    # 如果r1_gamma不为0,则对真实图片的损失进行缩放
    if r1_gamma != 0.0:
        with tf.name_scope('R1Penalty'):
            # 对loss进行缩放
            real_loss = opt.apply_loss_scaling(tf.reduce_sum(real_scores_out))
            # 取消loss缩放之后的影响
            real_grads = opt.undo_loss_scaling(fp32(tf.gradients(real_loss, [reals])[0]))
            # 对损失求平方和
            r1_penalty = tf.reduce_sum(tf.square(real_grads), axis=[1,2,3])
            r1_penalty = autosummary('Loss/r1_penalty', r1_penalty)
        #
        loss += r1_penalty * (r1_gamma * 0.5)

    # 如果r2_gamma不为0,则对生成图片的损失进行缩放
    if r2_gamma != 0.0:
        with tf.name_scope('R2Penalty'):
            fake_loss = opt.apply_loss_scaling(tf.reduce_sum(fake_scores_out))
            fake_grads = opt.undo_loss_scaling(fp32(tf.gradients(fake_loss, [fake_images_out])[0]))
            r2_penalty = tf.reduce_sum(tf.square(fake_grads), axis=[1,2,3])
            r2_penalty = autosummary('Loss/r2_penalty', r2_penalty)
        loss += r2_penalty * (r2_gamma * 0.5)
    return loss
损失计算总结

上面有详细的注解,细节部分我就不说了。这里为大家简单的提一下他们之间的区别。 可以很明确的看到,G_logistic_nonsaturating计算的损失,都是生成图片的损失,因为他的目的十分的单纯,就是为了生成逼真的图片,所以只需要对生成的图片进行损失计算即可。 但是对于判别网络,他的目的是在于鉴别图片的真假。他不仅要判断出造假的图片,还要判断出真实的图片。无论是造假还是真实他都要进行损失计算。 这个就是生成网络和判别网络的区别。

一路走来,或许讲解得不是很详细,但是的确花费了不少心思。后面就是对stylegan网络应用的讲解,如:如何进行图片融合等等。

关注
打赏
1592542134
查看更多评论
立即登录/注册

微信扫码登录

0.0403s