半監督生成對抗網路
一、SGAN簡介
半監督學習(semi-supervised learning)是GAN在實際應用中最有前途的領域之一,與監督學習(資料集中的每個樣本有一個標籤)和無監督學習(不使用任何標籤)不同,半監督學習只為訓練資料集的一小部分提供類別標籤。通過內化資料中的隱藏結構,半監督學習努力從標註資料點的小子集中歸納,以有效地對從未見過的新樣本進行分類,要使半監督學習有效,標籤資料和無標籤資料必須來自相同的基本分佈。
缺少標籤資料集是機器學習研究和實際應用中的主要瓶頸之一,儘管無標籤資料非常豐富(網際網路實際上就是無標籤影像、視訊和文字的無限來源),但為它們分配類別標籤通常非常昂貴、不切實際且耗時。在ImageNet中手工標註320萬張影像用了兩年半的時間,ImageNet是一個標籤影像的資料庫,在過去的十年中對於影像處理和計算機視覺取得的許多進步均有幫助。
訓練需要大量標籤資料是監督學習的致命弱點。目前,工業中的人工智慧應用絕大多數使用監督學習。缺乏大型標籤資料集的一個領域是醫學,醫學上獲取資料(如來自臨床試驗的結果)通常需要耗費大量的精力和開支,更別說會面臨道德倫理和隱私等更嚴重的問題了,因此,提高演算法從越來越少的標註樣本中學習的能力具有巨大的實際意義。
有趣的是,半監督學習可能也是最接近人類學習方式的機器學習方式之一,小學生學習閱讀和書寫時,老師不必帶他們出門旅行,讓他們在路上看到成幹上萬個字母和數字的樣本以後,再根據需要糾正他們一就像監督學習演算法的運作方式一樣。相反,只需要一組樣本可供孩子學習字母和數字,然後不管何種字型、大小、角度、照明條件和許多其他條件下,他們能夠識別出來。半監督學習旨在按照這種有效的方式教會機器。
作為可用於訓練的附加資訊的來源,生成模型己被證明有助於提高半監督模型的準確性。
1. 什麼是SGAN
半監督生成對抗網路(Semi-Supervised GAN, SGAN)是一種生成對抗網路,其判別器是多分類器。這裡的判別器不只是區分兩個類(真和假),而是學會區分N+1類,其中N是訓練資料集中的類數,生成器生成的偽樣本增加了一個類。
例如,MNIST手寫數字資料集有10個標籤(每個數字一個標籤,從0到9),因此在此資料集上訓練的SGAN鑑別器將預測10+1=11個類。在我們的實現中,SGAN判別器的輸出將表示為10個類別的概率(之和為1.0)加上另一個表示影像是真還是假的概率的向量。將判別器從二分類器轉變為多分類器看似是一個微不足道的變化,但其含義比乍看之下更為深遠。我們從下圖所示的SGAN架構開始解釋。(此SGAN中,生成器輸入隨機噪聲向量z並生成偽樣本\(x^*\)。判別器接收3種資料輸入:來自生成器的偽資料、真實的無標籤資料樣本X和真實的標籤資料樣本(x,y),其中y是給定樣本對應的標籤;然後判別器輸出分類,以區分偽樣本與真實樣本區,併為真實樣本確定正確的類別。注意,標籤資料比無標籤資料少得多。實際情況中,這一對比甚至比本圖所顯示的更明顯,標籤資料僅佔訓練資料的一小部分(通常低至1%~2%))
與傳統GAN相此,區分多個類的任務不僅影響了判別器本身,還增加了SGAN架構、訓練過程和訓練目標的複雜性。
SGAN生成器的目的與原始GAN相同:接收一個隨機數向量並生成偽樣本,力求使偽樣本與訓練資料集別無二致。
但是,SGAN判別器與原始GAN實現有很大不同。它接收3種輸入:生成器生成的偽樣本(x)、訓熱資料集中無標籤的真實樣本(x)和有標籤的真實樣本(x,y),其中y表示給定樣本x的標籤。
SGAN判別器的目標不是二分類,而是在輸入樣本為真的情況下,將其正確分類到相應的類中,或將樣本作為假的(可以認為是特殊的附加類)排除。
有關SGAN子網路的要點見下表。
生成器 | 判別器 | |
---|---|---|
輸入 | 一個隨機數向量(z) | 判別器接收3種輸入:訓練資料集中無標籤的真實樣本(x);訓練資料集中有標籤的真實樣本(x,y);生成器生成的偽樣本(\(x^*\)) |
輸出 | 儘可能令人相信的偽樣本(\(x^*\)) | 表示輸入樣本屬於N個真實類別中的某一個或屬於偽樣本類別的可能性 |
目標 | 生成與訓練資料集別無二致的偽樣本,以欺騙判別器,使之將偽樣本分到真實類別 | 學會將正確的類別標籤分配給真實的樣本,同時將來自生成器的所有樣本判別為假 |
3. 訓練過程
回想一下,常規GAN通過計算\(D(x)和D(x^*)\)的損失並反向傳播總損失來更新判別器的可訓練引數,以使損失最小,從而訓練判別器。生成器通過反向傳播判別器損失\(D(x^*)\)並尋求使其最大化來進行訓練,以便讓判別器將合成的偽樣本錯誤地分類為真。
為了訓練SGAN,除了\(D(x)和D(x^*)\),我們還必須計算有監督訓練樣本的損失:\(D((x, y))\)。些損失與SGAN判別器必須達到的雙重目標相對應:區分真偽樣本;學習將真實樣本正分類。用論文中的術語來說,雙重目標對應於兩種損夫:有監督損失(suprvised loss)和無監督損失(unsupervised loss)。
4. 訓練過程
到目前為止,我們看到的GAN變體都是生成模型。它們的目標是生成逼真的資料樣本,正因如此,人們最感興趣的一直是生成器,判別器網路的主要目的是幫助生成器提高生成影像的質量。在訓練結束時,我們通常會忽略判別器,僅使用訓練好的生成器來建立逼真的臺成資料。
在SGAN中主要關心的反而是判別器,訓練過程的目標是使該網路成為僅使用一小部分標籤資料的半監督分類器,其準確率儘可能接近全監督的分類器(其訓練資料集中的每個樣本都有標籤),生成器的目標是通過提供附加資訊(它生成的偽資料)來幫助判別器學習資料中的相關模式,從而提高其分類準確率,訓練結束時,生成器將被丟棄,而訓練有素的判別器將被用作分類器。
二、SGAN的實現
我們將實現一個SGAN模型。該模型僅使用I00個訓練樣本即可對MNIST資料集中的手寫數字進行分類。最後,我們將模型的分類準確率與其對應的全監督模型進行了比較,看看半監督學習所取得的進步。
1. 架構圖
本教程中實現的SGAN模型的高階示意如下圖所示,(生成器將隨機噪聲轉換為偽樣本;判別器輸入有標籤的真實影像(x,y)、無標籤的真實影像(x)和生成器生成的偽影像\((x^*)\)。為了區分真實樣本和偽樣本,判別器使用了sigmoid函式;為了區分真實標籤的分類,判別器使用了softmax函式)它比開頭介紹的一般概念圖要複雜一些。關鍵在於實現細節。
為了解決區分真實標籤的多分類問題,判別器使用了softmax函式,該函式給出了在給定數量的類別(本例中為10類)上的概率分佈,給一個給定類別標籤分配的概率越高,判別器就越確信該樣本屬於這一給定的類,為了計算分類誤差,我們使用了交叉熵損失,以測量輸出概率與目標獨熱編碼標籤之間的差異。
為了輸出樣本是真還是假的概率,判別器使用了sigmoid啟用函式,並通過反向傳播二元交叉熵損失來訓練其引數。
2. 設定
首先匯入執行模型需要的所有模組和庫,指定輸入影像的大小、噪聲向量z的大小以及半監督分類的真實類別的數量(判別器將學習識別每個數字對應的類)。
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
from tensorflow.keras.datasets import mnist
from tensorflow.keras.layers import (
Activation, BatchNormalization, Concatenate, Dense,
Dropout, Flatten, Input, Lambda, Reshape
)
from tensorflow.keras.layers import LeakyReLU
from tensorflow.keras.layers import Conv2D, Conv2DTranspose
from tensorflow.keras.models import Sequential, Model
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.optimizers import Adam
from tensorflow.keras import backend as K
from tensorflow.compat.v1 import ConfigProto
from tensorflow.compat.v1 import InteractiveSession
config = ConfigProto()
config.gpu_options.allow_growth = True
session = InteractiveSession(config=config)
#模型輸入維度
img_rows = 28
img_cols = 28
channels = 1
img_shape = (img_rows, img_cols, channels)#輸入影像的維度
z_dim = 100#噪聲向量的大小
num_classes = 10#資料集中類別的數量
3. 資料集
儘管MNIST訓練資料集裡有50000個有標籤的訓練影像,但我們僅將其中的一小部分(由num_labeled引數決定)用於訓練,並假設其餘影像都是無標籤的。我們這樣來實現這一點:取批量有標籤資料時僅從前num_labeled個影像取樣,而在取批量無標籤資料時從其餘(50000-num_labeled)個影像中取樣。
Dataset物件提供了返回所有num_labeled訓練樣本及其標籤的函式,以及能返回MNIST資料集中所有10000個帶標籤的測試影像的函式。訓練後,我們將使用測試集來評估模型的分類在多大程度上可以推廣到以前未見過的樣本。
class Dataset:
def __init__(self, num_labeled):
self.num_labeled = num_labeled#訓練集中使用的有標籤影像的數量
(self.x_train, self.y_train), (self.x_test, self.y_test) = mnist.load_data()
def preprocess_imgs(x):
x = (x.astype(np.float32) - 127.5) / 127.5#灰度畫素值從[0, 255]縮放到[-1, 1]
x = np.expand_dims(x, axis=3)#將影像尺寸擴充套件到寬x高x通道數
return x
def preprocess_labels(y):
return y.reshape(-1, 1)#將元素轉換成一列
#訓練
self.x_train = preprocess_imgs(self.x_train)
self.y_train = preprocess_labels(self.y_train)
#測試
self.x_test = preprocess_imgs(self.x_test)
self.y_test = preprocess_labels(self.y_test)
def batch_labeled(self, batch_size):
#獲取隨機批量的有標籤影像及其標籤
idx = np.random.randint(0, self.num_labeled, batch_size)
imgs = self.x_train[idx]
labels = self.y_train[idx]
return imgs, labels
def batch_unlabeled(self, batch_size):
#獲取隨機批量的無標籤影像
idx = np.random.randint(self.num_labeled, self.x_train.shape[0], batch_size)
imgs = self.x_train[idx]
return imgs
def training_set(self):
x_train = self.x_train[range(self.num_labeled)]
y_train = self.y_train[range(self.num_labeled)]
return x_train, y_train
def test_set(self):
return self.x_test, self.y_test
num_labeled = 100#要使用的有標籤樣本的數量(其餘作為無標籤樣本使用)
dataset = Dataset(num_labeled)
4. 生成器
def build_generator(z_dim):
model = Sequential()
model.add(Dense(256 * 7 * 7, input_dim=z_dim))
model.add(LeakyReLU(alpha=0.2))
model.add(Reshape((7, 7, 256)))
model.add(Conv2DTranspose(128, 3, 2, padding='same'))
model.add(LeakyReLU(alpha=0.2))
model.add(Conv2DTranspose(128, 3, 2, padding='same'))
model.add(LeakyReLU(alpha=0.2))
model.add(Conv2DTranspose(1, 3, 1, padding='same'))
model.add(LeakyReLU(alpha=0.2))
model.add(Conv2D(1, 7, padding='same'))
model.add(Activation('tanh'))
return model
5. 判別器
判別器是SGAN模型中最複雜的部分,它有如下雙重目標。
- 區分真實樣本和偽樣本。為此,SGAN判別器使用了sigmoid函式,輸出用於二元分類的概率。
- 對於真實樣本,還要對其標籤準確分類。為此,SGAN判別器使用了softmax函式,輸出概率向量——每個目標類別對應一個。
1. 核心判別器網路
我們先來定義核心判別器網路。SGAN判別器模型與DCGAN中實現的基於ConvNet的判別器相似。實際上,直到3×3×128卷積層,它的批歸一化和LeakyReLU啟用與之前的一直是完全相同的。
在該層之後新增了一個Dropout,這是一種正則化技術,通過在訓練過程中隨機丟棄神經元及其與網路的連線來防止過擬合。這就迫使剩餘的神經元減少它們之間的相互依賴,並得到對基礎資料更一般的表示形式。隨機丟棄的神經元比例由比例引數指定,在本實現中將其設定為0.5,即mode1.add(Dropout(0.5))。由於SGAN分類任務的複雜性增加,我們使用了Dropout,以提高模型從只有100個有標籤的樣本中歸納的能力。
def build_discriminator_net(img_shape):
model = Sequential()
model.add(Conv2D(128,kernel_size=3,strides=2,input_shape=img_shape,padding='same'))
model.add(LeakyReLU(alpha=0.2))
model.add(Conv2D(128,kernel_size=3,strides=2,input_shape=img_shape,padding='same'))
model.add(LeakyReLU(alpha=0.2))
model.add(Conv2D(128,kernel_size=3,strides=2,input_shape=img_shape,padding='same'))
model.add(LeakyReLU(alpha=0.2))
model.add(Flatten())
model.add(Dropout(0.4))
model.add(Dense(num_classes))
return model
注意,Dropout層是在批歸一化之後新增的。出於這兩種技不之間的相互作用,這種方法已顯示出優越的效能。
另外,請注意前面的網路以一個具有10個神經元的全連線層結束,接下來,我們需要定義從這些神經元計算出的兩個判別器輸出:一個用於有監督的多分類(使用softmax),另一個用於無監督的二分類(使用sigmoid)。
2.有監督的鑑別器
def build_discriminator_supervised(discriminator_net):
model = Sequential()
model.add(discriminator_net)
model.add(Activation('softmax'))
return model
3. 無監督的判別器
predict(x)這個函式將10個神經元(來自核心判別器網路)的輸出轉換成一個二分類的真假預測。
def build_discriminator_unsupervised(discriminator_net):
model = Sequential()
model.add(discriminator_net)
def predict(x):
prediction = 1.0 - (1.0 / (K.sum(K.exp(x), axis=-1, keepdims=True) + 1.0))#將真實類別的分佈轉換為二元真-假率
return prediction
model.add(Lambda(predict))#之前定義的真假輸出元
return model
6. 搭建整個模型
接下來,我們將構建並編譯判別器模型和生成器模型。注意,有監督損失和無監督損失分別使用categorical_crossentropy和binary_crossentropy損失函式。
def build_sgan(generator, discriminator):
model = Sequential()
model.add(generator)
model.add(discriminator)
return model
discriminator_net = build_discriminator_net(img_shape)#核心判別器網路:這些層在有監督和無監督訓練中共享
#構建並編譯有監督訓練判別器
discriminator_supervised = build_discriminator_supervised(discriminator_net)
discriminator_supervised.compile(
loss='categorical_crossentropy',
metrics=['accuracy'],
optimizer=Adam(lr=0.0002, beta_1=0.5)
)
#構建並編譯無監督訓練判別器
discriminator_unsupervised = build_discriminator_unsupervised(discriminator_net)
discriminator_unsupervised.compile(
loss='binary_crossentropy',
optimizer=Adam(lr=0.0002, beta_1=0.5)
)
#構建生成器
generator = build_generator(z_dim)
discriminator_unsupervised.trainable = False#生成器訓練時,判別器引數保持不變
#構建並編譯判別器固定的GAN模型,以訓練生成器(判別器使用無監督版本)
sgan = build_sgan(generator, discriminator_unsupervised)
sgan.compile(
loss='binary_crossentropy',
optimizer=Adam(lr=0.0002, beta_1=0.5)
)
7. 訓練
以下虛擬碼概述了SGAN的訓練演算法。
SGAN訓練演算法
對每次訓練選代,執行以下操作。
- 訓練判別器(有監督)。
- 隨機取小批量有標籤的真實樣本(x, y)
- 計算給定小批量的D((x, y))並反向傳播多分類損失更新\(\theta^{(D)}\),以使損失最小化。
- 訓練判別器(無監督)。
- 隨機取小批量無標籤的真實樣本x。
- 計算給定小批量的D(x)並反向傳播二元分類損失更新\(\theta^{(D)}\),以使損失最小化。
- 隨機取小批量的隨機噪聲z生成一小批量偽樣本:G(z)=\(x^*\)。
- 計算給定小批量的\(D(x^*)\)並反向傳播二元分類損失更新\(\theta^{(D)}\)以使損失最小化。
- 訓練生成器。
- 隨機取小批量的隨機噪聲z生成一小批量偽樣本:G(z)=\(x^*\)。
- 計算給定小批量的D(\(x^*\))並反向傳播二元分類損失更新\(\theta^{(G)}\)以使損失最大化。
結束
supervised_losses = []
iteration_checkpoints = []
def train(iterations, batch_size, sample_interval):
real = np.ones((batch_size, 1))#真實影像的標籤:全為1
fake = np.zeros((batch_size, 1))#偽影像的標籤:全為0
for iteration in range(iterations):
imgs, labels = dataset.batch_labeled(batch_size)#獲取有標籤樣本
labels = to_categorical(labels, num_classes=num_classes)#獨熱編碼標籤
imgs_unlabeled = dataset.batch_unlabeled(batch_size)#獲取無標籤樣本
#生成一批偽影像
z = np.random.normal(0, 1, (batch_size, z_dim))
gen_imgs = generator.predict(z)
#訓練有標籤的真實樣本
d_loss_supervised, accuracy = discriminator_supervised.train_on_batch(imgs, labels)
#訓練無標籤的真實樣本
d_loss_real = discriminator_unsupervised.train_on_batch(imgs_unlabeled, real)
#訓練偽樣本
d_loss_fake = discriminator_unsupervised.train_on_batch(gen_imgs, fake)
d_loss_unsupervised = 0.5 * np.add(d_loss_real, d_loss_fake)
#生成一批偽樣本
z = np.random.normal(0, 1, (batch_size, z_dim))
gen_imgs = generator.predict(z)
#訓練生成器
g_loss = sgan.train_on_batch(z, np.ones((batch_size, 1)))
if(iteration + 1) % sample_interval == 0:
#儲存判別器的有監督分類損失,以便繪製損失曲線
supervised_losses.append(d_loss_supervised)
iteration_checkpoints.append(iteration + 1)
print("%d [D loss supervised: %.4f, acc.: %.2f%%] [D loss unsupervised: %.4f] [G loss: %f]"
% (iteration + 1, d_loss_supervised, 100 * accuracy, d_loss_unsupervised, g_loss))
sample_image(generator)#輸出生成影像的取樣
1. 生成影像
def sample_image(generator, image_grid_rows=4, image_grid_columns=4):
z = np.random.normal(0, 1, (image_grid_rows * image_grid_columns, 100))#隨機噪聲取樣
gen_imgs = generator.predict(z)#從隨機噪聲生成影像
gen_imgs = 0.5 * gen_imgs + 0.5#影像縮放到[0, 1]:[-1, 1]--->[0, 1]
fig, axs = plt.subplots(
image_grid_rows,
image_grid_columns,
figsize=(4, 4),
sharex=True,
sharey=True
)
cnt = 0
for i in range(image_grid_rows):
for j in range(image_grid_columns):
axs[i, j].imshow(gen_imgs[cnt, :, :, 0], cmap='gray')
axs[i, j].axis('off')
cnt += 1
充分訓練後的SGAN的生成器生成的手寫數字如下圖左所示,為了便於同時比較,DCGAN生成的數字樣本如下圖右所示。我們可以看到,SGAN生成數字樣本明顯優於DCGAN生成的數字樣本。
2. 訓練模型
之所以使用較小的批量,是因為只有100個有標籤的訓練樣本。我們通過反覆試驗確定迭代次數:不斷增加次數,直到判別器的有監督損失趨於平穩,但不要超過穩定點太遠(以降低過擬合的風險)。
iterations = 20000
batch_size = 32
sample_interval = 2000
train(iterations, batch_size, sample_interval)
3. 模型訓練和測試準確率
在訓練過程中,SGAN達到了100%的有監督準確率。儘管這看似很好,但請記住只有100個有標籤的樣本用於有監督訓練——也許模型只是記住了訓練資料集。分類器能在多大程度上泛化到訓練集中未見過的資料上才是重要的。
x, y = dataset.test_set()
y = to_categorical(y, num_classes=num_classes)
_, acc = discriminator_supervised.evaluate(x, y)
print('Test accuracy: %.2f%%' % (100 * acc))
SGAN能夠準確分類測試集中大約93%的樣本,為了解這有多了不起,我們對比一下SGAN和全監督分類器的效能。
三、與全監督分類器的對比
為了使比較儘可能公平,我們讓全監督分類器使用與訓練有監督判別器相同的網路結構。這樣做的意圖在於,這將能突顯出半監督學習GAN對分類器泛化能力的提提高。
#有著與SGAN判別器相同網路結構的全監督分類器
mnist_classifier = build_discriminator_supervised(
build_discriminator_net(img_shape)
)
mnist_classifier.compile(
loss='categorical_crossentropy',
metrics=['acc'],
optimizer=Adam(lr=0.0002, beta_1=0.5)
)
imgs, labels = dataset.training_set()
labels = to_categorical(labels, num_classes=num_classes)
history = mnist_classifier.fit(
imgs,
labels,
batch_size=batch_size,
epochs=30
)
train_result = history.history
trian_loss = train_result['loss']
train_acc = train_result['acc']
epochs = range(1, len(loss) + 1)
plt.figure(figsize=(16, 12))
plt.subplot(221)
plt.grid()
plt.title('全監督分類器訓練損失')
plt.plot(epochs, loss, 'b', label='訓練損失')
plt.xlabel('epochs')
plt.ylabel('loss')
plt.legend(loc='best')
plt.subplot(222)
plt.grid()
plt.title('全監督分類器訓練精度')
plt.plot(epochs, acc, 'r', label='訓練精度')
plt.xlabel('epochs')
plt.ylabel('acc')
plt.legend(loc='best')
mnist_classifier.evaluate(x, y)
與SGAN的判別器一樣,全監督分類器在訓練資料集上達到了接近100%的準確率。但是在測試集上它只能正確分類大約70%的樣本,比SGAN差了約20個百分點。換句話說,SGAN將訓練準確率提高了近30個百分點!
隨著訓練資料的增加,全監督分類器的泛化能力顯著提高。使用相同的設定和訓練,使用10000個有標籤樣本(是最初使用樣本的100倍)訓練的全監督分類器,可以達到約98%的準確率,不過這不是半監督學習。
四、結論
我們通過教判別器輸出真實樣本的類別標籤,來探索如何把GAN用於半監督學習。可以看到,經過SGAN訓練的分類器從少量訓練樣本中泛化的能力明顯優於全監督分類器。
從GAN創新的角度來看,SGAN的主要特點是在判別器訓練中使用標籤,你可能想知道標籤是否也可以用於生成器訓練,條件GAN應用而生。
五、小結
- 半監督生成對抗網路(SGAN)的判別器可用來:區分真實樣本與偽樣本;給真實樣本分配正確的類別標籤。
- SGAN的目的是將判別器訓練成一個分類器,使之可以從儘可能少的有標籤樣本中獲得更高的分類精度,從而減少分類任務對大量標註資料的依賴性。
- 我們將softmax和多元交叉熵損失用於分配真實標籤的有監督任務,將sigmoid和二元交叉熵用於區分真實樣本和偽樣本。
- 我們證明了SGAN對沒見過的測試集資料的分類準確率遠遠優於在相同數量的有標籤樣本上訓練的全監督分類器。