Quantcast
Channel: プログラミング
Viewing all articles
Browse latest Browse all 8031

ゼロから作るDeep Learning 5(生成モデル編)をRで再現してみた(第五章) - GRGと金融工学・統計解析

$
0
0

前回

grg.hatenablog.com

はじめに

前回は混合ガウスモデル(GMM)について紹介しました。今回はEMアルゴリズムについて説明します。EMアルゴリズム隠れマルコフモデル(HMM)など潜在変数を持つモデルのパラメータ推定にも使用されるなど多くの場面で活用されています。参考書(ゼロから作るDeep Learning 5)にはEMアルゴリズムの導出について丁寧に書かれていますが、このブログで同じように全て記載してしまうと参考書の良さが丸パクリになってしまうので、詳細な導出については参考本をご覧になっていただければと思います。本稿では、メインの目的がRコードで同じような結果を出すことなので、数式やその説明などは必要最低限の内容だけ記載します。

STEP2:EMアルゴリズム

EMアルゴリズムの導出のためにはカルバック・ライブラー情報量(KLダイバージェンス)を上手い具合に活用して尤度関数を作っていきます。 まずパラメータ \thetaを持つ分布を考えます。その対数尤度関数は以下のように書くことができます。

 \displaystyle
\log p_\theta (x) = \log p_\theta (x) \sum_{z} q(z)
 \displaystyle
= \sum_{z} q(z) \log p_\theta (x)
 \displaystyle
= \sum_{z} q(z) \left(\log \frac{p_\theta (x,z)}{q(z)}+\log \frac{q(z)}{p_\theta (z | x)} \right)
 \displaystyle
= \sum_{z} q(z) \log \frac{p_\theta (x,z)}{q(z)} + \sum_{z} \log \frac{q(z)}{p_\theta (z | x)}

ここで q(z)は任意の確率密度関数になります。この q(z)を上手く設定することで対数尤度関数を計算しやすくすることができます。 そしてこの第二項目がKLダイバージェンスと呼ばれる項目で必ず0以上の数値を取ります。また、第一項目がエビデンスの下界(ELBO)と呼ばれます。 このKLダイバージェンスとELBOを上手く交互に更新していくことで対数尤度関数が大きくなっていくことが証明されており、その交互に更新していくことをEMアルゴリズムと呼ばれています。 具体的にはKLダイバージェンスを0に近づけるように設定し、ELBOを最大になるようにパラメータを更新していきます。KLダイバージェンス q(z) = p_\theta (z | x)となるように q(z)を設定できれば0にすることができます。 また、ELBOは今まで通り関数をパラメータ変数で微分して0となる数値を求めることでパラメータを更新していきます。GMMの場合はELBOをパラメータ変数で微分して0となる数値は解析的に求めることができるため、実装も簡単にできます。 最後に、EMアルゴリズムでは、KLダイバージェンスを更新するステップをEステップ、ELBOを更新することをMステップと呼びます。

では、このEMアルゴリズムをGMMに適用してみましょう。詳細については参考書をご覧になっていただき、ここでは結論だけ記載します。 まずEステップですが、以下のように更新します。

 \displaystyle
q^{(n)}(z = k)  = \frac{\phi_k N(x^{(n)}; \mu_k, \Sigma_k)}{\sum_{j=1}^{K} \phi_j N(x^{(n)}; \mu_j, \Sigma_j)}

次にMステップは以下のように更新していきます。

 \displaystyle
\mu_k = \frac{\sum_{n=1}^{N} q^{(n)}(k)x^{(n)}}{\sum_{n=1}^{N} q^{(n)}(k)}
 \displaystyle
\Sigma_k = \frac{\sum_{n=1}^{N} q^{(n)}(k)(x^{(n)}-\mu_k)(x^{(n)}-\mu_k)^{T}}{\sum_{n=1}^{N} q^{(n)}(k)}
 \displaystyle
\phi_k = \frac{1}{N} \sum_{n=1}^{N} q^{(n)}(k)

このような形でGMMのパラメータを更新(EMアルゴリズム)すると、以下の対数尤度が高くなっていくはずです。

 \displaystyle
\frac{1}{N} \sum_{n=1}^{N} \log \sum_{j=1}^{K} \phi_j N(x^{(n)}; \mu_j, \Sigma_j)

対数尤度の上昇幅が無視できるレベルで小さくなれば結果が収束したと考えられるので、そこでパラメータの更新をストップします。 では、このEMアルゴリズムをRで実装してみます。推定に使用するデータですが、前回GMMを紹介した際に作成したGMMからサンプリングするコードを活用して、そこから25,000個データをサンプリングしました。 また、多変量正規分布確率密度関数はmnormtパッケージのdmnorm関数を使用しています(前回実装した自前の多変量正規分布確率密度関数でも可)。

#パラメータ初期値
phis =list()
mus =list()
sigmas =list()
phis[[1]]=0.5
phis[[2]]=0.5
mus[[1]]=c(0,50.0)
mus[[2]]=c(0,100.0)
sigmas[[1]]=diag(c(1,1),nrow =2, ncol =2)
sigmas[[2]]=diag(c(1,1),nrow =2, ncol =2)

K =2#潜在変数の次元
N =nrow(dat)#データ数 
MAX_ITER =100
THRESHOLD =10^-4#データ1個当たりのGMM尤度
GMM =function(x,mus,sigmas,phis){
  K =length(mus)
  y =0for(i in1:K){
    phi = phis[[i]]
    mu = mus[[i]]
    sigma = sigmas[[i]]
    tmp =  phi *dmnorm(x,mean = mu, varcov = sigma,log =FALSE)
    y = y + tmp
  }return(y)}#GMMの対数尤度
GMM_likelihood =function(dat,mus,sigmas,phis){
  eps =10^-8
  L =0
  N =nrow(dat)for(i in1:N){
    y =GMM(dat[i,],mus,sigmas,phis)
    L = L +log(y + eps)}return(L / N)}#EMアルゴリズム
current_likelihood =GMM_likelihood(dat,mus,sigmas,phis)for(iter in1:MAX_ITER){#E STEP
  qs =matrix(0, nrow = N, ncol = K)for(n in1:N){
    x = dat[n,]for(k in1:K){
      phi = phis[[k]]
      mu = mus[[k]]
      sigma = sigmas[[k]]
      qs[n,k]= phi *multivariate_normal(x,mu,sigma)}
      qs[n,]= qs[n,]/rep(GMM(x,mus,sigmas,phis),K)}#M STEP
  qs_sum =colSums(qs)for(k in1:K){#phi更新
    phis[[k]]= qs_sum[k]/ N
    
    #mu更新
    c =0for(n in1:N){
      c = c + qs[n,k]* dat[n,]}
    mus[[k]]= c / qs_sum[k]#sigma更新
    c =0for(n in1:N){
      z =matrix(dat[n,]- mus[[k]])
      c = c + qs[n,k]* z %*%t(z)}
    sigmas[[k]]= c / qs_sum[k]}#終了条件print(paste0("current_likelihood-->",current_likelihood))
  next_likelihood =GMM_likelihood(dat,mus,sigmas,phis)
  diff_lik =abs(next_likelihood - current_likelihood)if(diff_lik < THRESHOLD){print("尤度関数の変化幅が閾値以下になったため終了")break}
  current_likelihood = next_likelihood
}print(phis)print(mus)print(sigmas)

この実行結果が以下の通りです。

[1]"current_likelihood-->-15.2745878992021"[1]"current_likelihood-->-4.40040150152027"[1]"current_likelihood-->-4.18465653638498"[1]"current_likelihood-->-4.02958516934097"[1]"current_likelihood-->-4.02374867082427"[1]"尤度関数の変化幅が閾値以下になったため終了">print(phis)[[1]][1]0.6511936[[2]][1]0.3488064>print(mus)[[1]][1]2.00040754.468064[[2]][1]4.30143280.010580>print(sigmas)[[1]][,1][,2][1,]0.070316210.4381252[2,]0.4381252033.6305036[[2]][,1][,2][1,]0.17041470.9244299[2,]0.924429935.4103630

推定された結果を見る限り、GMMのデータサンプリングでインプットしていたパラメータの数値とかなり近しい数値となっていることから、うまく推定できているように見えます。 しかし、今回は25,000個のデータから推定した結果をお見せしましたが、500個くらいだとうまく推定できない場合が発生したため、やはりGMMのような潜在変数モデルはデータ数を多く確保することが重要だと思います。 または、パラメータの初期値を極端な値ではなく、より妥当な数値に設定することで収束を早めることができる可能性があります。ここら辺はトライ&エラーだという肌感覚です。

まとめ

今回はEMアルゴリズムについて簡単に紹介しました。次回は参考書のタイトルであるDeep Learningの世界に踏み込んでいき、ニューラルネットワークについてみていきます。 そして、今まで紹介してきたGMMのロジックとニューラルネットワークを徐々に組み合わせていき、より複雑なモデルを構築していきます。


Viewing all articles
Browse latest Browse all 8031