diff --git a/src/instructlab/sdg/generate_data.py b/src/instructlab/sdg/generate_data.py index 533db868..11ff5807 100644 --- a/src/instructlab/sdg/generate_data.py +++ b/src/instructlab/sdg/generate_data.py @@ -281,11 +281,8 @@ def _mixer_init( # This is part of the public API, and used by instructlab. -# TODO - parameter removal needs to be done in sync with a CLI change. -# to be removed: logger def generate_data( client: openai.OpenAI, - logger: logging.Logger = logger, # pylint: disable=redefined-outer-name system_prompt: Optional[str] = None, use_legacy_pretraining_format: Optional[bool] = True, model_family: Optional[str] = None, diff --git a/tests/test_generate_data.py b/tests/test_generate_data.py index c5415636..b57ce01a 100644 --- a/tests/test_generate_data.py +++ b/tests/test_generate_data.py @@ -21,7 +21,12 @@ # First Party from instructlab.sdg import LLMBlock, PipelineContext -from instructlab.sdg.generate_data import _context_init, _sdg_init, generate_data +from instructlab.sdg.generate_data import ( + _context_init, + _sdg_init, + generate_data, + logger, +) TEST_SYS_PROMPT = "I am, Red Hat® Instruct Model based on Granite 7B, an AI language model developed by Red Hat and IBM Research, based on the Granite-7b-base language model. My primary function is to be a chat assistant." @@ -308,19 +313,17 @@ def setUp(self): ) def test_generate(self): - with patch("logging.Logger.info") as mocked_logger: - generate_data( - client=MagicMock(), - logger=mocked_logger, - model_family="merlinite", - model_name="models/merlinite-7b-lab-Q4_K_M.gguf", - num_instructions_to_generate=10, - taxonomy=self.test_taxonomy.root, - taxonomy_base=TEST_TAXONOMY_BASE, - output_dir=self.tmp_path, - pipeline="simple", - system_prompt=TEST_SYS_PROMPT, - ) + generate_data( + client=MagicMock(), + model_family="merlinite", + model_name="models/merlinite-7b-lab-Q4_K_M.gguf", + num_instructions_to_generate=10, + taxonomy=self.test_taxonomy.root, + taxonomy_base=TEST_TAXONOMY_BASE, + output_dir=self.tmp_path, + pipeline="simple", + system_prompt=TEST_SYS_PROMPT, + ) for name in ["test_*.jsonl", "train_*.jsonl", "messages_*.jsonl"]: matches = glob.glob(os.path.join(self.tmp_path, name)) @@ -391,21 +394,19 @@ def setUp(self): self.expected_train_samples = generate_train_samples(test_valid_knowledge_skill) def test_generate(self): - with patch("logging.Logger.info") as mocked_logger: - generate_data( - client=MagicMock(), - logger=mocked_logger, - model_family="merlinite", - model_name="models/merlinite-7b-lab-Q4_K_M.gguf", - num_instructions_to_generate=10, - taxonomy=self.test_taxonomy.root, - taxonomy_base=TEST_TAXONOMY_BASE, - output_dir=self.tmp_path, - chunk_word_count=1000, - server_ctx_size=4096, - pipeline="simple", - system_prompt=TEST_SYS_PROMPT, - ) + generate_data( + client=MagicMock(), + model_family="merlinite", + model_name="models/merlinite-7b-lab-Q4_K_M.gguf", + num_instructions_to_generate=10, + taxonomy=self.test_taxonomy.root, + taxonomy_base=TEST_TAXONOMY_BASE, + output_dir=self.tmp_path, + chunk_word_count=1000, + server_ctx_size=4096, + pipeline="simple", + system_prompt=TEST_SYS_PROMPT, + ) for name in ["test_*.jsonl", "train_*.jsonl", "messages_*.jsonl"]: matches = glob.glob(os.path.join(self.tmp_path, name)) @@ -495,10 +496,9 @@ def setUp(self): ) def test_generate(self): - with patch("logging.Logger.info") as mocked_logger: + with patch("instructlab.sdg.generate_data.logger") as mocked_logger: generate_data( client=MagicMock(), - logger=mocked_logger, model_family="merlinite", model_name="models/merlinite-7b-lab-Q4_K_M.gguf", num_instructions_to_generate=10,