当前位置:   article > 正文

关键点检测——直接回归法_关键点回归

关键点回归

一、数据集格式

 二、解析xml文件,生成data_center.txt

  1. from PIL import Image
  2. import math,os
  3. from xml.etree import ElementTree as ET
  4. def keep_image_size_open(path, size=(256, 256)):
  5. img = Image.open(path)
  6. temp = max(img.size)
  7. mask = Image.new('RGB', (temp, temp), (0, 0, 0))
  8. mask.paste(img, (0, 0))
  9. mask = mask.resize(size)
  10. return mask
  11. def make_data_center_txt(xml_dir):
  12. with open('data_center.txt', 'a') as f:
  13. f.truncate(0)
  14. path=r'data/images'
  15. xml_names = os.listdir(xml_dir)
  16. for xml in xml_names:
  17. xml_path = os.path.join(xml_dir, xml)
  18. in_file = open(xml_path)
  19. tree = ET.parse(in_file)
  20. root = tree.getroot()
  21. image_path = root.find('path')
  22. polygon = root.find('outputs/object/item/polygon')
  23. data = []
  24. c_data = []
  25. data_str = ''
  26. print(xml)
  27. for i in polygon:
  28. data.append(int(i.text))
  29. data_str = data_str + ' ' + str(i.text)
  30. for i in range(0, len(data), 2):
  31. c_data.append((data[i], data[i + 1]))
  32. data_str = os.path.join(path,image_path.text.split('\\')[-1]) +data_str
  33. f.write(data_str + '\n')
  34. if __name__ == '__main__':
  35. make_data_center_txt('data/xml')

 三、加载数据集

  1. import torch
  2. from torch.utils.data import Dataset
  3. from torchvision import transforms
  4. from PIL import Image
  5. tf = transforms.Compose([ #标准化处理
  6. transforms.ToTensor()
  7. ])
  8. class MyDataset(Dataset):
  9. def __init__(self,root): #传入路径
  10. f=open(root,'r')
  11. self.dataset=f.readlines() #读所有行
  12. def __len__(self):
  13. return len(self.dataset) #返回数据集长度
  14. def __getitem__(self, index):
  15. data=self.dataset[index] #取当前数据
  16. img_path=data.split(' ')[0] #以空格划分,并取出文件名,即data/images\0.png
  17. img_data=Image.open(img_path) #打开图片
  18. # points = data.split(' ')[1:-2] # 取出后面5个点的x,y坐标,-2是取不到的
  19. points=data.split(' ')[1:] #取出后面5个点的x,y坐标
  20. # print(img_data, points)
  21. points = [int(points[0])/774, int(points[1])/434, int(points[2])/774, int(points[3])/434, int(points[4])/774, int(points[5])/434]
  22. # points=[int(i)/100 for i in points] #图像宽高为100,int(i)/100进行归一化
  23. # print(img_data, points)
  24. return tf(img_data),torch.Tensor(points) #将img_data标准化,将points转化为tensor格式
  25. if __name__ == '__main__':
  26. data=MyDataset('data_center.txt')
  27. for i in data:
  28. print(i[0].shape)
  29. print(i[1].shape)

四、构建网络

  1. import torch
  2. from torchvision import models
  3. from torch import nn
  4. class Net(nn.Module):
  5. def __init__(self):
  6. super(Net, self).__init__()
  7. self.layer=nn.Sequential( #用resnet50模型
  8. models.resnet50(pretrained=True)
  9. )
  10. #全连接层的输出要改为自己对应的输出,将1000分类通过全连接层变为6分类
  11. self.out=nn.Linear(1000,6)
  12. def forward(self,x):
  13. return self.out(self.layer(x)) #将输入x经过resnet50以及全连接层Linear
  14. if __name__ == '__main__':
  15. net=Net()
  16. x=torch.randn(1,3,100,100)
  17. print(net(x).shape)

五、开始训练

  1. import os
  2. from torch import nn,optim
  3. import torch
  4. from dataset import *
  5. from net import *
  6. from torch.utils.data import DataLoader
  7. if __name__ == '__main__':
  8. device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')
  9. net=Net().to(device) #实例化网络并指认到设备上
  10. weights='params/net.pth'
  11. if os.path.exists(weights): #如果有初始权值就加载
  12. net.load_state_dict(torch.load(weights)) #加载权重
  13. print('loading successfully')
  14. opt=optim.Adam(net.parameters()) #指定优化器并传入参数
  15. loss_fun=nn.MSELoss() #定义损失函数
  16. dataset=MyDataset('data_center.txt') #实例化数据集
  17. data_loader=DataLoader(dataset,batch_size=2,shuffle=True) #加载数据集
  18. epoch = 1
  19. while True:
  20. for i,(image,label) in enumerate(data_loader): #用枚举的方式遍历数据集
  21. image,label=image.to(device),label.to(device) #将图片和标签指认到设备上
  22. # print(image.shape, label.shape)
  23. out=net(image) #将图片输入网络
  24. train_loss=loss_fun(out,label) #预测值和真是标签做损失
  25. print(f'{epoch}-{i}-train_loss:{train_loss.item()}') #打印当前轮次当前批次的训练损失
  26. opt.zero_grad() #梯度清零
  27. train_loss.backward() #反向传播
  28. opt.step() #更新梯度
  29. if epoch % 10 == 0: #每10轮保存一次权重
  30. torch.save(net.state_dict(),f'params/net.pth') #保存参数
  31. print('save successfully')
  32. epoch += 1

六、利用训练好的权重进行预测

  1. import os
  2. import torch
  3. from PIL import Image,ImageDraw
  4. from dataset import *
  5. from net import * #import * 代表导入所有
  6. path='test_image'
  7. net=Net() #实例化网络
  8. net.load_state_dict(torch.load('params/net.pth')) #加载训练好的权重
  9. net.eval() #测试模式
  10. for i in os.listdir(path):
  11. img=Image.open(os.path.join(path,i))
  12. draw=ImageDraw.Draw(img) #创建画板
  13. img_data=tf(img)
  14. img_data=torch.unsqueeze(img_data,dim=0)
  15. out=net(img_data)
  16. # print(out, out.shape)
  17. out=(out[0]).tolist() #取第0个,并由tenser转化成列表形式
  18. out = [out[0]*774,out[1]*434,out[2]*774,out[3]*434,out[4]*774,out[5]*434]
  19. # print(out)
  20. for j in range(0,len(out),2):
  21. draw.ellipse((out[j]-2,out[j+1]-2,out[j]+2,out[j+1]+2),(255,0,0)) #画半径为2的圆
  22. img.show()

七、制作数据集 

精灵标注助手->选择多边形框标注->标注完一张Ctrl+S保存->导出XML格式

reference

>>>>>来自B站大佬

【深度学习关键点回归(直接回归法&heatmap热力图法)】 https://www.bilibili.com/video/BV1sS4y197J1/?p=2&share_source=copy_web&vd_source=95705b32f23f70b32dfa1721628d5874

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/AllinToyou/article/detail/97813
推荐阅读
相关标签
  

闽ICP备14008679号