diff --git a/assets/config.example.toml b/assets/config.example.toml index 02c3b7e..22c841c 100644 --- a/assets/config.example.toml +++ b/assets/config.example.toml @@ -11,11 +11,60 @@ tls_key_file = "/etc/rsync-proxy/tls/server.key" motd = "Served by rsync-proxy (https://github.com/ustclug/rsync-proxy)" +# Idle timeout (seconds) applied during the bidirectional relay phase +# of a client connection. If no data flows in either direction for +# this duration, the connection is terminated. 0 (the default) +# disables the timeout. This mirrors rsyncd's "timeout" setting in +# rsyncd.conf(5); 600 is a common choice for public mirrors. +#relay_idle_timeout = 0 + +# Hard upper bound (seconds) on the total wall-clock duration of the +# bidirectional relay phase. When exceeded the proxy closes the +# connection regardless of activity. rsync clients typically resume +# on reconnect. 0 (the default) disables this cap. +#relay_max_duration = 0 + +# TCP keepalive period (seconds) applied to accepted client +# connections and dialed upstream connections. Helps detect +# half-open connections (peer crashed, NAT entry reaped) within +# minutes rather than the OS default (~2 hours). 0 (the default) +# leaves the OS-default keepalive behavior in place. +#tcp_keepalive = 0 + +# Default per-IP per-upstream concurrent active connection cap. +# Limits how many simultaneous active relay connections a single +# client IP may have to any one upstream, preventing a single client +# from monopolizing upstream slots. 0 (the default) disables the +# limit. Each upstream may override this via its own +# per_ip_max_active_connections setting below. +#per_ip_max_active_connections = 0 + +# Connect timeout (seconds) when dialing upstreams. Bounds how long +# the proxy waits for an unresponsive upstream before failing fast +# and incrementing rsync_proxy_upstream_dial_errors_total. 0 (the +# default) leaves the OS connect-attempt behavior in place +# (typically ~75s of SYN retries on Linux). +#dial_timeout = 0 + +# Throughput floor enforced during the relay phase. A connection +# that transfers fewer than min_throughput_bytes over the most +# recent min_throughput_window seconds is treated as slow-leeching +# and terminated. min_throughput_grace gives a fresh connection +# this many seconds of ramp-up before the floor is enforced; +# defaults to min_throughput_window when unset. Set +# min_throughput_bytes to 0 (the default) to disable the check. +#min_throughput_bytes = 0 +#min_throughput_window = 60 +#min_throughput_grace = 60 + [upstreams.u1] address = "127.0.0.1:1234" modules = ["foo"] max_active_connections = 60 max_queued_connections = 60 +# Override the proxy-wide per_ip_max_active_connections for this +# upstream. 0 means inherit the proxy-wide default. +#per_ip_max_active_connections = 4 [upstreams.u1_auto] address = "127.0.0.1:1234" diff --git a/pkg/server/config.go b/pkg/server/config.go index f9ed120..c6aabd3 100644 --- a/pkg/server/config.go +++ b/pkg/server/config.go @@ -15,6 +15,11 @@ type Upstream struct { UseProxyProtocol bool `toml:"use_proxy_protocol"` MaxActiveConns int `toml:"max_active_connections"` MaxQueuedConns int `toml:"max_queued_connections"` + // PerIPMaxActiveConns overrides the proxy-wide + // per_ip_max_active_connections setting for this upstream. A + // value of 0 (the default, i.e. field omitted) means the + // upstream inherits the proxy-wide value. + PerIPMaxActiveConns int `toml:"per_ip_max_active_connections"` } type ProxySettings struct { @@ -26,6 +31,55 @@ type ProxySettings struct { ErrorLog string `toml:"error_log"` TLSCertFile string `toml:"tls_cert_file"` TLSKeyFile string `toml:"tls_key_file"` + // RelayIdleTimeoutSecs is the idle timeout (in seconds) applied + // during the bidirectional relay phase of a connection. If no data + // flows in either direction for this duration, the connection is + // terminated. 0 (the default) disables the timeout. + // + // This mirrors the semantics of the rsyncd "timeout" setting (see + // rsyncd.conf(5)), which is an I/O timeout. A common choice for + // public mirrors is 600. + RelayIdleTimeoutSecs int `toml:"relay_idle_timeout"` + // RelayMaxDurationSecs is the maximum total wall-clock duration + // (in seconds) of the bidirectional relay phase. When exceeded + // the proxy closes the connection regardless of activity. 0 (the + // default) disables this hard cap. rsync clients will typically + // reconnect and resume on the next run. + RelayMaxDurationSecs int `toml:"relay_max_duration"` + // TCPKeepAliveSecs enables TCP keepalive on accepted client + // connections and on dialed upstream connections. The value is + // the keepalive period in seconds; 0 (the default) leaves the + // OS-default keepalive behavior in place (typically: disabled or + // ~2 hours). Enabling this helps detect half-open connections + // (peer crashed, NAT reaped) within minutes rather than hours. + TCPKeepAliveSecs int `toml:"tcp_keepalive"` + // PerIPMaxActiveConns is the proxy-wide default for the per-IP + // per-upstream concurrency cap applied during the relay phase. + // 0 (the default) disables the limit. Each upstream may + // override this via [upstreams.X].per_ip_max_active_connections. + PerIPMaxActiveConns int `toml:"per_ip_max_active_connections"` + // DialTimeoutSecs caps how long the proxy waits when dialing an + // upstream rsync server. A value of 0 (the default) leaves the + // OS-default TCP connect behavior in place (~75s on Linux). When + // an upstream is unreachable this surfaces a fast failure to + // the client and increments rsync_proxy_upstream_dial_errors_total + // without tying up the listener for the full kernel SYN-retry + // budget. + DialTimeoutSecs int `toml:"dial_timeout"` + // MinThroughputBytes, MinThroughputWindowSecs and + // MinThroughputGraceSecs together implement a throughput floor + // during the relay phase. Within any window of + // min_throughput_window seconds the connection must transfer at + // least min_throughput_bytes (counting both directions); if it + // does not, the proxy closes the connection. The grace period + // suppresses the check for the first min_throughput_grace seconds + // after the relay starts to avoid killing a slow-start session. + // Setting min_throughput_bytes or min_throughput_window to 0 (the + // default) disables the floor. min_throughput_grace defaults to + // the value of min_throughput_window when 0. + MinThroughputBytes int64 `toml:"min_throughput_bytes"` + MinThroughputWindowSecs int `toml:"min_throughput_window"` + MinThroughputGraceSecs int `toml:"min_throughput_grace"` } type Config struct { diff --git a/pkg/server/metrics.go b/pkg/server/metrics.go index fedfeb1..37ee1dc 100644 --- a/pkg/server/metrics.go +++ b/pkg/server/metrics.go @@ -97,6 +97,38 @@ func (s *Server) writePrometheusMetrics(w io.Writer, now time.Time) { prometheusEscapeLabelValue(u.Name), c.dialError.Load()) } + _, _ = fmt.Fprintln(w, "# HELP rsync_proxy_per_ip_rejected_total Total connections rejected by the per-IP per-upstream concurrency cap.") + _, _ = fmt.Fprintln(w, "# TYPE rsync_proxy_per_ip_rejected_total counter") + for _, u := range upstreams { + c := s.getUpstreamCounters(u.Name) + _, _ = fmt.Fprintf(w, "rsync_proxy_per_ip_rejected_total{upstream=\"%s\"} %d\n", + prometheusEscapeLabelValue(u.Name), c.perIPRejected.Load()) + } + + _, _ = fmt.Fprintln(w, "# HELP rsync_proxy_relay_idle_timeout_terminated_total Total relay connections terminated by the idle timeout watcher per upstream.") + _, _ = fmt.Fprintln(w, "# TYPE rsync_proxy_relay_idle_timeout_terminated_total counter") + for _, u := range upstreams { + c := s.getUpstreamCounters(u.Name) + _, _ = fmt.Fprintf(w, "rsync_proxy_relay_idle_timeout_terminated_total{upstream=\"%s\"} %d\n", + prometheusEscapeLabelValue(u.Name), c.idleTerminated.Load()) + } + + _, _ = fmt.Fprintln(w, "# HELP rsync_proxy_relay_max_duration_terminated_total Total relay connections terminated for exceeding the max-duration cap per upstream.") + _, _ = fmt.Fprintln(w, "# TYPE rsync_proxy_relay_max_duration_terminated_total counter") + for _, u := range upstreams { + c := s.getUpstreamCounters(u.Name) + _, _ = fmt.Fprintf(w, "rsync_proxy_relay_max_duration_terminated_total{upstream=\"%s\"} %d\n", + prometheusEscapeLabelValue(u.Name), c.maxDurationTerminated.Load()) + } + + _, _ = fmt.Fprintln(w, "# HELP rsync_proxy_throughput_floor_terminated_total Total relay connections terminated for falling below the configured throughput floor per upstream.") + _, _ = fmt.Fprintln(w, "# TYPE rsync_proxy_throughput_floor_terminated_total counter") + for _, u := range upstreams { + c := s.getUpstreamCounters(u.Name) + _, _ = fmt.Fprintf(w, "rsync_proxy_throughput_floor_terminated_total{upstream=\"%s\"} %d\n", + prometheusEscapeLabelValue(u.Name), c.throughputFloorTerminated.Load()) + } + _, _ = fmt.Fprintln(w, "# HELP rsync_proxy_unknown_module_requests_total Total requests for unknown modules.") _, _ = fmt.Fprintln(w, "# TYPE rsync_proxy_unknown_module_requests_total counter") _, _ = fmt.Fprintf(w, "rsync_proxy_unknown_module_requests_total %d\n", s.unknownModuleCount.Load()) diff --git a/pkg/server/server.go b/pkg/server/server.go index be46f45..4be212a 100644 --- a/pkg/server/server.go +++ b/pkg/server/server.go @@ -119,12 +119,29 @@ type upstreamConfig struct { DiscoverModules bool MaxActiveConns int MaxQueuedConns int + // PerIPMaxActiveConns is the resolved (effective) limit on the + // number of concurrent active relay connections from a single + // client IP to this upstream. 0 means no per-IP cap. Computed at + // load time as: per-upstream override (Upstream.PerIPMaxActiveConns) + // or, if zero, the proxy-wide default (Proxy.PerIPMaxActiveConns). + PerIPMaxActiveConns int } // upstreamCounters holds per-upstream failure counters. type upstreamCounters struct { - queueFull atomic.Uint64 - dialError atomic.Uint64 + queueFull atomic.Uint64 + dialError atomic.Uint64 + perIPRejected atomic.Uint64 + idleTerminated atomic.Uint64 + maxDurationTerminated atomic.Uint64 + throughputFloorTerminated atomic.Uint64 +} + +// perIPCountKey identifies a (upstream, client IP) pair for tracking +// per-IP per-upstream concurrent active relay connections. +type perIPCountKey struct { + upstream string + ip string } type Server struct { @@ -138,6 +155,38 @@ type Server struct { ReadTimeout time.Duration WriteTimeout time.Duration + // RelayIdleTimeout is the idle (no I/O activity in either + // direction) timeout applied during the bidirectional relay phase + // of a connection. A value of 0 (the default) disables the + // timeout, matching rsyncd's "timeout = 0" behavior. See + // rsyncd.conf(5). + RelayIdleTimeout time.Duration + + // RelayMaxDuration is a hard cap on the total wall-clock + // duration of the bidirectional relay phase of a connection. + // When exceeded the proxy closes both directions regardless of + // activity. 0 (the default) disables the cap. + RelayMaxDuration time.Duration + + // TCPKeepAlive is the keepalive period applied to accepted + // client connections and to dialed upstream connections. 0 (the + // default) leaves the OS-default keepalive behavior in place + // (typically: disabled, or ~2 hours). + TCPKeepAlive time.Duration + + // MinThroughputBytes, MinThroughputWindow and + // MinThroughputGrace together implement a throughput floor + // during the relay phase. Within any window of + // MinThroughputWindow the connection must transfer at least + // MinThroughputBytes (counting both directions); otherwise the + // proxy closes the connection. The grace period suppresses the + // check for the first MinThroughputGrace after the relay + // starts. Setting MinThroughputBytes or MinThroughputWindow to + // zero disables the floor. + MinThroughputBytes int64 + MinThroughputWindow time.Duration + MinThroughputGrace time.Duration + Motd string // --- End of options section @@ -166,6 +215,11 @@ type Server struct { upstreamCounters sync.Map unknownModuleCount atomic.Uint64 + // Per-(upstream, client IP) active-connection counters. + // Lazy-initialized via getPerIPCounter. + // map key is perIPCountKey. Value is *atomic.Int64. + perIPCounts sync.Map + TCPListener net.Listener TLSListener net.Listener HTTPListener net.Listener @@ -174,12 +228,20 @@ type Server struct { type countingReader struct { reader io.Reader counter *atomic.Int64 + // lastActivity is the UnixNano timestamp of the most recent + // successful read (n > 0). It is updated atomically so that an + // idle watcher goroutine can observe activity without locking. + // May be nil when activity tracking is not needed. + lastActivity *atomic.Int64 } func (cr *countingReader) Read(p []byte) (n int, err error) { n, err = cr.reader.Read(p) if n > 0 { cr.counter.Add(int64(n)) + if cr.lastActivity != nil { + cr.lastActivity.Store(time.Now().UnixNano()) + } } return n, err } @@ -226,6 +288,28 @@ func (s *Server) loadConfig(c *Config, openLog bool) error { tlsCertificate = &cert } + if c.Proxy.RelayMaxDurationSecs < 0 { + return fmt.Errorf("relay_max_duration must be non-negative, got %d", c.Proxy.RelayMaxDurationSecs) + } + if c.Proxy.TCPKeepAliveSecs < 0 { + return fmt.Errorf("tcp_keepalive must be non-negative, got %d", c.Proxy.TCPKeepAliveSecs) + } + if c.Proxy.PerIPMaxActiveConns < 0 { + return fmt.Errorf("per_ip_max_active_connections must be non-negative, got %d", c.Proxy.PerIPMaxActiveConns) + } + if c.Proxy.DialTimeoutSecs < 0 { + return fmt.Errorf("dial_timeout must be non-negative, got %d", c.Proxy.DialTimeoutSecs) + } + if c.Proxy.MinThroughputBytes < 0 { + return fmt.Errorf("min_throughput_bytes must be non-negative, got %d", c.Proxy.MinThroughputBytes) + } + if c.Proxy.MinThroughputWindowSecs < 0 { + return fmt.Errorf("min_throughput_window must be non-negative, got %d", c.Proxy.MinThroughputWindowSecs) + } + if c.Proxy.MinThroughputGraceSecs < 0 { + return fmt.Errorf("min_throughput_grace must be non-negative, got %d", c.Proxy.MinThroughputGraceSecs) + } + upstreams := make([]upstreamConfig, 0, len(c.Upstreams)) upstreamNames := make([]string, 0, len(c.Upstreams)) for upstreamName := range c.Upstreams { @@ -237,17 +321,27 @@ func (s *Server) loadConfig(c *Config, openLog bool) error { if len(v.Modules) == 0 && !v.DiscoverModules { return fmt.Errorf("upstream=%s must set modules or discover_modules", upstreamName) } + if v.PerIPMaxActiveConns < 0 { + return fmt.Errorf("upstream=%s: per_ip_max_active_connections must be non-negative, got %d", upstreamName, v.PerIPMaxActiveConns) + } addr := v.Address if err := validateTCPOrUnixAddr(addr); err != nil { return fmt.Errorf("resolve address: %w, upstream=%s, address=%s", err, upstreamName, addr) } + // Resolve effective per-IP cap: per-upstream override wins; + // fall back to proxy-wide default; 0 means no cap. + effectivePerIP := v.PerIPMaxActiveConns + if effectivePerIP == 0 { + effectivePerIP = c.Proxy.PerIPMaxActiveConns + } upstreams = append(upstreams, upstreamConfig{ - Name: upstreamName, - Target: Target{Upstream: upstreamName, Addr: addr, UseProxyProtocol: v.UseProxyProtocol}, - Modules: slices.Clone(v.Modules), - DiscoverModules: v.DiscoverModules, - MaxActiveConns: v.MaxActiveConns, - MaxQueuedConns: v.MaxQueuedConns, + Name: upstreamName, + Target: Target{Upstream: upstreamName, Addr: addr, UseProxyProtocol: v.UseProxyProtocol}, + Modules: slices.Clone(v.Modules), + DiscoverModules: v.DiscoverModules, + MaxActiveConns: v.MaxActiveConns, + MaxQueuedConns: v.MaxQueuedConns, + PerIPMaxActiveConns: effectivePerIP, }) } @@ -282,6 +376,27 @@ func (s *Server) loadConfig(c *Config, openLog bool) error { } } s.Motd = c.Proxy.Motd + if c.Proxy.RelayIdleTimeoutSecs < 0 { + return fmt.Errorf("relay_idle_timeout must be non-negative, got %d", c.Proxy.RelayIdleTimeoutSecs) + } + s.RelayIdleTimeout = time.Duration(c.Proxy.RelayIdleTimeoutSecs) * time.Second + s.RelayMaxDuration = time.Duration(c.Proxy.RelayMaxDurationSecs) * time.Second + s.TCPKeepAlive = time.Duration(c.Proxy.TCPKeepAliveSecs) * time.Second + s.MinThroughputBytes = c.Proxy.MinThroughputBytes + s.MinThroughputWindow = time.Duration(c.Proxy.MinThroughputWindowSecs) * time.Second + graceSecs := c.Proxy.MinThroughputGraceSecs + if graceSecs == 0 { + // Default the grace period to the window itself: a fresh + // connection gets one full window to ramp up before the + // floor is enforced. + graceSecs = c.Proxy.MinThroughputWindowSecs + } + s.MinThroughputGrace = time.Duration(graceSecs) * time.Second + // Reflect the new keepalive and dial-timeout settings on the + // dialer used to dial upstreams. The dialer is consulted under + // reloadLock via getDialer() to avoid racing with reloads. + s.dialer.KeepAlive = s.TCPKeepAlive + s.dialer.Timeout = time.Duration(c.Proxy.DialTimeoutSecs) * time.Second s.modules = modules s.upstreams = resolvedUpstreams s.upstreamQueues = s.updateUpstreamQueuesLocked(resolvedUpstreams) @@ -297,12 +412,13 @@ func resolveUpstreams(upstreams []upstreamConfig, discovered map[string][]string modules = slices.Clone(discovered[upstream.Name]) } resolved = append(resolved, upstreamConfig{ - Name: upstream.Name, - Target: upstream.Target, - Modules: modules, - DiscoverModules: upstream.DiscoverModules, - MaxActiveConns: upstream.MaxActiveConns, - MaxQueuedConns: upstream.MaxQueuedConns, + Name: upstream.Name, + Target: upstream.Target, + Modules: modules, + DiscoverModules: upstream.DiscoverModules, + MaxActiveConns: upstream.MaxActiveConns, + MaxQueuedConns: upstream.MaxQueuedConns, + PerIPMaxActiveConns: upstream.PerIPMaxActiveConns, }) } return resolved @@ -346,6 +462,93 @@ func (s *Server) getUpstreamCounters(name string) *upstreamCounters { return v.(*upstreamCounters) } +// getPerIPCounter returns the (upstream, ip) active-connection counter, +// creating it lazily. Safe for concurrent use. +func (s *Server) getPerIPCounter(upstream, ip string) *atomic.Int64 { + key := perIPCountKey{upstream: upstream, ip: ip} + if v, ok := s.perIPCounts.Load(key); ok { + return v.(*atomic.Int64) + } + v, _ := s.perIPCounts.LoadOrStore(key, &atomic.Int64{}) + return v.(*atomic.Int64) +} + +// getDialer returns a copy of the current upstream dialer, taking the +// reload lock to avoid racing with config reloads that mutate the +// dialer (e.g. KeepAlive). The returned value is safe to pass by value +// to dialContextTCPOrUnix. +func (s *Server) getDialer() net.Dialer { + s.reloadLock.RLock() + defer s.reloadLock.RUnlock() + return s.dialer +} + +// relayTimings is a snapshot of all relay-phase timing settings, +// captured atomically under the reload lock. Callers should consume +// this snapshot for the lifetime of a single relay connection so the +// behavior remains consistent across reloads. +type relayTimings struct { + idle time.Duration + maxDuration time.Duration + tcpKeepAlive time.Duration + minBytes int64 + minWindow time.Duration + minGrace time.Duration +} + +// getRelayTimings returns a snapshot of the relay-phase timing +// settings under the reload lock. These fields are mutated by config +// reloads, so callers must not read them directly. +func (s *Server) getRelayTimings() relayTimings { + s.reloadLock.RLock() + defer s.reloadLock.RUnlock() + return relayTimings{ + idle: s.RelayIdleTimeout, + maxDuration: s.RelayMaxDuration, + tcpKeepAlive: s.TCPKeepAlive, + minBytes: s.MinThroughputBytes, + minWindow: s.MinThroughputWindow, + minGrace: s.MinThroughputGrace, + } +} + +// getTCPKeepAlive returns the configured TCP keepalive period under +// the reload lock. +func (s *Server) getTCPKeepAlive() time.Duration { + s.reloadLock.RLock() + defer s.reloadLock.RUnlock() + return s.TCPKeepAlive +} + +// getPerIPLimitForUpstream returns the configured per-IP active +// connection cap for the named upstream, or 0 if none is set or the +// upstream is unknown. Safe for concurrent use. +func (s *Server) getPerIPLimitForUpstream(name string) int { + s.reloadLock.RLock() + defer s.reloadLock.RUnlock() + for i := range s.upstreams { + if s.upstreams[i].Name == name { + return s.upstreams[i].PerIPMaxActiveConns + } + } + return 0 +} + +// applyTCPKeepAlive enables TCP keepalive on the given connection if +// it is a *net.TCPConn and a positive period is provided. Other +// connection types (e.g. unix sockets) are silently ignored. +func applyTCPKeepAlive(conn net.Conn, period time.Duration) { + if period <= 0 { + return + } + tc, ok := conn.(*net.TCPConn) + if !ok { + return + } + _ = tc.SetKeepAlive(true) + _ = tc.SetKeepAlivePeriod(period) +} + func buildModuleTargets(upstreams []upstreamConfig) map[string][]Target { modules := map[string][]Target{} for _, upstream := range upstreams { @@ -613,6 +816,23 @@ func (s *Server) relay(ctx context.Context, index uint32, downConn net.Conn) err return fmt.Errorf("no queue configured for upstream %s", target.Upstream) } + // Per-IP per-upstream concurrency cap. A positive cap means the + // same client IP may not have more than N simultaneous active + // relay connections to this upstream. Counted before queue + // admission so the cap also bounds queueing, preventing a single + // IP from monopolizing both the active slots and the queue. + if perIPLimit := s.getPerIPLimitForUpstream(target.Upstream); perIPLimit > 0 { + counter := s.getPerIPCounter(target.Upstream, ip) + if n := counter.Add(1); n > int64(perIPLimit) { + counter.Add(-1) + s.getUpstreamCounters(target.Upstream).perIPRejected.Add(1) + s.accessLog.F("client %s rejected for upstream %s module %s: per-IP cap of %d reached", ip, target.Upstream, moduleName, perIPLimit) + _, _ = writeWithTimeout(downConn, fmt.Appendf(nil, "@ERROR: per-IP connection limit of %d for upstream %s reached, retry later\n", perIPLimit, target.Upstream), writeTimeout) + return nil + } + defer counter.Add(-1) + } + handle := upstreamQueue.Acquire() defer handle.Release() status := <-handle.C @@ -651,12 +871,19 @@ func (s *Server) relay(ctx context.Context, index uint32, downConn net.Conn) err } } - upConn, err := dialContextTCPOrUnix(ctx, s.dialer, upstreamAddr) + upConn, err := dialContextTCPOrUnix(ctx, s.getDialer(), upstreamAddr) if err != nil { s.getUpstreamCounters(target.Upstream).dialError.Add(1) return fmt.Errorf("dial to upstream: %s: %w", upstreamAddr, err) } defer upConn.Close() + // Enable TCP keepalive on the upstream-side connection so that a + // dead/half-open peer is detected within the configured period + // rather than relying on the OS default (commonly ~2 hours, or + // disabled). dialer.KeepAlive only takes effect for the initial + // SYN window; explicitly setting it on the resulting *TCPConn + // is portable and idempotent. + applyTCPKeepAlive(upConn, s.getTCPKeepAlive()) upAddr := netAddrToString(upConn.RemoteAddr()) if useProxyProtocol { err := writeProxyProtocolHeader(upConn, downConn.RemoteAddr(), upConn.RemoteAddr(), s.WriteTimeout) @@ -705,19 +932,143 @@ func (s *Server) relay(ctx context.Context, index uint32, downConn net.Conn) err downReader := &countingReader{reader: downConn, counter: &info.ReceivedBytes} upReader := &countingReader{reader: upConn, counter: &info.SentBytes} + // Optional watchers on the bidirectional relay phase: + // * RelayIdleTimeout terminates a connection that has had no + // data flow in either direction for the configured duration + // (rsyncd "timeout" semantics, rsyncd.conf(5)). + // * RelayMaxDuration is a hard wall-clock cap on the entire + // relay phase; any connection still alive past this point is + // terminated regardless of activity. + // * MinThroughputBytes/MinThroughputWindow enforce a minimum + // average throughput on the relay; a connection that has + // transferred fewer than minBytes over the last window is + // considered to be slow-leeching and terminated. The grace + // period suppresses the check during initial ramp-up so + // short rsync exchanges (file lists, etc.) are not killed. + // A single goroutine handles all checks. The shared + // watcherClosed flag is used by the io.Copy goroutines to + // suppress the expected "use of closed network connection" error + // log when the watcher initiated the close. The reason for the + // closure is recorded directly by the watcher into per-upstream + // counters and access log. + timings := s.getRelayTimings() + idleTimeout := timings.idle + maxDuration := timings.maxDuration + minBytes := timings.minBytes + minWindow := timings.minWindow + minGrace := timings.minGrace + throughputEnabled := minBytes > 0 && minWindow > 0 + relayStartedAt := time.Now() + var idleTimedOut atomic.Bool sentClosed := make(chan struct{}) receivedClosed := make(chan struct{}) + if idleTimeout > 0 || maxDuration > 0 || throughputEnabled { + var lastActivity *atomic.Int64 + if idleTimeout > 0 { + lastActivity = &atomic.Int64{} + lastActivity.Store(relayStartedAt.UnixNano()) + downReader.lastActivity = lastActivity + upReader.lastActivity = lastActivity + } + + // Pick a wakeup interval as 1/4 of the smallest enabled window + // so the worst-case detection latency is roughly 1.25x the + // configured value, while keeping wakeup overhead negligible. + // A 1s floor is then applied: detection latency therefore + // degrades for sub-4-second settings (e.g. a 500ms idle + // timeout will detect inside ~1s + the 500ms slack rather + // than 125ms). In practice all production-meaningful values + // are well above 4 seconds, so this only matters for tests. + interval := time.Duration(0) + bumpInterval := func(d time.Duration) { + if d <= 0 { + return + } + if interval == 0 || d < interval { + interval = d + } + } + if idleTimeout > 0 { + bumpInterval(idleTimeout / 4) + } + if maxDuration > 0 { + bumpInterval(maxDuration / 4) + } + if throughputEnabled { + bumpInterval(minWindow / 4) + } + if interval < time.Second { + interval = time.Second + } + + upstreamName := target.Upstream + go func() { + ticker := time.NewTicker(interval) + defer ticker.Stop() + // Sliding-window throughput state. lastSampleTime moves + // forward each time we reach a full window, with + // lastSampleBytes capturing the cumulative byte count at + // that moment. delta over the next window must meet + // minBytes or the connection is terminated. + lastSampleTime := relayStartedAt + lastSampleBytes := int64(0) + for { + select { + case <-sentClosed: + return + case <-receivedClosed: + return + case now := <-ticker.C: + if idleTimeout > 0 { + last := time.Unix(0, lastActivity.Load()) + if now.Sub(last) >= idleTimeout { + idleTimedOut.Store(true) + s.getUpstreamCounters(upstreamName).idleTerminated.Add(1) + s.accessLog.F("client %s idle for module %s exceeds %s, closing", ip, moduleName, idleTimeout) + _ = upConn.Close() + _ = downConn.Close() + return + } + } + if maxDuration > 0 && now.Sub(relayStartedAt) >= maxDuration { + idleTimedOut.Store(true) + s.getUpstreamCounters(upstreamName).maxDurationTerminated.Add(1) + s.accessLog.F("client %s exceeded max relay duration %s for module %s, closing", ip, maxDuration, moduleName) + _ = upConn.Close() + _ = downConn.Close() + return + } + if throughputEnabled && now.Sub(relayStartedAt) >= minGrace && now.Sub(lastSampleTime) >= minWindow { + curBytes := info.SentBytes.Load() + info.ReceivedBytes.Load() + delta := curBytes - lastSampleBytes + if delta < minBytes { + idleTimedOut.Store(true) + s.getUpstreamCounters(upstreamName).throughputFloorTerminated.Add(1) + s.accessLog.F("client %s for module %s below throughput floor (%d bytes < %d bytes in %s), closing", + ip, moduleName, delta, minBytes, minWindow) + _ = upConn.Close() + _ = downConn.Close() + return + } + lastSampleTime = now + lastSampleBytes = curBytes + } + } + } + }() + } + go func() { _, err := io.Copy(upConn, downReader) - if err != nil { + if err != nil && !idleTimedOut.Load() { s.errorLog.F("copy from downstream to upstream: %v", err) } close(sentClosed) }() go func() { _, err := io.Copy(downConn, upReader) - if err != nil { + if err != nil && !idleTimedOut.Load() { s.errorLog.F("copy from upstream to downstream: %v", err) } close(receivedClosed) @@ -941,6 +1292,12 @@ func (s *Server) handleConn(ctx context.Context, conn net.Conn) { s.acceptedConnCount.Add(1) connIndex := s.connIndex.Add(1) + // Apply TCP keepalive on the accepted client connection so that + // half-open connections (peer crashed, NAT entry reaped) are + // detected within the configured period rather than waiting for + // the OS default. No-op on unix-domain or non-TCP listeners. + applyTCPKeepAlive(conn, s.getTCPKeepAlive()) + defer func() { err := recover() if err != nil { diff --git a/pkg/server/server_test.go b/pkg/server/server_test.go index 9b374aa..5a0995f 100644 --- a/pkg/server/server_test.go +++ b/pkg/server/server_test.go @@ -209,6 +209,144 @@ func TestClientReadTimeout(t *testing.T) { r.Equal(expected, string(allData)) } +// TestRelayIdleTimeoutClosesIdleConnection verifies that when +// RelayIdleTimeout is configured and no data flows in either +// direction during the bidirectional relay phase, the proxy tears the +// connection down. This mirrors rsyncd's "timeout" behavior in +// rsyncd.conf(5). +func TestRelayIdleTimeoutClosesIdleConnection(t *testing.T) { + srv := startServer(t) + defer srv.Close() + srv.RelayIdleTimeout = 500 * time.Millisecond + accessLogPath := setupAccessLog(t, srv) + + r := require.New(t) + + upstreamReady := make(chan struct{}) + upstreamDone := make(chan struct{}) + + fakeRsync := rsync.NewServer(func(conn *rsync.Conn) { + defer conn.Close() + defer close(upstreamDone) + + _, _, err := doServerHandshake(conn, RsyncdServerVersion) + r.NoError(err, "upstream handshake") + close(upstreamReady) + + // Stay quiet so the relay phase has no I/O. The proxy must + // close us once the idle timeout elapses; ReadAll then + // returns with an EOF / closed-connection error. + _, _ = io.ReadAll(conn) + }) + fakeRsync.Start() + defer fakeRsync.Close() + + srv.modules = map[string][]Target{ + "fake": {{Upstream: "u1", Addr: fakeRsync.Listener.Addr().String()}}, + } + srv.upstreams = []upstreamConfig{{Name: "u1"}} + srv.upstreamQueues = map[string]*queue.Queue{"u1": queue.New(0, 0)} + + rawConn, err := net.Dial("tcp", srv.TCPListener.Addr().String()) + r.NoError(err) + conn := rsync.NewConn(rawConn) + defer conn.Close() + + _, err = doClientHandshake(conn, RsyncdServerVersion, "fake") + r.NoError(err) + + <-upstreamReady + + start := time.Now() + // ReadAll should return shortly after the proxy closes our + // connection due to the idle timeout firing on the relay side. + _, err = io.ReadAll(conn) + r.NoError(err) + elapsed := time.Since(start) + + // Allow generous slack: must be at least the configured timeout, + // and not pathologically long (e.g. waiting forever). + r.GreaterOrEqual(elapsed, srv.RelayIdleTimeout, "should have waited at least the idle timeout") + r.Less(elapsed, 5*time.Second, "should not block far beyond the idle timeout") + + select { + case <-upstreamDone: + case <-time.After(2 * time.Second): + t.Fatal("upstream connection was not closed after idle timeout") + } + + logData, err := os.ReadFile(accessLogPath) + r.NoError(err) + assert.Contains(t, string(logData), "idle for module fake exceeds") + + // The watcher must record the termination reason on the + // per-upstream counter and surface it via /metrics so that + // operators can distinguish a hung-and-killed connection from a + // normal completion. + r.Eventually(func() bool { + return srv.getUpstreamCounters("u1").idleTerminated.Load() == 1 + }, time.Second, 10*time.Millisecond) + + resp, err := testHTTPClient().Get("http://" + srv.HTTPListener.Addr().String() + "/metrics") + r.NoError(err) + defer resp.Body.Close() + body, err := io.ReadAll(resp.Body) + r.NoError(err) + text := string(body) + assert.Contains(t, text, "# HELP rsync_proxy_relay_idle_timeout_terminated_total") + assert.Contains(t, text, "# TYPE rsync_proxy_relay_idle_timeout_terminated_total counter") + assert.Contains(t, text, "rsync_proxy_relay_idle_timeout_terminated_total{upstream=\"u1\"} 1\n") +} + +// TestRelayIdleTimeoutNotTriggeredWhenActive verifies that the idle +// timeout is reset whenever data flows, so a slow but continuously +// active stream does not get cut. The fake upstream sends data at an +// interval well below the idle timeout for several iterations. +func TestRelayIdleTimeoutNotTriggeredWhenActive(t *testing.T) { + srv := startServer(t) + defer srv.Close() + srv.RelayIdleTimeout = 2 * time.Second + + r := require.New(t) + + const iterations = 5 + const interval = 200 * time.Millisecond + + fakeRsync := rsync.NewServer(func(conn *rsync.Conn) { + defer conn.Close() + _, _, err := doServerHandshake(conn, RsyncdServerVersion) + r.NoError(err, "upstream handshake") + + for i := 0; i < iterations; i++ { + _, err = conn.Write([]byte("data\n")) + r.NoError(err, "write data") + time.Sleep(interval) + } + }) + fakeRsync.Start() + defer fakeRsync.Close() + + srv.modules = map[string][]Target{ + "fake": {{Upstream: "u1", Addr: fakeRsync.Listener.Addr().String()}}, + } + srv.upstreamQueues = map[string]*queue.Queue{"u1": queue.New(0, 0)} + + rawConn, err := net.Dial("tcp", srv.TCPListener.Addr().String()) + r.NoError(err) + conn := rsync.NewConn(rawConn) + defer conn.Close() + + _, err = doClientHandshake(conn, RsyncdServerVersion, "fake") + r.NoError(err) + + allData, err := io.ReadAll(conn) + r.NoError(err) + + expected := strings.Repeat("data\n", iterations) + r.Equal(expected, string(allData), + "steadily flowing traffic must not be interrupted by the idle timeout") +} + func TestTLSRsyncListener(t *testing.T) { r := require.New(t) @@ -1101,3 +1239,486 @@ func TestDiscoverModulesFromTrailingModuleBlock(t *testing.T) { require.NoError(t, err) assert.Equal(t, []string{"bar", "foo"}, modules) } + +// TestPerIPCapRejectsConnectionsBeyondLimit verifies that when a +// per-IP active connection cap is configured for an upstream, a +// second simultaneous connection from the same client IP to the same +// upstream is rejected with an @ERROR message, the perIPRejected +// counter is incremented, and the rejection is reflected in the +// /metrics output. +func TestPerIPCapRejectsConnectionsBeyondLimit(t *testing.T) { + srv := startServer(t) + defer srv.Close() + accessLogPath := setupAccessLog(t, srv) + + var release sync.WaitGroup + release.Add(1) + + upstream := rsync.NewServer(func(conn *rsync.Conn) { + defer conn.Close() + _, _, err := doServerHandshake(conn, RsyncdServerVersion) + require.NoError(t, err) + release.Wait() + }) + upstream.Start() + defer upstream.Close() + + srv.reloadLock.Lock() + srv.upstreams = []upstreamConfig{ + {Name: "u1", PerIPMaxActiveConns: 1}, + } + srv.modules = map[string][]Target{ + "fake": {{Upstream: "u1", Addr: upstream.Listener.Addr().String()}}, + } + srv.upstreamQueues = map[string]*queue.Queue{"u1": queue.New(0, 0)} + srv.reloadLock.Unlock() + + // First connection occupies the per-IP slot. + c1Raw, err := net.Dial("tcp", srv.TCPListener.Addr().String()) + require.NoError(t, err) + c1 := rsync.NewConn(c1Raw) + defer c1.Close() + _, err = doClientHandshake(c1, RsyncdServerVersion, "fake") + require.NoError(t, err) + + // Second connection from the same IP must be rejected. + c2Raw, err := net.Dial("tcp", srv.TCPListener.Addr().String()) + require.NoError(t, err) + c2 := rsync.NewConn(c2Raw) + defer c2.Close() + _, err = doClientHandshake(c2, RsyncdServerVersion, "fake") + require.NoError(t, err) + + body, err := io.ReadAll(c2) + require.NoError(t, err) + assert.Contains(t, string(body), "@ERROR: per-IP connection limit") + assert.Contains(t, string(body), "for upstream u1") + + require.Eventually(t, func() bool { + return srv.getUpstreamCounters("u1").perIPRejected.Load() == 1 + }, time.Second, 10*time.Millisecond) + + resp, err := testHTTPClient().Get("http://" + srv.HTTPListener.Addr().String() + "/metrics") + require.NoError(t, err) + defer resp.Body.Close() + metricsBody, err := io.ReadAll(resp.Body) + require.NoError(t, err) + text := string(metricsBody) + assert.Contains(t, text, "# HELP rsync_proxy_per_ip_rejected_total") + assert.Contains(t, text, "# TYPE rsync_proxy_per_ip_rejected_total counter") + assert.Contains(t, text, "rsync_proxy_per_ip_rejected_total{upstream=\"u1\"} 1\n") + + logData, err := os.ReadFile(accessLogPath) + require.NoError(t, err) + assert.Contains(t, string(logData), "per-IP cap of 1 reached") + + release.Done() +} + +// TestRelayMaxDurationClosesLongConnection verifies that when +// RelayMaxDuration is configured, an otherwise-active relay (no idle +// timeout) is forcibly torn down once it exceeds the configured +// wall-clock cap. +func TestRelayMaxDurationClosesLongConnection(t *testing.T) { + srv := startServer(t) + defer srv.Close() + srv.RelayIdleTimeout = 0 + srv.RelayMaxDuration = 500 * time.Millisecond + accessLogPath := setupAccessLog(t, srv) + + r := require.New(t) + + upstreamReady := make(chan struct{}) + upstreamDone := make(chan struct{}) + + fakeRsync := rsync.NewServer(func(conn *rsync.Conn) { + defer conn.Close() + defer close(upstreamDone) + + _, _, err := doServerHandshake(conn, RsyncdServerVersion) + r.NoError(err, "upstream handshake") + close(upstreamReady) + + // Stay quiet so the relay phase has no I/O. With idle timeout + // disabled, only the max-duration watcher can tear this + // connection down. + _, _ = io.ReadAll(conn) + }) + fakeRsync.Start() + defer fakeRsync.Close() + + srv.modules = map[string][]Target{ + "fake": {{Upstream: "u1", Addr: fakeRsync.Listener.Addr().String()}}, + } + srv.upstreams = []upstreamConfig{{Name: "u1"}} + srv.upstreamQueues = map[string]*queue.Queue{"u1": queue.New(0, 0)} + + rawConn, err := net.Dial("tcp", srv.TCPListener.Addr().String()) + r.NoError(err) + conn := rsync.NewConn(rawConn) + defer conn.Close() + + _, err = doClientHandshake(conn, RsyncdServerVersion, "fake") + r.NoError(err) + + <-upstreamReady + + start := time.Now() + _, err = io.ReadAll(conn) + r.NoError(err) + elapsed := time.Since(start) + + r.GreaterOrEqual(elapsed, srv.RelayMaxDuration, "should have waited at least the max duration") + r.Less(elapsed, 5*time.Second, "should not block far beyond the max duration") + + select { + case <-upstreamDone: + case <-time.After(2 * time.Second): + t.Fatal("upstream connection was not closed after max duration") + } + + logData, err := os.ReadFile(accessLogPath) + r.NoError(err) + assert.Contains(t, string(logData), "exceeded max relay duration") + assert.Contains(t, string(logData), "for module fake") + + r.Eventually(func() bool { + return srv.getUpstreamCounters("u1").maxDurationTerminated.Load() == 1 + }, time.Second, 10*time.Millisecond) + + resp, err := testHTTPClient().Get("http://" + srv.HTTPListener.Addr().String() + "/metrics") + r.NoError(err) + defer resp.Body.Close() + body, err := io.ReadAll(resp.Body) + r.NoError(err) + text := string(body) + assert.Contains(t, text, "# HELP rsync_proxy_relay_max_duration_terminated_total") + assert.Contains(t, text, "# TYPE rsync_proxy_relay_max_duration_terminated_total counter") + assert.Contains(t, text, "rsync_proxy_relay_max_duration_terminated_total{upstream=\"u1\"} 1\n") +} + +// TestApplyTCPKeepAliveOnTCPConn exercises the applyTCPKeepAlive +// helper end-to-end on a real *net.TCPConn. We cannot portably read +// SO_KEEPALIVE/TCP_KEEPIDLE back via the standard library, so we +// verify that the helper does not error and is safe to invoke with +// zero or negative periods. +func TestApplyTCPKeepAliveOnTCPConn(t *testing.T) { + ln, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + defer ln.Close() + + accepted := make(chan net.Conn, 1) + go func() { + c, err := ln.Accept() + if err != nil { + accepted <- nil + return + } + accepted <- c + }() + + clientConn, err := net.Dial("tcp", ln.Addr().String()) + require.NoError(t, err) + defer clientConn.Close() + + srvConn := <-accepted + require.NotNil(t, srvConn) + defer srvConn.Close() + + // Should be a no-op for a zero/negative period. + applyTCPKeepAlive(srvConn, 0) + applyTCPKeepAlive(srvConn, -1*time.Second) + + // Should successfully apply the keepalive on a real *net.TCPConn. + applyTCPKeepAlive(srvConn, 45*time.Second) + _, ok := srvConn.(*net.TCPConn) + assert.True(t, ok, "test sanity: accepted conn should be *net.TCPConn") +} + +// TestLoadConfigPropagatesTCPKeepAliveToDialer verifies that the +// tcp_keepalive proxy setting is parsed, validated, and propagated to +// both the Server.TCPKeepAlive field and the underlying dialer used +// for upstream connections. +func TestLoadConfigPropagatesTCPKeepAliveToDialer(t *testing.T) { + configContent := ` +[proxy] +listen = "127.0.0.1:0" +listen_http = "127.0.0.1:0" +tcp_keepalive = 45 +relay_max_duration = 7200 +per_ip_max_active_connections = 4 + +[upstreams.u1] +address = "127.0.0.1:8730" +modules = ["m1"] +` + srv := New() + require.NoError(t, srv.ReadConfig(strings.NewReader(configContent), false)) + + assert.Equal(t, 45*time.Second, srv.TCPKeepAlive) + assert.Equal(t, 45*time.Second, srv.dialer.KeepAlive) + assert.Equal(t, 2*time.Hour, srv.RelayMaxDuration) + + srv.reloadLock.RLock() + defer srv.reloadLock.RUnlock() + require.Len(t, srv.upstreams, 1) + assert.Equal(t, 4, srv.upstreams[0].PerIPMaxActiveConns) +} + +// TestLoadConfigRejectsNegativeTimings verifies that loadConfig +// rejects negative values for the new connection-control settings. +func TestLoadConfigRejectsNegativeTimings(t *testing.T) { + cases := []struct { + name string + config string + wantMsg string + }{ + { + name: "relay_max_duration", + config: ` +[proxy] +listen = "127.0.0.1:0" +listen_http = "127.0.0.1:0" +relay_max_duration = -1 + +[upstreams.u1] +address = "127.0.0.1:8730" +modules = ["m1"] +`, + wantMsg: "relay_max_duration must be non-negative", + }, + { + name: "tcp_keepalive", + config: ` +[proxy] +listen = "127.0.0.1:0" +listen_http = "127.0.0.1:0" +tcp_keepalive = -5 + +[upstreams.u1] +address = "127.0.0.1:8730" +modules = ["m1"] +`, + wantMsg: "tcp_keepalive must be non-negative", + }, + { + name: "per_ip_max_active_connections", + config: ` +[proxy] +listen = "127.0.0.1:0" +listen_http = "127.0.0.1:0" +per_ip_max_active_connections = -2 + +[upstreams.u1] +address = "127.0.0.1:8730" +modules = ["m1"] +`, + wantMsg: "per_ip_max_active_connections must be non-negative", + }, + { + name: "dial_timeout", + config: ` +[proxy] +listen = "127.0.0.1:0" +listen_http = "127.0.0.1:0" +dial_timeout = -1 + +[upstreams.u1] +address = "127.0.0.1:8730" +modules = ["m1"] +`, + wantMsg: "dial_timeout must be non-negative", + }, + { + name: "min_throughput_bytes", + config: ` +[proxy] +listen = "127.0.0.1:0" +listen_http = "127.0.0.1:0" +min_throughput_bytes = -1 + +[upstreams.u1] +address = "127.0.0.1:8730" +modules = ["m1"] +`, + wantMsg: "min_throughput_bytes must be non-negative", + }, + { + name: "min_throughput_window", + config: ` +[proxy] +listen = "127.0.0.1:0" +listen_http = "127.0.0.1:0" +min_throughput_window = -1 + +[upstreams.u1] +address = "127.0.0.1:8730" +modules = ["m1"] +`, + wantMsg: "min_throughput_window must be non-negative", + }, + { + name: "min_throughput_grace", + config: ` +[proxy] +listen = "127.0.0.1:0" +listen_http = "127.0.0.1:0" +min_throughput_grace = -1 + +[upstreams.u1] +address = "127.0.0.1:8730" +modules = ["m1"] +`, + wantMsg: "min_throughput_grace must be non-negative", + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + srv := New() + err := srv.ReadConfig(strings.NewReader(tc.config), false) + require.Error(t, err) + assert.Contains(t, err.Error(), tc.wantMsg) + }) + } +} + +// TestThroughputFloorTerminatesSlowConnection verifies that when a +// throughput floor is configured, an otherwise-active relay (no idle +// timeout, no max-duration) which transfers fewer than MinThroughputBytes +// per MinThroughputWindow is forcibly torn down once the grace period +// elapses, the per-upstream counter is incremented, and the rejection +// is reflected in /metrics and the access log. +func TestThroughputFloorTerminatesSlowConnection(t *testing.T) { + srv := startServer(t) + defer srv.Close() + srv.RelayIdleTimeout = 0 + srv.RelayMaxDuration = 0 + // Demand effectively infeasible throughput in a 200ms window so + // the floor always trips. Grace = 0 so the very first sample + // after the window elapses is enough. + srv.MinThroughputBytes = 1 << 30 + srv.MinThroughputWindow = 200 * time.Millisecond + srv.MinThroughputGrace = 0 + accessLogPath := setupAccessLog(t, srv) + + r := require.New(t) + + upstreamReady := make(chan struct{}) + upstreamDone := make(chan struct{}) + + fakeRsync := rsync.NewServer(func(conn *rsync.Conn) { + defer conn.Close() + defer close(upstreamDone) + + _, _, err := doServerHandshake(conn, RsyncdServerVersion) + r.NoError(err, "upstream handshake") + close(upstreamReady) + + // Stay quiet so the relay phase has effectively zero throughput. + _, _ = io.ReadAll(conn) + }) + fakeRsync.Start() + defer fakeRsync.Close() + + srv.modules = map[string][]Target{ + "fake": {{Upstream: "u1", Addr: fakeRsync.Listener.Addr().String()}}, + } + srv.upstreams = []upstreamConfig{{Name: "u1"}} + srv.upstreamQueues = map[string]*queue.Queue{"u1": queue.New(0, 0)} + + rawConn, err := net.Dial("tcp", srv.TCPListener.Addr().String()) + r.NoError(err) + conn := rsync.NewConn(rawConn) + defer conn.Close() + + _, err = doClientHandshake(conn, RsyncdServerVersion, "fake") + r.NoError(err) + + <-upstreamReady + + start := time.Now() + _, err = io.ReadAll(conn) + r.NoError(err) + elapsed := time.Since(start) + + r.GreaterOrEqual(elapsed, srv.MinThroughputWindow, + "should have waited at least one throughput window before tearing down") + r.Less(elapsed, 5*time.Second, "should not block far beyond the throughput window") + + select { + case <-upstreamDone: + case <-time.After(2 * time.Second): + t.Fatal("upstream connection was not closed after throughput floor trip") + } + + r.Eventually(func() bool { + return srv.getUpstreamCounters("u1").throughputFloorTerminated.Load() == 1 + }, time.Second, 10*time.Millisecond) + + resp, err := testHTTPClient().Get("http://" + srv.HTTPListener.Addr().String() + "/metrics") + r.NoError(err) + defer resp.Body.Close() + body, err := io.ReadAll(resp.Body) + r.NoError(err) + text := string(body) + assert.Contains(t, text, "# HELP rsync_proxy_throughput_floor_terminated_total") + assert.Contains(t, text, "# TYPE rsync_proxy_throughput_floor_terminated_total counter") + assert.Contains(t, text, "rsync_proxy_throughput_floor_terminated_total{upstream=\"u1\"} 1\n") + + logData, err := os.ReadFile(accessLogPath) + r.NoError(err) + assert.Contains(t, string(logData), "below throughput floor") + assert.Contains(t, string(logData), "for module fake") +} + +// TestLoadConfigPropagatesDialTimeoutAndThroughputSettings verifies +// that dial_timeout and the min_throughput_* settings are parsed, +// validated, and propagated into the dialer and Server fields. It also +// verifies that an unset min_throughput_grace falls back to the +// configured min_throughput_window, mirroring the documented default. +func TestLoadConfigPropagatesDialTimeoutAndThroughputSettings(t *testing.T) { + t.Run("explicit grace", func(t *testing.T) { + configContent := ` +[proxy] +listen = "127.0.0.1:0" +listen_http = "127.0.0.1:0" +dial_timeout = 5 +min_throughput_bytes = 65536 +min_throughput_window = 60 +min_throughput_grace = 30 + +[upstreams.u1] +address = "127.0.0.1:8730" +modules = ["m1"] +` + srv := New() + require.NoError(t, srv.ReadConfig(strings.NewReader(configContent), false)) + + assert.Equal(t, 5*time.Second, srv.dialer.Timeout) + assert.Equal(t, int64(65536), srv.MinThroughputBytes) + assert.Equal(t, 60*time.Second, srv.MinThroughputWindow) + assert.Equal(t, 30*time.Second, srv.MinThroughputGrace) + }) + + t.Run("grace defaults to window", func(t *testing.T) { + configContent := ` +[proxy] +listen = "127.0.0.1:0" +listen_http = "127.0.0.1:0" +min_throughput_bytes = 1024 +min_throughput_window = 90 + +[upstreams.u1] +address = "127.0.0.1:8730" +modules = ["m1"] +` + srv := New() + require.NoError(t, srv.ReadConfig(strings.NewReader(configContent), false)) + + assert.Equal(t, int64(1024), srv.MinThroughputBytes) + assert.Equal(t, 90*time.Second, srv.MinThroughputWindow) + assert.Equal(t, 90*time.Second, srv.MinThroughputGrace, + "unset min_throughput_grace must default to min_throughput_window") + }) +}