diff --git a/python/README.md b/python/README.md
index 0d4a311..884a1f8 100644
--- a/python/README.md
+++ b/python/README.md
@@ -5,7 +5,9 @@
-- 模型出自阿里达摩院[Paraformer语音识别-中文-通用-16k-离线-large-pytorch](https://www.modelscope.cn/models/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/summary)
+- ASR模型出自阿里达摩院[Paraformer语音识别-中文-通用-16k-离线-large-pytorch](https://www.modelscope.cn/models/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/summary)
+- VAD模型FSMN-VAD出自阿里达摩院[FSMN语音端点检测-中文-通用-16k](https://modelscope.cn/models/damo/speech_fsmn_vad_zh-cn-16k-common-pytorch/summary)
+- Punc模型CT-Transformer出自阿里达摩院[CT-Transformer标点-中文-通用-pytorch](https://modelscope.cn/models/damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch/summary)
- 🎉该项目核心代码已经并入[FunASR](https://github.com/alibaba-damo-academy/FunASR)
- 本仓库仅对模型做了转换,只采用ONNXRuntime推理引擎
@@ -52,17 +54,18 @@
3. 运行demo
```python
from rapid_paraformer import RapidParaformer
+ ```
config_path = 'resources/config.yaml'
paraformer = RapidParaformer(config_path)
-
+
# 输入:支持Union[str, np.ndarray, List[str]] 三种方式传入
# 输出: List[asr_res]
wav_path = [
'test_wavs/0478_00017.wav',
]
-
+
result = paraformer(wav_path)
print(result)
```
@@ -71,3 +74,14 @@
['呃说不配合就不配合的好以上的话呢我们摘取八九十三条因为这三条的话呢比较典型啊一些数字比较明确尤其是时间那么我们要投资者就是了解这一点啊不要轻信这个市场可以快速回来啊这些配市公司啊后期又利好了可
以快速快速攻能包括像前一段时间啊有些媒体在二三月份的时候']
```
+
+更新内容:
+
+1、更新了VAD和Punc
+
+更新内容主要代码都来源于[FunASR](https://github.com/alibaba-damo-academy/FunASR)
+
+模型导出参考[这里](https://github.com/alibaba-damo-academy/FunASR/tree/main/funasr/export) ,把导出来的model.onnx放到对应的文件夹就可以了。
+
+demo里面组合了使用方式 ,目前来看VAD的效果不太好,所以我这里直接是把音频手动按固定的30s切了,然后再去识别组合。
+
diff --git a/python/demo.py b/python/demo.py
index c5170e6..8943300 100644
--- a/python/demo.py
+++ b/python/demo.py
@@ -1,23 +1,170 @@
# -*- encoding: utf-8 -*-
-# @Author: SWHL
-# @Contact: liekkaskono@163.com
+
from rapid_paraformer import RapidParaformer
+from rapid_paraformer.rapid_punc import PuncParaformer
+from rapid_paraformer.rapid_vad import RapidVad
+import moviepy.editor as mp
+import time
+from concurrent.futures import ThreadPoolExecutor
+vad_model = RapidVad()
+paraformer = RapidParaformer()
+punc = PuncParaformer()
+
+#统计时间的装饰器
+def timeit(func):
+ def wrapper(*args, **kwargs):
+ start = time.time()
+ print(f"function name: {func.__name__}")
+ result = func(*args, **kwargs)
+ end = time.time() - start
+ print(f"cost time: {end}")
+ return result
+ return wrapper
+# 音频时长
+@timeit
+def get_audio_duration(wav_path):
+ import wave
+ f = wave.open(wav_path, 'rb')
+ params = f.getparams()
+ nchannels, sampwidth, framerate, nframes = params[:4]
+ duration = nframes / framerate
+ #转成00:00:00格式
+ m, s = divmod(duration, 60)
+ h, m = divmod(m, 60)
+ duration = "%02d:%02d:%02d" % (h, m, s)
+ return duration
+
+#读取音频
+from joblib import Parallel, delayed
+@timeit
+def load_audiro(wav_path):
+ '''
+ 加载音频
+ :param wav_path: 音频路径
+ :return:
+ '''
+ #如果wav是mp4格式,需要先转换为wav格式,然后再加载
+
+ # print('加载音频')
+ y, sr = librosa.load(wav_path, sr=16000) # 加载音频文件并解码为音频信号数组
+ # wav_list = [y[i:i + 16000 * 30] for i in range(0, len(y), 16000 * 30)]
+ wav_list = (y[i:i + 16000 * 30] for i in range(0, len(y), 16000 * 30))
+ # wav_list = Parallel(n_jobs=-1)(
+ # delayed(lambda i: y[i:i + 16000 * 30])(i)
+ # for i in range(0, len(y), 16000 * 30))
+ # print('切割音频完成')
+ return wav_list
+
+def split_string(text, length):
+ """
+ 将字符串按照指定长度拆分并返回拆分后的列表
+ """
+ result = []
+ start = 0
+ while start < len(text):
+ end = start + length
+ result.append(text[start:end])
+ start = end
+ return result
+
+def text_process(text):
+ #句号,问号后面自动换行
+ text = text.replace('。', '。\n')
+ text = text.replace('?', '?\n')
+ return text
+
+
+import librosa
+import threading
+
+@timeit
+def load_and_cut_audio(wav_path, num_threads=4, chunk_size=30):
+ y, sr = librosa.load(wav_path, sr=16000)
+ n_samples = len(y)
+ chunk_size_samples = 16000 * chunk_size
+ num_chunks = n_samples // chunk_size_samples + (n_samples % chunk_size_samples > 0)
+ chunk_indices = [(i * chunk_size_samples, min(n_samples, (i + 1) * chunk_size_samples)) for i in range(num_chunks)]
+
+ results = [None] * num_chunks
+
+ def load_and_cut_thread(start_index, end_index, result_list, index):
+ result_list[index] = y[start_index:end_index]
+
+ threads = []
+ for i, (start_index, end_index) in enumerate(chunk_indices):
+ t = threading.Thread(target=load_and_cut_thread, args=(start_index, end_index, results, i))
+ threads.append(t)
+ t.start()
+
+ if i % num_threads == num_threads - 1:
+ for thread in threads[i - num_threads + 1:i + 1]:
+ thread.join()
+
+ for thread in threads[(num_chunks - 1) // num_threads * num_threads:]:
+ thread.join()
+
+ return results
+
+
+
+@timeit
+def vad(vad_model, wav_path):
+ return vad_model(wav_path)
+
+from multiprocessing import Pool
+from functools import partial
+
+def asr_single(wav):
+ try:
+ result_text = paraformer(wav)[0][0]
+ except:
+ result_text = ''
+ return result_text
+@timeit
+def asr(wav_path):
+ wav_list = load_audiro(wav_path)
+ # wav_list = load_and_cut_audio(wav_path)
+ # pool = Pool()
+ # result = pool.map(partial(asr_single), wav_list)
+ # pool.close()
+ # pool.join()
+ with ThreadPoolExecutor() as executor:
+ result = executor.map(partial(asr_single), wav_list)
+
+ return result
+
+if __name__ == '__main__':
+ wave_path = r'C:\Users\ADMINI~1\AppData\Local\Temp\gradio\d5e738ea910657f76c96e6fbfb74f7de8c6fdb11\11.mp3'
+ # wave_path = r'E:\10分钟.wav'
+ if wave_path.endswith('.mp4'):
+ wav_path = wave_path.replace('.mp4', '.wav')
+ clip = mp.VideoFileClip(wave_path)
+ clip.audio.write_audiofile(wav_path,fps = 22050,bitrate='64k') # 将剪辑对象的音频部分写入音频文件
+ print('mp4转wav完成')
+ print(wav_path)
+ print(clip.duration)
+
+ #音频时长
+ # duration = get_audio_duration(wave_path)
+ # print(f"音频时长:{duration}")
+ # vad
+ # vad_result = vad(vad_model, row_path)
+ # print(f"vad结果:{vad_result}")
+
+ #asr
+ asr_result = asr(wave_path)
+ print('asr完成')
+ print(asr_result)
+ # print(f"asr结果:{asr_result}")
+ #标点
+ new_text = punc(''.join(asr_result))
+ prossed_text = text_process(new_text[0])
+ print(f"标点结果:{prossed_text}")
+ #将识别结果写入txt,名称为音频名称
+ # with open(f'{wave_path.replace(".mp4", "")}.txt', 'w') as f:
+ # f.write(prossed_text)
-config_path = 'resources/config.yaml'
-paraformer = RapidParaformer(config_path)
-wav_path = [
- 'test_wavs/0478_00017.wav',
- 'test_wavs/asr_example_zh.wav',
- 'test_wavs/0478_00017.wav',
- 'test_wavs/asr_example_zh.wav',
- 'test_wavs/0478_00017.wav',
- 'test_wavs/asr_example_zh.wav',
-]
-print(wav_path)
-# wav_path = 'test_wavs/0478_00017.wav'
-result = paraformer(wav_path)
-print(result)
diff --git a/python/rapid_paraformer/punc_model/punc.yaml b/python/rapid_paraformer/punc_model/punc.yaml
new file mode 100644
index 0000000..2c0f6ac
--- /dev/null
+++ b/python/rapid_paraformer/punc_model/punc.yaml
@@ -0,0 +1,19 @@
+punc_list:
+-
+- _
+- ','
+- 。
+- '?'
+- 、
+
+TokenIDConverter:
+ token_path: punc_model/punc_token_list.pkl
+ unk_symbol:
+Model:
+ model_path: punc_model/model.onnx
+ use_cuda: false
+ CUDAExecutionProvider:
+ device_id: 0
+ arena_extend_strategy: kNextPowerOfTwo
+ cudnn_conv_algo_search: EXHAUSTIVE
+ do_copy_in_default_stream: true
\ No newline at end of file
diff --git a/python/rapid_paraformer/punc_model/punc_token_list.pkl b/python/rapid_paraformer/punc_model/punc_token_list.pkl
new file mode 100644
index 0000000..4dc4756
Binary files /dev/null and b/python/rapid_paraformer/punc_model/punc_token_list.pkl differ
diff --git a/python/rapid_paraformer/rapid_paraformer.py b/python/rapid_paraformer/rapid_paraformer.py
index 34b3692..bad2843 100644
--- a/python/rapid_paraformer/rapid_paraformer.py
+++ b/python/rapid_paraformer/rapid_paraformer.py
@@ -1,62 +1,88 @@
# -*- encoding: utf-8 -*-
# @Author: SWHL
# @Contact: liekkaskono@163.com
-import traceback
from pathlib import Path
-from typing import List, Union, Tuple
+from typing import List,Union
import librosa
import numpy as np
-from .utils import (CharTokenizer, Hypothesis, ONNXRuntimeError,
- OrtInferSession, TokenIDConverter, WavFrontend, get_logger,
- read_yaml)
+from .utils import (CharTokenizer, Hypothesis, OrtInferSession,
+ TokenIDConverter, WavFrontend, read_yaml)
-logging = get_logger()
+cur_dir = Path(__file__).resolve().parent
class RapidParaformer():
- def __init__(self, config_path: Union[str, Path]) -> None:
- if not Path(config_path).exists():
- raise FileNotFoundError(f'{config_path} does not exist.')
-
- config = read_yaml(config_path)
+ def __init__(self, config_path: str = None) -> None:
+ config = read_yaml(cur_dir / 'config.yaml')
+ if config_path:
+ config = read_yaml(config_path)
self.converter = TokenIDConverter(**config['TokenIDConverter'])
self.tokenizer = CharTokenizer(**config['CharTokenizer'])
- self.frontend = WavFrontend(
+ self.frontend_asr = WavFrontend(
cmvn_file=config['WavFrontend']['cmvn_file'],
**config['WavFrontend']['frontend_conf']
)
self.ort_infer = OrtInferSession(config['Model'])
- self.batch_size = config['Model']['batch_size']
- def __call__(self, wav_content: Union[str, np.ndarray, List[str]]) -> List:
- waveform_list = self.load_data(wav_content)
- waveform_nums = len(waveform_list)
+ def __call__(self, wav_path: str) -> List:
+ if isinstance(wav_path, str):
+
+ waveform = librosa.load(wav_path)[0][None, ...] # 读取音频文件,并转换为numpy数组
+ elif isinstance(wav_path, np.ndarray):
+ waveform = self.load_data(wav_path)[0][None, ...] #兼容numpy数组格式数据
+ else:
+ raise TypeError('wav_path must be str or numpy.ndarray')
+
+ speech, _ = self.frontend_asr.forward_fbank(waveform)
+ feats, feats_len = self.frontend_asr.forward_lfr_cmvn(speech)
+ try:
+ am_scores = self.ort_infer(input_content=[feats, feats_len])
+ except Exception as e:
+ # raise RuntimeError(f'ONNXRuntime Error: {e}')
+ return [[]]
+
+
+ results = []
+ for am_score in am_scores:
+ pred_res = self.infer_one_feat(am_score)
+ results.append(pred_res)
+ return results
+
+ def infer_one_feat(self, am_score: np.ndarray) -> List[str]:
+ yseq = am_score.argmax(axis=-1)
+ score = am_score.max(axis=-1)
+ score = np.sum(score, axis=-1)
+
+ # pad with mask tokens to ensure compatibility with sos/eos tokens
+ # asr_model.sos:1 asr_model.eos:2
+ yseq = np.array([1] + yseq.tolist() + [2])
+ nbest_hyps = [Hypothesis(yseq=yseq, score=score)]
- asr_res = []
- for beg_idx in range(0, waveform_nums, self.batch_size):
- end_idx = min(waveform_nums, beg_idx + self.batch_size)
+ infer_res = []
+ for hyp in nbest_hyps:
+ # remove sos/eos and get results
+ last_pos = -1
+ token_int = hyp.yseq[1:last_pos].tolist()
- feats, feats_len = self.extract_feat(waveform_list[beg_idx:end_idx])
+ # remove blank symbol id, which is assumed to be 0
+ token_int = list(filter(lambda x: x not in (0, 2), token_int))
- try:
- am_scores, valid_token_lens = self.infer(feats, feats_len)
- except ONNXRuntimeError:
- logging.warning("input wav is silence or noise")
- preds = []
- else:
- preds = self.decode(am_scores, valid_token_lens)
+ # Change integer-ids to tokens
+ token = self.converter.ids2tokens(token_int)
- asr_res.extend(preds)
- return asr_res
+ text = self.tokenizer.tokens2text(token)
+ infer_res.append(text)
+ # print(infer_res)
+ return infer_res
def load_data(self,
- wav_content: Union[str, np.ndarray, List[str]]) -> List:
+ wav_content: Union[str, np.ndarray, List[str]], fs: int = None) -> List:
def load_wav(path: str) -> np.ndarray:
- waveform, _ = librosa.load(path, sr=None)
- return waveform[None, ...]
+ waveform, _ = librosa.load(path, sr=fs)
+ return waveform
if isinstance(wav_content, np.ndarray):
return [wav_content]
@@ -69,71 +95,10 @@ def load_wav(path: str) -> np.ndarray:
raise TypeError(
f'The type of {wav_content} is not in [str, np.ndarray, list]')
-
- def extract_feat(self,
- waveform_list: List[np.ndarray]
- ) -> Tuple[np.ndarray, np.ndarray]:
- feats, feats_len = [], []
- for waveform in waveform_list:
- speech, _ = self.frontend.fbank(waveform)
- feat, feat_len = self.frontend.lfr_cmvn(speech)
- feats.append(feat)
- feats_len.append(feat_len)
-
- feats = self.pad_feats(feats, np.max(feats_len))
- feats_len = np.array(feats_len).astype(np.int32)
- return feats, feats_len
-
- @staticmethod
- def pad_feats(feats: List[np.ndarray], max_feat_len: int) -> np.ndarray:
- def pad_feat(feat: np.ndarray, cur_len: int) -> np.ndarray:
- pad_width = ((0, max_feat_len - cur_len), (0, 0))
- return np.pad(feat, pad_width, 'constant', constant_values=0)
-
- feat_res = [pad_feat(feat, feat.shape[0]) for feat in feats]
- feats = np.array(feat_res).astype(np.float32)
- return feats
-
- def infer(self, feats: np.ndarray,
- feats_len: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
- am_scores, token_nums = self.ort_infer([feats, feats_len])
- return am_scores, token_nums
-
- def decode(self, am_scores: np.ndarray, token_nums: int) -> List[str]:
- return [self.decode_one(am_score, token_num)
- for am_score, token_num in zip(am_scores, token_nums)]
-
- def decode_one(self,
- am_score: np.ndarray,
- valid_token_num: int) -> List[str]:
- yseq = am_score.argmax(axis=-1)
- score = am_score.max(axis=-1)
- score = np.sum(score, axis=-1)
-
- # pad with mask tokens to ensure compatibility with sos/eos tokens
- # asr_model.sos:1 asr_model.eos:2
- yseq = np.array([1] + yseq.tolist() + [2])
- hyp = Hypothesis(yseq=yseq, score=score)
-
- # remove sos/eos and get results
- last_pos = -1
- token_int = hyp.yseq[1:last_pos].tolist()
-
- # remove blank symbol id, which is assumed to be 0
- token_int = list(filter(lambda x: x not in (0, 2), token_int))
-
- # Change integer-ids to tokens
- token = self.converter.ids2tokens(token_int)
- text = self.tokenizer.tokens2text(token)
- return text[:valid_token_num-1]
-
-
if __name__ == '__main__':
- project_dir = Path(__file__).resolve().parent.parent
- cfg_path = project_dir / 'resources' / 'config.yaml'
- paraformer = RapidParaformer(cfg_path)
+ paraformer = RapidParaformer()
- wav_file = '0478_00017.wav'
+ wav_file = '/Users/laichunping/Documents/ASR/RapidASR-2.0.0/test_wavs/0478_00017.wav'
for i in range(1000):
result = paraformer(wav_file)
print(result)
diff --git a/python/rapid_paraformer/rapid_punc.py b/python/rapid_paraformer/rapid_punc.py
new file mode 100644
index 0000000..514e8ff
--- /dev/null
+++ b/python/rapid_paraformer/rapid_punc.py
@@ -0,0 +1,116 @@
+# -*- coding: UTF-8 -*-
+'''
+Project -> File :RapidASR-2.0.0 -> rapid_punc.py
+Author :standy
+Date :2023/5/3 11:45
+'''
+
+from pathlib import Path
+from typing import Union, Tuple
+import numpy as np
+
+from .utils import (OrtInferSession, read_yaml)
+from .utils import (TokenIDConverter, split_to_mini_sentence,code_mix_split_words)
+import warnings
+from typeguard import check_argument_types
+cur_dir = Path(__file__).resolve().parent
+
+class PuncParaformer():
+
+ def __init__(self, config_path: str = None) -> None:
+ config = read_yaml(cur_dir / 'punc_model/punc.yaml')
+
+ if config_path:
+ config = read_yaml(config_path)
+
+ self.converter = TokenIDConverter(**config['TokenIDConverter']) #转换器
+ self.ort_infer = OrtInferSession(config['Model']) #推理器
+ self.batch_size = 1
+ self.punc_list = config['punc_list']
+ self.period = 0
+ for i in range(len(self.punc_list)):
+ if self.punc_list[i] == ",":
+ self.punc_list[i] = ","
+ elif self.punc_list[i] == "?":
+ self.punc_list[i] = "?"
+ elif self.punc_list[i] == "。":
+ self.period = i
+
+ def __call__(self, text: Union[list, str], split_size=20):
+ check_argument_types()
+ split_text = code_mix_split_words(text)
+ split_text_id = self.converter.tokens2ids(split_text)
+ mini_sentences = split_to_mini_sentence(split_text, split_size)
+ mini_sentences_id = split_to_mini_sentence(split_text_id, split_size)
+ assert len(mini_sentences) == len(mini_sentences_id)
+ cache_sent = []
+ cache_sent_id = []
+ new_mini_sentence = ""
+ new_mini_sentence_punc = []
+ cache_pop_trigger_limit = 200
+ for mini_sentence_i in range(len(mini_sentences)):
+ mini_sentence = mini_sentences[mini_sentence_i]
+ mini_sentence_id = mini_sentences_id[mini_sentence_i]
+ mini_sentence = cache_sent + mini_sentence
+ mini_sentence_id = np.array(cache_sent_id + mini_sentence_id, dtype='int64')
+ data = {
+ "text": mini_sentence_id[None,:],
+ "text_lengths": np.array([len(mini_sentence_id)], dtype='int32'),
+ }
+ try:
+ outputs = self.infer(data['text'], data['text_lengths']) # 推理
+ y = outputs[0] # (1, seq_len, num_class)
+ punctuations = np.argmax(y,axis=-1) #返回最大值的索引
+ assert punctuations.size == len(mini_sentence) # 如果取得索引的长度和句子长度不一致,报错
+ except Exception as e:
+ warnings.warn(f'Error occurs when processing {mini_sentence}. Error message: {e}')
+ # punctuations = np.zeros(len(mini_sentence), dtype='int64')
+ # 搜索最后一个句号/问号作为缓存
+ if mini_sentence_i < len(mini_sentences) - 1:
+ sentenceEnd = -1
+ last_comma_index = -1
+ for i in range(len(punctuations) - 2, 1, -1):
+ if self.punc_list[punctuations[i]] == "。" or self.punc_list[punctuations[i]] == "?":
+ sentenceEnd = i
+ break
+ if last_comma_index < 0 and self.punc_list[punctuations[i]] == ",":
+ last_comma_index = i
+
+ if sentenceEnd < 0 and len(mini_sentence) > cache_pop_trigger_limit and last_comma_index >= 0:
+ # 这句话太长了,用逗号隔开了。直接把逗号前面的部分作为一句话
+ sentenceEnd = last_comma_index
+ punctuations[sentenceEnd] = self.period
+ cache_sent = mini_sentence[sentenceEnd + 1:]
+ cache_sent_id = mini_sentence_id[sentenceEnd + 1:].tolist()
+ mini_sentence = mini_sentence[0:sentenceEnd + 1]
+ punctuations = punctuations[0:sentenceEnd + 1]
+ # print(punctuations)
+ new_mini_sentence_punc += [int(x) for x in punctuations]
+ words_with_punc = []
+ for i in range(len(mini_sentence)):
+ if i > 0:
+ if len(mini_sentence[i][0].encode()) == 1 and len(mini_sentence[i - 1][0].encode()) == 1:
+ mini_sentence[i] = " " + mini_sentence[i]
+ words_with_punc.append(mini_sentence[i])
+ if self.punc_list[punctuations[i]] != "_":
+ words_with_punc.append(self.punc_list[punctuations[i]])
+ new_mini_sentence += "".join(words_with_punc)
+ # 在句末加上句号
+ new_mini_sentence_out = new_mini_sentence
+ new_mini_sentence_punc_out = new_mini_sentence_punc
+ if mini_sentence_i == len(mini_sentences) - 1:
+ if new_mini_sentence[-1] == "," or new_mini_sentence[-1] == "、":
+ new_mini_sentence_out = new_mini_sentence[:-1] + "。"
+ new_mini_sentence_punc_out = new_mini_sentence_punc[:-1] + [self.period]
+ elif new_mini_sentence[-1] != "。" and new_mini_sentence[-1] != "?":
+ new_mini_sentence_out = new_mini_sentence + "。"
+ new_mini_sentence_punc_out = new_mini_sentence_punc[:-1] + [self.period]
+ # print(new_mini_sentence_out, new_mini_sentence_punc_out)
+ return new_mini_sentence_out, new_mini_sentence_punc_out
+
+ def infer(self, feats: np.ndarray,
+ feats_len: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
+ outputs = self.ort_infer([feats, feats_len])
+ return outputs
+
+
diff --git a/python/rapid_paraformer/rapid_vad.py b/python/rapid_paraformer/rapid_vad.py
new file mode 100644
index 0000000..c9f991b
--- /dev/null
+++ b/python/rapid_paraformer/rapid_vad.py
@@ -0,0 +1,144 @@
+import os
+
+import librosa
+import numpy as np
+import warnings
+from .utils import (OrtInferSession,read_yaml)
+from .utils import WavFrontend
+from .utils import E2EVadModel
+from pathlib import Path
+from typing import Union, Tuple, List
+cur_dir = Path(__file__).resolve().parent
+
+class RapidVad():
+ def __init__(self, config_path: str = None,max_end_sil:int = None) -> None:
+ config = read_yaml(cur_dir / 'vad_model/vad.yaml')
+
+ if config_path:
+ config = read_yaml(config_path)
+
+ cmvn_file = os.path.join(cur_dir / 'vad_model/vad.mvn')
+
+ self.frontend_vad = WavFrontend(
+ cmvn_file=cmvn_file,
+ **config['frontend_conf']
+ )
+
+ # self.ort_infer = OrtInferSession(model_file, device_id, intra_op_num_threads=intra_op_num_threads)
+ self.ort_infer = OrtInferSession(config['Model'])
+ self.batch_size = config['batch_size']
+ self.vad_scorer = E2EVadModel(config["vad_post_conf"])
+ self.max_end_sil = max_end_sil if max_end_sil is not None else config["vad_post_conf"]["max_end_silence_time"]
+ self.encoder_conf = config["encoder_conf"]
+
+ def prepare_cache(self, in_cache: list = []):
+ if len(in_cache) > 0:
+ return in_cache
+ fsmn_layers = self.encoder_conf["fsmn_layers"]
+ proj_dim = self.encoder_conf["proj_dim"]
+ lorder = self.encoder_conf["lorder"]
+ for i in range(fsmn_layers):
+ cache = np.zeros((1, proj_dim, lorder - 1, 1)).astype(np.float32)
+ in_cache.append(cache)
+ return in_cache
+
+ def __call__(self, audio_in: Union[str, np.ndarray, List[str]], **kwargs) -> List:
+ # waveform = self.load_data(audio_in, self.frontend.fs)
+ waveform = librosa.load(audio_in)[0][None, ...] # 读取音频 ,并转换为二维数组
+
+ segments = [[]] * self.batch_size
+ speech, _ = self.frontend_vad.forward_fbank(waveform) # 提取特征
+ feats, feats_len = self.frontend_vad.forward_lfr_cmvn(speech) # 提取特征
+ # print(feats.shape, feats_len.shape)
+
+ is_final = kwargs.get('kwargs', False)
+ waveform = np.array(waveform)
+ param_dict = kwargs.get('param_dict', dict())
+ in_cache = param_dict.get('in_cache', list())
+ in_cache = self.prepare_cache(in_cache)
+ try:
+ t_offset = 0
+ step = int(min(feats_len.max(), 6000))
+ for t_offset in range(0, int(feats_len), min(step, feats_len - t_offset)):
+ if t_offset + step >= feats_len - 1:
+ step = feats_len - t_offset
+ is_final = True
+ else:
+ is_final = False
+ feats_package = feats[:, t_offset:int(t_offset + step), :]
+ waveform_package = waveform[:,
+ t_offset * 160:min(waveform.shape[-1], (int(t_offset + step) - 1) * 160 + 400)]
+
+ inputs = [feats_package]
+ # inputs = [feats]
+ inputs.extend(in_cache)
+ scores, out_caches = self.infer(inputs)
+ in_cache = out_caches
+ segments_part = self.vad_scorer(scores, waveform_package, is_final=is_final,
+ max_end_sil=self.max_end_sil, online=False)
+ # segments = self.vad_scorer(scores, waveform[0][None, :], is_final=is_final, max_end_sil=self.max_end_sil)
+
+ if segments_part:
+ for batch_num in range(0, self.batch_size):
+ segments[batch_num] += segments_part[batch_num]
+
+ except Exception as e:
+ segments = ''
+ warnings.warn("input wav is silence or noise")
+
+
+ return segments
+
+ def load_data(self,
+ wav_content: Union[str, np.ndarray, List[str]], fs: int = None) -> List:
+ def load_wav(path: str) -> np.ndarray:
+ waveform, _ = librosa.load(path, sr=fs)
+ return waveform
+
+ if isinstance(wav_content, np.ndarray):
+ return [wav_content]
+
+ if isinstance(wav_content, str):
+ return [load_wav(wav_content)]
+
+ if isinstance(wav_content, list):
+ return [load_wav(path) for path in wav_content]
+
+ raise TypeError(
+ f'The type of {wav_content} is not in [str, np.ndarray, list]')
+
+ def extract_feat(self,
+ waveform_list: List[np.ndarray]
+ ) -> Tuple[np.ndarray, np.ndarray]:
+ feats, feats_len = [], []
+
+ for waveform in waveform_list:
+ speech, _ = self.frontend.fbank(waveform)
+ feat, feat_len = self.frontend.lfr_cmvn(speech)
+ feats.append(feat)
+ feats_len.append(feat_len)
+
+ feats = self.pad_feats(feats, np.max(feats_len)) # 填充特征
+ feats_len = np.array(feats_len).astype(np.int32)[0]
+ return feats, feats_len
+
+ @staticmethod
+ def pad_feats(feats: List[np.ndarray], max_feat_len: int) -> np.ndarray:
+ def pad_feat(feat: np.ndarray, cur_len: int) -> np.ndarray:
+ pad_width = ((0, max_feat_len - cur_len), (0, 0))
+
+ feat = np.squeeze(feat, axis=0) # 去掉第一维
+ print(feat.shape)
+ return np.pad(feat, pad_width, 'constant', constant_values=0)
+
+
+ feat_res = [pad_feat(feat, feat.shape[1]) for feat in feats]
+ feats = np.array(feat_res).astype(np.float32)
+ return feats
+
+ def infer(self, feats: List) -> Tuple[np.ndarray, np.ndarray]:
+
+ outputs = self.ort_infer(feats)
+
+ # scores, out_caches = outputs[0], outputs[1:]
+ return outputs,[]
diff --git a/python/rapid_paraformer/utils.py b/python/rapid_paraformer/utils.py
index 829e36d..f2fbc53 100644
--- a/python/rapid_paraformer/utils.py
+++ b/python/rapid_paraformer/utils.py
@@ -1,32 +1,28 @@
# -*- encoding: utf-8 -*-
# @Author: SWHL
# @Contact: liekkaskono@163.com
-import functools
import logging
import pickle
from pathlib import Path
from typing import Any, Dict, Iterable, List, NamedTuple, Set, Tuple, Union
-
+import warnings
import numpy as np
import yaml
from onnxruntime import (GraphOptimizationLevel, InferenceSession,
SessionOptions, get_available_providers, get_device)
from typeguard import check_argument_types
-
+import kaldi_native_fbank as knf
from .kaldifeat import compute_fbank_feats
root_dir = Path(__file__).resolve().parent
-logger_initialized = {}
-
class TokenIDConverter():
def __init__(self, token_path: Union[Path, str],
unk_symbol: str = "",):
- check_argument_types()
-
- self.token_list = self.load_token(token_path)
- self.unk_symbol = unk_symbol
+ check_argument_types() # 检查参数类型
+ self.token_list = self.load_token(root_dir / token_path) # 读取token
+ self.unk_symbol = unk_symbol # 未知符号
@staticmethod
def load_token(file_path: Union[Path, str]) -> List:
@@ -148,30 +144,38 @@ def __init__(
self.filter_length_max = filter_length_max
self.lfr_m = lfr_m
self.lfr_n = lfr_n
- self.cmvn_file = cmvn_file
+ self.cmvn_file = root_dir / cmvn_file
self.dither = dither
-
- if self.cmvn_file:
- self.cmvn = self.load_cmvn()
+ self.fbank_fn = None
def fbank(self,
- input_content: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
- waveform_len = input_content.shape[1]
- waveform = input_content[0][:waveform_len]
- waveform = waveform * (1 << 15)
- mat = compute_fbank_feats(waveform,
- num_mel_bins=self.n_mels,
- frame_length=self.frame_length,
- frame_shift=self.frame_shift,
- dither=self.dither,
- energy_floor=0.0,
- window_type=self.window,
- sample_frequency=self.fs)
+ waveform: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
+ '''
+ 计算fbank特征
+ :param waveform: 语音信号
+ :return: fbank特征和fbank特征的长度
+ '''
+
+ waveform = waveform * (1 << 15) # 量化
+
+ #如果waveform不是个numpy.ndarray,就报错
+ assert isinstance(waveform, np.ndarray),'waveform must be a numpy.ndarray'
+ self.fbank_fn = knf.OnlineFbank(self.opts) # 初始化fbank计算器
+ self.fbank_fn.accept_waveform(self.opts.frame_opts.samp_freq, waveform.tolist()) # 计算fbank特征
+ frames = self.fbank_fn.num_frames_ready # 计算帧数
+ mat = np.empty([frames, self.opts.mel_opts.num_bins]) # 初始化fbank特征矩阵
+ for i in range(frames):
+ mat[i, :] = self.fbank_fn.get_frame(i) # 获取fbank特征
feat = mat.astype(np.float32)
feat_len = np.array(mat.shape[0]).astype(np.int32)
- return feat, feat_len
+ return feat, feat_len # 返回fbank特征和帧数
def lfr_cmvn(self, feat: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
+ '''
+ 对fbank特征进行lfr和cmvn
+ :param feat: 输入的fbank特征
+ :return: lfr和cmvn后的fbank特征
+ '''
if self.lfr_m != 1 or self.lfr_n != 1:
feat = self.apply_lfr(feat, self.lfr_m, self.lfr_n)
@@ -181,6 +185,62 @@ def lfr_cmvn(self, feat: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
feat_len = np.array(feat.shape[0]).astype(np.int32)
return feat, feat_len
+ def forward_fbank(self,
+ input_content: np.ndarray,
+ ) -> Tuple[np.ndarray, np.ndarray]:
+ '''
+ 这个函数的作用是将输入的音频信号转换为mel频谱特征
+ :param input_content: 输入的音频信号,shape为(batch_size, time_steps)
+ '''
+ feats, feats_lens = [], []
+
+ batch_size = input_content.shape[0]
+
+ input_lengths = np.array([input_content.shape[1]])
+ for i in range(batch_size):
+ waveform_length = input_lengths[i]
+ waveform = input_content[i][:waveform_length]
+ waveform = waveform * (1 << 15)
+ mat = compute_fbank_feats(waveform,
+ num_mel_bins=self.n_mels,
+ frame_length=self.frame_length,
+ frame_shift=self.frame_shift,
+ dither=self.dither,
+ energy_floor=0.0,
+ sample_frequency=self.fs)
+ feats.append(mat)
+ feats_lens.append(mat.shape[0])
+
+ feats_pad = np.array(feats).astype(np.float32)
+ feats_lens = np.array(feats_lens).astype(np.int64)
+ return feats_pad, feats_lens
+
+ def forward_lfr_cmvn(self,
+ input_content: np.ndarray,
+ ) -> Tuple[np.ndarray, np.ndarray]:
+ feats, feats_lens = [], []
+ batch_size = input_content.shape[0]
+
+ if self.cmvn_file:
+ cmvn = self.load_cmvn()
+
+ input_lengths = np.array([input_content.shape[1]])
+ for i in range(batch_size):
+ mat = input_content[i, :input_lengths[i], :]
+
+ if self.lfr_m != 1 or self.lfr_n != 1:
+ mat = self.apply_lfr(mat, self.lfr_m, self.lfr_n)
+
+ if self.cmvn_file:
+ mat = self.apply_cmvn(mat, cmvn)
+
+ feats.append(mat)
+ feats_lens.append(mat.shape[0])
+
+ feats_pad = np.array(feats).astype(np.float32)
+ feats_lens = np.array(feats_lens).astype(np.int32)
+ return feats_pad, feats_lens
+
@staticmethod
def apply_lfr(inputs: np.ndarray, lfr_m: int, lfr_n: int) -> np.ndarray:
LFR_inputs = []
@@ -205,13 +265,13 @@ def apply_lfr(inputs: np.ndarray, lfr_m: int, lfr_n: int) -> np.ndarray:
LFR_outputs = np.vstack(LFR_inputs).astype(np.float32)
return LFR_outputs
- def apply_cmvn(self, inputs: np.ndarray) -> np.ndarray:
+ def apply_cmvn(self, inputs: np.ndarray, cmvn: np.ndarray) -> np.ndarray:
"""
Apply CMVN with mvn data
"""
frame, dim = inputs.shape
- means = np.tile(self.cmvn[0:1, :dim], (frame, 1))
- vars = np.tile(self.cmvn[1:2, :dim], (frame, 1))
+ means = np.tile(cmvn[0:1, :dim], (frame, 1))
+ vars = np.tile(cmvn[1:2, :dim], (frame, 1))
inputs = (inputs + means) * vars
return inputs
@@ -263,15 +323,12 @@ class TokenIDConverterError(Exception):
pass
-class ONNXRuntimeError(Exception):
- pass
-
-
class OrtInferSession():
def __init__(self, config):
sess_opt = SessionOptions()
- sess_opt.log_severity_level = 4
- sess_opt.enable_cpu_mem_arena = False
+ sess_opt.log_severity_level = 4 # 日志级别
+ sess_opt.intra_op_num_threads = 4 # 线程数
+ sess_opt.enable_cpu_mem_arena = False # 是否开启内存复用
sess_opt.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_ALL
cuda_ep = 'CUDAExecutionProvider'
@@ -286,7 +343,7 @@ def __init__(self, config):
EP_list = [(cuda_ep, config[cuda_ep])]
EP_list.append((cpu_ep, cpu_provider_options))
- config['model_path'] = config['model_path']
+ config['model_path'] = str(root_dir / config['model_path'])
self._verify_model(config['model_path'])
self.session = InferenceSession(config['model_path'],
sess_options=sess_opt,
@@ -300,12 +357,9 @@ def __init__(self, config):
RuntimeWarning)
def __call__(self,
- input_content: List[np.ndarray]) -> np.ndarray:
+ input_content: List[Union[np.ndarray, np.ndarray]]) -> np.ndarray:
input_dict = dict(zip(self.get_input_names(), input_content))
- try:
- return self.session.run(None, input_dict)
- except Exception as e:
- raise ONNXRuntimeError('ONNXRuntime inference failed.') from e
+ return self.session.run(None, input_dict)[0]
def get_input_names(self, ):
return [v.name for v in self.session.get_inputs()]
@@ -330,6 +384,51 @@ def _verify_model(model_path):
if not model_path.is_file():
raise FileExistsError(f'{model_path} is not a file.')
+def split_to_mini_sentence(words: list, word_limit: int = 20):
+ '''
+ 把一组单词分成一组小句子。
+ :param words: # 一组单词。
+ :param word_limit: # 每个小句子的单词数量。
+ :return:
+ '''
+ assert word_limit > 1
+ if len(words) <= word_limit:
+ return [words]
+ sentences = []
+ length = len(words)
+ sentence_len = length // word_limit
+ for i in range(sentence_len):
+ sentences.append(words[i * word_limit:(i + 1) * word_limit])
+ if length % word_limit > 0:
+ sentences.append(words[sentence_len * word_limit:])
+ return sentences
+
+def code_mix_split_words(text: str):
+ '''
+ 把一段文本分成单词。
+ :param text: # 一段文本。
+ :return:
+ '''
+ words = []
+
+ segs = text.split()
+ for seg in segs:
+ # There is no space in seg.
+ current_word = ""
+ for c in seg:
+ if len(c.encode()) == 1:
+ # This is an ASCII char.
+ current_word += c
+ else:
+ # This is a Chinese char.
+ if len(current_word) > 0:
+ words.append(current_word)
+ current_word = ""
+ words.append(c)
+ if len(current_word) > 0:
+ words.append(current_word)
+ return words
+
def read_yaml(yaml_path: Union[str, Path]) -> Dict:
if not Path(yaml_path).exists():
@@ -340,33 +439,624 @@ def read_yaml(yaml_path: Union[str, Path]) -> Dict:
return data
-@functools.lru_cache()
-def get_logger(name='rapdi_paraformer'):
- """Initialize and get a logger by name.
- If the logger has not been initialized, this method will initialize the
- logger by adding one or two handlers, otherwise the initialized logger will
- be directly returned. During initialization, a StreamHandler will always be
- added.
- Args:
- name (str): Logger name.
- Returns:
- logging.Logger: The expected logger.
- """
- logger = logging.getLogger(name)
- if name in logger_initialized:
- return logger
-
- for logger_name in logger_initialized:
- if name.startswith(logger_name):
- return logger
-
- formatter = logging.Formatter(
- '[%(asctime)s] %(name)s %(levelname)s: %(message)s',
- datefmt="%Y/%m/%d %H:%M:%S")
-
- sh = logging.StreamHandler()
- sh.setFormatter(formatter)
- logger.addHandler(sh)
- logger_initialized[name] = True
- logger.propagate = False
- return logger
+
+#vad部分
+from enum import Enum
+from typing import List, Tuple, Dict, Any
+
+import math
+import numpy as np
+
+class VadStateMachine(Enum):
+ kVadInStateStartPointNotDetected = 1
+ kVadInStateInSpeechSegment = 2
+ kVadInStateEndPointDetected = 3
+
+
+class FrameState(Enum):
+ kFrameStateInvalid = -1
+ kFrameStateSpeech = 1
+ kFrameStateSil = 0
+
+
+# final voice/unvoice state per frame
+class AudioChangeState(Enum):
+ kChangeStateSpeech2Speech = 0
+ kChangeStateSpeech2Sil = 1
+ kChangeStateSil2Sil = 2
+ kChangeStateSil2Speech = 3
+ kChangeStateNoBegin = 4
+ kChangeStateInvalid = 5
+
+
+class VadDetectMode(Enum):
+ kVadSingleUtteranceDetectMode = 0
+ kVadMutipleUtteranceDetectMode = 1
+
+
+class VADXOptions:
+ def __init__(
+ self,
+ sample_rate: int = 16000, # 采样率
+ detect_mode: int = VadDetectMode.kVadMutipleUtteranceDetectMode.value,
+ snr_mode: int = 0,
+ max_end_silence_time: int = 800, #最大结束静音时间
+ max_start_silence_time: int = 3000, #最大开始静音时间
+ do_start_point_detection: bool = True, #是否进行开始点检测
+ do_end_point_detection: bool = True, #是否进行结束点检测
+ window_size_ms: int = 200, #窗口大小
+ sil_to_speech_time_thres: int = 150, # 静音到语音的时间阈值
+ speech_to_sil_time_thres: int = 150, # 语音到静音的时间阈值
+ speech_2_noise_ratio: float = 1.0, # 语音到噪声的比率
+ do_extend: int = 1,
+ lookback_time_start_point: int = 200,
+ lookahead_time_end_point: int = 100,
+ max_single_segment_time: int = 60000,
+ nn_eval_block_size: int = 8,
+ dcd_block_size: int = 4,
+ snr_thres: int = -100.0,
+ noise_frame_num_used_for_snr: int = 100,
+ decibel_thres: int = -100.0,
+ speech_noise_thres: float = 0.6,
+ fe_prior_thres: float = 1e-4,
+ silence_pdf_num: int = 1,
+ sil_pdf_ids: List[int] = [0],
+ speech_noise_thresh_low: float = -0.1,
+ speech_noise_thresh_high: float = 0.3,
+ output_frame_probs: bool = False,
+ frame_in_ms: int = 10,
+ frame_length_ms: int = 25,
+ ):
+ self.sample_rate = sample_rate
+ self.detect_mode = detect_mode
+ self.snr_mode = snr_mode
+ self.max_end_silence_time = max_end_silence_time
+ self.max_start_silence_time = max_start_silence_time
+ self.do_start_point_detection = do_start_point_detection
+ self.do_end_point_detection = do_end_point_detection
+ self.window_size_ms = window_size_ms
+ self.sil_to_speech_time_thres = sil_to_speech_time_thres
+ self.speech_to_sil_time_thres = speech_to_sil_time_thres
+ self.speech_2_noise_ratio = speech_2_noise_ratio
+ self.do_extend = do_extend
+ self.lookback_time_start_point = lookback_time_start_point
+ self.lookahead_time_end_point = lookahead_time_end_point
+ self.max_single_segment_time = max_single_segment_time
+ self.nn_eval_block_size = nn_eval_block_size
+ self.dcd_block_size = dcd_block_size
+ self.snr_thres = snr_thres
+ self.noise_frame_num_used_for_snr = noise_frame_num_used_for_snr
+ self.decibel_thres = decibel_thres
+ self.speech_noise_thres = speech_noise_thres
+ self.fe_prior_thres = fe_prior_thres
+ self.silence_pdf_num = silence_pdf_num
+ self.sil_pdf_ids = sil_pdf_ids
+ self.speech_noise_thresh_low = speech_noise_thresh_low
+ self.speech_noise_thresh_high = speech_noise_thresh_high
+ self.output_frame_probs = output_frame_probs
+ self.frame_in_ms = frame_in_ms
+ self.frame_length_ms = frame_length_ms
+
+
+class E2EVadSpeechBufWithDoa(object):
+ def __init__(self):
+ self.start_ms = 0
+ self.end_ms = 0
+ self.buffer = []
+ self.contain_seg_start_point = False
+ self.contain_seg_end_point = False
+ self.doa = 0
+
+ def Reset(self):
+ self.start_ms = 0
+ self.end_ms = 0
+ self.buffer = []
+ self.contain_seg_start_point = False
+ self.contain_seg_end_point = False
+ self.doa = 0
+
+
+class E2EVadFrameProb(object):
+ def __init__(self):
+ self.noise_prob = 0.0
+ self.speech_prob = 0.0
+ self.score = 0.0
+ self.frame_id = 0
+ self.frm_state = 0
+
+
+class WindowDetector(object):
+ def __init__(self, window_size_ms: int, sil_to_speech_time: int,
+ speech_to_sil_time: int, frame_size_ms: int):
+ self.window_size_ms = window_size_ms
+ self.sil_to_speech_time = sil_to_speech_time
+ self.speech_to_sil_time = speech_to_sil_time
+ self.frame_size_ms = frame_size_ms
+
+ self.win_size_frame = int(window_size_ms / frame_size_ms)
+ self.win_sum = 0
+ self.win_state = [0] * self.win_size_frame # 初始化窗
+
+ self.cur_win_pos = 0
+ self.pre_frame_state = FrameState.kFrameStateSil
+ self.cur_frame_state = FrameState.kFrameStateSil
+ self.sil_to_speech_frmcnt_thres = int(sil_to_speech_time / frame_size_ms)
+ self.speech_to_sil_frmcnt_thres = int(speech_to_sil_time / frame_size_ms)
+
+ self.voice_last_frame_count = 0
+ self.noise_last_frame_count = 0
+ self.hydre_frame_count = 0
+
+ def Reset(self) -> None:
+ self.cur_win_pos = 0
+ self.win_sum = 0
+ self.win_state = [0] * self.win_size_frame
+ self.pre_frame_state = FrameState.kFrameStateSil
+ self.cur_frame_state = FrameState.kFrameStateSil
+ self.voice_last_frame_count = 0
+ self.noise_last_frame_count = 0
+ self.hydre_frame_count = 0
+
+ def GetWinSize(self) -> int:
+ return int(self.win_size_frame)
+
+ def DetectOneFrame(self, frameState: FrameState, frame_count: int) -> AudioChangeState:
+ cur_frame_state = FrameState.kFrameStateSil
+ if frameState == FrameState.kFrameStateSpeech:
+ cur_frame_state = 1
+ elif frameState == FrameState.kFrameStateSil:
+ cur_frame_state = 0
+ else:
+ return AudioChangeState.kChangeStateInvalid
+ self.win_sum -= self.win_state[self.cur_win_pos]
+ self.win_sum += cur_frame_state
+ self.win_state[self.cur_win_pos] = cur_frame_state
+ self.cur_win_pos = (self.cur_win_pos + 1) % self.win_size_frame
+
+ if self.pre_frame_state == FrameState.kFrameStateSil and self.win_sum >= self.sil_to_speech_frmcnt_thres:
+ self.pre_frame_state = FrameState.kFrameStateSpeech
+ return AudioChangeState.kChangeStateSil2Speech
+
+ if self.pre_frame_state == FrameState.kFrameStateSpeech and self.win_sum <= self.speech_to_sil_frmcnt_thres:
+ self.pre_frame_state = FrameState.kFrameStateSil
+ return AudioChangeState.kChangeStateSpeech2Sil
+
+ if self.pre_frame_state == FrameState.kFrameStateSil:
+ return AudioChangeState.kChangeStateSil2Sil
+ if self.pre_frame_state == FrameState.kFrameStateSpeech:
+ return AudioChangeState.kChangeStateSpeech2Speech
+ return AudioChangeState.kChangeStateInvalid
+
+ def FrameSizeMs(self) -> int:
+ return int(self.frame_size_ms)
+
+
+class E2EVadModel():
+ def __init__(self, vad_post_args: Dict[str, Any]):
+ super(E2EVadModel, self).__init__()
+ self.vad_opts = VADXOptions(**vad_post_args)
+ self.windows_detector = WindowDetector(self.vad_opts.window_size_ms,
+ self.vad_opts.sil_to_speech_time_thres,
+ self.vad_opts.speech_to_sil_time_thres,
+ self.vad_opts.frame_in_ms)
+ # self.encoder = encoder
+ # init variables
+ self.is_final = False
+ self.data_buf_start_frame = 0
+ self.frm_cnt = 0
+ self.latest_confirmed_speech_frame = 0
+ self.lastest_confirmed_silence_frame = -1
+ self.continous_silence_frame_count = 0
+ self.vad_state_machine = VadStateMachine.kVadInStateStartPointNotDetected
+ self.confirmed_start_frame = -1
+ self.confirmed_end_frame = -1
+ self.number_end_time_detected = 0
+ self.sil_frame = 0
+ self.sil_pdf_ids = self.vad_opts.sil_pdf_ids
+ self.noise_average_decibel = -100.0
+ self.pre_end_silence_detected = False
+ self.next_seg = True
+
+ self.output_data_buf = []
+ self.output_data_buf_offset = 0
+ self.frame_probs = []
+ self.max_end_sil_frame_cnt_thresh = self.vad_opts.max_end_silence_time - self.vad_opts.speech_to_sil_time_thres
+ self.speech_noise_thres = self.vad_opts.speech_noise_thres
+ self.scores = None
+ self.max_time_out = False
+ self.decibel = []
+ self.data_buf = None
+ self.data_buf_all = None
+ self.waveform = None
+ self.ResetDetection()
+
+ def AllResetDetection(self):
+ self.is_final = False
+ self.data_buf_start_frame = 0
+ self.frm_cnt = 0
+ self.latest_confirmed_speech_frame = 0
+ self.lastest_confirmed_silence_frame = -1
+ self.continous_silence_frame_count = 0
+ self.vad_state_machine = VadStateMachine.kVadInStateStartPointNotDetected
+ self.confirmed_start_frame = -1
+ self.confirmed_end_frame = -1
+ self.number_end_time_detected = 0
+ self.sil_frame = 0
+ self.sil_pdf_ids = self.vad_opts.sil_pdf_ids
+ self.noise_average_decibel = -100.0
+ self.pre_end_silence_detected = False
+ self.next_seg = True
+
+ self.output_data_buf = []
+ self.output_data_buf_offset = 0
+ self.frame_probs = []
+ self.max_end_sil_frame_cnt_thresh = self.vad_opts.max_end_silence_time - self.vad_opts.speech_to_sil_time_thres
+ self.speech_noise_thres = self.vad_opts.speech_noise_thres
+ self.scores = None
+ self.max_time_out = False
+ self.decibel = []
+ self.data_buf = None
+ self.data_buf_all = None
+ self.waveform = None
+ self.ResetDetection()
+
+ def ResetDetection(self):
+ self.continous_silence_frame_count = 0
+ self.latest_confirmed_speech_frame = 0
+ self.lastest_confirmed_silence_frame = -1
+ self.confirmed_start_frame = -1
+ self.confirmed_end_frame = -1
+ self.vad_state_machine = VadStateMachine.kVadInStateStartPointNotDetected
+ self.windows_detector.Reset()
+ self.sil_frame = 0
+ self.frame_probs = []
+
+ def ComputeDecibel(self) -> None:
+ frame_sample_length = int(self.vad_opts.frame_length_ms * self.vad_opts.sample_rate / 1000)
+ frame_shift_length = int(self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000)
+ if self.data_buf_all is None:
+ self.data_buf_all = self.waveform[0] # self.data_buf is pointed to self.waveform[0]
+ self.data_buf = self.data_buf_all
+ else:
+ self.data_buf_all = np.concatenate((self.data_buf_all, self.waveform[0]))
+ for offset in range(0, self.waveform.shape[1] - frame_sample_length + 1, frame_shift_length):
+ self.decibel.append(
+ 10 * math.log10(np.square((self.waveform[0][offset: offset + frame_sample_length])).sum() + \
+ 0.000001))
+
+ def ComputeScores(self, scores: np.ndarray) -> None:
+ # scores = self.encoder(feats, in_cache) # return B * T * D
+ self.vad_opts.nn_eval_block_size = scores.shape[1]
+ self.frm_cnt += scores.shape[1] # count total frames
+ if self.scores is None:
+ self.scores = scores # the first calculation
+ else:
+ self.scores = np.concatenate((self.scores, scores), axis=1)
+ # print("scores.shape: ", self.scores.shape)
+
+ def PopDataBufTillFrame(self, frame_idx: int) -> None: # need check again
+ while self.data_buf_start_frame < frame_idx:
+ if len(self.data_buf) >= int(self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000):
+ self.data_buf_start_frame += 1
+ self.data_buf = self.data_buf_all[self.data_buf_start_frame * int(
+ self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000):]
+
+ def PopDataToOutputBuf(self, start_frm: int, frm_cnt: int, first_frm_is_start_point: bool,
+ last_frm_is_end_point: bool, end_point_is_sent_end: bool) -> None:
+ self.PopDataBufTillFrame(start_frm)
+ expected_sample_number = int(frm_cnt * self.vad_opts.sample_rate * self.vad_opts.frame_in_ms / 1000)
+ if last_frm_is_end_point:
+ extra_sample = max(0, int(self.vad_opts.frame_length_ms * self.vad_opts.sample_rate / 1000 - \
+ self.vad_opts.sample_rate * self.vad_opts.frame_in_ms / 1000))
+ expected_sample_number += int(extra_sample)
+ if end_point_is_sent_end:
+ expected_sample_number = max(expected_sample_number, len(self.data_buf))
+ if len(self.data_buf) < expected_sample_number:
+ print('error in calling pop data_buf\n')
+
+ if len(self.output_data_buf) == 0 or first_frm_is_start_point:
+ self.output_data_buf.append(E2EVadSpeechBufWithDoa())
+ self.output_data_buf[-1].Reset()
+ self.output_data_buf[-1].start_ms = start_frm * self.vad_opts.frame_in_ms
+ self.output_data_buf[-1].end_ms = self.output_data_buf[-1].start_ms
+ self.output_data_buf[-1].doa = 0
+ cur_seg = self.output_data_buf[-1]
+ if cur_seg.end_ms != start_frm * self.vad_opts.frame_in_ms:
+ print('warning\n')
+ out_pos = len(cur_seg.buffer) # cur_seg.buff现在没做任何操作
+ data_to_pop = 0
+ if end_point_is_sent_end:
+ data_to_pop = expected_sample_number
+ else:
+ data_to_pop = int(frm_cnt * self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000)
+ if data_to_pop > len(self.data_buf):
+ print('VAD data_to_pop is bigger than self.data_buf.size()!!!\n')
+ data_to_pop = len(self.data_buf)
+ expected_sample_number = len(self.data_buf)
+
+ cur_seg.doa = 0
+ for sample_cpy_out in range(0, data_to_pop):
+ # cur_seg.buffer[out_pos ++] = data_buf_.back();
+ out_pos += 1
+ for sample_cpy_out in range(data_to_pop, expected_sample_number):
+ # cur_seg.buffer[out_pos++] = data_buf_.back()
+ out_pos += 1
+ if cur_seg.end_ms != start_frm * self.vad_opts.frame_in_ms:
+ print('Something wrong with the VAD algorithm\n')
+ self.data_buf_start_frame += frm_cnt
+ cur_seg.end_ms = (start_frm + frm_cnt) * self.vad_opts.frame_in_ms
+ if first_frm_is_start_point:
+ cur_seg.contain_seg_start_point = True
+ if last_frm_is_end_point:
+ cur_seg.contain_seg_end_point = True
+
+ def OnSilenceDetected(self, valid_frame: int):
+ self.lastest_confirmed_silence_frame = valid_frame
+ if self.vad_state_machine == VadStateMachine.kVadInStateStartPointNotDetected:
+ self.PopDataBufTillFrame(valid_frame)
+ # silence_detected_callback_
+ # pass
+
+ def OnVoiceDetected(self, valid_frame: int) -> None:
+ self.latest_confirmed_speech_frame = valid_frame
+ self.PopDataToOutputBuf(valid_frame, 1, False, False, False)
+
+ def OnVoiceStart(self, start_frame: int, fake_result: bool = False) -> None:
+ if self.vad_opts.do_start_point_detection:
+ pass
+ if self.confirmed_start_frame != -1:
+ print('not reset vad properly\n')
+ else:
+ self.confirmed_start_frame = start_frame
+
+ if not fake_result and self.vad_state_machine == VadStateMachine.kVadInStateStartPointNotDetected:
+ self.PopDataToOutputBuf(self.confirmed_start_frame, 1, True, False, False)
+
+ def OnVoiceEnd(self, end_frame: int, fake_result: bool, is_last_frame: bool) -> None:
+ for t in range(self.latest_confirmed_speech_frame + 1, end_frame):
+ self.OnVoiceDetected(t)
+ if self.vad_opts.do_end_point_detection:
+ pass
+ if self.confirmed_end_frame != -1:
+ print('not reset vad properly\n')
+ else:
+ self.confirmed_end_frame = end_frame
+ if not fake_result:
+ self.sil_frame = 0
+ self.PopDataToOutputBuf(self.confirmed_end_frame, 1, False, True, is_last_frame)
+ self.number_end_time_detected += 1
+
+ def MaybeOnVoiceEndIfLastFrame(self, is_final_frame: bool, cur_frm_idx: int) -> None:
+ if is_final_frame:
+ self.OnVoiceEnd(cur_frm_idx, False, True)
+ self.vad_state_machine = VadStateMachine.kVadInStateEndPointDetected
+
+ def GetLatency(self) -> int:
+ return int(self.LatencyFrmNumAtStartPoint() * self.vad_opts.frame_in_ms)
+
+ def LatencyFrmNumAtStartPoint(self) -> int:
+ vad_latency = self.windows_detector.GetWinSize()
+ if self.vad_opts.do_extend:
+ vad_latency += int(self.vad_opts.lookback_time_start_point / self.vad_opts.frame_in_ms)
+ return vad_latency
+
+ def GetFrameState(self, t: int) -> FrameState:
+ frame_state = FrameState.kFrameStateInvalid
+ cur_decibel = self.decibel[t]
+ cur_snr = cur_decibel - self.noise_average_decibel
+ # for each frame, calc log posterior probability of each state
+ if cur_decibel < self.vad_opts.decibel_thres:
+ frame_state = FrameState.kFrameStateSil
+ self.DetectOneFrame(frame_state, t, False)
+ return frame_state
+
+ sum_score = 0.0
+ noise_prob = 0.0
+ assert len(self.sil_pdf_ids) == self.vad_opts.silence_pdf_num
+ if len(self.sil_pdf_ids) > 0:
+
+ assert len(self.scores) == 1 # 只支持batch_size = 1的测试
+ sil_pdf_scores = [self.scores[0][t][sil_pdf_id] for sil_pdf_id in self.sil_pdf_ids]
+ sum_score = sum(sil_pdf_scores)
+ noise_prob = math.log(sum_score) * self.vad_opts.speech_2_noise_ratio
+ total_score = 1.0
+ sum_score = total_score - sum_score
+ speech_prob = math.log(sum_score)
+ if self.vad_opts.output_frame_probs:
+ frame_prob = E2EVadFrameProb()
+ frame_prob.noise_prob = noise_prob
+ frame_prob.speech_prob = speech_prob
+ frame_prob.score = sum_score
+ frame_prob.frame_id = t
+ self.frame_probs.append(frame_prob)
+ if math.exp(speech_prob) >= math.exp(noise_prob) + self.speech_noise_thres:
+ if cur_snr >= self.vad_opts.snr_thres and cur_decibel >= self.vad_opts.decibel_thres:
+ frame_state = FrameState.kFrameStateSpeech
+ else:
+ frame_state = FrameState.kFrameStateSil
+ else:
+ frame_state = FrameState.kFrameStateSil
+ if self.noise_average_decibel < -99.9:
+ self.noise_average_decibel = cur_decibel
+ else:
+ self.noise_average_decibel = (cur_decibel + self.noise_average_decibel * (
+ self.vad_opts.noise_frame_num_used_for_snr
+ - 1)) / self.vad_opts.noise_frame_num_used_for_snr
+
+ return frame_state
+
+ def __call__(self, score: np.ndarray, waveform: np.ndarray,
+ is_final: bool = False, max_end_sil: int = 800, online: bool = False
+ ):
+ self.max_end_sil_frame_cnt_thresh = max_end_sil - self.vad_opts.speech_to_sil_time_thres
+ self.waveform = waveform # compute decibel for each frame
+ self.ComputeDecibel()
+ # print('score shape: ', score.shape)
+ self.ComputeScores(score)
+ if not is_final:
+ self.DetectCommonFrames()
+ else:
+ self.DetectLastFrames()
+ segments = []
+ for batch_num in range(0, score.shape[0]): # only support batch_size = 1 now
+ segment_batch = []
+ if len(self.output_data_buf) > 0:
+ for i in range(self.output_data_buf_offset, len(self.output_data_buf)):
+ if online:
+ if not self.output_data_buf[i].contain_seg_start_point:
+ continue
+ if not self.next_seg and not self.output_data_buf[i].contain_seg_end_point:
+ continue
+ start_ms = self.output_data_buf[i].start_ms if self.next_seg else -1
+ if self.output_data_buf[i].contain_seg_end_point:
+ end_ms = self.output_data_buf[i].end_ms
+ self.next_seg = True
+ self.output_data_buf_offset += 1
+ else:
+ end_ms = -1
+ self.next_seg = False
+ else:
+ if not is_final and (not self.output_data_buf[i].contain_seg_start_point or not self.output_data_buf[
+ i].contain_seg_end_point):
+ continue
+ start_ms = self.output_data_buf[i].start_ms
+ end_ms = self.output_data_buf[i].end_ms
+ self.output_data_buf_offset += 1
+ segment = [start_ms, end_ms]
+ segment_batch.append(segment)
+
+ if segment_batch:
+ segments.append(segment_batch)
+ if is_final:
+ # reset class variables and clear the dict for the next query
+ self.AllResetDetection()
+ return segments
+
+ def DetectCommonFrames(self) -> int:
+ if self.vad_state_machine == VadStateMachine.kVadInStateEndPointDetected:
+ return 0
+ for i in range(self.vad_opts.nn_eval_block_size - 1, -1, -1):
+ frame_state = FrameState.kFrameStateInvalid
+ frame_state = self.GetFrameState(self.frm_cnt - 1 - i)
+ self.DetectOneFrame(frame_state, self.frm_cnt - 1 - i, False)
+
+ return 0
+
+ def DetectLastFrames(self) -> int:
+ if self.vad_state_machine == VadStateMachine.kVadInStateEndPointDetected:
+ return 0
+ for i in range(self.vad_opts.nn_eval_block_size - 1, -1, -1):
+ # frame_state = FrameState.kFrameStateInvalid
+ frame_state = self.GetFrameState(self.frm_cnt - 1 - i)
+ if i != 0:
+ self.DetectOneFrame(frame_state, self.frm_cnt - 1 - i, False)
+ else:
+ self.DetectOneFrame(frame_state, self.frm_cnt - 1, True)
+
+ return 0
+
+ def DetectOneFrame(self, cur_frm_state: FrameState, cur_frm_idx: int, is_final_frame: bool) -> None:
+ tmp_cur_frm_state = FrameState.kFrameStateInvalid
+ if cur_frm_state == FrameState.kFrameStateSpeech:
+ if math.fabs(1.0) > self.vad_opts.fe_prior_thres:
+ tmp_cur_frm_state = FrameState.kFrameStateSpeech
+ else:
+ tmp_cur_frm_state = FrameState.kFrameStateSil
+ elif cur_frm_state == FrameState.kFrameStateSil:
+ tmp_cur_frm_state = FrameState.kFrameStateSil
+ state_change = self.windows_detector.DetectOneFrame(tmp_cur_frm_state, cur_frm_idx)
+ frm_shift_in_ms = self.vad_opts.frame_in_ms
+ if AudioChangeState.kChangeStateSil2Speech == state_change:
+ silence_frame_count = self.continous_silence_frame_count
+ self.continous_silence_frame_count = 0
+ self.pre_end_silence_detected = False
+ start_frame = 0
+ if self.vad_state_machine == VadStateMachine.kVadInStateStartPointNotDetected:
+ start_frame = max(self.data_buf_start_frame, cur_frm_idx - self.LatencyFrmNumAtStartPoint())
+ self.OnVoiceStart(start_frame)
+ self.vad_state_machine = VadStateMachine.kVadInStateInSpeechSegment
+ for t in range(start_frame + 1, cur_frm_idx + 1):
+ self.OnVoiceDetected(t)
+ elif self.vad_state_machine == VadStateMachine.kVadInStateInSpeechSegment:
+ for t in range(self.latest_confirmed_speech_frame + 1, cur_frm_idx):
+ self.OnVoiceDetected(t)
+ if cur_frm_idx - self.confirmed_start_frame + 1 > \
+ self.vad_opts.max_single_segment_time / frm_shift_in_ms:
+ self.OnVoiceEnd(cur_frm_idx, False, False)
+ self.vad_state_machine = VadStateMachine.kVadInStateEndPointDetected
+ elif not is_final_frame:
+ self.OnVoiceDetected(cur_frm_idx)
+ else:
+ self.MaybeOnVoiceEndIfLastFrame(is_final_frame, cur_frm_idx)
+ else:
+ pass
+ elif AudioChangeState.kChangeStateSpeech2Sil == state_change:
+ self.continous_silence_frame_count = 0
+ if self.vad_state_machine == VadStateMachine.kVadInStateStartPointNotDetected:
+ pass
+ elif self.vad_state_machine == VadStateMachine.kVadInStateInSpeechSegment:
+ if cur_frm_idx - self.confirmed_start_frame + 1 > \
+ self.vad_opts.max_single_segment_time / frm_shift_in_ms:
+ self.OnVoiceEnd(cur_frm_idx, False, False)
+ self.vad_state_machine = VadStateMachine.kVadInStateEndPointDetected
+ elif not is_final_frame:
+ self.OnVoiceDetected(cur_frm_idx)
+ else:
+ self.MaybeOnVoiceEndIfLastFrame(is_final_frame, cur_frm_idx)
+ else:
+ pass
+ elif AudioChangeState.kChangeStateSpeech2Speech == state_change:
+ self.continous_silence_frame_count = 0
+ if self.vad_state_machine == VadStateMachine.kVadInStateInSpeechSegment:
+ if cur_frm_idx - self.confirmed_start_frame + 1 > \
+ self.vad_opts.max_single_segment_time / frm_shift_in_ms:
+ self.max_time_out = True
+ self.OnVoiceEnd(cur_frm_idx, False, False)
+ self.vad_state_machine = VadStateMachine.kVadInStateEndPointDetected
+ elif not is_final_frame:
+ self.OnVoiceDetected(cur_frm_idx)
+ else:
+ self.MaybeOnVoiceEndIfLastFrame(is_final_frame, cur_frm_idx)
+ else:
+ pass
+ elif AudioChangeState.kChangeStateSil2Sil == state_change:
+ self.continous_silence_frame_count += 1
+ if self.vad_state_machine == VadStateMachine.kVadInStateStartPointNotDetected:
+ # silence timeout, return zero length decision
+ if ((self.vad_opts.detect_mode == VadDetectMode.kVadSingleUtteranceDetectMode.value) and (
+ self.continous_silence_frame_count * frm_shift_in_ms > self.vad_opts.max_start_silence_time)) \
+ or (is_final_frame and self.number_end_time_detected == 0):
+ for t in range(self.lastest_confirmed_silence_frame + 1, cur_frm_idx):
+ self.OnSilenceDetected(t)
+ self.OnVoiceStart(0, True)
+ self.OnVoiceEnd(0, True, False);
+ self.vad_state_machine = VadStateMachine.kVadInStateEndPointDetected
+ else:
+ if cur_frm_idx >= self.LatencyFrmNumAtStartPoint():
+ self.OnSilenceDetected(cur_frm_idx - self.LatencyFrmNumAtStartPoint())
+ elif self.vad_state_machine == VadStateMachine.kVadInStateInSpeechSegment:
+ if self.continous_silence_frame_count * frm_shift_in_ms >= self.max_end_sil_frame_cnt_thresh:
+ lookback_frame = int(self.max_end_sil_frame_cnt_thresh / frm_shift_in_ms)
+ if self.vad_opts.do_extend:
+ lookback_frame -= int(self.vad_opts.lookahead_time_end_point / frm_shift_in_ms)
+ lookback_frame -= 1
+ lookback_frame = max(0, lookback_frame)
+ self.OnVoiceEnd(cur_frm_idx - lookback_frame, False, False)
+ self.vad_state_machine = VadStateMachine.kVadInStateEndPointDetected
+ elif cur_frm_idx - self.confirmed_start_frame + 1 > \
+ self.vad_opts.max_single_segment_time / frm_shift_in_ms:
+ self.OnVoiceEnd(cur_frm_idx, False, False)
+ self.vad_state_machine = VadStateMachine.kVadInStateEndPointDetected
+ elif self.vad_opts.do_extend and not is_final_frame:
+ if self.continous_silence_frame_count <= int(
+ self.vad_opts.lookahead_time_end_point / frm_shift_in_ms):
+ self.OnVoiceDetected(cur_frm_idx)
+ else:
+ self.MaybeOnVoiceEndIfLastFrame(is_final_frame, cur_frm_idx)
+ else:
+ pass
+
+ if self.vad_state_machine == VadStateMachine.kVadInStateEndPointDetected and \
+ self.vad_opts.detect_mode == VadDetectMode.kVadMutipleUtteranceDetectMode.value:
+ self.ResetDetection()
+
diff --git a/python/rapid_paraformer/vad_model/vad.mvn b/python/rapid_paraformer/vad_model/vad.mvn
new file mode 100644
index 0000000..77070ed
--- /dev/null
+++ b/python/rapid_paraformer/vad_model/vad.mvn
@@ -0,0 +1,8 @@
+
+ 400 400
+[ 0 ]
+ 400 400
+ 0 [ -8.311879 -8.600912 -9.615928 -10.43595 -11.21292 -11.88333 -12.36243 -12.63706 -12.8818 -12.83066 -12.89103 -12.95666 -13.19763 -13.40598 -13.49113 -13.5546 -13.55639 -13.51915 -13.68284 -13.53289 -13.42107 -13.65519 -13.50713 -13.75251 -13.76715 -13.87408 -13.73109 -13.70412 -13.56073 -13.53488 -13.54895 -13.56228 -13.59408 -13.62047 -13.64198 -13.66109 -13.62669 -13.58297 -13.57387 -13.4739 -13.53063 -13.48348 -13.61047 -13.64716 -13.71546 -13.79184 -13.90614 -14.03098 -14.18205 -14.35881 -14.48419 -14.60172 -14.70591 -14.83362 -14.92122 -15.00622 -15.05122 -15.03119 -14.99028 -14.92302 -14.86927 -14.82691 -14.7972 -14.76909 -14.71356 -14.61277 -14.51696 -14.42252 -14.36405 -14.30451 -14.23161 -14.19851 -14.16633 -14.15649 -14.10504 -13.99518 -13.79562 -13.3996 -12.7767 -11.71208 -8.311879 -8.600912 -9.615928 -10.43595 -11.21292 -11.88333 -12.36243 -12.63706 -12.8818 -12.83066 -12.89103 -12.95666 -13.19763 -13.40598 -13.49113 -13.5546 -13.55639 -13.51915 -13.68284 -13.53289 -13.42107 -13.65519 -13.50713 -13.75251 -13.76715 -13.87408 -13.73109 -13.70412 -13.56073 -13.53488 -13.54895 -13.56228 -13.59408 -13.62047 -13.64198 -13.66109 -13.62669 -13.58297 -13.57387 -13.4739 -13.53063 -13.48348 -13.61047 -13.64716 -13.71546 -13.79184 -13.90614 -14.03098 -14.18205 -14.35881 -14.48419 -14.60172 -14.70591 -14.83362 -14.92122 -15.00622 -15.05122 -15.03119 -14.99028 -14.92302 -14.86927 -14.82691 -14.7972 -14.76909 -14.71356 -14.61277 -14.51696 -14.42252 -14.36405 -14.30451 -14.23161 -14.19851 -14.16633 -14.15649 -14.10504 -13.99518 -13.79562 -13.3996 -12.7767 -11.71208 -8.311879 -8.600912 -9.615928 -10.43595 -11.21292 -11.88333 -12.36243 -12.63706 -12.8818 -12.83066 -12.89103 -12.95666 -13.19763 -13.40598 -13.49113 -13.5546 -13.55639 -13.51915 -13.68284 -13.53289 -13.42107 -13.65519 -13.50713 -13.75251 -13.76715 -13.87408 -13.73109 -13.70412 -13.56073 -13.53488 -13.54895 -13.56228 -13.59408 -13.62047 -13.64198 -13.66109 -13.62669 -13.58297 -13.57387 -13.4739 -13.53063 -13.48348 -13.61047 -13.64716 -13.71546 -13.79184 -13.90614 -14.03098 -14.18205 -14.35881 -14.48419 -14.60172 -14.70591 -14.83362 -14.92122 -15.00622 -15.05122 -15.03119 -14.99028 -14.92302 -14.86927 -14.82691 -14.7972 -14.76909 -14.71356 -14.61277 -14.51696 -14.42252 -14.36405 -14.30451 -14.23161 -14.19851 -14.16633 -14.15649 -14.10504 -13.99518 -13.79562 -13.3996 -12.7767 -11.71208 -8.311879 -8.600912 -9.615928 -10.43595 -11.21292 -11.88333 -12.36243 -12.63706 -12.8818 -12.83066 -12.89103 -12.95666 -13.19763 -13.40598 -13.49113 -13.5546 -13.55639 -13.51915 -13.68284 -13.53289 -13.42107 -13.65519 -13.50713 -13.75251 -13.76715 -13.87408 -13.73109 -13.70412 -13.56073 -13.53488 -13.54895 -13.56228 -13.59408 -13.62047 -13.64198 -13.66109 -13.62669 -13.58297 -13.57387 -13.4739 -13.53063 -13.48348 -13.61047 -13.64716 -13.71546 -13.79184 -13.90614 -14.03098 -14.18205 -14.35881 -14.48419 -14.60172 -14.70591 -14.83362 -14.92122 -15.00622 -15.05122 -15.03119 -14.99028 -14.92302 -14.86927 -14.82691 -14.7972 -14.76909 -14.71356 -14.61277 -14.51696 -14.42252 -14.36405 -14.30451 -14.23161 -14.19851 -14.16633 -14.15649 -14.10504 -13.99518 -13.79562 -13.3996 -12.7767 -11.71208 -8.311879 -8.600912 -9.615928 -10.43595 -11.21292 -11.88333 -12.36243 -12.63706 -12.8818 -12.83066 -12.89103 -12.95666 -13.19763 -13.40598 -13.49113 -13.5546 -13.55639 -13.51915 -13.68284 -13.53289 -13.42107 -13.65519 -13.50713 -13.75251 -13.76715 -13.87408 -13.73109 -13.70412 -13.56073 -13.53488 -13.54895 -13.56228 -13.59408 -13.62047 -13.64198 -13.66109 -13.62669 -13.58297 -13.57387 -13.4739 -13.53063 -13.48348 -13.61047 -13.64716 -13.71546 -13.79184 -13.90614 -14.03098 -14.18205 -14.35881 -14.48419 -14.60172 -14.70591 -14.83362 -14.92122 -15.00622 -15.05122 -15.03119 -14.99028 -14.92302 -14.86927 -14.82691 -14.7972 -14.76909 -14.71356 -14.61277 -14.51696 -14.42252 -14.36405 -14.30451 -14.23161 -14.19851 -14.16633 -14.15649 -14.10504 -13.99518 -13.79562 -13.3996 -12.7767 -11.71208 ]
+ 400 400
+ 0 [ 0.155775 0.154484 0.1527379 0.1518718 0.1506028 0.1489256 0.147067 0.1447061 0.1436307 0.1443568 0.1451849 0.1455157 0.1452821 0.1445717 0.1439195 0.1435867 0.1436018 0.1438781 0.1442086 0.1448844 0.1454756 0.145663 0.146268 0.1467386 0.1472724 0.147664 0.1480913 0.1483739 0.1488841 0.1493636 0.1497088 0.1500379 0.1502916 0.1505389 0.1506787 0.1507102 0.1505992 0.1505445 0.1505938 0.1508133 0.1509569 0.1512396 0.1514625 0.1516195 0.1516156 0.1515561 0.1514966 0.1513976 0.1512612 0.151076 0.1510596 0.1510431 0.151077 0.1511168 0.1511917 0.151023 0.1508045 0.1505885 0.1503493 0.1502373 0.1501726 0.1500762 0.1500065 0.1499782 0.150057 0.1502658 0.150469 0.1505335 0.1505505 0.1505328 0.1504275 0.1502438 0.1499674 0.1497118 0.1494661 0.1493102 0.1493681 0.1495501 0.1499738 0.1509654 0.155775 0.154484 0.1527379 0.1518718 0.1506028 0.1489256 0.147067 0.1447061 0.1436307 0.1443568 0.1451849 0.1455157 0.1452821 0.1445717 0.1439195 0.1435867 0.1436018 0.1438781 0.1442086 0.1448844 0.1454756 0.145663 0.146268 0.1467386 0.1472724 0.147664 0.1480913 0.1483739 0.1488841 0.1493636 0.1497088 0.1500379 0.1502916 0.1505389 0.1506787 0.1507102 0.1505992 0.1505445 0.1505938 0.1508133 0.1509569 0.1512396 0.1514625 0.1516195 0.1516156 0.1515561 0.1514966 0.1513976 0.1512612 0.151076 0.1510596 0.1510431 0.151077 0.1511168 0.1511917 0.151023 0.1508045 0.1505885 0.1503493 0.1502373 0.1501726 0.1500762 0.1500065 0.1499782 0.150057 0.1502658 0.150469 0.1505335 0.1505505 0.1505328 0.1504275 0.1502438 0.1499674 0.1497118 0.1494661 0.1493102 0.1493681 0.1495501 0.1499738 0.1509654 0.155775 0.154484 0.1527379 0.1518718 0.1506028 0.1489256 0.147067 0.1447061 0.1436307 0.1443568 0.1451849 0.1455157 0.1452821 0.1445717 0.1439195 0.1435867 0.1436018 0.1438781 0.1442086 0.1448844 0.1454756 0.145663 0.146268 0.1467386 0.1472724 0.147664 0.1480913 0.1483739 0.1488841 0.1493636 0.1497088 0.1500379 0.1502916 0.1505389 0.1506787 0.1507102 0.1505992 0.1505445 0.1505938 0.1508133 0.1509569 0.1512396 0.1514625 0.1516195 0.1516156 0.1515561 0.1514966 0.1513976 0.1512612 0.151076 0.1510596 0.1510431 0.151077 0.1511168 0.1511917 0.151023 0.1508045 0.1505885 0.1503493 0.1502373 0.1501726 0.1500762 0.1500065 0.1499782 0.150057 0.1502658 0.150469 0.1505335 0.1505505 0.1505328 0.1504275 0.1502438 0.1499674 0.1497118 0.1494661 0.1493102 0.1493681 0.1495501 0.1499738 0.1509654 0.155775 0.154484 0.1527379 0.1518718 0.1506028 0.1489256 0.147067 0.1447061 0.1436307 0.1443568 0.1451849 0.1455157 0.1452821 0.1445717 0.1439195 0.1435867 0.1436018 0.1438781 0.1442086 0.1448844 0.1454756 0.145663 0.146268 0.1467386 0.1472724 0.147664 0.1480913 0.1483739 0.1488841 0.1493636 0.1497088 0.1500379 0.1502916 0.1505389 0.1506787 0.1507102 0.1505992 0.1505445 0.1505938 0.1508133 0.1509569 0.1512396 0.1514625 0.1516195 0.1516156 0.1515561 0.1514966 0.1513976 0.1512612 0.151076 0.1510596 0.1510431 0.151077 0.1511168 0.1511917 0.151023 0.1508045 0.1505885 0.1503493 0.1502373 0.1501726 0.1500762 0.1500065 0.1499782 0.150057 0.1502658 0.150469 0.1505335 0.1505505 0.1505328 0.1504275 0.1502438 0.1499674 0.1497118 0.1494661 0.1493102 0.1493681 0.1495501 0.1499738 0.1509654 0.155775 0.154484 0.1527379 0.1518718 0.1506028 0.1489256 0.147067 0.1447061 0.1436307 0.1443568 0.1451849 0.1455157 0.1452821 0.1445717 0.1439195 0.1435867 0.1436018 0.1438781 0.1442086 0.1448844 0.1454756 0.145663 0.146268 0.1467386 0.1472724 0.147664 0.1480913 0.1483739 0.1488841 0.1493636 0.1497088 0.1500379 0.1502916 0.1505389 0.1506787 0.1507102 0.1505992 0.1505445 0.1505938 0.1508133 0.1509569 0.1512396 0.1514625 0.1516195 0.1516156 0.1515561 0.1514966 0.1513976 0.1512612 0.151076 0.1510596 0.1510431 0.151077 0.1511168 0.1511917 0.151023 0.1508045 0.1505885 0.1503493 0.1502373 0.1501726 0.1500762 0.1500065 0.1499782 0.150057 0.1502658 0.150469 0.1505335 0.1505505 0.1505328 0.1504275 0.1502438 0.1499674 0.1497118 0.1494661 0.1493102 0.1493681 0.1495501 0.1499738 0.1509654 ]
+
\ No newline at end of file
diff --git a/python/rapid_paraformer/vad_model/vad.yaml b/python/rapid_paraformer/vad_model/vad.yaml
new file mode 100644
index 0000000..713f640
--- /dev/null
+++ b/python/rapid_paraformer/vad_model/vad.yaml
@@ -0,0 +1,68 @@
+input_size: null
+frontend: wav_frontend
+frontend_conf:
+ fs: 16000
+ window: hamming
+ n_mels: 80
+ frame_length: 25
+ frame_shift: 10
+ dither: 0.0
+ lfr_m: 5
+ lfr_n: 1
+
+Model:
+ model_path: vad_model/model.onnx
+ use_cuda: false
+ intra_op_num_threads: 1
+ CUDAExecutionProvider:
+ device_id: 0
+ arena_extend_strategy: kNextPowerOfTwo
+ cudnn_conv_algo_search: EXHAUSTIVE
+ do_copy_in_default_stream: true
+
+batch_size: 1
+model: e2evad
+encoder: fsmn
+encoder_conf:
+ input_dim: 400
+ input_affine_dim: 140
+ fsmn_layers: 4
+ linear_dim: 250
+ proj_dim: 128
+ lorder: 20
+ rorder: 0
+ lstride: 1
+ rstride: 0
+ output_affine_dim: 140
+ output_dim: 248
+
+
+
+vad_post_conf:
+ sample_rate: 16000
+ detect_mode: 1
+ snr_mode: 0
+ max_end_silence_time: 800
+ max_start_silence_time: 3000
+ do_start_point_detection: True
+ do_end_point_detection: True
+ window_size_ms: 200
+ sil_to_speech_time_thres: 150
+ speech_to_sil_time_thres: 150
+ speech_2_noise_ratio: 1.0
+ do_extend: 1
+ lookback_time_start_point: 200
+ lookahead_time_end_point: 100
+ max_single_segment_time: 60000
+ snr_thres: -100.0
+ noise_frame_num_used_for_snr: 100
+ decibel_thres: -100.0
+ speech_noise_thres: 0.6
+ fe_prior_thres: 0.0001
+ silence_pdf_num: 1
+ sil_pdf_ids: [0]
+ speech_noise_thresh_low: -0.1
+ speech_noise_thresh_high: 0.3
+ output_frame_probs: False
+ frame_in_ms: 10
+ frame_length_ms: 25
diff --git a/python/tests/test_punc.py b/python/tests/test_punc.py
new file mode 100644
index 0000000..ce6d8a1
--- /dev/null
+++ b/python/tests/test_punc.py
@@ -0,0 +1,18 @@
+# -*- coding: UTF-8 -*-
+'''
+Project -> File :RapidASR-2.0.0 -> test_punc
+Author :standy
+Date :2023/5/3 11:55
+function : 测试标点模型
+'''
+
+from rapid_paraformer.rapid_punc import PuncParaformer
+
+model = PuncParaformer()
+
+text_in=['一二年十一月这个时候小熊对于如何投资银行业也有比较成熟的方法了正好当时中国银行业整体创下了新低所以我们就把中国银行业作为一次检验银行投资方法的实验吧我们通过分析中国银行业来熟悉一下应该如何分析银行第一点银行业的根本银行业的根本是什么就是净资产率也就是公司净资产除以总资产这个值的导数就是杠杆比因为杠杆比的计算公式就是总资产除以净资产检验中国银行业最新年度的净资产率是我们分析这些公司的第一步净资产', '率只有最新年度甚至最新季度的净资产率这个数字才是有意义的历史数据当然意义不大的这里我们随便挑五家中国银行来对比一下我们这里就挑选民生银行招商银行中国银行工商银行和兴业银行他们的净资产率在这张表里面可以看', '从表中我们可以看到国有银行呢确实比较保守或者反过来说那些股份制商业银行就比较激进了民生和兴业银行的净资产率就比较低相对而言风险度就比较低是二银行业的盈利能力用什么值来考察呢是总', '资产收益率大家应该还记得初级课教的roe净资产收益率吧进阶课教的用roic投资回报率来分析公司的盈利率但是银行业却不同啊首先它不能用roe为什么因为roe会受到杠杆率的影响如果银行的杠杆率越高', '那他的roe也会被抬高的其次roic在银行业也不好使你很难找到有形资产因为银行所有的资产基本上都是有形资产那这样的话最好的办法我们就应该用什么应该用roa总资产收益率的因为对于银行这样的高杠杆公司来', '说无论是他借来的钱还是他的自由资金也就是净资产其实都是一样的都是用来生利息赚钱的那在这种情况下我们用roa来衡量银行业的盈利能力就是很恰当因为roa的这个分母就是总资产所有的资产只要是能够产生利息的资产我们还是选取上面五家公司来看他们的r', '行易对比一下不怕不是货就怕货币货从二零一二年的rv来看招商银行和和工商银行的盈利能力确实领先而兴业银行的盈利能力很不幸的垫比了当然这只是一年的业绩不能说明什么但是如果你计算出连续十年的阿v某家银行一直持续领先', '间的话那就说明它的盈利能力确实很强啊第三点银行的估值我们已经知道了用pb来给银行业估值那我们来看一下这几个银行在二零一三年八月份的时候各自的市净率pb通过对比我们很清楚的可以看到在这里市净率的话兴业银行在二零', '一三年八月份的时候它的估值是最低的这个说实话倒跟它的盈利能力倒是比较匹配的而中国银行的估值次值这里值得注意的是民生银行当时民生银行的估值已经领先所有的银行但是从这里我们并没有发现哪一个银行有特别的低估比如像买入零点六倍的富国零点三二倍批比的巴克莱伊', '一样的机会在二零一三年八月的中国银行业并没有出现所以只能说这些银行轻微的被低估了但并不是严重的低估好我们迅速的完成了银行的净资产率盈利能力和估值但是这些的话只能说是粗略的分析了银行那更多的我们需要分析什么呢', '需要分析以下几点首先是银行的资金成本不同的银行资金成本各不相同因为它们的资金来源各不相同那哪些资金它的成本是最低的呢你想一下对就是活期储蓄活期储蓄它的利率几乎就是零所以说成本非常低', '所以在一家银行的存款资金当中活期储蓄的比例越高那这家银行的资金成本就越低这个数据可以在银行的年报当中找到我们来看一下这五家银行的资金成本在二零一二年的时候我们很清楚的可以看到招商银行的活期储蓄率遥遥领先第二名是工商银行', '这两家银行确实是零售银行当中的翘楚而相对的民生银行的活期储蓄率就偏低了这一点所以它这个时候它的资金成本两比较高好第二个我们需要分析的是盈利来源银行的利润来源于利息收入和非息收入两个部分利息收入基本就是靠着中国银行业的垄', '断地位在吃饭而非吸收入占比高能够提高其收入的稳定性所以银行的非息收入比例的高低也值得我们注意这五家银行的非息收入的比例在下面可以这张表里面可以看到中国银行因为独享对外人民币业务你要换外币的话只能去中国银行在这种情况下中', '国银行的非息收入当然是遥遥领先的民生银行因为大力发展非息业务所以也能名列前茅那第三点银行业的利息收入它是主要来自于贷款的那贷问题就是这些贷款到底给谁了呢不同的贷款人他的风风险也各不相同', '一般来说个人贷款特别是房贷风险度就比较低而相比较的话公司贷款的风险度就比较高了所以零售贷款率也就是贷款中的个人贷款的比例也是衡量银行风险度的一个很重要的指标在这里我们也可以看到民生银行招商银行和', '中国银行的零售贷款比例都很高特别招商银行可以说是一枝独秀好到了这里的话分析好了之后有人会说衡量银行业风险最重要的数字应该是不良贷款比例啊那你应该怎么分析不良贷款比例呢说实话我个人非常不相信这个数字因为很少有银行会成', '承认自己的不良贷款比例太高了只有在万不得已银行快倒闭的时候银行的管理层才会嘟嘟囔囔的说哦我们的不良贷款好像啊对吧稍微的有有利有点高哦在一九九七年中国银行业大洗牌之前各大银行记录的不良贷款比例非常低直到一九九八年嘿一下子', '不良贷款比例都跳升了几十倍把时任总理朱隆基都吓了一大跳所以不良贷款比例这个数字更多是用来娱乐大众的连带不良贷款多倍率也很有问题了从所以从劳埃德银行的教训我们认为我个人认为银行业的风险最好的工具还是净资产率', '如果真的要分析分析银行业贷款的风险的话那不能仅仅盯住不良贷款率这一个指标而应该把眼光扩大包括那些逾期贷款等等全部都要考虑进去当然在这里我们就不会说了最后以上我们分析了这么多数据那如何从这些数据中得出我们的结论呢或者说完成定', '定量到定性的过程从盈利能力来看招行和工行领先兴业垫底道理很简单招行和工行是中国最著名的两家零售银行中国现维持这么高的息差的话那招行和工行的日子自然非常好过从活期储蓄率来看招行和工行也领先民生垫底', '民生银行的资金成本比较高从非吸收入来看中民生和中行遥遥领先而兴业又变从零售贷款占比来看招行和民生领先兴业和工行遍底把这些总结一下很清楚的就可以看到招商银行领先这五家银行而工商银行屈居第二', '这两家银行在资金成本和盈利能力上领先其他银行所以对于招行的话我们可以估值比较高的市净率匹比富国这样的极度优秀的银行如果是两倍市净率的话哇那招行我可以接受一点五倍的pd市净率而工行的话那才差不多是一点二倍的市净率相反像兴业这样垫底的银行我只能接', '受一倍甚至是零点八倍的pp来估值了最后检验风险度也就是净资产率中行和工行领先民生和兴业便利这样的话中行和工行我们要求的安全边际就比较小而民生和兴业我们要求的安全边际就比较高了好看到了吗银行业就是这么分析的', '并不是很难吧', 'thethethethethethethethethethethethethethe意嗯']
+
+
+
+prossed_text = model(''.join(text_in))
+print(prossed_text[0])
\ No newline at end of file
diff --git a/python/tests/test_vad.py b/python/tests/test_vad.py
new file mode 100644
index 0000000..4bb708e
--- /dev/null
+++ b/python/tests/test_vad.py
@@ -0,0 +1,16 @@
+# -*- coding: UTF-8 -*-
+'''
+Project -> File :RapidASR-2.0.0 -> test_vad
+Author :standy
+Date :2023/5/3 15:57
+function : #测试vad
+'''
+
+from rapid_paraformer.rapid_vad import RapidVad
+# model_dir = "/Users/laichunping/Documents/ASR/FunASR/export/damo/speech_fsmn_vad_zh-cn-16k-common-pytorch"
+wav_path = "/Users/laichunping/Documents/ASR/RapidASR-2.0.0/test_wavs/0478_00017.wav"
+model = RapidVad()
+
+#offline vad
+result = model(wav_path)
+print(result)
\ No newline at end of file