当前位置:   article > 正文

机器学习笔记 - Traffic-Net训练交通拥堵程度

traffic-net

Traffic-Net简介

        Traffic-Net是交通图像的数据集,其收集目的是确保可以训练机器学习系统以检测交通状况并提供实时监视,分析和警报。

        样本集下载地址:https://github.com/OlafenwaMoses/Traffic-Net/releases/download/1.0/trafficnet_dataset_v1.zipicon-default.png?t=M3K6https://github.com/OlafenwaMoses/Traffic-Net/releases/download/1.0/trafficnet_dataset_v1.zip

        这是Traffic-Net数据集的第一个版本。它包含4个类别的4,400张图像。此版本中包含的类为:

  • 事故
  • 密集交通
  • 稀疏交通

        每个类别有1,100张图像,其中900张图像进行训练200张图像进行测试

代码参考

        需要安装imageai的包,并且在tensorflow2下运行会有一些错误,看log修改对应的错误的imageai的包的文件

        1、from tensorflow.python.keras.utils 需要改成from tensorflow.keras.utils

        2、optimizer = Adam(lr=self.__initial_learning_rate, decay=1e-4) 改成 optimizer = tf.optimizers.Adam(lr=self.__initial_learning_rate, decay=1e-4)

  1. from io import open
  2. import requests
  3. import shutil
  4. from zipfile import ZipFile
  5. from imageai.Prediction.Custom import ModelTraining, CustomImagePrediction
  6. import os
  7. execution_path = os.getcwd()
  8. SOURCE_PATH = "https://github.com/OlafenwaMoses/Traffic-Net/releases/download/1.0/trafficnet_dataset_v1.zip"
  9. FILE_DIR = os.path.join(execution_path, "trafficnet_dataset_v1.zip")
  10. DATASET_DIR = os.path.join(execution_path, "trafficnet_dataset_v1.zip")
  11. def download_traffic_net():
  12. if (os.path.exists(FILE_DIR) == False):
  13. print("Downloading trafficnet_dataset_v1.zip")
  14. data = requests.get(SOURCE_PATH,
  15. stream=True)
  16. with open(FILE_DIR, "wb") as file:
  17. shutil.copyfileobj(data.raw, file)
  18. del data
  19. extract = ZipFile(FILE_DIR)
  20. extract.extractall(execution_path)
  21. extract.close()
  22. def train_traffic_net():
  23. download_traffic_net()
  24. trainer = ModelTraining()
  25. trainer.setModelTypeAsResNet()
  26. trainer.setDataDirectory("trafficnet_dataset_v1")
  27. trainer.trainModel(num_objects=4, num_experiments=200, batch_size=32, save_full_model=True, enhance_data=True)
  28. def run_predict():
  29. predictor = CustomImagePrediction()
  30. predictor.setModelPath(model_path="trafficnet_resnet_model_ex-055_acc-0.913750.h5")
  31. predictor.setJsonPath(model_json="model_class.json")
  32. predictor.loadFullModel(num_objects=4)
  33. predictions, probabilities = predictor.predictImage(image_input="images/1.jpg", result_count=4)
  34. for prediction, probability in zip(predictions, probabilities):
  35. print(prediction, " : ", probability)
  36. #Un-comment the line below to train your model
  37. #train_traffic_net()
  38. #Un-comment the line below to run predictions
  39. run_predict()

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/很楠不爱3/article/detail/458515
推荐阅读
  

闽ICP备14008679号