diff --git a/distribution/control_broadcast.go b/distribution/control_broadcast.go new file mode 100644 index 0000000..df1464a --- /dev/null +++ b/distribution/control_broadcast.go @@ -0,0 +1,92 @@ +package distribution + +import ( + "context" + "sync" +) + +// viewerControlBuffer is the per-subscriber channel buffer size for control +// messages. Matches viewerCaptionBuffer since control messages are similarly +// low-frequency. +const viewerControlBuffer = 10 + +// ControlBroadcaster fans out messages from a single source channel to +// multiple subscriber channels. Each subscriber gets its own buffered +// channel; slow subscribers have messages dropped (non-blocking send) +// rather than blocking other subscribers or the source. +type ControlBroadcaster struct { + mu sync.RWMutex + subscribers map[string]chan []byte +} + +// NewControlBroadcaster creates a ControlBroadcaster ready for use. +func NewControlBroadcaster() *ControlBroadcaster { + return &ControlBroadcaster{ + subscribers: make(map[string]chan []byte), + } +} + +// Subscribe creates a per-subscriber buffered channel and returns it. +// The caller must call Unsubscribe when done. If a channel already +// exists for the given id, it is closed and replaced. +func (b *ControlBroadcaster) Subscribe(id string) <-chan []byte { + ch := make(chan []byte, viewerControlBuffer) + b.mu.Lock() + if old, ok := b.subscribers[id]; ok { + close(old) + } + b.subscribers[id] = ch + b.mu.Unlock() + return ch +} + +// Unsubscribe removes and closes the subscriber's channel. It is safe +// to call multiple times for the same id. +func (b *ControlBroadcaster) Unsubscribe(id string) { + b.mu.Lock() + ch, ok := b.subscribers[id] + if ok { + delete(b.subscribers, id) + close(ch) + } + b.mu.Unlock() +} + +// Run reads from the source channel and fans out each message to all +// subscribers. It blocks until ctx is cancelled or the source channel +// is closed. Non-blocking sends: if a subscriber's channel is full, +// the message is dropped for that subscriber (matching the Viewer +// drop pattern). +func (b *ControlBroadcaster) Run(ctx context.Context, source <-chan []byte) { + for { + select { + case <-ctx.Done(): + b.closeAll() + return + case data, ok := <-source: + if !ok { + b.closeAll() + return + } + b.mu.RLock() + for _, ch := range b.subscribers { + select { + case ch <- data: + default: + // subscriber is slow; drop message + } + } + b.mu.RUnlock() + } + } +} + +// closeAll closes and removes all subscriber channels. +func (b *ControlBroadcaster) closeAll() { + b.mu.Lock() + for id, ch := range b.subscribers { + close(ch) + delete(b.subscribers, id) + } + b.mu.Unlock() +} diff --git a/distribution/control_broadcast_test.go b/distribution/control_broadcast_test.go new file mode 100644 index 0000000..adcac88 --- /dev/null +++ b/distribution/control_broadcast_test.go @@ -0,0 +1,146 @@ +package distribution + +import ( + "context" + "testing" + "time" +) + +func TestControlBroadcasterFanOut(t *testing.T) { + t.Parallel() + b := NewControlBroadcaster() + source := make(chan []byte, 10) + go b.Run(context.Background(), source) + + ch1 := b.Subscribe("s1") + ch2 := b.Subscribe("s2") + + source <- []byte(`{"state":"live"}`) + + // Both subscribers should receive the same message. + for _, tc := range []struct { + name string + ch <-chan []byte + }{ + {"s1", ch1}, + {"s2", ch2}, + } { + select { + case data := <-tc.ch: + if string(data) != `{"state":"live"}` { + t.Fatalf("%s got %q", tc.name, data) + } + case <-time.After(time.Second): + t.Fatalf("%s timeout", tc.name) + } + } +} + +func TestControlBroadcasterUnsubscribe(t *testing.T) { + t.Parallel() + b := NewControlBroadcaster() + source := make(chan []byte, 10) + go b.Run(context.Background(), source) + + ch1 := b.Subscribe("s1") + _ = b.Subscribe("s2") + + b.Unsubscribe("s2") + + source <- []byte(`{"state":"off"}`) + + // s1 should still receive. + select { + case <-ch1: + case <-time.After(time.Second): + t.Fatal("s1 timeout after s2 unsubscribe") + } +} + +func TestControlBroadcasterUnsubscribeIdempotent(t *testing.T) { + t.Parallel() + b := NewControlBroadcaster() + source := make(chan []byte, 10) + go b.Run(context.Background(), source) + + b.Subscribe("s1") + b.Unsubscribe("s1") + b.Unsubscribe("s1") // second call must not panic +} + +func TestControlBroadcasterSourceClose(t *testing.T) { + t.Parallel() + b := NewControlBroadcaster() + source := make(chan []byte) + go b.Run(context.Background(), source) + + ch := b.Subscribe("s1") + close(source) + + // Subscriber channel should be closed. + select { + case _, ok := <-ch: + if ok { + t.Fatal("expected channel to be closed") + } + case <-time.After(time.Second): + t.Fatal("timeout waiting for channel close") + } +} + +func TestControlBroadcasterContextCancel(t *testing.T) { + t.Parallel() + b := NewControlBroadcaster() + source := make(chan []byte) // unbuffered, never closed + + ctx, cancel := context.WithCancel(context.Background()) + go b.Run(ctx, source) + + ch := b.Subscribe("s1") + cancel() + + // Subscriber channel should be closed when context is cancelled. + select { + case _, ok := <-ch: + if ok { + t.Fatal("expected channel to be closed") + } + case <-time.After(time.Second): + t.Fatal("timeout waiting for channel close after ctx cancel") + } +} + +func TestControlBroadcasterDropOnFull(t *testing.T) { + t.Parallel() + b := NewControlBroadcaster() + source := make(chan []byte, 100) + go b.Run(context.Background(), source) + + ch := b.Subscribe("s1") + + // Send more messages than the buffer can hold. + for i := 0; i < viewerControlBuffer+5; i++ { + source <- []byte(`{"i":"data"}`) + } + + // Allow broadcaster goroutine to process. + time.Sleep(50 * time.Millisecond) + + // Should have exactly viewerControlBuffer messages (extras dropped). + if len(ch) != viewerControlBuffer { + t.Fatalf("channel length = %d, want %d", len(ch), viewerControlBuffer) + } +} + +func TestControlBroadcasterNoSubscribers(t *testing.T) { + t.Parallel() + b := NewControlBroadcaster() + source := make(chan []byte, 10) + go b.Run(context.Background(), source) + + // Send with no subscribers — should not block or panic. + source <- []byte(`{"empty":"room"}`) + + // Allow processing, then close cleanly. + close(source) +} diff --git a/distribution/moq_catalog.go b/distribution/moq_catalog.go index 8358e0c..cab1f64 100644 --- a/distribution/moq_catalog.go +++ b/distribution/moq_catalog.go @@ -42,7 +42,9 @@ type moqSelectionParams struct { } // buildMoQCatalog assembles the catalog JSON for a stream. -func buildMoQCatalog(streamKey string, relay *Relay) ([]byte, error) { +// If controlEnabled is true, a "control" track is included in the catalog +// for application-level state broadcast (e.g., switcher control room state). +func buildMoQCatalog(streamKey string, relay *Relay, controlEnabled bool) ([]byte, error) { vi := relay.VideoInfo() ai := relay.AudioInfo() @@ -98,6 +100,16 @@ func buildMoQCatalog(streamKey string, relay *Relay) ([]byte, error) { }, }) + // Control track (application-level state broadcast as JSON) + if controlEnabled { + catalog.Tracks = append(catalog.Tracks, moqCatalogTrack{ + Name: "control", + SelectionParams: moqSelectionParams{ + Codec: "application/json", + }, + }) + } + return json.Marshal(catalog) } diff --git a/distribution/moq_catalog_test.go b/distribution/moq_catalog_test.go index 75fd8b3..dbca0cf 100644 --- a/distribution/moq_catalog_test.go +++ b/distribution/moq_catalog_test.go @@ -8,7 +8,7 @@ import ( func TestBuildMoQCatalogBasic(t *testing.T) { t.Parallel() relay := NewRelay() - data, err := buildMoQCatalog("teststream", relay) + data, err := buildMoQCatalog("teststream", relay, false) if err != nil { t.Fatal(err) } @@ -83,7 +83,7 @@ func TestBuildMoQCatalogMultiAudio(t *testing.T) { relay := NewRelay() relay.SetAudioTrackCount(3) - data, err := buildMoQCatalog("multi", relay) + data, err := buildMoQCatalog("multi", relay, false) if err != nil { t.Fatal(err) } @@ -114,7 +114,7 @@ func TestBuildMoQCatalogCustomVideoInfo(t *testing.T) { relay.videoInfoSet = true relay.mu.Unlock() - data, err := buildMoQCatalog("4k", relay) + data, err := buildMoQCatalog("4k", relay, false) if err != nil { t.Fatal(err) } @@ -136,7 +136,7 @@ func TestBuildMoQCatalogCustomVideoInfo(t *testing.T) { func TestBuildMoQCatalogJSONFieldNames(t *testing.T) { t.Parallel() relay := NewRelay() - data, err := buildMoQCatalog("test", relay) + data, err := buildMoQCatalog("test", relay, false) if err != nil { t.Fatal(err) } @@ -168,7 +168,7 @@ func TestBuildMoQCatalogCustomAudioInfo(t *testing.T) { relay := NewRelay() relay.SetAudioInfo(AudioInfo{Codec: "mp4a.40.05", SampleRate: 44100, Channels: 1}) - data, err := buildMoQCatalog("custom-audio", relay) + data, err := buildMoQCatalog("custom-audio", relay, false) if err != nil { t.Fatal(err) } @@ -189,3 +189,43 @@ func TestBuildMoQCatalogCustomAudioInfo(t *testing.T) { t.Fatalf("audio channelConfig = %q", ap.ChannelConfig) } } + +func TestBuildMoQCatalogControlTrack(t *testing.T) { + t.Parallel() + relay := NewRelay() + + // Without control enabled: 4 tracks (video + audio0 + captions + stats) + dataNoControl, err := buildMoQCatalog("test", relay, false) + if err != nil { + t.Fatal(err) + } + var catNoControl moqCatalog + if err := json.Unmarshal(dataNoControl, &catNoControl); err != nil { + t.Fatal(err) + } + if len(catNoControl.Tracks) != 4 { + t.Fatalf("without control: track count = %d, want 4", len(catNoControl.Tracks)) + } + + // With control enabled: 5 tracks (video + audio0 + captions + stats + control) + dataWithControl, err := buildMoQCatalog("test", relay, true) + if err != nil { + t.Fatal(err) + } + var catWithControl moqCatalog + if err := json.Unmarshal(dataWithControl, &catWithControl); err != nil { + t.Fatal(err) + } + if len(catWithControl.Tracks) != 5 { + t.Fatalf("with control: track count = %d, want 5", len(catWithControl.Tracks)) + } + + // Verify the control track is last and has the right codec + controlTrack := catWithControl.Tracks[4] + if controlTrack.Name != "control" { + t.Fatalf("control track name = %q, want %q", controlTrack.Name, "control") + } + if controlTrack.SelectionParams.Codec != "application/json" { + t.Fatalf("control track codec = %q, want %q", controlTrack.SelectionParams.Codec, "application/json") + } +} diff --git a/distribution/moq_session.go b/distribution/moq_session.go index 4dd4213..f7e5535 100644 --- a/distribution/moq_session.go +++ b/distribution/moq_session.go @@ -42,15 +42,16 @@ type StatsProviderFunc func(streamKey string) StatsProvider // interface so the Relay can fan out frames to it. Internally, it dispatches // frames to per-track subscriptions, each with its own write loop and moqWriter. type MoQSession struct { - id string - log *slog.Logger - streamKey string - session *webtransport.Session - control webtransport.Stream - controlReader *bufio.Reader // persistent buffered reader for control stream - relay *Relay - statsProvider StatsProviderFunc - controlMu sync.Mutex + id string + log *slog.Logger + streamKey string + session *webtransport.Session + control webtransport.Stream + controlReader *bufio.Reader // persistent buffered reader for control stream + relay *Relay + statsProvider StatsProviderFunc + controlBroadcaster *ControlBroadcaster + controlMu sync.Mutex mu sync.RWMutex subscriptions map[string]*moqTrackSub // key: trackName @@ -72,26 +73,28 @@ type MoQSession struct { // MoQSessionConfig holds the parameters for creating a new MoQ session. type MoQSessionConfig struct { - ID string - Session *webtransport.Session - Control webtransport.Stream - StreamKey string - Relay *Relay - StatsProvider StatsProviderFunc + ID string + Session *webtransport.Session + Control webtransport.Stream + StreamKey string + Relay *Relay + StatsProvider StatsProviderFunc + ControlBroadcaster *ControlBroadcaster } // NewMoQSession creates a new MoQ session for the given stream key. func NewMoQSession(cfg MoQSessionConfig) *MoQSession { return &MoQSession{ - id: cfg.ID, - log: slog.With("session", cfg.ID, "stream", cfg.StreamKey), - streamKey: cfg.StreamKey, - session: cfg.Session, - control: cfg.Control, - controlReader: bufio.NewReader(cfg.Control), - relay: cfg.Relay, - statsProvider: cfg.StatsProvider, - subscriptions: make(map[string]*moqTrackSub), + id: cfg.ID, + log: slog.With("session", cfg.ID, "stream", cfg.StreamKey), + streamKey: cfg.StreamKey, + session: cfg.Session, + control: cfg.Control, + controlReader: bufio.NewReader(cfg.Control), + relay: cfg.Relay, + statsProvider: cfg.StatsProvider, + controlBroadcaster: cfg.ControlBroadcaster, + subscriptions: make(map[string]*moqTrackSub), } } @@ -260,6 +263,9 @@ func (m *MoQSession) handleSubscribe(ctx context.Context, sub moq.Subscribe) { case "stats": m.handleStatsSubscribe(ctx, sub, alias) + case "control": + m.handleControlSubscribe(ctx, sub, alias) + default: // Check for audio tracks: "audio0", "audio1", etc. if suffix, ok := strings.CutPrefix(trackName, "audio"); ok { @@ -274,7 +280,7 @@ func (m *MoQSession) handleSubscribe(ctx context.Context, sub moq.Subscribe) { // handleCatalogSubscribe builds and delivers the catalog, then sends SUBSCRIBE_OK. func (m *MoQSession) handleCatalogSubscribe(ctx context.Context, sub moq.Subscribe, alias uint64) { - catalogJSON, err := buildMoQCatalog(m.streamKey, m.relay) + catalogJSON, err := buildMoQCatalog(m.streamKey, m.relay, m.controlBroadcaster != nil) if err != nil { m.sendSubscribeError(sub.RequestID, 500, "catalog build failed") return @@ -592,7 +598,7 @@ func (m *MoQSession) writeCaptionLoop(ctx context.Context, sub *moqTrackSub) { } data := frame.Serialize() - n, err := sub.writer.WriteCaptionFrame(stream, data, tsMS) + n, err := sub.writer.WriteDataObject(stream, data, tsMS) if err != nil { stream.Close() m.log.Debug("caption frame write failed", "error", err) @@ -679,7 +685,7 @@ func (m *MoQSession) writeStatsLoop(ctx context.Context, sub *moqTrackSub) { return } - n, err := sub.writer.WriteCaptionFrame(stream, data, tsMS) + n, err := sub.writer.WriteDataObject(stream, data, tsMS) if err != nil { stream.Close() m.log.Debug("stats write failed", "error", err) @@ -692,3 +698,87 @@ func (m *MoQSession) writeStatsLoop(ctx context.Context, sub *moqTrackSub) { } } } + +// handleControlSubscribe sets up the control track subscription and starts the write loop. +// The control track delivers application-level JSON state updates (e.g., switcher control +// room state) to connected browsers. It is only available when ControlCh is configured +// in ServerConfig. Each viewer gets its own channel from the ControlBroadcaster. +func (m *MoQSession) handleControlSubscribe(ctx context.Context, sub moq.Subscribe, alias uint64) { + if m.controlBroadcaster == nil { + m.sendSubscribeError(sub.RequestID, 404, "control track not available") + return + } + + subCtx, subCancel := context.WithCancel(ctx) + + // Subscribe to the broadcaster to get a per-session channel. + ch := m.controlBroadcaster.Subscribe(m.id) + + trackSub := &moqTrackSub{ + requestID: sub.RequestID, + trackAlias: alias, + trackName: "control", + writer: NewMoQWriter(alias, priorityControl), + cancel: func() { + subCancel() + m.controlBroadcaster.Unsubscribe(m.id) + }, + } + + m.mu.Lock() + m.subscriptions["control"] = trackSub + m.mu.Unlock() + + m.sendSubscribeOK(sub.RequestID, alias, moq.GroupOrderAscending, false, 0, 0) + + go m.writeControlLoop(subCtx, trackSub, ch) + + m.log.Debug("control track subscribed", + "alias", alias, + "requestID", sub.RequestID) +} + +// writeControlLoop reads JSON state snapshots from the control channel and sends +// each as a MoQ group on a new uni-stream, following the same pattern as writeStatsLoop. +func (m *MoQSession) writeControlLoop(ctx context.Context, sub *moqTrackSub, ch <-chan []byte) { + var groupID uint32 + + for { + select { + case <-ctx.Done(): + return + case data, ok := <-ch: + if !ok { + return + } + + if m.closed.Load() { + return + } + + stream, err := m.session.OpenUniStreamSync(ctx) + if err != nil { + m.log.Debug("control stream open failed", "error", err) + return + } + + tsMS := uint32(time.Now().UnixMilli()) + if err := sub.writer.WriteStreamHeader(stream, 0, groupID, tsMS); err != nil { + stream.Close() + m.log.Debug("control header write failed", "error", err) + return + } + + n, err := sub.writer.WriteDataObject(stream, data, tsMS) + if err != nil { + stream.Close() + m.log.Debug("control write failed", "error", err) + return + } + + m.bytesSent.Add(n + sub.writer.StreamHeaderSize()) + groupID++ + stream.Close() + } + } +} diff --git a/distribution/moq_session_test.go b/distribution/moq_session_test.go index 94278e2..9f1cbd9 100644 --- a/distribution/moq_session_test.go +++ b/distribution/moq_session_test.go @@ -644,6 +644,103 @@ func TestMoQSessionStats(t *testing.T) { } } +func TestMoQSessionHandleControlSubscribeNoBroadcaster(t *testing.T) { + t.Parallel() + relay := NewRelay() + responseBuf := &bytes.Buffer{} + controlStream := &mockControlStream{ + Reader: &bytes.Buffer{}, + Writer: responseBuf, + } + + session := &MoQSession{ + id: "test-session", + streamKey: "live", + control: controlStream, + log: slog.With("session", "test-session"), + relay: relay, + subscriptions: make(map[string]*moqTrackSub), + // controlBroadcaster is nil + } + + sub := moq.Subscribe{ + RequestID: 10, + Namespace: []string{"prism", "live"}, + TrackName: "control", + FilterType: moq.FilterNextGroupStart, + } + + session.handleSubscribe(context.Background(), sub) + + msgType, payload, err := moq.ReadControlMsg(responseBuf) + if err != nil { + t.Fatal(err) + } + if msgType != moq.MsgSubscribeError { + t.Fatalf("response type = %#x, want SUBSCRIBE_ERROR", msgType) + } + + reqID, off := readVarint(payload, 0) + errCode, _ := readVarint(payload, off) + if reqID != 10 { + t.Fatalf("requestID = %d, want 10", reqID) + } + if errCode != 404 { + t.Fatalf("errorCode = %d, want 404", errCode) + } +} + +func TestMoQSessionHandleControlSubscribeWithBroadcaster(t *testing.T) { + t.Parallel() + relay := NewRelay() + responseBuf := &bytes.Buffer{} + controlStream := &mockControlStream{ + Reader: &bytes.Buffer{}, + Writer: responseBuf, + } + + broadcaster := NewControlBroadcaster() + source := make(chan []byte, 10) + go broadcaster.Run(context.Background(), source) + defer close(source) + + session := &MoQSession{ + id: "test-session", + streamKey: "live", + control: controlStream, + log: slog.With("session", "test-session"), + relay: relay, + controlBroadcaster: broadcaster, + subscriptions: make(map[string]*moqTrackSub), + } + + sub := moq.Subscribe{ + RequestID: 11, + Namespace: []string{"prism", "live"}, + TrackName: "control", + FilterType: moq.FilterNextGroupStart, + } + + session.handleSubscribe(context.Background(), sub) + + // Should get SUBSCRIBE_OK + msgType, _, err := moq.ReadControlMsg(responseBuf) + if err != nil { + t.Fatal(err) + } + if msgType != moq.MsgSubscribeOK { + t.Fatalf("response type = %#x, want SUBSCRIBE_OK", msgType) + } + + // Verify subscription was created + session.mu.RLock() + controlSub := session.subscriptions["control"] + session.mu.RUnlock() + if controlSub == nil { + t.Fatal("control subscription not created") + } +} + // mockControlStream implements webtransport.Stream for test purposes. // It uses separate Reader/Writer to simulate the control stream. type mockControlStream struct { diff --git a/distribution/moq_writer.go b/distribution/moq_writer.go index 5120cbf..71cc942 100644 --- a/distribution/moq_writer.go +++ b/distribution/moq_writer.go @@ -115,7 +115,7 @@ func (m *moqWriter) WriteAudioFrame(w io.Writer, data []byte, timestampMS uint32 return m.writeObject(w, exts, payload) } -func (m *moqWriter) WriteCaptionFrame(w io.Writer, data []byte, timestampMS uint32) (int64, error) { +func (m *moqWriter) WriteDataObject(w io.Writer, data []byte, timestampMS uint32) (int64, error) { var exts []byte exts = quicvarint.Append(exts, locExtCaptureTimestamp) exts = quicvarint.Append(exts, uint64(timestampMS)*1000) diff --git a/distribution/moq_writer_test.go b/distribution/moq_writer_test.go index 4bbff73..19ba36e 100644 --- a/distribution/moq_writer_test.go +++ b/distribution/moq_writer_test.go @@ -351,7 +351,7 @@ func TestMoQWriterAudioFrame(t *testing.T) { } } -func TestMoQWriterCaptionFrame(t *testing.T) { +func TestMoQWriterDataObject(t *testing.T) { t.Parallel() w := NewMoQWriter(10, 200) var buf bytes.Buffer @@ -363,9 +363,9 @@ func TestMoQWriterCaptionFrame(t *testing.T) { captionData := []byte{0xCC, 0x02, 0x01, 0x02, 0x03} - n, err := w.WriteCaptionFrame(&buf, captionData, 3000) + n, err := w.WriteDataObject(&buf, captionData, 3000) if err != nil { - t.Fatalf("WriteCaptionFrame failed: %v", err) + t.Fatalf("WriteDataObject failed: %v", err) } if n != int64(buf.Len()) { @@ -479,12 +479,12 @@ func TestMoQWriterBytesWritten(t *testing.T) { } buf.Reset() - n, err = w.WriteCaptionFrame(&buf, []byte{0x01, 0x02, 0x03}, 200) + n, err = w.WriteDataObject(&buf, []byte{0x01, 0x02, 0x03}, 200) if err != nil { - t.Fatalf("WriteCaptionFrame failed: %v", err) + t.Fatalf("WriteDataObject failed: %v", err) } if n != int64(buf.Len()) { - t.Errorf("WriteCaptionFrame: returned %d, buffer has %d", n, buf.Len()) + t.Errorf("WriteDataObject: returned %d, buffer has %d", n, buf.Len()) } } diff --git a/distribution/protocol.go b/distribution/protocol.go index 106db06..767cf58 100644 --- a/distribution/protocol.go +++ b/distribution/protocol.go @@ -31,6 +31,7 @@ const ( priorityVideo = 128 priorityAudio = 128 priorityCaptions = 200 + priorityControl = 210 priorityStats = 220 ) @@ -55,9 +56,10 @@ type StreamFrameWriter interface { // returning the total bytes written. WriteAudioFrame(w io.Writer, data []byte, timestampMS uint32) (int64, error) - // WriteCaptionFrame writes a single caption frame (header + payload) to w, - // returning the total bytes written. - WriteCaptionFrame(w io.Writer, data []byte, timestampMS uint32) (int64, error) + // WriteDataObject writes a single data object (caption, stats JSON, + // control JSON, etc.) with header + payload to w, returning the total + // bytes written. + WriteDataObject(w io.Writer, data []byte, timestampMS uint32) (int64, error) // StreamHeaderSize returns the byte size of the stream header written // by WriteStreamHeader, used for accurate byte accounting. diff --git a/distribution/server.go b/distribution/server.go index b55bba7..6113c94 100644 --- a/distribution/server.go +++ b/distribution/server.go @@ -132,6 +132,24 @@ type ServerConfig struct { SRTStop SRTStopFunc SRTList SRTListFunc ExtraRoutes func(mux *http.ServeMux) + + // OnStreamRegistered is called after a new stream relay is created + // and added to the server's stream map. It is NOT called when + // RegisterStream returns an existing relay for a duplicate key. + // The callback is invoked outside the server's mutex. + OnStreamRegistered func(key string, relay *Relay) + + // OnStreamUnregistered is called after a stream is removed from + // the server's stream map. It is NOT called if the stream key + // was not present. The callback is invoked outside the server's mutex. + OnStreamUnregistered func(key string) + + // ControlCh receives JSON-encoded control state. If set, a "control" + // track is advertised in the MoQ catalog and subscribers receive state + // updates as JSON objects. Each send produces one MoQ group. + // Messages are internally broadcast to all connected viewers via + // ControlBroadcaster. + ControlCh <-chan []byte } // streamResources bundles the relay and stats provider for a single live @@ -150,6 +168,8 @@ type Server struct { mu sync.RWMutex streams map[string]*streamResources + + controlBroadcaster *ControlBroadcaster // nil if ControlCh not configured } // NewServer creates a distribution Server with the given configuration. @@ -161,30 +181,53 @@ func NewServer(config ServerConfig) (*Server, error) { if config.Addr == "" { return nil, errors.New("distribution: Addr is required") } - return &Server{ + s := &Server{ config: config, streams: make(map[string]*streamResources), - }, nil + } + if config.ControlCh != nil { + s.controlBroadcaster = NewControlBroadcaster() + } + return s, nil } // RegisterStream creates a Relay for the given stream key and returns it. // If the stream already has a relay, the existing one is returned. +// For new streams, OnStreamRegistered is called (if set) after releasing +// the lock. Concurrent calls with the same key are safe (only one creates +// a relay), but the callback may observe transient inconsistency if a +// concurrent UnregisterStream for the same key interleaves between the +// lock release and the callback invocation. func (s *Server) RegisterStream(streamKey string) *Relay { s.mu.Lock() - defer s.mu.Unlock() if sr, ok := s.streams[streamKey]; ok { + s.mu.Unlock() return sr.relay } r := NewRelay() s.streams[streamKey] = &streamResources{relay: r} + s.mu.Unlock() + + if s.config.OnStreamRegistered != nil { + s.config.OnStreamRegistered(streamKey, r) + } return r } // UnregisterStream removes the relay and pipeline for a stream key. +// If the stream existed, OnStreamUnregistered is called (if set) after +// releasing the lock. If a concurrent RegisterStream for the same key +// races with this call, the callback may fire after a new relay has +// already been registered. func (s *Server) UnregisterStream(streamKey string) { s.mu.Lock() - defer s.mu.Unlock() + _, existed := s.streams[streamKey] delete(s.streams, streamKey) + s.mu.Unlock() + + if existed && s.config.OnStreamUnregistered != nil { + s.config.OnStreamUnregistered(streamKey) + } } // SetPipeline associates a StatsProvider with a stream key. The stream @@ -301,6 +344,10 @@ func (s *Server) Start(ctx context.Context) error { }, } + if s.controlBroadcaster != nil { + go s.controlBroadcaster.Run(ctx, s.config.ControlCh) + } + slog.Info("WebTransport server listening", "addr", s.config.Addr) stop := context.AfterFunc(ctx, func() { s.wtSrv.Close() }) @@ -384,12 +431,13 @@ func (s *Server) setupMoQ(r *http.Request, session *webtransport.Session, contro } moqSession := NewMoQSession(MoQSessionConfig{ - ID: fmt.Sprintf("moq-%s-%s", streamKey, r.RemoteAddr), - Session: session, - Control: controlStream, - StreamKey: streamKey, - Relay: relay, - StatsProvider: s.GetPipeline, + ID: fmt.Sprintf("moq-%s-%s", streamKey, r.RemoteAddr), + Session: session, + Control: controlStream, + StreamKey: streamKey, + Relay: relay, + StatsProvider: s.GetPipeline, + ControlBroadcaster: s.controlBroadcaster, }) pathKey, err := moqSession.handleSetup() diff --git a/distribution/server_test.go b/distribution/server_test.go index 96e7c6d..b5469d7 100644 --- a/distribution/server_test.go +++ b/distribution/server_test.go @@ -292,3 +292,120 @@ func TestNewServerValidation(t *testing.T) { } }) } + +func TestStreamLifecycleCallbacks(t *testing.T) { + t.Parallel() + + cert, err := certs.Generate(24 * 60 * 60 * 1e9) + if err != nil { + t.Fatalf("certs.Generate: %v", err) + } + + t.Run("OnStreamRegistered fires on new stream", func(t *testing.T) { + t.Parallel() + + var gotKey string + var gotRelay *Relay + srv, err := NewServer(ServerConfig{ + Addr: ":0", + Cert: cert, + OnStreamRegistered: func(key string, relay *Relay) { + gotKey = key + gotRelay = relay + }, + }) + if err != nil { + t.Fatalf("NewServer: %v", err) + } + + relay := srv.RegisterStream("cam1") + if gotKey != "cam1" { + t.Fatalf("callback key = %q, want %q", gotKey, "cam1") + } + if gotRelay != relay { + t.Fatal("callback relay does not match returned relay") + } + }) + + t.Run("OnStreamRegistered does NOT fire on duplicate", func(t *testing.T) { + t.Parallel() + + callCount := 0 + srv, err := NewServer(ServerConfig{ + Addr: ":0", + Cert: cert, + OnStreamRegistered: func(key string, relay *Relay) { + callCount++ + }, + }) + if err != nil { + t.Fatalf("NewServer: %v", err) + } + + srv.RegisterStream("cam1") + srv.RegisterStream("cam1") // duplicate + if callCount != 1 { + t.Fatalf("callback called %d times, want 1", callCount) + } + }) + + t.Run("OnStreamUnregistered fires on removal", func(t *testing.T) { + t.Parallel() + + var gotKey string + srv, err := NewServer(ServerConfig{ + Addr: ":0", + Cert: cert, + OnStreamUnregistered: func(key string) { + gotKey = key + }, + }) + if err != nil { + t.Fatalf("NewServer: %v", err) + } + + srv.RegisterStream("cam1") + srv.UnregisterStream("cam1") + if gotKey != "cam1" { + t.Fatalf("callback key = %q, want %q", gotKey, "cam1") + } + }) + + t.Run("OnStreamUnregistered does NOT fire if stream missing", func(t *testing.T) { + t.Parallel() + + called := false + srv, err := NewServer(ServerConfig{ + Addr: ":0", + Cert: cert, + OnStreamUnregistered: func(key string) { + called = true + }, + }) + if err != nil { + t.Fatalf("NewServer: %v", err) + } + + srv.UnregisterStream("nonexistent") + if called { + t.Fatal("callback should not fire for nonexistent stream") + } + }) + + t.Run("nil callbacks do not panic", func(t *testing.T) { + t.Parallel() + + srv, err := NewServer(ServerConfig{ + Addr: ":0", + Cert: cert, + // No callbacks set — should not panic. + }) + if err != nil { + t.Fatalf("NewServer: %v", err) + } + + srv.RegisterStream("cam1") + srv.UnregisterStream("cam1") + // If we get here without panicking, the test passes. + }) +} diff --git a/web/src/main.ts b/web/src/main.ts index 09a2e6c..3efea3b 100644 --- a/web/src/main.ts +++ b/web/src/main.ts @@ -38,6 +38,7 @@ let currentMode: "single" | "multi" = "single"; let singlePlayer: PrismPlayer | null = null; let multiview: MultiviewManager | null = null; let cachedStreams: StreamListEntry[] = []; +let lastSingleStreamKey: string | null = null; function hideEmptyState(): void { const el = document.getElementById("emptyState"); @@ -66,6 +67,7 @@ function initSingleMode(): void { onStreamDisconnected: (key) => { statusEl.textContent = `Disconnected from "${key}". Reconnecting...`; connectBtn.textContent = "Watch"; + connectBtn.disabled = false; }, }); singlePlayer.setMaxResolution(cap); @@ -111,8 +113,17 @@ function switchMode(mode: "single" | "multi"): void { } initSingleMode(); - statusEl.textContent = "Enter a stream key and click Watch to connect."; - connectBtn.textContent = "Watch"; + + if (lastSingleStreamKey) { + hideEmptyState(); + connectBtn.disabled = true; + statusEl.textContent = `Connecting to "${lastSingleStreamKey}"...`; + singlePlayer!.connect(lastSingleStreamKey); + } else { + statusEl.textContent = "Enter a stream key and click Watch to connect."; + connectBtn.textContent = "Watch"; + connectBtn.disabled = false; + } } else { singleModeEl.style.display = "none"; multiModeEl.style.display = "block"; @@ -141,6 +152,7 @@ connectBtn.addEventListener("click", () => { if (singlePlayer?.isConnected()) { singlePlayer.disconnect(); + lastSingleStreamKey = null; statusEl.textContent = "Disconnected."; connectBtn.textContent = "Watch"; connectBtn.disabled = false; @@ -157,6 +169,7 @@ connectBtn.addEventListener("click", () => { hideEmptyState(); connectBtn.disabled = true; statusEl.textContent = `Connecting to "${streamKey}"...`; + lastSingleStreamKey = streamKey; singlePlayer!.connect(streamKey); }); @@ -230,6 +243,7 @@ srtPullConnect.addEventListener("click", async () => { if (currentMode === "single") { streamKeyInput.value = streamKey; + lastSingleStreamKey = streamKey; initSingleMode(); singlePlayer!.connect(streamKey); statusEl.textContent = `Connected to SRT pull from ${address}`; @@ -302,6 +316,7 @@ async function fetchStreams(): Promise { tag.addEventListener("click", () => { if (currentMode === "single") { streamKeyInput.value = stream.key; + lastSingleStreamKey = stream.key; initSingleMode(); hideEmptyState(); singlePlayer!.connect(stream.key); @@ -354,6 +369,7 @@ function showClickToStart(onStart: () => void, target: HTMLElement): void { function showClickToPlay(streamKey: string): void { hideEmptyState(); showClickToStart(() => { + lastSingleStreamKey = streamKey; singlePlayer!.connect(streamKey); connectBtn.disabled = true; statusEl.textContent = `Connecting to "${streamKey}"...`; diff --git a/web/src/player.ts b/web/src/player.ts index e90f940..02805fa 100644 --- a/web/src/player.ts +++ b/web/src/player.ts @@ -98,6 +98,7 @@ export class PrismPlayer { private destroyed = false; private globalMute: boolean; private reconnectDelay = 2000; + private resumeAudioHandler: (() => void) | null = null; constructor(container: HTMLElement, options: PlayerOptions = {}) { this.container = container; @@ -141,14 +142,16 @@ export class PrismPlayer { container.appendChild(this.statsEl); container.appendChild(this.captionsEl); - // Resume suspended AudioContext on first user gesture (browser autoplay policy). - const resumeAudio = () => { + // Resume suspended AudioContext on user gesture (browser autoplay policy). + // Not { once: true } — the context is recreated on each connection, so we + // need this to keep working across disconnect/reconnect cycles. + this.resumeAudioHandler = () => { if (this.sharedAudioContext && this.sharedAudioContext.state === "suspended") { this.sharedAudioContext.resume(); } }; - document.addEventListener("click", resumeAudio, { once: true }); - document.addEventListener("keydown", resumeAudio, { once: true }); + document.addEventListener("click", this.resumeAudioHandler); + document.addEventListener("keydown", this.resumeAudioHandler); this.playerUI = new PlayerUI({ container: this.container, @@ -488,6 +491,11 @@ export class PrismPlayer { this.inspector?.destroy(); if (this.fullscreenBtn) this.fullscreenBtn.destroy(); this.playerUI.destroy(); + if (this.resumeAudioHandler) { + document.removeEventListener("click", this.resumeAudioHandler); + document.removeEventListener("keydown", this.resumeAudioHandler); + this.resumeAudioHandler = null; + } this.container.innerHTML = ""; } diff --git a/web/src/renderer.ts b/web/src/renderer.ts index 1c7316e..5363ae3 100644 --- a/web/src/renderer.ts +++ b/web/src/renderer.ts @@ -351,5 +351,14 @@ export class PrismRenderer { this.lastDrawnFrame.close(); this.lastDrawnFrame = null; } + // Reset timing state so the next start() begins fresh rather + // than pacing against stale PTS values from a prior session. + this.currentVideoPTS = -1; + this.currentAudioPTS = -1; + this.freeRunStart = -1; + this.freeRunBasePTS = -1; + this.lastAudioAdvanceTime = 0; + this.audioStallFreeRunStart = -1; + this.audioStallFreeRunBasePTS = -1; } }