From ed15c03997d6962428acf6d7b28873e89bdb1afe Mon Sep 17 00:00:00 2001 From: Matthew Ding Date: Thu, 21 Nov 2024 22:45:36 -0800 Subject: [PATCH 01/24] delta to mds script v1 --- llmfoundry/command_utils/__init__.py | 4 + .../data_prep/convert_delta_to_json.py | 115 ++++++++++-------- .../data_prep/convert_delta_to_mds.py | 88 ++++++++++++++ scripts/data_prep/convert_delta_to_mds.py | 32 +++++ 4 files changed, 186 insertions(+), 53 deletions(-) create mode 100644 llmfoundry/command_utils/data_prep/convert_delta_to_mds.py create mode 100644 scripts/data_prep/convert_delta_to_mds.py diff --git a/llmfoundry/command_utils/__init__.py b/llmfoundry/command_utils/__init__.py index 756e611a88..a48efe11ff 100644 --- a/llmfoundry/command_utils/__init__.py +++ b/llmfoundry/command_utils/__init__.py @@ -15,6 +15,9 @@ convert_delta_to_json_from_args, fetch_DT, ) +from llmfoundry.command_utils.data_prep.convert_delta_to_mds import ( + convert_delta_to_mds_from_args, +) from llmfoundry.command_utils.data_prep.convert_finetuning_dataset import ( convert_finetuning_dataset, convert_finetuning_dataset_from_args, @@ -53,5 +56,6 @@ 'convert_text_to_mds', 'convert_text_to_mds_from_args', 'convert_delta_to_json_from_args', + 'convert_delta_to_mds_from_args', 'fetch_DT', ] diff --git a/llmfoundry/command_utils/data_prep/convert_delta_to_json.py b/llmfoundry/command_utils/data_prep/convert_delta_to_json.py index 1a0e575850..5c4deedcb5 100644 --- a/llmfoundry/command_utils/data_prep/convert_delta_to_json.py +++ b/llmfoundry/command_utils/data_prep/convert_delta_to_json.py @@ -433,6 +433,48 @@ def get_columns_info( return columns, order_by, columns_str +def handle_fetch_exception( + e: Exception, + tablename: str, +) -> None: + from databricks.sql.exc import ServerOperationError + from pyspark.errors import AnalysisException + + if isinstance(e, (AnalysisException, ServerOperationError)): + error_message = str(e) + if 'INSUFFICIENT_PERMISSIONS' in error_message: + raise InsufficientPermissionsError(error_message) from e + elif 'UC_NOT_ENABLED' in error_message: + raise UCNotEnabledError() from e + elif 'UNRESOLVED_COLUMN.WITH_SUGGESTION' in error_message: + raise MalformedUCTableError(error_message) from e + elif 'Delta table' in str(e) and "doesn't exist" in str(e): + # Error processing `catalog`.`volume_name`.`table_name`: + # Delta table `volume_name`.`table_name` doesn't exist. + # --- + parts = error_message.split('`') + if len(parts) < 7: + # Failed to parse error, our codebase is brittle + # with respect to the string representations of + # errors in the spark library. + catalog_name, volume_name, table_name = ['unknown'] * 3 + else: + catalog_name = parts[1] + volume_name = parts[3] + table_name = parts[5] + raise DeltaTableNotFoundError( + catalog_name, + volume_name, + table_name, + ) from e + + if isinstance(e, InsufficientPermissionsError): + raise + + # For any other exception, raise a general error + raise RuntimeError(f'Error processing {tablename}: {str(e)}') from e + + def fetch( method: str, tablename: str, @@ -499,42 +541,7 @@ def fetch( ) except Exception as e: - from databricks.sql.exc import ServerOperationError - from pyspark.errors import AnalysisException - - if isinstance(e, (AnalysisException, ServerOperationError)): - error_message = str(e) - if 'INSUFFICIENT_PERMISSIONS' in error_message: - raise InsufficientPermissionsError(error_message) from e - elif 'UC_NOT_ENABLED' in error_message: - raise UCNotEnabledError() from e - elif 'UNRESOLVED_COLUMN.WITH_SUGGESTION' in error_message: - raise MalformedUCTableError(error_message) from e - elif 'Delta table' in str(e) and "doesn't exist" in str(e): - # Error processing `catalog`.`volume_name`.`table_name`: - # Delta table `volume_name`.`table_name` doesn't exist. - # --- - parts = error_message.split('`') - if len(parts) < 7: - # Failed to parse error, our codebase is brittle - # with respect to the string representations of - # errors in the spark library. - catalog_name, volume_name, table_name = ['unknown'] * 3 - else: - catalog_name = parts[1] - volume_name = parts[3] - table_name = parts[5] - raise DeltaTableNotFoundError( - catalog_name, - volume_name, - table_name, - ) from e - - if isinstance(e, InsufficientPermissionsError): - raise - - # For any other exception, raise a general error - raise RuntimeError(f'Error processing {tablename}: {str(e)}') from e + handle_fetch_exception(e, tablename) finally: if cursor is not None: @@ -647,6 +654,24 @@ def validate_and_get_cluster_info( return method, dbsql, sparkSession +def validate_output_folder(output_folder: str) -> None: + obj = urllib.parse.urlparse(output_folder) + if obj.scheme != '': + raise ValueError( + 'Check the output folder and verify it is a local path!', + ) + + if os.path.exists(output_folder): + if not os.path.isdir(output_folder) or os.listdir(output_folder,): + raise RuntimeError( + f'Output folder {output_folder} already exists and is not empty. Please remove it and retry.', + ) + + os.makedirs(output_folder, exist_ok=True) + log.info(f'Directory {output_folder} created.') + + + def fetch_DT( delta_table_name: str, json_output_folder: str, @@ -662,26 +687,10 @@ def fetch_DT( """Fetch UC Delta Table to local as jsonl.""" log.info(f'Start .... Convert delta to json') - obj = urllib.parse.urlparse(json_output_folder) - if obj.scheme != '': - raise ValueError( - 'Check the json_output_folder and verify it is a local path!', - ) - - if os.path.exists(json_output_folder): - if not os.path.isdir(json_output_folder) or os.listdir( - json_output_folder, - ): - raise RuntimeError( - f'Output folder {json_output_folder} already exists and is not empty. Please remove it and retry.', - ) - - os.makedirs(json_output_folder, exist_ok=True) - if not json_output_filename.endswith('.jsonl'): raise ValueError('json_output_filename needs to be a jsonl file') - log.info(f'Directory {json_output_folder} created.') + validate_output_folder(json_output_folder) # Validate_and_get_cluster_info allows cluster_id to be None if use_serverless is True. method, dbsql, sparkSession = validate_and_get_cluster_info( diff --git a/llmfoundry/command_utils/data_prep/convert_delta_to_mds.py b/llmfoundry/command_utils/data_prep/convert_delta_to_mds.py new file mode 100644 index 0000000000..13e835ad5f --- /dev/null +++ b/llmfoundry/command_utils/data_prep/convert_delta_to_mds.py @@ -0,0 +1,88 @@ +# Copyright 2024 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 + +import logging +import os +import time + +from streaming.base.converters import dataframe_to_mds + +from llmfoundry.command_utils.data_prep.convert_delta_to_json import ( + _check_imports, + format_tablename, + handle_fetch_exception, + validate_and_get_cluster_info, + validate_output_folder, +) + +log = logging.getLogger(__name__) + + +def fetch_DT_mds( + delta_table_name: str, + mds_output_folder: str, + DATABRICKS_HOST: str, + DATABRICKS_TOKEN: str, +) -> None: + """Fetch UC Delta Table and convert to MDS shards.""" + log.info(f'Converting Delta Table {delta_table_name} to MDS shards.') + + validate_output_folder(mds_output_folder) + + method, _, sparkSession = validate_and_get_cluster_info( + cluster_id=None, + databricks_host=DATABRICKS_HOST, + databricks_token=DATABRICKS_TOKEN, + http_path=None, + use_serverless=True, + ) + + formatted_delta_table_name = format_tablename(delta_table_name) + + try: + if method == 'dbconnect' and sparkSession is not None: + df = sparkSession.table(formatted_delta_table_name) + + with open(mds_output_folder): + mds_kwargs = { + 'out': mds_output_folder, + 'keep_local': True, + 'compression': None, + } + dataframe_to_mds( + df, + merge_index=True, + mds_kwargs=mds_kwargs, + ) + else: + raise NotImplementedError('Currently only dbconnect is supported.') + except Exception as e: + handle_fetch_exception(e, formatted_delta_table_name) + + +def convert_delta_to_mds_from_args( + delta_table_name: str, + mds_output_folder: str, +) -> None: + """A wrapper for convert_delta_to_mds that parses arguments. Currently only + supports dbconnect on severless compute and not dbsql. + + Args: + delta_table_name (str): The name of the delta table to convert. + mds_output_folder (str): The folder to output MDS shards. + """ + os.environ['WORLD_SIZE'] = '1' + _check_imports() + from databricks.sdk import WorkspaceClient + w = WorkspaceClient() + DATABRICKS_HOST = w.config.host + DATABRICKS_TOKEN = w.config.token + + tik = time.time() + fetch_DT_mds( + delta_table_name=delta_table_name, + mds_output_folder=mds_output_folder, + DATABRICKS_HOST=DATABRICKS_HOST, + DATABRICKS_TOKEN=DATABRICKS_TOKEN, + ) + log.info(f'convert_delta_to_mds took {time.time() - tik} seconds.') diff --git a/scripts/data_prep/convert_delta_to_mds.py b/scripts/data_prep/convert_delta_to_mds.py new file mode 100644 index 0000000000..4a42441276 --- /dev/null +++ b/scripts/data_prep/convert_delta_to_mds.py @@ -0,0 +1,32 @@ +# Copyright 2022 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 + +import logging +from argparse import ArgumentParser + +from llmfoundry.command_utils import convert_delta_to_mds_from_args + +log = logging.getLogger(__name__) + +if __name__ == '__main__': + parser = ArgumentParser( + description= + 'Download Delta Table from UC and save as MDS shards in local folder', + ) + parser.add_argument( + '--delta_table_name', + required=True, + type=str, + help='UC table ..', + ) + parser.add_argument( + '--mds_output_folder', + required=True, + type=str, + help='Local path to save the converted MDS shards', + ) + args = parser.parse_args() + convert_delta_to_mds_from_args( + delta_table_name=args.delta_table_name, + mds_output_folder=args.mds_output_folder, + ) From 5379d5b16d93862aeb984e92e79041eaff2f4419 Mon Sep 17 00:00:00 2001 From: Matthew Ding Date: Fri, 22 Nov 2024 00:42:37 -0800 Subject: [PATCH 02/24] remove open folder --- .../data_prep/convert_delta_to_mds.py | 21 +++++++++---------- 1 file changed, 10 insertions(+), 11 deletions(-) diff --git a/llmfoundry/command_utils/data_prep/convert_delta_to_mds.py b/llmfoundry/command_utils/data_prep/convert_delta_to_mds.py index 13e835ad5f..4aeab6d551 100644 --- a/llmfoundry/command_utils/data_prep/convert_delta_to_mds.py +++ b/llmfoundry/command_utils/data_prep/convert_delta_to_mds.py @@ -43,17 +43,16 @@ def fetch_DT_mds( if method == 'dbconnect' and sparkSession is not None: df = sparkSession.table(formatted_delta_table_name) - with open(mds_output_folder): - mds_kwargs = { - 'out': mds_output_folder, - 'keep_local': True, - 'compression': None, - } - dataframe_to_mds( - df, - merge_index=True, - mds_kwargs=mds_kwargs, - ) + mds_kwargs = { + 'out': mds_output_folder, + 'keep_local': True, + 'compression': None, + } + dataframe_to_mds( + df, + merge_index=True, + mds_kwargs=mds_kwargs, + ) else: raise NotImplementedError('Currently only dbconnect is supported.') except Exception as e: From 48d26e4e6e43e7077eaca65bc206a4d1959e37ba Mon Sep 17 00:00:00 2001 From: Matthew Ding Date: Fri, 22 Nov 2024 01:29:59 -0800 Subject: [PATCH 03/24] debug --- llmfoundry/command_utils/data_prep/convert_delta_to_mds.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/llmfoundry/command_utils/data_prep/convert_delta_to_mds.py b/llmfoundry/command_utils/data_prep/convert_delta_to_mds.py index 4aeab6d551..1514943ce7 100644 --- a/llmfoundry/command_utils/data_prep/convert_delta_to_mds.py +++ b/llmfoundry/command_utils/data_prep/convert_delta_to_mds.py @@ -43,6 +43,8 @@ def fetch_DT_mds( if method == 'dbconnect' and sparkSession is not None: df = sparkSession.table(formatted_delta_table_name) + print(type(df)) + mds_kwargs = { 'out': mds_output_folder, 'keep_local': True, From aa9edbfd7861bec7b3afa47fa7424710a717cbff Mon Sep 17 00:00:00 2001 From: Matthew Ding Date: Sun, 24 Nov 2024 22:30:00 -0800 Subject: [PATCH 04/24] added intermediate jsonl --- .../data_prep/convert_delta_to_mds.py | 185 ++++++++++++------ 1 file changed, 130 insertions(+), 55 deletions(-) diff --git a/llmfoundry/command_utils/data_prep/convert_delta_to_mds.py b/llmfoundry/command_utils/data_prep/convert_delta_to_mds.py index 1514943ce7..f4f07c4e04 100644 --- a/llmfoundry/command_utils/data_prep/convert_delta_to_mds.py +++ b/llmfoundry/command_utils/data_prep/convert_delta_to_mds.py @@ -1,89 +1,164 @@ # Copyright 2024 MosaicML LLM Foundry authors # SPDX-License-Identifier: Apache-2.0 +import json import logging import os -import time +import tempfile +from typing import Callable, Optional -from streaming.base.converters import dataframe_to_mds +import numpy as np +from streaming import MDSWriter from llmfoundry.command_utils.data_prep.convert_delta_to_json import ( _check_imports, + fetch_DT, format_tablename, - handle_fetch_exception, + get_columns_info, validate_and_get_cluster_info, - validate_output_folder, ) -log = logging.getLogger(__name__) +logger = logging.getLogger(__name__) -def fetch_DT_mds( - delta_table_name: str, - mds_output_folder: str, - DATABRICKS_HOST: str, - DATABRICKS_TOKEN: str, -) -> None: - """Fetch UC Delta Table and convert to MDS shards.""" - log.info(f'Converting Delta Table {delta_table_name} to MDS shards.') - - validate_output_folder(mds_output_folder) - - method, _, sparkSession = validate_and_get_cluster_info( - cluster_id=None, - databricks_host=DATABRICKS_HOST, - databricks_token=DATABRICKS_TOKEN, - http_path=None, - use_serverless=True, - ) - - formatted_delta_table_name = format_tablename(delta_table_name) +def get_conversion_config( + columns: list[str], + provided_dtypes: Optional[dict], +) -> tuple[dict, Callable]: + """If no dtypes is provided, attempts to infer config based on column names. - try: - if method == 'dbconnect' and sparkSession is not None: - df = sparkSession.table(formatted_delta_table_name) - - print(type(df)) - - mds_kwargs = { - 'out': mds_output_folder, - 'keep_local': True, - 'compression': None, - } - dataframe_to_mds( - df, - merge_index=True, - mds_kwargs=mds_kwargs, - ) - else: - raise NotImplementedError('Currently only dbconnect is supported.') - except Exception as e: - handle_fetch_exception(e, formatted_delta_table_name) + Args: + columns (List[str]): The list of column names. + provided_dtypes (Optional[Dict]): The provided dtypes. + """ + if provided_dtypes is not None: + return provided_dtypes, lambda x: x + + if len(columns) != 1: + raise ValueError( + 'Unable to infer dtypes from columns and no dtypes provided.', + ) + + if 'turns' in columns[0]: + logging.info('Identified IFT data') + dtypes = { + 'input_ids': 'ndarray', + 'attention_mask': 'ndarray', + 'labels': 'ndarray', + } + convert_x = lambda x: { + # join the turns into a single array + 'input_ids': + np.concatenate([ + np.array(turn['input_ids']) for turn in x['turns'] + ]), + 'attention_mask': + np.concatenate([ + np.array(turn['attention_mask']) for turn in x['turns'] + ]), + 'labels': + np. + concatenate([np.array(turn['labels']) for turn in x['turns']]), + } + elif 'tokens' in columns[0]: + logging.info('Identified CPT data') + dtypes = { + 'tokens': 'ndarray', + } + convert_x = lambda x: x + else: + raise ValueError( + 'Unable to infer dtypes from columns and no dtypes provided.', + ) + + return dtypes, convert_x def convert_delta_to_mds_from_args( delta_table_name: str, mds_output_folder: str, + http_path: Optional[str], + cluster_id: Optional[str], + use_serverless: bool, + batch_size: int, + processes: int, + dtypes: Optional[dict[str, str]], ) -> None: - """A wrapper for convert_delta_to_mds that parses arguments. Currently only - supports dbconnect on severless compute and not dbsql. + """A wrapper for convert_delta_to_mds that parses arguments. Args: delta_table_name (str): The name of the delta table to convert. mds_output_folder (str): The folder to output MDS shards. + http_path (Optional[str]): If set, dbsql method is used + batch_size (int): Row chunks to transmit a time to avoid OOM + processes (int): Number of processes allowed to use + cluster_id (Optional[str]): Cluster ID with runtime newer than 14.1.0 and access mode of either assigned or shared can use databricks-connect. + use_serverless (bool): Use serverless or not. Make sure the workspace is entitled with serverless + dtypes (Optional[Dict[str, str]]): Mapping between column name and dtype, where dtype is supported for MDS conversion. + If not provided, the function will attempt to infer the dtype. """ - os.environ['WORLD_SIZE'] = '1' _check_imports() from databricks.sdk import WorkspaceClient w = WorkspaceClient() DATABRICKS_HOST = w.config.host DATABRICKS_TOKEN = w.config.token - tik = time.time() - fetch_DT_mds( - delta_table_name=delta_table_name, - mds_output_folder=mds_output_folder, - DATABRICKS_HOST=DATABRICKS_HOST, - DATABRICKS_TOKEN=DATABRICKS_TOKEN, + method, dbsql, sparkSession = validate_and_get_cluster_info( + cluster_id=cluster_id, + databricks_host=DATABRICKS_HOST, + databricks_token=DATABRICKS_TOKEN, + http_path=http_path, + use_serverless=use_serverless, ) - log.info(f'convert_delta_to_mds took {time.time() - tik} seconds.') + cursor = dbsql.cursor() if dbsql is not None else None + columns, _, _ = get_columns_info( + tablename=format_tablename(delta_table_name), + method=method, + cursor=cursor, + sparkSession=sparkSession, + ) + logger.info(f'Columns: {columns}') + + dtypes, convert_x = get_conversion_config(columns, dtypes) + + compression = 'zstd:7' + hashes = ['sha1'] + limit = '10mb' + + logging.info(f'Fetching data from Delta Table {delta_table_name}...') + + with tempfile.TemporaryDirectory() as json_out_folder: + json_out_filename = 'temp.jsonl' + json_full_filepath = os.path.join(json_out_folder, json_out_filename) + try: + fetch_DT( + delta_table_name=delta_table_name, + json_output_folder=json_out_folder, + http_path=http_path, + batch_size=batch_size, + processes=processes, + cluster_id=cluster_id, + use_serverless=use_serverless, + json_output_filename=json_out_filename, + DATABRICKS_HOST=DATABRICKS_HOST, + DATABRICKS_TOKEN=DATABRICKS_TOKEN, + ) + except Exception as e: + logger.error(f'Error fetching data from Delta Table: {e}') + raise e + with MDSWriter( + out=mds_output_folder, + columns=dtypes, + compression=compression, + hashes=hashes, + size_limit=limit, + ) as out: + try: + with open(json_full_filepath, 'r') as f: + for line in f: + out.write(convert_x(json.loads(line))) + except FileNotFoundError as e: + logger.error(f'JSON output file not found: {e}') + raise e + + logging.info(f'Wrote to MDS at {mds_output_folder}') From fd54b5958bd286202f5d4e5581da024220dd0aac Mon Sep 17 00:00:00 2001 From: Matthew Ding Date: Sun, 24 Nov 2024 22:36:08 -0800 Subject: [PATCH 05/24] update script --- scripts/data_prep/convert_delta_to_mds.py | 53 ++++++++++++++++++++++- 1 file changed, 51 insertions(+), 2 deletions(-) diff --git a/scripts/data_prep/convert_delta_to_mds.py b/scripts/data_prep/convert_delta_to_mds.py index 4a42441276..16e6d11e70 100644 --- a/scripts/data_prep/convert_delta_to_mds.py +++ b/scripts/data_prep/convert_delta_to_mds.py @@ -1,6 +1,7 @@ # Copyright 2022 MosaicML LLM Foundry authors # SPDX-License-Identifier: Apache-2.0 +import json import logging from argparse import ArgumentParser @@ -23,10 +24,58 @@ '--mds_output_folder', required=True, type=str, - help='Local path to save the converted MDS shards', + help='Local path to save the MDS shards', + ) + parser.add_argument( + '--http_path', + required=False, + type=str, + help='http_path is set then dbsql method is used', + ) + parser.add_argument( + "--batch_size", + required=False, + type=int, + default=1 << 30, + help='row chunks to transmit a time to avoid OOM', + ) + parser.add_argument( + "--processes", + required=False, + type=int, + default=1, + help='number of processes allowed to use', + ) + parser.add_argument( + "--cluster_id", + required=False, + type=str, + help='cluster id to use for serverless', + ) + parser.add_argument( + "--use_serverless", + required=False, + action='store_true', + help='use serverless cluster', + ) + parser.add_argument( + "--dtypes", + required=False, + type=str, + help='data types for columns', ) args = parser.parse_args() + + if args.dtypes is not None: + args.dtypes = json.loads(args.dtypes) + convert_delta_to_mds_from_args( delta_table_name=args.delta_table_name, mds_output_folder=args.mds_output_folder, - ) + http_path=args.http_path, + cluster_id=args.cluster_id, + use_serverless=args.use_serverless, + batch_size=args.batch_size, + processes=args.processes, + dtypes=args.dtypes, + ) \ No newline at end of file From 2095115b01bdc6a76e6e30c738aeb557dc162eff Mon Sep 17 00:00:00 2001 From: Matthew Ding Date: Sun, 24 Nov 2024 22:52:49 -0800 Subject: [PATCH 06/24] cast to ndarray --- llmfoundry/command_utils/data_prep/convert_delta_to_mds.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/llmfoundry/command_utils/data_prep/convert_delta_to_mds.py b/llmfoundry/command_utils/data_prep/convert_delta_to_mds.py index f4f07c4e04..61448bd50f 100644 --- a/llmfoundry/command_utils/data_prep/convert_delta_to_mds.py +++ b/llmfoundry/command_utils/data_prep/convert_delta_to_mds.py @@ -65,7 +65,7 @@ def get_conversion_config( dtypes = { 'tokens': 'ndarray', } - convert_x = lambda x: x + convert_x = lambda x: np.array(x['tokens']) else: raise ValueError( 'Unable to infer dtypes from columns and no dtypes provided.', From 6a75da57d803951baa8c0d8e4c487ec167caf046 Mon Sep 17 00:00:00 2001 From: Matthew Ding Date: Sun, 24 Nov 2024 23:02:25 -0800 Subject: [PATCH 07/24] nit --- llmfoundry/command_utils/data_prep/convert_delta_to_mds.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/llmfoundry/command_utils/data_prep/convert_delta_to_mds.py b/llmfoundry/command_utils/data_prep/convert_delta_to_mds.py index 61448bd50f..4338533fcb 100644 --- a/llmfoundry/command_utils/data_prep/convert_delta_to_mds.py +++ b/llmfoundry/command_utils/data_prep/convert_delta_to_mds.py @@ -65,7 +65,7 @@ def get_conversion_config( dtypes = { 'tokens': 'ndarray', } - convert_x = lambda x: np.array(x['tokens']) + convert_x = lambda x: {'tokens': np.array(x['tokens'])} else: raise ValueError( 'Unable to infer dtypes from columns and no dtypes provided.', From a1a5274d655fe800cf7e9f8b76c178b64b6755de Mon Sep 17 00:00:00 2001 From: Matthew Ding Date: Sun, 24 Nov 2024 23:08:09 -0800 Subject: [PATCH 08/24] revert delta->jsonl refactor --- .../data_prep/convert_delta_to_json.py | 117 ++++++++---------- 1 file changed, 54 insertions(+), 63 deletions(-) diff --git a/llmfoundry/command_utils/data_prep/convert_delta_to_json.py b/llmfoundry/command_utils/data_prep/convert_delta_to_json.py index 5c4deedcb5..46d500773d 100644 --- a/llmfoundry/command_utils/data_prep/convert_delta_to_json.py +++ b/llmfoundry/command_utils/data_prep/convert_delta_to_json.py @@ -433,48 +433,6 @@ def get_columns_info( return columns, order_by, columns_str -def handle_fetch_exception( - e: Exception, - tablename: str, -) -> None: - from databricks.sql.exc import ServerOperationError - from pyspark.errors import AnalysisException - - if isinstance(e, (AnalysisException, ServerOperationError)): - error_message = str(e) - if 'INSUFFICIENT_PERMISSIONS' in error_message: - raise InsufficientPermissionsError(error_message) from e - elif 'UC_NOT_ENABLED' in error_message: - raise UCNotEnabledError() from e - elif 'UNRESOLVED_COLUMN.WITH_SUGGESTION' in error_message: - raise MalformedUCTableError(error_message) from e - elif 'Delta table' in str(e) and "doesn't exist" in str(e): - # Error processing `catalog`.`volume_name`.`table_name`: - # Delta table `volume_name`.`table_name` doesn't exist. - # --- - parts = error_message.split('`') - if len(parts) < 7: - # Failed to parse error, our codebase is brittle - # with respect to the string representations of - # errors in the spark library. - catalog_name, volume_name, table_name = ['unknown'] * 3 - else: - catalog_name = parts[1] - volume_name = parts[3] - table_name = parts[5] - raise DeltaTableNotFoundError( - catalog_name, - volume_name, - table_name, - ) from e - - if isinstance(e, InsufficientPermissionsError): - raise - - # For any other exception, raise a general error - raise RuntimeError(f'Error processing {tablename}: {str(e)}') from e - - def fetch( method: str, tablename: str, @@ -541,7 +499,42 @@ def fetch( ) except Exception as e: - handle_fetch_exception(e, tablename) + from databricks.sql.exc import ServerOperationError + from pyspark.errors import AnalysisException + + if isinstance(e, (AnalysisException, ServerOperationError)): + error_message = str(e) + if 'INSUFFICIENT_PERMISSIONS' in error_message: + raise InsufficientPermissionsError(error_message) from e + elif 'UC_NOT_ENABLED' in error_message: + raise UCNotEnabledError() from e + elif 'UNRESOLVED_COLUMN.WITH_SUGGESTION' in error_message: + raise MalformedUCTableError(error_message) from e + elif 'Delta table' in str(e) and "doesn't exist" in str(e): + # Error processing `catalog`.`volume_name`.`table_name`: + # Delta table `volume_name`.`table_name` doesn't exist. + # --- + parts = error_message.split('`') + if len(parts) < 7: + # Failed to parse error, our codebase is brittle + # with respect to the string representations of + # errors in the spark library. + catalog_name, volume_name, table_name = ['unknown'] * 3 + else: + catalog_name = parts[1] + volume_name = parts[3] + table_name = parts[5] + raise DeltaTableNotFoundError( + catalog_name, + volume_name, + table_name, + ) from e + + if isinstance(e, InsufficientPermissionsError): + raise + + # For any other exception, raise a general error + raise RuntimeError(f'Error processing {tablename}: {str(e)}') from e finally: if cursor is not None: @@ -654,24 +647,6 @@ def validate_and_get_cluster_info( return method, dbsql, sparkSession -def validate_output_folder(output_folder: str) -> None: - obj = urllib.parse.urlparse(output_folder) - if obj.scheme != '': - raise ValueError( - 'Check the output folder and verify it is a local path!', - ) - - if os.path.exists(output_folder): - if not os.path.isdir(output_folder) or os.listdir(output_folder,): - raise RuntimeError( - f'Output folder {output_folder} already exists and is not empty. Please remove it and retry.', - ) - - os.makedirs(output_folder, exist_ok=True) - log.info(f'Directory {output_folder} created.') - - - def fetch_DT( delta_table_name: str, json_output_folder: str, @@ -687,10 +662,26 @@ def fetch_DT( """Fetch UC Delta Table to local as jsonl.""" log.info(f'Start .... Convert delta to json') + obj = urllib.parse.urlparse(json_output_folder) + if obj.scheme != '': + raise ValueError( + 'Check the json_output_folder and verify it is a local path!', + ) + + if os.path.exists(json_output_folder): + if not os.path.isdir(json_output_folder) or os.listdir( + json_output_folder, + ): + raise RuntimeError( + f'Output folder {json_output_folder} already exists and is not empty. Please remove it and retry.', + ) + + os.makedirs(json_output_folder, exist_ok=True) + if not json_output_filename.endswith('.jsonl'): raise ValueError('json_output_filename needs to be a jsonl file') - validate_output_folder(json_output_folder) + log.info(f'Directory {json_output_folder} created.') # Validate_and_get_cluster_info allows cluster_id to be None if use_serverless is True. method, dbsql, sparkSession = validate_and_get_cluster_info( @@ -879,4 +870,4 @@ def convert_delta_to_json_from_args( DATABRICKS_HOST=DATABRICKS_HOST, DATABRICKS_TOKEN=DATABRICKS_TOKEN, ) - log.info(f'Elapsed time {time.time() - tik}') + log.info(f'Elapsed time {time.time() - tik}') \ No newline at end of file From 02dfcb5656a8565ecba9a740df42013c4ec42f2b Mon Sep 17 00:00:00 2001 From: Matthew Ding Date: Sun, 24 Nov 2024 23:10:21 -0800 Subject: [PATCH 09/24] nit --- .../data_prep/convert_delta_to_json.py | 2 +- scripts/data_prep/convert_delta_to_mds.py | 14 +++++++------- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/llmfoundry/command_utils/data_prep/convert_delta_to_json.py b/llmfoundry/command_utils/data_prep/convert_delta_to_json.py index 46d500773d..1a0e575850 100644 --- a/llmfoundry/command_utils/data_prep/convert_delta_to_json.py +++ b/llmfoundry/command_utils/data_prep/convert_delta_to_json.py @@ -870,4 +870,4 @@ def convert_delta_to_json_from_args( DATABRICKS_HOST=DATABRICKS_HOST, DATABRICKS_TOKEN=DATABRICKS_TOKEN, ) - log.info(f'Elapsed time {time.time() - tik}') \ No newline at end of file + log.info(f'Elapsed time {time.time() - tik}') diff --git a/scripts/data_prep/convert_delta_to_mds.py b/scripts/data_prep/convert_delta_to_mds.py index 16e6d11e70..0b89d7649d 100644 --- a/scripts/data_prep/convert_delta_to_mds.py +++ b/scripts/data_prep/convert_delta_to_mds.py @@ -33,39 +33,39 @@ help='http_path is set then dbsql method is used', ) parser.add_argument( - "--batch_size", + '--batch_size', required=False, type=int, default=1 << 30, help='row chunks to transmit a time to avoid OOM', ) parser.add_argument( - "--processes", + '--processes', required=False, type=int, default=1, help='number of processes allowed to use', ) parser.add_argument( - "--cluster_id", + '--cluster_id', required=False, type=str, help='cluster id to use for serverless', ) parser.add_argument( - "--use_serverless", + '--use_serverless', required=False, action='store_true', help='use serverless cluster', ) parser.add_argument( - "--dtypes", + '--dtypes', required=False, type=str, help='data types for columns', ) args = parser.parse_args() - + if args.dtypes is not None: args.dtypes = json.loads(args.dtypes) @@ -78,4 +78,4 @@ batch_size=args.batch_size, processes=args.processes, dtypes=args.dtypes, - ) \ No newline at end of file + ) From 4cf6d2353ae7634786754f0f78e90035f669f36e Mon Sep 17 00:00:00 2001 From: Matthew Ding Date: Tue, 26 Nov 2024 08:34:12 -0800 Subject: [PATCH 10/24] update col name --- llmfoundry/command_utils/data_prep/convert_delta_to_mds.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/llmfoundry/command_utils/data_prep/convert_delta_to_mds.py b/llmfoundry/command_utils/data_prep/convert_delta_to_mds.py index 4338533fcb..5e3abdbfb4 100644 --- a/llmfoundry/command_utils/data_prep/convert_delta_to_mds.py +++ b/llmfoundry/command_utils/data_prep/convert_delta_to_mds.py @@ -60,12 +60,12 @@ def get_conversion_config( np. concatenate([np.array(turn['labels']) for turn in x['turns']]), } - elif 'tokens' in columns[0]: + elif 'concat_tokens' in columns[0]: logging.info('Identified CPT data') dtypes = { 'tokens': 'ndarray', } - convert_x = lambda x: {'tokens': np.array(x['tokens'])} + convert_x = lambda x: {'tokens': np.array(x['concat_tokens'])} else: raise ValueError( 'Unable to infer dtypes from columns and no dtypes provided.', From 2932a9b68484090b7479849515a7e854e2ee9dcb Mon Sep 17 00:00:00 2001 From: Matthew Ding Date: Fri, 29 Nov 2024 22:32:28 -0800 Subject: [PATCH 11/24] use dtypes --- llmfoundry/command_utils/data_prep/convert_delta_to_mds.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/llmfoundry/command_utils/data_prep/convert_delta_to_mds.py b/llmfoundry/command_utils/data_prep/convert_delta_to_mds.py index 5e3abdbfb4..3da51faa52 100644 --- a/llmfoundry/command_utils/data_prep/convert_delta_to_mds.py +++ b/llmfoundry/command_utils/data_prep/convert_delta_to_mds.py @@ -32,7 +32,10 @@ def get_conversion_config( provided_dtypes (Optional[Dict]): The provided dtypes. """ if provided_dtypes is not None: - return provided_dtypes, lambda x: x + convert_x = lambda x: { + k: np.array(v, dtype=provided_dtypes.get(k)) for k, v in x.items() + } + return provided_dtypes, convert_x if len(columns) != 1: raise ValueError( From 04e628eb39f373ccdd79f38d6865890ad520bd8a Mon Sep 17 00:00:00 2001 From: Matthew Ding Date: Mon, 2 Dec 2024 15:58:26 -0800 Subject: [PATCH 12/24] dbugging message --- llmfoundry/data/finetuning/dataloader.py | 1 + 1 file changed, 1 insertion(+) diff --git a/llmfoundry/data/finetuning/dataloader.py b/llmfoundry/data/finetuning/dataloader.py index 661729ff8a..0808677318 100644 --- a/llmfoundry/data/finetuning/dataloader.py +++ b/llmfoundry/data/finetuning/dataloader.py @@ -180,6 +180,7 @@ def build_finetuning_dataloader( given a starting workload YAML. """ dataset_cfg = dataset + print("🦖🦖🦖🦖 build finetuning dataloader called with dataset_cfg: ", dataset_cfg) is_streaming = ( dataset_cfg.get('remote') is not None or dataset_cfg.get('streams') is not None From 08bc526e112c14aec490c1e613e4451eeb5f52cc Mon Sep 17 00:00:00 2001 From: Matthew Ding Date: Mon, 2 Dec 2024 16:52:55 -0800 Subject: [PATCH 13/24] test bugfix --- llmfoundry/data/packing.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/llmfoundry/data/packing.py b/llmfoundry/data/packing.py index 5eacced549..f452b5a4d0 100644 --- a/llmfoundry/data/packing.py +++ b/llmfoundry/data/packing.py @@ -485,7 +485,8 @@ def profile_packing( tmp_path_to_broadcast = tempfile.TemporaryDirectory().name gathered_paths = dist.all_gather_object(tmp_path_to_broadcast) tmp_path = gathered_paths[local_rank_zero] - stream_config['local'] = tmp_path + if stream_config.get('local') is not None: + stream_config['local'] = tmp_path # Determine the packing_ratio values we'll try packing_ratios, raw_batch_sizes = [], [] From 21abadaf8e9dc21bf0911ee12383f354bc92e043 Mon Sep 17 00:00:00 2001 From: Matthew Ding Date: Mon, 2 Dec 2024 17:14:08 -0800 Subject: [PATCH 14/24] logic is hard --- llmfoundry/data/packing.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/llmfoundry/data/packing.py b/llmfoundry/data/packing.py index f452b5a4d0..1b5858fac6 100644 --- a/llmfoundry/data/packing.py +++ b/llmfoundry/data/packing.py @@ -485,7 +485,7 @@ def profile_packing( tmp_path_to_broadcast = tempfile.TemporaryDirectory().name gathered_paths = dist.all_gather_object(tmp_path_to_broadcast) tmp_path = gathered_paths[local_rank_zero] - if stream_config.get('local') is not None: + if stream_config.get('local') is None: stream_config['local'] = tmp_path # Determine the packing_ratio values we'll try From 819c11264a7e932cd38d701f7e246960613b628c Mon Sep 17 00:00:00 2001 From: Matthew Ding Date: Mon, 2 Dec 2024 17:59:58 -0800 Subject: [PATCH 15/24] more testing --- llmfoundry/data/packing.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/llmfoundry/data/packing.py b/llmfoundry/data/packing.py index 1b5858fac6..d5bdd5735f 100644 --- a/llmfoundry/data/packing.py +++ b/llmfoundry/data/packing.py @@ -474,7 +474,8 @@ def profile_packing( # If streaming dataset, use a temporary local folder for profiling local_rank_zero = dist.get_global_rank() - dist.get_local_rank() - if dataset_cfg.get('remote') is not None: + if dataset_cfg.get('remote' + ) is not None and dataset_cfg.get('local') is None: tmp_path_to_broadcast = tempfile.TemporaryDirectory().name gathered_paths = dist.all_gather_object(tmp_path_to_broadcast) tmp_path = gathered_paths[local_rank_zero] From 19bf0a4997dd640d29b941427605ace44f8f1a8f Mon Sep 17 00:00:00 2001 From: Matthew Ding Date: Thu, 5 Dec 2024 15:36:41 -0800 Subject: [PATCH 16/24] remove debug msg --- llmfoundry/data/finetuning/dataloader.py | 1 - 1 file changed, 1 deletion(-) diff --git a/llmfoundry/data/finetuning/dataloader.py b/llmfoundry/data/finetuning/dataloader.py index 22f9d1bd2b..fce694f160 100644 --- a/llmfoundry/data/finetuning/dataloader.py +++ b/llmfoundry/data/finetuning/dataloader.py @@ -180,7 +180,6 @@ def build_finetuning_dataloader( given a starting workload YAML. """ dataset_cfg = dataset - print("🦖🦖🦖🦖 build finetuning dataloader called with dataset_cfg: ", dataset_cfg) is_streaming = ( dataset_cfg.get('remote') is not None or dataset_cfg.get('streams') is not None From b5bf28cf0dd9ef87e1d677ea0c080e048a35b06e Mon Sep 17 00:00:00 2001 From: Matthew Ding Date: Thu, 5 Dec 2024 16:54:11 -0800 Subject: [PATCH 17/24] assume single turn input --- .../data_prep/convert_delta_to_mds.py | 23 +++++++------------ 1 file changed, 8 insertions(+), 15 deletions(-) diff --git a/llmfoundry/command_utils/data_prep/convert_delta_to_mds.py b/llmfoundry/command_utils/data_prep/convert_delta_to_mds.py index 3da51faa52..9db204e71a 100644 --- a/llmfoundry/command_utils/data_prep/convert_delta_to_mds.py +++ b/llmfoundry/command_utils/data_prep/convert_delta_to_mds.py @@ -43,26 +43,19 @@ def get_conversion_config( ) if 'turns' in columns[0]: - logging.info('Identified IFT data') + logging.info('Identified IFT/CHAT data') dtypes = { 'input_ids': 'ndarray', 'attention_mask': 'ndarray', 'labels': 'ndarray', } - convert_x = lambda x: { - # join the turns into a single array - 'input_ids': - np.concatenate([ - np.array(turn['input_ids']) for turn in x['turns'] - ]), - 'attention_mask': - np.concatenate([ - np.array(turn['attention_mask']) for turn in x['turns'] - ]), - 'labels': - np. - concatenate([np.array(turn['labels']) for turn in x['turns']]), - } + convert_x = lambda x: ( + ValueError('More than one turn found') if len(x['turns']) > 1 else { + 'input_ids': np.array(x['turns'][0]['input_ids']), + 'attention_mask': np.array(x['turns'][0]['attention_mask']), + 'labels': np.array(x['turns'][0]['labels']), + } + ) elif 'concat_tokens' in columns[0]: logging.info('Identified CPT data') dtypes = { From 9372f488507d55991493bac52f324192a82df281 Mon Sep 17 00:00:00 2001 From: Matthew Ding Date: Fri, 6 Dec 2024 18:20:44 -0800 Subject: [PATCH 18/24] reuse convert_ft_dataset fn --- .../data_prep/convert_delta_to_mds.py | 98 +++++-------------- 1 file changed, 23 insertions(+), 75 deletions(-) diff --git a/llmfoundry/command_utils/data_prep/convert_delta_to_mds.py b/llmfoundry/command_utils/data_prep/convert_delta_to_mds.py index 9db204e71a..3acdecfec0 100644 --- a/llmfoundry/command_utils/data_prep/convert_delta_to_mds.py +++ b/llmfoundry/command_utils/data_prep/convert_delta_to_mds.py @@ -1,14 +1,10 @@ # Copyright 2024 MosaicML LLM Foundry authors # SPDX-License-Identifier: Apache-2.0 -import json import logging import os import tempfile -from typing import Callable, Optional - -import numpy as np -from streaming import MDSWriter +from typing import Optional from llmfoundry.command_utils.data_prep.convert_delta_to_json import ( _check_imports, @@ -17,59 +13,12 @@ get_columns_info, validate_and_get_cluster_info, ) +from llmfoundry.command_utils.data_prep.convert_finetuning_dataset import \ + convert_finetuning_dataset_from_args logger = logging.getLogger(__name__) -def get_conversion_config( - columns: list[str], - provided_dtypes: Optional[dict], -) -> tuple[dict, Callable]: - """If no dtypes is provided, attempts to infer config based on column names. - - Args: - columns (List[str]): The list of column names. - provided_dtypes (Optional[Dict]): The provided dtypes. - """ - if provided_dtypes is not None: - convert_x = lambda x: { - k: np.array(v, dtype=provided_dtypes.get(k)) for k, v in x.items() - } - return provided_dtypes, convert_x - - if len(columns) != 1: - raise ValueError( - 'Unable to infer dtypes from columns and no dtypes provided.', - ) - - if 'turns' in columns[0]: - logging.info('Identified IFT/CHAT data') - dtypes = { - 'input_ids': 'ndarray', - 'attention_mask': 'ndarray', - 'labels': 'ndarray', - } - convert_x = lambda x: ( - ValueError('More than one turn found') if len(x['turns']) > 1 else { - 'input_ids': np.array(x['turns'][0]['input_ids']), - 'attention_mask': np.array(x['turns'][0]['attention_mask']), - 'labels': np.array(x['turns'][0]['labels']), - } - ) - elif 'concat_tokens' in columns[0]: - logging.info('Identified CPT data') - dtypes = { - 'tokens': 'ndarray', - } - convert_x = lambda x: {'tokens': np.array(x['concat_tokens'])} - else: - raise ValueError( - 'Unable to infer dtypes from columns and no dtypes provided.', - ) - - return dtypes, convert_x - - def convert_delta_to_mds_from_args( delta_table_name: str, mds_output_folder: str, @@ -115,12 +64,6 @@ def convert_delta_to_mds_from_args( ) logger.info(f'Columns: {columns}') - dtypes, convert_x = get_conversion_config(columns, dtypes) - - compression = 'zstd:7' - hashes = ['sha1'] - limit = '10mb' - logging.info(f'Fetching data from Delta Table {delta_table_name}...') with tempfile.TemporaryDirectory() as json_out_folder: @@ -142,19 +85,24 @@ def convert_delta_to_mds_from_args( except Exception as e: logger.error(f'Error fetching data from Delta Table: {e}') raise e - with MDSWriter( - out=mds_output_folder, - columns=dtypes, - compression=compression, - hashes=hashes, - size_limit=limit, - ) as out: - try: - with open(json_full_filepath, 'r') as f: - for line in f: - out.write(convert_x(json.loads(line))) - except FileNotFoundError as e: - logger.error(f'JSON output file not found: {e}') - raise e - logging.info(f'Wrote to MDS at {mds_output_folder}') + convert_finetuning_dataset_from_args( + dataset='json', + data_subset=None, + splits=[''], + preprocessor=None, + data_files=[json_full_filepath], + skip_preprocessing=True, + out_root=mds_output_folder, + local=None, + compression='zstd:7', + num_workers=processes, + tokenizer=None, + tokenizer_kwargs=None, + max_seq_len=-1, + target_prompts='', + target_responses='', + encoder_decoder=False, + ) + + logging.info(f'Wrote to MDS at {mds_output_folder}') From 408b96feec831723c0b4f5787e5e16e0a673f43a Mon Sep 17 00:00:00 2001 From: Matthew Ding Date: Fri, 6 Dec 2024 18:33:07 -0800 Subject: [PATCH 19/24] update for ft --- llmfoundry/command_utils/data_prep/convert_delta_to_mds.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/llmfoundry/command_utils/data_prep/convert_delta_to_mds.py b/llmfoundry/command_utils/data_prep/convert_delta_to_mds.py index 3acdecfec0..5088455ba8 100644 --- a/llmfoundry/command_utils/data_prep/convert_delta_to_mds.py +++ b/llmfoundry/command_utils/data_prep/convert_delta_to_mds.py @@ -100,8 +100,8 @@ def convert_delta_to_mds_from_args( tokenizer=None, tokenizer_kwargs=None, max_seq_len=-1, - target_prompts='', - target_responses='', + target_prompts='none', + target_responses='all', encoder_decoder=False, ) From f47cfab3995ed48ed5e43f0048302eed2f106495 Mon Sep 17 00:00:00 2001 From: Matthew Ding Date: Fri, 6 Dec 2024 18:42:04 -0800 Subject: [PATCH 20/24] fix split --- llmfoundry/command_utils/data_prep/convert_delta_to_mds.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/llmfoundry/command_utils/data_prep/convert_delta_to_mds.py b/llmfoundry/command_utils/data_prep/convert_delta_to_mds.py index 5088455ba8..bdc8681d9e 100644 --- a/llmfoundry/command_utils/data_prep/convert_delta_to_mds.py +++ b/llmfoundry/command_utils/data_prep/convert_delta_to_mds.py @@ -89,7 +89,7 @@ def convert_delta_to_mds_from_args( convert_finetuning_dataset_from_args( dataset='json', data_subset=None, - splits=[''], + splits=['train'], preprocessor=None, data_files=[json_full_filepath], skip_preprocessing=True, From 46fc2d0ba244fcde42a80bcbff0bb8017b64e427 Mon Sep 17 00:00:00 2001 From: Matthew Ding Date: Fri, 6 Dec 2024 18:50:43 -0800 Subject: [PATCH 21/24] revert a few commits to not break --- .../data_prep/convert_delta_to_mds.py | 98 ++++++++++++++----- 1 file changed, 75 insertions(+), 23 deletions(-) diff --git a/llmfoundry/command_utils/data_prep/convert_delta_to_mds.py b/llmfoundry/command_utils/data_prep/convert_delta_to_mds.py index bdc8681d9e..41f205056f 100644 --- a/llmfoundry/command_utils/data_prep/convert_delta_to_mds.py +++ b/llmfoundry/command_utils/data_prep/convert_delta_to_mds.py @@ -1,10 +1,14 @@ # Copyright 2024 MosaicML LLM Foundry authors # SPDX-License-Identifier: Apache-2.0 +import json import logging import os import tempfile -from typing import Optional +from typing import Callable, Optional + +import numpy as np +from streaming import MDSWriter from llmfoundry.command_utils.data_prep.convert_delta_to_json import ( _check_imports, @@ -13,12 +17,59 @@ get_columns_info, validate_and_get_cluster_info, ) -from llmfoundry.command_utils.data_prep.convert_finetuning_dataset import \ - convert_finetuning_dataset_from_args logger = logging.getLogger(__name__) +def get_conversion_config( + columns: list[str], + provided_dtypes: Optional[dict], +) -> tuple[dict, Callable]: + """If no dtypes is provided, attempts to infer config based on column names. + + Args: + columns (List[str]): The list of column names. + provided_dtypes (Optional[Dict]): The provided dtypes. + """ + if provided_dtypes is not None: + convert_x = lambda x: { + k: np.array(v, dtype=provided_dtypes.get(k)) for k, v in x.items() + } + return provided_dtypes, convert_x + + if len(columns) != 1: + raise ValueError( + 'Unable to infer dtypes from columns and no dtypes provided.', + ) + + if 'turns' in columns[0]: + logging.info('Identified IFT/CHAT data') + dtypes = { + 'input_ids': 'ndarray', + 'attention_mask': 'ndarray', + 'labels': 'ndarray', + } + convert_x = lambda x: ( + ValueError('More than one turn found') if len(x['turns']) > 1 else { + 'input_ids': np.array(x['turns'][0]['input_ids']), + 'attention_mask': np.array(x['turns'][0]['attention_mask']), + 'labels': np.array(x['turns'][0]['labels']), + } + ) + elif 'concat_tokens' in columns[0]: + logging.info('Identified CPT data') + dtypes = { + 'tokens': 'ndarray', + } + convert_x = lambda x: {'tokens': np.array(x['concat_tokens'])} + else: + raise ValueError( + 'Unable to infer dtypes from columns and no dtypes provided.', + ) + + return dtypes, convert_x + + def convert_delta_to_mds_from_args( delta_table_name: str, mds_output_folder: str, @@ -64,6 +115,12 @@ def convert_delta_to_mds_from_args( ) logger.info(f'Columns: {columns}') + dtypes, convert_x = get_conversion_config(columns, dtypes) + + compression = 'zstd:7' + hashes = ['sha1'] + limit = '10mb' + logging.info(f'Fetching data from Delta Table {delta_table_name}...') with tempfile.TemporaryDirectory() as json_out_folder: @@ -85,24 +142,19 @@ def convert_delta_to_mds_from_args( except Exception as e: logger.error(f'Error fetching data from Delta Table: {e}') raise e + with MDSWriter( + out=mds_output_folder, + columns=dtypes, + compression=compression, + hashes=hashes, + size_limit=limit, + ) as out: + try: + with open(json_full_filepath, 'r') as f: + for line in f: + out.write(convert_x(json.loads(line))) + except FileNotFoundError as e: + logger.error(f'JSON output file not found: {e}') + raise e - convert_finetuning_dataset_from_args( - dataset='json', - data_subset=None, - splits=['train'], - preprocessor=None, - data_files=[json_full_filepath], - skip_preprocessing=True, - out_root=mds_output_folder, - local=None, - compression='zstd:7', - num_workers=processes, - tokenizer=None, - tokenizer_kwargs=None, - max_seq_len=-1, - target_prompts='none', - target_responses='all', - encoder_decoder=False, - ) - - logging.info(f'Wrote to MDS at {mds_output_folder}') + logging.info(f'Wrote to MDS at {mds_output_folder}') \ No newline at end of file From bb757c7eb555bea03397ef1c17b729eea4d2195d Mon Sep 17 00:00:00 2001 From: Matthew Ding Date: Wed, 18 Dec 2024 17:48:07 -0800 Subject: [PATCH 22/24] rename file to train.jsonl --- llmfoundry/command_utils/data_prep/convert_delta_to_mds.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/llmfoundry/command_utils/data_prep/convert_delta_to_mds.py b/llmfoundry/command_utils/data_prep/convert_delta_to_mds.py index 41f205056f..923b5ae95a 100644 --- a/llmfoundry/command_utils/data_prep/convert_delta_to_mds.py +++ b/llmfoundry/command_utils/data_prep/convert_delta_to_mds.py @@ -124,7 +124,7 @@ def convert_delta_to_mds_from_args( logging.info(f'Fetching data from Delta Table {delta_table_name}...') with tempfile.TemporaryDirectory() as json_out_folder: - json_out_filename = 'temp.jsonl' + json_out_filename = 'train.jsonl' json_full_filepath = os.path.join(json_out_folder, json_out_filename) try: fetch_DT( @@ -157,4 +157,4 @@ def convert_delta_to_mds_from_args( logger.error(f'JSON output file not found: {e}') raise e - logging.info(f'Wrote to MDS at {mds_output_folder}') \ No newline at end of file + logging.info(f'Wrote to MDS at {mds_output_folder}') From 6c3e0a7b7e4b2a1b6488705f874857197b8244f8 Mon Sep 17 00:00:00 2001 From: Matthew Ding Date: Sun, 22 Dec 2024 22:22:03 -0800 Subject: [PATCH 23/24] add debugging statement --- llmfoundry/data/finetuning/dataloader.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/llmfoundry/data/finetuning/dataloader.py b/llmfoundry/data/finetuning/dataloader.py index fce694f160..74e2ff1185 100644 --- a/llmfoundry/data/finetuning/dataloader.py +++ b/llmfoundry/data/finetuning/dataloader.py @@ -179,6 +179,8 @@ def build_finetuning_dataloader( padding/waste rates for different `cfg.dataset.packing_ratio` choices, given a starting workload YAML. """ + print(f"🚨🚨🚨 build_finetuning_dataloader was called with the following arguments: tokenizer={tokenizer}, device_batch_size={device_batch_size}, dataset={dataset}, num_workers={num_workers}, drop_last={drop_last}, pin_memory={pin_memory}, prefetch_factor={prefetch_factor}, persistent_workers={persistent_workers}, timeout={timeout}") + dataset_cfg = dataset is_streaming = ( dataset_cfg.get('remote') is not None or From 2712e2b0e5008d16f03f8263640cf59070e009d6 Mon Sep 17 00:00:00 2001 From: Matthew Ding Date: Fri, 27 Dec 2024 12:42:14 -0800 Subject: [PATCH 24/24] change debugging statement --- llmfoundry/data/finetuning/dataloader.py | 2 -- llmfoundry/utils/config_utils.py | 5 +++++ 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/llmfoundry/data/finetuning/dataloader.py b/llmfoundry/data/finetuning/dataloader.py index 74e2ff1185..fce694f160 100644 --- a/llmfoundry/data/finetuning/dataloader.py +++ b/llmfoundry/data/finetuning/dataloader.py @@ -179,8 +179,6 @@ def build_finetuning_dataloader( padding/waste rates for different `cfg.dataset.packing_ratio` choices, given a starting workload YAML. """ - print(f"🚨🚨🚨 build_finetuning_dataloader was called with the following arguments: tokenizer={tokenizer}, device_batch_size={device_batch_size}, dataset={dataset}, num_workers={num_workers}, drop_last={drop_last}, pin_memory={pin_memory}, prefetch_factor={prefetch_factor}, persistent_workers={persistent_workers}, timeout={timeout}") - dataset_cfg = dataset is_streaming = ( dataset_cfg.get('remote') is not None or diff --git a/llmfoundry/utils/config_utils.py b/llmfoundry/utils/config_utils.py index 997273de7f..e0acef93c1 100644 --- a/llmfoundry/utils/config_utils.py +++ b/llmfoundry/utils/config_utils.py @@ -429,6 +429,7 @@ def calculate_batch_size_info( int, Literal['auto']]]: world_size = dist.get_world_size() + print(f"🚨🚨🚨 world_size: {world_size}") if world_size % data_replication_degree != 0: raise ValueError( f'World size {world_size} is not divisible by data replication degree {data_replication_degree}.', @@ -457,6 +458,10 @@ def calculate_batch_size_info( ) else: raise ValueError(f'Not sure how to parse {device_microbatch_size=}') + + print(f"🚨🚨🚨 device_batch_size: {device_batch_size}") + print(f"🚨🚨🚨 device_microbatch_size: {device_microbatch_size}") + print(f"🚨🚨🚨 device_grad_accum: {device_grad_accum}") return device_batch_size, device_microbatch_size, device_grad_accum