1 自動微分
我們在《數值分析》課程中已經學過許多經典的數值微分方法。許多經典的數值微分演算法非常快,因為它們只需要計算差商。然而,他們的主要缺點在於他們是數值的,這意味著有限的算術精度和不精確的函式求值,而這些都從根本上限制了求解結果的質量。因此。充滿噪聲的、複雜多變的函式很難得到精準的數值微分。
自動微分技術(稱為“automatic differentiation, autodiff”)是介於符號微分和數值微分的一種技術,它是在計算效率和計算精度之間的一種折衷。自動微分不受任何離散化演算法誤差的約束,它充分利用了微分的鏈式法則和其他關於導數的性質來準確地計算它們。
2 前向自動微分
我們先來計算簡單的前向自動微分。假設我們有兩個變數\(u\)和\(v\),使用浮點數儲存。我們將變數\(u′=du/dt\)和\(v′=dv/dt\)和這些變數一起儲存,這裡\(t\)是獨立的變數。在一些程式設計語言(如Python)中,我們可以選擇定義一種新的資料型別來儲存\([u,u′]\)和\([v,v′]\)這類數對。我們可以在這些數對上定義一種代數運算,這些代數運算編碼了一些經典的操作:
在進行前向自動微分之前,我們需要先將計算\(f(t)\)所產生的操作序列表示為計算圖。接著,採用自底向上的遞推演算法的思想,從做為遞推起點的數對\(t≡[t_0,1]\)(因為\(dt/dt= 1\))開始,我們能夠按照我們上述編碼規則同時對函式\(f(t)\)和它的導數\(f′(t)\)進行求值。我們在程式語言中可以選擇令數對過載運算子,這樣額外的求導數運算就可以對使用者透明地執行了。
例1 比如,對於函式\(f(x) = \exp(x^2 - x)/{x}\),想要依次計算\({dy}_i/dx\)(這裡\(y_i\)為所有計算中間項)。則我們先從\(x\)開始將表示式分解為計算圖:
然後前向遞推地按照我們之前所述的編碼規則來進行求導
注意鏈式法則(chain rule)告訴我們:
所以我們對
有
事實上,我們也能夠處理有多個輸入的函式\(g\):
多元微分鏈式法則如下:
比如,對於
我們有
下面展示了一個對二元函式模擬前向自動微分的過程。
例2 設\(f(x_1, x_2) = x_1\cdot \exp(x_2) - x_1\),模擬前向微分過程。
接下來我們看如何用Python程式碼來實現單變數函式的前向自動微分過程。為了簡便起見,我們下面只編碼了幾個常用的求導規則。
import math
class Var:
def __init__(self, val, deriv=1.0):
self.val = val
self.deriv = deriv
def __add__(self, other):
if isinstance(other, Var):
val = self.val + other.val
deriv = self.deriv + other.deriv
else:
val = self.val + other
deriv = self.deriv
return Var(val, deriv)
def __radd__(self, other):
return self + other
def __sub__(self, other):
if isinstance(other, Var):
val = self.val - other.val
deriv = self.deriv - other.deriv
else:
val = self.val - other
deriv = self.deriv
return Var(val, deriv)
def __rsub__(self, other):
val = other - self.val
deriv = - self.deriv
return Var(val, deriv)
def __mul__(self, other):
if isinstance(other, Var):
val = self.val * other.val
deriv = self.val * other.deriv + self.deriv * other.val
else:
val = self.val * other
deriv = self.deriv * other
return Var(val, deriv)
def __rmul__(self, other):
return self * other
def __truediv__(self, other):
if isinstance(other, Var):
val = self.val / other.val
deriv = (self.deriv * other.val - self.val * other.deriv)/other.val**2
else:
val = self.val / other
deriv = self.deriv / other
return Var(val, deriv)
def __rtruediv__(self, other):
val = other / self.val
deriv = other * 1/self.val**2
return Var(val, deriv)
def __repr__(self):
return "value: {}\t gradient: {}".format(self.val, self.deriv)
def exp(f: Var):
return Var(math.exp(f.val), math.exp(f.val) * f.deriv)
例如,我們若嘗試計算函式\(f(x) = \exp(x^2 - x)/{x}\)在\(x=2.0\)處的導數\(f'(2.0)\)如下:
fx = lambda x: exp(x*x - x)/x
df = fx(Var(2.0))
print(df)
列印輸出:
value: 3.694528049465325 deriv: 9.236320123663312
可見,前向過程完成計算得到\(f(2.0)\approx 3.69\), \(f'(2.0)\approx 9.24\)。
3 反向自動微分
我們前面介紹的前向自動微分方法在計算\(y = f(t)\)的時候並行地計算\(f'(t)\)。接下來我們介紹一種“反向”自動微分方法,相比上一種的方法它僅需要更少的函式求值,不過需要以更多的記憶體消耗和更復雜的實現做為代價。
同樣,這個技術需要先將計算\(f(t)\)所產生的操作序列表示為計算圖。不過,與之前的從\(dt/dt = 1\)開始,然後往\(dy/dt\)方向計算不同,反向自動求導演算法從\(dy/dy = 1\)開始並且按與之前同樣的規則往反方向計算,一步步地將分母替換為\(dt\)。反向自動微分可以避免不必要的計算,特別是當\(y\)是一個多元函式的時候。例如,對\(f(t_1, t_2) = f_1(t_1) + f_2(t_2)\),反向自動微分並不需要計算\(f_1\)關於\(t_2\)的微分或\(f_2\)關於\(t_1\)的微分。
例3 設\(f(x_1, x_2) = x_1\cdot \exp(x_2) - x_1\),模擬反向自動微分過程。
可見若採用反向自動微分,我們需要儲存計算過程中的所有東西,故記憶體的使用量會和時間成正比。不過,在現有的深度學習框架中,對反向自動微分的實現進行了進一步最佳化,我們會在深度學習專題文章中再進行詳述。
4 總結
自動微分被廣泛認為是一種未被充分重視的數值技術, 它可以以儘量小的執行代價來產生函式的精確導數。它在軟體需要計算導數或Hessian來執行最佳化演算法時顯得格外有價值,從而避免每次目標函式改變時都去重新手動計算導數。當然,做為其便捷性的代價,自動微分也會帶來計算的效率問題,因為在實際工作中自動微分方法並不會去化簡表示式,而是直接應用最顯式的編碼規則。
參考
-
[1] Solomon J. Numerical algorithms: methods for computer vision, machine learning, and graphics[M]. CRC press, 2015.
-
[2] S&DS 631: Computation and Optimization Automatic Differentiation