Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add metric for computing accuracy of first tokens in values of parameters #287

Merged
merged 2 commits into from
Nov 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion functionary/train/custom_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -500,7 +500,7 @@ def map_raw_data_to_input_dic(
invalid_count += 1

t2 = datetime.datetime.now()
avg_time = (t2 - t1).total_seconds() / len(data_points)
avg_time = (t2 - t1).total_seconds() / (len(data_points) + invalid_count)
remaining_time = avg_time * (data_size - len(data_points))
print(
f"{len(data_points)}/{data_size}, avg_time per 1000 data points: {avg_time * 1000}, remaining time: {remaining_time}"
Expand Down
275 changes: 192 additions & 83 deletions functionary/train/metrics.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from json_source_map import calculate
from typing import Any, List, Dict
from typing import Any, List, Dict, Tuple
import json
from transformers import AutoTokenizer


def find_first_token_value(index: int, token_indices: List) -> int:
def find_first_token_value(index: int, token_indices: List[Tuple[int, int]]) -> int:
"""This function return the index of token that contains the index
For example: token_indices=[(1, 4), ()]
Args:
Expand All @@ -12,81 +14,130 @@ def find_first_token_value(index: int, token_indices: List) -> int:
Returns:
int: _description_
"""
for start, end, token_index, _ in token_indices:
for i, (start, end) in enumerate(token_indices):
if start <= index and index < end:
return token_index
return i
return None


def find_index_of_token_contain_breakline(token_ids, tokenizer):
"""Find index of token that contains breakline, token is not always: '\n' sometimes: '__\n'

def extract_indices_of_first_token_in_argument_values(
argument_token_indices: List[int], argument_text: str, verbose: bool = False
) -> List[int]:
"""this function return indices of first tokens in argument values
for example, argument_text: {"a": 12, "b": {"c": "Hanoi"}}; argument_token_indices = [(0, 1), ... (10, 12)]
--> return the indices of first token of: 12; indices of first token of Hanoi
Args:
token_ids (_type_): _description_
tokenizer (_type_): _description_
argument_token_indices (List[int]): List of (start, end) of tokens in argument_token_indices
argument_text (str): The text of arguments, a python dictionary

Returns:
_type_: _description_
List[int]: indices of first token of values in argument_text
"""
for i in range(len(token_ids)):
tok = tokenizer.decode([token_ids[i]])
if "\n" in tok:
return i
return None


def extract_indices_of_first_tokens_of_param_values(
arguments_token_ids: List[int], tokenizer: Any, verbose: bool = False
) -> List[int]:
argument_text = tokenizer.decode(arguments_token_ids)
token_strings = [tokenizer.decode(token_id) for token_id in arguments_token_ids]
token_indices = []
pos = 0

# print(f"argument_text: {argument_text}")

for token_index, token_str in enumerate(token_strings):
start = argument_text.find(token_str, pos)
if start == -1:
if verbose:
print("cannot find start")
continue
end = start + len(token_str)
token_indices.append((start, end, token_index, token_str))
pos = end

if verbose:
print("token_indices: ", token_indices)
# locate the key in the dictionary
try:
# this can run into error if argument_text is not a valid json because of being truncated
# Calculate the positions of the values in the argument_text
field_dic = calculate(argument_text)
except Exception as e:
if verbose:
print(f"exception using calculate to find key from: {argument_text}")
return []

result = []
for field in field_dic:
if len(field) > 0:
if verbose:
print("find param: ", field)
entry = field_dic[field]
start, end = entry.value_start.position, entry.value_end.position
if argument_text[start] == '"':
if argument_text[start] == '"': # if parameter is string
start += 1
token_index = find_first_token_value(start, argument_token_indices)
if verbose:
print(f"find first token of param: {start}")
token_index = find_first_token_value(start, token_indices)
if token_index:
print(
f"key={field}; at: {start}, {end}; --> token_index: {token_index}"
)
if token_index is not None:
result.append(token_index)
return result


def get_indices_of_tokens_in_string(
tokenizer: Any, token_ids: List[int], verbose: bool = False
):
text = tokenizer.decode(token_ids)
tokens = [tokenizer.decode(token_id) for token_id in token_ids]
pos = 0
token_indices = []

for token_index, token in enumerate(tokens):
if text[pos:].startswith(token):
start = pos
end = start + len(token)
pos = end
token_indices.append((start, end))
else:
if len(token) > 1 and token[0] == " " and text[pos:].startswith(token[1:]):
start = pos
end = start + len(token) - 1
token_indices.append((start, end))
pos = end
else:
raise Exception(
f"cannot match token_index: {token_index}, token='{token}'"
)

return token_indices, text


def locate_start_end_indices_of_token(char_start, char_end, token_indices):
token_index_start, token_index_end = -1, -1
for index, (start, end) in enumerate(token_indices):
if char_start >= start and char_start < end:
token_index_start = index
if char_end <= end and char_end > start:
token_index_end = index
return token_index_start, token_index_end


def extract_indices_of_json_objects(text: str) -> List[Tuple[int, int]]:
"""
Extract all indices of JSON objects from a given text.

Parameters:
text (str): The input text containing JSON objects.

Returns:
list: A list of indices ([(start, end), ...]) of extracted JSON objects in text.
"""
json_indices = []
stack = []
start_idx = None

for i, char in enumerate(text):
if char == "{":
if not stack:
start_idx = i # Potential start of JSON object
stack.append(char)
elif char == "}":
if stack:
stack.pop()
if not stack and start_idx is not None:
json_str = text[start_idx : i + 1]
try:
# print("load json: ", json_str)
parsed_json = json.loads(json_str)
json_indices.append((start_idx, i + 1))
except json.JSONDecodeError:
# Invalid JSON, ignore and continue
pass
start_idx = None
return json_indices


def extract_indices_of_first_tokens_of_param_values_in_assistant_response(
tokenizer: Any, token_ids: List[int], verbose: bool = False
) -> List[int]:
"""Extract the first tokens of values of parameters in tool call
For example, token_ids of assistant response=get_current_weather\n{"location": "Hanoi"}
this function will extract the indices of tokens associated with: Hanoi & 3
For example, token_ids of assistant response= [27, 1723, 29380, 70464, 89963, 2588, 794, 330, 39, 73803, 498, 330, 817, 669, 301, 5979, 355, 794, 837, 5474, 1723, 29, 128008]
this is for assistant response text= '<function=get_weather>{"location": "Hanoi", "use_celcius": true}</function><|eom_id|>'
this function will extract the indices of first tokens associated with: Hanoi & true which are tokens: 39 & 837
Args:
tokenizer (Any): _description_
token_ids (List[int]): token_ids of the assistant
Expand All @@ -95,44 +146,78 @@ def extract_indices_of_first_tokens_of_param_values_in_assistant_response(
Returns:
_type_: _description_
"""
function_sep = ">>>"
function_sep_id = tokenizer.encode(function_sep, add_special_tokens=False)[0]
break_line = "\n"
brk_line_token_id = tokenizer.encode(break_line, add_special_tokens=False)[0]
# print(f"function_sep_id: {function_sep_id}; brk_line_token_id:{brk_line_token_id}")
sep_indices = [-1]
# print([tokenizer.decode([tok]) for tok in token_ids])
for i in range(len(token_ids)):
if token_ids[i] == function_sep_id:
sep_indices.append(i - 1)

# first we compute the indices of tokens and the response_text from token_indices
# For example token_ids=[27, 1723, 29380, 70464, 89963, 2588, 794, 330, 39, 73803, 498, 330, 817, 669, 301, 5979, 355, 794, 837, 5474, 1723, 29, 128008]
# this is tokens = ['<', 'function', '=get', '_weather', '>{"', 'location', '":', ' "', 'H', 'anoi', '",', ' "', 'use', '_c', 'el', 'ci', 'us', '":', ' true', '}</', 'function', '>', '<|eom_id|>']
# --> respone_text=<function=get_weather>{"location": "Hanoi", "use_celcius": true}</function><|eom_id|>
# token_indices=[(0, 1), (1, 9), (9, 13), (13, 21), (21, 24), (24, 32), (32, 34), (34, 36), (36, 37), (37, 41), (41, 43), (43, 45), (45, 48), (48, 50), (50, 52), (52, 54), (54, 56), (56, 58), (58, 63), (63, 66), (66, 74), (74, 75), (75, 85)]
# token_indices is list of indices (start, end) of token in response_text
token_indices, response_text = get_indices_of_tokens_in_string(
tokenizer, token_ids, verbose
)
if verbose:
print("sep_indices: ", sep_indices)
print(f"response_text:", response_text)
tokens = [response_text[s:e] for s, e in token_indices]
print(f"tokens: ", tokens)
print("token_indices: ", token_indices)
print("---------------")

# Extract indices of jsons in response_text, indices is a list: [(start, end), ...] where response_text[start: end] is a json
json_indices = extract_indices_of_json_objects(response_text)
result = []
for i, sep_index in enumerate(sep_indices):
brk_index = find_index_of_token_contain_breakline(
token_ids[sep_index + 1 :], tokenizer
for start, end in json_indices:
# first find the token_start_ind, token_end_ind associated with start, end, this is mapping from character index --> token_index
token_start_ind, token_end_ind = locate_start_end_indices_of_token(
start, end, token_indices
)
if brk_index >= 0:
brk_index += sep_index + 1
func_name = tokenizer.decode(token_ids[sep_index + 1 : brk_index])
# print(f"func_name:{token_ids[sep_index + 1: brk_index]};{func_name};sep_index={sep_index}, brk_index:{brk_index}")
if func_name != "all":
end_index = len(token_ids) - 2 # exclude eos_token_id for the last call
if i != len(sep_indices) - 1:
end_index = sep_indices[i + 1]
start_argument_index = brk_index + 1
# print(
# f"sep_index={sep_index}; start_argument_index={start_argument_index}; end_index={end_index + 1}"
# )
# = brk_index + 1, end_index
# token_ids[brk_index + 1: ] --> {"car_name": "Tang"}
token_indices = extract_indices_of_first_tokens_of_param_values(
token_ids[start_argument_index : end_index + 1],
tokenizer,
verbose=verbose,
if verbose:
print("------------------------------")
print(
f"extract json: start={start}; end={end}; content: {response_text[start: end]}"
)
print(
f"convert to token_indices: token_start_ind={token_start_ind}({token_indices[token_start_ind]}); token_end_ind={token_end_ind}({token_indices[token_end_ind]})"
)

argument_text = response_text[start:end]
# This is the token_indices inside argument_text
# for example: argument_text={"location": "Hanoi", "use_celcius": true}
# argument_token_indices = [(0, 2), (2, 10), (10, 12), (12, 14), (14, 15), (15, 19), (19, 21), (21, 23), (23, 26), (26, 28), (28, 30), (30, 32), (32, 34), (34, 36), (36, 41)]
argument_token_indices = []
# in the best case, this is = 0, for example, >{"a": 10} --> '>{"' is a token, while the start is only {, we need to temporarily consider this token as: {"

for p in token_indices[token_start_ind : token_end_ind + 1]:
# compute the relative indices of original token indices in argument_text
# if p[0] != start, this is the case where token p here is: '>{"' while start is at: {, which is in the middle of the token, so we need to trim the token into: {"
argument_token_indices.append(
(p[0] - start if p[0] >= start else 0, p[1] - start)
)
# check if the last token is longer than end --> trim. For example, last token=}</ --> trim to }
argument_token_indices[-1] = (argument_token_indices[-1][0], end - start)

first_token_of_values_indices = (
extract_indices_of_first_token_in_argument_values(
argument_token_indices, argument_text
)
)
if verbose:
print(
f"argument_token_indices: {argument_token_indices}; argument_text: {argument_text}"
)
print(
f"argument_tokens: ",
[argument_text[s:e] for s, e in argument_token_indices],
)
print(f"first_token_of_values_indices={first_token_of_values_indices}")

for index in first_token_of_values_indices:
result.append(index + token_start_ind)
if verbose:
start, end = token_indices[index + token_start_ind]
content = response_text[start:end]
print(
f"the detected token at index: {index + token_start_ind}, token_id={token_ids[index + token_start_ind]}; content={content}"
)
result.extend([start_argument_index + ind for ind in token_indices])
return result


Expand Down Expand Up @@ -163,3 +248,27 @@ def extract_unmasked_chunks(labels: List[int], preds: List[int]):
if len(current_label_chunk) > 0:
result.append((current_label_chunk, current_pred_chunk))
return result


def test1():
text = """<function=get_weather>{"location": "Hanoi", "use_celcius": true}</function><|eom_id|>"""
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.1-8B-Instruct")
token_ids = tokenizer.encode(text, add_special_tokens=False)
print("token_ids: ", token_ids)
extract_indices_of_first_tokens_of_param_values_in_assistant_response(
tokenizer, token_ids, verbose=True
)


def test2():
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.1-8B-Instruct")
token_ids = [505, 364, 3007, 1025]
extract_indices_of_first_tokens_of_param_values_in_assistant_response(
tokenizer, token_ids, verbose=True
)

# extract_indices_of_first_tokens_of_param_values_in_assistant_response(tokenizer, token_ids, verbose=True)


if __name__ == "__main__":
test2()
21 changes: 21 additions & 0 deletions functionary/train/training_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from torch.utils.data import DataLoader
import os
from typing import List
from functionary.train import metrics as train_metrics

LOCAL_RANK = int(os.getenv("LOCAL_RANK", "0"))

Expand Down Expand Up @@ -132,6 +133,23 @@ def compute_metrics(eval_preds, id2token, tokenizer):
if label == pred:
dic[label]["acc"] += 1

# Calculate the accuracy of first token of the values of parameters
unmasked_labels_preds = train_metrics.extract_unmasked_chunks(
label_list, prediction_list
)
first_token_param_value_total, first_token_param_value_acc = 0, 0
for unmasked_labels, pred_result in unmasked_labels_preds:
try:
indices = train_metrics.extract_indices_of_first_tokens_of_param_values_in_assistant_response(
tokenizer, unmasked_labels
)
for index in indices:
first_token_param_value_total += 1
if unmasked_labels[index] == pred_result[index]:
first_token_param_value_acc += 1
except Exception as e:
print_rank0(f"encounter exeption: {str(e)}\nFor unmaksed_labels: {unmasked_labels}")

# Calculate perplexity
loss = eval_preds.predictions[1].tolist()
loss = sum(loss) / len(loss)
Expand All @@ -142,6 +160,9 @@ def compute_metrics(eval_preds, id2token, tokenizer):
"perplexity": perplexity,
"accuracy_first_token": first_token_correct_count / first_token_total_count,
"total_number_first_token": first_token_total_count,
"first_token_param_values": first_token_param_value_acc
/ first_token_param_value_total,
"first_token_param_values_total": first_token_param_value_total,
}

for token_id, stat in sorted(
Expand Down
Loading
Loading