Skip to content

Commit 42e3c1c

Browse files
committed
fix bug in plot_alignment
1 parent d2a6a4d commit 42e3c1c

4 files changed

Lines changed: 72 additions & 47 deletions

File tree

matchmaker/matchmaker.py

Lines changed: 47 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -551,38 +551,45 @@ def run_evaluation(
551551
self,
552552
perf_annotations: Union[PathLike, np.ndarray],
553553
level: str = "note",
554-
tolerances: list = TOLERANCES_IN_MILLISECONDS,
555-
musical_beat: bool = False, # beat annots are difference in some dataset
554+
tolerances: list = None,
555+
musical_beat: bool = False,
556556
debug: bool = False,
557557
save_dir: PathLike = None,
558558
run_name: str = None,
559-
domain: str = "performance", # "score" or "performance"
559+
domain: str = "score",
560560
plot_dist_matrix: bool = True,
561561
) -> dict:
562562
"""
563-
Evaluate the score following process
563+
Evaluate the score following process.
564+
565+
When domain="score" (default), returns beat-based metrics as primary
566+
and ms-based metrics under "ms" key. When domain="performance",
567+
returns ms-based metrics only (legacy behavior).
564568
565569
Parameters
566570
----------
567571
perf_annotations : PathLike or np.ndarray
568-
Path to the performance annotations file (tab-separated),
569-
or numpy array of annotation times in seconds.
572+
Path to the performance annotations file or numpy array of onset times (seconds).
570573
level : str
571-
Level of annotations to use: bar, beat or note
572-
tolerance : list
573-
Tolerances to use for evaluation (in milliseconds)
574+
Annotation level: "beat" or "note"
575+
tolerances : list or None
576+
Tolerances for evaluation. If None, uses default for the domain.
577+
musical_beat : bool
578+
Whether to use musical beat
574579
debug : bool
575-
Whether to save the score and performance audio with beat annotations
580+
Whether to save debug outputs
576581
domain : str
577-
Evaluation domain, either "score" or "performance".
578-
"score" domain evaluates in beat unit, "performance" domain evaluates in second unit. (Default: "performance")
582+
"score" (default, beat-based primary) or "performance" (ms-based, legacy)
579583
580584
Returns
581585
-------
582586
dict
583-
Evaluation results with mean, median, std, skewness, kurtosis, and
584-
accuracy for each tolerance
587+
Evaluation results. If domain="score", includes both beat and ms metrics.
585588
"""
589+
if tolerances is None:
590+
tolerances = (
591+
TOLERANCES_IN_BEATS if domain == "score" else TOLERANCES_IN_MILLISECONDS
592+
)
586593
if not self._has_run:
587594
raise ValueError("Must call run() before evaluation")
588595

@@ -643,26 +650,43 @@ def run_evaluation(
643650
wp_perf_sec,
644651
total_counts=len(wp_score),
645652
tolerances=tolerances,
646-
perf_times=wp_perf_sec,
647-
alignment_duration=self.alignment_duration,
648653
)
649654
else:
650-
# Score domain: compare predicted beats vs GT beats
655+
# Score domain: beat-based (primary) + ms-based (secondary)
651656
score_annots_predicted = transfer_positions(
652657
wp, perf_annots, frame_rate=self.frame_rate, domain="score"
653658
)
654659
score_annots = score_annots[: len(score_annots_predicted)]
655-
if tolerances == TOLERANCES_IN_MILLISECONDS:
656-
tolerances = TOLERANCES_IN_BEATS
657-
eval_results = get_evaluation_results(
660+
beat_tolerances = (
661+
tolerances
662+
if tolerances != TOLERANCES_IN_MILLISECONDS
663+
else TOLERANCES_IN_BEATS
664+
)
665+
beat_results = get_evaluation_results(
658666
score_annots,
659667
score_annots_predicted,
660668
total_counts=original_perf_annots_counts,
661-
tolerances=tolerances,
669+
tolerances=beat_tolerances,
662670
in_seconds=False,
663-
perf_times=perf_annots,
664-
alignment_duration=self.alignment_duration,
665671
)
672+
ms_results = get_evaluation_results(
673+
gt_perf_times,
674+
wp_perf_sec,
675+
total_counts=len(wp_score),
676+
tolerances=TOLERANCES_IN_MILLISECONDS,
677+
)
678+
eval_results = {"beat": beat_results, "ms": ms_results}
679+
680+
# Real-Time Factor (domain-independent)
681+
if self.alignment_duration is not None:
682+
finite_perf = perf_annots[np.isfinite(perf_annots)]
683+
if len(finite_perf) > 0:
684+
perf_duration = float(np.max(finite_perf) - np.min(finite_perf))
685+
if perf_duration > 0:
686+
eval_results["rtf"] = float(
687+
f"{self.alignment_duration / perf_duration:.4f}"
688+
)
689+
666690
if self.input_type == "audio":
667691
latency_results = self.get_latency_stats()
668692
eval_results.update(latency_results)

matchmaker/utils/eval.py

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -108,8 +108,6 @@ def get_evaluation_results(
108108
total_counts,
109109
tolerances=TOLERANCES_IN_MILLISECONDS,
110110
in_seconds=True,
111-
perf_times=None,
112-
alignment_duration=None,
113111
):
114112
if in_seconds:
115113
errors_in_delay = (gt_annots - predicted_annots) * 1000
@@ -141,12 +139,4 @@ def get_evaluation_results(
141139
f"{np.sum(np.abs(errors_in_delay) <= tau) / total_counts:.4f}"
142140
)
143141

144-
# Real-Time Factor (wall-clock alignment_duration / performance_duration).
145-
if alignment_duration is not None and perf_times is not None:
146-
finite_perf = perf_times[np.isfinite(perf_times)]
147-
if len(finite_perf) > 0:
148-
perf_duration = float(np.max(finite_perf) - np.min(finite_perf))
149-
if perf_duration > 0:
150-
results["rtf"] = float(f"{alignment_duration / perf_duration:.4f}")
151-
152142
return results

matchmaker/utils/misc.py

Lines changed: 22 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -584,7 +584,6 @@ def plot_alignment(
584584

585585
# x-axis: performance time in frames
586586
x_gt = gt * float(frame_rate)
587-
x_pred = pred * float(frame_rate)
588587
wp_x = warping_path[1]
589588

590589
# y-axis: score position (beats)
@@ -596,12 +595,21 @@ def plot_alignment(
596595
else:
597596
wp_y = warping_path[0]
598597

598+
# GT score positions (y-axis for annotation dots)
599599
if score_y is not None:
600-
y = np.asarray(score_y, dtype=float)[:n]
600+
y_gt = np.asarray(score_y, dtype=float)[:n]
601601
if show_dist and wp_in_beats and ref_frame_to_beat is not None:
602-
y = _beats_to_frames(y, ref_frame_to_beat)
602+
y_gt = _beats_to_frames(y_gt, ref_frame_to_beat)
603603
else:
604-
y = np.arange(n)
604+
y_gt = np.arange(n)
605+
606+
# Predicted score positions at GT perf times (perf→score direction)
607+
wp_x_sorted = np.asarray(wp_x, dtype=float)
608+
wp_y_sorted = np.asarray(wp_y, dtype=float)
609+
if len(wp_x_sorted) > 1:
610+
y_pred = np.interp(x_gt, wp_x_sorted, wp_y_sorted)
611+
else:
612+
y_pred = y_gt
605613

606614
# Plot layers
607615
ax.plot(
@@ -615,8 +623,8 @@ def plot_alignment(
615623
zorder=2,
616624
)
617625
ax.scatter(
618-
x_pred,
619-
y,
626+
x_gt,
627+
y_pred,
620628
label="predicted",
621629
s=80,
622630
alpha=0.9,
@@ -627,7 +635,7 @@ def plot_alignment(
627635
)
628636
ax.scatter(
629637
x_gt,
630-
y,
638+
y_gt,
631639
label="ground truth",
632640
s=120,
633641
alpha=0.9,
@@ -643,11 +651,14 @@ def plot_alignment(
643651

644652
# Beat tick labels when projected to frame space
645653
if show_dist and wp_in_beats and ref_frame_to_beat is not None:
646-
beat_min, beat_max = ref_frame_to_beat[0], ref_frame_to_beat[-1]
654+
finite_beats = ref_frame_to_beat[np.isfinite(ref_frame_to_beat)]
655+
beat_min, beat_max = (
656+
finite_beats[0],
657+
finite_beats[-1] if len(finite_beats) > 0 else (0, 1),
658+
)
659+
n_ticks = max(2, min(12, int(beat_max - beat_min) + 1))
647660
beat_ticks = np.unique(
648-
np.round(
649-
np.linspace(beat_min, beat_max, min(12, int(beat_max - beat_min) + 1))
650-
).astype(int)
661+
np.round(np.linspace(beat_min, beat_max, n_ticks)).astype(int)
651662
)
652663
ax.set_yticks(_beats_to_frames(beat_ticks.astype(float), ref_frame_to_beat))
653664
ax.set_yticklabels([str(b) for b in beat_ticks])

tests/test_matchmaker.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ def test_matchmaker_audio_run_with_evaluation(self):
130130

131131
# Then: the results should at least be 0.5
132132
for threshold in ["300ms", "500ms", "1000ms"]:
133-
self.assertGreaterEqual(results[threshold], 0.5)
133+
self.assertGreaterEqual(results["ms"][threshold], 0.5)
134134

135135
def test_matchmaker_audio_run_with_evaluation_cqt(self):
136136
# Given: a Matchmaker instance with audio input
@@ -159,7 +159,7 @@ def test_matchmaker_audio_run_with_evaluation_cqt(self):
159159

160160
# Then: the results should at least be 0.5
161161
for threshold in ["300ms", "500ms", "1000ms"]:
162-
self.assertGreaterEqual(results[threshold], 0.5)
162+
self.assertGreaterEqual(results["ms"][threshold], 0.5)
163163

164164
def test_matchmaker_audio_run_with_evaluation_in_beats(self):
165165
# Given: a Matchmaker instance with audio input
@@ -184,7 +184,7 @@ def test_matchmaker_audio_run_with_evaluation_in_beats(self):
184184

185185
# Then: the results should at least be 0.5
186186
for threshold in ["0.3b", "0.5b", "1b"]:
187-
self.assertGreaterEqual(results[threshold], 0.5)
187+
self.assertGreaterEqual(results["beat"][threshold], 0.5)
188188

189189
def test_matchmaker_audio_run_with_evaluation_before_run(self):
190190
# Given: a Matchmaker instance with audio input

0 commit comments

Comments
 (0)