diff --git a/.gitignore b/.gitignore index 18b6a5c..d5ee13f 100644 --- a/.gitignore +++ b/.gitignore @@ -2,3 +2,4 @@ *~ /bin/cron-control-runner .env +kb/ \ No newline at end of file diff --git a/main.go b/main.go index fbe7e4c..e150b86 100644 --- a/main.go +++ b/main.go @@ -29,6 +29,10 @@ type options struct { remoteToken string useWebsockets bool eventsWebhookURL string + maxHandshakeBytes int + handshakeInitialTTL time.Duration + handshakeIdleTTL time.Duration + maxHandshakeSessions int useLocker bool dataConfigPath string lockerRefreshInterval time.Duration @@ -78,7 +82,19 @@ func main() { // Setup the remote CLI module if enabled. if 0 < len(options.remoteToken) { // TODO: This module could definitely use some general refactoring, but namely a graceful shutdown would be good. - remote.Setup(options.remoteToken, options.useWebsockets, options.wpCLIPath, options.wpPath, options.eventsWebhookURL) + remote.SetupWithOptions( + options.remoteToken, + options.useWebsockets, + options.wpCLIPath, + options.wpPath, + options.eventsWebhookURL, + remote.SetupOptions{ + MaxHandshakeBytes: options.maxHandshakeBytes, + HandshakeInitialTimeout: options.handshakeInitialTTL, + HandshakeIdleTimeout: options.handshakeIdleTTL, + MaxConcurrentHandshakes: options.maxHandshakeSessions, + }, + ) go remote.ListenForConnections() } @@ -107,6 +123,10 @@ func getCliOptions() options { }, remoteToken: "", useWebsockets: false, + maxHandshakeBytes: 64 * 1024, + handshakeInitialTTL: 15 * time.Second, + handshakeIdleTTL: 200 * time.Millisecond, + maxHandshakeSessions: 256, useLocker: false, dataConfigPath: "/etc/wpvip-data-config/config.json", lockerRefreshInterval: 10 * time.Second, @@ -138,6 +158,10 @@ func getCliOptions() options { flag.StringVar(&(options.remoteToken), "token", options.remoteToken, "Token to authenticate remote WP CLI requests") flag.BoolVar(&(options.useWebsockets), "use-websockets", options.useWebsockets, "Use the websocket listener instead of raw tcp for remote WP CLI requests") flag.StringVar(&(options.eventsWebhookURL), "events-webhook-url", options.eventsWebhookURL, "Webhook URL used to send WP CLI events") + flag.IntVar(&(options.maxHandshakeBytes), "max-handshake-bytes", options.maxHandshakeBytes, "Maximum number of bytes accepted during remote handshake") + flag.DurationVar(&(options.handshakeInitialTTL), "handshake-initial-timeout", options.handshakeInitialTTL, "Absolute timeout for finishing remote handshake") + flag.DurationVar(&(options.handshakeIdleTTL), "handshake-idle-timeout", options.handshakeIdleTTL, "Idle timeout between handshake packets") + flag.IntVar(&(options.maxHandshakeSessions), "max-concurrent-handshakes", options.maxHandshakeSessions, "Maximum number of concurrent in-progress remote handshakes") // NOTE: this will exit if options are invalid or if help is requested, etc. flag.Parse() diff --git a/readme.md b/readme.md index 44c1059..f34f204 100644 --- a/readme.md +++ b/readme.md @@ -59,6 +59,30 @@ It's helpful to specify some environment variables (e.g. in an `.env` file): - `-use-mock-data` - use the mock performer for testing +## Remote WP CLI Options +- `-token` string + - Token to authenticate remote WP CLI requests. +- `-use-websockets` + - Use the websocket listener instead of raw tcp for remote WP CLI requests. +- `-events-webhook-url` string + - Webhook URL used to send WP CLI events. +- `-max-handshake-bytes` int + - Maximum number of bytes accepted during remote handshake (default 65536). +- `-handshake-initial-timeout` duration + - Absolute timeout to finish the handshake, regardless of trickle traffic (default 15s). +- `-handshake-idle-timeout` duration + - Maximum idle time between handshake packets (default 200ms). +- `-max-concurrent-handshakes` int + - Maximum number of concurrent in-progress handshakes (default 256). + +### Recommended Production Baseline +- Keep `-max-handshake-bytes` at `65536` unless clients require larger metadata. +- Keep `-handshake-initial-timeout` at `15s`; increase only if legitimate clients regularly exceed it. +- Keep `-handshake-idle-timeout` at `200ms`; this blocks slow-loris style packet trickling. +- Set `-max-concurrent-handshakes` to match host capacity and expected peak auth bursts. + - If memory pressure is observed during spikes, lower this value. + - If legitimate clients are rejected during deploy bursts, raise this value gradually. + ## Architecture ![runner diagram](https://d.pr/i/1THmhI+) diff --git a/remote/remote.go b/remote/remote.go index 5fdbbbe..f1f2d51 100644 --- a/remote/remote.go +++ b/remote/remote.go @@ -37,6 +37,11 @@ import ( const ( shutdownErrorCode = 4001 // WebSocket close code when a shutdown signal is detected + + defaultMaxHandshakeBytes = 64 * 1024 + defaultHandshakeInitialTimeout = 15 * time.Second + defaultHandshakeIdleTimeout = 200 * time.Millisecond + defaultMaxConcurrentHandshakes = 256 ) var nonUTF8Replacement = []byte(string(unicode.ReplacementChar)) @@ -54,17 +59,29 @@ type wpCLIProcess struct { } var ( - gGUIDLength = 36 - gGUIDttys map[string]*wpCLIProcess - padlock *sync.Mutex - guidRegex *regexp.Regexp + gGUIDLength = 36 + gGUIDttys map[string]*wpCLIProcess + padlock *sync.Mutex + guidRegex *regexp.Regexp + handshakeSem chan struct{} ) type config struct { - remoteToken string - useWebsockets bool - wpCLIPath string - wpPath string + remoteToken string + useWebsockets bool + wpCLIPath string + wpPath string + maxHandshakeBytes int + handshakeInitialTimeout time.Duration + handshakeIdleTimeout time.Duration + maxConcurrentHandshakes int +} + +type SetupOptions struct { + MaxHandshakeBytes int + HandshakeInitialTimeout time.Duration + HandshakeIdleTimeout time.Duration + MaxConcurrentHandshakes int } var remoteConfig config @@ -72,11 +89,39 @@ var wpCliEventSender EventSender // Setup configures the module (not super ideal, but this module needs some reworking to make it better) func Setup(remoteToken string, useWebsockets bool, wpCLIPath string, wpPath string, eventsWebhookURL string) { + SetupWithOptions(remoteToken, useWebsockets, wpCLIPath, wpPath, eventsWebhookURL, SetupOptions{}) +} + +func SetupWithOptions(remoteToken string, useWebsockets bool, wpCLIPath string, wpPath string, eventsWebhookURL string, options SetupOptions) { + maxHandshakeBytes := options.MaxHandshakeBytes + if maxHandshakeBytes <= 0 { + maxHandshakeBytes = defaultMaxHandshakeBytes + } + + handshakeInitialTimeout := options.HandshakeInitialTimeout + if handshakeInitialTimeout <= 0 { + handshakeInitialTimeout = defaultHandshakeInitialTimeout + } + + handshakeIdleTimeout := options.HandshakeIdleTimeout + if handshakeIdleTimeout <= 0 { + handshakeIdleTimeout = defaultHandshakeIdleTimeout + } + + maxConcurrentHandshakes := options.MaxConcurrentHandshakes + if maxConcurrentHandshakes <= 0 { + maxConcurrentHandshakes = defaultMaxConcurrentHandshakes + } + remoteConfig = config{ - remoteToken: remoteToken, - useWebsockets: useWebsockets, - wpCLIPath: wpCLIPath, - wpPath: wpPath, + remoteToken: remoteToken, + useWebsockets: useWebsockets, + wpCLIPath: wpCLIPath, + wpPath: wpPath, + maxHandshakeBytes: maxHandshakeBytes, + handshakeInitialTimeout: handshakeInitialTimeout, + handshakeIdleTimeout: handshakeIdleTimeout, + maxConcurrentHandshakes: maxConcurrentHandshakes, } wpCliEventSender = setupWebhookSender( @@ -105,6 +150,7 @@ func setupWebhookSender(remoteToken string, eventsWebhookURL string) EventSender func ListenForConnections() { gGUIDttys = make(map[string]*wpCLIProcess) padlock = &sync.Mutex{} + handshakeSem = make(chan struct{}, effectiveMaxConcurrentHandshakes()) guidRegex = regexp.MustCompile("^[a-fA-F0-9\\-]+$") if nil == guidRegex { @@ -152,53 +198,155 @@ func ListenForConnections() { for { log.Println("listening...") conn, err := listener.AcceptTCP() - log.Printf("connection from %s\n", conn.RemoteAddr().String()) if err != nil { log.Printf("error accepting connection: %s\n", err.Error()) continue } + log.Printf("connection from %s\n", conn.RemoteAddr().String()) + go authConn(conn) } } -func authConn(conn net.Conn) { - var rows, cols uint16 - var offset int64 - var token, GUID, cmd string - var read int - var err error - var data []byte - buf := make([]byte, 65535) +func tryAcquireHandshakeSlot() bool { + if handshakeSem == nil { + return true + } - log.Println("waiting for auth data") + select { + case handshakeSem <- struct{}{}: + return true + default: + return false + } +} - conn.SetReadDeadline(time.Now().Add(time.Duration(5000 * time.Millisecond.Nanoseconds()))) - bufReader := bufio.NewReader(conn) +func releaseHandshakeSlot() { + if handshakeSem == nil { + return + } + + select { + case <-handshakeSem: + default: + } +} + +func minTime(a, b time.Time) time.Time { + if a.Before(b) { + return a + } + + return b +} + +func effectiveMaxHandshakeBytes() int { + if remoteConfig.maxHandshakeBytes <= 0 { + return defaultMaxHandshakeBytes + } + + return remoteConfig.maxHandshakeBytes +} + +func effectiveHandshakeInitialTimeout() time.Duration { + if remoteConfig.handshakeInitialTimeout <= 0 { + return defaultHandshakeInitialTimeout + } + + return remoteConfig.handshakeInitialTimeout +} + +func effectiveHandshakeIdleTimeout() time.Duration { + if remoteConfig.handshakeIdleTimeout <= 0 { + return defaultHandshakeIdleTimeout + } + + return remoteConfig.handshakeIdleTimeout +} + +func effectiveMaxConcurrentHandshakes() int { + if remoteConfig.maxConcurrentHandshakes <= 0 { + return defaultMaxConcurrentHandshakes + } + + return remoteConfig.maxConcurrentHandshakes +} + +func readHandshakeData(conn net.Conn, bufReader *bufio.Reader) ([]byte, error) { + data := make([]byte, 0, 1024) + buf := make([]byte, 4096) + handshakeDeadline := time.Now().Add(effectiveHandshakeInitialTimeout()) + handshakeIdleTimeout := effectiveHandshakeIdleTimeout() + maxHandshakeBytes := effectiveMaxHandshakeBytes() + + if err := conn.SetReadDeadline(minTime(handshakeDeadline, time.Now().Add(handshakeIdleTimeout))); err != nil { + return nil, err + } for { - read, err = bufReader.Read(buf) + read, err := bufReader.Read(buf) - if nil != err && !strings.Contains(err.Error(), "i/o timeout") { - conn.Write([]byte("error during handshaking\n")) - log.Printf("error handshaking: %s\n", err.Error()) - conn.Close() - return + if read > 0 { + if len(data)+read > maxHandshakeBytes { + return nil, fmt.Errorf("error handshake exceeds maximum size of %d bytes", maxHandshakeBytes) + } + + data = append(data, buf[:read]...) + if data[len(data)-1] == '\n' { + break + } } - if 0 != read { - if nil == data { - data = make([]byte, read) - copy(data, buf[:read]) - } else { - data = append(data, buf[:read]...) + if err != nil { + if errors.Is(err, io.EOF) { + return nil, errors.New("error handshake terminated before delimiter") } - } else if 0 == bufReader.Buffered() { - break + + if netErr, ok := err.(net.Error); ok && netErr.Timeout() { + return nil, errors.New("error handshake timed out") + } + + return nil, err + } + + if err := conn.SetReadDeadline(minTime(handshakeDeadline, time.Now().Add(handshakeIdleTimeout))); err != nil { + return nil, err } + } + + return data, nil +} + +func authConn(conn net.Conn) { + var rows, cols uint16 + var offset int64 + var token, GUID, cmd string + var err error + var data []byte + handshakeSlotHeld := false - conn.SetReadDeadline(time.Now().Add(time.Duration(200 * time.Millisecond.Nanoseconds()))) + if !tryAcquireHandshakeSlot() { + conn.Write([]byte("server busy, try again")) + conn.Close() + return + } + handshakeSlotHeld = true + defer func() { + if handshakeSlotHeld { + releaseHandshakeSlot() + } + }() + + log.Println("waiting for auth data") + + bufReader := bufio.NewReader(conn) + data, err = readHandshakeData(conn, bufReader) + if nil != err { + conn.Write([]byte("error during handshaking\n")) + log.Printf("error handshaking: %s\n", err.Error()) + conn.Close() + return } - buf = nil size := len(data) log.Printf("size of handshake %d\n", size) @@ -238,6 +386,8 @@ func authConn(conn net.Conn) { return } + handshakeSlotHeld = false + releaseHandshakeSlot() log.Println("handshake complete!") conn.SetReadDeadline(time.Time{}) diff --git a/remote/remote_test.go b/remote/remote_test.go index bdcb6c5..c292450 100644 --- a/remote/remote_test.go +++ b/remote/remote_test.go @@ -1,10 +1,53 @@ package remote import ( + "bufio" + "bytes" + "io" + "net" "reflect" + "strings" "testing" + "time" ) +type mockNetConn struct { + deadlines []time.Time +} + +func (m *mockNetConn) Read(_ []byte) (int, error) { + return 0, io.EOF +} + +func (m *mockNetConn) Write(b []byte) (int, error) { + return len(b), nil +} + +func (m *mockNetConn) Close() error { + return nil +} + +func (m *mockNetConn) LocalAddr() net.Addr { + return &net.TCPAddr{} +} + +func (m *mockNetConn) RemoteAddr() net.Addr { + return &net.TCPAddr{} +} + +func (m *mockNetConn) SetDeadline(_ time.Time) error { + return nil +} + +func (m *mockNetConn) SetReadDeadline(t time.Time) error { + m.deadlines = append(m.deadlines, t) + return nil +} + +func (m *mockNetConn) SetWriteDeadline(_ time.Time) error { + return nil +} + func TestValidateCommand(t *testing.T) { tests := map[string]struct { errString string @@ -95,3 +138,127 @@ func TestGetCleanWpCliArgumentArray(t *testing.T) { }) } } + +func TestReadHandshakeData(t *testing.T) { + t.Run("returns payload", func(t *testing.T) { + SetupWithOptions("token", false, "/tmp/wp", "/tmp", "", SetupOptions{ + MaxHandshakeBytes: 1024, + HandshakeInitialTimeout: 5 * time.Second, + HandshakeIdleTimeout: 200 * time.Millisecond, + MaxConcurrentHandshakes: 16, + }) + + conn := &mockNetConn{} + payload := []byte("token-guid-rows-cols-cmd\n") + bufReader := bufio.NewReader(bytes.NewReader(payload)) + + data, err := readHandshakeData(conn, bufReader) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if string(data) != string(payload) { + t.Fatalf("unexpected payload: %q", string(data)) + } + + if len(conn.deadlines) == 0 { + t.Fatal("expected read deadline to be set") + } + }) + + t.Run("rejects oversized handshake", func(t *testing.T) { + SetupWithOptions("token", false, "/tmp/wp", "/tmp", "", SetupOptions{ + MaxHandshakeBytes: 16, + HandshakeInitialTimeout: 5 * time.Second, + HandshakeIdleTimeout: 200 * time.Millisecond, + MaxConcurrentHandshakes: 16, + }) + + conn := &mockNetConn{} + payload := bytes.Repeat([]byte("x"), 17) + bufReader := bufio.NewReader(bytes.NewReader(payload)) + + _, err := readHandshakeData(conn, bufReader) + if err == nil { + t.Fatal("expected oversized handshake error") + } + + if !strings.Contains(err.Error(), "exceeds maximum size of 16 bytes") { + t.Fatalf("unexpected error: %v", err) + } + }) + + t.Run("rejects handshake timeout without delimiter", func(t *testing.T) { + SetupWithOptions("token", false, "/tmp/wp", "/tmp", "", SetupOptions{ + MaxHandshakeBytes: 1024, + HandshakeInitialTimeout: 200 * time.Millisecond, + HandshakeIdleTimeout: 20 * time.Millisecond, + MaxConcurrentHandshakes: 16, + }) + + serverConn, clientConn := net.Pipe() + defer serverConn.Close() + defer clientConn.Close() + + go func() { + clientConn.Write([]byte("token-guid")) + // Keep the connection open so the server side hits a read timeout. + time.Sleep(50 * time.Millisecond) + }() + + _, err := readHandshakeData(serverConn, bufio.NewReader(serverConn)) + if err == nil { + t.Fatal("expected timeout error") + } + + if !strings.Contains(err.Error(), "timed out") { + t.Fatalf("unexpected error: %v", err) + } + }) + + t.Run("rejects handshake terminated without delimiter", func(t *testing.T) { + SetupWithOptions("token", false, "/tmp/wp", "/tmp", "", SetupOptions{ + MaxHandshakeBytes: 1024, + HandshakeInitialTimeout: 200 * time.Millisecond, + HandshakeIdleTimeout: 50 * time.Millisecond, + MaxConcurrentHandshakes: 16, + }) + + conn := &mockNetConn{} + bufReader := bufio.NewReader(bytes.NewReader([]byte("token-guid"))) + + _, err := readHandshakeData(conn, bufReader) + if err == nil { + t.Fatal("expected delimiter error") + } + + if !strings.Contains(err.Error(), "before delimiter") { + t.Fatalf("unexpected error: %v", err) + } + }) +} + +func TestSetupWithOptions_DefaultsInvalidValues(t *testing.T) { + SetupWithOptions("token", false, "/tmp/wp", "/tmp", "", SetupOptions{ + MaxHandshakeBytes: -1, + HandshakeInitialTimeout: -1, + HandshakeIdleTimeout: -1, + MaxConcurrentHandshakes: -1, + }) + + if got, want := remoteConfig.maxHandshakeBytes, defaultMaxHandshakeBytes; got != want { + t.Fatalf("maxHandshakeBytes=%d, want %d", got, want) + } + + if got, want := remoteConfig.handshakeInitialTimeout, defaultHandshakeInitialTimeout; got != want { + t.Fatalf("handshakeInitialTimeout=%s, want %s", got, want) + } + + if got, want := remoteConfig.handshakeIdleTimeout, defaultHandshakeIdleTimeout; got != want { + t.Fatalf("handshakeIdleTimeout=%s, want %s", got, want) + } + + if got, want := remoteConfig.maxConcurrentHandshakes, defaultMaxConcurrentHandshakes; got != want { + t.Fatalf("maxConcurrentHandshakes=%d, want %d", got, want) + } +}