diff --git a/src/rp2/abstract_entry_set.py b/src/rp2/abstract_entry_set.py index 0886c41..5c796ce 100644 --- a/src/rp2/abstract_entry_set.py +++ b/src/rp2/abstract_entry_set.py @@ -14,7 +14,7 @@ from copy import copy from datetime import date, datetime -from typing import Dict, Iterable, Iterator, List, Optional, Set +from typing import Dict, Iterable, Iterator, List, Optional, Set, TypeVar from rp2.abstract_entry import AbstractEntry from rp2.configuration import MAX_DATE, MIN_DATE, Configuration @@ -24,6 +24,8 @@ from rp2.out_transaction import OutTransaction from rp2.rp2_error import RP2TypeError, RP2ValueError +AbstractEntrySetSubclass = TypeVar("AbstractEntrySetSubclass", bound="AbstractEntrySet") + class AbstractEntrySet(Iterable[AbstractEntry]): def __init__( @@ -49,9 +51,9 @@ def __init__( self._entry_to_parent: Dict[AbstractEntry, Optional[AbstractEntry]] = {} self.__is_sorted: bool = False - def duplicate(self, from_date: date = MIN_DATE, to_date: date = MAX_DATE) -> "AbstractEntrySet": + def duplicate(self: AbstractEntrySetSubclass, from_date: date = MIN_DATE, to_date: date = MAX_DATE) -> AbstractEntrySetSubclass: # pylint: disable=protected-access - result: AbstractEntrySet = copy(self) + result: AbstractEntrySetSubclass = copy(self) result._from_date = from_date result._to_date = to_date # Force sort to recompute fields that are affected by time filter diff --git a/src/rp2/balance.py b/src/rp2/balance.py index f2dfb6f..0d286ff 100644 --- a/src/rp2/balance.py +++ b/src/rp2/balance.py @@ -15,10 +15,11 @@ from dataclasses import dataclass from datetime import date, datetime from decimal import Decimal -from typing import Callable, Dict, List, Optional, Union, cast +from typing import Callable, Dict, List, Optional from prezzemolo.utility import to_string +from rp2.abstract_entry import AbstractEntry from rp2.configuration import Configuration from rp2.in_transaction import InTransaction from rp2.input_data import InputData @@ -118,11 +119,11 @@ def __init__( from_account: Account to_account: Account - in_transactions: List[InTransaction] = cast(List[InTransaction], list(self.__input_data.unfiltered_in_transaction_set)) - intra_transactions: List[IntraTransaction] = cast(List[IntraTransaction], list(self.__input_data.unfiltered_intra_transaction_set)) - out_transactions: List[OutTransaction] = cast(List[OutTransaction], list(self.__input_data.unfiltered_out_transaction_set)) + in_transactions = list(self.__input_data.unfiltered_in_transaction_set) + intra_transactions = list(self.__input_data.unfiltered_intra_transaction_set) + out_transactions = list(self.__input_data.unfiltered_out_transaction_set) - transactions: List[Union[InTransaction, IntraTransaction, OutTransaction]] = in_transactions + intra_transactions + out_transactions + transactions = in_transactions + intra_transactions + out_transactions transactions = sorted( transactions, key=_transaction_time_sort_key, @@ -244,5 +245,5 @@ def _balance_sort_key(balance: Balance) -> str: return f"{balance.exchange}_{balance.holder}" -def _transaction_time_sort_key(transaction: Union[InTransaction, IntraTransaction, OutTransaction]) -> datetime: +def _transaction_time_sort_key(transaction: AbstractEntry) -> datetime: return transaction.timestamp diff --git a/src/rp2/computed_data.py b/src/rp2/computed_data.py index 47f0ab7..e6a929d 100644 --- a/src/rp2/computed_data.py +++ b/src/rp2/computed_data.py @@ -208,8 +208,8 @@ def __init__( TransactionSet.type_check("taxable_event_set", unfiltered_taxable_event_set, EntrySetType.MIXED, asset, True) GainLossSet.type_check("gain_loss_set", unfiltered_gain_loss_set) - self.__filtered_taxable_event_set: TransactionSet = cast(TransactionSet, unfiltered_taxable_event_set.duplicate(from_date=from_date, to_date=to_date)) - self.__filtered_gain_loss_set: GainLossSet = cast(GainLossSet, unfiltered_gain_loss_set.duplicate(from_date=from_date, to_date=to_date)) + self.__filtered_taxable_event_set: TransactionSet = unfiltered_taxable_event_set.duplicate(from_date=from_date, to_date=to_date) + self.__filtered_gain_loss_set: GainLossSet = unfiltered_gain_loss_set.duplicate(from_date=from_date, to_date=to_date) yearly_gain_loss_list: List[YearlyGainLoss] = self._create_yearly_gain_loss_list(unfiltered_gain_loss_set, to_date) LOGGER.debug("%s: Created yearly gain-loss list", input_data.asset) diff --git a/src/rp2/input_data.py b/src/rp2/input_data.py index 8e18c5a..59d881e 100644 --- a/src/rp2/input_data.py +++ b/src/rp2/input_data.py @@ -13,7 +13,6 @@ # limitations under the License. from datetime import date -from typing import cast from rp2.configuration import MAX_DATE, MIN_DATE, Configuration from rp2.entry_types import EntrySetType @@ -53,15 +52,9 @@ def __init__( if not isinstance(to_date, date): raise RP2TypeError("Parameter 'to_date' is not of type date") - self.__filtered_in_transaction_set: TransactionSet = cast( - TransactionSet, self.__unfiltered_in_transaction_set.duplicate(from_date=from_date, to_date=to_date) - ) - self.__filtered_out_transaction_set: TransactionSet = cast( - TransactionSet, self.__unfiltered_out_transaction_set.duplicate(from_date=from_date, to_date=to_date) - ) - self.__filtered_intra_transaction_set: TransactionSet = cast( - TransactionSet, self.__unfiltered_intra_transaction_set.duplicate(from_date=from_date, to_date=to_date) - ) + self.__filtered_in_transaction_set: TransactionSet = self.__unfiltered_in_transaction_set.duplicate(from_date=from_date, to_date=to_date) + self.__filtered_out_transaction_set: TransactionSet = self.__unfiltered_out_transaction_set.duplicate(from_date=from_date, to_date=to_date) + self.__filtered_intra_transaction_set: TransactionSet = self.__unfiltered_intra_transaction_set.duplicate(from_date=from_date, to_date=to_date) @property def asset(self) -> str: