高慧敏 楊 磊 朱軍龍 張明川 吳慶濤*
①(河南科技大學(xué)信息工程學(xué)院 洛陽 471023)
②(中信重工機(jī)械股份有限公司信息技術(shù)管理中心 洛陽 471003)
隨著移動互聯(lián)網(wǎng)和物聯(lián)網(wǎng)的快速發(fā)展,越來越多的數(shù)據(jù)由網(wǎng)絡(luò)中的邊緣設(shè)備產(chǎn)生和存儲。利用機(jī)器學(xué)習(xí)技術(shù),這些分散的數(shù)據(jù)集通常被用于模型訓(xùn)練,但這種集中式訓(xùn)練方式在將物聯(lián)網(wǎng)設(shè)備產(chǎn)生的數(shù)據(jù)傳輸?shù)皆贫嘶蛘咧行姆?wù)器的過程中面臨著諸多挑戰(zhàn),例如能量和帶寬限制,通信效率和隱私問題等。為解決上述問題,2016年谷歌學(xué)者提出了一個新的機(jī)器學(xué)習(xí)范式,即聯(lián)邦學(xué)習(xí)[1],它聯(lián)合多個用戶在云或中心服務(wù)器的協(xié)調(diào)下訓(xùn)練一個預(yù)測模型,同時用戶之間不用交換數(shù)據(jù),這在一定程度上保護(hù)了用戶數(shù)據(jù)的隱私性。不同于去中心化的多節(jié)點(diǎn)協(xié)作學(xué)習(xí)[2],聯(lián)邦學(xué)習(xí)與分布式學(xué)習(xí)中的參數(shù)服務(wù)器具有相同的架構(gòu),但是聯(lián)邦學(xué)習(xí)中的用戶對本地數(shù)據(jù)具有絕對的自治權(quán)限,可以自主決定是否參與和退出模型訓(xùn)練,而且數(shù)據(jù)之間可能有著不同的分布。此外,分布式學(xué)習(xí)一般在數(shù)據(jù)中心進(jìn)行并行訓(xùn)練,通信代價相對較小。而聯(lián)邦學(xué)習(xí)中的邊緣設(shè)備一般依賴互聯(lián)網(wǎng),其通信代價會更高[3]。面對這些新的問題,聯(lián)邦學(xué)習(xí)提供了一個有效的解決方案,并被廣泛應(yīng)用到不同的領(lǐng)域,例如移動邊緣網(wǎng)絡(luò)[4]、健康醫(yī)療[5]和智能交通系統(tǒng)[6]等。
通信開銷不僅是聯(lián)邦學(xué)習(xí)的一個研究熱點(diǎn),同時也是一個亟待解決的問題。在每一輪的模型訓(xùn)練中,云端或中心服務(wù)器將當(dāng)前的全局模型下發(fā)給參與訓(xùn)練的客戶端;之后,這些客戶端在本地執(zhí)行多步隨機(jī)梯度下降后生成本地模型;最后再將本地模型發(fā)送到云端或中心服務(wù)器。其中,通信開銷主要是由客戶端和云端或中心服務(wù)器之間經(jīng)過網(wǎng)絡(luò)連接傳輸模型參數(shù)產(chǎn)生。在實(shí)際場景中,隨著網(wǎng)絡(luò)中終端設(shè)備的增多、模型參數(shù)維度的增大,客戶端向中心服務(wù)器發(fā)送信息時會遭受嚴(yán)重的網(wǎng)絡(luò)時延和有限帶寬的限制。
為了緩解客戶端和云端或中心服務(wù)器之間的通信壓力,減少傳輸?shù)谋忍財?shù)是一個有效的方法,并由此延伸出許多新的壓縮方法,例如通過稀疏[7]或者量化[8,9]梯度更新或者模型更新,或者將二者結(jié)合[10]使用。這類方法在保證一定模型精確度的基礎(chǔ)上,減少了客戶端之間的通信成本,因而在具有高維模型參數(shù)的深度學(xué)習(xí)領(lǐng)域有著廣泛的應(yīng)用。Li等人[11]提出加速壓縮梯度下降算法,加快模型收斂的同時,節(jié)約了通信開銷。分別針對同構(gòu)和異構(gòu)數(shù)據(jù),Haddadpour等人[12]研究了一類通信有效的優(yōu)化算法,并引入了梯度追蹤方案減少客戶端之間由于數(shù)據(jù)異構(gòu)而產(chǎn)生的殘差,該方法得到了更好的通信復(fù)雜度和更精確的收斂速率。
上述方法只考慮在每輪通信中減少傳輸?shù)男畔⒘?,而通信的次?shù)并沒有改變。在實(shí)際訓(xùn)練中,連續(xù)發(fā)送的模型之間可能并沒有較大的變化,這樣頻繁的通信是一種資源的浪費(fèi)。近年來,在多智能體系統(tǒng),研究人員利用事件觸發(fā)機(jī)制[13,14]減少智能體之間不必要的通信,緩解網(wǎng)絡(luò)中的通信壓力。其中,如何設(shè)計觸發(fā)機(jī)制是關(guān)鍵,以保證算法能夠收斂,并達(dá)到減少通信開銷的目的。Hsieh等人[15]利用梯度衡量更新的重要性,進(jìn)而消除不重要的通信,其閾值設(shè)置為迭代次數(shù)平方根的倒數(shù)。文獻(xiàn)[14]利用模型參數(shù)之間的差異確定節(jié)點(diǎn)的觸發(fā)時刻,不同于分段參數(shù)的閾值調(diào)度方案[16],其閾值與學(xué)習(xí)率和神經(jīng)網(wǎng)絡(luò)參數(shù)的數(shù)量有關(guān)。
為了降低聯(lián)邦學(xué)習(xí)中客戶端和中心服務(wù)器之間的通信代價,本文引入事件觸發(fā)的通信機(jī)制,當(dāng)連續(xù)更新的本地模型之間變化較大時,客戶端的通信行為受到觸發(fā)并將模型差異壓縮后發(fā)送給中心服務(wù)器,中心服務(wù)器將收到的信息聚合后執(zhí)行新一輪全局模型更新。分別在不同目標(biāo)函數(shù)特征條件下,本文證明了所提方法(Federated learning w ith Event-T riggered comm unication,FedET)的收斂性,并給出了相應(yīng)的理論分析。最后通過仿真實(shí)驗(yàn)驗(yàn)證了所提方法的可行性和有效性。
在以數(shù)據(jù)為驅(qū)動的機(jī)器學(xué)習(xí)特別是深度學(xué)習(xí)中,分散的數(shù)據(jù)被收集到云端或數(shù)據(jù)中心進(jìn)行模型訓(xùn)練。在這個過程中,數(shù)據(jù)的傳輸不僅受到不同網(wǎng)絡(luò)環(huán)境特別是有限帶寬的制約,而且面臨著隱私泄露和被惡意篡改的風(fēng)險。為了解決上述問題,基于數(shù)據(jù)不出本地,聯(lián)邦學(xué)習(xí)聯(lián)合多個用戶學(xué)習(xí)一個共享模型,進(jìn)而在網(wǎng)絡(luò)邊緣實(shí)現(xiàn)分布式機(jī)器學(xué)習(xí)。本文考慮由一個中心服務(wù)器和n個客戶端組成的系統(tǒng)模型,每個客戶端i擁有自己的本地數(shù)據(jù)Di,其中{1,2,...,n}表示客戶端的集合,如圖1所示,中心服務(wù)器給每個客戶端分發(fā)當(dāng)前的全局模型。在此基礎(chǔ)上,客戶端通過執(zhí)行多步隨機(jī)梯度下降生成本地模型,并將其上傳到中心服務(wù)器。然后,中心服務(wù)器聚合所有客戶端上傳的模型參數(shù),更新產(chǎn)生下一輪的全局模型。在中心服務(wù)器的協(xié)調(diào)下,重復(fù)以上過程,這些客戶端通過協(xié)作方式解決以下有限和形式的優(yōu)化問題
圖1 聯(lián)邦學(xué)習(xí)的系統(tǒng)架構(gòu)
其中,向量x表示維度是d的模型參數(shù),f i:Rd→R為局部代價函數(shù),通常定義為第i個 客戶端上數(shù)據(jù)樣本的期望損失,表示為:fi(x)=Ez~P i[L i(x,z)],這里L(fēng) i是關(guān)于z的樣本損失函數(shù),它評估給定模型參數(shù)x的性能好壞,其中z是服從概率分布為Pi的隨機(jī)變量,i∈{1,2,...,n}。一般來說,當(dāng)Pi=P j時,稱客戶端i和j之間的數(shù)據(jù)是獨(dú)立同分布的。當(dāng)P i=P j時,稱客戶端i和j之間的數(shù)據(jù)是非獨(dú)立同分布或者異質(zhì)的。
本文用 Rd表示d維實(shí)向量空間?!ぁ硎練W氏范數(shù),‖·‖1表 示?1范數(shù)。gi表示第i個客戶端在模型參數(shù)x處的全梯度?f(x,D i),簡寫為?f i(x)。表示第i個客戶端在模型參數(shù)x處的隨機(jī)梯度?f(x,ξi),其中ξi?D i是采樣的最小批數(shù)據(jù)樣本。分別表示第i個 客戶端在模型參數(shù)處的梯度和隨機(jī)梯度估計,其中k表示全局模型的迭代次數(shù),h是本地模型的迭代次數(shù)。E [·]表示條件期望。
考慮到網(wǎng)絡(luò)中有限帶寬的限制,本節(jié)提出Fed-ET算法解決式(1)中的優(yōu)化問題。該算法是一個通信有效的聯(lián)邦學(xué)習(xí)算法,它利用事件觸發(fā)的通信策略,結(jié)合壓縮技術(shù),減小客戶端和中心服務(wù)器之間的通信負(fù)擔(dān)。
為了減少客戶端與中心服務(wù)器之間的通信,每個客戶端i∈{1,2,...,n}需要維持本地更新xi的一個估計,并根據(jù)事件觸發(fā)機(jī)制將信息壓縮后上傳至中心服務(wù)器。假定所提出的算法運(yùn)行K次全局模型更新,也即客戶端和中心服務(wù)器總的通信輪數(shù)。由于聯(lián)邦學(xué)習(xí)中客戶端的模型聚合是在同步環(huán)境下進(jìn)行的,定義模型參數(shù)的聚合時刻集合為{0,τ,...,Kτ},與文獻(xiàn)[8,12]中設(shè)定類似,τ表示每個客戶端本地模型更新的次數(shù)。此外,由于采樣是在固定次數(shù)的局部迭代之后進(jìn)行的,可知采樣周期是離散的,而且文中給出至少τ倍采樣周期的事件間隔的下界,因此Zeno現(xiàn)象在本文的設(shè)置中不會出現(xiàn)。
FedET算法的偽代碼如算法1所示。在有中心的聯(lián)邦學(xué)習(xí)架構(gòu)中,首先中心服務(wù)器將全局模型的初始值x0下發(fā)給每個客戶端,設(shè)置,允許每個客戶端在第1輪和中心服務(wù)器通信。第k次更新時,每個客戶端i通過隨機(jī)采樣本地數(shù)據(jù)計算梯度估計值,然后本地模型參數(shù)基于隨機(jī)梯度下降進(jìn)行τ次迭代得到參數(shù)向量,如算法1中第(4)~(6)行所示。在完成該輪次的本地計算之后,利用事件觸發(fā)機(jī)制,每個客戶端計算當(dāng)前模型和最近一次發(fā)送的模型之間的差異ei(k),定義為
其中,η是客戶端的本地學(xué)習(xí)率,,k ∈,其中=0,集合表示第i個客戶端觸發(fā)時刻的集合。如果滿足觸發(fā)條件,則客戶端發(fā)送壓縮后的模型差異給中心服務(wù)器,并更新本地模型的估計,如算法1中第(9)和(10)行所示。否則客戶端i和中心服務(wù)器在本輪不進(jìn)行通信。最后,中心服務(wù)器按照下式更新下一輪的全局模型,
其中,γ是中心服務(wù)器上的學(xué)習(xí)率,C(·)是隨機(jī)壓縮算子。
為了進(jìn)一步分析算法的性能,首先介紹函數(shù)光滑性和隨機(jī)梯度假設(shè)[17]及本文要用到的假設(shè)條件。
假設(shè)1對于代價函數(shù)f i(·):Rd →R,如果存在常數(shù)L>0,對任意的向量x1,x2∈Rd,不等式
成立,則稱函數(shù)f i(·)滿 足利普希茨條件,常數(shù)L被稱為利普希茨常數(shù)。本文假設(shè)f*是函數(shù)f(x)的最優(yōu)值。
假設(shè)2對于每一個i∈{1,2,...,n},參數(shù)σ>0,隨機(jī)梯度,若式(5)成立,
假設(shè)3對于隨機(jī)壓縮算子C(·),參數(shù)ω>0。對于輸入向量u∈Rd,若式(6)成立,
則稱隨機(jī)壓縮算子C(·)是無偏的,其方差隨著輸入向量范數(shù)的平方而變化。
假設(shè)4對于函數(shù)f(x),參數(shù)μ>0,如果對任意的x∈Rd,不等式成立,則稱函數(shù)f(x)滿足PL(Polyak-?ojasiew icz)條件。
事實(shí)上,根據(jù)文獻(xiàn)[18]可知,任意μ-強(qiáng)凸函數(shù)都滿足PL條件。
假設(shè)5對于函數(shù)f(x),參數(shù)μ>0,對任意的x,y ∈Rd,如果不等式
成立,則稱函數(shù)f(x)是μ-強(qiáng)凸的。
假設(shè)6對于序列{αk},假設(shè)αk=m0/(k+1)δ,其中m0>0是參數(shù),δ≥1/4。
假設(shè)2是關(guān)于隨機(jī)梯度的常見假設(shè)。假設(shè)3是采用壓縮方法的文獻(xiàn)[8,11]中常用的條件。假設(shè)6中的δ在不同的目標(biāo)函數(shù)下取值范圍不同,注1給出了說明。本文的結(jié)果都是在假設(shè)它們成立的前提下得到的。
本節(jié)給出以下3個主要結(jié)論,其詳細(xì)證明見第5節(jié)。首先針對非凸的目標(biāo)函數(shù),當(dāng)本地學(xué)習(xí)率η和全局學(xué)習(xí)率γ滿足一定的條件,利用函數(shù)的光滑性,下面的定理給出了算法的收斂性。
定理1如果假設(shè)1—假設(shè)3、假設(shè)6均成立,學(xué)習(xí)率γ和η滿足
則K次迭代后,梯度范數(shù)平方的平均值有界,且滿足
其中,f*?f(x*)是 函數(shù)在最優(yōu)點(diǎn)x*處的函數(shù)值,x0是模型參數(shù)的初始值,M0表示
下面定理主要對滿足PL條件的函數(shù)進(jìn)行分析,而任意μ-強(qiáng)凸函數(shù)都滿足PL條件,故只需分析滿足PL條件的函數(shù)的收斂性即可。
定理2如果假設(shè)1—假設(shè)4和假設(shè)6成立,或者假設(shè)1—假設(shè)3和假設(shè)5、假設(shè)6成立,且δ>。全局學(xué)習(xí)率γ和本地學(xué)習(xí)率η滿足式(8),則不等式(10)成立
其中,f*表示最優(yōu)值f(x*),x0是模型參數(shù)的初始值,M0表示
本節(jié)對所提出的基于事件觸發(fā)的聯(lián)邦優(yōu)化算法FedET進(jìn)行測試和性能評估,并通過圖像多級分類任務(wù)的仿真實(shí)驗(yàn)與FedAvg[1]和FedPAQ[8]進(jìn)行對比。仿真實(shí)驗(yàn)主要使用MNIST[19]和Fashion MNIST[20]這兩個具有代表性的數(shù)據(jù)集。其中,MNIST是手寫數(shù)字集,包含0~9 10個類別。Fashion MNIST包含10個類別服裝的正面圖片。每個數(shù)據(jù)集都包含60 000個訓(xùn)練數(shù)據(jù)和10 000個測試數(shù)據(jù)。
為了模擬真實(shí)世界中多方聯(lián)合學(xué)習(xí)的場景,每次實(shí)驗(yàn)假設(shè)10個客戶端在中心服務(wù)器的協(xié)作下聯(lián)合學(xué)習(xí)。本文使用PyTorch[21]作為分布式機(jī)器學(xué)習(xí)訓(xùn)練庫,并基于Python3來實(shí)現(xiàn)文中所提出的算法。實(shí)驗(yàn)環(huán)境是:Intel i7-7500U CPU@2.70 GHz的計算機(jī),包含一塊NV IDIA GeForce 940MX GPU。針對MNIST數(shù)據(jù)集和Fashion MNIST數(shù)據(jù)集,本文使用卷積神經(jīng)網(wǎng)絡(luò)CNN作為訓(xùn)練模型,該模型包含兩個卷積層、兩個最大池化層、1個全連接層和最后的softmax輸出層,使用ReLU作為激活函數(shù)。
實(shí)驗(yàn)假設(shè)所有客戶端使用相同的網(wǎng)絡(luò)模型,客戶端之間的數(shù)據(jù)均衡且服從獨(dú)立同分布:數(shù)據(jù)被隨機(jī)打亂之后,每個客戶端獲得6000個訓(xùn)練樣本。每個客戶端通過本地數(shù)據(jù)來訓(xùn)練神經(jīng)網(wǎng)絡(luò)模型,也就是權(quán)和偏置,初始學(xué)習(xí)率設(shè)定為0.01,數(shù)據(jù)最小批大小為64,本地迭代次數(shù)為5。實(shí)驗(yàn)采用文獻(xiàn)[8]中的壓縮方法,將模型從32位浮點(diǎn)量化為8位整數(shù)。閾值選取方法和文獻(xiàn)[14]中相同,閾值m0=v0N p,v0是參數(shù),通過選擇不同的v0可以改變m0的大小,N p表示神經(jīng)網(wǎng)絡(luò)模型參數(shù)總的個數(shù)。為了深入理解閾值對算法性能的影響,這里分別選取3個不同的閾值進(jìn)行實(shí)驗(yàn),閾值T1,T2,T3分別為0.00005Np,0.00009Np,0.00018Np。
圖2是MNIST數(shù)據(jù)集上的實(shí)驗(yàn)結(jié)果。從圖2(a)觀察到,在訓(xùn)練前期,對于較小的閾值T1,其損失曲線(藍(lán)色)波動較大,這是由于客戶端和中心服務(wù)之間頻繁的通信引起的。同理,對于較大的閾值T3,客戶端和中心服務(wù)器之間信息交流相對較少,其訓(xùn)練損失曲線(紅色)波動較小。閾值T2的訓(xùn)練損失曲線位于二者之間,在訓(xùn)練后期,雖然不同閾值情況下訓(xùn)練損失曲線最后均趨于穩(wěn)定,但是從測試精度和迭代次數(shù)關(guān)系圖2(b)不難看出,相對于較小的閾值T1和較大的閾值T3,閾值T2下的模型具有更高的精度值。從圖2(c)通信開銷角度來說,當(dāng)客戶端與中心服務(wù)器信息交流較少時,雖然通信開銷較少,但是無法達(dá)到理想的精度;客戶端與中心服務(wù)器之間充分交流雖然能保證精度,但是會造成大量的通信開銷;適當(dāng)?shù)男畔鬏?,如閾值T2情況下,可以以較少的通信代價最大程度的提高精度??偟膩碚f,閾值較小時,需要昂貴的通信代價來換取算法的精度,而閾值較大時,通信開銷減小,但算法精度又較差。因此,選擇合適的閾值才能達(dá)到精度和通信之間最好的權(quán)衡。
圖2 不同閾值對算法性能的影響
為了證明基于事件觸發(fā)的聯(lián)邦優(yōu)化算法的優(yōu)勢,以下實(shí)驗(yàn)對比了不同算法下的收斂速度、平均通信開銷和測試精度,包括聯(lián)邦平均Fed Avg和FedPAQ算法。圖3分別刻畫了在不同數(shù)據(jù)集上全局模型的平均訓(xùn)練損失和迭代輪次之間的關(guān)系,其中圖3(a)是在MNIST數(shù)據(jù)集上的運(yùn)行結(jié)果,可以看出FedET算法達(dá)到了和FedAvg和FedPAQ算法相似的收斂速度,雖然模型變化不大時客戶端和中心服務(wù)器不進(jìn)行通信,但獨(dú)立同分布的數(shù)據(jù)特征使算法快速收斂。在Fashion MNIST數(shù)據(jù)集上有類似的結(jié)果,如圖3(b)所示。
圖3 訓(xùn)練損失和迭代次數(shù)之間的關(guān)系
圖4顯示了客戶端和中心服務(wù)器之間上行鏈路中平均累計通信開銷和訓(xùn)練損失值的變化關(guān)系。從圖4(a)可以觀察到,當(dāng)達(dá)到最小損失時,F(xiàn)edAvg的通信比特數(shù)最多,F(xiàn)edPAQ次之,F(xiàn)edET的通信代價最少。這是由于算法分別采取了不同的通信有效技術(shù)引起的。FedAvg算法中客戶端向中心服務(wù)器發(fā)送整個模型參數(shù),F(xiàn)edPAQ算法將模型之間的差異壓縮后上傳給中心服務(wù)器,而FedET使用壓縮和事件觸發(fā)機(jī)制,和前兩個算法相比分別節(jié)約了64%和10%左右的通信。在Fashion MNIST數(shù)據(jù)集上的實(shí)驗(yàn)結(jié)果如圖4(b)所示,當(dāng)達(dá)到0.2的訓(xùn)練損失值時,F(xiàn)edET分別比FedAvg,FedPAQ節(jié)約了75%和35%左右的通信量。
圖4 訓(xùn)練損失和通信量之間的關(guān)系
測試精度是衡量模型泛化性能的一個重要指標(biāo)。圖5刻畫了測試集上全局模型的預(yù)測精度在平均累計通信開銷下的變化趨勢。從MNIST數(shù)據(jù)集上的實(shí)驗(yàn)結(jié)果圖5(a)不難看出,當(dāng)達(dá)到大約99%的精度時,F(xiàn)ed ET使用最少的通信比特值,F(xiàn)ash ion MNIST數(shù)據(jù)集上的實(shí)驗(yàn)結(jié)果和上述結(jié)論是類似的,如圖5(b)所示。在資源有限的網(wǎng)絡(luò)環(huán)境下,事件觸發(fā)機(jī)制的使用降低了對通信帶寬的需求。兩個數(shù)據(jù)集上的實(shí)驗(yàn)結(jié)果驗(yàn)證了本文所提的聯(lián)邦優(yōu)化算法FedET的可行性和有效性。
本文基于事件觸發(fā)機(jī)制提出一個通信有效的聯(lián)邦優(yōu)化算法,用于解決聯(lián)邦學(xué)習(xí)中的通信開銷問題。該算法通過減少上行鏈路中客戶端到中心服務(wù)器不必要的信息傳輸,同時結(jié)合信息壓縮技術(shù)減少每輪發(fā)送的信息比特數(shù),緩解聯(lián)邦學(xué)習(xí)中的通信瓶頸。在客戶端數(shù)據(jù)獨(dú)立同分布時,針對不同的目標(biāo)函數(shù)特征,本文從理論上分析了所提算法的收斂性,并給出了相應(yīng)的數(shù)學(xué)證明。最后,在MNIST和Fashion MNIST這兩個公共數(shù)據(jù)集上執(zhí)行仿真實(shí)驗(yàn),結(jié)果表明,所提算法能在合適的閾值下達(dá)到與FedAvg和FedPAQ算法相似的精度,同時節(jié)約了通信成本。