【Python 3】keras.layers.Lambda解析與使用

Better Bench發表於2021-01-02

1 作用

Lambda表示式: 用一行程式碼去表示一個函式,簡化和美觀程式碼。

keras.layers.Lambda(): 是Lambda表示式的應用。指定在神經網路模型中,如果某一層需要通過一個函式去變換資料,那利用keras.layers.Lambda()這個函式單獨把這一步資料操作命為單獨的一Lambda層。

2 引數解析

keras.layers.core.Lambda(function, output_shape=None, mask=None, arguments=None)

引數

  • function:要實現的函式,該函式僅接受一個變數,即神經網路上一層的輸出

  • output_shape:函式應該返回的值的shape,可以是一個tuple,也可以是一個根據輸入shape計算輸出shape的函式

  • mask: 掩膜

  • arguments:可選,是字典格式,用來傳參

3 舉例

3.1 傳參舉例

arguments引數,利用字典格式來傳參

# index是引數,
def slice(x,index):
    return x[:,:,index]

# 通過字典將引數index = 0傳遞進去
x1 = Lambda(slice,output_shape=(4,1),arguments={'index':0})(a)
# 通過字典將引數index = 1 傳遞進去
x2 = Lambda(slice,output_shape=(4,1),arguments={'index':1})(a)

3.2 簡單Demo

from keras.layers import Lambda
from keras.models import Input, Model
import numpy as np

## 第一步 定義模型
# 初始化兩個輸入形參
a = Input(shape=(2, ))
b = Input(shape=(2, ))
 
# 定義lambda要執行的函式
def minus(inputs):
    x, y = inputs
    return (x+y)
# 使用lambda表示式,對函式進行傳參
minus_layer = Lambda(minus, name='minus')([a, b])
model = Model(inputs=[a, b], outputs=[minus_layer])

## 第二步 測試模型
# 隨便定義的兩個陣列
v0 = np.array([5, 2])
v1 = np.array([8, 4])
# 轉成1*2的矩陣後測試模型
print(model.predict([v0.reshape(1, 2), v1.reshape(1, 2)]))

3.3 利用Lambda表示式實現某層資料的切片

Lambda傳引數
參考文件:keras Lambda自定義層實現資料的切片,Lambda傳引數

import numpy as np
from keras.models import Sequential
from keras.layers import Dense, Activation,Reshape
from keras.layers import merge
from keras.utils.visualize_util import plot
from keras.layers import Input, Lambda
from keras.models import Model
 
def slice(x,index):
        return x[:,:,index]
 
a = Input(shape=(4,2))
x1 = Lambda(slice,output_shape=(4,1),arguments={'index':0})(a)
x2 = Lambda(slice,output_shape=(4,1),arguments={'index':1})(a)
x1 = Reshape((4,1,1))(x1)
x2 = Reshape((4,1,1))(x2)
output = merge([x1,x2],mode='concat')
model = Model(a, output)
x_test = np.array([[[1,2],[2,3],[3,4],[4,5]]])
print model.predict(x_test)
plot(model, to_file='lambda.png',show_shapes=True)

上述程式碼實現的是,將矩陣的每一列提取出來,然後單獨進行操作,最後在拼在一起。視覺化的圖如下所示。
在這裡插入圖片描述

相關文章