当前位置:   article > 正文

自己实现SimpleNet_simplenet代码复现

simplenet代码复现

数据加载

  1. import os
  2. import PIL
  3. import torch
  4. from torchvision import transforms
  5. IMAGENET_MEAN = [0.485, 0.456, 0.406]
  6. IMAGENET_STD = [0.229, 0.224, 0.225]
  7. class Dataset(torch.utils.data.Dataset):
  8. def __init__(
  9. self,
  10. source,
  11. resize=512,
  12. imagesize=512,
  13. rotate_degrees=0,
  14. translate=0,
  15. brightness_factor=0,
  16. contrast_factor=0,
  17. saturation_factor=0,
  18. gray_p=0,
  19. h_flip_p=0,
  20. v_flip_p=0,
  21. scale=0,
  22. **kwargs,
  23. ):
  24. super().__init__()
  25. self.source = source
  26. self.imgpathlist = [os.path.join(self.source,filename) for filename in os.listdir(self.source)]
  27. self.transform_img = [
  28. transforms.Resize(resize),
  29. # transforms.RandomRotation(rotate_degrees, transforms.InterpolationMode.BILINEAR),
  30. transforms.ColorJitter(brightness_factor, contrast_factor, saturation_factor),
  31. transforms.RandomHorizontalFlip(h_flip_p),
  32. transforms.RandomVerticalFlip(v_flip_p),
  33. transforms.RandomGrayscale(gray_p),
  34. transforms.RandomAffine(rotate_degrees,
  35. translate=(translate, translate),
  36. scale=(1.0-scale, 1.0+scale),
  37. interpolation=transforms.InterpolationMode.BILINEAR),
  38. transforms.CenterCrop(imagesize),
  39. transforms.ToTensor(),
  40. transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
  41. ]
  42. self.transform_img = transforms.Compose(self.transform_img)
  43. self.transform_mask = [
  44. transforms.Resize(resize),
  45. transforms.CenterCrop(imagesize),
  46. transforms.ToTensor(),
  47. ]
  48. self.transform_mask = transforms.Compose(self.transform_mask)
  49. self.imagesize = (3, imagesize, imagesize)
  50. def __getitem__(self, idx):
  51. image = PIL.Image.open(self.imgpathlist[idx]).convert("RGB")
  52. image = self.transform_img(image)
  53. mask = torch.zeros([1, *image.size()[1:]])
  54. return image,mask
  55. def __len__(self):
  56. return len(self.imgpathlist)
  57. # dataset = Dataset(source=r"D:\dataset\mvtec_anomaly_detection\bottle\train\good")
  58. # image,mask = dataset.__getitem__(0)
  59. # print(image.shape,mask.shape)

网络:

  1. import torch
  2. import math
  3. from torch import onnx
  4. from torch import nn
  5. import numpy as np
  6. import torch.nn.functional as F
  7. import scipy.ndimage as ndimage
  8. import backbones
  9. import copy
  10. def pth2onnx(model, dummy_input, dynamiconnx):
  11. torch.set_grad_enabled(False)
  12. input_names = ["input1"]
  13. output_names = ["output1"]
  14. # 保存维度变化的onnx
  15. onnx.export(model=model, args=dummy_input, f=dynamiconnx, input_names=input_names,
  16. output_names=output_names, verbose=False,
  17. dynamic_axes=dict([(k, {0: 'batch_size'}) for k in input_names] +
  18. [(k, {0: 'batch_size'}) for k in output_names]),
  19. keep_initializers_as_inputs=True)
  20. class Convolution(torch.nn.Module):
  21. def __init__(self,in_chanel,out_chanel,kernalsize,strid,padding):
  22. super(Convolution,self).__init__()
  23. self.conv = nn.Conv2d(in_chanel,out_chanel,kernalsize,strid,padding)
  24. self.bn = nn.BatchNorm2d(out_chanel)
  25. self.active = nn.Mish(True)
  26. def forward(self,x):
  27. return self.active(self.bn(self.conv(x)))
  28. class PatchMaker(nn.Module):
  29. def __init__(self, patchsize, top_k=0, stride=None):
  30. super(PatchMaker,self).__init__()
  31. self.patchsize = patchsize
  32. self.stride = stride
  33. self.top_k = top_k
  34. def forward(self, features):
  35. """Convert a tensor into a tensor of respective patches.
  36. Args:
  37. x: [torch.Tensor, bs x c x w x h]
  38. Returns:
  39. x: [torch.Tensor, bs * w//stride * h//stride, c, patchsize,
  40. patchsize]
  41. """
  42. return_spatial_info = True
  43. padding = int((self.patchsize - 1) / 2)#1
  44. unfolder = torch.nn.Unfold(
  45. kernel_size=self.patchsize, stride=self.stride, padding=padding, dilation=1
  46. )
  47. unfolded_features = unfolder(features)
  48. number_of_total_patches = []
  49. for s in features.shape[-2:]:
  50. n_patches = (
  51. s + 2 * padding - 1 * (self.patchsize - 1) - 1
  52. ) / self.stride + 1
  53. number_of_total_patches.append(int(n_patches))
  54. unfolded_features = unfolded_features.reshape(
  55. *features.shape[:2], self.patchsize, self.patchsize, -1
  56. )
  57. unfolded_features = unfolded_features.permute(0, 4, 1, 2, 3)
  58. if return_spatial_info:
  59. return unfolded_features, number_of_total_patches
  60. return unfolded_features
  61. class Resblock(nn.Module):
  62. def __init__(self,ch):
  63. super(Resblock,self).__init__()
  64. self.conv1 = Convolution(ch, ch // 2, 1, 1, 0)
  65. self.conv2 = Convolution(ch // 2,ch // 2, 3, 1, 1)
  66. self.conv3 = nn.Conv2d(ch // 2,ch, 1, 1)
  67. self.relu = nn.ReLU(True)
  68. def forward(self,x):
  69. y = self.conv1(x)
  70. y = self.conv2(y)
  71. y = self.conv3(y)
  72. return self.relu(x + y)
  73. class Preprocessing(torch.nn.Module):
  74. def __init__(self, input_dims, output_dim):
  75. super(Preprocessing, self).__init__()
  76. self.input_dims = input_dims
  77. self.output_dim = output_dim
  78. self.preprocessing_modules = torch.nn.ModuleList()
  79. for input_dim in input_dims:
  80. module = MeanMapper(output_dim)
  81. self.preprocessing_modules.append(module)
  82. def forward(self, features):
  83. _features = []
  84. for module, feature in zip(self.preprocessing_modules, features):
  85. _features.append(module(feature))
  86. return torch.stack(_features, dim=1)
  87. class MeanMapper(torch.nn.Module):
  88. def __init__(self, preprocessing_dim):
  89. super(MeanMapper, self).__init__()
  90. self.preprocessing_dim = preprocessing_dim
  91. def forward(self, features):
  92. features = features.reshape(len(features), 1, -1)
  93. return F.adaptive_avg_pool1d(features, self.preprocessing_dim).squeeze(1)
  94. def init_weight(m):
  95. if isinstance(m, torch.nn.Linear):
  96. torch.nn.init.xavier_normal_(m.weight)
  97. elif isinstance(m, torch.nn.Conv2d):
  98. torch.nn.init.xavier_normal_(m.weight)
  99. class Projection(torch.nn.Module):
  100. def __init__(self, in_planes, out_planes=None, n_layers=1, layer_type=0):
  101. super(Projection, self).__init__()
  102. if out_planes is None:
  103. out_planes = in_planes
  104. self.layers = torch.nn.Sequential()
  105. _in = None
  106. _out = None
  107. for i in range(n_layers):
  108. _in = in_planes if i == 0 else _out
  109. _out = out_planes
  110. self.layers.add_module(f"{i}fc",
  111. torch.nn.Linear(_in, _out))
  112. if i < n_layers - 1:
  113. # if layer_type > 0:
  114. # self.layers.add_module(f"{i}bn",
  115. # torch.nn.BatchNorm1d(_out))
  116. if layer_type > 1:
  117. self.layers.add_module(f"{i}relu",
  118. torch.nn.LeakyReLU(.2))
  119. self.apply(init_weight)
  120. def forward(self, x):
  121. # x = .1 * self.layers(x) + x
  122. x = self.layers(x)
  123. return x
  124. class Discriminator(torch.nn.Module):
  125. def __init__(self, in_planes, n_layers=1, hidden=None):
  126. super(Discriminator, self).__init__()
  127. _hidden = in_planes if hidden is None else hidden
  128. self.body = torch.nn.Sequential()
  129. for i in range(n_layers-1):
  130. _in = in_planes if i == 0 else _hidden
  131. _hidden = int(_hidden // 1.5) if hidden is None else hidden
  132. self.body.add_module('block%d'%(i+1),
  133. torch.nn.Sequential(
  134. torch.nn.Linear(_in, _hidden),
  135. torch.nn.BatchNorm1d(_hidden),
  136. torch.nn.LeakyReLU(0.2)
  137. ))
  138. self.tail = torch.nn.Linear(_hidden, 1, bias=False)
  139. self.apply(init_weight)
  140. def forward(self,x):
  141. x = self.body(x)
  142. x = self.tail(x)
  143. return x
  144. class ForwardHook:
  145. def __init__(self, hook_dict, layer_name: str, last_layer_to_extract: str):
  146. self.hook_dict = hook_dict
  147. self.layer_name = layer_name
  148. self.raise_exception_to_break = copy.deepcopy(
  149. layer_name == last_layer_to_extract
  150. )
  151. def __call__(self, module, input, output):
  152. self.hook_dict[self.layer_name] = output
  153. return None
  154. class NetworkFeatureAggregator(torch.nn.Module):
  155. """Efficient extraction of network features."""
  156. def __init__(self, backbone, layers_to_extract_from, device, train_backbone=False):
  157. super(NetworkFeatureAggregator, self).__init__()
  158. """Extraction of network features.
  159. Runs a network only to the last layer of the list of layers where
  160. network features should be extracted from.
  161. Args:
  162. backbone: torchvision.model
  163. layers_to_extract_from: [list of str]
  164. """
  165. self.layers_to_extract_from = layers_to_extract_from
  166. self.backbone = backbone
  167. self.device = device
  168. self.train_backbone = train_backbone
  169. for extract_layer in layers_to_extract_from:
  170. if extract_layer == 'layer2':
  171. self.network_layer2 = backbone.__dict__["_modules"][extract_layer]
  172. if extract_layer == 'layer3':
  173. self.network_layer3 = backbone.__dict__["_modules"][extract_layer]
  174. print(self.network_layer2,'#%$^&*^*&(^^$%\n',self.network_layer3)
  175. self.to(self.device)
  176. def forward(self, images, eval=True):
  177. y = torch.randn((1,1000))
  178. if self.train_backbone and not eval:
  179. y = self.backbone(images)
  180. else:
  181. with torch.no_grad():
  182. try:
  183. y = self.backbone(images)
  184. except:
  185. pass
  186. return y
  187. class unsupervisedNet(torch.nn.Module):
  188. def __init__(self,batchsize,train):
  189. super(unsupervisedNet,self).__init__()
  190. self.batchsize = batchsize
  191. self.train_backbone = train
  192. self.backbone_name = 'wideresnet50'
  193. self.backbone = backbones.load(self.backbone_name)
  194. self.device = 'cpu' # 0
  195. self.input_shape = [3, 512, 512]
  196. self.target_size = self.input_shape[-2:]
  197. self.patchsize = 3
  198. self.stride = 1
  199. self.top_k = 0
  200. self.input_dims = [512, 1024]
  201. self.output_dim = 1536
  202. self.smoothing = 4
  203. # self.feature_dimensions = [512, 1024]
  204. # self.preprocessing = Preprocessing(self.feature_dimensions,self.output_dim)
  205. self.preprocessing_modules = torch.nn.ModuleList()
  206. for input_dim in self.input_dims:
  207. module = MeanMapper(self.output_dim)
  208. self.preprocessing_modules.append(module)
  209. self.pre_projection = Projection(self.output_dim, self.output_dim, 1, 0)
  210. self.discriminator = Discriminator(self.output_dim, n_layers=2, hidden=1024)
  211. self.layer2 = nn.Sequential(
  212. self.backbone.conv1,
  213. self.backbone.bn1,
  214. self.backbone.relu,
  215. self.backbone.maxpool,
  216. self.backbone.layer1,
  217. self.backbone.layer2,
  218. )
  219. self.layer3 = nn.Sequential(
  220. self.backbone.conv1,
  221. self.backbone.bn1,
  222. self.backbone.relu,
  223. self.backbone.maxpool,
  224. self.backbone.layer1,
  225. self.backbone.layer2,
  226. self.backbone.layer3
  227. )
  228. self.unfolded_features = []
  229. self.patch_shapes = []
  230. self.padding = int((self.patchsize - 1) / 2)
  231. self.unfolder = nn.Unfold(kernel_size=self.patchsize, stride=self.stride, padding=self.padding, dilation=1)
  232. def score(self, x):
  233. was_numpy = False
  234. if isinstance(x, np.ndarray):
  235. was_numpy = True
  236. x = torch.from_numpy(x)
  237. while x.ndim > 2:
  238. x = torch.max(x, dim=-1).values
  239. if x.ndim == 2:
  240. if self.top_k > 1:
  241. x = torch.topk(x, self.top_k, dim=1).values.mean(1)
  242. else:
  243. x = torch.max(x, dim=1).values
  244. if was_numpy:
  245. return x.numpy()
  246. return x
  247. def forward(self,x):
  248. output1 = self.layer2(x)
  249. output2 = self.layer3(x)
  250. unfolded_features1 = self.unfolder(output1)
  251. patch_shapes1 = []
  252. for s in output1.shape[-2:]:
  253. n_patches = (
  254. s + 2 * self.padding - 1 * (self.patchsize - 1) - 1
  255. ) / self.stride + 1
  256. patch_shapes1.append(int(n_patches))
  257. unfolded_features1 = unfolded_features1.reshape(
  258. *output1.shape[:2], self.patchsize, self.patchsize, -1
  259. )
  260. unfolded_features1 = unfolded_features1.permute(0, 4, 1, 2, 3)
  261. unfolded_features2 = self.unfolder(output2)
  262. patch_shapes2 = []
  263. for s in output2.shape[-2:]:
  264. n_patches = (
  265. s + 2 * self.padding - 1 * (self.patchsize - 1) - 1
  266. ) / self.stride + 1
  267. patch_shapes2.append(int(n_patches))
  268. unfolded_features2 = unfolded_features2.reshape(
  269. *output2.shape[:2], self.patchsize, self.patchsize, -1
  270. )
  271. unfolded_features2 = unfolded_features2.permute(0, 4, 1, 2, 3)
  272. ref_num_patches = patch_shapes1
  273. _features = unfolded_features2
  274. patch_dims = patch_shapes2
  275. _features = _features.reshape(
  276. _features.shape[0], patch_dims[0], patch_dims[1], *_features.shape[2:]
  277. )
  278. _features = _features.permute(0, -3, -2, -1, 1, 2)
  279. perm_base_shape = _features.shape
  280. _features = _features.reshape(-1, *_features.shape[-2:])
  281. _features = F.interpolate(
  282. _features.unsqueeze(1),
  283. size=(ref_num_patches[0], ref_num_patches[1]),
  284. mode="bilinear",
  285. align_corners=False,
  286. )
  287. _features = _features.squeeze(1)
  288. _features = _features.reshape(
  289. *perm_base_shape[:-2], ref_num_patches[0], ref_num_patches[1]
  290. )
  291. _features = _features.permute(0, -2, -1, 1, 2, 3)
  292. _features = _features.reshape(len(_features), -1, *_features.shape[-3:])
  293. unfolded_features2 = _features
  294. unfolded_features1 = unfolded_features1.reshape(-1, *unfolded_features1.shape[-3:])
  295. unfolded_features2 = unfolded_features2.reshape(-1, *unfolded_features2.shape[-3:])
  296. # _features = []
  297. model1 = self.preprocessing_modules[0]
  298. feature1 = model1(unfolded_features1)
  299. model2 = self.preprocessing_modules[1]
  300. feature2 = model2(unfolded_features2)
  301. features = torch.stack([feature1, feature2], dim=1)
  302. features = features.reshape(len(features), 1, -1)
  303. features = F.adaptive_avg_pool1d(features, self.output_dim)
  304. features = features.reshape(len(features), -1)
  305. patch_shapes = []
  306. patch_shapes.append(patch_shapes1)
  307. patch_shapes.append(patch_shapes2)
  308. features = self.pre_projection(features) # torch.Size([4096, 1536])
  309. self.features = features
  310. patch_scores = image_scores = -self.discriminator(features)
  311. patch_scores = patch_scores.cpu().detach().numpy()
  312. image_scores = image_scores.cpu().detach().numpy()
  313. image_scores = image_scores.reshape(self.batchsize, -1, *image_scores.shape[1:])
  314. image_scores = image_scores.reshape(*image_scores.shape[:2], -1)
  315. image_scores = self.score(image_scores)
  316. image_scores = image_scores.reshape(self.batchsize, -1, *image_scores.shape[1:])
  317. scales = patch_shapes[0]
  318. # patch_scores = patch_scores.reshape(1, scales[0], scales[1])
  319. patch_scores = patch_scores.reshape(self.batchsize, scales[0], scales[1])
  320. # features = features.reshape(1, scales[0], scales[1], -1)
  321. features = features.reshape(self.batchsize, scales[0], scales[1], -1)
  322. with torch.no_grad():
  323. if isinstance(patch_scores, np.ndarray):
  324. patch_scores = torch.from_numpy(patch_scores)
  325. _scores = patch_scores.to(self.device)
  326. _scores = _scores.unsqueeze(1)
  327. _scores = F.interpolate(
  328. _scores, size=self.target_size, mode="bilinear", align_corners=False
  329. )
  330. _scores = _scores.squeeze(1)
  331. patch_scores = _scores.cpu().numpy()
  332. if isinstance(features, np.ndarray):
  333. features = torch.from_numpy(features)
  334. features = features.to(self.device).permute(0, 3, 1, 2)
  335. if self.target_size[0] * self.target_size[1] * features.shape[0] * features.shape[1] >= 2 ** 31:
  336. subbatch_size = int((2 ** 31 - 1) / (self.target_size[0] * self.target_size[1] * features.shape[1]))
  337. interpolated_features = []
  338. for i_subbatch in range(int(features.shape[0] / subbatch_size + 1)):
  339. subfeatures = features[i_subbatch * subbatch_size:(i_subbatch + 1) * subbatch_size]
  340. subfeatures = subfeatures.unsuqeeze(0) if len(subfeatures.shape) == 3 else subfeatures
  341. subfeatures = F.interpolate(
  342. subfeatures, size=self.target_size, mode="bilinear", align_corners=False
  343. )
  344. interpolated_features.append(subfeatures)
  345. features = torch.cat(interpolated_features, 0)
  346. else:
  347. features = F.interpolate(
  348. features, size=self.target_size, mode="bilinear", align_corners=False
  349. )
  350. # features = features.cpu().detach().numpy()
  351. masks = [ndimage.gaussian_filter(patch_score, sigma=self.smoothing) for patch_score in patch_scores]
  352. masks = torch.tensor(masks)
  353. return masks # ,self.patch_shapes
  354. # net = unsupervisedNet(1,False)
  355. # # # net.cuda()
  356. # x = torch.randn((1,3,512,512))#.cuda()
  357. # # y = net(x)
  358. # # print(y.shape)
  359. # # pth2onnx(net,x,'test.onnx')
  360. # trace_script_module = torch.jit.trace(net,x)
  361. # trace_script_module.save('net1.torchscript')

训练代码

  1. from net import unsupervisedNet
  2. import torch
  3. from mvdataset import Dataset
  4. from torch.utils.data import DataLoader
  5. if __name__ == '__main__':
  6. lr = 1e-3
  7. mix_noise = 1
  8. noise_std = 0.015
  9. dsc_margin = .5
  10. x = torch.randn(2,3,512,512)
  11. label = torch.zeros(x.shape[0],512,512)
  12. model = unsupervisedNet(x.shape[0],True)
  13. dataset = Dataset(source=r"D:\dataset\mvtec_anomaly_detection\bottle\train\good")
  14. dataloader = DataLoader(dataset,batch_size=2,shuffle=True,num_workers=4)
  15. model.discriminator.train()
  16. model.pre_projection.train()
  17. dsc_opt = torch.optim.Adam(model.discriminator.parameters(), lr=0.0002, weight_decay=1e-5)
  18. proj_opt = torch.optim.AdamW(model.pre_projection.parameters(), lr * .1)
  19. backbone_opt = torch.optim.AdamW(model.backbone.parameters(), lr)
  20. # x = x.cuda()
  21. # label = label.cuda()
  22. # model.cuda()
  23. for epoch in range(100):
  24. for data,label in dataloader:
  25. dsc_opt.zero_grad()
  26. proj_opt.zero_grad()
  27. mask = model(data)
  28. true_feats = model.features
  29. noise_idxs = torch.randint(0,mix_noise, torch.Size([true_feats.shape[0]]))
  30. noise_one_hot = torch.nn.functional.one_hot(noise_idxs, num_classes=mix_noise)#.to(self.device)
  31. noise = torch.stack([
  32. torch.normal(0, noise_std * 1.1 ** (k), true_feats.shape)
  33. for k in range(mix_noise)], dim=1)#.to(self.device)
  34. noise = (noise * noise_one_hot.unsqueeze(-1)).sum(1)
  35. fake_feats = true_feats + noise
  36. scores = model.discriminator(torch.cat([true_feats, fake_feats]))
  37. true_scores = scores[:len(true_feats)]
  38. fake_scores = scores[len(fake_feats):]
  39. th = dsc_margin
  40. p_true = (true_scores.detach() >= th).sum() / len(true_scores)
  41. p_fake = (fake_scores.detach() < -th).sum() / len(fake_scores)
  42. true_loss = torch.clip(-true_scores + th, min=0)
  43. fake_loss = torch.clip(fake_scores + th, min=0)
  44. loss = true_loss.mean() + fake_loss.mean()
  45. print(loss)
  46. loss.backward()
  47. backbone_opt.step()
  48. proj_opt.step()
  49. dsc_opt.step()

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/花生_TL007/article/detail/537933
推荐阅读
相关标签
  

闽ICP备14008679号