Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions src/adapters/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,20 @@
from abc import ABC, abstractmethod

from domain.stock import CreateStock, Stock
from domain.portfolio import Portfolio


class AbstractStockRepository(ABC):
@abstractmethod
def create(self, stock: CreateStock) -> str:
"""Create a new stock entry in the repository."""

@abstractmethod
def list(self, user_id: int) -> List[Stock]:
"""List all stock by user id"""


class AbstractPortfolioRepository(ABC):
@abstractmethod
def update(self, portfolio: Portfolio) -> None:
"""Update portfolio in the repository"""
20 changes: 20 additions & 0 deletions src/adapters/portfolio.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from dataclasses import asdict
from datetime import datetime, timezone
from pymongo import MongoClient
from pymongo.database import Database
from .base import AbstractPortfolioRepository
from domain.portfolio import Portfolio


class PortfolioRepository(AbstractPortfolioRepository):
def __init__(self, mongo_client: MongoClient, database_name: str = "stock_db"):
self.client = mongo_client
self.db: Database = self.client[database_name]
self.collection = self.db["portfolio"]

def update(self, portfolio: Portfolio) -> None:
portfolio.updated_at = datetime.now(timezone.utc)
self.collection.replace_one({"user_id": portfolio.user_id}, asdict(portfolio), upsert=True)

def __del__(self):
self.client.close()
20 changes: 20 additions & 0 deletions src/domain/portfolio.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from dataclasses import dataclass
from datetime import datetime
from typing import List


@dataclass
class Holding:
symbol: str
shares: int
total_cost: float


@dataclass
class Portfolio:
user_id: int
cash_balance: float
total_money_in: float
holdings: List[Holding]
created_at: datetime
updated_at: datetime
86 changes: 86 additions & 0 deletions src/tests/test_portfolio_adapters.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
import pytest
from datetime import datetime, timezone
from dataclasses import asdict
from pymongo import MongoClient
from domain.portfolio import Portfolio, Holding
from adapters.portfolio import PortfolioRepository


@pytest.fixture(scope="module")
def mongo_client():
client = MongoClient("mongodb://localhost:27015")
yield client
client.drop_database("test_stock_db")
client.close()


@pytest.fixture(scope="module")
def portfolio_repository(mongo_client):
return PortfolioRepository(mongo_client, database_name="test_stock_db")


@pytest.fixture(scope="function", autouse=True)
def clear_collection(portfolio_repository):
portfolio_repository.collection.delete_many({})


class TestPortfolioRepository:
def test_update_new_portfolio(self, portfolio_repository):
# Arrange
created_at = datetime.now(timezone.utc)
portfolio = Portfolio(
user_id=1,
cash_balance=1000.0,
total_money_in=1000.0,
holdings=[Holding(symbol="AAPL", shares=10, total_cost=1500.0)],
created_at=created_at,
updated_at=created_at,
)

# Action
portfolio_repository.update(portfolio)

# Assertion
result = portfolio_repository.collection.find_one({"user_id": 1})
assert result["user_id"] == portfolio.user_id
assert result["cash_balance"] == portfolio.cash_balance
assert result["total_money_in"] == portfolio.total_money_in
assert len(result["holdings"]) == 1
assert result["holdings"][0]["symbol"] == "AAPL"
assert result["holdings"][0]["shares"] == 10
assert result["holdings"][0]["total_cost"] == 1500.0

def test_update_existing_portfolio(self, portfolio_repository):
# Arrange
initial_portfolio = Portfolio(
user_id=1,
cash_balance=1000.0,
total_money_in=1000.0,
holdings=[Holding(symbol="AAPL", shares=10, total_cost=1500.0)],
created_at=datetime.now(timezone.utc),
updated_at=datetime.now(timezone.utc),
)
result = portfolio_repository.collection.insert_one(asdict(initial_portfolio))

updated_portfolio = Portfolio(
user_id=1,
cash_balance=2000.0,
total_money_in=2000.0,
holdings=[Holding(symbol="AAPL", shares=20, total_cost=3000.0)],
created_at=initial_portfolio.created_at,
updated_at=datetime.now(timezone.utc),
)

# Action
portfolio_repository.update(updated_portfolio)

# Assertion
result = portfolio_repository.collection.find_one({"user_id": 1})
assert result is not None
assert result["user_id"] == updated_portfolio.user_id
assert result["cash_balance"] == 2000.0
assert result["total_money_in"] == 2000.0
assert len(result["holdings"]) == 1
assert result["holdings"][0]["symbol"] == "AAPL"
assert result["holdings"][0]["shares"] == 20
assert result["holdings"][0]["total_cost"] == 3000.0