diff --git a/.gitignore b/.gitignore index 0710f85..1e825cf 100644 --- a/.gitignore +++ b/.gitignore @@ -31,6 +31,8 @@ test/harness/* *.swp *~ +docs/plans/* + # Environment .env .env.* diff --git a/cmd/prism/main.go b/cmd/prism/main.go index e570090..1591501 100644 --- a/cmd/prism/main.go +++ b/cmd/prism/main.go @@ -19,6 +19,7 @@ import ( "github.com/zsiec/prism/certs" "github.com/zsiec/prism/distribution" "github.com/zsiec/prism/ingest" + dashingest "github.com/zsiec/prism/ingest/dash" srtingest "github.com/zsiec/prism/ingest/srt" "github.com/zsiec/prism/pipeline" "github.com/zsiec/prism/stream" @@ -80,6 +81,7 @@ func main() { a.handleNewStream(ctx, key, input, format) }) a.srtCaller = srtingest.NewCaller(a.registry, nil) + a.dashPuller = dashingest.NewPuller(nil) var distErr error a.distSrv, distErr = distribution.NewServer(distribution.ServerConfig{ @@ -96,7 +98,19 @@ func main() { SRTStop: func(streamKey string) error { return a.srtCaller.Stop(streamKey) }, - SRTList: a.listSRTPulls, + SRTList: a.listSRTPulls, + DASHPull: func(url, streamKey, videoRepID, audioRepID string) error { + return a.dashPuller.Pull(ctx, dashingest.PullRequest{ + URL: url, + StreamKey: streamKey, + VideoRepID: videoRepID, + AudioRepID: audioRepID, + }, a.distSrv, &streamManagerAdapter{a.mgr}) + }, + DASHStop: func(streamKey string) error { + return a.dashPuller.Stop(streamKey) + }, + DASHList: a.listDASHPulls, StreamLister: a.listStreams, IngestLookup: a.lookupIngest, }) @@ -138,6 +152,18 @@ func main() { return a.distSrv.Start(ctx) }) + if dashURL := os.Getenv("DASH_SOURCE"); dashURL != "" { + g.Go(func() error { + <-time.After(500 * time.Millisecond) + return a.dashPuller.Pull(ctx, dashingest.PullRequest{ + URL: dashURL, + StreamKey: envOr("DASH_STREAM_KEY", "dash-live"), + VideoRepID: os.Getenv("DASH_VIDEO_REP"), + AudioRepID: os.Getenv("DASH_AUDIO_REP"), + }, a.distSrv, &streamManagerAdapter{a.mgr}) + }) + } + if err := g.Wait(); err != nil { slog.Error("server error", "error", err) os.Exit(1) @@ -145,10 +171,11 @@ func main() { } type app struct { - mgr *stream.Manager - registry *ingest.Registry - srtCaller *srtingest.Caller - distSrv *distribution.Server + mgr *stream.Manager + registry *ingest.Registry + srtCaller *srtingest.Caller + dashPuller *dashingest.Puller + distSrv *distribution.Server } func (a *app) listSRTPulls() []distribution.SRTPullInfo { @@ -164,6 +191,20 @@ func (a *app) listSRTPulls() []distribution.SRTPullInfo { return out } +func (a *app) listDASHPulls() []distribution.DASHPullInfo { + pulls := a.dashPuller.ActivePulls() + out := make([]distribution.DASHPullInfo, len(pulls)) + for i, p := range pulls { + out[i] = distribution.DASHPullInfo{ + URL: p.URL, + StreamKey: p.StreamKey, + VideoRepID: p.VideoRepID, + AudioRepID: p.AudioRepID, + } + } + return out +} + func (a *app) listStreams() []distribution.StreamInfo { streams := a.mgr.List() infos := make([]distribution.StreamInfo, len(streams)) @@ -281,3 +322,18 @@ func buildStreamDescription(info distribution.StreamInfo) string { return strings.Join(parts, " · ") } + +// streamManagerAdapter adapts *stream.Manager to the dash.StreamManager +// interface, which uses interface{} in the Create return to avoid an import +// cycle with the stream package. +type streamManagerAdapter struct { + mgr *stream.Manager +} + +func (a *streamManagerAdapter) Create(key string) (interface{}, bool) { + return a.mgr.Create(key) +} + +func (a *streamManagerAdapter) Remove(key string) { + a.mgr.Remove(key) +} diff --git a/demux/av1.go b/demux/av1.go new file mode 100644 index 0000000..7374e13 --- /dev/null +++ b/demux/av1.go @@ -0,0 +1,518 @@ +package demux + +import ( + "errors" + "fmt" +) + +// AV1 OBU type constants as defined in AV1 spec §6.2.2. +const ( + OBUSequenceHeader = 1 + OBUTemporalDelimiter = 2 + OBUFrameHeader = 3 + OBUFrame = 6 +) + +// AV1 frame type constants as defined in AV1 spec §6.8.2. +const ( + av1FrameKey = 0 + av1FrameInter = 1 + av1FrameIntraOnly = 2 + av1FrameSwitch = 3 +) + +var ( + errAV1TooShort = errors.New("AV1 data too short") + errAV1InvalidOBU = errors.New("invalid AV1 OBU header") + errAV1LEB128 = errors.New("invalid LEB128 encoding") + errAV1SeqHdrParse = errors.New("failed to parse AV1 sequence header") + errAV1ForbiddenBit = errors.New("AV1 OBU forbidden bit set") +) + +// AV1SequenceHeader holds parameters extracted from an AV1 Sequence Header OBU. +type AV1SequenceHeader struct { + SeqProfile int + SeqLevelIdx int + SeqTier int // 0=Main, 1=High + BitDepth int + Width int + Height int + ChromaSubsamplingX int + ChromaSubsamplingY int + Monochrome bool +} + +// CodecString returns the RFC 6381 codec parameter string for AV1. +// Format: "av01.P.LLT.DD" where P=profile, LL=zero-padded level, +// T=tier(M/H), DD=zero-padded bit depth. +func (s AV1SequenceHeader) CodecString() string { + tier := "M" + if s.SeqTier == 1 { + tier = "H" + } + return fmt.Sprintf("av01.%d.%02d%s.%02d", s.SeqProfile, s.SeqLevelIdx, tier, s.BitDepth) +} + +// obuHeader represents a parsed AV1 OBU header. +type obuHeader struct { + obuType int + hasExtension bool + hasSizeField bool + temporalID int + spatialID int + headerLen int // total header bytes (1 or 2) + payloadSize int // only valid when hasSizeField is true + totalSize int // headerLen + leb128 size bytes + payloadSize +} + +// parseOBUHeader parses an OBU header from the given data. Returns the +// parsed header and total bytes consumed (header + optional size field). +func parseOBUHeader(data []byte) (obuHeader, error) { + if len(data) < 1 { + return obuHeader{}, errAV1TooShort + } + + b := data[0] + if b&0x80 != 0 { + return obuHeader{}, errAV1ForbiddenBit + } + + h := obuHeader{ + obuType: int((b >> 3) & 0x0F), + hasExtension: b&0x04 != 0, + hasSizeField: b&0x02 != 0, + headerLen: 1, + } + + pos := 1 + if h.hasExtension { + if len(data) < 2 { + return obuHeader{}, errAV1TooShort + } + ext := data[1] + h.temporalID = int((ext >> 5) & 0x07) + h.spatialID = int((ext >> 3) & 0x03) + h.headerLen = 2 + pos = 2 + } + + if h.hasSizeField { + size, n, err := readLEB128(data[pos:]) + if err != nil { + return obuHeader{}, err + } + h.payloadSize = size + h.totalSize = pos + n + size + } else { + // Without size field, OBU extends to end of temporal unit. + h.payloadSize = len(data) - pos + h.totalSize = len(data) + } + + return h, nil +} + +// readLEB128 reads a LEB128 encoded unsigned integer from data. +// Returns the value, number of bytes consumed, and any error. +func readLEB128(data []byte) (int, int, error) { + val := 0 + for i := 0; i < len(data) && i < 8; i++ { + b := data[i] + val |= int(b&0x7F) << (i * 7) + if b&0x80 == 0 { + return val, i + 1, nil + } + } + return 0, 0, errAV1LEB128 +} + +// av1BitReader is a MSB-first bit reader that tracks errors internally. +// After an overflow, all subsequent reads return zero and the error is +// preserved in the err field. +type av1BitReader struct { + data []byte + pos int + bit int + err error +} + +func newAV1BitReader(data []byte) *av1BitReader { + return &av1BitReader{data: data} +} + +func (br *av1BitReader) readBits(n int) uint { + if br.err != nil { + return 0 + } + var val uint + for i := 0; i < n; i++ { + if br.pos >= len(br.data) { + br.err = errAV1TooShort + return 0 + } + val = (val << 1) | uint((br.data[br.pos]>>(7-br.bit))&1) + br.bit++ + if br.bit == 8 { + br.bit = 0 + br.pos++ + } + } + return val +} + +// readUvlc reads an unsigned variable-length code (AV1 spec §4.10.3). +func (br *av1BitReader) readUvlc() uint { + if br.err != nil { + return 0 + } + leadingZeros := 0 + for { + b := br.readBits(1) + if br.err != nil { + return 0 + } + if b == 1 { + break + } + leadingZeros++ + if leadingZeros >= 32 { + br.err = errAV1SeqHdrParse + return 0 + } + } + if leadingZeros == 0 { + return 0 + } + val := br.readBits(leadingZeros) + return val + (1 << leadingZeros) - 1 +} + +// ParseAV1SequenceHeader parses an AV1 Sequence Header OBU (including OBU +// header bytes) and extracts codec parameters. The input must start with +// the OBU header byte. +func ParseAV1SequenceHeader(data []byte) (*AV1SequenceHeader, error) { + if len(data) < 2 { + return nil, errAV1TooShort + } + + h, err := parseOBUHeader(data) + if err != nil { + return nil, err + } + + if h.obuType != OBUSequenceHeader { + return nil, fmt.Errorf("expected OBU_SEQUENCE_HEADER (type %d), got type %d", OBUSequenceHeader, h.obuType) + } + + // Payload starts after header + LEB128 size + payloadStart := h.totalSize - h.payloadSize + if payloadStart >= len(data) || payloadStart+h.payloadSize > len(data) { + return nil, errAV1TooShort + } + payload := data[payloadStart : payloadStart+h.payloadSize] + + return parseSequenceHeaderPayload(payload) +} + +// parseSequenceHeaderPayload parses the raw payload of a sequence_header_obu. +func parseSequenceHeaderPayload(payload []byte) (*AV1SequenceHeader, error) { + br := newAV1BitReader(payload) + hdr := &AV1SequenceHeader{} + + hdr.SeqProfile = int(br.readBits(3)) + stillPicture := br.readBits(1) + reducedStillPicture := br.readBits(1) + _ = stillPicture + + var decoderModelInfoPresent bool + var bufferDelayLengthMinus1 uint + + if reducedStillPicture == 1 { + hdr.SeqLevelIdx = int(br.readBits(5)) + hdr.SeqTier = 0 + } else { + timingInfoPresent := br.readBits(1) + + if timingInfoPresent == 1 { + // timing_info() + br.readBits(32) // num_units_in_display_tick + br.readBits(32) // time_scale + equalPictureInterval := br.readBits(1) + if equalPictureInterval == 1 { + br.readUvlc() // num_ticks_per_picture_minus_1 + } + + dmi := br.readBits(1) + decoderModelInfoPresent = dmi == 1 + if decoderModelInfoPresent { + // decoder_model_info() + bufferDelayLengthMinus1 = br.readBits(5) + br.readBits(32) // num_units_in_decoding_tick + br.readBits(5) // buffer_removal_time_length_minus_1 + br.readBits(5) // frame_presentation_time_length_minus_1 + } + } + + initialDisplayDelayPresent := br.readBits(1) + operatingPointsCntMinus1 := br.readBits(5) + + for i := uint(0); i <= operatingPointsCntMinus1; i++ { + br.readBits(12) // operating_point_idc + levelIdx := br.readBits(5) + var tier uint + if levelIdx > 7 { + tier = br.readBits(1) + } + if i == 0 { + hdr.SeqLevelIdx = int(levelIdx) + hdr.SeqTier = int(tier) + } + if decoderModelInfoPresent { + decoderModelPresentForThisOp := br.readBits(1) + if decoderModelPresentForThisOp == 1 { + // operating_parameters_info(): skip decoder/encoder buffer delays and flag + n := int(bufferDelayLengthMinus1) + 1 + br.readBits(n) // decoder_buffer_delay[i] + br.readBits(n) // encoder_buffer_delay[i] + br.readBits(1) // low_delay_mode_flag[i] + } + } + if initialDisplayDelayPresent == 1 { + flag := br.readBits(1) + if flag == 1 { + br.readBits(4) // initial_display_delay_minus_1 + } + } + } + } + + if br.err != nil { + return nil, fmt.Errorf("%w: %v", errAV1SeqHdrParse, br.err) + } + + frameWidthBitsMinus1 := br.readBits(4) + frameHeightBitsMinus1 := br.readBits(4) + maxFrameWidthMinus1 := br.readBits(int(frameWidthBitsMinus1) + 1) + maxFrameHeightMinus1 := br.readBits(int(frameHeightBitsMinus1) + 1) + hdr.Width = int(maxFrameWidthMinus1) + 1 + hdr.Height = int(maxFrameHeightMinus1) + 1 + + if reducedStillPicture == 0 { + frameIDNumbersPresent := br.readBits(1) + if frameIDNumbersPresent == 1 { + br.readBits(4) // delta_frame_id_length_minus_2 + br.readBits(3) // additional_frame_id_length_minus_1 + } + } + + br.readBits(1) // use_128x128_superblock + br.readBits(1) // enable_filter_intra + br.readBits(1) // enable_intra_edge_filter + + if reducedStillPicture == 0 { + br.readBits(1) // enable_interintra_compound + br.readBits(1) // enable_masked_compound + br.readBits(1) // enable_warped_motion + br.readBits(1) // enable_dual_filter + enableOrderHint := br.readBits(1) + + if enableOrderHint == 1 { + br.readBits(1) // enable_jnt_comp + br.readBits(1) // enable_ref_frame_mvs + } + + seqChooseScreenContentTools := br.readBits(1) + var seqForceScreenContentTools uint + if seqChooseScreenContentTools == 0 { + seqForceScreenContentTools = br.readBits(1) + } else { + seqForceScreenContentTools = 2 // SELECT_SCREEN_CONTENT_TOOLS + } + + if seqForceScreenContentTools > 0 { + seqChooseIntegerMV := br.readBits(1) + if seqChooseIntegerMV == 0 { + br.readBits(1) // seq_force_integer_mv + } + } + + if enableOrderHint == 1 { + br.readBits(3) // order_hint_bits_minus_1 + } + } + + br.readBits(1) // enable_superres + br.readBits(1) // enable_cdef + br.readBits(1) // enable_restoration + + if br.err != nil { + return nil, fmt.Errorf("%w: %v", errAV1SeqHdrParse, br.err) + } + + // color_config() + parseColorConfig(br, hdr) + + if br.err != nil { + return nil, fmt.Errorf("%w: %v", errAV1SeqHdrParse, br.err) + } + + return hdr, nil +} + +// parseColorConfig parses the color_config() section of a sequence header. +func parseColorConfig(br *av1BitReader, hdr *AV1SequenceHeader) { + highBitDepth := br.readBits(1) + + if hdr.SeqProfile == 2 && highBitDepth == 1 { + twelveBit := br.readBits(1) + if twelveBit == 1 { + hdr.BitDepth = 12 + } else { + hdr.BitDepth = 10 + } + } else if highBitDepth == 1 { + hdr.BitDepth = 10 + } else { + hdr.BitDepth = 8 + } + + var monoChrome uint + if hdr.SeqProfile == 1 { + monoChrome = 0 + } else { + monoChrome = br.readBits(1) + } + hdr.Monochrome = monoChrome == 1 + + colorDescriptionPresent := br.readBits(1) + var colorPrimaries, transferCharacteristics, matrixCoefficients uint + if colorDescriptionPresent == 1 { + colorPrimaries = br.readBits(8) + transferCharacteristics = br.readBits(8) + matrixCoefficients = br.readBits(8) + } else { + colorPrimaries = 2 // CP_UNSPECIFIED + transferCharacteristics = 2 // TC_UNSPECIFIED + matrixCoefficients = 2 // MC_UNSPECIFIED + } + + if monoChrome == 1 { + br.readBits(1) // color_range + hdr.ChromaSubsamplingX = 1 + hdr.ChromaSubsamplingY = 1 + return + } + + // Check for sRGB/sYCC + if colorPrimaries == 1 && transferCharacteristics == 13 && matrixCoefficients == 0 { + // color_range is implicitly 1 (full) + hdr.ChromaSubsamplingX = 0 + hdr.ChromaSubsamplingY = 0 + return + } + + br.readBits(1) // color_range + + if hdr.SeqProfile == 0 { + hdr.ChromaSubsamplingX = 1 + hdr.ChromaSubsamplingY = 1 + } else if hdr.SeqProfile == 1 { + hdr.ChromaSubsamplingX = 0 + hdr.ChromaSubsamplingY = 0 + } else { + // Profile 2 + if hdr.BitDepth == 12 { + sx := br.readBits(1) + hdr.ChromaSubsamplingX = int(sx) + if sx == 1 { + hdr.ChromaSubsamplingY = int(br.readBits(1)) + } else { + hdr.ChromaSubsamplingY = 0 + } + } else { + hdr.ChromaSubsamplingX = 1 + hdr.ChromaSubsamplingY = 0 + } + } + + if hdr.ChromaSubsamplingX == 1 && hdr.ChromaSubsamplingY == 1 { + br.readBits(2) // chroma_sample_position + } + + br.readBits(1) // separate_uv_delta_q +} + +// IsAV1Keyframe scans OBU headers in a temporal unit for a key frame. +// It checks OBU_FRAME (type 6) and OBU_FRAME_HEADER (type 3) for +// frame_type == KEY_FRAME (0). Returns false for nil/empty input. +func IsAV1Keyframe(temporalUnit []byte) bool { + if len(temporalUnit) == 0 { + return false + } + + pos := 0 + for pos < len(temporalUnit) { + h, err := parseOBUHeader(temporalUnit[pos:]) + if err != nil { + return false + } + + if h.obuType == OBUFrame || h.obuType == OBUFrameHeader { + // Payload starts after header bytes and optional LEB128 size + payloadStart := pos + h.totalSize - h.payloadSize + if payloadStart >= len(temporalUnit) || h.payloadSize < 1 { + return false + } + // frame_header_obu(): first bit is show_existing_frame + // If show_existing_frame==1, this is not a new frame. + // If show_existing_frame==0, next 2 bits are frame_type. + payload := temporalUnit[payloadStart:] + showExisting := (payload[0] >> 7) & 1 + if showExisting == 1 { + return false + } + frameType := (payload[0] >> 5) & 0x03 + return frameType == av1FrameKey + } + + if h.hasSizeField { + pos += h.totalSize + } else { + // No size field — OBU extends to end + break + } + } + return false +} + +// FindSequenceHeaderOBU scans a temporal unit for an OBU_SEQUENCE_HEADER +// (type 1) and returns the complete OBU bytes (header + payload). +// Returns nil if no sequence header is found. +func FindSequenceHeaderOBU(temporalUnit []byte) []byte { + if len(temporalUnit) == 0 { + return nil + } + + pos := 0 + for pos < len(temporalUnit) { + h, err := parseOBUHeader(temporalUnit[pos:]) + if err != nil { + return nil + } + + if h.obuType == OBUSequenceHeader { + end := pos + h.totalSize + if end > len(temporalUnit) { + end = len(temporalUnit) + } + return temporalUnit[pos:end] + } + + if h.hasSizeField { + pos += h.totalSize + } else { + break + } + } + return nil +} diff --git a/demux/av1_test.go b/demux/av1_test.go new file mode 100644 index 0000000..c4652e6 --- /dev/null +++ b/demux/av1_test.go @@ -0,0 +1,862 @@ +package demux + +import ( + "testing" +) + +// Test data generated with ffmpeg using libsvtav1. +// 1920x1080, profile 0, level 8, 8-bit, 4:2:0. +var testSeqHdrOBU1080p = []byte{ + 0x0a, 0x0b, // OBU header (type=1, has_size=1) + LEB128 size=11 + 0x00, 0x00, 0x00, 0x42, 0xab, 0xbf, 0xc3, 0x71, + 0x2b, 0xe4, 0x01, +} + +// 640x480, profile 0, level 4, 10-bit, 4:2:0. +var testSeqHdrOBU480p10 = []byte{ + 0x0a, 0x0b, // OBU header (type=1, has_size=1) + LEB128 size=11 + 0x00, 0x00, 0x00, 0x24, 0xc4, 0xff, 0xdf, 0x12, + 0xbe, 0x50, 0x10, +} + +// Temporal unit containing TD + seq header + key frame (1920x1080 8-bit). +var testTemporalUnitKeyframe = []byte{ + 0x12, 0x00, // OBU_TEMPORAL_DELIMITER, has_size=1, size=0 + 0x0a, 0x0b, // OBU_SEQUENCE_HEADER, has_size=1, size=11 + 0x00, 0x00, 0x00, 0x42, 0xab, 0xbf, 0xc3, 0x71, + 0x2b, 0xe4, 0x01, + 0x32, 0x2d, // OBU_FRAME, has_size=1, size=45 + 0x10, 0x00, 0x8b, 0x00, 0x81, 0x45, 0x08, 0x20, + 0x40, 0x00, 0x00, 0x03, 0x24, 0xfe, 0x6a, 0x75, + 0xf5, 0xb5, 0x7d, 0xaa, 0x9a, 0x90, 0x9b, 0xe8, + 0x41, 0xd4, 0x1a, 0x98, 0x69, 0x56, 0xfa, 0xf0, + 0xa8, 0xe6, 0xda, 0x81, 0x0c, 0x1c, 0xe9, 0xc2, + 0xf3, 0x84, 0xbd, 0x86, 0xab, +} + +// Temporal unit containing TD + inter frame (no sequence header). +var testTemporalUnitInterFrame = []byte{ + 0x12, 0x00, // OBU_TEMPORAL_DELIMITER, has_size=1, size=0 + 0x32, 0x18, // OBU_FRAME, has_size=1, size=24 + 0x30, 0x02, 0x04, 0x09, 0x24, 0x92, 0x22, 0x46, + 0x00, 0x00, 0x00, 0x80, 0x00, 0x00, 0x00, 0xbc, + 0x32, 0xc1, 0x64, 0xdc, 0x91, 0xe8, 0x51, 0xe8, +} + +func TestParseAV1SequenceHeader1080p(t *testing.T) { + t.Parallel() + + hdr, err := ParseAV1SequenceHeader(testSeqHdrOBU1080p) + if err != nil { + t.Fatalf("ParseAV1SequenceHeader error: %v", err) + } + + if hdr.SeqProfile != 0 { + t.Errorf("SeqProfile: got %d, want 0", hdr.SeqProfile) + } + if hdr.SeqLevelIdx != 8 { + t.Errorf("SeqLevelIdx: got %d, want 8", hdr.SeqLevelIdx) + } + if hdr.SeqTier != 0 { + t.Errorf("SeqTier: got %d, want 0", hdr.SeqTier) + } + if hdr.BitDepth != 8 { + t.Errorf("BitDepth: got %d, want 8", hdr.BitDepth) + } + if hdr.Width != 1920 { + t.Errorf("Width: got %d, want 1920", hdr.Width) + } + if hdr.Height != 1080 { + t.Errorf("Height: got %d, want 1080", hdr.Height) + } + if hdr.ChromaSubsamplingX != 1 { + t.Errorf("ChromaSubsamplingX: got %d, want 1", hdr.ChromaSubsamplingX) + } + if hdr.ChromaSubsamplingY != 1 { + t.Errorf("ChromaSubsamplingY: got %d, want 1", hdr.ChromaSubsamplingY) + } + if hdr.Monochrome { + t.Error("Monochrome: got true, want false") + } +} + +func TestParseAV1SequenceHeader480p10Bit(t *testing.T) { + t.Parallel() + + hdr, err := ParseAV1SequenceHeader(testSeqHdrOBU480p10) + if err != nil { + t.Fatalf("ParseAV1SequenceHeader error: %v", err) + } + + if hdr.SeqProfile != 0 { + t.Errorf("SeqProfile: got %d, want 0", hdr.SeqProfile) + } + if hdr.SeqLevelIdx != 4 { + t.Errorf("SeqLevelIdx: got %d, want 4", hdr.SeqLevelIdx) + } + if hdr.SeqTier != 0 { + t.Errorf("SeqTier: got %d, want 0", hdr.SeqTier) + } + if hdr.BitDepth != 10 { + t.Errorf("BitDepth: got %d, want 10", hdr.BitDepth) + } + if hdr.Width != 640 { + t.Errorf("Width: got %d, want 640", hdr.Width) + } + if hdr.Height != 480 { + t.Errorf("Height: got %d, want 480", hdr.Height) + } + if hdr.ChromaSubsamplingX != 1 { + t.Errorf("ChromaSubsamplingX: got %d, want 1", hdr.ChromaSubsamplingX) + } + if hdr.ChromaSubsamplingY != 1 { + t.Errorf("ChromaSubsamplingY: got %d, want 1", hdr.ChromaSubsamplingY) + } + if hdr.Monochrome { + t.Error("Monochrome: got true, want false") + } +} + +func TestParseAV1SequenceHeaderNil(t *testing.T) { + t.Parallel() + + _, err := ParseAV1SequenceHeader(nil) + if err == nil { + t.Error("expected error for nil input") + } +} + +func TestParseAV1SequenceHeaderEmpty(t *testing.T) { + t.Parallel() + + _, err := ParseAV1SequenceHeader([]byte{}) + if err == nil { + t.Error("expected error for empty input") + } +} + +func TestParseAV1SequenceHeaderTruncated(t *testing.T) { + t.Parallel() + + // Just the OBU header, no payload + _, err := ParseAV1SequenceHeader([]byte{0x0a, 0x0b}) + if err == nil { + t.Error("expected error for truncated input") + } + + // Partial payload + _, err = ParseAV1SequenceHeader([]byte{0x0a, 0x0b, 0x00, 0x00}) + if err == nil { + t.Error("expected error for truncated payload") + } +} + +func TestParseAV1SequenceHeaderWrongOBUType(t *testing.T) { + t.Parallel() + + // OBU_FRAME (type 6) instead of SEQUENCE_HEADER + _, err := ParseAV1SequenceHeader([]byte{0x32, 0x01, 0x00}) + if err == nil { + t.Error("expected error for wrong OBU type") + } +} + +func TestAV1CodecString1080p(t *testing.T) { + t.Parallel() + + hdr, err := ParseAV1SequenceHeader(testSeqHdrOBU1080p) + if err != nil { + t.Fatalf("ParseAV1SequenceHeader error: %v", err) + } + + want := "av01.0.08M.08" + got := hdr.CodecString() + if got != want { + t.Errorf("CodecString: got %q, want %q", got, want) + } +} + +func TestAV1CodecString480p10Bit(t *testing.T) { + t.Parallel() + + hdr, err := ParseAV1SequenceHeader(testSeqHdrOBU480p10) + if err != nil { + t.Fatalf("ParseAV1SequenceHeader error: %v", err) + } + + want := "av01.0.04M.10" + got := hdr.CodecString() + if got != want { + t.Errorf("CodecString: got %q, want %q", got, want) + } +} + +func TestAV1CodecStringHighTier(t *testing.T) { + t.Parallel() + + hdr := &AV1SequenceHeader{ + SeqProfile: 1, + SeqLevelIdx: 13, + SeqTier: 1, + BitDepth: 10, + } + + want := "av01.1.13H.10" + got := hdr.CodecString() + if got != want { + t.Errorf("CodecString: got %q, want %q", got, want) + } +} + +func TestAV1CodecStringZeroPadding(t *testing.T) { + t.Parallel() + + hdr := &AV1SequenceHeader{ + SeqProfile: 0, + SeqLevelIdx: 1, + SeqTier: 0, + BitDepth: 8, + } + + want := "av01.0.01M.08" + got := hdr.CodecString() + if got != want { + t.Errorf("CodecString: got %q, want %q", got, want) + } +} + +func TestIsAV1KeyframeTrue(t *testing.T) { + t.Parallel() + + if !IsAV1Keyframe(testTemporalUnitKeyframe) { + t.Error("expected true for keyframe temporal unit") + } +} + +func TestIsAV1KeyframeFalseInterFrame(t *testing.T) { + t.Parallel() + + if IsAV1Keyframe(testTemporalUnitInterFrame) { + t.Error("expected false for inter frame temporal unit") + } +} + +func TestIsAV1KeyframeFalseNil(t *testing.T) { + t.Parallel() + + if IsAV1Keyframe(nil) { + t.Error("expected false for nil input") + } +} + +func TestIsAV1KeyframeFalseEmpty(t *testing.T) { + t.Parallel() + + if IsAV1Keyframe([]byte{}) { + t.Error("expected false for empty input") + } +} + +func TestIsAV1KeyframeFalseTruncated(t *testing.T) { + t.Parallel() + + // Just a temporal delimiter, no frame OBU + if IsAV1Keyframe([]byte{0x12, 0x00}) { + t.Error("expected false for temporal delimiter only") + } +} + +func TestFindSequenceHeaderOBUPresent(t *testing.T) { + t.Parallel() + + obu := FindSequenceHeaderOBU(testTemporalUnitKeyframe) + if obu == nil { + t.Fatal("expected non-nil OBU") + } + + // Verify the found OBU is a valid sequence header + hdr, err := ParseAV1SequenceHeader(obu) + if err != nil { + t.Fatalf("ParseAV1SequenceHeader on found OBU: %v", err) + } + if hdr.Width != 1920 || hdr.Height != 1080 { + t.Errorf("resolution: got %dx%d, want 1920x1080", hdr.Width, hdr.Height) + } +} + +func TestFindSequenceHeaderOBUAbsent(t *testing.T) { + t.Parallel() + + obu := FindSequenceHeaderOBU(testTemporalUnitInterFrame) + if obu != nil { + t.Errorf("expected nil for inter frame, got %d bytes", len(obu)) + } +} + +func TestFindSequenceHeaderOBUNil(t *testing.T) { + t.Parallel() + + obu := FindSequenceHeaderOBU(nil) + if obu != nil { + t.Error("expected nil for nil input") + } +} + +func TestFindSequenceHeaderOBUEmpty(t *testing.T) { + t.Parallel() + + obu := FindSequenceHeaderOBU([]byte{}) + if obu != nil { + t.Error("expected nil for empty input") + } +} + +func TestReadLEB128(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + data []byte + want int + wantN int + wantErr bool + }{ + { + name: "zero", + data: []byte{0x00}, + want: 0, + wantN: 1, + }, + { + name: "one byte value 11", + data: []byte{0x0b}, + want: 11, + wantN: 1, + }, + { + name: "one byte value 127", + data: []byte{0x7f}, + want: 127, + wantN: 1, + }, + { + name: "two byte value 128", + data: []byte{0x80, 0x01}, + want: 128, + wantN: 2, + }, + { + name: "two byte value 300", + data: []byte{0xac, 0x02}, + want: 300, + wantN: 2, + }, + { + name: "empty", + data: []byte{}, + wantErr: true, + }, + { + name: "unterminated", + data: []byte{0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80}, + wantErr: true, + }, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + val, n, err := readLEB128(tt.data) + if tt.wantErr { + if err == nil { + t.Error("expected error") + } + return + } + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if val != tt.want { + t.Errorf("value: got %d, want %d", val, tt.want) + } + if n != tt.wantN { + t.Errorf("bytes consumed: got %d, want %d", n, tt.wantN) + } + }) + } +} + +func TestParseOBUHeader(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + data []byte + wantType int + wantHasSize bool + wantHasExt bool + wantErr bool + }{ + { + name: "sequence header with size", + data: []byte{0x0a, 0x0b, 0x00, 0x00}, + wantType: OBUSequenceHeader, + wantHasSize: true, + }, + { + name: "temporal delimiter with size", + data: []byte{0x12, 0x00}, + wantType: OBUTemporalDelimiter, + wantHasSize: true, + }, + { + name: "frame with size", + data: []byte{0x32, 0x2d, 0x10}, + wantType: OBUFrame, + wantHasSize: true, + }, + { + name: "forbidden bit set", + data: []byte{0x8a}, + wantErr: true, + }, + { + name: "empty", + data: []byte{}, + wantErr: true, + }, + { + name: "extension flag", + data: []byte{0x0e, 0x40, 0x0b, 0x00, 0x00}, + wantType: OBUSequenceHeader, + wantHasExt: true, + wantHasSize: true, + }, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + h, err := parseOBUHeader(tt.data) + if tt.wantErr { + if err == nil { + t.Error("expected error") + } + return + } + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if h.obuType != tt.wantType { + t.Errorf("obuType: got %d, want %d", h.obuType, tt.wantType) + } + if h.hasSizeField != tt.wantHasSize { + t.Errorf("hasSizeField: got %v, want %v", h.hasSizeField, tt.wantHasSize) + } + if h.hasExtension != tt.wantHasExt { + t.Errorf("hasExtension: got %v, want %v", h.hasExtension, tt.wantHasExt) + } + }) + } +} + +func TestAV1BitReaderOverflow(t *testing.T) { + t.Parallel() + + br := newAV1BitReader([]byte{0xFF}) + _ = br.readBits(8) // should succeed + _ = br.readBits(1) // should set error + + if br.err == nil { + t.Error("expected error after reading past end") + } + + // Subsequent reads should return 0 without panicking + val := br.readBits(8) + if val != 0 { + t.Errorf("expected 0 after error, got %d", val) + } +} + +func TestAV1BitReaderReadBits(t *testing.T) { + t.Parallel() + + // 0xA5 = 10100101 + br := newAV1BitReader([]byte{0xA5}) + + got := br.readBits(3) + if got != 5 { // 101 + t.Errorf("first 3 bits: got %d, want 5", got) + } + + got = br.readBits(5) + if got != 5 { // 00101 + t.Errorf("next 5 bits: got %d, want 5", got) + } + + if br.err != nil { + t.Errorf("unexpected error: %v", br.err) + } +} + +func TestAV1BitReaderUvlc(t *testing.T) { + t.Parallel() + + // UVLC encoding of 0: just a single 1-bit → value 0 + br := newAV1BitReader([]byte{0x80}) // 10000000 + got := br.readUvlc() + if got != 0 { + t.Errorf("uvlc(0): got %d, want 0", got) + } + + // UVLC encoding of 1: 0,1,0 → leadingZeros=1, val=0, result = 0 + (1<<1) - 1 = 1 + br = newAV1BitReader([]byte{0x40}) // 01000000 + got = br.readUvlc() + if got != 1 { + t.Errorf("uvlc(1): got %d, want 1", got) + } + + // UVLC encoding of 2: 0,1,1 → leadingZeros=1, val=1, result = 1 + (1<<1) - 1 = 2 + br = newAV1BitReader([]byte{0x60}) // 01100000 + got = br.readUvlc() + if got != 2 { + t.Errorf("uvlc(2): got %d, want 2", got) + } +} + +func TestReducedStillPictureHeader(t *testing.T) { + t.Parallel() + + // Hand-craft a minimal sequence header with reduced_still_picture_header=1 + // Profile 0, still_picture=1, reduced=1, level_idx=1 (5 bits) + // Then frame_width_bits_minus_1=3, frame_height_bits_minus_1=3 (4+4 bits) + // Then width=128-1=127 (4 bits), height=96-1=95 (4 bits) + // No frame_id_numbers_present (since reduced) + // use_128x128_superblock=0, enable_filter_intra=0, enable_intra_edge_filter=0 + // No interintra/masked/etc (since reduced) + // enable_superres=0, enable_cdef=0, enable_restoration=0 + // color_config: high_bitdepth=0 → 8-bit, mono_chrome=0 + // color_description_present=0 → unspecified + // color_range=0 + // chroma_sample_position=0 (2 bits, since profile 0 → 4:2:0) + // separate_uv_delta_q=0 + + // Bits: profile(3)=000, still(1)=1, reduced(1)=1 + // level_idx(5)=00001 + // fw_bits(4)=0011, fh_bits(4)=0011 + // width(4)=0111 1111, height(4)=0101 1111 + // Wait, width is fw_bits+1=4 bits, so 127 in 4 bits doesn't fit. + // Let's use smaller: width=16, height=16 + // fw_bits_m1=3 → 4 bits for width, fh_bits_m1=3 → 4 bits for height + // width-1=15=0b1111, height-1=15=0b1111 + // Then: use_128x128_sb=0, filter_intra=0, intra_edge_filter=0 (3 bits) + // enable_superres=0, enable_cdef=0, enable_restoration=0 (3 bits) + // high_bitdepth=0 (1 bit) + // mono_chrome=0 (1 bit, profile != 1) + // color_desc_present=0 (1 bit) + // color_range=0 (1 bit) + // chroma_sample_position=00 (2 bits) + // separate_uv_delta_q=0 (1 bit) + + // Layout: + // byte 0: 000 1 1 000 = 0x18 + // byte 1: 01 0011 00 = 0x4C (level_idx=00001, fw_bits=0011, start of fh_bits) + // byte 2: 11 1111 11 = 0xFF (fh_bits last 2 bits=11, width=1111, height first 2=11) + // byte 3: 11 000 000 = 0xC0 (height last 2=11, sb=0, filter_intra=0, intra_edge=0, superres=0, cdef=0) + // byte 4: 0 0 0 0 0 0 00 = 0x00 (restoration=0, high_bitdepth=0, mono=0, cdp=0, color_range=0, csp=00, sep_uv=0) + + // Actually let me be more careful: + // Bit stream (MSB first): + // [0] profile: 000 + // [3] still_picture: 1 + // [4] reduced_still_picture_header: 1 + // [5] seq_level_idx_0: 00001 + // [10] frame_width_bits_minus_1: 0011 + // [14] frame_height_bits_minus_1: 0011 + // [18] max_frame_width_minus_1 (4 bits): 1111 + // [22] max_frame_height_minus_1 (4 bits): 1111 + // [26] use_128x128_superblock: 0 + // [27] enable_filter_intra: 0 + // [28] enable_intra_edge_filter: 0 + // [29] enable_superres: 0 + // [30] enable_cdef: 0 + // [31] enable_restoration: 0 + // [32] high_bitdepth: 0 + // [33] mono_chrome: 0 + // [34] color_description_present_flag: 0 + // [35] color_range: 0 + // [36] chroma_sample_position: 00 + // [38] separate_uv_delta_q: 0 + + // Byte layout: + // bits 0-7: 000 1 1 000 = 0x18 + // bits 8-15: 01 0011 00 = 0x4C + // bits 16-23: 11 1111 11 = 0xFF + // bits 24-31: 11 000000 = 0xC0 + // bits 32-39: 00000 000 = 0x00 + + payload := []byte{0x18, 0x4C, 0xFF, 0xC0, 0x00} + + // Wrap in OBU header: type=1, has_size=1 + obu := make([]byte, 0, 2+len(payload)) + obu = append(obu, 0x0a) // OBU header: type=1, has_size=1 + obu = append(obu, byte(len(payload))) // LEB128 size + obu = append(obu, payload...) + + hdr, err := ParseAV1SequenceHeader(obu) + if err != nil { + t.Fatalf("ParseAV1SequenceHeader error: %v", err) + } + + if hdr.SeqProfile != 0 { + t.Errorf("SeqProfile: got %d, want 0", hdr.SeqProfile) + } + if hdr.SeqLevelIdx != 1 { + t.Errorf("SeqLevelIdx: got %d, want 1", hdr.SeqLevelIdx) + } + if hdr.SeqTier != 0 { + t.Errorf("SeqTier: got %d, want 0", hdr.SeqTier) + } + if hdr.Width != 16 { + t.Errorf("Width: got %d, want 16", hdr.Width) + } + if hdr.Height != 16 { + t.Errorf("Height: got %d, want 16", hdr.Height) + } + if hdr.BitDepth != 8 { + t.Errorf("BitDepth: got %d, want 8", hdr.BitDepth) + } +} + +func TestParseSequenceHeaderWithDecoderModelInfo(t *testing.T) { + t.Parallel() + + // Hand-craft a sequence header payload with decoder_model_info and + // operating_parameters_info to exercise the code path where + // decoder_model_present_for_this_op=1 requires reading additional fields. + // + // Profile 0, still_picture=0, reduced_still_picture_header=0 + // timing_info_present=1, num_units_in_display_tick=1(32b), time_scale=30(32b) + // equal_picture_interval=0 + // decoder_model_info_present_flag=1 + // buffer_delay_length_minus_1=4 (5b, so delays are 5 bits each) + // num_units_in_decoding_tick=1 (32b) + // buffer_removal_time_length_minus_1=0 (5b) + // frame_presentation_time_length_minus_1=0 (5b) + // initial_display_delay_present_flag=0 + // operating_points_cnt_minus_1=0 (5b = 00000, so 1 operating point) + // operating_point_idc[0]=0 (12b) + // seq_level_idx[0]=8 (5b=01000), tier not read (level<=7 check: 8>7 so tier read) + // seq_tier[0]=0 (1b) + // decoder_model_present_for_this_op[0]=1 (1b) + // decoder_buffer_delay[0]=0 (5b, since buffer_delay_length_minus_1=4) + // encoder_buffer_delay[0]=0 (5b) + // low_delay_mode_flag[0]=0 (1b) + // frame_width_bits_minus_1=3 (4b=0011), frame_height_bits_minus_1=3 (4b=0011) + // max_frame_width_minus_1=15 (4b=1111), max_frame_height_minus_1=15 (4b=1111) + // frame_id_numbers_present=0 + // use_128x128_superblock=0, enable_filter_intra=0, enable_intra_edge_filter=0 + // enable_interintra_compound=0, enable_masked_compound=0, enable_warped_motion=0 + // enable_dual_filter=0, enable_order_hint=0 + // seq_choose_screen_content_tools=1 + // enable_superres=0, enable_cdef=0, enable_restoration=0 + // color_config: high_bitdepth=0, mono_chrome=0, color_description_present=0 + // color_range=0, chroma_sample_position=00, separate_uv_delta_q=0 + + // Bit-by-bit layout: + // [0-2] seq_profile: 000 + // [3] still_picture: 0 + // [4] reduced_still_picture_header: 0 + // [5] timing_info_present_flag: 1 + // [6-37] num_units_in_display_tick: 00000000 00000000 00000000 00000001 + // [38-69] time_scale: 00000000 00000000 00000000 00011110 + // [70] equal_picture_interval: 0 + // [71] decoder_model_info_present_flag: 1 + // [72-76] buffer_delay_length_minus_1: 00100 (=4, so delays are 5 bits) + // [77-108] num_units_in_decoding_tick: 00000000 00000000 00000000 00000001 + // [109-113] buffer_removal_time_length_minus_1: 00000 + // [114-118] frame_presentation_time_length_minus_1: 00000 + // [119] initial_display_delay_present_flag: 0 + // [120-124] operating_points_cnt_minus_1: 00000 + // [125-136] operating_point_idc[0]: 000000000000 + // [137-141] seq_level_idx[0]: 01000 (=8) + // [142] seq_tier[0]: 0 (level>7 so tier is read) + // [143] decoder_model_present_for_this_op[0]: 1 + // [144-148] decoder_buffer_delay[0]: 00000 + // [149-153] encoder_buffer_delay[0]: 00000 + // [154] low_delay_mode_flag[0]: 0 + // [155-158] frame_width_bits_minus_1: 0011 + // [159-162] frame_height_bits_minus_1: 0011 + // [163-166] max_frame_width_minus_1: 1111 + // [167-170] max_frame_height_minus_1: 1111 + // [171] frame_id_numbers_present_flag: 0 + // [172] use_128x128_superblock: 0 + // [173] enable_filter_intra: 0 + // [174] enable_intra_edge_filter: 0 + // [175] enable_interintra_compound: 0 + // [176] enable_masked_compound: 0 + // [177] enable_warped_motion: 0 + // [178] enable_dual_filter: 0 + // [179] enable_order_hint: 0 + // [180] seq_choose_screen_content_tools: 1 + // [181] enable_superres: 0 + // [182] enable_cdef: 0 + // [183] enable_restoration: 0 + // [184] high_bitdepth: 0 + // [185] mono_chrome: 0 + // [186] color_description_present_flag: 0 + // [187] color_range: 0 + // [188-189] chroma_sample_position: 00 + // [190] separate_uv_delta_q: 0 + + // Convert bit positions to bytes: + // byte 0 [0-7]: 00000 1 00 = 0x04 + // byte 1 [8-15]: 00000000 + // byte 2 [16-23]: 00000000 + // byte 3 [24-31]: 00000000 + // byte 4 [32-39]: 00000001 + // byte 5 [40-47]: 00000000 + // byte 6 [48-55]: 00000000 + // byte 7 [56-63]: 00000000 + // byte 8 [64-71]: 00011110 = 0x1E (bits 64-69 = time_scale last, 70=equal_pic=0, 71=dmi=1) + // Wait, let me redo this more carefully. + + // Let me just build it with a helper to be precise. + bits := []uint{} // each element is 0 or 1 + + appendBits := func(val uint, n int) { + for i := n - 1; i >= 0; i-- { + bits = append(bits, (val>>uint(i))&1) + } + } + + appendBits(0, 3) // seq_profile = 0 + appendBits(0, 1) // still_picture = 0 + appendBits(0, 1) // reduced_still_picture_header = 0 + appendBits(1, 1) // timing_info_present_flag = 1 + appendBits(1, 32) // num_units_in_display_tick = 1 + appendBits(30, 32) // time_scale = 30 + appendBits(0, 1) // equal_picture_interval = 0 + appendBits(1, 1) // decoder_model_info_present_flag = 1 + appendBits(4, 5) // buffer_delay_length_minus_1 = 4 + appendBits(1, 32) // num_units_in_decoding_tick = 1 + appendBits(0, 5) // buffer_removal_time_length_minus_1 = 0 + appendBits(0, 5) // frame_presentation_time_length_minus_1 = 0 + appendBits(0, 1) // initial_display_delay_present_flag = 0 + appendBits(0, 5) // operating_points_cnt_minus_1 = 0 + appendBits(0, 12) // operating_point_idc[0] = 0 + appendBits(8, 5) // seq_level_idx[0] = 8 + appendBits(0, 1) // seq_tier[0] = 0 (read because level > 7) + appendBits(1, 1) // decoder_model_present_for_this_op[0] = 1 + appendBits(0, 5) // decoder_buffer_delay[0] = 0 (5 bits) + appendBits(0, 5) // encoder_buffer_delay[0] = 0 (5 bits) + appendBits(0, 1) // low_delay_mode_flag[0] = 0 + appendBits(3, 4) // frame_width_bits_minus_1 = 3 + appendBits(3, 4) // frame_height_bits_minus_1 = 3 + appendBits(15, 4) // max_frame_width_minus_1 = 15 (width=16) + appendBits(15, 4) // max_frame_height_minus_1 = 15 (height=16) + appendBits(0, 1) // frame_id_numbers_present_flag = 0 + appendBits(0, 1) // use_128x128_superblock = 0 + appendBits(0, 1) // enable_filter_intra = 0 + appendBits(0, 1) // enable_intra_edge_filter = 0 + appendBits(0, 1) // enable_interintra_compound = 0 + appendBits(0, 1) // enable_masked_compound = 0 + appendBits(0, 1) // enable_warped_motion = 0 + appendBits(0, 1) // enable_dual_filter = 0 + appendBits(0, 1) // enable_order_hint = 0 + appendBits(1, 1) // seq_choose_screen_content_tools = 1 (SELECT) + // seqForceScreenContentTools = 2 (SELECT), which is > 0 + appendBits(1, 1) // seq_choose_integer_mv = 1 (SELECT) + appendBits(0, 1) // enable_superres = 0 + appendBits(0, 1) // enable_cdef = 0 + appendBits(0, 1) // enable_restoration = 0 + appendBits(0, 1) // high_bitdepth = 0 → 8-bit + appendBits(0, 1) // mono_chrome = 0 + appendBits(0, 1) // color_description_present_flag = 0 + appendBits(0, 1) // color_range = 0 + appendBits(0, 2) // chroma_sample_position = 0 + appendBits(0, 1) // separate_uv_delta_q = 0 + + // Pack bits into bytes + payload := make([]byte, (len(bits)+7)/8) + for i, b := range bits { + if b == 1 { + payload[i/8] |= 1 << (7 - uint(i%8)) + } + } + + // Wrap in OBU header + obu := make([]byte, 0, 2+len(payload)) + obu = append(obu, 0x0a) // OBU header: type=1, has_size=1 + obu = append(obu, byte(len(payload))) // LEB128 size + obu = append(obu, payload...) + + hdr, err := ParseAV1SequenceHeader(obu) + if err != nil { + t.Fatalf("ParseAV1SequenceHeader error: %v", err) + } + + if hdr.SeqProfile != 0 { + t.Errorf("SeqProfile: got %d, want 0", hdr.SeqProfile) + } + if hdr.SeqLevelIdx != 8 { + t.Errorf("SeqLevelIdx: got %d, want 8", hdr.SeqLevelIdx) + } + if hdr.SeqTier != 0 { + t.Errorf("SeqTier: got %d, want 0", hdr.SeqTier) + } + if hdr.Width != 16 { + t.Errorf("Width: got %d, want 16", hdr.Width) + } + if hdr.Height != 16 { + t.Errorf("Height: got %d, want 16", hdr.Height) + } + if hdr.BitDepth != 8 { + t.Errorf("BitDepth: got %d, want 8", hdr.BitDepth) + } +} + +func TestIsAV1KeyframeFrameHeaderOBU(t *testing.T) { + t.Parallel() + + // Construct a temporal unit with OBU_FRAME_HEADER (type 3) instead of OBU_FRAME + // OBU_FRAME_HEADER: type=3, has_size=1 → (3<<3)|2 = 0x1A + // Payload: show_existing=0, frame_type=0 (KEY) → first byte = 0b0_00_xxxxx = 0x00 + tu := []byte{ + 0x12, 0x00, // TD + 0x1a, 0x02, // OBU_FRAME_HEADER, size=2 + 0x00, 0x00, // show_existing=0, frame_type=0 (KEY) + } + + if !IsAV1Keyframe(tu) { + t.Error("expected true for OBU_FRAME_HEADER with KEY_FRAME type") + } +} + +func TestIsAV1KeyframeShowExistingFrame(t *testing.T) { + t.Parallel() + + // show_existing_frame = 1 → first payload byte: 1_xx_xxxxx = 0x80 + tu := []byte{ + 0x12, 0x00, // TD + 0x32, 0x02, // OBU_FRAME, size=2 + 0x80, 0x00, // show_existing_frame=1 + } + + if IsAV1Keyframe(tu) { + t.Error("expected false for show_existing_frame=1") + } +} + +func TestFindSequenceHeaderOBUTruncated(t *testing.T) { + t.Parallel() + + // Sequence header OBU header but truncated payload + data := []byte{0x0a, 0x80} // has_size with unterminated LEB128 + obu := FindSequenceHeaderOBU(data) + if obu != nil { + t.Error("expected nil for truncated OBU") + } +} diff --git a/distribution/moq_writer.go b/distribution/moq_writer.go index 71cc942..b176f07 100644 --- a/distribution/moq_writer.go +++ b/distribution/moq_writer.go @@ -88,12 +88,19 @@ func (m *moqWriter) WriteVideoFrame(w io.Writer, frame *media.VideoFrame) (int64 } // Video Config on keyframes (ID 13, odd → length-prefixed bytes) - if frame.IsKeyframe && frame.SPS != nil && frame.PPS != nil { + if frame.IsKeyframe && frame.SPS != nil { var configData []byte - if frame.Codec == "h265" && frame.VPS != nil { - configData = moq.BuildHEVCDecoderConfig(frame.VPS, frame.SPS, frame.PPS) - } else { - configData = moq.BuildAVCDecoderConfig(frame.SPS, frame.PPS) + switch frame.Codec { + case "av1": + configData = moq.BuildAV1DecoderConfig(frame.SPS) + case "h265": + if frame.VPS != nil && frame.PPS != nil { + configData = moq.BuildHEVCDecoderConfig(frame.VPS, frame.SPS, frame.PPS) + } + default: + if frame.PPS != nil { + configData = moq.BuildAVCDecoderConfig(frame.SPS, frame.PPS) + } } if configData != nil { exts = quicvarint.Append(exts, locExtVideoConfig) diff --git a/distribution/moq_writer_test.go b/distribution/moq_writer_test.go index 19ba36e..f4acace 100644 --- a/distribution/moq_writer_test.go +++ b/distribution/moq_writer_test.go @@ -488,6 +488,142 @@ func TestMoQWriterBytesWritten(t *testing.T) { } } +func TestMoQWriterVideoFrameAV1Keyframe(t *testing.T) { + t.Parallel() + w := NewMoQWriter(1, 0) + var buf bytes.Buffer + + if err := w.WriteStreamHeader(&buf, TrackIDVideo, 1, 0); err != nil { + t.Fatalf("WriteStreamHeader failed: %v", err) + } + buf.Reset() + + // AV1 sequence header OBU for 1920x1080 8-bit (from demux/av1_test.go) + seqHdrOBU := []byte{ + 0x0a, 0x0b, + 0x00, 0x00, 0x00, 0x42, 0xab, 0xbf, 0xc3, 0x71, + 0x2b, 0xe4, 0x01, + } + + // Raw OBU payload representing the frame data + wireData := []byte{0x32, 0x10, 0xAA, 0xBB, 0xCC, 0xDD} + + frame := &media.VideoFrame{ + PTS: 50000, + IsKeyframe: true, + Codec: "av1", + SPS: seqHdrOBU, + WireData: wireData, + } + + n, err := w.WriteVideoFrame(&buf, frame) + if err != nil { + t.Fatalf("WriteVideoFrame failed: %v", err) + } + + if n != int64(buf.Len()) { + t.Errorf("bytes written: got %d, actual buffer %d", n, buf.Len()) + } + + data := buf.Bytes() + pos := 0 + + // Object ID + objectID, nn, err := quicvarint.Parse(data[pos:]) + if err != nil { + t.Fatalf("parse object ID: %v", err) + } + if objectID != 0 { + t.Errorf("object ID: got %d, want 0", objectID) + } + pos += nn + + // Extension headers length + extLen, nn, err := quicvarint.Parse(data[pos:]) + if err != nil { + t.Fatalf("parse ext length: %v", err) + } + pos += nn + + extEnd := pos + int(extLen) + foundTimestamp := false + foundMarking := false + foundConfig := false + var configData []byte + + for pos < extEnd { + extID, nn, err := quicvarint.Parse(data[pos:]) + if err != nil { + t.Fatalf("parse ext ID: %v", err) + } + pos += nn + + if extID%2 == 0 { + val, nn, err := quicvarint.Parse(data[pos:]) + if err != nil { + t.Fatalf("parse ext value: %v", err) + } + pos += nn + + switch extID { + case locExtCaptureTimestamp: + foundTimestamp = true + if val != 50000 { + t.Errorf("capture timestamp: got %d, want 50000", val) + } + case locExtVideoFrameMarking: + foundMarking = true + if val != vfmKeyframe { + t.Errorf("video frame marking: got 0x%x, want 0x%x", val, vfmKeyframe) + } + } + } else { + valLen, nn, err := quicvarint.Parse(data[pos:]) + if err != nil { + t.Fatalf("parse ext value length: %v", err) + } + pos += nn + + if extID == locExtVideoConfig { + foundConfig = true + configData = data[pos : pos+int(valLen)] + } + pos += int(valLen) + } + } + + if !foundTimestamp { + t.Error("missing capture timestamp extension") + } + if !foundMarking { + t.Error("missing video frame marking extension") + } + if !foundConfig { + t.Error("missing video config extension") + } + + // AV1CodecConfigurationRecord starts with marker|version = 0x81 + if len(configData) < 4 { + t.Fatalf("config data too short: %d bytes", len(configData)) + } + if configData[0] != 0x81 { + t.Errorf("AV1 config marker|version: got 0x%02x, want 0x81", configData[0]) + } + + // Payload length + payloadLen, nn, err := quicvarint.Parse(data[pos:]) + if err != nil { + t.Fatalf("parse payload length: %v", err) + } + pos += nn + + // Payload should be the raw WireData (pass-through, no AnnexB conversion) + payload := data[pos : pos+int(payloadLen)] + if !bytes.Equal(payload, wireData) { + t.Errorf("payload mismatch: got %x, want %x", payload, wireData) + } +} + func TestMoQWriterStreamHeaderSize(t *testing.T) { t.Parallel() w := NewMoQWriter(1, 0) diff --git a/distribution/server.go b/distribution/server.go index 6c59cba..442a882 100644 --- a/distribution/server.go +++ b/distribution/server.go @@ -104,6 +104,23 @@ type SRTPullInfo struct { StreamID string `json:"streamId,omitempty"` } +// DASHPullFunc initiates a DASH pull from a remote MPD URL. +type DASHPullFunc func(url, streamKey, videoRepID, audioRepID string) error + +// DASHStopFunc stops an active DASH pull by stream key. +type DASHStopFunc func(streamKey string) error + +// DASHListFunc returns all active DASH pulls. +type DASHListFunc func() []DASHPullInfo + +// DASHPullInfo describes an active DASH pull. +type DASHPullInfo struct { + URL string `json:"url"` + StreamKey string `json:"streamKey"` + VideoRepID string `json:"videoRepresentationId,omitempty"` + AudioRepID string `json:"audioRepresentationId,omitempty"` +} + // WebTransport session close error codes sent to clients via CloseWithError. const ( wtErrStreamNotFound webtransport.SessionErrorCode = 1 @@ -131,6 +148,9 @@ type ServerConfig struct { SRTPull SRTPullFunc SRTStop SRTStopFunc SRTList SRTListFunc + DASHPull DASHPullFunc + DASHStop DASHStopFunc + DASHList DASHListFunc ExtraRoutes func(mux *http.ServeMux) // OnStreamRegistered is called after a new stream relay is created @@ -274,6 +294,10 @@ func (s *Server) registerAPIRoutes(mux *http.ServeMux) { mux.HandleFunc("POST /api/srt-pull", s.handleSRTPullCreate) mux.HandleFunc("DELETE /api/srt-pull", s.handleSRTPullStop) mux.HandleFunc("OPTIONS /api/srt-pull", s.handleSRTPullOptions) + mux.HandleFunc("GET /api/dash-pull", s.handleDASHPullList) + mux.HandleFunc("POST /api/dash-pull", s.handleDASHPullCreate) + mux.HandleFunc("DELETE /api/dash-pull", s.handleDASHPullStop) + mux.HandleFunc("OPTIONS /api/dash-pull", s.handleDASHPullOptions) if s.config.ExtraRoutes != nil { s.config.ExtraRoutes(mux) diff --git a/distribution/server_dash_handlers.go b/distribution/server_dash_handlers.go new file mode 100644 index 0000000..2c7edd1 --- /dev/null +++ b/distribution/server_dash_handlers.go @@ -0,0 +1,63 @@ +package distribution + +import ( + "encoding/json" + "net/http" +) + +func (s *Server) handleDASHPullList(w http.ResponseWriter, _ *http.Request) { + if s.config.DASHList == nil { + writeJSON(w, http.StatusOK, []DASHPullInfo{}) + return + } + writeJSON(w, http.StatusOK, s.config.DASHList()) +} + +func (s *Server) handleDASHPullCreate(w http.ResponseWriter, r *http.Request) { + if s.config.DASHPull == nil { + writeError(w, http.StatusNotImplemented, "DASH pull not configured") + return + } + var req struct { + URL string `json:"url"` + StreamKey string `json:"streamKey"` + VideoRepID string `json:"videoRepresentationId,omitempty"` + AudioRepID string `json:"audioRepresentationId,omitempty"` + } + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + writeError(w, http.StatusBadRequest, err.Error()) + return + } + if req.URL == "" || req.StreamKey == "" { + writeError(w, http.StatusBadRequest, "url and streamKey are required") + return + } + if err := s.config.DASHPull(req.URL, req.StreamKey, req.VideoRepID, req.AudioRepID); err != nil { + writeError(w, http.StatusConflict, err.Error()) + return + } + writeJSON(w, http.StatusCreated, map[string]string{"status": "pulling", "streamKey": req.StreamKey}) +} + +func (s *Server) handleDASHPullStop(w http.ResponseWriter, r *http.Request) { + if s.config.DASHStop == nil { + writeError(w, http.StatusNotImplemented, "DASH pull not configured") + return + } + streamKey := r.URL.Query().Get("streamKey") + if streamKey == "" { + writeError(w, http.StatusBadRequest, "streamKey query parameter required") + return + } + if err := s.config.DASHStop(streamKey); err != nil { + writeError(w, http.StatusNotFound, err.Error()) + return + } + writeJSON(w, http.StatusOK, map[string]string{"status": "stopped", "streamKey": streamKey}) +} + +func (s *Server) handleDASHPullOptions(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Access-Control-Allow-Methods", "GET, POST, DELETE, OPTIONS") + w.Header().Set("Access-Control-Allow-Headers", "Content-Type") + w.WriteHeader(http.StatusNoContent) +} diff --git a/distribution/server_test.go b/distribution/server_test.go index b5469d7..8ee4f30 100644 --- a/distribution/server_test.go +++ b/distribution/server_test.go @@ -409,3 +409,58 @@ func TestStreamLifecycleCallbacks(t *testing.T) { // If we get here without panicking, the test passes. }) } + +func TestHandleDASHPullList(t *testing.T) { + t.Parallel() + + srv := newTestServer(t) + handler := srv.APIHandler() + + req := httptest.NewRequest("GET", "/api/dash-pull", nil) + rec := httptest.NewRecorder() + handler.ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("status = %d, want %d", rec.Code, http.StatusOK) + } + + // DASHList is nil, should return empty array. + body := strings.TrimSpace(rec.Body.String()) + if body != "[]" { + t.Fatalf("body = %q, want %q", body, "[]") + } +} + +func TestHandleDASHPullCreateMissingFields(t *testing.T) { + t.Parallel() + + srv := newTestServer(t) + srv.config.DASHPull = func(_, _, _, _ string) error { return nil } + handler := srv.APIHandler() + + req := httptest.NewRequest("POST", "/api/dash-pull", strings.NewReader(`{"url":""}`)) + req.Header.Set("Content-Type", "application/json") + rec := httptest.NewRecorder() + handler.ServeHTTP(rec, req) + + if rec.Code != http.StatusBadRequest { + t.Fatalf("status = %d, want %d", rec.Code, http.StatusBadRequest) + } +} + +func TestHandleDASHPullNotConfigured(t *testing.T) { + t.Parallel() + + srv := newTestServer(t) + // DASHPull is nil. + handler := srv.APIHandler() + + req := httptest.NewRequest("POST", "/api/dash-pull", strings.NewReader(`{"url":"https://example.com/live.mpd","streamKey":"test"}`)) + req.Header.Set("Content-Type", "application/json") + rec := httptest.NewRecorder() + handler.ServeHTTP(rec, req) + + if rec.Code != http.StatusNotImplemented { + t.Fatalf("status = %d, want %d", rec.Code, http.StatusNotImplemented) + } +} diff --git a/go.mod b/go.mod index 784de9d..0731c5e 100644 --- a/go.mod +++ b/go.mod @@ -3,8 +3,10 @@ module github.com/zsiec/prism go 1.24.3 require ( + github.com/Eyevinn/mp4ff v0.51.0 github.com/quic-go/quic-go v0.59.0 github.com/quic-go/webtransport-go v0.10.0 + github.com/unki2aut/go-mpd v0.0.0-20250610073145-8336a8d84ee7 github.com/zsiec/ccx v0.2.0 github.com/zsiec/srtgo v0.2.4 golang.org/x/sync v0.17.0 @@ -13,6 +15,7 @@ require ( require ( github.com/dunglas/httpsfv v1.1.0 // indirect github.com/quic-go/qpack v0.6.0 // indirect + github.com/unki2aut/go-xsd-types v0.0.0-20200220223938-30e5405398f8 // indirect golang.org/x/crypto v0.42.0 // indirect golang.org/x/net v0.44.0 // indirect golang.org/x/sys v0.36.0 // indirect diff --git a/go.sum b/go.sum index 43ddee4..851b3e5 100644 --- a/go.sum +++ b/go.sum @@ -1,7 +1,15 @@ +github.com/Eyevinn/mp4ff v0.51.0 h1:ZYdHFXEcB3kJkCeCHMHl/tbCm64FJsD2XOU0Sj+ME2M= +github.com/Eyevinn/mp4ff v0.51.0/go.mod h1:hJNUUqOBryLAzUW9wpCJyw2HaI+TCd2rUPhafoS5lgg= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/dunglas/httpsfv v1.1.0 h1:Jw76nAyKWKZKFrpMMcL76y35tOpYHqQPzHQiwDvpe54= github.com/dunglas/httpsfv v1.1.0/go.mod h1:zID2mqw9mFsnt7YC3vYQ9/cjq30q41W+1AnDwH8TiMg= +github.com/go-test/deep v1.1.0 h1:WOcxcdHcvdgThNXjw0t76K42FXTU7HpNQWHpA2HHNlg= +github.com/go-test/deep v1.1.0/go.mod h1:5C2ZWiW0ErCdrYzpqxLbTX7MG14M9iiw8DgHncVwcsE= +github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= +github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= +github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= +github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/quic-go/qpack v0.6.0 h1:g7W+BMYynC1LbYLSqRt8PBg5Tgwxn214ZZR34VIOjz8= @@ -10,8 +18,14 @@ github.com/quic-go/quic-go v0.59.0 h1:OLJkp1Mlm/aS7dpKgTc6cnpynnD2Xg7C1pwL6vy/SA github.com/quic-go/quic-go v0.59.0/go.mod h1:upnsH4Ju1YkqpLXC305eW3yDZ4NfnNbmQRCMWS58IKU= github.com/quic-go/webtransport-go v0.10.0 h1:LqXXPOXuETY5Xe8ITdGisBzTYmUOy5eSj+9n4hLTjHI= github.com/quic-go/webtransport-go v0.10.0/go.mod h1:LeGIXr5BQKE3UsynwVBeQrU1TPrbh73MGoC6jd+V7ow= +github.com/rogpeppe/go-internal v1.10.0 h1:TMyTOH3F/DB16zRVcYyreMH6GnZZrwQVAoYjRBZyWFQ= +github.com/rogpeppe/go-internal v1.10.0/go.mod h1:UQnix2H7Ngw/k4C5ijL5+65zddjncjaFoBhdsK/akog= github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= +github.com/unki2aut/go-mpd v0.0.0-20250610073145-8336a8d84ee7 h1:FsfEp+/I/2p/WObIlPlB10kITDkOjzi/IDqM9GYpYJc= +github.com/unki2aut/go-mpd v0.0.0-20250610073145-8336a8d84ee7/go.mod h1:LITqXLCxxmcoHtOMgZh5NbcfS4RCrrADQXPVkYwF/cc= +github.com/unki2aut/go-xsd-types v0.0.0-20200220223938-30e5405398f8 h1:u0Bi6Mf8BKPQnxGJ7QubdMyhb0SJjnQU7kX0BA9eASk= +github.com/unki2aut/go-xsd-types v0.0.0-20200220223938-30e5405398f8/go.mod h1:uIeMfpmWIZ8SGp+fTfwDBWiiRn3aJm4b7rFSro9s++Q= github.com/zsiec/ccx v0.2.0 h1:RaCC4a0ng9wa6AUvT9Kg9kfiEy9svfXL5mmRbi6Ykmo= github.com/zsiec/ccx v0.2.0/go.mod h1:Y30W1TCZX7HAXM0miCzG18kSyKtJvqbOk7zrFcTGpU4= github.com/zsiec/srtgo v0.2.4 h1:WzQfUMSiQglWJDilcXgFvW/23IGLr7FXNRdrbNKJyS8= @@ -28,5 +42,7 @@ golang.org/x/sys v0.36.0 h1:KVRy2GtZBrk1cBYA7MKu5bEZFxQk4NIDV6RLVcC8o0k= golang.org/x/sys v0.36.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= golang.org/x/text v0.30.0 h1:yznKA/E9zq54KzlzBEAWn1NXSQ8DIp/NYMy88xJjl4k= golang.org/x/text v0.30.0/go.mod h1:yDdHFIX9t+tORqspjENWgzaCVXgk0yYnYuSZ8UzzBVM= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/ingest/dash/mpd.go b/ingest/dash/mpd.go new file mode 100644 index 0000000..4316699 --- /dev/null +++ b/ingest/dash/mpd.go @@ -0,0 +1,262 @@ +package dash + +import ( + "fmt" + "regexp" + "strconv" + "strings" + "time" + + "github.com/unki2aut/go-mpd" +) + +type mpdInfo struct { + IsDynamic bool + AvailabilityStartTime time.Time + MinUpdatePeriod time.Duration + BaseURL string + VideoAdaptations []adaptationInfo + AudioAdaptations []adaptationInfo +} + +type adaptationInfo struct { + MimeType string + Representations []representationInfo +} + +type representationInfo struct { + ID string + Bandwidth int + Width int + Height int + Codecs string + SegmentTemplate segmentTemplate +} + +type segmentTemplate struct { + InitializationPattern string + MediaPattern string + StartNumber int + Timescale int + Duration int +} + +// parseMPD decodes MPD XML and extracts video/audio adaptation sets with +// their representations and segment templates. +func parseMPD(data []byte) (*mpdInfo, error) { + var m mpd.MPD + if err := m.Decode(data); err != nil { + return nil, fmt.Errorf("decode MPD: %w", err) + } + + info := &mpdInfo{} + + // Type + if m.Type != nil && *m.Type == "dynamic" { + info.IsDynamic = true + } + + // AvailabilityStartTime + if m.AvailabilityStartTime != nil { + info.AvailabilityStartTime = time.Time(*m.AvailabilityStartTime) + } + + // MinimumUpdatePeriod + if m.MinimumUpdatePeriod != nil { + ns, err := m.MinimumUpdatePeriod.ToNanoseconds() + if err == nil { + info.MinUpdatePeriod = time.Duration(ns) + } + } + + // BaseURL (take first if present) + if len(m.BaseURL) > 0 { + info.BaseURL = m.BaseURL[0].Value + } + + // Walk periods and adaptation sets + for _, period := range m.Period { + if period == nil { + continue + } + for _, as := range period.AdaptationSets { + if as == nil { + continue + } + // Determine content type from mimeType, contentType, or Representation mimeType + contentType := classifyAdaptationSet(as) + ai := adaptationInfo{ + MimeType: contentType, + } + for _, rep := range as.Representations { + ri := representationInfo{} + if rep.ID != nil { + ri.ID = *rep.ID + } + if rep.Bandwidth != nil { + ri.Bandwidth = int(*rep.Bandwidth) + } + if rep.Width != nil { + ri.Width = int(*rep.Width) + } + if rep.Height != nil { + ri.Height = int(*rep.Height) + } + if rep.Codecs != nil { + ri.Codecs = *rep.Codecs + } else if as.Codecs != nil { + ri.Codecs = *as.Codecs + } + + // SegmentTemplate: check representation level first, then adaptation set level + st := rep.SegmentTemplate + if st == nil { + st = as.SegmentTemplate + } + if st != nil { + ri.SegmentTemplate = extractSegmentTemplate(st) + } + + ai.Representations = append(ai.Representations, ri) + } + + if strings.HasPrefix(contentType, "video") { + info.VideoAdaptations = append(info.VideoAdaptations, ai) + } else if strings.HasPrefix(contentType, "audio") { + info.AudioAdaptations = append(info.AudioAdaptations, ai) + } + } + } + + return info, nil +} + +func extractSegmentTemplate(st *mpd.SegmentTemplate) segmentTemplate { + t := segmentTemplate{} + if st.Initialization != nil { + t.InitializationPattern = *st.Initialization + } + if st.Media != nil { + t.MediaPattern = *st.Media + } + if st.StartNumber != nil { + t.StartNumber = int(*st.StartNumber) + } + if st.Timescale != nil { + t.Timescale = int(*st.Timescale) + } + if st.Duration != nil { + t.Duration = int(*st.Duration) + } + return t +} + +// selectRepresentation picks a representation from the given adaptation sets. +// If repID is empty, the highest bandwidth representation is selected. +// If repID is given, an exact match is required. +func selectRepresentation(adaptations []adaptationInfo, repID string) (representationInfo, segmentTemplate, error) { + if len(adaptations) == 0 { + return representationInfo{}, segmentTemplate{}, fmt.Errorf("no adaptation sets available") + } + + if repID == "" { + // Select highest bandwidth across all adaptations + var best representationInfo + found := false + for _, a := range adaptations { + for _, r := range a.Representations { + if !found || r.Bandwidth > best.Bandwidth { + best = r + found = true + } + } + } + if !found { + return representationInfo{}, segmentTemplate{}, fmt.Errorf("no representations available") + } + return best, best.SegmentTemplate, nil + } + + // Exact match + for _, a := range adaptations { + for _, r := range a.Representations { + if r.ID == repID { + return r, r.SegmentTemplate, nil + } + } + } + return representationInfo{}, segmentTemplate{}, fmt.Errorf("representation %q not found", repID) +} + +// resolveInitURL replaces $RepresentationID$ in the initialization pattern +// and prepends the base URL. +func resolveInitURL(tmpl segmentTemplate, repID, baseURL string) string { + s := strings.ReplaceAll(tmpl.InitializationPattern, "$RepresentationID$", repID) + return baseURL + s +} + +// numberPattern matches $Number$ or $Number%0Nd$ (printf-style zero-padded). +var numberPattern = regexp.MustCompile(`\$Number(%0(\d+)d)?\$`) + +// resolveMediaURL replaces $RepresentationID$ and $Number$/$Number%05d$ in +// the media pattern and prepends the base URL. +func resolveMediaURL(tmpl segmentTemplate, repID string, number int, baseURL string) string { + s := strings.ReplaceAll(tmpl.MediaPattern, "$RepresentationID$", repID) + s = numberPattern.ReplaceAllStringFunc(s, func(match string) string { + sub := numberPattern.FindStringSubmatch(match) + if sub[2] != "" { + width, _ := strconv.Atoi(sub[2]) + return fmt.Sprintf("%0*d", width, number) + } + return strconv.Itoa(number) + }) + return baseURL + s +} + +// computeSegmentNumber computes the latest fully-available segment number +// (one behind the live edge) based on the current time, availability start +// time, and the segment template parameters. +func computeSegmentNumber(now time.Time, ast time.Time, tmpl segmentTemplate) int { + if tmpl.Timescale == 0 || tmpl.Duration == 0 { + return tmpl.StartNumber + } + elapsed := now.Sub(ast).Seconds() + segmentDuration := float64(tmpl.Duration) / float64(tmpl.Timescale) + if elapsed < segmentDuration { + return tmpl.StartNumber + } + // Total segments elapsed from AST + totalSegments := int(elapsed / segmentDuration) + // One behind live edge + liveSegment := totalSegments - 1 + if liveSegment < 0 { + liveSegment = 0 + } + return tmpl.StartNumber + liveSegment +} + +// classifyAdaptationSet determines if an AdaptationSet is "video" or "audio" +// by checking (in order): mimeType attr, contentType attr, codec prefix heuristic. +func classifyAdaptationSet(as *mpd.AdaptationSet) string { + // 1. AdaptationSet mimeType (e.g., "video/mp4") + if as.MimeType != "" { + return as.MimeType + } + // 2. AdaptationSet contentType (e.g., "video", "audio") + if as.ContentType != nil && *as.ContentType != "" { + return *as.ContentType + } + // 3. Infer from codec string on first Representation + for _, rep := range as.Representations { + if rep.Codecs != nil { + c := *rep.Codecs + if strings.HasPrefix(c, "av01") || strings.HasPrefix(c, "avc") || strings.HasPrefix(c, "hev") || strings.HasPrefix(c, "hvc") || strings.HasPrefix(c, "vp0") { + return "video" + } + if strings.HasPrefix(c, "mp4a") || strings.HasPrefix(c, "opus") { + return "audio" + } + } + } + return "" +} diff --git a/ingest/dash/mpd_test.go b/ingest/dash/mpd_test.go new file mode 100644 index 0000000..460e718 --- /dev/null +++ b/ingest/dash/mpd_test.go @@ -0,0 +1,315 @@ +package dash + +import ( + "testing" + "time" +) + +const testMPDDynamic = ` + + + + + + + + + + + + + + + + +` + +const testMPDStatic = ` + + + + + + + + +` + +func TestParseMPD(t *testing.T) { + info, err := parseMPD([]byte(testMPDDynamic)) + if err != nil { + t.Fatalf("parseMPD: %v", err) + } + + if !info.IsDynamic { + t.Error("expected IsDynamic=true") + } + + expectedAST := time.Date(2026, 3, 18, 0, 0, 0, 0, time.UTC) + if !info.AvailabilityStartTime.Equal(expectedAST) { + t.Errorf("AvailabilityStartTime = %v, want %v", info.AvailabilityStartTime, expectedAST) + } + + if info.MinUpdatePeriod != 2*time.Second { + t.Errorf("MinUpdatePeriod = %v, want 2s", info.MinUpdatePeriod) + } + + // Video adaptations + if len(info.VideoAdaptations) != 1 { + t.Fatalf("video adaptation count = %d, want 1", len(info.VideoAdaptations)) + } + va := info.VideoAdaptations[0] + if len(va.Representations) != 2 { + t.Fatalf("video rep count = %d, want 2", len(va.Representations)) + } + if va.Representations[0].ID != "1080p" { + t.Errorf("video rep[0] ID = %q, want 1080p", va.Representations[0].ID) + } + if va.Representations[1].ID != "720p" { + t.Errorf("video rep[1] ID = %q, want 720p", va.Representations[1].ID) + } + if va.Representations[0].Width != 1920 || va.Representations[0].Height != 1080 { + t.Errorf("video rep[0] dims = %dx%d, want 1920x1080", + va.Representations[0].Width, va.Representations[0].Height) + } + if va.Representations[0].Codecs != "av01.0.09M.08" { + t.Errorf("video rep[0] codecs = %q, want av01.0.09M.08", va.Representations[0].Codecs) + } + + // Audio adaptations + if len(info.AudioAdaptations) != 1 { + t.Fatalf("audio adaptation count = %d, want 1", len(info.AudioAdaptations)) + } + aa := info.AudioAdaptations[0] + if len(aa.Representations) != 1 { + t.Fatalf("audio rep count = %d, want 1", len(aa.Representations)) + } + if aa.Representations[0].ID != "aac-128k" { + t.Errorf("audio rep[0] ID = %q, want aac-128k", aa.Representations[0].ID) + } + + // Segment template on video rep + st := va.Representations[0].SegmentTemplate + if st.Timescale != 90000 { + t.Errorf("timescale = %d, want 90000", st.Timescale) + } + if st.Duration != 180000 { + t.Errorf("duration = %d, want 180000", st.Duration) + } + if st.StartNumber != 1 { + t.Errorf("startNumber = %d, want 1", st.StartNumber) + } +} + +func TestSelectRepresentation(t *testing.T) { + info, err := parseMPD([]byte(testMPDDynamic)) + if err != nil { + t.Fatalf("parseMPD: %v", err) + } + + t.Run("explicit ID", func(t *testing.T) { + rep, tmpl, err := selectRepresentation(info.VideoAdaptations, "720p") + if err != nil { + t.Fatalf("selectRepresentation: %v", err) + } + if rep.ID != "720p" { + t.Errorf("ID = %q, want 720p", rep.ID) + } + if rep.Bandwidth != 2000000 { + t.Errorf("Bandwidth = %d, want 2000000", rep.Bandwidth) + } + if tmpl.Timescale != 90000 { + t.Errorf("Timescale = %d, want 90000", tmpl.Timescale) + } + }) + + t.Run("empty ID selects highest bandwidth", func(t *testing.T) { + rep, _, err := selectRepresentation(info.VideoAdaptations, "") + if err != nil { + t.Fatalf("selectRepresentation: %v", err) + } + if rep.ID != "1080p" { + t.Errorf("ID = %q, want 1080p", rep.ID) + } + if rep.Bandwidth != 4000000 { + t.Errorf("Bandwidth = %d, want 4000000", rep.Bandwidth) + } + }) + + t.Run("nonexistent ID", func(t *testing.T) { + _, _, err := selectRepresentation(info.VideoAdaptations, "4k") + if err == nil { + t.Error("expected error for nonexistent ID") + } + }) +} + +func TestResolveInitURL(t *testing.T) { + tmpl := segmentTemplate{ + InitializationPattern: "video/$RepresentationID$/init.mp4", + MediaPattern: "video/$RepresentationID$/seg-$Number$.m4s", + StartNumber: 1, + Timescale: 90000, + Duration: 180000, + } + + got := resolveInitURL(tmpl, "1080p", "") + want := "video/1080p/init.mp4" + if got != want { + t.Errorf("resolveInitURL = %q, want %q", got, want) + } + + got = resolveInitURL(tmpl, "720p", "https://cdn.example.com/") + want = "https://cdn.example.com/video/720p/init.mp4" + if got != want { + t.Errorf("resolveInitURL with baseURL = %q, want %q", got, want) + } +} + +func TestResolveMediaURL(t *testing.T) { + tmpl := segmentTemplate{ + InitializationPattern: "video/$RepresentationID$/init.mp4", + MediaPattern: "video/$RepresentationID$/seg-$Number$.m4s", + StartNumber: 1, + Timescale: 90000, + Duration: 180000, + } + + got := resolveMediaURL(tmpl, "1080p", 42, "") + want := "video/1080p/seg-42.m4s" + if got != want { + t.Errorf("resolveMediaURL = %q, want %q", got, want) + } + + got = resolveMediaURL(tmpl, "720p", 1, "https://cdn.example.com/") + want = "https://cdn.example.com/video/720p/seg-1.m4s" + if got != want { + t.Errorf("resolveMediaURL with baseURL = %q, want %q", got, want) + } +} + +func TestComputeSegmentNumber(t *testing.T) { + ast := time.Date(2026, 3, 18, 0, 0, 0, 0, time.UTC) + tmpl := segmentTemplate{ + StartNumber: 1, + Timescale: 90000, + Duration: 180000, // 2 seconds per segment + } + + tests := []struct { + name string + offset time.Duration + want int + }{ + {"before first segment", 1 * time.Second, 1}, + {"exactly one segment", 2 * time.Second, 1}, + {"two segments elapsed", 4 * time.Second, 1 + 1}, + {"ten segments elapsed", 20 * time.Second, 1 + 9}, + {"one hour", 1 * time.Hour, 1 + 1799}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + now := ast.Add(tt.offset) + got := computeSegmentNumber(now, ast, tmpl) + if got != tt.want { + t.Errorf("computeSegmentNumber(offset=%v) = %d, want %d", tt.offset, got, tt.want) + } + }) + } +} + +func TestParseMPDStatic(t *testing.T) { + info, err := parseMPD([]byte(testMPDStatic)) + if err != nil { + t.Fatalf("parseMPD: %v", err) + } + if info.IsDynamic { + t.Error("expected IsDynamic=false for static MPD") + } + if len(info.VideoAdaptations) != 1 { + t.Errorf("video adaptation count = %d, want 1", len(info.VideoAdaptations)) + } +} + +func TestParseMPDInvalid(t *testing.T) { + _, err := parseMPD([]byte("this is not xml at all {{{")) + if err == nil { + t.Error("expected error for invalid XML") + } +} + +func TestResolveMediaURLZeroPadded(t *testing.T) { + t.Parallel() + // ffmpeg uses $Number%05d$ for zero-padded segment numbers + tmpl := segmentTemplate{ + MediaPattern: "chunk-stream$RepresentationID$-$Number%05d$.m4s", + } + got := resolveMediaURL(tmpl, "0", 3, "http://localhost:8080/") + want := "http://localhost:8080/chunk-stream0-00003.m4s" + if got != want { + t.Errorf("resolveMediaURL zero-padded = %q, want %q", got, want) + } +} + +// Test MPD with contentType on AdaptationSet and mimeType on Representation +// (ffmpeg pattern) +const testMPDContentType = ` + + + + + + + + + + + + + +` + +func TestParseMPDContentTypeAttr(t *testing.T) { + t.Parallel() + info, err := parseMPD([]byte(testMPDContentType)) + if err != nil { + t.Fatalf("parseMPD: %v", err) + } + if !info.IsDynamic { + t.Error("expected dynamic MPD") + } + if len(info.VideoAdaptations) != 1 { + t.Fatalf("video adaptations = %d, want 1", len(info.VideoAdaptations)) + } + if len(info.AudioAdaptations) != 1 { + t.Fatalf("audio adaptations = %d, want 1", len(info.AudioAdaptations)) + } + videoRep := info.VideoAdaptations[0].Representations[0] + if videoRep.Codecs != "av01.0.05M.08" { + t.Errorf("video codecs = %q, want av01.0.05M.08", videoRep.Codecs) + } +} diff --git a/ingest/dash/pipeline.go b/ingest/dash/pipeline.go new file mode 100644 index 0000000..33a8275 --- /dev/null +++ b/ingest/dash/pipeline.go @@ -0,0 +1,147 @@ +package dash + +import ( + "sync/atomic" + "time" + + "github.com/zsiec/prism/distribution" + "github.com/zsiec/prism/media" +) + +// Compile-time interface check. +var _ distribution.StatsProvider = (*dashPipeline)(nil) + +// broadcaster is the subset of distribution.Relay the DASH pipeline uses. +// It omits BroadcastCaptions because DASH sources do not carry CEA-608/708. +type broadcaster interface { + BroadcastVideo(frame *media.VideoFrame) + BroadcastAudio(frame *media.AudioFrame) + SetVideoInfo(info distribution.VideoInfo) + SetAudioTrackCount(count int) + AudioTrackCount() int + SetAudioInfo(info distribution.AudioInfo) + ViewerCount() int + ViewerStatsAll() []distribution.ViewerStats +} + +// dashPipeline bridges fMP4 mediaSamples to the frame-based Relay for fan-out +// to MoQ viewers. It converts decoded DASH segments into media.VideoFrame and +// media.AudioFrame values and broadcasts them through the Relay. +type dashPipeline struct { + streamKey string + relay broadcaster + seqHdrOBU []byte // AV1 sequence header OBU, prepended to keyframes as SPS + startTime time.Time + groupID uint32 + + videoForwarded atomic.Int64 + audioForwarded atomic.Int64 +} + +// newDASHPipeline creates a pipeline that converts DASH mediaSamples into +// VideoFrame/AudioFrame values and broadcasts them via relay. The seqHdrOBU +// is the AV1 sequence header OBU extracted from the init segment, attached +// to keyframes so downstream decoders can initialize. +func newDASHPipeline(streamKey string, relay broadcaster, seqHdrOBU []byte) *dashPipeline { + return &dashPipeline{ + streamKey: streamKey, + relay: relay, + seqHdrOBU: seqHdrOBU, + startTime: time.Now(), + } +} + +// processVideoSamples converts video mediaSamples into VideoFrames and +// broadcasts them through the relay. Keyframes start a new group and carry +// the AV1 sequence header OBU as SPS so that late-joining decoders can +// configure themselves. +// +// Frames are paced based on PTS deltas to avoid bursting an entire segment +// worth of frames at once (which causes jittery playback at the viewer). +func (p *dashPipeline) processVideoSamples(samples []mediaSample) { + if len(samples) == 0 { + return + } + + basePTS := samples[0].PTS + wallStart := time.Now() + + for i := range samples { + s := &samples[i] + + // Pace: sleep until this frame's presentation time relative to the + // first frame in the batch, so frames arrive at ~real-time cadence. + if i > 0 { + ptsDelta := time.Duration(s.PTS-basePTS) * time.Microsecond + wallElapsed := time.Since(wallStart) + if sleep := ptsDelta - wallElapsed; sleep > 0 { + time.Sleep(sleep) + } + } + + frame := &media.VideoFrame{ + PTS: s.PTS, + DTS: s.DTS, + IsKeyframe: s.IsKeyframe, + Codec: "av1", + WireData: s.Data, + } + + if s.IsKeyframe { + p.groupID++ + frame.SPS = p.seqHdrOBU + } + frame.GroupID = p.groupID + + p.relay.BroadcastVideo(frame) + p.videoForwarded.Add(1) + } +} + +// processAudioSamples converts audio mediaSamples into AudioFrames and +// broadcasts them through the relay. Paced by PTS like video. +func (p *dashPipeline) processAudioSamples(samples []mediaSample) { + if len(samples) == 0 { + return + } + + basePTS := samples[0].PTS + wallStart := time.Now() + + for i := range samples { + s := &samples[i] + + if i > 0 { + ptsDelta := time.Duration(s.PTS-basePTS) * time.Microsecond + wallElapsed := time.Since(wallStart) + if sleep := ptsDelta - wallElapsed; sleep > 0 { + time.Sleep(sleep) + } + } + + frame := &media.AudioFrame{ + PTS: s.PTS, + Data: s.Data, + } + + p.relay.BroadcastAudio(frame) + p.audioForwarded.Add(1) + } +} + +// StreamSnapshot returns a point-in-time snapshot of stream health metrics, +// implementing the distribution.StatsProvider interface for the stats overlay +// and REST API. +func (p *dashPipeline) StreamSnapshot() distribution.StreamSnapshot { + return distribution.StreamSnapshot{ + Timestamp: time.Now().UnixMilli(), + UptimeMs: time.Since(p.startTime).Milliseconds(), + Protocol: "DASH", + Video: distribution.VideoStats{ + Codec: "AV1", + TotalFrames: p.videoForwarded.Load(), + }, + ViewerCount: p.relay.ViewerCount(), + Viewers: p.relay.ViewerStatsAll(), + } +} diff --git a/ingest/dash/pipeline_test.go b/ingest/dash/pipeline_test.go new file mode 100644 index 0000000..9d071b5 --- /dev/null +++ b/ingest/dash/pipeline_test.go @@ -0,0 +1,266 @@ +package dash + +import ( + "testing" + "time" + + "github.com/zsiec/prism/distribution" + "github.com/zsiec/prism/media" +) + +// mockBroadcaster captures all calls made by the dashPipeline for test assertions. +type mockBroadcaster struct { + videoFrames []*media.VideoFrame + audioFrames []*media.AudioFrame + videoInfo *distribution.VideoInfo + audioInfo *distribution.AudioInfo + trackCount int +} + +func (m *mockBroadcaster) BroadcastVideo(frame *media.VideoFrame) { + m.videoFrames = append(m.videoFrames, frame) +} + +func (m *mockBroadcaster) BroadcastAudio(frame *media.AudioFrame) { + m.audioFrames = append(m.audioFrames, frame) +} + +func (m *mockBroadcaster) SetVideoInfo(info distribution.VideoInfo) { + m.videoInfo = &info +} + +func (m *mockBroadcaster) SetAudioTrackCount(count int) { + m.trackCount = count +} + +func (m *mockBroadcaster) AudioTrackCount() int { + if m.trackCount == 0 { + return 1 + } + return m.trackCount +} + +func (m *mockBroadcaster) SetAudioInfo(info distribution.AudioInfo) { + m.audioInfo = &info +} + +func (m *mockBroadcaster) ViewerCount() int { + return 0 +} + +func (m *mockBroadcaster) ViewerStatsAll() []distribution.ViewerStats { + return nil +} + +func TestDASHPipelineProcessVideoSamples(t *testing.T) { + t.Parallel() + + mock := &mockBroadcaster{} + seqHdr := []byte{0x0A, 0x0B, 0x0C} + p := newDASHPipeline("test-stream", mock, seqHdr) + + samples := []mediaSample{ + {PTS: 0, DTS: 0, Data: []byte{0x01, 0x02}, IsKeyframe: true}, + {PTS: 33333, DTS: 33333, Data: []byte{0x03, 0x04}, IsKeyframe: false}, + {PTS: 66666, DTS: 66666, Data: []byte{0x05, 0x06}, IsKeyframe: false}, + } + + p.processVideoSamples(samples) + + if len(mock.videoFrames) != 3 { + t.Fatalf("expected 3 video frames, got %d", len(mock.videoFrames)) + } + + // All frames should have codec "av1". + for i, f := range mock.videoFrames { + if f.Codec != "av1" { + t.Errorf("frame %d: codec = %q, want %q", i, f.Codec, "av1") + } + } + + // First frame: keyframe with SPS set. + kf := mock.videoFrames[0] + if !kf.IsKeyframe { + t.Error("frame 0: expected IsKeyframe=true") + } + if kf.SPS == nil { + t.Error("frame 0: expected SPS to be set for keyframe") + } + if len(kf.SPS) != 3 || kf.SPS[0] != 0x0A { + t.Errorf("frame 0: SPS = %v, want %v", kf.SPS, seqHdr) + } + + // Delta frames: no SPS. + for i := 1; i < 3; i++ { + df := mock.videoFrames[i] + if df.IsKeyframe { + t.Errorf("frame %d: expected IsKeyframe=false", i) + } + if df.SPS != nil { + t.Errorf("frame %d: expected SPS=nil for delta frame, got %v", i, df.SPS) + } + } + + // GroupID should be 1 (incremented once for the keyframe). + if kf.GroupID != 1 { + t.Errorf("frame 0: GroupID = %d, want 1", kf.GroupID) + } + for i := 1; i < 3; i++ { + if mock.videoFrames[i].GroupID != 1 { + t.Errorf("frame %d: GroupID = %d, want 1", i, mock.videoFrames[i].GroupID) + } + } + + // WireData should match sample data. + for i, s := range samples { + f := mock.videoFrames[i] + if len(f.WireData) != len(s.Data) { + t.Errorf("frame %d: WireData length = %d, want %d", i, len(f.WireData), len(s.Data)) + } + for j := range s.Data { + if f.WireData[j] != s.Data[j] { + t.Errorf("frame %d: WireData[%d] = %d, want %d", i, j, f.WireData[j], s.Data[j]) + } + } + } + + // Forwarded counter. + if got := p.videoForwarded.Load(); got != 3 { + t.Errorf("videoForwarded = %d, want 3", got) + } +} + +func TestDASHPipelineProcessAudioSamples(t *testing.T) { + t.Parallel() + + mock := &mockBroadcaster{} + p := newDASHPipeline("test-stream", mock, nil) + + samples := []mediaSample{ + {PTS: 0, Data: []byte{0xAA, 0xBB}}, + {PTS: 21333, Data: []byte{0xCC, 0xDD, 0xEE}}, + } + + p.processAudioSamples(samples) + + if len(mock.audioFrames) != 2 { + t.Fatalf("expected 2 audio frames, got %d", len(mock.audioFrames)) + } + + for i, s := range samples { + f := mock.audioFrames[i] + if f.PTS != s.PTS { + t.Errorf("frame %d: PTS = %d, want %d", i, f.PTS, s.PTS) + } + if len(f.Data) != len(s.Data) { + t.Errorf("frame %d: Data length = %d, want %d", i, len(f.Data), len(s.Data)) + continue + } + for j := range s.Data { + if f.Data[j] != s.Data[j] { + t.Errorf("frame %d: Data[%d] = %d, want %d", i, j, f.Data[j], s.Data[j]) + } + } + } + + if got := p.audioForwarded.Load(); got != 2 { + t.Errorf("audioForwarded = %d, want 2", got) + } +} + +func TestDASHPipelineGroupIDIncrement(t *testing.T) { + t.Parallel() + + mock := &mockBroadcaster{} + seqHdr := []byte{0x01} + p := newDASHPipeline("test-stream", mock, seqHdr) + + // First GOP: keyframe + delta. + p.processVideoSamples([]mediaSample{ + {PTS: 0, DTS: 0, IsKeyframe: true, Data: []byte{0x10}}, + {PTS: 33333, DTS: 33333, IsKeyframe: false, Data: []byte{0x11}}, + }) + + // Second GOP: another keyframe + delta. + p.processVideoSamples([]mediaSample{ + {PTS: 66666, DTS: 66666, IsKeyframe: true, Data: []byte{0x20}}, + {PTS: 99999, DTS: 99999, IsKeyframe: false, Data: []byte{0x21}}, + }) + + // Third GOP: yet another keyframe. + p.processVideoSamples([]mediaSample{ + {PTS: 133332, DTS: 133332, IsKeyframe: true, Data: []byte{0x30}}, + }) + + if len(mock.videoFrames) != 5 { + t.Fatalf("expected 5 video frames, got %d", len(mock.videoFrames)) + } + + // Group IDs: first keyframe=1, second keyframe=2, third keyframe=3. + expectedGroupIDs := []uint32{1, 1, 2, 2, 3} + for i, f := range mock.videoFrames { + if f.GroupID != expectedGroupIDs[i] { + t.Errorf("frame %d: GroupID = %d, want %d", i, f.GroupID, expectedGroupIDs[i]) + } + } + + // All keyframes should have SPS. + keyframeIndices := []int{0, 2, 4} + for _, idx := range keyframeIndices { + if mock.videoFrames[idx].SPS == nil { + t.Errorf("frame %d (keyframe): expected SPS to be set", idx) + } + } + + // All delta frames should not have SPS. + deltaIndices := []int{1, 3} + for _, idx := range deltaIndices { + if mock.videoFrames[idx].SPS != nil { + t.Errorf("frame %d (delta): expected SPS=nil", idx) + } + } +} + +func TestDASHPipelineStreamSnapshot(t *testing.T) { + t.Parallel() + + mock := &mockBroadcaster{} + p := newDASHPipeline("test-stream", mock, nil) + + // Push startTime back so UptimeMs is reliably > 0. + p.startTime = p.startTime.Add(-10 * time.Millisecond) + + // Process some samples so counters are non-zero. + p.processVideoSamples([]mediaSample{ + {PTS: 0, DTS: 0, IsKeyframe: true, Data: []byte{0x01}}, + }) + p.processAudioSamples([]mediaSample{ + {PTS: 0, Data: []byte{0x02}}, + }) + + snap := p.StreamSnapshot() + + if snap.Protocol != "DASH" { + t.Errorf("Protocol = %q, want %q", snap.Protocol, "DASH") + } + + if snap.UptimeMs <= 0 { + t.Errorf("UptimeMs = %d, want > 0", snap.UptimeMs) + } + + if snap.Timestamp <= 0 { + t.Errorf("Timestamp = %d, want > 0", snap.Timestamp) + } + + if snap.Video.Codec != "AV1" { + t.Errorf("Video.Codec = %q, want %q", snap.Video.Codec, "AV1") + } + + if snap.Video.TotalFrames != 1 { + t.Errorf("Video.TotalFrames = %d, want 1", snap.Video.TotalFrames) + } + + if snap.ViewerCount != 0 { + t.Errorf("ViewerCount = %d, want 0", snap.ViewerCount) + } +} diff --git a/ingest/dash/puller.go b/ingest/dash/puller.go new file mode 100644 index 0000000..83133b4 --- /dev/null +++ b/ingest/dash/puller.go @@ -0,0 +1,389 @@ +package dash + +import ( + "context" + "fmt" + "io" + "log/slog" + "net/http" + "net/url" + "strings" + "sync" + "time" + + "github.com/zsiec/prism/demux" + "github.com/zsiec/prism/distribution" + "github.com/zsiec/prism/moq" +) + +// PullRequest describes a DASH source to pull from. +type PullRequest struct { + URL string `json:"url"` + StreamKey string `json:"streamKey"` + VideoRepID string `json:"videoRepresentationId,omitempty"` + AudioRepID string `json:"audioRepresentationId,omitempty"` +} + +// DistributionServer is the subset of distribution.Server needed by the puller. +type DistributionServer interface { + RegisterStream(key string) *distribution.Relay + UnregisterStream(key string) + SetPipeline(key string, p distribution.StatsProvider) +} + +// StreamManager is the subset of stream.Manager needed by the puller. +type StreamManager interface { + Create(key string) (interface{}, bool) + Remove(key string) +} + +type activePull struct { + req PullRequest + cancel context.CancelFunc +} + +// Puller manages active DASH pull connections. +type Puller struct { + log *slog.Logger + client *http.Client + mu sync.Mutex + pulls map[string]*activePull +} + +// NewPuller creates a Puller with a default HTTP client. If log is nil, +// slog.Default() is used. +func NewPuller(log *slog.Logger) *Puller { + if log == nil { + log = slog.Default() + } + return &Puller{ + log: log.With("component", "dash-puller"), + client: &http.Client{ + Timeout: 30 * time.Second, + }, + pulls: make(map[string]*activePull), + } +} + +// Pull validates the request, fetches the MPD manifest synchronously, and +// starts a background goroutine that continuously fetches and processes +// DASH segments. It returns immediately after the MPD is validated. +func (p *Puller) Pull(ctx context.Context, req PullRequest, distSrv DistributionServer, mgr StreamManager) error { + if req.URL == "" { + return fmt.Errorf("url is required") + } + if req.StreamKey == "" { + return fmt.Errorf("streamKey is required") + } + + p.mu.Lock() + if _, exists := p.pulls[req.StreamKey]; exists { + p.mu.Unlock() + return fmt.Errorf("pull already active for stream key %q", req.StreamKey) + } + p.mu.Unlock() + + // Fetch and validate MPD synchronously so the caller gets immediate + // feedback on bad URLs or malformed manifests. + p.log.Info("fetching MPD", "url", req.URL, "stream_key", req.StreamKey) + mpdData, err := p.fetchURL(ctx, req.URL) + if err != nil { + return fmt.Errorf("fetch MPD: %w", err) + } + + info, err := parseMPD(mpdData) + if err != nil { + return fmt.Errorf("parse MPD: %w", err) + } + + if !info.IsDynamic { + return fmt.Errorf("MPD is not dynamic (live); static MPDs are not supported") + } + + // Select video representation. + videoRep, videoTmpl, err := selectRepresentation(info.VideoAdaptations, req.VideoRepID) + if err != nil { + return fmt.Errorf("select video representation: %w", err) + } + + // Select audio representation (optional — some streams are video-only). + var audioRep representationInfo + var audioTmpl segmentTemplate + hasAudio := len(info.AudioAdaptations) > 0 + if hasAudio { + audioRep, audioTmpl, err = selectRepresentation(info.AudioAdaptations, req.AudioRepID) + if err != nil { + return fmt.Errorf("select audio representation: %w", err) + } + } + + // Double-check for duplicate before committing. + pullCtx, cancel := context.WithCancel(ctx) + + p.mu.Lock() + if _, exists := p.pulls[req.StreamKey]; exists { + p.mu.Unlock() + cancel() + return fmt.Errorf("pull already active for stream key %q", req.StreamKey) + } + p.pulls[req.StreamKey] = &activePull{req: req, cancel: cancel} + p.mu.Unlock() + + p.log.Info("starting DASH pull", + "stream_key", req.StreamKey, + "video_rep", videoRep.ID, + "video_dims", fmt.Sprintf("%dx%d", videoRep.Width, videoRep.Height), + "has_audio", hasAudio, + ) + + go p.runPullLoop(pullCtx, req, info, videoRep, videoTmpl, audioRep, audioTmpl, hasAudio, distSrv, mgr) + + return nil +} + +// Stop cancels an active pull by stream key. +func (p *Puller) Stop(streamKey string) error { + p.mu.Lock() + ap, ok := p.pulls[streamKey] + p.mu.Unlock() + + if !ok { + return fmt.Errorf("no active pull for stream key %q", streamKey) + } + + ap.cancel() + return nil +} + +// ActivePulls returns a snapshot of all active pull requests. +func (p *Puller) ActivePulls() []PullRequest { + p.mu.Lock() + defer p.mu.Unlock() + + out := make([]PullRequest, 0, len(p.pulls)) + for _, ap := range p.pulls { + out = append(out, ap.req) + } + return out +} + +// fetchURL performs an HTTP GET with context support and returns the response body. +func (p *Puller) fetchURL(ctx context.Context, rawURL string) ([]byte, error) { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, rawURL, nil) + if err != nil { + return nil, fmt.Errorf("create request: %w", err) + } + + resp, err := p.client.Do(req) + if err != nil { + return nil, fmt.Errorf("HTTP GET %s: %w", rawURL, err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("HTTP GET %s: status %d", rawURL, resp.StatusCode) + } + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("read response body: %w", err) + } + return body, nil +} + +// deriveBaseURL computes the base URL for resolving relative segment URLs. +// If the MPD contains a BaseURL, that is used; otherwise the MPD URL's +// directory is used. +func deriveBaseURL(mpdURL string, mpdBaseURL string) string { + if mpdBaseURL != "" { + return mpdBaseURL + } + // Strip filename from MPD URL to get directory. + u, err := url.Parse(mpdURL) + if err != nil { + return "" + } + idx := strings.LastIndex(u.Path, "/") + if idx >= 0 { + u.Path = u.Path[:idx+1] + } + return u.String() +} + +// runPullLoop is the background goroutine that continuously fetches DASH +// segments and feeds them into the distribution pipeline. +func (p *Puller) runPullLoop( + ctx context.Context, + req PullRequest, + info *mpdInfo, + videoRep representationInfo, + videoTmpl segmentTemplate, + audioRep representationInfo, + audioTmpl segmentTemplate, + hasAudio bool, + distSrv DistributionServer, + mgr StreamManager, +) { + defer func() { + distSrv.UnregisterStream(req.StreamKey) + if mgr != nil { + mgr.Remove(req.StreamKey) + } + p.mu.Lock() + delete(p.pulls, req.StreamKey) + p.mu.Unlock() + p.log.Info("DASH pull ended", "stream_key", req.StreamKey) + }() + + baseURL := deriveBaseURL(req.URL, info.BaseURL) + + // Download and parse video init segment. + videoInitURL := resolveInitURL(videoTmpl, videoRep.ID, baseURL) + videoInitData, err := p.fetchURL(ctx, videoInitURL) + if err != nil { + p.log.Error("fetch video init segment", "url", videoInitURL, "error", err) + return + } + + initInfo, err := parseInitSegment(videoInitData) + if err != nil { + p.log.Error("parse video init segment", "error", err) + return + } + + // If audio uses a different init URL, download that too. + if hasAudio { + audioInitURL := resolveInitURL(audioTmpl, audioRep.ID, baseURL) + if audioInitURL != videoInitURL { + audioInitData, err := p.fetchURL(ctx, audioInitURL) + if err != nil { + p.log.Warn("fetch audio init segment", "url", audioInitURL, "error", err) + hasAudio = false + } else { + audioInit, err := parseInitSegment(audioInitData) + if err != nil { + p.log.Warn("parse audio init segment", "error", err) + hasAudio = false + } else { + // Merge audio info into initInfo. + initInfo.AudioCodec = audioInit.AudioCodec + initInfo.SampleRate = audioInit.SampleRate + initInfo.Channels = audioInit.Channels + initInfo.AudioTimescale = audioInit.AudioTimescale + initInfo.AudioTrackID = audioInit.AudioTrackID + initInfo.audioTrex = audioInit.audioTrex + } + } + } + } + + // Register stream with distribution server. + relay := distSrv.RegisterStream(req.StreamKey) + if mgr != nil { + mgr.Create(req.StreamKey) + } + + // Set video info on relay with proper RFC 6381 codec string and decoder config. + vi := distribution.VideoInfo{ + Codec: initInfo.VideoCodec, + Width: initInfo.Width, + Height: initInfo.Height, + } + if initInfo.SeqHeaderOBU != nil { + if hdr, err := demux.ParseAV1SequenceHeader(initInfo.SeqHeaderOBU); err == nil { + vi.Codec = hdr.CodecString() + vi.Width = hdr.Width + vi.Height = hdr.Height + } + vi.DecoderConfig = moq.BuildAV1DecoderConfig(initInfo.SeqHeaderOBU) + } + relay.SetVideoInfo(vi) + + // Set audio info if available. + if hasAudio { + relay.SetAudioTrackCount(1) + relay.SetAudioInfo(distribution.AudioInfo{ + Codec: initInfo.AudioCodec, + SampleRate: initInfo.SampleRate, + Channels: initInfo.Channels, + }) + } + + // Create pipeline and register as stats provider. + pipeline := newDASHPipeline(req.StreamKey, relay, initInfo.SeqHeaderOBU) + distSrv.SetPipeline(req.StreamKey, pipeline) + + // Segment fetch loop. + lastSegNum := -1 + segmentDuration := time.Duration(0) + if videoTmpl.Timescale > 0 && videoTmpl.Duration > 0 { + segmentDuration = time.Duration(float64(videoTmpl.Duration) / float64(videoTmpl.Timescale) * float64(time.Second)) + } + + for { + if ctx.Err() != nil { + return + } + + segNum := computeSegmentNumber(time.Now(), info.AvailabilityStartTime, videoTmpl) + + if segNum == lastSegNum { + // Sleep until the next segment boundary. + sleepDur := segmentDuration / 4 + if sleepDur < 100*time.Millisecond { + sleepDur = 100 * time.Millisecond + } + select { + case <-ctx.Done(): + return + case <-time.After(sleepDur): + continue + } + } + lastSegNum = segNum + + // Fetch video segment. + videoMediaURL := resolveMediaURL(videoTmpl, videoRep.ID, segNum, baseURL) + videoData, err := p.fetchURL(ctx, videoMediaURL) + if err != nil { + if ctx.Err() != nil { + return + } + p.log.Warn("fetch video segment", "seg", segNum, "url", videoMediaURL, "error", err) + continue + } + + videoSamples, err := parseMediaSegment(videoData, initInfo, true) + if err != nil { + p.log.Warn("parse video segment", "seg", segNum, "error", err) + continue + } + // Fetch audio segment in parallel with video pacing. + var audioSamples []mediaSample + if hasAudio { + audioSegNum := computeSegmentNumber(time.Now(), info.AvailabilityStartTime, audioTmpl) + audioMediaURL := resolveMediaURL(audioTmpl, audioRep.ID, audioSegNum, baseURL) + audioData, err := p.fetchURL(ctx, audioMediaURL) + if err != nil { + if ctx.Err() != nil { + return + } + p.log.Warn("fetch audio segment", "seg", audioSegNum, "url", audioMediaURL, "error", err) + } else { + audioSamples, err = parseMediaSegment(audioData, initInfo, false) + if err != nil { + p.log.Warn("parse audio segment", "seg", audioSegNum, "error", err) + } + } + } + + // Pace video and audio concurrently so they interleave naturally. + done := make(chan struct{}) + go func() { + pipeline.processAudioSamples(audioSamples) + close(done) + }() + pipeline.processVideoSamples(videoSamples) + <-done + } +} diff --git a/ingest/dash/puller_test.go b/ingest/dash/puller_test.go new file mode 100644 index 0000000..d6406ba --- /dev/null +++ b/ingest/dash/puller_test.go @@ -0,0 +1,451 @@ +package dash + +import ( + "context" + "log/slog" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/zsiec/prism/distribution" +) + +func TestNewPuller(t *testing.T) { + t.Parallel() + + t.Run("with logger", func(t *testing.T) { + log := slog.Default() + p := NewPuller(log) + if p == nil { + t.Fatal("NewPuller returned nil") + } + if p.pulls == nil { + t.Error("pulls map is nil") + } + if p.client == nil { + t.Error("client is nil") + } + }) + + t.Run("nil logger uses default", func(t *testing.T) { + p := NewPuller(nil) + if p == nil { + t.Fatal("NewPuller returned nil") + } + if p.log == nil { + t.Error("log is nil") + } + }) +} + +func TestPullerValidation(t *testing.T) { + t.Parallel() + + p := NewPuller(slog.Default()) + + t.Run("empty URL", func(t *testing.T) { + err := p.Pull(context.Background(), PullRequest{ + StreamKey: "test", + }, nil, nil) + if err == nil { + t.Error("expected error for empty URL") + } + }) + + t.Run("empty StreamKey", func(t *testing.T) { + err := p.Pull(context.Background(), PullRequest{ + URL: "http://example.com/manifest.mpd", + }, nil, nil) + if err == nil { + t.Error("expected error for empty StreamKey") + } + }) +} + +func TestPullerDuplicateStreamKey(t *testing.T) { + t.Parallel() + + p := NewPuller(slog.Default()) + + // Insert a fake active pull directly into the map. + p.mu.Lock() + p.pulls["test-stream"] = &activePull{ + req: PullRequest{URL: "http://example.com/dash.mpd", StreamKey: "test-stream"}, + cancel: func() {}, + } + p.mu.Unlock() + + err := p.Pull(context.Background(), PullRequest{ + URL: "http://example.com/dash.mpd", + StreamKey: "test-stream", + }, nil, nil) + if err == nil { + t.Error("expected error for duplicate stream key") + } +} + +func TestPullerStopNonexistent(t *testing.T) { + t.Parallel() + + p := NewPuller(slog.Default()) + err := p.Stop("nonexistent") + if err == nil { + t.Error("expected error for nonexistent stream key") + } +} + +func TestPullerActivePulls(t *testing.T) { + t.Parallel() + + p := NewPuller(slog.Default()) + + // Empty initially. + pulls := p.ActivePulls() + if len(pulls) != 0 { + t.Errorf("expected 0 active pulls, got %d", len(pulls)) + } + + // Insert fake pulls. + p.mu.Lock() + p.pulls["stream-a"] = &activePull{ + req: PullRequest{URL: "http://a.com/dash.mpd", StreamKey: "stream-a"}, + cancel: func() {}, + } + p.pulls["stream-b"] = &activePull{ + req: PullRequest{URL: "http://b.com/dash.mpd", StreamKey: "stream-b"}, + cancel: func() {}, + } + p.mu.Unlock() + + pulls = p.ActivePulls() + if len(pulls) != 2 { + t.Errorf("expected 2 active pulls, got %d", len(pulls)) + } + + // Verify both keys are present. + keys := map[string]bool{} + for _, pr := range pulls { + keys[pr.StreamKey] = true + } + if !keys["stream-a"] || !keys["stream-b"] { + t.Errorf("unexpected keys: %v", keys) + } +} + +func TestPullerStop(t *testing.T) { + t.Parallel() + + p := NewPuller(slog.Default()) + + cancelled := false + p.mu.Lock() + p.pulls["test-stream"] = &activePull{ + req: PullRequest{URL: "http://example.com/dash.mpd", StreamKey: "test-stream"}, + cancel: func() { cancelled = true }, + } + p.mu.Unlock() + + err := p.Stop("test-stream") + if err != nil { + t.Fatalf("Stop returned error: %v", err) + } + if !cancelled { + t.Error("cancel function was not called") + } +} + +func TestDeriveBaseURL(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + mpdURL string + mpdBaseURL string + want string + }{ + { + name: "MPD has BaseURL", + mpdURL: "http://cdn.example.com/live/manifest.mpd", + mpdBaseURL: "http://cdn.example.com/live/", + want: "http://cdn.example.com/live/", + }, + { + name: "derive from MPD URL", + mpdURL: "http://cdn.example.com/live/manifest.mpd", + mpdBaseURL: "", + want: "http://cdn.example.com/live/", + }, + { + name: "MPD URL with deeper path", + mpdURL: "http://cdn.example.com/a/b/c/dash.mpd", + mpdBaseURL: "", + want: "http://cdn.example.com/a/b/c/", + }, + { + name: "MPD URL at root", + mpdURL: "http://cdn.example.com/manifest.mpd", + mpdBaseURL: "", + want: "http://cdn.example.com/", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := deriveBaseURL(tt.mpdURL, tt.mpdBaseURL) + if got != tt.want { + t.Errorf("deriveBaseURL(%q, %q) = %q, want %q", tt.mpdURL, tt.mpdBaseURL, got, tt.want) + } + }) + } +} + +func TestPullerFetchURL(t *testing.T) { + t.Parallel() + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/ok": + w.WriteHeader(http.StatusOK) + w.Write([]byte("hello")) + case "/error": + w.WriteHeader(http.StatusInternalServerError) + default: + w.WriteHeader(http.StatusNotFound) + } + })) + defer srv.Close() + + p := NewPuller(slog.Default()) + + t.Run("success", func(t *testing.T) { + data, err := p.fetchURL(context.Background(), srv.URL+"/ok") + if err != nil { + t.Fatalf("fetchURL: %v", err) + } + if string(data) != "hello" { + t.Errorf("body = %q, want %q", string(data), "hello") + } + }) + + t.Run("server error", func(t *testing.T) { + _, err := p.fetchURL(context.Background(), srv.URL+"/error") + if err == nil { + t.Error("expected error for 500 status") + } + }) + + t.Run("cancelled context", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + cancel() + _, err := p.fetchURL(ctx, srv.URL+"/ok") + if err == nil { + t.Error("expected error for cancelled context") + } + }) +} + +// mockDistributionServer implements DistributionServer for testing. +type mockDistributionServer struct { + registered map[string]*distribution.Relay + unregistered []string + pipelines map[string]distribution.StatsProvider +} + +func newMockDistributionServer() *mockDistributionServer { + return &mockDistributionServer{ + registered: make(map[string]*distribution.Relay), + pipelines: make(map[string]distribution.StatsProvider), + } +} + +func (m *mockDistributionServer) RegisterStream(key string) *distribution.Relay { + r := distribution.NewRelay() + m.registered[key] = r + return r +} + +func (m *mockDistributionServer) UnregisterStream(key string) { + m.unregistered = append(m.unregistered, key) + delete(m.registered, key) +} + +func (m *mockDistributionServer) SetPipeline(key string, p distribution.StatsProvider) { + m.pipelines[key] = p +} + +// mockStreamManager implements StreamManager for testing. +type mockStreamManager struct { + created []string + removed []string +} + +func (m *mockStreamManager) Create(key string) (interface{}, bool) { + m.created = append(m.created, key) + return nil, true +} + +func (m *mockStreamManager) Remove(key string) { + m.removed = append(m.removed, key) +} + +func TestPullerPullStaticMPDRejected(t *testing.T) { + t.Parallel() + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + w.Write([]byte(testMPDStatic)) + })) + defer srv.Close() + + p := NewPuller(slog.Default()) + distSrv := newMockDistributionServer() + + err := p.Pull(context.Background(), PullRequest{ + URL: srv.URL + "/manifest.mpd", + StreamKey: "test-static", + }, distSrv, nil) + + if err == nil { + t.Fatal("expected error for static MPD") + } + + // No stream should have been registered. + if len(distSrv.registered) != 0 { + t.Errorf("expected 0 registered streams, got %d", len(distSrv.registered)) + } +} + +func TestPullerPullBadURL(t *testing.T) { + t.Parallel() + + p := NewPuller(slog.Default()) + distSrv := newMockDistributionServer() + + err := p.Pull(context.Background(), PullRequest{ + URL: "http://localhost:1/nonexistent.mpd", + StreamKey: "test-bad", + }, distSrv, nil) + + if err == nil { + t.Fatal("expected error for unreachable URL") + } +} + +func TestPullerPullWithHTTP(t *testing.T) { + t.Parallel() + + // Serve a dynamic MPD. The pull loop will try to fetch init segments, + // which will fail, causing the loop to exit. We verify the lifecycle: + // stream registered, then cleaned up after the loop exits. + requestCount := 0 + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch { + case r.URL.Path == "/manifest.mpd": + w.WriteHeader(http.StatusOK) + w.Write([]byte(testMPDDynamic)) + requestCount++ + default: + // Init/media segments: return 404 so the loop exits. + w.WriteHeader(http.StatusNotFound) + } + })) + defer srv.Close() + + p := NewPuller(slog.Default()) + distSrv := newMockDistributionServer() + mgr := &mockStreamManager{} + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + err := p.Pull(ctx, PullRequest{ + URL: srv.URL + "/manifest.mpd", + StreamKey: "test-live", + }, distSrv, mgr) + + if err != nil { + t.Fatalf("Pull returned error: %v", err) + } + + // The pull should be registered immediately. + pulls := p.ActivePulls() + if len(pulls) != 1 { + t.Fatalf("expected 1 active pull, got %d", len(pulls)) + } + if pulls[0].StreamKey != "test-live" { + t.Errorf("stream key = %q, want %q", pulls[0].StreamKey, "test-live") + } + + // The background goroutine will fail on init segment fetch (404) and + // clean up. Wait briefly for that. + time.Sleep(500 * time.Millisecond) + + // After loop exits, pull should be removed from active list. + pulls = p.ActivePulls() + if len(pulls) != 0 { + t.Errorf("expected 0 active pulls after loop exit, got %d", len(pulls)) + } + + // Distribution server should have been unregistered. + if len(distSrv.unregistered) == 0 { + t.Error("expected stream to be unregistered after loop exit") + } + + // Stream manager should have been cleaned up. + if len(mgr.removed) == 0 { + t.Error("expected stream manager Remove to be called") + } +} + +func TestPullerPullAndStop(t *testing.T) { + t.Parallel() + + // Serve a dynamic MPD and valid-ish init segments that won't parse + // as fMP4 — the loop will exit on parse error. Instead, we'll use + // a server that delays responses so we can test Stop(). + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch { + case r.URL.Path == "/manifest.mpd": + w.WriteHeader(http.StatusOK) + w.Write([]byte(testMPDDynamic)) + default: + // Block until the request context is cancelled (simulating + // a slow CDN so we can test Stop cancellation). + <-r.Context().Done() + w.WriteHeader(http.StatusServiceUnavailable) + } + })) + defer srv.Close() + + p := NewPuller(slog.Default()) + distSrv := newMockDistributionServer() + + ctx := context.Background() + + err := p.Pull(ctx, PullRequest{ + URL: srv.URL + "/manifest.mpd", + StreamKey: "stop-test", + }, distSrv, nil) + if err != nil { + t.Fatalf("Pull returned error: %v", err) + } + + // Verify active. + if len(p.ActivePulls()) != 1 { + t.Fatalf("expected 1 active pull, got %d", len(p.ActivePulls())) + } + + // Stop the pull. + err = p.Stop("stop-test") + if err != nil { + t.Fatalf("Stop returned error: %v", err) + } + + // Wait for cleanup. + time.Sleep(500 * time.Millisecond) + + if len(p.ActivePulls()) != 0 { + t.Errorf("expected 0 active pulls after Stop, got %d", len(p.ActivePulls())) + } +} diff --git a/ingest/dash/segment.go b/ingest/dash/segment.go new file mode 100644 index 0000000..eadf175 --- /dev/null +++ b/ingest/dash/segment.go @@ -0,0 +1,232 @@ +package dash + +import ( + "bytes" + "fmt" + + "github.com/Eyevinn/mp4ff/mp4" + "github.com/zsiec/prism/demux" +) + +// initSegmentInfo holds codec and track metadata extracted from an fMP4 init +// segment (ftyp + moov). +type initSegmentInfo struct { + VideoCodec string // "av1", "h264", etc. + Width int + Height int + SeqHeaderOBU []byte // AV1: sequence header OBU from av1C ConfigOBUs + AudioCodec string // e.g. "mp4a.40.2" + SampleRate int + Channels int + VideoTimescale uint32 + AudioTimescale uint32 + VideoTrackID uint32 + AudioTrackID uint32 + videoTrex *mp4.TrexBox // stored for GetFullSamples + audioTrex *mp4.TrexBox +} + +// mediaSample is a decoded audio or video sample with timestamps in +// microseconds. +type mediaSample struct { + PTS int64 // microseconds + DTS int64 // microseconds + Data []byte + IsKeyframe bool +} + +// parseInitSegment decodes an fMP4 init segment (ftyp+moov) and extracts +// video/audio track metadata including codec configuration. +func parseInitSegment(data []byte) (*initSegmentInfo, error) { + if len(data) == 0 { + return nil, fmt.Errorf("empty init segment data") + } + + parsed, err := mp4.DecodeFile(bytes.NewReader(data)) + if err != nil { + return nil, fmt.Errorf("decode init segment: %w", err) + } + + // The moov box can be at parsed.Moov (when DecodeFile sees ftyp+moov) + // or at parsed.Init.Moov (when it's recognized as a fragmented init). + moov := parsed.Moov + if moov == nil && parsed.Init != nil { + moov = parsed.Init.Moov + } + if moov == nil { + return nil, fmt.Errorf("no moov box found in init segment") + } + + if len(moov.Traks) == 0 { + return nil, fmt.Errorf("no tracks found in init segment") + } + + info := &initSegmentInfo{} + + for _, trak := range moov.Traks { + if trak.Mdia == nil || trak.Mdia.Hdlr == nil { + continue + } + + trackID := trak.Tkhd.TrackID + handler := trak.Mdia.Hdlr.HandlerType + + switch handler { + case "vide": + info.VideoTrackID = trackID + if trak.Mdia.Mdhd != nil { + info.VideoTimescale = trak.Mdia.Mdhd.Timescale + } + if err := extractVideoInfo(trak, info); err != nil { + return nil, fmt.Errorf("extract video info: %w", err) + } + case "soun": + info.AudioTrackID = trackID + if trak.Mdia.Mdhd != nil { + info.AudioTimescale = trak.Mdia.Mdhd.Timescale + } + if err := extractAudioInfo(trak, info); err != nil { + return nil, fmt.Errorf("extract audio info: %w", err) + } + } + } + + // Resolve trex boxes from mvex. + if moov.Mvex != nil { + resolveTrex(moov.Mvex, info) + } + + return info, nil +} + +// extractVideoInfo populates video fields on info from the given video trak. +func extractVideoInfo(trak *mp4.TrakBox, info *initSegmentInfo) error { + stsd := getStsd(trak) + if stsd == nil { + return fmt.Errorf("no stsd box in video track") + } + + if stsd.Av01 != nil { + info.VideoCodec = "av1" + info.Width = int(stsd.Av01.Width) + info.Height = int(stsd.Av01.Height) + if stsd.Av01.Av1C != nil { + configOBUs := stsd.Av01.Av1C.ConfigOBUs + seqHdr := demux.FindSequenceHeaderOBU(configOBUs) + if seqHdr != nil { + info.SeqHeaderOBU = make([]byte, len(seqHdr)) + copy(info.SeqHeaderOBU, seqHdr) + } + } + } else if stsd.AvcX != nil { + info.VideoCodec = "h264" + info.Width = int(stsd.AvcX.Width) + info.Height = int(stsd.AvcX.Height) + } else { + return fmt.Errorf("unsupported video codec (no av01 or avcX in stsd)") + } + return nil +} + +// extractAudioInfo populates audio fields on info from the given audio trak. +func extractAudioInfo(trak *mp4.TrakBox, info *initSegmentInfo) error { + stsd := getStsd(trak) + if stsd == nil { + return fmt.Errorf("no stsd box in audio track") + } + + if stsd.Mp4a != nil { + info.AudioCodec = "mp4a.40.2" + info.SampleRate = int(stsd.Mp4a.SampleRate) + info.Channels = int(stsd.Mp4a.ChannelCount) + } else { + return fmt.Errorf("unsupported audio codec (no mp4a in stsd)") + } + return nil +} + +// getStsd returns the stsd box from a track, or nil if the path is incomplete. +func getStsd(trak *mp4.TrakBox) *mp4.StsdBox { + if trak.Mdia == nil || trak.Mdia.Minf == nil || trak.Mdia.Minf.Stbl == nil { + return nil + } + return trak.Mdia.Minf.Stbl.Stsd +} + +// resolveTrex finds TrexBox entries matching video and audio track IDs. +func resolveTrex(mvex *mp4.MvexBox, info *initSegmentInfo) { + // mp4ff stores a single trex in Trex and multiple in Trexs. + trexList := mvex.Trexs + if len(trexList) == 0 && mvex.Trex != nil { + trexList = []*mp4.TrexBox{mvex.Trex} + } + for _, trex := range trexList { + switch trex.TrackID { + case info.VideoTrackID: + info.videoTrex = trex + case info.AudioTrackID: + info.audioTrex = trex + } + } +} + +// parseMediaSegment decodes an fMP4 media segment and extracts samples for the +// specified track type. The init parameter must have been obtained from +// parseInitSegment on the corresponding init segment. +func parseMediaSegment(data []byte, init *initSegmentInfo, isVideo bool) ([]mediaSample, error) { + if len(data) == 0 { + return nil, fmt.Errorf("empty media segment data") + } + if init == nil { + return nil, fmt.Errorf("nil init segment info") + } + + parsed, err := mp4.DecodeFile(bytes.NewReader(data)) + if err != nil { + return nil, fmt.Errorf("decode media segment: %w", err) + } + + var trex *mp4.TrexBox + var timescale uint32 + if isVideo { + trex = init.videoTrex + timescale = init.VideoTimescale + } else { + trex = init.audioTrex + timescale = init.AudioTimescale + } + + if timescale == 0 { + return nil, fmt.Errorf("timescale is zero") + } + + var samples []mediaSample + + for _, seg := range parsed.Segments { + for _, frag := range seg.Fragments { + fullSamples, err := frag.GetFullSamples(trex) + if err != nil { + return nil, fmt.Errorf("get full samples: %w", err) + } + + for i := range fullSamples { + fs := &fullSamples[i] + ms := mediaSample{ + DTS: scaleToMicroseconds(fs.DecodeTime, timescale), + PTS: scaleToMicroseconds(uint64(fs.PresentationTime()), timescale), + Data: fs.Data, + IsKeyframe: fs.IsSync(), + } + samples = append(samples, ms) + } + } + } + + return samples, nil +} + +// scaleToMicroseconds converts a timestamp in the given timescale to +// microseconds. +func scaleToMicroseconds(ts uint64, timescale uint32) int64 { + return int64(ts * 1_000_000 / uint64(timescale)) +} diff --git a/ingest/dash/segment_test.go b/ingest/dash/segment_test.go new file mode 100644 index 0000000..1c14305 --- /dev/null +++ b/ingest/dash/segment_test.go @@ -0,0 +1,396 @@ +package dash + +import ( + "bytes" + "testing" + + "github.com/Eyevinn/mp4ff/aac" + "github.com/Eyevinn/mp4ff/av1" + "github.com/Eyevinn/mp4ff/mp4" +) + +func TestScaleToMicroseconds(t *testing.T) { + tests := []struct { + name string + ts uint64 + timescale uint32 + want int64 + }{ + {"zero", 0, 90000, 0}, + {"one second at 90kHz", 90000, 90000, 1_000_000}, + {"half second at 48kHz", 24000, 48000, 500_000}, + {"two seconds at 90kHz", 180000, 90000, 2_000_000}, + {"one frame at 30fps/90kHz", 3000, 90000, 33333}, + {"one sample at 44100", 1, 44100, 22}, // 1_000_000/44100 = 22.67 -> 22 + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := scaleToMicroseconds(tt.ts, tt.timescale) + if got != tt.want { + t.Errorf("scaleToMicroseconds(%d, %d) = %d, want %d", + tt.ts, tt.timescale, got, tt.want) + } + }) + } +} + +func TestParseInitSegmentNil(t *testing.T) { + _, err := parseInitSegment(nil) + if err == nil { + t.Error("expected error for nil data") + } + + _, err = parseInitSegment([]byte{}) + if err == nil { + t.Error("expected error for empty data") + } +} + +func TestParseInitSegmentInvalid(t *testing.T) { + _, err := parseInitSegment([]byte("not valid mp4 data at all")) + if err == nil { + t.Error("expected error for invalid data") + } +} + +func TestParseMediaSegmentNil(t *testing.T) { + _, err := parseMediaSegment(nil, &initSegmentInfo{}, true) + if err == nil { + t.Error("expected error for nil data") + } + + _, err = parseMediaSegment([]byte{}, &initSegmentInfo{}, true) + if err == nil { + t.Error("expected error for empty data") + } +} + +func TestParseMediaSegmentNilInit(t *testing.T) { + _, err := parseMediaSegment([]byte{0x00}, nil, true) + if err == nil { + t.Error("expected error for nil init info") + } +} + +// buildAV1InitSegment creates a minimal AV1+AAC init segment using mp4ff +// builder APIs. Returns the encoded bytes. +func buildAV1InitSegment(t *testing.T) []byte { + t.Helper() + + init := mp4.CreateEmptyInit() + + // Add video track (AV1, 1920x1080, 90kHz timescale) + videoTrak := init.AddEmptyTrack(90000, "video", "und") + videoTrak.Tkhd.Width = mp4.Fixed32(1920 << 16) + videoTrak.Tkhd.Height = mp4.Fixed32(1080 << 16) + + // Create an Av1C box with a minimal AV1 codec config + av1C := &mp4.Av1CBox{ + CodecConfRec: av1.CodecConfRec{ + Version: 1, + SeqProfile: 0, + SeqLevelIdx0: 9, + SeqTier0: 0, + HighBitdepth: 0, + // Build a minimal sequence header OBU for ConfigOBUs. + // OBU type 1 (seq header), has_size=1: + // header byte: 0b0_0001_0_1_0 = 0x0A + // size: LEB128 for 4 bytes = 0x04 + // payload: 4 bytes of minimal data + ConfigOBUs: buildMinimalSeqHeaderOBU(), + }, + } + + // Create AV1 visual sample entry and add to stsd + av01 := mp4.CreateVisualSampleEntryBox("av01", 1920, 1080, av1C) + videoTrak.Mdia.Minf.Stbl.Stsd.AddChild(av01) + + // Add audio track (AAC, 48kHz, stereo) + audioTrak := init.AddEmptyTrack(48000, "audio", "und") + err := audioTrak.SetAACDescriptor(aac.AAClc, 48000) + if err != nil { + t.Fatalf("SetAACDescriptor: %v", err) + } + + var buf bytes.Buffer + err = init.Encode(&buf) + if err != nil { + t.Fatalf("encode init segment: %v", err) + } + return buf.Bytes() +} + +// buildMinimalSeqHeaderOBU returns a minimal AV1 sequence header OBU with +// size field. This is not a fully valid sequence header but contains enough +// structure for FindSequenceHeaderOBU to locate it. +func buildMinimalSeqHeaderOBU() []byte { + // OBU header: type=1 (seq header), has_size=1 => 0x0A + // Size in LEB128: we'll use a small payload + // A minimal sequence_header_obu payload (not fully decodeable but + // structurally valid enough for FindSequenceHeaderOBU to find and return) + payload := []byte{ + 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, + } + obu := []byte{0x0A, byte(len(payload))} + obu = append(obu, payload...) + return obu +} + +func TestParseInitSegmentAV1(t *testing.T) { + data := buildAV1InitSegment(t) + + info, err := parseInitSegment(data) + if err != nil { + t.Fatalf("parseInitSegment: %v", err) + } + + if info.VideoCodec != "av1" { + t.Errorf("VideoCodec = %q, want av1", info.VideoCodec) + } + if info.Width != 1920 { + t.Errorf("Width = %d, want 1920", info.Width) + } + if info.Height != 1080 { + t.Errorf("Height = %d, want 1080", info.Height) + } + if info.VideoTimescale != 90000 { + t.Errorf("VideoTimescale = %d, want 90000", info.VideoTimescale) + } + if info.VideoTrackID == 0 { + t.Error("VideoTrackID should not be zero") + } + if info.SeqHeaderOBU == nil { + t.Error("SeqHeaderOBU should not be nil") + } + + // Audio track + if info.AudioCodec != "mp4a.40.2" { + t.Errorf("AudioCodec = %q, want mp4a.40.2", info.AudioCodec) + } + if info.SampleRate != 48000 { + t.Errorf("SampleRate = %d, want 48000", info.SampleRate) + } + if info.Channels != 2 { + t.Errorf("Channels = %d, want 2", info.Channels) + } + if info.AudioTimescale != 48000 { + t.Errorf("AudioTimescale = %d, want 48000", info.AudioTimescale) + } + if info.AudioTrackID == 0 { + t.Error("AudioTrackID should not be zero") + } + + // Trex boxes should be resolved + if info.videoTrex == nil { + t.Error("videoTrex should not be nil") + } + if info.audioTrex == nil { + t.Error("audioTrex should not be nil") + } +} + +// buildVideoMediaSegment creates a minimal fMP4 media segment with a single +// video fragment containing 3 samples. +func buildVideoMediaSegment(t *testing.T, trackID uint32) []byte { + t.Helper() + + seg := mp4.NewMediaSegment() + frag, err := mp4.CreateFragment(1, trackID) + if err != nil { + t.Fatalf("CreateFragment: %v", err) + } + + // Add 3 samples: first is a keyframe, others are not + sampleData := []byte{0x01, 0x02, 0x03, 0x04} + baseDecodeTime := uint64(0) + + // Keyframe sample + frag.AddFullSample(mp4.FullSample{ + Sample: mp4.Sample{ + Flags: mp4.SyncSampleFlags, + Dur: 3000, // 1 frame at 30fps in 90kHz + Size: uint32(len(sampleData)), + CompositionTimeOffset: 0, + }, + DecodeTime: baseDecodeTime, + Data: sampleData, + }) + + // Non-keyframe sample + frag.AddFullSample(mp4.FullSample{ + Sample: mp4.Sample{ + Flags: mp4.NonSyncSampleFlags, + Dur: 3000, + Size: uint32(len(sampleData)), + CompositionTimeOffset: 3000, // PTS > DTS + }, + DecodeTime: baseDecodeTime + 3000, + Data: sampleData, + }) + + // Another non-keyframe sample + frag.AddFullSample(mp4.FullSample{ + Sample: mp4.Sample{ + Flags: mp4.NonSyncSampleFlags, + Dur: 3000, + Size: uint32(len(sampleData)), + CompositionTimeOffset: 0, + }, + DecodeTime: baseDecodeTime + 6000, + Data: sampleData, + }) + + seg.AddFragment(frag) + + var buf bytes.Buffer + err = seg.Encode(&buf) + if err != nil { + t.Fatalf("encode media segment: %v", err) + } + return buf.Bytes() +} + +func TestParseMediaSegmentVideo(t *testing.T) { + // First build an init segment to get metadata + initData := buildAV1InitSegment(t) + info, err := parseInitSegment(initData) + if err != nil { + t.Fatalf("parseInitSegment: %v", err) + } + + // Build a video media segment + segData := buildVideoMediaSegment(t, info.VideoTrackID) + + samples, err := parseMediaSegment(segData, info, true) + if err != nil { + t.Fatalf("parseMediaSegment: %v", err) + } + + if len(samples) != 3 { + t.Fatalf("sample count = %d, want 3", len(samples)) + } + + // First sample: keyframe at DTS=0 + if !samples[0].IsKeyframe { + t.Error("sample[0] should be a keyframe") + } + if samples[0].DTS != 0 { + t.Errorf("sample[0].DTS = %d, want 0", samples[0].DTS) + } + if samples[0].PTS != 0 { + t.Errorf("sample[0].PTS = %d, want 0", samples[0].PTS) + } + if len(samples[0].Data) != 4 { + t.Errorf("sample[0].Data length = %d, want 4", len(samples[0].Data)) + } + + // Second sample: non-keyframe at DTS=3000/90000 * 1e6 = 33333us + if samples[1].IsKeyframe { + t.Error("sample[1] should not be a keyframe") + } + expectedDTS := scaleToMicroseconds(3000, 90000) + if samples[1].DTS != expectedDTS { + t.Errorf("sample[1].DTS = %d, want %d", samples[1].DTS, expectedDTS) + } + // PTS = (3000 + 3000) / 90000 * 1e6 = 66666us + expectedPTS := scaleToMicroseconds(6000, 90000) + if samples[1].PTS != expectedPTS { + t.Errorf("sample[1].PTS = %d, want %d", samples[1].PTS, expectedPTS) + } + + // Third sample: non-keyframe at DTS=6000/90000 * 1e6 = 66666us + expectedDTS3 := scaleToMicroseconds(6000, 90000) + if samples[2].DTS != expectedDTS3 { + t.Errorf("sample[2].DTS = %d, want %d", samples[2].DTS, expectedDTS3) + } +} + +func TestParseMediaSegmentZeroTimescale(t *testing.T) { + initInfo := &initSegmentInfo{ + VideoTimescale: 0, + videoTrex: &mp4.TrexBox{TrackID: 1}, + } + _, err := parseMediaSegment([]byte{0x00}, initInfo, true) + if err == nil { + t.Error("expected error for zero timescale") + } +} + +// buildAudioMediaSegment creates a minimal fMP4 media segment with audio +// samples. +func buildAudioMediaSegment(t *testing.T, trackID uint32) []byte { + t.Helper() + + seg := mp4.NewMediaSegment() + frag, err := mp4.CreateFragment(1, trackID) + if err != nil { + t.Fatalf("CreateFragment: %v", err) + } + + // Add 2 audio samples (AAC frames are all sync) + sampleData := make([]byte, 128) + for i := range sampleData { + sampleData[i] = byte(i) + } + + frag.AddFullSample(mp4.FullSample{ + Sample: mp4.Sample{ + Flags: mp4.SyncSampleFlags, + Dur: 1024, // typical AAC frame duration at 48kHz + Size: uint32(len(sampleData)), + }, + DecodeTime: 0, + Data: sampleData, + }) + + frag.AddFullSample(mp4.FullSample{ + Sample: mp4.Sample{ + Flags: mp4.SyncSampleFlags, + Dur: 1024, + Size: uint32(len(sampleData)), + }, + DecodeTime: 1024, + Data: sampleData, + }) + + seg.AddFragment(frag) + + var buf bytes.Buffer + err = seg.Encode(&buf) + if err != nil { + t.Fatalf("encode media segment: %v", err) + } + return buf.Bytes() +} + +func TestParseMediaSegmentAudio(t *testing.T) { + initData := buildAV1InitSegment(t) + info, err := parseInitSegment(initData) + if err != nil { + t.Fatalf("parseInitSegment: %v", err) + } + + segData := buildAudioMediaSegment(t, info.AudioTrackID) + + samples, err := parseMediaSegment(segData, info, false) + if err != nil { + t.Fatalf("parseMediaSegment: %v", err) + } + + if len(samples) != 2 { + t.Fatalf("sample count = %d, want 2", len(samples)) + } + + // First audio sample: DTS=0 + if samples[0].DTS != 0 { + t.Errorf("sample[0].DTS = %d, want 0", samples[0].DTS) + } + + // Second audio sample: DTS = 1024/48000 * 1e6 = 21333us + expectedDTS := scaleToMicroseconds(1024, 48000) + if samples[1].DTS != expectedDTS { + t.Errorf("sample[1].DTS = %d, want %d", samples[1].DTS, expectedDTS) + } +} diff --git a/moq/format.go b/moq/format.go index 2a55d1c..0c387a9 100644 --- a/moq/format.go +++ b/moq/format.go @@ -86,6 +86,60 @@ func BuildAVCDecoderConfig(sps, pps []byte) []byte { return buf } +// BuildAV1DecoderConfig builds an AV1CodecConfigurationRecord +// (https://aomediacodec.github.io/av1-isobmff/#av1codecconfigurationrecord) +// from a raw Sequence Header OBU (including OBU header bytes). +// The sequence header OBU is appended as configOBUs. +// Returns nil if the input is nil, too short, or fails to parse. +func BuildAV1DecoderConfig(seqHeaderOBU []byte) []byte { + if len(seqHeaderOBU) < 2 { + return nil + } + + hdr, err := demux.ParseAV1SequenceHeader(seqHeaderOBU) + if err != nil { + return nil + } + + buf := make([]byte, 0, 4+len(seqHeaderOBU)) + + // Byte 0: marker(1)=1 | version(7)=1 + buf = append(buf, 0x81) + + // Byte 1: seq_profile(3) | seq_level_idx_0(5) + buf = append(buf, byte(hdr.SeqProfile&0x07)<<5|byte(hdr.SeqLevelIdx&0x1F)) + + // Byte 2: seq_tier_0(1) | high_bitdepth(1) | twelve_bit(1) | monochrome(1) | + // chroma_subsampling_x(1) | chroma_subsampling_y(1) | chroma_sample_position(2)=0 + var highBitDepth, twelveBit byte + if hdr.BitDepth >= 10 { + highBitDepth = 1 + } + if hdr.BitDepth == 12 { + twelveBit = 1 + } + var mono byte + if hdr.Monochrome { + mono = 1 + } + b2 := byte(hdr.SeqTier&0x01)<<7 | + highBitDepth<<6 | + twelveBit<<5 | + mono<<4 | + byte(hdr.ChromaSubsamplingX&0x01)<<3 | + byte(hdr.ChromaSubsamplingY&0x01)<<2 + // chroma_sample_position (2 bits) = 0 + buf = append(buf, b2) + + // Byte 3: reserved(3)=0 | initial_presentation_delay_present(1)=0 | reserved(4)=0 + buf = append(buf, 0x00) + + // configOBUs: the sequence header OBU itself + buf = append(buf, seqHeaderOBU...) + + return buf +} + // BuildHEVCDecoderConfig builds an HEVCDecoderConfigurationRecord // (ISO 14496-15 §8.3.3.1.2) from raw VPS, SPS, and PPS NAL data // (without start codes). The SPS must include the 2-byte NAL header. diff --git a/moq/format_test.go b/moq/format_test.go index c9f26ad..f79f3e8 100644 --- a/moq/format_test.go +++ b/moq/format_test.go @@ -279,6 +279,127 @@ func TestBuildHEVCDecoderConfig(t *testing.T) { } } +func TestBuildAV1DecoderConfig(t *testing.T) { + t.Parallel() + + // 1920x1080, profile 0, level 8, 8-bit, 4:2:0 (from demux/av1_test.go) + seqHeaderOBU := []byte{ + 0x0a, 0x0b, // OBU header (type=1, has_size=1) + LEB128 size=11 + 0x00, 0x00, 0x00, 0x42, 0xab, 0xbf, 0xc3, 0x71, + 0x2b, 0xe4, 0x01, + } + + config := BuildAV1DecoderConfig(seqHeaderOBU) + if config == nil { + t.Fatal("expected non-nil config") + } + + // Total length = 4-byte header + len(seqHeaderOBU) + expectedLen := 4 + len(seqHeaderOBU) + if len(config) != expectedLen { + t.Fatalf("total length: got %d, want %d", len(config), expectedLen) + } + + // Byte 0: marker(1)=1 | version(7)=1 → 0x81 + if config[0] != 0x81 { + t.Errorf("byte 0 (marker+version): got 0x%02x, want 0x81", config[0]) + } + + // Byte 1: seq_profile(3)=0 | seq_level_idx_0(5)=8 → 0x08 + if config[1] != 0x08 { + t.Errorf("byte 1 (profile|level): got 0x%02x, want 0x08", config[1]) + } + + // Verify profile extracted correctly (top 3 bits) + profile := config[1] >> 5 + if profile != 0 { + t.Errorf("profile: got %d, want 0", profile) + } + + // Verify level extracted correctly (bottom 5 bits) + level := config[1] & 0x1F + if level != 8 { + t.Errorf("level: got %d, want 8", level) + } + + // Byte 2: tier=0, high_bitdepth=0, twelve_bit=0, mono=0, csx=1, csy=1, csp=0 → 0x0C + if config[2] != 0x0C { + t.Errorf("byte 2 (tier/bd/chroma): got 0x%02x, want 0x0C", config[2]) + } + + // Byte 3: reserved = 0x00 + if config[3] != 0x00 { + t.Errorf("byte 3 (reserved): got 0x%02x, want 0x00", config[3]) + } + + // configOBUs (bytes 4+) must match input seqHeaderOBU + if !bytes.Equal(config[4:], seqHeaderOBU) { + t.Error("configOBUs do not match input seqHeaderOBU") + } +} + +func TestBuildAV1DecoderConfig10Bit(t *testing.T) { + t.Parallel() + + // 640x480, profile 0, level 4, 10-bit, 4:2:0 (from demux/av1_test.go) + seqHeaderOBU := []byte{ + 0x0a, 0x0b, // OBU header (type=1, has_size=1) + LEB128 size=11 + 0x00, 0x00, 0x00, 0x24, 0xc4, 0xff, 0xdf, 0x12, + 0xbe, 0x50, 0x10, + } + + config := BuildAV1DecoderConfig(seqHeaderOBU) + if config == nil { + t.Fatal("expected non-nil config") + } + + // Total length = 4-byte header + len(seqHeaderOBU) + expectedLen := 4 + len(seqHeaderOBU) + if len(config) != expectedLen { + t.Fatalf("total length: got %d, want %d", len(config), expectedLen) + } + + // Byte 0: marker + version + if config[0] != 0x81 { + t.Errorf("byte 0: got 0x%02x, want 0x81", config[0]) + } + + // Byte 1: profile=0 | level=4 → 0x04 + if config[1] != 0x04 { + t.Errorf("byte 1 (profile|level): got 0x%02x, want 0x04", config[1]) + } + + // Byte 2: tier=0, high_bitdepth=1, twelve_bit=0, mono=0, csx=1, csy=1, csp=0 + // = 0<<7 | 1<<6 | 0<<5 | 0<<4 | 1<<3 | 1<<2 | 0 = 0x4C + if config[2] != 0x4C { + t.Errorf("byte 2 (tier/bd/chroma): got 0x%02x, want 0x4C", config[2]) + } + + // configOBUs (bytes 4+) must match input + if !bytes.Equal(config[4:], seqHeaderOBU) { + t.Error("configOBUs do not match input seqHeaderOBU") + } +} + +func TestBuildAV1DecoderConfigNil(t *testing.T) { + t.Parallel() + + if BuildAV1DecoderConfig(nil) != nil { + t.Error("expected nil for nil input") + } + if BuildAV1DecoderConfig([]byte{}) != nil { + t.Error("expected nil for empty input") + } + if BuildAV1DecoderConfig([]byte{0x0a}) != nil { + t.Error("expected nil for truncated input") + } + + // Wrong OBU type (OBU_FRAME instead of SEQUENCE_HEADER) + if BuildAV1DecoderConfig([]byte{0x32, 0x01, 0x00}) != nil { + t.Error("expected nil for wrong OBU type") + } +} + func TestBuildHEVCDecoderConfigNil(t *testing.T) { t.Parallel() if BuildHEVCDecoderConfig(nil, []byte{0x42, 0x01, 0x01, 0x01}, []byte{0x44}) != nil { diff --git a/pipeline/pipeline.go b/pipeline/pipeline.go index 3dd88ba..89befbe 100644 --- a/pipeline/pipeline.go +++ b/pipeline/pipeline.go @@ -228,7 +228,19 @@ func (p *Pipeline) forwardVideo(frame *media.VideoFrame) { // including decoder configuration record for the catalog. func (p *Pipeline) buildVideoInfo(frame *media.VideoFrame) (distribution.VideoInfo, bool) { var vi distribution.VideoInfo - if frame.Codec == "h265" { + switch frame.Codec { + case "av1": + info, err := demux.ParseAV1SequenceHeader(frame.SPS) + if err != nil { + return vi, false + } + vi = distribution.VideoInfo{ + Codec: info.CodecString(), + Width: info.Width, + Height: info.Height, + } + vi.DecoderConfig = moq.BuildAV1DecoderConfig(frame.SPS) + case "h265": info, err := demux.ParseHEVCSPS(frame.SPS) if err != nil { return vi, false @@ -241,7 +253,7 @@ func (p *Pipeline) buildVideoInfo(frame *media.VideoFrame) (distribution.VideoIn if frame.VPS != nil { vi.DecoderConfig = moq.BuildHEVCDecoderConfig(frame.VPS, frame.SPS, frame.PPS) } - } else { + default: info, err := demux.ParseSPS(frame.SPS) if err != nil { return vi, false diff --git a/pipeline/pipeline_test.go b/pipeline/pipeline_test.go index 719e33f..61f393c 100644 --- a/pipeline/pipeline_test.go +++ b/pipeline/pipeline_test.go @@ -6,6 +6,7 @@ import ( "testing" "github.com/zsiec/prism/distribution" + "github.com/zsiec/prism/media" ) func TestNew(t *testing.T) { @@ -71,3 +72,51 @@ func TestDemuxStats(t *testing.T) { t.Fatal("expected non-nil DemuxStats") } } + +func TestBuildVideoInfoAV1(t *testing.T) { + t.Parallel() + + relay := distribution.NewRelay() + p := New("test-stream", strings.NewReader(""), relay) + + // AV1 sequence header OBU for 1920x1080, profile 0, level 8, 8-bit + // (same test data as demux/av1_test.go testSeqHdrOBU1080p) + seqHdrOBU := []byte{ + 0x0a, 0x0b, + 0x00, 0x00, 0x00, 0x42, 0xab, 0xbf, 0xc3, 0x71, + 0x2b, 0xe4, 0x01, + } + + frame := &media.VideoFrame{ + PTS: 0, + IsKeyframe: true, + Codec: "av1", + SPS: seqHdrOBU, + WireData: []byte{0x32, 0x10, 0xAA, 0xBB}, + } + + vi, ok := p.buildVideoInfo(frame) + if !ok { + t.Fatal("buildVideoInfo returned false") + } + + if vi.Width != 1920 { + t.Errorf("Width: got %d, want 1920", vi.Width) + } + if vi.Height != 1080 { + t.Errorf("Height: got %d, want 1080", vi.Height) + } + + wantPrefix := "av01." + if len(vi.Codec) < len(wantPrefix) || vi.Codec[:len(wantPrefix)] != wantPrefix { + t.Errorf("Codec: got %q, want prefix %q", vi.Codec, wantPrefix) + } + + if vi.DecoderConfig == nil { + t.Fatal("DecoderConfig is nil") + } + // AV1CodecConfigurationRecord starts with marker|version = 0x81 + if vi.DecoderConfig[0] != 0x81 { + t.Errorf("DecoderConfig[0]: got 0x%02x, want 0x81", vi.DecoderConfig[0]) + } +}