diff --git a/src/adapters/stock.py b/src/adapters/stock.py index 63f3178..5cf6551 100644 --- a/src/adapters/stock.py +++ b/src/adapters/stock.py @@ -13,7 +13,9 @@ def __init__(self, mongo_client: MongoClient, database_name: str = "stock_db"): self.collection = self.db["stocks"] def create(self, stock: CreateStock) -> str: - result = self.collection.insert_one(stock.as_dict()) + stock_dict = stock.as_dict() + stock_dict["updated_at"] = stock_dict["created_at"] + result = self.collection.insert_one(stock_dict) return str(result.inserted_id) def list(self, user_id: int) -> List[Stock]: diff --git a/src/domain/portfolio.py b/src/domain/portfolio.py index a5d835e..ac45d8b 100644 --- a/src/domain/portfolio.py +++ b/src/domain/portfolio.py @@ -50,8 +50,6 @@ class Portfolio: updated_at: datetime def __post_init__(self): - if self.cash_balance < 0: - raise ValueError("cash_balance cannot be negative") if self.total_money_in < 0: raise ValueError("total_money_in cannot be negative") diff --git a/src/handler/stock.py b/src/handler/stock.py index 135995f..058603c 100644 --- a/src/handler/stock.py +++ b/src/handler/stock.py @@ -1,10 +1,11 @@ +from typing import List as ListType import logging from datetime import datetime, timezone import grpc import proto.stock_pb2 as stock_pb2 import proto.stock_pb2_grpc as stock_pb2_grpc from usecase.base import AbstractStockUsecase -from domain.stock import CreateStock +from domain.stock import CreateStock, Stock from domain.enum import ActionType, ACTION_MAP, StockType, STOCK_MAP @@ -58,6 +59,27 @@ def List(self, request, context): context.set_details("Internal server error") raise grpc.RpcError("Internal server error") + def GetPortfolioInfo(self, request, context): + try: + user_id = request.user_id + info = self.stock_usecase.get_portfolio_info(user_id=user_id) + + return stock_pb2.GetPortfolioInfoResp( + user_id=user_id, + total_portfolio_value=info.total_portfolio_value, + total_gain=info.total_gain, + roi=info.roi, + ) + except Exception as e: + logging.error( + "Failed to get portfolio info for user_id=%s: %s", + request.user_id, + str(e), + ) + context.set_code(grpc.StatusCode.INTERNAL) + context.set_details("Internal server error") + raise grpc.RpcError("Internal server error") + def _map_action_type(self, action: int) -> ActionType: if action not in ACTION_MAP: raise ValueError(f"Invalid action type: {action}. Must be 1 (BUY), 2 (SELL), or 3 (TRANSFER).") @@ -68,7 +90,7 @@ def _map_stock_type(self, stock_type: int) -> StockType: raise ValueError(f"Invalid stock type: {stock_type}. Must be 1 (STOCKS), 2 (ETF).") return STOCK_MAP[stock_type] - def _convert_to_proto_stock_list(self, stock_list): + def _convert_to_proto_stock_list(self, stock_list: ListType[Stock]): return [ stock_pb2.Stock( id=stock.id, diff --git a/src/index.py b/src/index.py index 6b0fdc9..01dac7d 100644 --- a/src/index.py +++ b/src/index.py @@ -5,6 +5,7 @@ from dotenv import load_dotenv from handler.stock import StockService from adapters.stock import StockRepository +from adapters.portfolio import PortfolioRepository from usecase.stock import StockUsecase @@ -14,7 +15,8 @@ def serve(): client = MongoClient("mongodb://localhost:27017") stock_repo = StockRepository(client, "stock_db") - stock_usecase = StockUsecase(stock_repo) + portfolio_repo = PortfolioRepository(client, "stock_db") + stock_usecase = StockUsecase(stock_repo, portfolio_repo) server = grpc.server(futures.ThreadPoolExecutor(max_workers=10)) stock_pb2_grpc.add_StockServiceServicer_to_server(StockService(stock_usecase), server) server.add_insecure_port("[::]:50051") diff --git a/src/proto/stock.proto b/src/proto/stock.proto index 04b8437..ae761a8 100644 --- a/src/proto/stock.proto +++ b/src/proto/stock.proto @@ -56,8 +56,20 @@ message ListResp { repeated Stock stock_list = 1 [json_name = "stock_list"]; } +message GetPortfolioInfoReq { + int32 user_id = 1 [json_name = "user_id"]; +} + +message GetPortfolioInfoResp { + int32 user_id = 1 [json_name = "user_id"]; + double total_portfolio_value = 2 [json_name = "total_portfolio_value"]; + double total_gain = 3 [json_name = "total_gain"]; + double roi = 4 [json_name = "roi"]; +} + + service StockService { rpc Create (CreateReq) returns (CreateResp) {} - rpc List (ListReq) returns (ListResp) {} + rpc GetPortfolioInfo (GetPortfolioInfoReq) returns (GetPortfolioInfoResp) {} } \ No newline at end of file diff --git a/src/proto/stock_pb2.py b/src/proto/stock_pb2.py index e3ec87f..1f46b81 100644 --- a/src/proto/stock_pb2.py +++ b/src/proto/stock_pb2.py @@ -25,7 +25,7 @@ from google.protobuf import timestamp_pb2 as google_dot_protobuf_dot_timestamp__pb2 -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x11proto/stock.proto\x12\x05stock\x1a\x1fgoogle/protobuf/timestamp.proto\"B\n\x06\x41\x63tion\"8\n\x04Type\x12\x0f\n\x0bUNSPECIFIED\x10\x00\x12\x07\n\x03\x42UY\x10\x01\x12\x08\n\x04SELL\x10\x02\x12\x0c\n\x08TRANSFER\x10\x03\"9\n\tStockType\",\n\x04Type\x12\x0f\n\x0bUNSPECIFIED\x10\x00\x12\n\n\x06STOCKS\x10\x01\x12\x07\n\x03\x45TF\x10\x02\"\xab\x02\n\x05Stock\x12\x0e\n\x02id\x18\x01 \x01(\tR\x02id\x12\x18\n\x07user_id\x18\x02 \x01(\x05R\x07user_id\x12\x16\n\x06symbol\x18\x03 \x01(\tR\x06symbol\x12\x14\n\x05price\x18\x04 \x01(\x01R\x05price\x12\x1a\n\x08quantity\x18\x05 \x01(\x05R\x08quantity\x12\x16\n\x06\x61\x63tion\x18\x06 \x01(\tR\x06\x61\x63tion\x12\x1e\n\nstock_type\x18\x07 \x01(\tR\nstock_type\x12:\n\ncreated_at\x18\x08 \x01(\x0b\x32\x1a.google.protobuf.TimestampR\ncreated_at\x12:\n\nupdated_at\x18\t \x01(\x0b\x32\x1a.google.protobuf.TimestampR\nupdated_at\"\xca\x02\n\tCreateReq\x12\x18\n\x07user_id\x18\x01 \x01(\x05R\x07user_id\x12\x16\n\x06symbol\x18\x02 \x01(\tR\x06symbol\x12\x14\n\x05price\x18\x03 \x01(\x01R\x05price\x12\x1a\n\x08quantity\x18\x04 \x01(\x05R\x08quantity\x12*\n\x06\x61\x63tion\x18\x05 \x01(\x0e\x32\x12.stock.Action.TypeR\x06\x61\x63tion\x12\x35\n\nstock_type\x18\x06 \x01(\x0e\x32\x15.stock.StockType.TypeR\nstock_type\x12:\n\ncreated_at\x18\x07 \x01(\x0b\x32\x1a.google.protobuf.TimestampR\ncreated_at\x12:\n\nupdated_at\x18\x08 \x01(\x0b\x32\x1a.google.protobuf.TimestampR\nupdated_at\"\x1c\n\nCreateResp\x12\x0e\n\x02id\x18\x01 \x01(\tR\x02id\"#\n\x07ListReq\x12\x18\n\x07user_id\x18\x01 \x01(\x05R\x07user_id\"8\n\x08ListResp\x12,\n\nstock_list\x18\x01 \x03(\x0b\x32\x0c.stock.StockR\nstock_list2j\n\x0cStockService\x12/\n\x06\x43reate\x12\x10.stock.CreateReq\x1a\x11.stock.CreateResp\"\x00\x12)\n\x04List\x12\x0e.stock.ListReq\x1a\x0f.stock.ListResp\"\x00\x62\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x11proto/stock.proto\x12\x05stock\x1a\x1fgoogle/protobuf/timestamp.proto\"B\n\x06\x41\x63tion\"8\n\x04Type\x12\x0f\n\x0bUNSPECIFIED\x10\x00\x12\x07\n\x03\x42UY\x10\x01\x12\x08\n\x04SELL\x10\x02\x12\x0c\n\x08TRANSFER\x10\x03\"9\n\tStockType\",\n\x04Type\x12\x0f\n\x0bUNSPECIFIED\x10\x00\x12\n\n\x06STOCKS\x10\x01\x12\x07\n\x03\x45TF\x10\x02\"\xab\x02\n\x05Stock\x12\x0e\n\x02id\x18\x01 \x01(\tR\x02id\x12\x18\n\x07user_id\x18\x02 \x01(\x05R\x07user_id\x12\x16\n\x06symbol\x18\x03 \x01(\tR\x06symbol\x12\x14\n\x05price\x18\x04 \x01(\x01R\x05price\x12\x1a\n\x08quantity\x18\x05 \x01(\x05R\x08quantity\x12\x16\n\x06\x61\x63tion\x18\x06 \x01(\tR\x06\x61\x63tion\x12\x1e\n\nstock_type\x18\x07 \x01(\tR\nstock_type\x12:\n\ncreated_at\x18\x08 \x01(\x0b\x32\x1a.google.protobuf.TimestampR\ncreated_at\x12:\n\nupdated_at\x18\t \x01(\x0b\x32\x1a.google.protobuf.TimestampR\nupdated_at\"\xca\x02\n\tCreateReq\x12\x18\n\x07user_id\x18\x01 \x01(\x05R\x07user_id\x12\x16\n\x06symbol\x18\x02 \x01(\tR\x06symbol\x12\x14\n\x05price\x18\x03 \x01(\x01R\x05price\x12\x1a\n\x08quantity\x18\x04 \x01(\x05R\x08quantity\x12*\n\x06\x61\x63tion\x18\x05 \x01(\x0e\x32\x12.stock.Action.TypeR\x06\x61\x63tion\x12\x35\n\nstock_type\x18\x06 \x01(\x0e\x32\x15.stock.StockType.TypeR\nstock_type\x12:\n\ncreated_at\x18\x07 \x01(\x0b\x32\x1a.google.protobuf.TimestampR\ncreated_at\x12:\n\nupdated_at\x18\x08 \x01(\x0b\x32\x1a.google.protobuf.TimestampR\nupdated_at\"\x1c\n\nCreateResp\x12\x0e\n\x02id\x18\x01 \x01(\tR\x02id\"#\n\x07ListReq\x12\x18\n\x07user_id\x18\x01 \x01(\x05R\x07user_id\"8\n\x08ListResp\x12,\n\nstock_list\x18\x01 \x03(\x0b\x32\x0c.stock.StockR\nstock_list\"/\n\x13GetPortfolioInfoReq\x12\x18\n\x07user_id\x18\x01 \x01(\x05R\x07user_id\"\x98\x01\n\x14GetPortfolioInfoResp\x12\x18\n\x07user_id\x18\x01 \x01(\x05R\x07user_id\x12\x34\n\x15total_portfolio_value\x18\x02 \x01(\x01R\x15total_portfolio_value\x12\x1e\n\ntotal_gain\x18\x03 \x01(\x01R\ntotal_gain\x12\x10\n\x03roi\x18\x04 \x01(\x01R\x03roi2\xb9\x01\n\x0cStockService\x12/\n\x06\x43reate\x12\x10.stock.CreateReq\x1a\x11.stock.CreateResp\"\x00\x12)\n\x04List\x12\x0e.stock.ListReq\x1a\x0f.stock.ListResp\"\x00\x12M\n\x10GetPortfolioInfo\x12\x1a.stock.GetPortfolioInfoReq\x1a\x1b.stock.GetPortfolioInfoResp\"\x00\x62\x06proto3') _globals = globals() _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) @@ -50,6 +50,10 @@ _globals['_LISTREQ']._serialized_end=888 _globals['_LISTRESP']._serialized_start=890 _globals['_LISTRESP']._serialized_end=946 - _globals['_STOCKSERVICE']._serialized_start=948 - _globals['_STOCKSERVICE']._serialized_end=1054 + _globals['_GETPORTFOLIOINFOREQ']._serialized_start=948 + _globals['_GETPORTFOLIOINFOREQ']._serialized_end=995 + _globals['_GETPORTFOLIOINFORESP']._serialized_start=998 + _globals['_GETPORTFOLIOINFORESP']._serialized_end=1150 + _globals['_STOCKSERVICE']._serialized_start=1153 + _globals['_STOCKSERVICE']._serialized_end=1338 # @@protoc_insertion_point(module_scope) diff --git a/src/proto/stock_pb2_grpc.py b/src/proto/stock_pb2_grpc.py index c4cf380..628bb3d 100644 --- a/src/proto/stock_pb2_grpc.py +++ b/src/proto/stock_pb2_grpc.py @@ -44,6 +44,11 @@ def __init__(self, channel): request_serializer=proto_dot_stock__pb2.ListReq.SerializeToString, response_deserializer=proto_dot_stock__pb2.ListResp.FromString, _registered_method=True) + self.GetPortfolioInfo = channel.unary_unary( + '/stock.StockService/GetPortfolioInfo', + request_serializer=proto_dot_stock__pb2.GetPortfolioInfoReq.SerializeToString, + response_deserializer=proto_dot_stock__pb2.GetPortfolioInfoResp.FromString, + _registered_method=True) class StockServiceServicer(object): @@ -61,6 +66,12 @@ def List(self, request, context): context.set_details('Method not implemented!') raise NotImplementedError('Method not implemented!') + def GetPortfolioInfo(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + def add_StockServiceServicer_to_server(servicer, server): rpc_method_handlers = { @@ -74,6 +85,11 @@ def add_StockServiceServicer_to_server(servicer, server): request_deserializer=proto_dot_stock__pb2.ListReq.FromString, response_serializer=proto_dot_stock__pb2.ListResp.SerializeToString, ), + 'GetPortfolioInfo': grpc.unary_unary_rpc_method_handler( + servicer.GetPortfolioInfo, + request_deserializer=proto_dot_stock__pb2.GetPortfolioInfoReq.FromString, + response_serializer=proto_dot_stock__pb2.GetPortfolioInfoResp.SerializeToString, + ), } generic_handler = grpc.method_handlers_generic_handler( 'stock.StockService', rpc_method_handlers) @@ -138,3 +154,30 @@ def List(request, timeout, metadata, _registered_method=True) + + @staticmethod + def GetPortfolioInfo(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary( + request, + target, + '/stock.StockService/GetPortfolioInfo', + proto_dot_stock__pb2.GetPortfolioInfoReq.SerializeToString, + proto_dot_stock__pb2.GetPortfolioInfoResp.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + _registered_method=True) diff --git a/src/tests/test_stock_handler.py b/src/tests/test_stock_handler.py index a81ac61..404142b 100644 --- a/src/tests/test_stock_handler.py +++ b/src/tests/test_stock_handler.py @@ -6,6 +6,7 @@ from handler.stock import StockService from usecase.base import AbstractStockUsecase from domain.stock import CreateStock, Stock +from domain.portfolio import PortfolioInfo from domain.enum import ActionType, StockType @@ -258,3 +259,62 @@ def test_internal_error(self, mock_stock_usecase, mock_context, valid_request): mock_context.set_code.assert_called_once_with(grpc.StatusCode.INTERNAL) mock_context.set_details.assert_called_once_with("Internal server error") mock_stock_usecase.list.assert_called_once_with(1) + + +class TestStockServiceGetPortfolioInfo: + # Fixture to create a mock stock_usecase + @pytest.fixture + def mock_stock_usecase(self): + usecase = Mock(spec=AbstractStockUsecase) + usecase.get_portfolio_info.return_value = PortfolioInfo( + user_id=1, + total_portfolio_value=2500.0, + total_gain=500.0, + roi=25.0, + ) + return usecase + + # Fixture to create a mock gRPC context + @pytest.fixture + def mock_context(self): + context = Mock() + context.set_code = Mock() + context.set_details = Mock() + return context + + # Fixture to create a valid gRPC request + @pytest.fixture + def valid_request(self): + request = Mock() + request.user_id = 1 + return request + + def test_success(self, mock_stock_usecase, mock_context, valid_request): + # Arrange + service = StockService(mock_stock_usecase) + + # Action + response = service.GetPortfolioInfo(valid_request, mock_context) + + # Assertion + assert isinstance(response, stock_pb2.GetPortfolioInfoResp) + assert response.user_id == 1 + assert response.total_portfolio_value == 2500.0 + assert response.total_gain == 500.0 + assert response.roi == 25.0 + mock_stock_usecase.get_portfolio_info.assert_called_once_with(user_id=1) + mock_context.set_code.assert_not_called() + mock_context.set_details.assert_not_called() + + def test_internal_error(self, mock_stock_usecase, mock_context, valid_request): + # Arrange + service = StockService(mock_stock_usecase) + mock_stock_usecase.get_portfolio_info.side_effect = Exception("Database error") # Simulate internal error + + # Act/Assertion + with pytest.raises(grpc.RpcError) as exc_info: + service.GetPortfolioInfo(valid_request, mock_context) + assert str(exc_info.value) == "Internal server error" + mock_context.set_code.assert_called_once_with(grpc.StatusCode.INTERNAL) + mock_context.set_details.assert_called_once_with("Internal server error") + mock_stock_usecase.get_portfolio_info.assert_called_once_with(user_id=1) diff --git a/src/usecase/base.py b/src/usecase/base.py index 399c24d..0f4bc88 100644 --- a/src/usecase/base.py +++ b/src/usecase/base.py @@ -2,6 +2,7 @@ from abc import ABC, abstractmethod from domain.stock import CreateStock, Stock +from domain.portfolio import PortfolioInfo class AbstractStockUsecase(ABC): @@ -11,3 +12,6 @@ def create(self, stock: CreateStock) -> str: def list(self, user_id: int) -> List[Stock]: """List all stock by user id""" + + def get_portfolio_info(self, user_id: int) -> PortfolioInfo: + """Get portfolio info"""