forked from zhongbin1/bert_tokenization_for_java
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathWordpieceTokenizer.java
84 lines (69 loc) · 2.17 KB
/
WordpieceTokenizer.java
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
package bert;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
public class WordpieceTokenizer {
private Map<String, Integer> vocab;
private String unkToken = "[UNK]";
private int maxInputCharsPerWord = 200;
public WordpieceTokenizer(Map<String, Integer> vocab){
this.vocab = vocab;
}
/*
For example:
input = "unaffable"
output = ["un", "##aff", "##able"]
*/
public List<String> tokenize(String text){
List<String> tokens = whiteSpaceTokenize(text);
List<String> outputTokens = new ArrayList<String>();
for(String token : tokens){
int length = token.length();
if(length > this.maxInputCharsPerWord){
outputTokens.add(this.unkToken);
continue;
}
boolean isBad = false;
int start = 0;
List<String> subTokens = new ArrayList<String>();
while(start < length){
int end = length;
String curSubStr = null;
while(start < end){
String subStr = token.substring(start, end);
if(start > 0){
subStr = "##" + subStr;
}
if(this.vocab.containsKey(subStr)){
curSubStr = subStr;
break;
}
end -= 1;
}
if(null == curSubStr){
isBad = true;
break;
}
subTokens.add(curSubStr);
start = end;
}
if(isBad){
outputTokens.add(this.unkToken);
}else{
outputTokens.addAll(subTokens);
}
}
return outputTokens;
}
private List<String> whiteSpaceTokenize(String text){
List<String> result = new ArrayList<String>();
text = text.trim();
if(null == text){
return result;
}
String[] tokens = text.split(" ");
result = Arrays.asList(tokens);
return result;
}
}