Twitterで質問があったので,書いておきます。
Stanでは,functions{}ブロックでユーザーが関数を定義できます。
普通の関数は,Rと同じように作ればいいのでそれほど難しくはないというか,違和感はないと思いますが,Stanではサンプリングのための関数も定義できて,それでMCMCもできてしまいます。
以下では,自分で正規分布の式を書いて定義したnormal2という確率分布の関数を作って,それでMCMCをする,ということをやってみます。
まず,普通にStanに最初から入ってる正規分布を使って平均と標準偏差を推定するコードを書いておきます。これをnormal1.stanという名前で保存します。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 |
data{ int N; //サンプルサイズ real y[N]; //データ } parameters{ real mu; //平均値 real } model{ mu ~ normal(0,100); //平均値の事前分布 sigma ~ cauchy(0,5); //標準偏差の事前分布 for(n in 1:N){ y[n] ~ normal(mu,sigma); } } |
これを次のRコードで走らせます。乱数は正規分布から100個,平均=5,SD=2を母数として生成します。
1 2 3 4 5 6 7 |
library(rstan) set.seed(123) y <- rnorm(100,5,2) fit.normal <- stan("normal1.stan",data=list(N=100,y=y)) fit.normal |
結果は次のようになります。
1 2 3 4 |
mean se_mean sd 2.5% 25% 50% 75% 97.5% n_eff Rhat mu 5.18 0.00 0.19 4.82 5.06 5.18 5.31 5.55 2253 1 sigma 1.84 0.00 0.14 1.60 1.75 1.83 1.93 2.15 2326 1 lp__ -110.28 0.03 1.08 -113.18 -110.69 -109.95 -109.50 -109.25 1198 1 |
ちゃんと推定できています。
続いて,自分でわざわざ正規分布を定義するバージョンです。このStanコードをfunction.stanで保存します。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 |
functions{ //手作り正規分布を定義 real normal2_log(real y, real mu, real sigma){ real temp; temp <- 1/(2*pi()*sigma^2)^0.5 * exp(-(y-mu)^2/(2*sigma^2)); return log(temp); } } data{ int N; //サンプルサイズ real y[N]; //データ } parameters{ real mu; //平均値 real } model{ mu ~ normal(0,100); //平均値の事前分布 sigma ~ cauchy(0,5); //標準偏差の事前分布 for(n in 1:N){ y[n] ~ normal2(mu,sigma);//手作り正規分布を使う } } |
ポイントは,関数名は必ず最後に_logをつけます。こうすることでStanはサンプリング用の関数だと判断してくれます。あとその名前からわかるように,戻り値は対数確率(対数尤度)である必要があります。なので,正規分布の確率のlogを戻り値にする点に注意です。
model{}での使い方は,普通のStanのサンプリング用の関数と同じで,
y[n] ~ normal2(mu, sigma);
という感じで,_logを省略して,~を使ってサンプリングすることができます。
もちろん,increment_log_prob()を使って,
increment_log_prob(normal2_log(y[n], mu, sigma));
としてもよいです。
Rコードは以下です。
1 2 3 4 5 |
set.seed(123) y <- rnorm(100,5,2) fit.function <- stan("function.stan",data=list(N=100,y=y)) fit.function |
結果は,
1 2 3 4 |
mean se_mean sd 2.5% 25% 50% 75% 97.5% n_eff Rhat mu 5.18 0.00 0.19 4.82 5.06 5.18 5.31 5.56 2109 1 sigma 1.85 0.00 0.13 1.62 1.75 1.84 1.93 2.13 2250 1 lp__ -202.14 0.03 1.03 -204.97 -202.53 -201.81 -201.41 -201.14 1056 1 |
このように,さっきと同じ結果が得られています。
ただ,手作り関数の方は正規化定数の省略とかができないので,最初から入ってる関数に比べてスピードは遅いです。normal1.stanは2000回のMCMCで0.07秒程度ですが,function.stanは0.28秒と,4倍遅いです。ただ,もしかしたら関数の作り方を工夫すると早くできるのかもしれません。
なにはともあれ,この機能を使うと,Stanに入っていない関数でサンプリングができるようになるわけです。いざというときには便利ですね。