diff --git a/src/adapters/base.py b/src/adapters/base.py index c54d517..972c670 100644 --- a/src/adapters/base.py +++ b/src/adapters/base.py @@ -2,6 +2,7 @@ from abc import ABC, abstractmethod from domain.stock import CreateStock, Stock +from domain.portfolio import Portfolio class AbstractStockRepository(ABC): @@ -9,5 +10,12 @@ class AbstractStockRepository(ABC): def create(self, stock: CreateStock) -> str: """Create a new stock entry in the repository.""" + @abstractmethod def list(self, user_id: int) -> List[Stock]: """List all stock by user id""" + + +class AbstractPortfolioRepository(ABC): + @abstractmethod + def update(self, portfolio: Portfolio) -> None: + """Update portfolio in the repository""" diff --git a/src/adapters/portfolio.py b/src/adapters/portfolio.py new file mode 100644 index 0000000..b5582d8 --- /dev/null +++ b/src/adapters/portfolio.py @@ -0,0 +1,20 @@ +from dataclasses import asdict +from datetime import datetime, timezone +from pymongo import MongoClient +from pymongo.database import Database +from .base import AbstractPortfolioRepository +from domain.portfolio import Portfolio + + +class PortfolioRepository(AbstractPortfolioRepository): + def __init__(self, mongo_client: MongoClient, database_name: str = "stock_db"): + self.client = mongo_client + self.db: Database = self.client[database_name] + self.collection = self.db["portfolio"] + + def update(self, portfolio: Portfolio) -> None: + portfolio.updated_at = datetime.now(timezone.utc) + self.collection.replace_one({"user_id": portfolio.user_id}, asdict(portfolio), upsert=True) + + def __del__(self): + self.client.close() diff --git a/src/domain/portfolio.py b/src/domain/portfolio.py new file mode 100644 index 0000000..b1d6610 --- /dev/null +++ b/src/domain/portfolio.py @@ -0,0 +1,20 @@ +from dataclasses import dataclass +from datetime import datetime +from typing import List + + +@dataclass +class Holding: + symbol: str + shares: int + total_cost: float + + +@dataclass +class Portfolio: + user_id: int + cash_balance: float + total_money_in: float + holdings: List[Holding] + created_at: datetime + updated_at: datetime diff --git a/src/tests/test_portfolio_adapters.py b/src/tests/test_portfolio_adapters.py new file mode 100644 index 0000000..7148a62 --- /dev/null +++ b/src/tests/test_portfolio_adapters.py @@ -0,0 +1,86 @@ +import pytest +from datetime import datetime, timezone +from dataclasses import asdict +from pymongo import MongoClient +from domain.portfolio import Portfolio, Holding +from adapters.portfolio import PortfolioRepository + + +@pytest.fixture(scope="module") +def mongo_client(): + client = MongoClient("mongodb://localhost:27015") + yield client + client.drop_database("test_stock_db") + client.close() + + +@pytest.fixture(scope="module") +def portfolio_repository(mongo_client): + return PortfolioRepository(mongo_client, database_name="test_stock_db") + + +@pytest.fixture(scope="function", autouse=True) +def clear_collection(portfolio_repository): + portfolio_repository.collection.delete_many({}) + + +class TestPortfolioRepository: + def test_update_new_portfolio(self, portfolio_repository): + # Arrange + created_at = datetime.now(timezone.utc) + portfolio = Portfolio( + user_id=1, + cash_balance=1000.0, + total_money_in=1000.0, + holdings=[Holding(symbol="AAPL", shares=10, total_cost=1500.0)], + created_at=created_at, + updated_at=created_at, + ) + + # Action + portfolio_repository.update(portfolio) + + # Assertion + result = portfolio_repository.collection.find_one({"user_id": 1}) + assert result["user_id"] == portfolio.user_id + assert result["cash_balance"] == portfolio.cash_balance + assert result["total_money_in"] == portfolio.total_money_in + assert len(result["holdings"]) == 1 + assert result["holdings"][0]["symbol"] == "AAPL" + assert result["holdings"][0]["shares"] == 10 + assert result["holdings"][0]["total_cost"] == 1500.0 + + def test_update_existing_portfolio(self, portfolio_repository): + # Arrange + initial_portfolio = Portfolio( + user_id=1, + cash_balance=1000.0, + total_money_in=1000.0, + holdings=[Holding(symbol="AAPL", shares=10, total_cost=1500.0)], + created_at=datetime.now(timezone.utc), + updated_at=datetime.now(timezone.utc), + ) + result = portfolio_repository.collection.insert_one(asdict(initial_portfolio)) + + updated_portfolio = Portfolio( + user_id=1, + cash_balance=2000.0, + total_money_in=2000.0, + holdings=[Holding(symbol="AAPL", shares=20, total_cost=3000.0)], + created_at=initial_portfolio.created_at, + updated_at=datetime.now(timezone.utc), + ) + + # Action + portfolio_repository.update(updated_portfolio) + + # Assertion + result = portfolio_repository.collection.find_one({"user_id": 1}) + assert result is not None + assert result["user_id"] == updated_portfolio.user_id + assert result["cash_balance"] == 2000.0 + assert result["total_money_in"] == 2000.0 + assert len(result["holdings"]) == 1 + assert result["holdings"][0]["symbol"] == "AAPL" + assert result["holdings"][0]["shares"] == 20 + assert result["holdings"][0]["total_cost"] == 3000.0