鐘坤華,秦小林,陳敏,陳芋文
(1.中國科學院成都計算機應用研究所,四川成都 610041;2.中國科學院重慶綠色智能技術(shù)研究院,重慶 400714;3.中國科學院大學,北京 100049)
機器學習正在成為醫(yī)療保健領(lǐng)域越來越重要的技術(shù)手段。一些基于機器學習算法的人工智能系統(tǒng)在癌癥分類檢測、糖尿病視網(wǎng)膜病變檢測方面的水平已經(jīng)接近甚至超過了人類專家。毫無疑問,人工智能將重塑醫(yī)學的未來。然而,目前已成功應用于醫(yī)療問題的機器學習方法僅基于關(guān)聯(lián)而非因果關(guān)系。在統(tǒng)計學中,關(guān)聯(lián)在邏輯上并不意味著因果關(guān)系。相關(guān)性與因果關(guān)系之間的關(guān)系由Reichenbach正式確定為共同原因原則,即如果兩個隨機變量X和Y在統(tǒng)計上相互依賴,則必須持有以下因果解釋之一:①X是Y的直接原因;②有一個隨機變量Z是X和Y的共同原因。因此,與關(guān)聯(lián)相比,因果關(guān)系進一步探索了變量之間更本質(zhì)的關(guān)系。
隨著現(xiàn)代醫(yī)學技術(shù)的飛速發(fā)展,針對患者采集的臨床數(shù)據(jù)越來越多,這種增長對疾病預測模型的性能以及檢測效率均提出了巨大挑戰(zhàn)。理論上使用的特征越多,模型訓練效果越好,而在測試集中效果不理想的現(xiàn)象可解釋為非相關(guān)特征過度擬合,導致模型性能和泛化能力降低。但事實上,變量越多并不意味著信息越有用,預測效果越好。因此,為了減小數(shù)據(jù)集規(guī)模、提高模型預測性能,減少特征數(shù)量非常必要。在機器學習中,特征選擇是獲得良好預測效果的重要步驟之一。近年來,人們不僅對基于信息選擇特征進行預測感興趣,還希望了解這些特征與研究目標的相互作用。在這種背景下,一些研究者開發(fā)了一些理論,試圖將圖(Graph)與因果關(guān)系的概念引入到特征選擇中,目的是找到能夠生成數(shù)據(jù)的因果關(guān)系,以便更好地理解數(shù)據(jù)集的底層機制。以癌癥為例,我們需要知道其是什么原因?qū)е碌模枰褂媚男┳兞恐斡?/p>
因果特征選擇作為一種新興的特征濾波方法,其為特征與類屬性之間的關(guān)系提供了因果解釋,從而更好地理解數(shù)據(jù)背后的機制。與非因果特征選擇相比,因果特征選擇在理論上是最優(yōu)的,回答了最優(yōu)特征選擇包含哪些核心特征,以及特征濾波方法在什么條件下能夠輸出最優(yōu)特征的問題。
傳統(tǒng)的因果特征選擇是在因果貝葉斯網(wǎng)絡(Causal Bayesian Network,CBN)中尋找類屬性的馬爾可夫毯(Markov Blanket,MB),其中邊X→Y表示X為Y的直接原因(父親節(jié)點),Y為X的直接結(jié)果(孩子節(jié)點)。目標變量(例如類標簽)的MB由父節(jié)點、子節(jié)點以及子節(jié)點的父節(jié)點(配偶節(jié)點)構(gòu)成。MB提供了圍繞局部因果結(jié)構(gòu)的完整結(jié)構(gòu),即MB是最小的特征集,其使類屬性在統(tǒng)計上條件獨立于所有的其他屬性。在該研究領(lǐng)域,Koller等首先引入MBs進行特征選擇,并提出Koller-Sahami(KS)算法,但KS算法并不能保證找到真正的MB;Margaritis等設(shè)計了一種GS(Growing-Shrinking)算法,可用于貝葉斯網(wǎng)絡結(jié)構(gòu)學習;Tsamardinos等改良了GS算法,并提出一系列用于最優(yōu)特征選擇的MB發(fā)現(xiàn)算法,從而形成了IAMB(Incremental Association-based MB)算法家族,包括IAMB、interIAMB、IAMBnPC和FastIAMB等;Goudet等提出因果生成神經(jīng)網(wǎng)絡(Causal Generative Neural Networks,CGNNs),利用條件獨立性和分布不對稱性探索雙變量和多變量的因果結(jié)構(gòu);Kalainathan等提出結(jié)構(gòu)不可知建模(Structural Agnostic Modeling,SAM)方法,該法基于不同參與者之間的博弈,結(jié)合分布估計、稀疏性和非循環(huán)性約束的學習準則,通過隨機梯度下降方法進行端到端的參數(shù)學習。
本文參考文獻[13][14]的研究成果,提出一種基于生成神經(jīng)網(wǎng)絡和強化學習的因果特征選擇和預測模型,框架如圖1所示。該模型包含一個因果門網(wǎng)絡和一個因果預測網(wǎng)絡,其中因果門網(wǎng)絡輸入原始數(shù)據(jù),輸出選擇因果概率,然后根據(jù)這些概率對選擇向量進行采樣;因果預測網(wǎng)絡接收所選特征并進行預測。兩個網(wǎng)絡基于真實標簽進行反向傳播的訓練,然后從預測網(wǎng)絡的損失中減去基線網(wǎng)絡損失,用于因果門網(wǎng)絡的更新。
X
,…X
,X
],表示d+1維隨機變量向量;P(X)
為聯(lián)合概率分布;X′
=[X
,…,X
,X
],表示d
個隨機特征空間變量向量;X
為離散的標簽空間變量。基于觀察因果發(fā)現(xiàn)從分布P(X)
中采集獨立同分布的樣本D={X(1),…,X(j),…,X(n)},X(j)=(X
,…X
,X
)。為了更清楚地表示患者數(shù)據(jù),將X
表示為患者的疾病標簽Y。f
,…f
,f
),為一組d+1的因果機制。函數(shù)因果模型假設(shè)每個變量滿足如下關(guān)系:f
。如圖3所示,深層神經(jīng)網(wǎng)絡的因果機制由H隱層神經(jīng)網(wǎng)絡實現(xiàn),其中c
=(c
,c
,…,c
)為因果系數(shù)。如果使用變量X
生成Y,即X
→Y在圖G中有一條邊,因此認為X
為Y的原因,c
為1,否則c
為0;E
為高斯噪聲。網(wǎng)絡結(jié)構(gòu)的數(shù)學表達式為:Fig.1 The proposed model framework圖1 本文模型框架
c
⊙X
表示兩個向量之間對應元素相乘,[c
⊙X
,E
]為連接c
⊙X
和噪聲的d+1維向量,L為隱層中的代數(shù)變換。Fig.2 Example of functional causal model on X(Left:causal graph G;Right:causal mechanisms)圖2 在X上的函數(shù)因果模型示例(左:因果圖,右:因果機制)
Fig.3 Neural network causal mechanisms圖3 神經(jīng)網(wǎng)絡因果機制圖
如函數(shù)因果模型所描述,特征選擇的目標是找到一個盡可能小的X子集,使基于X的最優(yōu)子集與X全集具有相同的效應,表示為:
c
⊙X
,Y的條件分布與給定所有X、Y的條件分布相同。本文使用Kullback-Leibler(KL)散度將式(3)轉(zhuǎn)換為式(4),以最小化兩個分布的距離,表示為:本文模型的改良在于設(shè)計了因果門結(jié)構(gòu),主要基于強化學習框架對特征進行因果選擇預測,學習率為0.000 1,激活函數(shù)為ReLu,batch_size為100。
f
為因果門特征選擇網(wǎng)絡,稱為Actor,是由3層隱藏層組成的全連接網(wǎng)絡,輸入節(jié)點根據(jù)實際輸入數(shù)據(jù)確定。f
:X→{0,1},該網(wǎng)絡輸出每個特征的選擇概率,給定特征選擇向量的概率為c∈[0,1],則有:因果門特征選擇網(wǎng)絡的損失函數(shù)表示為:
f
為因果預測器網(wǎng)絡,稱為Critical。該網(wǎng)絡為3層全連接網(wǎng)絡,每層隱藏層有200個節(jié)點,輸入節(jié)點根據(jù)實際輸入數(shù)據(jù)確定。接受選擇的因果特征向量作為輸入,在c維輸出空間中輸出概率分布。該網(wǎng)絡的損失函數(shù)表示為:
y
為y的第i
個分量編碼,c
⊙X
為因果門選擇的特征。f
為預測網(wǎng)絡,結(jié)構(gòu)與f
因果預測器網(wǎng)絡(Critical)一致,隱藏層為200個節(jié)點的3層全連接前饋神經(jīng)網(wǎng)絡,并經(jīng)過訓練以最小化。該網(wǎng)絡使用所有觀察到的患者數(shù)據(jù)進行直接預測,損失函數(shù)表示為:使用BP反向傳播算法組合上述3個損失函數(shù)對3個神經(jīng)網(wǎng)絡進行端到端的訓練,將患者觀察數(shù)據(jù)輸入訓練后的模型,得到特征的最優(yōu)子集和預測結(jié)果。
在合成數(shù)據(jù)、開源數(shù)據(jù)和真實世界醫(yī)學數(shù)據(jù)上進行驗證實驗,從特征選擇的相關(guān)性和預測的準確性兩方面評估模型性能。將特征選擇模型與LIME和Shapley兩種方法進行比較,將預測模型與XGBoost和Lasso正則化線性模型進行比較。
服務器搭載Ubuntu 16.04 LTS操作系統(tǒng)、Intel Xeon e5-2650 V4處理器和Nvidia GTX 1080 Ti GPU,內(nèi)存64G。基于Pytorch框架構(gòu)建模型,編程工具為Python3.6。
針對每個數(shù)據(jù)集生成40 000個樣本,其中20 000個用于訓練,20 000個用于測試。特征選擇時使用真陽性率(TPR,越高越好)和錯誤率(FDR,越低越好)評估算法性能,具體定義見表1和式(12)、式(13);使用接受者操作特征曲線下面積(Area Under the Receiver Operating Characteristic curve,AUROC)、精確召回曲線下面積(Area Under Precision-Recall Curve,AUPRC)評估預測準確度。
使用Adam優(yōu)化器進行訓練,初始學習率為0.000 1,并采用stepLR學習率變化策略,每10步更新1次學習率,共訓練100epoch。
Tabel 1 Definition of TPR and FDR表1 TPR和FDR的定義
分析特征選擇作為預測預處理步驟的效果。首先進行特征選擇,然后訓練一個3層全連接的神經(jīng)網(wǎng)絡,在特征選擇的數(shù)據(jù)上執(zhí)行預測。如表2所示,本文模型的TPR和FDR均明顯優(yōu)于LIME和Shapely算法,能有效檢測相關(guān)特征。如表3所示,當丟棄所有不相關(guān)特征時,本文模型性能有顯著提高,但XGBoost和Lasso算法性能提升不明顯。
Table 2 Feature selection result for synthetic datasets表2 合成數(shù)據(jù)集的特征選擇結(jié)果
Table 3 Prediction performance results表3 預測性能結(jié)果
使用開源醫(yī)療數(shù)據(jù)集進行驗證實驗,該數(shù)據(jù)集為根據(jù)墨西哥、秘魯和哥倫比亞居民的飲食習慣和身體狀況估計肥胖水平的數(shù)據(jù),包含15個屬性和2 111條記錄。該數(shù)據(jù)集中77%的數(shù)據(jù)是使用Weka工具和SMOTE過濾器綜合生成的,23%的數(shù)據(jù)是通過Web平臺直接從用戶處收集的。所有數(shù)據(jù)均被標記,類變量的值分別為normal和abnormal。數(shù)據(jù)集的具體屬性見表4。
從表5可以看出,本文模型在肥胖預測能力方面與全特征預測方法的性能基本一致。原因可能是特征數(shù)量很小,并且所選特征與預測標簽之間有很強的相關(guān)性,因此本文特征選擇模型的優(yōu)勢沒有得以體現(xiàn)。此外,本文繪制了測試患者特征選擇概率的熱圖,如圖4(彩圖掃OSID碼可見,下同)所示,模型預測肥胖患者的主要原因為Weight、FHWO、CAEC和FAF變量。
Table 4 Obesity levelsdata set attributes表4 肥胖水平數(shù)據(jù)集屬性
Table5 Prediction performance results表5 預測性能結(jié)果
Fig.4 Feature selection probabilistic heat map圖4 特征選擇概率熱圖
使用心力衰竭數(shù)據(jù)集進行驗證實驗,數(shù)據(jù)來源于第三軍醫(yī)大第一附屬醫(yī)院2014-2018年間住院的1 452例患者,包含66個測量特征,標簽為心力衰竭。數(shù)據(jù)集的具體屬性見表6。
Table 6 Heart failure data set attributes表6 心力衰竭數(shù)據(jù)集屬性
續(xù)表
如表7所示,當丟棄所有不相關(guān)特性時,本文模型預測性能相較全特征預測方法有輕微提高。圖5描述了男性和女性心力衰竭患者所選特征平均概率熱圖。可以看出,導致成年男性和女性心力衰竭的因素是相同的,這與醫(yī)生的判斷基本一致。
Table 7 Prediction performance results表7 預測性能結(jié)果
Fig.5 Female and male heart failure patients'features selected for average probability heat maps(a:Female,b:Male)圖5 女性和男性心力衰竭患者所選特征平均概率熱圖(a:女性,b:男性)
本文針對特征選擇與預測問題,從因果特征分析的視角提供了一種新方法。首先,從定性的角度進行特征選擇,然后在強化學習框架下,設(shè)計可解釋的實例特征選擇與預測模型,最后在合成數(shù)據(jù)、開源數(shù)據(jù)以及真實數(shù)據(jù)集上進行了實驗評估,結(jié)果表明該方法可有效選擇屬性進行疾病預測。本文研究成果能在一定程度上拓展醫(yī)療問題的分析角度,并進一步回答病因與疾病的因果關(guān)系,例如醫(yī)療健康領(lǐng)域存在哪些反事實問題,哪些反事實問題能夠得以解決,以及醫(yī)療健康領(lǐng)域是否存在因果鏈等。本文研究也存在一定的局限性,例如關(guān)注的只是患者的靜態(tài)屬性數(shù)據(jù),尚不能應用于動態(tài)的時間序列數(shù)據(jù),如圍術(shù)期的監(jiān)護數(shù)據(jù)。后續(xù)將嘗試進行動態(tài)數(shù)據(jù)、混合數(shù)據(jù)的因果分析,例如采用循環(huán)神經(jīng)網(wǎng)絡替換本文模型中的網(wǎng)絡,以適用于醫(yī)療健康領(lǐng)域中的時間序列數(shù)據(jù)研究。