diff --git a/ser/_internal/transcription/compatibility.py b/ser/_internal/transcription/compatibility.py index 1e57c373..2a6cd66d 100644 --- a/ser/_internal/transcription/compatibility.py +++ b/ser/_internal/transcription/compatibility.py @@ -4,15 +4,17 @@ import logging from collections.abc import Callable -from typing import Literal, Protocol, TypeVar +from typing import TYPE_CHECKING, Literal, Protocol, TypeVar from ser.config import AppConfig from ser.profiles import TranscriptionBackendId -from ser.transcript.backends import ( - BackendRuntimeRequest, - CompatibilityIssueImpact, - CompatibilityReport, -) + +if TYPE_CHECKING: + from ser.transcript.backends.base import ( + BackendRuntimeRequest, + CompatibilityIssueImpact, + CompatibilityReport, + ) type _CompatibilityIssueKind = Literal["noise", "operational"] type _EmittedIssueKeySet = set[tuple[str, str, str]] diff --git a/ser/_internal/transcription/process_isolation.py b/ser/_internal/transcription/process_isolation.py index 85e46c44..f98b083e 100644 --- a/ser/_internal/transcription/process_isolation.py +++ b/ser/_internal/transcription/process_isolation.py @@ -4,9 +4,10 @@ import logging from collections.abc import Callable +from numbers import Real from pathlib import Path from types import ModuleType -from typing import Literal, Never, Protocol, TypeVar, cast +from typing import TYPE_CHECKING, Literal, Never, Protocol, TypeVar, cast from ser.config import AppConfig from ser.profiles import TranscriptionBackendId @@ -20,11 +21,14 @@ log_phase_failed, log_phase_started, ) -from ser.transcript.backends import BackendRuntimeRequest + +if TYPE_CHECKING: + from ser.transcript.backends.base import BackendRuntimeRequest type _WorkerPhase = Literal["setup_complete", "model_loaded"] type WorkerPhaseMessage = tuple[Literal["phase"], _WorkerPhase] -type WorkerSuccessMessage = tuple[Literal["ok"], list[tuple[str, float, float]]] +type SerializedTranscriptWord = tuple[str, float, float] +type WorkerSuccessMessage = tuple[Literal["ok"], list[SerializedTranscriptWord]] type WorkerErrorMessage = tuple[Literal["err"], str, str, str] type WorkerMessage = WorkerPhaseMessage | WorkerSuccessMessage | WorkerErrorMessage @@ -247,6 +251,71 @@ def transcribe( type _AdapterResolver = Callable[[TranscriptionBackendId], object] +def _is_real_timestamp(value: object) -> bool: + """Returns whether one worker timestamp payload is numeric and non-bool.""" + return isinstance(value, Real) and not isinstance(value, bool) + + +def _serialize_transcript_words( + transcript_words: list[object], +) -> list[SerializedTranscriptWord]: + """Converts backend transcript words into a validated worker payload.""" + serialized_words: list[SerializedTranscriptWord] = [] + for index, raw_word in enumerate(transcript_words): + word = cast(_TranscriptWordLike, raw_word) + if not isinstance(word.word, str): + raise TypeError( + "Transcription worker produced a non-string token " f"at index {index}." + ) + if not _is_real_timestamp(word.start_seconds) or not _is_real_timestamp(word.end_seconds): + raise TypeError( + "Transcription worker produced non-numeric timestamps " f"at index {index}." + ) + serialized_words.append( + ( + word.word, + float(word.start_seconds), + float(word.end_seconds), + ) + ) + return serialized_words + + +def _deserialize_transcript_words( + serialized_words: object, + *, + transcript_word_factory: Callable[[str, float, float], _TTranscriptWord], + error_factory: _ErrorFactory, +) -> list[_TTranscriptWord]: + """Builds transcript words from validated worker payload tuples.""" + if not isinstance(serialized_words, list): + raise error_factory("Transcription worker returned malformed transcript payload.") + transcript_words: list[_TTranscriptWord] = [] + for index, raw_word in enumerate(serialized_words): + if not isinstance(raw_word, tuple) or len(raw_word) != 3: + raise error_factory( + "Transcription worker returned malformed transcript payload " f"at index {index}." + ) + word, start_seconds, end_seconds = raw_word + if not isinstance(word, str): + raise error_factory( + "Transcription worker returned non-string transcript token " f"at index {index}." + ) + if not _is_real_timestamp(start_seconds) or not _is_real_timestamp(end_seconds): + raise error_factory( + "Transcription worker returned non-numeric transcript timestamps " + f"at index {index}." + ) + transcript_words.append( + transcript_word_factory( + word, + float(start_seconds), + float(end_seconds), + ) + ) + return transcript_words + + def should_use_process_isolated_path(profile: _ProcessIsolatedProfile) -> bool: """Returns whether one transcription profile should use worker-process isolation.""" return profile.backend_id == "faster_whisper" @@ -260,6 +329,8 @@ def runtime_request_for_isolated_faster_whisper( logger: logging.Logger, ) -> BackendRuntimeRequest: """Builds one faster-whisper runtime request without importing torch in worker.""" + from ser.transcript.backends.base import BackendRuntimeRequest + if profile.backend_id != "faster_whisper": raise error_factory( "Process-isolated runtime request only supports faster-whisper backend." @@ -388,14 +459,7 @@ def transcription_worker_entry( language=payload.language, settings=settings, ) - serialized_words = [ - ( - cast(_TranscriptWordLike, word).word, - cast(_TranscriptWordLike, word).start_seconds, - cast(_TranscriptWordLike, word).end_seconds, - ) - for word in transcript_words - ] + serialized_words = _serialize_transcript_words(transcript_words) connection.send(("ok", serialized_words)) except BaseException as err: connection.send(("err", stage, type(err).__name__, str(err))) @@ -488,9 +552,7 @@ def run_faster_whisper_process_isolated( isinstance(completion_message, tuple) and len(completion_message) == 2 and completion_message[0] == "ok" - and isinstance(completion_message[1], list) ): - serialized_words = completion_message[1] if transcription_started_at is None: raise error_factory( "Transcription worker completed transcription before phase timer start." @@ -500,10 +562,11 @@ def run_faster_whisper_process_isolated( phase_name=PHASE_TRANSCRIPTION, started_at=transcription_started_at, ) - return [ - transcript_word_factory(word, start_seconds, end_seconds) - for word, start_seconds, end_seconds in serialized_words - ] + return _deserialize_transcript_words( + completion_message[1], + transcript_word_factory=transcript_word_factory, + error_factory=error_factory, + ) raise_worker_error_fn(completion_message) except Exception: if transcription_started_at is not None: diff --git a/ser/_internal/transcription/process_worker.py b/ser/_internal/transcription/process_worker.py index 25b7ba15..6b1ac758 100644 --- a/ser/_internal/transcription/process_worker.py +++ b/ser/_internal/transcription/process_worker.py @@ -8,11 +8,13 @@ from dataclasses import dataclass from pathlib import Path from types import ModuleType -from typing import Protocol +from typing import TYPE_CHECKING, Protocol from ser.config import AppConfig from ser.profiles import TranscriptionBackendId -from ser.transcript.backends import BackendRuntimeRequest + +if TYPE_CHECKING: + from ser.transcript.backends.base import BackendRuntimeRequest class ProcessIsolatedProfileLike(Protocol): diff --git a/ser/_internal/transcription/public_boundary_process.py b/ser/_internal/transcription/public_boundary_process.py index c0169b29..e62db3b4 100644 --- a/ser/_internal/transcription/public_boundary_process.py +++ b/ser/_internal/transcription/public_boundary_process.py @@ -6,16 +6,18 @@ from collections.abc import Callable from multiprocessing.connection import Connection from multiprocessing.process import BaseProcess -from typing import Never, Protocol, TypeVar, cast +from typing import TYPE_CHECKING, Never, Protocol, TypeVar, cast from ser.config import AppConfig from ser.domain import TranscriptWord from ser.profiles import TranscriptionBackendId -from ser.transcript.backends import BackendRuntimeRequest from .process_isolation import WorkerMessage from .process_worker import TranscriptionProcessPayload +if TYPE_CHECKING: + from ser.transcript.backends.base import BackendRuntimeRequest + class _ProcessIsolatedProfile(Protocol): """Minimal profile contract for public-boundary worker execution.""" diff --git a/ser/_internal/transcription/public_boundary_runtime.py b/ser/_internal/transcription/public_boundary_runtime.py index 7a690568..06e659e0 100644 --- a/ser/_internal/transcription/public_boundary_runtime.py +++ b/ser/_internal/transcription/public_boundary_runtime.py @@ -4,12 +4,14 @@ import logging from collections.abc import Callable -from typing import cast +from typing import TYPE_CHECKING, cast from ser.config import AppConfig from ser.domain import TranscriptWord from ser.profiles import TranscriptionBackendId -from ser.transcript.backends import BackendRuntimeRequest, CompatibilityReport + +if TYPE_CHECKING: + from ser.transcript.backends.base import BackendRuntimeRequest, CompatibilityReport type _ResolveTranscriptionProfileImpl = Callable[..., object] type _ResolveBackendId = Callable[..., TranscriptionBackendId] diff --git a/ser/_internal/transcription/public_boundary_support.py b/ser/_internal/transcription/public_boundary_support.py index 75ba5227..dfd2c169 100644 --- a/ser/_internal/transcription/public_boundary_support.py +++ b/ser/_internal/transcription/public_boundary_support.py @@ -5,7 +5,7 @@ import logging import multiprocessing as mp from collections.abc import Callable -from typing import Literal, Never, Protocol, cast +from typing import TYPE_CHECKING, Literal, Never, Protocol, cast from ser._internal.transcription.compatibility import ( check_adapter_compatibility as _check_adapter_compatibility_impl, @@ -113,16 +113,15 @@ from ser.config import AppConfig from ser.domain import TranscriptWord from ser.profiles import TranscriptionBackendId -from ser.transcript.backends import ( - BackendRuntimeRequest, - CompatibilityReport, - resolve_transcription_backend_adapter, -) +from ser.transcript.backends.factory import resolve_transcription_backend_adapter from ser.transcript.runtime_policy import ( DEFAULT_MPS_LOW_MEMORY_THRESHOLD_GB, resolve_transcription_runtime_policy, ) +if TYPE_CHECKING: + from ser.transcript.backends.base import BackendRuntimeRequest, CompatibilityReport + _CompatibilityIssueKind = Literal["noise", "operational"] diff --git a/ser/_internal/transcription/runtime_profile.py b/ser/_internal/transcription/runtime_profile.py index 4bb0f739..16ffa2e0 100644 --- a/ser/_internal/transcription/runtime_profile.py +++ b/ser/_internal/transcription/runtime_profile.py @@ -3,11 +3,13 @@ from __future__ import annotations from collections.abc import Callable -from typing import Protocol, cast +from typing import TYPE_CHECKING, Protocol, cast from ser.config import AppConfig from ser.profiles import TranscriptionBackendId -from ser.transcript.backends import BackendRuntimeRequest + +if TYPE_CHECKING: + from ser.transcript.backends.base import BackendRuntimeRequest class _RuntimeProfile(Protocol): @@ -119,6 +121,8 @@ def runtime_request_from_profile( default_mps_low_memory_threshold_gb: float, ) -> BackendRuntimeRequest: """Builds one backend runtime request from transcription profile settings.""" + from ser.transcript.backends.base import BackendRuntimeRequest + torch_runtime = getattr(settings, "torch_runtime", None) transcription_settings = getattr(settings, "transcription", None) requested_device = getattr(torch_runtime, "device", "cpu") diff --git a/tests/suites/integration/docs/test_architecture_docs.py b/tests/suites/integration/docs/test_architecture_docs.py index 07cb7664..68bbef7a 100644 --- a/tests/suites/integration/docs/test_architecture_docs.py +++ b/tests/suites/integration/docs/test_architecture_docs.py @@ -13,11 +13,6 @@ r"https://github\.com/jsugg/ser/(?:blob|tree)/main/(docs/[A-Za-z0-9_./-]+)" ) _ARCHITECTURE_RELATIVE_LINK_PATTERN = re.compile(r"\(([^)]+)\)") -_ARCHITECTURE_COUNT_PATTERN = re.compile( - r"- (?P