動手畫混淆矩陣(Confusion Matrix)(含程式碼)

專注的阿熊發表於2022-08-10

import numpy as np

import matplotlib.pyplot as plt

class DrawConfusionMatrix:

     def __init__(self, labels_name, normalize=True):

         """

normalize :是否設元素為百分比形式

         """

         self.normalize = normalize

         self.labels_name = labels_name

         self.num_classes = len(labels_name)

         self.matrix = np.zeros((self.num_classes, self.num_classes), dtype="float32")

     def update(self, predicts, labels):

         """

         :param predicts: 一維預測向量, eg array([0,5,1,6,3,...],dtype=int64)

         :param labels:    一維標籤向量: eg array([0,5,0,6,2,...],dtype=int64)

         :return:

         """

         for predict, label in zip(predicts, labels):

             self.matrix[predict, label] += 1

     def getMatrix(self,normalize=True):

         """

         根據傳入的 normalize 判斷要進行 percent 的轉換,

         如果 normalize True ,外匯跟單gendan5.com則矩陣元素轉換為百分比形式,

         如果 normalize False ,則矩陣元素就為數量

         Returns: 返回一個以百分比或者數量為元素的矩陣

         """

         if normalize:

             per_sum = self.matrix.sum(axis=1)  # 計算每行的和,用於百分比計算

             for i in range(self.num_classes):

                 self.matrix[i] =(self.matrix[i] / per_sum[i])   # 百分比轉換

             self.matrix=np.around(self.matrix, 2)   # 保留 2 位小數點

             self.matrix[np.isnan(self.matrix)] = 0  # 可能存在 NaN ,將其設為 0

         return self.matrix

     def drawMatrix(self):

         self.matrix = self.getMatrix(self.normalize)

         plt.imshow(self.matrix, cmap=plt.cm.Blues)  # 僅畫出顏色格子,沒有值

         plt.title("Normalized confusion matrix")  # title

         plt.xlabel("Predict label")

         plt.ylabel("Truth label")

         plt.yticks(range(self.num_classes), self.labels_name)  # y 軸標籤

         plt.xticks(range(self.num_classes), self.labels_name, rotation=45)  # x 軸標籤

         for x in range(self.num_classes):

             for y in range(self.num_classes):

                 value = float(format('%.2f' % self.matrix[y, x]))  # 數值處理

                 plt.text(x, y, value, verticalalignment='center', horizontalalignment='center')  # 寫值

         plt.tight_layout()  # 自動調整子圖引數,使之填充整個影像區域

         plt.colorbar()  # 色條

         plt.savefig('./ConfusionMatrix.png', bbox_inches='tight')  # bbox_inches='tight' 可確保標籤資訊顯示全

         plt.show()


來自 “ ITPUB部落格 ” ,連結:http://blog.itpub.net/69946337/viewspace-2909904/,如需轉載,請註明出處,否則將追究法律責任。

相關文章