Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Teach OutliersTransform to ignore holidays #291

Merged
merged 32 commits into from
Apr 9, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
7c33aed
add feature
Polzovat123 Mar 28, 2024
8faf807
issue 9043
Polzovat123 Mar 29, 2024
94561b8
fix lint error
Polzovat123 Apr 2, 2024
371aef1
clean code
Polzovat123 Apr 3, 2024
9d62d81
more informative docs
Polzovat123 Apr 3, 2024
e29fc47
add tests
Polzovat123 Apr 3, 2024
cfb7cbd
write CHANGELOG.md
Polzovat123 Apr 3, 2024
ec2d517
clear code
Polzovat123 Apr 3, 2024
3e745df
fix bug
Polzovat123 Apr 3, 2024
6840c75
fix style
Polzovat123 Apr 3, 2024
b28cbc1
fix
Polzovat123 Apr 3, 2024
15f5736
fix CHANGELOG.md
Polzovat123 Apr 4, 2024
b2a0f62
fix use level name instead of index to improve readability
Polzovat123 Apr 4, 2024
f40e0a9
more descriptive name
Polzovat123 Apr 4, 2024
68ffccc
use isna
Polzovat123 Apr 4, 2024
0a2d78b
separate test for error
Polzovat123 Apr 4, 2024
d65f202
clear duplication
Polzovat123 Apr 4, 2024
68a5deb
clear after lint
Polzovat123 Apr 4, 2024
28d7221
clear add test pipeline
Polzovat123 Apr 4, 2024
f73f9bf
fix typo mistake
Polzovat123 Apr 5, 2024
d24e3fa
add link in CHANGELOG.md
Polzovat123 Apr 5, 2024
c9738ab
add link in CHANGELOG.md
Polzovat123 Apr 5, 2024
4eb344a
remove duplicate code
Polzovat123 Apr 5, 2024
51b5044
rewrite pipeline test, that shows real usage example
Polzovat123 Apr 5, 2024
211b681
all -> any
Polzovat123 Apr 5, 2024
8927a7a
test behaviour if ignore_column is regressor/non-regressor
Polzovat123 Apr 5, 2024
9445f50
clear code
Polzovat123 Apr 5, 2024
3b5f167
fix pull, not issue
Polzovat123 Apr 5, 2024
4a00daf
remove useless test
Polzovat123 Apr 5, 2024
6c83b12
remove useless tune_params check
Polzovat123 Apr 5, 2024
d925f0a
Merge branch 'master' into issue-1850009043
Polzovat123 Apr 8, 2024
95e5129
Merge branch 'master' into issue-1850009043
Polzovat123 Apr 8, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 23 additions & 3 deletions etna/transforms/outliers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,23 @@
class OutliersTransform(ReversibleTransform, ABC):
"""Finds outliers in specific columns of DataFrame and replaces it with NaNs."""

def __init__(self, in_column: str):
def __init__(self, in_column: str, ignore_flag_column: Optional[str] = None):
"""
Create instance of OutliersTransform.

Parameters
----------
in_column:
name of processed column
ignore_flag_column:
name of column binary flag of holidays
brsnw250 marked this conversation as resolved.
Show resolved Hide resolved
"""
super().__init__(required_features=[in_column])
if ignore_flag_column:
super().__init__(required_features=[in_column, ignore_flag_column])
brsnw250 marked this conversation as resolved.
Show resolved Hide resolved
else:
super().__init__(required_features=[in_column])
self.in_column = in_column
self.ignore_flag_column = ignore_flag_column

self.segment_outliers: Optional[Dict[str, pd.Series]] = None

Expand Down Expand Up @@ -78,6 +84,12 @@
:
The fitted transform instance.
"""
if self.ignore_flag_column is not None and self.ignore_flag_column not in ts.regressors:
raise ValueError("Name ignore_flag_column not find.")
if self.ignore_flag_column is not None and not all(
ts[:, segment, self.ignore_flag_column].isin([0, 1]).all() for segment in ts.segments
brsnw250 marked this conversation as resolved.
Show resolved Hide resolved
):
raise ValueError("Columns ignore_flag contain non binary value")

Check warning on line 92 in etna/transforms/outliers/base.py

View check run for this annotation

Codecov / codecov/patch

etna/transforms/outliers/base.py#L92

Added line #L92 was not covered by tests
brsnw250 marked this conversation as resolved.
Show resolved Hide resolved
self.segment_outliers = self.detect_outliers(ts)
self._fit_segments = ts.segments
super().fit(ts=ts)
Expand Down Expand Up @@ -131,8 +143,16 @@
if segment not in segments:
continue
# to locate only present indices
segment_outliers_timestamps = list(index_set.intersection(self.segment_outliers[segment].index.values))
if self.ignore_flag_column:
available_points = set(df[df[segment, self.ignore_flag_column] == 0].index.values)
else:
available_points = index_set
segment_outliers_timestamps = list(
available_points.intersection(self.segment_outliers[segment].index.values)
)

df.loc[segment_outliers_timestamps, pd.IndexSlice[segment, self.in_column]] = np.NaN

return df

def _inverse_transform(self, df: pd.DataFrame) -> pd.DataFrame:
Expand Down
22 changes: 17 additions & 5 deletions etna/transforms/outliers/point_outliers.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Callable
from typing import Dict
from typing import List
from typing import Optional
from typing import Type
from typing import Union

Expand Down Expand Up @@ -32,7 +33,13 @@ class MedianOutliersTransform(OutliersTransform):
it uses information from the whole train part.
"""

def __init__(self, in_column: str, window_size: int = 10, alpha: float = 3):
def __init__(
self,
in_column: str,
window_size: int = 10,
alpha: float = 3,
ignore_flag_column: Optional[str] = None,
):
"""Create instance of MedianOutliersTransform.

Parameters
Expand All @@ -43,10 +50,12 @@ def __init__(self, in_column: str, window_size: int = 10, alpha: float = 3):
number of points in the window
alpha:
coefficient for determining the threshold
ignore_flag_column:
name of column binary flag of holidays
brsnw250 marked this conversation as resolved.
Show resolved Hide resolved
"""
self.window_size = window_size
self.alpha = alpha
super().__init__(in_column=in_column)
super().__init__(in_column=in_column, ignore_flag_column=ignore_flag_column)

def detect_outliers(self, ts: TSDataset) -> Dict[str, List[pd.Timestamp]]:
"""Call :py:func:`~etna.analysis.outliers.median_outliers.get_anomalies_median` function with self parameters.
Expand Down Expand Up @@ -97,6 +106,7 @@ def __init__(
distance_coef: float = 3,
n_neighbors: int = 3,
distance_func: Union[Literal["absolute_difference"], Callable[[float, float], float]] = "absolute_difference",
ignore_flag_column: Optional[str] = None,
):
"""Create instance of DensityOutliersTransform.

Expand All @@ -118,7 +128,7 @@ def __init__(
self.distance_coef = distance_coef
self.n_neighbors = n_neighbors
self.distance_func = distance_func
super().__init__(in_column=in_column)
super().__init__(in_column=in_column, ignore_flag_column=ignore_flag_column)

def detect_outliers(self, ts: TSDataset) -> Dict[str, List[pd.Timestamp]]:
"""Call :py:func:`~etna.analysis.outliers.density_outliers.get_anomalies_density` function with self parameters.
Expand Down Expand Up @@ -169,6 +179,7 @@ def __init__(
in_column: str,
model: Union[Literal["prophet"], Literal["sarimax"], Type["ProphetModel"], Type["SARIMAXModel"]],
interval_width: float = 0.95,
ignore_flag_column: Optional[str] = None,
**model_kwargs,
):
"""Create instance of PredictionIntervalOutliersTransform.
Expand All @@ -181,7 +192,8 @@ def __init__(
model for prediction interval estimation
interval_width:
width of the prediction interval

ignore_flag_column:
name of column binary flag of holidays
Notes
-----
For not "target" column only column data will be used for learning.
Expand All @@ -190,7 +202,7 @@ def __init__(
self.interval_width = interval_width
self.model_kwargs = model_kwargs
self._model_type = self._get_model_type(model)
super().__init__(in_column=in_column)
super().__init__(in_column=in_column, ignore_flag_column=ignore_flag_column)

@staticmethod
def _get_model_type(
Expand Down
61 changes: 61 additions & 0 deletions tests/test_transforms/test_outliers/test_outliers_transform.py
brsnw250 marked this conversation as resolved.
Show resolved Hide resolved
brsnw250 marked this conversation as resolved.
Show resolved Hide resolved
brsnw250 marked this conversation as resolved.
Show resolved Hide resolved
brsnw250 marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from etna.models import ProphetModel
from etna.models import SARIMAXModel
from etna.transforms import DensityOutliersTransform
from etna.transforms import HolidayTransform
from etna.transforms import MedianOutliersTransform
from etna.transforms import PredictionIntervalOutliersTransform
from tests.test_transforms.utils import assert_sampling_is_valid
Expand Down Expand Up @@ -42,6 +43,34 @@ def outliers_solid_tsds():
return ts


@pytest.fixture()
def outliers_solid_tsds_with_holidays():
"""Create TSDataset with outliers with holidays"""
timestamp = pd.date_range("2021-01-01", end="2021-02-20", freq="D")
target1 = [np.sin(i) for i in range(len(timestamp))]
target1[10] += 10

target2 = [np.sin(i) for i in range(len(timestamp))]
target2[8] += 8
target2[15] = 2
target2[26] -= 12

df1 = pd.DataFrame({"timestamp": timestamp, "target": target1, "segment": "1"})
df2 = pd.DataFrame({"timestamp": timestamp, "target": target2, "segment": "2"})
df = pd.concat([df1, df2], ignore_index=True)
df_exog = df.copy()
df_exog.columns = ["timestamp", "regressor_1", "segment"]
ts = TSDataset(
df=TSDataset.to_dataset(df).iloc[:-10],
df_exog=TSDataset.to_dataset(df_exog),
freq="D",
known_future="all",
)
holiday_transform = HolidayTransform(iso_code="RUS", mode="binary", out_column="is_holiday")
ts = holiday_transform.fit_transform(ts)
return ts


@pytest.mark.parametrize("attribute_name,value_type", (("outliers_timestamps", list), ("original_values", pd.Series)))
def test_density_outliers_deprecated_store_attributes(outliers_solid_tsds, attribute_name, value_type):
transform = DensityOutliersTransform(in_column="target")
Expand Down Expand Up @@ -255,3 +284,35 @@ def test_params_to_tune(transform, outliers_solid_tsds):
ts = outliers_solid_tsds
assert len(transform.params_to_tune()) > 0
assert_sampling_is_valid(transform=transform, ts=ts)


@pytest.mark.parametrize(
"transform",
(
MedianOutliersTransform(in_column="target", ignore_flag_column="is_holiday"),
DensityOutliersTransform(in_column="target", ignore_flag_column="is_holiday"),
# PredictionIntervalOutliersTransform(in_column="target", model="sarimax", ignore_flag_column="is_holiday"),
brsnw250 marked this conversation as resolved.
Show resolved Hide resolved
),
)
def test_correct_ignore_flag(transform, outliers_solid_tsds_with_holidays):
ts = outliers_solid_tsds_with_holidays
assert len(transform.params_to_tune()) > 0
transform.fit(ts)
ts_output = transform.transform(ts)
assert ts_output["2021-01-06":"2021-01-06", "1", "target"][0] != np.nan
brsnw250 marked this conversation as resolved.
Show resolved Hide resolved


@pytest.mark.parametrize(
"transform",
(
MedianOutliersTransform(in_column="target", ignore_flag_column="is_holiday"),
DensityOutliersTransform(in_column="target", ignore_flag_column="is_holiday"),
PredictionIntervalOutliersTransform(in_column="target", model="sarimax", ignore_flag_column="is_holiday"),
),
)
def test_incorrect_formats(transform, outliers_solid_tsds):
ts = outliers_solid_tsds
assert len(transform.params_to_tune()) > 0
with pytest.raises(ValueError):
transform.fit(ts)
_ = transform.transform(ts)
Loading