当前位置:   article > 正文

代码解读:基于深度学习的单目深度估计(4)_coarse_mean_lpips

coarse_mean_lpips

代码解读:基于深度学习的单目深度估计(4)


今天再来分析深度网络的coarse和fine的stack结构

  1. def define_coarse_stack(self, imnet_feats):
  2. full1 = self.create_unit('full1', ninput=test_shape(imnet_feats)[1])
  3. f_1 = relu(full1.infer(imnet_feats))
  4. f_1_drop = random_zero(f_1, 0.5)
  5. f_1_mean = 0.5 * f_1
  6. full2 = self.create_unit('full2', ninput=test_shape(f_1_mean)[1])
  7. f_2_drop = full2.infer(f_1_drop)
  8. f_2_mean = full2.infer(f_1_mean)
  9. # prediction
  10. (h, w) = self.output_size
  11. pred_drop = f_2_drop.reshape((self.bsize, h, w))
  12. pred_mean = f_2_mean.reshape((self.bsize, h, w))
  13. self.coarse = MachinePart(locals())
  14. def define_fine_stack(self, x0):
  15. # pproc slightly different from imagenet because no cmrnorm
  16. x0_pproc = (x0 - self.meta.images_mean) \
  17. * self.meta.images_istd
  18. conv_s2_1 = self.create_unit('conv_s2_1')
  19. z_s2_1 = relu(conv_s2_1.infer(x0_pproc))
  20. pool_s2_1 = self.create_unit('pool_s2_1')
  21. (p_s2_1, s_s2_1) = pool_s2_1.infer(z_s2_1)
  22. # concat input features with coarse prediction
  23. (h, w) = self.output_size
  24. coarse_drop = self.coarse.pred_drop.reshape((self.bsize, 1, h, w))
  25. coarse_mean = self.coarse.pred_mean.reshape((self.bsize, 1, h, w))
  26. p_1_concat_drop = T.concatenate(
  27. (coarse_drop,
  28. p_s2_1[:, 1:, :, :]),
  29. axis=1)
  30. p_1_concat_mean = T.concatenate(
  31. (coarse_mean,
  32. p_s2_1[:, 1:, :, :]),
  33. axis=1)
  34. conv_s2_2 = self.create_unit('conv_s2_2')
  35. z_s2_2_drop = relu(conv_s2_2.infer(p_1_concat_drop))
  36. z_s2_2_mean = relu(conv_s2_2.infer(p_1_concat_mean))
  37. conv_s2_3 = self.create_unit('conv_s2_3')
  38. z_s2_3_drop = conv_s2_3.infer(z_s2_2_drop)
  39. z_s2_3_mean = conv_s2_3.infer(z_s2_2_mean)
  40. # prediction
  41. pred_drop = z_s2_3_drop[:,0,:,:]
  42. pred_mean = z_s2_3_mean[:,0,:,:]
  43. self.fine = MachinePart(locals())

从这段代码中,可以了解:

1,define_fine_stack可以了解到三个卷积层的连接,这和论文的描述一致

2,但是define_coarse_stack的讲述与论文不符,不清楚这个原因

接下来看定义误差函数的代码,

  1. #定义损失函数 这个会不会就是文献的创新点呢?缩放不变的损失函数
  2. def define_cost(self, pred, y0, m0):
  3. bsize = self.bsize
  4. npix = int(np.prod(test_shape(y0)[1:]))
  5. y0_target = y0.reshape((self.bsize, npix))
  6. y0_mask = m0.reshape((self.bsize, npix))
  7. pred = pred.reshape((self.bsize, npix))
  8. p = pred * y0_mask
  9. t = y0_target * y0_mask
  10. d = (p - t)
  11. nvalid_pix = T.sum(y0_mask, axis=1)
  12. depth_cost = (T.sum(nvalid_pix * T.sum(d**2, axis=1))
  13. - 0.5*T.sum(T.sum(d, axis=1)**2)) \
  14. / T.maximum(T.sum(nvalid_pix**2), 1)
  15. return depth_cost

具体的数学推导这里不去详解


总之,分析到现在,已经把源码中models文件夹下的depth.conf 和 depth.py 两个文件看完了,了解到:

1,depth.conf 讲述了深度网络的搭建模式,基本和论文是对应的

2,depth.py 介绍了深度推测函数,图片的预处理,深度网络的初始化以及误差函数的定义

而文件common存放的是一些更为底层的函数,不必要做更深分析

在源码的根目录下,net.py 和 pooling.py 是关于网络层以及池化层的底层函数,不做太多了解

说白了,跟论文相关的部分已经分析完了

最后,把目光转向 test.py 文件,之前分析了前几行,来看下面的代码,

  1. def main():
  2. # location of depth module, config and parameters
  3. module_fn = 'models/depth.py'
  4. config_fn = 'models/depth.conf'#网络结构
  5. params_dir = 'weights/depth'#网络相关参数
  6. # load depth network
  7. machine = net.create_machine(module_fn, config_fn, params_dir)
  8. # demo image
  9. rgb = Image.open('demo_nyud_rgb.jpg')
  10. rgb = rgb.resize((320, 240), Image.BICUBIC)
  11. # build depth inference function and run
  12. rgb_imgs = np.asarray(rgb).reshape((1, 240, 320, 3))
  13. pred_depths = machine.infer_depth(rgb_imgs)
  14. # save prediction
  15. (m, M) = (pred_depths.min(), pred_depths.max())
  16. depth_img_np = (pred_depths[0] - m) / (M - m)
  17. depth_img = Image.fromarray((255*depth_img_np).astype(np.uint8))
  18. depth_img.save('demo_nyud_depth_prediction.png')
  19. if __name__ == '__main__':
  20. main()

从这段代码可以了解:

1,net的函数存入create_machine()初始化整个深度网络

2,machine的infer_depth()用来估计深度,这个函数之前分析过

3,对结果的depth图像进行处理,保存图像


总而言之,我把基于深度学习的单目深度估计从论文到算法到源码认真整理了一遍,

很有意思!

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/Cpp五条/article/detail/149404
推荐阅读
相关标签
  

闽ICP备14008679号