diff --git a/.mypy.ini b/.mypy.ini index 00f3b68e757..ee120cac22f 100644 --- a/.mypy.ini +++ b/.mypy.ini @@ -43,14 +43,6 @@ ignore_errors = True ignore_missing_imports = True ignore_errors = True -[mypy-ert.gui.plottery.*] -ignore_missing_imports = True -ignore_errors = True - -[mypy-ert.gui.resources.*] -ignore_missing_imports = True -ignore_errors = True - [mypy-ert.gui.simulation.*] ignore_missing_imports = True ignore_errors = True @@ -59,6 +51,9 @@ ignore_errors = True ignore_missing_imports = True ignore_errors = True +[mypy-mpl_toolkits.*] +ignore_missing_imports = True + [mypy-cwrap.*] ignore_missing_imports = True diff --git a/src/ert/gui/plottery/plot_config.py b/src/ert/gui/plottery/plot_config.py index 0c4f812eaf0..e30b9056c89 100644 --- a/src/ert/gui/plottery/plot_config.py +++ b/src/ert/gui/plottery/plot_config.py @@ -11,7 +11,13 @@ class PlotConfig: # The plot_settings input argument is an internalisation of the (quite few) plot # policy settings which can be set in the configuration file. - def __init__(self, plot_settings=None, title="Unnamed", x_label=None, y_label=None): + def __init__( + self, + plot_settings: Optional[dict[str, Any]] = None, + title: str = "Unnamed", + x_label: Optional[str] = None, + y_label: Optional[str] = None, + ): self._title = title self._plot_settings = plot_settings if self._plot_settings is None: @@ -25,8 +31,8 @@ def __init__(self, plot_settings=None, title="Unnamed", x_label=None, y_label=No # ["#e41a1c", "#377eb8", "#4daf4a", "#984ea3", "#ff7f00", "#ffff33", # "#a65628", "#f781bf" ,"#386CB0", "#7FC97F", "#FDC086", "#F0027F", "#BF5B17"] - self._legend_items = [] - self._legend_labels = [] + self._legend_items: list[Any] = [] + self._legend_labels: list[str] = [] self._x_label = x_label self._y_label = y_label @@ -62,7 +68,7 @@ def __init__(self, plot_settings=None, title="Unnamed", x_label=None, y_label=No name="Distribution lines", line_style="-", alpha=0.25, width=1.0 ) self._distribution_line_style.setEnabled(False) - self._current_color = None + self._current_color: Optional[str] = None self._legend_enabled = True self._grid_enabled = True @@ -80,12 +86,13 @@ def __init__(self, plot_settings=None, title="Unnamed", x_label=None, y_label=No def currentColor(self) -> str: if self._current_color is None: - self.nextColor() + return self.nextColor() return self._current_color - def nextColor(self): - self._current_color = next(self._line_color_cycle) + def nextColor(self) -> str: + color = next(self._line_color_cycle) + self._current_color = color return self._current_color def setLineColorCycle(self, color_list: List[str]) -> None: @@ -150,7 +157,7 @@ def xLabel(self) -> Optional[str]: def yLabel(self) -> Optional[str]: return self._y_label - def legendItems(self): + def legendItems(self) -> List[Any]: return self._legend_items def legendLabels(self) -> List[str]: @@ -183,10 +190,10 @@ def isDistributionLineEnabled(self) -> bool: def setDistributionLineEnabled(self, enabled: bool) -> None: self._distribution_line_style.setEnabled(enabled) - def setStandardDeviationFactor(self, value: float) -> None: + def setStandardDeviationFactor(self, value: int) -> None: self._std_dev_factor = value - def getStandardDeviationFactor(self) -> float: + def getStandardDeviationFactor(self) -> int: return self._std_dev_factor def setLegendEnabled(self, enabled: bool) -> None: diff --git a/src/ert/gui/plottery/plot_context.py b/src/ert/gui/plottery/plot_context.py index 0ee2636164d..ecde8c4e432 100644 --- a/src/ert/gui/plottery/plot_context.py +++ b/src/ert/gui/plottery/plot_context.py @@ -5,6 +5,8 @@ from ert.gui.tools.plot.plot_api import EnsembleObject if TYPE_CHECKING: + from pandas import DataFrame + from ert.gui.plottery import PlotConfig @@ -35,7 +37,7 @@ def __init__( self._key = key self._ensembles = ensembles self._plot_config = plot_config - self.history_data = None + self.history_data: Optional[DataFrame] = None self._log_scale = False self._layer: Optional[int] = layer diff --git a/src/ert/gui/plottery/plot_style.py b/src/ert/gui/plottery/plot_style.py index 45fdbfa60cd..009e510ae22 100644 --- a/src/ert/gui/plottery/plot_style.py +++ b/src/ert/gui/plottery/plot_style.py @@ -46,23 +46,7 @@ def setEnabled(self, enabled: bool) -> None: self._enabled = enabled def isVisible(self) -> bool: - return self.line_style or self.marker - - @property - def name(self) -> str: - return self._name - - @name.setter - def name(self, name: str) -> None: - self._name = name - - @property - def color(self) -> str: - return self._color - - @color.setter - def color(self, color): - self._color = color + return bool(self.line_style or self.marker) @property def alpha(self) -> float: diff --git a/src/ert/gui/plottery/plots/cesp.py b/src/ert/gui/plottery/plots/cesp.py index e124356b741..a8c8763a278 100644 --- a/src/ert/gui/plottery/plots/cesp.py +++ b/src/ert/gui/plottery/plots/cesp.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, Dict, List, TypedDict +from typing import TYPE_CHECKING, Dict, List, TypedDict import pandas as pd from matplotlib.lines import Line2D @@ -43,12 +43,12 @@ def __init__(self) -> None: def plot( figure: Figure, plot_context: PlotContext, - ensemble_to_data_map: Dict[str, DataFrame], - _observation_data: DataFrame, - std_dev_images: Any, + ensemble_to_data_map: Dict[EnsembleObject, pd.DataFrame], + observation_data: pd.DataFrame, + std_dev_images: Dict[str, bytes], ) -> None: plotCrossEnsembleStatistics( - figure, plot_context, ensemble_to_data_map, _observation_data + figure, plot_context, ensemble_to_data_map, observation_data ) @@ -161,24 +161,24 @@ def _addStatisticsLegend( plot_config.addLegendItem(style.name, line) -def _assertNumeric(data): - data = data[0] - if data.dtype == "object": +def _assertNumeric(data: pd.DataFrame) -> pd.Series: + data_series = data[0] + if data_series.dtype == "object": try: - data = pd.to_numeric(data, errors="coerce") + data_series = pd.to_numeric(data_series, errors="coerce") except AttributeError: - data = data.convert_objects(convert_numeric=True) + data_series = data_series.convert_objects(convert_numeric=True) - if data.dtype == "object": - data = None - return data + if data_series.dtype == "object": + data_series = None + return data_series def _plotCrossEnsembleStatistics( axes: "Axes", plot_config: "PlotConfig", data: CcsData, index: int -): - axes.set_xlabel(plot_config.xLabel()) - axes.set_ylabel(plot_config.yLabel()) +) -> None: + axes.set_xlabel(plot_config.xLabel()) # type: ignore + axes.set_ylabel(plot_config.yLabel()) # type: ignore style = plot_config.getStatisticsStyle("mean") if style.isVisible(): diff --git a/src/ert/gui/plottery/plots/distribution.py b/src/ert/gui/plottery/plots/distribution.py index aa91bc4eac4..6d527a4c8b0 100644 --- a/src/ert/gui/plottery/plots/distribution.py +++ b/src/ert/gui/plottery/plots/distribution.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, Dict, List +from typing import TYPE_CHECKING, Dict, List, Optional import pandas as pd @@ -11,7 +11,6 @@ if TYPE_CHECKING: from matplotlib.axes import Axes from matplotlib.figure import Figure - from pandas import DataFrame from ert.gui.plottery import PlotConfig, PlotContext @@ -24,11 +23,11 @@ def __init__(self) -> None: def plot( figure: Figure, plot_context: PlotContext, - ensemble_to_data_map: dict[str, DataFrame], - _observation_data: DataFrame, - std_dev_images: Any, + ensemble_to_data_map: Dict[EnsembleObject, pd.DataFrame], + observation_data: pd.DataFrame, + std_dev_images: Dict[str, bytes], ) -> None: - plotDistribution(figure, plot_context, ensemble_to_data_map, _observation_data) + plotDistribution(figure, plot_context, ensemble_to_data_map, observation_data) def plotDistribution( @@ -89,12 +88,12 @@ def _plotDistribution( data: pd.DataFrame, label: str, index: int, - previous_data, -): + previous_data: Optional[pd.DataFrame], +) -> None: data = pd.Series(dtype="float64") if data.empty else data[0] - axes.set_xlabel(plot_config.xLabel()) - axes.set_ylabel(plot_config.yLabel()) + axes.set_xlabel(plot_config.xLabel()) # type: ignore + axes.set_ylabel(plot_config.yLabel()) # type: ignore style = plot_config.distributionStyle() diff --git a/src/ert/gui/plottery/plots/gaussian_kde.py b/src/ert/gui/plottery/plots/gaussian_kde.py index 8ead59350ec..0e38ae87612 100644 --- a/src/ert/gui/plottery/plots/gaussian_kde.py +++ b/src/ert/gui/plottery/plots/gaussian_kde.py @@ -25,11 +25,11 @@ def __init__(self) -> None: def plot( figure: Figure, plot_context: PlotContext, - ensemble_to_data_map: Dict[str, pd.DataFrame], - _observation_data: Any, - std_dev_images: Any, + ensemble_to_data_map: Dict[EnsembleObject, pd.DataFrame], + observation_data: pd.DataFrame, + std_dev_images: Dict[str, bytes], ) -> None: - plotGaussianKDE(figure, plot_context, ensemble_to_data_map, _observation_data) + plotGaussianKDE(figure, plot_context, ensemble_to_data_map, observation_data) def plotGaussianKDE( diff --git a/src/ert/gui/plottery/plots/histogram.py b/src/ert/gui/plottery/plots/histogram.py index 51e5be44b12..f8ecc56b6c3 100644 --- a/src/ert/gui/plottery/plots/histogram.py +++ b/src/ert/gui/plottery/plots/histogram.py @@ -1,7 +1,7 @@ from __future__ import annotations from math import ceil, floor, log10, sqrt -from typing import TYPE_CHECKING, Dict, List, Optional +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Union import numpy import pandas as pd @@ -14,6 +14,7 @@ if TYPE_CHECKING: from matplotlib.axes import Axes from matplotlib.figure import Figure + from numpy.typing import NDArray from ert.gui.plottery import PlotContext, PlotStyle @@ -101,8 +102,8 @@ def plotHistogram( else: current_min = data[ensemble.name].min() current_max = data[ensemble.name].max() - minimum = current_min if minimum is None else min(minimum, current_min) - maximum = current_max if maximum is None else max(maximum, current_max) + minimum = current_min if minimum is None else min(minimum, current_min) # type: ignore + maximum = current_max if maximum is None else max(maximum, current_max) # type: ignore max_element_count = max(max_element_count, len(data[ensemble.name].index)) bin_count = int(ceil(sqrt(max_element_count))) @@ -193,11 +194,12 @@ def _plotHistogram( minimum: Optional[float] = None, maximum: Optional[float] = None, ) -> Rectangle: + bins: Union[Sequence[float], int] if minimum is not None and maximum is not None: if use_log_scale: - bins = _histogramLogBins(bin_count, minimum, maximum) + bins = _histogramLogBins(bin_count, minimum, maximum) # type: ignore else: - bins = numpy.linspace(minimum, maximum, bin_count) + bins = numpy.linspace(minimum, maximum, bin_count) # type: ignore if minimum == maximum: minimum -= 0.5 @@ -214,7 +216,9 @@ def _plotHistogram( ) # creates rectangle patch for legend use.' -def _histogramLogBins(bin_count: int, minimum: float, maximum: float): +def _histogramLogBins( + bin_count: int, minimum: float, maximum: float +) -> NDArray[numpy.floating[Any]]: minimum = log10(float(minimum)) maximum = log10(float(maximum)) diff --git a/src/ert/gui/plottery/plots/plot_tools.py b/src/ert/gui/plottery/plots/plot_tools.py index 694e4c3a56a..9ef2f883b0e 100644 --- a/src/ert/gui/plottery/plots/plot_tools.py +++ b/src/ert/gui/plottery/plots/plot_tools.py @@ -3,10 +3,12 @@ from typing import TYPE_CHECKING, Optional if TYPE_CHECKING: + from datetime import date + from matplotlib.axes import Axes from matplotlib.figure import Figure - from ert.gui.plottery import PlotContext, PlotLimits + from ert.gui.plottery import PlotContext class PlotTools: @@ -23,7 +25,13 @@ def showLegend(axes: Axes, plot_context: PlotContext) -> None: axes.legend(config.legendItems(), config.legendLabels(), numpoints=1) @staticmethod - def _getXAxisLimits(plot_context: PlotContext) -> Optional[PlotLimits]: + def _getXAxisLimits( + plot_context: PlotContext, + ) -> Optional[ + tuple[Optional[int], Optional[int]] + | tuple[Optional[float], Optional[float]] + | tuple[Optional[date], Optional[date]] + ]: limits = plot_context.plotConfig().limits axis_name = plot_context.x_axis @@ -41,7 +49,13 @@ def _getXAxisLimits(plot_context: PlotContext) -> Optional[PlotLimits]: return None # No limits set @staticmethod - def _getYAxisLimits(plot_context: PlotContext) -> Optional[PlotLimits]: + def _getYAxisLimits( + plot_context: PlotContext, + ) -> Optional[ + tuple[Optional[int], Optional[int]] + | tuple[Optional[float], Optional[float]] + | tuple[Optional[date], Optional[date]] + ]: limits = plot_context.plotConfig().limits axis_name = plot_context.y_axis @@ -72,8 +86,8 @@ def finalizePlot( PlotTools.__setupLabels(plot_context, default_x_label, default_y_label) plot_config = plot_context.plotConfig() - axes.set_xlabel(plot_config.xLabel()) - axes.set_ylabel(plot_config.yLabel()) + axes.set_xlabel(plot_config.xLabel()) # type: ignore + axes.set_ylabel(plot_config.yLabel()) # type: ignore x_axis_limits = PlotTools._getXAxisLimits(plot_context) if x_axis_limits is not None: diff --git a/src/ert/gui/plottery/plots/statistics.py b/src/ert/gui/plottery/plots/statistics.py index c7f9574a306..9c20e033268 100644 --- a/src/ert/gui/plottery/plots/statistics.py +++ b/src/ert/gui/plottery/plots/statistics.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Dict from matplotlib.lines import Line2D from matplotlib.patches import Rectangle @@ -17,6 +17,8 @@ from matplotlib.axes import Axes from matplotlib.figure import Figure + from ert.gui.tools.plot.plot_api import EnsembleObject + class StatisticsPlot: def __init__(self) -> None: @@ -26,10 +28,10 @@ def __init__(self) -> None: def plot( figure: Figure, plot_context: PlotContext, - ensemble_to_data_map, - _observation_data, - std_dev_images, - ): + ensemble_to_data_map: Dict[EnsembleObject, DataFrame], + _observation_data: DataFrame, + std_dev_images: Dict[str, bytes], + ) -> None: config = plot_context.plotConfig() axes = figure.add_subplot(111) diff --git a/src/ert/gui/plottery/plots/std_dev.py b/src/ert/gui/plottery/plots/std_dev.py index 6f7767836ba..05087ce4424 100644 --- a/src/ert/gui/plottery/plots/std_dev.py +++ b/src/ert/gui/plottery/plots/std_dev.py @@ -2,10 +2,13 @@ from typing import Any, Dict import matplotlib.pyplot as plt +import pandas as pd +from matplotlib.collections import QuadMesh from matplotlib.figure import Figure from mpl_toolkits.axes_grid1 import make_axes_locatable from ert.gui.plottery import PlotContext +from ert.gui.tools.plot.plot_api import EnsembleObject class StdDevPlot: @@ -16,8 +19,8 @@ def plot( self, figure: Figure, plot_context: PlotContext, - ensemble_to_data_map, - _observation_data, + ensemble_to_data_map: Dict[EnsembleObject, pd.DataFrame], + observation_data: pd.DataFrame, std_dev_images: Dict[str, bytes], ) -> None: ensemble_count = len(plot_context.ensembles()) @@ -45,11 +48,13 @@ def plot( self._colorbar(p) @staticmethod - def _colorbar(mappable) -> Any: + def _colorbar(mappable: QuadMesh) -> Any: # https://joseph-long.com/writing/colorbars/ last_axes = plt.gca() ax = mappable.axes + assert ax is not None fig = ax.figure + assert fig is not None divider = make_axes_locatable(ax) cax = divider.append_axes("right", size="5%", pad=0.05) cbar = fig.colorbar(mappable, cax=cax) diff --git a/src/ert/gui/tools/plot/plot_api.py b/src/ert/gui/tools/plot/plot_api.py index 1b3bb6c2dd8..ff3adecb91a 100644 --- a/src/ert/gui/tools/plot/plot_api.py +++ b/src/ert/gui/tools/plot/plot_api.py @@ -175,7 +175,7 @@ def data_for_key(self, ensemble_name: str, key: str) -> pd.DataFrame: except ValueError: return df - def observations_for_key(self, ensemble_name, key) -> pd.DataFrame: + def observations_for_key(self, ensemble_name: str, key: str) -> pd.DataFrame: """Returns a pandas DataFrame with the datapoints for a given observation key for a given ensemble. The row index is the realization number, and the column index is a multi-index with (obs_key, index/date, obs_index), where index/date is @@ -213,7 +213,7 @@ def observations_for_key(self, ensemble_name, key) -> pd.DataFrame: } return pd.DataFrame(data_struct).T - def history_data(self, key, ensembles: Optional[List[str]]) -> pd.DataFrame: + def history_data(self, key: str, ensembles: Optional[List[str]]) -> pd.DataFrame: """Returns a pandas DataFrame with the data points for the history for a given data key, if any. The row index is the index/date and the column index is the key.""" diff --git a/src/ert/gui/tools/plot/plot_widget.py b/src/ert/gui/tools/plot/plot_widget.py index d995db060b2..ff63ea3cae9 100644 --- a/src/ert/gui/tools/plot/plot_widget.py +++ b/src/ert/gui/tools/plot/plot_widget.py @@ -136,8 +136,8 @@ def updatePlot( self, plot_context: "PlotContext", ensemble_to_data_map: Dict[EnsembleObject, pd.DataFrame], - observations: Optional[pd.DataFrame] = None, - std_dev_images: Optional[bytes] = None, + observations: pd.DataFrame, + std_dev_images: Dict[str, bytes], ): self.resetPlot() try: