From e8011a27e77948f27571a4a250aba6f3b4e08215 Mon Sep 17 00:00:00 2001 From: Khai Mai Date: Fri, 8 Nov 2024 10:33:55 +0700 Subject: [PATCH 1/2] implement computing accuracy of first token in parameter's values --- functionary/train/metrics.py | 257 +++++++++++++++++++--------- functionary/train/training_utils.py | 18 ++ tests/test_metrics.py | 6 +- 3 files changed, 196 insertions(+), 85 deletions(-) diff --git a/functionary/train/metrics.py b/functionary/train/metrics.py index c706f0c..356ca32 100644 --- a/functionary/train/metrics.py +++ b/functionary/train/metrics.py @@ -1,8 +1,10 @@ from json_source_map import calculate -from typing import Any, List, Dict +from typing import Any, List, Dict, Tuple +import json +from transformers import AutoTokenizer -def find_first_token_value(index: int, token_indices: List) -> int: +def find_first_token_value(index: int, token_indices: List[Tuple[int, int]]) -> int: """This function return the index of token that contains the index For example: token_indices=[(1, 4), ()] Args: @@ -12,81 +14,122 @@ def find_first_token_value(index: int, token_indices: List) -> int: Returns: int: _description_ """ - for start, end, token_index, _ in token_indices: + for i, (start, end) in enumerate(token_indices): if start <= index and index < end: - return token_index + return i return None -def find_index_of_token_contain_breakline(token_ids, tokenizer): - """Find index of token that contains breakline, token is not always: '\n' sometimes: '__\n' - +def extract_indices_of_first_token_in_argument_values( + argument_token_indices: List[int], argument_text: str, verbose: bool = False +) -> List[int]: + """this function return indices of first tokens in argument values + for example, argument_text: {"a": 12, "b": {"c": "Hanoi"}}; argument_token_indices = [(0, 1), ... (10, 12)] + --> return the indices of first token of: 12; indices of first token of Hanoi Args: - token_ids (_type_): _description_ - tokenizer (_type_): _description_ + argument_token_indices (List[int]): List of (start, end) of tokens in argument_token_indices + argument_text (str): The text of arguments, a python dictionary Returns: - _type_: _description_ + List[int]: indices of first token of values in argument_text """ - for i in range(len(token_ids)): - tok = tokenizer.decode([token_ids[i]]) - if "\n" in tok: - return i - return None - - -def extract_indices_of_first_tokens_of_param_values( - arguments_token_ids: List[int], tokenizer: Any, verbose: bool = False -) -> List[int]: - argument_text = tokenizer.decode(arguments_token_ids) - token_strings = [tokenizer.decode(token_id) for token_id in arguments_token_ids] - token_indices = [] - pos = 0 - - # print(f"argument_text: {argument_text}") - - for token_index, token_str in enumerate(token_strings): - start = argument_text.find(token_str, pos) - if start == -1: - if verbose: - print("cannot find start") - continue - end = start + len(token_str) - token_indices.append((start, end, token_index, token_str)) - pos = end - - if verbose: - print("token_indices: ", token_indices) - # locate the key in the dictionary try: - # this can run into error if argument_text is not a valid json because of being truncated + # Calculate the positions of the values in the argument_text field_dic = calculate(argument_text) except Exception as e: + if verbose: + print(f"exception using calculate to find key from: {argument_text}") return [] result = [] for field in field_dic: if len(field) > 0: - if verbose: - print("find param: ", field) entry = field_dic[field] start, end = entry.value_start.position, entry.value_end.position - if argument_text[start] == '"': + if argument_text[start] == '"': # if parameter is string start += 1 + token_index = find_first_token_value(start, argument_token_indices) if verbose: - print(f"find first token of param: {start}") - token_index = find_first_token_value(start, token_indices) - if token_index: + print( + f"key={field}; at: {start}, {end}; --> token_index: {token_index}" + ) + if token_index is not None: result.append(token_index) return result +def get_indices_of_tokens_in_string( + tokenizer: Any, token_ids: List[int], verbose: bool = False +): + text = tokenizer.decode(token_ids) + tokens = [tokenizer.decode(token_id) for token_id in token_ids] + pos = 0 + token_indices = [] + + for token_index, token in enumerate(tokens): + start = text.find(token, pos) + if start == -1: + if verbose: + print("cannot find start") + raise Exception(f"cannot find token: '{token}' in {text[pos: ]}") + end = start + len(token) + token_indices.append((start, end)) + pos = end + return token_indices, text + + +def locate_start_end_indices_of_token(char_start, char_end, token_indices): + token_index_start, token_index_end = -1, -1 + for index, (start, end) in enumerate(token_indices): + if char_start >= start and char_start < end: + token_index_start = index + if char_end <= end and char_end > start: + token_index_end = index + return token_index_start, token_index_end + + +def extract_indices_of_json_objects(text: str) -> List[Tuple[int, int]]: + """ + Extract all indices of JSON objects from a given text. + + Parameters: + text (str): The input text containing JSON objects. + + Returns: + list: A list of indices ([(start, end), ...]) of extracted JSON objects in text. + """ + json_indices = [] + stack = [] + start_idx = None + + for i, char in enumerate(text): + if char == "{": + if not stack: + start_idx = i # Potential start of JSON object + stack.append(char) + elif char == "}": + if stack: + stack.pop() + if not stack and start_idx is not None: + json_str = text[start_idx : i + 1] + try: + print("load json: ", json_str) + parsed_json = json.loads(json_str) + json_indices.append((start_idx, i + 1)) + except json.JSONDecodeError: + # Invalid JSON, ignore and continue + pass + start_idx = None + return json_indices + + def extract_indices_of_first_tokens_of_param_values_in_assistant_response( tokenizer: Any, token_ids: List[int], verbose: bool = False ) -> List[int]: """Extract the first tokens of values of parameters in tool call - For example, token_ids of assistant response=get_current_weather\n{"location": "Hanoi"} - this function will extract the indices of tokens associated with: Hanoi & 3 + For example, token_ids of assistant response= [27, 1723, 29380, 70464, 89963, 2588, 794, 330, 39, 73803, 498, 330, 817, 669, 301, 5979, 355, 794, 837, 5474, 1723, 29, 128008] + this is for assistant response text= '{"location": "Hanoi", "use_celcius": true}<|eom_id|>' + this function will extract the indices of first tokens associated with: Hanoi & true which are tokens: 39 & 837 Args: tokenizer (Any): _description_ token_ids (List[int]): token_ids of the assistant @@ -95,44 +138,78 @@ def extract_indices_of_first_tokens_of_param_values_in_assistant_response( Returns: _type_: _description_ """ - function_sep = ">>>" - function_sep_id = tokenizer.encode(function_sep, add_special_tokens=False)[0] - break_line = "\n" - brk_line_token_id = tokenizer.encode(break_line, add_special_tokens=False)[0] - # print(f"function_sep_id: {function_sep_id}; brk_line_token_id:{brk_line_token_id}") - sep_indices = [-1] - # print([tokenizer.decode([tok]) for tok in token_ids]) - for i in range(len(token_ids)): - if token_ids[i] == function_sep_id: - sep_indices.append(i - 1) - + # first we compute the indices of tokens and the response_text from token_indices + # For example token_ids=[27, 1723, 29380, 70464, 89963, 2588, 794, 330, 39, 73803, 498, 330, 817, 669, 301, 5979, 355, 794, 837, 5474, 1723, 29, 128008] + # this is tokens = ['<', 'function', '=get', '_weather', '>{"', 'location', '":', ' "', 'H', 'anoi', '",', ' "', 'use', '_c', 'el', 'ci', 'us', '":', ' true', '}', '<|eom_id|>'] + # --> respone_text={"location": "Hanoi", "use_celcius": true}<|eom_id|> + # token_indices=[(0, 1), (1, 9), (9, 13), (13, 21), (21, 24), (24, 32), (32, 34), (34, 36), (36, 37), (37, 41), (41, 43), (43, 45), (45, 48), (48, 50), (50, 52), (52, 54), (54, 56), (56, 58), (58, 63), (63, 66), (66, 74), (74, 75), (75, 85)] + # token_indices is list of indices (start, end) of token in response_text + token_indices, response_text = get_indices_of_tokens_in_string( + tokenizer, token_ids, verbose + ) if verbose: - print("sep_indices: ", sep_indices) + print(f"response_text:", response_text) + tokens = [response_text[s:e] for s, e in token_indices] + print(f"tokens: ", tokens) + print("token_indices: ", token_indices) + print("---------------") + + # Extract indices of jsons in response_text, indices is a list: [(start, end), ...] where response_text[start: end] is a json + json_indices = extract_indices_of_json_objects(response_text) result = [] - for i, sep_index in enumerate(sep_indices): - brk_index = find_index_of_token_contain_breakline( - token_ids[sep_index + 1 :], tokenizer + for start, end in json_indices: + # first find the token_start_ind, token_end_ind associated with start, end, this is mapping from character index --> token_index + token_start_ind, token_end_ind = locate_start_end_indices_of_token( + start, end, token_indices + ) + if verbose: + print("------------------------------") + print( + f"extract json: start={start}; end={end}; content: {response_text[start: end]}" + ) + print( + f"convert to token_indices: token_start_ind={token_start_ind}({token_indices[token_start_ind]}); token_end_ind={token_end_ind}({token_indices[token_end_ind]})" + ) + + argument_text = response_text[start:end] + # This is the token_indices inside argument_text + # for example: argument_text={"location": "Hanoi", "use_celcius": true} + # argument_token_indices = [(0, 2), (2, 10), (10, 12), (12, 14), (14, 15), (15, 19), (19, 21), (21, 23), (23, 26), (26, 28), (28, 30), (30, 32), (32, 34), (34, 36), (36, 41)] + argument_token_indices = [] + # in the best case, this is = 0, for example, >{"a": 10} --> '>{"' is a token, while the start is only {, we need to temporarily consider this token as: {" + + for p in token_indices[token_start_ind : token_end_ind + 1]: + # compute the relative indices of original token indices in argument_text + # if p[0] != start, this is the case where token p here is: '>{"' while start is at: {, which is in the middle of the token, so we need to trim the token into: {" + argument_token_indices.append( + (p[0] - start if p[0] >= start else 0, p[1] - start) + ) + # check if the last token is longer than end --> trim. For example, last token=} trim to } + argument_token_indices[-1] = (argument_token_indices[-1][0], end - start) + + first_token_of_values_indices = ( + extract_indices_of_first_token_in_argument_values( + argument_token_indices, argument_text + ) ) - if brk_index >= 0: - brk_index += sep_index + 1 - func_name = tokenizer.decode(token_ids[sep_index + 1 : brk_index]) - # print(f"func_name:{token_ids[sep_index + 1: brk_index]};{func_name};sep_index={sep_index}, brk_index:{brk_index}") - if func_name != "all": - end_index = len(token_ids) - 2 # exclude eos_token_id for the last call - if i != len(sep_indices) - 1: - end_index = sep_indices[i + 1] - start_argument_index = brk_index + 1 - # print( - # f"sep_index={sep_index}; start_argument_index={start_argument_index}; end_index={end_index + 1}" - # ) - # = brk_index + 1, end_index - # token_ids[brk_index + 1: ] --> {"car_name": "Tang"} - token_indices = extract_indices_of_first_tokens_of_param_values( - token_ids[start_argument_index : end_index + 1], - tokenizer, - verbose=verbose, + if verbose: + print( + f"argument_token_indices: {argument_token_indices}; argument_text: {argument_text}" + ) + print( + f"argument_tokens: ", + [argument_text[s:e] for s, e in argument_token_indices], + ) + print(f"first_token_of_values_indices={first_token_of_values_indices}") + + for index in first_token_of_values_indices: + result.append(index + token_start_ind) + if verbose: + start, end = token_indices[index + token_start_ind] + content = response_text[start:end] + print( + f"the detected token at index: {index + token_start_ind}, token_id={token_ids[index + token_start_ind]}; content={content}" ) - result.extend([start_argument_index + ind for ind in token_indices]) return result @@ -163,3 +240,17 @@ def extract_unmasked_chunks(labels: List[int], preds: List[int]): if len(current_label_chunk) > 0: result.append((current_label_chunk, current_pred_chunk)) return result + + +def test(): + text = """{"location": "Hanoi", "use_celcius": true}<|eom_id|>""" + tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.1-8B-Instruct") + token_ids = tokenizer.encode(text, add_special_tokens=False) + print("token_ids: ", token_ids) + extract_indices_of_first_tokens_of_param_values_in_assistant_response( + tokenizer, token_ids, verbose=True + ) + + +if __name__ == "__main__": + test() diff --git a/functionary/train/training_utils.py b/functionary/train/training_utils.py index a828831..52fb6eb 100644 --- a/functionary/train/training_utils.py +++ b/functionary/train/training_utils.py @@ -7,6 +7,7 @@ from torch.utils.data import DataLoader import os from typing import List +from functionary.train import metrics as train_metrics LOCAL_RANK = int(os.getenv("LOCAL_RANK", "0")) @@ -132,6 +133,20 @@ def compute_metrics(eval_preds, id2token, tokenizer): if label == pred: dic[label]["acc"] += 1 + # Calculate the accuracy of first token of the values of parameters + unmasked_labels_preds = train_metrics.extract_unmasked_chunks( + label_list, prediction_list + ) + first_token_param_value_total, first_token_param_value_acc = 0, 0 + for unmasked_labels, pred_result in unmasked_labels_preds: + indices = train_metrics.extract_indices_of_first_tokens_of_param_values_in_assistant_response( + tokenizer, unmasked_labels + ) + for index in indices: + first_token_param_value_total += 1 + if unmasked_labels[index] == pred_result[index]: + first_token_param_value_acc += 1 + # Calculate perplexity loss = eval_preds.predictions[1].tolist() loss = sum(loss) / len(loss) @@ -142,6 +157,9 @@ def compute_metrics(eval_preds, id2token, tokenizer): "perplexity": perplexity, "accuracy_first_token": first_token_correct_count / first_token_total_count, "total_number_first_token": first_token_total_count, + "first_token_param_values": first_token_param_value_acc + / first_token_param_value_total, + "first_token_param-values_total": first_token_param_value_total, } for token_id, stat in sorted( diff --git a/tests/test_metrics.py b/tests/test_metrics.py index 62d7128..6ac2872 100644 --- a/tests/test_metrics.py +++ b/tests/test_metrics.py @@ -7,6 +7,8 @@ class TestMetrics(unittest.TestCase): def test_first_argument_value_token_prediction(self): + # Each test case includes assistant_response and target_tokens + # the purpose of the test is to make sure that the function `extract_indices_of_first_tokens_of_param_values_in_assistant_response` can extract the first token of target token test_cases = [ { "assistant_response": """all @@ -17,7 +19,7 @@ def test_first_argument_value_token_prediction(self): }, {"assistant_response": "all\ngood morning<|eot_id|>", "target_tokens": []}, { - "assistant_response": 'function_name\n{"a": "a", "b": {"b1": "value1", "b2": 10, "b3": 1.4, "b4": false, "b5": ["abc", 10]}}<|eot_id|>', + "assistant_response": '{"a": "a", "b": {"b1": "value1", "b2": 10, "b3": 1.4, "b4": false, "b5": ["abc", 10]}}<|eom_id|>', "target_tokens": [ "a", ' {"', @@ -34,7 +36,7 @@ def test_first_argument_value_token_prediction(self): tokenizer = AutoTokenizer.from_pretrained("lmms-lab/llama3-llava-next-8b") - for case in test_cases[2:]: + for case in test_cases: token_ids = tokenizer.encode( case["assistant_response"], add_special_tokens=False ) From d93bba2560cd880930daf7d3f4fd5816dedb676f Mon Sep 17 00:00:00 2001 From: Khai Mai Date: Fri, 8 Nov 2024 08:20:48 +0000 Subject: [PATCH 2/2] fix function mapping token_indices and character indices --- functionary/train/custom_datasets.py | 2 +- functionary/train/metrics.py | 40 ++++++++++++++++++++-------- functionary/train/training_utils.py | 19 +++++++------ 3 files changed, 41 insertions(+), 20 deletions(-) diff --git a/functionary/train/custom_datasets.py b/functionary/train/custom_datasets.py index 8ea509e..a820f5c 100644 --- a/functionary/train/custom_datasets.py +++ b/functionary/train/custom_datasets.py @@ -500,7 +500,7 @@ def map_raw_data_to_input_dic( invalid_count += 1 t2 = datetime.datetime.now() - avg_time = (t2 - t1).total_seconds() / len(data_points) + avg_time = (t2 - t1).total_seconds() / (len(data_points) + invalid_count) remaining_time = avg_time * (data_size - len(data_points)) print( f"{len(data_points)}/{data_size}, avg_time per 1000 data points: {avg_time * 1000}, remaining time: {remaining_time}" diff --git a/functionary/train/metrics.py b/functionary/train/metrics.py index 356ca32..b7f04d0 100644 --- a/functionary/train/metrics.py +++ b/functionary/train/metrics.py @@ -67,14 +67,22 @@ def get_indices_of_tokens_in_string( token_indices = [] for token_index, token in enumerate(tokens): - start = text.find(token, pos) - if start == -1: - if verbose: - print("cannot find start") - raise Exception(f"cannot find token: '{token}' in {text[pos: ]}") - end = start + len(token) - token_indices.append((start, end)) - pos = end + if text[pos:].startswith(token): + start = pos + end = start + len(token) + pos = end + token_indices.append((start, end)) + else: + if len(token) > 1 and token[0] == " " and text[pos:].startswith(token[1:]): + start = pos + end = start + len(token) - 1 + token_indices.append((start, end)) + pos = end + else: + raise Exception( + f"cannot match token_index: {token_index}, token='{token}'" + ) + return token_indices, text @@ -113,7 +121,7 @@ def extract_indices_of_json_objects(text: str) -> List[Tuple[int, int]]: if not stack and start_idx is not None: json_str = text[start_idx : i + 1] try: - print("load json: ", json_str) + # print("load json: ", json_str) parsed_json = json.loads(json_str) json_indices.append((start_idx, i + 1)) except json.JSONDecodeError: @@ -242,7 +250,7 @@ def extract_unmasked_chunks(labels: List[int], preds: List[int]): return result -def test(): +def test1(): text = """{"location": "Hanoi", "use_celcius": true}<|eom_id|>""" tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.1-8B-Instruct") token_ids = tokenizer.encode(text, add_special_tokens=False) @@ -252,5 +260,15 @@ def test(): ) +def test2(): + tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.1-8B-Instruct") + token_ids = [505, 364, 3007, 1025] + extract_indices_of_first_tokens_of_param_values_in_assistant_response( + tokenizer, token_ids, verbose=True + ) + + # extract_indices_of_first_tokens_of_param_values_in_assistant_response(tokenizer, token_ids, verbose=True) + + if __name__ == "__main__": - test() + test2() diff --git a/functionary/train/training_utils.py b/functionary/train/training_utils.py index 52fb6eb..6671be6 100644 --- a/functionary/train/training_utils.py +++ b/functionary/train/training_utils.py @@ -139,13 +139,16 @@ def compute_metrics(eval_preds, id2token, tokenizer): ) first_token_param_value_total, first_token_param_value_acc = 0, 0 for unmasked_labels, pred_result in unmasked_labels_preds: - indices = train_metrics.extract_indices_of_first_tokens_of_param_values_in_assistant_response( - tokenizer, unmasked_labels - ) - for index in indices: - first_token_param_value_total += 1 - if unmasked_labels[index] == pred_result[index]: - first_token_param_value_acc += 1 + try: + indices = train_metrics.extract_indices_of_first_tokens_of_param_values_in_assistant_response( + tokenizer, unmasked_labels + ) + for index in indices: + first_token_param_value_total += 1 + if unmasked_labels[index] == pred_result[index]: + first_token_param_value_acc += 1 + except Exception as e: + print_rank0(f"encounter exeption: {str(e)}\nFor unmaksed_labels: {unmasked_labels}") # Calculate perplexity loss = eval_preds.predictions[1].tolist() @@ -159,7 +162,7 @@ def compute_metrics(eval_preds, id2token, tokenizer): "total_number_first_token": first_token_total_count, "first_token_param_values": first_token_param_value_acc / first_token_param_value_total, - "first_token_param-values_total": first_token_param_value_total, + "first_token_param_values_total": first_token_param_value_total, } for token_id, stat in sorted(