Skip to content

Commit

Permalink
Type check ert.gui.plottery
Browse files Browse the repository at this point in the history
  • Loading branch information
eivindjahren committed Jun 13, 2024
1 parent 634792c commit b844a05
Show file tree
Hide file tree
Showing 13 changed files with 101 additions and 89 deletions.
11 changes: 3 additions & 8 deletions .mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -43,14 +43,6 @@ ignore_errors = True
ignore_missing_imports = True
ignore_errors = True

[mypy-ert.gui.plottery.*]
ignore_missing_imports = True
ignore_errors = True

[mypy-ert.gui.resources.*]
ignore_missing_imports = True
ignore_errors = True

[mypy-ert.gui.simulation.*]
ignore_missing_imports = True
ignore_errors = True
Expand All @@ -59,6 +51,9 @@ ignore_errors = True
ignore_missing_imports = True
ignore_errors = True

[mypy-mpl_toolkits.*]
ignore_missing_imports = True

[mypy-cwrap.*]
ignore_missing_imports = True

Expand Down
27 changes: 17 additions & 10 deletions src/ert/gui/plottery/plot_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,13 @@ class PlotConfig:
# The plot_settings input argument is an internalisation of the (quite few) plot
# policy settings which can be set in the configuration file.

def __init__(self, plot_settings=None, title="Unnamed", x_label=None, y_label=None):
def __init__(
self,
plot_settings: Optional[dict[str, Any]] = None,
title: str = "Unnamed",
x_label: Optional[str] = None,
y_label: Optional[str] = None,
):
self._title = title
self._plot_settings = plot_settings
if self._plot_settings is None:
Expand All @@ -25,8 +31,8 @@ def __init__(self, plot_settings=None, title="Unnamed", x_label=None, y_label=No
# ["#e41a1c", "#377eb8", "#4daf4a", "#984ea3", "#ff7f00", "#ffff33",
# "#a65628", "#f781bf" ,"#386CB0", "#7FC97F", "#FDC086", "#F0027F", "#BF5B17"]

self._legend_items = []
self._legend_labels = []
self._legend_items: list[Any] = []
self._legend_labels: list[str] = []

self._x_label = x_label
self._y_label = y_label
Expand Down Expand Up @@ -62,7 +68,7 @@ def __init__(self, plot_settings=None, title="Unnamed", x_label=None, y_label=No
name="Distribution lines", line_style="-", alpha=0.25, width=1.0
)
self._distribution_line_style.setEnabled(False)
self._current_color = None
self._current_color: Optional[str] = None

self._legend_enabled = True
self._grid_enabled = True
Expand All @@ -80,12 +86,13 @@ def __init__(self, plot_settings=None, title="Unnamed", x_label=None, y_label=No

def currentColor(self) -> str:
if self._current_color is None:
self.nextColor()
return self.nextColor()

return self._current_color

def nextColor(self):
self._current_color = next(self._line_color_cycle)
def nextColor(self) -> str:
color = next(self._line_color_cycle)
self._current_color = color
return self._current_color

def setLineColorCycle(self, color_list: List[str]) -> None:
Expand Down Expand Up @@ -150,7 +157,7 @@ def xLabel(self) -> Optional[str]:
def yLabel(self) -> Optional[str]:
return self._y_label

def legendItems(self):
def legendItems(self) -> List[Any]:
return self._legend_items

def legendLabels(self) -> List[str]:
Expand Down Expand Up @@ -183,10 +190,10 @@ def isDistributionLineEnabled(self) -> bool:
def setDistributionLineEnabled(self, enabled: bool) -> None:
self._distribution_line_style.setEnabled(enabled)

def setStandardDeviationFactor(self, value: float) -> None:
def setStandardDeviationFactor(self, value: int) -> None:
self._std_dev_factor = value

def getStandardDeviationFactor(self) -> float:
def getStandardDeviationFactor(self) -> int:
return self._std_dev_factor

def setLegendEnabled(self, enabled: bool) -> None:
Expand Down
4 changes: 3 additions & 1 deletion src/ert/gui/plottery/plot_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
from ert.gui.tools.plot.plot_api import EnsembleObject

if TYPE_CHECKING:
from pandas import DataFrame

from ert.gui.plottery import PlotConfig


Expand Down Expand Up @@ -35,7 +37,7 @@ def __init__(
self._key = key
self._ensembles = ensembles
self._plot_config = plot_config
self.history_data = None
self.history_data: Optional[DataFrame] = None
self._log_scale = False
self._layer: Optional[int] = layer

Expand Down
18 changes: 1 addition & 17 deletions src/ert/gui/plottery/plot_style.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,23 +46,7 @@ def setEnabled(self, enabled: bool) -> None:
self._enabled = enabled

def isVisible(self) -> bool:
return self.line_style or self.marker

@property
def name(self) -> str:
return self._name

@name.setter
def name(self, name: str) -> None:
self._name = name

@property
def color(self) -> str:
return self._color

@color.setter
def color(self, color):
self._color = color
return bool(self.line_style or self.marker)

@property
def alpha(self) -> float:
Expand Down
32 changes: 16 additions & 16 deletions src/ert/gui/plottery/plots/cesp.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from typing import TYPE_CHECKING, Any, Dict, List, TypedDict
from typing import TYPE_CHECKING, Dict, List, TypedDict

import pandas as pd
from matplotlib.lines import Line2D
Expand Down Expand Up @@ -43,12 +43,12 @@ def __init__(self) -> None:
def plot(
figure: Figure,
plot_context: PlotContext,
ensemble_to_data_map: Dict[str, DataFrame],
_observation_data: DataFrame,
std_dev_images: Any,
ensemble_to_data_map: Dict[EnsembleObject, pd.DataFrame],
observation_data: pd.DataFrame,
std_dev_images: Dict[str, bytes],
) -> None:
plotCrossEnsembleStatistics(
figure, plot_context, ensemble_to_data_map, _observation_data
figure, plot_context, ensemble_to_data_map, observation_data
)


Expand Down Expand Up @@ -161,24 +161,24 @@ def _addStatisticsLegend(
plot_config.addLegendItem(style.name, line)


def _assertNumeric(data):
data = data[0]
if data.dtype == "object":
def _assertNumeric(data: pd.DataFrame) -> pd.Series:
data_series = data[0]
if data_series.dtype == "object":
try:
data = pd.to_numeric(data, errors="coerce")
data_series = pd.to_numeric(data_series, errors="coerce")
except AttributeError:
data = data.convert_objects(convert_numeric=True)
data_series = data_series.convert_objects(convert_numeric=True)

if data.dtype == "object":
data = None
return data
if data_series.dtype == "object":
data_series = None
return data_series


def _plotCrossEnsembleStatistics(
axes: "Axes", plot_config: "PlotConfig", data: CcsData, index: int
):
axes.set_xlabel(plot_config.xLabel())
axes.set_ylabel(plot_config.yLabel())
) -> None:
axes.set_xlabel(plot_config.xLabel()) # type: ignore
axes.set_ylabel(plot_config.yLabel()) # type: ignore

style = plot_config.getStatisticsStyle("mean")
if style.isVisible():
Expand Down
19 changes: 9 additions & 10 deletions src/ert/gui/plottery/plots/distribution.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from typing import TYPE_CHECKING, Any, Dict, List
from typing import TYPE_CHECKING, Dict, List, Optional

import pandas as pd

Expand All @@ -11,7 +11,6 @@
if TYPE_CHECKING:
from matplotlib.axes import Axes
from matplotlib.figure import Figure
from pandas import DataFrame

from ert.gui.plottery import PlotConfig, PlotContext

Expand All @@ -24,11 +23,11 @@ def __init__(self) -> None:
def plot(
figure: Figure,
plot_context: PlotContext,
ensemble_to_data_map: dict[str, DataFrame],
_observation_data: DataFrame,
std_dev_images: Any,
ensemble_to_data_map: Dict[EnsembleObject, pd.DataFrame],
observation_data: pd.DataFrame,
std_dev_images: Dict[str, bytes],
) -> None:
plotDistribution(figure, plot_context, ensemble_to_data_map, _observation_data)
plotDistribution(figure, plot_context, ensemble_to_data_map, observation_data)


def plotDistribution(
Expand Down Expand Up @@ -89,12 +88,12 @@ def _plotDistribution(
data: pd.DataFrame,
label: str,
index: int,
previous_data,
):
previous_data: Optional[pd.DataFrame],
) -> None:
data = pd.Series(dtype="float64") if data.empty else data[0]

axes.set_xlabel(plot_config.xLabel())
axes.set_ylabel(plot_config.yLabel())
axes.set_xlabel(plot_config.xLabel()) # type: ignore
axes.set_ylabel(plot_config.yLabel()) # type: ignore

style = plot_config.distributionStyle()

Expand Down
8 changes: 4 additions & 4 deletions src/ert/gui/plottery/plots/gaussian_kde.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,11 @@ def __init__(self) -> None:
def plot(
figure: Figure,
plot_context: PlotContext,
ensemble_to_data_map: Dict[str, pd.DataFrame],
_observation_data: Any,
std_dev_images: Any,
ensemble_to_data_map: Dict[EnsembleObject, pd.DataFrame],
observation_data: pd.DataFrame,
std_dev_images: Dict[str, bytes],
) -> None:
plotGaussianKDE(figure, plot_context, ensemble_to_data_map, _observation_data)
plotGaussianKDE(figure, plot_context, ensemble_to_data_map, observation_data)


def plotGaussianKDE(
Expand Down
16 changes: 10 additions & 6 deletions src/ert/gui/plottery/plots/histogram.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

from math import ceil, floor, log10, sqrt
from typing import TYPE_CHECKING, Dict, List, Optional
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Union

import numpy
import pandas as pd
Expand All @@ -14,6 +14,7 @@
if TYPE_CHECKING:
from matplotlib.axes import Axes
from matplotlib.figure import Figure
from numpy.typing import NDArray

from ert.gui.plottery import PlotContext, PlotStyle

Expand Down Expand Up @@ -101,8 +102,8 @@ def plotHistogram(
else:
current_min = data[ensemble.name].min()
current_max = data[ensemble.name].max()
minimum = current_min if minimum is None else min(minimum, current_min)
maximum = current_max if maximum is None else max(maximum, current_max)
minimum = current_min if minimum is None else min(minimum, current_min) # type: ignore
maximum = current_max if maximum is None else max(maximum, current_max) # type: ignore
max_element_count = max(max_element_count, len(data[ensemble.name].index))

bin_count = int(ceil(sqrt(max_element_count)))
Expand Down Expand Up @@ -193,11 +194,12 @@ def _plotHistogram(
minimum: Optional[float] = None,
maximum: Optional[float] = None,
) -> Rectangle:
bins: Union[Sequence[float], int]
if minimum is not None and maximum is not None:
if use_log_scale:
bins = _histogramLogBins(bin_count, minimum, maximum)
bins = _histogramLogBins(bin_count, minimum, maximum) # type: ignore
else:
bins = numpy.linspace(minimum, maximum, bin_count)
bins = numpy.linspace(minimum, maximum, bin_count) # type: ignore

if minimum == maximum:
minimum -= 0.5
Expand All @@ -214,7 +216,9 @@ def _plotHistogram(
) # creates rectangle patch for legend use.'


def _histogramLogBins(bin_count: int, minimum: float, maximum: float):
def _histogramLogBins(
bin_count: int, minimum: float, maximum: float
) -> NDArray[numpy.floating[Any]]:
minimum = log10(float(minimum))
maximum = log10(float(maximum))

Expand Down
24 changes: 19 additions & 5 deletions src/ert/gui/plottery/plots/plot_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,12 @@
from typing import TYPE_CHECKING, Optional

if TYPE_CHECKING:
from datetime import date

from matplotlib.axes import Axes
from matplotlib.figure import Figure

from ert.gui.plottery import PlotContext, PlotLimits
from ert.gui.plottery import PlotContext


class PlotTools:
Expand All @@ -23,7 +25,13 @@ def showLegend(axes: Axes, plot_context: PlotContext) -> None:
axes.legend(config.legendItems(), config.legendLabels(), numpoints=1)

@staticmethod
def _getXAxisLimits(plot_context: PlotContext) -> Optional[PlotLimits]:
def _getXAxisLimits(
plot_context: PlotContext,
) -> Optional[
tuple[Optional[int], Optional[int]]
| tuple[Optional[float], Optional[float]]
| tuple[Optional[date], Optional[date]]
]:
limits = plot_context.plotConfig().limits
axis_name = plot_context.x_axis

Expand All @@ -41,7 +49,13 @@ def _getXAxisLimits(plot_context: PlotContext) -> Optional[PlotLimits]:
return None # No limits set

@staticmethod
def _getYAxisLimits(plot_context: PlotContext) -> Optional[PlotLimits]:
def _getYAxisLimits(
plot_context: PlotContext,
) -> Optional[
tuple[Optional[int], Optional[int]]
| tuple[Optional[float], Optional[float]]
| tuple[Optional[date], Optional[date]]
]:
limits = plot_context.plotConfig().limits
axis_name = plot_context.y_axis

Expand Down Expand Up @@ -72,8 +86,8 @@ def finalizePlot(
PlotTools.__setupLabels(plot_context, default_x_label, default_y_label)

plot_config = plot_context.plotConfig()
axes.set_xlabel(plot_config.xLabel())
axes.set_ylabel(plot_config.yLabel())
axes.set_xlabel(plot_config.xLabel()) # type: ignore
axes.set_ylabel(plot_config.yLabel()) # type: ignore

x_axis_limits = PlotTools._getXAxisLimits(plot_context)
if x_axis_limits is not None:
Expand Down
Loading

0 comments on commit b844a05

Please sign in to comment.