深度學習之新聞多分類問題

無風聽海 發表於 2021-04-15
深度學習

平時除了遇到二分類問題,碰到最多的就是多分類問題,例如我們釋出blogs時候選擇的tag等。如果每個樣本只關聯一個標籤則是單標籤多分類,如果每個樣本可以關聯多個樣本,則是多標籤多分類。今天我們來看下新聞的多分類問題。

一、資料集

這裡使用路透社在1986年釋出的資料集,它包含很多的短新聞及其對應的主題,它包含46個主題,是一個簡單的被廣泛使用的分類資料集。

    def load_data(self):
        return reuters.load_data(num_words=self.num_words)
        
    
    (train_data, train_labels), (test_data, test_labels) = self.load_data()
        print(len(train_data))
        print(len(test_data))
        print(train_data[0])
        print(train_labels[0])

可以看到有8982個訓練樣本及2246個測試樣本,同時也可以看到第一個訓練樣本的內容和標籤都是數字。

8982
2246
[1, 2, 2, 8, 43, 10, 447, 5, 25, 207, 270, 5, 3095, 111, 16, 369, 186, 90, 67, 7, 89, 5, 19, 102, 6, 19, 124, 15, 90, 67, 84, 22, 482, 26, 7, 48, 4, 49, 8, 864, 39, 209, 154, 6, 151, 6, 83, 11, 15, 22, 155, 11, 15, 7, 48, 9, 4579, 1005, 504, 6, 258, 6, 272, 11, 15, 22, 134, 44, 11, 15, 16, 8, 197, 1245, 90, 67, 52, 29, 209, 30, 32, 132, 6, 109, 15, 17, 12]
3

看下第一個訓練樣本的實際內容

    def get_text(self, data):
        word_id_index = reuters.get_word_index()
        id_word_index = dict([(id, value) for (value, id) in word_id_index.items()])
        return ' '.join([id_word_index.get(i - 3, '?') for i in data])
        
    
    print(self.get_text(train_data[0]))

執行後的樣本內容

? ? ? said as a result of its december acquisition of space co it expects earnings per share in 1987 of 1 15 to 1 30 dlrs per share up from 70 cts in 1986 the company said pretax net should rise to nine to 10 mln dlrs from six mln dlrs in 1986 and rental operation revenues to 19 to 22 mln dlrs from 12 5 mln dlrs it said cash flow per share this year should be 2 50 to three dlrs reuter 3

二、資料格式化

使用one-hot方式編碼訓練資料

    def vectorize_sequences(self, sequences, dimension=10000):
        results = np.zeros((len(sequences), dimension))
        for i,sequence in enumerate(sequences):
            results[i, sequence] = 1.
        return results
    
    self.x_train = x_train = self.vectorize_sequences(train_data)
    self.x_test = x_test = self.vectorize_sequences(test_data)

編碼標籤資料

    def to_one_hot(self, labels, dimension=46):
        results = np.zeros((len(labels), dimension))
        for i,label in enumerate(labels):
            results[i, label] = 1
        return results
        
    self.one_hot_train_labels = one_hot_train_labels = self.to_one_hot(train_labels)
    self.one_hot_test_labels = one_hot_test_labels = self.to_one_hot(test_labels)    

三、構建模型

這裡有46個新聞類別,所以中間層的維度不能太少,否則丟失的資訊太多,這裡我們使用64個隱藏單元。

        model = self.model = models.Sequential()
        model.add(layers.Dense(64, activation='relu', input_shape=(10000,)))
        model.add(layers.Dense(64, activation='relu'))
        model.add(layers.Dense(46, activation='softmax'))
        model.compile(optimizer='rmsprop', loss='categorical_crossentropy', metrics='accuracy')

最後一層輸出是46個維度的向量,每個維度程式碼樣本屬於對應分類的概率。
這裡使用便於計算兩個概率分佈距離的分類交叉熵作為損失函式。

四、校驗模型

從訓練集中保留一部分作為校驗資料集。

        x_val = x_train[:1000]
        partial_x_train = x_train[1000:]

        y_val = one_hot_train_labels[:1000]
        partial_y_train = one_hot_train_labels[1000:]

還是以512個樣本作為一個小的批次,訓練20輪。

        history = model.fit(partial_x_train, partial_y_train, epochs=self.epochs, batch_size=512, validation_data=(x_val, y_val))

繪製損失曲線圖

    def plt_loss(self, history):
        plt.clf()
        loss = history.histroy['loss']
        val_loss = history.histroy['val_loss']
        epochs = range(1, len(loss) + 1)
        plt.plot(epochs, loss, 'bo', label='Training loss')
        plt.plot(epochs, val_loss, 'b', label='Validation loss')
        plt.xlabel('Epochs')
        plt.ylabel('Loss')
        plt.legend()
        plt.show()

深度學習之新聞多分類問題

繪製準確度曲線

    def plt_accuracy(self, history):
        plt.clf()
        acc = history.history['accuracy']
        val_acc = history.history['val_accuracy']
        epochs = range(1, len(acc) + 1)

        plt.plot(epochs, acc, 'bo', label='Training accuracy')
        plt.plot(epochs, val_acc, 'b', label='Validation accuracy')
        plt.xlabel('Epochs')
        plt.ylabel('Accuracy')
        plt.legend()
        plt.show()

深度學習之新聞多分類問題

從圖中可以看到訓練到第九輪之後開始出現過擬合,改為9輪進行訓練模型,並在測試機上評估模型。

    def evaluate(self):
        results = self.model.evaluate(self.x_test, self.one_hot_test_labels)
        print('evaluate test data:')
        print(results)

最終訓練之後精度可以達到79%。

evaluate test data:
[0.9847680330276489, 0.7925200462341309]

五、總結

  • 網路最後一層的大小應該跟類別的數量保持一致;
  • 單標籤多分類問題,最後一層需要使用softmax啟用函式,方便輸出概率分佈。
  • 單標籤多分類問題,需要使用分類交叉熵作為損失函式。
  • 中間層的維度不能小於輸出標籤數量。

完整原始碼

from tensorflow.keras.datasets import reuters
import numpy as np
from tensorflow.keras import models
from tensorflow.keras import layers
import matplotlib.pyplot as plt


class MultiClassifier:

    def __init__(self, num_words, epochs):
        self.num_words = num_words
        self.epochs = epochs
        self.model = None
        self.eval = False if epochs == 20 else True

    def load_data(self):
        return reuters.load_data(num_words=self.num_words)

    def get_text(self, data):
        word_id_index = reuters.get_word_index()
        id_word_index = dict([(id, value) for (value, id) in word_id_index.items()])
        return ' '.join([id_word_index.get(i - 3, '?') for i in data])

    def vectorize_sequences(self, sequences, dimension=10000):
        results = np.zeros((len(sequences), dimension))
        for i,sequence in enumerate(sequences):
            results[i, sequence] = 1.
        return results

    def to_one_hot(self, labels, dimension=46):
        results = np.zeros((len(labels), dimension))
        for i,label in enumerate(labels):
            results[i, label] = 1
        return results

    def plt_loss(self, history):
        plt.clf()
        loss = history.history['loss']
        val_loss = history.history['val_loss']
        epochs = range(1, len(loss) + 1)
        plt.plot(epochs, loss, 'bo', label='Training loss')
        plt.plot(epochs, val_loss, 'b', label='Validation loss')
        plt.xlabel('Epochs')
        plt.ylabel('Loss')
        plt.legend()
        plt.show()

    def plt_accuracy(self, history):
        plt.clf()
        acc = history.history['accuracy']
        val_acc = history.history['val_accuracy']
        epochs = range(1, len(acc) + 1)

        plt.plot(epochs, acc, 'bo', label='Training accuracy')
        plt.plot(epochs, val_acc, 'b', label='Validation accuracy')
        plt.xlabel('Epochs')
        plt.ylabel('Accuracy')
        plt.legend()
        plt.show()

    def evaluate(self):
        results = self.model.evaluate(self.x_test, self.one_hot_test_labels)
        print('evaluate test data:')
        print(results)


    def train(self):
        (train_data, train_labels), (test_data, test_labels) = self.load_data()
        print(len(train_data))
        print(len(test_data))
        print(train_data[0])
        print(train_labels[0])
        print(self.get_text(train_data[0]))

        self.x_train = x_train = self.vectorize_sequences(train_data)
        self.x_test = x_test = self.vectorize_sequences(test_data)

        self.one_hot_train_labels = one_hot_train_labels = self.to_one_hot(train_labels)
        self.one_hot_test_labels = one_hot_test_labels = self.to_one_hot(test_labels)

        model = self.model = models.Sequential()
        model.add(layers.Dense(64, activation='relu', input_shape=(10000,)))
        model.add(layers.Dense(64, activation='relu'))
        model.add(layers.Dense(46, activation='softmax'))
        model.compile(optimizer='rmsprop', loss='categorical_crossentropy', metrics='accuracy')

        x_val = x_train[:1000]
        partial_x_train = x_train[1000:]

        y_val = one_hot_train_labels[:1000]
        partial_y_train = one_hot_train_labels[1000:]

        history = model.fit(partial_x_train, partial_y_train, epochs=self.epochs, batch_size=512, validation_data=(x_val, y_val))



        if self.eval:
            self.evaluate()
            print(self.model.predict(x_test))
        else:
            self.plt_loss(history)
            self.plt_accuracy(history)

classifier = MultiClassifier(num_words=10000, epochs=20)

# classifier = MultiClassifier(num_words=10000, epochs=9)
classifier.train()