Skip to content

Commit

Permalink
use more types; use Adapter for compat with old importers in hooks
Browse files Browse the repository at this point in the history
  • Loading branch information
yagebu committed Jan 3, 2025
1 parent 8828038 commit 7e222ad
Show file tree
Hide file tree
Showing 5 changed files with 103 additions and 52 deletions.
8 changes: 8 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,14 @@ build-backend = "setuptools.build_meta"
[tool.black]
line-length = 79

[[tool.mypy.overrides]]
module = ["beancount.*"]
follow_untyped_imports = true

[[tool.mypy.overrides]]
module = ["beangulp.*"]
follow_untyped_imports = true

[tool.ruff]
target-version = "py38"
line-length = 79
Expand Down
29 changes: 25 additions & 4 deletions smart_importer/hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,24 @@

import logging
from functools import wraps
from typing import Callable, Sequence

from beancount.core import data
from beangulp import Adapter, Importer, ImporterProtocol

logger = logging.getLogger(__name__) # pylint: disable=invalid-name


class ImporterHook:
"""Interface for an importer hook."""

def __call__(self, importer, file, imported_entries, existing):
def __call__(
self,
importer: Importer,
file: str,
imported_entries: data.Directives,
existing: data.Directives,
) -> data.Directives:
"""Apply the hook and modify the imported entries.
Args:
Expand All @@ -25,20 +35,31 @@ def __call__(self, importer, file, imported_entries, existing):
raise NotImplementedError


def apply_hooks(importer, hooks):
def apply_hooks(
importer: Importer | ImporterProtocol,
hooks: Sequence[
Callable[
[Importer, str, data.Directives, data.Directives], data.Directives
]
],
) -> Importer:
"""Apply a list of importer hooks to an importer.
Args:
importer: An importer instance.
hooks: A list of hooks, each a callable object.
"""

if not isinstance(importer, Importer):
importer = Adapter(importer)
unpatched_extract = importer.extract

@wraps(unpatched_extract)
def patched_extract_method(filepath, existing=None):
def patched_extract_method(
filepath: str, existing: data.Directives
) -> data.Directives:
logger.debug("Calling the importer's extract method.")
imported_entries = unpatched_extract(filepath, existing=existing)
imported_entries = unpatched_extract(filepath, existing)

for hook in hooks:
imported_entries = hook(
Expand Down
34 changes: 23 additions & 11 deletions smart_importer/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@

import logging
import threading
from typing import TYPE_CHECKING, Callable
from typing import TYPE_CHECKING, Any, Callable

from beancount.core import data
from beancount.core.data import (
ALL_DIRECTIVES,
Close,
Open,
Transaction,
Expand All @@ -27,6 +27,7 @@
from smart_importer.pipelines import get_pipeline

if TYPE_CHECKING:
from beangulp import Importer
from sklearn import Pipeline

logger = logging.getLogger(__name__) # pylint: disable=invalid-name
Expand All @@ -53,25 +54,31 @@ class EntryPredictor(ImporterHook):

def __init__(
self,
predict=True,
overwrite=False,
predict: bool = True,
overwrite: bool = False,
string_tokenizer: Callable[[str], list] | None = None,
denylist_accounts: list[str] | None = None,
):
) -> None:
super().__init__()
self.training_data = None
self.open_accounts: dict[str, str] = {}
self.denylist_accounts = set(denylist_accounts or [])
self.pipeline: Pipeline | None = None
self.is_fitted = False
self.lock = threading.Lock()
self.account = None
self.account: str | None = None

self.predict = predict
self.overwrite = overwrite
self.string_tokenizer = string_tokenizer

def __call__(self, importer, file, imported_entries, existing_entries):
def __call__(
self,
importer: Importer,
file: str,
imported_entries: data.Directives,
existing_entries: data.Directives,
) -> data.Directives:
"""Predict attributes for imported transactions.
Args:
Expand Down Expand Up @@ -157,7 +164,7 @@ def targets(self):
for entry in self.training_data
]

def define_pipeline(self):
def define_pipeline(self) -> None:
"""Defines the machine learning pipeline based on given weights."""

transformers = [
Expand All @@ -172,7 +179,7 @@ def define_pipeline(self):
SVC(kernel="linear"),
)

def train_pipeline(self):
def train_pipeline(self) -> None:
"""Train the machine learning pipeline."""

self.is_fitted = False
Expand All @@ -187,11 +194,14 @@ def train_pipeline(self):
self.is_fitted = True
logger.debug("Only one target possible.")
else:
assert self.pipeline is not None
self.pipeline.fit(self.training_data, self.targets)
self.is_fitted = True
logger.debug("Trained the machine learning model.")

def process_entries(self, imported_entries) -> list[ALL_DIRECTIVES]:
def process_entries(
self, imported_entries: data.Directives
) -> data.Directives:
"""Process imported entries.
Transactions might be modified, all other entries are left as is.
Expand All @@ -206,7 +216,9 @@ def process_entries(self, imported_entries) -> list[ALL_DIRECTIVES]:
imported_entries, enhanced_transactions
)

def apply_prediction(self, entry, prediction):
def apply_prediction(
self, entry: data.Transaction, prediction: Any
) -> data.Transaction:
"""Apply a single prediction to an entry.
Args:
Expand Down
28 changes: 17 additions & 11 deletions tests/data_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,26 +4,30 @@
import os
import pprint
import re
from typing import Callable

import pytest
from beancount.core import data
from beancount.core.compare import stable_hash_namedtuple
from beancount.parser import parser
from beangulp import Importer

from smart_importer import PredictPostings, apply_hooks


def chinese_string_tokenizer(pre_tokenizer_string):
def chinese_string_tokenizer(pre_tokenizer_string: str) -> list[str]:
jieba = pytest.importorskip("jieba")
jieba.initialize()
return list(jieba.cut(pre_tokenizer_string))


def _hash(entry):
def _hash(entry: data.Directive) -> str:
return stable_hash_namedtuple(entry, ignore={"meta", "units"})


def _load_testset(testset):
def _load_testset(
testset: str,
) -> tuple[data.Directives, data.Directives, data.Directives]:
path = os.path.join(
os.path.dirname(__file__), "data", testset + ".beancount"
)
Expand All @@ -35,7 +39,7 @@ def _load_testset(testset):
assert not errors
parsed_sections.append(entries)
assert len(parsed_sections) == 3
return parsed_sections
return tuple(parsed_sections)


@pytest.mark.parametrize(
Expand All @@ -47,25 +51,27 @@ def _load_testset(testset):
("chinese", chinese_string_tokenizer),
],
)
def test_testset(testset, string_tokenizer):
def test_testset(
testset: str, string_tokenizer: Callable[[str], list[str]]
) -> None:
# pylint: disable=unbalanced-tuple-unpacking
imported, training_data, expected = _load_testset(testset)

class DummyImporter(Importer):
def extract(self, filepath, existing=None):
def extract(
self, filepath: str, existing: data.Directives
) -> data.Directives:
return imported

def account(self, filepath):
def account(self, filepath: str) -> str:
return ""

def identify(self, filepath):
def identify(self, filepath: str) -> bool:
return True

importer = DummyImporter()
apply_hooks(importer, [PredictPostings(string_tokenizer=string_tokenizer)])
imported_transactions = importer.extract(
"dummy-data", existing=training_data
)
imported_transactions = importer.extract("dummy-data", training_data)

for txn1, txn2 in zip(imported_transactions, expected):
if _hash(txn1) != _hash(txn2):
Expand Down
56 changes: 30 additions & 26 deletions tests/predictors_test.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Tests for the `PredictPayees` and the `PredictPostings` decorator"""

# pylint: disable=missing-docstring
from beancount.core import data
from beancount.parser import parser
from beangulp import Importer

Expand Down Expand Up @@ -133,18 +134,20 @@


class BasicTestImporter(Importer):
def extract(self, filepath, existing=None):
def extract(
self, filepath: str, existing: data.Directives
) -> data.Directives:
if filepath == "dummy-data":
return TEST_DATA
if filepath == "empty":
return []
assert False
return []

def account(self, filepath):
def account(self, filepath: str) -> str:
return "Assets:US:BofA:Checking"

def identify(self, filepath):
def identify(self, filepath: str) -> bool:
return True


Expand All @@ -155,39 +158,38 @@ def identify(self, filepath):
)


def test_empty_training_data():
def test_empty_training_data() -> None:
"""
Verifies that the decorator leaves the narration intact.
"""
assert POSTING_IMPORTER.extract("dummy-data") == TEST_DATA
assert PAYEE_IMPORTER.extract("dummy-data") == TEST_DATA
assert POSTING_IMPORTER.extract("dummy-data", []) == TEST_DATA
assert PAYEE_IMPORTER.extract("dummy-data", []) == TEST_DATA


def test_no_transactions():
def test_no_transactions() -> None:
"""
Should not crash when passed empty list of transactions.
"""
POSTING_IMPORTER.extract("empty")
PAYEE_IMPORTER.extract("empty")
POSTING_IMPORTER.extract("empty", existing=TRAINING_DATA)
PAYEE_IMPORTER.extract("empty", existing=TRAINING_DATA)
POSTING_IMPORTER.extract("empty", [])
PAYEE_IMPORTER.extract("empty", [])
POSTING_IMPORTER.extract("empty", TRAINING_DATA)
PAYEE_IMPORTER.extract("empty", TRAINING_DATA)


def test_unchanged_narrations():
def test_unchanged_narrations() -> None:
"""
Verifies that the decorator leaves the narration intact
"""
correct_narrations = [transaction.narration for transaction in TEST_DATA]
extracted_narrations = [
transaction.narration
for transaction in PAYEE_IMPORTER.extract(
"dummy-data", existing=TRAINING_DATA
)
for transaction in PAYEE_IMPORTER.extract("dummy-data", TRAINING_DATA)
if isinstance(transaction, data.Transaction)
]
assert extracted_narrations == correct_narrations


def test_unchanged_first_posting():
def test_unchanged_first_posting() -> None:
"""
Verifies that the decorator leaves the first posting intact
"""
Expand All @@ -196,30 +198,32 @@ def test_unchanged_first_posting():
]
extracted_first_postings = [
transaction.postings[0]
for transaction in PAYEE_IMPORTER.extract(
"dummy-data", existing=TRAINING_DATA
)
for transaction in PAYEE_IMPORTER.extract("dummy-data", TRAINING_DATA)
if isinstance(transaction, data.Transaction)
]
assert extracted_first_postings == correct_first_postings


def test_payee_predictions():
def test_payee_predictions() -> None:
"""
Verifies that the decorator adds predicted postings.
"""
transactions = PAYEE_IMPORTER.extract("dummy-data", existing=TRAINING_DATA)
predicted_payees = [transaction.payee for transaction in transactions]
transactions = PAYEE_IMPORTER.extract("dummy-data", TRAINING_DATA)
predicted_payees = [
transaction.payee
for transaction in transactions
if isinstance(transaction, data.Transaction)
]
assert predicted_payees == PAYEE_PREDICTIONS


def test_account_predictions():
def test_account_predictions() -> None:
"""
Verifies that the decorator adds predicted postings.
"""
predicted_accounts = [
entry.postings[-1].account
for entry in POSTING_IMPORTER.extract(
"dummy-data", existing=TRAINING_DATA
)
for entry in POSTING_IMPORTER.extract("dummy-data", TRAINING_DATA)
if isinstance(entry, data.Transaction)
]
assert predicted_accounts == ACCOUNT_PREDICTIONS

0 comments on commit 7e222ad

Please sign in to comment.