Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add MLBA (WIP) #75

Merged
merged 6 commits into from
Jul 3, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/benchmark_pr.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ jobs:
- uses: actions/checkout@v2
- uses: julia-actions/setup-julia@v1
with:
version: "1.8"
version: "1.9"
- uses: julia-actions/cache@v1
- name: Extract Package Name from Project.toml
id: extract-package-name
Expand Down
2 changes: 1 addition & 1 deletion docs/src/mdft.md
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ M₂ = [
choices,rts = rand(dist, 10_000, M₂; Δt = .001)
probs2 = map(c -> mean(choices .== c), 1:2)
```
Here, we see that job `A` is prefered over job `B`. Also note, in the code block above, `rand` has a keyword argument `Δt` which controls the precision of the discrete approximation. The default value is `Δt = .001`.
Here, we see that job `A` is prefered over job `B`. Also note, in the code block above, `rand` has a keyword argument `Δt` which controls the precision of the time discrete approximation. The default value is `Δt = .001`.

Next, we will simulate the choice between jobs `A`, `B`, and `S`.

Expand Down
16 changes: 8 additions & 8 deletions src/LBA.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@ A model object for the linear ballistic accumulator.

# Parameters

- `ν`: a vector of drift rates
- `σ`: a vector of drift rate standard deviation
- `A`: max start point
- `k`: A + k = b, where b is the decision threshold
- `τ`: a encoding-response offset
- `ν::Vector{T}`: a vector of drift rates
- `σ::Vector{T}`: a vector of drift rate standard deviation
- `A::T`: max start point
- `k::T`: A + k = b, where b is the decision threshold
- `τ::T`: an encoding-response offset

# Constructors

Expand Down Expand Up @@ -50,13 +50,13 @@ function LBA(ν, σ, A, k, τ)
return LBA(ν, σ, A, k, τ)
end

LBA(; τ = 0.3, A = 0.8, k = 0.5, ν = [2.0, 1.75], σ = fill(1.0, length(ν))) =
LBA(ν, σ, A, k, τ)

function params(d::LBA)
return (d.ν, d.σ, d.A, d.k, d.τ)
end

LBA(; τ = 0.3, A = 0.8, k = 0.5, ν = [2.0, 1.75], σ = fill(1.0, length(ν))) =
LBA(ν, σ, A, k, τ)

function select_winner(dt)
if any(x -> x > 0, dt)
mi, mv = 0, Inf
Expand Down
2 changes: 1 addition & 1 deletion src/LCA.jl
Original file line number Diff line number Diff line change
Expand Up @@ -177,5 +177,5 @@ function simulate(model::AbstractLCA; Δt = 0.001, _...)
push!(evidence, copy(x))
push!(time_steps, t)
end
return time_steps, reduce(vcat, transpose.(evidence))
return time_steps, stack(evidence, dims = 1)
end
2 changes: 1 addition & 1 deletion src/MDFT.jl
Original file line number Diff line number Diff line change
Expand Up @@ -367,5 +367,5 @@ function simulate(model::MDFT, M::AbstractArray; Δt = 0.001, _...)
push!(evidence, copy(x))
push!(time_steps, t)
end
return time_steps, reduce(vcat, transpose.(evidence))
return time_steps, stack(evidence, dims = 1)
end
148 changes: 148 additions & 0 deletions src/MLBA.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
"""
MLBA{T <: Real} <: AbstractMLBA

# Fields

- `ν::Vector{T}`: a vector of drift rates, which is a function of β₀, λₚ, λₙ, γ
- `β₀::T`: baseline input for drift rate
- `λₚ::T`: decay constant for attention weights of positive differences
- `λₙ::T`: decay constant for attention weights of negative differences
- `γ::T`: risk aversion exponent for subjective values
- `σ::Vector{T}`: a vector of drift rate standard deviation
- `A::T`: max start point
- `k::T`: A + k = b, where b is the decision threshold
- `τ::T`: an encoding-response offset

# Constructors

MLBA(ν, β₀, λₚ, λₙ, γ, σ, A, k, τ)

MLBA(;
n_alternatives = 3,
ν = fill(0.0, n_alternatives),
β₀ = 1.0,
λₚ = 1.0,
λₙ = 1.0,
γ = 0.70,
τ = 0.3,
A = 0.8,
k = 0.5,
σ = fill(1.0, n_alternatives)
)

# References

Trueblood, J. S., Brown, S. D., & Heathcote, A. (2014). The multiattribute linear ballistic accumulator model of context effects in multialternative choice. Psychological Review, 121(2), 179.
"""
mutable struct MLBA{T <: Real} <: AbstractMLBA
ν::Vector{T}
β₀::T
λₚ::T
λₙ::T
γ::T
σ::Vector{T}
A::T
k::T
τ::T
end

function MLBA(ν, β₀, λₚ, λₙ, γ, σ, A, k, τ)
_, β₀, λₚ, λₙ, γ, _, A, k, τ = promote(ν[1], β₀, λₚ, λₙ, γ, σ[1], A, k, τ)
ν = convert(Vector{typeof(k)}, ν)
σ = convert(Vector{typeof(k)}, σ)
return MLBA(ν, β₀, λₚ, λₙ, γ, σ, A, k, τ)
end

MLBA(;
n_alternatives = 3,
ν = fill(0.0, n_alternatives),
β₀ = 1.0,
λₚ = 1.0,
λₙ = 1.0,
γ = 0.70,
τ = 0.3,
A = 0.8,
k = 0.5,
σ = fill(1.0, n_alternatives)
) =
MLBA(ν, β₀, λₚ, λₙ, γ, σ, A, k, τ)

function params(d::AbstractMLBA)
return (d.ν, d.β₀, d.λₚ, d.λₙ, d.γ, d.σ, d.A, d.k, d.τ)
end

rand(d::AbstractMLBA, M::AbstractArray) = rand(Random.default_rng(), d, M)

function rand(rng::AbstractRNG, d::AbstractMLBA, M::AbstractArray)
compute_drift_rates!(d, M)
return rand(rng, d)
end

rand(d::AbstractMLBA, n_trials::Int, M::AbstractArray) = rand(Random.default_rng(), d, n_trials, M)
itsdfish marked this conversation as resolved.
Show resolved Hide resolved

function rand(rng::AbstractRNG, d::AbstractMLBA, n_trials::Int, M::AbstractArray)
compute_drift_rates!(d, M)
return rand(rng, d, n_trials)
end

function compute_drift_rates!(dist::AbstractMLBA, M::AbstractArray)
(; ν, β₀) = dist
n_options = length(ν)
ν .= β₀
utilities = map(x -> compute_utility(dist, x), eachrow(M))
for i ∈ 1:n_options
for j ∈ 1:n_options
i == j ? continue : nothing
ν[i] += compare(dist, utilities[i], utilities[j])
end
end
return nothing
end

function compute_weight(dist::AbstractMLBA, u1, u2)
(; λₙ, λₚ) = dist
λ = u1 ≥ u2 ? λₚ : λₙ
return exp(-λ * abs(u1 - u2))
end

function compute_utility(dist::AbstractMLBA, v)
(; γ) = dist
θ = atan(v[2], v[1])
# x and y intercepts for line passing through v with slope -1
a = sum(v)
u = fill(0.0, 2)
u[1] = a / (tan(θ)^γ + 1)^(1 / γ)
u[2] = a * (1 - (u[1] / a)^γ)^(1 / γ)
#u[2] = (a * tan(θ)) / (1 + tan(θ)^γ)^(1/ γ)
return u
end

function compare(dist::AbstractMLBA, u1, u2)
v = 0.0
for i ∈ 1:2
v += compute_weight(dist, u1[i], u2[i]) * (u1[i] - u2[i])
end
return v
end

"""
simulate(rng::AbstractRNG, model::AbstractMLBA, M::AbstractArray; n_steps = 100)

Returns a matrix containing evidence samples of the MLBA decision process. In the matrix, rows
represent samples of evidence per time step and columns represent different accumulators.

# Arguments

- `model::AbstractMLBA`: a subtype of AbstractMLBA

# Keywords

- `n_steps=100`: number of time steps at which evidence is recorded
"""
function simulate(rng::AbstractRNG, model::AbstractMLBA, M::AbstractArray; n_steps = 100)
compute_drift_rates!(model, M)
return simulate(rng, model; n_steps)
end

simulate(model::AbstractMLBA, M::AbstractArray; n_steps = 100) =
simulate(Random.default_rng(), model, M; n_steps)
10 changes: 4 additions & 6 deletions src/RDM.jl
Original file line number Diff line number Diff line change
Expand Up @@ -187,20 +187,18 @@ function simulate(rng::AbstractRNG, model::AbstractRDM; Δt = 0.001)
z = rand(rng, Uniform(0, A), n)
α = k + A
x = z
ϵ = fill(0.0, n)
evidence = [deepcopy(x)]
time_steps = [t]
while all(x .< α)
t += Δt
increment!(rng, model, x, ϵ, ν; Δt)
increment!(rng, model, x, ν; Δt)
push!(evidence, deepcopy(x))
push!(time_steps, t)
end
return time_steps, reduce(vcat, transpose.(evidence))
return time_steps, stack(evidence, dims = 1)
end

function increment!(rng::AbstractRNG, model::AbstractRDM, x, ϵ, ν; Δt)
ϵ .= rand(rng, Normal(0.0, 1.0), length(ν))
x .+= ν * Δt + ϵ * √(Δt)
function increment!(rng::AbstractRNG, model::AbstractRDM, x, ν; Δt)
x .+= ν * Δt .+ rand(rng, Normal(0.0, √(Δt)), length(ν))
return nothing
end
3 changes: 3 additions & 0 deletions src/SequentialSamplingModels.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ export AbstractCDDM
export AbstractLBA
export AbstractLCA
export AbstractLNR
export AbstractMLBA
export AbstractMDFT
export AbstractPoissonRace
export AbstractRDM
Expand All @@ -53,6 +54,7 @@ export LBA
export LCA
export LNR
export maaDDM
export MLBA
export MDFT
export PoissonRace
export SSM1D
Expand Down Expand Up @@ -103,4 +105,5 @@ include("poisson_race.jl")
include("stDDM.jl")
include("MDFT.jl")
include("ClassicMDFT.jl")
include("MLBA.jl")
end
8 changes: 8 additions & 0 deletions src/type_system.jl
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,13 @@ An abstract type for the lognormal race model
"""
abstract type AbstractLNR <: SSM2D end

"""
AbstractMLBA <: AbstractLBA

An abstract type for the multi-attribute linear ballistic accumulator
"""
abstract type AbstractMLBA <: AbstractLBA end

"""
AbstractLCA <: SSM2D

Expand Down Expand Up @@ -157,6 +164,7 @@ function rand(rng::AbstractRNG, d::SSM2D, n_trials::Int; kwargs...)
end
return (; choice, rt)
end

rand(d::SSM2D, n_trials::Int; kwargs...) =
rand(Random.default_rng(), d, n_trials; kwargs...)

Expand Down
95 changes: 95 additions & 0 deletions test/mlba.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
@safetestset "MLBA Tests" begin
@safetestset "compute_utility" begin
@safetestset "1" begin
using SequentialSamplingModels
using SequentialSamplingModels: compute_utility
using Test

model = MLBA(; γ = 1)

@test compute_utility(model, [3, 1]) ≈ [3, 1]
@test compute_utility(model, [2, 2]) ≈ [2, 2]
end

@safetestset "2" begin
using SequentialSamplingModels
using SequentialSamplingModels: compute_utility
using Test

model = MLBA(; γ = 0.5)

@test compute_utility(model, [3, 1]) ≈ [1.60770, 0.53590] atol = 1e-4
@test compute_utility(model, [2, 2]) ≈ [1, 1]
end

@safetestset "3" begin
using SequentialSamplingModels
using SequentialSamplingModels: compute_utility
using Test

model = MLBA(; γ = 1.2)

@test compute_utility(model, [3, 1]) ≈ [3.2828, 1.0943] atol = 1e-4
@test compute_utility(model, [2, 2]) ≈ [2.2449, 2.2449] atol = 1e-4
end
end

@safetestset "compute_weight" begin
using SequentialSamplingModels
using SequentialSamplingModels: compute_weight
using Test

model = MLBA(; λₚ = 1.5, λₙ = 2.5)

@test compute_weight(model, 2, 3) ≈ 0.082085 atol = 1e-4
@test compute_weight(model, 3, 2) ≈ 0.22313 atol = 1e-4
end

@safetestset "compare" begin
using SequentialSamplingModels
using SequentialSamplingModels: compare
using Test

model = MLBA(; λₚ = 1.5, λₙ = 2.5)

@test compare(model, [3, 1], [1, 3]) ≈ 0.086098 atol = 1e-4
@test compare(model, [1, 3], [3, 1]) ≈ 0.086098 atol = 1e-4
@test compare(model, [2, 4], [3, 5]) ≈ -0.16417 atol = 1e-4
end

@safetestset "compute_drift_rates" begin
@safetestset "1" begin
using SequentialSamplingModels
using SequentialSamplingModels: compute_drift_rates!
using Test

model = MLBA(; λₚ = 1.5, λₙ = 2.5, β₀ = 1, γ = 1.2)

stimuli = [
1 3
3 1
2 2
]

compute_drift_rates!(model, stimuli)
@test model.ν ≈ [1.2269, 1.2269, 1.2546] atol = 1e-4
end

@safetestset "2" begin
using SequentialSamplingModels
using SequentialSamplingModels: compute_drift_rates!
using Test

model = MLBA(; λₚ = 0.75, λₙ = 3, β₀ = 2, γ = 2)

stimuli = [
1 1
4 2
2 3
]

compute_drift_rates!(model, stimuli)
@test model.ν ≈ [1.9480, 3.0471, 3.3273] atol = 1e-4
end
end
end
2 changes: 1 addition & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@ using SafeTestsets

files = filter(f -> f ≠ "runtests.jl", readdir())

include.(files)
include.(files)
itsdfish marked this conversation as resolved.
Show resolved Hide resolved
Loading