やってみた。ドラゴンボールの戦闘力推定
ドラゴンボールの勝敗結果から戦闘力を推定する, StatModeling Memorandumという実験記事があります。面白そうなのでやってみました。
準備
RとStanを使いました。tidyrとdplyrはRで使われる変な名前のデータ整形ライブラリです。
set.seed(1) library(rstan) rstan_options(auto_write = TRUE) options(mc.cores = parallel::detectCores()) library(tidyr) library(dplyr)
データ
games.txtは戦闘データ、senshis.txtはキャラクタ名と戦闘力の文献値が入っています。値はMemorandumの記事から拝借しました。
%>%
はtidyrやdplyrで使われる、メソッドチェイン的な構文が使える演算子です。joinを使って戦績にsenshi.txtを連結した結果を出しています。
games <- read.table("games.txt", header=T) senshis <- read.table("senshis.txt", header=T) games %>% left_join(senshis, c("winner"="ID")) %>% left_join(senshis, c("loser"="ID"), suffix = c(".winner", ".loser"))
winner | loser | name.winner | power.winner | name.loser | power.loser |
---|---|---|---|---|---|
1 | 2 | ベジータ | 18000 | 悟空 | 8000 |
1 | 3 | ベジータ | 18000 | ナッパ | 4000 |
1 | 5 | ベジータ | 18000 | 栽培マン | 1200 |
2 | 3 | 悟空 | 8000 | ナッパ | 4000 |
3 | 4 | ナッパ | 4000 | クリリン | 1770 |
4 | 5 | クリリン | 1770 | 栽培マン | 1200 |
モデル
モデルです。元記事のBUGSのモデルをまったくそのままStanに書き直したつもりです。
model <- " data { int N_member; int N_game; int Winner[N_game]; int Loser[N_game]; } parameters { real winner_p[N_game]; real loser_p[N_game]; real power[N_member]; } model { for (game in 1:N_game) { target += normal_lpdf(winner_p[game] | power[Winner[game]], 1); target += normal_lpdf(loser_p[game] | power[Loser[game]], 1); target += bernoulli_lpmf(1 | step(winner_p[game] - loser_p[game])); } for (member in 1:N_member) target += normal_lpdf(power[member] | 0, 100); } "
実行と結果
stan関数でサンプリングを実行します。初期値を与えないとサンプリングに失敗しました。とはいえ下で使用したinit関数は適当すぎるかもしれませんが、とりあえず気にしないことにします。
N_member <- nrow(senshis) N_game <- nrow(games) data <- list( N_member = N_member, N_game = N_game, Winner = games$winner, Loser = games$loser ) init <- function(...) { list( winner_p = rep(1, N_game), loser_p = rep(0, N_game), power = rep(0, N_member) ) } chains <- 4 fit <- stan( model_code = model, data = data, init = lapply(1:chains, init), pars = c('power'), iter = 22000, warmup = 2000, thin = 10, chains = chains )
結果です。powerしか見ていませんが、Rhatが1.1以下なのでよく収束していると言えそうです。
print(fit)
Inference for Stan model: 174a3ef46412238a4997765f45f7964d. 4 chains, each with iter=22000; warmup=2000; thin=10; post-warmup draws per chain=2000, total post-warmup draws=8000. mean se_mean sd 2.5% 25% 50% 75% 97.5% n_eff Rhat power[1] 117.90 3.06 65.08 0.42 73.67 113.45 157.94 258.26 452 1.00 power[2] 52.57 2.89 55.15 -53.83 14.25 52.60 89.81 165.15 365 1.01 power[3] 5.47 2.82 54.18 -100.56 -32.49 7.71 42.99 106.10 370 1.01 power[4] -44.23 2.64 55.49 -156.75 -82.17 -42.78 -5.11 57.88 443 1.00 power[5] -110.90 2.85 65.68 -251.55 -152.42 -105.90 -63.64 5.47 531 1.00 lp__ -47.05 0.07 2.87 -53.58 -48.74 -46.74 -45.00 -42.42 1850 1.00 Samples were drawn using NUTS(diag_e) at Sun Jul 30 23:08:37 2017. For each parameter, n_eff is a crude measure of effective sample size, and Rhat is the potential scale reduction factor on split chains (at convergence, Rhat=1).
ただしサンプリング時に以下の警告が出ます。意味はわかりませんがモデルが良くないと言われているような気がします。試しにadapt_deltaを増やしたサンプリングもやってみましたが、エラーは消えませんでした。これも今回は気にしないことにします。
警告メッセージ: 1: There were 5873 divergent transitions after warmup. Increasing adapt_delta above 0.8 may help. See http://mc-stan.org/misc/warnings.html#divergent-transitions-after-warmup 2: There were 520 transitions after warmup that exceeded the maximum treedepth. Increase max_treedepth above 10. See http://mc-stan.org/misc/warnings.html#maximum-treedepth-exceeded 3: Examine the pairs() plot to diagnose sampling problems
可視化
さて楽しい楽しい可視化の時間です。むしろ可視化がしたいがためにこの記事を書いたようなものです。
ggplot2というRではメジャーな可視化ライブラリを使って、Memorandumの確率密度プロットを再現しました。バイオリンプロットと呼ぶそうです。
まずstanのサンプリング結果オブジェクトからサンプリング値にアクセスします。extract関数を使いますが、extractはtidyrロード時に同名の別関数でマスクされているので、::
でnamespaceを指定して呼び出します。strは文字列変換ではなく、オブジェクトの構造を出力する関数です。stringではなくstructureです。
str(rstan::extract(fit, pars="power"))
List of 1 $ power: num [1:8000, 1:5] 115.5 141.9 125.3 104.5 57.7 ... ..- attr(*, "dimnames")=List of 2 .. ..$ iterations: NULL .. ..$ : NULL
可視化用にデータを加工します。戦闘力の文献値は常用対数にします。
senshis$power <- log10(senshis$power) D <- rstan::extract(fit, pars="power")$power %>% data.frame %>% setNames(senshis$name) %>% gather(name, power) %>% inner_join(senshis, c("name"="name"), suffix=c('.est.','.lit.')) D$name = factor(D$name, levels=senshis$name) head(D) CI <- D %>% group_by(name) %>% summarise(power.lit.=mean(power.lit.), median=median(power.est.), lower_95=sort(power.est.)[length(power.est.)*0.025], upper_95=sort(power.est.)[length(power.est.)*0.975]) CI
name | power.est. | ID | power.lit. |
---|---|---|---|
ベジータ | 115.49726 | 1 | 4.255273 |
ベジータ | 141.91934 | 1 | 4.255273 |
ベジータ | 125.32928 | 1 | 4.255273 |
ベジータ | 104.52283 | 1 | 4.255273 |
ベジータ | 57.72020 | 1 | 4.255273 |
ベジータ | 53.39367 | 1 | 4.255273 |
name | power.lit. | median | lower_95 | upper_95 |
---|---|---|---|---|
ベジータ | 4.255273 | 113.446818 | 0.04747007 | 258.253949 |
悟空 | 3.903090 | 52.597184 | -53.84137102 | 165.144239 |
ナッパ | 3.602060 | 7.713342 | -100.58709283 | 106.098659 |
クリリン | 3.247973 | -42.779770 | -157.13586821 | 57.882540 |
栽培マン | 3.079181 | -105.896344 | -251.56942070 | 5.471327 |
ggplot2で可視化します。まったく慣れていないので大変でしたが、ググれば情報は出て来るのでやりようはあります。
p <- ggplot(D, aes(power.lit., power.est., fill=name)) + geom_violin(size=0, adjust=1.5) + geom_pointrange(CI, mapping=aes(x=power.lit., y=median, ymin=lower_95, ymax=upper_95, colour=name), size=1.5) + scale_fill_manual(values=rep("grey70", 5)) + scale_y_continuous(limits=c(-500, 500), breaks=seq(-500, 500, 250)) + xlab("log10(Literature Power Level)") + ylab("Estimated Power Level") + theme(plot.title=element_text(size=18)) + theme(axis.text.x=element_text(size=14)) + theme(axis.text.y=element_text(size=14)) + theme(axis.title.x=element_text(size=18)) + theme(axis.title.y=element_text(size=18)) + theme(legend.title=element_text(colour="black",size=18)) + theme(legend.text=element_text(colour="black",size=18)) png(filename='result.png', width=700, height=600) plot(p) garbage <- dev.off()
Memorundomと同じようなグラフがかけました。ベジータの上方向と栽培マンの下方向が伸び伸び、ベジータ-悟空間、ナッパ-クリリン間の戦闘力差が小さいなど、特徴もよく似ています。Stanの結果も可視化処理も問題ないようです。
おわり。次は階層ベイズモデルで勝敗データからプロ棋士の強さを推定する, StatModeling Memorandumをやってみたいです。