From a3f1b474db59a768db092742f8b02ab56fac4bd0 Mon Sep 17 00:00:00 2001 From: Eyo Chen Date: Sat, 28 Jun 2025 15:39:16 +0800 Subject: [PATCH 1/4] feat: add proto --- src/proto/stock.proto | 22 +++++++++++++++++-- src/proto/stock_pb2.py | 36 ++++++++++++++++++------------- src/proto/stock_pb2_grpc.py | 43 +++++++++++++++++++++++++++++++++++++ 3 files changed, 84 insertions(+), 17 deletions(-) diff --git a/src/proto/stock.proto b/src/proto/stock.proto index ae761a8..7ace8eb 100644 --- a/src/proto/stock.proto +++ b/src/proto/stock.proto @@ -33,13 +33,21 @@ message Stock { google.protobuf.Timestamp updated_at = 9 [json_name = "updated_at"]; } +message StockInfo { + string symbol = 1 [json_name = "symbol"]; + int32 quantity = 2 [json_name = "quantity"]; + double price = 3 [json_name = "price"]; + double avg_cost = 4 [json_name = "avg_cost"]; + double percentage = 5 [json_name = "percentage"]; +} + message CreateReq { int32 user_id = 1 [json_name = "user_id"]; string symbol = 2 [json_name = "symbol"]; double price = 3 [json_name = "price"]; int32 quantity = 4 [json_name = "quantity"]; - Action.Type action = 5 [json_name = "action"]; // add validation rules - StockType.Type stock_type = 6 [json_name = "stock_type"]; // add validation rules + Action.Type action = 5 [json_name = "action"]; + StockType.Type stock_type = 6 [json_name = "stock_type"]; google.protobuf.Timestamp created_at = 7 [json_name = "created_at"]; google.protobuf.Timestamp updated_at = 8 [json_name = "updated_at"]; } @@ -67,9 +75,19 @@ message GetPortfolioInfoResp { double roi = 4 [json_name = "roi"]; } +message GetStockInfoReq { + int32 user_id = 1 [json_name = "user_id"]; +} + +message GetStockInfoResp { + repeated StockInfo stocks = 1 [json_name = "STOCKS"]; + repeated StockInfo etf = 2 [json_name = "ETF"]; + repeated StockInfo cash = 3 [json_name = "CASH"]; +} service StockService { rpc Create (CreateReq) returns (CreateResp) {} rpc List (ListReq) returns (ListResp) {} rpc GetPortfolioInfo (GetPortfolioInfoReq) returns (GetPortfolioInfoResp) {} + rpc GetStockInfo (GetStockInfoReq) returns (GetStockInfoResp) {} } \ No newline at end of file diff --git a/src/proto/stock_pb2.py b/src/proto/stock_pb2.py index 1f46b81..197f01d 100644 --- a/src/proto/stock_pb2.py +++ b/src/proto/stock_pb2.py @@ -25,7 +25,7 @@ from google.protobuf import timestamp_pb2 as google_dot_protobuf_dot_timestamp__pb2 -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x11proto/stock.proto\x12\x05stock\x1a\x1fgoogle/protobuf/timestamp.proto\"B\n\x06\x41\x63tion\"8\n\x04Type\x12\x0f\n\x0bUNSPECIFIED\x10\x00\x12\x07\n\x03\x42UY\x10\x01\x12\x08\n\x04SELL\x10\x02\x12\x0c\n\x08TRANSFER\x10\x03\"9\n\tStockType\",\n\x04Type\x12\x0f\n\x0bUNSPECIFIED\x10\x00\x12\n\n\x06STOCKS\x10\x01\x12\x07\n\x03\x45TF\x10\x02\"\xab\x02\n\x05Stock\x12\x0e\n\x02id\x18\x01 \x01(\tR\x02id\x12\x18\n\x07user_id\x18\x02 \x01(\x05R\x07user_id\x12\x16\n\x06symbol\x18\x03 \x01(\tR\x06symbol\x12\x14\n\x05price\x18\x04 \x01(\x01R\x05price\x12\x1a\n\x08quantity\x18\x05 \x01(\x05R\x08quantity\x12\x16\n\x06\x61\x63tion\x18\x06 \x01(\tR\x06\x61\x63tion\x12\x1e\n\nstock_type\x18\x07 \x01(\tR\nstock_type\x12:\n\ncreated_at\x18\x08 \x01(\x0b\x32\x1a.google.protobuf.TimestampR\ncreated_at\x12:\n\nupdated_at\x18\t \x01(\x0b\x32\x1a.google.protobuf.TimestampR\nupdated_at\"\xca\x02\n\tCreateReq\x12\x18\n\x07user_id\x18\x01 \x01(\x05R\x07user_id\x12\x16\n\x06symbol\x18\x02 \x01(\tR\x06symbol\x12\x14\n\x05price\x18\x03 \x01(\x01R\x05price\x12\x1a\n\x08quantity\x18\x04 \x01(\x05R\x08quantity\x12*\n\x06\x61\x63tion\x18\x05 \x01(\x0e\x32\x12.stock.Action.TypeR\x06\x61\x63tion\x12\x35\n\nstock_type\x18\x06 \x01(\x0e\x32\x15.stock.StockType.TypeR\nstock_type\x12:\n\ncreated_at\x18\x07 \x01(\x0b\x32\x1a.google.protobuf.TimestampR\ncreated_at\x12:\n\nupdated_at\x18\x08 \x01(\x0b\x32\x1a.google.protobuf.TimestampR\nupdated_at\"\x1c\n\nCreateResp\x12\x0e\n\x02id\x18\x01 \x01(\tR\x02id\"#\n\x07ListReq\x12\x18\n\x07user_id\x18\x01 \x01(\x05R\x07user_id\"8\n\x08ListResp\x12,\n\nstock_list\x18\x01 \x03(\x0b\x32\x0c.stock.StockR\nstock_list\"/\n\x13GetPortfolioInfoReq\x12\x18\n\x07user_id\x18\x01 \x01(\x05R\x07user_id\"\x98\x01\n\x14GetPortfolioInfoResp\x12\x18\n\x07user_id\x18\x01 \x01(\x05R\x07user_id\x12\x34\n\x15total_portfolio_value\x18\x02 \x01(\x01R\x15total_portfolio_value\x12\x1e\n\ntotal_gain\x18\x03 \x01(\x01R\ntotal_gain\x12\x10\n\x03roi\x18\x04 \x01(\x01R\x03roi2\xb9\x01\n\x0cStockService\x12/\n\x06\x43reate\x12\x10.stock.CreateReq\x1a\x11.stock.CreateResp\"\x00\x12)\n\x04List\x12\x0e.stock.ListReq\x1a\x0f.stock.ListResp\"\x00\x12M\n\x10GetPortfolioInfo\x12\x1a.stock.GetPortfolioInfoReq\x1a\x1b.stock.GetPortfolioInfoResp\"\x00\x62\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x11proto/stock.proto\x12\x05stock\x1a\x1fgoogle/protobuf/timestamp.proto\"B\n\x06\x41\x63tion\"8\n\x04Type\x12\x0f\n\x0bUNSPECIFIED\x10\x00\x12\x07\n\x03\x42UY\x10\x01\x12\x08\n\x04SELL\x10\x02\x12\x0c\n\x08TRANSFER\x10\x03\"9\n\tStockType\",\n\x04Type\x12\x0f\n\x0bUNSPECIFIED\x10\x00\x12\n\n\x06STOCKS\x10\x01\x12\x07\n\x03\x45TF\x10\x02\"\xab\x02\n\x05Stock\x12\x0e\n\x02id\x18\x01 \x01(\tR\x02id\x12\x18\n\x07user_id\x18\x02 \x01(\x05R\x07user_id\x12\x16\n\x06symbol\x18\x03 \x01(\tR\x06symbol\x12\x14\n\x05price\x18\x04 \x01(\x01R\x05price\x12\x1a\n\x08quantity\x18\x05 \x01(\x05R\x08quantity\x12\x16\n\x06\x61\x63tion\x18\x06 \x01(\tR\x06\x61\x63tion\x12\x1e\n\nstock_type\x18\x07 \x01(\tR\nstock_type\x12:\n\ncreated_at\x18\x08 \x01(\x0b\x32\x1a.google.protobuf.TimestampR\ncreated_at\x12:\n\nupdated_at\x18\t \x01(\x0b\x32\x1a.google.protobuf.TimestampR\nupdated_at\"\x91\x01\n\tStockInfo\x12\x16\n\x06symbol\x18\x01 \x01(\tR\x06symbol\x12\x1a\n\x08quantity\x18\x02 \x01(\x05R\x08quantity\x12\x14\n\x05price\x18\x03 \x01(\x01R\x05price\x12\x1a\n\x08\x61vg_cost\x18\x04 \x01(\x01R\x08\x61vg_cost\x12\x1e\n\npercentage\x18\x05 \x01(\x01R\npercentage\"\xca\x02\n\tCreateReq\x12\x18\n\x07user_id\x18\x01 \x01(\x05R\x07user_id\x12\x16\n\x06symbol\x18\x02 \x01(\tR\x06symbol\x12\x14\n\x05price\x18\x03 \x01(\x01R\x05price\x12\x1a\n\x08quantity\x18\x04 \x01(\x05R\x08quantity\x12*\n\x06\x61\x63tion\x18\x05 \x01(\x0e\x32\x12.stock.Action.TypeR\x06\x61\x63tion\x12\x35\n\nstock_type\x18\x06 \x01(\x0e\x32\x15.stock.StockType.TypeR\nstock_type\x12:\n\ncreated_at\x18\x07 \x01(\x0b\x32\x1a.google.protobuf.TimestampR\ncreated_at\x12:\n\nupdated_at\x18\x08 \x01(\x0b\x32\x1a.google.protobuf.TimestampR\nupdated_at\"\x1c\n\nCreateResp\x12\x0e\n\x02id\x18\x01 \x01(\tR\x02id\"#\n\x07ListReq\x12\x18\n\x07user_id\x18\x01 \x01(\x05R\x07user_id\"8\n\x08ListResp\x12,\n\nstock_list\x18\x01 \x03(\x0b\x32\x0c.stock.StockR\nstock_list\"/\n\x13GetPortfolioInfoReq\x12\x18\n\x07user_id\x18\x01 \x01(\x05R\x07user_id\"\x98\x01\n\x14GetPortfolioInfoResp\x12\x18\n\x07user_id\x18\x01 \x01(\x05R\x07user_id\x12\x34\n\x15total_portfolio_value\x18\x02 \x01(\x01R\x15total_portfolio_value\x12\x1e\n\ntotal_gain\x18\x03 \x01(\x01R\ntotal_gain\x12\x10\n\x03roi\x18\x04 \x01(\x01R\x03roi\"+\n\x0fGetStockInfoReq\x12\x18\n\x07user_id\x18\x01 \x01(\x05R\x07user_id\"\x86\x01\n\x10GetStockInfoResp\x12(\n\x06stocks\x18\x01 \x03(\x0b\x32\x10.stock.StockInfoR\x06STOCKS\x12\"\n\x03\x65tf\x18\x02 \x03(\x0b\x32\x10.stock.StockInfoR\x03\x45TF\x12$\n\x04\x63\x61sh\x18\x03 \x03(\x0b\x32\x10.stock.StockInfoR\x04\x43\x41SH2\xfc\x01\n\x0cStockService\x12/\n\x06\x43reate\x12\x10.stock.CreateReq\x1a\x11.stock.CreateResp\"\x00\x12)\n\x04List\x12\x0e.stock.ListReq\x1a\x0f.stock.ListResp\"\x00\x12M\n\x10GetPortfolioInfo\x12\x1a.stock.GetPortfolioInfoReq\x1a\x1b.stock.GetPortfolioInfoResp\"\x00\x12\x41\n\x0cGetStockInfo\x12\x16.stock.GetStockInfoReq\x1a\x17.stock.GetStockInfoResp\"\x00\x62\x06proto3') _globals = globals() _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) @@ -42,18 +42,24 @@ _globals['_STOCKTYPE_TYPE']._serialized_end=186 _globals['_STOCK']._serialized_start=189 _globals['_STOCK']._serialized_end=488 - _globals['_CREATEREQ']._serialized_start=491 - _globals['_CREATEREQ']._serialized_end=821 - _globals['_CREATERESP']._serialized_start=823 - _globals['_CREATERESP']._serialized_end=851 - _globals['_LISTREQ']._serialized_start=853 - _globals['_LISTREQ']._serialized_end=888 - _globals['_LISTRESP']._serialized_start=890 - _globals['_LISTRESP']._serialized_end=946 - _globals['_GETPORTFOLIOINFOREQ']._serialized_start=948 - _globals['_GETPORTFOLIOINFOREQ']._serialized_end=995 - _globals['_GETPORTFOLIOINFORESP']._serialized_start=998 - _globals['_GETPORTFOLIOINFORESP']._serialized_end=1150 - _globals['_STOCKSERVICE']._serialized_start=1153 - _globals['_STOCKSERVICE']._serialized_end=1338 + _globals['_STOCKINFO']._serialized_start=491 + _globals['_STOCKINFO']._serialized_end=636 + _globals['_CREATEREQ']._serialized_start=639 + _globals['_CREATEREQ']._serialized_end=969 + _globals['_CREATERESP']._serialized_start=971 + _globals['_CREATERESP']._serialized_end=999 + _globals['_LISTREQ']._serialized_start=1001 + _globals['_LISTREQ']._serialized_end=1036 + _globals['_LISTRESP']._serialized_start=1038 + _globals['_LISTRESP']._serialized_end=1094 + _globals['_GETPORTFOLIOINFOREQ']._serialized_start=1096 + _globals['_GETPORTFOLIOINFOREQ']._serialized_end=1143 + _globals['_GETPORTFOLIOINFORESP']._serialized_start=1146 + _globals['_GETPORTFOLIOINFORESP']._serialized_end=1298 + _globals['_GETSTOCKINFOREQ']._serialized_start=1300 + _globals['_GETSTOCKINFOREQ']._serialized_end=1343 + _globals['_GETSTOCKINFORESP']._serialized_start=1346 + _globals['_GETSTOCKINFORESP']._serialized_end=1480 + _globals['_STOCKSERVICE']._serialized_start=1483 + _globals['_STOCKSERVICE']._serialized_end=1735 # @@protoc_insertion_point(module_scope) diff --git a/src/proto/stock_pb2_grpc.py b/src/proto/stock_pb2_grpc.py index 628bb3d..c136e7a 100644 --- a/src/proto/stock_pb2_grpc.py +++ b/src/proto/stock_pb2_grpc.py @@ -49,6 +49,11 @@ def __init__(self, channel): request_serializer=proto_dot_stock__pb2.GetPortfolioInfoReq.SerializeToString, response_deserializer=proto_dot_stock__pb2.GetPortfolioInfoResp.FromString, _registered_method=True) + self.GetStockInfo = channel.unary_unary( + '/stock.StockService/GetStockInfo', + request_serializer=proto_dot_stock__pb2.GetStockInfoReq.SerializeToString, + response_deserializer=proto_dot_stock__pb2.GetStockInfoResp.FromString, + _registered_method=True) class StockServiceServicer(object): @@ -72,6 +77,12 @@ def GetPortfolioInfo(self, request, context): context.set_details('Method not implemented!') raise NotImplementedError('Method not implemented!') + def GetStockInfo(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + def add_StockServiceServicer_to_server(servicer, server): rpc_method_handlers = { @@ -90,6 +101,11 @@ def add_StockServiceServicer_to_server(servicer, server): request_deserializer=proto_dot_stock__pb2.GetPortfolioInfoReq.FromString, response_serializer=proto_dot_stock__pb2.GetPortfolioInfoResp.SerializeToString, ), + 'GetStockInfo': grpc.unary_unary_rpc_method_handler( + servicer.GetStockInfo, + request_deserializer=proto_dot_stock__pb2.GetStockInfoReq.FromString, + response_serializer=proto_dot_stock__pb2.GetStockInfoResp.SerializeToString, + ), } generic_handler = grpc.method_handlers_generic_handler( 'stock.StockService', rpc_method_handlers) @@ -181,3 +197,30 @@ def GetPortfolioInfo(request, timeout, metadata, _registered_method=True) + + @staticmethod + def GetStockInfo(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary( + request, + target, + '/stock.StockService/GetStockInfo', + proto_dot_stock__pb2.GetStockInfoReq.SerializeToString, + proto_dot_stock__pb2.GetStockInfoResp.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + _registered_method=True) From 4ba42a40ffa82716bdd9564698bd34b479961827 Mon Sep 17 00:00:00 2001 From: Eyo Chen Date: Sat, 28 Jun 2025 15:40:20 +0800 Subject: [PATCH 2/4] refactor: renaming --- src/tests/test_stock_usecase.py | 18 +++++++++--------- src/usecase/base.py | 8 +++++--- src/usecase/stock.py | 2 +- 3 files changed, 15 insertions(+), 13 deletions(-) diff --git a/src/tests/test_stock_usecase.py b/src/tests/test_stock_usecase.py index 683807d..25b9659 100644 --- a/src/tests/test_stock_usecase.py +++ b/src/tests/test_stock_usecase.py @@ -662,9 +662,9 @@ def test_get_portfolio_info_with_valid_holdings(self, mock_get_stock_price, stoc assert result == expected_result -class TestStockUsecaseGetStockInfoList: +class TestStockUsecaseGetStockInfo: @patch.object(StockUsecase, "_get_stock_price") - def test_get_stock_info_list_no_portfolio(self, mock_get_stock_price, stock_usecase): + def test_get_stock_info_no_portfolio(self, mock_get_stock_price, stock_usecase): # Arrange usecase, _, portfolio_repo = stock_usecase user_id = 1 @@ -672,7 +672,7 @@ def test_get_stock_info_list_no_portfolio(self, mock_get_stock_price, stock_usec expected_result = {StockType.ETF.value: [], StockType.STOCKS.value: [], "CASH": []} # Act - result = usecase.get_stock_info_list(user_id) + result = usecase.get_stock_info(user_id) # Assert portfolio_repo.get.assert_called_once_with(user_id=user_id) @@ -680,7 +680,7 @@ def test_get_stock_info_list_no_portfolio(self, mock_get_stock_price, stock_usec assert result == expected_result @patch.object(StockUsecase, "_get_stock_price") - def test_get_stock_info_list_empty_portfolio(self, mock_get_stock_price, stock_usecase): + def test_get_stock_info_empty_portfolio(self, mock_get_stock_price, stock_usecase): # Arrange usecase, _, portfolio_repo = stock_usecase user_id = 1 @@ -696,7 +696,7 @@ def test_get_stock_info_list_empty_portfolio(self, mock_get_stock_price, stock_u expected_result = {StockType.ETF.value: [], StockType.STOCKS.value: [], "CASH": []} # Act - result = usecase.get_stock_info_list(user_id) + result = usecase.get_stock_info(user_id) # Assert portfolio_repo.get.assert_called_once_with(user_id=user_id) @@ -704,7 +704,7 @@ def test_get_stock_info_list_empty_portfolio(self, mock_get_stock_price, stock_u assert result == expected_result @patch.object(StockUsecase, "_get_stock_price") - def test_get_stock_info_list_no_valid_holdings(self, mock_get_stock_price, stock_usecase): + def test_get_stock_info_no_valid_holdings(self, mock_get_stock_price, stock_usecase): # Arrange usecase, _, portfolio_repo = stock_usecase user_id = 1 @@ -720,7 +720,7 @@ def test_get_stock_info_list_no_valid_holdings(self, mock_get_stock_price, stock expected_result = {StockType.ETF.value: [], StockType.STOCKS.value: [], "CASH": []} # Act - result = usecase.get_stock_info_list(user_id) + result = usecase.get_stock_info(user_id) # Assert portfolio_repo.get.assert_called_once_with(user_id=user_id) @@ -728,7 +728,7 @@ def test_get_stock_info_list_no_valid_holdings(self, mock_get_stock_price, stock assert result == expected_result @patch.object(StockUsecase, "_get_stock_price") - def test_get_stock_info_list_with_valid_holdings(self, mock_get_stock_price, stock_usecase): + def test_get_stock_info_with_valid_holdings(self, mock_get_stock_price, stock_usecase): # Arrange usecase, _, portfolio_repo = stock_usecase user_id = 1 @@ -793,7 +793,7 @@ def test_get_stock_info_list_with_valid_holdings(self, mock_get_stock_price, sto } # Act - result = usecase.get_stock_info_list(user_id) + result = usecase.get_stock_info(user_id) # Assert portfolio_repo.get.assert_called_once_with(user_id=user_id) diff --git a/src/usecase/base.py b/src/usecase/base.py index 0f4bc88..6aae35b 100644 --- a/src/usecase/base.py +++ b/src/usecase/base.py @@ -1,7 +1,6 @@ -from typing import List +from typing import List, Dict from abc import ABC, abstractmethod - -from domain.stock import CreateStock, Stock +from domain.stock import CreateStock, Stock, StockInfo from domain.portfolio import PortfolioInfo @@ -15,3 +14,6 @@ def list(self, user_id: int) -> List[Stock]: def get_portfolio_info(self, user_id: int) -> PortfolioInfo: """Get portfolio info""" + + def get_stock_info(self, user_id: int) -> Dict[str, List[StockInfo]]: + """Get stock info by user id""" diff --git a/src/usecase/stock.py b/src/usecase/stock.py index d988b7c..43085fd 100644 --- a/src/usecase/stock.py +++ b/src/usecase/stock.py @@ -105,7 +105,7 @@ def get_portfolio_info(self, user_id: int) -> PortfolioInfo: roi=roi, ) - def get_stock_info_list(self, user_id: int) -> Dict[str, List[StockInfo]]: + def get_stock_info(self, user_id: int) -> Dict[str, List[StockInfo]]: result = {StockType.ETF.value: [], StockType.STOCKS.value: [], "CASH": []} portfolio = self.portfolio_repo.get(user_id=user_id) if portfolio is None or portfolio.total_money_in == 0.0: From a8aa1d22f039e69ec77cc88b5604784af7a6a7c8 Mon Sep 17 00:00:00 2001 From: Eyo Chen Date: Sat, 28 Jun 2025 15:42:22 +0800 Subject: [PATCH 3/4] feat: add handler --- src/handler/stock.py | 37 ++++++++++++++++++++++++++++++++++++- 1 file changed, 36 insertions(+), 1 deletion(-) diff --git a/src/handler/stock.py b/src/handler/stock.py index 058603c..50fa6b1 100644 --- a/src/handler/stock.py +++ b/src/handler/stock.py @@ -5,7 +5,7 @@ import proto.stock_pb2 as stock_pb2 import proto.stock_pb2_grpc as stock_pb2_grpc from usecase.base import AbstractStockUsecase -from domain.stock import CreateStock, Stock +from domain.stock import CreateStock, Stock, StockInfo from domain.enum import ActionType, ACTION_MAP, StockType, STOCK_MAP @@ -80,6 +80,22 @@ def GetPortfolioInfo(self, request, context): context.set_details("Internal server error") raise grpc.RpcError("Internal server error") + def GetStockInfo(self, request, context): + try: + user_id = request.user_id + stock_info = self.stock_usecase.get_stock_info(user_id=user_id) + + return self._convert_to_proto_stock_info(stock_info=stock_info) + except Exception as e: + logging.error( + "Failed to get stock info for user_id=%s: %s", + request.user_id, + str(e), + ) + context.set_code(grpc.StatusCode.INTERNAL) + context.set_details("Internal server error") + raise grpc.RpcError("Internal server error") + def _map_action_type(self, action: int) -> ActionType: if action not in ACTION_MAP: raise ValueError(f"Invalid action type: {action}. Must be 1 (BUY), 2 (SELL), or 3 (TRANSFER).") @@ -105,3 +121,22 @@ def _convert_to_proto_stock_list(self, stock_list: ListType[Stock]): ) for stock in stock_list ] + + def _convert_to_proto_stock_info(self, stock_info: StockInfo): + return stock_pb2.GetStockInfoResp( + stocks=self._convert_to_proto_stock_info_list(stock_info[StockType.STOCKS.value]), + etf=self._convert_to_proto_stock_info_list(stock_info[StockType.ETF.value]), + cash=self._convert_to_proto_stock_info_list(stock_info["CASH"]), + ) + + def _convert_to_proto_stock_info_list(self, stock_info_list: ListType[StockInfo]): + return [ + stock_pb2.StockInfo( + symbol=stock_info.symbol, + quantity=stock_info.quantity, + price=stock_info.price, + avg_cost=stock_info.avg_cost, + percentage=stock_info.percentage, + ) + for stock_info in stock_info_list + ] From 35c47b55a9affd2aa012582b12efa54448d5992e Mon Sep 17 00:00:00 2001 From: Eyo Chen Date: Sat, 28 Jun 2025 15:42:49 +0800 Subject: [PATCH 4/4] test: add unit testing --- src/tests/test_stock_handler.py | 125 +++++++++++++++++++++++++++++++- 1 file changed, 124 insertions(+), 1 deletion(-) diff --git a/src/tests/test_stock_handler.py b/src/tests/test_stock_handler.py index 404142b..9cc08a6 100644 --- a/src/tests/test_stock_handler.py +++ b/src/tests/test_stock_handler.py @@ -5,7 +5,7 @@ from unittest.mock import Mock from handler.stock import StockService from usecase.base import AbstractStockUsecase -from domain.stock import CreateStock, Stock +from domain.stock import CreateStock, Stock, StockInfo from domain.portfolio import PortfolioInfo from domain.enum import ActionType, StockType @@ -318,3 +318,126 @@ def test_internal_error(self, mock_stock_usecase, mock_context, valid_request): mock_context.set_code.assert_called_once_with(grpc.StatusCode.INTERNAL) mock_context.set_details.assert_called_once_with("Internal server error") mock_stock_usecase.get_portfolio_info.assert_called_once_with(user_id=1) + + +class TestStockServiceGetStockInfo: + # Fixture to create a mock stock_usecase + @pytest.fixture + def mock_stock_usecase(self): + usecase = Mock(spec=AbstractStockUsecase) + usecase.get_stock_info.return_value = { + StockType.STOCKS.value: [ + StockInfo( + symbol="AAPL", + quantity=10, + price=100.0, + avg_cost=95.0, + percentage=5.26, + ), + StockInfo( + symbol="GOOGL", + quantity=5, + price=1500.0, + avg_cost=1400.0, + percentage=7.14, + ), + ], + StockType.ETF.value: [ + StockInfo( + symbol="SPY", + quantity=20, + price=400.0, + avg_cost=390.0, + percentage=2.56, + ), + ], + "CASH": [ + StockInfo( + symbol="USD", + quantity=1, + price=1000.0, + avg_cost=1.0, + percentage=0.0, + ), + ], + } + return usecase + + # Fixture to create a mock gRPC context + @pytest.fixture + def mock_context(self): + context = Mock() + context.set_code = Mock() + context.set_details = Mock() + return context + + # Fixture to create a valid gRPC request + @pytest.fixture + def valid_request(self): + request = Mock() + request.user_id = 1 + return request + + def test_success(self, mock_stock_usecase, mock_context, valid_request): + # Arrange + service = StockService(mock_stock_usecase) + + expected_result = stock_pb2.GetStockInfoResp( + stocks=[ + stock_pb2.StockInfo( + symbol="AAPL", + quantity=10, + price=100.0, + avg_cost=95.0, + percentage=5.26, + ), + stock_pb2.StockInfo( + symbol="GOOGL", + quantity=5, + price=1500.0, + avg_cost=1400.0, + percentage=7.14, + ), + ], + etf=[ + stock_pb2.StockInfo( + symbol="SPY", + quantity=20, + price=400.0, + avg_cost=390.0, + percentage=2.56, + ), + ], + cash=[ + stock_pb2.StockInfo( + symbol="USD", + quantity=1, + price=1000.0, + avg_cost=1.0, + percentage=0.0, + ), + ], + ) + + # Action + response = service.GetStockInfo(valid_request, mock_context) + + # Assertion + assert isinstance(response, stock_pb2.GetStockInfoResp) + assert response == expected_result + mock_stock_usecase.get_stock_info.assert_called_once_with(user_id=1) + mock_context.set_code.assert_not_called() + mock_context.set_details.assert_not_called() + + def test_internal_error(self, mock_stock_usecase, mock_context, valid_request): + # Arrange + service = StockService(mock_stock_usecase) + mock_stock_usecase.get_stock_info.side_effect = Exception("Database error") # Simulate internal error + + # Act/Assertion + with pytest.raises(grpc.RpcError) as exc_info: + service.GetStockInfo(valid_request, mock_context) + assert str(exc_info.value) == "Internal server error" + mock_context.set_code.assert_called_once_with(grpc.StatusCode.INTERNAL) + mock_context.set_details.assert_called_once_with("Internal server error") + mock_stock_usecase.get_stock_info.assert_called_once_with(user_id=1)