Skip to content

Commit

Permalink
MAFactory
Browse files Browse the repository at this point in the history
  • Loading branch information
femtotrader committed Jul 8, 2024
1 parent 92b2cec commit d5a17fc
Show file tree
Hide file tree
Showing 5 changed files with 88 additions and 43 deletions.
47 changes: 4 additions & 43 deletions talipp/indicators/ATR.py
Original file line number Diff line number Diff line change
@@ -1,49 +1,10 @@
from typing import List, Any

from talipp.indicator_util import has_valid_values
from talipp.indicators.Indicator import Indicator, InputModifierType
from talipp.indicators.TrueRange import TrueRange
from talipp.input import SamplingPeriodType
from talipp.ohlcv import OHLCV


class WilderMA(Indicator):
"""Wilder's Moving Average.
Input type: `float`
Output type: `float`
Args:
period: Period.
input_values: List of input values.
input_indicator: Input indicator.
input_modifier: Input modifier.
input_sampling: Input sampling type.
"""

def __init__(self, period: int,
input_values: List[float] = None,
input_indicator: Indicator = None,
input_modifier: InputModifierType = None,
input_sampling: SamplingPeriodType = None):
super().__init__(input_modifier=input_modifier,
input_sampling=input_sampling)

self.period = period
self.k = 1.0 / self.period

self.initialize(input_values, input_indicator)

def _calculate_new_value(self) -> Any:
#if not has_valid_values(self.input_values, self.period):
if len(self.input_values) < self.period:
return None
elif has_valid_values(self.input_values, self.period, exact=True):
#elif len(self.input_values) == self.period:
return sum(self.input_values[-self.period:]) / self.period
else:
return float(self.k * self.input_values[-1] + (1.0 - self.k) * self.output_values[-1])
from talipp.ma import MAType, MAFactory


class ATR(Indicator):
Expand All @@ -59,12 +20,12 @@ class ATR(Indicator):
input_indicator: Input indicator.
input_modifier: Input modifier.
input_sampling: Input sampling type.
"""

"""
def __init__(self, period: int,
input_values: List[OHLCV] = None,
input_indicator: Indicator = None,
input_modifier: InputModifierType = None,
ma_type: MAType = MAType.WilderMA,
input_sampling: SamplingPeriodType = None):
super(ATR, self).__init__(input_modifier=input_modifier,
input_sampling=input_sampling)
Expand All @@ -74,7 +35,7 @@ def __init__(self, period: int,
self._tr = TrueRange()
self.add_sub_indicator(self._tr)

self._ma_tr = WilderMA(period, input_indicator=self._tr)
self._ma_tr = MAFactory.get_ma(ma_type, period, input_indicator=self._tr)

self.initialize(input_values, input_indicator)

Expand Down
44 changes: 44 additions & 0 deletions talipp/indicators/WilderMA.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
from typing import List, Any

from talipp.indicator_util import has_valid_values
from talipp.indicators.Indicator import Indicator, InputModifierType
from talipp.input import SamplingPeriodType


class WilderMA(Indicator):
"""Wilder's Moving Average.
Input type: `float`
Output type: `float`
Args:
period: Period.
input_values: List of input values.
input_indicator: Input indicator.
input_modifier: Input modifier.
input_sampling: Input sampling type.
"""

def __init__(self, period: int,
input_values: List[float] = None,
input_indicator: Indicator = None,
input_modifier: InputModifierType = None,
input_sampling: SamplingPeriodType = None):
super().__init__(input_modifier=input_modifier,
input_sampling=input_sampling)

self.period = period
self.k = 1.0 / self.period

self.initialize(input_values, input_indicator)

def _calculate_new_value(self) -> Any:
#if not has_valid_values(self.input_values, self.period):
if len(self.input_values) < self.period:
return None
elif has_valid_values(self.input_values, self.period, exact=True):
#elif len(self.input_values) == self.period:
return sum(self.input_values[-self.period:]) / self.period
else:
return float(self.k * self.input_values[-1] + (1.0 - self.k) * self.output_values[-1])
2 changes: 2 additions & 0 deletions talipp/indicators/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
from .VTX import VTX as VTX
from .VWAP import VWAP as VWAP
from .VWMA import VWMA as VWMA
from .WilderMA import WilderMA as WilderMA
from .WMA import WMA as WMA
from .ZigZag import ZigZag as ZigZag
from .ZLEMA import ZLEMA as ZLEMA
Expand Down Expand Up @@ -114,6 +115,7 @@
"VTX",
"VWAP",
"VWMA",
"WilderMA",
"WMA",
"ZigZag",
"ZLEMA"
Expand Down
6 changes: 6 additions & 0 deletions talipp/ma.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from talipp.indicators.TEMA import TEMA
from talipp.indicators.TRIX import TRIX
from talipp.indicators.VWMA import VWMA
from talipp.indicators.WilderMA import WilderMA
from talipp.indicators.WMA import WMA
from talipp.indicators.ZLEMA import ZLEMA

Expand Down Expand Up @@ -54,6 +55,9 @@ class MAType(Enum):
VWMA = auto()
"""[Volume Weighted Moving Average][talipp.indicators.VWMA]"""

WilderMA = auto()
"""[Wilder's Moving Average][talipp.indicators.WMA]"""

WMA = auto()
"""[Weighted Moving Average][talipp.indicators.WMA]"""

Expand Down Expand Up @@ -101,6 +105,8 @@ def get_ma(ma_type: MAType,
return HMA(period=period, input_values=input_values, input_indicator=input_indicator, input_modifier=input_modifier)
elif ma_type == MAType.VWMA:
return VWMA(period=period, input_values=input_values, input_indicator=input_indicator, input_modifier=input_modifier)
elif ma_type == MAType.WilderMA:
return WilderMA(period=period, input_values=input_values, input_indicator=input_indicator, input_modifier=input_modifier)
elif ma_type == MAType.WMA:
return WMA(period=period, input_values=input_values, input_indicator=input_indicator, input_modifier=input_modifier)
elif ma_type == MAType.T3:
Expand Down
32 changes: 32 additions & 0 deletions test/test_WilderMA.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import unittest

from talipp.indicators import WilderMA

from TalippTest import TalippTest


class Test(TalippTest):
def setUp(self) -> None:
self.input_values = list(TalippTest.CLOSE_TMPL)

def test_init(self):
ind = WilderMA(5, self.input_values)

print(ind)

self.assertAlmostEqual(ind[-3], 9.699400, places = 5)
self.assertAlmostEqual(ind[-2], 9.805521, places = 5)
self.assertAlmostEqual(ind[-1], 9.844417, places = 5)

def test_update(self):
self.assertIndicatorUpdate(WilderMA(5, self.input_values))

def test_delete(self):
self.assertIndicatorDelete(WilderMA(5, self.input_values))

def test_purge_oldest(self):
self.assertIndicatorPurgeOldest(WilderMA(5, self.input_values))


if __name__ == '__main__':
unittest.main()

0 comments on commit d5a17fc

Please sign in to comment.