-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy path2-validation.R
95 lines (81 loc) · 2.19 KB
/
2-validation.R
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
# Validate the selected models using ICD 10 codes from other years.
library(torch)
library(purrr)
library(stringr)
library(readr)
library(tibble)
library(tidyr)
library(progress)
source("autoencoder.R")
default_device = "cpu"
if (backends_cudnn_is_available()) {
default_device = "cuda"
} else if (backends_mps_is_available()) {
default_device = "mps"
} else if (backends_mkldnn_is_available()) {
default_device = "mkldnn"
} else if (backends_openmp_is_available()) {
default_device = "openmp"
} else if (backends_mkl_is_available()) {
default_device = "mkl"
}
ae_model_paths = "autoencoder-models" |>
(\(x) file.path(x, dir(x)))()
icd10_embedding_paths = file.path("icd-10-cm-embeddings", 2019:2022) |>
map( ~ file.path(.x, dir(.x)))
xs = tibble(
embed = map(icd10_embedding_paths, ICD10Embedding),
year = 2019:2022
)
vds = tibble(
model = map(ae_model_paths, torch_load),
embedding_dim = str_extract(ae_model_paths, "\\d{4}") |> as.integer()
)
pred_error = function(d, m, device = default_device) {
m = m$to(device = device)
ret = c()
dl = dataloader(d, batch_size = 100, num_workers = 5)
pb = progress_bar$new(
format = "[:bar] :percent eta: :eta",
total = length(dl)
)
loop(for (b in dl) {
xt = b$x$to(device = device)
r = torch_mean((xt - m(xt))^2, 2)$to(device = "cpu") |>
as.numeric()
pb$tick()
ret = c(ret, r)
})
mean(ret)
}
variation = function(d, device = default_device) {
ret = c()
dl = dataloader(d, batch_size = 100, num_workers = 5)
pb = progress_bar$new(
format = "[:bar] :percent eta: :eta",
total = length(dl)
)
loop(for (b in dl) {
xt = b$x$to(device = device)
r = torch_var(xt, 2)$to(device = "cpu") |> as.numeric()
pb$tick()
ret = c(ret, r)
})
mean(ret)
}
x = expand_grid(vds, xs)
x$pred_error =
map_dbl(
seq_len(nrow(x)),
~ {print(.x); pred_error(x$embed[[.x]], x$model[[.x]])})
x$variation =
map_dbl(
seq_len(nrow(x)),
~ {print(.x); variation(x$embed[[.x]])})
saveRDS(x, "year-validation-raw.rds")
x |>
arrange(year) |>
select(-model, -embed) |>
mutate(cod = pred_error / variation) |>
select(year, embedding_dim, pred_error, cod) |>
saveRDS("year-validation.rds")