当前位置:   article > 正文

python-opencv五种自动白平衡算法,附源码直接可用(均值、完美反射、灰度世界、动态阈值、基于图像分析的偏色检测及颜色校正)

python-opencv五种自动白平衡算法,附源码直接可用(均值、完美反射、灰度世界、动态阈值、基于图像分析的偏色检测及颜色校正)

最近研究了自动白平衡的几种方法,参考了不少,最为感谢python opencv白平衡算法(但是这篇文章提供的算法没有考虑到uint8格式问题,产生了图像的局部失真,这里做了改进):(<-原图,失真图->)

谈谈总体理解:

(本来目标是同一张图,无论在什么样子的滤镜、光照下最后白平衡结果要尽可能相同,最后发现都太难了)

1.均值、灰度世界都建立一种计算平均的算法基础上,适用于色彩分布比较全面平均的场景,其实在很多场合都不适用

2.完美反射、动态阈值建立在白点的基础上,比如完美反射认为最亮的点为白点,以白点为基础进行整体的调节,导致的问题在于如果整张图没有白点算法效果非常不好,其次,由于不同色温下白点所呈现的数值差异性很大,导致白平衡结果不尽如人意。且Ratio的选取也有效果差异。还有一种做法是固定某一区域为白色区域然后算法计算,延展全图,效果展示使用uint8格式时一定要注意的问题(python-opencv完美反射白平衡算法)

3.基于图像分析的偏色检测及颜色校正,看了这篇原文,感觉整体意思是提供一种偏色检测的做法,然后还是采用基于完美反射、灰度世界的改进算法进行白平衡,效果同样局限。


结果展示,在不同的场景下每种白平衡结果都有不同,没有通用性的最好算法:

  1. 第一张: 原图
  2. 第二张:均值白平衡法
  3. 第三张: 完美反射
  4. 第四张: 灰度世界假设
  5. 第五张: 基于图像分析的偏色检测及颜色校正方法
  6. 第六张: 动态阈值算法


源码:

  1. import cv2
  2. import numpy as np
  3. import random
  4. def white_balance_1(img):
  5. '''
  6. 第一种简单的求均值白平衡法
  7. :param img: cv2.imread读取的图片数据
  8. :return: 返回的白平衡结果图片数据
  9. '''
  10. # 读取图像
  11. r, g, b = cv2.split(img)
  12. r_avg = cv2.mean(r)[0]
  13. g_avg = cv2.mean(g)[0]
  14. b_avg = cv2.mean(b)[0]
  15. # 求各个通道所占增益
  16. k = (r_avg + g_avg + b_avg) / 3
  17. kr = k / r_avg
  18. kg = k / g_avg
  19. kb = k / b_avg
  20. r = cv2.addWeighted(src1=r, alpha=kr, src2=0, beta=0, gamma=0)
  21. g = cv2.addWeighted(src1=g, alpha=kg, src2=0, beta=0, gamma=0)
  22. b = cv2.addWeighted(src1=b, alpha=kb, src2=0, beta=0, gamma=0)
  23. balance_img = cv2.merge([b, g, r])
  24. return balance_img
  25. def white_balance_2(img_input):
  26. '''
  27. 完美反射白平衡
  28. STEP 1:计算每个像素的R\G\B之和
  29. STEP 2:按R+G+B值的大小计算出其前Ratio%的值作为参考点的的阈值T
  30. STEP 3:对图像中的每个点,计算其中R+G+B值大于T的所有点的R\G\B分量的累积和的平均值
  31. STEP 4:对每个点将像素量化到[0,255]之间
  32. 依赖ratio值选取而且对亮度最大区域不是白色的图像效果不佳。
  33. :param img: cv2.imread读取的图片数据
  34. :return: 返回的白平衡结果图片数据
  35. '''
  36. img = img_input.copy()
  37. b, g, r = cv2.split(img)
  38. m, n, t = img.shape
  39. sum_ = np.zeros(b.shape)
  40. for i in range(m):
  41. for j in range(n):
  42. sum_[i][j] = int(b[i][j]) + int(g[i][j]) + int(r[i][j])
  43. hists, bins = np.histogram(sum_.flatten(), 766, [0, 766])
  44. Y = 765
  45. num, key = 0, 0
  46. ratio = 0.01
  47. while Y >= 0:
  48. num += hists[Y]
  49. if num > m * n * ratio / 100:
  50. key = Y
  51. break
  52. Y = Y - 1
  53. sum_b, sum_g, sum_r = 0, 0, 0
  54. time = 0
  55. for i in range(m):
  56. for j in range(n):
  57. if sum_[i][j] >= key:
  58. sum_b += b[i][j]
  59. sum_g += g[i][j]
  60. sum_r += r[i][j]
  61. time = time + 1
  62. avg_b = sum_b / time
  63. avg_g = sum_g / time
  64. avg_r = sum_r / time
  65. maxvalue = float(np.max(img))
  66. # maxvalue = 255
  67. for i in range(m):
  68. for j in range(n):
  69. b = int(img[i][j][0]) * maxvalue / int(avg_b)
  70. g = int(img[i][j][1]) * maxvalue / int(avg_g)
  71. r = int(img[i][j][2]) * maxvalue / int(avg_r)
  72. if b > 255:
  73. b = 255
  74. if b < 0:
  75. b = 0
  76. if g > 255:
  77. g = 255
  78. if g < 0:
  79. g = 0
  80. if r > 255:
  81. r = 255
  82. if r < 0:
  83. r = 0
  84. img[i][j][0] = b
  85. img[i][j][1] = g
  86. img[i][j][2] = r
  87. return img
  88. def white_balance_3(img):
  89. '''
  90. 灰度世界假设
  91. :param img: cv2.imread读取的图片数据
  92. :return: 返回的白平衡结果图片数据
  93. '''
  94. B, G, R = np.double(img[:, :, 0]), np.double(img[:, :, 1]), np.double(img[:, :, 2])
  95. B_ave, G_ave, R_ave = np.mean(B), np.mean(G), np.mean(R)
  96. K = (B_ave + G_ave + R_ave) / 3
  97. Kb, Kg, Kr = K / B_ave, K / G_ave, K / R_ave
  98. Ba = (B * Kb)
  99. Ga = (G * Kg)
  100. Ra = (R * Kr)
  101. for i in range(len(Ba)):
  102. for j in range(len(Ba[0])):
  103. Ba[i][j] = 255 if Ba[i][j] > 255 else Ba[i][j]
  104. Ga[i][j] = 255 if Ga[i][j] > 255 else Ga[i][j]
  105. Ra[i][j] = 255 if Ra[i][j] > 255 else Ra[i][j]
  106. # print(np.mean(Ba), np.mean(Ga), np.mean(Ra))
  107. dst_img = np.uint8(np.zeros_like(img))
  108. dst_img[:, :, 0] = Ba
  109. dst_img[:, :, 1] = Ga
  110. dst_img[:, :, 2] = Ra
  111. return dst_img
  112. def white_balance_4(img):
  113. '''
  114. 基于图像分析的偏色检测及颜色校正方法
  115. :param img: cv2.imread读取的图片数据
  116. :return: 返回的白平衡结果图片数据
  117. '''
  118. def detection(img):
  119. '''计算偏色值'''
  120. img_lab = cv2.cvtColor(img, cv2.COLOR_BGR2LAB)
  121. l, a, b = cv2.split(img_lab)
  122. d_a, d_b, M_a, M_b = 0, 0, 0, 0
  123. for i in range(m):
  124. for j in range(n):
  125. d_a = d_a + a[i][j]
  126. d_b = d_b + b[i][j]
  127. d_a, d_b = (d_a / (m * n)) - 128, (d_b / (n * m)) - 128
  128. D = np.sqrt((np.square(d_a) + np.square(d_b)))
  129. for i in range(m):
  130. for j in range(n):
  131. M_a = np.abs(a[i][j] - d_a - 128) + M_a
  132. M_b = np.abs(b[i][j] - d_b - 128) + M_b
  133. M_a, M_b = M_a / (m * n), M_b / (m * n)
  134. M = np.sqrt((np.square(M_a) + np.square(M_b)))
  135. k = D / M
  136. print('偏色值:%f' % k)
  137. return
  138. b, g, r = cv2.split(img)
  139. # print(img.shape)
  140. m, n = b.shape
  141. # detection(img)
  142. I_r_2 = np.zeros(r.shape)
  143. I_b_2 = np.zeros(b.shape)
  144. sum_I_r_2, sum_I_r, sum_I_b_2, sum_I_b, sum_I_g = 0, 0, 0, 0, 0
  145. max_I_r_2, max_I_r, max_I_b_2, max_I_b, max_I_g = int(r[0][0] ** 2), int(r[0][0]), int(b[0][0] ** 2), int(b[0][0]), int(g[0][0])
  146. for i in range(m):
  147. for j in range(n):
  148. I_r_2[i][j] = int(r[i][j] ** 2)
  149. I_b_2[i][j] = int(b[i][j] ** 2)
  150. sum_I_r_2 = I_r_2[i][j] + sum_I_r_2
  151. sum_I_b_2 = I_b_2[i][j] + sum_I_b_2
  152. sum_I_g = g[i][j] + sum_I_g
  153. sum_I_r = r[i][j] + sum_I_r
  154. sum_I_b = b[i][j] + sum_I_b
  155. if max_I_r < r[i][j]:
  156. max_I_r = r[i][j]
  157. if max_I_r_2 < I_r_2[i][j]:
  158. max_I_r_2 = I_r_2[i][j]
  159. if max_I_g < g[i][j]:
  160. max_I_g = g[i][j]
  161. if max_I_b_2 < I_b_2[i][j]:
  162. max_I_b_2 = I_b_2[i][j]
  163. if max_I_b < b[i][j]:
  164. max_I_b = b[i][j]
  165. [u_b, v_b] = np.matmul(np.linalg.inv([[sum_I_b_2, sum_I_b], [max_I_b_2, max_I_b]]), [sum_I_g, max_I_g])
  166. [u_r, v_r] = np.matmul(np.linalg.inv([[sum_I_r_2, sum_I_r], [max_I_r_2, max_I_r]]), [sum_I_g, max_I_g])
  167. # print(u_b, v_b, u_r, v_r)
  168. b0, g0, r0 = np.zeros(b.shape, np.uint8), np.zeros(g.shape, np.uint8), np.zeros(r.shape, np.uint8)
  169. for i in range(m):
  170. for j in range(n):
  171. b_point = u_b * (b[i][j] ** 2) + v_b * b[i][j]
  172. g0[i][j] = g[i][j]
  173. # r0[i][j] = r[i][j]
  174. r_point = u_r * (r[i][j] ** 2) + v_r * r[i][j]
  175. if r_point>255:
  176. r0[i][j] = 255
  177. else:
  178. if r_point<0:
  179. r0[i][j] = 0
  180. else:
  181. r0[i][j] = r_point
  182. if b_point>255:
  183. b0[i][j] = 255
  184. else:
  185. if b_point<0:
  186. b0[i][j] = 0
  187. else:
  188. b0[i][j] = b_point
  189. return cv2.merge([b0, g0, r0])
  190. def white_balance_5(img):
  191. '''
  192. 动态阈值算法
  193. 算法分为两个步骤:白点检测和白点调整。
  194. 只是白点检测不是与完美反射算法相同的认为最亮的点为白点,而是通过另外的规则确定
  195. :param img: cv2.imread读取的图片数据
  196. :return: 返回的白平衡结果图片数据
  197. '''
  198. b, g, r = cv2.split(img)
  199. """
  200. YUV空间
  201. """
  202. def con_num(x):
  203. if x > 0:
  204. return 1
  205. if x < 0:
  206. return -1
  207. if x == 0:
  208. return 0
  209. yuv_img = cv2.cvtColor(img, cv2.COLOR_BGR2YCrCb)
  210. (y, u, v) = cv2.split(yuv_img)
  211. # y, u, v = cv2.split(img)
  212. m, n = y.shape
  213. sum_u, sum_v = 0, 0
  214. max_y = np.max(y.flatten())
  215. # print(max_y)
  216. for i in range(m):
  217. for j in range(n):
  218. sum_u = sum_u + u[i][j]
  219. sum_v = sum_v + v[i][j]
  220. avl_u = sum_u / (m * n)
  221. avl_v = sum_v / (m * n)
  222. du, dv = 0, 0
  223. # print(avl_u, avl_v)
  224. for i in range(m):
  225. for j in range(n):
  226. du = du + np.abs(u[i][j] - avl_u)
  227. dv = dv + np.abs(v[i][j] - avl_v)
  228. avl_du = du / (m * n)
  229. avl_dv = dv / (m * n)
  230. num_y, yhistogram, ysum = np.zeros(y.shape), np.zeros(256), 0
  231. radio = 0.5 # 如果该值过大过小,色温向两极端发展
  232. for i in range(m):
  233. for j in range(n):
  234. value = 0
  235. if np.abs(u[i][j] - (avl_u + avl_du * con_num(avl_u))) < radio * avl_du or np.abs(
  236. v[i][j] - (avl_v + avl_dv * con_num(avl_v))) < radio * avl_dv:
  237. value = 1
  238. else:
  239. value = 0
  240. if value <= 0:
  241. continue
  242. num_y[i][j] = y[i][j]
  243. yhistogram[int(num_y[i][j])] = 1 + yhistogram[int(num_y[i][j])]
  244. ysum += 1
  245. # print(yhistogram.shape)
  246. sum_yhistogram = 0
  247. # hists2, bins = np.histogram(yhistogram, 256, [0, 256])
  248. # print(hists2)
  249. Y = 255
  250. num, key = 0, 0
  251. while Y >= 0:
  252. num += yhistogram[Y]
  253. if num > 0.1 * ysum: # 取前10%的亮点为计算值,如果该值过大易过曝光,该值过小调整幅度小
  254. key = Y
  255. break
  256. Y = Y - 1
  257. # print(key)
  258. sum_r, sum_g, sum_b, num_rgb = 0, 0, 0, 0
  259. for i in range(m):
  260. for j in range(n):
  261. if num_y[i][j] > key:
  262. sum_r = sum_r + r[i][j]
  263. sum_g = sum_g + g[i][j]
  264. sum_b = sum_b + b[i][j]
  265. num_rgb += 1
  266. avl_r = sum_r / num_rgb
  267. avl_g = sum_g / num_rgb
  268. avl_b = sum_b / num_rgb
  269. for i in range(m):
  270. for j in range(n):
  271. b_point = int(b[i][j]) * int(max_y) / avl_b
  272. g_point = int(g[i][j]) * int(max_y) / avl_g
  273. r_point = int(r[i][j]) * int(max_y) / avl_r
  274. if b_point>255:
  275. b[i][j] = 255
  276. else:
  277. if b_point<0:
  278. b[i][j] = 0
  279. else:
  280. b[i][j] = b_point
  281. if g_point>255:
  282. g[i][j] = 255
  283. else:
  284. if g_point<0:
  285. g[i][j] = 0
  286. else:
  287. g[i][j] = g_point
  288. if r_point>255:
  289. r[i][j] = 255
  290. else:
  291. if r_point<0:
  292. r[i][j] = 0
  293. else:
  294. r[i][j] = r_point
  295. return cv2.merge([b, g, r])
  296. '''
  297. img : 原图
  298. img1:均值白平衡法
  299. img2: 完美反射
  300. img3: 灰度世界假设
  301. img4: 基于图像分析的偏色检测及颜色校正方法
  302. img5: 动态阈值算法
  303. '''
  304. img = cv2.imread('./dataset/1/3.JPG')
  305. # img = cv2.imread('./dataset/2/1_'+str(i)+'.JPG')
  306. img1 = white_balance_1(img)
  307. img2 = white_balance_2(img)
  308. img3 = white_balance_3(img)
  309. img4 = white_balance_4(img)
  310. img5 = white_balance_5(img)
  311. print('----------------------')
  312. img_stack = np.vstack([img,img1,img2,img3,img4,img5])
  313. # cv2.imwrite("./dataset/"+str(i)+'.JPG',img_stack)
  314. cv2.imshow('image',img_stack)
  315. cv2.waitKey(0)

 

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/weixin_40725706/article/detail/287451
推荐阅读
相关标签
  

闽ICP备14008679号