李飛 高曉光 萬開方
?
基于動(dòng)態(tài)Gibbs采樣的RBM訓(xùn)練算法研究
李飛1高曉光1萬開方1
目前大部分受限玻爾茲曼機(jī)(Restricted Boltzmann machines,RBMs)訓(xùn)練算法都是以多步Gibbs采樣為基礎(chǔ)的采樣算法.本文針對(duì)多步Gibbs采樣過程中出現(xiàn)的采樣發(fā)散和訓(xùn)練速度過慢的問題,首先,對(duì)問題進(jìn)行實(shí)驗(yàn)描述,給出了問題的具體形式;然后,從馬爾科夫采樣的角度對(duì)多步Gibbs采樣的收斂性質(zhì)進(jìn)行了理論分析,證明了多步Gibbs采樣在受限玻爾茲曼機(jī)訓(xùn)練初期較差的收斂性質(zhì)是造成采樣發(fā)散和訓(xùn)練速度過慢的主要原因;最后,提出了動(dòng)態(tài)Gibbs采樣算法,給出了對(duì)比仿真實(shí)驗(yàn).實(shí)驗(yàn)結(jié)果表明,動(dòng)態(tài)Gibbs采樣算法可以有效地克服采樣發(fā)散的問題,并且能夠以微小的運(yùn)行時(shí)間為代價(jià)獲得更高的訓(xùn)練精度.
受限玻爾茲曼機(jī),Gibbs采樣,采樣算法,馬爾科夫理論
引用格式李飛,高曉光,萬開方.基于動(dòng)態(tài)Gibbs采樣的RBM訓(xùn)練算法研究.自動(dòng)化學(xué)報(bào),2016,42(6):931-942
自2006年Hinton等[1]提出第一個(gè)深度置信網(wǎng)絡(luò)開始,經(jīng)過十年的發(fā)展,深度學(xué)習(xí)已逐漸成為機(jī)器學(xué)習(xí)研究領(lǐng)域的前沿?zé)狳c(diǎn).深度置信網(wǎng)絡(luò)[2]、深度卷積神經(jīng)網(wǎng)絡(luò)[3]、深度自動(dòng)編碼器[4]等深度網(wǎng)絡(luò)也廣泛應(yīng)用于機(jī)器學(xué)習(xí)的各個(gè)領(lǐng)域,如圖像識(shí)別、語(yǔ)音分析、文本分析等[5-7].相對(duì)于傳統(tǒng)的機(jī)器學(xué)習(xí)網(wǎng)絡(luò),深度網(wǎng)絡(luò)取得了更好的效果,極大地推動(dòng)了技術(shù)發(fā)展水平(State-of-the-art)[8].尤其在大數(shù)據(jù)背景下,針對(duì)海量無標(biāo)簽數(shù)據(jù)的學(xué)習(xí),深度網(wǎng)絡(luò)具有明顯的優(yōu)勢(shì)[9].
受限玻爾茲曼機(jī)(Restricted Boltzmann ma-chine,RBM)[10]是深度學(xué)習(xí)領(lǐng)域中的一個(gè)重要模型,也是構(gòu)成諸多深度網(wǎng)絡(luò)的基本單元之一.由于RBM較難訓(xùn)練,所以在很多大數(shù)據(jù)量任務(wù)上使用較少.但相對(duì)于其他基本模型,RBM具備較強(qiáng)的理論分析優(yōu)勢(shì)和可解釋性,是幫助我們理解深度網(wǎng)絡(luò)和其他基本模型內(nèi)在機(jī)理的重要模型,而且在某些特殊數(shù)據(jù)集上,RBM可以獲得更好的學(xué)習(xí)效果.所以,研究RBM仍然很有意義.RBM具有兩層結(jié)構(gòu),在無監(jiān)督學(xué)習(xí)下,隱層單元可以對(duì)輸入層單元進(jìn)行抽象,提取輸入層數(shù)據(jù)的抽象特征.當(dāng)多個(gè)RBM或RBM與其他基本單元以堆棧的方式構(gòu)成深度網(wǎng)絡(luò)時(shí),RBM隱層單元提取到的抽象特征可以作為其他單元的輸入,繼續(xù)進(jìn)行特征提取.通過這種方式,深度網(wǎng)絡(luò)可以提取到抽象度非常高的數(shù)據(jù)特征.當(dāng)采用逐層貪婪(Greedy layer-wise)[11]訓(xùn)練方法對(duì)深度網(wǎng)絡(luò)進(jìn)行訓(xùn)練時(shí),各個(gè)基本單元是逐一被訓(xùn)練的.因此,RBM訓(xùn)練的優(yōu)劣將直接影響整個(gè)深度網(wǎng)絡(luò)的性能.
2006年,Hinton等提出了對(duì)比散度 (Contrastive divergence,CD)算法[12]用以訓(xùn)練RBM網(wǎng)絡(luò).在每次訓(xùn)練迭代時(shí),CD算法以數(shù)據(jù)樣本為初始值,通過多步Gibbs迭代獲得目標(biāo)分布的近似采樣,然后通過該近似采樣來近似目標(biāo)梯度,取得了較好的效果,是目前RBM訓(xùn)練的標(biāo)準(zhǔn)算法.但研究表明,CD算法對(duì)目標(biāo)梯度的估計(jì)是有偏估計(jì)[13],而且每次迭代時(shí)都需要重新啟動(dòng)Gibbs采樣鏈,這降低了CD算法的訓(xùn)練性能.為此,Tieleman等以CD算法為基礎(chǔ),于2008年提出了持續(xù)對(duì)比散度(Persistent contrastive divergence,PCD)算法[14].在學(xué)習(xí)率足夠小的前提下,每次參數(shù)更新后,RBM模型的變化不大,可以認(rèn)為RBM網(wǎng)絡(luò)分布基本不變.基于此假設(shè),PCD算法只運(yùn)行一條獨(dú)立的采樣鏈,以上次采樣迭代的采樣值作為下次采樣迭代的初值繼續(xù)迭代,而不是像CD算法那樣每次采樣都以樣本數(shù)據(jù)為采樣初值,取得了比CD算法更好的訓(xùn)練效果.為了加速PCD算法,Tieleman又于2009年提出了加速持續(xù)對(duì)比散度(Fast persistent contrastive divergence,F(xiàn)PCD)算法[15],引入了額外的加速參數(shù)來提高訓(xùn)練速度.PCD算法和FPCD算法雖然訓(xùn)練性能較CD算法有所提高,但并沒有從本質(zhì)上提高CD算法的混合率[16].不管是CD算法,還是以CD算法為基礎(chǔ)的PCD算法、FPCD算法,都是通過一條Gibbs采樣鏈來逼近目標(biāo)分布,對(duì)于目標(biāo)分布較簡(jiǎn)單的數(shù)據(jù),可以取得較好的效果.但當(dāng)數(shù)據(jù)分布復(fù)雜,尤其為多模分布時(shí),即目標(biāo)分布函數(shù)存在多個(gè)峰值,Gibbs采樣鏈很容易陷入局部極小域,導(dǎo)致樣本不能描述數(shù)據(jù)分布的整體結(jié)構(gòu)[17].為克服這個(gè)問題,Desjardins(2010)等[18]、Cho(2010)等[19]、Brakel(2012)等[20]等分別提出應(yīng)用并行回火算法(Parallel tempering,PT)來訓(xùn)練RBM.PT算法并行化多條溫度鏈,每條溫度鏈上進(jìn)行多步Gibbs迭代.高溫鏈采樣目標(biāo)總體分布的結(jié)構(gòu)信息,低溫鏈采樣目標(biāo)局部分布的精確信息.不同溫度鏈之間以一定的交換概率進(jìn)行交換,不斷迭代,最后低溫鏈就可以精確獲得目標(biāo)分布的總體信息.對(duì)于多模分布數(shù)據(jù),PT算法的訓(xùn)練效果要明顯優(yōu)于CD算法[21].
通過以上描述可知,不管是CD算法還是PT算法,本質(zhì)上都是以Gibbs采樣來獲得關(guān)于目標(biāo)分布的采樣樣本.因此,Gibbs采樣性能的優(yōu)劣將直接影響以上算法的訓(xùn)練效果.本文研究發(fā)現(xiàn),當(dāng)采用多步Gibbs采樣時(shí),在訓(xùn)練初期會(huì)發(fā)生采樣發(fā)散現(xiàn)象,嚴(yán)重影響網(wǎng)絡(luò)收斂速度,而且算法運(yùn)行速度較慢;當(dāng)采用單步Gibbs采樣時(shí),前期網(wǎng)絡(luò)收斂性質(zhì)較好,且算法運(yùn)行速度較快,但后期采樣精度不高.如何在前期保證良好的收斂性質(zhì),同時(shí)在后期保證網(wǎng)絡(luò)訓(xùn)練精度并提高算法運(yùn)行速度,是目前基于Gibbs采樣的RBM訓(xùn)練算法亟需解決的問題,從現(xiàn)有文獻(xiàn)來看,尚無人對(duì)以上問題進(jìn)行研究.因此,本文將從馬爾科夫采樣理論的角度對(duì)以上問題進(jìn)行分析,并提出了動(dòng)態(tài)Gibbs采樣算法,最后給出了仿真驗(yàn)證.
受限玻爾茲曼機(jī)是一個(gè)馬爾科夫隨機(jī)場(chǎng)模型[22],它具有兩層結(jié)構(gòu),如圖1所示.下層為輸入層,包含m個(gè)輸入單元vi,用來表示輸入數(shù)據(jù),每個(gè)輸入單元包含一個(gè)實(shí)值偏置量ai;上層為隱層,包含n個(gè)隱層單元hj,表示受限玻爾茲曼機(jī)提取到的輸入數(shù)據(jù)的特征,每個(gè)隱層單元包含一個(gè)實(shí)值偏置bj.受限玻爾茲曼機(jī)具有層內(nèi)無連接,層間全連接的特點(diǎn).即同層內(nèi)各節(jié)點(diǎn)之間沒有連線,每個(gè)節(jié)點(diǎn)與相鄰層所有節(jié)點(diǎn)全連接,連線上有實(shí)值權(quán)重矩陣wij.這一性質(zhì)保證了各層之間的條件獨(dú)立性.
圖1 RBM結(jié)構(gòu)Fig.1 Configuration of RBM
本文研究二值受限玻爾茲曼機(jī)[23],即隨機(jī)變量(V,H)取值(v,h)∈{0,1}.由二值受限玻爾茲曼機(jī)定義的聯(lián)合分布滿足Gibbs分布,其中θ為網(wǎng)絡(luò)參數(shù)Eθ(v,h)為網(wǎng)絡(luò)的能量函數(shù):
Zθ為配分函數(shù):.輸入層節(jié)點(diǎn)v的概率分布P(v)為:.由受限玻爾茲曼機(jī)各層之間的條件獨(dú)立性可知,當(dāng)給定輸入層數(shù)據(jù)時(shí),輸出層節(jié)點(diǎn)取值滿足如下條件概率:相應(yīng)地,當(dāng)輸出層數(shù)據(jù)確定后,輸入層節(jié)點(diǎn)取值的條件概率為
給定一組訓(xùn)練樣本S= {v1,v2,···,vn},訓(xùn)練RBM 意味著調(diào)整參數(shù)θ,以擬合給定的訓(xùn)練樣本,使得該參數(shù)下由相應(yīng)RBM表示的概率分布盡可能地與訓(xùn)練數(shù)據(jù)的經(jīng)驗(yàn)分布相符合.本文應(yīng)用最大似然估計(jì)的方法對(duì)網(wǎng)絡(luò)參數(shù)進(jìn)行估計(jì).這樣,訓(xùn)練RBM的目標(biāo)就是最大化網(wǎng)絡(luò)的似然函數(shù):.為簡(jiǎn)化計(jì)算,將其改寫為對(duì)數(shù)形式:.進(jìn)一步推導(dǎo)對(duì)數(shù)似然函數(shù)的參數(shù)梯度
得到對(duì)數(shù)似然函數(shù)的參數(shù)梯度后,可以由梯度上升法求解其最大值.但由于數(shù)據(jù)分布P(v)未知,且包含配分函數(shù)Zθ,因此,無法給出梯度的解析解.現(xiàn)有訓(xùn)練算法主要是基于采樣的方法,首先,構(gòu)造以P(v)為平穩(wěn)分布的馬爾科夫鏈,獲得滿足P(v)分布的樣本;然后,通過蒙特卡洛迭代來近似梯度:
步驟1.設(shè)定網(wǎng)絡(luò)參數(shù)初值.
步驟2.將訓(xùn)練數(shù)據(jù)輸入到輸入層節(jié)點(diǎn),由式(2)對(duì)隱層節(jié)點(diǎn)值進(jìn)行采樣,
步驟3.根據(jù)式(3)對(duì)輸入層節(jié)點(diǎn)進(jìn)行采樣.再以此采樣值作為輸入層節(jié)點(diǎn)的值重復(fù)步驟2,這樣就完成了一步Gibbs采樣.
步驟4.步驟2和步驟3重復(fù)k次,完成k步
步驟5.將步驟4獲得的采樣值帶入式(5)中,計(jì)算參數(shù)梯度.
步驟6.將步驟5中獲得的參數(shù)梯度帶入式(6)中,對(duì)參數(shù)進(jìn)行更新.
步驟7.更新訓(xùn)練數(shù)據(jù),重復(fù)步驟2~6,直到達(dá)到額定迭代次數(shù).
相應(yīng)的偽代碼如算法1所示:
其中,a為可見層偏置向量,b為隱層偏置向量,w為網(wǎng)絡(luò)權(quán)值矩陣,η為學(xué)習(xí)率.
1.1問題實(shí)驗(yàn)描述
1)實(shí)驗(yàn)設(shè)計(jì)
本文采用的數(shù)據(jù)集是MNIST數(shù)據(jù)集,它是二值手寫數(shù)據(jù)集,也是目前訓(xùn)練RBM網(wǎng)絡(luò)的標(biāo)準(zhǔn)數(shù)據(jù)集.它總共包含60000個(gè)訓(xùn)練樣本和10000個(gè)測(cè)試樣本,每個(gè)樣本是一幅28像素×28像素的灰度圖.所采用的RBM網(wǎng)絡(luò)有784×500個(gè)節(jié)點(diǎn),輸入層有784個(gè)可見單元,對(duì)應(yīng)灰度圖的784個(gè)像素點(diǎn);輸出層有500個(gè)隱層節(jié)點(diǎn),這是目前實(shí)驗(yàn)顯示的訓(xùn)練效果較好的隱層節(jié)點(diǎn)數(shù)目.具體的網(wǎng)絡(luò)參數(shù)初始值設(shè)定如表1.
表1 網(wǎng)絡(luò)參數(shù)初值Table 1 Initial value of parameters
本文設(shè)計(jì)了6組對(duì)比實(shí)驗(yàn),用60000個(gè)訓(xùn)練樣本對(duì)RBM進(jìn)行訓(xùn)練,分別迭代1000次,如表2所示.其中CD_k表示進(jìn)行k步Gibbs迭代.用于顯示的樣本數(shù)據(jù)的原始圖片如圖2所示.實(shí)驗(yàn)結(jié)束后,我們比較了各組實(shí)驗(yàn)的重構(gòu)誤差,并給出了最終的誤差圖.
表2 實(shí)驗(yàn)分組Table 2 Experimental grouping
圖2 原始數(shù)據(jù)灰度圖Fig.2 Gray image of initial data
2)仿真結(jié)果圖3表示整個(gè)迭代過程中各組CD算法的重構(gòu)誤差圖,圖4給出了各組實(shí)驗(yàn)的訓(xùn)練時(shí)間,圖5~圖10分別給出了各組實(shí)驗(yàn)的采樣灰度圖.
圖3 重構(gòu)誤差圖Fig.3 Reconstruction error diagram
圖4 運(yùn)行時(shí)間圖Fig.4 Runtime diagram
圖5 CD_1采樣灰度圖Fig.5 Gray image of CD_1 sampling
圖6 CD_5采樣灰度圖Fig.6 Gray image of CD_5 sampling
圖7 CD_10采樣灰度圖Fig.7 Gray image of CD_10 sampling
圖8 CD_100采樣灰度圖Fig.8 Gray image of CD_100 sampling
圖9 CD_500采樣灰度圖Fig.9 Gray image of CD_500 sampling
圖10 CD_1000采樣灰度圖Fig.10 Gray image of CD_1000 sampling
1.2問題歸納描述
上節(jié)實(shí)驗(yàn)給出了CD算法在不同Gibbs采樣步數(shù)下的仿真圖,可以看出,當(dāng)RBM網(wǎng)絡(luò)采用多步Gibbs算法進(jìn)行采樣迭代時(shí),會(huì)出現(xiàn)如下問題:
問題1.訓(xùn)練初始階段,得到的每幅重構(gòu)采樣圖幾乎完全相同.
如圖11、圖12所示,在訓(xùn)練初始階段,多步Gibbs采樣出現(xiàn)了各組采樣數(shù)據(jù)同分布的現(xiàn)象,這表明各組樣本幾乎完全相同,這與事實(shí)相左.在訓(xùn)練初期,大約0~100次迭代期間,這種現(xiàn)象持續(xù)存在.
圖11 CD_500采樣灰度圖Fig.11 Gray image of CD_500 sampling
圖12 CD_1000采樣灰度圖Fig.12 Gray image of CD_1000 sampling
問題2.采樣誤差分布集中,在批量訓(xùn)練時(shí),存在全0全1現(xiàn)象.
如圖13、圖14所示,當(dāng)進(jìn)行多步Gibbs采樣時(shí),出現(xiàn)了誤差分布集中的現(xiàn)象:有些樣本采樣幾乎全為1,而其他的樣本采樣幾乎全為0.由仿真實(shí)驗(yàn)可知,在0~100次迭代期間,這種現(xiàn)象在迭代初期持續(xù)存在.
圖13 CD_500采樣灰度圖Fig.13 Gray image of CD_500 sampling
問題3.一步Gibbs采樣初始誤差小,訓(xùn)練速度快,但后期訓(xùn)練精度低;多步Gibbs采樣初始誤差大,訓(xùn)練速度慢,但后期訓(xùn)練精高.
如圖15、圖16所示,只進(jìn)行一步Gibbs采樣的CD_1算法在開始時(shí)訓(xùn)練誤差較小,很快便收斂到較好值,但訓(xùn)練后期精度不如CD_10等進(jìn)行多步Gibbs迭代的CD算法;進(jìn)行多步Gibbs采樣的CD_k迭代算法,在訓(xùn)練初期誤差較大,且不斷振蕩,而且訓(xùn)練時(shí)間較慢,但到訓(xùn)練后期,它們可以達(dá)到極高的精度.
圖14 CD_1000采樣灰度圖Fig.14 Gray image of CD_1000 sampling
圖15 采樣誤差局部放大圖Fig.15 Local enlarged drawing of reconstruction error in initial phase
圖16 采樣誤差局部放大圖Fig.16 Local enlarged drawing of reconstruction error in later stage
以上實(shí)驗(yàn)表明,CD算法雖然對(duì)RBM具有良好的訓(xùn)練能力,但Gibbs采樣的步數(shù)對(duì)訓(xùn)練性能造成了明顯的影響.我們將在下節(jié)研究這種影響,并對(duì)以上問題給出理論分析.
Gibbs采樣是馬爾科夫鏈蒙特卡洛(Markov chain Monte Carlo,MCMC)采樣算法的一種.在RBM訓(xùn)練中,它的轉(zhuǎn)移核是Sigmoid函數(shù).隱層節(jié)點(diǎn)和輸入層節(jié)點(diǎn)交替采樣,公式如下:
由馬爾科夫鏈?zhǔn)諗慷ɡ砜芍?,?dāng)n→+∞時(shí),Gibbs采樣鏈會(huì)收斂到平衡分布,即:
其中,π(x)為樣本x的平衡分布.同時(shí),由細(xì)致平衡準(zhǔn)則可得:
即Gibbs采樣的平穩(wěn)分布與迭代初始值無關(guān),只與轉(zhuǎn)移概率有關(guān).由上面給出的RBM交替采樣概率公式可知,當(dāng)用Gibbs采樣對(duì)RBM進(jìn)行采樣訓(xùn)練時(shí),其平穩(wěn)分布是網(wǎng)絡(luò)參數(shù)的函數(shù):
從這個(gè)角度講,訓(xùn)練RBM的目的就是調(diào)節(jié)網(wǎng)絡(luò)參數(shù),使由網(wǎng)絡(luò)參數(shù)確定的平穩(wěn)分布等于樣本的真實(shí)分布.
基于以上描述,下面對(duì)第2節(jié)中提出的問題給出理論解釋.
問題1.訓(xùn)練初始階段,得到的每幅重構(gòu)采樣圖幾乎完全相同.
初始時(shí)刻,網(wǎng)絡(luò)參數(shù)初值相同,在早期迭代過程中,網(wǎng)絡(luò)參數(shù)值的變動(dòng)也不大,滿足如下公式:
ε為一極小正值.由網(wǎng)絡(luò)參數(shù)決定的平穩(wěn)分布也近乎相同:
即各樣本的平穩(wěn)分布相等.因此,當(dāng)進(jìn)行多步Gibbs采樣時(shí),各訓(xùn)練樣本的采樣樣本逐漸收斂到相同的平穩(wěn)分布,這時(shí)就出現(xiàn)了問題1描述的現(xiàn)象,各樣本的重構(gòu)采樣圖幾乎完全相同.
問題2.采樣誤差分布集中,在批量訓(xùn)練時(shí),存在全0全1現(xiàn)象.
由上一部分分析可知,在訓(xùn)練初期,網(wǎng)絡(luò)參數(shù)改變不大,由RBM參數(shù)決定的平衡分布幾乎同構(gòu),即各采樣概率收斂到相同平衡分布值.上述對(duì)比實(shí)驗(yàn)中,網(wǎng)絡(luò)參數(shù)的初始值為θ=(a,b,w)=(0,0,0.1),此時(shí)網(wǎng)絡(luò)平衡分布收斂在0.5附近,樣本數(shù)據(jù)的收斂概率將在0.5附近浮動(dòng),即一部分樣本的采樣概率略小于0.5,另一部分樣本的采樣概率略大于0.5,即:
其中,ε為一極小正值.這時(shí)基于隨機(jī)數(shù)對(duì)樣本進(jìn)行采樣,一部分樣本的采樣值將全為0,另一部分的采樣值將全為1,即全0全1現(xiàn)象.
問題3.一步Gibbs采樣初始誤差小,訓(xùn)練速度快,但后期訓(xùn)練精度低;多步Gibbs采樣初始誤差大,訓(xùn)練速度慢,但后期訓(xùn)練精高.
在網(wǎng)絡(luò)訓(xùn)練早期,網(wǎng)絡(luò)參數(shù)差較大,由網(wǎng)絡(luò)參數(shù)定義的平穩(wěn)分布與真實(shí)分布相差也較大,即Δπ=.此時(shí),如果對(duì)樣本進(jìn)行多步迭代采樣,采樣樣本將偏離真實(shí)分布,從而不能收斂到真實(shí)分布,而是收斂到與真實(shí)分布相差較大的其他分布.因此,在迭代初期,CD_1000、CD_500等算法的采樣誤差非常大,而且運(yùn)行時(shí)間較長(zhǎng).而CD_1算法由于只進(jìn)行了一次采樣迭代,不僅運(yùn)行速度加快,而且由于采樣樣本的分布沒有偏離真實(shí)分布太多,使得這時(shí)候的CD_1算法的采樣誤差非常小.由實(shí)驗(yàn)可知,此時(shí)采樣誤差的大小關(guān)系為:CD_1< CD_5<CD_10<CD_100<CD_500<CD_1000.到了網(wǎng)絡(luò)訓(xùn)練后期,由于網(wǎng)絡(luò)參數(shù)差非常小,網(wǎng)絡(luò)參數(shù)的實(shí)際值已經(jīng)非常接近真實(shí)值,這時(shí)候進(jìn)行多步Gibbs迭代能很好地逼近樣本真實(shí)分布,所以這一階段,CD_k算法的采樣精度要比CD_1高.但由于網(wǎng)絡(luò)參數(shù)差一直存在,所以,Gibbs迭代步數(shù)也不宜過高,如實(shí)驗(yàn)所示,CD_1000在采樣到最后,采樣誤差仍高于CD_10.
在現(xiàn)有以Gibbs采樣為基礎(chǔ)的RBM訓(xùn)練算法中,Gibbs采樣的采樣步數(shù)多為固定值,即在整個(gè)訓(xùn)練過程中,每次迭代采樣時(shí)都進(jìn)行固定步數(shù)的Gibbs采樣,這樣就難以兼顧訓(xùn)練精度和訓(xùn)練速度這兩個(gè)訓(xùn)練指標(biāo).當(dāng)進(jìn)行多步Gibbs采樣時(shí),容易在訓(xùn)練前期發(fā)生誤差發(fā)散的現(xiàn)象,且算法運(yùn)行時(shí)間較長(zhǎng);一步Gibbs采樣算法運(yùn)行較快,但后期訓(xùn)練精度不高,基于此,本文提出了動(dòng)態(tài)Gibbs采樣(Dynamic Gibbs sampling,DGS)算法.
定義1.動(dòng)態(tài)Gibbs采樣是指在迭代訓(xùn)練過程中的不同階段,根據(jù)網(wǎng)絡(luò)的訓(xùn)練誤差,動(dòng)態(tài)地調(diào)整Gibbs采樣的步數(shù),以達(dá)到最優(yōu)訓(xùn)練效果.
通過上節(jié)分析可知,在網(wǎng)絡(luò)訓(xùn)練初期,網(wǎng)絡(luò)參數(shù)幾乎相等,各樣本的平穩(wěn)分布也近乎相等,而且網(wǎng)絡(luò)參數(shù)差較大,樣本的平穩(wěn)分布與真實(shí)分布相差也較大,因此,這一階段應(yīng)盡量減少采樣次數(shù),克服多步Gibbs采樣引起的誤差發(fā)散,提高訓(xùn)練速度,使網(wǎng)絡(luò)參數(shù)盡快逼近真實(shí)值;當(dāng)網(wǎng)絡(luò)參數(shù)逼近真實(shí)值時(shí),此時(shí)應(yīng)加大采樣迭代次數(shù),提高訓(xùn)練精度.
基于以上定義和描述,DGS算法的操作步驟如下:
步驟1.設(shè)定網(wǎng)絡(luò)參數(shù)初值和動(dòng)態(tài)策略M.
步驟2.在1~m1迭代范圍內(nèi),設(shè)置Gibbs采樣步數(shù)k1=Gibbs_N1.
步驟3.將訓(xùn)練數(shù)據(jù)輸入到輸入層節(jié)點(diǎn),由式(2)對(duì)隱層節(jié)點(diǎn)值進(jìn)行采樣.
步驟4.根據(jù)式(3)對(duì)輸入層節(jié)點(diǎn)進(jìn)行采樣.再以此采樣值作為輸入層節(jié)點(diǎn)的值重復(fù)步驟3,這樣就完成了一步Gibbs采樣.
步驟5.步驟3和步驟4重復(fù)k1次,完成k1步Gibbs采樣.
步驟6.將步驟5獲得的采樣值帶入式(5)中,計(jì)算參數(shù)梯度.
步驟7.將步驟6中獲得的參數(shù)梯度帶入式(6)中,對(duì)參數(shù)進(jìn)行更新.
步驟8.更新訓(xùn)練數(shù)據(jù),重復(fù)步驟3到步驟7,直到迭代次數(shù)達(dá)到m1.
步驟9.在m1~m2迭代范圍內(nèi),設(shè)置Gibbs采樣步數(shù)k2=Gibbs_N2.
步驟10.重復(fù)步驟3到步驟8,直到迭代次數(shù)達(dá)到m2.
步驟11.在m2~I(xiàn)ter迭代范圍內(nèi),設(shè)置Gibbs采樣步數(shù)k3=Gibbs_N3.
步驟12.重復(fù)步驟3到步驟8,直到迭代次數(shù)達(dá)到最大迭代次數(shù)Iter.
相應(yīng)的偽代碼如算法2所示.
算法2.DGS算法偽代碼
其中,M=(m1,m2)為動(dòng)態(tài)策略,且滿足m2>m1.Iter為總的迭代次數(shù),iter為當(dāng)前迭代次數(shù). Gibbs_Ni為Gibbs采樣,Ni表示采樣次數(shù),且滿足Nn>Nn-1.其中Gibbs采樣次數(shù)N與網(wǎng)絡(luò)訓(xùn)練迭代次數(shù)M 之間的大致關(guān)系如下:
本節(jié)設(shè)計(jì)了7組對(duì)比實(shí)驗(yàn),第1~6組實(shí)驗(yàn)采用固定Gibbs采樣步數(shù)的CD_k算法進(jìn)行訓(xùn)練仿真,第6組實(shí)驗(yàn)用DGS算法對(duì)網(wǎng)絡(luò)進(jìn)行訓(xùn)練仿真,如表3所示.兩組實(shí)驗(yàn)使用相同的數(shù)據(jù)集MNIST,網(wǎng)絡(luò)結(jié)構(gòu)相同,網(wǎng)絡(luò)參數(shù)初始值相同,如表4所示.本文設(shè)計(jì)的動(dòng)態(tài)采樣策略如表5所示.下面給出仿真實(shí)驗(yàn)結(jié)果和分析.
表3 實(shí)驗(yàn)分組Table 3 Experimental grouping
表4 網(wǎng)絡(luò)參數(shù)初值Table 4 Initial values of parameters
表5 DGS迭代策略Table 5 Iterative strategy of DGS
4.1重構(gòu)誤差對(duì)比分析
圖17給出了所有算法的重構(gòu)誤差對(duì)比圖.對(duì)比結(jié)果顯示,本文設(shè)計(jì)的DGS算法可以很好地訓(xùn)練RBM網(wǎng)絡(luò),從而證明了本文算法的有效性.
在迭代初期,DGS算法只進(jìn)行一次Gibbs采樣迭代,避免了采樣發(fā)散,從而迅速收斂到較好的值,由誤差對(duì)比圖初始階段的局部放大圖(圖18)可以看出,此時(shí)誤差滿足:
在迭代后期,網(wǎng)絡(luò)參數(shù)值已非常接近真實(shí)值,此時(shí)DGS逐步增大了Gibbs采樣的迭代步數(shù),獲得了采樣精度更高的目標(biāo)樣本,最終獲得了更高的訓(xùn)練精度,即:
如圖19所示.
圖17 重構(gòu)誤差對(duì)比圖Fig.17 Contrast of reconstruction error
圖18 訓(xùn)練初期局部放大圖Fig.18 Local enlarged drawing of reconstruction error in initial phase
圖19 訓(xùn)練后期局部放大圖Fig.19 Local enlarged drawing of reconstruction error in later stage
4.2運(yùn)行時(shí)間對(duì)比分析
圖20給出了所有算法的運(yùn)行時(shí)間對(duì)比圖.從圖中可以看出,在整個(gè)訓(xùn)練過程中,DGS算法、CD_1算法、CD_5算法和CD_10算法的運(yùn)行速度都明顯比其他算法快.因此,下面根據(jù)本文設(shè)計(jì)的動(dòng)態(tài)策略,對(duì)各個(gè)迭代區(qū)間內(nèi)這4種算法的運(yùn)行速度進(jìn)行分析:
圖20 運(yùn)行時(shí)間對(duì)比圖Fig.20 Contrast of runtime
在1~300迭代范圍內(nèi),DGS算法的Gibbs采樣步數(shù)k設(shè)為1,與CD_1算法相同.所以,此時(shí)的DGS算法的運(yùn)行速度與CD_1相同,且快于其他兩種算法,如圖21所示.
圖21 運(yùn)行時(shí)間對(duì)比圖Fig.21 Contrast of runtime
在300~900迭代范圍內(nèi),DGS算法的Gibbs采樣步數(shù)k設(shè)為5.由圖22可以看出,此時(shí)DGS算法的運(yùn)行速度逐漸放緩,運(yùn)行時(shí)間明顯上升,逐漸大于CD_1算法.
在900~1000迭代范圍內(nèi),DGS算法的Gibbs采樣步數(shù)k設(shè)為10.所以,這個(gè)時(shí)期的DGS運(yùn)行時(shí)間持續(xù)放緩.但從圖23中可以看出,即便到了訓(xùn)練后期,DGS算法的運(yùn)行時(shí)間仍然小于CD_5算法和其他CD_k(k>5)算法.這說明,DGS算法在后期提高訓(xùn)練精度的同時(shí),只付出了微小的時(shí)間代價(jià).
圖22 運(yùn)行時(shí)間對(duì)比圖Fig.22 Contrast of runtime
圖23 運(yùn)行時(shí)間對(duì)比圖Fig.23 Contrast of runtime
4.3采樣效果圖
圖24~圖28分別給出了DGS算法在不同迭代次數(shù)下的采樣重構(gòu)圖.對(duì)比圖11、圖12,可以看出,DGS在訓(xùn)練迭代50次以內(nèi)就可以很好地重構(gòu)輸入樣本,而且沒有出現(xiàn)全0全1現(xiàn)象和采樣圖同構(gòu)現(xiàn)象,從而克服了第2.2節(jié)問題1和問題2中描述的問題.
圖24 DGS迭代10次采樣灰度圖Fig.24 Gray image of DGS by 10 iterations
圖25 DGS迭代20次采樣灰度圖Fig.25 Gray image of DGS by 20 iterations
圖26 DGS迭代30次采樣灰度圖Fig.26 Gray image of DGS by 30 iterations
圖27 DGS迭代40次采樣灰度圖Fig.27 Gray image of DGS by 40 iterations
圖28 DGS迭代50次采樣灰度圖Fig.28 Gray image of DGS by 50 iterations
圖29顯示了DGS訓(xùn)練結(jié)束后的重構(gòu)灰度圖,圖中幾乎沒有噪點(diǎn).可見,采用DGS算法訓(xùn)練網(wǎng)絡(luò)可以獲得更高的訓(xùn)練精度,從而解決了第2.2節(jié)中問題3描述的問題.
圖29 DGS重構(gòu)灰度圖Fig.29 Gray image of DGS
綜上所述,本文設(shè)計(jì)的DGS算法在訓(xùn)練初期克服了多步Gibbs采樣發(fā)散的缺點(diǎn),在訓(xùn)練后期獲得更高的精度,而且在保證收斂精度的情況下大幅度提高了訓(xùn)練速度,獲得了較好的效果.
本文首先通過仿真實(shí)驗(yàn),給出了現(xiàn)有基于Gibbs采樣的RBM訓(xùn)練算法在訓(xùn)練初期誤差發(fā)散和后期訓(xùn)練精度不高等問題的具體描述,然后從馬爾科夫采樣理論的角度對(duì)Gibbs采樣誤差進(jìn)行理論分析.證明在RBM網(wǎng)絡(luò)下,多步Gibbs采樣較差的收斂性質(zhì)是導(dǎo)致前期采樣發(fā)散和算法運(yùn)行速度較低的主要原因;單步Gibbs采樣是造成后期訓(xùn)練精度不高的主要原因.基于此,本文提出了動(dòng)態(tài)Gibbs采樣算法,并給出了驗(yàn)證實(shí)驗(yàn).實(shí)驗(yàn)表明,本文提出的動(dòng)態(tài)Gibbs采樣算法在訓(xùn)練初期克服了多步Gibbs采樣引起的誤差發(fā)散,后期克服了單步Gibbs采樣帶來的訓(xùn)練精度低的問題,同時(shí)提高了訓(xùn)練速度,以上特點(diǎn)可以彌補(bǔ)現(xiàn)有以Gibbs采樣為基礎(chǔ)的RBM訓(xùn)練算法的不足.
關(guān)于Gibbs采樣步數(shù)、訓(xùn)練迭代次數(shù)與訓(xùn)練精度之間的關(guān)系,本文在理論分析部分只給出了定性分析;在動(dòng)態(tài)Gibbs采樣算法設(shè)計(jì)階段,本文只是根據(jù)實(shí)驗(yàn)分析,給出Gibbs采樣步數(shù)和訓(xùn)練迭代次數(shù)之間的經(jīng)驗(yàn)區(qū)間.Gibbs采樣步數(shù)、訓(xùn)練迭代次數(shù)以及網(wǎng)絡(luò)訓(xùn)練精度之間是否存在精確的數(shù)學(xué)關(guān)系,如果存在,其數(shù)學(xué)模型如何構(gòu)建.以上問題仍有待進(jìn)一步研究.
References
1 Hinton G E,Salakhutdinov R R.Reducing the dimensionality of data with neural networks.Science,2006,313(5786):504-507
2 Le Roux N,Heess N,Shotton J,Winn J.Learning a generative model of images by factoring appearance and shape. Neural Computation,2011,23(3):593-650
3 Su Lian-Cheng,Zhu Feng.Design of a novel omnidirectional stereo vision system.Acta Automatica Sinica,2006,32(1):67-72(蘇連成,朱楓.一種新的全向立體視覺系統(tǒng)的設(shè)計(jì).自動(dòng)化學(xué)報(bào),2006,32(1):67-72)
4 Bengio Y.Learning deep architectures for AI.Foundations and Trends?in Machine Learning,2009,2(1):1-127
5 Deng L,Abdel-Hamid O,Yu D.A deep convolutional neural network using heterogeneous pooling for trading acoustic invariance with phonetic confusion.In:Proceedings of the 2013 IEEE International Conference on Acoustics,Speech and Signal Processing(ICASSP).Vancouver,BC:IEEE,2013.6669-6673
6 Deng L.Design and learning of output representations for speech recognition.In:Proceedings of the Neural Information Processing Systems(NIPS)Workshop on Learning OutputRepresentations[Online],available: http://research.microsoft.com/apps/pubs/default.aspx?id=204702,July 14,2015
7 Chet C C,Eswaran C.Reconstruction and recognition of face and digit images using autoencoders.Neural Computing and Applications,2010,19(7):1069-1079
8 Deng L,Hinton G,Kingsbury B.New types of deep neural network learning for speech recognition and related applications:an overview.In:Proceedings of the 2013 IEEE International Conference on Acoustics,Speech and Signal Processing(ICASSP).Vancouver,BC:IEEE,2013.8599-8603
9 Erhan D,Courville A,Bengio Y,Vincent P.Why does unsupervised pre-training help deep learning?In:Proceedings of the 13th International Conference on Artificial Intelligence and Statistics(AISTATS 2010).Sardinia,Italy,2010. 201-208
10 Salakhutdinov R,Hinton G.Deep Boltzmann machines.In:Proceedings of the 12th International Conference on Artificial Intelligence and Statistics(AISTATS 2009).Florida,USA,2009.448-455
11 Swersky K,Chen B,Marlin B,de Freitas N.A tutorial on stochastic approximation algorithms for training restricted Boltzmann machines and deep belief nets.In:Proceedings of the 2010 Information Theory and Applications Workshop (ITA).San Diego,CA:IEEE,2010.1-10
12 Hinton G E,Osindero S,Teh Y W.A fast learning algorithm for deep belief nets.Neural Computation,2006,18(7):1527-1554
13 Fischer A,Igel C.Bounding the bias of contrastive divergence learning.Neural Computation,2011,23(3):664-673
14 Tieleman T.Training restricted Boltzmann machines using approximations to the likelihood gradient.In:Proceedings of the 25th International Conference on Machine Learning (ICML).New York:ACM,2008.1064-1071
15 Tieleman T,Hinton G E.Using fast weights to improve persistent contrastive divergence.In:Proceedings of the 26th Annual International Conference on Machine Learning (ICML).New York:ACM,2009.1033-1040
16 Sutskever I,Tieleman T.On the convergence properties of contrastive divergence.In:Proceedings of the 13th International Conference on Artificial Intelligence and Statistics (AISTATS 2010).Sardinia,Italy,2010.789-795
17 FischerA,IgelC.Paralleltempering, importance sampling, andrestrictedBoltzmannmachines.In:Proceedingsof5thWorkshoponTheoryofRandomizedSearchHeuristics(ThRaSH),[Online],available:http://www2.imm.dtu.dk/projects/thrashworkshop/schedule.php,August 20,2015
18 Desjardins G,Courville A,Bengio Y.Adaptive parallel tempering for stochastic maximum likelihood learning of RBMs. In:Proceedings of NIPS 2010 Workshop on Deep Learning and Unsupervised Feature Learning.Granada,Spain,2010.
19 Cho K,Raiko T,Ilin A.Parallel tempering is efficient for learning restricted Boltzmann machines.In:Proceedings of the WCCI 2010 IEEE World Congress on Computational Intelligence.Barcelona,Spain:IEEE,2010.3246-3253
20 Brakel P,Dieleman S,Schrauwen B.Training restricted Boltzmann machines with multi-tempering:harnessing parallelization.In:Proceedings of the 22nd International Conference on Artificial Neural Networks.Lausanne,Switzerland:Springer,2012.92-99
21 Desjardins G,Courville A,Bengio Y,Vincent P,Delalleau O.Tempered Markov chain Monte Carlo for training of restricted Boltzmann machines.In:Proceedings of the 13th International Conference on Artificial Intelligence and Statistics(AISTATS 2010).Sardinia,Italy,2010.145-152
22 Fischer A,Igel C.Training restricted Boltzmann machines:an introduction.Pattern Recognition,2014,47(1):25-39
23 Hinton G E.A practical guide to training restricted Boltzmann machines.Neural Networks:Tricks of the Trade(2nd Edition).Berlin Heidelberg:Springer,2012.599-619
李 飛西北工業(yè)大學(xué)電子信息學(xué)院博士研究生.2011年獲得西北工業(yè)大學(xué)系統(tǒng)工程專業(yè)學(xué)士學(xué)位.主要研究方向?yàn)闄C(jī)器學(xué)習(xí)和深度學(xué)習(xí).
E-mail:nwpulf@mail.nwpu.edu.cn
(LIFeiPh.D.candidateatthe School of Electronics and Information,Northwestern Polytechnical University. He received his bachelor degree in system engineering from Northwestern Polytechnical University in 2011.His research interest covers machine learning and deep learning.)
高曉光西北工業(yè)大學(xué)電子信息學(xué)院教授.1989年獲得西北工業(yè)大學(xué)飛行器導(dǎo)航與控制系統(tǒng)博士學(xué)位.主要研究方向?yàn)樨惾~斯和航空火力控制.本文通信作者.
E-mail:cxg2012@nwpu.edu.cn
(GAO Xiao-GuangProfessor at the School of Electronics and Information,Northwestern Polytechnical University.She received her Ph.D.degree in aircraft navigation and control system from Northwestern Polytechnical University in 1989.Her research interest covers Bayes and airborne fire control.Corresponding author of this paper.)
萬開方西北工業(yè)大學(xué)電子信息學(xué)院博士研究生.2010年獲得西北工業(yè)大學(xué)系統(tǒng)工程專業(yè)學(xué)士學(xué)位.主要研究方向?yàn)楹娇栈鹆刂?
E-mail:yibai_2003@126.com
(WAN Kai-FangPh.D.candidate at the School of Electronics and Information,Northwestern Polytechnical University.He received his bachelor degree in system engineering from Northwestern Polytechnical University in 2010.His main research interest is airborne fire control.)
Research on RBM Training Algorithm with Dynamic Gibbs Sampling
LI Fei1GAO Xiao-Guang1WAN Kai-Fang1
Currently,most algorithms for training restricted Boltzmann machines(RBMs)are based on the multi-step Gibbs sampling.This article focuses on the problems of sampling divergence and the low training speed associated with the multi-step Gibbs sampling process.Firstly,these problems are illustrated and described by experiments.Then,the convergence property of the Gibbs sampling procedure is theoretically analyzed from the prospective of the Markov sampling.It is proved that the poor convergence property of the multi-step Gibbs sampling is the main cause of the sampling divergence and the low training speed when training an RBM.Furthermore,a new dynamic Gibbs sampling algorithm is proposed and its simulation results are given.It has been demonstrated that the dynamic Gibbs sampling algorithm can effectively tackle the issue of sampling divergence and can achieve a higher training accuracy at a reasonable expense of computation time.
Restricted Boltzmann machine(RBM),Gibbs sampling,sampling algorithm,Markov theory
10.16383/j.aas.2016.c150645
Li Fei,Gao Xiao-Guang,Wan Kai-Fang.Research on RBM training algorithm with dynamic Gibbs sampling. Acta Automatica Sinica,2016,42(6):931-942
2015-10-19錄用日期2016-05-03
Manuscript received October 19,2015;accepted May 3,2016
國(guó)家自然科學(xué)基金(61305133,61573285)資助
Supported by National Natural Science Foundation of China (61305133,61573285)
本文責(zé)任編委柯登峰
Recommended by Associate Editor KE Deng-Feng
1.西北工業(yè)大學(xué)電子信息學(xué)院西安710129
1.School of Electronics and Information,Northwestern Polytechnical University,Xi′an 710129