当前位置:   article > 正文

DPM(Defomable Parts Model) 源码分析_deformable part model 源码

deformable part model 源码

DPM(Deformable Parts Model)--原理(一)

原文:http://blog.csdn.net/ttransposition/article/details/12966521

DPM(Deformable Parts Model)

Reference:

Object detection with discriminatively trained partbased models. IEEE Trans. PAMI, 32(9):1627–1645, 2010.

"Support Vector Machines for Multiple-Instance Learning,"Proc. Advances in Neural Information Processing Systems,2003.

作者主页:http://www.cs.berkeley.edu/~rbg/latent/index.html

  1. 大体思路

DPM是一个非常成功的目标检测算法,连续获得VOC(Visual Object Class)07,08,09年的检测冠军。目前已成为众多分类器、分割、人体姿态和行为分类的重要部分。2010年Pedro Felzenszwalb被VOC授予"终身成就奖"DPM可以看做是HOG(Histogrrams of Oriented Gradients)的扩展,大体思路与HOG一致。先计算梯度方向直方图,然后用SVM(Surpport Vector Machine )训练得到物体的梯度模型(Model)。有了这样的模板就可以直接用来分类了,简单理解就是模型和目标匹配。DPM只是在模型上做了很多改进工作。

上图是HOG论文中训练出来的人形模型。它是单模型,对直立的正面和背面人检测效果很好,较以前取得了重大的突破。也是目前为止最好的的特征(最近被CVPR20 13年的一篇论文 《Histograms of Sparse Codes for Object Detection》 超过了)。但是, 如果是侧面呢?所以自然我们会想到用多模型来做。DPM就使用了2个模型,主页上最新版本Versio5的程序使用了12个模型。

                                                                                                                                             

                                                                 

上图就是自行车的模型,左图为侧面看,右图为从正前方看。好吧,我承认已经面目全非了,这只是粗糙版本。训练的时候只是给了一堆自行车的照片,没有标注是属于component 1,还是component 2.直接按照边界的长宽比,分为2半训练。这样肯定会有很多很多分错了的情况,训练出来的自然就失真了。不过没关系,论文里面只是把这两个Model当做初始值。重点就是作者用了多模型。

 

上图右边的两个模型各使用了6个子模型,白色矩形框出来的区域就是一个子模型。基本上见过自行车的人都知道这是自行车。之所以会比左边好辨识,是因为分错component类别的问题基本上解决了,还有就是图像分辨率是左边的两倍,这个就不细说,看论文。

有了多模型就能解决视角的问题了,还有个严重的问题,动物是动的,就算是没有生命的车也有很多款式,单单用一个Model,如果动物动一下,比如美女搔首弄姿,那模型和这个美女的匹配程度就低了很多。也就是说,我们的模型太死板了,不能适应物体的运动,特别是非刚性物体的运动。自然我们又能想到添加子模型,比如给手一个子模型,当手移动时,子模型能够检测到手的位置。把子模型和主模型的匹配程度综合起来,最简单的就是相加,那模型匹配程度不就提高了吗?思路很简单吧!还有个小细节,子模型肯定不能离主模型太远了,试想下假如手到身体的位置有两倍身高那么远,那这还是人吗?也许这是个检测是不是鬼的好主意。所以我们加入子模型与主模型的位置偏移作为Cost,也就是说综合得分要减去偏移Cost.本质上就是使用子模型和主模型的空间先验知识。

 

好了,终于来了一张合影。最右边就是我们的偏移Cost,圆圈中心自然就是子模型的理性位置,如果检测出来的子模型的位置恰好在此,那Cost就为0,在周边那就要减掉一定的值,偏离的越远减掉的值越大。

最后再理一下继承发展关系,HOG特征源自于SIFT,参见《Distinctive image features from scale-invariant Keypoints》。Part Model 早在1973年就被提出参见《The representation and matching of pictorial structures》(木有看……)。

另外HOG特征可以参考鄙人博客:Opencv HOG行人检测 源码分析SIFT特征本来也想写的但是,那时候懒,而且表述比较啰嗦,就参考一位跟我同一届的北大美女的系列博客吧。【OpenCV】SIFT原理与源码分析


总之,DPM的本质就是弹簧形变模型,参见 1973年的一篇论文  The representation and matching of pictorial structures


2.检测

检测过程比较简单:

综合得分:

是rootfilter (我前面称之为主模型)的得分,或者说是匹配程度,本质就是的卷积,后面的partfilter也是如此。中间是n个partfilter(前面称之为子模型)的得分。是为了component之间对齐而设的rootoffset. 为rootfilter的left-top位置在root feature map中的坐标,为第个partfilter映射到part feature map中的坐标。是因为part feature map的分辨率是root feature map的两倍,为相对于rootfilter left-top 的偏移。

 的得分如下:

上式是在patfilter理想位置,即anchor position的一定范围内,寻找一个综合匹配和形变最优的位置。为偏移向量,为偏移向量为偏移的Cost权值。比如即为最普遍的欧氏距离。这一步称为距离变换,即下图中的transformed response。这部分的主要程序有train.m、featpyramid.m、dt.cc.

3.训练

3.1多示例学习(Multiple-instance learning)

3.1.1 MI-SVM

一般机器学习算法,每一个训练样本都需要类别标号(对于二分类:1/-1)。实际上那样的数据其实已经经过了抽象,实际的数据要获得这样的标号还是很难,图像就是个典型。还有就是数据标记的工作量太大,我们想偷懒了,所以多只是给了正负样本集。负样本集里面的样本都是负的,但是正样本里面的样本不一定都是正的,但是至少有一个样本是正的。比如检测人的问题,一张天空的照片就可以是一个负样本集;一张某某自拍照就是一个正样本集(你可以在N个区域取N个样本,但是只有部分是有人的正样本)。这样正样本的类别就很不明确,传统的方法就没法训练。

疑问来了,图像的不是有标注吗?有标注就应该有类别标号啊?这是因为图片是人标的,数据量特大,难免会有些标的不够好,这就是所谓的弱监督集(weakly supervised set)。所以如果算法能够自动找出最优的位置,那分类器不就更精确吗? 标注位置不是很准确,这个例子不是很明显,还记得前面讲过的子模型的位置吗?比如自行车的车轮的位置,是完全没有位置标注的,只知道在bounding box区域附件有一个车轮。不知道精确位置,就没法提取样本。这种情况下,车轮会有很多个可能的位置,也就会形成一个正样本集,但里面只有部分是包含轮子的。

针对上述问题《Support Vector Machines for Multiple-Instance Learning》提出了MI-SVM。本质思想是将标准SVM的最大化样本间距扩展为最大化样本集间距。具体来说是选取正样本集中最像正样本的样本用作训练,正样本集内其它的样本就等候发落。同样取负样本中离分界面最近的负样本作为负样本。因为我们的目的是要保证正样本中有正,负样本不能为正。就基本上化为了标准SVM。取最大正样本(离分界面最远),最小负样本(离分界面最近):

对于正样本: 为正样本集中选中的最像大正样本的样本。

对于负样本:可以将max展开,因为最小的负样本满足的话,其余负样本就都能满足,所以任意负样本有:

目标函数:

也就是说选取正样本集中最大的正样本,负样本集中的所有样本。与标准SVM的唯一不同之处在于拉格朗日系数的界限。

而标准SVM的约束是:

最终化为一个迭代优化问题:

思想很简单:第一步是在正样本集中优化;第二步是优化SVM模型。与K-Means这类聚类算法一样都只是简单的两步,却爆发了无穷的力量。

这里可以参考一篇博客Multiple-instance learning

关于SVM的详细理论推导就不得不推荐我最为膜拜的MIT Doctor pluskid: 支持向量机系列

关于SVM的求解:SVM学习——Sequential Minimal Optimization

SVM学习——Coordinate Desent Method

此外,与多示例学习对应的还有多标记学习(multi-lable learning)有兴趣可以了解下。二者联系很大,多示例是输入样本的标记具有歧义(可正可负),而多标记是输出样本有歧义。

3.1.2 Latent SVM

1)Latent-SVM实质上和MI-SVM是一样的。区别在于扩展了Latent变量。首先解释下Latent变量,MI-SVM决定正样本集中哪一个样本作为正样本的就是一个latent变量。不过这个变量是单一的,比较简单,取值只是正样本集中的序号而已。DPM中也是要选择最大的正样本,但是它的latent变量就特别多。比如bounding box的实际位置,在HOG特征金字塔中的level,某样本属于哪一类component。也就是说我们有了一张正样本的图片,标注了bounding box,我们要在某一位置,某一尺度,提取出一个最大正样本作为某一component的正样本。

直接看Latent-SVM的训练过程:

这一部分还牵扯到了Data-minig。先不管,先只看循环中的3-6,12.

3-6就对于MI-SVM的第一步。12就对应了MI-SVM的第二步。作者这里直接用了梯度下降法,求解最优模型β。

2)现在说下Data-minig。作者为什么不直接优化,还搞个Data-minig干嘛呢?因为,负样本数目巨大,Version3中用到的总样本数为2^28,其中Pos样本数目占的比例特别低,负样本太多,直接导致优化过程很慢,因为很多负样本远离分界面对于优化几乎没有帮助。Data-minig的作用就是去掉那些对优化作用很小的Easy-examples保留靠近分界面的Hard-examples。分别对应13和10。这样做的的理论支撑证明如下:

3)再简单说下随机梯度下降法(Stochastic Gradient Decent):

首先梯度表达式:

梯度近似:

优化流程:

这部分的主要程序:pascal_train.m->train.m->detect.m->learn.cc

3.2 训练初始化

LSVM对初始值很敏感,因此初始化也是个重头戏。分为三个阶段。英语方面我就不班门弄斧了,直接上截图。

下面稍稍提下各阶段的工作,主要是论文中没有的Latent 变量分析:

Phase1:是传统的SVM训练过程,与HOG算法一致。作者是随机将正样本按照aspect ration(长宽比)排序,然后很粗糙的均分为两半训练两个component的rootfilte。这两个rootfilter的size也就直接由分到的pos examples决定了。后续取正样本时,直接将正样本缩放成rootfilter的大小。

Phase2:是LSVM训练。Latent variables 有图像中正样本的实际位置包括空间位置(x,y),尺度位置level,以及component的类别c,即属于component1 还是属于 component 2。要训练的参数为两个 rootfilter,offset(b)

Phase3:也是LSVM过程。

先提下子模型的添加。作者固定了每个component有6个partfilter,但实际上还会根据实际情况减少。为了减少参数,partfilter都是对称的。partfilter在rootfilter中的锚点(anchor location)在按最大energy选取partfilter的时候就已经固定下来了。

这阶段的Latent variables是最多的有:rootfilter(x,y,scale),partfilters(x,y,scale)。要训练的参数为 rootfilters, rootoffset, partfilters, defs(的偏移Cost)。

这部分的主要程序:pascal_train.m

  1. 4.细节

4.1轮廓预测(Bounding Box Prediction)

仔细看下自行车的左轮,如果我们只用rootfilter检测出来的区域,即红色区域,那么前轮会被切掉一部分,但是如果能综合partfilter检测出来的bounding box就能得到更加准确的bounding box如右图。

这部分很简单就是用最小二乘(Least Squres)回归,程序中trainbox.m中直接左除搞定。

4.2 HOG

作者对HOG进行了很大的改动。作者没有用4*9=36维向量,而是对每个8x8的cell提取18+9+4=31维特征向量。作者还讨论了依据PCA(Principle Component Analysis)可视化的结果选9+4维特征,能达到HOG 4*9维特征的效果。

这里很多就不细说了。开题一个字都还没写,要赶着开题……主要是features.cc。有了下面这张图,自己慢慢研究下:

 

源码分析:

DPM(Defomable Parts Model) 源码分析-检测

DPM(Defomable Parts Model) 源码分析-训练


DPM(Defomable Parts Model) 源码分析-检测(二)

原文:http://blog.csdn.net/ttransposition/article/details/12954195

DPM(Defomable Parts Model)原理

首先声明此版本为V3.1。因为和论文最相符。V4增加了模型数由2个增加为6个,V5提取了语义特征。源码太长纯代码应该在2K+,只选取了核心部分代码

demo.m

[cpp]  view plain copy 在CODE上查看代码片 派生到我的代码片
  1. function demo()  
  2.   
  3. test('000034.jpg''car');  
  4. test('000061.jpg''person');  
  5. test('000084.jpg''bicycle');  
  6.   
  7. function test(name, cls)  
  8. % load and display image  
  9. im=imread(name);  
  10. clf;  
  11. image(im);  
  12. axis equal;   
  13. axis on;  
  14. disp('input image');  
  15. disp('press any key to continue'); pause;  
  16.   
  17. % load and display model  
  18. load(['VOC2007/' cls '_final']); %加载模型  
  19. visualizemodel(model);  
  20. disp([cls ' model']);  
  21. disp('press any key to continue'); pause;  
  22.   
  23. % detect objects  
  24. boxes = detect(im, model, 0); %model为mat中的结构体  
  25. top = nms(boxes, 0.5);  %Non-maximum suppression.  
  26. showboxes(im, top);  
  27. %print(gcf, '-djpeg90''-r0', [cls '.jpg']);  
  28. disp('detections');  
  29. disp('press any key to continue'); pause;  
  30.   
  31. % get bounding boxes  
  32. bbox = getboxes(model, boxes);  %根据检测到的root,parts,预测bounding  
  33. top = nms(bbox, 0.5);  
  34. bbox = clipboxes(im, top); %预测出来的bounding,可能会超过图像原始尺寸,所以要减掉  
  35. showboxes(im, bbox);  
  36. disp('bounding boxes');  
  37. disp('press any key to continue'); pause;  


detect.m

[cpp]  view plain copy 在CODE上查看代码片 派生到我的代码片
  1. function [boxes] = detect(input, model, thresh, bbox, ...  
  2.                           overlap, label, fid, id, maxsize)  
  3. % 论文 fig.4                         
  4.   
  5. % boxes = detect(input, model, thresh, bbox, overlap, label, fid, id, maxsize)  
  6. % Detect objects in input using a model and a score threshold.  
  7. % Higher threshold leads to fewer detections.  
  8. % boxes = [rx1 ry1 rx2 ry2 | px1 py1 px2 py2 ...| componetindex | score ]  
  9. % The function returns a matrix with one row per detected object.  The  
  10. % last column of each row gives the score of the detection.  The  
  11. % column before last specifies the component used for the detection.  
  12. % The first 4 columns specify the bounding box for the root filter and  
  13. % subsequent columns specify the bounding boxes of each part.  
  14. %  
  15. % If bbox is not empty, we pick best detection with significant overlap.   
  16. % If label and fid are included, we write feature vectors to a data file.  
  17.   
  18. %phase 2: im, model, 0, bbox, overlap, 1, fid, 2*i-1  
  19. % trian boxex : detect(im, model, 0, bbox, overlap)  
  20. if nargin > 3 && ~isempty(bbox)  
  21.   latent = true;  
  22. else  
  23.   latent = false;  
  24. end  
  25.   
  26. if nargin > 6 && fid ~= 0  
  27.   write = true;  
  28. else  
  29.   write = false;  
  30. end  
  31.   
  32. if nargin < 9  
  33.   maxsize = inf;  
  34. end  
  35.   
  36. % we assume color images  
  37. input = color(input);   %如果是灰度图,扩充为三通道 R=G=B=Gray  
  38.   
  39. % prepare model for convolutions  
  40. rootfilters = [];  
  41. for i = 1:length(model.rootfilters) %   
  42.   rootfilters{i} = model.rootfilters{i}.w;% r*w*31维向量,9(方向范围 0~180) +18(方向范围 0-360)+4(cell熵和)  
  43. end  
  44. partfilters = [];  
  45. for i = 1:length(model.partfilters)  
  46.   partfilters{i} = model.partfilters{i}.w;  
  47. end  
  48.   
  49. % cache some data 获取所有 root,part的所有信息  
  50. for c = 1:model.numcomponents   % releas3.1 一种对象,只有2个模型,releas5 有3*2个模型  
  51.   ridx{c} = model.components{c}.rootindex; % m1=1,m2=2  
  52.   oidx{c} = model.components{c}.offsetindex; %o1=1,o2=2  
  53.   root{c} = model.rootfilters{ridx{c}}.w;  
  54.   rsize{c} = [size(root{c},1) size(root{c},2)]; %root size,单位为 sbin*sbin的block块,相当于原始HOG中的一个cell  
  55.   numparts{c} = length(model.components{c}.parts); %目前为固定值6个,但是有些part是 fake  
  56.   for j = 1:numparts{c}  
  57.     pidx{c,j} = model.components{c}.parts{j}.partindex; %part是在该对象的所有component的part下连续编号  
  58.     didx{c,j} = model.components{c}.parts{j}.defindex;  % 在 rootfiter中的 anchor location  
  59.     part{c,j} = model.partfilters{pidx{c,j}}.w; % 6*6*31  
  60.     psize{c,j} = [size(part{c,j},1) size(part{c,j},2)]; %   
  61.     % reverse map from partfilter index to (component, part#)  
  62.     rpidx{pidx{c,j}} = [c j];  
  63.   end  
  64. end  
  65.   
  66. % we pad the feature maps to detect partially visible objects  
  67. padx = ceil(model.maxsize(2)/2+1); % 7/2+1 = 5  
  68. pady = ceil(model.maxsize(1)/2+1); % 11/2+1 = 7  
  69.   
  70. % the feature pyramid  
  71. interval = model.interval;  %10  
  72. %--------------------------------特征金字塔---------------------------------------------------------  
  73. % feat的尺寸为 img.rows/sbin,img.cols/sbin  
  74. % scales:缩放了多少  
  75. [feat, scales] = featpyramid(input, model.sbin, interval); % 8,10  
  76.   
  77. % detect at each scale  
  78. best = -inf;  
  79. ex = [];  
  80. boxes = [];  
  81. %---------------------逐层检测目标-----------------------------------------------------------%  
  82. for level = interval+1:length(feat) %注意是从第二层开始  
  83.   scale = model.sbin/scales(level);  % 1/缩小了多少    
  84.   if size(feat{level}, 1)+2*pady < model.maxsize(1) || ... %扩展后还是未能达到 能同时计算两个component的得分  
  85.      size(feat{level}, 2)+2*padx < model.maxsize(2) || ...  
  86.      (write && ftell(fid) >= maxsize) %已经没有空间保存样本了  
  87.     continue;  
  88.   end  
  89.     
  90.   if latent %训练时使用,检测时跳过  
  91.     skip = true;  
  92.     for c = 1:model.numcomponents  
  93.       root_area = (rsize{c}(1)*scale) * (rsize{c}(2)*scale);% rootfilter  
  94.       box_area = (bbox(3)-bbox(1)+1) * (bbox(4)-bbox(2)+1); % bbox该class 所有 rootfilter 的交集即minsize  
  95.       if (root_area/box_area) >= overlap && (box_area/root_area) >= overlap %这句话真纠结,a>=0.7b,b>=0.7a -> a>=0.7b>=0.49a  
  96.         skip = false;  
  97.       end  
  98.     end  
  99.     if skip  
  100.       continue;  
  101.     end  
  102.   end  
  103.       
  104.   % -----------convolve feature maps with filters -----------  
  105.   %rootmatch,partmatch ,得分图root的尺度总是part的一半,  
  106.   %rootmatch尺寸是partmatch的一半  
  107.   featr = padarray(feat{level}, [pady padx 0], 0);  % 上下各补充 pady 行0,左右各补充padx行 0  
  108.   %C = fconv(A, cell of B, start, end);  
  109.   rootmatch = fconv(featr, rootfilters, 1, length(rootfilters));  
  110.   if length(partfilters) > 0  
  111.     featp = padarray(feat{level-interval}, [2*pady 2*padx 0], 0);  
  112.     partmatch = fconv(featp, partfilters, 1, length(partfilters));  
  113.   end  
  114.   %-------------------逐component检测-----------------------------------  
  115.   % 参见论文 Fig 4  
  116.   % 最终得到  综合得分图   score  
  117.   for c = 1:model.numcomponents  
  118.     % root score + offset  
  119.     score = rootmatch{ridx{c}} + model.offsets{oidx{c}}.w;    
  120.     % add in parts  
  121.     for j = 1:numparts{c}  
  122.       def = model.defs{didx{c,j}}.w;  
  123.       anchor = model.defs{didx{c,j}}.anchor;  
  124.       % the anchor position is shifted to account for misalignment  
  125.       % between features at different resolutions  
  126.       ax{c,j} = anchor(1) + 1; %  
  127.       ay{c,j} = anchor(2) + 1;  
  128.       match = partmatch{pidx{c,j}};  
  129.       [M, Ix{c,j}, Iy{c,j}] = dt(-match, def(1), def(2), def(3), def(4)); % dx,dy,dx^2,dy^2的偏移惩罚系数  
  130.       % M part的综合匹配得分图,与part尺寸一致。Ix{c,j}, Iy{c,j} 即part实际的最佳位置(相对于root)  
  131.       % 参见论文公式 9  
  132.       score = score - M(ay{c,j}:2:ay{c,j}+2*(size(score,1)-1), ...  
  133.                         ax{c,j}:2:ax{c,j}+2*(size(score,2)-1));  
  134.     end  
  135.       
  136.     %-------阈值淘汰------------------------  
  137.     if ~latent  
  138.       % get all good matches  
  139.       % ---thresh  在 分类时为0,在 找 hard exmaple 时是 -1.05--  
  140.       I = find(score > thresh);  %返回的是从上到下从左到右的索引  
  141.       [Y, X] = ind2sub(size(score), I);  %还原为 行,列坐标        
  142.       tmp = zeros(length(I), 4*(1+numparts{c})+2);  %一个目标的root,part,score信息,见程序开头说明  
  143.       for i = 1:length(I)  
  144.         x = X(i);  
  145.         y = Y(i);  
  146.         [x1, y1, x2, y2] = rootbox(x, y, scale, padx, pady, rsize{c});  
  147.         b = [x1 y1 x2 y2];  
  148.         if write  
  149.           rblocklabel = model.rootfilters{ridx{c}}.blocklabel;  
  150.           oblocklabel = model.offsets{oidx{c}}.blocklabel;        
  151.           f = featr(y:y+rsize{c}(1)-1, x:x+rsize{c}(2)-1, :);  
  152.           xc = round(x + rsize{c}(2)/2 - padx); %   
  153.           yc = round(y + rsize{c}(1)/2 - pady);  
  154.           ex = [];  
  155.           ex.header = [label; id; level; xc; yc; ...  
  156.                        model.components{c}.numblocks; ...  
  157.                        model.components{c}.dim];  
  158.           ex.offset.bl = oblocklabel;  
  159.           ex.offset.w = 1;  
  160.           ex.root.bl = rblocklabel;  
  161.           width1 = ceil(rsize{c}(2)/2);  
  162.           width2 = floor(rsize{c}(2)/2);  
  163.           f(:,1:width2,:) = f(:,1:width2,:) + flipfeat(f(:,width1+1:end,:));  
  164.           ex.root.w = f(:,1:width1,:);  
  165.           ex.part = [];  
  166.         end  
  167.         for j = 1:numparts{c}  
  168.           [probex, probey, px, py, px1, py1, px2, py2] = ...  
  169.               partbox(x, y, ax{c,j}, ay{c,j}, scale, padx, pady, ...  
  170.                       psize{c,j}, Ix{c,j}, Iy{c,j});  
  171.           b = [b px1 py1 px2 py2];  
  172.           if write  
  173.             if model.partfilters{pidx{c,j}}.fake  
  174.               continue;  
  175.             end  
  176.             pblocklabel = model.partfilters{pidx{c,j}}.blocklabel;  
  177.             dblocklabel = model.defs{didx{c,j}}.blocklabel;  
  178.             f = featp(py:py+psize{c,j}(1)-1,px:px+psize{c,j}(2)-1,:);  
  179.             def = -[(probex-px)^2; probex-px; (probey-py)^2; probey-py];  
  180.             partner = model.partfilters{pidx{c,j}}.partner;  
  181.             if partner > 0  
  182.               k = rpidx{partner}(2);  
  183.               [kprobex, kprobey, kpx, kpy, kpx1, kpy1, kpx2, kpy2] = ...  
  184.                   partbox(x, y, ax{c,k}, ay{c,k}, scale, padx, pady, ...  
  185.                           psize{c,k}, Ix{c,k}, Iy{c,k});  
  186.               kf = featp(kpy:kpy+psize{c,k}(1)-1,kpx:kpx+psize{c,k}(2)-1,:);  
  187.               % flip linear term in horizontal deformation model  
  188.               kdef = -[(kprobex-kpx)^2; kpx-kprobex; ...  
  189.                        (kprobey-kpy)^2; kprobey-kpy];  
  190.               f = f + flipfeat(kf);  
  191.               def = def + kdef;  
  192.             else  
  193.               width1 = ceil(psize{c,j}(2)/2);  
  194.               width2 = floor(psize{c,j}(2)/2);  
  195.               f(:,1:width2,:) = f(:,1:width2,:) + flipfeat(f(:,width1+1:end,:));  
  196.               f = f(:,1:width1,:);  
  197.             end  
  198.             ex.part(j).bl = pblocklabel;  
  199.             ex.part(j).w = f;  
  200.             ex.def(j).bl = dblocklabel;  
  201.             ex.def(j).w = def;  
  202.           end  
  203.         end  
  204.         if write  
  205.           exwrite(fid, ex); % 写入负样本  
  206.         end  
  207.         tmp(i,:) = [b c score(I(i))];  
  208.       end  
  209.       boxes = [boxes; tmp];  
  210.     end  
  211.   
  212.     if latent  
  213.       % get best match  
  214.       for x = 1:size(score,2)  
  215.         for y = 1:size(score,1)  
  216.           if score(y, x) > best    
  217.             % 以该(y,x)为left-top点的rootfilter的范围在原图像中的位置  
  218.             [x1, y1, x2, y2] = rootbox(x, y, scale, padx, pady, rsize{c});  
  219.             % intesection with bbox  
  220.             xx1 = max(x1, bbox(1));  
  221.             yy1 = max(y1, bbox(2));  
  222.             xx2 = min(x2, bbox(3));  
  223.             yy2 = min(y2, bbox(4));  
  224.             w = (xx2-xx1+1);  
  225.             h = (yy2-yy1+1);  
  226.             if w > 0 && h > 0  
  227.               % check overlap with bbox  
  228.               inter = w*h;  
  229.               a = (x2-x1+1) * (y2-y1+1); % rootfilter 的面积  
  230.               b = (bbox(3)-bbox(1)+1) * (bbox(4)-bbox(2)+1); % bbox的面积  
  231.               % 计算很很独特,如果只是 inter / b 那么 如果a很大,只是一部分与 bounding box重合,那就不可靠了,人再怎么标注错误,也不会这么大  
  232.               % 所以,a越大,要求的重合率越高才好,所以分母+a,是个不错的选择,但是这样减小的太多了,所以减去 inter  
  233.               o = inter / (a+b-inter);  
  234.               if (o >= overlap)  
  235.                 %  
  236.                 best = score(y, x);  
  237.                 boxes = [x1 y1 x2 y2];  
  238.                 % 这一部分一直被覆盖,最后保留的是 best样本  
  239.                 if write                    
  240.                   f = featr(y:y+rsize{c}(1)-1, x:x+rsize{c}(2)-1, :);  
  241.                   rblocklabel = model.rootfilters{ridx{c}}.blocklabel;  
  242.                   oblocklabel = model.offsets{oidx{c}}.blocklabel;        
  243.                   xc = round(x + rsize{c}(2)/2 - padx);  
  244.                   yc = round(y + rsize{c}(1)/2 - pady);            
  245.                   ex = [];  
  246.                   % label; id; level; xc; yc,正样本的重要信息!  
  247.                   % xc,yc,居然是相对于剪切后的图片  
  248.                   ex.header = [label; id; level; xc; yc; ...  
  249.                                model.components{c}.numblocks; ...  
  250.                                model.components{c}.dim];  
  251.                   ex.offset.bl = oblocklabel;  
  252.                   ex.offset.w = 1;  
  253.                   ex.root.bl = rblocklabel;  
  254.                   width1 = ceil(rsize{c}(2)/2);  
  255.                   width2 = floor(rsize{c}(2)/2);  
  256.                   f(:,1:width2,:) = f(:,1:width2,:) + flipfeat(f(:,width1+1:end,:));  
  257.                   ex.root.w = f(:,1:width1,:); %样本特征  
  258.                   ex.part = [];  
  259.                 end  
  260.                 for j = 1:numparts{c}  
  261.                   %probex,probey综合得分最高的位置,相对于featp  
  262.                   %px1,py1,px2,py2 转化成相对于featr  
  263.                   [probex, probey, px, py, px1, py1, px2, py2] = ...  
  264.                       partbox(x, y, ax{c,j}, ay{c,j}, scale, ...  
  265.                               padx, pady, psize{c,j}, Ix{c,j}, Iy{c,j});  
  266.                   boxes = [boxes px1 py1 px2 py2];  
  267.                   if write  
  268.                     if model.partfilters{pidx{c,j}}.fake  
  269.                       continue;  
  270.                     end  
  271.                     p = featp(py:py+psize{c,j}(1)-1, ...  
  272.                               px:px+psize{c,j}(2)-1, :);  
  273.                     def = -[(probex-px)^2; probex-px; (probey-py)^2; probey-py];  
  274.                     pblocklabel = model.partfilters{pidx{c,j}}.blocklabel;  
  275.                     dblocklabel = model.defs{didx{c,j}}.blocklabel;  
  276.                     partner = model.partfilters{pidx{c,j}}.partner;  
  277.                     if partner > 0  
  278.                       k = rpidx{partner}(2);  
  279.                       [kprobex, kprobey, kpx, kpy, kpx1, kpy1, kpx2, kpy2] = ...  
  280.                           partbox(x, y, ax{c,k}, ay{c,k}, scale, padx, pady, ...  
  281.                                   psize{c,k}, Ix{c,k}, Iy{c,k});  
  282.                       kp = featp(kpy:kpy+psize{c,k}(1)-1, ...  
  283.                                  kpx:kpx+psize{c,k}(2)-1, :);  
  284.                       % flip linear term in horizontal deformation model  
  285.                       kdef = -[(kprobex-kpx)^2; kpx-kprobex; ...  
  286.                                (kprobey-kpy)^2; kprobey-kpy];  
  287.                       p = p + flipfeat(kp);  
  288.                       def = def + kdef;  
  289.                     else  
  290.                       width1 = ceil(psize{c,j}(2)/2);  
  291.                       width2 = floor(psize{c,j}(2)/2);  
  292.                       p(:,1:width2,:) = p(:,1:width2,:) + ...  
  293.                           flipfeat(p(:,width1+1:end,:));  
  294.                       p = p(:,1:width1,:);  
  295.                     end  
  296.                     ex.part(j).bl = pblocklabel;  
  297.                     ex.part(j).w = p;  
  298.                     ex.def(j).bl = dblocklabel;  
  299.                     ex.def(j).w = def;  
  300.                   end  
  301.                 end  
  302.                 boxes = [boxes c best];  
  303.               end  
  304.             end  
  305.           end  
  306.         end  
  307.       end  
  308.     end  
  309.   end  
  310. end  
  311.   
  312. if latent && write && ~isempty(ex)  
  313.   exwrite(fid, ex); %datfile  
  314. end  
  315.   
  316. % The functions below compute a bounding box for a root or part   
  317. template placed in the feature hierarchy.  
  318. %  
  319. % coordinates need to be transformed to take into account:  
  320. % 1. padding from convolution  
  321. % 2. scaling due to sbin & image subsampling  
  322. % 3. offset from feature computation      
  323. %  
  324.   
  325. function [x1, y1, x2, y2] = rootbox(x, y, scale, padx, pady, rsize)  
  326. x1 = (x-padx)*scale+1;  %图像是先缩放(构造金字塔时)再打补丁  
  327. y1 = (y-pady)*scale+1;  
  328. x2 = x1 + rsize(2)*scale - 1; % 宽度也要缩放  
  329. y2 = y1 + rsize(1)*scale - 1;  
  330.   
  331. function [probex, probey, px, py, px1, py1, px2, py2] = ...  
  332.     partbox(x, y, ax, ay, scale, padx, pady, psize, Ix, Iy)  
  333. probex = (x-1)*2+ax; %最优位置  
  334. probey = (y-1)*2+ay;  
  335. px = double(Ix(probey, probex)); %综合得分最高的位置  
  336. py = double(Iy(probey, probex));  
  337. px1 = ((px-2)/2+1-padx)*scale+1; % pading是root的两倍  
  338. py1 = ((py-2)/2+1-pady)*scale+1;  
  339. px2 = px1 + psize(2)*scale/2 - 1;  
  340. py2 = py1 + psize(1)*scale/2 - 1;  
  341.   
  342. % write an example to the data file  
  343. function exwrite(fid, ex)  
  344. fwrite(fid, ex.header, 'int32');  
  345. buf = [ex.offset.bl; ex.offset.w(:); ...  
  346.        ex.root.bl; ex.root.w(:)];  
  347. fwrite(fid, buf, 'single');  
  348. for j = 1:length(ex.part)  
  349.   if ~isempty(ex.part(j).w)  
  350.     buf = [ex.part(j).bl; ex.part(j).w(:); ...  
  351.            ex.def(j).bl; ex.def(j).w(:)];  
  352.     fwrite(fid, buf, 'single');  
  353.   end  
  354. end  


features.cc

[cpp]  view plain copy 在CODE上查看代码片 派生到我的代码片
  1. #include <math.h>  
  2. #include "mex.h"  
  3.   
  4. // small value, used to avoid division by zero  
  5. #define eps 0.0001  
  6.   
  7. #define bzero(a, b) memset(a, 0, b)   
  8. int round(float a) { float tmp = a - (int)a; if( tmp >= 0.5 ) return (int)a + 1; else return (int)a; }  
  9. // unit vectors used to compute gradient orientation  
  10. // cos(20*i)  
  11. double uu[9] = {1.0000,   
  12.         0.9397,   
  13.         0.7660,   
  14.         0.500,   
  15.         0.1736,   
  16.         -0.1736,   
  17.         -0.5000,   
  18.         -0.7660,   
  19.         -0.9397};  
  20. //sin(20*i)  
  21. double vv[9] = {0.0000,   
  22.         0.3420,   
  23.         0.6428,   
  24.         0.8660,   
  25.         0.9848,   
  26.         0.9848,   
  27.         0.8660,   
  28.         0.6428,   
  29.         0.3420};  
  30.   
  31. static inline double min(double x, double y) { return (x <= y ? x : y); }  
  32. static inline double max(double x, double y) { return (x <= y ? y : x); }  
  33.   
  34. static inline int min(int x, int y) { return (x <= y ? x : y); }  
  35. static inline int max(int x, int y) { return (x <= y ? y : x); }  
  36.   
  37. // main function:  
  38. // takes a double color image and a bin size   
  39. // returns HOG features  
  40. mxArray *process(const mxArray *mximage, const mxArray *mxsbin) {  
  41.   double *im = (double *)mxGetPr(mximage);  
  42.   const int *dims = mxGetDimensions(mximage);  
  43.   if (mxGetNumberOfDimensions(mximage) != 3 ||  
  44.       dims[2] != 3 ||  
  45.       mxGetClassID(mximage) != mxDOUBLE_CLASS)  
  46.     mexErrMsgTxt("Invalid input");  
  47.   
  48.   int sbin = (int)mxGetScalar(mxsbin);  
  49.   
  50.   // memory for caching orientation histograms & their norms  
  51.   int blocks[2];  
  52.   blocks[0] = (int)round((double)dims[0]/(double)sbin);//行  
  53.   blocks[1] = (int)round((double)dims[1]/(double)sbin);//列  
  54.   double *hist = (double *)mxCalloc(blocks[0]*blocks[1]*18, sizeof(double));//只需要计算18bin,9bin的推  
  55.   double *norm = (double *)mxCalloc(blocks[0]*blocks[1], sizeof(double));  
  56.   
  57.   // memory for HOG features  
  58.   int out[3];//size  
  59.   out[0] = max(blocks[0]-2, 0);//减去2干嘛??  
  60.   out[1] = max(blocks[1]-2, 0);  
  61.   out[2] = 27+4;  
  62.   mxArray *mxfeat = mxCreateNumericArray(3, out, mxDOUBLE_CLASS, mxREAL);//特征,size=out   
  63.   double *feat = (double *)mxGetPr(mxfeat);  
  64.     
  65.   int visible[2];  
  66.   visible[0] = blocks[0]*sbin;  
  67.   visible[1] = blocks[1]*sbin;  
  68.   //先列再行  
  69.   for (int x = 1; x < visible[1]-1; x++) {  
  70.     for (int y = 1; y < visible[0]-1; y++) {  
  71.       // first color channel  
  72.       double *s = im + min(x, dims[1]-2)*dims[0] + min(y, dims[0]-2);//在im中的位置  
  73.       double dy = *(s+1) - *(s-1);  
  74.       double dx = *(s+dims[0]) - *(s-dims[0]); //坐标系是一样的,c和matlab的存储顺序不一样  
  75.       double v = dx*dx + dy*dy;  
  76.   
  77.       // second color channel  
  78.       s += dims[0]*dims[1];  
  79.       double dy2 = *(s+1) - *(s-1);  
  80.       double dx2 = *(s+dims[0]) - *(s-dims[0]);  
  81.       double v2 = dx2*dx2 + dy2*dy2;  
  82.   
  83.       // third color channel  
  84.       s += dims[0]*dims[1];  
  85.       double dy3 = *(s+1) - *(s-1);  
  86.       double dx3 = *(s+dims[0]) - *(s-dims[0]);  
  87.       double v3 = dx3*dx3 + dy3*dy3;  
  88.   
  89.       // pick channel with strongest gradient,计算v  
  90.       if (v2 > v) {  
  91.         v = v2;  
  92.         dx = dx2;  
  93.         dy = dy2;  
  94.           }   
  95.           if (v3 > v) {  
  96.         v = v3;  
  97.         dx = dx3;  
  98.         dy = dy3;  
  99.       }  
  100.   
  101.       // snap to one of 18 orientations,就算角度best_o  
  102.       double best_dot = 0;  
  103.       int best_o = 0;  
  104.       for (int o = 0; o < 9; o++) {  
  105.         // (sinθ)^2+(cosθ)^2 =1  
  106.         // max cosθ*dx+ sinθ*dy 对其求导,可得极大值 θ = arctan dy/dx  
  107.         double dot = uu[o]*dx + vv[o]*dy;  
  108.         if (dot > best_dot) {  
  109.           best_dot = dot;  
  110.           best_o = o;  
  111.         } else if (-dot > best_dot) {  
  112.           best_dot = -dot;  
  113.           best_o = o+9;  
  114.         }  
  115.       }  
  116.         
  117.       // add to 4 histograms around pixel using linear interpolation  
  118.       double xp = ((double)x+0.5)/(double)sbin - 0.5;  
  119.       double yp = ((double)y+0.5)/(double)sbin - 0.5;  
  120.       int ixp = (int)floor(xp);  
  121.       int iyp = (int)floor(yp);  
  122.       double vx0 = xp-ixp;  
  123.       double vy0 = yp-iyp;  
  124.       double vx1 = 1.0-vx0;  
  125.       double vy1 = 1.0-vy0;  
  126.       v = sqrt(v);  
  127.     //左上角     
  128.       if (ixp >= 0 && iyp >= 0) {  
  129.         *(hist + ixp*blocks[0] + iyp + best_o*blocks[0]*blocks[1]) +=   
  130.           vx1*vy1*v;  
  131.       }  
  132.       //右上角        
  133.       if (ixp+1 < blocks[1] && iyp >= 0) {  
  134.         *(hist + (ixp+1)*blocks[0] + iyp + best_o*blocks[0]*blocks[1]) +=   
  135.           vx0*vy1*v;  
  136.       }  
  137.       //左下角  
  138.       if (ixp >= 0 && iyp+1 < blocks[0]) {  
  139.         *(hist + ixp*blocks[0] + (iyp+1) + best_o*blocks[0]*blocks[1]) +=   
  140.           vx1*vy0*v;  
  141.       }  
  142.       //右下角  
  143.       if (ixp+1 < blocks[1] && iyp+1 < blocks[0]) {  
  144.         *(hist + (ixp+1)*blocks[0] + (iyp+1) + best_o*blocks[0]*blocks[1]) +=   
  145.           vx0*vy0*v;  
  146.       }  
  147.     }  
  148.   }  
  149.   
  150.   // compute energy in each block by summing over orientations  
  151.   //计算每一个cell的 sum( ( v(oi)+v(oi+9) )^2 ),oi=0..8  
  152.   for (int o = 0; o < 9; o++) {  
  153.     double *src1 = hist + o*blocks[0]*blocks[1];  
  154.     double *src2 = hist + (o+9)*blocks[0]*blocks[1];  
  155.     double *dst = norm;  
  156.     double *end = norm + blocks[1]*blocks[0];  
  157.     while (dst < end) {  
  158.       *(dst++) += (*src1 + *src2) * (*src1 + *src2);  
  159.       src1++;  
  160.       src2++;  
  161.     }  
  162.   }  
  163.   
  164.   // compute features  
  165.   for (int x = 0; x < out[1]; x++) {  
  166.     for (int y = 0; y < out[0]; y++) {  
  167.       double *dst = feat + x*out[0] + y;        
  168.       double *src, *p, n1, n2, n3, n4;  
  169.   
  170.       p = norm + (x+1)*blocks[0] + y+1;//右下角的constrain insensitive sum  
  171.       n1 = 1.0 / sqrt(*p + *(p+1) + *(p+blocks[0]) + *(p+blocks[0]+1) + eps);  
  172.       p = norm + (x+1)*blocks[0] + y;//右边  
  173.       n2 = 1.0 / sqrt(*p + *(p+1) + *(p+blocks[0]) + *(p+blocks[0]+1) + eps);  
  174.       p = norm + x*blocks[0] + y+1;//下边  
  175.       n3 = 1.0 / sqrt(*p + *(p+1) + *(p+blocks[0]) + *(p+blocks[0]+1) + eps);  
  176.       p = norm + x*blocks[0] + y;//自己        
  177.       n4 = 1.0 / sqrt(*p + *(p+1) + *(p+blocks[0]) + *(p+blocks[0]+1) + eps);  
  178.   
  179.       double t1 = 0;  
  180.       double t2 = 0;  
  181.       double t3 = 0;  
  182.       double t4 = 0;  
  183.   
  184.       // contrast-sensitive features  
  185.       src = hist + (x+1)*blocks[0] + (y+1);  
  186.       for (int o = 0; o < 18; o++) {  
  187.         double h1 = min(*src * n1, 0.2);//截短  
  188.         double h2 = min(*src * n2, 0.2);  
  189.         double h3 = min(*src * n3, 0.2);  
  190.         double h4 = min(*src * n4, 0.2);  
  191.         *dst = 0.5 * (h1 + h2 + h3 + h4);//求和  
  192.         t1 += h1;  
  193.         t2 += h2;  
  194.         t3 += h3;  
  195.         t4 += h4;  
  196.         dst += out[0]*out[1];//下一个bin  
  197.         src += blocks[0]*blocks[1];  
  198.       }  
  199.   
  200.       // contrast-insensitive features  
  201.       src = hist + (x+1)*blocks[0] + (y+1);  
  202.       for (int o = 0; o < 9; o++) {  
  203.         double sum = *src + *(src + 9*blocks[0]*blocks[1]);  
  204.         double h1 = min(sum * n1, 0.2);  
  205.         double h2 = min(sum * n2, 0.2);  
  206.         double h3 = min(sum * n3, 0.2);  
  207.         double h4 = min(sum * n4, 0.2);  
  208.         *dst = 0.5 * (h1 + h2 + h3 + h4);  
  209.         dst += out[0]*out[1];  
  210.         src += blocks[0]*blocks[1];  
  211.       }  
  212.   
  213.       // texture features  
  214.       *dst = 0.2357 * t1;  
  215.       dst += out[0]*out[1];  
  216.       *dst = 0.2357 * t2;  
  217.       dst += out[0]*out[1];  
  218.       *dst = 0.2357 * t3;  
  219.       dst += out[0]*out[1];  
  220.       *dst = 0.2357 * t4;  
  221.     }  
  222.   }  
  223.   
  224.   mxFree(hist);  
  225.   mxFree(norm);  
  226.   return mxfeat;  
  227. }  
  228.   
  229. // matlab entry point  
  230. // F = features(image, bin)  
  231. // image should be color with double values  
  232. void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[]) {   
  233.   if (nrhs != 2)  
  234.     mexErrMsgTxt("Wrong number of inputs");   
  235.   if (nlhs != 1)  
  236.     mexErrMsgTxt("Wrong number of outputs");  
  237.   plhs[0] = process(prhs[0], prhs[1]);  
  238. }  

 

dt.cc

[cpp]  view plain copy 在CODE上查看代码片 派生到我的代码片
  1. #include <math.h>  
  2. #include <sys/types.h>  
  3. #include "mex.h"  
  4.   
  5. #define int32_t int  
  6. /* 
  7.  * Generalized distance transforms. 
  8.  * We use a simple nlog(n) divide and conquer algorithm instead of the 
  9.  * theoretically faster linear method, for no particular reason except 
  10.  * that this is a bit simpler and I wanted to test it out. 
  11.  * 
  12.  * The code is a bit convoluted because dt1d can operate either along 
  13.  * a row or column of an array.   
  14.  */  
  15.   
  16. static inline int square(int x) { return x*x; }  
  17.   
  18. // dt helper function  
  19. void dt_helper(double *src, double *dst, int *ptr, int step,   
  20.            int s1, int s2, int d1, int d2, double a, double b) {  
  21.  if (d2 >= d1) {  
  22.    int d = (d1+d2) >> 1;  
  23.    int s = s1;  
  24.    for (int p = s1+1; p <= s2; p++)  
  25.      if (src[s*step] + a*square(d-s) + b*(d-s) >   
  26.      src[p*step] + a*square(d-p) + b*(d-p))  
  27.     s = p;  
  28.    dst[d*step] = src[s*step] + a*square(d-s) + b*(d-s);  
  29.    ptr[d*step] = s;  
  30.    dt_helper(src, dst, ptr, step, s1, s, d1, d-1, a, b);  
  31.    dt_helper(src, dst, ptr, step, s, s2, d+1, d2, a, b);  
  32.  }  
  33. }  
  34.   
  35. // dt of 1d array  
  36. void dt1d(double *src, double *dst, int *ptr, int step, int n,   
  37.       double a, double b) {  
  38.   dt_helper(src, dst, ptr, step, 0, n-1, 0, n-1, a, b);  
  39. }  
  40.   
  41. // matlab entry point  
  42. // [M, Ix, Iy] = dt(vals, ax, bx, ay, by)  
  43. void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[]) {   
  44.   if (nrhs != 5)  
  45.     mexErrMsgTxt("Wrong number of inputs");   
  46.   if (nlhs != 3)  
  47.     mexErrMsgTxt("Wrong number of outputs");  
  48.   if (mxGetClassID(prhs[0]) != mxDOUBLE_CLASS)  
  49.     mexErrMsgTxt("Invalid input");  
  50.   
  51.   const int *dims = mxGetDimensions(prhs[0]);  
  52.   double *vals = (double *)mxGetPr(prhs[0]);  
  53.   double ax = mxGetScalar(prhs[1]);  
  54.   double bx = mxGetScalar(prhs[2]);  
  55.   double ay = mxGetScalar(prhs[3]);  
  56.   double by = mxGetScalar(prhs[4]);  
  57.     
  58.   mxArray *mxM = mxCreateNumericArray(2, dims, mxDOUBLE_CLASS, mxREAL);  
  59.   mxArray *mxIx = mxCreateNumericArray(2, dims, mxINT32_CLASS, mxREAL);  
  60.   mxArray *mxIy = mxCreateNumericArray(2, dims, mxINT32_CLASS, mxREAL);  
  61.   double *M = (double *)mxGetPr(mxM);  
  62.   int32_t *Ix = (int32_t *)mxGetPr(mxIx);  
  63.   int32_t *Iy = (int32_t *)mxGetPr(mxIy);  
  64.   
  65.   double *tmpM = (double *)mxCalloc(dims[0]*dims[1], sizeof(double)); // part map  
  66.   int32_t *tmpIx = (int32_t *)mxCalloc(dims[0]*dims[1], sizeof(int32_t));  
  67.   int32_t *tmpIy = (int32_t *)mxCalloc(dims[0]*dims[1], sizeof(int32_t));  
  68.   
  69.   for (int x = 0; x < dims[1]; x++)  
  70.     dt1d(vals+x*dims[0], tmpM+x*dims[0], tmpIy+x*dims[0], 1, dims[0], ay, by);  
  71.   
  72.   for (int y = 0; y < dims[0]; y++)  
  73.     dt1d(tmpM+y, M+y, tmpIx+y, dims[0], dims[1], ax, bx);  
  74.   
  75.   // get argmins and adjust for matlab indexing from 1  
  76.   for (int x = 0; x < dims[1]; x++) {  
  77.     for (int y = 0; y < dims[0]; y++) {  
  78.       int p = x*dims[0]+y;  
  79.       Ix[p] = tmpIx[p]+1;  
  80.       Iy[p] = tmpIy[tmpIx[p]*dims[0]+y]+1;  
  81.     }  
  82.   }  
  83.   
  84.   mxFree(tmpM);  
  85.   mxFree(tmpIx);  
  86.   mxFree(tmpIy);  
  87.   plhs[0] = mxM;  
  88.   plhs[1] = mxIx;  
  89.   plhs[2] = mxIy;  
  90. }  

DPM(Defomable Parts Model) 源码分析-训练(三)

原文:http://blog.csdn.net/ttransposition/article/details/12954631

DPM(Defomable Parts Model)原理

首先调用格式:

example:
pascal('person', 2);   % train and evaluate a 2 component person model

pascal_train.m

[cpp]  view plain copy 在CODE上查看代码片 派生到我的代码片
  1. function model = pascal_train(cls, n) % n=2  
  2.   
  3. % model = pascal_train(cls)  
  4. % Train a model using the PASCAL dataset.  
  5.   
  6. globals;   
  7. %----------读取正负样本-----------------------  
  8. % pos.im,neg.im存储了图像路径,pos.x1..pos.y2为box,负样本无box  
  9. [pos, neg] = pascal_data(cls);  
  10.   
  11. % 按照长宽比,分成等量的两部分? 即将 component label  固定,phase2时,该值为latent variable。  spos为索引  
  12. spos = split(pos, n);  
  13.   
  14. % -----------phase 1 : train root filters using warped positives & random negatives-----------  
  15. try  
  16.   load([cachedir cls '_random']);  
  17. catch  
  18. % -----------------------------phas 1--------------------------------  
  19. % 初始化 rootfilters  
  20.   for i=1:n  
  21.     models{i} = initmodel(spos{i});  
  22.     %---------train-------------  
  23.     % model.rootfilters{i}.w  
  24.     % model.offsets{i}.w  
  25.     models{i} = train(cls, models{i}, spos{i}, neg, 1, 1, 1, 1, 2^28);  
  26.   
  27.   end  
  28.   save([cachedir cls '_random'], 'models');  
  29. end  
  30.   
  31. % -----------------phase2-------------------------------------------  
  32. % :merge models and train using latent detections & hard negatives  
  33. try   
  34.   load([cachedir cls '_hard']);  
  35. catch  
  36.   model = mergemodels(models);  
  37.   model = train(cls, model, pos, neg(1:200), 0, 0, 2, 2, 2^28, true, 0.7);  
  38.   save([cachedir cls '_hard'], 'model');  
  39. end  
  40. %----------------phase 3----------------------------------------------  
  41. % add parts and update models using latent detections & hard negatives.  
  42. try   
  43.   load([cachedir cls '_parts']);  
  44. catch  
  45.   for i=1:n  
  46.     model = addparts(model, i, 6);  
  47.   end   
  48.   % use more data mining iterations in the beginning  
  49.   model = train(cls, model, pos, neg(1:200), 0, 0, 1, 4, 2^30, true, 0.7);  
  50.   model = train(cls, model, pos, neg(1:200), 0, 0, 6, 2, 2^30, true, 0.7, true);  
  51.   save([cachedir cls '_parts'], 'model');  
  52. end  
  53.   
  54. % update models using full set of negatives.  
  55. try   
  56.   load([cachedir cls '_mine']);  
  57. catch  
  58.   model = train(cls, model, pos, neg, 0, 0, 1, 3, 2^30, true, 0.7, true, ...  
  59.                 0.003*model.numcomponents, 2);  
  60.   save([cachedir cls '_mine'], 'model');  
  61. end  
  62.   
  63. % train bounding box prediction  
  64. try  
  65.   load([cachedir cls '_final']);  
  66. catch  
  67.  % 论文中说用最小二乘,怎么直接相除了,都不考虑矩阵的奇异性  
  68.   model = trainbox(cls, model, pos, 0.7);  
  69.   save([cachedir cls '_final'], 'model');  
  70. end  

initmodel.m

[cpp]  view plain copy 在CODE上查看代码片 派生到我的代码片
  1. function model = initmodel(pos, sbin, size)  
  2.   
  3. % model = initmodel(pos, sbin, size)  
  4. % Initialize model structure.  
  5. %  
  6. % If not supplied the dimensions of the model template are computed  
  7. % from statistics in the postive examples.  
  8. %   
  9. % This should be documented! :-)  
  10. % model.sbin         8  
  11. % model.interval     10  
  12. % model.numblocks     phase 1 :单独训练rootfilter时为2,offset,rootfilter;phase 2,为 4   
  13. % model.numcomponents  1  
  14. % model.blocksizes     (1)=1,(2)= root.h*root.w/2*31  
  15. % model.regmult        0,1  
  16. % model.learnmult      20,1  
  17. % model.maxsize        root 的size   
  18. % model.minsize  
  19. % model.rootfilters{i}  
  20. %   .size               以sbin为单位,尺寸为综合各样本的h/w,area计算出来的  
  21. %   .w  
  22. %   .blocklabel        blocklabel是为编号,offset(2),rootfilter(2),partfilter(12 or less),def (12 same as part)虽然意义不同但是放在一起统一编号  
  23. % model.partfilters{i}  
  24. %   .w  
  25. %   .blocklabel  
  26. % model.defs{i}  
  27. %   .anchor  
  28. %   .w  
  29. %   .blocklabel  
  30. % model.offsets{i}  
  31. %   .w               0  
  32. %   .blocklabel       1  
  33. % model.components{i}  
  34. %   .rootindex    1  
  35. %   .parts{j}  
  36. %     .partindex  
  37. %     .defindex  
  38. %   .offsetindex    1  
  39. %   .dim             2 + model.blocksizes(1) + model.blocksizes(2)  
  40. %   .numblocks       2  
  41.   
  42. % pick mode of aspect ratios  
  43. h = [pos(:).y2]' - [pos(:).y1]' + 1;  
  44. w = [pos(:).x2]' - [pos(:).x1]' + 1;  
  45. xx = -2:.02:2;  
  46. filter = exp(-[-100:100].^2/400); % e^-25,e^25  
  47. aspects = hist(log(h./w), xx); %  
  48. aspects = convn(aspects, filter, 'same');  
  49. [peak, I] = max(aspects);  
  50. aspect = exp(xx(I)); %滤波后最大的h/w,作为最典型的h/w  
  51.   
  52. % pick 20 percentile area  
  53. areas = sort(h.*w);  
  54. area = areas(floor(length(areas) * 0.2)); % 比它大的,可以缩放,比该尺寸小的呢?  
  55. area = max(min(area, 5000), 3000); %限制在 3000-5000  
  56.   
  57. % pick dimensions  
  58. w = sqrt(area/aspect);  
  59. h = w*aspect;  
  60.   
  61. % size of HOG features  
  62. if nargin < 4  
  63.   model.sbin = 8;  
  64. else  
  65.   model.sbin = sbin;  
  66. end  
  67.   
  68. % size of root filter  
  69. if nargin < 5  
  70.   model.rootfilters{1}.size = [round(h/model.sbin) round(w/model.sbin)];  
  71. else  
  72.   model.rootfilters{1}.size = size;  
  73. end  
  74.   
  75. % set up offset   
  76. model.offsets{1}.w = 0;  
  77. model.offsets{1}.blocklabel = 1;  
  78. model.blocksizes(1) = 1;  
  79. model.regmult(1) = 0;  
  80. model.learnmult(1) = 20;  
  81. model.lowerbounds{1} = -100;  
  82.   
  83. % set up root filter  
  84. model.rootfilters{1}.w = zeros([model.rootfilters{1}.size 31]);  
  85. height = model.rootfilters{1}.size(1);  
  86. % root filter is symmetricf  
  87. width = ceil(model.rootfilters{1}.size(2)/2);  % ??? /2  
  88. model.rootfilters{1}.blocklabel = 2;  
  89. model.blocksizes(2) = width * height * 31;  
  90. model.regmult(2) = 1;  
  91. model.learnmult(2) = 1;  
  92. model.lowerbounds{2} = -100*ones(model.blocksizes(2),1);  
  93.   
  94. % set up one component model  
  95. model.components{1}.rootindex = 1;  
  96. model.components{1}.offsetindex = 1;  
  97. model.components{1}.parts = {};  
  98. model.components{1}.dim = 2 + model.blocksizes(1) + model.blocksizes(2);  
  99. model.components{1}.numblocks = 2;  
  100.   
  101. % initialize the rest of the model structure  
  102. model.interval = 10;  
  103. model.numcomponents = 1;  
  104. model.numblocks = 2;  
  105. model.partfilters = {};  
  106. model.defs = {};  
  107. model.maxsize = model.rootfilters{1}.size;  
  108. model.minsize = model.rootfilters{1}.size;  


 

learn.cc

[cpp]  view plain copy 在CODE上查看代码片 派生到我的代码片
  1. #include <stdio.h>  
  2. #include <stdlib.h>  
  3. #include <string.h>  
  4. #include <math.h>  
  5. #include <sys/time.h>  
  6. #include <errno.h>  
  7.   
  8. /* 
  9.  * Optimize LSVM objective function via gradient descent. 
  10.  * 
  11.  * We use an adaptive cache mechanism.  After a negative example 
  12.  * scores beyond the margin multiple times it is removed from the 
  13.  * training set for a fixed number of iterations. 
  14.  */  
  15.   
  16. // Data File Format  
  17. // EXAMPLE*  
  18. //   
  19. // EXAMPLE:  
  20. //  long label          ints  
  21. //  blocks              int  
  22. //  dim                 int  
  23. //  DATA{blocks}  
  24. //  
  25. // DATA:  
  26. //  block label         float  
  27. //  block data          floats  
  28. //  
  29. // Internal Binary Format  
  30. //  len           int (byte length of EXAMPLE)  
  31. //  EXAMPLE       <see above>  
  32. //  unique flag   byte  
  33.   
  34. // number of iterations  
  35. #define ITER 5000000  
  36.   
  37. // small cache parameters  
  38. #define INCACHE 3  
  39. #define WAIT 10  
  40.   
  41. // error checking  
  42. #define check(e) \  
  43. (e ? (void)0 : (printf("%s:%u error: %s\n%s\n", __FILE__, __LINE__, #e, strerror(errno)), exit(1)))  
  44.   
  45. // number of non-zero blocks in example ex  
  46. #define NUM_NONZERO(ex) (((int *)ex)[labelsize+1])  
  47.   
  48. // float pointer to data segment of example ex  
  49. #define EX_DATA(ex) ((float *)(ex + sizeof(int)*(labelsize+3)))  
  50.   
  51. // class label (+1 or -1) for the example  
  52. #define LABEL(ex) (((int *)ex)[1])  
  53.   
  54. // block label (converted to 0-based index)  
  55. #define BLOCK_IDX(data) (((int)data[0])-1)  
  56.   
  57. int labelsize;  
  58. int dim;  
  59.   
  60. // comparison function for sorting examples   
  61. // 参见 http://blog.sina.com.cn/s/blog_5155e8d401009145.html  
  62. int comp(const void *a, const void *b) {  
  63.   // sort by extended label first, and whole example second...  
  64.     
  65.   //逐字节比较的,当buf1<buf2时,返回值<0,当buf1=buf2时,返回值=0,当buf1>buf2时,返回值>0  
  66.   // 先比较这五个量 [label id level x y],也就是说按照 样本类别->id->level->x->y排序样本  
  67.   int c = memcmp(*((char **)a) + sizeof(int),   
  68.          *((char **)b) + sizeof(int),   
  69.          labelsize*sizeof(int));// 5  
  70.   if (c) //label 不相等  
  71.     return c;  
  72.     
  73.   // labels are the same ,怎么可能会一样呢 id在正负样本集内从1开始是递增的啊  phase 2 阶段同一张图片产生的样本,id都是一样的  
  74.   int alen = **((int **)a);  
  75.   int blen = **((int **)b);  
  76.   if (alen == blen) //长度一样  
  77.     return memcmp(*((char **)a) + sizeof(int),   
  78.           *((char **)b) + sizeof(int),   
  79.           alen); //真霸气,所有字节都比较……  
  80.   return ((alen < blen) ? -1 : 1);//按长度排序  
  81. }  
  82.   
  83. // a collapsed example is a sequence of examples  
  84. struct collapsed {  
  85.   char **seq;  
  86.   int num;  
  87. };  
  88.   
  89. // set of collapsed examples  
  90. struct data {  
  91.   collapsed *x;  
  92.   int num;  
  93.   int numblocks;  
  94.   int *blocksizes;  
  95.   float *regmult;  
  96.   float *learnmult;  
  97. };  
  98.   
  99. // seed the random number generator with the current time  
  100. void seed_time() {  
  101.  struct timeval tp;  
  102.  check(gettimeofday(&tp, NULL) == 0);  
  103.  srand48((long)tp.tv_usec);  
  104. }  
  105.   
  106. static inline double min(double x, double y) { return (x <= y ? x : y); }  
  107. static inline double max(double x, double y) { return (x <= y ? y : x); }  
  108.   
  109. // gradient descent  
  110. //---------------参照论文公式17 后的步骤---------------------------------------  
  111. void gd(double C, double J, data X, double **w, double **lb) {  
  112. //  C=0.0002, J=1, X, w==0, lb==-100);  
  113. //      
  114.   int num = X.num; //组数  
  115.     
  116.   // state for random permutations  
  117.   int *perm = (int *)malloc(sizeof(int)*X.num);  
  118.   check(perm != NULL);  
  119.   
  120.   // state for small cache  
  121.   int *W = (int *)malloc(sizeof(int)*num);  
  122.   check(W != NULL);  
  123.   for (int j = 0; j < num; j++)  
  124.     W[j] = 0;  
  125.   
  126.   int t = 0;  
  127.   while (t < ITER) {  // 5000000 ,霸气……  
  128.     // pick random permutation  
  129.     for (int i = 0; i < num; i++) //组数  
  130.       perm[i] = i;  
  131.     //-------打乱顺序-----  
  132.     // 论文中是随机选择一个样本,这里是随机排好序,再顺序取。  
  133.     // 类似于随机取,但是这里能保证取到全部样本,避免单个样本重复被抽到,重复作用  
  134.     for (int swapi = 0; swapi < num; swapi++) {  
  135.       int swapj = (int)(drand48()*(num-swapi)) + swapi; //drand48 产生 0-1之间的均匀分布  
  136.       int tmp = perm[swapi];  
  137.       perm[swapi] = perm[swapj];  
  138.       perm[swapj] = tmp;  
  139.     }  
  140.   
  141.     // count number of examples in the small cache  
  142.     int cnum = 0; //下面的循环部分的实际循环次数  
  143.     for (int i = 0; i < num; i++) {  
  144.       if (W[i] <= INCACHE) // 3  
  145.         cnum++;  
  146.     }  
  147.     //-------------------------------------------------------  
  148.     for (int swapi = 0; swapi < num; swapi++) {  
  149.       // select example  
  150.       int i = perm[swapi];  
  151.       collapsed x = X.x[i];  
  152.   
  153.       // skip if example is not in small cache  
  154.       //负样本分对一次+1,分错一次清为0  
  155.       //连续三次都分对了,那么这个样本很有可能是 easy 样本  
  156.       //直接让他罚停四次迭代  
  157.       if (W[i] > INCACHE) { //3  
  158.             W[i]--;  
  159.             continue;  
  160.       }  
  161.   
  162.       // learning rate  
  163.       double T = t + 1000.0; //学习率,直接1/t太大了  
  164.       double rateX = cnum * C / T;  
  165.       double rateR = 1.0 / T;  
  166.   
  167.       if (t % 10000 == 0) {  
  168.         printf(".");  
  169.         fflush(stdout); //清除文件缓冲区,文件以写方式打开时将缓冲区内容写入文件  
  170.       }  
  171.       t++;  
  172.         
  173.       // compute max over latent placements  
  174.       //  -----step 3----  
  175.       int M = -1;  
  176.       double V = 0;  
  177.       // 组内循环,选择 Zi=argmax β*f 即文中的第3部  
  178.       // 训练rootfiter时,x.num=1,因为随机产生的负样本其id不同  
  179.       for (int m = 0; m < x.num; m++) {   
  180.         double val = 0;  
  181.         char *ptr = x.seq[m];  
  182.         float *data = EX_DATA(ptr); //特征数据的地址 第9个数据开始,  
  183.         //后面跟着是 block1 label | block2 data|block2 lable | block2 data    
  184.         //                 1      |       1    |     2       |  h*w/2*31个float  
  185.         int blocks = NUM_NONZERO(ptr); // phase 1,phase 2 : 2 个,offset,rootfilter  
  186.         for (int j = 0; j < blocks; j++) {  
  187.           int b = BLOCK_IDX(data); //   
  188.           data++;  
  189.           for (int k = 0; k < X.blocksizes[b]; k++)//(1)=1,(2)= root.h*root.w/2*31  
  190.             val += w[b][k] * data[k]; //第一次循环是0  
  191.           data += X.blocksizes[b];  
  192.         }  
  193.         if (M < 0 || val > V) {  
  194.           M = m;  
  195.           V = val;  
  196.         }  
  197.       }  
  198.         
  199.       // update model  
  200.       //-----step.4 也算了step.5 的一半 ---------------  
  201.       // 梯度下降,减小 w  
  202.       for (int j = 0; j < X.numblocks; j++) {// 2  
  203.         double mult = rateR * X.regmult[j] * X.learnmult[j]; // 0,1  20,1,1/T,对于block2,学习率at就是 1/t,block 1 为0  
  204.         for (int k = 0; k < X.blocksizes[j]; k++) {  
  205.           w[j][k] -= mult * w[j][k]; //不管是分对了,还是分错了,都要减掉 at*β,见公式17下的4,5   
  206.         }  
  207.       }  
  208.       char *ptr = x.seq[M];  
  209.       int label = LABEL(ptr);  
  210.       //----step.5----------分错了,往梯度的负方向移动  
  211.       if (label * V < 1.0)   
  212.       {  
  213.         W[i] = 0;  
  214.         float *data = EX_DATA(ptr);  
  215.         int blocks = NUM_NONZERO(ptr);  
  216.         for (int j = 0; j < blocks; j++) {  
  217.             int b = BLOCK_IDX(data);  
  218.             //  yi*cnum * C / T*1,见论文中 公式16,17  
  219.             double mult = (label > 0 ? J : -1) * rateX * X.learnmult[b];         
  220.             data++;  
  221.             for (int k = 0; k < X.blocksizes[b]; k++)  
  222.                 w[b][k] += mult * data[k];  
  223.             data += X.blocksizes[b];  
  224.         }  
  225.       } else if (label == -1)   
  226.       {  
  227.             if (W[i] == INCACHE) //3  
  228.                 W[i] = WAIT; //10  
  229.             else  
  230.                 W[i]++;  
  231.       }  
  232.     }  
  233.   
  234.     // apply lowerbounds  
  235.     for (int j = 0; j < X.numblocks; j++) {  
  236.       for (int k = 0; k < X.blocksizes[j]; k++) {  
  237.         w[j][k] = max(w[j][k], lb[j][k]);  
  238.       }  
  239.     }  
  240.   
  241.   }  
  242.   
  243.   free(perm);  
  244.   free(W);  
  245. }  
  246.   
  247. // score examples  
  248. double *score(data X, char **examples, int num, double **w) {  
  249.   double *s = (double *)malloc(sizeof(double)*num);  
  250.   check(s != NULL);  
  251.   for (int i = 0; i < num; i++) {  
  252.     s[i] = 0.0;  
  253.     float *data = EX_DATA(examples[i]);  
  254.     int blocks = NUM_NONZERO(examples[i]);  
  255.     for (int j = 0; j < blocks; j++) {  
  256.       int b = BLOCK_IDX(data);  
  257.       data++;  
  258.       for (int k = 0; k < X.blocksizes[b]; k++)  
  259.         s[i] += w[b][k] * data[k];  
  260.       data += X.blocksizes[b];  
  261.     }  
  262.   }  
  263.   return s;    
  264. }  
  265.   
  266. // merge examples with identical labels  
  267. void collapse(data *X, char **examples, int num) {  
  268. //&X, sorted, num_unique  
  269.   collapsed *x = (collapsed *)malloc(sizeof(collapsed)*num);  
  270.   check(x != NULL);  
  271.   int i = 0;  
  272.   x[0].seq = examples;  
  273.   x[0].num = 1;  
  274.   for (int j = 1; j < num; j++) {  
  275.     if (!memcmp(x[i].seq[0]+sizeof(int), examples[j]+sizeof(int),   
  276.         labelsize*sizeof(int))) {  
  277.       x[i].num++; //如果label 五个量相同  
  278.     } else {  
  279.       i++;  
  280.       x[i].seq = &(examples[j]);  
  281.       x[i].num = 1;  
  282.     }  
  283.   }  
  284.   X->x = x;  
  285.   X->num = i+1;    
  286. }  
  287.   
  288. //调用参数 C=0.0002, J=1, hdrfile, datfile, modfile, inffile, lobfile  
  289. int main(int argc, char **argv) {    
  290.   seed_time();  
  291.   int count;  
  292.   data X;  
  293.   
  294.   // command line arguments  
  295.   check(argc == 8);  
  296.   double C = atof(argv[1]);  
  297.   double J = atof(argv[2]);  
  298.   char *hdrfile = argv[3];  
  299.   char *datfile = argv[4];  
  300.   char *modfile = argv[5];  
  301.   char *inffile = argv[6];  
  302.   char *lobfile = argv[7];  
  303.   
  304.   // read header file  
  305.   FILE *f = fopen(hdrfile, "rb");  
  306.   check(f != NULL);  
  307.   int header[3];  
  308.   count = fread(header, sizeof(int), 3, f);  
  309.   check(count == 3);  
  310.   int num = header[0]; //正负样本总数  
  311.   labelsize = header[1]; // labelsize = 5;  [label id level x y]  
  312.   X.numblocks = header[2]; // 2  
  313.   X.blocksizes = (int *)malloc(X.numblocks*sizeof(int)); //(1)=1,(2)= root.h*root.w/2*31  
  314.   count = fread(X.blocksizes, sizeof(int), X.numblocks, f);  
  315.   check(count == X.numblocks);  
  316.   X.regmult = (float *)malloc(sizeof(float)*X.numblocks); //0 ,1  
  317.   check(X.regmult != NULL);  
  318.   count = fread(X.regmult, sizeof(float), X.numblocks, f);  
  319.   check(count == X.numblocks);  
  320.   X.learnmult = (float *)malloc(sizeof(float)*X.numblocks);//20, 1  
  321.   check(X.learnmult != NULL);  
  322.   count = fread(X.learnmult, sizeof(float), X.numblocks, f);  
  323.   check(count == X.numblocks);  
  324.   check(num != 0);  
  325.   fclose(f);  
  326.   printf("%d examples with label size %d and %d blocks\n",  
  327.      num, labelsize, X.numblocks);  
  328.   printf("block size, regularization multiplier, learning rate multiplier\n");  
  329.   dim = 0;  
  330.   for (int i = 0; i < X.numblocks; i++) {  
  331.     dim += X.blocksizes[i];  
  332.     printf("%d, %.2f, %.2f\n", X.blocksizes[i], X.regmult[i], X.learnmult[i]);  
  333.   }  
  334.   
  335.   // ---------------从 datfile 读取  正负 examples----------------  
  336.   // examples [i] 存储了第i个样本的信息 长度为 1 int + 7 int +dim 个float + 1 byte  
  337.   // 1 int      legth 样本包括信息头在内的总字节长度  
  338.   // 7 int      [1/-1 id 0 0 0 2 dim] ,id为样本编号,[label id level centry_x centry_y],2是block个数  
  339.   // dim float  feature,dim=2+1+root.h*root.w/2*31,意义如下  
  340.   //         block1 label | block2 data|block2 lable | block2 data  
  341.   //               1      |       1    |     2       |  h*w/2*31个float  
  342.   // 1 byte     unique=0  
  343.   f = fopen(datfile, "rb");  
  344.   check(f != NULL);  
  345.   printf("Reading examples\n");  
  346.     
  347.   //+,-example数据  
  348.   char **examples = (char **)malloc(num*sizeof(char *));   
  349.     
  350.   check(examples != NULL);  
  351.     for (int i = 0; i < num; i++) {  
  352.     // we use an extra byte in the end of each example to mark unique  
  353.     // we use an extra int at the start of each example to store the   
  354.     // example's byte length (excluding unique flag and this int)  
  355.     //[legth label id level x y  unique] unique=0  
  356.     int buf[labelsize+2];   
  357.     //写入时的值为[1/-1 i 0 0 0 2 dim]   
  358.     count = fread(buf, sizeof(int), labelsize+2, f);  
  359.     check(count == labelsize+2);  
  360.     // byte length of an example's data segment  
  361.       
  362.     //---前面七个是头,后面dim个float是样本特征数据,dim=2+1+root.h*root.w/2*31  
  363.     int len = sizeof(int)*(labelsize+2) + sizeof(float)*buf[labelsize+1];     
  364.     // memory for data, an initial integer, and a final byte  
  365.     examples[i] = (char *)malloc(sizeof(int)+len+1);  
  366.       
  367.     check(examples[i] != NULL);  
  368.     // set data segment's byte length  
  369.     ((int *)examples[i])[0] = len;  
  370.     // set the unique flag to zero  
  371.     examples[i][sizeof(int)+len] = 0;  
  372.     // copy label data into example  
  373.     for (int j = 0; j < labelsize+2; j++)  
  374.       ((int *)examples[i])[j+1] = buf[j];  
  375.     // read the rest of the data segment into the example  
  376.     count = fread(examples[i]+sizeof(int)*(labelsize+3), 1,   
  377.           len-sizeof(int)*(labelsize+2), f);  
  378.     check(count == len-sizeof(int)*(labelsize+2));  
  379.   }  
  380.   fclose(f);  
  381.   printf("done\n");  
  382.   
  383.   // sort  
  384.   printf("Sorting examples\n");  
  385.   char **sorted = (char **)malloc(num*sizeof(char *));  
  386.   check(sorted != NULL);  
  387.   memcpy(sorted, examples, num*sizeof(char *));  
  388.     
  389.   //qsort 库函数,真正的比较函数为 comp  
  390.   //从小到大,快速排序  
  391.   //依次按照 样本类别->id->level->cx->cy  排序样本  
  392.   //如果前面五个量都一样……  
  393.   //1.等长度,比较所有字节;  
  394.   //2.谁长谁小,长度不同是因为不同的component的 尺寸不一致   
  395.     
  396.   qsort(sorted, num, sizeof(char *), comp);   
  397.   printf("done\n");  
  398.   
  399.   // find unique examples  
  400.   // 唯一的样本,unique flag=1,  
  401.   // 相同的样本第一个样本的unique flag为1,其余为0 ,有的样本的位置被,unique替代了,但是并没有完全删除掉  
  402.   int i = 0;  
  403.   int len = *((int *)sorted[0]); //负样本的第一个  
  404.   sorted[0][sizeof(int)+len] = 1; // unique flag 置 1  
  405.   for (int j = 1; j < num; j++) {  
  406.     int alen = *((int *)sorted[i]);  
  407.     int blen = *((int *)sorted[j]);  
  408.     if (alen != blen || memcmp(sorted[i] + sizeof(int), sorted[j] + sizeof(int), alen)) //component不同 || 不同样本  
  409.     {  
  410.       i++;  
  411.       sorted[i] = sorted[j];  
  412.       sorted[i][sizeof(int)+blen] = 1; //标记为 unique  
  413.     }  
  414.   }  
  415.   int num_unique = i+1;  
  416.   printf("%d unique examples\n", num_unique);  
  417.   
  418.   // -------------------collapse examples----------------  
  419.   // 前面是找完全不一样的样本,这里是分组  
  420.   // label 的五个量 [label id level centry_x centry_y] 相同的分为一组,在detect时,写入了datfile   
  421.   // 负样本的 cx,cy都是相对于整张图片的,正样本是相对于剪切后的图像  
  422.   // 前面五个全相同,  
  423.   // 对于phase1 不可能,因为正负样本的id都不相同  
  424.   // 对于phase2 正样本只保留了最有可能是正样本的样本,只有一种情况,  
  425.   // rootfilter1,rootfilter2在同一张图片(id相同),检测出来的 Hard负样本 的cx,cy相同,因此一组最多应该只能出现2个 (待验证)  
  426.   // 原因是此时的latent variable 为(cx,cy,component),上述情况相下,我们只能保留component1或者component2  
  427.   // 后续训练时,这两个量是连续使用的,为什么呢??  
  428.   // collapse.seq(char **) 记录了每一组的第一个样本  
  429.   // collapse.num 每组的个数  
  430.   // X.num 组数  
  431.   // X.x=&collapse[0],也就是第一个 collapse的地址  
  432.   collapse(&X, sorted, num_unique);  
  433.   printf("%d collapsed examples\n", X.num);  
  434.   
  435.   // initial model  
  436.   // 读modfile文件,得到w的初始值。phase 1 初始化为全 0,phase 2 为上一次训练的结果……  
  437.   double **w = (double **)malloc(sizeof(double *)*X.numblocks);//2  
  438.   check(w != NULL);  
  439.   f = fopen(modfile, "rb");  
  440.   for (int i = 0; i < X.numblocks; i++) {  
  441.     w[i] = (double *)malloc(sizeof(double)*X.blocksizes[i]); //(1)=1,(2)= root.h*root.w/2*31  
  442.     check(w[i] != NULL);  
  443.     count = fread(w[i], sizeof(double), X.blocksizes[i], f);  
  444.     check(count == X.blocksizes[i]);  
  445.   }  
  446.   fclose(f);  
  447.   
  448.   // lower bounds  
  449.   // 读lobfile文件,初始化为全 滤波器参数下线-100 ……  
  450.   double **lb = (double **)malloc(sizeof(double *)*X.numblocks);  
  451.   check(lb != NULL);  
  452.   f = fopen(lobfile, "rb");  
  453.   for (int i = 0; i < X.numblocks; i++) {  
  454.     lb[i] = (double *)malloc(sizeof(double)*X.blocksizes[i]);  
  455.     check(lb[i] != NULL);  
  456.     count = fread(lb[i], sizeof(double), X.blocksizes[i], f);  
  457.     check(count == X.blocksizes[i]);  
  458.   }  
  459.   fclose(f);  
  460.     
  461.   
  462.   printf("Training");  
  463.   //-------------------------------- train -------------------------------  
  464.   //-----梯度下降发训练参数 w,参见论文 公式17 后面的步骤  
  465.   gd(C, J, X, w, lb);  
  466.   printf("done\n");  
  467.   
  468.   // save model  
  469.   printf("Saving model\n");  
  470.   f = fopen(modfile, "wb");  
  471.   check(f != NULL);  
  472.   //   存储 block1,block2的训练结果,w  
  473.   for (int i = 0; i < X.numblocks; i++) {  
  474.     count = fwrite(w[i], sizeof(double), X.blocksizes[i], f);  
  475.     check(count == X.blocksizes[i]);  
  476.   }  
  477.   fclose(f);  
  478.   
  479.   // score examples  
  480.   // ---所有的样本都的得分,没有乘以 label y   
  481.   printf("Scoring\n");  
  482.   double *s = score(X, examples, num, w);  
  483.   
  484.   // ---------Write info file-------------  
  485.   printf("Writing info file\n");  
  486.   f = fopen(inffile, "w");  
  487.   check(f != NULL);  
  488.   for (int i = 0; i < num; i++) {  
  489.     int len = ((int *)examples[i])[0];  
  490.     // label, score, unique flag  
  491.     count = fprintf(f, "%d\t%f\t%d\n", ((int *)examples[i])[1], s[i],   
  492.                     (int)examples[i][sizeof(int)+len]);  
  493.     check(count > 0);  
  494.   }  
  495.   fclose(f);  
  496.     
  497.   printf("Freeing memory\n");  
  498.   for (int i = 0; i < X.numblocks; i++) {  
  499.     free(w[i]);  
  500.     free(lb[i]);  
  501.   }  
  502.   free(w);  
  503.   free(lb);  
  504.   free(s);  
  505.   for (int i = 0; i < num; i++)  
  506.     free(examples[i]);  
  507.   free(examples);  
  508.   free(sorted);  
  509.   free(X.x);  
  510.   free(X.blocksizes);  
  511.   free(X.regmult);  
  512.   free(X.learnmult);  
  513.   
  514.   return 0;  
  515. }  

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/AllinToyou/article/detail/615856
推荐阅读
相关标签
  

闽ICP备14008679号