甘 宏
(廣州南方學院,510970,廣州)
近年來,將深度學習技術用于視覺識別任務取得了相當大的進展[1-5]。然而,有監(jiān)督的深度學習模型需要大量的標記樣本和迭代步驟來訓練模型的參數(shù),這嚴重限制了深度學習技術對新出現(xiàn)或罕見類別的適用性,同時收集并標記大量的樣本需要耗費大量的人力物力。相比之下,人類卻擅長通過少量甚至幾個樣本來識別物體,而深度學習技術難以用于每類僅有一個或幾個樣本的學習。受人類具備小樣本學習能力的啟發(fā),使得小樣本學習問題引起了廣泛的關注。
現(xiàn)有的小樣本學習方法大致可以分成3類:度量學習、元學習以及基于數(shù)據(jù)增強方法。度量學習方法利用輔助數(shù)據(jù)集學習得到一個度量空間,使得在該度量空間中同一類樣本的特征向量彼此間的距離較近,而不同類樣本的特征向量距離則較遠,從而實現(xiàn)小樣本學習。文獻[6]將卷積孿生網(wǎng)絡用于單樣本圖像識別,通過有監(jiān)督的方式訓練孿生網(wǎng)絡,然后重用網(wǎng)絡所提取的特征向量進行單樣本學習。文獻[7]提出了匹配網(wǎng)絡,該算法的核心是episode-based的訓練策略,其基本思想是訓練和測試是要在同樣條件下進行,即在訓練的時候讓網(wǎng)絡模型只看每一類的少量樣本,使得訓練和測試的過程保持一致。原型網(wǎng)絡的基本思想是每個類都存在一個原型表達,該類的原型是支撐集在利用度量空間中特征向量的均值作為類表示[8]。F Sung等[9]提出了關系網(wǎng)絡求解小樣本學習問題,該模型2個模塊:嵌入模塊和關系模塊。嵌入式模塊用于提取數(shù)據(jù)樣本的特征表示,而關系模塊用于估計2個特征表示之間的距離。D DAS等[10]添加了一個預訓練階段,利用所有基類的分類任務預訓練模型獲得參數(shù)的初始化。Li等[11]在分類損失函數(shù)中添加一個與任務相關的附加邊際損失,以更好地區(qū)分不同類別的樣本,從而提高分類性能。Zhou等[12]利用貪婪算法選擇與支持集樣本的相似基類,使得度量模型能對新的小樣本任務有較強的適應性。元學習方法通過對多個任務的學習,以使元模型(meta-learner)能夠對新的任務做出快速而準確的學習,該方法包含了2個關鍵問題:訓練得到最優(yōu)初始化參數(shù)和學習有效的參數(shù)更新規(guī)則。FINN等[13]提出了MAML(Model-Agnostic Meta-Learning)的元學習方法,基本思想是訓練一組初始化參數(shù),通過在初始參數(shù)的基礎上進行一或多步的梯度調(diào)整,來達到僅用少量數(shù)據(jù)就能快速適應新任務的目的。K Wang等[14]給出了結合概率推理和元學習的識別模型,以阻止元模型訓練過程中偏向某些具體任務,從而提高元模型對新任務的泛化能力。Meta-SGD算法[15]對MAML算法進一步優(yōu)化,不僅對初始參數(shù)進行了學習,而且對元模型的更新方向和學習速率進行學習。文獻[16]提出了一階元學習算法,該算法采用一階導數(shù)近似表示二階導數(shù),使得元參數(shù)更新過程中不需要像MAML算法一樣計算二階導數(shù),從而提高元模型的訓練效率。數(shù)據(jù)增強方法通過擴充樣本來提高小樣本學習的性能。然而,數(shù)據(jù)生成模型在僅有少數(shù)幾個訓練數(shù)據(jù)時,往往表現(xiàn)不佳。
本文提出算法屬于元學習方法的范疇。針對現(xiàn)有元學習方法對部分訓練任務存在有偏的不足,本文提出基于正則化元學習算法。通過在元學習的目標函數(shù)中添加正則化項,阻止元學習的初始模型偏向現(xiàn)有某些訓練任務,提高元模型對新任務的泛化能力,從而提高小樣本圖像分類的性能。
小樣本分類的目標是找到參數(shù)θ,小樣本分類目標是學習得到參數(shù)θ使得分類器fθ在詢問集中的期望值最大
(1)
為了減小元訓練過程中產(chǎn)生有偏,提高元學習模型的泛化能力。本節(jié)提出了正則化元學習算法(Regularized Meta Learning,REML)。通過在元目標函數(shù)添加正則化約束項,使得模型對訓練任務無偏。針對小樣本圖像分類問題,MAML算法的元目標函數(shù)為:
(2)
(3)
其中LTi(fθ)采用交叉熵損失函數(shù),表示為:
(4)
因此,MAML算法的元目標函數(shù)可以表示為:
(5)
為盡量減小參數(shù)θ對訓練任務有偏,提高元模型的泛化能力。本文引入交叉熵的約束條件,作為原目標函數(shù)的正則化項,使得參數(shù)θ對訓練任務是無偏的。交叉熵表示為:
(6)
以交叉熵作為正則化項,則元目標函數(shù)表示為
(7)
(8)
元目標函數(shù)梯度更新表示為
(9)
(10)
求導涉及到二維求導問題,大大增加了算法的計算量。針對以上不足,利用一階導數(shù)近似二階導數(shù)得到
(11)
則元參數(shù)更新模型(9)可以簡化為
(12)
本節(jié)將給出算法的詳細步驟,詳見算法1。
算法1:正則化元學習算法。
1)While not done do;
2)抽取幾個任務Ti構成任務塊Tbat;
3)for allTiinTbatdo;
4)從Ti中每類選取K個樣本記做D;
5)利用LTi(fθ)和D計算?θL(fθ);
7)從Ti抽取Dval用于元參數(shù)學習;
8)End for;
9)利用Dval和元學習目標函數(shù)L(θ)學習元模型參數(shù)θ,
10)End while。
輸出:元模型參數(shù)θ。
本節(jié)通過在miniImageNet、CUB-200和CIFAR-100這3個典型數(shù)據(jù)集上進行的小樣本分類實驗,來充分驗證本文算法性能,并與MAML、Reptile、Relation Networks和Prototypical Networks等先進算法比較。實驗1比較了不同算法在MiniImageNet數(shù)據(jù)集中的性能,并給出了參數(shù)λ對本文算法的影響;實驗2比較了不同算法在數(shù)據(jù)集CUB-200上的算法性能;實驗3給出了在數(shù)據(jù)集CIFAR-100上不同算法的性能比較。
為方便與其他算法進行比較,在后續(xù)的實驗中本文算法采用了與文獻[8-9,13,16]相同的網(wǎng)絡結構。網(wǎng)絡結構由4個模塊組成,每個模塊包含1個3×3×64的卷積層和1個2×2的池化層,每個卷積層均采用歸一化處理。
MiniImageNet數(shù)據(jù)集包含100個類,其中每個類包含600個樣本。采用與其他算法相同的拆分,其中64個類用于訓練,16個類進行驗證,20個類用于測試。分別進行了5-way 1-shot和5-way 5-shot小樣本圖像分類實驗,表1給出不同算法的分類精度比較。
由表1可以看出,本文算法由于提高了模型對新任務的泛化能力,從而使分類精度得到了一定的提升。
表1 不同算法在數(shù)據(jù)集miniImagenet中分類精度的比較
CUB-200數(shù)據(jù)集[14]包括了200種細分的類。參照文獻[15]中的劃分,隨機選取100個類用于元訓練,50個類用于驗證,50個類進行測試,并將每幅圖像的尺寸大小調(diào)整為84×84。分別進行了5-way 1-shot和5-way 5-shot小樣本圖像分類實驗,表2比較了4種算法的分類精度。
由表2可以看出,本文算法相對于MAML算法的分類精度能有將近4%的提升。
CIFAR-100數(shù)據(jù)集包括了100個類,每個類包含600張尺寸為32×32的圖形。隨機選取64個類進行元訓練,16個類用于驗證,20個類用于小樣本分類性能測試。與其他實驗類似,分別進行了5-way 1-shot和5-way 5-shot小樣本圖像分類實驗,表3比較了不同算法的分類精度。由表3可以看出,本文算法相對于MAML算法精度有3%左右的精度提高。
表2 不同算法在數(shù)據(jù)集CUB-200中分類精度的比較
表3 不同算法在數(shù)據(jù)集CIFAR-100中的分類精度比較
本小節(jié)通過對以上3個數(shù)據(jù)庫的5-way 5-shot小樣本圖像分類實驗,分析平衡參數(shù)λ對算法性能的影響。圖1給出了本文算法(REML)在不同參數(shù)值時的分類精度。由圖1可以看出,當參數(shù)λ取值接近0時,算法識別精度與MAML算法接近;當參數(shù)λ取0.2~0.3之間時能獲得較高的識別精度;當參數(shù)λ大于0.3之后,隨著參數(shù)的增加算法性能逐步下降。
圖1 平衡參數(shù)λ不同時的算法分類精度
針對小樣本學習問題,本文提出了正則化元學習算法(REML)用于求解小樣本圖像分類問題。該算法以交叉熵作為正則化項,以阻止元模型參數(shù)偏向某些具體任務,從而提高元模型的泛化能力,即提高元模型對新任務的適應能力。此外,采用一階導數(shù)近似二階導數(shù)減小元學習模型訓練所需計算量。在miniImageNet、CUB-200和CIFAR-100這3個數(shù)據(jù)集上進行的實驗表明,本文算法的分類性能優(yōu)于現(xiàn)有的同類算法,并表明平衡參數(shù)選擇在0.2~0.3之間時能獲得較高的識別精度。