Skip to content

Commit 34f0485

Browse files
committed
fix per-frame evaluation based on eval_symbolic
1 parent 796d012 commit 34f0485

1 file changed

Lines changed: 41 additions & 64 deletions

File tree

matchmaker/matchmaker.py

Lines changed: 41 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
import numpy as np
88
import partitura
9+
import scipy.interpolate
910
from partitura.io.exportmidi import get_ppq
1011
from partitura.musicanalysis.performance_codec import get_time_maps_from_alignment
1112
from partitura.score import Part, merge_parts
@@ -73,7 +74,7 @@
7374
"step_size": 3,
7475
},
7576
"audio_outerhmm": {
76-
"sample_rate": 16000,
77+
"sample_rate": 24000,
7778
"frame_rate": 25,
7879
"s_j": 0.0,
7980
},
@@ -594,81 +595,57 @@ def run_evaluation(
594595
# perf_annots = perf_annots[:min_length]
595596

596597
wp = self.score_follower.warping_path
597-
mode = "state" if self.input_type == "midi" else "beat"
598+
score_annots_beats = self.build_score_annotations(
599+
level, musical_beat, return_type="beats"
600+
)
598601

599-
if mode == "beat":
600-
# Beat mode forward needs beat positions to match against wp[0].
601-
score_annots_beats = self.build_score_annotations(
602-
level, musical_beat, return_type="beats"
603-
)
604-
perf_annots_predicted = transfer_from_score_to_predicted_perf(
605-
wp,
606-
score_annots_beats,
607-
frame_rate=self.frame_rate,
608-
mode=mode,
609-
)
602+
# --- Per-frame evaluation ---
603+
# Build GT interpolator: score beat → perf time (seconds)
604+
valid_gt = np.isfinite(perf_annots)
605+
gt_interp = scipy.interpolate.interp1d(
606+
score_annots_beats[valid_gt],
607+
perf_annots[valid_gt],
608+
bounds_error=False,
609+
fill_value=np.nan,
610+
)
611+
612+
wp_score = wp[0].astype(float)
613+
wp_perf = wp[1].astype(float)
614+
615+
# Convert wp perf axis to seconds
616+
if self.input_type == "midi":
617+
# MIDI: wp_perf is IOI-accumulated from 0; shift by first note onset
618+
_perf = partitura.load_performance_midi(self.performance_file)
619+
midi_offset = float(_perf.note_array()["onset_sec"].min())
620+
wp_perf_sec = wp_perf + midi_offset
610621
else:
611-
perf_annots_predicted = transfer_from_score_to_predicted_perf(
612-
wp,
613-
score_annots,
614-
frame_rate=self.frame_rate,
615-
mode=mode,
616-
)
622+
# Audio: wp_perf is frame index
623+
wp_perf_sec = wp_perf / self.frame_rate
617624

618-
score_annots_predicted = transfer_from_perf_to_predicted_score(
625+
# For each wp entry: GT perf time for predicted beat vs actual perf time
626+
gt_perf_times = gt_interp(wp_score)
627+
perf_annots_predicted = transfer_from_score_to_predicted_perf(
619628
wp,
620-
perf_annots,
629+
score_annots_beats,
621630
frame_rate=self.frame_rate,
622-
mode=mode,
631+
mode="beat",
623632
)
624-
score_annots = score_annots[: len(score_annots_predicted)]
625633

626-
if original_perf_annots_counts != len(perf_annots_predicted):
627-
print(
628-
f"Length of the annotation changed: {original_perf_annots_counts} -> {len(perf_annots_predicted)}"
629-
)
630-
631-
# Evaluation metrics
632634
if domain == "performance":
633635
eval_results = get_evaluation_results(
634-
perf_annots,
635-
perf_annots_predicted,
636-
total_counts=original_perf_annots_counts,
636+
gt_perf_times,
637+
wp_perf_sec,
638+
total_counts=len(wp_score),
637639
tolerances=tolerances,
638-
perf_times=perf_annots,
640+
perf_times=wp_perf_sec,
639641
alignment_duration=self.alignment_duration,
640642
)
641643
else:
642-
if mode == "beat":
643-
# Beat mode reverse already returns beat positions directly.
644-
# score_annots_predicted was already computed above with mode="beat",
645-
# which returns beats directly from transfer_positions.
646-
pass
647-
elif mode == "state":
648-
# State mode reverse returns state indices (via causal lookup).
649-
# Map state indices directly to beats through state_space.
650-
state_space = self.score_follower.state_space
651-
raw_states = transfer_from_perf_to_predicted_score(
652-
wp,
653-
perf_annots,
654-
frame_rate=self.frame_rate,
655-
mode=mode,
656-
output="frames",
657-
)
658-
score_annots_predicted = np.array(
659-
[
660-
(
661-
float(state_space[int(s)])
662-
if not np.isnan(s) and 0 <= int(s) < len(state_space)
663-
else np.nan
664-
)
665-
for s in raw_states
666-
]
667-
)
668-
else:
669-
score_annots_predicted = self.convert_timestamps_to_beats(
670-
score_annots_predicted
671-
)
644+
# Score domain: compare predicted beats vs GT beats
645+
score_annots_predicted = transfer_from_perf_to_predicted_score(
646+
wp, perf_annots, frame_rate=self.frame_rate, mode="beat"
647+
)
648+
score_annots = score_annots[: len(score_annots_predicted)]
672649
if tolerances == TOLERANCES_IN_MILLISECONDS:
673650
tolerances = TOLERANCES_IN_BEATS
674651
eval_results = get_evaluation_results(
@@ -687,7 +664,7 @@ def run_evaluation(
687664
# Debug: save warping path TSV, results JSON, and plots
688665
if debug and save_dir is not None:
689666
# For plot y-axis: use beats when wp[0] is in beats
690-
debug_score_annots = score_annots_beats if mode == "beat" else score_annots
667+
debug_score_annots = score_annots_beats
691668
save_debug_results(
692669
warping_path=self.score_follower.warping_path,
693670
score_annots=debug_score_annots,

0 commit comments

Comments
 (0)