当前位置:   article > 正文

深度学习区分不同种类的图片_google/vit-base-patch16-224-in21k

google/vit-base-patch16-224-in21k

数据集格式

 

之前利用resnet从一开始训练,效果比较差,后来利用谷歌的模型进行微调达到了很好的效果

训练代码如下:

  1. from datasets import load_dataset
  2. # /home/huhao/TensorFlow2.0_ResNet/dataset
  3. # /home/huhao/dataset
  4. import numpy as np
  5. from datasets import load_metric
  6. scene = load_dataset("/home/huhao/TensorFlow2.0_ResNet/dataset")
  7. dataset = scene['train']
  8. scene = dataset.train_test_split(test_size=0.2)
  9. labels = scene["train"].features["label"].names
  10. label2id, id2label = dict(), dict()
  11. for i, label in enumerate(labels):
  12. label2id[label] = str(i)
  13. id2label[str(i)] = label
  14. from transformers import AutoFeatureExtractor
  15. # google/vit-base-patch16-224-in21k
  16. feature_extractor = AutoFeatureExtractor.from_pretrained("google/vit-base-patch16-224-in21k")
  17. from torchvision.transforms import RandomResizedCrop, Compose, Normalize, ToTensor
  18. normalize = Normalize(mean=feature_extractor.image_mean, std=feature_extractor.image_std)
  19. _transforms = Compose([RandomResizedCrop(feature_extractor.size), ToTensor(), normalize])
  20. def transforms(examples):
  21. examples["pixel_values"] = [_transforms(img.convert("RGB")) for img in examples["image"]]
  22. del examples["image"]
  23. return examples
  24. scene = scene.with_transform(transforms)
  25. from transformers import DefaultDataCollator
  26. data_collator = DefaultDataCollator()
  27. from transformers import AutoModelForImageClassification, TrainingArguments, Trainer
  28. def compute_metric(eval_pred):
  29. metric = load_metric("accuracy")
  30. logits,labels = eval_pred
  31. print(logits,labels)
  32. print(len(logits),len(labels))
  33. predictions = np.argmax(logits,axis=-1)
  34. print(len(predictions))
  35. print('predictions')
  36. print(predictions)
  37. return metric.compute(predictions = predictions,references = labels)
  38. model = AutoModelForImageClassification.from_pretrained(
  39. "google/vit-base-patch16-224-in21k",
  40. num_labels=len(labels),
  41. id2label=id2label,
  42. label2id=label2id,
  43. )
  44. training_args = TrainingArguments(
  45. output_dir="./results",
  46. overwrite_output_dir = 'True',
  47. per_device_train_batch_size=16,
  48. evaluation_strategy="steps",
  49. num_train_epochs=4,
  50. save_steps=100,
  51. eval_steps=100,
  52. logging_steps=10,
  53. learning_rate=2e-4,
  54. save_total_limit=2,
  55. remove_unused_columns=False,
  56. load_best_model_at_end=False,
  57. save_strategy='no',
  58. )
  59. trainer = Trainer(
  60. model=model,
  61. args=training_args,
  62. data_collator=data_collator,
  63. train_dataset=scene["train"],
  64. eval_dataset=scene["test"],
  65. tokenizer=feature_extractor,
  66. compute_metrics=compute_metric,
  67. )
  68. trainer.train()
  69. trainer.evaluate()
  70. trainer.save_model('/home/huhao/script/model')

测试代码如下

  1. from transformers import AutoFeatureExtractor, AutoModelForImageClassification
  2. extractor = AutoFeatureExtractor.from_pretrained("HaoHu/vit-base-patch16-224-in21k-classify-4scence")
  3. model = AutoModelForImageClassification.from_pretrained("HaoHu/vit-base-patch16-224-in21k-classify-4scence")
  4. # 我已经把训练好的模型上传到网上,这里下载即可使用
  5. from datasets import load_dataset
  6. # /home/huhao/TensorFlow2.0_ResNet/dataset
  7. # /home/huhao/dataset
  8. import numpy as np
  9. from datasets import load_metric
  10. # 这个是数据集加载的路径
  11. scene = load_dataset("/home/huhao/script/dataset")
  12. dataset = scene['train']
  13. scene = dataset.train_test_split(test_size=0.2)
  14. labels = scene["train"].features["label"].names
  15. label2id, id2label = dict(), dict()
  16. for i, label in enumerate(labels):
  17. label2id[label] = str(i)
  18. id2label[str(i)] = label
  19. from transformers import AutoFeatureExtractor
  20. # google/vit-base-patch16-224-in21k
  21. feature_extractor = AutoFeatureExtractor.from_pretrained("HaoHu/vit-base-patch16-224-in21k-classify-4scence")
  22. from torchvision.transforms import RandomResizedCrop, Compose, Normalize, ToTensor
  23. normalize = Normalize(mean=feature_extractor.image_mean, std=feature_extractor.image_std)
  24. _transforms = Compose([RandomResizedCrop(feature_extractor.size), ToTensor(), normalize])
  25. def transforms(examples):
  26. examples["pixel_values"] = [_transforms(img.convert("RGB")) for img in examples["image"]]
  27. del examples["image"]
  28. return examples
  29. scene = scene.with_transform(transforms)
  30. from transformers import DefaultDataCollator
  31. data_collator = DefaultDataCollator()
  32. from transformers import AutoModelForImageClassification, TrainingArguments, Trainer
  33. training_args = TrainingArguments(
  34. output_dir="./results",
  35. overwrite_output_dir = 'True',
  36. per_device_train_batch_size=16,
  37. evaluation_strategy="steps",
  38. num_train_epochs=4,
  39. save_steps=100,
  40. eval_steps=100,
  41. logging_steps=10,
  42. learning_rate=2e-4,
  43. save_total_limit=2,
  44. remove_unused_columns=False,
  45. load_best_model_at_end=False,
  46. save_strategy='no',
  47. )
  48. model = AutoModelForImageClassification.from_pretrained(
  49. "HaoHu/vit-base-patch16-224-in21k-classify-4scence",
  50. num_labels=len(labels),
  51. id2label=id2label,
  52. label2id=label2id,
  53. )
  54. def compute_metric(eval_pred):
  55. metric = load_metric("f1")
  56. logits,labels = eval_pred
  57. print(len(logits),len(labels))
  58. predictions = np.argmax(logits,axis=-1)
  59. print('对测试集进行评估')
  60. print('labels')
  61. print(labels)
  62. print('predictions')
  63. print(predictions)
  64. return metric.compute(predictions = predictions,references = labels,average='macro')
  65. trainer = Trainer(
  66. model=model,
  67. args=training_args,
  68. data_collator=data_collator,
  69. eval_dataset=scene["test"],
  70. tokenizer=feature_extractor,
  71. compute_metrics=compute_metric,
  72. )
  73. compute_metrics = trainer.evaluate()
  74. # {'eval_loss': 0.04495017230510712, 'eval_accuracy': 0.9943181818181818, 'eval_runtime': 30.8715, 'eval_samples_per_second': 11.402, 'eval_steps_per_second': 1.425}
  75. print('输出最后的结果eval_f1:')
  76. print(compute_metrics['eval_f1'])
  77. from doctest import Example
  78. from transformers import AutoFeatureExtractor, AutoModelForImageClassification, ImageClassificationPipeline
  79. import os
  80. extractor = AutoFeatureExtractor.from_pretrained("HaoHu/vit-base-patch16-224-in21k-classify-4scence")
  81. model = AutoModelForImageClassification.from_pretrained("HaoHu/vit-base-patch16-224-in21k-classify-4scence")
  82. from transformers import pipeline
  83. #generator = ImageClassificationPipeline(model=model, tokenizer=extractor)
  84. vision_classifier = pipeline(task="image-classification",model = model,feature_extractor = extractor)
  85. result_dict = {'City_road':0,'fog':1,'rain':2,'snow':3}
  86. val_path = '/home/huhao/script/val/'
  87. all_img = os.listdir(val_path)
  88. for img in all_img:
  89. tmp_score = 0
  90. end_label = ''
  91. img_path = os.path.join(val_path,img)
  92. score_list = vision_classifier(img_path)
  93. for sample in score_list:
  94. score = sample['score']
  95. label = sample['label']
  96. if tmp_score < score:
  97. tmp_score = score
  98. end_label = label
  99. print(result_dict[end_label])

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/AllinToyou/article/detail/287817
推荐阅读
相关标签
  

闽ICP备14008679号