-
Notifications
You must be signed in to change notification settings - Fork 66
/
Copy pathpytorch_multiclass_classification.py
133 lines (98 loc) · 4.6 KB
/
pytorch_multiclass_classification.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
"""Example of a multiclass active learning text classification with pytorch.
Note:
This examples requires gensim>=4.0.0 which is used for obtaining word2vec embeddings.
"""
import torch
import numpy as np
from collections import Counter
from small_text import (
ActiveLearnerException,
EmptyPoolException,
BreakingTies,
KimCNNClassifierFactory,
PoolBasedActiveLearner,
PoolExhaustedException,
random_initialization_stratified
)
from examplecode.data.example_data_multiclass import (
get_train_test,
preprocess_data
)
from examplecode.shared import evaluate
try:
import gensim.downloader as api
except ImportError:
raise ActiveLearnerException('This example requires the gensim library. '
'Please install gensim>=4.0.0 to run this example.')
def main(num_iterations=10, device='cuda'):
from small_text.integrations.pytorch.classifiers.base import AMPArguments
pretrained_vectors = api.load('word2vec-google-news-300')
train, test = get_train_test()
# TODO: use another dataset
train, test, tokenizer = preprocess_data(train, test, pretrained_vectors)
num_classes = len(np.unique(train.y))
# Active learning parameters
# TODO: the selection of embedding vectors can still be improved
classifier_kwargs = {
'embedding_matrix': load_gensim_embedding(train.data, tokenizer, pretrained_vectors),
'max_seq_len': 512,
'num_epochs': 4,
'device': device,
'amp_args': AMPArguments(use_amp=True, device_type='cuda')
}
clf_factory = KimCNNClassifierFactory(num_classes, classifier_kwargs)
query_strategy = BreakingTies()
# Active learner
active_learner = PoolBasedActiveLearner(clf_factory, query_strategy, train)
indices_labeled = initialize_active_learner(active_learner, train.y)
try:
perform_active_learning(active_learner, train, indices_labeled, test, num_iterations)
except PoolExhaustedException:
print('Error! Not enough samples left to handle the query.')
except EmptyPoolException:
print('Error! No more samples left. (Unlabeled pool is empty)')
def perform_active_learning(active_learner, train, indices_labeled, test, num_iterations):
# Perform 20 iterations of active learning...
for i in range(num_iterations):
# ...where each iteration consists of labelling 20 samples
indices_queried = active_learner.query(num_samples=20, representation=train)
# Simulate user interaction here. Replace this for real-world usage.
y = train.y[indices_queried]
# Return the labels for the current query to the active learner.
active_learner.update(y)
indices_labeled = np.concatenate([indices_queried, indices_labeled])
print('Iteration #{:d} ({} samples)'.format(i, len(indices_labeled)))
evaluate(active_learner, train[indices_labeled], test)
def load_gensim_embedding(texts, tokenizer, pretrained_vectors, min_freq=1, num_special_tokens=2):
vectors = [
np.zeros(pretrained_vectors.vectors.shape[1])
for _ in range(num_special_tokens)
]
vocab = tokenizer.get_vocab()
vectors += [
pretrained_vectors.get_vector(tokenizer.id_to_token(i))
if pretrained_vectors.has_index_for(tokenizer.id_to_token(i))
else np.zeros(pretrained_vectors.vectors.shape[1])
for i in range(num_special_tokens, len(vocab))
]
token_id_list = [text[0].cpu().numpy().tolist() for text in texts]
word_frequencies = Counter([token for tokens in token_id_list for token in tokens])
for i in range(num_special_tokens, len(vocab)):
is_in_vocab = pretrained_vectors.has_index_for(tokenizer.id_to_token(i))
if not is_in_vocab and word_frequencies[tokenizer.id_to_token(i)] >= min_freq:
vectors[i] = np.random.uniform(-0.25, 0.25, pretrained_vectors.vectors.shape[1])
return torch.as_tensor(np.stack(vectors))
def initialize_active_learner(active_learner, y_train):
indices_initial = random_initialization_stratified(y_train, 20)
active_learner.initialize(indices_initial)
return indices_initial
if __name__ == '__main__':
import argparse
import logging
logging.getLogger('small_text').setLevel(logging.INFO)
parser = argparse.ArgumentParser(description='An example that shows active learning '
'for multi-class text classification.')
parser.add_argument('--num_iterations', type=int, default=10,
help='number of active learning iterations')
args = parser.parse_args()
main(num_iterations=args.num_iterations)