From 8a305b42b156cb7a1a43cddf2ccca9d2909592b4 Mon Sep 17 00:00:00 2001 From: Juan Pedro Sugg Date: Thu, 6 Nov 2025 20:22:47 -0300 Subject: [PATCH] Add regression tests and uv-based tooling --- .github/workflows/run-tests.yml | 26 +++ Pipfile | 24 --- pyproject.toml | 91 +++++++++ requirements.txt | 64 ------ ser/__main__.py | 75 +++++++- ser/utils/subtitles.py | 141 ++++++++++---- tests/conftest.py | 182 ++++++++++++++++++ tests/suites/e2e/.gitkeep | 1 + tests/suites/e2e/test_cli.py | 79 ++++++++ tests/suites/integration/.gitkeep | 1 + tests/suites/integration/test_timeline_csv.py | 20 ++ tests/suites/unit/.gitkeep | 1 + tests/suites/unit/test_audio_utils.py | 56 ++++++ tests/suites/unit/test_emotion_model.py | 18 ++ tests/suites/unit/test_subtitles.py | 89 +++++++++ tests/suites/unit/test_timeline_utils.py | 16 ++ 16 files changed, 753 insertions(+), 131 deletions(-) create mode 100644 .github/workflows/run-tests.yml delete mode 100644 Pipfile create mode 100644 pyproject.toml delete mode 100644 requirements.txt create mode 100644 tests/conftest.py create mode 100644 tests/suites/e2e/.gitkeep create mode 100644 tests/suites/e2e/test_cli.py create mode 100644 tests/suites/integration/.gitkeep create mode 100644 tests/suites/integration/test_timeline_csv.py create mode 100644 tests/suites/unit/.gitkeep create mode 100644 tests/suites/unit/test_audio_utils.py create mode 100644 tests/suites/unit/test_emotion_model.py create mode 100644 tests/suites/unit/test_subtitles.py create mode 100644 tests/suites/unit/test_timeline_utils.py diff --git a/.github/workflows/run-tests.yml b/.github/workflows/run-tests.yml new file mode 100644 index 0000000..358dff5 --- /dev/null +++ b/.github/workflows/run-tests.yml @@ -0,0 +1,26 @@ +name: Run tests + +on: + pull_request: + push: + branches: + - main + +jobs: + tests: + runs-on: ubuntu-latest + strategy: + fail-fast: false + matrix: + python-version: ["3.9", "3.10", "3.11", "3.12", "3.13"] + + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + - uses: astral-sh/setup-uv@v2 + - name: Install dependencies + run: uv sync --dev + - name: Run tests + run: uv run pytest diff --git a/Pipfile b/Pipfile deleted file mode 100644 index f98c2f4..0000000 --- a/Pipfile +++ /dev/null @@ -1,24 +0,0 @@ -[[source]] -url = "https://pypi.org/simple" -verify_ssl = true -name = "pypi" - -[packages] -numpy = "1.26.2" -ffmpeg-python = "0.2.0" -scikit-learn = "1.3.2" -soundfile = "0.12.1" -tqdm = "4.66.1" -openai-whisper = "20231106" -stable-ts = "2.13.3" -typing-extensions = "4.8.0" -demucs = "*" -librosa = "*" -colored = "*" -psutil = "*" -python-dotenv = "*" - -[dev-packages] - -[requires] -python_version = "3.10" diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..f1631e7 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,91 @@ +[build-system] +requires = ["setuptools>=61.0"] +build-backend = "setuptools.build_meta" + +[project] +name = "ser" +version = "0.1.0" +description = "Speech Emotion Recognition toolkit" +readme = "README.md" +requires-python = ">=3.9" +license = {text = "MIT"} +authors = [{name = "Juan Sugg", email = "juanpedrosugg@gmail.com"}] +dependencies = [ + "antlr4-python3-runtime==4.9.3", + "audioread==3.0.1; python_version >= '3.6'", + "certifi==2024.2.2; python_version >= '3.6'", + "cffi==1.16.0; python_version >= '3.8'", + "charset-normalizer==3.3.2; python_full_version >= '3.7.0'", + "cloudpickle==3.0.0; python_version >= '3.8'", + "colored==2.2.4; python_version >= '3.9'", + "decorator==5.1.1; python_version >= '3.5'", + "demucs==4.0.1; python_full_version >= '3.8.0'", + "dora-search==0.1.12; python_full_version >= '3.7.0'", + "einops==0.8.0; python_version >= '3.8'", + "ffmpeg-python==0.2.0", + "filelock==3.14.0; python_version >= '3.8'", + "fsspec==2024.5.0; python_version >= '3.8'", + "future==1.0.0; python_version >= '2.6' and python_version not in '3.0, 3.1, 3.2'", + "huggingface-hub==0.23.1; python_full_version >= '3.8.0'", + "idna==3.7; python_version >= '3.5'", + "jinja2==3.1.4; python_version >= '3.7'", + "joblib==1.4.2; python_version >= '3.8'", + "julius==0.2.7; python_full_version >= '3.6.0'", + "lameenc==1.7.0", + "lazy-loader==0.4; python_version >= '3.7'", + "librosa==0.10.2.post1; python_version >= '3.7'", + "llvmlite==0.42.0; python_version >= '3.9'", + "markupsafe==2.1.5; python_version >= '3.7'", + "more-itertools==10.2.0; python_version >= '3.8'", + "mpmath==1.3.0", + "msgpack==1.0.8; python_version >= '3.8'", + "networkx==3.3; python_version >= '3.10'", + "numba==0.59.1; python_version >= '3.9'", + "numpy==1.26.2; python_version >= '3.9'", + "omegaconf==2.3.0; python_version >= '3.6'", + "openai-whisper==20231106; python_version >= '3.8'", + "openunmix==1.3.0; python_version >= '3.9'", + "packaging==24.0; python_version >= '3.7'", + "platformdirs==4.2.2; python_version >= '3.8'", + "pooch==1.8.1; python_version >= '3.7'", + "psutil==5.9.8; python_version >= '2.7' and python_version not in '3.0, 3.1, 3.2, 3.3, 3.4, 3.5'", + "pycparser==2.22; python_version >= '3.8'", + "python-dotenv==1.0.1; python_version >= '3.8'", + "pyyaml==6.0.1; python_version >= '3.6'", + "regex==2024.5.15; python_version >= '3.8'", + "requests==2.32.2; python_version >= '3.8'", + "retrying==1.3.4", + "safetensors==0.4.3; python_version >= '3.7'", + "scikit-learn==1.3.2; python_version >= '3.8'", + "scipy==1.13.0; python_version >= '3.9'", + "six==1.16.0; python_version >= '2.7' and python_version not in '3.0, 3.1, 3.2'", + "soundfile==0.12.1", + "soxr==0.3.7; python_version >= '3.6'", + "stable-ts==2.13.3; python_version >= '3.8'", + "submitit==1.5.1; python_version >= '3.8'", + "sympy==1.12; python_version >= '3.8'", + "threadpoolctl==3.5.0; python_version >= '3.8'", + "tiktoken==0.7.0; python_version >= '3.8'", + "tokenizers==0.19.1; python_version >= '3.7'", + "torch==2.2.2; python_full_version >= '3.8.0'", + "torchaudio==2.2.2", + "tqdm==4.66.1; python_version >= '3.7'", + "transformers==4.41.1; python_full_version >= '3.8.0'", + "treetable==0.2.5; python_full_version >= '3.6.0'", + "typing-extensions==4.8.0; python_version >= '3.8'", + "urllib3==2.2.1; python_version >= '3.8'", +] + +[project.scripts] +ser = "ser.__main__:main" + +[tool.uv] +dev-dependencies = [ + "pytest>=8.2", + "pytest-cov>=5.0", +] + +[tool.pytest.ini_options] +addopts = "-ra" +testpaths = ["tests"] + diff --git a/requirements.txt b/requirements.txt deleted file mode 100644 index 58e3864..0000000 --- a/requirements.txt +++ /dev/null @@ -1,64 +0,0 @@ --i https://pypi.org/simple -antlr4-python3-runtime==4.9.3 -audioread==3.0.1; python_version >= '3.6' -certifi==2024.2.2; python_version >= '3.6' -cffi==1.16.0; python_version >= '3.8' -charset-normalizer==3.3.2; python_full_version >= '3.7.0' -cloudpickle==3.0.0; python_version >= '3.8' -colored==2.2.4; python_version >= '3.9' -decorator==5.1.1; python_version >= '3.5' -demucs==4.0.1; python_full_version >= '3.8.0' -dora-search==0.1.12; python_full_version >= '3.7.0' -einops==0.8.0; python_version >= '3.8' -ffmpeg-python==0.2.0 -filelock==3.14.0; python_version >= '3.8' -fsspec==2024.5.0; python_version >= '3.8' -future==1.0.0; python_version >= '2.6' and python_version not in '3.0, 3.1, 3.2' -huggingface-hub==0.23.1; python_full_version >= '3.8.0' -idna==3.7; python_version >= '3.5' -jinja2==3.1.4; python_version >= '3.7' -joblib==1.4.2; python_version >= '3.8' -julius==0.2.7; python_full_version >= '3.6.0' -lameenc==1.7.0 -lazy-loader==0.4; python_version >= '3.7' -librosa==0.10.2.post1; python_version >= '3.7' -llvmlite==0.42.0; python_version >= '3.9' -markupsafe==2.1.5; python_version >= '3.7' -more-itertools==10.2.0; python_version >= '3.8' -mpmath==1.3.0 -msgpack==1.0.8; python_version >= '3.8' -networkx==3.3; python_version >= '3.10' -numba==0.59.1; python_version >= '3.9' -numpy==1.26.2; python_version >= '3.9' -omegaconf==2.3.0; python_version >= '3.6' -openai-whisper==20231106; python_version >= '3.8' -openunmix==1.3.0; python_version >= '3.9' -packaging==24.0; python_version >= '3.7' -platformdirs==4.2.2; python_version >= '3.8' -pooch==1.8.1; python_version >= '3.7' -psutil==5.9.8; python_version >= '2.7' and python_version not in '3.0, 3.1, 3.2, 3.3, 3.4, 3.5' -pycparser==2.22; python_version >= '3.8' -python-dotenv==1.0.1; python_version >= '3.8' -pyyaml==6.0.1; python_version >= '3.6' -regex==2024.5.15; python_version >= '3.8' -requests==2.32.2; python_version >= '3.8' -retrying==1.3.4 -safetensors==0.4.3; python_version >= '3.7' -scikit-learn==1.3.2; python_version >= '3.8' -scipy==1.13.0; python_version >= '3.9' -six==1.16.0; python_version >= '2.7' and python_version not in '3.0, 3.1, 3.2' -soundfile==0.12.1 -soxr==0.3.7; python_version >= '3.6' -stable-ts==2.13.3; python_version >= '3.8' -submitit==1.5.1; python_version >= '3.8' -sympy==1.12; python_version >= '3.8' -threadpoolctl==3.5.0; python_version >= '3.8' -tiktoken==0.7.0; python_version >= '3.8' -tokenizers==0.19.1; python_version >= '3.7' -torch==2.2.2; python_full_version >= '3.8.0' -torchaudio==2.2.2 -tqdm==4.66.1; python_version >= '3.7' -transformers==4.41.1; python_full_version >= '3.8.0' -treetable==0.2.5; python_full_version >= '3.6.0' -typing-extensions==4.8.0; python_version >= '3.8' -urllib3==2.2.1; python_version >= '3.8' diff --git a/ser/__main__.py b/ser/__main__.py index 8ce7cd0..cc74edd 100644 --- a/ser/__main__.py +++ b/ser/__main__.py @@ -18,10 +18,10 @@ """ import argparse +import logging import sys import time -import logging -from typing import List, Tuple +from pathlib import Path from ser.models.emotion_model import predict_emotions, train_model from ser.transcript import extract_transcript @@ -31,6 +31,7 @@ print_timeline, save_timeline_to_csv, ) +from ser.utils.subtitles import SubtitleGenerator, FORMATTERS, timeline_to_subtitles from ser.config import Config @@ -65,6 +66,22 @@ def main() -> None: action="store_true", help="Save the transcript to a CSV file", ) + parser.add_argument( + "--subtitle-format", + choices=tuple(FORMATTERS.keys()), + help=( + "Export the generated timeline as subtitles in the chosen format. " + "If omitted, the format is inferred from --subtitle-output when possible." + ), + ) + parser.add_argument( + "--subtitle-output", + type=str, + help=( + "File path for the exported subtitle file. The format is inferred from " + "the extension when --subtitle-format is not provided." + ), + ) args: argparse.Namespace = parser.parse_args() if args.train: @@ -82,13 +99,56 @@ def main() -> None: logger.info(msg="Starting emotion prediction...") start_time = time.time() - emotions: List[Tuple[str, float, float]] = predict_emotions(args.file) - transcript: List[Tuple[str, float, float]] = extract_transcript( + emotions: list[tuple[str, float, float]] = predict_emotions(args.file) + transcript: list[tuple[str, float, float]] = extract_transcript( args.file, args.language ) - timeline: list = build_timeline(transcript, emotions) + timeline: list[tuple[float, str, str]] = build_timeline(transcript, emotions) print_timeline(timeline) + if args.subtitle_format or args.subtitle_output: + if not args.subtitle_output: + logger.error( + msg="--subtitle-output is required to export subtitles.", + ) + sys.exit(1) + + subtitle_format: str | None = args.subtitle_format + if not subtitle_format: + subtitle_format = _infer_subtitle_format(args.subtitle_output) + if not subtitle_format: + logger.error( + "Unable to infer subtitle format from %s. Provide --subtitle-format.", + args.subtitle_output, + ) + sys.exit(1) + else: + inferred_format: str | None = _infer_subtitle_format(args.subtitle_output) + if inferred_format and inferred_format != subtitle_format: + logger.info( + "Using subtitle format %s (overriding inferred format %s from output path)", + subtitle_format, + inferred_format, + ) + + subtitles: list[tuple[float, float, str, str]] = timeline_to_subtitles(timeline) + if not subtitles: + logger.warning("Timeline did not produce any subtitle entries to export.") + else: + try: + generator = SubtitleGenerator(FORMATTERS[subtitle_format]) + generator.generate_file(subtitles, args.subtitle_output) + logger.info( + "Subtitle file exported to %s", + args.subtitle_output, + ) + except Exception as err: + logger.error( + msg=f"Failed to export subtitles: {err}", + exc_info=True, + ) + sys.exit(1) + if args.save_transcript: csv_file_name: str = save_timeline_to_csv(timeline, args.file) logger.info(msg=f"Timeline saved to {csv_file_name}") @@ -98,5 +158,10 @@ def main() -> None: ) +def _infer_subtitle_format(output_path: str) -> str | None: + suffix: str = Path(output_path).suffix.lower().lstrip(".") + return suffix if suffix in FORMATTERS else None + + if __name__ == "__main__": main() diff --git a/ser/utils/subtitles.py b/ser/utils/subtitles.py index 222cc02..d858162 100644 --- a/ser/utils/subtitles.py +++ b/ser/utils/subtitles.py @@ -1,15 +1,11 @@ import argparse import logging from abc import ABC, abstractmethod -from typing import List, Tuple from ser.utils.logger import get_logger logger: logging.Logger = get_logger(__name__) -# Configure logging -logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') - class SubtitleFormatter(ABC): """Abstract base class for subtitle formatters.""" @@ -24,7 +20,7 @@ def generate_entry(self, index: int, start: float, end: float, text: str, emotio pass @abstractmethod - def generate_file(self, subtitles: List[Tuple[float, float, str, str]], output_file: str) -> None: + def generate_file(self, subtitles: list[tuple[float, float, str, str]], output_file: str) -> None: """Generate a subtitle file from a list of subtitles.""" pass @@ -58,19 +54,25 @@ def generate_entry(self, index: int, start: float, end: float, text: str, emotio """Generate a single ASS subtitle entry.""" start_time: str = self.format_time(start) end_time: str = self.format_time(end) - logging.debug(f"ASS Entry: Start {start_time}, End {end_time}, Text {text}, Emotion {emotion}") + logger.debug( + "ASS Entry: Start %s, End %s, Text %s, Emotion %s", + start_time, + end_time, + text, + emotion, + ) return f"Dialogue: 0,{start_time},{end_time},Default,,0,0,0,,{text} ({emotion})" - def generate_file(self, subtitles: List[Tuple[float, float, str, str]], output_file: str) -> None: + def generate_file(self, subtitles: list[tuple[float, float, str, str]], output_file: str) -> None: """Generate an ASS file from a list of subtitles.""" - logging.info(f"Generating ASS file: {output_file}") + logger.info("Generating ASS file: %s", output_file) with open(output_file, 'w', encoding='utf-8') as f: f.write(self.ASS_HEADER) for i, (start, duration, text, emotion) in enumerate(subtitles, 1): end: float = start + duration entry: str = self.generate_entry(i, start, end, text, emotion) f.write(entry + '\n') - logging.info(f"ASS file generated successfully: {output_file}") + logger.info("ASS file generated successfully: %s", output_file) class SRTFormatter(SubtitleFormatter): """Formatter for SRT subtitles.""" @@ -87,18 +89,24 @@ def generate_entry(self, index: int, start: float, end: float, text: str, emotio """Generate a single SRT subtitle entry.""" start_time: str = self.format_time(start) end_time: str = self.format_time(end) - logging.debug(f"SRT Entry: Start {start_time}, End {end_time}, Text {text}, Emotion {emotion}") + logger.debug( + "SRT Entry: Start %s, End %s, Text %s, Emotion %s", + start_time, + end_time, + text, + emotion, + ) return f"{index}\n{start_time} --> {end_time}\n{text} ({emotion})\n" - def generate_file(self, subtitles: List[Tuple[float, float, str, str]], output_file: str) -> None: + def generate_file(self, subtitles: list[tuple[float, float, str, str]], output_file: str) -> None: """Generate an SRT file from a list of subtitles.""" - logging.info(f"Generating SRT file: {output_file}") + logger.info("Generating SRT file: %s", output_file) with open(output_file, 'w', encoding='utf-8') as f: for i, (start, duration, text, emotion) in enumerate(subtitles, 1): end: float = start + duration entry: str = self.generate_entry(i, start, end, text, emotion) f.write(entry + '\n') - logging.info(f"SRT file generated successfully: {output_file}") + logger.info("SRT file generated successfully: %s", output_file) class VTTFormatter(SubtitleFormatter): """Formatter for WebVTT subtitles.""" @@ -115,19 +123,72 @@ def generate_entry(self, index: int, start: float, end: float, text: str, emotio """Generate a single WebVTT subtitle entry.""" start_time: str = self.format_time(start) end_time: str = self.format_time(end) - logging.debug(f"VTT Entry: Start {start_time}, End {end_time}, Text {text}, Emotion {emotion}") + logger.debug( + "VTT Entry: Start %s, End %s, Text %s, Emotion %s", + start_time, + end_time, + text, + emotion, + ) return f"{start_time} --> {end_time}\n{text} ({emotion})\n" - def generate_file(self, subtitles: List[Tuple[float, float, str, str]], output_file: str) -> None: + def generate_file(self, subtitles: list[tuple[float, float, str, str]], output_file: str) -> None: """Generate a WebVTT file from a list of subtitles.""" - logging.info(f"Generating WebVTT file: {output_file}") + logger.info("Generating WebVTT file: %s", output_file) with open(output_file, 'w', encoding='utf-8') as f: f.write("WEBVTT\n\n") for i, (start, duration, text, emotion) in enumerate(subtitles, 1): end: float = start + duration entry: str = self.generate_entry(i, start, end, text, emotion) f.write(entry + '\n') - logging.info(f"WebVTT file generated successfully: {output_file}") + logger.info("WebVTT file generated successfully: %s", output_file) + + +DEFAULT_SUBTITLE_DURATION = 1.0 + + +def timeline_to_subtitles( + timeline: list[tuple[float, str, str]], + default_duration: float = DEFAULT_SUBTITLE_DURATION, +) -> list[tuple[float, float, str, str]]: + """Convert a timeline of transcript/emotion entries to subtitle tuples.""" + if not timeline: + logger.debug("Received empty timeline for subtitle conversion") + return [] + + sorted_timeline: list[tuple[float, str, str]] = sorted( + timeline, + key=lambda entry: entry[0], + ) + + subtitles: list[tuple[float, float, str, str]] = [] + for index, (timestamp, emotion, text) in enumerate(sorted_timeline): + cleaned_text: str = text.strip() + if not cleaned_text: + continue + + next_timestamp: float | None = None + if index + 1 < len(sorted_timeline): + next_timestamp = float(sorted_timeline[index + 1][0]) + + if next_timestamp is not None: + duration: float = max(next_timestamp - float(timestamp), 0.0) + if duration == 0.0: + duration = default_duration + else: + duration = default_duration + + subtitles.append((float(timestamp), duration, cleaned_text, emotion)) + logger.debug( + "Subtitle entry prepared: Start %s, Duration %s, Text %s, Emotion %s", + timestamp, + duration, + cleaned_text, + emotion, + ) + + return subtitles + class SubtitleGenerator: """Main class to generate subtitle files in different formats.""" @@ -135,12 +196,16 @@ class SubtitleGenerator: def __init__(self, formatter: SubtitleFormatter) -> None: self.formatter: SubtitleFormatter = formatter - def generate_file(self, subtitles: List[Tuple[float, float, str, str]], output_file: str) -> None: + def generate_file(self, subtitles: list[tuple[float, float, str, str]], output_file: str) -> None: """Generate a subtitle file using the provided formatter.""" - try: - self.formatter.generate_file(subtitles, output_file) - except Exception as e: - logging.error(f"Failed to generate file: {output_file}, error: {e}") + self.formatter.generate_file(subtitles, output_file) + + +FORMATTERS: dict[str, SubtitleFormatter] = { + "ass": ASSFormatter(), + "srt": SRTFormatter(), + "vtt": VTTFormatter(), +} def parse_arguments() -> argparse.Namespace: """Parse command line arguments.""" @@ -149,9 +214,9 @@ def parse_arguments() -> argparse.Namespace: formatter_class=argparse.RawTextHelpFormatter ) parser.add_argument( - "format", - type=str, - choices=["ass", "srt", "vtt"], + "format", + type=str, + choices=tuple(FORMATTERS.keys()), help="Output subtitle format. Choices are:\n" " ass: Generate Advanced SubStation Alpha (.ass) file\n" " srt: Generate SubRip Subtitle (.srt) file\n" @@ -170,33 +235,33 @@ def parse_arguments() -> argparse.Namespace: ) return parser.parse_args() -def parse_subtitles(subtitles_str: str) -> List[Tuple[float, float, str, str]]: +def parse_subtitles(subtitles_str: str) -> list[tuple[float, float, str, str]]: """Parse the input string of subtitles into a list of tuples.""" - subtitles: List[Tuple[float, float, str, str]] = [] + subtitles: list[tuple[float, float, str, str]] = [] for subtitle in subtitles_str.split(';'): try: start_str, duration_str, text, emotion = subtitle.split(',') start: float = float(start_str) duration: float = float(duration_str) subtitles.append((start, duration, text, emotion)) - logging.debug(f"Parsed subtitle: Start {start}, Duration {duration}, Text {text}, Emotion {emotion}") + logger.debug( + "Parsed subtitle: Start %s, Duration %s, Text %s, Emotion %s", + start, + duration, + text, + emotion, + ) except ValueError: - logging.error(f"Invalid subtitle format: {subtitle}") + logger.error("Invalid subtitle format: %s", subtitle) continue return subtitles def main() -> None: """Main entry point for the CLI.""" args: argparse.Namespace = parse_arguments() - subtitles: List[Tuple[float, float, str, str]] = parse_subtitles(args.subtitles) - - formatters: Dict[str, SubtitleFormatter] = { - "ass": ASSFormatter(), - "srt": SRTFormatter(), - "vtt": VTTFormatter() - } - - formatter: SubtitleFormatter = formatters[args.format] + subtitles: list[tuple[float, float, str, str]] = parse_subtitles(args.subtitles) + + formatter: SubtitleFormatter = FORMATTERS[args.format] generator: SubtitleGenerator = SubtitleGenerator(formatter) generator.generate_file(subtitles, args.output) diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..38823b4 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,182 @@ +import contextlib +import sys +import types +from pathlib import Path + +PROJECT_ROOT = Path(__file__).resolve().parents[1] +if str(PROJECT_ROOT) not in sys.path: + sys.path.insert(0, str(PROJECT_ROOT)) + +_dummy_colored = types.ModuleType("colored") +_dummy_colored.attr = lambda name: "" +_dummy_colored.bg = lambda color: "" +_dummy_colored.fg = lambda color: "" +sys.modules.setdefault("colored", _dummy_colored) + +class _AutoHalo: + def __enter__(self): + return self + def __exit__(self, exc_type, exc, tb): + return False + +_dummy_halo = types.ModuleType("halo") +_dummy_halo.Halo = _AutoHalo +sys.modules.setdefault("halo", _dummy_halo) + +class _Array(list): + def tolist(self): + return list(self) + + def __itruediv__(self, other): + for index, value in enumerate(self): + self[index] = value / other + return self + +_dummy_numpy = types.ModuleType("numpy") +_dummy_numpy.ndarray = _Array +_dummy_numpy.float32 = float +_dummy_numpy.bool_ = bool + +def _as_array(data): + return _Array(data) if not isinstance(data, _Array) else data + +def _array(data, dtype=None): + return _Array(list(data)) + +_dummy_numpy.array = _array +def _zeros(n, dtype=None): + return _Array(0.0 for _ in range(int(n))) + +_dummy_numpy.zeros = _zeros +def _zeros_like(data): + return _Array(0.0 for _ in data) + +_dummy_numpy.zeros_like = _zeros_like +def _abs(data): + return _Array(abs(x) for x in data) + +_dummy_numpy.abs = _abs +_dummy_numpy.max = lambda data: max(data) +_dummy_numpy.isscalar = lambda obj: isinstance(obj, (int, float, str, bool)) + +sys.modules.setdefault("numpy", _dummy_numpy) + +_dummy_librosa = types.ModuleType("librosa") +_dummy_librosa.load = lambda *args, **kwargs: (_Array([]), 16000) +_dummy_librosa.get_duration = lambda y: float(len(y)) +sys.modules.setdefault("librosa", _dummy_librosa) + +class _SoundFile: + def __init__(self, *args, **kwargs): + self._data = _Array([0.0]) + self.samplerate = 16000 + def __enter__(self): + return self + def __exit__(self, exc_type, exc, tb): + return False + def read(self, dtype=None): + return _Array(self._data) + +_dummy_soundfile = types.ModuleType("soundfile") +_dummy_soundfile.SoundFile = _SoundFile +sys.modules.setdefault("soundfile", _dummy_soundfile) + +class _DummyMLP: + def __init__(self, *args, **kwargs): + pass + def fit(self, X, y): + return self + def predict(self, X): + return [0 for _ in range(len(X))] + +_dummy_sklearn_metrics = types.ModuleType("sklearn.metrics") +_dummy_sklearn_metrics.accuracy_score = lambda y_true, y_pred: 1.0 +sys.modules.setdefault("sklearn.metrics", _dummy_sklearn_metrics) + +_dummy_sklearn_nn = types.ModuleType("sklearn.neural_network") +_dummy_sklearn_nn.MLPClassifier = _DummyMLP +sys.modules.setdefault("sklearn.neural_network", _dummy_sklearn_nn) + +_dummy_sklearn_model_selection = types.ModuleType("sklearn.model_selection") +_dummy_sklearn_model_selection.train_test_split = lambda X, y, test_size=0.2, random_state=None: (X, X, y, y) +sys.modules.setdefault("sklearn.model_selection", _dummy_sklearn_model_selection) + +_dummy_sklearn = types.ModuleType("sklearn") +_dummy_sklearn.metrics = _dummy_sklearn_metrics +_dummy_sklearn.neural_network = _dummy_sklearn_nn +_dummy_sklearn.model_selection = _dummy_sklearn_model_selection +sys.modules.setdefault("sklearn", _dummy_sklearn) + +_dummy_dotenv = types.ModuleType("dotenv") +_dummy_dotenv.load_dotenv = lambda *args, **kwargs: None +sys.modules.setdefault("dotenv", _dummy_dotenv) + +_dummy_stable_whisper = types.ModuleType("stable_whisper") +_dummy_stable_whisper.load_model = lambda *args, **kwargs: object() +_dummy_stable_whisper_result = types.ModuleType("stable_whisper.result") +class _DummyWhisperResult: + segments = [] + +_dummy_stable_whisper_result.WhisperResult = _DummyWhisperResult +sys.modules.setdefault("stable_whisper", _dummy_stable_whisper) +sys.modules.setdefault("stable_whisper.result", _dummy_stable_whisper_result) + +_dummy_whisper_model = types.ModuleType("whisper.model") +class _DummyWhisper: + def transcribe(self, *args, **kwargs): + return None + +_dummy_whisper_model.Whisper = _DummyWhisper +sys.modules.setdefault("whisper.model", _dummy_whisper_model) + +import io +import logging +from typing import Sequence + +import pytest + +import ser.__main__ as ser_main + + +@pytest.fixture +def run_cli(monkeypatch): + """Run the SER CLI with a custom argv list.""" + + def _run_cli(args: Sequence[str], *, expect_exit: bool = True) -> tuple[int, str]: + argv = ["ser", *args] + monkeypatch.setattr(sys, "argv", argv) + stdout = io.StringIO() + with contextlib.redirect_stdout(stdout), contextlib.redirect_stderr(stdout): + try: + ser_main.main() + except SystemExit as exc: # pragma: no cover - exercised in tests + return exc.code, stdout.getvalue() + if expect_exit: + raise AssertionError("CLI did not exit as expected") + return 0, stdout.getvalue() + + return _run_cli + + +@pytest.fixture(autouse=True) +def _silence_halo(monkeypatch): + """Replace Halo spinners with a no-op context manager for tests.""" + + class _DummyHalo: + def __init__(self, *args, **kwargs): + self.text = kwargs.get("text") + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + return False + + monkeypatch.setattr("ser.utils.timeline_utils.Halo", _DummyHalo, raising=False) + monkeypatch.setattr("ser.models.emotion_model.Halo", _DummyHalo, raising=False) + + +@pytest.fixture +def caplog_info(caplog): + caplog.set_level(logging.INFO) + return caplog diff --git a/tests/suites/e2e/.gitkeep b/tests/suites/e2e/.gitkeep new file mode 100644 index 0000000..ff550f2 --- /dev/null +++ b/tests/suites/e2e/.gitkeep @@ -0,0 +1 @@ +# Placeholder to retain directory structure in Git diff --git a/tests/suites/e2e/test_cli.py b/tests/suites/e2e/test_cli.py new file mode 100644 index 0000000..fe1564d --- /dev/null +++ b/tests/suites/e2e/test_cli.py @@ -0,0 +1,79 @@ +import logging +import ser.__main__ as cli + + +def _stub_processing(monkeypatch, subtitles=None): + monkeypatch.setattr(cli, "predict_emotions", lambda *_: [("happy", 0.0, 1.0)]) + monkeypatch.setattr(cli, "extract_transcript", lambda *args: [("hello", 0.0, 0.5)]) + monkeypatch.setattr(cli, "build_timeline", lambda *args: [(0.0, "happy", "hello")]) + monkeypatch.setattr(cli, "print_timeline", lambda *args: None) + + captured = {"generated": None, "formatter": None} + + def fake_timeline_to_subtitles(timeline): + return subtitles or [(0.0, 1.0, "hello", "happy")] + + class DummySubtitleGenerator: + def __init__(self, formatter): + captured["formatter"] = formatter + + def generate_file(self, subs, output_path): + captured["generated"] = (subs, output_path) + + monkeypatch.setattr(cli, "timeline_to_subtitles", fake_timeline_to_subtitles) + monkeypatch.setattr(cli, "SubtitleGenerator", DummySubtitleGenerator) + return captured + + +def test_cli_without_file_exits_with_error(run_cli, caplog): + caplog.set_level(logging.ERROR) + exit_code, _ = run_cli([]) + assert exit_code == 1 + assert any("No audio file provided" in message for message in caplog.messages) + + +def test_cli_requires_subtitle_output_when_flags_present(run_cli, monkeypatch, caplog): + caplog.set_level(logging.ERROR) + _stub_processing(monkeypatch) + exit_code, _ = run_cli(["--file", "input.wav", "--subtitle-format", "srt"]) + assert exit_code == 1 + assert any("--subtitle-output is required" in message for message in caplog.messages) + + +def test_cli_inferrs_format_from_output_extension(run_cli, monkeypatch, tmp_path, caplog): + caplog.set_level(logging.INFO) + captured = _stub_processing(monkeypatch) + output_path = tmp_path / "export.srt" + + exit_code, _ = run_cli( + ["--file", "input.wav", "--subtitle-output", output_path.as_posix()], + expect_exit=False, + ) + + assert exit_code == 0 + assert captured["generated"] == ([(0.0, 1.0, "hello", "happy")], output_path.as_posix()) + assert captured["formatter"].__class__.__name__.lower().startswith("srt") + + +def test_cli_format_override_logs_precedence(run_cli, monkeypatch, tmp_path, caplog): + caplog.set_level(logging.INFO) + captured = _stub_processing(monkeypatch) + output_path = tmp_path / "export.vtt" + + exit_code, _ = run_cli( + [ + "--file", + "input.wav", + "--subtitle-output", + output_path.as_posix(), + "--subtitle-format", + "ass", + ], + expect_exit=False, + ) + + assert exit_code == 0 + assert captured["formatter"].__class__.__name__.lower().startswith("ass") + assert any( + "overriding inferred format" in message for message in caplog.messages + ) diff --git a/tests/suites/integration/.gitkeep b/tests/suites/integration/.gitkeep new file mode 100644 index 0000000..ff550f2 --- /dev/null +++ b/tests/suites/integration/.gitkeep @@ -0,0 +1 @@ +# Placeholder to retain directory structure in Git diff --git a/tests/suites/integration/test_timeline_csv.py b/tests/suites/integration/test_timeline_csv.py new file mode 100644 index 0000000..5e502f7 --- /dev/null +++ b/tests/suites/integration/test_timeline_csv.py @@ -0,0 +1,20 @@ +from pathlib import Path + +from ser.config import Config +from ser.utils.timeline_utils import save_timeline_to_csv + + +def test_save_timeline_to_csv_writes_headers_and_rows(tmp_path, monkeypatch): + target_dir = tmp_path / "transcripts" + target_dir.mkdir() + monkeypatch.setitem(Config.TIMELINE_CONFIG, "folder", target_dir.as_posix()) + + timeline = [(0.1234, "happy", "Hello world")] + + csv_path = save_timeline_to_csv(timeline, "sample.wav") + + written = Path(csv_path) + assert written.exists() + rows = written.read_text(encoding="utf-8").splitlines() + assert rows[0] == "Time (s),Emotion,Speech" + assert rows[1] == "0.12,happy,Hello world" diff --git a/tests/suites/unit/.gitkeep b/tests/suites/unit/.gitkeep new file mode 100644 index 0000000..ff550f2 --- /dev/null +++ b/tests/suites/unit/.gitkeep @@ -0,0 +1 @@ +# Placeholder to retain directory structure in Git diff --git a/tests/suites/unit/test_audio_utils.py b/tests/suites/unit/test_audio_utils.py new file mode 100644 index 0000000..8c42682 --- /dev/null +++ b/tests/suites/unit/test_audio_utils.py @@ -0,0 +1,56 @@ +from unittest import mock + +import numpy as np +import pytest + +from ser.utils.audio_utils import read_audio_file + + +@pytest.fixture(autouse=True) +def _patch_sleep(monkeypatch): + monkeypatch.setattr("ser.utils.audio_utils.time.sleep", lambda *_: None) + + +def test_read_audio_file_falls_back_and_eventually_raises(monkeypatch): + audio_data = np.array([0.0, 0.5, -0.5], dtype=np.float32) + + load_calls = {} + + def failing_librosa_load(*args, **kwargs): + load_calls["librosa"] = load_calls.get("librosa", 0) + 1 + raise RuntimeError("boom") + + class DummySoundFile: + def __init__(self, *args, **kwargs): + load_calls["soundfile"] = load_calls.get("soundfile", 0) + 1 + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + return False + + def read(self, dtype=None): + return audio_data + + @property + def samplerate(self): + return 16000 + + monkeypatch.setattr("ser.utils.audio_utils.librosa.load", failing_librosa_load) + monkeypatch.setattr("ser.utils.audio_utils.sf.SoundFile", DummySoundFile) + + audio, rate = read_audio_file("fake.wav") + + assert pytest.approx(audio.tolist()) == [0.0, 1.0, -1.0] + assert rate == 16000 + assert load_calls["librosa"] == 1 + assert load_calls["soundfile"] == 1 + + def failing_soundfile(*args, **kwargs): + raise OSError("nope") + + monkeypatch.setattr("ser.utils.audio_utils.sf.SoundFile", failing_soundfile) + + with pytest.raises(IOError): + read_audio_file("fake.wav") diff --git a/tests/suites/unit/test_emotion_model.py b/tests/suites/unit/test_emotion_model.py new file mode 100644 index 0000000..e950f95 --- /dev/null +++ b/tests/suites/unit/test_emotion_model.py @@ -0,0 +1,18 @@ +import numpy as np + +from ser.models import emotion_model + + +def test_predict_emotions_collapses_repeated_segments(monkeypatch): + class DummyModel: + def predict(self, features): + return np.array(["happy", "happy", "sad", "sad", "sad"], dtype=object) + + monkeypatch.setattr(emotion_model, "load_model", lambda: DummyModel()) + monkeypatch.setattr(emotion_model, "extended_extract_feature", lambda file: [np.zeros(1)]) + monkeypatch.setattr(emotion_model, "read_audio_file", lambda file: (np.zeros(5), 16000)) + monkeypatch.setattr(emotion_model.librosa, "get_duration", lambda y: 10.0) + + segments = emotion_model.predict_emotions("fake.wav") + + assert segments == [("happy", 0, 4.0), ("sad", 4.0, 10.0)] diff --git a/tests/suites/unit/test_subtitles.py b/tests/suites/unit/test_subtitles.py new file mode 100644 index 0000000..ef56e37 --- /dev/null +++ b/tests/suites/unit/test_subtitles.py @@ -0,0 +1,89 @@ +from pathlib import Path + +import pytest + +from ser.utils.subtitles import ( + ASSFormatter, + SRTFormatter, + VTTFormatter, + parse_subtitles, + timeline_to_subtitles, +) + + +def test_timeline_to_subtitles_sorts_and_computes_durations(): + timeline = [ + (5.0, "happy", "Third"), + (1.0, "neutral", "First"), + (3.0, "sad", "Second"), + ] + + result = timeline_to_subtitles(timeline, default_duration=2.0) + + assert [start for start, *_ in result] == [1.0, 3.0, 5.0] + assert result[0][1] == pytest.approx(2.0) + assert result[1][1] == pytest.approx(2.0) + + +def test_timeline_to_subtitles_uses_default_duration_for_last_entry(): + timeline = [(0.0, "calm", "Only line")] + + result = timeline_to_subtitles(timeline, default_duration=3.5) + + assert result == [(0.0, 3.5, "Only line", "calm")] + + +def test_timeline_to_subtitles_skips_blank_text(): + timeline = [ + (0.0, "happy", " Hello "), + (1.0, "sad", " "), + (2.0, "angry", "World"), + ] + + result = timeline_to_subtitles(timeline, default_duration=1.0) + + assert result == [ + (0.0, 1.0, "Hello", "happy"), + (2.0, 1.0, "World", "angry"), + ] + + +def test_parse_subtitles_handles_invalid_entries(caplog): + caplog.set_level("ERROR") + parsed = parse_subtitles("0.0,1.0,Hello,Happy;invalid;2.0,1.5,Bye,Sad") + + assert parsed == [(0.0, 1.0, "Hello", "Happy"), (2.0, 1.5, "Bye", "Sad")] + assert any("Invalid subtitle format" in record.message for record in caplog.records) + + +def test_ass_formatter_writes_header_and_entries(tmp_path: Path): + output = tmp_path / "example.ass" + subtitles = [(0.0, 2.0, "Hello", "Happy")] + + ASSFormatter().generate_file(subtitles, output.as_posix()) + + content = output.read_text(encoding="utf-8") + assert content.startswith(ASSFormatter.ASS_HEADER) + assert "Dialogue: 0,0:00:00.00,0:00:02.00,Default" in content + + +def test_srt_formatter_entry_format(tmp_path: Path): + output = tmp_path / "example.srt" + subtitles = [(0.0, 1.25, "Hello", "Happy"), (1.25, 0.75, "Bye", "Sad")] + + SRTFormatter().generate_file(subtitles, output.as_posix()) + + content = output.read_text(encoding="utf-8") + assert "1\n00:00:00,000 --> 00:00:01,250\nHello (Happy)" in content + assert "2\n00:00:01,250 --> 00:00:02,000\nBye (Sad)" in content + + +def test_vtt_formatter_includes_header_and_arrows(tmp_path: Path): + output = tmp_path / "example.vtt" + subtitles = [(0.0, 1.0, "Hello", "Happy")] + + VTTFormatter().generate_file(subtitles, output.as_posix()) + + content = output.read_text(encoding="utf-8") + assert content.startswith("WEBVTT\n\n") + assert "00:00:00.000 --> 00:00:01.000" in content diff --git a/tests/suites/unit/test_timeline_utils.py b/tests/suites/unit/test_timeline_utils.py new file mode 100644 index 0000000..69ad544 --- /dev/null +++ b/tests/suites/unit/test_timeline_utils.py @@ -0,0 +1,16 @@ +from ser.utils.timeline_utils import build_timeline + + +def test_build_timeline_merges_transcript_and_emotions(): + transcript = [("hello", 0.0, 0.5), ("world", 2.0, 2.5)] + emotions = [("happy", 0.0, 1.0), ("sad", 3.0, 4.0)] + + timeline = build_timeline(transcript, emotions) + + assert timeline == [ + (0.0, "happy", "hello"), + (1.0, "", ""), + (2.0, "", "world"), + (3.0, "sad", ""), + (4.0, "", ""), + ]