diff --git a/src/adapters/portfolio.py b/src/adapters/portfolio.py index b5582d8..ec4de53 100644 --- a/src/adapters/portfolio.py +++ b/src/adapters/portfolio.py @@ -3,7 +3,7 @@ from pymongo import MongoClient from pymongo.database import Database from .base import AbstractPortfolioRepository -from domain.portfolio import Portfolio +from domain.portfolio import Portfolio, Holding class PortfolioRepository(AbstractPortfolioRepository): @@ -12,6 +12,23 @@ def __init__(self, mongo_client: MongoClient, database_name: str = "stock_db"): self.db: Database = self.client[database_name] self.collection = self.db["portfolio"] + def get(self, user_id: int) -> Portfolio: + result = self.collection.find_one({"user_id": user_id}) + if result is None: + return None + + return Portfolio( + user_id=result["user_id"], + cash_balance=result["cash_balance"], + total_money_in=result["total_money_in"], + holdings=[ + Holding(symbol=holding["symbol"], shares=holding["shares"], total_cost=holding["total_cost"]) + for holding in result["holdings"] + ], + created_at=result["created_at"], + updated_at=result["updated_at"], + ) + def update(self, portfolio: Portfolio) -> None: portfolio.updated_at = datetime.now(timezone.utc) self.collection.replace_one({"user_id": portfolio.user_id}, asdict(portfolio), upsert=True) diff --git a/src/tests/test_portfolio_adapters.py b/src/tests/test_portfolio_adapters.py index 7148a62..3b49b80 100644 --- a/src/tests/test_portfolio_adapters.py +++ b/src/tests/test_portfolio_adapters.py @@ -1,5 +1,6 @@ import pytest from datetime import datetime, timezone +from unittest.mock import ANY from dataclasses import asdict from pymongo import MongoClient from domain.portfolio import Portfolio, Holding @@ -60,7 +61,7 @@ def test_update_existing_portfolio(self, portfolio_repository): created_at=datetime.now(timezone.utc), updated_at=datetime.now(timezone.utc), ) - result = portfolio_repository.collection.insert_one(asdict(initial_portfolio)) + portfolio_repository.collection.insert_one(asdict(initial_portfolio)) updated_portfolio = Portfolio( user_id=1, @@ -84,3 +85,49 @@ def test_update_existing_portfolio(self, portfolio_repository): assert result["holdings"][0]["symbol"] == "AAPL" assert result["holdings"][0]["shares"] == 20 assert result["holdings"][0]["total_cost"] == 3000.0 + + def test_get_existing_portfolio(self, portfolio_repository): + # Arrange + created_at = datetime.now(timezone.utc) + portfolio = Portfolio( + user_id=1, + cash_balance=1000.0, + total_money_in=1000.0, + holdings=[Holding(symbol="AAPL", shares=10, total_cost=1500.0)], + created_at=created_at, + updated_at=created_at, + ) + portfolio_repository.collection.insert_one(asdict(portfolio)) + expected_result = Portfolio( + user_id=1, + cash_balance=1000.0, + total_money_in=1000.0, + holdings=[Holding(symbol="AAPL", shares=10, total_cost=1500.0)], + created_at=ANY, + updated_at=ANY, + ) + + # Action + result = portfolio_repository.get(user_id=1) + + # Assertion + assert result == expected_result + + def test_get_non_existent_portfolio(self, portfolio_repository): + # Arrange + created_at = datetime.now(timezone.utc) + portfolio = Portfolio( + user_id=1, + cash_balance=1000.0, + total_money_in=1000.0, + holdings=[Holding(symbol="AAPL", shares=10, total_cost=1500.0)], + created_at=created_at, + updated_at=created_at, + ) + portfolio_repository.collection.insert_one(asdict(portfolio)) + + # Action + result = portfolio_repository.get(user_id=999) + + # Assertion + assert result is None