袁培森 吳茂盛 翟肇裕 楊承林 徐煥良
(1.南京農(nóng)業(yè)大學(xué)信息科學(xué)技術(shù)學(xué)院, 南京 210095; 2.馬德里理工大學(xué)技術(shù)工程和電信系統(tǒng)高級(jí)學(xué)院, 馬德里 28040)
表型(Phenotype)研究核心是獲取高質(zhì)量的性狀數(shù)據(jù),進(jìn)而對(duì)基因型和環(huán)境互作效應(yīng)(Genotype-by-Environment) 進(jìn)行分析[1-2],表型組學(xué)近年來(lái)發(fā)展迅猛,已成為分子育種和農(nóng)業(yè)應(yīng)用中的重要技術(shù)支撐[3-4]。然而,植物表型數(shù)據(jù)的獲取需搭建實(shí)驗(yàn)環(huán)境,并需昂貴的數(shù)據(jù)采集工具,具有周期長(zhǎng)、代價(jià)高昂等特點(diǎn)[1,5-6]。當(dāng)前,以大數(shù)據(jù)為基礎(chǔ)的深度學(xué)習(xí)正在成為表型數(shù)據(jù)分析的有力工具[7-8],深度學(xué)習(xí)相關(guān)算法的有效性在很大程度上取決于標(biāo)記樣本的數(shù)量,因此限制了其在小樣本量環(huán)境中的應(yīng)用[9]。數(shù)據(jù)的非均衡性是生物表型數(shù)據(jù)具有挑戰(zhàn)性的問(wèn)題[10-13]。
為了提升非均衡數(shù)據(jù)分析的性能和質(zhì)量,文獻(xiàn)[14-15]提出了數(shù)據(jù)生成的方法。然而,過(guò)采樣技術(shù)SMOTE[15]、ADASYN[16]等對(duì)于處理經(jīng)典學(xué)習(xí)系統(tǒng)中的類不平衡有效,但是此類方法生成的數(shù)據(jù)不能直接應(yīng)用于深度學(xué)習(xí)系統(tǒng)[17]。近年來(lái),生成式對(duì)抗網(wǎng)絡(luò)(Generative adversarial networks,GAN)[18]的出現(xiàn)為計(jì)算機(jī)視覺(jué)應(yīng)用提供了新的技術(shù)和手段,GAN采用零和博弈與對(duì)抗訓(xùn)練的思想生成高質(zhì)量的樣本,具有比傳統(tǒng)機(jī)器學(xué)習(xí)算法更強(qiáng)大的特征學(xué)習(xí)和特征表達(dá)能力[19],是一種基于深度學(xué)習(xí)的學(xué)習(xí)模型,可以用于海量數(shù)據(jù)的智能生成,已經(jīng)廣泛用于圖像、文本、語(yǔ)音、語(yǔ)言等領(lǐng)域[20-21]。
有學(xué)者提出將GAN網(wǎng)絡(luò)技術(shù)用于生物學(xué)等領(lǐng)域的數(shù)據(jù)生成問(wèn)題[9,22-25],結(jié)果顯示生成數(shù)據(jù)的質(zhì)量有顯著提高。目前,記錄約8萬(wàn)種真菌、近1 500種野生蘑菇種類的圖像數(shù)據(jù)集,這對(duì)種類繁多和分布非均衡的菌類識(shí)別和分類具有重要的生態(tài)意義[26-28]。
本文提出基于生成對(duì)抗網(wǎng)絡(luò)的菌菇表型數(shù)據(jù)生成方法(Mushroom phenotypic based on generative adversarial network, MPGAN)。以菌菇表型為研究對(duì)象,在特定目標(biāo)域上訓(xùn)練GAN網(wǎng)絡(luò),作為GAN發(fā)生器網(wǎng)絡(luò)的輸入給出潛在模型,以期生成可控制和高質(zhì)量的蘑菇圖像。
GAN[18]的核心思想來(lái)源于博弈論的納什均衡,它設(shè)定雙方分別為生成器和判別器,生成器的目的是盡量學(xué)習(xí)真實(shí)的數(shù)據(jù)分布,而判別器的目的是盡量正確判別輸入數(shù)據(jù)是來(lái)自真實(shí)數(shù)據(jù)還是來(lái)自生成器。GAN中的生成器和判別器需要不斷優(yōu)化,各自提高生成能力和判別能力,其學(xué)習(xí)優(yōu)化過(guò)程就是尋找二者之間的一個(gè)納什均衡[29]。
GAN系統(tǒng)一般框架如圖1所示,系統(tǒng)結(jié)構(gòu)主要包括:生成器(用于生成虛擬圖像),它通過(guò)接收隨機(jī)噪聲z,通過(guò)這個(gè)噪聲生成網(wǎng)絡(luò)G(z)。判別器是負(fù)責(zé)判斷圖像真假,輸入圖像x,輸出對(duì)該圖像的判別結(jié)果D(x)。
圖1 一般的GAN框架Fig.1 Framework of GAN
首先,在給定生成器G的情況下,最優(yōu)化判別器D。采用基于Sigmoid的二分類模型的訓(xùn)練方式,判別器D的訓(xùn)練是最小化交叉熵的過(guò)程,其損失函數(shù)表示為
(1)
式中x——采樣于真實(shí)數(shù)據(jù)分布Pdata(x)
z——采樣于先驗(yàn)分布Pz(z),例如高斯噪聲分布
E(·)——計(jì)算期望值
式(1)中判別器的訓(xùn)練數(shù)據(jù)集來(lái)源于真實(shí)數(shù)據(jù)集分布Pdata(x)(標(biāo)注為1) 和生成器數(shù)據(jù)分布Pg(x)(標(biāo)注為0)。
給定生成器G,最小化式(1)得到最優(yōu)解。對(duì)于任意的非零實(shí)數(shù)m和n,且實(shí)數(shù)值y∈[0,1],表達(dá)式為
Φ=-mlgy-nlg(1-y)
(2)
(3)
D(x)代表x來(lái)源于真實(shí)數(shù)據(jù)而非生成數(shù)據(jù)的概率。當(dāng)輸入數(shù)據(jù)采樣自真實(shí)數(shù)據(jù)x時(shí),D的目標(biāo)是使得輸出概率D(x)趨近于1,而當(dāng)輸入來(lái)自生成數(shù)據(jù)G(z)時(shí),D的目標(biāo)是正確判斷數(shù)據(jù)來(lái)源,使得D(G(z))趨近于0,同時(shí)G的目標(biāo)是使得其趨近于1。生成器G損失函數(shù)可表示為
OG(θG)=-OD(θD,θG)
(4)
其優(yōu)化問(wèn)題是一個(gè)極值問(wèn)題,GAN的目標(biāo)函數(shù)可以描述為
min(G)max(D){f(D,G)=Ex~Pdata(x)lgD(x)+Ez~Pz(z)lg(1-D(G(z)))}
(5)
GAN模型需要訓(xùn)練模型D最大化判別數(shù)據(jù)來(lái)源于真實(shí)數(shù)據(jù)或者偽數(shù)據(jù)分布G(z)的準(zhǔn)確率,同時(shí),需要訓(xùn)練模型G最小化lg(1-D(G(z)))。
GAN學(xué)習(xí)優(yōu)化的方法為:先固定生成器G,優(yōu)化判別器D,使得D的判別準(zhǔn)確率最大化;然后固定判別器D,優(yōu)化生成器G,使得D的判別準(zhǔn)確率最小化。當(dāng)且僅當(dāng)Pdata=Pg時(shí)達(dá)到全局最優(yōu)解。
MPGAN系統(tǒng)的框架如圖2所示,蘑菇圖像的生成過(guò)程為:生成器G(z)使用截?cái)嗟揭欢ǚ秶鷥?nèi)的隨機(jī)正態(tài)分布數(shù)據(jù)作為輸入,輸入到卷積網(wǎng)絡(luò)(Convolutional neural network, CNN),最后輸出生成圖像數(shù)據(jù)。判別器D(x)根據(jù)真實(shí)圖像數(shù)據(jù)和生成圖像數(shù)據(jù)輸出判別結(jié)果,并對(duì)神經(jīng)網(wǎng)絡(luò)的所有參數(shù)進(jìn)行反向更新操作。
圖2 蘑菇表型數(shù)據(jù)生成的MPGAN框架Fig.2 MPGAN framework for mushroom phenotypic data generation
圖3 生成器神經(jīng)網(wǎng)絡(luò)框架Fig.3 Neural network framework of generator
2.1.1生成器
生成器卷積神經(jīng)網(wǎng)絡(luò)結(jié)構(gòu)的作用是通過(guò)輸入隨機(jī)數(shù)據(jù)生成128×128×3的圖像,128表示像素?cái)?shù),3表示RGB的通道數(shù)。圖3是生成器的框架。
生成器采用8層的卷積神經(jīng)網(wǎng)絡(luò),首先是Input數(shù)據(jù)輸入層,第2層是全連接層(Fully connected, FC),然后是連續(xù)5個(gè)反卷積層(Deconvolution, DeConv),其中分為DC反卷積層、BN批歸一化層(Batch normalization,BN)和激活函數(shù),批歸一化層是對(duì)于同一批次數(shù)據(jù)按照給定的系數(shù)進(jìn)行規(guī)范化處理,以防止梯度彌散,最后是Output數(shù)據(jù)輸出層。生成器的反卷積層如圖4所示,各層具體描述如下:
(1)FC全連接層設(shè)計(jì)輸入為生成100個(gè)圖像的隨機(jī)數(shù)據(jù),經(jīng)過(guò)全連接層的8 192個(gè)神經(jīng)元處理以及形狀重塑后變?yōu)?×4×512大小的數(shù)據(jù),再經(jīng)過(guò)批歸一化層及ReLU激活函數(shù)后將結(jié)果輸出到下一層。
(2)生成器中包括5個(gè)反卷積層,卷積核的移動(dòng)步長(zhǎng)為2,卷積核尺寸為5×5,1~4層的每一層經(jīng)過(guò)批歸一化層及ReLU激活函數(shù)后將結(jié)果輸出到下一層,其中:
第1層輸入數(shù)據(jù)為4×4×512。反卷積層的卷積核數(shù)為256個(gè),經(jīng)過(guò)反卷積后得到的數(shù)據(jù)為8×8×256。
第2層輸入數(shù)據(jù)為8×8×256。反卷積層的卷積核數(shù)為128個(gè),經(jīng)過(guò)反卷積后得到的數(shù)據(jù)為16×16×128。
第3層輸入數(shù)據(jù)為16×16×128。反卷積層的卷積核數(shù)為64個(gè),經(jīng)過(guò)反卷積后得到的數(shù)據(jù)為32×32×64。
第4層輸入數(shù)據(jù)為32×32×64。反卷積層的卷積核數(shù)為32個(gè),經(jīng)過(guò)反卷積后得到的數(shù)據(jù)為64×64×32。
圖4 生成器的反卷積層Fig.4 Deconvolution layer of generator
第5層輸入數(shù)據(jù)為64×64×32。反卷積層的卷積核數(shù)為3個(gè)。輸入數(shù)據(jù)經(jīng)過(guò)反卷積后得到的數(shù)據(jù)為128×128×3,再經(jīng)過(guò)批歸一化層及tanh激活函數(shù)后將結(jié)果輸出到下一層。tanh函數(shù)表達(dá)式為
(6)
式中a——參數(shù)
不使用傳統(tǒng)的Sigmod函數(shù)進(jìn)行Output輸出層,而是直接將上一層輸入結(jié)果輸出。生成器網(wǎng)絡(luò)參數(shù)如表1所示。
表1 生成器網(wǎng)絡(luò)參數(shù)Tab.1 Summary of generator network parameters
圖5 判別器神經(jīng)網(wǎng)絡(luò)框架Fig.5 Neural network framework of discriminator
2.1.2判別器
判別器的作用是盡量擬合樣本之間的Wasserstein距離,從而將分類任務(wù)轉(zhuǎn)換成回歸任務(wù)。判別器采用7層的卷積神經(jīng)網(wǎng)絡(luò),首先是Input數(shù)據(jù)入層,接著是連續(xù)4個(gè)卷積層(Convolution,Conv),其中分為卷積層、歸一化層和激活函數(shù),然后是全連接層FC,最后是數(shù)據(jù)輸出層Output。判別器的架構(gòu)如圖5所示。
判別器的Conv卷積層設(shè)計(jì)如圖6所示。判別器共有4個(gè)卷積層,卷積核的移動(dòng)步長(zhǎng)為2,卷積核尺寸為5×5,經(jīng)過(guò)歸一化層及Leaky ReLU激活函數(shù)后將結(jié)果輸出到下一層。
第1層輸入數(shù)據(jù)為128×128×3。卷積層的卷積核數(shù)為64個(gè),經(jīng)過(guò)卷積后得到的數(shù)據(jù)為64×64×64。
第2層輸入數(shù)據(jù)為64×64×64。卷積層的卷積核數(shù)為128個(gè),經(jīng)過(guò)卷積后得到的數(shù)據(jù)為32×32×128。
圖6 判別器的卷積層操作Fig.6 Convolution layer of discriminator
第3層輸入數(shù)據(jù)為32×32×128。卷積層的卷積核數(shù)為256個(gè),經(jīng)過(guò)卷積后得到的數(shù)據(jù)為16×16×256。
第4層輸入數(shù)據(jù)為16×16×256。卷積層的卷積核數(shù)為512個(gè),經(jīng)過(guò)卷積后得到的數(shù)據(jù)為8×8×512。
FC全連接層設(shè)計(jì)的輸入數(shù)據(jù)為8×8×512,經(jīng)過(guò)全連接層處理以及形狀重塑后變?yōu)榇笮?的蘑菇圖像,并將結(jié)果輸出。判別器的網(wǎng)絡(luò)參數(shù)如表2所示。
表2 判別器網(wǎng)絡(luò)參數(shù)Tab.2 Summary of discriminator network parameters
2.2.1Wasserstein距離
MPGAN系統(tǒng)采用帶有梯度懲罰的Wasserstein距離[30],Wasserstein距離[9,31-32]又叫推土機(jī)(Earth-mover,EM)距離,定義為
(7)
式中Pr——真實(shí)數(shù)據(jù)分布
Pg——生成數(shù)據(jù)分布
r——真實(shí)樣本
y——生成樣本
γ——聯(lián)合分布
∏(Pr,Pg)——Pr和Pg組合起來(lái)的所有可能的聯(lián)合分布的集合
對(duì)于每個(gè)可能的聯(lián)合分布γ而言,采樣(x,y)~γ得到一個(gè)真實(shí)樣本x和一個(gè)生成樣本y,并計(jì)算這對(duì)樣本之間的距離‖x-y‖,計(jì)算該聯(lián)合分布γ下樣本對(duì)距離的期望值E(x,y)~γ(‖x-y‖)。Wasserstein距離定義為在所有可能的聯(lián)合分布中能夠?qū)@個(gè)期望值的下界[31]。
2.2.2系統(tǒng)損失函數(shù)
設(shè)定fw代表判別器網(wǎng)絡(luò),根據(jù)Lipschitz連續(xù)性條件的要求,該判別器網(wǎng)絡(luò)含參數(shù)w,并且參數(shù)w不超過(guò)某個(gè)范圍,根據(jù)式(7)定義的Wasserstein距離,MPGAN系統(tǒng)判別器的目的是近似擬合Wasserstein距離,因此判別器的損失函數(shù)可以表示為
LD=Ex~Pg(fw(x))-Ex~Pr(fw(x))
(8)
MPGAN系統(tǒng)生成器的目的是近似地最小化Wasserstein距離,即最小化式(8),因此生成器的損失函數(shù)可以表示為
LG=Ex~Pr(fw(x))-Ex~Pg(fw(x))
(9)
GULRAJANI等[30]提出的帶有梯度懲罰的Wasserstein距離來(lái)滿足Lipschitz連續(xù)性。當(dāng)生成數(shù)據(jù)分布Pg接近真實(shí)數(shù)據(jù)分布Pr時(shí),Lipschitz連續(xù)性可表示為
‖D(Pg)-D(Pr)‖≤K‖Pg-Pr‖
(10)
式(10)可轉(zhuǎn)換為
(11)
式中Pc——生成數(shù)據(jù)分布與真實(shí)數(shù)據(jù)分布的差值
K——整數(shù)常量
先對(duì)真假樣本的數(shù)據(jù)分布進(jìn)行隨機(jī)差值采樣,即產(chǎn)生一對(duì)真假樣本Xr和Xg,采樣公式為
X=ξXr+(1-ξ)Xg
(12)
式中ξ——[0,1]區(qū)間的隨機(jī)數(shù)
(13)
式中λ——調(diào)節(jié)梯度懲罰項(xiàng)大小的參數(shù)
K為使得Lipschitz連續(xù)性條件成立的常量,設(shè)定K為1,MPGAN系統(tǒng)的判別器損失函數(shù)式(9)和梯度懲罰項(xiàng)式(13),損失函數(shù)可表示為
(14)
根據(jù)GAN網(wǎng)絡(luò)的框架和優(yōu)化過(guò)程,MPGAN系統(tǒng)的訓(xùn)練過(guò)程如圖7所示。
圖7 MPGAN系統(tǒng)的訓(xùn)練過(guò)程Fig.7 Training procedure of MPGAN system
圖7中的訓(xùn)練過(guò)程描述如下:
(1)采用方差為0.02的截?cái)嗾龖B(tài)分布初始化網(wǎng)絡(luò)中的權(quán)值參數(shù)W和卷積核初始化網(wǎng)絡(luò)的偏置值b,初始化學(xué)習(xí)率η,即每次參數(shù)更新幅度。在訓(xùn)練過(guò)程中,參數(shù)更新向著損失函數(shù)梯度下降的方向,表示為
Wn+1=Wn-ηΔ
(15)
式中Δ——梯度,即損失函數(shù)的導(dǎo)數(shù)
(2)采用區(qū)間為[-1,1]的均勻分布初始化隨機(jī)噪聲。
(3)采用數(shù)據(jù)集中隨機(jī)獲取批次大小的訓(xùn)練樣本,并在輸入隊(duì)列中進(jìn)行數(shù)據(jù)預(yù)處理。
(4)將步驟(2)中生成的隨機(jī)噪聲輸入到生成器網(wǎng)絡(luò),生成虛擬圖像數(shù)據(jù),將生成的虛擬圖像數(shù)據(jù)輸入判別器,得到生成圖像判別結(jié)果;將步驟(3)中獲取的訓(xùn)練樣本使用批歸一化操作輸入判別器,得到真實(shí)圖像判別結(jié)果;計(jì)算判別器損失并反向更新判別器參數(shù)。
(5)計(jì)算梯度懲罰項(xiàng),為判別器損失施加懲罰,然后使用優(yōu)化器反向更新判別器參數(shù),使用梯度懲罰項(xiàng),替換原來(lái)的權(quán)重截?cái)嗖呗浴?/p>
(6)判斷是否達(dá)到指定判別器優(yōu)化次數(shù),即每?jī)?yōu)化一次生成器時(shí)優(yōu)化N次判別器,若是則進(jìn)入步驟(7),若否則重新進(jìn)入步驟(3)。其中N由用戶設(shè)定。
(7)將步驟(2)中生成的隨機(jī)噪聲輸入到生成器網(wǎng)絡(luò),計(jì)算生成器損失并使用優(yōu)化器反向更新判別器參數(shù)。
(8)判斷是否達(dá)到指定迭代次數(shù),即是否遍歷完全部樣本,若是則進(jìn)入步驟(9),否則重新進(jìn)入步驟(2)。
(9)判斷是否達(dá)到EPOCH次數(shù),EPOCH為總共訓(xùn)練的輪次,若是則結(jié)束,否則重新進(jìn)入步驟(2)。
實(shí)驗(yàn)平臺(tái)為Windows 10系統(tǒng),16 GB內(nèi)存,256 GB SSD,1 TB HD,Intel QuadCore i7-8700, 4.2 GHz, Nvidia GTX 1070,8 GB。算法采用Tensorflow V1.1 GPU框架[33]和Python 3.6實(shí)現(xiàn)。
采用兩類數(shù)據(jù)集:開源蘑菇數(shù)據(jù)集Fungi[28],選擇了其中375幅圖像;私有數(shù)據(jù)集,共138幅圖像。圖像預(yù)處理方法包括隨機(jī)翻轉(zhuǎn)、隨機(jī)亮度變換、隨機(jī)對(duì)比度變換和圖像歸一化,前面幾種預(yù)處理方法主要是為了增加樣本數(shù)量,而圖像歸一化是為了降低幾何變換帶來(lái)的影響。
圖8為開源數(shù)據(jù)集Fungi蘑菇示例圖像,該數(shù)據(jù)集環(huán)境噪聲大且背景復(fù)雜,背景中有草地、林地、樹葉、木塊等多種干擾物。
圖8 開源數(shù)據(jù)集示例Fig.8 Examples of public dataset
私有蘑菇數(shù)據(jù)集采用鳳尾菇作為對(duì)象,該數(shù)據(jù)集采用黑色作為背景,背景噪聲小,且蘑菇形狀不同,適合菌菇表型圖像生成。圖9為私有蘑菇數(shù)據(jù)集的示例圖像。
圖9 私有蘑菇數(shù)據(jù)集示例Fig.9 Examples of private dataset
MPGAN系統(tǒng)默認(rèn)使用Adam優(yōu)化器[34],優(yōu)化器超參數(shù)β1=0.5、β2=0.9、ε=1×10-8,學(xué)習(xí)率η默認(rèn)為0.000 3,判別器優(yōu)化次數(shù)N=5。
3.2.1生成器參數(shù)設(shè)置
由于生成器的輸出層直接將前一層的值作為輸入,最后激活函數(shù)選擇tanh激活函數(shù),該激活函數(shù)可以將輸出層的輸出約束到區(qū)間[-1,1]。
為了保證數(shù)據(jù)分布的一致性,并防止反向傳播權(quán)值更新時(shí)發(fā)生梯度彌散并加速收斂,采用批歸一化(Local response normalization),對(duì)同一批次數(shù)據(jù)按照給定的系數(shù)進(jìn)行規(guī)范化處理。其處理步驟如下:
(1)沿通道計(jì)算同一批次內(nèi)所有圖像的均值μB,計(jì)算式為
(16)
(17)
(3)對(duì)圖像做歸一化處理,計(jì)算式為
(18)
ω——防止方差為0的參數(shù)
(4)加入縮放變量γ和平移變量φ,得出結(jié)果
yi=γi+φ≡BNγ,φ(xi)
(19)
式中yi——加入縮放變量γ和平移變量φ處理結(jié)果
3.2.2判別器參數(shù)設(shè)置
選擇Leaky ReLU激活函數(shù)作為判別器激活函數(shù),確保梯度更新整個(gè)圖像。Leaky ReLU激活函數(shù)表達(dá)式為
(20)
式中α——(1,+∞)區(qū)間內(nèi)的參數(shù)
MPGAN系統(tǒng)生成式對(duì)抗網(wǎng)絡(luò)模型的梯度懲罰策略采用層歸一化函數(shù)(Layer normalization,LN)。
在學(xué)習(xí)率η為0.000 3時(shí),使用開源數(shù)據(jù)集和私有數(shù)據(jù)集作為訓(xùn)練數(shù)據(jù)集,MPGAN系統(tǒng)的Wasserstein距離與EPOCH的關(guān)系如圖10所示。
圖10 Wasserstein距離收斂曲線Fig.10 Wasserstein distance convergence curves
由圖10a可知,在開源數(shù)據(jù)集,EPOCH大于2 000后逐漸開始學(xué)習(xí)到真實(shí)圖像的數(shù)據(jù)分布,在EPOCH達(dá)到10 000后逐漸趨于穩(wěn)定,在這個(gè)階段數(shù)據(jù)集本身噪聲較大導(dǎo)致模型的學(xué)習(xí)能力有所下降,所以模型學(xué)習(xí)的特征被背景所干擾,并且在曲線尾部的振蕩程度明顯增大,此時(shí)減小學(xué)習(xí)率η可以使模型訓(xùn)練更加穩(wěn)定。
由圖10b可知,Wasserstein距離在EPOCH達(dá)到2 000后不斷收斂,在10 000左右有小幅振蕩,EPOCH在超過(guò)35 000之后,振蕩幅度減小,模型比較穩(wěn)定。
由圖10可知,不同數(shù)據(jù)集訓(xùn)練的EPOCH次數(shù)不同,開源數(shù)據(jù)集的噪聲較大,模型不容易收斂,并且相似度衡量指標(biāo)Wasserstein距離在EPOCH為12 000時(shí)開始穩(wěn)定在一個(gè)較高的程度;私有數(shù)據(jù)集上的噪聲較小,當(dāng)在該數(shù)據(jù)集,模型收斂更加快速,Wasserstein距離在EPOCH大于35 000時(shí)開始逐漸收斂穩(wěn)定。
基于開源數(shù)據(jù)集的學(xué)習(xí)率與EPOCH關(guān)系如圖11所示。從圖11可看出,提高學(xué)習(xí)率η時(shí),模型的收斂速度有明顯的提升并在EPOCH為1 000后逐漸穩(wěn)定,但是隨著學(xué)習(xí)率的提高,收斂的振蕩程度也在加大,因此可以在訓(xùn)練初期使用較大的學(xué)習(xí)率提高初始收斂速度,然后逐漸減小學(xué)習(xí)率保證訓(xùn)練過(guò)程穩(wěn)定。由于在私有數(shù)據(jù)集上的結(jié)果類似,因此僅報(bào)告了開源數(shù)據(jù)集上的測(cè)試結(jié)果。
圖11 基于開源數(shù)據(jù)集的學(xué)習(xí)率與EPOCH關(guān)系Fig.11 Learning rate and EPOCH relationship based on open source dataset
首先,系統(tǒng)測(cè)試了數(shù)據(jù)中的scalpturatum口蘑,EPOCH為1 000時(shí),學(xué)習(xí)率η為0.000 1~0.000 5生成圖像如圖12所示。圖12a為原始圖像,從圖12b可看出,學(xué)習(xí)率η為0.000 3時(shí),生成的菌菇圖像相對(duì)較好。
圖12 不同學(xué)習(xí)率的菌菇圖像生成結(jié)果對(duì)比Fig.12 Mushroom image generation results comparison at different learning rates
當(dāng)學(xué)習(xí)率η為0.000 3時(shí),在開源數(shù)據(jù)集和私有數(shù)據(jù)集上,測(cè)試了系統(tǒng)菌菇圖像生成結(jié)果,生成圖像尺寸設(shè)置為64像素×64像素,結(jié)果分別如圖13和圖14所示。圖13為EPOCH為15 000時(shí),開源數(shù)據(jù)集上的生成結(jié)果。圖13b的生成圖像能夠清晰地顯示出原始菌菇的表型特征。
圖14為EPOCH為50 000時(shí),私有數(shù)據(jù)集上的生成結(jié)果。圖14b的生成圖像能夠清晰地顯示出原始菌菇的表型特征。
圖13 基于開源數(shù)據(jù)集上的蘑菇生成圖像Fig.13 Illustration of generating Fungi images based on public dataset
圖14 基于私有數(shù)據(jù)集上的蘑菇生成圖像Fig.14 Illustration of generating Fungi images based on private dataset
對(duì)比圖13b和圖14b可以看出,圖14b質(zhì)量?jī)?yōu)于圖13b,表明高質(zhì)量的菌菇訓(xùn)練數(shù)據(jù)對(duì)圖菌菇表型圖像的生成有重要影響。
(1)研究了菌菇表型數(shù)據(jù)生成技術(shù),設(shè)計(jì)了用于菌菇表型數(shù)據(jù)生成的生成式對(duì)抗網(wǎng)絡(luò)結(jié)構(gòu)。使用Wasserstein距離和帶有梯度懲罰的損失函數(shù)。
(2)利用開源數(shù)據(jù)和私有數(shù)據(jù)集進(jìn)行了測(cè)試,結(jié)果表明,數(shù)據(jù)集噪聲越小越好,噪聲越小則損失越容易收斂,否則背景和主體目標(biāo)發(fā)生混淆時(shí),損失會(huì)在一個(gè)較大程度上振蕩。
(3)測(cè)試了學(xué)習(xí)率η、EPOCH與Wasserstein距離關(guān)系,系統(tǒng)生成的菌菇表型數(shù)據(jù)可為后期菌菇數(shù)據(jù)分類與識(shí)別提供大數(shù)據(jù)基礎(chǔ),為解決菌菇分類的數(shù)據(jù)非均衡、長(zhǎng)尾分布等問(wèn)題提供研究基礎(chǔ)。