やってみた。ドラゴンボールの戦闘力推定

ドラゴンボールの勝敗結果から戦闘力を推定する, StatModeling Memorandumという実験記事があります。面白そうなのでやってみました。

準備

RとStanを使いました。tidyrとdplyrはRで使われる変な名前のデータ整形ライブラリです。

set.seed(1)

library(rstan)
rstan_options(auto_write = TRUE)
options(mc.cores = parallel::detectCores())
library(tidyr)
library(dplyr)

データ

games.txtは戦闘データ、senshis.txtはキャラクタ名と戦闘力の文献値が入っています。値はMemorandumの記事から拝借しました。

%>%はtidyrやdplyrで使われる、メソッドチェイン的な構文が使える演算子です。joinを使って戦績にsenshi.txtを連結した結果を出しています。

games <- read.table("games.txt", header=T)
senshis <- read.table("senshis.txt", header=T)

games %>%
    left_join(senshis, c("winner"="ID")) %>%
    left_join(senshis, c("loser"="ID"), suffix = c(".winner", ".loser"))
winnerlosername.winnerpower.winnername.loserpower.loser
1 2 ベジータ18000 悟空 8000
1 3 ベジータ18000 ナッパ 4000
1 5 ベジータ18000 栽培マン1200
2 3 悟空 8000 ナッパ 4000
3 4 ナッパ 4000 クリリン1770
4 5 クリリン 1770 栽培マン1200

モデル

モデルです。元記事のBUGSのモデルをまったくそのままStanに書き直したつもりです。

model <- "
data {
  int N_member;
  int N_game;
  int Winner[N_game];
  int Loser[N_game];
}
parameters {
  real winner_p[N_game];
  real loser_p[N_game];
  real power[N_member];
}
model {
  for (game in 1:N_game) {
    target += normal_lpdf(winner_p[game] | power[Winner[game]], 1);
    target += normal_lpdf(loser_p[game] | power[Loser[game]], 1);
    target += bernoulli_lpmf(1 | step(winner_p[game] - loser_p[game]));
  }
  for (member in 1:N_member)
    target += normal_lpdf(power[member] | 0, 100);
}
"

実行と結果

stan関数でサンプリングを実行します。初期値を与えないとサンプリングに失敗しました。とはいえ下で使用したinit関数は適当すぎるかもしれませんが、とりあえず気にしないことにします。

N_member <-  nrow(senshis)
N_game <- nrow(games)

data <- list(
  N_member = N_member,
  N_game   = N_game,
  Winner   = games$winner,
  Loser    = games$loser
)

init <- function(...) {
  list(
    winner_p = rep(1, N_game),
    loser_p  = rep(0, N_game),
    power    = rep(0, N_member)
  )
}

chains <- 4
 
fit <- stan(
  model_code = model,
  data       = data,
  init       = lapply(1:chains, init),
  pars       = c('power'),
  iter       = 22000,
  warmup     = 2000,
  thin       = 10,
  chains     = chains
)

結果です。powerしか見ていませんが、Rhatが1.1以下なのでよく収束していると言えそうです。

print(fit)
Inference for Stan model: 174a3ef46412238a4997765f45f7964d.
4 chains, each with iter=22000; warmup=2000; thin=10; 
post-warmup draws per chain=2000, total post-warmup draws=8000.

            mean se_mean    sd    2.5%     25%     50%    75%  97.5% n_eff Rhat
power[1]  117.90    3.06 65.08    0.42   73.67  113.45 157.94 258.26   452 1.00
power[2]   52.57    2.89 55.15  -53.83   14.25   52.60  89.81 165.15   365 1.01
power[3]    5.47    2.82 54.18 -100.56  -32.49    7.71  42.99 106.10   370 1.01
power[4]  -44.23    2.64 55.49 -156.75  -82.17  -42.78  -5.11  57.88   443 1.00
power[5] -110.90    2.85 65.68 -251.55 -152.42 -105.90 -63.64   5.47   531 1.00
lp__      -47.05    0.07  2.87  -53.58  -48.74  -46.74 -45.00 -42.42  1850 1.00

Samples were drawn using NUTS(diag_e) at Sun Jul 30 23:08:37 2017.
For each parameter, n_eff is a crude measure of effective sample size,
and Rhat is the potential scale reduction factor on split chains (at 
convergence, Rhat=1).

ただしサンプリング時に以下の警告が出ます。意味はわかりませんがモデルが良くないと言われているような気がします。試しにadapt_deltaを増やしたサンプリングもやってみましたが、エラーは消えませんでした。これも今回は気にしないことにします。

 警告メッセージ:
1: There were 5873 divergent transitions after warmup. Increasing adapt_delta above 0.8 may help. See
http://mc-stan.org/misc/warnings.html#divergent-transitions-after-warmup
2: There were 520 transitions after warmup that exceeded the maximum treedepth. Increase max_treedepth above 10. See
http://mc-stan.org/misc/warnings.html#maximum-treedepth-exceeded
3: Examine the pairs() plot to diagnose sampling problems

可視化

さて楽しい楽しい可視化の時間です。むしろ可視化がしたいがためにこの記事を書いたようなものです。

ggplot2というRではメジャーな可視化ライブラリを使って、Memorandumの確率密度プロットを再現しました。バイオリンプロットと呼ぶそうです。

まずstanのサンプリング結果オブジェクトからサンプリング値にアクセスします。extract関数を使いますが、extractはtidyrロード時に同名の別関数でマスクされているので、::でnamespaceを指定して呼び出します。strは文字列変換ではなく、オブジェクトの構造を出力する関数です。stringではなくstructureです。

str(rstan::extract(fit, pars="power"))
List of 1
 $ power: num [1:8000, 1:5] 115.5 141.9 125.3 104.5 57.7 ...
  ..- attr(*, "dimnames")=List of 2
  .. ..$ iterations: NULL
  .. ..$           : NULL

可視化用にデータを加工します。戦闘力の文献値は常用対数にします。

senshis$power <- log10(senshis$power)

D <- rstan::extract(fit, pars="power")$power %>%
  data.frame %>%
  setNames(senshis$name) %>%
  gather(name, power) %>%
  inner_join(senshis, c("name"="name"), suffix=c('.est.','.lit.'))

D$name = factor(D$name, levels=senshis$name)
head(D)

CI <- D %>%
  group_by(name) %>%
  summarise(power.lit.=mean(power.lit.),
            median=median(power.est.),
            lower_95=sort(power.est.)[length(power.est.)*0.025],
            upper_95=sort(power.est.)[length(power.est.)*0.975])
CI
namepower.est.IDpower.lit.
ベジータ 115.497261 4.255273
ベジータ 141.919341 4.255273
ベジータ 125.329281 4.255273
ベジータ 104.522831 4.255273
ベジータ 57.720201 4.255273
ベジータ 53.393671 4.255273
namepower.lit.medianlower_95upper_95
ベジータ 4.255273 113.446818 0.04747007258.253949
悟空 3.903090 52.597184 -53.84137102165.144239
ナッパ 3.602060 7.713342 -100.58709283106.098659
クリリン 3.247973 -42.779770 -157.13586821 57.882540
栽培マン 3.079181 -105.896344 -251.56942070 5.471327

ggplot2で可視化します。まったく慣れていないので大変でしたが、ググれば情報は出て来るのでやりようはあります。

p <- ggplot(D, aes(power.lit., power.est., fill=name)) +
  geom_violin(size=0, adjust=1.5) +
  geom_pointrange(CI, mapping=aes(x=power.lit., y=median, ymin=lower_95, ymax=upper_95, colour=name), size=1.5) +
  scale_fill_manual(values=rep("grey70", 5)) +
  scale_y_continuous(limits=c(-500, 500), breaks=seq(-500, 500, 250)) +
  xlab("log10(Literature Power Level)") +
  ylab("Estimated Power Level") +
  theme(plot.title=element_text(size=18)) +
  theme(axis.text.x=element_text(size=14)) +
  theme(axis.text.y=element_text(size=14)) +
  theme(axis.title.x=element_text(size=18)) +
  theme(axis.title.y=element_text(size=18)) +
  theme(legend.title=element_text(colour="black",size=18)) +
  theme(legend.text=element_text(colour="black",size=18))

png(filename='result.png', width=700, height=600)
plot(p)
garbage <- dev.off()

可視化

Memorundomと同じようなグラフがかけました。ベジータの上方向と栽培マンの下方向が伸び伸び、ベジータ-悟空間、ナッパ-クリリン間の戦闘力差が小さいなど、特徴もよく似ています。Stanの結果も可視化処理も問題ないようです。

おわり。次は階層ベイズモデルで勝敗データからプロ棋士の強さを推定する, StatModeling Memorandumをやってみたいです。