当前位置:   article > 正文

Python实现图片美化,醉后不知天在水?(附上代码) | 机器学习_机器学习图片美化










根据我的另一篇文章:如何将照片美化,DPED机器学习开源项目安装使用 | 机器学习_阿良的博客-CSDN博客







其中模型文件下载地址:https://pan.baidu.com/s/1IUm8xz5dhh8iW_bLWfihPQ 提取码:TUAN


环境依赖可以直接参考:Python实现替换照片人物背景,精细到头发丝(附上代码) | 机器学习_阿良的博客-CSDN博客






  1. #!/usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. # @Time : 2021/11/27 13:48
  4. # @Author : 剑客阿良_ALiang
  5. # @Site :
  6. # @File : dped.py
  7. # python test_model.py model=iphone_orig dped_dir=dped/ test_subset=full
  8. # iteration=all resolution=orig use_gpu=true
  9. import imageio
  10. from PIL import Image
  11. import numpy as np
  12. import tensorflow as tf
  13. import os
  14. import sys
  15. import scipy.stats as st
  16. import uuid
  17. from functools import reduce
  18. # ---------------------- hy add 2 ----------------------
  19. def log10(x):
  20. numerator = tf.compat.v1.log(x)
  21. denominator = tf.compat.v1.log(tf.constant(10, dtype=numerator.dtype))
  22. return numerator / denominator
  23. def _tensor_size(tensor):
  24. from operator import mul
  25. return reduce(mul, (d.value for d in tensor.get_shape()[1:]), 1)
  26. def gauss_kernel(kernlen=21, nsig=3, channels=1):
  27. interval = (2 * nsig + 1.) / (kernlen)
  28. x = np.linspace(-nsig - interval / 2., nsig + interval / 2., kernlen + 1)
  29. kern1d = np.diff(st.norm.cdf(x))
  30. kernel_raw = np.sqrt(np.outer(kern1d, kern1d))
  31. kernel = kernel_raw / kernel_raw.sum()
  32. out_filter = np.array(kernel, dtype=np.float32)
  33. out_filter = out_filter.reshape((kernlen, kernlen, 1, 1))
  34. out_filter = np.repeat(out_filter, channels, axis=2)
  35. return out_filter
  36. def blur(x):
  37. kernel_var = gauss_kernel(21, 3, 3)
  38. return tf.nn.depthwise_conv2d(x, kernel_var, [1, 1, 1, 1], padding='SAME')
  39. def process_command_args(arguments):
  40. # specifying default parameters
  41. batch_size = 50
  42. train_size = 30000
  43. learning_rate = 5e-4
  44. num_train_iters = 20000
  45. w_content = 10
  46. w_color = 0.5
  47. w_texture = 1
  48. w_tv = 2000
  49. dped_dir = 'dped/'
  50. vgg_dir = 'vgg_pretrained/imagenet-vgg-verydeep-19.mat'
  51. eval_step = 1000
  52. phone = ""
  53. for args in arguments:
  54. if args.startswith("model"):
  55. phone = args.split("=")[1]
  56. if args.startswith("batch_size"):
  57. batch_size = int(args.split("=")[1])
  58. if args.startswith("train_size"):
  59. train_size = int(args.split("=")[1])
  60. if args.startswith("learning_rate"):
  61. learning_rate = float(args.split("=")[1])
  62. if args.startswith("num_train_iters"):
  63. num_train_iters = int(args.split("=")[1])
  64. # -----------------------------------
  65. if args.startswith("w_content"):
  66. w_content = float(args.split("=")[1])
  67. if args.startswith("w_color"):
  68. w_color = float(args.split("=")[1])
  69. if args.startswith("w_texture"):
  70. w_texture = float(args.split("=")[1])
  71. if args.startswith("w_tv"):
  72. w_tv = float(args.split("=")[1])
  73. # -----------------------------------
  74. if args.startswith("dped_dir"):
  75. dped_dir = args.split("=")[1]
  76. if args.startswith("vgg_dir"):
  77. vgg_dir = args.split("=")[1]
  78. if args.startswith("eval_step"):
  79. eval_step = int(args.split("=")[1])
  80. if phone == "":
  81. print("\nPlease specify the camera model by running the script with the following parameter:\n")
  82. print("python train_model.py model={iphone,blackberry,sony}\n")
  83. sys.exit()
  84. if phone not in ["iphone", "sony", "blackberry"]:
  85. print("\nPlease specify the correct camera model:\n")
  86. print("python train_model.py model={iphone,blackberry,sony}\n")
  87. sys.exit()
  88. print("\nThe following parameters will be applied for CNN training:\n")
  89. print("Phone model:", phone)
  90. print("Batch size:", batch_size)
  91. print("Learning rate:", learning_rate)
  92. print("Training iterations:", str(num_train_iters))
  93. print()
  94. print("Content loss:", w_content)
  95. print("Color loss:", w_color)
  96. print("Texture loss:", w_texture)
  97. print("Total variation loss:", str(w_tv))
  98. print()
  99. print("Path to DPED dataset:", dped_dir)
  100. print("Path to VGG-19 network:", vgg_dir)
  101. print("Evaluation step:", str(eval_step))
  102. print()
  103. return phone, batch_size, train_size, learning_rate, num_train_iters, \
  104. w_content, w_color, w_texture, w_tv, \
  105. dped_dir, vgg_dir, eval_step
  106. def process_test_model_args(arguments):
  107. phone = ""
  108. dped_dir = 'dped/'
  109. test_subset = "small"
  110. iteration = "all"
  111. resolution = "orig"
  112. use_gpu = "true"
  113. for args in arguments:
  114. if args.startswith("model"):
  115. phone = args.split("=")[1]
  116. if args.startswith("dped_dir"):
  117. dped_dir = args.split("=")[1]
  118. if args.startswith("test_subset"):
  119. test_subset = args.split("=")[1]
  120. if args.startswith("iteration"):
  121. iteration = args.split("=")[1]
  122. if args.startswith("resolution"):
  123. resolution = args.split("=")[1]
  124. if args.startswith("use_gpu"):
  125. use_gpu = args.split("=")[1]
  126. if phone == "":
  127. print("\nPlease specify the model by running the script with the following parameter:\n")
  128. print(
  129. "python test_model.py model={iphone,blackberry,sony,iphone_orig,blackberry_orig,sony_orig}\n")
  130. sys.exit()
  131. return phone, dped_dir, test_subset, iteration, resolution, use_gpu
  132. def get_resolutions():
  134. res_sizes = {}
  135. res_sizes["iphone"] = [1536, 2048]
  136. res_sizes["iphone_orig"] = [1536, 2048]
  137. res_sizes["blackberry"] = [1560, 2080]
  138. res_sizes["blackberry_orig"] = [1560, 2080]
  139. res_sizes["sony"] = [1944, 2592]
  140. res_sizes["sony_orig"] = [1944, 2592]
  141. res_sizes["high"] = [1260, 1680]
  142. res_sizes["medium"] = [1024, 1366]
  143. res_sizes["small"] = [768, 1024]
  144. res_sizes["tiny"] = [600, 800]
  145. return res_sizes
  146. def get_specified_res(res_sizes, phone, resolution):
  147. if resolution == "orig":
  148. IMAGE_HEIGHT = res_sizes[phone][0]
  149. IMAGE_WIDTH = res_sizes[phone][1]
  150. else:
  151. IMAGE_HEIGHT = res_sizes[resolution][0]
  152. IMAGE_WIDTH = res_sizes[resolution][1]
  155. def extract_crop(image, resolution, phone, res_sizes):
  156. if resolution == "orig":
  157. return image
  158. else:
  159. x_up = int((res_sizes[phone][1] - res_sizes[resolution][1]) / 2)
  160. y_up = int((res_sizes[phone][0] - res_sizes[resolution][0]) / 2)
  161. x_down = x_up + res_sizes[resolution][1]
  162. y_down = y_up + res_sizes[resolution][0]
  163. return image[y_up: y_down, x_up: x_down, :]
  164. # ---------------------- hy add 1 ----------------------
  165. def resnet(input_image):
  166. with tf.compat.v1.variable_scope("generator"):
  167. W1 = weight_variable([9, 9, 3, 64], name="W1")
  168. b1 = bias_variable([64], name="b1")
  169. c1 = tf.nn.relu(conv2d(input_image, W1) + b1)
  170. # residual 1
  171. W2 = weight_variable([3, 3, 64, 64], name="W2")
  172. b2 = bias_variable([64], name="b2")
  173. c2 = tf.nn.relu(_instance_norm(conv2d(c1, W2) + b2))
  174. W3 = weight_variable([3, 3, 64, 64], name="W3")
  175. b3 = bias_variable([64], name="b3")
  176. c3 = tf.nn.relu(_instance_norm(conv2d(c2, W3) + b3)) + c1
  177. # residual 2
  178. W4 = weight_variable([3, 3, 64, 64], name="W4")
  179. b4 = bias_variable([64], name="b4")
  180. c4 = tf.nn.relu(_instance_norm(conv2d(c3, W4) + b4))
  181. W5 = weight_variable([3, 3, 64, 64], name="W5")
  182. b5 = bias_variable([64], name="b5")
  183. c5 = tf.nn.relu(_instance_norm(conv2d(c4, W5) + b5)) + c3
  184. # residual 3
  185. W6 = weight_variable([3, 3, 64, 64], name="W6")
  186. b6 = bias_variable([64], name="b6")
  187. c6 = tf.nn.relu(_instance_norm(conv2d(c5, W6) + b6))
  188. W7 = weight_variable([3, 3, 64, 64], name="W7")
  189. b7 = bias_variable([64], name="b7")
  190. c7 = tf.nn.relu(_instance_norm(conv2d(c6, W7) + b7)) + c5
  191. # residual 4
  192. W8 = weight_variable([3, 3, 64, 64], name="W8")
  193. b8 = bias_variable([64], name="b8")
  194. c8 = tf.nn.relu(_instance_norm(conv2d(c7, W8) + b8))
  195. W9 = weight_variable([3, 3, 64, 64], name="W9")
  196. b9 = bias_variable([64], name="b9")
  197. c9 = tf.nn.relu(_instance_norm(conv2d(c8, W9) + b9)) + c7
  198. # Convolutional
  199. W10 = weight_variable([3, 3, 64, 64], name="W10")
  200. b10 = bias_variable([64], name="b10")
  201. c10 = tf.nn.relu(conv2d(c9, W10) + b10)
  202. W11 = weight_variable([3, 3, 64, 64], name="W11")
  203. b11 = bias_variable([64], name="b11")
  204. c11 = tf.nn.relu(conv2d(c10, W11) + b11)
  205. # Final
  206. W12 = weight_variable([9, 9, 64, 3], name="W12")
  207. b12 = bias_variable([3], name="b12")
  208. enhanced = tf.nn.tanh(conv2d(c11, W12) + b12) * 0.58 + 0.5
  209. return enhanced
  210. def adversarial(image_):
  211. with tf.compat.v1.variable_scope("discriminator"):
  212. conv1 = _conv_layer(image_, 48, 11, 4, batch_nn=False)
  213. conv2 = _conv_layer(conv1, 128, 5, 2)
  214. conv3 = _conv_layer(conv2, 192, 3, 1)
  215. conv4 = _conv_layer(conv3, 192, 3, 1)
  216. conv5 = _conv_layer(conv4, 128, 3, 2)
  217. flat_size = 128 * 7 * 7
  218. conv5_flat = tf.reshape(conv5, [-1, flat_size])
  219. W_fc = tf.Variable(tf.compat.v1.truncated_normal(
  220. [flat_size, 1024], stddev=0.01))
  221. bias_fc = tf.Variable(tf.constant(0.01, shape=[1024]))
  222. fc = leaky_relu(tf.matmul(conv5_flat, W_fc) + bias_fc)
  223. W_out = tf.Variable(
  224. tf.compat.v1.truncated_normal([1024, 2], stddev=0.01))
  225. bias_out = tf.Variable(tf.constant(0.01, shape=[2]))
  226. adv_out = tf.nn.softmax(tf.matmul(fc, W_out) + bias_out)
  227. return adv_out
  228. def weight_variable(shape, name):
  229. initial = tf.compat.v1.truncated_normal(shape, stddev=0.01)
  230. return tf.Variable(initial, name=name)
  231. def bias_variable(shape, name):
  232. initial = tf.constant(0.01, shape=shape)
  233. return tf.Variable(initial, name=name)
  234. def conv2d(x, W):
  235. return tf.nn.conv2d(x, W, strides=[1, 1, 1, 1], padding='SAME')
  236. def leaky_relu(x, alpha=0.2):
  237. return tf.maximum(alpha * x, x)
  238. def _conv_layer(net, num_filters, filter_size, strides, batch_nn=True):
  239. weights_init = _conv_init_vars(net, num_filters, filter_size)
  240. strides_shape = [1, strides, strides, 1]
  241. bias = tf.Variable(tf.constant(0.01, shape=[num_filters]))
  242. net = tf.nn.conv2d(net, weights_init, strides_shape, padding='SAME') + bias
  243. net = leaky_relu(net)
  244. if batch_nn:
  245. net = _instance_norm(net)
  246. return net
  247. def _instance_norm(net):
  248. batch, rows, cols, channels = [i.value for i in net.get_shape()]
  249. var_shape = [channels]
  250. mu, sigma_sq = tf.compat.v1.nn.moments(net, [1, 2], keepdims=True)
  251. shift = tf.Variable(tf.zeros(var_shape))
  252. scale = tf.Variable(tf.ones(var_shape))
  253. epsilon = 1e-3
  254. normalized = (net - mu) / (sigma_sq + epsilon) ** (.5)
  255. return scale * normalized + shift
  256. def _conv_init_vars(net, out_channels, filter_size, transpose=False):
  257. _, rows, cols, in_channels = [i.value for i in net.get_shape()]
  258. if not transpose:
  259. weights_shape = [filter_size, filter_size, in_channels, out_channels]
  260. else:
  261. weights_shape = [filter_size, filter_size, out_channels, in_channels]
  262. weights_init = tf.Variable(
  263. tf.compat.v1.truncated_normal(
  264. weights_shape,
  265. stddev=0.01,
  266. seed=1),
  267. dtype=tf.float32)
  268. return weights_init
  269. # ---------------------- hy add 0 ----------------------
  270. def beautify(pic_path: str, output_dir: str, gpu='1'):
  271. tf.compat.v1.disable_v2_behavior()
  272. # process command arguments
  273. phone = "iphone_orig"
  274. test_subset = "full"
  275. iteration = "all"
  276. resolution = "orig"
  277. # get all available image resolutions
  278. res_sizes = get_resolutions()
  279. # get the specified image resolution
  280. IMAGE_HEIGHT, IMAGE_WIDTH, IMAGE_SIZE = get_specified_res(
  281. res_sizes, phone, resolution)
  282. if gpu == '1':
  283. use_gpu = 'true'
  284. else:
  285. use_gpu = 'false'
  286. # disable gpu if specified
  287. config = tf.compat.v1.ConfigProto(
  288. device_count={'GPU': 0}) if use_gpu == "false" else None
  289. # create placeholders for input images
  290. x_ = tf.compat.v1.placeholder(tf.float32, [None, IMAGE_SIZE])
  291. x_image = tf.reshape(x_, [-1, IMAGE_HEIGHT, IMAGE_WIDTH, 3])
  292. # generate enhanced image
  293. enhanced = resnet(x_image)
  294. with tf.compat.v1.Session(config=config) as sess:
  295. # test_dir = dped_dir + phone.replace("_orig",
  296. # "") + "/test_data/full_size_test_images/"
  297. # test_photos = [f for f in os.listdir(
  298. # test_dir) if os.path.isfile(test_dir + f)]
  299. test_photos = [pic_path]
  300. if test_subset == "small":
  301. # use five first images only
  302. test_photos = test_photos[0:5]
  303. if phone.endswith("_orig"):
  304. # load pre-trained model
  305. saver = tf.compat.v1.train.Saver()
  306. saver.restore(sess, "models_orig/" + phone)
  307. for photo in test_photos:
  308. # load training image and crop it if necessary
  309. new_pic_name = uuid.uuid4()
  310. print(
  311. "Testing original " +
  312. phone.replace(
  313. "_orig",
  314. "") +
  315. " model, processing image " +
  316. photo)
  317. image = np.float16(np.array(
  318. Image.fromarray(imageio.imread(photo)).resize([res_sizes[phone][1], res_sizes[phone][0]]))) / 255
  319. image_crop = extract_crop(
  320. image, resolution, phone, res_sizes)
  321. image_crop_2d = np.reshape(image_crop, [1, IMAGE_SIZE])
  322. # get enhanced image
  323. enhanced_2d = sess.run(enhanced, feed_dict={x_: image_crop_2d})
  324. enhanced_image = np.reshape(
  325. enhanced_2d, [IMAGE_HEIGHT, IMAGE_WIDTH, 3])
  326. before_after = np.hstack((image_crop, enhanced_image))
  327. photo_name = photo.rsplit(".", 1)[0]
  328. # save the results as .png images
  329. # imageio.imwrite(
  330. # "visual_results/" +
  331. # phone +
  332. # "_" +
  333. # photo_name +
  334. # "_enhanced.png",
  335. # enhanced_image)
  336. imageio.imwrite(os.path.join(output_dir, '{}.png'.format(new_pic_name)), enhanced_image)
  337. # imageio.imwrite(
  338. # "visual_results/" +
  339. # phone +
  340. # "_" +
  341. # photo_name +
  342. # "_before_after.png",
  343. # before_after)
  344. imageio.imwrite(os.path.join(output_dir, '{}_before_after.png'.format(new_pic_name)), before_after)
  345. return os.path.join(output_dir, '{}.png'.format(new_pic_name))
  346. else:
  347. num_saved_models = int(len([f for f in os.listdir(
  348. "models_orig/") if f.startswith(phone + "_iteration")]) / 2)
  349. if iteration == "all":
  350. iteration = np.arange(1, num_saved_models) * 1000
  351. else:
  352. iteration = [int(iteration)]
  353. for i in iteration:
  354. # load pre-trained model
  355. saver = tf.compat.v1.train.Saver()
  356. saver.restore(
  357. sess,
  358. "models_orig/" +
  359. phone +
  360. "_iteration_" +
  361. str(i) +
  362. ".ckpt")
  363. for photo in test_photos:
  364. # load training image and crop it if necessary
  365. new_pic_name = uuid.uuid4()
  366. print("iteration " + str(i) + ", processing image " + photo)
  367. image = np.float16(np.array(
  368. Image.fromarray(imageio.imread(photo)).resize(
  369. [res_sizes[phone][1], res_sizes[phone][0]]))) / 255
  370. image_crop = extract_crop(
  371. image, resolution, phone, res_sizes)
  372. image_crop_2d = np.reshape(image_crop, [1, IMAGE_SIZE])
  373. # get enhanced image
  374. enhanced_2d = sess.run(enhanced, feed_dict={x_: image_crop_2d})
  375. enhanced_image = np.reshape(
  376. enhanced_2d, [IMAGE_HEIGHT, IMAGE_WIDTH, 3])
  377. before_after = np.hstack((image_crop, enhanced_image))
  378. photo_name = photo.rsplit(".", 1)[0]
  379. # save the results as .png images
  380. # imageio.imwrite(
  381. # "visual_results/" +
  382. # phone +
  383. # "_" +
  384. # photo_name +
  385. # "_iteration_" +
  386. # str(i) +
  387. # "_enhanced.png",
  388. # enhanced_image)
  389. imageio.imwrite(os.path.join(output_dir, '{}.png'.format(new_pic_name)), enhanced_image)
  390. # imageio.imwrite(
  391. # "visual_results/" +
  392. # phone +
  393. # "_" +
  394. # photo_name +
  395. # "_iteration_" +
  396. # str(i) +
  397. # "_before_after.png",
  398. # before_after)
  399. imageio.imwrite(os.path.join(output_dir, '{}_before_after.png'.format(new_pic_name)), before_after)
  400. return os.path.join(output_dir, '{}.png'.format(new_pic_name))
  401. if __name__ == '__main__':
  402. print(beautify('C:/Users/yi/Desktop/6.jpg', 'result/'))













