当前位置:   article > 正文

【keras】读取模型进行测试的方式_keras读取.h5模型进行测试

keras读取.h5模型进行测试

本篇博客仅供自己查资料时使用。 

  1. from keras.preprocessing import image
  2. import numpy as np
  3. from keras.models import load_model
  4. import os
  5. from keras.applications.resnet50 import preprocess_input
  6. from shutil import copyfile
  7. from keras.preprocessing.image import ImageDataGenerator
  8. work_dir = ''
  9. #载入模型
  10. def read_model():
  11. model = load_model(work_dir + '/model_weight.h5')
  12. return model
  13. #单张图片读取,并预测
  14. def read_model_predict(img_path,model):
  15. img = image.load_img(img_path, target_size=(100, 100))
  16. x = image.img_to_array(img)
  17. x = np.expand_dims(x, axis=0)
  18. #print(x)
  19. #归一化
  20. amin, amax = x.min(), x.max() # 求最大最小值
  21. x = (x-amin)/(amax-amin)
  22. preds = model.predict(x)
  23. return preds
  24. #测试数据集读取
  25. def read_test(test_data_dir):
  26. test_datagen = ImageDataGenerator(rescale=1. / 255)
  27. test_generator = test_datagen.flow_from_directory(
  28. test_data_dir,
  29. target_size=(100, 100),
  30. batch_size=64,
  31. class_mode='binary'
  32. )
  33. model = load_model(work_dir + '/model_weight.h5')
  34. score = model.evaluate_generator(test_generator,steps=1)
  35. print("样本准确率%s: %.2f%%" % (model.metrics_names[1], score[1] * 100))
  36. #y = model.evaluate_generator(test_generator, 20, max_q_size=10,workers=1, use_multiprocessing=False)
  37. #name_list = model.predict_generator.filenames()
  38. #print(name_list)
  39. #return y
  40. #迭代读取文件夹下的所有文件,对每一张图片进行预测
  41. def read_file_all(data_dir_path,model):
  42. right = 0
  43. wrong = 0
  44. for f in os.listdir(data_dir_path):
  45. image_path = os.path.join(data_dir_path, f)
  46. #print(f)
  47. if os.path.isfile(image_path):
  48. preds = read_model_predict(image_path,model)
  49. print(preds[0][0])
  50. if preds[0][0] >= 0.5:
  51. #rdst = 'E:/pcb_image_data/data_2500/right/' + f
  52. #copyfile(image_path, rdst)
  53. right += 1
  54. else:
  55. #wdst = 'E:/pcb_image_data/data_2500/wrong/' + f
  56. #copyfile(image_path, wdst)
  57. #print(preds[0][0])
  58. wrong += 1
  59. else:
  60. read_file_all(image_path)
  61. all_num = right + wrong
  62. Tacc = right/all_num
  63. Facc = wrong/all_num
  64. return Tacc,Facc
  65. if __name__ == '__main__':
  66. img_file = '/test'
  67. model = read_model()
  68. tc,fc = read_file_all(img_file,model)
  69. print('True 识别率',tc,'\n','False 识别率',fc)

 

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/繁依Fanyi0/article/detail/572845
推荐阅读
相关标签
  

闽ICP备14008679号