내생성: 베이지안 분석 2, Stan
베이지안 추론의 또 다른 툴은 Stan이다.
JAGS가 깁스 샘플링으로 사후 분포를 구한다면,
Stan은 HMC를 활용하여 사후 분포를 구한다.
Stan 또는 Stan이 구현하는 HMC(Hamiltonian Monte Carlo) 샘플링이 깁스 샘플링의 단점을
개선하기 위해 만들어졌지만 Stan이 항상 JAGS보다 좋은 것은 아니다.
시간이 꽤 흐른 블로그이지만, JAGS and Stan.
아래의 참고 문헌에서도 보더라도 DINA라는 모형에 깁스 샘플링, 메트로폴리스-해스팅스(Metropolis-Hastings) 샘플링, HMC를 비교했는데,
HMC가 MH(Metropolis-Hastings)보다는 좋았지만, 시간 효율에서 깁스 샘플링이 NUT(HMC의 단점을 개선한 방법)보다 훨씬 좋았다.
- 참고문헌
da Silva, M. A., de Oliveira, E. S., von Davier, A. A., & Bazan, J. L. (2018). Estimating the DINA model parameters using the No-U-Turn Sampler. Biometrical Journal, 60(2), 352-368.
그래도 앞선 JAGS 모형을 Stan으로 구현해 보면 다음과 같다.
## STAN MODEL
model <- '
// Simple linear regression with endogeneity
data {
int<lower=0> N; // number of subjects
real x[N]; // independent variables
real y[N]; // outcome variables
real meanx; // meanx
real varx; // varx
}
parameters{
real covXE;
real beta0;
real beta1;
real<lower=0> sigma2;
}
transformed parameters{
real rho = covXE/varx;
//real tau = sigma2^(-2);
real varE = rho^2 * varx + sigma2^2;
//real
}
model {
covXE ~ normal(5,1);
beta0 ~ normal(0, 10);
beta1 ~ normal(0, 10);
sigma2 ~ uniform(0, 100);
for (i in 1:N)
y[i] ~ normal(beta0 + beta1*x[i] + rho*(x[i]-meanx), sigma2);
}
'
cat(model, file="currentModel.stan")
## DATA
set.seed(1)
N <- 1000
gene <- rnorm(N)
power <- gene + (e1 = rnorm(N))
height <- 180 + 5*gene + (e2 = rnorm(N))
ability <- power + height/10 + (e3 = rnorm(N))
summary(lm(ability ~ height)) # 예측 모형
## ## Call: ## lm(formula = ability ~ height) ## ## Residuals: ## Min 1Q Median 3Q Max ## -5.1616 -1.0461 -0.0145 1.0365 5.0071 ## ## Coefficients: ## Estimate Std. Error t value Pr(>|t|) ## (Intercept) -35.56398 1.59503 -22.30 <2e-16 *** ## height 0.29756 0.00886 33.59 <2e-16 *** ## --- ## Signif. codes: 0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1 ## ## Residual standard error: 1.491 on 998 degrees of freedom ## Multiple R-squared: 0.5306, Adjusted R-squared: 0.5301 ## F-statistic: 1128 on 1 and 998 DF, p-value: < 2.2e-16
summary(lm(ability ~ height + power)) # 원인 결과 모형
## ## Call: ## lm(formula = ability ~ height + power) ## ## Residuals: ## Min 1Q Median 3Q Max ## -3.2322 -0.7124 -0.0099 0.7173 3.0623 ## ## Coefficients: ## Estimate Std. Error t value Pr(>|t|) ## (Intercept) -0.372789 1.551017 -0.24 0.81 ## height 0.102167 0.008614 11.86 <2e-16 *** ## power 1.013926 0.031168 32.53 <2e-16 *** ## --- ## Signif. codes: 0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1 ## ## Residual standard error: 1.039 on 997 degrees of freedom ## Multiple R-squared: 0.7723, Adjusted R-squared: 0.7718 ## F-statistic: 1691 on 2 and 997 DF, p-value: < 2.2e-16
cov(height, power + e3)
## [1] 5.604122
x <- height
y <- ability
meanx <- mean(x)
varx <- var(x)*(N-1)/N
## MCMC chains
require(rstan)
stanfit <- stan('currentModel.stan', chains=3,
iter=5000, warmup=1000, cores=3)
## recompiling to avoid crashing R session
## Warning: There were 193 transitions after warmup that exceeded the maximum treedepth. Increase max_treedepth above 10. See ## http://mc-stan.org/misc/warnings.html#maximum-treedepth-exceeded
## Warning: Examine the pairs() plot to diagnose sampling problems
# alogorithm can be "NUTS"(No-U-Turn sampler), "HMC"(Hamiltonian Monte Carlo)
## ANALYSIS
summary(stanfit)
## $summary ## mean se_mean sd 2.5% 25% ## covXE 5.1501611 0.0174778331 0.85786425 3.43605107 4.57195789 ## beta0 -2.8003156 0.1150685314 5.55869719 -13.99516612 -6.44751117 ## beta1 0.1155023 0.0006393507 0.03089100 0.05497617 0.09479741 ## sigma2 1.4936135 0.0005221238 0.03380906 1.42917567 1.47036680 ## rho 0.1817408 0.0006167644 0.03027264 0.12125268 0.16133699 ## varE 3.1939866 0.0065668009 0.33152352 2.60603701 2.96444921 ## lp__ -900.4169546 0.0258137199 1.44914920 -904.10287033 -901.14463963 ## 50% 75% 97.5% n_eff Rhat ## covXE 5.1601522 5.7112963 6.8408076 2409.140 1.000817 ## beta0 -2.7969643 0.9192211 8.0940037 2333.635 1.000713 ## beta1 0.1154543 0.1358261 0.1778146 2334.456 1.000714 ## sigma2 1.4930532 1.5163966 1.5605360 4192.947 1.000251 ## rho 0.1820934 0.2015424 0.2414010 2409.140 1.000817 ## varE 3.1737907 3.3940699 3.9076323 2548.713 1.000422 ## lp__ -900.0959384 -899.3501613 -898.5974498 3151.556 1.001013 ## ## $c_summary ## , , chains = chain:1 ## ## stats ## parameter mean sd 2.5% 25% 50% ## covXE 5.1313721 0.83957044 3.43605107 4.57636912 5.1636069 ## beta0 -2.8873359 5.44414805 -13.91700534 -6.40549743 -2.8121955 ## beta1 0.1159898 0.03026158 0.05624835 0.09563365 0.1155479 ## sigma2 1.4948183 0.03441096 1.43132862 1.47114592 1.4936091 ## rho 0.1810778 0.02962708 0.12125268 0.16149265 0.1822153 ## varE 3.1897111 0.32232100 2.60716723 2.97199407 3.1743552 ## lp__ -900.4027924 1.46621153 -904.17018310 -901.11980287 -900.0697335 ## stats ## parameter 75% 97.5% ## covXE 5.6728616 6.7689017 ## beta0 0.7610194 7.8425654 ## beta1 0.1356727 0.1773804 ## sigma2 1.5176589 1.5658123 ## rho 0.2001861 0.2388636 ## varE 3.3881855 3.8883487 ## lp__ -899.3267717 -898.5869159 ## ## , , chains = chain:2 ## ## stats ## parameter mean sd 2.5% 25% 50% ## covXE 5.1387402 0.89118778 3.42898787 4.53155807 5.1133255 ## beta0 -2.8978242 5.79000796 -14.15229461 -6.78341718 -3.0083340 ## beta1 0.1160433 0.03217695 0.05341483 0.09414703 0.1166375 ## sigma2 1.4928036 0.03417094 1.42832408 1.46931422 1.4920719 ## rho 0.1813378 0.03144857 0.12100343 0.15991134 0.1804410 ## varE 3.1894974 0.34591082 2.59669282 2.94753741 3.1594956 ## lp__ -900.4541403 1.47146276 -904.16599684 -901.15405705 -900.1263251 ## stats ## parameter 75% 97.5% ## covXE 5.7268232 6.9402935 ## beta0 1.0321895 8.3591987 ## beta1 0.1374552 0.1786248 ## sigma2 1.5160543 1.5597252 ## rho 0.2020903 0.2449117 ## varE 3.3879482 3.9451043 ## lp__ -899.3832946 -898.6125729 ## ## , , chains = chain:3 ## ## stats ## parameter mean sd 2.5% 25% 50% ## covXE 5.1803710 0.84123163 3.46440549 4.61098802 5.1982324 ## beta0 -2.6157865 5.43127173 -13.87442133 -6.19829787 -2.5462290 ## beta1 0.1144738 0.03017502 0.05534193 0.09448747 0.1139699 ## sigma2 1.4932187 0.03279785 1.42972154 1.47046485 1.4936441 ## rho 0.1828069 0.02968570 0.12225326 0.16271430 0.1834372 ## varE 3.2027513 0.32575772 2.60808247 2.97595951 3.1897279 ## lp__ -900.3939311 1.40856582 -903.94113798 -901.15037669 -900.0874199 ## stats ## parameter 75% 97.5% ## covXE 5.7278314 6.8324713 ## beta0 0.9824730 7.9993459 ## beta1 0.1344158 0.1769755 ## sigma2 1.5156740 1.5570008 ## rho 0.2021259 0.2411068 ## varE 3.4040959 3.8959928 ## lp__ -899.3439732 -898.6011902
plot(stanfit)
## ci_level: 0.8 (80% intervals)
## outer_level: 0.95 (95% intervals)
pairs(stanfit, pars= c("beta1", "covXE", "sigma2"))
여기서 잠깐. 위의 stan 모형을 보면, \(x\) 에 대해 평균중심화를 하지 않았음에도 적절히 수렴(convergence)가 이루어졌다. 비록 같은표본 수를 얻기 위해 시간이 오래 걸리지만, 이렇게 JAGS와 Stan의 서로 다른 강점과 단점을 가지고 있다. 특히 (수렴이 잘 이루어지지 않는) 복잡한 모형에서 Stan이 강점을 나타낸다고 한다.
traceplot(stanfit)
Leave a comment