赞
踩
下图为VGG网络结构图,最常用的就是表中的D结构,16层结构(13层卷积+3层全连接层),卷积的stride为1,padding为1,maxpool的大小为2,stride为2(池化只改变图像的大小,不改变图像的深度)
vgg网络结构可以看作两部分:特征提取网络(连接层之前)+分类网络(3层全连接层)
VGG模型一共分为两部分,特征提取部分和分类网络部分,我们分别进行搭建
1、定义字典文件,定义了四个网络结构
cfgs = {
'vgg11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], # 列表的数字代表卷积层卷积核的个数,字符M代表池化层的结构
'vgg13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
'vgg16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
'vgg19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'],
}
2、定义一个函数,生成vgg网络第一部分:特征提取网络
def make_features(cfg: list): # 传入一个配置变量
layers = [] # 定义一个空列表
in_channels = 3
for v in cfg:
if v == "M": # 判断是否是池化层
layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
else:
conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1) # v表示输出通道
layers += [conv2d, nn.ReLU(True)]
in_channels = v # 卷积之后,输出通道变为v
return nn.Sequential(*layers) # *layers代表通过非关键字参数的形式传入进去
1、定义VGG类
# vgg类 class VGG(nn.Module): # features代表提取特征网络 def __init__(self, features, num_classes=1000, init_weights=False): super(VGG, self).__init__() self.features = features self.classifier = nn.Sequential( nn.Dropout(p=0.5), # 减少过拟合,50%比例随机失活神经元 nn.Linear(512*7*7, 4096), nn.ReLU(True), nn.Dropout(p=0.5), nn.Linear(4096, 4096), nn.ReLU(True), nn.Linear(4096, num_classes) ) if init_weights: self._initialize_weights() def forward(self, x): # N x 3 x 224 x 224 x = self.features(x) # N x 512 x 7 x 7 展平操作 x = torch.flatten(x, start_dim=1) # 从第一个维度开始展平,第0个维度是batch # N x 512*7*7 x = self.classifier(x) return x # 初始化权重函数,会便利网络的每一个子模块,也就是遍历每一层 def _initialize_weights(self): for m in self.modules(): if isinstance(m, nn.Conv2d): # 如果当前层为卷积层 # nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') nn.init.xavier_uniform_(m.weight) # 初始化卷积核参数 if m.bias is not None: # 如果卷积核有偏置,设置偏置为0 nn.init.constant_(m.bias, 0) elif isinstance(m, nn.Linear): # 如果当前层为全连接层 nn.init.xavier_uniform_(m.weight) # nn.init.normal_(m.weight, 0, 0.01) nn.init.constant_(m.bias, 0)
2、实例化vgg
# 实例化vgg
def vgg(model_name="vgg16", **kwargs):
try:
cfg = cfgs[model_name]
except:
print("Warning: model number {} not in cfgs dict!".format(model_name))
exit(-1)
model = VGG(make_features(cfg), **kwargs) # **kwargs可变长度的字典变量
return model
vgg_model = vgg(model_name='vgg13')
import os import json import torch import torch.nn as nn from torchvision import transforms, datasets import torch.optim as optim from tqdm import tqdm from model import vgg def main(): device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") print("using {} device.".format(device)) # 数据预处理 data_transform = { "train": transforms.Compose([transforms.RandomResizedCrop(224), # 随即裁剪 transforms.RandomHorizontalFlip(), # 随机翻转 transforms.ToTensor(), # 转为tensor transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]),# 标准化处理 "val": transforms.Compose([transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])} data_root = os.path.abspath(os.path.join(os.getcwd(), "../..")) # get data root path image_path = os.path.join(data_root, "data_set", "flower_data") # flower data set path assert os.path.exists(image_path), "{} path does not exist.".format(image_path) train_dataset = datasets.ImageFolder(root=os.path.join(image_path, "train"), transform=data_transform["train"]) train_num = len(train_dataset) # {'daisy':0, 'dandelion':1, 'roses':2, 'sunflower':3, 'tulips':4} flower_list = train_dataset.class_to_idx cla_dict = dict((val, key) for key, val in flower_list.items()) # write dict into json file json_str = json.dumps(cla_dict, indent=4) with open('class_indices.json', 'w') as json_file: json_file.write(json_str) batch_size = 32 nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8]) # number of workers print('Using {} dataloader workers every process'.format(nw)) train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=nw) validate_dataset = datasets.ImageFolder(root=os.path.join(image_path, "val"), transform=data_transform["val"]) val_num = len(validate_dataset) validate_loader = torch.utils.data.DataLoader(validate_dataset, batch_size=batch_size, shuffle=False, num_workers=nw) print("using {} images for training, {} images for validation.".format(train_num, val_num)) # test_data_iter = iter(validate_loader) # test_image, test_label = test_data_iter.next() model_name = "vgg16" net = vgg(model_name=model_name, num_classes=5, init_weights=True) net.to(device) loss_function = nn.CrossEntropyLoss() optimizer = optim.Adam(net.parameters(), lr=0.0001) epochs = 30 best_acc = 0.0 save_path = './{}Net.pth'.format(model_name) train_steps = len(train_loader) for epoch in range(epochs): # train net.train() running_loss = 0.0 train_bar = tqdm(train_loader) for step, data in enumerate(train_bar): images, labels = data optimizer.zero_grad() outputs = net(images.to(device)) loss = loss_function(outputs, labels.to(device)) loss.backward() optimizer.step() # print statistics running_loss += loss.item() train_bar.desc = "train epoch[{}/{}] loss:{:.3f}".format(epoch + 1, epochs, loss) # validate net.eval() acc = 0.0 # accumulate accurate number / epoch with torch.no_grad(): val_bar = tqdm(validate_loader) for val_data in val_bar: val_images, val_labels = val_data outputs = net(val_images.to(device)) predict_y = torch.max(outputs, dim=1)[1] acc += torch.eq(predict_y, val_labels.to(device)).sum().item() val_accurate = acc / val_num print('[epoch %d] train_loss: %.3f val_accuracy: %.3f' % (epoch + 1, running_loss / train_steps, val_accurate)) if val_accurate > best_acc: best_acc = val_accurate torch.save(net.state_dict(), save_path) print('Finished Training') if __name__ == '__main__': main()
import os import json import torch from PIL import Image from torchvision import transforms import matplotlib.pyplot as plt from model import vgg def main(): device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") data_transform = transforms.Compose( [transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) # load image img_path = "../tulip.jpg" assert os.path.exists(img_path), "file: '{}' dose not exist.".format(img_path) img = Image.open(img_path) plt.imshow(img) # [N, C, H, W] img = data_transform(img) # expand batch dimension img = torch.unsqueeze(img, dim=0) # read class_indict json_path = './class_indices.json' assert os.path.exists(json_path), "file: '{}' dose not exist.".format(json_path) json_file = open(json_path, "r") class_indict = json.load(json_file) # create model model = vgg(model_name="vgg16", num_classes=5).to(device) # load model weights weights_path = "./vgg16Net.pth" assert os.path.exists(weights_path), "file: '{}' dose not exist.".format(weights_path) model.load_state_dict(torch.load(weights_path, map_location=device)) model.eval() with torch.no_grad(): # predict class output = torch.squeeze(model(img.to(device))).cpu() predict = torch.softmax(output, dim=0) predict_cla = torch.argmax(predict).numpy() print_res = "class: {} prob: {:.3}".format(class_indict[str(predict_cla)], predict[predict_cla].numpy()) plt.title(print_res) print(print_res) plt.show() if __name__ == '__main__': main()
参考视频:https://www.bilibili.com/video/BV1i7411T7ZN?spm_id_from=333.999.0.0
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。