Easy to type

個人的な勉強の記録です。データ分析、可視化などをメイントピックとしています。

StanのThreading動作確認

TL; DR

導入

PyStanは2.18.0系からStanに搭載されたThreading機能を使うことができるようになりました。公式のドキュメントはこちらになります*1

この機能は、今までのMCMC chainごとにスレッドをフォークして並列化するのではなく、MCMCサンプル内での並列化です。先程のリンク先のExampleの欄にも書かれているとおりに、chain数×設定スレッド数分の並列化ができるように成りました。ただし、まだかなり作りかけの機能であるため、運用上は幾つか疑問点があります。

  • 並列化することでどれほど早くなるのか?
  • 単一のMCMC chainの場合でも並列化されるのか?
  • 並列化したことで推論結果が悪くなることはあるのか?
  • 設定スレッドはモデルのコンパイル時のみ影響するのか?サンプリング時にも影響するのか?

以上の疑問を解決するために、モデルを実際に動かしてみました。

手法

環境

次のとおりです。

  • pystan 2.18.0
  • python 3.6.5
  • CPU数 32
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import os
from scipy.stats import norm
from pystan import StanModel
import pystan

print(pystan.__version__)
print(os.cpu_count())

モデル

Exampleに記載されているeight-school problemのデータとモデルをそのまま利用します。

model_code = """
functions {
  vector bl_glm(vector mu_sigma, vector beta,
                real[] x, int[] y) {
    vector[2] mu = mu_sigma[1:2];
    vector[2] sigma = mu_sigma[3:4];
    real lp = normal_lpdf(beta | mu, sigma);
    real ll = bernoulli_logit_lpmf(y | beta[1] + beta[2] * to_vector(x));
    return [lp + ll]';
  }
}
data {
  int<lower = 0> K;
  int<lower = 0> N;
  vector[N] x;
  int<lower = 0, upper = 1> y[N];
}
transformed data {
  int<lower = 0> J = N / K;
  real x_r[K, J];
  int<lower = 0, upper = 1> x_i[K, J];
  {
    int pos = 1;
    for (k in 1:K) {
      int end = pos + J - 1;
      x_r[k] = to_array_1d(x[pos:end]);
      x_i[k] = y[pos:end];
      pos += J;
    }
  }
}
parameters {
  vector[2] beta[K];
  vector[2] mu;
  vector<lower=0>[2] sigma;
}
model {
  mu ~ normal(0, 2);
  sigma ~ normal(0, 2);
  target += sum(map_rect(bl_glm, append_row(mu, sigma),
                         beta, x_r, x_i));
}
"""

stan_data = dict(
    K = 4,
    N = 12,
    x = [1.204, -0.573, -1.35, -1.157,
         -1.29, 0.515, 1.496, 0.918,
         0.517, 1.092, -0.485, -2.157],
    y = [1, 1, 0, 0, 1, 1, 0, 0, 0, 1, 1, 1]
)

パラメータ

モデルは次の4つをコンパイルしました。

  • extra_args等の一切を無指定(以下Vanillaと呼称)
  • extra_argsを指定 + STAN_NUM_THREADS=1と指定(THREADS1と呼称)
  • extra_argsを指定 + STAN_NUM_THREADS=4と指定(THREADS4と呼称)
  • extra_argsを指定 + STAN_NUM_THREADS=8と指定(THREADS8と呼称)

これとは別に以下のパラメータをサンプリング時に指定しました。

  • サンプリング直前にSTAN_NUM_THREADSを再設定
  • n_jobs(公式APIでは"Sample in parallel. If -1 all CPUs are used. If 1, no parallel computing code is used at all, which is useful for debugging."と書かれているが、これがThreadingに影響するのか、MCMC chainだけに影響するのか不透明)
  • iter(MCMCの反復数)
  • chains(MCMCのchain数)

実験

上述のパラメータを設定した上で、jupyter labの%timeitオプションを付けて実行しました。各パラメータセットごとに7回反復されて、平均が取られました。notebookはこちらです。

結果

index compile num_threads sampling num_threads n_jobs iter chains avarage time std time
1 Vanilla 1 4 2000 2 0.883 0.121
2 Vanilla 1 2 2000 2 0.834 0.0624
3 Vanilla 1 4 2000 2 0.935 0.181
4 Vanilla 1 1 2000 2 1.21 0.413
5 Vanilla 1 1 2000 1 0.554 0.0567
6 Vanilla 1 1 20000 1 4.36 0.296
7 Vanilla 1 8 2000 2 0.82 0.0896
8 1 1 4 2000 2 0.936 0.0743
9 1 1 2 2000 2 0.967 0.101
10 1 1 4 2000 2 0.937 0.0693
11 1 1 1 2000 2 1.27 0.0501
12 1 1 1 2000 1 0.678 0.0833
13 1 1 1 20000 1 5.34 0.804
14 1 1 8 2000 2 0.962 0.0777
15 4 4 4 2000 2 3.63 0.547
16 4 4 2 2000 2 3.54 0.434
17 4 1 4 2000 2 0.99 0.0883
18 4 4 1 2000 2 5.4 0.576
19 4 4 1 2000 1 2.68 0.216
20 4 4 1 20000 1 21.1 1.11
21 4 4 8 2000 2 3.27 0.274
22 8 8 4 2000 2 3.28 0.456
23 8 8 2 2000 2 3.43 0.263
24 8 1 4 2000 2 0.865 0.13
25 8 8 1 2000 2 4.91 0.434
26 8 8 1 2000 1 2.52 0.314
27 8 8 1 20000 1 22.8 4.12
28 8 8 8 2000 2 3.31 0.341

average timeとstd timeの単位は秒です。

  • VanillaはSTAN_NUM_THREADSが1の場合と殆ど変わりません。中途処理があるからか、0.1秒ほど遅くなっています(Compare 1, 2, 7... vs 8, 9, ..., 14)。

  • n_jobsはchain数より大きい値を入れても早くなりません(Compare 15 vs 16等)。

    • 一方で、chain数より小さな値を入れれば、並列でMCMCできなくなるので遅くなります(Compare 15 vs 18等)。
    • また、STAN_NUM_THREADSのパフォーマンスにも影響しないようです(Compare 22 vs 23)。
  • 環境変数に設定したSTAN_NUM_THREADSは、サンプリング時にも影響するようです(Compare 15 vs 17等)。

    • コンパイル時の値は上限スレッド数というだけで、実効値は環境変数を参照するのでしょう。
    • 即ち、SGEなどのqueuing systemを使っているときには、サンプリング時に環境変数をセットする必要があります。
  • 推論結果は、見た感じ悪化しませんでした(実行結果を参照)。

  • 一番大事なことですが、 Threadingをしても遅くなるだけです (Comapare 2 vs 9 vs 16 vs 23)。

    • モデルが単純でデータが少ないからでしょうか?iter数を増やしてみればThreadingの恩恵があるかと思いましたが、傾向は変わりませんでした(Compare 6 vs 13 vs 20 vs 27)。
    • ぜひどなたかに、複雑なモデルで実験してほしいです(自分には適当なものが思いつきませんでした)。
    • %timeitがマルチスレッドの場合は時間×スレッド数をはじき出すかと思い、timeモジュール+自分の時間間隔でも計測しましたが、やはり遅かったです。
  • サンプリング時にhtopコマンドでCPU利用数を見ていましたが、高々200%ぐらいになるだけで、本来理想としていた800%等になることはありませんでした。


追記(2018-10-10)

@nan_makersstat さんや id:StatModeling (@hankagosa)さんから次のような指摘をいただきました。

というわけで、軽く実験してみます。モデルは簡単な一次回帰です。シミュレーション的にデータを10万点作成します。 注: 以下のモデルはThreadingで必要なmap_rect関数が使われていないので早くなるはずはありません。map_rect用のモデルは後述しました(2018-10-19)

model_code = """
  data {
    int I ;
    real x[I] ;
    real y[I] ;
  }
  
  parameters {
    real a ;
    real b ;
  }
  
  model {
    for (i in 1:I){
      y[i] ~ normal(a * x[i], b) ;
    }
  }
  
  generated quantities {
    real log_likelihood[I] ;
    for (i in 1:I){
      log_likelihood[i] = normal_lpdf(y[i] | a * x[i], b) ;
    }
  }
"""

I = 100000
a = 1
b = 1
x = np.linspace(-1, 1, I)
y = norm.rvs(loc=a*x, scale=b)

stan_data = {
    "I": I,
    "x": x,
    "y": y
}

これをVanilla, STAN_NUM_THREADS=1, STAN_NUM_THREADS=4で投入してみました。計測方法は同じで7回の平均です。

index compile num_threads sampling num_threads n_jobs iter chains avarage time std time
1 Vanilla 1 2 100000 2 111 5.5
2 1 1 2 100000 2 113 5.65
3 4 4 2 100000 2 108 5.71

うーん、ちょっとだけ早くなっています。しかしt検定で有意差が出る気がしません。おそらくデータ数を増やせばより効果があるのではないでしょうか。 何れにせよ、小標本なのにThreadingが必要なほどサンプリングで時間がかかりすぎている場合は、モデルが適切でない可能性をまず考えるべきです。丁寧に、メカニズムを考えて、モデルを改良していきましょう!


追記(2018-10-18)

discourseで質問したところ、ThreadingやMPIのような操作をするためには、データやモデルをそれ専用に作り変える必要があるとの回答を頂きました。 まだ試せていませんが、cmdstanを使った例についてはこのリポジトリが参考になります。ご参考までに。


追記(2018-10-19)

map_rect用のモデルを別に書いて実験しました。ついでにデータ数は合成に100万まで増やしてみました。実験につかったnotebookはこちらです。

index compile num_threads sampling num_threads n_jobs iter chains avarage time std time
1 Vanilla 1 1 500 1 128 6.71
2 4 4 4 500 1 173 5.71

なんと、やっぱり遅くなってしまいました。しかし、やはりtop等のコマンドで確認してもCPU使用率が100%までいかないんですよね。別の原因がある気がしますが、現状不明です。

*1:この機能は単体のマシンでの並列化であり、MPI 環境を使った並列化はまた別に実装されています