Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
136 changes: 132 additions & 4 deletions Sources/TranscriptedCore/Speaker/EmbeddingClusterer.swift
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
// EmbeddingClusterer.swift
// Post-processes diarization speaker segments to fix two failure modes:
// Post-processes diarization speaker segments to fix three failure modes:
//
// Supports both Sortformer (streaming) and PyAnnote (offline) pipelines.
//
Expand All @@ -9,7 +9,15 @@
// Note: Skipped for PyAnnote offline output, where VBx clustering already
// handles speaker merging/fragmentation.
//
// 2. Merging: Different speakers collapsed into one diarizer ID.
// 2. Over-segmentation: One real voice split across several clusters that each
// accumulate enough speech to survive small-cluster absorption. This is why
// a one-on-one call can surface 4-7 "speakers" to name for a single person.
// Fixed by same-voice consolidation — agglomeratively merge clusters whose
// mean embeddings are as similar as the "same known person" auto-accept bar,
// recomputing centroids after each merge so genuinely distinct speakers in a
// crowded meeting do not chain-collapse.
//
// 3. Merging: Different speakers collapsed into one diarizer ID.
// Fixed by DB-informed split — compare per-segment embeddings against
// known speaker profiles and split clusters that contain 2+ distinct voices.

Expand All @@ -24,13 +32,21 @@ public enum EmbeddingClusterer {
///
/// - Parameter pairwiseMergeThreshold: Cosine similarity threshold for merging
/// fragmented speaker clusters. Pass `nil` to skip only the pairwise merge
/// phase; small-cluster absorption and DB-informed split still run.
/// phase; small-cluster absorption, same-voice consolidation, and
/// DB-informed split still run.
/// Sortformer default: 0.85 (conservative). Offline PyAnnote callers pass
/// `nil` because VBx already handles the base merge/fragmentation case.
/// - Parameter consolidationThreshold: Cosine similarity threshold for the
/// same-voice consolidation pass that collapses over-segmented large
/// clusters of one speaker. Pass `nil` to skip it. Defaults to the
/// `SpeakerNamingPolicy` auto-accept bar (0.88) so two clusters only merge
/// when they are more similar than we'd demand to auto-accept them as the
/// same known person.
public static func postProcess(
segments: [SpeakerSegment],
existingProfiles: [SpeakerProfile],
pairwiseMergeThreshold: Float? = 0.85
pairwiseMergeThreshold: Float? = 0.85,
consolidationThreshold: Float? = 0.88
) -> [SpeakerSegment] {
guard segments.count >= 2 else { return segments }
var result: [SpeakerSegment]
Expand All @@ -40,6 +56,9 @@ public enum EmbeddingClusterer {
result = segments
}
result = absorbSmallClusters(segments: result)
if let consolidationThreshold {
result = consolidateSameVoiceClusters(segments: result, threshold: consolidationThreshold)
}
result = dbInformedSplit(segments: result, profiles: existingProfiles)
return result
}
Expand Down Expand Up @@ -263,6 +282,115 @@ public enum EmbeddingClusterer {
}
}

// MARK: - Same-Voice Consolidation

/// Consolidate clusters that are almost certainly the same voice, even when
/// each cluster is large enough to survive `absorbSmallClusters`.
///
/// Offline VBx clustering sometimes splits one remote participant across
/// several speaker IDs that each accumulate well over `minClusterDuration`
/// of speech. `absorbSmallClusters` never touches them because it only folds
/// short clusters into large ones, so a one-on-one call can surface 4-7
/// "speakers" the user has to name for a single person.
///
/// This pass compares the mean embedding of every surviving cluster pair and
/// merges those above `threshold`. Two safeguards keep genuine
/// multi-speaker meetings intact:
/// - The threshold is high (0.88 by default — the `SpeakerNamingPolicy`
/// auto-accept bar). Distinct speakers rarely exceed ~0.6 cosine
/// similarity, so only near-identical voices merge.
/// - Merging is agglomerative with recomputed centroids: after A and B
/// merge, the combined centroid must still clear `threshold` against C
/// before C joins. This avoids the transitive A≈B, B≈C → A+B+C collapse
/// that made the broad pairwise merge unsafe on VBx output.
static func consolidateSameVoiceClusters(
segments: [SpeakerSegment],
threshold: Float = 0.88
) -> [SpeakerSegment] {
let distinctIds = Set(segments.map { $0.speakerId })
guard distinctIds.count >= 2 else { return segments }

// Collect embeddings per speaker. Prefer quality-filtered samples but
// fall back to all samples so every cluster has a centroid to compare.
var qualityEmbeddings: [Int: [[Float]]] = [:]
var allEmbeddings: [Int: [[Float]]] = [:]
for segment in segments {
guard let embedding = segment.embedding, !embedding.isEmpty else { continue }
allEmbeddings[segment.speakerId, default: []].append(embedding)
if segment.qualityScore >= 0.3, segment.duration >= 1.0 {
qualityEmbeddings[segment.speakerId, default: []].append(embedding)
}
}

// Live clusters: the raw embeddings backing each centroid, so we can
// recompute the centroid after every merge.
var clusterEmbeddings: [Int: [[Float]]] = [:]
for id in distinctIds {
let quality = qualityEmbeddings[id] ?? []
let embeddings = quality.isEmpty ? (allEmbeddings[id] ?? []) : quality
if !embeddings.isEmpty {
clusterEmbeddings[id] = embeddings
}
}
guard clusterEmbeddings.count >= 2 else { return segments }

var centroids: [Int: [Float]] = [:]
for (id, embeddings) in clusterEmbeddings {
centroids[id] = Transcription.computeMeanEmbedding(embeddings)
}

// old speaker ID → canonical surviving ID (identity to start).
var mergeMap: [Int: Int] = [:]
for id in clusterEmbeddings.keys { mergeMap[id] = id }

// Repeatedly merge the single most-similar pair above threshold,
// recomputing the merged centroid each round until nothing qualifies.
while centroids.count >= 2 {
let liveIds = centroids.keys.sorted()
var bestSim = threshold
var bestPair: (keep: Int, drop: Int)?
for i in 0..<liveIds.count {
for j in (i + 1)..<liveIds.count {
let a = liveIds[i], b = liveIds[j]
guard let ea = centroids[a], let eb = centroids[b] else { continue }
let sim = Float(Transcription.cosineSimilarityStatic(ea, eb))
if sim > bestSim {
bestSim = sim
bestPair = (keep: a, drop: b) // liveIds sorted, so a < b
}
}
}

guard let pair = bestPair else { break }
clusterEmbeddings[pair.keep, default: []].append(contentsOf: clusterEmbeddings[pair.drop] ?? [])
clusterEmbeddings[pair.drop] = nil
centroids[pair.keep] = Transcription.computeMeanEmbedding(clusterEmbeddings[pair.keep] ?? [])
centroids[pair.drop] = nil
for (old, canonical) in mergeMap where canonical == pair.drop {
mergeMap[old] = pair.keep
}
AppLogger.transcription.info("Consolidated same-voice clusters", [
"merged": "spk\(pair.drop)",
"into": "spk\(pair.keep)",
"similarity": String(format: "%.3f", bestSim)
])
}

guard mergeMap.contains(where: { $0.key != $0.value }) else { return segments }

return segments.map { segment in
let newId = mergeMap[segment.speakerId] ?? segment.speakerId
guard newId != segment.speakerId else { return segment }
return SpeakerSegment(
speakerId: newId,
startTime: segment.startTime,
endTime: segment.endTime,
embedding: segment.embedding,
qualityScore: segment.qualityScore
)
}
}

// MARK: - DB-Informed Split

/// Split clusters that contain 2+ known DB voices.
Expand Down
101 changes: 101 additions & 0 deletions Tests/TranscriptedCoreTests/EmbeddingClustererTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,107 @@ final class EmbeddingClustererTests: XCTestCase {
XCTAssertEqual(ids.filter { $0 == 2 }.count, 2)
}

func testConsolidateMergesOverSegmentedSameVoice() {
// One voice that VBx split into four near-identical large clusters.
let merged = EmbeddingClusterer.consolidateSameVoiceClusters(
segments: [
segment(speakerId: 1, startTime: 0, endTime: 40, embedding: [1.0, 0.0]),
segment(speakerId: 2, startTime: 40, endTime: 80, embedding: unitVector(cosineToXAxis: 0.99)),
segment(speakerId: 3, startTime: 80, endTime: 120, embedding: unitVector(cosineToXAxis: 0.97)),
segment(speakerId: 4, startTime: 120, endTime: 160, embedding: unitVector(cosineToXAxis: 0.95)),
],
threshold: 0.88
)

XCTAssertEqual(Set(merged.map(\.speakerId)).count, 1)
}

func testConsolidatePreservesDistinctSpeakers() {
// Realistic distinct voices sit well under ~0.6 cosine, so none merge.
let kept = EmbeddingClusterer.consolidateSameVoiceClusters(
segments: [
segment(speakerId: 1, startTime: 0, endTime: 40, embedding: unitVector(degrees: 0)),
segment(speakerId: 2, startTime: 40, endTime: 80, embedding: unitVector(degrees: 66)),
segment(speakerId: 3, startTime: 80, endTime: 120, embedding: unitVector(degrees: 132)),
],
threshold: 0.88
)

XCTAssertEqual(Set(kept.map(\.speakerId)).count, 3)
}

func testConsolidateDoesNotMergeAtAutoAcceptBoundary() {
// SpeakerNamingPolicy only auto-accepts above 0.88. The consolidation pass
// should use the same strict edge so genuinely similar voices get review.
let kept = EmbeddingClusterer.consolidateSameVoiceClusters(
segments: [
segment(speakerId: 1, startTime: 0, endTime: 40, embedding: [1.0, 0.0]),
segment(speakerId: 2, startTime: 40, endTime: 80, embedding: unitVector(cosineToXAxis: 0.88)),
],
threshold: 0.88
)

XCTAssertEqual(Set(kept.map(\.speakerId)).count, 2)
}

func testConsolidateDoesNotChainCollapseAcrossDissimilarEndpoints() {
// A≈B and B≈C, but A and C are far apart. Recomputed centroids must stop
// the transitive collapse that broke the broad pairwise merge.
let chained = EmbeddingClusterer.consolidateSameVoiceClusters(
segments: [
segment(speakerId: 1, startTime: 0, endTime: 40, embedding: unitVector(degrees: 0)),
segment(speakerId: 2, startTime: 40, endTime: 80, embedding: unitVector(degrees: 20)),
segment(speakerId: 3, startTime: 80, endTime: 120, embedding: unitVector(degrees: 40)),
],
threshold: 0.88
)

XCTAssertEqual(
Set(chained.map(\.speakerId)).count,
2,
"Recomputed centroids stop A≈B, B≈C from chain-collapsing into one speaker"
)
}

func testPostProcessConsolidatesOneOnOneCallToSingleSpeaker() {
// The reported case: a single remote voice over-segmented into four large
// clusters that all survive small-cluster absorption.
let voices: [[Float]] = [
[1.0, 0.0],
unitVector(cosineToXAxis: 0.99),
unitVector(cosineToXAxis: 0.98),
unitVector(cosineToXAxis: 0.97),
]
let segments = voices.enumerated().map { index, embedding in
segment(
speakerId: index + 1,
startTime: Double(index * 40),
endTime: Double(index * 40 + 40),
embedding: embedding
)
}

let processed = EmbeddingClusterer.postProcess(
segments: segments,
existingProfiles: [],
pairwiseMergeThreshold: nil
)
XCTAssertEqual(
Set(processed.map(\.speakerId)).count,
1,
"An over-segmented single remote voice should collapse to one speaker to name"
)

// The pass is opt-out: passing nil leaves the over-segmentation in place.
let notConsolidated = EmbeddingClusterer.postProcess(
segments: segments,
existingProfiles: [],
pairwiseMergeThreshold: nil,
consolidationThreshold: nil
)
XCTAssertEqual(Set(notConsolidated.map(\.speakerId)).count, 4)
}

private func segment(
speakerId: Int,
startTime: Double,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ final class SpeakerNamingSimulationRunnerTests: XCTestCase {
segments: [
segment(.system, 1, truth: "alex", expected: "Alex Rivera", text: "alex one", start: 0, embedding: alex),
segment(.system, 2, truth: "blair", expected: "Blair Stone", text: "blair one", start: 3, embedding: blair),
segment(.system, 3, truth: "alex", expected: "Alex Rivera", text: "alex two", start: 6, embedding: alex)
segment(.system, 3, truth: "alex", expected: "Alex Rivera", text: "alex two", start: 6, embedding: casey)
],
actions: [
.name(channel: .system, diarizerSpeakerId: 1, as: "Alex Rivera"),
Expand Down