以下链接是个人关于MVSNet(R-MVSNet)-多视角立体深度推导重建 所有见解,如有错误欢迎大家指出,我会第一时间纠正。有兴趣的朋友可以加微信:17575010159 相互讨论技术。若是帮助到了你什么,一定要记得点赞!因为这是对我最大的鼓励。 文末附带 \color{blue}{文末附带} 文末附带 公众号 − \color{blue}{公众号 -} 公众号− 海量资源。 \color{blue}{ 海量资源}。 海量资源。
3D点云重建0-00:MVSNet(R-MVSNet)–目录-史上最新无死角讲解:https://blog.csdn.net/weixin_43013761/article/details/102852209
代码引导根据前面我们运行测试代码的命令如下:
python test.py --dense_folder ../../MVS_TRANING/scan9/scan9 --model_dir ../../MVS_TRANING/models/tf_model_190307/tf_model/ --regularization 3DCNNs --max_w 1152 --max_h 864 --max_d 192 --interval_scale 1.06
代码注释
#!/usr/bin/env python
"""
Copyright 2019, Yao Yao, HKUST.
Test script.
"""
from __future__ import print_function
import os
import time
import sys
import math
import argparse
import numpy as np
import cv2
import tensorflow as tf
sys.path.append("../")
from tools.common import Notify
from preprocess import *
from model import *
from loss import *
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
# dataset parameters
tf.app.flags.DEFINE_string('dense_folder', None,
"""Root path to dense folder.""")
tf.app.flags.DEFINE_string('model_dir',
'/data/tf_model',
"""Path to restore the model.""")
tf.app.flags.DEFINE_integer('ckpt_step', 100000,
"""ckpt step.""")
# input parameters
tf.app.flags.DEFINE_integer('view_num', 5,
"""Number of images (1 ref image and view_num - 1 view images).""")
tf.app.flags.DEFINE_integer('max_d', 256,
"""Maximum depth step when testing.""")
tf.app.flags.DEFINE_integer('max_w', 1600,
"""Maximum image width when testing.""")
tf.app.flags.DEFINE_integer('max_h', 1200,
"""Maximum image height when testing.""")
tf.app.flags.DEFINE_float('sample_scale', 0.25,
"""Downsample scale for building cost volume (W and H).""")
tf.app.flags.DEFINE_float('interval_scale', 0.8,
"""Downsample scale for building cost volume (D).""")
tf.app.flags.DEFINE_float('base_image_size', 8,
"""Base image size""")
tf.app.flags.DEFINE_integer('batch_size', 1,
"""Testing batch size.""")
tf.app.flags.DEFINE_bool('adaptive_scaling', True,
"""Let image size to fit the network, including 'scaling', 'cropping'""")
# network architecture
tf.app.flags.DEFINE_string('regularization', 'GRU',
"""Regularization method, including '3DCNNs' and 'GRU'""")
tf.app.flags.DEFINE_boolean('refinement', False,
"""Whether to apply depth map refinement for MVSNet""")
tf.app.flags.DEFINE_bool('inverse_depth', True,
"""Whether to apply inverse depth for R-MVSNet""")
FLAGS = tf.app.flags.FLAGS
class MVSGenerator:
""" data generator class, tf only accept generator without param """
def __init__(self, sample_list, view_num):
"""
样本路径的赋值,以及样本数目的计算
"""
self.sample_list = sample_list
self.view_num = view_num
self.sample_num = len(sample_list)
self.counter = 0
def __iter__(self):
"""
每次迭代会调用该函数
:return:
"""
while True:
# 每次获得一个样本
for data in self.sample_list:
# read input data
images = []
cams = []
# basename会返回路径最后的文件名,所以这里的image_index的值应该都为0
image_index = int(os.path.splitext(os.path.basename(data[0]))[0])
selected_view_num = int(len(data) / 2)
for view in range(min(self.view_num, selected_view_num)):
# image_file = file_io.FileIO(data[2 * view], mode='r', )
# 获得视觉图片对应的文件描述符
image_file = file_io.FileIO(data[2 * view], mode='rb', )
# 按照RGB格式读取图片
image = scipy.misc.imread(image_file, mode='RGB')
# 转换为BGR个格式
image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
# cam_file = file_io.FileIO(data[2 * view + 1], mode='r')
cam_file = file_io.FileIO(data[2 * view + 1], mode='rb')
# 摄像头参数加载
cam = load_cam(cam_file, FLAGS.interval_scale)
# 如果参数中没有设置最大测量深度
if cam[1][3][2] == 0:
cam[1][3][2] = FLAGS.max_d
images.append(image)
cams.append(cam)
# 如果需要的视角图,小于测试数据集中视角图
if selected_view_num h_scale:
h_scale = height_scale
if width_scale > w_scale:
w_scale = width_scale
if h_scale > 1 or w_scale > 1:
print("max_h, max_w should < W and H!")
exit(-1)
resize_scale = h_scale
if w_scale > h_scale:
resize_scale = w_scale
# 对输入的图片进行缩放,图片缩放的时候,摄像头对应的参数也要改变
# scaled_input_images (view_num,864, 1152, 3)
scaled_input_images, scaled_input_cams = scale_mvs_input(images, cams, scale=resize_scale)
# crop to fit network,把图片进行剪切,剪切为合适网络输入大小,测试作者代码时,输出(view_num,864, 1152, 3)
croped_images, croped_cams = crop_mvs_input(scaled_input_images, scaled_input_cams)
# center images,就是对图片进行标准化
centered_images = []
for view in range(self.view_num):
centered_images.append(center_image(croped_images[view]))
# sample cameras for building cost volume
real_cams = np.copy(croped_cams)
# 对摄像头参数按照比例进行缩放
scaled_cams = scale_mvs_camera(croped_cams, scale=FLAGS.sample_scale)
# return mvs input
scaled_images = []
for view in range(self.view_num):
scaled_images.append(scale_image(croped_images[view], scale=FLAGS.sample_scale))
scaled_images = np.stack(scaled_images, axis=0)
croped_images = np.stack(croped_images, axis=0)
scaled_cams = np.stack(scaled_cams, axis=0)
self.counter += 1
# scaled_images[5, 216, 288, 3],进行了缩小,但是没有进行正则化
# centered_images[864, 1152, 3],进行了正则化
# scaled_cams对应scaled_images缩小的摄像头参数
# image_index,r img的下标索引
yield (scaled_images, centered_images, scaled_cams, image_index)
def mvsnet_pipeline(mvs_list):
""" mvsnet in altizure pipeline """
print('sample number: ', len(mvs_list))
# create output folder,在测试数据目录下创建depths_mvsnet文件夹
output_folder = os.path.join(FLAGS.dense_folder, 'depths_mvsnet')
if not os.path.isdir(output_folder):
os.mkdir(output_folder)
# testing set,数据迭代器生成
mvs_generator = iter(MVSGenerator(mvs_list, FLAGS.view_num))
a = next(mvs_generator)
generator_data_type = (tf.float32, tf.float32, tf.float32, tf.int32)
mvs_set = tf.data.Dataset.from_generator(lambda: mvs_generator, generator_data_type)
mvs_set = mvs_set.batch(FLAGS.batch_size)
mvs_set = mvs_set.prefetch(buffer_size=1)
# data from dataset via iterator
mvs_iterator = mvs_set.make_initializable_iterator()
# scaled_images[5, 216, 288, 3],进行了缩小,但是没有进行正则化
# centered_images[864, 1152, 3],进行了正则化
# scaled_cams对应scaled_images缩小的摄像头参数
# image_index,r img的下标索引
scaled_images, centered_images, scaled_cams, image_index = mvs_iterator.get_next()
# set shapes,设置形状
scaled_images.set_shape(tf.TensorShape([None, FLAGS.view_num, None, None, 3]))
centered_images.set_shape(tf.TensorShape([None, FLAGS.view_num, None, None, 3]))
scaled_cams.set_shape(tf.TensorShape([None, FLAGS.view_num, 2, 4, 4]))
# 从摄像头参数中获得参数信息
depth_start = tf.reshape(
tf.slice(scaled_cams, [0, 0, 1, 3, 0], [FLAGS.batch_size, 1, 1, 1, 1]), [FLAGS.batch_size])
depth_interval = tf.reshape(
tf.slice(scaled_cams, [0, 0, 1, 3, 1], [FLAGS.batch_size, 1, 1, 1, 1]), [FLAGS.batch_size])
depth_num = tf.cast(
tf.reshape(tf.slice(scaled_cams, [0, 0, 1, 3, 2], [1, 1, 1, 1, 1]), []), 'int32')
# deal with inverse depth,获得深度最大值
if FLAGS.regularization == '3DCNNs' and FLAGS.inverse_depth:
depth_end = tf.reshape(
tf.slice(scaled_cams, [0, 0, 1, 3, 3], [FLAGS.batch_size, 1, 1, 1, 1]), [FLAGS.batch_size])
else:
depth_end = depth_start + (tf.cast(depth_num, tf.float32) - 1) * depth_interval
# depth map inference using 3DCNNs
print(FLAGS.regularization)
if FLAGS.regularization == '3DCNNs':
# 该函数实现在model.py文件,结构类似inference()函数
init_depth_map, prob_map = inference_mem(
centered_images, scaled_cams, FLAGS.max_d, depth_start, depth_interval)
if FLAGS.refinement:
# 对初始深度图进行提炼
ref_image = tf.squeeze(tf.slice(centered_images, [0, 0, 0, 0, 0], [-1, 1, -1, -1, 3]), axis=1)
refined_depth_map = depth_refine(
init_depth_map, ref_image, FLAGS.max_d, depth_start, depth_interval, True)
# depth map inference using GRU
elif FLAGS.regularization == 'GRU':
init_depth_map, prob_map = inference_winner_take_all(centered_images, scaled_cams,
depth_num, depth_start, depth_end, reg_type='GRU',
inverse_depth=FLAGS.inverse_depth)
# init option
init_op = tf.global_variables_initializer()
var_init_op = tf.local_variables_initializer()
# GPU grows incrementally
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
with tf.Session(config=config) as sess:
# with tf.Session(config=config) as sess, tf.device('/gpu:0'):
# initialization
sess.run(var_init_op)
sess.run(init_op)
total_step = 0
# load model
if FLAGS.model_dir is not None:
pretrained_model_ckpt_path = os.path.join(FLAGS.model_dir, FLAGS.regularization, 'model.ckpt')
restorer = tf.train.Saver(tf.global_variables())
restorer.restore(sess, '-'.join([pretrained_model_ckpt_path, str(FLAGS.ckpt_step)]))
print(Notify.INFO, 'Pre-trained model restored from %s' %
('-'.join([pretrained_model_ckpt_path, str(FLAGS.ckpt_step)])), Notify.ENDC)
total_step = FLAGS.ckpt_step
# run inference for each reference view
sess.run(mvs_iterator.initializer)
for step in range(len(mvs_list)):
start_time = time.time()
try:
out_init_depth_map, out_prob_map, out_images, out_cams, out_index = sess.run(
[init_depth_map, prob_map, scaled_images, scaled_cams, image_index])
except tf.errors.OutOfRangeError:
print("all dense finished") # ==> "End of dataset"
break
duration = time.time() - start_time
print(Notify.INFO, 'depth inference %d finished. (%.3f sec/step)' % (step, duration),
Notify.ENDC)
# squeeze output
out_init_depth_image = np.squeeze(out_init_depth_map)
out_prob_map = np.squeeze(out_prob_map)
out_ref_image = np.squeeze(out_images)
out_ref_image = np.squeeze(out_ref_image[0, :, :, :])
out_ref_cam = np.squeeze(out_cams)
out_ref_cam = np.squeeze(out_ref_cam[0, :, :, :])
out_index = np.squeeze(out_index)
# paths
init_depth_map_path = output_folder + ('/%08d_init.pfm' % out_index)
prob_map_path = output_folder + ('/%08d_prob.pfm' % out_index)
out_ref_image_path = output_folder + ('/%08d.jpg' % out_index)
out_ref_cam_path = output_folder + ('/%08d.txt' % out_index)
# save output
write_pfm(init_depth_map_path, out_init_depth_image)
write_pfm(prob_map_path, out_prob_map)
out_ref_image = cv2.cvtColor(out_ref_image, cv2.COLOR_RGB2BGR)
image_file = file_io.FileIO(out_ref_image_path, mode='w')
scipy.misc.imsave(image_file, out_ref_image)
write_cam(out_ref_cam_path, out_ref_cam)
total_step += 1
def main(_): # pylint: disable=unused-argument
""" program entrance """
# generate input path list
mvs_list = gen_pipeline_mvs_list(FLAGS.dense_folder)
# mvsnet inference
mvsnet_pipeline(mvs_list)
if __name__ == '__main__':
tf.app.run()
通过前面的介绍,我相信理解这个代码是没有什么难度了,所以稍微注释了一下,运行测试代码之后,本人在scan9\scan9\depths_mvsnet生成了测试结果,其中最重要的当然是深度图。
这是我第一次接触3D,其实弄完下来,感觉这样的网络很难落实,不知道能做些什么东西。但是不可否认,的确学到了很多东西。再见了,有了这样初步的了解之后,我相信做其他的3D项目,应该就能很快的上手了。