PyTorch 實戰:計算 Wasserstein 距離
(給Python開發者加星標,提升Python技能)
編譯:機器之心,作者:Daniel Daza
最優傳輸理論及 Wasserstein 距離是很多讀者都希望瞭解的基礎,本文主要通過簡單案例展示了它們的基本思想,並通過 PyTorch 介紹如何實戰 W 距離。
機器學習中的許多問題都涉及到令兩個分佈儘可能接近的思想,例如在 GAN 中令生成器分佈接近判別器分佈就能偽造出逼真的影象。但是 KL 散度等分佈的度量方法有很多侷限性,本文則介紹了 Wasserstein 距離及 Sinkhorn 迭代方法,它們 GAN 及眾多工上都展示了傑出的效能。
在簡單的情況下,我們假設從未知資料分佈 p(x) 中觀測到一些隨機變數 x(例如,貓的圖片),我們想要找到一個模型 q(x|θ)(例如一個神經網路)能作為 p(x) 的一個很好的近似。如果 p 和 q 的分佈很相近,那麼就表明我們的模型已經學習到如何識別貓。
因為 KL 散度可以度量兩個分佈的距離,所以只需要最小化 KL(q‖p) 就可以了。可以證明,最小化 KL(q‖p) 等價於最小化一個負對數似然,這樣的做法在我們訓練一個分類器時很常見。例如,對於變分自編碼器來說,我們希望後驗分佈能夠接近於某種先驗分佈,這也是我們通過最小化它們之間的 KL 散度來實現的。
儘管 KL 散度有很廣泛的應用,在某些情況下,KL 散度則會失效。不妨考慮一下如下圖所示的離散分佈:
KL 散度假設這兩個分佈共享相同的支撐集(也就是說,它們被定義在同一個點集上)。因此,我們不能為上面的例子計算 KL 散度。由於這一個限制和其他計算方面的因素促使研究人員尋找一種更適合於計算兩個分佈之間差異的方法。
在本文中,作者將:
簡單介紹最優傳輸問題
將 Sinkhorn 迭代描述為對解求近似
使用 PyTorch 計算 Sinkhorn 距離
描述用於計算 mini-batch 之間的距離的對該實現的擴充套件
移動概率質量函式
我們不妨把離散的概率分佈想象成空間中分散的點的質量。我們可以觀測這些帶質量的點從一個分佈移動到另一個分佈需要做多少功,如下圖所示:
接著,我們可以定義另一個度量標準,用以衡量移動做所有點所需要做的功。要想將這個直觀的概念形式化定義下來,首先,我們可以通過引入一個耦合矩陣 P(coupling matrix),它表示要從 p(x) 支撐集中的一個點上到 q(x) 支撐集中的一個點需要分配多少概率質量。對於均勻分佈,我們規定每個點都具有 1/4 的概率質量。如果我們將本例支撐集中的點從左到右排列,我們可以將上述的耦合矩陣寫作:
也就是說,p(x) 支撐集中點 1 的質量被分配給了 q(x) 支撐集中的點 4,p(x) 支撐集中點 2 的質量被分配給了 q(x) 支撐集中的點 3,以此類推,如上圖中的箭頭所示。
為了算出質量分配的過程需要做多少功,我們將引入第二個矩陣:距離矩陣。該矩陣中的每個元素 C_ij 表示將 p(x) 支撐集中的點移動到 q(x) 支撐集中的點上的成本。點與點之間的歐幾里得距離是定義這種成本的一種方式,它也被稱為「ground distance」。如果我們假設 p(x) 的支撐集和 q(x) 的支撐集分別為 {1,2,3,4} 和 {5,6,7,8},成本矩陣即為:
根據上述定義,總的成本可以通過 P 和 C 之間的 Frobenius 內積來計算:
你可能已經注意到了,實際上有很多種方法可以把點從一個支撐集移動到另一個支撐集中,每一種方式都會得到不同的成本。上面給出的只是一個示例,但是我們感興趣的是最終能夠讓成本較小的分配方式。這就是兩個離散分佈之間的「最優傳輸」問題,該問題的解是所有耦合矩陣上的最低成本 L_C。
由於不是所有矩陣都是有效的耦合矩陣,最後一個條件會引入了一個約束。對於一個耦合矩陣來說,其所有列都必須要加到帶有 q(x) 概率質量的向量中。在本例中,該向量包含 4 個值為 1/4 的元素。更一般地,我們可以將兩個向量分別記為 a 和 b,因此最有運輸問題可以被寫作:
當距離矩陣基於一個有效的距離函式構建時,最小成本即為我們所說的「Wasserstein 距離」。
關於該問題的解以及將其擴充套件到連續概率分佈中還有大量問題需要解決。如果想要獲取更正式、更容易理解的解釋,讀者可以參閱 Gabriel Peyré 和 Marco Cuturi 編寫的「Computational Optimal Transport」一書,此書也是本文寫作的主要參考來源之一。
這裡的基本設定是,我們已經把求兩個分佈之間距離的問題定義為求最優耦合矩陣的問題。事實證明,我們可以通過一個小的修改讓我們以迭代和可微分的方式解決這個問題,這將讓我們可以很好地使用深度學習自動微分機制完成該工作。
熵正則化和 Sinkhorn 迭代
首先,我們將一個矩陣的熵定義如下:
正如資訊理論中概率分佈的熵一樣,一個熵較低的矩陣將會更稀疏,它的大部分非零值集中在幾個點周圍。相反,一個具有高熵的矩陣將會更平滑,其最大熵是在均勻分佈的情況下獲得的。我們可以將正則化係數 ε 引入最優傳輸問題,從而得到更平滑的耦合矩陣:
通過增大 ε,最終得到的耦合矩陣將會變得更加平滑;而當 ε 趨近於零時,耦合矩陣會更加稀疏,同時最終的解會更加趨近於原始最優運輸問題。
通過引入這種熵正則化,該問題變成了一個凸優化問題,並且可 以通過使用「Sinkhorn iteration」求解。解可以被寫作 P=diag(u)Kdiag(v),在迭代過程中交替更新 u 和 v:
其中 K 是一個用 C 計算的核矩陣(kernel matrix)。由於這些迭代過程是在對原始問題的正則化版本求解,因此對應產生的 Wasserstein 距離有時被稱為 Sinkhorn 距離。該迭代過程會形成一個線性操作的序列,因此對於深度學習模型,通過這些迭代進行反向傳播是非常簡單的。
通過 PyTorch 實現 Sinkhorn 迭代
為了提升 Sinkhorn 迭代的收斂性和穩定性,還可以加入其它的步驟。我們可以在 GitHub 上找到 Gabriel Peyre 完成的詳細實現。
專案連結:https://github.com/gpeyre/SinkhornAutoDiff。
讓我們先用一個簡單的例子來測試一下,現在我們將研究二維空間(而不是上面的一維空間)中的離散均勻分佈。在這種情況下,我們將在平面上移動概率質量。讓我們首先定義兩個簡單的分佈:
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
np.random.seed(42)
n_points = 5
a = np.array([[i, 0] for i in range(n_points)])
b = np.array([[i, 1] for i in range(n_points)])
plt.figure(figsize=(6, 3))
plt.scatter(a[:, 0], a[:, 1], label='supp($p(x)$)')
plt.scatter(b[:, 0], b[:, 1], label='supp($q(x)$)')
plt.legend();
我們很容易看出,最優傳輸對應於將 p(x) 支撐集中的每個點分配到 q(x) 支撐集上的點。對於所有的點來說,距離都是 1,同時由於分佈是均勻的,每點移動的概率質量是 1/5。因此,Wasserstein 距離是 5×1/5= 1。現在我們用 Sinkhorn 迭代來計算這個距離:
import torch
from layers import SinkhornDistance
x = torch.tensor(a, dtype=torch.float)
y = torch.tensor(b, dtype=torch.float)
sinkhorn = SinkhornDistance(eps=0.1, max_iter=100, reduction=None)
dist, P, C = sinkhorn(x, y)
print("Sinkhorn distance: {:.3f}".format(dist.item()))
————————————————————————————————————————————————
Sinkhorn distance: 1.000
結果正如我們所計算的那樣,距離為 1。現在,讓我們檢視一下「Sinkhorn( )」方法返回的矩陣,其中 P 是計算出的耦合矩陣,C 是距離矩陣。距離矩陣如下圖所示:
plt.imshow(C)
plt.title('Distance matrix')
plt.colorbar();
plt.imshow(C)plt.title('Distance matrix')plt.colorbar();
元素「C[0, 0]」說明了將(0,0)點的質量移動到(0,1)所需要的成本 1 是如何產生的。在該行的另一端,元素「C[0, 4]」包含了將點(0,0)的質量移動到點(4,1)所需要的成本,這個成本是整個矩陣中最大的:
由於我們為距離矩陣使用的是平方後的 ℓ2 範數,計算結果如上所示。現在,讓我們看看計算出的耦合矩陣吧:
plt.imshow(P)
plt.title('Coupling matrix');
plt.imshow(P)plt.title('Coupling matrix');
該圖很好地向我們展示了演算法是如何有效地發現最優耦合,它與我們前面確定的耦合矩陣是相同的。到目前為止,我們使用了 0.1 的正則化係數。如果將該值增加到 1 會怎樣?
sinkhorn = SinkhornDistance(eps=1, max_iter=100, reduction=None)
dist, P, C = sinkhorn(x, y)
print("Sinkhorn distance: {:.3f}".format(dist.item()))
plt.imshow(P);
————————————————————————————————————————————————
Sinkhorn distance: 1.408
正如我們前面討論過的,加大 ε 有增大耦合矩陣熵的作用。接下來,我們看看 P 是如何變得更加平滑的。但是,這樣做也會為計算出的距離帶來一個不好的影響,導致對 Wasserstein 距離的近似效果變差。
視覺化支撐集的空間分配也很有意思:
def show_assignments(a, b, P):
norm_P = P/P.max()
for i in range(a.shape[0]):
for j in range(b.shape[0]):
plt.arrow(a[i, 0], a[i, 1], b[j, 0]-a[i, 0], b[j, 1]-a[i, 1],
alpha=norm_P[i,j].item())
plt.title('Assignments')
plt.scatter(a[:, 0], a[:, 1])
plt.scatter(b[:, 0], b[:, 1])
plt.axis('off')
show_assignments(a, b, P)
讓我們在一個更有趣的分佈(Moons 資料集)上完成這項工作。
from sklearn.datasets import make_moons
X, Y = make_moons(n_samples = 30)
a = X[Y==0]
b = X[Y==1]
x = torch.tensor(a, dtype=torch.float)
y = torch.tensor(b, dtype=torch.float)
sinkhorn = SinkhornDistance(eps=0.1, max_iter=100, reduction=None)
dist, P, C = sinkhorn(x, y)
print("Sinkhorn distance: {:.3f}".format(dist.item()))
show_assignments(a, b, P)
——————————————————————————————————————————
Sinkhorn distance: 1.714
Mini-batch 上的 Sinkhorn 距離
在深度學習中,我們通常對使用 mini-batch 來加速計算十分感興趣。我們也可以通過使用額外的批處理維度修改 Sinkhorn 迭代來滿足該設定。將此更改新增到具體實現中後,我們可以在一個 mini-batch 中計算多個分佈的 Sinkhorn 距離。下面我們將通過另一個容易被驗證的例子說明這一點。
程式碼:https://github.com/dfdazac/wassdistance/blob/master/layers.py
我們將計算包含 5 個支撐點的 4 對均勻分佈的 Sinkhorn 距離,它們垂直地被 1(如上所示)、2、3 和 4 個單元分隔開。這樣,它們之間的 Wasserstein 距離將分別為 1、4、9 和 16。
n = 5
batch_size = 4
a = np.array([[[i, 0] for i in range(n)] for b in range(batch_size)])
b = np.array([[[i, b + 1] for i in range(n)] for b in range(batch_size)])
# Wrap with torch tensors
x = torch.tensor(a, dtype=torch.float)
y = torch.tensor(b, dtype=torch.float)
sinkhorn = SinkhornDistance(eps=0.1, max_iter=100, reduction=None)
dist, P, C = sinkhorn(x, y)
print("Sinkhorn distances: ", dist)
——————————————————————————————————————————
Sinkhorn distances: tensor([ 1.0001, 4.0001, 9.0000, 16.0000])
這樣做確實有效!同時,也請注意,現在 P 和 C 為 3 維張量,它包含 mini-batch 中每對分佈的耦合矩陣和距離矩陣:
print('P.shape = {}'.format(P.shape))
print('C.shape = {}'.format(C.shape))
——————————————————————————————————————————
P.shape = torch.Size([4, 5, 5])
C.shape = torch.Size([4, 5, 5])
結語
分佈之間的 Wasserstein 距離及其通過 Sinkhorn 迭代實現的計算方法為我們帶來了許多可能性。該框架不僅提供了對 KL 散度等距離的替代方法,而且在建模過程中提供了更大的靈活性,我們不再被迫要選擇特定的引數分佈。這些迭代過程可以在 GPU 上高效地執行,並且是完全可微分的,這使得它對於深度學習來說是一個很好的選擇。這些優點在機器學習領域的最新研究中得到了充分的利用(如自編碼器和距離嵌入),使其在該領域的應用前景更加廣闊。
原文連結:https://dfdazac.github.io/sinkhorn.html
推薦閱讀
(點選標題可跳轉閱讀)
PyTorch 0.4.0 大更新,正式支援 Windows 平臺
覺得本文對你有幫助?請分享給更多人
關注「Python開發者」加星標,提升Python技能
喜歡就點一下「好看」唄~
相關文章
- Pytorch計算機視覺實戰(更新中)PyTorch計算機視覺
- Levenshtein:計算字串的編輯距離字串
- 優於VAE,為萬能近似器高斯混合模型加入Wasserstein距離模型
- milvus 使用 l2 歐式距離計算向量的距離,計算出來的距離的最大值是多少?
- 28、(向量)歐幾里得距離計算
- JAVA計算兩經緯度間的距離Java
- 【leetcode】72. Edit Distance 編輯距離計算LeetCode
- 經緯度距離換算
- 最小距離分類器,互動式選取影像樣本分類資料,進行最小距離分類(實現歐式距離,馬氏距離,計程距離)
- 根據兩點經緯度計算距離和角度——java實現Java
- 計算地圖中兩點之間的距離地圖
- IBM量子計算機亮相 距離標準量子計算機相距甚遠IBM計算機
- 通過經緯度計算距離實現附近、附近的人等功能
- PHP實現透過經緯度計算距離和查附近店門PHP
- 微信小程式——計算2點之間的距離微信小程式
- 透過經緯度計算距離獲取附近商家
- 計算幾何 —— 二維幾何基礎 —— 距離度量方法
- Python 計算多少天前後、距離 X日多久的日期Python
- 通過sql 計算兩經緯度之間的距離SQL
- C語言:使用函式計算兩點間的距離C語言函式
- 馬氏距離與歐氏距離
- 常見問題01:計算地球上兩個點的距離
- 微信小程式結合騰訊地圖實現使用者商家距離計算微信小程式地圖
- pytorch(1)梯度計算PyTorch梯度
- 【Python】距離Python
- 根據經緯度計算兩點之間的距離的公式公式
- 我們可能會遇到的距離量算方法
- JavaScript 元素距離視窗頂部的距離JavaScript
- java 經緯度處理、計算兩地的距離、獲取當前一定距離以內的經緯度值Java
- 編輯距離及編輯距離演算法演算法
- 曼哈頓距離與切比雪夫距離
- JavaScript獲取元素距離文件頂部的距離JavaScript
- Php兩點地理座標距離的計算方法和具體程式碼PHP
- 當支援向量機遇上神經網路:這項研究揭示了SVM、GAN、Wasserstein距離之間的關係神經網路
- 場景設計中距離感的設計
- Laravel 距離排序Laravel排序
- unit原子距離
- 餘弦距離