赞
踩
数据加载:
- import os
- import PIL
- import torch
- from torchvision import transforms
-
- IMAGENET_MEAN = [0.485, 0.456, 0.406]
- IMAGENET_STD = [0.229, 0.224, 0.225]
-
- class Dataset(torch.utils.data.Dataset):
- def __init__(
- self,
- source,
- resize=512,
- imagesize=512,
- rotate_degrees=0,
- translate=0,
- brightness_factor=0,
- contrast_factor=0,
- saturation_factor=0,
- gray_p=0,
- h_flip_p=0,
- v_flip_p=0,
- scale=0,
- **kwargs,
- ):
- super().__init__()
- self.source = source
- self.imgpathlist = [os.path.join(self.source,filename) for filename in os.listdir(self.source)]
- self.transform_img = [
- transforms.Resize(resize),
- # transforms.RandomRotation(rotate_degrees, transforms.InterpolationMode.BILINEAR),
- transforms.ColorJitter(brightness_factor, contrast_factor, saturation_factor),
- transforms.RandomHorizontalFlip(h_flip_p),
- transforms.RandomVerticalFlip(v_flip_p),
- transforms.RandomGrayscale(gray_p),
- transforms.RandomAffine(rotate_degrees,
- translate=(translate, translate),
- scale=(1.0-scale, 1.0+scale),
- interpolation=transforms.InterpolationMode.BILINEAR),
- transforms.CenterCrop(imagesize),
- transforms.ToTensor(),
- transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
- ]
- self.transform_img = transforms.Compose(self.transform_img)
-
- self.transform_mask = [
- transforms.Resize(resize),
- transforms.CenterCrop(imagesize),
- transforms.ToTensor(),
- ]
- self.transform_mask = transforms.Compose(self.transform_mask)
- self.imagesize = (3, imagesize, imagesize)
-
- def __getitem__(self, idx):
- image = PIL.Image.open(self.imgpathlist[idx]).convert("RGB")
- image = self.transform_img(image)
- mask = torch.zeros([1, *image.size()[1:]])
- return image,mask
-
- def __len__(self):
- return len(self.imgpathlist)
-
- # dataset = Dataset(source=r"D:\dataset\mvtec_anomaly_detection\bottle\train\good")
- # image,mask = dataset.__getitem__(0)
- # print(image.shape,mask.shape)
网络:
- import torch
- import math
- from torch import onnx
- from torch import nn
- import numpy as np
- import torch.nn.functional as F
- import scipy.ndimage as ndimage
- import backbones
- import copy
-
- def pth2onnx(model, dummy_input, dynamiconnx):
- torch.set_grad_enabled(False)
- input_names = ["input1"]
- output_names = ["output1"]
- # 保存维度变化的onnx
- onnx.export(model=model, args=dummy_input, f=dynamiconnx, input_names=input_names,
- output_names=output_names, verbose=False,
- dynamic_axes=dict([(k, {0: 'batch_size'}) for k in input_names] +
- [(k, {0: 'batch_size'}) for k in output_names]),
- keep_initializers_as_inputs=True)
-
- class Convolution(torch.nn.Module):
- def __init__(self,in_chanel,out_chanel,kernalsize,strid,padding):
- super(Convolution,self).__init__()
- self.conv = nn.Conv2d(in_chanel,out_chanel,kernalsize,strid,padding)
- self.bn = nn.BatchNorm2d(out_chanel)
- self.active = nn.Mish(True)
-
- def forward(self,x):
- return self.active(self.bn(self.conv(x)))
-
- class PatchMaker(nn.Module):
- def __init__(self, patchsize, top_k=0, stride=None):
- super(PatchMaker,self).__init__()
- self.patchsize = patchsize
- self.stride = stride
- self.top_k = top_k
-
- def forward(self, features):
- """Convert a tensor into a tensor of respective patches.
- Args:
- x: [torch.Tensor, bs x c x w x h]
- Returns:
- x: [torch.Tensor, bs * w//stride * h//stride, c, patchsize,
- patchsize]
- """
- return_spatial_info = True
- padding = int((self.patchsize - 1) / 2)#1
- unfolder = torch.nn.Unfold(
- kernel_size=self.patchsize, stride=self.stride, padding=padding, dilation=1
- )
- unfolded_features = unfolder(features)
- number_of_total_patches = []
- for s in features.shape[-2:]:
- n_patches = (
- s + 2 * padding - 1 * (self.patchsize - 1) - 1
- ) / self.stride + 1
- number_of_total_patches.append(int(n_patches))
- unfolded_features = unfolded_features.reshape(
- *features.shape[:2], self.patchsize, self.patchsize, -1
- )
- unfolded_features = unfolded_features.permute(0, 4, 1, 2, 3)
-
- if return_spatial_info:
- return unfolded_features, number_of_total_patches
- return unfolded_features
-
- class Resblock(nn.Module):
- def __init__(self,ch):
- super(Resblock,self).__init__()
- self.conv1 = Convolution(ch, ch // 2, 1, 1, 0)
- self.conv2 = Convolution(ch // 2,ch // 2, 3, 1, 1)
- self.conv3 = nn.Conv2d(ch // 2,ch, 1, 1)
- self.relu = nn.ReLU(True)
- def forward(self,x):
- y = self.conv1(x)
- y = self.conv2(y)
- y = self.conv3(y)
- return self.relu(x + y)
-
- class Preprocessing(torch.nn.Module):
- def __init__(self, input_dims, output_dim):
- super(Preprocessing, self).__init__()
- self.input_dims = input_dims
- self.output_dim = output_dim
- self.preprocessing_modules = torch.nn.ModuleList()
- for input_dim in input_dims:
- module = MeanMapper(output_dim)
- self.preprocessing_modules.append(module)
-
- def forward(self, features):
- _features = []
- for module, feature in zip(self.preprocessing_modules, features):
- _features.append(module(feature))
- return torch.stack(_features, dim=1)
-
- class MeanMapper(torch.nn.Module):
- def __init__(self, preprocessing_dim):
- super(MeanMapper, self).__init__()
- self.preprocessing_dim = preprocessing_dim
- def forward(self, features):
- features = features.reshape(len(features), 1, -1)
- return F.adaptive_avg_pool1d(features, self.preprocessing_dim).squeeze(1)
-
- def init_weight(m):
-
- if isinstance(m, torch.nn.Linear):
- torch.nn.init.xavier_normal_(m.weight)
- elif isinstance(m, torch.nn.Conv2d):
- torch.nn.init.xavier_normal_(m.weight)
-
- class Projection(torch.nn.Module):
- def __init__(self, in_planes, out_planes=None, n_layers=1, layer_type=0):
- super(Projection, self).__init__()
-
- if out_planes is None:
- out_planes = in_planes
- self.layers = torch.nn.Sequential()
- _in = None
- _out = None
- for i in range(n_layers):
- _in = in_planes if i == 0 else _out
- _out = out_planes
- self.layers.add_module(f"{i}fc",
- torch.nn.Linear(_in, _out))
- if i < n_layers - 1:
- # if layer_type > 0:
- # self.layers.add_module(f"{i}bn",
- # torch.nn.BatchNorm1d(_out))
- if layer_type > 1:
- self.layers.add_module(f"{i}relu",
- torch.nn.LeakyReLU(.2))
- self.apply(init_weight)
-
- def forward(self, x):
- # x = .1 * self.layers(x) + x
- x = self.layers(x)
- return x
-
- class Discriminator(torch.nn.Module):
- def __init__(self, in_planes, n_layers=1, hidden=None):
- super(Discriminator, self).__init__()
- _hidden = in_planes if hidden is None else hidden
- self.body = torch.nn.Sequential()
- for i in range(n_layers-1):
- _in = in_planes if i == 0 else _hidden
- _hidden = int(_hidden // 1.5) if hidden is None else hidden
- self.body.add_module('block%d'%(i+1),
- torch.nn.Sequential(
- torch.nn.Linear(_in, _hidden),
- torch.nn.BatchNorm1d(_hidden),
- torch.nn.LeakyReLU(0.2)
- ))
- self.tail = torch.nn.Linear(_hidden, 1, bias=False)
- self.apply(init_weight)
-
- def forward(self,x):
- x = self.body(x)
- x = self.tail(x)
- return x
-
- class ForwardHook:
- def __init__(self, hook_dict, layer_name: str, last_layer_to_extract: str):
- self.hook_dict = hook_dict
- self.layer_name = layer_name
- self.raise_exception_to_break = copy.deepcopy(
- layer_name == last_layer_to_extract
- )
-
- def __call__(self, module, input, output):
- self.hook_dict[self.layer_name] = output
- return None
-
- class NetworkFeatureAggregator(torch.nn.Module):
- """Efficient extraction of network features."""
-
- def __init__(self, backbone, layers_to_extract_from, device, train_backbone=False):
- super(NetworkFeatureAggregator, self).__init__()
- """Extraction of network features.
- Runs a network only to the last layer of the list of layers where
- network features should be extracted from.
- Args:
- backbone: torchvision.model
- layers_to_extract_from: [list of str]
- """
- self.layers_to_extract_from = layers_to_extract_from
- self.backbone = backbone
- self.device = device
- self.train_backbone = train_backbone
-
- for extract_layer in layers_to_extract_from:
- if extract_layer == 'layer2':
- self.network_layer2 = backbone.__dict__["_modules"][extract_layer]
- if extract_layer == 'layer3':
- self.network_layer3 = backbone.__dict__["_modules"][extract_layer]
- print(self.network_layer2,'#%$^&*^*&(^^$%\n',self.network_layer3)
- self.to(self.device)
-
- def forward(self, images, eval=True):
- y = torch.randn((1,1000))
- if self.train_backbone and not eval:
- y = self.backbone(images)
- else:
- with torch.no_grad():
- try:
- y = self.backbone(images)
- except:
- pass
- return y
-
- class unsupervisedNet(torch.nn.Module):
- def __init__(self,batchsize,train):
- super(unsupervisedNet,self).__init__()
- self.batchsize = batchsize
- self.train_backbone = train
- self.backbone_name = 'wideresnet50'
- self.backbone = backbones.load(self.backbone_name)
- self.device = 'cpu' # 0
- self.input_shape = [3, 512, 512]
-
- self.target_size = self.input_shape[-2:]
- self.patchsize = 3
- self.stride = 1
- self.top_k = 0
- self.input_dims = [512, 1024]
- self.output_dim = 1536
- self.smoothing = 4
-
- # self.feature_dimensions = [512, 1024]
- # self.preprocessing = Preprocessing(self.feature_dimensions,self.output_dim)
-
- self.preprocessing_modules = torch.nn.ModuleList()
- for input_dim in self.input_dims:
- module = MeanMapper(self.output_dim)
- self.preprocessing_modules.append(module)
- self.pre_projection = Projection(self.output_dim, self.output_dim, 1, 0)
-
- self.discriminator = Discriminator(self.output_dim, n_layers=2, hidden=1024)
-
- self.layer2 = nn.Sequential(
- self.backbone.conv1,
- self.backbone.bn1,
- self.backbone.relu,
- self.backbone.maxpool,
- self.backbone.layer1,
- self.backbone.layer2,
- )
-
- self.layer3 = nn.Sequential(
- self.backbone.conv1,
- self.backbone.bn1,
- self.backbone.relu,
- self.backbone.maxpool,
- self.backbone.layer1,
- self.backbone.layer2,
- self.backbone.layer3
- )
-
- self.unfolded_features = []
- self.patch_shapes = []
- self.padding = int((self.patchsize - 1) / 2)
- self.unfolder = nn.Unfold(kernel_size=self.patchsize, stride=self.stride, padding=self.padding, dilation=1)
-
- def score(self, x):
- was_numpy = False
- if isinstance(x, np.ndarray):
- was_numpy = True
- x = torch.from_numpy(x)
- while x.ndim > 2:
- x = torch.max(x, dim=-1).values
- if x.ndim == 2:
- if self.top_k > 1:
- x = torch.topk(x, self.top_k, dim=1).values.mean(1)
- else:
- x = torch.max(x, dim=1).values
- if was_numpy:
- return x.numpy()
- return x
-
- def forward(self,x):
- output1 = self.layer2(x)
- output2 = self.layer3(x)
- unfolded_features1 = self.unfolder(output1)
- patch_shapes1 = []
- for s in output1.shape[-2:]:
- n_patches = (
- s + 2 * self.padding - 1 * (self.patchsize - 1) - 1
- ) / self.stride + 1
- patch_shapes1.append(int(n_patches))
- unfolded_features1 = unfolded_features1.reshape(
- *output1.shape[:2], self.patchsize, self.patchsize, -1
- )
- unfolded_features1 = unfolded_features1.permute(0, 4, 1, 2, 3)
-
- unfolded_features2 = self.unfolder(output2)
- patch_shapes2 = []
- for s in output2.shape[-2:]:
- n_patches = (
- s + 2 * self.padding - 1 * (self.patchsize - 1) - 1
- ) / self.stride + 1
- patch_shapes2.append(int(n_patches))
- unfolded_features2 = unfolded_features2.reshape(
- *output2.shape[:2], self.patchsize, self.patchsize, -1
- )
- unfolded_features2 = unfolded_features2.permute(0, 4, 1, 2, 3)
-
- ref_num_patches = patch_shapes1
- _features = unfolded_features2
- patch_dims = patch_shapes2
- _features = _features.reshape(
- _features.shape[0], patch_dims[0], patch_dims[1], *_features.shape[2:]
- )
- _features = _features.permute(0, -3, -2, -1, 1, 2)
- perm_base_shape = _features.shape
- _features = _features.reshape(-1, *_features.shape[-2:])
- _features = F.interpolate(
- _features.unsqueeze(1),
- size=(ref_num_patches[0], ref_num_patches[1]),
- mode="bilinear",
- align_corners=False,
- )
- _features = _features.squeeze(1)
- _features = _features.reshape(
- *perm_base_shape[:-2], ref_num_patches[0], ref_num_patches[1]
- )
- _features = _features.permute(0, -2, -1, 1, 2, 3)
- _features = _features.reshape(len(_features), -1, *_features.shape[-3:])
- unfolded_features2 = _features
-
- unfolded_features1 = unfolded_features1.reshape(-1, *unfolded_features1.shape[-3:])
- unfolded_features2 = unfolded_features2.reshape(-1, *unfolded_features2.shape[-3:])
-
- # _features = []
- model1 = self.preprocessing_modules[0]
- feature1 = model1(unfolded_features1)
- model2 = self.preprocessing_modules[1]
- feature2 = model2(unfolded_features2)
- features = torch.stack([feature1, feature2], dim=1)
-
- features = features.reshape(len(features), 1, -1)
- features = F.adaptive_avg_pool1d(features, self.output_dim)
- features = features.reshape(len(features), -1)
-
- patch_shapes = []
- patch_shapes.append(patch_shapes1)
- patch_shapes.append(patch_shapes2)
-
- features = self.pre_projection(features) # torch.Size([4096, 1536])
- self.features = features
-
- patch_scores = image_scores = -self.discriminator(features)
-
- patch_scores = patch_scores.cpu().detach().numpy()
- image_scores = image_scores.cpu().detach().numpy()
-
- image_scores = image_scores.reshape(self.batchsize, -1, *image_scores.shape[1:])
- image_scores = image_scores.reshape(*image_scores.shape[:2], -1)
-
- image_scores = self.score(image_scores)
- image_scores = image_scores.reshape(self.batchsize, -1, *image_scores.shape[1:])
- scales = patch_shapes[0]
- # patch_scores = patch_scores.reshape(1, scales[0], scales[1])
- patch_scores = patch_scores.reshape(self.batchsize, scales[0], scales[1])
- # features = features.reshape(1, scales[0], scales[1], -1)
- features = features.reshape(self.batchsize, scales[0], scales[1], -1)
-
- with torch.no_grad():
- if isinstance(patch_scores, np.ndarray):
- patch_scores = torch.from_numpy(patch_scores)
- _scores = patch_scores.to(self.device)
- _scores = _scores.unsqueeze(1)
- _scores = F.interpolate(
- _scores, size=self.target_size, mode="bilinear", align_corners=False
- )
- _scores = _scores.squeeze(1)
- patch_scores = _scores.cpu().numpy()
-
- if isinstance(features, np.ndarray):
- features = torch.from_numpy(features)
- features = features.to(self.device).permute(0, 3, 1, 2)
- if self.target_size[0] * self.target_size[1] * features.shape[0] * features.shape[1] >= 2 ** 31:
- subbatch_size = int((2 ** 31 - 1) / (self.target_size[0] * self.target_size[1] * features.shape[1]))
- interpolated_features = []
- for i_subbatch in range(int(features.shape[0] / subbatch_size + 1)):
- subfeatures = features[i_subbatch * subbatch_size:(i_subbatch + 1) * subbatch_size]
- subfeatures = subfeatures.unsuqeeze(0) if len(subfeatures.shape) == 3 else subfeatures
- subfeatures = F.interpolate(
- subfeatures, size=self.target_size, mode="bilinear", align_corners=False
- )
- interpolated_features.append(subfeatures)
- features = torch.cat(interpolated_features, 0)
- else:
- features = F.interpolate(
- features, size=self.target_size, mode="bilinear", align_corners=False
- )
- # features = features.cpu().detach().numpy()
- masks = [ndimage.gaussian_filter(patch_score, sigma=self.smoothing) for patch_score in patch_scores]
- masks = torch.tensor(masks)
- return masks # ,self.patch_shapes
-
- # net = unsupervisedNet(1,False)
- # # # net.cuda()
- # x = torch.randn((1,3,512,512))#.cuda()
- # # y = net(x)
- # # print(y.shape)
- # # pth2onnx(net,x,'test.onnx')
- # trace_script_module = torch.jit.trace(net,x)
- # trace_script_module.save('net1.torchscript')
训练代码:
- from net import unsupervisedNet
- import torch
- from mvdataset import Dataset
- from torch.utils.data import DataLoader
-
- if __name__ == '__main__':
- lr = 1e-3
- mix_noise = 1
- noise_std = 0.015
- dsc_margin = .5
- x = torch.randn(2,3,512,512)
- label = torch.zeros(x.shape[0],512,512)
- model = unsupervisedNet(x.shape[0],True)
- dataset = Dataset(source=r"D:\dataset\mvtec_anomaly_detection\bottle\train\good")
- dataloader = DataLoader(dataset,batch_size=2,shuffle=True,num_workers=4)
- model.discriminator.train()
- model.pre_projection.train()
- dsc_opt = torch.optim.Adam(model.discriminator.parameters(), lr=0.0002, weight_decay=1e-5)
- proj_opt = torch.optim.AdamW(model.pre_projection.parameters(), lr * .1)
- backbone_opt = torch.optim.AdamW(model.backbone.parameters(), lr)
- # x = x.cuda()
- # label = label.cuda()
- # model.cuda()
- for epoch in range(100):
- for data,label in dataloader:
- dsc_opt.zero_grad()
- proj_opt.zero_grad()
- mask = model(data)
- true_feats = model.features
- noise_idxs = torch.randint(0,mix_noise, torch.Size([true_feats.shape[0]]))
- noise_one_hot = torch.nn.functional.one_hot(noise_idxs, num_classes=mix_noise)#.to(self.device)
- noise = torch.stack([
- torch.normal(0, noise_std * 1.1 ** (k), true_feats.shape)
- for k in range(mix_noise)], dim=1)#.to(self.device)
- noise = (noise * noise_one_hot.unsqueeze(-1)).sum(1)
- fake_feats = true_feats + noise
- scores = model.discriminator(torch.cat([true_feats, fake_feats]))
- true_scores = scores[:len(true_feats)]
- fake_scores = scores[len(fake_feats):]
- th = dsc_margin
- p_true = (true_scores.detach() >= th).sum() / len(true_scores)
- p_fake = (fake_scores.detach() < -th).sum() / len(fake_scores)
- true_loss = torch.clip(-true_scores + th, min=0)
- fake_loss = torch.clip(fake_scores + th, min=0)
- loss = true_loss.mean() + fake_loss.mean()
- print(loss)
- loss.backward()
- backbone_opt.step()
- proj_opt.step()
- dsc_opt.step()
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。