赞
踩
所用代码地址:crnn.caffe
这里需要将图像数据统一转换到12832(宽度高度)上来,当然也可以更改为其它类型的长度,这里只是使用了该尺寸。
Label数据是与图像数据对应的数据,其中包含了图像中的具体字符数据。在制作label数据的时候需要将不同的字符转换到不同的数字标号上来,这里需要将字符映射表的最后一位设置为“_blank_”。这里还需要注意的是Label的长度应该和最大label的长度对应否则会超过label的表示范围。
这里首先假设需要的分类数目为N,再加上背景那么就是N+1类。所以就要在CRNN中首先就要修改的便是:
layer { name: "fc1" type: "InnerProduct" bottom: "drop1" top: "fc1" param { lr_mult: 1 decay_mult: 1 } param { lr_mult: 2 decay_mult: 0 } inner_product_param { num_output: N+1 axis: 2 weight_filler { type: "xavier" } bias_filler { type: "constant" value: 0 } } }
还有
layer {
name: "ctc_loss"
type: "CtcLoss"
bottom: "fc1"
bottom: "label"
top: "ctc_loss"
loss_weight: 1.0
ctc_loss_param {
blank_label: N
alphabet_size: N+1
time_step: 32
}
}
还有精度层
layer {
name: "accuracy"
type: "LabelsequenceAccuracy"
bottom: "premuted_fc"
bottom: "label"
top: "accuracy"
labelsequence_accuracy_param {
blank_label: N
}
}
其中还需要修改的便是time_step的值,也就是CRNN中的帧率。将time_step都设置为32,因为我们的输入图像宽度为128。
修改reshape层
layer {
name: "reshape"
type: "Reshape"
bottom: "conv6"
top: "reshape"
reshape_param {
shape {
#n*c*(w*h)
dim: 64
dim: 512
dim: 32
}
}
}
这里的batch_size为64,所以也需要修改前面的data层。
训练用的代码
# -*- coding=utf-8 -*- import numpy as np import sys sys.path.append('~/Desktop/crnn.caffe/python') import caffe # training caffe.set_device(2) caffe.set_mode_gpu() solver = caffe.SGDSolver('solver.prototxt') net = solver.net print net.blobs['label'].data[0].shape iter_nums = 100000 for _ in range(iter_nums): solver.step(1)
在这里需要注意solver参数的选择,否则会出现不收敛的情况-_-||,这也算是一个坑吧…
deploy文件就自己生成了哈,也记得修改其中的batch_size…
# -*- coding=utf-8 -*- import sys sys.path.append('~/Desktop/crnn.caffe/python') import caffe from PIL import Image import numpy as np model_file = './snapshot/_iter_60000.caffemodel' deploy_file = 'crnn_deploy.prototxt' test_img = '2.jpg' # set device caffe.set_device(2) caffe.set_mode_gpu() # load model net = caffe.Net(deploy_file, model_file, caffe.TEST) # load test img img = Image.open(test_img) img = img.resize((128, 32), Image.ANTIALIAS) in_ = np.array(img, dtype=np.float32) in_ = in_[:,:,::-1] in_ = in_.transpose((2,0,1)) # 执行上面设置的图片预处理操作,并将图片载入到blob中 # shape for input (data blob is N x C x H x W), set data net.blobs['data'].reshape(1, *in_.shape) net.blobs['data'].data[...] = in_ # run net net.forward() # get result res = net.blobs['probs'].data print('result shape is:', res.shape) # 取出标签文档 char_set = [] with open('label.txt', 'r') as f: line = f.readline() while line: line = line.strip('\n\r') # print(line) char_set.append(str(line)) line = f.readline() # 取出最多可能的label标签 for i in range(32): data = res[i, :, :] index = np.argmax(data) #print(index, data[0, index]) print(char_set[index])
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。