生成型神經網路

周見智發表於2022-03-25

最簡單的X->Y對映

讀書時代課本上的函式其實就是我們最早接觸關於“對映”的概念,f(x)以“某種邏輯”將橫座標軸上的值對映到縱座標軸。後來學習程式設計之後接觸了函式(或者叫方法)的概念,它以“某種邏輯”將函式輸入對映成輸出,對映邏輯就是函式本身的實現過程。對映就是將若干輸入以某種邏輯轉換成若干輸出,課本上的函式是這樣,程式設計中的函式是這樣,深度學習中神經網路同樣是這樣。

圖1 對映關係 

最簡單的對映邏輯是線性對映,類似f(x)=w*x + b這樣,它的函式影像是一條直線,其中w和b是對映過程用到的引數。高中數學中的拋物線f(x)=a*x*x + b*x +c 是一種比線性對映更復雜的對映關係,同樣a、b、c是該對映過程要用到的引數。 這裡不管是w還是abc這樣的引數,同樣存在於神經網路這種高複雜度的對映過程之中,事實上,我們完全可以將神經網路類比為我們更為熟悉的直線、拋物線結構。

與我們讀書學習直線、拋物線不同的是,書本中直線、拋物線函式的引數大部分時候都是已知的,而神經網路中的引數是未知的,或者說不是最優的,我們需要使用已知的樣本(輸入/輸出)去擬合網路,從而得出最優的引數值,然後將擬合(學習)到的引數應用到新資料中(只有輸入),擬合的這個過程叫做“訓練模型”或者“學習”。

圖2 線性對映與神經網路類比

 

神經網路的輸入輸出

前面說到對映過程有輸入和輸出,同樣神經網路也有輸入和輸出。就計算機視覺領域相關任務來講,有影像分類網路,有目標檢測、分割等網路,也有專門用於影像特徵提取的網路,這些網路的輸入基本都是圖片,輸出要麼是分類概率值,要麼是跟目標有關的畫素座標值,或者是高維特徵向量值,這些輸出有一個共同的特點,都屬於數字類輸出,就像我們剛接觸機器學習時預測房價的案例,演算法輸出是房價,也是數字型別。

 

圖3 普通網路的輸入輸出

那麼神經網路是否還有更復雜的輸出格式呢?答案是有,神經網路可以輸出更高維的資料結構,比如圖片。雖然嚴格來講,圖片也是由數字組成,但是為了與前面提到的簡單輸出型別區分,我們可以將輸出圖片的這類神經網路稱為“生成型網路”。顧名思義,生成型網路的特點是可以生成類似圖片這種直觀結果(或者類似網路輸入那種更高維資料)。

 

圖4 生成型網路的輸入輸出(舉例)

生成型網路的訓練方式一般比較特別,大部分都是無監督學習,也就是事先無需對樣本做人工標註。事實上,瞭解生成型網路是我們熟悉無監督學習的一種非常好的途徑。

生成型網路的實現方式常見有兩種,一個是AutoEncoder,翻譯成中文叫“自編碼器”,基本原理就是將輸入圖片先編碼,生成一個特徵表達,然後再基於該特徵重新解碼生成原來的輸入圖片,編碼和解碼是兩個獨立的環節,訓練的時候合併、使用的時候拆開再交叉組合。前幾年非常火的AI換臉就是基於AutoEncoder實現的,程式碼非常簡單,之前的一篇文章對它有非常詳細的介紹。另一個是GAN,全稱Generative Adversarial Networks,翻譯成中文是“生成對抗型網路”,它也是由兩部分組成,一個負責生成一張圖片,一個負責對圖片進行分類(判斷真假),訓練的時候一起訓練,使用的時候只用前面的生成網路。GAN也是本篇文章後面介紹的重點內容。  

 

GAN生成對抗型網路工作原理

首先要明確地是,最原始的GAN是2014年提出來的,它可以基於一個無任何標註的圖片樣本集,生成類似風格的全新圖片,就像一個畫家,看了一些示例後,可以畫出類似風格的畫。2014年提出來的GAN結構相對簡單,也存在一些缺陷,比如生成的圖片解析度太低、網路訓練很難收斂。後面有人陸陸續續提出了各種改進版本的GAN(變種),雖然各方面都有所調整,但是整體思想仍然沿用第一版,這裡也是介紹最原始GAN網路的工作原理。

GAN網路主要由兩子網路組成,一個學名叫Generative Network(顧名思義,生成網路),輸入隨機數,輸出一張固定尺寸的圖片(比如64*64大小),另一個學名叫Discriminative Network(顧名思義,辨別網路),輸入前一個網路生成的圖片以及訓練樣本集中的圖片,輸出真和假(生成的為假、樣本集中的為真)。如果單獨分開來看,這兩沒什麼特別的,一個負責生成圖片、一個負責判別真假。但是,GAN高妙之處就在於這兩在訓練過程中並不是獨立的。我們先來看下GAN的結構:

 

圖5 GAN網路結構

上面顯示的GAN網路結構其實非常簡單,但是隻看這張圖的人可能會有一個疑問?訓練樣本集只作為辨別網路(Discriminative network,簡稱D)的輸入並參與訓練,它是怎樣對生成網路(Generative network,簡稱G)起到作用的?答案在GAN的訓練過程中,GAN的訓練方式主要可以歸納成3個要點:

1、將G的輸出作為D的輸入,並以False標籤來訓練D網路,更新D網路的引數,讓D知道什麼是假的;

2、將訓練樣本集輸入D,並以True標籤來訓練D網路,更新D網路的引數,讓D知道什麼是真的;

3、D的引數凍結,再將D和G兩個網路串起來,G的輸出還是作為D的輸入,但是這次以True標籤來訓練整個組合網路。由於D的引數已經凍結,所以整個訓練過程只會更新前面G網路的引數,讓G知道如何生成看起來像“真的”圖片。

上面的過程1和過程2讓D越來越聰明,知道什麼是真的、什麼是假的。過程3讓G也越來越聰明,知道如何調整自己的引數,才能讓生成的圖片更能欺騙D。隨著上面的過程不斷反覆,G不斷的更新自己的引數,雖然輸入的一直是隨機數,但是輸出的圖片卻越來越像訓練樣本集中的圖片,到此完成了對G的訓練。下面用圖來說明上面的過程:

 

GAN(變種網路)可以解決哪些任務?

基於GAN思想的生成型網路有很多應用,比如素材生成,當你已經有一定數量的圖片樣本後,你可以利用GAN生成更多類似但完全不同的圖片,提升樣本庫的豐富性。目前比較成熟的風格遷移(style transfer)也用到了GAN思想,可以將一張手機拍攝的自然風景圖片轉換成山水畫、卡通畫風格。下面是利用Pro-GAN網路生成的圖公路畫素材,剛開始生成的影像模糊,隨著訓練時間增加,生成的影像越來越清晰、並且能夠保持與已有樣本集一致的風格。

 

相關文章