機器學習筆記-多類邏輯迴歸
仍然是 動手學嘗試學習系列的筆記,原文見:多類邏輯迴歸 — 從0開始 。 這篇的主要目的,是從一堆服飾圖片中,透過機器學習識別出每個服飾圖片對應的分類是什麼(比如:一個看起來象短袖上衣的圖片,應該歸類到T-Shirt分類)
示例程式碼如下,這篇的程式碼略複雜,分成幾個步驟解讀:
一、下載資料,並顯示圖片及標籤
1 from mxnet import gluon
2 from mxnet import ndarray as nd
3 import matplotlib.pyplot as plt
4 import mxnet as mx
5 from mxnet import autograd
6
7 def transform(data, label):
8 return data.astype('float32')/255, label.astype('float32')
9
10 #訓練資料集(需聯網下載,網速慢時,會很卡)
11 mnist_train = gluon.data.vision.FashionMNIST(train=True, transform=transform)
12
13 #測試資料集(需聯網下載)
14 mnist_test = gluon.data.vision.FashionMNIST(train=False, transform=transform)
15
16 # data, label = mnist_train[0]
17 # ('example shape: ', data.shape, 'label:', label)
18
19 #顯示服飾圖片
20 def show_images(images):
21 n = images.shape[0]
22 _, figs = plt.subplots(1, n, figsize=(15, 15))
23 for i in range(n):
24 figs[i].imshow(images[i].reshape((28, 28)).asnumpy())
25 figs[i].axes.get_xaxis().set_visible(False)
26 figs[i].axes.get_yaxis().set_visible(False)
27 plt.show()
28
29 #獲取圖片對應分類標籤文字
30 def get_text_labels(label):
31 text_labels = [
32 'T 恤', '長 褲', '套頭衫', '裙 子', '外 套',
33 '涼 鞋', '襯 衣', '運動鞋', '包 包', '短 靴'
34 ]
35 return [text_labels[int(i)] for i in label]
36
37 #下面這些程式碼,用於輔助大家理解示例圖片資料集內部結構
38 # tup1 = mnist_train[0:1] #取出訓練集的第1個樣本
39 # print(type(tup1)) #
40 # print(len(tup1)) #2 有2個元素
41 # print(type(tup1[0])) #
42 # print(type(tup1[1])) #
43 # print(tup1[0].shape) #(1, 28, 28, 1) 第1個元素是一個四維矩陣,用來儲存每張圖中的畫素點對應的值,最後1維表示RGB通道,這裡只取了1個通道
44 # print(tup1[1].shape) #(1,) 第2個元素用於表示圖片對應的文字分類的索引值
45 # print(tup1[0]) #列印第1個元素(即:四維矩陣的值),
46 # print(tup1[1]) #[2.],列印第2個元素(即:該圖片對應的分類索引數值)
47 # print(get_text_labels(tup1[1])) #顯示分類索引值對應的文字['pullover']
48
49 #取出訓練集中的圖片資料,以及圖片標籤索引值
50 data, label = mnist_train[0:10]
51
52 #列印資料集的相關資訊
53 print('example shape: ', data.shape, 'label:', label)
54
55 #顯示圖片
56 show_images(data)
57
58 #列印圖片分類標籤
59 print(get_text_labels(label))
首次執行時,可能會很久都沒有反應,讓人誤以為程式碼有問題,其實背後在聯網下載資料,去睡會兒,等醒來的時候,估計就下載好了~_~,下載的資料會儲存在~/.mxnet/datasets/fashion-mnist目錄(mac環境):
下載完成後,上面的程式碼會將圖片資料解析並顯示出來,類似下面這樣:
二、讀取資料並初始化引數
1 #批次讀取資料
2 batch_size = 256
3 #訓練集
4 train_data = gluon.data.DataLoader(mnist_train, batch_size, shuffle=True)
5 #測試集
6 test_data = gluon.data.DataLoader(mnist_test, batch_size, shuffle=False)
7
8 #每張圖片的畫素用向量表示,就是28*28的長度,即:784
9 num_inputs = 784
10 #要預測10張圖片,即:輸出結果長度為10的向量
11 num_outputs = 10
12
13 #初始化權重W、偏置b引數矩陣
14 W = nd.random_normal(shape=(num_inputs, num_outputs))
15 b = nd.random_normal(shape=num_outputs)
16
17 params = [W, b]
18
19 #附加梯度,方便後面用梯度下降法計算
20 for param in params:
21 param.attach_grad()
這與之前的 機器學習筆記(1):線性迴歸 很類似,不再重複解釋
三、建立模型
1 #歸一化函式
2 def softmax(X):
3 exp = nd.exp(X)
4 partition = exp.sum(axis=1, keepdims=True)
5 return exp / partition
6
7 #計算模型(仍然是類似y=w.x+b的方程)
8 def net(X):
9 return softmax(nd.dot(X.reshape((-1, num_inputs)), W) + b)
10
11 #損失函式(使用交叉熵函式)
12 def cross_entropy(yhat, y):
13 return - nd.pick(nd.log(yhat), y)
14
15 #梯度下降法
16 def SGD(params, lr):
17 for param in params:
18 param[:] = param - lr * param.grad
其中softmax(歸一化)及交叉熵cross_entropy,詳情可參考上篇:歸一化(softmax)、資訊熵、交叉熵
四、如何評估準確度
1 #計算準確度
2 def accuracy(output, label):
3 return nd.mean(output.argmax(axis=1) == label).asscalar()
4
5 def _get_batch(batch):
6 if isinstance(batch, mx.io.DataBatch):
7 data = batch.data[0]
8 label = batch.label[0]
9 else:
10 data, label = batch
11 return data, label
12
13 #評估準確度
14 def evaluate_accuracy(data_iterator, net):
15 acc = 0.
16 if isinstance(data_iterator, mx.io.MXDataIter):
17 data_iterator.reset()
18 for i, batch in enumerate(data_iterator):
19 data, label = _get_batch(batch)
20 output = net(data)
21 acc += accuracy(output, label)
22 return acc / (i+1)
機器學習的效果如何,通常要有一個評價值,上面的函式就是用來估計演算法和模型準確度的。
注: 這裡面用到了二個新的函式mean,argmax 解釋一下
mean類似sql中的avg函式,就是求平均值,即把一個矩陣的所有元數加起來,然後除以元數個數
+ View Code?
123 |
from mxnet import ndarray as nd x = nd.array([ 1 , 2 , 3 , 4 , 5 , 6 ]); print (x,x.mean(),( 1 + 2 + 3 + 4 + 5 + 6 ) / 6.0 )
|
輸出如下:
[ 1. 2. 3. 4. 5. 6.]
[ 3.5]
而argmax,是找出(指定軸向)最大值的索引下標
from mxnet import ndarray as nd
x = nd.array([1,4,7,3,6])
print(x.argmax(axis=0))
輸出為[ 2.],即:第3列數字7最大。再來個多維矩陣的
如上圖,多維矩陣時,如果指定axis=0,表示軸的方向是縱向(自上而下),顯然第1列中的最大值7在第2行(即:row_index是1),第2列的最大值9在第3行(即:row_index=2),類推第3列的最大值8在第1行(row_index=0),最終輸出的結果就是[1, 2, 0]
如果把axis指定為1,則軸的方向為橫向(自左向右),如下圖:
axis為1時,輸出的索引,為列下標(即:第幾列),顯然8在第2列,7在第0列,9在第1列。
現在我們來想一下:為啥argmax結合mean這二個函式,可以用來評估準確度?
答案:預測的結果也是一個矩陣,通常預測對了,該元素值為1,預測錯誤則為0。
如上圖,假如有3個指標,預測對了2個,第三行,一個都沒預測對,那麼準確率為2/3,即0.6666左右
五、訓練
1 #學習率
2 learning_rate = .1
3
4 #開始訓練
5 for epoch in range(5):
6 train_loss = 0.
7 train_acc = 0.
8 for data, label in train_data:
9 with autograd.record():
10 output = net(data)
11 loss = cross_entropy(output, label)
12 loss.backward()
13 SGD(params, learning_rate / batch_size)
14 train_loss += nd.mean(loss).asscalar()
15 train_acc += accuracy(output, label)
16
17 test_acc = evaluate_accuracy(test_data, net)
18 print("Epoch %d. Loss: %f, Train acc %f, Test acc %f" % (
19 epoch, train_loss / len(train_data), train_acc / len(train_data), test_acc))
訓練過程與之前的機器學習筆記(1):線性迴歸 套路一樣,參看之前的即可。
六、顯示預測結果
1 #顯示結果
2 data, label = mnist_test[0:10]
3 show_images(data)
4 print('true labels')
5 print(get_text_labels(label))
6
7 predicted_labels = net(data).argmax(axis=1)
8 print('predicted labels')
9 print(get_text_labels(predicted_labels.asnumpy()))
執行結果,參考下圖:
可以看到損失函式的計算值在一直下降(即:計算在收斂),最終的結果中紅線部分為100%預測正確的,其它一些外形相似的分類:襯衣、T恤、套頭衫、外套 這些都是"有袖子類的上衣",並沒有完全預測正確,但整體方向還是對的(即:並沒有把"上衣"識別成"鞋子"或"包包"等明顯不靠譜的分類),最終的模型、演算法及引數有待進一步提高。
來自 “ ITPUB部落格 ” ,連結:http://blog.itpub.net/4822/viewspace-2809038/,如需轉載,請註明出處,否則將追究法律責任。
相關文章
- 機器學習:邏輯迴歸機器學習邏輯迴歸
- 機器學習之邏輯迴歸機器學習邏輯迴歸
- 機器學習整理(邏輯迴歸)機器學習邏輯迴歸
- 機器學習 | 線性迴歸與邏輯迴歸機器學習邏輯迴歸
- 學習筆記——機器學習演算法(一): 基於邏輯迴歸的分類預測筆記機器學習演算法邏輯迴歸
- 人工智慧-機器學習-邏輯迴歸人工智慧機器學習邏輯迴歸
- 【機器學習基礎】邏輯迴歸——LogisticRegression機器學習邏輯迴歸
- 機器學習之邏輯迴歸:計算機率機器學習邏輯迴歸計算機
- 【機器學習】邏輯迴歸過程推導機器學習邏輯迴歸
- 機器學習之邏輯迴歸:計算概率機器學習邏輯迴歸
- 機器學習之邏輯迴歸:模型訓練機器學習邏輯迴歸模型
- 機器學習之使用Python完成邏輯迴歸機器學習Python邏輯迴歸
- 【6%】100小時機器學習——邏輯迴歸機器學習邏輯迴歸
- 從零開始學機器學習——邏輯迴歸機器學習邏輯迴歸
- 機器學習入門 - 快速掌握邏輯迴歸模型機器學習邏輯迴歸模型
- 手擼機器學習演算法 - 邏輯迴歸機器學習演算法邏輯迴歸
- 機器學習筆記(2): Logistic 迴歸機器學習筆記
- 機器學習演算法--邏輯迴歸原理介紹機器學習演算法邏輯迴歸
- 機器學習演算法(一): 基於邏輯迴歸的分類預測機器學習演算法邏輯迴歸
- 機器學習(三):理解邏輯迴歸及二分類、多分類程式碼實踐機器學習邏輯迴歸
- 機器學習演算法:Logistic迴歸學習筆記機器學習演算法筆記
- 機器學習(六):迴歸分析——鳶尾花多變數回歸、邏輯迴歸三分類只用numpy,sigmoid、實現RANSAC 線性擬合機器學習變數邏輯迴歸Sigmoid
- 【火爐煉AI】機器學習009-用邏輯迴歸分類器解決多分類問題AI機器學習邏輯迴歸
- [DataAnalysis]機器學習演算法——線性模型(邏輯迴歸+LDA)機器學習演算法模型邏輯迴歸LDA
- 從零開始學習邏輯迴歸邏輯迴歸
- 吳恩達機器學習筆記 —— 5 多變數線性迴歸吳恩達機器學習筆記變數
- 【《白話機器學習的數學》筆記1】迴歸機器學習筆記
- 機器學習-邏輯迴歸:從技術原理到案例實戰機器學習邏輯迴歸
- 100天搞定機器學習|Day17-18 神奇的邏輯迴歸機器學習邏輯迴歸
- 邏輯迴歸邏輯迴歸
- 吳恩達機器學習筆記 —— 7 Logistic迴歸吳恩達機器學習筆記
- 【機器學習筆記】:大話線性迴歸(二)機器學習筆記
- 【機器學習筆記】:大話線性迴歸(一)機器學習筆記
- 機器學習簡介之基礎理論- 線性迴歸、邏輯迴歸、神經網路機器學習邏輯迴歸神經網路
- 數學推導+純Python實現機器學習演算法:邏輯迴歸Python機器學習演算法邏輯迴歸
- 把ChatGPT調教成機器學習專家,以邏輯迴歸模型的學習為例ChatGPT機器學習邏輯迴歸模型
- 【機器學習】求解邏輯迴歸引數(三種方法程式碼實現)機器學習邏輯迴歸
- 《應用迴歸及分類》學習筆記1筆記