Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions drpcconn/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ func (c *Conn) Invoke(ctx context.Context, rpc string, enc drpc.Encoding, in, ou
}

func (c *Conn) doInvoke(stream *drpcstream.Stream, enc drpc.Encoding, rpc string, data []byte, metadata []byte, out drpc.Message) (err error) {
defer func() { err = stream.CheckCancelError(err) }()
if err := stream.WriteInvoke(rpc, metadata); err != nil {
return err
}
Expand Down
26 changes: 16 additions & 10 deletions drpcconn/conn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,31 +40,32 @@ func TestConn_InvokeFlushesSendClose(t *testing.T) {
invokeDone := make(chan struct{})

ctx.Run(func(ctx context.Context) {
wr := drpcwire.NewWriter(ps, 64)
wr := drpcwire.NewMuxWriter(ps, nil)
defer func() { wr.Stop(nil); <-wr.Done() }()
rd := drpcwire.NewReader(ps)

_, _ = rd.ReadFrame() // Invoke
_, _ = rd.ReadFrame() // Message
pkt, _ := rd.ReadFrame() // CloseSend

_ = wr.WritePacket(drpcwire.Packet{
_ = wr.WriteFrame(drpcwire.Frame{
Data: []byte("qux"),
ID: drpcwire.ID{Stream: pkt.ID.Stream, Message: 1},
Kind: drpcwire.KindMessage,
Done: true,
})
_ = wr.Flush()

_, _ = rd.ReadFrame() // Close
<-invokeDone // wait for invoke to return

// ensure that any later packets are dropped by writing one
// before closing the transport.
for i := 0; i < 5; i++ {
_ = wr.WritePacket(drpcwire.Packet{
for range 5 {
_ = wr.WriteFrame(drpcwire.Frame{
ID: drpcwire.ID{Stream: pkt.ID.Stream, Message: 2},
Kind: drpcwire.KindCloseSend,
Done: true,
})
_ = wr.Flush()
}

_ = ps.Close()
Expand All @@ -78,7 +79,7 @@ func TestConn_InvokeFlushesSendClose(t *testing.T) {

invokeDone <- struct{}{} // signal invoke has returned

// we should eventually notice the transport is closed
// we should eventually notice the transport is closed due to ps.Close()
select {
case <-conn.Closed():
case <-time.After(1 * time.Second):
Expand All @@ -95,7 +96,8 @@ func TestConn_InvokeSendsGrpcAndDrpcMetadata(t *testing.T) {
defer func() { assert.NoError(t, ps.Close()) }()

ctx.Run(func(ctx context.Context) {
wr := drpcwire.NewWriter(ps, 64)
wr := drpcwire.NewMuxWriter(ps, nil)
defer func() { wr.Stop(nil); <-wr.Done() }()
rd := drpcwire.NewReader(ps)

md, err := rd.ReadFrame() // Metadata
Expand All @@ -114,12 +116,12 @@ func TestConn_InvokeSendsGrpcAndDrpcMetadata(t *testing.T) {
_, _ = rd.ReadFrame() // Message
pkt, _ := rd.ReadFrame() // CloseSend

_ = wr.WritePacket(drpcwire.Packet{
_ = wr.WriteFrame(drpcwire.Frame{
Data: []byte("qux"),
ID: drpcwire.ID{Stream: pkt.ID.Stream, Message: 1},
Kind: drpcwire.KindMessage,
Done: true,
})
_ = wr.Flush()

_, _ = rd.ReadFrame() // Close
})
Expand Down Expand Up @@ -181,6 +183,10 @@ func TestConn_NewStreamSendsGrpcAndDrpcMetadata(t *testing.T) {
s, err := conn.NewStream(ctx, "/com.example.Foo/Bar", testEncoding{})
assert.NoError(t, err)
_ = s.CloseSend()

// Wait for the server goroutine to read all frames before defers
// close the pipe. With MuxWriter, writes are asynchronous.
ctx.Wait()
}

func TestConn_encodeMetadata(t *testing.T) {
Expand Down
10 changes: 6 additions & 4 deletions drpcmanager/active_streams.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,10 @@ import (
// activeStreams is a thread-safe map of stream IDs to stream objects.
// It is used by the Manager to track active streams for lifecycle management.
type activeStreams struct {
mu sync.RWMutex
streams map[uint64]*drpcstream.Stream
closed bool
mu sync.RWMutex
streams map[uint64]*drpcstream.Stream
closed bool
closeErr error
}

func newActiveStreams() *activeStreams {
Expand All @@ -34,7 +35,7 @@ func (r *activeStreams) Add(id uint64, stream *drpcstream.Stream) error {
defer r.mu.Unlock()

if r.closed {
return managerClosed.New("add to closed collection")
return r.closeErr
}
if _, ok := r.streams[id]; ok {
return managerClosed.New("duplicate stream id")
Expand Down Expand Up @@ -73,6 +74,7 @@ func (r *activeStreams) Close(err error) {
defer r.mu.Unlock()

r.closed = true
r.closeErr = err
for id, s := range r.streams {
s.Cancel(err)
delete(r.streams, id)
Expand Down
28 changes: 17 additions & 11 deletions drpcmanager/active_streams_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ package drpcmanager
import (
"context"
"errors"
"io"
"testing"

"github.com/zeebo/assert"
Expand All @@ -14,13 +15,19 @@ import (
"storj.io/drpc/drpcwire"
)

func testStream(id uint64) *drpcstream.Stream {
return drpcstream.New(context.Background(), id, &drpcwire.Writer{})
func testMuxWriter(t *testing.T) *drpcwire.MuxWriter {
mw := drpcwire.NewMuxWriter(io.Discard, func(error) {})
t.Cleanup(func() { mw.Stop(nil); <-mw.Done() })
return mw
}

func testStream(t *testing.T, id uint64) *drpcstream.Stream {
return drpcstream.New(context.Background(), id, testMuxWriter(t))
}

func TestActiveStreams_AddAndGet(t *testing.T) {
streams := newActiveStreams()
s := testStream(1)
s := testStream(t, 1)

assert.NoError(t, streams.Add(1, s))

Expand All @@ -39,7 +46,7 @@ func TestActiveStreams_GetMissing(t *testing.T) {

func TestActiveStreams_Remove(t *testing.T) {
streams := newActiveStreams()
s := testStream(1)
s := testStream(t, 1)

assert.NoError(t, streams.Add(1, s))
assert.Equal(t, streams.Len(), 1)
Expand All @@ -60,8 +67,8 @@ func TestActiveStreams_RemoveIdempotent(t *testing.T) {

func TestActiveStreams_DuplicateAdd(t *testing.T) {
streams := newActiveStreams()
s1 := testStream(1)
s2 := testStream(1)
s1 := testStream(t, 1)
s2 := testStream(t, 1)

assert.NoError(t, streams.Add(1, s1))
assert.Error(t, streams.Add(1, s2))
Expand All @@ -76,13 +83,13 @@ func TestActiveStreams_AddAfterClose(t *testing.T) {
streams := newActiveStreams()
streams.Close(errors.New("closed"))

err := streams.Add(1, testStream(1))
err := streams.Add(1, testStream(t, 1))
assert.Error(t, err)
}

func TestActiveStreams_RemoveAfterClose(t *testing.T) {
streams := newActiveStreams()
s := testStream(1)
s := testStream(t, 1)
assert.NoError(t, streams.Add(1, s))

streams.Close(errors.New("closed"))
Expand All @@ -95,13 +102,12 @@ func TestActiveStreams_Len(t *testing.T) {
streams := newActiveStreams()
assert.Equal(t, streams.Len(), 0)

assert.NoError(t, streams.Add(1, testStream(1)))
assert.NoError(t, streams.Add(1, testStream(t, 1)))
assert.Equal(t, streams.Len(), 1)

assert.NoError(t, streams.Add(2, testStream(2)))
assert.NoError(t, streams.Add(2, testStream(t, 2)))
assert.Equal(t, streams.Len(), 2)

streams.Remove(1)
assert.Equal(t, streams.Len(), 1)
}

64 changes: 38 additions & 26 deletions drpcmanager/manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,6 @@ var managerClosed = errs.Class("manager closed")

// Options controls configuration settings for a manager.
type Options struct {
// WriterBufferSize controls the size of the buffer that we will fill before
// flushing. Normal writes to streams typically issue a flush explicitly.
WriterBufferSize int

// Reader are passed to any readers the manager creates.
Reader drpcwire.ReaderOptions

Expand All @@ -55,7 +51,7 @@ type Options struct {
// to the appropriate stream.
type Manager struct {
tr drpc.Transport
wr *drpcwire.Writer
wr *drpcwire.MuxWriter
rd *drpcwire.Reader
opts Options

Expand All @@ -70,9 +66,10 @@ type Manager struct {
pdone drpcsignal.Chan // signals when NewServerStream has registered the new stream
invokes chan invokeInfo // completed invoke info from manageReader to NewServerStream

// invokesAssembler is owned by the manageReader goroutine, used in
// handleInvokeFrame.
invokesAssembler map[uint64]*invokeAssembler
// pendingStreams is owned by the manageReader goroutine, used in
// handleInvokeFrame. It tracks streams that are being assembled from
// invoke/metadata frames but haven't been fully created yet.
pendingStreams map[uint64]*pendingStream

sigs struct {
term drpcsignal.Signal // set when the manager should start terminating
Expand All @@ -91,7 +88,10 @@ const (
Server
)

type invokeAssembler struct {
// pendingStream accumulates invoke and metadata frames for a stream that is
// being set up but hasn't been fully created yet. Once the invoke packet
// arrives, the pending stream is forwarded to NewServerStream.
type pendingStream struct {
metadata map[string]string // accumulated invoke metadata
pa drpcwire.PacketAssembler // assembles invoke/metadata frames into packets
}
Expand All @@ -114,19 +114,20 @@ func New(tr drpc.Transport, kind ManagerKind) *Manager {
func NewWithOptions(tr drpc.Transport, kind ManagerKind, opts Options) *Manager {
m := &Manager{
tr: tr,
wr: drpcwire.NewWriter(tr, opts.WriterBufferSize),
rd: drpcwire.NewReaderWithOptions(tr, opts.Reader),
opts: opts,

invokes: make(chan invokeInfo),
kind: kind,
}

m.wr = drpcwire.NewMuxWriter(tr, m.terminate)

// a buffer of size 1 allows NewServerStream to signal it is done creating a
// new server stream without having to coordinate with manageReader.
m.pdone.Make(1)

m.invokesAssembler = make(map[uint64]*invokeAssembler)
m.pendingStreams = make(map[uint64]*pendingStream)

m.streams = newActiveStreams()

Expand All @@ -148,17 +149,21 @@ func (m *Manager) log(what string, cb func() string) {
}

// terminate puts the Manager into a terminal state and closes any resources
// that need to be closed to signal the state change.
// that need to be closed to signal the state change. The mux writer is stopped
// before closing the transport so that WriteFrame immediately rejects new
// writes; the subsequent transport close unblocks any in-flight Write in the
// drain goroutine.
func (m *Manager) terminate(err error) {
if m.sigs.term.Set(err) {
m.log("TERM", func() string { return fmt.Sprint(err) })
m.sigs.tport.Set(m.tr.Close())
if errors.Is(err, io.EOF) {
err = context.Canceled
if m.kind == Client {
err = drpc.ClosedError.New("connection closed")
}
}
m.wr.Stop(err)
m.sigs.tport.Set(m.tr.Close())
m.streams.Close(err)
}
}
Expand Down Expand Up @@ -190,7 +195,7 @@ func (m *Manager) manageReader() {

switch {
// if the packet is for an active stream, deliver it.
case ok && stream != nil:
case ok:
if err := stream.HandleFrame(incomingFrame); err != nil {
m.terminate(managerClosed.Wrap(err))
return
Expand All @@ -210,15 +215,15 @@ func (m *Manager) manageReader() {
}

// handleInvokeFrame assembles invoke/metadata frames into complete packets and
// forwards the finished invoke info to NewServerStream via m.newServerStreamInfo.
// Metadata packets are accumulated; the invoke packet triggers the send.
// forwards the finished invoke info to NewServerStream. Metadata packets are
// accumulated; the invoke packet triggers the send.
func (m *Manager) handleInvokeFrame(fr drpcwire.Frame) error {
ia, ok := m.invokesAssembler[fr.ID.Stream]
ps, ok := m.pendingStreams[fr.ID.Stream]
if !ok {
ia = &invokeAssembler{pa: drpcwire.NewPacketAssembler()}
m.invokesAssembler[fr.ID.Stream] = ia
ps = &pendingStream{pa: drpcwire.NewPacketAssembler()}
m.pendingStreams[fr.ID.Stream] = ps
}
pkt, packetReady, err := ia.pa.AppendFrame(fr)
pkt, packetReady, err := ps.pa.AppendFrame(fr)
if err != nil {
return err
}
Expand All @@ -232,19 +237,19 @@ func (m *Manager) handleInvokeFrame(fr drpcwire.Frame) error {
if err != nil {
return err
}
ia.metadata = meta
ps.metadata = meta
return nil
}

// Invoke packet completes the sequence. Send to NewServerStream.
select {
case m.invokes <- invokeInfo{sid: pkt.ID.Stream, data: pkt.Data, metadata: ia.metadata}:
case m.invokes <- invokeInfo{sid: pkt.ID.Stream, data: pkt.Data, metadata: ps.metadata}:
// Wait for NewServerStream to finish stream creation before reading the
// next frame. This guarantees curr is set for subsequent non-invoke
// packets.
m.pdone.Recv()
// TODO: reuse invoke assembler
delete(m.invokesAssembler, fr.ID.Stream)
// TODO: reuse pending stream
delete(m.pendingStreams, fr.ID.Stream)
case <-m.sigs.term.Signal():
}
return nil
Expand Down Expand Up @@ -309,6 +314,9 @@ func (m *Manager) Closed() <-chan struct{} {
// Unblocked returns a channel that is closed when the manager is no longer
// blocked. With multiplexing, multiple streams run concurrently and this
// channel is always closed immediately.
//
// TODO(shubham): audit whether this is still useful in a multiplexing world.
// The only meaningful caller is Pool.Take.
func (m *Manager) Unblocked() <-chan struct{} {
return closedCh
}
Expand All @@ -317,15 +325,19 @@ func (m *Manager) Unblocked() <-chan struct{} {
func (m *Manager) Close() error {
m.terminate(managerClosed.New("Close called"))

m.wg.Wait() // wait for all stream goroutines
m.sigs.read.Wait()
<-m.wr.Done() // wait for writer goroutine to exit
m.wg.Wait() // wait for all stream goroutines
m.sigs.read.Wait() // wait for reader goroutine to exit
m.sigs.tport.Wait()

return m.sigs.tport.Err()
}

// NewClientStream starts a stream on the managed transport for use by a client.
func (m *Manager) NewClientStream(ctx context.Context, rpc string) (stream *drpcstream.Stream, err error) {
if err := ctx.Err(); err != nil {
return nil, err
}
return m.newStream(ctx, m.lastStreamID.Add(1), drpc.StreamKindClient, rpc)
}

Expand Down
Loading