左一鵬, 陳 輝
(上海電力大學 自動化工程學院, 上海 200090)
機器學習中的數據分類問題在機器人導航、無人駕駛、垃圾郵件過濾、手寫數字的識別等領域中扮演著不可或缺的角色[1]。常見的數據分類算法有K近鄰[2]、樸素貝葉斯[3]、決策樹[4]、Logistic回歸[5]和支持向量機(Support Vector Machines,SVM)[6]。目前應用最為廣泛的是SVM分類器,其核心思想是線性分類器,最初解決線性可分問題,然后拓展到非線性映射,將低維空間內不可分的樣本映射成高維空間線性可分的樣本,通過超平面將樣本分割成不同類別。SVM分類器的優(yōu)勢在高維空間內非常有效,即使在數據維度比樣本數量大的情況下仍然可高效利用內存。SVM的核函數的通用性較強,可根據數據集的特點選擇核函數類型以及參數。SVM分類器存在的問題是對小量數據分類時易出現過擬合現象,泛化性不強。帶有核函數的SVM一般依靠經驗選擇參數,且僅能達到局部最優(yōu),手動調參方式的效率低下,實際應用困難。針對SVM參數調整問題,目前學者們已提出很多種優(yōu)化算法,如人工蜂群算法結合SVM[7],基于蝙蝠算法優(yōu)化的方法結合SVM[8]和結合KNN算法的改進[9]等。但是針對SVM使用過程中的參數選擇,沒有一個固定的方法。
本文從提高實際使用時效率的角度出發(fā),針對SVM參數調整及優(yōu)化問題,提出一種基于Scikit-Learn的SVM參數調整優(yōu)化方法。結合網格搜索以及交叉驗證的方法對最優(yōu)參數范圍進行搜索,利用Python機器學習庫Scikit-Learn對不同參數、不同核函數的分類結果進行可視化觀察,并在網格上顯示其最優(yōu)參數范圍,尋找準確率高的參數分布,再通過迭代的方法提高對最優(yōu)參數的求解,最終實現數據的準確分類。
SVM分類器按監(jiān)督學習方式對數據進行二元分類,通過分類超平面能使分類間隔最大化[10]。在分類問題中,令給定的輸入數據樣本集為:X={X1,X2,X3,…,Xm},學習目標為:y={-1,1},每個樣本Xi(i=1,2,3,…)包含n個特征,即Xi=[x1,x2,x3,…,xn]。針對在二維平面上分布的兩種數據類型,需要一個超平面對數據進行分割。超平面可以用線性方程表示為
ωTXi+b=0
(1)
式中:ω——法向量,決定超平面的方向;
b——偏移量,決定超平面與原點之間的距離。
對于線性分類問題,可以通過設定f(Xi)=ωTXi+b為決策函數進行處理。對于所有樣本點,均滿足如下條件:即y=1時,認為ωTXi+b≥1;如果y=-1時,則ωTXi+b≤-1。
根據點到面的距離公式,可知空間中任意一點樣本Xi到超平面(ω,b)的距離為
(2)
(3)
圖1為以鳶尾花數據集為例的SVM二分類最大間隔分類示意圖,橫坐標為維吉尼亞鳶尾、山鳶尾花萼長度,縱坐標為花萼寬度。 SVM的二分類就是在兩個類別不同特征的分布中找到一個使得兩個類別區(qū)分間隔最大的分割線。
圖1 SVM二分類最大間隔分類示意
通常線性不可分問題可能是非線性可分的,即在特征空間中存在超曲面將正類和負類分開。通過非線性函數將可分問題從原始特征空間映射至更高維希爾伯特空間,從而轉化為線性可分問題。
設χ是輸入空間,即Xi∈χ,χ是Rn的子集或離散集合,又設H為特征空間,如果存在一個χ到H的映射:
φ(x):χ→H
(4)
使得對所有x,z∈χ,函數K(x,z)滿足條件
K(x,z)=〈φ(x),φ(z)〉
(5)
則稱K為核函數[11],其中φ(x)為映射函數。核函數主要包括:多項式核、徑向基函數核、拉普拉斯核和Sigmoid核等。
針對數據集的多樣化,SVM分類器可進行針對性分類,對于不能直接線性分類的數據集需要使用非線性SVM分類[12]。通過特征添加,將n維平面空間映射到n+1維空間,將二維空間平面映射到三維空間以便于類別的區(qū)分,從而將數據集變?yōu)榭删€性分離。
在數據線性不可分的情況下,SVM首先在低維空間中完成計算,然后通過核函數將輸入空間映射到高維特征空間,最終在高維特征空間中構造出最優(yōu)分離超平面,將不易分離的非線性數據分類。高斯徑向基核函數(Radial Basis Function,RBF)[13]是一種局部性強的核函數,可將一個樣本映射到一個更高維的空間內。其應用廣泛,通用性強,更適合在無先驗知識情況下使用。本文選用RBF核作為核函數進行數據分類。其公式為
(6)
網格搜索法[14]是通過在指定的數值范圍內,按照一定次數將參數進行調整試驗,以網格形式表示,得到相應的最優(yōu)參數組合。但網格搜索法受限于數據量的大小以及選定產生網格數據的大小,數據大時難以得出結果。針對該問題,可利用多次提取網格搜索的參數最優(yōu)值在相鄰范圍內繼續(xù)進行網格搜索以及適當擴大網格搜索的數據量,使計算結果逐漸趨近于最優(yōu)參數的值[15]。本文通過逐漸趨近的方式進行網格搜索調參,構建可以自動迭代的網格搜索法。
基于手動調整網格搜索思想,通過設置自動循環(huán)進行網格調參。在每次網格搜索結束后,利用函數調用對搜索C和gamma的最優(yōu)參數值進行讀取,得到網格中的最優(yōu)參數值,觀察網格搜索結果分布,確定最優(yōu)參數值。對于數值較小的gamma值,選擇以10倍基準為上限,0.1倍基準為下限;對于參數較大的C以2倍基準為上限,0.01倍基準為下限,以此為下一次的迭代區(qū)域進行搜索。該方法首先根據網格參數的搜索得到最優(yōu)值,再通過最優(yōu)值設定上下限,并在區(qū)域內進行進一步搜索,持續(xù)迭代至準確率穩(wěn)定。
為了驗證所提方法的有效性,以鳶尾花數據集為測試對象(包含150個數據,分為3類,每類50個數據,每個數據包含4個屬性),進行分類測試。實驗在Ubuntu16.04操作系統(tǒng)、Scikit-Learn開發(fā)環(huán)境下進行,程序為Python3.7語言??赏ㄟ^花萼長度、花萼寬度、花瓣長度、花瓣寬度4個屬性預測鳶尾花卉屬于維吉尼亞鳶尾、山鳶尾、雜色鳶尾3個種類中的哪一類。在本次實驗測試中,選取花萼長度和花萼寬度兩種特征進行分類。
在鳶尾花數據集中,從花萼長度和花萼寬度兩項數據的分布,觀察3種花的不同。維吉尼亞鳶尾花萼長度分布在4.9~7.9 cm,雜色鳶尾花萼長度分布在4.9~7.0 cm,山鳶尾花萼長度分布在4.3~5.8 cm。維吉尼亞鳶尾花萼寬度分布在2.2~3.8 cm,雜色鳶尾花萼寬度分布在2.0~3.4 cm,山鳶尾花萼寬度分布在2.3~4.4 cm。從整體的分布上看,花萼長度由低到高排序依次為山鳶尾、雜色鳶尾、維吉尼亞鳶尾,花萼寬度由低到高的排序依次為雜色鳶尾、維吉尼亞鳶尾、山鳶尾。在這些數據中會有部分不同種類花的花萼長度、寬度分布在同一區(qū)間。機器學習分類過程中,會根據實際樣本的數據分布情況以及參數的設定,確定一個能夠較好的將兩者分類的界限。實驗中,選擇了4種不同類型的核函數進行分類結果比較。分類結果如圖2所示。
圖2 不同核函數的分類結果
其中,SVC是SVM中的一類,用于樣例數量少于10 000時的二元和多元分類。由圖2可以看出:圖2(a)、圖2(c)、圖2(d)的分類結果中,在兩種類別的分割上不是很合理,其中圖2(a)和圖2(c)的分割近乎直線方式,分類結果對于實際數據集的效果不理想,主要是因為實際分類的對象是具有某種特征范圍內的一類,而不是僅僅具有線性關系就可以分類;圖2(b)和圖2(d)的分類結果比較平滑,但是圖2(d)的類別劃分中,有的近似于直線,有的類別劃分延伸較長。同時結合表1可以看出,帶有RBF核函數的SVC分類準確率略高一些。
表1 不同核函數下SVM分類結果對比
為了提高帶有RBF核函數的SVM分類器的準確率,進一步調整參數。將gamma參數設置為5,并保持不變,僅對參數C進行調整,采用數量級調參的方式,即改變C參數的數量級,觀察這種變化對分類準確率的影響,結果如表2和圖3所示。
表2 改變參數C對分類準確率的影響
圖3 RBF核函數不同參數值的分類結果
由表2可知,C的數值由0.01變?yōu)?.1后,分類準確率提升幅度較大,這是因為將C值調高后,提高了對總誤差的懲罰系數,在程序運行中,會增加對準確率的權重,因此準確率大幅上升。由圖3可知,改變C參數的數量級在分類結果的改進中并不是非常明顯,僅圖3(d)的分類結果區(qū)分比較明顯,不過這種方式效率比較低。手動更改gamma和C的參數值對于分類的結果提高并不顯著,因此需要一種可以高效的求解最優(yōu)參數的方法。
在初次進行網格搜索時,設置參數C的范圍為10-2~1010,gamma的范圍為10-9~103,都分成13等份。輸出網格排列的準確率圖像以及準確率最高一組的C與gamma的參數值。根據Scikit-Learn的使用特點,并為了便于可視化,選擇雜色鳶尾和維吉尼亞鳶尾兩種進行二分類。
圖4表示經過網絡搜索后,不同參數下的分類結果,其中花萼長度(單位cm)作為橫軸,花萼寬度(單位cm)作為縱軸。藍色數據點表示雜色鳶尾數據,藍色區(qū)域表示分類后對判斷屬于這一花類范圍的劃分;紅色數據點表示維吉尼亞鳶尾,紅色區(qū)域表示分類后對屬于這一花類范圍的劃分。
圖4 初次網絡搜索的分類結果
由圖4可知,C=1,gamma= 0.1時,準確率達到97%,為手動調整參數的最高準確率,但是不能確定是否為全局最高,需要進一步進行優(yōu)化,以確定全局最優(yōu),并且提高尋找效率。
參數C與gamma對應的準確率分布如圖5所示。
圖5 參數C與gamma對應的準確率分布
由圖5可知,當C為1時,觀察不同的gamma數值從0.001~1變化,發(fā)現準確率僅在一定范圍內增加或者降低,沒有得到全局最優(yōu)參數值。
采用網格搜索法調整參數,5次網格搜索結果如表3所示。
表3 5次網格搜索結果
由表3可以看到,在圖5中的一個區(qū)域范圍內,手動調整參數準確率比較高。由于第一次設置的數值間隔大且所分組數較少,得到的可產生最高準確率的一組參數只是與最優(yōu)參數相近的數值。因此,采用逐漸逼近法進行網格搜索,逐漸縮小參數的搜索范圍,同時適當增加所分數組數量。圖6是針對表3的參數進行實驗所得出的分類效果,花萼長度(單位cm)作為橫軸,花萼寬度(單位cm)作為縱軸。藍色數據點表示雜色鳶尾數據,藍色區(qū)域表示分類后對判斷屬于這一花類范圍的劃分;紅色數據點表示維吉尼亞鳶尾,紅色區(qū)域表示分類后對屬于這一花類范圍的劃分。
網格搜索在一定范圍內的最優(yōu)參數求解效果明顯,但受網格搜索大小限制,只能先求出部分參數,然后逐步手動輸入下一個范圍,增加了工作量,且效率較低。因此,本文根據分類結果圖中所示意的最優(yōu)參數分布,采用自動循環(huán)迭代的網格搜索進行參數的選擇。圖5中C=10,gamma=0.01時,對應的白色表示在該網格的搜索下準確率最高一組參數。參數gamma較小,選擇以10倍基準為上限,0.1倍基準為下限,對于較大的參數C以2倍基準為上限,0.01倍基準作為下限得出gamma和C的最佳參數范圍,通過進行迭代以進一步求得更精確的最優(yōu)參數。本文中使用的數據集取自動循環(huán)次數為40次。初次設定gamma范圍為0.001~0.1,C的范圍為0.1~20,自動循環(huán)的網格搜索運行。迭代次數為1,2,16,24,32,40時的準確率如表4所示。由表4可知,在運行至第2次后,得到了98%的準確率。但本文進行了40次的循環(huán),這是因為在較高的準確率下,尋找更優(yōu)參數的難度增加,為了確定最高準確率的可靠性,在后續(xù)的循環(huán)中得到準確率一致的情況下,可以認為得到了全局最高準確率。
圖6 3組優(yōu)化參數不同搭配的分類結果
針對一種數據集分類時,采用傳統(tǒng)SVM分類器,需要調整兩個參數的數值,運行時間約為10 min,實際應用不方便。先進行網格搜索方式確定最優(yōu)參數的大致范圍后,可以進一步通過自動迭代的方式求取最優(yōu)參數精確數值。該方法準確率高,且求解方式便捷,整個求解時間僅為26.77 s,大大縮減了尋找最優(yōu)參數的時間。
表4 取40次網格搜索結果中的5次結果
本文提出了一種SVM算法參數調整優(yōu)化方法,且與手動調整參數以及網格搜索選擇參數進行了實驗比較研究。在選擇最優(yōu)參數(C,gamma)的過程中,以網格搜索法為基礎,再通過在選定參數范圍的基礎上,采用自動迭代方式進行更精確的調參。實驗結果表明,采用該方法,不僅分類的精度由原來的最高97.667%提高至98.00%,改進了學習性能,而且相較用原始方法調參使用10 min相比,該方法只用了26.77 s,大大提高了實際應用的效率。