crf(條件隨機場)用於遙感影像分類結果的優化

如霧如電發表於2020-11-20

參考連結:1.https://blog.csdn.net/wzw12315/article/details/106475791
2。https://www.cnblogs.com/wanghui-garcia/p/10761612.html
主要的程式碼段都是差不多的,就是用gdal讀入了資料,結果還是有點變化的,引數需要自己慢慢調整

"""
Adapted from the inference.py to demonstate the usage of the util functions.
"""
import sys
import numpy as np
import pydensecrf.densecrf as dcrf
import cv2
import gdal
from skimage import color
# Get im{read,write} from somewhere.
#  try:
    #  from cv2 import imread, imwrite
#  except ImportError:
    #  # Note that, sadly, skimage unconditionally import scipy and matplotlib,
    #  # so you'll need them if you don't have OpenCV. But you probably have them.
    #  from skimage.io import imread, imsave
    #  imwrite = imsave
    # TODO: Use scipy instead.


from skimage.io import imread, imsave
imwrite = imsave
from pydensecrf.utils import unary_from_labels, create_pairwise_bilateral, create_pairwise_gaussian

def read_img(filename):     #讀圖
    dataset=gdal.Open(filename)

    im_width = dataset.RasterXSize
    im_height = dataset.RasterYSize

    im_geotrans = dataset.GetGeoTransform()
    im_proj = dataset.GetProjection()
    im_data = dataset.ReadAsArray(0,0,im_width,im_height)

    del dataset 
    return im_proj,im_geotrans,im_width, im_height,im_data


def write_img(filename, im_proj, im_geotrans, im_data): #寫出圖
    if 'int8' in im_data.dtype.name:
        datatype = gdal.GDT_Byte
    elif 'int16' in im_data.dtype.name:
        datatype = gdal.GDT_UInt16
    else:
        datatype = gdal.GDT_Float32

    if len(im_data.shape) == 3:
        im_bands, im_height, im_width = im_data.shape
    else:
        im_bands, (im_height, im_width) = 1,im_data.shape 

    driver = gdal.GetDriverByName("GTiff")
    dataset = driver.Create(filename, im_width, im_height, im_bands, datatype)

    dataset.SetGeoTransform(im_geotrans)
    dataset.SetProjection(im_proj)

    if im_bands == 1:
        dataset.GetRasterBand(1).WriteArray(im_data)
    else:
        for i in range(im_bands):
            dataset.GetRasterBand(i+1).WriteArray(im_data[i])

    del dataset

def crf(inimage,img_anno):    # inimage為原圖    img_anno為預測結果,我的預測結果是0,1,2,3這樣,每個數字代表一個類別
        fn_im = inimage
        fn_anno = img_anno
        img = inimage
        anno_rgb = img_anno
        rgb = anno_rgb
        print("=========>>", anno_rgb.shape)
        #rgb= np.argmax(anno_rgb[0],axis=0)
        print("=======>>",rgb.shape)
        print(np.max(rgb), np.min(rgb))
        anno_lbl=rgb
        # img = img[0]
        # img = img.transpose(1, 2, 0)
        colors, labels = np.unique(anno_lbl, return_inverse=True)
        colors = colors[1:]
        colorize = np.empty((len(colors), 3), np.uint8)
        colorize[:,0] = (colors & 0x0000FF)
        colorize[:,1] = (colors & 0x00FF00) >> 8
        colorize[:,2] = (colors & 0xFF0000) >> 16
        # n_labels = len(set(labels.flat))-1
        n_labels = len(set(labels.flat))   #這裡我把減1去掉了,因為我的所有數字都代表一個類別,沒有背景
        if n_labels <= 1:
            return rgb
        use_2d = False
        if use_2d:
            img = img.astype(int)
            d = dcrf.DenseCRF2D(img.shape[1], img.shape[0], n_labels)
            U = unary_from_labels(labels, n_labels, gt_prob=0.7, zero_unsure=False)
            d.setUnaryEnergy(U)
            d.addPairwiseGaussian(sxy=(3, 3), compat=3, kernel=dcrf.DIAG_KERNEL,    #1.CONST_KERNEL  2.DIAG_KERNEL (the default)  3.FULL_KERNEL
                                normalization=dcrf.NORMALIZE_SYMMETRIC)  #1.NO_NORMALIZATION  2.NORMALIZE_BEFORE 3.NORMALIZE_AFTER 4.NORMALIZE_SYMMETRIC (the default)
            img = counts = np.copy(np.array(img,dtype = np.uint8),order='C')
            d.addPairwiseBilateral(sxy=(80,80), srgb=(13, 13, 13), rgbim=img,
                                compat=10,
                                kernel=dcrf.CONST_KERNEL,
                                normalization=dcrf.NORMALIZE_SYMMETRIC)

        else:
			#這部分比上面的效果好點,建議用這個
            # Example using the DenseCRF class and the util functions
            d = dcrf.DenseCRF(img.shape[1] * img.shape[0], n_labels)

            # get unary potentials (neg log probability)
            U = unary_from_labels(labels, n_labels, gt_prob=0.7, zero_unsure=False)  #zero_unsure=False 0不是背景而是一個類別,所以False
            d.setUnaryEnergy(U)

            # This creates the color-independent features and then add them to the CRF
            feats = create_pairwise_gaussian(sdims=(3, 3), shape=img.shape[:2])
            d.addPairwiseEnergy(feats, compat=3,
                                kernel=dcrf.DIAG_KERNEL,
                                normalization=dcrf.NORMALIZE_SYMMETRIC)

            # This creates the color-dependent features and then add them to the CRF
            feats = create_pairwise_bilateral(sdims=(80, 80), schan=(13, 13, 13),
                                            img=img, chdim=2)
            d.addPairwiseEnergy(feats, compat=10,
                                kernel=dcrf.DIAG_KERNEL,
                                normalization=dcrf.NORMALIZE_SYMMETRIC)

        Q = d.inference(20)


# Find out the most probable class for each pixel.
        MAP = np.argmax(Q, axis=0)

        return MAP.reshape(img.shape[:2])

if __name__ == "__main__":
    img_path = "D:/xx/xx/xx.tif"
    anno = 'D:/xx/result/t.tif'
    out = 'D:/xx/result/t_t.tif'

    im_proj,im_geotrans,im_width, im_height,im_data = read_img(img_path)
    img = im_data.transpose(1,2,0)
    # print(img.shape)

    im_proj,im_geotrans,im_width, im_height,lab = read_img(anno)

    # dense_crf(img, lab, out, im_proj,im_geotrans)
    arr = crf(img,lab)
    write_img(out, im_proj, im_geotrans, arr)


原圖
處理前的結果:
預測結果
處理後的結果:
仔細看變化還是挺大的,去掉了很多雜質,讓類別分佈更純粹
在這裡插入圖片描述

其他版本:
1.

import os
import gdal
import numpy as np
import pydensecrf.densecrf as dcrf
from pydensecrf.utils import compute_unary, create_pairwise_bilateral,create_pairwise_gaussian, softmax_to_unary, unary_from_softmax,unary_from_labels

# """   
# Getting a Unary
# 得到 unary potentials有兩種常見的方法:
# 1)由人類或其他過程產生的硬標籤。該方法由from pydensecrf.utils import unary_from_labels實現
# 2)由概率分佈計算得到,例如深度網路的softmax輸出。即我們之前先對圖片使用訓練好的網路預測得到最終經過softmax函式得到的分類結果,
# 這裡需要將這個結果轉成一元勢
# """

def read_img(filename):
    dataset=gdal.Open(filename)

    im_width = dataset.RasterXSize
    im_height = dataset.RasterYSize

    im_geotrans = dataset.GetGeoTransform()
    im_proj = dataset.GetProjection()
    im_data = dataset.ReadAsArray(0,0,im_width,im_height)

    del dataset 
    return im_proj,im_geotrans,im_width, im_height,im_data


def write_img(filename, im_proj, im_geotrans, im_data):
    if 'int8' in im_data.dtype.name:
        datatype = gdal.GDT_Byte
    elif 'int16' in im_data.dtype.name:
        datatype = gdal.GDT_UInt16
    else:
        datatype = gdal.GDT_Float32

    if len(im_data.shape) == 3:
        im_bands, im_height, im_width = im_data.shape
    else:
        im_bands, (im_height, im_width) = 1,im_data.shape 

    driver = gdal.GetDriverByName("GTiff")
    dataset = driver.Create(filename, im_width, im_height, im_bands, datatype)

    dataset.SetGeoTransform(im_geotrans)
    dataset.SetProjection(im_proj)

    if im_bands == 1:
        dataset.GetRasterBand(1).WriteArray(im_data)
    else:
        for i in range(im_bands):
            dataset.GetRasterBand(i+1).WriteArray(im_data[i])

    del dataset


def dense_crf(img, pre, save, im_proj, im_geotrans):

    softmax = pre  # processed_probabilities:CNN 預測概率 經過 softmax [n_label,H,W]
    # print(softmax.shape)
    # exit()
    #1)Getting a Unary
    #1.直接呼叫函式
    arr = np.zeros((4, img.shape[0], img.shape[1]))
    arr[0] = pre
    arr[1] = pre
    arr[2] = pre
    arr[3] = pre
    # print(arr.shape)
    # softmax = arr
    # softmax[softmax==0] = 4
    # print(tt)
    # unary = unary_from_softmax(softmax)
    # softmax = softmax.astype(np.uint32)
    # print(unary.shape)

    # unary = unary_from_labels(softmax, 4, gt_prob=0.7, zero_unsure=0)
    # print(unary.shape)

    # unary = softmax.reshape(4, -1)
    # unary = unary.astype(np.float32)
    # print(unary)

    #2.自己生成一元勢函式
    # The inputs should be C-continious -- we are using Cython wrapper
    unary = -np.log(arr)
    unary = unary.reshape((4, -1))
    unary = np.ascontiguousarray(unary)  # (21, n)

    d = dcrf.DenseCRF2D(img.shape[1], img.shape[0], 4)  # h,w,n_class

    unary = np.float32(unary)
    d.setUnaryEnergy(unary)

    # This potential penalizes small pieces of segmentation that are
    # spatially isolated -- enforces more spatially consistent segmentations
    
    # Pairwise potentials(二元勢)
    feats = create_pairwise_gaussian(sdims=(3, 3), shape=img.shape[:2])

    d.addPairwiseEnergy(feats, compat=3,
                        kernel=dcrf.DIAG_KERNEL,
                        normalization=dcrf.NORMALIZE_SYMMETRIC)

    # This creates the color-dependent features --
    # because the segmentation that we get from CNN are too coarse
    # and we can use local color features to refine them
    feats = create_pairwise_bilateral(sdims=(80, 80), schan=(13, 13, 13),
                                      img=img, chdim=2)

    d.addPairwiseEnergy(feats, compat=10,
                        kernel=dcrf.DIAG_KERNEL,
                        normalization=dcrf.NORMALIZE_SYMMETRIC)
    # 快捷方法
    # d.addPairwiseGaussian(sxy=3, compat=3)
    # d.addPairwiseBilateral(sxy=80, srgb=13, rgbim=img, compat=10)
    # 迭代次數,對於IMG_1702(2592*1456)這張圖,迭代5 16.807087183s 迭代20 37.5700438023s
    Q = d.inference(5)
    print(Q)
    res = np.argmax(Q, axis=0).reshape((img.shape[0], img.shape[1]))
    res = res*255

    write_img(save, im_proj, im_geotrans, res)

    return res


if __name__ == "__main__":
    img_path = "D:/xx/xx/xx.tif"
    anno = 'D:/xx/result/t.tif'
    out = 'D:/xx/result/t_t.tif'

    im_proj,im_geotrans,im_width, im_height,im_data = read_img(img_path)
    img = im_data.transpose(1,2,0)

    im_proj,im_geotrans,im_width, im_height,lab = read_img(anno)

    dense_crf(img, lab, out, im_proj,im_geotrans)

import os, sys
import numpy as np
import pydensecrf.densecrf as dcrf
import cv2, gdal
from collections import Counter

# Get im{read,write} from somewhere.
try:
    from cv2 import imread, imwrite
except ImportError:
    # Note that, sadly, skimage unconditionally import scipy and matplotlib,
    # so you'll need them if you don't have OpenCV. But you probably have them.
    from skimage.io import imread, imsave
    imwrite = imsave
    # TODO: Use scipy instead.

from pydensecrf.utils import unary_from_labels, create_pairwise_bilateral, create_pairwise_gaussian

def read_img(filename):
    dataset=gdal.Open(filename)

    im_width = dataset.RasterXSize
    im_height = dataset.RasterYSize

    im_geotrans = dataset.GetGeoTransform()
    im_proj = dataset.GetProjection()
    im_data = dataset.ReadAsArray(0,0,im_width,im_height)

    del dataset 
    return im_proj,im_geotrans,im_width, im_height,im_data


def write_img(filename, im_proj, im_geotrans, im_data):
    if 'int8' in im_data.dtype.name:
        datatype = gdal.GDT_Byte
    elif 'int16' in im_data.dtype.name:
        datatype = gdal.GDT_UInt16
    else:
        datatype = gdal.GDT_Float32

    if len(im_data.shape) == 3:
        im_bands, im_height, im_width = im_data.shape
    else:
        im_bands, (im_height, im_width) = 1,im_data.shape 

    driver = gdal.GetDriverByName("GTiff")
    dataset = driver.Create(filename, im_width, im_height, im_bands, datatype)

    dataset.SetGeoTransform(im_geotrans)
    dataset.SetProjection(im_proj)

    if im_bands == 1:
        dataset.GetRasterBand(1).WriteArray(im_data)
    else:
        for i in range(im_bands):
            dataset.GetRasterBand(i+1).WriteArray(im_data[i])

    del dataset


def crf(x,y,z):
    # fn_im = 'unet_pred/%s'%x
    # fn_anno = 'mask/%s'%y
    # fn_output = 'crf/%s'%z

    fn_im = x
    fn_anno = y
    fn_output = z

    ##################################
    ### Read images and annotation ###
    ##################################
    # img = imread(fn_im)
    im_proj,im_geotrans,im_width, im_height,im_data = read_img(fn_im)
    img = im_data.transpose(1,2,0)

    # Convert the annotation's RGB color to a single 32-bit integer color 0xBBGGRR
    
    anno_rgb = imread(fn_anno)
    anno_rgb[anno_rgb == 0] = 4
    anno_rgb = anno_rgb.astype(np.uint32)

    #anno_rgb = anno_rgb.astype(np.uint32)
    # anno_rgb[anno_rgb < 1] = 1
    # anno_rgb[anno_rgb > 1] = 255

    anno_lbl = anno_rgb[:, :, 0] + (anno_rgb[:, :, 1] << 8) + (anno_rgb[:, :, 2] << 16)

    # labels = labels_cc
    # Convert the 32bit integer color to 1, 2, ... labels.
    # Note that all-black, i.e. the value 0 for background will stay 0.
    colors, labels = np.unique(anno_lbl, return_inverse=True)
    labels[labels==0] = 4

    # But remove the all-0 black, that won't exist in the MAP!
    HAS_UNK = 0 in colors
    if HAS_UNK:
        print(
        "Found a full-black pixel in annotation image, assuming it means 'unknown' label, and will thus not be present in the output!")
        print(
        "If 0 is an actual label for you, consider writing your own code, or simply giving your labels only non-zero values.")
        colors = colors[1:]
    # else:
    #    print("No single full-black pixel found in annotation image. Assuming there's no 'unknown' label!")

    # And create a mapping back from the labels to 32bit integer colors.
    colorize = np.empty((len(colors), 3), np.uint8)
    colorize[:, 0] = (colors & 0x0000FF)
    colorize[:, 1] = (colors & 0x00FF00) >> 8
    colorize[:, 2] = (colors & 0xFF0000) >> 16

    # Compute the number of classes in the label image.
    # We subtract one because the number shouldn't include the value 0 which stands
    # for "unknown" or "unsure".
    n_labels = len(set(labels.flat)) - int(HAS_UNK)
    print(n_labels, " labels", (" plus \"unknown\" 0: " if HAS_UNK else ""), set(labels.flat))

    ###########################
    ### Setup the CRF model ###
    ###########################

    use_2d = False
    # use_2d = True
    if use_2d:
        print("Using 2D specialized functions")

        # Example using the DenseCRF2D code
        d = dcrf.DenseCRF2D(img.shape[1], img.shape[0], n_labels)

        # get unary potentials (neg log probability)
        U = unary_from_labels(labels, n_labels, gt_prob=0.7, zero_unsure=HAS_UNK)
        d.setUnaryEnergy(U)

        # This adds the color-independent term, features are the locations only.
        d.addPairwiseGaussian(sxy=(3, 3), compat=3, kernel=dcrf.DIAG_KERNEL,
                              normalization=dcrf.NORMALIZE_SYMMETRIC)

        # This adds the color-dependent term, i.e. features are (x,y,r,g,b).
        d.addPairwiseBilateral(sxy=(80, 80), srgb=(13, 13, 13), rgbim=img,
                               compat=10,
                               kernel=dcrf.DIAG_KERNEL,
                               normalization=dcrf.NORMALIZE_SYMMETRIC)
    else:
        print("Using generic 2D functions")

        # Example using the DenseCRF class and the util functions
        d = dcrf.DenseCRF(img.shape[1] * img.shape[0], n_labels)

        # get unary potentials (neg log probability)
        U = unary_from_labels(labels, n_labels, gt_prob=0.7, zero_unsure=HAS_UNK)
        d.setUnaryEnergy(U)

        # This creates the color-independent features and then add them to the CRF
        feats = create_pairwise_gaussian(sdims=(3, 3), shape=img.shape[:2])
        d.addPairwiseEnergy(feats, compat=3,
                            kernel=dcrf.DIAG_KERNEL,
                            normalization=dcrf.NORMALIZE_SYMMETRIC)

        # This creates the color-dependent features and then add them to the CRF
        feats = create_pairwise_bilateral(sdims=(80, 80), schan=(13, 13, 13),
                                          img=img, chdim=2)
        d.addPairwiseEnergy(feats, compat=10,
                            kernel=dcrf.DIAG_KERNEL,
                            normalization=dcrf.NORMALIZE_SYMMETRIC)

    ####################################
    ### Do inference and compute MAP ###
    ####################################

    # Run five inference steps.
    Q = d.inference(5)

    # Find out the most probable class for each pixel.
    MAP = np.argmax(Q, axis=0)

    # Convert the MAP (labels) back to the corresponding colors and save the image.
    # Note that there is no "unknown" here anymore, no matter what we had at first.
    MAP = colorize[MAP, :]
    re_out = MAP.reshape(img.shape)
    imwrite(fn_output, re_out[:,:,0])

    # Just randomly manually run inference iterations
    Q, tmp1, tmp2 = d.startInference()
    for i in range(5):
        print("KL-divergence at {}: {}".format(i, d.klDivergence(Q)))
    d.stepInference(Q, tmp1, tmp2)

    print(np.shape(Q), np.shape(MAP), np.shape(tmp2))


if __name__ == "__main__":
    img_path = "D:/xx/xx.tif"
    anno = 'D:/xx/temp/class_raster.tif'
    out = 'D:/xx/temp/class_raster_crf.tif'
    crf(img_path,anno,out)

    
    # img_path = ''
    # pre_path = ''
    # out_path = ''
    # img_names = os.listdir(img_path)
    # for name in img_names:
    #     im_full_path = os.path.join(img_path, name)
    #     pre_full_path = os.path.join(pre_path, name)
    #     out_full_path = os.path.join(out_path, name)
    #     crf(im_full_path,pre_full_path,out_full_path)

相關文章