diff --git a/packages/backend/app/models.py b/packages/backend/app/models.py index 64d44810..f5f97f9a 100644 --- a/packages/backend/app/models.py +++ b/packages/backend/app/models.py @@ -66,6 +66,43 @@ class RecurringExpense(db.Model): created_at = db.Column(db.DateTime, default=datetime.utcnow, nullable=False) + +class SubscriptionCadence(str, Enum): + WEEKLY = "WEEKLY" + MONTHLY = "MONTHLY" + YEARLY = "YEARLY" + + +class SubscriptionStatus(str, Enum): + DETECTED = "DETECTED" + CONFIRMED = "CONFIRMED" + DISMISSED = "DISMISSED" + + +class Subscription(db.Model): + """Auto-detected subscription from recurring expense patterns.""" + __tablename__ = "subscriptions" + + id = db.Column(db.Integer, primary_key=True) + user_id = db.Column(db.Integer, db.ForeignKey("users.id"), nullable=False) + merchant_name = db.Column(db.String(200), nullable=False) + category_id = db.Column(db.Integer, db.ForeignKey("categories.id"), nullable=True) + amount = db.Column(db.Numeric(12, 2), nullable=False) + currency = db.Column(db.String(10), default="INR", nullable=False) + detected_cadence = db.Column(SAEnum(SubscriptionCadence), nullable=False) + confidence_score = db.Column(db.Numeric(3, 2), nullable=False) # 0.00 to 1.00 + occurrence_count = db.Column(db.Integer, default=1, nullable=False) + first_occurrence_date = db.Column(db.Date, nullable=False) + last_occurrence_date = db.Column(db.Date, nullable=False) + next_predicted_date = db.Column(db.Date, nullable=True) + average_amount = db.Column(db.Numeric(12, 2), nullable=True) + amount_variance = db.Column(db.Numeric(12, 2), nullable=True) + status = db.Column(SAEnum(SubscriptionStatus), default=SubscriptionStatus.DETECTED, nullable=False) + notes = db.Column(db.String(500), nullable=True) + created_at = db.Column(db.DateTime, default=datetime.utcnow, nullable=False) + updated_at = db.Column(db.DateTime, default=datetime.utcnow, onupdate=datetime.utcnow, nullable=False) + + class BillCadence(str, Enum): MONTHLY = "MONTHLY" WEEKLY = "WEEKLY" diff --git a/packages/backend/app/routes/__init__.py b/packages/backend/app/routes/__init__.py index f13b0f89..045200b9 100644 --- a/packages/backend/app/routes/__init__.py +++ b/packages/backend/app/routes/__init__.py @@ -7,6 +7,7 @@ from .categories import bp as categories_bp from .docs import bp as docs_bp from .dashboard import bp as dashboard_bp +from .subscriptions import bp as subscriptions_bp def register_routes(app: Flask): @@ -17,4 +18,5 @@ def register_routes(app: Flask): app.register_blueprint(insights_bp, url_prefix="/insights") app.register_blueprint(categories_bp, url_prefix="/categories") app.register_blueprint(docs_bp, url_prefix="/docs") + app.register_blueprint(subscriptions_bp, url_prefix="/subscriptions") app.register_blueprint(dashboard_bp, url_prefix="/dashboard") diff --git a/packages/backend/app/routes/subscriptions.py b/packages/backend/app/routes/subscriptions.py new file mode 100644 index 00000000..2b817344 --- /dev/null +++ b/packages/backend/app/routes/subscriptions.py @@ -0,0 +1,158 @@ +""" +Subscriptions API routes for auto-detected recurring charges. +""" + +from datetime import date +from flask import Blueprint, current_app, jsonify, request +from flask_jwt_extended import jwt_required, get_jwt_identity +from ..extensions import db +from ..models import Subscription, SubscriptionStatus, SubscriptionCadence, User +from ..services.subscription_detector import detect_and_create_subscriptions, SubscriptionDetector +import logging + +bp = Blueprint("subscriptions", __name__) +logger = logging.getLogger("finmind.subscriptions") + + +@bp.get("") +@jwt_required() +def list_subscriptions(): + """ + List subscriptions for the current user. + Query params: + - status: filter by status (DETECTED, CONFIRMED, DISMISSED). Default: all. + """ + uid = int(get_jwt_identity()) + q = db.session.query(Subscription).filter_by(user_id=uid) + + status_filter = request.args.get("status", "").upper() + if status_filter in [s.value for s in SubscriptionStatus]: + q = q.filter_by(status=SubscriptionStatus(status_filter)) + + items = q.order_by(Subscription.created_at.desc()).all() + return jsonify([_subscription_to_dict(s) for s in items]) + + +@bp.post("/detect") +@jwt_required() +def detect_subscriptions(): + """ + Manually trigger subscription detection for the current user. + Returns list of newly detected subscriptions. + """ + uid = int(get_jwt_identity()) + + try: + new_subs = detect_and_create_subscriptions(uid) + return jsonify({ + "detected": len(new_subs), + "subscriptions": [_subscription_to_dict(s) for s in new_subs] + }), 201 + except Exception as e: + logger.exception("Subscription detection failed user=%s", uid) + return jsonify(error="detection failed", details=str(e)), 500 + + +@bp.post("//confirm") +@jwt_required() +def confirm_subscription(subscription_id: int): + """ + Confirm a detected subscription (user accepts it as a real subscription). + """ + uid = int(get_jwt_identity()) + sub = db.session.get(Subscription, subscription_id) + + if not sub or sub.user_id != uid: + return jsonify(error="not found"), 404 + + if sub.status != SubscriptionStatus.DETECTED: + return jsonify(error="only DETECTED subscriptions can be confirmed"), 400 + + sub.status = SubscriptionStatus.CONFIRMED + db.session.commit() + + logger.info("Confirmed subscription id=%s user=%s", subscription_id, uid) + return jsonify(_subscription_to_dict(sub)), 200 + + +@bp.post("//dismiss") +@jwt_required() +def dismiss_subscription(subscription_id: int): + """ + Dismiss a detected subscription (user rejects the detection). + """ + uid = int(get_jwt_identity()) + sub = db.session.get(Subscription, subscription_id) + + if not sub or sub.user_id != uid: + return jsonify(error="not found"), 404 + + if sub.status == SubscriptionStatus.DISMISSED: + return jsonify(error="already dismissed"), 400 + + sub.status = SubscriptionStatus.DISMISSED + db.session.commit() + + logger.info("Dismissed subscription id=%s user=%s", subscription_id, uid) + return jsonify(_subscription_to_dict(sub)), 200 + + +@bp.delete("/") +@jwt_required() +def delete_subscription(subscription_id: int): + """ + Delete a subscription (soft-delete by setting status to DISMISSED). + """ + uid = int(get_jwt_identity()) + sub = db.session.get(Subscription, subscription_id) + + if not sub or sub.user_id != uid: + return jsonify(error="not found"), 404 + + # Soft delete: mark as dismissed + sub.status = SubscriptionStatus.DISMISSED + db.session.commit() + + logger.info("Deleted subscription id=%s user=%s", subscription_id, uid) + return jsonify(message="deleted"), 200 + + +@bp.get("/predictions/refresh") +@jwt_required() +def refresh_predictions(): + """ + Refresh next_predicted_date for all CONFIRMED subscriptions. + This can be run periodically via cron. + """ + uid = int(get_jwt_identity()) + + try: + detector = SubscriptionDetector() + detector.refresh_predictions(uid) + return jsonify(message="predictions refreshed"), 200 + except Exception as e: + logger.exception("Prediction refresh failed user=%s", uid) + return jsonify(error="refresh failed", details=str(e)), 500 + + +def _subscription_to_dict(s: Subscription) -> dict: + """Convert Subscription model to dict for JSON response.""" + return { + "id": s.id, + "merchant_name": s.merchant_name, + "category_id": s.category_id, + "amount": float(s.amount), + "currency": s.currency, + "detected_cadence": s.detected_cadence.value, + "confidence_score": float(s.confidence_score), + "occurrence_count": s.occurrence_count, + "first_occurrence_date": s.first_occurrence_date.isoformat(), + "last_occurrence_date": s.last_occurrence_date.isoformat(), + "next_predicted_date": s.next_predicted_date.isoformat() if s.next_predicted_date else None, + "average_amount": float(s.average_amount) if s.average_amount else None, + "amount_variance": float(s.amount_variance) if s.amount_variance else None, + "status": s.status.value, + "notes": s.notes, + "created_at": s.created_at.isoformat(), + "updated_at": s.updated_at.isoformat(), + } \ No newline at end of file diff --git a/packages/backend/app/services/__init__.py b/packages/backend/app/services/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/packages/backend/app/services/subscription_detector.py b/packages/backend/app/services/subscription_detector.py new file mode 100644 index 00000000..cc1e2180 --- /dev/null +++ b/packages/backend/app/services/subscription_detector.py @@ -0,0 +1,327 @@ +""" +Subscription detection service for auto-detecting recurring charges from expenses. +""" + +from datetime import date, timedelta, datetime +from decimal import Decimal +from typing import List, Dict, Tuple, Optional +from collections import defaultdict +import logging +from sqlalchemy import func + +from ..extensions import db +from ..models import Expense, Subscription, SubscriptionCadence, SubscriptionStatus + +logger = logging.getLogger("finmind.subscriptions") + + +class SubscriptionDetector: + """Detect subscription patterns from user expenses.""" + + # Minimum data points for reliable detection + MIN_OCCURRENCES = 3 + MIN_CONFIDENCE = 0.6 + + # Date variance thresholds (in days) for pattern matching + WEEKLY_TOLERANCE = 3 # ±3 days from expected weekly date + MONTHLY_TOLERANCE = 5 # ±5 days from expected monthly date + YEARLY_TOLERANCE = 10 # ±10 days from expected yearly date + + # Amount variance threshold (percentage) + AMOUNT_VARIANCE_THRESHOLD = 0.05 # 5% + + def detect_subscriptions(self, user_id: int) -> List[Subscription]: + """ + Analyze user expenses and detect potential subscriptions. + Returns a list of Subscription objects (not yet persisted). + """ + # Get all expenses for the user, ordered by date + expenses = ( + db.session.query(Expense) + .filter_by(user_id=user_id) + .order_by(Expense.spent_at, Expense.notes) + .all() + ) + + if len(expenses) < self.MIN_OCCURRENCES: + return [] + + # Group expenses by (merchant, currency) + groups = self._group_by_merchant_and_currency(expenses) + + subscriptions = [] + for (merchant_name, currency), expense_list in groups.items(): + if len(expense_list) < self.MIN_OCCURRENCES: + continue + + # Analyze each merchant for recurring patterns + sub = self._analyze_merchant_pattern(user_id, merchant_name, currency, expense_list) + if sub and sub.confidence_score >= self.MIN_CONFIDENCE: + subscriptions.append(sub) + + return subscriptions + + def _group_by_merchant_and_currency(self, expenses: List[Expense]) -> Dict[Tuple[str, str], List[Expense]]: + """Group expenses by (normalized merchant name, currency).""" + groups = defaultdict(list) + for exp in expenses: + merchant = self._normalize_merchant_name(exp.notes or "") + if merchant: + key = (merchant, exp.currency) + groups[key].append(exp) + return groups + + def _normalize_merchant_name(self, notes: str) -> str: + """ + Normalize merchant name from expense notes. + Simple approach: lowercase, strip, remove common prefixes. + """ + if not notes: + return "" + + # Basic normalization + normalized = notes.strip().lower() + + # Remove common prefixes that might interfere + prefixes = ["payment to ", "paid to ", "transaction at ", "purchase at "] + for prefix in prefixes: + if normalized.startswith(prefix): + normalized = normalized[len(prefix):] + + # Remove extra whitespace + normalized = " ".join(normalized.split()) + + return normalized if normalized else notes.strip() + + def _analyze_merchant_pattern(self, user_id: int, merchant: str, currency: str, expenses: List[Expense]) -> Optional[Subscription]: + """ + Analyze a group of expenses to determine if they represent a subscription. + """ + if len(expenses) < self.MIN_OCCURRENCES: + return None + + # Sort by date + expenses.sort(key=lambda e: e.spent_at) + + # Calculate date intervals + intervals = self._calculate_intervals(expenses) + if not intervals: + return None + + # Detect cadence + cadence, avg_interval_days, interval_std = self._detect_cadence(intervals) + if not cadence: + return None + + # Check amount consistency + amounts = [float(exp.amount) for exp in expenses] + avg_amount = sum(amounts) / len(amounts) + amount_variance = self._calculate_variance(amounts, avg_amount) + + # Calculate confidence score + confidence = self._calculate_confidence( + len(expenses), avg_interval_days, interval_std, amount_variance, cadence + ) + + # Determine tolerance based on cadence + tolerance = self._get_tolerance_for_cadence(cadence) + + # Check if intervals are regular within tolerance + regularity_score = self._calculate_regularity(intervals, tolerance) + + # Adjust confidence based on regularity + confidence *= regularity_score + + if confidence < self.MIN_CONFIDENCE: + return None + + # Calculate average amount and variance as Decimal + avg_amount_dec = Decimal(str(round(avg_amount, 2))) + variance_dec = Decimal(str(round(amount_variance, 2))) + + # Predict next date + next_date = self._predict_next_date(expenses[-1].spent_at, cadence) + + # Create Subscription object + subscription = Subscription( + user_id=user_id, + merchant_name=merchant.title(), # Nice formatting + amount=avg_amount_dec, + currency=currency, + detected_cadence=cadence, + confidence_score=Decimal(str(round(confidence, 2))), + occurrence_count=len(expenses), + first_occurrence_date=expenses[0].spent_at, + last_occurrence_date=expenses[-1].spent_at, + next_predicted_date=next_date, + average_amount=avg_amount_dec, + amount_variance=variance_dec, + status=SubscriptionStatus.DETECTED, + notes=f"Auto-detected from {len(expenses)} occurrences" + ) + + return subscription + + def _calculate_intervals(self, expenses: List[Expense]) -> List[int]: + """Calculate days between consecutive expenses.""" + intervals = [] + for i in range(1, len(expenses)): + delta = (expenses[i].spent_at - expenses[i-1].spent_at).days + intervals.append(delta) + return intervals + + def _detect_cadence(self, intervals: List[int]) -> Tuple[Optional[SubscriptionCadence], float, float]: + """ + Determine the most likely cadence based on intervals. + Returns (cadence, average_days, standard_deviation). + """ + if not intervals: + return None, 0.0, 0.0 + + avg_interval = sum(intervals) / len(intervals) + std_dev = self._std_deviation(intervals, avg_interval) + + # Check for weekly pattern (6-8 days, allowing for weekend shifts) + if 6 <= avg_interval <= 8 and std_dev <= self.WEEKLY_TOLERANCE: + return SubscriptionCadence.WEEKLY, avg_interval, std_dev + + # Check for monthly pattern (25-35 days) + if 25 <= avg_interval <= 35 and std_dev <= self.MONTHLY_TOLERANCE: + return SubscriptionCadence.MONTHLY, avg_interval, std_dev + + # Check for yearly pattern (350-380 days) + if 350 <= avg_interval <= 380 and std_dev <= self.YEARLY_TOLERANCE: + return SubscriptionCadence.YEARLY, avg_interval, std_dev + + return None, avg_interval, std_dev + + def _calculate_variance(self, values: List[float], mean: float) -> float: + """Calculate variance as percentage of mean.""" + if len(values) < 2: + return 0.0 + squared_diffs = [(v - mean) ** 2 for v in values] + variance = sum(squared_diffs) / len(values) + # Return as percentage of mean + return (variance ** 0.5) / mean if mean > 0 else 0.0 + + def _calculate_confidence(self, count: int, avg_interval: float, std_dev: float, + amount_variance: float, cadence: SubscriptionCadence) -> float: + """ + Calculate confidence score (0.0-1.0) based on: + - Number of occurrences + - Interval regularity + - Amount consistency + """ + score = 0.0 + + # Base score from occurrence count (more data = more confident) + count_score = min(count / 12.0, 1.0) # Cap at 12 occurrences + score += count_score * 0.3 + + # Interval regularity score + tolerance = self._get_tolerance_for_cadence(cadence) + regularity = max(0.0, 1.0 - (std_dev / tolerance)) + score += regularity * 0.4 + + # Amount consistency score + amount_consistency = max(0.0, 1.0 - (amount_variance / self.AMOUNT_VARIANCE_THRESHOLD)) + score += amount_consistency * 0.3 + + return min(score, 1.0) + + def _calculate_regularity(self, intervals: List[int], tolerance: int) -> float: + """ + Calculate how regular the intervals are within tolerance. + Returns score between 0.0 and 1.0. + """ + if not intervals: + return 0.0 + + # For detected cadences, intervals should already be close + # This calculates the proportion of intervals within extended tolerance + extended_tolerance = tolerance * 2 + within_tolerance = sum(1 for i in intervals if abs(i - intervals[0]) <= extended_tolerance) + return within_tolerance / len(intervals) + + def _get_tolerance_for_cadence(self, cadence: SubscriptionCadence) -> int: + """Get day tolerance for a given cadence.""" + if cadence == SubscriptionCadence.WEEKLY: + return self.WEEKLY_TOLERANCE + if cadence == SubscriptionCadence.MONTHLY: + return self.MONTHLY_TOLERANCE + if cadence == SubscriptionCadence.YEARLY: + return self.YEARLY_TOLERANCE + return 5 + + def _predict_next_date(self, last_date: date, cadence: SubscriptionCadence) -> date: + """Predict the next occurrence date based on cadence.""" + if cadence == SubscriptionCadence.WEEKLY: + return last_date + timedelta(days=7) + if cadence == SubscriptionCadence.MONTHLY: + year = last_date.year + (1 if last_date.month == 12 else 0) + month = 1 if last_date.month == 12 else last_date.month + 1 + day = min(last_date.day, self._days_in_month(year, month)) + return date(year, month, day) + if cadence == SubscriptionCadence.YEARLY: + return date(last_date.year + 1, last_date.month, last_date.day) + return last_date + + def _days_in_month(self, year: int, month: int) -> int: + """Get number of days in a month.""" + import calendar + return calendar.monthrange(year, month)[1] + + def _std_deviation(self, values: List[float], mean: float) -> float: + """Calculate standard deviation.""" + if len(values) < 2: + return 0.0 + variance = sum((x - mean) ** 2 for x in values) / len(values) + return variance ** 0.5 + + def refresh_predictions(self, user_id: int) -> None: + """Update next_predicted_date for all confirmed subscriptions.""" + subscriptions = ( + db.session.query(Subscription) + .filter_by(user_id=user_id, status=SubscriptionStatus.CONFIRMED) + .all() + ) + + for sub in subscriptions: + sub.next_predicted_date = self._predict_next_date( + sub.last_occurrence_date, sub.detected_cadence + ) + + db.session.commit() + + +def detect_and_create_subscriptions(user_id: int) -> List[Subscription]: + """ + Run detection for a user and persist new subscriptions. + Returns list of newly created subscriptions. + """ + detector = SubscriptionDetector() + detected = detector.detect_subscriptions(user_id) + + # Check which ones are new (not already in DB with similar pattern) + new_subs = [] + for sub in detected: + # Simple deduplication: check if there's an existing subscription + # with same user, merchant, and cadence that is CONFIRMED or DETECTED + existing = ( + db.session.query(Subscription) + .filter_by( + user_id=user_id, + merchant_name=sub.merchant_name, + detected_cadence=sub.detected_cadence, + ) + .filter(Subscription.status.in_([SubscriptionStatus.DETECTED, SubscriptionStatus.CONFIRMED])) + .first() + ) + + if not existing: + db.session.add(sub) + new_subs.append(sub) + + db.session.commit() + logger.info("Detected %d new subscriptions for user %s", len(new_subs), user_id) + return new_subs \ No newline at end of file diff --git a/packages/backend/tests/conftest.py b/packages/backend/tests/conftest.py index a7315b8c..84cb928c 100644 --- a/packages/backend/tests/conftest.py +++ b/packages/backend/tests/conftest.py @@ -1,16 +1,37 @@ import os import pytest +from unittest.mock import MagicMock, patch + +# Mock Redis BEFORE any app modules are imported +_mock_redis = MagicMock() +_mock_redis.get.return_value = None +_mock_redis.set.return_value = True +_mock_redis.setex.return_value = True +_mock_redis.delete.return_value = True +_mock_redis.scan.return_value = (0, []) +_mock_redis.flushdb.return_value = True +_mock_redis.keys.return_value = [] + +# Patch redis.Redis.from_url to return our mock +_original_from_url = __import__('redis').Redis.from_url + +def mock_from_url(url, **kwargs): + return _mock_redis + +# Apply the patch at import time +import redis +redis.Redis.from_url = mock_from_url + +# Now import app modules - they will get the mocked redis_client from app import create_app from app.config import Settings -from app.extensions import db -from app.extensions import redis_client +from app.extensions import db, redis_client from app import models # noqa: F401 - ensure models are registered class TestSettings(Settings): - # Override defaults for tests database_url: str = "sqlite+pysqlite:///:memory:" - redis_url: str = "redis://localhost:6379/15" # not used in tests + redis_url: str = "redis://localhost:6379/15" jwt_secret: str = "test-secret" @@ -21,7 +42,6 @@ def _setup_db(app): @pytest.fixture() def app_fixture(): - # Ensure a clean env for tests os.environ.setdefault("FLASK_ENV", "testing") settings = TestSettings( database_url="sqlite+pysqlite:///:memory:", @@ -31,18 +51,10 @@ def app_fixture(): app = create_app(settings) app.config.update(TESTING=True) _setup_db(app) - try: - redis_client.flushdb() - except Exception: - pass yield app with app.app_context(): db.session.remove() db.drop_all() - try: - redis_client.flushdb() - except Exception: - pass @pytest.fixture() @@ -52,17 +64,12 @@ def client(app_fixture): @pytest.fixture() def auth_header(client): - # Register and login a default user, return auth header email = "test@example.com" password = "password123" r = client.post("/auth/register", json={"email": email, "password": password}) register_debug = f"register failed: status={r.status_code}, body={r.get_json()}" - assert r.status_code in ( - 200, - 201, - 409, - ), register_debug # 409 if already exists + assert r.status_code in (200, 201, 409), register_debug r = client.post("/auth/login", json={"email": email, "password": password}) assert r.status_code == 200 access = r.get_json()["access_token"] - return {"Authorization": f"Bearer {access}"} + return {"Authorization": f"Bearer {access}"} \ No newline at end of file diff --git a/packages/backend/tests/test_subscriptions.py b/packages/backend/tests/test_subscriptions.py new file mode 100644 index 00000000..6be84eff --- /dev/null +++ b/packages/backend/tests/test_subscriptions.py @@ -0,0 +1,327 @@ +""" +Tests for subscription detection and management. +""" + +import pytest +from datetime import date, timedelta +from decimal import Decimal +from unittest.mock import MagicMock + +from app import create_app +from app.extensions import db, redis_client +from app.models import Expense, Subscription, SubscriptionStatus, SubscriptionCadence, User, Category +from app.services import cache + + +@pytest.fixture(autouse=True) +def mock_redis_operations(monkeypatch): + """Mock Redis operations to avoid connection errors in tests.""" + # Mock redis_client methods + mock_redis = MagicMock() + mock_redis.get.return_value = None + mock_redis.set.return_value = True + mock_redis.setex.return_value = True + mock_redis.delete.return_value = True + mock_redis.scan.return_value = (0, []) + mock_redis.flushdb.return_value = True + + monkeypatch.setattr("app.extensions.redis_client", mock_redis) + + # Also mock cache functions + def mock_cache_delete_patterns(patterns): + pass + + monkeypatch.setattr(cache, "cache_delete_patterns", mock_cache_delete_patterns) + + +def _create_user_and_auth(client, email="test@example.com", password="password123"): + """Helper to create user and return auth header.""" + r = client.post("/auth/register", json={"email": email, "password": password}) + assert r.status_code in (200, 201, 409) + r = client.post("/auth/login", json={"email": email, "password": password}) + assert r.status_code == 200 + access = r.get_json()["access_token"] + return {"Authorization": f"Bearer {access}"} + + +def _create_category(client, auth_header, name="General"): + r = client.post("/categories", json={"name": name}, headers=auth_header) + assert r.status_code in (201, 409) + r = client.get("/categories", headers=auth_header) + assert r.status_code == 200 + return r.get_json()[0]["id"] + + +def test_subscription_detection_weekly_pattern(client, auth_header): + """Test detection of weekly subscription pattern.""" + cat_id = _create_category(client, auth_header, name="Streaming") + + # Create weekly expenses for Netflix: same amount, ~7 days apart + base_date = date(2026, 1, 1) + amounts = [15.99, 15.99, 15.99, 15.99] # 4 occurrences + + for i, amount in enumerate(amounts): + expense_date = base_date + timedelta(days=7 * i) + r = client.post("/expenses", json={ + "amount": float(amount), + "currency": "USD", + "category_id": cat_id, + "description": "Netflix", + "date": expense_date.isoformat(), + }, headers=auth_header) + assert r.status_code == 201 + + # Trigger detection + r = client.post("/subscriptions/detect", headers=auth_header) + assert r.status_code in (201, 200) + result = r.get_json() + + # Should detect at least one subscription + assert result["detected"] >= 1 + subs = result["subscriptions"] + + # Find Netflix subscription + netflix_sub = next((s for s in subs if "netflix" in s["merchant_name"].lower()), None) + assert netflix_sub is not None + assert netflix_sub["detected_cadence"] == "WEEKLY" + assert netflix_sub["confidence_score"] >= 0.7 + assert netflix_sub["occurrence_count"] == 4 + assert netflix_sub["status"] == "DETECTED" + + +def test_subscription_detection_monthly_pattern(client, auth_header): + """Test detection of monthly subscription pattern.""" + cat_id = _create_category(client, auth_header, name="Utilities") + + # Create monthly expenses for Spotify: same amount, ~30 days apart + base_date = date(2026, 1, 15) + amounts = [9.99, 9.99, 9.99] + + for i, amount in enumerate(amounts): + # Add some day variance to test robustness + expense_date = base_date + timedelta(days=30 * i + (2 if i > 0 else 0)) + r = client.post("/expenses", json={ + "amount": float(amount), + "currency": "USD", + "category_id": cat_id, + "description": "Spotify", + "date": expense_date.isoformat(), + }, headers=auth_header) + assert r.status_code == 201 + + r = client.post("/subscriptions/detect", headers=auth_header) + assert r.status_code in (201, 200) + result = r.get_json() + + spotify_sub = next((s for s in result["subscriptions"] if "spotify" in s["merchant_name"].lower()), None) + assert spotify_sub is not None + assert spotify_sub["detected_cadence"] == "MONTHLY" + assert spotify_sub["confidence_score"] >= 0.7 + + +def test_subscription_detection_insufficient_data(client, auth_header): + """Test that detection requires minimum occurrences.""" + cat_id = _create_category(client, auth_header) + + # Only 2 occurrences - should not create subscription + base_date = date(2026, 1, 1) + for i in range(2): + r = client.post("/expenses", json={ + "amount": 10.0, + "currency": "USD", + "category_id": cat_id, + "description": "Rare Service", + "date": (base_date + timedelta(days=30 * i)).isoformat(), + }, headers=auth_header) + assert r.status_code == 201 + + r = client.post("/subscriptions/detect", headers=auth_header) + assert r.status_code in (201, 200) + result = r.get_json() + + # Should not detect (insufficient data) + assert result["detected"] == 0 + + +def test_subscription_list_empty(client, auth_header): + """Test listing subscriptions when none exist.""" + r = client.get("/subscriptions", headers=auth_header) + assert r.status_code == 200 + assert r.get_json() == [] + + +def test_subscription_confirm_and_dismiss(client, auth_header): + """Test confirming and dismissing a detected subscription.""" + cat_id = _create_category(client, auth_header, name="Subscriptions") + + # Create expenses for a clear weekly pattern + base_date = date(2026, 1, 1) + for i in range(4): + r = client.post("/expenses", json={ + "amount": 19.99, + "currency": "USD", + "category_id": cat_id, + "description": "Hulu", + "date": (base_date + timedelta(days=7 * i)).isoformat(), + }, headers=auth_header) + assert r.status_code == 201 + + # Detect + r = client.post("/subscriptions/detect", headers=auth_header) + assert r.status_code in (201, 200) + subs = r.get_json()["subscriptions"] + hulu_sub = next(s for s in subs if "hulu" in s["merchant_name"].lower()) + sub_id = hulu_sub["id"] + assert hulu_sub["status"] == "DETECTED" + + # Confirm + r = client.post(f"/subscriptions/{sub_id}/confirm", headers=auth_header) + assert r.status_code == 200 + confirmed = r.get_json() + assert confirmed["status"] == "CONFIRMED" + + # Try to confirm again - should fail + r = client.post(f"/subscriptions/{sub_id}/confirm", headers=auth_header) + assert r.status_code == 400 + + # Dismiss + r = client.post(f"/subscriptions/{sub_id}/dismiss", headers=auth_header) + assert r.status_code == 200 + dismissed = r.get_json() + assert dismissed["status"] == "DISMISSED" + + # Delete (soft delete) + r = client.delete(f"/subscriptions/{sub_id}", headers=auth_header) + assert r.status_code == 200 + assert r.get_json()["message"] == "deleted" + + +def test_subscription_filter_by_status(client, auth_header): + """Test filtering subscriptions by status.""" + cat_id = _create_category(client, auth_header) + + # Create two different patterns + base_date = date(2026, 1, 1) + + # Pattern 1: weekly (will be DETECTED) + for i in range(4): + client.post("/expenses", json={ + "amount": 10.0, + "currency": "USD", + "category_id": cat_id, + "description": "Service A", + "date": (base_date + timedelta(days=7 * i)).isoformat(), + }, headers=auth_header) + + # Pattern 2: monthly (will be DETECTED) + for i in range(4): + client.post("/expenses", json={ + "amount": 20.0, + "currency": "USD", + "category_id": cat_id, + "description": "Service B", + "date": (base_date + timedelta(days=30 * i)).isoformat(), + }, headers=auth_header) + + client.post("/subscriptions/detect", headers=auth_header) + + # List all + r = client.get("/subscriptions", headers=auth_header) + assert r.status_code == 200 + all_subs = r.get_json() + assert len(all_subs) >= 2 + + # Filter DETECTED + r = client.get("/subscriptions?status=DETECTED", headers=auth_header) + assert r.status_code == 200 + detected = r.get_json() + assert all(s["status"] == "DETECTED" for s in detected) + + +def test_refresh_predictions(client, auth_header): + """Test refreshing next_predicted_date for confirmed subscriptions.""" + cat_id = _create_category(client, auth_header) + + # Create weekly pattern and detect + base_date = date(2026, 1, 1) + for i in range(4): + client.post("/expenses", json={ + "amount": 15.0, + "currency": "USD", + "category_id": cat_id, + "description": "Test Sub", + "date": (base_date + timedelta(days=7 * i)).isoformat(), + }, headers=auth_header) + + r = client.post("/subscriptions/detect", headers=auth_header) + subs = r.get_json()["subscriptions"] + sub_id = subs[0]["id"] + + # Confirm it + client.post(f"/subscriptions/{sub_id}/confirm", headers=auth_header) + + # Refresh predictions + r = client.get("/subscriptions/predictions/refresh", headers=auth_header) + assert r.status_code == 200 + assert r.get_json()["message"] == "predictions refreshed" + + # Verify next_predicted_date is set + r = client.get(f"/subscriptions/{sub_id}", headers=auth_header) + # Note: we'd need a GET /subscriptions/{id} endpoint, but for now we check via list + r = client.get("/subscriptions?status=CONFIRMED", headers=auth_header) + confirmed = r.get_json() + assert len(confirmed) > 0 + assert confirmed[0]["next_predicted_date"] is not None + + +def test_detection_ignores_different_amounts(client, auth_header): + """Test that expenses with varying amounts are less likely to be detected.""" + cat_id = _create_category(client, auth_header) + + base_date = date(2026, 1, 1) + amounts = [10.0, 15.0, 12.0, 20.0] # High variance + + for i, amount in enumerate(amounts): + client.post("/expenses", json={ + "amount": amount, + "currency": "USD", + "category_id": cat_id, + "description": "Variable Service", + "date": (base_date + timedelta(days=7 * i)).isoformat(), + }, headers=auth_header) + + r = client.post("/subscriptions/detect", headers=auth_header) + result = r.get_json() + + # Might detect but confidence should be lower + if result["detected"] > 0: + sub = next((s for s in result["subscriptions"] if "variable" in s["merchant_name"].lower()), None) + if sub: + assert sub["confidence_score"] < 0.8 + assert sub["amount_variance"] is not None + + +def test_merchant_name_normalization(client, auth_header): + """Test that similar merchant names are grouped.""" + cat_id = _create_category(client, auth_header) + + base_date = date(2026, 1, 1) + # Slight variations in name + names = ["Netflix", "Netflix ", " NETFLIX", "netflix"] + + for i, name in enumerate(names): + client.post("/expenses", json={ + "amount": 15.99, + "currency": "USD", + "category_id": cat_id, + "description": name, + "date": (base_date + timedelta(days=7 * i)).isoformat(), + }, headers=auth_header) + + r = client.post("/subscriptions/detect", headers=auth_header) + result = r.get_json() + + # Should group them as one subscription + netflix_subs = [s for s in result["subscriptions"] if "netflix" in s["merchant_name"].lower()] + assert len(netflix_subs) == 1 + assert netflix_subs[0]["occurrence_count"] == 4 \ No newline at end of file