Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,15 @@
from __future__ import annotations

from bankstatements_core.domain.models.extraction_result import ExtractionResult
from bankstatements_core.domain.models.extraction_scoring_config import (
ExtractionScoringConfig,
)
from bankstatements_core.domain.models.extraction_warning import ExtractionWarning
from bankstatements_core.domain.models.transaction import Transaction

__all__ = ["Transaction", "ExtractionResult", "ExtractionWarning"]
__all__ = [
"Transaction",
"ExtractionResult",
"ExtractionWarning",
"ExtractionScoringConfig",
]
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
"""Extraction scoring configuration for Transaction.confidence_score computation.

Holds the per-signal penalty weights used by RowPostProcessor to reduce
a transaction's confidence score when extraction anomalies are detected.
"""

from __future__ import annotations

from dataclasses import dataclass


@dataclass(frozen=True)
class ExtractionScoringConfig:
"""Injectable penalty weights for Transaction.confidence_score.

All weights must be non-negative. The score starts at 1.0 and is
decremented by each applicable penalty, then clamped to [0.0, 1.0].

Attributes:
penalty_date_propagated: Applied when a date is filled in from a
prior row or the filename rather than read directly from the row.
penalty_missing_balance: Applied when the balance field is absent or
empty for a transaction row.

Examples:
>>> cfg = ExtractionScoringConfig.default()
>>> cfg.penalty_date_propagated
0.1
>>> cfg.penalty_missing_balance
0.2
"""

penalty_date_propagated: float = 0.1
penalty_missing_balance: float = 0.2

def __post_init__(self) -> None:
for name, val in [
("penalty_date_propagated", self.penalty_date_propagated),
("penalty_missing_balance", self.penalty_missing_balance),
]:
if val < 0.0:
raise ValueError(f"{name} must be >= 0.0, got {val}")

@classmethod
def default(cls) -> "ExtractionScoringConfig":
"""Return the default production scoring configuration."""
return cls()
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
# Machine-readable warning codes
CODE_DATE_PROPAGATED = "DATE_PROPAGATED"
CODE_CREDIT_CARD_SKIPPED = "CREDIT_CARD_SKIPPED"
CODE_MISSING_BALANCE = "MISSING_BALANCE"


@dataclass
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@

from bankstatements_core.domain import ExtractionResult
from bankstatements_core.domain.converters import dicts_to_transactions
from bankstatements_core.domain.models.extraction_scoring_config import (
ExtractionScoringConfig,
)
from bankstatements_core.domain.models.extraction_warning import (
CODE_CREDIT_CARD_SKIPPED,
ExtractionWarning,
Expand Down Expand Up @@ -54,6 +57,7 @@ def __init__(
pdf_reader: "IPDFReader | None" = None,
extraction_config: "Any | None" = None,
template: "Any | None" = None,
scoring_config: ExtractionScoringConfig | None = None,
):
self.columns = columns
self.table_top_y = table_top_y
Expand All @@ -64,6 +68,7 @@ def __init__(
self.header_check_top_y = header_check_top_y
self.extraction_config = extraction_config
self.template = template
self.scoring_config = scoring_config

self._row_classifier = create_row_classifier_chain()
self._row_builder = RowBuilder(columns, self._row_classifier)
Expand Down Expand Up @@ -99,6 +104,7 @@ def extract(self, pdf_path: Path) -> ExtractionResult:
template=self.template,
filename_date=filename_date,
filename=pdf_path.name,
scoring_config=self.scoring_config,
)
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,12 @@
from datetime import datetime
from typing import TYPE_CHECKING

from bankstatements_core.domain.models.extraction_scoring_config import (
ExtractionScoringConfig,
)
from bankstatements_core.domain.models.extraction_warning import (
CODE_DATE_PROPAGATED,
CODE_MISSING_BALANCE,
ExtractionWarning,
)
from bankstatements_core.extraction.column_identifier import ColumnTypeIdentifier
Expand Down Expand Up @@ -53,13 +57,22 @@ def __init__(
template: "BankTemplate | None",
filename_date: str,
filename: str,
scoring_config: ExtractionScoringConfig | None = None,
) -> None:
self._columns = columns
self._row_classifier = row_classifier
self._template = template
self._filename_date = filename_date
self._filename = filename
self._scoring_config = (
scoring_config
if scoring_config is not None
else ExtractionScoringConfig.default()
)
self._date_col = ColumnTypeIdentifier.find_first_column_of_type(columns, "date")
self._balance_col = ColumnTypeIdentifier.find_first_column_of_type(
columns, "balance"
)
self._last_source: str = ""

def process(self, row: dict, current_date: str) -> str:
Expand All @@ -82,6 +95,9 @@ def process(self, row: dict, current_date: str) -> str:
if self._row_classifier.classify(row, self._columns) != "transaction":
return current_date

score = 1.0
warnings: list[dict] = []

# Date propagation
if self._date_col and row.get(self._date_col):
current_date = row[self._date_col]
Expand All @@ -92,11 +108,27 @@ def process(self, row: dict, current_date: str) -> str:
if not current_date:
current_date = fallback_date
self._last_source = "propagated"
warning = ExtractionWarning(
code=CODE_DATE_PROPAGATED,
message=f"date propagated from previous row ('{fallback_date}')",
score -= self._scoring_config.penalty_date_propagated
warnings.append(
ExtractionWarning(
code=CODE_DATE_PROPAGATED,
message=f"date propagated from previous row ('{fallback_date}')",
).to_dict()
)
row["extraction_warnings"] = json.dumps([warning.to_dict()])

# Missing balance
if self._balance_col and not row.get(self._balance_col, "").strip():
score -= self._scoring_config.penalty_missing_balance
warnings.append(
ExtractionWarning(
code=CODE_MISSING_BALANCE,
message="balance field is missing or empty",
).to_dict()
)

row["confidence_score"] = str(max(0.0, min(1.0, score)))
if warnings:
row["extraction_warnings"] = json.dumps(warnings)

# Metadata tagging
row["Filename"] = self._filename
Expand Down
Loading
Loading