From 3bf32744e42ada0483f6d9de9f70536c301847c6 Mon Sep 17 00:00:00 2001 From: Syed Ali Mohsin Bukhari Date: Thu, 28 Nov 2024 11:22:44 +0500 Subject: [PATCH 01/11] [testing-minor-1] 1. added error docstrings to `eIO.py`, `eList.py`, `eOS.py`and `ePlotting.py` 2. added function docstrings to `uDict.py` and `uList.py` --- src/mpyez/backend/eIO.py | 18 ++- src/mpyez/backend/eList.py | 60 +++++++++ src/mpyez/backend/eOS.py | 12 ++ src/mpyez/backend/ePlotting.py | 29 ++++- src/mpyez/backend/uDict.py | 128 +++++++++++++++++--- src/mpyez/backend/uList.py | 214 +++++++++++++++++++++++++++------ 6 files changed, 399 insertions(+), 62 deletions(-) diff --git a/src/mpyez/backend/eIO.py b/src/mpyez/backend/eIO.py index c7f9424..2ade60c 100644 --- a/src/mpyez/backend/eIO.py +++ b/src/mpyez/backend/eIO.py @@ -2,10 +2,24 @@ class EzFileErrs(Exception): - """Base class for exceptions in EzFile.""" + """ + Base class for exceptions in EzFile. + + Notes + ----- + This serves as the base class for all exceptions related to file operations + in the EzFile module. Specific exceptions should inherit from this class. + """ pass class LineNumberOutOfBounds(EzFileErrs): - """Exception raised when a line number is out of bounds.""" + """ + Raised when a specified line number is out of the bounds of the file. + + Notes + ----- + This error occurs when an operation attempts to access a line number + that exceeds the number of lines in the file or is less than 1. + """ pass diff --git a/src/mpyez/backend/eList.py b/src/mpyez/backend/eList.py index 11b373a..afe101d 100644 --- a/src/mpyez/backend/eList.py +++ b/src/mpyez/backend/eList.py @@ -2,32 +2,92 @@ class EzListErrs(Exception): + """ + Base class for custom exceptions related to EzList operations. + All specific errors inherit from this class. + """ pass class AlphabetFound(EzListErrs): + """ + Raised when an unexpected alphabet character is found in the list. + + Notes + ----- + This error is typically raised when numeric operations are attempted + on a list that contains alphabetic strings. + """ pass class IndexOutOfList(EzListErrs): + """ + Raised when an index is out of the valid range of the list. + + Notes + ----- + This error occurs when an attempt is made to access or modify a list + using an index that exceeds the bounds of the list. + """ pass class UnequalElements(EzListErrs): + """ + Raised when lists have unequal elements where equality is expected. + + Notes + ----- + This error may be raised when operations requiring equal-length lists + or matching elements are attempted but the lists differ in size or content. + """ pass class GotAnUnknownValue(EzListErrs): + """ + Raised when an unknown or unexpected value is encountered. + + Notes + ----- + This error is raised when a list contains a value that does not conform + to the expected data type or range. + """ pass class ChildListLengthError(EzListErrs): + """ + Raised when a child list has an invalid length. + + Notes + ----- + This error is raised in scenarios where nested or child lists are + expected to meet specific length constraints but fail to do so. + """ pass class StringNotPassed(EzListErrs): + """ + Raised when a string input is expected but not provided. + + Notes + ----- + This error is raised when a function or operation requires a string + input but receives a non-string value instead. + """ pass class InvalidInputParameter(EzListErrs): + """ + Raised when an invalid parameter is passed to a function or method. + + Notes + ----- + This error indicates that one or more input arguments to a function + do not meet the expected type, range, or format. + """ pass diff --git a/src/mpyez/backend/eOS.py b/src/mpyez/backend/eOS.py index 292ac01..4852d06 100644 --- a/src/mpyez/backend/eOS.py +++ b/src/mpyez/backend/eOS.py @@ -2,8 +2,20 @@ class EzOsErrs(Exception): + """ + Base class for custom exceptions related to EzOs operations. + All specific errors inherit from this class. + """ pass class FileNotPresent(EzOsErrs): + """ + Raised when a specified file is not found in the expected location. + + Notes + ----- + This error is typically raised when a file operation (e.g., reading or writing) + is attempted on a file that does not exist. + """ pass diff --git a/src/mpyez/backend/ePlotting.py b/src/mpyez/backend/ePlotting.py index 8a77af6..d94bf32 100644 --- a/src/mpyez/backend/ePlotting.py +++ b/src/mpyez/backend/ePlotting.py @@ -2,15 +2,38 @@ class PlotError(Exception): - """Basic PlotError class""" + """ + Base class for exceptions related to plotting operations. + + Notes + ----- + This serves as the parent class for all plotting-related errors. + Specific exceptions related to plot configuration or data issues + should inherit from this class. + """ pass class NoXYLabels(PlotError): - """Custom class for missing x or y labels""" + """ + Raised when x or y labels are missing in a plot. + + Notes + ----- + This error occurs when a plot is expected to have labels for both + the x-axis and y-axis, but one or both are missing. Proper labeling + is often required for clarity in visualizations. + """ pass class OrientationError(PlotError): - """Custom class for wrong orientation""" + """ + Raised when an invalid or unexpected orientation is used in a plot. + + Notes + ----- + This error occurs when the orientation parameter for a plot is set + incorrectly or does not match the expected format. + """ pass diff --git a/src/mpyez/backend/uDict.py b/src/mpyez/backend/uDict.py index e444ac8..80c22f3 100644 --- a/src/mpyez/backend/uDict.py +++ b/src/mpyez/backend/uDict.py @@ -1,41 +1,131 @@ """Created on Aug 17 23:51:58 2022.""" +from typing import Any, Dict, List, Union -def change_value_to_list(input_dictionary): + +def change_value_to_list(input_dictionary: Dict[Any, Any]) -> Dict[Any, List[Any]]: + """ + Converts all non-list values in a dictionary to lists. + + Parameters + ---------- + input_dictionary : dict + The dictionary whose values need to be converted to lists. + + Returns + ------- + dict + A dictionary where all values are guaranteed to be lists. + """ for key, value in input_dictionary.items(): if not isinstance(value, list): - input_dictionary[key] = [input_dictionary[key]] - + input_dictionary[key] = [value] return input_dictionary -def change_list_to_values(input_dictionary): - for key, value in input_dictionary.items(): - if len(input_dictionary[key]) == 1: - input_dictionary[key] = input_dictionary[key][0] +def change_list_to_values(input_dictionary: Dict[Any, List[Any]]) -> Dict[Any, Union[Any, List[Any]]]: + """ + Converts single-element lists in a dictionary back to their single values. + Parameters + ---------- + input_dictionary : dict + The dictionary whose single-element list values need to be simplified. + + Returns + ------- + dict + A dictionary where single-element lists are replaced by their sole values. + """ + for key, value in input_dictionary.items(): + if len(value) == 1: + input_dictionary[key] = value[0] return input_dictionary class PrettyPrint: - def __init__(self, input_dictionary: dict): + """ + A class for displaying dictionaries in a tabular format. + + Parameters + ---------- + input_dictionary : dict + The dictionary to be formatted and displayed. + column_width : int, optional + Custom width for table columns (default is dynamically calculated). + alignment : str, optional + Alignment for table cells: 'left', 'center', or 'right' (default is 'center'). + + Attributes + ---------- + inp_dict : dict + The input dictionary stored for formatting and display. + column_width : int + Width of each column in the table. + alignment : str + Alignment configuration for table cells. + """ + + def __init__(self, input_dictionary: Dict[Any, Any], column_width: int = None, alignment: str = "center"): self.inp_dict = input_dictionary + self.column_width = column_width + self.alignment = alignment + + def __get_max_width(self) -> int: + """ + Computes the maximum column width for formatting if not provided. + + Returns + ------- + int + The maximum width for the table columns. + """ + if self.column_width: + return self.column_width - def __get_max_width(self): value_widths = [len(str(value)) for value in self.inp_dict.values()] max_width = max(value_widths) return max(max_width + 1, 71) if max_width % 2 == 0 else max(max_width, 71) - def __str__(self): - max_width = self.__get_max_width() + def __align_text(self, text: str, width: int) -> str: + """ + Aligns text within a given width based on the alignment setting. - width = (max_width - 1) // 2 - 1 - pline = '-' * max_width + '\n' + Parameters + ---------- + text : str + The text to align. + width : int + The width of the column. + + Returns + ------- + str + The aligned text. + """ + if self.alignment == "left": + return text.ljust(width) + elif self.alignment == "right": + return text.rjust(width) + else: # Default to center + return text.center(width) + + def __str__(self) -> str: + """ + Returns the formatted string representation of the dictionary. + + Returns + ------- + str + A tabular representation of the dictionary with enhanced aesthetics. + """ + max_width = self.__get_max_width() + column_width = (max_width - 1) // 2 - 1 + separator_line = "+" + "-" * column_width + "+" + "-" * column_width + "+\n" - out = pline - out += f"|{'dict_key'.center(width)}|{'dict_value'.center(width)}|\n" - out += pline - out += '\n'.join([f"|{k.center(width)}|{v.center(width)}|" for k, v in self.inp_dict.items()]) + '\n' - out += pline + header = f"|{self.__align_text('dict_key', column_width)}|{self.__align_text('dict_value', column_width)}|\n" + rows = '\n'.join([f"|{self.__align_text(str(key), column_width)}|" + f"{self.__align_text(str(value), column_width)}|" + for key, value in self.inp_dict.items()]) - return out + return separator_line + header + separator_line + rows + '\n' + separator_line diff --git a/src/mpyez/backend/uList.py b/src/mpyez/backend/uList.py index 35b5152..159eb95 100644 --- a/src/mpyez/backend/uList.py +++ b/src/mpyez/backend/uList.py @@ -2,86 +2,224 @@ from copy import deepcopy from itertools import compress -from typing import List, Union +from typing import Any, Dict, List, Union from . import eList as eL -from .eList import GotAnUnknownValue, IndexOutOfList, UnequalElements -def equalizing_list_length(primary_list, secondary_list, names): - primary_len = len(primary_list) - secondary_len = len(secondary_list) +def equalizing_list_length(primary_list: List, secondary_list: List) -> List: + """ + Adjusts the length of the secondary list to match the length of the primary list. + + If the secondary list is shorter, it is extended by repeating its elements cyclically. + If the secondary list is longer, an error is raised. - if secondary_len > primary_len: - raise UnequalElements(f'The number of elements in the {names[0]} list is greater than that of {names[1]}. ' - f'Cannot perform replacement in this case.') - elif secondary_len < primary_len: - secondary_list += secondary_list[:primary_len - secondary_len] + Parameters + ---------- + primary_list : list + The reference list whose length needs to be matched. + secondary_list : list + The list to be adjusted to the length of the primary list. + + Returns + ------- + list + The adjusted secondary list, with a length equal to the primary list. + + Raises + ------ + UnequalElements + If the secondary list has more elements than the primary list. + """ + primary_length = len(primary_list) + secondary_length = len(secondary_list) + + if secondary_length > primary_length: + raise eL.UnequalElements(f"The secondary list ({secondary_length} elements) is longer than the primary list ({primary_length} elements).") + elif secondary_length < primary_length: + # Extend secondary_list cyclically to match the length of primary_list + secondary_list += secondary_list[:primary_length - secondary_length] return secondary_list -def replace_at_index(input_list, index, value, new_list=False): - index, value = map(lambda x: [x] if not isinstance(x, list) else x, (index, value)) - value += value[:len(index) - len(value)] +def replace_at_index(input_list: List, index: Union[int, List[int]], value: Union[Any, List[Any]], new_list: bool = False) -> List: + """ + Replaces elements in a list at specified indices with new values. - for i, v in zip(index, value): - if i > len(input_list) - 1: - raise IndexOutOfList(f'Index {i} is out of bound for a list of length {len(input_list)}.') + Parameters + ---------- + input_list : list + The original list whose elements need to be replaced. + index : int or list of int + The index or indices of the elements to replace. + value : any or list of any + The new value(s) to insert at the specified index/indices. + new_list : bool, optional + If True, returns a modified copy of the original list. If False, modifies + the list in place (default is False). + + Returns + ------- + list + The modified list with replaced values. + + Raises + ------ + IndexOutOfList + If any index in `index` is out of bounds for the input list. + ValueError + If the number of indices does not match the number of values. + + Examples + -------- + >>> input_ = [1, 2, 3, 4] + >>> replace_at_index(input_, [1, 3], [9, 10]) + [1, 9, 3, 10] + + >>> replace_at_index([1, 2, 3, 4], 2, 99) + [1, 2, 99, 4] + """ + # Ensure index and value are lists for uniform processing + index = [index] if not isinstance(index, list) else index + value = [value] if not isinstance(value, list) else value + + # Check if the number of indices matches the number of values + if len(index) != len(value): + raise ValueError(f"The number of indices ({len(index)}) must match the number of values ({len(value)}).") + + # Validate indices are within bounds + for i in index: + if i >= len(input_list) or i < 0: + raise eL.IndexOutOfList(f"Index {i} is out of bounds for a list of length {len(input_list)}.") + # Create a new list if requested if new_list: - input_list = input_list[:] + input_list = deepcopy(input_list) + # Replace elements at the specified indices for i, v in zip(index, value): input_list[i] = v return input_list -def replace_element(input_list, old_element, new_element, new_list=False): - old_element, new_element = map(lambda x: [x] if not isinstance(x, list) else x, (old_element, new_element)) - new_element += new_element[:len(old_element) - len(new_element)] +def replace_element(input_list: List[Union[int, float, str]], + old_elements: Union[List[Union[int, float, str]], Union[int, float, str]], + new_elements: Union[List[Union[int, float, str]], Union[int, float, str]], + new_list: bool = False) -> List[Union[int, float, str]]: + """ + Replaces elements in a list with new values at corresponding indices. - index = [] - for i, old in enumerate(old_element): + Parameters + ---------- + input_list : list of int, float, or str + The original list in which elements will be replaced. + old_elements : int, float, str, or list of int, float, str + The element(s) to be replaced in the input_list. + new_elements : int, float, str, or list of int, float, str + The new value(s) to replace the old_elements. + new_list : bool, optional + If True, returns a modified copy of the original list. If False, modifies + the list in place (default is False). + + Returns + ------- + list of int, float, or str + The modified list with the replaced elements. + + Raises + ------ + GotAnUnknownValue + If any value in old_elements does not exist in the input_list. + UnequalElements + If old_elements and new_elements lists have different lengths. + + Notes + ----- + - The lengths of old_elements and new_elements must match if they are provided as lists. + - If a single element is provided in old_elements or new_elements, it will be applied to all occurrences of old_elements in input_list. + """ + if not isinstance(old_elements, list): + old_elements = [old_elements] + if not isinstance(new_elements, list): + new_elements = [new_elements] + + if len(old_elements) != len(new_elements): + raise eL.UnequalElements(f'The number of elements in old_elements ({len(old_elements)}) does not match ' + f'the number of elements in new_elements ({len(new_elements)}).') + + indices = [] + for old in old_elements: if old not in input_list: - raise GotAnUnknownValue(f'The value {old} given in old_element does not exist in the input_list.') - index.append(input_list.index(old)) + raise eL.GotAnUnknownValue(f'The value {old} given in old_elements does not exist in the input_list.') + indices.append(input_list.index(old)) if new_list: input_list = input_list[:] - for i, new in zip(index, new_element): + for i, new in zip(indices, new_elements): input_list[i] = new return input_list class CountObjectsInList: + """Class to count objects in the given list.""" - def __init__(self, counter_dict): + def __init__(self, counter_dict: Dict[Union[str, int], int]): + """ + Initialize the CountObjectsInList with a dictionary containing items and their counts. + + Parameters + ---------- + counter_dict : dict + A dictionary where keys are items (could be strings or other types), + and values are their corresponding counts. + """ self.counter_dict = counter_dict self.__counter_dict = sorted(self.counter_dict.items(), key=lambda x: x[1], reverse=True) - self.counter = 0 + def __str__(self) -> str: + """Return a formatted string representing the counts of the objects in the list. The items and their counts are displayed in a table format. - def __str__(self): + Returns + ------- + str + A string representation of the object with formatted counts. + """ out = '-' * 50 + '\n' out += f'|{"items":^30}|{"counts":^17}|\n' out += '-' * 50 + '\n' - out += '\n'.join([f'|{key:^30}|{value:^17}|' - if not isinstance(key, str) - else f"|\'{key}\':^30|{value:^17}|" - for key, value in self.counter_dict.items()]) + '\n' + + for key, value in self.counter_dict.items(): + if isinstance(key, str): + out += f"|\'{key}\':^30|{value:^17}|\n" + else: + out += f"|{key:^30}|{value:^17}|\n" out += '-' * 50 + '\n' return out - def __getitem__(self, item): - _get = self.__counter_dict[item] - try: - return CountObjectsInList({element[0]: element[1] for element in _get}) - except TypeError: - return CountObjectsInList({_get[0]: _get[1]}) + def __getitem__(self, item: int) -> 'CountObjectsInList': + """ + Retrieve a specific item from the sorted counter list and return a new CountObjectsInList instance. + + Parameters + ---------- + item : int + The index of the item in the sorted counter list. + + Returns + ------- + CountObjectsInList + A new CountObjectsInList instance with the corresponding item and its count. + """ + if item < 0 or item >= len(self.__counter_dict): + raise IndexError("Index out of bounds.") + + selected_item = self.__counter_dict[item] + + return CountObjectsInList({selected_item[0]: selected_item[1]}) def numeric_list_to_string(num_list: List[int]) -> List[str]: From 569b4e82b5a5aebe9c7ca6127f0c1cfd1626a9e8 Mon Sep 17 00:00:00 2001 From: Syed Ali Mohsin Bukhari Date: Thu, 28 Nov 2024 13:24:37 +0500 Subject: [PATCH 02/11] [mp-minor-1] 1. initial implementation of `ezMultiprocessing.py` 2. updated version to match with the main branch --- src/mpyez/ezMultiprocessing.py | 106 +++++++++++++++++++++++++++++++++ src/mpyez/version.py | 3 +- 2 files changed, 108 insertions(+), 1 deletion(-) diff --git a/src/mpyez/ezMultiprocessing.py b/src/mpyez/ezMultiprocessing.py index 70e15e9..824e6ff 100644 --- a/src/mpyez/ezMultiprocessing.py +++ b/src/mpyez/ezMultiprocessing.py @@ -1 +1,107 @@ """Created on Jun 12 13:48:59 2024""" + +from collections import defaultdict +from multiprocessing import Pool +from typing import Any, Callable, Dict, List + +import numpy as np + + +def _reshape(lst: List[Any], n: int) -> List[List[Any]]: + """Yield successive n-sized chunks from lst.""" + for i in range(0, len(lst), n): + yield lst[i:i + n] + + +class MultiProcessor: + """Generalized multiprocessing functionality.""" + + def __init__(self, func: Callable[..., Dict], args: List[List[Any]], n_processors: int = 3, result_handler: Callable = None): + """ + Initialize the `MultiProcessor` class. + + Parameters + ---------- + func: Callable + Function to be processed. + args: list of lists + A list of arguments to be passed onto the given function. + n_processors: int, optional + The number of processors to use for multiprocessing. Default is 3. + result_handler: Callable, optional + A function to handle results after processing. Defaults to `defaultdict(list)`. + + Raises + ------ + ValueError + If `n_processors` is less than 1, or if `args` is malformed. + """ + if n_processors < 1: + raise ValueError("`n_processors` must be at least 1.") + if not all(isinstance(arg, list) for arg in args): + raise ValueError("Each element of `args` must be a list.") + + self.func = func + self.args = args + self.n_proc = n_processors + self.n_values = len(self.args[0]) + self.result_handler = result_handler or (lambda results: defaultdict(list)) + + def run(self) -> Dict: + # Reshape the arguments to split the work across processors + reshaped_args = np.array([list(_reshape(i, self.n_proc)) for i in self.args]) + reshaped_args = reshaped_args.transpose() + flattened_args = [list(chunk) for chunk in reshaped_args] + + with Pool(self.n_proc) as pool: + try: + # Use starmap to correctly pass the arguments to the function + results = pool.starmap(self.func, flattened_args) + except Exception as e: + raise RuntimeError(f"Error occurred during parallel execution: {e}") + + new_dict = self.result_handler(results) + return new_dict + + +def sine_wave(t_low, t_high, f, a, phi): + """Generate a sine wave. + + Parameters + ---------- + t_low: + The start time for the wave. + t_high: + The end time for the wave. + f: + The frequency for the wave. + a: + The amplitude of the wave. + phi: + The phase of the wave. + + Returns + ------- + dict + A dictionary with sine wave values for each time `t`. + """ + t = np.linspace(t_low, t_high, 1000) + return {f"wave_{i}": a * np.sin(2 * np.pi * f * t[i] + phi) for i in range(len(t))} + + +# Time and other parameters for the waves +f_ = np.linspace(0, 1, 16).tolist() # Frequencies for the waves (in Hz) +a_ = np.linspace(0, 1, 16).tolist() # Amplitudes for each wave +phi_ = np.linspace(-2 * np.pi, 2 * np.pi, 16).tolist() # Phases for the waves +t_l = np.repeat(0, len(f_)).tolist() # Start time for each sine wave +t_h = np.repeat(1, len(f_)).tolist() # End time for each sine wave + +# Arguments to be passed to the function +arguments = [t_l, t_h, f_, a_, phi_] + +# Create the MultiProcessor instance and run it +processor = MultiProcessor(func=sine_wave, args=arguments, n_processors=4) +result = processor.run() + +# Print the results +print(result) diff --git a/src/mpyez/version.py b/src/mpyez/version.py index 054cc11..8e2e5d5 100644 --- a/src/mpyez/version.py +++ b/src/mpyez/version.py @@ -1,6 +1,7 @@ """Created on Nov 02 21:24:07 2024""" -__version__ = "0.0.9a0" +__name__ = 'mpyez' +__version__ = "0.0.9" __author__ = "Syed Ali Mohsin Bukhari" __email__ = "syedali.b@outlook.com" __license__ = "MIT" From e20467f419a51221b3470c1f0c7d7d350b9c12b2 Mon Sep 17 00:00:00 2001 From: Syed Ali Mohsin Bukhari Date: Sat, 30 Nov 2024 10:29:53 +0500 Subject: [PATCH 03/11] [mp-minor-1-1] 1. removed numpy strict dependency --- environment.yaml | 6 +----- pyproject.toml | 2 +- requirements.txt | 3 ++- setup.py | 3 +-- 4 files changed, 5 insertions(+), 9 deletions(-) diff --git a/environment.yaml b/environment.yaml index a36e2f1..98f460a 100644 --- a/environment.yaml +++ b/environment.yaml @@ -2,8 +2,4 @@ name: mpyez channels: - conda-forge - defaults -dependencies: - - python=3.9.* - - numpy==1.26.4 - - matplotlib - - setuptools +dependencies: [ python=3.9, numpy==2, matplotlib, setuptools ] diff --git a/pyproject.toml b/pyproject.toml index 72087d1..407caa0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,4 +1,4 @@ [build-system] -requires = ["setuptools", "numpy==1.26.4", "matplotlib"] +requires = ["setuptools", "numpy", "matplotlib"] build-backend = "setuptools.build_meta:__legacy__" diff --git a/requirements.txt b/requirements.txt index 1061981..7a0b215 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,2 +1,3 @@ -numpy==1.26.4 +numpy matplotlib +setuptools diff --git a/setup.py b/setup.py index 5fc6ed4..e285353 100644 --- a/setup.py +++ b/setup.py @@ -30,8 +30,7 @@ def load_metadata(): long_description=readme, long_description_content_type="text/markdown", python_requires=">=3.9", - install_requires=["numpy==1.26.4", "matplotlib", "setuptools"], - include_package_data=True, + install_requires=["numpy", "matplotlib", "setuptools"], classifiers=["License :: OSI Approved :: MIT License", "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", From d26836695dc5b011ebed6084fb2a8bb7f38a3f29 Mon Sep 17 00:00:00 2001 From: Syed Ali Mohsin Bukhari Date: Sat, 30 Nov 2024 10:31:18 +0500 Subject: [PATCH 04/11] [mp-minor-1-2] 1. version updated for pypi --- src/mpyez/version.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/mpyez/version.py b/src/mpyez/version.py index 8e2e5d5..1355232 100644 --- a/src/mpyez/version.py +++ b/src/mpyez/version.py @@ -1,7 +1,7 @@ """Created on Nov 02 21:24:07 2024""" __name__ = 'mpyez' -__version__ = "0.0.9" +__version__ = "0.0.9a2" __author__ = "Syed Ali Mohsin Bukhari" __email__ = "syedali.b@outlook.com" __license__ = "MIT" From 16754cd40ad16e2a3b36d42e3089faed1c86953c Mon Sep 17 00:00:00 2001 From: Syed Ali Mohsin Bukhari Date: Sat, 30 Nov 2024 11:21:59 +0500 Subject: [PATCH 05/11] [mp-minor-1-3] 1. for some reason, the last update overwrote previous color fix. --- src/mpyez/backend/uPlotting.py | 4 ++-- src/mpyez/version.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/mpyez/backend/uPlotting.py b/src/mpyez/backend/uPlotting.py index 9ad9aaf..e6f4f1d 100644 --- a/src/mpyez/backend/uPlotting.py +++ b/src/mpyez/backend/uPlotting.py @@ -152,7 +152,7 @@ def __init__(self, line_style=None, line_width=None, color=None, alpha=None, alpha=alpha, marker=marker, marker_size=marker_size, marker_edge_color=marker_edge_color, marker_face_color=marker_face_color, marker_edge_width=marker_edge_width) - self.color = get_color() + self.color = color or get_color() def __repr__(self): param_str = ', '.join(f"{key}={value!r}" for key, value in self.to_dict().items()) @@ -221,7 +221,7 @@ def __init__(self, color=None, alpha=None, marker=None, size=None, cmap=None, fa super().__init__(color=color, alpha=alpha, marker=marker, size=size, cmap=cmap, face_color=face_color) if self.color is None: - self.color = get_color() + self.color = color or get_color() def __repr__(self): param_str = ', '.join(f"{key}={value!r}" for key, value in self.to_dict().items()) diff --git a/src/mpyez/version.py b/src/mpyez/version.py index 1355232..61ffd58 100644 --- a/src/mpyez/version.py +++ b/src/mpyez/version.py @@ -1,7 +1,7 @@ """Created on Nov 02 21:24:07 2024""" __name__ = 'mpyez' -__version__ = "0.0.9a2" +__version__ = "0.0.9a3" __author__ = "Syed Ali Mohsin Bukhari" __email__ = "syedali.b@outlook.com" __license__ = "MIT" From 258b75f384fe32f3ee2a832d5fae11260580a289 Mon Sep 17 00:00:00 2001 From: Syed Ali Mohsin Bukhari Date: Sun, 1 Dec 2024 03:05:05 +0500 Subject: [PATCH 06/11] [mp-minor-2] 1. added `evaluate_with_broadcast` to `ezArray.py` 2. minor docstring changes in `ezPlotting.py` --- src/mpyez/ezArray.py | 43 +++++++++++++++++++++++++++++++++++++++++ src/mpyez/ezPlotting.py | 14 +++++++------- 2 files changed, 50 insertions(+), 7 deletions(-) diff --git a/src/mpyez/ezArray.py b/src/mpyez/ezArray.py index 984463f..79047ea 100644 --- a/src/mpyez/ezArray.py +++ b/src/mpyez/ezArray.py @@ -62,3 +62,46 @@ def moving_average(array: np.ndarray, window_size: int) -> np.ndarray: An array containing the moving averages. The length of this array will be `len(array) - window_size + 1`. """ return np.convolve(array, np.ones(window_size), 'valid') / window_size + + +def evaluate_with_broadcast(func, constant_array, **param_arrays): + """ + Evaluate a function with a constant array and multiple parameter arrays using broadcasting. + + Parameters + ---------- + func : callable + The function to evaluate. It must support broadcasting of inputs. The function signature should match `func(constant_array, **param_arrays)`. + constant_array : ndarray + The constant array to evaluate the function over. + **param_arrays : dict + Keyword arguments representing parameter arrays. Each array should be compatible with broadcasting. + + Returns + ------- + ndarray + The result of the function evaluation. The shape of the result is `(len(constant_array), ...)`, + where the additional dimensions correspond to the shapes of the parameter arrays. + + Examples + -------- + Evaluate a sine wave function over time with varying frequencies and amplitudes: + + >>> import numpy as np + >>> def sine_wave(t, f, a): + ... return a * np.sin(2 * np.pi * f * t) + >>> time = np.linspace(0, 1, 100) + >>> frequencies = np.array([1, 2, 3]) + >>> amplitudes = np.array([0.5, 1.0, 1.5]) + >>> result = evaluate_with_broadcast(sine_wave, time, f=frequencies, a=amplitudes) + >>> result.shape + (100, 3) + """ + # Expand constant_array to add a broadcastable dimension + expanded_constant = constant_array[:, np.newaxis] + + # Expand all parameter arrays + broadcasted_params = {key: value[np.newaxis, :] for key, value in param_arrays.items()} + + # Evaluate the function using broadcasting + return func(expanded_constant, **broadcasted_params) diff --git a/src/mpyez/ezPlotting.py b/src/mpyez/ezPlotting.py index 65304da..9770cb6 100644 --- a/src/mpyez/ezPlotting.py +++ b/src/mpyez/ezPlotting.py @@ -1,4 +1,4 @@ -"""Created on Jul 23 23:41:18 2024""" +"""Created on Jul 23 23:41:18 2024.""" __all__ = ['plot_two_column_file', 'plot_xy', 'plot_xyy', 'plot_with_dual_axes', 'two_subplots', 'n_plotter'] @@ -27,7 +27,7 @@ def plot_two_column_file(file_name: str, plot_dictionary: Optional[plot_dictionary_type] = None, axis: Optional[Axes] = None) -> axis_return: """ - Reads a two-column file (x, y) and plots the data. + Read a two-column file (x, y) and plots the data. This function reads a file containing two columns (e.g., x and y values) and plots them using either a line plot or scatter plot based on the user's preference. @@ -47,7 +47,7 @@ def plot_two_column_file(file_name: str, is_scatter : bool, optional If True, creates a scatter plot. Otherwise, creates a line plot. Default is False. plot_dictionary: Union[LinePlot, ScatterPlot], optional - An object representing the plot data, either a `LinePlot` or `ScatterPlot`, to be passed to the matplotlib plotting library. + An object representing the plot data, either a `LinePlot` or `ScatterPlot`, to be passed to the matplotlib plotting library. If None, a default plot type will be used. axis: Optional[Axes] The axis object to draw the plots on. If not passed, a new axis object will be created internally. @@ -77,7 +77,7 @@ def plot_xy(x_data: np.ndarray, y_data: np.ndarray, plot_dictionary: Optional[plot_dictionary_type] = None, axis: Optional[Axes] = None) -> axis_return: """ - Plots x_data against y_data with customizable options. + Plot the x_data against y_data with customizable options. This function accepts two arrays of data (x and y) and plots them using either a line plot or scatter plot, with options for labels and figure size. @@ -189,7 +189,7 @@ def plot_with_dual_axes(x1_data: np.ndarray, y1_data: np.ndarray, plot_dictionary: Optional[plot_dictionary_type] = None, axis: Optional[Axes] = None) -> axis_return: """ - Plots data with options for dual axes (x or y) or single axis. + Plot the data with options for dual axes (x or y) or single axis. Parameters ---------- @@ -295,7 +295,7 @@ def two_subplots(x_data: List[np.ndarray], y_data: List[np.ndarray], subplot_dictionary: Optional[uPl.SubPlots] = None, plot_dictionary: Optional[Union[uPl.LinePlot, uPl.ScatterPlot]] = None) -> None: """ - Creates two subplots arranged horizontally or vertically, with optional customization. + Create two subplots arranged horizontally or vertically, with optional customization. This function internally calls `n_plotter` to handle the plotting of each subplot. `n_plotter` arranges the subplots and applies relevant plot and subplot dictionaries. @@ -359,7 +359,7 @@ def n_plotter(x_data: List[np.ndarray], y_data: List[np.ndarray], subplot_dictionary: Optional[uPl.SubPlots] = None, plot_dictionary: Optional[Union[uPl.LinePlot, uPl.ScatterPlot]] = None) -> Union[plt.figure, Axes]: """ - Plots multiple subplots in a grid with optional customization for each subplot. + Plot multiple subplots in a grid with optional customization for each subplot. Parameters ---------- From a8209c85838af0e102e5b7cdeeef3d14d0e54730 Mon Sep 17 00:00:00 2001 From: Syed Ali Mohsin Bukhari Date: Sat, 7 Dec 2024 03:00:26 +0500 Subject: [PATCH 07/11] [mp-minor-3] 1. added a simple implementation of `MultiProcessing` function in `ezMultiprocessing.py` --- src/mpyez/ezMultiprocessing.py | 91 ++++++---------------------------- 1 file changed, 15 insertions(+), 76 deletions(-) diff --git a/src/mpyez/ezMultiprocessing.py b/src/mpyez/ezMultiprocessing.py index 824e6ff..b85b3c5 100644 --- a/src/mpyez/ezMultiprocessing.py +++ b/src/mpyez/ezMultiprocessing.py @@ -1,107 +1,46 @@ """Created on Jun 12 13:48:59 2024""" -from collections import defaultdict from multiprocessing import Pool -from typing import Any, Callable, Dict, List +from typing import Callable, Dict, List import numpy as np -def _reshape(lst: List[Any], n: int) -> List[List[Any]]: - """Yield successive n-sized chunks from lst.""" - for i in range(0, len(lst), n): - yield lst[i:i + n] - - class MultiProcessor: - """Generalized multiprocessing functionality.""" + """Generalized multiprocessing functionality for processing dictionary inputs.""" - def __init__(self, func: Callable[..., Dict], args: List[List[Any]], n_processors: int = 3, result_handler: Callable = None): + def __init__(self, func: Callable, args: Dict[str, List], n_processors: int = 3): """ Initialize the `MultiProcessor` class. Parameters ---------- func: Callable - Function to be processed. - args: list of lists - A list of arguments to be passed onto the given function. + Function to be processed. The function should take arguments corresponding to the keys in the dictionary. + args: Dict[str, List] + A dictionary where keys are argument names, and values are lists of values for those arguments. n_processors: int, optional The number of processors to use for multiprocessing. Default is 3. - result_handler: Callable, optional - A function to handle results after processing. Defaults to `defaultdict(list)`. - - Raises - ------ - ValueError - If `n_processors` is less than 1, or if `args` is malformed. """ if n_processors < 1: raise ValueError("`n_processors` must be at least 1.") - if not all(isinstance(arg, list) for arg in args): - raise ValueError("Each element of `args` must be a list.") self.func = func self.args = args self.n_proc = n_processors - self.n_values = len(self.args[0]) - self.result_handler = result_handler or (lambda results: defaultdict(list)) - def run(self) -> Dict: - # Reshape the arguments to split the work across processors - reshaped_args = np.array([list(_reshape(i, self.n_proc)) for i in self.args]) - reshaped_args = reshaped_args.transpose() - flattened_args = [list(chunk) for chunk in reshaped_args] + lengths = [len(v) for v in args.values()] + if len(set(lengths)) != 1: + raise ValueError("All argument lists must have the same length.") + + self.arg_tuples = list(zip(*self.args.values())) + def run(self) -> np.ndarray: + """Run the multiprocessing task.""" with Pool(self.n_proc) as pool: try: - # Use starmap to correctly pass the arguments to the function - results = pool.starmap(self.func, flattened_args) + results = pool.starmap(self.func, self.arg_tuples) except Exception as e: raise RuntimeError(f"Error occurred during parallel execution: {e}") - new_dict = self.result_handler(results) - return new_dict - - -def sine_wave(t_low, t_high, f, a, phi): - """Generate a sine wave. - - Parameters - ---------- - t_low: - The start time for the wave. - t_high: - The end time for the wave. - f: - The frequency for the wave. - a: - The amplitude of the wave. - phi: - The phase of the wave. - - Returns - ------- - dict - A dictionary with sine wave values for each time `t`. - """ - t = np.linspace(t_low, t_high, 1000) - return {f"wave_{i}": a * np.sin(2 * np.pi * f * t[i] + phi) for i in range(len(t))} - - -# Time and other parameters for the waves -f_ = np.linspace(0, 1, 16).tolist() # Frequencies for the waves (in Hz) -a_ = np.linspace(0, 1, 16).tolist() # Amplitudes for each wave -phi_ = np.linspace(-2 * np.pi, 2 * np.pi, 16).tolist() # Phases for the waves -t_l = np.repeat(0, len(f_)).tolist() # Start time for each sine wave -t_h = np.repeat(1, len(f_)).tolist() # End time for each sine wave - -# Arguments to be passed to the function -arguments = [t_l, t_h, f_, a_, phi_] - -# Create the MultiProcessor instance and run it -processor = MultiProcessor(func=sine_wave, args=arguments, n_processors=4) -result = processor.run() - -# Print the results -print(result) + return np.array(results) From c2cd30b92aebf17e65908211261b96dcde307a11 Mon Sep 17 00:00:00 2001 From: Syed Ali Mohsin Bukhari Date: Sat, 7 Dec 2024 03:22:19 +0500 Subject: [PATCH 08/11] [mp-minor-4] 1. updated numpy requirement in `requirements.txt`, `setup.py` and `pyproject.toml` 2. cosmetic changes in `ezArray.py` --- pyproject.toml | 2 +- requirements.txt | 2 +- setup.py | 2 +- src/mpyez/ezArray.py | 23 +++-------------------- 4 files changed, 6 insertions(+), 23 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 407caa0..bee597f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,4 +1,4 @@ [build-system] -requires = ["setuptools", "numpy", "matplotlib"] +requires = ["setuptools", "numpy<2.1.0", "matplotlib"] build-backend = "setuptools.build_meta:__legacy__" diff --git a/requirements.txt b/requirements.txt index 7a0b215..f42e317 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,3 @@ -numpy +numpy<2.1.0 matplotlib setuptools diff --git a/setup.py b/setup.py index e285353..ccce88b 100644 --- a/setup.py +++ b/setup.py @@ -30,7 +30,7 @@ def load_metadata(): long_description=readme, long_description_content_type="text/markdown", python_requires=">=3.9", - install_requires=["numpy", "matplotlib", "setuptools"], + install_requires=["numpy<2.1.0", "matplotlib", "setuptools"], classifiers=["License :: OSI Approved :: MIT License", "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", diff --git a/src/mpyez/ezArray.py b/src/mpyez/ezArray.py index 79047ea..d8a0dfd 100644 --- a/src/mpyez/ezArray.py +++ b/src/mpyez/ezArray.py @@ -82,26 +82,9 @@ def evaluate_with_broadcast(func, constant_array, **param_arrays): ndarray The result of the function evaluation. The shape of the result is `(len(constant_array), ...)`, where the additional dimensions correspond to the shapes of the parameter arrays. - - Examples - -------- - Evaluate a sine wave function over time with varying frequencies and amplitudes: - - >>> import numpy as np - >>> def sine_wave(t, f, a): - ... return a * np.sin(2 * np.pi * f * t) - >>> time = np.linspace(0, 1, 100) - >>> frequencies = np.array([1, 2, 3]) - >>> amplitudes = np.array([0.5, 1.0, 1.5]) - >>> result = evaluate_with_broadcast(sine_wave, time, f=frequencies, a=amplitudes) - >>> result.shape - (100, 3) """ - # Expand constant_array to add a broadcastable dimension + # Expand constant_array to add a broadcast dimension expanded_constant = constant_array[:, np.newaxis] + params_ = {key: value[np.newaxis, :] for key, value in param_arrays.items()} - # Expand all parameter arrays - broadcasted_params = {key: value[np.newaxis, :] for key, value in param_arrays.items()} - - # Evaluate the function using broadcasting - return func(expanded_constant, **broadcasted_params) + return func(expanded_constant, **params_) From f634f41bdc9f7e622a9a5a8cb1790703273f46f5 Mon Sep 17 00:00:00 2001 From: Syed Ali Mohsin Bukhari Date: Sat, 7 Dec 2024 03:23:48 +0500 Subject: [PATCH 09/11] [mp-minor-4-1] 1. updated `version.py` for pypi --- src/mpyez/version.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/mpyez/version.py b/src/mpyez/version.py index 61ffd58..90612e8 100644 --- a/src/mpyez/version.py +++ b/src/mpyez/version.py @@ -1,7 +1,7 @@ """Created on Nov 02 21:24:07 2024""" __name__ = 'mpyez' -__version__ = "0.0.9a3" +__version__ = "0.1.0" __author__ = "Syed Ali Mohsin Bukhari" __email__ = "syedali.b@outlook.com" __license__ = "MIT" From 46c931ed438add4f3732221bfc9ee899425faf8c Mon Sep 17 00:00:00 2001 From: Syed Ali Mohsin Bukhari Date: Sun, 8 Dec 2024 08:54:14 +0500 Subject: [PATCH 10/11] [testing-minor-2] 1. updated `ezPlotting.py` with correct use of `auto_label` arguments 2. minor changes in `uPlotting.py` --- src/mpyez/backend/uPlotting.py | 75 +---------------- src/mpyez/ezPlotting.py | 148 +++++++++++++++++++-------------- 2 files changed, 90 insertions(+), 133 deletions(-) diff --git a/src/mpyez/backend/uPlotting.py b/src/mpyez/backend/uPlotting.py index 1bbddd3..15f2074 100644 --- a/src/mpyez/backend/uPlotting.py +++ b/src/mpyez/backend/uPlotting.py @@ -1,16 +1,13 @@ """Created on Oct 29 09:33:06 2024""" -__all__ = ['LinePlot', 'ScatterPlot', 'SubPlots', 'label_handler', 'plot_or_scatter', 'plot_dictionary_handler', - 'split_dictionary', 'dual_axes_data_validation', 'dual_axes_label_management'] +__all__ = ['LinePlot', 'ScatterPlot', 'SubPlots', 'plot_or_scatter', 'plot_dictionary_handler', 'split_dictionary', 'dual_axes_data_validation', + 'dual_axes_label_management'] -import warnings from typing import Any, Dict, List, Optional, Tuple, Union import numpy as np from matplotlib import pyplot as plt, rcParams -from .ePlotting import NoXYLabels - def get_color(): """ @@ -282,69 +279,6 @@ def _all_parameters(self): return [self.share_x, self.share_y, self.fig_size] -def label_handler(x_labels: Optional[List[str]], y_labels: Optional[List[str]], - n_rows: int, n_cols: int, auto_label: bool) -> Tuple[List[str], List[str]]: - """ - Handles the generation or validation of x and y labels for a subplot configuration. - - Parameters - ---------- - x_labels : list of str or None - The labels for the x-axis. If `None`, labels may be auto-generated based on `auto_label`. - y_labels : list of str or None - The labels for the y-axis. If `None`, labels may be auto-generated based on `auto_label`. - n_rows : int - Number of rows in the subplot grid. - n_cols : int - Number of columns in the subplot grid. - auto_label : bool - If `True`, generates x and y labels automatically when not provided. - - Returns - ------- - Tuple[List[str], List[str]] - The x and y labels for the subplot grid. - - Raises - ------ - NoXYLabels - If both `x_labels` and `y_labels` are `None` and `auto_label` is `False`. - - Warnings - -------- - UserWarning - If one of `x_labels` or `y_labels` is missing when `auto_label` is enabled, or if - there is a mismatch in the number of provided labels. - - """ - if not auto_label and (x_labels is None or y_labels is None): - raise NoXYLabels("Both x_labels and y_labels are required without the auto_label parameter.") - - elif auto_label and (x_labels is None or y_labels is None): - if x_labels is None and y_labels is None: - pass - else: - if x_labels is None: - warnings.warn("y_labels given but x_labels is missing, applying auto-labeling...", UserWarning) - if y_labels is None: - warnings.warn("x_labels given but y_labels is missing, applying auto-labeling...", UserWarning) - - if auto_label: - if x_labels and y_labels: - start = "auto_label selected with x_labels and y_labels provided" - if len(x_labels) != n_rows * n_cols or len(y_labels) != n_rows * n_cols: - warnings.warn(f"{start}, mismatch found, using auto-generated labels...", UserWarning) - x_labels = [fr'X$_{i + 1}$' for i in range(n_cols * n_rows)] - y_labels = [fr'Y$_{i + 1}$' for i in range(n_cols * n_rows)] - else: - print(f"{start}, using user-provided labels...") - else: - x_labels = [fr'X$_{i + 1}$' for i in range(n_cols * n_rows)] - y_labels = [fr'Y$_{i + 1}$' for i in range(n_cols * n_rows)] - - return x_labels, y_labels - - def plot_or_scatter(axes: plt.axis, scatter: bool): """ Returns the plot or scatter method based on the specified plot type. @@ -413,10 +347,7 @@ def split_dictionary(plot_instance: Union[LinePlot, ScatterPlot]) -> _split: # Split each parameter into two separate dictionaries for the two instances for param_name, values in parameters.items(): - if isinstance(values, (list, tuple)) and len(values) == 2: - params_instance1[param_name], params_instance2[param_name] = values - else: - raise ValueError(f"Parameter '{param_name}' must be a list or tuple with exactly two elements.") + params_instance1[param_name], params_instance2[param_name] = values[:2] instance1 = plot_instance.__class__.populate(params_instance1) instance2 = plot_instance.__class__.populate(params_instance2) diff --git a/src/mpyez/ezPlotting.py b/src/mpyez/ezPlotting.py index 9770cb6..23dd9bf 100644 --- a/src/mpyez/ezPlotting.py +++ b/src/mpyez/ezPlotting.py @@ -21,16 +21,15 @@ def plot_two_column_file(file_name: str, delimiter: str = ',', skip_header: bool = False, - data_label: str = 'X vs Y', + x_label: Optional[str] = None, + y_label: Optional[str] = None, + data_label: Optional[str] = None, + plot_title: Optional[str] = None, auto_label: bool = False, is_scatter: bool = False, plot_dictionary: Optional[plot_dictionary_type] = None, axis: Optional[Axes] = None) -> axis_return: - """ - Read a two-column file (x, y) and plots the data. - - This function reads a file containing two columns (e.g., x and y values) and plots them - using either a line plot or scatter plot based on the user's preference. + """Read a two-column file (x, y) and plot the data. Parameters ---------- @@ -40,8 +39,14 @@ def plot_two_column_file(file_name: str, The delimiter used in the file (default is ','). skip_header: bool, optional If True, skips the first row in the given data file, otherwise does nothing. Default is False. - data_label: str + x_label: str, optional + The label for the x-axis. + y_label: str, optional + The label for the y-axis. + data_label: str, optional Data label for the plot to put in the legend. Defaults to 'X vs Y'. + plot_title: str, optional + The title for the plot. auto_label : bool, optional If True, automatically sets the x-axis label, y-axis label, and plot title. Default is False. is_scatter : bool, optional @@ -59,6 +64,7 @@ def plot_two_column_file(file_name: str, """ # CHANGELIST: # - Removed `fig_size` and added `data_label` parameter + # - Added `x_label`, `y_label`, and `plot_title` data = np.genfromtxt(file_name, delimiter=delimiter, skip_header=skip_header) if data.shape[1] != 2: @@ -66,21 +72,17 @@ def plot_two_column_file(file_name: str, x_data, y_data = data.T - return plot_with_dual_axes(x1_data=x_data, y1_data=y_data, x1y1_label=data_label, auto_label=auto_label, - is_scatter=is_scatter, plot_dictionary=plot_dictionary, axis=axis) + return plot_with_dual_axes(x1_data=x_data, y1_data=y_data, x1y1_label=data_label, auto_label=auto_label, axis_labels=[x_label, y_label, None], + plot_title=plot_title, is_scatter=is_scatter, plot_dictionary=plot_dictionary, axis=axis) def plot_xy(x_data: np.ndarray, y_data: np.ndarray, - x_label: str = 'X', y_label: str = 'Y', plot_title: str = 'XY plot', - data_label: str = 'X vs Y', + x_label: Optional[str] = None, y_label: Optional[str] = None, plot_title: Optional[str] = None, + data_label: Optional[str] = None, auto_label: bool = False, is_scatter: bool = False, plot_dictionary: Optional[plot_dictionary_type] = None, axis: Optional[Axes] = None) -> axis_return: - """ - Plot the x_data against y_data with customizable options. - - This function accepts two arrays of data (x and y) and plots them using either - a line plot or scatter plot, with options for labels and figure size. + """Plot the x_data against y_data with customizable options. Parameters ---------- @@ -88,13 +90,13 @@ def plot_xy(x_data: np.ndarray, y_data: np.ndarray, The data for the x-axis. y_data : np.ndarray The data for the y-axis. - x_label: str + x_label: str, optional The label for the x-axis. - y_label: str + y_label: str, optional The label for the y-axis. - plot_title: str + plot_title: str, optional The title for the plot. - data_label: str + data_label: str, optional Data label for the plot to put in the legend. Defaults to 'X vs Y'. auto_label : bool, optional If True, automatically sets x and y-axis labels and the plot title. Default is False. @@ -102,7 +104,7 @@ def plot_xy(x_data: np.ndarray, y_data: np.ndarray, If True, creates a scatter plot. Otherwise, creates a line plot. Default is False. plot_dictionary: Union[LinePlot, ScatterPlot], optional An object representing the plot data, either a `LinePlot` or `ScatterPlot`, to be passed to the matplotlib plotting library. - If None, a default plot type will be used. + If None, a default plot type will be used. axis: Optional[Axes] The axis object to draw the plots on. If not passed, a new axis object will be created internally. @@ -114,19 +116,26 @@ def plot_xy(x_data: np.ndarray, y_data: np.ndarray, # CHANGELIST: # - Removed `fig_size` parameter # - Added `x_label`, `y_label` and `plot_title` for respective plot arguments - axis_labels = [x_label, y_label, ''] - return plot_with_dual_axes(x1_data=x_data, y1_data=y_data, x1y1_label=data_label, auto_label=auto_label, - axis_labels=axis_labels, plot_title=plot_title, - is_scatter=is_scatter, plot_dictionary=plot_dictionary, axis=axis) + # - Replaced the argument labels to None for better handling + # - Correct handling of `auto_label` argument with default labels + if auto_label: + x_label = 'X' + y_label = 'Y' + plot_title = 'Plot' + data_label = 'X vs Y' + + axis_labels = [x_label, y_label, None] + return plot_with_dual_axes(x1_data=x_data, y1_data=y_data, x1y1_label=data_label, auto_label=auto_label, axis_labels=axis_labels, + plot_title=plot_title, is_scatter=is_scatter, plot_dictionary=plot_dictionary, axis=axis) def plot_xyy(x_data: np.ndarray, y1_data: np.ndarray, y2_data: np.ndarray, - x_label: str = 'X', y1_label: str = 'Y1', y2_label: str = 'Y2', plot_title: str = 'XYY plot', - data_labels: Optional[List[str]] = None, auto_label: bool = False, + x_label: Optional[str] = None, y1_label: Optional[str] = None, y2_label: Optional[str] = None, + plot_title: Optional[str] = None, data_labels: Optional[List[str]] = (None, None), + use_twin_x: bool = True, auto_label: bool = False, is_scatter: bool = False, plot_dictionary: plot_dictionary_type = None, axis: Optional[Axes] = None) -> Axes: - """ - Plot two sets of y-data (`y1_data` and `y2_data`) against the same x-data (`x_data`) on the same plot. + """Plot two sets of y-data (`y1_data` and `y2_data`) against the same x-data (`x_data`) on the same plot. Parameters ---------- @@ -146,6 +155,8 @@ def plot_xyy(x_data: np.ndarray, y1_data: np.ndarray, y2_data: np.ndarray, The title for the plot. data_labels : list of str, optional The labels for the two datasets. Default is `['X vs Y1', 'X vs Y2']`. + use_twin_x : bool, optional + If True, creates dual y-axis plot. If False, creates dual x-axis plot. Default is True. auto_label : bool, optional Whether to automatically label the plot. Default is `False`. is_scatter : bool, optional @@ -164,16 +175,26 @@ def plot_xyy(x_data: np.ndarray, y1_data: np.ndarray, y2_data: np.ndarray, # - Removed `fig_size` parameter # - Added `x_label`, `y_label` and `plot_title` for respective plot arguments # - Fixed None `plot_dictionary` - data_labels = ['X vs Y1', 'X vs Y2'] if data_labels is None else data_labels + # - Handles `auto_label` correctly + if auto_label: + x_label = 'X' + y1_label = 'Y1' + y2_label = 'Y2' + plot_title = 'XYY plot' + data_labels = ['X vs Y1', 'X vs Y2'] + plot_config_1, plot_config_2 = uPl.split_dictionary(plot_dictionary if plot_dictionary else uPl.ScatterPlot() if is_scatter else uPl.LinePlot()) - ax_labels1 = [x_label, y1_label, ''] - ax_labels2 = [x_label, y2_label, ''] - axis = plot_with_dual_axes(x1_data=x_data, y1_data=y1_data, x1y1_label=data_labels[0], auto_label=auto_label, axis_labels=ax_labels1, - is_scatter=is_scatter, plot_dictionary=plot_config_1, axis=axis) - return plot_with_dual_axes(x1_data=x_data, y1_data=y2_data, x1y1_label=data_labels[1], auto_label=auto_label, axis_labels=ax_labels2, - plot_title=plot_title, is_scatter=is_scatter, plot_dictionary=plot_config_2, axis=axis) + axis = plot_with_dual_axes(x1_data=x_data, y1_data=y1_data, x1y1_label=data_labels[0], auto_label=auto_label, + axis_labels=[x_label, y1_label, None], is_scatter=is_scatter, plot_dictionary=plot_config_1, axis=axis) + # remove the prior title set by plot_title = None + axis.set_title('') + # use twin x-axis (or not) + axis = axis.twinx() if use_twin_x else axis + return plot_with_dual_axes(x1_data=x_data, y1_data=y2_data, x1y1_label=data_labels[1], auto_label=auto_label, + axis_labels=[x_label, y2_label, None], plot_title=plot_title, is_scatter=is_scatter, plot_dictionary=plot_config_2, + axis=axis) def plot_with_dual_axes(x1_data: np.ndarray, y1_data: np.ndarray, @@ -188,8 +209,7 @@ def plot_with_dual_axes(x1_data: np.ndarray, y1_data: np.ndarray, is_scatter: bool = False, plot_dictionary: Optional[plot_dictionary_type] = None, axis: Optional[Axes] = None) -> axis_return: - """ - Plot the data with options for dual axes (x or y) or single axis. + """Plot the data with options for dual axes (x or y) or single axis. Parameters ---------- @@ -237,14 +257,12 @@ def plot_with_dual_axes(x1_data: np.ndarray, y1_data: np.ndarray, # - Handles empty labels correctly as well # - Can deal with labels and data validations - labels = uPl.dual_axes_label_management(x1y1_label=x1y1_label, x1y2_label=x1y2_label, x2y1_label=x2y1_label, - auto_label=auto_label, axis_labels=axis_labels, plot_title=plot_title, - use_twin_x=use_twin_x) + labels = uPl.dual_axes_label_management(x1y1_label=x1y1_label, x1y2_label=x1y2_label, x2y1_label=x2y1_label, auto_label=auto_label, + axis_labels=axis_labels, plot_title=plot_title, use_twin_x=use_twin_x) x1y1_label, x1y2_label, x2y1_label, plot_title, axis_labels = labels - uPl.dual_axes_data_validation(x1_data=x1_data, x2_data=x2_data, y1_data=y1_data, y2_data=y2_data, - use_twin_x=use_twin_x, axis_labels=axis_labels) + uPl.dual_axes_data_validation(x1_data=x1_data, x2_data=x2_data, y1_data=y1_data, y2_data=y2_data, use_twin_x=use_twin_x, axis_labels=axis_labels) if axis: ax1 = axis @@ -288,17 +306,13 @@ def plot_with_dual_axes(x1_data: np.ndarray, y1_data: np.ndarray, def two_subplots(x_data: List[np.ndarray], y_data: List[np.ndarray], - x_labels: List[str], y_labels: List[str], data_labels: List[str], + x_labels: Optional[List[str]] = None, y_labels: Optional[List[str]] = None, data_labels: Optional[List[str]] = None, orientation: str = 'h', auto_label: bool = False, is_scatter: bool = False, subplot_dictionary: Optional[uPl.SubPlots] = None, - plot_dictionary: Optional[Union[uPl.LinePlot, uPl.ScatterPlot]] = None) -> None: - """ - Create two subplots arranged horizontally or vertically, with optional customization. - - This function internally calls `n_plotter` to handle the plotting of each subplot. - `n_plotter` arranges the subplots and applies relevant plot and subplot dictionaries. + plot_dictionary: Optional[Union[uPl.LinePlot, uPl.ScatterPlot]] = None) -> Union[plt.figure, Axes]: + """Create two subplots arranged horizontally or vertically, with optional customization. Parameters ---------- @@ -343,17 +357,13 @@ def two_subplots(x_data: List[np.ndarray], y_data: List[np.ndarray], else: raise ePl.OrientationError("The orientation must be either \'h\' or \'v\'.") - return n_plotter(x_data=x_data, y_data=y_data, - n_rows=n_rows, n_cols=n_cols, - x_labels=x_labels, y_labels=y_labels, data_labels=data_labels, - auto_label=auto_label, is_scatter=is_scatter, - subplot_dictionary=subplot_dictionary, plot_dictionary=plot_dictionary) + return n_plotter(x_data=x_data, y_data=y_data, n_rows=n_rows, n_cols=n_cols, x_labels=x_labels, y_labels=y_labels, data_labels=data_labels, + auto_label=auto_label, is_scatter=is_scatter, subplot_dictionary=subplot_dictionary, plot_dictionary=plot_dictionary) def n_plotter(x_data: List[np.ndarray], y_data: List[np.ndarray], n_rows: int, n_cols: int, - x_labels: Optional[List[str]] = None, y_labels: Optional[List[str]] = None, - data_labels: Optional[List[str]] = None, + x_labels=None, y_labels=None, data_labels=None, plot_title=None, auto_label: bool = False, is_scatter: bool = False, subplot_dictionary: Optional[uPl.SubPlots] = None, @@ -377,8 +387,10 @@ def n_plotter(x_data: List[np.ndarray], y_data: List[np.ndarray], List of labels for the y-axes of each subplot. data_labels : list of str, optional List of labels for the data series in each subplot. + plot_title: str, optional + Title of the plot. auto_label : bool, default False - Automatically assigns labels to subplots if `True`. + Automatically assigns labels to subplots if `True`. If `True`, it overwrites user provided labels. Defaults to False. is_scatter : bool, default False If `True`, plots data as scatter plots; otherwise, plots as line plots. subplot_dictionary : dict, optional @@ -396,6 +408,8 @@ def n_plotter(x_data: List[np.ndarray], y_data: List[np.ndarray], # - Improved subplot decoration for multi-row and multi-column cases (x/y labels, ticks, etc.). # - If `share_y = True`, other y-axes are no longer shown to avoid clutter. # - Removed `plot_on_dual_axes` or `plot_xy` dependency, instead uses simple plot/scatter functionality. + # - Added fail-safe labels to the function + # - Efficient handling of `auto_label` argument sp_dict = subplot_dictionary.get() if subplot_dictionary else uPl.SubPlots().get() @@ -406,8 +420,18 @@ def n_plotter(x_data: List[np.ndarray], y_data: List[np.ndarray], main_dict = [{key: value[c] for key, value in plot_items} for c in range(n_cols * n_rows)] - x_labels, y_labels = uPl.label_handler(x_labels=x_labels, y_labels=y_labels, n_rows=n_rows, n_cols=n_cols, - auto_label=auto_label) + if auto_label: + x_labels = [fr'X$_{i + 1}$' for i in range(n_cols * n_rows)] + y_labels = [fr'Y$_{i + 1}$' for i in range(n_cols * n_rows)] + + data_labels = [f'{i} vs {j}' for i, j in zip(x_labels, y_labels)] + plot_title = 'N Plotter' + # safeguard from `None` iterations in case if no label is provided and auto_label is false + else: + empty_ = [None for _ in range(n_cols * n_rows)] + x_labels = x_labels if x_labels else empty_ + y_labels = y_labels if y_labels else empty_ + data_labels = data_labels if data_labels else empty_ shared_y = sp_dict.get('sharey') shared_x1 = sp_dict.get('sharex') @@ -422,7 +446,9 @@ def n_plotter(x_data: List[np.ndarray], y_data: List[np.ndarray], ax.set_xlabel(j) if not (shared_y and index % n_cols != 0): ax.set_ylabel(k) - ax.legend(loc='best') + fig.suptitle(plot_title) + if label: + ax.legend(loc='best') fig.tight_layout() From 7817b3bb50a7ea8f44e74731d4a8ef2d2579dcec Mon Sep 17 00:00:00 2001 From: Syed Ali Mohsin Bukhari Date: Sun, 8 Dec 2024 09:46:46 +0500 Subject: [PATCH 11/11] [subplots-minor-1] 1. added subplot title and subplot dictionary usage in `ezPlotting.py` --- src/mpyez/ezPlotting.py | 97 ++++++++++++++++++++++++++--------------- 1 file changed, 61 insertions(+), 36 deletions(-) diff --git a/src/mpyez/ezPlotting.py b/src/mpyez/ezPlotting.py index 23dd9bf..b6fa63c 100644 --- a/src/mpyez/ezPlotting.py +++ b/src/mpyez/ezPlotting.py @@ -6,14 +6,11 @@ import matplotlib.pyplot as plt import numpy as np -from matplotlib import rcParams from matplotlib.axes import Axes from .backend import ePlotting as ePl, uPlotting as uPl # safeguard -line_plot = "LinePlot" -scatter_plot = "ScatterPlot" plot_dictionary_type = Optional[Union[uPl.LinePlot, uPl.ScatterPlot]] axis_return = Union[List[Axes], Axes] @@ -27,7 +24,8 @@ def plot_two_column_file(file_name: str, plot_title: Optional[str] = None, auto_label: bool = False, is_scatter: bool = False, - plot_dictionary: Optional[plot_dictionary_type] = None, + plot_dictionary: plot_dictionary_type = None, + subplot_dictionary: uPl.SubPlots = None, axis: Optional[Axes] = None) -> axis_return: """Read a two-column file (x, y) and plot the data. @@ -54,6 +52,8 @@ def plot_two_column_file(file_name: str, plot_dictionary: Union[LinePlot, ScatterPlot], optional An object representing the plot data, either a `LinePlot` or `ScatterPlot`, to be passed to the matplotlib plotting library. If None, a default plot type will be used. + subplot_dictionary + Dictionary of parameters for subplot configuration. axis: Optional[Axes] The axis object to draw the plots on. If not passed, a new axis object will be created internally. @@ -65,6 +65,8 @@ def plot_two_column_file(file_name: str, # CHANGELIST: # - Removed `fig_size` and added `data_label` parameter # - Added `x_label`, `y_label`, and `plot_title` + # - Added use of `subplot_dictionary` + data = np.genfromtxt(file_name, delimiter=delimiter, skip_header=skip_header) if data.shape[1] != 2: @@ -73,14 +75,15 @@ def plot_two_column_file(file_name: str, x_data, y_data = data.T return plot_with_dual_axes(x1_data=x_data, y1_data=y_data, x1y1_label=data_label, auto_label=auto_label, axis_labels=[x_label, y_label, None], - plot_title=plot_title, is_scatter=is_scatter, plot_dictionary=plot_dictionary, axis=axis) + plot_title=plot_title, is_scatter=is_scatter, plot_dictionary=plot_dictionary, subplot_dictionary=subplot_dictionary, + axis=axis) def plot_xy(x_data: np.ndarray, y_data: np.ndarray, x_label: Optional[str] = None, y_label: Optional[str] = None, plot_title: Optional[str] = None, data_label: Optional[str] = None, auto_label: bool = False, is_scatter: bool = False, - plot_dictionary: Optional[plot_dictionary_type] = None, + plot_dictionary: plot_dictionary_type = None, subplot_dictionary: Optional[uPl.SubPlots] = None, axis: Optional[Axes] = None) -> axis_return: """Plot the x_data against y_data with customizable options. @@ -105,6 +108,8 @@ def plot_xy(x_data: np.ndarray, y_data: np.ndarray, plot_dictionary: Union[LinePlot, ScatterPlot], optional An object representing the plot data, either a `LinePlot` or `ScatterPlot`, to be passed to the matplotlib plotting library. If None, a default plot type will be used. + subplot_dictionary: SubPlots, optional + Dictionary of parameters for subplot configuration. axis: Optional[Axes] The axis object to draw the plots on. If not passed, a new axis object will be created internally. @@ -118,6 +123,7 @@ def plot_xy(x_data: np.ndarray, y_data: np.ndarray, # - Added `x_label`, `y_label` and `plot_title` for respective plot arguments # - Replaced the argument labels to None for better handling # - Correct handling of `auto_label` argument with default labels + # - Added use of `subplot_dictionary` if auto_label: x_label = 'X' y_label = 'Y' @@ -126,14 +132,15 @@ def plot_xy(x_data: np.ndarray, y_data: np.ndarray, axis_labels = [x_label, y_label, None] return plot_with_dual_axes(x1_data=x_data, y1_data=y_data, x1y1_label=data_label, auto_label=auto_label, axis_labels=axis_labels, - plot_title=plot_title, is_scatter=is_scatter, plot_dictionary=plot_dictionary, axis=axis) + plot_title=plot_title, is_scatter=is_scatter, plot_dictionary=plot_dictionary, subplot_dictionary=subplot_dictionary, + axis=axis) def plot_xyy(x_data: np.ndarray, y1_data: np.ndarray, y2_data: np.ndarray, x_label: Optional[str] = None, y1_label: Optional[str] = None, y2_label: Optional[str] = None, plot_title: Optional[str] = None, data_labels: Optional[List[str]] = (None, None), use_twin_x: bool = True, auto_label: bool = False, - is_scatter: bool = False, plot_dictionary: plot_dictionary_type = None, + is_scatter: bool = False, plot_dictionary: plot_dictionary_type = None, subplot_dictionary: Optional[uPl.SubPlots] = None, axis: Optional[Axes] = None) -> Axes: """Plot two sets of y-data (`y1_data` and `y2_data`) against the same x-data (`x_data`) on the same plot. @@ -161,8 +168,10 @@ def plot_xyy(x_data: np.ndarray, y1_data: np.ndarray, y2_data: np.ndarray, Whether to automatically label the plot. Default is `False`. is_scatter : bool, optional Whether to create a scatter plot (`True`) or a line plot (`False`). Default is `False`. - plot_dictionary : dict, optional - A dictionary containing plot configuration parameters for the two datasets. Default is `None`. + plot_dictionary: Union[LinePlot, ScatterPlot], optional + An object representing the plot data, either a `LinePlot` or `ScatterPlot`, to be passed to the matplotlib plotting library. + subplot_dictionary: SubPlots + Dictionary of parameters for subplot configuration. axis : Axes, optional A Matplotlib axis to plot on. If `None`, a new axis is created. Default is `None`. @@ -176,25 +185,18 @@ def plot_xyy(x_data: np.ndarray, y1_data: np.ndarray, y2_data: np.ndarray, # - Added `x_label`, `y_label` and `plot_title` for respective plot arguments # - Fixed None `plot_dictionary` # - Handles `auto_label` correctly + # - Simplified by using a single `plot_with_dual_axis` instance + # - Added use of `subplot_dictionary` if auto_label: x_label = 'X' - y1_label = 'Y1' - y2_label = 'Y2' + y1_label = r'Y$_1$' + y2_label = r'Y$_2$' plot_title = 'XYY plot' - data_labels = ['X vs Y1', 'X vs Y2'] - - plot_config_1, plot_config_2 = uPl.split_dictionary(plot_dictionary if plot_dictionary - else uPl.ScatterPlot() if is_scatter else uPl.LinePlot()) - - axis = plot_with_dual_axes(x1_data=x_data, y1_data=y1_data, x1y1_label=data_labels[0], auto_label=auto_label, - axis_labels=[x_label, y1_label, None], is_scatter=is_scatter, plot_dictionary=plot_config_1, axis=axis) - # remove the prior title set by plot_title = None - axis.set_title('') - # use twin x-axis (or not) - axis = axis.twinx() if use_twin_x else axis - return plot_with_dual_axes(x1_data=x_data, y1_data=y2_data, x1y1_label=data_labels[1], auto_label=auto_label, - axis_labels=[x_label, y2_label, None], plot_title=plot_title, is_scatter=is_scatter, plot_dictionary=plot_config_2, - axis=axis) + data_labels = [r'X vs Y$_1$', r'X vs Y$_2$'] + + return plot_with_dual_axes(x1_data=x_data, y1_data=y1_data, y2_data=y2_data, x1y1_label=data_labels[0], x1y2_label=data_labels[1], + auto_label=auto_label, plot_title=plot_title, use_twin_x=use_twin_x, axis_labels=[x_label, y1_label, y2_label], + is_scatter=is_scatter, plot_dictionary=plot_dictionary, subplot_dictionary=subplot_dictionary, axis=axis) def plot_with_dual_axes(x1_data: np.ndarray, y1_data: np.ndarray, @@ -207,7 +209,8 @@ def plot_with_dual_axes(x1_data: np.ndarray, y1_data: np.ndarray, axis_labels: List[str] = None, plot_title: str = None, is_scatter: bool = False, - plot_dictionary: Optional[plot_dictionary_type] = None, + plot_dictionary: plot_dictionary_type = None, + subplot_dictionary: Optional[uPl.SubPlots] = None, axis: Optional[Axes] = None) -> axis_return: """Plot the data with options for dual axes (x or y) or single axis. @@ -240,6 +243,8 @@ def plot_with_dual_axes(x1_data: np.ndarray, y1_data: np.ndarray, If True, creates scatter plot; otherwise, line plot. Default is False. plot_dictionary: Union[LinePlot, ScatterPlot], optional An object representing the plot data, either a `LinePlot` or `ScatterPlot`, to be passed to the matplotlib plotting library. + subplot_dictionary: SubPlots + Dictionary of parameters for subplot configuration. axis: Optional[Axis] The axis object to draw the plots on. If not passed, a new axis object will be created internally. @@ -256,6 +261,7 @@ def plot_with_dual_axes(x1_data: np.ndarray, y1_data: np.ndarray, # - Handles `auto_label`, `axis_labels` and `plot_title` separately # - Handles empty labels correctly as well # - Can deal with labels and data validations + # - Added use of `subplot_dictionary` labels = uPl.dual_axes_label_management(x1y1_label=x1y1_label, x1y2_label=x1y2_label, x2y1_label=x2y1_label, auto_label=auto_label, axis_labels=axis_labels, plot_title=plot_title, use_twin_x=use_twin_x) @@ -267,7 +273,8 @@ def plot_with_dual_axes(x1_data: np.ndarray, y1_data: np.ndarray, if axis: ax1 = axis else: - _, ax1 = plt.subplots(figsize=rcParams["figure.figsize"]) + sp_dict = subplot_dictionary.get() if subplot_dictionary else uPl.SubPlots().get() + _, ax1 = plt.subplots(1, 1, **sp_dict) plot_items = uPl.plot_dictionary_handler(plot_dictionary=plot_dictionary) dict1 = {key: (value[0] if isinstance(value, list) else value) for key, value in plot_items} @@ -307,6 +314,7 @@ def plot_with_dual_axes(x1_data: np.ndarray, y1_data: np.ndarray, def two_subplots(x_data: List[np.ndarray], y_data: List[np.ndarray], x_labels: Optional[List[str]] = None, y_labels: Optional[List[str]] = None, data_labels: Optional[List[str]] = None, + plot_title: Optional[str] = None, subplot_title: Optional[List[str]] = None, orientation: str = 'h', auto_label: bool = False, is_scatter: bool = False, @@ -326,6 +334,10 @@ def two_subplots(x_data: List[np.ndarray], y_data: List[np.ndarray], List of labels for the y-axes in each subplot. data_labels : list of str List of labels for the data series in each subplot. + plot_title: str, optional + Title of the plot. + subplot_title: list of str, optional + Titles for the subplots, if required. orientation : str, optional, default='h' Orientation of the subplots, either 'h' for horizontal or 'v' for vertical. auto_label : bool, default False @@ -349,6 +361,7 @@ def two_subplots(x_data: List[np.ndarray], y_data: List[np.ndarray], # - Returns the axes object for better integration with other plotting functions. # - Adapts to `n_plotter` for enhanced plot flexibility. # - Removed the redundant `axes` variable for a cleaner implementation. + # - Can handle `plot_title` if orientation == 'h': n_rows, n_cols = 1, 2 @@ -358,12 +371,13 @@ def two_subplots(x_data: List[np.ndarray], y_data: List[np.ndarray], raise ePl.OrientationError("The orientation must be either \'h\' or \'v\'.") return n_plotter(x_data=x_data, y_data=y_data, n_rows=n_rows, n_cols=n_cols, x_labels=x_labels, y_labels=y_labels, data_labels=data_labels, - auto_label=auto_label, is_scatter=is_scatter, subplot_dictionary=subplot_dictionary, plot_dictionary=plot_dictionary) + plot_title=plot_title, subplot_title=subplot_title, auto_label=auto_label, is_scatter=is_scatter, + subplot_dictionary=subplot_dictionary, plot_dictionary=plot_dictionary) def n_plotter(x_data: List[np.ndarray], y_data: List[np.ndarray], n_rows: int, n_cols: int, - x_labels=None, y_labels=None, data_labels=None, plot_title=None, + x_labels=None, y_labels=None, data_labels=None, plot_title=None, subplot_title=None, auto_label: bool = False, is_scatter: bool = False, subplot_dictionary: Optional[uPl.SubPlots] = None, @@ -389,6 +403,8 @@ def n_plotter(x_data: List[np.ndarray], y_data: List[np.ndarray], List of labels for the data series in each subplot. plot_title: str, optional Title of the plot. + subplot_title: list of str, optional + Titles for the subplots, if required. auto_label : bool, default False Automatically assigns labels to subplots if `True`. If `True`, it overwrites user provided labels. Defaults to False. is_scatter : bool, default False @@ -410,6 +426,7 @@ def n_plotter(x_data: List[np.ndarray], y_data: List[np.ndarray], # - Removed `plot_on_dual_axes` or `plot_xy` dependency, instead uses simple plot/scatter functionality. # - Added fail-safe labels to the function # - Efficient handling of `auto_label` argument + # - Can handle `subplot_title` sp_dict = subplot_dictionary.get() if subplot_dictionary else uPl.SubPlots().get() @@ -425,31 +442,39 @@ def n_plotter(x_data: List[np.ndarray], y_data: List[np.ndarray], y_labels = [fr'Y$_{i + 1}$' for i in range(n_cols * n_rows)] data_labels = [f'{i} vs {j}' for i, j in zip(x_labels, y_labels)] - plot_title = 'N Plotter' + subplot_title = [f'Subplot {i}' for i in range(n_cols * n_rows)] + plot_title = f'{n_cols * n_rows} Plotter' # safeguard from `None` iterations in case if no label is provided and auto_label is false else: empty_ = [None for _ in range(n_cols * n_rows)] x_labels = x_labels if x_labels else empty_ y_labels = y_labels if y_labels else empty_ + plot_title = plot_title if plot_title else None data_labels = data_labels if data_labels else empty_ + subplot_title = subplot_title if subplot_title else empty_ shared_y = sp_dict.get('sharey') shared_x1 = sp_dict.get('sharex') shared_x2 = len(axs) - int(len(axs) / n_rows if n_rows > n_cols else n_cols) - for index, ax, j, k in zip(range(n_cols * n_rows), axs, x_labels, y_labels): + + # use column stack instead of zip + zipped = np.column_stack([range(n_cols * n_rows), axs, x_labels, y_labels, subplot_title]) + for index, ax, x_, y_, sp_ in zipped: label = f'{x_labels[index]} vs {y_labels[index]}' if data_labels is None else data_labels[index] uPl.plot_or_scatter(axes=ax, scatter=is_scatter)(x_data[index], y_data[index], label=label, **main_dict[index]) if shared_x1: if not index < shared_x2: - ax.set_xlabel(j) + ax.set_xlabel(x_) else: - ax.set_xlabel(j) + ax.set_xlabel(x_) if not (shared_y and index % n_cols != 0): - ax.set_ylabel(k) - fig.suptitle(plot_title) + ax.set_ylabel(y_) if label: ax.legend(loc='best') + ax.set_title(sp_) + fig.suptitle(plot_title) + fig.tight_layout() return fig, axs