当前位置:   article > 正文

resnet50网络实现垃圾分类_resnet50垃圾分类

resnet50垃圾分类

一:介绍:

        经典resnet网络是由何凯明团队于2015年提出,论文名为《Deep Residual Learning for Image Recognition》

        resnet网络所要解决的问题为深度神经网络的“退化”问题,即随着神经网络搭建的越深,拟合效果却越差的问题,并且这个问题不是由过拟合诱发的。

        resnet也成为残差网络,网络由残差块构建:

        残差块由多个级联的卷积层和一个shortcut connections组成,将二者的输出值累加后,通过ReLU激活层得到残差块的输出。多个残差块可以串联起来,从而实现更深的网络。

        残差块有两种设计方式

 

        左图针对较浅的网络,如ResNet-18/34;右图针对较深的网络,又称为”bottleneck” building block,如ResNet-50/101/152,使用此方式的目的就是为了降低参数数目。

        论文给出了五种不同层数的resnet

         ResNet-18/34对应的每个残差块的卷积kernel大小依次是3*3、3*3,ResNet-50/101/152对应的每个残差块的卷积kernel大小依次是1*1、3*3、1*1。

         论文中给出了层数为34的ResNet网络结构

二:实现垃圾分类

1.准备数据集:

 2.加载数据集:

  1. class garbage_datasets(Dataset):
  2. def __init__(self, filepath):
  3. self.images = []
  4. self.labels = []
  5. self.transform = transform
  6. for filename in tqdm(os.listdir(filepath+'Hazardous waste')):
  7. image = Image.open(filepath+'Hazardous waste/'+filename)
  8. image = image.resize((224,224))
  9. image = self.transform(image)
  10. self.images.append(image)
  11. self.labels.append(0)
  12. for filename in tqdm(os.listdir(filepath+'Kitchen waste')):
  13. image = Image.open(filepath+'Kitchen waste/'+filename)
  14. image = image.resize((224,224))
  15. image = self.transform(image)
  16. self.images.append(image)
  17. self.labels.append(1)
  18. for filename in tqdm(os.listdir(filepath+'Other garbage')):
  19. image = Image.open(filepath+'Other garbage/'+filename)
  20. image = image.resize((224,224))
  21. image = self.transform(image)
  22. self.images.append(image)
  23. self.labels.append(2)
  24. for filename in tqdm(os.listdir(filepath+'Recyclable garbage')):
  25. image = Image.open(filepath+'Recyclable garbage/'+filename)
  26. image = image.resize((224,224))
  27. image = self.transform(image)
  28. self.images.append(image)
  29. self.labels.append(3)
  30. self.labels = torch.LongTensor(self.labels)
  31. def __getitem__(self, index):
  32. return self.images[index], self.labels[index]
  33. def __len__(self):
  34. images = np.array(self.images)
  35. len = images.shape[0]
  36. return len
  37. train_data = garbage_datasets('data/train/')
  38. train_loader = DataLoader(train_data,batch_size = batch_size,shuffle = True)
  39. val_data = garbage_datasets('data/val/')
  40. val_loader = DataLoader(val_data,batch_size = batch_size)

 3.构建网络:

  1. class Bottleneck(nn.Module):
  2. extention=4
  3. def __init__(self,inplanes,planes,stride,downsample=None):
  4. super(Bottleneck, self).__init__()
  5. self.conv1=nn.Conv2d(inplanes,planes,kernel_size=1,stride=stride,bias=False)
  6. self.bn1=nn.BatchNorm2d(planes)
  7. self.conv2=nn.Conv2d(planes,planes,kernel_size=3,stride=1,padding=1,bias=False)
  8. self.bn2=nn.BatchNorm2d(planes)
  9. self.conv3=nn.Conv2d(planes,planes*self.extention,kernel_size=1,stride=1,bias=False)
  10. self.bn3=nn.BatchNorm2d(planes*self.extention)
  11. self.relu=nn.ReLU( )
  12. self.downsample=downsample
  13. self.stride=stride
  14. def forward(self,x):
  15. residual=x
  16. out=self.conv1(x)
  17. out=self.bn1(out)
  18. out=self.relu(out)
  19. out=self.conv2(out)
  20. out=self.bn2(out)
  21. out=self.relu(out)
  22. out=self.conv3(out)
  23. out=self.bn3(out)
  24. out=self.relu(out)
  25. if self.downsample is not None:
  26. residual=self.downsample(x)
  27. out=out + residual
  28. out=self.relu(out)
  29. return out
  30. class ResNet(nn.Module):
  31. def __init__(self,block,layers,num_class):
  32. self.inplane=64
  33. super(ResNet, self).__init__()
  34. self.block=block
  35. self.layers=layers
  36. self.conv1=nn.Conv2d(3,self.inplane,kernel_size=7,stride=2,padding=3,bias=False)
  37. self.bn1=nn.BatchNorm2d(self.inplane)
  38. self.relu=nn.ReLU()
  39. self.maxpool=nn.MaxPool2d(kernel_size=3,stride=2,padding=1)
  40. self.stage1=self.make_layer(self.block,64,layers[0],stride=1)
  41. self.stage2=self.make_layer(self.block,128,layers[1],stride=2)
  42. self.stage3=self.make_layer(self.block,256,layers[2],stride=2)
  43. self.stage4=self.make_layer(self.block,512,layers[3],stride=2)
  44. self.avgpool=nn.AvgPool2d(7)
  45. self.fc=nn.Linear(512*block.extention,num_class)
  46. def forward(self,x):
  47. out=self.conv1(x)
  48. out=self.bn1(out)
  49. out=self.relu(out)
  50. out=self.maxpool(out)
  51. out=self.stage1(out)
  52. out=self.stage2(out)
  53. out=self.stage3(out)
  54. out=self.stage4(out)
  55. out=self.avgpool(out)
  56. out=torch.flatten(out,1)
  57. out=self.fc(out)
  58. return out
  59. def make_layer(self,block,plane,block_num,stride=1):
  60. block_list=[]
  61. downsample=None
  62. if(stride!=1 or self.inplane!=plane*block.extention):
  63. downsample=nn.Sequential(
  64. nn.Conv2d(self.inplane,plane*block.extention,stride=stride,kernel_size=1,bias=False),
  65. nn.BatchNorm2d(plane*block.extention)
  66. )
  67. conv_block=block(self.inplane,plane,stride=stride,downsample=downsample)
  68. block_list.append(conv_block)
  69. self.inplane=plane*block.extention
  70. for i in range(1,block_num):
  71. block_list.append(block(self.inplane,plane,stride=1))
  72. return nn.Sequential(*block_list)
  73. model=ResNet(Bottleneck,[3,4,6,3],4)

4.训练模型:

  1. def train(epoch):
  2. model.train()
  3. print("epoch:",epoch+1)
  4. running_loss = 0.0
  5. for batch_idx,data in enumerate(train_loader,0):
  6. inputs, targets = data
  7. inputs, targets = inputs.to(device),targets.to(device)
  8. optimizer.zero_grad()
  9. outputs = model(inputs)
  10. loss = criterion(outputs,targets)
  11. loss.backward()
  12. optimizer.step()
  13. running_loss = running_loss + loss.item()
  14. print('train loss: %.3f' % (running_loss/batch_idx))
  15. torch.save(model.state_dict(), './model1.pth')

5.验证模型:

  1. def val():
  2. model.eval()
  3. correct = 0
  4. total = 0
  5. with torch.no_grad():
  6. for data in val_loader:
  7. images, labels = data
  8. images, labels = images.to(device), labels.to(device)
  9. outputs = model(images)
  10. _, predicted = torch.max(outputs.data, dim=1)
  11. total += labels.size(0)
  12. correct += (predicted == labels).sum().item()
  13. print('accuracy on test set: %d %% ' % (100*correct/total))
  14. return correct/total

6.测试模型:

  1. def test(imgpath):
  2. font={ 'color': 'red',
  3. 'size': 20,
  4. 'family': 'Times New Roman',
  5. 'style':'italic'}
  6. o_img = Image.open(imgpath)
  7. o_img1 = o_img.resize((224,224))
  8. img = transform(o_img1)
  9. img = img.unsqueeze(0)
  10. img = img.cuda()
  11. print(img.shape)
  12. model = ResNet(Bottleneck,[3,4,6,3],4)
  13. model.load_state_dict(torch.load("model.pth"))
  14. model = model.cuda()
  15. output = model(img)
  16. _, predict = torch.max(output,dim=1)
  17. if predict == 0:
  18. print("Hazardous waste")
  19. plt.imshow(o_img)
  20. plt.text(0, -6.0, "Hazardous waste", fontdict=font)
  21. plt.show()
  22. if predict == 1:
  23. print("Kitchen waste")
  24. plt.imshow(o_img)
  25. plt.text(0, -6.0, "Kitchen waste", fontdict=font)
  26. plt.show()
  27. if predict == 2:
  28. print("Other garbage")
  29. plt.imshow(o_img)
  30. plt.text(0, -6.0, "Other garbage", fontdict=font)
  31. plt.show()
  32. if predict == 3:
  33. print("Recyclable garbage")
  34. plt.imshow(o_img)
  35. plt.text(0, -6.0, "Recyclable garbage", fontdict=font)
  36. plt.show()

源代码:

  1. import torch.nn as nn
  2. import torch
  3. import numpy as np
  4. from torch.utils.data import DataLoader,Dataset
  5. from torchvision import transforms
  6. import torchvision
  7. import torch.nn.functional as F
  8. import torch.optim as optim
  9. import os
  10. from tqdm import tqdm
  11. from PIL import Image
  12. import matplotlib.pyplot as plt
  13. batch_size = 8
  14. transform = transforms.Compose([transforms.ToTensor(),
  15. transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])])
  16. torch.cuda.empty_cache()
  17. class garbage_datasets(Dataset):
  18. def __init__(self, filepath):
  19. self.images = []
  20. self.labels = []
  21. self.transform = transform
  22. for filename in tqdm(os.listdir(filepath+'Hazardous waste')):
  23. image = Image.open(filepath+'Hazardous waste/'+filename)
  24. image = image.resize((224,224))
  25. image = self.transform(image)
  26. self.images.append(image)
  27. self.labels.append(0)
  28. for filename in tqdm(os.listdir(filepath+'Kitchen waste')):
  29. image = Image.open(filepath+'Kitchen waste/'+filename)
  30. image = image.resize((224,224))
  31. image = self.transform(image)
  32. self.images.append(image)
  33. self.labels.append(1)
  34. for filename in tqdm(os.listdir(filepath+'Other garbage')):
  35. image = Image.open(filepath+'Other garbage/'+filename)
  36. image = image.resize((224,224))
  37. image = self.transform(image)
  38. self.images.append(image)
  39. self.labels.append(2)
  40. for filename in tqdm(os.listdir(filepath+'Recyclable garbage')):
  41. image = Image.open(filepath+'Recyclable garbage/'+filename)
  42. image = image.resize((224,224))
  43. image = self.transform(image)
  44. self.images.append(image)
  45. self.labels.append(3)
  46. self.labels = torch.LongTensor(self.labels)
  47. def __getitem__(self, index):
  48. return self.images[index], self.labels[index]
  49. def __len__(self):
  50. images = np.array(self.images)
  51. len = images.shape[0]
  52. return len
  53. train_data = garbage_datasets('data/train/')
  54. train_loader = DataLoader(train_data,batch_size = batch_size,shuffle = True)
  55. val_data = garbage_datasets('data/val/')
  56. val_loader = DataLoader(val_data,batch_size = batch_size)
  57. class Bottleneck(nn.Module):
  58. extention=4
  59. def __init__(self,inplanes,planes,stride,downsample=None):
  60. super(Bottleneck, self).__init__()
  61. self.conv1=nn.Conv2d(inplanes,planes,kernel_size=1,stride=stride,bias=False)
  62. self.bn1=nn.BatchNorm2d(planes)
  63. self.conv2=nn.Conv2d(planes,planes,kernel_size=3,stride=1,padding=1,bias=False)
  64. self.bn2=nn.BatchNorm2d(planes)
  65. self.conv3=nn.Conv2d(planes,planes*self.extention,kernel_size=1,stride=1,bias=False)
  66. self.bn3=nn.BatchNorm2d(planes*self.extention)
  67. self.relu=nn.ReLU( )
  68. self.downsample=downsample
  69. self.stride=stride
  70. def forward(self,x):
  71. residual=x
  72. out=self.conv1(x)
  73. out=self.bn1(out)
  74. out=self.relu(out)
  75. out=self.conv2(out)
  76. out=self.bn2(out)
  77. out=self.relu(out)
  78. out=self.conv3(out)
  79. out=self.bn3(out)
  80. out=self.relu(out)
  81. if self.downsample is not None:
  82. residual=self.downsample(x)
  83. out=out + residual
  84. out=self.relu(out)
  85. return out
  86. class ResNet(nn.Module):
  87. def __init__(self,block,layers,num_class):
  88. self.inplane=64
  89. super(ResNet, self).__init__()
  90. self.block=block
  91. self.layers=layers
  92. self.conv1=nn.Conv2d(3,self.inplane,kernel_size=7,stride=2,padding=3,bias=False)
  93. self.bn1=nn.BatchNorm2d(self.inplane)
  94. self.relu=nn.ReLU()
  95. self.maxpool=nn.MaxPool2d(kernel_size=3,stride=2,padding=1)
  96. self.stage1=self.make_layer(self.block,64,layers[0],stride=1)
  97. self.stage2=self.make_layer(self.block,128,layers[1],stride=2)
  98. self.stage3=self.make_layer(self.block,256,layers[2],stride=2)
  99. self.stage4=self.make_layer(self.block,512,layers[3],stride=2)
  100. self.avgpool=nn.AvgPool2d(7)
  101. self.fc=nn.Linear(512*block.extention,num_class)
  102. def forward(self,x):
  103. out=self.conv1(x)
  104. out=self.bn1(out)
  105. out=self.relu(out)
  106. out=self.maxpool(out)
  107. out=self.stage1(out)
  108. out=self.stage2(out)
  109. out=self.stage3(out)
  110. out=self.stage4(out)
  111. out=self.avgpool(out)
  112. out=torch.flatten(out,1)
  113. out=self.fc(out)
  114. return out
  115. def make_layer(self,block,plane,block_num,stride=1):
  116. block_list=[]
  117. downsample=None
  118. if(stride!=1 or self.inplane!=plane*block.extention):
  119. downsample=nn.Sequential(
  120. nn.Conv2d(self.inplane,plane*block.extention,stride=stride,kernel_size=1,bias=False),
  121. nn.BatchNorm2d(plane*block.extention)
  122. )
  123. conv_block=block(self.inplane,plane,stride=stride,downsample=downsample)
  124. block_list.append(conv_block)
  125. self.inplane=plane*block.extention
  126. for i in range(1,block_num):
  127. block_list.append(block(self.inplane,plane,stride=1))
  128. return nn.Sequential(*block_list)
  129. model=ResNet(Bottleneck,[3,4,6,3],4)
  130. device = torch.device('cuda'if torch.cuda.is_available else 'cpu')
  131. model.to(device)
  132. model.load_state_dict(torch.load("model1.pth"))
  133. criterion = torch.nn.CrossEntropyLoss()
  134. optimizer = optim.Adam(model.parameters(),lr = 0.001)
  135. def train(epoch):
  136. model.train()
  137. print("epoch:",epoch+1)
  138. running_loss = 0.0
  139. for batch_idx,data in enumerate(train_loader,0):
  140. inputs, targets = data
  141. inputs, targets = inputs.to(device),targets.to(device)
  142. optimizer.zero_grad()
  143. outputs = model(inputs)
  144. loss = criterion(outputs,targets)
  145. loss.backward()
  146. optimizer.step()
  147. running_loss = running_loss + loss.item()
  148. print('train loss: %.3f' % (running_loss/batch_idx))
  149. torch.save(model.state_dict(), './model1.pth')
  150. def val():
  151. model.eval()
  152. correct = 0
  153. total = 0
  154. with torch.no_grad():
  155. for data in val_loader:
  156. images, labels = data
  157. images, labels = images.to(device), labels.to(device)
  158. outputs = model(images)
  159. _, predicted = torch.max(outputs.data, dim=1)
  160. total += labels.size(0)
  161. correct += (predicted == labels).sum().item()
  162. print('accuracy on test set: %d %% ' % (100*correct/total))
  163. return correct/total
  164. if __name__ == '__main__':
  165. acc_list = []
  166. epoch_list = []
  167. for epoch in range(5):
  168. train(epoch)
  169. acc = val()
  170. acc_list.append(acc)
  171. epoch_list.append(epoch + 1)
  172. plt.plot(epoch_list,acc_list)
  173. plt.ylabel("ACC")
  174. plt.xlabel("Epoch")
  175. plt.show()

测试源码:

  1. from torchvision import transforms
  2. from PIL import Image
  3. import matplotlib.pyplot as plt
  4. import torch
  5. import torch.nn as nn
  6. transform = transforms.Compose([transforms.ToTensor(),
  7. transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])])
  8. class Bottleneck(nn.Module):
  9. extention=4
  10. def __init__(self,inplanes,planes,stride,downsample=None):
  11. super(Bottleneck, self).__init__()
  12. self.conv1=nn.Conv2d(inplanes,planes,kernel_size=1,stride=stride,bias=False)
  13. self.bn1=nn.BatchNorm2d(planes)
  14. self.conv2=nn.Conv2d(planes,planes,kernel_size=3,stride=1,padding=1,bias=False)
  15. self.bn2=nn.BatchNorm2d(planes)
  16. self.conv3=nn.Conv2d(planes,planes*self.extention,kernel_size=1,stride=1,bias=False)
  17. self.bn3=nn.BatchNorm2d(planes*self.extention)
  18. self.relu=nn.ReLU( )
  19. self.downsample=downsample
  20. self.stride=stride
  21. def forward(self,x):
  22. residual=x
  23. out=self.conv1(x)
  24. out=self.bn1(out)
  25. out=self.relu(out)
  26. out=self.conv2(out)
  27. out=self.bn2(out)
  28. out=self.relu(out)
  29. out=self.conv3(out)
  30. out=self.bn3(out)
  31. out=self.relu(out)
  32. if self.downsample is not None:
  33. residual=self.downsample(x)
  34. out=out + residual
  35. out=self.relu(out)
  36. return out
  37. class ResNet(nn.Module):
  38. def __init__(self,block,layers,num_class):
  39. self.inplane=64
  40. super(ResNet, self).__init__()
  41. self.block=block
  42. self.layers=layers
  43. self.conv1=nn.Conv2d(3,self.inplane,kernel_size=7,stride=2,padding=3,bias=False)
  44. self.bn1=nn.BatchNorm2d(self.inplane)
  45. self.relu=nn.ReLU()
  46. self.maxpool=nn.MaxPool2d(kernel_size=3,stride=2,padding=1)
  47. self.stage1=self.make_layer(self.block,64,layers[0],stride=1)
  48. self.stage2=self.make_layer(self.block,128,layers[1],stride=2)
  49. self.stage3=self.make_layer(self.block,256,layers[2],stride=2)
  50. self.stage4=self.make_layer(self.block,512,layers[3],stride=2)
  51. self.avgpool=nn.AvgPool2d(7)
  52. self.fc=nn.Linear(512*block.extention,num_class)
  53. def forward(self,x):
  54. out=self.conv1(x)
  55. out=self.bn1(out)
  56. out=self.relu(out)
  57. out=self.maxpool(out)
  58. out=self.stage1(out)
  59. out=self.stage2(out)
  60. out=self.stage3(out)
  61. out=self.stage4(out)
  62. out=self.avgpool(out)
  63. out=torch.flatten(out,1)
  64. out=self.fc(out)
  65. return out
  66. def make_layer(self,block,plane,block_num,stride=1):
  67. block_list=[]
  68. downsample=None
  69. if(stride!=1 or self.inplane!=plane*block.extention):
  70. downsample=nn.Sequential(
  71. nn.Conv2d(self.inplane,plane*block.extention,stride=stride,kernel_size=1,bias=False),
  72. nn.BatchNorm2d(plane*block.extention)
  73. )
  74. conv_block=block(self.inplane,plane,stride=stride,downsample=downsample)
  75. block_list.append(conv_block)
  76. self.inplane=plane*block.extention
  77. for i in range(1,block_num):
  78. block_list.append(block(self.inplane,plane,stride=1))
  79. return nn.Sequential(*block_list)
  80. def test(imgpath):
  81. font={ 'color': 'red',
  82. 'size': 20,
  83. 'family': 'Times New Roman',
  84. 'style':'italic'}
  85. o_img = Image.open(imgpath)
  86. o_img1 = o_img.resize((224,224))
  87. img = transform(o_img1)
  88. img = img.unsqueeze(0)
  89. img = img.cuda()
  90. print(img.shape)
  91. model = ResNet(Bottleneck,[3,4,6,3],4)
  92. model.load_state_dict(torch.load("model.pth"))
  93. model = model.cuda()
  94. output = model(img)
  95. _, predict = torch.max(output,dim=1)
  96. if predict == 0:
  97. print("Hazardous waste")
  98. plt.imshow(o_img)
  99. plt.text(0, -6.0, "Hazardous waste", fontdict=font)
  100. plt.show()
  101. if predict == 1:
  102. print("Kitchen waste")
  103. plt.imshow(o_img)
  104. plt.text(0, -6.0, "Kitchen waste", fontdict=font)
  105. plt.show()
  106. if predict == 2:
  107. print("Other garbage")
  108. plt.imshow(o_img)
  109. plt.text(0, -6.0, "Other garbage", fontdict=font)
  110. plt.show()
  111. if predict == 3:
  112. print("Recyclable garbage")
  113. plt.imshow(o_img)
  114. plt.text(0, -6.0, "Recyclable garbage", fontdict=font)
  115. plt.show()
  116. if __name__ == "__main__":
  117. test('data/test/Hazardous waste/2.jpg')

最终验证集的准确率可达到70%

附几张测试成功的图片

    

 

 

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/盐析白兔/article/detail/372773
推荐阅读
相关标签
  

闽ICP备14008679号