当前位置:   article > 正文

踩坑:Libtorch实现U-net在C++上对BraTS2019数据集的分割任务部署_brats 2019 数据集

brats 2019 数据集

根据项目组要求,最近用Pytorch完成了U-net训练BraTS2019数据的任务,输出模型并在C++上重现部署,简单说说我遇到的坑,可按目录浏览:

- 模型转换问题

步骤1:先在pytorch验证你的模型(.pth)是否能够重现结果,一般来讲都可以成功复原的,这一步很重要,在C++上重现时最好有对照,以免思路混乱,这里贴出我的测试一张图片的部分代码:

def load_model():
    net = UNet2D(1, 5, 64 ).to(torch.device('cpu')) #cpu测试
    state_dict = torch.load(saved_model_path,map_location='cpu')
    net.load_state_dict(state_dict,strict=True) #记得加True
    return net
    
if __name__ == "__main__":        
    np.set_printoptions(threshold=np.inf)  #为了看完整的结果
    net = load_model()
    
    img = sitk.ReadImage('BraTS19_2013_2_1_flair.nii.gz')#需要安装SimpleITK这个包来读nii
    nda = sitk.GetArrayFromImage(img)  
    test_data = np.asarray(nda[110])
    test_data = norm_vol(test_data)#这里的归一化自己写的,按照自己需求来确定是否要做归一化
	#转成tensor
    test_data = torch.from_numpy(test_data)
    test_data = torch.tensor(test_data,dtype=torch.float32)
    #按照网络输入需求拓展维度
    test_data=torch.unsqueeze(test_data,0)
    test_data=torch.unsqueeze(test_data,1)
	#预测阶段
    with torch.no_grad():
        net.eval()
        predict = net(test_data)
        predict = F.softmax(predict,dim=1)
        predict = torch.max(predict,dim=1)[1]
        predict = predict.squeeze().long().data
    
    io.imwrite('result.jpg', predict)#imageio工具包
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29

步骤2:将pth模型转至pt文件,后续用于C++预测,这个过程官方的例子里也有,这是我按照自己需求来改的代码:

import torch
import torch.nn as nn
from unet2d import UNet2D
saved_model_path = 'best.pth'
net = UNet2D(1, 5, 64 ).to(torch.device('cpu')) 
state_dict = torch.load(saved_model_path,map_location='cpu')
net.load_state_dict(state_dict,strict=True)#同样记得TRUE
net.eval()#重要!
example = torch.rand(1, 1, 240, 240).float()
traced = torch.jit.trace(net, example)
traced.save('best.pt')
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11

其他:可以按压缩包的打开方式来打开pt文件来查看模型追踪是否正确:
在这里插入图片描述
其中这个文件可以查看整体的追踪结果。
在这里插入图片描述

- tensor是nan的值

libtorch的安装以及cmake运行在我之前的博客里有提到,可以翻一下对号入座来安装,网上教程也很多,这里不加赘述,我pytorch和libtorch都是1.4.0。
部分C++部署代码如下:

int main() {

	auto tensor1 = torch::empty(1 * 1 * 240 * 240);
	float* data1 = tensor1.data<float>();
	for (int i = 0; i < 1; i++)
	{
		for (int j = 0; j < 1; j++)
		{
			for (int x = 0; x < 240; x++)
			{
				for (int y = 0; y < 240; y++)
				{
					*data1++ = (itk[110][x][y]-min*1.0)/(max.0-1.0);
				}
			}
		}
	}
	auto t = tensor1.resize_({ 1,1,240,240 });
	t=t.div(255);
	torch::jit::script::Module module = torch::jit::load(model_path,torch::kCPU);  //load model
 //   init model
	module.eval();
	cout << "model input is ok\n";
	vector<torch::jit::IValue> inputs;  //def an input

	inputs.emplace_back(t.toType(torch::kFloat32));
	float start = getTickCount(); 
	auto result = module.forward( inputs ).toTensor();  //前向传播获取结果 = net(image)
	float end = getTickCount();
	float last = end - start;
	cout << "time consume: " << (last / getTickFrequency()) << endl;
	//rescalling input element into range(0,1) and sum to 1; output size is same to input
	auto prob = result.softmax(1);//torch::nn::functional::softmax(result, 1);
	auto prediction = prob.max(1);//tuple类型
	inputs.pop_back();
	std::tuple_element<1, decltype(prediction)>::type cnt = std::get<1>(prediction);
	std::cout << "cnt = " << cnt.sizes() << std::endl;// 1 240 240
	
	cnt=cnt.squeeze().data;//这句代码有问题,最好按照需求来确认要不要写
	
	cout << "result sizes:" << cnt.sizes() << endl; //240 240

	Mat m(cnt.size(0), cnt.size(1), CV_32FC1, cnt.data());
	imwrite("res.jpg", m);

	return 0;
}
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47

从上述代码可以看出流程:
读入数据->载入模型->将数据整理成tensor(这里我没用torch::from_blob)->预测->结果整理->输出
大体流程与上面的pytorch流程一致。
但在前向传播这一步的时候,结果输出nan值:
在这里插入图片描述
很明显是不正确的,问题是出在我加了一句t=t.div(255);,导致我数据类型出错,问题发生点可以看我在pytorch论坛里的提问:
Pytorch Forums–Having problems in segmenting image when using libtorch

- 输出结果只有部分的一块

结果输出有很多情况,就我而言,遇到的问题是输出全黑的图片和只有部分分割结果的图片:
在这里插入图片描述
这种情况最好检查输入tensor的值,是否与pytorch上输出的一致
在这里插入图片描述
这种情况最好确认是否需要加==cnt.squeeze.data();==这一句代码,我删除以后就没问题了。
在这里插入图片描述
因为我只训练了50个epoch,所以结果很差,将就着先用。

先写到这,想到再补充。

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/Cpp五条/article/detail/97978
推荐阅读
相关标签
  

闽ICP备14008679号