ギブスサンプリングでベイズ推定
ギブスサンプリングは確率分布からのサンプリング手法の一つ、各変数について条件付き確率さえわかればサンプリングできる。ベイズ推定での事後分布サンプリングに使われることが有る(Just another Gibbs samplerというそのものズバリなベイズ推定のオープンソースプロダクトもある)。
以下はギブスサンプリングに関する参考文献。2は可視化もあって大変分かりやすい。直感的に理解できるのがギブスサンプリングのいいところだと思う。
ここではギブスサンプリングを実装してベイズ推定を実施する。ギブスサンプリングは簡単なアルゴリズムなので、多分、実装も簡単であろうと思う。
実装は以下(Campbell)を参考にした。これを読めば誰でもギブスサンプリングでベイズ推定ができる、とても分かりやすい記事。
Gibbs sampling for Bayesian linear regression in Python, Kieran Campbell
テストデータ、モデルは以下の記事(Memorandum)のものをそのまま利用した。記事中のRstanのコードは推定結果の答え合わせに利用した。このブログは好奇心をそそられる記事に溢れており、眺めるだけで賢くなった気分になれる。
WAICとWBICを事後分布から計算する, StatModeling Memorandum
データと統計モデル
2成分混合正規分布から100点を生成しテストデータとする。MemorandumがRを使っているので、それにならう。
from subprocess import Popen, PIPE R = Popen(["R", "--vanilla", "--silent"], stdin=PIPE, stdout=PIPE) R.stdin.write(b''' set.seed(1) N <- 100 a_true <- 0.4 mean0 <- 0 mean1 <- 3 sd0 <- 1 sd1 <- 1 Y <- c(rnorm((1-a_true)*N, mean0, sd0), rnorm(a_true*N, mean1, sd1)) write.table(Y, file="points.csv", sep=",", row.names=F, col.names=F) q() ''') R.communicate() R.kill()
統計モデルは2つ。それぞれについてベイズ推定を行う。
Pythonセットアップ
CampbellにならってPython3を用いる。pandasというのはRのdata.frame的なものらしい。多彩なデータ加工ができるようだが趣味人には使い方が分からない。Anacondaは怖いので使わない。
$ python -m venv .venv $ source .venv/bin/activate $ pip install numpy scipy pandas seaborn
以下はpythonの共通設定的なもの。
import numpy as np np.random.seed(0) import scipy as sp import pandas as pd pd.set_option('display.width', 200) import seaborn as sns from seaborn import plt plt.rcParams['figure.figsize'] = (12, 6)
可視化のテスト、R経由で作成したテストデータを可視化してみる。
Y = pd.read_csv("points.csv", header=None) p = sns.distplot(Y, bins=20) plt.savefig("points.png")
混合分布ってのはこういう形をしているようだ。
ギブスサンプリングでベイズ推定のおさらい
Campbellの記事でギブスサンプリングを復習。
今手元に 個のモデルパラメタ とデータ があり、ここから事後分布 を得たいとする。ギブスサンプリングで事後分布をサンプリングするには、条件付き確率 を用いて以下の手順をとる。
- 適当な初期値 を与える.
- すべてのについて をサンプリング.
- 2.を予め決めたイテレーション数だけ繰り返す.
イテレーション数が十分であれば、各パラメタのサンプリング結果の密度分布を、各パラメタの周辺化した事後分布として扱うことができる、らしい。
確率モデルの実装(モデル1: 単一の正規分布モデル)
モデルの概要
ここではデータ が従う正規分布の平均 と標準偏差 を推定する。
データ の条件付き確率。
パラメタには事前分布は、簡単のため無情報事前分布とする。
このときパラメタの確率は一定になる。
のサンプリング
何はなくとも条件付き確率がわからなければサンプリングできない。 定数項と に独立な項を消去する。
データ の条件付き確率が出てきた。
最後の二行で、規格化されていない尤度関数からうまいこと正規分布が得られた。
これはつまり、ある確率変数 の尤度が exp(x の二次式) であるとき、確率変数 は正規分布をとる、ということである。対数尤度が の二次式なら正規分布!、と憶えるとよさそう。
このようにして の条件付き確率がわかった。
のサンプリングコード、正規分布からのサンプリングなのでとても簡単。
def sample_mu(y, N, s): mean = np.sum(y) / N variance = s * s / N return np.random.normal(mean, np.sqrt(variance))
のサンプリング
つづいて 。 と同様に条件付き確率の比例成分として の事後分布が得られる。このままでは計算しづらいので、対数尤度 を解き、 がどのような分布に従うか調べる。
対数尤度が の二次式ではないので、とりあえず正規分布ではない。ここでガンマ分布(Gamma distribution:wikipedia)を導入する。
ガンマ分布は と の2パラメタで定義される分布であり、確率変数 がガンマ分布に従うとき確率密度は となる。 の依存項のみを抜き出した対数尤度は以下になる。
と似たような形をしている。ここで、 から ( は精度と呼ばれるもの)への変数変換を行い、
より、 の条件付き確率は以下のガンマ分布に従うことが示せた。
サンプリングコード、 をサンプリングし、後に に変換する、という実装。
def sample_s(y, N, mu): alpha = N / 2 + 1 residuals = y - mu beta = np.sum(residuals * residuals) / 2 tau = np.random.gamma(alpha, 1 / beta) return 1 / np.sqrt(tau)
サンプリングの実施
条件付き確率が得られたので、Campbellを参考にサンプラを実装する。
イテレーションごとに 各 の対数尤度(log_likelihoods)を計算している。これはMemorandumを参考にしたもので、WAICというモデルの汎化性能を評価する指標があり、その算出に必要。ここではMemorandumの結果と一致することを確認するために計算している。
def model1(y, iters, init): mu = init["mu"] s = init["s"] N = len(y) trace = np.zeros((iters, 2)) log_likelihoods = np.zeros((iters, N)) for i in range(iters): mu = sample_mu(y, N, s) s = sample_s(y, N, mu) trace[i, :] = np.array((mu, s)) norm = sp.stats.norm(mu, s) log_likelihoods[i, :] = np.array([np.log(norm.pdf(x)) for x in y]) trace = pd.DataFrame(trace) trace.columns = ['mu', 's'] log_likelihoods = pd.DataFrame(log_likelihoods) return trace, log_likelihoods
初期値。
init = {"mu": np.random.uniform(-100, 100), "s": np.random.uniform(0, 100)} print(init)
{'mu': 9.762700785464943, 's': 71.51893663724195}
サンプリング実施。
iters = 5000 y = np.array([float(x) for x in open("points.csv", "r").read().strip().split("\n")]) trace, log_likelihoods = model1(y, iters, init)
結果
サンプリング結果をイテレーションを横軸としてプロットしたものをトレースプロットと呼ぶ。これが一定の範囲内で上下ふらふらしていればひとまずサンプリングは成功と言える。見た感じは良さそう。1
traceplot = trace.plot() traceplot.set_xlabel("Iteration") traceplot.set_ylabel("Parameter value") traceplot.get_figure().savefig("model1_trace.png")
パラメタごとのサンプリング結果のヒストグラム。サンプリング結果の後半だけをtrace_burntに切り出している。ギブスサンプリングを含めたマルコフ連鎖モンテカルロサンプリングでは、サンプリングの序盤は暖機運転扱いで、結果から除外するのが通例であるらしい2。
trace_burnt = trace[int(len(trace)/2):] hist_plot = trace_burnt.hist(bins = 30, layout = (1,2)) traceplot.get_figure().savefig("model1_hist.png")
MemorandumのRコードからWAIC算出関数を移植。
def waic(log_likelihoods): training_error = -np.log(np.exp(log_likelihoods).mean(axis=0)).mean() functional_variance_div_N = (np.power(log_likelihoods, 2).mean(axis=0) - np.power(log_likelihoods.mean(axis=0),2)).mean() return training_error + functional_variance_div_N
統計量等と合わせて表示する。WAICはMemorandumでは1.980なので、よく一致している。成功したらしい。
print(trace_burnt.describe().T) print("waic:", waic(log_likelihoods)) print("np.mean(y):", np.mean(y)) print("np.std(y):", np.std(y))
count mean std min 25% 50% 75% max mu 2500.0 1.307089 0.173329 0.596588 1.193417 1.307739 1.419458 1.945270 s 2500.0 1.728932 0.119035 1.411500 1.644560 1.725093 1.802547 2.212252 waic: 1.98051651897 np.mean(y): 1.30888736691 np.std(y): 1.7214151253
確率モデルの実装(モデル2: 2成分混合モデル)
次のモデルに進む。2つの正規分布の混合モデル。
モデルの概要
何も考えずにモデルを書くと下になる、と思う。
ところがこのモデルはどういじっても条件付き確率が導出できない。どうも正規分布の足し算が難易度高いらしい。無理。
困ったのでmixture、gaussian、gibbs等で適当にググると、混合分布モデルの場合はデータ点が属するカテゴリを推定するのがよい、と書いてあるらしい難しい内容の記事・文献がいくつか出てた3。要するにデータ の属する正規分布カテゴリを表す確率変数 を導入し、
とするといいらしい。Bernoulliは確率aで1、(1-a)で0となる離散確率分布、曲がったコインのトスみたいなもの。これはつまり、今回のテストデータ作成をそのままモデル化したのと同じことである。 の条件付き確率は、
となり、掛け算なので取り扱いが楽そう。確かに先に進めそうである。
は既知とする。 と の事前分布は以下とする。
のサンプリング
の条件付き確率。
考え方はモデル1の と同じで、正規分布になる。ただしデータは 全体ではなく、 の正規分布に属するデータ()のみを用いる。
サンプリングコード。コード中のzは の配列だが、要素が0か1なので、 の総和や総数は以下のように楽ちんに計算できる。
def sample_mu_1(y, z): N1 = np.sum(z) mean = np.sum(y * z) / N1 variance = 1 / N1 return np.random.normal(mean, np.sqrt(variance))
のサンプリング
の条件付き確率。 は0か1の離散値なので、総和を取ると に属するデータ点の総数になる。
例によって上の二行目は規格化すると名前のある分布になるのだろうと期待される。ベータ分布 であるようだ。
def sample_a(z, a): alpha = np.sum(z) + 1 beta = len(z) - alpha + 2 return np.random.beta(alpha, beta)
のサンプリング
の条件付き確率。 は0 or 1 の離散値なので、尤度さえ分かればサンプリング可能。
コードでは一つの関数で全 をサンプリングしている。 は使いまわせるのでグローバルに定義している。 zのサンプリング結果を返すのではなく、zの配列を更新しているため、関数名がこれまでと異なる。
norm0 = sp.stats.norm(0, 1) def sample_and_update_z(y, mu_1, a, z): norm1 = sp.stats.norm(mu_1, 1) for i in range(len(y)): d0 = (1-a) * norm0.pdf(y[i]) d1 = a * norm1.pdf(y[i]) z[i] = 0 if np.random.uniform() * (d0+d1) < d0 else 1
サンプリングの実施
サンプリング実行コードと初期値。
def model2(y, iters, init): mu_1 = init["mu_1"] a = init["a"] z = init["z"] N = len(y) trace = np.zeros((iters, 2)) log_likelihoods = np.zeros((iters, N)) for i in range(iters): mu_1 = sample_mu_1(y, z) a = sample_a(z, a) sample_and_update_z(y, mu_1, a, z) trace[i, :] = np.array((mu_1, a)) norm1 = sp.stats.norm(mu_1, 1) log_likelihoods[i, :] = \ np.array([np.log((1-a)*norm0.pdf(x) + a*norm1.pdf(x)) for x in y]) trace = pd.DataFrame(trace) trace.columns = ['mu_1', 'a'] log_likelihoods = pd.DataFrame(log_likelihoods) return trace, log_likelihoods
init = {"mu_1": np.random.uniform(-100, 100), "a": np.random.rand(), "z": np.random.randint(0, 2, 100)} print(init)
{'mu_1': -99.69811311464245, 'a': 0.7177148886002629, 'z': array([1, 1, 1, 0, 0, 1, 0, 0, 1, 0, 0, 0, 1, 1, 0, 1, 0, 0, 1, 1, 1, 1, 1,
1, 0, 0, 1, 0, 0, 1, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 1, 1, 0, 1, 0,
1, 0, 1, 0, 0, 0, 0, 1, 0, 0, 1, 1, 1, 0, 1, 1, 0, 0, 0, 1, 0, 0, 0,
1, 1, 0, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 0, 1, 0, 0, 1, 0, 0, 1, 1, 1,
0, 1, 0, 1, 0, 0, 1, 1])}
サンプリングを実施。
iters = 5000 y = np.array([float(x) for x in open("points.csv", "r").read().strip().split("\n")]) trace, log_likelihoods = model2(y, iters, init)
結果
traceplot = trace.plot() traceplot.set_xlabel("Iteration") traceplot.set_ylabel("Parameter value") traceplot.get_figure().savefig("model2_trace.png")
trace_burnt = trace[int(len(trace)/2):] hist_plot = trace_burnt.hist(bins = 30, layout = (1,2)) traceplot.get_figure().savefig("model2_hist.png")
統計量とWAIC。Memorandumでは1.913とのこと。ほぼ同じである。
print(trace_burnt.describe().T) print("waic:", waic(log_likelihoods))
count mean std min 25% 50% 75% max
mu_1 2500.0 3.083692 0.205765 2.199600 2.940599 3.085343 3.223398 3.768049
a 2500.0 0.397030 0.055852 0.214459 0.360779 0.395766 0.433929 0.601889
waic: 1.91467884523
Rstanでの結果
推定結果は良さそうだし、WAICもMemorandumとほぼ同じだ。ついでにMemorandumに書かれたRstanコードを実行し、推定値の統計量を比較する。
from subprocess import check_output print(check_output(["Rscript", "model2.r"]).decode("utf8"))
# results of the mixture normal distribution model with Rstan
mean se_mean sd 2.5% 25% 50% 75% 97.5% n_eff Rhat
mu 3.0914330 0.0012293010 0.20128997 2.6958750 2.955486 3.0909330 3.2286982 3.485634 26811.91 1.000001
a 0.3965491 0.0003816836 0.05627476 0.2907154 0.357534 0.3955335 0.4343038 0.509682 21738.04 1.000274
WAIC: 1.914037
これもよく一致した。やはりうまくいったようだ。
上で使用したRコードはこちら。
options(width=200) Y <- read.table("points.csv") data <- list(N=length(Y), Y=Y) model2 <- " data { int<lower=1> N; vector[N] Y; } parameters { real<lower=0, upper=1> a; real<lower=-50, upper=50> mu; } model { for(n in 1:N){ target += log_sum_exp( log(1-a) + normal_lpdf(Y[n] | 0, 1), log(a) + normal_lpdf(Y[n] | mu, 1) ); } } generated quantities { vector[N] log_likelihood; int index; real y_pred; for(n in 1:N) log_likelihood[n] = log_sum_exp( log(1-a) + normal_lpdf(Y[n] | 0, 1), log(a) + normal_lpdf(Y[n] | mu, 1) ); index = bernoulli_rng(a); y_pred = normal_rng(index == 1 ? mu: 0, 1); } " sink(file="/dev/null") suppressMessages({ library(rstan) fit <- stan(model_code=model2, data=data, iter=11000, warmup=1000, seed=123) }) sink() cat("# results of the mixture normal distribution model with Rstan\n") print(summary(fit)$summary[c("mu", "a"), ]) waic <- function(log_likelihood) { training_error <- - mean(log(colMeans(exp(log_likelihood)))) functional_variance_div_N <- mean(colMeans(log_likelihood^2) - colMeans(log_likelihood)^2) waic <- training_error + functional_variance_div_N return(waic) } cat(sprintf("WAIC: %f", waic(extract(fit)$log_likelihood)))
おわり。
-
見た感じでは許されない状況にある場合は、2つ以上のギブスサンプリングを回して、どれだけ結果が似ているかを評価する。Rhat(potential scale reduction factor)という指標がメジャー。簡単な統計量なのに実装はいくつかあってどれを使えばいいのかよく分からない。どれでもいいのかもしれない。↩
-
他に、サンプリングとサンプリングの間に何回かのイテレーションを挟むことがある。 目的は不明。たくさんイテレーションを回したいが、サンプリング数が多すぎると結果の処理が辛いため? ↩
-
日本語だとベイズ混合モデルにおける近似推論③ ~崩壊型ギブスサンプリング~, 作って遊ぶ機械学習。 があった。内容は難しくて3%も理解できない。↩