赞
踩
这个系列是对哔哩哔哩up主霹雳吧啦Wz所出的FasterRCNN源码解析的视频进行一个记录以及加上自己理解(可能没有多少,更多的是对数据类型怎么变换的进行一个记录),首先学习源码的第一步就是先跑通目标代码
这里附上霹雳吧啦Wz的github链接:https://github.com/WZMIAOMIAO/deep-learning-for-image-processing
课程中的代码都在git中,大家可以自行下载
作者在视频中跑的是mobilnet模型,这里我们尝试跑一下res50+fpn的模型
create_model
这个类是定义模型的部分。
这里需要注意的是
backbone = resnet50_fpn_backbone()
会自动的冻结部分底层权重
代码如下(示例):
def create_model(num_classes): backbone = resnet50_fpn_backbone() # 训练自己数据集时不要修改这里的91,修改的是传入的num_classes参数 model = FasterRCNN(backbone=backbone, num_classes=91) # 载入预训练模型权重 # https://download.pytorch.org/models/fasterrcnn_resnet50_fpn_coco-258fb6c6.pth weights_dict = torch.load("./backbone/fasterrcnn_resnet50_fpn_coco.pth") missing_keys, unexpected_keys = model.load_state_dict(weights_dict, strict=False) if len(missing_keys) != 0 or len(unexpected_keys) != 0: print("missing_keys: ", missing_keys) print("unexpected_keys: ", unexpected_keys) # get number of input features for the classifier in_features = model.roi_heads.box_predictor.cls_score.in_features # replace the pre-trained head with a new one model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes) return model
main
这就是训练的主干部分了,
主要步骤有:
data_transform
图像预处理函数def main(parser_data):
device <
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。