当前位置:   article > 正文

第95步 深度学习图像目标检测:Faster R-CNN建模_faster rcnn模型

faster rcnn模型

基于WIN10的64位系统演示

一、写在前面

本期开始,我们学习深度学习图像目标检测系列。

深度学习图像目标检测是计算机视觉领域的一个重要子领域,它的核心目标是利用深度学习模型来识别并定位图像中的特定目标。这些目标可以是物体、人、动物或其他可识别的实体。与传统的图像分类任务不同,目标检测不仅要确定图像中存在哪些类别的目标,还要确定它们的确切位置和尺寸。这通常是通过在图像上绘制一个或多个边界框来实现的,这些边界框精确地标出了目标的位置和范围。

二、Faster R-CNN简介

Faster R-CNN 是一种流行的深度学习图像目标检测算法,由 Shaoqing Ren, Kaiming He, Ross Girshick 和 Jian Sun 在 2015 年提出。它是 R-CNN 系列模型中的一个重要里程碑,因为它提高了检测速度,同时保持了高精度。以下是 Faster R-CNN 的主要特点和组件:

(1)区域提议网络 (RPN):

Faster R-CNN 的核心创新是引入了一个叫做区域提议网络 (RPN) 的组件。RPN 能够在卷积特征图上直接生成目标的边界框提议,这大大减少了提议的计算时间。RPN 使用了一组固定大小和比例的锚框(anchors),对每一个锚框预测偏移量和目标存在的概率。

(2)共享卷积特征:

与其前任 Fast R-CNN 不同,Faster R-CNN 的 RPN 和最终的目标检测都共享相同的卷积特征。这意味着图像只需要进行一次前向传播,从而大大提高了计算效率。

(3)ROI Pooling:

一旦得到了区域提议,Faster R-CNN 使用 ROI (Region of Interest) Pooling 技术来从每个提议中提取固定大小的特征。这确保无论提议的大小如何,都可以输入到一个固定大小的全连接网络中进行分类和边界框回归。

(4)双任务损失:

RPN 被训练为一个双任务问题:分类(目标 vs. 非目标)和边界框回归。这种双任务损失结构确保了 RPN 在生成提议时既考虑了准确性也考虑了定位。

总之,Faster R-CNN 通过引入区域提议网络和共享卷积特征,大大提高了目标检测的速度和精度,为后续的研究和应用打下了坚实的基础。

三、数据源

来源于公共数据,文件设置如下:

大概的任务就是:用一个框框标记出MTB的位置。

四、Faster R-CNN实战

直接上代码:

  1. import os
  2. import random
  3. import torch
  4. import torchvision
  5. from torchvision.models.detection import fasterrcnn_resnet50_fpn
  6. from torchvision.transforms import functional as F
  7. from PIL import Image
  8. from torch.utils.data import DataLoader
  9. import xml.etree.ElementTree as ET
  10. from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
  11. import matplotlib.pyplot as plt
  12. from torchvision import transforms
  13. import albumentations as A
  14. from albumentations.pytorch import ToTensorV2
  15. import numpy as np
  16. # Function to parse XML annotations
  17. def parse_xml(xml_path):
  18. tree = ET.parse(xml_path)
  19. root = tree.getroot()
  20. boxes = []
  21. for obj in root.findall("object"):
  22. bndbox = obj.find("bndbox")
  23. xmin = int(bndbox.find("xmin").text)
  24. ymin = int(bndbox.find("ymin").text)
  25. xmax = int(bndbox.find("xmax").text)
  26. ymax = int(bndbox.find("ymax").text)
  27. # Check if the bounding box is valid
  28. if xmin < xmax and ymin < ymax:
  29. boxes.append((xmin, ymin, xmax, ymax))
  30. else:
  31. print(f"Warning: Ignored invalid box in {xml_path} - ({xmin}, {ymin}, {xmax}, {ymax})")
  32. return boxes
  33. # Function to split data into training and validation sets
  34. def split_data(image_dir, split_ratio=0.8):
  35. all_images = [f for f in os.listdir(image_dir) if f.endswith(".jpg")]
  36. random.shuffle(all_images)
  37. split_idx = int(len(all_images) * split_ratio)
  38. train_images = all_images[:split_idx]
  39. val_images = all_images[split_idx:]
  40. return train_images, val_images
  41. # Dataset class for the Tuberculosis dataset
  42. class TuberculosisDataset(torch.utils.data.Dataset):
  43. def __init__(self, image_dir, annotation_dir, image_list, transform=None):
  44. self.image_dir = image_dir
  45. self.annotation_dir = annotation_dir
  46. self.image_list = image_list
  47. self.transform = transform
  48. def __len__(self):
  49. return len(self.image_list)
  50. def __getitem__(self, idx):
  51. image_path = os.path.join(self.image_dir, self.image_list[idx])
  52. image = Image.open(image_path).convert("RGB")
  53. xml_path = os.path.join(self.annotation_dir, self.image_list[idx].replace(".jpg", ".xml"))
  54. boxes = parse_xml(xml_path)
  55. # Check for empty bounding boxes and return None
  56. if len(boxes) == 0:
  57. return None
  58. boxes = torch.as_tensor(boxes, dtype=torch.float32)
  59. labels = torch.ones((len(boxes),), dtype=torch.int64)
  60. iscrowd = torch.zeros((len(boxes),), dtype=torch.int64)
  61. target = {}
  62. target["boxes"] = boxes
  63. target["labels"] = labels
  64. target["image_id"] = torch.tensor([idx])
  65. target["iscrowd"] = iscrowd
  66. # Apply transformations
  67. if self.transform:
  68. image = self.transform(image)
  69. return image, target
  70. # Define the transformations using torchvision
  71. data_transform = torchvision.transforms.Compose([
  72. torchvision.transforms.ToTensor(), # Convert PIL image to tensor
  73. torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # Normalize the images
  74. ])
  75. # Adjusting the DataLoader collate function to handle None values
  76. def collate_fn(batch):
  77. batch = list(filter(lambda x: x is not None, batch))
  78. return tuple(zip(*batch))
  79. # Function to get the Mask R-CNN model
  80. def get_model(num_classes):
  81. model = fasterrcnn_resnet50_fpn(pretrained=True)
  82. in_features = model.roi_heads.box_predictor.cls_score.in_features
  83. model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)
  84. return model
  85. # Function to save the model
  86. def save_model(model, path="mmaskrcnn_mtb.pth", save_full_model=False):
  87. if save_full_model:
  88. torch.save(model, path)
  89. else:
  90. torch.save(model.state_dict(), path)
  91. print(f"Model saved to {path}")
  92. # Function to compute Intersection over Union
  93. def compute_iou(boxA, boxB):
  94. xA = max(boxA[0], boxB[0])
  95. yA = max(boxA[1], boxB[1])
  96. xB = min(boxA[2], boxB[2])
  97. yB = min(boxA[3], boxB[3])
  98. interArea = max(0, xB - xA + 1) * max(0, yB - yA + 1)
  99. boxAArea = (boxA[2] - boxA[0] + 1) * (boxA[3] - boxA[1] + 1)
  100. boxBArea = (boxB[2] - boxB[0] + 1) * (boxB[3] - boxB[1] + 1)
  101. iou = interArea / float(boxAArea + boxBArea - interArea)
  102. return iou
  103. # Adjusting the DataLoader collate function to handle None values and entirely empty batches
  104. def collate_fn(batch):
  105. batch = list(filter(lambda x: x is not None, batch))
  106. if len(batch) == 0:
  107. # Return placeholder batch if entirely empty
  108. return [torch.zeros(1, 3, 224, 224)], [{}]
  109. return tuple(zip(*batch))
  110. #Training function with modifications for collecting IoU and loss
  111. def train_model(model, train_loader, optimizer, device, num_epochs=10):
  112. model.train()
  113. model.to(device)
  114. loss_values = []
  115. iou_values = []
  116. for epoch in range(num_epochs):
  117. epoch_loss = 0.0
  118. total_ious = 0
  119. num_boxes = 0
  120. for images, targets in train_loader:
  121. # Skip batches with placeholder data
  122. if len(targets) == 1 and not targets[0]:
  123. continue
  124. # Skip batches with empty targets
  125. if any(len(target["boxes"]) == 0 for target in targets):
  126. continue
  127. images = [image.to(device) for image in images]
  128. targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
  129. loss_dict = model(images, targets)
  130. losses = sum(loss for loss in loss_dict.values())
  131. optimizer.zero_grad()
  132. losses.backward()
  133. optimizer.step()
  134. epoch_loss += losses.item()
  135. # Compute IoU for evaluation
  136. with torch.no_grad():
  137. model.eval()
  138. predictions = model(images)
  139. for i, prediction in enumerate(predictions):
  140. pred_boxes = prediction["boxes"].cpu().numpy()
  141. true_boxes = targets[i]["boxes"].cpu().numpy()
  142. for pred_box in pred_boxes:
  143. for true_box in true_boxes:
  144. iou = compute_iou(pred_box, true_box)
  145. total_ious += iou
  146. num_boxes += 1
  147. model.train()
  148. avg_loss = epoch_loss / len(train_loader)
  149. avg_iou = total_ious / num_boxes
  150. loss_values.append(avg_loss)
  151. iou_values.append(avg_iou)
  152. print(f"Epoch {epoch+1}/{num_epochs} Loss: {avg_loss} Avg IoU: {avg_iou}")
  153. # Plotting loss and IoU values
  154. plt.figure(figsize=(12, 5))
  155. plt.subplot(1, 2, 1)
  156. plt.plot(loss_values, label="Training Loss")
  157. plt.title("Training Loss across Epochs")
  158. plt.xlabel("Epochs")
  159. plt.ylabel("Loss")
  160. plt.subplot(1, 2, 2)
  161. plt.plot(iou_values, label="IoU")
  162. plt.title("IoU across Epochs")
  163. plt.xlabel("Epochs")
  164. plt.ylabel("IoU")
  165. plt.show()
  166. # Save model after training
  167. save_model(model)
  168. # Validation function
  169. def validate_model(model, val_loader, device):
  170. model.eval()
  171. model.to(device)
  172. with torch.no_grad():
  173. for images, targets in val_loader:
  174. images = [image.to(device) for image in images]
  175. targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
  176. model(images)
  177. # Paths to your data
  178. image_dir = "tuberculosis-phonecamera"
  179. annotation_dir = "tuberculosis-phonecamera"
  180. # Split data
  181. train_images, val_images = split_data(image_dir)
  182. # Create datasets and dataloaders
  183. train_dataset = TuberculosisDataset(image_dir, annotation_dir, train_images, transform=data_transform)
  184. val_dataset = TuberculosisDataset(image_dir, annotation_dir, val_images, transform=data_transform)
  185. # Updated DataLoader with new collate function
  186. train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True, collate_fn=collate_fn)
  187. val_loader = DataLoader(val_dataset, batch_size=4, shuffle=False, collate_fn=collate_fn)
  188. # Model and optimizer
  189. model = get_model(2)
  190. optimizer = torch.optim.SGD(model.parameters(), lr=0.005, momentum=0.9, weight_decay=0.0005)
  191. # Train and validate
  192. train_model(model, train_loader, optimizer, device="cuda", num_epochs=100)
  193. validate_model(model, val_loader, device="cuda")
  194. #######################################Print Metrics######################################
  195. def calculate_metrics(predictions, ground_truths, iou_threshold=0.5):
  196. TP = 0 # True Positives
  197. FP = 0 # False Positives
  198. FN = 0 # False Negatives
  199. total_iou = 0 # to calculate mean IoU
  200. for pred, gt in zip(predictions, ground_truths):
  201. pred_boxes = pred["boxes"].cpu().numpy()
  202. gt_boxes = gt["boxes"].cpu().numpy()
  203. # Match predicted boxes to ground truth boxes
  204. for pred_box in pred_boxes:
  205. max_iou = 0
  206. matched = False
  207. for gt_box in gt_boxes:
  208. iou = compute_iou(pred_box, gt_box)
  209. if iou > max_iou:
  210. max_iou = iou
  211. if iou > iou_threshold:
  212. matched = True
  213. total_iou += max_iou
  214. if matched:
  215. TP += 1
  216. else:
  217. FP += 1
  218. FN += len(gt_boxes) - TP
  219. precision = TP / (TP + FP) if (TP + FP) != 0 else 0
  220. recall = TP / (TP + FN) if (TP + FN) != 0 else 0
  221. f1_score = (2 * precision * recall) / (precision + recall) if (precision + recall) != 0 else 0
  222. mean_iou = total_iou / (TP + FP)
  223. return precision, recall, f1_score, mean_iou
  224. def evaluate_model(model, dataloader, device):
  225. model.eval()
  226. model.to(device)
  227. all_predictions = []
  228. all_ground_truths = []
  229. with torch.no_grad():
  230. for images, targets in dataloader:
  231. images = [image.to(device) for image in images]
  232. predictions = model(images)
  233. all_predictions.extend(predictions)
  234. all_ground_truths.extend(targets)
  235. precision, recall, f1_score, mean_iou = calculate_metrics(all_predictions, all_ground_truths)
  236. return precision, recall, f1_score, mean_iou
  237. train_precision, train_recall, train_f1, train_iou = evaluate_model(model, train_loader, "cuda")
  238. val_precision, val_recall, val_f1, val_iou = evaluate_model(model, val_loader, "cuda")
  239. print("Training Set Metrics:")
  240. print(f"Precision: {train_precision:.4f}, Recall: {train_recall:.4f}, F1 Score: {train_f1:.4f}, Mean IoU: {train_iou:.4f}")
  241. print("\nValidation Set Metrics:")
  242. print(f"Precision: {val_precision:.4f}, Recall: {val_recall:.4f}, F1 Score: {val_f1:.4f}, Mean IoU: {val_iou:.4f}")
  243. #sheet
  244. header = "| Metric | Training Set | Validation Set |"
  245. divider = "+----------+--------------+----------------+"
  246. train_metrics = f"| Precision | {train_precision:.4f} | {val_precision:.4f} |"
  247. recall_metrics = f"| Recall | {train_recall:.4f} | {val_recall:.4f} |"
  248. f1_metrics = f"| F1 Score | {train_f1:.4f} | {val_f1:.4f} |"
  249. iou_metrics = f"| Mean IoU | {train_iou:.4f} | {val_iou:.4f} |"
  250. print(header)
  251. print(divider)
  252. print(train_metrics)
  253. print(recall_metrics)
  254. print(f1_metrics)
  255. print(iou_metrics)
  256. print(divider)
  257. #######################################Train Set######################################
  258. import numpy as np
  259. import matplotlib.pyplot as plt
  260. def plot_predictions_on_image(model, dataset, device, title):
  261. # Select a random image from the dataset
  262. idx = np.random.randint(50, len(dataset))
  263. image, target = dataset[idx]
  264. img_tensor = image.clone().detach().to(device).unsqueeze(0)
  265. # Use the model to make predictions
  266. model.eval()
  267. with torch.no_grad():
  268. prediction = model(img_tensor)
  269. # Inverse normalization for visualization
  270. inv_normalize = transforms.Normalize(
  271. mean=[-0.485/0.229, -0.456/0.224, -0.406/0.225],
  272. std=[1/0.229, 1/0.224, 1/0.225]
  273. )
  274. image = inv_normalize(image)
  275. image = torch.clamp(image, 0, 1)
  276. image = F.to_pil_image(image)
  277. # Plot the image with ground truth boxes
  278. plt.figure(figsize=(10, 6))
  279. plt.title(title + " with Ground Truth Boxes")
  280. plt.imshow(image)
  281. ax = plt.gca()
  282. # Draw the ground truth boxes in blue
  283. for box in target["boxes"]:
  284. rect = plt.Rectangle(
  285. (box[0], box[1]), box[2]-box[0], box[3]-box[1],
  286. fill=False, color='blue', linewidth=2
  287. )
  288. ax.add_patch(rect)
  289. plt.show()
  290. # Plot the image with predicted boxes
  291. plt.figure(figsize=(10, 6))
  292. plt.title(title + " with Predicted Boxes")
  293. plt.imshow(image)
  294. ax = plt.gca()
  295. # Draw the predicted boxes in red
  296. for box in prediction[0]["boxes"].cpu():
  297. rect = plt.Rectangle(
  298. (box[0], box[1]), box[2]-box[0], box[3]-box[1],
  299. fill=False, color='red', linewidth=2
  300. )
  301. ax.add_patch(rect)
  302. plt.show()
  303. # Call the function for a random image from the train dataset
  304. plot_predictions_on_image(model, train_dataset, "cuda", "Selected from Training Set")
  305. #######################################Val Set######################################
  306. # Call the function for a random image from the validation dataset
  307. plot_predictions_on_image(model, val_dataset, "cuda", "Selected from Validation Set")

不解读了,给出GPT的咒语参考:

咒语:我有一批数据,存在“tuberculosis-phonecamera”文件夹中,包括两部分:

一部分是MTB的痰涂片抗酸染色图片,为jpg格式,命名为“tuberculosis-phone-0001.jpg”、“tuberculosis-phone-0002.jpg”等;

一部分是MTB的痰涂片抗酸染色图片对应的注释文件,主要内容是标注MTB的痰涂片抗酸染色图片中MTB的具体位置,是若干个红色框,为XML格式,命名为“tuberculosis-phone-0001.xml”、“tuberculosis-phone-0002.xml”等,我上传一个xml文件给你做例子;

我需要基于上面的数据,使用pytorch建立一个Mask R-CNN目标识别模型,去识别MTB的痰涂片抗酸染色图片中的MTB,并使用红色框标注出来。数据需要随机分为训练集(80%)和验证集(20%)。

看看结果:

(1)loss曲线图:

(2)性能指标:

(3)训练的图片测试结果:

(4)验证集的图片测试结果:

五、写在后面

直接使用预训练模型,而且模型并没有调参。但是训练集的准确率还是挺高的,验证集就差点意思了。需要更高的性能,还得认真研究如何调参。

声明:本文内容由网友自发贡献,转载请注明出处:【wpsshop】
推荐阅读
相关标签
  

闽ICP备14008679号