Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
145 changes: 76 additions & 69 deletions src/ocr_router/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@ class DocumentRouter:

def __init__(self, config: dict):
self.config = config
self.categories = config.get('categories', {})
self.route_templates = config.get('route_templates', {})
self.owners = config.get('owners', [])
self.categories = config.get("categories", {})
self.route_templates = config.get("route_templates", {})
self.owners = config.get("owners", [])

# ------------------------------------------------------------------
# Classification
Expand All @@ -21,13 +21,13 @@ def __init__(self, config: dict):
def classify_document(self, text: str) -> Optional[str]:
"""Score each category by keyword matches; return best match above threshold."""
text_lower = text.lower()
min_score: int = self.config.get('min_classification_score', 2)
min_score: int = self.config.get("min_classification_score", 2)
scores: dict[str, int] = {}
for category, keywords in self.categories.items():
hits = sum(1 for kw in keywords if kw.lower() in text_lower)
if hits >= min_score:
scores[category] = hits
return max(scores, key=scores.get) if scores else 'Uncategorized'
return max(scores, key=scores.get) if scores else "Uncategorized"

# ------------------------------------------------------------------
# Route path
Expand All @@ -37,57 +37,56 @@ def build_route_path(self, category: str, metadata: dict) -> str:
"""Build destination folder path from template, stripping Unknown segments."""
template = self.route_templates.get(
category,
self.route_templates.get('default', '{category}/{issuer}/{year}'),
self.route_templates.get("default", "{category}/{issuer}/{year}"),
)

year = 'Unknown'
if metadata.get('date') and len(metadata['date']) >= 4:
year = metadata['date'][:4]
year = "Unknown"
if metadata.get("date") and len(metadata["date"]) >= 4:
year = metadata["date"][:4]

# Blank/template IRS forms: Tax Returns + year-only date + no amount
# → route to Tax Returns\Forms instead of the flat root
if (category == 'Tax Returns'
and metadata.get('date_year_only')
and not metadata.get('amount')):
return 'Tax Returns\\Forms'
if (
category == "Tax Returns"
and metadata.get("date_year_only")
and not metadata.get("amount")
):
return "Tax Returns\\Forms"

# CC/Bank contracts (no amount, contrato/contract keyword) → issuer root, no year
contract_keywords = self.config.get('contract_keywords',
['contrato', 'contract terms', 'cardmember agreement', 'account agreement',
'terms and conditions', 'hoja resumen'])
text_lower = metadata.get('_text_lower', '') # populated by classify if available
is_contract = (
category in ('Credit Card Statements', 'Bank Account & Statements')
and not metadata.get('amount')
)
is_contract = category in (
"Credit Card Statements",
"Bank Account & Statements",
) and not metadata.get("amount")
if is_contract:
issuer_seg = metadata.get('issuer') or 'Unknown'
template = self.route_templates.get(category,
self.route_templates.get('default', '{category}/{issuer}/{year}'))
issuer_seg = metadata.get("issuer") or "Unknown"
template = self.route_templates.get(
category, self.route_templates.get("default", "{category}/{issuer}/{year}")
)
# Use issuer-level path (drop {year} segment)
path = template.replace('/{year}', '').replace('{issuer}', issuer_seg)
path = template.replace("/{year}", "").replace("{issuer}", issuer_seg)
# strip any remaining Unknown segments
parts = [p for p in path.replace('\\', '/').split('/') if p and p != 'Unknown']
return '\\'.join(parts)
parts = [p for p in path.replace("\\", "/").split("/") if p and p != "Unknown"]
return "\\".join(parts)

replacements = {
'category': category or 'Uncategorized',
'issuer': metadata.get('issuer') or 'Unknown',
'owner': metadata.get('owner') or 'Unknown',
'account': metadata.get('account') or 'Unknown',
'year': year,
'date': metadata.get('date') or '',
'amount': metadata.get('amount') or '',
"category": category or "Uncategorized",
"issuer": metadata.get("issuer") or "Unknown",
"owner": metadata.get("owner") or "Unknown",
"account": metadata.get("account") or "Unknown",
"year": year,
"date": metadata.get("date") or "",
"amount": metadata.get("amount") or "",
}

path = template
for key, value in replacements.items():
path = path.replace(f'{{{key}}}', str(value))
path = path.replace(f"{{{key}}}", str(value))

# Remove any path segment whose value is 'Unknown'
parts = re.split(r'[/\\]', path)
parts = [p for p in parts if p and p != 'Unknown']
return '\\'.join(parts)
parts = re.split(r"[/\\]", path)
parts = [p for p in parts if p and p != "Unknown"]
return "\\".join(parts)

# ------------------------------------------------------------------
# Filename normalization
Expand All @@ -105,34 +104,42 @@ def normalize_filename(self, filename: str, metadata: dict) -> str:
- Amount appended as $X.XX when present
- Parts joined with ' - '
"""
category = metadata.get('category', '')
category = metadata.get("category", "")
ext = Path(filename).suffix

monthly_cats = set(self.config.get('monthly_categories', ['Bills']))
account_cats = set(self.config.get('account_in_filename_categories', [
'Bills', 'Credit Card Statements', 'Mortgage & Home Equity Accounts',
]))
doc_types: dict[str, str] = self.config.get('doc_types', {})
monthly_cats = set(self.config.get("monthly_categories", ["Bills"]))
account_cats = set(
self.config.get(
"account_in_filename_categories",
[
"Bills",
"Credit Card Statements",
"Mortgage & Home Equity Accounts",
],
)
)
doc_types: dict[str, str] = self.config.get("doc_types", {})

# Contracts: no amount + CC/Bank → override doc type, force dated (not monthly) format
is_contract = (
category in ('Credit Card Statements', 'Bank Account & Statements')
and not metadata.get('amount')
is_contract = category in (
"Credit Card Statements",
"Bank Account & Statements",
) and not metadata.get("amount")
effective_doc_type = "Contract" if is_contract else doc_types.get(category, "")
effective_monthly = monthly_cats - (
{"Credit Card Statements", "Bank Account & Statements"} if is_contract else set()
)
effective_doc_type = 'Contract' if is_contract else doc_types.get(category, '')
effective_monthly = monthly_cats - ({'Credit Card Statements', 'Bank Account & Statements'}
if is_contract else set())

# --- Date component ---
date_str = metadata.get('date') or '' # ISO YYYY-MM-DD
date_year_only = metadata.get('date_year_only', False)
date_part = ''
date_str = metadata.get("date") or "" # ISO YYYY-MM-DD
date_year_only = metadata.get("date_year_only", False)
date_part = ""
if len(date_str) >= 4:
year = date_str[:4]
month = date_str[5:7] if len(date_str) >= 7 else ''
day = date_str[8:10] if len(date_str) >= 10 else ''
month = date_str[5:7] if len(date_str) >= 7 else ""
day = date_str[8:10] if len(date_str) >= 10 else ""
if date_year_only:
date_part = year # e.g. 2021 (tax form, year only)
date_part = year # e.g. 2021 (tax form, year only)
elif category in effective_monthly:
date_part = f"{year}.{month}" if month else year
elif day:
Expand All @@ -143,27 +150,27 @@ def normalize_filename(self, filename: str, metadata: dict) -> str:
date_part = year

# --- Smart name: Issuer + DocType ---
issuer = (metadata.get('issuer') or '').strip()
issuer = (metadata.get("issuer") or "").strip()
doc_type = effective_doc_type
name_parts = [p for p in [issuer, doc_type] if p]
smart_name = ' '.join(name_parts) if name_parts else Path(filename).stem[:50]
smart_name = " ".join(name_parts) if name_parts else Path(filename).stem[:50]

# --- Account component (only for applicable categories) ---
account_part = ''
account_part = ""
if category in account_cats:
account = (metadata.get('account') or '').strip()
account = (metadata.get("account") or "").strip()
if account:
if metadata.get('account_masked'):
n = metadata.get('account_digits') or 4
account_part = f"(Last{n} {account})" # e.g. (Last4 1234)
if metadata.get("account_masked"):
n = metadata.get("account_digits") or 4
account_part = f"(Last{n} {account})" # e.g. (Last4 1234)
else:
account_part = f"({account})"

# --- Amount component (always 2 decimal places) ---
amount_part = ''
no_amount_cats = set(self.config.get('no_amount_categories', []))
raw_amount = metadata.get('amount')
currency = metadata.get('currency', '$')
amount_part = ""
no_amount_cats = set(self.config.get("no_amount_categories", []))
raw_amount = metadata.get("amount")
currency = metadata.get("currency", "$")
if raw_amount and category not in no_amount_cats:
try:
amount_part = f"{currency}{float(raw_amount):.2f}"
Expand All @@ -172,9 +179,9 @@ def normalize_filename(self, filename: str, metadata: dict) -> str:

# --- Assemble ---
components = [p for p in [date_part, smart_name, account_part, amount_part] if p]
normalized = ' - '.join(components)
normalized = " - ".join(components)

# Strip characters invalid in Windows filenames (preserve $, parens, dots, spaces)
normalized = re.sub(r'[<>:"/\\|?*]', '_', normalized)
normalized = re.sub(r'[<>:"/\\|?*]', "_", normalized)

return f"{normalized}{ext}"
27 changes: 27 additions & 0 deletions tests/test_router.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
"""Regression tests for router behavior."""

from ocr_router.router import DocumentRouter


def test_contract_route_path_drops_year_segment():
"""Contract-like CC docs should route to issuer root without year."""
config = {
"categories": {},
"route_templates": {
"default": "{category}/{issuer}/{year}",
"Credit Card Statements": "Credit Card Statements/{issuer}/{year}",
},
"owners": [],
}

router = DocumentRouter(config)
path = router.build_route_path(
"Credit Card Statements",
{
"issuer": "AMEX (OZ)",
"amount": "",
"date": "2024-01-10",
},
)

assert path == "Credit Card Statements\\AMEX (OZ)"