Skip to content

Commit

Permalink
WilderMA
Browse files Browse the repository at this point in the history
  • Loading branch information
femtotrader committed Jul 8, 2024
1 parent 0d6392c commit 92b2cec
Showing 1 changed file with 41 additions and 15 deletions.
56 changes: 41 additions & 15 deletions talipp/indicators/ATR.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,45 @@
from talipp.ohlcv import OHLCV


class WilderMA(Indicator):
"""Wilder's Moving Average.
Input type: `float`
Output type: `float`
Args:
period: Period.
input_values: List of input values.
input_indicator: Input indicator.
input_modifier: Input modifier.
input_sampling: Input sampling type.
"""

def __init__(self, period: int,
input_values: List[float] = None,
input_indicator: Indicator = None,
input_modifier: InputModifierType = None,
input_sampling: SamplingPeriodType = None):
super().__init__(input_modifier=input_modifier,
input_sampling=input_sampling)

self.period = period
self.k = 1.0 / self.period

self.initialize(input_values, input_indicator)

def _calculate_new_value(self) -> Any:
#if not has_valid_values(self.input_values, self.period):
if len(self.input_values) < self.period:
return None
elif has_valid_values(self.input_values, self.period, exact=True):
#elif len(self.input_values) == self.period:
return sum(self.input_values[-self.period:]) / self.period
else:
return float(self.k * self.input_values[-1] + (1.0 - self.k) * self.output_values[-1])


class ATR(Indicator):
"""Average True Range
Expand Down Expand Up @@ -35,22 +74,9 @@ def __init__(self, period: int,
self._tr = TrueRange()
self.add_sub_indicator(self._tr)

self.tr = []

self.add_managed_sequence(self.tr)
self._ma_tr = WilderMA(period, input_indicator=self._tr)

self.initialize(input_values, input_indicator)

def _calculate_new_value(self) -> Any:
tr = self._tr.output_values[-1]
if has_valid_values(self.input_values, 1, exact=True):
self.tr.append(tr)
else:
self.tr.append(tr)

if len(self.input_values) < self.period:
return None
elif len(self.input_values) == self.period:
return sum(self.tr) / self.period
else:
return (self.output_values[-1] * (self.period - 1) + self.tr[-1]) / self.period
return self._ma_tr.output_values[-1]

0 comments on commit 92b2cec

Please sign in to comment.