-
-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathclassifier.py
executable file
·72 lines (58 loc) · 2.41 KB
/
classifier.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
#!/usr/bin/env python3
import lzma
import logging
from typing import Union, Dict
from pickle import load
from sklearn.model_selection import train_test_split
from sklearn.ensemble import VotingClassifier
logging.getLogger(__name__)
logging.basicConfig(level=logging.WARNING)
class Classifier(object):
"""
A classifier for categorising text data using a VotingClassifier.
This class encapsulates the functionality to load a pre-trained VotingClassifier model,
provide access to the model, and support for re-training the model with new data.
Attributes:
clf (VotingClassifier): The loaded VotingClassifier model.
"""
def __init__(self) -> None:
"""
Initialises the Classifier instance by attempting to load a pre-trained model from a file.
If the model file cannot be loaded, an error is logged.
"""
try:
with lzma.open('./data/voting_classifier.pickle.xz', 'rb') as fh:
self.clf = load(fh)
except IOError:
logging.error("Unable to load file")
finally:
logging.info("Done loading file")
def model(self) -> VotingClassifier:
"""
Provides access to the loaded VotingClassifier model.
Returns:
VotingClassifier: The loaded model.
"""
return self.clf
def train(self, data: Dict[str, dict]) -> Union[object, Dict[str, str]]:
"""
Trains the classifier with the provided data.
This method attempts to split the provided data into training and testing sets,
then fits the classifier with the training data. If an error occurs during this process,
it returns a dictionary containing the error message.
Parameters:
- data (dict): A dictionary containing the data to train the classifier. It should have two
keys: 'body' and 'categories'. 'body' should map to the features, and 'categories' should map to the labels.
Returns:
- object: On successful training, it returns the trained classifier object. If an error occurs,
it returns a dictionary with the key 'error' and the error message as its value.
Raises:
- Exception: If an error occurs during the training process, an exception is caught and its message
is returned in a dictionary.
"""
try:
xtrain, xtest, ytrain, ytest = train_test_split(
dict(data['body']), dict(data['categories']), test_size=0.2, random_state=0)
self.clf.fit(xtrain, ytrain)
except Exception as e:
return {'error': "{}".format(str(e))}