Easy to type

個人的な勉強の記録です。データ分析、可視化などをメイントピックとしています。

はじめての 統計データ分析 ―ベイズ的〈ポストp値時代〉の統計学― その4.1

対応のあるt検定を行う前に、3章の内容を等分散モデルで書くとどうなるか、をちゃんと検証します。

番外編 3章の等分散モデルと変分ベイズによる推定

等分散モデルによる差の推定

といっても、Stanのモデルをちょっと変えるだけなのでした。

まずコンパイル

import os
import pystan
import pickle

# Stanのモデルを読み込んでコンパイルする
# 等分散モデル
stan_file = os.path.join("stan", "g2_ind_equ.stan")
stan_file_c = os.path.join("stan", "g2_ind_equ.pkl")
model = pystan.StanModel(file=stan_file)
with open(stan_file_c, "wb") as f:
    pickle.dump(model, f)

Stanのモデルはこんな感じです。

data {
    int<lower=0> n_a ;
    int<lower=0> n_b ;
    real<lower=0> a[n_a] ;
    real<lower=0> b[n_b] ;
    real<lower=0> c_mu_diff ;
    real<lower=0> c_es ;
    real<lower=0> c_cohenu ;
    real<lower=0> c_pod ;
    real<lower=0> c_pbt ;
    real<lower=0> cdash_pbt ;
}

parameters {
    real<lower=0> mu_a ;
    real<lower=0> sigma ;
    real<lower=0> mu_b ;
}

model {
    a ~ normal(mu_a, sigma) ;
    b ~ normal(mu_b, sigma) ;
}

generated quantities {
    vector[n_a] log_lik ;
    real mu_diff ;
    real es ;
    real cohenu ;
    real pod ;
    real pbt ;
    int<lower=0, upper=1> prob_mu_diff_upper_0 ;
    int<lower=0, upper=1> prob_mu_diff_upper_c ;
    int<lower=0, upper=1> prob_es_upper_c ;
    int<lower=0, upper=1> prob_cohenu_upper_c ;
    int<lower=0, upper=1> prob_pod_upper_c ;
    int<lower=0, upper=1> prob_pbt_upper_cdash ;

    for(i in 1:n_a){
        log_lik[i] = normal_lpdf(a[i] | mu_a, sigma) + normal_lpdf(b[i] | mu_b, sigma) ;
    }
    mu_diff = mu_a - mu_b ;
    es = mu_diff / sigma ;
    cohenu = normal_cdf(mu_a, mu_b, sigma) ;
    pod = normal_cdf(es / sqrt(2), 0, 1) ;
    pbt = normal_cdf((mu_diff - c_pbt) / ( sqrt(2) * sigma ), 0, 1) ;
    prob_mu_diff_upper_0 = mu_diff > 0 ? 1 : 0 ;
    prob_mu_diff_upper_c = mu_diff > c_mu_diff ? 1 : 0 ;
    prob_es_upper_c = es > c_es ? 1 : 0 ;
    prob_cohenu_upper_c = cohenu > c_cohenu ? 1 : 0 ;
    prob_pod_upper_c = pod > c_pod ? 1 : 0 ;
    prob_pbt_upper_cdash = pbt > cdash_pbt ? 1 : 0 ;
}

後はサンプリングを同様に行うだけ。

import pandas as pd
import pickle
import pystan
import matplotlib
import os
import matplotlib.pyplot as plt
from IPython.core.display import display
%matplotlib inline
matplotlib.rcParams['figure.figsize'] = (10, 50)
# 等分散モデル

# Stanで使うデータの用意
stan_data = {"n_a": a.size,
             "n_b": b.size,
             "a": a,
             "b": b,
             "c_mu_diff": 14, #標本平均の差を使ってみる
             "c_es": 3.0, # 標本効果量で、a群からみたもの
             "c_cohenu": 0.95, # a群から見た非重複度
             "c_pod": 0.95, # 優越率
             "c_pbt": 10, # 閾上率の基準値
             "cdash_pbt": 0.60} #閾上率
# 興味のあるパラメータの設定
pars = ["mu_a",
       "sigma",
       "mu_b",
       "log_lik",
       "mu_diff",
       "es",
       "cohenu",
       "pod",
       "pbt",
       "prob_mu_diff_upper_0",
       "prob_mu_diff_upper_c",
       "prob_es_upper_c",
       "prob_cohenu_upper_c",
       "prob_pod_upper_c",
       "prob_pbt_upper_cdash"]
prob = [0.025, 0.05, 0.25, 0.5, 0.75, 0.95, 0.975]

# モデルの読み込み
stan_file_c = os.path.join("stan", "g2_ind_equ.pkl")
with open(stan_file_c, "rb") as f:
    model = pickle.load(f)

# MCMCでサンプリング
fit = model.sampling(data=stan_data,
                     pars=pars,
                     iter=21000,
                     chains=5,
                     warmup=1000,
                     seed=1234,
                     algorithm="NUTS")


# 事後分布の表を取得
summary = fit.summary(pars=pars, probs=prob)
summary_df = pd.DataFrame(summary["summary"],
                          index=summary["summary_rownames"],
                          columns=summary["summary_colnames"])
display(summary_df)

# 事後分布の可視化
for par in summary_df.index[summary_df["sd"] == 0]:
    pars.remove(par)
fit.traceplot(pars)
plt.show()

# WAICの計算
log_lik = fit.extract("log_lik")["log_lik"]
waic = -2 * np.sum(np.log(np.mean(np.exp(log_lik), axis=0))) + 2 * np.sum(np.var(log_lik, axis=0))
logger.info("WAICの値は{0}です".format(waic))

これを実行すると、s.d.についてはEAPが8.63と推定されます。a群とb群のちょうど中間ぐらいですね。 両方の分布に同じ分散を使っているので、推定される値としては納得できるものです。 ただ尤度については716と不等分散モデルより高くなってしまいました。よってデータにはフィッティングしていないと捉えます。

変分ベイズによる推定

これまでは事後分布の推定を、NUTSアルゴリズムによるサンプリングで行ってきました。 ところでStanには、MCMCサンプリング以外にも変分ベイズによる分布の推定も一応実装されています。ついでにこれも試してみましょう。モデルのコンパイル部分までは、上述のものと共通です。

# サンプリングをADVIでやってみる

import pandas as pd
import pickle
import pystan
import matplotlib
import os
import matplotlib.pyplot as plt
from IPython.core.display import display
from collections import OrderedDict
from pystan.external.pymc import plots

%matplotlib inline
matplotlib.rcParams['figure.figsize'] = (10, 50)
# 等分散モデル

# Stanで使うデータの用意
stan_data = {"n_a": a.size,
             "n_b": b.size,
             "a": a,
             "b": b,
             "c_mu_diff": 14, #標本平均の差を使ってみる
             "c_es": 3.0, # 標本効果量で、a群からみたもの
             "c_cohenu": 0.95, # a群から見た非重複度
             "c_pod": 0.95, # 優越率
             "c_pbt": 10, # 閾上率の基準値
             "cdash_pbt": 0.60} #閾上率
# 興味のあるパラメータの設定
pars = ["mu_a",
       "sigma",
       "mu_b",
       "log_lik",
       "mu_diff",
       "es",
       "cohenu",
       "pod",
       "pbt",
       "prob_mu_diff_upper_0",
       "prob_mu_diff_upper_c",
       "prob_es_upper_c",
       "prob_cohenu_upper_c",
       "prob_pod_upper_c",
       "prob_pbt_upper_cdash"]
prob = [0.025, 0.05, 0.25, 0.5, 0.75, 0.95, 0.975]

# モデルの読み込み
stan_file_c = os.path.join("stan", "g2_ind_equ.pkl")
with open(stan_file_c, "rb") as f:
    model = pickle.load(f)

# MCMCでADVIサンプリング
fit = model.vb(data=stan_data,
               pars=pars,
               iter=100000,
               seed=1234)


# 事後分布の表を取得
# ADVIはAPIがまだ完成されていないので、summaryの表を作る方式が違う
# chainを利用したサンプリングもしていないため、Rhatも計算できない
vb_sample = pd.read_csv(fit["args"]["sample_file"].decode("utf-8"), comment="#")
vb_sample = vb_sample.drop("lp__", 1)
summary_df = vb_sample.describe(percentiles=prob).T
display(summary_df)

# カーネル密度推定出来ないパラメータの削除
for par in summary_df.index[summary_df["std"] == 0]:
    pars.remove(par)
## 事後分布の可視化
od = OrderedDict()
for i, par in enumerate(fit["sampler_param_names"]):
    par_s = par.split(".")
    if len(par_s) == 1:
        od[par] = np.array(fit["sampler_params"][i])
    else:
        par = par_s[0]
        if par in od.keys():
            od[par] = np.vstack([od[par], np.array(fit["sampler_params"][i])])
        else:
            od[par] = np.array(fit["sampler_params"][i])
plots.traceplot(od, pars)
plt.show()

# WAICの計算
log_lik = od["log_lik"] 
waic = -2 * np.sum(np.log(np.mean(np.exp(log_lik), axis=0))) + 2 * np.sum(np.var(log_lik, axis=0))
logger.info("WAICの値は{0}です".format(waic))

注意すべきところは、PyStanの変分ベイズAPIが完成されていないためsamplingとやり方がぜんぜん違う点です。 サンプリング結果の集計にはpandasのdescribeメソッドを使っています。

結果は次のとおりです。

count mean std min 2.5% 5% 25% 50% 75% 95% 97.5% max
mu_a 1001 56.4035 1.14076 52.6137 54.1746 54.6317 55.6578 56.3594 57.2304 58.3143 58.6232 59.8892
sigma 1001 7.8148 0.593153 5.79781 6.71698 6.90312 7.43062 7.78376 8.18896 8.82866 9.04493 10.1979
mu_b 1001 40.3831 1.19016 36.5435 38.1027 38.3704 39.5665 40.4121 41.1783 42.2879 42.7135 44.4097
log_lik.1 1001 -6.42276 0.16823 -7.06181 -6.76429 -6.6899 -6.53466 -6.42269 -6.31037 -6.15472 -6.10705 -5.86712
log_lik.2 1001 -6.07869 0.150718 -6.67179 -6.3649 -6.31169 -6.17576 -6.08422 -5.98092 -5.81884 -5.78335 -5.58263
log_lik.3 1001 -8.86784 0.466876 -10.957 -9.91909 -9.69209 -9.1477 -8.81625 -8.52498 -8.19867 -8.08506 -7.54151
log_lik.4 1001 -6.23111 0.153535 -6.66425 -6.54282 -6.48627 -6.33461 -6.2304 -6.12679 -5.96821 -5.93276 -5.71942
log_lik.5 1001 -6.0363 0.152214 -6.58138 -6.32926 -6.28372 -6.13988 -6.03141 -5.93468 -5.77738 -5.73873 -5.49402
log_lik.6 1001 -9.86775 0.610265 -12.5208 -11.2156 -11.0038 -10.2397 -9.81467 -9.41468 -9.01542 -8.86263 -8.08728
log_lik.7 1001 -6.96474 0.214733 -7.70444 -7.41328 -7.32943 -7.10373 -6.95857 -6.81509 -6.63545 -6.57965 -6.26646
log_lik.8 1001 -7.96041 0.353478 -10.3051 -8.72111 -8.56551 -8.16825 -7.93375 -7.71524 -7.44256 -7.35968 -7.04291
log_lik.9 1001 -6.55416 0.179014 -7.26925 -6.9163 -6.838 -6.66726 -6.55088 -6.42996 -6.27155 -6.21939 -6.00683
log_lik.10 1001 -6.20992 0.155187 -6.83945 -6.50715 -6.45275 -6.31807 -6.21339 -6.10742 -5.94728 -5.89824 -5.662
log_lik.11 1001 -7.74777 0.313827 -9.05947 -8.43614 -8.28447 -7.94685 -7.71843 -7.52812 -7.30075 -7.21452 -6.86988
log_lik.12 1001 -6.23652 0.159109 -6.74122 -6.5479 -6.49699 -6.34863 -6.22906 -6.12679 -5.9821 -5.94032 -5.73857
log_lik.13 1001 -6.24006 0.160162 -6.76222 -6.54558 -6.50324 -6.35045 -6.23568 -6.134 -5.98247 -5.93949 -5.70956
log_lik.14 1001 -6.3109 0.159837 -6.75209 -6.61471 -6.57248 -6.42249 -6.30973 -6.19668 -6.06357 -6.02272 -5.77449
log_lik.15 1001 -6.76248 0.194958 -7.7655 -7.14874 -7.08354 -6.89143 -6.75921 -6.62693 -6.45886 -6.40091 -6.21029
log_lik.16 1001 -9.50787 0.558317 -12.0117 -10.7875 -10.4896 -9.84995 -9.46429 -9.09205 -8.71395 -8.58059 -7.86643
log_lik.17 1001 -6.23197 0.159713 -6.77176 -6.55355 -6.51252 -6.33496 -6.22916 -6.12582 -5.97284 -5.91782 -5.47086
log_lik.18 1001 -10.937 0.792785 -16.4895 -12.6988 -12.3467 -11.3937 -10.8907 -10.4062 -9.72509 -9.5612 -8.87464
log_lik.19 1001 -6.55719 0.178717 -7.10137 -6.92829 -6.86605 -6.66928 -6.54786 -6.43631 -6.2802 -6.22669 -5.77744
log_lik.20 1001 -6.92278 0.21248 -7.83707 -7.3648 -7.28035 -7.0519 -6.91023 -6.78217 -6.60749 -6.55017 -6.33407
log_lik.21 1001 -7.55595 0.295133 -9.58726 -8.17106 -8.06674 -7.72853 -7.5413 -7.3452 -7.10152 -7.04768 -6.85252
log_lik.22 1001 -7.55767 0.281643 -9.09015 -8.19547 -8.05543 -7.72132 -7.53802 -7.36901 -7.14595 -7.075 -6.81468
log_lik.23 1001 -7.35302 0.270502 -8.89604 -7.91456 -7.82611 -7.51521 -7.32474 -7.17505 -6.95945 -6.88067 -6.60574
log_lik.24 1001 -6.94368 0.215861 -8.22401 -7.36714 -7.31273 -7.08547 -6.93466 -6.79251 -6.616 -6.54842 -6.35364
log_lik.25 1001 -6.71402 0.191982 -7.69412 -7.09788 -7.03066 -6.84066 -6.71253 -6.58261 -6.40948 -6.34514 -6.1755
log_lik.26 1001 -9.53694 0.584227 -13.4754 -10.843 -10.567 -9.86342 -9.49629 -9.13708 -8.68395 -8.54657 -8.1407
log_lik.27 1001 -6.78402 0.201555 -7.52225 -7.20148 -7.13041 -6.91297 -6.78015 -6.64539 -6.46899 -6.42453 -6.18105
log_lik.28 1001 -6.44432 0.167778 -6.96219 -6.78804 -6.72768 -6.55666 -6.44535 -6.32604 -6.18597 -6.13832 -5.85676
log_lik.29 1001 -7.69239 0.315538 -9.81127 -8.36852 -8.23455 -7.87711 -7.67732 -7.4738 -7.2188 -7.15845 -6.87702
log_lik.30 1001 -7.7927 0.319736 -9.13938 -8.50667 -8.34904 -7.99599 -7.76069 -7.5676 -7.33473 -7.24417 -6.90979
log_lik.31 1001 -6.16013 0.157118 -6.71291 -6.45853 -6.41783 -6.26426 -6.15737 -6.05708 -5.90701 -5.86343 -5.49712
log_lik.32 1001 -6.13017 0.151897 -6.74355 -6.41578 -6.37076 -6.23383 -6.13207 -6.0305 -5.87105 -5.83521 -5.59475
log_lik.33 1001 -7.00073 0.223585 -7.81207 -7.48043 -7.40014 -7.14758 -6.99534 -6.84432 -6.65199 -6.60361 -6.3401
log_lik.34 1001 -6.92585 0.215031 -7.65724 -7.38092 -7.30914 -7.06934 -6.91029 -6.78167 -6.59183 -6.54201 -6.25093
log_lik.35 1001 -6.78402 0.201555 -7.52225 -7.20148 -7.13041 -6.91297 -6.78015 -6.64539 -6.46899 -6.42453 -6.18105
log_lik.36 1001 -6.45221 0.172966 -7.0252 -6.79691 -6.752 -6.56508 -6.448 -6.33613 -6.17511 -6.12069 -5.69448
log_lik.37 1001 -6.43874 0.169585 -6.94675 -6.7849 -6.73926 -6.54679 -6.43412 -6.32283 -6.16711 -6.11841 -5.65075
log_lik.38 1001 -5.98581 0.149329 -6.52304 -6.27182 -6.22623 -6.0843 -5.98738 -5.88368 -5.74347 -5.69836 -5.48248
log_lik.39 1001 -6.22339 0.155309 -6.81162 -6.52915 -6.47324 -6.33043 -6.22686 -6.12094 -5.9492 -5.9203 -5.72975
log_lik.40 1001 -6.38489 0.163779 -6.94275 -6.72587 -6.64711 -6.49179 -6.38321 -6.27196 -6.11557 -6.0736 -5.90043
log_lik.41 1001 -6.66083 0.18857 -7.36905 -7.064 -6.97551 -6.77774 -6.65256 -6.5262 -6.36894 -6.30808 -6.14803
log_lik.42 1001 -6.03561 0.149328 -6.59567 -6.32253 -6.27471 -6.13574 -6.03972 -5.93328 -5.79028 -5.74892 -5.56119
log_lik.43 1001 -7.9404 0.350003 -10.3391 -8.68739 -8.53891 -8.1428 -7.92091 -7.69262 -7.41976 -7.33845 -7.05618
log_lik.44 1001 -11.5909 0.862142 -15.4103 -13.5076 -13.1583 -12.1299 -11.5137 -10.9734 -10.3584 -10.1675 -9.0104
log_lik.45 1001 -7.77097 0.315238 -9.01645 -8.47879 -8.3403 -7.97236 -7.74171 -7.5439 -7.3051 -7.22634 -6.98188
log_lik.46 1001 -7.6209 0.303068 -9.7381 -8.25764 -8.14321 -7.80149 -7.6036 -7.40669 -7.16314 -7.0904 -6.8879
log_lik.47 1001 -6.97629 0.217257 -7.94213 -7.42008 -7.33636 -7.10676 -6.9649 -6.82924 -6.64425 -6.58657 -6.38334
log_lik.48 1001 -6.66957 0.186988 -7.39538 -7.06688 -6.98032 -6.78594 -6.65865 -6.54712 -6.38723 -6.3303 -6.12713
log_lik.49 1001 -6.63225 0.187684 -7.31207 -7.0097 -6.94733 -6.75122 -6.63039 -6.50275 -6.33472 -6.2794 -6.0357
log_lik.50 1001 -6.32385 0.159621 -6.76867 -6.63689 -6.58101 -6.43788 -6.32294 -6.2061 -6.07091 -6.03085 -5.80358
mu_diff 1001 16.0204 1.66137 10.9645 12.6538 13.2746 14.9254 16.0157 17.0902 18.7446 19.2714 20.8654
es 1001 2.06108 0.259806 1.27133 1.58195 1.65403 1.87671 2.05535 2.23764 2.50222 2.59257 2.90178
cohenu 1001 0.977041 0.0139309 0.898194 0.94317 0.95094 0.969721 0.980078 0.987378 0.993829 0.995237 0.998145
pod 1001 0.924156 0.0256932 0.815664 0.868346 0.878915 0.907751 0.926937 0.943204 0.961581 0.966615 0.979909
pbt 1001 0.705715 0.0529771 0.538187 0.593225 0.61623 0.671596 0.708586 0.741193 0.786155 0.803224 0.849742
prob_mu_diff_upper_0 1001 1 0 1 1 1 1 1 1 1 1 1
prob_mu_diff_upper_c 1001 0.879121 0.32615 0 0 0 1 1 1 1 1 1
prob_es_upper_c 1001 0 0 0 0 0 0 0 0 0 0 0
prob_cohenu_upper_c 1001 0.951049 0.215874 0 0 1 1 1 1 1 1 1
prob_pod_upper_c 1001 0.152847 0.36002 0 0 0 0 0 0 1 1 1
prob_pbt_upper_cdash 1001 0.97003 0.17059 0 0 1 1 1 1 1 1 1

f:id:ajhjhaf:20170525183254p:plain