-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathestimator_compare.jl
87 lines (81 loc) · 3.02 KB
/
estimator_compare.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
using Base.Filesystem
using ComputerAdaptiveTesting
using FittedItemBanks.DummyData: std_normal
using ComputerAdaptiveTesting.Sim
using ComputerAdaptiveTesting.NextItemRules
using ComputerAdaptiveTesting.TerminationConditions
using ComputerAdaptiveTesting.Aggregators
using FittedItemBanks
using PsychometricsBazaarBase.Integrators
using PsychometricsBazaarBase.Optimizers
import PsychometricsBazaarBase.IntegralCoeffs
using CATPlots
using ItemResponseDatasets: prompt_readline
using ItemResponseDatasets.VocabIQ
using GLMakie
using RIrtWrappers.Mirt
function get_item_bank()
fit_4pl(get_marked_df_cached(); TOL=1e-2)[1]
end
function main()
item_bank = get_item_bank()
integrator = FixedGKIntegrator(-6, 6, 61)
ability_integrator = AbilityIntegrator(integrator)
lh_ability_est = LikelihoodAbilityEstimator()
prior_ability_est = PriorAbilityEstimator(std_normal)
optimizer = AbilityOptimizer(OneDimOptimOptimizer(-6.0, 6.0, NelderMead()))
ability_estimator = ModeAbilityEstimator(prior_ability_est, optimizer)
grid = -3:0.1:3
lh_grid_tracker = GriddedAbilityTracker(lh_ability_est, grid)
prior_grid_tracker = GriddedAbilityTracker(prior_ability_est, grid)
closed_normal_tracker = ClosedFormNormalAbilityTracker(prior_ability_est)
#laplace_normal_tracker = LaplaceAbilityTracker(prior_ability_est)
rules = CatRules(
MultiAbilityTracker([
lh_grid_tracker,
prior_grid_tracker,
closed_normal_tracker
]),
ability_estimator,
AbilityVarianceStateCriterion(prior_ability_est, ability_integrator),
FixedItemsTerminationCondition(45)
)
function get_response(response_idx, response_name)
params = item_params(item_bank, response_idx)
println("Parameters for next question: $params")
prompt_readline(VocabIQ.questions[response_idx])
end
function new_response_callback(tracked_responses, terminating)
if tracked_responses.responses.values[end] > 0
println("Correct")
else
println("Wrong")
end
ability = ability_estimator(tracked_responses)
var = variance_given_mean(ability_integrator, prior_ability_est, tracked_responses, ability)
println("Got ability estimate: $ability ± $var")
fig = plot_likelihoods(
[
("Likelihood", lh_ability_est),
("Prior", prior_ability_est),
("Mode", ability_estimator),
("Likelihood grid", lh_grid_tracker),
("Prior grid", prior_grid_tracker),
("Closed normal (Owen 1975)", closed_normal_tracker)
],
tracked_responses,
ability_integrator,
-6:0.01:6,
)
display(GLMakie.Screen(), fig)
println("Press enter to continue")
readline()
end
loop_config = CatLoopConfig(
rules=rules,
get_response=get_response,
new_response_callback=new_response_callback
)
run_cat(loop_config, item_bank)
end
main()