diff --git a/environment.yaml b/environment.yaml index a36e2f1..98f460a 100644 --- a/environment.yaml +++ b/environment.yaml @@ -2,8 +2,4 @@ name: mpyez channels: - conda-forge - defaults -dependencies: - - python=3.9.* - - numpy==1.26.4 - - matplotlib - - setuptools +dependencies: [ python=3.9, numpy==2, matplotlib, setuptools ] diff --git a/pyproject.toml b/pyproject.toml index 72087d1..bee597f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,4 +1,4 @@ [build-system] -requires = ["setuptools", "numpy==1.26.4", "matplotlib"] +requires = ["setuptools", "numpy<2.1.0", "matplotlib"] build-backend = "setuptools.build_meta:__legacy__" diff --git a/requirements.txt b/requirements.txt index 1061981..f42e317 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,2 +1,3 @@ -numpy==1.26.4 +numpy<2.1.0 matplotlib +setuptools diff --git a/setup.py b/setup.py index 5fc6ed4..ccce88b 100644 --- a/setup.py +++ b/setup.py @@ -30,8 +30,7 @@ def load_metadata(): long_description=readme, long_description_content_type="text/markdown", python_requires=">=3.9", - install_requires=["numpy==1.26.4", "matplotlib", "setuptools"], - include_package_data=True, + install_requires=["numpy<2.1.0", "matplotlib", "setuptools"], classifiers=["License :: OSI Approved :: MIT License", "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", diff --git a/src/mpyez/backend/uPlotting.py b/src/mpyez/backend/uPlotting.py index 9ad9aaf..e6f4f1d 100644 --- a/src/mpyez/backend/uPlotting.py +++ b/src/mpyez/backend/uPlotting.py @@ -152,7 +152,7 @@ def __init__(self, line_style=None, line_width=None, color=None, alpha=None, alpha=alpha, marker=marker, marker_size=marker_size, marker_edge_color=marker_edge_color, marker_face_color=marker_face_color, marker_edge_width=marker_edge_width) - self.color = get_color() + self.color = color or get_color() def __repr__(self): param_str = ', '.join(f"{key}={value!r}" for key, value in self.to_dict().items()) @@ -221,7 +221,7 @@ def __init__(self, color=None, alpha=None, marker=None, size=None, cmap=None, fa super().__init__(color=color, alpha=alpha, marker=marker, size=size, cmap=cmap, face_color=face_color) if self.color is None: - self.color = get_color() + self.color = color or get_color() def __repr__(self): param_str = ', '.join(f"{key}={value!r}" for key, value in self.to_dict().items()) diff --git a/src/mpyez/ezArray.py b/src/mpyez/ezArray.py index 984463f..d8a0dfd 100644 --- a/src/mpyez/ezArray.py +++ b/src/mpyez/ezArray.py @@ -62,3 +62,29 @@ def moving_average(array: np.ndarray, window_size: int) -> np.ndarray: An array containing the moving averages. The length of this array will be `len(array) - window_size + 1`. """ return np.convolve(array, np.ones(window_size), 'valid') / window_size + + +def evaluate_with_broadcast(func, constant_array, **param_arrays): + """ + Evaluate a function with a constant array and multiple parameter arrays using broadcasting. + + Parameters + ---------- + func : callable + The function to evaluate. It must support broadcasting of inputs. The function signature should match `func(constant_array, **param_arrays)`. + constant_array : ndarray + The constant array to evaluate the function over. + **param_arrays : dict + Keyword arguments representing parameter arrays. Each array should be compatible with broadcasting. + + Returns + ------- + ndarray + The result of the function evaluation. The shape of the result is `(len(constant_array), ...)`, + where the additional dimensions correspond to the shapes of the parameter arrays. + """ + # Expand constant_array to add a broadcast dimension + expanded_constant = constant_array[:, np.newaxis] + params_ = {key: value[np.newaxis, :] for key, value in param_arrays.items()} + + return func(expanded_constant, **params_) diff --git a/src/mpyez/ezMultiprocessing.py b/src/mpyez/ezMultiprocessing.py index 70e15e9..b85b3c5 100644 --- a/src/mpyez/ezMultiprocessing.py +++ b/src/mpyez/ezMultiprocessing.py @@ -1 +1,46 @@ """Created on Jun 12 13:48:59 2024""" + +from multiprocessing import Pool +from typing import Callable, Dict, List + +import numpy as np + + +class MultiProcessor: + """Generalized multiprocessing functionality for processing dictionary inputs.""" + + def __init__(self, func: Callable, args: Dict[str, List], n_processors: int = 3): + """ + Initialize the `MultiProcessor` class. + + Parameters + ---------- + func: Callable + Function to be processed. The function should take arguments corresponding to the keys in the dictionary. + args: Dict[str, List] + A dictionary where keys are argument names, and values are lists of values for those arguments. + n_processors: int, optional + The number of processors to use for multiprocessing. Default is 3. + """ + if n_processors < 1: + raise ValueError("`n_processors` must be at least 1.") + + self.func = func + self.args = args + self.n_proc = n_processors + + lengths = [len(v) for v in args.values()] + if len(set(lengths)) != 1: + raise ValueError("All argument lists must have the same length.") + + self.arg_tuples = list(zip(*self.args.values())) + + def run(self) -> np.ndarray: + """Run the multiprocessing task.""" + with Pool(self.n_proc) as pool: + try: + results = pool.starmap(self.func, self.arg_tuples) + except Exception as e: + raise RuntimeError(f"Error occurred during parallel execution: {e}") + + return np.array(results) diff --git a/src/mpyez/ezPlotting.py b/src/mpyez/ezPlotting.py index 65304da..9770cb6 100644 --- a/src/mpyez/ezPlotting.py +++ b/src/mpyez/ezPlotting.py @@ -1,4 +1,4 @@ -"""Created on Jul 23 23:41:18 2024""" +"""Created on Jul 23 23:41:18 2024.""" __all__ = ['plot_two_column_file', 'plot_xy', 'plot_xyy', 'plot_with_dual_axes', 'two_subplots', 'n_plotter'] @@ -27,7 +27,7 @@ def plot_two_column_file(file_name: str, plot_dictionary: Optional[plot_dictionary_type] = None, axis: Optional[Axes] = None) -> axis_return: """ - Reads a two-column file (x, y) and plots the data. + Read a two-column file (x, y) and plots the data. This function reads a file containing two columns (e.g., x and y values) and plots them using either a line plot or scatter plot based on the user's preference. @@ -47,7 +47,7 @@ def plot_two_column_file(file_name: str, is_scatter : bool, optional If True, creates a scatter plot. Otherwise, creates a line plot. Default is False. plot_dictionary: Union[LinePlot, ScatterPlot], optional - An object representing the plot data, either a `LinePlot` or `ScatterPlot`, to be passed to the matplotlib plotting library. + An object representing the plot data, either a `LinePlot` or `ScatterPlot`, to be passed to the matplotlib plotting library. If None, a default plot type will be used. axis: Optional[Axes] The axis object to draw the plots on. If not passed, a new axis object will be created internally. @@ -77,7 +77,7 @@ def plot_xy(x_data: np.ndarray, y_data: np.ndarray, plot_dictionary: Optional[plot_dictionary_type] = None, axis: Optional[Axes] = None) -> axis_return: """ - Plots x_data against y_data with customizable options. + Plot the x_data against y_data with customizable options. This function accepts two arrays of data (x and y) and plots them using either a line plot or scatter plot, with options for labels and figure size. @@ -189,7 +189,7 @@ def plot_with_dual_axes(x1_data: np.ndarray, y1_data: np.ndarray, plot_dictionary: Optional[plot_dictionary_type] = None, axis: Optional[Axes] = None) -> axis_return: """ - Plots data with options for dual axes (x or y) or single axis. + Plot the data with options for dual axes (x or y) or single axis. Parameters ---------- @@ -295,7 +295,7 @@ def two_subplots(x_data: List[np.ndarray], y_data: List[np.ndarray], subplot_dictionary: Optional[uPl.SubPlots] = None, plot_dictionary: Optional[Union[uPl.LinePlot, uPl.ScatterPlot]] = None) -> None: """ - Creates two subplots arranged horizontally or vertically, with optional customization. + Create two subplots arranged horizontally or vertically, with optional customization. This function internally calls `n_plotter` to handle the plotting of each subplot. `n_plotter` arranges the subplots and applies relevant plot and subplot dictionaries. @@ -359,7 +359,7 @@ def n_plotter(x_data: List[np.ndarray], y_data: List[np.ndarray], subplot_dictionary: Optional[uPl.SubPlots] = None, plot_dictionary: Optional[Union[uPl.LinePlot, uPl.ScatterPlot]] = None) -> Union[plt.figure, Axes]: """ - Plots multiple subplots in a grid with optional customization for each subplot. + Plot multiple subplots in a grid with optional customization for each subplot. Parameters ---------- diff --git a/src/mpyez/version.py b/src/mpyez/version.py index 054cc11..90612e8 100644 --- a/src/mpyez/version.py +++ b/src/mpyez/version.py @@ -1,6 +1,7 @@ """Created on Nov 02 21:24:07 2024""" -__version__ = "0.0.9a0" +__name__ = 'mpyez' +__version__ = "0.1.0" __author__ = "Syed Ali Mohsin Bukhari" __email__ = "syedali.b@outlook.com" __license__ = "MIT"