tf.one_hot()用法 獨熱編碼

weixin_33866037發表於2018-05-13

tf.one_hot()進行獨熱編碼

首先肯定需要解釋下什麼叫做獨熱編碼(one-hot encoding),獨熱編碼一般是在有監督學習中對資料集進行標註時候使用的,指的是在分類問題中,將存在資料類別的那一類用X表示,不存在的用Y表示,這裡的X常常是1, Y常常是0。
舉個例子:
比如我們有一個5類分類問題,我們有資料(Xi,Yi),其中類別Yi有五種取值(因為是五類分類問題),所以如果Yj為第一類那麼其獨熱編碼為: [1,0,0,0,0],如果是第二類那麼獨熱編碼為:[0,1,0,0,0],也就是說只對存在有該類別的數的位置上進行標記為1,其他皆為0。這個編碼方式經常用於多分類問題,特別是損失函式為交叉熵函式的時候。接下來我們再介紹下TensorFlow中自帶的對資料進行獨熱編碼的函式tf.one_hot(),首先先貼出其API手冊

one_hot(
    indices,#輸入,這裡是一維的
    depth,# one hot dimension.
    on_value=None,#output 預設1
    off_value=None,#output 預設0
    axis=None,
    dtype=None,
    name=None
)

需要指定indices,和depth,其中depth是編碼深度,on_value和off_value相當於是編碼後的開閉值,如同我們剛才描述的X值和Y值,需要和dtype相同型別(指定了dtype的情況下),axis指定編碼的軸。這裡給個小的例項:

import tensorflow as tf
var0 = tf.one_hot(indices=[1, 2, 3], depth=3, axis=0)
var1 = tf.one_hot(indices=[1, 2, 3], depth=4, axis=0)
var2 = tf.one_hot(indices=[1, 2, 3], depth=4, axis=1)
# axis=1 按行排
var3 = tf.one_hot(indices=[1, 2, 3], depth=4, axis=-1)
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    a0 = sess.run(var0)
    a1 = sess.run(var1)
    a2 = sess.run(var2)
    a3 = sess.run(var3)
    print("var0(axis=0 depth=3)\n",a0)
    print("var1(axis=0 depth=4P)\n",a1)
    print("var2(axis=1)\n",a2)
    print("var3(axis=-1)\n",a3)
4550574-3e3fb51b11ed96f7.png
結果

相關文章