diff --git a/src/rp2/abstract_entry_set.py b/src/rp2/abstract_entry_set.py index b629bc0..5c796ce 100644 --- a/src/rp2/abstract_entry_set.py +++ b/src/rp2/abstract_entry_set.py @@ -14,7 +14,7 @@ from copy import copy from datetime import date, datetime -from typing import Dict, List, Optional, Set +from typing import Dict, Iterable, Iterator, List, Optional, Set, TypeVar from rp2.abstract_entry import AbstractEntry from rp2.configuration import MAX_DATE, MIN_DATE, Configuration @@ -24,8 +24,10 @@ from rp2.out_transaction import OutTransaction from rp2.rp2_error import RP2TypeError, RP2ValueError +AbstractEntrySetSubclass = TypeVar("AbstractEntrySetSubclass", bound="AbstractEntrySet") -class AbstractEntrySet: + +class AbstractEntrySet(Iterable[AbstractEntry]): def __init__( self, configuration: Configuration, @@ -49,9 +51,9 @@ def __init__( self._entry_to_parent: Dict[AbstractEntry, Optional[AbstractEntry]] = {} self.__is_sorted: bool = False - def duplicate(self, from_date: date = MIN_DATE, to_date: date = MAX_DATE) -> "AbstractEntrySet": + def duplicate(self: AbstractEntrySetSubclass, from_date: date = MIN_DATE, to_date: date = MAX_DATE) -> AbstractEntrySetSubclass: # pylint: disable=protected-access - result: AbstractEntrySet = copy(self) + result: AbstractEntrySetSubclass = copy(self) result._from_date = from_date result._to_date = to_date # Force sort to recompute fields that are affected by time filter @@ -167,7 +169,7 @@ def __iter__(self) -> "EntrySetIterator": return EntrySetIterator(self) -class EntrySetIterator: +class EntrySetIterator(Iterator[AbstractEntry]): def __init__(self, entry_set: AbstractEntrySet) -> None: self.__entry_set: AbstractEntrySet = entry_set self.__entry_set_size: int = self.__entry_set.count diff --git a/src/rp2/balance.py b/src/rp2/balance.py index d1fa071..0d286ff 100644 --- a/src/rp2/balance.py +++ b/src/rp2/balance.py @@ -13,12 +13,13 @@ # limitations under the License. from dataclasses import dataclass -from datetime import date +from datetime import date, datetime from decimal import Decimal -from typing import Callable, Dict, List, Optional, cast +from typing import Callable, Dict, List, Optional from prezzemolo.utility import to_string +from rp2.abstract_entry import AbstractEntry from rp2.configuration import Configuration from rp2.in_transaction import InTransaction from rp2.input_data import InputData @@ -28,7 +29,6 @@ from rp2.rp2_decimal import ZERO, RP2Decimal from rp2.rp2_error import RP2TypeError, RP2ValueError - CRYPTO_BALANCE_DECIMAL_MASK: Decimal = Decimal("1." + "0" * 10) @@ -119,53 +119,60 @@ def __init__( from_account: Account to_account: Account - # Balances for bought and earned currency - for transaction in self.__input_data.unfiltered_in_transaction_set: - if transaction.timestamp.date() > to_date: - break - in_transaction: InTransaction = cast(InTransaction, transaction) - to_account = Account(in_transaction.exchange, in_transaction.holder) - acquired_balances[to_account] = acquired_balances.get(to_account, ZERO) + in_transaction.crypto_in - final_balances[to_account] = final_balances.get(to_account, ZERO) + in_transaction.crypto_in + in_transactions = list(self.__input_data.unfiltered_in_transaction_set) + intra_transactions = list(self.__input_data.unfiltered_intra_transaction_set) + out_transactions = list(self.__input_data.unfiltered_out_transaction_set) - # Balances for currency that is moved across accounts - for transaction in self.__input_data.unfiltered_intra_transaction_set: - if transaction.timestamp.date() > to_date: - break - intra_transaction: IntraTransaction = cast(IntraTransaction, transaction) - from_account = Account(intra_transaction.from_exchange, intra_transaction.from_holder) - to_account = Account(intra_transaction.to_exchange, intra_transaction.to_holder) - sent_balances[from_account] = sent_balances.get(from_account, ZERO) + intra_transaction.crypto_sent - received_balances[to_account] = received_balances.get(to_account, ZERO) + intra_transaction.crypto_received - final_balances[from_account] = final_balances.get(from_account, ZERO) - intra_transaction.crypto_sent - final_balances[to_account] = final_balances.get(to_account, ZERO) + intra_transaction.crypto_received - if ( - not RP2Decimal.is_equal_within_precision(final_balances[from_account], ZERO, CRYPTO_BALANCE_DECIMAL_MASK) - and final_balances[from_account] < ZERO - and not configuration.allow_negative_balances - ): - raise RP2ValueError( - f'{intra_transaction.asset} balance of account "{from_account.exchange}" (holder "{from_account.holder}") went negative ' - f'({final_balances[from_account]}) on the following transaction: {intra_transaction}' - ) - - # Balances for sold and gifted currency - for transaction in self.__input_data.unfiltered_out_transaction_set: + transactions = in_transactions + intra_transactions + out_transactions + transactions = sorted( + transactions, + key=_transaction_time_sort_key, + ) + + # Balances for bought and earned currency + for transaction in transactions: if transaction.timestamp.date() > to_date: break - out_transaction: OutTransaction = cast(OutTransaction, transaction) - from_account = Account(out_transaction.exchange, out_transaction.holder) - sent_balances[from_account] = sent_balances.get(from_account, ZERO) + out_transaction.crypto_out_no_fee + out_transaction.crypto_fee - final_balances[from_account] = final_balances.get(from_account, ZERO) - out_transaction.crypto_out_no_fee - out_transaction.crypto_fee - if ( - not RP2Decimal.is_equal_within_precision(final_balances[from_account], ZERO, CRYPTO_BALANCE_DECIMAL_MASK) - and final_balances[from_account] < ZERO - and not configuration.allow_negative_balances - ): - raise RP2ValueError( - f'{out_transaction.asset} balance of account "{from_account.exchange}" (holder "{from_account.holder}") went negative ' - f'({final_balances[from_account]}) on the following transaction: {out_transaction}' - ) + if isinstance(transaction, InTransaction): + in_transaction: InTransaction = transaction + to_account = Account(in_transaction.exchange, in_transaction.holder) + acquired_balances[to_account] = acquired_balances.get(to_account, ZERO) + in_transaction.crypto_in + final_balances[to_account] = final_balances.get(to_account, ZERO) + in_transaction.crypto_in + + # Balances for currency that is moved across accounts + if isinstance(transaction, IntraTransaction): + intra_transaction: IntraTransaction = transaction + from_account = Account(intra_transaction.from_exchange, intra_transaction.from_holder) + to_account = Account(intra_transaction.to_exchange, intra_transaction.to_holder) + sent_balances[from_account] = sent_balances.get(from_account, ZERO) + intra_transaction.crypto_sent + received_balances[to_account] = received_balances.get(to_account, ZERO) + intra_transaction.crypto_received + final_balances[from_account] = final_balances.get(from_account, ZERO) - intra_transaction.crypto_sent + final_balances[to_account] = final_balances.get(to_account, ZERO) + intra_transaction.crypto_received + if ( + not RP2Decimal.is_equal_within_precision(final_balances[from_account], ZERO, CRYPTO_BALANCE_DECIMAL_MASK) + and final_balances[from_account] < ZERO + and not configuration.allow_negative_balances + ): + raise RP2ValueError( + f'{intra_transaction.asset} balance of account "{from_account.exchange}" (holder "{from_account.holder}") went negative ' + f"({final_balances[from_account]}) on the following transaction: {intra_transaction}" + ) + + # Balances for sold and gifted currency + if isinstance(transaction, OutTransaction): + out_transaction: OutTransaction = transaction + from_account = Account(out_transaction.exchange, out_transaction.holder) + sent_balances[from_account] = sent_balances.get(from_account, ZERO) + out_transaction.crypto_out_no_fee + out_transaction.crypto_fee + final_balances[from_account] = final_balances.get(from_account, ZERO) - out_transaction.crypto_out_no_fee - out_transaction.crypto_fee + if ( + not RP2Decimal.is_equal_within_precision(final_balances[from_account], ZERO, CRYPTO_BALANCE_DECIMAL_MASK) + and final_balances[from_account] < ZERO + and not configuration.allow_negative_balances + ): + raise RP2ValueError( + f'{out_transaction.asset} balance of account "{from_account.exchange}" (holder "{from_account.holder}") went negative ' + f"({final_balances[from_account]}) on the following transaction: {out_transaction}" + ) for account, final_balance in final_balances.items(): balance = Balance( @@ -236,3 +243,7 @@ def __next__(self) -> Balance: def _balance_sort_key(balance: Balance) -> str: return f"{balance.exchange}_{balance.holder}" + + +def _transaction_time_sort_key(transaction: AbstractEntry) -> datetime: + return transaction.timestamp diff --git a/src/rp2/computed_data.py b/src/rp2/computed_data.py index 47f0ab7..e6a929d 100644 --- a/src/rp2/computed_data.py +++ b/src/rp2/computed_data.py @@ -208,8 +208,8 @@ def __init__( TransactionSet.type_check("taxable_event_set", unfiltered_taxable_event_set, EntrySetType.MIXED, asset, True) GainLossSet.type_check("gain_loss_set", unfiltered_gain_loss_set) - self.__filtered_taxable_event_set: TransactionSet = cast(TransactionSet, unfiltered_taxable_event_set.duplicate(from_date=from_date, to_date=to_date)) - self.__filtered_gain_loss_set: GainLossSet = cast(GainLossSet, unfiltered_gain_loss_set.duplicate(from_date=from_date, to_date=to_date)) + self.__filtered_taxable_event_set: TransactionSet = unfiltered_taxable_event_set.duplicate(from_date=from_date, to_date=to_date) + self.__filtered_gain_loss_set: GainLossSet = unfiltered_gain_loss_set.duplicate(from_date=from_date, to_date=to_date) yearly_gain_loss_list: List[YearlyGainLoss] = self._create_yearly_gain_loss_list(unfiltered_gain_loss_set, to_date) LOGGER.debug("%s: Created yearly gain-loss list", input_data.asset) diff --git a/src/rp2/input_data.py b/src/rp2/input_data.py index 8e18c5a..59d881e 100644 --- a/src/rp2/input_data.py +++ b/src/rp2/input_data.py @@ -13,7 +13,6 @@ # limitations under the License. from datetime import date -from typing import cast from rp2.configuration import MAX_DATE, MIN_DATE, Configuration from rp2.entry_types import EntrySetType @@ -53,15 +52,9 @@ def __init__( if not isinstance(to_date, date): raise RP2TypeError("Parameter 'to_date' is not of type date") - self.__filtered_in_transaction_set: TransactionSet = cast( - TransactionSet, self.__unfiltered_in_transaction_set.duplicate(from_date=from_date, to_date=to_date) - ) - self.__filtered_out_transaction_set: TransactionSet = cast( - TransactionSet, self.__unfiltered_out_transaction_set.duplicate(from_date=from_date, to_date=to_date) - ) - self.__filtered_intra_transaction_set: TransactionSet = cast( - TransactionSet, self.__unfiltered_intra_transaction_set.duplicate(from_date=from_date, to_date=to_date) - ) + self.__filtered_in_transaction_set: TransactionSet = self.__unfiltered_in_transaction_set.duplicate(from_date=from_date, to_date=to_date) + self.__filtered_out_transaction_set: TransactionSet = self.__unfiltered_out_transaction_set.duplicate(from_date=from_date, to_date=to_date) + self.__filtered_intra_transaction_set: TransactionSet = self.__unfiltered_intra_transaction_set.duplicate(from_date=from_date, to_date=to_date) @property def asset(self) -> str: diff --git a/src/rp2/plugin/report/jp/tax_report_jp.py b/src/rp2/plugin/report/jp/tax_report_jp.py index 538c414..2eca443 100644 --- a/src/rp2/plugin/report/jp/tax_report_jp.py +++ b/src/rp2/plugin/report/jp/tax_report_jp.py @@ -17,10 +17,9 @@ from enum import Enum from itertools import chain from pathlib import Path -from typing import Any, Dict, List, NamedTuple, Optional, Set, cast +from typing import Any, Dict, List, NamedTuple, Optional, Set from rp2.abstract_country import AbstractCountry -from rp2.abstract_entry import AbstractEntry from rp2.abstract_transaction import AbstractTransaction from rp2.computed_data import ComputedData from rp2.configuration import MAX_DATE, MIN_DATE @@ -169,15 +168,14 @@ def __generate_asset(self, computed_data: ComputedData, output_file: Any) -> Non in_transaction_set: TransactionSet = computed_data.in_transaction_set out_transaction_set: TransactionSet = computed_data.out_transaction_set intra_transaction_set: TransactionSet = computed_data.intra_transaction_set - entry: AbstractEntry + entry: AbstractTransaction year: int years_2_transaction_sets: Dict[int, List[AbstractTransaction]] = {} previous_year_row_offset: int = 0 # Sort all in and out transactions by year, the fee from intra transactions must be reported for entry in chain(in_transaction_set, out_transaction_set, intra_transaction_set): # type: ignore - transaction: AbstractTransaction = cast(AbstractTransaction, entry) - years_2_transaction_sets.setdefault(transaction.timestamp.year, []).append(entry) + years_2_transaction_sets.setdefault(entry.timestamp.year, []).append(entry) for year, transaction_set in years_2_transaction_sets.items(): # Sort the transactions by timestamp and generate sheet by year diff --git a/tests/test_balance.py b/tests/test_balance.py new file mode 100644 index 0000000..a31f186 --- /dev/null +++ b/tests/test_balance.py @@ -0,0 +1,168 @@ +# Copyright 2024 qwhelan +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +from datetime import date + +from rp2.balance import BalanceSet +from rp2.configuration import Configuration +from rp2.in_transaction import InTransaction +from rp2.input_data import InputData +from rp2.out_transaction import OutTransaction +from rp2.plugin.country.us import US +from rp2.rp2_decimal import RP2Decimal +from rp2.rp2_error import RP2ValueError +from rp2.transaction_set import TransactionSet + + +class TestBalanceSet(unittest.TestCase): + _configuration: Configuration + + @classmethod + def setUpClass(cls) -> None: + TestBalanceSet._configuration = Configuration("./config/test_data.ini", US()) + + def setUp(self) -> None: + self.maxDiff = None # pylint: disable=invalid-name + + def test_easy_negative_case(self) -> None: + """ + Check that an exception is raised in the case where sum(OUT) > sum(IN) + """ + asset = "B1" + end_date = date(2024, 1, 1) + in_transaction_set: TransactionSet = TransactionSet(self._configuration, "IN", asset) + out_transaction_set: TransactionSet = TransactionSet(self._configuration, "OUT", asset) + intra_transaction_set: TransactionSet = TransactionSet(self._configuration, "INTRA", asset) + + transaction1: InTransaction = InTransaction( + self._configuration, + "1/8/2021 8:42:43.883 -04:00", + asset, + "Coinbase", + "Alice", + "BUY", + RP2Decimal("1000"), + RP2Decimal("3.0002"), + fiat_fee=RP2Decimal("20"), + fiat_in_no_fee=RP2Decimal("3000.2"), + fiat_in_with_fee=RP2Decimal("3020.2"), + internal_id=30, + ) + in_transaction_set.add_entry(transaction1) + + transaction2: OutTransaction = OutTransaction( + self._configuration, + "1/9/2021 8:42:43.883 -04:00", + asset, + "Coinbase", + "Alice", + "SELL", + RP2Decimal("1000"), + RP2Decimal("4.0000"), + crypto_fee=RP2Decimal("0.0002"), + fiat_out_no_fee=RP2Decimal("4000.0"), + internal_id=31, + ) + out_transaction_set.add_entry(transaction2) + + input_data = InputData(asset, in_transaction_set, out_transaction_set, intra_transaction_set) + + with self.assertRaisesRegex( + RP2ValueError, r'B1 balance of account "Coinbase" \(holder "Alice"\) went negative \(-1.0000\) on the following transaction: .*' + ): + BalanceSet(self._configuration, input_data, end_date) + + def test_hard_negative_case(self) -> None: + """ + Check that an exception is raised in the case where sum(OUT) > sum(IN) only briefly + """ + asset = "B1" + end_date = date(2024, 1, 1) + in_transaction_set: TransactionSet = TransactionSet(self._configuration, "IN", asset) + out_transaction_set: TransactionSet = TransactionSet(self._configuration, "OUT", asset) + intra_transaction_set: TransactionSet = TransactionSet(self._configuration, "INTRA", asset) + + transaction1: InTransaction = InTransaction( + self._configuration, + "1/8/2021 8:42:43.883 -04:00", + asset, + "Coinbase", + "Alice", + "BUY", + RP2Decimal("1000"), + RP2Decimal("3.0002"), + fiat_fee=RP2Decimal("20"), + fiat_in_no_fee=RP2Decimal("3000.2"), + fiat_in_with_fee=RP2Decimal("3020.2"), + internal_id=30, + ) + in_transaction_set.add_entry(transaction1) + + transaction2: OutTransaction = OutTransaction( + self._configuration, + "1/9/2021 8:42:43.883 -04:00", + asset, + "Coinbase", + "Alice", + "SELL", + RP2Decimal("1000"), + RP2Decimal("4.0000"), + crypto_fee=RP2Decimal("0.0002"), + fiat_out_no_fee=RP2Decimal("6000.0"), + internal_id=31, + ) + out_transaction_set.add_entry(transaction2) + + transaction3: InTransaction = InTransaction( + self._configuration, + "1/10/2021 8:42:43.883 -04:00", + asset, + "Coinbase", + "Alice", + "BUY", + RP2Decimal("1000"), + RP2Decimal("3.0002"), + fiat_fee=RP2Decimal("20"), + fiat_in_no_fee=RP2Decimal("3000.2"), + fiat_in_with_fee=RP2Decimal("3020.2"), + internal_id=32, + ) + in_transaction_set.add_entry(transaction3) + + transaction4: OutTransaction = OutTransaction( + self._configuration, + "1/11/2021 8:42:43.883 -04:00", + asset, + "Coinbase", + "Alice", + "SELL", + RP2Decimal("1000"), + RP2Decimal("2.0000"), + crypto_fee=RP2Decimal("0.0002"), + fiat_out_no_fee=RP2Decimal("2000.0"), + internal_id=33, + ) + out_transaction_set.add_entry(transaction4) + + input_data = InputData(asset, in_transaction_set, out_transaction_set, intra_transaction_set) + + with self.assertRaisesRegex( + RP2ValueError, r'B1 balance of account "Coinbase" \(holder "Alice"\) went negative \(-1.0000\) on the following transaction: .*' + ): + BalanceSet(self._configuration, input_data, end_date) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_localized_output.py b/tests/test_localized_output.py index 4c780be..c6ec42d 100644 --- a/tests/test_localized_output.py +++ b/tests/test_localized_output.py @@ -34,7 +34,15 @@ def setUpClass(cls) -> None: # To test localization plumbing, we generate Japanese taxes for test_data in Kalaallisut language. Note that the localization # file (locales/kl/LC_MESSAGES/messages.po) doesn't contain real Kalaallisut translations, but only placeholder strings starting # with "__test_": this is good enough to test localization plumbing (and it would work in the same way with a real translation). - cls._generate(cls.output_dir, test_name="test_data", config="test_data", method="fifo", generation_language="kl", country="jp") + cls._generate( + cls.output_dir, + test_name="test_data", + config="test_data", + method="fifo", + generation_language="kl", + country="jp", + allow_negative_balances=True, + ) def setUp(self) -> None: self.maxDiff = None # pylint: disable=invalid-name diff --git a/tests/test_ods_output_diff.py b/tests/test_ods_output_diff.py index af6bd6a..2ea13a5 100644 --- a/tests/test_ods_output_diff.py +++ b/tests/test_ods_output_diff.py @@ -36,9 +36,9 @@ def setUpClass(cls) -> None: AbstractTestODSOutputDiff._generate( cls.output_dir, test_name="crypto_example", config="crypto_example", method=method, allow_negative_balances=True ) - AbstractTestODSOutputDiff._generate(cls.output_dir, test_name="test_data", config="test_data", method=method) + AbstractTestODSOutputDiff._generate(cls.output_dir, test_name="test_data", config="test_data", method=method, allow_negative_balances=True) AbstractTestODSOutputDiff._generate(cls.output_dir, test_name="test_data2", config="test_data", method=method, allow_negative_balances=True) - AbstractTestODSOutputDiff._generate(cls.output_dir, test_name="test_data3", config="test_data", method=method) + AbstractTestODSOutputDiff._generate(cls.output_dir, test_name="test_data3", config="test_data", method=method, allow_negative_balances=True) AbstractTestODSOutputDiff._generate(cls.output_dir, test_name="test_data4", config="test_data4", method=method) AbstractTestODSOutputDiff._generate(cls.output_dir, test_name="test_hifo", config="test_data", method=method, allow_negative_balances=True) AbstractTestODSOutputDiff._generate(cls.output_dir, test_name="test_hifo2", config="test_data", method=method, allow_negative_balances=True) diff --git a/tests/test_ods_output_diff_es.py b/tests/test_ods_output_diff_es.py index 4ec69c1..5d3ba49 100644 --- a/tests/test_ods_output_diff_es.py +++ b/tests/test_ods_output_diff_es.py @@ -39,11 +39,15 @@ def setUpClass(cls) -> None: generation_language="es", allow_negative_balances=True, ) - AbstractTestODSOutputDiff._generate(cls.output_dir, test_name="test_data", config="test_data", method="fifo", country="es", generation_language="es") + AbstractTestODSOutputDiff._generate( + cls.output_dir, test_name="test_data", config="test_data", method="fifo", country="es", generation_language="es", allow_negative_balances=True + ) AbstractTestODSOutputDiff._generate( cls.output_dir, test_name="test_data2", config="test_data", method="fifo", country="es", generation_language="es", allow_negative_balances=True ) - AbstractTestODSOutputDiff._generate(cls.output_dir, test_name="test_data3", config="test_data", method="fifo", country="es", generation_language="es") + AbstractTestODSOutputDiff._generate( + cls.output_dir, test_name="test_data3", config="test_data", method="fifo", country="es", generation_language="es", allow_negative_balances=True + ) AbstractTestODSOutputDiff._generate(cls.output_dir, test_name="test_data4", config="test_data4", method="fifo", country="es", generation_language="es") AbstractTestODSOutputDiff._generate( cls.output_dir, diff --git a/tests/test_ods_output_diff_jp.py b/tests/test_ods_output_diff_jp.py index 8f2e3c0..504c700 100755 --- a/tests/test_ods_output_diff_jp.py +++ b/tests/test_ods_output_diff_jp.py @@ -40,11 +40,15 @@ def setUpClass(cls) -> None: generation_language="en", allow_negative_balances=True, ) - AbstractTestODSOutputDiff._generate(cls.output_dir, test_name="test_data", config="test_data", method="fifo", country="jp", generation_language="en") + AbstractTestODSOutputDiff._generate( + cls.output_dir, test_name="test_data", config="test_data", method="fifo", country="jp", generation_language="en", allow_negative_balances=True + ) AbstractTestODSOutputDiff._generate( cls.output_dir, test_name="test_data2", config="test_data", method="fifo", country="jp", generation_language="en", allow_negative_balances=True ) - AbstractTestODSOutputDiff._generate(cls.output_dir, test_name="test_data3", config="test_data", method="fifo", country="jp", generation_language="en") + AbstractTestODSOutputDiff._generate( + cls.output_dir, test_name="test_data3", config="test_data", method="fifo", country="jp", generation_language="en", allow_negative_balances=True + ) AbstractTestODSOutputDiff._generate(cls.output_dir, test_name="test_data4", config="test_data4", method="fifo", country="jp", generation_language="en") AbstractTestODSOutputDiff._generate( cls.output_dir, diff --git a/tests/test_tax_engine.py b/tests/test_tax_engine.py index 6ac45a7..03debeb 100644 --- a/tests/test_tax_engine.py +++ b/tests/test_tax_engine.py @@ -33,12 +33,14 @@ class TestTaxEngine(unittest.TestCase): _good_input_configuration: Configuration + _good_input_allow_negative_balance_configuration: Configuration _bad_input_configuration: Configuration _accounting_engine: AccountingEngine @classmethod def setUpClass(cls) -> None: TestTaxEngine._good_input_configuration = Configuration("./config/test_data.ini", US()) + TestTaxEngine._good_input_allow_negative_balance_configuration = Configuration("./config/test_data.ini", US(), allow_negative_balances=True) TestTaxEngine._bad_input_configuration = Configuration("./config/test_bad_data.ini", US()) years_2_methods = AVLTree[int, AbstractAccountingMethod]() years_2_methods.insert_node(MIN_DATE.year, AccountingMethod()) @@ -49,19 +51,24 @@ def setUp(self) -> None: def test_good_input(self) -> None: self._verify_good_output("B1") - self._verify_good_output("B2") - self._verify_good_output("B3") - self._verify_good_output("B4") + self._verify_good_output("B2", allow_negative_balances=True) + self._verify_good_output("B3", allow_negative_balances=True) + self._verify_good_output("B4", allow_negative_balances=True) + + def _verify_good_output(self, sheet_name: str, allow_negative_balances: bool = False) -> None: + if allow_negative_balances: + config = self._good_input_allow_negative_balance_configuration + else: + config = self._good_input_configuration - def _verify_good_output(self, sheet_name: str) -> None: asset = sheet_name # Parser is tested separately (on same input) in test_input_parser.py - input_file_handle: object = open_ods(self._good_input_configuration, "./input/test_data.ods") - input_data: InputData = parse_ods(self._good_input_configuration, asset, input_file_handle) + input_file_handle: object = open_ods(config, "./input/test_data.ods") + input_data: InputData = parse_ods(config, asset, input_file_handle) # In table is always present - computed_data: ComputedData = compute_tax(self._good_input_configuration, self._accounting_engine, input_data) + computed_data: ComputedData = compute_tax(config, self._accounting_engine, input_data) if asset in RP2_TEST_OUTPUT: self.assertEqual(str(computed_data.gain_loss_set), RP2_TEST_OUTPUT[asset])