一文弄懂pytorch搭建網路流程+多分類評價指標

西西嘛呦發表於2021-05-16

講在前面,本來想通過一個簡單的多層感知機實驗一下不同的優化方法的,結果寫著寫著就先研究起評價指標來了,之前也寫過一篇:https://www.cnblogs.com/xiximayou/p/13700934.html
與上篇不同的是,這次我們新加了一些相關的實現,接下來我們慢慢來看。

利用pytorch搭建多層感知機分類的整個流程

匯入相關包

from sklearn.datasets import load_digits
from sklearn.model_selection import train_test_split
import numpy as np
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, classification_report, confusion_matrix
from sklearn.preprocessing import label_binarize
from sklearn.metrics import roc_curve, auc

設定隨機種子

設定隨機種子總是需要的,它可以讓我們的實驗可以復現:即對於隨機初始化的資料生成相同的結果。

np.random.seed(0)
torch.manual_seed(0)
torch.cuda.manual_seed_all(0)

載入資料

使用簡單的sklearn自帶的數字資料:

print("載入資料")
digits = load_digits()
data, label = digits.data, digits.target
# print(data.shape, label.shape)
train_data, test_data, train_label, test_label = train_test_split(data, label, test_size=.3, random_state=123)
print('訓練資料:', train_data.shape)
print('測試資料:', test_data.shape)

定義相關引數

print("定義相關引數")
epochs = 30
batch_size = train_data.shape[0]
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 
input_dim = data.shape[1]
hidden_dim = 256
output_dim = len(set(label))

構建資料集

pytorch構建資料集可以自己實現一個類,繼承Dataset,然後在類中重寫__len__和__getitem__方法。

print("構建資料集")
class DigitsDataset(Dataset):
  def __init__(self, input_data, input_label):
    data = []
    for i,j in zip(input_data, input_label):
      data.append((i,j))
    self.data = data

  def __len__(self):
    return len(self.data)

  def __getitem__(self, index):
    d, l = self.data[index]
    return d, l

在初始化的時候,我們將每一條資料及其標籤放在一個列表中,然後在__len__中計算資料的總量,在__getitem__中根據索引取得每一條資料。接下我們我們要使用DataLoader將定義的資料集轉換為資料載入器。

trainDataset = DigitsDataset(train_data, train_label)
testDataset = DigitsDataset(test_data, test_label)
# print(trainDataset[0])
# print(trainDataset[0])
trainDataLoader = DataLoader(trainDataset, batch_size=batch_size, shuffle=True, num_workers=2)
testDataLoader = DataLoader(testDataset, batch_size=batch_size, shuffle=False, num_workers=2)

定義模型

這裡我們就簡單的實現下多層感知機:

class Model(nn.Module):
  def __init__(self, input_dim, hidden_dim, output_dim):
    super(Model, self).__init__()
    self.fc1 = nn.Linear(input_dim, hidden_dim) 
    self.relu = nn.ReLU()
    self.fc2 = nn.Linear(hidden_dim, output_dim)

  def forward(self, x):
    x = self.fc1(x)
    x = self.relu(x)
    x = self.fc2(x)
    return x

定義損失函式、優化器和初始化相關引數

model = Model(input_dim, hidden_dim, output_dim)
print(model)
model.to(device)

print("定義損失函式、優化器")
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9, weight_decay=1e-4)

print("初始化相關引數")
for param in model.parameters():
  nn.init.normal_(param, mean=0, std=0.01)

進行訓練和測試

這裡我們就僅僅使用sklearn自帶的評價指標函式來計算評價指標:accuracy_score:計算準確率, precision_score:計算精確率, recall_score:計算召回率, f1_score:計算f1, classification_report:分類報告, confusion_matrix:混淆矩陣。具體是怎麼使用的,我們可以直接看程式碼。

print("開始訓練主迴圈")
total_step = len(trainDataLoader)

model.train()
for epoch in range(epochs):
  tot_loss = 0.0
  tot_acc = 0.0
  train_preds = []
  train_trues = []
  # model.train()
  for i,(train_data_batch, train_label_batch) in enumerate(trainDataLoader):
    train_data_batch = train_data_batch.float().to(device) # 將double資料轉換為float
    train_label_batch = train_label_batch.to(device)
    outputs = model(train_data_batch)
    # _, preds = torch.max(outputs.data, 1)
    loss = criterion(outputs, train_label_batch)
    # print(loss)
    #反向傳播優化網路引數
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()
    #累加每個step的損失
    tot_loss += loss.data
    train_outputs = outputs.argmax(dim=1)

    train_preds.extend(train_outputs.detach().cpu().numpy())
    train_trues.extend(train_label_batch.detach().cpu().numpy())

    # tot_acc += (outputs.argmax(dim=1) == train_label_batch).sum().item()

  sklearn_accuracy = accuracy_score(train_trues, train_preds) 
  sklearn_precision = precision_score(train_trues, train_preds, average='micro')
  sklearn_recall = recall_score(train_trues, train_preds, average='micro')
  sklearn_f1 = f1_score(train_trues, train_preds, average='micro')
  print("[sklearn_metrics] Epoch:{} loss:{:.4f} accuracy:{:.4f} precision:{:.4f} recall:{:.4f} f1:{:.4f}".format(epoch, tot_loss, sklearn_accuracy, sklearn_precision, sklearn_recall, sklearn_f1))

test_preds = []
test_trues = []
model.eval()
with torch.no_grad():
  for i,(test_data_batch, test_data_label) in enumerate(testDataLoader):
    test_data_batch = test_data_batch.float().to(device) # 將double資料轉換為float
    test_data_label = test_data_label.to(device)
    test_outputs = model(test_data_batch)
    test_outputs = test_outputs.argmax(dim=1)
    test_preds.extend(test_outputs.detach().cpu().numpy())
    test_trues.extend(test_data_label.detach().cpu().numpy())

  sklearn_precision = precision_score(test_trues, test_preds, average='micro')
  sklearn_recall = recall_score(test_trues, test_preds, average='micro')
  sklearn_f1 = f1_score(test_trues, test_preds, average='micro')
  print(classification_report(test_trues, test_preds))
  conf_matrix = get_confusion_matrix(test_trues, test_preds)
  print(conf_matrix)
  plot_confusion_matrix(conf_matrix)
  print("[sklearn_metrics] accuracy:{:.4f} precision:{:.4f} recall:{:.4f} f1:{:.4f}".format(sklearn_accuracy, sklearn_precision, sklearn_recall, sklearn_f1))

定義和繪製混淆矩陣

額外的,我們補充一下混淆矩陣的計算和繪製。

def get_confusion_matrix(trues, preds):
  labels = [0,1,2,3,4,5,6,7,8,9]
  conf_matrix = confusion_matrix(trues, preds, labels)
  return conf_matrix
  
def plot_confusion_matrix(conf_matrix):
  plt.imshow(conf_matrix, cmap=plt.cm.Greens)
  indices = range(conf_matrix.shape[0])
  labels = [0,1,2,3,4,5,6,7,8,9]
  plt.xticks(indices, labels)
  plt.yticks(indices, labels)
  plt.colorbar()
  plt.xlabel('y_pred')
  plt.ylabel('y_true')
  # 顯示資料
  for first_index in range(conf_matrix.shape[0]):
    for second_index in range(conf_matrix.shape[1]):
      plt.text(first_index, second_index, conf_matrix[first_index, second_index])
  plt.savefig('heatmap_confusion_matrix.jpg')
  plt.show()

結果顯示

載入資料
訓練資料: (1257, 64)
測試資料: (540, 64)
定義相關引數
構建資料集
定義計算評價指標
定義模型
Model(
  (fc1): Linear(in_features=64, out_features=256, bias=True)
  (relu): ReLU()
  (fc2): Linear(in_features=256, out_features=10, bias=True)
)
定義損失函式、優化器
初始化相關引數
開始訓練主迴圈
[sklearn_metrics] Epoch:0 loss:2.2986 accuracy:0.1098 precision:0.1098 recall:0.1098 f1:0.1098
[sklearn_metrics] Epoch:1 loss:2.2865 accuracy:0.1225 precision:0.1225 recall:0.1225 f1:0.1225
[sklearn_metrics] Epoch:2 loss:2.2637 accuracy:0.1702 precision:0.1702 recall:0.1702 f1:0.1702
[sklearn_metrics] Epoch:3 loss:2.2316 accuracy:0.3174 precision:0.3174 recall:0.3174 f1:0.3174
[sklearn_metrics] Epoch:4 loss:2.1915 accuracy:0.5561 precision:0.5561 recall:0.5561 f1:0.5561
[sklearn_metrics] Epoch:5 loss:2.1438 accuracy:0.6881 precision:0.6881 recall:0.6881 f1:0.6881
[sklearn_metrics] Epoch:6 loss:2.0875 accuracy:0.7669 precision:0.7669 recall:0.7669 f1:0.7669
[sklearn_metrics] Epoch:7 loss:2.0213 accuracy:0.8226 precision:0.8226 recall:0.8226 f1:0.8226
[sklearn_metrics] Epoch:8 loss:1.9428 accuracy:0.8409 precision:0.8409 recall:0.8409 f1:0.8409
[sklearn_metrics] Epoch:9 loss:1.8494 accuracy:0.8552 precision:0.8552 recall:0.8552 f1:0.8552
[sklearn_metrics] Epoch:10 loss:1.7397 accuracy:0.8568 precision:0.8568 recall:0.8568 f1:0.8568
[sklearn_metrics] Epoch:11 loss:1.6140 accuracy:0.8632 precision:0.8632 recall:0.8632 f1:0.8632
[sklearn_metrics] Epoch:12 loss:1.4748 accuracy:0.8616 precision:0.8616 recall:0.8616 f1:0.8616
[sklearn_metrics] Epoch:13 loss:1.3259 accuracy:0.8640 precision:0.8640 recall:0.8640 f1:0.8640
[sklearn_metrics] Epoch:14 loss:1.1735 accuracy:0.8703 precision:0.8703 recall:0.8703 f1:0.8703
[sklearn_metrics] Epoch:15 loss:1.0245 accuracy:0.8791 precision:0.8791 recall:0.8791 f1:0.8791
[sklearn_metrics] Epoch:16 loss:0.8858 accuracy:0.8878 precision:0.8878 recall:0.8878 f1:0.8878
[sklearn_metrics] Epoch:17 loss:0.7625 accuracy:0.9006 precision:0.9006 recall:0.9006 f1:0.9006
[sklearn_metrics] Epoch:18 loss:0.6575 accuracy:0.9045 precision:0.9045 recall:0.9045 f1:0.9045
[sklearn_metrics] Epoch:19 loss:0.5709 accuracy:0.9077 precision:0.9077 recall:0.9077 f1:0.9077
[sklearn_metrics] Epoch:20 loss:0.5004 accuracy:0.9093 precision:0.9093 recall:0.9093 f1:0.9093
[sklearn_metrics] Epoch:21 loss:0.4436 accuracy:0.9101 precision:0.9101 recall:0.9101 f1:0.9101
[sklearn_metrics] Epoch:22 loss:0.3982 accuracy:0.9109 precision:0.9109 recall:0.9109 f1:0.9109
[sklearn_metrics] Epoch:23 loss:0.3615 accuracy:0.9149 precision:0.9149 recall:0.9149 f1:0.9149
[sklearn_metrics] Epoch:24 loss:0.3314 accuracy:0.9173 precision:0.9173 recall:0.9173 f1:0.9173
[sklearn_metrics] Epoch:25 loss:0.3065 accuracy:0.9196 precision:0.9196 recall:0.9196 f1:0.9196
[sklearn_metrics] Epoch:26 loss:0.2856 accuracy:0.9228 precision:0.9228 recall:0.9228 f1:0.9228
[sklearn_metrics] Epoch:27 loss:0.2673 accuracy:0.9236 precision:0.9236 recall:0.9236 f1:0.9236
[sklearn_metrics] Epoch:28 loss:0.2512 accuracy:0.9268 precision:0.9268 recall:0.9268 f1:0.9268
[sklearn_metrics] Epoch:29 loss:0.2370 accuracy:0.9300 precision:0.9300 recall:0.9300 f1:0.9300
              precision    recall  f1-score   support

           0       0.98      0.98      0.98        59
           1       0.86      0.86      0.86        56
           2       0.98      0.91      0.94        53
           3       0.98      0.93      0.96        46
           4       0.95      0.97      0.96        61
           5       0.98      0.91      0.95        57
           6       0.96      0.96      0.96        57
           7       0.92      0.98      0.95        50
           8       0.87      0.81      0.84        48
           9       0.77      0.91      0.83        53

    accuracy                           0.92       540
   macro avg       0.93      0.92      0.92       540
weighted avg       0.93      0.92      0.92       540

[[58  0  0  0  1  0  0  0  0  0]
 [ 0 48  0  0  0  0  1  0  0  7]
 [ 0  2 48  0  0  0  0  1  2  0]
 [ 0  0  1 43  0  0  0  1  1  0]
 [ 0  0  0  0 59  0  0  1  1  0]
 [ 0  0  0  0  1 52  0  0  0  4]
 [ 1  1  0  0  0  0 55  0  0  0]
 [ 0  0  0  0  0  0  0 49  0  1]
 [ 0  4  0  0  1  1  1  0 39  2]
 [ 0  1  0  1  0  0  0  1  2 48]]
<Figure size 640x480 with 2 Axes>
[sklearn_metrics] accuracy:0.9241 precision:0.9241 recall:0.9241 f1:0.9241

image

評價指標相關:準確率-精確率-召回率-f1

(1)基本知識
之前我們將pytorch載入資料、建立模型、訓練和測試、使用sklearn評估模型都完整的過了一遍,接下來我們要再細講下評價指標。首先大體的講下四個基本的評價指標(針對於多分類):
accuracy:準確率。準確率就是有多少資料被正確識別了。針對整體,比如
預測標籤[0,1,1,3,2,2,1],
真實標籤[1,2,1,3,2,1,1],
此時正確率就是:5 / 7 = 0.7142,5是指列表中有5個對應的位置是相同的,即第0 2 3 4 6個位置,7是列表總長度。
precision:精確率,就是在預測為正的資料中,有多少是正確的。這裡以標籤為1的那一類說明,對於標籤不為1的,我們先全部置為0(這裡0表示的是負樣本,不是第0類),則有:
預測標籤[0,1,1,0,0,0,1]
真實標籤[1,0,1,0,0,1,1]
在預測標籤中,有3個預測為正,也就是1,在這三個中,有2個是與真實標籤相同,也就是第2 6個位置,則精確率就是:2 / 3 = 0.6666
recall:在正樣本中,有多少被正確識別。還是以標籤為1的進行說明:在真實標籤中有4個為1,在這4箇中有2個被預測出來,所以召回率就是:2 / 4 =0.5000。
f1:綜合考慮精確率和召回率。其值就是2 * p * r) / (p + r)
(2)具體計算
使用到的就是TP、FP、FN、TN,分別解釋一下這些是什麼:
第一位是True False的意思,第二位是Positive Negative。相當於第一位是對第二位的一個判斷。
TP,即True Positive,預測為Positive的是True,也就是預測為正的,真實值是正。
FP,即False Positive,預測為Positive的是False,也就是預測為正的,真實值是負。
FN,即False Negative,預測為Negative的是False,也就是預測為負的,真實值是正。
TN,即True Negative,預測為Negative的是True,也就是預測為負的,真實值是是負。
那麼根據之前我們的定義:
準確率(accuracy)不就是:(TP + FN) / (TP + FP + FN + TN)
精確率(precision)不就是:(TP) / (TP + FP)
召回率(recall)不就是:(TP) / (TP + FN)
f1 不就是:2 * precision * recall / (precision + recall)
(3)micro-f1和macro-f1
簡單來講,micro-f1就是先計算每一類的TP、FP、FN、TN,再計算相關的評價指標,在資料不平衡的情況下考慮到了每一類的數量。macro-f1就是先計算每一類的評價指標,最後取平均,它易受精確率和召回率較大的類的影響。

基本實現

接下來,我們要根據理解去實現評價指標。
(1)基本實現

def get_acc_p_r_f1(trues, preds):
  labels = [0,1,2,3,4,5,6,7,8,9]
  TP,FP,FN,TN = 0,0,0,0
  for label in labels:
    preds_tmp = np.array([1 if pred == label else 0 for pred in preds])
    trues_tmp = np.array([1 if true == label else 0 for true in trues])
    # print(preds_tmp, trues_tmp)
    # print()
    # TP預測為1真實為1
    # TN預測為0真實為0
    # FN預測為0真實為1
    # FP預測為1真實為0
    TP += ((preds_tmp == 1) & (trues_tmp == 1)).sum()
    TN += ((preds_tmp == 0) & (trues_tmp == 0)).sum()
    FN += ((preds_tmp == 0) & (trues_tmp == 1)).sum()
    FP += ((preds_tmp == 1) & (trues_tmp == 0)).sum()
  # print(TP, FP, FN)
  precision = TP / (TP + FP)
  recall = TP / (TP + FN)
  f1 = 2 * precision * recall / (precision + recall)
  return precision, recall, f1

def get_acc(trues, preds):
  accuracy = (np.array(trues) == np.array(preds)).sum() / len(trues)
  return accuracy

具體的就不細講了,程式碼很容易看懂。
(2)根據混淆矩陣實現

def get_p_r_f1_from_conf_matrix(conf_matrix):
  TP,FP,FN,TN = 0,0,0,0
  labels = [0,1,2,3,4,5,6,7,8,9]
  nums = len(labels)
  for i in labels:
    TP += conf_matrix[i, i]
    FP += (conf_matrix[:i, i].sum() + conf_matrix[i+1:, i].sum())
    FN += (conf_matrix[i, i+1:].sum() + conf_matrix[i, :i].sum())
  print(TP, FP, FN)
  precision = TP / (TP + FP)
  recall = TP / (TP + FN)
  f1 = 2 * precision * recall / (precision + recall)
  return precision, recall, f1

def get_acc_from_conf_matrix(conf_matrix):
  labels = [0,1,2,3,4,5,6,7,8,9]
  return sum([conf_matrix[i, i] for i in range(len(labels))]) / np.sum(np.sum(conf_matrix, axis=0))

最終結果

載入資料
訓練資料: (1257, 64)
測試資料: (540, 64)
定義相關引數
構建資料集
定義計算評價指標
定義模型
Model(
  (fc1): Linear(in_features=64, out_features=256, bias=True)
  (relu): ReLU()
  (fc2): Linear(in_features=256, out_features=10, bias=True)
)
定義損失函式、優化器
初始化相關引數
開始訓練主迴圈
[custom_metrics] Epoch:0 loss:2.2986 accuracy:0.1098 precision:0.1098 recall:0.1098 f1:0.1098
[sklearn_metrics] Epoch:0 loss:2.2986 accuracy:0.1098 precision:0.1098 recall:0.1098 f1:0.1098
[custom_metrics] Epoch:1 loss:2.2865 accuracy:0.1225 precision:0.1225 recall:0.1225 f1:0.1225
[sklearn_metrics] Epoch:1 loss:2.2865 accuracy:0.1225 precision:0.1225 recall:0.1225 f1:0.1225
[custom_metrics] Epoch:2 loss:2.2637 accuracy:0.1702 precision:0.1702 recall:0.1702 f1:0.1702
[sklearn_metrics] Epoch:2 loss:2.2637 accuracy:0.1702 precision:0.1702 recall:0.1702 f1:0.1702
[custom_metrics] Epoch:3 loss:2.2316 accuracy:0.3174 precision:0.3174 recall:0.3174 f1:0.3174
[sklearn_metrics] Epoch:3 loss:2.2316 accuracy:0.3174 precision:0.3174 recall:0.3174 f1:0.3174
[custom_metrics] Epoch:4 loss:2.1915 accuracy:0.5561 precision:0.5561 recall:0.5561 f1:0.5561
[sklearn_metrics] Epoch:4 loss:2.1915 accuracy:0.5561 precision:0.5561 recall:0.5561 f1:0.5561
[custom_metrics] Epoch:5 loss:2.1438 accuracy:0.6881 precision:0.6881 recall:0.6881 f1:0.6881
[sklearn_metrics] Epoch:5 loss:2.1438 accuracy:0.6881 precision:0.6881 recall:0.6881 f1:0.6881
[custom_metrics] Epoch:6 loss:2.0875 accuracy:0.7669 precision:0.7669 recall:0.7669 f1:0.7669
[sklearn_metrics] Epoch:6 loss:2.0875 accuracy:0.7669 precision:0.7669 recall:0.7669 f1:0.7669
[custom_metrics] Epoch:7 loss:2.0213 accuracy:0.8226 precision:0.8226 recall:0.8226 f1:0.8226
[sklearn_metrics] Epoch:7 loss:2.0213 accuracy:0.8226 precision:0.8226 recall:0.8226 f1:0.8226
[custom_metrics] Epoch:8 loss:1.9428 accuracy:0.8409 precision:0.8409 recall:0.8409 f1:0.8409
[sklearn_metrics] Epoch:8 loss:1.9428 accuracy:0.8409 precision:0.8409 recall:0.8409 f1:0.8409
[custom_metrics] Epoch:9 loss:1.8494 accuracy:0.8552 precision:0.8552 recall:0.8552 f1:0.8552
[sklearn_metrics] Epoch:9 loss:1.8494 accuracy:0.8552 precision:0.8552 recall:0.8552 f1:0.8552
[custom_metrics] Epoch:10 loss:1.7397 accuracy:0.8568 precision:0.8568 recall:0.8568 f1:0.8568
[sklearn_metrics] Epoch:10 loss:1.7397 accuracy:0.8568 precision:0.8568 recall:0.8568 f1:0.8568
[custom_metrics] Epoch:11 loss:1.6140 accuracy:0.8632 precision:0.8632 recall:0.8632 f1:0.8632
[sklearn_metrics] Epoch:11 loss:1.6140 accuracy:0.8632 precision:0.8632 recall:0.8632 f1:0.8632
[custom_metrics] Epoch:12 loss:1.4748 accuracy:0.8616 precision:0.8616 recall:0.8616 f1:0.8616
[sklearn_metrics] Epoch:12 loss:1.4748 accuracy:0.8616 precision:0.8616 recall:0.8616 f1:0.8616
[custom_metrics] Epoch:13 loss:1.3259 accuracy:0.8640 precision:0.8640 recall:0.8640 f1:0.8640
[sklearn_metrics] Epoch:13 loss:1.3259 accuracy:0.8640 precision:0.8640 recall:0.8640 f1:0.8640
[custom_metrics] Epoch:14 loss:1.1735 accuracy:0.8703 precision:0.8703 recall:0.8703 f1:0.8703
[sklearn_metrics] Epoch:14 loss:1.1735 accuracy:0.8703 precision:0.8703 recall:0.8703 f1:0.8703
[custom_metrics] Epoch:15 loss:1.0245 accuracy:0.8791 precision:0.8791 recall:0.8791 f1:0.8791
[sklearn_metrics] Epoch:15 loss:1.0245 accuracy:0.8791 precision:0.8791 recall:0.8791 f1:0.8791
[custom_metrics] Epoch:16 loss:0.8858 accuracy:0.8878 precision:0.8878 recall:0.8878 f1:0.8878
[sklearn_metrics] Epoch:16 loss:0.8858 accuracy:0.8878 precision:0.8878 recall:0.8878 f1:0.8878
[custom_metrics] Epoch:17 loss:0.7625 accuracy:0.9006 precision:0.9006 recall:0.9006 f1:0.9006
[sklearn_metrics] Epoch:17 loss:0.7625 accuracy:0.9006 precision:0.9006 recall:0.9006 f1:0.9006
[custom_metrics] Epoch:18 loss:0.6575 accuracy:0.9045 precision:0.9045 recall:0.9045 f1:0.9045
[sklearn_metrics] Epoch:18 loss:0.6575 accuracy:0.9045 precision:0.9045 recall:0.9045 f1:0.9045
[custom_metrics] Epoch:19 loss:0.5709 accuracy:0.9077 precision:0.9077 recall:0.9077 f1:0.9077
[sklearn_metrics] Epoch:19 loss:0.5709 accuracy:0.9077 precision:0.9077 recall:0.9077 f1:0.9077
[custom_metrics] Epoch:20 loss:0.5004 accuracy:0.9093 precision:0.9093 recall:0.9093 f1:0.9093
[sklearn_metrics] Epoch:20 loss:0.5004 accuracy:0.9093 precision:0.9093 recall:0.9093 f1:0.9093
[custom_metrics] Epoch:21 loss:0.4436 accuracy:0.9101 precision:0.9101 recall:0.9101 f1:0.9101
[sklearn_metrics] Epoch:21 loss:0.4436 accuracy:0.9101 precision:0.9101 recall:0.9101 f1:0.9101
[custom_metrics] Epoch:22 loss:0.3982 accuracy:0.9109 precision:0.9109 recall:0.9109 f1:0.9109
[sklearn_metrics] Epoch:22 loss:0.3982 accuracy:0.9109 precision:0.9109 recall:0.9109 f1:0.9109
[custom_metrics] Epoch:23 loss:0.3615 accuracy:0.9149 precision:0.9149 recall:0.9149 f1:0.9149
[sklearn_metrics] Epoch:23 loss:0.3615 accuracy:0.9149 precision:0.9149 recall:0.9149 f1:0.9149
[custom_metrics] Epoch:24 loss:0.3314 accuracy:0.9173 precision:0.9173 recall:0.9173 f1:0.9173
[sklearn_metrics] Epoch:24 loss:0.3314 accuracy:0.9173 precision:0.9173 recall:0.9173 f1:0.9173
[custom_metrics] Epoch:25 loss:0.3065 accuracy:0.9196 precision:0.9196 recall:0.9196 f1:0.9196
[sklearn_metrics] Epoch:25 loss:0.3065 accuracy:0.9196 precision:0.9196 recall:0.9196 f1:0.9196
[custom_metrics] Epoch:26 loss:0.2856 accuracy:0.9228 precision:0.9228 recall:0.9228 f1:0.9228
[sklearn_metrics] Epoch:26 loss:0.2856 accuracy:0.9228 precision:0.9228 recall:0.9228 f1:0.9228
[custom_metrics] Epoch:27 loss:0.2673 accuracy:0.9236 precision:0.9236 recall:0.9236 f1:0.9236
[sklearn_metrics] Epoch:27 loss:0.2673 accuracy:0.9236 precision:0.9236 recall:0.9236 f1:0.9236
[custom_metrics] Epoch:28 loss:0.2512 accuracy:0.9268 precision:0.9268 recall:0.9268 f1:0.9268
[sklearn_metrics] Epoch:28 loss:0.2512 accuracy:0.9268 precision:0.9268 recall:0.9268 f1:0.9268
[custom_metrics] Epoch:29 loss:0.2370 accuracy:0.9300 precision:0.9300 recall:0.9300 f1:0.9300
[sklearn_metrics] Epoch:29 loss:0.2370 accuracy:0.9300 precision:0.9300 recall:0.9300 f1:0.9300
              precision    recall  f1-score   support

           0       0.98      0.98      0.98        59
           1       0.86      0.86      0.86        56
           2       0.98      0.91      0.94        53
           3       0.98      0.93      0.96        46
           4       0.95      0.97      0.96        61
           5       0.98      0.91      0.95        57
           6       0.96      0.96      0.96        57
           7       0.92      0.98      0.95        50
           8       0.87      0.81      0.84        48
           9       0.77      0.91      0.83        53

    accuracy                           0.92       540
   macro avg       0.93      0.92      0.92       540
weighted avg       0.93      0.92      0.92       540

[[58  0  0  0  1  0  0  0  0  0]
 [ 0 48  0  0  0  0  1  0  0  7]
 [ 0  2 48  0  0  0  0  1  2  0]
 [ 0  0  1 43  0  0  0  1  1  0]
 [ 0  0  0  0 59  0  0  1  1  0]
 [ 0  0  0  0  1 52  0  0  0  4]
 [ 1  1  0  0  0  0 55  0  0  0]
 [ 0  0  0  0  0  0  0 49  0  1]
 [ 0  4  0  0  1  1  1  0 39  2]
 [ 0  1  0  1  0  0  0  1  2 48]]
<Figure size 640x480 with 2 Axes>
[custom_metrics] accuracy:0.9241 precision:0.9241 recall:0.9241 f1:0.9241
[sklearn_metrics] accuracy:0.9241 precision:0.9241 recall:0.9241 f1:0.9241
[cm_metrics] accuracy:0.9241 precision:0.9241 recall:0.9241 f1:0.9241

我們計算出的和sklearn自帶的計算出的結果是一樣的。為了確保是正確的,這裡我們再列印一下測試的時候的每一類的精確率、召回率和micro-f1。

[custom_metrics] 0 precision:0.9831 recall:0.9831 f1:0.9831
[custom_metrics] 1 precision:0.8571 recall:0.8571 f1:0.8571
[custom_metrics] 2 precision:0.9796 recall:0.9057 f1:0.9412
[custom_metrics] 3 precision:0.9773 recall:0.9348 f1:0.9556
[custom_metrics] 4 precision:0.9516 recall:0.9672 f1:0.9593
[custom_metrics] 5 precision:0.9811 recall:0.9123 f1:0.9455
[custom_metrics] 6 precision:0.9649 recall:0.9649 f1:0.9649
[custom_metrics] 7 precision:0.9245 recall:0.9800 f1:0.9515
[custom_metrics] 8 precision:0.8667 recall:0.8125 f1:0.8387
[custom_metrics] 9 precision:0.7742 recall:0.9057 f1:0.8348
[cm_metrics] 0 precision:0.9831 recall:0.9831 f1:0.9831
[cm_metrics] 1 precision:0.8571 recall:0.8571 f1:0.8571
[cm_metrics] 2 precision:0.9796 recall:0.9057 f1:0.9412
[cm_metrics] 3 precision:0.9773 recall:0.9348 f1:0.9556
[cm_metrics] 4 precision:0.9516 recall:0.9672 f1:0.9593
[cm_metrics] 5 precision:0.9811 recall:0.9123 f1:0.9455
[cm_metrics] 6 precision:0.9649 recall:0.9649 f1:0.9649
[cm_metrics] 7 precision:0.9245 recall:0.9800 f1:0.9515
[cm_metrics] 8 precision:0.8667 recall:0.8125 f1:0.8387
[cm_metrics] 9 precision:0.7742 recall:0.9057 f1:0.8348

和sklearn中的classification_report是一致的。

繪製ROC和計算AUC

最後的最後,繪製ROC曲線和計算AUC,這兩個評價指標就偷個懶,不介紹了,先要將標籤進行二值化:

def get_roc_auc(trues, preds):
  labels = [0,1,2,3,4,5,6,7,8,9]
  nb_classes = len(labels)
  fpr = dict()
  tpr = dict()
  roc_auc = dict()
  print(trues, preds)
  for i in range(nb_classes):
    fpr[i], tpr[i], _ = roc_curve(trues[:, i], preds[:, i])
    roc_auc[i] = auc(fpr[i], tpr[i])
    # Compute micro-average ROC curve and ROC area
  fpr["micro"], tpr["micro"], _ = roc_curve(trues.ravel(), preds.ravel())
  roc_auc["micro"] = auc(fpr["micro"], tpr["micro"])
  # First aggregate all false positive rates
  all_fpr = np.unique(np.concatenate([fpr[i] for i in range(nb_classes)]))
  # Then interpolate all ROC curves at this points
  mean_tpr = np.zeros_like(all_fpr)
  for i in range(nb_classes):
    mean_tpr += interp(all_fpr, fpr[i], tpr[i])
  # Finally average it and compute AUC
  mean_tpr /= nb_classes
  fpr["macro"] = all_fpr
  tpr["macro"] = mean_tpr
  roc_auc["macro"] = auc(fpr["macro"], tpr["macro"])
  # Plot all ROC curves
  lw = 2
  plt.figure()
  plt.plot(fpr["micro"], tpr["micro"],label='micro-average ROC curve (area = {0:0.2f})'.format(roc_auc["micro"]),color='deeppink', linestyle=':', linewidth=4)
  plt.plot(fpr["macro"], tpr["macro"],label='macro-average ROC curve (area = {0:0.2f})'.format(roc_auc["macro"]),color='navy', linestyle=':', linewidth=4)
  colors = cycle(['aqua', 'darkorange', 'cornflowerblue'])
  for i, color in zip(range(nb_classes), colors):
    plt.plot(fpr[i], tpr[i], color=color, lw=lw, label='ROC curve of class {0} (area = {1:0.2f})'.format(i, roc_auc[i]))
  plt.plot([0, 1], [0, 1], 'k--', lw=lw)
  plt.xlim([0.0, 1.0])
  plt.ylim([0.0, 1.05])
  plt.xlabel('False Positive Rate')
  plt.ylabel('True Positive Rate')
  plt.title('Some extension of Receiver operating characteristic to multi-class')
  plt.legend(loc="lower right")
  plt.savefig("ROC_10分類.png")
  plt.show()
  
test_trues = label_binarize(test_trues, classes=[i for i in range(10)])
test_preds = label_binarize(test_preds, classes=[i for i in range(10)])
get_roc_auc(test_trues, test_preds)

image

相關文章