林 磊,孫 涵
(南京航空航天大學(xué) 計(jì)算機(jī)科學(xué)與技術(shù)學(xué)院/人工智能學(xué)院,江蘇 南京 211106)
隨著深度學(xué)習(xí)的飛速發(fā)展,現(xiàn)實(shí)世界中也存在越來越多的域自適應(yīng)問題,這些問題大部分是由于數(shù)據(jù)標(biāo)注成本昂貴和深度模型對(duì)于不同任務(wù)的欠遷移性所導(dǎo)致的[1]。無監(jiān)督域自適應(yīng)旨在解決有標(biāo)記的訓(xùn)練樣本和無標(biāo)記的測(cè)試樣本來自于不同領(lǐng)域(分別稱為源域和目標(biāo)域)的問題。若用源域的標(biāo)記樣本進(jìn)行訓(xùn)練學(xué)習(xí)到的模型,在不進(jìn)行域自適應(yīng)的情況下直接應(yīng)用于目標(biāo)域樣本,則模型會(huì)出現(xiàn)明顯的性能下降。這種性能的降低主要是由于兩個(gè)領(lǐng)域之間存在因?yàn)閿?shù)據(jù)分布差異而導(dǎo)致的域偏移所造成。不少域自適應(yīng)方法采用偽標(biāo)簽[2]的思想,利用源域數(shù)據(jù)訓(xùn)練出的模型為目標(biāo)域生成偽標(biāo)簽。但是現(xiàn)存的偽標(biāo)簽方法存在兩個(gè)不足之處:偽標(biāo)簽由源域數(shù)據(jù)訓(xùn)練的模型生成,這會(huì)導(dǎo)致偽標(biāo)簽受限于源域數(shù)據(jù),因?yàn)橛蜻w移現(xiàn)象,該偽標(biāo)簽無法完全適配目標(biāo)域數(shù)據(jù);在訓(xùn)練的早期,網(wǎng)絡(luò)可能生成錯(cuò)誤的偽標(biāo)簽,不進(jìn)行修正繼續(xù)訓(xùn)練則會(huì)導(dǎo)致網(wǎng)絡(luò)學(xué)習(xí)的分布與目標(biāo)域分布差異越來越大。根據(jù)目前基于偽標(biāo)簽域自適應(yīng)方法的不足之處,該文從偽標(biāo)簽的選擇和更新兩個(gè)方面提出了改進(jìn)方案。
受到半監(jiān)督學(xué)習(xí)中元偽標(biāo)簽[3]的啟發(fā),提出了基于自糾錯(cuò)偽標(biāo)簽的無監(jiān)督域自適應(yīng)方法(Self-correcting Pseudo Label for unsupervised domain adaptation,SPL)。首先,在偽標(biāo)簽生成階段,主要使用源域子空間對(duì)齊和目標(biāo)域聚類對(duì)齊結(jié)合的方法選擇更優(yōu)的初始偽標(biāo)簽。然后,在訓(xùn)練過程中,利用學(xué)生教師雙網(wǎng)絡(luò)模型進(jìn)行偽標(biāo)簽的更新。具體而言,教師網(wǎng)絡(luò)使用源域和目標(biāo)域數(shù)據(jù)一起訓(xùn)練,生成最優(yōu)偽標(biāo)簽,學(xué)生網(wǎng)絡(luò)則使用目標(biāo)域數(shù)據(jù)和偽標(biāo)簽進(jìn)行有監(jiān)督訓(xùn)練。接著,根據(jù)學(xué)生網(wǎng)絡(luò)在偽標(biāo)簽?zāi)繕?biāo)域數(shù)據(jù)集上的表現(xiàn),同步優(yōu)化教師網(wǎng)絡(luò)的參數(shù),并且相應(yīng)地調(diào)整偽標(biāo)簽,以進(jìn)一步提高學(xué)生網(wǎng)絡(luò)的表現(xiàn)。在多個(gè)標(biāo)準(zhǔn)域自適應(yīng)數(shù)據(jù)集上的實(shí)驗(yàn)結(jié)果證明了該方法在域自適應(yīng)問題的有效性。
在無監(jiān)督域自適應(yīng)問題中,因?yàn)樵从虻臉?biāo)記樣本和目標(biāo)域的無標(biāo)記樣本在訓(xùn)練階段都是可用的,所以它是一個(gè)可以進(jìn)行歸納學(xué)習(xí)的過渡性學(xué)習(xí)問題。早期的方法試圖通過學(xué)習(xí)一個(gè)聯(lián)合子空間來調(diào)整源域和目標(biāo)域,從而使任何一個(gè)域的樣本都能被投射到這個(gè)共同的子空間中,然后采用不同的算法來促進(jìn)目標(biāo)域樣本在這個(gè)子空間中的可分離性[4]。然后,使用在大規(guī)模ImageNet數(shù)據(jù)集上預(yù)訓(xùn)練的深度模型提取特征的方法進(jìn)一步促進(jìn)了這些基于特征轉(zhuǎn)換的方法。隨后,梯度反轉(zhuǎn)[5]和對(duì)抗學(xué)習(xí)[6]的方法被用于深度域自適應(yīng),以學(xué)習(xí)端到端方式學(xué)習(xí)域不變特征。另一種有效的方式則是為目標(biāo)域樣本生成偽標(biāo)簽[7-8]。
盡管偽標(biāo)簽方法的性能很好,但是它們都有兩大不足之處。首先,偽標(biāo)簽主要由有標(biāo)記的源域數(shù)據(jù)訓(xùn)練的分類器生成,這樣生成的偽標(biāo)簽會(huì)過度依賴于源域數(shù)據(jù),而源域和目標(biāo)域之間存在域偏移,所以這種偽標(biāo)簽就可能攜帶噪聲信息,導(dǎo)致最后訓(xùn)練出來的網(wǎng)絡(luò)在目標(biāo)域上出現(xiàn)性能下降的情況。此外,目前偽標(biāo)簽主要分為硬偽標(biāo)簽和軟偽標(biāo)簽。硬偽標(biāo)簽的策略是為每一個(gè)未標(biāo)記的目標(biāo)域樣本都分配一個(gè)偽標(biāo)簽,然后將有偽標(biāo)簽的目標(biāo)域樣本和源域樣本一起進(jìn)行學(xué)習(xí)來改進(jìn)分類模型。這種硬偽標(biāo)簽的問題是,在訓(xùn)練初期,弱分類器誤標(biāo)的樣本可能會(huì)對(duì)后續(xù)學(xué)習(xí)過程造成嚴(yán)重傷害。弱偽標(biāo)簽則是為目標(biāo)域樣本分配每個(gè)類的條件概率,從而得到一個(gè)偽標(biāo)簽向量,并且在每次迭代訓(xùn)練的過程中都更新這個(gè)偽標(biāo)簽向量。雖然弱偽標(biāo)簽優(yōu)于硬偽標(biāo)簽,但是如果弱偽標(biāo)簽更新方法的不佳同樣會(huì)導(dǎo)致出現(xiàn)硬偽標(biāo)簽的問題。所以,針對(duì)目前偽標(biāo)簽方法存在的問題,該文提出了一種更加穩(wěn)定的可以進(jìn)行自糾錯(cuò)的偽標(biāo)簽域自適應(yīng)方法。
提出的網(wǎng)絡(luò)主要由學(xué)生網(wǎng)絡(luò)S和教師網(wǎng)絡(luò)T組成,它們對(duì)應(yīng)的參數(shù)分別為θS和θT。教師網(wǎng)絡(luò)的作用是利用有標(biāo)記源域數(shù)據(jù)和無標(biāo)記目標(biāo)域數(shù)據(jù)訓(xùn)練,然后進(jìn)行子空間對(duì)齊,將源域和目標(biāo)域數(shù)據(jù)映射到易于區(qū)分特征的子空間內(nèi)。接著,在該子空間內(nèi)進(jìn)行雙重偽標(biāo)簽生成,分別從源域和目標(biāo)域兩個(gè)角度生成偽標(biāo)簽,綜合考慮源域中可遷移知識(shí)和目標(biāo)域內(nèi)的結(jié)構(gòu)信息。然后,從候選偽標(biāo)簽中選擇最優(yōu)偽標(biāo)簽加入偽標(biāo)簽集。接著,學(xué)生網(wǎng)絡(luò)利用目標(biāo)域數(shù)據(jù)集和偽標(biāo)簽集進(jìn)行訓(xùn)練。但是,這個(gè)時(shí)候的偽標(biāo)簽集與真實(shí)的目標(biāo)域標(biāo)簽集還存在一定的差異,需要進(jìn)行自糾錯(cuò)更新。具體而言,設(shè)置一個(gè)反饋信號(hào),用于反饋學(xué)生網(wǎng)絡(luò)在偽標(biāo)簽集上的表現(xiàn),然后將這一信號(hào)傳遞給教師網(wǎng)絡(luò)更新教師網(wǎng)絡(luò)的參數(shù)。學(xué)生網(wǎng)絡(luò)和教師網(wǎng)絡(luò)是并行訓(xùn)練的:學(xué)生網(wǎng)絡(luò)從教師網(wǎng)絡(luò)生成的偽標(biāo)簽數(shù)據(jù)中進(jìn)行學(xué)習(xí);教師網(wǎng)絡(luò)從反饋信號(hào)中學(xué)習(xí)學(xué)生網(wǎng)絡(luò)在偽標(biāo)簽數(shù)據(jù)上的表現(xiàn),從而更新偽標(biāo)簽。經(jīng)過這樣的自糾錯(cuò)過程,網(wǎng)絡(luò)可以學(xué)習(xí)到越來越貼合目標(biāo)域的偽標(biāo)簽,成功實(shí)現(xiàn)從源域到目標(biāo)域的遷移。
(1)
XHXTa=φa
(2)
所以,降維矩陣A=[a1,a2,…,am1]∈m×m1可以由協(xié)方差矩陣XHXT的前k個(gè)特征值對(duì)應(yīng)的特征向量構(gòu)成。應(yīng)用主成分分析降維后的數(shù)據(jù)m1×n為:
(3)
在這個(gè)子空間S,目的是拉近源域和目標(biāo)域之間的分布,但是因?yàn)橛蚱?,想要整體匹配源域和目標(biāo)域是不現(xiàn)實(shí)的,但是源域和目標(biāo)域的標(biāo)簽空間是一致的。在一致的標(biāo)簽空間內(nèi),可以將源域和目標(biāo)域數(shù)據(jù)一起進(jìn)行類別的對(duì)齊。因?yàn)閷?duì)于同一類別的樣本,無論其來自于哪個(gè)域,在子空間S中投影應(yīng)該是相近的。所以投影矩陣B的優(yōu)化方式如下:
(4)
其中,W表示源域和目標(biāo)域樣本數(shù)據(jù)之間的相似矩陣。因?yàn)樵从驍?shù)據(jù)是有標(biāo)注的,所以這里利用標(biāo)注數(shù)據(jù)對(duì)相似矩陣進(jìn)行優(yōu)化,即同一類別的樣本在映射空間內(nèi)的距離應(yīng)該是相近的,所以它們之間的權(quán)重可以設(shè)為1。
(5)
然后,利用MMD的優(yōu)化方式,投影變換可以優(yōu)化為:
(6)
(7)
(8)
(9)
其中,d(Sy,T)表示目標(biāo)域特征st與類別為y的源域類原型特征的距離。所以,可以利用目標(biāo)域特征st與每一個(gè)類的源域原型特征之間的差異表示目標(biāo)域特征st的條件概率:
根據(jù)源域的標(biāo)簽信息,獲得了基于源域原型的條件概率ps(y|xt)。這一概率主要依賴于源域的類內(nèi)分布,但是目標(biāo)域的類內(nèi)分布和源域之間還是存在一定差異的,所以只是使用源域類別信息的條件概率是不完善的,并且忽略了目標(biāo)域自身的類別結(jié)構(gòu)。為了獲取目標(biāo)域樣本的內(nèi)在結(jié)構(gòu),該文基于目標(biāo)域類別空間生成新的偽標(biāo)簽。
具體而言,使用K-means聚類算法在所有目標(biāo)樣本的投影向量st生成|Y|個(gè)聚類族群,聚類的中心位置使用源域原型進(jìn)行初始化。假設(shè)存在一個(gè)目標(biāo)域映射和源域映射一對(duì)一的相關(guān)矩陣C∈|Y|×|Y|,并且對(duì)于任意的cij∈C存在以下關(guān)系:
(11)
其中,cij表示目標(biāo)域中的第i個(gè)族群與類別j的相關(guān)性。該相關(guān)矩陣的優(yōu)化方式如下:
(12)
(13)
借此,可以得到基于目標(biāo)域類別結(jié)構(gòu)信息的條件概率:
(14)
使用迭代學(xué)習(xí)策略,交替學(xué)習(xí)用于域?qū)R的投影矩陣P和用于目標(biāo)樣本的雙偽標(biāo)簽。盡管上述兩種偽標(biāo)簽方法中的任何一種都能夠?yàn)橄乱淮蔚械耐队皩W(xué)習(xí)提供有用的偽標(biāo)簽,但它們?cè)诒举|(zhì)上是不同的。通過最近的源域類原型進(jìn)行的偽標(biāo)簽傾向于向靠近源數(shù)據(jù)的樣本輸出高概率,而基于目標(biāo)域結(jié)構(gòu)化預(yù)測(cè)則對(duì)靠近目標(biāo)域的聚類中心的樣本具有高置信度,無論它們離源域有多遠(yuǎn)。所以,該文主張通過公式10和公式14的簡(jiǎn)單組合來利用這兩種方法的互補(bǔ)性:
p(y|xt)=max{ps(y|xt),pt(y|xt)}
(15)
(16)
在半監(jiān)督學(xué)習(xí)領(lǐng)域中,學(xué)生教師模型的方法已經(jīng)被廣泛應(yīng)用到偽標(biāo)簽生成中,但是大部分的方法都是學(xué)生和教師之間無反饋的訓(xùn)練。受到元偽標(biāo)簽的啟發(fā),該文在域自適應(yīng)學(xué)生教師模型中引入自糾錯(cuò)的偽標(biāo)簽更新方式。具體而言,存在一批源域數(shù)據(jù)xs及其標(biāo)簽ys,目標(biāo)域數(shù)據(jù)xt,對(duì)于學(xué)生網(wǎng)絡(luò)S和教師網(wǎng)絡(luò)T,可以獲得相應(yīng)的軟預(yù)測(cè)值S(xt;θS),T(xs;θT)和T(xt;θT)。其中,學(xué)生網(wǎng)絡(luò)只使用目標(biāo)域數(shù)據(jù),教師網(wǎng)絡(luò)則同時(shí)使用源域和目標(biāo)域數(shù)據(jù)。所以在有監(jiān)督訓(xùn)練中,可以使用CE(ys,S(xt;θS))作為典型的交叉熵?fù)p失。在偽標(biāo)簽訓(xùn)練過程中,往往通過最小化目標(biāo)域數(shù)據(jù)的交叉熵?fù)p失來優(yōu)化學(xué)生網(wǎng)絡(luò)的參數(shù):
(17)
(18)
?θSExt[CE(T(xt;θT),S(xt;θS))]
(19)
其中,ηS表示學(xué)習(xí)率。所以,可以獲得最終的學(xué)生網(wǎng)絡(luò)優(yōu)化目標(biāo):
?θSExt[CE(T(xt;θT),S(xt;θS))]}
(20)
(21)
LT,s=CE(ys,S(xs;θT))
(22)
(23)
主要使用三組域自適應(yīng)任務(wù)中通用的數(shù)據(jù)集對(duì)上述方法進(jìn)行了對(duì)比實(shí)驗(yàn),分別是Office-31[10]數(shù)據(jù)集、Office-Home[11]數(shù)據(jù)集以及VisDA-2017[12]數(shù)據(jù)集。為了驗(yàn)證提出的基于自糾錯(cuò)偽標(biāo)簽算法的有效性,選擇與多個(gè)無監(jiān)督域自適應(yīng)的方法相比較,這些方法中有的使用了對(duì)抗學(xué)習(xí)的方法,有的則使用了偽標(biāo)簽的思想。
該文主要使用Pytorch深度學(xué)習(xí)框架來實(shí)現(xiàn)基于自糾錯(cuò)偽標(biāo)簽的無監(jiān)督域自適應(yīng)方法。為了公平比較,每次實(shí)驗(yàn)都用相同的網(wǎng)絡(luò)結(jié)構(gòu)。利用不包含最后全連接層的,在ImageNet上進(jìn)行預(yù)訓(xùn)練的ResNet50作為特征提取器。使用了源域中所有的標(biāo)簽數(shù)據(jù)和目標(biāo)域中所有的無標(biāo)簽數(shù)據(jù),最終在目標(biāo)域數(shù)據(jù)集上比較算法的圖像分類的準(zhǔn)確率。主要使用的GPU為Nvidia Titan Xp顯卡,主要環(huán)境是在Ubuntu16.04操作系統(tǒng)下。
根據(jù)表1的實(shí)驗(yàn)結(jié)果,提出的SPL算法在六種遷移任務(wù)上的平均準(zhǔn)確率都優(yōu)于其他對(duì)比算法。SPL算法相較于不進(jìn)行域自適應(yīng)的ResNet-50,在平均準(zhǔn)確率上提高了近13.9百分點(diǎn)。相較于同樣使用聚類偽標(biāo)簽但是偽標(biāo)簽沒有更新的CAT方法也有不錯(cuò)的提升,平均準(zhǔn)確率比CAT高了2.4百分點(diǎn)。具體到每一個(gè)遷移任務(wù),SPL的方法雖然在A→W、D→W和A→D任務(wù)上稍低于REN,但是在遷移任務(wù)更加困難的D→A和W→A上,SPL的方法明顯優(yōu)于REN,分別提升3.1和1.1百分點(diǎn)。
表1 Office-31實(shí)驗(yàn)結(jié)果 %
Office-Home數(shù)據(jù)集的實(shí)驗(yàn)結(jié)果如表2所示。Office-Home數(shù)據(jù)集相較于Office-31數(shù)據(jù)集更加困難。不過基于自糾錯(cuò)偽標(biāo)簽的SPL在面臨這類困難情況時(shí),發(fā)揮出了偽標(biāo)簽自糾錯(cuò)的優(yōu)勢(shì),最終平均準(zhǔn)確率比ResNet-50高了22.6百分點(diǎn),證明該算法具有不錯(cuò)的泛化能力。對(duì)于12個(gè)不同的遷移任務(wù),SPL在多個(gè)遷移任務(wù)上都表現(xiàn)最佳,尤其是Ar→Cl、Cl→Ar和Pr→Cl任務(wù)上,相較于第二名提升了2~3百分點(diǎn)。這是因?yàn)樵谶@些難度較大的任務(wù)上,源域和目標(biāo)域之間的差異過大,之前的方法大部分都過分依賴于源域的標(biāo)簽信息,而SPL充分考慮到了目標(biāo)域的結(jié)構(gòu)信息,根據(jù)源域和目標(biāo)域的結(jié)合提取出了更優(yōu)的偽標(biāo)簽,并且在后續(xù)更新中一直優(yōu)化偽標(biāo)簽,從而遷移能力更強(qiáng)。
表2 Office-Home實(shí)驗(yàn)結(jié)果 %
續(xù)表2
對(duì)于VisDA-2017數(shù)據(jù)集,只使用了其中的合成數(shù)據(jù)集到真實(shí)數(shù)據(jù)集的遷移任務(wù)。因?yàn)閂isDA-2017數(shù)據(jù)集比較復(fù)雜,所以采用ResNet-101作為主干網(wǎng)絡(luò)。因?yàn)樵摂?shù)據(jù)集的復(fù)雜性,大部分對(duì)比算法都無法在所有類別上表現(xiàn)良好,比如MSTN雖然在類別aero和mcycl上取得了最好的效果,但是其在truck的準(zhǔn)確率只有18.5%。文中方法雖然只在類別horse和person上表現(xiàn)最佳,但是基本上在其他類別上效果也不錯(cuò),所以最終的平均準(zhǔn)確率要優(yōu)于其他的對(duì)比算法。
表3 VisDA-2017實(shí)驗(yàn)結(jié)果 %
探討了域自適應(yīng)問題中如何有效地使用偽標(biāo)簽方法,提出了基于自糾錯(cuò)偽標(biāo)簽的無監(jiān)督域自適應(yīng)方法,不僅提出了更優(yōu)的偽標(biāo)簽選擇方案,而且使用了可以自動(dòng)糾錯(cuò)的偽標(biāo)簽更新方案。在偽標(biāo)簽選擇階段,充分考慮了源域的類別信息和目標(biāo)域的內(nèi)在結(jié)構(gòu)信息,從這兩方面出發(fā)提出了雙重偽標(biāo)簽,使生成的偽標(biāo)簽避免受限于源域知識(shí),并且更符合目標(biāo)域的特征分布。鑒于單網(wǎng)絡(luò)生成的偽標(biāo)簽可能無法完全監(jiān)督該網(wǎng)絡(luò)的訓(xùn)練,使用了學(xué)生教師模型,利用教師網(wǎng)絡(luò)同時(shí)訓(xùn)練源域和目標(biāo)域特征,然后為學(xué)生網(wǎng)絡(luò)生成偽標(biāo)簽,使偽標(biāo)簽的生成與使用分離。但是只有教師到學(xué)生網(wǎng)絡(luò)的單向反饋是不足的,該文使用元偽標(biāo)簽的思想,通過學(xué)生網(wǎng)絡(luò)在利用偽標(biāo)簽進(jìn)行訓(xùn)練時(shí)的反饋,反向優(yōu)化教師網(wǎng)絡(luò),使其形成一個(gè)循環(huán)優(yōu)化的過程。提出的SPL方法在Office-31、Office-Home和VisDA-2017數(shù)據(jù)集上進(jìn)行了大量的實(shí)驗(yàn),并且與多個(gè)不同無監(jiān)督域自適應(yīng)方法進(jìn)行了比較,驗(yàn)證了SPL方法的有效性以及遷移能力。