機器學習筆記(3):多類邏輯迴歸

菩提樹下的楊過發表於2017-11-19

仍然是 動手學嘗試學習系列的筆記,原文見:多類邏輯迴歸 — 從0開始 。 這篇的主要目的,是從一堆服飾圖片中,通過機器學習識別出每個服飾圖片對應的分類是什麼(比如:一個看起來象短袖上衣的圖片,應該歸類到T-Shirt分類)

示例程式碼如下,這篇的程式碼略複雜,分成幾個步驟解讀:

 

一、下載資料,並顯示圖片及標籤

 1 from mxnet import gluon
 2 from mxnet import ndarray as nd
 3 import matplotlib.pyplot as plt
 4 import mxnet as mx
 5 from mxnet import autograd
 6 
 7 def transform(data, label):
 8     return data.astype('float32')/255, label.astype('float32')
 9 
10 #訓練資料集(需聯網下載,網速慢時,會很卡)
11 mnist_train = gluon.data.vision.FashionMNIST(train=True, transform=transform)
12 
13 #測試資料集(需聯網下載)
14 mnist_test = gluon.data.vision.FashionMNIST(train=False, transform=transform)
15 
16 # data, label = mnist_train[0]
17 # ('example shape: ', data.shape, 'label:', label)
18 
19 #顯示服飾圖片
20 def show_images(images):
21     n = images.shape[0]
22     _, figs = plt.subplots(1, n, figsize=(15, 15))
23     for i in range(n):
24         figs[i].imshow(images[i].reshape((28, 28)).asnumpy())
25         figs[i].axes.get_xaxis().set_visible(False)
26         figs[i].axes.get_yaxis().set_visible(False)
27     plt.show()
28 
29 #獲取圖片對應分類標籤文字
30 def get_text_labels(label):
31     text_labels = [
32         'T 恤', '長 褲', '套頭衫', '裙 子', '外 套',
33         '涼 鞋', '襯 衣', '運動鞋', '包 包', '短 靴'
34     ]
35     return [text_labels[int(i)] for i in label]
36 
37 #下面這些程式碼,用於輔助大家理解示例圖片資料集內部結構
38 # tup1 = mnist_train[0:1] #取出訓練集的第1個樣本
39 # print(type(tup1)) #<class 'tuple'> 可以看出這是個元組型別
40 # print(len(tup1)) #2 有2個元素
41 # print(type(tup1[0])) #<class 'mxnet.ndarray.ndarray.NDArray'> 第1個元素是一個矩陣
42 # print(type(tup1[1])) #<class 'numpy.ndarray'> 第2個元素是numpy的矩陣
43 # print(tup1[0].shape) #(1, 28, 28, 1) 第1個元素是一個四維矩陣,用來儲存每張圖中的畫素點對應的值,最後1維表示RGB通道,這裡只取了1個通道
44 # print(tup1[1].shape) #(1,) 第2個元素用於表示圖片對應的文字分類的索引值
45 # print(tup1[0]) #列印第1個元素(即:四維矩陣的值),<NDArray 1x28x28x1 @cpu(0)> 結果太長,就不列在註釋裡了
46 # print(tup1[1]) #[2.],列印第2個元素(即:該圖片對應的分類索引數值)
47 # print(get_text_labels(tup1[1])) #顯示分類索引值對應的文字['pullover']
48 
49 #取出訓練集中的圖片資料,以及圖片標籤索引值
50 data, label = mnist_train[0:10]
51 
52 #列印資料集的相關資訊
53 print('example shape: ', data.shape, 'label:', label)
54 
55 #顯示圖片
56 show_images(data)
57 
58 #列印圖片分類標籤
59 print(get_text_labels(label))
View Code

首次執行時,可能會很久都沒有反應,讓人誤以為程式碼有問題,其實背後在聯網下載資料,去睡會兒,等醒來的時候,估計就下載好了~_~,下載的資料會儲存在~/.mxnet/datasets/fashion-mnist目錄(mac環境):

下載完成後,上面的程式碼會將圖片資料解析並顯示出來,類似下面這樣:

 

二、讀取資料並初始化引數

 1 #批量讀取資料
 2 batch_size = 256
 3 #訓練集
 4 train_data = gluon.data.DataLoader(mnist_train, batch_size, shuffle=True)
 5 #測試集
 6 test_data = gluon.data.DataLoader(mnist_test, batch_size, shuffle=False)
 7 
 8 #每張圖片的畫素用向量表示,就是28*28的長度,即:784
 9 num_inputs = 784
10 #要預測10張圖片,即:輸出結果長度為10的向量
11 num_outputs = 10
12 
13 #初始化權重W、偏置b引數矩陣
14 W = nd.random_normal(shape=(num_inputs, num_outputs))
15 b = nd.random_normal(shape=num_outputs)
16 
17 params = [W, b]
18 
19 #附加梯度,方便後面用梯度下降法計算
20 for param in params:
21     param.attach_grad()
View Code

這與之前的 機器學習筆記(1):線性迴歸 很類似,不再重複解釋 

 

三、建立模型

 1 #歸一化函式
 2 def softmax(X):
 3     exp = nd.exp(X)
 4     partition = exp.sum(axis=1, keepdims=True)
 5     return exp / partition
 6 
 7 #計算模型(仍然是類似y=w.x+b的方程)
 8 def net(X):
 9     return softmax(nd.dot(X.reshape((-1, num_inputs)), W) + b)
10 
11 #損失函式(使用交叉熵函式)
12 def cross_entropy(yhat, y):
13     return - nd.pick(nd.log(yhat), y)
14 
15 #梯度下降法
16 def SGD(params, lr):
17     for param in params:
18         param[:] = param - lr * param.grad
View Code

其中softmax(歸一化)及交叉熵cross_entropy,詳情可參考上篇:歸一化(softmax)、資訊熵、交叉熵

 

四、如何評估準確度

 1 #計算準確度
 2 def accuracy(output, label):
 3     return nd.mean(output.argmax(axis=1) == label).asscalar()
 4 
 5 def _get_batch(batch):
 6     if isinstance(batch, mx.io.DataBatch):
 7         data = batch.data[0]
 8         label = batch.label[0]
 9     else:
10         data, label = batch
11     return data, label
12 
13 #評估準確度
14 def evaluate_accuracy(data_iterator, net):
15     acc = 0.
16     if isinstance(data_iterator, mx.io.MXDataIter):
17         data_iterator.reset()
18     for i, batch in enumerate(data_iterator):
19         data, label = _get_batch(batch)
20         output = net(data)
21         acc += accuracy(output, label)
22     return acc / (i+1)
View Code

機器學習的效果如何,通常要有一個評價值,上面的函式就是用來估計演算法和模型準確度的。

注: 這裡面用到了二個新的函式mean,argmax 解釋一下

mean類似sql中的avg函式,就是求平均值,即把一個矩陣的所有元數加起來,然後除以元數個數

from mxnet import ndarray as nd
x = nd.array([1,2,3,4,5,6]);
print(x,x.mean(),(1+2+3+4+5+6)/6.0)

輸出如下:

[ 1.  2.  3.  4.  5.  6.]
<NDArray 6 @cpu(0)> 
[ 3.5]
<NDArray 1 @cpu(0)> 3.5

而argmax,是找出(指定軸向)最大值的索引下標

from mxnet import ndarray as nd
x = nd.array([1,4,7,3,6])
print(x.argmax(axis=0))

輸出為[ 2.],即:第3列數字7最大。再來個多維矩陣的

如上圖,多維矩陣時,如果指定axis=0,表示軸的方向是縱向(自上而下),顯然第1列中的最大值7在第2行(即:row_index是1),第2列的最大值9在第3行(即:row_index=2),類推第3列的最大值8在第1行(row_index=0),最終輸出的結果就是[1, 2, 0]

如果把axis指定為1,則軸的方向為橫向(自左向右),如下圖:

axis為1時,輸出的索引,為列下標(即:第幾列),顯然8在第2列,7在第0列,9在第1列。

現在我們來想一下:為啥argmax結合mean這二個函式,可以用來評估準確度?

答案:預測的結果也是一個矩陣,通常預測對了,該元素值為1,預測錯誤則為0。

如上圖,假如有3個指標,預測對了2個,第三行,一個都沒預測對,那麼準確率為2/3,即0.6666左右

 

五、訓練

 1 #學習率
 2 learning_rate = .1
 3 
 4 #開始訓練
 5 for epoch in range(5):
 6     train_loss = 0.
 7     train_acc = 0.
 8     for data, label in train_data:
 9         with autograd.record():
10             output = net(data)
11             loss = cross_entropy(output, label)
12         loss.backward()
13         SGD(params, learning_rate / batch_size)
14         train_loss += nd.mean(loss).asscalar()
15         train_acc += accuracy(output, label)
16 
17     test_acc = evaluate_accuracy(test_data, net)
18     print("Epoch %d. Loss: %f, Train acc %f, Test acc %f" % (
19         epoch, train_loss / len(train_data), train_acc / len(train_data), test_acc))
View Code

訓練過程與之前的機器學習筆記(1):線性迴歸 套路一樣,參看之前的即可。

 

六、顯示預測結果

1 #顯示結果    
2 data, label = mnist_test[0:10]
3 show_images(data)
4 print('true labels')
5 print(get_text_labels(label))
6 
7 predicted_labels = net(data).argmax(axis=1)
8 print('predicted labels')
9 print(get_text_labels(predicted_labels.asnumpy()))
View Code

執行結果,參考下圖:

可以看到損失函式的計算值在一直下降(即:計算在收斂),最終的結果中紅線部分為100%預測正確的,其它一些外形相似的分類:襯衣、T恤、套頭衫、外套 這些都是"有袖子類的上衣",並沒有完全預測正確,但整體方向還是對的(即:並沒有把"上衣"識別成"鞋子"或"包包"等明顯不靠譜的分類),最終的模型、演算法及引數有待進一步提高。

相關文章