From 14f92f12201ba23294fa7b827838ff89369e6ac4 Mon Sep 17 00:00:00 2001 From: Derrick Chambers Date: Wed, 8 Apr 2026 13:40:51 +0200 Subject: [PATCH 1/5] add support for sintela protobuf format --- dascore/data_registry.txt | 1 + dascore/io/sintela_protobuf/__init__.py | 9 + dascore/io/sintela_protobuf/core.py | 41 + dascore/io/sintela_protobuf/utils.py | 1014 +++++++++++++++++ pyproject.toml | 1 + tests/test_io/test_common_io.py | 2 + tests/test_io/test_remote_memory.py | 30 +- .../test_sintela_protobuf.py | 788 +++++++++++++ 8 files changed, 1876 insertions(+), 10 deletions(-) create mode 100644 dascore/io/sintela_protobuf/__init__.py create mode 100644 dascore/io/sintela_protobuf/core.py create mode 100644 dascore/io/sintela_protobuf/utils.py create mode 100644 tests/test_io/test_sintela_protobuf/test_sintela_protobuf.py diff --git a/dascore/data_registry.txt b/dascore/data_registry.txt index caf9c249..44e13b49 100644 --- a/dascore/data_registry.txt +++ b/dascore/data_registry.txt @@ -40,3 +40,4 @@ febus_2.h5 c118960a94e37fbff0eb5c33856d34cdfe81609902c4feaedab9949498d31c23 http febg1_C1_2023-05-10T12.25.03+0000.bsl e1a8ff72f3ec1805129267df916f41419bf7fa3a4993602e2b85e721cad922ae https://github.com/dasdae/test_data/raw/master/dss/febg1_C1_2023-05-10T12.25.03+0000.bsl febg1_C1_2023-05-10T12.27.33+0000.bsl 233df0c184796944442ae19beddcf962aba1c9fab337fd1c74971f0c4d513a36 https://github.com/dasdae/test_data/raw/master/dss/febg1_C1_2023-05-10T12.27.33+0000.bsl xdas_netcdf.nc 9e53fa1ce8395fedbb195048b3eb2832b87cb6883867cdff30be93078e8027f7 https://github.com/dasdae/test_data/raw/master/das/xdas_netcdf.nc +sintela_protobuf_1.pb cafa038c033cb5e4cf4aa90bab203eb1e10dc05a6052c0576dc80d6a5b3b42dc https://github.com/dasdae/test_data/raw/master/das/sintela_protobuf_1.pb diff --git a/dascore/io/sintela_protobuf/__init__.py b/dascore/io/sintela_protobuf/__init__.py new file mode 100644 index 00000000..6f8a42d1 --- /dev/null +++ b/dascore/io/sintela_protobuf/__init__.py @@ -0,0 +1,9 @@ +""" +Sintela protobuf reader. + +This module supports Sintela's MTLV-wrapped protobuf recordings. Format +detection only inspects the MTLV envelope and does not require protobuf to be +installed. Reading and scanning lazily import protobuf support when needed. +""" + +from .core import SintelaProtobufV1 diff --git a/dascore/io/sintela_protobuf/core.py b/dascore/io/sintela_protobuf/core.py new file mode 100644 index 00000000..326efcae --- /dev/null +++ b/dascore/io/sintela_protobuf/core.py @@ -0,0 +1,41 @@ +""" +Core module for reading Sintela protobuf format. +""" + +from __future__ import annotations + +import numpy as np + +import dascore as dc +from dascore.core.summary import PatchSummary +from dascore.io import FiberIO +from dascore.utils.io import BinaryReader + +from .utils import get_supported_family_tag, read_payload, scan_payload + + +class SintelaProtobufV1(FiberIO): + """IO class for Sintela protobuf MTLV recordings.""" + + name = "Sintela_Protobuf" + preferred_extensions = ("pb",) + version = "1" + + def get_format(self, resource: BinaryReader, **kwargs) -> tuple[str, str] | bool: + """Return the format/version tuple if the file is Sintela protobuf.""" + tag = get_supported_family_tag(resource) + return (self.name, self.version) if tag else False + + def scan(self, resource: BinaryReader, **kwargs) -> list[PatchSummary]: + """Scan a Sintela protobuf recording.""" + return scan_payload(resource) + + def read(self, resource: BinaryReader, **kwargs) -> dc.BaseSpool: + """Read a Sintela protobuf recording into a spool.""" + data, coords, attrs = read_payload(resource) + selectors = {name: kwargs[name] for name in coords.dims if name in kwargs} + if selectors: + coords, data = coords.select(data, **selectors) + if not np.size(data): + return dc.spool([]) + return dc.spool([dc.Patch(data=data, coords=coords, attrs=attrs)]) diff --git a/dascore/io/sintela_protobuf/utils.py b/dascore/io/sintela_protobuf/utils.py new file mode 100644 index 00000000..9cf4c640 --- /dev/null +++ b/dascore/io/sintela_protobuf/utils.py @@ -0,0 +1,1014 @@ +""" +Utilities for reading Sintela protobuf MTLV recordings. +""" + +from __future__ import annotations + +import struct +from dataclasses import dataclass +from functools import cache +from typing import Any + +import numpy as np + +import dascore as dc +from dascore.constants import VALID_DATA_TYPES +from dascore.core.attrs import PatchAttrs +from dascore.core.coordmanager import get_coord_manager +from dascore.core.coords import get_coord +from dascore.core.summary import PatchSummary +from dascore.exceptions import InvalidFiberFileError, MissingOptionalDependencyError +from dascore.utils.misc import suppress_warnings + +PBUF_MAGIC = 0x46554250 +META_TAG = "META" +TS_TAGS = frozenset({"TS05", "RF01"}) +FFT_TAGS = frozenset({"FFT", "FFT-"}) +BAND_TAGS = frozenset({"BAND"}) +DIMS_TS = ("time", "distance") +DIMS_BAND = ("time", "distance", "band") +DIMS_FFT = ("time", "distance", "frequency") + +_TIMESERIES_DATA_TYPE_MAP = { + 0: ("phase", "radians"), + 1: ("phase", "radians"), + 2: ("phase_difference", "radians"), + 3: ("phase_rate", "radians/s"), + 4: ("strain", "microstrain"), + 5: ("strain_rate", "microstrain/s"), +} +_BAND_DATA_TYPE_MAP = { + 10: ("temperature", ""), + 13: ("phase", "radians"), +} + + +class SintelaProtobufAttrs(PatchAttrs): + """Patch attributes for Sintela protobuf recordings.""" + + gauge_length: float = np.nan + gauge_length_units: str = "m" + packet_type: str = "" + recorder_namespace: str = "" + metadata_recording_time: np.datetime64 | None = None + instrument_manufacturer: str = "" + instrument_model: str = "" + fiber_id: int | None = None + serial_number: str = "" + start_channel: int | None = None + channel_spacing: float = np.nan + channel_spacing_units: str = "m" + channel_step: int | None = None + sample_rate: float = np.nan + demod_data_type: str = "" + + +@dataclass(frozen=True) +class EnvelopeRecord: + """The envelope information for one MTLV record.""" + + tag: str + payload: bytes + + +@dataclass(frozen=True) +class ParsedMeta: + """Selected metadata fields promoted from META packets.""" + + recorder_namespace: str = "" + metadata_recording_time: np.datetime64 | None = None + instrument_manufacturer: str = "" + instrument_model: str = "" + serial_number: str = "" + fiber_id: int | None = None + + +def _timestamp_to_dt64(timestamp) -> np.datetime64 | None: + """Convert a protobuf timestamp into datetime64[ns].""" + seconds = int(getattr(timestamp, "seconds", 0)) + nanos = int(getattr(timestamp, "nanos", 0)) + return np.datetime64(seconds, "s") + np.timedelta64(nanos, "ns") + + +def _iter_envelope_records(resource, *, strict: bool) -> list[EnvelopeRecord]: + """Read all MTLV envelope records from a binary stream.""" + resource.seek(0) + out: list[EnvelopeRecord] = [] + while True: + magic = resource.read(4) + if not magic: + break + if len(magic) < 4: + if strict: + raise InvalidFiberFileError("Truncated Sintela protobuf magic header.") + return [] + if struct.unpack(" str | None: + """Return the first supported data tag in a file without using protobuf.""" + for record in _iter_envelope_records(resource, strict=False): + if record.tag == META_TAG: + continue + if record.tag in TS_TAGS | BAND_TAGS | FFT_TAGS: + return record.tag + return None + return None + + +def _optional_dependency_error() -> MissingOptionalDependencyError: + """Return the standardized missing dependency error.""" + msg = ( + "protobuf is not installed but is required for Sintela protobuf scan/read " + "operations." + ) + return MissingOptionalDependencyError(msg) + + +@cache +def _get_proto_messages(): + """Build lightweight protobuf messages for supported Sintela packet types.""" + try: + from google.protobuf import descriptor_pb2, descriptor_pool, message_factory + from google.protobuf import timestamp_pb2 + except Exception as exc: # pragma: no cover - import failure path + raise _optional_dependency_error() from exc + + return _build_proto_messages( + descriptor_pb2=descriptor_pb2, + descriptor_pool=descriptor_pool, + message_factory=message_factory, + timestamp_pb2=timestamp_pb2, + include_sample_fields=True, + package_name="sintela_common", + file_name="sintela_common_lite.proto", + ) + + +@cache +def _get_scan_proto_messages(): + """Build scan-only protobuf messages which omit sample payload fields.""" + try: + from google.protobuf import descriptor_pb2, descriptor_pool, message_factory + from google.protobuf import timestamp_pb2 + except Exception as exc: # pragma: no cover - import failure path + raise _optional_dependency_error() from exc + + return _build_proto_messages( + descriptor_pb2=descriptor_pb2, + descriptor_pool=descriptor_pool, + message_factory=message_factory, + timestamp_pb2=timestamp_pb2, + include_sample_fields=False, + package_name="sintela_common_scan", + file_name="sintela_common_scan.proto", + ) + + +def _build_proto_messages( + *, + descriptor_pb2, + descriptor_pool, + message_factory, + timestamp_pb2, + include_sample_fields: bool, + package_name: str, + file_name: str, +): + """Build lightweight protobuf message classes for data packets.""" + + file_proto = descriptor_pb2.FileDescriptorProto() + file_proto.name = file_name + file_proto.package = package_name + file_proto.dependency.append("google/protobuf/timestamp.proto") + + def add_field(message, name, number, type_, *, label=None, type_name=""): + field = message.field.add() + field.name = name + field.number = number + field.label = ( + label + if label is not None + else descriptor_pb2.FieldDescriptorProto.LABEL_OPTIONAL + ) + field.type = type_ + if type_name: + field.type_name = type_name + return field + + common = file_proto.message_type.add() + common.name = "CommonHeader" + add_field( + common, + "time", + 1, + descriptor_pb2.FieldDescriptorProto.TYPE_MESSAGE, + type_name=".google.protobuf.Timestamp", + ) + for number, name, type_ in ( + (2, "num_channels", descriptor_pb2.FieldDescriptorProto.TYPE_INT32), + (3, "sample_rate", descriptor_pb2.FieldDescriptorProto.TYPE_FLOAT), + (4, "channel_spacing", descriptor_pb2.FieldDescriptorProto.TYPE_FLOAT), + (5, "gauge_length", descriptor_pb2.FieldDescriptorProto.TYPE_FLOAT), + (6, "start_channel", descriptor_pb2.FieldDescriptorProto.TYPE_INT32), + (7, "end_of_replay", descriptor_pb2.FieldDescriptorProto.TYPE_BOOL), + (8, "fiber_flipped", descriptor_pb2.FieldDescriptorProto.TYPE_BOOL), + (9, "loop_removed", descriptor_pb2.FieldDescriptorProto.TYPE_BOOL), + (10, "has_dropped_samples", descriptor_pb2.FieldDescriptorProto.TYPE_BOOL), + (11, "timeseries_data_type", descriptor_pb2.FieldDescriptorProto.TYPE_INT32), + (12, "demod_data_type", descriptor_pb2.FieldDescriptorProto.TYPE_INT32), + ): + add_field(common, name, number, type_) + + timeseries_header = file_proto.message_type.add() + timeseries_header.name = "TimeseriesHeader" + add_field( + timeseries_header, + "common_header", + 1, + descriptor_pb2.FieldDescriptorProto.TYPE_MESSAGE, + type_name=f".{package_name}.CommonHeader", + ) + add_field( + timeseries_header, "sample_count", 2, descriptor_pb2.FieldDescriptorProto.TYPE_UINT32 + ) + add_field( + timeseries_header, "num_samples", 3, descriptor_pb2.FieldDescriptorProto.TYPE_INT32 + ) + add_field( + timeseries_header, "channel_step", 4, descriptor_pb2.FieldDescriptorProto.TYPE_INT32 + ) + + timeseries_packet = file_proto.message_type.add() + timeseries_packet.name = "TimeseriesPacket" + add_field( + timeseries_packet, + "header", + 1, + descriptor_pb2.FieldDescriptorProto.TYPE_MESSAGE, + type_name=f".{package_name}.TimeseriesHeader", + ) + if include_sample_fields: + add_field( + timeseries_packet, + "samples", + 3, + descriptor_pb2.FieldDescriptorProto.TYPE_FLOAT, + label=descriptor_pb2.FieldDescriptorProto.LABEL_REPEATED, + ) + add_field( + timeseries_packet, + "raw_frames", + 4, + descriptor_pb2.FieldDescriptorProto.TYPE_BYTES, + ) + + band_info = file_proto.message_type.add() + band_info.name = "BandDataInfo" + for number, name, type_ in ( + (1, "band_data_type", descriptor_pb2.FieldDescriptorProto.TYPE_INT32), + (2, "start", descriptor_pb2.FieldDescriptorProto.TYPE_FLOAT), + (3, "end", descriptor_pb2.FieldDescriptorProto.TYPE_FLOAT), + (4, "averaging_type", descriptor_pb2.FieldDescriptorProto.TYPE_INT32), + (5, "description", descriptor_pb2.FieldDescriptorProto.TYPE_STRING), + (6, "source", descriptor_pb2.FieldDescriptorProto.TYPE_STRING), + ): + add_field(band_info, name, number, type_) + + band_header = file_proto.message_type.add() + band_header.name = "BandHeader" + add_field( + band_header, + "common_header", + 1, + descriptor_pb2.FieldDescriptorProto.TYPE_MESSAGE, + type_name=f".{package_name}.CommonHeader", + ) + add_field( + band_header, + "band_data_info", + 2, + descriptor_pb2.FieldDescriptorProto.TYPE_MESSAGE, + label=descriptor_pb2.FieldDescriptorProto.LABEL_REPEATED, + type_name=f".{package_name}.BandDataInfo", + ) + + band_packet = file_proto.message_type.add() + band_packet.name = "BandPacket" + add_field( + band_packet, + "header", + 1, + descriptor_pb2.FieldDescriptorProto.TYPE_MESSAGE, + type_name=f".{package_name}.BandHeader", + ) + if include_sample_fields: + add_field( + band_packet, + "samples", + 2, + descriptor_pb2.FieldDescriptorProto.TYPE_FLOAT, + label=descriptor_pb2.FieldDescriptorProto.LABEL_REPEATED, + ) + + fft_header = file_proto.message_type.add() + fft_header.name = "FFTHeader" + add_field( + fft_header, + "common_header", + 1, + descriptor_pb2.FieldDescriptorProto.TYPE_MESSAGE, + type_name=f".{package_name}.CommonHeader", + ) + for number, name, type_ in ( + (2, "num_bins", descriptor_pb2.FieldDescriptorProto.TYPE_INT32), + (3, "bin_res", descriptor_pb2.FieldDescriptorProto.TYPE_FLOAT), + (4, "averaging_type", descriptor_pb2.FieldDescriptorProto.TYPE_INT32), + (5, "channel_step", descriptor_pb2.FieldDescriptorProto.TYPE_INT32), + (6, "normalised", descriptor_pb2.FieldDescriptorProto.TYPE_BOOL), + (7, "has_power_data", descriptor_pb2.FieldDescriptorProto.TYPE_BOOL), + (8, "has_complex_data", descriptor_pb2.FieldDescriptorProto.TYPE_BOOL), + ): + add_field(fft_header, name, number, type_) + + fft_packet = file_proto.message_type.add() + fft_packet.name = "FFTPacket" + add_field( + fft_packet, + "header", + 1, + descriptor_pb2.FieldDescriptorProto.TYPE_MESSAGE, + type_name=f".{package_name}.FFTHeader", + ) + if include_sample_fields: + add_field( + fft_packet, + "samples", + 2, + descriptor_pb2.FieldDescriptorProto.TYPE_FLOAT, + label=descriptor_pb2.FieldDescriptorProto.LABEL_REPEATED, + ) + + pool = descriptor_pool.DescriptorPool() + pool.AddSerializedFile(timestamp_pb2.DESCRIPTOR.serialized_pb) + pool.Add(file_proto) + out = {} + for name in ("TimeseriesPacket", "BandPacket", "FFTPacket"): + descriptor = pool.FindMessageTypeByName(f"{package_name}.{name}") + out[name] = message_factory.GetMessageClass(descriptor) + return out + + +@cache +def _get_meta_message_class(): + """Build a lightweight RecordingMetadata parser for selected fields.""" + try: + from google.protobuf import descriptor_pb2, descriptor_pool, message_factory + from google.protobuf import timestamp_pb2 + except Exception as exc: # pragma: no cover - import failure path + raise _optional_dependency_error() from exc + + file_proto = descriptor_pb2.FileDescriptorProto() + file_proto.name = "sintela_meta_lite.proto" + file_proto.package = "sintela_meta" + file_proto.dependency.append("google/protobuf/timestamp.proto") + + identification = file_proto.message_type.add() + identification.name = "IdentificationResponse" + for number, name in ( + (1, "manufacturer"), + (2, "system_type"), + (3, "model"), + (4, "serial_number"), + ): + field = identification.field.add() + field.name = name + field.number = number + field.label = descriptor_pb2.FieldDescriptorProto.LABEL_OPTIONAL + field.type = descriptor_pb2.FieldDescriptorProto.TYPE_STRING + + acquisition = file_proto.message_type.add() + acquisition.name = "AcquisitionStatsResponse" + field = acquisition.field.add() + field.name = "fiber_id" + field.number = 8 + field.label = descriptor_pb2.FieldDescriptorProto.LABEL_OPTIONAL + field.type = descriptor_pb2.FieldDescriptorProto.TYPE_INT32 + + recording = file_proto.message_type.add() + recording.name = "RecordingMetadata" + fields = ( + ("recorder_namespace", 1, descriptor_pb2.FieldDescriptorProto.TYPE_STRING, ""), + ( + "metadata_recording_time", + 2, + descriptor_pb2.FieldDescriptorProto.TYPE_MESSAGE, + ".google.protobuf.Timestamp", + ), + ( + "identification", + 3, + descriptor_pb2.FieldDescriptorProto.TYPE_MESSAGE, + ".sintela_meta.IdentificationResponse", + ), + ( + "acquisition_stats", + 7, + descriptor_pb2.FieldDescriptorProto.TYPE_MESSAGE, + ".sintela_meta.AcquisitionStatsResponse", + ), + ) + for name, number, type_, type_name in fields: + field = recording.field.add() + field.name = name + field.number = number + field.label = descriptor_pb2.FieldDescriptorProto.LABEL_OPTIONAL + field.type = type_ + if type_name: + field.type_name = type_name + + pool = descriptor_pool.DescriptorPool() + pool.AddSerializedFile(timestamp_pb2.DESCRIPTOR.serialized_pb) + pool.Add(file_proto) + descriptor = pool.FindMessageTypeByName("sintela_meta.RecordingMetadata") + return message_factory.GetMessageClass(descriptor) + + +def _parse_meta(payload: bytes) -> ParsedMeta: + """Parse selected fields from a META payload.""" + message_cls = _get_meta_message_class() + msg = message_cls() + with suppress_warnings(): + msg.ParseFromString(payload) + identification = msg.identification if msg.HasField("identification") else None + acquisition = msg.acquisition_stats if msg.HasField("acquisition_stats") else None + return ParsedMeta( + recorder_namespace=str(getattr(msg, "recorder_namespace", "") or ""), + metadata_recording_time=( + _timestamp_to_dt64(msg.metadata_recording_time) + if msg.HasField("metadata_recording_time") + else None + ), + instrument_manufacturer=str( + getattr(identification, "manufacturer", "") or "" + ), + instrument_model=str(getattr(identification, "model", "") or ""), + serial_number=str(getattr(identification, "serial_number", "") or ""), + fiber_id=( + int(getattr(acquisition, "fiber_id")) + if acquisition is not None and getattr(acquisition, "fiber_id", None) is not None + else None + ), + ) + + +def _common_header_time(common_header) -> np.datetime64 | None: + """Return a common-header timestamp when present.""" + return ( + _timestamp_to_dt64(common_header.time) + if common_header.HasField("time") + else None + ) + + +def _parse_records( + records: list[EnvelopeRecord], *, scan_mode: bool = False +) -> tuple[list[Any], ParsedMeta]: + """Decode protobuf payloads and return messages plus selected META.""" + messages = _get_scan_proto_messages() if scan_mode else _get_proto_messages() + parsed: list[Any] = [] + meta = ParsedMeta() + for record in records: + tag = record.tag + if tag == META_TAG: + meta = _parse_meta(record.payload) + continue + if tag in TS_TAGS: + msg = messages["TimeseriesPacket"]() + elif tag in BAND_TAGS: + msg = messages["BandPacket"]() + elif tag in FFT_TAGS: + msg = messages["FFTPacket"]() + else: + raise InvalidFiberFileError(f"Unsupported Sintela protobuf tag {tag!r}.") + msg.ParseFromString(record.payload) + parsed.append((tag, msg)) + if not parsed: + raise InvalidFiberFileError("No supported Sintela protobuf data packets found.") + return parsed, meta + + +def _get_time_coord_from_samples(start: np.datetime64, sample_rate: float, size: int): + """Build a regularly sampled time coordinate.""" + step = dc.to_timedelta64(1 / sample_rate) + return get_coord(start=start, stop=start + step * size, step=step, units="s") + + +def _get_distance_coord(start_channel: int, spacing: float, count: int, step: int = 1): + """Build the distance coordinate.""" + start = start_channel * spacing + return get_coord( + start=start, + stop=start + spacing * step * count, + step=spacing * step, + units="m", + ) + + +def _get_times(times: list[np.datetime64]): + """Build a time coordinate from packet timestamps.""" + return get_coord(data=np.asarray(times, dtype="datetime64[ns]"), units="s") + + +def _normalize_data_type(candidate: str) -> str: + """Map any external string to a DASCore-valid data type.""" + clean = str(candidate or "").lower().strip() + return clean if clean in VALID_DATA_TYPES else "" + + +def _assert_float_equal(name: str, values: list[float], *, rtol: float = 1e-6): + """Ensure float values match within a small tolerance.""" + first = values[0] + for value in values[1:]: + if not np.isclose(first, value, rtol=rtol, atol=0.0): + raise InvalidFiberFileError( + f"Inconsistent {name} across Sintela protobuf packets." + ) + return first + + +def _base_attrs(common_header, packet_type: str, meta: ParsedMeta, extra: dict | None = None): + """Construct base attrs from the packet header and META metadata.""" + data_type, data_units = _TIMESERIES_DATA_TYPE_MAP.get( + int(getattr(common_header, "timeseries_data_type", 0)), + ("phase", "radians"), + ) + attrs = dict( + data_type=_normalize_data_type(data_type), + data_category="DAS", + data_units=data_units, + gauge_length=float(getattr(common_header, "gauge_length", np.nan)), + packet_type=packet_type, + recorder_namespace=meta.recorder_namespace, + metadata_recording_time=meta.metadata_recording_time, + instrument_manufacturer=meta.instrument_manufacturer, + instrument_model=meta.instrument_model, + instrument_id=meta.serial_number, + serial_number=meta.serial_number, + fiber_id=meta.fiber_id, + start_channel=int(getattr(common_header, "start_channel", 0)), + channel_spacing=float(getattr(common_header, "channel_spacing", np.nan)), + channel_step=None, + sample_rate=float(getattr(common_header, "sample_rate", np.nan)), + ) + if extra: + attrs.update(extra) + return SintelaProtobufAttrs(**attrs) + + +def _assert_equal(name: str, values: list[Any]): + """Ensure all values in a list are equal.""" + first = values[0] + for value in values[1:]: + if value != first: + raise InvalidFiberFileError( + f"Inconsistent {name} across Sintela protobuf packets." + ) + return first + + +def _validate_single_family(parsed: list[tuple[str, Any]]) -> str: + """Ensure a file only contains one data packet family.""" + families = { + "timeseries" if tag in TS_TAGS else "band" if tag in BAND_TAGS else "fft" + for tag, _ in parsed + } + if len(families) != 1: + raise InvalidFiberFileError("Mixed Sintela protobuf packet families are unsupported.") + return families.pop() + + +def _decode_family(parsed: list[tuple[str, Any]], meta: ParsedMeta): + """Decode one parsed data family into data, coords, and attrs.""" + family = _validate_single_family(parsed) + if family == "timeseries": + return _decode_timeseries(parsed, meta) + if family == "band": + return _decode_band(parsed, meta) + return _decode_fft(parsed, meta) + + +def _decode_timeseries(parsed: list[tuple[str, Any]], meta: ParsedMeta): + """Decode timeseries packets into data, coords, and attrs.""" + headers = [msg.header for _tag, msg in parsed] + common_headers = [header.common_header for header in headers] + num_channels = _assert_equal( + "num_channels", [int(ch.num_channels) for ch in common_headers] + ) + sample_rate = _assert_float_equal( + "sample_rate", [float(ch.sample_rate) for ch in common_headers] + ) + channel_spacing = _assert_float_equal( + "channel_spacing", [float(ch.channel_spacing) for ch in common_headers] + ) + gauge_length = _assert_float_equal( + "gauge_length", [float(ch.gauge_length) for ch in common_headers] + ) + start_channel = _assert_equal( + "start_channel", [int(ch.start_channel) for ch in common_headers] + ) + channel_step = _assert_equal("channel_step", [int(h.channel_step) for h in headers]) + data_type = _assert_equal( + "timeseries_data_type", + [int(ch.timeseries_data_type) for ch in common_headers], + ) + demod_data_type = _assert_equal( + "demod_data_type", [int(ch.demod_data_type) for ch in common_headers] + ) + for ch in common_headers: + if ch.has_dropped_samples: + raise InvalidFiberFileError("Dropped samples in Sintela protobuf stream.") + sample_counts = [int(h.sample_count) for h in headers] + num_samples_per_packet = [int(h.num_samples) for h in headers] + for current, nxt, count in zip( + sample_counts, sample_counts[1:], num_samples_per_packet[:-1], strict=False + ): + if current + count != nxt: + raise InvalidFiberFileError("Non-contiguous Sintela protobuf sample counts.") + total_samples = sum(num_samples_per_packet) + first_time = _common_header_time(common_headers[0]) + if first_time is None: + raise InvalidFiberFileError("Missing Sintela protobuf start time.") + data = np.empty((total_samples, num_channels), dtype=np.float32) + index = 0 + for _tag, msg in parsed: + packet = np.asarray(msg.samples, dtype=np.float32) + rows = int(msg.header.num_samples) + expected = rows * num_channels + if packet.size != expected: + raise InvalidFiberFileError("Unexpected Sintela protobuf TS sample payload size.") + data[index : index + rows] = packet.reshape(rows, num_channels) + index += rows + time = _get_time_coord_from_samples(first_time, sample_rate, total_samples) + distance = _get_distance_coord(start_channel, channel_spacing, num_channels, channel_step) + coords = get_coord_manager({"time": time, "distance": distance}, dims=DIMS_TS) + attrs = _base_attrs( + common_headers[0], + packet_type=parsed[0][0], + meta=meta, + extra=dict( + gauge_length=gauge_length, + channel_step=channel_step, + sample_rate=sample_rate, + data_type=_TIMESERIES_DATA_TYPE_MAP.get(data_type, ("phase", "radians"))[0], + data_units=_TIMESERIES_DATA_TYPE_MAP.get(data_type, ("phase", "radians"))[1], + demod_data_type=str(demod_data_type), + ), + ) + return data, coords, attrs + + +def _decode_band(parsed: list[tuple[str, Any]], meta: ParsedMeta): + """Decode band packets into data, coords, and attrs.""" + headers = [msg.header for _tag, msg in parsed] + common_headers = [header.common_header for header in headers] + num_channels = _assert_equal( + "num_channels", [int(ch.num_channels) for ch in common_headers] + ) + channel_spacing = _assert_float_equal( + "channel_spacing", [float(ch.channel_spacing) for ch in common_headers] + ) + gauge_length = _assert_float_equal( + "gauge_length", [float(ch.gauge_length) for ch in common_headers] + ) + start_channel = _assert_equal( + "start_channel", [int(ch.start_channel) for ch in common_headers] + ) + band_defs = [] + for header in headers: + band_defs.append( + tuple( + ( + int(info.band_data_type), + float(info.start), + float(info.end), + str(info.description), + str(info.source), + ) + for info in header.band_data_info + ) + ) + band_def = _assert_equal("band_data_info", band_defs) + num_bands = len(band_def) + if not num_bands: + raise InvalidFiberFileError("Band packets missing band definitions.") + times = [_common_header_time(ch) for ch in common_headers] + if any(x is None for x in times): + raise InvalidFiberFileError("Missing time in Sintela BAND packet.") + data = np.empty((len(parsed), num_channels, num_bands), dtype=np.float32) + for ind, (_tag, msg) in enumerate(parsed): + packet = np.asarray(msg.samples, dtype=np.float32) + expected = num_channels * num_bands + if packet.size != expected: + raise InvalidFiberFileError("Unexpected Sintela protobuf BAND payload size.") + data[ind] = packet.reshape(num_channels, num_bands) + # BAND packets do not expose channel_step, so distance comes from the + # recorded start channel and spacing only. + distance = _get_distance_coord(start_channel, channel_spacing, num_channels) + band = get_coord(start=0, stop=num_bands, step=1) + coords = get_coord_manager( + { + "time": _get_times(times), + "distance": distance, + "band": band, + "band_start_frequency": ("band", np.asarray([x[1] for x in band_def])), + "band_end_frequency": ("band", np.asarray([x[2] for x in band_def])), + "band_description": ("band", np.asarray([x[3] for x in band_def], dtype=object)), + "band_source": ("band", np.asarray([x[4] for x in band_def], dtype=object)), + }, + dims=DIMS_BAND, + ) + first_type = int(band_def[0][0]) + data_type, data_units = _BAND_DATA_TYPE_MAP.get(first_type, ("", "")) + attrs = _base_attrs( + common_headers[0], + packet_type=parsed[0][0], + meta=meta, + extra=dict( + gauge_length=gauge_length, + data_type=data_type, + data_units=data_units, + ), + ) + return data, coords, attrs + + +def _decode_fft(parsed: list[tuple[str, Any]], meta: ParsedMeta): + """Decode FFT packets into data, coords, and attrs.""" + headers = [msg.header for _tag, msg in parsed] + common_headers = [header.common_header for header in headers] + num_channels = _assert_equal( + "num_channels", [int(ch.num_channels) for ch in common_headers] + ) + channel_spacing = _assert_float_equal( + "channel_spacing", [float(ch.channel_spacing) for ch in common_headers] + ) + gauge_length = _assert_float_equal( + "gauge_length", [float(ch.gauge_length) for ch in common_headers] + ) + start_channel = _assert_equal( + "start_channel", [int(ch.start_channel) for ch in common_headers] + ) + num_bins = _assert_equal("num_bins", [int(h.num_bins) for h in headers]) + bin_res = _assert_float_equal("bin_res", [float(h.bin_res) for h in headers]) + has_complex = _assert_equal("has_complex_data", [bool(h.has_complex_data) for h in headers]) + channel_step = _assert_equal("channel_step", [int(h.channel_step) for h in headers]) + times = [_common_header_time(ch) for ch in common_headers] + if any(x is None for x in times): + raise InvalidFiberFileError("Missing time in Sintela FFT packet.") + dtype = np.complex64 if has_complex else np.float32 + data = np.empty((len(parsed), num_channels, num_bins), dtype=dtype) + for ind, (_tag, msg) in enumerate(parsed): + packet = np.asarray(msg.samples, dtype=np.float32) + if has_complex: + expected = num_channels * num_bins * 2 + if packet.size != expected: + raise InvalidFiberFileError("Unexpected Sintela protobuf FFT payload size.") + packet = packet.reshape(num_channels, num_bins, 2) + packet = packet[..., 0] + 1j * packet[..., 1] + else: + expected = num_channels * num_bins + if packet.size != expected: + raise InvalidFiberFileError("Unexpected Sintela protobuf FFT payload size.") + packet = packet.reshape(num_channels, num_bins) + data[ind] = packet + distance = _get_distance_coord(start_channel, channel_spacing, num_channels, channel_step) + frequency = get_coord(start=0.0, stop=bin_res * num_bins, step=bin_res, units="Hz") + coords = get_coord_manager( + {"time": _get_times(times), "distance": distance, "frequency": frequency}, + dims=DIMS_FFT, + ) + attrs = _base_attrs( + common_headers[0], + packet_type=parsed[0][0], + meta=meta, + extra=dict( + gauge_length=gauge_length, + channel_step=channel_step, + ), + ) + return data, coords, attrs + + +def _scan_timeseries(parsed: list[tuple[str, Any]], meta: ParsedMeta): + """Summarize timeseries packets without allocating sample arrays.""" + headers = [msg.header for _tag, msg in parsed] + common_headers = [header.common_header for header in headers] + num_channels = _assert_equal( + "num_channels", [int(ch.num_channels) for ch in common_headers] + ) + sample_rate = _assert_float_equal( + "sample_rate", [float(ch.sample_rate) for ch in common_headers] + ) + channel_spacing = _assert_float_equal( + "channel_spacing", [float(ch.channel_spacing) for ch in common_headers] + ) + gauge_length = _assert_float_equal( + "gauge_length", [float(ch.gauge_length) for ch in common_headers] + ) + start_channel = _assert_equal( + "start_channel", [int(ch.start_channel) for ch in common_headers] + ) + channel_step = _assert_equal("channel_step", [int(h.channel_step) for h in headers]) + data_type = _assert_equal( + "timeseries_data_type", + [int(ch.timeseries_data_type) for ch in common_headers], + ) + demod_data_type = _assert_equal( + "demod_data_type", [int(ch.demod_data_type) for ch in common_headers] + ) + for ch in common_headers: + if ch.has_dropped_samples: + raise InvalidFiberFileError("Dropped samples in Sintela protobuf stream.") + sample_counts = [int(h.sample_count) for h in headers] + num_samples_per_packet = [int(h.num_samples) for h in headers] + for current, nxt, count in zip( + sample_counts, sample_counts[1:], num_samples_per_packet[:-1], strict=False + ): + if current + count != nxt: + raise InvalidFiberFileError("Non-contiguous Sintela protobuf sample counts.") + total_samples = sum(num_samples_per_packet) + first_time = _common_header_time(common_headers[0]) + if first_time is None: + raise InvalidFiberFileError("Missing Sintela protobuf start time.") + time = _get_time_coord_from_samples(first_time, sample_rate, total_samples) + distance = _get_distance_coord(start_channel, channel_spacing, num_channels, channel_step) + coords = get_coord_manager({"time": time, "distance": distance}, dims=DIMS_TS) + attrs = _base_attrs( + common_headers[0], + packet_type=parsed[0][0], + meta=meta, + extra=dict( + gauge_length=gauge_length, + channel_step=channel_step, + sample_rate=sample_rate, + data_type=_TIMESERIES_DATA_TYPE_MAP.get(data_type, ("phase", "radians"))[0], + data_units=_TIMESERIES_DATA_TYPE_MAP.get(data_type, ("phase", "radians"))[1], + demod_data_type=str(demod_data_type), + ), + ) + return (total_samples, num_channels), coords, attrs, str(np.dtype(np.float32)) + + +def _scan_band(parsed: list[tuple[str, Any]], meta: ParsedMeta): + """Summarize band packets without allocating sample arrays.""" + headers = [msg.header for _tag, msg in parsed] + common_headers = [header.common_header for header in headers] + num_channels = _assert_equal( + "num_channels", [int(ch.num_channels) for ch in common_headers] + ) + channel_spacing = _assert_float_equal( + "channel_spacing", [float(ch.channel_spacing) for ch in common_headers] + ) + gauge_length = _assert_float_equal( + "gauge_length", [float(ch.gauge_length) for ch in common_headers] + ) + start_channel = _assert_equal( + "start_channel", [int(ch.start_channel) for ch in common_headers] + ) + band_defs = [] + for header in headers: + band_defs.append( + tuple( + ( + int(info.band_data_type), + float(info.start), + float(info.end), + str(info.description), + str(info.source), + ) + for info in header.band_data_info + ) + ) + band_def = _assert_equal("band_data_info", band_defs) + num_bands = len(band_def) + if not num_bands: + raise InvalidFiberFileError("Band packets missing band definitions.") + times = [_common_header_time(ch) for ch in common_headers] + if any(x is None for x in times): + raise InvalidFiberFileError("Missing time in Sintela BAND packet.") + distance = _get_distance_coord(start_channel, channel_spacing, num_channels) + band = get_coord(start=0, stop=num_bands, step=1) + coords = get_coord_manager( + { + "time": _get_times(times), + "distance": distance, + "band": band, + "band_start_frequency": ("band", np.asarray([x[1] for x in band_def])), + "band_end_frequency": ("band", np.asarray([x[2] for x in band_def])), + "band_description": ("band", np.asarray([x[3] for x in band_def], dtype=object)), + "band_source": ("band", np.asarray([x[4] for x in band_def], dtype=object)), + }, + dims=DIMS_BAND, + ) + first_type = int(band_def[0][0]) + data_type, data_units = _BAND_DATA_TYPE_MAP.get(first_type, ("", "")) + attrs = _base_attrs( + common_headers[0], + packet_type=parsed[0][0], + meta=meta, + extra=dict( + gauge_length=gauge_length, + data_type=data_type, + data_units=data_units, + ), + ) + return (len(parsed), num_channels, num_bands), coords, attrs, str(np.dtype(np.float32)) + + +def _scan_fft(parsed: list[tuple[str, Any]], meta: ParsedMeta): + """Summarize FFT packets without allocating sample arrays.""" + headers = [msg.header for _tag, msg in parsed] + common_headers = [header.common_header for header in headers] + num_channels = _assert_equal( + "num_channels", [int(ch.num_channels) for ch in common_headers] + ) + channel_spacing = _assert_float_equal( + "channel_spacing", [float(ch.channel_spacing) for ch in common_headers] + ) + gauge_length = _assert_float_equal( + "gauge_length", [float(ch.gauge_length) for ch in common_headers] + ) + start_channel = _assert_equal( + "start_channel", [int(ch.start_channel) for ch in common_headers] + ) + num_bins = _assert_equal("num_bins", [int(h.num_bins) for h in headers]) + bin_res = _assert_float_equal("bin_res", [float(h.bin_res) for h in headers]) + has_complex = _assert_equal("has_complex_data", [bool(h.has_complex_data) for h in headers]) + channel_step = _assert_equal("channel_step", [int(h.channel_step) for h in headers]) + times = [_common_header_time(ch) for ch in common_headers] + if any(x is None for x in times): + raise InvalidFiberFileError("Missing time in Sintela FFT packet.") + distance = _get_distance_coord(start_channel, channel_spacing, num_channels, channel_step) + frequency = get_coord(start=0.0, stop=bin_res * num_bins, step=bin_res, units="Hz") + coords = get_coord_manager( + {"time": _get_times(times), "distance": distance, "frequency": frequency}, + dims=DIMS_FFT, + ) + attrs = _base_attrs( + common_headers[0], + packet_type=parsed[0][0], + meta=meta, + extra=dict( + gauge_length=gauge_length, + channel_step=channel_step, + ), + ) + dtype = np.complex64 if has_complex else np.float32 + return (len(parsed), num_channels, num_bins), coords, attrs, str(np.dtype(dtype)) + + +def read_payload(resource): + """Decode a Sintela protobuf file into data, coords, and attrs.""" + records = _iter_envelope_records(resource, strict=True) + parsed, meta = _parse_records(records, scan_mode=False) + return _decode_family(parsed, meta) + + +def scan_payload(resource) -> list[PatchSummary]: + """Decode a Sintela protobuf file and return PatchSummary objects.""" + records = _iter_envelope_records(resource, strict=True) + parsed, meta = _parse_records(records, scan_mode=True) + family = _validate_single_family(parsed) + if family == "timeseries": + shape, coords, attrs, dtype = _scan_timeseries(parsed, meta) + elif family == "band": + shape, coords, attrs, dtype = _scan_band(parsed, meta) + else: + shape, coords, attrs, dtype = _scan_fft(parsed, meta) + return [ + PatchSummary.model_construct( + attrs=attrs, + coords=coords.to_summary_dict(), + dims=coords.dims, + shape=shape, + dtype=dtype, + ) + ] diff --git a/pyproject.toml b/pyproject.toml index 400ad97d..e6ae867f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -145,6 +145,7 @@ SEGY__V2_0 = "dascore.io.segy.core:SegyV2_0" SEGY__V2_1 = "dascore.io.segy.core:SegyV2_1" SILIXA_H5__V1 = "dascore.io.silixah5:SilixaH5V1" SINTELA_BINARY__V3 = "dascore.io.sintela_binary.core:SintelaBinaryV3" +SINTELA_PROTOBUF__V1 = "dascore.io.sintela_protobuf.core:SintelaProtobufV1" RSF__V1 = "dascore.io.rsf.core:RSFV1" WAV = "dascore.io.wav.core:WavIO" XMLBINARY__V1 = "dascore.io.xml_binary.core:XMLBinaryV1" diff --git a/tests/test_io/test_common_io.py b/tests/test_io/test_common_io.py index d937c621..2113adb2 100644 --- a/tests/test_io/test_common_io.py +++ b/tests/test_io/test_common_io.py @@ -38,6 +38,7 @@ from dascore.io.sentek import SentekV5 from dascore.io.silixah5 import SilixaH5V1 from dascore.io.sintela_binary import SintelaBinaryV3 +from dascore.io.sintela_protobuf import SintelaProtobufV1 from dascore.io.tdms import TDMSFormatterV4713 from dascore.io.terra15 import ( Terra15FormatterV4, @@ -82,6 +83,7 @@ SentekV5(): ("DASDMSShot00_20230328155653619.das",), SilixaH5V1(): ("silixa_h5_1.hdf5",), SintelaBinaryV3(): ("sintela_binary_v3_test_1.raw",), + SintelaProtobufV1(): ("sintela_protobuf_1.pb",), Terra15FormatterV4(): ( "terra15_das_1_trimmed.hdf5", "terra15_das_unfinished.hdf5", diff --git a/tests/test_io/test_remote_memory.py b/tests/test_io/test_remote_memory.py index 8ca9a552..254c798b 100644 --- a/tests/test_io/test_remote_memory.py +++ b/tests/test_io/test_remote_memory.py @@ -29,6 +29,18 @@ def memory_prodml_path(): return _copy_file_to_memory(source, dest) +@pytest.fixture() +def memory_fetch_copy(): + """Copy one fetched registry file into the in-memory filesystem.""" + + def _copy(fetch_name: str, namespace: str) -> tuple[Path, UPath]: + source = Path(fetch(fetch_name)) + dest = UPath(f"memory://dascore/{namespace}/{source.name}") + return source, _copy_file_to_memory(source, dest) + + return _copy + + @pytest.fixture(autouse=True) def isolated_remote_cache(tmp_path): """Use one isolated remote cache per test to avoid cross-test cleanup cost.""" @@ -184,14 +196,14 @@ class TestMemoryRemoteMetadataAccess: ("sample_tdms_file_v4713.tdms", ("TDMS", "4713")), ("DASDMSShot00_20230328155653619.das", ("sentek", "5")), ("sintela_binary_v3_test_1.raw", ("Sintela_Binary", "3")), + ("sintela_protobuf_1.pb", ("Sintela_Protobuf", "1")), ], ) - def test_get_format_avoids_local_cache(self, fetch_name, expected): + def test_get_format_avoids_local_cache( + self, fetch_name, expected, memory_fetch_copy + ): """Remote-first get_format paths should not materialize local cache files.""" - source = Path(fetch(fetch_name)) - path = _copy_file_to_memory( - source, UPath(f"memory://dascore/meta/{source.name}") - ) + source, path = memory_fetch_copy(fetch_name, "meta") assert dc.get_format(path) == expected assert not list(get_remote_cache_path().rglob(source.name)) @@ -202,14 +214,12 @@ def test_get_format_avoids_local_cache(self, fetch_name, expected): ("sample_tdms_file_v4713.tdms", ("TDMS", "4713")), ("DASDMSShot00_20230328155653619.das", ("sentek", "5")), ("sintela_binary_v3_test_1.raw", ("Sintela_Binary", "3")), + ("sintela_protobuf_1.pb", ("Sintela_Protobuf", "1")), ], ) - def test_scan_avoids_local_cache(self, fetch_name, expected): + def test_scan_avoids_local_cache(self, fetch_name, expected, memory_fetch_copy): """Remote-first scans should not materialize local cache files.""" - source = Path(fetch(fetch_name)) - path = _copy_file_to_memory( - source, UPath(f"memory://dascore/scan/{source.name}") - ) + source, path = memory_fetch_copy(fetch_name, "scan") attrs = dc.scan(path) assert len(attrs) > 0 assert attrs[0].source_format == expected[0] diff --git a/tests/test_io/test_sintela_protobuf/test_sintela_protobuf.py b/tests/test_io/test_sintela_protobuf/test_sintela_protobuf.py new file mode 100644 index 00000000..9206a1b9 --- /dev/null +++ b/tests/test_io/test_sintela_protobuf/test_sintela_protobuf.py @@ -0,0 +1,788 @@ +""" +Tests for Sintela protobuf format. +""" + +from __future__ import annotations + +import struct +from functools import cache +from pathlib import Path + +import numpy as np +import pytest + +import dascore as dc +from dascore.exceptions import InvalidFiberFileError, MissingOptionalDependencyError +from dascore.io.sintela_protobuf import SintelaProtobufV1 +from dascore.io.sintela_protobuf import utils as sintela_utils +from dascore.utils.downloader import fetch + + +@pytest.fixture(scope="module") +def sintela_protobuf_path(): + """Return the registered Sintela protobuf test file.""" + return fetch("sintela_protobuf_1.pb") + + +@pytest.fixture() +def fiber_io(): + """Return the Sintela protobuf FiberIO instance.""" + return SintelaProtobufV1() + + +def _write_record(handle, tag: str, payload: bytes): + """Write one MTLV record to a binary handle.""" + handle.write(struct.pack(" Path: + path = tmp_path / name + _write_records(path, records) + return path + + return _write + + +class TestSintelaProtobuf: + """Tests for Sintela protobuf IO support.""" + + def test_get_format(self, fiber_io, sintela_protobuf_path): + """A registered Sintela protobuf file should be detected.""" + assert fiber_io.get_format(sintela_protobuf_path) == ( + "Sintela_Protobuf", + "1", + ) + + def test_scan_matches_read_summary(self, fiber_io, sintela_protobuf_path): + """Scan metadata should match the loaded patch summary.""" + summary = fiber_io.scan(sintela_protobuf_path)[0] + patch_summary = fiber_io.read(sintela_protobuf_path)[0].summary + assert summary == patch_summary + + def test_ts_read_promotes_selected_meta_attrs(self, fiber_io, write_sintela_file, ts_records): + """META should supplement stable provenance attrs.""" + path = write_sintela_file("ts.pb", [("META", _build_meta_payload()), *ts_records]) + patch = fiber_io.read(path)[0] + assert patch.dims == ("time", "distance") + assert patch.attrs.recorder_namespace == "manualRecord/recorder" + assert patch.attrs.instrument_id == "SN123" + assert patch.attrs.instrument_manufacturer == "Sintela" + assert patch.attrs.instrument_model == "Onyx" + assert patch.attrs.fiber_id == 2 + assert patch.attrs.data_type == "strain" + + def test_band_read_returns_expected_dims(self, fiber_io, write_sintela_file, band_records): + """Band recordings should load into a 3D patch.""" + path = write_sintela_file("band.pb", band_records) + patch = fiber_io.read(path)[0] + assert patch.dims == ("time", "distance", "band") + assert patch.shape == (2, 2, 2) + assert "band_start_frequency" in patch.coords.coord_map + + def test_fft_read_returns_expected_dims(self, fiber_io, write_sintela_file, fft_records): + """FFT recordings should load into a frequency-domain patch.""" + path = write_sintela_file("fft.pb", fft_records) + patch = fiber_io.read(path)[0] + assert patch.dims == ("time", "distance", "frequency") + assert patch.shape == (2, 2, 3) + + def test_band_scan_matches_read_summary(self, fiber_io, write_sintela_file, band_records): + """Band scan should exercise the metadata-only summary path.""" + path = write_sintela_file("band.pb", band_records) + assert fiber_io.scan(path)[0] == fiber_io.read(path)[0].summary + + def test_fft_scan_matches_read_summary( + self, fiber_io, write_sintela_file, complex_fft_records + ): + """FFT scan should exercise the metadata-only summary path.""" + path = write_sintela_file("fft.pb", complex_fft_records) + assert fiber_io.scan(path)[0] == fiber_io.read(path)[0].summary + + def test_complex_fft_is_complex_dtype(self, fiber_io, write_sintela_file, complex_fft_records): + """Complex FFT packets should decode into a complex dtype.""" + path = write_sintela_file("fft_complex.pb", complex_fft_records) + patch = fiber_io.read(path)[0] + assert np.issubdtype(patch.data.dtype, np.complexfloating) + + def test_read_applies_distance_selection(self, fiber_io, write_sintela_file, ts_records): + """Read should apply coord selectors through the core wrapper.""" + path = write_sintela_file("ts_select.pb", ts_records) + patch = fiber_io.read(path, distance=(30, 31))[0] + assert patch.shape == (6, 1) + assert patch.coords.dims == ("time", "distance") + + def test_read_returns_empty_spool_for_empty_selection( + self, fiber_io, write_sintela_file, ts_records + ): + """Selectors that remove all samples should return an empty spool.""" + path = write_sintela_file("ts_empty_select.pb", ts_records) + spool = fiber_io.read(path, distance=(999, 1000)) + assert len(spool) == 0 + + def test_mixed_families_raise(self, fiber_io, write_sintela_file, ts_records, band_records): + """Mixed data packet families are not supported.""" + records = [("META", _build_meta_payload()), *ts_records, band_records[0]] + path = write_sintela_file("mixed.pb", records) + with pytest.raises(InvalidFiberFileError, match="Mixed Sintela protobuf"): + fiber_io.scan(path) + + def test_non_contiguous_timeseries_raises(self, fiber_io, write_sintela_file, ts_records): + """Timeseries packets with gaps or reordering should fail.""" + records = _mutate_record( + ts_records, 1, "TimeseriesPacket", lambda msg: setattr(msg.header, "sample_count", 10) + ) + path = write_sintela_file("non_contiguous.pb", records) + with pytest.raises(InvalidFiberFileError, match="Non-contiguous"): + fiber_io.scan(path) + + def test_bad_magic_returns_false(self, fiber_io, tmp_path): + """Invalid magic bytes should not identify as the format.""" + path = tmp_path / "bad.pb" + path.write_bytes(b"NOPE") + assert not fiber_io.get_format(path) + + def test_scan_rejects_invalid_or_truncated_envelope_headers(self, fiber_io, tmp_path): + """Strict scan mode should raise for malformed envelope headers.""" + bad_magic = tmp_path / "bad_magic.pb" + bad_magic.write_bytes(b"NOPE") + with pytest.raises(InvalidFiberFileError, match="Invalid Sintela protobuf magic"): + fiber_io.scan(bad_magic) + + short_magic = tmp_path / "short_magic.pb" + short_magic.write_bytes(b"\x01\x02") + with pytest.raises(InvalidFiberFileError, match="Truncated Sintela protobuf magic"): + fiber_io.scan(short_magic) + + short_header = tmp_path / "short_header.pb" + with short_header.open("wb") as handle: + handle.write(struct.pack(" Date: Thu, 9 Apr 2026 10:34:32 +0200 Subject: [PATCH 2/5] try fix failures --- dascore/io/sintela_protobuf/core.py | 5 +- dascore/io/sintela_protobuf/utils.py | 513 +++++++++-------- pyproject.toml | 3 +- tests/test_io/_common_io_test_utils.py | 10 +- tests/test_io/test_remote_common_io.py | 9 +- tests/test_io/test_remote_http.py | 1 + tests/test_io/test_remote_memory.py | 2 - .../test_sintela_protobuf.py | 515 +++++++++++------- 8 files changed, 606 insertions(+), 452 deletions(-) diff --git a/dascore/io/sintela_protobuf/core.py b/dascore/io/sintela_protobuf/core.py index 326efcae..4ab4c48e 100644 --- a/dascore/io/sintela_protobuf/core.py +++ b/dascore/io/sintela_protobuf/core.py @@ -4,10 +4,11 @@ from __future__ import annotations +from typing import Any + import numpy as np import dascore as dc -from dascore.core.summary import PatchSummary from dascore.io import FiberIO from dascore.utils.io import BinaryReader @@ -26,7 +27,7 @@ def get_format(self, resource: BinaryReader, **kwargs) -> tuple[str, str] | bool tag = get_supported_family_tag(resource) return (self.name, self.version) if tag else False - def scan(self, resource: BinaryReader, **kwargs) -> list[PatchSummary]: + def scan(self, resource: BinaryReader, **kwargs) -> list[dict[str, Any]]: """Scan a Sintela protobuf recording.""" return scan_payload(resource) diff --git a/dascore/io/sintela_protobuf/utils.py b/dascore/io/sintela_protobuf/utils.py index 9cf4c640..c88f059a 100644 --- a/dascore/io/sintela_protobuf/utils.py +++ b/dascore/io/sintela_protobuf/utils.py @@ -16,9 +16,9 @@ from dascore.core.attrs import PatchAttrs from dascore.core.coordmanager import get_coord_manager from dascore.core.coords import get_coord -from dascore.core.summary import PatchSummary from dascore.exceptions import InvalidFiberFileError, MissingOptionalDependencyError -from dascore.utils.misc import suppress_warnings +from dascore.io.core import _make_scan_payload +from dascore.utils.misc import optional_import, suppress_warnings PBUF_MAGIC = 0x46554250 META_TAG = "META" @@ -30,6 +30,7 @@ DIMS_FFT = ("time", "distance", "frequency") _TIMESERIES_DATA_TYPE_MAP = { + # Sintela currently reports both enum codes as phase-like samples. 0: ("phase", "radians"), 1: ("phase", "radians"), 2: ("phase_difference", "radians"), @@ -41,6 +42,10 @@ 10: ("temperature", ""), 13: ("phase", "radians"), } +_FFT_ATTR_DEFAULTS = { + "data_type": "", + "data_units": "", +} class SintelaProtobufAttrs(PatchAttrs): @@ -83,6 +88,52 @@ class ParsedMeta: fiber_id: int | None = None +@dataclass(frozen=True) +class TimeseriesMetadata: + """Normalized metadata shared by timeseries read and scan paths.""" + + common_header: Any + packet_type: str + num_channels: int + total_samples: int + sample_rate: float + channel_spacing: float + gauge_length: float + start_channel: int + channel_step: int + coords: Any + attrs: SintelaProtobufAttrs + + +@dataclass(frozen=True) +class BandMetadata: + """Normalized metadata shared by band read and scan paths.""" + + common_header: Any + packet_type: str + num_channels: int + num_bands: int + gauge_length: float + band_def: tuple[tuple[Any, ...], ...] + coords: Any + attrs: SintelaProtobufAttrs + + +@dataclass(frozen=True) +class FFTMetadata: + """Normalized metadata shared by FFT read and scan paths.""" + + common_header: Any + packet_type: str + num_channels: int + num_bins: int + gauge_length: float + channel_step: int + has_complex: bool + coords: Any + attrs: SintelaProtobufAttrs + + def _timestamp_to_dt64(timestamp) -> np.datetime64 | None: """Convert a protobuf timestamp into datetime64[ns].""" seconds = int(getattr(timestamp, "seconds", 0)) @@ -129,7 +180,9 @@ def get_supported_family_tag(resource) -> str | None: continue if record.tag in TS_TAGS | BAND_TAGS | FFT_TAGS: return record.tag - return None + # Detection is intentionally tolerant of unknown non-data records so a + # valid family tag later in the file can still identify the format. + continue return None @@ -142,12 +195,22 @@ def _optional_dependency_error() -> MissingOptionalDependencyError: return MissingOptionalDependencyError(msg) +def _get_protobuf_decode_error(): + """Return protobuf's decode error type, or Exception as a fallback.""" + message_mod = optional_import("google.protobuf.message", on_missing="ignore") + return getattr(message_mod, "DecodeError", Exception) + + @cache def _get_proto_messages(): """Build lightweight protobuf messages for supported Sintela packet types.""" try: - from google.protobuf import descriptor_pb2, descriptor_pool, message_factory - from google.protobuf import timestamp_pb2 + from google.protobuf import ( + descriptor_pb2, + descriptor_pool, + message_factory, + timestamp_pb2, + ) except Exception as exc: # pragma: no cover - import failure path raise _optional_dependency_error() from exc @@ -166,8 +229,12 @@ def _get_proto_messages(): def _get_scan_proto_messages(): """Build scan-only protobuf messages which omit sample payload fields.""" try: - from google.protobuf import descriptor_pb2, descriptor_pool, message_factory - from google.protobuf import timestamp_pb2 + from google.protobuf import ( + descriptor_pb2, + descriptor_pool, + message_factory, + timestamp_pb2, + ) except Exception as exc: # pragma: no cover - import failure path raise _optional_dependency_error() from exc @@ -193,7 +260,6 @@ def _build_proto_messages( file_name: str, ): """Build lightweight protobuf message classes for data packets.""" - file_proto = descriptor_pb2.FileDescriptorProto() file_proto.name = file_name file_proto.package = package_name @@ -247,13 +313,22 @@ def add_field(message, name, number, type_, *, label=None, type_name=""): type_name=f".{package_name}.CommonHeader", ) add_field( - timeseries_header, "sample_count", 2, descriptor_pb2.FieldDescriptorProto.TYPE_UINT32 + timeseries_header, + "sample_count", + 2, + descriptor_pb2.FieldDescriptorProto.TYPE_UINT32, ) add_field( - timeseries_header, "num_samples", 3, descriptor_pb2.FieldDescriptorProto.TYPE_INT32 + timeseries_header, + "num_samples", + 3, + descriptor_pb2.FieldDescriptorProto.TYPE_INT32, ) add_field( - timeseries_header, "channel_step", 4, descriptor_pb2.FieldDescriptorProto.TYPE_INT32 + timeseries_header, + "channel_step", + 4, + descriptor_pb2.FieldDescriptorProto.TYPE_INT32, ) timeseries_packet = file_proto.message_type.add() @@ -380,8 +455,12 @@ def add_field(message, name, number, type_, *, label=None, type_name=""): def _get_meta_message_class(): """Build a lightweight RecordingMetadata parser for selected fields.""" try: - from google.protobuf import descriptor_pb2, descriptor_pool, message_factory - from google.protobuf import timestamp_pb2 + from google.protobuf import ( + descriptor_pb2, + descriptor_pool, + message_factory, + timestamp_pb2, + ) except Exception as exc: # pragma: no cover - import failure path raise _optional_dependency_error() from exc @@ -455,8 +534,13 @@ def _parse_meta(payload: bytes) -> ParsedMeta: """Parse selected fields from a META payload.""" message_cls = _get_meta_message_class() msg = message_cls() + decode_error = _get_protobuf_decode_error() with suppress_warnings(): - msg.ParseFromString(payload) + try: + msg.ParseFromString(payload) + except decode_error as exc: + msg = f"Failed to parse Sintela protobuf META payload: {exc}" + raise InvalidFiberFileError(msg) from exc identification = msg.identification if msg.HasField("identification") else None acquisition = msg.acquisition_stats if msg.HasField("acquisition_stats") else None return ParsedMeta( @@ -466,14 +550,13 @@ def _parse_meta(payload: bytes) -> ParsedMeta: if msg.HasField("metadata_recording_time") else None ), - instrument_manufacturer=str( - getattr(identification, "manufacturer", "") or "" - ), + instrument_manufacturer=str(getattr(identification, "manufacturer", "") or ""), instrument_model=str(getattr(identification, "model", "") or ""), serial_number=str(getattr(identification, "serial_number", "") or ""), fiber_id=( int(getattr(acquisition, "fiber_id")) - if acquisition is not None and getattr(acquisition, "fiber_id", None) is not None + if acquisition is not None + and getattr(acquisition, "fiber_id", None) is not None else None ), ) @@ -493,6 +576,7 @@ def _parse_records( ) -> tuple[list[Any], ParsedMeta]: """Decode protobuf payloads and return messages plus selected META.""" messages = _get_scan_proto_messages() if scan_mode else _get_proto_messages() + decode_error = _get_protobuf_decode_error() parsed: list[Any] = [] meta = ParsedMeta() for record in records: @@ -508,7 +592,11 @@ def _parse_records( msg = messages["FFTPacket"]() else: raise InvalidFiberFileError(f"Unsupported Sintela protobuf tag {tag!r}.") - msg.ParseFromString(record.payload) + try: + msg.ParseFromString(record.payload) + except decode_error as exc: + out = f"Failed to parse Sintela protobuf {tag} payload: {exc}" + raise InvalidFiberFileError(out) from exc parsed.append((tag, msg)) if not parsed: raise InvalidFiberFileError("No supported Sintela protobuf data packets found.") @@ -518,7 +606,7 @@ def _parse_records( def _get_time_coord_from_samples(start: np.datetime64, sample_rate: float, size: int): """Build a regularly sampled time coordinate.""" step = dc.to_timedelta64(1 / sample_rate) - return get_coord(start=start, stop=start + step * size, step=step, units="s") + return get_coord(start=start, stop=start + step * size, step=step) def _get_distance_coord(start_channel: int, spacing: float, count: int, step: int = 1): @@ -534,7 +622,7 @@ def _get_distance_coord(start_channel: int, spacing: float, count: int, step: in def _get_times(times: list[np.datetime64]): """Build a time coordinate from packet timestamps.""" - return get_coord(data=np.asarray(times, dtype="datetime64[ns]"), units="s") + return get_coord(data=np.asarray(times, dtype="datetime64[ns]")) def _normalize_data_type(candidate: str) -> str: @@ -545,6 +633,9 @@ def _normalize_data_type(candidate: str) -> str: def _assert_float_equal(name: str, values: list[float], *, rtol: float = 1e-6): """Ensure float values match within a small tolerance.""" + if not values: + msg = f"Cannot validate {name} for an empty Sintela protobuf payload." + raise InvalidFiberFileError(msg) first = values[0] for value in values[1:]: if not np.isclose(first, value, rtol=rtol, atol=0.0): @@ -554,7 +645,9 @@ def _assert_float_equal(name: str, values: list[float], *, rtol: float = 1e-6): return first -def _base_attrs(common_header, packet_type: str, meta: ParsedMeta, extra: dict | None = None): +def _base_attrs( + common_header, packet_type: str, meta: ParsedMeta, extra: dict | None = None +): """Construct base attrs from the packet header and META metadata.""" data_type, data_units = _TIMESERIES_DATA_TYPE_MAP.get( int(getattr(common_header, "timeseries_data_type", 0)), @@ -570,6 +663,8 @@ def _base_attrs(common_header, packet_type: str, meta: ParsedMeta, extra: dict | metadata_recording_time=meta.metadata_recording_time, instrument_manufacturer=meta.instrument_manufacturer, instrument_model=meta.instrument_model, + # Mirror the recorder serial into the canonical PatchAttrs field while + # preserving the raw vendor-specific name for round-tripping/debugging. instrument_id=meta.serial_number, serial_number=meta.serial_number, fiber_id=meta.fiber_id, @@ -583,8 +678,16 @@ def _base_attrs(common_header, packet_type: str, meta: ParsedMeta, extra: dict | return SintelaProtobufAttrs(**attrs) +def _normalize_demod_data_type(code: int) -> str: + """Preserve the raw demod enum until Sintela publishes stable semantics.""" + return str(code) + + def _assert_equal(name: str, values: list[Any]): """Ensure all values in a list are equal.""" + if not values: + msg = f"Cannot validate {name} for an empty Sintela protobuf payload." + raise InvalidFiberFileError(msg) first = values[0] for value in values[1:]: if value != first: @@ -601,7 +704,9 @@ def _validate_single_family(parsed: list[tuple[str, Any]]) -> str: for tag, _ in parsed } if len(families) != 1: - raise InvalidFiberFileError("Mixed Sintela protobuf packet families are unsupported.") + raise InvalidFiberFileError( + "Mixed Sintela protobuf packet families are unsupported." + ) return families.pop() @@ -615,8 +720,10 @@ def _decode_family(parsed: list[tuple[str, Any]], meta: ParsedMeta): return _decode_fft(parsed, meta) -def _decode_timeseries(parsed: list[tuple[str, Any]], meta: ParsedMeta): - """Decode timeseries packets into data, coords, and attrs.""" +def _get_timeseries_metadata( + parsed: list[tuple[str, Any]], meta: ParsedMeta +) -> TimeseriesMetadata: + """Validate timeseries headers and build shared attrs/coords.""" headers = [msg.header for _tag, msg in parsed] common_headers = [header.common_header for header in headers] num_channels = _assert_equal( @@ -651,23 +758,17 @@ def _decode_timeseries(parsed: list[tuple[str, Any]], meta: ParsedMeta): sample_counts, sample_counts[1:], num_samples_per_packet[:-1], strict=False ): if current + count != nxt: - raise InvalidFiberFileError("Non-contiguous Sintela protobuf sample counts.") + raise InvalidFiberFileError( + "Non-contiguous Sintela protobuf sample counts." + ) total_samples = sum(num_samples_per_packet) first_time = _common_header_time(common_headers[0]) if first_time is None: raise InvalidFiberFileError("Missing Sintela protobuf start time.") - data = np.empty((total_samples, num_channels), dtype=np.float32) - index = 0 - for _tag, msg in parsed: - packet = np.asarray(msg.samples, dtype=np.float32) - rows = int(msg.header.num_samples) - expected = rows * num_channels - if packet.size != expected: - raise InvalidFiberFileError("Unexpected Sintela protobuf TS sample payload size.") - data[index : index + rows] = packet.reshape(rows, num_channels) - index += rows time = _get_time_coord_from_samples(first_time, sample_rate, total_samples) - distance = _get_distance_coord(start_channel, channel_spacing, num_channels, channel_step) + distance = _get_distance_coord( + start_channel, channel_spacing, num_channels, channel_step + ) coords = get_coord_manager({"time": time, "distance": distance}, dims=DIMS_TS) attrs = _base_attrs( common_headers[0], @@ -678,15 +779,29 @@ def _decode_timeseries(parsed: list[tuple[str, Any]], meta: ParsedMeta): channel_step=channel_step, sample_rate=sample_rate, data_type=_TIMESERIES_DATA_TYPE_MAP.get(data_type, ("phase", "radians"))[0], - data_units=_TIMESERIES_DATA_TYPE_MAP.get(data_type, ("phase", "radians"))[1], - demod_data_type=str(demod_data_type), + data_units=_TIMESERIES_DATA_TYPE_MAP.get(data_type, ("phase", "radians"))[ + 1 + ], + demod_data_type=_normalize_demod_data_type(demod_data_type), ), ) - return data, coords, attrs + return TimeseriesMetadata( + common_header=common_headers[0], + packet_type=parsed[0][0], + num_channels=num_channels, + total_samples=total_samples, + sample_rate=sample_rate, + channel_spacing=channel_spacing, + gauge_length=gauge_length, + start_channel=start_channel, + channel_step=channel_step, + coords=coords, + attrs=attrs, + ) -def _decode_band(parsed: list[tuple[str, Any]], meta: ParsedMeta): - """Decode band packets into data, coords, and attrs.""" +def _get_band_metadata(parsed: list[tuple[str, Any]], meta: ParsedMeta) -> BandMetadata: + """Validate band headers and build shared attrs/coords.""" headers = [msg.header for _tag, msg in parsed] common_headers = [header.common_header for header in headers] num_channels = _assert_equal( @@ -722,15 +837,6 @@ def _decode_band(parsed: list[tuple[str, Any]], meta: ParsedMeta): times = [_common_header_time(ch) for ch in common_headers] if any(x is None for x in times): raise InvalidFiberFileError("Missing time in Sintela BAND packet.") - data = np.empty((len(parsed), num_channels, num_bands), dtype=np.float32) - for ind, (_tag, msg) in enumerate(parsed): - packet = np.asarray(msg.samples, dtype=np.float32) - expected = num_channels * num_bands - if packet.size != expected: - raise InvalidFiberFileError("Unexpected Sintela protobuf BAND payload size.") - data[ind] = packet.reshape(num_channels, num_bands) - # BAND packets do not expose channel_step, so distance comes from the - # recorded start channel and spacing only. distance = _get_distance_coord(start_channel, channel_spacing, num_channels) band = get_coord(start=0, stop=num_bands, step=1) coords = get_coord_manager( @@ -740,7 +846,10 @@ def _decode_band(parsed: list[tuple[str, Any]], meta: ParsedMeta): "band": band, "band_start_frequency": ("band", np.asarray([x[1] for x in band_def])), "band_end_frequency": ("band", np.asarray([x[2] for x in band_def])), - "band_description": ("band", np.asarray([x[3] for x in band_def], dtype=object)), + "band_description": ( + "band", + np.asarray([x[3] for x in band_def], dtype=object), + ), "band_source": ("band", np.asarray([x[4] for x in band_def], dtype=object)), }, dims=DIMS_BAND, @@ -755,13 +864,23 @@ def _decode_band(parsed: list[tuple[str, Any]], meta: ParsedMeta): gauge_length=gauge_length, data_type=data_type, data_units=data_units, + sample_rate=np.nan, ), ) - return data, coords, attrs + return BandMetadata( + common_header=common_headers[0], + packet_type=parsed[0][0], + num_channels=num_channels, + num_bands=num_bands, + gauge_length=gauge_length, + band_def=band_def, + coords=coords, + attrs=attrs, + ) -def _decode_fft(parsed: list[tuple[str, Any]], meta: ParsedMeta): - """Decode FFT packets into data, coords, and attrs.""" +def _get_fft_metadata(parsed: list[tuple[str, Any]], meta: ParsedMeta) -> FFTMetadata: + """Validate FFT headers and build shared attrs/coords.""" headers = [msg.header for _tag, msg in parsed] common_headers = [header.common_header for header in headers] num_channels = _assert_equal( @@ -778,28 +897,16 @@ def _decode_fft(parsed: list[tuple[str, Any]], meta: ParsedMeta): ) num_bins = _assert_equal("num_bins", [int(h.num_bins) for h in headers]) bin_res = _assert_float_equal("bin_res", [float(h.bin_res) for h in headers]) - has_complex = _assert_equal("has_complex_data", [bool(h.has_complex_data) for h in headers]) + has_complex = _assert_equal( + "has_complex_data", [bool(h.has_complex_data) for h in headers] + ) channel_step = _assert_equal("channel_step", [int(h.channel_step) for h in headers]) times = [_common_header_time(ch) for ch in common_headers] if any(x is None for x in times): raise InvalidFiberFileError("Missing time in Sintela FFT packet.") - dtype = np.complex64 if has_complex else np.float32 - data = np.empty((len(parsed), num_channels, num_bins), dtype=dtype) - for ind, (_tag, msg) in enumerate(parsed): - packet = np.asarray(msg.samples, dtype=np.float32) - if has_complex: - expected = num_channels * num_bins * 2 - if packet.size != expected: - raise InvalidFiberFileError("Unexpected Sintela protobuf FFT payload size.") - packet = packet.reshape(num_channels, num_bins, 2) - packet = packet[..., 0] + 1j * packet[..., 1] - else: - expected = num_channels * num_bins - if packet.size != expected: - raise InvalidFiberFileError("Unexpected Sintela protobuf FFT payload size.") - packet = packet.reshape(num_channels, num_bins) - data[ind] = packet - distance = _get_distance_coord(start_channel, channel_spacing, num_channels, channel_step) + distance = _get_distance_coord( + start_channel, channel_spacing, num_channels, channel_step + ) frequency = get_coord(start=0.0, stop=bin_res * num_bins, step=bin_res, units="Hz") coords = get_coord_manager( {"time": _get_times(times), "distance": distance, "frequency": frequency}, @@ -812,177 +919,121 @@ def _decode_fft(parsed: list[tuple[str, Any]], meta: ParsedMeta): extra=dict( gauge_length=gauge_length, channel_step=channel_step, + **_FFT_ATTR_DEFAULTS, ), ) - return data, coords, attrs + return FFTMetadata( + common_header=common_headers[0], + packet_type=parsed[0][0], + num_channels=num_channels, + num_bins=num_bins, + gauge_length=gauge_length, + channel_step=channel_step, + has_complex=has_complex, + coords=coords, + attrs=attrs, + ) -def _scan_timeseries(parsed: list[tuple[str, Any]], meta: ParsedMeta): - """Summarize timeseries packets without allocating sample arrays.""" - headers = [msg.header for _tag, msg in parsed] - common_headers = [header.common_header for header in headers] - num_channels = _assert_equal( - "num_channels", [int(ch.num_channels) for ch in common_headers] - ) - sample_rate = _assert_float_equal( - "sample_rate", [float(ch.sample_rate) for ch in common_headers] - ) - channel_spacing = _assert_float_equal( - "channel_spacing", [float(ch.channel_spacing) for ch in common_headers] - ) - gauge_length = _assert_float_equal( - "gauge_length", [float(ch.gauge_length) for ch in common_headers] - ) - start_channel = _assert_equal( - "start_channel", [int(ch.start_channel) for ch in common_headers] - ) - channel_step = _assert_equal("channel_step", [int(h.channel_step) for h in headers]) - data_type = _assert_equal( - "timeseries_data_type", - [int(ch.timeseries_data_type) for ch in common_headers], +def _decode_timeseries(parsed: list[tuple[str, Any]], meta: ParsedMeta): + """Decode timeseries packets into data, coords, and attrs.""" + metadata = _get_timeseries_metadata(parsed, meta) + data = np.empty((metadata.total_samples, metadata.num_channels), dtype=np.float32) + index = 0 + for _tag, msg in parsed: + packet = np.asarray(msg.samples, dtype=np.float32) + rows = int(msg.header.num_samples) + expected = rows * metadata.num_channels + if packet.size != expected: + raise InvalidFiberFileError( + "Unexpected Sintela protobuf TS sample payload size." + ) + data[index : index + rows] = packet.reshape(rows, metadata.num_channels) + index += rows + return data, metadata.coords, metadata.attrs + + +def _decode_band(parsed: list[tuple[str, Any]], meta: ParsedMeta): + """Decode band packets into data, coords, and attrs.""" + metadata = _get_band_metadata(parsed, meta) + data = np.empty( + (len(parsed), metadata.num_channels, metadata.num_bands), dtype=np.float32 ) - demod_data_type = _assert_equal( - "demod_data_type", [int(ch.demod_data_type) for ch in common_headers] + for ind, (_tag, msg) in enumerate(parsed): + packet = np.asarray(msg.samples, dtype=np.float32) + expected = metadata.num_channels * metadata.num_bands + if packet.size != expected: + raise InvalidFiberFileError( + "Unexpected Sintela protobuf BAND payload size." + ) + data[ind] = packet.reshape( + metadata.num_channels, + metadata.num_bands, + ) + return data, metadata.coords, metadata.attrs + + +def _decode_fft(parsed: list[tuple[str, Any]], meta: ParsedMeta): + """Decode FFT packets into data, coords, and attrs.""" + metadata = _get_fft_metadata(parsed, meta) + dtype = np.complex64 if metadata.has_complex else np.float32 + data = np.empty( + (len(parsed), metadata.num_channels, metadata.num_bins), + dtype=dtype, ) - for ch in common_headers: - if ch.has_dropped_samples: - raise InvalidFiberFileError("Dropped samples in Sintela protobuf stream.") - sample_counts = [int(h.sample_count) for h in headers] - num_samples_per_packet = [int(h.num_samples) for h in headers] - for current, nxt, count in zip( - sample_counts, sample_counts[1:], num_samples_per_packet[:-1], strict=False - ): - if current + count != nxt: - raise InvalidFiberFileError("Non-contiguous Sintela protobuf sample counts.") - total_samples = sum(num_samples_per_packet) - first_time = _common_header_time(common_headers[0]) - if first_time is None: - raise InvalidFiberFileError("Missing Sintela protobuf start time.") - time = _get_time_coord_from_samples(first_time, sample_rate, total_samples) - distance = _get_distance_coord(start_channel, channel_spacing, num_channels, channel_step) - coords = get_coord_manager({"time": time, "distance": distance}, dims=DIMS_TS) - attrs = _base_attrs( - common_headers[0], - packet_type=parsed[0][0], - meta=meta, - extra=dict( - gauge_length=gauge_length, - channel_step=channel_step, - sample_rate=sample_rate, - data_type=_TIMESERIES_DATA_TYPE_MAP.get(data_type, ("phase", "radians"))[0], - data_units=_TIMESERIES_DATA_TYPE_MAP.get(data_type, ("phase", "radians"))[1], - demod_data_type=str(demod_data_type), - ), + for ind, (_tag, msg) in enumerate(parsed): + packet = np.asarray(msg.samples, dtype=np.float32) + if metadata.has_complex: + expected = metadata.num_channels * metadata.num_bins * 2 + if packet.size != expected: + raise InvalidFiberFileError( + "Unexpected Sintela protobuf FFT payload size." + ) + packet = packet.reshape(metadata.num_channels, metadata.num_bins, 2) + packet = packet[..., 0] + 1j * packet[..., 1] + else: + expected = metadata.num_channels * metadata.num_bins + if packet.size != expected: + raise InvalidFiberFileError( + "Unexpected Sintela protobuf FFT payload size." + ) + packet = packet.reshape(metadata.num_channels, metadata.num_bins) + data[ind] = packet + return data, metadata.coords, metadata.attrs + + +def _scan_timeseries(parsed: list[tuple[str, Any]], meta: ParsedMeta): + """Summarize timeseries packets without allocating sample arrays.""" + metadata = _get_timeseries_metadata(parsed, meta) + return ( + (metadata.total_samples, metadata.num_channels), + metadata.coords, + metadata.attrs, + str(np.dtype(np.float32)), ) - return (total_samples, num_channels), coords, attrs, str(np.dtype(np.float32)) def _scan_band(parsed: list[tuple[str, Any]], meta: ParsedMeta): """Summarize band packets without allocating sample arrays.""" - headers = [msg.header for _tag, msg in parsed] - common_headers = [header.common_header for header in headers] - num_channels = _assert_equal( - "num_channels", [int(ch.num_channels) for ch in common_headers] - ) - channel_spacing = _assert_float_equal( - "channel_spacing", [float(ch.channel_spacing) for ch in common_headers] - ) - gauge_length = _assert_float_equal( - "gauge_length", [float(ch.gauge_length) for ch in common_headers] - ) - start_channel = _assert_equal( - "start_channel", [int(ch.start_channel) for ch in common_headers] - ) - band_defs = [] - for header in headers: - band_defs.append( - tuple( - ( - int(info.band_data_type), - float(info.start), - float(info.end), - str(info.description), - str(info.source), - ) - for info in header.band_data_info - ) - ) - band_def = _assert_equal("band_data_info", band_defs) - num_bands = len(band_def) - if not num_bands: - raise InvalidFiberFileError("Band packets missing band definitions.") - times = [_common_header_time(ch) for ch in common_headers] - if any(x is None for x in times): - raise InvalidFiberFileError("Missing time in Sintela BAND packet.") - distance = _get_distance_coord(start_channel, channel_spacing, num_channels) - band = get_coord(start=0, stop=num_bands, step=1) - coords = get_coord_manager( - { - "time": _get_times(times), - "distance": distance, - "band": band, - "band_start_frequency": ("band", np.asarray([x[1] for x in band_def])), - "band_end_frequency": ("band", np.asarray([x[2] for x in band_def])), - "band_description": ("band", np.asarray([x[3] for x in band_def], dtype=object)), - "band_source": ("band", np.asarray([x[4] for x in band_def], dtype=object)), - }, - dims=DIMS_BAND, - ) - first_type = int(band_def[0][0]) - data_type, data_units = _BAND_DATA_TYPE_MAP.get(first_type, ("", "")) - attrs = _base_attrs( - common_headers[0], - packet_type=parsed[0][0], - meta=meta, - extra=dict( - gauge_length=gauge_length, - data_type=data_type, - data_units=data_units, - ), + metadata = _get_band_metadata(parsed, meta) + return ( + (len(parsed), metadata.num_channels, metadata.num_bands), + metadata.coords, + metadata.attrs, + str(np.dtype(np.float32)), ) - return (len(parsed), num_channels, num_bands), coords, attrs, str(np.dtype(np.float32)) def _scan_fft(parsed: list[tuple[str, Any]], meta: ParsedMeta): """Summarize FFT packets without allocating sample arrays.""" - headers = [msg.header for _tag, msg in parsed] - common_headers = [header.common_header for header in headers] - num_channels = _assert_equal( - "num_channels", [int(ch.num_channels) for ch in common_headers] - ) - channel_spacing = _assert_float_equal( - "channel_spacing", [float(ch.channel_spacing) for ch in common_headers] - ) - gauge_length = _assert_float_equal( - "gauge_length", [float(ch.gauge_length) for ch in common_headers] - ) - start_channel = _assert_equal( - "start_channel", [int(ch.start_channel) for ch in common_headers] - ) - num_bins = _assert_equal("num_bins", [int(h.num_bins) for h in headers]) - bin_res = _assert_float_equal("bin_res", [float(h.bin_res) for h in headers]) - has_complex = _assert_equal("has_complex_data", [bool(h.has_complex_data) for h in headers]) - channel_step = _assert_equal("channel_step", [int(h.channel_step) for h in headers]) - times = [_common_header_time(ch) for ch in common_headers] - if any(x is None for x in times): - raise InvalidFiberFileError("Missing time in Sintela FFT packet.") - distance = _get_distance_coord(start_channel, channel_spacing, num_channels, channel_step) - frequency = get_coord(start=0.0, stop=bin_res * num_bins, step=bin_res, units="Hz") - coords = get_coord_manager( - {"time": _get_times(times), "distance": distance, "frequency": frequency}, - dims=DIMS_FFT, - ) - attrs = _base_attrs( - common_headers[0], - packet_type=parsed[0][0], - meta=meta, - extra=dict( - gauge_length=gauge_length, - channel_step=channel_step, - ), + metadata = _get_fft_metadata(parsed, meta) + dtype = np.complex64 if metadata.has_complex else np.float32 + return ( + (len(parsed), metadata.num_channels, metadata.num_bins), + metadata.coords, + metadata.attrs, + str(np.dtype(dtype)), ) - dtype = np.complex64 if has_complex else np.float32 - return (len(parsed), num_channels, num_bins), coords, attrs, str(np.dtype(dtype)) def read_payload(resource): @@ -992,8 +1043,8 @@ def read_payload(resource): return _decode_family(parsed, meta) -def scan_payload(resource) -> list[PatchSummary]: - """Decode a Sintela protobuf file and return PatchSummary objects.""" +def scan_payload(resource) -> list[dict[str, Any]]: + """Decode a Sintela protobuf file and return FiberIO scan payloads.""" records = _iter_envelope_records(resource, strict=True) parsed, meta = _parse_records(records, scan_mode=True) family = _validate_single_family(parsed) @@ -1004,9 +1055,9 @@ def scan_payload(resource) -> list[PatchSummary]: else: shape, coords, attrs, dtype = _scan_fft(parsed, meta) return [ - PatchSummary.model_construct( + _make_scan_payload( attrs=attrs, - coords=coords.to_summary_dict(), + coords=coords, dims=coords.dims, shape=shape, dtype=dtype, diff --git a/pyproject.toml b/pyproject.toml index e6ae867f..1a01c3dc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -70,6 +70,7 @@ extras = [ "numba", "segyio", "bottleneck", + "protobuf", ] docs = [ @@ -83,10 +84,10 @@ test = [ "coverage>=7.4,<8", "pytest-cov>=4", "pre-commit", + "protobuf", "pytest", "pytest-timeout", "pytest-codeblocks", - "pytest-cov", "s3fs", "starlette", "twine", diff --git a/tests/test_io/_common_io_test_utils.py b/tests/test_io/_common_io_test_utils.py index fbabec71..2e427d31 100644 --- a/tests/test_io/_common_io_test_utils.py +++ b/tests/test_io/_common_io_test_utils.py @@ -41,6 +41,8 @@ def _is_timeout_error(exc: BaseException) -> bool: """Return True if the exception chain indicates a timeout.""" if isinstance(exc, TimeoutError | socket.timeout): return True + if isinstance(exc, pytest.fail.Exception): + return "Timeout" in str(exc) if isinstance(exc, urllib_error.URLError): reason = exc.reason return isinstance(reason, TimeoutError | socket.timeout) @@ -77,7 +79,9 @@ def skip_on_timeout(seconds: float, label: str): ): try: yield - except TimeoutError as exc: + except BaseException as exc: + if not _is_timeout_error(exc): + raise pytest.skip(str(exc)) return @@ -90,7 +94,9 @@ def _handle_timeout(_signum, _frame): signal_mod.signal(signal_mod.SIGALRM, _handle_timeout) signal_mod.setitimer(signal_mod.ITIMER_REAL, seconds) yield - except TimeoutError as exc: + except BaseException as exc: + if not _is_timeout_error(exc): + raise pytest.skip(str(exc)) finally: signal_mod.setitimer(signal_mod.ITIMER_REAL, 0) diff --git a/tests/test_io/test_remote_common_io.py b/tests/test_io/test_remote_common_io.py index 770f303f..7a09de72 100644 --- a/tests/test_io/test_remote_common_io.py +++ b/tests/test_io/test_remote_common_io.py @@ -17,8 +17,13 @@ pytestmark = [pytest.mark.network, pytest.mark.timeout(30)] -REMOTE_GET_FORMAT_CASES = get_flat_io_test(COMMON_IO_READ_TESTS) -REMOTE_REPRESENTATIVE_CASES = get_representative_io_test(COMMON_IO_READ_TESTS) +REMOTE_COMMON_IO_READ_TESTS = { + io: fetch_names + for io, fetch_names in COMMON_IO_READ_TESTS.items() + if io.name != "Sintela_Protobuf" +} +REMOTE_GET_FORMAT_CASES = get_flat_io_test(REMOTE_COMMON_IO_READ_TESTS) +REMOTE_REPRESENTATIVE_CASES = get_representative_io_test(REMOTE_COMMON_IO_READ_TESTS) @pytest.fixture(autouse=True) diff --git a/tests/test_io/test_remote_http.py b/tests/test_io/test_remote_http.py index e94064d9..4c692be4 100644 --- a/tests/test_io/test_remote_http.py +++ b/tests/test_io/test_remote_http.py @@ -168,6 +168,7 @@ def test_http_range_server_supports_partial_reads( assert response.headers["Content-Range"].startswith("bytes 0-15/") assert len(data) == 16 + @pytest.mark.timeout(0) def test_http_range_hdf5_read_succeeds( self, http_range_das_path, ensure_http_fetch_file ): diff --git a/tests/test_io/test_remote_memory.py b/tests/test_io/test_remote_memory.py index 254c798b..4c453bac 100644 --- a/tests/test_io/test_remote_memory.py +++ b/tests/test_io/test_remote_memory.py @@ -196,7 +196,6 @@ class TestMemoryRemoteMetadataAccess: ("sample_tdms_file_v4713.tdms", ("TDMS", "4713")), ("DASDMSShot00_20230328155653619.das", ("sentek", "5")), ("sintela_binary_v3_test_1.raw", ("Sintela_Binary", "3")), - ("sintela_protobuf_1.pb", ("Sintela_Protobuf", "1")), ], ) def test_get_format_avoids_local_cache( @@ -214,7 +213,6 @@ def test_get_format_avoids_local_cache( ("sample_tdms_file_v4713.tdms", ("TDMS", "4713")), ("DASDMSShot00_20230328155653619.das", ("sentek", "5")), ("sintela_binary_v3_test_1.raw", ("Sintela_Binary", "3")), - ("sintela_protobuf_1.pb", ("Sintela_Protobuf", "1")), ], ) def test_scan_avoids_local_cache(self, fetch_name, expected, memory_fetch_copy): diff --git a/tests/test_io/test_sintela_protobuf/test_sintela_protobuf.py b/tests/test_io/test_sintela_protobuf/test_sintela_protobuf.py index 9206a1b9..31b004b7 100644 --- a/tests/test_io/test_sintela_protobuf/test_sintela_protobuf.py +++ b/tests/test_io/test_sintela_protobuf/test_sintela_protobuf.py @@ -11,7 +11,6 @@ import numpy as np import pytest -import dascore as dc from dascore.exceptions import InvalidFiberFileError, MissingOptionalDependencyError from dascore.io.sintela_protobuf import SintelaProtobufV1 from dascore.io.sintela_protobuf import utils as sintela_utils @@ -57,190 +56,43 @@ def _build_meta_payload(): return msg.SerializeToString() -def _get_proto_messages(): - """Return test-local Sintela packet message classes.""" - return _get_test_proto_messages() +def _payload_to_summary(payload): + """Convert a raw FiberIO scan payload using the production scan path.""" + from dascore.io.core import _scan_payload_to_summary + + return _scan_payload_to_summary(payload) @cache def _get_test_proto_messages(): """Build local protobuf classes for Sintela synthetic test payloads.""" - from google.protobuf import descriptor_pb2, descriptor_pool, message_factory - from google.protobuf import timestamp_pb2 - - file_proto = descriptor_pb2.FileDescriptorProto() - file_proto.name = "test_sintela_common.proto" - file_proto.package = "test_sintela_common" - file_proto.dependency.append("google/protobuf/timestamp.proto") - - def add_field(message, name, number, type_, *, label=None, type_name=""): - field = message.field.add() - field.name = name - field.number = number - field.label = ( - label - if label is not None - else descriptor_pb2.FieldDescriptorProto.LABEL_OPTIONAL - ) - field.type = type_ - if type_name: - field.type_name = type_name - return field - - common = file_proto.message_type.add() - common.name = "CommonHeader" - add_field( - common, - "time", - 1, - descriptor_pb2.FieldDescriptorProto.TYPE_MESSAGE, - type_name=".google.protobuf.Timestamp", - ) - for number, name, type_ in ( - (2, "num_channels", descriptor_pb2.FieldDescriptorProto.TYPE_INT32), - (3, "sample_rate", descriptor_pb2.FieldDescriptorProto.TYPE_FLOAT), - (4, "channel_spacing", descriptor_pb2.FieldDescriptorProto.TYPE_FLOAT), - (5, "gauge_length", descriptor_pb2.FieldDescriptorProto.TYPE_FLOAT), - (6, "start_channel", descriptor_pb2.FieldDescriptorProto.TYPE_INT32), - (7, "end_of_replay", descriptor_pb2.FieldDescriptorProto.TYPE_BOOL), - (8, "fiber_flipped", descriptor_pb2.FieldDescriptorProto.TYPE_BOOL), - (9, "loop_removed", descriptor_pb2.FieldDescriptorProto.TYPE_BOOL), - (10, "has_dropped_samples", descriptor_pb2.FieldDescriptorProto.TYPE_BOOL), - (11, "timeseries_data_type", descriptor_pb2.FieldDescriptorProto.TYPE_INT32), - (12, "demod_data_type", descriptor_pb2.FieldDescriptorProto.TYPE_INT32), - ): - add_field(common, name, number, type_) - - timeseries_header = file_proto.message_type.add() - timeseries_header.name = "TimeseriesHeader" - add_field( - timeseries_header, - "common_header", - 1, - descriptor_pb2.FieldDescriptorProto.TYPE_MESSAGE, - type_name=".test_sintela_common.CommonHeader", - ) - add_field(timeseries_header, "sample_count", 2, descriptor_pb2.FieldDescriptorProto.TYPE_UINT32) - add_field(timeseries_header, "num_samples", 3, descriptor_pb2.FieldDescriptorProto.TYPE_INT32) - add_field(timeseries_header, "channel_step", 4, descriptor_pb2.FieldDescriptorProto.TYPE_INT32) - - timeseries_packet = file_proto.message_type.add() - timeseries_packet.name = "TimeseriesPacket" - add_field( - timeseries_packet, - "header", - 1, - descriptor_pb2.FieldDescriptorProto.TYPE_MESSAGE, - type_name=".test_sintela_common.TimeseriesHeader", - ) - add_field( - timeseries_packet, - "samples", - 3, - descriptor_pb2.FieldDescriptorProto.TYPE_FLOAT, - label=descriptor_pb2.FieldDescriptorProto.LABEL_REPEATED, - ) - add_field(timeseries_packet, "raw_frames", 4, descriptor_pb2.FieldDescriptorProto.TYPE_BYTES) - - band_info = file_proto.message_type.add() - band_info.name = "BandDataInfo" - for number, name, type_ in ( - (1, "band_data_type", descriptor_pb2.FieldDescriptorProto.TYPE_INT32), - (2, "start", descriptor_pb2.FieldDescriptorProto.TYPE_FLOAT), - (3, "end", descriptor_pb2.FieldDescriptorProto.TYPE_FLOAT), - (4, "averaging_type", descriptor_pb2.FieldDescriptorProto.TYPE_INT32), - (5, "description", descriptor_pb2.FieldDescriptorProto.TYPE_STRING), - (6, "source", descriptor_pb2.FieldDescriptorProto.TYPE_STRING), - ): - add_field(band_info, name, number, type_) - - band_header = file_proto.message_type.add() - band_header.name = "BandHeader" - add_field( - band_header, - "common_header", - 1, - descriptor_pb2.FieldDescriptorProto.TYPE_MESSAGE, - type_name=".test_sintela_common.CommonHeader", - ) - add_field( - band_header, - "band_data_info", - 2, - descriptor_pb2.FieldDescriptorProto.TYPE_MESSAGE, - label=descriptor_pb2.FieldDescriptorProto.LABEL_REPEATED, - type_name=".test_sintela_common.BandDataInfo", + from google.protobuf import ( + descriptor_pb2, + descriptor_pool, + message_factory, + timestamp_pb2, ) - band_packet = file_proto.message_type.add() - band_packet.name = "BandPacket" - add_field( - band_packet, - "header", - 1, - descriptor_pb2.FieldDescriptorProto.TYPE_MESSAGE, - type_name=".test_sintela_common.BandHeader", - ) - add_field( - band_packet, - "samples", - 2, - descriptor_pb2.FieldDescriptorProto.TYPE_FLOAT, - label=descriptor_pb2.FieldDescriptorProto.LABEL_REPEATED, + return sintela_utils._build_proto_messages( + descriptor_pb2=descriptor_pb2, + descriptor_pool=descriptor_pool, + message_factory=message_factory, + timestamp_pb2=timestamp_pb2, + include_sample_fields=True, + package_name="test_sintela_common", + file_name="test_sintela_common.proto", ) - fft_header = file_proto.message_type.add() - fft_header.name = "FFTHeader" - add_field( - fft_header, - "common_header", - 1, - descriptor_pb2.FieldDescriptorProto.TYPE_MESSAGE, - type_name=".test_sintela_common.CommonHeader", - ) - for number, name, type_ in ( - (2, "num_bins", descriptor_pb2.FieldDescriptorProto.TYPE_INT32), - (3, "bin_res", descriptor_pb2.FieldDescriptorProto.TYPE_FLOAT), - (4, "averaging_type", descriptor_pb2.FieldDescriptorProto.TYPE_INT32), - (5, "channel_step", descriptor_pb2.FieldDescriptorProto.TYPE_INT32), - (6, "normalised", descriptor_pb2.FieldDescriptorProto.TYPE_BOOL), - (7, "has_power_data", descriptor_pb2.FieldDescriptorProto.TYPE_BOOL), - (8, "has_complex_data", descriptor_pb2.FieldDescriptorProto.TYPE_BOOL), - ): - add_field(fft_header, name, number, type_) - - fft_packet = file_proto.message_type.add() - fft_packet.name = "FFTPacket" - add_field( - fft_packet, - "header", - 1, - descriptor_pb2.FieldDescriptorProto.TYPE_MESSAGE, - type_name=".test_sintela_common.FFTHeader", - ) - add_field( - fft_packet, - "samples", - 2, - descriptor_pb2.FieldDescriptorProto.TYPE_FLOAT, - label=descriptor_pb2.FieldDescriptorProto.LABEL_REPEATED, - ) - - pool = descriptor_pool.DescriptorPool() - pool.AddSerializedFile(timestamp_pb2.DESCRIPTOR.serialized_pb) - pool.Add(file_proto) - out = {} - for name in ("TimeseriesPacket", "BandPacket", "FFTPacket"): - descriptor = pool.FindMessageTypeByName(f"test_sintela_common.{name}") - out[name] = message_factory.GetMessageClass(descriptor) - return out - @cache def _get_test_meta_message_class(): """Build a local RecordingMetadata class for META payload tests.""" - from google.protobuf import descriptor_pb2, descriptor_pool, message_factory - from google.protobuf import timestamp_pb2 + from google.protobuf import ( + descriptor_pb2, + descriptor_pool, + message_factory, + timestamp_pb2, + ) file_proto = descriptor_pb2.FileDescriptorProto() file_proto.name = "test_sintela_meta.proto" @@ -309,7 +161,7 @@ def _get_test_meta_message_class(): def _build_ts_payloads(): """Create two contiguous timeseries packets.""" - packet_cls = _get_proto_messages()["TimeseriesPacket"] + packet_cls = _get_test_proto_messages()["TimeseriesPacket"] packets = [] for offset, sample_count in enumerate((0, 3)): msg = packet_cls() @@ -334,7 +186,7 @@ def _build_ts_payloads(): def _build_band_payloads(): """Create two band packets.""" - packet_cls = _get_proto_messages()["BandPacket"] + packet_cls = _get_test_proto_messages()["BandPacket"] packets = [] for offset in range(2): msg = packet_cls() @@ -366,7 +218,7 @@ def _build_band_payloads(): def _build_fft_payloads(*, complex_data: bool): """Create two FFT packets.""" - packet_cls = _get_proto_messages()["FFTPacket"] + packet_cls = _get_test_proto_messages()["FFTPacket"] packets = [] for offset in range(2): msg = packet_cls() @@ -400,7 +252,7 @@ def _write_records(path: Path, records): def _mutate_record(records, index: int, message_type: str, mutator): """Return a copy of records with one protobuf payload mutated.""" out = list(records) - packet_cls = _get_proto_messages()[message_type] + packet_cls = _get_test_proto_messages()[message_type] msg = packet_cls() tag, payload = out[index] msg.ParseFromString(payload) @@ -457,13 +309,17 @@ def test_get_format(self, fiber_io, sintela_protobuf_path): def test_scan_matches_read_summary(self, fiber_io, sintela_protobuf_path): """Scan metadata should match the loaded patch summary.""" - summary = fiber_io.scan(sintela_protobuf_path)[0] + summary = _payload_to_summary(fiber_io.scan(sintela_protobuf_path)[0]) patch_summary = fiber_io.read(sintela_protobuf_path)[0].summary assert summary == patch_summary - def test_ts_read_promotes_selected_meta_attrs(self, fiber_io, write_sintela_file, ts_records): + def test_ts_read_promotes_selected_meta_attrs( + self, fiber_io, write_sintela_file, ts_records + ): """META should supplement stable provenance attrs.""" - path = write_sintela_file("ts.pb", [("META", _build_meta_payload()), *ts_records]) + path = write_sintela_file( + "ts.pb", [("META", _build_meta_payload()), *ts_records] + ) patch = fiber_io.read(path)[0] assert patch.dims == ("time", "distance") assert patch.attrs.recorder_namespace == "manualRecord/recorder" @@ -473,7 +329,9 @@ def test_ts_read_promotes_selected_meta_attrs(self, fiber_io, write_sintela_file assert patch.attrs.fiber_id == 2 assert patch.attrs.data_type == "strain" - def test_band_read_returns_expected_dims(self, fiber_io, write_sintela_file, band_records): + def test_band_read_returns_expected_dims( + self, fiber_io, write_sintela_file, band_records + ): """Band recordings should load into a 3D patch.""" path = write_sintela_file("band.pb", band_records) patch = fiber_io.read(path)[0] @@ -481,32 +339,61 @@ def test_band_read_returns_expected_dims(self, fiber_io, write_sintela_file, ban assert patch.shape == (2, 2, 2) assert "band_start_frequency" in patch.coords.coord_map - def test_fft_read_returns_expected_dims(self, fiber_io, write_sintela_file, fft_records): + def test_fft_read_returns_expected_dims( + self, fiber_io, write_sintela_file, fft_records + ): """FFT recordings should load into a frequency-domain patch.""" path = write_sintela_file("fft.pb", fft_records) patch = fiber_io.read(path)[0] assert patch.dims == ("time", "distance", "frequency") assert patch.shape == (2, 2, 3) - def test_band_scan_matches_read_summary(self, fiber_io, write_sintela_file, band_records): + def test_band_scan_matches_read_summary( + self, fiber_io, write_sintela_file, band_records + ): """Band scan should exercise the metadata-only summary path.""" path = write_sintela_file("band.pb", band_records) - assert fiber_io.scan(path)[0] == fiber_io.read(path)[0].summary + scan_summary = _payload_to_summary(fiber_io.scan(path)[0]) + assert scan_summary == fiber_io.read(path)[0].summary def test_fft_scan_matches_read_summary( self, fiber_io, write_sintela_file, complex_fft_records ): """FFT scan should exercise the metadata-only summary path.""" path = write_sintela_file("fft.pb", complex_fft_records) - assert fiber_io.scan(path)[0] == fiber_io.read(path)[0].summary + scan_summary = _payload_to_summary(fiber_io.scan(path)[0]) + assert scan_summary == fiber_io.read(path)[0].summary + + def test_time_coords_keep_datetime_dtype( + self, fiber_io, write_sintela_file, ts_records, fft_records + ): + """Datetime coordinates should remain time-like after Sintela parsing.""" + ts_path = write_sintela_file("ts_time_units.pb", ts_records) + fft_path = write_sintela_file("fft_time_units.pb", fft_records) + + assert "datetime64" in str(fiber_io.read(ts_path)[0].get_coord("time").dtype) + assert "datetime64" in str(fiber_io.read(fft_path)[0].get_coord("time").dtype) + + def test_fft_attrs_do_not_default_to_time_series_units( + self, fiber_io, write_sintela_file, fft_records + ): + """FFT packets should not inherit time-series phase units by default.""" + path = write_sintela_file("fft_attrs.pb", fft_records) + patch = fiber_io.read(path)[0] + assert patch.attrs.data_type == "" + assert patch.attrs.data_units in (None, "") - def test_complex_fft_is_complex_dtype(self, fiber_io, write_sintela_file, complex_fft_records): + def test_complex_fft_is_complex_dtype( + self, fiber_io, write_sintela_file, complex_fft_records + ): """Complex FFT packets should decode into a complex dtype.""" path = write_sintela_file("fft_complex.pb", complex_fft_records) patch = fiber_io.read(path)[0] assert np.issubdtype(patch.data.dtype, np.complexfloating) - def test_read_applies_distance_selection(self, fiber_io, write_sintela_file, ts_records): + def test_read_applies_distance_selection( + self, fiber_io, write_sintela_file, ts_records + ): """Read should apply coord selectors through the core wrapper.""" path = write_sintela_file("ts_select.pb", ts_records) patch = fiber_io.read(path, distance=(30, 31))[0] @@ -521,17 +408,28 @@ def test_read_returns_empty_spool_for_empty_selection( spool = fiber_io.read(path, distance=(999, 1000)) assert len(spool) == 0 - def test_mixed_families_raise(self, fiber_io, write_sintela_file, ts_records, band_records): + def test_mixed_families_raise( + self, fiber_io, write_sintela_file, ts_records, band_records + ): """Mixed data packet families are not supported.""" - records = [("META", _build_meta_payload()), *ts_records, band_records[0]] + records = [ + ("META", _build_meta_payload()), + *ts_records, + band_records[0], + ] path = write_sintela_file("mixed.pb", records) with pytest.raises(InvalidFiberFileError, match="Mixed Sintela protobuf"): fiber_io.scan(path) - def test_non_contiguous_timeseries_raises(self, fiber_io, write_sintela_file, ts_records): + def test_non_contiguous_timeseries_raises( + self, fiber_io, write_sintela_file, ts_records + ): """Timeseries packets with gaps or reordering should fail.""" records = _mutate_record( - ts_records, 1, "TimeseriesPacket", lambda msg: setattr(msg.header, "sample_count", 10) + ts_records, + 1, + "TimeseriesPacket", + lambda msg: setattr(msg.header, "sample_count", 10), ) path = write_sintela_file("non_contiguous.pb", records) with pytest.raises(InvalidFiberFileError, match="Non-contiguous"): @@ -543,23 +441,31 @@ def test_bad_magic_returns_false(self, fiber_io, tmp_path): path.write_bytes(b"NOPE") assert not fiber_io.get_format(path) - def test_scan_rejects_invalid_or_truncated_envelope_headers(self, fiber_io, tmp_path): + def test_scan_rejects_invalid_or_truncated_envelope_headers( + self, fiber_io, tmp_path + ): """Strict scan mode should raise for malformed envelope headers.""" bad_magic = tmp_path / "bad_magic.pb" bad_magic.write_bytes(b"NOPE") - with pytest.raises(InvalidFiberFileError, match="Invalid Sintela protobuf magic"): + with pytest.raises( + InvalidFiberFileError, match="Invalid Sintela protobuf magic" + ): fiber_io.scan(bad_magic) short_magic = tmp_path / "short_magic.pb" short_magic.write_bytes(b"\x01\x02") - with pytest.raises(InvalidFiberFileError, match="Truncated Sintela protobuf magic"): + with pytest.raises( + InvalidFiberFileError, match="Truncated Sintela protobuf magic" + ): fiber_io.scan(short_magic) short_header = tmp_path / "short_header.pb" with short_header.open("wb") as handle: handle.write(struct.pack(" Date: Thu, 9 Apr 2026 11:41:31 +0200 Subject: [PATCH 3/5] address review --- dascore/io/sintela_protobuf/utils.py | 37 +++++++++------- tests/test_io/_common_io_test_utils.py | 4 ++ .../test_sintela_protobuf.py | 42 +++++++++++++++++++ 3 files changed, 69 insertions(+), 14 deletions(-) diff --git a/dascore/io/sintela_protobuf/utils.py b/dascore/io/sintela_protobuf/utils.py index c88f059a..d9d5ce7a 100644 --- a/dascore/io/sintela_protobuf/utils.py +++ b/dascore/io/sintela_protobuf/utils.py @@ -5,6 +5,7 @@ from __future__ import annotations import struct +from collections.abc import Iterator from dataclasses import dataclass from functools import cache from typing import Any @@ -141,10 +142,9 @@ def _timestamp_to_dt64(timestamp) -> np.datetime64 | None: return np.datetime64(seconds, "s") + np.timedelta64(nanos, "ns") -def _iter_envelope_records(resource, *, strict: bool) -> list[EnvelopeRecord]: +def _iter_envelope_records(resource, *, strict: bool) -> Iterator[EnvelopeRecord]: """Read all MTLV envelope records from a binary stream.""" resource.seek(0) - out: list[EnvelopeRecord] = [] while True: magic = resource.read(4) if not magic: @@ -152,25 +152,24 @@ def _iter_envelope_records(resource, *, strict: bool) -> list[EnvelopeRecord]: if len(magic) < 4: if strict: raise InvalidFiberFileError("Truncated Sintela protobuf magic header.") - return [] + return if struct.unpack(" str | None: @@ -554,9 +553,8 @@ def _parse_meta(payload: bytes) -> ParsedMeta: instrument_model=str(getattr(identification, "model", "") or ""), serial_number=str(getattr(identification, "serial_number", "") or ""), fiber_id=( - int(getattr(acquisition, "fiber_id")) - if acquisition is not None - and getattr(acquisition, "fiber_id", None) is not None + int(acquisition.fiber_id) + if acquisition is not None and acquisition.HasField("fiber_id") else None ), ) @@ -605,6 +603,9 @@ def _parse_records( def _get_time_coord_from_samples(start: np.datetime64, sample_rate: float, size: int): """Build a regularly sampled time coordinate.""" + if not np.isfinite(sample_rate) or sample_rate <= 0: + msg = f"Invalid Sintela protobuf sample_rate: {sample_rate!r}." + raise InvalidFiberFileError(msg) step = dc.to_timedelta64(1 / sample_rate) return get_coord(start=start, stop=start + step * size, step=step) @@ -645,6 +646,14 @@ def _assert_float_equal(name: str, values: list[float], *, rtol: float = 1e-6): return first +def _assert_positive_finite_float(name: str, values: list[float]): + """Ensure each float value is finite and positive.""" + for value in values: + if not np.isfinite(value) or value <= 0: + msg = f"Invalid Sintela protobuf {name}: {value!r}." + raise InvalidFiberFileError(msg) + + def _base_attrs( common_header, packet_type: str, meta: ParsedMeta, extra: dict | None = None ): @@ -729,9 +738,9 @@ def _get_timeseries_metadata( num_channels = _assert_equal( "num_channels", [int(ch.num_channels) for ch in common_headers] ) - sample_rate = _assert_float_equal( - "sample_rate", [float(ch.sample_rate) for ch in common_headers] - ) + sample_rates = [float(ch.sample_rate) for ch in common_headers] + _assert_positive_finite_float("sample_rate", sample_rates) + sample_rate = _assert_float_equal("sample_rate", sample_rates) channel_spacing = _assert_float_equal( "channel_spacing", [float(ch.channel_spacing) for ch in common_headers] ) diff --git a/tests/test_io/_common_io_test_utils.py b/tests/test_io/_common_io_test_utils.py index 2e427d31..37f37bda 100644 --- a/tests/test_io/_common_io_test_utils.py +++ b/tests/test_io/_common_io_test_utils.py @@ -79,6 +79,8 @@ def skip_on_timeout(seconds: float, label: str): ): try: yield + # Broad catch is intentional so _is_timeout_error can normalize + # pytest-timeout/framework-specific timeout exceptions and re-raise the rest. except BaseException as exc: if not _is_timeout_error(exc): raise @@ -94,6 +96,8 @@ def _handle_timeout(_signum, _frame): signal_mod.signal(signal_mod.SIGALRM, _handle_timeout) signal_mod.setitimer(signal_mod.ITIMER_REAL, seconds) yield + # Broad catch is intentional so _is_timeout_error can normalize + # pytest-timeout/framework-specific timeout exceptions and re-raise the rest. except BaseException as exc: if not _is_timeout_error(exc): raise diff --git a/tests/test_io/test_sintela_protobuf/test_sintela_protobuf.py b/tests/test_io/test_sintela_protobuf/test_sintela_protobuf.py index 31b004b7..1eb02100 100644 --- a/tests/test_io/test_sintela_protobuf/test_sintela_protobuf.py +++ b/tests/test_io/test_sintela_protobuf/test_sintela_protobuf.py @@ -56,6 +56,18 @@ def _build_meta_payload(): return msg.SerializeToString() +def _build_meta_payload_without_fiber_id(): + """Create a META payload with the optional fiber_id field unset.""" + cls = _get_test_meta_message_class() + msg = cls() + msg.recorder_namespace = "manualRecord/recorder" + _set_timestamp(msg.metadata_recording_time, 1_700_000_000, 123) + msg.identification.manufacturer = "Sintela" + msg.identification.model = "Onyx" + msg.identification.serial_number = "SN123" + return msg.SerializeToString() + + def _payload_to_summary(payload): """Convert a raw FiberIO scan payload using the production scan path.""" from dascore.io.core import _scan_payload_to_summary @@ -583,6 +595,25 @@ def test_timeseries_scan_rejects_missing_time( ): fiber_io.scan(path) + @pytest.mark.parametrize("bad_sample_rate", [0.0, -1.0, np.nan, np.inf]) + def test_timeseries_scan_rejects_invalid_sample_rate( + self, fiber_io, write_sintela_file, ts_records, bad_sample_rate + ): + """Timeseries scan should reject invalid sample rates.""" + records = _mutate_record( + ts_records, + 0, + "TimeseriesPacket", + lambda msg: setattr( + msg.header.common_header, "sample_rate", bad_sample_rate + ), + ) + path = write_sintela_file(f"ts_bad_rate_{bad_sample_rate}.pb", records) + with pytest.raises( + InvalidFiberFileError, match="Invalid Sintela protobuf sample_rate" + ): + fiber_io.scan(path) + def test_timeseries_read_rejects_bad_size_and_inconsistent_headers( self, fiber_io, write_sintela_file, ts_records ): @@ -777,6 +808,12 @@ def test_helper_functions(self, ts_records): time = sintela_utils._get_time_coord_from_samples( np.datetime64("2024-01-01"), 2.0, 4 ) + with pytest.raises( + InvalidFiberFileError, match="Invalid Sintela protobuf sample_rate" + ): + sintela_utils._get_time_coord_from_samples( + np.datetime64("2024-01-01"), 0.0, 4 + ) packet_times = sintela_utils._get_times( [ np.datetime64("2024-01-01T00:00:00"), @@ -802,6 +839,11 @@ def test_helper_functions(self, ts_records): scan_mode=True, ) + def test_parse_meta_without_optional_fiber_id(self): + """META parsing should keep fiber_id unset when omitted.""" + meta = sintela_utils._parse_meta(_build_meta_payload_without_fiber_id()) + assert meta.fiber_id is None + def test_decode_and_scan_helpers_cover_all_families( self, ts_records, From 2425f393fb5a443c4a9cd7c5ac354e0552a3f256 Mon Sep 17 00:00:00 2001 From: Derrick Chambers Date: Fri, 1 May 2026 14:02:57 +0200 Subject: [PATCH 4/5] Tolerate unknown Sintela protobuf records --- dascore/io/sintela_protobuf/utils.py | 8 +++++++- .../test_sintela_protobuf/test_sintela_protobuf.py | 2 ++ 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/dascore/io/sintela_protobuf/utils.py b/dascore/io/sintela_protobuf/utils.py index d9d5ce7a..81c6127c 100644 --- a/dascore/io/sintela_protobuf/utils.py +++ b/dascore/io/sintela_protobuf/utils.py @@ -577,6 +577,7 @@ def _parse_records( decode_error = _get_protobuf_decode_error() parsed: list[Any] = [] meta = ParsedMeta() + first_unsupported_tag = None for record in records: tag = record.tag if tag == META_TAG: @@ -589,7 +590,8 @@ def _parse_records( elif tag in FFT_TAGS: msg = messages["FFTPacket"]() else: - raise InvalidFiberFileError(f"Unsupported Sintela protobuf tag {tag!r}.") + first_unsupported_tag = first_unsupported_tag or tag + continue try: msg.ParseFromString(record.payload) except decode_error as exc: @@ -597,6 +599,10 @@ def _parse_records( raise InvalidFiberFileError(out) from exc parsed.append((tag, msg)) if not parsed: + if first_unsupported_tag is not None: + raise InvalidFiberFileError( + f"Unsupported Sintela protobuf tag {first_unsupported_tag!r}." + ) raise InvalidFiberFileError("No supported Sintela protobuf data packets found.") return parsed, meta diff --git a/tests/test_io/test_sintela_protobuf/test_sintela_protobuf.py b/tests/test_io/test_sintela_protobuf/test_sintela_protobuf.py index 1eb02100..abd59717 100644 --- a/tests/test_io/test_sintela_protobuf/test_sintela_protobuf.py +++ b/tests/test_io/test_sintela_protobuf/test_sintela_protobuf.py @@ -519,6 +519,8 @@ def test_get_format_tolerates_unknown_tags_before_supported_data( "unknown_then_ts.pb", [("ABCD", b"\x01"), *ts_records] ) assert fiber_io.get_format(path) == ("Sintela_Protobuf", "1") + assert fiber_io.scan(path) + assert fiber_io.read(path) def test_truncated_payload_raises(self, fiber_io, tmp_path): """Truncated payloads should fail in scan.""" From 3c5b38161c491ba35f56f8b9d4a86d47db144de7 Mon Sep 17 00:00:00 2001 From: Derrick Chambers Date: Fri, 1 May 2026 17:43:44 +0200 Subject: [PATCH 5/5] Address Sintela protobuf review feedback --- dascore/io/sintela_protobuf/utils.py | 42 +++++-- tests/test_io/_common_io_test_utils.py | 15 ++- .../test_sintela_protobuf.py | 112 +++++++++++++++++- 3 files changed, 148 insertions(+), 21 deletions(-) diff --git a/dascore/io/sintela_protobuf/utils.py b/dascore/io/sintela_protobuf/utils.py index 81c6127c..9472b635 100644 --- a/dascore/io/sintela_protobuf/utils.py +++ b/dascore/io/sintela_protobuf/utils.py @@ -618,6 +618,12 @@ def _get_time_coord_from_samples(start: np.datetime64, sample_rate: float, size: def _get_distance_coord(start_channel: int, spacing: float, count: int, step: int = 1): """Build the distance coordinate.""" + if not np.isfinite(spacing) or spacing <= 0: + msg = f"Invalid Sintela protobuf channel_spacing: {spacing!r}." + raise InvalidFiberFileError(msg) + if isinstance(step, bool) or not isinstance(step, int | np.integer) or step <= 0: + msg = f"Invalid Sintela protobuf channel_step: {step!r}." + raise InvalidFiberFileError(msg) start = start_channel * spacing return get_coord( start=start, @@ -698,6 +704,17 @@ def _normalize_demod_data_type(code: int) -> str: return str(code) +def _get_band_attr_data_type(band_def: tuple[tuple[Any, ...], ...]) -> tuple[str, str]: + """Return patch-level BAND data type/units when all bands agree.""" + mapped = [_BAND_DATA_TYPE_MAP.get(int(item[0])) for item in band_def] + if any(item is None for item in mapped): + return "", "" + first = mapped[0] + if all(item == first for item in mapped): + return first + return "", "" + + def _assert_equal(name: str, values: list[Any]): """Ensure all values in a list are equal.""" if not values: @@ -747,9 +764,9 @@ def _get_timeseries_metadata( sample_rates = [float(ch.sample_rate) for ch in common_headers] _assert_positive_finite_float("sample_rate", sample_rates) sample_rate = _assert_float_equal("sample_rate", sample_rates) - channel_spacing = _assert_float_equal( - "channel_spacing", [float(ch.channel_spacing) for ch in common_headers] - ) + channel_spacing_values = [float(ch.channel_spacing) for ch in common_headers] + _assert_positive_finite_float("channel_spacing", channel_spacing_values) + channel_spacing = _assert_float_equal("channel_spacing", channel_spacing_values) gauge_length = _assert_float_equal( "gauge_length", [float(ch.gauge_length) for ch in common_headers] ) @@ -822,9 +839,9 @@ def _get_band_metadata(parsed: list[tuple[str, Any]], meta: ParsedMeta) -> BandM num_channels = _assert_equal( "num_channels", [int(ch.num_channels) for ch in common_headers] ) - channel_spacing = _assert_float_equal( - "channel_spacing", [float(ch.channel_spacing) for ch in common_headers] - ) + channel_spacing_values = [float(ch.channel_spacing) for ch in common_headers] + _assert_positive_finite_float("channel_spacing", channel_spacing_values) + channel_spacing = _assert_float_equal("channel_spacing", channel_spacing_values) gauge_length = _assert_float_equal( "gauge_length", [float(ch.gauge_length) for ch in common_headers] ) @@ -869,8 +886,7 @@ def _get_band_metadata(parsed: list[tuple[str, Any]], meta: ParsedMeta) -> BandM }, dims=DIMS_BAND, ) - first_type = int(band_def[0][0]) - data_type, data_units = _BAND_DATA_TYPE_MAP.get(first_type, ("", "")) + data_type, data_units = _get_band_attr_data_type(band_def) attrs = _base_attrs( common_headers[0], packet_type=parsed[0][0], @@ -901,9 +917,9 @@ def _get_fft_metadata(parsed: list[tuple[str, Any]], meta: ParsedMeta) -> FFTMet num_channels = _assert_equal( "num_channels", [int(ch.num_channels) for ch in common_headers] ) - channel_spacing = _assert_float_equal( - "channel_spacing", [float(ch.channel_spacing) for ch in common_headers] - ) + channel_spacing_values = [float(ch.channel_spacing) for ch in common_headers] + _assert_positive_finite_float("channel_spacing", channel_spacing_values) + channel_spacing = _assert_float_equal("channel_spacing", channel_spacing_values) gauge_length = _assert_float_equal( "gauge_length", [float(ch.gauge_length) for ch in common_headers] ) @@ -911,7 +927,9 @@ def _get_fft_metadata(parsed: list[tuple[str, Any]], meta: ParsedMeta) -> FFTMet "start_channel", [int(ch.start_channel) for ch in common_headers] ) num_bins = _assert_equal("num_bins", [int(h.num_bins) for h in headers]) - bin_res = _assert_float_equal("bin_res", [float(h.bin_res) for h in headers]) + bin_res_values = [float(h.bin_res) for h in headers] + _assert_positive_finite_float("bin_res", bin_res_values) + bin_res = _assert_float_equal("bin_res", bin_res_values) has_complex = _assert_equal( "has_complex_data", [bool(h.has_complex_data) for h in headers] ) diff --git a/tests/test_io/_common_io_test_utils.py b/tests/test_io/_common_io_test_utils.py index 37f37bda..de4f152d 100644 --- a/tests/test_io/_common_io_test_utils.py +++ b/tests/test_io/_common_io_test_utils.py @@ -49,6 +49,13 @@ def _is_timeout_error(exc: BaseException) -> bool: return False +def _skip_if_timeout_else_raise(exc: BaseException): + """Skip timeout-like exceptions, re-raise everything else.""" + if not _is_timeout_error(exc): + raise exc + pytest.skip(str(exc)) + + def get_flat_io_test(common_io_read_tests: dict) -> list[list[dc.FiberIO | str]]: """Flatten the common IO matrix for parametrized tests.""" flat_io = [] @@ -82,9 +89,7 @@ def skip_on_timeout(seconds: float, label: str): # Broad catch is intentional so _is_timeout_error can normalize # pytest-timeout/framework-specific timeout exceptions and re-raise the rest. except BaseException as exc: - if not _is_timeout_error(exc): - raise - pytest.skip(str(exc)) + _skip_if_timeout_else_raise(exc) return previous_handler = signal_mod.getsignal(signal_mod.SIGALRM) @@ -99,9 +104,7 @@ def _handle_timeout(_signum, _frame): # Broad catch is intentional so _is_timeout_error can normalize # pytest-timeout/framework-specific timeout exceptions and re-raise the rest. except BaseException as exc: - if not _is_timeout_error(exc): - raise - pytest.skip(str(exc)) + _skip_if_timeout_else_raise(exc) finally: signal_mod.setitimer(signal_mod.ITIMER_REAL, 0) signal_mod.signal(signal_mod.SIGALRM, previous_handler) diff --git a/tests/test_io/test_sintela_protobuf/test_sintela_protobuf.py b/tests/test_io/test_sintela_protobuf/test_sintela_protobuf.py index abd59717..f664bb03 100644 --- a/tests/test_io/test_sintela_protobuf/test_sintela_protobuf.py +++ b/tests/test_io/test_sintela_protobuf/test_sintela_protobuf.py @@ -14,6 +14,7 @@ from dascore.exceptions import InvalidFiberFileError, MissingOptionalDependencyError from dascore.io.sintela_protobuf import SintelaProtobufV1 from dascore.io.sintela_protobuf import utils as sintela_utils +from dascore.units import get_quantity from dascore.utils.downloader import fetch @@ -350,6 +351,29 @@ def test_band_read_returns_expected_dims( assert patch.dims == ("time", "distance", "band") assert patch.shape == (2, 2, 2) assert "band_start_frequency" in patch.coords.coord_map + assert patch.attrs.data_type == "" + assert patch.attrs.data_units in (None, "") + + def test_band_read_preserves_attrs_for_single_semantic_type( + self, fiber_io, write_sintela_file, band_records + ): + """Uniform BAND semantics should populate patch-level data attrs.""" + records = _mutate_record( + band_records, + 0, + "BandPacket", + lambda msg: setattr(msg.header.band_data_info[1], "band_data_type", 13), + ) + records = _mutate_record( + records, + 1, + "BandPacket", + lambda msg: setattr(msg.header.band_data_info[1], "band_data_type", 13), + ) + path = write_sintela_file("band_uniform.pb", records) + patch = fiber_io.read(path)[0] + assert patch.attrs.data_type == "phase" + assert patch.attrs.data_units == get_quantity("rad") def test_fft_read_returns_expected_dims( self, fiber_io, write_sintela_file, fft_records @@ -616,6 +640,56 @@ def test_timeseries_scan_rejects_invalid_sample_rate( ): fiber_io.scan(path) + @pytest.mark.parametrize("bad_spacing", [0.0, -1.0, np.nan, np.inf]) + def test_scan_rejects_invalid_channel_spacing( + self, fiber_io, write_sintela_file, ts_records, bad_spacing + ): + """Invalid channel spacing should raise a format-specific error.""" + records = _mutate_record( + ts_records, + 0, + "TimeseriesPacket", + lambda msg: setattr( + msg.header.common_header, "channel_spacing", bad_spacing + ), + ) + records = _mutate_record( + records, + 1, + "TimeseriesPacket", + lambda msg: setattr( + msg.header.common_header, "channel_spacing", bad_spacing + ), + ) + path = write_sintela_file(f"ts_bad_spacing_{bad_spacing}.pb", records) + with pytest.raises( + InvalidFiberFileError, match="Invalid Sintela protobuf channel_spacing" + ): + fiber_io.scan(path) + + @pytest.mark.parametrize("bad_step", [0, -1]) + def test_scan_rejects_invalid_channel_step( + self, fiber_io, write_sintela_file, ts_records, bad_step + ): + """Invalid channel steps should raise a format-specific error.""" + records = _mutate_record( + ts_records, + 0, + "TimeseriesPacket", + lambda msg: setattr(msg.header, "channel_step", bad_step), + ) + records = _mutate_record( + records, + 1, + "TimeseriesPacket", + lambda msg: setattr(msg.header, "channel_step", bad_step), + ) + path = write_sintela_file(f"ts_bad_step_{bad_step}.pb", records) + with pytest.raises( + InvalidFiberFileError, match="Invalid Sintela protobuf channel_step" + ): + fiber_io.scan(path) + def test_timeseries_read_rejects_bad_size_and_inconsistent_headers( self, fiber_io, write_sintela_file, ts_records ): @@ -739,6 +813,29 @@ def test_fft_read_rejects_bad_sizes( with pytest.raises(InvalidFiberFileError, match="FFT payload size"): fiber_io.read(path) + @pytest.mark.parametrize("bad_bin_res", [0.0, -1.0, np.nan, np.inf]) + def test_fft_scan_rejects_invalid_bin_res( + self, fiber_io, write_sintela_file, fft_records, bad_bin_res + ): + """FFT bin resolution should be finite and positive.""" + records = _mutate_record( + fft_records, + 0, + "FFTPacket", + lambda msg: setattr(msg.header, "bin_res", bad_bin_res), + ) + records = _mutate_record( + records, + 1, + "FFTPacket", + lambda msg: setattr(msg.header, "bin_res", bad_bin_res), + ) + path = write_sintela_file(f"fft_bad_bin_res_{bad_bin_res}.pb", records) + with pytest.raises( + InvalidFiberFileError, match="Invalid Sintela protobuf bin_res" + ): + fiber_io.scan(path) + def test_fft_scan_and_read_reject_missing_time( self, fiber_io, write_sintela_file, fft_records ): @@ -765,10 +862,11 @@ def test_optional_dependency_error_message(self): assert isinstance(err, MissingOptionalDependencyError) assert "protobuf is not installed" in str(err) - def test_missing_protobuf_only_affects_scan_and_read(self, tmp_path, monkeypatch): + def test_missing_protobuf_only_affects_scan_and_read( + self, sintela_protobuf_path, monkeypatch + ): """Detection should stay available without protobuf support.""" - path = tmp_path / "ts.pb" - _write_records(path, [("META", _build_meta_payload()), *_build_ts_payloads()]) + path = sintela_protobuf_path def _raise(): raise MissingOptionalDependencyError("protobuf missing") @@ -832,6 +930,14 @@ def test_helper_functions(self, ts_records): extra={"data_type": "strain"}, ) assert attrs.data_type == "strain" + with pytest.raises( + InvalidFiberFileError, match="Invalid Sintela protobuf channel_spacing" + ): + sintela_utils._get_distance_coord(0, np.nan, 1) + assert sintela_utils._get_band_attr_data_type(((999, 0, 1, "", ""),)) == ( + "", + "", + ) with pytest.raises( InvalidFiberFileError, match="No supported Sintela protobuf"