forked from Sherlock-Voice/Sherlock-Voice_Model
-
Notifications
You must be signed in to change notification settings - Fork 0
/
BERTSentenceTransform.py
70 lines (57 loc) · 2.24 KB
/
BERTSentenceTransform.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import numpy as np
# BERTSentenceTransform 수정
class BERTSentenceTransform:
r"""BERT style data transformation.
Parameters
----------
tokenizer : BERTTokenizer.
Tokenizer for the sentences.
max_seq_length : int.
Maximum sequence length of the sentences.
pad : bool, default True
Whether to pad the sentences to maximum length.
pair : bool, default True
Whether to transform sentences or sentence pairs.
"""
def __init__(self, tokenizer, max_seq_length,vocab, pad=True, pair=True):
self._tokenizer = tokenizer
self._max_seq_length = max_seq_length
self._pad = pad
self._pair = pair
self._vocab = vocab
def __call__(self, line):
text_a = line[0]
if self._pair:
assert len(line) == 2
text_b = line[1]
tokens_a = self._tokenizer.tokenize(text_a)
tokens_b = None
if self._pair:
tokens_b = self._tokenizer(text_b)
if tokens_b:
self._truncate_seq_pair(tokens_a, tokens_b,
self._max_seq_length - 3)
else:
if len(tokens_a) > self._max_seq_length - 2:
tokens_a = tokens_a[0:(self._max_seq_length - 2)]
vocab = self._vocab
tokens = []
tokens.append(vocab.cls_token)
tokens.extend(tokens_a)
tokens.append(vocab.sep_token)
segment_ids = [0] * len(tokens)
if tokens_b:
tokens.extend(tokens_b)
tokens.append(vocab.sep_token)
segment_ids.extend([1] * (len(tokens) - len(segment_ids)))
input_ids = self._tokenizer.convert_tokens_to_ids(tokens)
# The valid length of sentences. Only real tokens are attended to.
valid_length = len(input_ids)
if self._pad:
# Zero-pad up to the sequence length.
padding_length = self._max_seq_length - valid_length
# use padding tokens for the rest
input_ids.extend([vocab[vocab.padding_token]] * padding_length)
segment_ids.extend([0] * padding_length)
return np.array(input_ids, dtype='int32'), np.array(valid_length, dtype='int32'),\
np.array(segment_ids, dtype='int32')