diff --git a/speechscore/basis.py b/speechscore/basis.py index 87ffe41c..90a55c4f 100644 --- a/speechscore/basis.py +++ b/speechscore/basis.py @@ -33,12 +33,17 @@ def scoring(self, data, window=None, score_rate=None): audios[index] = audio if window is not None: + maxlen = len(audios[0]) framer = Framing(window * score_rate, window * score_rate, maxlen) nwin = framer.nwin result = {} for (t, win) in enumerate(framer): result_t = self.windowed_scoring([audio[win] for audio in audios], score_rate) result[t] = result_t + if win and maxlen > win.stop: + last_win = slice(win.stop, maxlen) + result_t = self.windowed_scoring([audio[last_win] for audio in audios], score_rate) + result[nwin] = result_t else: result = self.windowed_scoring(audios, score_rate) return result diff --git a/speechscore/scores/distill_mos/distill_mos.py b/speechscore/scores/distill_mos/distill_mos.py index a47e7762..bacabdba 100644 --- a/speechscore/scores/distill_mos/distill_mos.py +++ b/speechscore/scores/distill_mos/distill_mos.py @@ -8,11 +8,13 @@ def __init__(self): super(DISTILL_MOS, self).__init__(name='DISTILL_MOS') self.intrusive = False self.score_rate = 16000 - self.model = ConvTransformerSQAModel() + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + self.model = ConvTransformerSQAModel().to(self.device) self.model.eval() def windowed_scoring(self, audios, score_rate): - score = self.model(torch.from_numpy(np.expand_dims(audios[0], axis=0)).float()) + score = self.model( + torch.from_numpy(np.expand_dims(audios[0], axis=0)).float().to(self.device)) score_np = score.detach().cpu().numpy() return score_np[0][0] \ No newline at end of file diff --git a/speechscore/scores/distill_mos/sqa.py b/speechscore/scores/distill_mos/sqa.py index ce4aa6cc..4fb8a0a5 100644 --- a/speechscore/scores/distill_mos/sqa.py +++ b/speechscore/scores/distill_mos/sqa.py @@ -16,7 +16,7 @@ SEQ_LEN = 122880 MAX_HOP_LEN = 16000 -DEFAULT_WEIGHTS_CHKPT = os.path.join("scores/distill_mos/weights", "distill_mos_v7.pt") +DEFAULT_WEIGHTS_CHKPT = os.path.join(os.path.dirname(os.path.abspath(__file__)), "weights", "distill_mos_v7.pt") def _complex_compressed(x, hop_length, win_length): diff --git a/speechscore/scores/dnsmos/dnsmos.py b/speechscore/scores/dnsmos/dnsmos.py index 358bdb7e..55098c47 100644 --- a/speechscore/scores/dnsmos/dnsmos.py +++ b/speechscore/scores/dnsmos/dnsmos.py @@ -11,14 +11,15 @@ from basis import ScoreBasis +module_dir = os.path.dirname(os.path.abspath(__file__)) class DNSMOS(ScoreBasis): def __init__(self): super(DNSMOS, self).__init__(name='DNSMOS') self.intrusive = True self.score_rate = 16000 - self.p808_model_path = os.path.join('scores/dnsmos/DNSMOS', 'model_v8.onnx') - self.primary_model_path = os.path.join('scores/dnsmos/DNSMOS', 'sig_bak_ovr.onnx') + self.p808_model_path = os.path.join(module_dir, 'DNSMOS', 'model_v8.onnx') + self.primary_model_path = os.path.join(module_dir, 'DNSMOS', 'sig_bak_ovr.onnx') self.compute_score = ComputeScore(self.primary_model_path, self.p808_model_path) def windowed_scoring(self, audios, rate): @@ -26,8 +27,9 @@ def windowed_scoring(self, audios, rate): class ComputeScore: def __init__(self, primary_model_path, p808_model_path) -> None: - self.onnx_sess = ort.InferenceSession(primary_model_path) - self.p808_onnx_sess = ort.InferenceSession(p808_model_path) + providers = ['CUDAExecutionProvider', 'CPUExecutionProvider'] + self.onnx_sess = ort.InferenceSession(primary_model_path, providers=providers) + self.p808_onnx_sess = ort.InferenceSession(p808_model_path, providers=providers) def audio_melspec(self, audio, n_mels=120, frame_size=320, hop_length=160, sr=16000, to_db=True): mel_spec = librosa.feature.melspectrogram(y=audio, sr=sr, n_fft=frame_size+1, hop_length=hop_length, n_mels=n_mels) diff --git a/speechscore/scores/nisqa/nisqa.py b/speechscore/scores/nisqa/nisqa.py index 4c5b619a..0d9a1184 100644 --- a/speechscore/scores/nisqa/nisqa.py +++ b/speechscore/scores/nisqa/nisqa.py @@ -1,12 +1,14 @@ from basis import ScoreBasis from scores.nisqa.cal_nisqa import load_nisqa_model - +import os +import torch class NISQA(ScoreBasis): def __init__(self): super(NISQA, self).__init__(name='NISQA') self.intrusive = False self.score_rate = 48000 - self.model = load_nisqa_model("scores/nisqa/weights/nisqa.tar", device='cpu') + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + self.model = load_nisqa_model(os.path.join(os.path.dirname(os.path.abspath(__file__)), "weights/nisqa.tar"), device=device) def windowed_scoring(self, audios, score_rate): from scores.nisqa.cal_nisqa import cal_NISQA