- # -*- coding: utf-8 -*-
- """
- Function: 图像特征的提取,可以依据需求修改CNN的输出,得到不同层网络的输出图像特征
- Writer: Zenght
- date:2019.2.16
- """
- from __future__ import print_function, division, absolute_import
- import torch
- import torchvision
- import torch.nn as nn
- from torch.autograd import Variable
- import torch.optim as optim
- from torch.optim import lr_scheduler
- import numpy as np
- from torchvision import datasets, models, transforms
- import os
- import cv2
- import time
- import copy
- import torch.utils.data as data
- from Rsenet50 import Resnet
- class Net(nn.Module):
- # 此处可以添加自行设定的网络结构
- def __init__(self):
- super(Net, self).__init__()
- def cv2_imageloader(path):
- mean = [0.485, 0.456, 0.406]
- std = [0.229, 0.224, 0.225]
- img = cv2.imread(path)
- img = cv2.resize(img, (224, 224))
- im_arr = np.float32(img)
- im_arr = np.ascontiguousarray(im_arr[..., ::-1])
- im_arr = im_arr.transpose(2, 0, 1)# Convert Img from BGR to RGB
- for channel, _ in enumerate(im_arr):
- # Normalization
- im_arr[channel] /= 255
- im_arr[channel] -= mean[channel]
- im_arr[channel] /= std[channel]
- # Convert to float tensor
- im_as_ten = torch.from_numpy(im_arr).float()
- # Convert to Pytorch variable
- im_as_var = Variable(im_as_ten, requires_grad=True)
- return im_as_var
- def default_loader(path):
- return cv2_imageloader(path)
- class CustomImageLoader(data.Dataset):
- ##自定义类型数据输入
- def __init__(self, img_path, txt_path, dataset = '', loader = default_loader, save_path='/home/haitaizeng/stanforf/ctumb_zht/TFP/pytorch/Feature/MIT67'):
- im_list = []
- im_dirs = []
- im_labels = []
- with open(txt_path, 'r') as files:
- for line in files:
- items = line.split()
- if items[0][0] == '/':
- imname = line.split()[0][1:]
- fnewname = '_'.join(imname[:-4].split('/')) + '.npy'
- else:
- imname = line.split()[0]
- fnewname = '_'.join(imname[:-4].split('/'))+'.npy'
- im_list.append(os.path.join(img_path, imname))
- im_labels.append(int(items[1]))
- im_dirs.append(os.path.join(save_path, fnewname))
- self.imgs = im_list
- self.labels = im_labels
- self.save_dir = im_dirs
- self.loader = loader
- self.dataset = dataset
- def __len__(self):
- return len(self.imgs)
- def __getitem__(self, item):
- # print(item)
- img_name = self.imgs[item]
- label = self.labels[item]
- imdir = self.save_dir[item]
- img = self.loader(img_name)
- return img, label, imdir
- batch_size = 64
- device = torch.device('cuda:0')
- # SUN397 INPUT
- # image_dir = '/media/haitaizeng/000222840009D764/Images'#
- # image_datasets = {x : CustomImageLoader(image_dir, txt_path=('/home/haitaizeng/stanforf/ctumb_zht/TFP/pytorch/Trainfiles/SUN397/'+x+'Images.label'),
- #
- # dataset=x) for x in ['Train', 'Test']
- # }
- image_dir = '/media/haitaizeng/00038FCE000387A5/cgw/Datasets/MIT67/Images'
- image_datasets = {x : CustomImageLoader(image_dir, txt_path=('/home/haitaizeng/stanforf/alex_mit/data_image/'+x+'Images.label'),
- dataset=x) for x in ['Train', 'Test']
- }
- dataloders = {x: torch.utils.data.DataLoader(image_datasets[x],
- batch_size=batch_size,
- shuffle=True) for x in ['Train', 'Test']}
- dataset_sizes = {x: len(image_datasets[x]) for x in ['Train', 'Test']}
- def Feature_extractor(models, savepath):
- for phase in ['Train', 'Test']:
- for images, labels,save_dir in dataloders[phase]:
- images.to(device)
- labels.to(device)
- # 输出特征,并转换为NPY格式进行保存
- output3 = models(images.cuda())
- output = nn.functional.softmax(output3, dim=0)
- print(output.shape)
- output = output.cpu()
- output = torch.squeeze(output)
- output = output.data.numpy()
- for feat, featpath in zip(output, save_dir):
- np.save(featpath, feat)
- if __name__ == '__main__':
- Num_class = 67
- pthpath = '/home/haitaizeng/stanforf/ctumb_zht/TFP/pytorch/save_model/MIT67/Places0.8500.pth'
- # model_ft = net() ##这是自行编写的Resnet50,用于后面的特征提取的操作
- model_ft = Resnet([3, 4, 6, 3], Num_class)
- ckpt = torch.load(pthpath, map_location=lambda storage, loc: storage)
- model_ft.load_state_dict(ckpt)
- model_ft.eval()
- model_ft = model_ft.to(device)
- model_ft.cuda()
- path = '/home/haitaizeng/stanforf/ctumb_zht/TFP/pytorch/Feature/MIT67'
- Feature_extractor(models=model_ft, savepath=path)
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。