-
Notifications
You must be signed in to change notification settings - Fork 11
/
Copy pathconfig.py
150 lines (134 loc) · 4.51 KB
/
config.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
from enum import Enum
from sklearn.svm import LinearSVC
from sklearn.calibration import CalibratedClassifierCV
import logging
from corpora.taxonomy import Taxonomy
from typing import List, Tuple, Optional
import time
from transformers import BertTokenizer
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("ISO_DA")
class Model(Enum):
SVM = "SVM"
class Config:
def __init__(self, model_type: Model, taxonomy: Taxonomy, out_folder: str = None):
current_timestamp = time.time()
self.model_type = model_type
if out_folder is None:
out_folder = f"models/{current_timestamp}/"
self.out_folder = out_folder
self.acceptance_threshold = 0.5
self.taxonomy = taxonomy
@staticmethod
def from_dict(dict_):
raise NotImplementedError()
def to_dict(self):
raise NotImplementedError()
class SVMConfig(Config):
def __init__(
self,
taxonomy: Taxonomy,
indexed_pos: bool,
indexed_dep: bool,
ngrams: bool,
dep: bool,
pos: bool,
prev: bool,
pipeline_files: List[str] = None,
out_folder: str = None,
):
Config.__init__(self, Model.SVM, taxonomy, out_folder)
if pipeline_files is None:
pipeline_files = []
self.indexed_pos = indexed_pos
self.indexed_dep = indexed_dep
self.ngrams = ngrams
self.dep = dep
self.pos = pos
self.prev = prev
self.taxonomy = taxonomy
self.pipeline_files = pipeline_files
@staticmethod
def create_classifier():
return CalibratedClassifierCV(LinearSVC(C=0.1), cv=3)
@staticmethod
def from_dict(dict_):
svm_config = SVMConfig(
indexed_pos=dict_["indexed_pos"],
dep=dict_["dep"],
pos=dict_["pos"],
prev=dict_["prev"],
indexed_dep=dict_["indexed_dep"],
ngrams=dict_["ngrams"],
taxonomy=Taxonomy.from_str(dict_["taxonomy"]),
pipeline_files=dict_["pipeline_files"],
)
svm_config.out_folder = dict_["out_folder"]
return svm_config
def to_dict(self):
return {
"indexed_pos": self.indexed_pos,
"dep": self.dep,
"pos": self.pos,
"prev": self.prev,
"indexed_dep": self.indexed_dep,
"ngrams": self.ngrams,
"taxonomy": self.taxonomy.to_str(),
"pipeline_files": self.pipeline_files,
"out_folder": self.out_folder,
}
class TransformerConfig(Config):
def __init__(
self,
taxonomy: Taxonomy,
device: str,
optimizer: type,
lr: float,
batch_size: int,
max_seq_len: int,
n_epochs: int,
pipeline_files: Optional[List[str]] = None,
out_folder: str = None,
):
Config.__init__(self, Model.SVM, taxonomy, out_folder)
if pipeline_files is None:
pipeline_files = []
self.device = device
self.lr = lr
self.optimizer = optimizer
self.taxonomy = taxonomy
self.batch_size = batch_size
self.max_seq_len = max_seq_len
self.tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
self.pad_index = self.tokenizer.convert_tokens_to_ids(self.tokenizer.pad_token)
self.unk_index = self.tokenizer.convert_tokens_to_ids(self.tokenizer.unk_token)
self.n_epochs = n_epochs
self.pipeline_files = pipeline_files
@staticmethod
def from_dict(dict_):
transformer_config = TransformerConfig(
taxonomy=Taxonomy.from_str(dict_["taxonomy"]),
device=dict_["device"],
optimizer=dict_["optimizer"],
lr=dict_["lr"],
batch_size=dict_["batch_size"],
max_seq_len=dict_["max_seq_len"],
n_epochs=dict_["n_epochs"],
pipeline_files=dict_["model_files"],
)
transformer_config.out_folder = dict_["out_folder"]
return transformer_config
def to_dict(self):
return {
"device": self.device,
"lr": self.lr,
"optimizer": self.optimizer,
"taxonomy": self.taxonomy.to_str(),
"out_folder": self.out_folder,
"batch_size": self.batch_size,
"max_seq_len": self.max_seq_len,
"pad_index": self.pad_index,
"unk_index": self.unk_index,
"n_epochs": self.n_epochs,
"pipeline_files": self.pipeline_files,
}