Skip to content

Commit

Permalink
Merge pull request #194 from tiran/module-logger
Browse files Browse the repository at this point in the history
 Remove calls to logging.basicConfig on import
  • Loading branch information
markmc authored Jul 27, 2024
2 parents 3362e18 + 81a69bb commit d9cc9b7
Show file tree
Hide file tree
Showing 10 changed files with 32 additions and 50 deletions.
6 changes: 2 additions & 4 deletions src/instructlab/sdg/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,13 @@
# Standard
from abc import ABC
from typing import Any, Dict, Union
import logging
import os.path

# Third Party
import yaml

# Local
from .logger_config import setup_logger

logger = setup_logger(__name__)
logger = logging.getLogger(__name__)


# This is part of the public API.
Expand Down
4 changes: 2 additions & 2 deletions src/instructlab/sdg/datamixing.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Standard
from typing import Optional
import json
import logging
import os.path
import random
import uuid
Expand All @@ -10,11 +11,10 @@
import yaml

# First Party
from instructlab.sdg.logger_config import setup_logger
from instructlab.sdg.utils import GenerateException, pandas

ALLOWED_COLS = ["id", "messages", "metadata"]
logger = setup_logger(__name__)
logger = logging.getLogger(__name__)


def _adjust_train_sample_size(ds: Dataset, num_samples: int):
Expand Down
6 changes: 2 additions & 4 deletions src/instructlab/sdg/eval_data.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Standard
from importlib import resources
from typing import Any
import logging
import re

# Third Party
Expand All @@ -10,10 +11,7 @@
# First Party
from instructlab.sdg.pipeline import EVAL_PIPELINES_PKG, Pipeline

# Local
from .logger_config import setup_logger

logger = setup_logger(__name__)
logger = logging.getLogger(__name__)


def _extract_options(text: str) -> list[Any]:
Expand Down
4 changes: 2 additions & 2 deletions src/instructlab/sdg/filterblock.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
# SPDX-License-Identifier: Apache-2.0
# Standard
import logging
import operator

# Third Party
from datasets import Dataset

# Local
from .block import Block
from .logger_config import setup_logger

logger = setup_logger(__name__)
logger = logging.getLogger(__name__)


# This is part of the public API.
Expand Down
24 changes: 12 additions & 12 deletions src/instructlab/sdg/generate_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from typing import Optional
import dataclasses
import json
import logging
import os
import time

Expand Down Expand Up @@ -36,6 +37,8 @@
read_taxonomy_leaf_nodes,
)

logger = logging.getLogger(__name__)

_SYS_PROMPT = "I am, Red Hat® Instruct Model based on Granite 7B, an AI language model developed by Red Hat and IBM Research, based on the Granite-7b-base language model. My primary function is to be a chat assistant."


Expand Down Expand Up @@ -74,9 +77,7 @@ def _convert_to_messages(sample):
return sample


def _gen_train_data(
logger, machine_instruction_data, output_file_train, output_file_messages
):
def _gen_train_data(machine_instruction_data, output_file_train, output_file_messages):
"""
Generate training data in the legacy system/user/assistant format
used in train_*.jsonl as well as the legacy messages format used
Expand Down Expand Up @@ -262,9 +263,9 @@ def _mixer_init(ctx, output_dir, date_suffix):

# This is part of the public API, and used by instructlab.
# TODO - parameter removal needs to be done in sync with a CLI change.
# pylint: disable=unused-argument
# to be removed: logger, prompt_file_path, rouge_threshold, tls_*
def generate_data(
logger,
logger: logging.Logger = logger, # pylint: disable=redefined-outer-name
api_base: Optional[str] = None,
api_key: Optional[str] = None,
model_family: Optional[str] = None,
Expand All @@ -275,9 +276,9 @@ def generate_data(
taxonomy_base: Optional[str] = None,
output_dir: Optional[str] = None,
# TODO - not used and should be removed from the CLI
prompt_file_path: Optional[str] = None,
prompt_file_path: Optional[str] = None, # pylint: disable=unused-argument
# TODO - probably should be removed
rouge_threshold: Optional[float] = None,
rouge_threshold: Optional[float] = None, # pylint: disable=unused-argument
console_output=True,
yaml_rules: Optional[str] = None,
chunk_word_count=None,
Expand Down Expand Up @@ -392,9 +393,9 @@ def generate_data(
else:
sdg = sdg_freeform_skill

logger.debug("Samples: %s" % samples)
logger.debug("Samples: %s", samples)
ds = Dataset.from_list(samples)
logger.debug("Dataset: %s" % ds)
logger.debug("Dataset: %s", ds)
new_generated_data = sdg.generate(ds)
if len(new_generated_data) == 0:
raise EmptyDatasetError(
Expand All @@ -405,8 +406,8 @@ def generate_data(
if generated_data is None
else generated_data + [new_generated_data]
)
logger.info("Generated %d samples" % len(generated_data))
logger.debug("Generated data: %s" % generated_data)
logger.info("Generated %d samples", len(generated_data))
logger.debug("Generated data: %s", generated_data)

if is_knowledge:
# generate mmlubench data for the current leaf node
Expand All @@ -424,7 +425,6 @@ def generate_data(
generated_data = []

_gen_train_data(
logger,
generated_data,
os.path.join(output_dir, output_file_train),
os.path.join(output_dir, output_file_messages),
Expand Down
6 changes: 4 additions & 2 deletions src/instructlab/sdg/importblock.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
# SPDX-License-Identifier: Apache-2.0
# Standard
import logging

# Third Party
from datasets import Dataset

# Local
from .block import Block
from .logger_config import setup_logger

logger = setup_logger(__name__)
logger = logging.getLogger(__name__)


# This is part of the public API.
Expand Down
4 changes: 2 additions & 2 deletions src/instructlab/sdg/llmblock.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# Standard
from collections import ChainMap
from typing import Any, Dict
import logging
import re

# Third Party
Expand All @@ -10,9 +11,8 @@

# Local
from .block import Block
from .logger_config import setup_logger

logger = setup_logger(__name__)
logger = logging.getLogger(__name__)

MODEL_FAMILY_MIXTRAL = "mixtral"
MODEL_FAMILY_MERLINITE = "merlinite"
Expand Down
18 changes: 0 additions & 18 deletions src/instructlab/sdg/logger_config.py

This file was deleted.

4 changes: 2 additions & 2 deletions src/instructlab/sdg/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from dataclasses import dataclass
from importlib import resources
from typing import Iterable, Optional
import logging
import math
import os.path

Expand All @@ -19,9 +20,8 @@
# Local
from . import filterblock, importblock, llmblock, utilblocks
from .block import Block
from .logger_config import setup_logger

logger = setup_logger(__name__)
logger = logging.getLogger(__name__)


# This is part of the public API.
Expand Down
6 changes: 4 additions & 2 deletions src/instructlab/sdg/utilblocks.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# Standard
import logging

# Third Party
from datasets import Dataset

Expand All @@ -7,9 +10,8 @@

# Local
from .block import Block
from .logger_config import setup_logger

logger = setup_logger(__name__)
logger = logging.getLogger(__name__)


# This is part of the public API.
Expand Down

0 comments on commit d9cc9b7

Please sign in to comment.