Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
185 changes: 185 additions & 0 deletions internal/server/account_resolution.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,185 @@
package server

import (
"net/http"
"sort"
"strings"

"github.com/smallfish06/krsec/pkg/broker"
)

const ambiguousAccountIDError = "account_id is ambiguous; use full account_id"

func normalizeBrokerCode(name string) string {
switch strings.ToLower(strings.TrimSpace(name)) {
case broker.CodeKIS, strings.ToLower(broker.NameKIS):
return broker.CodeKIS
case broker.CodeKiwoom, strings.ToLower(broker.NameKiwoom):
return broker.CodeKiwoom
default:
return ""
}
}

func normalizeAccountIDAlias(accountID string) string {
accountID = strings.TrimSpace(accountID)
if strings.HasSuffix(accountID, "-01") {
return strings.TrimSuffix(accountID, "-01")
}
return accountID
}

func isDigits(s string) bool {
if s == "" {
return false
}
for _, r := range s {
if r < '0' || r > '9' {
return false
}
}
return true
}

func kisAccountBase(accountID string) (string, bool) {
accountID = strings.TrimSpace(accountID)
switch {
case len(accountID) == 8 && isDigits(accountID):
return accountID, true
case len(accountID) == 11 && accountID[8] == '-' && isDigits(accountID[:8]) && isDigits(accountID[9:]):
return accountID[:8], true
default:
return "", false
}
}

func sameAccountID(a, b string) bool {
a = strings.TrimSpace(a)
b = strings.TrimSpace(b)
if a == "" || b == "" {
return false
}
if a == b {
return true
}
if baseA, okA := kisAccountBase(a); okA {
if baseB, okB := kisAccountBase(b); okB {
return baseA == baseB
}
}
return normalizeAccountIDAlias(a) == normalizeAccountIDAlias(b)
}

func (s *Server) resolveBrokerByAccountID(accountID string) (broker.Broker, int, string) {
accountID = strings.TrimSpace(accountID)
if accountID == "" {
return nil, http.StatusBadRequest, "account_id is required"
}

if brk, ok := s.brokers[accountID]; ok {
return brk, 0, ""
}

candidates := s.findBrokerAccountCandidates(accountID)
switch len(candidates) {
case 0:
return nil, http.StatusNotFound, "account not found"
case 1:
return s.brokers[candidates[0]], 0, ""
default:
return nil, http.StatusBadRequest, ambiguousAccountIDError
}
}

func (s *Server) getBrokerStrict(accountID string) (broker.Broker, bool) {
brk, status, _ := s.resolveBrokerByAccountID(accountID)
return brk, status == 0
}

func (s *Server) findBrokerAccountCandidates(accountID string) []string {
matches := make([]string, 0, 2)
seen := make(map[string]struct{}, 2)

// Prefer configured account order for deterministic matching.
for _, acc := range s.accounts {
id := strings.TrimSpace(acc.AccountID)
if id == "" {
continue
}
if _, ok := s.brokers[id]; !ok {
continue
}
if !sameAccountID(id, accountID) {
continue
}
if _, ok := seen[id]; ok {
continue
}
seen[id] = struct{}{}
matches = append(matches, id)
}

extra := make([]string, 0, len(s.brokers))
for id := range s.brokers {
if _, ok := seen[id]; ok {
continue
}
if sameAccountID(id, accountID) {
extra = append(extra, id)
}
}
sort.Strings(extra)
matches = append(matches, extra...)

return matches
}

func (s *Server) resolveAuthBroker(requestedBroker string, sandbox bool) (broker.Broker, int, string) {
brokerCode := normalizeBrokerCode(requestedBroker)
if strings.TrimSpace(requestedBroker) != "" && brokerCode == "" {
return nil, http.StatusBadRequest, "unsupported broker"
}

if brokerCode == "" {
brk := s.getFirstBroker()
if brk == nil {
return nil, http.StatusServiceUnavailable, "no broker available"
}
return brk, 0, ""
}

var fallback broker.Broker
for _, acc := range s.accounts {
if normalizeBrokerCode(acc.Broker) != brokerCode {
continue
}
brk, status, _ := s.resolveBrokerByAccountID(acc.AccountID)
if status != 0 {
continue
}
if acc.Sandbox == sandbox {
return brk, 0, ""
}
if fallback == nil {
fallback = brk
}
}

if fallback != nil {
return fallback, 0, ""
}

ids := make([]string, 0, len(s.brokers))
for id := range s.brokers {
ids = append(ids, id)
}
sort.Strings(ids)
for _, id := range ids {
brk := s.brokers[id]
if normalizeBrokerCode(brk.Name()) == brokerCode {
return brk, 0, ""
}
}

return nil, http.StatusServiceUnavailable, "no " + strings.ToUpper(brokerCode) + " account available"
}
16 changes: 8 additions & 8 deletions internal/server/handler_accounts.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,11 @@ import (
func (s *Server) handleGetBalance(c fuego.ContextNoBody) (Response, error) {
accountID := c.PathParam("account_id")

brk, ok := s.getBrokerStrict(accountID)
if !ok {
return respond(c, http.StatusNotFound, Response{
brk, status, reason := s.resolveBrokerByAccountID(accountID)
if brk == nil {
return respond(c, status, Response{
OK: false,
Error: "account not found",
Error: reason,
})
}

Expand All @@ -39,11 +39,11 @@ func (s *Server) handleGetBalance(c fuego.ContextNoBody) (Response, error) {
func (s *Server) handleGetPositions(c fuego.ContextNoBody) (Response, error) {
accountID := c.PathParam("account_id")

brk, ok := s.getBrokerStrict(accountID)
if !ok {
return respond(c, http.StatusNotFound, Response{
brk, status, reason := s.resolveBrokerByAccountID(accountID)
if brk == nil {
return respond(c, status, Response{
OK: false,
Error: "account not found",
Error: reason,
})
}

Expand Down
28 changes: 28 additions & 0 deletions internal/server/handler_accounts_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,3 +58,31 @@ func TestHandleGetPositions_UnknownAccountReturnsNotFound(t *testing.T) {
t.Fatalf("unexpected error: %s", resp.Error)
}
}

func TestHandleGetBalance_AmbiguousAccountReturnsBadRequest(t *testing.T) {
t.Parallel()

first := newMockBroker(t, "KIS-1")
second := newMockBroker(t, "KIS-2")
s := newOrderTestServer(
map[string]broker.Broker{
"12345678-01": first,
"12345678-02": second,
},
[]config.AccountConfig{
{AccountID: "12345678-01"},
{AccountID: "12345678-02"},
},
)

req := httptest.NewRequest(http.MethodGet, "/accounts/12345678/balance", nil)
rr := performFiberRequest(t, s, req)

if rr.Code != http.StatusBadRequest {
t.Fatalf("expected 400, got %d body=%s", rr.Code, rr.Body.String())
}
resp := decodeResponse(t, rr)
if resp.Error != ambiguousAccountIDError {
t.Fatalf("unexpected error: %s", resp.Error)
}
}
6 changes: 3 additions & 3 deletions internal/server/handler_auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,11 @@ func (s *Server) handleAuthToken(c fuego.ContextWithBody[AuthTokenRequest]) (Res
})
}

brk := s.getFirstBroker()
brk, status, reason := s.resolveAuthBroker(req.Broker, req.Sandbox)
if brk == nil {
return respond(c, http.StatusServiceUnavailable, Response{
return respond(c, status, Response{
OK: false,
Error: "no broker available",
Error: reason,
})
}

Expand Down
112 changes: 112 additions & 0 deletions internal/server/handler_auth_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
package server

import (
"bytes"
"net/http"
"net/http/httptest"
"testing"
"time"

"github.com/stretchr/testify/mock"

"github.com/smallfish06/krsec/pkg/broker"
"github.com/smallfish06/krsec/pkg/config"
)

func TestHandleAuthToken_RespectsRequestedBroker(t *testing.T) {
t.Parallel()

kis := newMockBroker(t, "KIS")
kiwoom := newMockBroker(t, "KIWOOM")
kis.On("Authenticate", mock.Anything, broker.Credentials{
AppKey: "k",
AppSecret: "s",
}).Return(&broker.Token{
AccessToken: "kis-token",
TokenType: "Bearer",
ExpiresAt: time.Now().Add(time.Hour),
}, nil).Once()

s := newOrderTestServer(
map[string]broker.Broker{
"kiwoom-acc": kiwoom,
"kis-acc": kis,
},
[]config.AccountConfig{
{AccountID: "kiwoom-acc", Broker: "kiwoom", Sandbox: true},
{AccountID: "kis-acc", Broker: "kis", Sandbox: true},
},
)

body := []byte(`{"broker":"kis","credentials":{"app_key":"k","app_secret":"s"},"sandbox":true}`)
req := httptest.NewRequest(http.MethodPost, "/auth/token", bytes.NewReader(body))
rr := performFiberRequest(t, s, req)

if rr.Code != http.StatusOK {
t.Fatalf("expected 200, got %d body=%s", rr.Code, rr.Body.String())
}
resp := decodeResponse(t, rr)
if !resp.OK {
t.Fatalf("expected ok=true")
}
if resp.Broker != "KIS" {
t.Fatalf("broker = %q, want KIS", resp.Broker)
}
kiwoom.AssertNotCalled(t, "Authenticate", mock.Anything, mock.Anything)
}

func TestHandleAuthToken_RejectsUnsupportedBroker(t *testing.T) {
t.Parallel()

kis := newMockBroker(t, "KIS")
s := newOrderTestServer(
map[string]broker.Broker{"kis-acc": kis},
[]config.AccountConfig{{AccountID: "kis-acc", Broker: "kis"}},
)

body := []byte(`{"broker":"future","credentials":{"app_key":"k","app_secret":"s"}}`)
req := httptest.NewRequest(http.MethodPost, "/auth/token", bytes.NewReader(body))
rr := performFiberRequest(t, s, req)

if rr.Code != http.StatusBadRequest {
t.Fatalf("expected 400, got %d body=%s", rr.Code, rr.Body.String())
}
resp := decodeResponse(t, rr)
if resp.Error != "unsupported broker" {
t.Fatalf("unexpected error: %s", resp.Error)
}
kis.AssertNotCalled(t, "Authenticate", mock.Anything, mock.Anything)
}

func TestHandleAuthToken_SelectsSandboxMatchFirst(t *testing.T) {
t.Parallel()

sandboxBroker := newMockBroker(t, "KIS")
prodBroker := newMockBroker(t, "KIS")

sandboxBroker.On("Authenticate", mock.Anything, mock.Anything).Return(&broker.Token{
AccessToken: "sandbox-token",
TokenType: "Bearer",
ExpiresAt: time.Now().Add(time.Hour),
}, nil).Once()

s := newOrderTestServer(
map[string]broker.Broker{
"kis-prod": prodBroker,
"kis-sandbox": sandboxBroker,
},
[]config.AccountConfig{
{AccountID: "kis-prod", Broker: "kis", Sandbox: false},
{AccountID: "kis-sandbox", Broker: "kis", Sandbox: true},
},
)

body := []byte(`{"broker":"kis","credentials":{"app_key":"k","app_secret":"s"},"sandbox":true}`)
req := httptest.NewRequest(http.MethodPost, "/auth/token", bytes.NewReader(body))
rr := performFiberRequest(t, s, req)

if rr.Code != http.StatusOK {
t.Fatalf("expected 200, got %d body=%s", rr.Code, rr.Body.String())
}
prodBroker.AssertNotCalled(t, "Authenticate", mock.Anything, mock.Anything)
}
10 changes: 7 additions & 3 deletions internal/server/handler_instruments.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,17 @@ func (s *Server) handleGetInstrument(c fuego.ContextNoBody) (Response, error) {
symbol := c.PathParam("symbol")
accountID := c.QueryParam("account_id")

var candidates []broker.Broker
if accountID != "" {
if _, ok := s.getBrokerStrict(accountID); !ok {
return respond(c, http.StatusNotFound, Response{OK: false, Error: "account not found"})
brk, status, reason := s.resolveBrokerByAccountID(accountID)
if brk == nil {
return respond(c, status, Response{OK: false, Error: reason})
}
candidates = []broker.Broker{brk}
} else {
candidates = s.orderBrokerCandidates("")
}

candidates := s.orderBrokerCandidates(accountID)
if len(candidates) == 0 {
return respond(c, http.StatusServiceUnavailable, Response{OK: false, Error: "no broker available"})
}
Expand Down
Loading