丁柏楠, 王淏, 袁文翠, 吳圣潔
(1.東北石油大學 計算機與信息技術學院,黑龍江 大慶 163319;2.吉林大學 物理學院,吉林 長春 130012)
在過去的幾年里,深度神經(jīng)網(wǎng)絡在具有充足數(shù)據(jù)的情況下在多種任務上取得了舉世矚目的進展。它已經(jīng)被運用到了各個領域,例如圖像分類、機器翻譯、自然語言處理、圖像合成、語音處理等。無論在哪個領域,都必須使用巨量的數(shù)據(jù)來訓練網(wǎng)絡獲得令人滿意的結(jié)果。但在很多現(xiàn)實情況下,一些數(shù)據(jù)難以收集,或是收集大量數(shù)據(jù)的代價過于高昂,因此需要在有限的數(shù)據(jù)下達到期望的目標。而數(shù)據(jù)不足會導致神經(jīng)網(wǎng)絡在訓練集上出現(xiàn)過擬合,在測試集上出現(xiàn)較差的泛化能力的情況。
近些年,有許多技術被開發(fā)出來幫助對抗訓練時的過擬合問題,例如Dropout、Batch Normalization、Layer Normalization等方法。但在樣本過少的訓練中,即使使用了這些方法,也不能顯著地提升網(wǎng)絡訓練的穩(wěn)定性。故為了應對小樣本數(shù)據(jù)集訓練問題,一個主要措施就是數(shù)據(jù)增強方法。
數(shù)據(jù)增強[1]方法通常用于處理分類問題,通過合成或者轉(zhuǎn)換的方式, 從已有的數(shù)據(jù)中生成新的數(shù)據(jù),目前圖像領域的數(shù)據(jù)增強技術分為兩種,一種是有監(jiān)督的數(shù)據(jù)增強方法,采用預設的數(shù)據(jù)變換規(guī)則,例如隨機旋轉(zhuǎn)、隨機裁剪、色彩抖動、添加高斯噪聲等,但這些改變只局限于當前標簽類別,沒有從根本上解決數(shù)據(jù)不足的問題。另一種是無監(jiān)督的數(shù)據(jù)增強方法,通過使用模型學習現(xiàn)有數(shù)據(jù)的分布,生成與訓練數(shù)據(jù)集分布一致的數(shù)據(jù),例如基于生成對抗網(wǎng)絡的數(shù)據(jù)增強方法[2]。這種基于模型的方法可以合成相比傳統(tǒng)數(shù)據(jù)增強方法更加多樣的新數(shù)據(jù),但仍然只是針對單一類別進行生成。本文展示的是基于無監(jiān)督的多模態(tài)圖像轉(zhuǎn)換(Multimodal Unsupervised Image-to-image Translation,MUNIT)的數(shù)據(jù)增強技術,并將其應用于小樣本訓練的分類任務中。
MUNIT是由Xun Huang等[3]提出的多模態(tài)圖像轉(zhuǎn)換模型,是基于UNIT[4]網(wǎng)絡經(jīng)過改進后能夠生成非單一映射的圖像轉(zhuǎn)換網(wǎng)絡,其在圖像風格轉(zhuǎn)換的多樣性和生成的圖像質(zhì)量上的表現(xiàn)十分出色。MUNIT的根本思想將圖像分解為域不變的內(nèi)容編碼和樣式編碼,將多個域圖像的內(nèi)容存放于一個共享空間C中,從目標域捕獲樣式代碼,將內(nèi)容與樣式相組合,以此實現(xiàn)了將圖像轉(zhuǎn)換到另一個域的目標。
MUNIT網(wǎng)絡結(jié)構(gòu)如圖1所示。其中圖1(a)圖像重構(gòu)網(wǎng)絡表示對應每個圖像域,都有一組內(nèi)容編碼器和風格編碼器,以及用于重構(gòu)圖像的解碼器;其中的圖1(b)交叉圖像域的圖像轉(zhuǎn)換網(wǎng)絡展示了從x1圖像域轉(zhuǎn)換成x2圖像域的過程,由x1的內(nèi)容編碼器,x2的風格編碼器,和一個生成對抗網(wǎng)絡(Generative Adversarial Network, GAN)[5]構(gòu)成。
(a)圖像重構(gòu)網(wǎng)絡
在很多的圖像分類任務中,不同類別之間的底層特征十分相近,例如輪廓、大小等,但其高層特征會在很大程度上區(qū)分為不同的類別,例如不同的貓科動物之間的分類任務。本文針對這類問題,通過改進MUNIT網(wǎng)絡的部分結(jié)構(gòu)和訓練算法,設計了一種基于MUNIT的數(shù)據(jù)增強方法。
將MUNIT應用于數(shù)據(jù)增強的網(wǎng)絡結(jié)構(gòu)設計稱之為DA-MUNIT(Data Augmentation MUNIT)網(wǎng)絡,如圖2所示。
圖2 DA-MUNIT網(wǎng)絡結(jié)構(gòu)
該數(shù)據(jù)增強結(jié)構(gòu)首先將小樣本數(shù)據(jù)集中的不同標簽類別的圖像分別經(jīng)自動編碼器訓練生成其內(nèi)容空間Ci和風格空間Si,再針對每一個標簽類別的內(nèi)容,以除此之外的標簽圖像數(shù)據(jù)作為目標域,從其風格空間S中隨機采樣服從分布q(sj)~N(0,1)的風格潛在編碼sj,經(jīng)生成器G生成新的合成圖像xi→j=G(ci,sj),將所有新生成的圖像X的標簽類別標記為其風格潛在編碼所在圖像域的標簽類別。最后將所有新生成的已有標簽的圖像分別加入到原數(shù)據(jù)集相同標簽類別下,作為增強后的新數(shù)據(jù)集。新生成的圖像的底層特征是X1圖像域和X2圖像域共享的內(nèi)容空間,即該底層特征在兩個圖像域中的區(qū)分度不是很高,這可能造成分類器在通過未經(jīng)過數(shù)據(jù)增強的數(shù)據(jù)上訓練時,于底層特征分類上的表現(xiàn)較差。而加入了新生成的圖像后,由于新生成的圖像的高層特征屬于X2圖像域,底層特征屬于X1圖像域,并且其標簽類別與X2圖像相同,這就使分類器在底層內(nèi)容特征相似的數(shù)據(jù)中能夠更好地區(qū)別不同類別。
DA-MUNIT網(wǎng)絡的訓練分為2個階段,第一階段為原數(shù)據(jù)圖像重構(gòu)訓練階段。該訓練階段的實施過程如圖3所示。
圖3 圖像重構(gòu)訓練過程
該階段的目的就是通過一個圖像重構(gòu)網(wǎng)絡的訓練來得到數(shù)據(jù)集中各種類別的內(nèi)容代碼Ci和風格代碼Si。在該階段,我們將原MUNIT網(wǎng)絡的下采樣和上采樣部分改為了Pix2Pix圖像轉(zhuǎn)換網(wǎng)絡[6],最終的辨別器使用了Patch Discriminator,該辨別器用于使生成網(wǎng)絡更快地得到內(nèi)容編碼和風格編碼。其中,內(nèi)容編碼器的下采樣塊我們使用了8個編碼塊,每塊包含一個卷積層,BatchNorm層和激活函數(shù)LeakyRelu。解碼器的上采樣塊也包含8個解碼塊,每塊包含一個Trans Conv層、BatchNorm層和激活函數(shù)Relu。編碼器的每塊都與對應的解碼塊進行了跳轉(zhuǎn)鏈接,以此來給解碼器提供更多的內(nèi)容信息。
第一階段的損失函數(shù)包括兩部分,即雙向重構(gòu)損失和潛在重構(gòu)損失。
雙向重構(gòu)損失(Bidirectional reconstruction loss),即經(jīng)過重構(gòu)的圖像與原圖像像素距離損失,為式(1)。
(1)
潛在重構(gòu)損失(Latent reconstruction loss),分為目標域生成圖像對要轉(zhuǎn)換的圖像的內(nèi)容損失函數(shù)和目標域生成圖像對要采用的潛在風格域?qū)獔D像的風格損失函數(shù),2個損失均使用L1損失,因其更支持輸出清晰的圖像。2個損失函數(shù)為式(2)。
(2)
訓練的第二階段為使用第一階段得到的各種類圖片的風格潛在編碼S和共享的內(nèi)容空間C來訓練生成各個圖像域?qū)D(zhuǎn)換的圖像數(shù)據(jù)。為了得到Xi→j的轉(zhuǎn)換圖像,該階段從Xi下采樣得到Xi圖像域的內(nèi)容編碼,同時從Xj下采樣得到Xj圖像域的風格編碼,最終通過生成對抗網(wǎng)絡生成Xi→j的轉(zhuǎn)換圖像。該階段的訓練過程如圖4所示。其中風格編碼器部分設計均和圖3相同。該部分將不使用跳轉(zhuǎn)鏈接和Patch Discriminator。
圖4 圖像轉(zhuǎn)換過程
該階段的損失函數(shù)為對抗損失(Adversarial loss)。最終使用GAN來匹配轉(zhuǎn)換后圖像的分布到目標圖像域的分布。Xun Huang等提出的MUNIT網(wǎng)絡中該部分的對抗損失使用的是二元交叉熵BCE損失函數(shù),而我們提出使用最小二乘損失函數(shù),因為這樣可以在訓練過程中增加一定的穩(wěn)定性[7],并且相比使用BCE損失函數(shù),生成的圖像更加真實。對抗損失函數(shù)為式(3)。
(3)
綜上,DA-MUNIT網(wǎng)絡的總損失為式(4)。
(4)
其中,λx、λc和λs為控制重構(gòu)項對輸出結(jié)果的影響性的權重。
為了驗證DA-MUNIT的新樣本生成能力以及生成的樣本能否在一定程度上提高小樣本分類器的分類準確率,我們在ImageNet數(shù)據(jù)集上人工選取了2 000張貓科動物圖片,其中1 400張作為訓練集,600張為測試集數(shù)據(jù)。圖像分辨率為32×32,總共可分為5類。
為了研究對于不同大小的小樣本數(shù)據(jù)集下數(shù)據(jù)增強方法的表現(xiàn),我們?nèi)藶榈貜脑摂?shù)據(jù)集中抽取不同數(shù)量的子數(shù)據(jù)集, 每類從50到1 000不等。實驗主要對比以下幾種不同的數(shù)據(jù)增強方式:(1)不采用任何的數(shù)據(jù)增強方式(D);(2)傳統(tǒng)的基于仿射變換和圖像操作的數(shù)據(jù)增強方式(D_aug);(3)本文所提出的基于MUNIT的數(shù)據(jù)增強方法(DA-MUNIT_aug)。實驗對比了不同方法下訓練出來的分類器在測試集上的分類準確率(Acc),結(jié)果如表1所示。
表1 不同數(shù)據(jù)增強方法在ImageNet貓科動物數(shù)據(jù)集上測試集的準確率/%
從實驗結(jié)果可以看出,DA-MUNIT_aug是所有方法中對分類器的分類結(jié)果準確率提升效果最好的。于1 000張原數(shù)據(jù)和經(jīng)DA-MUNIT網(wǎng)絡數(shù)據(jù)增強后的訓練集準確率變化如圖5所示。在測試集上的準確率變化如圖6所示,訓練時的損失函數(shù)變化如圖7所示。
圖5 訓練集準確率對比圖
圖6 驗證集準確率對比圖
從圖5-圖7中可明顯看出經(jīng)DA-MUNIT網(wǎng)絡進行數(shù)據(jù)增強后的訓練更加穩(wěn)定且測試準確率更高,且在一定程度上減少了過擬合。
圖7 訓練損失變化對比圖
上述1 000張數(shù)據(jù)實驗結(jié)果說明基于MUNIT的數(shù)據(jù)增強方法對一些小樣本的圖像分類任務中的模型進行很大程度的提升,在一定程度上解決了數(shù)據(jù)過少導致的模型表現(xiàn)較差或過擬合問題。
以數(shù)據(jù)增強為目標, 本文通過使用和改進MUNIT網(wǎng)絡,設計了一種基于MUNIT的數(shù)據(jù)增強方法,以MUNIT生成的圖像做數(shù)據(jù)增強的方法相比于傳統(tǒng)的數(shù)據(jù)增強方法,生成的新樣本和原始數(shù)據(jù)的分布基本相同,還可以做到不同類之間圖像風格交叉轉(zhuǎn)換,以此提供內(nèi)容相同但風格不同的不同類別標簽的新樣本,使小樣本分類器能夠在加上合成的新數(shù)據(jù)后很大程度地提升了訓練穩(wěn)定性,減少了過擬合。