当前位置:   article > 正文

SuperPoint学习训练纪录 无训练版与带训练版本(一)_superpoint训练

superpoint训练

superpoint   论文原文,先过一遍论文。

superpoint   官方源代码,基于 pytorch 的,不过遗憾的是训练代码和相应的渲染的训练数据没提供。

superpoint   大佬复现带训练版本,基于tensorflow

1.官方源码效果

命令行下载源码,也可以点上面链接下载。

git clone https://github.com/magicleap/SuperPointPretrainedNetwork.git 

先配置环境(强烈建议安装anaconda配置虚拟环境,方便管理)

  1. pip install opencv-python
  2. pip install torch

我的实现版本:opencv-python 4.4     torch1.8.1

官方要求:

​​​​

运行图片模式(其他模式github都有对应命令代码,输入为摄像头或视频)

./demo_superpoint.py assets/给到图片文件夹的路径上一级/

 注意图片不是png格式则需要做修改,在img_glob default里修改,运行卡顿也可以改HW调整输入图片大小。

 跟踪效果(检测对应特征点在图片上的移动)

 

 源码只有跟踪效果,我在源码的基础上做了一些修改,写了匹配和用时匹配数等数据的可视化。

  1. import argparse
  2. import glob
  3. import numpy as np
  4. import os
  5. import time
  6. import cv2
  7. from numpy.core.records import array
  8. from numpy.distutils.system_info import x11_info
  9. import torch
  10. if int(cv2.__version__[0]) < 3: # pragma: no cover
  11. print('Warning: OpenCV 3 is not installed')
  12. class SuperPointNet(torch.nn.Module):#
  13. def __init__(self):
  14. super(SuperPointNet, self).__init__() # 第一句话,调用父类的构造函数 也就是nn.module
  15. self.relu = torch.nn.ReLU(inplace=True)
  16. self.pool = torch.nn.MaxPool2d(kernel_size=2, stride=2)
  17. c1, c2, c3, c4, c5, d1 = 64, 64, 128, 128, 256, 256
  18. # Shared Encoder.
  19. self.conv1a = torch.nn.Conv2d(1, c1, kernel_size=3, stride=1, padding=1)# 编码卷集层
  20. self.conv1b = torch.nn.Conv2d(c1, c1, kernel_size=3, stride=1, padding=1)
  21. self.conv2a = torch.nn.Conv2d(c1, c2, kernel_size=3, stride=1, padding=1)
  22. self.conv2b = torch.nn.Conv2d(c2, c2, kernel_size=3, stride=1, padding=1)
  23. self.conv3a = torch.nn.Conv2d(c2, c3, kernel_size=3, stride=1, padding=1)
  24. self.conv3b = torch.nn.Conv2d(c3, c3, kernel_size=3, stride=1, padding=1)
  25. self.conv4a = torch.nn.Conv2d(c3, c4, kernel_size=3, stride=1, padding=1)
  26. self.conv4b = torch.nn.Conv2d(c4, c4, kernel_size=3, stride=1, padding=1)
  27. # Detector Head.
  28. self.convPa = torch.nn.Conv2d(c4, c5, kernel_size=3, stride=1, padding=1)# 解码卷集层
  29. self.convPb = torch.nn.Conv2d(c5, 65, kernel_size=1, stride=1, padding=0)
  30. # Descriptor Head.
  31. self.convDa = torch.nn.Conv2d(c4, c5, kernel_size=3, stride=1, padding=1)# 解码卷集层
  32. self.convDb = torch.nn.Conv2d(c5, d1, kernel_size=1, stride=1, padding=0)
  33. def forward(self, x):
  34. # Shared Encoder.
  35. x = self.relu(self.conv1a(x))
  36. x = self.relu(self.conv1b(x))
  37. x = self.pool(x)
  38. x = self.relu(self.conv2a(x))
  39. x = self.relu(self.conv2b(x))
  40. x = self.pool(x)
  41. x = self.relu(self.conv3a(x))
  42. x = self.relu(self.conv3b(x))
  43. x = self.pool(x)
  44. x = self.relu(self.conv4a(x))
  45. x = self.relu(self.conv4b(x))
  46. # Detector Head.
  47. cPa = self.relu(self.convPa(x))
  48. semi = self.convPb(cPa)
  49. # Descriptor Head.
  50. cDa = self.relu(self.convDa(x))
  51. desc = self.convDb(cDa)
  52. dn = torch.norm(desc, p=2, dim=1) # Compute the norm.
  53. desc = desc.div(torch.unsqueeze(dn, 1)) # Divide by norm to normalize.
  54. return semi, desc #网络结构 返回值
  55. class SuperPointFrontend(object): #
  56. #superpoint前端类 继承了object
  57. #包含住了superpoint类
  58. """ Wrapper around pytorch net to help with pre and post image processing. """
  59. def __init__(self, weights_path, nms_dist, conf_thresh, nn_thresh,
  60. cuda=False):
  61. self.name = 'SuperPoint'
  62. self.cuda = cuda
  63. self.nms_dist = nms_dist
  64. self.conf_thresh = conf_thresh#传参
  65. self.nn_thresh = nn_thresh # L2 descriptor distance for good match.
  66. self.cell = 8 # Size of each output cell. Keep this fixed.每个输出单元格的大小。保持固定的。
  67. self.border_remove = 4 # Remove points this close to the border. 边界点
  68. # Load the network in inference mode.
  69. self.net = SuperPointNet() #子类
  70. if cuda:
  71. # Train on GPU, deploy on GPU.
  72. self.net.load_state_dict(torch.load(weights_path))
  73. self.net = self.net.cuda()
  74. else:
  75. # Train on GPU, deploy on CPU.
  76. self.net.load_state_dict(torch.load(weights_path,
  77. map_location=lambda storage, loc: storage))
  78. self.net.eval()
  79. def nms_fast(self, in_corners, H, W, dist_thresh):#非极大值抑制 ,dis thresh 默认4
  80. grid = np.zeros((H, W)).astype(int) # Track NMS data. #array
  81. inds = np.zeros((H, W)).astype(int) # 存储点的索引。
  82. # 根据特征点信心值排序 四舍五入最接近的int。
  83. inds1 = np.argsort(-in_corners[2,:])
  84. #argsort返回数组值从小到大的索引值 ,前是行索引 后面是列 取第三行的信心值全切片 从大到小排序
  85. corners = in_corners[:,inds1]#取出所有特征点信值
  86. rcorners = corners[:2,:].round().astype(int) # 取出特征点xy 0-2 列全部
  87. # Check for edge case of 0 or 1 corners.检查0或1个角的边缘情况。
  88. if rcorners.shape[1] == 0:
  89. return np.zeros((3,0)).astype(int), np.zeros(0).astype(int)
  90. if rcorners.shape[1] == 1:
  91. out = np.vstack((rcorners, in_corners[2])).reshape(3,1)
  92. return out, np.zeros((1)).astype(int)
  93. # Initialize the grid.
  94. for i, rc in enumerate(rcorners.T):
  95. #enumerate就是枚举的意思,把元素一个个列举出来, 所以他返回的是元素rc以及对应的索引i。
  96. grid[rcorners[1,i], rcorners[0,i]] = 1 #有哪些点 有就给1
  97. inds[rcorners[1,i], rcorners[0,i]] = i #给点对应的特征点序号
  98. # Pad the border of the grid, so that we can NMS points near the border.
  99. pad = dist_thresh #距离4 int
  100. grid = np.pad(grid, ((pad,pad), (pad,pad)), mode='constant')#pad填充边缘网格 constant default=0
  101. # Iterate through points, highest to lowest conf, suppress neighborhood.
  102. #遍历点,从最高到最低的conf,抑制邻域。
  103. count = 0
  104. for i, rc in enumerate(rcorners.T):
  105. # Account for top and left padding.
  106. pt = (rc[0]+pad, rc[1]+pad)#4*4领域 因为要考虑边界
  107. if grid[pt[1], pt[0]] == 1: # If not yet suppressed.
  108. grid[pt[1]-pad:pt[1]+pad+1, pt[0]-pad:pt[0]+pad+1] = 0
  109. grid[pt[1], pt[0]] = -1 #抑制完了取-1
  110. count += 1
  111. # Get all surviving -1's and return sorted array of remaining corners.
  112. #获取所有幸存的-1,并返回剩余角的排序数组。
  113. keepy, keepx = np.where(grid==-1) #-1的xy where返回坐标
  114. keepy, keepx = keepy - pad, keepx - pad#真实的xy坐标 因为前面加4了
  115. inds_keep = inds[keepy, keepx]
  116. out = corners[:, inds_keep] #取出xy和conf
  117. values = out[-1, :] #出conf
  118. inds2 = np.argsort(-values) #排序conf 返回索引
  119. out = out[:, inds2]
  120. out_inds = inds1[inds_keep[inds2]]
  121. return out, out_inds#剩下点的3*n 和索引
  122. def run(self, img):
  123. assert img.ndim == 2, 'Image must be grayscale.'
  124. assert img.dtype == np.float32, 'Image must be float32.'
  125. H, W = img.shape[0], img.shape[1]
  126. inp = img.copy()
  127. inp = (inp.reshape(1, H, W))
  128. inp = torch.from_numpy(inp)
  129. inp = torch.autograd.Variable(inp).view(1, 1, H, W)
  130. if self.cuda:
  131. inp = inp.cuda()
  132. # Forward pass of network.
  133. outs = self.net.forward(inp)
  134. semi, coarse_desc = outs[0], outs[1]
  135. # Convert pytorch -> numpy.
  136. semi = semi.data.cpu().numpy().squeeze()
  137. # --- Process points.
  138. dense = np.exp(semi) # Softmax.
  139. dense = dense / (np.sum(dense, axis=0)+.00001) # Should sum to 1.
  140. # Remove dustbin.
  141. nodust = dense[:-1, :, :]
  142. # Reshape to get full resolution heatmap.
  143. Hc = int(H / self.cell)
  144. Wc = int(W / self.cell)
  145. nodust = nodust.transpose(1, 2, 0)
  146. heatmap = np.reshape(nodust, [Hc, Wc, self.cell, self.cell])
  147. heatmap = np.transpose(heatmap, [0, 2, 1, 3])
  148. heatmap = np.reshape(heatmap, [Hc*self.cell, Wc*self.cell])
  149. xs, ys = np.where(heatmap >= self.conf_thresh) # Confidence threshold.
  150. if len(xs) == 0:
  151. return np.zeros((3, 0)), None, None
  152. pts = np.zeros((3, len(xs))) # Populate point data sized 3xN.
  153. pts[0, :] = ys
  154. pts[1, :] = xs
  155. pts[2, :] = heatmap[xs, ys]
  156. pts, _ = self.nms_fast(pts, H, W, dist_thresh=self.nms_dist) # Apply NMS.
  157. inds = np.argsort(pts[2,:])
  158. pts = pts[:,inds[::-1]] # Sort by confidence.
  159. # Remove points along border.
  160. bord = self.border_remove
  161. toremoveW = np.logical_or(pts[0, :] < bord, pts[0, :] >= (W-bord))
  162. toremoveH = np.logical_or(pts[1, :] < bord, pts[1, :] >= (H-bord))
  163. toremove = np.logical_or(toremoveW, toremoveH)
  164. pts = pts[:, ~toremove]
  165. # --- Process descriptor.
  166. D = coarse_desc.shape[1]
  167. if pts.shape[1] == 0:
  168. desc = np.zeros((D, 0))
  169. else:
  170. # Interpolate into descriptor map using 2D point locations.
  171. samp_pts = torch.from_numpy(pts[:2, :].copy())
  172. samp_pts[0, :] = (samp_pts[0, :] / (float(W)/2.)) - 1.
  173. samp_pts[1, :] = (samp_pts[1, :] / (float(H)/2.)) - 1.
  174. samp_pts = samp_pts.transpose(0, 1).contiguous()
  175. samp_pts = samp_pts.view(1, 1, -1, 2)
  176. samp_pts = samp_pts.float()
  177. if self.cuda:
  178. samp_pts = samp_pts.cuda()
  179. desc = torch.nn.functional.grid_sample(coarse_desc, samp_pts)
  180. desc = desc.data.cpu().numpy().reshape(D, -1)
  181. desc /= np.linalg.norm(desc, axis=0)[np.newaxis, :]
  182. return pts, desc, heatmap
  183. def nn_match_two_way(desc1, desc2, nn_thresh):
  184. assert desc1.shape[0] == desc2.shape[0]
  185. if desc1.shape[1] == 0 or desc2.shape[1] == 0:
  186. return np.zeros((3, 0))
  187. if nn_thresh < 0.0:
  188. raise ValueError('\'nn_thresh\' should be non-negative')
  189. # Compute L2 distance. Easy since vectors are unit normalized.
  190. dmat = np.dot(desc1.T, desc2)
  191. dmat = np.sqrt(2-2*np.clip(dmat, -1, 1))
  192. # Get NN indices and scores.
  193. idx = np.argmin(dmat, axis=1)
  194. scores = dmat[np.arange(dmat.shape[0]), idx]
  195. # Threshold the NN matches.
  196. keep = scores < nn_thresh
  197. # Check if nearest neighbor goes both directions and keep those.
  198. idx2 = np.argmin(dmat, axis=0)
  199. keep_bi = np.arange(len(idx)) == idx2[idx]
  200. keep = np.logical_and(keep, keep_bi)
  201. idx = idx[keep]
  202. scores = scores[keep]
  203. # Get the surviving point indices.
  204. m_idx1 = np.arange(desc1.shape[1])[keep]
  205. m_idx2 = idx
  206. # Populate the final 3xN match data structure.
  207. matches = np.zeros((3, int(keep.sum())))
  208. matches[0, :] = m_idx1
  209. matches[1, :] = m_idx2
  210. matches[2, :] = scores
  211. return matches
  212. class VideoStreamer(object):
  213. def __init__(self, basedir, camid, height, width, skip, img_glob):#img_glod *.png
  214. # 构造函数
  215. # 图模式中主要用来得到self.list 所有图片的路径
  216. #
  217. self.cap = [] #list
  218. self.camera = False
  219. self.video_file = False
  220. self.listing = []
  221. self.sizer = [height, width]
  222. self.i = 0
  223. self.skip = skip
  224. self.maxlen = 1000000
  225. # If the "basedir" string is the word camera, then use a webcam.
  226. if basedir == "camera/" or basedir == "camera": #使用相机
  227. print('==> Processing Webcam Input.')
  228. self.cap = cv2.VideoCapture(camid)
  229. self.listing = range(0, self.maxlen)
  230. self.camera = True
  231. else:
  232. # Try to open as a video.
  233. self.cap = cv2.VideoCapture(basedir) #使用视屏
  234. lastbit = basedir[-4:len(basedir)] #len str的长度 string[-4:最后] 切片 切最后四个字符
  235. if (type(self.cap) == list or not self.cap.isOpened()) and (lastbit == '.mp4'):
  236. raise IOError('Cannot open movie file')
  237. elif type(self.cap) != list and self.cap.isOpened() and (lastbit != '.txt'):
  238. print('==> Processing Video Input.')
  239. num_frames = int(self.cap.get(cv2.CAP_PROP_FRAME_COUNT)) #视频总帧数
  240. self.listing = range(0, num_frames) #生成list
  241. self.listing = self.listing[::self.skip] #切片 0-num_frame 10步一切
  242. self.camera = True
  243. self.video_file = True
  244. self.maxlen = len(self.listing) #更新总帧数
  245. else:
  246. print('==> Processing Image Directory Input.')
  247. search = os.path.join(basedir, img_glob)#前后str拼起来 得到*图片路径 名字用*.png
  248. #type(search)=str
  249. self.listing = glob.glob(search) #找到所有符合search格式的图片路径
  250. self.listing.sort() #从小到大排序
  251. self.listing = self.listing[::self.skip] ##切片 遍历list skip步一切
  252. self.maxlen = len(self.listing) #更新总帧数
  253. if self.maxlen == 0:
  254. raise IOError('No images were found (maybe bad \'--img_glob\' parameter?)')
  255. def read_image(self, impath, img_size): #读图的成员函数
  256. """ Read image as grayscale and resize to img_size.
  257. Inputs
  258. impath: Path to input image.
  259. img_size: (W, H) tuple specifying resize size.
  260. Returns
  261. grayim: float32 numpy array sized H x W with values in range [0, 1].
  262. """
  263. grayim = cv2.imread(impath, 0) #0 gray
  264. if grayim is None:
  265. raise Exception('Error reading image %s' % impath)
  266. # Image is resized via opencv.
  267. interp = cv2.INTER_AREA#一种插值方法 一般用来缩小
  268. grayim = cv2.resize(grayim, (img_size[1], img_size[0]), interpolation=interp) #把图片缩小到指定HW
  269. grayim = (grayim.astype('float32') / 255.) #格式转换
  270. return grayim
  271. def next_frame(self): #依次给图的成员函数
  272. """ Return the next frame, and increment internal counter.
  273. Returns
  274. image: Next H x W image.
  275. status: True or False depending whether image was loaded.
  276. """
  277. if self.i == self.maxlen:#到最后一帧了
  278. return (None, False)
  279. if self.camera: #相机模式
  280. ret, input_image = self.cap.read()
  281. if ret is False:
  282. print('VideoStreamer: Cannot get image from camera (maybe bad --camid?)')
  283. return (None, False)
  284. if self.video_file: #video模式
  285. self.cap.set(cv2.CAP_PROP_POS_FRAMES, self.listing[self.i])
  286. input_image = cv2.resize(input_image, (self.sizer[1], self.sizer[0]),
  287. interpolation=cv2.INTER_AREA)
  288. input_image = cv2.cvtColor(input_image, cv2.COLOR_RGB2GRAY)
  289. input_image = input_image.astype('float')/255.0
  290. else:
  291. image_file = self.listing[self.i] #切当前片
  292. input_image = self.read_image(image_file, self.sizer)
  293. # Increment internal counter.
  294. self.i = self.i + 1#内部的计数器
  295. input_image = input_image.astype('float32')
  296. return (input_image, True)
  297. def match_descriptors(kp1, desc1, kp2, desc2):
  298. # Match the keypoints with the warped_keypoints with nearest neighbor search
  299. bf = cv2.BFMatcher(cv2.NORM_L2, crossCheck=True)
  300. ##是一个暴力匹配的对象,取desc1中一个描述子,再与desc2中的所有描述子计算欧式距离
  301. matches = bf.match(desc1, desc2)
  302. ##上一行的返回值类似元组的集合(i,j)代表第一个集合的第i个点的最佳匹配是第二个集合的第j个点
  303. matches_idx = np.array([m.queryIdx for m in matches])
  304. m_kp1 = [kp1[idx] for idx in matches_idx]
  305. matches_idx = np.array([m.trainIdx for m in matches])
  306. m_kp2 = [kp2[idx] for idx in matches_idx]
  307. ####m_kp1是第一张图片的特征点,m_kp2是第二张图片的特征点,此时它们已经一一对应了(至少是最对应的,距离最小的
  308. return m_kp1, m_kp2, matches
  309. def showpoint(img,ptx):
  310. for i in range(ptx.shape[1]):
  311. x=int(round(ptx[0,i]))
  312. y=int(round(ptx[1,i]))
  313. # if x>20 and y>20 and x<640 and y <450:
  314. # None
  315. cv2.circle(img,(x,y),3,color=(255,0,0))
  316. return img
  317. def drawMatches(img1, kp1, img2, kp2, matches):
  318. """
  319. My own implementation of cv2.drawMatches as OpenCV 2.4.9
  320. does not have this function available but it's supported in
  321. OpenCV 3.0.0
  322. This function takes in two images with their associated
  323. keypoints, as well as a list of DMatch data structure (matches)
  324. that contains which keypoints matched in which images.
  325. An image will be produced where a montage is shown with
  326. the first image followed by the second image beside it.
  327. Keypoints are delineated with circles, while lines are connected
  328. between matching keypoints.
  329. img1,img2 - Grayscale images
  330. kp1,kp2 - Detected list of keypoints through any of the OpenCV keypoint
  331. detection algorithms
  332. matches - A list of matches of corresponding keypoints through any
  333. OpenCV keypoint matching algorithm
  334. """
  335. # Create a new output image that concatenates the two images together
  336. # (a.k.a) a montage
  337. rows1 = img1.shape[0]
  338. cols1 = img1.shape[1]
  339. rows2 = img2.shape[0]
  340. cols2 = img2.shape[1]
  341. #out = np.zeros((max([rows1,rows2]),cols1+cols2,3), dtype='uint8')
  342. # Place the first image to the left
  343. i1= np.dstack([img1, img1, img1])
  344. i2=np.dstack([img2, img2, img2])
  345. cv2.imshow("sdd",i1)
  346. cv2.imshow("sd",i1)
  347. out = np.hstack([i1,i2])
  348. print("sdsdsd",out.shape)
  349. # Place the next image to the right of it
  350. #out[0:480,640:1280] = np.dstack([img2, img2, img2])
  351. # For each pair of points we have between both images
  352. # draw circles, then connect a line between them
  353. for i in range(matches.shape[1]):
  354. # Get the matching keypoints for each of the images
  355. img1_idx = matches[0,i]
  356. img2_idx = matches[1,i]
  357. x11=int(img1_idx)
  358. y11=int(img1_idx)
  359. x22=int(img2_idx)
  360. y22=int(img2_idx)
  361. # x - columns
  362. # y - rows
  363. x1= kp1[0,x11]
  364. y1= kp1[1,y11]
  365. x2= kp2[0,x22]
  366. y2= kp2[1,y22]
  367. # Draw a small circle at both co-ordinates
  368. # radius 4
  369. # colour blue
  370. # thickness = 1
  371. a = np.random.randint(0,256)
  372. b = np.random.randint(0,256)
  373. c = np.random.randint(0,256)
  374. cv2.circle(out, (int(np.round(x1)),int(np.round(y1))), 2, (a, b, c), 1) #画圆,cv2.circle()参考官方文档
  375. cv2.circle(out, (int(np.round(x2)+cols1),int(np.round(y2))), 2, (a, b, c), 1)
  376. # Draw a line in between the two points
  377. # thickness = 1
  378. # colour blue
  379. cv2.line(out, (int(np.round(x1)),int(np.round(y1))), (int(np.round(x2)+cols1),int(np.round(y2))), (a, b, c), 1, shift=0) #画线,cv2.line()参考官方文档
  380. # Also return the image if you'd like a copy
  381. return out
  382. if __name__ == '__main__':
  383. # Parse command line arguments.
  384. parser = argparse.ArgumentParser(description='PyTorch SuperPoint Demo.')
  385. parser.add_argument('input', type=str, default='',
  386. help='Image directory or movie file or "camera" (for webcam).')
  387. parser.add_argument('--weights_path', type=str, default='superpoint_v1.pth',
  388. help='Path to pretrained weights file (default: superpoint_v1.pth).')
  389. parser.add_argument('--img_glob', type=str, default='*.png',#################pgm
  390. help='Glob match if directory of images is specified (default: \'*.png\').')
  391. parser.add_argument('--skip', type=int, default=1,
  392. help='Images to skip if input is movie or directory (default: 1).')
  393. parser.add_argument('--show_extra', action='store_true',
  394. help='Show extra debug outputs (default: False).')
  395. parser.add_argument('--H', type=int, default=480,
  396. help='Input image height (default: 120).')
  397. parser.add_argument('--W', type=int, default=640,
  398. help='Input image width (default:640).')
  399. parser.add_argument('--display_scale', type=int, default=2,
  400. help='Factor to scale output visualization (default: 2).')
  401. parser.add_argument('--min_length', type=int, default=2,
  402. help='Minimum length of point tracks (default: 2).')
  403. parser.add_argument('--max_length', type=int, default=5,
  404. help='Maximum length of point tracks (default: 5).')
  405. parser.add_argument('--nms_dist', type=int, default=4,
  406. help='Non Maximum Suppression (NMS) distance (default: 4).')
  407. parser.add_argument('--conf_thresh', type=float, default=0.015,
  408. help='Detector confidence threshold (default: 0.015).')
  409. parser.add_argument('--nn_thresh', type=float, default=0.7,
  410. help='Descriptor matching threshold (default: 0.7).')
  411. parser.add_argument('--camid', type=int, default=0,
  412. help='OpenCV webcam video capture ID, usually 0 or 1 (default: 0).')
  413. parser.add_argument('--waitkey', type=int, default=1,
  414. help='OpenCV waitkey time in ms (default: 1).')
  415. parser.add_argument('--cuda', action='store_true',
  416. help='Use cuda GPU to speed up network processing speed (default: False)')
  417. parser.add_argument('--no_display', action='store_true',
  418. help='Do not display images to screen. Useful if running remotely (default: False).')
  419. parser.add_argument('--write', action='store_true',
  420. help='Save output frames to a directory (default: False)')
  421. parser.add_argument('--write_dir', type=str, default='tracker_outputs/',
  422. help='Directory where to write output frames (default: tracker_outputs/).')
  423. opt = parser.parse_args()
  424. print(opt)
  425. #读图 读下一张图
  426. vs = VideoStreamer(opt.input, opt.camid, opt.H, opt.W, opt.skip, opt.img_glob)
  427. print('==> Loading pre-trained network.')
  428. # This class runs the SuperPoint network and processes its outputs.
  429. fe = SuperPointFrontend(weights_path=opt.weights_path,#权重的路径str
  430. nms_dist=opt.nms_dist,#非极大值抑制 int距离4
  431. conf_thresh=opt.conf_thresh,#探测器阈值0.015
  432. nn_thresh=opt.nn_thresh,#匹配器阈值0.7
  433. cuda=opt.cuda) #GPU加速 默认false
  434. print('==> Successfully loaded pre-trained network.')
  435. # Create a window to display the demo.
  436. if not opt.no_display:
  437. win = 'SuperPoint Tracker'
  438. cv2.namedWindow(win)
  439. else:
  440. print('Skipping visualization, will not show a GUI.')
  441. # Font parameters for visualizaton.
  442. font = cv2.FONT_HERSHEY_DUPLEX#设置可视化字体
  443. font_clr = (255, 255, 255)
  444. font_pt = (4, 12)
  445. font_sc = 0.4
  446. #创建输出目录
  447. if opt.write:#默认false
  448. print('==> Will write outputs to %s' % opt.write_dir)
  449. if not os.path.exists(opt.write_dir):
  450. os.makedirs(opt.write_dir)
  451. print('==> Running Demo.')
  452. img1, status = vs.next_frame()#读第一张图
  453. start1 = time.time()
  454. pts, desc, heatmap = fe.run(img1)
  455. end1 = time.time()
  456. c2=end1-start1
  457. print("第一张图提取用时",c2,"提取特征点数目",pts.shape[1])
  458. imgx=img1.copy()
  459. img11=showpoint(imgx,pts)
  460. cv2.imshow("imgone",img11)
  461. img2, status = vs.next_frame()#读第二张图
  462. start1 = time.time()
  463. pts1, desc1, heatmap1 = fe.run(img2)
  464. end1 = time.time()
  465. c2=end1-start1
  466. print("第二张图提取用时",c2,"提取特征点数目",pts1.shape[1])
  467. imgx=img2.copy()
  468. img22=showpoint(imgx,pts1)
  469. cv2.imshow("imgtwo",img22)
  470. match=nn_match_two_way(desc,desc1,0.7)
  471. print("图1与图2匹配对数",match.shape[1])
  472. # cv_kpts1 = [cv2.KeyPoint(pts[0,i], pts[1,i], 1)
  473. # for i in range(pts.shape[1])]
  474. # cv_kpts2 = [cv2.KeyPoint(pts1[0,i], pts1[1,i], 1)
  475. # for i in range(pts1.shape[1])]
  476. # sift_matched_img = cv2.drawMatches(img1, cv_kpts1, img2,
  477. # cv_kpts2, matches, None,
  478. # matchColor=(0, 255, 0),
  479. # singlePointColor=(0, 0, 255))
  480. #手写匹配 有些问题
  481. out=sift_matched_img = drawMatches(img1, pts, img2,
  482. pts1, match)
  483. cv2.namedWindow("matcher",0)
  484. cv2.imshow("matcher",out)
  485. cv2.waitKey(0)
  486. print('==> Finshed Demo.')

代码运行效果

 

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/正经夜光杯/article/detail/754797
推荐阅读
相关标签
  

闽ICP备14008679号