我是如何使計算提速>150倍的
書接上文《我是如何使計算時間提速25.6倍》.
上篇文章提到, F-measure使用累計直方圖可以進一步加速計算, 但是E-measure卻沒有改出來. 在寫完上篇文章的那個晚上, 重新整理思路後, 我似乎想到了如何去使用累計直方圖來再次提速.
速度的制約
雖然使用"解耦"的思路可以高效優化每一個閾值下指標的計算過程, 但是整體的 for
迴圈確實仍然會佔用較大的時間. 又考慮到各個閾值下的計算實際上並無太大關聯, 如果可以實現同時計算, 那必然可以進一步提升速度. 這裡我們又要把目光放回到在計算F-measure時大放光彩的累計直方圖的策略上.
在前面的解耦之後, 實際上獲得的關鍵變數是 fg_fg_numel
和 fg_bg_numel
.
fg_fg_numel = np.count_nonzero(binarized_pred & gt)
fg_bg_numel = np.count_nonzero(binarized_pred & ~gt)
從這兩個變數本身入手, 如果使用累計直方圖的話, 實際上可以同時獲得 >=不同閾值
下的前景畫素(值為1)的數量, 計算的本質和 np.count_nonzero
是一樣的東西. 所以我們可以進行直觀的替換:
"""
函式內部變數命名規則:
pred屬性(前景fg、背景bg)_gt屬性(前景fg、背景bg)_變數含義
如果僅考慮pred或者gt,則另一個對應的屬性位置使用`_`替換
"""
fg_fg_hist, _ = np.histogram(pred[gt], bins=bins)
fg_bg_hist, _ = np.histogram(pred[~gt], bins=bins)
fg_fg_numel_w_thrs = np.cumsum(np.flip(fg_fg_hist), axis=0)
fg_bg_numel_w_thrs = np.cumsum(np.flip(fg_bg_hist), axis=0)
這樣我們就獲得了不同閾值下的對應的一系列 fg_fg_numel
和 fg_bg_numel
了. 這裡需要注意的是, 使用的劃分割槽間 bins
的設定. 由於預設的 histogram
劃分的區間會包含最後一個端點, 所以比較合理的劃分是 bins = np.linspace(0, 256, 257)
, 這樣最後一個區間是 [255, 256]
, 就可以包含到最大的值, 又不會和 254
重複計數.
為了便於計算, 這裡將後面會用到的 pred
前景統計 fg___numel_w_thrs
和背景統計 bg____numel_w_thrs
直接寫出來, 便於使用:
fg___numel_w_thrs = fg_fg_numel_w_thrs + fg_bg_numel_w_thrs
bg___numel_w_thrs = self.gt_size - fg___numel_w_thrs
後面的步驟和之前的基本一致, numpy的廣播機制使得不需要改動太多. 由於這部分程式碼實際上再多處位置會被使用, 所以提取成一個單獨的方法.
def generate_parts_numel_combinations(self, fg_fg_numel, fg_bg_numel, pred_fg_numel, pred_bg_numel):
bg_fg_numel = self.gt_fg_numel - fg_fg_numel
bg_bg_numel = pred_bg_numel - bg_fg_numel
parts_numel = [fg_fg_numel, fg_bg_numel, bg_fg_numel, bg_bg_numel]
mean_pred_value = pred_fg_numel / self.gt_size
mean_gt_value = self.gt_fg_numel / self.gt_size
demeaned_pred_fg_value = 1 - mean_pred_value
demeaned_pred_bg_value = 0 - mean_pred_value
demeaned_gt_fg_value = 1 - mean_gt_value
demeaned_gt_bg_value = 0 - mean_gt_value
combinations = [
(demeaned_pred_fg_value, demeaned_gt_fg_value),
(demeaned_pred_fg_value, demeaned_gt_bg_value),
(demeaned_pred_bg_value, demeaned_gt_fg_value),
(demeaned_pred_bg_value, demeaned_gt_bg_value)
]
return parts_numel, combinations
後面計算 enhanced_matrix_sum
的部分也就順理成章比較自然的可以寫出來:
parts_numel_w_thrs, combinations = self.generate_parts_numel_combinations(
fg_fg_numel=fg_fg_numel_w_thrs, fg_bg_numel=fg_bg_numel_w_thrs,
pred_fg_numel=fg___numel_w_thrs, pred_bg_numel=bg___numel_w_thrs,
)
# 這裡雖然可以使用列表來收集各個results_part,但是列表之後還需要再轉為numpy陣列來求和,倒不如直接一次性申請好空間後面直接裝入即可
results_parts = np.empty(shape=(4, 256), dtype=np.float64)
for i, (part_numel, combination) in enumerate(zip(parts_numel_w_thrs, combinations)):
align_matrix_value = 2 * (combination[0] * combination[1]) / \
(combination[0] ** 2 + combination[1] ** 2 + _EPS)
enhanced_matrix_value = (align_matrix_value + 1) ** 2 / 4
results_parts[i] = enhanced_matrix_value * part_numel
enhanced_matrix_sum = results_parts.sum(axis=0)
整體梳理
主要邏輯已經搞定, 接下來就是將這些程式碼與原始的程式碼融合起來, 也就是整合原始程式碼的 cal_em_with_threshold
和 cal_enhanced_matrix
兩個方法.
def cal_em_with_threshold(self, pred: np.ndarray, gt: np.ndarray, threshold: float) -> float:
binarized_pred = pred >= threshold
if self.gt_fg_numel == 0:
binarized_pred_bg_numel = np.count_nonzero(~binarized_pred)
enhanced_matrix_sum = binarized_pred_bg_numel
elif self.gt_fg_numel == self.gt_size:
binarized_pred_fg_numel = np.count_nonzero(binarized_pred)
enhanced_matrix_sum = binarized_pred_fg_numel
else:
enhanced_matrix_sum = self.cal_enhanced_matrix(binarized_pred, gt)
em = enhanced_matrix_sum / (self.gt_size - 1 + _EPS)
return em
結合前面程式碼中計算出的各個閾值下的前背景元素的統計值, 上面這裡的程式碼實際上可以通過使用現有運算結果進行化簡, 即 if
的前兩個分支. 另外閾值劃分也不需要顯式處理, 因為已經在累計直方圖中搞定了. 所以這裡的程式碼對於動態閾值計算的情況下, 是可以被合併到 cal_enhanced_matrix
的計算過程中的. 直接得到最終的整合後的方法:
def cal_em_with_cumsumhistogram(self, pred: np.ndarray, gt: np.ndarray) -> np.ndarray:
"""
函式內部變數命名規則:
pred屬性(前景fg、背景bg)_gt屬性(前景fg、背景bg)_變數含義
如果僅考慮pred或者gt,則另一個對應的屬性位置使用`_`替換
"""
pred = (pred * 255).astype(np.uint8)
bins = np.linspace(0, 256, 257)
fg_fg_hist, _ = np.histogram(pred[gt], bins=bins)
fg_bg_hist, _ = np.histogram(pred[~gt], bins=bins)
fg_fg_numel_w_thrs = np.cumsum(np.flip(fg_fg_hist), axis=0)
fg_bg_numel_w_thrs = np.cumsum(np.flip(fg_bg_hist), axis=0)
fg___numel_w_thrs = fg_fg_numel_w_thrs + fg_bg_numel_w_thrs
bg___numel_w_thrs = self.gt_size - fg___numel_w_thrs
if self.gt_fg_numel == 0:
enhanced_matrix_sum = bg___numel_w_thrs
elif self.gt_fg_numel == self.gt_size:
enhanced_matrix_sum = fg___numel_w_thrs
else:
parts_numel_w_thrs, combinations = self.generate_parts_numel_combinations(
fg_fg_numel=fg_fg_numel_w_thrs, fg_bg_numel=fg_bg_numel_w_thrs,
pred_fg_numel=fg___numel_w_thrs, pred_bg_numel=bg___numel_w_thrs,
)
results_parts = np.empty(shape=(4, 256), dtype=np.float64)
for i, (part_numel, combination) in enumerate(zip(parts_numel_w_thrs, combinations)):
align_matrix_value = 2 * (combination[0] * combination[1]) / \
(combination[0] ** 2 + combination[1] ** 2 + _EPS)
enhanced_matrix_value = (align_matrix_value + 1) ** 2 / 4
results_parts[i] = enhanced_matrix_value * part_numel
enhanced_matrix_sum = results_parts.sum(axis=0)
em = enhanced_matrix_sum / (self.gt_size - 1 + _EPS)
return em
還是為了重用, cal_em_with_threshold
(該方法需要保留, 因為還有另一種E-measure的計算情況需要用到該方法)可以被重構:
def cal_em_with_threshold(self, pred: np.ndarray, gt: np.ndarray, threshold: float) -> float:
"""
函式內部變數命名規則:
pred屬性(前景fg、背景bg)_gt屬性(前景fg、背景bg)_變數含義
如果僅考慮pred或者gt,則另一個對應的屬性位置使用`_`替換
"""
binarized_pred = pred >= threshold
fg_fg_numel = np.count_nonzero(binarized_pred & gt)
fg_bg_numel = np.count_nonzero(binarized_pred & ~gt)
fg___numel = fg_fg_numel + fg_bg_numel
bg___numel = self.gt_size - fg___numel
if self.gt_fg_numel == 0:
enhanced_matrix_sum = bg___numel
elif self.gt_fg_numel == self.gt_size:
enhanced_matrix_sum = fg___numel
else:
parts_numel, combinations = self.generate_parts_numel_combinations(
fg_fg_numel=fg_fg_numel, fg_bg_numel=fg_bg_numel,
pred_fg_numel=fg___numel, pred_bg_numel=bg___numel,
)
results_parts = []
for i, (part_numel, combination) in enumerate(zip(parts_numel, combinations)):
align_matrix_value = 2 * (combination[0] * combination[1]) / \
(combination[0] ** 2 + combination[1] ** 2 + _EPS)
enhanced_matrix_value = (align_matrix_value + 1) ** 2 / 4
results_parts.append(enhanced_matrix_value * part_numel)
enhanced_matrix_sum = sum(results_parts)
em = enhanced_matrix_sum / (self.gt_size - 1 + _EPS)
return em
效率對比
使用本地的845張灰度預測圖和二值mask真值資料進行測試比較, 重新跑了一遍, 總體時間對比如下:
方法 | 總體耗時(s) | 速度提升(倍) |
---|---|---|
'base' | 539.2173762321472s | x1 |
'best' | 19.94518733024597s | x27.0 (539.22/19.95) |
'cumsumhistogram' | 3.2935903072357178s | x163.8 (539.22/3.29) |
還是那句話, 雖然具體時間可能還受硬體限制, 但是相對快慢還是比較明顯的.
測試程式碼可見我的 github
: https://github.com/lartpang/CodeForArticle/tree/main/sod_metrics