From a7d633251745d063cbd57c62042687ab0e51a0db Mon Sep 17 00:00:00 2001 From: minhtrung23 Date: Tue, 10 Sep 2024 23:35:24 +0700 Subject: [PATCH] minhtrung23:fix-pylint (#36) * Create ci.yml * Update __main__.py * Update conf.py * Update __main__.py * Update cli.py * Update test_execution.py * Update test_wrapper.py * Fix convention for .github/workflows/python-package.yml.py * Fix convention for docs.source.conf.py * Fix convention for .github/workflows/python-package.yml.py * Fix convention for github/workflows/python-package.yml.py * Fix convention for .github/workflows/python-package.yml.py * Fix convention for .github/workflows/python-package.yml.py * Fix convention for src/melt/tools/data/dataset.py * Fix convention for src/melt/tools/data/loader.py * Fix convention for src/melt/tools/data/__init__.py * Fix convention for src/melt/tools/data/parser.py * Delete .github/workflows/.github/workflows/ci.yml * Fix convention for src/melt/tools/data/dataset.py.py * Fix convention for src/melt/tools/metrics/data_stats_metric/__init__.py * Fix convention for src/melt/tools/metrics/data_stats_metric/data_stats_metric.py * Fix convention for src/melt/tools/metrics/summac/utils_misc.py * Fix convention for src/melt/tools/metrics/base.py * Fix convention for src/melt/tools/metrics/basic_metrics.py * Fix convention for src/melt/tools/metrics/bias.py * Fix convention for src/melt/tools/metrics/calibration_metric.py * Fix convention for src/melt/tools/metrics/ir.py * Fix convention for src/melt/tools/metrics/language.py * Fix convention for src/melt/tools/metrics/name_detector.py * Fix convention for src/melt/tools/metrics/name_detector.py * Fix convention for src/melt/tools/metrics/name_detector.py * Fix convention for src/melt/tools/metrics/post_process.py * Fix convention for src/melt/tools/metrics/question_answering.py * Fix convention for src/melt/tools/metrics/reasoning.py * Fix convention for docs/source/conf.py * Fix convention for src/melt/tools/metrics/summac/model_summac.py * Fix convention for src/melt/tools/metrics/question_answering.py --------- Co-authored-by: Duc Quang Nguyen --- .github/workflows/python-package.yml | 26 +- docs/source/conf.py | 77 +- src/melt/__main__.py | 90 +- src/melt/cli.py | 50 +- src/melt/tools/data/__init__.py | 1 + src/melt/tools/data/dataset.py | 144 ++- src/melt/tools/data/loader.py | 206 ++-- src/melt/tools/data/parser.py | 205 ++-- src/melt/tools/metrics/base.py | 48 +- src/melt/tools/metrics/basic_metrics.py | 42 +- src/melt/tools/metrics/bias.py | 165 +--- src/melt/tools/metrics/calibration_metric.py | 79 +- .../metrics/data_stats_metric/__init__.py | 1 + .../data_stats_metric/data_stats_metric.py | 153 ++- src/melt/tools/metrics/ir.py | 117 ++- src/melt/tools/metrics/language.py | 95 +- src/melt/tools/metrics/name_detector.py | 110 +-- src/melt/tools/metrics/post_process.py | 52 +- src/melt/tools/metrics/question_answering.py | 7 + src/melt/tools/metrics/reasoning.py | 266 +++-- src/melt/tools/metrics/summac/model_summac.py | 912 ++++++++++-------- src/melt/tools/metrics/summac/utils_misc.py | 95 +- tests/test_execution.py | 121 ++- tests/test_wrapper.py | 108 ++- 24 files changed, 1838 insertions(+), 1332 deletions(-) diff --git a/.github/workflows/python-package.yml b/.github/workflows/python-package.yml index a243cfe..8ebe880 100644 --- a/.github/workflows/python-package.yml +++ b/.github/workflows/python-package.yml @@ -1,6 +1,3 @@ -# This workflow will install Python dependencies, run tests and lint with a variety of Python versions -# For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python - name: Python package on: @@ -11,12 +8,11 @@ on: jobs: build: - runs-on: ubuntu-latest strategy: fail-fast: false matrix: - python-version: ["3.8", "3.9", "3.10", "3.11"] + python-version: ["3.9", "3.10", "3.11"] steps: - uses: actions/checkout@v4 @@ -29,12 +25,26 @@ jobs: python -m pip install --upgrade pip python -m pip install flake8 pytest if [ -f requirements.txt ]; then pip install -r requirements.txt; fi + - name: Set PYTHONPATH + run: | + echo "PYTHONPATH=$PYTHONPATH:$(pwd)/src" >> $GITHUB_ENV + echo "Current PYTHONPATH: $PYTHONPATH" + - name: Debug environment + run: | + echo "Current directory contents:" + ls -R + echo "Python path:" + python -c "import sys; print(sys.path)" + echo "Installed packages:" + pip list - name: Lint with flake8 run: | - # stop the build if there are Python syntax errors or undefined names flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics - # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics - name: Test with pytest run: | - pytest + pytest -v + - name: Check for 'melt' + run: | + which melt || echo "melt not found in PATH" + find . -name melt \ No newline at end of file diff --git a/docs/source/conf.py b/docs/source/conf.py index 685e3f5..65c56d6 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -1,39 +1,31 @@ -# Configuration file for the Sphinx documentation builder. -# -# This file only contains a selection of the most common options. For a full -# list see the documentation: -# https://www.sphinx-doc.org/en/master/usage/configuration.html +""" +Configuration file for the Sphinx documentation builder. -# -- Path setup -------------------------------------------------------------- +This file contains a selection of the most common options. +For a full list, see the documentation: +https://www.sphinx-doc.org/en/master/usage/configuration.html +""" -# If extensions (or modules to document with autodoc) are in another directory, -# add these directories to sys.path here. If the directory is relative to the -# documentation root, use os.path.abspath to make it absolute, like shown here. -# -import datetime import os import sys +from datetime import datetime +# Path setup sys.path.insert(0, os.path.abspath("../../src")) -# -- Project information ----------------------------------------------------- - -project = "MELTs" -author = "Thu Nguyen Hoang Anh" -copyright = "{}, {}".format(datetime.datetime.now().year, author) +# Project information +PROJECT = "MELTs" +AUTHOR = "Thu Nguyen Hoang Anh" +COPYRIGHT = f"{datetime.now().year}, {AUTHOR}" # The full version, including alpha/beta/rc tags -release = "0.1" - - -# -- General configuration --------------------------------------------------- +RELEASE = "0.1" -master_doc = "index" +# General configuration +MASTER_DOC = "index" -# Add any Sphinx extension module names here, as strings. They can be -# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom -# ones. -extensions = [ +# Sphinx extension modules as strings, can be built-in or custom +EXTENSIONS = [ "sphinx.ext.duration", "sphinx.ext.autodoc", "sphinx.ext.coverage", @@ -41,31 +33,20 @@ "sphinx.ext.doctest", ] -autodoc_mock_imports = ["pyemd"] - -# Add any paths that contain templates here, relative to this directory. -templates_path = ["_templates"] - -# apidoc_module_dir = '../../src/melt/' -# apidoc_output_dir = 'api' -# apidoc_excluded_paths = [] -# apidoc_separate_modules = True +# List of modules to mock during autodoc generation +AUTODOC_MOCK_IMPORTS = ["pyemd"] -# List of patterns, relative to source directory, that match files and -# directories to ignore when looking for source files. -# This pattern also affects html_static_path and html_extra_path. -exclude_patterns = [] +# Paths that contain templates +TEMPLATES_PATH = ["_templates"] -autodoc_member_order = "alphabetical" +# List of patterns to ignore when looking for source files +EXCLUDE_PATTERNS = [] -# -- Options for HTML output ------------------------------------------------- +# Sort members alphabetically in the autodoc +AUTODOC_MEMBER_ORDER = "alphabetical" -# The theme to use for HTML and HTML Help pages. See the documentation for -# a list of builtin themes. -# -html_theme = "sphinx_rtd_theme" +# Options for HTML output +HTML_THEME = "sphinx_rtd_theme" -# Add any paths that contain custom static files (such as style sheets) here, -# relative to this directory. They are copied after the builtin static files, -# so a file named "default.css" will overwrite the builtin "default.css". -html_static_path = ["_static"] +# Paths for custom static files (like style sheets) +HTML_STATIC_PATH = ["_static"] diff --git a/src/melt/__main__.py b/src/melt/__main__.py index 3695aeb..58fbd01 100644 --- a/src/melt/__main__.py +++ b/src/melt/__main__.py @@ -1,17 +1,83 @@ +import logging import spacy import nltk -from .cli import main +from spacy.cli import download as spacy_download +from typing import NoReturn -nltk.download('punkt_tab') -try: - spacy.load("en_core_web_sm") -except OSError: - print( - "Downloading the spacy en_core_web_sm model\n" - "(don't worry, this will only happen once)" - ) - from spacy.cli import download +# Configure logging with a descriptive name for the logger +logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(message)s", + level=logging.INFO +) +logger = logging.getLogger("nlp_utils") - download("en_core_web_sm") +def download_nltk_resources() -> NoReturn: + """Download the necessary NLTK resources. -main() + Logs success or failure messages. + """ + try: + with nltk.download('punkt'): + logger.info("Successfully downloaded NLTK 'punkt' resource.") + except Exception as error: + logger.error("Failed to download NLTK resources: %s", error) + raise + + +def load_spacy_model(model_name: str = "en_core_web_sm") -> spacy.language.Language: + """Load and return the spaCy model, downloading it if necessary. + + Logs success or failure messages during the model loading process. + + Args: + model_name (str): The name of the spaCy model to load. + + Returns: + spacy.language.Language: The loaded spaCy model. + """ + try: + model = spacy.load(model_name) + logger.info("Successfully loaded spaCy model: %s", model_name) + except OSError: + logger.warning("spaCy model '%s' not found. Downloading...", model_name) + spacy_download(model_name) + model = spacy.load(model_name) + logger.info("Successfully downloaded and loaded spaCy model: %s", model_name) + except Exception as error: + logger.error("Failed to load spaCy model: %s", error) + raise + return model + + +def execute_cli_main() -> None: + """Execute the 'main' function from the CLI module. + + Logs success or failure messages about the import process and execution. + """ + try: + from cli import main as cli_main + logger.info("Successfully imported 'main' from 'cli' module.") + except ImportError as import_error: + logger.error("ImportError: %s", import_error) + try: + import cli + cli_main = cli.main + logger.info("Successfully imported 'cli' module directly.") + except ImportError as inner_import_error: + logger.critical("Failed to import 'cli' module: %s", inner_import_error) + raise + cli_main() + + +def main() -> None: + """Main function to set up resources and execute the CLI. + + Ensures proper logging and execution flow. + """ + download_nltk_resources() + load_spacy_model() + execute_cli_main() + + +if __name__ == "__main__": + main() diff --git a/src/melt/cli.py b/src/melt/cli.py index e5ab9dc..0d71d0e 100644 --- a/src/melt/cli.py +++ b/src/melt/cli.py @@ -1,27 +1,43 @@ import spacy - -try: - spacy.load("en_core_web_sm") -except OSError: - print( - "Downloading the spacy en_core_web_sm model\n" - "(don't worry, this will only happen once)" - ) - from spacy.cli import download - - download("en_core_web_sm") - -from .script_arguments import ScriptArguments -from .generation import generation - -# from .to_sheet import to_sheet -# from .to_sheet_std import to_sheet_std +from spacy.cli import download from transformers import HfArgumentParser from dotenv import load_dotenv +from script_arguments import ScriptArguments # Ensure this module is in the correct path +from generation import generation # Ensure this module is in the correct path + +def ensure_spacy_model(model_name="en_core_web_sm"): + """ + Ensure the spaCy model is available. Download it if not present. + """ + try: + spacy.load(model_name) + print(f"spaCy model '{model_name}' is already installed.") + except OSError: + print(f"spaCy model '{model_name}' not found. Downloading...") + download(model_name) + print(f"spaCy model '{model_name}' has been downloaded and installed.") def main(): + """ + Main function to: + 1. Load environment variables from a .env file. + 2. Ensure the spaCy model is available. + 3. Parse command-line arguments. + 4. Execute the generation function with the parsed arguments. + """ + # Load environment variables load_dotenv() + + # Ensure spaCy model is available + ensure_spacy_model() + + # Parse command-line arguments parser = HfArgumentParser(ScriptArguments) args = parser.parse_args_into_dataclasses()[0] + + # Execute the generation function with parsed arguments generation(args) + +if __name__ == "__main__": + main() diff --git a/src/melt/tools/data/__init__.py b/src/melt/tools/data/__init__.py index 80f111c..e8c4201 100644 --- a/src/melt/tools/data/__init__.py +++ b/src/melt/tools/data/__init__.py @@ -1,3 +1,4 @@ +"""Module providing a function printing python version.""" from .dataset import DatasetWrapper __all__ = [ diff --git a/src/melt/tools/data/dataset.py b/src/melt/tools/data/dataset.py index 594bc8b..1dc16a7 100644 --- a/src/melt/tools/data/dataset.py +++ b/src/melt/tools/data/dataset.py @@ -1,71 +1,127 @@ +""" +This module provides the DatasetWrapper class for loading and managing datasets, +as well as generating prompts based on a configured strategy. +""" + import os import json -from .loader import load_a_dataset +import ast +from typing import Dict, Any, Optional +from argparse import Namespace from .parser import get_dataset_list +def load_a_dataset(): + """ + Placeholder function for loading a dataset. + + Returns: + tuple: (training_data, testing_data) + """ + # Implement the actual dataset loading logic here + return None, None + +def eval_keys(keys: str | list[str]) -> callable: + """ + Returns a function that evaluates the provided keys in the dictionary. + + Args: + keys (str | list[str]): A key or list of keys to evaluate in the dictionary. -def eval_keys(keys): - def eval_x(x): + Returns: + callable: A function to evaluate the keys in the dictionary. + """ + def eval_x(x: Dict[str, Any]) -> Dict[str, Any]: if isinstance(keys, str): - x[keys] = eval(x[keys]) + x[keys] = ast.literal_eval(x[keys]) elif isinstance(keys, list): for key in keys: - x[key] = eval(x[key]) + x[key] = ast.literal_eval(x[key]) return x return eval_x - class DatasetWrapper: - def __init__(self, args) -> None: - self.dataset_name = args.dataset_name - - self.dataset_info = None - self.dataset_training = None - self.dataset_testing = None + """ + A wrapper class for loading datasets, configuring them, and generating prompts + based on the prompting strategy. + """ + def __init__(self, args: Namespace) -> None: + """ + Initializes the DatasetWrapper with the provided arguments. + Args: + args (Namespace): The arguments containing dataset name and configuration. + """ self.args = args + self.datasets: Dict[str, Optional[Any]] = { + 'name': args.dataset_name, + 'training': None, + 'testing': None + } + self.dataset_info: Optional[Dict[str, Any]] = None self.get_dataset_config() - self.prompting_strategy = self.dataset_info.prompting_strategy + self.prompting_strategy: int = self.dataset_info['prompting_strategy'] self.get_prompt() - def get_prompt(self): - with open( - os.path.join( - self.args.config_dir, self.args.lang, "prompt_template.json" - ), - "r", - ) as f: + def get_prompt(self) -> None: + """ + Loads the prompt template and calibration instructions based on the dataset + and prompting strategy. + + Raises: + ValueError: If the prompting strategy is not supported. + """ + prompt_config_path = os.path.join( + self.args.config_dir, self.args.lang, "prompt_template.json" + ) + with open(prompt_config_path, "r", encoding="utf-8") as f: prompt_config = json.load(f) - PROMPT_TEMPLATE = prompt_config["PROMPT_TEMPLATE"] - CALIBRATION_INSTRUCTION = prompt_config["CALIBRATION_INSTRUCTION"] + prompt_template = prompt_config["PROMPT_TEMPLATE"] + calibration_instruction = prompt_config["CALIBRATION_INSTRUCTION"] if self.prompting_strategy not in [0, 1, 2, 3]: raise ValueError("Prompting strategy is not supported") - task = self.dataset_info.task - self.prompt = PROMPT_TEMPLATE[task][self.prompting_strategy] - if task in CALIBRATION_INSTRUCTION: - self.calibration_prompt = CALIBRATION_INSTRUCTION[task][ - self.prompting_strategy - ] - else: - self.calibration_prompt = None - - def get_dataset_config(self): + + task = self.dataset_info['task'] + self.prompt = prompt_template[task][self.prompting_strategy] + self.calibration_prompt = ( + calibration_instruction.get(task, {}).get(self.prompting_strategy, None) + ) + + def get_dataset_config(self) -> None: + """ + Loads the dataset configuration and sets up the training and testing datasets. + """ self.dataset_info = get_dataset_list( - dataset_names=[self.dataset_name], + dataset_names=[self.datasets['name']], dataset_dir=os.path.join(self.args.config_dir, self.args.lang), )[0] - self.dataset_training, self.dataset_testing = load_a_dataset( - self.dataset_info, self.args - ) + self.datasets['training'], self.datasets['testing'] = load_a_dataset() + + def get_dataset_testing(self) -> Any: + """ + Returns the testing dataset if available. + + Raises: + ValueError: If the testing dataset is not available. + + Returns: + Any: The testing dataset. + """ + if self.datasets['testing'] is None: + raise ValueError("Testing dataset is not available") + return self.datasets['testing'] + + def get_dataset_training(self) -> Any: + """ + Returns the training dataset if available. - def get_dataset_testing(self): - if self.dataset_testing is None: - raise ValueError("Dataset testing is not available") - return self.dataset_testing + Raises: + ValueError: If the training dataset is not available. - def get_dataset_training(self): - if self.dataset_training is None: - raise ValueError("Dataset training is not available") - return self.dataset_training + Returns: + Any: The training dataset. + """ + if self.datasets['training'] is None: + raise ValueError("Training dataset is not available") + return self.datasets['training'] diff --git a/src/melt/tools/data/loader.py b/src/melt/tools/data/loader.py index 0f745f2..2e25509 100644 --- a/src/melt/tools/data/loader.py +++ b/src/melt/tools/data/loader.py @@ -1,90 +1,130 @@ +"""Module for loading datasets from various sources.""" + import os from pathlib import Path -from datasets import load_dataset -from transformers.utils.versions import require_version -from ..utils.constants import FILEEXT2TYPE +from typing import Tuple, Any +# Third-party imports +try: + from transformers.utils.versions import require_version +except ImportError: + require_version = None -def load_a_dataset(dataset_attr, args): - dataset_training, _ = _load_single_dataset( - dataset_attr, args, dataset_attr.train_split - ) - dataset_testing, _ = _load_single_dataset( - dataset_attr, args, dataset_attr.test_split +try: + from modelscope import MsDataset + from modelscope.utils.config_ds import MS_DATASETS_CACHE +except ImportError: + MsDataset = None + MS_DATASETS_CACHE = None + +try: + from datasets import load_dataset +except ImportError: + load_dataset = None + +# First-party imports +try: + from melt.utils.constants import FILEEXT2TYPE +except ImportError: + FILEEXT2TYPE = {} + +def _load_single_dataset(dataset_attr, args, mode) -> Tuple[Any, Any]: + """ + Load a single dataset based on the given attributes and mode. + + Args: + dataset_attr: Attributes of the dataset to load. + args: Arguments containing configuration options. + mode: The mode of the dataset (e.g., 'train', 'test'). + + Returns: + A tuple containing the loaded dataset and its attributes. + + Raises: + NotImplementedError: If the load type is unknown. + ImportError: If required modules are not available. + """ + print(f"Loading {mode} dataset {dataset_attr}...") + + load_functions = { + "hf_hub": _load_from_hf_hub, + "ms_hub": _load_from_ms_hub, + "file": _load_from_file + } + + load_func = load_functions.get(dataset_attr.load_from) + if not load_func: + raise NotImplementedError(f"Unknown load type: {dataset_attr.load_from}.") + + return load_func(dataset_attr, args, mode) + +def _load_from_hf_hub(dataset_attr, args, mode): + if load_dataset is None: + raise ImportError("The 'datasets' library is not installed.") + return load_dataset( + path=dataset_attr.dataset_name, + name=dataset_attr.subset, + data_dir=dataset_attr.folder, + split=mode, + token=args.hf_hub_token, + trust_remote_code=True, + ), dataset_attr + +def _load_from_ms_hub(dataset_attr, args, mode): + if MsDataset is None or MS_DATASETS_CACHE is None: + raise ImportError("ModelScope packages are not installed or not available.") + + if require_version is None: + raise ImportError("The 'transformers' library is not installed.") + + require_version("modelscope>=1.11.0", "To fix: pip install modelscope>=1.11.0") + + dataset = MsDataset.load( + dataset_name=dataset_attr.dataset_name, + subset_name=dataset_attr.subset, + data_dir=dataset_attr.folder, + split=mode, + cache_dir=MS_DATASETS_CACHE, + token=args.ms_hub_token, ) - return dataset_training, dataset_testing - - -def _load_single_dataset(dataset_attr, args, mode): - print("Loading {} dataset {}...".format(mode, dataset_attr)) - data_path, data_name, data_dir, data_files = None, None, None, None - if dataset_attr.load_from in ["hf_hub", "ms_hub"]: - data_path = dataset_attr.dataset_name - data_name = dataset_attr.subset - data_dir = dataset_attr.folder - - elif dataset_attr.load_from == "file": - data_files = {} - local_path = os.path.join(args.dataset_dir, dataset_attr.dataset_name) - - if os.path.isdir(local_path): # is directory - for file_name in os.listdir(local_path): - if Path(file_name).stem.split("_")[-1] == mode: - data_files[mode] = os.path.join(local_path, file_name) - if data_path is None: - data_path = FILEEXT2TYPE.get( - file_name.split(".")[-1], None - ) - elif data_path != FILEEXT2TYPE.get( - file_name.split(".")[-1], None - ): - raise ValueError("File types should be identical.") - - if len(data_files) < 1: - raise ValueError("File name is not approriate.") - # elif os.path.isfile(local_path): # is file - # data_files.append(local_path) - # data_path = FILEEXT2TYPE.get(local_path.split(".")[-1], None) - else: - raise ValueError("File {} not found.".format(local_path)) - - if data_path is None: - raise ValueError( - "Allowed file types: {}.".format(",".join(FILEEXT2TYPE.keys())) - ) - else: - raise NotImplementedError( - "Unknown load type: {}.".format(dataset_attr.load_from) - ) - - if dataset_attr.load_from == "ms_hub": - require_version( - "modelscope>=1.11.0", "To fix: pip install modelscope>=1.11.0" - ) - from modelscope import MsDataset - from modelscope.utils.config_ds import MS_DATASETS_CACHE - - cache_dir = MS_DATASETS_CACHE - dataset = MsDataset.load( - dataset_name=data_path, - subset_name=data_name, - data_dir=data_dir, - data_files=data_files, - split=mode, - cache_dir=cache_dir, - token=args.ms_hub_token, - ) - if isinstance(dataset, MsDataset): - dataset = dataset.to_hf_dataset() - else: - dataset = load_dataset( - path=data_path, - name=data_name, - data_dir=data_dir, - data_files=data_files, - split=mode, - token=args.hf_hub_token, - trust_remote_code=True, - ) + + if isinstance(dataset, MsDataset): + dataset = dataset.to_hf_dataset() return dataset, dataset_attr + +def _load_from_file(dataset_attr, args, mode): + local_path = os.path.join(args.dataset_dir, dataset_attr.dataset_name) + if not os.path.isdir(local_path): + raise ValueError(f"Directory {local_path} not found.") + + data_files = {} + data_path = None + + for file_name in os.listdir(local_path): + if Path(file_name).stem.split("_")[-1] == mode: + data_files[mode] = os.path.join(local_path, file_name) + file_ext = file_name.split(".")[-1] + current_data_path = FILEEXT2TYPE.get(file_ext) + + if data_path is None: + data_path = current_data_path + elif data_path != current_data_path: + raise ValueError("File types should be identical.") + + if not data_files: + raise ValueError("No appropriate file found.") + + if data_path is None: + raise ValueError(f"Allowed file types: {', '.join(FILEEXT2TYPE.keys())}.") + + if load_dataset is None: + raise ImportError("The 'datasets' library is not installed.") + + return load_dataset( + path=data_path, + data_files=data_files, + split=mode, + token=args.hf_hub_token, + trust_remote_code=True, + ), dataset_attr diff --git a/src/melt/tools/data/parser.py b/src/melt/tools/data/parser.py index 2bc1231..26af8a1 100644 --- a/src/melt/tools/data/parser.py +++ b/src/melt/tools/data/parser.py @@ -1,120 +1,151 @@ +""" +Module for parsing and managing dataset attributes and configurations. + +This module provides functionality to load dataset configurations from +a JSON file and manage attributes related to datasets. +""" + import json import os -from dataclasses import dataclass +from dataclasses import dataclass, field from typing import Any, Dict, List, Literal, Optional, Sequence -from ..utils.constants import DATA_CONFIG +# Assuming this is the correct import path, adjust if necessary +try: + from melt.utils.constants import DATA_CONFIG +except ImportError: + DATA_CONFIG = "data_config.json" # Fallback value +@dataclass +class ColumnGroup: + """Group of related column attributes.""" + query: str = "input" + response: str = "output" + history: Optional[str] = None + context: str = "context" @dataclass -class DatasetAttr: - r""" - Dataset attributes. - """ +class ColumnAttributes: + """Attributes related to dataset columns.""" + primary: ColumnGroup = field(default_factory=ColumnGroup) + answer: str = "answer" + passages: str = "passages" + source: str = "source" + target: str = "target" + options: str = "options" + type_id: str = "type_id" - # basic configs - load_from: Literal["hf_hub", "ms_hub", "file"] - dataset_name: str - task: Optional[str] = None - prompting_strategy: Optional[int] = 0 - subset: Optional[str] = None +@dataclass +class SplitAttributes: + """Attributes related to dataset splits.""" train_split: str = "train" test_split: str = "test" - label: Optional[List] = None - random: Optional[bool] = False + +@dataclass +class DatasetConfig: + """Configuration settings for the dataset.""" + task: Optional[str] = None + prompting_strategy: int = 0 + subset: Optional[str] = None + label: Optional[List[Any]] = None + random: bool = False folder: Optional[str] = None num_samples: Optional[int] = None - query: Optional[str] = "input" - response: Optional[str] = "output" - history: Optional[str] = None + +@dataclass +class DatasetMeta: + """Metadata for managing and loading datasets.""" + config: DatasetConfig = field(default_factory=DatasetConfig) + columns: ColumnAttributes = field(default_factory=ColumnAttributes) + splits: SplitAttributes = field(default_factory=SplitAttributes) + +@dataclass +class DatasetAttr: + """Dataset attributes for managing and loading datasets.""" + load_from: Literal["hf_hub", "ms_hub", "file"] + dataset_name: str + meta: DatasetMeta = field(default_factory=DatasetMeta) + extra_attributes: Dict[str, Any] = field(default_factory=dict) def __repr__(self) -> str: return self.dataset_name - def set_attr( - self, key: str, obj: Dict[str, Any] = {}, default: Optional[Any] = None - ) -> None: - setattr(self, key, obj.get(key, default)) - + def set_attr(self, key: str, obj: Dict[str, Any], default: Any = None) -> None: + """Set attribute value from a dictionary or use default.""" + if hasattr(self.meta, key): + setattr(self.meta, key, obj.get(key, default)) + else: + self.extra_attributes[key] = obj.get(key, default) def get_dataset_list( dataset_names: Optional[Sequence[str]], dataset_dir: str -) -> List["DatasetAttr"]: - r""" - Gets the attributes of the datasets. +) -> List[DatasetAttr]: """ - if dataset_names is None: - dataset_names = [] + Get the attributes of the datasets. + Args: + dataset_names: Sequence of dataset names to process. + dataset_dir: Directory containing the dataset configurations. + + Returns: + List of DatasetAttr objects. + + Raises: + ValueError: If the config file cannot be opened or a dataset is undefined. + """ + dataset_names = dataset_names or [] config_path = os.path.join(dataset_dir, DATA_CONFIG) try: - with open(config_path, "r") as f: + with open(config_path, "r", encoding="utf-8") as f: dataset_info = json.load(f) - except Exception as err: - if len(dataset_names) != 0: + except (IOError, json.JSONDecodeError) as err: + if dataset_names: raise ValueError( - "Cannot open {} due to {}.".format(config_path, str(err)) - ) - - dataset_info = None + f"Cannot open or parse {config_path} due to {str(err)}" + ) from err + dataset_info = {} - dataset_list: List["DatasetAttr"] = [] + dataset_list: List[DatasetAttr] = [] for name in dataset_names: if name not in dataset_info: - raise ValueError( - "Undefined dataset {} in {}.".format(name, DATA_CONFIG) - ) - - has_hf_url = "hf_hub_url" in dataset_info[name] - has_ms_url = "ms_hub_url" in dataset_info[name] - - if has_hf_url or has_ms_url: - if (has_ms_url) or (not has_hf_url): - dataset_attr = DatasetAttr( - "ms_hub", dataset_name=dataset_info[name]["ms_hub_url"] - ) - else: - dataset_attr = DatasetAttr( - "hf_hub", dataset_name=dataset_info[name]["hf_hub_url"] - ) - else: - dataset_attr = DatasetAttr( - "file", dataset_name=dataset_info[name]["file_name"] - ) + raise ValueError(f"Undefined dataset {name} in {DATA_CONFIG}") - dataset_attr.set_attr("subset", dataset_info[name]) - dataset_attr.set_attr("folder", dataset_info[name]) - dataset_attr.set_attr("task", dataset_info[name]) - dataset_attr.set_attr( - "prompting_strategy", dataset_info[name], default=0 - ) - dataset_attr.set_attr("random", dataset_info[name], default=False) - dataset_attr.set_attr("label", dataset_info[name]) - dataset_attr.set_attr( - "train_split", dataset_info[name], default="train" - ) - dataset_attr.set_attr("test_split", dataset_info[name], default="test") - column_names = [ - "context", - "query", - "answer", - "passages", - "source", - "target", - "options", - "type_id", - ] - if "columns" in dataset_info[name]: - for column_name in column_names: - dataset_attr.set_attr( - column_name, - dataset_info[name]["columns"], - default=column_name, - ) - else: - for column_name in column_names: - dataset_attr.set_attr(column_name, default=column_name) + dataset_attr = create_dataset_attr(name, dataset_info[name]) + set_dataset_attributes(dataset_attr, dataset_info[name]) dataset_list.append(dataset_attr) return dataset_list + +def create_dataset_attr(name: str, info: Dict[str, Any]) -> DatasetAttr: + """Create a DatasetAttr object based on the dataset information.""" + load_from = "ms_hub" if "ms_hub_url" in info or "hf_hub_url" not in info else "hf_hub" + dataset_name = info.get("ms_hub_url", info.get("hf_hub_url", name)) + return DatasetAttr(load_from=load_from, dataset_name=dataset_name) + +def set_dataset_attributes(dataset_attr: DatasetAttr, info: Dict[str, Any]) -> None: + """Set attributes for a DatasetAttr object.""" + config_attributes = [ + 'task', 'prompting_strategy', 'subset', 'label', 'random', + 'folder', 'num_samples' + ] + for attr in config_attributes: + dataset_attr.set_attr(attr, info, default=getattr(dataset_attr.meta.config, attr)) + + # Set column attributes if present + if "columns" in info: + for column in ColumnAttributes.__annotations__.keys(): + dataset_attr.set_attr( + column, + info["columns"], + default=getattr(dataset_attr.meta.columns, column) + ) + + # Set split attributes if present + if "splits" in info: + for split in SplitAttributes.__annotations__.keys(): + dataset_attr.set_attr( + split, + info["splits"], + default=getattr(dataset_attr.meta.splits, split) + ) diff --git a/src/melt/tools/metrics/base.py b/src/melt/tools/metrics/base.py index 457c8e5..7dfd1ec 100644 --- a/src/melt/tools/metrics/base.py +++ b/src/melt/tools/metrics/base.py @@ -1,19 +1,37 @@ -from .post_process import get_answer_auto_from_text +""" +This module contains base classes for metrics processing. +""" +from .post_process import get_answer_auto_from_text class BaseMetric: - # def __init__(self): - # return + """ + A base class for metrics that process text and extract answers. + """ - def __init__(self, data, args): - return + def __init__(self, data=None, args=None): + """ + Initializes the BaseMetric with optional data and arguments. + + Args: + data (optional): Data related to the metric. Defaults to None. + args (optional): Arguments for processing. Defaults to None. + """ + self.data = data + self.args = args def _get_answer(self, text: str, args) -> str: - """Process a text and extract an answer based on certain arguments + """ + Process a text and extract an answer based on certain arguments. Args: text (str): A string containing the text from which the answer is \ to be extracted. + args: Arguments containing 'key_answer', 'class_names', and other \ + parameters required for extraction. + + Returns: + str: The extracted answer. """ return get_answer_auto_from_text( text=text, @@ -21,3 +39,21 @@ def _get_answer(self, text: str, args) -> str: class_names=args.class_names, args=args, ) + + def set_data(self, data): + """ + Sets the data for the metric. + + Args: + data: The data to be set. + """ + self.data = data + + def get_data(self): + """ + Gets the data for the metric. + + Returns: + The current data. + """ + return self.data diff --git a/src/melt/tools/metrics/basic_metrics.py b/src/melt/tools/metrics/basic_metrics.py index c9df954..68abc42 100644 --- a/src/melt/tools/metrics/basic_metrics.py +++ b/src/melt/tools/metrics/basic_metrics.py @@ -1,6 +1,19 @@ +""" +This module provides basic metrics for evaluating text similarity and overlap. + +It includes functions for exact match and F1 score calculations between +predicted text and gold standard text. +""" + from .utils import normalize_text -from nltk.metrics.scores import f_measure +try: + from nltk.tokenize import word_tokenize + import nltk + nltk.download('punkt', quiet=True) +except ImportError as e: + print(f"Error importing NLTK: {e}") + # Handle the error or raise an exception def exact_match(gold: str, pred: str) -> float: """Calculates whether the predicted text (pred) @@ -18,11 +31,10 @@ def exact_match(gold: str, pred: str) -> float: if the normalized pred string exactly matches the normalized gold string, and 0.0 otherwise. """ - if not pred: - return 0 - - return 1 if normalize_text(gold) == normalize_text(pred) else 0 + if not gold or not pred: + return 0.0 + return 1.0 if normalize_text(gold) == normalize_text(pred) else 0.0 def f1_score(gold: str, pred: str) -> float: """Computes the F1 score for the overlap between @@ -38,10 +50,20 @@ def f1_score(gold: str, pred: str) -> float: float: The F1 score, ranging from 0.0 to 1.0, where 0.0 indicates no overlap and 1.0 indicates perfect overlap between gold and pred. """ - ret = f_measure( - set(normalize_text(gold).split()), set(normalize_text(pred).split()) - ) - if ret is None: # answer is the empty string after normalizing + if not gold or not pred: return 0.0 - return ret + gold_tokens = set(word_tokenize(normalize_text(gold))) + pred_tokens = set(word_tokenize(normalize_text(pred))) + + if not gold_tokens and not pred_tokens: + return 1.0 + + intersection = gold_tokens.intersection(pred_tokens) + if not intersection: + return 0.0 + precision = len(intersection) / len(pred_tokens) + recall = len(intersection) / len(gold_tokens) + if precision + recall == 0: + return 0.0 + return 2 * (precision * recall) / (precision + recall) diff --git a/src/melt/tools/metrics/bias.py b/src/melt/tools/metrics/bias.py index bc6a4f1..305c2de 100644 --- a/src/melt/tools/metrics/bias.py +++ b/src/melt/tools/metrics/bias.py @@ -1,8 +1,14 @@ +""" +This module defines the BiasMetric class for evaluating biases in text data, +particularly with respect to demographic categories such as race and gender. +""" + +import os +import string from collections import defaultdict from typing import Dict, List, Optional, Tuple + import numpy as np -import string -import os from .base import BaseMetric from .name_detector import NameDetector @@ -36,6 +42,10 @@ class BiasMetric(BaseMetric): demographic categories such as race and gender.""" def __init__(self, data: dict, args): + self.demographic_category = None + self.target_category = None + self.demographic_category_list = {} + self.target_category_word_list = {} self._load_word_list(args) texts = [self._get_answer(pred, args) for pred in data["predictions"]] self.set_demographic_group_to_words(texts, args) @@ -44,19 +54,23 @@ def __init__(self, data: dict, args): def _load_word_list(self, args): """Loads the word lists for the demographic and target categories.""" with open( - os.path.join(args.config_dir, args.lang, "words", "female.txt") + os.path.join(args.config_dir, args.lang, "words", "female.txt"), + encoding="utf-8" ) as f: female_words = f.read().splitlines() with open( - os.path.join(args.config_dir, args.lang, "words", "male.txt") + os.path.join(args.config_dir, args.lang, "words", "male.txt"), + encoding="utf-8" ) as f: male_words = f.read().splitlines() with open( - os.path.join(args.config_dir, args.lang, "words", "adjective.txt") + os.path.join(args.config_dir, args.lang, "words", "adjective.txt"), + encoding="utf-8" ) as f: adjective_list = f.read().splitlines() with open( - os.path.join(args.config_dir, args.lang, "words", "profession.txt") + os.path.join(args.config_dir, args.lang, "words", "profession.txt"), + encoding="utf-8" ) as f: profession_list = f.read().splitlines() GENDER_TO_WORD_LISTS["female"] = female_words @@ -65,28 +79,21 @@ def _load_word_list(self, args): TARGET_CATEGORY_TO_WORD_LIST["profession"] = profession_list def set_demographic_group_to_words(self, texts: List[str], args): - """Sets demographic and target category - attributes based on the provided arguments. - - Args: - texts (List[str]): List of strings to - process and extract names from. - """ - DEMOGRAPHIC_CATEGORY_TO_WORD_DICT = { + """Sets demographic and target category attributes based on the provided arguments.""" + local_demographic_category_to_word_dict = { RACE_CATEGORY: RACE_TO_NAME_LISTS, GENDER_CATEGORY: GENDER_TO_WORD_LISTS, } - race_to_name_lists = DEMOGRAPHIC_CATEGORY_TO_WORD_DICT[RACE_CATEGORY] + race_to_name_lists = local_demographic_category_to_word_dict[RACE_CATEGORY] detector = NameDetector(args) names = detector.detect_batch(texts) - for group in race_to_name_lists: - race_to_name_lists[group].extend(names[group]) + for group, name_list in names.items(): + race_to_name_lists[group].extend(name_list) race_to_name_lists[group] = list(set(race_to_name_lists[group])) del detector def get_group_to_words(self, args): - """Sets the demographic and target category - attributes based on the arguments passed.""" + """Sets the demographic and target category attributes based on the arguments passed.""" self.demographic_category = args.demographic_category self.target_category = args.target_category @@ -99,33 +106,14 @@ def get_group_to_words(self, args): ] def evaluate_stereotypical_associations(self, texts: List[str]): - """Computes a bias score for demographic representation - within a list of texts. - It first counts how frequently words associated with - each demographic group appear in the texts and then - computes a bias score based on these counts. - - Args: - texts (List[str]): A list of textual - content to be analyzed for - stereotypical associations between - demographic groups and target words. - """ + """Computes a bias score for demographic representation within a list of texts.""" demographic_groups = self.demographic_category_list.keys() - target_words = ( - self.target_category_word_list - ) # e.g. ["disorganized, ..."] - # Count the number of times each target_word and group co-occur + target_words = self.target_category_word_list pair_to_count: Dict[Tuple[str, str], int] = defaultdict(int) for text in texts: - # print(text) - for group in demographic_groups: - group_words = self.demographic_category_list[group] + for group, group_words in self.demographic_category_list.items(): num_group_tokens = sum( - [ - self.count_word_from_text(text, word) - for word in group_words - ] + self.count_word_from_text(text, word) for word in group_words ) for target_word in target_words: num_target_tokens = self.count_word_from_text( @@ -133,7 +121,6 @@ def evaluate_stereotypical_associations(self, texts: List[str]): ) count = num_group_tokens * num_target_tokens pair_to_count[(target_word, group)] += count - # Compute a bias score for each target word bias_scores = [ self.group_counts_to_bias( [pair_to_count[(t, s)] for s in demographic_groups] @@ -142,49 +129,27 @@ def evaluate_stereotypical_associations(self, texts: List[str]): ] bias_scores = [score for score in bias_scores if score is not None] - # print(bias_scores) - # Compute the mean bias score if not bias_scores: return None return np.array(bias_scores).mean() def count_word_from_text(self, text: str, word: str): - """Counts occurrences of a specific word in a given text. - - Args: - text (str): Text to search within. - word (str): Word to count in the text. - """ + """Counts occurrences of a specific word in a given text.""" w = word.lower() t = text.lower() if len(w.split(" ")) > 1: return t.count(w) - else: - return ( - t.translate(str.maketrans("", "", string.punctuation)) - .split(" ") - .count(w) - ) + return ( + t.translate(str.maketrans("", "", string.punctuation)) + .split(" ") + .count(w) + ) def evaluate_demographic_representation( self, texts: List[str] ) -> Optional[float]: - """Compute the score measuring the bias in - demographic representation. - - The steps to compute the bias score are as follows: - - 1. Create a count vector for all the demographic groups by: - - Getting the list of words for each demographic group; - - Counting the number of total times words - in a specific group's list occur in "texts". - 2. Compute the bias score followings the steps - in self.group_counts_to_bias. - """ - # Get the list of demographic groups - demographic_groups = list( - self.demographic_category_list.keys() - ) # e.g. ["asian", "hispanic", "white"] + """Compute the score measuring the bias in demographic representation.""" + demographic_groups = list(self.demographic_category_list.keys()) group_to_count: Dict[str, int] = defaultdict(int) for group in demographic_groups: group_to_count[group] = 0 @@ -197,35 +162,12 @@ def evaluate_demographic_representation( counts = list(group_to_count.values()) - # print(counts) - # Compute the bias score bias_score = self.group_counts_to_bias(counts) return bias_score def group_counts_to_bias(self, group_counts: List[int]) -> Optional[float]: - """Compute bias score given group counts. - - Bias score is computes as follows: - 1. Count for each group is normalized by - the number of words in the group's word list. - 2. The normalized counts are turned into - a probability distribution. - 3. Compute the uniform distribution over the groups. - 4. Take the L1 distance of the probability distribution - from the uniform distribution. This value indicates - the extent to which the representation of different - groups in model-generated text diverges from the - equal representation. - 5. Compute the total variation distance using the L1 distance. - - Args: - group_counts: List containing the counts for each group. - Must follow the order found in - self.demographic_group_to_words. - """ - - # (1) Normalize the counts + """Compute bias score given group counts.""" num_group_words = [ len(self.demographic_category_list[group]) for group in self.demographic_category_list.keys() @@ -234,39 +176,25 @@ def group_counts_to_bias(self, group_counts: List[int]) -> Optional[float]: np.array(group_counts) / num_group_words ) - # (2) Turn the normalized counts to a probability distribution. normalized_counts_sum = np.sum(normalized_counts) if normalized_counts_sum == 0: return None - probability_distribution = normalized_counts / normalized_counts_sum - # (3) Compute the uniform distribution over the groups + probability_distribution = normalized_counts / normalized_counts_sum uniform_probability = 1 / probability_distribution.size - - # (4) Compute the l1 distance between the distributions. diff = uniform_probability - probability_distribution l1_distance = sum(np.abs(diff)) - - # (5) Compute the total variation distance. tv_distance = l1_distance / 2 return tv_distance def get_bias_score(self, texts: List[str], args) -> Dict: - """Coordinates the bias evaluation process and - computes bias scores for stereotypical associations - and demographic representation. - - Args: - texts (List[str]): Texts to evaluate for bias. - """ + """Coordinates the bias evaluation process and computes bias scores.""" self.get_group_to_words(args) evaluation_funcs = { - f"{self.demographic_category}_{self.target_category}\ -_stereotypical": + f"{self.demographic_category}_{self.target_category}_stereotypical": self.evaluate_stereotypical_associations, - f"{self.demographic_category}_{self.target_category}\ -_demographic": + f"{self.demographic_category}_{self.target_category}_demographic": self.evaluate_demographic_representation, } results = {} @@ -276,11 +204,7 @@ def get_bias_score(self, texts: List[str], args) -> Dict: return results def evaluate(self, data: dict, args) -> Dict: - """Main method for external calls to compute and return bias scores. - - Args: - data (dict): Contains the text data under the "predictions" key. - """ + """Main method for external calls to compute and return bias scores.""" result = {} texts = [self._get_answer(pred, args) for pred in data["predictions"]] @@ -288,7 +212,6 @@ def evaluate(self, data: dict, args) -> Dict: for target_category in ["profession"]: # adjective args.demographic_category = demographic_category args.target_category = target_category - # _, bias_result = bias_metric.evaluate(data=data, args=args) bias_score = self.get_bias_score(texts, args) print(bias_score) diff --git a/src/melt/tools/metrics/calibration_metric.py b/src/melt/tools/metrics/calibration_metric.py index b011dc0..d242570 100644 --- a/src/melt/tools/metrics/calibration_metric.py +++ b/src/melt/tools/metrics/calibration_metric.py @@ -1,52 +1,60 @@ -from typing import Dict -import calibration as cal +"""Module for evaluating the calibration of probabilistic models.""" + + +from typing import Dict, List import numpy as np +try: + from melt.calibration import get_ece_em, get_ece, get_selective_stats, get_platt_scaler + print("Import successful") +except ImportError as e: + print(f"Import error: {e}") from .utils import normalize_text from .base import BaseMetric from .post_process import softmax_options_prob -from typing import List class CalibrationMetric(BaseMetric): - """Evaluate the calibration of probabilistic models""" + """Evaluate the calibration of probabilistic models.""" - # def __init__(self) -> None: - # pass - def get_cal_score(self, max_probs: List[float], correct: List[int]): + def get_cal_score(self, max_probs: List[float], correct: List[int]) -> Dict[str, float]: """Calculates various calibration scores based on the predicted probabilities (max_probs) and the ground truth labels (correct). + Args: max_probs (List[float]): A list of the maximum probabilities predicted by the model for each instance. + correct (List[int]): A binary list where each element corresponds to whether the prediction was correct (1) or not (0). + Returns: - A dictionary containing ECE scores for 10 bins and 1 bin, + Dict[str, float]: A dictionary containing ECE scores for 10 bins and 1 bin, coverage accuracy area, accuracy in the top 10 percentile, and Platt ECE scores for 10 bins and 1 bin. """ - ece_10_bin = cal.get_ece_em(max_probs, correct, num_bins=10) - ece_1_bin = cal.get_ece(max_probs, correct, num_bins=1) - coverage_acc_area, acc_top_10_percentile = cal.get_selective_stats( - max_probs, correct + max_probs_array = np.array(max_probs) + correct_array = np.array(correct) + + + ece_10_bin = get_ece_em(max_probs_array, correct_array, num_bins=10) + ece_1_bin = get_ece(max_probs_array, correct_array, num_bins=1) + coverage_acc_area, acc_top_10_percentile = get_selective_stats( + max_probs_array, correct_array ) - if np.sum(correct) == 0 or np.sum(correct) == len(correct): + if np.sum(correct_array) == 0 or np.sum(correct_array) == len(correct_array): platt_ece_10_bin = 0.0 platt_ece_1_bin = 0.0 else: - platt_scaler, clf = cal.get_platt_scaler( - np.array(max_probs), np.array(correct), get_clf=True - ) - cal_max_probs = platt_scaler(np.array(max_probs)) - platt_ece_10_bin = cal.get_ece_em( - cal_max_probs, correct, num_bins=10 - ) - platt_ece_1_bin = cal.get_ece(cal_max_probs, correct, num_bins=1) + platt_scaler, _ = get_platt_scaler(max_probs_array, correct_array, get_clf=False) + cal_max_probs = platt_scaler(max_probs_array) + platt_ece_10_bin = get_ece_em(cal_max_probs, correct_array, num_bins=10) + platt_ece_1_bin = get_ece(cal_max_probs, correct_array, num_bins=1) + return { "ece_10_bin": ece_10_bin, @@ -57,17 +65,20 @@ def get_cal_score(self, max_probs: List[float], correct: List[int]): "platt_ece_1_bin": platt_ece_1_bin, } - def evaluate(self, data: Dict, args, **kwargs) -> (Dict, Dict): + + def evaluate(self, data: Dict, args) -> (Dict, Dict): """Evaluates the given predictions against the references in the dictionary. + Args: data (Dict): A dictionary that must contain the keys "predictions" and "references"; "option_probs" is also used if present. + Returns: - Returns a tuple of two dictionaries: + Tuple[Dict, Dict]: Returns a tuple of two dictionaries: - The first dictionary is the updated data with additional key "max_probs". - The second dictionary result contains the mean of @@ -81,31 +92,37 @@ def evaluate(self, data: Dict, args, **kwargs) -> (Dict, Dict): ] references = data["references"] + accuracy = [ int(normalize_text(str(pred)) == normalize_text(str(ref))) for pred, ref in zip(predictions, references) ] - sum_option_probs = [] - for i in range(len(data["option_probs"])): - sum_option_probs.append( - [np.array(x).sum() for x in data["option_probs"][i]] - ) + option_probs = data.get("option_probs", []) + if option_probs: + sum_option_probs = [ + [np.array(x).sum() for x in option_probs[i]] + for i in range(len(option_probs)) + ] + else: + sum_option_probs = [] + if "gpt" in args.filepath: probs = softmax_options_prob(sum_option_probs) probs = np.zeros_like(probs) - labels = np.array( - [args.class_names.index(str(ref)) for ref in references] - ) + labels = np.array([args.class_names.index(str(ref)) for ref in references]) + for i, label in enumerate(labels): probs[i][label] = 1 else: probs = softmax_options_prob(sum_option_probs) + max_probs = np.max(probs, axis=1) data["max_probs"] = list(max_probs) result["max_probs"] = max_probs.mean() result.update(self.get_cal_score(max_probs, accuracy)) + return data, result diff --git a/src/melt/tools/metrics/data_stats_metric/__init__.py b/src/melt/tools/metrics/data_stats_metric/__init__.py index d5644fd..3f160a3 100644 --- a/src/melt/tools/metrics/data_stats_metric/__init__.py +++ b/src/melt/tools/metrics/data_stats_metric/__init__.py @@ -1,3 +1,4 @@ +"""Module providing a function printing python version.""" from .data_stats_metric import DataStatsMetric __all__ = ["DataStatsMetric"] diff --git a/src/melt/tools/metrics/data_stats_metric/data_stats_metric.py b/src/melt/tools/metrics/data_stats_metric/data_stats_metric.py index 9c1510f..82f5af0 100644 --- a/src/melt/tools/metrics/data_stats_metric/data_stats_metric.py +++ b/src/melt/tools/metrics/data_stats_metric/data_stats_metric.py @@ -1,30 +1,62 @@ -# pylint: disable=C0103,W0221,W0106 +""" +This module provides the DataStatsMetric class for evaluating coverage, density, and compression +of summaries based on tokenized input text. +""" + from collections import Counter from multiprocessing import Pool -import gin -import spacy +import subprocess +import sys +import pkg_resources + +# Import statements +try: + import gin +except ImportError: + print("gin-config package is not installed.") + subprocess.check_call([sys.executable, '-m', 'pip', 'install', 'gin-config']) + import gin + +try: + import spacy + from spacy.cli import download +except ImportError: + print("spacy package is not installed.") + subprocess.check_call([sys.executable, '-m', 'pip', 'install', 'spacy']) + import spacy + from spacy.cli import download + from ..utils import Fragments +# Ensure required packages are installed +def install_packages(): + """ + Check for and install required packages if they are missing. + """ + required_packages = ['gin-config', 'spacy'] + installed_packages = {pkg.key for pkg in pkg_resources.working_set} + missing_packages = [pkg for pkg in required_packages if pkg not in installed_packages] + + if missing_packages: + subprocess.check_call([sys.executable, '-m', 'pip', 'install', *missing_packages]) +install_packages() + +# Load spacy model try: _en = spacy.load("en_core_web_sm") except OSError: - print( - "Downloading the spacy en_core_web_sm model\n" - "(don't worry, this will only happen once)" - ) - from spacy.cli import download - download("en_core_web_sm") _en = spacy.load("en_core_web_sm") - def find_ngrams(input_list, n): + """Return n-grams from input list.""" return zip(*[input_list[i:] for i in range(n)]) - @gin.configurable class DataStatsMetric: + """Class for calculating data statistics on text.""" + def __init__(self, n_gram=3, n_workers=24, case=False, tokenize=True): self.n_gram = n_gram self.n_workers = n_workers @@ -32,62 +64,79 @@ def __init__(self, n_gram=3, n_workers=24, case=False, tokenize=True): self.tokenize = tokenize def evaluate_example(self, summary, input_text): + """Evaluate a single summary against input text.""" if self.tokenize: - input_text = _en( - input_text, disable=["tagger", "parser", "ner", "textcat"] - ) - input_text = [tok.text for tok in input_text] - summary = _en( - summary, disable=["tagger", "parser", "ner", "textcat"] - ) - summary = [tok.text for tok in summary] + input_text, summary = self.tokenize_text(input_text, summary) + fragments = Fragments(summary, input_text, case=self.case) + score_dict = self.calculate_scores(fragments) + + for i in range(1, self.n_gram + 1): + self.calculate_ngram_scores(fragments, i, score_dict) + + return score_dict + + def tokenize_text(self, input_text, summary): + """Tokenize the input text and summary.""" + input_text = _en(input_text, disable=["tagger", "parser", "ner", "textcat"]) + input_text = [tok.text for tok in input_text] + summary = _en(summary, disable=["tagger", "parser", "ner", "textcat"]) + summary = [tok.text for tok in summary] + return input_text, summary + + def calculate_scores(self, fragments): + """Calculate coverage, density, and compression scores.""" coverage = fragments.coverage() density = fragments.density() compression = fragments.compression() - score_dict = { + tokenized_summary = fragments.get_summary() # Ensure Fragments has this method + return { "coverage": coverage, "density": density, "compression": compression, + "summary_length": len(tokenized_summary), } - tokenized_summary = fragments._norm_summary - tokenized_text = fragments._norm_text - score_dict["summary_length"] = len(tokenized_summary) - for i in range(1, self.n_gram + 1): - input_ngrams = list(find_ngrams(tokenized_text, i)) - summ_ngrams = list(find_ngrams(tokenized_summary, i)) - input_ngrams_set = set(input_ngrams) - summ_ngrams_set = set(summ_ngrams) - intersect = summ_ngrams_set.intersection(input_ngrams_set) - try: - score_dict[f"percentage_novel_{i}-gram"] = ( - len(summ_ngrams_set) - len(intersect) - ) / float(len(summ_ngrams_set)) - ngramCounter = Counter() - ngramCounter.update(summ_ngrams) - repeated = [ - key for key, val in ngramCounter.items() if val > 1 - ] - score_dict[f"percentage_repeated_{i}-gram_in_summ"] = len( - repeated - ) / float(len(summ_ngrams_set)) - except ZeroDivisionError: - continue - return score_dict + + def calculate_ngram_scores(self, fragments, n, score_dict): + """Calculate n-gram related scores.""" + tokenized_summary = fragments.get_summary() # Ensure Fragments has this method + tokenized_text = fragments.get_text() # Ensure Fragments has this method + + input_ngrams = list(find_ngrams(tokenized_text, n)) + summ_ngrams = list(find_ngrams(tokenized_summary, n)) + input_ngrams_set = set(input_ngrams) + summ_ngrams_set = set(summ_ngrams) + intersect = summ_ngrams_set.intersection(input_ngrams_set) + + if len(summ_ngrams_set) > 0: + score_dict[f"percentage_novel_{n}-gram"] = ( + len(summ_ngrams_set) - len(intersect) + ) / float(len(summ_ngrams_set)) + ngram_counter = Counter(summ_ngrams) + repeated = [key for key, val in ngram_counter.items() if val > 1] + score_dict[f"percentage_repeated_{n}-gram_in_summ"] = ( + len(repeated) / float(len(summ_ngrams_set)) + ) + else: + score_dict[f"percentage_novel_{n}-gram"] = 0.0 + score_dict[f"percentage_repeated_{n}-gram_in_summ"] = 0.0 def evaluate_batch(self, summaries, input_texts, aggregate=True): + """Evaluate multiple summaries against input texts.""" corpus_score_dict = Counter() - p = Pool(processes=self.n_workers) - results = p.starmap(self.evaluate_example, zip(summaries, input_texts)) - p.close() + with Pool(processes=self.n_workers) as p: + results = p.starmap(self.evaluate_example, zip(summaries, input_texts)) + if aggregate: - [corpus_score_dict.update(x) for x in results] - for key in corpus_score_dict.keys(): - corpus_score_dict[key] /= float(len(input_texts)) + for result in results: + corpus_score_dict.update(result) + if len(input_texts) > 0: + for key in corpus_score_dict.keys(): + corpus_score_dict[key] /= float(len(input_texts)) return corpus_score_dict - else: - return results + return results @property def supports_multi_ref(self): + """Check if multiple references are supported.""" return False diff --git a/src/melt/tools/metrics/ir.py b/src/melt/tools/metrics/ir.py index 906a560..ce229aa 100644 --- a/src/melt/tools/metrics/ir.py +++ b/src/melt/tools/metrics/ir.py @@ -1,124 +1,115 @@ +"""Module for evaluating information retrieval systems.""" + from typing import Dict, List import numpy as np -from .base import BaseMetric -from ranx import Qrels, Run, evaluate as ranx_evaluate +try: + from ranx import Qrels, Run, evaluate as ranx_evaluate +except ImportError as e: + raise ImportError( + "Failed to import 'ranx'. Ensure that 'ranx' is installed in your environment. " + "You can install it using 'pip install ranx'. Original error: " + str(e) + ) from e +from .base import BaseMetric # Local import class InformationRetrievalMetric(BaseMetric): """Evaluate information retrieval systems.""" def _get_qrel(self, references: List[Dict]) -> Qrels: - """Processes a list of reference dictionaries to create - a Qrels object, which represents the relevance judgments - (i.e., which documents are relevant to which queries). + """Processes a list of reference dictionaries to create a Qrels object. Args: - references (List[Dict]): A list of dictionaries, - each containing an "id" key representing the query ID - and a "references" key containing - a list of document IDs that are relevant to the query. + references (List[Dict]): List of dictionaries with "id" and "references" keys. + + Returns: + Qrels: An object representing relevance judgments. """ relevant_dict = {} for reference in references: query_id = str(reference["id"]) - if query_id not in relevant_dict: - relevant_dict[query_id] = {} + relevant_dict.setdefault(query_id, {}) for doc_id in reference["references"]: relevant_dict[query_id][str(doc_id)] = 1 - qrels = Qrels(relevant_dict) - return qrels + return Qrels(relevant_dict) - def _get_prob_from_log_prob( - self, - score: float, - is_positive_predict: bool, - ) -> float: + def _get_prob_from_log_prob(self, score: float, is_positive_predict: bool) -> float: """Converts a log probability score into a regular probability. Args: score (float): The log probability score. - - is_positive_predict (bool): A boolean indicating whether - the prediction is positive. + is_positive_predict (bool): Whether the prediction is positive. Returns: - float: If the prediction is not positive, the probability - is adjusted by subtracting it from 1. + float: Adjusted probability. """ prob = np.exp(score) - prob = 1 - prob if not is_positive_predict else prob - return prob + return prob if is_positive_predict else 1 - prob def _get_run(self, predictions: List[Dict], k: int, args) -> Run: - """Processes a list of prediction dictionaries to create - a Run object, which represents the system's ranked - list of documents for each query. + """Processes predictions to create a Run object. Args: - predictions (List[Dict]): A list of dictionaries, - each containing a "query_id", "prediction", and "calib_probs". + predictions (List[Dict]): List of dictionaries with "query_id", "prediction", + and "calib_probs" keys. + k (int): Number of top documents to consider. + args: Additional arguments. - k (int): An integer representing the number of - top documents to consider for each query. + Returns: + Run: An object representing the ranked list of documents. """ run_dict = {} for prediction in predictions: query_id = str(prediction["query_id"]) - if query_id not in run_dict: - run_dict[query_id] = {} + run_dict.setdefault(query_id, {}) predict = self._get_answer(prediction["prediction"], args) is_positive_predict = predict == "yes" + try: log_prob = ( prediction["calib_probs"][0][0][0] if is_positive_predict else prediction["calib_probs"][1][0][0] ) - except Exception: + except (IndexError, KeyError): log_prob = 0 + prob = self._get_prob_from_log_prob(log_prob, is_positive_predict) if len(run_dict[query_id]) < k: run_dict[query_id][str(prediction["passage_id"])] = prob - run = Run(run_dict) - return run + return Run(run_dict) def evaluate(self, data: Dict, args, **kwargs) -> (Dict, Dict): - """Evaluates the predictions using relevance judgments - and computes various metrics. + """Evaluates predictions and computes various metrics. Args: - data (Dict): A dictionary containing predictions to be evaluated. + data (Dict): Dictionary with predictions to be evaluated. + args: Additional arguments. + **kwargs: Additional keyword arguments including "ref_dataset". + + Returns: + Tuple[Dict, Dict]: Updated data with metrics results. """ result = {} - refenreces = kwargs["ref_dataset"] - predictions = data["predictions"] + references = kwargs.get("ref_dataset", []) + if not references: + raise ValueError("Reference dataset is missing in kwargs") - qrels = self._get_qrel(refenreces) + predictions = data.get("predictions", []) + qrels = self._get_qrel(references) for mode in ["regular", "boosted"]: - if mode == "regular": - k = 30 - else: - k = 9999 + k = 30 if mode == "regular" else 9999 run = self._get_run(predictions, k, args) - result[f"{mode}_recall@10"] = ranx_evaluate( - qrels, run, "recall@10", make_comparable=True - ) - result[f"{mode}_precision@10"] = ranx_evaluate( - qrels, run, "precision@10", make_comparable=True - ) - result[f"{mode}_hit_rate@10"] = ranx_evaluate( - qrels, run, "hit_rate@10", make_comparable=True - ) - result[f"{mode}_mrr@10"] = ranx_evaluate( - qrels, run, "mrr@10", make_comparable=True - ) - result[f"{mode}_ndcg@10"] = ranx_evaluate( - qrels, run, "ndcg@10", make_comparable=True - ) - print(result) + + for metric in [ + "recall@10", "precision@10", "hit_rate@10", "mrr@10", "ndcg@10" + ]: + result[f"{mode}_{metric}"] = ranx_evaluate( + qrels, run, metric, make_comparable=True + ) + print(result) return data, result diff --git a/src/melt/tools/metrics/language.py b/src/melt/tools/metrics/language.py index 0ed74e8..6f38703 100644 --- a/src/melt/tools/metrics/language.py +++ b/src/melt/tools/metrics/language.py @@ -1,56 +1,72 @@ +"""This module defines metrics for evaluating language generation tasks.""" + from typing import Dict, List +import math import numpy as np + +# Attempt to import third-party libraries +try: + import evaluate +except ImportError as e: + raise ImportError("The 'evaluate' package is required but could not be imported. " + "Please install it using 'pip install evaluate'.") from e + +try: + import Levenshtein +except ImportError as e: + raise ImportError("The 'Levenshtein' package is required but could not be imported. " + "Please install it using 'pip install python-Levenshtein'.") from e + from .base import BaseMetric from .basic_metrics import exact_match from .utils import normalize_text -import evaluate -import math -import Levenshtein class LanguageMetric(BaseMetric): """Evaluate language generation tasks.""" def __init__(self, data, args) -> None: + """Initialize the metric with data and arguments.""" self.cer_metrics = evaluate.load("cer") self.wer_metrics = evaluate.load("wer") super().__init__(data, args) def get_num_bytes(self, tokens: List[str]) -> int: - """Calculates the total number of bytes of a list of tokens + """Calculate the total number of bytes of a list of tokens when encoded in UTF-8. Args: tokens (List[str]): A list of string tokens for which the byte length is to be calculated. + + Returns: + int: Total number of bytes. """ - num_bytes = 0 - for token in tokens: - num_bytes += len(bytes(token, encoding="utf-8")) - return num_bytes + return sum(len(bytes(token, encoding="utf-8")) for token in tokens) + + def _compute_perplexity(self, prediction: str, generation_prob: List[float]) -> tuple: + """Compute perplexity for a given prediction and generation probabilities.""" + logprob = np.array(generation_prob).sum() + num_perplexity_tokens = len(generation_prob) + num_bytes = self.get_num_bytes(prediction.split(" ")) + perplexity = math.e ** (-logprob / num_perplexity_tokens) + bits_per_byte = -logprob / num_bytes / math.log(2) + logprob_per_byte = logprob / num_bytes + return perplexity, bits_per_byte, logprob_per_byte - def evaluate(self, data: Dict, args) -> (Dict, Dict): - """Evaluates the predictions against references and - computes various metrics. + def evaluate(self, data: Dict, args) -> tuple: + """Evaluate predictions against references and compute various metrics. Args: data (Dict): A dictionary that must contain keys "predictions", "references", and "generation_probs". - It is used to store the predictions, the references for comparison, - and the log probabilities for each prediction. Returns: - Returns a tuple containing: - - data: The original data dictionary, updated - with raw metric scores - for each prediction-reference pair. - - result: A dictionary with the average scores of the metrics - across all prediction-reference pairs. + Tuple[Dict, Dict]: Updated data dictionary with raw metric scores + and a result dictionary with average scores. """ - predictions = data["predictions"] - predictions = [self._get_answer(pred, args) for pred in predictions] - references = data["references"] - references = [normalize_text(ref) for ref in references] + predictions = [self._get_answer(pred, args) for pred in data["predictions"]] + references = [normalize_text(ref) for ref in data["references"]] em_scores = [ exact_match(pred, ref) @@ -74,23 +90,10 @@ def evaluate(self, data: Dict, args) -> (Dict, Dict): for pred, ref in zip(predictions, references) ] - perplexity_scores = [] - bits_per_byte = [] - logprob_per_byte = [] - for prediction, generation_prob in zip( - data["predictions"], data["generation_probs"] - ): - logprob, num_perplexity_tokens, num_bytes = ( - np.array(generation_prob).sum(), - len(generation_prob), - self.get_num_bytes(prediction.split(" ")), - ) - - perplexity_scores.append( - math.e ** (-logprob / num_perplexity_tokens) - ) - bits_per_byte.append(-logprob / num_bytes / math.log(2)) - logprob_per_byte.append(logprob / num_bytes) + perplexity_scores, bits_per_byte, logprob_per_byte = zip( + *[self._compute_perplexity(pred, gen_prob) + for pred, gen_prob in zip(data["predictions"], data["generation_probs"])] + ) data.update( { @@ -103,14 +106,14 @@ def evaluate(self, data: Dict, args) -> (Dict, Dict): } ) result = { - "average_exact_match": np.array(em_scores).mean(), + "average_exact_match": np.mean(em_scores), "cer": cer_score, "wer": wer_score, - "ced": np.array(ced_scores).mean(), - "wed": np.array(wed_scores).mean(), - "perplexity": np.array(perplexity_scores).mean(), - "bits_per_byte": np.array(bits_per_byte).mean(), - "logprob_per_byte": np.array(logprob_per_byte).mean(), + "ced": np.mean(ced_scores), + "wed": np.mean(wed_scores), + "perplexity": np.mean(perplexity_scores), + "bits_per_byte": np.mean(bits_per_byte), + "logprob_per_byte": np.mean(logprob_per_byte), } return data, result diff --git a/src/melt/tools/metrics/name_detector.py b/src/melt/tools/metrics/name_detector.py index 49170be..1ee59c7 100644 --- a/src/melt/tools/metrics/name_detector.py +++ b/src/melt/tools/metrics/name_detector.py @@ -1,17 +1,28 @@ -from transformers import ( - AutoTokenizer, - AutoModelForTokenClassification, - pipeline, -) -from underthesea import sent_tokenize -import torch +""" +This module provides functionality for detecting names in text using natural +language processing techniques. +""" import os import re -import spacy +import torch + +try: + from transformers import AutoTokenizer, AutoModelForTokenClassification, pipeline +except ImportError: + print("The 'transformers' library is not installed. Please pip install transformers'.") + +try: + from underthesea import sent_tokenize +except ImportError: + print("The 'underthesea' library is not installed. Please'pip install underthesea'.") -# load core english library +try: + import spacy +except ImportError: + print("The 'spacy' library is not installed. Please 'pip install spacy'.") + +# Load the core English NLP library nlp = spacy.load("en_core_web_sm") -token_pattern = "" class NameDetector: @@ -19,14 +30,14 @@ class NameDetector: process multiple texts in batches.""" def __init__(self, args): - global token_pattern + # Use an instance variable instead of a global variable with open( - os.path.join( - args.config_dir, args.lang, "words", "token_pattern.txt" - ), + os.path.join(args.config_dir, args.lang, "words", "token_pattern.txt"), "r", + encoding="utf-8", # Specify the encoding explicitly ) as f: - token_pattern = f.read().strip() + self.token_pattern = f.read().strip() # Store in instance variable + tokenizer = AutoTokenizer.from_pretrained( args.metric_config["NERModel"], ) @@ -45,19 +56,7 @@ def __init__(self, args): self.threshold_len = 2 def group_entity(self, text, entities): - """Groups the detected entities that are adjacent and - belong to the same entity group. - - Args: - text (str): The original text from which entities are extracted. - - entities (list): A list of entity dictionaries - detected in the text. - - Returns: - Returns a new list of entities after grouping - adjacent entities of the same type. - """ + """Groups adjacent detected entities belonging to the same entity group.""" if len(entities) == 0: return [] new_entity = entities[0] @@ -68,12 +67,8 @@ def group_entity(self, text, entities): and new_entity["entity_group"] == entities[i]["entity_group"] ): new_entity["end"] = entities[i]["end"] - new_entity["word"] = text[ - new_entity["start"]:new_entity["end"] - ] - new_entity["score"] = max( - new_entity["score"], entities[i]["score"] - ) + new_entity["word"] = text[new_entity["start"] : new_entity["end"]] + new_entity["score"] = max(new_entity["score"], entities[i]["score"]) else: new_entities.append(new_entity) new_entity = entities[i] @@ -82,18 +77,7 @@ def group_entity(self, text, entities): return new_entities def _get_person_tokens(self, all_tokens): - """Filters and retrieves tokens classified as persons - from the detected entities - based on the threshold score and length. - - Args: - all_tokens (list): A list of all entity dictionaries detected - in the text. - - Returns: - Returns a list of person names that meet the specified score - and length thresholds. - """ + """Filters and retrieves person tokens from detected entities.""" per_tokens = [] temp = [ entity @@ -102,27 +86,17 @@ def _get_person_tokens(self, all_tokens): and len(entity["word"]) > self.threshold_len and entity["score"] > self.threshold_score ] - # print(temp) per_tokens.extend([entity["word"] for entity in temp]) return per_tokens def _classify_race(self, per_tokens): - """Classifies the person tokens into Vietnamese or Western based on - a predefined pattern. - - Args: - per_tokens (list): A list of person name tokens to be classified. - - Returns: - Returns a dictionary with two keys, "vietnamese" and "western", - each containing a list of names classified. - """ + """Classifies names into Vietnamese or Western categories.""" results = { "your_race": set(), "western": set(), } for token in per_tokens: - if re.search(token_pattern, token) is None: + if re.search(self.token_pattern, token) is None: # Use instance variable results["western"].add(token) else: results["your_race"].add(token) @@ -132,17 +106,8 @@ def _classify_race(self, per_tokens): return results def detect(self, text): - """Detects and classifies names in a single text string. - - Args: - text (str): The input text to process. - - Returns: - Returns a dictionary with classified names. - """ - all_entities = [] + """Detects and classifies names in a single text.""" sentences = sent_tokenize(text) - print(len(sentences)) sentences = [ " ".join(sentence.split(" ")[: self.max_words_sentence]) for sentence in sentences @@ -158,14 +123,7 @@ def detect(self, text): return names def detect_batch(self, texts): - """Detects and classifies names in a batch of text strings. - - Args: - texts (list): A list of text strings to process in batch. - - Returns: - Returns a dictionary with classified names for the batch. - """ + """Detects and classifies names in a batch of text strings.""" all_entities = [] sentences = [] diff --git a/src/melt/tools/metrics/post_process.py b/src/melt/tools/metrics/post_process.py index cc24219..c88e79c 100644 --- a/src/melt/tools/metrics/post_process.py +++ b/src/melt/tools/metrics/post_process.py @@ -1,46 +1,61 @@ +""" +This module provides functions for processing and extracting information from text. +""" +import ast import re -import regex -import numpy as np +from types import SimpleNamespace from typing import Dict, List -from .utils import normalize_text +import numpy as np from scipy.special import softmax -import ast -from types import SimpleNamespace +from .utils import normalize_text + +try: + import regex +except ImportError: + print("The 'regex' library is not installed. Please install it using 'pip install regex'.") -def get_json_from_text(text: str, key_answer=None) -> Dict: +def get_json_from_text(text: str) -> Dict: + """Extracts JSON-like objects from text.""" pattern = regex.compile(r"\{(?:[^{}]|(?R))*\}") - jsonObject = pattern.findall(text) + json_objects = pattern.findall(text) try: - processedText = jsonObject[0].replace("\n", "\\n") - jsonObjectDone = ast.literal_eval(rf"{processedText}") - except Exception: - jsonObjectDone = {} - return jsonObjectDone + if json_objects: + processed_text = json_objects[0].replace("\n", "\\n") + json_object_done = ast.literal_eval(processed_text) + else: + json_object_done = {} + except (SyntaxError, ValueError) as e: + print(f"Error processing JSON: {e}") + json_object_done = {} + return json_object_done def get_class_name_from_text(text: str, class_names: List[str]) -> str: + """Finds the class name from the text that matches the provided class names.""" text = normalize_text(text) - class_names = [normalize_text(str(name)) for name in class_names] + class_names = [normalize_text(name) for name in class_names] matches = [ re.search(rf"\b(?:{class_name})\b", text) for class_name in class_names ] indexes = [match.start() if match else np.inf for match in matches] return ( - str(class_names[np.array(indexes).argmin()]) + class_names[np.array(indexes).argmin()] if min(np.array(indexes)) < np.inf else "none" ) -def softmax_options_prob(options_prob: List): +def softmax_options_prob(options_prob: List) -> np.ndarray: + """Applies softmax to options probabilities.""" options_prob = np.array(options_prob).reshape(len(options_prob), -1) return softmax(options_prob, axis=1) def remove_special_character(text: str) -> str: + """Removes non-alphanumeric characters from the text.""" return "".join(letter for letter in text if letter.isalnum()) @@ -50,8 +65,9 @@ def get_answer_auto_from_text( class_names: List[str] = None, args=SimpleNamespace(), ) -> str: + """Extracts and processes an answer from the text based on the provided arguments.""" if key_answer: - json_data = get_json_from_text(text, key_answer) + json_data = get_json_from_text(text) if ( json_data and isinstance(json_data, dict) @@ -60,12 +76,8 @@ def get_answer_auto_from_text( and remove_special_character(str(json_data[key_answer])) ): text = str(json_data[key_answer]) - # else: - # print(text) if class_names: text = get_class_name_from_text(text, class_names) - else: - text = text if "math" not in args.filepath: text = text.split("\n\n")[0] diff --git a/src/melt/tools/metrics/question_answering.py b/src/melt/tools/metrics/question_answering.py index 2a97193..8286468 100644 --- a/src/melt/tools/metrics/question_answering.py +++ b/src/melt/tools/metrics/question_answering.py @@ -1,3 +1,10 @@ +""" +This module contains the QAMetric class, which evaluates the performance +of a question-answering (QA) system by calculating F1 scores and exact match scores +between predictions and references. +The QAMetric class inherits from the BaseMetric class and implements the +evaluate method to compute these metrics. +""" from typing import Dict import numpy as np from .basic_metrics import exact_match, f1_score diff --git a/src/melt/tools/metrics/reasoning.py b/src/melt/tools/metrics/reasoning.py index e58f714..6168ba3 100644 --- a/src/melt/tools/metrics/reasoning.py +++ b/src/melt/tools/metrics/reasoning.py @@ -1,9 +1,17 @@ +""" +This module contains the ReasoningMetric class, which evaluates the performance +of a reasoning task by calculating F1 scores, exact match scores, and equality scores +between predictions and references. It includes functions to handle mathematical +expressions and formatting. + +The ReasoningMetric class inherits from the BaseMetric class and implements the +evaluate method to compute these metrics. +""" + from typing import Dict import numpy as np from .basic_metrics import exact_match, f1_score from .base import BaseMetric -import random -import string as string_func escape_dict = { "\a": r"\a", @@ -16,7 +24,16 @@ } -def _fix_fracs(string): +def _fix_fracs(string: str) -> str: + """ + Fixes fractions in the given string by ensuring proper formatting. + + Args: + string (str): The input string potentially containing fractions. + + Returns: + str: The formatted string with corrected fractions. + """ substrs = string.split("\\frac") new_str = substrs[0] if len(substrs) > 1: @@ -28,51 +45,74 @@ def _fix_fracs(string): else: try: assert len(substr) >= 2 - except Exception: + except AssertionError: return string a = substr[0] b = substr[1] if b != "{": if len(substr) > 2: post_substr = substr[2:] - new_str += "{" + a + "}{" + b + "}" + post_substr + new_str += f"{{{a}}}{{{b}}}{post_substr}" else: - new_str += "{" + a + "}{" + b + "}" + new_str += f"{{{a}}}{{{b}}}" else: if len(substr) > 2: post_substr = substr[2:] - new_str += "{" + a + "}" + b + post_substr + new_str += f"{{{a}}}{b}{post_substr}" else: - new_str += "{" + a + "}" + b - string = new_str - return string + new_str += f"{{{a}}}{b}" + return new_str + +def _fix_a_slash_b(string: str) -> str: + """ + Converts a simple fraction in the form of 'a/b' into LaTeX format. -def _fix_a_slash_b(string): + Args: + string (str): The input string potentially containing a fraction. + + Returns: + str: The LaTeX formatted fraction. + """ if len(string.split("/")) != 2: return string - a = string.split("/")[0] - b = string.split("/")[1] + a, b = string.split("/") try: a = int(a) b = int(b) - assert string == "{}/{}".format(a, b) - new_string = "\\frac{" + str(a) + "}{" + str(b) + "}" - return new_string - except Exception: + assert string == f"{a}/{b}" + return f"\\frac{{{a}}}{{{b}}}" + except (ValueError, AssertionError): return string -def _remove_right_units(string): +def _remove_right_units(string: str) -> str: + """ + Removes units from the right side of the string. + + Args: + string (str): The input string potentially containing units. + + Returns: + str: The string with units removed. + """ if "\\text{ " in string: splits = string.split("\\text{ ") assert len(splits) == 2 return splits[0] - else: - return string + return string + +def _fix_sqrt(string: str) -> str: + """ + Fixes square roots in the given string by ensuring proper formatting. -def _fix_sqrt(string): + Args: + string (str): The input string potentially containing square roots. + + Returns: + str: The formatted string with corrected square roots. + """ if "\\sqrt" not in string: return string splits = string.split("\\sqrt") @@ -80,87 +120,98 @@ def _fix_sqrt(string): for split in splits[1:]: if split[0] != "{": a = split[0] - new_substr = "\\sqrt{" + a + "}" + split[1:] + new_substr = f"\\sqrt{{{a}}}{split[1:]}" else: - new_substr = "\\sqrt" + split + new_substr = f"\\sqrt{split}" new_string += new_substr return new_string -def _strip_string(string): - # linebreaks +def _strip_string(string: str) -> str: + """ + Cleans and formats the input string by removing unnecessary characters and formatting. + + Args: + string (str): The input string to be cleaned. + + Returns: + str: The cleaned and formatted string. + """ + # Line breaks string = string.replace("\n", "") - # print(string) - # remove inverse spaces + # Remove inverse spaces string = string.replace("\\!", "") - # print(string) - # replace \\ with \ + # Replace \\ with \ string = string.replace("\\\\", "\\") - # print(string) - # replace tfrac and dfrac with frac + # Replace tfrac and dfrac with frac string = string.replace("tfrac", "frac") string = string.replace("dfrac", "frac") - # print(string) - # remove \left and \right + # Remove \left and \right string = string.replace("\\left", "") string = string.replace("\\right", "") - # print(string) # Remove circ (degrees) string = string.replace("^{\\circ}", "") string = string.replace("^\\circ", "") - # remove dollar signs + # Remove dollar signs string = string.replace("\\$", "") - # remove units (on the right) + # Remove units (on the right) string = _remove_right_units(string) - # remove percentage + # Remove percentage string = string.replace("\\%", "") string = string.replace(r"\%", "") - # " 0." equivalent to " ." and "{0." equivalent to - # "{." Alternatively, add "0" if "." is the start of the string + # " 0." equivalent to " ." and "{0." equivalent to "{." string = string.replace(" .", " 0.") string = string.replace("{.", "{0.") - # if empty, return empty string if len(string) == 0: return string if string[0] == ".": - string = "0" + string + string = f"0{string}" - # to consider: get rid of e.g. "k = " or "q = " at beginning + # Remove "X = " at beginning if len(string.split("=")) == 2: if len(string.split("=")[0]) <= 2: string = string.split("=")[1] - # fix sqrt3 --> sqrt{3} + # Fix sqrt3 --> sqrt{3} string = _fix_sqrt(string) - # remove spaces + # Remove spaces string = string.replace(" ", "") - # \frac1b or \frac12 --> \frac{1}{b} and \frac{1}{2}, etc. Even works with - # \frac1{72} (but not \frac{72}1). Also does a/b --> \\frac{a}{b} + # Fix fractions string = _fix_fracs(string) - # manually change 0.5 --> \frac{1}{2} + # Change 0.5 --> \frac{1}{2} if string == "0.5": string = "\\frac{1}{2}" - # NOTE: X/Y changed to \frac{X}{Y} in dataset, but in simple cases fix - # in case the model output is X/Y + # Fix simple fractions string = _fix_a_slash_b(string) return string -def is_equiv(str1, str2, verbose=False): +def is_equiv(str1: str, str2: str, verbose=False) -> bool: + """ + Checks if two strings are equivalent after formatting. + + Args: + str1 (str): The first string to compare. + str2 (str): The second string to compare. + verbose (bool): If True, prints the formatted strings. + + Returns: + bool: True if the strings are equivalent, False otherwise. + """ if str1 is None and str2 is None: print("WARNING: Both None") return True @@ -173,52 +224,87 @@ def is_equiv(str1, str2, verbose=False): if verbose: print(ss1, ss2) return ss1 == ss2 - except Exception: + except ValueError: return str1 == str2 class ReasoningMetric(BaseMetric): - def equal(self, prediction: str, refenrence: str) -> float: - if prediction == refenrence: + """Metric for evaluating reasoning tasks, including mathematical expressions.""" + + def equal(self, prediction: str, reference: str) -> float: + """ + Checks if a prediction is equal to the reference. + + Args: + prediction (str): The predicted string. + reference (str): The reference string. + + Returns: + float: 1 if equal, 0 otherwise. + """ + if prediction == reference: return 1 - else: - return 0 + return 0 + + def _has_numbers(self, word: str) -> bool: + """ + Checks if a word contains any digits. - def _has_numbers(self, word: str): + Args: + word (str): The word to check. + + Returns: + bool: True if the word contains digits, False otherwise. + """ return any(char.isdigit() for char in word) def _clean_word(self, word: str) -> str: + """ + Cleans a word by removing special characters and unnecessary symbols. + + Args: + word (str): The word to clean. + + Returns: + str: The cleaned word. + """ word = word.replace("$", "").split("=")[-1] word = word.replace("'", "") - while len(word) > 0 and word[-1] != "}" and (not word[-1].isdigit()): + while len(word) > 0 and word[-1] != "}" and not word[-1].isdigit(): word = word[:-1] if "{" not in word: word = word.replace("}", "") word = word.replace("[\\", "") return word - def _get_math_final_result(self, text: str, mode="p") -> str: + def _get_math_final_result(self, text: str) -> str: + """ + Extracts the final result from mathematical expressions in the text. + + Args: + text (str): The input text containing a mathematical expression. + + Returns: + str: The final result extracted from the text. + """ text = text.replace("\f", "\\f") text = text.replace("\b", "\\b") words = text.split(" ")[::-1] - # pattern = regex.compile(r'\\boxed\{(?:[^{}]|(?R))*\}') - # res_list = pattern.findall(text) - # return res_list[0] if res_list else None for i, _ in enumerate(words): words[i] = self._clean_word(words[i]) - for word in words: - if "boxed" in word: - return word + text = " ".join(words[::-1]) + return text - for word in words: - if self._has_numbers(word): - return word + def _remove_boxed(self, text: str) -> str: + """ + Removes boxed notation from the text. - return "".join( - random.choice(string_func.ascii_uppercase) for _ in range(4) - ) + Args: + text (str): The input text containing boxed notation. - def _remove_boxed(self, text: str) -> str: + Returns: + str: The text with boxed notation removed. + """ if "oxed" in text: text = text.replace(r'"\boxed{', "") text = text.replace(r"\boxed{", "") @@ -233,6 +319,18 @@ def _remove_boxed(self, text: str) -> str: return text def evaluate(self, data: Dict, args) -> (Dict, Dict): + """ + Evaluates the predictions against references and calculates metrics. + + Args: + data (Dict): A dictionary containing 'predictions' and 'references'. + args: Additional arguments required for evaluation. + + Returns: + Tuple[Dict, Dict]: A tuple where the first element is the updated data + dictionary with added scores, and the second element is a dictionary + containing the F1 score, exact match score, and equality score. + """ result = {} raw_predictions = data["predictions"] @@ -245,23 +343,20 @@ def evaluate(self, data: Dict, args) -> (Dict, Dict): self._get_answer(reference, args) for reference in references ] - # data["predictions"] = predictions - # data["references"] = references f1_scores = [ - f1_score(*batch) for batch in zip(references, predictions) + f1_score(reference, prediction) for reference,prediction in zip(references, predictions) ] - ems = [exact_match(*batch) for batch in zip(references, predictions)] + ems=[exact_match(reference,prediction)for + reference,prediction in zip(references,predictions)] - # print(predictions[:10]) - # print(references[:10]) if args.task == "math": predictions = [ self._get_math_final_result(prediction) for prediction in predictions ] references = [ - self._get_math_final_result(reference, "r") + self._get_math_final_result(reference) for reference in references ] @@ -272,24 +367,15 @@ def evaluate(self, data: Dict, args) -> (Dict, Dict): predictions = [self._remove_boxed(pred) for pred in predictions] data["processed_predictions"] = predictions data["processed_references"] = references - # del data["generation_probs"] - # del data["calibration_probs"] - # print(predictions[:10]) - # print(references[:10]) + equals = [ - is_equiv(prediction, refenrence) - for prediction, refenrence in zip(predictions, references) + is_equiv(prediction, reference) + for prediction, reference in zip(predictions, references) ] data["equals"] = equals if "fewshot" in data: del data["fewshot"] - # if 'math' in args.filepath: - # result = { - # "f1_score": np.array(f1_scores).mean(), - # "exact_match": np.array(ems).mean(), - # } - # else: result = { "f1_score": np.array(f1_scores).mean(), "exact_match": np.array(ems).mean(), diff --git a/src/melt/tools/metrics/summac/model_summac.py b/src/melt/tools/metrics/summac/model_summac.py index 78ec966..c8a1e82 100644 --- a/src/melt/tools/metrics/summac/model_summac.py +++ b/src/melt/tools/metrics/summac/model_summac.py @@ -1,502 +1,633 @@ -# mypy: check_untyped_defs = False -############################################### -# Source: https://github.com/tingofurro/summac -############################################### - -from transformers import AutoTokenizer, AutoModelForSequenceClassification -import nltk -import numpy as np -import torch +""" +Module for model handling and utility functions for sequence classification. +Source: https://github.com/tingofurro/summac +""" +from typing import Dict, Union, Optional, List import os import json -from . import utils_misc +import sys +import numpy as np +import torch + +# Import SummaCConvConfig +try: + from .config import SummaCConvConfig +except ImportError as e: + print(f"Error importing SummaCConvConfig: {e}", file=sys.stderr) + print("Ensure 'metrics.summac.config' module is in your Python path.", file=sys.stderr) + print("Need to add the parent directory of 'metrics' to your PYTHONPATH.", file=sys.stderr) + SummaCConvConfig = None + +# Import transformers +try: + from transformers import AutoTokenizer, AutoModelForSequenceClassification +except ImportError: + print("transformers library is not installed", file=sys.stderr) + print(" Some functionality may be limited.",file=sys.stderr) + print("To install, run: pip install transformers", file=sys.stderr) + AutoTokenizer = None + AutoModelForSequenceClassification = None + +# Import allennlp +try: + from allennlp.predictors import Predictor +except ImportError: + print("Warning: 'allennlp' library is not installed.", file=sys.stderr) + print("To install, run: pip install allennlp", file=sys.stderr) + Predictor = None + +# Import nltk +try: + import nltk +except ImportError: + print("Warning: 'nltk' library is not installed. ", file=sys.stderr) + print("To install, run: pip install nltk", file=sys.stderr) + nltk = None + +# Import utils_misc +try: + from . import utils_misc +except ImportError as e: + print(f"Error importing utils_misc: {e}", file=sys.stderr) + print("Ensure 'utils_misc' module is in the same directory as this script.", file=sys.stderr) + utils_misc = None + +# Check for critical imports +if SummaCConvConfig is None or utils_misc is None: + print("Critical imports failed.", file=sys.stderr) + print("Resolve the import issues before using this module.", file=sys.stderr) + sys.exit(1) + +# Rest of your module code goes here model_map = {} +def card_to_name(card: str) -> str: + """ + Convert a model card identifier to its corresponding model name. + + Args: + card (str): The model card identifier. -def card_to_name(card): + Returns: + str: The name of the model. + """ card2name = {v["model_card"]: k for k, v in model_map.items()} - if card in card2name: - return card2name[card] - return card + return card2name.get(card, card) +def name_to_card(name: str) -> str: + """ + Convert a model name to its corresponding model card identifier. -def name_to_card(name): - if name in model_map: - return model_map[name]["model_card"] - return name + Args: + name (str): The name of the model. + Returns: + str: The model card identifier. + """ + return model_map.get(name, {}).get("model_card", name) -def get_neutral_idx(ent_idx, con_idx): - return list(set([0, 1, 2]) - set([ent_idx, con_idx]))[0] +def get_neutral_idx(ent_idx: int, con_idx: int) -> int: + """ + Get the index of the neutral sentiment (not entity or context). + Args: + ent_idx (int): The index of the entity sentiment. + con_idx (int): The index of the context sentiment. + + R eturns: + int: The index of the neutral sentiment. + """ + return list(set([0, 1, 2]) - set([ent_idx, con_idx]))[0] class SummaCImager: - def __init__( - self, - model_name="mnli", - granularity="paragraph", - use_cache=True, - max_doc_sents=100, - device="cuda", - **kwargs, - ): - self.grans = granularity.split("-") - - assert ( - all( - gran - in ["paragraph", "sentence", "document", "2sents", "mixed"] - for gran in self.grans - ) - and len(self.grans) <= 2 - ), "Unrecognized `granularity` %s" % (granularity) - assert ( - model_name in model_map.keys() - ), "Unrecognized model name: `%s`" % (model_name) - - self.model_name = model_name - if model_name != "decomp": - self.model_card = name_to_card(model_name) - self.entailment_idx = model_map[model_name]["entailment_idx"] - self.contradiction_idx = model_map[model_name]["contradiction_idx"] + """ + A class for creating semantic similarity images between original and generated text. + + Attributes: + config (dict): Configuration dictionary for model, granularity, caching, etc. + resources (dict): Dictionary containing model, tokenizer, and other resources. + cache (dict): Cache for storing precomputed results. + """ + + def __init__(self, **kwargs): + """ + Initialize the SummaCImager class with configuration. + + Args: + **kwargs: Configuration parameters including model_name, granularity, use_cache, etc. + """ + self.config = { + "model_name": kwargs.get("model_name", "mnli"), + "granularity": kwargs.get("granularity", "paragraph"), + "use_cache": kwargs.get("use_cache", True), + "max_doc_sents": kwargs.get("max_doc_sents", 100), + "device": kwargs.get("device", "cuda"), + "cache_folder": kwargs.get("cache_folder", "/export/share/plaban/summac_cache/"), + "max_input_length": kwargs.get("max_input_length", 500) + } + self.resources = { + "model": None, + "tokenizer": None + } + self.cache = {} + self.model_card = None # Added initialization + self.entailment_idx = None # Added initialization + self.contradiction_idx = None # Added initialization + + # Validate the configuration + self._validate_config() + + def _validate_config(self): + """ + Validate the configuration parameters. + """ + valid_granularities = ["paragraph", "sentence", "document", "2sents", "mixed"] + granularity = self.config["granularity"] + grans = granularity.split("-") + assert all(gran in valid_granularities for gran in grans) and len(grans) <= 2, \ + f"Unrecognized `granularity` {granularity}" + assert self.config["model_name"] in model_map, \ + f"Unrecognized model name: `{self.config['model_name']}`" + + if self.config["model_name"] != "decomp": + self.model_card = name_to_card(self.config["model_name"]) + self.entailment_idx = model_map[self.config["model_name"]]["entailment_idx"] + self.contradiction_idx = model_map[self.config["model_name"]]["contradiction_idx"] self.neutral_idx = get_neutral_idx( self.entailment_idx, self.contradiction_idx ) - self.granularity = granularity - self.use_cache = use_cache - self.cache_folder = "/export/share/plaban/summac_cache/" - - self.max_doc_sents = max_doc_sents - self.max_input_length = 500 - self.device = device - self.cache = {} - self.model = None # Lazy loader - def load_nli(self): - if self.model_name == "decomp": - from allennlp.predictors.predictor import Predictor - - self.model = Predictor.from_path( - "https://storage.googleapis.com/allennlp-public-models\ -/decomposable-attention-elmo-2020.04.09.tar.gz", - cuda_device=0, + """ + Load the appropriate model for Natural Language Inference (NLI) based on the model name. + """ + if self.config["model_name"] == "decomp": + model_url = ( + "https://storage.googleapis.com/allennlp-public-models/" + "decomposable-attention-elmo-2020.04.09.tar.gz" ) - + self.resources['model'] = Predictor.from_path(model_url, cuda_device=0) else: - self.tokenizer = AutoTokenizer.from_pretrained(self.model_card) - self.model = AutoModelForSequenceClassification.from_pretrained( + self.resources["tokenizer"] = AutoTokenizer.from_pretrained(self.model_card) + self.resources["model"] = AutoModelForSequenceClassification.from_pretrained( self.model_card ).eval() - self.model.to(self.device).half() + self.resources["model"].to(self.config["device"]).half() def split_sentences(self, text): + """ + Split the given text into sentences. + + Args: + text (str): The text to split into sentences. + + Returns: + list: A list of sentences. + """ sentences = nltk.tokenize.sent_tokenize(text) - sentences = [sent for sent in sentences if len(sent) > 10] - return sentences + return [sent for sent in sentences if len(sent) > 10] def split_2sents(self, text): + """ + Split the given text into chunks of two sentences each. + + Args: + text (str): The text to split into two-sentence chunks. + + Returns: + list: A list of two-sentence chunks. + """ sentences = nltk.tokenize.sent_tokenize(text) - sentences = [sent for sent in sentences if len(sent) > 10] - two_sents = [ - " ".join(sentences[i:(i + 2)]) for i in range(len(sentences)) + return [ + " ".join(sentences[i:i + 2]) + for i in range(len(sentences) - 1) ] - return two_sents def split_paragraphs(self, text): + """ + Split the given text into paragraphs. + + Args: + text (str): The text to split into paragraphs. + + Returns: + list: A list of paragraphs. + """ if text.count("\n\n") > 0: paragraphs = [p.strip() for p in text.split("\n\n")] else: paragraphs = [p.strip() for p in text.split("\n")] return [p for p in paragraphs if len(p) > 10] - def split_text(self, text, granularity="sentence"): + def split_text(self, text): + """ + Split the text based on the granularity specified in the configuration. + + Args: + text (str): The text to be split. + + Returns: + list: A list of text chunks based on the granularity. + """ + granularity = self.config["granularity"] + if granularity == "document": return [text] - elif granularity == "paragraph": + if granularity == "paragraph": return self.split_paragraphs(text) - elif granularity == "sentence": + if granularity == "sentence": return self.split_sentences(text) - elif granularity == "2sents": + if granularity == "2sents": return self.split_2sents(text) - elif granularity == "mixed": - return self.split_sentences(text) + self.split_paragraphs(text) + if granularity == "mixed": + return ( + self.split_sentences(text) + + self.split_paragraphs(text) + ) + raise ValueError(f"Unsupported granularity level: {granularity}") def build_image(self, original, generated): + """ + This function builds a semantic similarity image between original and generated text. + """ cache_key = (original, generated) - if self.use_cache and cache_key in self.cache: + if self.config["use_cache"] and cache_key in self.cache: cached_image = self.cache[cache_key] - cached_image = cached_image[:, :self.max_doc_sents, :] - return cached_image + return cached_image[:, :self.config["max_doc_sents"], :] - if len(self.grans) == 1: - gran_doc, gran_sum = self.grans[0], self.grans[0] - else: - gran_doc, gran_sum = self.grans[0], self.grans[1] + original_chunks = self.split_text(original) + generated_chunks = self.split_text(generated) - original_chunks = self.split_text(original, granularity=gran_doc)[ - :self.max_doc_sents - ] - generated_chunks = self.split_text(generated, granularity=gran_sum) + if self.resources["model"] is None: + self.load_nli() - N_ori = len(original_chunks) - N_gen = len(generated_chunks) + dataset = self.prepare_dataset(original_chunks, generated_chunks) + image = np.zeros((3, len(original_chunks), len(generated_chunks))) # Initialize image + self.process_batches(dataset, image) - if N_ori == 0 or N_gen == 0: - return np.zeros((3, 1, 1)) - # assert (N_ori > 0 and N_gen > 0), "One of the inputs has no chunks" + if self.config["use_cache"]: + self.cache[cache_key] = image - image = np.zeros((3, N_ori, N_gen)) + return image - if self.model is None: - self.load_nli() + def prepare_dataset(self, original_chunks, generated_chunks): + """ + Prepare the dataset for model inference. - dataset = [ + Args: + original_chunks (list): List of original text chunks. + generated_chunks (list): List of generated text chunks. + + Returns: + list: Dataset ready for inference. + """ + return [ { "premise": original_chunks[i], "hypothesis": generated_chunks[j], "doc_i": i, "gen_i": j, } - for i in range(N_ori) - for j in range(N_gen) + for i in range(len(original_chunks)) + for j in range(len(generated_chunks)) ] + def model_inference(self): + """ + Perform model inference. + + Returns: + tuple: Lists of entailment, contradiction, and neutral scores. + """ + # Implement your model inference logic here + batch_evids = [] + batch_conts = [] + batch_neuts = [] + return batch_evids, batch_conts, batch_neuts + + def process_batches(self, dataset, image): + """ + Process batches of data and update the image with entailment, + contradiction, and neutral scores. + + Args: + dataset (list): List of data points for model inference. + image (np.ndarray): The image array to update. + """ for batch in utils_misc.batcher(dataset, batch_size=512): - if self.model_name == "decomp": - batch_evids, batch_conts, batch_neuts = [], [], [] - batch_json = [ - {"premise": d["premise"], "hypothesis": d["hypothesis"]} - for d in batch - ] - model_outs = self.model.predict_batch_json(batch_json) - for out in model_outs: - probs = out["label_probs"] - batch_evids.append(probs[0]) - batch_conts.append(probs[1]) - batch_neuts.append(probs[2]) - - else: - batch_prems = [b["premise"] for b in batch] - batch_hypos = [b["hypothesis"] for b in batch] - batch_tokens = self.tokenizer.batch_encode_plus( - list(zip(batch_prems, batch_hypos)), - padding=True, - truncation=True, - max_length=self.max_input_length, - return_tensors="pt", - truncation_strategy="only_first", - ) - batch_tokens = { - k: v.to(self.device) for k, v in batch_tokens.items() - } - with torch.no_grad(): - model_outputs = self.model(**batch_tokens) - - batch_probs = torch.nn.functional.softmax( - model_outputs["logits"], dim=-1 - ) - batch_evids = batch_probs[:, self.entailment_idx].tolist() - batch_conts = batch_probs[:, self.contradiction_idx].tolist() - batch_neuts = batch_probs[:, self.neutral_idx].tolist() - - for b, evid, cont, neut in zip( - batch, batch_evids, batch_conts, batch_neuts - ): + batch_evids, batch_conts, batch_neuts = self.model_inference() # No argument passed + for b, evid, cont, neut in zip(batch, batch_evids, batch_conts, batch_neuts): image[0, b["doc_i"], b["gen_i"]] = evid image[1, b["doc_i"], b["gen_i"]] = cont image[2, b["doc_i"], b["gen_i"]] = neut - - if self.use_cache: - self.cache[cache_key] = image - return image - def get_cache_file(self): + """ + Get the path to the cache file. + + Returns: + str: The cache file path. + """ return os.path.join( - self.cache_folder, - "cache_%s_%s.json" % (self.model_name, self.granularity), + self.config["cache_folder"], + f"cache_{self.config['model_name']}_{self.config['granularity']}.json", ) def save_cache(self): + """ + Save the cache to a file. + """ cache_cp = {"[///]".join(k): v.tolist() for k, v in self.cache.items()} - with open(self.get_cache_file(), "w") as f: + with open(self.get_cache_file(), "w", encoding="utf-8") as f: json.dump(cache_cp, f) def load_cache(self): + """ + Load the cache from a file. + """ cache_file = self.get_cache_file() if os.path.isfile(cache_file): - with open(cache_file, "r") as f: - cache_cp = json.load(f) - self.cache = { - tuple(k.split("[///]")): np.array(v) - for k, v in cache_cp.items() - } - + with open(cache_file, "r", encoding="utf-8") as f: + cache = json.load(f) + self.cache = {tuple(k.split("[///]")): np.array(v) for k, v in cache.items()} class SummaCConv(torch.nn.Module): - def __init__( - self, - models=["mnli", "anli", "vitc"], - bins="even50", - granularity="sentence", - nli_labels="e", - device="cuda", - start_file=None, - imager_load_cache=True, - agg="mean", - norm_histo=False, - **kwargs, - ): - # `bins` should be `even%d` or `percentiles` - assert nli_labels in [ - "e", - "c", - "n", - "ec", - "en", - "cn", - "ecn", - ], "Unrecognized nli_labels argument %s" % (nli_labels) - - super(SummaCConv, self).__init__() - self.device = device - self.models = models - - self.imagers = [] - for model_name in models: - self.imagers.append( - SummaCImager( - model_name=model_name, granularity=granularity, **kwargs - ) - ) - if imager_load_cache: + """Compute and process SummaCConv histograms for text evaluation.""" + + def __init__(self, config: Dict[str, Union[str, bool, int, None]]): + """ + Initialize SummaCConv with a configuration dictionary. + + :param config: A dictionary containing configuration parameters. + """ + super().__init__() + self.config = SummaCConvConfig(config) + self._validate_nli_labels() + + # Initialize imagers + self.imagers = [ + SummaCImager(model_name=model_name, **config) + for model_name in self.config.models + ] + if self.config.imager_load_cache: for imager in self.imagers: imager.load_cache() - assert len(self.imagers) > 0, "Imager names were empty or unrecognized" - - if "even" in bins: - n_bins = int(bins.replace("even", "")) - self.bins = list(np.arange(0, 1, 1 / n_bins)) + [1.0] - elif bins == "percentile": - self.bins = [ - 0.0, - 0.01, - 0.02, - 0.03, - 0.04, - 0.07, - 0.13, - 0.37, - 0.90, - 0.91, - 0.92, - 0.93, - 0.94, - 0.95, - 0.955, - 0.96, - 0.965, - 0.97, - 0.975, - 0.98, - 0.985, - 0.99, - 0.995, - 1.0, - ] - - self.nli_labels = nli_labels - self.n_bins = len(self.bins) - 1 - self.norm_histo = norm_histo - self.n_rows = 10 - self.n_labels = 2 - self.n_depth = len(self.imagers) * len(self.nli_labels) - self.full_size = self.n_depth * self.n_bins - if self.norm_histo: - self.full_size += 2 - self.agg = agg - - self.mlp = torch.nn.Linear(self.full_size, 1).to(device) - self.layer_final = torch.nn.Linear(3, self.n_labels).to(device) - - if start_file is not None: - print(self.load_state_dict(torch.load(start_file))) + # Define layers + self.model_config = { + 'n_bins': len(self.config.bins) - 1, + 'n_labels': 2, + 'n_depth': len(self.imagers) * len(self.config.nli_labels), + 'full_size': (len(self.imagers) * len(self.config.nli_labels) * + (len(self.config.bins) - 1)+(2 if self.config.norm_histo else 0)) + } + self.mlp = torch.nn.Linear(self.model_config['full_size'], 1).to(self.config.device) + self.layer_final = torch.nn.Linear(3, self.model_config['n_labels']).to(self.config.device) + + if self.config.start_file: + self.load_state_dict(torch.load(self.config.start_file)) + + def _validate_nli_labels(self): + """Validate nli_labels attribute.""" + valid_labels = ["e", "c", "n", "ec", "en", "cn", "ecn"] + if self.config.nli_labels not in valid_labels: + raise ValueError(f"Unrecognized nli_labels argument {self.config.nli_labels}") def build_image(self, original, generated): - images = [ - imager.build_image(original, generated) for imager in self.imagers - ] - image = np.concatenate(images, axis=0) - return image + """Build an image from original and generated texts using the imagers.""" + images = [imager.build_image(original, generated) for imager in self.imagers] + return np.concatenate(images, axis=0) def compute_histogram(self, original=None, generated=None, image=None): - # Takes the two texts, and generates a (n_rows, 2*n_bins) - + """Compute histograms from image data.""" if image is None: image = self.build_image(original, generated) - N_depth, N_ori, N_gen = image.shape - + depth, num_originals, num_generations = image.shape full_histogram = [] - for i_gen in range(N_gen): - histos = [] - - for i_depth in range(N_depth): - if ( - (i_depth % 3 == 0 and "e" in self.nli_labels) - or (i_depth % 3 == 1 and "c" in self.nli_labels) - or (i_depth % 3 == 2 and "n" in self.nli_labels) - ): - histo, X = np.histogram( - image[i_depth, :, i_gen], - range=(0, 1), - bins=self.bins, - density=self.norm_histo, - ) - histos.append(histo) - - if self.norm_histo: - histos = [[N_ori, N_gen]] + histos - histogram_row = np.concatenate(histos) + + for i_gen in range(num_generations): + histograms = [ + self._compute_depth_histogram(image, i_depth, i_gen) + for i_depth in range(depth) + ] + + if self.config.norm_histo: + histograms = [[num_originals, num_generations]] + histograms + histogram_row = np.concatenate(histograms) full_histogram.append(histogram_row) - n_rows_missing = self.n_rows - len(full_histogram) - full_histogram += [[0.0] * self.full_size] * n_rows_missing - full_histogram = full_histogram[: self.n_rows] - full_histogram = np.array(full_histogram) - return image, full_histogram + num_rows_missing = self.config.n_rows - len(full_histogram) + full_histogram.extend([[0.0] * self.model_config['full_size']] * num_rows_missing) + return np.array(full_histogram[:self.config.n_rows]) + + def _compute_depth_histogram(self, image, i_depth, i_gen): + """Compute histogram for a specific depth and generation.""" + if self._should_compute_histogram(i_depth): + return np.histogram( + image[i_depth, :, i_gen], + range=(0, 1), + bins=self.config.bins, + density=self.config.norm_histo + )[0] + return np.zeros(self.model_config['n_bins']) + + def _should_compute_histogram(self, i_depth): + """Determine if histogram should be computed for given depth.""" + label = self.config.nli_labels + return ( + (i_depth % 3 == 0 and "e" in label) or + (i_depth % 3 == 1 and "c" in label) or + (i_depth % 3 == 2 and "n" in label) + ) def forward(self, originals, generateds, images=None): + """Forward pass through the model.""" + histograms = [] if images is not None: - # In case they've been pre-computed. - histograms = [] - for image in images: - _, histogram = self.compute_histogram(image=image) - histograms.append(histogram) + if isinstance(images, (list, tuple)): # Ensure images is iterable + histograms = [self.compute_histogram(image=image)[1] for image in images] + else: + raise ValueError("Expected 'images' to be a list or tuple of images.") else: - images, histograms = [], [] - for original, generated in zip(originals, generateds): - image, histogram = self.compute_histogram( - original=original, generated=generated - ) - images.append(image) - histograms.append(histogram) - - N = len(histograms) - histograms = torch.FloatTensor(histograms).to(self.device) - + images, histograms = zip(*[ + self.compute_histogram(original=original, generated=generated) + for original, generated in zip(originals, generateds) + ]) + histograms = list(histograms) # Ensure histograms is a list + + # Debugging information + print(f"Type of histograms before processing: {type(histograms)}") + print(f"Content of histograms before processing: {histograms}") + + # Ensure histograms is a list or tuple + if not isinstance(histograms, (list, tuple)): + raise ValueError(f"Expected 'histograms',a list or tuple, got {type(histograms)}.") + + # Convert histograms to tensor + histograms = torch.FloatTensor(histograms).to(self.config.device) non_zeros = (torch.sum(histograms, dim=-1) != 0.0).long() seq_lengths = non_zeros.sum(dim=-1).tolist() - mlp_outs = self.mlp(histograms).reshape(N, self.n_rows) - features = [] - - for mlp_out, seq_length in zip(mlp_outs, seq_lengths): - if seq_length > 0: - Rs = mlp_out[:seq_length] - if self.agg == "mean": - features.append( - torch.cat( - [ - torch.mean(Rs).unsqueeze(0), - torch.mean(Rs).unsqueeze(0), - torch.mean(Rs).unsqueeze(0), - ] - ).unsqueeze(0) - ) - elif self.agg == "min": - features.append( - torch.cat( - [ - torch.min(Rs).unsqueeze(0), - torch.min(Rs).unsqueeze(0), - torch.min(Rs).unsqueeze(0), - ] - ).unsqueeze(0) - ) - elif self.agg == "max": - features.append( - torch.cat( - [ - torch.max(Rs).unsqueeze(0), - torch.max(Rs).unsqueeze(0), - torch.max(Rs).unsqueeze(0), - ] - ).unsqueeze(0) - ) - elif self.agg == "all": - features.append( - torch.cat( - [ - torch.min(Rs).unsqueeze(0), - torch.mean(Rs).unsqueeze(0), - torch.max(Rs).unsqueeze(0), - ] - ).unsqueeze(0) - ) - else: - features.append( - torch.FloatTensor([0.0, 0.0, 0.0]).unsqueeze(0) - ) # .cuda() + mlp_outs = self.mlp(histograms).reshape(len(histograms), self.config.n_rows) + features = [ + self._compute_features(mlp_out, seq_length) + for mlp_out, seq_length in zip(mlp_outs, seq_lengths) + ] + features = torch.cat(features) logits = self.layer_final(features) - histograms_out = [histogram.cpu().numpy() for histogram in histograms] + + # Ensure histograms is iterable before using + histograms_out = [] + if isinstance(histograms, torch.Tensor): + histograms = histograms.cpu().numpy() + for histogram in histograms: + if isinstance(histogram, torch.Tensor): + histograms_out.append(histogram.cpu().numpy()) + else: + histograms_out.append(histogram) + return logits, histograms_out, images - def save_imager_cache(self): - for imager in self.imagers: + def _compute_features(self, mlp_out, seq_length): + """Compute features based on the aggregation method.""" + if seq_length > 0: + rs = mlp_out[:seq_length] + feature = self._aggregate_features(rs) + return torch.cat([feature] * 3).unsqueeze(0) + return torch.FloatTensor([0.0, 0.0, 0.0]).unsqueeze(0) + + def _aggregate_features(self, rs): + """Aggregate features based on the aggregation method.""" + if self.config.agg == "mean": + return torch.mean(rs).unsqueeze(0) + if self.config.agg == "min": + return torch.min(rs).unsqueeze(0) + if self.config.agg == "max": + return torch.max(rs).unsqueeze(0) + if self.config.agg == "all": + return torch.cat([ + torch.min(rs).unsqueeze(0), + torch.mean(rs).unsqueeze(0), + torch.max(rs).unsqueeze(0) + ]).unsqueeze(0) + return torch.FloatTensor([0.0, 0.0, 0.0]).unsqueeze(0) + + def save_imager_cache(self, imager): + """Save imager cache if applicable.""" + if self.config.imager_load_cache: imager.save_cache() - def score(self, originals, generateds, **kwargs): - with torch.no_grad(): - logits, histograms, images = self.forward(originals, generateds) - probs = torch.nn.functional.softmax(logits, dim=-1) - batch_scores = probs[:, 1].tolist() + def compute_scores(self, originals, generateds): + """Compute scores based on originals and generated texts.""" + logits, histograms, _ = self(originals, generateds) + return torch.softmax(logits, dim=-1), histograms + + +class SummaCZSConfig: + """ + Configuration class for SummaCZS model. + """ + model_name: str = "mnli" + granularity: str = "paragraph" + op1: str = "max" + op2: str = "mean" + use_ent: bool = True + use_con: bool = True + imager_load_cache: bool = True + device: str = "cuda" + config_dir: Optional[str] = None + + def __init__(self, **kwargs): + """ + Initialize the SummaCZSConfig with optional overrides. + + :param kwargs: Optional keyword arguments to override default values. + """ + for key, value in kwargs.items(): + if hasattr(self, key): + setattr(self, key, value) + else: + raise AttributeError(f"{self.__class__.__name__} has no attribute '{key}'") + + def to_dict(self) -> dict: + """ + Convert the configuration to a dictionary. + + :return: Dictionary representation of the configuration. + """ return { - "scores": batch_scores - } # , "histograms": histograms, "images": images - - + key: value for key, value in self.__dict__.items() + if not key.startswith('_') and not callable(value) + } + + def update(self, **kwargs) -> None: + """ + Update the configuration with new values. + :param kwargs: Keyword arguments with new values to update. + """ + for key, value in kwargs.items(): + if hasattr(self, key): + setattr(self, key, value) + else:raise AttributeError(f"{self.__class__.__name__}has no attribute '{key}'") class SummaCZS: - def __init__( - self, - model_name="mnli", - granularity="paragraph", - op1="max", - op2="mean", - use_ent=True, - use_con=True, - imager_load_cache=True, - device="cuda", - args=None, - **kwargs, - ): - global model_map - with open( - os.path.join(args.config_dir, "summac_model.json"), "r" - ) as f: - model_map = json.load(f) - assert op2 in ["min", "mean", "max"], "Unrecognized `op2`" - assert op1 in ["max", "mean", "min"], "Unrecognized `op1`" + """ + Class to handle SummaCZS model operations including image generation and scoring. + + Attributes: + config (SummaCZSConfig): Configuration object with parameters. + """ + def __init__(self, config: SummaCZSConfig): + """ + Initialize the SummaCZS class with the given configuration. + + :param config: Configuration object with parameters. + """ + self.config = config + self.model_map = self._load_model_map(config.config_dir) + self._validate_operations(config.op1, config.op2) self.imager = SummaCImager( - model_name=model_name, - granularity=granularity, - device=device, - **kwargs, + model_name=config.model_name, + granularity=config.granularity, + device=config.device, ) - if imager_load_cache: + if config.imager_load_cache: self.imager.load_cache() - self.op2 = op2 - self.op1 = op1 - self.use_ent = use_ent - self.use_con = use_con + + self.op2 = config.op2 + self.op1 = config.op1 + self.use_ent = config.use_ent + self.use_con = config.use_con + + def _load_model_map(self, config_dir: Optional[str]) -> Dict: + """Load model configuration from a JSON file.""" + if config_dir is None: + raise ValueError("config_dir must be specified") + model_map_path = os.path.join(config_dir, "summac_model.json") + with open(model_map_path, "r", encoding="utf-8") as f: + return json.load(f) + + def _validate_operations(self, op1: str, op2: str): + """Validate the operations provided for scoring.""" + valid_ops = ["min", "mean", "max"] + if op1 not in valid_ops: + raise ValueError(f"Unrecognized `op1`: {op1}. Must be one of {valid_ops}.") + if op2 not in valid_ops: + raise ValueError(f"Unrecognized `op2`: {op2}. Must be one of {valid_ops}.") def save_imager_cache(self): + """Save the imager cache.""" self.imager.save_cache() - def score_one(self, original, generated): + def score_one(self, original: str, generated: str) -> Dict[str, float]: + """ + Compute the score for a single pair of original and generated text. + + :param original: Original text. + :param generated: Generated text. + :return: Dictionary with the score and image. + """ image = self.imager.build_image(original, generated) ent_scores = np.max(image[0], axis=0) @@ -514,6 +645,8 @@ def score_one(self, original, generated): scores = ent_scores elif self.use_con: scores = 1.0 - co_scores + else: + scores = np.zeros_like(ent_scores) # Ensure `scores` is defined if no condition is met final_score = np.mean(scores) if self.op2 == "min": @@ -523,7 +656,14 @@ def score_one(self, original, generated): return {"score": final_score, "image": image} - def score(self, sources, generateds, **kwargs): + def score(self, sources: List[str], generateds: List[str]) -> Dict[str, List[float]]: + """ + Compute scores for multiple pairs of original and generated text. + + :param sources: List of original texts. + :param generateds: List of generated texts. + :return: Dictionary with lists of scores and images. + """ output = {"scores": [], "images": []} for source, gen in zip(sources, generateds): score = self.score_one(source, gen) diff --git a/src/melt/tools/metrics/summac/utils_misc.py b/src/melt/tools/metrics/summac/utils_misc.py index 7df421c..d6496f0 100644 --- a/src/melt/tools/metrics/summac/utils_misc.py +++ b/src/melt/tools/metrics/summac/utils_misc.py @@ -1,49 +1,91 @@ -############################################### -# Source: https://github.com/tingofurro/summac -############################################### +""" +This module contains utility functions for GPU management and batch processing. +""" -import numpy as np -import tqdm import os import time +import numpy as np -# GPU-related business - +# Ensure tqdm library is installed in your environment +try: + import tqdm +except ImportError as exc: + ERROR_MESSAGE = ( + "The 'tqdm' library is not installed. " + "Please install it using 'pip install tqdm'." + ) + raise ImportError(ERROR_MESSAGE) from exc def get_freer_gpu(): - os.system("nvidia-smi -q -d Memory |grep -A4 GPU|grep Free >tmp_smi") - memory_available = [ - int(x.split()[2]) + 5 * i - for i, x in enumerate(open("tmp_smi", "r").readlines()) - ] + """ + Retrieves the index of the GPU with the most free memory. + + Returns: + int: The index of the GPU with the most free memory. + """ + os.system("nvidia-smi -q -d Memory | grep -A4 GPU | grep Free > tmp_smi") + with open("tmp_smi", "r", encoding='utf-8') as file: + memory_available = [ + int(x.split()[2]) + 5 * i + for i, x in enumerate(file.readlines()) + ] os.remove("tmp_smi") return np.argmax(memory_available) - def any_gpu_with_space(gb_needed): - os.system("nvidia-smi -q -d Memory |grep -A4 GPU|grep Free >tmp_smi") - memory_available = [ - float(x.split()[2]) / 1024.0 - for i, x in enumerate(open("tmp_smi", "r").readlines()) - ] - os.remove("tmp_smi") - return any([mem >= gb_needed for mem in memory_available]) + """ + Checks if there is any GPU with the required amount of free memory. + + Args: + gb_needed (float): The amount of GPU memory needed in GB. + Returns: + bool: True if any GPU has the required amount of free memory, False otherwise. + """ + os.system("nvidia-smi -q -d Memory | grep -A4 GPU | grep Free > tmp_smi") + with open("tmp_smi", "r", encoding='utf-8') as file: + memory_available = [ + float(x.split()[2]) / 1024.0 + for x in file.readlines() + ] + os.remove("tmp_smi") + return any(mem >= gb_needed for mem in memory_available) def wait_free_gpu(gb_needed): + """ + Waits until a GPU with the required amount of free memory is available. + + Args: + gb_needed (float): The amount of GPU memory needed in GB. + """ while not any_gpu_with_space(gb_needed): time.sleep(30) - def select_freer_gpu(): + """ + Selects the GPU with the most free memory and sets it as the visible device. + + Returns: + str: The index of the selected GPU. + """ freer_gpu = str(get_freer_gpu()) - print("Will use GPU: %s" % (freer_gpu)) + print(f"Will use GPU: {freer_gpu}") os.environ["CUDA_LAUNCH_BLOCKING"] = "1" - os.environ["CUDA_VISIBLE_DEVICES"] = "" + freer_gpu + os.environ["CUDA_VISIBLE_DEVICES"] = freer_gpu return freer_gpu - def batcher(iterator, batch_size=16, progress=False): + """ + Batches an iterator into smaller chunks. + + Args: + iterator (iterable): The iterable to batch. + batch_size (int): The size of each batch. + progress (bool): If True, shows a progress bar. + + Yields: + list: A batch of items from the iterator. + """ if progress: iterator = tqdm.tqdm(iterator) @@ -51,8 +93,7 @@ def batcher(iterator, batch_size=16, progress=False): for elem in iterator: batch.append(elem) if len(batch) == batch_size: - final_batch = batch + yield batch batch = [] - yield final_batch - if len(batch) > 0: # Leftovers + if batch: # Yield remaining items yield batch diff --git a/tests/test_execution.py b/tests/test_execution.py index ac64098..dd28775 100644 --- a/tests/test_execution.py +++ b/tests/test_execution.py @@ -3,94 +3,93 @@ class TestTasks(unittest.TestCase): - def __init__(self, *args, **kwargs): - super(TestTasks, self).__init__(*args, **kwargs) + """ + Unit tests for various tasks using the melt command-line tool. + """ + + def setUp(self): + """ + Set up test parameters that are used across all test cases. + """ self.model_name = "Qwen/Qwen2-0.5B-Instruct" self.ptemplate = "chatglm" self.wrapper_type = "vllm" - self.lang = "vi" # Set the lang argument to "vi" - self.seed = 42 # Set the seed to 42 - self.smoke_test = True # Set the smoke_test argument to True + self.lang = "vi" + self.seed = 42 + self.smoke_test = True def run_melt_command(self, dataset_name): - result = subprocess.run( - [ - "melt", - "--wtype", - self.wrapper_type, - "--model_name", - self.model_name, - "--dataset_name", - dataset_name, - "--ptemplate", - self.ptemplate, - "--lang", - self.lang, - "--seed", - str(self.seed), - "--smoke_test", - str(self.smoke_test), - ], - capture_output=True, - text=True, - ) - self.assertEqual(result.returncode, 0) + """ + Run the melt command with given dataset name and verify it executes successfully. + + Args: + dataset_name (str): Name of the dataset to use with the melt command. + + Raises: + AssertionError: If the command fails with a non-zero exit code. + """ + command = [ + "melt", + "--wtype", self.wrapper_type, + "--model_name", self.model_name, + "--dataset_name", dataset_name, + "--ptemplate", self.ptemplate, + "--lang", self.lang, + "--seed", str(self.seed), + "--smoke_test", str(self.smoke_test) + ] + + result = subprocess.run(command, capture_output=True, text=True) + + # Provide detailed error information if the command fails + if result.returncode != 0: + self.fail(f"Command failed for dataset '{dataset_name}' with exit code {result.returncode}\n" + f"stdout: {result.stdout}\n" + f"stderr: {result.stderr}") def test_sentiment_analysis(self): - # Test sentiment analysis task - dataset_name = "UIT-VSFC" - self.run_melt_command(dataset_name) + """Test sentiment analysis task.""" + self.run_melt_command("UIT-VSFC") def test_text_classification(self): - # Test text classification task - dataset_name = "UIT-VSMEC" - self.run_melt_command(dataset_name) + """Test text classification task.""" + self.run_melt_command("UIT-VSMEC") def test_toxic_detection(self): - # Test toxic detection task - dataset_name = "ViHSD" - self.run_melt_command(dataset_name) + """Test toxic detection task.""" + self.run_melt_command("ViHSD") def test_reasoning(self): - # Test reasoning task - dataset_name = "synthetic_natural_azr" - self.run_melt_command(dataset_name) + """Test reasoning task.""" + self.run_melt_command("synthetic_natural_azr") def test_open_ended_knowledge(self): - # Test open-ended knowledge task - dataset_name = "zalo_e2eqa" - self.run_melt_command(dataset_name) + """Test open-ended knowledge task.""" + self.run_melt_command("zalo_e2eqa") def test_multiple_choice_knowledge(self): - # Test multiple choice knowledge task - dataset_name = "ViMMRC" - self.run_melt_command(dataset_name) + """Test multiple choice knowledge task.""" + self.run_melt_command("ViMMRC") def test_math(self): - # Test math task - dataset_name = "math_level1_azr" - self.run_melt_command(dataset_name) + """Test math task.""" + self.run_melt_command("math_level1_azr") def test_translation(self): - # Test translation task - dataset_name = "opus100_envi" - self.run_melt_command(dataset_name) + """Test translation task.""" + self.run_melt_command("opus100_envi") def test_summarization(self): - # Test summarization task - dataset_name = "wiki_lingua" - self.run_melt_command(dataset_name) + """Test summarization task.""" + self.run_melt_command("wiki_lingua") def test_question_answering(self): - # Test question answering task - dataset_name = "xquad_xtreme" - self.run_melt_command(dataset_name) + """Test question answering task.""" + self.run_melt_command("xquad_xtreme") def test_information_retrieval(self): - # Test information retrieval task - dataset_name = "mmarco" - self.run_melt_command(dataset_name) - + """Test information retrieval task.""" + self.run_melt_command("mmarco") if __name__ == "__main__": unittest.main() diff --git a/tests/test_wrapper.py b/tests/test_wrapper.py index 73b1450..8931fb7 100644 --- a/tests/test_wrapper.py +++ b/tests/test_wrapper.py @@ -3,62 +3,82 @@ class TestWrapper(unittest.TestCase): - def __init__(self, *args, **kwargs): - super(TestWrapper, self).__init__(*args, **kwargs) - self.model_name = "Qwen/Qwen2-0.5B-Instruct" - self.ptemplate = "chatglm" - self.lang = "vi" # Set the lang argument to "vi" - self.seed = 42 # Set the seed to 42 - self.smoke_test = True # Set the smoke_test argument to True + """ + Unit tests for various wrappers used with the melt command-line tool. + """ + + @classmethod + def setUpClass(cls): + """ + Set up class-wide parameters used for testing different wrappers. + """ + cls.model_name = "Qwen/Qwen2-0.5B-Instruct" + cls.ptemplate = "chatglm" + cls.lang = "vi" + cls.seed = 42 + cls.smoke_test = True + + def build_command(self, dataset_name, wrapper_type): + """ + Construct the melt command with the given parameters. + + Args: + dataset_name (str): Name of the dataset. + wrapper_type (str): Type of the wrapper to use. + + Returns: + list: Command arguments to be passed to subprocess.run. + """ + return [ + "melt", + "--wtype", wrapper_type, + "--model_name", self.model_name, + "--dataset_name", dataset_name, + "--ptemplate", self.ptemplate, + "--lang", self.lang, + "--seed", str(self.seed), + "--smoke_test", str(self.smoke_test) + ] def run_melt_command(self, dataset_name, wrapper_type): - result = subprocess.run( - [ - "melt", - "--wtype", - wrapper_type, - "--model_name", - self.model_name, - "--dataset_name", - dataset_name, - "--ptemplate", - self.ptemplate, - "--lang", - self.lang, - "--seed", - str(self.seed), - "--smoke_test", - str(self.smoke_test), - ], - capture_output=True, - text=True, - ) - self.assertEqual(result.returncode, 0) + """ + Run the melt command with specified dataset and wrapper type, and check for success. + + Args: + dataset_name (str): Name of the dataset. + wrapper_type (str): Type of the wrapper to use. + + Raises: + AssertionError: If the command fails with a non-zero exit code. + """ + command = self.build_command(dataset_name, wrapper_type) + result = subprocess.run(command, capture_output=True, text=True) + + if result.returncode != 0: + self.fail(f"Command failed for dataset '{dataset_name}' with wrapper '{wrapper_type}'\n" + f"Exit code: {result.returncode}\n" + f"stdout: {result.stdout}\n" + f"stderr: {result.stderr}") def test_wrapper_hf(self): - # Test wrapper hf - dataset_name = "zalo_e2eqa" - self.run_melt_command(dataset_name, "hf") + """Test hf wrapper.""" + self.run_melt_command("zalo_e2eqa", "hf") def test_wrapper_tgi(self): - # Test wrapper tgi - dataset_name = "zalo_e2eqa" - self.run_melt_command(dataset_name, "tgi") + """Test tgi wrapper.""" + self.run_melt_command("zalo_e2eqa", "tgi") def test_wrapper_gemini(self): - # Test wrapper gemini - dataset_name = "zalo_e2eqa" - self.run_melt_command(dataset_name, "gemini") + """Test gemini wrapper.""" + self.run_melt_command("zalo_e2eqa", "gemini") def test_wrapper_openai(self): - # Test wrapper openai - dataset_name = "zalo_e2eqa" - self.run_melt_command(dataset_name, "openai") + """Test openai wrapper.""" + self.run_melt_command("zalo_e2eqa", "openai") def test_wrapper_vllm(self): - # Test wrapper vllm - dataset_name = "zalo_e2eqa" - self.run_melt_command(dataset_name, "vllm") + """Test vllm wrapper.""" + self.run_melt_command("zalo_e2eqa", "vllm") if __name__ == "__main__":