您当前的位置: 首页 > 
  • 1浏览

    0关注

    417博文

    0收益

  • 0浏览

    0点赞

    0打赏

    0留言

私信
关注
热门博文

风格迁移2-05:MUNIT(多模态无监督)-源码无死角解析(1)-训练代码总览

江南才尽,年少无知! 发布时间:2020-04-18 10:56:35 ,浏览量:1

以下链接是个人关于 MUNIT(多模态无监督)-图片风格转换,的所有见解,如有错误欢迎大家指出,我会第一时间纠正。有兴趣的朋友可以加微信 17575010159 相互讨论技术。若是帮助到了你什么,一定要记得点赞!因为这是对我最大的鼓励。 文末附带 \color{blue}{文末附带} 文末附带 公众号 − \color{blue}{公众号 -} 公众号− 海量资源。 \color{blue}{ 海量资源}。 海量资源。

风格迁移2-00:MUNIT(多模态无监督)-目录-史上最新无死角讲解

配置文件

在对源码进行讲解之前,我们先来看一下配置文件configs/edges2shoes_folder.yaml,本人注解如下:

# 再训练迭代的期间,保存图像的频率
image_save_iter: 10000        # How often do you want to save output images during training
# 再训练迭代的期间,显示图片的的频率
image_display_iter: 500       # How often do you want to display output images during training
# 单次显示图片的张数
display_size: 16              # How many images do you want to display each time
# 迭代到指定次数,保存一次模型
snapshot_save_iter: 10000     # How often do you want to save trained models
# log打印保存的频率
log_iter: 10                  # How often do you want to log the training stats



# optimization options
# 最大的迭代次数
max_iter: 1000000             # maximum number of training iterations
# 每个批次的大小
batch_size: 1                 # batch size
# 权重衰减
weight_decay: 0.0001          # weight decay
# 优化器相关参数
beta1: 0.5                    # Adam parameter
beta2: 0.999                  # Adam parameter
# 初始化的方式
init: kaiming                 # initialization [gaussian/kaiming/xavier/orthogonal]
# 学习率
lr: 0.0001                    # initial learning rate
# 学习率衰减测率
lr_policy: step               # learning rate scheduler
# 学习率
step_size: 100000             # how often to decay learning rate
# 学习率衰减参数
gamma: 0.5                    # how much to decay learning rate
# 计算生成网络loss的权重大小
gan_w: 1                      # weight of adversarial loss
# 重构图片loos的权重
recon_x_w: 10                 # weight of image reconstruction loss
# 重构图片风格loos的权重
recon_s_w: shu1                  # weight of style reconstruction loss
# 重构图片内容loos的权重
recon_c_w: 1                  # weight of content reconstruction loss

recon_x_cyc_w: 0              # weight of explicit style augmented cycle consistency loss
# 域不变感知损失的权重
vgg_w: 0                      # weight of domain-invariant perceptual loss

# model options
gen:
  # 最深卷积层输出特征的维度
  dim: 64                     # number of filters in the bottommost layer
  # 全连接层的filters
  mlp_dim: 256                # number of filters in MLP
  # 风格特征的filters
  style_dim: 8                # length of style code
  # 激活函数类型
  activ: relu                 # activation function [relu/lrelu/prelu/selu/tanh]
  # 内容编码器下采样的层数
  n_downsample: 2             # number of downsampling layers in content encoder
  # 内容编码器中使用残差模块的数目
  n_res: 4                    # number of residual blocks in content encoder/decoder
  # pad填补的方式
  pad_type: reflect           # padding type [zero/reflect]

dis:
   # 最深卷积层输出特征的维度
  dim: 64                     # number of filters in the bottommost layer
  # 正则化的方式
  norm: none                  # normalization layer [none/bn/in/ln]
  # 激活函数类型
  activ: lrelu                # activation function [relu/lrelu/prelu/selu/tanh]
  # 鉴别模型的层数
  n_layer: 4                  # number of layers in D
  # 计算 GAN loss的方式
  gan_type: lsgan             # GAN loss [lsgan/nsgan]
  # 缩放的数目(暂时不知道是什么)
  num_scales: 3               # number of scales
  # pad填补的方式
  pad_type: reflect           # padding type [zero/reflect]

# data options
input_dim_a: 3                              # number of image channels [1/3]
input_dim_b: 3                              # number of image channels [1/3]
num_workers: 8                              # number of data loading threads
# 重新调整图片的大小
new_size: 256                               # first resize the shortest image side to this size
# 随机裁剪图片的高宽
crop_image_height: 256                      # random crop image of this height
crop_image_width: 256                       # random crop image of this width
#data_root: ./datasets/edges2shoes/     # dataset folder location
# 数据集的根目录
data_root: ../2.Dataset/edges2shoes        # dataset folder location
train.py代码注释
"""
Copyright (C) 2018 NVIDIA Corporation.  All rights reserved.
Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
"""
from utils import get_all_data_loaders, prepare_sub_folder, write_html, write_loss, get_config, write_2images, Timer
import argparse
from torch.autograd import Variable
from trainer import MUNIT_Trainer, UNIT_Trainer
import torch.backends.cudnn as cudnn
import torch
try:
    from itertools import izip as zip
except ImportError: # will be 3.x series
    pass
import os
import sys
import tensorboardX
import shutil
if __name__ == '__main__':

    parser = argparse.ArgumentParser()
    parser.add_argument('--config', type=str, default='configs/edges2shoes_folder.yaml', help='Path to the config file.')
    parser.add_argument('--output_path', type=str, default='.', help="outputs path")
    parser.add_argument("--resume", action="store_true")
    parser.add_argument('--trainer', type=str, default='MUNIT', help="MUNIT|UNIT")
    opts = parser.parse_args()

    cudnn.benchmark = True

    # Load experiment setting,获取环境配置
    config = get_config(opts.config)

    # 最大的迭代次数
    max_iter = config['max_iter']

    # 显示图片大小
    display_size = config['display_size']

    # vgg模型的路径
    config['vgg_model_path'] = opts.output_path

    # Setup model and data loader, 根据配置创建模型
    if opts.trainer == 'MUNIT':
        trainer = MUNIT_Trainer(config)
    elif opts.trainer == 'UNIT':
        trainer = UNIT_Trainer(config)
    else:
        sys.exit("Only support MUNIT|UNIT")
    trainer.cuda()

    # 创建训练以及测试得数据迭代器,同时取出对每个迭代器取出display_size张图片,水平拼接到一起,
    # 后续会一直拿这些图片作为生成图片的演示,当作一个标本即可
    train_loader_a, train_loader_b, test_loader_a, test_loader_b = get_all_data_loaders(config)
    train_display_images_a = torch.stack([train_loader_a.dataset[i] for i in range(display_size)]).cuda()
    train_display_images_b = torch.stack([train_loader_b.dataset[i] for i in range(display_size)]).cuda()
    test_display_images_a = torch.stack([test_loader_a.dataset[i] for i in range(display_size)]).cuda()
    test_display_images_b = torch.stack([test_loader_b.dataset[i] for i in range(display_size)]).cuda()

    # Setup logger and output folders, 设置打印信息以及输出目录
    # 获得模型的名字
    model_name = os.path.splitext(os.path.basename(opts.config))[0]
    # 创建一个 tensorboardX,记录训练过程中的信息
    train_writer = tensorboardX.SummaryWriter(os.path.join(opts.output_path + "/logs", model_name))
    # 准备并且创建好输出目录,同时拷贝对应的config.yaml文件
    output_directory = os.path.join(opts.output_path + "/outputs", model_name)
    checkpoint_directory, image_directory = prepare_sub_folder(output_directory)
    shutil.copy(opts.config, os.path.join(output_directory, 'config.yaml')) # copy config file to output folder

    # Start training,开始训练模型,如果设置opts.resume=Ture,表示接着之前得训练
    iterations = trainer.resume(checkpoint_directory, hyperparameters=config) if opts.resume else 0
    while True:
        # 获取训练数据
        for it, (images_a, images_b) in enumerate(zip(train_loader_a, train_loader_b)):
            # 更新学习率,
            trainer.update_learning_rate()
            # 指定数据存储计算的设备
            images_a, images_b = images_a.cuda().detach(), images_b.cuda().detach()

            with Timer("Elapsed time in update: %f"):

                # Main training code,主要的训练代码
                trainer.dis_update(images_a, images_b, config)
                trainer.gen_update(images_a, images_b, config)
                torch.cuda.synchronize()

            # Dump training stats in log file,记录训练过程中的信息
            if (iterations + 1) % config['log_iter'] == 0:
                print("Iteration: %08d/%08d" % (iterations + 1, max_iter))
                write_loss(iterations, trainer, train_writer)

            # Write images,到达指定次数后,把生成的样本图片写入到输出文件夹,方便观察生成效果,重新保存
            if (iterations + 1) % config['image_save_iter'] == 0:
                with torch.no_grad():
                    test_image_outputs = trainer.sample(test_display_images_a, test_display_images_b)
                    train_image_outputs = trainer.sample(train_display_images_a, train_display_images_b)
                write_2images(test_image_outputs, display_size, image_directory, 'test_%08d' % (iterations + 1))
                write_2images(train_image_outputs, display_size, image_directory, 'train_%08d' % (iterations + 1))
                # HTML
                write_html(output_directory + "/index.html", iterations + 1, config['image_save_iter'], 'images')

                # Write images,到达指定次数后,把生成的样本图片写入到输出文件夹,方便观察生成效果,覆盖上一次结果
            if (iterations + 1) % config['image_display_iter'] == 0:
                with torch.no_grad():
                    image_outputs = trainer.sample(train_display_images_a, train_display_images_b)
                write_2images(image_outputs, display_size, image_directory, 'train_current')

            # Save network weights, 保存训练的模型
            if (iterations + 1) % config['snapshot_save_iter'] == 0:
                trainer.save(checkpoint_directory, iterations)

            # 如果超过最大迭代次数,则退出训练
            iterations += 1
            if iterations >= max_iter:
                sys.exit('Finish training')

还是特别简单,基本都是这个套路: 1.加载训练测试数据集迭代器 2.构建网络模型 3.迭代训练 4.模型评估保存 好了,总体的结构就简单的介绍到这里,下小结为大家开始讲解代码的每一个细节。

在这里插入图片描述

关注
打赏
1592542134
查看更多评论
立即登录/注册

微信扫码登录

0.0766s