Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: remove unused generate_data argument #411

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions src/instructlab/sdg/generate_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,11 +281,8 @@ def _mixer_init(


# This is part of the public API, and used by instructlab.
# TODO - parameter removal needs to be done in sync with a CLI change.
# to be removed: logger
def generate_data(
client: openai.OpenAI,
logger: logging.Logger = logger, # pylint: disable=redefined-outer-name
system_prompt: Optional[str] = None,
use_legacy_pretraining_format: Optional[bool] = True,
model_family: Optional[str] = None,
Expand Down
62 changes: 31 additions & 31 deletions tests/test_generate_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,12 @@

# First Party
from instructlab.sdg import LLMBlock, PipelineContext
from instructlab.sdg.generate_data import _context_init, _sdg_init, generate_data
from instructlab.sdg.generate_data import (
_context_init,
_sdg_init,
generate_data,
logger,
)

TEST_SYS_PROMPT = "I am, Red Hat® Instruct Model based on Granite 7B, an AI language model developed by Red Hat and IBM Research, based on the Granite-7b-base language model. My primary function is to be a chat assistant."

Expand Down Expand Up @@ -308,19 +313,17 @@ def setUp(self):
)

def test_generate(self):
with patch("logging.Logger.info") as mocked_logger:
generate_data(
client=MagicMock(),
logger=mocked_logger,
model_family="merlinite",
model_name="models/merlinite-7b-lab-Q4_K_M.gguf",
num_instructions_to_generate=10,
taxonomy=self.test_taxonomy.root,
taxonomy_base=TEST_TAXONOMY_BASE,
output_dir=self.tmp_path,
pipeline="simple",
system_prompt=TEST_SYS_PROMPT,
)
generate_data(
client=MagicMock(),
model_family="merlinite",
model_name="models/merlinite-7b-lab-Q4_K_M.gguf",
num_instructions_to_generate=10,
taxonomy=self.test_taxonomy.root,
taxonomy_base=TEST_TAXONOMY_BASE,
output_dir=self.tmp_path,
pipeline="simple",
system_prompt=TEST_SYS_PROMPT,
)

for name in ["test_*.jsonl", "train_*.jsonl", "messages_*.jsonl"]:
matches = glob.glob(os.path.join(self.tmp_path, name))
Expand Down Expand Up @@ -391,21 +394,19 @@ def setUp(self):
self.expected_train_samples = generate_train_samples(test_valid_knowledge_skill)

def test_generate(self):
with patch("logging.Logger.info") as mocked_logger:
generate_data(
client=MagicMock(),
logger=mocked_logger,
model_family="merlinite",
model_name="models/merlinite-7b-lab-Q4_K_M.gguf",
num_instructions_to_generate=10,
taxonomy=self.test_taxonomy.root,
taxonomy_base=TEST_TAXONOMY_BASE,
output_dir=self.tmp_path,
chunk_word_count=1000,
server_ctx_size=4096,
pipeline="simple",
system_prompt=TEST_SYS_PROMPT,
)
generate_data(
client=MagicMock(),
model_family="merlinite",
model_name="models/merlinite-7b-lab-Q4_K_M.gguf",
num_instructions_to_generate=10,
taxonomy=self.test_taxonomy.root,
taxonomy_base=TEST_TAXONOMY_BASE,
output_dir=self.tmp_path,
chunk_word_count=1000,
server_ctx_size=4096,
pipeline="simple",
system_prompt=TEST_SYS_PROMPT,
)

for name in ["test_*.jsonl", "train_*.jsonl", "messages_*.jsonl"]:
matches = glob.glob(os.path.join(self.tmp_path, name))
Expand Down Expand Up @@ -495,10 +496,9 @@ def setUp(self):
)

def test_generate(self):
with patch("logging.Logger.info") as mocked_logger:
with patch("instructlab.sdg.generate_data.logger") as mocked_logger:
generate_data(
client=MagicMock(),
logger=mocked_logger,
model_family="merlinite",
model_name="models/merlinite-7b-lab-Q4_K_M.gguf",
num_instructions_to_generate=10,
Expand Down
Loading