那么既然要用到模型训练,我想到了两种方式,第一个是opencv.js的使用。但是我的js水平,实在是不敢恭维。所以我想到的另一个方式,就是opencv直接读取视频流后,逐帧截取图片,然后套到模型里进行评估。



  1. #define APP_CPU 1
  2. #define PRO_CPU 0
  3. #include "OV2640.h"
  4. #include <WiFi.h>
  5. #include <WebServer.h>
  6. #include <WiFiClient.h>
  7. #include <esp_bt.h>
  8. #include <esp_wifi.h>
  9. #include <esp_sleep.h>
  10. #include <driver/rtc_io.h>
  12. #include "camera_pins.h"
  13. #include "home_wifi_multi.h"
  14. OV2640 cam;
  15. WebServer server(80);
  16. // ===== rtos task handles =========================
  17. // Streaming is implemented with 3 tasks:
  18. TaskHandle_t tMjpeg; // handles client connections to the webserver
  19. TaskHandle_t tCam; // handles getting picture frames from the camera and storing them locally
  20. TaskHandle_t tStream; // actually streaming frames to all connected clients
  21. // frameSync semaphore is used to prevent streaming buffer as it is replaced with the next frame
  22. SemaphoreHandle_t frameSync = NULL;
  23. // Queue stores currently connected clients to whom we are streaming
  24. QueueHandle_t streamingClients;
  25. // We will try to achieve 25 FPS frame rate
  26. const int FPS = 7;
  27. // We will handle web client requests every 50 ms (20 Hz)
  28. const int WSINTERVAL = 100;
  29. // ======== Server Connection Handler Task ==========================
  30. void mjpegCB(void* pvParameters) {
  31. TickType_t xLastWakeTime;
  32. const TickType_t xFrequency = pdMS_TO_TICKS(WSINTERVAL);
  33. // Creating frame synchronization semaphore and initializing it
  34. frameSync = xSemaphoreCreateBinary();
  35. xSemaphoreGive( frameSync );
  36. // Creating a queue to track all connected clients
  37. streamingClients = xQueueCreate( 10, sizeof(WiFiClient*) );
  38. //=== setup section ==================
  39. // Creating RTOS task for grabbing frames from the camera
  40. xTaskCreatePinnedToCore(
  41. camCB, // callback
  42. "cam", // name
  43. 4096, // stacj size
  44. NULL, // parameters
  45. 2, // priority
  46. &tCam, // RTOS task handle
  47. APP_CPU); // core
  48. // Creating task to push the stream to all connected clients
  49. xTaskCreatePinnedToCore(
  50. streamCB,
  51. "strmCB",
  52. 4 * 1024,
  53. NULL, //(void*) handler,
  54. 2,
  55. &tStream,
  56. APP_CPU);
  57. // Registering webserver handling routines
  58. server.on("/mjpeg/1", HTTP_GET, handleJPGSstream);
  59. server.on("/jpg", HTTP_GET, handleJPG);
  60. server.onNotFound(handleNotFound);
  61. // Starting webserver
  62. server.begin();
  63. //=== loop() section ===================
  64. xLastWakeTime = xTaskGetTickCount();
  65. for (;;) {
  66. server.handleClient();
  67. // After every server client handling request, we let other tasks run and then pause
  68. taskYIELD();
  69. vTaskDelayUntil(&xLastWakeTime, xFrequency);
  70. }
  71. }
  72. // Commonly used variables:
  73. volatile size_t camSize; // size of the current frame, byte
  74. volatile char* camBuf; // pointer to the current frame
  75. // ==== RTOS task to grab frames from the camera =========================
  76. void camCB(void* pvParameters) {
  77. TickType_t xLastWakeTime;
  78. // A running interval associated with currently desired frame rate
  79. const TickType_t xFrequency = pdMS_TO_TICKS(1000 / FPS);
  80. // Mutex for the critical section of swithing the active frames around
  81. portMUX_TYPE xSemaphore = portMUX_INITIALIZER_UNLOCKED;
  82. // Pointers to the 2 frames, their respective sizes and index of the current frame
  83. char* fbs[2] = { NULL, NULL };
  84. size_t fSize[2] = { 0, 0 };
  85. int ifb = 0;
  86. //=== loop() section ===================
  87. xLastWakeTime = xTaskGetTickCount();
  88. for (;;) {
  89. // Grab a frame from the camera and query its size
  90. cam.run();
  91. size_t s = cam.getSize();
  92. // If frame size is more that we have previously allocated - request 125% of the current frame space
  93. if (s > fSize[ifb]) {
  94. fSize[ifb] = s * 4 / 3;
  95. fbs[ifb] = allocateMemory(fbs[ifb], fSize[ifb]);
  96. }
  97. // Copy current frame into local buffer
  98. char* b = (char*) cam.getfb();
  99. memcpy(fbs[ifb], b, s);
  100. // Let other tasks run and wait until the end of the current frame rate interval (if any time left)
  101. taskYIELD();
  102. vTaskDelayUntil(&xLastWakeTime, xFrequency);
  103. // Only switch frames around if no frame is currently being streamed to a client
  104. // Wait on a semaphore until client operation completes
  105. xSemaphoreTake( frameSync, portMAX_DELAY );
  106. // Do not allow interrupts while switching the current frame
  107. portENTER_CRITICAL(&xSemaphore);
  108. camBuf = fbs[ifb];
  109. camSize = s;
  110. ifb++;
  111. ifb &= 1; // this should produce 1, 0, 1, 0, 1 ... sequence
  112. portEXIT_CRITICAL(&xSemaphore);
  113. // Let anyone waiting for a frame know that the frame is ready
  114. xSemaphoreGive( frameSync );
  115. // Technically only needed once: let the streaming task know that we have at least one frame
  116. // and it could start sending frames to the clients, if any
  117. xTaskNotifyGive( tStream );
  118. // Immediately let other (streaming) tasks run
  119. taskYIELD();
  120. // If streaming task has suspended itself (no active clients to stream to)
  121. // there is no need to grab frames from the camera. We can save some juice
  122. // by suspedning the tasks
  123. if ( eTaskGetState( tStream ) == eSuspended ) {
  124. vTaskSuspend(NULL); // passing NULL means "suspend yourself"
  125. }
  126. }
  127. }
  128. // ==== Memory allocator that takes advantage of PSRAM if present =======================
  129. char* allocateMemory(char* aPtr, size_t aSize) {
  130. // Since current buffer is too smal, free it
  131. if (aPtr != NULL) free(aPtr);
  132. size_t freeHeap = ESP.getFreeHeap();
  133. char* ptr = NULL;
  134. // If memory requested is more than 2/3 of the currently free heap, try PSRAM immediately
  135. if ( aSize > freeHeap * 2 / 3 ) {
  136. if ( psramFound() && ESP.getFreePsram() > aSize ) {
  137. ptr = (char*) ps_malloc(aSize);
  138. }
  139. }
  140. else {
  141. // Enough free heap - let's try allocating fast RAM as a buffer
  142. ptr = (char*) malloc(aSize);
  143. // If allocation on the heap failed, let's give PSRAM one more chance:
  144. if ( ptr == NULL && psramFound() && ESP.getFreePsram() > aSize) {
  145. ptr = (char*) ps_malloc(aSize);
  146. }
  147. }
  148. // Finally, if the memory pointer is NULL, we were not able to allocate any memory, and that is a terminal condition.
  149. if (ptr == NULL) {
  150. ESP.restart();
  151. }
  152. return ptr;
  153. }
  154. // ==== STREAMING ======================================================
  155. const char HEADER[] = "HTTP/1.1 200 OK\r\n" \
  156. "Access-Control-Allow-Origin: *\r\n" \
  157. "Content-Type: multipart/x-mixed-replace; boundary=123456789000000000000987654321\r\n";
  158. const char BOUNDARY[] = "\r\n--123456789000000000000987654321\r\n";
  159. const char CTNTTYPE[] = "Content-Type: image/jpeg\r\nContent-Length: ";
  160. const int hdrLen = strlen(HEADER);
  161. const int bdrLen = strlen(BOUNDARY);
  162. const int cntLen = strlen(CTNTTYPE);
  163. // ==== Handle connection request from clients ===============================
  164. void handleJPGSstream(void)
  165. {
  166. // Can only acommodate 10 clients. The limit is a default for WiFi connections
  167. if ( !uxQueueSpacesAvailable(streamingClients) ) return;
  168. // Create a new WiFi Client object to keep track of this one
  169. WiFiClient* client = new WiFiClient();
  170. *client = server.client();
  171. // Immediately send this client a header
  172. client->write(HEADER, hdrLen);
  173. client->write(BOUNDARY, bdrLen);
  174. // Push the client to the streaming queue
  175. xQueueSend(streamingClients, (void *) &client, 0);
  176. // Wake up streaming tasks, if they were previously suspended:
  177. if ( eTaskGetState( tCam ) == eSuspended ) vTaskResume( tCam );
  178. if ( eTaskGetState( tStream ) == eSuspended ) vTaskResume( tStream );
  179. }
  180. // ==== Actually stream content to all connected clients ========================
  181. void streamCB(void * pvParameters) {
  182. char buf[16];
  183. TickType_t xLastWakeTime;
  184. TickType_t xFrequency;
  185. // Wait until the first frame is captured and there is something to send
  186. // to clients
  187. ulTaskNotifyTake( pdTRUE, /* Clear the notification value before exiting. */
  188. portMAX_DELAY ); /* Block indefinitely. */
  189. xLastWakeTime = xTaskGetTickCount();
  190. for (;;) {
  191. // Default assumption we are running according to the FPS
  192. xFrequency = pdMS_TO_TICKS(1000 / FPS);
  193. // Only bother to send anything if there is someone watching
  194. UBaseType_t activeClients = uxQueueMessagesWaiting(streamingClients);
  195. if ( activeClients ) {
  196. // Adjust the period to the number of connected clients
  197. xFrequency /= activeClients;
  198. // Since we are sending the same frame to everyone,
  199. // pop a client from the the front of the queue
  200. WiFiClient *client;
  201. xQueueReceive (streamingClients, (void*) &client, 0);
  202. // Check if this client is still connected.
  203. if (!client->connected()) {
  204. // delete this client reference if s/he has disconnected
  205. // and don't put it back on the queue anymore. Bye!
  206. delete client;
  207. }
  208. else {
  209. // Ok. This is an actively connected client.
  210. // Let's grab a semaphore to prevent frame changes while we
  211. // are serving this frame
  212. xSemaphoreTake( frameSync, portMAX_DELAY );
  213. client->write(CTNTTYPE, cntLen);
  214. sprintf(buf, "%d\r\n\r\n", camSize);
  215. client->write(buf, strlen(buf));
  216. client->write((char*) camBuf, (size_t)camSize);
  217. client->write(BOUNDARY, bdrLen);
  218. // Since this client is still connected, push it to the end
  219. // of the queue for further processing
  220. xQueueSend(streamingClients, (void *) &client, 0);
  221. // The frame has been served. Release the semaphore and let other tasks run.
  222. // If there is a frame switch ready, it will happen now in between frames
  223. xSemaphoreGive( frameSync );
  224. taskYIELD();
  225. }
  226. }
  227. else {
  228. // Since there are no connected clients, there is no reason to waste battery running
  229. vTaskSuspend(NULL);
  230. }
  231. // Let other tasks run after serving every client
  232. taskYIELD();
  233. vTaskDelayUntil(&xLastWakeTime, xFrequency);
  234. }
  235. }
  236. const char JHEADER[] = "HTTP/1.1 200 OK\r\n" \
  237. "Content-disposition: inline; filename=capture.jpg\r\n" \
  238. "Content-type: image/jpeg\r\n\r\n";
  239. const int jhdLen = strlen(JHEADER);
  240. // ==== Serve up one JPEG frame =============================================
  241. void handleJPG(void)
  242. {
  243. WiFiClient client = server.client();
  244. if (!client.connected()) return;
  245. cam.run();
  246. client.write(JHEADER, jhdLen);
  247. client.write((char*)cam.getfb(), cam.getSize());
  248. }
  249. // ==== Handle invalid URL requests ============================================
  250. void handleNotFound()
  251. {
  252. String message = "Server is running!\n\n";
  253. message += "URI: ";
  254. message += server.uri();
  255. message += "\nMethod: ";
  256. message += (server.method() == HTTP_GET) ? "GET" : "POST";
  257. message += "\nArguments: ";
  258. message += server.args();
  259. message += "\n";
  260. server.send(200, "text / plain", message);
  261. }
  262. // ==== SETUP method ==================================================================
  263. void setup()
  264. {
  265. // Setup Serial connection:
  266. Serial.begin(115200);
  267. delay(1000); // wait for a second to let Serial connect
  268. // Configure the camera
  269. camera_config_t config;
  270. config.ledc_channel = LEDC_CHANNEL_0;
  271. config.ledc_timer = LEDC_TIMER_0;
  272. config.pin_d0 = Y2_GPIO_NUM;
  273. config.pin_d1 = Y3_GPIO_NUM;
  274. config.pin_d2 = Y4_GPIO_NUM;
  275. config.pin_d3 = Y5_GPIO_NUM;
  276. config.pin_d4 = Y6_GPIO_NUM;
  277. config.pin_d5 = Y7_GPIO_NUM;
  278. config.pin_d6 = Y8_GPIO_NUM;
  279. config.pin_d7 = Y9_GPIO_NUM;
  280. config.pin_xclk = XCLK_GPIO_NUM;
  281. config.pin_pclk = PCLK_GPIO_NUM;
  282. config.pin_vsync = VSYNC_GPIO_NUM;
  283. config.pin_href = HREF_GPIO_NUM;
  284. config.pin_sscb_sda = SIOD_GPIO_NUM;
  285. config.pin_sscb_scl = SIOC_GPIO_NUM;
  286. config.pin_pwdn = PWDN_GPIO_NUM;
  287. config.pin_reset = RESET_GPIO_NUM;
  288. config.xclk_freq_hz = 20000000;
  289. config.pixel_format = PIXFORMAT_JPEG;
  290. // Frame parameters: pick one
  291. // config.frame_size = FRAMESIZE_UXGA;
  292. // config.frame_size = FRAMESIZE_SVGA;
  293. // config.frame_size = FRAMESIZE_QVGA;
  294. config.frame_size = FRAMESIZE_VGA;
  295. config.jpeg_quality = 12;
  296. config.fb_count = 2;
  297. #if defined(CAMERA_MODEL_ESP_EYE)
  298. pinMode(13, INPUT_PULLUP);
  299. pinMode(14, INPUT_PULLUP);
  300. #endif
  301. if (cam.init(config) != ESP_OK) {
  302. Serial.println("Error initializing the camera");
  303. delay(10000);
  304. ESP.restart();
  305. }
  306. // Configure and connect to WiFi
  307. IPAddress ip;
  308. WiFi.mode(WIFI_STA);
  309. WiFi.begin("201", "1234abcd");//WIFI名称和密码
  310. Serial.print("Connecting to WiFi");
  311. while (WiFi.status() != WL_CONNECTED)
  312. {
  313. delay(500);
  314. Serial.print(F("."));
  315. }
  316. ip = WiFi.localIP();
  317. Serial.println(F("WiFi connected"));
  318. Serial.println("");
  319. Serial.print("Stream Link: http://");
  320. Serial.print(ip);
  321. Serial.println("/mjpeg/1");
  322. // Start mainstreaming RTOS task
  323. xTaskCreatePinnedToCore(
  324. mjpegCB,
  325. "mjpeg",
  326. 4 * 1024,
  327. NULL,
  328. 2,
  329. &tMjpeg,
  330. APP_CPU);
  331. }
  332. void loop() {
  333. vTaskDelay(1000);
  334. }




  1. import cv2
  2. #url = ''
  3. url = 'your ip stream'
  4. cap = cv2.VideoCapture(url)
  5. width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
  6. height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
  7. fourcc = cv2.VideoWriter_fourcc(*"mp4v")
  8. out = cv2.VideoWriter('./video/test.mp4', fourcc, 20, (width, height))
  9. while(cap.isOpened()):
  10. ret, frame = cap.read()
  11. if ret:
  12. out.write(frame)
  13. cv2.imshow('frame', frame)
  14. if cv2.waitKey(25) & 0xFF == ord('q'): #按键盘Q键退出
  15. break
  16. else:
  17. continue
  18. cap.release()
  19. out.release()
  20. cv2.destroyAllWindows()


  1. import cv2
  2. # 使用opencv按一定间隔截取视频帧,并保存为图片
  3. vc = cv2.VideoCapture('./video/test.mp4') # 读取视频文件
  4. c = 0
  5. d = 0
  6. print("------------")
  7. if vc.isOpened(): # 判断是否正常打开
  8. print("yes")
  9. rval, frame = vc.read()
  10. else:
  11. rval = False
  12. print("false")
  13. timeF = 100 # 视频帧计数间隔频率
  14. while rval: # 循环读取视频帧
  15. rval, frame = vc.read()
  16. print(c,timeF, c%timeF)
  17. if (c % timeF == 0):# 每隔timeF帧进行存储操作
  18. print("write...")
  19. cv2.imwrite(f'./video_cut/1_{d}.jpg', frame) # 存储为图像
  20. print("success!")
  21. c = c + 100000
  22. d = d + 1
  23. cv2.waitKey(1)
  24. vc.release()
  25. print("==================================")





  1. #导入库文件
  2. import matplotlib.pyplot as plt
  3. import matplotlib.image as mpimg
  4. import numpy as np
  5. import os
  6. import pandas as pd
  7. #计算颜色矩特征模型
  8. def img2vector(filename):
  9. returnvect = np.zeros((1, 9))
  10. #一个1*9的二维数组
  11. fr = mpimg.imread(filename)
  12. #用matplotlib读取图片文件
  13. l_max = fr.shape[0]//2+50 #读取矩阵的第一个维度,然后除以二后向下取整,再加50
  14. l_min = fr.shape[0]//2-50
  15. w_max = fr.shape[1]//2+50
  16. w_min = fr.shape[1]//2-50
  17. water = fr[l_min:l_max, w_min:w_max, :].reshape(1, 10000, 3)#重塑为一个三维矩阵,1*10000*3
  18. for i in range(3):
  19. this = water[:, :, i]/255
  20. print(this)
  21. returnvect[0, i] = np.mean(this) #0,1,2存储一阶颜色矩
  22. returnvect[0, 3+i] = np.sqrt(np.mean(np.square(this-returnvect[0, i])))#3,4,5存储二阶颜色矩
  23. returnvect[0, 6+i] = np.cbrt(np.mean(np.power(this-returnvect[0, i], 3)))#6,7,8存储三阶颜色矩
  24. print(returnvect)
  25. return returnvect
  26. #计算每个图片的特征
  27. trainfilelist = os.listdir('./water_image')#读取目录下文件列表
  28. m = len(trainfilelist) #计算文件数目
  29. labels = np.zeros((1, m)) #生成两个196个0的空矩阵
  30. train = np.zeros((1, m))
  31. #trainingMat=[]
  32. #print(trainfilelist)
  33. trainingMat=np.zeros((m, 9)) #m行9列的0空矩阵
  34. for i in range(m):
  35. filenamestr = trainfilelist[i] #获取当前文件名,例1_1.jpg
  36. filestr = filenamestr.split('.')[0] #按照.划分,取前一部分
  37. classnumstr = int(filestr.split('_')[0])#按照_划分,后一部分为该类图片中的序列
  38. picture_num = int(filestr.split('_')[1])
  39. labels[0, i] = classnumstr #前一部分为该图片的标签
  40. train[0, i] = picture_num
  41. trainingMat[i, :] = img2vector('./water_image/%s' % filenamestr) #构成数组
  42. #保存
  43. d = np.concatenate((labels.T, train.T, trainingMat), axis=1)#连接数组
  44. dataframe = pd.DataFrame(d, columns=['Water kind','number', 'R_1', 'G_1', 'B_1', 'R_2', 'G_2', 'B_2', 'R_3', 'G_3', 'B_3'])
  45. dataframe.to_csv('./data/moment.csv', encoding='utf-8', index=False)#保存文件


  1. import numpy as np
  2. import matplotlib.pyplot as plt
  3. import matplotlib.image as mpimg
  4. import os
  5. import pandas as pd
  6. from sklearn.model_selection import train_test_split
  7. from sklearn import svm
  8. from sklearn import metrics
  9. import joblib
  10. import warnings
  11. from first_step import img2vector
  12. from sklearn.model_selection import GridSearchCV
  13. import seaborn as sns
  14. warnings.filterwarnings("ignore")#防止标签缺失的警报
  15. #计算每个图片的特征
  16. trainfilelist = os.listdir('./video_cut')#读取目录下文件列表
  17. m = len(trainfilelist) #计算文件数目
  18. labels = np.zeros((1, m)) #生成两个196个0的空矩阵
  19. train = np.zeros((1, m))
  20. #trainingMat=[]
  21. #print(trainfilelist)
  22. trainingMat=np.zeros((m, 9)) #m行9列的0空矩阵
  23. for i in range(m):
  24. filenamestr = trainfilelist[i] #获取当前文件名,例1_1.jpg
  25. filestr = filenamestr.split('.')[0] #按照.划分,取前一部分
  26. classnumstr = int(filestr.split('_')[0])#按照_划分,后一部分为该类图片中的序列
  27. picture_num = int(filestr.split('_')[1])
  28. labels[0, i] = classnumstr #前一部分为该图片的标签
  29. train[0, i] = picture_num
  30. trainingMat[i, :] = img2vector('./video_cut/%s' % filenamestr) #构成数组
  31. #保存
  32. d = np.concatenate((labels.T, train.T, trainingMat), axis=1)#连接数组
  33. dataframe = pd.DataFrame(d, columns=['Water kind', 'number', 'R_1', 'G_1', 'B_1', 'R_2', 'G_2', 'B_2', 'R_3', 'G_3', 'B_3'])
  34. dataframe.to_csv('./real_data/real_moment.csv', encoding='utf-8', index=False)#保存文件


  1. import matplotlib.pyplot as plt
  2. import pandas as pd
  3. from pandas import DataFrame,Series
  4. import random
  5. import numpy as np
  6. # -*- coding:utf-8 -*-
  7. def cm_plot(y, yp):
  8. from sklearn.metrics import confusion_matrix # 导入混淆矩阵函数
  9. cm = confusion_matrix(y, yp) # 混淆矩阵
  10. import matplotlib.pyplot as plt # 导入作图库
  11. plt.matshow(cm, cmap=plt.cm.Greens) # 画混淆矩阵图,配色风格使用cm.Greens,更多风格请参考官网。
  12. plt.colorbar() # 颜色标签
  13. for x in range(len(cm)): # 数据标签
  14. for y in range(len(cm)):
  15. plt.annotate(cm[x, y], xy=(x, y), horizontalalignment='center', verticalalignment='center')
  16. plt.ylabel('True label') # 坐标轴标签
  17. plt.xlabel('Predicted label') # 坐标轴标签
  18. return plt
  19. inputfile = './data/moment.csv'
  20. data = pd.read_csv(inputfile, encoding='gbk')
  21. # 注意,此处不能用shuffle
  22. sampler = np.random.permutation(len(data))
  23. d = data.take(sampler).values
  24. data_train = d[:int(0.8*len(data)),:] #选取前80%做训练集
  25. data_test = d[int(0.8*len(data)):,:] #选取后20%做测试集
  26. print(data_train.shape)
  27. print(data_test.shape)
  28. # 构建支持向量机模型代码
  29. x_train = data_train[:, 2:]*30 #放大特征
  30. y_train = data_train[:,0].astype(int)
  31. x_test = data_test[:, 2:]*30 #放大特征
  32. y_test = data_test[:,0].astype(int)
  33. print(x_train.shape)
  34. print(x_test.shape)
  35. # 导入模型相关的支持向量机函数 建立并且训练模型
  36. from sklearn import svm
  37. model = svm.SVC()
  38. model.fit(x_train, y_train)
  39. import pickle
  40. pickle.dump(model, open('./save_model/clf.model', 'wb'))
  41. # model = pickle.load(open('svcmodel.model','rb'))
  42. # 导入输出相关的库,生成混淆矩阵
  43. from sklearn import metrics
  44. cm_train = metrics.confusion_matrix(y_train, model.predict(x_train)) # 训练样本的混淆矩阵
  45. cm_test = metrics.confusion_matrix(y_test, model.predict(x_test)) # 测试样本的混淆矩阵
  46. print(cm_train.shape)
  47. print(cm_test.shape)
  48. df1 = DataFrame(cm_train, index = range(1,5), columns=range(1,5))
  49. df2 = DataFrame(cm_test, index = range(1,5), columns=range(1,5))
  50. df1.to_excel('./train_data_xlxs/trainPre.xlsx')
  51. df2.to_excel('./train_data_xlxs/testPre.xlsx')
  52. print(model.score(x_train,y_train)) # 评价模型训练的准确率
  53. print(model.score(x_test,y_test)) # 评价模型测试的准确率
  54. cm_plot(y_train, model.predict(x_train)).show() # cm_plot是自定义的画混淆矩阵的函数
  55. cm_plot(y_test, model.predict(x_test)).show() # cm_plot是自定义的画混淆矩阵的函数
  56. #------------------------------------------------------------------------------------------------------------------
  57. #------------------------------------------------------------------------------------------------------------------
  58. #------------------------------------------------------------------------------------------------------------------
  59. #正式开始的数据
  60. inputfile1 = './real_data/real_moment.csv'
  61. data1 = pd.read_csv(inputfile1, encoding='gbk')
  62. sampler = np.random.permutation(len(data1))
  63. d = data1.take(sampler).values
  64. data_train1 = d[:int(0.8*len(data1)),:] #选取前80%做训练集
  65. data_test1 = d[int(0.8*len(data1)):,:] #选取后20%做测试集
  66. print(data_train1.shape)
  67. print(data_test1.shape)
  68. x_train1 = data_train1[:, 2:] * 30 #放大特征
  69. y_train1 = data_train1[:, 0].astype(int)
  70. x_test1 = data_test1[:, 2:] * 30 #放大特征
  71. y_test1 = data_test1[:, 0].astype(int)
  72. print(x_train1.shape)
  73. print(x_test1.shape)
  74. cm_train1 = metrics.confusion_matrix(y_train1, model.predict(x_train1))
  75. # df3 = DataFrame(cm_train1, index = range(1, 5), columns=range(1, 5))
  76. # df3.to_excel('./real_data_xlxs/realPreTrain.xlsx')
  77. # print(model.score(x_train1, y_train1)) # 评价模型测试的准确率
  78. cm_plot(y_train1, model.predict(x_train1)).show() # cm_plot是自定义的画混淆矩阵的函数
  79. cm_test1 = metrics.confusion_matrix(y_test1, model.predict(x_test1))
  80. # df4 = DataFrame(cm_test1, index = range(1, 5), columns=range(1, 5))
  81. # df4.to_excel('./real_data_xlxs/realPreTest.xlsx')
  82. # print(model.score(x_test1, y_test1))
  83. cm_plot(y_test1, model.predict(x_test1)).show()
  84. print(model.score(x_train1, y_train1)) # 评价模型训练的准确率
  85. print(model.score(x_test1, y_test1)) # 评价模型测试的准确率




