StanのThreading動作確認
TL; DR
Stanのmap_rect使うchain内並列化は、たぶんデータ点が10万ぐらいはないとご利益ないような…。その10万個を切って並列処理する感じ。
— Kentaro Matsuura (@hankagosa) 2018年10月10日
導入
PyStanは2.18.0系からStanに搭載されたThreading機能を使うことができるようになりました。公式のドキュメントはこちらになります*1。
この機能は、今までのMCMC chainごとにスレッドをフォークして並列化するのではなく、MCMCサンプル内での並列化です。先程のリンク先のExampleの欄にも書かれているとおりに、chain数×設定スレッド数分の並列化ができるように成りました。ただし、まだかなり作りかけの機能であるため、運用上は幾つか疑問点があります。
- 並列化することでどれほど早くなるのか?
- 単一のMCMC chainの場合でも並列化されるのか?
- 並列化したことで推論結果が悪くなることはあるのか?
- 設定スレッドはモデルのコンパイル時のみ影響するのか?サンプリング時にも影響するのか?
以上の疑問を解決するために、モデルを実際に動かしてみました。
手法
環境
次のとおりです。
- pystan 2.18.0
- python 3.6.5
- CPU数 32
import numpy as np import pandas as pd import matplotlib.pyplot as plt import os from scipy.stats import norm from pystan import StanModel import pystan print(pystan.__version__) print(os.cpu_count())
モデル
Exampleに記載されているeight-school problemのデータとモデルをそのまま利用します。
model_code = """ functions { vector bl_glm(vector mu_sigma, vector beta, real[] x, int[] y) { vector[2] mu = mu_sigma[1:2]; vector[2] sigma = mu_sigma[3:4]; real lp = normal_lpdf(beta | mu, sigma); real ll = bernoulli_logit_lpmf(y | beta[1] + beta[2] * to_vector(x)); return [lp + ll]'; } } data { int<lower = 0> K; int<lower = 0> N; vector[N] x; int<lower = 0, upper = 1> y[N]; } transformed data { int<lower = 0> J = N / K; real x_r[K, J]; int<lower = 0, upper = 1> x_i[K, J]; { int pos = 1; for (k in 1:K) { int end = pos + J - 1; x_r[k] = to_array_1d(x[pos:end]); x_i[k] = y[pos:end]; pos += J; } } } parameters { vector[2] beta[K]; vector[2] mu; vector<lower=0>[2] sigma; } model { mu ~ normal(0, 2); sigma ~ normal(0, 2); target += sum(map_rect(bl_glm, append_row(mu, sigma), beta, x_r, x_i)); } """ stan_data = dict( K = 4, N = 12, x = [1.204, -0.573, -1.35, -1.157, -1.29, 0.515, 1.496, 0.918, 0.517, 1.092, -0.485, -2.157], y = [1, 1, 0, 0, 1, 1, 0, 0, 0, 1, 1, 1] )
パラメータ
モデルは次の4つをコンパイルしました。
- extra_args等の一切を無指定(以下Vanillaと呼称)
- extra_argsを指定 + STAN_NUM_THREADS=1と指定(THREADS1と呼称)
- extra_argsを指定 + STAN_NUM_THREADS=4と指定(THREADS4と呼称)
- extra_argsを指定 + STAN_NUM_THREADS=8と指定(THREADS8と呼称)
これとは別に以下のパラメータをサンプリング時に指定しました。
- サンプリング直前にSTAN_NUM_THREADSを再設定
- n_jobs(公式APIでは"Sample in parallel. If -1 all CPUs are used. If 1, no parallel computing code is used at all, which is useful for debugging."と書かれているが、これがThreadingに影響するのか、MCMC chainだけに影響するのか不透明)
- iter(MCMCの反復数)
- chains(MCMCのchain数)
実験
上述のパラメータを設定した上で、jupyter labの%timeit
オプションを付けて実行しました。各パラメータセットごとに7回反復されて、平均が取られました。notebookはこちらです。
結果
index | compile num_threads | sampling num_threads | n_jobs | iter | chains | avarage time | std time |
---|---|---|---|---|---|---|---|
1 | Vanilla | 1 | 4 | 2000 | 2 | 0.883 | 0.121 |
2 | Vanilla | 1 | 2 | 2000 | 2 | 0.834 | 0.0624 |
3 | Vanilla | 1 | 4 | 2000 | 2 | 0.935 | 0.181 |
4 | Vanilla | 1 | 1 | 2000 | 2 | 1.21 | 0.413 |
5 | Vanilla | 1 | 1 | 2000 | 1 | 0.554 | 0.0567 |
6 | Vanilla | 1 | 1 | 20000 | 1 | 4.36 | 0.296 |
7 | Vanilla | 1 | 8 | 2000 | 2 | 0.82 | 0.0896 |
8 | 1 | 1 | 4 | 2000 | 2 | 0.936 | 0.0743 |
9 | 1 | 1 | 2 | 2000 | 2 | 0.967 | 0.101 |
10 | 1 | 1 | 4 | 2000 | 2 | 0.937 | 0.0693 |
11 | 1 | 1 | 1 | 2000 | 2 | 1.27 | 0.0501 |
12 | 1 | 1 | 1 | 2000 | 1 | 0.678 | 0.0833 |
13 | 1 | 1 | 1 | 20000 | 1 | 5.34 | 0.804 |
14 | 1 | 1 | 8 | 2000 | 2 | 0.962 | 0.0777 |
15 | 4 | 4 | 4 | 2000 | 2 | 3.63 | 0.547 |
16 | 4 | 4 | 2 | 2000 | 2 | 3.54 | 0.434 |
17 | 4 | 1 | 4 | 2000 | 2 | 0.99 | 0.0883 |
18 | 4 | 4 | 1 | 2000 | 2 | 5.4 | 0.576 |
19 | 4 | 4 | 1 | 2000 | 1 | 2.68 | 0.216 |
20 | 4 | 4 | 1 | 20000 | 1 | 21.1 | 1.11 |
21 | 4 | 4 | 8 | 2000 | 2 | 3.27 | 0.274 |
22 | 8 | 8 | 4 | 2000 | 2 | 3.28 | 0.456 |
23 | 8 | 8 | 2 | 2000 | 2 | 3.43 | 0.263 |
24 | 8 | 1 | 4 | 2000 | 2 | 0.865 | 0.13 |
25 | 8 | 8 | 1 | 2000 | 2 | 4.91 | 0.434 |
26 | 8 | 8 | 1 | 2000 | 1 | 2.52 | 0.314 |
27 | 8 | 8 | 1 | 20000 | 1 | 22.8 | 4.12 |
28 | 8 | 8 | 8 | 2000 | 2 | 3.31 | 0.341 |
average timeとstd timeの単位は秒です。
VanillaはSTAN_NUM_THREADSが1の場合と殆ど変わりません。中途処理があるからか、0.1秒ほど遅くなっています(Compare 1, 2, 7... vs 8, 9, ..., 14)。
n_jobsはchain数より大きい値を入れても早くなりません(Compare 15 vs 16等)。
- 一方で、chain数より小さな値を入れれば、並列でMCMCできなくなるので遅くなります(Compare 15 vs 18等)。
- また、STAN_NUM_THREADSのパフォーマンスにも影響しないようです(Compare 22 vs 23)。
環境変数に設定したSTAN_NUM_THREADSは、サンプリング時にも影響するようです(Compare 15 vs 17等)。
推論結果は、見た感じ悪化しませんでした(実行結果を参照)。
一番大事なことですが、 Threadingをしても遅くなるだけです (Comapare 2 vs 9 vs 16 vs 23)。
- モデルが単純でデータが少ないからでしょうか?iter数を増やしてみればThreadingの恩恵があるかと思いましたが、傾向は変わりませんでした(Compare 6 vs 13 vs 20 vs 27)。
- ぜひどなたかに、複雑なモデルで実験してほしいです(自分には適当なものが思いつきませんでした)。
%timeit
がマルチスレッドの場合は時間×スレッド数をはじき出すかと思い、time
モジュール+自分の時間間隔でも計測しましたが、やはり遅かったです。
サンプリング時に
htop
コマンドでCPU利用数を見ていましたが、高々200%ぐらいになるだけで、本来理想としていた800%等になることはありませんでした。
追記(2018-10-10)
@nan_makersstat さんや id:StatModeling (@hankagosa)さんから次のような指摘をいただきました。
ちょうど仕事でMCMC高速化をすべく使っていて、全然速くならないなと思ってたら案の定「遅くなる」との結論だった…
— 📊🔨📝 (narrowly*) (@nan_makersstat) 2018年10月10日
Stanのmap_rect使うchain内並列化は、たぶんデータ点が10万ぐらいはないとご利益ないような…。その10万個を切って並列処理する感じ。
— Kentaro Matsuura (@hankagosa) 2018年10月10日
というわけで、軽く実験してみます。モデルは簡単な一次回帰です。シミュレーション的にデータを10万点作成します。 注: 以下のモデルはThreadingで必要なmap_rect関数が使われていないので早くなるはずはありません。map_rect用のモデルは後述しました(2018-10-19)
model_code = """ data { int I ; real x[I] ; real y[I] ; } parameters { real a ; real b ; } model { for (i in 1:I){ y[i] ~ normal(a * x[i], b) ; } } generated quantities { real log_likelihood[I] ; for (i in 1:I){ log_likelihood[i] = normal_lpdf(y[i] | a * x[i], b) ; } } """ I = 100000 a = 1 b = 1 x = np.linspace(-1, 1, I) y = norm.rvs(loc=a*x, scale=b) stan_data = { "I": I, "x": x, "y": y }
これをVanilla, STAN_NUM_THREADS=1, STAN_NUM_THREADS=4で投入してみました。計測方法は同じで7回の平均です。
index | compile num_threads | sampling num_threads | n_jobs | iter | chains | avarage time | std time |
---|---|---|---|---|---|---|---|
1 | Vanilla | 1 | 2 | 100000 | 2 | 111 | 5.5 |
2 | 1 | 1 | 2 | 100000 | 2 | 113 | 5.65 |
3 | 4 | 4 | 2 | 100000 | 2 | 108 | 5.71 |
うーん、ちょっとだけ早くなっています。しかしt検定で有意差が出る気がしません。おそらくデータ数を増やせばより効果があるのではないでしょうか。
何れにせよ、小標本なのにThreadingが必要なほどサンプリングで時間がかかりすぎている場合は、モデルが適切でない可能性をまず考えるべきです。丁寧に、メカニズムを考えて、モデルを改良していきましょう!
追記(2018-10-18)
discourseで質問したところ、ThreadingやMPIのような操作をするためには、データやモデルをそれ専用に作り変える必要があるとの回答を頂きました。 まだ試せていませんが、cmdstanを使った例についてはこのリポジトリが参考になります。ご参考までに。
追記(2018-10-19)
map_rect用のモデルを別に書いて実験しました。ついでにデータ数は合成に100万まで増やしてみました。実験につかったnotebookはこちらです。
index | compile num_threads | sampling num_threads | n_jobs | iter | chains | avarage time | std time |
---|---|---|---|---|---|---|---|
1 | Vanilla | 1 | 1 | 500 | 1 | 128 | 6.71 |
2 | 4 | 4 | 4 | 500 | 1 | 173 | 5.71 |
なんと、やっぱり遅くなってしまいました。しかし、やはりtop等のコマンドで確認してもCPU使用率が100%までいかないんですよね。別の原因がある気がしますが、現状不明です。
*1:この機能は単体のマシンでの並列化であり、MPI 環境を使った並列化はまた別に実装されています