Skip to content

Commit

Permalink
Use bounded typevar to avoid casting .duplicate()
Browse files Browse the repository at this point in the history
  • Loading branch information
qwhelan committed May 14, 2024
1 parent 24695e5 commit 3d144f5
Show file tree
Hide file tree
Showing 4 changed files with 17 additions and 21 deletions.
8 changes: 5 additions & 3 deletions src/rp2/abstract_entry_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

from copy import copy
from datetime import date, datetime
from typing import Dict, Iterable, Iterator, List, Optional, Set
from typing import Dict, Iterable, Iterator, List, Optional, Set, TypeVar

from rp2.abstract_entry import AbstractEntry
from rp2.configuration import MAX_DATE, MIN_DATE, Configuration
Expand All @@ -24,6 +24,8 @@
from rp2.out_transaction import OutTransaction
from rp2.rp2_error import RP2TypeError, RP2ValueError

T = TypeVar("T", bound="AbstractEntrySet")


class AbstractEntrySet(Iterable[AbstractEntry]):
def __init__(
Expand All @@ -49,9 +51,9 @@ def __init__(
self._entry_to_parent: Dict[AbstractEntry, Optional[AbstractEntry]] = {}
self.__is_sorted: bool = False

def duplicate(self, from_date: date = MIN_DATE, to_date: date = MAX_DATE) -> "AbstractEntrySet":
def duplicate(self: T, from_date: date = MIN_DATE, to_date: date = MAX_DATE) -> T:
# pylint: disable=protected-access
result: AbstractEntrySet = copy(self)
result: T = copy(self)
result._from_date = from_date
result._to_date = to_date
# Force sort to recompute fields that are affected by time filter
Expand Down
13 changes: 7 additions & 6 deletions src/rp2/balance.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,11 @@
from dataclasses import dataclass
from datetime import date, datetime
from decimal import Decimal
from typing import Callable, Dict, List, Optional, Union, cast
from typing import Callable, Dict, List, Optional

from prezzemolo.utility import to_string

from rp2.abstract_entry import AbstractEntry
from rp2.configuration import Configuration
from rp2.in_transaction import InTransaction
from rp2.input_data import InputData
Expand Down Expand Up @@ -118,11 +119,11 @@ def __init__(
from_account: Account
to_account: Account

in_transactions: List[InTransaction] = cast(List[InTransaction], list(self.__input_data.unfiltered_in_transaction_set))
intra_transactions: List[IntraTransaction] = cast(List[IntraTransaction], list(self.__input_data.unfiltered_intra_transaction_set))
out_transactions: List[OutTransaction] = cast(List[OutTransaction], list(self.__input_data.unfiltered_out_transaction_set))
in_transactions = list(self.__input_data.unfiltered_in_transaction_set)
intra_transactions = list(self.__input_data.unfiltered_intra_transaction_set)
out_transactions = list(self.__input_data.unfiltered_out_transaction_set)

transactions: List[Union[InTransaction, IntraTransaction, OutTransaction]] = in_transactions + intra_transactions + out_transactions
transactions = in_transactions + intra_transactions + out_transactions
transactions = sorted(
transactions,
key=_transaction_time_sort_key,
Expand Down Expand Up @@ -244,5 +245,5 @@ def _balance_sort_key(balance: Balance) -> str:
return f"{balance.exchange}_{balance.holder}"


def _transaction_time_sort_key(transaction: Union[InTransaction, IntraTransaction, OutTransaction]) -> datetime:
def _transaction_time_sort_key(transaction: AbstractEntry) -> datetime:
return transaction.timestamp
4 changes: 2 additions & 2 deletions src/rp2/computed_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,8 +208,8 @@ def __init__(
TransactionSet.type_check("taxable_event_set", unfiltered_taxable_event_set, EntrySetType.MIXED, asset, True)
GainLossSet.type_check("gain_loss_set", unfiltered_gain_loss_set)

self.__filtered_taxable_event_set: TransactionSet = cast(TransactionSet, unfiltered_taxable_event_set.duplicate(from_date=from_date, to_date=to_date))
self.__filtered_gain_loss_set: GainLossSet = cast(GainLossSet, unfiltered_gain_loss_set.duplicate(from_date=from_date, to_date=to_date))
self.__filtered_taxable_event_set: TransactionSet = unfiltered_taxable_event_set.duplicate(from_date=from_date, to_date=to_date)
self.__filtered_gain_loss_set: GainLossSet = unfiltered_gain_loss_set.duplicate(from_date=from_date, to_date=to_date)

yearly_gain_loss_list: List[YearlyGainLoss] = self._create_yearly_gain_loss_list(unfiltered_gain_loss_set, to_date)
LOGGER.debug("%s: Created yearly gain-loss list", input_data.asset)
Expand Down
13 changes: 3 additions & 10 deletions src/rp2/input_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
# limitations under the License.

from datetime import date
from typing import cast

from rp2.configuration import MAX_DATE, MIN_DATE, Configuration
from rp2.entry_types import EntrySetType
Expand Down Expand Up @@ -53,15 +52,9 @@ def __init__(
if not isinstance(to_date, date):
raise RP2TypeError("Parameter 'to_date' is not of type date")

self.__filtered_in_transaction_set: TransactionSet = cast(
TransactionSet, self.__unfiltered_in_transaction_set.duplicate(from_date=from_date, to_date=to_date)
)
self.__filtered_out_transaction_set: TransactionSet = cast(
TransactionSet, self.__unfiltered_out_transaction_set.duplicate(from_date=from_date, to_date=to_date)
)
self.__filtered_intra_transaction_set: TransactionSet = cast(
TransactionSet, self.__unfiltered_intra_transaction_set.duplicate(from_date=from_date, to_date=to_date)
)
self.__filtered_in_transaction_set: TransactionSet = self.__unfiltered_in_transaction_set.duplicate(from_date=from_date, to_date=to_date)
self.__filtered_out_transaction_set: TransactionSet = self.__unfiltered_out_transaction_set.duplicate(from_date=from_date, to_date=to_date)
self.__filtered_intra_transaction_set: TransactionSet = self.__unfiltered_intra_transaction_set.duplicate(from_date=from_date, to_date=to_date)

@property
def asset(self) -> str:
Expand Down

0 comments on commit 3d144f5

Please sign in to comment.