diff --git a/README.md b/README.md index 49592bff..64abf465 100644 --- a/README.md +++ b/README.md @@ -193,3 +193,22 @@ finmind/ --- MIT Licensed. Built with ❤️. + +## Rule-Based Auto-Tagging + +Users can define rules to automatically categorize transactions. + +### Features +- Rule Fields: payee, amount, description, notes +- Operators: contains, equals, regex, gt, lt, gte, lte +- Multi-Condition Rules with AND/OR logic +- Priority Ordering +- Auto-Apply on expense creation + +### API Endpoints +- GET /rules - List all rules +- POST /rules - Create a rule +- PATCH /rules/{id} - Update a rule +- DELETE /rules/{id} - Delete a rule +- POST /rules/{id}/conditions - Add condition +- POST /rules/apply/{expense_id} - Manually apply diff --git a/packages/backend/app/db/schema.sql b/packages/backend/app/db/schema.sql index 410189de..12594f1b 100644 --- a/packages/backend/app/db/schema.sql +++ b/packages/backend/app/db/schema.sql @@ -123,3 +123,47 @@ CREATE TABLE IF NOT EXISTS audit_logs ( action VARCHAR(100) NOT NULL, created_at TIMESTAMP NOT NULL DEFAULT NOW() ); + + +-- Rule-based auto-tagging +DO \$\$ BEGIN + CREATE TYPE rule_field AS ENUM ('payee','amount','description','notes'); +EXCEPTION WHEN duplicate_object THEN NULL; +END \$\$; + +DO \$\$ BEGIN + CREATE TYPE rule_operator AS ENUM ('contains','equals','regex','gt','lt','gte','lte','startswith','endswith'); +EXCEPTION WHEN duplicate_object THEN NULL; +END \$\$; + +DO \$\$ BEGIN + CREATE TYPE condition_type AS ENUM ('AND','OR'); +EXCEPTION WHEN duplicate_object THEN NULL; +END \$\$; + +CREATE TABLE IF NOT EXISTS categorization_rules ( + id SERIAL PRIMARY KEY, + user_id INT NOT NULL REFERENCES users(id) ON DELETE CASCADE, + name VARCHAR(100) NOT NULL, + field rule_field NOT NULL, + operator rule_operator NOT NULL, + value VARCHAR(500) NOT NULL, + category_id INT REFERENCES categories(id) ON DELETE SET NULL, + tag VARCHAR(100), + priority INT NOT NULL DEFAULT 0, + condition_type condition_type NOT NULL DEFAULT 'AND', + active BOOLEAN NOT NULL DEFAULT TRUE, + created_at TIMESTAMP NOT NULL DEFAULT NOW(), + updated_at TIMESTAMP NOT NULL DEFAULT NOW() +); +CREATE INDEX IF NOT EXISTS idx_categorization_rules_user_priority ON categorization_rules(user_id, priority DESC); + +CREATE TABLE IF NOT EXISTS rule_conditions ( + id SERIAL PRIMARY KEY, + rule_id INT NOT NULL REFERENCES categorization_rules(id) ON DELETE CASCADE, + field rule_field NOT NULL, + operator rule_operator NOT NULL, + value VARCHAR(500) NOT NULL, + created_at TIMESTAMP NOT NULL DEFAULT NOW() +); +CREATE INDEX IF NOT EXISTS idx_rule_conditions_rule ON rule_conditions(rule_id); diff --git a/packages/backend/app/models.py b/packages/backend/app/models.py index 64d44810..f6fda72c 100644 --- a/packages/backend/app/models.py +++ b/packages/backend/app/models.py @@ -133,3 +133,71 @@ class AuditLog(db.Model): user_id = db.Column(db.Integer, db.ForeignKey("users.id"), nullable=True) action = db.Column(db.String(100), nullable=False) created_at = db.Column(db.DateTime, default=datetime.utcnow, nullable=False) + + +class RuleField(str, Enum): + PAYEE = 'payee' + AMOUNT = 'amount' + DESCRIPTION = 'description' + NOTES = 'notes' + +class RuleOperator(str, Enum): + CONTAINS = 'contains' + EQUALS = 'equals' + REGEX = 'regex' + GT = 'gt' + LT = 'lt' + GTE = 'gte' + LTE = 'lte' + STARTSWITH = 'startswith' + ENDSWITH = 'endswith' + +class ConditionType(str, Enum): + AND = 'AND' + OR = 'OR' + +class CategorizationRule(db.Model): + __tablename__ = 'categorization_rules' + id = db.Column(db.Integer, primary_key=True) + user_id = db.Column(db.Integer, db.ForeignKey('users.id'), nullable=False) + name = db.Column(db.String(100), nullable=False) + field = db.Column(SAEnum(RuleField), nullable=False) + operator = db.Column(SAEnum(RuleOperator), nullable=False) + value = db.Column(db.String(500), nullable=False) + category_id = db.Column(db.Integer, db.ForeignKey('categories.id'), nullable=True) + tag = db.Column(db.String(100), nullable=True) + priority = db.Column(db.Integer, default=0, nullable=False) + condition_type = db.Column(SAEnum(ConditionType), default=ConditionType.AND, nullable=False) + active = db.Column(db.Boolean, default=True, nullable=False) + created_at = db.Column(db.DateTime, default=datetime.utcnow, nullable=False) + updated_at = db.Column(db.DateTime, default=datetime.utcnow, onupdate=datetime.utcnow, nullable=False) + + def to_dict(self): + return { + 'id': self.id, 'name': self.name, + 'field': self.field.value if self.field else None, + 'operator': self.operator.value if self.operator else None, + 'value': self.value, 'category_id': self.category_id, + 'tag': self.tag, 'priority': self.priority, + 'condition_type': self.condition_type.value if self.condition_type else None, + 'active': self.active, + 'created_at': self.created_at.isoformat() if self.created_at else None, + 'updated_at': self.updated_at.isoformat() if self.updated_at else None, + } + +class RuleCondition(db.Model): + __tablename__ = 'rule_conditions' + id = db.Column(db.Integer, primary_key=True) + rule_id = db.Column(db.Integer, db.ForeignKey('categorization_rules.id', ondelete='CASCADE'), nullable=False) + field = db.Column(SAEnum(RuleField), nullable=False) + operator = db.Column(SAEnum(RuleOperator), nullable=False) + value = db.Column(db.String(500), nullable=False) + created_at = db.Column(db.DateTime, default=datetime.utcnow, nullable=False) + + def to_dict(self): + return { + 'id': self.id, 'rule_id': self.rule_id, + 'field': self.field.value if self.field else None, + 'operator': self.operator.value if self.operator else None, + 'value': self.value, + } diff --git a/packages/backend/app/routes/__init__.py b/packages/backend/app/routes/__init__.py index f13b0f89..475c9305 100644 --- a/packages/backend/app/routes/__init__.py +++ b/packages/backend/app/routes/__init__.py @@ -7,6 +7,8 @@ from .categories import bp as categories_bp from .docs import bp as docs_bp from .dashboard import bp as dashboard_bp +from .savings_opportunities import bp as savings_opportunities_bp +from .rules import bp as rules_bp def register_routes(app: Flask): @@ -18,3 +20,5 @@ def register_routes(app: Flask): app.register_blueprint(categories_bp, url_prefix="/categories") app.register_blueprint(docs_bp, url_prefix="/docs") app.register_blueprint(dashboard_bp, url_prefix="/dashboard") + app.register_blueprint(savings_opportunities_bp, url_prefix="/savings-opportunities") + app.register_blueprint(rules_bp, url_prefix="/rules") diff --git a/packages/backend/app/routes/expenses.py b/packages/backend/app/routes/expenses.py index 1376d46f..def59625 100644 --- a/packages/backend/app/routes/expenses.py +++ b/packages/backend/app/routes/expenses.py @@ -75,6 +75,13 @@ def create_expense(): spent_at=date.fromisoformat(raw_date) if raw_date else date.today(), ) db.session.add(e) + # Auto-apply categorization rules if no category specified + if not e.category_id: + try: + from .rules import apply_rules + apply_rules(e, uid) + except Exception as ex: + logger.warning('Failed to auto-apply rules: %s', ex) db.session.commit() logger.info("Created expense id=%s user=%s amount=%s", e.id, uid, e.amount) # Invalidate caches diff --git a/packages/backend/app/routes/rules.py b/packages/backend/app/routes/rules.py new file mode 100644 index 00000000..1027851b --- /dev/null +++ b/packages/backend/app/routes/rules.py @@ -0,0 +1,165 @@ +import logging +import re +from decimal import Decimal, InvalidOperation +from flask import Blueprint, jsonify, request +from flask_jwt_extended import jwt_required, get_jwt_identity +from ..extensions import db +from ..models import CategorizationRule, RuleField, RuleOperator, ConditionType, RuleCondition, Category + +bp = Blueprint('rules', __name__) +logger = logging.getLogger('finmind.rules') + +@bp.get('') +@jwt_required() +def list_rules(): + uid = int(get_jwt_identity()) + rules = db.session.query(CategorizationRule).filter_by(user_id=uid).order_by(CategorizationRule.priority.desc()).all() + return jsonify([r.to_dict() for r in rules]) + +@bp.get('/') +@jwt_required() +def get_rule(rule_id): + uid = int(get_jwt_identity()) + rule = db.session.get(CategorizationRule, rule_id) + if not rule or rule.user_id != uid: + return jsonify(error='not found'), 404 + return jsonify(rule.to_dict()) + +@bp.post('') +@jwt_required() +def create_rule(): + uid = int(get_jwt_identity()) + data = request.get_json() or {} + name = (data.get('name') or '').strip() + if not name: return jsonify(error='name required'), 400 + field = _parse_field(data.get('field')) + if not field: return jsonify(error='valid field required'), 400 + operator = _parse_operator(data.get('operator')) + if not operator: return jsonify(error='valid operator required'), 400 + value = (data.get('value') or '').strip() + if not value: return jsonify(error='value required'), 400 + if operator == RuleOperator.REGEX: + try: re.compile(value) + except re.error as e: return jsonify(error='invalid regex: ' + str(e)), 400 + category_id = data.get('category_id') + if category_id: + cat = db.session.get(Category, category_id) + if not cat or cat.user_id != uid: return jsonify(error='category not found'), 404 + try: priority = int(data.get('priority', 0)) + except: return jsonify(error='invalid priority'), 400 + condition_type = _parse_condition_type(data.get('condition_type')) + rule = CategorizationRule(user_id=uid, name=name, field=field, operator=operator, value=value, category_id=category_id, tag=data.get('tag'), priority=priority, condition_type=condition_type, active=bool(data.get('active', True))) + db.session.add(rule) + db.session.commit() + logger.info('Created rule id=%s user=%s', rule.id, uid) + conditions = data.get('conditions', []) + if conditions: + for cd in conditions: + cf, co, cv = _parse_field(cd.get('field')), _parse_operator(cd.get('operator')), (cd.get('value') or '').strip() + if cf and co and cv: + if co == RuleOperator.REGEX: + try: + re.compile(cv) + except: + continue + db.session.add(RuleCondition(rule_id=rule.id, field=cf, operator=co, value=cv)) + db.session.commit() + return jsonify(rule.to_dict()), 201 + + +@bp.patch('/') +@jwt_required() +def update_rule(rule_id): + uid = int(get_jwt_identity()) + rule = db.session.get(CategorizationRule, rule_id) + if not rule or rule.user_id != uid: return jsonify(error='not found'), 404 + data = request.get_json() or {} + if 'name' in data: rule.name = (data.get('name') or '').strip() or rule.name + if 'field' in data: rule.field = _parse_field(data.get('field')) or rule.field + if 'operator' in data: rule.operator = _parse_operator(data.get('operator')) or rule.operator + if 'value' in data: rule.value = (data.get('value') or '').strip() or rule.value + if 'category_id' in data: rule.category_id = data.get('category_id') + if 'tag' in data: rule.tag = data.get('tag') + if 'priority' in data: try: rule.priority = int(data.get('priority', 0)); except: pass + if 'active' in data: rule.active = bool(data.get('active')) + db.session.commit() + return jsonify(rule.to_dict()) + +@bp.delete('/') +@jwt_required() +def delete_rule(rule_id): + uid = int(get_jwt_identity()) + rule = db.session.get(CategorizationRule, rule_id) + if not rule or rule.user_id != uid: return jsonify(error='not found'), 404 + db.session.delete(rule) + db.session.commit() + return jsonify(message='deleted') + +@bp.post('//conditions') +@jwt_required() +def add_condition(rule_id): + uid = int(get_jwt_identity()) + rule = db.session.get(CategorizationRule, rule_id) + if not rule or rule.user_id != uid: return jsonify(error='not found'), 404 + data = request.get_json() or {} + field = _parse_field(data.get('field')) + operator = _parse_operator(data.get('operator')) + value = (data.get('value') or '').strip() + if not field or not operator or not value: return jsonify(error='field, operator, value required'), 400 + cond = RuleCondition(rule_id=rule_id, field=field, operator=operator, value=value) + db.session.add(cond) + db.session.commit() + return jsonify(cond.to_dict()), 201 + +@bp.post('/apply/') +@jwt_required() +def apply_rules_to_expense(expense_id): + from ..models import Expense + uid = int(get_jwt_identity()) + exp = db.session.get(Expense, expense_id) + if not exp or exp.user_id != uid: return jsonify(error='expense not found'), 404 + result = apply_rules(exp, uid) + return jsonify(result) + + +def _parse_field(raw): return RuleField(str(raw).lower().strip()) if raw else None +def _parse_operator(raw): return RuleOperator(str(raw).lower().strip()) if raw else None +def _parse_condition_type(raw): return ConditionType(str(raw).upper().strip()) if raw else ConditionType.AND + +def apply_rules(expense, user_id): + rules = db.session.query(CategorizationRule).filter_by(user_id=user_id, active=True).order_by(CategorizationRule.priority.desc()).all() + applied = [] + cat_id = None + tag = None + for rule in rules: + if _evaluate_rule(expense, rule): + applied.append(rule.to_dict()) + if rule.category_id and not cat_id: cat_id = rule.category_id; expense.category_id = cat_id + if rule.tag and not tag: tag = rule.tag + if not db.session.query(RuleCondition).filter_by(rule_id=rule.id).first(): break + if applied: db.session.commit() + return {'expense_id': expense.id, 'category_id': cat_id, 'tag': tag, 'applied_rules': applied} + +def _evaluate_rule(expense, rule): + conds = db.session.query(RuleCondition).filter_by(rule_id=rule.id).all() + if not conds: return _evaluate_condition(expense, rule.field, rule.operator, rule.value) + results = [_evaluate_condition(expense, rule.field, rule.operator, rule.value)] + for c in conds: results.append(_evaluate_condition(expense, c.field, c.operator, c.value)) + return all(results) if rule.condition_type == ConditionType.AND else any(results) + +def _evaluate_condition(expense, field, operator, value): + fv = str(expense.amount) if field == RuleField.AMOUNT else expense.notes or '' + if operator == RuleOperator.CONTAINS: return value.lower() in fv.lower() + elif operator == RuleOperator.EQUALS: return value.lower() == fv.lower() + elif operator == RuleOperator.STARTSWITH: return fv.lower().startswith(value.lower()) + elif operator == RuleOperator.ENDSWITH: return fv.lower().endswith(value.lower()) + elif operator == RuleOperator.REGEX: try: return bool(re.search(value, fv, re.I)); except: return False + elif operator in (RuleOperator.GT, RuleOperator.LT, RuleOperator.GTE, RuleOperator.LTE): + try: + fn, vn = Decimal(fv), Decimal(value) + if operator == RuleOperator.GT: return fn > vn + elif operator == RuleOperator.LT: return fn < vn + elif operator == RuleOperator.GTE: return fn >= vn + elif operator == RuleOperator.LTE: return fn <= vn + except: return False + return False diff --git a/packages/backend/tests/test_rules.py b/packages/backend/tests/test_rules.py new file mode 100644 index 00000000..e4772582 --- /dev/null +++ b/packages/backend/tests/test_rules.py @@ -0,0 +1,35 @@ + +def test_update_rule(client, auth_header): + r = client.post('/rules', json={'name': 'Test Rule', 'field': 'description', 'operator': 'contains', 'value': 'test'}, headers=auth_header) + assert r.status_code == 201 + rule = r.get_json() + r = client.patch('/rules/' + str(rule['id']), json={'name': 'Updated Rule', 'value': 'updated', 'priority': 5}, headers=auth_header) + assert r.status_code == 200 + updated = r.get_json() + assert updated['name'] == 'Updated Rule' + +def test_delete_rule(client, auth_header): + r = client.post('/rules', json={'name': 'To Delete', 'field': 'description', 'operator': 'contains', 'value': 'delete'}, headers=auth_header) + assert r.status_code == 201 + rule = r.get_json() + r = client.delete('/rules/' + str(rule['id']), headers=auth_header) + assert r.status_code == 200 + +def test_rule_applies_to_expense(client, auth_header): + r = client.post('/categories', json={'name': 'Shopping'}, headers=auth_header) + cat = r.get_json() + r = client.post('/rules', json={'name': 'Amazon', 'field': 'description', 'operator': 'contains', 'value': 'amazon', 'category_id': cat['id'], 'priority': 10}, headers=auth_header) + r = client.post('/expenses', json={'amount': 50.00, 'description': 'Amazon purchase', 'date': '2026-01-15'}, headers=auth_header) + expense = r.get_json() + assert expense['category_id'] == cat['id'] + +def test_priority_ordering(client, auth_header): + r = client.post('/categories', json={'name': 'High'}, headers=auth_header) + cat_high = r.get_json() + r = client.post('/categories', json={'name': 'Low'}, headers=auth_header) + cat_low = r.get_json() + r = client.post('/rules', json={'name': 'Low', 'field': 'description', 'operator': 'contains', 'value': 'test', 'category_id': cat_low['id'], 'priority': 1}, headers=auth_header) + r = client.post('/rules', json={'name': 'High', 'field': 'description', 'operator': 'contains', 'value': 'test', 'category_id': cat_high['id'], 'priority': 10}, headers=auth_header) + r = client.post('/expenses', json={'amount': 50.00, 'description': 'test expense', 'date': '2026-01-15'}, headers=auth_header) + expense = r.get_json() + assert expense['category_id'] == cat_high['id']