Skip to content

Commit 9155272

Browse files
committed
drpcmanager: replace streamBuffer with streamRegistry
Replace the single-stream streamBuffer with a stream registry that maps stream IDs to stream objects. The registry currently holds at most one active stream (two briefly during handoff), but provides the foundation for stream multiplexing where callers will look up streams by ID directly.
1 parent eff22e7 commit 9155272

4 files changed

Lines changed: 264 additions & 82 deletions

File tree

drpcmanager/active_streams.go

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
// Copyright (C) 2026 Cockroach Labs.
2+
// See LICENSE for copying information.
3+
4+
package drpcmanager
5+
6+
import (
7+
"sync"
8+
9+
"storj.io/drpc/drpcstream"
10+
)
11+
12+
// activeStreams is a thread-safe map of stream IDs to stream objects.
13+
// It is used by the Manager to track active streams for lifecycle management.
14+
type activeStreams struct {
15+
mu sync.RWMutex
16+
streams map[uint64]*drpcstream.Stream
17+
closed bool
18+
}
19+
20+
func newActiveStreams() *activeStreams {
21+
return &activeStreams{
22+
streams: make(map[uint64]*drpcstream.Stream),
23+
}
24+
}
25+
26+
// Add adds a stream. It returns an error if the collection is closed or if a
27+
// stream with the same ID already exists.
28+
func (r *activeStreams) Add(id uint64, stream *drpcstream.Stream) error {
29+
if stream == nil {
30+
return managerClosed.New("stream can't be nil")
31+
}
32+
33+
r.mu.Lock()
34+
defer r.mu.Unlock()
35+
36+
if r.closed {
37+
return managerClosed.New("add to closed collection")
38+
}
39+
if _, ok := r.streams[id]; ok {
40+
return managerClosed.New("duplicate stream id")
41+
}
42+
r.streams[id] = stream
43+
return nil
44+
}
45+
46+
// Remove removes a stream. It is a no-op if the stream is not present or if
47+
// the collection has been closed.
48+
func (r *activeStreams) Remove(id uint64) {
49+
r.mu.Lock()
50+
defer r.mu.Unlock()
51+
52+
if r.streams != nil {
53+
delete(r.streams, id)
54+
}
55+
}
56+
57+
// Get returns the stream for the given ID and whether it was found.
58+
func (r *activeStreams) Get(id uint64) (*drpcstream.Stream, bool) {
59+
r.mu.RLock()
60+
defer r.mu.RUnlock()
61+
62+
s, ok := r.streams[id]
63+
return s, ok
64+
}
65+
66+
// GetLatest returns the stream with the highest ID, or nil if empty.
67+
func (r *activeStreams) GetLatest() *drpcstream.Stream {
68+
r.mu.RLock()
69+
defer r.mu.RUnlock()
70+
71+
var latest *drpcstream.Stream
72+
for _, s := range r.streams {
73+
if latest == nil || latest.ID() < s.ID() {
74+
latest = s
75+
}
76+
}
77+
return latest
78+
}
79+
80+
// Close marks the collection as closed, preventing future Add calls.
81+
// It does not cancel any streams.
82+
func (r *activeStreams) Close() {
83+
r.mu.Lock()
84+
defer r.mu.Unlock()
85+
86+
r.closed = true
87+
}
88+
89+
// ForEach calls fn for each active stream. The collection is read-locked
90+
// during iteration.
91+
func (r *activeStreams) ForEach(fn func(*drpcstream.Stream)) {
92+
r.mu.RLock()
93+
defer r.mu.RUnlock()
94+
95+
for _, s := range r.streams {
96+
fn(s)
97+
}
98+
}
99+
100+
// Len returns the number of active streams.
101+
func (r *activeStreams) Len() int {
102+
r.mu.RLock()
103+
defer r.mu.RUnlock()
104+
105+
return len(r.streams)
106+
}

drpcmanager/active_streams_test.go

Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
// Copyright (C) 2026 Cockroach Labs.
2+
// See LICENSE for copying information.
3+
4+
package drpcmanager
5+
6+
import (
7+
"context"
8+
"testing"
9+
10+
"github.com/zeebo/assert"
11+
12+
"storj.io/drpc/drpcstream"
13+
"storj.io/drpc/drpcwire"
14+
)
15+
16+
func testStream(id uint64) *drpcstream.Stream {
17+
return drpcstream.New(context.Background(), id, &drpcwire.Writer{})
18+
}
19+
20+
func TestActiveStreams_AddAndGet(t *testing.T) {
21+
streams := newActiveStreams()
22+
s := testStream(1)
23+
24+
assert.NoError(t, streams.Add(1, s))
25+
26+
got, ok := streams.Get(1)
27+
assert.That(t, ok)
28+
assert.Equal(t, got, s)
29+
}
30+
31+
func TestActiveStreams_GetMissing(t *testing.T) {
32+
streams := newActiveStreams()
33+
34+
got, ok := streams.Get(42)
35+
assert.That(t, !ok)
36+
assert.Nil(t, got)
37+
}
38+
39+
func TestActiveStreams_Remove(t *testing.T) {
40+
streams := newActiveStreams()
41+
s := testStream(1)
42+
43+
assert.NoError(t, streams.Add(1, s))
44+
assert.Equal(t, streams.Len(), 1)
45+
46+
streams.Remove(1)
47+
48+
_, ok := streams.Get(1)
49+
assert.That(t, !ok)
50+
assert.Equal(t, streams.Len(), 0)
51+
}
52+
53+
func TestActiveStreams_RemoveIdempotent(t *testing.T) {
54+
streams := newActiveStreams()
55+
56+
// must not panic when removing a non-existent ID
57+
streams.Remove(99)
58+
}
59+
60+
func TestActiveStreams_DuplicateAdd(t *testing.T) {
61+
streams := newActiveStreams()
62+
s1 := testStream(1)
63+
s2 := testStream(1)
64+
65+
assert.NoError(t, streams.Add(1, s1))
66+
assert.Error(t, streams.Add(1, s2))
67+
68+
// original stream is still present
69+
got, ok := streams.Get(1)
70+
assert.That(t, ok)
71+
assert.Equal(t, got, s1)
72+
}
73+
74+
func TestActiveStreams_AddAfterClose(t *testing.T) {
75+
streams := newActiveStreams()
76+
streams.Close()
77+
78+
err := streams.Add(1, testStream(1))
79+
assert.Error(t, err)
80+
}
81+
82+
func TestActiveStreams_RemoveAfterClose(t *testing.T) {
83+
streams := newActiveStreams()
84+
s := testStream(1)
85+
assert.NoError(t, streams.Add(1, s))
86+
87+
streams.Close()
88+
89+
// must not panic
90+
streams.Remove(1)
91+
}
92+
93+
func TestActiveStreams_Len(t *testing.T) {
94+
streams := newActiveStreams()
95+
assert.Equal(t, streams.Len(), 0)
96+
97+
assert.NoError(t, streams.Add(1, testStream(1)))
98+
assert.Equal(t, streams.Len(), 1)
99+
100+
assert.NoError(t, streams.Add(2, testStream(2)))
101+
assert.Equal(t, streams.Len(), 2)
102+
103+
streams.Remove(1)
104+
assert.Equal(t, streams.Len(), 1)
105+
}
106+
107+
func TestActiveStreams_ForEach(t *testing.T) {
108+
streams := newActiveStreams()
109+
s1 := testStream(1)
110+
s2 := testStream(2)
111+
s3 := testStream(3)
112+
113+
assert.NoError(t, streams.Add(1, s1))
114+
assert.NoError(t, streams.Add(2, s2))
115+
assert.NoError(t, streams.Add(3, s3))
116+
117+
seen := make(map[uint64]*drpcstream.Stream)
118+
streams.ForEach(func(s *drpcstream.Stream) {
119+
seen[s.ID()] = s
120+
})
121+
122+
assert.Equal(t, len(seen), 3)
123+
assert.Equal(t, seen[1], s1)
124+
assert.Equal(t, seen[2], s2)
125+
assert.Equal(t, seen[3], s3)
126+
}
127+
128+
func TestActiveStreams_ForEach_Empty(t *testing.T) {
129+
streams := newActiveStreams()
130+
131+
count := 0
132+
streams.ForEach(func(_ *drpcstream.Stream) { count++ })
133+
assert.Equal(t, count, 0)
134+
}

drpcmanager/manager.go

Lines changed: 24 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ import (
1111
"net"
1212
"strings"
1313
"sync"
14+
"sync/atomic"
1415
"syscall"
1516
"time"
1617

@@ -81,10 +82,14 @@ type Manager struct {
8182

8283
wg sync.WaitGroup // tracks active manageStream goroutines
8384

84-
sem drpcsignal.Chan // held by the active stream
85-
sbuf streamBuffer // largest stream id created
85+
// streams tracks active streams. Currently holds at most one active stream;
86+
// a second may briefly coexist during stream handoff (old stream's Remove
87+
// races with new stream's Add).
88+
streams *activeStreams
8689

87-
pdone drpcsignal.Chan // signals when NewServerStream has registered the new stream
90+
sem drpcsignal.Chan // held by the active stream
91+
92+
pdone drpcsignal.Chan // signals when NewServerStream has added the new stream
8893
invokes chan invokeInfo // completed invoke info from manageReader to NewServerStream
8994

9095
// Below fields are owned by the manageReader goroutine, used in handleInvokeFrame.
@@ -123,9 +128,6 @@ func NewWithOptions(tr drpc.Transport, opts Options) *Manager {
123128
invokes: make(chan invokeInfo),
124129
}
125130

126-
// initialize the stream buffer
127-
m.sbuf.init()
128-
129131
// this semaphore controls the number of concurrent streams. it MUST be 1.
130132
m.sem.Make(1)
131133

@@ -134,6 +136,7 @@ func NewWithOptions(tr drpc.Transport, opts Options) *Manager {
134136
m.pdone.Make(1)
135137

136138
m.pa = drpcwire.NewPacketAssembler()
139+
m.streams = newActiveStreams()
137140

138141
// set the internal stream options
139142
drpcopts.SetStreamTransport(&m.opts.Stream.Internal, m.tr)
@@ -186,7 +189,7 @@ func (m *Manager) acquireSemaphore(ctx context.Context) error {
186189
// longer make any reads or writes on the transport. It exits early if the
187190
// context is canceled or the manager is terminated.
188191
func (m *Manager) waitForPreviousStream(ctx context.Context) (err error) {
189-
prev := m.sbuf.Get()
192+
prev := m.streams.GetLatest()
190193
if prev == nil {
191194
return nil
192195
}
@@ -217,7 +220,7 @@ func (m *Manager) terminate(err error) {
217220
if m.sigs.term.Set(err) {
218221
m.log("TERM", func() string { return fmt.Sprint(err) })
219222
m.sigs.tport.Set(m.tr.Close())
220-
m.sbuf.Close()
223+
m.streams.Close()
221224
}
222225
}
223226

@@ -249,7 +252,7 @@ func (m *Manager) manageReader() {
249252
return
250253
}
251254

252-
switch curr := m.sbuf.Get(); {
255+
switch curr := m.streams.GetLatest(); {
253256
// If the frame is for the current stream, deliver it.
254257
case curr != nil && incomingFrame.ID.Stream == curr.ID():
255258
if err := curr.HandleFrame(incomingFrame); err != nil {
@@ -272,14 +275,7 @@ func (m *Manager) manageReader() {
272275
}
273276

274277
default:
275-
// A non-invoke frame arrived for a stream that doesn't exist yet
276-
// (curr is nil or incomingFrame.ID.Stream > curr.ID). The first
277-
// frame of a new stream must be KindInvoke or KindInvokeMetadata.
278-
m.terminate(managerClosed.Wrap(drpc.ProtocolError.New(
279-
"first frame of a new stream must be Invoke, got %v (ID:%v)",
280-
incomingFrame.Kind,
281-
incomingFrame.ID)))
282-
return
278+
m.log("DROP", incomingFrame.String)
283279
}
284280
}
285281
}
@@ -319,9 +315,9 @@ func (m *Manager) handleInvokeFrame(fr drpcwire.Frame) error {
319315
// Invoke packet completes the sequence. Send to NewServerStream.
320316
select {
321317
case m.invokes <- invokeInfo{sid: pkt.ID.Stream, data: pkt.Data, metadata: m.metadata}:
322-
// Wait for NewServerStream to finish stream creation (including
323-
// sbuf.Set) before reading the next frame. This guarantees curr
324-
// is set for subsequent non-invoke packets.
318+
// Wait for NewServerStream to finish stream creation before reading the
319+
// next frame. This guarantees curr is set for subsequent non-invoke
320+
// packets.
325321
m.pdone.Recv()
326322

327323
m.pa.Reset()
@@ -346,10 +342,13 @@ func (m *Manager) newStream(ctx context.Context, sid uint64, kind drpc.StreamKin
346342

347343
stream := drpcstream.NewWithOptions(ctx, sid, m.wr, opts)
348344

345+
if err := m.streams.Add(sid, stream); err != nil {
346+
return nil, err
347+
}
348+
349349
m.wg.Add(1)
350350
go m.manageStream(ctx, stream)
351351

352-
m.sbuf.Set(stream)
353352
m.log("STREAM", stream.String)
354353

355354
return stream, nil
@@ -359,6 +358,7 @@ func (m *Manager) newStream(ctx context.Context, sid uint64, kind drpc.StreamKin
359358
// is finished, canceling the stream if the context is canceled.
360359
func (m *Manager) manageStream(ctx context.Context, stream *drpcstream.Stream) {
361360
defer m.wg.Done()
361+
defer m.streams.Remove(stream.ID())
362362
select {
363363
case <-m.sigs.term.Signal():
364364
err := m.sigs.term.Err()
@@ -429,7 +429,7 @@ func (m *Manager) Closed() <-chan struct{} {
429429
// the return result is only valid until the next call to NewClientStream or
430430
// NewServerStream.
431431
func (m *Manager) Unblocked() <-chan struct{} {
432-
if prev := m.sbuf.Get(); prev != nil {
432+
if prev := m.streams.GetLatest(); prev != nil {
433433
return prev.Context().Done()
434434
}
435435
return closedCh
@@ -506,9 +506,8 @@ func (m *Manager) NewServerStream(ctx context.Context) (stream *drpcstream.Strea
506506
}
507507
}
508508
stream, err := m.newStream(ctx, pkt.sid, drpc.StreamKindServer, rpc)
509-
// Signal pdone only after stream registration so that
510-
// manageReader sees the new stream via sbuf.Get() when it reads
511-
// the next frame.
509+
// Signal pdone only after adding the stream so that manageReader sees
510+
// the new stream in activeStreams when it reads the next frame.
512511
m.pdone.Send()
513512
return stream, rpc, err
514513
}

0 commit comments

Comments
 (0)