当前位置:   article > 正文

TransUNet:使用自己的数据集完成训练后进行测试_transunet跑自己的数据集测试

transunet跑自己的数据集测试

训练完成后,会保存模型 ,会显示权重保存的路径snapshot_path

 1.修改test.py文件,需要与train.py中设置参数保持一致

test_save_dir为测试结果的路径

同样的与train.py中相同,配置自己数据集的相关信息

接着添加一下权重路径

 2.修改utils.py中的test_single_volume函数,可以根据自己的分割类别修改,自定义不同颜色所代表的分割种类

  1. def test_single_volume(image, label, net, classes, patch_size=[256, 256], test_save_path=None, case=None, z_spacing=1):
  2. image, label = image.squeeze(0).cpu().detach().numpy(), label.squeeze(0).cpu().detach().numpy()
  3. _,x, y = image.shape
  4. if x != patch_size[0] or y != patch_size[1]:
  5. image = zoom(image, (1,patch_size[0] / x, patch_size[1] / y), order=3)
  6. input = torch.from_numpy(image).unsqueeze(0).float() #input = torch.from_numpy(image).unsqueeze(0).float().cuda()
  7. net.eval()
  8. with torch.no_grad():
  9. out = torch.argmax(torch.softmax(net(input), dim=1), dim=1).squeeze(0)
  10. out = out.cpu().detach().numpy()
  11. if x != patch_size[0] or y != patch_size[1]:
  12. prediction = zoom(out, (x / patch_size[0], y / patch_size[1]), order=0)
  13. else:
  14. prediction = out
  15. metric_list = []
  16. for i in range(1, classes):
  17. metric_list.append(calculate_metric_percase(prediction == i, label == i))
  18. if test_save_path is not None:
  19. a1 = copy.deepcopy(prediction)
  20. a2 = copy.deepcopy(prediction)
  21. a3 = copy.deepcopy(prediction)
  22. ##表示有4个类别
  23. a1[a1 == 1] = 255
  24. a1[a1 == 2] = 0 #代表R通道中输出结果为2的赋值0
  25. a1[a1 == 3] = 255
  26. a1[a1 == 4] = 20
  27. a2[a2 == 1] = 255
  28. a2[a2 == 2] = 255 #代表G通道中输出结果为2的赋值255
  29. a2[a2 == 3] = 0
  30. a2[a2 == 4] = 10
  31. a3[a3 == 1] = 255
  32. a3[a3 == 2] = 77 #代表B通道中输出结果为2的赋值77 ;(0,255,77)对应就是绿色,类别2就是绿色
  33. a3[a3 == 3] = 0
  34. a3[a3 == 4] = 120
  35. a1 = Image.fromarray(np.uint8(a1)).convert('L') #array转换成image,Image.fromarray(np.uint8(img))
  36. a2 = Image.fromarray(np.uint8(a2)).convert('L')
  37. a3 = Image.fromarray(np.uint8(a3)).convert('L')
  38. prediction = Image.merge('RGB', [a1, a2, a3])
  39. prediction.save(test_save_path+'/'+case+'.png')
  40. return metric_list

 3.运行test.py,测试结果会以.png格式保存在前面所设置的test_save_dir中

声明:本文内容由网友自发贡献,转载请注明出处:【wpsshop博客】
推荐阅读
相关标签
  

闽ICP备14008679号