From af5629c75dadfc70b108ea555b412e517330f0fd Mon Sep 17 00:00:00 2001 From: Thomas Symborski Date: Wed, 18 Mar 2026 16:20:37 -0400 Subject: [PATCH 01/14] feat: add AV1 OBU parsing for sequence headers and keyframe detection Implement AV1 bitstream parsing in demux/av1.go with four exported functions: - ParseAV1SequenceHeader: extracts profile, level, tier, bit depth, resolution, and chroma subsampling from sequence header OBUs - CodecString: generates RFC 6381 codec strings (e.g. "av01.0.08M.08") - IsAV1Keyframe: scans temporal units for KEY_FRAME in OBU_FRAME/FRAME_HEADER - FindSequenceHeaderOBU: locates sequence header OBU within a temporal unit Internal helpers include OBU header parsing, LEB128 decoding, and an error-tracking MSB-first bit reader with UVLC support. --- demux/av1.go | 520 +++++++++++++++++++++++++++++++++++ demux/av1_test.go | 674 ++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 1194 insertions(+) create mode 100644 demux/av1.go create mode 100644 demux/av1_test.go diff --git a/demux/av1.go b/demux/av1.go new file mode 100644 index 0000000..0d4346f --- /dev/null +++ b/demux/av1.go @@ -0,0 +1,520 @@ +package demux + +import ( + "errors" + "fmt" +) + +// AV1 OBU type constants as defined in AV1 spec §6.2.2. +const ( + OBUSequenceHeader = 1 + OBUTemporalDelimiter = 2 + OBUFrameHeader = 3 + OBUTileGroup = 4 + OBUMetadata = 5 + OBUFrame = 6 + OBURedundantFrameHeader = 7 + OBUTileList = 8 + OBUPadding = 15 +) + +// AV1 frame type constants as defined in AV1 spec §6.8.2. +const ( + av1FrameKey = 0 + av1FrameInter = 1 + av1FrameIntraOnly = 2 + av1FrameSwitch = 3 +) + +var ( + errAV1TooShort = errors.New("AV1 data too short") + errAV1InvalidOBU = errors.New("invalid AV1 OBU header") + errAV1LEB128 = errors.New("invalid LEB128 encoding") + errAV1SeqHdrParse = errors.New("failed to parse AV1 sequence header") + errAV1ForbiddenBit = errors.New("AV1 OBU forbidden bit set") +) + +// AV1SequenceHeader holds parameters extracted from an AV1 Sequence Header OBU. +type AV1SequenceHeader struct { + SeqProfile int + SeqLevelIdx int + SeqTier int // 0=Main, 1=High + BitDepth int + Width int + Height int + ChromaSubsamplingX int + ChromaSubsamplingY int + Monochrome bool +} + +// CodecString returns the RFC 6381 codec parameter string for AV1. +// Format: "av01.P.LLT.DD" where P=profile, LL=zero-padded level, +// T=tier(M/H), DD=zero-padded bit depth. +func (s *AV1SequenceHeader) CodecString() string { + tier := "M" + if s.SeqTier == 1 { + tier = "H" + } + return fmt.Sprintf("av01.%d.%02d%s.%02d", s.SeqProfile, s.SeqLevelIdx, tier, s.BitDepth) +} + +// obuHeader represents a parsed AV1 OBU header. +type obuHeader struct { + obuType int + hasExtension bool + hasSizeField bool + temporalID int + spatialID int + headerLen int // total header bytes (1 or 2) + payloadSize int // only valid when hasSizeField is true + totalSize int // headerLen + leb128 size bytes + payloadSize +} + +// parseOBUHeader parses an OBU header from the given data. Returns the +// parsed header and total bytes consumed (header + optional size field). +func parseOBUHeader(data []byte) (obuHeader, error) { + if len(data) < 1 { + return obuHeader{}, errAV1TooShort + } + + b := data[0] + if b&0x80 != 0 { + return obuHeader{}, errAV1ForbiddenBit + } + + h := obuHeader{ + obuType: int((b >> 3) & 0x0F), + hasExtension: b&0x04 != 0, + hasSizeField: b&0x02 != 0, + headerLen: 1, + } + + pos := 1 + if h.hasExtension { + if len(data) < 2 { + return obuHeader{}, errAV1TooShort + } + ext := data[1] + h.temporalID = int((ext >> 5) & 0x07) + h.spatialID = int((ext >> 3) & 0x03) + h.headerLen = 2 + pos = 2 + } + + if h.hasSizeField { + size, n, err := readLEB128(data[pos:]) + if err != nil { + return obuHeader{}, err + } + h.payloadSize = size + h.totalSize = pos + n + size + } else { + // Without size field, OBU extends to end of temporal unit. + h.payloadSize = len(data) - pos + h.totalSize = len(data) + } + + return h, nil +} + +// readLEB128 reads a LEB128 encoded unsigned integer from data. +// Returns the value, number of bytes consumed, and any error. +func readLEB128(data []byte) (int, int, error) { + val := 0 + for i := 0; i < len(data) && i < 8; i++ { + b := data[i] + val |= int(b&0x7F) << (i * 7) + if b&0x80 == 0 { + return val, i + 1, nil + } + } + return 0, 0, errAV1LEB128 +} + +// av1BitReader is a MSB-first bit reader that tracks errors internally. +// After an overflow, all subsequent reads return zero and the error is +// preserved in the err field. +type av1BitReader struct { + data []byte + pos int + bit int + err error +} + +func newAV1BitReader(data []byte) *av1BitReader { + return &av1BitReader{data: data} +} + +func (br *av1BitReader) readBits(n int) uint { + if br.err != nil { + return 0 + } + var val uint + for i := 0; i < n; i++ { + if br.pos >= len(br.data) { + br.err = errAV1TooShort + return 0 + } + val = (val << 1) | uint((br.data[br.pos]>>(7-br.bit))&1) + br.bit++ + if br.bit == 8 { + br.bit = 0 + br.pos++ + } + } + return val +} + +// readUvlc reads an unsigned variable-length code (AV1 spec §4.10.3). +func (br *av1BitReader) readUvlc() uint { + if br.err != nil { + return 0 + } + leadingZeros := 0 + for { + b := br.readBits(1) + if br.err != nil { + return 0 + } + if b == 1 { + break + } + leadingZeros++ + if leadingZeros >= 32 { + br.err = errAV1SeqHdrParse + return 0 + } + } + if leadingZeros == 0 { + return 0 + } + val := br.readBits(leadingZeros) + return val + (1 << leadingZeros) - 1 +} + +// ParseAV1SequenceHeader parses an AV1 Sequence Header OBU (including OBU +// header bytes) and extracts codec parameters. The input must start with +// the OBU header byte. +func ParseAV1SequenceHeader(data []byte) (*AV1SequenceHeader, error) { + if len(data) < 2 { + return nil, errAV1TooShort + } + + h, err := parseOBUHeader(data) + if err != nil { + return nil, err + } + + if h.obuType != OBUSequenceHeader { + return nil, fmt.Errorf("expected OBU_SEQUENCE_HEADER (type %d), got type %d", OBUSequenceHeader, h.obuType) + } + + // Payload starts after header + LEB128 size + payloadStart := h.totalSize - h.payloadSize + if payloadStart >= len(data) || payloadStart+h.payloadSize > len(data) { + return nil, errAV1TooShort + } + payload := data[payloadStart : payloadStart+h.payloadSize] + + return parseSequenceHeaderPayload(payload) +} + +// parseSequenceHeaderPayload parses the raw payload of a sequence_header_obu. +func parseSequenceHeaderPayload(payload []byte) (*AV1SequenceHeader, error) { + br := newAV1BitReader(payload) + hdr := &AV1SequenceHeader{} + + hdr.SeqProfile = int(br.readBits(3)) + stillPicture := br.readBits(1) + reducedStillPicture := br.readBits(1) + _ = stillPicture + + var decoderModelInfoPresent bool + + if reducedStillPicture == 1 { + hdr.SeqLevelIdx = int(br.readBits(5)) + hdr.SeqTier = 0 + } else { + timingInfoPresent := br.readBits(1) + + if timingInfoPresent == 1 { + // timing_info() + br.readBits(32) // num_units_in_display_tick + br.readBits(32) // time_scale + equalPictureInterval := br.readBits(1) + if equalPictureInterval == 1 { + br.readUvlc() // num_ticks_per_picture_minus_1 + } + + dmi := br.readBits(1) + decoderModelInfoPresent = dmi == 1 + if decoderModelInfoPresent { + br.readBits(5) // buffer_delay_length_minus_1 + br.readBits(32) // num_units_in_decoding_tick + br.readBits(5) // buffer_removal_time_length_minus_1 + br.readBits(5) // frame_presentation_time_length_minus_1 + } + } + + initialDisplayDelayPresent := br.readBits(1) + operatingPointsCntMinus1 := br.readBits(5) + + for i := uint(0); i <= operatingPointsCntMinus1; i++ { + br.readBits(12) // operating_point_idc + levelIdx := br.readBits(5) + var tier uint + if levelIdx > 7 { + tier = br.readBits(1) + } + if i == 0 { + hdr.SeqLevelIdx = int(levelIdx) + hdr.SeqTier = int(tier) + } + if decoderModelInfoPresent { + br.readBits(1) // decoder_model_present_for_this_op + } + if initialDisplayDelayPresent == 1 { + flag := br.readBits(1) + if flag == 1 { + br.readBits(4) // initial_display_delay_minus_1 + } + } + } + } + + if br.err != nil { + return nil, fmt.Errorf("%w: %v", errAV1SeqHdrParse, br.err) + } + + frameWidthBitsMinus1 := br.readBits(4) + frameHeightBitsMinus1 := br.readBits(4) + maxFrameWidthMinus1 := br.readBits(int(frameWidthBitsMinus1) + 1) + maxFrameHeightMinus1 := br.readBits(int(frameHeightBitsMinus1) + 1) + hdr.Width = int(maxFrameWidthMinus1) + 1 + hdr.Height = int(maxFrameHeightMinus1) + 1 + + if reducedStillPicture == 0 { + frameIDNumbersPresent := br.readBits(1) + if frameIDNumbersPresent == 1 { + br.readBits(4) // delta_frame_id_length_minus_2 + br.readBits(3) // additional_frame_id_length_minus_1 + } + } + + br.readBits(1) // use_128x128_superblock + br.readBits(1) // enable_filter_intra + br.readBits(1) // enable_intra_edge_filter + + if reducedStillPicture == 0 { + br.readBits(1) // enable_interintra_compound + br.readBits(1) // enable_masked_compound + br.readBits(1) // enable_warped_motion + br.readBits(1) // enable_dual_filter + enableOrderHint := br.readBits(1) + + if enableOrderHint == 1 { + br.readBits(1) // enable_jnt_comp + br.readBits(1) // enable_ref_frame_mvs + } + + seqChooseScreenContentTools := br.readBits(1) + var seqForceScreenContentTools uint + if seqChooseScreenContentTools == 0 { + seqForceScreenContentTools = br.readBits(1) + } else { + seqForceScreenContentTools = 2 // SELECT_SCREEN_CONTENT_TOOLS + } + + if seqForceScreenContentTools > 0 { + seqChooseIntegerMV := br.readBits(1) + if seqChooseIntegerMV == 0 { + br.readBits(1) // seq_force_integer_mv + } + } + + if enableOrderHint == 1 { + br.readBits(3) // order_hint_bits_minus_1 + } + } + + br.readBits(1) // enable_superres + br.readBits(1) // enable_cdef + br.readBits(1) // enable_restoration + + if br.err != nil { + return nil, fmt.Errorf("%w: %v", errAV1SeqHdrParse, br.err) + } + + // color_config() + parseColorConfig(br, hdr) + + if br.err != nil { + return nil, fmt.Errorf("%w: %v", errAV1SeqHdrParse, br.err) + } + + return hdr, nil +} + +// parseColorConfig parses the color_config() section of a sequence header. +func parseColorConfig(br *av1BitReader, hdr *AV1SequenceHeader) { + highBitDepth := br.readBits(1) + + if hdr.SeqProfile == 2 && highBitDepth == 1 { + twelveBit := br.readBits(1) + if twelveBit == 1 { + hdr.BitDepth = 12 + } else { + hdr.BitDepth = 10 + } + } else if highBitDepth == 1 { + hdr.BitDepth = 10 + } else { + hdr.BitDepth = 8 + } + + var monoChrome uint + if hdr.SeqProfile == 1 { + monoChrome = 0 + } else { + monoChrome = br.readBits(1) + } + hdr.Monochrome = monoChrome == 1 + + colorDescriptionPresent := br.readBits(1) + var colorPrimaries, transferCharacteristics, matrixCoefficients uint + if colorDescriptionPresent == 1 { + colorPrimaries = br.readBits(8) + transferCharacteristics = br.readBits(8) + matrixCoefficients = br.readBits(8) + } else { + colorPrimaries = 2 // CP_UNSPECIFIED + transferCharacteristics = 2 // TC_UNSPECIFIED + matrixCoefficients = 2 // MC_UNSPECIFIED + } + + if monoChrome == 1 { + br.readBits(1) // color_range + hdr.ChromaSubsamplingX = 1 + hdr.ChromaSubsamplingY = 1 + return + } + + // Check for sRGB/sYCC + if colorPrimaries == 1 && transferCharacteristics == 13 && matrixCoefficients == 0 { + // color_range is implicitly 1 (full) + hdr.ChromaSubsamplingX = 0 + hdr.ChromaSubsamplingY = 0 + return + } + + br.readBits(1) // color_range + + if hdr.SeqProfile == 0 { + hdr.ChromaSubsamplingX = 1 + hdr.ChromaSubsamplingY = 1 + } else if hdr.SeqProfile == 1 { + hdr.ChromaSubsamplingX = 0 + hdr.ChromaSubsamplingY = 0 + } else { + // Profile 2 + if hdr.BitDepth == 12 { + sx := br.readBits(1) + hdr.ChromaSubsamplingX = int(sx) + if sx == 1 { + hdr.ChromaSubsamplingY = int(br.readBits(1)) + } else { + hdr.ChromaSubsamplingY = 0 + } + } else { + hdr.ChromaSubsamplingX = 1 + hdr.ChromaSubsamplingY = 0 + } + } + + if hdr.ChromaSubsamplingX == 1 && hdr.ChromaSubsamplingY == 1 { + br.readBits(2) // chroma_sample_position + } + + br.readBits(1) // separate_uv_delta_q +} + +// IsAV1Keyframe scans OBU headers in a temporal unit for a key frame. +// It checks OBU_FRAME (type 6) and OBU_FRAME_HEADER (type 3) for +// frame_type == KEY_FRAME (0). Returns false for nil/empty input. +func IsAV1Keyframe(temporalUnit []byte) bool { + if len(temporalUnit) == 0 { + return false + } + + pos := 0 + for pos < len(temporalUnit) { + h, err := parseOBUHeader(temporalUnit[pos:]) + if err != nil { + return false + } + + if h.obuType == OBUFrame || h.obuType == OBUFrameHeader { + // Payload starts after header bytes and optional LEB128 size + payloadStart := pos + h.totalSize - h.payloadSize + if payloadStart >= len(temporalUnit) || h.payloadSize < 1 { + return false + } + // frame_header_obu(): first bit is show_existing_frame + // If show_existing_frame==1, this is not a new frame. + // If show_existing_frame==0, next 2 bits are frame_type. + payload := temporalUnit[payloadStart:] + if len(payload) < 1 { + return false + } + showExisting := (payload[0] >> 7) & 1 + if showExisting == 1 { + return false + } + if len(payload) < 1 { + return false + } + frameType := (payload[0] >> 5) & 0x03 + return frameType == av1FrameKey + } + + if h.hasSizeField { + pos += h.totalSize + } else { + // No size field — OBU extends to end + break + } + } + return false +} + +// FindSequenceHeaderOBU scans a temporal unit for an OBU_SEQUENCE_HEADER +// (type 1) and returns the complete OBU bytes (header + payload). +// Returns nil if no sequence header is found. +func FindSequenceHeaderOBU(temporalUnit []byte) []byte { + if len(temporalUnit) == 0 { + return nil + } + + pos := 0 + for pos < len(temporalUnit) { + h, err := parseOBUHeader(temporalUnit[pos:]) + if err != nil { + return nil + } + + if h.obuType == OBUSequenceHeader { + end := pos + h.totalSize + if end > len(temporalUnit) { + end = len(temporalUnit) + } + return temporalUnit[pos:end] + } + + if h.hasSizeField { + pos += h.totalSize + } else { + break + } + } + return nil +} diff --git a/demux/av1_test.go b/demux/av1_test.go new file mode 100644 index 0000000..38a3271 --- /dev/null +++ b/demux/av1_test.go @@ -0,0 +1,674 @@ +package demux + +import ( + "testing" +) + +// Test data generated with ffmpeg using libsvtav1. +// 1920x1080, profile 0, level 8, 8-bit, 4:2:0. +var testSeqHdrOBU1080p = []byte{ + 0x0a, 0x0b, // OBU header (type=1, has_size=1) + LEB128 size=11 + 0x00, 0x00, 0x00, 0x42, 0xab, 0xbf, 0xc3, 0x71, + 0x2b, 0xe4, 0x01, +} + +// 640x480, profile 0, level 4, 10-bit, 4:2:0. +var testSeqHdrOBU480p10 = []byte{ + 0x0a, 0x0b, // OBU header (type=1, has_size=1) + LEB128 size=11 + 0x00, 0x00, 0x00, 0x24, 0xc4, 0xff, 0xdf, 0x12, + 0xbe, 0x50, 0x10, +} + +// Temporal unit containing TD + seq header + key frame (1920x1080 8-bit). +var testTemporalUnitKeyframe = []byte{ + 0x12, 0x00, // OBU_TEMPORAL_DELIMITER, has_size=1, size=0 + 0x0a, 0x0b, // OBU_SEQUENCE_HEADER, has_size=1, size=11 + 0x00, 0x00, 0x00, 0x42, 0xab, 0xbf, 0xc3, 0x71, + 0x2b, 0xe4, 0x01, + 0x32, 0x2d, // OBU_FRAME, has_size=1, size=45 + 0x10, 0x00, 0x8b, 0x00, 0x81, 0x45, 0x08, 0x20, + 0x40, 0x00, 0x00, 0x03, 0x24, 0xfe, 0x6a, 0x75, + 0xf5, 0xb5, 0x7d, 0xaa, 0x9a, 0x90, 0x9b, 0xe8, + 0x41, 0xd4, 0x1a, 0x98, 0x69, 0x56, 0xfa, 0xf0, + 0xa8, 0xe6, 0xda, 0x81, 0x0c, 0x1c, 0xe9, 0xc2, + 0xf3, 0x84, 0xbd, 0x86, 0xab, +} + +// Temporal unit containing TD + inter frame (no sequence header). +var testTemporalUnitInterFrame = []byte{ + 0x12, 0x00, // OBU_TEMPORAL_DELIMITER, has_size=1, size=0 + 0x32, 0x18, // OBU_FRAME, has_size=1, size=24 + 0x30, 0x02, 0x04, 0x09, 0x24, 0x92, 0x22, 0x46, + 0x00, 0x00, 0x00, 0x80, 0x00, 0x00, 0x00, 0xbc, + 0x32, 0xc1, 0x64, 0xdc, 0x91, 0xe8, 0x51, 0xe8, +} + +func TestParseAV1SequenceHeader1080p(t *testing.T) { + t.Parallel() + + hdr, err := ParseAV1SequenceHeader(testSeqHdrOBU1080p) + if err != nil { + t.Fatalf("ParseAV1SequenceHeader error: %v", err) + } + + if hdr.SeqProfile != 0 { + t.Errorf("SeqProfile: got %d, want 0", hdr.SeqProfile) + } + if hdr.SeqLevelIdx != 8 { + t.Errorf("SeqLevelIdx: got %d, want 8", hdr.SeqLevelIdx) + } + if hdr.SeqTier != 0 { + t.Errorf("SeqTier: got %d, want 0", hdr.SeqTier) + } + if hdr.BitDepth != 8 { + t.Errorf("BitDepth: got %d, want 8", hdr.BitDepth) + } + if hdr.Width != 1920 { + t.Errorf("Width: got %d, want 1920", hdr.Width) + } + if hdr.Height != 1080 { + t.Errorf("Height: got %d, want 1080", hdr.Height) + } + if hdr.ChromaSubsamplingX != 1 { + t.Errorf("ChromaSubsamplingX: got %d, want 1", hdr.ChromaSubsamplingX) + } + if hdr.ChromaSubsamplingY != 1 { + t.Errorf("ChromaSubsamplingY: got %d, want 1", hdr.ChromaSubsamplingY) + } + if hdr.Monochrome { + t.Error("Monochrome: got true, want false") + } +} + +func TestParseAV1SequenceHeader480p10Bit(t *testing.T) { + t.Parallel() + + hdr, err := ParseAV1SequenceHeader(testSeqHdrOBU480p10) + if err != nil { + t.Fatalf("ParseAV1SequenceHeader error: %v", err) + } + + if hdr.SeqProfile != 0 { + t.Errorf("SeqProfile: got %d, want 0", hdr.SeqProfile) + } + if hdr.SeqLevelIdx != 4 { + t.Errorf("SeqLevelIdx: got %d, want 4", hdr.SeqLevelIdx) + } + if hdr.SeqTier != 0 { + t.Errorf("SeqTier: got %d, want 0", hdr.SeqTier) + } + if hdr.BitDepth != 10 { + t.Errorf("BitDepth: got %d, want 10", hdr.BitDepth) + } + if hdr.Width != 640 { + t.Errorf("Width: got %d, want 640", hdr.Width) + } + if hdr.Height != 480 { + t.Errorf("Height: got %d, want 480", hdr.Height) + } + if hdr.ChromaSubsamplingX != 1 { + t.Errorf("ChromaSubsamplingX: got %d, want 1", hdr.ChromaSubsamplingX) + } + if hdr.ChromaSubsamplingY != 1 { + t.Errorf("ChromaSubsamplingY: got %d, want 1", hdr.ChromaSubsamplingY) + } + if hdr.Monochrome { + t.Error("Monochrome: got true, want false") + } +} + +func TestParseAV1SequenceHeaderNil(t *testing.T) { + t.Parallel() + + _, err := ParseAV1SequenceHeader(nil) + if err == nil { + t.Error("expected error for nil input") + } +} + +func TestParseAV1SequenceHeaderEmpty(t *testing.T) { + t.Parallel() + + _, err := ParseAV1SequenceHeader([]byte{}) + if err == nil { + t.Error("expected error for empty input") + } +} + +func TestParseAV1SequenceHeaderTruncated(t *testing.T) { + t.Parallel() + + // Just the OBU header, no payload + _, err := ParseAV1SequenceHeader([]byte{0x0a, 0x0b}) + if err == nil { + t.Error("expected error for truncated input") + } + + // Partial payload + _, err = ParseAV1SequenceHeader([]byte{0x0a, 0x0b, 0x00, 0x00}) + if err == nil { + t.Error("expected error for truncated payload") + } +} + +func TestParseAV1SequenceHeaderWrongOBUType(t *testing.T) { + t.Parallel() + + // OBU_FRAME (type 6) instead of SEQUENCE_HEADER + _, err := ParseAV1SequenceHeader([]byte{0x32, 0x01, 0x00}) + if err == nil { + t.Error("expected error for wrong OBU type") + } +} + +func TestAV1CodecString1080p(t *testing.T) { + t.Parallel() + + hdr, err := ParseAV1SequenceHeader(testSeqHdrOBU1080p) + if err != nil { + t.Fatalf("ParseAV1SequenceHeader error: %v", err) + } + + want := "av01.0.08M.08" + got := hdr.CodecString() + if got != want { + t.Errorf("CodecString: got %q, want %q", got, want) + } +} + +func TestAV1CodecString480p10Bit(t *testing.T) { + t.Parallel() + + hdr, err := ParseAV1SequenceHeader(testSeqHdrOBU480p10) + if err != nil { + t.Fatalf("ParseAV1SequenceHeader error: %v", err) + } + + want := "av01.0.04M.10" + got := hdr.CodecString() + if got != want { + t.Errorf("CodecString: got %q, want %q", got, want) + } +} + +func TestAV1CodecStringHighTier(t *testing.T) { + t.Parallel() + + hdr := &AV1SequenceHeader{ + SeqProfile: 1, + SeqLevelIdx: 13, + SeqTier: 1, + BitDepth: 10, + } + + want := "av01.1.13H.10" + got := hdr.CodecString() + if got != want { + t.Errorf("CodecString: got %q, want %q", got, want) + } +} + +func TestAV1CodecStringZeroPadding(t *testing.T) { + t.Parallel() + + hdr := &AV1SequenceHeader{ + SeqProfile: 0, + SeqLevelIdx: 1, + SeqTier: 0, + BitDepth: 8, + } + + want := "av01.0.01M.08" + got := hdr.CodecString() + if got != want { + t.Errorf("CodecString: got %q, want %q", got, want) + } +} + +func TestIsAV1KeyframeTrue(t *testing.T) { + t.Parallel() + + if !IsAV1Keyframe(testTemporalUnitKeyframe) { + t.Error("expected true for keyframe temporal unit") + } +} + +func TestIsAV1KeyframeFalseInterFrame(t *testing.T) { + t.Parallel() + + if IsAV1Keyframe(testTemporalUnitInterFrame) { + t.Error("expected false for inter frame temporal unit") + } +} + +func TestIsAV1KeyframeFalseNil(t *testing.T) { + t.Parallel() + + if IsAV1Keyframe(nil) { + t.Error("expected false for nil input") + } +} + +func TestIsAV1KeyframeFalseEmpty(t *testing.T) { + t.Parallel() + + if IsAV1Keyframe([]byte{}) { + t.Error("expected false for empty input") + } +} + +func TestIsAV1KeyframeFalseTruncated(t *testing.T) { + t.Parallel() + + // Just a temporal delimiter, no frame OBU + if IsAV1Keyframe([]byte{0x12, 0x00}) { + t.Error("expected false for temporal delimiter only") + } +} + +func TestFindSequenceHeaderOBUPresent(t *testing.T) { + t.Parallel() + + obu := FindSequenceHeaderOBU(testTemporalUnitKeyframe) + if obu == nil { + t.Fatal("expected non-nil OBU") + } + + // Verify the found OBU is a valid sequence header + hdr, err := ParseAV1SequenceHeader(obu) + if err != nil { + t.Fatalf("ParseAV1SequenceHeader on found OBU: %v", err) + } + if hdr.Width != 1920 || hdr.Height != 1080 { + t.Errorf("resolution: got %dx%d, want 1920x1080", hdr.Width, hdr.Height) + } +} + +func TestFindSequenceHeaderOBUAbsent(t *testing.T) { + t.Parallel() + + obu := FindSequenceHeaderOBU(testTemporalUnitInterFrame) + if obu != nil { + t.Errorf("expected nil for inter frame, got %d bytes", len(obu)) + } +} + +func TestFindSequenceHeaderOBUNil(t *testing.T) { + t.Parallel() + + obu := FindSequenceHeaderOBU(nil) + if obu != nil { + t.Error("expected nil for nil input") + } +} + +func TestFindSequenceHeaderOBUEmpty(t *testing.T) { + t.Parallel() + + obu := FindSequenceHeaderOBU([]byte{}) + if obu != nil { + t.Error("expected nil for empty input") + } +} + +func TestReadLEB128(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + data []byte + want int + wantN int + wantErr bool + }{ + { + name: "zero", + data: []byte{0x00}, + want: 0, + wantN: 1, + }, + { + name: "one byte value 11", + data: []byte{0x0b}, + want: 11, + wantN: 1, + }, + { + name: "one byte value 127", + data: []byte{0x7f}, + want: 127, + wantN: 1, + }, + { + name: "two byte value 128", + data: []byte{0x80, 0x01}, + want: 128, + wantN: 2, + }, + { + name: "two byte value 300", + data: []byte{0xac, 0x02}, + want: 300, + wantN: 2, + }, + { + name: "empty", + data: []byte{}, + wantErr: true, + }, + { + name: "unterminated", + data: []byte{0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80}, + wantErr: true, + }, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + val, n, err := readLEB128(tt.data) + if tt.wantErr { + if err == nil { + t.Error("expected error") + } + return + } + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if val != tt.want { + t.Errorf("value: got %d, want %d", val, tt.want) + } + if n != tt.wantN { + t.Errorf("bytes consumed: got %d, want %d", n, tt.wantN) + } + }) + } +} + +func TestParseOBUHeader(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + data []byte + wantType int + wantHasSize bool + wantHasExt bool + wantErr bool + }{ + { + name: "sequence header with size", + data: []byte{0x0a, 0x0b, 0x00, 0x00}, + wantType: OBUSequenceHeader, + wantHasSize: true, + }, + { + name: "temporal delimiter with size", + data: []byte{0x12, 0x00}, + wantType: OBUTemporalDelimiter, + wantHasSize: true, + }, + { + name: "frame with size", + data: []byte{0x32, 0x2d, 0x10}, + wantType: OBUFrame, + wantHasSize: true, + }, + { + name: "forbidden bit set", + data: []byte{0x8a}, + wantErr: true, + }, + { + name: "empty", + data: []byte{}, + wantErr: true, + }, + { + name: "extension flag", + data: []byte{0x0e, 0x40, 0x0b, 0x00, 0x00}, + wantType: OBUSequenceHeader, + wantHasExt: true, + wantHasSize: true, + }, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + h, err := parseOBUHeader(tt.data) + if tt.wantErr { + if err == nil { + t.Error("expected error") + } + return + } + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if h.obuType != tt.wantType { + t.Errorf("obuType: got %d, want %d", h.obuType, tt.wantType) + } + if h.hasSizeField != tt.wantHasSize { + t.Errorf("hasSizeField: got %v, want %v", h.hasSizeField, tt.wantHasSize) + } + if h.hasExtension != tt.wantHasExt { + t.Errorf("hasExtension: got %v, want %v", h.hasExtension, tt.wantHasExt) + } + }) + } +} + +func TestAV1BitReaderOverflow(t *testing.T) { + t.Parallel() + + br := newAV1BitReader([]byte{0xFF}) + _ = br.readBits(8) // should succeed + _ = br.readBits(1) // should set error + + if br.err == nil { + t.Error("expected error after reading past end") + } + + // Subsequent reads should return 0 without panicking + val := br.readBits(8) + if val != 0 { + t.Errorf("expected 0 after error, got %d", val) + } +} + +func TestAV1BitReaderReadBits(t *testing.T) { + t.Parallel() + + // 0xA5 = 10100101 + br := newAV1BitReader([]byte{0xA5}) + + got := br.readBits(3) + if got != 5 { // 101 + t.Errorf("first 3 bits: got %d, want 5", got) + } + + got = br.readBits(5) + if got != 5 { // 00101 + t.Errorf("next 5 bits: got %d, want 5", got) + } + + if br.err != nil { + t.Errorf("unexpected error: %v", br.err) + } +} + +func TestAV1BitReaderUvlc(t *testing.T) { + t.Parallel() + + // UVLC encoding of 0: just a single 1-bit → value 0 + br := newAV1BitReader([]byte{0x80}) // 10000000 + got := br.readUvlc() + if got != 0 { + t.Errorf("uvlc(0): got %d, want 0", got) + } + + // UVLC encoding of 1: 0,1,0 → leadingZeros=1, val=0, result = 0 + (1<<1) - 1 = 1 + br = newAV1BitReader([]byte{0x40}) // 01000000 + got = br.readUvlc() + if got != 1 { + t.Errorf("uvlc(1): got %d, want 1", got) + } + + // UVLC encoding of 2: 0,1,1 → leadingZeros=1, val=1, result = 1 + (1<<1) - 1 = 2 + br = newAV1BitReader([]byte{0x60}) // 01100000 + got = br.readUvlc() + if got != 2 { + t.Errorf("uvlc(2): got %d, want 2", got) + } +} + +func TestReducedStillPictureHeader(t *testing.T) { + t.Parallel() + + // Hand-craft a minimal sequence header with reduced_still_picture_header=1 + // Profile 0, still_picture=1, reduced=1, level_idx=1 (5 bits) + // Then frame_width_bits_minus_1=3, frame_height_bits_minus_1=3 (4+4 bits) + // Then width=128-1=127 (4 bits), height=96-1=95 (4 bits) + // No frame_id_numbers_present (since reduced) + // use_128x128_superblock=0, enable_filter_intra=0, enable_intra_edge_filter=0 + // No interintra/masked/etc (since reduced) + // enable_superres=0, enable_cdef=0, enable_restoration=0 + // color_config: high_bitdepth=0 → 8-bit, mono_chrome=0 + // color_description_present=0 → unspecified + // color_range=0 + // chroma_sample_position=0 (2 bits, since profile 0 → 4:2:0) + // separate_uv_delta_q=0 + + // Bits: profile(3)=000, still(1)=1, reduced(1)=1 + // level_idx(5)=00001 + // fw_bits(4)=0011, fh_bits(4)=0011 + // width(4)=0111 1111, height(4)=0101 1111 + // Wait, width is fw_bits+1=4 bits, so 127 in 4 bits doesn't fit. + // Let's use smaller: width=16, height=16 + // fw_bits_m1=3 → 4 bits for width, fh_bits_m1=3 → 4 bits for height + // width-1=15=0b1111, height-1=15=0b1111 + // Then: use_128x128_sb=0, filter_intra=0, intra_edge_filter=0 (3 bits) + // enable_superres=0, enable_cdef=0, enable_restoration=0 (3 bits) + // high_bitdepth=0 (1 bit) + // mono_chrome=0 (1 bit, profile != 1) + // color_desc_present=0 (1 bit) + // color_range=0 (1 bit) + // chroma_sample_position=00 (2 bits) + // separate_uv_delta_q=0 (1 bit) + + // Layout: + // byte 0: 000 1 1 000 = 0x18 + // byte 1: 01 0011 00 = 0x4C (level_idx=00001, fw_bits=0011, start of fh_bits) + // byte 2: 11 1111 11 = 0xFF (fh_bits last 2 bits=11, width=1111, height first 2=11) + // byte 3: 11 000 000 = 0xC0 (height last 2=11, sb=0, filter_intra=0, intra_edge=0, superres=0, cdef=0) + // byte 4: 0 0 0 0 0 0 00 = 0x00 (restoration=0, high_bitdepth=0, mono=0, cdp=0, color_range=0, csp=00, sep_uv=0) + + // Actually let me be more careful: + // Bit stream (MSB first): + // [0] profile: 000 + // [3] still_picture: 1 + // [4] reduced_still_picture_header: 1 + // [5] seq_level_idx_0: 00001 + // [10] frame_width_bits_minus_1: 0011 + // [14] frame_height_bits_minus_1: 0011 + // [18] max_frame_width_minus_1 (4 bits): 1111 + // [22] max_frame_height_minus_1 (4 bits): 1111 + // [26] use_128x128_superblock: 0 + // [27] enable_filter_intra: 0 + // [28] enable_intra_edge_filter: 0 + // [29] enable_superres: 0 + // [30] enable_cdef: 0 + // [31] enable_restoration: 0 + // [32] high_bitdepth: 0 + // [33] mono_chrome: 0 + // [34] color_description_present_flag: 0 + // [35] color_range: 0 + // [36] chroma_sample_position: 00 + // [38] separate_uv_delta_q: 0 + + // Byte layout: + // bits 0-7: 000 1 1 000 = 0x18 + // bits 8-15: 01 0011 00 = 0x4C + // bits 16-23: 11 1111 11 = 0xFF + // bits 24-31: 11 000000 = 0xC0 + // bits 32-39: 00000 000 = 0x00 + + payload := []byte{0x18, 0x4C, 0xFF, 0xC0, 0x00} + + // Wrap in OBU header: type=1, has_size=1 + obu := make([]byte, 0, 2+len(payload)) + obu = append(obu, 0x0a) // OBU header: type=1, has_size=1 + obu = append(obu, byte(len(payload))) // LEB128 size + obu = append(obu, payload...) + + hdr, err := ParseAV1SequenceHeader(obu) + if err != nil { + t.Fatalf("ParseAV1SequenceHeader error: %v", err) + } + + if hdr.SeqProfile != 0 { + t.Errorf("SeqProfile: got %d, want 0", hdr.SeqProfile) + } + if hdr.SeqLevelIdx != 1 { + t.Errorf("SeqLevelIdx: got %d, want 1", hdr.SeqLevelIdx) + } + if hdr.SeqTier != 0 { + t.Errorf("SeqTier: got %d, want 0", hdr.SeqTier) + } + if hdr.Width != 16 { + t.Errorf("Width: got %d, want 16", hdr.Width) + } + if hdr.Height != 16 { + t.Errorf("Height: got %d, want 16", hdr.Height) + } + if hdr.BitDepth != 8 { + t.Errorf("BitDepth: got %d, want 8", hdr.BitDepth) + } +} + +func TestIsAV1KeyframeFrameHeaderOBU(t *testing.T) { + t.Parallel() + + // Construct a temporal unit with OBU_FRAME_HEADER (type 3) instead of OBU_FRAME + // OBU_FRAME_HEADER: type=3, has_size=1 → (3<<3)|2 = 0x1A + // Payload: show_existing=0, frame_type=0 (KEY) → first byte = 0b0_00_xxxxx = 0x00 + tu := []byte{ + 0x12, 0x00, // TD + 0x1a, 0x02, // OBU_FRAME_HEADER, size=2 + 0x00, 0x00, // show_existing=0, frame_type=0 (KEY) + } + + if !IsAV1Keyframe(tu) { + t.Error("expected true for OBU_FRAME_HEADER with KEY_FRAME type") + } +} + +func TestIsAV1KeyframeShowExistingFrame(t *testing.T) { + t.Parallel() + + // show_existing_frame = 1 → first payload byte: 1_xx_xxxxx = 0x80 + tu := []byte{ + 0x12, 0x00, // TD + 0x32, 0x02, // OBU_FRAME, size=2 + 0x80, 0x00, // show_existing_frame=1 + } + + if IsAV1Keyframe(tu) { + t.Error("expected false for show_existing_frame=1") + } +} + +func TestFindSequenceHeaderOBUTruncated(t *testing.T) { + t.Parallel() + + // Sequence header OBU header but truncated payload + data := []byte{0x0a, 0x80} // has_size with unterminated LEB128 + obu := FindSequenceHeaderOBU(data) + if obu != nil { + t.Error("expected nil for truncated OBU") + } +} From 290a6ba9d8a68994b5ee9de5982be6071c927a36 Mon Sep 17 00:00:00 2001 From: Thomas Symborski Date: Wed, 18 Mar 2026 16:24:57 -0400 Subject: [PATCH 02/14] fix: consume operating_parameters_info in AV1 sequence header parsing When decoder_model_present_for_this_op is 1, the AV1 spec (section 5.5.1) requires reading decoder_buffer_delay, encoder_buffer_delay, and low_delay_mode_flag. The parser was reading the flag but not consuming these fields, causing the bit reader to be misaligned for all subsequent fields in the sequence header. --- demux/av1.go | 13 +++- demux/av1_test.go | 188 ++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 199 insertions(+), 2 deletions(-) diff --git a/demux/av1.go b/demux/av1.go index 0d4346f..763741a 100644 --- a/demux/av1.go +++ b/demux/av1.go @@ -230,6 +230,7 @@ func parseSequenceHeaderPayload(payload []byte) (*AV1SequenceHeader, error) { _ = stillPicture var decoderModelInfoPresent bool + var bufferDelayLengthMinus1 uint if reducedStillPicture == 1 { hdr.SeqLevelIdx = int(br.readBits(5)) @@ -249,7 +250,8 @@ func parseSequenceHeaderPayload(payload []byte) (*AV1SequenceHeader, error) { dmi := br.readBits(1) decoderModelInfoPresent = dmi == 1 if decoderModelInfoPresent { - br.readBits(5) // buffer_delay_length_minus_1 + // decoder_model_info() + bufferDelayLengthMinus1 = br.readBits(5) br.readBits(32) // num_units_in_decoding_tick br.readBits(5) // buffer_removal_time_length_minus_1 br.readBits(5) // frame_presentation_time_length_minus_1 @@ -271,7 +273,14 @@ func parseSequenceHeaderPayload(payload []byte) (*AV1SequenceHeader, error) { hdr.SeqTier = int(tier) } if decoderModelInfoPresent { - br.readBits(1) // decoder_model_present_for_this_op + decoderModelPresentForThisOp := br.readBits(1) + if decoderModelPresentForThisOp == 1 { + // operating_parameters_info(): skip decoder/encoder buffer delays and flag + n := int(bufferDelayLengthMinus1) + 1 + br.readBits(n) // decoder_buffer_delay[i] + br.readBits(n) // encoder_buffer_delay[i] + br.readBits(1) // low_delay_mode_flag[i] + } } if initialDisplayDelayPresent == 1 { flag := br.readBits(1) diff --git a/demux/av1_test.go b/demux/av1_test.go index 38a3271..f041512 100644 --- a/demux/av1_test.go +++ b/demux/av1_test.go @@ -630,6 +630,194 @@ func TestReducedStillPictureHeader(t *testing.T) { } } +func TestParseSequenceHeaderWithDecoderModelInfo(t *testing.T) { + t.Parallel() + + // Hand-craft a sequence header payload with decoder_model_info and + // operating_parameters_info to exercise the code path where + // decoder_model_present_for_this_op=1 requires reading additional fields. + // + // Profile 0, still_picture=0, reduced_still_picture_header=0 + // timing_info_present=1, num_units_in_display_tick=1(32b), time_scale=30(32b) + // equal_picture_interval=0 + // decoder_model_info_present_flag=1 + // buffer_delay_length_minus_1=4 (5b, so delays are 5 bits each) + // num_units_in_decoding_tick=1 (32b) + // buffer_removal_time_length_minus_1=0 (5b) + // frame_presentation_time_length_minus_1=0 (5b) + // initial_display_delay_present_flag=0 + // operating_points_cnt_minus_1=0 (5b = 00000, so 1 operating point) + // operating_point_idc[0]=0 (12b) + // seq_level_idx[0]=8 (5b=01000), tier not read (level<=7 check: 8>7 so tier read) + // seq_tier[0]=0 (1b) + // decoder_model_present_for_this_op[0]=1 (1b) + // decoder_buffer_delay[0]=0 (5b, since buffer_delay_length_minus_1=4) + // encoder_buffer_delay[0]=0 (5b) + // low_delay_mode_flag[0]=0 (1b) + // frame_width_bits_minus_1=3 (4b=0011), frame_height_bits_minus_1=3 (4b=0011) + // max_frame_width_minus_1=15 (4b=1111), max_frame_height_minus_1=15 (4b=1111) + // frame_id_numbers_present=0 + // use_128x128_superblock=0, enable_filter_intra=0, enable_intra_edge_filter=0 + // enable_interintra_compound=0, enable_masked_compound=0, enable_warped_motion=0 + // enable_dual_filter=0, enable_order_hint=0 + // seq_choose_screen_content_tools=1 + // enable_superres=0, enable_cdef=0, enable_restoration=0 + // color_config: high_bitdepth=0, mono_chrome=0, color_description_present=0 + // color_range=0, chroma_sample_position=00, separate_uv_delta_q=0 + + // Bit-by-bit layout: + // [0-2] seq_profile: 000 + // [3] still_picture: 0 + // [4] reduced_still_picture_header: 0 + // [5] timing_info_present_flag: 1 + // [6-37] num_units_in_display_tick: 00000000 00000000 00000000 00000001 + // [38-69] time_scale: 00000000 00000000 00000000 00011110 + // [70] equal_picture_interval: 0 + // [71] decoder_model_info_present_flag: 1 + // [72-76] buffer_delay_length_minus_1: 00100 (=4, so delays are 5 bits) + // [77-108] num_units_in_decoding_tick: 00000000 00000000 00000000 00000001 + // [109-113] buffer_removal_time_length_minus_1: 00000 + // [114-118] frame_presentation_time_length_minus_1: 00000 + // [119] initial_display_delay_present_flag: 0 + // [120-124] operating_points_cnt_minus_1: 00000 + // [125-136] operating_point_idc[0]: 000000000000 + // [137-141] seq_level_idx[0]: 01000 (=8) + // [142] seq_tier[0]: 0 (level>7 so tier is read) + // [143] decoder_model_present_for_this_op[0]: 1 + // [144-148] decoder_buffer_delay[0]: 00000 + // [149-153] encoder_buffer_delay[0]: 00000 + // [154] low_delay_mode_flag[0]: 0 + // [155-158] frame_width_bits_minus_1: 0011 + // [159-162] frame_height_bits_minus_1: 0011 + // [163-166] max_frame_width_minus_1: 1111 + // [167-170] max_frame_height_minus_1: 1111 + // [171] frame_id_numbers_present_flag: 0 + // [172] use_128x128_superblock: 0 + // [173] enable_filter_intra: 0 + // [174] enable_intra_edge_filter: 0 + // [175] enable_interintra_compound: 0 + // [176] enable_masked_compound: 0 + // [177] enable_warped_motion: 0 + // [178] enable_dual_filter: 0 + // [179] enable_order_hint: 0 + // [180] seq_choose_screen_content_tools: 1 + // [181] enable_superres: 0 + // [182] enable_cdef: 0 + // [183] enable_restoration: 0 + // [184] high_bitdepth: 0 + // [185] mono_chrome: 0 + // [186] color_description_present_flag: 0 + // [187] color_range: 0 + // [188-189] chroma_sample_position: 00 + // [190] separate_uv_delta_q: 0 + + // Convert bit positions to bytes: + // byte 0 [0-7]: 00000 1 00 = 0x04 + // byte 1 [8-15]: 00000000 + // byte 2 [16-23]: 00000000 + // byte 3 [24-31]: 00000000 + // byte 4 [32-39]: 00000001 + // byte 5 [40-47]: 00000000 + // byte 6 [48-55]: 00000000 + // byte 7 [56-63]: 00000000 + // byte 8 [64-71]: 00011110 = 0x1E (bits 64-69 = time_scale last, 70=equal_pic=0, 71=dmi=1) + // Wait, let me redo this more carefully. + + // Let me just build it with a helper to be precise. + bits := []uint{} // each element is 0 or 1 + + appendBits := func(val uint, n int) { + for i := n - 1; i >= 0; i-- { + bits = append(bits, (val>>uint(i))&1) + } + } + + appendBits(0, 3) // seq_profile = 0 + appendBits(0, 1) // still_picture = 0 + appendBits(0, 1) // reduced_still_picture_header = 0 + appendBits(1, 1) // timing_info_present_flag = 1 + appendBits(1, 32) // num_units_in_display_tick = 1 + appendBits(30, 32) // time_scale = 30 + appendBits(0, 1) // equal_picture_interval = 0 + appendBits(1, 1) // decoder_model_info_present_flag = 1 + appendBits(4, 5) // buffer_delay_length_minus_1 = 4 + appendBits(1, 32) // num_units_in_decoding_tick = 1 + appendBits(0, 5) // buffer_removal_time_length_minus_1 = 0 + appendBits(0, 5) // frame_presentation_time_length_minus_1 = 0 + appendBits(0, 1) // initial_display_delay_present_flag = 0 + appendBits(0, 5) // operating_points_cnt_minus_1 = 0 + appendBits(0, 12) // operating_point_idc[0] = 0 + appendBits(8, 5) // seq_level_idx[0] = 8 + appendBits(0, 1) // seq_tier[0] = 0 (read because level > 7) + appendBits(1, 1) // decoder_model_present_for_this_op[0] = 1 + appendBits(0, 5) // decoder_buffer_delay[0] = 0 (5 bits) + appendBits(0, 5) // encoder_buffer_delay[0] = 0 (5 bits) + appendBits(0, 1) // low_delay_mode_flag[0] = 0 + appendBits(3, 4) // frame_width_bits_minus_1 = 3 + appendBits(3, 4) // frame_height_bits_minus_1 = 3 + appendBits(15, 4) // max_frame_width_minus_1 = 15 (width=16) + appendBits(15, 4) // max_frame_height_minus_1 = 15 (height=16) + appendBits(0, 1) // frame_id_numbers_present_flag = 0 + appendBits(0, 1) // use_128x128_superblock = 0 + appendBits(0, 1) // enable_filter_intra = 0 + appendBits(0, 1) // enable_intra_edge_filter = 0 + appendBits(0, 1) // enable_interintra_compound = 0 + appendBits(0, 1) // enable_masked_compound = 0 + appendBits(0, 1) // enable_warped_motion = 0 + appendBits(0, 1) // enable_dual_filter = 0 + appendBits(0, 1) // enable_order_hint = 0 + appendBits(1, 1) // seq_choose_screen_content_tools = 1 (SELECT) + // seqForceScreenContentTools = 2 (SELECT), which is > 0 + appendBits(1, 1) // seq_choose_integer_mv = 1 (SELECT) + appendBits(0, 1) // enable_superres = 0 + appendBits(0, 1) // enable_cdef = 0 + appendBits(0, 1) // enable_restoration = 0 + appendBits(0, 1) // high_bitdepth = 0 → 8-bit + appendBits(0, 1) // mono_chrome = 0 + appendBits(0, 1) // color_description_present_flag = 0 + appendBits(0, 1) // color_range = 0 + appendBits(0, 2) // chroma_sample_position = 0 + appendBits(0, 1) // separate_uv_delta_q = 0 + + // Pack bits into bytes + payload := make([]byte, (len(bits)+7)/8) + for i, b := range bits { + if b == 1 { + payload[i/8] |= 1 << (7 - uint(i%8)) + } + } + + // Wrap in OBU header + obu := make([]byte, 0, 2+len(payload)) + obu = append(obu, 0x0a) // OBU header: type=1, has_size=1 + obu = append(obu, byte(len(payload))) // LEB128 size + obu = append(obu, payload...) + + hdr, err := ParseAV1SequenceHeader(obu) + if err != nil { + t.Fatalf("ParseAV1SequenceHeader error: %v", err) + } + + if hdr.SeqProfile != 0 { + t.Errorf("SeqProfile: got %d, want 0", hdr.SeqProfile) + } + if hdr.SeqLevelIdx != 8 { + t.Errorf("SeqLevelIdx: got %d, want 8", hdr.SeqLevelIdx) + } + if hdr.SeqTier != 0 { + t.Errorf("SeqTier: got %d, want 0", hdr.SeqTier) + } + if hdr.Width != 16 { + t.Errorf("Width: got %d, want 16", hdr.Width) + } + if hdr.Height != 16 { + t.Errorf("Height: got %d, want 16", hdr.Height) + } + if hdr.BitDepth != 8 { + t.Errorf("BitDepth: got %d, want 8", hdr.BitDepth) + } +} + func TestIsAV1KeyframeFrameHeaderOBU(t *testing.T) { t.Parallel() From 4d386ae5fafa3f1aec5c665286b3e17b1463d4e7 Mon Sep 17 00:00:00 2001 From: Thomas Symborski Date: Wed, 18 Mar 2026 16:28:01 -0400 Subject: [PATCH 03/14] cleanup: remove unused OBU constants, dead bounds checks, and use value receiver in av1.go - Remove OBUMetadata, OBUTileGroup, OBURedundantFrameHeader, OBUTileList, OBUPadding (unused) - Remove redundant len(payload) < 1 checks in IsAV1Keyframe (already guarded by payloadSize check) - Change CodecString() to value receiver for consistency with SPSInfo/HEVCSPSInfo --- demux/av1.go | 21 +++++---------------- 1 file changed, 5 insertions(+), 16 deletions(-) diff --git a/demux/av1.go b/demux/av1.go index 763741a..225dff5 100644 --- a/demux/av1.go +++ b/demux/av1.go @@ -7,15 +7,10 @@ import ( // AV1 OBU type constants as defined in AV1 spec §6.2.2. const ( - OBUSequenceHeader = 1 - OBUTemporalDelimiter = 2 - OBUFrameHeader = 3 - OBUTileGroup = 4 - OBUMetadata = 5 - OBUFrame = 6 - OBURedundantFrameHeader = 7 - OBUTileList = 8 - OBUPadding = 15 + OBUSequenceHeader = 1 + OBUTemporalDelimiter = 2 + OBUFrameHeader = 3 + OBUFrame = 6 ) // AV1 frame type constants as defined in AV1 spec §6.8.2. @@ -50,7 +45,7 @@ type AV1SequenceHeader struct { // CodecString returns the RFC 6381 codec parameter string for AV1. // Format: "av01.P.LLT.DD" where P=profile, LL=zero-padded level, // T=tier(M/H), DD=zero-padded bit depth. -func (s *AV1SequenceHeader) CodecString() string { +func (s AV1SequenceHeader) CodecString() string { tier := "M" if s.SeqTier == 1 { tier = "H" @@ -472,16 +467,10 @@ func IsAV1Keyframe(temporalUnit []byte) bool { // If show_existing_frame==1, this is not a new frame. // If show_existing_frame==0, next 2 bits are frame_type. payload := temporalUnit[payloadStart:] - if len(payload) < 1 { - return false - } showExisting := (payload[0] >> 7) & 1 if showExisting == 1 { return false } - if len(payload) < 1 { - return false - } frameType := (payload[0] >> 5) & 0x03 return frameType == av1FrameKey } From 145181452e64b88b7fb5073a983c1f44d2849133 Mon Sep 17 00:00:00 2001 From: Thomas Symborski Date: Wed, 18 Mar 2026 16:29:46 -0400 Subject: [PATCH 04/14] feat: add BuildAV1DecoderConfig for AV1CodecConfigurationRecord Build the 4-byte AV1 codec configuration header from a parsed sequence header OBU per the AV1-ISOBMFF spec, appending the raw OBU as configOBUs. Tests cover 8-bit 1080p, 10-bit 480p, and nil/invalid inputs. --- moq/format.go | 54 ++++++++++++++++++++ moq/format_test.go | 121 +++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 175 insertions(+) diff --git a/moq/format.go b/moq/format.go index 2a55d1c..0c387a9 100644 --- a/moq/format.go +++ b/moq/format.go @@ -86,6 +86,60 @@ func BuildAVCDecoderConfig(sps, pps []byte) []byte { return buf } +// BuildAV1DecoderConfig builds an AV1CodecConfigurationRecord +// (https://aomediacodec.github.io/av1-isobmff/#av1codecconfigurationrecord) +// from a raw Sequence Header OBU (including OBU header bytes). +// The sequence header OBU is appended as configOBUs. +// Returns nil if the input is nil, too short, or fails to parse. +func BuildAV1DecoderConfig(seqHeaderOBU []byte) []byte { + if len(seqHeaderOBU) < 2 { + return nil + } + + hdr, err := demux.ParseAV1SequenceHeader(seqHeaderOBU) + if err != nil { + return nil + } + + buf := make([]byte, 0, 4+len(seqHeaderOBU)) + + // Byte 0: marker(1)=1 | version(7)=1 + buf = append(buf, 0x81) + + // Byte 1: seq_profile(3) | seq_level_idx_0(5) + buf = append(buf, byte(hdr.SeqProfile&0x07)<<5|byte(hdr.SeqLevelIdx&0x1F)) + + // Byte 2: seq_tier_0(1) | high_bitdepth(1) | twelve_bit(1) | monochrome(1) | + // chroma_subsampling_x(1) | chroma_subsampling_y(1) | chroma_sample_position(2)=0 + var highBitDepth, twelveBit byte + if hdr.BitDepth >= 10 { + highBitDepth = 1 + } + if hdr.BitDepth == 12 { + twelveBit = 1 + } + var mono byte + if hdr.Monochrome { + mono = 1 + } + b2 := byte(hdr.SeqTier&0x01)<<7 | + highBitDepth<<6 | + twelveBit<<5 | + mono<<4 | + byte(hdr.ChromaSubsamplingX&0x01)<<3 | + byte(hdr.ChromaSubsamplingY&0x01)<<2 + // chroma_sample_position (2 bits) = 0 + buf = append(buf, b2) + + // Byte 3: reserved(3)=0 | initial_presentation_delay_present(1)=0 | reserved(4)=0 + buf = append(buf, 0x00) + + // configOBUs: the sequence header OBU itself + buf = append(buf, seqHeaderOBU...) + + return buf +} + // BuildHEVCDecoderConfig builds an HEVCDecoderConfigurationRecord // (ISO 14496-15 §8.3.3.1.2) from raw VPS, SPS, and PPS NAL data // (without start codes). The SPS must include the 2-byte NAL header. diff --git a/moq/format_test.go b/moq/format_test.go index c9f26ad..f79f3e8 100644 --- a/moq/format_test.go +++ b/moq/format_test.go @@ -279,6 +279,127 @@ func TestBuildHEVCDecoderConfig(t *testing.T) { } } +func TestBuildAV1DecoderConfig(t *testing.T) { + t.Parallel() + + // 1920x1080, profile 0, level 8, 8-bit, 4:2:0 (from demux/av1_test.go) + seqHeaderOBU := []byte{ + 0x0a, 0x0b, // OBU header (type=1, has_size=1) + LEB128 size=11 + 0x00, 0x00, 0x00, 0x42, 0xab, 0xbf, 0xc3, 0x71, + 0x2b, 0xe4, 0x01, + } + + config := BuildAV1DecoderConfig(seqHeaderOBU) + if config == nil { + t.Fatal("expected non-nil config") + } + + // Total length = 4-byte header + len(seqHeaderOBU) + expectedLen := 4 + len(seqHeaderOBU) + if len(config) != expectedLen { + t.Fatalf("total length: got %d, want %d", len(config), expectedLen) + } + + // Byte 0: marker(1)=1 | version(7)=1 → 0x81 + if config[0] != 0x81 { + t.Errorf("byte 0 (marker+version): got 0x%02x, want 0x81", config[0]) + } + + // Byte 1: seq_profile(3)=0 | seq_level_idx_0(5)=8 → 0x08 + if config[1] != 0x08 { + t.Errorf("byte 1 (profile|level): got 0x%02x, want 0x08", config[1]) + } + + // Verify profile extracted correctly (top 3 bits) + profile := config[1] >> 5 + if profile != 0 { + t.Errorf("profile: got %d, want 0", profile) + } + + // Verify level extracted correctly (bottom 5 bits) + level := config[1] & 0x1F + if level != 8 { + t.Errorf("level: got %d, want 8", level) + } + + // Byte 2: tier=0, high_bitdepth=0, twelve_bit=0, mono=0, csx=1, csy=1, csp=0 → 0x0C + if config[2] != 0x0C { + t.Errorf("byte 2 (tier/bd/chroma): got 0x%02x, want 0x0C", config[2]) + } + + // Byte 3: reserved = 0x00 + if config[3] != 0x00 { + t.Errorf("byte 3 (reserved): got 0x%02x, want 0x00", config[3]) + } + + // configOBUs (bytes 4+) must match input seqHeaderOBU + if !bytes.Equal(config[4:], seqHeaderOBU) { + t.Error("configOBUs do not match input seqHeaderOBU") + } +} + +func TestBuildAV1DecoderConfig10Bit(t *testing.T) { + t.Parallel() + + // 640x480, profile 0, level 4, 10-bit, 4:2:0 (from demux/av1_test.go) + seqHeaderOBU := []byte{ + 0x0a, 0x0b, // OBU header (type=1, has_size=1) + LEB128 size=11 + 0x00, 0x00, 0x00, 0x24, 0xc4, 0xff, 0xdf, 0x12, + 0xbe, 0x50, 0x10, + } + + config := BuildAV1DecoderConfig(seqHeaderOBU) + if config == nil { + t.Fatal("expected non-nil config") + } + + // Total length = 4-byte header + len(seqHeaderOBU) + expectedLen := 4 + len(seqHeaderOBU) + if len(config) != expectedLen { + t.Fatalf("total length: got %d, want %d", len(config), expectedLen) + } + + // Byte 0: marker + version + if config[0] != 0x81 { + t.Errorf("byte 0: got 0x%02x, want 0x81", config[0]) + } + + // Byte 1: profile=0 | level=4 → 0x04 + if config[1] != 0x04 { + t.Errorf("byte 1 (profile|level): got 0x%02x, want 0x04", config[1]) + } + + // Byte 2: tier=0, high_bitdepth=1, twelve_bit=0, mono=0, csx=1, csy=1, csp=0 + // = 0<<7 | 1<<6 | 0<<5 | 0<<4 | 1<<3 | 1<<2 | 0 = 0x4C + if config[2] != 0x4C { + t.Errorf("byte 2 (tier/bd/chroma): got 0x%02x, want 0x4C", config[2]) + } + + // configOBUs (bytes 4+) must match input + if !bytes.Equal(config[4:], seqHeaderOBU) { + t.Error("configOBUs do not match input seqHeaderOBU") + } +} + +func TestBuildAV1DecoderConfigNil(t *testing.T) { + t.Parallel() + + if BuildAV1DecoderConfig(nil) != nil { + t.Error("expected nil for nil input") + } + if BuildAV1DecoderConfig([]byte{}) != nil { + t.Error("expected nil for empty input") + } + if BuildAV1DecoderConfig([]byte{0x0a}) != nil { + t.Error("expected nil for truncated input") + } + + // Wrong OBU type (OBU_FRAME instead of SEQUENCE_HEADER) + if BuildAV1DecoderConfig([]byte{0x32, 0x01, 0x00}) != nil { + t.Error("expected nil for wrong OBU type") + } +} + func TestBuildHEVCDecoderConfigNil(t *testing.T) { t.Parallel() if BuildHEVCDecoderConfig(nil, []byte{0x42, 0x01, 0x01, 0x01}, []byte{0x44}) != nil { From e311ddfa961efa71750eea053cfb372da1cd5e49 Mon Sep 17 00:00:00 2001 From: Thomas Symborski Date: Wed, 18 Mar 2026 16:32:41 -0400 Subject: [PATCH 05/14] feat: wire AV1 codec support into MoQ writer and pipeline Add AV1 as a recognized codec in WriteVideoFrame() and buildVideoInfo(), alongside existing H.264/H.265 handling. The MoQ writer now emits AV1CodecConfigurationRecord on keyframes, and the pipeline extracts resolution and codec string from AV1 sequence header OBUs. AV1 frames use raw OBU pass-through (no AnnexB conversion needed). --- distribution/moq_writer.go | 17 ++-- distribution/moq_writer_test.go | 136 ++++++++++++++++++++++++++++++++ pipeline/pipeline.go | 16 +++- pipeline/pipeline_test.go | 49 ++++++++++++ 4 files changed, 211 insertions(+), 7 deletions(-) diff --git a/distribution/moq_writer.go b/distribution/moq_writer.go index 71cc942..b176f07 100644 --- a/distribution/moq_writer.go +++ b/distribution/moq_writer.go @@ -88,12 +88,19 @@ func (m *moqWriter) WriteVideoFrame(w io.Writer, frame *media.VideoFrame) (int64 } // Video Config on keyframes (ID 13, odd → length-prefixed bytes) - if frame.IsKeyframe && frame.SPS != nil && frame.PPS != nil { + if frame.IsKeyframe && frame.SPS != nil { var configData []byte - if frame.Codec == "h265" && frame.VPS != nil { - configData = moq.BuildHEVCDecoderConfig(frame.VPS, frame.SPS, frame.PPS) - } else { - configData = moq.BuildAVCDecoderConfig(frame.SPS, frame.PPS) + switch frame.Codec { + case "av1": + configData = moq.BuildAV1DecoderConfig(frame.SPS) + case "h265": + if frame.VPS != nil && frame.PPS != nil { + configData = moq.BuildHEVCDecoderConfig(frame.VPS, frame.SPS, frame.PPS) + } + default: + if frame.PPS != nil { + configData = moq.BuildAVCDecoderConfig(frame.SPS, frame.PPS) + } } if configData != nil { exts = quicvarint.Append(exts, locExtVideoConfig) diff --git a/distribution/moq_writer_test.go b/distribution/moq_writer_test.go index 19ba36e..f4acace 100644 --- a/distribution/moq_writer_test.go +++ b/distribution/moq_writer_test.go @@ -488,6 +488,142 @@ func TestMoQWriterBytesWritten(t *testing.T) { } } +func TestMoQWriterVideoFrameAV1Keyframe(t *testing.T) { + t.Parallel() + w := NewMoQWriter(1, 0) + var buf bytes.Buffer + + if err := w.WriteStreamHeader(&buf, TrackIDVideo, 1, 0); err != nil { + t.Fatalf("WriteStreamHeader failed: %v", err) + } + buf.Reset() + + // AV1 sequence header OBU for 1920x1080 8-bit (from demux/av1_test.go) + seqHdrOBU := []byte{ + 0x0a, 0x0b, + 0x00, 0x00, 0x00, 0x42, 0xab, 0xbf, 0xc3, 0x71, + 0x2b, 0xe4, 0x01, + } + + // Raw OBU payload representing the frame data + wireData := []byte{0x32, 0x10, 0xAA, 0xBB, 0xCC, 0xDD} + + frame := &media.VideoFrame{ + PTS: 50000, + IsKeyframe: true, + Codec: "av1", + SPS: seqHdrOBU, + WireData: wireData, + } + + n, err := w.WriteVideoFrame(&buf, frame) + if err != nil { + t.Fatalf("WriteVideoFrame failed: %v", err) + } + + if n != int64(buf.Len()) { + t.Errorf("bytes written: got %d, actual buffer %d", n, buf.Len()) + } + + data := buf.Bytes() + pos := 0 + + // Object ID + objectID, nn, err := quicvarint.Parse(data[pos:]) + if err != nil { + t.Fatalf("parse object ID: %v", err) + } + if objectID != 0 { + t.Errorf("object ID: got %d, want 0", objectID) + } + pos += nn + + // Extension headers length + extLen, nn, err := quicvarint.Parse(data[pos:]) + if err != nil { + t.Fatalf("parse ext length: %v", err) + } + pos += nn + + extEnd := pos + int(extLen) + foundTimestamp := false + foundMarking := false + foundConfig := false + var configData []byte + + for pos < extEnd { + extID, nn, err := quicvarint.Parse(data[pos:]) + if err != nil { + t.Fatalf("parse ext ID: %v", err) + } + pos += nn + + if extID%2 == 0 { + val, nn, err := quicvarint.Parse(data[pos:]) + if err != nil { + t.Fatalf("parse ext value: %v", err) + } + pos += nn + + switch extID { + case locExtCaptureTimestamp: + foundTimestamp = true + if val != 50000 { + t.Errorf("capture timestamp: got %d, want 50000", val) + } + case locExtVideoFrameMarking: + foundMarking = true + if val != vfmKeyframe { + t.Errorf("video frame marking: got 0x%x, want 0x%x", val, vfmKeyframe) + } + } + } else { + valLen, nn, err := quicvarint.Parse(data[pos:]) + if err != nil { + t.Fatalf("parse ext value length: %v", err) + } + pos += nn + + if extID == locExtVideoConfig { + foundConfig = true + configData = data[pos : pos+int(valLen)] + } + pos += int(valLen) + } + } + + if !foundTimestamp { + t.Error("missing capture timestamp extension") + } + if !foundMarking { + t.Error("missing video frame marking extension") + } + if !foundConfig { + t.Error("missing video config extension") + } + + // AV1CodecConfigurationRecord starts with marker|version = 0x81 + if len(configData) < 4 { + t.Fatalf("config data too short: %d bytes", len(configData)) + } + if configData[0] != 0x81 { + t.Errorf("AV1 config marker|version: got 0x%02x, want 0x81", configData[0]) + } + + // Payload length + payloadLen, nn, err := quicvarint.Parse(data[pos:]) + if err != nil { + t.Fatalf("parse payload length: %v", err) + } + pos += nn + + // Payload should be the raw WireData (pass-through, no AnnexB conversion) + payload := data[pos : pos+int(payloadLen)] + if !bytes.Equal(payload, wireData) { + t.Errorf("payload mismatch: got %x, want %x", payload, wireData) + } +} + func TestMoQWriterStreamHeaderSize(t *testing.T) { t.Parallel() w := NewMoQWriter(1, 0) diff --git a/pipeline/pipeline.go b/pipeline/pipeline.go index 3dd88ba..89befbe 100644 --- a/pipeline/pipeline.go +++ b/pipeline/pipeline.go @@ -228,7 +228,19 @@ func (p *Pipeline) forwardVideo(frame *media.VideoFrame) { // including decoder configuration record for the catalog. func (p *Pipeline) buildVideoInfo(frame *media.VideoFrame) (distribution.VideoInfo, bool) { var vi distribution.VideoInfo - if frame.Codec == "h265" { + switch frame.Codec { + case "av1": + info, err := demux.ParseAV1SequenceHeader(frame.SPS) + if err != nil { + return vi, false + } + vi = distribution.VideoInfo{ + Codec: info.CodecString(), + Width: info.Width, + Height: info.Height, + } + vi.DecoderConfig = moq.BuildAV1DecoderConfig(frame.SPS) + case "h265": info, err := demux.ParseHEVCSPS(frame.SPS) if err != nil { return vi, false @@ -241,7 +253,7 @@ func (p *Pipeline) buildVideoInfo(frame *media.VideoFrame) (distribution.VideoIn if frame.VPS != nil { vi.DecoderConfig = moq.BuildHEVCDecoderConfig(frame.VPS, frame.SPS, frame.PPS) } - } else { + default: info, err := demux.ParseSPS(frame.SPS) if err != nil { return vi, false diff --git a/pipeline/pipeline_test.go b/pipeline/pipeline_test.go index 719e33f..61f393c 100644 --- a/pipeline/pipeline_test.go +++ b/pipeline/pipeline_test.go @@ -6,6 +6,7 @@ import ( "testing" "github.com/zsiec/prism/distribution" + "github.com/zsiec/prism/media" ) func TestNew(t *testing.T) { @@ -71,3 +72,51 @@ func TestDemuxStats(t *testing.T) { t.Fatal("expected non-nil DemuxStats") } } + +func TestBuildVideoInfoAV1(t *testing.T) { + t.Parallel() + + relay := distribution.NewRelay() + p := New("test-stream", strings.NewReader(""), relay) + + // AV1 sequence header OBU for 1920x1080, profile 0, level 8, 8-bit + // (same test data as demux/av1_test.go testSeqHdrOBU1080p) + seqHdrOBU := []byte{ + 0x0a, 0x0b, + 0x00, 0x00, 0x00, 0x42, 0xab, 0xbf, 0xc3, 0x71, + 0x2b, 0xe4, 0x01, + } + + frame := &media.VideoFrame{ + PTS: 0, + IsKeyframe: true, + Codec: "av1", + SPS: seqHdrOBU, + WireData: []byte{0x32, 0x10, 0xAA, 0xBB}, + } + + vi, ok := p.buildVideoInfo(frame) + if !ok { + t.Fatal("buildVideoInfo returned false") + } + + if vi.Width != 1920 { + t.Errorf("Width: got %d, want 1920", vi.Width) + } + if vi.Height != 1080 { + t.Errorf("Height: got %d, want 1080", vi.Height) + } + + wantPrefix := "av01." + if len(vi.Codec) < len(wantPrefix) || vi.Codec[:len(wantPrefix)] != wantPrefix { + t.Errorf("Codec: got %q, want prefix %q", vi.Codec, wantPrefix) + } + + if vi.DecoderConfig == nil { + t.Fatal("DecoderConfig is nil") + } + // AV1CodecConfigurationRecord starts with marker|version = 0x81 + if vi.DecoderConfig[0] != 0x81 { + t.Errorf("DecoderConfig[0]: got 0x%02x, want 0x81", vi.DecoderConfig[0]) + } +} From 7b5dc9391a032e4b850f5971eb58fee2f22cd339 Mon Sep 17 00:00:00 2001 From: Thomas Symborski Date: Wed, 18 Mar 2026 16:35:58 -0400 Subject: [PATCH 06/14] feat: add DASH MPD parser package for AV1-DASH ingest Add ingest/dash/ package with MPD XML parsing using go-mpd library, representation selection (by ID or highest bandwidth), segment template URL resolution, and live segment number computation. Includes tests for dynamic/static MPDs, template substitution, and edge cases. --- .gitignore | 2 + go.mod | 3 + go.sum | 7 ++ ingest/dash/mpd.go | 223 +++++++++++++++++++++++++++++++++++ ingest/dash/mpd_test.go | 255 ++++++++++++++++++++++++++++++++++++++++ 5 files changed, 490 insertions(+) create mode 100644 ingest/dash/mpd.go create mode 100644 ingest/dash/mpd_test.go diff --git a/.gitignore b/.gitignore index 0710f85..1e825cf 100644 --- a/.gitignore +++ b/.gitignore @@ -31,6 +31,8 @@ test/harness/* *.swp *~ +docs/plans/* + # Environment .env .env.* diff --git a/go.mod b/go.mod index 784de9d..8a8402b 100644 --- a/go.mod +++ b/go.mod @@ -11,8 +11,11 @@ require ( ) require ( + github.com/Eyevinn/mp4ff v0.51.0 // indirect github.com/dunglas/httpsfv v1.1.0 // indirect github.com/quic-go/qpack v0.6.0 // indirect + github.com/unki2aut/go-mpd v0.0.0-20250610073145-8336a8d84ee7 // indirect + github.com/unki2aut/go-xsd-types v0.0.0-20200220223938-30e5405398f8 // indirect golang.org/x/crypto v0.42.0 // indirect golang.org/x/net v0.44.0 // indirect golang.org/x/sys v0.36.0 // indirect diff --git a/go.sum b/go.sum index 43ddee4..c50174c 100644 --- a/go.sum +++ b/go.sum @@ -1,7 +1,10 @@ +github.com/Eyevinn/mp4ff v0.51.0 h1:ZYdHFXEcB3kJkCeCHMHl/tbCm64FJsD2XOU0Sj+ME2M= +github.com/Eyevinn/mp4ff v0.51.0/go.mod h1:hJNUUqOBryLAzUW9wpCJyw2HaI+TCd2rUPhafoS5lgg= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/dunglas/httpsfv v1.1.0 h1:Jw76nAyKWKZKFrpMMcL76y35tOpYHqQPzHQiwDvpe54= github.com/dunglas/httpsfv v1.1.0/go.mod h1:zID2mqw9mFsnt7YC3vYQ9/cjq30q41W+1AnDwH8TiMg= +github.com/go-test/deep v1.1.0/go.mod h1:5C2ZWiW0ErCdrYzpqxLbTX7MG14M9iiw8DgHncVwcsE= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/quic-go/qpack v0.6.0 h1:g7W+BMYynC1LbYLSqRt8PBg5Tgwxn214ZZR34VIOjz8= @@ -12,6 +15,10 @@ github.com/quic-go/webtransport-go v0.10.0 h1:LqXXPOXuETY5Xe8ITdGisBzTYmUOy5eSj+ github.com/quic-go/webtransport-go v0.10.0/go.mod h1:LeGIXr5BQKE3UsynwVBeQrU1TPrbh73MGoC6jd+V7ow= github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= +github.com/unki2aut/go-mpd v0.0.0-20250610073145-8336a8d84ee7 h1:FsfEp+/I/2p/WObIlPlB10kITDkOjzi/IDqM9GYpYJc= +github.com/unki2aut/go-mpd v0.0.0-20250610073145-8336a8d84ee7/go.mod h1:LITqXLCxxmcoHtOMgZh5NbcfS4RCrrADQXPVkYwF/cc= +github.com/unki2aut/go-xsd-types v0.0.0-20200220223938-30e5405398f8 h1:u0Bi6Mf8BKPQnxGJ7QubdMyhb0SJjnQU7kX0BA9eASk= +github.com/unki2aut/go-xsd-types v0.0.0-20200220223938-30e5405398f8/go.mod h1:uIeMfpmWIZ8SGp+fTfwDBWiiRn3aJm4b7rFSro9s++Q= github.com/zsiec/ccx v0.2.0 h1:RaCC4a0ng9wa6AUvT9Kg9kfiEy9svfXL5mmRbi6Ykmo= github.com/zsiec/ccx v0.2.0/go.mod h1:Y30W1TCZX7HAXM0miCzG18kSyKtJvqbOk7zrFcTGpU4= github.com/zsiec/srtgo v0.2.4 h1:WzQfUMSiQglWJDilcXgFvW/23IGLr7FXNRdrbNKJyS8= diff --git a/ingest/dash/mpd.go b/ingest/dash/mpd.go new file mode 100644 index 0000000..395c68f --- /dev/null +++ b/ingest/dash/mpd.go @@ -0,0 +1,223 @@ +package dash + +import ( + "fmt" + "strconv" + "strings" + "time" + + "github.com/unki2aut/go-mpd" +) + +type mpdInfo struct { + IsDynamic bool + AvailabilityStartTime time.Time + MinUpdatePeriod time.Duration + BaseURL string + VideoAdaptations []adaptationInfo + AudioAdaptations []adaptationInfo +} + +type adaptationInfo struct { + MimeType string + Representations []representationInfo +} + +type representationInfo struct { + ID string + Bandwidth int + Width int + Height int + Codecs string + SegmentTemplate segmentTemplate +} + +type segmentTemplate struct { + InitializationPattern string + MediaPattern string + StartNumber int + Timescale int + Duration int +} + +// parseMPD decodes MPD XML and extracts video/audio adaptation sets with +// their representations and segment templates. +func parseMPD(data []byte) (*mpdInfo, error) { + var m mpd.MPD + if err := m.Decode(data); err != nil { + return nil, fmt.Errorf("decode MPD: %w", err) + } + + info := &mpdInfo{} + + // Type + if m.Type != nil && *m.Type == "dynamic" { + info.IsDynamic = true + } + + // AvailabilityStartTime + if m.AvailabilityStartTime != nil { + info.AvailabilityStartTime = time.Time(*m.AvailabilityStartTime) + } + + // MinimumUpdatePeriod + if m.MinimumUpdatePeriod != nil { + ns, err := m.MinimumUpdatePeriod.ToNanoseconds() + if err == nil { + info.MinUpdatePeriod = time.Duration(ns) + } + } + + // BaseURL (take first if present) + if len(m.BaseURL) > 0 { + info.BaseURL = m.BaseURL[0].Value + } + + // Walk periods and adaptation sets + for _, period := range m.Period { + if period == nil { + continue + } + for _, as := range period.AdaptationSets { + if as == nil { + continue + } + ai := adaptationInfo{ + MimeType: as.MimeType, + } + for _, rep := range as.Representations { + ri := representationInfo{} + if rep.ID != nil { + ri.ID = *rep.ID + } + if rep.Bandwidth != nil { + ri.Bandwidth = int(*rep.Bandwidth) + } + if rep.Width != nil { + ri.Width = int(*rep.Width) + } + if rep.Height != nil { + ri.Height = int(*rep.Height) + } + if rep.Codecs != nil { + ri.Codecs = *rep.Codecs + } else if as.Codecs != nil { + ri.Codecs = *as.Codecs + } + + // SegmentTemplate: check representation level first, then adaptation set level + st := rep.SegmentTemplate + if st == nil { + st = as.SegmentTemplate + } + if st != nil { + ri.SegmentTemplate = extractSegmentTemplate(st) + } + + ai.Representations = append(ai.Representations, ri) + } + + if strings.HasPrefix(as.MimeType, "video") { + info.VideoAdaptations = append(info.VideoAdaptations, ai) + } else if strings.HasPrefix(as.MimeType, "audio") { + info.AudioAdaptations = append(info.AudioAdaptations, ai) + } + } + } + + return info, nil +} + +func extractSegmentTemplate(st *mpd.SegmentTemplate) segmentTemplate { + t := segmentTemplate{} + if st.Initialization != nil { + t.InitializationPattern = *st.Initialization + } + if st.Media != nil { + t.MediaPattern = *st.Media + } + if st.StartNumber != nil { + t.StartNumber = int(*st.StartNumber) + } + if st.Timescale != nil { + t.Timescale = int(*st.Timescale) + } + if st.Duration != nil { + t.Duration = int(*st.Duration) + } + return t +} + +// selectRepresentation picks a representation from the given adaptation sets. +// If repID is empty, the highest bandwidth representation is selected. +// If repID is given, an exact match is required. +func selectRepresentation(adaptations []adaptationInfo, repID string) (representationInfo, segmentTemplate, error) { + if len(adaptations) == 0 { + return representationInfo{}, segmentTemplate{}, fmt.Errorf("no adaptation sets available") + } + + if repID == "" { + // Select highest bandwidth across all adaptations + var best representationInfo + found := false + for _, a := range adaptations { + for _, r := range a.Representations { + if !found || r.Bandwidth > best.Bandwidth { + best = r + found = true + } + } + } + if !found { + return representationInfo{}, segmentTemplate{}, fmt.Errorf("no representations available") + } + return best, best.SegmentTemplate, nil + } + + // Exact match + for _, a := range adaptations { + for _, r := range a.Representations { + if r.ID == repID { + return r, r.SegmentTemplate, nil + } + } + } + return representationInfo{}, segmentTemplate{}, fmt.Errorf("representation %q not found", repID) +} + +// resolveInitURL replaces $RepresentationID$ in the initialization pattern +// and prepends the base URL. +func resolveInitURL(tmpl segmentTemplate, repID, baseURL string) string { + s := strings.ReplaceAll(tmpl.InitializationPattern, "$RepresentationID$", repID) + return baseURL + s +} + +// resolveMediaURL replaces $RepresentationID$ and $Number$ in the media +// pattern and prepends the base URL. +func resolveMediaURL(tmpl segmentTemplate, repID string, number int, baseURL string) string { + s := strings.ReplaceAll(tmpl.MediaPattern, "$RepresentationID$", repID) + s = strings.ReplaceAll(s, "$Number$", strconv.Itoa(number)) + return baseURL + s +} + +// computeSegmentNumber computes the latest fully-available segment number +// (one behind the live edge) based on the current time, availability start +// time, and the segment template parameters. +func computeSegmentNumber(now time.Time, ast time.Time, tmpl segmentTemplate) int { + if tmpl.Timescale == 0 || tmpl.Duration == 0 { + return tmpl.StartNumber + } + elapsed := now.Sub(ast).Seconds() + segmentDuration := float64(tmpl.Duration) / float64(tmpl.Timescale) + if elapsed < segmentDuration { + return tmpl.StartNumber + } + // Total segments elapsed from AST + totalSegments := int(elapsed / segmentDuration) + // One behind live edge + liveSegment := totalSegments - 1 + if liveSegment < 0 { + liveSegment = 0 + } + return tmpl.StartNumber + liveSegment +} diff --git a/ingest/dash/mpd_test.go b/ingest/dash/mpd_test.go new file mode 100644 index 0000000..4cd9c49 --- /dev/null +++ b/ingest/dash/mpd_test.go @@ -0,0 +1,255 @@ +package dash + +import ( + "testing" + "time" +) + +const testMPDDynamic = ` + + + + + + + + + + + + + + + + +` + +const testMPDStatic = ` + + + + + + + + +` + +func TestParseMPD(t *testing.T) { + info, err := parseMPD([]byte(testMPDDynamic)) + if err != nil { + t.Fatalf("parseMPD: %v", err) + } + + if !info.IsDynamic { + t.Error("expected IsDynamic=true") + } + + expectedAST := time.Date(2026, 3, 18, 0, 0, 0, 0, time.UTC) + if !info.AvailabilityStartTime.Equal(expectedAST) { + t.Errorf("AvailabilityStartTime = %v, want %v", info.AvailabilityStartTime, expectedAST) + } + + if info.MinUpdatePeriod != 2*time.Second { + t.Errorf("MinUpdatePeriod = %v, want 2s", info.MinUpdatePeriod) + } + + // Video adaptations + if len(info.VideoAdaptations) != 1 { + t.Fatalf("video adaptation count = %d, want 1", len(info.VideoAdaptations)) + } + va := info.VideoAdaptations[0] + if len(va.Representations) != 2 { + t.Fatalf("video rep count = %d, want 2", len(va.Representations)) + } + if va.Representations[0].ID != "1080p" { + t.Errorf("video rep[0] ID = %q, want 1080p", va.Representations[0].ID) + } + if va.Representations[1].ID != "720p" { + t.Errorf("video rep[1] ID = %q, want 720p", va.Representations[1].ID) + } + if va.Representations[0].Width != 1920 || va.Representations[0].Height != 1080 { + t.Errorf("video rep[0] dims = %dx%d, want 1920x1080", + va.Representations[0].Width, va.Representations[0].Height) + } + if va.Representations[0].Codecs != "av01.0.09M.08" { + t.Errorf("video rep[0] codecs = %q, want av01.0.09M.08", va.Representations[0].Codecs) + } + + // Audio adaptations + if len(info.AudioAdaptations) != 1 { + t.Fatalf("audio adaptation count = %d, want 1", len(info.AudioAdaptations)) + } + aa := info.AudioAdaptations[0] + if len(aa.Representations) != 1 { + t.Fatalf("audio rep count = %d, want 1", len(aa.Representations)) + } + if aa.Representations[0].ID != "aac-128k" { + t.Errorf("audio rep[0] ID = %q, want aac-128k", aa.Representations[0].ID) + } + + // Segment template on video rep + st := va.Representations[0].SegmentTemplate + if st.Timescale != 90000 { + t.Errorf("timescale = %d, want 90000", st.Timescale) + } + if st.Duration != 180000 { + t.Errorf("duration = %d, want 180000", st.Duration) + } + if st.StartNumber != 1 { + t.Errorf("startNumber = %d, want 1", st.StartNumber) + } +} + +func TestSelectRepresentation(t *testing.T) { + info, err := parseMPD([]byte(testMPDDynamic)) + if err != nil { + t.Fatalf("parseMPD: %v", err) + } + + t.Run("explicit ID", func(t *testing.T) { + rep, tmpl, err := selectRepresentation(info.VideoAdaptations, "720p") + if err != nil { + t.Fatalf("selectRepresentation: %v", err) + } + if rep.ID != "720p" { + t.Errorf("ID = %q, want 720p", rep.ID) + } + if rep.Bandwidth != 2000000 { + t.Errorf("Bandwidth = %d, want 2000000", rep.Bandwidth) + } + if tmpl.Timescale != 90000 { + t.Errorf("Timescale = %d, want 90000", tmpl.Timescale) + } + }) + + t.Run("empty ID selects highest bandwidth", func(t *testing.T) { + rep, _, err := selectRepresentation(info.VideoAdaptations, "") + if err != nil { + t.Fatalf("selectRepresentation: %v", err) + } + if rep.ID != "1080p" { + t.Errorf("ID = %q, want 1080p", rep.ID) + } + if rep.Bandwidth != 4000000 { + t.Errorf("Bandwidth = %d, want 4000000", rep.Bandwidth) + } + }) + + t.Run("nonexistent ID", func(t *testing.T) { + _, _, err := selectRepresentation(info.VideoAdaptations, "4k") + if err == nil { + t.Error("expected error for nonexistent ID") + } + }) +} + +func TestResolveInitURL(t *testing.T) { + tmpl := segmentTemplate{ + InitializationPattern: "video/$RepresentationID$/init.mp4", + MediaPattern: "video/$RepresentationID$/seg-$Number$.m4s", + StartNumber: 1, + Timescale: 90000, + Duration: 180000, + } + + got := resolveInitURL(tmpl, "1080p", "") + want := "video/1080p/init.mp4" + if got != want { + t.Errorf("resolveInitURL = %q, want %q", got, want) + } + + got = resolveInitURL(tmpl, "720p", "https://cdn.example.com/") + want = "https://cdn.example.com/video/720p/init.mp4" + if got != want { + t.Errorf("resolveInitURL with baseURL = %q, want %q", got, want) + } +} + +func TestResolveMediaURL(t *testing.T) { + tmpl := segmentTemplate{ + InitializationPattern: "video/$RepresentationID$/init.mp4", + MediaPattern: "video/$RepresentationID$/seg-$Number$.m4s", + StartNumber: 1, + Timescale: 90000, + Duration: 180000, + } + + got := resolveMediaURL(tmpl, "1080p", 42, "") + want := "video/1080p/seg-42.m4s" + if got != want { + t.Errorf("resolveMediaURL = %q, want %q", got, want) + } + + got = resolveMediaURL(tmpl, "720p", 1, "https://cdn.example.com/") + want = "https://cdn.example.com/video/720p/seg-1.m4s" + if got != want { + t.Errorf("resolveMediaURL with baseURL = %q, want %q", got, want) + } +} + +func TestComputeSegmentNumber(t *testing.T) { + ast := time.Date(2026, 3, 18, 0, 0, 0, 0, time.UTC) + tmpl := segmentTemplate{ + StartNumber: 1, + Timescale: 90000, + Duration: 180000, // 2 seconds per segment + } + + tests := []struct { + name string + offset time.Duration + want int + }{ + {"before first segment", 1 * time.Second, 1}, + {"exactly one segment", 2 * time.Second, 1}, + {"two segments elapsed", 4 * time.Second, 1 + 1}, + {"ten segments elapsed", 20 * time.Second, 1 + 9}, + {"one hour", 1 * time.Hour, 1 + 1799}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + now := ast.Add(tt.offset) + got := computeSegmentNumber(now, ast, tmpl) + if got != tt.want { + t.Errorf("computeSegmentNumber(offset=%v) = %d, want %d", tt.offset, got, tt.want) + } + }) + } +} + +func TestParseMPDStatic(t *testing.T) { + info, err := parseMPD([]byte(testMPDStatic)) + if err != nil { + t.Fatalf("parseMPD: %v", err) + } + if info.IsDynamic { + t.Error("expected IsDynamic=false for static MPD") + } + if len(info.VideoAdaptations) != 1 { + t.Errorf("video adaptation count = %d, want 1", len(info.VideoAdaptations)) + } +} + +func TestParseMPDInvalid(t *testing.T) { + _, err := parseMPD([]byte("this is not xml at all {{{")) + if err == nil { + t.Error("expected error for invalid XML") + } +} From 0d211582fccbb849b511c5211216a09b9e34fc9b Mon Sep 17 00:00:00 2001 From: Thomas Symborski Date: Wed, 18 Mar 2026 16:42:30 -0400 Subject: [PATCH 07/14] feat: add fMP4 segment parser for DASH AV1 ingest Parse fMP4 init segments to extract AV1/AAC codec configuration, track metadata, and trex boxes. Parse media segments to extract individual samples with microsecond-precision PTS/DTS timestamps and keyframe flags using the mp4ff library. --- go.mod | 4 +- go.sum | 9 + ingest/dash/segment.go | 232 +++++++++++++++++++++ ingest/dash/segment_test.go | 396 ++++++++++++++++++++++++++++++++++++ 4 files changed, 639 insertions(+), 2 deletions(-) create mode 100644 ingest/dash/segment.go create mode 100644 ingest/dash/segment_test.go diff --git a/go.mod b/go.mod index 8a8402b..0731c5e 100644 --- a/go.mod +++ b/go.mod @@ -3,18 +3,18 @@ module github.com/zsiec/prism go 1.24.3 require ( + github.com/Eyevinn/mp4ff v0.51.0 github.com/quic-go/quic-go v0.59.0 github.com/quic-go/webtransport-go v0.10.0 + github.com/unki2aut/go-mpd v0.0.0-20250610073145-8336a8d84ee7 github.com/zsiec/ccx v0.2.0 github.com/zsiec/srtgo v0.2.4 golang.org/x/sync v0.17.0 ) require ( - github.com/Eyevinn/mp4ff v0.51.0 // indirect github.com/dunglas/httpsfv v1.1.0 // indirect github.com/quic-go/qpack v0.6.0 // indirect - github.com/unki2aut/go-mpd v0.0.0-20250610073145-8336a8d84ee7 // indirect github.com/unki2aut/go-xsd-types v0.0.0-20200220223938-30e5405398f8 // indirect golang.org/x/crypto v0.42.0 // indirect golang.org/x/net v0.44.0 // indirect diff --git a/go.sum b/go.sum index c50174c..851b3e5 100644 --- a/go.sum +++ b/go.sum @@ -4,7 +4,12 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/dunglas/httpsfv v1.1.0 h1:Jw76nAyKWKZKFrpMMcL76y35tOpYHqQPzHQiwDvpe54= github.com/dunglas/httpsfv v1.1.0/go.mod h1:zID2mqw9mFsnt7YC3vYQ9/cjq30q41W+1AnDwH8TiMg= +github.com/go-test/deep v1.1.0 h1:WOcxcdHcvdgThNXjw0t76K42FXTU7HpNQWHpA2HHNlg= github.com/go-test/deep v1.1.0/go.mod h1:5C2ZWiW0ErCdrYzpqxLbTX7MG14M9iiw8DgHncVwcsE= +github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= +github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= +github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= +github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/quic-go/qpack v0.6.0 h1:g7W+BMYynC1LbYLSqRt8PBg5Tgwxn214ZZR34VIOjz8= @@ -13,6 +18,8 @@ github.com/quic-go/quic-go v0.59.0 h1:OLJkp1Mlm/aS7dpKgTc6cnpynnD2Xg7C1pwL6vy/SA github.com/quic-go/quic-go v0.59.0/go.mod h1:upnsH4Ju1YkqpLXC305eW3yDZ4NfnNbmQRCMWS58IKU= github.com/quic-go/webtransport-go v0.10.0 h1:LqXXPOXuETY5Xe8ITdGisBzTYmUOy5eSj+9n4hLTjHI= github.com/quic-go/webtransport-go v0.10.0/go.mod h1:LeGIXr5BQKE3UsynwVBeQrU1TPrbh73MGoC6jd+V7ow= +github.com/rogpeppe/go-internal v1.10.0 h1:TMyTOH3F/DB16zRVcYyreMH6GnZZrwQVAoYjRBZyWFQ= +github.com/rogpeppe/go-internal v1.10.0/go.mod h1:UQnix2H7Ngw/k4C5ijL5+65zddjncjaFoBhdsK/akog= github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= github.com/unki2aut/go-mpd v0.0.0-20250610073145-8336a8d84ee7 h1:FsfEp+/I/2p/WObIlPlB10kITDkOjzi/IDqM9GYpYJc= @@ -35,5 +42,7 @@ golang.org/x/sys v0.36.0 h1:KVRy2GtZBrk1cBYA7MKu5bEZFxQk4NIDV6RLVcC8o0k= golang.org/x/sys v0.36.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= golang.org/x/text v0.30.0 h1:yznKA/E9zq54KzlzBEAWn1NXSQ8DIp/NYMy88xJjl4k= golang.org/x/text v0.30.0/go.mod h1:yDdHFIX9t+tORqspjENWgzaCVXgk0yYnYuSZ8UzzBVM= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/ingest/dash/segment.go b/ingest/dash/segment.go new file mode 100644 index 0000000..eadf175 --- /dev/null +++ b/ingest/dash/segment.go @@ -0,0 +1,232 @@ +package dash + +import ( + "bytes" + "fmt" + + "github.com/Eyevinn/mp4ff/mp4" + "github.com/zsiec/prism/demux" +) + +// initSegmentInfo holds codec and track metadata extracted from an fMP4 init +// segment (ftyp + moov). +type initSegmentInfo struct { + VideoCodec string // "av1", "h264", etc. + Width int + Height int + SeqHeaderOBU []byte // AV1: sequence header OBU from av1C ConfigOBUs + AudioCodec string // e.g. "mp4a.40.2" + SampleRate int + Channels int + VideoTimescale uint32 + AudioTimescale uint32 + VideoTrackID uint32 + AudioTrackID uint32 + videoTrex *mp4.TrexBox // stored for GetFullSamples + audioTrex *mp4.TrexBox +} + +// mediaSample is a decoded audio or video sample with timestamps in +// microseconds. +type mediaSample struct { + PTS int64 // microseconds + DTS int64 // microseconds + Data []byte + IsKeyframe bool +} + +// parseInitSegment decodes an fMP4 init segment (ftyp+moov) and extracts +// video/audio track metadata including codec configuration. +func parseInitSegment(data []byte) (*initSegmentInfo, error) { + if len(data) == 0 { + return nil, fmt.Errorf("empty init segment data") + } + + parsed, err := mp4.DecodeFile(bytes.NewReader(data)) + if err != nil { + return nil, fmt.Errorf("decode init segment: %w", err) + } + + // The moov box can be at parsed.Moov (when DecodeFile sees ftyp+moov) + // or at parsed.Init.Moov (when it's recognized as a fragmented init). + moov := parsed.Moov + if moov == nil && parsed.Init != nil { + moov = parsed.Init.Moov + } + if moov == nil { + return nil, fmt.Errorf("no moov box found in init segment") + } + + if len(moov.Traks) == 0 { + return nil, fmt.Errorf("no tracks found in init segment") + } + + info := &initSegmentInfo{} + + for _, trak := range moov.Traks { + if trak.Mdia == nil || trak.Mdia.Hdlr == nil { + continue + } + + trackID := trak.Tkhd.TrackID + handler := trak.Mdia.Hdlr.HandlerType + + switch handler { + case "vide": + info.VideoTrackID = trackID + if trak.Mdia.Mdhd != nil { + info.VideoTimescale = trak.Mdia.Mdhd.Timescale + } + if err := extractVideoInfo(trak, info); err != nil { + return nil, fmt.Errorf("extract video info: %w", err) + } + case "soun": + info.AudioTrackID = trackID + if trak.Mdia.Mdhd != nil { + info.AudioTimescale = trak.Mdia.Mdhd.Timescale + } + if err := extractAudioInfo(trak, info); err != nil { + return nil, fmt.Errorf("extract audio info: %w", err) + } + } + } + + // Resolve trex boxes from mvex. + if moov.Mvex != nil { + resolveTrex(moov.Mvex, info) + } + + return info, nil +} + +// extractVideoInfo populates video fields on info from the given video trak. +func extractVideoInfo(trak *mp4.TrakBox, info *initSegmentInfo) error { + stsd := getStsd(trak) + if stsd == nil { + return fmt.Errorf("no stsd box in video track") + } + + if stsd.Av01 != nil { + info.VideoCodec = "av1" + info.Width = int(stsd.Av01.Width) + info.Height = int(stsd.Av01.Height) + if stsd.Av01.Av1C != nil { + configOBUs := stsd.Av01.Av1C.ConfigOBUs + seqHdr := demux.FindSequenceHeaderOBU(configOBUs) + if seqHdr != nil { + info.SeqHeaderOBU = make([]byte, len(seqHdr)) + copy(info.SeqHeaderOBU, seqHdr) + } + } + } else if stsd.AvcX != nil { + info.VideoCodec = "h264" + info.Width = int(stsd.AvcX.Width) + info.Height = int(stsd.AvcX.Height) + } else { + return fmt.Errorf("unsupported video codec (no av01 or avcX in stsd)") + } + return nil +} + +// extractAudioInfo populates audio fields on info from the given audio trak. +func extractAudioInfo(trak *mp4.TrakBox, info *initSegmentInfo) error { + stsd := getStsd(trak) + if stsd == nil { + return fmt.Errorf("no stsd box in audio track") + } + + if stsd.Mp4a != nil { + info.AudioCodec = "mp4a.40.2" + info.SampleRate = int(stsd.Mp4a.SampleRate) + info.Channels = int(stsd.Mp4a.ChannelCount) + } else { + return fmt.Errorf("unsupported audio codec (no mp4a in stsd)") + } + return nil +} + +// getStsd returns the stsd box from a track, or nil if the path is incomplete. +func getStsd(trak *mp4.TrakBox) *mp4.StsdBox { + if trak.Mdia == nil || trak.Mdia.Minf == nil || trak.Mdia.Minf.Stbl == nil { + return nil + } + return trak.Mdia.Minf.Stbl.Stsd +} + +// resolveTrex finds TrexBox entries matching video and audio track IDs. +func resolveTrex(mvex *mp4.MvexBox, info *initSegmentInfo) { + // mp4ff stores a single trex in Trex and multiple in Trexs. + trexList := mvex.Trexs + if len(trexList) == 0 && mvex.Trex != nil { + trexList = []*mp4.TrexBox{mvex.Trex} + } + for _, trex := range trexList { + switch trex.TrackID { + case info.VideoTrackID: + info.videoTrex = trex + case info.AudioTrackID: + info.audioTrex = trex + } + } +} + +// parseMediaSegment decodes an fMP4 media segment and extracts samples for the +// specified track type. The init parameter must have been obtained from +// parseInitSegment on the corresponding init segment. +func parseMediaSegment(data []byte, init *initSegmentInfo, isVideo bool) ([]mediaSample, error) { + if len(data) == 0 { + return nil, fmt.Errorf("empty media segment data") + } + if init == nil { + return nil, fmt.Errorf("nil init segment info") + } + + parsed, err := mp4.DecodeFile(bytes.NewReader(data)) + if err != nil { + return nil, fmt.Errorf("decode media segment: %w", err) + } + + var trex *mp4.TrexBox + var timescale uint32 + if isVideo { + trex = init.videoTrex + timescale = init.VideoTimescale + } else { + trex = init.audioTrex + timescale = init.AudioTimescale + } + + if timescale == 0 { + return nil, fmt.Errorf("timescale is zero") + } + + var samples []mediaSample + + for _, seg := range parsed.Segments { + for _, frag := range seg.Fragments { + fullSamples, err := frag.GetFullSamples(trex) + if err != nil { + return nil, fmt.Errorf("get full samples: %w", err) + } + + for i := range fullSamples { + fs := &fullSamples[i] + ms := mediaSample{ + DTS: scaleToMicroseconds(fs.DecodeTime, timescale), + PTS: scaleToMicroseconds(uint64(fs.PresentationTime()), timescale), + Data: fs.Data, + IsKeyframe: fs.IsSync(), + } + samples = append(samples, ms) + } + } + } + + return samples, nil +} + +// scaleToMicroseconds converts a timestamp in the given timescale to +// microseconds. +func scaleToMicroseconds(ts uint64, timescale uint32) int64 { + return int64(ts * 1_000_000 / uint64(timescale)) +} diff --git a/ingest/dash/segment_test.go b/ingest/dash/segment_test.go new file mode 100644 index 0000000..1c14305 --- /dev/null +++ b/ingest/dash/segment_test.go @@ -0,0 +1,396 @@ +package dash + +import ( + "bytes" + "testing" + + "github.com/Eyevinn/mp4ff/aac" + "github.com/Eyevinn/mp4ff/av1" + "github.com/Eyevinn/mp4ff/mp4" +) + +func TestScaleToMicroseconds(t *testing.T) { + tests := []struct { + name string + ts uint64 + timescale uint32 + want int64 + }{ + {"zero", 0, 90000, 0}, + {"one second at 90kHz", 90000, 90000, 1_000_000}, + {"half second at 48kHz", 24000, 48000, 500_000}, + {"two seconds at 90kHz", 180000, 90000, 2_000_000}, + {"one frame at 30fps/90kHz", 3000, 90000, 33333}, + {"one sample at 44100", 1, 44100, 22}, // 1_000_000/44100 = 22.67 -> 22 + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := scaleToMicroseconds(tt.ts, tt.timescale) + if got != tt.want { + t.Errorf("scaleToMicroseconds(%d, %d) = %d, want %d", + tt.ts, tt.timescale, got, tt.want) + } + }) + } +} + +func TestParseInitSegmentNil(t *testing.T) { + _, err := parseInitSegment(nil) + if err == nil { + t.Error("expected error for nil data") + } + + _, err = parseInitSegment([]byte{}) + if err == nil { + t.Error("expected error for empty data") + } +} + +func TestParseInitSegmentInvalid(t *testing.T) { + _, err := parseInitSegment([]byte("not valid mp4 data at all")) + if err == nil { + t.Error("expected error for invalid data") + } +} + +func TestParseMediaSegmentNil(t *testing.T) { + _, err := parseMediaSegment(nil, &initSegmentInfo{}, true) + if err == nil { + t.Error("expected error for nil data") + } + + _, err = parseMediaSegment([]byte{}, &initSegmentInfo{}, true) + if err == nil { + t.Error("expected error for empty data") + } +} + +func TestParseMediaSegmentNilInit(t *testing.T) { + _, err := parseMediaSegment([]byte{0x00}, nil, true) + if err == nil { + t.Error("expected error for nil init info") + } +} + +// buildAV1InitSegment creates a minimal AV1+AAC init segment using mp4ff +// builder APIs. Returns the encoded bytes. +func buildAV1InitSegment(t *testing.T) []byte { + t.Helper() + + init := mp4.CreateEmptyInit() + + // Add video track (AV1, 1920x1080, 90kHz timescale) + videoTrak := init.AddEmptyTrack(90000, "video", "und") + videoTrak.Tkhd.Width = mp4.Fixed32(1920 << 16) + videoTrak.Tkhd.Height = mp4.Fixed32(1080 << 16) + + // Create an Av1C box with a minimal AV1 codec config + av1C := &mp4.Av1CBox{ + CodecConfRec: av1.CodecConfRec{ + Version: 1, + SeqProfile: 0, + SeqLevelIdx0: 9, + SeqTier0: 0, + HighBitdepth: 0, + // Build a minimal sequence header OBU for ConfigOBUs. + // OBU type 1 (seq header), has_size=1: + // header byte: 0b0_0001_0_1_0 = 0x0A + // size: LEB128 for 4 bytes = 0x04 + // payload: 4 bytes of minimal data + ConfigOBUs: buildMinimalSeqHeaderOBU(), + }, + } + + // Create AV1 visual sample entry and add to stsd + av01 := mp4.CreateVisualSampleEntryBox("av01", 1920, 1080, av1C) + videoTrak.Mdia.Minf.Stbl.Stsd.AddChild(av01) + + // Add audio track (AAC, 48kHz, stereo) + audioTrak := init.AddEmptyTrack(48000, "audio", "und") + err := audioTrak.SetAACDescriptor(aac.AAClc, 48000) + if err != nil { + t.Fatalf("SetAACDescriptor: %v", err) + } + + var buf bytes.Buffer + err = init.Encode(&buf) + if err != nil { + t.Fatalf("encode init segment: %v", err) + } + return buf.Bytes() +} + +// buildMinimalSeqHeaderOBU returns a minimal AV1 sequence header OBU with +// size field. This is not a fully valid sequence header but contains enough +// structure for FindSequenceHeaderOBU to locate it. +func buildMinimalSeqHeaderOBU() []byte { + // OBU header: type=1 (seq header), has_size=1 => 0x0A + // Size in LEB128: we'll use a small payload + // A minimal sequence_header_obu payload (not fully decodeable but + // structurally valid enough for FindSequenceHeaderOBU to find and return) + payload := []byte{ + 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, + } + obu := []byte{0x0A, byte(len(payload))} + obu = append(obu, payload...) + return obu +} + +func TestParseInitSegmentAV1(t *testing.T) { + data := buildAV1InitSegment(t) + + info, err := parseInitSegment(data) + if err != nil { + t.Fatalf("parseInitSegment: %v", err) + } + + if info.VideoCodec != "av1" { + t.Errorf("VideoCodec = %q, want av1", info.VideoCodec) + } + if info.Width != 1920 { + t.Errorf("Width = %d, want 1920", info.Width) + } + if info.Height != 1080 { + t.Errorf("Height = %d, want 1080", info.Height) + } + if info.VideoTimescale != 90000 { + t.Errorf("VideoTimescale = %d, want 90000", info.VideoTimescale) + } + if info.VideoTrackID == 0 { + t.Error("VideoTrackID should not be zero") + } + if info.SeqHeaderOBU == nil { + t.Error("SeqHeaderOBU should not be nil") + } + + // Audio track + if info.AudioCodec != "mp4a.40.2" { + t.Errorf("AudioCodec = %q, want mp4a.40.2", info.AudioCodec) + } + if info.SampleRate != 48000 { + t.Errorf("SampleRate = %d, want 48000", info.SampleRate) + } + if info.Channels != 2 { + t.Errorf("Channels = %d, want 2", info.Channels) + } + if info.AudioTimescale != 48000 { + t.Errorf("AudioTimescale = %d, want 48000", info.AudioTimescale) + } + if info.AudioTrackID == 0 { + t.Error("AudioTrackID should not be zero") + } + + // Trex boxes should be resolved + if info.videoTrex == nil { + t.Error("videoTrex should not be nil") + } + if info.audioTrex == nil { + t.Error("audioTrex should not be nil") + } +} + +// buildVideoMediaSegment creates a minimal fMP4 media segment with a single +// video fragment containing 3 samples. +func buildVideoMediaSegment(t *testing.T, trackID uint32) []byte { + t.Helper() + + seg := mp4.NewMediaSegment() + frag, err := mp4.CreateFragment(1, trackID) + if err != nil { + t.Fatalf("CreateFragment: %v", err) + } + + // Add 3 samples: first is a keyframe, others are not + sampleData := []byte{0x01, 0x02, 0x03, 0x04} + baseDecodeTime := uint64(0) + + // Keyframe sample + frag.AddFullSample(mp4.FullSample{ + Sample: mp4.Sample{ + Flags: mp4.SyncSampleFlags, + Dur: 3000, // 1 frame at 30fps in 90kHz + Size: uint32(len(sampleData)), + CompositionTimeOffset: 0, + }, + DecodeTime: baseDecodeTime, + Data: sampleData, + }) + + // Non-keyframe sample + frag.AddFullSample(mp4.FullSample{ + Sample: mp4.Sample{ + Flags: mp4.NonSyncSampleFlags, + Dur: 3000, + Size: uint32(len(sampleData)), + CompositionTimeOffset: 3000, // PTS > DTS + }, + DecodeTime: baseDecodeTime + 3000, + Data: sampleData, + }) + + // Another non-keyframe sample + frag.AddFullSample(mp4.FullSample{ + Sample: mp4.Sample{ + Flags: mp4.NonSyncSampleFlags, + Dur: 3000, + Size: uint32(len(sampleData)), + CompositionTimeOffset: 0, + }, + DecodeTime: baseDecodeTime + 6000, + Data: sampleData, + }) + + seg.AddFragment(frag) + + var buf bytes.Buffer + err = seg.Encode(&buf) + if err != nil { + t.Fatalf("encode media segment: %v", err) + } + return buf.Bytes() +} + +func TestParseMediaSegmentVideo(t *testing.T) { + // First build an init segment to get metadata + initData := buildAV1InitSegment(t) + info, err := parseInitSegment(initData) + if err != nil { + t.Fatalf("parseInitSegment: %v", err) + } + + // Build a video media segment + segData := buildVideoMediaSegment(t, info.VideoTrackID) + + samples, err := parseMediaSegment(segData, info, true) + if err != nil { + t.Fatalf("parseMediaSegment: %v", err) + } + + if len(samples) != 3 { + t.Fatalf("sample count = %d, want 3", len(samples)) + } + + // First sample: keyframe at DTS=0 + if !samples[0].IsKeyframe { + t.Error("sample[0] should be a keyframe") + } + if samples[0].DTS != 0 { + t.Errorf("sample[0].DTS = %d, want 0", samples[0].DTS) + } + if samples[0].PTS != 0 { + t.Errorf("sample[0].PTS = %d, want 0", samples[0].PTS) + } + if len(samples[0].Data) != 4 { + t.Errorf("sample[0].Data length = %d, want 4", len(samples[0].Data)) + } + + // Second sample: non-keyframe at DTS=3000/90000 * 1e6 = 33333us + if samples[1].IsKeyframe { + t.Error("sample[1] should not be a keyframe") + } + expectedDTS := scaleToMicroseconds(3000, 90000) + if samples[1].DTS != expectedDTS { + t.Errorf("sample[1].DTS = %d, want %d", samples[1].DTS, expectedDTS) + } + // PTS = (3000 + 3000) / 90000 * 1e6 = 66666us + expectedPTS := scaleToMicroseconds(6000, 90000) + if samples[1].PTS != expectedPTS { + t.Errorf("sample[1].PTS = %d, want %d", samples[1].PTS, expectedPTS) + } + + // Third sample: non-keyframe at DTS=6000/90000 * 1e6 = 66666us + expectedDTS3 := scaleToMicroseconds(6000, 90000) + if samples[2].DTS != expectedDTS3 { + t.Errorf("sample[2].DTS = %d, want %d", samples[2].DTS, expectedDTS3) + } +} + +func TestParseMediaSegmentZeroTimescale(t *testing.T) { + initInfo := &initSegmentInfo{ + VideoTimescale: 0, + videoTrex: &mp4.TrexBox{TrackID: 1}, + } + _, err := parseMediaSegment([]byte{0x00}, initInfo, true) + if err == nil { + t.Error("expected error for zero timescale") + } +} + +// buildAudioMediaSegment creates a minimal fMP4 media segment with audio +// samples. +func buildAudioMediaSegment(t *testing.T, trackID uint32) []byte { + t.Helper() + + seg := mp4.NewMediaSegment() + frag, err := mp4.CreateFragment(1, trackID) + if err != nil { + t.Fatalf("CreateFragment: %v", err) + } + + // Add 2 audio samples (AAC frames are all sync) + sampleData := make([]byte, 128) + for i := range sampleData { + sampleData[i] = byte(i) + } + + frag.AddFullSample(mp4.FullSample{ + Sample: mp4.Sample{ + Flags: mp4.SyncSampleFlags, + Dur: 1024, // typical AAC frame duration at 48kHz + Size: uint32(len(sampleData)), + }, + DecodeTime: 0, + Data: sampleData, + }) + + frag.AddFullSample(mp4.FullSample{ + Sample: mp4.Sample{ + Flags: mp4.SyncSampleFlags, + Dur: 1024, + Size: uint32(len(sampleData)), + }, + DecodeTime: 1024, + Data: sampleData, + }) + + seg.AddFragment(frag) + + var buf bytes.Buffer + err = seg.Encode(&buf) + if err != nil { + t.Fatalf("encode media segment: %v", err) + } + return buf.Bytes() +} + +func TestParseMediaSegmentAudio(t *testing.T) { + initData := buildAV1InitSegment(t) + info, err := parseInitSegment(initData) + if err != nil { + t.Fatalf("parseInitSegment: %v", err) + } + + segData := buildAudioMediaSegment(t, info.AudioTrackID) + + samples, err := parseMediaSegment(segData, info, false) + if err != nil { + t.Fatalf("parseMediaSegment: %v", err) + } + + if len(samples) != 2 { + t.Fatalf("sample count = %d, want 2", len(samples)) + } + + // First audio sample: DTS=0 + if samples[0].DTS != 0 { + t.Errorf("sample[0].DTS = %d, want 0", samples[0].DTS) + } + + // Second audio sample: DTS = 1024/48000 * 1e6 = 21333us + expectedDTS := scaleToMicroseconds(1024, 48000) + if samples[1].DTS != expectedDTS { + t.Errorf("sample[1].DTS = %d, want %d", samples[1].DTS, expectedDTS) + } +} From baab1d708f26c074301851b357c129309fd70f38 Mon Sep 17 00:00:00 2001 From: Thomas Symborski Date: Wed, 18 Mar 2026 16:45:29 -0400 Subject: [PATCH 08/14] feat: add DASH pipeline to bridge fMP4 samples to Relay fan-out Introduces dashPipeline which converts mediaSample values from the fMP4 segment parser into media.VideoFrame and media.AudioFrame for broadcast via the Relay. Keyframes start new groups and carry the AV1 sequence header OBU as SPS for late-joining decoder initialization. Implements distribution.StatsProvider for the stats overlay. --- ingest/dash/pipeline.go | 112 +++++++++++++++ ingest/dash/pipeline_test.go | 266 +++++++++++++++++++++++++++++++++++ 2 files changed, 378 insertions(+) create mode 100644 ingest/dash/pipeline.go create mode 100644 ingest/dash/pipeline_test.go diff --git a/ingest/dash/pipeline.go b/ingest/dash/pipeline.go new file mode 100644 index 0000000..2292049 --- /dev/null +++ b/ingest/dash/pipeline.go @@ -0,0 +1,112 @@ +package dash + +import ( + "sync/atomic" + "time" + + "github.com/zsiec/prism/distribution" + "github.com/zsiec/prism/media" +) + +// Compile-time interface check. +var _ distribution.StatsProvider = (*dashPipeline)(nil) + +// broadcaster is the subset of distribution.Relay the DASH pipeline uses. +// It omits BroadcastCaptions because DASH sources do not carry CEA-608/708. +type broadcaster interface { + BroadcastVideo(frame *media.VideoFrame) + BroadcastAudio(frame *media.AudioFrame) + SetVideoInfo(info distribution.VideoInfo) + SetAudioTrackCount(count int) + AudioTrackCount() int + SetAudioInfo(info distribution.AudioInfo) + ViewerCount() int + ViewerStatsAll() []distribution.ViewerStats +} + +// dashPipeline bridges fMP4 mediaSamples to the frame-based Relay for fan-out +// to MoQ viewers. It converts decoded DASH segments into media.VideoFrame and +// media.AudioFrame values and broadcasts them through the Relay. +type dashPipeline struct { + streamKey string + relay broadcaster + seqHdrOBU []byte // AV1 sequence header OBU, prepended to keyframes as SPS + startTime time.Time + groupID uint32 + + videoForwarded atomic.Int64 + audioForwarded atomic.Int64 +} + +// newDASHPipeline creates a pipeline that converts DASH mediaSamples into +// VideoFrame/AudioFrame values and broadcasts them via relay. The seqHdrOBU +// is the AV1 sequence header OBU extracted from the init segment, attached +// to keyframes so downstream decoders can initialize. +func newDASHPipeline(streamKey string, relay broadcaster, seqHdrOBU []byte) *dashPipeline { + return &dashPipeline{ + streamKey: streamKey, + relay: relay, + seqHdrOBU: seqHdrOBU, + startTime: time.Now(), + } +} + +// processVideoSamples converts video mediaSamples into VideoFrames and +// broadcasts them through the relay. Keyframes start a new group and carry +// the AV1 sequence header OBU as SPS so that late-joining decoders can +// configure themselves. +func (p *dashPipeline) processVideoSamples(samples []mediaSample) { + for i := range samples { + s := &samples[i] + + frame := &media.VideoFrame{ + PTS: s.PTS, + DTS: s.DTS, + IsKeyframe: s.IsKeyframe, + Codec: "av1", + WireData: s.Data, + } + + if s.IsKeyframe { + p.groupID++ + frame.SPS = p.seqHdrOBU + } + frame.GroupID = p.groupID + + p.relay.BroadcastVideo(frame) + p.videoForwarded.Add(1) + } +} + +// processAudioSamples converts audio mediaSamples into AudioFrames and +// broadcasts them through the relay. +func (p *dashPipeline) processAudioSamples(samples []mediaSample) { + for i := range samples { + s := &samples[i] + + frame := &media.AudioFrame{ + PTS: s.PTS, + Data: s.Data, + } + + p.relay.BroadcastAudio(frame) + p.audioForwarded.Add(1) + } +} + +// StreamSnapshot returns a point-in-time snapshot of stream health metrics, +// implementing the distribution.StatsProvider interface for the stats overlay +// and REST API. +func (p *dashPipeline) StreamSnapshot() distribution.StreamSnapshot { + return distribution.StreamSnapshot{ + Timestamp: time.Now().UnixMilli(), + UptimeMs: time.Since(p.startTime).Milliseconds(), + Protocol: "DASH", + Video: distribution.VideoStats{ + Codec: "AV1", + TotalFrames: p.videoForwarded.Load(), + }, + ViewerCount: p.relay.ViewerCount(), + Viewers: p.relay.ViewerStatsAll(), + } +} diff --git a/ingest/dash/pipeline_test.go b/ingest/dash/pipeline_test.go new file mode 100644 index 0000000..9d071b5 --- /dev/null +++ b/ingest/dash/pipeline_test.go @@ -0,0 +1,266 @@ +package dash + +import ( + "testing" + "time" + + "github.com/zsiec/prism/distribution" + "github.com/zsiec/prism/media" +) + +// mockBroadcaster captures all calls made by the dashPipeline for test assertions. +type mockBroadcaster struct { + videoFrames []*media.VideoFrame + audioFrames []*media.AudioFrame + videoInfo *distribution.VideoInfo + audioInfo *distribution.AudioInfo + trackCount int +} + +func (m *mockBroadcaster) BroadcastVideo(frame *media.VideoFrame) { + m.videoFrames = append(m.videoFrames, frame) +} + +func (m *mockBroadcaster) BroadcastAudio(frame *media.AudioFrame) { + m.audioFrames = append(m.audioFrames, frame) +} + +func (m *mockBroadcaster) SetVideoInfo(info distribution.VideoInfo) { + m.videoInfo = &info +} + +func (m *mockBroadcaster) SetAudioTrackCount(count int) { + m.trackCount = count +} + +func (m *mockBroadcaster) AudioTrackCount() int { + if m.trackCount == 0 { + return 1 + } + return m.trackCount +} + +func (m *mockBroadcaster) SetAudioInfo(info distribution.AudioInfo) { + m.audioInfo = &info +} + +func (m *mockBroadcaster) ViewerCount() int { + return 0 +} + +func (m *mockBroadcaster) ViewerStatsAll() []distribution.ViewerStats { + return nil +} + +func TestDASHPipelineProcessVideoSamples(t *testing.T) { + t.Parallel() + + mock := &mockBroadcaster{} + seqHdr := []byte{0x0A, 0x0B, 0x0C} + p := newDASHPipeline("test-stream", mock, seqHdr) + + samples := []mediaSample{ + {PTS: 0, DTS: 0, Data: []byte{0x01, 0x02}, IsKeyframe: true}, + {PTS: 33333, DTS: 33333, Data: []byte{0x03, 0x04}, IsKeyframe: false}, + {PTS: 66666, DTS: 66666, Data: []byte{0x05, 0x06}, IsKeyframe: false}, + } + + p.processVideoSamples(samples) + + if len(mock.videoFrames) != 3 { + t.Fatalf("expected 3 video frames, got %d", len(mock.videoFrames)) + } + + // All frames should have codec "av1". + for i, f := range mock.videoFrames { + if f.Codec != "av1" { + t.Errorf("frame %d: codec = %q, want %q", i, f.Codec, "av1") + } + } + + // First frame: keyframe with SPS set. + kf := mock.videoFrames[0] + if !kf.IsKeyframe { + t.Error("frame 0: expected IsKeyframe=true") + } + if kf.SPS == nil { + t.Error("frame 0: expected SPS to be set for keyframe") + } + if len(kf.SPS) != 3 || kf.SPS[0] != 0x0A { + t.Errorf("frame 0: SPS = %v, want %v", kf.SPS, seqHdr) + } + + // Delta frames: no SPS. + for i := 1; i < 3; i++ { + df := mock.videoFrames[i] + if df.IsKeyframe { + t.Errorf("frame %d: expected IsKeyframe=false", i) + } + if df.SPS != nil { + t.Errorf("frame %d: expected SPS=nil for delta frame, got %v", i, df.SPS) + } + } + + // GroupID should be 1 (incremented once for the keyframe). + if kf.GroupID != 1 { + t.Errorf("frame 0: GroupID = %d, want 1", kf.GroupID) + } + for i := 1; i < 3; i++ { + if mock.videoFrames[i].GroupID != 1 { + t.Errorf("frame %d: GroupID = %d, want 1", i, mock.videoFrames[i].GroupID) + } + } + + // WireData should match sample data. + for i, s := range samples { + f := mock.videoFrames[i] + if len(f.WireData) != len(s.Data) { + t.Errorf("frame %d: WireData length = %d, want %d", i, len(f.WireData), len(s.Data)) + } + for j := range s.Data { + if f.WireData[j] != s.Data[j] { + t.Errorf("frame %d: WireData[%d] = %d, want %d", i, j, f.WireData[j], s.Data[j]) + } + } + } + + // Forwarded counter. + if got := p.videoForwarded.Load(); got != 3 { + t.Errorf("videoForwarded = %d, want 3", got) + } +} + +func TestDASHPipelineProcessAudioSamples(t *testing.T) { + t.Parallel() + + mock := &mockBroadcaster{} + p := newDASHPipeline("test-stream", mock, nil) + + samples := []mediaSample{ + {PTS: 0, Data: []byte{0xAA, 0xBB}}, + {PTS: 21333, Data: []byte{0xCC, 0xDD, 0xEE}}, + } + + p.processAudioSamples(samples) + + if len(mock.audioFrames) != 2 { + t.Fatalf("expected 2 audio frames, got %d", len(mock.audioFrames)) + } + + for i, s := range samples { + f := mock.audioFrames[i] + if f.PTS != s.PTS { + t.Errorf("frame %d: PTS = %d, want %d", i, f.PTS, s.PTS) + } + if len(f.Data) != len(s.Data) { + t.Errorf("frame %d: Data length = %d, want %d", i, len(f.Data), len(s.Data)) + continue + } + for j := range s.Data { + if f.Data[j] != s.Data[j] { + t.Errorf("frame %d: Data[%d] = %d, want %d", i, j, f.Data[j], s.Data[j]) + } + } + } + + if got := p.audioForwarded.Load(); got != 2 { + t.Errorf("audioForwarded = %d, want 2", got) + } +} + +func TestDASHPipelineGroupIDIncrement(t *testing.T) { + t.Parallel() + + mock := &mockBroadcaster{} + seqHdr := []byte{0x01} + p := newDASHPipeline("test-stream", mock, seqHdr) + + // First GOP: keyframe + delta. + p.processVideoSamples([]mediaSample{ + {PTS: 0, DTS: 0, IsKeyframe: true, Data: []byte{0x10}}, + {PTS: 33333, DTS: 33333, IsKeyframe: false, Data: []byte{0x11}}, + }) + + // Second GOP: another keyframe + delta. + p.processVideoSamples([]mediaSample{ + {PTS: 66666, DTS: 66666, IsKeyframe: true, Data: []byte{0x20}}, + {PTS: 99999, DTS: 99999, IsKeyframe: false, Data: []byte{0x21}}, + }) + + // Third GOP: yet another keyframe. + p.processVideoSamples([]mediaSample{ + {PTS: 133332, DTS: 133332, IsKeyframe: true, Data: []byte{0x30}}, + }) + + if len(mock.videoFrames) != 5 { + t.Fatalf("expected 5 video frames, got %d", len(mock.videoFrames)) + } + + // Group IDs: first keyframe=1, second keyframe=2, third keyframe=3. + expectedGroupIDs := []uint32{1, 1, 2, 2, 3} + for i, f := range mock.videoFrames { + if f.GroupID != expectedGroupIDs[i] { + t.Errorf("frame %d: GroupID = %d, want %d", i, f.GroupID, expectedGroupIDs[i]) + } + } + + // All keyframes should have SPS. + keyframeIndices := []int{0, 2, 4} + for _, idx := range keyframeIndices { + if mock.videoFrames[idx].SPS == nil { + t.Errorf("frame %d (keyframe): expected SPS to be set", idx) + } + } + + // All delta frames should not have SPS. + deltaIndices := []int{1, 3} + for _, idx := range deltaIndices { + if mock.videoFrames[idx].SPS != nil { + t.Errorf("frame %d (delta): expected SPS=nil", idx) + } + } +} + +func TestDASHPipelineStreamSnapshot(t *testing.T) { + t.Parallel() + + mock := &mockBroadcaster{} + p := newDASHPipeline("test-stream", mock, nil) + + // Push startTime back so UptimeMs is reliably > 0. + p.startTime = p.startTime.Add(-10 * time.Millisecond) + + // Process some samples so counters are non-zero. + p.processVideoSamples([]mediaSample{ + {PTS: 0, DTS: 0, IsKeyframe: true, Data: []byte{0x01}}, + }) + p.processAudioSamples([]mediaSample{ + {PTS: 0, Data: []byte{0x02}}, + }) + + snap := p.StreamSnapshot() + + if snap.Protocol != "DASH" { + t.Errorf("Protocol = %q, want %q", snap.Protocol, "DASH") + } + + if snap.UptimeMs <= 0 { + t.Errorf("UptimeMs = %d, want > 0", snap.UptimeMs) + } + + if snap.Timestamp <= 0 { + t.Errorf("Timestamp = %d, want > 0", snap.Timestamp) + } + + if snap.Video.Codec != "AV1" { + t.Errorf("Video.Codec = %q, want %q", snap.Video.Codec, "AV1") + } + + if snap.Video.TotalFrames != 1 { + t.Errorf("Video.TotalFrames = %d, want 1", snap.Video.TotalFrames) + } + + if snap.ViewerCount != 0 { + t.Errorf("ViewerCount = %d, want 0", snap.ViewerCount) + } +} From 35acd59b839fc5ad0b3cd9a350001efd287a2540 Mon Sep 17 00:00:00 2001 From: Thomas Symborski Date: Wed, 18 Mar 2026 16:48:44 -0400 Subject: [PATCH 09/14] feat: add DASH puller orchestrator for live AV1-DASH ingest Implements the Puller type that manages active DASH pull connections. Pull() fetches and validates the MPD manifest synchronously, then starts a background goroutine that continuously downloads segments, parses them via parseInitSegment/parseMediaSegment, and feeds samples through the dashPipeline to the distribution Relay. Stop() cancels an active pull, and ActivePulls() returns a snapshot of all pulls. Interfaces (DistributionServer, StreamManager) decouple from concrete types. The fetchURL helper uses http.NewRequestWithContext for proper cancellation. Tests cover validation, duplicate detection, lifecycle management, HTTP fetching, and end-to-end pull/stop with httptest. --- ingest/dash/puller.go | 373 ++++++++++++++++++++++++++++++ ingest/dash/puller_test.go | 451 +++++++++++++++++++++++++++++++++++++ 2 files changed, 824 insertions(+) create mode 100644 ingest/dash/puller.go create mode 100644 ingest/dash/puller_test.go diff --git a/ingest/dash/puller.go b/ingest/dash/puller.go new file mode 100644 index 0000000..08bbc51 --- /dev/null +++ b/ingest/dash/puller.go @@ -0,0 +1,373 @@ +package dash + +import ( + "context" + "fmt" + "io" + "log/slog" + "net/http" + "net/url" + "strings" + "sync" + "time" + + "github.com/zsiec/prism/distribution" +) + +// PullRequest describes a DASH source to pull from. +type PullRequest struct { + URL string `json:"url"` + StreamKey string `json:"streamKey"` + VideoRepID string `json:"videoRepresentationId,omitempty"` + AudioRepID string `json:"audioRepresentationId,omitempty"` +} + +// DistributionServer is the subset of distribution.Server needed by the puller. +type DistributionServer interface { + RegisterStream(key string) *distribution.Relay + UnregisterStream(key string) + SetPipeline(key string, p distribution.StatsProvider) +} + +// StreamManager is the subset of stream.Manager needed by the puller. +type StreamManager interface { + Create(key string) (interface{}, bool) + Remove(key string) +} + +type activePull struct { + req PullRequest + cancel context.CancelFunc +} + +// Puller manages active DASH pull connections. +type Puller struct { + log *slog.Logger + client *http.Client + mu sync.Mutex + pulls map[string]*activePull +} + +// NewPuller creates a Puller with a default HTTP client. If log is nil, +// slog.Default() is used. +func NewPuller(log *slog.Logger) *Puller { + if log == nil { + log = slog.Default() + } + return &Puller{ + log: log.With("component", "dash-puller"), + client: &http.Client{ + Timeout: 30 * time.Second, + }, + pulls: make(map[string]*activePull), + } +} + +// Pull validates the request, fetches the MPD manifest synchronously, and +// starts a background goroutine that continuously fetches and processes +// DASH segments. It returns immediately after the MPD is validated. +func (p *Puller) Pull(ctx context.Context, req PullRequest, distSrv DistributionServer, mgr StreamManager) error { + if req.URL == "" { + return fmt.Errorf("url is required") + } + if req.StreamKey == "" { + return fmt.Errorf("streamKey is required") + } + + p.mu.Lock() + if _, exists := p.pulls[req.StreamKey]; exists { + p.mu.Unlock() + return fmt.Errorf("pull already active for stream key %q", req.StreamKey) + } + p.mu.Unlock() + + // Fetch and validate MPD synchronously so the caller gets immediate + // feedback on bad URLs or malformed manifests. + p.log.Info("fetching MPD", "url", req.URL, "stream_key", req.StreamKey) + mpdData, err := p.fetchURL(ctx, req.URL) + if err != nil { + return fmt.Errorf("fetch MPD: %w", err) + } + + info, err := parseMPD(mpdData) + if err != nil { + return fmt.Errorf("parse MPD: %w", err) + } + + if !info.IsDynamic { + return fmt.Errorf("MPD is not dynamic (live); static MPDs are not supported") + } + + // Select video representation. + videoRep, videoTmpl, err := selectRepresentation(info.VideoAdaptations, req.VideoRepID) + if err != nil { + return fmt.Errorf("select video representation: %w", err) + } + + // Select audio representation (optional — some streams are video-only). + var audioRep representationInfo + var audioTmpl segmentTemplate + hasAudio := len(info.AudioAdaptations) > 0 + if hasAudio { + audioRep, audioTmpl, err = selectRepresentation(info.AudioAdaptations, req.AudioRepID) + if err != nil { + return fmt.Errorf("select audio representation: %w", err) + } + } + + // Double-check for duplicate before committing. + pullCtx, cancel := context.WithCancel(ctx) + + p.mu.Lock() + if _, exists := p.pulls[req.StreamKey]; exists { + p.mu.Unlock() + cancel() + return fmt.Errorf("pull already active for stream key %q", req.StreamKey) + } + p.pulls[req.StreamKey] = &activePull{req: req, cancel: cancel} + p.mu.Unlock() + + p.log.Info("starting DASH pull", + "stream_key", req.StreamKey, + "video_rep", videoRep.ID, + "video_dims", fmt.Sprintf("%dx%d", videoRep.Width, videoRep.Height), + "has_audio", hasAudio, + ) + + go p.runPullLoop(pullCtx, req, info, videoRep, videoTmpl, audioRep, audioTmpl, hasAudio, distSrv, mgr) + + return nil +} + +// Stop cancels an active pull by stream key. +func (p *Puller) Stop(streamKey string) error { + p.mu.Lock() + ap, ok := p.pulls[streamKey] + p.mu.Unlock() + + if !ok { + return fmt.Errorf("no active pull for stream key %q", streamKey) + } + + ap.cancel() + return nil +} + +// ActivePulls returns a snapshot of all active pull requests. +func (p *Puller) ActivePulls() []PullRequest { + p.mu.Lock() + defer p.mu.Unlock() + + out := make([]PullRequest, 0, len(p.pulls)) + for _, ap := range p.pulls { + out = append(out, ap.req) + } + return out +} + +// fetchURL performs an HTTP GET with context support and returns the response body. +func (p *Puller) fetchURL(ctx context.Context, rawURL string) ([]byte, error) { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, rawURL, nil) + if err != nil { + return nil, fmt.Errorf("create request: %w", err) + } + + resp, err := p.client.Do(req) + if err != nil { + return nil, fmt.Errorf("HTTP GET %s: %w", rawURL, err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("HTTP GET %s: status %d", rawURL, resp.StatusCode) + } + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("read response body: %w", err) + } + return body, nil +} + +// deriveBaseURL computes the base URL for resolving relative segment URLs. +// If the MPD contains a BaseURL, that is used; otherwise the MPD URL's +// directory is used. +func deriveBaseURL(mpdURL string, mpdBaseURL string) string { + if mpdBaseURL != "" { + return mpdBaseURL + } + // Strip filename from MPD URL to get directory. + u, err := url.Parse(mpdURL) + if err != nil { + return "" + } + idx := strings.LastIndex(u.Path, "/") + if idx >= 0 { + u.Path = u.Path[:idx+1] + } + return u.String() +} + +// runPullLoop is the background goroutine that continuously fetches DASH +// segments and feeds them into the distribution pipeline. +func (p *Puller) runPullLoop( + ctx context.Context, + req PullRequest, + info *mpdInfo, + videoRep representationInfo, + videoTmpl segmentTemplate, + audioRep representationInfo, + audioTmpl segmentTemplate, + hasAudio bool, + distSrv DistributionServer, + mgr StreamManager, +) { + defer func() { + distSrv.UnregisterStream(req.StreamKey) + if mgr != nil { + mgr.Remove(req.StreamKey) + } + p.mu.Lock() + delete(p.pulls, req.StreamKey) + p.mu.Unlock() + p.log.Info("DASH pull ended", "stream_key", req.StreamKey) + }() + + baseURL := deriveBaseURL(req.URL, info.BaseURL) + + // Download and parse video init segment. + videoInitURL := resolveInitURL(videoTmpl, videoRep.ID, baseURL) + videoInitData, err := p.fetchURL(ctx, videoInitURL) + if err != nil { + p.log.Error("fetch video init segment", "url", videoInitURL, "error", err) + return + } + + initInfo, err := parseInitSegment(videoInitData) + if err != nil { + p.log.Error("parse video init segment", "error", err) + return + } + + // If audio uses a different init URL, download that too. + if hasAudio { + audioInitURL := resolveInitURL(audioTmpl, audioRep.ID, baseURL) + if audioInitURL != videoInitURL { + audioInitData, err := p.fetchURL(ctx, audioInitURL) + if err != nil { + p.log.Warn("fetch audio init segment", "url", audioInitURL, "error", err) + hasAudio = false + } else { + audioInit, err := parseInitSegment(audioInitData) + if err != nil { + p.log.Warn("parse audio init segment", "error", err) + hasAudio = false + } else { + // Merge audio info into initInfo. + initInfo.AudioCodec = audioInit.AudioCodec + initInfo.SampleRate = audioInit.SampleRate + initInfo.Channels = audioInit.Channels + initInfo.AudioTimescale = audioInit.AudioTimescale + initInfo.AudioTrackID = audioInit.AudioTrackID + initInfo.audioTrex = audioInit.audioTrex + } + } + } + } + + // Register stream with distribution server. + relay := distSrv.RegisterStream(req.StreamKey) + if mgr != nil { + mgr.Create(req.StreamKey) + } + + // Set video info on relay. + relay.SetVideoInfo(distribution.VideoInfo{ + Codec: initInfo.VideoCodec, + Width: initInfo.Width, + Height: initInfo.Height, + }) + + // Set audio info if available. + if hasAudio { + relay.SetAudioTrackCount(1) + relay.SetAudioInfo(distribution.AudioInfo{ + Codec: initInfo.AudioCodec, + SampleRate: initInfo.SampleRate, + Channels: initInfo.Channels, + }) + } + + // Create pipeline and register as stats provider. + pipeline := newDASHPipeline(req.StreamKey, relay, initInfo.SeqHeaderOBU) + distSrv.SetPipeline(req.StreamKey, pipeline) + + // Segment fetch loop. + lastSegNum := -1 + segmentDuration := time.Duration(0) + if videoTmpl.Timescale > 0 && videoTmpl.Duration > 0 { + segmentDuration = time.Duration(float64(videoTmpl.Duration) / float64(videoTmpl.Timescale) * float64(time.Second)) + } + + for { + if ctx.Err() != nil { + return + } + + segNum := computeSegmentNumber(time.Now(), info.AvailabilityStartTime, videoTmpl) + + if segNum == lastSegNum { + // Sleep until the next segment boundary. + sleepDur := segmentDuration / 4 + if sleepDur < 100*time.Millisecond { + sleepDur = 100 * time.Millisecond + } + select { + case <-ctx.Done(): + return + case <-time.After(sleepDur): + continue + } + } + lastSegNum = segNum + + // Fetch video segment. + videoMediaURL := resolveMediaURL(videoTmpl, videoRep.ID, segNum, baseURL) + videoData, err := p.fetchURL(ctx, videoMediaURL) + if err != nil { + if ctx.Err() != nil { + return + } + p.log.Warn("fetch video segment", "seg", segNum, "url", videoMediaURL, "error", err) + continue + } + + videoSamples, err := parseMediaSegment(videoData, initInfo, true) + if err != nil { + p.log.Warn("parse video segment", "seg", segNum, "error", err) + continue + } + pipeline.processVideoSamples(videoSamples) + + // Fetch audio segment. + if hasAudio { + audioSegNum := computeSegmentNumber(time.Now(), info.AvailabilityStartTime, audioTmpl) + audioMediaURL := resolveMediaURL(audioTmpl, audioRep.ID, audioSegNum, baseURL) + audioData, err := p.fetchURL(ctx, audioMediaURL) + if err != nil { + if ctx.Err() != nil { + return + } + p.log.Warn("fetch audio segment", "seg", audioSegNum, "url", audioMediaURL, "error", err) + continue + } + + audioSamples, err := parseMediaSegment(audioData, initInfo, false) + if err != nil { + p.log.Warn("parse audio segment", "seg", audioSegNum, "error", err) + continue + } + pipeline.processAudioSamples(audioSamples) + } + } +} diff --git a/ingest/dash/puller_test.go b/ingest/dash/puller_test.go new file mode 100644 index 0000000..d6406ba --- /dev/null +++ b/ingest/dash/puller_test.go @@ -0,0 +1,451 @@ +package dash + +import ( + "context" + "log/slog" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/zsiec/prism/distribution" +) + +func TestNewPuller(t *testing.T) { + t.Parallel() + + t.Run("with logger", func(t *testing.T) { + log := slog.Default() + p := NewPuller(log) + if p == nil { + t.Fatal("NewPuller returned nil") + } + if p.pulls == nil { + t.Error("pulls map is nil") + } + if p.client == nil { + t.Error("client is nil") + } + }) + + t.Run("nil logger uses default", func(t *testing.T) { + p := NewPuller(nil) + if p == nil { + t.Fatal("NewPuller returned nil") + } + if p.log == nil { + t.Error("log is nil") + } + }) +} + +func TestPullerValidation(t *testing.T) { + t.Parallel() + + p := NewPuller(slog.Default()) + + t.Run("empty URL", func(t *testing.T) { + err := p.Pull(context.Background(), PullRequest{ + StreamKey: "test", + }, nil, nil) + if err == nil { + t.Error("expected error for empty URL") + } + }) + + t.Run("empty StreamKey", func(t *testing.T) { + err := p.Pull(context.Background(), PullRequest{ + URL: "http://example.com/manifest.mpd", + }, nil, nil) + if err == nil { + t.Error("expected error for empty StreamKey") + } + }) +} + +func TestPullerDuplicateStreamKey(t *testing.T) { + t.Parallel() + + p := NewPuller(slog.Default()) + + // Insert a fake active pull directly into the map. + p.mu.Lock() + p.pulls["test-stream"] = &activePull{ + req: PullRequest{URL: "http://example.com/dash.mpd", StreamKey: "test-stream"}, + cancel: func() {}, + } + p.mu.Unlock() + + err := p.Pull(context.Background(), PullRequest{ + URL: "http://example.com/dash.mpd", + StreamKey: "test-stream", + }, nil, nil) + if err == nil { + t.Error("expected error for duplicate stream key") + } +} + +func TestPullerStopNonexistent(t *testing.T) { + t.Parallel() + + p := NewPuller(slog.Default()) + err := p.Stop("nonexistent") + if err == nil { + t.Error("expected error for nonexistent stream key") + } +} + +func TestPullerActivePulls(t *testing.T) { + t.Parallel() + + p := NewPuller(slog.Default()) + + // Empty initially. + pulls := p.ActivePulls() + if len(pulls) != 0 { + t.Errorf("expected 0 active pulls, got %d", len(pulls)) + } + + // Insert fake pulls. + p.mu.Lock() + p.pulls["stream-a"] = &activePull{ + req: PullRequest{URL: "http://a.com/dash.mpd", StreamKey: "stream-a"}, + cancel: func() {}, + } + p.pulls["stream-b"] = &activePull{ + req: PullRequest{URL: "http://b.com/dash.mpd", StreamKey: "stream-b"}, + cancel: func() {}, + } + p.mu.Unlock() + + pulls = p.ActivePulls() + if len(pulls) != 2 { + t.Errorf("expected 2 active pulls, got %d", len(pulls)) + } + + // Verify both keys are present. + keys := map[string]bool{} + for _, pr := range pulls { + keys[pr.StreamKey] = true + } + if !keys["stream-a"] || !keys["stream-b"] { + t.Errorf("unexpected keys: %v", keys) + } +} + +func TestPullerStop(t *testing.T) { + t.Parallel() + + p := NewPuller(slog.Default()) + + cancelled := false + p.mu.Lock() + p.pulls["test-stream"] = &activePull{ + req: PullRequest{URL: "http://example.com/dash.mpd", StreamKey: "test-stream"}, + cancel: func() { cancelled = true }, + } + p.mu.Unlock() + + err := p.Stop("test-stream") + if err != nil { + t.Fatalf("Stop returned error: %v", err) + } + if !cancelled { + t.Error("cancel function was not called") + } +} + +func TestDeriveBaseURL(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + mpdURL string + mpdBaseURL string + want string + }{ + { + name: "MPD has BaseURL", + mpdURL: "http://cdn.example.com/live/manifest.mpd", + mpdBaseURL: "http://cdn.example.com/live/", + want: "http://cdn.example.com/live/", + }, + { + name: "derive from MPD URL", + mpdURL: "http://cdn.example.com/live/manifest.mpd", + mpdBaseURL: "", + want: "http://cdn.example.com/live/", + }, + { + name: "MPD URL with deeper path", + mpdURL: "http://cdn.example.com/a/b/c/dash.mpd", + mpdBaseURL: "", + want: "http://cdn.example.com/a/b/c/", + }, + { + name: "MPD URL at root", + mpdURL: "http://cdn.example.com/manifest.mpd", + mpdBaseURL: "", + want: "http://cdn.example.com/", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := deriveBaseURL(tt.mpdURL, tt.mpdBaseURL) + if got != tt.want { + t.Errorf("deriveBaseURL(%q, %q) = %q, want %q", tt.mpdURL, tt.mpdBaseURL, got, tt.want) + } + }) + } +} + +func TestPullerFetchURL(t *testing.T) { + t.Parallel() + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/ok": + w.WriteHeader(http.StatusOK) + w.Write([]byte("hello")) + case "/error": + w.WriteHeader(http.StatusInternalServerError) + default: + w.WriteHeader(http.StatusNotFound) + } + })) + defer srv.Close() + + p := NewPuller(slog.Default()) + + t.Run("success", func(t *testing.T) { + data, err := p.fetchURL(context.Background(), srv.URL+"/ok") + if err != nil { + t.Fatalf("fetchURL: %v", err) + } + if string(data) != "hello" { + t.Errorf("body = %q, want %q", string(data), "hello") + } + }) + + t.Run("server error", func(t *testing.T) { + _, err := p.fetchURL(context.Background(), srv.URL+"/error") + if err == nil { + t.Error("expected error for 500 status") + } + }) + + t.Run("cancelled context", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + cancel() + _, err := p.fetchURL(ctx, srv.URL+"/ok") + if err == nil { + t.Error("expected error for cancelled context") + } + }) +} + +// mockDistributionServer implements DistributionServer for testing. +type mockDistributionServer struct { + registered map[string]*distribution.Relay + unregistered []string + pipelines map[string]distribution.StatsProvider +} + +func newMockDistributionServer() *mockDistributionServer { + return &mockDistributionServer{ + registered: make(map[string]*distribution.Relay), + pipelines: make(map[string]distribution.StatsProvider), + } +} + +func (m *mockDistributionServer) RegisterStream(key string) *distribution.Relay { + r := distribution.NewRelay() + m.registered[key] = r + return r +} + +func (m *mockDistributionServer) UnregisterStream(key string) { + m.unregistered = append(m.unregistered, key) + delete(m.registered, key) +} + +func (m *mockDistributionServer) SetPipeline(key string, p distribution.StatsProvider) { + m.pipelines[key] = p +} + +// mockStreamManager implements StreamManager for testing. +type mockStreamManager struct { + created []string + removed []string +} + +func (m *mockStreamManager) Create(key string) (interface{}, bool) { + m.created = append(m.created, key) + return nil, true +} + +func (m *mockStreamManager) Remove(key string) { + m.removed = append(m.removed, key) +} + +func TestPullerPullStaticMPDRejected(t *testing.T) { + t.Parallel() + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + w.Write([]byte(testMPDStatic)) + })) + defer srv.Close() + + p := NewPuller(slog.Default()) + distSrv := newMockDistributionServer() + + err := p.Pull(context.Background(), PullRequest{ + URL: srv.URL + "/manifest.mpd", + StreamKey: "test-static", + }, distSrv, nil) + + if err == nil { + t.Fatal("expected error for static MPD") + } + + // No stream should have been registered. + if len(distSrv.registered) != 0 { + t.Errorf("expected 0 registered streams, got %d", len(distSrv.registered)) + } +} + +func TestPullerPullBadURL(t *testing.T) { + t.Parallel() + + p := NewPuller(slog.Default()) + distSrv := newMockDistributionServer() + + err := p.Pull(context.Background(), PullRequest{ + URL: "http://localhost:1/nonexistent.mpd", + StreamKey: "test-bad", + }, distSrv, nil) + + if err == nil { + t.Fatal("expected error for unreachable URL") + } +} + +func TestPullerPullWithHTTP(t *testing.T) { + t.Parallel() + + // Serve a dynamic MPD. The pull loop will try to fetch init segments, + // which will fail, causing the loop to exit. We verify the lifecycle: + // stream registered, then cleaned up after the loop exits. + requestCount := 0 + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch { + case r.URL.Path == "/manifest.mpd": + w.WriteHeader(http.StatusOK) + w.Write([]byte(testMPDDynamic)) + requestCount++ + default: + // Init/media segments: return 404 so the loop exits. + w.WriteHeader(http.StatusNotFound) + } + })) + defer srv.Close() + + p := NewPuller(slog.Default()) + distSrv := newMockDistributionServer() + mgr := &mockStreamManager{} + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + err := p.Pull(ctx, PullRequest{ + URL: srv.URL + "/manifest.mpd", + StreamKey: "test-live", + }, distSrv, mgr) + + if err != nil { + t.Fatalf("Pull returned error: %v", err) + } + + // The pull should be registered immediately. + pulls := p.ActivePulls() + if len(pulls) != 1 { + t.Fatalf("expected 1 active pull, got %d", len(pulls)) + } + if pulls[0].StreamKey != "test-live" { + t.Errorf("stream key = %q, want %q", pulls[0].StreamKey, "test-live") + } + + // The background goroutine will fail on init segment fetch (404) and + // clean up. Wait briefly for that. + time.Sleep(500 * time.Millisecond) + + // After loop exits, pull should be removed from active list. + pulls = p.ActivePulls() + if len(pulls) != 0 { + t.Errorf("expected 0 active pulls after loop exit, got %d", len(pulls)) + } + + // Distribution server should have been unregistered. + if len(distSrv.unregistered) == 0 { + t.Error("expected stream to be unregistered after loop exit") + } + + // Stream manager should have been cleaned up. + if len(mgr.removed) == 0 { + t.Error("expected stream manager Remove to be called") + } +} + +func TestPullerPullAndStop(t *testing.T) { + t.Parallel() + + // Serve a dynamic MPD and valid-ish init segments that won't parse + // as fMP4 — the loop will exit on parse error. Instead, we'll use + // a server that delays responses so we can test Stop(). + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch { + case r.URL.Path == "/manifest.mpd": + w.WriteHeader(http.StatusOK) + w.Write([]byte(testMPDDynamic)) + default: + // Block until the request context is cancelled (simulating + // a slow CDN so we can test Stop cancellation). + <-r.Context().Done() + w.WriteHeader(http.StatusServiceUnavailable) + } + })) + defer srv.Close() + + p := NewPuller(slog.Default()) + distSrv := newMockDistributionServer() + + ctx := context.Background() + + err := p.Pull(ctx, PullRequest{ + URL: srv.URL + "/manifest.mpd", + StreamKey: "stop-test", + }, distSrv, nil) + if err != nil { + t.Fatalf("Pull returned error: %v", err) + } + + // Verify active. + if len(p.ActivePulls()) != 1 { + t.Fatalf("expected 1 active pull, got %d", len(p.ActivePulls())) + } + + // Stop the pull. + err = p.Stop("stop-test") + if err != nil { + t.Fatalf("Stop returned error: %v", err) + } + + // Wait for cleanup. + time.Sleep(500 * time.Millisecond) + + if len(p.ActivePulls()) != 0 { + t.Errorf("expected 0 active pulls after Stop, got %d", len(p.ActivePulls())) + } +} From b667dc899d7e53a85ab1ae8d97d7c511f4c039a1 Mon Sep 17 00:00:00 2001 From: Thomas Symborski Date: Wed, 18 Mar 2026 16:52:27 -0400 Subject: [PATCH 10/14] feat: add DASH pull REST API endpoints and wire puller into server Add /api/dash-pull endpoints (GET, POST, DELETE, OPTIONS) mirroring the existing SRT pull API pattern. Wire the DASH puller into cmd/prism/main.go with DASHPull/DASHStop/DASHList callbacks, a streamManagerAdapter to bridge the interface mismatch, and DASH_SOURCE env var for startup pulls. --- cmd/prism/main.go | 64 ++++++++++++++++++++++++++-- distribution/server.go | 24 +++++++++++ distribution/server_dash_handlers.go | 63 +++++++++++++++++++++++++++ distribution/server_test.go | 55 ++++++++++++++++++++++++ 4 files changed, 202 insertions(+), 4 deletions(-) create mode 100644 distribution/server_dash_handlers.go diff --git a/cmd/prism/main.go b/cmd/prism/main.go index e570090..78ebeb5 100644 --- a/cmd/prism/main.go +++ b/cmd/prism/main.go @@ -19,6 +19,7 @@ import ( "github.com/zsiec/prism/certs" "github.com/zsiec/prism/distribution" "github.com/zsiec/prism/ingest" + dashingest "github.com/zsiec/prism/ingest/dash" srtingest "github.com/zsiec/prism/ingest/srt" "github.com/zsiec/prism/pipeline" "github.com/zsiec/prism/stream" @@ -80,6 +81,7 @@ func main() { a.handleNewStream(ctx, key, input, format) }) a.srtCaller = srtingest.NewCaller(a.registry, nil) + a.dashPuller = dashingest.NewPuller(nil) var distErr error a.distSrv, distErr = distribution.NewServer(distribution.ServerConfig{ @@ -97,6 +99,18 @@ func main() { return a.srtCaller.Stop(streamKey) }, SRTList: a.listSRTPulls, + DASHPull: func(url, streamKey, videoRepID, audioRepID string) error { + return a.dashPuller.Pull(ctx, dashingest.PullRequest{ + URL: url, + StreamKey: streamKey, + VideoRepID: videoRepID, + AudioRepID: audioRepID, + }, a.distSrv, &streamManagerAdapter{a.mgr}) + }, + DASHStop: func(streamKey string) error { + return a.dashPuller.Stop(streamKey) + }, + DASHList: a.listDASHPulls, StreamLister: a.listStreams, IngestLookup: a.lookupIngest, }) @@ -138,6 +152,18 @@ func main() { return a.distSrv.Start(ctx) }) + if dashURL := os.Getenv("DASH_SOURCE"); dashURL != "" { + g.Go(func() error { + <-time.After(500 * time.Millisecond) + return a.dashPuller.Pull(ctx, dashingest.PullRequest{ + URL: dashURL, + StreamKey: envOr("DASH_STREAM_KEY", "dash-live"), + VideoRepID: os.Getenv("DASH_VIDEO_REP"), + AudioRepID: os.Getenv("DASH_AUDIO_REP"), + }, a.distSrv, &streamManagerAdapter{a.mgr}) + }) + } + if err := g.Wait(); err != nil { slog.Error("server error", "error", err) os.Exit(1) @@ -145,10 +171,11 @@ func main() { } type app struct { - mgr *stream.Manager - registry *ingest.Registry - srtCaller *srtingest.Caller - distSrv *distribution.Server + mgr *stream.Manager + registry *ingest.Registry + srtCaller *srtingest.Caller + dashPuller *dashingest.Puller + distSrv *distribution.Server } func (a *app) listSRTPulls() []distribution.SRTPullInfo { @@ -164,6 +191,20 @@ func (a *app) listSRTPulls() []distribution.SRTPullInfo { return out } +func (a *app) listDASHPulls() []distribution.DASHPullInfo { + pulls := a.dashPuller.ActivePulls() + out := make([]distribution.DASHPullInfo, len(pulls)) + for i, p := range pulls { + out[i] = distribution.DASHPullInfo{ + URL: p.URL, + StreamKey: p.StreamKey, + VideoRepID: p.VideoRepID, + AudioRepID: p.AudioRepID, + } + } + return out +} + func (a *app) listStreams() []distribution.StreamInfo { streams := a.mgr.List() infos := make([]distribution.StreamInfo, len(streams)) @@ -281,3 +322,18 @@ func buildStreamDescription(info distribution.StreamInfo) string { return strings.Join(parts, " · ") } + +// streamManagerAdapter adapts *stream.Manager to the dash.StreamManager +// interface, which uses interface{} in the Create return to avoid an import +// cycle with the stream package. +type streamManagerAdapter struct { + mgr *stream.Manager +} + +func (a *streamManagerAdapter) Create(key string) (interface{}, bool) { + return a.mgr.Create(key) +} + +func (a *streamManagerAdapter) Remove(key string) { + a.mgr.Remove(key) +} diff --git a/distribution/server.go b/distribution/server.go index 6c59cba..442a882 100644 --- a/distribution/server.go +++ b/distribution/server.go @@ -104,6 +104,23 @@ type SRTPullInfo struct { StreamID string `json:"streamId,omitempty"` } +// DASHPullFunc initiates a DASH pull from a remote MPD URL. +type DASHPullFunc func(url, streamKey, videoRepID, audioRepID string) error + +// DASHStopFunc stops an active DASH pull by stream key. +type DASHStopFunc func(streamKey string) error + +// DASHListFunc returns all active DASH pulls. +type DASHListFunc func() []DASHPullInfo + +// DASHPullInfo describes an active DASH pull. +type DASHPullInfo struct { + URL string `json:"url"` + StreamKey string `json:"streamKey"` + VideoRepID string `json:"videoRepresentationId,omitempty"` + AudioRepID string `json:"audioRepresentationId,omitempty"` +} + // WebTransport session close error codes sent to clients via CloseWithError. const ( wtErrStreamNotFound webtransport.SessionErrorCode = 1 @@ -131,6 +148,9 @@ type ServerConfig struct { SRTPull SRTPullFunc SRTStop SRTStopFunc SRTList SRTListFunc + DASHPull DASHPullFunc + DASHStop DASHStopFunc + DASHList DASHListFunc ExtraRoutes func(mux *http.ServeMux) // OnStreamRegistered is called after a new stream relay is created @@ -274,6 +294,10 @@ func (s *Server) registerAPIRoutes(mux *http.ServeMux) { mux.HandleFunc("POST /api/srt-pull", s.handleSRTPullCreate) mux.HandleFunc("DELETE /api/srt-pull", s.handleSRTPullStop) mux.HandleFunc("OPTIONS /api/srt-pull", s.handleSRTPullOptions) + mux.HandleFunc("GET /api/dash-pull", s.handleDASHPullList) + mux.HandleFunc("POST /api/dash-pull", s.handleDASHPullCreate) + mux.HandleFunc("DELETE /api/dash-pull", s.handleDASHPullStop) + mux.HandleFunc("OPTIONS /api/dash-pull", s.handleDASHPullOptions) if s.config.ExtraRoutes != nil { s.config.ExtraRoutes(mux) diff --git a/distribution/server_dash_handlers.go b/distribution/server_dash_handlers.go new file mode 100644 index 0000000..2c7edd1 --- /dev/null +++ b/distribution/server_dash_handlers.go @@ -0,0 +1,63 @@ +package distribution + +import ( + "encoding/json" + "net/http" +) + +func (s *Server) handleDASHPullList(w http.ResponseWriter, _ *http.Request) { + if s.config.DASHList == nil { + writeJSON(w, http.StatusOK, []DASHPullInfo{}) + return + } + writeJSON(w, http.StatusOK, s.config.DASHList()) +} + +func (s *Server) handleDASHPullCreate(w http.ResponseWriter, r *http.Request) { + if s.config.DASHPull == nil { + writeError(w, http.StatusNotImplemented, "DASH pull not configured") + return + } + var req struct { + URL string `json:"url"` + StreamKey string `json:"streamKey"` + VideoRepID string `json:"videoRepresentationId,omitempty"` + AudioRepID string `json:"audioRepresentationId,omitempty"` + } + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + writeError(w, http.StatusBadRequest, err.Error()) + return + } + if req.URL == "" || req.StreamKey == "" { + writeError(w, http.StatusBadRequest, "url and streamKey are required") + return + } + if err := s.config.DASHPull(req.URL, req.StreamKey, req.VideoRepID, req.AudioRepID); err != nil { + writeError(w, http.StatusConflict, err.Error()) + return + } + writeJSON(w, http.StatusCreated, map[string]string{"status": "pulling", "streamKey": req.StreamKey}) +} + +func (s *Server) handleDASHPullStop(w http.ResponseWriter, r *http.Request) { + if s.config.DASHStop == nil { + writeError(w, http.StatusNotImplemented, "DASH pull not configured") + return + } + streamKey := r.URL.Query().Get("streamKey") + if streamKey == "" { + writeError(w, http.StatusBadRequest, "streamKey query parameter required") + return + } + if err := s.config.DASHStop(streamKey); err != nil { + writeError(w, http.StatusNotFound, err.Error()) + return + } + writeJSON(w, http.StatusOK, map[string]string{"status": "stopped", "streamKey": streamKey}) +} + +func (s *Server) handleDASHPullOptions(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Access-Control-Allow-Methods", "GET, POST, DELETE, OPTIONS") + w.Header().Set("Access-Control-Allow-Headers", "Content-Type") + w.WriteHeader(http.StatusNoContent) +} diff --git a/distribution/server_test.go b/distribution/server_test.go index b5469d7..8ee4f30 100644 --- a/distribution/server_test.go +++ b/distribution/server_test.go @@ -409,3 +409,58 @@ func TestStreamLifecycleCallbacks(t *testing.T) { // If we get here without panicking, the test passes. }) } + +func TestHandleDASHPullList(t *testing.T) { + t.Parallel() + + srv := newTestServer(t) + handler := srv.APIHandler() + + req := httptest.NewRequest("GET", "/api/dash-pull", nil) + rec := httptest.NewRecorder() + handler.ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("status = %d, want %d", rec.Code, http.StatusOK) + } + + // DASHList is nil, should return empty array. + body := strings.TrimSpace(rec.Body.String()) + if body != "[]" { + t.Fatalf("body = %q, want %q", body, "[]") + } +} + +func TestHandleDASHPullCreateMissingFields(t *testing.T) { + t.Parallel() + + srv := newTestServer(t) + srv.config.DASHPull = func(_, _, _, _ string) error { return nil } + handler := srv.APIHandler() + + req := httptest.NewRequest("POST", "/api/dash-pull", strings.NewReader(`{"url":""}`)) + req.Header.Set("Content-Type", "application/json") + rec := httptest.NewRecorder() + handler.ServeHTTP(rec, req) + + if rec.Code != http.StatusBadRequest { + t.Fatalf("status = %d, want %d", rec.Code, http.StatusBadRequest) + } +} + +func TestHandleDASHPullNotConfigured(t *testing.T) { + t.Parallel() + + srv := newTestServer(t) + // DASHPull is nil. + handler := srv.APIHandler() + + req := httptest.NewRequest("POST", "/api/dash-pull", strings.NewReader(`{"url":"https://example.com/live.mpd","streamKey":"test"}`)) + req.Header.Set("Content-Type", "application/json") + rec := httptest.NewRecorder() + handler.ServeHTTP(rec, req) + + if rec.Code != http.StatusNotImplemented { + t.Fatalf("status = %d, want %d", rec.Code, http.StatusNotImplemented) + } +} From d984f1e032f0e46d2199f3dbf84e48790d82d334 Mon Sep 17 00:00:00 2001 From: Thomas Symborski Date: Wed, 18 Mar 2026 16:53:48 -0400 Subject: [PATCH 11/14] style: apply gofmt -s formatting fixes --- cmd/prism/main.go | 4 +-- demux/av1.go | 2 +- demux/av1_test.go | 80 ++++++++++++++++++++--------------------- ingest/dash/pipeline.go | 6 ++-- 4 files changed, 46 insertions(+), 46 deletions(-) diff --git a/cmd/prism/main.go b/cmd/prism/main.go index 78ebeb5..1591501 100644 --- a/cmd/prism/main.go +++ b/cmd/prism/main.go @@ -98,7 +98,7 @@ func main() { SRTStop: func(streamKey string) error { return a.srtCaller.Stop(streamKey) }, - SRTList: a.listSRTPulls, + SRTList: a.listSRTPulls, DASHPull: func(url, streamKey, videoRepID, audioRepID string) error { return a.dashPuller.Pull(ctx, dashingest.PullRequest{ URL: url, @@ -110,7 +110,7 @@ func main() { DASHStop: func(streamKey string) error { return a.dashPuller.Stop(streamKey) }, - DASHList: a.listDASHPulls, + DASHList: a.listDASHPulls, StreamLister: a.listStreams, IngestLookup: a.lookupIngest, }) diff --git a/demux/av1.go b/demux/av1.go index 225dff5..7374e13 100644 --- a/demux/av1.go +++ b/demux/av1.go @@ -391,7 +391,7 @@ func parseColorConfig(br *av1BitReader, hdr *AV1SequenceHeader) { transferCharacteristics = br.readBits(8) matrixCoefficients = br.readBits(8) } else { - colorPrimaries = 2 // CP_UNSPECIFIED + colorPrimaries = 2 // CP_UNSPECIFIED transferCharacteristics = 2 // TC_UNSPECIFIED matrixCoefficients = 2 // MC_UNSPECIFIED } diff --git a/demux/av1_test.go b/demux/av1_test.go index f041512..c4652e6 100644 --- a/demux/av1_test.go +++ b/demux/av1_test.go @@ -427,10 +427,10 @@ func TestParseOBUHeader(t *testing.T) { wantErr: true, }, { - name: "extension flag", - data: []byte{0x0e, 0x40, 0x0b, 0x00, 0x00}, - wantType: OBUSequenceHeader, - wantHasExt: true, + name: "extension flag", + data: []byte{0x0e, 0x40, 0x0b, 0x00, 0x00}, + wantType: OBUSequenceHeader, + wantHasExt: true, wantHasSize: true, }, } @@ -601,7 +601,7 @@ func TestReducedStillPictureHeader(t *testing.T) { // Wrap in OBU header: type=1, has_size=1 obu := make([]byte, 0, 2+len(payload)) - obu = append(obu, 0x0a) // OBU header: type=1, has_size=1 + obu = append(obu, 0x0a) // OBU header: type=1, has_size=1 obu = append(obu, byte(len(payload))) // LEB128 size obu = append(obu, payload...) @@ -732,41 +732,41 @@ func TestParseSequenceHeaderWithDecoderModelInfo(t *testing.T) { } } - appendBits(0, 3) // seq_profile = 0 - appendBits(0, 1) // still_picture = 0 - appendBits(0, 1) // reduced_still_picture_header = 0 - appendBits(1, 1) // timing_info_present_flag = 1 - appendBits(1, 32) // num_units_in_display_tick = 1 + appendBits(0, 3) // seq_profile = 0 + appendBits(0, 1) // still_picture = 0 + appendBits(0, 1) // reduced_still_picture_header = 0 + appendBits(1, 1) // timing_info_present_flag = 1 + appendBits(1, 32) // num_units_in_display_tick = 1 appendBits(30, 32) // time_scale = 30 - appendBits(0, 1) // equal_picture_interval = 0 - appendBits(1, 1) // decoder_model_info_present_flag = 1 - appendBits(4, 5) // buffer_delay_length_minus_1 = 4 - appendBits(1, 32) // num_units_in_decoding_tick = 1 - appendBits(0, 5) // buffer_removal_time_length_minus_1 = 0 - appendBits(0, 5) // frame_presentation_time_length_minus_1 = 0 - appendBits(0, 1) // initial_display_delay_present_flag = 0 - appendBits(0, 5) // operating_points_cnt_minus_1 = 0 - appendBits(0, 12) // operating_point_idc[0] = 0 - appendBits(8, 5) // seq_level_idx[0] = 8 - appendBits(0, 1) // seq_tier[0] = 0 (read because level > 7) - appendBits(1, 1) // decoder_model_present_for_this_op[0] = 1 - appendBits(0, 5) // decoder_buffer_delay[0] = 0 (5 bits) - appendBits(0, 5) // encoder_buffer_delay[0] = 0 (5 bits) - appendBits(0, 1) // low_delay_mode_flag[0] = 0 - appendBits(3, 4) // frame_width_bits_minus_1 = 3 - appendBits(3, 4) // frame_height_bits_minus_1 = 3 - appendBits(15, 4) // max_frame_width_minus_1 = 15 (width=16) - appendBits(15, 4) // max_frame_height_minus_1 = 15 (height=16) - appendBits(0, 1) // frame_id_numbers_present_flag = 0 - appendBits(0, 1) // use_128x128_superblock = 0 - appendBits(0, 1) // enable_filter_intra = 0 - appendBits(0, 1) // enable_intra_edge_filter = 0 - appendBits(0, 1) // enable_interintra_compound = 0 - appendBits(0, 1) // enable_masked_compound = 0 - appendBits(0, 1) // enable_warped_motion = 0 - appendBits(0, 1) // enable_dual_filter = 0 - appendBits(0, 1) // enable_order_hint = 0 - appendBits(1, 1) // seq_choose_screen_content_tools = 1 (SELECT) + appendBits(0, 1) // equal_picture_interval = 0 + appendBits(1, 1) // decoder_model_info_present_flag = 1 + appendBits(4, 5) // buffer_delay_length_minus_1 = 4 + appendBits(1, 32) // num_units_in_decoding_tick = 1 + appendBits(0, 5) // buffer_removal_time_length_minus_1 = 0 + appendBits(0, 5) // frame_presentation_time_length_minus_1 = 0 + appendBits(0, 1) // initial_display_delay_present_flag = 0 + appendBits(0, 5) // operating_points_cnt_minus_1 = 0 + appendBits(0, 12) // operating_point_idc[0] = 0 + appendBits(8, 5) // seq_level_idx[0] = 8 + appendBits(0, 1) // seq_tier[0] = 0 (read because level > 7) + appendBits(1, 1) // decoder_model_present_for_this_op[0] = 1 + appendBits(0, 5) // decoder_buffer_delay[0] = 0 (5 bits) + appendBits(0, 5) // encoder_buffer_delay[0] = 0 (5 bits) + appendBits(0, 1) // low_delay_mode_flag[0] = 0 + appendBits(3, 4) // frame_width_bits_minus_1 = 3 + appendBits(3, 4) // frame_height_bits_minus_1 = 3 + appendBits(15, 4) // max_frame_width_minus_1 = 15 (width=16) + appendBits(15, 4) // max_frame_height_minus_1 = 15 (height=16) + appendBits(0, 1) // frame_id_numbers_present_flag = 0 + appendBits(0, 1) // use_128x128_superblock = 0 + appendBits(0, 1) // enable_filter_intra = 0 + appendBits(0, 1) // enable_intra_edge_filter = 0 + appendBits(0, 1) // enable_interintra_compound = 0 + appendBits(0, 1) // enable_masked_compound = 0 + appendBits(0, 1) // enable_warped_motion = 0 + appendBits(0, 1) // enable_dual_filter = 0 + appendBits(0, 1) // enable_order_hint = 0 + appendBits(1, 1) // seq_choose_screen_content_tools = 1 (SELECT) // seqForceScreenContentTools = 2 (SELECT), which is > 0 appendBits(1, 1) // seq_choose_integer_mv = 1 (SELECT) appendBits(0, 1) // enable_superres = 0 @@ -789,7 +789,7 @@ func TestParseSequenceHeaderWithDecoderModelInfo(t *testing.T) { // Wrap in OBU header obu := make([]byte, 0, 2+len(payload)) - obu = append(obu, 0x0a) // OBU header: type=1, has_size=1 + obu = append(obu, 0x0a) // OBU header: type=1, has_size=1 obu = append(obu, byte(len(payload))) // LEB128 size obu = append(obu, payload...) diff --git a/ingest/dash/pipeline.go b/ingest/dash/pipeline.go index 2292049..4e5aa00 100644 --- a/ingest/dash/pipeline.go +++ b/ingest/dash/pipeline.go @@ -99,9 +99,9 @@ func (p *dashPipeline) processAudioSamples(samples []mediaSample) { // and REST API. func (p *dashPipeline) StreamSnapshot() distribution.StreamSnapshot { return distribution.StreamSnapshot{ - Timestamp: time.Now().UnixMilli(), - UptimeMs: time.Since(p.startTime).Milliseconds(), - Protocol: "DASH", + Timestamp: time.Now().UnixMilli(), + UptimeMs: time.Since(p.startTime).Milliseconds(), + Protocol: "DASH", Video: distribution.VideoStats{ Codec: "AV1", TotalFrames: p.videoForwarded.Load(), From be61dd6893bc0d2b33bf251596fde66dccbf9014 Mon Sep 17 00:00:00 2001 From: Thomas Symborski Date: Wed, 18 Mar 2026 17:01:25 -0400 Subject: [PATCH 12/14] fix: handle ffmpeg DASH output format in MPD parser - Support contentType attr on AdaptationSet (ffmpeg puts mimeType on Representation, not AdaptationSet) - Support $Number%05d$ printf-style zero-padded segment numbers in SegmentTemplate media patterns - Fall back to codec-prefix heuristic when neither mimeType nor contentType is available --- ingest/dash/mpd.go | 51 ++++++++++++++++++++++++++++++----- ingest/dash/mpd_test.go | 60 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 105 insertions(+), 6 deletions(-) diff --git a/ingest/dash/mpd.go b/ingest/dash/mpd.go index 395c68f..4316699 100644 --- a/ingest/dash/mpd.go +++ b/ingest/dash/mpd.go @@ -2,6 +2,7 @@ package dash import ( "fmt" + "regexp" "strconv" "strings" "time" @@ -82,8 +83,10 @@ func parseMPD(data []byte) (*mpdInfo, error) { if as == nil { continue } + // Determine content type from mimeType, contentType, or Representation mimeType + contentType := classifyAdaptationSet(as) ai := adaptationInfo{ - MimeType: as.MimeType, + MimeType: contentType, } for _, rep := range as.Representations { ri := representationInfo{} @@ -117,9 +120,9 @@ func parseMPD(data []byte) (*mpdInfo, error) { ai.Representations = append(ai.Representations, ri) } - if strings.HasPrefix(as.MimeType, "video") { + if strings.HasPrefix(contentType, "video") { info.VideoAdaptations = append(info.VideoAdaptations, ai) - } else if strings.HasPrefix(as.MimeType, "audio") { + } else if strings.HasPrefix(contentType, "audio") { info.AudioAdaptations = append(info.AudioAdaptations, ai) } } @@ -192,11 +195,21 @@ func resolveInitURL(tmpl segmentTemplate, repID, baseURL string) string { return baseURL + s } -// resolveMediaURL replaces $RepresentationID$ and $Number$ in the media -// pattern and prepends the base URL. +// numberPattern matches $Number$ or $Number%0Nd$ (printf-style zero-padded). +var numberPattern = regexp.MustCompile(`\$Number(%0(\d+)d)?\$`) + +// resolveMediaURL replaces $RepresentationID$ and $Number$/$Number%05d$ in +// the media pattern and prepends the base URL. func resolveMediaURL(tmpl segmentTemplate, repID string, number int, baseURL string) string { s := strings.ReplaceAll(tmpl.MediaPattern, "$RepresentationID$", repID) - s = strings.ReplaceAll(s, "$Number$", strconv.Itoa(number)) + s = numberPattern.ReplaceAllStringFunc(s, func(match string) string { + sub := numberPattern.FindStringSubmatch(match) + if sub[2] != "" { + width, _ := strconv.Atoi(sub[2]) + return fmt.Sprintf("%0*d", width, number) + } + return strconv.Itoa(number) + }) return baseURL + s } @@ -221,3 +234,29 @@ func computeSegmentNumber(now time.Time, ast time.Time, tmpl segmentTemplate) in } return tmpl.StartNumber + liveSegment } + +// classifyAdaptationSet determines if an AdaptationSet is "video" or "audio" +// by checking (in order): mimeType attr, contentType attr, codec prefix heuristic. +func classifyAdaptationSet(as *mpd.AdaptationSet) string { + // 1. AdaptationSet mimeType (e.g., "video/mp4") + if as.MimeType != "" { + return as.MimeType + } + // 2. AdaptationSet contentType (e.g., "video", "audio") + if as.ContentType != nil && *as.ContentType != "" { + return *as.ContentType + } + // 3. Infer from codec string on first Representation + for _, rep := range as.Representations { + if rep.Codecs != nil { + c := *rep.Codecs + if strings.HasPrefix(c, "av01") || strings.HasPrefix(c, "avc") || strings.HasPrefix(c, "hev") || strings.HasPrefix(c, "hvc") || strings.HasPrefix(c, "vp0") { + return "video" + } + if strings.HasPrefix(c, "mp4a") || strings.HasPrefix(c, "opus") { + return "audio" + } + } + } + return "" +} diff --git a/ingest/dash/mpd_test.go b/ingest/dash/mpd_test.go index 4cd9c49..460e718 100644 --- a/ingest/dash/mpd_test.go +++ b/ingest/dash/mpd_test.go @@ -253,3 +253,63 @@ func TestParseMPDInvalid(t *testing.T) { t.Error("expected error for invalid XML") } } + +func TestResolveMediaURLZeroPadded(t *testing.T) { + t.Parallel() + // ffmpeg uses $Number%05d$ for zero-padded segment numbers + tmpl := segmentTemplate{ + MediaPattern: "chunk-stream$RepresentationID$-$Number%05d$.m4s", + } + got := resolveMediaURL(tmpl, "0", 3, "http://localhost:8080/") + want := "http://localhost:8080/chunk-stream0-00003.m4s" + if got != want { + t.Errorf("resolveMediaURL zero-padded = %q, want %q", got, want) + } +} + +// Test MPD with contentType on AdaptationSet and mimeType on Representation +// (ffmpeg pattern) +const testMPDContentType = ` + + + + + + + + + + + + + +` + +func TestParseMPDContentTypeAttr(t *testing.T) { + t.Parallel() + info, err := parseMPD([]byte(testMPDContentType)) + if err != nil { + t.Fatalf("parseMPD: %v", err) + } + if !info.IsDynamic { + t.Error("expected dynamic MPD") + } + if len(info.VideoAdaptations) != 1 { + t.Fatalf("video adaptations = %d, want 1", len(info.VideoAdaptations)) + } + if len(info.AudioAdaptations) != 1 { + t.Fatalf("audio adaptations = %d, want 1", len(info.AudioAdaptations)) + } + videoRep := info.VideoAdaptations[0].Representations[0] + if videoRep.Codecs != "av01.0.05M.08" { + t.Errorf("video codecs = %q, want av01.0.05M.08", videoRep.Codecs) + } +} From 93e4aebcc80e511642c52fc75e6b760b72ccc812 Mon Sep 17 00:00:00 2001 From: Thomas Symborski Date: Wed, 18 Mar 2026 17:06:39 -0400 Subject: [PATCH 13/14] fix: build proper AV1 codec string and decoder config in DASH puller SetVideoInfo was passing "av1" instead of the RFC 6381 codec string (e.g., "av01.0.05M.08") and omitting the AV1CodecConfigurationRecord. This caused the browser's WebCodecs VideoDecoder to reject the config with "Unknown or ambiguous codec name". Now parses the sequence header OBU from the init segment to generate the correct codec string and decoder config for the MoQ catalog. --- ingest/dash/puller.go | 17 ++++++++++++++--- 1 file changed, 14 insertions(+), 3 deletions(-) diff --git a/ingest/dash/puller.go b/ingest/dash/puller.go index 08bbc51..1944fce 100644 --- a/ingest/dash/puller.go +++ b/ingest/dash/puller.go @@ -11,7 +11,9 @@ import ( "sync" "time" + "github.com/zsiec/prism/demux" "github.com/zsiec/prism/distribution" + "github.com/zsiec/prism/moq" ) // PullRequest describes a DASH source to pull from. @@ -281,12 +283,21 @@ func (p *Puller) runPullLoop( mgr.Create(req.StreamKey) } - // Set video info on relay. - relay.SetVideoInfo(distribution.VideoInfo{ + // Set video info on relay with proper RFC 6381 codec string and decoder config. + vi := distribution.VideoInfo{ Codec: initInfo.VideoCodec, Width: initInfo.Width, Height: initInfo.Height, - }) + } + if initInfo.SeqHeaderOBU != nil { + if hdr, err := demux.ParseAV1SequenceHeader(initInfo.SeqHeaderOBU); err == nil { + vi.Codec = hdr.CodecString() + vi.Width = hdr.Width + vi.Height = hdr.Height + } + vi.DecoderConfig = moq.BuildAV1DecoderConfig(initInfo.SeqHeaderOBU) + } + relay.SetVideoInfo(vi) // Set audio info if available. if hasAudio { From bd0a2629e011e375cc5107c2bf0536f85bac5b1a Mon Sep 17 00:00:00 2001 From: Thomas Symborski Date: Wed, 18 Mar 2026 17:11:18 -0400 Subject: [PATCH 14/14] fix: pace DASH segment frames at real-time cadence DASH segments arrive as bursts (2s of frames at once). Without pacing, the player sees burst-pause-burst-pause jitter. Now processVideoSamples and processAudioSamples sleep between frames based on PTS deltas, and run concurrently so video and audio interleave naturally. --- ingest/dash/pipeline.go | 37 ++++++++++++++++++++++++++++++++++++- ingest/dash/puller.go | 25 +++++++++++++++---------- 2 files changed, 51 insertions(+), 11 deletions(-) diff --git a/ingest/dash/pipeline.go b/ingest/dash/pipeline.go index 4e5aa00..33a8275 100644 --- a/ingest/dash/pipeline.go +++ b/ingest/dash/pipeline.go @@ -55,10 +55,30 @@ func newDASHPipeline(streamKey string, relay broadcaster, seqHdrOBU []byte) *das // broadcasts them through the relay. Keyframes start a new group and carry // the AV1 sequence header OBU as SPS so that late-joining decoders can // configure themselves. +// +// Frames are paced based on PTS deltas to avoid bursting an entire segment +// worth of frames at once (which causes jittery playback at the viewer). func (p *dashPipeline) processVideoSamples(samples []mediaSample) { + if len(samples) == 0 { + return + } + + basePTS := samples[0].PTS + wallStart := time.Now() + for i := range samples { s := &samples[i] + // Pace: sleep until this frame's presentation time relative to the + // first frame in the batch, so frames arrive at ~real-time cadence. + if i > 0 { + ptsDelta := time.Duration(s.PTS-basePTS) * time.Microsecond + wallElapsed := time.Since(wallStart) + if sleep := ptsDelta - wallElapsed; sleep > 0 { + time.Sleep(sleep) + } + } + frame := &media.VideoFrame{ PTS: s.PTS, DTS: s.DTS, @@ -79,11 +99,26 @@ func (p *dashPipeline) processVideoSamples(samples []mediaSample) { } // processAudioSamples converts audio mediaSamples into AudioFrames and -// broadcasts them through the relay. +// broadcasts them through the relay. Paced by PTS like video. func (p *dashPipeline) processAudioSamples(samples []mediaSample) { + if len(samples) == 0 { + return + } + + basePTS := samples[0].PTS + wallStart := time.Now() + for i := range samples { s := &samples[i] + if i > 0 { + ptsDelta := time.Duration(s.PTS-basePTS) * time.Microsecond + wallElapsed := time.Since(wallStart) + if sleep := ptsDelta - wallElapsed; sleep > 0 { + time.Sleep(sleep) + } + } + frame := &media.AudioFrame{ PTS: s.PTS, Data: s.Data, diff --git a/ingest/dash/puller.go b/ingest/dash/puller.go index 1944fce..83133b4 100644 --- a/ingest/dash/puller.go +++ b/ingest/dash/puller.go @@ -358,9 +358,8 @@ func (p *Puller) runPullLoop( p.log.Warn("parse video segment", "seg", segNum, "error", err) continue } - pipeline.processVideoSamples(videoSamples) - - // Fetch audio segment. + // Fetch audio segment in parallel with video pacing. + var audioSamples []mediaSample if hasAudio { audioSegNum := computeSegmentNumber(time.Now(), info.AvailabilityStartTime, audioTmpl) audioMediaURL := resolveMediaURL(audioTmpl, audioRep.ID, audioSegNum, baseURL) @@ -370,15 +369,21 @@ func (p *Puller) runPullLoop( return } p.log.Warn("fetch audio segment", "seg", audioSegNum, "url", audioMediaURL, "error", err) - continue + } else { + audioSamples, err = parseMediaSegment(audioData, initInfo, false) + if err != nil { + p.log.Warn("parse audio segment", "seg", audioSegNum, "error", err) + } } + } - audioSamples, err := parseMediaSegment(audioData, initInfo, false) - if err != nil { - p.log.Warn("parse audio segment", "seg", audioSegNum, "error", err) - continue - } + // Pace video and audio concurrently so they interleave naturally. + done := make(chan struct{}) + go func() { pipeline.processAudioSamples(audioSamples) - } + close(done) + }() + pipeline.processVideoSamples(videoSamples) + <-done } }