吳豪杰 ,王妍潔 ,蔡文炳 ,王 飛 ,劉 洋 ,蒲 鵬 ,林紹輝
(1.中國電子科技集團公司第二十七研究所,鄭州 450047;2.北京跟蹤與通信技術(shù)研究所,北京 100094;3.中國人民解放軍63726 部隊,銀川 750004;4.華東師范大學 計算機科學與技術(shù)學院,上海 200062;5.華東師范大學 數(shù)據(jù)科學與工程學院,上海 200062)
近年來,隨著深度學習與圖形處理器(Graphics Processing Unit,GPU)硬件的不斷發(fā)展,卷積神經(jīng)網(wǎng)絡(luò)(Convolutional Neural Networks,CNNs)已經(jīng)在諸多人工智能領(lǐng)域取得了顯著的成效,如區(qū)塊鏈[1]、圖像分類[2]、目標檢測[3]等.得益于其大規(guī)模的數(shù)據(jù)量與強大的特征提取能力,CNNs 在某些任務(wù)上甚至已經(jīng)超過了人類識別的準確率[4].同時,GPU 硬件的高速發(fā)展大大提高了網(wǎng)絡(luò)模型的計算效率.
隨著網(wǎng)絡(luò)模型性能的提升,其計算開銷與存儲量也在不斷增加.如AlexNet[2]模型,其具有0.61 億網(wǎng)絡(luò)參數(shù)和7.29 億次浮點計算量(Floating-point Operations per Second,FLOPs),占用約240 MB 的存儲空間.對于被廣為使用的152 層殘差網(wǎng)絡(luò)(Residual Network-152,ResNet-152)[4]具有0.57 億網(wǎng)絡(luò)參數(shù)和113 億次浮點計算量,占用約230 MB 的存儲空間.龐大的網(wǎng)絡(luò)參數(shù)意味著更大的內(nèi)存占用,而巨大的浮點計算量意味著高昂的訓練代價與較小的推理速度.這使得如此高存儲、高功耗模型無法直接在資源有限的應(yīng)用場景下應(yīng)用,如手機、無人機、機器人等邊緣嵌入式設(shè)備.因此,在保持模型識別準確率的前提下,對于網(wǎng)絡(luò)模型進行壓縮與加速,以適應(yīng)邊緣設(shè)備的實際要求,成為了當前計算機視覺領(lǐng)域火熱的研究課題.與此同時,也有研究表明[5],在巨大的網(wǎng)絡(luò)參數(shù)內(nèi)部,并不是所有的結(jié)構(gòu)和參數(shù)對于網(wǎng)絡(luò)的識別預測能力都起到?jīng)Q定性作用,這使得模型壓縮技術(shù),即移除冗余性參數(shù)和計算量成為了一種有效的解決方案.
當前主流的模型壓縮方法可以分為5 種,分別為參數(shù)剪枝、參數(shù)量化、低秩分解、輕量型網(wǎng)絡(luò)結(jié)構(gòu)設(shè)計和知識蒸餾(Knowledge Distillation,KD).知識蒸餾方法可以直接設(shè)定壓縮后模型的結(jié)構(gòu)、計算量和參數(shù)量,以及不引入額外的計算算子,這使得知識蒸餾技術(shù)得到了廣泛關(guān)注.因此,本文也著重研究基于知識蒸餾的模型壓縮方法.知識蒸餾方法將較大和較小的網(wǎng)絡(luò)分別定義為教師網(wǎng)絡(luò)和學生網(wǎng)絡(luò) (也稱之為壓縮后網(wǎng)絡(luò)).其主要思想在于,通過最小化該兩個網(wǎng)絡(luò)輸出分布差異,來實現(xiàn)網(wǎng)絡(luò)間的知識遷移,使得學生網(wǎng)絡(luò)盡可能地獲得教師網(wǎng)絡(luò)的知識,提高學生網(wǎng)絡(luò)的準確率.從而,學生網(wǎng)絡(luò)可以在維持其參數(shù)量不變的情況下提升性能,盡可能逼近甚至有可能超越教師網(wǎng)絡(luò)的性能.傳統(tǒng)的知識蒸餾方法是將網(wǎng)絡(luò)的輸出分布作為知識在網(wǎng)絡(luò)間進行遷移,隨著該研究領(lǐng)域的進一步發(fā)展,研究發(fā)現(xiàn)[6],利用其他一些具有代表性的表征信息或知識在網(wǎng)絡(luò)間進行遷移或蒸餾,可以獲得比傳統(tǒng)知識蒸餾方法更好的效果.知識蒸餾方法大致又可以分為: ①基于網(wǎng)絡(luò)輸出層的知識蒸餾方法;② 基于網(wǎng)絡(luò)中間層的知識蒸餾方法;③基于樣本關(guān)系之間的知識蒸餾方法.
本文提出了一種新的基于隱層相關(guān)聯(lián)算子的知識蒸餾 (Correlation Operation Based Knowledge Distillation,CorrKD) 方法,通過計算教師網(wǎng)絡(luò)與學生網(wǎng)絡(luò)各自隱含層之間的關(guān)聯(lián)性,挖掘出更有效的知識表征,從而將教師的知識表征遷移到學生的知識表征中,提高學生網(wǎng)絡(luò)的判別性.該方法的核心是利用了被廣泛應(yīng)用于光流[7-8]、圖像匹配[9]等領(lǐng)域內(nèi)的相關(guān)聯(lián)算子,用于提取網(wǎng)絡(luò)中間層的知識表征.相關(guān)聯(lián)算子的特性在于,可以很好地表征兩個特征之間的匹配程度,并反映其特征的變化過程.首先,本文對于網(wǎng)絡(luò)中每個階段的輸入特征與輸出特征,利用相關(guān)聯(lián)算子進行建模與知識提取,有效獲得了圖像特征的學習變化信息.然后,將教師網(wǎng)絡(luò)每階段通過相關(guān)聯(lián)算子提取出的表征信息作為知識,遷移到學生網(wǎng)絡(luò)中,提升學生網(wǎng)絡(luò)判別性和學習有效性.
在CIFAR-10 和CIFAR-100 分類數(shù)據(jù)集評測結(jié)果中,相比其他中間層知識蒸餾方法,本文所提出的方法取得了較好的效果.同時,本文所提出的方法在減小網(wǎng)絡(luò)的計算量和參數(shù)量的同時,能夠有效逼近原始網(wǎng)絡(luò)的準確率.
除本文將詳細介紹的知識蒸餾方法外,其他主流的模型壓縮方法有: ①參數(shù)剪枝[10-11],該方法的主要思想在于,通過對已訓練好的深度神經(jīng)網(wǎng)絡(luò)模型移除冗余、信息量較少的權(quán)值,減少網(wǎng)絡(luò)模型的參數(shù),進而增大模型的計算速度和減小模型所占用的存儲空間,實現(xiàn)模型壓縮;② 參數(shù)量化[12-14],該方法的主要思想是一種將多個參數(shù)實現(xiàn)共享的直接表示形式,其核心思想在于,利用較低的位來代替原始32 位的浮點型參數(shù),從而縮減網(wǎng)絡(luò)存儲和浮點計算次數(shù);③低秩分解[15-16],該方法的核心思想在于,利用矩陣或張量的分解技術(shù)對網(wǎng)絡(luò)模型中的原始卷積核進行分解.一般來說,卷積計算是網(wǎng)絡(luò)中復雜度最高且最為普遍的計算操作,通過對張量進行分解從而減小模型內(nèi)部冗余性,實現(xiàn)模型壓縮;④ 輕量型網(wǎng)絡(luò)結(jié)構(gòu)設(shè)計,輕量型網(wǎng)絡(luò)結(jié)構(gòu)設(shè)計的方法主要是改變了卷積神經(jīng)網(wǎng)絡(luò)的結(jié)構(gòu)特征,提出了一些新穎的輕量計算模塊或操作,從而精簡網(wǎng)絡(luò)結(jié)構(gòu),增大處理速度.如基于深度可分離卷積的MobileNet[17],利用神經(jīng)網(wǎng)絡(luò)結(jié)構(gòu)搜索得到的EfficientNet[18]等.
知識蒸餾方法[19]指利用教師網(wǎng)絡(luò)中的知識表征為學生網(wǎng)絡(luò)提供指導,以提高學生網(wǎng)絡(luò)的性能.傳統(tǒng)的知識蒸餾方法通過最小化教師網(wǎng)絡(luò)和學生網(wǎng)絡(luò)類別輸出分布的KL (Kullback-Leibler)散度來實現(xiàn)蒸餾.除了在輸出層外,網(wǎng)絡(luò)中間層的特征信息也被應(yīng)用到知識蒸餾方法中.
中間層特征知識的構(gòu)造.Romero 等[20]提出的FitNet 是較早利用中間特征信息進行知識蒸餾的方法,其目標是使經(jīng)過奇異值分解的學生網(wǎng)絡(luò)盡可能學習教師網(wǎng)絡(luò)中間層的特征信息.隨后,Zagoruyko 等[21]提出在網(wǎng)絡(luò)中間層引入注意力機制,將每層的注意力特征作為可學習的知識遷移到學生網(wǎng)絡(luò)中.近年來,隨著自注意力模型被廣泛運用到變形器[22]中,進而獲得人工智能領(lǐng)域各項任務(wù)的性能突破,相關(guān)知識蒸餾方法[23-24]通過對齊教師與學生的自注意力矩陣實現(xiàn)知識遷移.Yim 等[25]提出了FSP (Flow of Solution Procedure)方法,將網(wǎng)絡(luò)中每層之間的數(shù)據(jù)流動關(guān)系作為知識,由教師網(wǎng)絡(luò)遷移到學生網(wǎng)絡(luò)中.除此之外,樣本之間的關(guān)系特征也被發(fā)現(xiàn)可以凝煉出更好的知識表示.例如,Park 等[26]提出RKD (Relational Knowledge Distillation)知識蒸餾框架,對于不同樣本網(wǎng)絡(luò)輸出的結(jié)構(gòu)關(guān)系進行建模,將關(guān)系特征進行知識遷移.此外,Liu 等[27]通過將教師網(wǎng)絡(luò)特征空間映射到由頂點與邊構(gòu)成的圖表示空間中,然后對齊教師與學生網(wǎng)絡(luò)的頂點以及它們邊的對應(yīng)信息實現(xiàn)知識蒸餾.Tung 等[28]利用網(wǎng)絡(luò)中間層每個樣本之間的相似度信息進行知識遷移.Kim 等[29]提出在教師網(wǎng)絡(luò)的最后一層特征中提取便于學生網(wǎng)絡(luò)理解的轉(zhuǎn)移因子,將知識傳遞給學生網(wǎng)絡(luò).對于教師網(wǎng)絡(luò)和學生網(wǎng)絡(luò)中間層特征不一致的情況,Heo 等[30]提出了使用 1×1 卷積進行維度對齊,并構(gòu)建教師網(wǎng)絡(luò)激活邊界作為中間層知識遷移到學生網(wǎng)絡(luò)中.不僅如此,特征圖的雅可比梯度信息[31]也可以作為中間層特征知識表示.近年來,出現(xiàn)了一些在輸出層特征進行對比學習[32]或基于自監(jiān)督[33]的知識蒸餾方法,分別用于挖掘教師網(wǎng)絡(luò)和學生網(wǎng)絡(luò)對于不同樣本之間的關(guān)系,從而將教師網(wǎng)絡(luò)的關(guān)系知識遷移到學生網(wǎng)絡(luò)中.不同于以上知識蒸餾方法,本文所提出的基于相關(guān)聯(lián)系數(shù)的知識蒸餾方法作用于每階段中間層特征信息,從而獲得每階段中間特征變化信息,能更好構(gòu)建知識表征,提高學生網(wǎng)絡(luò)的學習性能.
使用優(yōu)化訓練策略進行中間層知識蒸餾.近年來,大量生成對抗思想被應(yīng)用到中間層知識蒸餾中,提高知識蒸餾性能.例如,Su 等[34]引入了任務(wù)驅(qū)動的注意力機制,將教師網(wǎng)絡(luò)和學生網(wǎng)絡(luò)各自高層信息嵌入低層中,實現(xiàn)中間層信息的遷移,同時加入判別器用于增強學生網(wǎng)絡(luò)最后輸出特征的魯棒性.類似地,Shen 等[35]提出了基于對抗學習的多教師網(wǎng)絡(luò)集成蒸餾框架,利用自適應(yīng)池化操作對齊一個學生與多個教師集成網(wǎng)絡(luò)的中間層輸出維度,同時利用生成對抗策略對池化的中間層特征進行對抗訓練,提高了知識蒸餾性能.Chung 等[36]提出了基于中間層特征圖的在線對抗蒸餾框架,設(shè)計教師網(wǎng)絡(luò)和學生網(wǎng)絡(luò)的判別器,用于共同學習和對齊這兩個網(wǎng)絡(luò)在訓練過程中的特征圖分布的變化情況.Jin 等[37]提出了一種路線限制優(yōu)化策略,預先設(shè)定好教師網(wǎng)絡(luò)訓練的中間模型狀態(tài),并通過逐步對齊學生網(wǎng)絡(luò)與其中間層特征分布,使得學生網(wǎng)絡(luò)獲得更好的局部最優(yōu)解.
知識蒸餾方法[19]認為在數(shù)據(jù)的網(wǎng)絡(luò)輸出中,每一個數(shù)據(jù)的預測概率結(jié)果都可以看作是一個分布,不僅關(guān)注于置信度最高的類別所對應(yīng)的結(jié)果,而且對于預測錯誤結(jié)果的置信度概率也具備一定的網(wǎng)絡(luò)知識.在傳統(tǒng)分類任務(wù)所使用的交叉熵損失函數(shù)中,只會關(guān)注對應(yīng)于正確類別的概率值,對于其他類別所對應(yīng)的概率是直接丟棄,沒有利用的,Hinton 等[19]將其稱作是暗知識.在知識蒸餾的過程中,學生網(wǎng)絡(luò)所學習到的,不僅是預測正確的類別所對應(yīng)的概率值結(jié)果,而且包括教師網(wǎng)絡(luò)所學習到的暗知識.
在具體的實現(xiàn)過程中,將教師網(wǎng)絡(luò)記為ft,學生網(wǎng)絡(luò)記為fs,將輸入記作x,教師網(wǎng)絡(luò)和學生網(wǎng)絡(luò)的模型輸出結(jié)果分別記為zt和zs,且zt=ft(x),zs=fs(x),zt,zs∈Rd,d為總類別數(shù).對于網(wǎng)絡(luò)得到的輸出分布,利用 Softmax對此進行歸一化,得到概率分布.同時,還引入了溫度分布參數(shù)τ用來平滑該層的輸出分布,以強化網(wǎng)絡(luò)輸出的概率分布中所學習到的知識,通過溫度平滑后的網(wǎng)絡(luò)輸出被稱為軟目標.對此,以教師網(wǎng)絡(luò)為例,對于第i個輸入樣本xi,其軟目標用公式表示為
式(2)中:n表示樣本總個數(shù),KL(ps||pt) 定義為學生網(wǎng)絡(luò)輸出分布與教師網(wǎng)絡(luò)輸出分布之間差異,具體公式表示為
所以,在學生網(wǎng)絡(luò)訓練的過程中,教師網(wǎng)絡(luò)的軟目標與真實標簽共同起到監(jiān)督作用.傳統(tǒng)知識蒸餾損失函數(shù)為
式(3)中:LCE為傳統(tǒng)的學生網(wǎng)絡(luò)輸出與真實標簽的交叉熵損失函數(shù);α為平衡因子,用于權(quán)衡LCE和LKL的重要性比例.
相關(guān)聯(lián)算子[7]被廣泛應(yīng)用到光流、圖像匹配、目標跟蹤領(lǐng)域中,用于描述兩張圖像或兩個特征之間的匹配程度(圖1).對于三維的圖像特征張量A和B,其尺寸為C×H ×W,C、H和W分別表示其特征圖的通道數(shù)、高度與寬度.特征張量A中給定位置 (i,j)的特征為PA(i,j)∈RC,需要計算其與特征張量B中所對應(yīng)位置圖像塊的特征相似度,這里所對應(yīng)的圖像塊以 (i,j)為中心,大小為k×k,將該區(qū)域內(nèi)的像素位置記為 (i′,j′) ,所對應(yīng)的特征為PB(i′,j′) ,與PA(i,j)類似,該像素特征均為C維向量.因此,可以通過計算內(nèi)積的方式得到對應(yīng)像素特征之間的相似度,由此得到相關(guān)聯(lián)算子φ,其計算公式為
圖1 相關(guān)聯(lián)算子示意圖Fig.1 Illustration of correlation operation
式(4)中:⊙表示向量內(nèi)積,為歸一化系數(shù).由此,可以得到特征張量A和B之間的相關(guān)聯(lián)算子,可以將其記為φ(A,B)∈Rk2×H×W.所以,對于給定的兩個三維圖像特征張量,可以通過計算像素特征與圖像塊中每個像素之間的相似度,得到尺寸為k2×H ×W的相關(guān)聯(lián)算子,用于反映特征之間的相似程度或匹配程度.
借助相關(guān)聯(lián)算子,可以計算網(wǎng)絡(luò)模型隱層中尺度相同的兩個特征張量之間的特征,用以反映特征的匹配相似程度,并利用其進行知識遷移 (圖2).圖2 中的KL 損失LKL和LCor損失分別被定義于式(2)和式(5)中,xi和分別為第i個輸入樣本和該樣本增強變化后的表示.
圖2 基于隱層相關(guān)聯(lián)算子蒸餾方法的整體框架Fig.2 Illustration of intermediate CorrKD framework
通常,網(wǎng)絡(luò)模型會根據(jù)其特征圖空間尺寸大小的不同而劃分成不同的階段,換句話說,在相同的網(wǎng)絡(luò)階段內(nèi),其中間特征的維度尺寸都是相同的.因此,可以將每個階段的第一層特征與最后一層輸出特征作為相關(guān)聯(lián)算子中的特征張量A和B.該相關(guān)聯(lián)算子的計算可以很好地反映出模型每個階段對于數(shù)據(jù)的處理變化過程,成為非常有效的知識表征.因此,可以將相關(guān)聯(lián)算子計算結(jié)果用作知識蒸餾的表征信息,由教師網(wǎng)絡(luò)對學生網(wǎng)絡(luò)進行指導.假設(shè)網(wǎng)絡(luò)有N個階段,教師網(wǎng)絡(luò)和學生網(wǎng)絡(luò)的第i個階段的第一層輸入特征分別記為,最后一層的輸出特征分別記為Fit2和Fis2,其知識遷移的過程可以利用LCor損失進行約束,對此,基于隱層相關(guān)聯(lián)算子的知識遷移損失函數(shù)可以表示為
式(5)中:λi,i=1,2,···,N表示第i階段的權(quán)重因子,||·||2為L2范數(shù).為了更好形成多樣的知識表征,在本文中引入數(shù)據(jù)增強和變化[4](如旋轉(zhuǎn)、翻轉(zhuǎn)、顏色變化等),可以更有效地將隱含層的相關(guān)聯(lián)算子的知識遷移到學生網(wǎng)絡(luò)中,從而產(chǎn)生更好的效果.通過結(jié)合了教師網(wǎng)絡(luò)中傳統(tǒng)知識蒸餾損失函數(shù) (式 (3))和隱層相關(guān)聯(lián)算子的知識遷移損失函數(shù) (式 (5)),可以得到該知識蒸餾方法完整的訓練損失函數(shù)公式為
式(6)中:β為超參數(shù),用于控制3 個損失 (LCE、LKL和LCor) 的平衡性.在訓練過程中,本文直接使用梯度下降法優(yōu)化式 (6),選擇學生網(wǎng)絡(luò)進行測試,并計算出學生網(wǎng)絡(luò)的準確率作為該方法的評測效果.
本文在兩個經(jīng)典的分類公開數(shù)據(jù)集CIFAR-10 與CIFAR-100 上進行了實驗,均包含6 萬張長寬尺寸均為32 的圖像,其中5 萬張用于訓練,剩下的1 萬張用于測試,他們的分類類別數(shù)分別為10和100.
本文所提出的方法使用Pytorch 在單張GPU 上進行實現(xiàn),對于兩種數(shù)據(jù)集均采用隨機梯度下降方法進行優(yōu)化.在訓練中,圖像批量大小設(shè)置為64,學習率設(shè)置為0.05,動量設(shè)置為0.9,權(quán)重衰減系數(shù)為0.000 5.對于教師網(wǎng)絡(luò),利用標準交叉熵損失函數(shù)進行訓練,訓練迭代次數(shù)為240,其學習率分別在第150、180、210 次迭代時,分別縮小為原來的1/10,訓練完成后將教師網(wǎng)絡(luò)進行保存,存儲于本地磁盤中.
對于學生網(wǎng)絡(luò),需要先讀取教師網(wǎng)絡(luò)的模型參數(shù),利用所提出的損失函數(shù)式 (6) 進行訓練,模型訓練優(yōu)化器與學習率設(shè)置均與教師網(wǎng)絡(luò)一致,訓練迭代次數(shù)設(shè)為300,其學習率分別在第180,220,260 次迭代時,分別縮小為原來的1/10.
在相關(guān)聯(lián)算子的計算過程中,需要引入數(shù)據(jù)增強,首先,對于圖像進行隨機旋轉(zhuǎn)與翻轉(zhuǎn).其次,在圖像色彩上從灰度轉(zhuǎn)化、色彩抖動、高斯模糊等操作中隨機選取一種對圖像進行色彩上的增強.在相關(guān)聯(lián)算子的計算過程中,參數(shù)k=7,對于所選取的網(wǎng)絡(luò)模型,其結(jié)構(gòu)均為4 個階段,也就是式 (5) 中的N=4,同時將每個階段的權(quán)重設(shè)為相等,也就是λi=1 .設(shè)置式 (1) 中的τ=4 .最后,設(shè)置式 (6)中的α=0.2,β=5 .
本文所提出的方法在多種模型結(jié)構(gòu)上進行實驗驗證,選取ResNet[4]與WideResNet[38](WRN)作為網(wǎng)絡(luò)主干,并在多種教師網(wǎng)絡(luò)與學生網(wǎng)絡(luò)組合上進行實驗.表1 總結(jié)了4 組教師網(wǎng)絡(luò)與學生網(wǎng)絡(luò)的參數(shù)量與計算量信息.在表2 和表3 中,總結(jié)了本文所提出方法的性能效果,其中本文所提出的基于隱層相關(guān)聯(lián)算子的知識蒸餾方法記為CorrKD,僅利用中間層式 (5) 與交叉熵損失訓練得到的學生網(wǎng)絡(luò)方法簡稱為Corr,KD 表示僅利用式 (3) 進行訓練的傳統(tǒng)知識蒸餾訓練結(jié)果.注意到表2 與表3 中的第3 和第4 列分別表示教師網(wǎng)絡(luò)與學生網(wǎng)絡(luò)在正常情況下訓練得到的基準準確率結(jié)果 (即只使用交叉熵損失函數(shù)).KD 展示了學生網(wǎng)絡(luò)在利用式 (3) 訓練得到的傳統(tǒng)知識蒸餾方法的結(jié)果.
表1 實驗所用模型參數(shù)量與計算量信息Tab.1 Model parameters and FLOPs information used in the experiment
從實驗結(jié)果來看,單純基于中間隱層相關(guān)聯(lián)算子的知識遷移方法可以對于學生網(wǎng)絡(luò)的訓練帶來一定的促進作用,但效果并不明顯.通過結(jié)合了輸出層的傳統(tǒng)知識蒸餾方法KD 之后,在學生網(wǎng)絡(luò)的分類正確率上,獲得了很好的性能提升.在蒸餾教師網(wǎng)絡(luò)WRN40-2 時,在CIFAR-10 上學生網(wǎng)絡(luò)WRN16-2 的網(wǎng)絡(luò)參數(shù)和網(wǎng)絡(luò)計算量都約為原來教師網(wǎng)絡(luò)WRN40-2 的31.8%,即參數(shù)量 (教師網(wǎng)絡(luò)參數(shù)量為2.2 M,學生網(wǎng)絡(luò)參數(shù)量為0.7 M,教師網(wǎng)絡(luò)計算量為329.0 M,學生網(wǎng)絡(luò)計算量為 101.6 M).如表2 所示,由本文所提出的CorrKD 方法得到的學生網(wǎng)絡(luò)準確率只下降了0.5 百分點 (教師網(wǎng)絡(luò)準確率為95.2%,學生網(wǎng)絡(luò)使用CorrKD 方法準確率為94.7%).對于類別個數(shù)更多的CIFAR-100 上,同樣蒸餾的網(wǎng)絡(luò)選擇,由本文所提出的CorrKD 方法壓縮WRN40-2 后的網(wǎng)絡(luò)計算量和參數(shù)量約是壓縮前的31.8% (表1),準確率只下降1 百分點 (表3 中教師網(wǎng)絡(luò)準確率為76.8%,由CorrKD 方法得到的準確率為75.8%).由此可見,本文所提出的方法在準確率有限下降的情況下,模型能夠獲得顯著的壓縮比,壓縮后形成的學生網(wǎng)絡(luò)能夠有效嵌入受限移動設(shè)備端中.
表2 CorrKD 在CIFAR-10 上實驗結(jié)果Tab.2 Experimental results of CorrKD on CIFAR-10
表3 CorrKD 在CIFAR-100 上實驗結(jié)果Tab.3 Experimental results of CorrKD on CIFAR-100
在CIFAR-100 上,也可視化了本文所提出的CorrKD 方法對于蒸餾WRN16-2 的訓練損失的變化以及測試準確率的變化.如圖3 所示,隨著訓練的回合數(shù)的增加,完整訓練損失Lo逐步減小,同時測試準確率逐漸提升.該訓練結(jié)果驗證了本文所提出的方法在訓練上的穩(wěn)定性與有效性.
圖3 完整訓練損失Lo 和測試準確率變化曲線Fig.3 Curves of overall training loss Loa nd test accuracy with respect to the epoch number
在CIFAR-100 評測數(shù)據(jù)集上并以WideResNet 為主干網(wǎng)絡(luò),將本文所提出的方法與其他經(jīng)典基于中間層的知識蒸餾方法進行對比,包括FitNet[20],AT (Attention Transfer)[21],SP (Similarity-Preserving)[28]和FT (Factor Transfer)[29].為保證公平性,上述中間層蒸餾方法都展示與傳統(tǒng)KD 相結(jié)合訓練的實驗結(jié)果,各方法所得到的結(jié)果對比如表4 所示.從實驗結(jié)果來看,本文所提出的知識蒸餾方法在WideResNet 模型結(jié)構(gòu)上,和其他中間層的知識蒸餾方法相比,取得了較好水平.例如,在學生網(wǎng)絡(luò)為WRN16-1 時,本文所提出的方法和AT 方法相比,準確率提高了0.1 百分點 (CorrKD 準確率為74.6%,AT 準確率為74.5%),同時,與教師網(wǎng)絡(luò)WRN40-2 相比,準確率降低2.2 百分點 (CorrKD 準確率為74.6%,WRN40-2 準確率為76.8% (表3)).
表4 CorrKD 與其他知識蒸餾方法在CIFAR-100 上準確率對比Tab.4 Accuracy comparison between different KD methods and CorrKD on CIFAR-100
本節(jié)主要探索部分超參數(shù)對于實驗效果的影響,主要包括相關(guān)聯(lián)算子中參數(shù)k的影響以及完整的訓練損失函數(shù)中參數(shù)α,β的影響.實驗均在CIFAR-100 上進行,教師網(wǎng)絡(luò)結(jié)構(gòu)選取WRN40-2,學生網(wǎng)絡(luò)結(jié)構(gòu)選取WRN16-2.對于3 組參數(shù)的實驗結(jié)果分別如表5 和表6 所示,“教師網(wǎng)絡(luò)→學生網(wǎng)絡(luò)”表示教師網(wǎng)絡(luò)蒸餾學生網(wǎng)絡(luò)所使用的網(wǎng)絡(luò)模型.在k相關(guān)的實驗中,固定α=0.2,β=5 ;同理,在α與β相關(guān)的實驗中,固定其他兩個參數(shù).從實驗結(jié)果看出,實驗中所選取的參數(shù)k=7,α=0.2,β=5 均為最佳參數(shù).
表5 相關(guān)聯(lián)算子參數(shù) k 實驗結(jié)果對比Tab.5 Comparison with different values of k in the correlation operation
表6 完整訓練損失 Lo 中參數(shù) α ,β 實驗結(jié)果對比Tab.6 Comparison with different values of α ,β in the overall training lossLo
本文提出了一種新的基于隱層相關(guān)聯(lián)算子的知識蒸餾方法,首次將用于光流中的相關(guān)聯(lián)算子計算操作運用到模型中間隱含層的特征提取中,相關(guān)聯(lián)算子可以對特征之間的匹配程度或變化過程進行有效建模,反映模型中間層的表征信息.同時在數(shù)據(jù)增強的作用下,進行中間層的知識遷移,結(jié)合輸出層的傳統(tǒng)知識蒸餾方法,構(gòu)成了本文所提出的全新知識蒸餾框架.實驗表明,本文所提出的知識蒸餾方法在兩種公開數(shù)據(jù)集上均取得了優(yōu)越性能,并在WideResNet 模型上取得了同類型中間層知識蒸餾方法中的最優(yōu)水平.在未來的研究中,可以考慮將該模型中間層表征知識提取方法利用到更多視覺領(lǐng)域下游任務(wù)的蒸餾中,并在多個任務(wù)上驗證本文所提出方法的壓縮效果.