機器學習筆記-多類邏輯迴歸

nt1979發表於2021-09-09

仍然是 動手學嘗試學習系列的筆記,原文見:多類邏輯迴歸 — 從0開始 。 這篇的主要目的,是從一堆服飾圖片中,透過機器學習識別出每個服飾圖片對應的分類是什麼(比如:一個看起來象短袖上衣的圖片,應該歸類到T-Shirt分類)

示例程式碼如下,這篇的程式碼略複雜,分成幾個步驟解讀:

 

一、下載資料,並顯示圖片及標籤

 1 from mxnet import gluon

 2 from mxnet import ndarray as nd

 3 import matplotlib.pyplot as plt

 4 import mxnet as mx

 5 from mxnet import autograd

 6 

 7 def transform(data, label):

 8     return data.astype('float32')/255, label.astype('float32')

 9 

10 #訓練資料集(需聯網下載,網速慢時,會很卡)

11 mnist_train = gluon.data.vision.FashionMNIST(train=True, transform=transform)

12 

13 #測試資料集(需聯網下載)

14 mnist_test = gluon.data.vision.FashionMNIST(train=False, transform=transform)

15 

16 # data, label = mnist_train[0]

17 # ('example shape: ', data.shape, 'label:', label)

18 

19 #顯示服飾圖片

20 def show_images(images):

21     n = images.shape[0]

22     _, figs = plt.subplots(1, n, figsize=(15, 15))

23     for i in range(n):

24         figs[i].imshow(images[i].reshape((28, 28)).asnumpy())

25         figs[i].axes.get_xaxis().set_visible(False)

26         figs[i].axes.get_yaxis().set_visible(False)

27     plt.show()

28 

29 #獲取圖片對應分類標籤文字

30 def get_text_labels(label):

31     text_labels = [

32         'T 恤', '長 褲', '套頭衫', '裙 子', '外 套',

33         '涼 鞋', '襯 衣', '運動鞋', '包 包', '短 靴'

34     ]

35     return [text_labels[int(i)] for i in label]

36 

37 #下面這些程式碼,用於輔助大家理解示例圖片資料集內部結構

38 # tup1 = mnist_train[0:1] #取出訓練集的第1個樣本

39 # print(type(tup1)) # 可以看出這是個元組型別

40 # print(len(tup1)) #2 有2個元素

41 # print(type(tup1[0])) # 第1個元素是一個矩陣

42 # print(type(tup1[1])) # 第2個元素是numpy的矩陣

43 # print(tup1[0].shape) #(1, 28, 28, 1) 第1個元素是一個四維矩陣,用來儲存每張圖中的畫素點對應的值,最後1維表示RGB通道,這裡只取了1個通道

44 # print(tup1[1].shape) #(1,) 第2個元素用於表示圖片對應的文字分類的索引值

45 # print(tup1[0]) #列印第1個元素(即:四維矩陣的值), 結果太長,就不列在註釋裡了

46 # print(tup1[1]) #[2.],列印第2個元素(即:該圖片對應的分類索引數值)

47 # print(get_text_labels(tup1[1])) #顯示分類索引值對應的文字['pullover']

48 

49 #取出訓練集中的圖片資料,以及圖片標籤索引值

50 data, label = mnist_train[0:10]

51 

52 #列印資料集的相關資訊

53 print('example shape: ', data.shape, 'label:', label)

54 

55 #顯示圖片

56 show_images(data)

57 

58 #列印圖片分類標籤

59 print(get_text_labels(label))

首次執行時,可能會很久都沒有反應,讓人誤以為程式碼有問題,其實背後在聯網下載資料,去睡會兒,等醒來的時候,估計就下載好了~_~,下載的資料會儲存在~/.mxnet/datasets/fashion-mnist目錄(mac環境):

圖片描述

下載完成後,上面的程式碼會將圖片資料解析並顯示出來,類似下面這樣:

圖片描述

 

二、讀取資料並初始化引數

 1 #批次讀取資料

 2 batch_size = 256

 3 #訓練集

 4 train_data = gluon.data.DataLoader(mnist_train, batch_size, shuffle=True)

 5 #測試集

 6 test_data = gluon.data.DataLoader(mnist_test, batch_size, shuffle=False)

 7 

 8 #每張圖片的畫素用向量表示,就是28*28的長度,即:784

 9 num_inputs = 784

10 #要預測10張圖片,即:輸出結果長度為10的向量

11 num_outputs = 10

12 

13 #初始化權重W、偏置b引數矩陣

14 W = nd.random_normal(shape=(num_inputs, num_outputs))

15 b = nd.random_normal(shape=num_outputs)

16 

17 params = [W, b]

18 

19 #附加梯度,方便後面用梯度下降法計算

20 for param in params:

21     param.attach_grad()

這與之前的 機器學習筆記(1):線性迴歸 很類似,不再重複解釋 

 

三、建立模型

 1 #歸一化函式

 2 def softmax(X):

 3     exp = nd.exp(X)

 4     partition = exp.sum(axis=1, keepdims=True)

 5     return exp / partition

 6 

 7 #計算模型(仍然是類似y=w.x+b的方程)

 8 def net(X):

 9     return softmax(nd.dot(X.reshape((-1, num_inputs)), W) + b)

10 

11 #損失函式(使用交叉熵函式)

12 def cross_entropy(yhat, y):

13     return - nd.pick(nd.log(yhat), y)

14 

15 #梯度下降法

16 def SGD(params, lr):

17     for param in params:

18         param[:] = param - lr * param.grad

其中softmax(歸一化)及交叉熵cross_entropy,詳情可參考上篇:歸一化(softmax)、資訊熵、交叉熵

 

四、如何評估準確度

 1 #計算準確度

 2 def accuracy(output, label):

 3     return nd.mean(output.argmax(axis=1) == label).asscalar()

 4 

 5 def _get_batch(batch):

 6     if isinstance(batch, mx.io.DataBatch):

 7         data = batch.data[0]

 8         label = batch.label[0]

 9     else:

10         data, label = batch

11     return data, label

12 

13 #評估準確度

14 def evaluate_accuracy(data_iterator, net):

15     acc = 0.

16     if isinstance(data_iterator, mx.io.MXDataIter):

17         data_iterator.reset()

18     for i, batch in enumerate(data_iterator):

19         data, label = _get_batch(batch)

20         output = net(data)

21         acc += accuracy(output, label)

22     return acc / (i+1)

機器學習的效果如何,通常要有一個評價值,上面的函式就是用來估計演算法和模型準確度的。

注: 這裡面用到了二個新的函式mean,argmax 解釋一下

mean類似sql中的avg函式,就是求平均值,即把一個矩陣的所有元數加起來,然後除以元數個數

+ View Code?

123 from mxnet import ndarray as ndx = nd.array([1,2,3,4,5,6]);print(x,x.mean(),(1+2+3+4+5+6)/6.0)

輸出如下:

[ 1.  2.  3.  4.  5.  6.]

 

[ 3.5]

3.5

而argmax,是找出(指定軸向)最大值的索引下標

from mxnet import ndarray as nd

x = nd.array([1,4,7,3,6])

print(x.argmax(axis=0))

輸出為[ 2.],即:第3列數字7最大。再來個多維矩陣的

圖片描述

如上圖,多維矩陣時,如果指定axis=0,表示軸的方向是縱向(自上而下),顯然第1列中的最大值7在第2行(即:row_index是1),第2列的最大值9在第3行(即:row_index=2),類推第3列的最大值8在第1行(row_index=0),最終輸出的結果就是[1, 2, 0]

如果把axis指定為1,則軸的方向為橫向(自左向右),如下圖:

圖片描述

axis為1時,輸出的索引,為列下標(即:第幾列),顯然8在第2列,7在第0列,9在第1列。

現在我們來想一下:為啥argmax結合mean這二個函式,可以用來評估準確度?

答案:預測的結果也是一個矩陣,通常預測對了,該元素值為1,預測錯誤則為0。

圖片描述

如上圖,假如有3個指標,預測對了2個,第三行,一個都沒預測對,那麼準確率為2/3,即0.6666左右

 

五、訓練

 1 #學習率

 2 learning_rate = .1

 3 

 4 #開始訓練

 5 for epoch in range(5):

 6     train_loss = 0.

 7     train_acc = 0.

 8     for data, label in train_data:

 9         with autograd.record():

10             output = net(data)

11             loss = cross_entropy(output, label)

12         loss.backward()

13         SGD(params, learning_rate / batch_size)

14         train_loss += nd.mean(loss).asscalar()

15         train_acc += accuracy(output, label)

16 

17     test_acc = evaluate_accuracy(test_data, net)

18     print("Epoch %d. Loss: %f, Train acc %f, Test acc %f" % (

19         epoch, train_loss / len(train_data), train_acc / len(train_data), test_acc))

訓練過程與之前的機器學習筆記(1):線性迴歸 套路一樣,參看之前的即可。

 

六、顯示預測結果

1 #顯示結果    

2 data, label = mnist_test[0:10]

3 show_images(data)

4 print('true labels')

5 print(get_text_labels(label))

7 predicted_labels = net(data).argmax(axis=1)

8 print('predicted labels')

9 print(get_text_labels(predicted_labels.asnumpy()))

執行結果,參考下圖:

圖片描述

可以看到損失函式的計算值在一直下降(即:計算在收斂),最終的結果中紅線部分為100%預測正確的,其它一些外形相似的分類:襯衣、T恤、套頭衫、外套 這些都是"有袖子類的上衣",並沒有完全預測正確,但整體方向還是對的(即:並沒有把"上衣"識別成"鞋子"或"包包"等明顯不靠譜的分類),最終的模型、演算法及引數有待進一步提高。

來自 “ ITPUB部落格 ” ,連結:http://blog.itpub.net/4822/viewspace-2809038/,如需轉載,請註明出處,否則將追究法律責任。

相關文章