Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Upgrade for beancount3/beangulp fixes #135 #136

Merged
merged 5 commits into from
Jan 6, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions CHANGES
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
Changelog
=========

v0.6 (2025-01-06)
-----------------

Upgrade to Beancount v3 and beangulp.


v0.5 (2024-01-21)
-----------------

Expand Down
2 changes: 1 addition & 1 deletion pylintrc
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
[MESSAGES CONTROL]
disable = too-few-public-methods
disable = too-few-public-methods,cyclic-import
8 changes: 8 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,14 @@ build-backend = "setuptools.build_meta"
[tool.black]
line-length = 79

[[tool.mypy.overrides]]
module = ["beancount.*"]
follow_untyped_imports = true

[[tool.mypy.overrides]]
module = ["beangulp.*"]
follow_untyped_imports = true

[tool.ruff]
target-version = "py38"
line-length = 79
Expand Down
4 changes: 3 additions & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,11 @@ packages = find:
setup_requires =
setuptools_scm
install_requires =
beancount>=2.3.5,<3.0.0
beancount>=3
beangulp
scikit-learn>=1.0
numpy>=1.18.0
typing-extensions>=4.9

[options.packages.find]
exclude =
Expand Down
29 changes: 24 additions & 5 deletions smart_importer/detector.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,22 @@
"""Duplicate detector hook."""

from __future__ import annotations

import logging
from typing import Callable

from beancount.ingest import similar
from beancount.core import data
from beangulp import Importer, similar
from typing_extensions import deprecated

from smart_importer.hooks import ImporterHook

logger = logging.getLogger(__name__) # pylint: disable=invalid-name


@deprecated(
"Use or override the deduplicate method on beangulp.Importer directly."
)
class DuplicateDetector(ImporterHook):
"""Class for duplicate detector importer helpers.

Expand All @@ -18,17 +26,28 @@ class DuplicateDetector(ImporterHook):
entries to classify against.
"""

def __init__(self, comparator=None, window_days=2):
def __init__(
self,
comparator: Callable[[data.Directive, data.Directive], bool]
| None = None,
window_days: int = 2,
) -> None:
super().__init__()
self.comparator = comparator
self.window_days = window_days

def __call__(self, importer, file, imported_entries, existing_entries):
def __call__(
self,
importer: Importer,
file: str,
imported_entries: data.Directives,
existing: data.Directives,
) -> data.Directives:
"""Add duplicate metadata for imported transactions.

Args:
imported_entries: The list of imported entries.
existing_entries: The list of existing entries as passed to the
existing: The list of existing entries as passed to the
importer.

Returns:
Expand All @@ -37,7 +56,7 @@ def __call__(self, importer, file, imported_entries, existing_entries):

duplicate_pairs = similar.find_similar_entries(
imported_entries,
existing_entries,
existing,
self.comparator,
self.window_days,
)
Expand Down
49 changes: 40 additions & 9 deletions smart_importer/hooks.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,34 @@
"""Importer decorators."""

from __future__ import annotations

import logging
from functools import wraps
from typing import Callable, Sequence

from beancount.core import data
from beangulp import Adapter, Importer, ImporterProtocol

logger = logging.getLogger(__name__) # pylint: disable=invalid-name


class ImporterHook:
"""Interface for an importer hook."""

def __call__(self, importer, file, imported_entries, existing_entries):
def __call__(
self,
importer: Importer,
file: str,
imported_entries: data.Directives,
existing: data.Directives,
) -> data.Directives:
"""Apply the hook and modify the imported entries.

Args:
importer: The importer that this hooks is being applied to.
file: The file that is being imported.
imported_entries: The current list of imported entries.
existing_entries: The existing entries, as passed to the extract
existing: The existing entries, as passed to the extract
function.

Returns:
Expand All @@ -25,29 +37,48 @@ def __call__(self, importer, file, imported_entries, existing_entries):
raise NotImplementedError


def apply_hooks(importer, hooks):
def apply_hooks(
importer: Importer | ImporterProtocol,
hooks: Sequence[
Callable[
[Importer, str, data.Directives, data.Directives], data.Directives
]
],
) -> Importer:
"""Apply a list of importer hooks to an importer.

Args:
importer: An importer instance.
hooks: A list of hooks, each a callable object.
"""

if not isinstance(importer, Importer):
importer = Adapter(importer)
unpatched_extract = importer.extract

@wraps(unpatched_extract)
def patched_extract_method(file, existing_entries=None):
def patched_extract_method(
filepath: str, existing: data.Directives
) -> data.Directives:
logger.debug("Calling the importer's extract method.")
imported_entries = unpatched_extract(
file, existing_entries=existing_entries
)
imported_entries = unpatched_extract(filepath, existing)

for hook in hooks:
imported_entries = hook(
importer, file, imported_entries, existing_entries
importer, filepath, imported_entries, existing
)

return imported_entries

importer.extract = patched_extract_method
importer.extract = patched_extract_method # type: ignore

# pylint: disable=import-outside-toplevel
from smart_importer.detector import DuplicateDetector

if any(isinstance(hook, DuplicateDetector) for hook in hooks):
logger.warning(
"Use of DuplicateDetector detected - this is deprecated, "
"please use the beangulp.Importer.deduplicate method directly."
)
importer.deduplicate = lambda entries, existing: None # type: ignore
return importer
36 changes: 24 additions & 12 deletions smart_importer/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@

import logging
import threading
from typing import TYPE_CHECKING, Callable
from typing import TYPE_CHECKING, Any, Callable

from beancount.core import data
from beancount.core.data import (
ALL_DIRECTIVES,
Close,
Open,
Transaction,
Expand All @@ -27,6 +27,7 @@
from smart_importer.pipelines import get_pipeline

if TYPE_CHECKING:
from beangulp import Importer
from sklearn import Pipeline

logger = logging.getLogger(__name__) # pylint: disable=invalid-name
Expand All @@ -53,25 +54,31 @@ class EntryPredictor(ImporterHook):

def __init__(
self,
predict=True,
overwrite=False,
predict: bool = True,
overwrite: bool = False,
string_tokenizer: Callable[[str], list] | None = None,
denylist_accounts: list[str] | None = None,
):
) -> None:
super().__init__()
self.training_data = None
self.open_accounts: dict[str, str] = {}
self.denylist_accounts = set(denylist_accounts or [])
self.pipeline: Pipeline | None = None
self.is_fitted = False
self.lock = threading.Lock()
self.account = None
self.account: str | None = None

self.predict = predict
self.overwrite = overwrite
self.string_tokenizer = string_tokenizer

def __call__(self, importer, file, imported_entries, existing_entries):
def __call__(
self,
importer: Importer,
file: str,
imported_entries: data.Directives,
existing_entries: data.Directives,
) -> data.Directives:
"""Predict attributes for imported transactions.

Args:
Expand All @@ -83,7 +90,7 @@ def __call__(self, importer, file, imported_entries, existing_entries):
A list of entries, modified by this predictor.
"""
logging.debug("Running %s for file %s", self.__class__.__name__, file)
self.account = importer.file_account(file)
self.account = importer.account(file)
self.load_training_data(existing_entries)
with self.lock:
self.define_pipeline()
Expand Down Expand Up @@ -157,7 +164,7 @@ def targets(self):
for entry in self.training_data
]

def define_pipeline(self):
def define_pipeline(self) -> None:
"""Defines the machine learning pipeline based on given weights."""

transformers = [
Expand All @@ -172,7 +179,7 @@ def define_pipeline(self):
SVC(kernel="linear"),
)

def train_pipeline(self):
def train_pipeline(self) -> None:
"""Train the machine learning pipeline."""

self.is_fitted = False
Expand All @@ -187,11 +194,14 @@ def train_pipeline(self):
self.is_fitted = True
logger.debug("Only one target possible.")
else:
assert self.pipeline is not None
self.pipeline.fit(self.training_data, self.targets)
self.is_fitted = True
logger.debug("Trained the machine learning model.")

def process_entries(self, imported_entries) -> list[ALL_DIRECTIVES]:
def process_entries(
self, imported_entries: data.Directives
) -> data.Directives:
"""Process imported entries.

Transactions might be modified, all other entries are left as is.
Expand All @@ -206,7 +216,9 @@ def process_entries(self, imported_entries) -> list[ALL_DIRECTIVES]:
imported_entries, enhanced_transactions
)

def apply_prediction(self, entry, prediction):
def apply_prediction(
self, entry: data.Transaction, prediction: Any
) -> data.Transaction:
"""Apply a single prediction to an entry.

Args:
Expand Down
36 changes: 25 additions & 11 deletions tests/data_test.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,35 @@
"""Tests for the `PredictPostings` decorator"""

from __future__ import annotations

# pylint: disable=missing-docstring
import os
import pprint
import re
from typing import Callable

import pytest
from beancount.core import data
from beancount.core.compare import stable_hash_namedtuple
from beancount.ingest.importer import ImporterProtocol
from beancount.parser import parser
from beangulp import Importer

from smart_importer import PredictPostings, apply_hooks


def chinese_string_tokenizer(pre_tokenizer_string):
def chinese_string_tokenizer(pre_tokenizer_string: str) -> list[str]:
jieba = pytest.importorskip("jieba")
jieba.initialize()
return list(jieba.cut(pre_tokenizer_string))


def _hash(entry):
def _hash(entry: data.Directive) -> str:
return stable_hash_namedtuple(entry, ignore={"meta", "units"})


def _load_testset(testset):
def _load_testset(
testset: str,
) -> tuple[data.Directives, data.Directives, data.Directives]:
path = os.path.join(
os.path.dirname(__file__), "data", testset + ".beancount"
)
Expand All @@ -35,7 +41,7 @@ def _load_testset(testset):
assert not errors
parsed_sections.append(entries)
assert len(parsed_sections) == 3
return parsed_sections
return tuple(parsed_sections)


@pytest.mark.parametrize(
Expand All @@ -47,19 +53,27 @@ def _load_testset(testset):
("chinese", chinese_string_tokenizer),
],
)
def test_testset(testset, string_tokenizer):
def test_testset(
testset: str, string_tokenizer: Callable[[str], list[str]]
) -> None:
# pylint: disable=unbalanced-tuple-unpacking
imported, training_data, expected = _load_testset(testset)

class DummyImporter(ImporterProtocol):
def extract(self, file, existing_entries=None):
class DummyImporter(Importer):
def extract(
self, filepath: str, existing: data.Directives
) -> data.Directives:
return imported

def account(self, filepath: str) -> str:
return ""

def identify(self, filepath: str) -> bool:
return True

importer = DummyImporter()
apply_hooks(importer, [PredictPostings(string_tokenizer=string_tokenizer)])
imported_transactions = importer.extract(
"dummy-data", existing_entries=training_data
)
imported_transactions = importer.extract("dummy-data", training_data)

for txn1, txn2 in zip(imported_transactions, expected):
if _hash(txn1) != _hash(txn2):
Expand Down
Loading
Loading