From c3211949f2daabb3431a1ccef3c8e4e16f2ef902 Mon Sep 17 00:00:00 2001 From: Adam Fisk Date: Mon, 23 Feb 2026 13:09:16 -0700 Subject: [PATCH] Add per-peer bandwidth attribution to egress server Previously, the egress server tracked ingress bytes as a single global counter. This change attributes bytes to individual peers so the data can be consumed by a future reward oracle. Changes: - Add PeerID field to EgressOptions (UUID by default) - Include peer ID in WebSocket subprotocol header alongside consumer session ID (backwards compatible: old clients fall back to CSID) - Replace global nIngressBytes atomic with per-peer sync.Map of atomic counters, shared across all connections from the same peer - Add "ingress-bytes-by-peer" OTel metric with peer_id attribute - Preserve existing "ingress-bytes" total metric for backward compat Co-Authored-By: Claude Opus 4.6 --- clientcore/jit_egress_consumer.go | 2 +- clientcore/settings.go | 2 ++ common/resource.go | 17 ++++++---- egress/egresslib.go | 56 ++++++++++++++++++++++++------- egress/websocket.go | 14 ++++---- 5 files changed, 66 insertions(+), 25 deletions(-) diff --git a/clientcore/jit_egress_consumer.go b/clientcore/jit_egress_consumer.go index bae9064..1d1e33c 100644 --- a/clientcore/jit_egress_consumer.go +++ b/clientcore/jit_egress_consumer.go @@ -69,7 +69,7 @@ func NewJITEgressConsumer(options *EgressOptions, wg *sync.WaitGroup) *WorkerFSM defer cancel() dialOpts := &websocket.DialOptions{ - Subprotocols: common.NewSubprotocolsRequest(consumerInfoMsg.SessionID, common.Version), + Subprotocols: common.NewSubprotocolsRequest(consumerInfoMsg.SessionID, options.PeerID, common.Version), } // TODO: WSS diff --git a/clientcore/settings.go b/clientcore/settings.go index 84d5427..df2df64 100644 --- a/clientcore/settings.go +++ b/clientcore/settings.go @@ -49,6 +49,7 @@ type EgressOptions struct { Endpoint string ConnectTimeout time.Duration ErrorBackoff time.Duration + PeerID string } func NewDefaultEgressOptions() *EgressOptions { @@ -57,6 +58,7 @@ func NewDefaultEgressOptions() *EgressOptions { Endpoint: "/ws", ConnectTimeout: 5 * time.Second, ErrorBackoff: 5 * time.Second, + PeerID: uuid.NewString(), } } diff --git a/common/resource.go b/common/resource.go index 0ffb046..61c99c9 100644 --- a/common/resource.go +++ b/common/resource.go @@ -188,15 +188,20 @@ func DecodeSignalMsg(raw []byte) (string, interface{}, error) { // coder/websocket API, to pass arbitrary data. Note that a server receiving a populated // Sec-Websocket-Protocols header must reply with a reciprocal header containing some selected // protocol from the request. -func NewSubprotocolsRequest(csid, version string) []string { - return []string{subprotocolsMagicCookie, csid, version} +func NewSubprotocolsRequest(csid, peerID, version string) []string { + return []string{subprotocolsMagicCookie, csid, peerID, version} } -func ParseSubprotocolsRequest(s []string) (csid string, version string, ok bool) { - if len(s) != 3 { - return "", "", false +func ParseSubprotocolsRequest(s []string) (csid string, peerID string, version string, ok bool) { + switch len(s) { + case 4: + return s[1], s[2], s[3], true + case 3: + // Backwards compat: old clients don't send peerID + return s[1], "", s[2], true + default: + return "", "", "", false } - return s[1], s[2], true } func NewSubprotocolsResponse() []string { diff --git a/egress/egresslib.go b/egress/egresslib.go index e1f3612..3e6bb0b 100644 --- a/egress/egresslib.go +++ b/egress/egresslib.go @@ -7,6 +7,7 @@ import ( "net" "net/http" "strings" + "sync" "sync/atomic" "time" @@ -14,6 +15,7 @@ import ( "github.com/google/uuid" "go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp" "go.opentelemetry.io/otel" + "go.opentelemetry.io/otel/attribute" "go.opentelemetry.io/otel/metric" "github.com/getlantern/broflake/common" @@ -34,8 +36,9 @@ var nClients uint64 // nQUICStreams is the number of open QUIC streams (not to be confused with QUIC connections) var nQUICStreams uint64 -// nIngressBytes is the number of bytes received over all WebSocket connections since the last otel measurement callback -var nIngressBytes uint64 +// peerIngressBytes tracks ingress bytes per peer ID since the last otel measurement callback. +// Keys are peer ID strings, values are *uint64 (atomic counters). +var peerIngressBytes sync.Map // Otel instruments var nClientsCounter metric.Int64UpDownCounter @@ -44,6 +47,7 @@ var nClientsCounter metric.Int64UpDownCounter var nQUICConnectionsCounter metric.Int64UpDownCounter var nQUICStreamsCounter metric.Int64UpDownCounter var nIngressBytesCounter metric.Int64ObservableUpDownCounter +var nIngressBytesByPeerCounter metric.Int64ObservableUpDownCounter type proxyListener struct { net.Listener @@ -85,7 +89,7 @@ func (l proxyListener) handleWebsocket(w http.ResponseWriter, r *http.Request) { } } - consumerSessionID, version, ok := common.ParseSubprotocolsRequest(subprotocols) + consumerSessionID, peerID, version, ok := common.ParseSubprotocolsRequest(subprotocols) if !ok { common.Debugf("Refused WebSocket connection, missing subprotocols") return @@ -111,6 +115,11 @@ func (l proxyListener) handleWebsocket(w http.ResponseWriter, r *http.Request) { return } + // Old clients don't send a peer ID; fall back to consumer session ID so bytes are still tracked + if peerID == "" { + peerID = consumerSessionID + } + c, err := websocket.Accept( w, r, @@ -131,12 +140,18 @@ func (l proxyListener) handleWebsocket(w http.ResponseWriter, r *http.Request) { return } + // Get or create the per-peer ingress byte counter + counterPtr := new(uint64) + actual, _ := peerIngressBytes.LoadOrStore(peerID, counterPtr) + wspconn := errorlessWebSocketPacketConn{ - w: c, - addr: common.DebugAddr(fmt.Sprintf("WebSocket connection %v", uuid.NewString())), - keepalive: websocketKeepalive, - tcpAddr: tcpAddr, - readError: make(chan error), + w: c, + addr: common.DebugAddr(fmt.Sprintf("WebSocket connection %v", uuid.NewString())), + keepalive: websocketKeepalive, + tcpAddr: tcpAddr, + readError: make(chan error), + peerID: peerID, + ingressBytes: actual.(*uint64), } defer wspconn.Close() @@ -246,15 +261,32 @@ func NewListener(ctx context.Context, ll net.Listener, tlsConfig *tls.Config) (n return nil, err } + nIngressBytesByPeerCounter, err = m.Int64ObservableUpDownCounter("ingress-bytes-by-peer") + if err != nil { + closeFuncMetric(ctx) + return nil, err + } + _, err = m.RegisterCallback( func(ctx context.Context, o metric.Observer) error { - b := atomic.LoadUint64(&nIngressBytes) - o.ObserveInt64(nIngressBytesCounter, int64(b)) - common.Debugf("Ingress bytes: %v", b) - atomic.StoreUint64(&nIngressBytes, uint64(0)) + var total int64 + peerIngressBytes.Range(func(key, value any) bool { + pid := key.(string) + ptr := value.(*uint64) + b := int64(atomic.SwapUint64(ptr, 0)) + if b > 0 { + o.ObserveInt64(nIngressBytesByPeerCounter, b, + metric.WithAttributes(attribute.String("peer_id", pid))) + } + total += b + return true + }) + o.ObserveInt64(nIngressBytesCounter, total) + common.Debugf("Ingress bytes: %v", total) return nil }, nIngressBytesCounter, + nIngressBytesByPeerCounter, ) if err != nil { closeFuncMetric(ctx) diff --git a/egress/websocket.go b/egress/websocket.go index 3b64de1..1ff639d 100644 --- a/egress/websocket.go +++ b/egress/websocket.go @@ -23,11 +23,13 @@ import ( // at some point in the future. Intercepted *read* errors are sent over the readError channel. // Currently, intercepted *write* errors are simply discarded. type errorlessWebSocketPacketConn struct { - w *websocket.Conn - addr net.Addr - keepalive time.Duration - tcpAddr *net.TCPAddr - readError chan error + w *websocket.Conn + addr net.Addr + keepalive time.Duration + tcpAddr *net.TCPAddr + readError chan error + peerID string + ingressBytes *uint64 // per-peer atomic counter, shared across all connections from the same peer } func (q errorlessWebSocketPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) { @@ -84,7 +86,7 @@ func (q errorlessWebSocketPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, } copy(p, b) - atomic.AddUint64(&nIngressBytes, uint64(len(b))) + atomic.AddUint64(q.ingressBytes, uint64(len(b))) return len(b), q.tcpAddr, err }