diff --git a/fuse/data/tokenizers/modular_tokenizer/modular_tokenizer.py b/fuse/data/tokenizers/modular_tokenizer/modular_tokenizer.py index c34fde47..5762250c 100644 --- a/fuse/data/tokenizers/modular_tokenizer/modular_tokenizer.py +++ b/fuse/data/tokenizers/modular_tokenizer/modular_tokenizer.py @@ -1,29 +1,37 @@ import random from torch import Tensor -from typing import Dict from collections.abc import Iterable from tokenizers import Tokenizer, Encoding import tokenizers from warnings import warn -from typing import Optional, List, Set, Union, Tuple, Any, Iterator +from typing import Optional, List, Set, Union, Tuple, Any, Iterator, Dict import json import transformers import os from omegaconf import OmegaConf -import collections import omegaconf import copy import traceback import re from fuse.data.tokenizers.modular_tokenizer.special_tokens import special_wrap_input +from dataclasses import dataclass -TypedInput = collections.namedtuple( - "TypedInput", ["input_type", "input_string", "max_len", "truncate_mode"] -) +@dataclass +class ModularTokenizerInput: + input_type: str # sub tokenizer name + input_string: str # the string to tokenize + max_len: Optional[int] = None # max length used for truncation only. + truncate_mode: Optional[ + str + ] = None # by defualt will truncate with right direction, setting to "RAND" will randomly crop a sub-sequence of length=max_len -def list_to_tokenizer_string(lst: List[TypedInput]) -> str: +# for backward compatibility +TypedInput = ModularTokenizerInput + + +def list_to_tokenizer_string(lst: List[ModularTokenizerInput]) -> str: out = "" # prev_tokenizer = None for in_named_tuple in lst: @@ -983,7 +991,7 @@ def count_unknowns( def encode_list( self, - typed_input_list: List, + typed_input_list: List[ModularTokenizerInput], max_len: Optional[int] = None, padding_token_id: Optional[int] = None, padding_token: Optional[str] = "", @@ -1001,13 +1009,7 @@ def encode_list( """_summary_ Args: - typed_input_list (List): list of collections.namedtuple("input_type", ["input_string", "max_len"]), with - input type: the name of input type, - input_string: the string to be encoded - max_len: maximal length of the encoding (in tokens). Only relevant for truncation, as we do not need to - pad individual sub-tokenizer encodings - we only pad the final encoding of the ModularTokenizer. - The smallest value between config-defined and tuple-defined is used. If None, the max_len - that was defined for the sub-tokenizer in the config is used. + typed_input_list (List): list of ModularTokenizerInput. max_len (Optional[int], optional): _description_. Defaults to None. padding_token_id (Optional[str], optional): _description_. Defaults to 0. TODO: default to None and infer it padding_token (Optional[str], optional): _description_. Defaults to "".