赞
踩
paper作者Zero-DCE主页:https://li-chongyi.github.io/Proj_Zero-DCE.html
Zero-DCE使用深度学习方法参考了深度曲线估计,通过一个轻量的深度卷积神经网络设计了一个光线增强曲线,对微光图像进行增强,可以将不同的灯光条件下采集的光照不均匀和弱光的图像进行调整。
对输入图像的每个信道分别做迭代操作,每次迭代操作的输出和输入图像map层再次结合作为下层输入。
pytorch实现:https://github.com/Li-Chongyi/Zero-DCE
class enhance_net_nopool(nn.Module): def __init__(self): super(enhance_net_nopool, self).__init__() self.relu = nn.ReLU(inplace=True) number_f = 32 self.e_conv1 = nn.Conv2d(3,number_f,3,1,1,bias=True) self.e_conv2 = nn.Conv2d(number_f,number_f,3,1,1,bias=True) self.e_conv3 = nn.Conv2d(number_f,number_f,3,1,1,bias=True) self.e_conv4 = nn.Conv2d(number_f,number_f,3,1,1,bias=True) self.e_conv5 = nn.Conv2d(number_f*2,number_f,3,1,1,bias=True) self.e_conv6 = nn.Conv2d(number_f*2,number_f,3,1,1,bias=True) self.e_conv7 = nn.Conv2d(number_f*2,24,3,1,1,bias=True) self.maxpool = nn.MaxPool2d(2, stride=2, return_indices=False, ceil_mode=False) self.upsample = nn.UpsamplingBilinear2d(scale_factor=2) def forward(self, x): x1 = self.relu(self.e_conv1(x)) x2 = self.relu(self.e_conv2(x1)) x3 = self.relu(self.e_conv3(x2)) x4 = self.relu(self.e_conv4(x3)) x5 = self.relu(self.e_conv5(torch.cat([x3,x4],1))) x6 = self.relu(self.e_conv6(torch.cat([x2,x5],1))) x_r = F.tanh(self.e_conv7(torch.cat([x1,x6],1))) r1,r2,r3,r4,r5,r6,r7,r8 = torch.split(x_r, 3, dim=1) x = x + r1*(torch.pow(x,2)-x) x = x + r2*(torch.pow(x,2)-x) x = x + r3*(torch.pow(x,2)-x) enhance_image_1 = x + r4*(torch.pow(x,2)-x) x = enhance_image_1 + r5*(torch.pow(enhance_image_1,2)-enhance_image_1) x = x + r6*(torch.pow(x,2)-x) x = x + r7*(torch.pow(x,2)-x) enhance_image = x + r8*(torch.pow(x,2)-x) r = torch.cat([r1,r2,r3,r4,r5,r6,r7,r8],1) return enhance_image_1,enhance_image,r
Zero-DCE是一个轻量网络,使用到的算子也不复杂,算子都为tensorrt已经支持的操作,我们可以直接调用tensorrt的操作来部署,不需要重新自定义算子。
auto pow = network->addElementWise(*data, *data, ElementWiseOperation::kPROD); auto sub = network->addElementWise(*pow->getOutput(0), *data, ElementWiseOperation::kSUB); //e_conv1 IConvolutionLayer* conv1 = network->addConvolutionNd(*data, 32, DimsHW{ 3, 3 }, weightMap["e_conv1.weight"], weightMap["e_conv1.bias"]); assert(conv1); conv1->setPaddingNd(DimsHW{ 1, 1 }); conv1->setStrideNd(DimsHW{ 1,1 }); IActivationLayer* relu1 = network->addActivation(*conv1->getOutput(0), ActivationType::kRELU); assert(relu1); //e_conv2 IConvolutionLayer* conv2 = network->addConvolutionNd(*relu1->getOutput(0), 32, DimsHW{ 3, 3 }, weightMap["e_conv2.weight"], weightMap["e_conv2.bias"]); conv2->setPaddingNd(DimsHW{ 1, 1 }); conv2->setStrideNd(DimsHW{ 1,1 }); IActivationLayer* relu2 = network->addActivation(*conv2->getOutput(0), ActivationType::kRELU); //e_conv3 IConvolutionLayer* conv3 = network->addConvolutionNd(*relu2->getOutput(0), 32, DimsHW{ 3, 3 }, weightMap["e_conv3.weight"], weightMap["e_conv3.bias"]); conv3->setPaddingNd(DimsHW{ 1, 1 }); conv3->setStrideNd(DimsHW{ 1,1 }); IActivationLayer* relu3 = network->addActivation(*conv3->getOutput(0), ActivationType::kRELU); //e_conv4 IConvolutionLayer* conv4 = network->addConvolutionNd(*relu3->getOutput(0), 32, DimsHW{ 3, 3 }, weightMap["e_conv4.weight"], weightMap["e_conv4.bias"]); conv4->setPaddingNd(DimsHW{ 1, 1 }); conv4->setStrideNd(DimsHW{ 1,1 }); IActivationLayer* relu4 = network->addActivation(*conv4->getOutput(0), ActivationType::kRELU); //concat relu3 and relu4 ITensor* inputTensors34[] = { relu3->getOutput(0), relu4->getOutput(0) }; auto cat34 = network->addConcatenation(inputTensors34, 2); //e_conv5 IConvolutionLayer* conv5 = network->addConvolutionNd(*cat34->getOutput(0), 32, DimsHW{ 3, 3 }, weightMap["e_conv5.weight"], weightMap["e_conv5.bias"]); conv5->setPaddingNd(DimsHW{ 1, 1 }); conv5->setStrideNd(DimsHW{ 1,1 }); IActivationLayer* relu5 = network->addActivation(*conv5->getOutput(0), ActivationType::kRELU); //concat relu2 and relu5 ITensor* inputTensors25[] = { relu2->getOutput(0), relu5->getOutput(0) }; auto cat25 = network->addConcatenation(inputTensors25, 2); //e_conv6 IConvolutionLayer* conv6 = network->addConvolutionNd(*cat25->getOutput(0), 32, DimsHW{ 3, 3 }, weightMap["e_conv6.weight"], weightMap["e_conv6.bias"]); conv6->setPaddingNd(DimsHW{ 1, 1 }); conv6->setStrideNd(DimsHW{ 1,1 }); IActivationLayer* relu6 = network->addActivation(*conv6->getOutput(0), ActivationType::kRELU); //concat relu1 and relu6 ITensor* inputTensors16[] = { relu1->getOutput(0), relu6->getOutput(0) }; auto cat16 = network->addConcatenation(inputTensors16, 2); //e_conv7 IConvolutionLayer* conv7 = network->addConvolutionNd(*cat16->getOutput(0), 24, DimsHW{ 3, 3 }, weightMap["e_conv7.weight"], weightMap["e_conv7.bias"]); conv7->setPaddingNd(DimsHW{ 1, 1 }); conv7->setStrideNd(DimsHW{ 1,1 }); IActivationLayer* relu7 = network->addActivation(*conv7->getOutput(0), ActivationType::kTANH); //addSlice Dims d = relu7->getOutput(0)->getDimensions(); ISliceLayer* slice0 = network->addSlice(*relu7->getOutput(0), Dims3{ 0,0,0 }, Dims3{ d.d[0] / 8,d.d[1],d.d[2] }, Dims3{ 1,1,1 }); ISliceLayer* slice1 = network->addSlice(*relu7->getOutput(0), Dims3{ 1,0,0 }, Dims3{ d.d[0] / 8,d.d[1],d.d[2] }, Dims3{ 1,1,1 }); ISliceLayer* slice2 = network->addSlice(*relu7->getOutput(0), Dims3{ 2,0,0 }, Dims3{ d.d[0] / 8,d.d[1],d.d[2] }, Dims3{ 1,1,1 }); ISliceLayer* slice3 = network->addSlice(*relu7->getOutput(0), Dims3{ 3,0,0 }, Dims3{ d.d[0] / 8,d.d[1],d.d[2] }, Dims3{ 1,1,1 }); ISliceLayer* slice4 = network->addSlice(*relu7->getOutput(0), Dims3{ 4,0,0 }, Dims3{ d.d[0] / 8,d.d[1],d.d[2] }, Dims3{ 1,1,1 }); ISliceLayer* slice5 = network->addSlice(*relu7->getOutput(0), Dims3{ 5,0,0 }, Dims3{ d.d[0] / 8,d.d[1],d.d[2] }, Dims3{ 1,1,1 }); ISliceLayer* slice6 = network->addSlice(*relu7->getOutput(0), Dims3{ 6,0,0 }, Dims3{ d.d[0] / 8,d.d[1],d.d[2] }, Dims3{ 1,1,1 }); ISliceLayer* slice7 = network->addSlice(*relu7->getOutput(0), Dims3{ 7,0,0 }, Dims3{ d.d[0] / 8,d.d[1],d.d[2] }, Dims3{ 1,1,1 }); //split auto mul = network->addElementWise(*slice0->getOutput(0), *sub->getOutput(0), ElementWiseOperation::kPROD); auto add = network->addElementWise(*data, *mul->getOutput(0), ElementWiseOperation::kSUM); pow = network->addElementWise(*add->getOutput(0), *add->getOutput(0), ElementWiseOperation::kPROD); sub = network->addElementWise(*pow->getOutput(0), *add->getOutput(0), ElementWiseOperation::kSUB); mul = network->addElementWise(*slice1->getOutput(0), *sub->getOutput(0), ElementWiseOperation::kPROD); add = network->addElementWise(*add->getOutput(0), *mul->getOutput(0), ElementWiseOperation::kSUM); pow = pow = network->addElementWise(*add->getOutput(0), *add->getOutput(0), ElementWiseOperation::kPROD); sub = network->addElementWise(*pow->getOutput(0), *add->getOutput(0), ElementWiseOperation::kSUB); mul = network->addElementWise(*slice2->getOutput(0), *sub->getOutput(0), ElementWiseOperation::kPROD); add = network->addElementWise(*add->getOutput(0), *mul->getOutput(0), ElementWiseOperation::kSUM); pow = network->addElementWise(*add->getOutput(0), *add->getOutput(0), ElementWiseOperation::kPROD); sub = network->addElementWise(*pow->getOutput(0), *add->getOutput(0), ElementWiseOperation::kSUB); mul = network->addElementWise(*slice3->getOutput(0), *sub->getOutput(0), ElementWiseOperation::kPROD); add = network->addElementWise(*add->getOutput(0), *mul->getOutput(0), ElementWiseOperation::kSUM); pow = network->addElementWise(*add->getOutput(0), *add->getOutput(0), ElementWiseOperation::kPROD); sub = network->addElementWise(*pow->getOutput(0), *add->getOutput(0), ElementWiseOperation::kSUB); mul = network->addElementWise(*slice4->getOutput(0), *sub->getOutput(0), ElementWiseOperation::kPROD); add = network->addElementWise(*add->getOutput(0), *mul->getOutput(0), ElementWiseOperation::kSUM); pow = network->addElementWise(*add->getOutput(0), *add->getOutput(0), ElementWiseOperation::kPROD); sub = network->addElementWise(*pow->getOutput(0), *add->getOutput(0), ElementWiseOperation::kSUB); mul = network->addElementWise(*slice5->getOutput(0), *sub->getOutput(0), ElementWiseOperation::kPROD); add = network->addElementWise(*add->getOutput(0), *mul->getOutput(0), ElementWiseOperation::kSUM); pow = network->addElementWise(*add->getOutput(0), *add->getOutput(0), ElementWiseOperation::kPROD); sub = network->addElementWise(*pow->getOutput(0), *add->getOutput(0), ElementWiseOperation::kSUB); mul = network->addElementWise(*slice6->getOutput(0), *sub->getOutput(0), ElementWiseOperation::kPROD); add = network->addElementWise(*add->getOutput(0), *mul->getOutput(0), ElementWiseOperation::kSUM); pow = network->addElementWise(*add->getOutput(0), *add->getOutput(0), ElementWiseOperation::kPROD); sub = network->addElementWise(*pow->getOutput(0), *add->getOutput(0), ElementWiseOperation::kSUB); mul = network->addElementWise(*slice7->getOutput(0), *sub->getOutput(0), ElementWiseOperation::kPROD); add = network->addElementWise(*add->getOutput(0), *mul->getOutput(0), ElementWiseOperation::kSUM); add->getOutput(0)->setName(OUTPUT_BLOB_NAME);
Zero-DCE使用的tensorrt算子:
tensorrt | pytorch |
---|---|
addElementWise(ElementWiseOperation::kSUM) | + |
addElementWise(ElementWiseOperation::kSUB) | - |
addElementWise(ElementWiseOperation::kPROD) | torch.pow |
addConvolutionNd | Conv2d |
addActivation(ActivationType::kRELU) | ReLU |
addActivation(ActivationType::kTANH) | tanh |
addConcatenation | torch.cat |
addSlice | split |
tensorrtx部署步骤参考:https://github.com/wang-xinyu/tensorrtx
将Zero-DCE模型转换为wts权重文件
python gen_wts.py -w zero_dce.pt -o zero_dec.wts
gen_wts.py
import sys import argparse import os import struct import torch from utils.torch_utils import select_device def parse_args(): parser = argparse.ArgumentParser(description='Convert .pt file to .wts') parser.add_argument('-w', '--weights', required=True, help='Input weights (.pt) file path (required)') parser.add_argument('-o', '--output', help='Output (.wts) file path (optional)') args = parser.parse_args() if not os.path.isfile(args.weights): raise SystemExit('Invalid input file') if not args.output: args.output = os.path.splitext(args.weights)[0] + '.wts' elif os.path.isdir(args.output): args.output = os.path.join( args.output, os.path.splitext(os.path.basename(args.weights))[0] + '.wts') return args.weights, args.output pt_file, wts_file = parse_args() # Initialize device = select_device('cpu') # Load model model = torch.load(pt_file, map_location=device)['model'].float() # load to FP32 # update anchor_grid info anchor_grid = model.model[-1].anchors * model.model[-1].stride[...,None,None] # model.model[-1].anchor_grid = anchor_grid delattr(model.model[-1], 'anchor_grid') # model.model[-1] is detect layer model.model[-1].register_buffer("anchor_grid",anchor_grid) #The parameters are saved in the OrderDict through the "register_buffer" method, and then saved to the weight. model.to(device).eval() with open(wts_file, 'w') as f: f.write('{}\n'.format(len(model.state_dict().keys()))) for k, v in model.state_dict().items(): vr = v.reshape(-1).cpu().numpy() f.write('{} {} '.format(k, len(vr))) for vv in vr: f.write(' ') f.write(struct.pack('>f' ,float(vv)).hex()) f.write('\n')
读取zero_dce.wts中权重信息,写入改写时weightMap
std::map<std::string, Weights> loadWeights(const std::string file) { std::cout << "Loading weights: " << file << std::endl; std::map<std::string, Weights> weightMap; // Open weights file std::ifstream input(file); assert(input.is_open() && "Unable to load weight file. please check if the .wts file path is right!!!!!!"); // Read number of weight blobs int32_t count; input >> count; assert(count > 0 && "Invalid weight map file."); while (count--) { Weights wt{ DataType::kFLOAT, nullptr, 0 }; uint32_t size; // Read name and type of blob std::string name; input >> name >> std::dec >> size; wt.type = DataType::kFLOAT; // Load blob uint32_t* val = reinterpret_cast<uint32_t*>(malloc(sizeof(val) * size)); for (uint32_t x = 0, y = size; x < y; ++x) { input >> std::hex >> val[x]; } wt.values = val; wt.count = size; weightMap[name] = wt; } return weightMap; }
其余tensorrt初始化部分与tensorrtx其余项目类似不再赘述了。
之前部署yolov4时实现scatterNd算子花了不小的功夫(结果实现后,tensorrt22.01开始支持scatterplugin…),只要算子tensorrt已经提供支持,使用tensorrt部署还是很方便的,切记根据输入和输出大小修改开辟的空间(踩过坑…
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。