以下链接是个人关于DenseFusion(6D姿态估计) 所有见解,如有错误欢迎大家指出,我会第一时间纠正。有兴趣的朋友可以加微信:17575010159 相互讨论技术。若是帮助到了你什么,一定要记得点赞!因为这是对我最大的鼓励。 文末附带 \color{blue}{文末附带} 文末附带 公众号 − \color{blue}{公众号 -} 公众号− 海量资源。 \color{blue}{ 海量资源}。 海量资源。
姿态估计0-00:DenseFusion(6D姿态估计)-目录-史上最新无死角讲解https://blog.csdn.net/weixin_43013761/article/details/103053585
代码详细注解从之前的博客,我相信大家都已经知道,训练代码为tools/train.py,下面时对该代码的详细注解(这里只要随便看看就好,最后面还有总结)
# --------------------------------------------------------
# DenseFusion 6D Object Pose Estimation by Iterative Dense Fusion
# Licensed under The MIT License [see LICENSE for details]
# Written by Chen
# --------------------------------------------------------
import _init_paths
import argparse
import os
import random
import time
import numpy as np
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.utils.data
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision.utils as vutils
from torch.autograd import Variable
from datasets.ycb.dataset import PoseDataset as PoseDataset_ycb
from datasets.warehouse.dataset import PoseDataset as PoseDataset_warehouse
from datasets.linemod.dataset import PoseDataset as PoseDataset_linemod
from lib.network import PoseNet, PoseRefineNet
from lib.loss import Loss
from lib.loss_refiner import Loss_refine
from lib.utils import setup_logger
from torchsummary import summary
parser = argparse.ArgumentParser()
parser.add_argument('--dataset', type=str, default = 'ycb', help='ycb or warehouse or linemod')
parser.add_argument('--dataset_root', type=str, default = '', help='dataset root dir (''YCB_Video_Dataset'' or ''Warehouse_Dataset'' or ''Linemod_preprocessed'')')
parser.add_argument('--batch_size', type=int, default = 8, help='batch size')
# 加载数据的线程数目
parser.add_argument('--workers', type=int, default = 10, help='number of data loading workers')
# 初始学习率
parser.add_argument('--lr', default=0.0001, help='learning rate')
parser.add_argument('--lr_rate', default=0.3, help='learning rate decay rate')
# 初始权重
parser.add_argument('--w', default=0.015, help='learning rate')
# 权重衰减率
parser.add_argument('--w_rate', default=0.3, help='learning rate decay rate')
#
parser.add_argument('--decay_margin', default=0.016, help='margin to decay lr & w')
# 大概是loss到了这个设定的值,则会进行refine模型的训练
parser.add_argument('--refine_margin', default=0.013, help='margin to start the training of iterative refinement')
# 给训练数据添加噪声相关的参数,可以理解为数据增强
parser.add_argument('--noise_trans', default=0.03, help='range of the random noise of translation added to the training data')
# 训练refinenet的时候是连续迭代几次
parser.add_argument('--iteration', type=int, default = 2, help='number of refinement iterations')
# 训练到多少个epoch则停止
parser.add_argument('--nepoch', type=int, default=500, help='max number of epochs to train')
# 是否继续训练posenet模型,继续训练则加载posenet预训练模型
parser.add_argument('--resume_posenet', type=str, default = '', help='resume PoseNet model')
# 是否继续训练refinenet模型,继续训练则加载refinenet预训练模型
parser.add_argument('--resume_refinenet', type=str, default = '', help='resume PoseRefineNet model')
parser.add_argument('--start_epoch', type=int, default = 1, help='which epoch to start')
opt = parser.parse_args()
def main():
opt.manualSeed = random.randint(1, 100)
# 为CPU随机生成数设定的种子
random.seed(opt.manualSeed)
# 为GPU随机生成数设定的种子
torch.manual_seed(opt.manualSeed)
# 根据数据集的不同,分别配置其
# 训练数据的物体种类数目,输入点云的数目,训练模型保存的目录,log保存的目录,起始的epoch数目
if opt.dataset == 'ycb':
opt.num_objects = 21 #number of object classes in the dataset
opt.num_points = 1000 #number of points on the input pointcloud
opt.outf = 'trained_models/ycb' #folder to save trained models
opt.log_dir = 'experiments/logs/ycb' #folder to save logs
opt.repeat_epoch = 1 #number of repeat times for one epoch training
elif opt.dataset == 'warehouse':
opt.num_objects = 13
opt.num_points = 1000
opt.outf = 'trained_models/warehouse'
opt.log_dir = 'experiments/logs/warehouse'
opt.repeat_epoch = 1
elif opt.dataset == 'linemod':
opt.num_objects = 13
opt.num_points = 500
opt.outf = 'trained_models/linemod'
opt.log_dir = 'experiments/logs/linemod'
opt.repeat_epoch = 20
else:
print('Unknown dataset')
return
# 该处为网络的构建,构建完成之后,能对物体的6D姿态进行预测
estimator = PoseNet(num_points = opt.num_points, num_obj = opt.num_objects)
estimator.cuda()
#summary(estimator,[(3, 120, 120),(500,3),(1,500),(1,)])
# 对初步预测的姿态进行提炼
refiner = PoseRefineNet(num_points = opt.num_points, num_obj = opt.num_objects)
refiner.cuda()
# 对posenet以及refinenet模型的加载,然后标记对应的网络是否已经开始训练过了,以及是否进行衰减
if opt.resume_posenet != '':
estimator.load_state_dict(torch.load('{0}/{1}'.format(opt.outf, opt.resume_posenet)))
if opt.resume_refinenet != '':
refiner.load_state_dict(torch.load('{0}/{1}'.format(opt.outf, opt.resume_refinenet)))
opt.refine_start = True
opt.decay_start = True
opt.lr *= opt.lr_rate
opt.w *= opt.w_rate
opt.batch_size = int(opt.batch_size / opt.iteration)
optimizer = optim.Adam(refiner.parameters(), lr=opt.lr)
else:
opt.refine_start = False
opt.decay_start = False
optimizer = optim.Adam(estimator.parameters(), lr=opt.lr)
# 加载对应的训练和验证数据集
if opt.dataset == 'ycb':
dataset = PoseDataset_ycb('train', opt.num_points, True, opt.dataset_root, opt.noise_trans, opt.refine_start)
elif opt.dataset == 'warehouse':
dataset = PoseDataset_warehouse('train', opt.num_points, True, opt.dataset_root, opt.noise_trans, opt.refine_start)
elif opt.dataset == 'linemod':
dataset = PoseDataset_linemod('train', opt.num_points, True, opt.dataset_root, opt.noise_trans, opt.refine_start)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True, num_workers=opt.workers)
if opt.dataset == 'ycb':
test_dataset = PoseDataset_ycb('test', opt.num_points, False, opt.dataset_root, 0.0, opt.refine_start)
elif opt.dataset == 'warehouse':
test_dataset = PoseDataset_warehouse('test', opt.num_points, False, opt.dataset_root, 0.0, opt.refine_start)
elif opt.dataset == 'linemod':
test_dataset = PoseDataset_linemod('test', opt.num_points, False, opt.dataset_root, 0.0, opt.refine_start)
testdataloader = torch.utils.data.DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=opt.workers)
opt.sym_list = dataset.get_sym_list()
#print(opt.sym_list)
opt.num_points_mesh = dataset.get_num_points_mesh()
print('>>>>>>>>----------Dataset loaded!-------------------epoch {0} train finish---------
关注
打赏
最近更新
- 深拷贝和浅拷贝的区别(重点)
- 【Vue】走进Vue框架世界
- 【云服务器】项目部署—搭建网站—vue电商后台管理系统
- 【React介绍】 一文带你深入React
- 【React】React组件实例的三大属性之state,props,refs(你学废了吗)
- 【脚手架VueCLI】从零开始,创建一个VUE项目
- 【React】深入理解React组件生命周期----图文详解(含代码)
- 【React】DOM的Diffing算法是什么?以及DOM中key的作用----经典面试题
- 【React】1_使用React脚手架创建项目步骤--------详解(含项目结构说明)
- 【React】2_如何使用react脚手架写一个简单的页面?