Skip to content

Commit

Permalink
refactor: remove unused generate_data argument
Browse files Browse the repository at this point in the history
The logger argument is already removed from the cli.

Signed-off-by: Costa Shulyupin <[email protected]>
  • Loading branch information
makelinux committed Dec 12, 2024
1 parent 31ff971 commit ae516c3
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 34 deletions.
3 changes: 0 additions & 3 deletions src/instructlab/sdg/generate_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,11 +281,8 @@ def _mixer_init(


# This is part of the public API, and used by instructlab.
# TODO - parameter removal needs to be done in sync with a CLI change.
# to be removed: logger
def generate_data(
client: openai.OpenAI,
logger: logging.Logger = logger, # pylint: disable=redefined-outer-name
system_prompt: Optional[str] = None,
use_legacy_pretraining_format: Optional[bool] = True,
model_family: Optional[str] = None,
Expand Down
62 changes: 31 additions & 31 deletions tests/test_generate_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,12 @@

# First Party
from instructlab.sdg import LLMBlock, PipelineContext
from instructlab.sdg.generate_data import _context_init, _sdg_init, generate_data
from instructlab.sdg.generate_data import (
_context_init,
_sdg_init,
generate_data,
logger,
)

TEST_SYS_PROMPT = "I am, Red Hat® Instruct Model based on Granite 7B, an AI language model developed by Red Hat and IBM Research, based on the Granite-7b-base language model. My primary function is to be a chat assistant."

Expand Down Expand Up @@ -308,19 +313,17 @@ def setUp(self):
)

def test_generate(self):
with patch("logging.Logger.info") as mocked_logger:
generate_data(
client=MagicMock(),
logger=mocked_logger,
model_family="merlinite",
model_name="models/merlinite-7b-lab-Q4_K_M.gguf",
num_instructions_to_generate=10,
taxonomy=self.test_taxonomy.root,
taxonomy_base=TEST_TAXONOMY_BASE,
output_dir=self.tmp_path,
pipeline="simple",
system_prompt=TEST_SYS_PROMPT,
)
generate_data(
client=MagicMock(),
model_family="merlinite",
model_name="models/merlinite-7b-lab-Q4_K_M.gguf",
num_instructions_to_generate=10,
taxonomy=self.test_taxonomy.root,
taxonomy_base=TEST_TAXONOMY_BASE,
output_dir=self.tmp_path,
pipeline="simple",
system_prompt=TEST_SYS_PROMPT,
)

for name in ["test_*.jsonl", "train_*.jsonl", "messages_*.jsonl"]:
matches = glob.glob(os.path.join(self.tmp_path, name))
Expand Down Expand Up @@ -391,21 +394,19 @@ def setUp(self):
self.expected_train_samples = generate_train_samples(test_valid_knowledge_skill)

def test_generate(self):
with patch("logging.Logger.info") as mocked_logger:
generate_data(
client=MagicMock(),
logger=mocked_logger,
model_family="merlinite",
model_name="models/merlinite-7b-lab-Q4_K_M.gguf",
num_instructions_to_generate=10,
taxonomy=self.test_taxonomy.root,
taxonomy_base=TEST_TAXONOMY_BASE,
output_dir=self.tmp_path,
chunk_word_count=1000,
server_ctx_size=4096,
pipeline="simple",
system_prompt=TEST_SYS_PROMPT,
)
generate_data(
client=MagicMock(),
model_family="merlinite",
model_name="models/merlinite-7b-lab-Q4_K_M.gguf",
num_instructions_to_generate=10,
taxonomy=self.test_taxonomy.root,
taxonomy_base=TEST_TAXONOMY_BASE,
output_dir=self.tmp_path,
chunk_word_count=1000,
server_ctx_size=4096,
pipeline="simple",
system_prompt=TEST_SYS_PROMPT,
)

for name in ["test_*.jsonl", "train_*.jsonl", "messages_*.jsonl"]:
matches = glob.glob(os.path.join(self.tmp_path, name))
Expand Down Expand Up @@ -495,10 +496,9 @@ def setUp(self):
)

def test_generate(self):
with patch("logging.Logger.info") as mocked_logger:
with patch("instructlab.sdg.generate_data.logger") as mocked_logger:
generate_data(
client=MagicMock(),
logger=mocked_logger,
model_family="merlinite",
model_name="models/merlinite-7b-lab-Q4_K_M.gguf",
num_instructions_to_generate=10,
Expand Down

0 comments on commit ae516c3

Please sign in to comment.