以下链接是个人关于 MUNIT(多模态无监督)-图片风格转换,的所有见解,如有错误欢迎大家指出,我会第一时间纠正。有兴趣的朋友可以加微信 17575010159 相互讨论技术。若是帮助到了你什么,一定要记得点赞!因为这是对我最大的鼓励。 文末附带 \color{blue}{文末附带} 文末附带 公众号 − \color{blue}{公众号 -} 公众号− 海量资源。 \color{blue}{ 海量资源}。 海量资源。
风格迁移2-00:MUNIT(多模态无监督)-目录-史上最新无死角讲解
配置文件在对源码进行讲解之前,我们先来看一下配置文件configs/edges2shoes_folder.yaml,本人注解如下:
# 再训练迭代的期间,保存图像的频率
image_save_iter: 10000 # How often do you want to save output images during training
# 再训练迭代的期间,显示图片的的频率
image_display_iter: 500 # How often do you want to display output images during training
# 单次显示图片的张数
display_size: 16 # How many images do you want to display each time
# 迭代到指定次数,保存一次模型
snapshot_save_iter: 10000 # How often do you want to save trained models
# log打印保存的频率
log_iter: 10 # How often do you want to log the training stats
# optimization options
# 最大的迭代次数
max_iter: 1000000 # maximum number of training iterations
# 每个批次的大小
batch_size: 1 # batch size
# 权重衰减
weight_decay: 0.0001 # weight decay
# 优化器相关参数
beta1: 0.5 # Adam parameter
beta2: 0.999 # Adam parameter
# 初始化的方式
init: kaiming # initialization [gaussian/kaiming/xavier/orthogonal]
# 学习率
lr: 0.0001 # initial learning rate
# 学习率衰减测率
lr_policy: step # learning rate scheduler
# 学习率
step_size: 100000 # how often to decay learning rate
# 学习率衰减参数
gamma: 0.5 # how much to decay learning rate
# 计算生成网络loss的权重大小
gan_w: 1 # weight of adversarial loss
# 重构图片loos的权重
recon_x_w: 10 # weight of image reconstruction loss
# 重构图片风格loos的权重
recon_s_w: shu1 # weight of style reconstruction loss
# 重构图片内容loos的权重
recon_c_w: 1 # weight of content reconstruction loss
recon_x_cyc_w: 0 # weight of explicit style augmented cycle consistency loss
# 域不变感知损失的权重
vgg_w: 0 # weight of domain-invariant perceptual loss
# model options
gen:
# 最深卷积层输出特征的维度
dim: 64 # number of filters in the bottommost layer
# 全连接层的filters
mlp_dim: 256 # number of filters in MLP
# 风格特征的filters
style_dim: 8 # length of style code
# 激活函数类型
activ: relu # activation function [relu/lrelu/prelu/selu/tanh]
# 内容编码器下采样的层数
n_downsample: 2 # number of downsampling layers in content encoder
# 内容编码器中使用残差模块的数目
n_res: 4 # number of residual blocks in content encoder/decoder
# pad填补的方式
pad_type: reflect # padding type [zero/reflect]
dis:
# 最深卷积层输出特征的维度
dim: 64 # number of filters in the bottommost layer
# 正则化的方式
norm: none # normalization layer [none/bn/in/ln]
# 激活函数类型
activ: lrelu # activation function [relu/lrelu/prelu/selu/tanh]
# 鉴别模型的层数
n_layer: 4 # number of layers in D
# 计算 GAN loss的方式
gan_type: lsgan # GAN loss [lsgan/nsgan]
# 缩放的数目(暂时不知道是什么)
num_scales: 3 # number of scales
# pad填补的方式
pad_type: reflect # padding type [zero/reflect]
# data options
input_dim_a: 3 # number of image channels [1/3]
input_dim_b: 3 # number of image channels [1/3]
num_workers: 8 # number of data loading threads
# 重新调整图片的大小
new_size: 256 # first resize the shortest image side to this size
# 随机裁剪图片的高宽
crop_image_height: 256 # random crop image of this height
crop_image_width: 256 # random crop image of this width
#data_root: ./datasets/edges2shoes/ # dataset folder location
# 数据集的根目录
data_root: ../2.Dataset/edges2shoes # dataset folder location
train.py代码注释
"""
Copyright (C) 2018 NVIDIA Corporation. All rights reserved.
Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
"""
from utils import get_all_data_loaders, prepare_sub_folder, write_html, write_loss, get_config, write_2images, Timer
import argparse
from torch.autograd import Variable
from trainer import MUNIT_Trainer, UNIT_Trainer
import torch.backends.cudnn as cudnn
import torch
try:
from itertools import izip as zip
except ImportError: # will be 3.x series
pass
import os
import sys
import tensorboardX
import shutil
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--config', type=str, default='configs/edges2shoes_folder.yaml', help='Path to the config file.')
parser.add_argument('--output_path', type=str, default='.', help="outputs path")
parser.add_argument("--resume", action="store_true")
parser.add_argument('--trainer', type=str, default='MUNIT', help="MUNIT|UNIT")
opts = parser.parse_args()
cudnn.benchmark = True
# Load experiment setting,获取环境配置
config = get_config(opts.config)
# 最大的迭代次数
max_iter = config['max_iter']
# 显示图片大小
display_size = config['display_size']
# vgg模型的路径
config['vgg_model_path'] = opts.output_path
# Setup model and data loader, 根据配置创建模型
if opts.trainer == 'MUNIT':
trainer = MUNIT_Trainer(config)
elif opts.trainer == 'UNIT':
trainer = UNIT_Trainer(config)
else:
sys.exit("Only support MUNIT|UNIT")
trainer.cuda()
# 创建训练以及测试得数据迭代器,同时取出对每个迭代器取出display_size张图片,水平拼接到一起,
# 后续会一直拿这些图片作为生成图片的演示,当作一个标本即可
train_loader_a, train_loader_b, test_loader_a, test_loader_b = get_all_data_loaders(config)
train_display_images_a = torch.stack([train_loader_a.dataset[i] for i in range(display_size)]).cuda()
train_display_images_b = torch.stack([train_loader_b.dataset[i] for i in range(display_size)]).cuda()
test_display_images_a = torch.stack([test_loader_a.dataset[i] for i in range(display_size)]).cuda()
test_display_images_b = torch.stack([test_loader_b.dataset[i] for i in range(display_size)]).cuda()
# Setup logger and output folders, 设置打印信息以及输出目录
# 获得模型的名字
model_name = os.path.splitext(os.path.basename(opts.config))[0]
# 创建一个 tensorboardX,记录训练过程中的信息
train_writer = tensorboardX.SummaryWriter(os.path.join(opts.output_path + "/logs", model_name))
# 准备并且创建好输出目录,同时拷贝对应的config.yaml文件
output_directory = os.path.join(opts.output_path + "/outputs", model_name)
checkpoint_directory, image_directory = prepare_sub_folder(output_directory)
shutil.copy(opts.config, os.path.join(output_directory, 'config.yaml')) # copy config file to output folder
# Start training,开始训练模型,如果设置opts.resume=Ture,表示接着之前得训练
iterations = trainer.resume(checkpoint_directory, hyperparameters=config) if opts.resume else 0
while True:
# 获取训练数据
for it, (images_a, images_b) in enumerate(zip(train_loader_a, train_loader_b)):
# 更新学习率,
trainer.update_learning_rate()
# 指定数据存储计算的设备
images_a, images_b = images_a.cuda().detach(), images_b.cuda().detach()
with Timer("Elapsed time in update: %f"):
# Main training code,主要的训练代码
trainer.dis_update(images_a, images_b, config)
trainer.gen_update(images_a, images_b, config)
torch.cuda.synchronize()
# Dump training stats in log file,记录训练过程中的信息
if (iterations + 1) % config['log_iter'] == 0:
print("Iteration: %08d/%08d" % (iterations + 1, max_iter))
write_loss(iterations, trainer, train_writer)
# Write images,到达指定次数后,把生成的样本图片写入到输出文件夹,方便观察生成效果,重新保存
if (iterations + 1) % config['image_save_iter'] == 0:
with torch.no_grad():
test_image_outputs = trainer.sample(test_display_images_a, test_display_images_b)
train_image_outputs = trainer.sample(train_display_images_a, train_display_images_b)
write_2images(test_image_outputs, display_size, image_directory, 'test_%08d' % (iterations + 1))
write_2images(train_image_outputs, display_size, image_directory, 'train_%08d' % (iterations + 1))
# HTML
write_html(output_directory + "/index.html", iterations + 1, config['image_save_iter'], 'images')
# Write images,到达指定次数后,把生成的样本图片写入到输出文件夹,方便观察生成效果,覆盖上一次结果
if (iterations + 1) % config['image_display_iter'] == 0:
with torch.no_grad():
image_outputs = trainer.sample(train_display_images_a, train_display_images_b)
write_2images(image_outputs, display_size, image_directory, 'train_current')
# Save network weights, 保存训练的模型
if (iterations + 1) % config['snapshot_save_iter'] == 0:
trainer.save(checkpoint_directory, iterations)
# 如果超过最大迭代次数,则退出训练
iterations += 1
if iterations >= max_iter:
sys.exit('Finish training')
还是特别简单,基本都是这个套路: 1.加载训练测试数据集迭代器 2.构建网络模型 3.迭代训练 4.模型评估保存 好了,总体的结构就简单的介绍到这里,下小结为大家开始讲解代码的每一个细节。