赞
踩
最近在做ITP的毕业作品,我的设想是做一个用涂鸦+AI识别的方式来玩的游戏。类似Scribblenauts, 但是输入词语的方式是用涂鸦和AI识别。
Google并没有开放涂鸦识别的API,只开放了数据集,所以我需要自己做一个服务器来提供涂鸦识别服务。
这周把服务器架设好了,在这里分享一下我的实现方法,也方便后面的人参考。
完成的代码在这里:
EonYang/flaskServergithub.com我的Demo在这里:
p5.js Web Editoreditor.p5js.org我并不是Machine Learning的专家,Python也是今年才开始学。其中参考,借用,copy了很多别人的代码和模型。
值得庆幸的是,Kaggle组织了一个QuickDraw训练的比赛,很多人参与并公开了自己的实现方式和模型。我直接借用了排名第一的代码和模型。原文和链接在这里。
通过读他的代码,可以看出他做了什么:
我电脑烂,偷懒不想自己Train模型。直接下载他的model.h5
文件。
同时我们还需要借用的是他用来Prepare图像的function。需要注意的是,这个Model必须输入黑色背景白色笔划的图片。
环境: Python 3.5.5, Tensorflow和Keras当前最新的不知道什么版本。
predictor.py
, 导入用得上的library,一大堆。- from IPython.core.interactiveshell import InteractiveShell
- InteractiveShell.ast_node_interactivity = "all"
- import os
- import json
- import datetime as dt
- import cv2
- import base64
- import io
- from PIL import Image
- import pandas as pd
- import numpy as np
- import tensorflow as tf
- from tensorflow import keras
- from tensorflow.keras.layers import Conv2D, MaxPooling2D
- from tensorflow.keras.layers import Dense, Dropout, Flatten, Activation
- from tensorflow.keras.metrics import categorical_accuracy, top_k_categorical_accuracy, categorical_crossentropy
- from tensorflow.keras.models import Sequential
- from tensorflow.keras.callbacks import EarlyStopping, ReduceLROnPlateau, ModelCheckpoint
- from tensorflow.keras.optimizers import Adam
- from tensorflow.keras.applications import MobileNet
- from tensorflow.keras.applications.mobilenet import preprocess_input
- from tensorflow.keras import backend as K
这里面很多是测试时候用过的libray,后面好像有些测试代码删了,所有有些library并没有用上,但是我懒得整理了。
核心library有几个:
- BASE_SIZE = 256
- NCSVS = 100
- NCATS = 340
- np.random.seed(seed=1987)
- tf.set_random_seed(seed=1987)
-
- def top_3_accuracy(y_true, y_pred):
- return top_k_categorical_accuracy(y_true, y_pred, k=3)
-
- STEPS = 800
- EPOCHS = 16
- size = 64
- batchsize = 680
把前面下载到的model.h5
放到model文件夹下面。 真正的model是从Keras导入的Mobilenet,这个model.h5
只是一个weight文件,要用导入Mobilenet后,用load_weights
来加载。
- def init():
- sess = tf.InteractiveSession()
- loaded_model = MobileNet(input_shape=(size, size, 1), alpha=1., weights=None, classes=NCATS)
- loaded_model.load_weights("./model/model.h5")
- loaded_model.compile(optimizer=Adam(lr=0.002), loss='categorical_crossentropy',
- metrics=[categorical_crossentropy, categorical_accuracy, top_3_accuracy])
- print(loaded_model.summary())
- graph = tf.get_default_graph()
- return loaded_model, sess, graph
-
- global model, sess, graph
- model, sess, graph = init()
要不要测试一下?测试代码已经被我删了,直接往下写吧。
- def prepareImage(im):
- #gray
- im = cv2.cvtColor(im, cv2.COLOR_BGR2GRAY)
- #binary
- thresh = 127
- im = cv2.threshold(im, thresh, 255, cv2.THRESH_BINARY)[1]
- # see if need to invert
- n_white_pix = np.sum(im == 255)
- n_black_pix = np.sum(im == 0)
- if n_white_pix > n_black_pix:
- im = cv2.bitwise_not(im)
- #trim1, move content to the left-up corner;
- size = len(im[0])
- sum0 = im.sum(axis = 0)
- sum1 = im.sum(axis = 1)
- for i in range(len(sum0)):
- if sum0[i] == 0:
- im = np.delete(im, 0, 1)
- zero = np.zeros((size,1))
- im = np.append(im,zero,1)
- else :
- break
- for i in range(len(sum1)):
- if sum1[i] == 0:
- im = np.delete(im, 0, 0)
- zero = np.zeros((1,size))
- im = np.append(im,zero,0)
- else :
- break
- # trim2 crop content
- sum3 = im.sum(axis = 0)
- sum4 = im.sum(axis = 1)
- x2 = 1
- y2 = 1
- while x2 < len(sum3) and sum3[-x2] ==0:
- x2 += 1
- while y2 < len(sum4) and sum4[-y2] ==0:
- y2 += 1
- w = size - x2
- h = size - y2
- contentSize = w if w > h else h
- # only crop if there is realy content
- if contentSize > 16:
- im = im[0:contentSize, 0:contentSize]
- return im
这个function做的事情是:
- def prepareImageAndPredict(model, cv2ImageData,size=64):
- try:
- # downsize to 64
- image64 = cv2.resize(cv2ImageData, (64, 64))
- x = np.zeros((1,size, size, 1))
- x[0, :, :, 0] = image64
- x = preprocess_input(x).astype(np.float32)
- prediction = model.predict(x, batch_size=128, verbose=1)
- top5 = np.argsort(-prediction, axis=1)[:, :5]
- return top5[0]
-
- except Exception as e:
- print(e)
- pass
这里面做的事情是:
model.predict
得到Prediction
。return top5[0]
- imagePath = "./whateverDoodle.jpg"
- image = cv2.imread(imagePath)
- image = prepareImage(image)
-
- with sess.as_default():
- with graph.as_default():
- prediction= prepareImageAndPredict(model, image).tolist()
-
- print(prediction)
以前的测试代码已经删了,这一段是我临时敲的,没运行过,可能会有错误什么的。
如果成功,会打印出5个数字,0-340之间,每个数字代表一个物品。
数字是以下list的index:
categories = ['airplane', 'alarm clock', 'ambulance', 'angel', 'animal migration', 'ant', 'anvil', 'apple', 'arm', 'asparagus', 'axe', 'backpack', 'banana', 'bandage', 'barn', 'baseball', 'baseball bat', 'basket', 'basketball', 'bat', 'bathtub', 'beach', 'bear', 'beard', 'bed', 'bee', 'belt', 'bench', 'bicycle', 'binoculars', 'bird', 'birthday cake', 'blackberry', 'blueberry', 'book', 'boomerang', 'bottlecap', 'bowtie', 'bracelet', 'brain', 'bread', 'bridge', 'broccoli', 'broom', 'bucket', 'bulldozer', 'bus', 'bush', 'butterfly', 'cactus', 'cake', 'calculator', 'calendar', 'camel', 'camera', 'camouflage', 'campfire', 'candle', 'cannon', 'canoe', 'car', 'carrot', 'castle', 'cat', 'ceiling fan', 'cell phone', 'cello', 'chair', 'chandelier', 'church', 'circle', 'clarinet', 'clock', 'cloud', 'coffee cup', 'compass', 'computer', 'cookie', 'cooler', 'couch', 'cow', 'crab', 'crayon', 'crocodile', 'crown', 'cruise ship', 'cup', 'diamond', 'dishwasher', 'diving board', 'dog', 'dolphin', 'donut', 'door', 'dragon', 'dresser', 'drill', 'drums', 'duck', 'dumbbell', 'ear', 'elbow', 'elephant', 'envelope', 'eraser', 'eye', 'eyeglasses', 'face', 'fan', 'feather', 'fence', 'finger', 'fire hydrant', 'fireplace', 'firetruck', 'fish', 'flamingo', 'flashlight', 'flip flops', 'floor lamp', 'flower', 'flying saucer', 'foot', 'fork', 'frog', 'frying pan', 'garden', 'garden hose', 'giraffe', 'goatee', 'golf club', 'grapes', 'grass', 'guitar', 'hamburger', 'hammer', 'hand', 'harp', 'hat', 'headphones', 'hedgehog', 'helicopter', 'helmet', 'hexagon', 'hockey puck', 'hockey stick', 'horse', 'hospital', 'hot air balloon', 'hot dog', 'hot tub', 'hourglass', 'house', 'house plant', 'hurricane', 'ice cream', 'jacket', 'jail', 'kangaroo', 'key', 'keyboard', 'knee', 'ladder', 'lantern', 'laptop', 'leaf', 'leg', 'light bulb', 'lighthouse', 'lightning', 'line', 'lion', 'lipstick', 'lobster', 'lollipop', 'mailbox', 'map', 'marker', 'matches', 'megaphone', 'mermaid', 'microphone', 'microwave', 'monkey', 'moon', 'mosquito', 'motorbike', 'mountain', 'mouse', 'moustache', 'mouth', 'mug', 'mushroom', 'nail', 'necklace', 'nose', 'ocean', 'octagon', 'octopus', 'onion', 'oven', 'owl', 'paint can', 'paintbrush', 'palm tree', 'panda', 'pants', 'paper clip', 'parachute', 'parrot', 'passport', 'peanut', 'pear', 'peas', 'pencil', 'penguin', 'piano', 'pickup truck', 'picture frame', 'pig', 'pillow', 'pineapple', 'pizza', 'pliers', 'police car', 'pond', 'pool', 'popsicle', 'postcard', 'potato', 'power outlet', 'purse', 'rabbit', 'raccoon', 'radio', 'rain', 'rainbow', 'rake', 'remote control', 'rhinoceros', 'river', 'roller coaster', 'rollerskates', 'sailboat', 'sandwich', 'saw', 'saxophone', 'school bus', 'scissors', 'scorpion', 'screwdriver', 'sea turtle', 'see saw', 'shark', 'sheep', 'shoe', 'shorts', 'shovel', 'sink', 'skateboard', 'skull', 'skyscraper', 'sleeping bag', 'smiley face', 'snail', 'snake', 'snorkel', 'snowflake', 'snowman', 'soccer ball', 'sock', 'speedboat', 'spider', 'spoon', 'spreadsheet', 'square', 'squiggle', 'squirrel', 'stairs', 'star', 'steak', 'stereo', 'stethoscope', 'stitches', 'stop sign', 'stove', 'strawberry', 'streetlight', 'string bean', 'submarine', 'suitcase', 'sun', 'swan', 'sweater', 'swing set', 'sword', 't-shirt', 'table', 'teapot', 'teddy-bear', 'telephone', 'television', 'tennis racquet', 'tent', 'The Eiffel Tower', 'The Great Wall of China', 'The Mona Lisa', 'tiger', 'toaster', 'toe', 'toilet', 'tooth', 'toothbrush', 'toothpaste', 'tornado', 'tractor', 'traffic light', 'train', 'tree', 'triangle', 'trombone', 'truck', 'trumpet', 'umbrella', 'underwear', 'van', 'vase', 'violin', 'washing machine', 'watermelon', 'waterslide', 'whale', 'wheel', 'windmill', 'wine bottle', 'wine glass', 'wristwatch', 'yoga', 'zebra', 'zigzag']
现在如果它work了,我们可以创建一个Flask server了
server.py
,然后按照python惯例,import一大堆library,即使有些已经用不着了。- from flask import Flask , jsonify, request, render_template, send_from_directory
- from flask_cors import CORS
- from predictor import *
- import random
- import json
- import pandas as pd
- import numpy as np
- from tensorflow.keras import models
- import time
- import datetime
- import cv2
- import sys, getopt
- import os
- import base64
- import io
- from PIL import Image
- app = Flask(__name__)
- CORS(app)
CORS可以允许你从别的域名来Ajax请求你的API。也是我偷懒用的,因为我的客户端demo没有放在自己的服务器上,直接托管在了p5js上。如果你的客户端和自己API同一个域名,你不需要这个。
- def stringToRGB(base64_string):
- imgdata = base64.b64decode(str(base64_string))
- image = Image.open(io.BytesIO(imgdata))
- return cv2.cvtColor(np.array(image), cv2.COLOR_BGR2RGB)
-
- @app.route("/api/doodlePredict", methods=["POST"])
- def predictAPI():
- global model, graph
- print("this is the request: ", request.form.to_dict())
- image_raw = request.form.to_dict()["data"]
- image_raw = stringToRGB(image_raw)
- image = prepareImage(image_raw)
- response = {'prediction':{
- 'numbers':[],
- 'names':[]
- }}
- with sess.as_default():
- with graph.as_default():
- response['prediction']['numbers'] = prepareImageAndPredict(model, image).tolist()
- for i in range(len(response['prediction']['numbers'])):
- response['prediction']['names'].append(categories[response['prediction']['numbers'][i]])
- print("this is the response: ", response['prediction']['names'])
- cv2.imwrite("./doodleHistory/"+ ', '.join(response['prediction']['names']) +", "+datetime.datetime.now().strftime("%I:%M%p on %B %d, %Y") +".jpg", image_raw)
- return jsonify(response)
这个route假设,图片文件是被转换成了base64
之后,放在一个form里面的data
栏 里,通过POST传过来的。
于是先读form,然后转换成dict,然后从dict里面读data,读完了再转成rgb图片。
然后call我们的prepareImage
,确保图片是黑底白笔划,没有多余背景。
然后就和上面测试一样,Sess和Graph用上,callmodel.predict()
。
我顺便把传过来的Doodle图像给存起来了。
然后把得到的结果弄成一个JSON,return回去。
app.run()
- if __name__ == "__main__":
- app.run(host = "0.0.0.0", port = 5800, debug = True)
注意必须定义host = "0.0.0.0"
才可以从外网访问这个API。
我直接在p5js
的在线editor里面写了,因为这样比较快。这里不按模块一段一段写了,只写写思路,和放出代码。
今天先写这么多,回头补图。
游戏正在制作中,核心模块已经快完成,有空的话我会写一篇来分享我的游戏是怎么做的而且怎么用ML的。
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。