Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ability to log model instead of saving to folder #1683

Open
wants to merge 18 commits into
base: main
Choose a base branch
from
121 changes: 82 additions & 39 deletions llmfoundry/callbacks/hf_checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,7 @@ def _log_model_with_multi_process(
registered_model_name: Optional[str],
await_registration_for: int,
mlflow_logging_config: dict[str, Any],
register: bool,
):
"""Call MLFlowLogger.log_model.

Expand Down Expand Up @@ -206,7 +207,7 @@ def save_model_patch(*args: Any, **kwargs: Any):
transformers_model=transformers_model,
flavor='transformers',
artifact_path=artifact_path,
registered_model_name=register_model_path,
registered_model_name=register_model_path if register else None,
run_id=mlflow_logger._run_id,
await_registration_for=await_registration_for,
**mlflow_logging_config,
Expand Down Expand Up @@ -246,13 +247,14 @@ class HuggingFaceCheckpointer(Callback):
"""Save a huggingface formatted checkpoint during training.

Args:
save_folder (str): Top level folder to save checkpoints to (can be a
URI). It is likely that this would be the same as your save_folder.
save_interval: Union[str, int, Time]: The interval describing how often
checkpoints should be saved. If an integer, it will be assumed to be
in :attr:`.TimeUnit.EPOCH`. Otherwise, the unit must be either
:attr:`.TimeUnit.EPOCH`, :attr:`.TimeUnit.BATCH`,
:attr:`.TimeUnit.TOKEN`, or :attr:`.TimeUnit.SAMPLE`.
save_folder (Optional[str]): Top level folder to save checkpoints to (can be a
URI). It is likely that this would be the same as your save_folder. If
set to None, the model will be logged to MLFlow.
huggingface_folder_name (str): Folder to save each checkpoint under (can
be a format string). Default is ``ba{batch}``.
precision: The precision to save the model in. Default is ``float32``.
Expand All @@ -278,8 +280,8 @@ class HuggingFaceCheckpointer(Callback):

def __init__(
self,
save_folder: str,
save_interval: Union[str, int, Time],
save_folder: Optional[str] = None,
irenedea marked this conversation as resolved.
Show resolved Hide resolved
huggingface_folder_name: str = 'ba{batch}',
precision: str = 'float32',
overwrite: bool = True,
Expand All @@ -289,7 +291,11 @@ def __init__(
final_register_only: bool = False,
register_wait_seconds: int = 7200,
):
_, _, self.save_dir_format_str = parse_uri(save_folder)
if save_folder is None and mlflow_registered_model_name is None:
raise ValueError(
'No `save_folder` or `mlflow_register_model_name` set. Please set at least one of them.',
irenedea marked this conversation as resolved.
Show resolved Hide resolved
)

self.overwrite = overwrite
self.precision = precision
self.dtype = {
Expand Down Expand Up @@ -365,12 +371,20 @@ def __init__(
self.save_interval,
include_end_of_training=True,
)
self.remote_ud = maybe_create_remote_uploader_downloader_from_uri(
save_folder,
loggers=[],
)
if self.remote_ud is not None:
self.remote_ud._num_concurrent_uploads = 4

self.save_folder = save_folder

self.remote_ud = None
if self.save_folder is not None:
_, _, self.save_dir_format_str = parse_uri(self.save_folder)
self.remote_ud = maybe_create_remote_uploader_downloader_from_uri(
self.save_folder,
loggers=[],
)
if self.remote_ud is not None:
self.remote_ud._num_concurrent_uploads = 4
else:
self.save_dir_format_str = tempfile.mkdtemp()

self.last_checkpoint_batch: Optional[Time] = None
self.mlflow_loggers = []
Expand All @@ -389,12 +403,16 @@ def run_event(self, event: Event, state: State, logger: Logger) -> None:
self._save_checkpoint(
state,
logger,
register_to_mlflow=(
self.mlflow_registered_model_name is not None and
is_last_batch
log_to_mlflow=(
self.save_folder is None or (
irenedea marked this conversation as resolved.
Show resolved Hide resolved
self.mlflow_registered_model_name is not None and
is_last_batch
)
),
upload_to_save_folder=not self.final_register_only or
not is_last_batch,
upload_to_save_folder=self.save_folder is not None and
(not self.final_register_only or not is_last_batch),
register=self.mlflow_registered_model_name is not None and
is_last_batch, # Register only on the last batch
)
elif event == Event.INIT:
if not isinstance(state.model, HuggingFaceModel):
Expand Down Expand Up @@ -457,15 +475,24 @@ def run_event(self, event: Event, state: State, logger: Logger) -> None:
if self._any_register_processes_error(
state.device,
) and self.final_register_only:
log.error(
'An error occurred in one or more registration processes. Fallback to saving the HuggingFace checkpoint.',
)
self._save_checkpoint(
state,
logger,
upload_to_save_folder=True,
register_to_mlflow=False,
)
if self.save_folder:
log.error(
'An error occurred in one or more registration processes. Fallback to saving the HuggingFace checkpoint.',
)
self._save_checkpoint(
state,
logger,
upload_to_save_folder=True,
irenedea marked this conversation as resolved.
Show resolved Hide resolved
log_to_mlflow=False,
register=False,
)
else:
# Clean up temporary save directory and raise an error.
if self.temp_save_dir is not None:
shutil.rmtree(self.temp_save_dir)
raise Exception(
'An error occurred in one or more registration processes.',
)

# Clean up temporary save directory; all processes are done with it.
if self.temp_save_dir is not None:
Expand Down Expand Up @@ -587,18 +614,28 @@ def _save_checkpoint(
state: State,
logger: Logger,
upload_to_save_folder: bool,
register_to_mlflow: bool,
log_to_mlflow: bool,
register: bool,
):
"""Save a HuggingFace formatted checkpoint.

Args:
state (State): The training state.
logger (Logger): The logger.
upload_to_save_folder (bool): Whether to upload the HF checkpoint to the save folder.
register_to_mlflow (bool): Whether to register the model to MLFlow
log_to_mlflow (bool): Whether to log the model to MLFlow
register (bool): Whether to register the model when logging to MLFlow

Raises:
ValueError: If `register` is True but `log_to_mlflow` is False.
"""
del logger # unused

if log_to_mlflow is False and register is True:
raise ValueError(
f'Got {log_to_mlflow=} and {register=}. Cannot register the model if it is not logged.',
)

self.last_checkpoint_batch = state.timestamp.batch

log.info('Saving HuggingFace formatted checkpoint')
Expand All @@ -619,7 +656,7 @@ def _save_checkpoint(

# Use a temporary directory if save_dir is remote.
use_temp_dir = self.remote_ud is not None
temp_save_dir = tempfile.mkdtemp() if use_temp_dir else save_dir
local_save_dir = tempfile.mkdtemp() if use_temp_dir else save_dir

log.debug('Gathering state dict')

Expand Down Expand Up @@ -744,26 +781,26 @@ def tensor_hook(
)
with context_manager:
new_model_instance.save_pretrained(
temp_save_dir,
local_save_dir,
max_shard_size='1GB',
)
if original_tokenizer is not None:
assert isinstance(
original_tokenizer,
PreTrainedTokenizerBase,
)
original_tokenizer.save_pretrained(temp_save_dir)
original_tokenizer.save_pretrained(local_save_dir)

# Only need to edit files for MPT because it has custom code
if new_model_instance.config.model_type == 'mpt':
log.debug('Editing MPT files for HuggingFace compatibility')
edit_files_for_hf_compatibility(
temp_save_dir,
local_save_dir,
self.flatten_imports,
)

if self.remote_ud is not None:
for filename in os.listdir(temp_save_dir):
for filename in os.listdir(local_save_dir):
remote_file_name = os.path.join(save_dir, filename)
remote_file_uri = self.remote_ud.remote_backend.get_uri(
remote_file_name,
Expand All @@ -775,15 +812,15 @@ def tensor_hook(
state=state,
remote_file_name=remote_file_name,
file_path=Path(
os.path.join(temp_save_dir, filename),
os.path.join(local_save_dir, filename),
),
overwrite=self.overwrite,
)

dist.barrier()

if dist.get_global_rank() == 0:
if register_to_mlflow:
if log_to_mlflow:
assert new_model_instance is not None
new_model_instance = self.transform_model_pre_registration(
new_model_instance,
Expand All @@ -795,11 +832,11 @@ def tensor_hook(
state,
new_model_instance,
original_tokenizer,
temp_save_dir,
local_save_dir,
)
else:
register_save_dir = os.path.join(
temp_save_dir,
local_save_dir,
'register_save',
)
new_model_instance.save_pretrained(
Expand Down Expand Up @@ -829,7 +866,11 @@ def tensor_hook(
'transformers_model':
register_save_dir,
'artifact_path':
'final_model_checkpoint',
format_name_with_dist_and_time(
self.huggingface_folder_name_fstr,
state.run_name,
state.timestamp,
),
'pretrained_model_name':
self.pretrained_model_name,
'registered_model_name':
Expand All @@ -838,6 +879,8 @@ def tensor_hook(
3600,
'mlflow_logging_config':
self.mlflow_logging_config,
'register':
register,
},
)

Expand All @@ -846,11 +889,11 @@ def tensor_hook(

# Save the temporary directory to be cleaned up later.
if use_temp_dir:
self.temp_save_dir = temp_save_dir
self.temp_save_dir = local_save_dir
else:
# Clean up the temporary directory if we don't need to register to mlflow.
if use_temp_dir:
shutil.rmtree(temp_save_dir)
shutil.rmtree(local_save_dir)
dist.barrier()

def _save_and_register_peft_model(
Expand Down
3 changes: 2 additions & 1 deletion llmfoundry/command_utils/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -618,7 +618,8 @@ def train(cfg: DictConfig) -> Trainer:
trainer.state,
trainer.logger,
upload_to_save_folder=True,
register_to_mlflow=True,
log_to_mlflow=True,
register=True,
)
return trainer

Expand Down
20 changes: 13 additions & 7 deletions tests/a_scripts/inference/test_convert_composer_to_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -450,30 +450,34 @@ def test_final_register_only(
assert checkpointer_callback._save_checkpoint.call_count == 2
assert checkpointer_callback._save_checkpoint.call_args_list[
0].kwargs == {
'register_to_mlflow': True,
'log_to_mlflow': True,
'upload_to_save_folder': False,
'register': True,
}
assert checkpointer_callback._save_checkpoint.call_args_list[
1].kwargs == {
'register_to_mlflow': False,
'log_to_mlflow': False,
'upload_to_save_folder': True,
'register': False,
}
else:
# No mlflow_registry_error, so we should only register the model
assert checkpointer_callback._save_checkpoint.call_count == 1
assert checkpointer_callback._save_checkpoint.call_args_list[
0].kwargs == {
'register_to_mlflow': True,
'log_to_mlflow': True,
'upload_to_save_folder': False,
'register': True,
}
else:
# No mlflow_registered_model_name, so we should only save the checkpoint
assert mlflow_logger_mock.log_model.call_count == 0
assert checkpointer_callback._save_checkpoint.call_count == 1
assert checkpointer_callback._save_checkpoint.call_args_list[
0].kwargs == {
'register_to_mlflow': False,
'log_to_mlflow': False,
'upload_to_save_folder': True,
'register': False,
}


Expand Down Expand Up @@ -549,7 +553,7 @@ def test_huggingface_conversion_callback_interval(
mlflow_logger_mock.log_model.assert_called_with(
transformers_model=ANY,
flavor='transformers',
artifact_path='final_model_checkpoint',
artifact_path=f'huggingface/ba{trainer.state.timestamp.batch.value}',
registered_model_name='dummy-registered-name',
run_id='mlflow-run-id',
await_registration_for=3600,
Expand Down Expand Up @@ -1173,7 +1177,8 @@ def transform_model_pre_registration(self, model: PreTrainedModel):
state=state,
logger=logger,
upload_to_save_folder=True,
register_to_mlflow=True,
log_to_mlflow=True,
register=True,
)

checkpointer._save_and_register_peft_model.assert_not_called()
Expand Down Expand Up @@ -1776,5 +1781,6 @@ def __init__(self, config: PretrainedConfig):
state=state,
logger=logger,
upload_to_save_folder=False,
register_to_mlflow=False,
log_to_mlflow=False,
register=False,
)
2 changes: 1 addition & 1 deletion tests/a_scripts/train/test_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def test_sort_callbacks():
trainer_mock = Mock()
trainer_mock.state.callbacks = [
CheckpointSaver(),
HuggingFaceCheckpointer('save-folder', '1ba'),
HuggingFaceCheckpointer(save_interval='1ba', save_folder='save-folder'),
RunTimeoutCallback(),
]
_sort_callbacks(trainer_mock)
Expand Down
8 changes: 6 additions & 2 deletions tests/callbacks/test_callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,11 @@
import pytest
from composer.core import Callback

from llmfoundry.callbacks.async_eval_callback import AsyncEval
from llmfoundry.callbacks.curriculum_learning_callback import CurriculumLearning
from llmfoundry.callbacks import (
AsyncEval,
CurriculumLearning,
HuggingFaceCheckpointer,
)
from llmfoundry.interfaces.callback_with_config import CallbackWithConfig
from llmfoundry.registry import callbacks, callbacks_with_config
from llmfoundry.utils.builders import build_callback
Expand All @@ -20,6 +23,7 @@
skip_callbacks = [
AsyncEval,
CurriculumLearning,
HuggingFaceCheckpointer,
]


Expand Down
Loading