Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 36 additions & 1 deletion src/handler/stock.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import proto.stock_pb2 as stock_pb2
import proto.stock_pb2_grpc as stock_pb2_grpc
from usecase.base import AbstractStockUsecase
from domain.stock import CreateStock, Stock
from domain.stock import CreateStock, Stock, StockInfo
from domain.enum import ActionType, ACTION_MAP, StockType, STOCK_MAP


Expand Down Expand Up @@ -80,6 +80,22 @@ def GetPortfolioInfo(self, request, context):
context.set_details("Internal server error")
raise grpc.RpcError("Internal server error")

def GetStockInfo(self, request, context):
try:
user_id = request.user_id
stock_info = self.stock_usecase.get_stock_info(user_id=user_id)

return self._convert_to_proto_stock_info(stock_info=stock_info)
except Exception as e:
logging.error(
"Failed to get stock info for user_id=%s: %s",
request.user_id,
str(e),
)
context.set_code(grpc.StatusCode.INTERNAL)
context.set_details("Internal server error")
raise grpc.RpcError("Internal server error")

def _map_action_type(self, action: int) -> ActionType:
if action not in ACTION_MAP:
raise ValueError(f"Invalid action type: {action}. Must be 1 (BUY), 2 (SELL), or 3 (TRANSFER).")
Expand All @@ -105,3 +121,22 @@ def _convert_to_proto_stock_list(self, stock_list: ListType[Stock]):
)
for stock in stock_list
]

def _convert_to_proto_stock_info(self, stock_info: StockInfo):
return stock_pb2.GetStockInfoResp(
stocks=self._convert_to_proto_stock_info_list(stock_info[StockType.STOCKS.value]),
etf=self._convert_to_proto_stock_info_list(stock_info[StockType.ETF.value]),
cash=self._convert_to_proto_stock_info_list(stock_info["CASH"]),
)

def _convert_to_proto_stock_info_list(self, stock_info_list: ListType[StockInfo]):
return [
stock_pb2.StockInfo(
symbol=stock_info.symbol,
quantity=stock_info.quantity,
price=stock_info.price,
avg_cost=stock_info.avg_cost,
percentage=stock_info.percentage,
)
for stock_info in stock_info_list
]
22 changes: 20 additions & 2 deletions src/proto/stock.proto
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,21 @@ message Stock {
google.protobuf.Timestamp updated_at = 9 [json_name = "updated_at"];
}

message StockInfo {
string symbol = 1 [json_name = "symbol"];
int32 quantity = 2 [json_name = "quantity"];
double price = 3 [json_name = "price"];
double avg_cost = 4 [json_name = "avg_cost"];
double percentage = 5 [json_name = "percentage"];
}

message CreateReq {
int32 user_id = 1 [json_name = "user_id"];
string symbol = 2 [json_name = "symbol"];
double price = 3 [json_name = "price"];
int32 quantity = 4 [json_name = "quantity"];
Action.Type action = 5 [json_name = "action"]; // add validation rules
StockType.Type stock_type = 6 [json_name = "stock_type"]; // add validation rules
Action.Type action = 5 [json_name = "action"];
StockType.Type stock_type = 6 [json_name = "stock_type"];
google.protobuf.Timestamp created_at = 7 [json_name = "created_at"];
google.protobuf.Timestamp updated_at = 8 [json_name = "updated_at"];
}
Expand Down Expand Up @@ -67,9 +75,19 @@ message GetPortfolioInfoResp {
double roi = 4 [json_name = "roi"];
}

message GetStockInfoReq {
int32 user_id = 1 [json_name = "user_id"];
}

message GetStockInfoResp {
repeated StockInfo stocks = 1 [json_name = "STOCKS"];
repeated StockInfo etf = 2 [json_name = "ETF"];
repeated StockInfo cash = 3 [json_name = "CASH"];
}

service StockService {
rpc Create (CreateReq) returns (CreateResp) {}
rpc List (ListReq) returns (ListResp) {}
rpc GetPortfolioInfo (GetPortfolioInfoReq) returns (GetPortfolioInfoResp) {}
rpc GetStockInfo (GetStockInfoReq) returns (GetStockInfoResp) {}
}
36 changes: 21 additions & 15 deletions src/proto/stock_pb2.py

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

43 changes: 43 additions & 0 deletions src/proto/stock_pb2_grpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,11 @@ def __init__(self, channel):
request_serializer=proto_dot_stock__pb2.GetPortfolioInfoReq.SerializeToString,
response_deserializer=proto_dot_stock__pb2.GetPortfolioInfoResp.FromString,
_registered_method=True)
self.GetStockInfo = channel.unary_unary(
'/stock.StockService/GetStockInfo',
request_serializer=proto_dot_stock__pb2.GetStockInfoReq.SerializeToString,
response_deserializer=proto_dot_stock__pb2.GetStockInfoResp.FromString,
_registered_method=True)


class StockServiceServicer(object):
Expand All @@ -72,6 +77,12 @@ def GetPortfolioInfo(self, request, context):
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')

def GetStockInfo(self, request, context):
"""Missing associated documentation comment in .proto file."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')


def add_StockServiceServicer_to_server(servicer, server):
rpc_method_handlers = {
Expand All @@ -90,6 +101,11 @@ def add_StockServiceServicer_to_server(servicer, server):
request_deserializer=proto_dot_stock__pb2.GetPortfolioInfoReq.FromString,
response_serializer=proto_dot_stock__pb2.GetPortfolioInfoResp.SerializeToString,
),
'GetStockInfo': grpc.unary_unary_rpc_method_handler(
servicer.GetStockInfo,
request_deserializer=proto_dot_stock__pb2.GetStockInfoReq.FromString,
response_serializer=proto_dot_stock__pb2.GetStockInfoResp.SerializeToString,
),
}
generic_handler = grpc.method_handlers_generic_handler(
'stock.StockService', rpc_method_handlers)
Expand Down Expand Up @@ -181,3 +197,30 @@ def GetPortfolioInfo(request,
timeout,
metadata,
_registered_method=True)

@staticmethod
def GetStockInfo(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(
request,
target,
'/stock.StockService/GetStockInfo',
proto_dot_stock__pb2.GetStockInfoReq.SerializeToString,
proto_dot_stock__pb2.GetStockInfoResp.FromString,
options,
channel_credentials,
insecure,
call_credentials,
compression,
wait_for_ready,
timeout,
metadata,
_registered_method=True)
125 changes: 124 additions & 1 deletion src/tests/test_stock_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from unittest.mock import Mock
from handler.stock import StockService
from usecase.base import AbstractStockUsecase
from domain.stock import CreateStock, Stock
from domain.stock import CreateStock, Stock, StockInfo
from domain.portfolio import PortfolioInfo
from domain.enum import ActionType, StockType

Expand Down Expand Up @@ -318,3 +318,126 @@ def test_internal_error(self, mock_stock_usecase, mock_context, valid_request):
mock_context.set_code.assert_called_once_with(grpc.StatusCode.INTERNAL)
mock_context.set_details.assert_called_once_with("Internal server error")
mock_stock_usecase.get_portfolio_info.assert_called_once_with(user_id=1)


class TestStockServiceGetStockInfo:
# Fixture to create a mock stock_usecase
@pytest.fixture
def mock_stock_usecase(self):
usecase = Mock(spec=AbstractStockUsecase)
usecase.get_stock_info.return_value = {
StockType.STOCKS.value: [
StockInfo(
symbol="AAPL",
quantity=10,
price=100.0,
avg_cost=95.0,
percentage=5.26,
),
StockInfo(
symbol="GOOGL",
quantity=5,
price=1500.0,
avg_cost=1400.0,
percentage=7.14,
),
],
StockType.ETF.value: [
StockInfo(
symbol="SPY",
quantity=20,
price=400.0,
avg_cost=390.0,
percentage=2.56,
),
],
"CASH": [
StockInfo(
symbol="USD",
quantity=1,
price=1000.0,
avg_cost=1.0,
percentage=0.0,
),
],
}
return usecase

# Fixture to create a mock gRPC context
@pytest.fixture
def mock_context(self):
context = Mock()
context.set_code = Mock()
context.set_details = Mock()
return context

# Fixture to create a valid gRPC request
@pytest.fixture
def valid_request(self):
request = Mock()
request.user_id = 1
return request

def test_success(self, mock_stock_usecase, mock_context, valid_request):
# Arrange
service = StockService(mock_stock_usecase)

expected_result = stock_pb2.GetStockInfoResp(
stocks=[
stock_pb2.StockInfo(
symbol="AAPL",
quantity=10,
price=100.0,
avg_cost=95.0,
percentage=5.26,
),
stock_pb2.StockInfo(
symbol="GOOGL",
quantity=5,
price=1500.0,
avg_cost=1400.0,
percentage=7.14,
),
],
etf=[
stock_pb2.StockInfo(
symbol="SPY",
quantity=20,
price=400.0,
avg_cost=390.0,
percentage=2.56,
),
],
cash=[
stock_pb2.StockInfo(
symbol="USD",
quantity=1,
price=1000.0,
avg_cost=1.0,
percentage=0.0,
),
],
)

# Action
response = service.GetStockInfo(valid_request, mock_context)

# Assertion
assert isinstance(response, stock_pb2.GetStockInfoResp)
assert response == expected_result
mock_stock_usecase.get_stock_info.assert_called_once_with(user_id=1)
mock_context.set_code.assert_not_called()
mock_context.set_details.assert_not_called()

def test_internal_error(self, mock_stock_usecase, mock_context, valid_request):
# Arrange
service = StockService(mock_stock_usecase)
mock_stock_usecase.get_stock_info.side_effect = Exception("Database error") # Simulate internal error

# Act/Assertion
with pytest.raises(grpc.RpcError) as exc_info:
service.GetStockInfo(valid_request, mock_context)
assert str(exc_info.value) == "Internal server error"
mock_context.set_code.assert_called_once_with(grpc.StatusCode.INTERNAL)
mock_context.set_details.assert_called_once_with("Internal server error")
mock_stock_usecase.get_stock_info.assert_called_once_with(user_id=1)
Loading