赞
踩
Traffic-Net是交通图像的数据集,其收集目的是确保可以训练机器学习系统以检测交通状况并提供实时监视,分析和警报。
这是Traffic-Net数据集的第一个版本。它包含4个类别的4,400张图像。此版本中包含的类为:
每个类别有1,100张图像,其中900张图像进行训练,200张图像进行测试。
需要安装imageai的包,并且在tensorflow2下运行会有一些错误,看log修改对应的错误的imageai的包的文件
1、from tensorflow.python.keras.utils 需要改成from tensorflow.keras.utils
2、optimizer = Adam(lr=self.__initial_learning_rate, decay=1e-4) 改成 optimizer = tf.optimizers.Adam(lr=self.__initial_learning_rate, decay=1e-4)
- from io import open
- import requests
- import shutil
- from zipfile import ZipFile
- from imageai.Prediction.Custom import ModelTraining, CustomImagePrediction
- import os
-
- execution_path = os.getcwd()
-
- SOURCE_PATH = "https://github.com/OlafenwaMoses/Traffic-Net/releases/download/1.0/trafficnet_dataset_v1.zip"
- FILE_DIR = os.path.join(execution_path, "trafficnet_dataset_v1.zip")
- DATASET_DIR = os.path.join(execution_path, "trafficnet_dataset_v1.zip")
-
-
- def download_traffic_net():
- if (os.path.exists(FILE_DIR) == False):
- print("Downloading trafficnet_dataset_v1.zip")
- data = requests.get(SOURCE_PATH,
- stream=True)
-
- with open(FILE_DIR, "wb") as file:
- shutil.copyfileobj(data.raw, file)
- del data
-
- extract = ZipFile(FILE_DIR)
- extract.extractall(execution_path)
- extract.close()
-
-
- def train_traffic_net():
- download_traffic_net()
-
- trainer = ModelTraining()
- trainer.setModelTypeAsResNet()
- trainer.setDataDirectory("trafficnet_dataset_v1")
- trainer.trainModel(num_objects=4, num_experiments=200, batch_size=32, save_full_model=True, enhance_data=True)
-
- def run_predict():
- predictor = CustomImagePrediction()
- predictor.setModelPath(model_path="trafficnet_resnet_model_ex-055_acc-0.913750.h5")
- predictor.setJsonPath(model_json="model_class.json")
- predictor.loadFullModel(num_objects=4)
-
- predictions, probabilities = predictor.predictImage(image_input="images/1.jpg", result_count=4)
- for prediction, probability in zip(predictions, probabilities):
- print(prediction, " : ", probability)
-
- #Un-comment the line below to train your model
- #train_traffic_net()
-
- #Un-comment the line below to run predictions
- run_predict()
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。