From 88fbb7d6e67d81567bd92155fc030f4993a2b0c8 Mon Sep 17 00:00:00 2001 From: yaoge123 Date: Sun, 31 May 2026 10:55:21 +0800 Subject: [PATCH 1/4] feat: add relay-phase idle timeout Once a client connection passes the rsync handshake and enters the bidirectional relay phase, the existing ReadTimeout/WriteTimeout no longer apply: the proxy explicitly clears deadlines and just runs io.Copy in both directions. A misbehaving or stuck client could then hold a connection open indefinitely without transferring data. Add an opt-in idle timeout for the relay phase, mirroring rsyncd's "timeout" semantics (rsyncd.conf(5)): * New config option [proxy].relay_idle_timeout (seconds). 0 (default) disables the timeout, matching the previous behavior. * countingReader now optionally records a UnixNano timestamp of the last non-zero read. A single timestamp is shared between the two directions, so any traffic in either direction resets the clock. * relay() starts a watcher goroutine when the timeout is enabled. It ticks at idle_timeout/4 (minimum 1s) and closes both endpoints once the configured duration has elapsed without activity, logging an access-log entry. * The expected I/O errors that follow such a forced close are suppressed from the error log via an idleTimedOut atomic flag. Tests cover both the timeout-fires path (idle upstream is torn down and an access-log line is written) and the steady-traffic path (a stream slower than the timeout but continuously active is not cut). --- assets/config.example.toml | 7 +++ pkg/server/config.go | 9 +++ pkg/server/server.go | 68 ++++++++++++++++++++- pkg/server/server_test.go | 119 +++++++++++++++++++++++++++++++++++++ 4 files changed, 201 insertions(+), 2 deletions(-) diff --git a/assets/config.example.toml b/assets/config.example.toml index 02c3b7e..d0b2938 100644 --- a/assets/config.example.toml +++ b/assets/config.example.toml @@ -11,6 +11,13 @@ tls_key_file = "/etc/rsync-proxy/tls/server.key" motd = "Served by rsync-proxy (https://github.com/ustclug/rsync-proxy)" +# Idle timeout (seconds) applied during the bidirectional relay phase +# of a client connection. If no data flows in either direction for +# this duration, the connection is terminated. 0 (the default) +# disables the timeout. This mirrors rsyncd's "timeout" setting in +# rsyncd.conf(5); 600 is a common choice for public mirrors. +#relay_idle_timeout = 0 + [upstreams.u1] address = "127.0.0.1:1234" modules = ["foo"] diff --git a/pkg/server/config.go b/pkg/server/config.go index f9ed120..ffe8efd 100644 --- a/pkg/server/config.go +++ b/pkg/server/config.go @@ -26,6 +26,15 @@ type ProxySettings struct { ErrorLog string `toml:"error_log"` TLSCertFile string `toml:"tls_cert_file"` TLSKeyFile string `toml:"tls_key_file"` + // RelayIdleTimeoutSecs is the idle timeout (in seconds) applied + // during the bidirectional relay phase of a connection. If no data + // flows in either direction for this duration, the connection is + // terminated. 0 (the default) disables the timeout. + // + // This mirrors the semantics of the rsyncd "timeout" setting (see + // rsyncd.conf(5)), which is an I/O timeout. A common choice for + // public mirrors is 600. + RelayIdleTimeoutSecs int `toml:"relay_idle_timeout"` } type Config struct { diff --git a/pkg/server/server.go b/pkg/server/server.go index be46f45..951cf90 100644 --- a/pkg/server/server.go +++ b/pkg/server/server.go @@ -138,6 +138,13 @@ type Server struct { ReadTimeout time.Duration WriteTimeout time.Duration + // RelayIdleTimeout is the idle (no I/O activity in either + // direction) timeout applied during the bidirectional relay phase + // of a connection. A value of 0 (the default) disables the + // timeout, matching rsyncd's "timeout = 0" behavior. See + // rsyncd.conf(5). + RelayIdleTimeout time.Duration + Motd string // --- End of options section @@ -174,12 +181,20 @@ type Server struct { type countingReader struct { reader io.Reader counter *atomic.Int64 + // lastActivity is the UnixNano timestamp of the most recent + // successful read (n > 0). It is updated atomically so that an + // idle watcher goroutine can observe activity without locking. + // May be nil when activity tracking is not needed. + lastActivity *atomic.Int64 } func (cr *countingReader) Read(p []byte) (n int, err error) { n, err = cr.reader.Read(p) if n > 0 { cr.counter.Add(int64(n)) + if cr.lastActivity != nil { + cr.lastActivity.Store(time.Now().UnixNano()) + } } return n, err } @@ -282,6 +297,10 @@ func (s *Server) loadConfig(c *Config, openLog bool) error { } } s.Motd = c.Proxy.Motd + if c.Proxy.RelayIdleTimeoutSecs < 0 { + return fmt.Errorf("relay_idle_timeout must be non-negative, got %d", c.Proxy.RelayIdleTimeoutSecs) + } + s.RelayIdleTimeout = time.Duration(c.Proxy.RelayIdleTimeoutSecs) * time.Second s.modules = modules s.upstreams = resolvedUpstreams s.upstreamQueues = s.updateUpstreamQueuesLocked(resolvedUpstreams) @@ -705,19 +724,64 @@ func (s *Server) relay(ctx context.Context, index uint32, downConn net.Conn) err downReader := &countingReader{reader: downConn, counter: &info.ReceivedBytes} upReader := &countingReader{reader: upConn, counter: &info.SentBytes} + // If an idle timeout is configured, share a single lastActivity + // timestamp between the two directions and start a watcher + // goroutine that tears the connections down when no data has + // flowed in either direction for the configured duration. This + // mirrors rsyncd's "timeout" semantics (rsyncd.conf(5)). + idleTimeout := s.RelayIdleTimeout + var idleTimedOut atomic.Bool sentClosed := make(chan struct{}) receivedClosed := make(chan struct{}) + if idleTimeout > 0 { + lastActivity := &atomic.Int64{} + lastActivity.Store(time.Now().UnixNano()) + downReader.lastActivity = lastActivity + upReader.lastActivity = lastActivity + + // Wake at most a few times per timeout window so a stalled + // connection is detected within roughly 1.25x the configured + // timeout in the worst case, while keeping wakeup overhead + // negligible. + interval := idleTimeout / 4 + if interval < time.Second { + interval = time.Second + } + + go func() { + ticker := time.NewTicker(interval) + defer ticker.Stop() + for { + select { + case <-sentClosed: + return + case <-receivedClosed: + return + case now := <-ticker.C: + last := time.Unix(0, lastActivity.Load()) + if now.Sub(last) >= idleTimeout { + idleTimedOut.Store(true) + s.accessLog.F("client %s idle for module %s exceeds %s, closing", ip, moduleName, idleTimeout) + _ = upConn.Close() + _ = downConn.Close() + return + } + } + } + }() + } + go func() { _, err := io.Copy(upConn, downReader) - if err != nil { + if err != nil && !idleTimedOut.Load() { s.errorLog.F("copy from downstream to upstream: %v", err) } close(sentClosed) }() go func() { _, err := io.Copy(downConn, upReader) - if err != nil { + if err != nil && !idleTimedOut.Load() { s.errorLog.F("copy from upstream to downstream: %v", err) } close(receivedClosed) diff --git a/pkg/server/server_test.go b/pkg/server/server_test.go index 9b374aa..d2b3d2f 100644 --- a/pkg/server/server_test.go +++ b/pkg/server/server_test.go @@ -209,6 +209,125 @@ func TestClientReadTimeout(t *testing.T) { r.Equal(expected, string(allData)) } +// TestRelayIdleTimeoutClosesIdleConnection verifies that when +// RelayIdleTimeout is configured and no data flows in either +// direction during the bidirectional relay phase, the proxy tears the +// connection down. This mirrors rsyncd's "timeout" behavior in +// rsyncd.conf(5). +func TestRelayIdleTimeoutClosesIdleConnection(t *testing.T) { + srv := startServer(t) + defer srv.Close() + srv.RelayIdleTimeout = 500 * time.Millisecond + accessLogPath := setupAccessLog(t, srv) + + r := require.New(t) + + upstreamReady := make(chan struct{}) + upstreamDone := make(chan struct{}) + + fakeRsync := rsync.NewServer(func(conn *rsync.Conn) { + defer conn.Close() + defer close(upstreamDone) + + _, _, err := doServerHandshake(conn, RsyncdServerVersion) + r.NoError(err, "upstream handshake") + close(upstreamReady) + + // Stay quiet so the relay phase has no I/O. The proxy must + // close us once the idle timeout elapses; ReadAll then + // returns with an EOF / closed-connection error. + _, _ = io.ReadAll(conn) + }) + fakeRsync.Start() + defer fakeRsync.Close() + + srv.modules = map[string][]Target{ + "fake": {{Upstream: "u1", Addr: fakeRsync.Listener.Addr().String()}}, + } + srv.upstreamQueues = map[string]*queue.Queue{"u1": queue.New(0, 0)} + + rawConn, err := net.Dial("tcp", srv.TCPListener.Addr().String()) + r.NoError(err) + conn := rsync.NewConn(rawConn) + defer conn.Close() + + _, err = doClientHandshake(conn, RsyncdServerVersion, "fake") + r.NoError(err) + + <-upstreamReady + + start := time.Now() + // ReadAll should return shortly after the proxy closes our + // connection due to the idle timeout firing on the relay side. + _, err = io.ReadAll(conn) + r.NoError(err) + elapsed := time.Since(start) + + // Allow generous slack: must be at least the configured timeout, + // and not pathologically long (e.g. waiting forever). + r.GreaterOrEqual(elapsed, srv.RelayIdleTimeout, "should have waited at least the idle timeout") + r.Less(elapsed, 5*time.Second, "should not block far beyond the idle timeout") + + select { + case <-upstreamDone: + case <-time.After(2 * time.Second): + t.Fatal("upstream connection was not closed after idle timeout") + } + + logData, err := os.ReadFile(accessLogPath) + r.NoError(err) + assert.Contains(t, string(logData), "idle for module fake exceeds") +} + +// TestRelayIdleTimeoutNotTriggeredWhenActive verifies that the idle +// timeout is reset whenever data flows, so a slow but continuously +// active stream does not get cut. The fake upstream sends data at an +// interval well below the idle timeout for several iterations. +func TestRelayIdleTimeoutNotTriggeredWhenActive(t *testing.T) { + srv := startServer(t) + defer srv.Close() + srv.RelayIdleTimeout = 2 * time.Second + + r := require.New(t) + + const iterations = 5 + const interval = 200 * time.Millisecond + + fakeRsync := rsync.NewServer(func(conn *rsync.Conn) { + defer conn.Close() + _, _, err := doServerHandshake(conn, RsyncdServerVersion) + r.NoError(err, "upstream handshake") + + for i := 0; i < iterations; i++ { + _, err = conn.Write([]byte("data\n")) + r.NoError(err, "write data") + time.Sleep(interval) + } + }) + fakeRsync.Start() + defer fakeRsync.Close() + + srv.modules = map[string][]Target{ + "fake": {{Upstream: "u1", Addr: fakeRsync.Listener.Addr().String()}}, + } + srv.upstreamQueues = map[string]*queue.Queue{"u1": queue.New(0, 0)} + + rawConn, err := net.Dial("tcp", srv.TCPListener.Addr().String()) + r.NoError(err) + conn := rsync.NewConn(rawConn) + defer conn.Close() + + _, err = doClientHandshake(conn, RsyncdServerVersion, "fake") + r.NoError(err) + + allData, err := io.ReadAll(conn) + r.NoError(err) + + expected := strings.Repeat("data\n", iterations) + r.Equal(expected, string(allData), + "steadily flowing traffic must not be interrupted by the idle timeout") +} + func TestTLSRsyncListener(t *testing.T) { r := require.New(t) From fe615e25f751bbb0cb747d760f758a38dd6cec66 Mon Sep 17 00:00:00 2001 From: yaoge123 Date: Sun, 31 May 2026 11:43:38 +0800 Subject: [PATCH 2/4] feat: add connection controls (TCP keepalive, per-IP cap, relay max duration) Three new mechanisms to mitigate clients holding upstream slots without making progress ("squatters"): * tcp_keepalive (proxy-wide, seconds): enables TCP keepalive with the configured period on both accepted client connections and dialed upstream connections, so a half-open peer is detected within minutes rather than the OS default (~2 hours, or disabled). 0 keeps OS-default behavior. * per_ip_max_active_connections (proxy-wide default with per-upstream override): caps the number of simultaneous active relay connections a single client IP may have to one upstream. Counted before queue admission so the cap also bounds queueing. Rejected requests get an @ERROR line and are tracked via the new rsync_proxy_per_ip_rejected_total{upstream} counter. * relay_max_duration (proxy-wide, seconds): hard wall-clock cap on the bidirectional relay phase. Connections still alive past this duration are closed regardless of activity, complementing the existing relay_idle_timeout (which only fires on no-I/O). The existing relay watcher goroutine now handles both checks. The reload-mutated timing fields (RelayIdleTimeout, RelayMaxDuration, TCPKeepAlive, dialer.KeepAlive) are now read under the reload lock via getRelayTimings/getTCPKeepAlive/getDialer helpers to avoid data races with config reloads. Tests cover per-IP rejection (counter + metric + access log), max-duration enforcement (with idle disabled), the keepalive helper, config propagation to the dialer, and rejection of negative values for the three new settings. --- assets/config.example.toml | 24 ++++ pkg/server/config.go | 23 ++++ pkg/server/metrics.go | 8 ++ pkg/server/server.go | 253 +++++++++++++++++++++++++++++----- pkg/server/server_test.go | 272 +++++++++++++++++++++++++++++++++++++ 5 files changed, 545 insertions(+), 35 deletions(-) diff --git a/assets/config.example.toml b/assets/config.example.toml index d0b2938..a7fd223 100644 --- a/assets/config.example.toml +++ b/assets/config.example.toml @@ -18,11 +18,35 @@ motd = "Served by rsync-proxy (https://github.com/ustclug/rsync-proxy)" # rsyncd.conf(5); 600 is a common choice for public mirrors. #relay_idle_timeout = 0 +# Hard upper bound (seconds) on the total wall-clock duration of the +# bidirectional relay phase. When exceeded the proxy closes the +# connection regardless of activity. rsync clients typically resume +# on reconnect. 0 (the default) disables this cap. +#relay_max_duration = 0 + +# TCP keepalive period (seconds) applied to accepted client +# connections and dialed upstream connections. Helps detect +# half-open connections (peer crashed, NAT entry reaped) within +# minutes rather than the OS default (~2 hours). 0 (the default) +# leaves the OS-default keepalive behavior in place. +#tcp_keepalive = 0 + +# Default per-IP per-upstream concurrent active connection cap. +# Limits how many simultaneous active relay connections a single +# client IP may have to any one upstream, preventing a single client +# from monopolizing upstream slots. 0 (the default) disables the +# limit. Each upstream may override this via its own +# per_ip_max_active_connections setting below. +#per_ip_max_active_connections = 0 + [upstreams.u1] address = "127.0.0.1:1234" modules = ["foo"] max_active_connections = 60 max_queued_connections = 60 +# Override the proxy-wide per_ip_max_active_connections for this +# upstream. 0 means inherit the proxy-wide default. +#per_ip_max_active_connections = 4 [upstreams.u1_auto] address = "127.0.0.1:1234" diff --git a/pkg/server/config.go b/pkg/server/config.go index ffe8efd..776ee0d 100644 --- a/pkg/server/config.go +++ b/pkg/server/config.go @@ -15,6 +15,11 @@ type Upstream struct { UseProxyProtocol bool `toml:"use_proxy_protocol"` MaxActiveConns int `toml:"max_active_connections"` MaxQueuedConns int `toml:"max_queued_connections"` + // PerIPMaxActiveConns overrides the proxy-wide + // per_ip_max_active_connections setting for this upstream. A + // value of 0 (the default, i.e. field omitted) means the + // upstream inherits the proxy-wide value. + PerIPMaxActiveConns int `toml:"per_ip_max_active_connections"` } type ProxySettings struct { @@ -35,6 +40,24 @@ type ProxySettings struct { // rsyncd.conf(5)), which is an I/O timeout. A common choice for // public mirrors is 600. RelayIdleTimeoutSecs int `toml:"relay_idle_timeout"` + // RelayMaxDurationSecs is the maximum total wall-clock duration + // (in seconds) of the bidirectional relay phase. When exceeded + // the proxy closes the connection regardless of activity. 0 (the + // default) disables this hard cap. rsync clients will typically + // reconnect and resume on the next run. + RelayMaxDurationSecs int `toml:"relay_max_duration"` + // TCPKeepAliveSecs enables TCP keepalive on accepted client + // connections and on dialed upstream connections. The value is + // the keepalive period in seconds; 0 (the default) leaves the + // OS-default keepalive behavior in place (typically: disabled or + // ~2 hours). Enabling this helps detect half-open connections + // (peer crashed, NAT reaped) within minutes rather than hours. + TCPKeepAliveSecs int `toml:"tcp_keepalive"` + // PerIPMaxActiveConns is the proxy-wide default for the per-IP + // per-upstream concurrency cap applied during the relay phase. + // 0 (the default) disables the limit. Each upstream may + // override this via [upstreams.X].per_ip_max_active_connections. + PerIPMaxActiveConns int `toml:"per_ip_max_active_connections"` } type Config struct { diff --git a/pkg/server/metrics.go b/pkg/server/metrics.go index fedfeb1..96be47f 100644 --- a/pkg/server/metrics.go +++ b/pkg/server/metrics.go @@ -97,6 +97,14 @@ func (s *Server) writePrometheusMetrics(w io.Writer, now time.Time) { prometheusEscapeLabelValue(u.Name), c.dialError.Load()) } + _, _ = fmt.Fprintln(w, "# HELP rsync_proxy_per_ip_rejected_total Total connections rejected by the per-IP per-upstream concurrency cap.") + _, _ = fmt.Fprintln(w, "# TYPE rsync_proxy_per_ip_rejected_total counter") + for _, u := range upstreams { + c := s.getUpstreamCounters(u.Name) + _, _ = fmt.Fprintf(w, "rsync_proxy_per_ip_rejected_total{upstream=\"%s\"} %d\n", + prometheusEscapeLabelValue(u.Name), c.perIPRejected.Load()) + } + _, _ = fmt.Fprintln(w, "# HELP rsync_proxy_unknown_module_requests_total Total requests for unknown modules.") _, _ = fmt.Fprintln(w, "# TYPE rsync_proxy_unknown_module_requests_total counter") _, _ = fmt.Fprintf(w, "rsync_proxy_unknown_module_requests_total %d\n", s.unknownModuleCount.Load()) diff --git a/pkg/server/server.go b/pkg/server/server.go index 951cf90..b0da128 100644 --- a/pkg/server/server.go +++ b/pkg/server/server.go @@ -119,12 +119,26 @@ type upstreamConfig struct { DiscoverModules bool MaxActiveConns int MaxQueuedConns int + // PerIPMaxActiveConns is the resolved (effective) limit on the + // number of concurrent active relay connections from a single + // client IP to this upstream. 0 means no per-IP cap. Computed at + // load time as: per-upstream override (Upstream.PerIPMaxActiveConns) + // or, if zero, the proxy-wide default (Proxy.PerIPMaxActiveConns). + PerIPMaxActiveConns int } // upstreamCounters holds per-upstream failure counters. type upstreamCounters struct { - queueFull atomic.Uint64 - dialError atomic.Uint64 + queueFull atomic.Uint64 + dialError atomic.Uint64 + perIPRejected atomic.Uint64 +} + +// perIPCountKey identifies a (upstream, client IP) pair for tracking +// per-IP per-upstream concurrent active relay connections. +type perIPCountKey struct { + upstream string + ip string } type Server struct { @@ -145,6 +159,18 @@ type Server struct { // rsyncd.conf(5). RelayIdleTimeout time.Duration + // RelayMaxDuration is a hard cap on the total wall-clock + // duration of the bidirectional relay phase of a connection. + // When exceeded the proxy closes both directions regardless of + // activity. 0 (the default) disables the cap. + RelayMaxDuration time.Duration + + // TCPKeepAlive is the keepalive period applied to accepted + // client connections and to dialed upstream connections. 0 (the + // default) leaves the OS-default keepalive behavior in place + // (typically: disabled, or ~2 hours). + TCPKeepAlive time.Duration + Motd string // --- End of options section @@ -173,6 +199,11 @@ type Server struct { upstreamCounters sync.Map unknownModuleCount atomic.Uint64 + // Per-(upstream, client IP) active-connection counters. + // Lazy-initialized via getPerIPCounter. + // map key is perIPCountKey. Value is *atomic.Int64. + perIPCounts sync.Map + TCPListener net.Listener TLSListener net.Listener HTTPListener net.Listener @@ -241,6 +272,16 @@ func (s *Server) loadConfig(c *Config, openLog bool) error { tlsCertificate = &cert } + if c.Proxy.RelayMaxDurationSecs < 0 { + return fmt.Errorf("relay_max_duration must be non-negative, got %d", c.Proxy.RelayMaxDurationSecs) + } + if c.Proxy.TCPKeepAliveSecs < 0 { + return fmt.Errorf("tcp_keepalive must be non-negative, got %d", c.Proxy.TCPKeepAliveSecs) + } + if c.Proxy.PerIPMaxActiveConns < 0 { + return fmt.Errorf("per_ip_max_active_connections must be non-negative, got %d", c.Proxy.PerIPMaxActiveConns) + } + upstreams := make([]upstreamConfig, 0, len(c.Upstreams)) upstreamNames := make([]string, 0, len(c.Upstreams)) for upstreamName := range c.Upstreams { @@ -252,17 +293,27 @@ func (s *Server) loadConfig(c *Config, openLog bool) error { if len(v.Modules) == 0 && !v.DiscoverModules { return fmt.Errorf("upstream=%s must set modules or discover_modules", upstreamName) } + if v.PerIPMaxActiveConns < 0 { + return fmt.Errorf("upstream=%s: per_ip_max_active_connections must be non-negative, got %d", upstreamName, v.PerIPMaxActiveConns) + } addr := v.Address if err := validateTCPOrUnixAddr(addr); err != nil { return fmt.Errorf("resolve address: %w, upstream=%s, address=%s", err, upstreamName, addr) } + // Resolve effective per-IP cap: per-upstream override wins; + // fall back to proxy-wide default; 0 means no cap. + effectivePerIP := v.PerIPMaxActiveConns + if effectivePerIP == 0 { + effectivePerIP = c.Proxy.PerIPMaxActiveConns + } upstreams = append(upstreams, upstreamConfig{ - Name: upstreamName, - Target: Target{Upstream: upstreamName, Addr: addr, UseProxyProtocol: v.UseProxyProtocol}, - Modules: slices.Clone(v.Modules), - DiscoverModules: v.DiscoverModules, - MaxActiveConns: v.MaxActiveConns, - MaxQueuedConns: v.MaxQueuedConns, + Name: upstreamName, + Target: Target{Upstream: upstreamName, Addr: addr, UseProxyProtocol: v.UseProxyProtocol}, + Modules: slices.Clone(v.Modules), + DiscoverModules: v.DiscoverModules, + MaxActiveConns: v.MaxActiveConns, + MaxQueuedConns: v.MaxQueuedConns, + PerIPMaxActiveConns: effectivePerIP, }) } @@ -301,6 +352,12 @@ func (s *Server) loadConfig(c *Config, openLog bool) error { return fmt.Errorf("relay_idle_timeout must be non-negative, got %d", c.Proxy.RelayIdleTimeoutSecs) } s.RelayIdleTimeout = time.Duration(c.Proxy.RelayIdleTimeoutSecs) * time.Second + s.RelayMaxDuration = time.Duration(c.Proxy.RelayMaxDurationSecs) * time.Second + s.TCPKeepAlive = time.Duration(c.Proxy.TCPKeepAliveSecs) * time.Second + // Reflect the new keepalive setting on the dialer used to dial + // upstreams. The dialer is consulted under reloadLock via + // getDialer() to avoid racing with reloads. + s.dialer.KeepAlive = s.TCPKeepAlive s.modules = modules s.upstreams = resolvedUpstreams s.upstreamQueues = s.updateUpstreamQueuesLocked(resolvedUpstreams) @@ -316,12 +373,13 @@ func resolveUpstreams(upstreams []upstreamConfig, discovered map[string][]string modules = slices.Clone(discovered[upstream.Name]) } resolved = append(resolved, upstreamConfig{ - Name: upstream.Name, - Target: upstream.Target, - Modules: modules, - DiscoverModules: upstream.DiscoverModules, - MaxActiveConns: upstream.MaxActiveConns, - MaxQueuedConns: upstream.MaxQueuedConns, + Name: upstream.Name, + Target: upstream.Target, + Modules: modules, + DiscoverModules: upstream.DiscoverModules, + MaxActiveConns: upstream.MaxActiveConns, + MaxQueuedConns: upstream.MaxQueuedConns, + PerIPMaxActiveConns: upstream.PerIPMaxActiveConns, }) } return resolved @@ -365,6 +423,73 @@ func (s *Server) getUpstreamCounters(name string) *upstreamCounters { return v.(*upstreamCounters) } +// getPerIPCounter returns the (upstream, ip) active-connection counter, +// creating it lazily. Safe for concurrent use. +func (s *Server) getPerIPCounter(upstream, ip string) *atomic.Int64 { + key := perIPCountKey{upstream: upstream, ip: ip} + if v, ok := s.perIPCounts.Load(key); ok { + return v.(*atomic.Int64) + } + v, _ := s.perIPCounts.LoadOrStore(key, &atomic.Int64{}) + return v.(*atomic.Int64) +} + +// getDialer returns a copy of the current upstream dialer, taking the +// reload lock to avoid racing with config reloads that mutate the +// dialer (e.g. KeepAlive). The returned value is safe to pass by value +// to dialContextTCPOrUnix. +func (s *Server) getDialer() net.Dialer { + s.reloadLock.RLock() + defer s.reloadLock.RUnlock() + return s.dialer +} + +// getRelayTimings returns a snapshot of the relay-phase timing +// settings under the reload lock. These fields are mutated by config +// reloads, so callers must not read them directly. +func (s *Server) getRelayTimings() (idle, maxDuration, tcpKeepAlive time.Duration) { + s.reloadLock.RLock() + defer s.reloadLock.RUnlock() + return s.RelayIdleTimeout, s.RelayMaxDuration, s.TCPKeepAlive +} + +// getTCPKeepAlive returns the configured TCP keepalive period under +// the reload lock. +func (s *Server) getTCPKeepAlive() time.Duration { + s.reloadLock.RLock() + defer s.reloadLock.RUnlock() + return s.TCPKeepAlive +} + +// getPerIPLimitForUpstream returns the configured per-IP active +// connection cap for the named upstream, or 0 if none is set or the +// upstream is unknown. Safe for concurrent use. +func (s *Server) getPerIPLimitForUpstream(name string) int { + s.reloadLock.RLock() + defer s.reloadLock.RUnlock() + for i := range s.upstreams { + if s.upstreams[i].Name == name { + return s.upstreams[i].PerIPMaxActiveConns + } + } + return 0 +} + +// applyTCPKeepAlive enables TCP keepalive on the given connection if +// it is a *net.TCPConn and a positive period is provided. Other +// connection types (e.g. unix sockets) are silently ignored. +func applyTCPKeepAlive(conn net.Conn, period time.Duration) { + if period <= 0 { + return + } + tc, ok := conn.(*net.TCPConn) + if !ok { + return + } + _ = tc.SetKeepAlive(true) + _ = tc.SetKeepAlivePeriod(period) +} + func buildModuleTargets(upstreams []upstreamConfig) map[string][]Target { modules := map[string][]Target{} for _, upstream := range upstreams { @@ -632,6 +757,23 @@ func (s *Server) relay(ctx context.Context, index uint32, downConn net.Conn) err return fmt.Errorf("no queue configured for upstream %s", target.Upstream) } + // Per-IP per-upstream concurrency cap. A positive cap means the + // same client IP may not have more than N simultaneous active + // relay connections to this upstream. Counted before queue + // admission so the cap also bounds queueing, preventing a single + // IP from monopolizing both the active slots and the queue. + if perIPLimit := s.getPerIPLimitForUpstream(target.Upstream); perIPLimit > 0 { + counter := s.getPerIPCounter(target.Upstream, ip) + if n := counter.Add(1); int(n) > perIPLimit { + counter.Add(-1) + s.getUpstreamCounters(target.Upstream).perIPRejected.Add(1) + s.accessLog.F("client %s rejected for upstream %s module %s: per-IP cap of %d reached", ip, target.Upstream, moduleName, perIPLimit) + _, _ = writeWithTimeout(downConn, fmt.Appendf(nil, "@ERROR: per-IP connection limit of %d for upstream %s reached, retry later\n", perIPLimit, target.Upstream), writeTimeout) + return nil + } + defer counter.Add(-1) + } + handle := upstreamQueue.Acquire() defer handle.Release() status := <-handle.C @@ -670,12 +812,19 @@ func (s *Server) relay(ctx context.Context, index uint32, downConn net.Conn) err } } - upConn, err := dialContextTCPOrUnix(ctx, s.dialer, upstreamAddr) + upConn, err := dialContextTCPOrUnix(ctx, s.getDialer(), upstreamAddr) if err != nil { s.getUpstreamCounters(target.Upstream).dialError.Add(1) return fmt.Errorf("dial to upstream: %s: %w", upstreamAddr, err) } defer upConn.Close() + // Enable TCP keepalive on the upstream-side connection so that a + // dead/half-open peer is detected within the configured period + // rather than relying on the OS default (commonly ~2 hours, or + // disabled). dialer.KeepAlive only takes effect for the initial + // SYN window; explicitly setting it on the resulting *TCPConn + // is portable and idempotent. + applyTCPKeepAlive(upConn, s.getTCPKeepAlive()) upAddr := netAddrToString(upConn.RemoteAddr()) if useProxyProtocol { err := writeProxyProtocolHeader(upConn, downConn.RemoteAddr(), upConn.RemoteAddr(), s.WriteTimeout) @@ -724,27 +873,46 @@ func (s *Server) relay(ctx context.Context, index uint32, downConn net.Conn) err downReader := &countingReader{reader: downConn, counter: &info.ReceivedBytes} upReader := &countingReader{reader: upConn, counter: &info.SentBytes} - // If an idle timeout is configured, share a single lastActivity - // timestamp between the two directions and start a watcher - // goroutine that tears the connections down when no data has - // flowed in either direction for the configured duration. This - // mirrors rsyncd's "timeout" semantics (rsyncd.conf(5)). - idleTimeout := s.RelayIdleTimeout + // Optional watchers on the bidirectional relay phase: + // * RelayIdleTimeout terminates a connection that has had no + // data flow in either direction for the configured duration + // (rsyncd "timeout" semantics, rsyncd.conf(5)). + // * RelayMaxDuration is a hard wall-clock cap on the entire + // relay phase; any connection still alive past this point is + // terminated regardless of activity. + // A single goroutine handles both checks. The shared + // idleTimedOut flag is used by the io.Copy goroutines to + // suppress the expected "use of closed network connection" error + // log when the watcher initiated the close. + idleTimeout, maxDuration, _ := s.getRelayTimings() + relayStartedAt := time.Now() var idleTimedOut atomic.Bool sentClosed := make(chan struct{}) receivedClosed := make(chan struct{}) - if idleTimeout > 0 { - lastActivity := &atomic.Int64{} - lastActivity.Store(time.Now().UnixNano()) - downReader.lastActivity = lastActivity - upReader.lastActivity = lastActivity - - // Wake at most a few times per timeout window so a stalled - // connection is detected within roughly 1.25x the configured - // timeout in the worst case, while keeping wakeup overhead - // negligible. - interval := idleTimeout / 4 + if idleTimeout > 0 || maxDuration > 0 { + var lastActivity *atomic.Int64 + if idleTimeout > 0 { + lastActivity = &atomic.Int64{} + lastActivity.Store(relayStartedAt.UnixNano()) + downReader.lastActivity = lastActivity + upReader.lastActivity = lastActivity + } + + // Wake at most a few times per the smaller of the two + // configured windows so timeouts are detected within roughly + // 1.25x the configured value in the worst case, while keeping + // wakeup overhead negligible. + interval := time.Duration(0) + if idleTimeout > 0 { + interval = idleTimeout / 4 + } + if maxDuration > 0 { + d := maxDuration / 4 + if interval == 0 || d < interval { + interval = d + } + } if interval < time.Second { interval = time.Second } @@ -759,10 +927,19 @@ func (s *Server) relay(ctx context.Context, index uint32, downConn net.Conn) err case <-receivedClosed: return case now := <-ticker.C: - last := time.Unix(0, lastActivity.Load()) - if now.Sub(last) >= idleTimeout { + if idleTimeout > 0 { + last := time.Unix(0, lastActivity.Load()) + if now.Sub(last) >= idleTimeout { + idleTimedOut.Store(true) + s.accessLog.F("client %s idle for module %s exceeds %s, closing", ip, moduleName, idleTimeout) + _ = upConn.Close() + _ = downConn.Close() + return + } + } + if maxDuration > 0 && now.Sub(relayStartedAt) >= maxDuration { idleTimedOut.Store(true) - s.accessLog.F("client %s idle for module %s exceeds %s, closing", ip, moduleName, idleTimeout) + s.accessLog.F("client %s exceeded max relay duration %s for module %s, closing", ip, maxDuration, moduleName) _ = upConn.Close() _ = downConn.Close() return @@ -1005,6 +1182,12 @@ func (s *Server) handleConn(ctx context.Context, conn net.Conn) { s.acceptedConnCount.Add(1) connIndex := s.connIndex.Add(1) + // Apply TCP keepalive on the accepted client connection so that + // half-open connections (peer crashed, NAT entry reaped) are + // detected within the configured period rather than waiting for + // the OS default. No-op on unix-domain or non-TCP listeners. + applyTCPKeepAlive(conn, s.getTCPKeepAlive()) + defer func() { err := recover() if err != nil { diff --git a/pkg/server/server_test.go b/pkg/server/server_test.go index d2b3d2f..df1d230 100644 --- a/pkg/server/server_test.go +++ b/pkg/server/server_test.go @@ -1220,3 +1220,275 @@ func TestDiscoverModulesFromTrailingModuleBlock(t *testing.T) { require.NoError(t, err) assert.Equal(t, []string{"bar", "foo"}, modules) } + +// TestPerIPCapRejectsConnectionsBeyondLimit verifies that when a +// per-IP active connection cap is configured for an upstream, a +// second simultaneous connection from the same client IP to the same +// upstream is rejected with an @ERROR message, the perIPRejected +// counter is incremented, and the rejection is reflected in the +// /metrics output. +func TestPerIPCapRejectsConnectionsBeyondLimit(t *testing.T) { + srv := startServer(t) + defer srv.Close() + accessLogPath := setupAccessLog(t, srv) + + var release sync.WaitGroup + release.Add(1) + + upstream := rsync.NewServer(func(conn *rsync.Conn) { + defer conn.Close() + _, _, err := doServerHandshake(conn, RsyncdServerVersion) + require.NoError(t, err) + release.Wait() + }) + upstream.Start() + defer upstream.Close() + + srv.reloadLock.Lock() + srv.upstreams = []upstreamConfig{ + {Name: "u1", PerIPMaxActiveConns: 1}, + } + srv.modules = map[string][]Target{ + "fake": {{Upstream: "u1", Addr: upstream.Listener.Addr().String()}}, + } + srv.upstreamQueues = map[string]*queue.Queue{"u1": queue.New(0, 0)} + srv.reloadLock.Unlock() + + // First connection occupies the per-IP slot. + c1Raw, err := net.Dial("tcp", srv.TCPListener.Addr().String()) + require.NoError(t, err) + c1 := rsync.NewConn(c1Raw) + defer c1.Close() + _, err = doClientHandshake(c1, RsyncdServerVersion, "fake") + require.NoError(t, err) + + // Second connection from the same IP must be rejected. + c2Raw, err := net.Dial("tcp", srv.TCPListener.Addr().String()) + require.NoError(t, err) + c2 := rsync.NewConn(c2Raw) + defer c2.Close() + _, err = doClientHandshake(c2, RsyncdServerVersion, "fake") + require.NoError(t, err) + + body, err := io.ReadAll(c2) + require.NoError(t, err) + assert.Contains(t, string(body), "@ERROR: per-IP connection limit") + assert.Contains(t, string(body), "for upstream u1") + + require.Eventually(t, func() bool { + return srv.getUpstreamCounters("u1").perIPRejected.Load() == 1 + }, time.Second, 10*time.Millisecond) + + resp, err := testHTTPClient().Get("http://" + srv.HTTPListener.Addr().String() + "/metrics") + require.NoError(t, err) + defer resp.Body.Close() + metricsBody, err := io.ReadAll(resp.Body) + require.NoError(t, err) + text := string(metricsBody) + assert.Contains(t, text, "# HELP rsync_proxy_per_ip_rejected_total") + assert.Contains(t, text, "# TYPE rsync_proxy_per_ip_rejected_total counter") + assert.Contains(t, text, "rsync_proxy_per_ip_rejected_total{upstream=\"u1\"} 1\n") + + logData, err := os.ReadFile(accessLogPath) + require.NoError(t, err) + assert.Contains(t, string(logData), "per-IP cap of 1 reached") + + release.Done() +} + +// TestRelayMaxDurationClosesLongConnection verifies that when +// RelayMaxDuration is configured, an otherwise-active relay (no idle +// timeout) is forcibly torn down once it exceeds the configured +// wall-clock cap. +func TestRelayMaxDurationClosesLongConnection(t *testing.T) { + srv := startServer(t) + defer srv.Close() + srv.RelayIdleTimeout = 0 + srv.RelayMaxDuration = 500 * time.Millisecond + accessLogPath := setupAccessLog(t, srv) + + r := require.New(t) + + upstreamReady := make(chan struct{}) + upstreamDone := make(chan struct{}) + + fakeRsync := rsync.NewServer(func(conn *rsync.Conn) { + defer conn.Close() + defer close(upstreamDone) + + _, _, err := doServerHandshake(conn, RsyncdServerVersion) + r.NoError(err, "upstream handshake") + close(upstreamReady) + + // Stay quiet so the relay phase has no I/O. With idle timeout + // disabled, only the max-duration watcher can tear this + // connection down. + _, _ = io.ReadAll(conn) + }) + fakeRsync.Start() + defer fakeRsync.Close() + + srv.modules = map[string][]Target{ + "fake": {{Upstream: "u1", Addr: fakeRsync.Listener.Addr().String()}}, + } + srv.upstreamQueues = map[string]*queue.Queue{"u1": queue.New(0, 0)} + + rawConn, err := net.Dial("tcp", srv.TCPListener.Addr().String()) + r.NoError(err) + conn := rsync.NewConn(rawConn) + defer conn.Close() + + _, err = doClientHandshake(conn, RsyncdServerVersion, "fake") + r.NoError(err) + + <-upstreamReady + + start := time.Now() + _, err = io.ReadAll(conn) + r.NoError(err) + elapsed := time.Since(start) + + r.GreaterOrEqual(elapsed, srv.RelayMaxDuration, "should have waited at least the max duration") + r.Less(elapsed, 5*time.Second, "should not block far beyond the max duration") + + select { + case <-upstreamDone: + case <-time.After(2 * time.Second): + t.Fatal("upstream connection was not closed after max duration") + } + + logData, err := os.ReadFile(accessLogPath) + r.NoError(err) + assert.Contains(t, string(logData), "exceeded max relay duration") + assert.Contains(t, string(logData), "for module fake") +} + +// TestApplyTCPKeepAliveOnTCPConn exercises the applyTCPKeepAlive +// helper end-to-end on a real *net.TCPConn. We cannot portably read +// SO_KEEPALIVE/TCP_KEEPIDLE back via the standard library, so we +// verify that the helper does not error and is safe to invoke with +// zero or negative periods. +func TestApplyTCPKeepAliveOnTCPConn(t *testing.T) { + ln, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + defer ln.Close() + + accepted := make(chan net.Conn, 1) + go func() { + c, err := ln.Accept() + if err != nil { + accepted <- nil + return + } + accepted <- c + }() + + clientConn, err := net.Dial("tcp", ln.Addr().String()) + require.NoError(t, err) + defer clientConn.Close() + + srvConn := <-accepted + require.NotNil(t, srvConn) + defer srvConn.Close() + + // Should be a no-op for a zero/negative period. + applyTCPKeepAlive(srvConn, 0) + applyTCPKeepAlive(srvConn, -1*time.Second) + + // Should successfully apply the keepalive on a real *net.TCPConn. + applyTCPKeepAlive(srvConn, 45*time.Second) + _, ok := srvConn.(*net.TCPConn) + assert.True(t, ok, "test sanity: accepted conn should be *net.TCPConn") +} + +// TestLoadConfigPropagatesTCPKeepAliveToDialer verifies that the +// tcp_keepalive proxy setting is parsed, validated, and propagated to +// both the Server.TCPKeepAlive field and the underlying dialer used +// for upstream connections. +func TestLoadConfigPropagatesTCPKeepAliveToDialer(t *testing.T) { + configContent := ` +[proxy] +listen = "127.0.0.1:0" +listen_http = "127.0.0.1:0" +tcp_keepalive = 45 +relay_max_duration = 7200 +per_ip_max_active_connections = 4 + +[upstreams.u1] +address = "127.0.0.1:8730" +modules = ["m1"] +` + srv := New() + require.NoError(t, srv.ReadConfig(strings.NewReader(configContent), false)) + + assert.Equal(t, 45*time.Second, srv.TCPKeepAlive) + assert.Equal(t, 45*time.Second, srv.dialer.KeepAlive) + assert.Equal(t, 2*time.Hour, srv.RelayMaxDuration) + + srv.reloadLock.RLock() + defer srv.reloadLock.RUnlock() + require.Len(t, srv.upstreams, 1) + assert.Equal(t, 4, srv.upstreams[0].PerIPMaxActiveConns) +} + +// TestLoadConfigRejectsNegativeTimings verifies that loadConfig +// rejects negative values for the new connection-control settings. +func TestLoadConfigRejectsNegativeTimings(t *testing.T) { + cases := []struct { + name string + config string + wantMsg string + }{ + { + name: "relay_max_duration", + config: ` +[proxy] +listen = "127.0.0.1:0" +listen_http = "127.0.0.1:0" +relay_max_duration = -1 + +[upstreams.u1] +address = "127.0.0.1:8730" +modules = ["m1"] +`, + wantMsg: "relay_max_duration must be non-negative", + }, + { + name: "tcp_keepalive", + config: ` +[proxy] +listen = "127.0.0.1:0" +listen_http = "127.0.0.1:0" +tcp_keepalive = -5 + +[upstreams.u1] +address = "127.0.0.1:8730" +modules = ["m1"] +`, + wantMsg: "tcp_keepalive must be non-negative", + }, + { + name: "per_ip_max_active_connections", + config: ` +[proxy] +listen = "127.0.0.1:0" +listen_http = "127.0.0.1:0" +per_ip_max_active_connections = -2 + +[upstreams.u1] +address = "127.0.0.1:8730" +modules = ["m1"] +`, + wantMsg: "per_ip_max_active_connections must be non-negative", + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + srv := New() + err := srv.ReadConfig(strings.NewReader(tc.config), false) + require.Error(t, err) + assert.Contains(t, err.Error(), tc.wantMsg) + }) + } +} From 14abc91f8f13080953b0233b8c4565af6dd97fe3 Mon Sep 17 00:00:00 2001 From: yaoge123 Date: Sun, 31 May 2026 14:23:02 +0800 Subject: [PATCH 3/4] feat: add dial timeout, throughput floor, and termination-cause counters Three additions to the connection-controls family that close gaps the existing idle/max-duration/per-IP/keepalive controls leave open and make the existing teardown paths observable. dial_timeout (proxy-wide) Bounds upstream Dial latency, sourced from net.Dialer.Timeout. The previous default behaviour relied on the kernel SYN-retry budget (~75s on Linux), which let unreachable backends pile up handshake- blocked goroutines and exhaust per-upstream queues. min_throughput_bytes / min_throughput_window / min_throughput_grace Sliding-window throughput floor for the relay phase. Idle timeout cannot catch slow drips (a client that sips a few hundred bytes every second never goes idle). The watcher samples info.SentBytes+info.ReceivedBytes once per window after grace and tears the connection down if the delta falls below the configured byte budget. min_throughput_grace defaults to min_throughput_window when unset to avoid false positives during handshake. Termination-cause counters rsync_proxy_relay_idle_timeout_terminated_total{upstream} rsync_proxy_relay_max_duration_terminated_total{upstream} rsync_proxy_throughput_floor_terminated_total{upstream} Until now an operator could not distinguish a watcher-killed connection from a normal close on the lifetime counter; these separate the three forced-shutdown paths. Watcher restructured to share one ticker across all three checks; the interval is min(idle/4, maxDuration/4, window/4) bounded at 1s. New config values are validated for non-negative seconds. The race-safe helper getRelayTimings() now returns a relayTimings struct so the RLock-guarded snapshot covers all six fields atomically. Tests cover counter increments and metrics output for idle, max- duration and throughput-floor paths, propagation of dial_timeout and min_throughput_* settings into the dialer and Server fields, the grace-defaults-to-window fallback, and rejection of negative values for all four new settings. --- assets/config.example.toml | 18 +++ pkg/server/config.go | 22 ++++ pkg/server/metrics.go | 24 ++++ pkg/server/server.go | 144 ++++++++++++++++++++--- pkg/server/server_test.go | 230 +++++++++++++++++++++++++++++++++++++ 5 files changed, 419 insertions(+), 19 deletions(-) diff --git a/assets/config.example.toml b/assets/config.example.toml index a7fd223..22c841c 100644 --- a/assets/config.example.toml +++ b/assets/config.example.toml @@ -39,6 +39,24 @@ motd = "Served by rsync-proxy (https://github.com/ustclug/rsync-proxy)" # per_ip_max_active_connections setting below. #per_ip_max_active_connections = 0 +# Connect timeout (seconds) when dialing upstreams. Bounds how long +# the proxy waits for an unresponsive upstream before failing fast +# and incrementing rsync_proxy_upstream_dial_errors_total. 0 (the +# default) leaves the OS connect-attempt behavior in place +# (typically ~75s of SYN retries on Linux). +#dial_timeout = 0 + +# Throughput floor enforced during the relay phase. A connection +# that transfers fewer than min_throughput_bytes over the most +# recent min_throughput_window seconds is treated as slow-leeching +# and terminated. min_throughput_grace gives a fresh connection +# this many seconds of ramp-up before the floor is enforced; +# defaults to min_throughput_window when unset. Set +# min_throughput_bytes to 0 (the default) to disable the check. +#min_throughput_bytes = 0 +#min_throughput_window = 60 +#min_throughput_grace = 60 + [upstreams.u1] address = "127.0.0.1:1234" modules = ["foo"] diff --git a/pkg/server/config.go b/pkg/server/config.go index 776ee0d..c6aabd3 100644 --- a/pkg/server/config.go +++ b/pkg/server/config.go @@ -58,6 +58,28 @@ type ProxySettings struct { // 0 (the default) disables the limit. Each upstream may // override this via [upstreams.X].per_ip_max_active_connections. PerIPMaxActiveConns int `toml:"per_ip_max_active_connections"` + // DialTimeoutSecs caps how long the proxy waits when dialing an + // upstream rsync server. A value of 0 (the default) leaves the + // OS-default TCP connect behavior in place (~75s on Linux). When + // an upstream is unreachable this surfaces a fast failure to + // the client and increments rsync_proxy_upstream_dial_errors_total + // without tying up the listener for the full kernel SYN-retry + // budget. + DialTimeoutSecs int `toml:"dial_timeout"` + // MinThroughputBytes, MinThroughputWindowSecs and + // MinThroughputGraceSecs together implement a throughput floor + // during the relay phase. Within any window of + // min_throughput_window seconds the connection must transfer at + // least min_throughput_bytes (counting both directions); if it + // does not, the proxy closes the connection. The grace period + // suppresses the check for the first min_throughput_grace seconds + // after the relay starts to avoid killing a slow-start session. + // Setting min_throughput_bytes or min_throughput_window to 0 (the + // default) disables the floor. min_throughput_grace defaults to + // the value of min_throughput_window when 0. + MinThroughputBytes int64 `toml:"min_throughput_bytes"` + MinThroughputWindowSecs int `toml:"min_throughput_window"` + MinThroughputGraceSecs int `toml:"min_throughput_grace"` } type Config struct { diff --git a/pkg/server/metrics.go b/pkg/server/metrics.go index 96be47f..37ee1dc 100644 --- a/pkg/server/metrics.go +++ b/pkg/server/metrics.go @@ -105,6 +105,30 @@ func (s *Server) writePrometheusMetrics(w io.Writer, now time.Time) { prometheusEscapeLabelValue(u.Name), c.perIPRejected.Load()) } + _, _ = fmt.Fprintln(w, "# HELP rsync_proxy_relay_idle_timeout_terminated_total Total relay connections terminated by the idle timeout watcher per upstream.") + _, _ = fmt.Fprintln(w, "# TYPE rsync_proxy_relay_idle_timeout_terminated_total counter") + for _, u := range upstreams { + c := s.getUpstreamCounters(u.Name) + _, _ = fmt.Fprintf(w, "rsync_proxy_relay_idle_timeout_terminated_total{upstream=\"%s\"} %d\n", + prometheusEscapeLabelValue(u.Name), c.idleTerminated.Load()) + } + + _, _ = fmt.Fprintln(w, "# HELP rsync_proxy_relay_max_duration_terminated_total Total relay connections terminated for exceeding the max-duration cap per upstream.") + _, _ = fmt.Fprintln(w, "# TYPE rsync_proxy_relay_max_duration_terminated_total counter") + for _, u := range upstreams { + c := s.getUpstreamCounters(u.Name) + _, _ = fmt.Fprintf(w, "rsync_proxy_relay_max_duration_terminated_total{upstream=\"%s\"} %d\n", + prometheusEscapeLabelValue(u.Name), c.maxDurationTerminated.Load()) + } + + _, _ = fmt.Fprintln(w, "# HELP rsync_proxy_throughput_floor_terminated_total Total relay connections terminated for falling below the configured throughput floor per upstream.") + _, _ = fmt.Fprintln(w, "# TYPE rsync_proxy_throughput_floor_terminated_total counter") + for _, u := range upstreams { + c := s.getUpstreamCounters(u.Name) + _, _ = fmt.Fprintf(w, "rsync_proxy_throughput_floor_terminated_total{upstream=\"%s\"} %d\n", + prometheusEscapeLabelValue(u.Name), c.throughputFloorTerminated.Load()) + } + _, _ = fmt.Fprintln(w, "# HELP rsync_proxy_unknown_module_requests_total Total requests for unknown modules.") _, _ = fmt.Fprintln(w, "# TYPE rsync_proxy_unknown_module_requests_total counter") _, _ = fmt.Fprintf(w, "rsync_proxy_unknown_module_requests_total %d\n", s.unknownModuleCount.Load()) diff --git a/pkg/server/server.go b/pkg/server/server.go index b0da128..f29ef1e 100644 --- a/pkg/server/server.go +++ b/pkg/server/server.go @@ -129,9 +129,12 @@ type upstreamConfig struct { // upstreamCounters holds per-upstream failure counters. type upstreamCounters struct { - queueFull atomic.Uint64 - dialError atomic.Uint64 - perIPRejected atomic.Uint64 + queueFull atomic.Uint64 + dialError atomic.Uint64 + perIPRejected atomic.Uint64 + idleTerminated atomic.Uint64 + maxDurationTerminated atomic.Uint64 + throughputFloorTerminated atomic.Uint64 } // perIPCountKey identifies a (upstream, client IP) pair for tracking @@ -171,6 +174,19 @@ type Server struct { // (typically: disabled, or ~2 hours). TCPKeepAlive time.Duration + // MinThroughputBytes, MinThroughputWindow and + // MinThroughputGrace together implement a throughput floor + // during the relay phase. Within any window of + // MinThroughputWindow the connection must transfer at least + // MinThroughputBytes (counting both directions); otherwise the + // proxy closes the connection. The grace period suppresses the + // check for the first MinThroughputGrace after the relay + // starts. Setting MinThroughputBytes or MinThroughputWindow to + // zero disables the floor. + MinThroughputBytes int64 + MinThroughputWindow time.Duration + MinThroughputGrace time.Duration + Motd string // --- End of options section @@ -281,6 +297,18 @@ func (s *Server) loadConfig(c *Config, openLog bool) error { if c.Proxy.PerIPMaxActiveConns < 0 { return fmt.Errorf("per_ip_max_active_connections must be non-negative, got %d", c.Proxy.PerIPMaxActiveConns) } + if c.Proxy.DialTimeoutSecs < 0 { + return fmt.Errorf("dial_timeout must be non-negative, got %d", c.Proxy.DialTimeoutSecs) + } + if c.Proxy.MinThroughputBytes < 0 { + return fmt.Errorf("min_throughput_bytes must be non-negative, got %d", c.Proxy.MinThroughputBytes) + } + if c.Proxy.MinThroughputWindowSecs < 0 { + return fmt.Errorf("min_throughput_window must be non-negative, got %d", c.Proxy.MinThroughputWindowSecs) + } + if c.Proxy.MinThroughputGraceSecs < 0 { + return fmt.Errorf("min_throughput_grace must be non-negative, got %d", c.Proxy.MinThroughputGraceSecs) + } upstreams := make([]upstreamConfig, 0, len(c.Upstreams)) upstreamNames := make([]string, 0, len(c.Upstreams)) @@ -354,10 +382,21 @@ func (s *Server) loadConfig(c *Config, openLog bool) error { s.RelayIdleTimeout = time.Duration(c.Proxy.RelayIdleTimeoutSecs) * time.Second s.RelayMaxDuration = time.Duration(c.Proxy.RelayMaxDurationSecs) * time.Second s.TCPKeepAlive = time.Duration(c.Proxy.TCPKeepAliveSecs) * time.Second - // Reflect the new keepalive setting on the dialer used to dial - // upstreams. The dialer is consulted under reloadLock via - // getDialer() to avoid racing with reloads. + s.MinThroughputBytes = c.Proxy.MinThroughputBytes + s.MinThroughputWindow = time.Duration(c.Proxy.MinThroughputWindowSecs) * time.Second + graceSecs := c.Proxy.MinThroughputGraceSecs + if graceSecs == 0 { + // Default the grace period to the window itself: a fresh + // connection gets one full window to ramp up before the + // floor is enforced. + graceSecs = c.Proxy.MinThroughputWindowSecs + } + s.MinThroughputGrace = time.Duration(graceSecs) * time.Second + // Reflect the new keepalive and dial-timeout settings on the + // dialer used to dial upstreams. The dialer is consulted under + // reloadLock via getDialer() to avoid racing with reloads. s.dialer.KeepAlive = s.TCPKeepAlive + s.dialer.Timeout = time.Duration(c.Proxy.DialTimeoutSecs) * time.Second s.modules = modules s.upstreams = resolvedUpstreams s.upstreamQueues = s.updateUpstreamQueuesLocked(resolvedUpstreams) @@ -444,13 +483,33 @@ func (s *Server) getDialer() net.Dialer { return s.dialer } +// relayTimings is a snapshot of all relay-phase timing settings, +// captured atomically under the reload lock. Callers should consume +// this snapshot for the lifetime of a single relay connection so the +// behavior remains consistent across reloads. +type relayTimings struct { + idle time.Duration + maxDuration time.Duration + tcpKeepAlive time.Duration + minBytes int64 + minWindow time.Duration + minGrace time.Duration +} + // getRelayTimings returns a snapshot of the relay-phase timing // settings under the reload lock. These fields are mutated by config // reloads, so callers must not read them directly. -func (s *Server) getRelayTimings() (idle, maxDuration, tcpKeepAlive time.Duration) { +func (s *Server) getRelayTimings() relayTimings { s.reloadLock.RLock() defer s.reloadLock.RUnlock() - return s.RelayIdleTimeout, s.RelayMaxDuration, s.TCPKeepAlive + return relayTimings{ + idle: s.RelayIdleTimeout, + maxDuration: s.RelayMaxDuration, + tcpKeepAlive: s.TCPKeepAlive, + minBytes: s.MinThroughputBytes, + minWindow: s.MinThroughputWindow, + minGrace: s.MinThroughputGrace, + } } // getTCPKeepAlive returns the configured TCP keepalive period under @@ -880,17 +939,31 @@ func (s *Server) relay(ctx context.Context, index uint32, downConn net.Conn) err // * RelayMaxDuration is a hard wall-clock cap on the entire // relay phase; any connection still alive past this point is // terminated regardless of activity. - // A single goroutine handles both checks. The shared - // idleTimedOut flag is used by the io.Copy goroutines to + // * MinThroughputBytes/MinThroughputWindow enforce a minimum + // average throughput on the relay; a connection that has + // transferred fewer than minBytes over the last window is + // considered to be slow-leeching and terminated. The grace + // period suppresses the check during initial ramp-up so + // short rsync exchanges (file lists, etc.) are not killed. + // A single goroutine handles all checks. The shared + // watcherClosed flag is used by the io.Copy goroutines to // suppress the expected "use of closed network connection" error - // log when the watcher initiated the close. - idleTimeout, maxDuration, _ := s.getRelayTimings() + // log when the watcher initiated the close. The reason for the + // closure is recorded directly by the watcher into per-upstream + // counters and access log. + timings := s.getRelayTimings() + idleTimeout := timings.idle + maxDuration := timings.maxDuration + minBytes := timings.minBytes + minWindow := timings.minWindow + minGrace := timings.minGrace + throughputEnabled := minBytes > 0 && minWindow > 0 relayStartedAt := time.Now() var idleTimedOut atomic.Bool sentClosed := make(chan struct{}) receivedClosed := make(chan struct{}) - if idleTimeout > 0 || maxDuration > 0 { + if idleTimeout > 0 || maxDuration > 0 || throughputEnabled { var lastActivity *atomic.Int64 if idleTimeout > 0 { lastActivity = &atomic.Int64{} @@ -899,27 +972,43 @@ func (s *Server) relay(ctx context.Context, index uint32, downConn net.Conn) err upReader.lastActivity = lastActivity } - // Wake at most a few times per the smaller of the two + // Wake at most a few times per the smallest of the // configured windows so timeouts are detected within roughly // 1.25x the configured value in the worst case, while keeping // wakeup overhead negligible. interval := time.Duration(0) - if idleTimeout > 0 { - interval = idleTimeout / 4 - } - if maxDuration > 0 { - d := maxDuration / 4 + bumpInterval := func(d time.Duration) { + if d <= 0 { + return + } if interval == 0 || d < interval { interval = d } } + if idleTimeout > 0 { + bumpInterval(idleTimeout / 4) + } + if maxDuration > 0 { + bumpInterval(maxDuration / 4) + } + if throughputEnabled { + bumpInterval(minWindow / 4) + } if interval < time.Second { interval = time.Second } + upstreamName := target.Upstream go func() { ticker := time.NewTicker(interval) defer ticker.Stop() + // Sliding-window throughput state. lastSampleTime moves + // forward each time we reach a full window, with + // lastSampleBytes capturing the cumulative byte count at + // that moment. delta over the next window must meet + // minBytes or the connection is terminated. + lastSampleTime := relayStartedAt + lastSampleBytes := int64(0) for { select { case <-sentClosed: @@ -931,6 +1020,7 @@ func (s *Server) relay(ctx context.Context, index uint32, downConn net.Conn) err last := time.Unix(0, lastActivity.Load()) if now.Sub(last) >= idleTimeout { idleTimedOut.Store(true) + s.getUpstreamCounters(upstreamName).idleTerminated.Add(1) s.accessLog.F("client %s idle for module %s exceeds %s, closing", ip, moduleName, idleTimeout) _ = upConn.Close() _ = downConn.Close() @@ -939,11 +1029,27 @@ func (s *Server) relay(ctx context.Context, index uint32, downConn net.Conn) err } if maxDuration > 0 && now.Sub(relayStartedAt) >= maxDuration { idleTimedOut.Store(true) + s.getUpstreamCounters(upstreamName).maxDurationTerminated.Add(1) s.accessLog.F("client %s exceeded max relay duration %s for module %s, closing", ip, maxDuration, moduleName) _ = upConn.Close() _ = downConn.Close() return } + if throughputEnabled && now.Sub(relayStartedAt) >= minGrace && now.Sub(lastSampleTime) >= minWindow { + curBytes := info.SentBytes.Load() + info.ReceivedBytes.Load() + delta := curBytes - lastSampleBytes + if delta < minBytes { + idleTimedOut.Store(true) + s.getUpstreamCounters(upstreamName).throughputFloorTerminated.Add(1) + s.accessLog.F("client %s for module %s below throughput floor (%d bytes < %d bytes in %s), closing", + ip, moduleName, delta, minBytes, minWindow) + _ = upConn.Close() + _ = downConn.Close() + return + } + lastSampleTime = now + lastSampleBytes = curBytes + } } } }() diff --git a/pkg/server/server_test.go b/pkg/server/server_test.go index df1d230..5a0995f 100644 --- a/pkg/server/server_test.go +++ b/pkg/server/server_test.go @@ -244,6 +244,7 @@ func TestRelayIdleTimeoutClosesIdleConnection(t *testing.T) { srv.modules = map[string][]Target{ "fake": {{Upstream: "u1", Addr: fakeRsync.Listener.Addr().String()}}, } + srv.upstreams = []upstreamConfig{{Name: "u1"}} srv.upstreamQueues = map[string]*queue.Queue{"u1": queue.New(0, 0)} rawConn, err := net.Dial("tcp", srv.TCPListener.Addr().String()) @@ -277,6 +278,24 @@ func TestRelayIdleTimeoutClosesIdleConnection(t *testing.T) { logData, err := os.ReadFile(accessLogPath) r.NoError(err) assert.Contains(t, string(logData), "idle for module fake exceeds") + + // The watcher must record the termination reason on the + // per-upstream counter and surface it via /metrics so that + // operators can distinguish a hung-and-killed connection from a + // normal completion. + r.Eventually(func() bool { + return srv.getUpstreamCounters("u1").idleTerminated.Load() == 1 + }, time.Second, 10*time.Millisecond) + + resp, err := testHTTPClient().Get("http://" + srv.HTTPListener.Addr().String() + "/metrics") + r.NoError(err) + defer resp.Body.Close() + body, err := io.ReadAll(resp.Body) + r.NoError(err) + text := string(body) + assert.Contains(t, text, "# HELP rsync_proxy_relay_idle_timeout_terminated_total") + assert.Contains(t, text, "# TYPE rsync_proxy_relay_idle_timeout_terminated_total counter") + assert.Contains(t, text, "rsync_proxy_relay_idle_timeout_terminated_total{upstream=\"u1\"} 1\n") } // TestRelayIdleTimeoutNotTriggeredWhenActive verifies that the idle @@ -1331,6 +1350,7 @@ func TestRelayMaxDurationClosesLongConnection(t *testing.T) { srv.modules = map[string][]Target{ "fake": {{Upstream: "u1", Addr: fakeRsync.Listener.Addr().String()}}, } + srv.upstreams = []upstreamConfig{{Name: "u1"}} srv.upstreamQueues = map[string]*queue.Queue{"u1": queue.New(0, 0)} rawConn, err := net.Dial("tcp", srv.TCPListener.Addr().String()) @@ -1361,6 +1381,20 @@ func TestRelayMaxDurationClosesLongConnection(t *testing.T) { r.NoError(err) assert.Contains(t, string(logData), "exceeded max relay duration") assert.Contains(t, string(logData), "for module fake") + + r.Eventually(func() bool { + return srv.getUpstreamCounters("u1").maxDurationTerminated.Load() == 1 + }, time.Second, 10*time.Millisecond) + + resp, err := testHTTPClient().Get("http://" + srv.HTTPListener.Addr().String() + "/metrics") + r.NoError(err) + defer resp.Body.Close() + body, err := io.ReadAll(resp.Body) + r.NoError(err) + text := string(body) + assert.Contains(t, text, "# HELP rsync_proxy_relay_max_duration_terminated_total") + assert.Contains(t, text, "# TYPE rsync_proxy_relay_max_duration_terminated_total counter") + assert.Contains(t, text, "rsync_proxy_relay_max_duration_terminated_total{upstream=\"u1\"} 1\n") } // TestApplyTCPKeepAliveOnTCPConn exercises the applyTCPKeepAlive @@ -1481,6 +1515,62 @@ modules = ["m1"] `, wantMsg: "per_ip_max_active_connections must be non-negative", }, + { + name: "dial_timeout", + config: ` +[proxy] +listen = "127.0.0.1:0" +listen_http = "127.0.0.1:0" +dial_timeout = -1 + +[upstreams.u1] +address = "127.0.0.1:8730" +modules = ["m1"] +`, + wantMsg: "dial_timeout must be non-negative", + }, + { + name: "min_throughput_bytes", + config: ` +[proxy] +listen = "127.0.0.1:0" +listen_http = "127.0.0.1:0" +min_throughput_bytes = -1 + +[upstreams.u1] +address = "127.0.0.1:8730" +modules = ["m1"] +`, + wantMsg: "min_throughput_bytes must be non-negative", + }, + { + name: "min_throughput_window", + config: ` +[proxy] +listen = "127.0.0.1:0" +listen_http = "127.0.0.1:0" +min_throughput_window = -1 + +[upstreams.u1] +address = "127.0.0.1:8730" +modules = ["m1"] +`, + wantMsg: "min_throughput_window must be non-negative", + }, + { + name: "min_throughput_grace", + config: ` +[proxy] +listen = "127.0.0.1:0" +listen_http = "127.0.0.1:0" +min_throughput_grace = -1 + +[upstreams.u1] +address = "127.0.0.1:8730" +modules = ["m1"] +`, + wantMsg: "min_throughput_grace must be non-negative", + }, } for _, tc := range cases { @@ -1492,3 +1582,143 @@ modules = ["m1"] }) } } + +// TestThroughputFloorTerminatesSlowConnection verifies that when a +// throughput floor is configured, an otherwise-active relay (no idle +// timeout, no max-duration) which transfers fewer than MinThroughputBytes +// per MinThroughputWindow is forcibly torn down once the grace period +// elapses, the per-upstream counter is incremented, and the rejection +// is reflected in /metrics and the access log. +func TestThroughputFloorTerminatesSlowConnection(t *testing.T) { + srv := startServer(t) + defer srv.Close() + srv.RelayIdleTimeout = 0 + srv.RelayMaxDuration = 0 + // Demand effectively infeasible throughput in a 200ms window so + // the floor always trips. Grace = 0 so the very first sample + // after the window elapses is enough. + srv.MinThroughputBytes = 1 << 30 + srv.MinThroughputWindow = 200 * time.Millisecond + srv.MinThroughputGrace = 0 + accessLogPath := setupAccessLog(t, srv) + + r := require.New(t) + + upstreamReady := make(chan struct{}) + upstreamDone := make(chan struct{}) + + fakeRsync := rsync.NewServer(func(conn *rsync.Conn) { + defer conn.Close() + defer close(upstreamDone) + + _, _, err := doServerHandshake(conn, RsyncdServerVersion) + r.NoError(err, "upstream handshake") + close(upstreamReady) + + // Stay quiet so the relay phase has effectively zero throughput. + _, _ = io.ReadAll(conn) + }) + fakeRsync.Start() + defer fakeRsync.Close() + + srv.modules = map[string][]Target{ + "fake": {{Upstream: "u1", Addr: fakeRsync.Listener.Addr().String()}}, + } + srv.upstreams = []upstreamConfig{{Name: "u1"}} + srv.upstreamQueues = map[string]*queue.Queue{"u1": queue.New(0, 0)} + + rawConn, err := net.Dial("tcp", srv.TCPListener.Addr().String()) + r.NoError(err) + conn := rsync.NewConn(rawConn) + defer conn.Close() + + _, err = doClientHandshake(conn, RsyncdServerVersion, "fake") + r.NoError(err) + + <-upstreamReady + + start := time.Now() + _, err = io.ReadAll(conn) + r.NoError(err) + elapsed := time.Since(start) + + r.GreaterOrEqual(elapsed, srv.MinThroughputWindow, + "should have waited at least one throughput window before tearing down") + r.Less(elapsed, 5*time.Second, "should not block far beyond the throughput window") + + select { + case <-upstreamDone: + case <-time.After(2 * time.Second): + t.Fatal("upstream connection was not closed after throughput floor trip") + } + + r.Eventually(func() bool { + return srv.getUpstreamCounters("u1").throughputFloorTerminated.Load() == 1 + }, time.Second, 10*time.Millisecond) + + resp, err := testHTTPClient().Get("http://" + srv.HTTPListener.Addr().String() + "/metrics") + r.NoError(err) + defer resp.Body.Close() + body, err := io.ReadAll(resp.Body) + r.NoError(err) + text := string(body) + assert.Contains(t, text, "# HELP rsync_proxy_throughput_floor_terminated_total") + assert.Contains(t, text, "# TYPE rsync_proxy_throughput_floor_terminated_total counter") + assert.Contains(t, text, "rsync_proxy_throughput_floor_terminated_total{upstream=\"u1\"} 1\n") + + logData, err := os.ReadFile(accessLogPath) + r.NoError(err) + assert.Contains(t, string(logData), "below throughput floor") + assert.Contains(t, string(logData), "for module fake") +} + +// TestLoadConfigPropagatesDialTimeoutAndThroughputSettings verifies +// that dial_timeout and the min_throughput_* settings are parsed, +// validated, and propagated into the dialer and Server fields. It also +// verifies that an unset min_throughput_grace falls back to the +// configured min_throughput_window, mirroring the documented default. +func TestLoadConfigPropagatesDialTimeoutAndThroughputSettings(t *testing.T) { + t.Run("explicit grace", func(t *testing.T) { + configContent := ` +[proxy] +listen = "127.0.0.1:0" +listen_http = "127.0.0.1:0" +dial_timeout = 5 +min_throughput_bytes = 65536 +min_throughput_window = 60 +min_throughput_grace = 30 + +[upstreams.u1] +address = "127.0.0.1:8730" +modules = ["m1"] +` + srv := New() + require.NoError(t, srv.ReadConfig(strings.NewReader(configContent), false)) + + assert.Equal(t, 5*time.Second, srv.dialer.Timeout) + assert.Equal(t, int64(65536), srv.MinThroughputBytes) + assert.Equal(t, 60*time.Second, srv.MinThroughputWindow) + assert.Equal(t, 30*time.Second, srv.MinThroughputGrace) + }) + + t.Run("grace defaults to window", func(t *testing.T) { + configContent := ` +[proxy] +listen = "127.0.0.1:0" +listen_http = "127.0.0.1:0" +min_throughput_bytes = 1024 +min_throughput_window = 90 + +[upstreams.u1] +address = "127.0.0.1:8730" +modules = ["m1"] +` + srv := New() + require.NoError(t, srv.ReadConfig(strings.NewReader(configContent), false)) + + assert.Equal(t, int64(1024), srv.MinThroughputBytes) + assert.Equal(t, 90*time.Second, srv.MinThroughputWindow) + assert.Equal(t, 90*time.Second, srv.MinThroughputGrace, + "unset min_throughput_grace must default to min_throughput_window") + }) +} From af5c3350dc3da58c85680a500cccf3c09a8093a0 Mon Sep 17 00:00:00 2001 From: yaoge123 Date: Tue, 2 Jun 2026 20:53:10 +0800 Subject: [PATCH 4/4] fix: address review comments on PR #37 - Use int64 comparison for per-IP cap. perIPLimit values from config are always small (typical caps are single-digit to low-double-digit), so the practical risk of overflow on 32-bit GOARCH was zero. Still, comparing the int64 counter against int64(perIPLimit) is more correct and removes the conversion entirely from review noise. - Clarify the watcher interval comment. The 1.25x worst-case detection claim only holds when all enabled timeouts are above 4 seconds, because of the explicit 1s floor on the wakeup interval. Document the trade-off explicitly so future maintainers know that sub-second timeouts (used only in tests) trade detection latency for kernel timer overhead. --- pkg/server/server.go | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/pkg/server/server.go b/pkg/server/server.go index f29ef1e..4be212a 100644 --- a/pkg/server/server.go +++ b/pkg/server/server.go @@ -823,7 +823,7 @@ func (s *Server) relay(ctx context.Context, index uint32, downConn net.Conn) err // IP from monopolizing both the active slots and the queue. if perIPLimit := s.getPerIPLimitForUpstream(target.Upstream); perIPLimit > 0 { counter := s.getPerIPCounter(target.Upstream, ip) - if n := counter.Add(1); int(n) > perIPLimit { + if n := counter.Add(1); n > int64(perIPLimit) { counter.Add(-1) s.getUpstreamCounters(target.Upstream).perIPRejected.Add(1) s.accessLog.F("client %s rejected for upstream %s module %s: per-IP cap of %d reached", ip, target.Upstream, moduleName, perIPLimit) @@ -972,10 +972,14 @@ func (s *Server) relay(ctx context.Context, index uint32, downConn net.Conn) err upReader.lastActivity = lastActivity } - // Wake at most a few times per the smallest of the - // configured windows so timeouts are detected within roughly - // 1.25x the configured value in the worst case, while keeping - // wakeup overhead negligible. + // Pick a wakeup interval as 1/4 of the smallest enabled window + // so the worst-case detection latency is roughly 1.25x the + // configured value, while keeping wakeup overhead negligible. + // A 1s floor is then applied: detection latency therefore + // degrades for sub-4-second settings (e.g. a 500ms idle + // timeout will detect inside ~1s + the 500ms slack rather + // than 125ms). In practice all production-meaningful values + // are well above 4 seconds, so this only matters for tests. interval := time.Duration(0) bumpInterval := func(d time.Duration) { if d <= 0 {