您当前的位置: 首页 > 
  • 2浏览

    0关注

    417博文

    0收益

  • 0浏览

    0点赞

    0打赏

    0留言

私信
关注
热门博文

姿态估计0-05:DenseFusion(6D姿态估计)-源码解析(1)-训练代码初探,框架了解

江南才尽,年少无知! 发布时间:2019-11-16 12:04:02 ,浏览量:2

以下链接是个人关于DenseFusion(6D姿态估计) 所有见解,如有错误欢迎大家指出,我会第一时间纠正。有兴趣的朋友可以加微信:17575010159 相互讨论技术。若是帮助到了你什么,一定要记得点赞!因为这是对我最大的鼓励。 文末附带 \color{blue}{文末附带} 文末附带 公众号 − \color{blue}{公众号 -} 公众号− 海量资源。 \color{blue}{ 海量资源}。 海量资源。

姿态估计0-00:DenseFusion(6D姿态估计)-目录-史上最新无死角讲解https://blog.csdn.net/weixin_43013761/article/details/103053585

代码详细注解

从之前的博客,我相信大家都已经知道,训练代码为tools/train.py,下面时对该代码的详细注解(这里只要随便看看就好,最后面还有总结)

# --------------------------------------------------------
# DenseFusion 6D Object Pose Estimation by Iterative Dense Fusion
# Licensed under The MIT License [see LICENSE for details]
# Written by Chen
# --------------------------------------------------------

import _init_paths
import argparse
import os
import random
import time
import numpy as np
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.utils.data
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision.utils as vutils
from torch.autograd import Variable
from datasets.ycb.dataset import PoseDataset as PoseDataset_ycb
from datasets.warehouse.dataset import PoseDataset as PoseDataset_warehouse
from datasets.linemod.dataset import PoseDataset as PoseDataset_linemod
from lib.network import PoseNet, PoseRefineNet
from lib.loss import Loss
from lib.loss_refiner import Loss_refine
from lib.utils import setup_logger
from torchsummary import summary


parser = argparse.ArgumentParser()
parser.add_argument('--dataset', type=str, default = 'ycb', help='ycb or warehouse or linemod')
parser.add_argument('--dataset_root', type=str, default = '', help='dataset root dir (''YCB_Video_Dataset'' or ''Warehouse_Dataset'' or ''Linemod_preprocessed'')')
parser.add_argument('--batch_size', type=int, default = 8, help='batch size')

# 加载数据的线程数目
parser.add_argument('--workers', type=int, default = 10, help='number of data loading workers')

# 初始学习率
parser.add_argument('--lr', default=0.0001, help='learning rate')

parser.add_argument('--lr_rate', default=0.3, help='learning rate decay rate')

# 初始权重
parser.add_argument('--w', default=0.015, help='learning rate')
# 权重衰减率
parser.add_argument('--w_rate', default=0.3, help='learning rate decay rate')

#
parser.add_argument('--decay_margin', default=0.016, help='margin to decay lr & w')

# 大概是loss到了这个设定的值,则会进行refine模型的训练
parser.add_argument('--refine_margin', default=0.013, help='margin to start the training of iterative refinement')

# 给训练数据添加噪声相关的参数,可以理解为数据增强
parser.add_argument('--noise_trans', default=0.03, help='range of the random noise of translation added to the training data')

# 训练refinenet的时候是连续迭代几次
parser.add_argument('--iteration', type=int, default = 2, help='number of refinement iterations')

# 训练到多少个epoch则停止
parser.add_argument('--nepoch', type=int, default=500, help='max number of epochs to train')

# 是否继续训练posenet模型,继续训练则加载posenet预训练模型
parser.add_argument('--resume_posenet', type=str, default = '',  help='resume PoseNet model')
# 是否继续训练refinenet模型,继续训练则加载refinenet预训练模型
parser.add_argument('--resume_refinenet', type=str, default = '',  help='resume PoseRefineNet model')

parser.add_argument('--start_epoch', type=int, default = 1, help='which epoch to start')
opt = parser.parse_args()


def main():
    opt.manualSeed = random.randint(1, 100)

    # 为CPU随机生成数设定的种子
    random.seed(opt.manualSeed)
    # 为GPU随机生成数设定的种子
    torch.manual_seed(opt.manualSeed)

    # 根据数据集的不同,分别配置其
    # 训练数据的物体种类数目,输入点云的数目,训练模型保存的目录,log保存的目录,起始的epoch数目
    if opt.dataset == 'ycb':
        opt.num_objects = 21 #number of object classes in the dataset
        opt.num_points = 1000 #number of points on the input pointcloud
        opt.outf = 'trained_models/ycb' #folder to save trained models
        opt.log_dir = 'experiments/logs/ycb' #folder to save logs
        opt.repeat_epoch = 1 #number of repeat times for one epoch training
    elif opt.dataset == 'warehouse':
        opt.num_objects = 13
        opt.num_points = 1000
        opt.outf = 'trained_models/warehouse'
        opt.log_dir = 'experiments/logs/warehouse'
        opt.repeat_epoch = 1
    elif opt.dataset == 'linemod':
        opt.num_objects = 13
        opt.num_points = 500
        opt.outf = 'trained_models/linemod'
        opt.log_dir = 'experiments/logs/linemod'
        opt.repeat_epoch = 20
    else:
        print('Unknown dataset')
        return

    # 该处为网络的构建,构建完成之后,能对物体的6D姿态进行预测
    estimator = PoseNet(num_points = opt.num_points, num_obj = opt.num_objects)
    estimator.cuda()
    #summary(estimator,[(3, 120, 120),(500,3),(1,500),(1,)])
    # 对初步预测的姿态进行提炼
    refiner = PoseRefineNet(num_points = opt.num_points, num_obj = opt.num_objects)
    refiner.cuda()

    # 对posenet以及refinenet模型的加载,然后标记对应的网络是否已经开始训练过了,以及是否进行衰减
    if opt.resume_posenet != '':
        estimator.load_state_dict(torch.load('{0}/{1}'.format(opt.outf, opt.resume_posenet)))
    if opt.resume_refinenet != '':
        refiner.load_state_dict(torch.load('{0}/{1}'.format(opt.outf, opt.resume_refinenet)))
        opt.refine_start = True
        opt.decay_start = True
        opt.lr *= opt.lr_rate
        opt.w *= opt.w_rate
        opt.batch_size = int(opt.batch_size / opt.iteration)
        optimizer = optim.Adam(refiner.parameters(), lr=opt.lr)
    else:
        opt.refine_start = False
        opt.decay_start = False
        optimizer = optim.Adam(estimator.parameters(), lr=opt.lr)

    # 加载对应的训练和验证数据集
    if opt.dataset == 'ycb':
        dataset = PoseDataset_ycb('train', opt.num_points, True, opt.dataset_root, opt.noise_trans, opt.refine_start)
    elif opt.dataset == 'warehouse':
        dataset = PoseDataset_warehouse('train', opt.num_points, True, opt.dataset_root, opt.noise_trans, opt.refine_start)
    elif opt.dataset == 'linemod':
        dataset = PoseDataset_linemod('train', opt.num_points, True, opt.dataset_root, opt.noise_trans, opt.refine_start)
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True, num_workers=opt.workers)
    if opt.dataset == 'ycb':
        test_dataset = PoseDataset_ycb('test', opt.num_points, False, opt.dataset_root, 0.0, opt.refine_start)
    elif opt.dataset == 'warehouse':
        test_dataset = PoseDataset_warehouse('test', opt.num_points, False, opt.dataset_root, 0.0, opt.refine_start)
    elif opt.dataset == 'linemod':
        test_dataset = PoseDataset_linemod('test', opt.num_points, False, opt.dataset_root, 0.0, opt.refine_start)
    testdataloader = torch.utils.data.DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=opt.workers)


    opt.sym_list = dataset.get_sym_list()
    #print(opt.sym_list)
    opt.num_points_mesh = dataset.get_num_points_mesh()


    print('>>>>>>>>----------Dataset loaded!-------------------epoch {0} train finish---------            
关注
打赏
1592542134
查看更多评论
0.0522s