Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 8 additions & 6 deletions ser/_internal/transcription/compatibility.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,17 @@

import logging
from collections.abc import Callable
from typing import Literal, Protocol, TypeVar
from typing import TYPE_CHECKING, Literal, Protocol, TypeVar

from ser.config import AppConfig
from ser.profiles import TranscriptionBackendId
from ser.transcript.backends import (
BackendRuntimeRequest,
CompatibilityIssueImpact,
CompatibilityReport,
)

if TYPE_CHECKING:
from ser.transcript.backends.base import (
BackendRuntimeRequest,
CompatibilityIssueImpact,
CompatibilityReport,
)

type _CompatibilityIssueKind = Literal["noise", "operational"]
type _EmittedIssueKeySet = set[tuple[str, str, str]]
Expand Down
97 changes: 80 additions & 17 deletions ser/_internal/transcription/process_isolation.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@

import logging
from collections.abc import Callable
from numbers import Real
from pathlib import Path
from types import ModuleType
from typing import Literal, Never, Protocol, TypeVar, cast
from typing import TYPE_CHECKING, Literal, Never, Protocol, TypeVar, cast

from ser.config import AppConfig
from ser.profiles import TranscriptionBackendId
Expand All @@ -20,11 +21,14 @@
log_phase_failed,
log_phase_started,
)
from ser.transcript.backends import BackendRuntimeRequest

if TYPE_CHECKING:
from ser.transcript.backends.base import BackendRuntimeRequest

type _WorkerPhase = Literal["setup_complete", "model_loaded"]
type WorkerPhaseMessage = tuple[Literal["phase"], _WorkerPhase]
type WorkerSuccessMessage = tuple[Literal["ok"], list[tuple[str, float, float]]]
type SerializedTranscriptWord = tuple[str, float, float]
type WorkerSuccessMessage = tuple[Literal["ok"], list[SerializedTranscriptWord]]
type WorkerErrorMessage = tuple[Literal["err"], str, str, str]
type WorkerMessage = WorkerPhaseMessage | WorkerSuccessMessage | WorkerErrorMessage

Expand Down Expand Up @@ -247,6 +251,71 @@ def transcribe(
type _AdapterResolver = Callable[[TranscriptionBackendId], object]


def _is_real_timestamp(value: object) -> bool:
"""Returns whether one worker timestamp payload is numeric and non-bool."""
return isinstance(value, Real) and not isinstance(value, bool)


def _serialize_transcript_words(
transcript_words: list[object],
) -> list[SerializedTranscriptWord]:
"""Converts backend transcript words into a validated worker payload."""
serialized_words: list[SerializedTranscriptWord] = []
for index, raw_word in enumerate(transcript_words):
word = cast(_TranscriptWordLike, raw_word)
if not isinstance(word.word, str):
raise TypeError(
"Transcription worker produced a non-string token " f"at index {index}."
)
if not _is_real_timestamp(word.start_seconds) or not _is_real_timestamp(word.end_seconds):
raise TypeError(
"Transcription worker produced non-numeric timestamps " f"at index {index}."
)
serialized_words.append(
(
word.word,
float(word.start_seconds),
float(word.end_seconds),
)
)
return serialized_words


def _deserialize_transcript_words(
serialized_words: object,
*,
transcript_word_factory: Callable[[str, float, float], _TTranscriptWord],
error_factory: _ErrorFactory,
) -> list[_TTranscriptWord]:
"""Builds transcript words from validated worker payload tuples."""
if not isinstance(serialized_words, list):
raise error_factory("Transcription worker returned malformed transcript payload.")
transcript_words: list[_TTranscriptWord] = []
for index, raw_word in enumerate(serialized_words):
if not isinstance(raw_word, tuple) or len(raw_word) != 3:
raise error_factory(
"Transcription worker returned malformed transcript payload " f"at index {index}."
)
word, start_seconds, end_seconds = raw_word
if not isinstance(word, str):
raise error_factory(
"Transcription worker returned non-string transcript token " f"at index {index}."
)
if not _is_real_timestamp(start_seconds) or not _is_real_timestamp(end_seconds):
raise error_factory(
"Transcription worker returned non-numeric transcript timestamps "
f"at index {index}."
)
transcript_words.append(
transcript_word_factory(
word,
float(start_seconds),
float(end_seconds),
)
)
return transcript_words


def should_use_process_isolated_path(profile: _ProcessIsolatedProfile) -> bool:
"""Returns whether one transcription profile should use worker-process isolation."""
return profile.backend_id == "faster_whisper"
Expand All @@ -260,6 +329,8 @@ def runtime_request_for_isolated_faster_whisper(
logger: logging.Logger,
) -> BackendRuntimeRequest:
"""Builds one faster-whisper runtime request without importing torch in worker."""
from ser.transcript.backends.base import BackendRuntimeRequest

if profile.backend_id != "faster_whisper":
raise error_factory(
"Process-isolated runtime request only supports faster-whisper backend."
Expand Down Expand Up @@ -388,14 +459,7 @@ def transcription_worker_entry(
language=payload.language,
settings=settings,
)
serialized_words = [
(
cast(_TranscriptWordLike, word).word,
cast(_TranscriptWordLike, word).start_seconds,
cast(_TranscriptWordLike, word).end_seconds,
)
for word in transcript_words
]
serialized_words = _serialize_transcript_words(transcript_words)
connection.send(("ok", serialized_words))
except BaseException as err:
connection.send(("err", stage, type(err).__name__, str(err)))
Expand Down Expand Up @@ -488,9 +552,7 @@ def run_faster_whisper_process_isolated(
isinstance(completion_message, tuple)
and len(completion_message) == 2
and completion_message[0] == "ok"
and isinstance(completion_message[1], list)
):
serialized_words = completion_message[1]
if transcription_started_at is None:
raise error_factory(
"Transcription worker completed transcription before phase timer start."
Expand All @@ -500,10 +562,11 @@ def run_faster_whisper_process_isolated(
phase_name=PHASE_TRANSCRIPTION,
started_at=transcription_started_at,
)
return [
transcript_word_factory(word, start_seconds, end_seconds)
for word, start_seconds, end_seconds in serialized_words
]
return _deserialize_transcript_words(
completion_message[1],
transcript_word_factory=transcript_word_factory,
error_factory=error_factory,
)
raise_worker_error_fn(completion_message)
except Exception:
if transcription_started_at is not None:
Expand Down
6 changes: 4 additions & 2 deletions ser/_internal/transcription/process_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,13 @@
from dataclasses import dataclass
from pathlib import Path
from types import ModuleType
from typing import Protocol
from typing import TYPE_CHECKING, Protocol

from ser.config import AppConfig
from ser.profiles import TranscriptionBackendId
from ser.transcript.backends import BackendRuntimeRequest

if TYPE_CHECKING:
from ser.transcript.backends.base import BackendRuntimeRequest


class ProcessIsolatedProfileLike(Protocol):
Expand Down
6 changes: 4 additions & 2 deletions ser/_internal/transcription/public_boundary_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,18 @@
from collections.abc import Callable
from multiprocessing.connection import Connection
from multiprocessing.process import BaseProcess
from typing import Never, Protocol, TypeVar, cast
from typing import TYPE_CHECKING, Never, Protocol, TypeVar, cast

from ser.config import AppConfig
from ser.domain import TranscriptWord
from ser.profiles import TranscriptionBackendId
from ser.transcript.backends import BackendRuntimeRequest

from .process_isolation import WorkerMessage
from .process_worker import TranscriptionProcessPayload

if TYPE_CHECKING:
from ser.transcript.backends.base import BackendRuntimeRequest


class _ProcessIsolatedProfile(Protocol):
"""Minimal profile contract for public-boundary worker execution."""
Expand Down
6 changes: 4 additions & 2 deletions ser/_internal/transcription/public_boundary_runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,14 @@

import logging
from collections.abc import Callable
from typing import cast
from typing import TYPE_CHECKING, cast

from ser.config import AppConfig
from ser.domain import TranscriptWord
from ser.profiles import TranscriptionBackendId
from ser.transcript.backends import BackendRuntimeRequest, CompatibilityReport

if TYPE_CHECKING:
from ser.transcript.backends.base import BackendRuntimeRequest, CompatibilityReport

type _ResolveTranscriptionProfileImpl = Callable[..., object]
type _ResolveBackendId = Callable[..., TranscriptionBackendId]
Expand Down
11 changes: 5 additions & 6 deletions ser/_internal/transcription/public_boundary_support.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import logging
import multiprocessing as mp
from collections.abc import Callable
from typing import Literal, Never, Protocol, cast
from typing import TYPE_CHECKING, Literal, Never, Protocol, cast

from ser._internal.transcription.compatibility import (
check_adapter_compatibility as _check_adapter_compatibility_impl,
Expand Down Expand Up @@ -113,16 +113,15 @@
from ser.config import AppConfig
from ser.domain import TranscriptWord
from ser.profiles import TranscriptionBackendId
from ser.transcript.backends import (
BackendRuntimeRequest,
CompatibilityReport,
resolve_transcription_backend_adapter,
)
from ser.transcript.backends.factory import resolve_transcription_backend_adapter
from ser.transcript.runtime_policy import (
DEFAULT_MPS_LOW_MEMORY_THRESHOLD_GB,
resolve_transcription_runtime_policy,
)

if TYPE_CHECKING:
from ser.transcript.backends.base import BackendRuntimeRequest, CompatibilityReport

_CompatibilityIssueKind = Literal["noise", "operational"]


Expand Down
8 changes: 6 additions & 2 deletions ser/_internal/transcription/runtime_profile.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,13 @@
from __future__ import annotations

from collections.abc import Callable
from typing import Protocol, cast
from typing import TYPE_CHECKING, Protocol, cast

from ser.config import AppConfig
from ser.profiles import TranscriptionBackendId
from ser.transcript.backends import BackendRuntimeRequest

if TYPE_CHECKING:
from ser.transcript.backends.base import BackendRuntimeRequest


class _RuntimeProfile(Protocol):
Expand Down Expand Up @@ -119,6 +121,8 @@ def runtime_request_from_profile(
default_mps_low_memory_threshold_gb: float,
) -> BackendRuntimeRequest:
"""Builds one backend runtime request from transcription profile settings."""
from ser.transcript.backends.base import BackendRuntimeRequest

torch_runtime = getattr(settings, "torch_runtime", None)
transcription_settings = getattr(settings, "transcription", None)
requested_device = getattr(torch_runtime, "device", "cpu")
Expand Down
33 changes: 0 additions & 33 deletions tests/suites/integration/docs/test_architecture_docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,6 @@
r"https://github\.com/jsugg/ser/(?:blob|tree)/main/(docs/[A-Za-z0-9_./-]+)"
)
_ARCHITECTURE_RELATIVE_LINK_PATTERN = re.compile(r"\(([^)]+)\)")
_ARCHITECTURE_COUNT_PATTERN = re.compile(
r"- (?P<label>Source modules under `ser/`|Test modules under `tests/`|"
r"Public modules outside `_internal/`|Internal owner/helper modules under `_internal/`): "
r"`(?P<count>\d+)`"
)
_REMOVED_TRACKER_REFERENCES = (
"ser_refactor_implementation_journal.md",
"ser_refactor_status.md",
Expand All @@ -29,22 +24,6 @@ def _repo_root() -> Path:
return Path(__file__).resolve().parents[4]


def _expected_codebase_counts(root: Path) -> dict[str, int]:
"""Builds the current architecture snapshot counts from the working tree."""
source_files = list((root / "ser").rglob("*.py"))
test_files = list((root / "tests").rglob("*.py"))
return {
"Source modules under `ser/`": len(source_files),
"Test modules under `tests/`": len(test_files),
"Public modules outside `_internal/`": sum(
1 for path in source_files if "_internal" not in path.parts
),
"Internal owner/helper modules under `_internal/`": sum(
1 for path in source_files if "_internal" in path.parts
),
}


def test_readme_architecture_links_resolve_to_existing_docs() -> None:
"""README architecture links should point at docs artifacts that exist in-tree."""
root = _repo_root()
Expand All @@ -70,18 +49,6 @@ def test_architecture_index_links_resolve_to_existing_docs() -> None:
assert all(target.is_file() for target in resolved_targets)


def test_codebase_architecture_snapshot_counts_match_current_tree() -> None:
"""Architecture snapshot counts should match the current repository tree."""
root = _repo_root()
architecture_text = (root / "docs" / "codebase-architecture.md").read_text(encoding="utf-8")
reported_counts = {
match.group("label"): int(match.group("count"))
for match in _ARCHITECTURE_COUNT_PATTERN.finditer(architecture_text)
}

assert reported_counts == _expected_codebase_counts(root)


def test_compatibility_matrix_does_not_reference_removed_tracker_docs() -> None:
"""Compatibility matrix should not point contributors at removed tracker files."""
root = _repo_root()
Expand Down
28 changes: 18 additions & 10 deletions tests/suites/integration/test_process_isolation.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@
from multiprocessing.connection import Connection
from multiprocessing.process import BaseProcess
from pathlib import Path
from typing import cast
from types import SimpleNamespace
from typing import TYPE_CHECKING, cast

import pytest
from tests.utils.helpers import process_spawn_support
Expand All @@ -20,9 +21,13 @@
from ser._internal.transcription import public_boundary_support as transcription_boundary_support
from ser.config import AppConfig, reload_settings
from ser.domain import TranscriptWord
from ser.transcript.backends import BackendRuntimeRequest
from ser.transcript.transcript_extractor import TranscriptionProfile

if TYPE_CHECKING:
from ser.transcript.backends.base import BackendRuntimeRequest
else:
BackendRuntimeRequest = object

pytestmark = [
pytest.mark.integration,
pytest.mark.process_isolation,
Expand Down Expand Up @@ -136,14 +141,17 @@ def _parse_worker_completion_message(message: tuple[object, ...]) -> str:


def _transcription_runtime_request() -> BackendRuntimeRequest:
return BackendRuntimeRequest(
model_name="tiny",
use_demucs=False,
use_vad=True,
device_spec="cpu",
device_type="cpu",
precision_candidates=("float32",),
memory_tier="not_applicable",
return cast(
BackendRuntimeRequest,
SimpleNamespace(
model_name="tiny",
use_demucs=False,
use_vad=True,
device_spec="cpu",
device_type="cpu",
precision_candidates=("float32",),
memory_tier="not_applicable",
),
)


Expand Down
Loading