keras自定義網路層

快到皖裡來發表於2021-02-16

在深度學習領域,Keras是一個高度封裝的庫並被廣泛應用,可以通過呼叫其內建網路模組(各種網路層)實現針對性的模型結構;當所需要的網路層功能不被包含時,則需要通過自定義網路層或模型實現。

如何在keras框架下自定義層,基本“套路”如下。

一般地,keras中的網路層是一個類,所以自定義層即編寫一個類,更為重要的是這個類(即自定義層)需要繼承Layer父類,而且需要實現以下四種方法:

  1. __init __ (self, output_dim, **kwargs)

這個方法是用來初始化並自定義自定義層所需的屬性,比如output_dim;
此外,該方法需要執行super().__init __(**kwargs),這行程式碼是執行Layer類中的初始化函式;
當執行上述程式碼就沒有必要去管input_shape,weights,trainable等關鍵字引數,因為父類(Layer)的初始化函式實現了它們與layer例項的繫結。

  1. build(self, input_shape)

這個方法是用來建立層的權重;
在該方法中,根據之前的繼承,通過Layer類的add_weight方法來自定義並新增一個權重矩陣,這個方法需要input_shape引數;
該方法必須設self.built = True,目的是為了保證這個層的權重定義函式build被執行過了;
在built函式中,需要說明這個權重各方面的屬性,比如shape、初始化方式以及可訓練性等資訊。

  1. call(self, x)

這個方法是用來編寫層的功能邏輯;
在該方法中,需要關注傳入call的第一個引數:輸入張量x;x只能是一種形式變數,不能是具體的變數,即它不能被定義;
這個call函式就是該層的計算邏輯,當建立好這個層例項後,該例項可以執行call函式;
可見,這個層的核心應該是一段符號式的輸入張量到輸出張量的計算過程。

  1. compute_output_shape(self, input_shape)

這個方法是用來保證輸出shape是正確的;
這裡重寫compute_output_shape方法去覆蓋父類中的同名方法,來保證輸出的shape符合實際;
父類Layer中的compute_output_shape方法直接返回的是input_shape這明顯是不對的,所以需要重寫該方法。

示例

結合官方文件的例子,給出如下一個自定義層的程式碼:

使用自定義層,就如同使用keras內建網路層一樣,如下圖所示:(另外,本例使用kears內建的啟用函式層ReLU承接自定義層的輸出,從而避免將啟用函式的功能加入到自定義層中)

相關文章