您当前的位置: 首页 >  pytorch
  • 1浏览

    0关注

    483博文

    0收益

  • 0浏览

    0点赞

    0打赏

    0留言

私信
关注
热门博文

Pytorch转换为onnx模型:ShapeMatchingGan初体验

高精度计算机视觉 发布时间:2022-04-28 14:25:35 ,浏览量:1

源码:

GitHub - VITA-Group/ShapeMatchingGAN: [ICCV 2019, Oral] Controllable Artistic Text Style Transfer via Shape-Matching GAN

介绍:

ICCV 2019 开源论文 | ShapeMatchingGAN:打造炫酷动态的艺术字 | 机器之心 

文章的核心是网络结构的设计,没有太多难以理解的地方,大致摘录如下。

双向形状匹配策略

ShapeMatchingGAN 的首要目的是学会文字的变形。不同于纹理尺度、风格强度等可用超参描述的特征,文字变形难以定义与建模,同时也没有对应的数据集支撑。为了解决这个问题,文章提出了双向形状匹配策略:

整体思路是比较直观理解的。分为两个阶段,第一个阶段(反向结构迁移),提取风格图的结构,反向将文字的形状风格迁移到结构图上,获得简化的结构图。第二个阶段(正向风格迁移),正向学习该上述过程的逆过程,即学习将简化的结构映射到原始结构再进一步映射回风格图。这样网络就学会了为文字边缘增添风格图的形状特征和渲染纹理。

但是我们还面临两个挑战,首先,如何在风格图只有一张的条件下,训练网络;其次,如何训练一个网络来快速处理不同的变形程度。

其他内容我就不重复了,作者提供了一个文件

ShapeMatchingGAN.ipynb

展示了整个网络处理的全部过程。我拆分成三部分,分别对应为

sketchMatchingGan1.py
sketchMatchingGan2.py
sketchMatchingGan3.py

对应可以用vscode调试分步运行,注意需要有cuda支持。具体可以到我的github仓库查看,

GitHub - SpaceView/ShapeMatchingGan_Test

另外,netron对ckpt的支持不好,为了看网络结构,我把模型都转换成了onnx格式,转换的源码如下(GB.ckpt是指GB-iccv.ckpt这个文件,源码也上传一了我的git仓库), 

from __future__ import print_function
import torch
from torch.autograd import Variable
from models import SketchModule
from utils import load_image, to_data, to_var, visualize, save_image, gaussian, weights_init
from utils import load_train_batchfnames, prepare_text_batch
import argparse
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
import sys

from pathlib import Path as ppath
FILE = ppath(__file__).resolve()
#ROOT = FILE.parents[1]
#if str(ROOT) not in sys.path:
#    sys.path.append(str(ROOT))
#ROOT = ppath(os.path.relpath(ROOT, ppath.cwd())) 
ROOT = FILE.parents[0]
cwdir = os.getcwd()
cudir = os.chdir(ROOT)

opts = argparse.ArgumentParser()
opts.GB_nlayers = 6
opts.DB_nlayers = 5
opts.GB_nf = 32
opts.DB_nf = 32
opts.gpu = True
opts.epochs = 3
opts.save_GB_name = '../save/GB.ckpt'
opts.batchsize = 16
opts.text_path = '../data/rawtext/yaheiB/train'
opts.augment_text_path = '../data/rawtext/augment'
opts.text_datasize = 708
opts.augment_text_datasize = 5
opts.Btraining_num = 12800

# create model
print('--- create model ---')
netSketch = SketchModule(opts.GB_nlayers, opts.DB_nlayers, opts.GB_nf, opts.DB_nf, opts.gpu)
if opts.gpu:
    netSketch.cuda()
#netSketch.init_networks(weights_init)
#netSketch.train()


import torch.onnx 

#Function to Convert to ONNX 
def Convert_ONNX():     
    model.eval() 
    
    I = load_image('../data/style/leaf.png')
    I = to_var(I[:,:,:,0:int(I.size(3)/2)])

    # Export the model   
    torch.onnx.export(model,          # model being run 
         (I, -1), # dummy_input,       # model input (or a tuple for multiple inputs) 
         "GB-ckpt1.onnx",      # where to save the model  
         export_params=True,  # store the trained parameter weights inside the model file 
         opset_version=10,    # the ONNX version to export the model to 
         do_constant_folding=True,    # whether to execute constant folding for optimization 
         input_names = ['modelInput'],     # the model's input names 
         output_names = ['modelOutput'],   # the model's output names 
         dynamic_axes={'modelInput' : {0 : 'batch_size'},  # variable length axes
                                'modelOutput' : {0 : 'batch_size'}}) 
    print(" ") 
    print('Model has been converted to ONNX') 

if __name__ == "__main__":   

    model = netSketch
    #path = "myFirstModel.pth" 
    #model.load_state_dict(torch.load(path)) 
    state_dict = torch.load('../save/GB.ckpt')
    model.load_state_dict(state_dict)
 
    # Conversion to ONNX 
    Convert_ONNX() 

print("all done!")

 其他子模型的转换如下,

from __future__ import print_function
import torch
from models import SketchModule, ShapeMatchingGAN
from utils import load_image, to_data, to_var, visualize, save_image, gaussian, weights_init
from utils import load_train_batchfnames, prepare_text_batch, load_style_image_pair, cropping_training_batches
import random
import argparse
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

from pathlib import Path as ppath
FILE = ppath(__file__).resolve()
#ROOT = FILE.parents[1]
#if str(ROOT) not in sys.path:
#    sys.path.append(str(ROOT))
#ROOT = ppath(os.path.relpath(ROOT, ppath.cwd())) 
ROOT = FILE.parents[0]
cwdir = os.getcwd()
cudir = os.chdir(ROOT)

opts = argparse.ArgumentParser()
# SMGAN
opts.GS_nlayers = 6
opts.DS_nlayers = 4
opts.GS_nf = 32
opts.DS_nf = 32
opts.GT_nlayers = 6
opts.DT_nlayers = 4
opts.GT_nf = 32
opts.DT_nf = 32

# SketchModule
opts.GB_nlayers = 6
opts.DB_nlayers = 5
opts.GB_nf = 32
opts.DB_nf = 32
opts.load_GB_name = '../save/GB-iccv.ckpt'

# train 
opts.gpu = True
opts.step1_epochs = 30
opts.step2_epochs = 40
opts.step3_epochs = 80
opts.step4_epochs = 10
opts.batchsize = 16
opts.Straining_num = 2560
opts.scale_num = 4
opts.Sanglejitter = True
opts.subimg_size = 256
opts.glyph_preserve = False
opts.text_datasize = 708
opts.text_path = '../data/rawtext/yaheiB/train'

# data and path
opts.save_path = '../save/'
opts.save_name = 'maple'
opts.style_name = '../data/style/maple.png'


# create model
print('--- create model ---')
netShapeM = ShapeMatchingGAN(opts.GS_nlayers, opts.DS_nlayers, opts.GS_nf, opts.DS_nf,
                 opts.GT_nlayers, opts.DT_nlayers, opts.GT_nf, opts.DT_nf, opts.gpu)

if opts.gpu:
    netShapeM.cuda()
#netShapeM.init_networks(weights_init)
#netShapeM.train()

import torch.onnx 
from torch.autograd import Variable

#Function to Convert to ONNX 
def Convert_ONNX(model, model_name, dummy_input): 
    model.eval()
   
    # Export the model   
    torch.onnx.export(model,          # model being run 
        dummy_input, # (I, -1),       # model input (or a tuple for multiple inputs) 
        model_name,  # "GB-ckpt1.onnx",      # where to save the model  
        export_params=True,  # store the trained parameter weights inside the model file 
        opset_version=10,    # the ONNX version to export the model to 
        do_constant_folding=True,    # whether to execute constant folding for optimization 
        input_names = ['modelInput'],     # the model's input names 
        output_names = ['modelOutput'],   # the model's output names 
        dynamic_axes={'modelInput' : {0 : 'batch_size'},  # variable length axes
                               'modelOutput' : {0 : 'batch_size'}}) 
    print(" ") 
    print('Model has been converted to ONNX: ', model_name) 

if __name__ == "__main__":     
    # done ---- OK ---- 20220428
    model = netShapeM.G_S
    state_dict = torch.load('../save/maple-GS-iccv.ckpt')
    model.load_state_dict(state_dict) 
    I = load_image('../data/rawtext/yaheiB/val/0801.png')
    I = to_var(I[:,:,32:288,32:288])
    I[:,0:1] = gaussian(I[:,0:1], stddev=0.2)
    dummy_input = (I, 1.0)
    Convert_ONNX(model, 'maple-GS-iccv_ckpt.onnx', dummy_input)
    
    state_dict = torch.load('../save/maple-GT-iccv.ckpt')
    model = netShapeM.G_T
    model.load_state_dict(state_dict)
    I = Variable(torch.randn(1, 3, 320, 320, requires_grad=True)).cuda()
    Convert_ONNX(model, 'maple-GT-iccv_ckpt.onnx', I)

print("all done!")

转换完成后,会生民两个子文件,

maple-GS-iccv_ckpt.onnx

maple-GT-iccv_ckpt.onnx

然后,就可以用netron打开直观地看结果了。

关注
打赏
1661664439
查看更多评论
立即登录/注册

微信扫码登录

0.0386s