NumPyroは軽量な確率プログラミングライブラリです。PyroをバックエンドにしたNumPyを提供します。自動微分にJAX, JITコンパイラをCPU/GPUに用います。NumPyroは開発中です。そのため設計は発展するので、不安定さ、バグ、APIの変更に注意してください。
インストール
JAXの最新バージョンとNumPyroをインストールするには、pipが使えます。
pip install numpyro
NumPyroをGPUで使うためには、最初にcudaをインストールする必要があります。その後、以下のpipコマンドを使ってください。
pip install numpyro[cuda] -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
ソースからNumPyroをインストールすることができます。
git clone https://github.com/pyro-ppl/numpyro.git
cd numpyro
# install jax/jaxlib first for CUDA support
pip install -e .[dev] # contains additional dependencies for NumPyro development
condaからインストールすることもできます。
conda install -c conda-forge numpyro
例 8スクール
簡単な例を使って、NumPyroを探検しましょう。私たちは、Gelman et al., Bayesian Data Analysis:Sec 5.5 2003 から8スクールの例を使います。それは、8つのスクールのSATパフォーマンスのコーチングの結果を調査したものです。
データは、以下で与えられます。
>>> import numpy as np
>>> J = 8
>>> y = np.array([28.0, 8.0, -3.0, 7.0, -1.0, 1.0, 18.0, 12.0])
>>> sigma = np.array([15.0, 10.0, 16.0, 11.0, 9.0, 11.0, 10.0, 18.0])
ここで、y は処置の効果、そして、sigmaは、標準エラー。私たちは、研究のために仮想モデルを構築します。そこでは各スクールのためにグループレベルパラメータthetaは、未知の平均値 mu, 標準偏差 tau,の正規分布からサンプルされます。一方、観測データは順番に、thetaとsigmaで与えられる平均と標準偏差の正規分布から生成されます。これは、すべての観測値からの蓄えにより集団レベルのパラメータmuとtauを見積もることを私たちに許容します。それでもまだグループレベルのthetaパラメータを使って学校の間の個別の変化を許容します。
>>> import numpyro
>>> import numpyro.distributions as dist
>>> # Eight Schools example
... def eight_schools(J, sigma, y=None):
... mu = numpyro.sample('mu', dist.Normal(0, 5))
... tau = numpyro.sample('tau', dist.HalfCauchy(5))
... with numpyro.plate('J', J):
... theta = numpyro.sample('theta', dist.Normal(mu, tau))
... numpyro.sample('obs', dist.Normal(theta, sigma), obs=y)
No-U-Turn(NUTS)サンプラーを使ってMCMCを走らせることで私たちのモデルの未知のパラメータを推定しましょう。MCMC実行時の引数extra_fieldの使い方に注意してください。デフォルトでは、MCMCを使って推定を走らせるときに、対象の(事後)分布からサンプルを収集するだけです。しかし、潜在エネルギーやサンプルの受容確率のような追加フィールドを集めることは、extra_fields引数を使うことによって簡単にアーカイブすることができます。収集できるフィールドのリストは、HMCStateオブジェクトを見てください。この例では、私たちは、追加で各サンプルのためにpotential_energyを収集することができます。
>>> from jax import random
>>> from numpyro.infer import MCMC, NUTS
>>> nuts_kernel = NUTS(eight_schools)
>>> mcmc = MCMC(nuts_kernel, num_warmup=500, num_samples=1000)
>>> rng_key = random.PRNGKey(0)
>>> mcmc.run(rng_key, J, sigma, y=y, extra_fields=('potential_energy',))
sample: 100%|████| 1500/1500 [00:01<00:00, 1101.91it/s, 47 steps of size 1.72e-01. acc. prob=0.87]
私たちは、MCMC実行の結果を出力することができます。そして、推定中にいくつかの不一致が観測された場合、調査します。追加で、各サンプルの潜在エネルギーを収集したので、私たちは簡単に予測対数結合密度を計算することができます。
>>> mcmc.print_summary()
mean std median 5.0% 95.0% n_eff r_hat
mu 4.41 3.13 4.24 0.33 10.17 182.43 1.04
tau 4.41 3.31 3.39 0.61 8.74 125.48 1.00
theta[0] 6.82 5.95 5.91 -2.21 16.29 263.78 1.02
theta[1] 4.97 4.85 4.76 -2.44 13.30 356.13 1.01
theta[2] 3.77 5.91 4.06 -5.12 12.75 351.54 1.01
theta[3] 4.71 4.77 4.70 -2.85 12.12 359.95 1.02
theta[4] 3.28 4.77 3.63 -4.18 11.12 377.46 1.01
theta[5] 4.07 5.23 4.12 -4.11 12.30 427.85 1.01
theta[6] 6.58 5.03 5.83 -1.35 14.57 256.32 1.02
theta[7] 5.24 5.51 4.80 -3.16 12.94 379.81 1.01
Number of divergences: 11
>>> pe = mcmc.get_extra_fields()['potential_energy']
>>> print('Expected log joint density: {:.2f}'.format(np.mean(-pe))) # doctest: +SKIP
Expected log joint density: -55.58
分割Gelman Rubin診断(r_hat)のための1を上回る値は、チェインが完全に収束しないことを示しています。効果的なサンプルサイズ(n_eff)にとっての低い値、特にtau と不一致の変換の数は、解決が難しく見えます。幸いにもこれは、私たちのモデルのtau に中心のないパラメータ化を使うことによって訂正できる共通の問題です。これは、TransformedDistribution インスタンス、reparameterization効果ハンドラーと一緒にを使うことでNumPyro で単純に実行します。私たちに同じモデルを書き換えさせてください。しかし、Normal(mu,tau)から thetaをサンプリングする代わりに、AffineTransform を使って変換される基礎のNormal(0,1)分布から代わりにサンプルします。そうすることによって、NumPyroは、代わりに基礎Normal(0,1)分布のためにtheta_baseからサンプルを生成することによって、HMCを走らせることに注意してください。私たちは、結果のチェインが同じ問題から苦しむことはないことがわかります。それはすべてのパラメータにとってGelman Rubin診断が1であり、そして効果的なサンプルサイズが、完全に適しているのが見えます。
>>> from numpyro.infer.reparam import TransformReparam
>>> # Eight Schools example - Non-centered Reparametrization
... def eight_schools_noncentered(J, sigma, y=None):
... mu = numpyro.sample('mu', dist.Normal(0, 5))
... tau = numpyro.sample('tau', dist.HalfCauchy(5))
... with numpyro.plate('J', J):
... with numpyro.handlers.reparam(config={'theta': TransformReparam()}):
... theta = numpyro.sample(
... 'theta',
... dist.TransformedDistribution(dist.Normal(0., 1.),
... dist.transforms.AffineTransform(mu, tau)))
... numpyro.sample('obs', dist.Normal(theta, sigma), obs=y)
>>> nuts_kernel = NUTS(eight_schools_noncentered)
>>> mcmc = MCMC(nuts_kernel, num_warmup=500, num_samples=1000)
>>> rng_key = random.PRNGKey(0)
>>> mcmc.run(rng_key, J, sigma, y=y, extra_fields=('potential_energy',))
sample: 100%|█████| 1500/1500 [00:01<00:00, 1108.63it/s, 7 steps of size 5.72e-01. acc. prob=0.82]
>>> mcmc.print_summary(exclude_deterministic=False) # doctest: +SKIP
mean std median 5.0% 95.0% n_eff r_hat
mu 4.33 3.36 4.50 -0.50 10.30 940.18 1.00
tau 3.65 3.37 2.72 0.00 8.01 435.85 1.00
theta[0] 6.12 5.59 5.46 -2.18 15.39 874.95 1.00
theta[1] 4.86 5.11 4.81 -4.03 12.31 1044.76 1.00
theta[2] 3.56 5.39 3.99 -5.04 11.41 880.21 1.00
theta[3] 4.69 4.47 4.76 -2.88 11.12 1104.25 1.00
theta[4] 3.67 4.81 3.82 -2.77 12.51 1168.40 1.00
theta[5] 3.90 4.78 3.86 -3.68 12.04 953.79 1.00
theta[6] 6.46 5.23 5.94 -1.94 14.23 838.76 1.00
theta[7] 4.69 5.25 4.68 -3.55 11.83 793.00 1.00
theta_base[0] 0.28 0.95 0.27 -1.29 1.94 1216.01 1.00
theta_base[1] 0.11 0.96 0.13 -1.56 1.51 1268.95 1.00
theta_base[2] -0.09 0.93 -0.10 -1.57 1.45 1983.66 1.00
theta_base[3] 0.06 0.93 0.07 -1.55 1.49 1606.13 1.00
theta_base[4] -0.13 0.98 -0.13 -1.70 1.45 1565.53 1.00
theta_base[5] -0.08 0.95 -0.11 -1.43 1.62 1280.40 1.00
theta_base[6] 0.39 0.93 0.42 -1.15 1.90 1114.74 1.00
theta_base[7] 0.05 0.94 0.07 -1.48 1.66 1375.19 1.00
Number of divergences: 1
>>> pe = mcmc.get_extra_fields()['potential_energy']
>>> # Compare with the earlier value
>>> print('Expected log joint density: {:.2f}'.format(np.mean(-pe))) # doctest: +SKIP
Expected log joint density: -46.17
Normal, Cauchy, Student-Tのようなloc,scale パラメータを伴う分布のクラスのために、私たちは、同じ目的でアーカイブするために LocScaleReparam再パラメータ化を提供することに注意してください。一致するコードは、以下のとおり
with numpyro.handlers.reparam(config={'theta': LocScaleReparam(centered=0)}):
theta = numpyro.sample('theta', dist.Normal(mu, tau))
ここで、私たちは、どのようなテストスコアも観測していない新しいスクールを持つことを仮定させてください。任意の観測データの欠測に注意してください。私たちは、予測を生成するために集団レベルのパラメータを簡単に使うことができます。観測していないmuとtauのPredictive ユーティリティ条件は、最後のMCMC実行の事後分布から出力される値を配置します。そして、予測を生成するためにモデルを前に進めます。
>>> from numpyro.infer import Predictive
>>> # New School
... def new_school():
... mu = numpyro.sample('mu', dist.Normal(0, 5))
... tau = numpyro.sample('tau', dist.HalfCauchy(5))
... return numpyro.sample('obs', dist.Normal(mu, tau))
>>> predictive = Predictive(new_school, mcmc.get_samples())
>>> samples_predictive = predictive(random.PRNGKey(1))
>>> print(np.mean(samples_predictive['obs'])) # doctest: +SKIP
4.5596056
このサンプルコードはNumPyroを使った実装例ですが、8スクールのアルゴリズムは、以下にも同様にStanを使った実装例としてサンプルを上げています。