[譯]使用 Python 實現接縫裁剪演算法

caoyi發表於2018-07-12

接縫裁剪是一種新型的裁剪影象的方式,它不會丟失影象中的重要內容。這通常被稱之為“內容感知”裁剪或影象重定向。你可以從這張照片中感受一下這個演算法:

[譯]使用 Python 實現接縫裁剪演算法

照片由 Unsplash 使用者 Pietro De Grandi 提供

變成下面這張:

[譯]使用 Python 實現接縫裁剪演算法

正如你所看到的,影象中的非常重要內容 —— 船隻,都保留下來了。該演算法去除了一些岩層和水(讓船看起來更靠近)。核心演算法可以參考 Shai Avidan 和 Ariel Shamir 的原始論文 Seam Carving for Content-Aware Image Resizing。在這篇文章中,我將展示如何在 Python 中基本實現該演算法。

概要

該演算法的工作原理如下:

  1. 為每個畫素分派一個能量值(energy)
  2. 找到能量最低的畫素的 8 聯通區域
  3. 刪除該區域內所有的畫素
  4. 重複 1-3,直到刪除所需要保留的行/列數

接下來,假設我們只是嘗試裁剪影象的寬度,即刪除列。對於刪除行來說也是類似的,至於原因最後會說明。

以下是 Python 程式碼需要引入的包:

import sys

import numpy as np
from imageio import imread, imwrite
from scipy.ndimage.filters import convolve

# tqdm 並不是必需的,但它可以向我們展示一個漂亮的進度條
from tqdm import trange
複製程式碼

能量圖

第一步是計算每個畫素的能量值,論文中定義了許多不同的可以使用的能量函式。我們來使用最基礎的那個:

[譯]使用 Python 實現接縫裁剪演算法

這意味著什麼呢?I 代表影象,所以這個式子告訴我們,對於影象中的每個畫素和每個通道,我們執行以下幾個步驟:

  • 找到 x 軸的偏導數
  • 找到 y 軸的偏導數
  • 將他們的絕對值求和

這就是該畫素的能量值。那麼問題就來了,“你怎麼計算影象的導數?”,維基百科上的 Image derivations(影象導數)給我們展示了許多不同的計算影象導數的方法。我們將使用 Sobel 濾波器。這是一個在影象上的每個通道上的計算的convolutional kernel(卷積核)。以下是影象的兩個不同方向的過濾器:

[譯]使用 Python 實現接縫裁剪演算法

直觀地說,我們可以認為第一個濾波器是將每個畫素替換為它上邊的值和下邊的值之差。第二個過濾器將每個畫素替換為它右邊的值和左邊的值之差。這種濾波器捕捉到的是每個畫素相鄰所構成的 3x3 區域中畫素的總體趨勢。事實上,這種方法與邊緣檢測演算法也有關係。計算能量圖的方式非常簡單:

def calc_energy(img):
    filter_du = np.array([
        [1.0, 2.0, 1.0],
        [0.0, 0.0, 0.0],
        [-1.0, -2.0, -1.0],
    ])
    # 將一個 2D 的濾波器轉為 3D 的濾波器,為每個通道設定相同的濾波器:R,G,B
    filter_du = np.stack([filter_du] * 3, axis=2)

    filter_dv = np.array([
        [1.0, 0.0, -1.0],
        [2.0, 0.0, -2.0],
        [1.0, 0.0, -1.0],
    ])
    # 將一個 2D 的濾波器轉為 3D 的濾波器,為每個通道設定相同的濾波器:R,G,B
    filter_dv = np.stack([filter_dv] * 3, axis=2)

    img = img.astype('float32')
    convolved = np.absolute(convolve(img, filter_du)) + np.absolute(convolve(img, filter_dv))

    # 我們將紅綠色藍三通道中的能量相加
    energy_map = convolved.sum(axis=2)

    return energy_map
複製程式碼

視覺化能量圖後,我們可以看到:

[譯]使用 Python 實現接縫裁剪演算法

顯然,像天空和水的靜止部分這樣變化最小的區域,具有非常低的能量(暗的部分)。當我們執行接縫裁剪演算法的時候,被移除的線條一般都與影象的這些部分緊密相關,同時試圖保留高能量部分(亮的部分)。

### 找到最小能量的接縫(seam)

我們下一個目標就是找到一條從影象頂部到影象底部的能量最小的路徑。這條線必須是 8 聯通的:這意味著線中的每個畫素都可以他通過邊或叫角碰到線中的下一個畫素。舉個例子,這就是下圖中的紅色線條:

[譯]使用 Python 實現接縫裁剪演算法

所以我們怎麼找到這條線呢?事實證明,這個問題可以很好地使用動態規劃來解決!

[譯]使用 Python 實現接縫裁剪演算法

讓我們建立一個名為 M 的 2D 陣列 來儲存每個畫素的最小能量值。如果您不熟悉動態規劃,這簡單來說就是,從影象頂部到該點的所有可能接縫(seam)中的最小能量即為 M[i,j]。因此,M 的最後一行中就將包含從影象頂部到底部的最小能量。我們需要從此回溯以查詢此接縫中存在的畫素,所以我們將保留這些值,儲存在名為backtrack 的 2D 陣列中。

def minimum_seam(img):
    r, c, _ = img.shape
    energy_map = calc_energy(img)

    M = energy_map.copy()
    backtrack = np.zeros_like(M, dtype=np.int)

    for i in range(1, r):
        for j in range(0, c):
            # 處理影象的左邊緣,防止索引到 -1
            if j == 0:
                idx = np.argmin(M[i - 1, j:j + 2])
                backtrack[i, j] = idx + j
                min_energy = M[i - 1, idx + j]
            else:
                idx = np.argmin(M[i - 1, j - 1:j + 2])
                backtrack[i, j] = idx + j - 1
                min_energy = M[i - 1, idx + j - 1]

            M[i, j] += min_energy

    return M, backtrack
複製程式碼

刪除最小能量的接縫中的畫素

然後我們就可以刪除有著最低能量的接縫中的畫素,返回新的圖片:

def carve_column(img):
    r, c, _ = img.shape

    M, backtrack = minimum_seam(img)

    # 建立一個(r,c)矩陣,所有值都為 True
    # 我們將刪除影象中矩陣裡所有為 False 的對應的畫素
    mask = np.ones((r, c), dtype=np.bool)

    # 找到 M 最後一行中最小元素的那一列的索引
    j = np.argmin(M[-1])

    for i in reversed(range(r)):
        # 標記這個畫素之後需要刪除
        mask[i, j] = False
        j = backtrack[i, j]

    # 因為影象是三通道的,我們將 mask 轉為 3D 的
    mask = np.stack([mask] * 3, axis=2)

    # 刪除 mask 中所有為 False 的位置所對應的畫素,並將
    # 他們重新調整為新影象的尺寸
    img = img[mask].reshape((r, c - 1, 3))

    return img
複製程式碼

對每列重複操作

所有的基礎工作都已做完了!現在,我們只要一次次地執行 carve_column 函式,直到我們刪除到了所需的列數。我們再建立一個 crop_c 函式,影象和縮放因子作為輸入。如果影象的尺寸為(300,600),並且我們想要將其減小到(150,600),scale_c 設定為 0.5 即可。

def crop_c(img, scale_c):
    r, c, _ = img.shape
    new_c = int(scale_c * c)

    for i in trange(c - new_c): # 如果你不想用 tqdm,這裡將 trange 改為 range
        img = carve_column(img)

    return img
複製程式碼

將它們合在一起

我們可以新增一個 main 函式,讓程式碼可以通過命令列呼叫:

def main():
    scale = float(sys.argv[1])
    in_filename = sys.argv[2]
    out_filename = sys.argv[3]

    img = imread(in_filename)
    out = crop_c(img, scale)
    imwrite(out_filename, out)

if __name__ == '__main__':
    main()
複製程式碼

然後執行這段程式碼:

python carver.py 0.5 image.jpg cropped.jpg
複製程式碼

cropped.jpg 現在應該顯示以下這樣的影象:

![]https://user-gold-cdn.xitu.io/2018/7/12/1648d13cb3f0ab58?w=400&h=533&f=jpeg&s=57795)

行應該怎麼處理呢?

然後,我們可以開始研究怎麼修改我們的迴圈來換個方向處理資料。或者...只需旋轉影象就可以執行 crop_c

def crop_r(img, scale_r):
    img = np.rot90(img, 1, (0, 1))
    img = crop_c(img, scale_r)
    img = np.rot90(img, 3, (0, 1))
    return img
複製程式碼

將這段程式碼新增到 main 函式中,現在我們也可以裁剪行!

def main():
    if len(sys.argv) != 5:
        print('usage: carver.py <r/c> <scale> <image_in> <image_out>', file=sys.stderr)
        sys.exit(1)

    which_axis = sys.argv[1]
    scale = float(sys.argv[2])
    in_filename = sys.argv[3]
    out_filename = sys.argv[4]

    img = imread(in_filename)

    if which_axis == 'r':
        out = crop_r(img, scale)
    elif which_axis == 'c':
        out = crop_c(img, scale)
    else:
        print('usage: carver.py <r/c> <scale> <image_in> <image_out>', file=sys.stderr)
        sys.exit(1)
    
    imwrite(out_filename, out)
複製程式碼

執行程式碼:

python carver.py r 0.5 image2.jpg cropped.jpg
複製程式碼

然後我們就可以把這張圖:

[譯]使用 Python 實現接縫裁剪演算法

Photo by Brent Cox on Unsplash

變成這樣:

[譯]使用 Python 實現接縫裁剪演算法

總結

我希望你是愉快而又收穫地讀到這裡的。我很享受實現這篇論文的過程,並打算構建一個這個演算法更快的版本。比如說,使用相同的計算過的影象接縫去除多個接縫。在我的實驗中,這可以使演算法更快,每次迭代可以幾乎線性地移除接縫,但質量明顯下降。另一個優化是計算 GPU 上的能量圖,在這裡探討的

這是完整的程式:

#!/usr/bin/env python

"""
Usage: python carver.py <r/c> <scale> <image_in> <image_out>
Copyright 2018 Karthik Karanth, MIT License
"""

import sys

from tqdm import trange
import numpy as np
from imageio import imread, imwrite
from scipy.ndimage.filters import convolve

def calc_energy(img):
    filter_du = np.array([
        [1.0, 2.0, 1.0],
        [0.0, 0.0, 0.0],
        [-1.0, -2.0, -1.0],
    ])
    # 將一個 2D 的濾波器轉為 3D 的濾波器,為每個通道設定相同的濾波器:R,G,B
    filter_du = np.stack([filter_du] * 3, axis=2)

    filter_dv = np.array([
        [1.0, 0.0, -1.0],
        [2.0, 0.0, -2.0],
        [1.0, 0.0, -1.0],
    ])
    # 將一個 2D 的濾波器轉為 3D 的濾波器,為每個通道設定相同的濾波器:R,G,B
    filter_dv = np.stack([filter_dv] * 3, axis=2)

    img = img.astype('float32')
    convolved = np.absolute(convolve(img, filter_du)) + np.absolute(convolve(img, filter_dv))

    # 我們將紅綠色藍三通道中的能量相加
    energy_map = convolved.sum(axis=2)

    return energy_map

def crop_c(img, scale_c):
    r, c, _ = img.shape
    new_c = int(scale_c * c)

    for i in trange(c - new_c):
        img = carve_column(img)

    return img

def crop_r(img, scale_r):
    img = np.rot90(img, 1, (0, 1))
    img = crop_c(img, scale_r)
    img = np.rot90(img, 3, (0, 1))
    return img

def carve_column(img):
    r, c, _ = img.shape

    M, backtrack = minimum_seam(img)
    mask = np.ones((r, c), dtype=np.bool)

    j = np.argmin(M[-1])
    for i in reversed(range(r)):
        mask[i, j] = False
        j = backtrack[i, j]

    mask = np.stack([mask] * 3, axis=2)
    img = img[mask].reshape((r, c - 1, 3))
    return img

def minimum_seam(img):
    r, c, _ = img.shape
    energy_map = calc_energy(img)

    M = energy_map.copy()
    backtrack = np.zeros_like(M, dtype=np.int)

    for i in range(1, r):
        for j in range(0, c):
            # 處理影象的左邊緣,防止索引到 -1
            if j == 0:
                idx = np.argmin(M[i-1, j:j + 2])
                backtrack[i, j] = idx + j
                min_energy = M[i-1, idx + j]
            else:
                idx = np.argmin(M[i - 1, j - 1:j + 2])
                backtrack[i, j] = idx + j - 1
                min_energy = M[i - 1, idx + j - 1]

            M[i, j] += min_energy

    return M, backtrack

def main():
    if len(sys.argv) != 5:
        print('usage: carver.py <r/c> <scale> <image_in> <image_out>', file=sys.stderr)
        sys.exit(1)

    which_axis = sys.argv[1]
    scale = float(sys.argv[2])
    in_filename = sys.argv[3]
    out_filename = sys.argv[4]

    img = imread(in_filename)

    if which_axis == 'r':
        out = crop_r(img, scale)
    elif which_axis == 'c':
        out = crop_c(img, scale)
    else:
        print('usage: carver.py <r/c> <scale> <image_in> <image_out>', file=sys.stderr)
        sys.exit(1)
    
    imwrite(out_filename, out)

if __name__ == '__main__':
    main()
複製程式碼

修改於(2018 年 5 月 5 日): 正如一個熱心的 reddit 使用者所說,通過使用 numba 來加速計算繁重的功能,可以很容易的得到幾十倍的效能提升。要想體驗 numba,只要在函式 carve_columnminimum_seam 之前加上 @numba.jit。就像下面這樣:

@numba.jit
def carve_column(img):

@numba.jit
def minimum_seam(img):
複製程式碼

如果發現譯文存在錯誤或其他需要改進的地方,歡迎到 掘金翻譯計劃 對譯文進行修改並 PR,也可獲得相應獎勵積分。文章開頭的 本文永久連結 即為本文在 GitHub 上的 MarkDown 連結。


掘金翻譯計劃 是一個翻譯優質網際網路技術文章的社群,文章來源為 掘金 上的英文分享文章。內容覆蓋 AndroidiOS前端後端區塊鏈產品設計人工智慧等領域,想要檢視更多優質譯文請持續關注 掘金翻譯計劃官方微博知乎專欄

相關文章