当前位置:   article > 正文

基于时空序列模型ConvLstm的气象预测_convlstm时序图像预测

convlstm时序图像预测

一、数据获取

  1. 去国家气象科学数据中心下载雷达拼图 可以手动收集也可以写爬虫程序收集 手动收集就更改日期和时间

    中国气象数据网 - Online Data


二、数据预处理

  1.  得到这些PNG图片后,首先做处理,只关注东南部分,其他的都扔掉(即从图中截取一个固定的长方形,就差不多是下面这个部分。保证截取的图中不包括左下角的南海诸岛以及右下角的基本反射率图例,这些都是干扰。你可以扔掉更多的部分,比如西宁以西都扔掉)。对所有的PNG图片都这样操作。
  2. 可能截取完之后,图片的像素仍然过多,那你可以截取更小的一个部分。比如:
    1. pic_name_list
    2. cuted_img_list = []
    3. save_img_path = os.path.join(wd, 'output')
    4. if not os.path.exists(save_img_path):
    5. os.makedirs(save_img_path)
    6. for pic_id in range(len(pic_name_list)):
    7. pic_name = str(pic_id) + '.png'
    8. temp = os.path.join(pic_root_path, pic_name)
    9. img = cv.imread(temp)
    10. print("This is the "+pic_name+":")
    11. print(img.shape)
    12. # (y, x)
    13. cuted_img = img[633:725, 510:665]
    14. #cv.imwrite(os.path.join(save_img_path, pic_name),cuted_img)
    15. cuted_img_list.append(cuted_img)
    16. plt.imshow(cuted_img)
    17. plt.show()

  3. 识别出基本反射率图例中每一个不同数值对应的颜色RGB常数,将图片中每个像素的RGB映射到基本反射率。如果没有对应的值(比如白色,黑色)统统设置为255。
  4. 转为灰度图。

        

  1. def viewColor(pic, color):
  2. #pic = Image.copy()
  3. for i, nar in enumerate(pic):
  4. for j, n in enumerate(nar):
  5. if list(n) == list(color): # 南宁附近的三八线不需要的
  6. pic[i][j] = np.array([255,255,255])
  7. def get_usedColor(img):
  8. from collections import defaultdict
  9. colorMap = defaultdict(int)
  10. usedColor = []
  11. for i, nar in enumerate(img):
  12. for j, n in enumerate(nar):
  13. if str(n) not in colorMap:
  14. usedColor.append(list(n))
  15. colorMap[str(n)] += 1
  16. return usedColor
  17. def Image_Preprocessing(img):
  18. use_color = []
  19. use_color.append([178, 178, 178])
  20. use_color.append([247, 221, 136])
  21. use_color.append([104, 104, 104])
  22. use_color.append([182, 255, 255])
  23. use_color.append([0, 0, 102])
  24. use_color.append([219, 144, 58])
  25. use_color.append([58, 144, 219])
  26. use_color.append([102, 0, 0])
  27. use_color.append([255, 255, 182])
  28. use_color.append([219, 255, 255])
  29. use_color.append([219, 182, 182])
  30. use_color.append([219, 144, 144])
  31. use_color.append([219, 219, 219])
  32. use_color.append([0, 58, 144])
  33. use_color.append([182, 182, 182])
  34. use_color.append([0, 0, 0])
  35. use_color.append([58, 0, 0])
  36. use_color.append([255, 219, 144])
  37. for c in use_color:
  38. viewColor(img, c)
  39. return img
  40. def pltShow(img):
  41. plt.imshow(img)
  42. plt.show()
  43. def cvShow(img):
  44. cv.imshow("mat", img)
  45. cv.waitKey(0)
  46. imgGray_list = []
  47. for id, cuted_img in enumerate(cuted_img_list):
  48. img = Image_Preprocessing(cuted_img)
  49. imgGray = cv.cvtColor(img, cv.COLOR_BGR2GRAY)
  50. pltShow(img)
  51. temp = os.path.join(save_img_path, str(id)+'.png')
  52. if os.path.exists(temp):
  53. os.remove(temp)
  54. cv.imwrite(temp, imgGray)
  55. imgGray_list.append(imgGray)
  56. print(str(id)+'.png is preprocesed, Wait next!')
  57. print("Everything is OK!")

三、ConvLstm模型

  1. ## NEW NETWORK ##
  2. def mainmodel():
  3. # Inputs
  4. dtype='float32'
  5. nk = 128 # number of kernels for conv layers #48
  6. fs = (3,3) # filter size for convolutional kernels
  7. contentInput = Input(shape=(None, WIDTH, HEIGHT, 1), name='content_input', dtype=dtype)
  8. # Encoding Network
  9. x1 = ConvLSTM2D(nk, (5,5), padding='same', return_sequences=True, kernel_initializer ='he_normal', name='layer1')(contentInput)
  10. x2 = ConvLSTM2D(nk, (5,5), padding='same', return_sequences=True, kernel_initializer ='he_normal', name='layer2')(x1)
  11. # Forecasting Network
  12. x3 = ConvLSTM2D(nk, (5,5), padding='same', return_sequences=True, kernel_initializer ='he_normal', name='layer3')(x1)
  13. add1 = Add()([x3, x2])
  14. x4 = ConvLSTM2D(nk, (5,5), padding='same', return_sequences=True, kernel_initializer ='he_normal', name='layer4')(add1)
  15. # Prediction Network
  16. conc = Concatenate()([x4, x3])
  17. predictions = Conv3D(1, (5,5,5), activation='sigmoid', padding='same', name='prediction')(conc) #sigmoid original
  18. model = Model(inputs=contentInput, outputs=predictions)
  19. return model

四、训练

  

  1. ## NEW NETWORK ##
  2. def mainmodel():
  3. # Inputs
  4. dtype='float32'
  5. nk = 128 # number of kernels for conv layers #48
  6. fs = (3,3) # filter size for convolutional kernels
  7. contentInput = Input(shape=(None, WIDTH, HEIGHT, 1), name='content_input', dtype=dtype)
  8. # Encoding Network
  9. x1 = ConvLSTM2D(nk, (5,5), padding='same', return_sequences=True, kernel_initializer ='he_normal', name='layer1')(contentInput)
  10. x2 = ConvLSTM2D(nk, (5,5), padding='same', return_sequences=True, kernel_initializer ='he_normal', name='layer2')(x1)
  11. # Forecasting Network
  12. x3 = ConvLSTM2D(nk, (5,5), padding='same', return_sequences=True, kernel_initializer ='he_normal', name='layer3')(x1)
  13. add1 = Add()([x3, x2])
  14. x4 = ConvLSTM2D(nk, (5,5), padding='same', return_sequences=True, kernel_initializer ='he_normal', name='layer4')(add1)
  15. # Prediction Network
  16. conc = Concatenate()([x4, x3])
  17. predictions = Conv3D(1, (5,5,5), activation='sigmoid', padding='same', name='prediction')(conc) #sigmoid original
  18. model = Model(inputs=contentInput, outputs=predictions)
  19. return model
  20. # Train model
  21. def train(main_model=True, batchsize=5, epochs=50, save=False):
  22. smooth=1e-9
  23. #Additional metrics: SSIM, PSNR, POD, FAR
  24. def ssim(x, y, max_val=1.0):
  25. return tf.image.ssim(x, y, max_val)
  26. def psnr(x, y, max_val=1.0):
  27. return tf.image.psnr(x, y, max_val)
  28. #recall
  29. def POD(x, y):
  30. y_pos = K.clip(x, 0, 1)
  31. y_pred_pos = K.clip(y, 0, 1)
  32. y_pred_neg = 1 - y_pred_pos
  33. tp = K.sum(y_pos * y_pred_pos)
  34. fn = K.sum(y_pos * y_pred_neg)
  35. return (tp+smooth)/(tp+fn+smooth)
  36. def FAR(x, y):
  37. y_pred_pos = K.clip(y, 0, 1)
  38. y_pos = K.clip(x, 0, 1)
  39. y_neg = 1 - y_pos
  40. tp = K.sum(y_pos * y_pred_pos)
  41. fp = K.sum(y_neg * y_pred_pos)
  42. return (fp)/(tp+fp+smooth)
  43. metrics = ['accuracy', ssim, psnr, POD, FAR]
  44. global history, model
  45. if main_model:
  46. model=mainmodel()
  47. print("[INFO] Compiling Main Model...")
  48. optim = Adam(lr=0.001, beta_1=0.9, beta_2=0.999, amsgrad=False)
  49. model.compile(loss='logcosh', optimizer=optim, metrics=metrics) #logcosh gives better results than crossentropy or mse
  50. print("[INFO] Compiling Main Model: DONE")
  51. print("[INFO] Training Main Model...")
  52. history = model.fit(INPUT_SEQUENCE[:40], NEXT_SEQUENCE[:40], batch_size=batchsize, epochs=epochs, validation_split=0.1, verbose=1, use_multiprocessing=True)
  53. print("[INFO] Training of Main Model: DONE")
  54. #Save trained model
  55. if save:
  56. print("[INFO] Saving Model...")
  57. #model.save('models/model1_ConvLSTM/mainmodel_1.h5')
  58. # serialize model to JSON
  59. model_json = model.to_json()
  60. with open("models/model1_ConvLSTM/mainmodel_1.json", "w") as json_file:
  61. json_file.write(model_json)
  62. # serialize weights to HDF5
  63. model.save_weights("models/model1_ConvLSTM/mainmodel_1.h5")
  64. print("[INFO] Model Saved")
  65. else: print("[INFO] Model not saved")
  66. else:
  67. model=test_model()
  68. print("[INFO] Compiling Test Model...")
  69. model.compile(loss='logcosh', optimizer='adam', metrics=metrics)
  70. print("[INFO] Compiling Test Model: DONE")
  71. print("[INFO] Training Test Model...:")
  72. #history = model.fit(INPUT_SEQUENCE[:40], NEXT_SEQUENCE[:40], batch_size=5, epochs=180, validation_split=0.05, verbose=1, use_multiprocessing=True)
  73. history = model.fit(INPUT_SEQUENCE[:60], NEXT_SEQUENCE[:60], batch_size=batchsize, epochs=epochs, validation_split=0.05, verbose=1, use_multiprocessing=True)
  74. print("[INFO] Training of Test Model: DONE")
  75. #Save trained model
  76. if save:
  77. print("[INFO] Saving Test Model...")
  78. model.save('models/model1_ConvLSTM/trained_test_model_samples.h5')
  79. print("[INFO] Model Saved")
  80. else: print("[INFO] Model not saved")
  81. ### PLOT LOSS vs EPOCHS ###
  82. def performance():
  83. # Plot training & validation accuracy values
  84. plt.plot(history.history['acc'])
  85. plt.plot(history.history['val_acc'])
  86. plt.title('Model accuracy')
  87. plt.ylabel('Accuracy')
  88. plt.xlabel('Epoch')
  89. plt.legend(['Train', 'Test'], loc='upper left')
  90. plt.show()
  91. # Plot training & validation loss values
  92. plt.plot(history.history['loss'])
  93. plt.plot(history.history['val_loss'])
  94. plt.title('Model loss')
  95. plt.ylabel('Loss')
  96. plt.xlabel('Epoch')
  97. plt.legend(['Train', 'Test'], loc='upper left')
  98. plt.show()
  99. # Plot POD/FAR plot
  100. plt.plot(history.history['POD'])
  101. plt.plot(history.history['FAR'])
  102. plt.title('POD, FAR plot')
  103. plt.ylabel('POD / FAR')
  104. plt.xlabel('Epoch')
  105. plt.legend(['POD', 'FAR'], loc='upper left')
  106. plt.show()
  107. #Train Model
  108. #main_model = True trains main_model
  109. #main_model = False trains test_model
  110. train(main_model=True, batchsize=4, epochs=8, save=True)

五、预测

(待补)

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/煮酒与君饮/article/detail/769720
推荐阅读
相关标签
  

闽ICP备14008679号