diff --git a/README.md b/README.md index 403b297..168155c 100644 --- a/README.md +++ b/README.md @@ -23,9 +23,9 @@ for seg in result.segments: print(f" [{seg.start:.1f}s - {seg.end:.1f}s] {seg.speaker}") ``` -**~5.0% weighted DER** on VoxConverse dev. Processes audio **~8x faster than real-time** on CPU. Automatically detects the number of speakers. +**~4.8% weighted DER** on VoxConverse dev. Processes audio **~8x faster than real-time** on CPU. Automatically detects the number of speakers. -> Benchmarked on a single dataset ([VoxConverse](https://github.com/joonson/voxconverse)). Cross-dataset validation is [in progress](#roadmap). +> Primary benchmark: [VoxConverse](https://github.com/joonson/voxconverse). Preliminary AMI meeting-domain validation is [in progress](#roadmap). ## How diarize compares @@ -35,7 +35,7 @@ for seg in result.segments: | GPU required | No | No (7x slower on CPU) | No | | HuggingFace account | No | Yes | Yes | | Auto speaker count | Yes | Yes | Yes | -| DER (VoxConverse dev) | **~5.0%** | ~11.2% | ~8.5% | +| DER (VoxConverse dev) | **~4.8%** | ~11.2% | ~8.5% | | CPU speed (RTF) | **0.12** | 0.86 | — | | Install | `pip install diarize` | `pip install pyannote.audio` | `pip install pyannote.audio` | @@ -102,7 +102,7 @@ Evaluated on [VoxConverse](https://github.com/joonson/voxconverse) dev set (216 | System | Weighted DER | Notes | |--------|----------|-------| | pyannote precision-2 | ~8.5% | Commercial license | -| **diarize** | **~5.0%** | **Apache 2.0, CPU-only, no API key** | +| **diarize** | **~4.8%** | **Apache 2.0, CPU-only, no API key** | | pyannote community-1 | ~11.2% | CC-BY-4.0, needs HF token | | pyannote 3.1 (legacy) | ~11.2% | MIT, needs HF token | @@ -111,16 +111,26 @@ Evaluated on [VoxConverse](https://github.com/joonson/voxconverse) dev set (216 | Metric | Result | |--------|--------| | Files | 216 | -| Exact match | 117/216 (54%) | -| Within ±1 | 175/216 (81%) | +| Exact match | 125/216 (58%) | +| Within ±1 | 178/216 (82%) | Many-speaker files remain the weak spot: automatic count estimation degrades above 7 speakers. Pass `num_speakers` when the count is known. +Preliminary AMI meeting-domain check (16 Mix-Headset test files, 4–9 speakers): + +| Metric | Result | +|--------|--------| +| Weighted DER | 14.96% | +| Speaker count exact match | 4/16 (25%) | +| Speaker count within ±1 | 8/16 (50%) | + +AMI confirms that meeting-domain speaker counting is harder: the estimator often collapses 6+ speaker meetings to 4–5 speakers. + Full benchmark results, speed comparison, and methodology: [benchmarks](https://foxnosetech.github.io/diarize/benchmarks/). ## When to use something else -- **You need commercial support or cross-dataset validation.** pyannote's commercial model has published production-oriented benchmarks beyond this single VoxConverse evaluation. If accuracy is the top priority and you have budget, compare on your own data. +- **You need commercial support or broad cross-dataset validation.** pyannote's commercial model has published production-oriented benchmarks beyond this limited VoxConverse/AMI evaluation. If accuracy is the top priority and you have budget, compare on your own data. - **You need very stable speaker labels in transcripts.** Temporal smoothing reduces short label jumps, but diarize can still show speaker fragmentation / label switching: one real speaker may be split across multiple `SPEAKER_XX` labels, especially on noisy real-world audio. - **Your audio has 8+ speakers.** Automatic speaker count estimation degrades above 7 speakers. You can pass `num_speakers` explicitly, but test carefully. - **You need overlapping speech detection.** diarize assigns each segment to one speaker. Overlapping speech is not modeled. @@ -128,9 +138,9 @@ Full benchmark results, speed comparison, and methodology: [benchmarks](https:// ## Roadmap -Current benchmarks are based on VoxConverse dev set only. We are actively working on: +Current benchmarks include VoxConverse dev and preliminary AMI test validation. We are actively working on: -- **Cross-dataset validation** — AMI, DIHARD III, CALLHOME, and other standard benchmarks in isolated environments +- **Cross-dataset validation** — DIHARD III, CALLHOME, and other standard benchmarks in isolated environments - **Speaker count estimation benchmarks** — comparison of speaker counting accuracy against other systems - **Broader system comparison** — NeMo, WhisperX, and other diarization solutions - **Streaming / real-time diarization** — live audio streams with real-time speaker detection diff --git a/docs/benchmarks.md b/docs/benchmarks.md index 9dd8004..cbfdea0 100644 --- a/docs/benchmarks.md +++ b/docs/benchmarks.md @@ -1,15 +1,17 @@ # Benchmarks -Evaluated on the [VoxConverse](https://github.com/joonson/voxconverse) -dev set (216 files, 1--20 speakers per file). +Primary published numbers are evaluated on the +[VoxConverse](https://github.com/joonson/voxconverse) dev set +(216 files, 1--20 speakers per file). We also run preliminary +cross-dataset checks on AMI meetings to track generalisation. ## Speaker Count Estimation | Metric | Result | |--------|--------| | Files | 216 | -| Exact match | 117/216 (54%) | -| Within +/-1 | 175/216 (81%) | +| Exact match | 125/216 (58%) | +| Within +/-1 | 178/216 (82%) | The automatic estimator is usually close, but exact counting remains the main weak spot. Accuracy drops for many-speaker files --- see @@ -23,7 +25,7 @@ DER is the standard metric for speaker diarization, computed with | System | Weighted DER | Median DER | Notes | |--------|----------|------------|-------| | pyannote precision-2 | ~8.5% | -- | Commercial license | -| **diarize** | **~5.0%** | **~2.2%** | **Apache 2.0, CPU-only, no API key** | +| **diarize** | **~4.8%** | **~2.1%** | **Apache 2.0, CPU-only, no API key** | | pyannote community-1 | ~11.2% | -- | CC-BY-4.0, needs HF token | | pyannote 3.1 (legacy) | ~11.2% | -- | MIT, needs HF token | @@ -31,13 +33,34 @@ pyannote DER numbers are self-reported from the [pyannote benchmark page](https://huggingface.co/pyannote/speaker-diarization-3.1) on VoxConverse v0.3. -!!! note "VoxConverse-only result" +!!! note "Dataset-specific result" On this VoxConverse dev evaluation, `diarize` reports lower weighted DER than the published pyannote VoxConverse figures, while requiring no HuggingFace token or account registration. Treat this as a - single-dataset benchmark and compare on your own audio when accuracy + VoxConverse-specific benchmark and compare on your own audio when accuracy is the top priority. +## Cross-Dataset Check: AMI + +Preliminary AMI test-set evaluation uses 16 Mix-Headset meeting +recordings (4--9 speakers per file), RTTM annotations from the +standard AMI speaker-diarization benchmark, and the same DER settings +(``collar=0.25``, ``skip_overlap=True``). + +| Metric | Result | +|--------|--------| +| Files | 16 | +| Weighted DER | 14.96% | +| Mean DER | 14.63% | +| Median DER | 14.18% | +| Speaker count exact match | 4/16 (25%) | +| Speaker count within +/-1 | 8/16 (50%) | + +This confirms that meeting-domain audio is a harder case for automatic +speaker counting. The estimator often collapses 6+ speaker meetings to +4--5 speakers, even when aggregate DER remains moderate because some +ground-truth speakers have little speaking time. + ## CPU Speed (Real Time Factor) RTF = processing_time / audio_duration. Lower is faster; RTF < 1.0 means @@ -76,6 +99,34 @@ Measured on VoxConverse dev files on Apple M2 Pro / M2 Max warm-up. RTF = processing_time / audio_duration. - **Hardware:** Apple M2 Pro, macOS, CPU only (no GPU). +## Reproducing and Extending Benchmarks + +The repository includes a dataset-agnostic RTTM runner for local +experiments: + +```bash +python scripts/benchmark_rttm.py \ + --dataset voxconverse-dev \ + --audio-dir /path/to/voxconverse/dev/audio \ + --rttm-dir /path/to/voxconverse/rttm_annotations/dev \ + --output results_voxconverse_dev.json +``` + +It also supports combined RTTM files and targeted diagnostics: + +```bash +python scripts/benchmark_rttm.py \ + --dataset ami-test \ + --audio-dir /path/to/ami/mix-headset/test \ + --rttm-file /path/to/AMI.SpeakerDiarization.Benchmark.test.rttm \ + --oracle-speakers \ + --file-id IS1009a +``` + +Use ``--oracle-speakers`` to isolate speaker assignment and clustering +quality when the true speaker count is known. Use ``--list-only`` to +verify audio/RTTM matching without running inference. + ## Limitations !!! warning "Speaker count > 7" @@ -108,14 +159,14 @@ Measured on VoxConverse dev files on Apple M2 Pro / M2 Max ## Future Work -!!! info "Single-dataset disclaimer" - All results above are from VoxConverse dev set only. We are actively - expanding evaluation to ensure the algorithm generalises well and is - not overfit to a single benchmark. +!!! info "Cross-dataset validation in progress" + VoxConverse remains the primary published benchmark. AMI is now used + as an additional meeting-domain check, and more datasets are needed + before making broad accuracy claims. **Planned evaluation:** -- **Cross-dataset validation** --- AMI, DIHARD III, CALLHOME, and other +- **Cross-dataset validation** --- DIHARD III, CALLHOME, and other standard benchmarks, run in isolated environments with controlled CPU/memory limits. - **Speaker count estimation comparison** --- dedicated benchmarks comparing diff --git a/docs/how-it-works.md b/docs/how-it-works.md index 463b220..4f5c5b5 100644 --- a/docs/how-it-works.md +++ b/docs/how-it-works.md @@ -76,9 +76,10 @@ speakers while keeping computational cost low. **Step 3 --- Silhouette refinement.** BIC is used as an anchor, then a small neighbourhood around it is scored with silhouette over cosine -distance. The candidate range is clamped by `min_speakers`, -`max_speakers`, and the number of available embeddings. This catches -some BIC undercounts and overcounts without searching the full range. +distance plus a small logarithmic bonus for larger *k*. The candidate +range is clamped by `min_speakers`, `max_speakers`, and the number of +available embeddings. This catches some BIC undercounts and overcounts +without searching the full range. !!! warning For **8 or more speakers** the estimator can undercount. diff --git a/docs/index.md b/docs/index.md index 6dbac20..27a348f 100644 --- a/docs/index.md +++ b/docs/index.md @@ -29,7 +29,7 @@ for seg in result.segments: | GPU required | No | No (7x slower on CPU) | No | | HuggingFace account | No | Yes | Yes | | Auto speaker count | Yes | Yes | Yes | -| DER (VoxConverse dev) | **~5.0%** | ~11.2% | ~8.5% | +| DER (VoxConverse dev) | **~4.8%** | ~11.2% | ~8.5% | | CPU speed (RTF) | **0.12** | 0.86 | --- | DER and speed numbers for pyannote are from their @@ -40,7 +40,7 @@ The diarize number is from the VoxConverse dev evaluation described in ## Next Steps - [How It Works](how-it-works.md) --- pipeline architecture and algorithms -- [Benchmarks](benchmarks.md) --- VoxConverse evaluation, speed comparison, limitations +- [Benchmarks](benchmarks.md) --- VoxConverse, AMI, speed comparison, limitations - [API Reference](api.md) --- full auto-generated API documentation ## License diff --git a/pyproject.toml b/pyproject.toml index 9d9a759..3fb4472 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "hatchling.build" [project] name = "diarize" -version = "0.1.1" +version = "0.1.2" description = "Speaker diarization for Python — detect who spoke when in audio files. CPU-only, no GPU, no API keys, no account signup. Automatic speaker count detection." readme = "README.md" license = "Apache-2.0" diff --git a/scripts/benchmark_rttm.py b/scripts/benchmark_rttm.py new file mode 100644 index 0000000..930a450 --- /dev/null +++ b/scripts/benchmark_rttm.py @@ -0,0 +1,514 @@ +#!/usr/bin/env python3 +"""Benchmark diarize against RTTM speaker annotations. + +This runner is intentionally dataset-agnostic: point it at an audio directory +and an RTTM directory and it will match files by stem, run diarize, and report +DER plus speaker-count accuracy. + +Examples: + python scripts/benchmark_rttm.py \ + --dataset voxconverse-dev \ + --audio-dir /Users/lukashov/records/benchmark/audio/audio \ + --rttm-dir /Users/lukashov/records/benchmark/rttm_annotations/dev \ + --output /Users/lukashov/records/benchmark/results_voxconverse_dev.json + + python scripts/benchmark_rttm.py \ + --dataset voxconverse-test \ + --audio-dir /Users/lukashov/records/benchmark/audio/test \ + --rttm-dir /Users/lukashov/records/benchmark/rttm_annotations/test \ + --output /Users/lukashov/records/benchmark/results_voxconverse_test.json + + python scripts/benchmark_rttm.py \ + --dataset ami-test \ + --audio-dir /Users/lukashov/records/benchmark/ami/audio/mix-headset/test \ + --rttm-file /path/to/AMI.SpeakerDiarization.Benchmark.test.rttm \ + --output /Users/lukashov/records/benchmark/ami/results_test.json + + # Isolate assignment/clustering quality when the true speaker count is known. + python scripts/benchmark_rttm.py ... --oracle-speakers +""" + +from __future__ import annotations + +import argparse +import json +import os +import sys +import tempfile +import time +import warnings +from dataclasses import asdict, dataclass +from pathlib import Path +from statistics import mean, median +from typing import Any + +_TEMP_DIR = Path(tempfile.gettempdir()) +os.environ.setdefault("XDG_CACHE_HOME", str(_TEMP_DIR / "diarize-cache")) +os.environ.setdefault("MPLCONFIGDIR", str(_TEMP_DIR / "diarize-matplotlib")) +os.environ.setdefault("LOKY_MAX_CPU_COUNT", str(os.cpu_count() or 1)) +warnings.filterwarnings("ignore", message="'uem' was approximated.*") + +_AUDIO_SUFFIXES = (".wav", ".flac", ".mp3", ".m4a", ".ogg") +_DIARIZE: tuple[Any, Any] | None = None +_PYANNOTE: tuple[Any, Any, Any] | None = None + + +@dataclass +class BenchmarkFile: + file_id: str + audio_path: Path + rttm_source: str + rttm_lines: list[str] + gt_speakers: int + + +@dataclass +class FileResult: + file_id: str + gt_speakers: int + predicted_speakers: int + speaker_delta: int + der: float | None + elapsed_sec: float + duration: float + n_segments: int + error: str | None = None + + +def load_pyannote() -> tuple[Any, Any, Any]: + """Load optional pyannote benchmark dependencies lazily.""" + global _PYANNOTE + if _PYANNOTE is not None: + return _PYANNOTE + + try: + from pyannote.core import Annotation + from pyannote.core import Segment as PySegment + from pyannote.metrics.diarization import DiarizationErrorRate + except ImportError as exc: # pragma: no cover - depends on optional local deps. + raise SystemExit( + "benchmark_rttm.py requires pyannote.metrics. Install it in your benchmark " + "environment, for example: pip install pyannote.metrics" + ) from exc + + _PYANNOTE = (Annotation, PySegment, DiarizationErrorRate) + return _PYANNOTE + + +def load_diarize() -> tuple[Any, Any]: + """Load diarize lazily after benchmark-specific env setup.""" + global _DIARIZE + if _DIARIZE is not None: + return _DIARIZE + + from diarize import diarize + from diarize.utils import get_audio_duration + + _DIARIZE = (diarize, get_audio_duration) + return _DIARIZE + + +def parse_rttm_lines_to_annotation(file_id: str, rttm_lines: list[str]) -> Any: + """Parse a standard RTTM file into a pyannote Annotation.""" + Annotation, PySegment, _ = load_pyannote() + annotation = Annotation(uri=file_id) + for line in rttm_lines: + parts = line.strip().split() + if len(parts) < 8 or parts[0] != "SPEAKER": + continue + start = float(parts[3]) + duration = float(parts[4]) + speaker = parts[7] + annotation[PySegment(start, start + duration)] = speaker + return annotation + + +def rttm_lines_speaker_count(rttm_lines: list[str]) -> int: + """Count unique speakers in an RTTM file.""" + speakers: set[str] = set() + for line in rttm_lines: + parts = line.strip().split() + if len(parts) >= 8 and parts[0] == "SPEAKER": + speakers.add(parts[7]) + return len(speakers) + + +def result_to_annotation(file_id: str, segments) -> Any: + """Convert diarize result segments to pyannote Annotation.""" + Annotation, PySegment, _ = load_pyannote() + annotation = Annotation(uri=file_id) + for segment in segments: + annotation[PySegment(segment.start, segment.end)] = segment.speaker + return annotation + + +def build_audio_index(audio_dir: Path) -> dict[str, Path]: + """Index audio files recursively by stem and simple AMI-style aliases.""" + index: dict[str, Path] = {} + duplicates: set[str] = set() + for path in sorted(audio_dir.rglob("*")): + if not path.is_file() or path.suffix.lower() not in _AUDIO_SUFFIXES: + continue + + aliases = [path.stem] + if "." in path.stem: + aliases.append(path.stem.split(".", 1)[0]) + + for alias in aliases: + if alias in index and index[alias] != path: + duplicates.add(alias) + continue + index[alias] = path + + if duplicates: + duplicate_list = ", ".join(sorted(duplicates)[:10]) + print( + f"Warning: ignored duplicate audio stems ({duplicate_list})", + file=sys.stderr, + ) + return index + + +def load_rttm_groups_from_dir(rttm_dir: Path) -> dict[str, tuple[str, list[str]]]: + """Load one-RTTM-per-recording annotations from a directory.""" + groups: dict[str, tuple[str, list[str]]] = {} + for rttm_path in sorted(rttm_dir.rglob("*.rttm")): + lines = rttm_path.read_text(encoding="utf-8").splitlines() + groups[rttm_path.stem] = (str(rttm_path), lines) + return groups + + +def load_rttm_groups_from_file(rttm_file: Path) -> dict[str, tuple[str, list[str]]]: + """Load a combined RTTM file and group lines by recording id.""" + groups: dict[str, tuple[str, list[str]]] = {} + grouped_lines: dict[str, list[str]] = {} + with rttm_file.open(encoding="utf-8") as file: + for line in file: + parts = line.strip().split() + if len(parts) < 8 or parts[0] != "SPEAKER": + continue + grouped_lines.setdefault(parts[1], []).append(line.rstrip("\n")) + + for file_id, lines in grouped_lines.items(): + groups[file_id] = (str(rttm_file), lines) + return groups + + +def collect_files( + audio_dir: Path, + rttm_dir: Path | None, + rttm_file: Path | None, + *, + gt_min: int, + gt_max: int, +) -> tuple[list[BenchmarkFile], list[str]]: + """Match RTTM files to audio files by stem.""" + audio_index = build_audio_index(audio_dir) + if rttm_file is not None: + rttm_groups = load_rttm_groups_from_file(rttm_file) + elif rttm_dir is not None: + rttm_groups = load_rttm_groups_from_dir(rttm_dir) + else: # pragma: no cover - parse_args/main validate this. + raise ValueError("Either rttm_dir or rttm_file is required") + + files: list[BenchmarkFile] = [] + missing_audio: list[str] = [] + + for file_id, (source, lines) in sorted(rttm_groups.items()): + gt_speakers = rttm_lines_speaker_count(lines) + if gt_speakers < gt_min or gt_speakers > gt_max: + continue + + audio_path = audio_index.get(file_id) + if audio_path is None: + missing_audio.append(file_id) + continue + + files.append( + BenchmarkFile( + file_id=file_id, + audio_path=audio_path, + rttm_source=source, + rttm_lines=lines, + gt_speakers=gt_speakers, + ) + ) + + return files, missing_audio + + +def load_existing_results(output_path: Path | None) -> list[FileResult]: + """Load previous results from a JSON output file.""" + if output_path is None or not output_path.exists(): + return [] + + with output_path.open(encoding="utf-8") as file: + payload = json.load(file) + + rows = payload["files"] if isinstance(payload, dict) and "files" in payload else payload + return [FileResult(**row) for row in rows] + + +def write_results( + output_path: Path | None, + *, + dataset: str, + args: argparse.Namespace, + results: list[FileResult], +) -> None: + """Persist intermediate benchmark results.""" + if output_path is None: + return + + payload = { + "dataset": dataset, + "audio_dir": str(args.audio_dir), + "rttm_dir": str(args.rttm_dir) if args.rttm_dir else None, + "rttm_file": str(args.rttm_file) if args.rttm_file else None, + "collar": args.collar, + "skip_overlap": not args.score_overlap, + "oracle_speakers": args.oracle_speakers, + "min_speakers": args.min_speakers, + "max_speakers": args.max_speakers, + "files": [asdict(result) for result in results], + } + output_path.parent.mkdir(parents=True, exist_ok=True) + output_path.write_text(json.dumps(payload, indent=2), encoding="utf-8") + + +def run_file( + file: BenchmarkFile, + *, + args: argparse.Namespace, + metric: Any, +) -> FileResult: + """Run diarize and compute metrics for one file.""" + diarize_fn, get_audio_duration_fn = load_diarize() + start_time = time.time() + duration = get_audio_duration_fn(file.audio_path) + num_speakers = file.gt_speakers if args.oracle_speakers else args.num_speakers + + result = diarize_fn( + file.audio_path, + min_speakers=args.min_speakers, + max_speakers=args.max_speakers, + num_speakers=num_speakers, + ) + elapsed = time.time() - start_time + + if not result.segments: + return FileResult( + file_id=file.file_id, + gt_speakers=file.gt_speakers, + predicted_speakers=0, + speaker_delta=-file.gt_speakers, + der=None, + elapsed_sec=round(elapsed, 2), + duration=round(duration, 2), + n_segments=0, + error="no_segments", + ) + + reference = parse_rttm_lines_to_annotation(file.file_id, file.rttm_lines) + hypothesis = result_to_annotation(file.file_id, result.segments) + der = float(metric(reference, hypothesis) * 100) + predicted_speakers = result.num_speakers + + return FileResult( + file_id=file.file_id, + gt_speakers=file.gt_speakers, + predicted_speakers=predicted_speakers, + speaker_delta=predicted_speakers - file.gt_speakers, + der=round(der, 2), + elapsed_sec=round(elapsed, 2), + duration=round(duration or result.audio_duration, 2), + n_segments=len(result.segments), + ) + + +def print_summary(results: list[FileResult]) -> None: + """Print aggregate DER and speaker-count metrics.""" + valid = [result for result in results if result.der is not None] + if not valid: + print("No valid results.") + return + + ders = [result.der for result in valid if result.der is not None] + total_duration = sum(max(result.duration, 0.0) for result in valid) + weighted_der = ( + sum((result.der or 0.0) * result.duration for result in valid) / total_duration + if total_duration > 0 + else mean(ders) + ) + + exact = sum(1 for result in valid if result.speaker_delta == 0) + within_1 = sum(1 for result in valid if abs(result.speaker_delta) <= 1) + + print("\nSummary") + print(f" Files: {len(valid)}") + print(f" Weighted DER: {weighted_der:.2f}%") + print(f" Mean DER: {mean(ders):.2f}%") + print(f" Median DER: {median(ders):.2f}%") + print(f" Exact count: {exact}/{len(valid)} ({100 * exact / len(valid):.0f}%)") + print(f" Within +/-1: {within_1}/{len(valid)} ({100 * within_1 / len(valid):.0f}%)") + + by_gt: dict[int, list[FileResult]] = {} + for result in valid: + by_gt.setdefault(result.gt_speakers, []).append(result) + + print("\nBy ground-truth speaker count") + print(f" {'GT':>3s} {'N':>3s} {'DER':>7s} {'Exact':>7s} {'Bias':>7s}") + for gt_speakers, rows in sorted(by_gt.items()): + row_ders = [row.der for row in rows if row.der is not None] + row_exact = sum(1 for row in rows if row.speaker_delta == 0) + row_bias = mean(row.speaker_delta for row in rows) + print( + f" {gt_speakers:>3d} {len(rows):>3d} " + f"{mean(row_ders):>6.2f}% {row_exact:>3d}/{len(rows):<3d} " + f"{row_bias:>+7.2f}" + ) + + +def print_file_inventory(files: list[BenchmarkFile]) -> None: + """Print matched-file inventory without running inference.""" + if not files: + return + + by_gt: dict[int, int] = {} + for file in files: + by_gt[file.gt_speakers] = by_gt.get(file.gt_speakers, 0) + 1 + + print("\nGround-truth speaker distribution") + print(f" {'GT':>3s} {'N':>3s}") + for gt_speakers, count in sorted(by_gt.items()): + print(f" {gt_speakers:>3d} {count:>3d}") + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description="Benchmark diarize on audio + RTTM data.") + parser.add_argument("--dataset", default="dataset", help="Dataset label used in output JSON") + parser.add_argument("--audio-dir", type=Path, required=True, help="Directory with audio files") + parser.add_argument("--rttm-dir", type=Path, default=None, help="Directory with RTTM files") + parser.add_argument("--rttm-file", type=Path, default=None, help="Combined RTTM file") + parser.add_argument("--output", type=Path, default=None, help="Optional JSON output path") + parser.add_argument("--max-files", type=int, default=0, help="Limit number of files (0=all)") + parser.add_argument( + "--file-id", + action="append", + default=[], + help="Only run a specific file id; may be passed multiple times", + ) + parser.add_argument( + "--gt-min", + type=int, + default=0, + help="Only files with at least N speakers", + ) + parser.add_argument( + "--gt-max", + type=int, + default=999, + help="Only files with at most N speakers", + ) + parser.add_argument("--min-speakers", type=int, default=1, help="Minimum auto speaker count") + parser.add_argument("--max-speakers", type=int, default=20, help="Maximum auto speaker count") + parser.add_argument("--num-speakers", type=int, default=None, help="Use fixed speaker count") + parser.add_argument("--list-only", action="store_true", help="Only list matched files") + parser.add_argument( + "--oracle-speakers", + action="store_true", + help="Use each file's RTTM speaker count as num_speakers", + ) + parser.add_argument("--resume", action="store_true", help="Resume from --output results") + parser.add_argument("--collar", type=float, default=0.25, help="DER collar in seconds") + parser.add_argument( + "--score-overlap", + action="store_true", + help="Score overlapped speech instead of skipping it", + ) + return parser.parse_args() + + +def main() -> None: + args = parse_args() + if args.oracle_speakers and args.num_speakers is not None: + raise SystemExit("--oracle-speakers and --num-speakers are mutually exclusive") + if (args.rttm_dir is None) == (args.rttm_file is None): + raise SystemExit("Pass exactly one of --rttm-dir or --rttm-file") + if not args.audio_dir.exists(): + raise SystemExit(f"Audio directory not found: {args.audio_dir}") + if args.rttm_dir is not None and not args.rttm_dir.exists(): + raise SystemExit(f"RTTM directory not found: {args.rttm_dir}") + if args.rttm_file is not None and not args.rttm_file.exists(): + raise SystemExit(f"RTTM file not found: {args.rttm_file}") + + files, missing_audio = collect_files( + args.audio_dir, + args.rttm_dir, + args.rttm_file, + gt_min=args.gt_min, + gt_max=args.gt_max, + ) + if args.file_id: + requested_ids = set(args.file_id) + files = [file for file in files if file.file_id in requested_ids] + if args.max_files > 0: + files = files[: args.max_files] + + print(f"Dataset: {args.dataset}") + print(f"Audio dir: {args.audio_dir}") + if args.rttm_dir is not None: + print(f"RTTM dir: {args.rttm_dir}") + else: + print(f"RTTM file: {args.rttm_file}") + print(f"Matched: {len(files)} files") + if missing_audio: + print(f"Missing audio: {len(missing_audio)} RTTM files") + + if args.list_only: + print_file_inventory(files) + return + + results = load_existing_results(args.output) if args.resume else [] + done_ids = {result.file_id for result in results} + if done_ids: + print(f"Resuming: {len(done_ids)} existing results") + + _, _, DiarizationErrorRate = load_pyannote() + metric = DiarizationErrorRate(collar=args.collar, skip_overlap=not args.score_overlap) + total = len(files) + for index, file in enumerate(files, 1): + if file.file_id in done_ids: + print(f"[{index:3d}/{total}] {file.file_id} skipped") + continue + + print(f"[{index:3d}/{total}] {file.file_id} gt={file.gt_speakers}", end=" ", flush=True) + try: + result = run_file(file, args=args, metric=metric) + except Exception as exc: # noqa: BLE001 - benchmark should continue after bad files. + result = FileResult( + file_id=file.file_id, + gt_speakers=file.gt_speakers, + predicted_speakers=-1, + speaker_delta=0, + der=None, + elapsed_sec=0.0, + duration=0.0, + n_segments=0, + error=str(exc), + ) + print(f"error={exc}") + else: + der_text = "n/a" if result.der is None else f"{result.der:.2f}%" + print( + f"pred={result.predicted_speakers} DER={der_text} time={result.elapsed_sec:.1f}s" + ) + + results.append(result) + write_results(args.output, dataset=args.dataset, args=args, results=results) + + print_summary(results) + + +if __name__ == "__main__": + main() diff --git a/src/diarize/__init__.py b/src/diarize/__init__.py index 9c1a383..2755c78 100644 --- a/src/diarize/__init__.py +++ b/src/diarize/__init__.py @@ -42,7 +42,7 @@ class _RawSegment(NamedTuple): speaker: str -__version__ = "0.1.0" +__version__ = "0.1.2" __all__ = [ "diarize", "DiarizeResult", diff --git a/src/diarize/clustering.py b/src/diarize/clustering.py index c7459e1..c372192 100644 --- a/src/diarize/clustering.py +++ b/src/diarize/clustering.py @@ -28,6 +28,8 @@ logger = logging.getLogger(__name__) +_SILHOUETTE_K_BONUS = 0.04 + __all__ = [ "estimate_speakers", "cluster_spectral", @@ -288,6 +290,16 @@ def _silhouette_candidate_counts( return list(range(lower, upper + 1)) +def _speaker_count_score(silhouette: float, k: int) -> float: + """Score speaker-count candidates from silhouette plus a small k prior. + + Raw silhouette tends to prefer fewer, broader clusters on noisy + speaker embeddings. The logarithmic k bonus counteracts that bias + without overwhelming the separation score. + """ + return silhouette + _SILHOUETTE_K_BONUS * np.log(max(k, 1)) + + def cluster_auto( embeddings: np.ndarray, min_speakers: int = 1, @@ -317,19 +329,25 @@ def cluster_auto( candidates = _silhouette_candidate_counts(k, n, min_speakers, max_speakers) if len(candidates) > 1: distance = np.maximum(1 - (cosine_similarity(embeddings) + 1) / 2, 0) - best_k, best_labels, best_sil = k, None, -1.0 + best_k, best_labels, best_score = k, None, -1.0 for c in candidates: labels_c = cluster_spectral(embeddings, c) sil = silhouette_score(distance, labels_c, metric="precomputed") - logger.debug("Silhouette refinement: k=%d sil=%.4f", c, sil) - if sil > best_sil: - best_k, best_labels, best_sil = c, labels_c, sil + score = _speaker_count_score(sil, c) + logger.debug( + "Silhouette refinement: k=%d sil=%.4f score=%.4f", + c, + sil, + score, + ) + if score > best_score: + best_k, best_labels, best_score = c, labels_c, score if best_k != k: logger.info( - "Silhouette refinement: BIC k=%d -> k=%d (sil=%.4f)", + "Silhouette refinement: BIC k=%d -> k=%d (score=%.4f)", k, best_k, - best_sil, + best_score, ) details.best_k = best_k return best_labels, details # type: ignore[return-value] diff --git a/tests/test_diarize.py b/tests/test_diarize.py index dec799d..40f57c9 100644 --- a/tests/test_diarize.py +++ b/tests/test_diarize.py @@ -597,6 +597,12 @@ def test_silhouette_candidates_respect_min_speakers(self): max_speakers=10, ) == [4, 5, 6, 7, 8] + def test_speaker_count_score_applies_small_k_bonus(self): + from diarize.clustering import _speaker_count_score + + assert _speaker_count_score(0.50, 4) > _speaker_count_score(0.50, 2) + assert _speaker_count_score(0.50, 2) > _speaker_count_score(0.44, 4) + def test_auto_updates_details_when_silhouette_changes_k(self): from diarize.clustering import cluster_auto from diarize.utils import SpeakerEstimationDetails @@ -1262,7 +1268,7 @@ def test_import_result_classes(self): def test_version(self): from diarize import __version__ - assert __version__ == "0.1.0" + assert __version__ == "0.1.2" def test_import_pydantic_models(self): from diarize.utils import (