Support RK NPU for SenseVoice non-streaming ASR models (#2589)

This PR adds RK NPU support for SenseVoice non-streaming ASR models by implementing a new RKNN backend with greedy CTC decoding.

- Adds offline RKNN implementation for SenseVoice models including model loading, feature processing, and CTC decoding
- Introduces export tools to convert SenseVoice models from PyTorch to ONNX and then to RKNN format
- Implements provider-aware validation to prevent mismatched model and provider usage
This commit is contained in:
Fangjun Kuang 2025-09-12 10:46:38 +08:00 committed by GitHub
parent 926b288525
commit c691318b95
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
18 changed files with 1740 additions and 11 deletions

4
.gitignore vendored
View File

@ -152,3 +152,7 @@ vocab.json
*.so
sherpa-onnx-streaming-t-one-russian-2025-09-08
sherpa-onnx-wenetspeech-yue-u2pp-conformer-ctc-zh-en-cantonese-int8-2025-09-10
am.mvn
*bpe.model
config.yaml
configuration.json

View File

@ -118,8 +118,11 @@ def display_params(params):
os.system(f"cat {params['config']}")
@torch.no_grad()
def main():
model, params = SenseVoiceSmall.from_pretrained(model="iic/SenseVoiceSmall", device="cpu")
model.eval()
display_params(params)
generate_tokens(params)

View File

@ -0,0 +1,164 @@
#!/usr/bin/env python3
# Copyright 2025 Xiaomi Corp. (authors: Fangjun Kuang)
import argparse
import os
from typing import Any, Dict, List, Tuple
import onnx
import sentencepiece as spm
import torch
from torch_model import SenseVoiceSmall
def get_args():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--input-len-in-seconds",
type=int,
required=True,
help="""RKNN does not support dynamic shape, so we need to hard-code
how long the model can process.
""",
)
return parser.parse_args()
def add_meta_data(filename: str, meta_data: Dict[str, Any]):
"""Add meta data to an ONNX model. It is changed in-place.
Args:
filename:
Filename of the ONNX model to be changed.
meta_data:
Key-value pairs.
"""
model = onnx.load(filename)
while len(model.metadata_props):
model.metadata_props.pop()
for key, value in meta_data.items():
meta = model.metadata_props.add()
meta.key = key
meta.value = str(value)
onnx.save(model, filename)
def load_cmvn(filename) -> Tuple[List[float], List[float]]:
neg_mean = None
inv_stddev = None
with open(filename) as f:
for line in f:
if not line.startswith("<LearnRateCoef>"):
continue
t = line.split()[3:-1]
if neg_mean is None:
neg_mean = list(map(lambda x: float(x), t))
else:
inv_stddev = list(map(lambda x: float(x), t))
return neg_mean, inv_stddev
def generate_tokens(sp):
with open("tokens.txt", "w", encoding="utf-8") as f:
for i in range(sp.vocab_size()):
f.write(f"{sp.id_to_piece(i)} {i}\n")
print("saved to tokens.txt")
@torch.no_grad()
def main():
args = get_args()
print(vars(args))
sp = spm.SentencePieceProcessor()
sp.load("./chn_jpn_yue_eng_ko_spectok.bpe.model")
vocab_size = sp.vocab_size()
generate_tokens(sp)
print("loading model")
state_dict = torch.load("./model.pt")
if "state_dict" in state_dict:
state_dict = state_dict["state_dict"]
neg_mean, inv_stddev = load_cmvn("./am.mvn")
neg_mean = torch.tensor(neg_mean, dtype=torch.float32)
inv_stddev = torch.tensor(inv_stddev, dtype=torch.float32)
model = SenseVoiceSmall(neg_mean=neg_mean, inv_stddev=inv_stddev)
model.load_state_dict(state_dict)
model.eval()
del state_dict
lfr_window_size = 7
lfr_window_shift = 6
# frame shift is 10ms, 1 second has about 100 feature frames
input_len_in_seconds = int(args.input_len_in_seconds)
num_frames = input_len_in_seconds * 100
print("num_frames", num_frames)
# num_input_frames is an approximate number
num_input_frames = int(num_frames / lfr_window_shift + 0.5)
print("num_input_frames", num_input_frames)
x = torch.randn(1, num_input_frames, 560, dtype=torch.float32)
language = 3
text_norm = 15
prompt = torch.tensor([language, 1, 2, text_norm], dtype=torch.int32)
opset_version = 13
filename = f"model-{input_len_in_seconds}-seconds.onnx"
torch.onnx.export(
model,
(x, prompt),
filename,
opset_version=opset_version,
input_names=["x", "prompt"],
output_names=["logits"],
dynamic_axes={},
)
model_author = os.environ.get("model_author", "iic")
comment = os.environ.get("comment", "iic/SenseVoiceSmall")
url = os.environ.get("url", "https://huggingface.co/FunAudioLLM/SenseVoiceSmall")
meta_data = {
"lfr_window_size": lfr_window_size,
"lfr_window_shift": lfr_window_shift,
"num_input_frames": num_input_frames,
"normalize_samples": 0, # input should be in the range [-32768, 32767]
"model_type": "sense_voice_ctc",
"version": "1",
"model_author": model_author,
"maintainer": "k2-fsa",
"vocab_size": vocab_size,
"comment": comment,
"lang_auto": model.lid_dict["auto"],
"lang_zh": model.lid_dict["zh"],
"lang_en": model.lid_dict["en"],
"lang_yue": model.lid_dict["yue"], # cantonese
"lang_ja": model.lid_dict["ja"],
"lang_ko": model.lid_dict["ko"],
"lang_nospeech": model.lid_dict["nospeech"],
"with_itn": model.textnorm_dict["withitn"],
"without_itn": model.textnorm_dict["woitn"],
"url": url,
}
add_meta_data(filename=filename, meta_data=meta_data)
if __name__ == "__main__":
torch.manual_seed(20250717)
main()

View File

@ -0,0 +1,158 @@
#!/usr/bin/env python3
# Copyright (c) 2025 Xiaomi Corporation (authors: Fangjun Kuang)
import argparse
import logging
from pathlib import Path
from rknn.api import RKNN
logging.basicConfig(level=logging.WARNING)
g_platforms = [
# "rv1103",
# "rv1103b",
# "rv1106",
# "rk2118",
"rk3562",
"rk3566",
"rk3568",
"rk3576",
"rk3588",
]
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--target-platform",
type=str,
required=True,
help=f"Supported values are: {','.join(g_platforms)}",
)
parser.add_argument(
"--in-model",
type=str,
required=True,
help="Path to the input onnx model",
)
parser.add_argument(
"--out-model",
type=str,
required=True,
help="Path to the output rknn model",
)
return parser
def get_meta_data(model: str):
import onnxruntime
session_opts = onnxruntime.SessionOptions()
session_opts.inter_op_num_threads = 1
session_opts.intra_op_num_threads = 1
m = onnxruntime.InferenceSession(
model,
sess_options=session_opts,
providers=["CPUExecutionProvider"],
)
for i in m.get_inputs():
print(i)
print("-----")
for i in m.get_outputs():
print(i)
print()
meta = m.get_modelmeta().custom_metadata_map
s = ""
sep = ""
for key, value in meta.items():
if key in ("neg_mean", "inv_stddev"):
continue
s = s + sep + f"{key}={value}"
sep = ";"
assert len(s) < 1024, len(s)
print("len(s)", len(s), s)
return s
def export_rknn(rknn, filename):
ret = rknn.export_rknn(filename)
if ret != 0:
exit(f"Export rknn model to {filename} failed!")
def init_model(filename: str, target_platform: str, custom_string=None):
rknn = RKNN(verbose=False)
rknn.config(
optimization_level=0,
target_platform=target_platform,
custom_string=custom_string,
)
if not Path(filename).is_file():
exit(f"{filename} does not exist")
ret = rknn.load_onnx(model=filename)
if ret != 0:
exit(f"Load model {filename} failed!")
ret = rknn.build(do_quantization=False)
if ret != 0:
exit(f"Build model {filename} failed!")
return rknn
class RKNNModel:
def __init__(
self,
model: str,
target_platform: str,
):
meta = get_meta_data(model)
print(meta)
self.model = init_model(
model,
target_platform=target_platform,
custom_string=meta,
)
def export_rknn(self, model):
export_rknn(self.model, model)
def release(self):
self.model.release()
def main():
args = get_parser().parse_args()
print(vars(args))
model = RKNNModel(
model=args.in_model,
target_platform=args.target_platform,
)
model.export_rknn(
model=args.out_model,
)
model.release()
if __name__ == "__main__":
main()

View File

@ -0,0 +1,236 @@
#!/usr/bin/env python3
# Copyright 2025 Xiaomi Corp. (authors: Fangjun Kuang)
import argparse
from typing import Tuple
import kaldi_native_fbank as knf
import numpy as np
import onnxruntime
import onnxruntime as ort
import soundfile as sf
import torch
def get_args():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--model",
type=str,
required=True,
help="Path to model.onnx",
)
parser.add_argument(
"--tokens",
type=str,
required=True,
help="Path to tokens.txt",
)
parser.add_argument(
"--wave",
type=str,
required=True,
help="The input wave to be recognized",
)
parser.add_argument(
"--language",
type=str,
default="auto",
help="the language of the input wav file. Supported values: zh, en, ja, ko, yue, auto",
)
parser.add_argument(
"--use-itn",
type=int,
default=0,
help="1 to use inverse text normalization. 0 to not use inverse text normalization",
)
return parser.parse_args()
class OnnxModel:
def __init__(self, filename):
session_opts = ort.SessionOptions()
session_opts.inter_op_num_threads = 1
session_opts.intra_op_num_threads = 1
self.session_opts = session_opts
self.model = ort.InferenceSession(
filename,
sess_options=self.session_opts,
providers=["CPUExecutionProvider"],
)
meta = self.model.get_modelmeta().custom_metadata_map
self.window_size = int(meta["lfr_window_size"]) # lfr_m
self.window_shift = int(meta["lfr_window_shift"]) # lfr_n
lang_zh = int(meta["lang_zh"])
lang_en = int(meta["lang_en"])
lang_ja = int(meta["lang_ja"])
lang_ko = int(meta["lang_ko"])
lang_yue = int(meta["lang_yue"])
lang_auto = int(meta["lang_auto"])
self.lang_id = {
"zh": lang_zh,
"en": lang_en,
"ja": lang_ja,
"ko": lang_ko,
"yue": lang_yue,
"auto": lang_auto,
}
self.with_itn = int(meta["with_itn"])
self.without_itn = int(meta["without_itn"])
self.max_len = self.model.get_inputs()[0].shape[1]
def __call__(self, x, prompt):
logits = self.model.run(
[
self.model.get_outputs()[0].name,
],
{
self.model.get_inputs()[0].name: x.numpy(),
self.model.get_inputs()[1].name: prompt.numpy(),
},
)[0]
return torch.from_numpy(logits)
def load_audio(filename: str) -> Tuple[np.ndarray, int]:
data, sample_rate = sf.read(
filename,
always_2d=True,
dtype="float32",
)
data = data[:, 0] # use only the first channel
samples = np.ascontiguousarray(data)
return samples, sample_rate
def load_tokens(filename):
ans = dict()
i = 0
with open(filename, encoding="utf-8") as f:
for line in f:
ans[i] = line.strip().split()[0]
i += 1
return ans
def compute_feat(
samples,
sample_rate,
max_len: int,
window_size: int = 7, # lfr_m
window_shift: int = 6, # lfr_n
):
opts = knf.FbankOptions()
opts.frame_opts.dither = 0
opts.frame_opts.snip_edges = False
opts.frame_opts.window_type = "hamming"
opts.frame_opts.samp_freq = sample_rate
opts.mel_opts.num_bins = 80
online_fbank = knf.OnlineFbank(opts)
online_fbank.accept_waveform(sample_rate, (samples * 32768).tolist())
online_fbank.input_finished()
features = np.stack(
[online_fbank.get_frame(i) for i in range(online_fbank.num_frames_ready)]
)
assert features.data.contiguous is True
assert features.dtype == np.float32, features.dtype
T = (features.shape[0] - window_size) // window_shift + 1
features = np.lib.stride_tricks.as_strided(
features,
shape=(T, features.shape[1] * window_size),
strides=((window_shift * features.shape[1]) * 4, 4),
)
print("features.shape", features.shape)
if features.shape[0] > max_len:
features = features[:max_len]
elif features.shape[0] < max_len:
features = np.pad(
features,
((0, max_len - features.shape[0]), (0, 0)),
mode="constant",
constant_values=0,
)
print("features.shape", features.shape)
return features
def main():
args = get_args()
print(vars(args))
samples, sample_rate = load_audio(args.wave)
if sample_rate != 16000:
import librosa
samples = librosa.resample(samples, orig_sr=sample_rate, target_sr=16000)
sample_rate = 16000
model = OnnxModel(filename=args.model)
features = compute_feat(
samples=samples,
sample_rate=sample_rate,
max_len=model.max_len,
window_size=model.window_size,
window_shift=model.window_shift,
)
features = torch.from_numpy(features).unsqueeze(0)
language = model.lang_id["auto"]
if args.language in model.lang_id:
language = model.lang_id[args.language]
else:
print(f"Invalid language: '{args.language}'")
print("Use auto")
if args.use_itn:
text_norm = model.with_itn
else:
text_norm = model.without_itn
prompt = torch.tensor([language, 1, 2, text_norm], dtype=torch.int32)
logits = model(
x=features,
prompt=prompt,
)
idx = logits.squeeze(0).argmax(dim=-1)
# idx is of shape (T,)
idx = torch.unique_consecutive(idx)
blank_id = 0
idx = idx[idx != blank_id].tolist()
tokens = load_tokens(args.tokens)
text = "".join([tokens[i] for i in idx])
text = text.replace("", " ")
print(text)
if __name__ == "__main__":
main()

View File

@ -0,0 +1,604 @@
# This file is modified from
# https://github.com/modelscope/FunASR/blob/main/funasr/models/sense_voice/model.py
import torch
import torch.nn
import torch.nn as nn
import torch.nn.functional as F
class SinusoidalPositionEncoder(nn.Module):
""" """
def __init__(self, d_model=80, dropout_rate=0.1):
pass
def encode(
self,
positions: torch.Tensor = None,
depth: int = None,
dtype: torch.dtype = torch.float32,
):
"""
Args:
positions: (batch_size, )
"""
batch_size = positions.size(0)
positions = positions.type(dtype)
device = positions.device
log_timescale_increment = torch.log(
torch.tensor([10000], dtype=dtype, device=device)
) / (depth / 2 - 1)
inv_timescales = torch.exp(
torch.arange(depth / 2, device=device).type(dtype)
* (-log_timescale_increment)
)
inv_timescales = torch.reshape(inv_timescales, [batch_size, -1])
scaled_time = torch.reshape(positions, [1, -1, 1]) * torch.reshape(
inv_timescales, [1, 1, -1]
)
encoding = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=2)
return encoding.type(dtype)
def forward(self, x):
batch_size, timesteps, input_dim = x.size()
positions = torch.arange(1, timesteps + 1, device=x.device)[None, :]
position_encoding = self.encode(positions, input_dim, x.dtype).to(x.device)
return x + position_encoding
class PositionwiseFeedForward(nn.Module):
"""Positionwise feed forward layer.
Args:
idim (int): Input dimenstion.
hidden_units (int): The number of hidden units.
dropout_rate (float): Dropout rate.
"""
def __init__(self, idim, hidden_units, dropout_rate, activation=None):
super().__init__()
self.w_1 = torch.nn.Linear(idim, hidden_units)
self.w_2 = torch.nn.Linear(hidden_units, idim)
self.dropout = torch.nn.Dropout(dropout_rate)
if activation is None:
activation = torch.nn.ReLU()
self.activation = activation
def forward(self, x):
"""Forward function."""
return self.w_2(self.dropout(self.activation(self.w_1(x))))
class MultiHeadedAttentionSANM(nn.Module):
"""Multi-Head Attention layer.
Args:
n_head (int): The number of heads.
n_feat (int): The number of features.
dropout_rate (float): Dropout rate.
"""
def __init__(
self,
n_head,
in_feat,
n_feat,
dropout_rate,
kernel_size,
sanm_shfit=0,
lora_list=None,
lora_rank=8,
lora_alpha=16,
lora_dropout=0.1,
):
super().__init__()
assert n_feat % n_head == 0
# We assume d_v always equals d_k
self.d_k = n_feat // n_head
self.h = n_head
self.linear_out = nn.Linear(n_feat, n_feat)
self.linear_q_k_v = nn.Linear(in_feat, n_feat * 3)
self.attn = None
self.dropout = nn.Dropout(p=dropout_rate)
self.fsmn_block = nn.Conv1d(
n_feat, n_feat, kernel_size, stride=1, padding=0, groups=n_feat, bias=False
)
# padding
left_padding = (kernel_size - 1) // 2
if sanm_shfit > 0:
left_padding = left_padding + sanm_shfit
right_padding = kernel_size - 1 - left_padding
self.pad_fn = nn.ConstantPad1d((left_padding, right_padding), 0.0)
def forward_fsmn(self, inputs, mask, mask_shfit_chunk=None):
b, t, d = inputs.size()
if mask is not None:
mask = torch.reshape(mask, (b, -1, 1))
if mask_shfit_chunk is not None:
mask = mask * mask_shfit_chunk
inputs = inputs * mask
x = inputs.transpose(1, 2)
x = self.pad_fn(x)
x = self.fsmn_block(x)
x = x.transpose(1, 2)
x += inputs
x = self.dropout(x)
if mask is not None:
x = x * mask
return x
def forward_qkv(self, x):
"""Transform query, key and value.
Args:
query (torch.Tensor): Query tensor (#batch, time1, size).
key (torch.Tensor): Key tensor (#batch, time2, size).
value (torch.Tensor): Value tensor (#batch, time2, size).
Returns:
torch.Tensor: Transformed query tensor (#batch, n_head, time1, d_k).
torch.Tensor: Transformed key tensor (#batch, n_head, time2, d_k).
torch.Tensor: Transformed value tensor (#batch, n_head, time2, d_k).
"""
b, t, d = x.size()
q_k_v = self.linear_q_k_v(x)
q, k, v = torch.split(q_k_v, int(self.h * self.d_k), dim=-1)
q_h = torch.reshape(q, (b, t, self.h, self.d_k)).transpose(
1, 2
) # (batch, head, time1, d_k)
k_h = torch.reshape(k, (b, t, self.h, self.d_k)).transpose(
1, 2
) # (batch, head, time2, d_k)
v_h = torch.reshape(v, (b, t, self.h, self.d_k)).transpose(
1, 2
) # (batch, head, time2, d_k)
return q_h, k_h, v_h, v
def forward_attention(self, value, scores, mask, mask_att_chunk_encoder=None):
"""Compute attention context vector.
Args:
value (torch.Tensor): Transformed value (#batch, n_head, time2, d_k).
scores (torch.Tensor): Attention score (#batch, n_head, time1, time2).
mask (torch.Tensor): Mask (#batch, 1, time2) or (#batch, time1, time2).
Returns:
torch.Tensor: Transformed value (#batch, time1, d_model)
weighted by the attention score (#batch, time1, time2).
"""
n_batch = value.size(0)
if mask is not None:
if mask_att_chunk_encoder is not None:
mask = mask * mask_att_chunk_encoder
mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2)
min_value = -float(
"inf"
) # float(numpy.finfo(torch.tensor(0, dtype=scores.dtype).numpy().dtype).min)
scores = scores.masked_fill(mask, min_value)
attn = torch.softmax(scores, dim=-1).masked_fill(
mask, 0.0
) # (batch, head, time1, time2)
else:
attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2)
p_attn = self.dropout(attn)
x = torch.matmul(p_attn, value) # (batch, head, time1, d_k)
x = (
x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k)
) # (batch, time1, d_model)
return self.linear_out(x) # (batch, time1, d_model)
def forward(self, x, mask, mask_shfit_chunk=None, mask_att_chunk_encoder=None):
"""Compute scaled dot product attention.
Args:
query (torch.Tensor): Query tensor (#batch, time1, size).
key (torch.Tensor): Key tensor (#batch, time2, size).
value (torch.Tensor): Value tensor (#batch, time2, size).
mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
(#batch, time1, time2).
Returns:
torch.Tensor: Output tensor (#batch, time1, d_model).
"""
q_h, k_h, v_h, v = self.forward_qkv(x)
fsmn_memory = self.forward_fsmn(v, mask, mask_shfit_chunk)
q_h = q_h * self.d_k ** (-0.5)
scores = torch.matmul(q_h, k_h.transpose(-2, -1))
att_outs = self.forward_attention(v_h, scores, mask, mask_att_chunk_encoder)
return att_outs + fsmn_memory
class EncoderLayerSANM(nn.Module):
def __init__(
self,
in_size,
size,
self_attn,
feed_forward,
dropout_rate,
normalize_before=True,
concat_after=False,
stochastic_depth_rate=0.0,
):
super().__init__()
self.self_attn = self_attn
self.feed_forward = feed_forward
self.norm1 = LayerNorm(in_size)
self.norm2 = LayerNorm(size)
self.dropout = nn.Dropout(dropout_rate)
self.in_size = in_size
self.size = size
self.normalize_before = normalize_before
self.concat_after = concat_after
if self.concat_after:
self.concat_linear = nn.Linear(size + size, size)
self.stochastic_depth_rate = stochastic_depth_rate
self.dropout_rate = dropout_rate
def forward(
self, x, mask, cache=None, mask_shfit_chunk=None, mask_att_chunk_encoder=None
):
"""Compute encoded features.
Args:
x_input (torch.Tensor): Input tensor (#batch, time, size).
mask (torch.Tensor): Mask tensor for the input (#batch, time).
cache (torch.Tensor): Cache tensor of the input (#batch, time - 1, size).
Returns:
torch.Tensor: Output tensor (#batch, time, size).
torch.Tensor: Mask tensor (#batch, time).
"""
skip_layer = False
# with stochastic depth, residual connection `x + f(x)` becomes
# `x <- x + 1 / (1 - p) * f(x)` at training time.
stoch_layer_coeff = 1.0
if self.training and self.stochastic_depth_rate > 0:
skip_layer = torch.rand(1).item() < self.stochastic_depth_rate
stoch_layer_coeff = 1.0 / (1 - self.stochastic_depth_rate)
if skip_layer:
if cache is not None:
x = torch.cat([cache, x], dim=1)
return x, mask
residual = x
if self.normalize_before:
x = self.norm1(x)
if self.concat_after:
x_concat = torch.cat(
(
x,
self.self_attn(
x,
mask,
mask_shfit_chunk=mask_shfit_chunk,
mask_att_chunk_encoder=mask_att_chunk_encoder,
),
),
dim=-1,
)
if self.in_size == self.size:
x = residual + stoch_layer_coeff * self.concat_linear(x_concat)
else:
x = stoch_layer_coeff * self.concat_linear(x_concat)
else:
if self.in_size == self.size:
x = residual + stoch_layer_coeff * self.dropout(
self.self_attn(
x,
mask,
mask_shfit_chunk=mask_shfit_chunk,
mask_att_chunk_encoder=mask_att_chunk_encoder,
)
)
else:
x = stoch_layer_coeff * self.dropout(
self.self_attn(
x,
mask,
mask_shfit_chunk=mask_shfit_chunk,
mask_att_chunk_encoder=mask_att_chunk_encoder,
)
)
return x, mask
if not self.normalize_before:
x = self.norm1(x)
residual = x
if self.normalize_before:
x = self.norm2(x)
x = residual + stoch_layer_coeff * self.dropout(self.feed_forward(x))
if not self.normalize_before:
x = self.norm2(x)
return x, mask, cache, mask_shfit_chunk, mask_att_chunk_encoder
class LayerNorm(nn.LayerNorm):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def forward(self, input):
output = F.layer_norm(
input.float(),
self.normalized_shape,
self.weight.float() if self.weight is not None else None,
self.bias.float() if self.bias is not None else None,
self.eps,
)
return output.type_as(input)
def sequence_mask(lengths, maxlen=None, dtype=torch.float32, device=None):
if maxlen is None:
maxlen = lengths.max()
row_vector = torch.arange(0, maxlen, 1).to(lengths.device)
matrix = torch.unsqueeze(lengths, dim=-1)
mask = row_vector < matrix
mask = mask.detach()
return mask.type(dtype).to(device) if device is not None else mask.type(dtype)
class SenseVoiceEncoderSmall(nn.Module):
def __init__(self):
super().__init__()
self.input_size = 80 * 7
self.output_size = 512
self.attention_heads = 4
self.linear_units = 2048
self.num_blocks = 50
self.tp_blocks = 20
self.input_layer = "pe"
self.pos_enc_class = "SinusoidalPositionEncoder"
self.normalize_before = True
self.kernel_size = 11
self.sanm_shfit = 0
self.concat_after = False
self.positionwise_layer_type = "linear"
self.positionwise_conv_kernel_size = 1
self.padding_idx = -1
self.selfattention_layer_type = "sanm"
self.dropout_rate = 0.1
self.attention_dropout_rate = 0.1
self._output_size = self.output_size
self.embed = SinusoidalPositionEncoder()
positionwise_layer = PositionwiseFeedForward
positionwise_layer_args = (
self.output_size,
self.linear_units,
self.dropout_rate,
)
encoder_selfattn_layer = MultiHeadedAttentionSANM
encoder_selfattn_layer_args0 = (
self.attention_heads,
self.input_size,
self.output_size,
self.attention_dropout_rate,
self.kernel_size,
self.sanm_shfit,
)
encoder_selfattn_layer_args = (
self.attention_heads,
self.output_size,
self.output_size,
self.attention_dropout_rate,
self.kernel_size,
self.sanm_shfit,
)
self.encoders0 = nn.ModuleList(
[
EncoderLayerSANM(
self.input_size,
self.output_size,
encoder_selfattn_layer(*encoder_selfattn_layer_args0),
positionwise_layer(*positionwise_layer_args),
self.dropout_rate,
)
for i in range(1)
]
)
self.encoders = nn.ModuleList(
[
EncoderLayerSANM(
self.output_size,
self.output_size,
encoder_selfattn_layer(*encoder_selfattn_layer_args),
positionwise_layer(*positionwise_layer_args),
self.dropout_rate,
)
for i in range(self.num_blocks - 1)
]
)
self.tp_encoders = nn.ModuleList(
[
EncoderLayerSANM(
self.output_size,
self.output_size,
encoder_selfattn_layer(*encoder_selfattn_layer_args),
positionwise_layer(*positionwise_layer_args),
self.dropout_rate,
)
for i in range(self.tp_blocks)
]
)
self.after_norm = LayerNorm(self.output_size)
self.tp_norm = LayerNorm(self.output_size)
def forward(
self,
xs_pad: torch.Tensor,
):
masks = None
xs_pad *= self.output_size**0.5
xs_pad = self.embed(xs_pad)
# forward encoder1
for layer_idx, encoder_layer in enumerate(self.encoders0):
encoder_outs = encoder_layer(xs_pad, masks)
xs_pad, masks = encoder_outs[0], encoder_outs[1]
for layer_idx, encoder_layer in enumerate(self.encoders):
encoder_outs = encoder_layer(xs_pad, masks)
xs_pad, masks = encoder_outs[0], encoder_outs[1]
xs_pad = self.after_norm(xs_pad)
for layer_idx, encoder_layer in enumerate(self.tp_encoders):
encoder_outs = encoder_layer(xs_pad, masks)
xs_pad, masks = encoder_outs[0], encoder_outs[1]
xs_pad = self.tp_norm(xs_pad)
return xs_pad
class CTC(nn.Module):
def __init__(
self,
odim: int,
encoder_output_size: int,
dropout_rate: float = 0.0,
ctc_type: str = "builtin",
reduce: bool = True,
ignore_nan_grad: bool = True,
extra_linear: bool = True,
):
super().__init__()
eprojs = encoder_output_size
self.dropout_rate = dropout_rate
if extra_linear:
self.ctc_lo = torch.nn.Linear(eprojs, odim)
else:
self.ctc_lo = None
def softmax(self, hs_pad):
"""softmax of frame activations
Args:
Tensor hs_pad: 3d tensor (B, Tmax, eprojs)
Returns:
torch.Tensor: softmax applied 3d tensor (B, Tmax, odim)
"""
if self.ctc_lo is not None:
return F.softmax(self.ctc_lo(hs_pad), dim=2)
else:
return F.softmax(hs_pad, dim=2)
def log_softmax(self, hs_pad):
"""log_softmax of frame activations
Args:
Tensor hs_pad: 3d tensor (B, Tmax, eprojs)
Returns:
torch.Tensor: log softmax applied 3d tensor (B, Tmax, odim)
"""
if self.ctc_lo is not None:
return F.log_softmax(self.ctc_lo(hs_pad), dim=2)
else:
return F.log_softmax(hs_pad, dim=2)
def argmax(self, hs_pad):
"""argmax of frame activations
Args:
torch.Tensor hs_pad: 3d tensor (B, Tmax, eprojs)
Returns:
torch.Tensor: argmax applied 2d tensor (B, Tmax)
"""
if self.ctc_lo is not None:
return torch.argmax(self.ctc_lo(hs_pad), dim=2)
else:
return torch.argmax(hs_pad, dim=2)
class SenseVoiceSmall(nn.Module):
def __init__(self, neg_mean: torch.Tensor, inv_stddev: torch.Tensor):
super().__init__()
self.sos = 1
self.eos = 2
self.length_normalized_loss = True
self.ignore_id = -1
self.blank_id = 0
self.input_size = 80 * 7
self.vocab_size = 25055
self.neg_mean = neg_mean.unsqueeze(0).unsqueeze(0)
self.inv_stddev = inv_stddev.unsqueeze(0).unsqueeze(0)
self.lid_dict = {
"auto": 0,
"zh": 3,
"en": 4,
"yue": 7,
"ja": 11,
"ko": 12,
"nospeech": 13,
}
self.lid_int_dict = {
24884: 3,
24885: 4,
24888: 7,
24892: 11,
24896: 12,
24992: 13,
}
self.textnorm_dict = {"withitn": 14, "woitn": 15}
self.textnorm_int_dict = {25016: 14, 25017: 15}
self.emo_dict = {
"unk": 25009,
"happy": 25001,
"sad": 25002,
"angry": 25003,
"neutral": 25004,
}
self.encoder = SenseVoiceEncoderSmall()
self.ctc = CTC(
odim=self.vocab_size,
encoder_output_size=self.encoder.output_size,
)
self.embed = torch.nn.Embedding(
7 + len(self.lid_dict) + len(self.textnorm_dict), self.input_size
)
def forward(self, x, prompt):
input_query = self.embed(prompt).unsqueeze(0)
# for export, we always assume x and self.neg_mean are on CPU
x = (x + self.neg_mean) * self.inv_stddev
x = torch.cat((input_query, x), dim=1)
encoder_out = self.encoder(x)
logits = self.ctc.ctc_lo(encoder_out)
return logits

View File

@ -173,6 +173,8 @@ list(APPEND sources
)
if(SHERPA_ONNX_ENABLE_RKNN)
list(APPEND sources
./rknn/offline-ctc-greedy-search-decoder-rknn.cc
./rknn/offline-sense-voice-model-rknn.cc
./rknn/online-stream-rknn.cc
./rknn/online-transducer-greedy-search-decoder-rknn.cc
./rknn/online-transducer-modified-beam-search-decoder-rknn.cc

View File

@ -7,6 +7,7 @@
#include "sherpa-onnx/csrc/file-utils.h"
#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/text-utils.h"
namespace sherpa_onnx {
@ -57,9 +58,37 @@ void OfflineModelConfig::Register(ParseOptions *po) {
}
bool OfflineModelConfig::Validate() const {
if (num_threads < 1) {
SHERPA_ONNX_LOGE("num_threads should be > 0. Given %d", num_threads);
return false;
// For RK NPU, we reinterpret num_threads:
//
// For RK3588 only
// num_threads == 1 -> Select a core randomly
// num_threads == 0 -> Use NPU core 0
// num_threads == -1 -> Use NPU core 1
// num_threads == -2 -> Use NPU core 2
// num_threads == -3 -> Use NPU core 0 and core 1
// num_threads == -4 -> Use NPU core 0, core 1, and core 2
if (provider != "rknn") {
if (num_threads < 1) {
SHERPA_ONNX_LOGE("num_threads should be > 0. Given %d", num_threads);
return false;
}
if (!sense_voice.model.empty() && (EndsWith(sense_voice.model, ".rknn"))) {
SHERPA_ONNX_LOGE(
"--provider is %s, which is not rknn, but you pass a rknn model "
"filename. model: '%s'",
provider.c_str(), sense_voice.model.c_str());
return false;
}
}
if (provider == "rknn") {
if (!sense_voice.model.empty() && (EndsWith(sense_voice.model, ".onnx"))) {
SHERPA_ONNX_LOGE(
"--provider is rknn, but you pass an onnx model "
"filename. model: '%s'",
sense_voice.model.c_str());
return false;
}
}
if (!FileExists(tokens)) {

View File

@ -35,10 +35,32 @@
#include "sherpa-onnx/csrc/offline-recognizer-whisper-impl.h"
#include "sherpa-onnx/csrc/text-utils.h"
#if SHERPA_ONNX_ENABLE_RKNN
#include "sherpa-onnx/csrc/rknn/offline-recognizer-sense-voice-rknn-impl.h"
#endif
namespace sherpa_onnx {
std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
const OfflineRecognizerConfig &config) {
if (config.model_config.provider == "rknn") {
#if SHERPA_ONNX_ENABLE_RKNN
if (config.model_config.sense_voice.model.empty()) {
SHERPA_ONNX_LOGE(
"Only SenseVoice models are currently supported "
"by rknn for non-streaming ASR. Fallback to CPU");
} else if (!config.model_config.sense_voice.model.empty()) {
return std::make_unique<OfflineRecognizerSenseVoiceRknnImpl>(config);
}
#else
SHERPA_ONNX_LOGE(
"Please rebuild sherpa-onnx with -DSHERPA_ONNX_ENABLE_RKNN=ON if you "
"want to use rknn.");
SHERPA_ONNX_EXIT(-1);
return nullptr;
#endif
}
if (!config.model_config.sense_voice.model.empty()) {
return std::make_unique<OfflineRecognizerSenseVoiceImpl>(config);
}
@ -229,6 +251,24 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
template <typename Manager>
std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
Manager *mgr, const OfflineRecognizerConfig &config) {
if (config.model_config.provider == "rknn") {
#if SHERPA_ONNX_ENABLE_RKNN
if (config.model_config.sense_voice.model.empty()) {
SHERPA_ONNX_LOGE(
"Only SenseVoice models are currently supported "
"by rknn for non-streaming ASR. Fallback to CPU");
} else if (!config.model_config.sense_voice.model.empty()) {
return std::make_unique<OfflineRecognizerSenseVoiceRknnImpl>(mgr, config);
}
#else
SHERPA_ONNX_LOGE(
"Please rebuild sherpa-onnx with -DSHERPA_ONNX_ENABLE_RKNN=ON if you "
"want to use rknn.");
SHERPA_ONNX_EXIT(-1);
return nullptr;
#endif
}
if (!config.model_config.sense_voice.model.empty()) {
return std::make_unique<OfflineRecognizerSenseVoiceImpl>(mgr, config);
}

View File

@ -11,6 +11,7 @@
#include <utility>
#include <vector>
#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/offline-ctc-greedy-search-decoder.h"
#include "sherpa-onnx/csrc/offline-model-config.h"
#include "sherpa-onnx/csrc/offline-recognizer-impl.h"
@ -21,7 +22,7 @@
namespace sherpa_onnx {
static OfflineRecognitionResult ConvertSenseVoiceResult(
OfflineRecognitionResult ConvertSenseVoiceResult(
const OfflineCtcDecoderResult &src, const SymbolTable &sym_table,
int32_t frame_shift_ms, int32_t subsampling_factor) {
OfflineRecognitionResult r;
@ -72,7 +73,7 @@ class OfflineRecognizerSenseVoiceImpl : public OfflineRecognizerImpl {
} else {
SHERPA_ONNX_LOGE("Only greedy_search is supported at present. Given %s",
config.decoding_method.c_str());
exit(-1);
SHERPA_ONNX_EXIT(-1);
}
InitFeatConfig();
@ -93,7 +94,7 @@ class OfflineRecognizerSenseVoiceImpl : public OfflineRecognizerImpl {
} else {
SHERPA_ONNX_LOGE("Only greedy_search is supported at present. Given %s",
config.decoding_method.c_str());
exit(-1);
SHERPA_ONNX_EXIT(-1);
}
InitFeatConfig();

View File

@ -37,7 +37,6 @@ std::unique_ptr<OnlineRecognizerImpl> OnlineRecognizerImpl::Create(
const OnlineRecognizerConfig &config) {
if (config.model_config.provider_config.provider == "rknn") {
#if SHERPA_ONNX_ENABLE_RKNN
// Currently, only zipformer v1 is suported for rknn
if (config.model_config.transducer.encoder.empty() &&
config.model_config.zipformer2_ctc.model.empty()) {
SHERPA_ONNX_LOGE(

View File

@ -0,0 +1,38 @@
// sherpa-onnx/csrc/rknn/offline-ctc-greedy-search-decoder-rknn.cc
//
// Copyright (c) 2025 Xiaomi Corporation
#include "sherpa-onnx/csrc/rknn/offline-ctc-greedy-search-decoder-rknn.h"
#include <algorithm>
#include <utility>
#include <vector>
#include "sherpa-onnx/csrc/macros.h"
namespace sherpa_onnx {
OfflineCtcDecoderResult OfflineCtcGreedySearchDecoderRknn::Decode(
const float *logits, int32_t num_frames, int32_t vocab_size) {
OfflineCtcDecoderResult ans;
int64_t prev_id = -1;
for (int32_t t = 0; t != num_frames; ++t) {
auto y = static_cast<int64_t>(std::distance(
static_cast<const float *>(logits),
std::max_element(static_cast<const float *>(logits),
static_cast<const float *>(logits) + vocab_size)));
logits += vocab_size;
if (y != blank_id_ && y != prev_id) {
ans.tokens.push_back(y);
ans.timestamps.push_back(t);
}
prev_id = y;
} // for (int32_t t = 0; ...)
return ans;
}
} // namespace sherpa_onnx

View File

@ -0,0 +1,28 @@
// sherpa-onnx/csrc/rknn/offline-ctc-greedy-search-decoder-rknn.h
//
// Copyright (c) 2025 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_RKNN_OFFLINE_CTC_GREEDY_SEARCH_DECODER_RKNN_H_
#define SHERPA_ONNX_CSRC_RKNN_OFFLINE_CTC_GREEDY_SEARCH_DECODER_RKNN_H_
#include <vector>
#include "sherpa-onnx/csrc/offline-ctc-decoder.h"
namespace sherpa_onnx {
class OfflineCtcGreedySearchDecoderRknn {
public:
explicit OfflineCtcGreedySearchDecoderRknn(int32_t blank_id)
: blank_id_(blank_id) {}
OfflineCtcDecoderResult Decode(const float *logits, int32_t num_frames,
int32_t vocab_size);
private:
int32_t blank_id_;
};
} // namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_RKNN_OFFLINE_CTC_GREEDY_SEARCH_DECODER_RKNN_H_

View File

@ -0,0 +1,138 @@
// sherpa-onnx/csrc/offline-recognizer-sense-voice-rknn-impl.h
//
// Copyright (c) 2025 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_RKNN_OFFLINE_RECOGNIZER_SENSE_VOICE_RKNN_IMPL_H_
#define SHERPA_ONNX_CSRC_RKNN_OFFLINE_RECOGNIZER_SENSE_VOICE_RKNN_IMPL_H_
#include <memory>
#include <utility>
#include <vector>
#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/offline-model-config.h"
#include "sherpa-onnx/csrc/offline-recognizer-impl.h"
#include "sherpa-onnx/csrc/offline-recognizer.h"
#include "sherpa-onnx/csrc/rknn/offline-ctc-greedy-search-decoder-rknn.h"
#include "sherpa-onnx/csrc/rknn/offline-sense-voice-model-rknn.h"
#include "sherpa-onnx/csrc/symbol-table.h"
namespace sherpa_onnx {
// defined in ../online-recognizer-sense-voice-impl.h
OfflineRecognitionResult ConvertSenseVoiceResult(
const OfflineCtcDecoderResult &src, const SymbolTable &sym_table,
int32_t frame_shift_ms, int32_t subsampling_factor);
class OfflineRecognizerSenseVoiceRknnImpl : public OfflineRecognizerImpl {
public:
explicit OfflineRecognizerSenseVoiceRknnImpl(
const OfflineRecognizerConfig &config)
: OfflineRecognizerImpl(config),
config_(config),
symbol_table_(config_.model_config.tokens),
model_(
std::make_unique<OfflineSenseVoiceModelRknn>(config.model_config)) {
const auto &meta_data = model_->GetModelMetadata();
if (config.decoding_method == "greedy_search") {
decoder_ = std::make_unique<OfflineCtcGreedySearchDecoderRknn>(
meta_data.blank_id);
} else {
SHERPA_ONNX_LOGE("Only greedy_search is supported at present. Given %s",
config.decoding_method.c_str());
SHERPA_ONNX_EXIT(-1);
}
InitFeatConfig();
}
template <typename Manager>
OfflineRecognizerSenseVoiceRknnImpl(Manager *mgr,
const OfflineRecognizerConfig &config)
: OfflineRecognizerImpl(mgr, config),
config_(config),
symbol_table_(mgr, config_.model_config.tokens),
model_(std::make_unique<OfflineSenseVoiceModelRknn>(
mgr, config.model_config)) {
const auto &meta_data = model_->GetModelMetadata();
if (config.decoding_method == "greedy_search") {
decoder_ = std::make_unique<OfflineCtcGreedySearchDecoderRknn>(
meta_data.blank_id);
} else {
SHERPA_ONNX_LOGE("Only greedy_search is supported at present. Given %s",
config.decoding_method.c_str());
SHERPA_ONNX_EXIT(-1);
}
InitFeatConfig();
}
std::unique_ptr<OfflineStream> CreateStream() const override {
return std::make_unique<OfflineStream>(config_.feat_config);
}
void DecodeStreams(OfflineStream **ss, int32_t n) const override {
for (int32_t i = 0; i < n; ++i) {
DecodeOneStream(ss[i]);
}
}
OfflineRecognizerConfig GetConfig() const override { return config_; }
private:
void InitFeatConfig() {
const auto &meta_data = model_->GetModelMetadata();
config_.feat_config.normalize_samples = meta_data.normalize_samples;
config_.feat_config.window_type = "hamming";
config_.feat_config.high_freq = 0;
config_.feat_config.snip_edges = true;
}
void DecodeOneStream(OfflineStream *s) const {
const auto &meta_data = model_->GetModelMetadata();
std::vector<float> f = s->GetFrames();
int32_t language = 0;
if (config_.model_config.sense_voice.language.empty()) {
language = 0;
} else if (meta_data.lang2id.count(
config_.model_config.sense_voice.language)) {
language =
meta_data.lang2id.at(config_.model_config.sense_voice.language);
} else {
SHERPA_ONNX_LOGE("Unknown language: %s. Use 0 instead.",
config_.model_config.sense_voice.language.c_str());
}
int32_t text_norm = config_.model_config.sense_voice.use_itn
? meta_data.with_itn_id
: meta_data.without_itn_id;
std::vector<float> logits = model_->Run(std::move(f), language, text_norm);
int32_t num_out_frames = logits.size() / meta_data.vocab_size;
auto result =
decoder_->Decode(logits.data(), num_out_frames, meta_data.vocab_size);
int32_t frame_shift_ms = 10;
int32_t subsampling_factor = meta_data.window_shift;
auto r = ConvertSenseVoiceResult(result, symbol_table_, frame_shift_ms,
subsampling_factor);
r.text = ApplyInverseTextNormalization(std::move(r.text));
r.text = ApplyHomophoneReplacer(std::move(r.text));
s->SetResult(r);
}
private:
OfflineRecognizerConfig config_;
SymbolTable symbol_table_;
std::unique_ptr<OfflineSenseVoiceModelRknn> model_;
std::unique_ptr<OfflineCtcGreedySearchDecoderRknn> decoder_;
};
} // namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_RKNN_OFFLINE_RECOGNIZER_SENSE_VOICE_RKNN_IMPL_H_

View File

@ -0,0 +1,244 @@
// sherpa-onnx/csrc/rknn/offline-sense-voice-model-rknn.cc
//
// Copyright (c) 2025 Xiaomi Corporation
#include "sherpa-onnx/csrc/rknn/offline-sense-voice-model-rknn.h"
#include <algorithm>
#include <array>
#include <utility>
#include <vector>
#if __ANDROID_API__ >= 9
#include "android/asset_manager.h"
#include "android/asset_manager_jni.h"
#endif
#if __OHOS__
#include "rawfile/raw_file_manager.h"
#endif
#include "sherpa-onnx/csrc/file-utils.h"
#include "sherpa-onnx/csrc/rknn/macros.h"
#include "sherpa-onnx/csrc/rknn/utils.h"
namespace sherpa_onnx {
class OfflineSenseVoiceModelRknn::Impl {
public:
~Impl() {
auto ret = rknn_destroy(ctx_);
if (ret != RKNN_SUCC) {
SHERPA_ONNX_LOGE("Failed to destroy the context");
}
}
explicit Impl(const OfflineModelConfig &config) : config_(config) {
{
auto buf = ReadFile(config_.sense_voice.model);
Init(buf.data(), buf.size());
}
SetCoreMask(ctx_, config_.num_threads);
}
template <typename Manager>
Impl(Manager *mgr, const OfflineModelConfig &config) : config_(config) {
{
auto buf = ReadFile(mgr, config_.sense_voice.model);
Init(buf.data(), buf.size());
}
SetCoreMask(ctx_, config_.num_threads);
}
const OfflineSenseVoiceModelMetaData &GetModelMetadata() const {
return meta_data_;
}
std::vector<float> Run(std::vector<float> features, int32_t language,
int32_t text_norm) {
features = ApplyLFR(std::move(features));
std::vector<rknn_input> inputs(input_attrs_.size());
std::array<int32_t, 4> prompt{language, 1, 2, text_norm};
inputs[0].index = input_attrs_[0].index;
inputs[0].type = RKNN_TENSOR_FLOAT32;
inputs[0].fmt = input_attrs_[0].fmt;
inputs[0].buf = reinterpret_cast<void *>(features.data());
inputs[0].size = features.size() * sizeof(float);
inputs[1].index = input_attrs_[1].index;
inputs[1].type = RKNN_TENSOR_INT32;
inputs[1].fmt = input_attrs_[1].fmt;
inputs[1].buf = reinterpret_cast<void *>(prompt.data());
inputs[1].size = prompt.size() * sizeof(int32_t);
std::vector<float> out(output_attrs_[0].n_elems);
std::vector<rknn_output> outputs(output_attrs_.size());
outputs[0].index = output_attrs_[0].index;
outputs[0].is_prealloc = 1;
outputs[0].want_float = 1;
outputs[0].size = out.size() * sizeof(float);
outputs[0].buf = reinterpret_cast<void *>(out.data());
rknn_context ctx = 0;
auto ret = rknn_dup_context(&ctx_, &ctx);
SHERPA_ONNX_RKNN_CHECK(ret, "Failed to duplicate the ctx");
ret = rknn_inputs_set(ctx, inputs.size(), inputs.data());
SHERPA_ONNX_RKNN_CHECK(ret, "Failed to set inputs");
ret = rknn_run(ctx, nullptr);
SHERPA_ONNX_RKNN_CHECK(ret, "Failed to run the model");
ret = rknn_outputs_get(ctx, outputs.size(), outputs.data(), nullptr);
SHERPA_ONNX_RKNN_CHECK(ret, "Failed to get model output");
rknn_destroy(ctx);
return out;
}
private:
void Init(void *model_data, size_t model_data_length) {
InitContext(model_data, model_data_length, config_.debug, &ctx_);
InitInputOutputAttrs(ctx_, config_.debug, &input_attrs_, &output_attrs_);
rknn_custom_string custom_string = GetCustomString(ctx_, config_.debug);
auto meta = Parse(custom_string, config_.debug);
#define SHERPA_ONNX_RKNN_READ_META_DATA_INT(dst, src_key) \
do { \
if (!meta.count(#src_key)) { \
SHERPA_ONNX_LOGE("'%s' does not exist in the custom_string", #src_key); \
SHERPA_ONNX_EXIT(-1); \
} \
\
dst = atoi(meta.at(#src_key).c_str()); \
} while (0)
SHERPA_ONNX_RKNN_READ_META_DATA_INT(meta_data_.with_itn_id, with_itn);
SHERPA_ONNX_RKNN_READ_META_DATA_INT(meta_data_.without_itn_id, without_itn);
SHERPA_ONNX_RKNN_READ_META_DATA_INT(meta_data_.window_size,
lfr_window_size);
SHERPA_ONNX_RKNN_READ_META_DATA_INT(meta_data_.window_shift,
lfr_window_shift);
SHERPA_ONNX_RKNN_READ_META_DATA_INT(meta_data_.vocab_size, vocab_size);
SHERPA_ONNX_RKNN_READ_META_DATA_INT(meta_data_.normalize_samples,
normalize_samples);
int32_t lang_auto = 0;
int32_t lang_zh = 0;
int32_t lang_en = 0;
int32_t lang_ja = 0;
int32_t lang_ko = 0;
int32_t lang_yue = 0;
SHERPA_ONNX_RKNN_READ_META_DATA_INT(lang_auto, lang_auto);
SHERPA_ONNX_RKNN_READ_META_DATA_INT(lang_zh, lang_zh);
SHERPA_ONNX_RKNN_READ_META_DATA_INT(lang_en, lang_en);
SHERPA_ONNX_RKNN_READ_META_DATA_INT(lang_ja, lang_ja);
SHERPA_ONNX_RKNN_READ_META_DATA_INT(lang_ko, lang_ko);
SHERPA_ONNX_RKNN_READ_META_DATA_INT(lang_yue, lang_yue);
meta_data_.lang2id = {
{"auto", lang_auto}, {"zh", lang_zh}, {"en", lang_en},
{"ja", lang_ja}, {"ko", lang_ko}, {"yue", lang_yue},
};
// for rknn models, neg_mean and inv_stddev are stored inside the model
#undef SHERPA_ONNX_RKNN_READ_META_DATA_INT
num_input_frames_ = input_attrs_[0].dims[1];
}
std::vector<float> ApplyLFR(std::vector<float> in) const {
int32_t lfr_window_size = meta_data_.window_size;
int32_t lfr_window_shift = meta_data_.window_shift;
int32_t in_feat_dim = 80;
int32_t in_num_frames = in.size() / in_feat_dim;
int32_t out_num_frames =
(in_num_frames - lfr_window_size) / lfr_window_shift + 1;
if (out_num_frames > num_input_frames_) {
SHERPA_ONNX_LOGE(
"Number of input frames %d is too large. Truncate it to %d frames.",
out_num_frames, num_input_frames_);
SHERPA_ONNX_LOGE(
"Recognition result may be truncated/incomplete. Please select a "
"model accepting longer audios.");
out_num_frames = num_input_frames_;
}
int32_t out_feat_dim = in_feat_dim * lfr_window_size;
std::vector<float> out(num_input_frames_ * out_feat_dim);
const float *p_in = in.data();
float *p_out = out.data();
for (int32_t i = 0; i != out_num_frames; ++i) {
std::copy(p_in, p_in + out_feat_dim, p_out);
p_out += out_feat_dim;
p_in += lfr_window_shift * in_feat_dim;
}
return out;
}
private:
OfflineModelConfig config_;
rknn_context ctx_ = 0;
std::vector<rknn_tensor_attr> input_attrs_;
std::vector<rknn_tensor_attr> output_attrs_;
OfflineSenseVoiceModelMetaData meta_data_;
int32_t num_input_frames_ = -1;
};
OfflineSenseVoiceModelRknn::~OfflineSenseVoiceModelRknn() = default;
OfflineSenseVoiceModelRknn::OfflineSenseVoiceModelRknn(
const OfflineModelConfig &config)
: impl_(std::make_unique<Impl>(config)) {}
template <typename Manager>
OfflineSenseVoiceModelRknn::OfflineSenseVoiceModelRknn(
Manager *mgr, const OfflineModelConfig &config)
: impl_(std::make_unique<Impl>(mgr, config)) {}
std::vector<float> OfflineSenseVoiceModelRknn::Run(std::vector<float> features,
int32_t language,
int32_t text_norm) const {
return impl_->Run(std::move(features), language, text_norm);
}
const OfflineSenseVoiceModelMetaData &
OfflineSenseVoiceModelRknn::GetModelMetadata() const {
return impl_->GetModelMetadata();
}
#if __ANDROID_API__ >= 9
template OfflineSenseVoiceModelRknn::OfflineSenseVoiceModelRknn(
AAssetManager *mgr, const OfflineModelConfig &config);
#endif
#if __OHOS__
template OfflineSenseVoiceModelRknn::OfflineSenseVoiceModelRknn(
NativeResourceManager *mgr, const OfflineModelConfig &config);
#endif
} // namespace sherpa_onnx

View File

@ -0,0 +1,45 @@
// sherpa-onnx/csrc/rknn/offline-sense-voice-model-rknn.h
//
// Copyright (c) 2025 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_RKNN_OFFLINE_SENSE_VOICE_MODEL_RKNN_H_
#define SHERPA_ONNX_CSRC_RKNN_OFFLINE_SENSE_VOICE_MODEL_RKNN_H_
#include <memory>
#include <utility>
#include <vector>
#include "rknn_api.h" // NOLINT
#include "sherpa-onnx/csrc/offline-model-config.h"
#include "sherpa-onnx/csrc/offline-sense-voice-model-meta-data.h"
namespace sherpa_onnx {
class OfflineSenseVoiceModelRknn {
public:
~OfflineSenseVoiceModelRknn();
explicit OfflineSenseVoiceModelRknn(const OfflineModelConfig &config);
template <typename Manager>
OfflineSenseVoiceModelRknn(Manager *mgr, const OfflineModelConfig &config);
/**
* @param features A tensor of shape (num_frames, feature_dim)
* before applying LFR.
* @param language
* @param text_norm
* @returns Return a tensor of shape (num_output_frames, vocab_size)
*/
std::vector<float> Run(std::vector<float> features, int32_t language,
int32_t text_norm) const;
const OfflineSenseVoiceModelMetaData &GetModelMetadata() const;
private:
class Impl;
std::unique_ptr<Impl> impl_;
};
} // namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_RKNN_OFFLINE_SENSE_VOICE_MODEL_RKNN_H_

View File

@ -55,8 +55,6 @@ class OnlineZipformerCtcModelRknn::Impl {
SetCoreMask(ctx_, config_.num_threads);
}
// TODO(fangjun): Support Android
std::vector<std::vector<uint8_t>> GetInitStates() const {
// input_attrs_[0] is for the feature
// input_attrs_[1:] is for states

View File

@ -89,8 +89,6 @@ class OnlineZipformerTransducerModelRknn::Impl {
SetCoreMask(joiner_ctx_, config_.num_threads);
}
// TODO(fangjun): Support Android
std::vector<std::vector<uint8_t>> GetEncoderInitStates() const {
// encoder_input_attrs_[0] is for the feature
// encoder_input_attrs_[1:] is for states