-
Notifications
You must be signed in to change notification settings - Fork 32
/
Copy pathinfer_tokenize.py
73 lines (62 loc) · 2.86 KB
/
infer_tokenize.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
B_INST, E_INST = "[INST]", "[/INST]"
B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"
simple_vision_conv_multimodal = {
"system": "You are a helpful language and vision assistant. You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language.",
"roles": {"human": "USER", "gpt": "ASSISTANT"},
}
def tokenize_baichuan(item, tokenizer):
roles = simple_vision_conv_multimodal["roles"]
input_ids = []
if "instruction" in item and len(item["instruction"]) > 0:
system = item["instruction"]
else:
system = simple_vision_conv_multimodal["system"]
system_ids = tokenizer.encode(system, add_special_tokens=False)
input_ids += system_ids
for i, turn in enumerate(item["conversations"]):
role = roles.get(turn['from'], 'USER')
content = turn['value']
content = content.strip()
if role == 'ASSISTANT' and content != '':
content += '</s>'
role_ids = tokenizer.encode(role + ":", add_special_tokens=False)
content_ids = tokenizer.encode(content, add_special_tokens=False, truncation=True,
max_length=tokenizer.model_max_length)
input_ids += role_ids + content_ids
if tokenizer.add_bos_token:
input_ids = [tokenizer.bos_token_id] + input_ids
input_ids = input_ids[:tokenizer.model_max_length]
return input_ids
def tokenize_Cllama2(item, tokenizer):
input_ids = []
if "instruction" in item and len(item["instruction"]) > 0:
system = item["instruction"]
else:
system = simple_vision_conv_multimodal["system"]
system = B_SYS + system + E_SYS
# add system before the first content in conversations
item["conversations"][0]['value'] = system + item["conversations"][0]['value']
for i, turn in enumerate(item["conversations"]):
role = turn['from']
content = turn['value']
content = content.strip()
if role == 'human':
content = f"{B_INST} {content} {E_INST} "
content_ids = tokenizer.encode(content)
else:
# assert role == "gpt"
if content == "":
content_ids = []
else:
content = f"{content} "
content_ids = tokenizer.encode(content, add_special_tokens=False) + [tokenizer.eos_token_id] # add_special_tokens=False remove bos token, and add eos at the end
input_ids += content_ids
input_ids = input_ids[-tokenizer.model_max_length:]
return input_ids
def tokenize(item, tokenizer, llm_type):
if llm_type == "Chinese_llama2":
return tokenize_Cllama2(item, tokenizer)
elif llm_type == "baichuan":
return tokenize_baichuan(item, tokenizer)
else:
raise ValueError (f"Invalid llm type {llm_type}, please choose in ['Chinese_llama2', 'baichuan']")