赞
踩
本文章将带大家实现灾害监测中一种常用的图像分类方法,即区域生长算法。与前面介绍的几种图像分割方法不同,区域生长算法可直接对高于Uint8灰级的数据直接进行处理,所以保持了原数据的结构形式。另外,区域生长算法涉及到的参数较多,分类的结果与参数关联度较高,所以笔者也添加了阈值参量的调试程序。
区域生长算法是将图像中所有种子附近的像素按照像元值的相似度和似然性进行合并、归类的二值化过程。区域生长算法的生长准则包括四邻域和八邻域两种。除了生长准则外,图像中种子点的位置与个数也与分类结果紧密相关。关于种子点的选择方式也有两种,分别是手动选点和自动选点,鉴于本文的数据量与分割精度,笔者采用了自动选点方式(其实,手动选点精度会更高一些)。
- import cv2
- import gdal
- import numpy as np
-
- def image_open(image):
- data = gdal.Open(image)
- if data == "None":
- print("数据无法打开")
- return data
-
- Filepath = r"E:yynctryedata20180911_yync(DA).tif"
- data = image_open(Filepath).ReadAsArray().transpose(1, 2, 0)
- data1 = data[:, :, 0:3]
-
- Normalize_data = np.zeros(data1.shape)
- for d in range(data1.shape[2]):
- for i in range(data1.shape[0]):
- for j in range(data1.shape[1]):
- Normalize_data[i][j][d] = (data1[i][j][d] - np.min(data1)) / (np.max(data1) - np.min(data1))
- Normalize_data = Normalize_data*256
- data2 = Normalize_data.astype(np.uint8)
-
- # data = image_open(Filepath).ReadAsArray()
- # data1 = data[0:3, :, :].transpose(1, 2, 0)
- # data2 = data1.astype(np.uint8)
-
- cv2.imwrite(r"E:yynctryedata20180911_yync(DA)1.jpg", data2)

代码关键点说明:
这段代码的主要过程是将多波段图像的BGR波段合成了RGB真彩色图像,目的是根据图像获取需要的种子点坐标。笔者采用了遍历+归一化的方法进行图像转换。这里需要说明的一点是传统印象里,要实现图像的格式转换和数据压缩可能只需要用numpy的astype方法就可以实现,也就是下面这段代码:
- # data = image_open(Filepath).ReadAsArray()
- # data1 = data[0:3, :, :].transpose(1, 2, 0)
- # data2 = data1.astype(np.uint8)
经过笔者的亲测试验这种方法不可行,至少对于本文章所使用的数据是不可行的,到底有多么的不可行呢?大家看看下图吧。虽然图(b)的显示仍有问题,但效果比图(a)好的不是一点半点。
- import cv2
- import gdal
-
- def image_open(image):
- data = gdal.Open(image)
- if data == "None":
- print("数据无法打开")
- return data
-
- def on_EVENT_LBUTTONDOWN(event, x, y, flags, param):
- if event == cv2.EVENT_LBUTTONDOWN:
- xy = "%d,%d" % (x, y)
- cv2.circle(img, (x, y), 1, (255, 0, 0), thickness = -1)
- cv2.putText(img, xy, (x, y), cv2.FONT_HERSHEY_PLAIN,
- 1.0, (0,0,0), thickness = 1)
- cv2.imshow("image", img)
-
- img = cv2.imread(r"E:yynctryedata20180911_yync(DA)1.jpg")
-
- cv2.namedWindow("image")
- cv2.setMouseCallback("image", on_EVENT_LBUTTONDOWN)
- while(1):
- cv2.imshow("image", img)
- if cv2.waitKey(0)&0xFF==27:
- break
- cv2.destroyAllWindows()

代码关键点说明:
把之前的图像输入进来并运行之后,会展示出图像,此时只需要点击图像中的像素位置即可显示矩阵坐标信息。
- import numpy as np
- import cv2
- import gdal
-
- class Point(object):
- def __init__(self, x, y):
- self.x = x
- self.y = y
-
- def getX(self):
- return self.x
-
- def getY(self):
- return self.y
-
- def image_open(image):
- data = gdal.Open(image)
- if data == "None":
- print("数据不存在")
- return data
-
- def getGrayDiff(img, currentPoint, tmpPoint):
- return abs(int(img[currentPoint.x, currentPoint.y]) - int(img[tmpPoint.x, tmpPoint.y]))
-
- def selectConnects(p):
- if p != 0:
- connects = [Point(-1, -1), Point(0, -1), Point(1, -1), Point(1, 0), Point(1, 1),
- Point(0, 1), Point(-1, 1), Point(-1, 0)]
- else:
- connects = [Point(0, -1), Point(1, 0), Point(0, 1), Point(-1, 0)]
- return connects
-
- def regionGrow(img, seeds, thresh, p=1):
- height, weight = img.shape
- seedMark = np.zeros(img.shape)
- seedList = []
- for seed in seeds:
- seedList.append(seed)
- label = 1
- connects = selectConnects(p)
- while (len(seedList) > 0):
- currentPoint = seedList.pop(0)
-
- seedMark[currentPoint.x, currentPoint.y] = label
- for i in range(8):
- tmpX = currentPoint.x + connects[i].x
- tmpY = currentPoint.y + connects[i].y
- if tmpX < 0 or tmpY < 0 or tmpX >= height or tmpY >= weight:
- continue
- grayDiff = getGrayDiff(img, currentPoint, Point(tmpX, tmpY))
- if grayDiff < thresh and seedMark[tmpX, tmpY] == 0:
- seedMark[tmpX, tmpY] = label
- seedList.append(Point(tmpX, tmpY))
- return seedMark
-
- def regionGrow_t(t):
- thresh = t
- binary_image = regionGrow(img, seeds, thresh, p=1) #p=0四邻域,p≠0八邻域
- cv2.imshow("thresh_test", binary_image)
-
- #选取图像中适于分类的敏感波段
- Filepath = r"E:yynctryedata20180911_yync(DA).tif"
- img = image_open(Filepath).ReadAsArray()
- img = img[1, :, :]
- print(img.shape)
-
- #设定种子
- seeds = [Point(163,79),Point(173,127),Point(184,15),
- Point(73,144),Point(85,199)]
- thresh = 10
- cv2.namedWindow("thresh_test")
-
- #设定调试的阈值范围
- cv2.createTrackbar('thresh', "thresh_test", 1, 30, regionGrow_t)
-
- #调试可视化
- while(1):
- k = cv2.waitKey(1)&0xFF
- if k==27:
- break
- thresh = cv2.getTrackbarPos('thresh',"thresh_test")

代码关键点说明:
这段代码调试的是生长规则判定过程中的阈值,移动阈值选项卡就可以查看不同阈值下的分割结果,另外阈值范围大家可以自己调试。这里,我们选择27作为最优的阈值。
Python代码实现(结果输出)
- import gdal
- import numpy as np
- import cv2
-
- def image_open(image):
- data = gdal.Open(image)
- if data == "None":
- print("数据不存在")
- return data
-
- def datasave(Filename, data):
- output1 = gdal.GetDriverByName("GTiff")
- output2 = output1.Create(Filename, width, height, 1, gdal.GDT_Float32)
- output2.SetProjection(projection)
- output2.SetGeoTransform(transform)
- output2.GetRasterBand(1)
- output2.WirteArray(data)
-
- class Point(object):
- def __init__(self, x, y):
- self.x = x
- self.y = y
-
- def getX(self):
- return self.x
-
- def getY(self):
- return self.y
-
- def getGrayDiff(img, currentPoint, tmpPoint):
- return abs(int(img[currentPoint.x, currentPoint.y]) - int(img[tmpPoint.x, tmpPoint.y]))
-
- def selectConnects(p):
- if p != 0:
- connects = [Point(-1, -1), Point(0, -1), Point(1, -1), Point(1, 0), Point(1, 1),
- Point(0, 1), Point(-1, 1), Point(-1, 0)]
- else:
- connects = [Point(0, -1), Point(1, 0), Point(0, 1), Point(-1, 0)]
- return connects
-
- def regionGrow(img, seeds, thresh, p=1):
- height, weight = img.shape
- seedMark = np.zeros(img.shape)
- seedList = []
- for seed in seeds:
- seedList.append(seed)
- label1 = 255
- connects = selectConnects(p)
- while (len(seedList) > 0):
- currentPoint = seedList.pop(0)
- seedMark[currentPoint.x, currentPoint.y] = label1
- for i in range(8):
- tmpX = currentPoint.x + connects[i].x
- tmpY = currentPoint.y + connects[i].y
- if tmpX < 0 or tmpY < 0 or tmpX >= height or tmpY >= weight:
- continue
- grayDiff = getGrayDiff(img, currentPoint, Point(tmpX, tmpY))
- if grayDiff < thresh and seedMark[tmpX, tmpY] == 0:
- seedMark[tmpX, tmpY] = label1
- seedList.append(Point(tmpX, tmpY))
- return seedMark
-
- #数据读取及参数设定
- Filepath = r"E:yynctryedata20180911_yync(DA).tif"
- Filename1 = r"E:yynctryedataquyushengzhang.jpg"
- data = image_open(Filepath)
- width = data.RasterXSize
- height = data.RasterYSize
- projection = data.GetProjection()
- transform = data.GetGeoTransform()
- data1 = data.GetRasterBand(2)
- img = data1.ReadAsArray(0, 0, width, height)
- result1 = np.zeros(img.shape)
-
- #设定种子
- seeds = [Point(163,79),Point(173,127),Point(184,15),
- Point(73,144),Point(85,199)]
-
- #输入调试得到的最优阈值
- binaryImg = regionGrow(img, seeds, 27).astype(np.uint8)
-
- cv2.imwrite(Filename1, binaryImg)
- cv2.imshow("image", binaryImg)
- cv2.waitKey(0)

代码关键点说明:
图像输出是不带坐标的结果,大家可以在相关软件里附上坐标,或者在gdal库里进行处理
个人感觉比之前的阈值分割方法好很多,四邻域的生长准则大家可以自行试下。
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。