forked from SimonCMills/reduce_sum_example
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathREADME.Rmd
217 lines (167 loc) · 18.6 KB
/
README.Rmd
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
---
title: "Reduce_sum applied to 2D data"
date: "22 June 2020"
output:
github_document:
pandoc_args: --webtex
---
In the 2.23 update, Stan introduced ```reduce_sum```, which allows us to parallelise within-chain computation. The 1D case- where the response variable is a vector- has a well worked example ([here](https://mc-stan.org/users/documentation/case-studies/reduce_sum_tutorial.html)), and the speedups for these simple models where most of the computation can be placed inside the parallel sum component are reasonably well documented. It's taken a bit of experimentation to figure out how to apply ```reduce_sum``` to a 2D case and I'm not aware of equivalent worked examples for these cases, or example speedups. This document therefore aims to provide a minimal example of a 2D application of ```reduce_sum``` and take a look at some of the speedups.
There is:
1. A brief description of the model and the data structure
2. Fitting this model using ```reduce_sum```
3. A comparison of speedups with increasing number of threads
They are referred to at various points throughout, but see also:
- The ```reduce_sum``` documentation: https://mc-stan.org/docs/2_23/stan-users-guide/reduce-sum.html
- [The 1D worked example (logistic regression)](https://mc-stan.org/users/documentation/case-studies/reduce_sum_tutorial.html), which has details that aren't included here
- [https://jsocolar.github.io/occupancyModels/](https://jsocolar.github.io/occupancyModels/) as an introduction/explanation to the model in Stan
# 1. (Brief) model description
A model familiar to many ecologists is the occupancy model. They come in a lot of flavours, but, very broadly, these models are typically attempting to relate binary presence/absence to characteristics of the environment and/or species, usually using data collected from a number of different spatial locations.
Of these models a further broad subset are detection-occupancy models, which attempt to delinate observed occupancy into separate detection and occupancy components (i.e. acknowledge that the probability of observing a given species depends both on the underlying probability of presence and also how likely it is to be observed, given presence). In order to do this, two things are required: first there needs to be some form of replicated sampling, and second we need to be able to reasonably assume that if a species is observed on a single one of these replicates, it is 'present' (in some sense) on every other of the replicates. The resulting mixture of 0s and 1s then allow for us to say something about both detection and occupancy.
We'll take a look at these detection-occupancy models, partly because it is working with these models that have motivated the need to parallelise, due to the fact that these datasets can quickly become large and they aren't obviously reduced to a 1D observation vector for modelling, but also they are also somewhat distinct from different flavours of glm that are already documented.
These models are readily fit in Stan (note though that they are slightly differently formulated than in JAGS/BUGS, see [here](https://jsocolar.github.io/occupancyModels/)).
## Data simulation
For our simulated species, there is variation in occupancy both across species (i.e. some species are more common than others), and also across an environmental gradient. Across all species, there is an average tendency for occupancy probability to increase across the environmental gradient, but there is variation in the strength of this association across species, such that some species can decrease in occupancy probability along this gradient. Detection varies across species, with some species more readily detected than others, and also with a visit-level covariate, which, for the sake of argument, we'll call time of day. Again, there is an average (negative) time of day effect across species, but species also vary in the strength of this association.
More formally:
$$\psi_{ik} = \beta_{0k} + \beta_{1k}.x_i$$
$$\theta_{ijk} = \alpha_{0k} + \alpha_{1k}.t_{ij}$$
$$Z_{ik} \sim bernoulli(logit^{-1}(\psi_{ik}))$$
$$y_{ijk} \sim bernoulli(logit^{-1}(\theta_{ijk}) \times Z)$$
where $i$ indexes points, $j$ indexes visits, $k$ indexes species. $\psi_{ik}$ is the occupancy probability, which varies according to a species-level intercept, and a species-level environmental association (where $x_i$ is the environmental variable). $\theta_{ijk}$ is the detection component, and varies according to a species-level intercept, and a species-level time of day effect. All species-level effects are simulated and modelled as normally-distributed random effects, i.e. eight hyperparameters in total.
Data are generated from this model in the following script:
```{r, message=F}
source("simulate_data.R")
```
For the timings, I'll generate 50 species across 50 sites and 4 visits (i.e. 10,000 observations in total).
Visually, the model look something like this (for nine example species):
```{r figs, echo=F, fig.align = "center", fig.cap = "**Figure 1:** Nine example species taken from the full simulation. Solid line indicates the underlying occupancy relationship, while the dashed line indicates the probability of detecting a given species on a point, estimated using the upper and lower bounds on the detection effect (at time=0 and time=1 respectively), and accounting for the species' overall detectability. Note the variation across species in the environmental association, the detection offset, and the strength of the time-of-day effect."}
# plot
library(dplyr)
df_sim_eg <- expand.grid(env_var = seq(-2, 2, len=50), #i = 1:n_site,
j = 1:n_visit,
k = 1:9,
time = c(0, 1)) %>%
as_tibble %>%
mutate(psi = b0[k] + b1[k]*env_var,
Z = rbinom(n(), 1, inv.logit(psi)),
theta = d0[k] + d1[k]*time,
det_sim = rbinom(n(), Z, inv.logit(theta)))
ggplot(df_sim_eg, aes(env_var, inv.logit(psi), group=time)) +
geom_line() +
geom_line(aes(y=inv.logit(psi)*inv.logit(theta), col=factor(time)), lty=2) +
theme(axis.text = element_text(colour="black"),
strip.text = element_text(hjust=0, face="bold"),
strip.background = element_blank(),
panel.grid.minor= element_blank(),
aspect.ratio=1) +
ylim(0,1) +
labs(x="Environmental gradient", y="Probability of occupancy/detection",
colour = "Time of day (scaled)") +
facet_wrap(~k)
```
## Data structure
Observations inherently have a 3D structure, structured by the point that was surveyed (point $i$), the visit to the point (visit $j$), and the species that was observed/unobserved (species $k$). However, these are readily collapsed to a 2D structure by collapsing the species dimension to produce a dataframe that has dimensions ($n_{points} \times n_{species}$, $n_{visits}$), with each row now indexing a point:species combination.
```{r, comment=NA}
head(det_sim_wide)
```
We have two columns that identify the species:point combination, followed by 4 columns that give the observation (0 or 1) across each of 4 visits to the point.
The time of day predictor has the same dimensions. It could be reduced to a smaller [i, j] array, as time of day only depends upon when the point was visited, not upon which species were observed, but it's more straightforward to just produce it as a array of times with the same dimensions as the detection data.
```{r, comment=NA}
head(time_sim_wide)
```
```{r, comment=NA}
head(env_var)
```
Finally, an environmental variable, that just varies between points.
# 2. The Stan model
The Stan file is a little complicated if it's not a model you are already familiar with. The main oddity is the if-statement asking if Q==1: this just refers to the central assumption mentioned at the outset, where points that have at least one observation are treated as though the species is present (but undetected) at all other visits. The converse, where a species is never observed at a point is more complicated, because now we have to acknowledge that it might have been there all along, but never observed, and also that it might simply not be there. Again [https://jsocolar.github.io/occupancyModels/](https://jsocolar.github.io/occupancyModels/) explains this in more detail.
```{r, comment=NA, echo=F}
cat(readLines('occupancy_example.stan'), sep = '\n')
```
## ```reduce_sum``` formulation
To get the parallelisation done, we need to break up the computation such that it can be passed out to multiple workers. To do this, the model itself wrapped in ```partial_sum``` (user-written function), which will allow the full task of calculating all the log-probabilities to be broken up into chunks (the size of which are set by the grainsize argument). ```reduce_sum``` will automate the choice about how to chunk up the data, and pass these data chunks out to be computed in parallel.
```{r, comment=NA, echo=F}
cat(readLines('occupancy_example_redsum.stan'), sep = '\n')
```
This is largely a repetition of the previous stan model, the only difference is that the model has been shifted into the ```parallel_sum``` function block. Note that a lot of stuff has to get passed to this function: you could simplify it by doing more computation outside (e.g. transforming some parameters) and passing a reduced set of objects, but then you would be shifting more of the computation *outside* of the bit that will be run in parallel. As the parallelisation speedup would take a hit from doing this, it's better to do it this way, despite the unwieldiness of the ensuing function.
## Running the multi-threaded model
When compiling the model, we need to specify that it is threaded (cpp_options), and then specify the threads_per_chain at the sampling step. The model is run with ```cmdstanr``` (at some point I presume multi-threading will be rolled out to ```rstan``` as well).
```{r, eval=F}
library(cmdstanr)
# data
stan_data <- list(n_tot = nrow(det_sim_wide),
n_visit = 4,
n_species = length(unique(det_sim_wide$k)),
n_point = length(unique(det_sim_wide$i)),
id_sp = det_sim_wide$k,
id_pt = det_sim_wide$i,
det = det_sim_wide[,3:6],
vis_cov = time_sim_wide[,3:6],
Q = rowSums(det_sim_wide[,3:6]),
env_var = env_var,
grainsize = 1)
## Run mod ----
library(cmdstanr)
mod <- cmdstan_model("occupancy_example_redsum.stan",
cpp_options = list(stan_threads = T))
samps <- mod$sample(data = stan_data,
chains = 1,
threads_per_chain = 8,
iter_warmup = 1000,
iter_sampling = 1000)
print(samps$time())
```
# 3. Timings
```{r fig2, echo=F, warnings=F, fig.align = "center", fig.cap = "**Figure 2:** Timings across 1, 2, 4, and 8 cores, using a simulated dataset with 50 species, 50 sites, and 4 visits (i.e. 10,000 observations in total). Model fitting was repeated 6 times each (in sequence of 8, 4, 2, 1 threads, repeated 6 times). Left hand panel presents the average speedup across runs, while the right hand panel presents each individual run, with Time on the y-axis. Speedup figure given above the points (average in the left-hand panel)."}
dat <- readRDS("timings.rds") %>%
bind_rows(., data.frame(run = "6", cpu=c(1, 2, 4, 8),
time = c(1860.0, 801.7, 660.1, 337.7))) %>%
group_by(run) %>%
mutate(speedup = time[1]/time,
t1_rep = time[1])
dat_summ <- dat %>% group_by(cpu) %>% summarise(speedup = mean(speedup))
dat_summ2 <- dat %>% group_by(run, cpu) %>% summarise(oneToOne = t1_rep/cpu)
p1 <- ggplot(dat, aes(log(cpu), log(speedup), group=run)) +
geom_point(alpha=.2) +
geom_line(alpha=.2) +
scale_x_continuous(breaks=log(dat$cpu), labels=dat$cpu) +
geom_line(aes(y=log(cpu)), lty=2) +
scale_y_continuous(trans="reverse", limits = c(NA, log(0.9)),
name = "Relative speedup",
breaks = log(c(1, 2, 4, 8)),
labels = c(1, 2, 4, 8)) +
coord_equal() +
geom_line(data=dat_summ, aes(log(cpu), log(speedup)), inherit.aes=F) +
geom_point(data=dat_summ, aes(log(cpu), log(speedup)), pch=21, fill="white", size=2, inherit.aes=F) +
geom_text(data=dat_summ, aes(log(cpu), log(speedup), label=sprintf("%0.1f", round(speedup, 1))), inherit.aes=F,
vjust=-.6, fontface="bold") +
theme(panel.grid.minor = element_blank())
p2 <- ggplot(dat, aes(log(cpu), log(time), group=run)) +
geom_point() +
geom_line() +
scale_x_continuous(breaks=log(dat$cpu), labels=dat$cpu, expand=c(.1, .1)) +
geom_line(data=dat_summ2, aes(y=log(oneToOne)), lty=2) +
scale_y_continuous("Time (s)", limits=c(NA, log(2000)),
breaks = log(c(200, 400, 800, 1600)),
labels = c(200, 400, 800, 1600), expand = c(.1, .1)) +
coord_equal() +
geom_text(aes(log(cpu), log(time), label=sprintf("%0.1f", round(speedup, 1))), inherit.aes=F,
vjust=-.6, fontface="bold") +
theme(panel.grid.minor = element_blank(),
strip.text=element_text(hjust=0, face="bold"),
strip.background = element_blank()) +
facet_wrap(~paste0("Run: ", run))
egg::ggarrange(p1, p2, ncol=2)
```
As you would hope, we are getting good average speedups by parallelising the model. While the dataset simulated here is fairly small, with a larger dataset, a 4.5x speedup on a chain that would otherwise take >1 week to run is not trivial at all! The speedup depends on the proportion of the model that can be placed within the ```partial_sum``` function (see [here](https://statmodeling.stat.columbia.edu/2020/05/05/easy-within-chain-parallelisation-in-stan/)). With increased model complexity, as long as the relative proportions of 'stuff in the ```partial_sum```' to 'stuff outside the ```partial_sum```' remain constant, we should expect to see similar speedups. Anecdotally this seems to be the case, and I've seen fairly equivalent speedups running these models with more complexity (e.g. more terms in the detection and occupancy components).
For models where almost all the computation can be placed inside the ```partial_sum```, we can achieve a 1:1 speedup (in some cases, marginally better!). Given that this model has a bunch of centering of hyper-parameters outside of the main model block, it is to be expected that we should see speedups slightly below this line. Some individual runs substantially exceeded the 1:1 line, which presumably arises due to a very slow 1-thread run (with the speedup then being calculated relative to this).
The other thing to mention is that when comparing speedups, it's important to give the computer a proper job to really see the speedup from parallelisation. If it only takes 30 seconds to run the model, then the speedups are all over the place (though the tendency to get faster with increasing cores is present). Initially I had been simulating 15 species at 15 sites, across 4 visits (i.e. 900 observations)- speedups were (a) not very repeatable, varying substantially between runs, and (b) the scaling didn't look great. Presumably this largely arises because the overhead paid for parallelising makes up a greater proportion of the total run time, and it's difficult to get large speed-gains on the bit that is being run in parallel.
## When to parallelise?
Ultimately, we care about the number of effective samples per unit time. With extra cores available, we would also increase this just by running extra chains simultaneously. In this case the warmup overhead would be paid an extra time for each additional chain, unlike in the multi-threaded version, where it is just being paid for the original number of chains (assuming you put all your cores into increased threads rather than chains). The obvious question is: what speedup would you need to see to make running multiple threads better than running multiple chains?
As a rough heuristic, we can do the following calculation. In the first place we need to run multiple chains as the baseline comparison, as we ideally need multiple chains to check convergence diagnostics, so I would always run 4 chains in the first place, irrespective of how many cores are available. The question then becomes, with more cores- beyond those used by the 4 parallel chains in the first place- how should they be allocated (to threads or cores)?. A back of the envelope calculation would be that the number of post-warmup samples obtained, $n_{samp}$, in a fixed unit of time is:
$$n_{samp} = (n_{tot} \times r - n_{wu}) \times n_{core}$$
where $n_{tot}$ is the total number of samples, modified by $r$, the rate at which samples are obtained, $n_{wu}$ is the number of warmup samples, and $n_{core}$ is the number of cores. Faster sampling due to more threads, i.e. $r > 1$, will result in more total samples, while the warmup overhead remains the same.
The difference in number of samples, as either rate or number of cores is modified is therefore:
$$n_{diff} = (n_{tot} \times r_{diff} - n_{tot} \times p_{wu}) \times c_{diff}$$
where both r and c are proportional increases (i.e. 4 to 8 cores would be $c_{diff}=2$, and a rate doubling would be $r_{diff}=2$). I've also rewritten the warmup as a proportion of the total number of samples. From this we can see that the answer will depend upon the speedup, $r_{diff}$, and the proportion of samples spent on warmup. If we assume you spend half of the total samples on warmup (i.e. $n_{wu}$ is half $n_{tot}$), then as long as the speedup is >1.5, increasing the number of threads over the number of cores is best. If the warmup is less than half the total samples, this speedup 'break-even-point' will increase, as increasing the number of chains is penalised less by the warmup overhead.
Multi-threading will also get diminishing returns, as the parallelisable portion of the model becomes sliced ever more ways. Thus $r_{diff}$ isn't constant, but at some point will start to plateau with increasing threads. So if the model is already on a bunch of threads, then it may become worth running additional chains rather than increasing the number of threads further.
Importantly, this is just a vague heuristic, and there are likely subtleties in the relative speed of sampling. During warmup the HMC is also adapting, which presumably means rate of sampling is slower, which would adjust this relationship further in favour of running multiple threads. This caveat aside, running this model on a dataset of this size appears to come out in favour of multi-threading up to about 4-threads, and approximately break-even at around 8-threads. Beyond this, to get more samples in a fixed time it would be best to add chains.