以下链接是个人关于PVNet(6D姿态估计) 所有见解,如有错误欢迎大家指出,我会第一时间纠正。有兴趣的朋友可以加微信:17575010159 相互讨论技术。若是帮助到了你什么,一定要记得点赞!因为这是对我最大的鼓励。 文末附带 \color{blue}{文末附带} 文末附带 公众号 − \color{blue}{公众号 -} 公众号− 海量资源。 \color{blue}{ 海量资源}。 海量资源。
姿态估计2-00:PVNet(6D姿态估计)-目录-史上最新无死角讲解
train_net.py注释下面是对train_net.py文件的注释,该代码十分的简单,所以注释也十分简洁:
from lib.config import cfg, args
from lib.networks import make_network
from lib.train import make_trainer, make_optimizer, make_lr_scheduler, make_recorder, set_lr_scheduler
from lib.datasets import make_data_loader
from lib.utils.net_utils import load_model, save_model, load_network
from lib.evaluators import make_evaluator
import torch.multiprocessing
def train(cfg, network):
# 如果训练数据为City,这进行文件系统共享
if cfg.train.dataset[:4] != 'City':
torch.multiprocessing.set_sharing_strategy('file_system')
# 制作训练器
trainer = make_trainer(cfg, network)
# 制作优化器
optimizer = make_optimizer(cfg, network)
# 制作学习率调整器
scheduler = make_lr_scheduler(cfg, optimizer)
# 用于记录信息
recorder = make_recorder(cfg)
# 用于评估
evaluator = make_evaluator(cfg)
# 进行模型加载
begin_epoch = load_model(network, optimizer, scheduler, recorder, cfg.model_dir, resume=cfg.resume)
# set_lr_scheduler(cfg, scheduler)
# 创建训练以及评估数据集
train_loader = make_data_loader(cfg, is_train=True, max_iter=cfg.ep_iter)
val_loader = make_data_loader(cfg, is_train=False)
# train_loader = make_data_loader(cfg, is_train=True, max_iter=100)
# 循环进行迭代训练
for epoch in range(begin_epoch, cfg.train.epoch):
recorder.epoch = epoch
# 进行一个epoch的迭代训练
trainer.train(epoch, train_loader, optimizer, recorder)
# 记录学习了一个epoch,并且根据预设定的参数,看是否需要对学习率进行更改
scheduler.step()
# 迭代到指定次数,保存好训练的
if (epoch + 1) % cfg.save_ep == 0:
save_model(network, optimizer, scheduler, recorder, epoch, cfg.model_dir)
# 迭代到指定次数,进行评估训练
if (epoch + 1) % cfg.eval_ep == 0:
trainer.val(epoch, val_loader, evaluator, recorder)
return network
def test(cfg, network):
# 根据配置创建训练器
trainer = make_trainer(cfg, network)
# 创建数据迭代器
val_loader = make_data_loader(cfg, is_train=False)
# 创建评估器
evaluator = make_evaluator(cfg)
# 加载权重
epoch = load_network(network, cfg.model_dir, resume=cfg.resume, epoch=cfg.test.epoch)
# 进行评估
trainer.val(epoch, val_loader, evaluator)
def main():
# 根据配置参数,构建网路
network = make_network(cfg)
# 根据传入的参数选择测试或者训练
if args.test:
test(cfg, network)
else:
train(cfg, network)
if __name__ == "__main__":
main()
总结
训练代码的套路基本都是差不多的,基本就是 1.解析参数 2.构建网络模型 3.加载训练测试数据集迭代器 4.迭代训练 5.模型评估保存