-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathtokenizer_template_switch.py
95 lines (79 loc) · 3.54 KB
/
tokenizer_template_switch.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import re
from transformers import AutoTokenizer
def extract_separators(template):
"""
Extracts separators used in the tokenization template.
"""
# Adjust the regex to correctly match the specific pattern between '{{' and '+ message["content"] +'
pattern = r"\{\{\s*([^{}]+?)\s*\+ message\['content'\]"
matches = re.findall(pattern, template)
# Clean up any extra spaces and return the matches
separators = [match.strip() for match in matches]
if any("message['role']" in element for element in separators):
roles = ["system", "user", "assistant"]
separators_ = []
for role in roles:
separators_.append(separators[0].replace(" + message['role'] + ", role).replace("'",""))
return separators_
return separators
def detect_eos_token(jinja_template, tokenizer):
if "<|im_end|>" in jinja_template:
return "<|im_end|>"
if "</s>" in jinja_template:
return "</s>"
if "eos_token" in jinja_template:
return tokenizer.eos_token
else:
return "<|endoftext|>"
def recover_messages(formatted_message, separators, eos_token):
"""
Recovers the original messages from the formatted message string.
"""
# Split the formatted message using the end-of-string token
split_messages = formatted_message.split(eos_token)
# Remove the last empty string if it exists due to a trailing separator
if split_messages and split_messages[-1].strip() == '':
split_messages.pop()
# Prepare the list to hold the recovered messages
recovered_messages = []
# Define roles after the first message, alternating between "user" and "assistant"
alternate_roles = ["user", "assistant"]
# Iterate over the split messages
for index, message_content in enumerate(split_messages):
# Determine the role, starting with "system" for the first message
# then alternating between "user" and "assistant" for subsequent messages
if index == 0:
role = "system"
else:
role = alternate_roles[(index - 1) % 2]
# Clean the message content by removing leading/trailing whitespace and separators
clean_content = message_content.strip()
for separator in separators:
clean_content = clean_content.replace(separator.strip("'"), '', 1).strip()
# Append the cleaned message with its role to the list
recovered_messages.append({"role": role, "content": clean_content})
return recovered_messages
def recover_chat_messages(tokenized_chat, tokenizer):
"""
Given a tokenized_chat string and a tokenizer, returns the list of message dictionaries.
"""
jinja_template = tokenizer.chat_template
separators = extract_separators(jinja_template)
eos_token = eos_token = detect_eos_token(jinja_template, tokenizer)
recovered_messages = recover_messages(tokenized_chat, separators, eos_token)
return recovered_messages
# Example usage
if __name__ == "__main__":
checkpoint = "Qwen/Qwen1.5-0.5B"
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
messages = [
{
"role": "system",
"content": "You are a friendly chatbot who always responds in the style of a pirate",
},
{"role": "user", "content": "How many helicopters can a human eat in one sitting?"},
]
tokenized_chat = tokenizer.apply_chat_template(messages, tokenize=False)
print(tokenized_chat)
recovered_messages = recover_chat_messages(tokenized_chat, tokenizer)
print(recovered_messages)