Skip to content

Commit

Permalink
fix loo_compare in docs, simplify turing model, update project.toml
Browse files Browse the repository at this point in the history
  • Loading branch information
itsdfish committed Jul 10, 2024
1 parent 9e14830 commit 77423c1
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 24 deletions.
2 changes: 2 additions & 0 deletions docs/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
KernelDensity = "5ab0869b-81aa-558d-bb23-cbf5423bbe9b"
ParetoSmooth = "a68b5a21-f429-434e-8bfa-46b447300aac"
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Revise = "295af30f-e4ad-537b-8983-00126c2a3abe"
Expand All @@ -18,6 +19,7 @@ Colors = "0.12.0"
DataFrames = "1.0.0"
Documenter = "1"
KernelDensity = "0.6.0"
ParetoSmooth = "0.7.0"
Plots = "1.0.0"
StatsBase = "0.33.0,0.34.0"
StatsModels = "0.7.0"
Expand Down
41 changes: 17 additions & 24 deletions docs/src/loo_compare.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,20 +10,19 @@ First, we'll simulate data from a [`LBA`](@ref) with fixed parameters. Then, we'

Before proceeding, we will load the required packages.

```@setup loo_compare
```@example loo_example
using Random
using SequentialSamplingModels
using LinearAlgebra
using ParetoSmooth
using Turing
```

## Data-Generating Model

The next step is to generate simulated data for comparing the models. Here, we'll use an LBA as the true data-generating model:

```@example loo_compare
```@example loo_example
Random.seed!(5000)
dist = LBA(ν=[3.0, 2.0], A=0.8, k=0.3, τ=0.3)
Expand All @@ -34,52 +33,46 @@ data = rand(dist, 1000)

The following code block defines the model along with its prior distributions using Turing.jl. We'll use this model with different fixed values for the `k` parameter.

```@example loo_compare
@model function model_LBA(data, k; min_rt=0.2)
```@example loo_example
@model function model_LBA(data, k; min_rt=minimum(data.rt))
# Priors
ν ~ MvNormal(fill(3.0, 2), I * 2)
A ~ truncated(Normal(0.8, 0.4), 0.0, Inf)
τ ~ Uniform(0.0, min_rt)
# Likelihood
for i in 1:length(data)
data[i] ~ LBA(; ν, A, k, τ)
end
data ~ LBA(; ν, A, k, τ)
end
```

## Estimate the Parameters

Now we'll estimate the parameters using three different fixed `k` values:

```@example loo_compare
#prepare the data for model fitting
dat = [(choice=data.choice[i], rt=data.rt[i]) for i in 1:length(data.rt)]
min_rt = minimum(data.rt)
chain_LBA1 = sample(model_LBA(dat, 2, min_rt=min_rt), NUTS(), 1000)
chain_LBA2 = sample(model_LBA(dat, 0.3,min_rt=min_rt), NUTS(), 1000)
chain_LBA3 = sample(model_LBA(dat, 1, min_rt=min_rt), NUTS(), 1000)
```@example loo_example
chain_LBA1 = sample(model_LBA(data, 2.0), NUTS(), 1000)
chain_LBA2 = sample(model_LBA(data, 0.3), NUTS(), 1000)
chain_LBA3 = sample(model_LBA(data, 1.0), NUTS(), 1000)
```

## Compute PSIS-LOO

Next we will use the `psis_loo` function to compute the PSIS-LOO for each model:

```@example loo_compare
res1 = psis_loo(model_LBA(dat, 2, min_rt=min_rt), chain_LBA1)
res2 = psis_loo(model_LBA(dat, 0.3, min_rt=min_rt), chain_LBA2)
res3 = psis_loo(model_LBA(dat, 1, min_rt=min_rt), chain_LBA3)
show(stdout, MIME"text/plain"(), ans) #hide
```@example loo_example
res1 = psis_loo(model_LBA(data, 2.0), chain_LBA1)
res2 = psis_loo(model_LBA(data, 0.3), chain_LBA2)
res3 = psis_loo(model_LBA(data, 1.0), chain_LBA3)
show(stdout, MIME"text/plain"(), ans)
```

## Compare Models

Finally, we can compare the models using the `loo_compare` function:

```@example loo_compare
loo_compare((LBA1=res1, LBA2=res2, LBA3 = res3))
show(stdout, MIME"text/plain"(), ans) # hide
```@example loo_example
loo_compare((LBA1 = res1, LBA2 = res2, LBA3 = res3))
show(stdout, MIME"text/plain"(), ans)
```

Here we indeed correctly identified the generative model we simulated. It is of note that some researchers have criticized using model comparison metrics such as leave-one-out cross-validation. See Gronau et al. (2019) for more information.
Expand Down

0 comments on commit 77423c1

Please sign in to comment.