当前位置:   article > 正文

基于Resnet50的pytorch框架下的图像特征提取_resnet50提取图像特征

resnet50提取图像特征

在Resnet50CNN结构下实现图像的特征提取,这里采用的是CV2的图像读入方式,最后再把得到的图像转换成npy格式进行输出得得,图像对应的特征。

  1. # -*- coding: utf-8 -*-
  2. """
  3. Function: 图像特征的提取,可以依据需求修改CNN的输出,得到不同层网络的输出图像特征
  4. Writer: Zenght
  5. date:2019.2.16
  6. """
  7. from __future__ import print_function, division, absolute_import
  8. import torch
  9. import torchvision
  10. import torch.nn as nn
  11. from torch.autograd import Variable
  12. import torch.optim as optim
  13. from torch.optim import lr_scheduler
  14. import numpy as np
  15. from torchvision import datasets, models, transforms
  16. import os
  17. import cv2
  18. import time
  19. import copy
  20. import torch.utils.data as data
  21. from Rsenet50 import Resnet
  22. class Net(nn.Module):
  23. # 此处可以添加自行设定的网络结构
  24. def __init__(self):
  25. super(Net, self).__init__()
  26. def cv2_imageloader(path):
  27. mean = [0.485, 0.456, 0.406]
  28. std = [0.229, 0.224, 0.225]
  29. img = cv2.imread(path)
  30. img = cv2.resize(img, (224, 224))
  31. im_arr = np.float32(img)
  32. im_arr = np.ascontiguousarray(im_arr[..., ::-1])
  33. im_arr = im_arr.transpose(2, 0, 1)# Convert Img from BGR to RGB
  34. for channel, _ in enumerate(im_arr):
  35. # Normalization
  36. im_arr[channel] /= 255
  37. im_arr[channel] -= mean[channel]
  38. im_arr[channel] /= std[channel]
  39. # Convert to float tensor
  40. im_as_ten = torch.from_numpy(im_arr).float()
  41. # Convert to Pytorch variable
  42. im_as_var = Variable(im_as_ten, requires_grad=True)
  43. return im_as_var
  44. def default_loader(path):
  45. return cv2_imageloader(path)
  46. class CustomImageLoader(data.Dataset):
  47. ##自定义类型数据输入
  48. def __init__(self, img_path, txt_path, dataset = '', loader = default_loader, save_path='/home/haitaizeng/stanforf/ctumb_zht/TFP/pytorch/Feature/MIT67'):
  49. im_list = []
  50. im_dirs = []
  51. im_labels = []
  52. with open(txt_path, 'r') as files:
  53. for line in files:
  54. items = line.split()
  55. if items[0][0] == '/':
  56. imname = line.split()[0][1:]
  57. fnewname = '_'.join(imname[:-4].split('/')) + '.npy'
  58. else:
  59. imname = line.split()[0]
  60. fnewname = '_'.join(imname[:-4].split('/'))+'.npy'
  61. im_list.append(os.path.join(img_path, imname))
  62. im_labels.append(int(items[1]))
  63. im_dirs.append(os.path.join(save_path, fnewname))
  64. self.imgs = im_list
  65. self.labels = im_labels
  66. self.save_dir = im_dirs
  67. self.loader = loader
  68. self.dataset = dataset
  69. def __len__(self):
  70. return len(self.imgs)
  71. def __getitem__(self, item):
  72. # print(item)
  73. img_name = self.imgs[item]
  74. label = self.labels[item]
  75. imdir = self.save_dir[item]
  76. img = self.loader(img_name)
  77. return img, label, imdir
  78. batch_size = 64
  79. device = torch.device('cuda:0')
  80. # SUN397 INPUT
  81. # image_dir = '/media/haitaizeng/000222840009D764/Images'#
  82. # image_datasets = {x : CustomImageLoader(image_dir, txt_path=('/home/haitaizeng/stanforf/ctumb_zht/TFP/pytorch/Trainfiles/SUN397/'+x+'Images.label'),
  83. #
  84. # dataset=x) for x in ['Train', 'Test']
  85. # }
  86. #MIT67 INPUT
  87. image_dir = '/media/haitaizeng/00038FCE000387A5/cgw/Datasets/MIT67/Images'
  88. image_datasets = {x : CustomImageLoader(image_dir, txt_path=('/home/haitaizeng/stanforf/alex_mit/data_image/'+x+'Images.label'),
  89. dataset=x) for x in ['Train', 'Test']
  90. }
  91. dataloders = {x: torch.utils.data.DataLoader(image_datasets[x],
  92. batch_size=batch_size,
  93. shuffle=True) for x in ['Train', 'Test']}
  94. dataset_sizes = {x: len(image_datasets[x]) for x in ['Train', 'Test']}
  95. def Feature_extractor(models, savepath):
  96. for phase in ['Train', 'Test']:
  97. for images, labels,save_dir in dataloders[phase]:
  98. images.to(device)
  99. labels.to(device)
  100. # 输出特征,并转换为NPY格式进行保存
  101. output3 = models(images.cuda())
  102. output = nn.functional.softmax(output3, dim=0)
  103. print(output.shape)
  104. output = output.cpu()
  105. output = torch.squeeze(output)
  106. output = output.data.numpy()
  107. for feat, featpath in zip(output, save_dir):
  108. np.save(featpath, feat)
  109. if __name__ == '__main__':
  110. Num_class = 67
  111. pthpath = '/home/haitaizeng/stanforf/ctumb_zht/TFP/pytorch/save_model/MIT67/Places0.8500.pth'
  112. # model_ft = net() ##这是自行编写的Resnet50,用于后面的特征提取的操作
  113. model_ft = Resnet([3, 4, 6, 3], Num_class)
  114. ckpt = torch.load(pthpath, map_location=lambda storage, loc: storage)
  115. model_ft.load_state_dict(ckpt)
  116. model_ft.eval()
  117. model_ft = model_ft.to(device)
  118. model_ft.cuda()
  119. path = '/home/haitaizeng/stanforf/ctumb_zht/TFP/pytorch/Feature/MIT67'
  120. Feature_extractor(models=model_ft, savepath=path)

 

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/小舞很执着/article/detail/880925
推荐阅读
相关标签
  

闽ICP备14008679号