結婚制度の廃止を望んでやまない(^^)

多変量混合ガウスモデルでvariational inference

混合ガウスモデルとNormal-Wishart分布(事前分布)に対してVariational inferenceしてみる.

要するにPRMLの10.2. である. 以下はつまるところその引き写しである.

モデル

データ  \mathbf X = \left\lbrace \mathbf x_1, ..., \mathbf x_N\right\rbrace \mathbf xの次元は D )について, 混合数  K の混合ガウスモデルを考える. 過去何度かやったように, 混合分布モデルではデータが属する分布を表現する隠れ変数を導入するとやりやすい.

クラスタのパラメータ  \boldsymbol \mu_{1...K}  \boldsymbol \Lambda_{1...K} と, クラスタの混合比  \boldsymbol \pi_{1...K} \left( \sum_{k=1}^K \boldsymbol \pi_k = 1 \right) がある.

隠れ変数として, サイズが N \times K, 各行の総和が必ず1である2値行列  \mathbf Z を導入する. \mathbf Z_{nk} が1であるとき, \mathbf x_nクラスタ k に属する.

まずデータの尤度である.  \mathbf Z_{nk} は0 or 1なので, ガウス分布のPDFをこれでべき乗してkについて総乗をとると \mathbf x_n の尤度が得られる.



p\left(\mathbf X \mid \mathbf Z, \boldsymbol \mu, \boldsymbol \Lambda \right)
= \prod_{n=1}^N\prod_{k=1}^K \left\lbrace \mathcal N \left(\mathbf x_n \mid \boldsymbol \mu_k, \boldsymbol \Lambda_k^{-1} \right) \right\rbrace^{z_{nk}} \\

\boldsymbol \mu\boldsymbol \Lambda の共役事前分布. パラメータ  \mathbf m_{0}, \beta_{0},  \mathbf W_{0}, \nu_{0} は全クラスタで共通.



p\left(\boldsymbol \mu, \boldsymbol \Lambda \right) =
    \prod_{k=1}^{K}
    \mathcal N \left(\boldsymbol \mu_{k} \mid \mathbf m_{0}, \left(\beta_{0} \Lambda_{k} \right)^{-1}\right)
    \mathcal W\left(\boldsymbol \Lambda_{k} \mid \mathbf W_{0}, \nu_{0}\right)

\boldsymbol \pi の共役事前分布としてディリクレ分布を用いる. \boldsymbol \alpha_{0} は大きさ K の整数ベクトル.  \boldsymbol \alpha_{0k } は あるクラスタkについて \mathbf Z_{nk} が1であるデータnの個数の総和と同じ意味である. 事前情報があればここに反映させる. なければ全クラスタについて同じ値を与える.



p\left(\boldsymbol \pi \right)
    = {\rm Dir} \left( \boldsymbol \pi \mid \boldsymbol \alpha_0 \right)

さて,  p\left( \mathbf Z \right) が足りない. これについて考える. 混合比 \boldsymbol \pi が有るときに, データ\mathbf x_{n} クラスタ k に属する確率 p(\mathbf z_{n\lbrace1...K \rbrace} \mid Z_{nk}=1 )は混合比 \pi_kとなる.  Z_{n\lbrace 0,1,...,k-1,k+1,..K \rbrace} は0なので, \pi_k = \left( \pi_{k} \right)^{z_{nk}} と変形できる. つまり下となる.



\begin{eqnarray}
p\left( \mathbf Z \mid \boldsymbol \pi \right)
&=&
    \prod_{n=1}^N p\left( \mathbf z_{n\lbrace1...K \rbrace} \right) \\
&=&
    \prod_{n=1}^N \prod_{k=1}^K p\left( z_{nk} \right) \\
&=&
    \prod_{n=1}^N \prod_{k=1}^K \left( \pi_{k} \right)^{z_{nk}} \\
\end{eqnarray}

パラメータの事後分布は下のように書ける.



p\left(\boldsymbol \pi, \boldsymbol \mu, \boldsymbol \Lambda, \mathbf Z \mid \mathbf X \right) =
    \frac {
    p\left(\textbf X \mid \mathbf Z, \boldsymbol \mu, \boldsymbol \Lambda \right)
    p\left(\mathbf Z \mid \boldsymbol \pi \right)
    p\left(\boldsymbol \pi \right)
    p\left(\boldsymbol \mu, \boldsymbol \Lambda \right)
    } {p\left( \mathbf X \right)}

variational inferenceではパラメータの同時確率をパラメータごとに分解可能と近似するわけだが, \boldsymbol \mu\boldsymbol \Lambda前回, 前々回により同時に事後分布が計算できることがわかっているので分ける必要がなく, \mathbf Z \boldsymbol \pi(\boldsymbol {\mu, \Lambda}) が分解できればよい.

つまり以下のような近似を行う.



p\left(\boldsymbol \pi, \boldsymbol \mu, \boldsymbol \Lambda, \mathbf Z \right) =
    q\left(\mathbf Z \right)
    q\left(\boldsymbol \pi \right)
    q\left(\boldsymbol \mu, \boldsymbol \Lambda \right)

それでもって公式で分布を求める.



{\rm ln}\, q_j^{\ast}(\boldsymbol Z_j) =  \mathbb{E}_{i\neq j} \lbrack {\rm ln}\, p(\boldsymbol X, \boldsymbol Z) \rbrack + {\rm const}

q^{\ast} (\mathbf Z)

まず \mathbf Z について.



\begin{eqnarray}
\operatorname {ln} q^{\ast} \left( \mathbf Z \right)
&=&
    \mathbb{E}_{\boldsymbol \pi, \boldsymbol \mu, \boldsymbol \Lambda} \lbrack
        \operatorname {ln} p\left(
            \boldsymbol X, \boldsymbol Z, \boldsymbol \pi, \boldsymbol \mu, \boldsymbol \Lambda
        \right)
    \rbrack + {\rm const} \\
&=&
    \mathbb{E}_{\boldsymbol \pi, \boldsymbol \mu, \boldsymbol \Lambda} \lbrack
        \operatorname {ln} \left\lbrace
            p\left(\textbf X \mid \mathbf Z, \boldsymbol \mu, \boldsymbol \Lambda \right)
            p\left(\mathbf Z \mid \boldsymbol \pi \right)
            p\left(\boldsymbol \pi \right)
            p\left(\boldsymbol \mu, \boldsymbol \Lambda \right)
        \right\rbrace
    \rbrack + {\rm const} \\
&=&
    \mathbb{E}_{\boldsymbol \mu, \boldsymbol \Lambda} \lbrack
        \operatorname {ln}
            p\left(\textbf X \mid \mathbf Z, \boldsymbol \mu, \boldsymbol \Lambda \right)
    \rbrack
    +
    \mathbb{E}_{\boldsymbol \pi} \lbrack
        \operatorname {ln}
            p\left(\mathbf Z \mid \boldsymbol \pi \right)
    \rbrack + {\rm const} \\
&=&
    \mathbb{E}_{\boldsymbol \mu, \boldsymbol \Lambda} \left\lbrack
        \operatorname {ln}
            \prod_{n=1}^N\prod_{k=1}^K \left\lbrace \mathcal N \left(\mathbf x_n \mid \boldsymbol \mu_k, \boldsymbol \Lambda_k^{-1} \right) \right\rbrace^{z_{nk}} 
    \right\rbrack
    +
    \mathbb{E}_{\boldsymbol \pi} \left\lbrack
        \operatorname {ln}
            \prod_{n=1}^N \prod_{k=1}^K \left( \pi_{k} \right)^{z_{nk}} 
    \right\rbrack + {\rm const} \\
&=&
    \sum_{n=1}^N\sum_{k=1}^K z_{nk} \left(
        \mathbb{E}\left\lbrack 
            \operatorname {ln}
                \mathcal N \left(\mathbf x_n \mid \boldsymbol \mu_k, \boldsymbol \Lambda_k^{-1} \right)
        \right\rbrack
        +
        \mathbb{E}\left\lbrack
            \operatorname {ln} \pi_{k}
        \right\rbrack
    \right) + {\rm const} \\
&=&
    \sum_{n=1}^N\sum_{k=1}^K z_{nk} \underbrace {\left(
        \frac {1} {2} \mathbb{E}\left\lbrack 
            \operatorname {ln}\left| \boldsymbol \Lambda_{k} \right|
            \right\rbrack
        - \frac  {D} {2} \operatorname {ln} \left( 2\pi\right)
        - \frac {1} {2} \mathbb{E}_{\boldsymbol \mu_{k},\boldsymbol \Lambda_{k}}\left\lbrack
            \left(\mathbf x_{n} - \boldsymbol \mu_{k} \right)^{\rm T}
            \boldsymbol \Lambda_{k}
            \left(\mathbf x_{n} - \boldsymbol \mu_{k} \right)
            \right\rbrack
        + \mathbb{E}\left\lbrack
            \operatorname {ln} \pi_{k}
        \right\rbrack
    \right)}_{=: \; \operatorname {ln} \rho_{nk}} + {\rm const} \\
\end{eqnarray}

\rho_{nk} を導入するとよいらしい. 対数を外すとq^{\ast} (\mathbf Z)は下のような形となる.



q^{\ast}\left( \mathbf Z \right) \propto \prod_{n=1}^{N} \prod_{k=1}^{K} \left( \rho_{nk} \right)^{z_{nk}}

正規化係数Cを導入すると以下になる.



\int q^{\ast}\left( \mathbf Z \right)\operatorname d \mathbf Z = 
\int \frac
    {\prod_{n=1}^{N} \prod_{k=1}^{K} \left( \rho_{nk} \right)^{z_{nk}}}
    {C} \operatorname d \mathbf Z = 1

\mathbf Zの状態数を考えると, まず\mathbf ZN \times K の2値ベクトル(0 or 1)なのだが, ある行nについては, すべてのkについて, 値が1となるものが一つだけ存在する. なので状態数は K^{N} = \prod_{n=1}^{N} K である. まあこれは当然である. 正規化係数はすべての状態についてのp(\mathbf Z) の総和なので, 以下のように書ける.



\begin{eqnarray}
C &=& \sum_{k_{1}=1}^K\sum_{k_{2}=1}^K...\sum_{k_{K}=1}^K \rho_{1k_{1}}\rho_{2k_{2}}...\rho_{Nk_{K}}\\
&=& \prod_{n=1}^{N} \sum_{j=1}^{K} \rho_{nj} \\
&=& \prod_{n=1}^{N} \prod_{k=1}^{K} \left(\sum_{j=1}^{K} \rho_{nj}\right)^{z_{nk}}
\end{eqnarray}

最終行は q^{\ast}(\mathbf Z) に戻すための準備である. \mathbf Z_{nk} k にわたっては一つだけ1をとり他はゼロなので prodの中でこのように書けることは, これまでやってきたとおり.

q^{\ast}(\mathbf Z) を正規化する.



\begin{eqnarray}
q^{\ast}\left( \mathbf Z \right) 
&=&
    \frac 
    {\prod_{n=1}^{N} \prod_{k=1}^{K} \left( \rho_{nk} \right)^{z_{nk}}}
    {\prod_{n=1}^{N} \prod_{k=1}^{K} \left(\sum_{j=1}^{K} \rho_{nj}\right)^{z_{nk}}} \\
&=&
    \prod_{n=1}^{N} \prod_{k=1}^{K} \left(\frac {\rho_{nk}} {\sum_{j=1}^{K} \rho_{nj}}\right)^{z_{nk}} \\
&=&
    \prod_{n=1}^{N} \prod_{k=1}^{K} \left(r_{nk}\right)^{z_{nk}} \\
r_{nk} &=& \frac {\rho_{nk}} {\sum_{j=1}^{K} \rho_{nj}}
\end{eqnarray}

\rho_{nk}, r_{nk} を導入しておくのが重要であるのだそうだ.

上から  \operatorname {Pr} \left\lbrack z_{nk} = 1 \right\rbrack = r_{nk} なので, 以下が得られる.



\begin{eqnarray}
\mathbb E\lbrack z_{nk} \rbrack &=& r_{nk}
\end{eqnarray}

また, 以下を定義しておくと, q^{\ast}(\boldsymbol \pi) , q^{\ast}(\boldsymbol \mu, \boldsymbol \lambda) を解くのに便利.



\begin{eqnarray}
N_{k} &=& \sum_{n=1}^N r_{nk} \\
\overline{ \mathbf x}_{k} &=& \frac {1} {N_{k}} \sum_{n=1}^{N} r_{nk} \mathbf x_{n} \\
\mathbf S_{k} &=&
    \frac {1} {N_{k}}
    \sum_{n=1}^{N} r_{nk}
        \left(\mathbf x_{n} - \overline{ \mathbf x}_{k}\right)
        \left(\mathbf x_{n} - \overline{ \mathbf x}_{k}\right)^{\rm T}
    \\
\end{eqnarray}

q^{\ast} (\boldsymbol \pi)



\begin{eqnarray}
\operatorname {ln} q^{\ast} \left( \boldsymbol \pi \right)
&=&
    \mathbb{E}_{\textbf Z, \boldsymbol \mu, \boldsymbol \Lambda} \lbrack
        \operatorname {ln} \left\lbrace
            p\left(\textbf X \mid \mathbf Z, \boldsymbol \mu, \boldsymbol \Lambda \right)
            p\left(\mathbf Z \mid \boldsymbol \pi \right)
            p\left(\boldsymbol \pi \right)
            p\left(\boldsymbol \mu, \boldsymbol \Lambda \right)
        \right\rbrace
    \rbrack + {\rm const} \\
&=&
    \mathbb{E}_{\textbf Z} \lbrack
        \operatorname {ln} \left\lbrace
            p\left(\mathbf Z \mid \boldsymbol \pi \right)
            p\left(\boldsymbol \pi \right)
        \right\rbrace
    \rbrack + {\rm const} \\
&=&
    \mathbb{E}_{\textbf Z} \lbrack
        \operatorname {ln} \left\lbrace
            p\left(\mathbf Z \mid \boldsymbol \pi \right)
        \right\rbrace
    \rbrack
    + \operatorname {ln} \operatorname {Dir} \left(\boldsymbol \pi \mid \alpha_{0} \right)
    + {\rm const} \\
&=&
    \sum_{n=1}^N\sum_{k=1}^K  \operatorname {ln} \pi_{k} \mathbb{E} \lbrack
        z_{nk}
    \rbrack
    + \left(\alpha_{0} -1 \right) \sum_{k=1}^{K} \operatorname {ln} \pi_{k}
    + {\rm const} \\
&=&
    \sum_{k=1}^K
    \operatorname {ln} \pi_{k}\left(
        \sum_{n=1}^N \mathbb{E} \lbrack z_{nk}\rbrack
        + \alpha_{0}
        - 1
    \right)
    + {\rm const} \\
&=&
    \sum_{k=1}^K
    \operatorname {ln} \pi_{k}\left(
    N_{k} + \alpha_{0} - 1
    \right)
    + {\rm const} \\
\end{eqnarray}

これは事前分布と同様にDirichlet分布である.



\begin{eqnarray}
q^{\ast}\left( \boldsymbol \pi \right) &=& \operatorname {Dir} \left( \boldsymbol \alpha \right) \\
\alpha_{k} &=& N_k+\alpha_{0}
\end{eqnarray}

期待値は以下になる.



\mathbb E \left\lbrack \pi_{k} \right\rbrack
= \mathbb E \left\lbrack \operatorname {Dir} \left(\alpha_{k} \right) \right\rbrack
= \frac {\alpha_{k}} {\sum_{j=1}^K \alpha_j}
= \frac {\alpha_{0} + N_{k}} {K\alpha_0 + \sum_{j=1}^K N_j}
= \frac {\alpha_{0} + N_{k}} {K\alpha_0 + N}

q^{\ast} (\boldsymbol \mu_{k}, \boldsymbol \Lambda_{k})



\begin{eqnarray}
\operatorname {ln} q^{\ast} \left( \boldsymbol \mu, \boldsymbol \Lambda \right)
&=&
    \mathbb{E}_{\textbf Z, \boldsymbol \pi} \lbrack
        \operatorname {ln} \left\lbrace
            p\left(\textbf X \mid \mathbf Z, \boldsymbol \mu, \boldsymbol \Lambda \right)
            p\left(\mathbf Z \mid \boldsymbol \pi \right)
            p\left(\boldsymbol \pi \right)
            p\left(\boldsymbol \mu, \boldsymbol \Lambda \right)
        \right\rbrace
    \rbrack + {\rm const} \\
&=&
    \mathbb{E}_{\textbf Z} \left\lbrack
        \sum_{n=1}^N\sum_{k=1}^K 
            z_{nk}
            \operatorname {ln} \mathcal N \left(
                \mathbf x_n \mid \boldsymbol \mu_k, \boldsymbol \Lambda_k^{-1}
            \right)
         \right\rbrack
    + \sum_{k=1}^{K} \operatorname {ln} \left\lbrace
        \mathcal N \left(\boldsymbol \mu_{k} \mid \mathbf m_{0}, \left(\beta_{0} \boldsymbol \Lambda_{k} \right)^{-1}\right)
        \mathcal W\left(\boldsymbol \Lambda_{k} \mid \mathbf W_{0}, \nu_{0}\right)
        \right\rbrace
    + {\rm const } \\
&=&
    \sum_{k=1}^K
        \sum_{n=1}^N
            \mathbb{E}_{\textbf Z} \left\lbrack z_{nk} \right\rbrack
            \operatorname {ln} \mathcal N \left(
                \mathbf x_n \mid \boldsymbol \mu_k, \boldsymbol \Lambda_k^{-1}
            \right)
    + \sum_{k=1}^{K} \operatorname {ln} \left\lbrace
        \mathcal N \left(\boldsymbol \mu_{k} \mid \mathbf m_{0}, \left(\beta_{0} \boldsymbol \Lambda_{k} \right)^{-1}\right)
        \mathcal W\left(\boldsymbol \Lambda_{k} \mid \mathbf W_{0}, \nu_{0}\right)
        \right\rbrace
    + {\rm const } \\
&=&
    \sum_{k=1}^K
        \sum_{n=1}^N
            r_{nk}
            \operatorname {ln} \mathcal N \left(
                \mathbf x_n \mid \boldsymbol \mu_k, \boldsymbol \Lambda_k^{-1}
            \right)
    + \sum_{k=1}^{K} \operatorname {ln} \left\lbrace
        \mathcal N \left(\boldsymbol \mu_{k} \mid \mathbf m_{0}, \left(\beta_{0} \boldsymbol \Lambda_{k} \right)^{-1}\right)
        \mathcal W\left(\boldsymbol \Lambda_{k} \mid \mathbf W_{0}, \nu_{0}\right)
        \right\rbrace
    + {\rm const } \\
\end{eqnarray}

\boldsymbol \mu_{k}\boldsymbol \Lambda_{k}について解く. 前回の, \boldsymbol \mu\boldsymbol \Lambda の事後分布の解き方と同じである.



\begin{eqnarray}
\operatorname {ln} q^{\ast} \left( \boldsymbol \mu_{k}, \boldsymbol \Lambda_{k} \right)
&=&
    \sum_{n=1}^N
        r_{nk}
        \operatorname {ln} \mathcal N \left(
            \mathbf x_n \mid \boldsymbol \mu_k, \boldsymbol \Lambda_k^{-1}
        \right)
    + \operatorname {ln} \left\lbrace
        \mathcal N \left(\boldsymbol \mu_{k} \mid \mathbf m_{0}, \left(\beta_{0} \boldsymbol \Lambda_{k} \right)^{-1}\right)
        \mathcal W\left(\boldsymbol \Lambda_{k} \mid \mathbf W_{0}, \nu_{0}\right)
        \right\rbrace
    + {\rm const } \\
&=&
    \frac {N_{k} + \nu_{0} -D} {2} \operatorname{ln} \left| \boldsymbol \Lambda_{k} \right|
    - \frac {1} {2} \underbrace {\left(
        \sum_{n=1}^{N} r_{nk}
            \left(\mathbf x_{n} - \boldsymbol \mu_{k} \right)^{\rm T}
            \boldsymbol \Lambda_{k}
            \left(\mathbf x_{n} - \boldsymbol \mu_{k} \right)
        + \beta_{0}
            \left(\boldsymbol \mu_{k} - \mathbf m_{0} \right)^{\rm T}
            \boldsymbol \Lambda_{k}
            \left(\boldsymbol \mu_{k} - \mathbf m_{0} \right)
    \right) }_{=: \; A}
    - \frac {\operatorname {tr} \left(\mathbf W_{0}^{-1} \boldsymbol \Lambda \right)} {2}
\end{eqnarray}


\begin {eqnarray}
A &=&
    \sum_{n=1}^N r_{nk}
        \left(\textbf x_{n} - \boldsymbol \mu_{k}\right)^{\rm T}
        \boldsymbol \Lambda
        \left(\textbf x_{n} - \boldsymbol \mu_{k}\right)
    + \beta_{0}\left(\boldsymbol \mu_{k} - \textbf m_0\right)^{\rm T}
        \boldsymbol \Lambda_{k}
        \beta_{0}\left(\boldsymbol \mu_{k} - \textbf m_0\right) \\
&=&
    \sum_{n=1}^N r_{nk}\left( {\textbf x_{n}}^{\rm T} \boldsymbol \Lambda_{k} \textbf x_{n} \right)
    - \boldsymbol \mu_{k}^{\rm T} \boldsymbol \Lambda_{k} \sum_{n=1}^N r_{nk} \textbf x_{n}
    - \sum_{n=1}^N r_{nk}{\textbf x_{n}}^{\rm T} \boldsymbol \Lambda_{k} \boldsymbol \mu_{k}
    + N_{k} \boldsymbol \mu_{k}^{\rm T} \boldsymbol \Lambda \boldsymbol \mu_{k}
    + \beta_{0} \boldsymbol \mu_{k}^{\rm T} \boldsymbol \Lambda_{k} \boldsymbol \mu_{k}
    - \beta_{0} \mathbf m_{0}^{\rm T} \boldsymbol \Lambda_{k} \boldsymbol \mu_{k}
    - \beta_{0} \boldsymbol \mu_{k}^{\rm T} \boldsymbol \Lambda_{k} \textbf m_0
    + \beta_{0} \mathbf m_{0}^{\rm T} \boldsymbol \Lambda_{k} \textbf m_0 \\
&=&
    \left(N_{k} + \beta_{0}\right)\boldsymbol \mu_{k}^{\rm T} \boldsymbol \Lambda_{k} \boldsymbol \mu_{k}
    - \boldsymbol \mu_{k}^{\rm T} \boldsymbol \Lambda_{k} \left(
        \sum_{n=1}^N r_{nk} \textbf x_n
        + \beta_0 \mathbf m_0
    \right)
    - \left(
        \sum_{n=1}^N r_{nk}  \textbf x_n
        + \beta_0 \mathbf m_0
    \right)^{\rm T} \boldsymbol \Lambda_{k}  \boldsymbol \mu_{k}
    + \sum_{n=1}^N  r_{nk} {\textbf x_n}^{\rm T} \boldsymbol \Lambda_{k} \textbf x_n
    + \beta_0 \mathbf m_{0}^{\rm T} \boldsymbol \Lambda_{k} \mathbf m_{0} \\
&=&
    (N_{k} + \beta_0)
        \left(
            \boldsymbol \mu_{k}
            - \frac {N_{k}\overline {\textbf x_{k}} + \beta_{0} \mathbf m_0} {N_{k} + \beta_0}
        \right)^{\rm T}
        \boldsymbol \Lambda_{k}
        \left(
            \boldsymbol \mu_{k}
            - \frac {N_{k}\overline {\textbf x_{k}} + \beta_0 \textbf m_0} {N_{k} + \beta_0}
        \right) \\
&&
    \underbrace {- \frac {
        \left(N_{k} \overline {\textbf x_{k}} + \beta_0 \mathbf m_0\right)^{\rm T}
        \boldsymbol \Lambda
        \left(N_{k} \overline {\textbf x_{k}} + \beta_0 \mathbf m_0\right)
    } {N_{k} + \beta_0}
    + \sum_{n=1}^N r_{nk} {\textbf x_n}^{\rm T} \boldsymbol \Lambda_{k} \textbf x_n
    + \beta_0 \mathbf m_0^{\rm T} \boldsymbol \Lambda_{k} \mathbf m_0}_{=: B} \\
\end {eqnarray}

\begin{eqnarray}
B &=&
    - \frac {
        \left(N_{k}\overline {\textbf x_{k}} + \beta_0 \mathbf m_0\right)^{\rm T}
        \boldsymbol \Lambda_{k}
        \left(N_{k}\overline {\textbf x_{k}} + \beta_0 \mathbf m_0\right)
    } {N_{k} + \beta_0}
    + \sum_{n=1}^N r_{nk} {\textbf x_n}^{\rm T} \boldsymbol \Lambda \textbf x_n
    + \beta_0 \mathbf m_0^{\rm T} \boldsymbol \Lambda \mathbf m_0 \\
&=&
    - \frac {{\rm tr}\left(
        \left(N_{k}\overline {\textbf x_{k}} + \beta_0 \mathbf m_0\right)
        \left(N_{k}\overline {\textbf x_{k}} + \beta_0 \mathbf m_0\right)^{\rm T}
        \boldsymbol \Lambda_{k}
    \right)} {N_{k} + \beta_0}
    + {\rm tr}\left(
        \sum_{n=1}^N \left( r_{nk} \textbf x_n {\textbf x_n}^{\rm T} \right)
         \boldsymbol \Lambda_{k}
    \right)
    + {\rm tr} \left(
          \beta_0 \mathbf m_0 \mathbf m_0^{\rm T} \boldsymbol \Lambda_{k}
    \right) \\
&=&
    \frac {1}{N_{k}+\beta_0} {\rm tr} \left(\left(
        - N_{k}^2 \overline {\textbf x_{k}}\overline {\textbf x_{k}}^{\rm T}
        - N_{k} \beta_0 \overline {\textbf x_{k}}\mathbf m_0^{\rm T}
        - N_{k} \beta_0 \mathbf m_0 \overline {\textbf x_{k}}^{\rm T}
        - \beta_0^2 \mathbf m_0 \mathbf m_0^{\rm T}
        + \left( N_{k} + \beta_0 \right) \sum_{n=1}^N r_{nk} \textbf x_n {\textbf x_n}^{\rm T}
        + N_{k} \beta_0 \mathbf m_0 \mathbf m_0^{\rm T}
        + \beta_0^2 \mathbf m_0 \mathbf m_0^{\rm T}
     \right)\boldsymbol \Lambda_{k} \right) \\
&=&
    \frac {1}{N_{k}+\beta_0} {\rm tr} \left(\left(
        - N_{k}^2 \overline {\textbf x_{k}}\overline {\textbf x_{k}}^{\rm T}
        - N_{k} \beta_0 \overline {\textbf x}\mathbf m_0^{\rm T}
        - N_{k} \beta_0 \mathbf m_0 \overline {\textbf x_{k}}^{\rm T}
        + \left( N_{k} + \beta_0 \right) \sum_{n=1}^N r_{nk} \textbf x_n {\textbf x_n}^{\rm T}
        + N_{k} \beta_0 \mathbf m_0 \mathbf m_0^{\rm T}
     \right)\boldsymbol \Lambda \right) \\
&=&
    \frac {1}{N_{k}+\beta_0} {\rm tr} \left(\left(
        N_{k}\beta_0 \left(\mathbf m_0 - \overline {\mathbf x_{k}}\right)
        \left(\mathbf m_0 - \overline {\mathbf x_{k}}\right)^{\rm T}
        - N_{k}\left( N_{k} + \beta_0\right) \overline {\textbf x_{k}}\overline {\textbf x_{k}}^{\rm T}
        + \left( N_{k} + \beta_0 \right) \sum_{n=1}^N r_{nk} \textbf x_n {\textbf x_n}^{\rm T}
     \right)\boldsymbol \Lambda_{k} \right) \\
&=&
    \frac {N_{k} \beta_0}{N_{k}+\beta_0} {\rm tr} \left(
        \left(\mathbf m_0 - \overline {\textbf x}\right)
        \left(\mathbf m_0 - \overline {\textbf x}\right)^{\rm T}
        \boldsymbol \Lambda_{k}
    \right)
    + {\rm tr} \left(\left(
        - N_{k} \overline {\textbf x_{k}}\overline {\textbf x_{k}}^{\rm T}
        + \sum_{n=1}^N r_{nk}\textbf x_n {\textbf x_n}^{\rm T}
     \right)\boldsymbol \Lambda_{k} \right) \\
&=&
    \frac {N_{k} \beta_0}{N_{k}+\beta_0} {\rm tr} \left(
        \left(\mathbf m_0 - \overline {\textbf x}\right)
        \left(\mathbf m_0 - \overline {\textbf x}\right)^{\rm T}
        \boldsymbol \Lambda_{k}
    \right) \\
&&
    + {\rm tr} \left(\left(
        \sum_{n=1}^N r_{nk}
            \left(\textbf x_n - \overline {\textbf x_{k}} \right)
            \left(\textbf x_n - \overline {\textbf x_{k}} \right)^{\rm T}
        + \sum_{n=1}^N r_{nk} \textbf x_n \overline {\textbf x_{k}} ^ {\rm T}
        + \sum_{n=1}^N r_{nk} \overline {\textbf x_{k}} \textbf x_n^ {\rm T}
        - \sum_{n=1}^N r_{nk} \overline {\textbf x_{k}} \overline {\textbf x_{k}} ^ {\rm T}
        - N_{k} \overline {\textbf x_{k}} \overline {\textbf x_{k}} ^ {\rm T}
     \right)\boldsymbol \Lambda_{k} \right) \\
&=&
    \frac {N_{k} \beta_0}{N_{k}+\beta_0} {\rm tr} \left(
        \left(\mathbf m_0 - \overline {\textbf x}\right)
        \left(\mathbf m_0 - \overline {\textbf x}\right)^{\rm T}
        \boldsymbol \Lambda_{k}
    \right)
    + {\rm tr} \left(\left(
        \sum_{n=1}^N r_{nk}
        \left(\textbf x_n - \overline {\textbf x_{k}} \right)
        \left(\textbf x_n - \overline {\textbf x_{k}} \right)^{\rm T}
     \right)\boldsymbol \Lambda_{k} \right) \\
&=&
    \frac {N_{k} \beta_0}{N_{k}+\beta_0} {\rm tr} \left(
        \left(\mathbf m_0 - \overline {\textbf x}\right)
        \left(\mathbf m_0 - \overline {\textbf x}\right)^{\rm T}
        \boldsymbol \Lambda_{k}
    \right)
    + {\rm tr} \left(N_{k} \mathbf S_{k} \boldsymbol \Lambda_{k} \right) \\
\end{eqnarray}

最終的に以下が得られる.



\begin{eqnarray}
\operatorname {ln} q^{\ast}\left(\boldsymbol \mu_{k}, \boldsymbol \Lambda_{k} \right)
&=& 
    - \frac {1} {2}
    (N_{k} + \beta_0)
        \left(
            \boldsymbol \mu
            - \frac {N_{k}\overline {\textbf x_{k}} + \beta_0 \mathbf m_0} {N_{k} + \beta_0}
        \right)^{\rm T}
        \boldsymbol \Lambda_{k}
        \left(
            \boldsymbol \mu
            - \frac {N_{k} \overline {\textbf x_{k}} + \beta_0 \mathbf m_0} {N_{k} + \beta_0}
        \right)
    + \frac {N_k + \nu_{0} - D} {2} \operatorname {ln} \left| \boldsymbol \Lambda \right| \\
&&
    - \frac {1} {2} \frac {N_{k} \beta _0}{N_{k}+\beta_0} {\rm tr} \left(
        \left(\mathbf m_0 - \overline {\textbf x_{k}}\right)
        \left(\mathbf m_0 - \overline {\textbf x_{k}}\right)^{\rm T}
        \boldsymbol \Lambda_{k}
    \right)
    - \frac {1} {2} {\rm tr} \left(N_{k} \mathbf S_{k} \boldsymbol \Lambda_{k} \right)
    - \frac {\operatorname {tr} (\mathbf {W_0} ^{-1}\mathbf {\boldsymbol \Lambda_{k}})} {2}
    + {\rm const} \\
\end{eqnarray}

これは予想どおりNormal-Wishart分布である.



\begin{eqnarray}
q^{\ast}\left(\boldsymbol \mu_{k}, \boldsymbol \Lambda_{k} \right)
&=&
    \mathcal N\left(
        \boldsymbol \mu_{k} \mid \mathbf m_k, \left(\beta_{k} \boldsymbol \Lambda_{k} \right)^{-1}\right)
    \mathcal W\left(\boldsymbol \Lambda_{k} \mid \textbf W_{k}^{-1}, \nu_{k} \right) \\
\mathbf m_k &=& \frac {N_{k}\overline {\textbf x_{k}} + \beta_0 \mathbf m_0} {N_{k} + \beta_0} \\
\beta_k &=& N_{k}+\beta_0 \\
\textbf W_k^{-1} &=&
    \textbf W_0^{-1}
    + \frac {N_{k} \beta_0}{N_{k}+\beta_0}
        \left(\overline {\textbf x_{k}} - \mathbf m_0\right)
        \left(\overline {\textbf x_{k}} - \mathbf m_0\right)^{\rm T}
    + N_{k} \mathbf S_{k} \\
\nu_k &=& N_k + \nu_0
\end{eqnarray}

できた.

r_{nk}, \rho_{nk} を更新できる形に変形

さて, 下の  \rho_{nk} に戻る. これをデータ点単位で正規化したものが  r_{nk} である. このままでは期待値の項がよくわからないので, 得られた事後分布のパラメータの式に変形する必要が有る.



\begin{eqnarray}
\ln \rho_{nk} &=&
    \frac {1} {2} \mathbb{E}\left\lbrack 
        \operatorname {ln}\left| \boldsymbol \Lambda_{k} \right|
        \right\rbrack
    - \frac  {D} {2} \operatorname {ln} \left( 2\pi\right)
    - \frac {1} {2} \mathbb{E}_{\boldsymbol \mu_{k},\boldsymbol \Lambda_{k}}\left\lbrack
        \left(\mathbf x_{n} - \boldsymbol \mu_{k} \right)^{\rm T}
        \boldsymbol \Lambda_{k}
        \left(\mathbf x_{n} - \boldsymbol \mu_{k} \right)
        \right\rbrack
    + \mathbb{E}\left\lbrack
        \operatorname {ln} \pi_{k}
    \right\rbrack \\
\end{eqnarray}

その前に準備. 以下のいずれも 多変量正規分布の特徴Wishart分布の特徴から得られる.



\begin{eqnarray}
\mathbb{E}\left\lbrack \boldsymbol \mu_{k}\right\rbrack
&=&
    \mathbf m_{k} \\
\mathbb{E}\left\lbrack \boldsymbol \mu_{k} \boldsymbol \mu_{k}^{\rm T}\right\rbrack
&=&
    \mathbb{E}\left\lbrack
        \left(\boldsymbol \mu_{k} - \mathbf m_{k} \right)
        \left(\boldsymbol \mu_{k} - \mathbf m_{k} \right)^{\rm T}
    \right\rbrack
    + \mathbf m_{k} \mathbb{E}\left\lbrack \boldsymbol \mu_{k}\right\rbrack^{\rm T}
    + \mathbb{E}\left\lbrack \boldsymbol \mu_{k}\right\rbrack \mathbf m_{k}^{\rm T}
    - \mathbf m_{k}\mathbf m_{k}^{\rm T} \\
&=&
    \beta_{k}^{-1} \boldsymbol \Lambda_{k}^{-1} + \mathbf m_{k}\mathbf m_{k}^{\rm T} \\
\mathbb{E}\left\lbrack \boldsymbol \Lambda_{k}\right\rbrack
&=&
    \nu_{k} \mathbf W_{k} \\
\mathbb{E}\left\lbrack 
    \ln\left| \boldsymbol \Lambda_{k} \right|
    \right\rbrack
&=&
    \psi _{D}\left({\frac {\nu_{k}}{2}}\right)+D\,\ln(2)+\ln |\mathbf {W}_{k} | \\
&=&
    \sum_{i=1}^D \psi \left({\frac {\nu_{k} + 1 - i}{2}}\right)+D\,\ln(2)+\ln |\mathbf {W}_{k} | \\
\end{eqnarray}

一つずつやっつける.



\begin{eqnarray}
\mathbb{E}\left\lbrack 
    \ln\left| \boldsymbol \Lambda_{k} \right|
    \right\rbrack
&=&
    \sum_{i=1}^D \psi \left({\frac {\nu_{k} + 1 - i}{2}}\right)+D\,\ln(2)+\ln |\mathbf {W}_{k} | \\
&{=:}&
    \ln \tilde {\boldsymbol \Lambda}_{k}
\end{eqnarray}

つぎ.



\begin{eqnarray}
\mathbb{E}_{\boldsymbol \mu_{k},\boldsymbol \Lambda_{k}}\left\lbrack
        \left(\mathbf x_{n} - \boldsymbol \mu_{k} \right)^{\rm T}
        \boldsymbol \Lambda_{k}
        \left(\mathbf x_{n} - \boldsymbol \mu_{k} \right)
        \right\rbrack
&=&
    \int \int
        \left(\mathbf x_{n} - \boldsymbol \mu_{k} \right)^{\rm T}
        \boldsymbol \Lambda_{k}
        \left(\mathbf x_{n} - \boldsymbol \mu_{k} \right)
    q^{\ast}(\boldsymbol \mu_{k} \mid \boldsymbol \Lambda_{k})
    q^{\ast}(\boldsymbol \Lambda_{k})
    \operatorname d \boldsymbol \mu_{k}
    \operatorname d \boldsymbol \Lambda_{k} \\
&=&
    \int \int
    \operatorname {tr} \left(
        \left(\mathbf x_{n} - \boldsymbol \mu_{k} \right)^{\rm T}
        \boldsymbol \Lambda_{k}
        \left(\mathbf x_{n} - \boldsymbol \mu_{k} \right)
    \right)
    q^{\ast}(\boldsymbol \mu_{k} \mid \boldsymbol \Lambda_{k})
    q^{\ast}(\boldsymbol \Lambda_{k})
    \operatorname d \boldsymbol \mu_{k}
    \operatorname d \boldsymbol \Lambda_{k} \\
&=&
    \int \int
    \operatorname {tr} \left(
        \boldsymbol \Lambda_{k}
        \left(\mathbf x_{n} - \boldsymbol \mu_{k} \right)
        \left(\mathbf x_{n} - \boldsymbol \mu_{k} \right)^{\rm T}
    \right)
    q^{\ast}(\boldsymbol \mu_{k} \mid \boldsymbol \Lambda_{k})
    q^{\ast}(\boldsymbol \Lambda_{k})
    \operatorname d \boldsymbol \mu_{k}
    \operatorname d \boldsymbol \Lambda_{k} \\
&=&
    \int
    \operatorname {tr} \left(
        \boldsymbol \Lambda_{k}
         \int \left(\mathbf x_{n} - \boldsymbol \mu_{k} \right)
        \left(\mathbf x_{n} - \boldsymbol \mu_{k} \right)^{\rm T}
        q^{\ast}(\boldsymbol \mu_{k} \mid \boldsymbol \Lambda_{k})
        \operatorname d \boldsymbol \mu_{k}
    \right)
    q^{\ast}(\boldsymbol \Lambda_{k})
    \operatorname d \boldsymbol \Lambda_{k} \\
\end{eqnarray}

ここで,



\begin{eqnarray}
\int \left(\mathbf x_{n} - \boldsymbol \mu_{k} \right)
    \left(\mathbf x_{n} - \boldsymbol \mu_{k} \right)^{\rm T}
    q^{\ast}(\boldsymbol \mu_{k} \mid \boldsymbol \Lambda_{k})
    \operatorname d \boldsymbol \mu_{k}
&=&
    \int
    \left(
        \mathbf x_{n} \mathbf x_{n}^{\rm T}
        - \mathbf x_{n} \boldsymbol \mu_{k}^{\rm T}
        - \boldsymbol \mu_{k} \mathbf x_{n}^{\rm T}
        + \boldsymbol \mu_{k} \boldsymbol \mu_{k}^{\rm T}
    \right)
    q^{\ast}(\boldsymbol \mu_{k} \mid \boldsymbol \Lambda_{k})
    \operatorname d \boldsymbol \mu_{k} \\
&=&
    \left(\mathbf x_{n} - \textbf m_{k} \right)
    \left(\mathbf x_{n} - \textbf m_{k} \right)^{\rm T}
    + \beta_{k}^{-1} \boldsymbol \Lambda_{k}^{-1}
    \\
\end{eqnarray}

なので,



\begin{eqnarray}
\mathbb{E}_{\boldsymbol \mu_{k},\boldsymbol \Lambda_{k}}\left\lbrack
        \left(\mathbf x_{n} - \boldsymbol \mu_{k} \right)^{\rm T}
        \boldsymbol \Lambda_{k}
        \left(\mathbf x_{n} - \boldsymbol \mu_{k} \right)
        \right\rbrack
&=&
    \int
    \operatorname {tr} \left(
        \boldsymbol \Lambda_{k}
         \int \left(\mathbf x_{n} - \boldsymbol \mu_{k} \right)
        \left(\mathbf x_{n} - \boldsymbol \mu_{k} \right)^{\rm T}
        q^{\ast}(\boldsymbol \mu_{k} \mid \boldsymbol \Lambda_{k})
        \operatorname d \boldsymbol \mu_{k}
    \right)
    q^{\ast}(\boldsymbol \Lambda_{k})
    \operatorname d \boldsymbol \Lambda_{k} \\
&=&
    \int
    \operatorname {tr} \left(
        \boldsymbol \Lambda_{k}
        \left(\mathbf x_{n} - \textbf m_{k} \right)
        \left(\mathbf x_{n} - \textbf m_{k} \right)^{\rm T}
        + \boldsymbol \Lambda_{k} \beta_{k}^{-1} \boldsymbol \Lambda_{k}^{-1}
    \right)
    q^{\ast}(\boldsymbol \Lambda_{k})
    \operatorname d \boldsymbol \Lambda_{k} \\
&=&
    \int
    \operatorname {tr} \left(
        \left(\mathbf x_{n} - \textbf m_{k} \right)^{\rm T}
        \boldsymbol \Lambda_{k}
        \left(\mathbf x_{n} - \textbf m_{k} \right)
    \right)
    + \operatorname {tr} \left(
        \beta_{k}^{-1}
        \boldsymbol \Lambda_{k}
        \boldsymbol \Lambda_{k}^{-1}
    \right)
    q^{\ast}(\boldsymbol \Lambda_{k})
    \operatorname d \boldsymbol \Lambda_{k} \\
&=&
    \int \left\lbrace
        \left(\mathbf x_{n} - \textbf m_{k} \right)^{\rm T}
            \boldsymbol \Lambda_{k}
            \left(\mathbf x_{n} - \textbf m_{k} \right)
        + D \beta_{k}^{-1}
    \right \rbrace
    q^{\ast}(\boldsymbol \Lambda_{k})
    \operatorname d \boldsymbol \Lambda_{k} \\
&=&
    \nu_{k} \left(\mathbf x_{n} - \textbf m_{k} \right)^{\rm T}
        \mathbf W_{k}
        \left(\mathbf x_{n} - \textbf m_{k} \right)
    + D \beta_{k}^{-1}
\end{eqnarray}

最後. WikipediaのDirichle分布より,



\begin{eqnarray}
\mathbb{E}\left\lbrack \operatorname {ln} \pi_{k} \right\rbrack
&=&
    \psi (\alpha _{k})-\psi \left(\textstyle \sum _{k}\alpha _{k}\right) \\
&{=:}&
    \ln \tilde \pi_{k}
\end{eqnarray}

 \rho_{nk} に戻す.



\begin{eqnarray}
\ln \rho_{nk}
&=&
    \frac {1} {2} \ln \tilde {\boldsymbol \Lambda}_{k}
    - \frac  {D} {2} \operatorname {ln} \left( 2\pi\right)
    - \frac {1} {2}\left(
        \nu_{k} \left(\mathbf x_{n} - \textbf m_{k} \right)^{\rm T}
            \mathbf W_{k}
            \left(\mathbf x_{n} - \textbf m_{k} \right)
        + D \beta_{k}^{-1}
    \right) + \ln \tilde \pi_{k}
\end{eqnarray}

mk, Wk, nuk, bk, ak, r, rho

これを正規化すると \mathbb E \left\lbrack z_{nk} \right \rbrack = r_{nk} = \frac {\rho_{nk}} {\sum_{j=1}^{K} \rho_{nj}} が得られる.

できた. 計算上は \alpha_{k},\;\boldsymbol \mu_{k},\;\beta_{k},\; \mathbf W_{k},\;\nu_{k}に初期値を与えて, r_{nk} (\rho_{nk}) の更新と \alpha_{k},\;\boldsymbol \mu_{k},\;\beta_{k},\; \mathbf W_{k},\;\nu_{k} の更新を順繰りに回す.

事後予測分布

最後に  \hat {\mathbf x} とそれに対応する潜在変数  \hat {\mathbf z} を用いて事後予測分布を導出する.



\begin{eqnarray}
p(\widehat {\mathbf x} \mid \mathbf X) &=& \sum_{\widehat {\mathbf z}} \int \int \int
    p\left(\widehat {\mathbf x} \mid \widehat {\mathbf z}, \boldsymbol \mu, \boldsymbol \Lambda \right)
    p\left(\widehat {\mathbf z} \mid \boldsymbol \pi \right)
    p\left(\boldsymbol \pi, \boldsymbol \mu, \boldsymbol \Lambda \mid \mathbf X\right)
    \operatorname d\boldsymbol \pi
    \operatorname d\boldsymbol \mu
    \operatorname d\boldsymbol \Lambda \\
&=&
    \sum_{k=1}^K \int \int \int
    \pi_{k}
    p\left(\widehat {\mathbf x} \mid \boldsymbol \mu_{k}, \boldsymbol \Lambda_{k} \right)
    p\left(\boldsymbol \pi, \boldsymbol \mu, \boldsymbol \Lambda \mid \mathbf X\right)
    \operatorname d\boldsymbol \pi
    \operatorname d\boldsymbol \mu
    \operatorname d\boldsymbol \Lambda \\
&\simeq& 
    \sum_{k=1}^K \int \int
    \left(\int
    \pi_{k}
    q\left(\boldsymbol \pi\right)
    \operatorname d\boldsymbol \pi
    \right)
    p\left(\widehat {\mathbf x} \mid \boldsymbol \mu_{k}, \boldsymbol \Lambda_{k} \right)
    q\left(\boldsymbol \mu_{k}, \boldsymbol \Lambda_{k}\right)
    \operatorname d\boldsymbol \mu
    \operatorname d\boldsymbol \Lambda \\
&=& 
    \sum_{k=1}^K
    \frac {a_{k}} {\sum_{n=1}^{K} a_n}
    \int \int
    p\left(\widehat {\mathbf x} \mid \boldsymbol \mu_{k}, \boldsymbol \Lambda_{k} \right)
    q\left(\boldsymbol \mu_{k}, \boldsymbol \Lambda_{k}\right)
    \operatorname d\boldsymbol \mu
    \operatorname d\boldsymbol \Lambda \\
&=&
    \sum_{k=1}^K
    \frac {a_{k}} {\sum_{n=1}^{K} a_n}
    \operatorname {St} \left(
        \widehat {\mathbf x}
        \middle|
        \boldsymbol \mu_k, 
        \frac {\beta_k\left(\nu_k + 1 - D\right)}{1+\beta_k}\textbf W_k,
        \nu_k + 1 - D
    \right)
\end{eqnarray}

 q(\boldsymbol \pi) は Dirichlet分布, normal-wishart分布の予測分布の導出は前回やっていて, t分布になる. \operatorname {St} のパラメータは順に, 平均, 精度, 自由度.

可視化

さて可視化する.

import numpy as np
import scipy as sp
from scipy import stats
from scipy import special

import pandas as pd
pd.set_option('display.width', 200)
import matplotlib.pyplot as plt
import matplotlib
plt.style.use('ggplot')
colors = plt.rcParams['axes.prop_cycle'].by_key()['color']

import os
import subprocess
import warnings
warnings.filterwarnings('ignore')

テストデータ

適当なテストデータ. 2つの正規分布の3:7の混合です. データ数は100. ちょっとくっついている.

np.random.seed(0)
mean0 = np.array([0, 0])
cov0 = np.array([[5,2],[2,1]])
mean1 = np.array([5, 7])
cov1 = np.array([[8,0],[0,7]])
N = 100
K = 2
a = 0.7
X = pd.DataFrame(
    np.r_[
        sp.random.multivariate_normal(mean0, cov0, int(N*(1-a))),
        sp.random.multivariate_normal(mean1, cov1, int(N*a))
    ],
    columns=['x', 'y']
)
X.plot(kind='scatter', x='x', y='y').set_aspect('equal')
plt.savefig('1.png', bbox_inches='tight', pad_inches=0)

f:id:kazufusa1484:20180803151609p:plain

事前分布

適当な事前分布.

x = X.values.T
D = 2

# 事前分布
b0 = 1
m0 = np.zeros((D, 1))
W0 = np.diag([0.01, 0.01])
nu0 = 1.1
a0 = 1

パラメータ変数と推定関数

パラメータ変数を定義する関数. 初期値も与える. mkの初期値が乱数なのは, これが2分布で同じ値だと事後分布も同じ分布になってしまうため. 別にWkの初期値を乱数で与えてもいいはず.

def init():
    global mk, Wk, nuk, bk, ak, r, rho
    mk = np.random.rand(K, D, 1)
    Wk = np.array([np.diag([0.01, 0.01]) for _ in np.arange(K)])
    nuk = np.repeat(1.1, K)
    bk = np.ones(K)
    ak = np.ones(K)

    r = np.zeros((N, K))
    rho = np.zeros((N, K))
# init()

r_{nk} を更新する関数.

# mk, Wk, bk, nukの更新
def step1():
    for k in np.arange(K):
        ln_Lambda_tilde_k = ( \
            np.sum(sp.special.digamma((nuk[k]+1-i)/2) for i in np.arange(1, D+1)) \
            + D * np.log(2) \
            + np.log(np.linalg.det(Wk[k])) \
        )
        ln_pi_tilde_k = sp.special.digamma(ak[k]) - sp.special.digamma(ak.sum())
        for n in np.arange(N):
            rho[n, k] = \
                1/2*ln_Lambda_tilde_k \
                - D / 2 * np.log(2*np.pi) \
                - 1/2 * nuk[k] * ((x[:, n:n+1] - mk[k]).T@Wk[k]@(x[:, n:n+1] - mk[k]))[0, 0] \
                - 1/2 * D / bk[k] \
                + ln_pi_tilde_k
            rho[n, k] = np.exp(rho[n, k])
    for n in np.arange(N):
        r[n, :] = rho[n, :] / rho[n, :].sum()

# init()
# step1()

\boldsymbol \alpha_{k}, \mathbf m_{k}, \mathbf W_{k}, \boldsymbol \beta_{k}, \boldsymbol \nu_{k} を更新する関数.

# rの更新
def step2():
    Nk = r.sum(axis=0)
    for k in range(K):
        xbark = x@r[:, k:k+1] / Nk[k]
        Sk = 0
        for n in range(N):
            Sk += r[n, k] * (x[:,n:n+1] - xbark)@(x[:,n:n+1] - xbark).T
        Sk /= Nk[k]

        mk[k] = (Nk[k] * xbark + b0 * m0) / (Nk[k] + b0)
        bk[k] = Nk[k] + b0
        Wkinv = np.linalg.inv(W0) + (Nk[k]*b0)/(Nk[k]+b0) * (xbark - m0)@(xbark - m0).T + Nk[k] * Sk
        Wk[k] = np.linalg.inv(Wkinv)
        nuk[k] = Nk[k] + nu0

        ak[k] = Nk[k] + a0

# init()
# step1()
# step2()

推定の実施

200回の繰り返しを実施. 適当にPythonで書いてもとても速い. 同じ問題をstanで解いても数10秒はかかると思う.

%%time
init()
ak_history = pd.DataFrame(
    [[step1(), step2(), ak.copy()][-1] for _ in range(200)],
    columns=[f'Gr.{x}' for x in range(K)]
)
CPU times: user 893 ms, sys: 342 ms, total: 1.24 s
Wall time: 653 ms

データ数の分類のそれぞれの個数と平均. 真の分布と大体同じ.

print(f"混合比: {ak / ak.sum()}")
[print(f"クラスタ{i}の平均: {mk[i][0][0]:3.5f}, {mk[i][1][0]:3.5f}") for i in range(K)];
混合比: [0.67219215 0.32780785]
クラスタ0の平均: 4.80261, 7.41357
クラスタ1の平均: 0.23764, 0.36730

ステップが進行するにつれて推定がよくなってくれることを確認するために, 混合比の履歴を可視化. 収束している感は出ている.

ax = ak_history.plot(title='Ratio history')
ax.set_ylim(0, 100)
ax.set_xlabel('Iteration')
ax.set_ylabel('ratio')
plt.savefig('2.png', bbox_inches='tight', pad_inches=0);

f:id:kazufusa1484:20180803151632p:plain

事後分布のZを可視化. 綺麗に分類できていることがわかる.

P = X.copy()
P['group'] = r.argmax(axis=1)
ax = P[P['group']==0].plot(kind='scatter', x='x', y='y', c='r', label='Gr.0')
for k in range(1, K):
    P[P['group']==k].plot(kind='scatter', x='x', y='y', label=f'Gr.{k}', c=colors[k], ax=ax).set_aspect('equal')
plt.savefig('3.png', bbox_inches='tight', pad_inches=0)

f:id:kazufusa1484:20180803151644p:plain

事後予測分布を可視化.

def t_pdf(x, m, Lambda, df):
    D = len(m)
    m = m.reshape(D, 1)
    x = x.T

    return np.diag(
        sp.special.gamma(D/2. + df/2.) /
        sp.special.gamma(df/2.) *
        np.power(np.linalg.det(Lambda), 1/2.) / 
        np.power(np.pi * df, D/2.) *
        np.power(1 + (x-m).T@Lambda@(x-m)/df, -(D+df)/2.)
    )
fig, ax = plt.subplots(ncols=1, figsize=(9, 7))

ax.set_aspect('equal')
ax.set_title('posterior predictive distribution')
ax.set_xlabel('x')
ax.set_ylabel('y')
ax.set_aspect('equal')

delta = 0.25
xrange = np.array([-10, 15])
yrange = np.array([-10, 15])
gx = np.arange(*(xrange + delta/2), delta)
gy = np.arange(*(yrange + delta/2), delta)
gxx, gyy = np.meshgrid(gx, gy)
gxgy = np.c_[gxx.ravel(), gyy.ravel()]

# parameters of posterior predictive distribution
zz = ak[0]/ak.sum() * t_pdf(
    gxgy,
    mk[0],
    bk[0] * (nuk[0] + 1 - D) * Wk[0] / (1 + bk[0]),
    (nuk[0] + 1 - D)
) + ak[1]/ak.sum() * t_pdf(
    gxgy,
    mk[1],
    bk[1] * (nuk[1] + 1 - D) * Wk[1] / (1 + bk[1]),
    (nuk[1] + 1 - D)
) 

im = ax.imshow(
    zz.reshape(len(gx), len(gy)), 
    interpolation='none', 
    origin='lower',
    extent=list(xrange) + list(yrange),
    cmap=matplotlib.cm.rainbow
)
fig.colorbar(im, ax=ax).set_label('probability')
X.plot(kind='scatter', x='x', y='y', c='black', ax=ax, label='data', alpha=0.5)
plt.savefig('4.png', bbox_inches='tight', pad_inches=0);

f:id:kazufusa1484:20180803151703p:plain

事後予測分布を描くのも簡単である. おわり.

その他

Variational inferenceは速い上に事後分布が共役事前分布と同じ形になるのでとても便利である. ギブスサンプリングだとこうはいかない.

例えば, 予測分布が簡単に計算できるので, データD次元+データの分類N次元の教師データを使って混合正規分布モデルなりで事後分布を得れば, そこから任意データに対して分類N次元の予測分布を簡単に計算できる. 文字認識とかできそうである. 一度はやってみたい文字認識.

そういえば事後分布が事前分布と同じ形ということは, 逐次的にデータが増える環境において, 事後分布を次の事前分布にして逐次学習とかできそうである. 一度はやってみたいオンライン機械学習.

比較用に同じモデル, 事前分布でStanでサンプリングしたのだが, W0がゆるすぎるのか, 片方の分布の分散共分散がぶっとんでしまう. W0が[[0.1, 0], [0, 0.1]]だとよく推定できた. ただしこの場合でも, 混合分布の推定は分布の添字が入れ替わるので出来合いのrhatのスコアが悪く気持ち悪い. うまい方法はないだろうか.

variational inferenceには初期値依存性がある, ハズである. 今回の例では2分布のパラメータの初期値が同じだと事後分布も同じになる. 混合比は0.5となる. 初期値に乱数をふって複数回の推定が必要なのではないか? その場合はどの初期値での事後分布がいいのか比較と取捨選択が必要になる. どうするのだろう. この辺がわからないのはそもそもvariational inferenceの仕組みを理解していないからだ.

おまけ

こんなデータ.

def mean():
    return np.random.uniform(-10,10,2)

def cov():
    cov = np.random.uniform(-4,4,4)
    cov[0] = np.abs(cov[0])
    cov[3] = np.abs(cov[3])
    cov[2] = cov[1]
    return cov.reshape((2,2))

np.random.seed(0)
N = 200
K = 4
pi = [0.1, 0.2, 0.3, 0.4]
X = pd.DataFrame(
    np.vstack([sp.random.multivariate_normal(mean(), cov(), int(N*pi[i])) for i in range(K)]),
    columns=['x', 'y'])

X.plot(kind='scatter', x='x', y='y').set_aspect('equal')
plt.savefig('5.png', bbox_inches='tight', pad_inches=0)

f:id:kazufusa1484:20180803151717p:plain

x = X.values.T

# 事前分布を変更
W0 = np.diag([1, 1])
%%time
init()
ak_history = pd.DataFrame(
    [[step1(), step2(), ak.copy()][-1] for _ in range(200)],
    columns=[f'Gr.{x}' for x in range(K)]
)
print(f"混合比: {ak / ak.sum()}")
[print(f"クラスタ{i}の平均: {mk[i][0][0]:3.5f}, {mk[i][1][0]:3.5f}") for i in range(K)];
混合比: [0.3970738  0.29961274 0.2011257  0.10218777]
クラスタ0の平均: -8.48932, 4.14327
クラスタ1の平均: 9.38159, 6.79947
クラスタ2の平均: 2.76510, -4.92961
クラスタ3の平均: 1.01452, 4.21432
CPU times: user 3.37 s, sys: 1.11 s, total: 4.48 s
Wall time: 2.26 s
P = X.copy()
ax = plt.subplot()
P['group'] = r.argmax(axis=1)
for k in range(K):
    if P[P['group']==k].empty: continue
    P[P['group']==k].plot(kind='scatter', x='x', y='y', label=f'Gr.{k}', c=colors[k], ax=ax).set_aspect('equal')
plt.savefig('6.png', bbox_inches='tight', pad_inches=0)

f:id:kazufusa1484:20180803151731p:plain