diff --git a/.gitignore b/.gitignore
index 0710f85..1e825cf 100644
--- a/.gitignore
+++ b/.gitignore
@@ -31,6 +31,8 @@ test/harness/*
*.swp
*~
+docs/plans/*
+
# Environment
.env
.env.*
diff --git a/cmd/prism/main.go b/cmd/prism/main.go
index e570090..1591501 100644
--- a/cmd/prism/main.go
+++ b/cmd/prism/main.go
@@ -19,6 +19,7 @@ import (
"github.com/zsiec/prism/certs"
"github.com/zsiec/prism/distribution"
"github.com/zsiec/prism/ingest"
+ dashingest "github.com/zsiec/prism/ingest/dash"
srtingest "github.com/zsiec/prism/ingest/srt"
"github.com/zsiec/prism/pipeline"
"github.com/zsiec/prism/stream"
@@ -80,6 +81,7 @@ func main() {
a.handleNewStream(ctx, key, input, format)
})
a.srtCaller = srtingest.NewCaller(a.registry, nil)
+ a.dashPuller = dashingest.NewPuller(nil)
var distErr error
a.distSrv, distErr = distribution.NewServer(distribution.ServerConfig{
@@ -96,7 +98,19 @@ func main() {
SRTStop: func(streamKey string) error {
return a.srtCaller.Stop(streamKey)
},
- SRTList: a.listSRTPulls,
+ SRTList: a.listSRTPulls,
+ DASHPull: func(url, streamKey, videoRepID, audioRepID string) error {
+ return a.dashPuller.Pull(ctx, dashingest.PullRequest{
+ URL: url,
+ StreamKey: streamKey,
+ VideoRepID: videoRepID,
+ AudioRepID: audioRepID,
+ }, a.distSrv, &streamManagerAdapter{a.mgr})
+ },
+ DASHStop: func(streamKey string) error {
+ return a.dashPuller.Stop(streamKey)
+ },
+ DASHList: a.listDASHPulls,
StreamLister: a.listStreams,
IngestLookup: a.lookupIngest,
})
@@ -138,6 +152,18 @@ func main() {
return a.distSrv.Start(ctx)
})
+ if dashURL := os.Getenv("DASH_SOURCE"); dashURL != "" {
+ g.Go(func() error {
+ <-time.After(500 * time.Millisecond)
+ return a.dashPuller.Pull(ctx, dashingest.PullRequest{
+ URL: dashURL,
+ StreamKey: envOr("DASH_STREAM_KEY", "dash-live"),
+ VideoRepID: os.Getenv("DASH_VIDEO_REP"),
+ AudioRepID: os.Getenv("DASH_AUDIO_REP"),
+ }, a.distSrv, &streamManagerAdapter{a.mgr})
+ })
+ }
+
if err := g.Wait(); err != nil {
slog.Error("server error", "error", err)
os.Exit(1)
@@ -145,10 +171,11 @@ func main() {
}
type app struct {
- mgr *stream.Manager
- registry *ingest.Registry
- srtCaller *srtingest.Caller
- distSrv *distribution.Server
+ mgr *stream.Manager
+ registry *ingest.Registry
+ srtCaller *srtingest.Caller
+ dashPuller *dashingest.Puller
+ distSrv *distribution.Server
}
func (a *app) listSRTPulls() []distribution.SRTPullInfo {
@@ -164,6 +191,20 @@ func (a *app) listSRTPulls() []distribution.SRTPullInfo {
return out
}
+func (a *app) listDASHPulls() []distribution.DASHPullInfo {
+ pulls := a.dashPuller.ActivePulls()
+ out := make([]distribution.DASHPullInfo, len(pulls))
+ for i, p := range pulls {
+ out[i] = distribution.DASHPullInfo{
+ URL: p.URL,
+ StreamKey: p.StreamKey,
+ VideoRepID: p.VideoRepID,
+ AudioRepID: p.AudioRepID,
+ }
+ }
+ return out
+}
+
func (a *app) listStreams() []distribution.StreamInfo {
streams := a.mgr.List()
infos := make([]distribution.StreamInfo, len(streams))
@@ -281,3 +322,18 @@ func buildStreamDescription(info distribution.StreamInfo) string {
return strings.Join(parts, " · ")
}
+
+// streamManagerAdapter adapts *stream.Manager to the dash.StreamManager
+// interface, which uses interface{} in the Create return to avoid an import
+// cycle with the stream package.
+type streamManagerAdapter struct {
+ mgr *stream.Manager
+}
+
+func (a *streamManagerAdapter) Create(key string) (interface{}, bool) {
+ return a.mgr.Create(key)
+}
+
+func (a *streamManagerAdapter) Remove(key string) {
+ a.mgr.Remove(key)
+}
diff --git a/demux/av1.go b/demux/av1.go
new file mode 100644
index 0000000..7374e13
--- /dev/null
+++ b/demux/av1.go
@@ -0,0 +1,518 @@
+package demux
+
+import (
+ "errors"
+ "fmt"
+)
+
+// AV1 OBU type constants as defined in AV1 spec §6.2.2.
+const (
+ OBUSequenceHeader = 1
+ OBUTemporalDelimiter = 2
+ OBUFrameHeader = 3
+ OBUFrame = 6
+)
+
+// AV1 frame type constants as defined in AV1 spec §6.8.2.
+const (
+ av1FrameKey = 0
+ av1FrameInter = 1
+ av1FrameIntraOnly = 2
+ av1FrameSwitch = 3
+)
+
+var (
+ errAV1TooShort = errors.New("AV1 data too short")
+ errAV1InvalidOBU = errors.New("invalid AV1 OBU header")
+ errAV1LEB128 = errors.New("invalid LEB128 encoding")
+ errAV1SeqHdrParse = errors.New("failed to parse AV1 sequence header")
+ errAV1ForbiddenBit = errors.New("AV1 OBU forbidden bit set")
+)
+
+// AV1SequenceHeader holds parameters extracted from an AV1 Sequence Header OBU.
+type AV1SequenceHeader struct {
+ SeqProfile int
+ SeqLevelIdx int
+ SeqTier int // 0=Main, 1=High
+ BitDepth int
+ Width int
+ Height int
+ ChromaSubsamplingX int
+ ChromaSubsamplingY int
+ Monochrome bool
+}
+
+// CodecString returns the RFC 6381 codec parameter string for AV1.
+// Format: "av01.P.LLT.DD" where P=profile, LL=zero-padded level,
+// T=tier(M/H), DD=zero-padded bit depth.
+func (s AV1SequenceHeader) CodecString() string {
+ tier := "M"
+ if s.SeqTier == 1 {
+ tier = "H"
+ }
+ return fmt.Sprintf("av01.%d.%02d%s.%02d", s.SeqProfile, s.SeqLevelIdx, tier, s.BitDepth)
+}
+
+// obuHeader represents a parsed AV1 OBU header.
+type obuHeader struct {
+ obuType int
+ hasExtension bool
+ hasSizeField bool
+ temporalID int
+ spatialID int
+ headerLen int // total header bytes (1 or 2)
+ payloadSize int // only valid when hasSizeField is true
+ totalSize int // headerLen + leb128 size bytes + payloadSize
+}
+
+// parseOBUHeader parses an OBU header from the given data. Returns the
+// parsed header and total bytes consumed (header + optional size field).
+func parseOBUHeader(data []byte) (obuHeader, error) {
+ if len(data) < 1 {
+ return obuHeader{}, errAV1TooShort
+ }
+
+ b := data[0]
+ if b&0x80 != 0 {
+ return obuHeader{}, errAV1ForbiddenBit
+ }
+
+ h := obuHeader{
+ obuType: int((b >> 3) & 0x0F),
+ hasExtension: b&0x04 != 0,
+ hasSizeField: b&0x02 != 0,
+ headerLen: 1,
+ }
+
+ pos := 1
+ if h.hasExtension {
+ if len(data) < 2 {
+ return obuHeader{}, errAV1TooShort
+ }
+ ext := data[1]
+ h.temporalID = int((ext >> 5) & 0x07)
+ h.spatialID = int((ext >> 3) & 0x03)
+ h.headerLen = 2
+ pos = 2
+ }
+
+ if h.hasSizeField {
+ size, n, err := readLEB128(data[pos:])
+ if err != nil {
+ return obuHeader{}, err
+ }
+ h.payloadSize = size
+ h.totalSize = pos + n + size
+ } else {
+ // Without size field, OBU extends to end of temporal unit.
+ h.payloadSize = len(data) - pos
+ h.totalSize = len(data)
+ }
+
+ return h, nil
+}
+
+// readLEB128 reads a LEB128 encoded unsigned integer from data.
+// Returns the value, number of bytes consumed, and any error.
+func readLEB128(data []byte) (int, int, error) {
+ val := 0
+ for i := 0; i < len(data) && i < 8; i++ {
+ b := data[i]
+ val |= int(b&0x7F) << (i * 7)
+ if b&0x80 == 0 {
+ return val, i + 1, nil
+ }
+ }
+ return 0, 0, errAV1LEB128
+}
+
+// av1BitReader is a MSB-first bit reader that tracks errors internally.
+// After an overflow, all subsequent reads return zero and the error is
+// preserved in the err field.
+type av1BitReader struct {
+ data []byte
+ pos int
+ bit int
+ err error
+}
+
+func newAV1BitReader(data []byte) *av1BitReader {
+ return &av1BitReader{data: data}
+}
+
+func (br *av1BitReader) readBits(n int) uint {
+ if br.err != nil {
+ return 0
+ }
+ var val uint
+ for i := 0; i < n; i++ {
+ if br.pos >= len(br.data) {
+ br.err = errAV1TooShort
+ return 0
+ }
+ val = (val << 1) | uint((br.data[br.pos]>>(7-br.bit))&1)
+ br.bit++
+ if br.bit == 8 {
+ br.bit = 0
+ br.pos++
+ }
+ }
+ return val
+}
+
+// readUvlc reads an unsigned variable-length code (AV1 spec §4.10.3).
+func (br *av1BitReader) readUvlc() uint {
+ if br.err != nil {
+ return 0
+ }
+ leadingZeros := 0
+ for {
+ b := br.readBits(1)
+ if br.err != nil {
+ return 0
+ }
+ if b == 1 {
+ break
+ }
+ leadingZeros++
+ if leadingZeros >= 32 {
+ br.err = errAV1SeqHdrParse
+ return 0
+ }
+ }
+ if leadingZeros == 0 {
+ return 0
+ }
+ val := br.readBits(leadingZeros)
+ return val + (1 << leadingZeros) - 1
+}
+
+// ParseAV1SequenceHeader parses an AV1 Sequence Header OBU (including OBU
+// header bytes) and extracts codec parameters. The input must start with
+// the OBU header byte.
+func ParseAV1SequenceHeader(data []byte) (*AV1SequenceHeader, error) {
+ if len(data) < 2 {
+ return nil, errAV1TooShort
+ }
+
+ h, err := parseOBUHeader(data)
+ if err != nil {
+ return nil, err
+ }
+
+ if h.obuType != OBUSequenceHeader {
+ return nil, fmt.Errorf("expected OBU_SEQUENCE_HEADER (type %d), got type %d", OBUSequenceHeader, h.obuType)
+ }
+
+ // Payload starts after header + LEB128 size
+ payloadStart := h.totalSize - h.payloadSize
+ if payloadStart >= len(data) || payloadStart+h.payloadSize > len(data) {
+ return nil, errAV1TooShort
+ }
+ payload := data[payloadStart : payloadStart+h.payloadSize]
+
+ return parseSequenceHeaderPayload(payload)
+}
+
+// parseSequenceHeaderPayload parses the raw payload of a sequence_header_obu.
+func parseSequenceHeaderPayload(payload []byte) (*AV1SequenceHeader, error) {
+ br := newAV1BitReader(payload)
+ hdr := &AV1SequenceHeader{}
+
+ hdr.SeqProfile = int(br.readBits(3))
+ stillPicture := br.readBits(1)
+ reducedStillPicture := br.readBits(1)
+ _ = stillPicture
+
+ var decoderModelInfoPresent bool
+ var bufferDelayLengthMinus1 uint
+
+ if reducedStillPicture == 1 {
+ hdr.SeqLevelIdx = int(br.readBits(5))
+ hdr.SeqTier = 0
+ } else {
+ timingInfoPresent := br.readBits(1)
+
+ if timingInfoPresent == 1 {
+ // timing_info()
+ br.readBits(32) // num_units_in_display_tick
+ br.readBits(32) // time_scale
+ equalPictureInterval := br.readBits(1)
+ if equalPictureInterval == 1 {
+ br.readUvlc() // num_ticks_per_picture_minus_1
+ }
+
+ dmi := br.readBits(1)
+ decoderModelInfoPresent = dmi == 1
+ if decoderModelInfoPresent {
+ // decoder_model_info()
+ bufferDelayLengthMinus1 = br.readBits(5)
+ br.readBits(32) // num_units_in_decoding_tick
+ br.readBits(5) // buffer_removal_time_length_minus_1
+ br.readBits(5) // frame_presentation_time_length_minus_1
+ }
+ }
+
+ initialDisplayDelayPresent := br.readBits(1)
+ operatingPointsCntMinus1 := br.readBits(5)
+
+ for i := uint(0); i <= operatingPointsCntMinus1; i++ {
+ br.readBits(12) // operating_point_idc
+ levelIdx := br.readBits(5)
+ var tier uint
+ if levelIdx > 7 {
+ tier = br.readBits(1)
+ }
+ if i == 0 {
+ hdr.SeqLevelIdx = int(levelIdx)
+ hdr.SeqTier = int(tier)
+ }
+ if decoderModelInfoPresent {
+ decoderModelPresentForThisOp := br.readBits(1)
+ if decoderModelPresentForThisOp == 1 {
+ // operating_parameters_info(): skip decoder/encoder buffer delays and flag
+ n := int(bufferDelayLengthMinus1) + 1
+ br.readBits(n) // decoder_buffer_delay[i]
+ br.readBits(n) // encoder_buffer_delay[i]
+ br.readBits(1) // low_delay_mode_flag[i]
+ }
+ }
+ if initialDisplayDelayPresent == 1 {
+ flag := br.readBits(1)
+ if flag == 1 {
+ br.readBits(4) // initial_display_delay_minus_1
+ }
+ }
+ }
+ }
+
+ if br.err != nil {
+ return nil, fmt.Errorf("%w: %v", errAV1SeqHdrParse, br.err)
+ }
+
+ frameWidthBitsMinus1 := br.readBits(4)
+ frameHeightBitsMinus1 := br.readBits(4)
+ maxFrameWidthMinus1 := br.readBits(int(frameWidthBitsMinus1) + 1)
+ maxFrameHeightMinus1 := br.readBits(int(frameHeightBitsMinus1) + 1)
+ hdr.Width = int(maxFrameWidthMinus1) + 1
+ hdr.Height = int(maxFrameHeightMinus1) + 1
+
+ if reducedStillPicture == 0 {
+ frameIDNumbersPresent := br.readBits(1)
+ if frameIDNumbersPresent == 1 {
+ br.readBits(4) // delta_frame_id_length_minus_2
+ br.readBits(3) // additional_frame_id_length_minus_1
+ }
+ }
+
+ br.readBits(1) // use_128x128_superblock
+ br.readBits(1) // enable_filter_intra
+ br.readBits(1) // enable_intra_edge_filter
+
+ if reducedStillPicture == 0 {
+ br.readBits(1) // enable_interintra_compound
+ br.readBits(1) // enable_masked_compound
+ br.readBits(1) // enable_warped_motion
+ br.readBits(1) // enable_dual_filter
+ enableOrderHint := br.readBits(1)
+
+ if enableOrderHint == 1 {
+ br.readBits(1) // enable_jnt_comp
+ br.readBits(1) // enable_ref_frame_mvs
+ }
+
+ seqChooseScreenContentTools := br.readBits(1)
+ var seqForceScreenContentTools uint
+ if seqChooseScreenContentTools == 0 {
+ seqForceScreenContentTools = br.readBits(1)
+ } else {
+ seqForceScreenContentTools = 2 // SELECT_SCREEN_CONTENT_TOOLS
+ }
+
+ if seqForceScreenContentTools > 0 {
+ seqChooseIntegerMV := br.readBits(1)
+ if seqChooseIntegerMV == 0 {
+ br.readBits(1) // seq_force_integer_mv
+ }
+ }
+
+ if enableOrderHint == 1 {
+ br.readBits(3) // order_hint_bits_minus_1
+ }
+ }
+
+ br.readBits(1) // enable_superres
+ br.readBits(1) // enable_cdef
+ br.readBits(1) // enable_restoration
+
+ if br.err != nil {
+ return nil, fmt.Errorf("%w: %v", errAV1SeqHdrParse, br.err)
+ }
+
+ // color_config()
+ parseColorConfig(br, hdr)
+
+ if br.err != nil {
+ return nil, fmt.Errorf("%w: %v", errAV1SeqHdrParse, br.err)
+ }
+
+ return hdr, nil
+}
+
+// parseColorConfig parses the color_config() section of a sequence header.
+func parseColorConfig(br *av1BitReader, hdr *AV1SequenceHeader) {
+ highBitDepth := br.readBits(1)
+
+ if hdr.SeqProfile == 2 && highBitDepth == 1 {
+ twelveBit := br.readBits(1)
+ if twelveBit == 1 {
+ hdr.BitDepth = 12
+ } else {
+ hdr.BitDepth = 10
+ }
+ } else if highBitDepth == 1 {
+ hdr.BitDepth = 10
+ } else {
+ hdr.BitDepth = 8
+ }
+
+ var monoChrome uint
+ if hdr.SeqProfile == 1 {
+ monoChrome = 0
+ } else {
+ monoChrome = br.readBits(1)
+ }
+ hdr.Monochrome = monoChrome == 1
+
+ colorDescriptionPresent := br.readBits(1)
+ var colorPrimaries, transferCharacteristics, matrixCoefficients uint
+ if colorDescriptionPresent == 1 {
+ colorPrimaries = br.readBits(8)
+ transferCharacteristics = br.readBits(8)
+ matrixCoefficients = br.readBits(8)
+ } else {
+ colorPrimaries = 2 // CP_UNSPECIFIED
+ transferCharacteristics = 2 // TC_UNSPECIFIED
+ matrixCoefficients = 2 // MC_UNSPECIFIED
+ }
+
+ if monoChrome == 1 {
+ br.readBits(1) // color_range
+ hdr.ChromaSubsamplingX = 1
+ hdr.ChromaSubsamplingY = 1
+ return
+ }
+
+ // Check for sRGB/sYCC
+ if colorPrimaries == 1 && transferCharacteristics == 13 && matrixCoefficients == 0 {
+ // color_range is implicitly 1 (full)
+ hdr.ChromaSubsamplingX = 0
+ hdr.ChromaSubsamplingY = 0
+ return
+ }
+
+ br.readBits(1) // color_range
+
+ if hdr.SeqProfile == 0 {
+ hdr.ChromaSubsamplingX = 1
+ hdr.ChromaSubsamplingY = 1
+ } else if hdr.SeqProfile == 1 {
+ hdr.ChromaSubsamplingX = 0
+ hdr.ChromaSubsamplingY = 0
+ } else {
+ // Profile 2
+ if hdr.BitDepth == 12 {
+ sx := br.readBits(1)
+ hdr.ChromaSubsamplingX = int(sx)
+ if sx == 1 {
+ hdr.ChromaSubsamplingY = int(br.readBits(1))
+ } else {
+ hdr.ChromaSubsamplingY = 0
+ }
+ } else {
+ hdr.ChromaSubsamplingX = 1
+ hdr.ChromaSubsamplingY = 0
+ }
+ }
+
+ if hdr.ChromaSubsamplingX == 1 && hdr.ChromaSubsamplingY == 1 {
+ br.readBits(2) // chroma_sample_position
+ }
+
+ br.readBits(1) // separate_uv_delta_q
+}
+
+// IsAV1Keyframe scans OBU headers in a temporal unit for a key frame.
+// It checks OBU_FRAME (type 6) and OBU_FRAME_HEADER (type 3) for
+// frame_type == KEY_FRAME (0). Returns false for nil/empty input.
+func IsAV1Keyframe(temporalUnit []byte) bool {
+ if len(temporalUnit) == 0 {
+ return false
+ }
+
+ pos := 0
+ for pos < len(temporalUnit) {
+ h, err := parseOBUHeader(temporalUnit[pos:])
+ if err != nil {
+ return false
+ }
+
+ if h.obuType == OBUFrame || h.obuType == OBUFrameHeader {
+ // Payload starts after header bytes and optional LEB128 size
+ payloadStart := pos + h.totalSize - h.payloadSize
+ if payloadStart >= len(temporalUnit) || h.payloadSize < 1 {
+ return false
+ }
+ // frame_header_obu(): first bit is show_existing_frame
+ // If show_existing_frame==1, this is not a new frame.
+ // If show_existing_frame==0, next 2 bits are frame_type.
+ payload := temporalUnit[payloadStart:]
+ showExisting := (payload[0] >> 7) & 1
+ if showExisting == 1 {
+ return false
+ }
+ frameType := (payload[0] >> 5) & 0x03
+ return frameType == av1FrameKey
+ }
+
+ if h.hasSizeField {
+ pos += h.totalSize
+ } else {
+ // No size field — OBU extends to end
+ break
+ }
+ }
+ return false
+}
+
+// FindSequenceHeaderOBU scans a temporal unit for an OBU_SEQUENCE_HEADER
+// (type 1) and returns the complete OBU bytes (header + payload).
+// Returns nil if no sequence header is found.
+func FindSequenceHeaderOBU(temporalUnit []byte) []byte {
+ if len(temporalUnit) == 0 {
+ return nil
+ }
+
+ pos := 0
+ for pos < len(temporalUnit) {
+ h, err := parseOBUHeader(temporalUnit[pos:])
+ if err != nil {
+ return nil
+ }
+
+ if h.obuType == OBUSequenceHeader {
+ end := pos + h.totalSize
+ if end > len(temporalUnit) {
+ end = len(temporalUnit)
+ }
+ return temporalUnit[pos:end]
+ }
+
+ if h.hasSizeField {
+ pos += h.totalSize
+ } else {
+ break
+ }
+ }
+ return nil
+}
diff --git a/demux/av1_test.go b/demux/av1_test.go
new file mode 100644
index 0000000..c4652e6
--- /dev/null
+++ b/demux/av1_test.go
@@ -0,0 +1,862 @@
+package demux
+
+import (
+ "testing"
+)
+
+// Test data generated with ffmpeg using libsvtav1.
+// 1920x1080, profile 0, level 8, 8-bit, 4:2:0.
+var testSeqHdrOBU1080p = []byte{
+ 0x0a, 0x0b, // OBU header (type=1, has_size=1) + LEB128 size=11
+ 0x00, 0x00, 0x00, 0x42, 0xab, 0xbf, 0xc3, 0x71,
+ 0x2b, 0xe4, 0x01,
+}
+
+// 640x480, profile 0, level 4, 10-bit, 4:2:0.
+var testSeqHdrOBU480p10 = []byte{
+ 0x0a, 0x0b, // OBU header (type=1, has_size=1) + LEB128 size=11
+ 0x00, 0x00, 0x00, 0x24, 0xc4, 0xff, 0xdf, 0x12,
+ 0xbe, 0x50, 0x10,
+}
+
+// Temporal unit containing TD + seq header + key frame (1920x1080 8-bit).
+var testTemporalUnitKeyframe = []byte{
+ 0x12, 0x00, // OBU_TEMPORAL_DELIMITER, has_size=1, size=0
+ 0x0a, 0x0b, // OBU_SEQUENCE_HEADER, has_size=1, size=11
+ 0x00, 0x00, 0x00, 0x42, 0xab, 0xbf, 0xc3, 0x71,
+ 0x2b, 0xe4, 0x01,
+ 0x32, 0x2d, // OBU_FRAME, has_size=1, size=45
+ 0x10, 0x00, 0x8b, 0x00, 0x81, 0x45, 0x08, 0x20,
+ 0x40, 0x00, 0x00, 0x03, 0x24, 0xfe, 0x6a, 0x75,
+ 0xf5, 0xb5, 0x7d, 0xaa, 0x9a, 0x90, 0x9b, 0xe8,
+ 0x41, 0xd4, 0x1a, 0x98, 0x69, 0x56, 0xfa, 0xf0,
+ 0xa8, 0xe6, 0xda, 0x81, 0x0c, 0x1c, 0xe9, 0xc2,
+ 0xf3, 0x84, 0xbd, 0x86, 0xab,
+}
+
+// Temporal unit containing TD + inter frame (no sequence header).
+var testTemporalUnitInterFrame = []byte{
+ 0x12, 0x00, // OBU_TEMPORAL_DELIMITER, has_size=1, size=0
+ 0x32, 0x18, // OBU_FRAME, has_size=1, size=24
+ 0x30, 0x02, 0x04, 0x09, 0x24, 0x92, 0x22, 0x46,
+ 0x00, 0x00, 0x00, 0x80, 0x00, 0x00, 0x00, 0xbc,
+ 0x32, 0xc1, 0x64, 0xdc, 0x91, 0xe8, 0x51, 0xe8,
+}
+
+func TestParseAV1SequenceHeader1080p(t *testing.T) {
+ t.Parallel()
+
+ hdr, err := ParseAV1SequenceHeader(testSeqHdrOBU1080p)
+ if err != nil {
+ t.Fatalf("ParseAV1SequenceHeader error: %v", err)
+ }
+
+ if hdr.SeqProfile != 0 {
+ t.Errorf("SeqProfile: got %d, want 0", hdr.SeqProfile)
+ }
+ if hdr.SeqLevelIdx != 8 {
+ t.Errorf("SeqLevelIdx: got %d, want 8", hdr.SeqLevelIdx)
+ }
+ if hdr.SeqTier != 0 {
+ t.Errorf("SeqTier: got %d, want 0", hdr.SeqTier)
+ }
+ if hdr.BitDepth != 8 {
+ t.Errorf("BitDepth: got %d, want 8", hdr.BitDepth)
+ }
+ if hdr.Width != 1920 {
+ t.Errorf("Width: got %d, want 1920", hdr.Width)
+ }
+ if hdr.Height != 1080 {
+ t.Errorf("Height: got %d, want 1080", hdr.Height)
+ }
+ if hdr.ChromaSubsamplingX != 1 {
+ t.Errorf("ChromaSubsamplingX: got %d, want 1", hdr.ChromaSubsamplingX)
+ }
+ if hdr.ChromaSubsamplingY != 1 {
+ t.Errorf("ChromaSubsamplingY: got %d, want 1", hdr.ChromaSubsamplingY)
+ }
+ if hdr.Monochrome {
+ t.Error("Monochrome: got true, want false")
+ }
+}
+
+func TestParseAV1SequenceHeader480p10Bit(t *testing.T) {
+ t.Parallel()
+
+ hdr, err := ParseAV1SequenceHeader(testSeqHdrOBU480p10)
+ if err != nil {
+ t.Fatalf("ParseAV1SequenceHeader error: %v", err)
+ }
+
+ if hdr.SeqProfile != 0 {
+ t.Errorf("SeqProfile: got %d, want 0", hdr.SeqProfile)
+ }
+ if hdr.SeqLevelIdx != 4 {
+ t.Errorf("SeqLevelIdx: got %d, want 4", hdr.SeqLevelIdx)
+ }
+ if hdr.SeqTier != 0 {
+ t.Errorf("SeqTier: got %d, want 0", hdr.SeqTier)
+ }
+ if hdr.BitDepth != 10 {
+ t.Errorf("BitDepth: got %d, want 10", hdr.BitDepth)
+ }
+ if hdr.Width != 640 {
+ t.Errorf("Width: got %d, want 640", hdr.Width)
+ }
+ if hdr.Height != 480 {
+ t.Errorf("Height: got %d, want 480", hdr.Height)
+ }
+ if hdr.ChromaSubsamplingX != 1 {
+ t.Errorf("ChromaSubsamplingX: got %d, want 1", hdr.ChromaSubsamplingX)
+ }
+ if hdr.ChromaSubsamplingY != 1 {
+ t.Errorf("ChromaSubsamplingY: got %d, want 1", hdr.ChromaSubsamplingY)
+ }
+ if hdr.Monochrome {
+ t.Error("Monochrome: got true, want false")
+ }
+}
+
+func TestParseAV1SequenceHeaderNil(t *testing.T) {
+ t.Parallel()
+
+ _, err := ParseAV1SequenceHeader(nil)
+ if err == nil {
+ t.Error("expected error for nil input")
+ }
+}
+
+func TestParseAV1SequenceHeaderEmpty(t *testing.T) {
+ t.Parallel()
+
+ _, err := ParseAV1SequenceHeader([]byte{})
+ if err == nil {
+ t.Error("expected error for empty input")
+ }
+}
+
+func TestParseAV1SequenceHeaderTruncated(t *testing.T) {
+ t.Parallel()
+
+ // Just the OBU header, no payload
+ _, err := ParseAV1SequenceHeader([]byte{0x0a, 0x0b})
+ if err == nil {
+ t.Error("expected error for truncated input")
+ }
+
+ // Partial payload
+ _, err = ParseAV1SequenceHeader([]byte{0x0a, 0x0b, 0x00, 0x00})
+ if err == nil {
+ t.Error("expected error for truncated payload")
+ }
+}
+
+func TestParseAV1SequenceHeaderWrongOBUType(t *testing.T) {
+ t.Parallel()
+
+ // OBU_FRAME (type 6) instead of SEQUENCE_HEADER
+ _, err := ParseAV1SequenceHeader([]byte{0x32, 0x01, 0x00})
+ if err == nil {
+ t.Error("expected error for wrong OBU type")
+ }
+}
+
+func TestAV1CodecString1080p(t *testing.T) {
+ t.Parallel()
+
+ hdr, err := ParseAV1SequenceHeader(testSeqHdrOBU1080p)
+ if err != nil {
+ t.Fatalf("ParseAV1SequenceHeader error: %v", err)
+ }
+
+ want := "av01.0.08M.08"
+ got := hdr.CodecString()
+ if got != want {
+ t.Errorf("CodecString: got %q, want %q", got, want)
+ }
+}
+
+func TestAV1CodecString480p10Bit(t *testing.T) {
+ t.Parallel()
+
+ hdr, err := ParseAV1SequenceHeader(testSeqHdrOBU480p10)
+ if err != nil {
+ t.Fatalf("ParseAV1SequenceHeader error: %v", err)
+ }
+
+ want := "av01.0.04M.10"
+ got := hdr.CodecString()
+ if got != want {
+ t.Errorf("CodecString: got %q, want %q", got, want)
+ }
+}
+
+func TestAV1CodecStringHighTier(t *testing.T) {
+ t.Parallel()
+
+ hdr := &AV1SequenceHeader{
+ SeqProfile: 1,
+ SeqLevelIdx: 13,
+ SeqTier: 1,
+ BitDepth: 10,
+ }
+
+ want := "av01.1.13H.10"
+ got := hdr.CodecString()
+ if got != want {
+ t.Errorf("CodecString: got %q, want %q", got, want)
+ }
+}
+
+func TestAV1CodecStringZeroPadding(t *testing.T) {
+ t.Parallel()
+
+ hdr := &AV1SequenceHeader{
+ SeqProfile: 0,
+ SeqLevelIdx: 1,
+ SeqTier: 0,
+ BitDepth: 8,
+ }
+
+ want := "av01.0.01M.08"
+ got := hdr.CodecString()
+ if got != want {
+ t.Errorf("CodecString: got %q, want %q", got, want)
+ }
+}
+
+func TestIsAV1KeyframeTrue(t *testing.T) {
+ t.Parallel()
+
+ if !IsAV1Keyframe(testTemporalUnitKeyframe) {
+ t.Error("expected true for keyframe temporal unit")
+ }
+}
+
+func TestIsAV1KeyframeFalseInterFrame(t *testing.T) {
+ t.Parallel()
+
+ if IsAV1Keyframe(testTemporalUnitInterFrame) {
+ t.Error("expected false for inter frame temporal unit")
+ }
+}
+
+func TestIsAV1KeyframeFalseNil(t *testing.T) {
+ t.Parallel()
+
+ if IsAV1Keyframe(nil) {
+ t.Error("expected false for nil input")
+ }
+}
+
+func TestIsAV1KeyframeFalseEmpty(t *testing.T) {
+ t.Parallel()
+
+ if IsAV1Keyframe([]byte{}) {
+ t.Error("expected false for empty input")
+ }
+}
+
+func TestIsAV1KeyframeFalseTruncated(t *testing.T) {
+ t.Parallel()
+
+ // Just a temporal delimiter, no frame OBU
+ if IsAV1Keyframe([]byte{0x12, 0x00}) {
+ t.Error("expected false for temporal delimiter only")
+ }
+}
+
+func TestFindSequenceHeaderOBUPresent(t *testing.T) {
+ t.Parallel()
+
+ obu := FindSequenceHeaderOBU(testTemporalUnitKeyframe)
+ if obu == nil {
+ t.Fatal("expected non-nil OBU")
+ }
+
+ // Verify the found OBU is a valid sequence header
+ hdr, err := ParseAV1SequenceHeader(obu)
+ if err != nil {
+ t.Fatalf("ParseAV1SequenceHeader on found OBU: %v", err)
+ }
+ if hdr.Width != 1920 || hdr.Height != 1080 {
+ t.Errorf("resolution: got %dx%d, want 1920x1080", hdr.Width, hdr.Height)
+ }
+}
+
+func TestFindSequenceHeaderOBUAbsent(t *testing.T) {
+ t.Parallel()
+
+ obu := FindSequenceHeaderOBU(testTemporalUnitInterFrame)
+ if obu != nil {
+ t.Errorf("expected nil for inter frame, got %d bytes", len(obu))
+ }
+}
+
+func TestFindSequenceHeaderOBUNil(t *testing.T) {
+ t.Parallel()
+
+ obu := FindSequenceHeaderOBU(nil)
+ if obu != nil {
+ t.Error("expected nil for nil input")
+ }
+}
+
+func TestFindSequenceHeaderOBUEmpty(t *testing.T) {
+ t.Parallel()
+
+ obu := FindSequenceHeaderOBU([]byte{})
+ if obu != nil {
+ t.Error("expected nil for empty input")
+ }
+}
+
+func TestReadLEB128(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ data []byte
+ want int
+ wantN int
+ wantErr bool
+ }{
+ {
+ name: "zero",
+ data: []byte{0x00},
+ want: 0,
+ wantN: 1,
+ },
+ {
+ name: "one byte value 11",
+ data: []byte{0x0b},
+ want: 11,
+ wantN: 1,
+ },
+ {
+ name: "one byte value 127",
+ data: []byte{0x7f},
+ want: 127,
+ wantN: 1,
+ },
+ {
+ name: "two byte value 128",
+ data: []byte{0x80, 0x01},
+ want: 128,
+ wantN: 2,
+ },
+ {
+ name: "two byte value 300",
+ data: []byte{0xac, 0x02},
+ want: 300,
+ wantN: 2,
+ },
+ {
+ name: "empty",
+ data: []byte{},
+ wantErr: true,
+ },
+ {
+ name: "unterminated",
+ data: []byte{0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80},
+ wantErr: true,
+ },
+ }
+
+ for _, tt := range tests {
+ tt := tt
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+ val, n, err := readLEB128(tt.data)
+ if tt.wantErr {
+ if err == nil {
+ t.Error("expected error")
+ }
+ return
+ }
+ if err != nil {
+ t.Fatalf("unexpected error: %v", err)
+ }
+ if val != tt.want {
+ t.Errorf("value: got %d, want %d", val, tt.want)
+ }
+ if n != tt.wantN {
+ t.Errorf("bytes consumed: got %d, want %d", n, tt.wantN)
+ }
+ })
+ }
+}
+
+func TestParseOBUHeader(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ data []byte
+ wantType int
+ wantHasSize bool
+ wantHasExt bool
+ wantErr bool
+ }{
+ {
+ name: "sequence header with size",
+ data: []byte{0x0a, 0x0b, 0x00, 0x00},
+ wantType: OBUSequenceHeader,
+ wantHasSize: true,
+ },
+ {
+ name: "temporal delimiter with size",
+ data: []byte{0x12, 0x00},
+ wantType: OBUTemporalDelimiter,
+ wantHasSize: true,
+ },
+ {
+ name: "frame with size",
+ data: []byte{0x32, 0x2d, 0x10},
+ wantType: OBUFrame,
+ wantHasSize: true,
+ },
+ {
+ name: "forbidden bit set",
+ data: []byte{0x8a},
+ wantErr: true,
+ },
+ {
+ name: "empty",
+ data: []byte{},
+ wantErr: true,
+ },
+ {
+ name: "extension flag",
+ data: []byte{0x0e, 0x40, 0x0b, 0x00, 0x00},
+ wantType: OBUSequenceHeader,
+ wantHasExt: true,
+ wantHasSize: true,
+ },
+ }
+
+ for _, tt := range tests {
+ tt := tt
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+ h, err := parseOBUHeader(tt.data)
+ if tt.wantErr {
+ if err == nil {
+ t.Error("expected error")
+ }
+ return
+ }
+ if err != nil {
+ t.Fatalf("unexpected error: %v", err)
+ }
+ if h.obuType != tt.wantType {
+ t.Errorf("obuType: got %d, want %d", h.obuType, tt.wantType)
+ }
+ if h.hasSizeField != tt.wantHasSize {
+ t.Errorf("hasSizeField: got %v, want %v", h.hasSizeField, tt.wantHasSize)
+ }
+ if h.hasExtension != tt.wantHasExt {
+ t.Errorf("hasExtension: got %v, want %v", h.hasExtension, tt.wantHasExt)
+ }
+ })
+ }
+}
+
+func TestAV1BitReaderOverflow(t *testing.T) {
+ t.Parallel()
+
+ br := newAV1BitReader([]byte{0xFF})
+ _ = br.readBits(8) // should succeed
+ _ = br.readBits(1) // should set error
+
+ if br.err == nil {
+ t.Error("expected error after reading past end")
+ }
+
+ // Subsequent reads should return 0 without panicking
+ val := br.readBits(8)
+ if val != 0 {
+ t.Errorf("expected 0 after error, got %d", val)
+ }
+}
+
+func TestAV1BitReaderReadBits(t *testing.T) {
+ t.Parallel()
+
+ // 0xA5 = 10100101
+ br := newAV1BitReader([]byte{0xA5})
+
+ got := br.readBits(3)
+ if got != 5 { // 101
+ t.Errorf("first 3 bits: got %d, want 5", got)
+ }
+
+ got = br.readBits(5)
+ if got != 5 { // 00101
+ t.Errorf("next 5 bits: got %d, want 5", got)
+ }
+
+ if br.err != nil {
+ t.Errorf("unexpected error: %v", br.err)
+ }
+}
+
+func TestAV1BitReaderUvlc(t *testing.T) {
+ t.Parallel()
+
+ // UVLC encoding of 0: just a single 1-bit → value 0
+ br := newAV1BitReader([]byte{0x80}) // 10000000
+ got := br.readUvlc()
+ if got != 0 {
+ t.Errorf("uvlc(0): got %d, want 0", got)
+ }
+
+ // UVLC encoding of 1: 0,1,0 → leadingZeros=1, val=0, result = 0 + (1<<1) - 1 = 1
+ br = newAV1BitReader([]byte{0x40}) // 01000000
+ got = br.readUvlc()
+ if got != 1 {
+ t.Errorf("uvlc(1): got %d, want 1", got)
+ }
+
+ // UVLC encoding of 2: 0,1,1 → leadingZeros=1, val=1, result = 1 + (1<<1) - 1 = 2
+ br = newAV1BitReader([]byte{0x60}) // 01100000
+ got = br.readUvlc()
+ if got != 2 {
+ t.Errorf("uvlc(2): got %d, want 2", got)
+ }
+}
+
+func TestReducedStillPictureHeader(t *testing.T) {
+ t.Parallel()
+
+ // Hand-craft a minimal sequence header with reduced_still_picture_header=1
+ // Profile 0, still_picture=1, reduced=1, level_idx=1 (5 bits)
+ // Then frame_width_bits_minus_1=3, frame_height_bits_minus_1=3 (4+4 bits)
+ // Then width=128-1=127 (4 bits), height=96-1=95 (4 bits)
+ // No frame_id_numbers_present (since reduced)
+ // use_128x128_superblock=0, enable_filter_intra=0, enable_intra_edge_filter=0
+ // No interintra/masked/etc (since reduced)
+ // enable_superres=0, enable_cdef=0, enable_restoration=0
+ // color_config: high_bitdepth=0 → 8-bit, mono_chrome=0
+ // color_description_present=0 → unspecified
+ // color_range=0
+ // chroma_sample_position=0 (2 bits, since profile 0 → 4:2:0)
+ // separate_uv_delta_q=0
+
+ // Bits: profile(3)=000, still(1)=1, reduced(1)=1
+ // level_idx(5)=00001
+ // fw_bits(4)=0011, fh_bits(4)=0011
+ // width(4)=0111 1111, height(4)=0101 1111
+ // Wait, width is fw_bits+1=4 bits, so 127 in 4 bits doesn't fit.
+ // Let's use smaller: width=16, height=16
+ // fw_bits_m1=3 → 4 bits for width, fh_bits_m1=3 → 4 bits for height
+ // width-1=15=0b1111, height-1=15=0b1111
+ // Then: use_128x128_sb=0, filter_intra=0, intra_edge_filter=0 (3 bits)
+ // enable_superres=0, enable_cdef=0, enable_restoration=0 (3 bits)
+ // high_bitdepth=0 (1 bit)
+ // mono_chrome=0 (1 bit, profile != 1)
+ // color_desc_present=0 (1 bit)
+ // color_range=0 (1 bit)
+ // chroma_sample_position=00 (2 bits)
+ // separate_uv_delta_q=0 (1 bit)
+
+ // Layout:
+ // byte 0: 000 1 1 000 = 0x18
+ // byte 1: 01 0011 00 = 0x4C (level_idx=00001, fw_bits=0011, start of fh_bits)
+ // byte 2: 11 1111 11 = 0xFF (fh_bits last 2 bits=11, width=1111, height first 2=11)
+ // byte 3: 11 000 000 = 0xC0 (height last 2=11, sb=0, filter_intra=0, intra_edge=0, superres=0, cdef=0)
+ // byte 4: 0 0 0 0 0 0 00 = 0x00 (restoration=0, high_bitdepth=0, mono=0, cdp=0, color_range=0, csp=00, sep_uv=0)
+
+ // Actually let me be more careful:
+ // Bit stream (MSB first):
+ // [0] profile: 000
+ // [3] still_picture: 1
+ // [4] reduced_still_picture_header: 1
+ // [5] seq_level_idx_0: 00001
+ // [10] frame_width_bits_minus_1: 0011
+ // [14] frame_height_bits_minus_1: 0011
+ // [18] max_frame_width_minus_1 (4 bits): 1111
+ // [22] max_frame_height_minus_1 (4 bits): 1111
+ // [26] use_128x128_superblock: 0
+ // [27] enable_filter_intra: 0
+ // [28] enable_intra_edge_filter: 0
+ // [29] enable_superres: 0
+ // [30] enable_cdef: 0
+ // [31] enable_restoration: 0
+ // [32] high_bitdepth: 0
+ // [33] mono_chrome: 0
+ // [34] color_description_present_flag: 0
+ // [35] color_range: 0
+ // [36] chroma_sample_position: 00
+ // [38] separate_uv_delta_q: 0
+
+ // Byte layout:
+ // bits 0-7: 000 1 1 000 = 0x18
+ // bits 8-15: 01 0011 00 = 0x4C
+ // bits 16-23: 11 1111 11 = 0xFF
+ // bits 24-31: 11 000000 = 0xC0
+ // bits 32-39: 00000 000 = 0x00
+
+ payload := []byte{0x18, 0x4C, 0xFF, 0xC0, 0x00}
+
+ // Wrap in OBU header: type=1, has_size=1
+ obu := make([]byte, 0, 2+len(payload))
+ obu = append(obu, 0x0a) // OBU header: type=1, has_size=1
+ obu = append(obu, byte(len(payload))) // LEB128 size
+ obu = append(obu, payload...)
+
+ hdr, err := ParseAV1SequenceHeader(obu)
+ if err != nil {
+ t.Fatalf("ParseAV1SequenceHeader error: %v", err)
+ }
+
+ if hdr.SeqProfile != 0 {
+ t.Errorf("SeqProfile: got %d, want 0", hdr.SeqProfile)
+ }
+ if hdr.SeqLevelIdx != 1 {
+ t.Errorf("SeqLevelIdx: got %d, want 1", hdr.SeqLevelIdx)
+ }
+ if hdr.SeqTier != 0 {
+ t.Errorf("SeqTier: got %d, want 0", hdr.SeqTier)
+ }
+ if hdr.Width != 16 {
+ t.Errorf("Width: got %d, want 16", hdr.Width)
+ }
+ if hdr.Height != 16 {
+ t.Errorf("Height: got %d, want 16", hdr.Height)
+ }
+ if hdr.BitDepth != 8 {
+ t.Errorf("BitDepth: got %d, want 8", hdr.BitDepth)
+ }
+}
+
+func TestParseSequenceHeaderWithDecoderModelInfo(t *testing.T) {
+ t.Parallel()
+
+ // Hand-craft a sequence header payload with decoder_model_info and
+ // operating_parameters_info to exercise the code path where
+ // decoder_model_present_for_this_op=1 requires reading additional fields.
+ //
+ // Profile 0, still_picture=0, reduced_still_picture_header=0
+ // timing_info_present=1, num_units_in_display_tick=1(32b), time_scale=30(32b)
+ // equal_picture_interval=0
+ // decoder_model_info_present_flag=1
+ // buffer_delay_length_minus_1=4 (5b, so delays are 5 bits each)
+ // num_units_in_decoding_tick=1 (32b)
+ // buffer_removal_time_length_minus_1=0 (5b)
+ // frame_presentation_time_length_minus_1=0 (5b)
+ // initial_display_delay_present_flag=0
+ // operating_points_cnt_minus_1=0 (5b = 00000, so 1 operating point)
+ // operating_point_idc[0]=0 (12b)
+ // seq_level_idx[0]=8 (5b=01000), tier not read (level<=7 check: 8>7 so tier read)
+ // seq_tier[0]=0 (1b)
+ // decoder_model_present_for_this_op[0]=1 (1b)
+ // decoder_buffer_delay[0]=0 (5b, since buffer_delay_length_minus_1=4)
+ // encoder_buffer_delay[0]=0 (5b)
+ // low_delay_mode_flag[0]=0 (1b)
+ // frame_width_bits_minus_1=3 (4b=0011), frame_height_bits_minus_1=3 (4b=0011)
+ // max_frame_width_minus_1=15 (4b=1111), max_frame_height_minus_1=15 (4b=1111)
+ // frame_id_numbers_present=0
+ // use_128x128_superblock=0, enable_filter_intra=0, enable_intra_edge_filter=0
+ // enable_interintra_compound=0, enable_masked_compound=0, enable_warped_motion=0
+ // enable_dual_filter=0, enable_order_hint=0
+ // seq_choose_screen_content_tools=1
+ // enable_superres=0, enable_cdef=0, enable_restoration=0
+ // color_config: high_bitdepth=0, mono_chrome=0, color_description_present=0
+ // color_range=0, chroma_sample_position=00, separate_uv_delta_q=0
+
+ // Bit-by-bit layout:
+ // [0-2] seq_profile: 000
+ // [3] still_picture: 0
+ // [4] reduced_still_picture_header: 0
+ // [5] timing_info_present_flag: 1
+ // [6-37] num_units_in_display_tick: 00000000 00000000 00000000 00000001
+ // [38-69] time_scale: 00000000 00000000 00000000 00011110
+ // [70] equal_picture_interval: 0
+ // [71] decoder_model_info_present_flag: 1
+ // [72-76] buffer_delay_length_minus_1: 00100 (=4, so delays are 5 bits)
+ // [77-108] num_units_in_decoding_tick: 00000000 00000000 00000000 00000001
+ // [109-113] buffer_removal_time_length_minus_1: 00000
+ // [114-118] frame_presentation_time_length_minus_1: 00000
+ // [119] initial_display_delay_present_flag: 0
+ // [120-124] operating_points_cnt_minus_1: 00000
+ // [125-136] operating_point_idc[0]: 000000000000
+ // [137-141] seq_level_idx[0]: 01000 (=8)
+ // [142] seq_tier[0]: 0 (level>7 so tier is read)
+ // [143] decoder_model_present_for_this_op[0]: 1
+ // [144-148] decoder_buffer_delay[0]: 00000
+ // [149-153] encoder_buffer_delay[0]: 00000
+ // [154] low_delay_mode_flag[0]: 0
+ // [155-158] frame_width_bits_minus_1: 0011
+ // [159-162] frame_height_bits_minus_1: 0011
+ // [163-166] max_frame_width_minus_1: 1111
+ // [167-170] max_frame_height_minus_1: 1111
+ // [171] frame_id_numbers_present_flag: 0
+ // [172] use_128x128_superblock: 0
+ // [173] enable_filter_intra: 0
+ // [174] enable_intra_edge_filter: 0
+ // [175] enable_interintra_compound: 0
+ // [176] enable_masked_compound: 0
+ // [177] enable_warped_motion: 0
+ // [178] enable_dual_filter: 0
+ // [179] enable_order_hint: 0
+ // [180] seq_choose_screen_content_tools: 1
+ // [181] enable_superres: 0
+ // [182] enable_cdef: 0
+ // [183] enable_restoration: 0
+ // [184] high_bitdepth: 0
+ // [185] mono_chrome: 0
+ // [186] color_description_present_flag: 0
+ // [187] color_range: 0
+ // [188-189] chroma_sample_position: 00
+ // [190] separate_uv_delta_q: 0
+
+ // Convert bit positions to bytes:
+ // byte 0 [0-7]: 00000 1 00 = 0x04
+ // byte 1 [8-15]: 00000000
+ // byte 2 [16-23]: 00000000
+ // byte 3 [24-31]: 00000000
+ // byte 4 [32-39]: 00000001
+ // byte 5 [40-47]: 00000000
+ // byte 6 [48-55]: 00000000
+ // byte 7 [56-63]: 00000000
+ // byte 8 [64-71]: 00011110 = 0x1E (bits 64-69 = time_scale last, 70=equal_pic=0, 71=dmi=1)
+ // Wait, let me redo this more carefully.
+
+ // Let me just build it with a helper to be precise.
+ bits := []uint{} // each element is 0 or 1
+
+ appendBits := func(val uint, n int) {
+ for i := n - 1; i >= 0; i-- {
+ bits = append(bits, (val>>uint(i))&1)
+ }
+ }
+
+ appendBits(0, 3) // seq_profile = 0
+ appendBits(0, 1) // still_picture = 0
+ appendBits(0, 1) // reduced_still_picture_header = 0
+ appendBits(1, 1) // timing_info_present_flag = 1
+ appendBits(1, 32) // num_units_in_display_tick = 1
+ appendBits(30, 32) // time_scale = 30
+ appendBits(0, 1) // equal_picture_interval = 0
+ appendBits(1, 1) // decoder_model_info_present_flag = 1
+ appendBits(4, 5) // buffer_delay_length_minus_1 = 4
+ appendBits(1, 32) // num_units_in_decoding_tick = 1
+ appendBits(0, 5) // buffer_removal_time_length_minus_1 = 0
+ appendBits(0, 5) // frame_presentation_time_length_minus_1 = 0
+ appendBits(0, 1) // initial_display_delay_present_flag = 0
+ appendBits(0, 5) // operating_points_cnt_minus_1 = 0
+ appendBits(0, 12) // operating_point_idc[0] = 0
+ appendBits(8, 5) // seq_level_idx[0] = 8
+ appendBits(0, 1) // seq_tier[0] = 0 (read because level > 7)
+ appendBits(1, 1) // decoder_model_present_for_this_op[0] = 1
+ appendBits(0, 5) // decoder_buffer_delay[0] = 0 (5 bits)
+ appendBits(0, 5) // encoder_buffer_delay[0] = 0 (5 bits)
+ appendBits(0, 1) // low_delay_mode_flag[0] = 0
+ appendBits(3, 4) // frame_width_bits_minus_1 = 3
+ appendBits(3, 4) // frame_height_bits_minus_1 = 3
+ appendBits(15, 4) // max_frame_width_minus_1 = 15 (width=16)
+ appendBits(15, 4) // max_frame_height_minus_1 = 15 (height=16)
+ appendBits(0, 1) // frame_id_numbers_present_flag = 0
+ appendBits(0, 1) // use_128x128_superblock = 0
+ appendBits(0, 1) // enable_filter_intra = 0
+ appendBits(0, 1) // enable_intra_edge_filter = 0
+ appendBits(0, 1) // enable_interintra_compound = 0
+ appendBits(0, 1) // enable_masked_compound = 0
+ appendBits(0, 1) // enable_warped_motion = 0
+ appendBits(0, 1) // enable_dual_filter = 0
+ appendBits(0, 1) // enable_order_hint = 0
+ appendBits(1, 1) // seq_choose_screen_content_tools = 1 (SELECT)
+ // seqForceScreenContentTools = 2 (SELECT), which is > 0
+ appendBits(1, 1) // seq_choose_integer_mv = 1 (SELECT)
+ appendBits(0, 1) // enable_superres = 0
+ appendBits(0, 1) // enable_cdef = 0
+ appendBits(0, 1) // enable_restoration = 0
+ appendBits(0, 1) // high_bitdepth = 0 → 8-bit
+ appendBits(0, 1) // mono_chrome = 0
+ appendBits(0, 1) // color_description_present_flag = 0
+ appendBits(0, 1) // color_range = 0
+ appendBits(0, 2) // chroma_sample_position = 0
+ appendBits(0, 1) // separate_uv_delta_q = 0
+
+ // Pack bits into bytes
+ payload := make([]byte, (len(bits)+7)/8)
+ for i, b := range bits {
+ if b == 1 {
+ payload[i/8] |= 1 << (7 - uint(i%8))
+ }
+ }
+
+ // Wrap in OBU header
+ obu := make([]byte, 0, 2+len(payload))
+ obu = append(obu, 0x0a) // OBU header: type=1, has_size=1
+ obu = append(obu, byte(len(payload))) // LEB128 size
+ obu = append(obu, payload...)
+
+ hdr, err := ParseAV1SequenceHeader(obu)
+ if err != nil {
+ t.Fatalf("ParseAV1SequenceHeader error: %v", err)
+ }
+
+ if hdr.SeqProfile != 0 {
+ t.Errorf("SeqProfile: got %d, want 0", hdr.SeqProfile)
+ }
+ if hdr.SeqLevelIdx != 8 {
+ t.Errorf("SeqLevelIdx: got %d, want 8", hdr.SeqLevelIdx)
+ }
+ if hdr.SeqTier != 0 {
+ t.Errorf("SeqTier: got %d, want 0", hdr.SeqTier)
+ }
+ if hdr.Width != 16 {
+ t.Errorf("Width: got %d, want 16", hdr.Width)
+ }
+ if hdr.Height != 16 {
+ t.Errorf("Height: got %d, want 16", hdr.Height)
+ }
+ if hdr.BitDepth != 8 {
+ t.Errorf("BitDepth: got %d, want 8", hdr.BitDepth)
+ }
+}
+
+func TestIsAV1KeyframeFrameHeaderOBU(t *testing.T) {
+ t.Parallel()
+
+ // Construct a temporal unit with OBU_FRAME_HEADER (type 3) instead of OBU_FRAME
+ // OBU_FRAME_HEADER: type=3, has_size=1 → (3<<3)|2 = 0x1A
+ // Payload: show_existing=0, frame_type=0 (KEY) → first byte = 0b0_00_xxxxx = 0x00
+ tu := []byte{
+ 0x12, 0x00, // TD
+ 0x1a, 0x02, // OBU_FRAME_HEADER, size=2
+ 0x00, 0x00, // show_existing=0, frame_type=0 (KEY)
+ }
+
+ if !IsAV1Keyframe(tu) {
+ t.Error("expected true for OBU_FRAME_HEADER with KEY_FRAME type")
+ }
+}
+
+func TestIsAV1KeyframeShowExistingFrame(t *testing.T) {
+ t.Parallel()
+
+ // show_existing_frame = 1 → first payload byte: 1_xx_xxxxx = 0x80
+ tu := []byte{
+ 0x12, 0x00, // TD
+ 0x32, 0x02, // OBU_FRAME, size=2
+ 0x80, 0x00, // show_existing_frame=1
+ }
+
+ if IsAV1Keyframe(tu) {
+ t.Error("expected false for show_existing_frame=1")
+ }
+}
+
+func TestFindSequenceHeaderOBUTruncated(t *testing.T) {
+ t.Parallel()
+
+ // Sequence header OBU header but truncated payload
+ data := []byte{0x0a, 0x80} // has_size with unterminated LEB128
+ obu := FindSequenceHeaderOBU(data)
+ if obu != nil {
+ t.Error("expected nil for truncated OBU")
+ }
+}
diff --git a/distribution/moq_writer.go b/distribution/moq_writer.go
index 71cc942..b176f07 100644
--- a/distribution/moq_writer.go
+++ b/distribution/moq_writer.go
@@ -88,12 +88,19 @@ func (m *moqWriter) WriteVideoFrame(w io.Writer, frame *media.VideoFrame) (int64
}
// Video Config on keyframes (ID 13, odd → length-prefixed bytes)
- if frame.IsKeyframe && frame.SPS != nil && frame.PPS != nil {
+ if frame.IsKeyframe && frame.SPS != nil {
var configData []byte
- if frame.Codec == "h265" && frame.VPS != nil {
- configData = moq.BuildHEVCDecoderConfig(frame.VPS, frame.SPS, frame.PPS)
- } else {
- configData = moq.BuildAVCDecoderConfig(frame.SPS, frame.PPS)
+ switch frame.Codec {
+ case "av1":
+ configData = moq.BuildAV1DecoderConfig(frame.SPS)
+ case "h265":
+ if frame.VPS != nil && frame.PPS != nil {
+ configData = moq.BuildHEVCDecoderConfig(frame.VPS, frame.SPS, frame.PPS)
+ }
+ default:
+ if frame.PPS != nil {
+ configData = moq.BuildAVCDecoderConfig(frame.SPS, frame.PPS)
+ }
}
if configData != nil {
exts = quicvarint.Append(exts, locExtVideoConfig)
diff --git a/distribution/moq_writer_test.go b/distribution/moq_writer_test.go
index 19ba36e..f4acace 100644
--- a/distribution/moq_writer_test.go
+++ b/distribution/moq_writer_test.go
@@ -488,6 +488,142 @@ func TestMoQWriterBytesWritten(t *testing.T) {
}
}
+func TestMoQWriterVideoFrameAV1Keyframe(t *testing.T) {
+ t.Parallel()
+ w := NewMoQWriter(1, 0)
+ var buf bytes.Buffer
+
+ if err := w.WriteStreamHeader(&buf, TrackIDVideo, 1, 0); err != nil {
+ t.Fatalf("WriteStreamHeader failed: %v", err)
+ }
+ buf.Reset()
+
+ // AV1 sequence header OBU for 1920x1080 8-bit (from demux/av1_test.go)
+ seqHdrOBU := []byte{
+ 0x0a, 0x0b,
+ 0x00, 0x00, 0x00, 0x42, 0xab, 0xbf, 0xc3, 0x71,
+ 0x2b, 0xe4, 0x01,
+ }
+
+ // Raw OBU payload representing the frame data
+ wireData := []byte{0x32, 0x10, 0xAA, 0xBB, 0xCC, 0xDD}
+
+ frame := &media.VideoFrame{
+ PTS: 50000,
+ IsKeyframe: true,
+ Codec: "av1",
+ SPS: seqHdrOBU,
+ WireData: wireData,
+ }
+
+ n, err := w.WriteVideoFrame(&buf, frame)
+ if err != nil {
+ t.Fatalf("WriteVideoFrame failed: %v", err)
+ }
+
+ if n != int64(buf.Len()) {
+ t.Errorf("bytes written: got %d, actual buffer %d", n, buf.Len())
+ }
+
+ data := buf.Bytes()
+ pos := 0
+
+ // Object ID
+ objectID, nn, err := quicvarint.Parse(data[pos:])
+ if err != nil {
+ t.Fatalf("parse object ID: %v", err)
+ }
+ if objectID != 0 {
+ t.Errorf("object ID: got %d, want 0", objectID)
+ }
+ pos += nn
+
+ // Extension headers length
+ extLen, nn, err := quicvarint.Parse(data[pos:])
+ if err != nil {
+ t.Fatalf("parse ext length: %v", err)
+ }
+ pos += nn
+
+ extEnd := pos + int(extLen)
+ foundTimestamp := false
+ foundMarking := false
+ foundConfig := false
+ var configData []byte
+
+ for pos < extEnd {
+ extID, nn, err := quicvarint.Parse(data[pos:])
+ if err != nil {
+ t.Fatalf("parse ext ID: %v", err)
+ }
+ pos += nn
+
+ if extID%2 == 0 {
+ val, nn, err := quicvarint.Parse(data[pos:])
+ if err != nil {
+ t.Fatalf("parse ext value: %v", err)
+ }
+ pos += nn
+
+ switch extID {
+ case locExtCaptureTimestamp:
+ foundTimestamp = true
+ if val != 50000 {
+ t.Errorf("capture timestamp: got %d, want 50000", val)
+ }
+ case locExtVideoFrameMarking:
+ foundMarking = true
+ if val != vfmKeyframe {
+ t.Errorf("video frame marking: got 0x%x, want 0x%x", val, vfmKeyframe)
+ }
+ }
+ } else {
+ valLen, nn, err := quicvarint.Parse(data[pos:])
+ if err != nil {
+ t.Fatalf("parse ext value length: %v", err)
+ }
+ pos += nn
+
+ if extID == locExtVideoConfig {
+ foundConfig = true
+ configData = data[pos : pos+int(valLen)]
+ }
+ pos += int(valLen)
+ }
+ }
+
+ if !foundTimestamp {
+ t.Error("missing capture timestamp extension")
+ }
+ if !foundMarking {
+ t.Error("missing video frame marking extension")
+ }
+ if !foundConfig {
+ t.Error("missing video config extension")
+ }
+
+ // AV1CodecConfigurationRecord starts with marker|version = 0x81
+ if len(configData) < 4 {
+ t.Fatalf("config data too short: %d bytes", len(configData))
+ }
+ if configData[0] != 0x81 {
+ t.Errorf("AV1 config marker|version: got 0x%02x, want 0x81", configData[0])
+ }
+
+ // Payload length
+ payloadLen, nn, err := quicvarint.Parse(data[pos:])
+ if err != nil {
+ t.Fatalf("parse payload length: %v", err)
+ }
+ pos += nn
+
+ // Payload should be the raw WireData (pass-through, no AnnexB conversion)
+ payload := data[pos : pos+int(payloadLen)]
+ if !bytes.Equal(payload, wireData) {
+ t.Errorf("payload mismatch: got %x, want %x", payload, wireData)
+ }
+}
+
func TestMoQWriterStreamHeaderSize(t *testing.T) {
t.Parallel()
w := NewMoQWriter(1, 0)
diff --git a/distribution/server.go b/distribution/server.go
index 6c59cba..442a882 100644
--- a/distribution/server.go
+++ b/distribution/server.go
@@ -104,6 +104,23 @@ type SRTPullInfo struct {
StreamID string `json:"streamId,omitempty"`
}
+// DASHPullFunc initiates a DASH pull from a remote MPD URL.
+type DASHPullFunc func(url, streamKey, videoRepID, audioRepID string) error
+
+// DASHStopFunc stops an active DASH pull by stream key.
+type DASHStopFunc func(streamKey string) error
+
+// DASHListFunc returns all active DASH pulls.
+type DASHListFunc func() []DASHPullInfo
+
+// DASHPullInfo describes an active DASH pull.
+type DASHPullInfo struct {
+ URL string `json:"url"`
+ StreamKey string `json:"streamKey"`
+ VideoRepID string `json:"videoRepresentationId,omitempty"`
+ AudioRepID string `json:"audioRepresentationId,omitempty"`
+}
+
// WebTransport session close error codes sent to clients via CloseWithError.
const (
wtErrStreamNotFound webtransport.SessionErrorCode = 1
@@ -131,6 +148,9 @@ type ServerConfig struct {
SRTPull SRTPullFunc
SRTStop SRTStopFunc
SRTList SRTListFunc
+ DASHPull DASHPullFunc
+ DASHStop DASHStopFunc
+ DASHList DASHListFunc
ExtraRoutes func(mux *http.ServeMux)
// OnStreamRegistered is called after a new stream relay is created
@@ -274,6 +294,10 @@ func (s *Server) registerAPIRoutes(mux *http.ServeMux) {
mux.HandleFunc("POST /api/srt-pull", s.handleSRTPullCreate)
mux.HandleFunc("DELETE /api/srt-pull", s.handleSRTPullStop)
mux.HandleFunc("OPTIONS /api/srt-pull", s.handleSRTPullOptions)
+ mux.HandleFunc("GET /api/dash-pull", s.handleDASHPullList)
+ mux.HandleFunc("POST /api/dash-pull", s.handleDASHPullCreate)
+ mux.HandleFunc("DELETE /api/dash-pull", s.handleDASHPullStop)
+ mux.HandleFunc("OPTIONS /api/dash-pull", s.handleDASHPullOptions)
if s.config.ExtraRoutes != nil {
s.config.ExtraRoutes(mux)
diff --git a/distribution/server_dash_handlers.go b/distribution/server_dash_handlers.go
new file mode 100644
index 0000000..2c7edd1
--- /dev/null
+++ b/distribution/server_dash_handlers.go
@@ -0,0 +1,63 @@
+package distribution
+
+import (
+ "encoding/json"
+ "net/http"
+)
+
+func (s *Server) handleDASHPullList(w http.ResponseWriter, _ *http.Request) {
+ if s.config.DASHList == nil {
+ writeJSON(w, http.StatusOK, []DASHPullInfo{})
+ return
+ }
+ writeJSON(w, http.StatusOK, s.config.DASHList())
+}
+
+func (s *Server) handleDASHPullCreate(w http.ResponseWriter, r *http.Request) {
+ if s.config.DASHPull == nil {
+ writeError(w, http.StatusNotImplemented, "DASH pull not configured")
+ return
+ }
+ var req struct {
+ URL string `json:"url"`
+ StreamKey string `json:"streamKey"`
+ VideoRepID string `json:"videoRepresentationId,omitempty"`
+ AudioRepID string `json:"audioRepresentationId,omitempty"`
+ }
+ if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
+ writeError(w, http.StatusBadRequest, err.Error())
+ return
+ }
+ if req.URL == "" || req.StreamKey == "" {
+ writeError(w, http.StatusBadRequest, "url and streamKey are required")
+ return
+ }
+ if err := s.config.DASHPull(req.URL, req.StreamKey, req.VideoRepID, req.AudioRepID); err != nil {
+ writeError(w, http.StatusConflict, err.Error())
+ return
+ }
+ writeJSON(w, http.StatusCreated, map[string]string{"status": "pulling", "streamKey": req.StreamKey})
+}
+
+func (s *Server) handleDASHPullStop(w http.ResponseWriter, r *http.Request) {
+ if s.config.DASHStop == nil {
+ writeError(w, http.StatusNotImplemented, "DASH pull not configured")
+ return
+ }
+ streamKey := r.URL.Query().Get("streamKey")
+ if streamKey == "" {
+ writeError(w, http.StatusBadRequest, "streamKey query parameter required")
+ return
+ }
+ if err := s.config.DASHStop(streamKey); err != nil {
+ writeError(w, http.StatusNotFound, err.Error())
+ return
+ }
+ writeJSON(w, http.StatusOK, map[string]string{"status": "stopped", "streamKey": streamKey})
+}
+
+func (s *Server) handleDASHPullOptions(w http.ResponseWriter, _ *http.Request) {
+ w.Header().Set("Access-Control-Allow-Methods", "GET, POST, DELETE, OPTIONS")
+ w.Header().Set("Access-Control-Allow-Headers", "Content-Type")
+ w.WriteHeader(http.StatusNoContent)
+}
diff --git a/distribution/server_test.go b/distribution/server_test.go
index b5469d7..8ee4f30 100644
--- a/distribution/server_test.go
+++ b/distribution/server_test.go
@@ -409,3 +409,58 @@ func TestStreamLifecycleCallbacks(t *testing.T) {
// If we get here without panicking, the test passes.
})
}
+
+func TestHandleDASHPullList(t *testing.T) {
+ t.Parallel()
+
+ srv := newTestServer(t)
+ handler := srv.APIHandler()
+
+ req := httptest.NewRequest("GET", "/api/dash-pull", nil)
+ rec := httptest.NewRecorder()
+ handler.ServeHTTP(rec, req)
+
+ if rec.Code != http.StatusOK {
+ t.Fatalf("status = %d, want %d", rec.Code, http.StatusOK)
+ }
+
+ // DASHList is nil, should return empty array.
+ body := strings.TrimSpace(rec.Body.String())
+ if body != "[]" {
+ t.Fatalf("body = %q, want %q", body, "[]")
+ }
+}
+
+func TestHandleDASHPullCreateMissingFields(t *testing.T) {
+ t.Parallel()
+
+ srv := newTestServer(t)
+ srv.config.DASHPull = func(_, _, _, _ string) error { return nil }
+ handler := srv.APIHandler()
+
+ req := httptest.NewRequest("POST", "/api/dash-pull", strings.NewReader(`{"url":""}`))
+ req.Header.Set("Content-Type", "application/json")
+ rec := httptest.NewRecorder()
+ handler.ServeHTTP(rec, req)
+
+ if rec.Code != http.StatusBadRequest {
+ t.Fatalf("status = %d, want %d", rec.Code, http.StatusBadRequest)
+ }
+}
+
+func TestHandleDASHPullNotConfigured(t *testing.T) {
+ t.Parallel()
+
+ srv := newTestServer(t)
+ // DASHPull is nil.
+ handler := srv.APIHandler()
+
+ req := httptest.NewRequest("POST", "/api/dash-pull", strings.NewReader(`{"url":"https://example.com/live.mpd","streamKey":"test"}`))
+ req.Header.Set("Content-Type", "application/json")
+ rec := httptest.NewRecorder()
+ handler.ServeHTTP(rec, req)
+
+ if rec.Code != http.StatusNotImplemented {
+ t.Fatalf("status = %d, want %d", rec.Code, http.StatusNotImplemented)
+ }
+}
diff --git a/go.mod b/go.mod
index 784de9d..0731c5e 100644
--- a/go.mod
+++ b/go.mod
@@ -3,8 +3,10 @@ module github.com/zsiec/prism
go 1.24.3
require (
+ github.com/Eyevinn/mp4ff v0.51.0
github.com/quic-go/quic-go v0.59.0
github.com/quic-go/webtransport-go v0.10.0
+ github.com/unki2aut/go-mpd v0.0.0-20250610073145-8336a8d84ee7
github.com/zsiec/ccx v0.2.0
github.com/zsiec/srtgo v0.2.4
golang.org/x/sync v0.17.0
@@ -13,6 +15,7 @@ require (
require (
github.com/dunglas/httpsfv v1.1.0 // indirect
github.com/quic-go/qpack v0.6.0 // indirect
+ github.com/unki2aut/go-xsd-types v0.0.0-20200220223938-30e5405398f8 // indirect
golang.org/x/crypto v0.42.0 // indirect
golang.org/x/net v0.44.0 // indirect
golang.org/x/sys v0.36.0 // indirect
diff --git a/go.sum b/go.sum
index 43ddee4..851b3e5 100644
--- a/go.sum
+++ b/go.sum
@@ -1,7 +1,15 @@
+github.com/Eyevinn/mp4ff v0.51.0 h1:ZYdHFXEcB3kJkCeCHMHl/tbCm64FJsD2XOU0Sj+ME2M=
+github.com/Eyevinn/mp4ff v0.51.0/go.mod h1:hJNUUqOBryLAzUW9wpCJyw2HaI+TCd2rUPhafoS5lgg=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/dunglas/httpsfv v1.1.0 h1:Jw76nAyKWKZKFrpMMcL76y35tOpYHqQPzHQiwDvpe54=
github.com/dunglas/httpsfv v1.1.0/go.mod h1:zID2mqw9mFsnt7YC3vYQ9/cjq30q41W+1AnDwH8TiMg=
+github.com/go-test/deep v1.1.0 h1:WOcxcdHcvdgThNXjw0t76K42FXTU7HpNQWHpA2HHNlg=
+github.com/go-test/deep v1.1.0/go.mod h1:5C2ZWiW0ErCdrYzpqxLbTX7MG14M9iiw8DgHncVwcsE=
+github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
+github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
+github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
+github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/quic-go/qpack v0.6.0 h1:g7W+BMYynC1LbYLSqRt8PBg5Tgwxn214ZZR34VIOjz8=
@@ -10,8 +18,14 @@ github.com/quic-go/quic-go v0.59.0 h1:OLJkp1Mlm/aS7dpKgTc6cnpynnD2Xg7C1pwL6vy/SA
github.com/quic-go/quic-go v0.59.0/go.mod h1:upnsH4Ju1YkqpLXC305eW3yDZ4NfnNbmQRCMWS58IKU=
github.com/quic-go/webtransport-go v0.10.0 h1:LqXXPOXuETY5Xe8ITdGisBzTYmUOy5eSj+9n4hLTjHI=
github.com/quic-go/webtransport-go v0.10.0/go.mod h1:LeGIXr5BQKE3UsynwVBeQrU1TPrbh73MGoC6jd+V7ow=
+github.com/rogpeppe/go-internal v1.10.0 h1:TMyTOH3F/DB16zRVcYyreMH6GnZZrwQVAoYjRBZyWFQ=
+github.com/rogpeppe/go-internal v1.10.0/go.mod h1:UQnix2H7Ngw/k4C5ijL5+65zddjncjaFoBhdsK/akog=
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
+github.com/unki2aut/go-mpd v0.0.0-20250610073145-8336a8d84ee7 h1:FsfEp+/I/2p/WObIlPlB10kITDkOjzi/IDqM9GYpYJc=
+github.com/unki2aut/go-mpd v0.0.0-20250610073145-8336a8d84ee7/go.mod h1:LITqXLCxxmcoHtOMgZh5NbcfS4RCrrADQXPVkYwF/cc=
+github.com/unki2aut/go-xsd-types v0.0.0-20200220223938-30e5405398f8 h1:u0Bi6Mf8BKPQnxGJ7QubdMyhb0SJjnQU7kX0BA9eASk=
+github.com/unki2aut/go-xsd-types v0.0.0-20200220223938-30e5405398f8/go.mod h1:uIeMfpmWIZ8SGp+fTfwDBWiiRn3aJm4b7rFSro9s++Q=
github.com/zsiec/ccx v0.2.0 h1:RaCC4a0ng9wa6AUvT9Kg9kfiEy9svfXL5mmRbi6Ykmo=
github.com/zsiec/ccx v0.2.0/go.mod h1:Y30W1TCZX7HAXM0miCzG18kSyKtJvqbOk7zrFcTGpU4=
github.com/zsiec/srtgo v0.2.4 h1:WzQfUMSiQglWJDilcXgFvW/23IGLr7FXNRdrbNKJyS8=
@@ -28,5 +42,7 @@ golang.org/x/sys v0.36.0 h1:KVRy2GtZBrk1cBYA7MKu5bEZFxQk4NIDV6RLVcC8o0k=
golang.org/x/sys v0.36.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
golang.org/x/text v0.30.0 h1:yznKA/E9zq54KzlzBEAWn1NXSQ8DIp/NYMy88xJjl4k=
golang.org/x/text v0.30.0/go.mod h1:yDdHFIX9t+tORqspjENWgzaCVXgk0yYnYuSZ8UzzBVM=
+gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
+gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
diff --git a/ingest/dash/mpd.go b/ingest/dash/mpd.go
new file mode 100644
index 0000000..4316699
--- /dev/null
+++ b/ingest/dash/mpd.go
@@ -0,0 +1,262 @@
+package dash
+
+import (
+ "fmt"
+ "regexp"
+ "strconv"
+ "strings"
+ "time"
+
+ "github.com/unki2aut/go-mpd"
+)
+
+type mpdInfo struct {
+ IsDynamic bool
+ AvailabilityStartTime time.Time
+ MinUpdatePeriod time.Duration
+ BaseURL string
+ VideoAdaptations []adaptationInfo
+ AudioAdaptations []adaptationInfo
+}
+
+type adaptationInfo struct {
+ MimeType string
+ Representations []representationInfo
+}
+
+type representationInfo struct {
+ ID string
+ Bandwidth int
+ Width int
+ Height int
+ Codecs string
+ SegmentTemplate segmentTemplate
+}
+
+type segmentTemplate struct {
+ InitializationPattern string
+ MediaPattern string
+ StartNumber int
+ Timescale int
+ Duration int
+}
+
+// parseMPD decodes MPD XML and extracts video/audio adaptation sets with
+// their representations and segment templates.
+func parseMPD(data []byte) (*mpdInfo, error) {
+ var m mpd.MPD
+ if err := m.Decode(data); err != nil {
+ return nil, fmt.Errorf("decode MPD: %w", err)
+ }
+
+ info := &mpdInfo{}
+
+ // Type
+ if m.Type != nil && *m.Type == "dynamic" {
+ info.IsDynamic = true
+ }
+
+ // AvailabilityStartTime
+ if m.AvailabilityStartTime != nil {
+ info.AvailabilityStartTime = time.Time(*m.AvailabilityStartTime)
+ }
+
+ // MinimumUpdatePeriod
+ if m.MinimumUpdatePeriod != nil {
+ ns, err := m.MinimumUpdatePeriod.ToNanoseconds()
+ if err == nil {
+ info.MinUpdatePeriod = time.Duration(ns)
+ }
+ }
+
+ // BaseURL (take first if present)
+ if len(m.BaseURL) > 0 {
+ info.BaseURL = m.BaseURL[0].Value
+ }
+
+ // Walk periods and adaptation sets
+ for _, period := range m.Period {
+ if period == nil {
+ continue
+ }
+ for _, as := range period.AdaptationSets {
+ if as == nil {
+ continue
+ }
+ // Determine content type from mimeType, contentType, or Representation mimeType
+ contentType := classifyAdaptationSet(as)
+ ai := adaptationInfo{
+ MimeType: contentType,
+ }
+ for _, rep := range as.Representations {
+ ri := representationInfo{}
+ if rep.ID != nil {
+ ri.ID = *rep.ID
+ }
+ if rep.Bandwidth != nil {
+ ri.Bandwidth = int(*rep.Bandwidth)
+ }
+ if rep.Width != nil {
+ ri.Width = int(*rep.Width)
+ }
+ if rep.Height != nil {
+ ri.Height = int(*rep.Height)
+ }
+ if rep.Codecs != nil {
+ ri.Codecs = *rep.Codecs
+ } else if as.Codecs != nil {
+ ri.Codecs = *as.Codecs
+ }
+
+ // SegmentTemplate: check representation level first, then adaptation set level
+ st := rep.SegmentTemplate
+ if st == nil {
+ st = as.SegmentTemplate
+ }
+ if st != nil {
+ ri.SegmentTemplate = extractSegmentTemplate(st)
+ }
+
+ ai.Representations = append(ai.Representations, ri)
+ }
+
+ if strings.HasPrefix(contentType, "video") {
+ info.VideoAdaptations = append(info.VideoAdaptations, ai)
+ } else if strings.HasPrefix(contentType, "audio") {
+ info.AudioAdaptations = append(info.AudioAdaptations, ai)
+ }
+ }
+ }
+
+ return info, nil
+}
+
+func extractSegmentTemplate(st *mpd.SegmentTemplate) segmentTemplate {
+ t := segmentTemplate{}
+ if st.Initialization != nil {
+ t.InitializationPattern = *st.Initialization
+ }
+ if st.Media != nil {
+ t.MediaPattern = *st.Media
+ }
+ if st.StartNumber != nil {
+ t.StartNumber = int(*st.StartNumber)
+ }
+ if st.Timescale != nil {
+ t.Timescale = int(*st.Timescale)
+ }
+ if st.Duration != nil {
+ t.Duration = int(*st.Duration)
+ }
+ return t
+}
+
+// selectRepresentation picks a representation from the given adaptation sets.
+// If repID is empty, the highest bandwidth representation is selected.
+// If repID is given, an exact match is required.
+func selectRepresentation(adaptations []adaptationInfo, repID string) (representationInfo, segmentTemplate, error) {
+ if len(adaptations) == 0 {
+ return representationInfo{}, segmentTemplate{}, fmt.Errorf("no adaptation sets available")
+ }
+
+ if repID == "" {
+ // Select highest bandwidth across all adaptations
+ var best representationInfo
+ found := false
+ for _, a := range adaptations {
+ for _, r := range a.Representations {
+ if !found || r.Bandwidth > best.Bandwidth {
+ best = r
+ found = true
+ }
+ }
+ }
+ if !found {
+ return representationInfo{}, segmentTemplate{}, fmt.Errorf("no representations available")
+ }
+ return best, best.SegmentTemplate, nil
+ }
+
+ // Exact match
+ for _, a := range adaptations {
+ for _, r := range a.Representations {
+ if r.ID == repID {
+ return r, r.SegmentTemplate, nil
+ }
+ }
+ }
+ return representationInfo{}, segmentTemplate{}, fmt.Errorf("representation %q not found", repID)
+}
+
+// resolveInitURL replaces $RepresentationID$ in the initialization pattern
+// and prepends the base URL.
+func resolveInitURL(tmpl segmentTemplate, repID, baseURL string) string {
+ s := strings.ReplaceAll(tmpl.InitializationPattern, "$RepresentationID$", repID)
+ return baseURL + s
+}
+
+// numberPattern matches $Number$ or $Number%0Nd$ (printf-style zero-padded).
+var numberPattern = regexp.MustCompile(`\$Number(%0(\d+)d)?\$`)
+
+// resolveMediaURL replaces $RepresentationID$ and $Number$/$Number%05d$ in
+// the media pattern and prepends the base URL.
+func resolveMediaURL(tmpl segmentTemplate, repID string, number int, baseURL string) string {
+ s := strings.ReplaceAll(tmpl.MediaPattern, "$RepresentationID$", repID)
+ s = numberPattern.ReplaceAllStringFunc(s, func(match string) string {
+ sub := numberPattern.FindStringSubmatch(match)
+ if sub[2] != "" {
+ width, _ := strconv.Atoi(sub[2])
+ return fmt.Sprintf("%0*d", width, number)
+ }
+ return strconv.Itoa(number)
+ })
+ return baseURL + s
+}
+
+// computeSegmentNumber computes the latest fully-available segment number
+// (one behind the live edge) based on the current time, availability start
+// time, and the segment template parameters.
+func computeSegmentNumber(now time.Time, ast time.Time, tmpl segmentTemplate) int {
+ if tmpl.Timescale == 0 || tmpl.Duration == 0 {
+ return tmpl.StartNumber
+ }
+ elapsed := now.Sub(ast).Seconds()
+ segmentDuration := float64(tmpl.Duration) / float64(tmpl.Timescale)
+ if elapsed < segmentDuration {
+ return tmpl.StartNumber
+ }
+ // Total segments elapsed from AST
+ totalSegments := int(elapsed / segmentDuration)
+ // One behind live edge
+ liveSegment := totalSegments - 1
+ if liveSegment < 0 {
+ liveSegment = 0
+ }
+ return tmpl.StartNumber + liveSegment
+}
+
+// classifyAdaptationSet determines if an AdaptationSet is "video" or "audio"
+// by checking (in order): mimeType attr, contentType attr, codec prefix heuristic.
+func classifyAdaptationSet(as *mpd.AdaptationSet) string {
+ // 1. AdaptationSet mimeType (e.g., "video/mp4")
+ if as.MimeType != "" {
+ return as.MimeType
+ }
+ // 2. AdaptationSet contentType (e.g., "video", "audio")
+ if as.ContentType != nil && *as.ContentType != "" {
+ return *as.ContentType
+ }
+ // 3. Infer from codec string on first Representation
+ for _, rep := range as.Representations {
+ if rep.Codecs != nil {
+ c := *rep.Codecs
+ if strings.HasPrefix(c, "av01") || strings.HasPrefix(c, "avc") || strings.HasPrefix(c, "hev") || strings.HasPrefix(c, "hvc") || strings.HasPrefix(c, "vp0") {
+ return "video"
+ }
+ if strings.HasPrefix(c, "mp4a") || strings.HasPrefix(c, "opus") {
+ return "audio"
+ }
+ }
+ }
+ return ""
+}
diff --git a/ingest/dash/mpd_test.go b/ingest/dash/mpd_test.go
new file mode 100644
index 0000000..460e718
--- /dev/null
+++ b/ingest/dash/mpd_test.go
@@ -0,0 +1,315 @@
+package dash
+
+import (
+ "testing"
+ "time"
+)
+
+const testMPDDynamic = `
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+`
+
+const testMPDStatic = `
+
+
+
+
+
+
+
+
+`
+
+func TestParseMPD(t *testing.T) {
+ info, err := parseMPD([]byte(testMPDDynamic))
+ if err != nil {
+ t.Fatalf("parseMPD: %v", err)
+ }
+
+ if !info.IsDynamic {
+ t.Error("expected IsDynamic=true")
+ }
+
+ expectedAST := time.Date(2026, 3, 18, 0, 0, 0, 0, time.UTC)
+ if !info.AvailabilityStartTime.Equal(expectedAST) {
+ t.Errorf("AvailabilityStartTime = %v, want %v", info.AvailabilityStartTime, expectedAST)
+ }
+
+ if info.MinUpdatePeriod != 2*time.Second {
+ t.Errorf("MinUpdatePeriod = %v, want 2s", info.MinUpdatePeriod)
+ }
+
+ // Video adaptations
+ if len(info.VideoAdaptations) != 1 {
+ t.Fatalf("video adaptation count = %d, want 1", len(info.VideoAdaptations))
+ }
+ va := info.VideoAdaptations[0]
+ if len(va.Representations) != 2 {
+ t.Fatalf("video rep count = %d, want 2", len(va.Representations))
+ }
+ if va.Representations[0].ID != "1080p" {
+ t.Errorf("video rep[0] ID = %q, want 1080p", va.Representations[0].ID)
+ }
+ if va.Representations[1].ID != "720p" {
+ t.Errorf("video rep[1] ID = %q, want 720p", va.Representations[1].ID)
+ }
+ if va.Representations[0].Width != 1920 || va.Representations[0].Height != 1080 {
+ t.Errorf("video rep[0] dims = %dx%d, want 1920x1080",
+ va.Representations[0].Width, va.Representations[0].Height)
+ }
+ if va.Representations[0].Codecs != "av01.0.09M.08" {
+ t.Errorf("video rep[0] codecs = %q, want av01.0.09M.08", va.Representations[0].Codecs)
+ }
+
+ // Audio adaptations
+ if len(info.AudioAdaptations) != 1 {
+ t.Fatalf("audio adaptation count = %d, want 1", len(info.AudioAdaptations))
+ }
+ aa := info.AudioAdaptations[0]
+ if len(aa.Representations) != 1 {
+ t.Fatalf("audio rep count = %d, want 1", len(aa.Representations))
+ }
+ if aa.Representations[0].ID != "aac-128k" {
+ t.Errorf("audio rep[0] ID = %q, want aac-128k", aa.Representations[0].ID)
+ }
+
+ // Segment template on video rep
+ st := va.Representations[0].SegmentTemplate
+ if st.Timescale != 90000 {
+ t.Errorf("timescale = %d, want 90000", st.Timescale)
+ }
+ if st.Duration != 180000 {
+ t.Errorf("duration = %d, want 180000", st.Duration)
+ }
+ if st.StartNumber != 1 {
+ t.Errorf("startNumber = %d, want 1", st.StartNumber)
+ }
+}
+
+func TestSelectRepresentation(t *testing.T) {
+ info, err := parseMPD([]byte(testMPDDynamic))
+ if err != nil {
+ t.Fatalf("parseMPD: %v", err)
+ }
+
+ t.Run("explicit ID", func(t *testing.T) {
+ rep, tmpl, err := selectRepresentation(info.VideoAdaptations, "720p")
+ if err != nil {
+ t.Fatalf("selectRepresentation: %v", err)
+ }
+ if rep.ID != "720p" {
+ t.Errorf("ID = %q, want 720p", rep.ID)
+ }
+ if rep.Bandwidth != 2000000 {
+ t.Errorf("Bandwidth = %d, want 2000000", rep.Bandwidth)
+ }
+ if tmpl.Timescale != 90000 {
+ t.Errorf("Timescale = %d, want 90000", tmpl.Timescale)
+ }
+ })
+
+ t.Run("empty ID selects highest bandwidth", func(t *testing.T) {
+ rep, _, err := selectRepresentation(info.VideoAdaptations, "")
+ if err != nil {
+ t.Fatalf("selectRepresentation: %v", err)
+ }
+ if rep.ID != "1080p" {
+ t.Errorf("ID = %q, want 1080p", rep.ID)
+ }
+ if rep.Bandwidth != 4000000 {
+ t.Errorf("Bandwidth = %d, want 4000000", rep.Bandwidth)
+ }
+ })
+
+ t.Run("nonexistent ID", func(t *testing.T) {
+ _, _, err := selectRepresentation(info.VideoAdaptations, "4k")
+ if err == nil {
+ t.Error("expected error for nonexistent ID")
+ }
+ })
+}
+
+func TestResolveInitURL(t *testing.T) {
+ tmpl := segmentTemplate{
+ InitializationPattern: "video/$RepresentationID$/init.mp4",
+ MediaPattern: "video/$RepresentationID$/seg-$Number$.m4s",
+ StartNumber: 1,
+ Timescale: 90000,
+ Duration: 180000,
+ }
+
+ got := resolveInitURL(tmpl, "1080p", "")
+ want := "video/1080p/init.mp4"
+ if got != want {
+ t.Errorf("resolveInitURL = %q, want %q", got, want)
+ }
+
+ got = resolveInitURL(tmpl, "720p", "https://cdn.example.com/")
+ want = "https://cdn.example.com/video/720p/init.mp4"
+ if got != want {
+ t.Errorf("resolveInitURL with baseURL = %q, want %q", got, want)
+ }
+}
+
+func TestResolveMediaURL(t *testing.T) {
+ tmpl := segmentTemplate{
+ InitializationPattern: "video/$RepresentationID$/init.mp4",
+ MediaPattern: "video/$RepresentationID$/seg-$Number$.m4s",
+ StartNumber: 1,
+ Timescale: 90000,
+ Duration: 180000,
+ }
+
+ got := resolveMediaURL(tmpl, "1080p", 42, "")
+ want := "video/1080p/seg-42.m4s"
+ if got != want {
+ t.Errorf("resolveMediaURL = %q, want %q", got, want)
+ }
+
+ got = resolveMediaURL(tmpl, "720p", 1, "https://cdn.example.com/")
+ want = "https://cdn.example.com/video/720p/seg-1.m4s"
+ if got != want {
+ t.Errorf("resolveMediaURL with baseURL = %q, want %q", got, want)
+ }
+}
+
+func TestComputeSegmentNumber(t *testing.T) {
+ ast := time.Date(2026, 3, 18, 0, 0, 0, 0, time.UTC)
+ tmpl := segmentTemplate{
+ StartNumber: 1,
+ Timescale: 90000,
+ Duration: 180000, // 2 seconds per segment
+ }
+
+ tests := []struct {
+ name string
+ offset time.Duration
+ want int
+ }{
+ {"before first segment", 1 * time.Second, 1},
+ {"exactly one segment", 2 * time.Second, 1},
+ {"two segments elapsed", 4 * time.Second, 1 + 1},
+ {"ten segments elapsed", 20 * time.Second, 1 + 9},
+ {"one hour", 1 * time.Hour, 1 + 1799},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ now := ast.Add(tt.offset)
+ got := computeSegmentNumber(now, ast, tmpl)
+ if got != tt.want {
+ t.Errorf("computeSegmentNumber(offset=%v) = %d, want %d", tt.offset, got, tt.want)
+ }
+ })
+ }
+}
+
+func TestParseMPDStatic(t *testing.T) {
+ info, err := parseMPD([]byte(testMPDStatic))
+ if err != nil {
+ t.Fatalf("parseMPD: %v", err)
+ }
+ if info.IsDynamic {
+ t.Error("expected IsDynamic=false for static MPD")
+ }
+ if len(info.VideoAdaptations) != 1 {
+ t.Errorf("video adaptation count = %d, want 1", len(info.VideoAdaptations))
+ }
+}
+
+func TestParseMPDInvalid(t *testing.T) {
+ _, err := parseMPD([]byte("this is not xml at all {{{"))
+ if err == nil {
+ t.Error("expected error for invalid XML")
+ }
+}
+
+func TestResolveMediaURLZeroPadded(t *testing.T) {
+ t.Parallel()
+ // ffmpeg uses $Number%05d$ for zero-padded segment numbers
+ tmpl := segmentTemplate{
+ MediaPattern: "chunk-stream$RepresentationID$-$Number%05d$.m4s",
+ }
+ got := resolveMediaURL(tmpl, "0", 3, "http://localhost:8080/")
+ want := "http://localhost:8080/chunk-stream0-00003.m4s"
+ if got != want {
+ t.Errorf("resolveMediaURL zero-padded = %q, want %q", got, want)
+ }
+}
+
+// Test MPD with contentType on AdaptationSet and mimeType on Representation
+// (ffmpeg pattern)
+const testMPDContentType = `
+
+
+
+
+
+
+
+
+
+
+
+
+
+`
+
+func TestParseMPDContentTypeAttr(t *testing.T) {
+ t.Parallel()
+ info, err := parseMPD([]byte(testMPDContentType))
+ if err != nil {
+ t.Fatalf("parseMPD: %v", err)
+ }
+ if !info.IsDynamic {
+ t.Error("expected dynamic MPD")
+ }
+ if len(info.VideoAdaptations) != 1 {
+ t.Fatalf("video adaptations = %d, want 1", len(info.VideoAdaptations))
+ }
+ if len(info.AudioAdaptations) != 1 {
+ t.Fatalf("audio adaptations = %d, want 1", len(info.AudioAdaptations))
+ }
+ videoRep := info.VideoAdaptations[0].Representations[0]
+ if videoRep.Codecs != "av01.0.05M.08" {
+ t.Errorf("video codecs = %q, want av01.0.05M.08", videoRep.Codecs)
+ }
+}
diff --git a/ingest/dash/pipeline.go b/ingest/dash/pipeline.go
new file mode 100644
index 0000000..33a8275
--- /dev/null
+++ b/ingest/dash/pipeline.go
@@ -0,0 +1,147 @@
+package dash
+
+import (
+ "sync/atomic"
+ "time"
+
+ "github.com/zsiec/prism/distribution"
+ "github.com/zsiec/prism/media"
+)
+
+// Compile-time interface check.
+var _ distribution.StatsProvider = (*dashPipeline)(nil)
+
+// broadcaster is the subset of distribution.Relay the DASH pipeline uses.
+// It omits BroadcastCaptions because DASH sources do not carry CEA-608/708.
+type broadcaster interface {
+ BroadcastVideo(frame *media.VideoFrame)
+ BroadcastAudio(frame *media.AudioFrame)
+ SetVideoInfo(info distribution.VideoInfo)
+ SetAudioTrackCount(count int)
+ AudioTrackCount() int
+ SetAudioInfo(info distribution.AudioInfo)
+ ViewerCount() int
+ ViewerStatsAll() []distribution.ViewerStats
+}
+
+// dashPipeline bridges fMP4 mediaSamples to the frame-based Relay for fan-out
+// to MoQ viewers. It converts decoded DASH segments into media.VideoFrame and
+// media.AudioFrame values and broadcasts them through the Relay.
+type dashPipeline struct {
+ streamKey string
+ relay broadcaster
+ seqHdrOBU []byte // AV1 sequence header OBU, prepended to keyframes as SPS
+ startTime time.Time
+ groupID uint32
+
+ videoForwarded atomic.Int64
+ audioForwarded atomic.Int64
+}
+
+// newDASHPipeline creates a pipeline that converts DASH mediaSamples into
+// VideoFrame/AudioFrame values and broadcasts them via relay. The seqHdrOBU
+// is the AV1 sequence header OBU extracted from the init segment, attached
+// to keyframes so downstream decoders can initialize.
+func newDASHPipeline(streamKey string, relay broadcaster, seqHdrOBU []byte) *dashPipeline {
+ return &dashPipeline{
+ streamKey: streamKey,
+ relay: relay,
+ seqHdrOBU: seqHdrOBU,
+ startTime: time.Now(),
+ }
+}
+
+// processVideoSamples converts video mediaSamples into VideoFrames and
+// broadcasts them through the relay. Keyframes start a new group and carry
+// the AV1 sequence header OBU as SPS so that late-joining decoders can
+// configure themselves.
+//
+// Frames are paced based on PTS deltas to avoid bursting an entire segment
+// worth of frames at once (which causes jittery playback at the viewer).
+func (p *dashPipeline) processVideoSamples(samples []mediaSample) {
+ if len(samples) == 0 {
+ return
+ }
+
+ basePTS := samples[0].PTS
+ wallStart := time.Now()
+
+ for i := range samples {
+ s := &samples[i]
+
+ // Pace: sleep until this frame's presentation time relative to the
+ // first frame in the batch, so frames arrive at ~real-time cadence.
+ if i > 0 {
+ ptsDelta := time.Duration(s.PTS-basePTS) * time.Microsecond
+ wallElapsed := time.Since(wallStart)
+ if sleep := ptsDelta - wallElapsed; sleep > 0 {
+ time.Sleep(sleep)
+ }
+ }
+
+ frame := &media.VideoFrame{
+ PTS: s.PTS,
+ DTS: s.DTS,
+ IsKeyframe: s.IsKeyframe,
+ Codec: "av1",
+ WireData: s.Data,
+ }
+
+ if s.IsKeyframe {
+ p.groupID++
+ frame.SPS = p.seqHdrOBU
+ }
+ frame.GroupID = p.groupID
+
+ p.relay.BroadcastVideo(frame)
+ p.videoForwarded.Add(1)
+ }
+}
+
+// processAudioSamples converts audio mediaSamples into AudioFrames and
+// broadcasts them through the relay. Paced by PTS like video.
+func (p *dashPipeline) processAudioSamples(samples []mediaSample) {
+ if len(samples) == 0 {
+ return
+ }
+
+ basePTS := samples[0].PTS
+ wallStart := time.Now()
+
+ for i := range samples {
+ s := &samples[i]
+
+ if i > 0 {
+ ptsDelta := time.Duration(s.PTS-basePTS) * time.Microsecond
+ wallElapsed := time.Since(wallStart)
+ if sleep := ptsDelta - wallElapsed; sleep > 0 {
+ time.Sleep(sleep)
+ }
+ }
+
+ frame := &media.AudioFrame{
+ PTS: s.PTS,
+ Data: s.Data,
+ }
+
+ p.relay.BroadcastAudio(frame)
+ p.audioForwarded.Add(1)
+ }
+}
+
+// StreamSnapshot returns a point-in-time snapshot of stream health metrics,
+// implementing the distribution.StatsProvider interface for the stats overlay
+// and REST API.
+func (p *dashPipeline) StreamSnapshot() distribution.StreamSnapshot {
+ return distribution.StreamSnapshot{
+ Timestamp: time.Now().UnixMilli(),
+ UptimeMs: time.Since(p.startTime).Milliseconds(),
+ Protocol: "DASH",
+ Video: distribution.VideoStats{
+ Codec: "AV1",
+ TotalFrames: p.videoForwarded.Load(),
+ },
+ ViewerCount: p.relay.ViewerCount(),
+ Viewers: p.relay.ViewerStatsAll(),
+ }
+}
diff --git a/ingest/dash/pipeline_test.go b/ingest/dash/pipeline_test.go
new file mode 100644
index 0000000..9d071b5
--- /dev/null
+++ b/ingest/dash/pipeline_test.go
@@ -0,0 +1,266 @@
+package dash
+
+import (
+ "testing"
+ "time"
+
+ "github.com/zsiec/prism/distribution"
+ "github.com/zsiec/prism/media"
+)
+
+// mockBroadcaster captures all calls made by the dashPipeline for test assertions.
+type mockBroadcaster struct {
+ videoFrames []*media.VideoFrame
+ audioFrames []*media.AudioFrame
+ videoInfo *distribution.VideoInfo
+ audioInfo *distribution.AudioInfo
+ trackCount int
+}
+
+func (m *mockBroadcaster) BroadcastVideo(frame *media.VideoFrame) {
+ m.videoFrames = append(m.videoFrames, frame)
+}
+
+func (m *mockBroadcaster) BroadcastAudio(frame *media.AudioFrame) {
+ m.audioFrames = append(m.audioFrames, frame)
+}
+
+func (m *mockBroadcaster) SetVideoInfo(info distribution.VideoInfo) {
+ m.videoInfo = &info
+}
+
+func (m *mockBroadcaster) SetAudioTrackCount(count int) {
+ m.trackCount = count
+}
+
+func (m *mockBroadcaster) AudioTrackCount() int {
+ if m.trackCount == 0 {
+ return 1
+ }
+ return m.trackCount
+}
+
+func (m *mockBroadcaster) SetAudioInfo(info distribution.AudioInfo) {
+ m.audioInfo = &info
+}
+
+func (m *mockBroadcaster) ViewerCount() int {
+ return 0
+}
+
+func (m *mockBroadcaster) ViewerStatsAll() []distribution.ViewerStats {
+ return nil
+}
+
+func TestDASHPipelineProcessVideoSamples(t *testing.T) {
+ t.Parallel()
+
+ mock := &mockBroadcaster{}
+ seqHdr := []byte{0x0A, 0x0B, 0x0C}
+ p := newDASHPipeline("test-stream", mock, seqHdr)
+
+ samples := []mediaSample{
+ {PTS: 0, DTS: 0, Data: []byte{0x01, 0x02}, IsKeyframe: true},
+ {PTS: 33333, DTS: 33333, Data: []byte{0x03, 0x04}, IsKeyframe: false},
+ {PTS: 66666, DTS: 66666, Data: []byte{0x05, 0x06}, IsKeyframe: false},
+ }
+
+ p.processVideoSamples(samples)
+
+ if len(mock.videoFrames) != 3 {
+ t.Fatalf("expected 3 video frames, got %d", len(mock.videoFrames))
+ }
+
+ // All frames should have codec "av1".
+ for i, f := range mock.videoFrames {
+ if f.Codec != "av1" {
+ t.Errorf("frame %d: codec = %q, want %q", i, f.Codec, "av1")
+ }
+ }
+
+ // First frame: keyframe with SPS set.
+ kf := mock.videoFrames[0]
+ if !kf.IsKeyframe {
+ t.Error("frame 0: expected IsKeyframe=true")
+ }
+ if kf.SPS == nil {
+ t.Error("frame 0: expected SPS to be set for keyframe")
+ }
+ if len(kf.SPS) != 3 || kf.SPS[0] != 0x0A {
+ t.Errorf("frame 0: SPS = %v, want %v", kf.SPS, seqHdr)
+ }
+
+ // Delta frames: no SPS.
+ for i := 1; i < 3; i++ {
+ df := mock.videoFrames[i]
+ if df.IsKeyframe {
+ t.Errorf("frame %d: expected IsKeyframe=false", i)
+ }
+ if df.SPS != nil {
+ t.Errorf("frame %d: expected SPS=nil for delta frame, got %v", i, df.SPS)
+ }
+ }
+
+ // GroupID should be 1 (incremented once for the keyframe).
+ if kf.GroupID != 1 {
+ t.Errorf("frame 0: GroupID = %d, want 1", kf.GroupID)
+ }
+ for i := 1; i < 3; i++ {
+ if mock.videoFrames[i].GroupID != 1 {
+ t.Errorf("frame %d: GroupID = %d, want 1", i, mock.videoFrames[i].GroupID)
+ }
+ }
+
+ // WireData should match sample data.
+ for i, s := range samples {
+ f := mock.videoFrames[i]
+ if len(f.WireData) != len(s.Data) {
+ t.Errorf("frame %d: WireData length = %d, want %d", i, len(f.WireData), len(s.Data))
+ }
+ for j := range s.Data {
+ if f.WireData[j] != s.Data[j] {
+ t.Errorf("frame %d: WireData[%d] = %d, want %d", i, j, f.WireData[j], s.Data[j])
+ }
+ }
+ }
+
+ // Forwarded counter.
+ if got := p.videoForwarded.Load(); got != 3 {
+ t.Errorf("videoForwarded = %d, want 3", got)
+ }
+}
+
+func TestDASHPipelineProcessAudioSamples(t *testing.T) {
+ t.Parallel()
+
+ mock := &mockBroadcaster{}
+ p := newDASHPipeline("test-stream", mock, nil)
+
+ samples := []mediaSample{
+ {PTS: 0, Data: []byte{0xAA, 0xBB}},
+ {PTS: 21333, Data: []byte{0xCC, 0xDD, 0xEE}},
+ }
+
+ p.processAudioSamples(samples)
+
+ if len(mock.audioFrames) != 2 {
+ t.Fatalf("expected 2 audio frames, got %d", len(mock.audioFrames))
+ }
+
+ for i, s := range samples {
+ f := mock.audioFrames[i]
+ if f.PTS != s.PTS {
+ t.Errorf("frame %d: PTS = %d, want %d", i, f.PTS, s.PTS)
+ }
+ if len(f.Data) != len(s.Data) {
+ t.Errorf("frame %d: Data length = %d, want %d", i, len(f.Data), len(s.Data))
+ continue
+ }
+ for j := range s.Data {
+ if f.Data[j] != s.Data[j] {
+ t.Errorf("frame %d: Data[%d] = %d, want %d", i, j, f.Data[j], s.Data[j])
+ }
+ }
+ }
+
+ if got := p.audioForwarded.Load(); got != 2 {
+ t.Errorf("audioForwarded = %d, want 2", got)
+ }
+}
+
+func TestDASHPipelineGroupIDIncrement(t *testing.T) {
+ t.Parallel()
+
+ mock := &mockBroadcaster{}
+ seqHdr := []byte{0x01}
+ p := newDASHPipeline("test-stream", mock, seqHdr)
+
+ // First GOP: keyframe + delta.
+ p.processVideoSamples([]mediaSample{
+ {PTS: 0, DTS: 0, IsKeyframe: true, Data: []byte{0x10}},
+ {PTS: 33333, DTS: 33333, IsKeyframe: false, Data: []byte{0x11}},
+ })
+
+ // Second GOP: another keyframe + delta.
+ p.processVideoSamples([]mediaSample{
+ {PTS: 66666, DTS: 66666, IsKeyframe: true, Data: []byte{0x20}},
+ {PTS: 99999, DTS: 99999, IsKeyframe: false, Data: []byte{0x21}},
+ })
+
+ // Third GOP: yet another keyframe.
+ p.processVideoSamples([]mediaSample{
+ {PTS: 133332, DTS: 133332, IsKeyframe: true, Data: []byte{0x30}},
+ })
+
+ if len(mock.videoFrames) != 5 {
+ t.Fatalf("expected 5 video frames, got %d", len(mock.videoFrames))
+ }
+
+ // Group IDs: first keyframe=1, second keyframe=2, third keyframe=3.
+ expectedGroupIDs := []uint32{1, 1, 2, 2, 3}
+ for i, f := range mock.videoFrames {
+ if f.GroupID != expectedGroupIDs[i] {
+ t.Errorf("frame %d: GroupID = %d, want %d", i, f.GroupID, expectedGroupIDs[i])
+ }
+ }
+
+ // All keyframes should have SPS.
+ keyframeIndices := []int{0, 2, 4}
+ for _, idx := range keyframeIndices {
+ if mock.videoFrames[idx].SPS == nil {
+ t.Errorf("frame %d (keyframe): expected SPS to be set", idx)
+ }
+ }
+
+ // All delta frames should not have SPS.
+ deltaIndices := []int{1, 3}
+ for _, idx := range deltaIndices {
+ if mock.videoFrames[idx].SPS != nil {
+ t.Errorf("frame %d (delta): expected SPS=nil", idx)
+ }
+ }
+}
+
+func TestDASHPipelineStreamSnapshot(t *testing.T) {
+ t.Parallel()
+
+ mock := &mockBroadcaster{}
+ p := newDASHPipeline("test-stream", mock, nil)
+
+ // Push startTime back so UptimeMs is reliably > 0.
+ p.startTime = p.startTime.Add(-10 * time.Millisecond)
+
+ // Process some samples so counters are non-zero.
+ p.processVideoSamples([]mediaSample{
+ {PTS: 0, DTS: 0, IsKeyframe: true, Data: []byte{0x01}},
+ })
+ p.processAudioSamples([]mediaSample{
+ {PTS: 0, Data: []byte{0x02}},
+ })
+
+ snap := p.StreamSnapshot()
+
+ if snap.Protocol != "DASH" {
+ t.Errorf("Protocol = %q, want %q", snap.Protocol, "DASH")
+ }
+
+ if snap.UptimeMs <= 0 {
+ t.Errorf("UptimeMs = %d, want > 0", snap.UptimeMs)
+ }
+
+ if snap.Timestamp <= 0 {
+ t.Errorf("Timestamp = %d, want > 0", snap.Timestamp)
+ }
+
+ if snap.Video.Codec != "AV1" {
+ t.Errorf("Video.Codec = %q, want %q", snap.Video.Codec, "AV1")
+ }
+
+ if snap.Video.TotalFrames != 1 {
+ t.Errorf("Video.TotalFrames = %d, want 1", snap.Video.TotalFrames)
+ }
+
+ if snap.ViewerCount != 0 {
+ t.Errorf("ViewerCount = %d, want 0", snap.ViewerCount)
+ }
+}
diff --git a/ingest/dash/puller.go b/ingest/dash/puller.go
new file mode 100644
index 0000000..83133b4
--- /dev/null
+++ b/ingest/dash/puller.go
@@ -0,0 +1,389 @@
+package dash
+
+import (
+ "context"
+ "fmt"
+ "io"
+ "log/slog"
+ "net/http"
+ "net/url"
+ "strings"
+ "sync"
+ "time"
+
+ "github.com/zsiec/prism/demux"
+ "github.com/zsiec/prism/distribution"
+ "github.com/zsiec/prism/moq"
+)
+
+// PullRequest describes a DASH source to pull from.
+type PullRequest struct {
+ URL string `json:"url"`
+ StreamKey string `json:"streamKey"`
+ VideoRepID string `json:"videoRepresentationId,omitempty"`
+ AudioRepID string `json:"audioRepresentationId,omitempty"`
+}
+
+// DistributionServer is the subset of distribution.Server needed by the puller.
+type DistributionServer interface {
+ RegisterStream(key string) *distribution.Relay
+ UnregisterStream(key string)
+ SetPipeline(key string, p distribution.StatsProvider)
+}
+
+// StreamManager is the subset of stream.Manager needed by the puller.
+type StreamManager interface {
+ Create(key string) (interface{}, bool)
+ Remove(key string)
+}
+
+type activePull struct {
+ req PullRequest
+ cancel context.CancelFunc
+}
+
+// Puller manages active DASH pull connections.
+type Puller struct {
+ log *slog.Logger
+ client *http.Client
+ mu sync.Mutex
+ pulls map[string]*activePull
+}
+
+// NewPuller creates a Puller with a default HTTP client. If log is nil,
+// slog.Default() is used.
+func NewPuller(log *slog.Logger) *Puller {
+ if log == nil {
+ log = slog.Default()
+ }
+ return &Puller{
+ log: log.With("component", "dash-puller"),
+ client: &http.Client{
+ Timeout: 30 * time.Second,
+ },
+ pulls: make(map[string]*activePull),
+ }
+}
+
+// Pull validates the request, fetches the MPD manifest synchronously, and
+// starts a background goroutine that continuously fetches and processes
+// DASH segments. It returns immediately after the MPD is validated.
+func (p *Puller) Pull(ctx context.Context, req PullRequest, distSrv DistributionServer, mgr StreamManager) error {
+ if req.URL == "" {
+ return fmt.Errorf("url is required")
+ }
+ if req.StreamKey == "" {
+ return fmt.Errorf("streamKey is required")
+ }
+
+ p.mu.Lock()
+ if _, exists := p.pulls[req.StreamKey]; exists {
+ p.mu.Unlock()
+ return fmt.Errorf("pull already active for stream key %q", req.StreamKey)
+ }
+ p.mu.Unlock()
+
+ // Fetch and validate MPD synchronously so the caller gets immediate
+ // feedback on bad URLs or malformed manifests.
+ p.log.Info("fetching MPD", "url", req.URL, "stream_key", req.StreamKey)
+ mpdData, err := p.fetchURL(ctx, req.URL)
+ if err != nil {
+ return fmt.Errorf("fetch MPD: %w", err)
+ }
+
+ info, err := parseMPD(mpdData)
+ if err != nil {
+ return fmt.Errorf("parse MPD: %w", err)
+ }
+
+ if !info.IsDynamic {
+ return fmt.Errorf("MPD is not dynamic (live); static MPDs are not supported")
+ }
+
+ // Select video representation.
+ videoRep, videoTmpl, err := selectRepresentation(info.VideoAdaptations, req.VideoRepID)
+ if err != nil {
+ return fmt.Errorf("select video representation: %w", err)
+ }
+
+ // Select audio representation (optional — some streams are video-only).
+ var audioRep representationInfo
+ var audioTmpl segmentTemplate
+ hasAudio := len(info.AudioAdaptations) > 0
+ if hasAudio {
+ audioRep, audioTmpl, err = selectRepresentation(info.AudioAdaptations, req.AudioRepID)
+ if err != nil {
+ return fmt.Errorf("select audio representation: %w", err)
+ }
+ }
+
+ // Double-check for duplicate before committing.
+ pullCtx, cancel := context.WithCancel(ctx)
+
+ p.mu.Lock()
+ if _, exists := p.pulls[req.StreamKey]; exists {
+ p.mu.Unlock()
+ cancel()
+ return fmt.Errorf("pull already active for stream key %q", req.StreamKey)
+ }
+ p.pulls[req.StreamKey] = &activePull{req: req, cancel: cancel}
+ p.mu.Unlock()
+
+ p.log.Info("starting DASH pull",
+ "stream_key", req.StreamKey,
+ "video_rep", videoRep.ID,
+ "video_dims", fmt.Sprintf("%dx%d", videoRep.Width, videoRep.Height),
+ "has_audio", hasAudio,
+ )
+
+ go p.runPullLoop(pullCtx, req, info, videoRep, videoTmpl, audioRep, audioTmpl, hasAudio, distSrv, mgr)
+
+ return nil
+}
+
+// Stop cancels an active pull by stream key.
+func (p *Puller) Stop(streamKey string) error {
+ p.mu.Lock()
+ ap, ok := p.pulls[streamKey]
+ p.mu.Unlock()
+
+ if !ok {
+ return fmt.Errorf("no active pull for stream key %q", streamKey)
+ }
+
+ ap.cancel()
+ return nil
+}
+
+// ActivePulls returns a snapshot of all active pull requests.
+func (p *Puller) ActivePulls() []PullRequest {
+ p.mu.Lock()
+ defer p.mu.Unlock()
+
+ out := make([]PullRequest, 0, len(p.pulls))
+ for _, ap := range p.pulls {
+ out = append(out, ap.req)
+ }
+ return out
+}
+
+// fetchURL performs an HTTP GET with context support and returns the response body.
+func (p *Puller) fetchURL(ctx context.Context, rawURL string) ([]byte, error) {
+ req, err := http.NewRequestWithContext(ctx, http.MethodGet, rawURL, nil)
+ if err != nil {
+ return nil, fmt.Errorf("create request: %w", err)
+ }
+
+ resp, err := p.client.Do(req)
+ if err != nil {
+ return nil, fmt.Errorf("HTTP GET %s: %w", rawURL, err)
+ }
+ defer resp.Body.Close()
+
+ if resp.StatusCode != http.StatusOK {
+ return nil, fmt.Errorf("HTTP GET %s: status %d", rawURL, resp.StatusCode)
+ }
+
+ body, err := io.ReadAll(resp.Body)
+ if err != nil {
+ return nil, fmt.Errorf("read response body: %w", err)
+ }
+ return body, nil
+}
+
+// deriveBaseURL computes the base URL for resolving relative segment URLs.
+// If the MPD contains a BaseURL, that is used; otherwise the MPD URL's
+// directory is used.
+func deriveBaseURL(mpdURL string, mpdBaseURL string) string {
+ if mpdBaseURL != "" {
+ return mpdBaseURL
+ }
+ // Strip filename from MPD URL to get directory.
+ u, err := url.Parse(mpdURL)
+ if err != nil {
+ return ""
+ }
+ idx := strings.LastIndex(u.Path, "/")
+ if idx >= 0 {
+ u.Path = u.Path[:idx+1]
+ }
+ return u.String()
+}
+
+// runPullLoop is the background goroutine that continuously fetches DASH
+// segments and feeds them into the distribution pipeline.
+func (p *Puller) runPullLoop(
+ ctx context.Context,
+ req PullRequest,
+ info *mpdInfo,
+ videoRep representationInfo,
+ videoTmpl segmentTemplate,
+ audioRep representationInfo,
+ audioTmpl segmentTemplate,
+ hasAudio bool,
+ distSrv DistributionServer,
+ mgr StreamManager,
+) {
+ defer func() {
+ distSrv.UnregisterStream(req.StreamKey)
+ if mgr != nil {
+ mgr.Remove(req.StreamKey)
+ }
+ p.mu.Lock()
+ delete(p.pulls, req.StreamKey)
+ p.mu.Unlock()
+ p.log.Info("DASH pull ended", "stream_key", req.StreamKey)
+ }()
+
+ baseURL := deriveBaseURL(req.URL, info.BaseURL)
+
+ // Download and parse video init segment.
+ videoInitURL := resolveInitURL(videoTmpl, videoRep.ID, baseURL)
+ videoInitData, err := p.fetchURL(ctx, videoInitURL)
+ if err != nil {
+ p.log.Error("fetch video init segment", "url", videoInitURL, "error", err)
+ return
+ }
+
+ initInfo, err := parseInitSegment(videoInitData)
+ if err != nil {
+ p.log.Error("parse video init segment", "error", err)
+ return
+ }
+
+ // If audio uses a different init URL, download that too.
+ if hasAudio {
+ audioInitURL := resolveInitURL(audioTmpl, audioRep.ID, baseURL)
+ if audioInitURL != videoInitURL {
+ audioInitData, err := p.fetchURL(ctx, audioInitURL)
+ if err != nil {
+ p.log.Warn("fetch audio init segment", "url", audioInitURL, "error", err)
+ hasAudio = false
+ } else {
+ audioInit, err := parseInitSegment(audioInitData)
+ if err != nil {
+ p.log.Warn("parse audio init segment", "error", err)
+ hasAudio = false
+ } else {
+ // Merge audio info into initInfo.
+ initInfo.AudioCodec = audioInit.AudioCodec
+ initInfo.SampleRate = audioInit.SampleRate
+ initInfo.Channels = audioInit.Channels
+ initInfo.AudioTimescale = audioInit.AudioTimescale
+ initInfo.AudioTrackID = audioInit.AudioTrackID
+ initInfo.audioTrex = audioInit.audioTrex
+ }
+ }
+ }
+ }
+
+ // Register stream with distribution server.
+ relay := distSrv.RegisterStream(req.StreamKey)
+ if mgr != nil {
+ mgr.Create(req.StreamKey)
+ }
+
+ // Set video info on relay with proper RFC 6381 codec string and decoder config.
+ vi := distribution.VideoInfo{
+ Codec: initInfo.VideoCodec,
+ Width: initInfo.Width,
+ Height: initInfo.Height,
+ }
+ if initInfo.SeqHeaderOBU != nil {
+ if hdr, err := demux.ParseAV1SequenceHeader(initInfo.SeqHeaderOBU); err == nil {
+ vi.Codec = hdr.CodecString()
+ vi.Width = hdr.Width
+ vi.Height = hdr.Height
+ }
+ vi.DecoderConfig = moq.BuildAV1DecoderConfig(initInfo.SeqHeaderOBU)
+ }
+ relay.SetVideoInfo(vi)
+
+ // Set audio info if available.
+ if hasAudio {
+ relay.SetAudioTrackCount(1)
+ relay.SetAudioInfo(distribution.AudioInfo{
+ Codec: initInfo.AudioCodec,
+ SampleRate: initInfo.SampleRate,
+ Channels: initInfo.Channels,
+ })
+ }
+
+ // Create pipeline and register as stats provider.
+ pipeline := newDASHPipeline(req.StreamKey, relay, initInfo.SeqHeaderOBU)
+ distSrv.SetPipeline(req.StreamKey, pipeline)
+
+ // Segment fetch loop.
+ lastSegNum := -1
+ segmentDuration := time.Duration(0)
+ if videoTmpl.Timescale > 0 && videoTmpl.Duration > 0 {
+ segmentDuration = time.Duration(float64(videoTmpl.Duration) / float64(videoTmpl.Timescale) * float64(time.Second))
+ }
+
+ for {
+ if ctx.Err() != nil {
+ return
+ }
+
+ segNum := computeSegmentNumber(time.Now(), info.AvailabilityStartTime, videoTmpl)
+
+ if segNum == lastSegNum {
+ // Sleep until the next segment boundary.
+ sleepDur := segmentDuration / 4
+ if sleepDur < 100*time.Millisecond {
+ sleepDur = 100 * time.Millisecond
+ }
+ select {
+ case <-ctx.Done():
+ return
+ case <-time.After(sleepDur):
+ continue
+ }
+ }
+ lastSegNum = segNum
+
+ // Fetch video segment.
+ videoMediaURL := resolveMediaURL(videoTmpl, videoRep.ID, segNum, baseURL)
+ videoData, err := p.fetchURL(ctx, videoMediaURL)
+ if err != nil {
+ if ctx.Err() != nil {
+ return
+ }
+ p.log.Warn("fetch video segment", "seg", segNum, "url", videoMediaURL, "error", err)
+ continue
+ }
+
+ videoSamples, err := parseMediaSegment(videoData, initInfo, true)
+ if err != nil {
+ p.log.Warn("parse video segment", "seg", segNum, "error", err)
+ continue
+ }
+ // Fetch audio segment in parallel with video pacing.
+ var audioSamples []mediaSample
+ if hasAudio {
+ audioSegNum := computeSegmentNumber(time.Now(), info.AvailabilityStartTime, audioTmpl)
+ audioMediaURL := resolveMediaURL(audioTmpl, audioRep.ID, audioSegNum, baseURL)
+ audioData, err := p.fetchURL(ctx, audioMediaURL)
+ if err != nil {
+ if ctx.Err() != nil {
+ return
+ }
+ p.log.Warn("fetch audio segment", "seg", audioSegNum, "url", audioMediaURL, "error", err)
+ } else {
+ audioSamples, err = parseMediaSegment(audioData, initInfo, false)
+ if err != nil {
+ p.log.Warn("parse audio segment", "seg", audioSegNum, "error", err)
+ }
+ }
+ }
+
+ // Pace video and audio concurrently so they interleave naturally.
+ done := make(chan struct{})
+ go func() {
+ pipeline.processAudioSamples(audioSamples)
+ close(done)
+ }()
+ pipeline.processVideoSamples(videoSamples)
+ <-done
+ }
+}
diff --git a/ingest/dash/puller_test.go b/ingest/dash/puller_test.go
new file mode 100644
index 0000000..d6406ba
--- /dev/null
+++ b/ingest/dash/puller_test.go
@@ -0,0 +1,451 @@
+package dash
+
+import (
+ "context"
+ "log/slog"
+ "net/http"
+ "net/http/httptest"
+ "testing"
+ "time"
+
+ "github.com/zsiec/prism/distribution"
+)
+
+func TestNewPuller(t *testing.T) {
+ t.Parallel()
+
+ t.Run("with logger", func(t *testing.T) {
+ log := slog.Default()
+ p := NewPuller(log)
+ if p == nil {
+ t.Fatal("NewPuller returned nil")
+ }
+ if p.pulls == nil {
+ t.Error("pulls map is nil")
+ }
+ if p.client == nil {
+ t.Error("client is nil")
+ }
+ })
+
+ t.Run("nil logger uses default", func(t *testing.T) {
+ p := NewPuller(nil)
+ if p == nil {
+ t.Fatal("NewPuller returned nil")
+ }
+ if p.log == nil {
+ t.Error("log is nil")
+ }
+ })
+}
+
+func TestPullerValidation(t *testing.T) {
+ t.Parallel()
+
+ p := NewPuller(slog.Default())
+
+ t.Run("empty URL", func(t *testing.T) {
+ err := p.Pull(context.Background(), PullRequest{
+ StreamKey: "test",
+ }, nil, nil)
+ if err == nil {
+ t.Error("expected error for empty URL")
+ }
+ })
+
+ t.Run("empty StreamKey", func(t *testing.T) {
+ err := p.Pull(context.Background(), PullRequest{
+ URL: "http://example.com/manifest.mpd",
+ }, nil, nil)
+ if err == nil {
+ t.Error("expected error for empty StreamKey")
+ }
+ })
+}
+
+func TestPullerDuplicateStreamKey(t *testing.T) {
+ t.Parallel()
+
+ p := NewPuller(slog.Default())
+
+ // Insert a fake active pull directly into the map.
+ p.mu.Lock()
+ p.pulls["test-stream"] = &activePull{
+ req: PullRequest{URL: "http://example.com/dash.mpd", StreamKey: "test-stream"},
+ cancel: func() {},
+ }
+ p.mu.Unlock()
+
+ err := p.Pull(context.Background(), PullRequest{
+ URL: "http://example.com/dash.mpd",
+ StreamKey: "test-stream",
+ }, nil, nil)
+ if err == nil {
+ t.Error("expected error for duplicate stream key")
+ }
+}
+
+func TestPullerStopNonexistent(t *testing.T) {
+ t.Parallel()
+
+ p := NewPuller(slog.Default())
+ err := p.Stop("nonexistent")
+ if err == nil {
+ t.Error("expected error for nonexistent stream key")
+ }
+}
+
+func TestPullerActivePulls(t *testing.T) {
+ t.Parallel()
+
+ p := NewPuller(slog.Default())
+
+ // Empty initially.
+ pulls := p.ActivePulls()
+ if len(pulls) != 0 {
+ t.Errorf("expected 0 active pulls, got %d", len(pulls))
+ }
+
+ // Insert fake pulls.
+ p.mu.Lock()
+ p.pulls["stream-a"] = &activePull{
+ req: PullRequest{URL: "http://a.com/dash.mpd", StreamKey: "stream-a"},
+ cancel: func() {},
+ }
+ p.pulls["stream-b"] = &activePull{
+ req: PullRequest{URL: "http://b.com/dash.mpd", StreamKey: "stream-b"},
+ cancel: func() {},
+ }
+ p.mu.Unlock()
+
+ pulls = p.ActivePulls()
+ if len(pulls) != 2 {
+ t.Errorf("expected 2 active pulls, got %d", len(pulls))
+ }
+
+ // Verify both keys are present.
+ keys := map[string]bool{}
+ for _, pr := range pulls {
+ keys[pr.StreamKey] = true
+ }
+ if !keys["stream-a"] || !keys["stream-b"] {
+ t.Errorf("unexpected keys: %v", keys)
+ }
+}
+
+func TestPullerStop(t *testing.T) {
+ t.Parallel()
+
+ p := NewPuller(slog.Default())
+
+ cancelled := false
+ p.mu.Lock()
+ p.pulls["test-stream"] = &activePull{
+ req: PullRequest{URL: "http://example.com/dash.mpd", StreamKey: "test-stream"},
+ cancel: func() { cancelled = true },
+ }
+ p.mu.Unlock()
+
+ err := p.Stop("test-stream")
+ if err != nil {
+ t.Fatalf("Stop returned error: %v", err)
+ }
+ if !cancelled {
+ t.Error("cancel function was not called")
+ }
+}
+
+func TestDeriveBaseURL(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ mpdURL string
+ mpdBaseURL string
+ want string
+ }{
+ {
+ name: "MPD has BaseURL",
+ mpdURL: "http://cdn.example.com/live/manifest.mpd",
+ mpdBaseURL: "http://cdn.example.com/live/",
+ want: "http://cdn.example.com/live/",
+ },
+ {
+ name: "derive from MPD URL",
+ mpdURL: "http://cdn.example.com/live/manifest.mpd",
+ mpdBaseURL: "",
+ want: "http://cdn.example.com/live/",
+ },
+ {
+ name: "MPD URL with deeper path",
+ mpdURL: "http://cdn.example.com/a/b/c/dash.mpd",
+ mpdBaseURL: "",
+ want: "http://cdn.example.com/a/b/c/",
+ },
+ {
+ name: "MPD URL at root",
+ mpdURL: "http://cdn.example.com/manifest.mpd",
+ mpdBaseURL: "",
+ want: "http://cdn.example.com/",
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ got := deriveBaseURL(tt.mpdURL, tt.mpdBaseURL)
+ if got != tt.want {
+ t.Errorf("deriveBaseURL(%q, %q) = %q, want %q", tt.mpdURL, tt.mpdBaseURL, got, tt.want)
+ }
+ })
+ }
+}
+
+func TestPullerFetchURL(t *testing.T) {
+ t.Parallel()
+
+ srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ switch r.URL.Path {
+ case "/ok":
+ w.WriteHeader(http.StatusOK)
+ w.Write([]byte("hello"))
+ case "/error":
+ w.WriteHeader(http.StatusInternalServerError)
+ default:
+ w.WriteHeader(http.StatusNotFound)
+ }
+ }))
+ defer srv.Close()
+
+ p := NewPuller(slog.Default())
+
+ t.Run("success", func(t *testing.T) {
+ data, err := p.fetchURL(context.Background(), srv.URL+"/ok")
+ if err != nil {
+ t.Fatalf("fetchURL: %v", err)
+ }
+ if string(data) != "hello" {
+ t.Errorf("body = %q, want %q", string(data), "hello")
+ }
+ })
+
+ t.Run("server error", func(t *testing.T) {
+ _, err := p.fetchURL(context.Background(), srv.URL+"/error")
+ if err == nil {
+ t.Error("expected error for 500 status")
+ }
+ })
+
+ t.Run("cancelled context", func(t *testing.T) {
+ ctx, cancel := context.WithCancel(context.Background())
+ cancel()
+ _, err := p.fetchURL(ctx, srv.URL+"/ok")
+ if err == nil {
+ t.Error("expected error for cancelled context")
+ }
+ })
+}
+
+// mockDistributionServer implements DistributionServer for testing.
+type mockDistributionServer struct {
+ registered map[string]*distribution.Relay
+ unregistered []string
+ pipelines map[string]distribution.StatsProvider
+}
+
+func newMockDistributionServer() *mockDistributionServer {
+ return &mockDistributionServer{
+ registered: make(map[string]*distribution.Relay),
+ pipelines: make(map[string]distribution.StatsProvider),
+ }
+}
+
+func (m *mockDistributionServer) RegisterStream(key string) *distribution.Relay {
+ r := distribution.NewRelay()
+ m.registered[key] = r
+ return r
+}
+
+func (m *mockDistributionServer) UnregisterStream(key string) {
+ m.unregistered = append(m.unregistered, key)
+ delete(m.registered, key)
+}
+
+func (m *mockDistributionServer) SetPipeline(key string, p distribution.StatsProvider) {
+ m.pipelines[key] = p
+}
+
+// mockStreamManager implements StreamManager for testing.
+type mockStreamManager struct {
+ created []string
+ removed []string
+}
+
+func (m *mockStreamManager) Create(key string) (interface{}, bool) {
+ m.created = append(m.created, key)
+ return nil, true
+}
+
+func (m *mockStreamManager) Remove(key string) {
+ m.removed = append(m.removed, key)
+}
+
+func TestPullerPullStaticMPDRejected(t *testing.T) {
+ t.Parallel()
+
+ srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
+ w.WriteHeader(http.StatusOK)
+ w.Write([]byte(testMPDStatic))
+ }))
+ defer srv.Close()
+
+ p := NewPuller(slog.Default())
+ distSrv := newMockDistributionServer()
+
+ err := p.Pull(context.Background(), PullRequest{
+ URL: srv.URL + "/manifest.mpd",
+ StreamKey: "test-static",
+ }, distSrv, nil)
+
+ if err == nil {
+ t.Fatal("expected error for static MPD")
+ }
+
+ // No stream should have been registered.
+ if len(distSrv.registered) != 0 {
+ t.Errorf("expected 0 registered streams, got %d", len(distSrv.registered))
+ }
+}
+
+func TestPullerPullBadURL(t *testing.T) {
+ t.Parallel()
+
+ p := NewPuller(slog.Default())
+ distSrv := newMockDistributionServer()
+
+ err := p.Pull(context.Background(), PullRequest{
+ URL: "http://localhost:1/nonexistent.mpd",
+ StreamKey: "test-bad",
+ }, distSrv, nil)
+
+ if err == nil {
+ t.Fatal("expected error for unreachable URL")
+ }
+}
+
+func TestPullerPullWithHTTP(t *testing.T) {
+ t.Parallel()
+
+ // Serve a dynamic MPD. The pull loop will try to fetch init segments,
+ // which will fail, causing the loop to exit. We verify the lifecycle:
+ // stream registered, then cleaned up after the loop exits.
+ requestCount := 0
+ srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ switch {
+ case r.URL.Path == "/manifest.mpd":
+ w.WriteHeader(http.StatusOK)
+ w.Write([]byte(testMPDDynamic))
+ requestCount++
+ default:
+ // Init/media segments: return 404 so the loop exits.
+ w.WriteHeader(http.StatusNotFound)
+ }
+ }))
+ defer srv.Close()
+
+ p := NewPuller(slog.Default())
+ distSrv := newMockDistributionServer()
+ mgr := &mockStreamManager{}
+
+ ctx, cancel := context.WithCancel(context.Background())
+ defer cancel()
+
+ err := p.Pull(ctx, PullRequest{
+ URL: srv.URL + "/manifest.mpd",
+ StreamKey: "test-live",
+ }, distSrv, mgr)
+
+ if err != nil {
+ t.Fatalf("Pull returned error: %v", err)
+ }
+
+ // The pull should be registered immediately.
+ pulls := p.ActivePulls()
+ if len(pulls) != 1 {
+ t.Fatalf("expected 1 active pull, got %d", len(pulls))
+ }
+ if pulls[0].StreamKey != "test-live" {
+ t.Errorf("stream key = %q, want %q", pulls[0].StreamKey, "test-live")
+ }
+
+ // The background goroutine will fail on init segment fetch (404) and
+ // clean up. Wait briefly for that.
+ time.Sleep(500 * time.Millisecond)
+
+ // After loop exits, pull should be removed from active list.
+ pulls = p.ActivePulls()
+ if len(pulls) != 0 {
+ t.Errorf("expected 0 active pulls after loop exit, got %d", len(pulls))
+ }
+
+ // Distribution server should have been unregistered.
+ if len(distSrv.unregistered) == 0 {
+ t.Error("expected stream to be unregistered after loop exit")
+ }
+
+ // Stream manager should have been cleaned up.
+ if len(mgr.removed) == 0 {
+ t.Error("expected stream manager Remove to be called")
+ }
+}
+
+func TestPullerPullAndStop(t *testing.T) {
+ t.Parallel()
+
+ // Serve a dynamic MPD and valid-ish init segments that won't parse
+ // as fMP4 — the loop will exit on parse error. Instead, we'll use
+ // a server that delays responses so we can test Stop().
+ srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ switch {
+ case r.URL.Path == "/manifest.mpd":
+ w.WriteHeader(http.StatusOK)
+ w.Write([]byte(testMPDDynamic))
+ default:
+ // Block until the request context is cancelled (simulating
+ // a slow CDN so we can test Stop cancellation).
+ <-r.Context().Done()
+ w.WriteHeader(http.StatusServiceUnavailable)
+ }
+ }))
+ defer srv.Close()
+
+ p := NewPuller(slog.Default())
+ distSrv := newMockDistributionServer()
+
+ ctx := context.Background()
+
+ err := p.Pull(ctx, PullRequest{
+ URL: srv.URL + "/manifest.mpd",
+ StreamKey: "stop-test",
+ }, distSrv, nil)
+ if err != nil {
+ t.Fatalf("Pull returned error: %v", err)
+ }
+
+ // Verify active.
+ if len(p.ActivePulls()) != 1 {
+ t.Fatalf("expected 1 active pull, got %d", len(p.ActivePulls()))
+ }
+
+ // Stop the pull.
+ err = p.Stop("stop-test")
+ if err != nil {
+ t.Fatalf("Stop returned error: %v", err)
+ }
+
+ // Wait for cleanup.
+ time.Sleep(500 * time.Millisecond)
+
+ if len(p.ActivePulls()) != 0 {
+ t.Errorf("expected 0 active pulls after Stop, got %d", len(p.ActivePulls()))
+ }
+}
diff --git a/ingest/dash/segment.go b/ingest/dash/segment.go
new file mode 100644
index 0000000..eadf175
--- /dev/null
+++ b/ingest/dash/segment.go
@@ -0,0 +1,232 @@
+package dash
+
+import (
+ "bytes"
+ "fmt"
+
+ "github.com/Eyevinn/mp4ff/mp4"
+ "github.com/zsiec/prism/demux"
+)
+
+// initSegmentInfo holds codec and track metadata extracted from an fMP4 init
+// segment (ftyp + moov).
+type initSegmentInfo struct {
+ VideoCodec string // "av1", "h264", etc.
+ Width int
+ Height int
+ SeqHeaderOBU []byte // AV1: sequence header OBU from av1C ConfigOBUs
+ AudioCodec string // e.g. "mp4a.40.2"
+ SampleRate int
+ Channels int
+ VideoTimescale uint32
+ AudioTimescale uint32
+ VideoTrackID uint32
+ AudioTrackID uint32
+ videoTrex *mp4.TrexBox // stored for GetFullSamples
+ audioTrex *mp4.TrexBox
+}
+
+// mediaSample is a decoded audio or video sample with timestamps in
+// microseconds.
+type mediaSample struct {
+ PTS int64 // microseconds
+ DTS int64 // microseconds
+ Data []byte
+ IsKeyframe bool
+}
+
+// parseInitSegment decodes an fMP4 init segment (ftyp+moov) and extracts
+// video/audio track metadata including codec configuration.
+func parseInitSegment(data []byte) (*initSegmentInfo, error) {
+ if len(data) == 0 {
+ return nil, fmt.Errorf("empty init segment data")
+ }
+
+ parsed, err := mp4.DecodeFile(bytes.NewReader(data))
+ if err != nil {
+ return nil, fmt.Errorf("decode init segment: %w", err)
+ }
+
+ // The moov box can be at parsed.Moov (when DecodeFile sees ftyp+moov)
+ // or at parsed.Init.Moov (when it's recognized as a fragmented init).
+ moov := parsed.Moov
+ if moov == nil && parsed.Init != nil {
+ moov = parsed.Init.Moov
+ }
+ if moov == nil {
+ return nil, fmt.Errorf("no moov box found in init segment")
+ }
+
+ if len(moov.Traks) == 0 {
+ return nil, fmt.Errorf("no tracks found in init segment")
+ }
+
+ info := &initSegmentInfo{}
+
+ for _, trak := range moov.Traks {
+ if trak.Mdia == nil || trak.Mdia.Hdlr == nil {
+ continue
+ }
+
+ trackID := trak.Tkhd.TrackID
+ handler := trak.Mdia.Hdlr.HandlerType
+
+ switch handler {
+ case "vide":
+ info.VideoTrackID = trackID
+ if trak.Mdia.Mdhd != nil {
+ info.VideoTimescale = trak.Mdia.Mdhd.Timescale
+ }
+ if err := extractVideoInfo(trak, info); err != nil {
+ return nil, fmt.Errorf("extract video info: %w", err)
+ }
+ case "soun":
+ info.AudioTrackID = trackID
+ if trak.Mdia.Mdhd != nil {
+ info.AudioTimescale = trak.Mdia.Mdhd.Timescale
+ }
+ if err := extractAudioInfo(trak, info); err != nil {
+ return nil, fmt.Errorf("extract audio info: %w", err)
+ }
+ }
+ }
+
+ // Resolve trex boxes from mvex.
+ if moov.Mvex != nil {
+ resolveTrex(moov.Mvex, info)
+ }
+
+ return info, nil
+}
+
+// extractVideoInfo populates video fields on info from the given video trak.
+func extractVideoInfo(trak *mp4.TrakBox, info *initSegmentInfo) error {
+ stsd := getStsd(trak)
+ if stsd == nil {
+ return fmt.Errorf("no stsd box in video track")
+ }
+
+ if stsd.Av01 != nil {
+ info.VideoCodec = "av1"
+ info.Width = int(stsd.Av01.Width)
+ info.Height = int(stsd.Av01.Height)
+ if stsd.Av01.Av1C != nil {
+ configOBUs := stsd.Av01.Av1C.ConfigOBUs
+ seqHdr := demux.FindSequenceHeaderOBU(configOBUs)
+ if seqHdr != nil {
+ info.SeqHeaderOBU = make([]byte, len(seqHdr))
+ copy(info.SeqHeaderOBU, seqHdr)
+ }
+ }
+ } else if stsd.AvcX != nil {
+ info.VideoCodec = "h264"
+ info.Width = int(stsd.AvcX.Width)
+ info.Height = int(stsd.AvcX.Height)
+ } else {
+ return fmt.Errorf("unsupported video codec (no av01 or avcX in stsd)")
+ }
+ return nil
+}
+
+// extractAudioInfo populates audio fields on info from the given audio trak.
+func extractAudioInfo(trak *mp4.TrakBox, info *initSegmentInfo) error {
+ stsd := getStsd(trak)
+ if stsd == nil {
+ return fmt.Errorf("no stsd box in audio track")
+ }
+
+ if stsd.Mp4a != nil {
+ info.AudioCodec = "mp4a.40.2"
+ info.SampleRate = int(stsd.Mp4a.SampleRate)
+ info.Channels = int(stsd.Mp4a.ChannelCount)
+ } else {
+ return fmt.Errorf("unsupported audio codec (no mp4a in stsd)")
+ }
+ return nil
+}
+
+// getStsd returns the stsd box from a track, or nil if the path is incomplete.
+func getStsd(trak *mp4.TrakBox) *mp4.StsdBox {
+ if trak.Mdia == nil || trak.Mdia.Minf == nil || trak.Mdia.Minf.Stbl == nil {
+ return nil
+ }
+ return trak.Mdia.Minf.Stbl.Stsd
+}
+
+// resolveTrex finds TrexBox entries matching video and audio track IDs.
+func resolveTrex(mvex *mp4.MvexBox, info *initSegmentInfo) {
+ // mp4ff stores a single trex in Trex and multiple in Trexs.
+ trexList := mvex.Trexs
+ if len(trexList) == 0 && mvex.Trex != nil {
+ trexList = []*mp4.TrexBox{mvex.Trex}
+ }
+ for _, trex := range trexList {
+ switch trex.TrackID {
+ case info.VideoTrackID:
+ info.videoTrex = trex
+ case info.AudioTrackID:
+ info.audioTrex = trex
+ }
+ }
+}
+
+// parseMediaSegment decodes an fMP4 media segment and extracts samples for the
+// specified track type. The init parameter must have been obtained from
+// parseInitSegment on the corresponding init segment.
+func parseMediaSegment(data []byte, init *initSegmentInfo, isVideo bool) ([]mediaSample, error) {
+ if len(data) == 0 {
+ return nil, fmt.Errorf("empty media segment data")
+ }
+ if init == nil {
+ return nil, fmt.Errorf("nil init segment info")
+ }
+
+ parsed, err := mp4.DecodeFile(bytes.NewReader(data))
+ if err != nil {
+ return nil, fmt.Errorf("decode media segment: %w", err)
+ }
+
+ var trex *mp4.TrexBox
+ var timescale uint32
+ if isVideo {
+ trex = init.videoTrex
+ timescale = init.VideoTimescale
+ } else {
+ trex = init.audioTrex
+ timescale = init.AudioTimescale
+ }
+
+ if timescale == 0 {
+ return nil, fmt.Errorf("timescale is zero")
+ }
+
+ var samples []mediaSample
+
+ for _, seg := range parsed.Segments {
+ for _, frag := range seg.Fragments {
+ fullSamples, err := frag.GetFullSamples(trex)
+ if err != nil {
+ return nil, fmt.Errorf("get full samples: %w", err)
+ }
+
+ for i := range fullSamples {
+ fs := &fullSamples[i]
+ ms := mediaSample{
+ DTS: scaleToMicroseconds(fs.DecodeTime, timescale),
+ PTS: scaleToMicroseconds(uint64(fs.PresentationTime()), timescale),
+ Data: fs.Data,
+ IsKeyframe: fs.IsSync(),
+ }
+ samples = append(samples, ms)
+ }
+ }
+ }
+
+ return samples, nil
+}
+
+// scaleToMicroseconds converts a timestamp in the given timescale to
+// microseconds.
+func scaleToMicroseconds(ts uint64, timescale uint32) int64 {
+ return int64(ts * 1_000_000 / uint64(timescale))
+}
diff --git a/ingest/dash/segment_test.go b/ingest/dash/segment_test.go
new file mode 100644
index 0000000..1c14305
--- /dev/null
+++ b/ingest/dash/segment_test.go
@@ -0,0 +1,396 @@
+package dash
+
+import (
+ "bytes"
+ "testing"
+
+ "github.com/Eyevinn/mp4ff/aac"
+ "github.com/Eyevinn/mp4ff/av1"
+ "github.com/Eyevinn/mp4ff/mp4"
+)
+
+func TestScaleToMicroseconds(t *testing.T) {
+ tests := []struct {
+ name string
+ ts uint64
+ timescale uint32
+ want int64
+ }{
+ {"zero", 0, 90000, 0},
+ {"one second at 90kHz", 90000, 90000, 1_000_000},
+ {"half second at 48kHz", 24000, 48000, 500_000},
+ {"two seconds at 90kHz", 180000, 90000, 2_000_000},
+ {"one frame at 30fps/90kHz", 3000, 90000, 33333},
+ {"one sample at 44100", 1, 44100, 22}, // 1_000_000/44100 = 22.67 -> 22
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ got := scaleToMicroseconds(tt.ts, tt.timescale)
+ if got != tt.want {
+ t.Errorf("scaleToMicroseconds(%d, %d) = %d, want %d",
+ tt.ts, tt.timescale, got, tt.want)
+ }
+ })
+ }
+}
+
+func TestParseInitSegmentNil(t *testing.T) {
+ _, err := parseInitSegment(nil)
+ if err == nil {
+ t.Error("expected error for nil data")
+ }
+
+ _, err = parseInitSegment([]byte{})
+ if err == nil {
+ t.Error("expected error for empty data")
+ }
+}
+
+func TestParseInitSegmentInvalid(t *testing.T) {
+ _, err := parseInitSegment([]byte("not valid mp4 data at all"))
+ if err == nil {
+ t.Error("expected error for invalid data")
+ }
+}
+
+func TestParseMediaSegmentNil(t *testing.T) {
+ _, err := parseMediaSegment(nil, &initSegmentInfo{}, true)
+ if err == nil {
+ t.Error("expected error for nil data")
+ }
+
+ _, err = parseMediaSegment([]byte{}, &initSegmentInfo{}, true)
+ if err == nil {
+ t.Error("expected error for empty data")
+ }
+}
+
+func TestParseMediaSegmentNilInit(t *testing.T) {
+ _, err := parseMediaSegment([]byte{0x00}, nil, true)
+ if err == nil {
+ t.Error("expected error for nil init info")
+ }
+}
+
+// buildAV1InitSegment creates a minimal AV1+AAC init segment using mp4ff
+// builder APIs. Returns the encoded bytes.
+func buildAV1InitSegment(t *testing.T) []byte {
+ t.Helper()
+
+ init := mp4.CreateEmptyInit()
+
+ // Add video track (AV1, 1920x1080, 90kHz timescale)
+ videoTrak := init.AddEmptyTrack(90000, "video", "und")
+ videoTrak.Tkhd.Width = mp4.Fixed32(1920 << 16)
+ videoTrak.Tkhd.Height = mp4.Fixed32(1080 << 16)
+
+ // Create an Av1C box with a minimal AV1 codec config
+ av1C := &mp4.Av1CBox{
+ CodecConfRec: av1.CodecConfRec{
+ Version: 1,
+ SeqProfile: 0,
+ SeqLevelIdx0: 9,
+ SeqTier0: 0,
+ HighBitdepth: 0,
+ // Build a minimal sequence header OBU for ConfigOBUs.
+ // OBU type 1 (seq header), has_size=1:
+ // header byte: 0b0_0001_0_1_0 = 0x0A
+ // size: LEB128 for 4 bytes = 0x04
+ // payload: 4 bytes of minimal data
+ ConfigOBUs: buildMinimalSeqHeaderOBU(),
+ },
+ }
+
+ // Create AV1 visual sample entry and add to stsd
+ av01 := mp4.CreateVisualSampleEntryBox("av01", 1920, 1080, av1C)
+ videoTrak.Mdia.Minf.Stbl.Stsd.AddChild(av01)
+
+ // Add audio track (AAC, 48kHz, stereo)
+ audioTrak := init.AddEmptyTrack(48000, "audio", "und")
+ err := audioTrak.SetAACDescriptor(aac.AAClc, 48000)
+ if err != nil {
+ t.Fatalf("SetAACDescriptor: %v", err)
+ }
+
+ var buf bytes.Buffer
+ err = init.Encode(&buf)
+ if err != nil {
+ t.Fatalf("encode init segment: %v", err)
+ }
+ return buf.Bytes()
+}
+
+// buildMinimalSeqHeaderOBU returns a minimal AV1 sequence header OBU with
+// size field. This is not a fully valid sequence header but contains enough
+// structure for FindSequenceHeaderOBU to locate it.
+func buildMinimalSeqHeaderOBU() []byte {
+ // OBU header: type=1 (seq header), has_size=1 => 0x0A
+ // Size in LEB128: we'll use a small payload
+ // A minimal sequence_header_obu payload (not fully decodeable but
+ // structurally valid enough for FindSequenceHeaderOBU to find and return)
+ payload := []byte{
+ 0x00, 0x00, 0x00, 0x00, 0x00,
+ 0x00, 0x00, 0x00, 0x00, 0x00,
+ }
+ obu := []byte{0x0A, byte(len(payload))}
+ obu = append(obu, payload...)
+ return obu
+}
+
+func TestParseInitSegmentAV1(t *testing.T) {
+ data := buildAV1InitSegment(t)
+
+ info, err := parseInitSegment(data)
+ if err != nil {
+ t.Fatalf("parseInitSegment: %v", err)
+ }
+
+ if info.VideoCodec != "av1" {
+ t.Errorf("VideoCodec = %q, want av1", info.VideoCodec)
+ }
+ if info.Width != 1920 {
+ t.Errorf("Width = %d, want 1920", info.Width)
+ }
+ if info.Height != 1080 {
+ t.Errorf("Height = %d, want 1080", info.Height)
+ }
+ if info.VideoTimescale != 90000 {
+ t.Errorf("VideoTimescale = %d, want 90000", info.VideoTimescale)
+ }
+ if info.VideoTrackID == 0 {
+ t.Error("VideoTrackID should not be zero")
+ }
+ if info.SeqHeaderOBU == nil {
+ t.Error("SeqHeaderOBU should not be nil")
+ }
+
+ // Audio track
+ if info.AudioCodec != "mp4a.40.2" {
+ t.Errorf("AudioCodec = %q, want mp4a.40.2", info.AudioCodec)
+ }
+ if info.SampleRate != 48000 {
+ t.Errorf("SampleRate = %d, want 48000", info.SampleRate)
+ }
+ if info.Channels != 2 {
+ t.Errorf("Channels = %d, want 2", info.Channels)
+ }
+ if info.AudioTimescale != 48000 {
+ t.Errorf("AudioTimescale = %d, want 48000", info.AudioTimescale)
+ }
+ if info.AudioTrackID == 0 {
+ t.Error("AudioTrackID should not be zero")
+ }
+
+ // Trex boxes should be resolved
+ if info.videoTrex == nil {
+ t.Error("videoTrex should not be nil")
+ }
+ if info.audioTrex == nil {
+ t.Error("audioTrex should not be nil")
+ }
+}
+
+// buildVideoMediaSegment creates a minimal fMP4 media segment with a single
+// video fragment containing 3 samples.
+func buildVideoMediaSegment(t *testing.T, trackID uint32) []byte {
+ t.Helper()
+
+ seg := mp4.NewMediaSegment()
+ frag, err := mp4.CreateFragment(1, trackID)
+ if err != nil {
+ t.Fatalf("CreateFragment: %v", err)
+ }
+
+ // Add 3 samples: first is a keyframe, others are not
+ sampleData := []byte{0x01, 0x02, 0x03, 0x04}
+ baseDecodeTime := uint64(0)
+
+ // Keyframe sample
+ frag.AddFullSample(mp4.FullSample{
+ Sample: mp4.Sample{
+ Flags: mp4.SyncSampleFlags,
+ Dur: 3000, // 1 frame at 30fps in 90kHz
+ Size: uint32(len(sampleData)),
+ CompositionTimeOffset: 0,
+ },
+ DecodeTime: baseDecodeTime,
+ Data: sampleData,
+ })
+
+ // Non-keyframe sample
+ frag.AddFullSample(mp4.FullSample{
+ Sample: mp4.Sample{
+ Flags: mp4.NonSyncSampleFlags,
+ Dur: 3000,
+ Size: uint32(len(sampleData)),
+ CompositionTimeOffset: 3000, // PTS > DTS
+ },
+ DecodeTime: baseDecodeTime + 3000,
+ Data: sampleData,
+ })
+
+ // Another non-keyframe sample
+ frag.AddFullSample(mp4.FullSample{
+ Sample: mp4.Sample{
+ Flags: mp4.NonSyncSampleFlags,
+ Dur: 3000,
+ Size: uint32(len(sampleData)),
+ CompositionTimeOffset: 0,
+ },
+ DecodeTime: baseDecodeTime + 6000,
+ Data: sampleData,
+ })
+
+ seg.AddFragment(frag)
+
+ var buf bytes.Buffer
+ err = seg.Encode(&buf)
+ if err != nil {
+ t.Fatalf("encode media segment: %v", err)
+ }
+ return buf.Bytes()
+}
+
+func TestParseMediaSegmentVideo(t *testing.T) {
+ // First build an init segment to get metadata
+ initData := buildAV1InitSegment(t)
+ info, err := parseInitSegment(initData)
+ if err != nil {
+ t.Fatalf("parseInitSegment: %v", err)
+ }
+
+ // Build a video media segment
+ segData := buildVideoMediaSegment(t, info.VideoTrackID)
+
+ samples, err := parseMediaSegment(segData, info, true)
+ if err != nil {
+ t.Fatalf("parseMediaSegment: %v", err)
+ }
+
+ if len(samples) != 3 {
+ t.Fatalf("sample count = %d, want 3", len(samples))
+ }
+
+ // First sample: keyframe at DTS=0
+ if !samples[0].IsKeyframe {
+ t.Error("sample[0] should be a keyframe")
+ }
+ if samples[0].DTS != 0 {
+ t.Errorf("sample[0].DTS = %d, want 0", samples[0].DTS)
+ }
+ if samples[0].PTS != 0 {
+ t.Errorf("sample[0].PTS = %d, want 0", samples[0].PTS)
+ }
+ if len(samples[0].Data) != 4 {
+ t.Errorf("sample[0].Data length = %d, want 4", len(samples[0].Data))
+ }
+
+ // Second sample: non-keyframe at DTS=3000/90000 * 1e6 = 33333us
+ if samples[1].IsKeyframe {
+ t.Error("sample[1] should not be a keyframe")
+ }
+ expectedDTS := scaleToMicroseconds(3000, 90000)
+ if samples[1].DTS != expectedDTS {
+ t.Errorf("sample[1].DTS = %d, want %d", samples[1].DTS, expectedDTS)
+ }
+ // PTS = (3000 + 3000) / 90000 * 1e6 = 66666us
+ expectedPTS := scaleToMicroseconds(6000, 90000)
+ if samples[1].PTS != expectedPTS {
+ t.Errorf("sample[1].PTS = %d, want %d", samples[1].PTS, expectedPTS)
+ }
+
+ // Third sample: non-keyframe at DTS=6000/90000 * 1e6 = 66666us
+ expectedDTS3 := scaleToMicroseconds(6000, 90000)
+ if samples[2].DTS != expectedDTS3 {
+ t.Errorf("sample[2].DTS = %d, want %d", samples[2].DTS, expectedDTS3)
+ }
+}
+
+func TestParseMediaSegmentZeroTimescale(t *testing.T) {
+ initInfo := &initSegmentInfo{
+ VideoTimescale: 0,
+ videoTrex: &mp4.TrexBox{TrackID: 1},
+ }
+ _, err := parseMediaSegment([]byte{0x00}, initInfo, true)
+ if err == nil {
+ t.Error("expected error for zero timescale")
+ }
+}
+
+// buildAudioMediaSegment creates a minimal fMP4 media segment with audio
+// samples.
+func buildAudioMediaSegment(t *testing.T, trackID uint32) []byte {
+ t.Helper()
+
+ seg := mp4.NewMediaSegment()
+ frag, err := mp4.CreateFragment(1, trackID)
+ if err != nil {
+ t.Fatalf("CreateFragment: %v", err)
+ }
+
+ // Add 2 audio samples (AAC frames are all sync)
+ sampleData := make([]byte, 128)
+ for i := range sampleData {
+ sampleData[i] = byte(i)
+ }
+
+ frag.AddFullSample(mp4.FullSample{
+ Sample: mp4.Sample{
+ Flags: mp4.SyncSampleFlags,
+ Dur: 1024, // typical AAC frame duration at 48kHz
+ Size: uint32(len(sampleData)),
+ },
+ DecodeTime: 0,
+ Data: sampleData,
+ })
+
+ frag.AddFullSample(mp4.FullSample{
+ Sample: mp4.Sample{
+ Flags: mp4.SyncSampleFlags,
+ Dur: 1024,
+ Size: uint32(len(sampleData)),
+ },
+ DecodeTime: 1024,
+ Data: sampleData,
+ })
+
+ seg.AddFragment(frag)
+
+ var buf bytes.Buffer
+ err = seg.Encode(&buf)
+ if err != nil {
+ t.Fatalf("encode media segment: %v", err)
+ }
+ return buf.Bytes()
+}
+
+func TestParseMediaSegmentAudio(t *testing.T) {
+ initData := buildAV1InitSegment(t)
+ info, err := parseInitSegment(initData)
+ if err != nil {
+ t.Fatalf("parseInitSegment: %v", err)
+ }
+
+ segData := buildAudioMediaSegment(t, info.AudioTrackID)
+
+ samples, err := parseMediaSegment(segData, info, false)
+ if err != nil {
+ t.Fatalf("parseMediaSegment: %v", err)
+ }
+
+ if len(samples) != 2 {
+ t.Fatalf("sample count = %d, want 2", len(samples))
+ }
+
+ // First audio sample: DTS=0
+ if samples[0].DTS != 0 {
+ t.Errorf("sample[0].DTS = %d, want 0", samples[0].DTS)
+ }
+
+ // Second audio sample: DTS = 1024/48000 * 1e6 = 21333us
+ expectedDTS := scaleToMicroseconds(1024, 48000)
+ if samples[1].DTS != expectedDTS {
+ t.Errorf("sample[1].DTS = %d, want %d", samples[1].DTS, expectedDTS)
+ }
+}
diff --git a/moq/format.go b/moq/format.go
index 2a55d1c..0c387a9 100644
--- a/moq/format.go
+++ b/moq/format.go
@@ -86,6 +86,60 @@ func BuildAVCDecoderConfig(sps, pps []byte) []byte {
return buf
}
+// BuildAV1DecoderConfig builds an AV1CodecConfigurationRecord
+// (https://aomediacodec.github.io/av1-isobmff/#av1codecconfigurationrecord)
+// from a raw Sequence Header OBU (including OBU header bytes).
+// The sequence header OBU is appended as configOBUs.
+// Returns nil if the input is nil, too short, or fails to parse.
+func BuildAV1DecoderConfig(seqHeaderOBU []byte) []byte {
+ if len(seqHeaderOBU) < 2 {
+ return nil
+ }
+
+ hdr, err := demux.ParseAV1SequenceHeader(seqHeaderOBU)
+ if err != nil {
+ return nil
+ }
+
+ buf := make([]byte, 0, 4+len(seqHeaderOBU))
+
+ // Byte 0: marker(1)=1 | version(7)=1
+ buf = append(buf, 0x81)
+
+ // Byte 1: seq_profile(3) | seq_level_idx_0(5)
+ buf = append(buf, byte(hdr.SeqProfile&0x07)<<5|byte(hdr.SeqLevelIdx&0x1F))
+
+ // Byte 2: seq_tier_0(1) | high_bitdepth(1) | twelve_bit(1) | monochrome(1) |
+ // chroma_subsampling_x(1) | chroma_subsampling_y(1) | chroma_sample_position(2)=0
+ var highBitDepth, twelveBit byte
+ if hdr.BitDepth >= 10 {
+ highBitDepth = 1
+ }
+ if hdr.BitDepth == 12 {
+ twelveBit = 1
+ }
+ var mono byte
+ if hdr.Monochrome {
+ mono = 1
+ }
+ b2 := byte(hdr.SeqTier&0x01)<<7 |
+ highBitDepth<<6 |
+ twelveBit<<5 |
+ mono<<4 |
+ byte(hdr.ChromaSubsamplingX&0x01)<<3 |
+ byte(hdr.ChromaSubsamplingY&0x01)<<2
+ // chroma_sample_position (2 bits) = 0
+ buf = append(buf, b2)
+
+ // Byte 3: reserved(3)=0 | initial_presentation_delay_present(1)=0 | reserved(4)=0
+ buf = append(buf, 0x00)
+
+ // configOBUs: the sequence header OBU itself
+ buf = append(buf, seqHeaderOBU...)
+
+ return buf
+}
+
// BuildHEVCDecoderConfig builds an HEVCDecoderConfigurationRecord
// (ISO 14496-15 §8.3.3.1.2) from raw VPS, SPS, and PPS NAL data
// (without start codes). The SPS must include the 2-byte NAL header.
diff --git a/moq/format_test.go b/moq/format_test.go
index c9f26ad..f79f3e8 100644
--- a/moq/format_test.go
+++ b/moq/format_test.go
@@ -279,6 +279,127 @@ func TestBuildHEVCDecoderConfig(t *testing.T) {
}
}
+func TestBuildAV1DecoderConfig(t *testing.T) {
+ t.Parallel()
+
+ // 1920x1080, profile 0, level 8, 8-bit, 4:2:0 (from demux/av1_test.go)
+ seqHeaderOBU := []byte{
+ 0x0a, 0x0b, // OBU header (type=1, has_size=1) + LEB128 size=11
+ 0x00, 0x00, 0x00, 0x42, 0xab, 0xbf, 0xc3, 0x71,
+ 0x2b, 0xe4, 0x01,
+ }
+
+ config := BuildAV1DecoderConfig(seqHeaderOBU)
+ if config == nil {
+ t.Fatal("expected non-nil config")
+ }
+
+ // Total length = 4-byte header + len(seqHeaderOBU)
+ expectedLen := 4 + len(seqHeaderOBU)
+ if len(config) != expectedLen {
+ t.Fatalf("total length: got %d, want %d", len(config), expectedLen)
+ }
+
+ // Byte 0: marker(1)=1 | version(7)=1 → 0x81
+ if config[0] != 0x81 {
+ t.Errorf("byte 0 (marker+version): got 0x%02x, want 0x81", config[0])
+ }
+
+ // Byte 1: seq_profile(3)=0 | seq_level_idx_0(5)=8 → 0x08
+ if config[1] != 0x08 {
+ t.Errorf("byte 1 (profile|level): got 0x%02x, want 0x08", config[1])
+ }
+
+ // Verify profile extracted correctly (top 3 bits)
+ profile := config[1] >> 5
+ if profile != 0 {
+ t.Errorf("profile: got %d, want 0", profile)
+ }
+
+ // Verify level extracted correctly (bottom 5 bits)
+ level := config[1] & 0x1F
+ if level != 8 {
+ t.Errorf("level: got %d, want 8", level)
+ }
+
+ // Byte 2: tier=0, high_bitdepth=0, twelve_bit=0, mono=0, csx=1, csy=1, csp=0 → 0x0C
+ if config[2] != 0x0C {
+ t.Errorf("byte 2 (tier/bd/chroma): got 0x%02x, want 0x0C", config[2])
+ }
+
+ // Byte 3: reserved = 0x00
+ if config[3] != 0x00 {
+ t.Errorf("byte 3 (reserved): got 0x%02x, want 0x00", config[3])
+ }
+
+ // configOBUs (bytes 4+) must match input seqHeaderOBU
+ if !bytes.Equal(config[4:], seqHeaderOBU) {
+ t.Error("configOBUs do not match input seqHeaderOBU")
+ }
+}
+
+func TestBuildAV1DecoderConfig10Bit(t *testing.T) {
+ t.Parallel()
+
+ // 640x480, profile 0, level 4, 10-bit, 4:2:0 (from demux/av1_test.go)
+ seqHeaderOBU := []byte{
+ 0x0a, 0x0b, // OBU header (type=1, has_size=1) + LEB128 size=11
+ 0x00, 0x00, 0x00, 0x24, 0xc4, 0xff, 0xdf, 0x12,
+ 0xbe, 0x50, 0x10,
+ }
+
+ config := BuildAV1DecoderConfig(seqHeaderOBU)
+ if config == nil {
+ t.Fatal("expected non-nil config")
+ }
+
+ // Total length = 4-byte header + len(seqHeaderOBU)
+ expectedLen := 4 + len(seqHeaderOBU)
+ if len(config) != expectedLen {
+ t.Fatalf("total length: got %d, want %d", len(config), expectedLen)
+ }
+
+ // Byte 0: marker + version
+ if config[0] != 0x81 {
+ t.Errorf("byte 0: got 0x%02x, want 0x81", config[0])
+ }
+
+ // Byte 1: profile=0 | level=4 → 0x04
+ if config[1] != 0x04 {
+ t.Errorf("byte 1 (profile|level): got 0x%02x, want 0x04", config[1])
+ }
+
+ // Byte 2: tier=0, high_bitdepth=1, twelve_bit=0, mono=0, csx=1, csy=1, csp=0
+ // = 0<<7 | 1<<6 | 0<<5 | 0<<4 | 1<<3 | 1<<2 | 0 = 0x4C
+ if config[2] != 0x4C {
+ t.Errorf("byte 2 (tier/bd/chroma): got 0x%02x, want 0x4C", config[2])
+ }
+
+ // configOBUs (bytes 4+) must match input
+ if !bytes.Equal(config[4:], seqHeaderOBU) {
+ t.Error("configOBUs do not match input seqHeaderOBU")
+ }
+}
+
+func TestBuildAV1DecoderConfigNil(t *testing.T) {
+ t.Parallel()
+
+ if BuildAV1DecoderConfig(nil) != nil {
+ t.Error("expected nil for nil input")
+ }
+ if BuildAV1DecoderConfig([]byte{}) != nil {
+ t.Error("expected nil for empty input")
+ }
+ if BuildAV1DecoderConfig([]byte{0x0a}) != nil {
+ t.Error("expected nil for truncated input")
+ }
+
+ // Wrong OBU type (OBU_FRAME instead of SEQUENCE_HEADER)
+ if BuildAV1DecoderConfig([]byte{0x32, 0x01, 0x00}) != nil {
+ t.Error("expected nil for wrong OBU type")
+ }
+}
+
func TestBuildHEVCDecoderConfigNil(t *testing.T) {
t.Parallel()
if BuildHEVCDecoderConfig(nil, []byte{0x42, 0x01, 0x01, 0x01}, []byte{0x44}) != nil {
diff --git a/pipeline/pipeline.go b/pipeline/pipeline.go
index 3dd88ba..89befbe 100644
--- a/pipeline/pipeline.go
+++ b/pipeline/pipeline.go
@@ -228,7 +228,19 @@ func (p *Pipeline) forwardVideo(frame *media.VideoFrame) {
// including decoder configuration record for the catalog.
func (p *Pipeline) buildVideoInfo(frame *media.VideoFrame) (distribution.VideoInfo, bool) {
var vi distribution.VideoInfo
- if frame.Codec == "h265" {
+ switch frame.Codec {
+ case "av1":
+ info, err := demux.ParseAV1SequenceHeader(frame.SPS)
+ if err != nil {
+ return vi, false
+ }
+ vi = distribution.VideoInfo{
+ Codec: info.CodecString(),
+ Width: info.Width,
+ Height: info.Height,
+ }
+ vi.DecoderConfig = moq.BuildAV1DecoderConfig(frame.SPS)
+ case "h265":
info, err := demux.ParseHEVCSPS(frame.SPS)
if err != nil {
return vi, false
@@ -241,7 +253,7 @@ func (p *Pipeline) buildVideoInfo(frame *media.VideoFrame) (distribution.VideoIn
if frame.VPS != nil {
vi.DecoderConfig = moq.BuildHEVCDecoderConfig(frame.VPS, frame.SPS, frame.PPS)
}
- } else {
+ default:
info, err := demux.ParseSPS(frame.SPS)
if err != nil {
return vi, false
diff --git a/pipeline/pipeline_test.go b/pipeline/pipeline_test.go
index 719e33f..61f393c 100644
--- a/pipeline/pipeline_test.go
+++ b/pipeline/pipeline_test.go
@@ -6,6 +6,7 @@ import (
"testing"
"github.com/zsiec/prism/distribution"
+ "github.com/zsiec/prism/media"
)
func TestNew(t *testing.T) {
@@ -71,3 +72,51 @@ func TestDemuxStats(t *testing.T) {
t.Fatal("expected non-nil DemuxStats")
}
}
+
+func TestBuildVideoInfoAV1(t *testing.T) {
+ t.Parallel()
+
+ relay := distribution.NewRelay()
+ p := New("test-stream", strings.NewReader(""), relay)
+
+ // AV1 sequence header OBU for 1920x1080, profile 0, level 8, 8-bit
+ // (same test data as demux/av1_test.go testSeqHdrOBU1080p)
+ seqHdrOBU := []byte{
+ 0x0a, 0x0b,
+ 0x00, 0x00, 0x00, 0x42, 0xab, 0xbf, 0xc3, 0x71,
+ 0x2b, 0xe4, 0x01,
+ }
+
+ frame := &media.VideoFrame{
+ PTS: 0,
+ IsKeyframe: true,
+ Codec: "av1",
+ SPS: seqHdrOBU,
+ WireData: []byte{0x32, 0x10, 0xAA, 0xBB},
+ }
+
+ vi, ok := p.buildVideoInfo(frame)
+ if !ok {
+ t.Fatal("buildVideoInfo returned false")
+ }
+
+ if vi.Width != 1920 {
+ t.Errorf("Width: got %d, want 1920", vi.Width)
+ }
+ if vi.Height != 1080 {
+ t.Errorf("Height: got %d, want 1080", vi.Height)
+ }
+
+ wantPrefix := "av01."
+ if len(vi.Codec) < len(wantPrefix) || vi.Codec[:len(wantPrefix)] != wantPrefix {
+ t.Errorf("Codec: got %q, want prefix %q", vi.Codec, wantPrefix)
+ }
+
+ if vi.DecoderConfig == nil {
+ t.Fatal("DecoderConfig is nil")
+ }
+ // AV1CodecConfigurationRecord starts with marker|version = 0x81
+ if vi.DecoderConfig[0] != 0x81 {
+ t.Errorf("DecoderConfig[0]: got 0x%02x, want 0x81", vi.DecoderConfig[0])
+ }
+}