diff --git a/src/adapters/base.py b/src/adapters/base.py index 972c670..6823a62 100644 --- a/src/adapters/base.py +++ b/src/adapters/base.py @@ -19,3 +19,7 @@ class AbstractPortfolioRepository(ABC): @abstractmethod def update(self, portfolio: Portfolio) -> None: """Update portfolio in the repository""" + + @abstractmethod + def get(self, user_id: int) -> Portfolio: + """Get Portfolio""" diff --git a/src/tests/test_stock_usecase.py b/src/tests/test_stock_usecase.py index 9ff9fd2..6c085e3 100644 --- a/src/tests/test_stock_usecase.py +++ b/src/tests/test_stock_usecase.py @@ -1,64 +1,368 @@ import pytest from datetime import datetime, timezone -from unittest.mock import Mock +from unittest.mock import Mock, ANY from usecase.stock import StockUsecase from domain.stock import CreateStock, ActionType, Stock +from domain.portfolio import Portfolio, Holding @pytest.fixture def stock_usecase(): - mock_repo = Mock() - usecase = StockUsecase(stock_repo=mock_repo) - return usecase, mock_repo + stock_repo = Mock() + portfolio_repo = Mock() + usecase = StockUsecase(stock_repo=stock_repo, portfolio_repo=portfolio_repo) + return usecase, stock_repo, portfolio_repo class TestStockUsecase: - def test_create(self, stock_usecase): + def test_create_transfer_new_portfolio(self, stock_usecase): # Arrange - usecase, mock_repo = stock_usecase - mock_stock = CreateStock( - user_id=1, + usecase, stock_repo, portfolio_repo = stock_usecase + user_id, stock_id = 1, "123" + stock = CreateStock( + user_id=user_id, + symbol="", + price=3000.0, + quantity=1, + action_type=ActionType.TRANSFER, + created_at=ANY, + ) + portfolio_repo.get.return_value = None + stock_repo.create.return_value = stock_id + portfolio = Portfolio( + user_id=user_id, + cash_balance=3000.0, + total_money_in=3000.0, + holdings=[], + created_at=ANY, + updated_at=ANY, + ) + + # Act + result = usecase.create(stock) + + # Assert + portfolio_repo.get.assert_called_once_with(user_id) + portfolio_repo.update.assert_called_once_with(portfolio=portfolio) + stock_repo.create.assert_called_once_with(stock) + assert result == stock_id + + def test_create_transfer_existing_portfolio(self, stock_usecase): + # Arrange + usecase, stock_repo, portfolio_repo = stock_usecase + user_id, stock_id = 1, "123" + stock = CreateStock( + user_id=user_id, + symbol="", + price=3000.0, + quantity=1, + action_type=ActionType.TRANSFER, + created_at=ANY, + ) + existing_portfolio = Portfolio( + user_id=user_id, + cash_balance=3000.0, + total_money_in=3000.0, + holdings=[], + created_at=ANY, + updated_at=ANY, + ) + portfolio_repo.get.return_value = existing_portfolio + stock_repo.create.return_value = stock_id + updated_portfolio = Portfolio( + user_id=existing_portfolio.user_id, + cash_balance=6000.0, + total_money_in=6000.0, + holdings=[], + created_at=ANY, + updated_at=ANY, + ) + + # Act + result = usecase.create(stock) + + # Assert + portfolio_repo.get.assert_called_once_with(user_id) + portfolio_repo.update.assert_called_once_with(portfolio=updated_portfolio) + stock_repo.create.assert_called_once_with(stock) + assert result == stock_id + + def test_create_buy_new_stock(self, stock_usecase): + # Arrange + usecase, stock_repo, portfolio_repo = stock_usecase + user_id, stock_id = 1, "123" + stock = CreateStock( + user_id=user_id, + symbol="TSLA", + price=2000.0, + quantity=2, + action_type=ActionType.BUY.value, + created_at=ANY, + ) + existing_portfolio = Portfolio( + user_id=user_id, + cash_balance=5000.0, + total_money_in=5000.0, + holdings=[], + created_at=ANY, + updated_at=ANY, + ) + portfolio_repo.get.return_value = existing_portfolio + stock_repo.create.return_value = stock_id + updated_portfolio = Portfolio( + user_id=existing_portfolio.user_id, + cash_balance=1000.0, + total_money_in=5000.0, + holdings=[Holding(symbol="TSLA", shares=2, total_cost=4000.0)], + created_at=ANY, + updated_at=ANY, + ) + + # Act + result = usecase.create(stock) + + # Assert + portfolio_repo.get.assert_called_once_with(user_id) + portfolio_repo.update.assert_called_once_with(portfolio=updated_portfolio) + stock_repo.create.assert_called_once_with(stock) + assert result == stock_id + + def test_create_buy_existing_holding(self, stock_usecase): + # Arrange + usecase, stock_repo, portfolio_repo = stock_usecase + user_id, stock_id = 1, "stock_126" + stock = CreateStock( + user_id=user_id, symbol="AAPL", - price=150.25, - quantity=100, - action_type=ActionType.BUY, - created_at=datetime.now(timezone.utc), + price=150.0, + quantity=3, + action_type=ActionType.BUY.value, + created_at=ANY, + ) + existing_portfolio = Portfolio( + user_id=user_id, + cash_balance=1000.0, + total_money_in=1000.0, + holdings=[Holding(symbol="AAPL", shares=5, total_cost=750.0)], + created_at=ANY, + updated_at=ANY, + ) + portfolio_repo.get.return_value = existing_portfolio + stock_repo.create.return_value = stock_id + updated_portfolio = Portfolio( + user_id=user_id, + cash_balance=550.0, # 1000 - (150 * 3) + total_money_in=1000.0, + holdings=[Holding(symbol="AAPL", shares=8, total_cost=1200.0)], # 750 + (150 * 3) + created_at=ANY, + updated_at=ANY, ) - # Define the expected return value from the mock repository - expected_id = "stock_123" - mock_repo.create.return_value = expected_id + # Act + result = usecase.create(stock) + + # Assert + portfolio_repo.get.assert_called_once_with(user_id) + portfolio_repo.update.assert_called_once_with(portfolio=updated_portfolio) + stock_repo.create.assert_called_once_with(stock) + assert result == stock_id + + def test_create_sell_existing_holding_partial(self, stock_usecase): + # Arrange + usecase, stock_repo, portfolio_repo = stock_usecase + user_id, stock_id = 1, "stock_127" + stock = CreateStock( + user_id=user_id, + symbol="AAPL", + price=200.0, + quantity=2, + action_type=ActionType.SELL.value, + created_at=ANY, + ) + existing_portfolio = Portfolio( + user_id=user_id, + cash_balance=1000.0, + total_money_in=1000.0, + holdings=[Holding(symbol="AAPL", shares=5, total_cost=750.0)], + created_at=ANY, + updated_at=ANY, + ) + portfolio_repo.get.return_value = existing_portfolio + stock_repo.create.return_value = stock_id + updated_portfolio = Portfolio( + user_id=user_id, + cash_balance=1400.0, # 1000 + (200 * 2) + total_money_in=1000.0, + holdings=[Holding(symbol="AAPL", shares=3, total_cost=450.0)], # 750 - (150 * 2) + created_at=ANY, + updated_at=ANY, + ) # Act - result = usecase.create(mock_stock) + result = usecase.create(stock) - # Assertion - mock_repo.create.assert_called_once_with(mock_stock) - assert result == expected_id + # Assert + portfolio_repo.get.assert_called_once_with(user_id) + portfolio_repo.update.assert_called_once_with(portfolio=updated_portfolio) + stock_repo.create.assert_called_once_with(stock) + assert result == stock_id - def test_create_handles_repository_error(self, stock_usecase): + def test_create_sell_existing_holding_all_shares(self, stock_usecase): # Arrange - usecase, mock_repo = stock_usecase - mock_stock = CreateStock( - user_id=1, + usecase, stock_repo, portfolio_repo = stock_usecase + user_id, stock_id = 1, "stock_128" + stock = CreateStock( + user_id=user_id, symbol="AAPL", - price=150.25, - quantity=100, - action_type=ActionType.BUY, - created_at=datetime.now(timezone.utc), + price=300.0, + quantity=5, + action_type=ActionType.SELL.value, + created_at=ANY, + ) + existing_portfolio = Portfolio( + user_id=user_id, + cash_balance=1000.0, + total_money_in=1000.0, + holdings=[Holding(symbol="AAPL", shares=5, total_cost=750.0)], + created_at=ANY, + updated_at=ANY, + ) + portfolio_repo.get.return_value = existing_portfolio + stock_repo.create.return_value = stock_id + updated_portfolio = Portfolio( + user_id=user_id, + cash_balance=2500.0, # 1000 + (300 * 5) + total_money_in=1000.0, + holdings=[], + created_at=ANY, + updated_at=ANY, ) - # Simulate an exception from the repository - mock_repo.create.side_effect = Exception("Repository error") + # Act + result = usecase.create(stock) - # Act/Assertion - with pytest.raises(Exception, match="Repository error"): - usecase.create(mock_stock) - mock_repo.create.assert_called_once_with(mock_stock) + # Assert + portfolio_repo.get.assert_called_once_with(user_id) + portfolio_repo.update.assert_called_once_with(portfolio=updated_portfolio) + stock_repo.create.assert_called_once_with(stock) + assert result == stock_id + + def test_create_sell_non_existent_holding(self, stock_usecase): + # Arrange + usecase, stock_repo, portfolio_repo = stock_usecase + user_id, stock_id = 1, "stock_129" + stock = CreateStock( + user_id=user_id, + symbol="AAPL", + price=150.0, + quantity=5, + action_type=ActionType.SELL.value, + created_at=ANY, + ) + existing_portfolio = Portfolio( + user_id=user_id, + cash_balance=1000.0, + total_money_in=1000.0, + holdings=[], + created_at=ANY, + updated_at=ANY, + ) + portfolio_repo.get.return_value = existing_portfolio + stock_repo.create.return_value = stock_id + + # Act/Assert + with pytest.raises(Exception, match="Can not sell non-exist stock"): + usecase.create(stock) + portfolio_repo.get.assert_called_once_with(user_id) + portfolio_repo.update.assert_not_called() + stock_repo.create.assert_not_called() + + def test_create_handles_repository_error_on_get(self, stock_usecase): + # Arrange + usecase, stock_repo, portfolio_repo = stock_usecase + user_id, stock_id = 1, "stock_130" + stock = CreateStock( + user_id=user_id, + symbol="AAPL", + price=150.0, + quantity=10, + action_type=ActionType.BUY.value, + created_at=ANY, + ) + portfolio_repo.get.side_effect = Exception("Portfolio repository error") + stock_repo.create.return_value = stock_id + + # Act/Assert + with pytest.raises(Exception, match="Portfolio repository error"): + usecase.create(stock) + portfolio_repo.get.assert_called_once_with(user_id) + portfolio_repo.update.assert_not_called() + stock_repo.create.assert_not_called() + + def test_create_handles_repository_error_on_update(self, stock_usecase): + # Arrange + usecase, stock_repo, portfolio_repo = stock_usecase + user_id, stock_id = 1, "stock_131" + stock = CreateStock( + user_id=user_id, + symbol="AAPL", + price=150.0, + quantity=10, + action_type=ActionType.BUY.value, + created_at=ANY, + ) + existing_portfolio = Portfolio( + user_id=user_id, + cash_balance=2000.0, + total_money_in=2000.0, + holdings=[], + created_at=ANY, + updated_at=ANY, + ) + portfolio_repo.get.return_value = existing_portfolio + portfolio_repo.update.side_effect = Exception("Portfolio update error") + stock_repo.create.return_value = stock_id + + # Act/Assert + with pytest.raises(Exception, match="Portfolio update error"): + usecase.create(stock) + portfolio_repo.get.assert_called_once_with(user_id) + portfolio_repo.update.assert_called_once_with(portfolio=ANY) # Portfolio may vary, so use ANY + stock_repo.create.assert_not_called() + + def test_create_handles_repository_error_on_stock_create(self, stock_usecase): + # Arrange + usecase, stock_repo, portfolio_repo = stock_usecase + user_id = 1 + stock = CreateStock( + user_id=user_id, + symbol="AAPL", + price=150.0, + quantity=10, + action_type=ActionType.BUY.value, + created_at=ANY, + ) + existing_portfolio = Portfolio( + user_id=user_id, + cash_balance=2000.0, + total_money_in=2000.0, + holdings=[], + created_at=ANY, + updated_at=ANY, + ) + portfolio_repo.get.return_value = existing_portfolio + stock_repo.create.side_effect = Exception("Stock create error") + + # Act/Assert + with pytest.raises(Exception, match="Stock create error"): + usecase.create(stock) + portfolio_repo.get.assert_called_once_with(user_id) + portfolio_repo.update.assert_called_once_with(portfolio=ANY) # Portfolio may vary, so use ANY + stock_repo.create.assert_called_once_with(stock) def test_list(self, stock_usecase): # Arrange - usecase, mock_repo = stock_usecase + usecase, mock_repo, _ = stock_usecase user_id = 1 created_at = datetime.now(timezone.utc) mock_stocks = [ @@ -98,7 +402,7 @@ def test_list(self, stock_usecase): def test_list_handles_repository_error(self, stock_usecase): # Arrange - usecase, mock_repo = stock_usecase + usecase, mock_repo, _ = stock_usecase user_id = 1 mock_repo.list.side_effect = Exception("Repository error") diff --git a/src/usecase/stock.py b/src/usecase/stock.py index 83e0db4..3f1f27a 100644 --- a/src/usecase/stock.py +++ b/src/usecase/stock.py @@ -1,15 +1,64 @@ from typing import List - -from domain.stock import CreateStock, Stock -from adapters.base import AbstractStockRepository +from datetime import datetime, timezone +from domain.stock import CreateStock, Stock, ActionType +from domain.portfolio import Portfolio, Holding +from adapters.base import AbstractStockRepository, AbstractPortfolioRepository from .base import AbstractStockUsecase class StockUsecase(AbstractStockUsecase): - def __init__(self, stock_repo: AbstractStockRepository): + def __init__(self, stock_repo: AbstractStockRepository, portfolio_repo: AbstractPortfolioRepository): self.stock_repo = stock_repo + self.portfolio_repo = portfolio_repo + + def create(self, stock: CreateStock) -> str: + portfolio = self.portfolio_repo.get(stock.user_id) + if portfolio is None: + created_at = datetime.now(timezone.utc) + portfolio = Portfolio( + user_id=stock.user_id, + cash_balance=0.0, + total_money_in=0.0, + holdings=[], + created_at=created_at, + updated_at=created_at, + ) + + symbol = stock.symbol + price = stock.price + quantity = stock.quantity + action_type = ActionType(stock.action_type) + + if action_type == ActionType.TRANSFER: + portfolio.cash_balance += price * quantity + portfolio.total_money_in += price * quantity + elif action_type == ActionType.BUY: + portfolio.cash_balance -= price * quantity + + holding = next((h for h in portfolio.holdings if h.symbol == symbol), None) # Find the first holding + if not holding: + holding = Holding(symbol=symbol, shares=0, total_cost=0.0) + portfolio.holdings.append(holding) + + holding.shares += quantity + holding.total_cost += price * quantity + else: + portfolio.cash_balance += price * quantity + + holding = next((h for h in portfolio.holdings if h.symbol == symbol), None) # Find the first holding + if holding is None: + raise Exception("Can not sell non-exist stock") + + holding.shares -= quantity + if holding.shares > 0: + # Adjust total_cost proportionally (using average cost) + avg_cost = holding.total_cost / (holding.shares + quantity) + holding.total_cost -= avg_cost * quantity + else: + holding.total_cost = 0.0 + portfolio.holdings = [h for h in portfolio.holdings if h.shares > 0] - def create(self, stock: CreateStock): + self.portfolio_repo.update(portfolio=portfolio) return self.stock_repo.create(stock) def list(self, user_id: int) -> List[Stock]: diff --git a/tools/tests/run_tests.sh b/tools/tests/run_tests.sh index 8938261..048814d 100755 --- a/tools/tests/run_tests.sh +++ b/tools/tests/run_tests.sh @@ -27,7 +27,10 @@ cleanup() { print_emoji_line "<=" "${BLUE}" echo -e "<= Cleaning up resources..." print_emoji_line "<=" "${BLUE}" - docker-compose down -v mongodb-test + # Only stop MongoDB if it was started + if [ "$MONGO_STARTED" = "true" ]; then + docker-compose down -v mongodb-test + fi rm -rf __pycache__ tests/__pycache__ coverage.out .coverage # Only remove coverage.xml if it exists [ -f coverage.xml ] && rm -f coverage.xml @@ -36,28 +39,44 @@ cleanup() { # Set trap to call cleanup on exit (success or failure) trap cleanup EXIT -echo -e "${YELLOW}" -print_emoji_line "=>" "${YELLOW}" -echo -e "=> Starting test environment..." -print_emoji_line "=>" "${YELLOW}" -docker-compose up -d mongodb-test -uv run tools/tests/wait_for_mongo.py - -# Conditionally set coverage report options based on CI environment -if [ "$CI" = "true" ]; then - cov_report="--cov-report=term-missing --cov-report=xml" -else - cov_report="--cov-report=term-missing" -fi - -# Check if a test file path is provided as an argument +# Check if a test file path is provided and if it contains "_adapters" TEST_FILE=$1 +MONGO_STARTED="false" + if [ -n "$TEST_FILE" ]; then echo -e "Running tests for: $TEST_FILE" TEST_PATH="$TEST_FILE" + # Check if the test file contains "_adapters" + if [[ "$TEST_FILE" == *"_adapters"* ]]; then + echo -e "${YELLOW}" + print_emoji_line "=>" "${YELLOW}" + echo -e "=> Starting test environment with MongoDB..." + print_emoji_line "=>" "${YELLOW}" + docker-compose up -d mongodb-test + uv run tools/tests/wait_for_mongo.py + MONGO_STARTED="true" + else + echo -e "${YELLOW}" + print_emoji_line "=>" "${YELLOW}" + echo -e "=> Running tests without MongoDB..." + print_emoji_line "=>" "${YELLOW}" + fi else - echo -e "Running all tests under src/tests" + echo -e "${YELLOW}" + print_emoji_line "=>" "${YELLOW}" + echo -e "=> Starting test environment with MongoDB for all tests..." + print_emoji_line "=>" "${YELLOW}" + docker-compose up -d mongodb-test + uv run tools/tests/wait_for_mongo.py TEST_PATH="src/tests/" + MONGO_STARTED="true" +fi + +# Conditionally set coverage report options based on CI environment +if [ "$CI" = "true" ]; then + cov_report="--cov-report=term-missing --cov-report=xml" +else + cov_report="--cov-report=term-missing" fi # Run pytest with coverage