-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtfm_ner.jac
63 lines (57 loc) · 1.35 KB
/
tfm_ner.jac
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
node tfm_ner {
can tfm_ner.train, tfm_ner.extract_entity;
can train {
train_data = file.load_json(visitor.train_file);
tfm_ner.train(
mode = "default",
epochs = visitor.num_train_epochs,
train_data = train_data,
val_data = train_data
);
}
can infer {
res = tfm_ner.extract_entity(
text = visitor.query
);
visitor.prediction = res;
}
}
walker train {
has train_file;
has num_train_epochs = 50, from_scratch = true;
root {
spawn here ++> node::tfm_ner;
take --> node::tfm_ner;
}
tfm_ner: here::train;
}
walker infer {
has query, interactive = true;
has labels, prediction;
root {
spawn here ++> node::tfm_ner;
take --> node::tfm_ner;
}
tfm_ner {
if (interactive) {
while true {
query = std.input("Enter input text (Ctrl-C to exit)> ");
here::infer;
std.out(prediction);
}
} else {
here::infer;
report prediction;
}
}
}
walker save_model {
has model_path;
can tfm_ner.save_model;
tfm_ner.save_model(model_path=model_path);
}
walker load_model {
has model_path;
can tfm_ner.load_model;
tfm_ner.load_model(model_path=model_path);
}