mirror of
https://github.com/k2-fsa/sherpa-onnx.git
synced 2026-01-09 07:41:06 +08:00
Support RK NPU for SenseVoice non-streaming ASR models (#2589)
This PR adds RK NPU support for SenseVoice non-streaming ASR models by implementing a new RKNN backend with greedy CTC decoding. - Adds offline RKNN implementation for SenseVoice models including model loading, feature processing, and CTC decoding - Introduces export tools to convert SenseVoice models from PyTorch to ONNX and then to RKNN format - Implements provider-aware validation to prevent mismatched model and provider usage
This commit is contained in:
parent
926b288525
commit
c691318b95
4
.gitignore
vendored
4
.gitignore
vendored
@ -152,3 +152,7 @@ vocab.json
|
||||
*.so
|
||||
sherpa-onnx-streaming-t-one-russian-2025-09-08
|
||||
sherpa-onnx-wenetspeech-yue-u2pp-conformer-ctc-zh-en-cantonese-int8-2025-09-10
|
||||
am.mvn
|
||||
*bpe.model
|
||||
config.yaml
|
||||
configuration.json
|
||||
|
||||
@ -118,8 +118,11 @@ def display_params(params):
|
||||
os.system(f"cat {params['config']}")
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
model, params = SenseVoiceSmall.from_pretrained(model="iic/SenseVoiceSmall", device="cpu")
|
||||
model.eval()
|
||||
|
||||
display_params(params)
|
||||
|
||||
generate_tokens(params)
|
||||
|
||||
164
scripts/sense-voice/rknn/export-onnx.py
Executable file
164
scripts/sense-voice/rknn/export-onnx.py
Executable file
@ -0,0 +1,164 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2025 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
|
||||
import argparse
|
||||
import os
|
||||
from typing import Any, Dict, List, Tuple
|
||||
|
||||
import onnx
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
|
||||
from torch_model import SenseVoiceSmall
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--input-len-in-seconds",
|
||||
type=int,
|
||||
required=True,
|
||||
help="""RKNN does not support dynamic shape, so we need to hard-code
|
||||
how long the model can process.
|
||||
""",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def add_meta_data(filename: str, meta_data: Dict[str, Any]):
|
||||
"""Add meta data to an ONNX model. It is changed in-place.
|
||||
|
||||
Args:
|
||||
filename:
|
||||
Filename of the ONNX model to be changed.
|
||||
meta_data:
|
||||
Key-value pairs.
|
||||
"""
|
||||
model = onnx.load(filename)
|
||||
while len(model.metadata_props):
|
||||
model.metadata_props.pop()
|
||||
|
||||
for key, value in meta_data.items():
|
||||
meta = model.metadata_props.add()
|
||||
meta.key = key
|
||||
meta.value = str(value)
|
||||
|
||||
onnx.save(model, filename)
|
||||
|
||||
|
||||
def load_cmvn(filename) -> Tuple[List[float], List[float]]:
|
||||
neg_mean = None
|
||||
inv_stddev = None
|
||||
|
||||
with open(filename) as f:
|
||||
for line in f:
|
||||
if not line.startswith("<LearnRateCoef>"):
|
||||
continue
|
||||
t = line.split()[3:-1]
|
||||
|
||||
if neg_mean is None:
|
||||
neg_mean = list(map(lambda x: float(x), t))
|
||||
else:
|
||||
inv_stddev = list(map(lambda x: float(x), t))
|
||||
|
||||
return neg_mean, inv_stddev
|
||||
|
||||
|
||||
def generate_tokens(sp):
|
||||
with open("tokens.txt", "w", encoding="utf-8") as f:
|
||||
for i in range(sp.vocab_size()):
|
||||
f.write(f"{sp.id_to_piece(i)} {i}\n")
|
||||
print("saved to tokens.txt")
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
args = get_args()
|
||||
print(vars(args))
|
||||
|
||||
sp = spm.SentencePieceProcessor()
|
||||
sp.load("./chn_jpn_yue_eng_ko_spectok.bpe.model")
|
||||
vocab_size = sp.vocab_size()
|
||||
generate_tokens(sp)
|
||||
|
||||
print("loading model")
|
||||
|
||||
state_dict = torch.load("./model.pt")
|
||||
if "state_dict" in state_dict:
|
||||
state_dict = state_dict["state_dict"]
|
||||
|
||||
neg_mean, inv_stddev = load_cmvn("./am.mvn")
|
||||
|
||||
neg_mean = torch.tensor(neg_mean, dtype=torch.float32)
|
||||
inv_stddev = torch.tensor(inv_stddev, dtype=torch.float32)
|
||||
|
||||
model = SenseVoiceSmall(neg_mean=neg_mean, inv_stddev=inv_stddev)
|
||||
model.load_state_dict(state_dict)
|
||||
model.eval()
|
||||
del state_dict
|
||||
|
||||
lfr_window_size = 7
|
||||
lfr_window_shift = 6
|
||||
|
||||
# frame shift is 10ms, 1 second has about 100 feature frames
|
||||
input_len_in_seconds = int(args.input_len_in_seconds)
|
||||
num_frames = input_len_in_seconds * 100
|
||||
print("num_frames", num_frames)
|
||||
|
||||
# num_input_frames is an approximate number
|
||||
num_input_frames = int(num_frames / lfr_window_shift + 0.5)
|
||||
print("num_input_frames", num_input_frames)
|
||||
|
||||
x = torch.randn(1, num_input_frames, 560, dtype=torch.float32)
|
||||
|
||||
language = 3
|
||||
text_norm = 15
|
||||
prompt = torch.tensor([language, 1, 2, text_norm], dtype=torch.int32)
|
||||
|
||||
opset_version = 13
|
||||
filename = f"model-{input_len_in_seconds}-seconds.onnx"
|
||||
torch.onnx.export(
|
||||
model,
|
||||
(x, prompt),
|
||||
filename,
|
||||
opset_version=opset_version,
|
||||
input_names=["x", "prompt"],
|
||||
output_names=["logits"],
|
||||
dynamic_axes={},
|
||||
)
|
||||
|
||||
model_author = os.environ.get("model_author", "iic")
|
||||
comment = os.environ.get("comment", "iic/SenseVoiceSmall")
|
||||
url = os.environ.get("url", "https://huggingface.co/FunAudioLLM/SenseVoiceSmall")
|
||||
|
||||
meta_data = {
|
||||
"lfr_window_size": lfr_window_size,
|
||||
"lfr_window_shift": lfr_window_shift,
|
||||
"num_input_frames": num_input_frames,
|
||||
"normalize_samples": 0, # input should be in the range [-32768, 32767]
|
||||
"model_type": "sense_voice_ctc",
|
||||
"version": "1",
|
||||
"model_author": model_author,
|
||||
"maintainer": "k2-fsa",
|
||||
"vocab_size": vocab_size,
|
||||
"comment": comment,
|
||||
"lang_auto": model.lid_dict["auto"],
|
||||
"lang_zh": model.lid_dict["zh"],
|
||||
"lang_en": model.lid_dict["en"],
|
||||
"lang_yue": model.lid_dict["yue"], # cantonese
|
||||
"lang_ja": model.lid_dict["ja"],
|
||||
"lang_ko": model.lid_dict["ko"],
|
||||
"lang_nospeech": model.lid_dict["nospeech"],
|
||||
"with_itn": model.textnorm_dict["withitn"],
|
||||
"without_itn": model.textnorm_dict["woitn"],
|
||||
"url": url,
|
||||
}
|
||||
add_meta_data(filename=filename, meta_data=meta_data)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
torch.manual_seed(20250717)
|
||||
main()
|
||||
158
scripts/sense-voice/rknn/export-rknn.py
Executable file
158
scripts/sense-voice/rknn/export-rknn.py
Executable file
@ -0,0 +1,158 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright (c) 2025 Xiaomi Corporation (authors: Fangjun Kuang)
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
from rknn.api import RKNN
|
||||
|
||||
logging.basicConfig(level=logging.WARNING)
|
||||
|
||||
g_platforms = [
|
||||
# "rv1103",
|
||||
# "rv1103b",
|
||||
# "rv1106",
|
||||
# "rk2118",
|
||||
"rk3562",
|
||||
"rk3566",
|
||||
"rk3568",
|
||||
"rk3576",
|
||||
"rk3588",
|
||||
]
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--target-platform",
|
||||
type=str,
|
||||
required=True,
|
||||
help=f"Supported values are: {','.join(g_platforms)}",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--in-model",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to the input onnx model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--out-model",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to the output rknn model",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def get_meta_data(model: str):
|
||||
import onnxruntime
|
||||
|
||||
session_opts = onnxruntime.SessionOptions()
|
||||
session_opts.inter_op_num_threads = 1
|
||||
session_opts.intra_op_num_threads = 1
|
||||
|
||||
m = onnxruntime.InferenceSession(
|
||||
model,
|
||||
sess_options=session_opts,
|
||||
providers=["CPUExecutionProvider"],
|
||||
)
|
||||
|
||||
for i in m.get_inputs():
|
||||
print(i)
|
||||
|
||||
print("-----")
|
||||
|
||||
for i in m.get_outputs():
|
||||
print(i)
|
||||
print()
|
||||
|
||||
meta = m.get_modelmeta().custom_metadata_map
|
||||
s = ""
|
||||
sep = ""
|
||||
for key, value in meta.items():
|
||||
if key in ("neg_mean", "inv_stddev"):
|
||||
continue
|
||||
s = s + sep + f"{key}={value}"
|
||||
sep = ";"
|
||||
assert len(s) < 1024, len(s)
|
||||
|
||||
print("len(s)", len(s), s)
|
||||
|
||||
return s
|
||||
|
||||
|
||||
def export_rknn(rknn, filename):
|
||||
ret = rknn.export_rknn(filename)
|
||||
if ret != 0:
|
||||
exit(f"Export rknn model to {filename} failed!")
|
||||
|
||||
|
||||
def init_model(filename: str, target_platform: str, custom_string=None):
|
||||
rknn = RKNN(verbose=False)
|
||||
|
||||
rknn.config(
|
||||
optimization_level=0,
|
||||
target_platform=target_platform,
|
||||
custom_string=custom_string,
|
||||
)
|
||||
if not Path(filename).is_file():
|
||||
exit(f"{filename} does not exist")
|
||||
|
||||
ret = rknn.load_onnx(model=filename)
|
||||
if ret != 0:
|
||||
exit(f"Load model {filename} failed!")
|
||||
|
||||
ret = rknn.build(do_quantization=False)
|
||||
if ret != 0:
|
||||
exit(f"Build model {filename} failed!")
|
||||
|
||||
return rknn
|
||||
|
||||
|
||||
class RKNNModel:
|
||||
def __init__(
|
||||
self,
|
||||
model: str,
|
||||
target_platform: str,
|
||||
):
|
||||
meta = get_meta_data(model)
|
||||
print(meta)
|
||||
|
||||
self.model = init_model(
|
||||
model,
|
||||
target_platform=target_platform,
|
||||
custom_string=meta,
|
||||
)
|
||||
|
||||
def export_rknn(self, model):
|
||||
export_rknn(self.model, model)
|
||||
|
||||
def release(self):
|
||||
self.model.release()
|
||||
|
||||
|
||||
def main():
|
||||
args = get_parser().parse_args()
|
||||
print(vars(args))
|
||||
|
||||
model = RKNNModel(
|
||||
model=args.in_model,
|
||||
target_platform=args.target_platform,
|
||||
)
|
||||
|
||||
model.export_rknn(
|
||||
model=args.out_model,
|
||||
)
|
||||
|
||||
model.release()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
236
scripts/sense-voice/rknn/test_onnx.py
Executable file
236
scripts/sense-voice/rknn/test_onnx.py
Executable file
@ -0,0 +1,236 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2025 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
|
||||
import argparse
|
||||
from typing import Tuple
|
||||
|
||||
import kaldi_native_fbank as knf
|
||||
import numpy as np
|
||||
import onnxruntime
|
||||
import onnxruntime as ort
|
||||
import soundfile as sf
|
||||
import torch
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--model",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to model.onnx",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--tokens",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to tokens.txt",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--wave",
|
||||
type=str,
|
||||
required=True,
|
||||
help="The input wave to be recognized",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--language",
|
||||
type=str,
|
||||
default="auto",
|
||||
help="the language of the input wav file. Supported values: zh, en, ja, ko, yue, auto",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--use-itn",
|
||||
type=int,
|
||||
default=0,
|
||||
help="1 to use inverse text normalization. 0 to not use inverse text normalization",
|
||||
)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
class OnnxModel:
|
||||
def __init__(self, filename):
|
||||
session_opts = ort.SessionOptions()
|
||||
session_opts.inter_op_num_threads = 1
|
||||
session_opts.intra_op_num_threads = 1
|
||||
|
||||
self.session_opts = session_opts
|
||||
|
||||
self.model = ort.InferenceSession(
|
||||
filename,
|
||||
sess_options=self.session_opts,
|
||||
providers=["CPUExecutionProvider"],
|
||||
)
|
||||
|
||||
meta = self.model.get_modelmeta().custom_metadata_map
|
||||
|
||||
self.window_size = int(meta["lfr_window_size"]) # lfr_m
|
||||
self.window_shift = int(meta["lfr_window_shift"]) # lfr_n
|
||||
|
||||
lang_zh = int(meta["lang_zh"])
|
||||
lang_en = int(meta["lang_en"])
|
||||
lang_ja = int(meta["lang_ja"])
|
||||
lang_ko = int(meta["lang_ko"])
|
||||
lang_yue = int(meta["lang_yue"])
|
||||
lang_auto = int(meta["lang_auto"])
|
||||
|
||||
self.lang_id = {
|
||||
"zh": lang_zh,
|
||||
"en": lang_en,
|
||||
"ja": lang_ja,
|
||||
"ko": lang_ko,
|
||||
"yue": lang_yue,
|
||||
"auto": lang_auto,
|
||||
}
|
||||
self.with_itn = int(meta["with_itn"])
|
||||
self.without_itn = int(meta["without_itn"])
|
||||
|
||||
self.max_len = self.model.get_inputs()[0].shape[1]
|
||||
|
||||
def __call__(self, x, prompt):
|
||||
logits = self.model.run(
|
||||
[
|
||||
self.model.get_outputs()[0].name,
|
||||
],
|
||||
{
|
||||
self.model.get_inputs()[0].name: x.numpy(),
|
||||
self.model.get_inputs()[1].name: prompt.numpy(),
|
||||
},
|
||||
)[0]
|
||||
|
||||
return torch.from_numpy(logits)
|
||||
|
||||
|
||||
def load_audio(filename: str) -> Tuple[np.ndarray, int]:
|
||||
data, sample_rate = sf.read(
|
||||
filename,
|
||||
always_2d=True,
|
||||
dtype="float32",
|
||||
)
|
||||
data = data[:, 0] # use only the first channel
|
||||
samples = np.ascontiguousarray(data)
|
||||
return samples, sample_rate
|
||||
|
||||
|
||||
def load_tokens(filename):
|
||||
ans = dict()
|
||||
i = 0
|
||||
with open(filename, encoding="utf-8") as f:
|
||||
for line in f:
|
||||
ans[i] = line.strip().split()[0]
|
||||
i += 1
|
||||
return ans
|
||||
|
||||
|
||||
def compute_feat(
|
||||
samples,
|
||||
sample_rate,
|
||||
max_len: int,
|
||||
window_size: int = 7, # lfr_m
|
||||
window_shift: int = 6, # lfr_n
|
||||
):
|
||||
opts = knf.FbankOptions()
|
||||
opts.frame_opts.dither = 0
|
||||
opts.frame_opts.snip_edges = False
|
||||
opts.frame_opts.window_type = "hamming"
|
||||
opts.frame_opts.samp_freq = sample_rate
|
||||
opts.mel_opts.num_bins = 80
|
||||
|
||||
online_fbank = knf.OnlineFbank(opts)
|
||||
online_fbank.accept_waveform(sample_rate, (samples * 32768).tolist())
|
||||
online_fbank.input_finished()
|
||||
|
||||
features = np.stack(
|
||||
[online_fbank.get_frame(i) for i in range(online_fbank.num_frames_ready)]
|
||||
)
|
||||
assert features.data.contiguous is True
|
||||
assert features.dtype == np.float32, features.dtype
|
||||
|
||||
T = (features.shape[0] - window_size) // window_shift + 1
|
||||
features = np.lib.stride_tricks.as_strided(
|
||||
features,
|
||||
shape=(T, features.shape[1] * window_size),
|
||||
strides=((window_shift * features.shape[1]) * 4, 4),
|
||||
)
|
||||
|
||||
print("features.shape", features.shape)
|
||||
|
||||
if features.shape[0] > max_len:
|
||||
features = features[:max_len]
|
||||
elif features.shape[0] < max_len:
|
||||
features = np.pad(
|
||||
features,
|
||||
((0, max_len - features.shape[0]), (0, 0)),
|
||||
mode="constant",
|
||||
constant_values=0,
|
||||
)
|
||||
|
||||
print("features.shape", features.shape)
|
||||
|
||||
return features
|
||||
|
||||
|
||||
def main():
|
||||
args = get_args()
|
||||
print(vars(args))
|
||||
samples, sample_rate = load_audio(args.wave)
|
||||
if sample_rate != 16000:
|
||||
import librosa
|
||||
|
||||
samples = librosa.resample(samples, orig_sr=sample_rate, target_sr=16000)
|
||||
sample_rate = 16000
|
||||
|
||||
model = OnnxModel(filename=args.model)
|
||||
|
||||
features = compute_feat(
|
||||
samples=samples,
|
||||
sample_rate=sample_rate,
|
||||
max_len=model.max_len,
|
||||
window_size=model.window_size,
|
||||
window_shift=model.window_shift,
|
||||
)
|
||||
|
||||
features = torch.from_numpy(features).unsqueeze(0)
|
||||
|
||||
language = model.lang_id["auto"]
|
||||
if args.language in model.lang_id:
|
||||
language = model.lang_id[args.language]
|
||||
else:
|
||||
print(f"Invalid language: '{args.language}'")
|
||||
print("Use auto")
|
||||
|
||||
if args.use_itn:
|
||||
text_norm = model.with_itn
|
||||
else:
|
||||
text_norm = model.without_itn
|
||||
|
||||
prompt = torch.tensor([language, 1, 2, text_norm], dtype=torch.int32)
|
||||
|
||||
logits = model(
|
||||
x=features,
|
||||
prompt=prompt,
|
||||
)
|
||||
|
||||
idx = logits.squeeze(0).argmax(dim=-1)
|
||||
# idx is of shape (T,)
|
||||
idx = torch.unique_consecutive(idx)
|
||||
|
||||
blank_id = 0
|
||||
idx = idx[idx != blank_id].tolist()
|
||||
|
||||
tokens = load_tokens(args.tokens)
|
||||
text = "".join([tokens[i] for i in idx])
|
||||
|
||||
text = text.replace("▁", " ")
|
||||
print(text)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
604
scripts/sense-voice/rknn/torch_model.py
Normal file
604
scripts/sense-voice/rknn/torch_model.py
Normal file
@ -0,0 +1,604 @@
|
||||
# This file is modified from
|
||||
# https://github.com/modelscope/FunASR/blob/main/funasr/models/sense_voice/model.py
|
||||
|
||||
import torch
|
||||
import torch.nn
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class SinusoidalPositionEncoder(nn.Module):
|
||||
""" """
|
||||
|
||||
def __init__(self, d_model=80, dropout_rate=0.1):
|
||||
pass
|
||||
|
||||
def encode(
|
||||
self,
|
||||
positions: torch.Tensor = None,
|
||||
depth: int = None,
|
||||
dtype: torch.dtype = torch.float32,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
positions: (batch_size, )
|
||||
"""
|
||||
batch_size = positions.size(0)
|
||||
positions = positions.type(dtype)
|
||||
device = positions.device
|
||||
log_timescale_increment = torch.log(
|
||||
torch.tensor([10000], dtype=dtype, device=device)
|
||||
) / (depth / 2 - 1)
|
||||
inv_timescales = torch.exp(
|
||||
torch.arange(depth / 2, device=device).type(dtype)
|
||||
* (-log_timescale_increment)
|
||||
)
|
||||
inv_timescales = torch.reshape(inv_timescales, [batch_size, -1])
|
||||
scaled_time = torch.reshape(positions, [1, -1, 1]) * torch.reshape(
|
||||
inv_timescales, [1, 1, -1]
|
||||
)
|
||||
encoding = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=2)
|
||||
return encoding.type(dtype)
|
||||
|
||||
def forward(self, x):
|
||||
batch_size, timesteps, input_dim = x.size()
|
||||
positions = torch.arange(1, timesteps + 1, device=x.device)[None, :]
|
||||
position_encoding = self.encode(positions, input_dim, x.dtype).to(x.device)
|
||||
|
||||
return x + position_encoding
|
||||
|
||||
|
||||
class PositionwiseFeedForward(nn.Module):
|
||||
"""Positionwise feed forward layer.
|
||||
|
||||
Args:
|
||||
idim (int): Input dimenstion.
|
||||
hidden_units (int): The number of hidden units.
|
||||
dropout_rate (float): Dropout rate.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, idim, hidden_units, dropout_rate, activation=None):
|
||||
super().__init__()
|
||||
self.w_1 = torch.nn.Linear(idim, hidden_units)
|
||||
self.w_2 = torch.nn.Linear(hidden_units, idim)
|
||||
self.dropout = torch.nn.Dropout(dropout_rate)
|
||||
if activation is None:
|
||||
activation = torch.nn.ReLU()
|
||||
self.activation = activation
|
||||
|
||||
def forward(self, x):
|
||||
"""Forward function."""
|
||||
return self.w_2(self.dropout(self.activation(self.w_1(x))))
|
||||
|
||||
|
||||
class MultiHeadedAttentionSANM(nn.Module):
|
||||
"""Multi-Head Attention layer.
|
||||
|
||||
Args:
|
||||
n_head (int): The number of heads.
|
||||
n_feat (int): The number of features.
|
||||
dropout_rate (float): Dropout rate.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
n_head,
|
||||
in_feat,
|
||||
n_feat,
|
||||
dropout_rate,
|
||||
kernel_size,
|
||||
sanm_shfit=0,
|
||||
lora_list=None,
|
||||
lora_rank=8,
|
||||
lora_alpha=16,
|
||||
lora_dropout=0.1,
|
||||
):
|
||||
super().__init__()
|
||||
assert n_feat % n_head == 0
|
||||
# We assume d_v always equals d_k
|
||||
self.d_k = n_feat // n_head
|
||||
self.h = n_head
|
||||
self.linear_out = nn.Linear(n_feat, n_feat)
|
||||
self.linear_q_k_v = nn.Linear(in_feat, n_feat * 3)
|
||||
self.attn = None
|
||||
self.dropout = nn.Dropout(p=dropout_rate)
|
||||
|
||||
self.fsmn_block = nn.Conv1d(
|
||||
n_feat, n_feat, kernel_size, stride=1, padding=0, groups=n_feat, bias=False
|
||||
)
|
||||
# padding
|
||||
left_padding = (kernel_size - 1) // 2
|
||||
if sanm_shfit > 0:
|
||||
left_padding = left_padding + sanm_shfit
|
||||
right_padding = kernel_size - 1 - left_padding
|
||||
self.pad_fn = nn.ConstantPad1d((left_padding, right_padding), 0.0)
|
||||
|
||||
def forward_fsmn(self, inputs, mask, mask_shfit_chunk=None):
|
||||
b, t, d = inputs.size()
|
||||
if mask is not None:
|
||||
mask = torch.reshape(mask, (b, -1, 1))
|
||||
if mask_shfit_chunk is not None:
|
||||
mask = mask * mask_shfit_chunk
|
||||
inputs = inputs * mask
|
||||
|
||||
x = inputs.transpose(1, 2)
|
||||
x = self.pad_fn(x)
|
||||
x = self.fsmn_block(x)
|
||||
x = x.transpose(1, 2)
|
||||
x += inputs
|
||||
x = self.dropout(x)
|
||||
if mask is not None:
|
||||
x = x * mask
|
||||
return x
|
||||
|
||||
def forward_qkv(self, x):
|
||||
"""Transform query, key and value.
|
||||
|
||||
Args:
|
||||
query (torch.Tensor): Query tensor (#batch, time1, size).
|
||||
key (torch.Tensor): Key tensor (#batch, time2, size).
|
||||
value (torch.Tensor): Value tensor (#batch, time2, size).
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Transformed query tensor (#batch, n_head, time1, d_k).
|
||||
torch.Tensor: Transformed key tensor (#batch, n_head, time2, d_k).
|
||||
torch.Tensor: Transformed value tensor (#batch, n_head, time2, d_k).
|
||||
|
||||
"""
|
||||
b, t, d = x.size()
|
||||
q_k_v = self.linear_q_k_v(x)
|
||||
q, k, v = torch.split(q_k_v, int(self.h * self.d_k), dim=-1)
|
||||
q_h = torch.reshape(q, (b, t, self.h, self.d_k)).transpose(
|
||||
1, 2
|
||||
) # (batch, head, time1, d_k)
|
||||
k_h = torch.reshape(k, (b, t, self.h, self.d_k)).transpose(
|
||||
1, 2
|
||||
) # (batch, head, time2, d_k)
|
||||
v_h = torch.reshape(v, (b, t, self.h, self.d_k)).transpose(
|
||||
1, 2
|
||||
) # (batch, head, time2, d_k)
|
||||
|
||||
return q_h, k_h, v_h, v
|
||||
|
||||
def forward_attention(self, value, scores, mask, mask_att_chunk_encoder=None):
|
||||
"""Compute attention context vector.
|
||||
|
||||
Args:
|
||||
value (torch.Tensor): Transformed value (#batch, n_head, time2, d_k).
|
||||
scores (torch.Tensor): Attention score (#batch, n_head, time1, time2).
|
||||
mask (torch.Tensor): Mask (#batch, 1, time2) or (#batch, time1, time2).
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Transformed value (#batch, time1, d_model)
|
||||
weighted by the attention score (#batch, time1, time2).
|
||||
|
||||
"""
|
||||
n_batch = value.size(0)
|
||||
if mask is not None:
|
||||
if mask_att_chunk_encoder is not None:
|
||||
mask = mask * mask_att_chunk_encoder
|
||||
|
||||
mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2)
|
||||
|
||||
min_value = -float(
|
||||
"inf"
|
||||
) # float(numpy.finfo(torch.tensor(0, dtype=scores.dtype).numpy().dtype).min)
|
||||
scores = scores.masked_fill(mask, min_value)
|
||||
attn = torch.softmax(scores, dim=-1).masked_fill(
|
||||
mask, 0.0
|
||||
) # (batch, head, time1, time2)
|
||||
else:
|
||||
attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2)
|
||||
|
||||
p_attn = self.dropout(attn)
|
||||
x = torch.matmul(p_attn, value) # (batch, head, time1, d_k)
|
||||
x = (
|
||||
x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k)
|
||||
) # (batch, time1, d_model)
|
||||
|
||||
return self.linear_out(x) # (batch, time1, d_model)
|
||||
|
||||
def forward(self, x, mask, mask_shfit_chunk=None, mask_att_chunk_encoder=None):
|
||||
"""Compute scaled dot product attention.
|
||||
|
||||
Args:
|
||||
query (torch.Tensor): Query tensor (#batch, time1, size).
|
||||
key (torch.Tensor): Key tensor (#batch, time2, size).
|
||||
value (torch.Tensor): Value tensor (#batch, time2, size).
|
||||
mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
|
||||
(#batch, time1, time2).
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Output tensor (#batch, time1, d_model).
|
||||
|
||||
"""
|
||||
q_h, k_h, v_h, v = self.forward_qkv(x)
|
||||
fsmn_memory = self.forward_fsmn(v, mask, mask_shfit_chunk)
|
||||
q_h = q_h * self.d_k ** (-0.5)
|
||||
scores = torch.matmul(q_h, k_h.transpose(-2, -1))
|
||||
att_outs = self.forward_attention(v_h, scores, mask, mask_att_chunk_encoder)
|
||||
return att_outs + fsmn_memory
|
||||
|
||||
|
||||
class EncoderLayerSANM(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_size,
|
||||
size,
|
||||
self_attn,
|
||||
feed_forward,
|
||||
dropout_rate,
|
||||
normalize_before=True,
|
||||
concat_after=False,
|
||||
stochastic_depth_rate=0.0,
|
||||
):
|
||||
super().__init__()
|
||||
self.self_attn = self_attn
|
||||
self.feed_forward = feed_forward
|
||||
self.norm1 = LayerNorm(in_size)
|
||||
self.norm2 = LayerNorm(size)
|
||||
self.dropout = nn.Dropout(dropout_rate)
|
||||
self.in_size = in_size
|
||||
self.size = size
|
||||
self.normalize_before = normalize_before
|
||||
self.concat_after = concat_after
|
||||
if self.concat_after:
|
||||
self.concat_linear = nn.Linear(size + size, size)
|
||||
self.stochastic_depth_rate = stochastic_depth_rate
|
||||
self.dropout_rate = dropout_rate
|
||||
|
||||
def forward(
|
||||
self, x, mask, cache=None, mask_shfit_chunk=None, mask_att_chunk_encoder=None
|
||||
):
|
||||
"""Compute encoded features.
|
||||
|
||||
Args:
|
||||
x_input (torch.Tensor): Input tensor (#batch, time, size).
|
||||
mask (torch.Tensor): Mask tensor for the input (#batch, time).
|
||||
cache (torch.Tensor): Cache tensor of the input (#batch, time - 1, size).
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Output tensor (#batch, time, size).
|
||||
torch.Tensor: Mask tensor (#batch, time).
|
||||
|
||||
"""
|
||||
skip_layer = False
|
||||
# with stochastic depth, residual connection `x + f(x)` becomes
|
||||
# `x <- x + 1 / (1 - p) * f(x)` at training time.
|
||||
stoch_layer_coeff = 1.0
|
||||
if self.training and self.stochastic_depth_rate > 0:
|
||||
skip_layer = torch.rand(1).item() < self.stochastic_depth_rate
|
||||
stoch_layer_coeff = 1.0 / (1 - self.stochastic_depth_rate)
|
||||
|
||||
if skip_layer:
|
||||
if cache is not None:
|
||||
x = torch.cat([cache, x], dim=1)
|
||||
return x, mask
|
||||
|
||||
residual = x
|
||||
if self.normalize_before:
|
||||
x = self.norm1(x)
|
||||
|
||||
if self.concat_after:
|
||||
x_concat = torch.cat(
|
||||
(
|
||||
x,
|
||||
self.self_attn(
|
||||
x,
|
||||
mask,
|
||||
mask_shfit_chunk=mask_shfit_chunk,
|
||||
mask_att_chunk_encoder=mask_att_chunk_encoder,
|
||||
),
|
||||
),
|
||||
dim=-1,
|
||||
)
|
||||
if self.in_size == self.size:
|
||||
x = residual + stoch_layer_coeff * self.concat_linear(x_concat)
|
||||
else:
|
||||
x = stoch_layer_coeff * self.concat_linear(x_concat)
|
||||
else:
|
||||
if self.in_size == self.size:
|
||||
x = residual + stoch_layer_coeff * self.dropout(
|
||||
self.self_attn(
|
||||
x,
|
||||
mask,
|
||||
mask_shfit_chunk=mask_shfit_chunk,
|
||||
mask_att_chunk_encoder=mask_att_chunk_encoder,
|
||||
)
|
||||
)
|
||||
else:
|
||||
x = stoch_layer_coeff * self.dropout(
|
||||
self.self_attn(
|
||||
x,
|
||||
mask,
|
||||
mask_shfit_chunk=mask_shfit_chunk,
|
||||
mask_att_chunk_encoder=mask_att_chunk_encoder,
|
||||
)
|
||||
)
|
||||
return x, mask
|
||||
if not self.normalize_before:
|
||||
x = self.norm1(x)
|
||||
|
||||
residual = x
|
||||
if self.normalize_before:
|
||||
x = self.norm2(x)
|
||||
x = residual + stoch_layer_coeff * self.dropout(self.feed_forward(x))
|
||||
if not self.normalize_before:
|
||||
x = self.norm2(x)
|
||||
|
||||
return x, mask, cache, mask_shfit_chunk, mask_att_chunk_encoder
|
||||
|
||||
|
||||
class LayerNorm(nn.LayerNorm):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def forward(self, input):
|
||||
output = F.layer_norm(
|
||||
input.float(),
|
||||
self.normalized_shape,
|
||||
self.weight.float() if self.weight is not None else None,
|
||||
self.bias.float() if self.bias is not None else None,
|
||||
self.eps,
|
||||
)
|
||||
return output.type_as(input)
|
||||
|
||||
|
||||
def sequence_mask(lengths, maxlen=None, dtype=torch.float32, device=None):
|
||||
if maxlen is None:
|
||||
maxlen = lengths.max()
|
||||
row_vector = torch.arange(0, maxlen, 1).to(lengths.device)
|
||||
matrix = torch.unsqueeze(lengths, dim=-1)
|
||||
mask = row_vector < matrix
|
||||
mask = mask.detach()
|
||||
|
||||
return mask.type(dtype).to(device) if device is not None else mask.type(dtype)
|
||||
|
||||
|
||||
class SenseVoiceEncoderSmall(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.input_size = 80 * 7
|
||||
self.output_size = 512
|
||||
self.attention_heads = 4
|
||||
self.linear_units = 2048
|
||||
self.num_blocks = 50
|
||||
self.tp_blocks = 20
|
||||
self.input_layer = "pe"
|
||||
self.pos_enc_class = "SinusoidalPositionEncoder"
|
||||
self.normalize_before = True
|
||||
self.kernel_size = 11
|
||||
self.sanm_shfit = 0
|
||||
self.concat_after = False
|
||||
self.positionwise_layer_type = "linear"
|
||||
self.positionwise_conv_kernel_size = 1
|
||||
self.padding_idx = -1
|
||||
self.selfattention_layer_type = "sanm"
|
||||
self.dropout_rate = 0.1
|
||||
self.attention_dropout_rate = 0.1
|
||||
|
||||
self._output_size = self.output_size
|
||||
|
||||
self.embed = SinusoidalPositionEncoder()
|
||||
|
||||
positionwise_layer = PositionwiseFeedForward
|
||||
positionwise_layer_args = (
|
||||
self.output_size,
|
||||
self.linear_units,
|
||||
self.dropout_rate,
|
||||
)
|
||||
|
||||
encoder_selfattn_layer = MultiHeadedAttentionSANM
|
||||
encoder_selfattn_layer_args0 = (
|
||||
self.attention_heads,
|
||||
self.input_size,
|
||||
self.output_size,
|
||||
self.attention_dropout_rate,
|
||||
self.kernel_size,
|
||||
self.sanm_shfit,
|
||||
)
|
||||
encoder_selfattn_layer_args = (
|
||||
self.attention_heads,
|
||||
self.output_size,
|
||||
self.output_size,
|
||||
self.attention_dropout_rate,
|
||||
self.kernel_size,
|
||||
self.sanm_shfit,
|
||||
)
|
||||
|
||||
self.encoders0 = nn.ModuleList(
|
||||
[
|
||||
EncoderLayerSANM(
|
||||
self.input_size,
|
||||
self.output_size,
|
||||
encoder_selfattn_layer(*encoder_selfattn_layer_args0),
|
||||
positionwise_layer(*positionwise_layer_args),
|
||||
self.dropout_rate,
|
||||
)
|
||||
for i in range(1)
|
||||
]
|
||||
)
|
||||
|
||||
self.encoders = nn.ModuleList(
|
||||
[
|
||||
EncoderLayerSANM(
|
||||
self.output_size,
|
||||
self.output_size,
|
||||
encoder_selfattn_layer(*encoder_selfattn_layer_args),
|
||||
positionwise_layer(*positionwise_layer_args),
|
||||
self.dropout_rate,
|
||||
)
|
||||
for i in range(self.num_blocks - 1)
|
||||
]
|
||||
)
|
||||
|
||||
self.tp_encoders = nn.ModuleList(
|
||||
[
|
||||
EncoderLayerSANM(
|
||||
self.output_size,
|
||||
self.output_size,
|
||||
encoder_selfattn_layer(*encoder_selfattn_layer_args),
|
||||
positionwise_layer(*positionwise_layer_args),
|
||||
self.dropout_rate,
|
||||
)
|
||||
for i in range(self.tp_blocks)
|
||||
]
|
||||
)
|
||||
|
||||
self.after_norm = LayerNorm(self.output_size)
|
||||
|
||||
self.tp_norm = LayerNorm(self.output_size)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
xs_pad: torch.Tensor,
|
||||
):
|
||||
masks = None
|
||||
|
||||
xs_pad *= self.output_size**0.5
|
||||
|
||||
xs_pad = self.embed(xs_pad)
|
||||
|
||||
# forward encoder1
|
||||
for layer_idx, encoder_layer in enumerate(self.encoders0):
|
||||
encoder_outs = encoder_layer(xs_pad, masks)
|
||||
xs_pad, masks = encoder_outs[0], encoder_outs[1]
|
||||
|
||||
for layer_idx, encoder_layer in enumerate(self.encoders):
|
||||
encoder_outs = encoder_layer(xs_pad, masks)
|
||||
xs_pad, masks = encoder_outs[0], encoder_outs[1]
|
||||
|
||||
xs_pad = self.after_norm(xs_pad)
|
||||
|
||||
for layer_idx, encoder_layer in enumerate(self.tp_encoders):
|
||||
encoder_outs = encoder_layer(xs_pad, masks)
|
||||
xs_pad, masks = encoder_outs[0], encoder_outs[1]
|
||||
|
||||
xs_pad = self.tp_norm(xs_pad)
|
||||
return xs_pad
|
||||
|
||||
|
||||
class CTC(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
odim: int,
|
||||
encoder_output_size: int,
|
||||
dropout_rate: float = 0.0,
|
||||
ctc_type: str = "builtin",
|
||||
reduce: bool = True,
|
||||
ignore_nan_grad: bool = True,
|
||||
extra_linear: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
eprojs = encoder_output_size
|
||||
self.dropout_rate = dropout_rate
|
||||
|
||||
if extra_linear:
|
||||
self.ctc_lo = torch.nn.Linear(eprojs, odim)
|
||||
else:
|
||||
self.ctc_lo = None
|
||||
|
||||
def softmax(self, hs_pad):
|
||||
"""softmax of frame activations
|
||||
|
||||
Args:
|
||||
Tensor hs_pad: 3d tensor (B, Tmax, eprojs)
|
||||
Returns:
|
||||
torch.Tensor: softmax applied 3d tensor (B, Tmax, odim)
|
||||
"""
|
||||
if self.ctc_lo is not None:
|
||||
return F.softmax(self.ctc_lo(hs_pad), dim=2)
|
||||
else:
|
||||
return F.softmax(hs_pad, dim=2)
|
||||
|
||||
def log_softmax(self, hs_pad):
|
||||
"""log_softmax of frame activations
|
||||
|
||||
Args:
|
||||
Tensor hs_pad: 3d tensor (B, Tmax, eprojs)
|
||||
Returns:
|
||||
torch.Tensor: log softmax applied 3d tensor (B, Tmax, odim)
|
||||
"""
|
||||
if self.ctc_lo is not None:
|
||||
return F.log_softmax(self.ctc_lo(hs_pad), dim=2)
|
||||
else:
|
||||
return F.log_softmax(hs_pad, dim=2)
|
||||
|
||||
def argmax(self, hs_pad):
|
||||
"""argmax of frame activations
|
||||
|
||||
Args:
|
||||
torch.Tensor hs_pad: 3d tensor (B, Tmax, eprojs)
|
||||
Returns:
|
||||
torch.Tensor: argmax applied 2d tensor (B, Tmax)
|
||||
"""
|
||||
if self.ctc_lo is not None:
|
||||
return torch.argmax(self.ctc_lo(hs_pad), dim=2)
|
||||
else:
|
||||
return torch.argmax(hs_pad, dim=2)
|
||||
|
||||
|
||||
class SenseVoiceSmall(nn.Module):
|
||||
def __init__(self, neg_mean: torch.Tensor, inv_stddev: torch.Tensor):
|
||||
super().__init__()
|
||||
self.sos = 1
|
||||
self.eos = 2
|
||||
self.length_normalized_loss = True
|
||||
self.ignore_id = -1
|
||||
self.blank_id = 0
|
||||
self.input_size = 80 * 7
|
||||
self.vocab_size = 25055
|
||||
|
||||
self.neg_mean = neg_mean.unsqueeze(0).unsqueeze(0)
|
||||
self.inv_stddev = inv_stddev.unsqueeze(0).unsqueeze(0)
|
||||
|
||||
self.lid_dict = {
|
||||
"auto": 0,
|
||||
"zh": 3,
|
||||
"en": 4,
|
||||
"yue": 7,
|
||||
"ja": 11,
|
||||
"ko": 12,
|
||||
"nospeech": 13,
|
||||
}
|
||||
self.lid_int_dict = {
|
||||
24884: 3,
|
||||
24885: 4,
|
||||
24888: 7,
|
||||
24892: 11,
|
||||
24896: 12,
|
||||
24992: 13,
|
||||
}
|
||||
self.textnorm_dict = {"withitn": 14, "woitn": 15}
|
||||
self.textnorm_int_dict = {25016: 14, 25017: 15}
|
||||
|
||||
self.emo_dict = {
|
||||
"unk": 25009,
|
||||
"happy": 25001,
|
||||
"sad": 25002,
|
||||
"angry": 25003,
|
||||
"neutral": 25004,
|
||||
}
|
||||
|
||||
self.encoder = SenseVoiceEncoderSmall()
|
||||
self.ctc = CTC(
|
||||
odim=self.vocab_size,
|
||||
encoder_output_size=self.encoder.output_size,
|
||||
)
|
||||
self.embed = torch.nn.Embedding(
|
||||
7 + len(self.lid_dict) + len(self.textnorm_dict), self.input_size
|
||||
)
|
||||
|
||||
def forward(self, x, prompt):
|
||||
input_query = self.embed(prompt).unsqueeze(0)
|
||||
|
||||
# for export, we always assume x and self.neg_mean are on CPU
|
||||
x = (x + self.neg_mean) * self.inv_stddev
|
||||
x = torch.cat((input_query, x), dim=1)
|
||||
|
||||
encoder_out = self.encoder(x)
|
||||
logits = self.ctc.ctc_lo(encoder_out)
|
||||
|
||||
return logits
|
||||
@ -173,6 +173,8 @@ list(APPEND sources
|
||||
)
|
||||
if(SHERPA_ONNX_ENABLE_RKNN)
|
||||
list(APPEND sources
|
||||
./rknn/offline-ctc-greedy-search-decoder-rknn.cc
|
||||
./rknn/offline-sense-voice-model-rknn.cc
|
||||
./rknn/online-stream-rknn.cc
|
||||
./rknn/online-transducer-greedy-search-decoder-rknn.cc
|
||||
./rknn/online-transducer-modified-beam-search-decoder-rknn.cc
|
||||
|
||||
@ -7,6 +7,7 @@
|
||||
|
||||
#include "sherpa-onnx/csrc/file-utils.h"
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
#include "sherpa-onnx/csrc/text-utils.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
@ -57,9 +58,37 @@ void OfflineModelConfig::Register(ParseOptions *po) {
|
||||
}
|
||||
|
||||
bool OfflineModelConfig::Validate() const {
|
||||
if (num_threads < 1) {
|
||||
SHERPA_ONNX_LOGE("num_threads should be > 0. Given %d", num_threads);
|
||||
return false;
|
||||
// For RK NPU, we reinterpret num_threads:
|
||||
//
|
||||
// For RK3588 only
|
||||
// num_threads == 1 -> Select a core randomly
|
||||
// num_threads == 0 -> Use NPU core 0
|
||||
// num_threads == -1 -> Use NPU core 1
|
||||
// num_threads == -2 -> Use NPU core 2
|
||||
// num_threads == -3 -> Use NPU core 0 and core 1
|
||||
// num_threads == -4 -> Use NPU core 0, core 1, and core 2
|
||||
if (provider != "rknn") {
|
||||
if (num_threads < 1) {
|
||||
SHERPA_ONNX_LOGE("num_threads should be > 0. Given %d", num_threads);
|
||||
return false;
|
||||
}
|
||||
if (!sense_voice.model.empty() && (EndsWith(sense_voice.model, ".rknn"))) {
|
||||
SHERPA_ONNX_LOGE(
|
||||
"--provider is %s, which is not rknn, but you pass a rknn model "
|
||||
"filename. model: '%s'",
|
||||
provider.c_str(), sense_voice.model.c_str());
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
if (provider == "rknn") {
|
||||
if (!sense_voice.model.empty() && (EndsWith(sense_voice.model, ".onnx"))) {
|
||||
SHERPA_ONNX_LOGE(
|
||||
"--provider is rknn, but you pass an onnx model "
|
||||
"filename. model: '%s'",
|
||||
sense_voice.model.c_str());
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
if (!FileExists(tokens)) {
|
||||
|
||||
@ -35,10 +35,32 @@
|
||||
#include "sherpa-onnx/csrc/offline-recognizer-whisper-impl.h"
|
||||
#include "sherpa-onnx/csrc/text-utils.h"
|
||||
|
||||
#if SHERPA_ONNX_ENABLE_RKNN
|
||||
#include "sherpa-onnx/csrc/rknn/offline-recognizer-sense-voice-rknn-impl.h"
|
||||
#endif
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
|
||||
const OfflineRecognizerConfig &config) {
|
||||
if (config.model_config.provider == "rknn") {
|
||||
#if SHERPA_ONNX_ENABLE_RKNN
|
||||
if (config.model_config.sense_voice.model.empty()) {
|
||||
SHERPA_ONNX_LOGE(
|
||||
"Only SenseVoice models are currently supported "
|
||||
"by rknn for non-streaming ASR. Fallback to CPU");
|
||||
} else if (!config.model_config.sense_voice.model.empty()) {
|
||||
return std::make_unique<OfflineRecognizerSenseVoiceRknnImpl>(config);
|
||||
}
|
||||
#else
|
||||
SHERPA_ONNX_LOGE(
|
||||
"Please rebuild sherpa-onnx with -DSHERPA_ONNX_ENABLE_RKNN=ON if you "
|
||||
"want to use rknn.");
|
||||
SHERPA_ONNX_EXIT(-1);
|
||||
return nullptr;
|
||||
#endif
|
||||
}
|
||||
|
||||
if (!config.model_config.sense_voice.model.empty()) {
|
||||
return std::make_unique<OfflineRecognizerSenseVoiceImpl>(config);
|
||||
}
|
||||
@ -229,6 +251,24 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
|
||||
template <typename Manager>
|
||||
std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
|
||||
Manager *mgr, const OfflineRecognizerConfig &config) {
|
||||
if (config.model_config.provider == "rknn") {
|
||||
#if SHERPA_ONNX_ENABLE_RKNN
|
||||
if (config.model_config.sense_voice.model.empty()) {
|
||||
SHERPA_ONNX_LOGE(
|
||||
"Only SenseVoice models are currently supported "
|
||||
"by rknn for non-streaming ASR. Fallback to CPU");
|
||||
} else if (!config.model_config.sense_voice.model.empty()) {
|
||||
return std::make_unique<OfflineRecognizerSenseVoiceRknnImpl>(mgr, config);
|
||||
}
|
||||
#else
|
||||
SHERPA_ONNX_LOGE(
|
||||
"Please rebuild sherpa-onnx with -DSHERPA_ONNX_ENABLE_RKNN=ON if you "
|
||||
"want to use rknn.");
|
||||
SHERPA_ONNX_EXIT(-1);
|
||||
return nullptr;
|
||||
#endif
|
||||
}
|
||||
|
||||
if (!config.model_config.sense_voice.model.empty()) {
|
||||
return std::make_unique<OfflineRecognizerSenseVoiceImpl>(mgr, config);
|
||||
}
|
||||
|
||||
@ -11,6 +11,7 @@
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
#include "sherpa-onnx/csrc/offline-ctc-greedy-search-decoder.h"
|
||||
#include "sherpa-onnx/csrc/offline-model-config.h"
|
||||
#include "sherpa-onnx/csrc/offline-recognizer-impl.h"
|
||||
@ -21,7 +22,7 @@
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
static OfflineRecognitionResult ConvertSenseVoiceResult(
|
||||
OfflineRecognitionResult ConvertSenseVoiceResult(
|
||||
const OfflineCtcDecoderResult &src, const SymbolTable &sym_table,
|
||||
int32_t frame_shift_ms, int32_t subsampling_factor) {
|
||||
OfflineRecognitionResult r;
|
||||
@ -72,7 +73,7 @@ class OfflineRecognizerSenseVoiceImpl : public OfflineRecognizerImpl {
|
||||
} else {
|
||||
SHERPA_ONNX_LOGE("Only greedy_search is supported at present. Given %s",
|
||||
config.decoding_method.c_str());
|
||||
exit(-1);
|
||||
SHERPA_ONNX_EXIT(-1);
|
||||
}
|
||||
|
||||
InitFeatConfig();
|
||||
@ -93,7 +94,7 @@ class OfflineRecognizerSenseVoiceImpl : public OfflineRecognizerImpl {
|
||||
} else {
|
||||
SHERPA_ONNX_LOGE("Only greedy_search is supported at present. Given %s",
|
||||
config.decoding_method.c_str());
|
||||
exit(-1);
|
||||
SHERPA_ONNX_EXIT(-1);
|
||||
}
|
||||
|
||||
InitFeatConfig();
|
||||
|
||||
@ -37,7 +37,6 @@ std::unique_ptr<OnlineRecognizerImpl> OnlineRecognizerImpl::Create(
|
||||
const OnlineRecognizerConfig &config) {
|
||||
if (config.model_config.provider_config.provider == "rknn") {
|
||||
#if SHERPA_ONNX_ENABLE_RKNN
|
||||
// Currently, only zipformer v1 is suported for rknn
|
||||
if (config.model_config.transducer.encoder.empty() &&
|
||||
config.model_config.zipformer2_ctc.model.empty()) {
|
||||
SHERPA_ONNX_LOGE(
|
||||
|
||||
@ -0,0 +1,38 @@
|
||||
// sherpa-onnx/csrc/rknn/offline-ctc-greedy-search-decoder-rknn.cc
|
||||
//
|
||||
// Copyright (c) 2025 Xiaomi Corporation
|
||||
|
||||
#include "sherpa-onnx/csrc/rknn/offline-ctc-greedy-search-decoder-rknn.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
OfflineCtcDecoderResult OfflineCtcGreedySearchDecoderRknn::Decode(
|
||||
const float *logits, int32_t num_frames, int32_t vocab_size) {
|
||||
OfflineCtcDecoderResult ans;
|
||||
|
||||
int64_t prev_id = -1;
|
||||
|
||||
for (int32_t t = 0; t != num_frames; ++t) {
|
||||
auto y = static_cast<int64_t>(std::distance(
|
||||
static_cast<const float *>(logits),
|
||||
std::max_element(static_cast<const float *>(logits),
|
||||
static_cast<const float *>(logits) + vocab_size)));
|
||||
logits += vocab_size;
|
||||
|
||||
if (y != blank_id_ && y != prev_id) {
|
||||
ans.tokens.push_back(y);
|
||||
ans.timestamps.push_back(t);
|
||||
}
|
||||
prev_id = y;
|
||||
} // for (int32_t t = 0; ...)
|
||||
|
||||
return ans;
|
||||
}
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
@ -0,0 +1,28 @@
|
||||
// sherpa-onnx/csrc/rknn/offline-ctc-greedy-search-decoder-rknn.h
|
||||
//
|
||||
// Copyright (c) 2025 Xiaomi Corporation
|
||||
|
||||
#ifndef SHERPA_ONNX_CSRC_RKNN_OFFLINE_CTC_GREEDY_SEARCH_DECODER_RKNN_H_
|
||||
#define SHERPA_ONNX_CSRC_RKNN_OFFLINE_CTC_GREEDY_SEARCH_DECODER_RKNN_H_
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "sherpa-onnx/csrc/offline-ctc-decoder.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
class OfflineCtcGreedySearchDecoderRknn {
|
||||
public:
|
||||
explicit OfflineCtcGreedySearchDecoderRknn(int32_t blank_id)
|
||||
: blank_id_(blank_id) {}
|
||||
|
||||
OfflineCtcDecoderResult Decode(const float *logits, int32_t num_frames,
|
||||
int32_t vocab_size);
|
||||
|
||||
private:
|
||||
int32_t blank_id_;
|
||||
};
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
#endif // SHERPA_ONNX_CSRC_RKNN_OFFLINE_CTC_GREEDY_SEARCH_DECODER_RKNN_H_
|
||||
138
sherpa-onnx/csrc/rknn/offline-recognizer-sense-voice-rknn-impl.h
Normal file
138
sherpa-onnx/csrc/rknn/offline-recognizer-sense-voice-rknn-impl.h
Normal file
@ -0,0 +1,138 @@
|
||||
// sherpa-onnx/csrc/offline-recognizer-sense-voice-rknn-impl.h
|
||||
//
|
||||
// Copyright (c) 2025 Xiaomi Corporation
|
||||
|
||||
#ifndef SHERPA_ONNX_CSRC_RKNN_OFFLINE_RECOGNIZER_SENSE_VOICE_RKNN_IMPL_H_
|
||||
#define SHERPA_ONNX_CSRC_RKNN_OFFLINE_RECOGNIZER_SENSE_VOICE_RKNN_IMPL_H_
|
||||
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
#include "sherpa-onnx/csrc/offline-model-config.h"
|
||||
#include "sherpa-onnx/csrc/offline-recognizer-impl.h"
|
||||
#include "sherpa-onnx/csrc/offline-recognizer.h"
|
||||
#include "sherpa-onnx/csrc/rknn/offline-ctc-greedy-search-decoder-rknn.h"
|
||||
#include "sherpa-onnx/csrc/rknn/offline-sense-voice-model-rknn.h"
|
||||
#include "sherpa-onnx/csrc/symbol-table.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
// defined in ../online-recognizer-sense-voice-impl.h
|
||||
OfflineRecognitionResult ConvertSenseVoiceResult(
|
||||
const OfflineCtcDecoderResult &src, const SymbolTable &sym_table,
|
||||
int32_t frame_shift_ms, int32_t subsampling_factor);
|
||||
|
||||
class OfflineRecognizerSenseVoiceRknnImpl : public OfflineRecognizerImpl {
|
||||
public:
|
||||
explicit OfflineRecognizerSenseVoiceRknnImpl(
|
||||
const OfflineRecognizerConfig &config)
|
||||
: OfflineRecognizerImpl(config),
|
||||
config_(config),
|
||||
symbol_table_(config_.model_config.tokens),
|
||||
model_(
|
||||
std::make_unique<OfflineSenseVoiceModelRknn>(config.model_config)) {
|
||||
const auto &meta_data = model_->GetModelMetadata();
|
||||
if (config.decoding_method == "greedy_search") {
|
||||
decoder_ = std::make_unique<OfflineCtcGreedySearchDecoderRknn>(
|
||||
meta_data.blank_id);
|
||||
} else {
|
||||
SHERPA_ONNX_LOGE("Only greedy_search is supported at present. Given %s",
|
||||
config.decoding_method.c_str());
|
||||
SHERPA_ONNX_EXIT(-1);
|
||||
}
|
||||
|
||||
InitFeatConfig();
|
||||
}
|
||||
|
||||
template <typename Manager>
|
||||
OfflineRecognizerSenseVoiceRknnImpl(Manager *mgr,
|
||||
const OfflineRecognizerConfig &config)
|
||||
: OfflineRecognizerImpl(mgr, config),
|
||||
config_(config),
|
||||
symbol_table_(mgr, config_.model_config.tokens),
|
||||
model_(std::make_unique<OfflineSenseVoiceModelRknn>(
|
||||
mgr, config.model_config)) {
|
||||
const auto &meta_data = model_->GetModelMetadata();
|
||||
if (config.decoding_method == "greedy_search") {
|
||||
decoder_ = std::make_unique<OfflineCtcGreedySearchDecoderRknn>(
|
||||
meta_data.blank_id);
|
||||
} else {
|
||||
SHERPA_ONNX_LOGE("Only greedy_search is supported at present. Given %s",
|
||||
config.decoding_method.c_str());
|
||||
SHERPA_ONNX_EXIT(-1);
|
||||
}
|
||||
|
||||
InitFeatConfig();
|
||||
}
|
||||
|
||||
std::unique_ptr<OfflineStream> CreateStream() const override {
|
||||
return std::make_unique<OfflineStream>(config_.feat_config);
|
||||
}
|
||||
|
||||
void DecodeStreams(OfflineStream **ss, int32_t n) const override {
|
||||
for (int32_t i = 0; i < n; ++i) {
|
||||
DecodeOneStream(ss[i]);
|
||||
}
|
||||
}
|
||||
|
||||
OfflineRecognizerConfig GetConfig() const override { return config_; }
|
||||
|
||||
private:
|
||||
void InitFeatConfig() {
|
||||
const auto &meta_data = model_->GetModelMetadata();
|
||||
|
||||
config_.feat_config.normalize_samples = meta_data.normalize_samples;
|
||||
config_.feat_config.window_type = "hamming";
|
||||
config_.feat_config.high_freq = 0;
|
||||
config_.feat_config.snip_edges = true;
|
||||
}
|
||||
|
||||
void DecodeOneStream(OfflineStream *s) const {
|
||||
const auto &meta_data = model_->GetModelMetadata();
|
||||
|
||||
std::vector<float> f = s->GetFrames();
|
||||
|
||||
int32_t language = 0;
|
||||
if (config_.model_config.sense_voice.language.empty()) {
|
||||
language = 0;
|
||||
} else if (meta_data.lang2id.count(
|
||||
config_.model_config.sense_voice.language)) {
|
||||
language =
|
||||
meta_data.lang2id.at(config_.model_config.sense_voice.language);
|
||||
} else {
|
||||
SHERPA_ONNX_LOGE("Unknown language: %s. Use 0 instead.",
|
||||
config_.model_config.sense_voice.language.c_str());
|
||||
}
|
||||
|
||||
int32_t text_norm = config_.model_config.sense_voice.use_itn
|
||||
? meta_data.with_itn_id
|
||||
: meta_data.without_itn_id;
|
||||
|
||||
std::vector<float> logits = model_->Run(std::move(f), language, text_norm);
|
||||
int32_t num_out_frames = logits.size() / meta_data.vocab_size;
|
||||
|
||||
auto result =
|
||||
decoder_->Decode(logits.data(), num_out_frames, meta_data.vocab_size);
|
||||
|
||||
int32_t frame_shift_ms = 10;
|
||||
int32_t subsampling_factor = meta_data.window_shift;
|
||||
auto r = ConvertSenseVoiceResult(result, symbol_table_, frame_shift_ms,
|
||||
subsampling_factor);
|
||||
|
||||
r.text = ApplyInverseTextNormalization(std::move(r.text));
|
||||
r.text = ApplyHomophoneReplacer(std::move(r.text));
|
||||
s->SetResult(r);
|
||||
}
|
||||
|
||||
private:
|
||||
OfflineRecognizerConfig config_;
|
||||
SymbolTable symbol_table_;
|
||||
std::unique_ptr<OfflineSenseVoiceModelRknn> model_;
|
||||
std::unique_ptr<OfflineCtcGreedySearchDecoderRknn> decoder_;
|
||||
};
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
#endif // SHERPA_ONNX_CSRC_RKNN_OFFLINE_RECOGNIZER_SENSE_VOICE_RKNN_IMPL_H_
|
||||
244
sherpa-onnx/csrc/rknn/offline-sense-voice-model-rknn.cc
Normal file
244
sherpa-onnx/csrc/rknn/offline-sense-voice-model-rknn.cc
Normal file
@ -0,0 +1,244 @@
|
||||
// sherpa-onnx/csrc/rknn/offline-sense-voice-model-rknn.cc
|
||||
//
|
||||
// Copyright (c) 2025 Xiaomi Corporation
|
||||
|
||||
#include "sherpa-onnx/csrc/rknn/offline-sense-voice-model-rknn.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <array>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
#include "android/asset_manager.h"
|
||||
#include "android/asset_manager_jni.h"
|
||||
#endif
|
||||
|
||||
#if __OHOS__
|
||||
#include "rawfile/raw_file_manager.h"
|
||||
#endif
|
||||
|
||||
#include "sherpa-onnx/csrc/file-utils.h"
|
||||
#include "sherpa-onnx/csrc/rknn/macros.h"
|
||||
#include "sherpa-onnx/csrc/rknn/utils.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
class OfflineSenseVoiceModelRknn::Impl {
|
||||
public:
|
||||
~Impl() {
|
||||
auto ret = rknn_destroy(ctx_);
|
||||
if (ret != RKNN_SUCC) {
|
||||
SHERPA_ONNX_LOGE("Failed to destroy the context");
|
||||
}
|
||||
}
|
||||
|
||||
explicit Impl(const OfflineModelConfig &config) : config_(config) {
|
||||
{
|
||||
auto buf = ReadFile(config_.sense_voice.model);
|
||||
Init(buf.data(), buf.size());
|
||||
}
|
||||
|
||||
SetCoreMask(ctx_, config_.num_threads);
|
||||
}
|
||||
|
||||
template <typename Manager>
|
||||
Impl(Manager *mgr, const OfflineModelConfig &config) : config_(config) {
|
||||
{
|
||||
auto buf = ReadFile(mgr, config_.sense_voice.model);
|
||||
Init(buf.data(), buf.size());
|
||||
}
|
||||
|
||||
SetCoreMask(ctx_, config_.num_threads);
|
||||
}
|
||||
|
||||
const OfflineSenseVoiceModelMetaData &GetModelMetadata() const {
|
||||
return meta_data_;
|
||||
}
|
||||
|
||||
std::vector<float> Run(std::vector<float> features, int32_t language,
|
||||
int32_t text_norm) {
|
||||
features = ApplyLFR(std::move(features));
|
||||
|
||||
std::vector<rknn_input> inputs(input_attrs_.size());
|
||||
|
||||
std::array<int32_t, 4> prompt{language, 1, 2, text_norm};
|
||||
|
||||
inputs[0].index = input_attrs_[0].index;
|
||||
inputs[0].type = RKNN_TENSOR_FLOAT32;
|
||||
inputs[0].fmt = input_attrs_[0].fmt;
|
||||
inputs[0].buf = reinterpret_cast<void *>(features.data());
|
||||
inputs[0].size = features.size() * sizeof(float);
|
||||
|
||||
inputs[1].index = input_attrs_[1].index;
|
||||
inputs[1].type = RKNN_TENSOR_INT32;
|
||||
inputs[1].fmt = input_attrs_[1].fmt;
|
||||
inputs[1].buf = reinterpret_cast<void *>(prompt.data());
|
||||
inputs[1].size = prompt.size() * sizeof(int32_t);
|
||||
|
||||
std::vector<float> out(output_attrs_[0].n_elems);
|
||||
|
||||
std::vector<rknn_output> outputs(output_attrs_.size());
|
||||
outputs[0].index = output_attrs_[0].index;
|
||||
outputs[0].is_prealloc = 1;
|
||||
outputs[0].want_float = 1;
|
||||
outputs[0].size = out.size() * sizeof(float);
|
||||
outputs[0].buf = reinterpret_cast<void *>(out.data());
|
||||
|
||||
rknn_context ctx = 0;
|
||||
auto ret = rknn_dup_context(&ctx_, &ctx);
|
||||
SHERPA_ONNX_RKNN_CHECK(ret, "Failed to duplicate the ctx");
|
||||
|
||||
ret = rknn_inputs_set(ctx, inputs.size(), inputs.data());
|
||||
SHERPA_ONNX_RKNN_CHECK(ret, "Failed to set inputs");
|
||||
|
||||
ret = rknn_run(ctx, nullptr);
|
||||
SHERPA_ONNX_RKNN_CHECK(ret, "Failed to run the model");
|
||||
|
||||
ret = rknn_outputs_get(ctx, outputs.size(), outputs.data(), nullptr);
|
||||
SHERPA_ONNX_RKNN_CHECK(ret, "Failed to get model output");
|
||||
|
||||
rknn_destroy(ctx);
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
private:
|
||||
void Init(void *model_data, size_t model_data_length) {
|
||||
InitContext(model_data, model_data_length, config_.debug, &ctx_);
|
||||
|
||||
InitInputOutputAttrs(ctx_, config_.debug, &input_attrs_, &output_attrs_);
|
||||
|
||||
rknn_custom_string custom_string = GetCustomString(ctx_, config_.debug);
|
||||
|
||||
auto meta = Parse(custom_string, config_.debug);
|
||||
|
||||
#define SHERPA_ONNX_RKNN_READ_META_DATA_INT(dst, src_key) \
|
||||
do { \
|
||||
if (!meta.count(#src_key)) { \
|
||||
SHERPA_ONNX_LOGE("'%s' does not exist in the custom_string", #src_key); \
|
||||
SHERPA_ONNX_EXIT(-1); \
|
||||
} \
|
||||
\
|
||||
dst = atoi(meta.at(#src_key).c_str()); \
|
||||
} while (0)
|
||||
|
||||
SHERPA_ONNX_RKNN_READ_META_DATA_INT(meta_data_.with_itn_id, with_itn);
|
||||
SHERPA_ONNX_RKNN_READ_META_DATA_INT(meta_data_.without_itn_id, without_itn);
|
||||
SHERPA_ONNX_RKNN_READ_META_DATA_INT(meta_data_.window_size,
|
||||
lfr_window_size);
|
||||
SHERPA_ONNX_RKNN_READ_META_DATA_INT(meta_data_.window_shift,
|
||||
lfr_window_shift);
|
||||
SHERPA_ONNX_RKNN_READ_META_DATA_INT(meta_data_.vocab_size, vocab_size);
|
||||
SHERPA_ONNX_RKNN_READ_META_DATA_INT(meta_data_.normalize_samples,
|
||||
normalize_samples);
|
||||
|
||||
int32_t lang_auto = 0;
|
||||
int32_t lang_zh = 0;
|
||||
int32_t lang_en = 0;
|
||||
int32_t lang_ja = 0;
|
||||
int32_t lang_ko = 0;
|
||||
int32_t lang_yue = 0;
|
||||
|
||||
SHERPA_ONNX_RKNN_READ_META_DATA_INT(lang_auto, lang_auto);
|
||||
SHERPA_ONNX_RKNN_READ_META_DATA_INT(lang_zh, lang_zh);
|
||||
SHERPA_ONNX_RKNN_READ_META_DATA_INT(lang_en, lang_en);
|
||||
SHERPA_ONNX_RKNN_READ_META_DATA_INT(lang_ja, lang_ja);
|
||||
SHERPA_ONNX_RKNN_READ_META_DATA_INT(lang_ko, lang_ko);
|
||||
SHERPA_ONNX_RKNN_READ_META_DATA_INT(lang_yue, lang_yue);
|
||||
|
||||
meta_data_.lang2id = {
|
||||
{"auto", lang_auto}, {"zh", lang_zh}, {"en", lang_en},
|
||||
{"ja", lang_ja}, {"ko", lang_ko}, {"yue", lang_yue},
|
||||
};
|
||||
|
||||
// for rknn models, neg_mean and inv_stddev are stored inside the model
|
||||
|
||||
#undef SHERPA_ONNX_RKNN_READ_META_DATA_INT
|
||||
|
||||
num_input_frames_ = input_attrs_[0].dims[1];
|
||||
}
|
||||
|
||||
std::vector<float> ApplyLFR(std::vector<float> in) const {
|
||||
int32_t lfr_window_size = meta_data_.window_size;
|
||||
int32_t lfr_window_shift = meta_data_.window_shift;
|
||||
int32_t in_feat_dim = 80;
|
||||
|
||||
int32_t in_num_frames = in.size() / in_feat_dim;
|
||||
int32_t out_num_frames =
|
||||
(in_num_frames - lfr_window_size) / lfr_window_shift + 1;
|
||||
|
||||
if (out_num_frames > num_input_frames_) {
|
||||
SHERPA_ONNX_LOGE(
|
||||
"Number of input frames %d is too large. Truncate it to %d frames.",
|
||||
out_num_frames, num_input_frames_);
|
||||
|
||||
SHERPA_ONNX_LOGE(
|
||||
"Recognition result may be truncated/incomplete. Please select a "
|
||||
"model accepting longer audios.");
|
||||
|
||||
out_num_frames = num_input_frames_;
|
||||
}
|
||||
|
||||
int32_t out_feat_dim = in_feat_dim * lfr_window_size;
|
||||
|
||||
std::vector<float> out(num_input_frames_ * out_feat_dim);
|
||||
|
||||
const float *p_in = in.data();
|
||||
float *p_out = out.data();
|
||||
|
||||
for (int32_t i = 0; i != out_num_frames; ++i) {
|
||||
std::copy(p_in, p_in + out_feat_dim, p_out);
|
||||
|
||||
p_out += out_feat_dim;
|
||||
p_in += lfr_window_shift * in_feat_dim;
|
||||
}
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
private:
|
||||
OfflineModelConfig config_;
|
||||
|
||||
rknn_context ctx_ = 0;
|
||||
|
||||
std::vector<rknn_tensor_attr> input_attrs_;
|
||||
std::vector<rknn_tensor_attr> output_attrs_;
|
||||
|
||||
OfflineSenseVoiceModelMetaData meta_data_;
|
||||
int32_t num_input_frames_ = -1;
|
||||
};
|
||||
|
||||
OfflineSenseVoiceModelRknn::~OfflineSenseVoiceModelRknn() = default;
|
||||
|
||||
OfflineSenseVoiceModelRknn::OfflineSenseVoiceModelRknn(
|
||||
const OfflineModelConfig &config)
|
||||
: impl_(std::make_unique<Impl>(config)) {}
|
||||
|
||||
template <typename Manager>
|
||||
OfflineSenseVoiceModelRknn::OfflineSenseVoiceModelRknn(
|
||||
Manager *mgr, const OfflineModelConfig &config)
|
||||
: impl_(std::make_unique<Impl>(mgr, config)) {}
|
||||
|
||||
std::vector<float> OfflineSenseVoiceModelRknn::Run(std::vector<float> features,
|
||||
int32_t language,
|
||||
int32_t text_norm) const {
|
||||
return impl_->Run(std::move(features), language, text_norm);
|
||||
}
|
||||
|
||||
const OfflineSenseVoiceModelMetaData &
|
||||
OfflineSenseVoiceModelRknn::GetModelMetadata() const {
|
||||
return impl_->GetModelMetadata();
|
||||
}
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
template OfflineSenseVoiceModelRknn::OfflineSenseVoiceModelRknn(
|
||||
AAssetManager *mgr, const OfflineModelConfig &config);
|
||||
#endif
|
||||
|
||||
#if __OHOS__
|
||||
template OfflineSenseVoiceModelRknn::OfflineSenseVoiceModelRknn(
|
||||
NativeResourceManager *mgr, const OfflineModelConfig &config);
|
||||
#endif
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
45
sherpa-onnx/csrc/rknn/offline-sense-voice-model-rknn.h
Normal file
45
sherpa-onnx/csrc/rknn/offline-sense-voice-model-rknn.h
Normal file
@ -0,0 +1,45 @@
|
||||
// sherpa-onnx/csrc/rknn/offline-sense-voice-model-rknn.h
|
||||
//
|
||||
// Copyright (c) 2025 Xiaomi Corporation
|
||||
#ifndef SHERPA_ONNX_CSRC_RKNN_OFFLINE_SENSE_VOICE_MODEL_RKNN_H_
|
||||
#define SHERPA_ONNX_CSRC_RKNN_OFFLINE_SENSE_VOICE_MODEL_RKNN_H_
|
||||
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "rknn_api.h" // NOLINT
|
||||
#include "sherpa-onnx/csrc/offline-model-config.h"
|
||||
#include "sherpa-onnx/csrc/offline-sense-voice-model-meta-data.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
class OfflineSenseVoiceModelRknn {
|
||||
public:
|
||||
~OfflineSenseVoiceModelRknn();
|
||||
|
||||
explicit OfflineSenseVoiceModelRknn(const OfflineModelConfig &config);
|
||||
|
||||
template <typename Manager>
|
||||
OfflineSenseVoiceModelRknn(Manager *mgr, const OfflineModelConfig &config);
|
||||
|
||||
/**
|
||||
* @param features A tensor of shape (num_frames, feature_dim)
|
||||
* before applying LFR.
|
||||
* @param language
|
||||
* @param text_norm
|
||||
* @returns Return a tensor of shape (num_output_frames, vocab_size)
|
||||
*/
|
||||
std::vector<float> Run(std::vector<float> features, int32_t language,
|
||||
int32_t text_norm) const;
|
||||
|
||||
const OfflineSenseVoiceModelMetaData &GetModelMetadata() const;
|
||||
|
||||
private:
|
||||
class Impl;
|
||||
std::unique_ptr<Impl> impl_;
|
||||
};
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
#endif // SHERPA_ONNX_CSRC_RKNN_OFFLINE_SENSE_VOICE_MODEL_RKNN_H_
|
||||
@ -55,8 +55,6 @@ class OnlineZipformerCtcModelRknn::Impl {
|
||||
SetCoreMask(ctx_, config_.num_threads);
|
||||
}
|
||||
|
||||
// TODO(fangjun): Support Android
|
||||
|
||||
std::vector<std::vector<uint8_t>> GetInitStates() const {
|
||||
// input_attrs_[0] is for the feature
|
||||
// input_attrs_[1:] is for states
|
||||
|
||||
@ -89,8 +89,6 @@ class OnlineZipformerTransducerModelRknn::Impl {
|
||||
SetCoreMask(joiner_ctx_, config_.num_threads);
|
||||
}
|
||||
|
||||
// TODO(fangjun): Support Android
|
||||
|
||||
std::vector<std::vector<uint8_t>> GetEncoderInitStates() const {
|
||||
// encoder_input_attrs_[0] is for the feature
|
||||
// encoder_input_attrs_[1:] is for states
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user