diff --git a/CLAUDE.md b/CLAUDE.md new file mode 100644 index 0000000..b0dc0dd --- /dev/null +++ b/CLAUDE.md @@ -0,0 +1,95 @@ +# CLAUDE.md + +This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository. + +## Project Overview + +This is a Go implementation of the Interactive Brokers TWS (Trader Workstation) API, providing idiomatic Go interfaces to IB Gateway functionality. The library targets TWS API version 10.12 and is not backwards compatible with older versions. + +Currently implements real-time market data functionality with plans to expand to order execution and account management. + +## Key Architecture + +The codebase follows a clean, layered architecture: + +- **client.go**: Main API client with connection management and request routing +- **transport.go**: TCP socket communication layer handling raw message exchange with IB Gateway +- **encoders.go/decoders.go**: Protocol message encoding/decoding for IB's proprietary format +- **messages.go**: Message type constants and protocol definitions +- **models.go**: Data structures (Contract, Bar, TickData, etc.) +- **versions.go**: Server version compatibility constants + +## Common Development Commands + +### Build and Test +```bash +# Run all tests with coverage and race detection +go test -v -race -coverprofile=coverage.txt -covermode=atomic ./... + +# Run specific test +go test -v -run TestRealTimeBarsEncoder + +# Build the example application +go build -o ibapi-example cmd/main.go + +# Check for issues +go vet ./... +``` + +### Module Management +```bash +# Update dependencies +go mod tidy + +# Download dependencies +go mod download + +# Verify dependencies +go mod verify +``` + +### Release Process +```bash +# Create and push a new version tag +git tag vx.y.z +git push origin vx.y.z + +# Publish to package repository +GOPROXY=proxy.golang.org go list -m github.com/wboayue/ibapi@vx.y.z +``` + +## Implementation Patterns + +### Request/Response Pattern +All client methods follow a consistent pattern: +1. Generate unique request ID using `nextRequestId()` +2. Create response channel and register it in `client.channels` map +3. Encode and send request message +4. Handle responses asynchronously via goroutines +5. Clean up channels on context cancellation + +### Thread Safety +- Use `sync.Mutex` for protecting shared state +- Separate mutexes for different concerns (requestIdMutex, contractDetailsMutex) +- Channels for safe concurrent message passing + +### Error Handling +- Use `fmt.Errorf` with `%w` verb for error wrapping (Go 1.13+ standard) +- Return errors immediately, don't panic +- Check for IB error responses in message handlers + +## Testing Approach + +Tests use `github.com/stretchr/testify` for assertions. Test files follow Go convention with `*_test.go` suffix. + +Key test patterns: +- Unit tests for encoders/decoders with known message formats +- Mock MessageBus interface for testing client logic without network +- Table-driven tests for multiple scenarios + +## CI/CD + +GitHub Actions workflow (`.github/workflows/ci.yml`) runs on all PRs and pushes to main: +- Tests with race detection enabled +- Coverage reporting to Codecov +- Go 1.17+ required \ No newline at end of file diff --git a/client.go b/client.go index 7902892..8516463 100644 --- a/client.go +++ b/client.go @@ -8,8 +8,6 @@ import ( "strings" "sync" "time" - - "github.com/palantir/stacktrace" ) const ( @@ -82,27 +80,27 @@ func (c *IbClient) handshake() error { version := fmt.Sprintf("v%d..%d", minClientVer, maxClientVer) if err := c.MessageBus.Write(prefix); err != nil { - return stacktrace.Propagate(err, "error sending prefix") + return fmt.Errorf("error sending prefix: %w", err) } if err := c.MessageBus.WritePacket(version); err != nil { - return stacktrace.Propagate(err, "error sending version") + return fmt.Errorf("error sending version: %w", err) } fields, err := c.readFirstPacket() if err != nil { - return stacktrace.Propagate(err, "error reading first packet") + return fmt.Errorf("error reading first packet: %w", err) } c.ServerVersion, err = strconv.Atoi(fields[0]) if err != nil { - return stacktrace.Propagate(err, "error parsing server version: %v", fields[0]) + return fmt.Errorf("error parsing server version %v: %w", fields[0], err) } log.Printf("server version: %d", c.ServerVersion) c.ServerTime, err = time.Parse(ibDateLayout, fields[1]) if err != nil { - return stacktrace.Propagate(err, "error parsing server time: %v", fields[1]) + return fmt.Errorf("error parsing server time %v: %w", fields[1], err) } log.Printf("server time: %s", c.ServerTime) @@ -130,7 +128,7 @@ func (c *IbClient) nextRequestId() int { func (c *IbClient) readFields() ([]string, error) { data, err := c.MessageBus.ReadPacket() if err != nil { - return nil, stacktrace.Propagate(err, "error reading packet") + return nil, fmt.Errorf("error reading packet: %w", err) } return strings.Split(string(data[:len(data)-1]), "\x00"), nil } @@ -138,14 +136,14 @@ func (c *IbClient) readFields() ([]string, error) { func (c *IbClient) readFirstPacket() ([]string, error) { fields, err := c.readFields() if err != nil { - return nil, stacktrace.Propagate(err, "error reading fields") + return nil, fmt.Errorf("error reading fields: %w", err) } if len(fields) != 2 { for _, field := range fields { fmt.Println("-" + field) } - return nil, stacktrace.NewError("expected 2 fields, got %d: %v", len(fields), fields) + return nil, fmt.Errorf("expected 2 fields, got %d: %v", len(fields), fields) } return fields, nil @@ -267,11 +265,11 @@ func (c *IbClient) handleErrorMessage(scanner *parser, fields []string) { // useRth - use regular trading hours func (c *IbClient) RealTimeBars(ctx context.Context, contract Contract, whatToShow string, useRth bool) (<-chan Bar, error) { if c.ServerVersion < minServerVersionRealTimeBars { - return nil, stacktrace.NewError("server version %d does not support real time bars", c.ServerVersion) + return nil, fmt.Errorf("server version %d does not support real time bars", c.ServerVersion) } if c.ServerVersion < minServerVersionTradingClass { - return nil, stacktrace.NewError("server version %d does not support TradingClass or ContractId fields", c.ServerVersion) + return nil, fmt.Errorf("server version %d does not support TradingClass or ContractId fields", c.ServerVersion) } encoder := realTimeBarsEncoder{ @@ -287,7 +285,7 @@ func (c *IbClient) RealTimeBars(ctx context.Context, contract Contract, whatToSh err := c.MessageBus.WritePacket(encoder.encode()) if err != nil { - return nil, stacktrace.Propagate(err, "error sending request market data message") + return nil, fmt.Errorf("error sending request market data message: %w", err) } // process response @@ -330,7 +328,7 @@ func (c *IbClient) RealTimeBars(ctx context.Context, contract Contract, whatToSh // cancelRealTimeBar cancels a request for real time bars. func (c *IbClient) cancelRealTimeBars(ctx context.Context, requestId int) error { if c.ServerVersion < minServerVersionRealTimeBars { - return stacktrace.NewError("server version %d does not support real time bars cancellation", c.ServerVersion) + return fmt.Errorf("server version %d does not support real time bars cancellation", c.ServerVersion) } log.Printf("canceling real time bar request %v.", requestId) @@ -344,7 +342,7 @@ func (c *IbClient) cancelRealTimeBars(ctx context.Context, requestId int) error // interface for this if err := c.MessageBus.WritePacket(message.Encode()); err != nil { - return stacktrace.Propagate(err, "error sending request to cancel market data") + return fmt.Errorf("error sending request to cancel market data: %w", err) } return nil @@ -353,11 +351,11 @@ func (c *IbClient) cancelRealTimeBars(ctx context.Context, requestId int) error // TickByTickTrades requests tick by tick trades. func (c *IbClient) TickByTickTrades(ctx context.Context, contract Contract) (chan Trade, error) { if c.ServerVersion < minServerVerTickByTick { - return nil, stacktrace.NewError("server version %d does not support tick-by-tick data requests.", c.ServerVersion) + return nil, fmt.Errorf("server version %d does not support tick-by-tick data requests", c.ServerVersion) } if c.ServerVersion < minServerVerTickByTickIgnoreSize { - return nil, stacktrace.NewError("server version %d does not support ignore_size and number_of_ticks parameters in tick-by-tick data requests.", c.ServerVersion) + return nil, fmt.Errorf("server version %d does not support ignore_size and number_of_ticks parameters in tick-by-tick data requests", c.ServerVersion) } encoder := tickByTickEncoder{ @@ -373,7 +371,7 @@ func (c *IbClient) TickByTickTrades(ctx context.Context, contract Contract) (cha err := c.MessageBus.WritePacket(encoder.encode()) if err != nil { - return nil, stacktrace.Propagate(err, "error sending request for tick by tick trades") + return nil, fmt.Errorf("error sending request for tick by tick trades: %w", err) } // process response @@ -416,7 +414,7 @@ func (c *IbClient) TickByTickTrades(ctx context.Context, contract Contract) (cha // cancelTickByTickData cancels a request for tick by tick data. func (c *IbClient) cancelTickByTickData(ctx context.Context, requestId int) error { if c.ServerVersion < minServerVerTickByTick { - return stacktrace.NewError("server version %d does not support tick by tick cancellation", c.ServerVersion) + return fmt.Errorf("server version %d does not support tick by tick cancellation", c.ServerVersion) } log.Printf("canceling tick by tick data request %v.", requestId) @@ -427,7 +425,7 @@ func (c *IbClient) cancelTickByTickData(ctx context.Context, requestId int) erro message.addInt(requestId) if err := c.MessageBus.WritePacket(message.Encode()); err != nil { - return stacktrace.Propagate(err, "error sending request to cancel tick by tick data") + return fmt.Errorf("error sending request to cancel tick by tick data: %w", err) } return nil @@ -436,11 +434,11 @@ func (c *IbClient) cancelTickByTickData(ctx context.Context, requestId int) erro // TickByTickBidAsk requests tick-by-tick bid/ask. func (c *IbClient) TickByTickBidAsk(ctx context.Context, contract Contract) (chan BidAsk, error) { if c.ServerVersion < minServerVerTickByTick { - return nil, stacktrace.NewError("server version %d does not support tick-by-tick data requests.", c.ServerVersion) + return nil, fmt.Errorf("server version %d does not support tick-by-tick data requests", c.ServerVersion) } if c.ServerVersion < minServerVerTickByTickIgnoreSize { - return nil, stacktrace.NewError("server version %d does not support ignore_size and number_of_ticks parameters in tick-by-tick data requests.", c.ServerVersion) + return nil, fmt.Errorf("server version %d does not support ignore_size and number_of_ticks parameters in tick-by-tick data requests", c.ServerVersion) } encoder := tickByTickEncoder{ @@ -456,7 +454,7 @@ func (c *IbClient) TickByTickBidAsk(ctx context.Context, contract Contract) (cha err := c.MessageBus.WritePacket(encoder.encode()) if err != nil { - return nil, stacktrace.Propagate(err, "error sending request for tick by tick bid/ask") + return nil, fmt.Errorf("error sending request for tick by tick bid/ask: %w", err) } // process response @@ -505,15 +503,15 @@ func (c *IbClient) TickByTickBidAsk(ctx context.Context, contract Contract) (cha // It can also be used to retrieve complete options and futures chains. func (c *IbClient) ContractDetails(ctx context.Context, contract Contract) ([]ContractDetails, error) { if c.ServerVersion < minServerVersionSecurityIdType { - return nil, stacktrace.NewError("server version %d does not support SecurityIdType or SecurityId fields", c.ServerVersion) + return nil, fmt.Errorf("server version %d does not support SecurityIdType or SecurityId fields", c.ServerVersion) } if c.ServerVersion < minServerVersionTradingClass { - return nil, stacktrace.NewError("server version %d does not support TradingClass field in Contract", c.ServerVersion) + return nil, fmt.Errorf("server version %d does not support TradingClass field in Contract", c.ServerVersion) } if c.ServerVersion < minServerVersionLinking { - return nil, stacktrace.NewError("server version %d does not support PrimaryExchange field in Contract", c.ServerVersion) + return nil, fmt.Errorf("server version %d does not support PrimaryExchange field in Contract", c.ServerVersion) } c.contractDetailsMutex.Lock() @@ -532,7 +530,7 @@ func (c *IbClient) ContractDetails(ctx context.Context, contract Contract) ([]Co err := c.MessageBus.WritePacket(encoder.encode()) if err != nil { - return nil, stacktrace.Propagate(err, "error sending contract details request") + return nil, fmt.Errorf("error sending contract details request: %w", err) } // process response diff --git a/go.mod b/go.mod index c1896b9..b852042 100644 --- a/go.mod +++ b/go.mod @@ -2,13 +2,10 @@ module github.com/wboayue/ibapi go 1.17 -require ( - github.com/palantir/stacktrace v0.0.0-20161112013806-78658fd2d177 - github.com/stretchr/testify v1.7.0 -) +require github.com/stretchr/testify v1.11.0 require ( - github.com/davecgh/go-spew v1.1.0 // indirect + github.com/davecgh/go-spew v1.1.1 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect - gopkg.in/yaml.v3 v3.0.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index a63ae45..1c95799 100644 --- a/go.sum +++ b/go.sum @@ -1,14 +1,19 @@ -github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/palantir/stacktrace v0.0.0-20161112013806-78658fd2d177 h1:nRlQD0u1871kaznCnn1EvYiMbum36v7hw1DLPEjds4o= -github.com/palantir/stacktrace v0.0.0-20161112013806-78658fd2d177/go.mod h1:ao5zGxj8Z4x60IOVYZUbDSmt3R8Ddo080vEgPosHpak= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= -github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY= -github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= +github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= +github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA= +github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= +github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= +github.com/stretchr/testify v1.11.0 h1:ib4sjIrwZKxE5u/Japgo/7SJV3PvgjGiRNAvTVGqQl8= +github.com/stretchr/testify v1.11.0/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= -gopkg.in/yaml.v3 v3.0.0 h1:hjy8E9ON/egN1tAYqKb61G10WtihqetD4sz2H+8nIeA= -gopkg.in/yaml.v3 v3.0.0/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/transport.go b/transport.go index f7fa2b8..c461dff 100644 --- a/transport.go +++ b/transport.go @@ -5,8 +5,6 @@ import ( "fmt" "io" "net" - - "github.com/palantir/stacktrace" ) // TcpMessageBus implements the MessageBus over TCP @@ -26,7 +24,7 @@ func (b *TcpMessageBus) Connect(host string, port int, clientId int) error { var err error b.socket, err = net.Dial("tcp", fmt.Sprintf("%s:%d", host, port)) if err != nil { - return stacktrace.Propagate(err, "error dialing %s:%d", host, port) + return fmt.Errorf("error dialing %s:%d: %w", host, port, err) } return nil @@ -45,7 +43,7 @@ func (b *TcpMessageBus) Close() error { func (b *TcpMessageBus) Write(data string) error { _, err := b.socket.Write([]byte(data)) if err != nil { - return stacktrace.Propagate(err, "error writing bytes") + return fmt.Errorf("error writing bytes: %w", err) } return nil @@ -58,7 +56,7 @@ func (b *TcpMessageBus) WritePacket(data string) error { _, err := b.socket.Write(header) if err != nil { - return stacktrace.Propagate(err, "error writing packet") + return fmt.Errorf("error writing packet: %w", err) } _, err = b.socket.Write([]byte(data)) @@ -74,7 +72,7 @@ func (b *TcpMessageBus) ReadPacket() (string, error) { header := make([]byte, 4) _, err := io.ReadFull(b.socket, header) if err != nil { - return "", stacktrace.Propagate(err, "error reading packet header") + return "", fmt.Errorf("error reading packet header: %w", err) } size := binary.BigEndian.Uint32(header) @@ -82,7 +80,7 @@ func (b *TcpMessageBus) ReadPacket() (string, error) { data := make([]byte, size) _, err = io.ReadFull(b.socket, data) if err != nil { - return "", stacktrace.Propagate(err, "error reading packet body") + return "", fmt.Errorf("error reading packet body: %w", err) } return string(data), nil