From becc888659bb6b032313b403599a0ac9dc5b1c15 Mon Sep 17 00:00:00 2001 From: smallfish06 Date: Sun, 1 Mar 2026 22:02:51 +0900 Subject: [PATCH] Refactor account resolution and broker selection logic - Introduced `resolveBrokerByAccountID` and `normalizeAccountIDAlias` for unified broker resolution. - Enhanced account matching with suffix handling and deterministic sorting. - Replaced `getBrokerStrict` with `resolveBrokerByAccountID` throughout server handlers. - Added new test cases for ambiguous account ID handling, broker selection, and sandbox preference. - Updated OpenAPI documentation to include optional `account_id` query parameter for relevant endpoints. --- internal/server/account_resolution.go | 185 ++++++++++++++++++ internal/server/handler_accounts.go | 16 +- internal/server/handler_accounts_test.go | 28 +++ internal/server/handler_auth.go | 6 +- internal/server/handler_auth_test.go | 112 +++++++++++ internal/server/handler_instruments.go | 10 +- internal/server/handler_kis_proxy.go | 6 +- internal/server/handler_kiwoom_proxy.go | 6 +- internal/server/handler_multiaccounts.go | 14 +- internal/server/handler_multiaccounts_test.go | 39 ++++ internal/server/handler_orders.go | 60 ++---- internal/server/handler_orders_test.go | 30 +++ internal/server/handler_quotes.go | 36 +++- internal/server/handler_quotes_test.go | 81 ++++++++ internal/server/server.go | 32 +-- 15 files changed, 577 insertions(+), 84 deletions(-) create mode 100644 internal/server/account_resolution.go create mode 100644 internal/server/handler_auth_test.go create mode 100644 internal/server/handler_multiaccounts_test.go diff --git a/internal/server/account_resolution.go b/internal/server/account_resolution.go new file mode 100644 index 0000000..2f31a64 --- /dev/null +++ b/internal/server/account_resolution.go @@ -0,0 +1,185 @@ +package server + +import ( + "net/http" + "sort" + "strings" + + "github.com/smallfish06/krsec/pkg/broker" +) + +const ambiguousAccountIDError = "account_id is ambiguous; use full account_id" + +func normalizeBrokerCode(name string) string { + switch strings.ToLower(strings.TrimSpace(name)) { + case broker.CodeKIS, strings.ToLower(broker.NameKIS): + return broker.CodeKIS + case broker.CodeKiwoom, strings.ToLower(broker.NameKiwoom): + return broker.CodeKiwoom + default: + return "" + } +} + +func normalizeAccountIDAlias(accountID string) string { + accountID = strings.TrimSpace(accountID) + if strings.HasSuffix(accountID, "-01") { + return strings.TrimSuffix(accountID, "-01") + } + return accountID +} + +func isDigits(s string) bool { + if s == "" { + return false + } + for _, r := range s { + if r < '0' || r > '9' { + return false + } + } + return true +} + +func kisAccountBase(accountID string) (string, bool) { + accountID = strings.TrimSpace(accountID) + switch { + case len(accountID) == 8 && isDigits(accountID): + return accountID, true + case len(accountID) == 11 && accountID[8] == '-' && isDigits(accountID[:8]) && isDigits(accountID[9:]): + return accountID[:8], true + default: + return "", false + } +} + +func sameAccountID(a, b string) bool { + a = strings.TrimSpace(a) + b = strings.TrimSpace(b) + if a == "" || b == "" { + return false + } + if a == b { + return true + } + if baseA, okA := kisAccountBase(a); okA { + if baseB, okB := kisAccountBase(b); okB { + return baseA == baseB + } + } + return normalizeAccountIDAlias(a) == normalizeAccountIDAlias(b) +} + +func (s *Server) resolveBrokerByAccountID(accountID string) (broker.Broker, int, string) { + accountID = strings.TrimSpace(accountID) + if accountID == "" { + return nil, http.StatusBadRequest, "account_id is required" + } + + if brk, ok := s.brokers[accountID]; ok { + return brk, 0, "" + } + + candidates := s.findBrokerAccountCandidates(accountID) + switch len(candidates) { + case 0: + return nil, http.StatusNotFound, "account not found" + case 1: + return s.brokers[candidates[0]], 0, "" + default: + return nil, http.StatusBadRequest, ambiguousAccountIDError + } +} + +func (s *Server) getBrokerStrict(accountID string) (broker.Broker, bool) { + brk, status, _ := s.resolveBrokerByAccountID(accountID) + return brk, status == 0 +} + +func (s *Server) findBrokerAccountCandidates(accountID string) []string { + matches := make([]string, 0, 2) + seen := make(map[string]struct{}, 2) + + // Prefer configured account order for deterministic matching. + for _, acc := range s.accounts { + id := strings.TrimSpace(acc.AccountID) + if id == "" { + continue + } + if _, ok := s.brokers[id]; !ok { + continue + } + if !sameAccountID(id, accountID) { + continue + } + if _, ok := seen[id]; ok { + continue + } + seen[id] = struct{}{} + matches = append(matches, id) + } + + extra := make([]string, 0, len(s.brokers)) + for id := range s.brokers { + if _, ok := seen[id]; ok { + continue + } + if sameAccountID(id, accountID) { + extra = append(extra, id) + } + } + sort.Strings(extra) + matches = append(matches, extra...) + + return matches +} + +func (s *Server) resolveAuthBroker(requestedBroker string, sandbox bool) (broker.Broker, int, string) { + brokerCode := normalizeBrokerCode(requestedBroker) + if strings.TrimSpace(requestedBroker) != "" && brokerCode == "" { + return nil, http.StatusBadRequest, "unsupported broker" + } + + if brokerCode == "" { + brk := s.getFirstBroker() + if brk == nil { + return nil, http.StatusServiceUnavailable, "no broker available" + } + return brk, 0, "" + } + + var fallback broker.Broker + for _, acc := range s.accounts { + if normalizeBrokerCode(acc.Broker) != brokerCode { + continue + } + brk, status, _ := s.resolveBrokerByAccountID(acc.AccountID) + if status != 0 { + continue + } + if acc.Sandbox == sandbox { + return brk, 0, "" + } + if fallback == nil { + fallback = brk + } + } + + if fallback != nil { + return fallback, 0, "" + } + + ids := make([]string, 0, len(s.brokers)) + for id := range s.brokers { + ids = append(ids, id) + } + sort.Strings(ids) + for _, id := range ids { + brk := s.brokers[id] + if normalizeBrokerCode(brk.Name()) == brokerCode { + return brk, 0, "" + } + } + + return nil, http.StatusServiceUnavailable, "no " + strings.ToUpper(brokerCode) + " account available" +} diff --git a/internal/server/handler_accounts.go b/internal/server/handler_accounts.go index 7a40e85..be8e422 100644 --- a/internal/server/handler_accounts.go +++ b/internal/server/handler_accounts.go @@ -12,11 +12,11 @@ import ( func (s *Server) handleGetBalance(c fuego.ContextNoBody) (Response, error) { accountID := c.PathParam("account_id") - brk, ok := s.getBrokerStrict(accountID) - if !ok { - return respond(c, http.StatusNotFound, Response{ + brk, status, reason := s.resolveBrokerByAccountID(accountID) + if brk == nil { + return respond(c, status, Response{ OK: false, - Error: "account not found", + Error: reason, }) } @@ -39,11 +39,11 @@ func (s *Server) handleGetBalance(c fuego.ContextNoBody) (Response, error) { func (s *Server) handleGetPositions(c fuego.ContextNoBody) (Response, error) { accountID := c.PathParam("account_id") - brk, ok := s.getBrokerStrict(accountID) - if !ok { - return respond(c, http.StatusNotFound, Response{ + brk, status, reason := s.resolveBrokerByAccountID(accountID) + if brk == nil { + return respond(c, status, Response{ OK: false, - Error: "account not found", + Error: reason, }) } diff --git a/internal/server/handler_accounts_test.go b/internal/server/handler_accounts_test.go index 2b16940..2a76253 100644 --- a/internal/server/handler_accounts_test.go +++ b/internal/server/handler_accounts_test.go @@ -58,3 +58,31 @@ func TestHandleGetPositions_UnknownAccountReturnsNotFound(t *testing.T) { t.Fatalf("unexpected error: %s", resp.Error) } } + +func TestHandleGetBalance_AmbiguousAccountReturnsBadRequest(t *testing.T) { + t.Parallel() + + first := newMockBroker(t, "KIS-1") + second := newMockBroker(t, "KIS-2") + s := newOrderTestServer( + map[string]broker.Broker{ + "12345678-01": first, + "12345678-02": second, + }, + []config.AccountConfig{ + {AccountID: "12345678-01"}, + {AccountID: "12345678-02"}, + }, + ) + + req := httptest.NewRequest(http.MethodGet, "/accounts/12345678/balance", nil) + rr := performFiberRequest(t, s, req) + + if rr.Code != http.StatusBadRequest { + t.Fatalf("expected 400, got %d body=%s", rr.Code, rr.Body.String()) + } + resp := decodeResponse(t, rr) + if resp.Error != ambiguousAccountIDError { + t.Fatalf("unexpected error: %s", resp.Error) + } +} diff --git a/internal/server/handler_auth.go b/internal/server/handler_auth.go index 2a25776..ff08df5 100644 --- a/internal/server/handler_auth.go +++ b/internal/server/handler_auth.go @@ -25,11 +25,11 @@ func (s *Server) handleAuthToken(c fuego.ContextWithBody[AuthTokenRequest]) (Res }) } - brk := s.getFirstBroker() + brk, status, reason := s.resolveAuthBroker(req.Broker, req.Sandbox) if brk == nil { - return respond(c, http.StatusServiceUnavailable, Response{ + return respond(c, status, Response{ OK: false, - Error: "no broker available", + Error: reason, }) } diff --git a/internal/server/handler_auth_test.go b/internal/server/handler_auth_test.go new file mode 100644 index 0000000..8432d49 --- /dev/null +++ b/internal/server/handler_auth_test.go @@ -0,0 +1,112 @@ +package server + +import ( + "bytes" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/stretchr/testify/mock" + + "github.com/smallfish06/krsec/pkg/broker" + "github.com/smallfish06/krsec/pkg/config" +) + +func TestHandleAuthToken_RespectsRequestedBroker(t *testing.T) { + t.Parallel() + + kis := newMockBroker(t, "KIS") + kiwoom := newMockBroker(t, "KIWOOM") + kis.On("Authenticate", mock.Anything, broker.Credentials{ + AppKey: "k", + AppSecret: "s", + }).Return(&broker.Token{ + AccessToken: "kis-token", + TokenType: "Bearer", + ExpiresAt: time.Now().Add(time.Hour), + }, nil).Once() + + s := newOrderTestServer( + map[string]broker.Broker{ + "kiwoom-acc": kiwoom, + "kis-acc": kis, + }, + []config.AccountConfig{ + {AccountID: "kiwoom-acc", Broker: "kiwoom", Sandbox: true}, + {AccountID: "kis-acc", Broker: "kis", Sandbox: true}, + }, + ) + + body := []byte(`{"broker":"kis","credentials":{"app_key":"k","app_secret":"s"},"sandbox":true}`) + req := httptest.NewRequest(http.MethodPost, "/auth/token", bytes.NewReader(body)) + rr := performFiberRequest(t, s, req) + + if rr.Code != http.StatusOK { + t.Fatalf("expected 200, got %d body=%s", rr.Code, rr.Body.String()) + } + resp := decodeResponse(t, rr) + if !resp.OK { + t.Fatalf("expected ok=true") + } + if resp.Broker != "KIS" { + t.Fatalf("broker = %q, want KIS", resp.Broker) + } + kiwoom.AssertNotCalled(t, "Authenticate", mock.Anything, mock.Anything) +} + +func TestHandleAuthToken_RejectsUnsupportedBroker(t *testing.T) { + t.Parallel() + + kis := newMockBroker(t, "KIS") + s := newOrderTestServer( + map[string]broker.Broker{"kis-acc": kis}, + []config.AccountConfig{{AccountID: "kis-acc", Broker: "kis"}}, + ) + + body := []byte(`{"broker":"future","credentials":{"app_key":"k","app_secret":"s"}}`) + req := httptest.NewRequest(http.MethodPost, "/auth/token", bytes.NewReader(body)) + rr := performFiberRequest(t, s, req) + + if rr.Code != http.StatusBadRequest { + t.Fatalf("expected 400, got %d body=%s", rr.Code, rr.Body.String()) + } + resp := decodeResponse(t, rr) + if resp.Error != "unsupported broker" { + t.Fatalf("unexpected error: %s", resp.Error) + } + kis.AssertNotCalled(t, "Authenticate", mock.Anything, mock.Anything) +} + +func TestHandleAuthToken_SelectsSandboxMatchFirst(t *testing.T) { + t.Parallel() + + sandboxBroker := newMockBroker(t, "KIS") + prodBroker := newMockBroker(t, "KIS") + + sandboxBroker.On("Authenticate", mock.Anything, mock.Anything).Return(&broker.Token{ + AccessToken: "sandbox-token", + TokenType: "Bearer", + ExpiresAt: time.Now().Add(time.Hour), + }, nil).Once() + + s := newOrderTestServer( + map[string]broker.Broker{ + "kis-prod": prodBroker, + "kis-sandbox": sandboxBroker, + }, + []config.AccountConfig{ + {AccountID: "kis-prod", Broker: "kis", Sandbox: false}, + {AccountID: "kis-sandbox", Broker: "kis", Sandbox: true}, + }, + ) + + body := []byte(`{"broker":"kis","credentials":{"app_key":"k","app_secret":"s"},"sandbox":true}`) + req := httptest.NewRequest(http.MethodPost, "/auth/token", bytes.NewReader(body)) + rr := performFiberRequest(t, s, req) + + if rr.Code != http.StatusOK { + t.Fatalf("expected 200, got %d body=%s", rr.Code, rr.Body.String()) + } + prodBroker.AssertNotCalled(t, "Authenticate", mock.Anything, mock.Anything) +} diff --git a/internal/server/handler_instruments.go b/internal/server/handler_instruments.go index 4961542..9fd4637 100644 --- a/internal/server/handler_instruments.go +++ b/internal/server/handler_instruments.go @@ -20,13 +20,17 @@ func (s *Server) handleGetInstrument(c fuego.ContextNoBody) (Response, error) { symbol := c.PathParam("symbol") accountID := c.QueryParam("account_id") + var candidates []broker.Broker if accountID != "" { - if _, ok := s.getBrokerStrict(accountID); !ok { - return respond(c, http.StatusNotFound, Response{OK: false, Error: "account not found"}) + brk, status, reason := s.resolveBrokerByAccountID(accountID) + if brk == nil { + return respond(c, status, Response{OK: false, Error: reason}) } + candidates = []broker.Broker{brk} + } else { + candidates = s.orderBrokerCandidates("") } - candidates := s.orderBrokerCandidates(accountID) if len(candidates) == 0 { return respond(c, http.StatusServiceUnavailable, Response{OK: false, Error: "no broker available"}) } diff --git a/internal/server/handler_kis_proxy.go b/internal/server/handler_kis_proxy.go index 8129f1b..bdeb3fe 100644 --- a/internal/server/handler_kis_proxy.go +++ b/internal/server/handler_kis_proxy.go @@ -116,9 +116,9 @@ func (s *Server) handleKISProxyPath(c fuego.ContextWithBody[kisProxyRequest], ra func (s *Server) resolveKISProxyBroker(accountID string) (broker.Broker, int, string) { accountID = strings.TrimSpace(accountID) if accountID != "" { - brk, ok := s.getBrokerStrict(accountID) - if !ok { - return nil, http.StatusNotFound, "account not found" + brk, status, reason := s.resolveBrokerByAccountID(accountID) + if brk == nil { + return nil, status, reason } if !strings.EqualFold(strings.TrimSpace(brk.Name()), broker.NameKIS) { return nil, http.StatusBadRequest, "account broker is not KIS" diff --git a/internal/server/handler_kiwoom_proxy.go b/internal/server/handler_kiwoom_proxy.go index 13e56e4..3e65dd8 100644 --- a/internal/server/handler_kiwoom_proxy.go +++ b/internal/server/handler_kiwoom_proxy.go @@ -127,9 +127,9 @@ func (s *Server) handleKiwoomProxyStatic(path, apiID string) func(fuego.ContextW func (s *Server) resolveKiwoomProxyBroker(accountID string) (broker.Broker, int, string) { accountID = strings.TrimSpace(accountID) if accountID != "" { - brk, ok := s.getBrokerStrict(accountID) - if !ok { - return nil, http.StatusNotFound, "account not found" + brk, status, reason := s.resolveBrokerByAccountID(accountID) + if brk == nil { + return nil, status, reason } if !strings.EqualFold(strings.TrimSpace(brk.Name()), broker.NameKiwoom) { return nil, http.StatusBadRequest, "account broker is not Kiwoom" diff --git a/internal/server/handler_multiaccounts.go b/internal/server/handler_multiaccounts.go index 1c90b84..dd59853 100644 --- a/internal/server/handler_multiaccounts.go +++ b/internal/server/handler_multiaccounts.go @@ -31,16 +31,19 @@ func (s *Server) handleAccountsSummary(c fuego.ContextNoBody) (Response, error) var totalAssets, totalCash, totalProfitLoss float64 balances := make([]broker.Balance, 0, len(s.accounts)) + failed := 0 for _, account := range s.accounts { - brk := s.getBroker(account.AccountID) - if brk == nil { + brk, status, _ := s.resolveBrokerByAccountID(account.AccountID) + if status != 0 || brk == nil { + failed++ continue } balance, err := brk.GetBalance(ctx, account.AccountID) if err != nil { // 에러가 발생해도 계속 진행 + failed++ continue } @@ -50,6 +53,13 @@ func (s *Server) handleAccountsSummary(c fuego.ContextNoBody) (Response, error) totalProfitLoss += balance.ProfitLoss } + if len(s.accounts) > 0 && len(balances) == 0 && failed > 0 { + return respond(c, http.StatusServiceUnavailable, Response{ + OK: false, + Error: "failed to retrieve balances from all accounts", + }) + } + summary := broker.AccountSummary{ TotalAssets: totalAssets, TotalCash: totalCash, diff --git a/internal/server/handler_multiaccounts_test.go b/internal/server/handler_multiaccounts_test.go new file mode 100644 index 0000000..f1e71ba --- /dev/null +++ b/internal/server/handler_multiaccounts_test.go @@ -0,0 +1,39 @@ +package server + +import ( + "errors" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/mock" + + "github.com/smallfish06/krsec/pkg/broker" + "github.com/smallfish06/krsec/pkg/config" +) + +func TestHandleAccountsSummary_ReturnsServiceUnavailableWhenAllBalancesFail(t *testing.T) { + t.Parallel() + + b := newMockBroker(t, "KIS") + b.On("GetBalance", mock.Anything, "acc1").Return((*broker.Balance)(nil), errors.New("upstream unavailable")).Once() + + s := newOrderTestServer( + map[string]broker.Broker{"acc1": b}, + []config.AccountConfig{{AccountID: "acc1"}}, + ) + + req := httptest.NewRequest(http.MethodGet, "/accounts/summary", nil) + rr := performFiberRequest(t, s, req) + + if rr.Code != http.StatusServiceUnavailable { + t.Fatalf("expected 503, got %d body=%s", rr.Code, rr.Body.String()) + } + resp := decodeResponse(t, rr) + if resp.OK { + t.Fatalf("expected ok=false") + } + if resp.Error != "failed to retrieve balances from all accounts" { + t.Fatalf("unexpected error: %s", resp.Error) + } +} diff --git a/internal/server/handler_orders.go b/internal/server/handler_orders.go index 4b386a4..77afd32 100644 --- a/internal/server/handler_orders.go +++ b/internal/server/handler_orders.go @@ -4,7 +4,6 @@ import ( "context" "errors" "net/http" - "strings" "github.com/go-fuego/fuego" @@ -23,9 +22,9 @@ type orderFillsGetter interface { func (s *Server) handleGetOrder(c fuego.ContextNoBody) (Response, error) { accountID := c.PathParam("account_id") orderID := c.PathParam("order_id") - brk, ok := s.getBrokerStrict(accountID) - if !ok { - return respond(c, http.StatusNotFound, Response{OK: false, Error: "account not found"}) + brk, status, reason := s.resolveBrokerByAccountID(accountID) + if brk == nil { + return respond(c, status, Response{OK: false, Error: reason}) } getter, ok := brk.(orderGetter) @@ -61,9 +60,9 @@ func (s *Server) handleGetOrder(c fuego.ContextNoBody) (Response, error) { func (s *Server) handleGetOrderFills(c fuego.ContextNoBody) (Response, error) { accountID := c.PathParam("account_id") orderID := c.PathParam("order_id") - brk, ok := s.getBrokerStrict(accountID) - if !ok { - return respond(c, http.StatusNotFound, Response{OK: false, Error: "account not found"}) + brk, status, reason := s.resolveBrokerByAccountID(accountID) + if brk == nil { + return respond(c, status, Response{OK: false, Error: reason}) } getter, ok := brk.(orderFillsGetter) @@ -99,9 +98,9 @@ func (s *Server) handleGetOrderFills(c fuego.ContextNoBody) (Response, error) { func (s *Server) handlePlaceOrder(c fuego.ContextWithBody[broker.OrderRequest]) (Response, error) { accountID := c.PathParam("account_id") - brk, ok := s.getBrokerStrict(accountID) - if !ok { - return respond(c, http.StatusNotFound, Response{OK: false, Error: "account not found"}) + brk, status, reason := s.resolveBrokerByAccountID(accountID) + if brk == nil { + return respond(c, status, Response{OK: false, Error: reason}) } req, err := c.Body() @@ -140,9 +139,9 @@ func (s *Server) handlePlaceOrder(c fuego.ContextWithBody[broker.OrderRequest]) func (s *Server) handleCancelOrder(c fuego.ContextNoBody) (Response, error) { accountID := c.PathParam("account_id") orderID := c.PathParam("order_id") - brk, ok := s.getBrokerStrict(accountID) - if !ok { - return respond(c, http.StatusNotFound, Response{OK: false, Error: "account not found"}) + brk, status, reason := s.resolveBrokerByAccountID(accountID) + if brk == nil { + return respond(c, status, Response{OK: false, Error: reason}) } err := brk.CancelOrder(c.Context(), orderID) @@ -170,9 +169,9 @@ func (s *Server) handleCancelOrder(c fuego.ContextNoBody) (Response, error) { func (s *Server) handleModifyOrder(c fuego.ContextWithBody[broker.ModifyOrderRequest]) (Response, error) { accountID := c.PathParam("account_id") orderID := c.PathParam("order_id") - brk, ok := s.getBrokerStrict(accountID) - if !ok { - return respond(c, http.StatusNotFound, Response{OK: false, Error: "account not found"}) + brk, status, reason := s.resolveBrokerByAccountID(accountID) + if brk == nil { + return respond(c, status, Response{OK: false, Error: reason}) } req, err := c.Body() @@ -210,7 +209,7 @@ func (s *Server) orderBrokerCandidates(accountID string) []broker.Broker { seen := make(map[broker.Broker]struct{}) if accountID != "" { - if brk, ok := s.getBrokerStrict(accountID); ok { + if brk, status, _ := s.resolveBrokerByAccountID(accountID); status == 0 { out = append(out, brk) seen[brk] = struct{}{} } @@ -226,30 +225,3 @@ func (s *Server) orderBrokerCandidates(accountID string) []broker.Broker { return out } - -func (s *Server) getBrokerStrict(accountID string) (broker.Broker, bool) { - if brk, ok := s.brokers[accountID]; ok { - return brk, true - } - for key, brk := range s.brokers { - if strings.HasPrefix(key, accountID+"-") || strings.HasPrefix(accountID, key+"-") || strings.TrimSuffix(key, "-01") == strings.TrimSuffix(accountID, "-01") { - return brk, true - } - } - return nil, false -} - -func sameAccountID(a, b string) bool { - a = strings.TrimSpace(a) - b = strings.TrimSpace(b) - if a == b { - return true - } - if strings.TrimSuffix(a, "-01") == strings.TrimSuffix(b, "-01") { - return true - } - if strings.HasPrefix(a, b+"-") || strings.HasPrefix(b, a+"-") { - return true - } - return false -} diff --git a/internal/server/handler_orders_test.go b/internal/server/handler_orders_test.go index 38d843c..6a522c8 100644 --- a/internal/server/handler_orders_test.go +++ b/internal/server/handler_orders_test.go @@ -75,6 +75,36 @@ func TestHandlePlaceOrder_BodyAccountMismatchReturnsBadRequest(t *testing.T) { } } +func TestHandlePlaceOrder_BodyDefaultSuffixAliasAccepted(t *testing.T) { + t.Parallel() + + var captured broker.OrderRequest + b := newMockBroker(t, "KIS") + b.On("PlaceOrder", testifymock.Anything, testifymock.Anything).Run(func(args testifymock.Arguments) { + captured = args.Get(1).(broker.OrderRequest) + }).Return(&broker.OrderResult{ + OrderID: "000123", + Status: broker.OrderStatusPending, + Timestamp: time.Now(), + }, nil).Once() + + s := newOrderTestServer( + map[string]broker.Broker{"12345678-01": b}, + []config.AccountConfig{{AccountID: "12345678-01"}}, + ) + + body := []byte(`{"account_id":"12345678","symbol":"005930","market":"KRX","side":"buy","type":"limit","quantity":1,"price":70000}`) + req := httptest.NewRequest(http.MethodPost, "/accounts/12345678-01/orders", bytes.NewReader(body)) + rr := performFiberRequest(t, s, req) + + if rr.Code != http.StatusOK { + t.Fatalf("expected 200, got %d body=%s", rr.Code, rr.Body.String()) + } + if captured.AccountID != "12345678-01" { + t.Fatalf("account_id = %q, want 12345678-01", captured.AccountID) + } +} + func TestHandlePlaceOrder_UsesBrokerDomainShape(t *testing.T) { t.Parallel() diff --git a/internal/server/handler_quotes.go b/internal/server/handler_quotes.go index 1d6174b..70f3082 100644 --- a/internal/server/handler_quotes.go +++ b/internal/server/handler_quotes.go @@ -16,8 +16,22 @@ import ( func (s *Server) handleGetQuote(c fuego.ContextNoBody) (Response, error) { market := c.PathParam("market") symbol := c.PathParam("symbol") - - brk := s.getFirstBroker() + accountID := strings.TrimSpace(c.QueryParam("account_id")) + + var brk broker.Broker + if accountID != "" { + var status int + var reason string + brk, status, reason = s.resolveBrokerByAccountID(accountID) + if brk == nil { + return respond(c, status, Response{ + OK: false, + Error: reason, + }) + } + } else { + brk = s.getFirstBroker() + } if brk == nil { return respond(c, http.StatusInternalServerError, Response{ OK: false, @@ -44,8 +58,22 @@ func (s *Server) handleGetQuote(c fuego.ContextNoBody) (Response, error) { func (s *Server) handleGetOHLCV(c fuego.ContextNoBody) (Response, error) { market := c.PathParam("market") symbol := c.PathParam("symbol") - - brk := s.getFirstBroker() + accountID := strings.TrimSpace(c.QueryParam("account_id")) + + var brk broker.Broker + if accountID != "" { + var status int + var reason string + brk, status, reason = s.resolveBrokerByAccountID(accountID) + if brk == nil { + return respond(c, status, Response{ + OK: false, + Error: reason, + }) + } + } else { + brk = s.getFirstBroker() + } if brk == nil { return respond(c, http.StatusInternalServerError, Response{ OK: false, diff --git a/internal/server/handler_quotes_test.go b/internal/server/handler_quotes_test.go index 1a5b554..da403a1 100644 --- a/internal/server/handler_quotes_test.go +++ b/internal/server/handler_quotes_test.go @@ -65,6 +65,87 @@ func TestHandleGetQuote_InvalidSymbolReturnsBadRequest(t *testing.T) { } } +func TestHandleGetQuote_UsesQueryAccountIDBroker(t *testing.T) { + t.Parallel() + + first := newMockBroker(t, "KIS-1") + second := newMockBroker(t, "KIS-2") + second.On("GetQuote", mock.Anything, "NASDAQ", "AAPL").Return(&broker.Quote{ + Symbol: "AAPL", + Price: 250.0, + }, nil).Once() + + s := newOrderTestServer( + map[string]broker.Broker{ + "acc1": first, + "acc2": second, + }, + []config.AccountConfig{ + {AccountID: "acc1"}, + {AccountID: "acc2"}, + }, + ) + + req := httptest.NewRequest(http.MethodGet, "/quotes/NASDAQ/AAPL?account_id=acc2", nil) + rr := performFiberRequest(t, s, req) + + if rr.Code != http.StatusOK { + t.Fatalf("expected 200, got %d body=%s", rr.Code, rr.Body.String()) + } + resp := decodeResponse(t, rr) + if !resp.OK { + t.Fatalf("expected ok=true") + } + if resp.Broker != "KIS-2" { + t.Fatalf("broker = %q, want KIS-2", resp.Broker) + } +} + +func TestHandleGetQuote_Returns404WhenQueryAccountMissing(t *testing.T) { + t.Parallel() + + b := newMockBroker(t, "KIS") + s := newOrderTestServer( + map[string]broker.Broker{"acc1": b}, + []config.AccountConfig{{AccountID: "acc1"}}, + ) + + req := httptest.NewRequest(http.MethodGet, "/quotes/KRX/005930?account_id=missing", nil) + rr := performFiberRequest(t, s, req) + + if rr.Code != http.StatusNotFound { + t.Fatalf("expected 404, got %d body=%s", rr.Code, rr.Body.String()) + } +} + +func TestHandleGetQuote_Returns400WhenQueryAccountAmbiguous(t *testing.T) { + t.Parallel() + + first := newMockBroker(t, "KIS-1") + second := newMockBroker(t, "KIS-2") + s := newOrderTestServer( + map[string]broker.Broker{ + "12345678-01": first, + "12345678-02": second, + }, + []config.AccountConfig{ + {AccountID: "12345678-01"}, + {AccountID: "12345678-02"}, + }, + ) + + req := httptest.NewRequest(http.MethodGet, "/quotes/KRX/005930?account_id=12345678", nil) + rr := performFiberRequest(t, s, req) + + if rr.Code != http.StatusBadRequest { + t.Fatalf("expected 400, got %d body=%s", rr.Code, rr.Body.String()) + } + resp := decodeResponse(t, rr) + if resp.Error != ambiguousAccountIDError { + t.Fatalf("unexpected error: %s", resp.Error) + } +} + func TestStatusFromBrokerError_DefaultAndTyped(t *testing.T) { t.Parallel() diff --git a/internal/server/server.go b/internal/server/server.go index f498d50..7947192 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -225,6 +225,7 @@ func (s *Server) routes() { fuego.OptionDescription("Returns daily/weekly/monthly candlestick data."), fuego.OptionPath("market", "Exchange market code", fuego.ParamExample("KRX", "KRX")), fuego.OptionPath("symbol", "Ticker symbol", fuego.ParamExample("Samsung", "005930")), + fuego.OptionQuery("account_id", "Use a specific account's broker (optional)", fuego.ParamExample("KIS account", "12345678-01")), fuego.OptionQuery("interval", "Candle interval: 1d, 1w, 1mo", fuego.ParamDefault("1d"), fuego.ParamExample("daily", "1d"), fuego.ParamExample("weekly", "1w")), fuego.OptionQuery("from", "Start date (YYYY-MM-DD)", fuego.ParamExample("Jan 2026", "2026-01-01")), fuego.OptionQuery("to", "End date (YYYY-MM-DD)", fuego.ParamExample("Feb 2026", "2026-02-28")), @@ -450,28 +451,31 @@ func (s *Server) handleHealth(c fuego.ContextNoBody) (map[string]interface{}, er // getBroker returns the broker for the given account ID func (s *Server) getBroker(accountID string) broker.Broker { - if brk, ok := s.brokers[accountID]; ok { + if brk, status, _ := s.resolveBrokerByAccountID(accountID); status == 0 { return brk } - // Try matching with/without product code suffix (e.g., "73027400" matches "73027400-01") - for key, brk := range s.brokers { - if strings.HasPrefix(key, accountID+"-") || strings.HasPrefix(accountID, key+"-") || strings.TrimSuffix(key, "-01") == strings.TrimSuffix(accountID, "-01") { - return brk - } - } - // If not found, return first broker (legacy compatibility) - if len(s.brokers) > 0 { - for _, brk := range s.brokers { - return brk - } - } return nil } // getFirstBroker returns the first available broker (for legacy endpoints) func (s *Server) getFirstBroker() broker.Broker { if len(s.accounts) > 0 { - return s.getBroker(s.accounts[0].AccountID) + if brk := s.getBroker(s.accounts[0].AccountID); brk != nil { + return brk + } + } + if len(s.brokers) == 0 { + return nil + } + ids := make([]string, 0, len(s.brokers)) + for accountID := range s.brokers { + ids = append(ids, accountID) + } + sort.Strings(ids) + for _, accountID := range ids { + if brk := s.brokers[accountID]; brk != nil { + return brk + } } return nil }