摘要:小樣本學(xué)習(xí)是圖像分類(lèi)任務(wù)中的一個(gè)重要挑戰(zhàn),能夠有效解決因數(shù)據(jù)量較少而產(chǎn)生的模型準(zhǔn)確率降低的問(wèn)題。針對(duì)小樣本學(xué)習(xí)難以準(zhǔn)確獲取類(lèi)內(nèi)共有特征的問(wèn)題,提出一種基于類(lèi)注意力的原型網(wǎng)絡(luò)改進(jìn)方法。利用掩膜圖像進(jìn)行數(shù)據(jù)預(yù)處理和圖像增強(qiáng),以提高原始數(shù)據(jù)質(zhì)量;引入注意力機(jī)制,選擇性地關(guān)注特征圖中的重要信息,以增強(qiáng)特征提取能力;設(shè)計(jì)類(lèi)注意力模塊,提取具有注意力信息的類(lèi)別原型。實(shí)驗(yàn)結(jié)果表明,在miniImageNet數(shù)據(jù)集上,該方法的分類(lèi)準(zhǔn)確率在基線基礎(chǔ)上提高了2%,驗(yàn)證了其有效性。
關(guān)鍵詞:原型網(wǎng)絡(luò);小樣本學(xué)習(xí);數(shù)據(jù)增強(qiáng);類(lèi)注意力;圖像分類(lèi)
中圖分類(lèi)號(hào):TP183""""""""""""文獻(xiàn)標(biāo)志碼:A """""""""文章編號(hào):1674-2605(2025)01-0009-07
DOI:10.3969/j.issn.1674-2605.2025.01.009"""nbsp;""""""""""""""""開(kāi)放獲取
Improvement Method of Prototype Network Based on Class Attention
CAO Zenghui CHEN Hao CAO"Yahui
(1.Guangdong University of Technology, Guangzhou 510000, China
2.Zhengzhou Vocational College of Industrial Safety,"Zhengzhou 450000, China)
Abstract:"Small sample learning is an important challenge in image classification tasks, which can effectively solve the problem of reduced model accuracy due to limited data volume. A prototype network improvement method based on class attention is proposed to address the problem of difficulty in accurately obtaining common features within classes in small sample learning. Using mask images for data preprocessing and image enhancement to improve the quality of raw data; Introducing attention mechanism to selectively focus on important information in feature maps to enhance feature extraction capability; Design a class attention module to extract class prototypes with attention information. The experimental results show that on the miniImageNet dataset, the classification accuracy of this method has improved by 2% compared to the baseline, verifying its effectiveness.
Keywords:"prototype network; small sample learning; data enhancement; class attention; image classification
0 引言
在計(jì)算機(jī)視覺(jué)領(lǐng)域,圖像分類(lèi)是一個(gè)重要且具有挑戰(zhàn)性的研究方向。傳統(tǒng)的圖像分類(lèi)方法,如K近鄰算法、決策樹(shù)、隨機(jī)森林等,在小樣本場(chǎng)景下泛化能力和準(zhǔn)確率有限。而小樣本學(xué)習(xí)在模型訓(xùn)練階段僅用少量的標(biāo)簽樣本即可完成分類(lèi)任務(wù),解決了因樣本數(shù)量較少而導(dǎo)致的模型準(zhǔn)確率下降的問(wèn)題。然而,小樣本學(xué)習(xí)存在泛化能力不足、過(guò)擬合、類(lèi)別不平衡等問(wèn)
題。為此,學(xué)者們提出了一系列的解決方案。其中,原型網(wǎng)絡(luò)[1]作為一種有效的模型框架被廣泛研究和應(yīng)用。
原型網(wǎng)絡(luò)通過(guò)學(xué)習(xí)類(lèi)別原型的特征,求取各個(gè)類(lèi)別原型的表示,通過(guò)樣本與類(lèi)別原型之間的距離進(jìn)行分類(lèi),初步解決了類(lèi)別不平衡的問(wèn)題,但仍然存在因樣本數(shù)量較少而導(dǎo)致的難以準(zhǔn)確獲取類(lèi)內(nèi)共有特征的問(wèn)題。文獻(xiàn)[2]通過(guò)對(duì)訓(xùn)練樣本的特征進(jìn)行收縮和擴(kuò)
展,生成額外的樣本,提高了模型的泛化能力。文獻(xiàn)[3]通過(guò)在特征空間進(jìn)行隨機(jī)變換和插值操作,生成多樣化的樣本,幫助模型更好地學(xué)習(xí)特征。文獻(xiàn)[4]結(jié)合半監(jiān)督學(xué)習(xí)與數(shù)據(jù)增強(qiáng),通過(guò)弱增強(qiáng)生成偽標(biāo)簽,強(qiáng)增強(qiáng)優(yōu)化模型的一致性。以上文獻(xiàn)利用不同的圖像增強(qiáng)方法來(lái)增加樣本數(shù)量,但簡(jiǎn)單的圖像變換無(wú)法有效增加樣本的多樣性。
針對(duì)上述現(xiàn)狀,本文提出一種基于類(lèi)注意力的原型網(wǎng)絡(luò)改進(jìn)方法。采用掩膜圖像進(jìn)行數(shù)據(jù)預(yù)處理,增強(qiáng)圖像的質(zhì)量和信息,改善小樣本數(shù)據(jù)質(zhì)量;引入注意力機(jī)制區(qū)分無(wú)關(guān)特征和相關(guān)特征;設(shè)計(jì)類(lèi)注意力模塊,提取具有注意力信息的類(lèi)別原型表示,從而提高原型網(wǎng)絡(luò)在小樣本學(xué)習(xí)中的分類(lèi)性能和泛化能力。
1 相關(guān)工作
1.1 原型網(wǎng)絡(luò)
原型網(wǎng)絡(luò)是一種基于距離度量的分類(lèi)器[5],其先通過(guò)學(xué)習(xí)每個(gè)類(lèi)別的原型向量來(lái)表示不同類(lèi)別之間的關(guān)系,再通過(guò)計(jì)算查詢(xún)集樣本的特征向量與支持集每個(gè)類(lèi)別原型向量之間的歐氏距離進(jìn)行分類(lèi)[6]。傳統(tǒng)的類(lèi)別原型向量通常由每個(gè)類(lèi)別所有樣本的特征向量進(jìn)行均值計(jì)算得到。
1.2 數(shù)據(jù)增強(qiáng)方法
數(shù)據(jù)增強(qiáng)通過(guò)對(duì)訓(xùn)練數(shù)據(jù)進(jìn)行變換和擴(kuò)充,增加數(shù)據(jù)的多樣性和數(shù)量,從而改善模型的泛化能力和魯棒性。常用的數(shù)據(jù)增強(qiáng)方法包括平移、旋轉(zhuǎn)、縮放、翻轉(zhuǎn)等幾何變換[8],以及亮度、對(duì)比度、色彩等顏色變換[9]。數(shù)據(jù)增強(qiáng)不僅可通過(guò)對(duì)原始圖像進(jìn)行隨機(jī)變換來(lái)生成更多的訓(xùn)練數(shù)據(jù),還可通過(guò)剪切、填充、仿射等操作,改變?cè)紙D像的形狀和結(jié)構(gòu)。
近年來(lái),數(shù)據(jù)增強(qiáng)技術(shù)在深度學(xué)習(xí)領(lǐng)域取得了較大進(jìn)展。文獻(xiàn)[10]提出一種RandAugment數(shù)據(jù)增強(qiáng)方法,通過(guò)一系列的隨機(jī)變換來(lái)擴(kuò)充訓(xùn)練數(shù)據(jù)集;在ImageNet數(shù)據(jù)集上,模型的準(zhǔn)確率在基線基礎(chǔ)上提升了1.3%。文獻(xiàn)[11]提出一種Mixup數(shù)據(jù)增強(qiáng)方法,通過(guò)在訓(xùn)練樣本之間進(jìn)行線性插值來(lái)生成新的樣本,有效地增加了樣本的多樣性。
1.3 注意力機(jī)制
注意力機(jī)制是指在神經(jīng)網(wǎng)絡(luò)中,通過(guò)對(duì)輸入數(shù)據(jù)的不同部分進(jìn)行加權(quán)處理,使網(wǎng)絡(luò)更加關(guān)注有用的信息,廣泛應(yīng)用于自然語(yǔ)言處理、計(jì)算機(jī)視覺(jué)、語(yǔ)音識(shí)別等領(lǐng)域[12]。文獻(xiàn)[13]提出一種用于深度神經(jīng)網(wǎng)絡(luò)的注意力機(jī)制,可自適應(yīng)地調(diào)整輸入數(shù)據(jù)的通道權(quán)重,從而提高模型性能。文獻(xiàn)[14]提出一種高效通道注意力(efficient channel attention, ECA)模塊,利用自適應(yīng)卷積核計(jì)算每個(gè)通道的權(quán)重,避免了傳統(tǒng)通道注意力機(jī)制因采用全局平均池化操作而導(dǎo)致的信息損失。文獻(xiàn)[15]提出一種基于空間注意力和通道注意力機(jī)制的網(wǎng)絡(luò)模塊,利用一組卷積核來(lái)學(xué)習(xí)每個(gè)空間位置的權(quán)重,并結(jié)合通道注意力機(jī)制來(lái)提高特征圖的表達(dá)能力。文獻(xiàn)[16]提出一種Non-local注意力機(jī)制,利用所有位置的特征信息計(jì)算每個(gè)位置的權(quán)重,以實(shí)現(xiàn)不同空間位置特征的加權(quán),模型準(zhǔn)確率在基線基礎(chǔ)上提高了2.3%。
在小樣本場(chǎng)景下,文獻(xiàn)[17]引入自適應(yīng)注意力機(jī)制,根據(jù)樣本的重要性動(dòng)態(tài)調(diào)整模型的注意力,提高了模型對(duì)關(guān)鍵樣本的學(xué)習(xí)能力。文獻(xiàn)[18]設(shè)計(jì)了元權(quán)重生成器和空間注意力生成器結(jié)構(gòu),并將分類(lèi)預(yù)測(cè)得分改為對(duì)稱(chēng)形式,以提高模型的泛化能力。文獻(xiàn)[19]通過(guò)引入多級(jí)注意力機(jī)制、特征金字塔結(jié)構(gòu)、細(xì)粒度的注意力加權(quán)和端到端的訓(xùn)練策略,有效改進(jìn)了小樣本學(xué)習(xí)任務(wù)中的特征提取和分類(lèi)性能,使模型能夠更好地適應(yīng)小樣本的學(xué)習(xí)任務(wù)。
2 本文方法
2.1 訓(xùn)練策略
2.2 原型網(wǎng)絡(luò)改進(jìn)模型
在圖2的網(wǎng)絡(luò)模型中,將支持集圖像和查詢(xún)集圖像輸入同一特征提取模塊,獲取圖像的特征向量。支持集特征向量通過(guò)類(lèi)注意力模塊獲取關(guān)注類(lèi)內(nèi)共同信息的類(lèi)原型向量,通過(guò)計(jì)算查詢(xún)集樣本的特征向量與每個(gè)類(lèi)原型向量的歐氏距離進(jìn)行分類(lèi)。
2.2.1 數(shù)據(jù)增強(qiáng)模塊
數(shù)據(jù)增強(qiáng)技術(shù)在小樣本學(xué)習(xí)中被廣泛采用[21]。由于數(shù)據(jù)集樣本具有主體位置不定、大小不等、背景復(fù)雜等特點(diǎn),本文采用掩膜圖像對(duì)支持集圖像進(jìn)行隨機(jī)區(qū)域掩膜,提升原型網(wǎng)絡(luò)對(duì)局部信息的補(bǔ)全,以及不完全信息圖像的識(shí)別能力。掩膜效果圖如圖3所示。
掩膜圖像方法獨(dú)立于參數(shù)學(xué)習(xí)過(guò)程,因此可以嵌入到任何基于卷積神經(jīng)網(wǎng)絡(luò)(convolutional neural networks, CNN)的識(shí)別模型中。
2.2.2 特征提取模塊
將數(shù)據(jù)增強(qiáng)后的支持集圖像和查詢(xún)集圖像一起輸入到特征提取模塊,將所有支持集中的D維向量數(shù)據(jù)映射到新的Z維特征空間。特征提取模塊的特征提取器采用Vgg16模型作為主干網(wǎng)絡(luò),并引入了注意力機(jī)制,以重點(diǎn)關(guān)注提取圖像中的重要信息。
2.2.3 類(lèi)注意力模塊
類(lèi)注意力模塊將支持集圖像進(jìn)行類(lèi)注意力信息的提取,得到帶有權(quán)值的類(lèi)別原型表示。本文提出的類(lèi)注意力模塊主要包括Extract和Interaction"2個(gè)模塊,如圖4所示。
Extract模塊用于壓縮、提取圖像數(shù)據(jù)。經(jīng)過(guò)編碼后的類(lèi)內(nèi)K個(gè)C×H×W維度的特征向量,通過(guò)全局平均池化壓縮為K個(gè)C通道、1×1維的特征圖,即將每個(gè)樣本、每個(gè)通道內(nèi)H×W維的圖像轉(zhuǎn)化為一個(gè)數(shù)字表示,得到K×C個(gè)類(lèi)別內(nèi)所有樣本的權(quán)值。提取圖像數(shù)據(jù)的計(jì)算公式為
2.2.4 距離度量模塊
距離度量模塊基于度量的方式來(lái)計(jì)算查詢(xún)集樣本的特征向量與支持集每個(gè)類(lèi)別原型向量之間的距離,再轉(zhuǎn)化為相似性度量,從而判斷樣本類(lèi)別。
3 實(shí)驗(yàn)與評(píng)估
3.1 數(shù)據(jù)集
本實(shí)驗(yàn)數(shù)據(jù)集采用miniImageNet,其包含60"000幅圖像,分為100個(gè)類(lèi)別。采用文獻(xiàn)[20]的數(shù)據(jù)集劃分方式將訓(xùn)練集、驗(yàn)證集、測(cè)試集分別劃分為64、16、20個(gè)類(lèi)別,同時(shí)將輸入圖像處理為84×84像素。
3.2 實(shí)驗(yàn)環(huán)境
在Ubuntu操作系統(tǒng)上,采用開(kāi)源深度學(xué)習(xí)框架PyTorch搭建模型,利用GPU進(jìn)行實(shí)驗(yàn)計(jì)算,以提高模型的迭代速度。為保證實(shí)驗(yàn)的嚴(yán)謹(jǐn)性,設(shè)置固定的隨機(jī)順序來(lái)保證每次對(duì)比實(shí)驗(yàn)抽取的樣本一致。采用Vgg16模型作為主干網(wǎng)絡(luò)進(jìn)行訓(xùn)練,并確保每次實(shí)驗(yàn)僅有驗(yàn)證項(xiàng)發(fā)生改變。實(shí)驗(yàn)環(huán)境如表1所示,實(shí)驗(yàn)參數(shù)如表2所示。
3.3 評(píng)價(jià)指標(biāo)
本實(shí)驗(yàn)采用5-way"1-shot和5-way"5-shot的驗(yàn)證模式,即在支持集中每次隨機(jī)選擇5個(gè)支持集類(lèi)別,每個(gè)支持集類(lèi)別分別有1個(gè)樣本和5個(gè)樣本進(jìn)行實(shí)驗(yàn)。利用查詢(xún)集中樣本的準(zhǔn)確率來(lái)評(píng)估模型性能。準(zhǔn)確率的計(jì)算公式為
3.4 實(shí)驗(yàn)結(jié)果
3.4.1 "數(shù)據(jù)增強(qiáng)方法驗(yàn)證實(shí)驗(yàn)
選取翻轉(zhuǎn)、旋轉(zhuǎn)、隨機(jī)裁剪等不同的數(shù)據(jù)增強(qiáng)方法進(jìn)行驗(yàn)證實(shí)驗(yàn)。其中,RandomCrop方法根據(jù)設(shè)置的參數(shù)隨機(jī)裁剪原始圖像;RandomHorizontalFlip、RandomVerticalFlip方法水平、垂直翻轉(zhuǎn)原始圖像;ColorJitter方法隨機(jī)修改原始圖像的亮度、對(duì)比度和飽和度;RandomRotation方法隨機(jī)角度旋轉(zhuǎn)原始圖像。實(shí)驗(yàn)結(jié)果如表3所示。
由表3可以看出:RandomCrop方法的準(zhǔn)確率在基線基礎(chǔ)上下降約8%,而RandomHorizontalFlip、RandomVerticalFlip、ColorJitter、RandomRotation、本文方法的準(zhǔn)確率在基線基礎(chǔ)上分別提高了0.97%、0.51%、0.82%、0.05%、1.58%,本文方法的準(zhǔn)確率提高最為顯著,表明本文數(shù)據(jù)增強(qiáng)方法有效。
3.4.2 "小樣本學(xué)習(xí)方法對(duì)比實(shí)驗(yàn)
將匹配網(wǎng)絡(luò)(matching networks, MN)、關(guān)系網(wǎng)絡(luò)(relation networks,"RN)、記憶匹配網(wǎng)絡(luò)(memory matching networks,"MMN)、注意力吸引網(wǎng)絡(luò)(attention attractor networks,"AAN)、模型無(wú)關(guān)的元學(xué)習(xí)(model--agnostic meta-learning,"MAML)、Reptile、文獻(xiàn)[27]、文獻(xiàn)[28]、Prototypical network等9種經(jīng)典的小樣本學(xué)習(xí)方法與本文方法進(jìn)行對(duì)比實(shí)驗(yàn),結(jié)果如表4所示。
由表4可以看出:本文方法在5-way 1-shot任務(wù)上取得了53.42%的準(zhǔn)確率,與其他方法相比處于較高水平;在5-way 5-shot任務(wù)上則取得了70.33%的準(zhǔn)確率,優(yōu)于表中所有對(duì)比方法,說(shuō)明本文方法在少量樣本的場(chǎng)景下具有更出色的泛化能力。
4 結(jié)論
本文受原型網(wǎng)絡(luò)和注意力機(jī)制的啟發(fā),利用數(shù)據(jù)增強(qiáng)方法增加樣本的多樣性,引入注意力機(jī)制提升網(wǎng)絡(luò)特的征提取能力,利用類(lèi)注意力模塊改進(jìn)原型網(wǎng)絡(luò),解決小樣本學(xué)習(xí)因樣本多樣性不足導(dǎo)致的類(lèi)內(nèi)共有特征難以準(zhǔn)確獲取的問(wèn)題。實(shí)驗(yàn)結(jié)果表明,數(shù)據(jù)增強(qiáng)方法能夠較好地增加數(shù)據(jù)樣本,提升模型對(duì)不同樣本的辨識(shí)性;類(lèi)注意力機(jī)制能較好提取類(lèi)內(nèi)信息,更好地表示類(lèi)別原型。
?The author(s) 2024. This is an open access article under the CC BY-NC-ND 4.0 License (https://creativecommons.org/licenses/ by-nc-nd/4.0/)
參考文獻(xiàn)
[1] 趙凱琳,靳小龍,王元卓.小樣本學(xué)習(xí)研究綜述[J].軟件學(xué)報(bào),nbsp;2021,32(2):349-369.
[2] HARIHARAN B, GIRSHICK"R."Low-shot visual recognition by shrinking and hallucinating features[J]."IEEE Transactions on Pattern Analysis and Machine Intelligence, 2017,39(8): 1653-1667.
[3] DEVRIES T, TAYLOR"G W. Dataset augmentation in feature space"[J]. arXiv preprint arXiv:1702.05538, 2017.
[4] CUBUK E D, ZOPH B, MANE"D, et al. Autoaugment: Learning augmentation policies from data[J]. arXiv preprint arXiv:1805."09501, 2018.
[5] 王圣杰,王鐸,梁秋金,等.小樣本學(xué)習(xí)綜述[J].空間控制技術(shù)與應(yīng)用,2023,49(5):1-10.
[6] 陳良臣,傅德印.面向小樣本數(shù)據(jù)的機(jī)器學(xué)習(xí)方法研究綜述[J].計(jì)算機(jī)工程,2022,48(11):1-13.
[7] SNELL J, SWERSKY K, ZEMEL R S."Prototypical networks for few-shot learning[J]. Advances in Neural Information pro-cessing Systems, 2017:30.
[8] SIMARD P Y, STEINKRAUS D, PLATT J C. Best practices for convolutional neural networks applied to visual document analysis[C]//7th International Conference on Document Anal-ysis and Recognition (ICDAR)."Edinburgh, UK: IEEE, 2003.
[9] KRIZHEVSKY A, SUTSKEVER I, HINTON"G."ImageNet classification with deep convolutional neural networks[J]."Communications of the ACM, 2017,60(6):84-90.
[10] CUBUK E D, ZOPH B, SHLENS"J, et al. Randaugment: Practical automated data augmentation with a reduced search space[C]//Proceedings of the IEEE/CVF Conference on Com-puter Vision and Pattern Recognition Workshops,"2020:702-703.
[11] ZHANG H, CISSE M, DAUPHIN Y N,"et al."Mixup: Beyond Empirical Risk Minimization[J]. arXiv preprint arXiv:1710. 09412, 2017.
[12] 彭云聰,秦小林,張力戈,等.面向圖像分類(lèi)的小樣本學(xué)習(xí)算法綜述[J].計(jì)算機(jī)科學(xué),2022,49(5):1-9.
[13] HU J, SHEN L, SUN G."Squeeze-and-Excitation networks[C]//"Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition,"2018:7132-7141.
[14] WANG Q, WU B, ZHU"P,"et al."ECA-Net: Efficient channel attention for deep convolutional neural networks[C]// Proceed-ings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition,"2020:11534-11542.
[15] WOO"S, PARK"J, LEE"J Y, et al. Cbam: Convolutional block attention module[C]//Proceedings of the European Conference on Computer Vision (ECCV),"2018:3-19.
[16]"WANG"X", GIRSHICK R, GUPTA A, et al. Non-local neural networks[C]//Proceedings of the IEEE Conference on Com-puter Vision and Pattern Recognition,"2018:7794-7803.
[17] XING C, ROSTAMZADEH N, ORESHKIN B N,"et al."Adap-tive cross-modal few-shot learning[C]. Advances in Neural In-formation Processing Systems, 2019.
[18] JIANG Z, KANG B, ZHOU K, et al. Few-shot classification via adaptive attention[J]. arXiv preprint arXiv:2008.02465, 2020.
[19] 汪榮貴,韓夢(mèng)雅,楊娟,等.多級(jí)注意力特征網(wǎng)絡(luò)的小樣本學(xué)習(xí)[J].電子與信息學(xué)報(bào),2020,42(3):772-778.
[20] VINYALS O, BLUNDELL C, LILLICRAP T,"et al."Matching networks for one shot learning[J]."Advances in Neural Infor-mation Processing Systems, 2016:29.
[21] LI B, HOU Y, CHE"W."Data augmentation approaches in natural language processing: A survey[J]."AI Open, 2022,3:71-90.
[22] SUNG F, YANG Y, ZHANG L,"et al."Learning to compare: relation network for few-shot learning[C]//Proceedings of the IEEE Conference on Computer Vision and Pattern"Recogni-tion,"2018:1199-1208.
[23] CAI Q, PAN Y W, YAO T,"et al. Memory matching net-works"for"one-shot image recognition[C]//Proceedings of the IEEE Conference on Computer Vision and Pattern Recogni-tion,"2018:4080-4088.
[24] REN M, LIAO R, FETAYA"E,"et al."Incremental Few-Shot Learning with Attention Attractor Networks[C]. Advances in Neural Information Processing Systems, 2019.
[25] FINN C, ABBEEL P, LEVINE S. Model-agnostic meta-learning for fast adaptation of deep networks[C]//International Conference on Machine Learning. PMLR, 2017:1126-1135.
[26] NICHOL A, SCHULMAN J."Reptile: A"scalable metalearning algorithm[J]. arXiv preprint arXiv:1803.02999, 2018,2(3):4.
[27] RAVI S, LAROCHELLE H. Optimization as a model for few--shot learning[C]//International Conference on Learning Repre--sentations,"2017.
[28] YE H J, CHAO W L."How to train your"MAML to excel in few-shot classification[J]. arXiv preprint arXiv:2106.16245, 2021.
作者簡(jiǎn)介:
曹增輝,男,1997年生,碩士研究生,主要研究方向:圖像處理和小樣本圖像分類(lèi)。E-mail:"czh258biu@163.com
陳浩,男,2000年生,碩士研究生,主要研究方向:人工智能和原型網(wǎng)絡(luò)。E-mail:"chenhao_gd@163.com
曹雅慧,女,2003年生,專(zhuān)科,主要研究方向:人工智能。E-mail:"15103814269@163.com