您当前的位置: 首页 > 

寒冰屋

暂无认证

  • 1浏览

    0关注

    2286博文

    0收益

  • 0浏览

    0点赞

    0打赏

    0留言

私信
关注
热门博文

(四)训练运行Deep CycleGAN以进行移动风格迁移

寒冰屋 发布时间:2022-03-13 21:47:10 ,浏览量:1

目录

介绍

训练周期GAN

评估CycleGAN

季节更替CycleGAN

下一步

  • 下载项目代码 - 7.2 MB
介绍

在本系列文章中,我们将展示一个基于循环一致对抗网络(CycleGAN)的移动图像到图像转换系统。我们将构建一个CycleGAN,它可以执行不成对的图像到图像的转换,并向您展示一些有趣但具有学术深度的例子。我们还将讨论如何将这种使用TensorFlow和Keras构建的训练有素的网络转换为TensorFlow Lite 并用作移动设备上的应用程序。

我们假设您熟悉深度学习的概念,以及Jupyter Notebooks和TensorFlow。欢迎您下载项目代码。

在上一篇文章中,我们从头实现了CycleGAN。在本文中,我们将在Horse2zebra数据集上训练和测试网络并评估其性能。

训练周期GAN

是时候训练我们的CycleGAN以执行一些有趣的转换了,例如马到斑马,反之亦然。我们将首先设置检查点路径以保存最佳模型:

checkpoint_path = "./checkpoints/train"

ckpt = tf.train.Checkpoint(generator_g=generator_g,
                           generator_f=generator_f,
                           discriminator_x=discriminator_x,
                           discriminator_y=discriminator_y,
                           generator_g_optimizer=generator_g_optimizer,
                           generator_f_optimizer=generator_f_optimizer,
                           discriminator_x_optimizer=discriminator_x_optimizer,
                           discriminator_y_optimizer=discriminator_y_optimizer)

ckpt_manager = tf.train.CheckpointManager(ckpt, checkpoint_path, max_to_keep=5)

# if a checkpoint exists, restore the latest checkpoint
if ckpt_manager.latest_checkpoint:
  ckpt.restore(ckpt_manager.latest_checkpoint)
  print ('Latest checkpoint restored!!')

首先,我们将训练超过20个epoch,看看这是否足以获得可接受的结果。根据获得的结果,我们可能需要增加epochs的数量。即使您的训练结果看起来不错,预测可能仍然不太准确。因此,80到100个epoch更有可能让您获得完美的转换,但是这将需要3天以上的训练,除非您使用具有非常高规格的系统或付费的基于云的计算服务,例如AWS或Microsoft Azure。

EPOCHS = 20
def generate_images(model, test_input):
  prediction = model(test_input)
    
  plt.figure(figsize=(12, 12))

  display_list = [test_input[0], prediction[0]]
  title = ['Input Image', 'Predicted Image']

  for i in range(2):
    plt.subplot(1, 2, i+1)
    plt.title(title[i])
    # getting the pixel values between [0, 1] to plot it.
    plt.imshow(display_list[i] * 0.5 + 0.5)
    plt.axis('off')
  plt.show()


def train_step(real_x, real_y):
  # persistent is set to True because the tape is used more than
  # once to calculate the gradients.
  with tf.GradientTape(persistent=True) as tape:
    # Generator G translates X -> Y
    # Generator F translates Y -> X.
    
    fake_y = generator_g(real_x, training=True)
    cycled_x = generator_f(fake_y, training=True)

    fake_x = generator_f(real_y, training=True)
    cycled_y = generator_g(fake_x, training=True)

    # same_x and same_y are used for identity loss.
    same_x = generator_f(real_x, training=True)
    same_y = generator_g(real_y, training=True)

    disc_real_x = discriminator_x(real_x, training=True)
    disc_real_y = discriminator_y(real_y, training=True)

    disc_fake_x = discriminator_x(fake_x, training=True)
    disc_fake_y = discriminator_y(fake_y, training=True)

    # calculate the loss
    gen_g_loss = generator_loss(disc_fake_y)
    gen_f_loss = generator_loss(disc_fake_x)
    
    total_cycle_loss = calc_cycle_loss(real_x, cycled_x) + calc_cycle_loss(real_y, cycled_y)
    
    # Total generator loss = adversarial loss + cycle loss
    total_gen_g_loss = gen_g_loss + total_cycle_loss + identity_loss(real_y, same_y)
    total_gen_f_loss = gen_f_loss + total_cycle_loss + identity_loss(real_x, same_x)

    disc_x_loss = discriminator_loss(disc_real_x, disc_fake_x)
    disc_y_loss = discriminator_loss(disc_real_y, disc_fake_y)
  
  # Calculate the gradients for generator and discriminator
  generator_g_gradients = tape.gradient(total_gen_g_loss, 
                                        generator_g.trainable_variables)
  generator_f_gradients = tape.gradient(total_gen_f_loss, 
                                        generator_f.trainable_variables)
  
  discriminator_x_gradients = tape.gradient(disc_x_loss, 
                                            discriminator_x.trainable_variables)
  discriminator_y_gradients = tape.gradient(disc_y_loss, 
                                            discriminator_y.trainable_variables)
  
  # Apply the gradients to the optimizer
  generator_g_optimizer.apply_gradients(zip(generator_g_gradients, 
                                            generator_g.trainable_variables))

  generator_f_optimizer.apply_gradients(zip(generator_f_gradients, 
                                            generator_f.trainable_variables))
  
  discriminator_x_optimizer.apply_gradients(zip(discriminator_x_gradients,
                                                discriminator_x.trainable_variables))
  
  discriminator_y_optimizer.apply_gradients(zip(discriminator_y_gradients,
                                                discriminator_y.trainable_variables))

上面的训练循环执行以下操作:

  • 获取预测
  • 计算损失
  • 使用反向传播计算梯度
  • 将梯度应用到优化器

在训练过程中,网络会从训练集中随机选择一张图像并将其与其转换后的版本一起显示,让我们可视化每个epoch后的性能变化,如下图所示。

for epoch in range(EPOCHS):
  start = time.time()

  n = 0
  for image_x, image_y in tf.data.Dataset.zip((train_horses, train_zebras)):
    train_step(image_x, image_y)
    if n % 10 == 0:
      print ('.', end='')
    n += 1

  clear_output(wait=True)
  # Using a consistent image (sample_horse) so that the progress of the model
  # is clearly visible.
  generate_images(generator_g, sample_horse)

  if (epoch + 1) % 5 == 0:
    ckpt_save_path = ckpt_manager.save()
    print ('Saving checkpoint for epoch {} at {}'.format(epoch+1,
                                                         ckpt_save_path))

  print ('Time taken for epoch {} is {} sec\n'.format(epoch + 1,
                                                      time.time()-start))

评估CycleGAN

一旦CycleGAN训练完毕,我们就可以开始为它提供新图像并评估它在将马转换为斑马(反之亦然)方面的性能。

让我们在数据集中的图像上测试经过训练的CycleGAN,并可视化其泛化能力。我们将使用generate_images函数,该函数将选取一些图像,将它们传递给经过训练的网络,并显示转换结果。

def generate_images(model, test_input):
  prediction = model(test_input)
    
  plt.figure(figsize=(12, 12))

  display_list = [test_input[0], prediction[0]]
  title = ['Input Image', 'Predicted Image']

  for i in range(2):
    plt.subplot(1, 2, i+1)
    plt.title(title[i])
    # getting the pixel values between [0, 1] to plot it.
    plt.imshow(display_list[i] * 0.5 + 0.5)
    plt.axis('off')
  plt.show()

现在,您可以选择任何测试图像并可视化转换结果:

for inp in test_horses.take(5):
  generate_images(generator_g, inp)

以下是网络仅训练20个epoch后获得的一些示例。对于这么短的训练,结果相当不错。您可以通过添加更多时期来改进它们。

季节更替CycleGAN

我们可以使用我们为不同任务设计的网络,例如日夜更替或季节更替。为了训练我们的季节更替网络,我们需要做的就是将训练数据集更改为summer2winter。

我们在上述数据集上训练了我们的网络80个epoch。看看结果。

下一步

在本文中,我们使用基于U-Net的生成器训练了CycleGAN。在下一篇文章中​​​​​​​,我们将向您展示如何实现基于残差的生成器并在医疗数据集上训练生成的CycleGAN。

https://www.codeproject.com/Articles/5304927/Training-a-Running-a-Deep-CycleGAN-for-Mobile-Styl

关注
打赏
1665926880
查看更多评论
立即登录/注册

微信扫码登录

0.0468s