-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathREADME.rmd
70 lines (53 loc) · 2.2 KB
/
README.rmd
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
---
title: "README"
output: github_document
---
dartMachine
===========
This is a fork of the bartMachine package with details listed below, designed for use in high-dimensional, sparse, settings. Of course, any bugs here are my own. For details, see [Linero (2016). Bayesian Regression Trees for High Dimensional Prediction and Variable Selection. Journal of the American Statistical Association. To appear](http://www.tandfonline.com/doi/full/10.1080/01621459.2016.1264957).
Instructions
===============
For detailed instructions on installation, see the instructions on the [bartMachine](https://github.com/kapelner/bartMachine/) repository. This package is not available on CRAN, and must be installed by compiling from the source.
Illustration
=============
The following illustration reproduces some components of Figure 4 from Linero (2016).
```{r, message=FALSE}
options(java.parameters = "-Xmx2g")
library(dartMachine)
Fried <- function(X) {
10 * sin(pi * X[, 1] * X[,2]) + 20 * (X[,3] - 0.5)^2 + 10 * X[,4] + 5 * X[,5]
}
FriedSamp <- function(n, dim_x, sigma=1) {
X <- matrix(runif(n * dim_x), n, dim_x)
mu <- Fried(X)
Y <- rnorm(n, mu, sigma)
return(list(X=X, Y=Y))
}
RMSE <- function(x,y) sqrt(mean((x - y)^2))
set.seed(123)
train <- FriedSamp(250, 100, sqrt(10))
test <- FriedSamp(1000, 100, sqrt(10))
## Fit bart
bart <- bartMachine(X = as.data.frame(train$X), y = train$Y, seed = 1234, num_trees = 200)
bart_pred <- predict(bart, as.data.frame(test$X))
rm(bart); gc()
## Fit dart, with alpha = 1
dart <- bartMachine(X = as.data.frame(train$X), y = train$Y,
do_ard = TRUE,
num_trees = 200,
num_burn_in = 5000,
num_iterations_after_burn_in = 5000,
alpha_0 = 1,
seed = 1234)
dart_pred <- predict(dart, as.data.frame(test$X))
split_counts <- get_var_counts_over_chain(dart, type = "splits")
s_samples <- get_cov_prior_select(dart)
rm(dart); gc()
## RMSE
RMSE(bart_pred, Fried(test$X))
RMSE(dart_pred, Fried(test$X))
## Reproduction of figure in paper
plot(colMeans(split_counts > 0), cex=.2)
abline(h = 0.5)
plot(colMeans(s_samples))
```