Skip to content

Commit

Permalink
refactor: remove unused generate_data argument
Browse files Browse the repository at this point in the history
The logger argument is already removed from the cli.

Signed-off-by: Costa Shulyupin <[email protected]>
  • Loading branch information
makelinux committed Nov 26, 2024
1 parent eb8119c commit 07e5d74
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 34 deletions.
3 changes: 0 additions & 3 deletions src/instructlab/sdg/generate_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,11 +286,8 @@ def _mixer_init(


# This is part of the public API, and used by instructlab.
# TODO - parameter removal needs to be done in sync with a CLI change.
# to be removed: logger
def generate_data(
client: openai.OpenAI,
logger: logging.Logger = logger, # pylint: disable=redefined-outer-name
system_prompt: Optional[str] = None,
use_legacy_pretraining_format: Optional[bool] = True,
model_family: Optional[str] = None,
Expand Down
62 changes: 31 additions & 31 deletions tests/test_generate_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,12 @@
import yaml

# First Party
from instructlab.sdg.generate_data import _context_init, _sdg_init, generate_data
from instructlab.sdg.generate_data import (
_context_init,
_sdg_init,
generate_data,
logger,
)
from instructlab.sdg.llmblock import LLMBlock
from instructlab.sdg.pipeline import PipelineContext

Expand Down Expand Up @@ -309,19 +314,17 @@ def setUp(self):
)

def test_generate(self):
with patch("logging.Logger.info") as mocked_logger:
generate_data(
client=MagicMock(),
logger=mocked_logger,
model_family="merlinite",
model_name="models/merlinite-7b-lab-Q4_K_M.gguf",
num_instructions_to_generate=10,
taxonomy=self.test_taxonomy.root,
taxonomy_base=TEST_TAXONOMY_BASE,
output_dir=self.tmp_path,
pipeline="simple",
system_prompt=TEST_SYS_PROMPT,
)
generate_data(
client=MagicMock(),
model_family="merlinite",
model_name="models/merlinite-7b-lab-Q4_K_M.gguf",
num_instructions_to_generate=10,
taxonomy=self.test_taxonomy.root,
taxonomy_base=TEST_TAXONOMY_BASE,
output_dir=self.tmp_path,
pipeline="simple",
system_prompt=TEST_SYS_PROMPT,
)

for name in ["test_*.jsonl", "train_*.jsonl", "messages_*.jsonl"]:
matches = glob.glob(os.path.join(self.tmp_path, name))
Expand Down Expand Up @@ -386,21 +389,19 @@ def setUp(self):
self.expected_train_samples = generate_train_samples(test_valid_knowledge_skill)

def test_generate(self):
with patch("logging.Logger.info") as mocked_logger:
generate_data(
client=MagicMock(),
logger=mocked_logger,
model_family="merlinite",
model_name="models/merlinite-7b-lab-Q4_K_M.gguf",
num_instructions_to_generate=10,
taxonomy=self.test_taxonomy.root,
taxonomy_base=TEST_TAXONOMY_BASE,
output_dir=self.tmp_path,
chunk_word_count=1000,
server_ctx_size=4096,
pipeline="simple",
system_prompt=TEST_SYS_PROMPT,
)
generate_data(
client=MagicMock(),
model_family="merlinite",
model_name="models/merlinite-7b-lab-Q4_K_M.gguf",
num_instructions_to_generate=10,
taxonomy=self.test_taxonomy.root,
taxonomy_base=TEST_TAXONOMY_BASE,
output_dir=self.tmp_path,
chunk_word_count=1000,
server_ctx_size=4096,
pipeline="simple",
system_prompt=TEST_SYS_PROMPT,
)

for name in ["test_*.jsonl", "train_*.jsonl", "messages_*.jsonl"]:
matches = glob.glob(os.path.join(self.tmp_path, name))
Expand Down Expand Up @@ -484,10 +485,9 @@ def setUp(self):
)

def test_generate(self):
with patch("logging.Logger.info") as mocked_logger:
with patch("instructlab.sdg.generate_data.logger") as mocked_logger:
generate_data(
client=MagicMock(),
logger=mocked_logger,
model_family="merlinite",
model_name="models/merlinite-7b-lab-Q4_K_M.gguf",
num_instructions_to_generate=10,
Expand Down

0 comments on commit 07e5d74

Please sign in to comment.