From c691318b95a7da5ebd5b1781559575d7c90d764f Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Fri, 12 Sep 2025 10:46:38 +0800 Subject: [PATCH] Support RK NPU for SenseVoice non-streaming ASR models (#2589) This PR adds RK NPU support for SenseVoice non-streaming ASR models by implementing a new RKNN backend with greedy CTC decoding. - Adds offline RKNN implementation for SenseVoice models including model loading, feature processing, and CTC decoding - Introduces export tools to convert SenseVoice models from PyTorch to ONNX and then to RKNN format - Implements provider-aware validation to prevent mismatched model and provider usage --- .gitignore | 4 + scripts/sense-voice/export-onnx.py | 3 + scripts/sense-voice/rknn/export-onnx.py | 164 +++++ scripts/sense-voice/rknn/export-rknn.py | 158 +++++ scripts/sense-voice/rknn/test_onnx.py | 236 +++++++ scripts/sense-voice/rknn/torch_model.py | 604 ++++++++++++++++++ sherpa-onnx/csrc/CMakeLists.txt | 2 + sherpa-onnx/csrc/offline-model-config.cc | 35 +- sherpa-onnx/csrc/offline-recognizer-impl.cc | 40 ++ .../offline-recognizer-sense-voice-impl.h | 7 +- sherpa-onnx/csrc/online-recognizer-impl.cc | 1 - .../offline-ctc-greedy-search-decoder-rknn.cc | 38 ++ .../offline-ctc-greedy-search-decoder-rknn.h | 28 + ...offline-recognizer-sense-voice-rknn-impl.h | 138 ++++ .../rknn/offline-sense-voice-model-rknn.cc | 244 +++++++ .../rknn/offline-sense-voice-model-rknn.h | 45 ++ .../rknn/online-zipformer-ctc-model-rknn.cc | 2 - .../online-zipformer-transducer-model-rknn.cc | 2 - 18 files changed, 1740 insertions(+), 11 deletions(-) create mode 100755 scripts/sense-voice/rknn/export-onnx.py create mode 100755 scripts/sense-voice/rknn/export-rknn.py create mode 100755 scripts/sense-voice/rknn/test_onnx.py create mode 100644 scripts/sense-voice/rknn/torch_model.py create mode 100644 sherpa-onnx/csrc/rknn/offline-ctc-greedy-search-decoder-rknn.cc create mode 100644 sherpa-onnx/csrc/rknn/offline-ctc-greedy-search-decoder-rknn.h create mode 100644 sherpa-onnx/csrc/rknn/offline-recognizer-sense-voice-rknn-impl.h create mode 100644 sherpa-onnx/csrc/rknn/offline-sense-voice-model-rknn.cc create mode 100644 sherpa-onnx/csrc/rknn/offline-sense-voice-model-rknn.h diff --git a/.gitignore b/.gitignore index 8adf2300..f0a4c52e 100644 --- a/.gitignore +++ b/.gitignore @@ -152,3 +152,7 @@ vocab.json *.so sherpa-onnx-streaming-t-one-russian-2025-09-08 sherpa-onnx-wenetspeech-yue-u2pp-conformer-ctc-zh-en-cantonese-int8-2025-09-10 +am.mvn +*bpe.model +config.yaml +configuration.json diff --git a/scripts/sense-voice/export-onnx.py b/scripts/sense-voice/export-onnx.py index 48b68636..0153cebd 100755 --- a/scripts/sense-voice/export-onnx.py +++ b/scripts/sense-voice/export-onnx.py @@ -118,8 +118,11 @@ def display_params(params): os.system(f"cat {params['config']}") +@torch.no_grad() def main(): model, params = SenseVoiceSmall.from_pretrained(model="iic/SenseVoiceSmall", device="cpu") + model.eval() + display_params(params) generate_tokens(params) diff --git a/scripts/sense-voice/rknn/export-onnx.py b/scripts/sense-voice/rknn/export-onnx.py new file mode 100755 index 00000000..2adcefa1 --- /dev/null +++ b/scripts/sense-voice/rknn/export-onnx.py @@ -0,0 +1,164 @@ +#!/usr/bin/env python3 +# Copyright 2025 Xiaomi Corp. (authors: Fangjun Kuang) + +import argparse +import os +from typing import Any, Dict, List, Tuple + +import onnx +import sentencepiece as spm +import torch + +from torch_model import SenseVoiceSmall + + +def get_args(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--input-len-in-seconds", + type=int, + required=True, + help="""RKNN does not support dynamic shape, so we need to hard-code + how long the model can process. + """, + ) + return parser.parse_args() + + +def add_meta_data(filename: str, meta_data: Dict[str, Any]): + """Add meta data to an ONNX model. It is changed in-place. + + Args: + filename: + Filename of the ONNX model to be changed. + meta_data: + Key-value pairs. + """ + model = onnx.load(filename) + while len(model.metadata_props): + model.metadata_props.pop() + + for key, value in meta_data.items(): + meta = model.metadata_props.add() + meta.key = key + meta.value = str(value) + + onnx.save(model, filename) + + +def load_cmvn(filename) -> Tuple[List[float], List[float]]: + neg_mean = None + inv_stddev = None + + with open(filename) as f: + for line in f: + if not line.startswith(""): + continue + t = line.split()[3:-1] + + if neg_mean is None: + neg_mean = list(map(lambda x: float(x), t)) + else: + inv_stddev = list(map(lambda x: float(x), t)) + + return neg_mean, inv_stddev + + +def generate_tokens(sp): + with open("tokens.txt", "w", encoding="utf-8") as f: + for i in range(sp.vocab_size()): + f.write(f"{sp.id_to_piece(i)} {i}\n") + print("saved to tokens.txt") + + +@torch.no_grad() +def main(): + args = get_args() + print(vars(args)) + + sp = spm.SentencePieceProcessor() + sp.load("./chn_jpn_yue_eng_ko_spectok.bpe.model") + vocab_size = sp.vocab_size() + generate_tokens(sp) + + print("loading model") + + state_dict = torch.load("./model.pt") + if "state_dict" in state_dict: + state_dict = state_dict["state_dict"] + + neg_mean, inv_stddev = load_cmvn("./am.mvn") + + neg_mean = torch.tensor(neg_mean, dtype=torch.float32) + inv_stddev = torch.tensor(inv_stddev, dtype=torch.float32) + + model = SenseVoiceSmall(neg_mean=neg_mean, inv_stddev=inv_stddev) + model.load_state_dict(state_dict) + model.eval() + del state_dict + + lfr_window_size = 7 + lfr_window_shift = 6 + + # frame shift is 10ms, 1 second has about 100 feature frames + input_len_in_seconds = int(args.input_len_in_seconds) + num_frames = input_len_in_seconds * 100 + print("num_frames", num_frames) + + # num_input_frames is an approximate number + num_input_frames = int(num_frames / lfr_window_shift + 0.5) + print("num_input_frames", num_input_frames) + + x = torch.randn(1, num_input_frames, 560, dtype=torch.float32) + + language = 3 + text_norm = 15 + prompt = torch.tensor([language, 1, 2, text_norm], dtype=torch.int32) + + opset_version = 13 + filename = f"model-{input_len_in_seconds}-seconds.onnx" + torch.onnx.export( + model, + (x, prompt), + filename, + opset_version=opset_version, + input_names=["x", "prompt"], + output_names=["logits"], + dynamic_axes={}, + ) + + model_author = os.environ.get("model_author", "iic") + comment = os.environ.get("comment", "iic/SenseVoiceSmall") + url = os.environ.get("url", "https://huggingface.co/FunAudioLLM/SenseVoiceSmall") + + meta_data = { + "lfr_window_size": lfr_window_size, + "lfr_window_shift": lfr_window_shift, + "num_input_frames": num_input_frames, + "normalize_samples": 0, # input should be in the range [-32768, 32767] + "model_type": "sense_voice_ctc", + "version": "1", + "model_author": model_author, + "maintainer": "k2-fsa", + "vocab_size": vocab_size, + "comment": comment, + "lang_auto": model.lid_dict["auto"], + "lang_zh": model.lid_dict["zh"], + "lang_en": model.lid_dict["en"], + "lang_yue": model.lid_dict["yue"], # cantonese + "lang_ja": model.lid_dict["ja"], + "lang_ko": model.lid_dict["ko"], + "lang_nospeech": model.lid_dict["nospeech"], + "with_itn": model.textnorm_dict["withitn"], + "without_itn": model.textnorm_dict["woitn"], + "url": url, + } + add_meta_data(filename=filename, meta_data=meta_data) + + +if __name__ == "__main__": + torch.manual_seed(20250717) + main() diff --git a/scripts/sense-voice/rknn/export-rknn.py b/scripts/sense-voice/rknn/export-rknn.py new file mode 100755 index 00000000..744a3b0e --- /dev/null +++ b/scripts/sense-voice/rknn/export-rknn.py @@ -0,0 +1,158 @@ +#!/usr/bin/env python3 +# Copyright (c) 2025 Xiaomi Corporation (authors: Fangjun Kuang) + +import argparse +import logging +from pathlib import Path + +from rknn.api import RKNN + +logging.basicConfig(level=logging.WARNING) + +g_platforms = [ + # "rv1103", + # "rv1103b", + # "rv1106", + # "rk2118", + "rk3562", + "rk3566", + "rk3568", + "rk3576", + "rk3588", +] + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--target-platform", + type=str, + required=True, + help=f"Supported values are: {','.join(g_platforms)}", + ) + + parser.add_argument( + "--in-model", + type=str, + required=True, + help="Path to the input onnx model", + ) + + parser.add_argument( + "--out-model", + type=str, + required=True, + help="Path to the output rknn model", + ) + + return parser + + +def get_meta_data(model: str): + import onnxruntime + + session_opts = onnxruntime.SessionOptions() + session_opts.inter_op_num_threads = 1 + session_opts.intra_op_num_threads = 1 + + m = onnxruntime.InferenceSession( + model, + sess_options=session_opts, + providers=["CPUExecutionProvider"], + ) + + for i in m.get_inputs(): + print(i) + + print("-----") + + for i in m.get_outputs(): + print(i) + print() + + meta = m.get_modelmeta().custom_metadata_map + s = "" + sep = "" + for key, value in meta.items(): + if key in ("neg_mean", "inv_stddev"): + continue + s = s + sep + f"{key}={value}" + sep = ";" + assert len(s) < 1024, len(s) + + print("len(s)", len(s), s) + + return s + + +def export_rknn(rknn, filename): + ret = rknn.export_rknn(filename) + if ret != 0: + exit(f"Export rknn model to {filename} failed!") + + +def init_model(filename: str, target_platform: str, custom_string=None): + rknn = RKNN(verbose=False) + + rknn.config( + optimization_level=0, + target_platform=target_platform, + custom_string=custom_string, + ) + if not Path(filename).is_file(): + exit(f"{filename} does not exist") + + ret = rknn.load_onnx(model=filename) + if ret != 0: + exit(f"Load model {filename} failed!") + + ret = rknn.build(do_quantization=False) + if ret != 0: + exit(f"Build model {filename} failed!") + + return rknn + + +class RKNNModel: + def __init__( + self, + model: str, + target_platform: str, + ): + meta = get_meta_data(model) + print(meta) + + self.model = init_model( + model, + target_platform=target_platform, + custom_string=meta, + ) + + def export_rknn(self, model): + export_rknn(self.model, model) + + def release(self): + self.model.release() + + +def main(): + args = get_parser().parse_args() + print(vars(args)) + + model = RKNNModel( + model=args.in_model, + target_platform=args.target_platform, + ) + + model.export_rknn( + model=args.out_model, + ) + + model.release() + + +if __name__ == "__main__": + main() diff --git a/scripts/sense-voice/rknn/test_onnx.py b/scripts/sense-voice/rknn/test_onnx.py new file mode 100755 index 00000000..d4ac38bc --- /dev/null +++ b/scripts/sense-voice/rknn/test_onnx.py @@ -0,0 +1,236 @@ +#!/usr/bin/env python3 +# Copyright 2025 Xiaomi Corp. (authors: Fangjun Kuang) + +import argparse +from typing import Tuple + +import kaldi_native_fbank as knf +import numpy as np +import onnxruntime +import onnxruntime as ort +import soundfile as sf +import torch + + +def get_args(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--model", + type=str, + required=True, + help="Path to model.onnx", + ) + + parser.add_argument( + "--tokens", + type=str, + required=True, + help="Path to tokens.txt", + ) + + parser.add_argument( + "--wave", + type=str, + required=True, + help="The input wave to be recognized", + ) + + parser.add_argument( + "--language", + type=str, + default="auto", + help="the language of the input wav file. Supported values: zh, en, ja, ko, yue, auto", + ) + + parser.add_argument( + "--use-itn", + type=int, + default=0, + help="1 to use inverse text normalization. 0 to not use inverse text normalization", + ) + + return parser.parse_args() + + +class OnnxModel: + def __init__(self, filename): + session_opts = ort.SessionOptions() + session_opts.inter_op_num_threads = 1 + session_opts.intra_op_num_threads = 1 + + self.session_opts = session_opts + + self.model = ort.InferenceSession( + filename, + sess_options=self.session_opts, + providers=["CPUExecutionProvider"], + ) + + meta = self.model.get_modelmeta().custom_metadata_map + + self.window_size = int(meta["lfr_window_size"]) # lfr_m + self.window_shift = int(meta["lfr_window_shift"]) # lfr_n + + lang_zh = int(meta["lang_zh"]) + lang_en = int(meta["lang_en"]) + lang_ja = int(meta["lang_ja"]) + lang_ko = int(meta["lang_ko"]) + lang_yue = int(meta["lang_yue"]) + lang_auto = int(meta["lang_auto"]) + + self.lang_id = { + "zh": lang_zh, + "en": lang_en, + "ja": lang_ja, + "ko": lang_ko, + "yue": lang_yue, + "auto": lang_auto, + } + self.with_itn = int(meta["with_itn"]) + self.without_itn = int(meta["without_itn"]) + + self.max_len = self.model.get_inputs()[0].shape[1] + + def __call__(self, x, prompt): + logits = self.model.run( + [ + self.model.get_outputs()[0].name, + ], + { + self.model.get_inputs()[0].name: x.numpy(), + self.model.get_inputs()[1].name: prompt.numpy(), + }, + )[0] + + return torch.from_numpy(logits) + + +def load_audio(filename: str) -> Tuple[np.ndarray, int]: + data, sample_rate = sf.read( + filename, + always_2d=True, + dtype="float32", + ) + data = data[:, 0] # use only the first channel + samples = np.ascontiguousarray(data) + return samples, sample_rate + + +def load_tokens(filename): + ans = dict() + i = 0 + with open(filename, encoding="utf-8") as f: + for line in f: + ans[i] = line.strip().split()[0] + i += 1 + return ans + + +def compute_feat( + samples, + sample_rate, + max_len: int, + window_size: int = 7, # lfr_m + window_shift: int = 6, # lfr_n +): + opts = knf.FbankOptions() + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.window_type = "hamming" + opts.frame_opts.samp_freq = sample_rate + opts.mel_opts.num_bins = 80 + + online_fbank = knf.OnlineFbank(opts) + online_fbank.accept_waveform(sample_rate, (samples * 32768).tolist()) + online_fbank.input_finished() + + features = np.stack( + [online_fbank.get_frame(i) for i in range(online_fbank.num_frames_ready)] + ) + assert features.data.contiguous is True + assert features.dtype == np.float32, features.dtype + + T = (features.shape[0] - window_size) // window_shift + 1 + features = np.lib.stride_tricks.as_strided( + features, + shape=(T, features.shape[1] * window_size), + strides=((window_shift * features.shape[1]) * 4, 4), + ) + + print("features.shape", features.shape) + + if features.shape[0] > max_len: + features = features[:max_len] + elif features.shape[0] < max_len: + features = np.pad( + features, + ((0, max_len - features.shape[0]), (0, 0)), + mode="constant", + constant_values=0, + ) + + print("features.shape", features.shape) + + return features + + +def main(): + args = get_args() + print(vars(args)) + samples, sample_rate = load_audio(args.wave) + if sample_rate != 16000: + import librosa + + samples = librosa.resample(samples, orig_sr=sample_rate, target_sr=16000) + sample_rate = 16000 + + model = OnnxModel(filename=args.model) + + features = compute_feat( + samples=samples, + sample_rate=sample_rate, + max_len=model.max_len, + window_size=model.window_size, + window_shift=model.window_shift, + ) + + features = torch.from_numpy(features).unsqueeze(0) + + language = model.lang_id["auto"] + if args.language in model.lang_id: + language = model.lang_id[args.language] + else: + print(f"Invalid language: '{args.language}'") + print("Use auto") + + if args.use_itn: + text_norm = model.with_itn + else: + text_norm = model.without_itn + + prompt = torch.tensor([language, 1, 2, text_norm], dtype=torch.int32) + + logits = model( + x=features, + prompt=prompt, + ) + + idx = logits.squeeze(0).argmax(dim=-1) + # idx is of shape (T,) + idx = torch.unique_consecutive(idx) + + blank_id = 0 + idx = idx[idx != blank_id].tolist() + + tokens = load_tokens(args.tokens) + text = "".join([tokens[i] for i in idx]) + + text = text.replace("▁", " ") + print(text) + + +if __name__ == "__main__": + main() diff --git a/scripts/sense-voice/rknn/torch_model.py b/scripts/sense-voice/rknn/torch_model.py new file mode 100644 index 00000000..b1a750d7 --- /dev/null +++ b/scripts/sense-voice/rknn/torch_model.py @@ -0,0 +1,604 @@ +# This file is modified from +# https://github.com/modelscope/FunASR/blob/main/funasr/models/sense_voice/model.py + +import torch +import torch.nn +import torch.nn as nn +import torch.nn.functional as F + + +class SinusoidalPositionEncoder(nn.Module): + """ """ + + def __init__(self, d_model=80, dropout_rate=0.1): + pass + + def encode( + self, + positions: torch.Tensor = None, + depth: int = None, + dtype: torch.dtype = torch.float32, + ): + """ + Args: + positions: (batch_size, ) + """ + batch_size = positions.size(0) + positions = positions.type(dtype) + device = positions.device + log_timescale_increment = torch.log( + torch.tensor([10000], dtype=dtype, device=device) + ) / (depth / 2 - 1) + inv_timescales = torch.exp( + torch.arange(depth / 2, device=device).type(dtype) + * (-log_timescale_increment) + ) + inv_timescales = torch.reshape(inv_timescales, [batch_size, -1]) + scaled_time = torch.reshape(positions, [1, -1, 1]) * torch.reshape( + inv_timescales, [1, 1, -1] + ) + encoding = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=2) + return encoding.type(dtype) + + def forward(self, x): + batch_size, timesteps, input_dim = x.size() + positions = torch.arange(1, timesteps + 1, device=x.device)[None, :] + position_encoding = self.encode(positions, input_dim, x.dtype).to(x.device) + + return x + position_encoding + + +class PositionwiseFeedForward(nn.Module): + """Positionwise feed forward layer. + + Args: + idim (int): Input dimenstion. + hidden_units (int): The number of hidden units. + dropout_rate (float): Dropout rate. + + """ + + def __init__(self, idim, hidden_units, dropout_rate, activation=None): + super().__init__() + self.w_1 = torch.nn.Linear(idim, hidden_units) + self.w_2 = torch.nn.Linear(hidden_units, idim) + self.dropout = torch.nn.Dropout(dropout_rate) + if activation is None: + activation = torch.nn.ReLU() + self.activation = activation + + def forward(self, x): + """Forward function.""" + return self.w_2(self.dropout(self.activation(self.w_1(x)))) + + +class MultiHeadedAttentionSANM(nn.Module): + """Multi-Head Attention layer. + + Args: + n_head (int): The number of heads. + n_feat (int): The number of features. + dropout_rate (float): Dropout rate. + + """ + + def __init__( + self, + n_head, + in_feat, + n_feat, + dropout_rate, + kernel_size, + sanm_shfit=0, + lora_list=None, + lora_rank=8, + lora_alpha=16, + lora_dropout=0.1, + ): + super().__init__() + assert n_feat % n_head == 0 + # We assume d_v always equals d_k + self.d_k = n_feat // n_head + self.h = n_head + self.linear_out = nn.Linear(n_feat, n_feat) + self.linear_q_k_v = nn.Linear(in_feat, n_feat * 3) + self.attn = None + self.dropout = nn.Dropout(p=dropout_rate) + + self.fsmn_block = nn.Conv1d( + n_feat, n_feat, kernel_size, stride=1, padding=0, groups=n_feat, bias=False + ) + # padding + left_padding = (kernel_size - 1) // 2 + if sanm_shfit > 0: + left_padding = left_padding + sanm_shfit + right_padding = kernel_size - 1 - left_padding + self.pad_fn = nn.ConstantPad1d((left_padding, right_padding), 0.0) + + def forward_fsmn(self, inputs, mask, mask_shfit_chunk=None): + b, t, d = inputs.size() + if mask is not None: + mask = torch.reshape(mask, (b, -1, 1)) + if mask_shfit_chunk is not None: + mask = mask * mask_shfit_chunk + inputs = inputs * mask + + x = inputs.transpose(1, 2) + x = self.pad_fn(x) + x = self.fsmn_block(x) + x = x.transpose(1, 2) + x += inputs + x = self.dropout(x) + if mask is not None: + x = x * mask + return x + + def forward_qkv(self, x): + """Transform query, key and value. + + Args: + query (torch.Tensor): Query tensor (#batch, time1, size). + key (torch.Tensor): Key tensor (#batch, time2, size). + value (torch.Tensor): Value tensor (#batch, time2, size). + + Returns: + torch.Tensor: Transformed query tensor (#batch, n_head, time1, d_k). + torch.Tensor: Transformed key tensor (#batch, n_head, time2, d_k). + torch.Tensor: Transformed value tensor (#batch, n_head, time2, d_k). + + """ + b, t, d = x.size() + q_k_v = self.linear_q_k_v(x) + q, k, v = torch.split(q_k_v, int(self.h * self.d_k), dim=-1) + q_h = torch.reshape(q, (b, t, self.h, self.d_k)).transpose( + 1, 2 + ) # (batch, head, time1, d_k) + k_h = torch.reshape(k, (b, t, self.h, self.d_k)).transpose( + 1, 2 + ) # (batch, head, time2, d_k) + v_h = torch.reshape(v, (b, t, self.h, self.d_k)).transpose( + 1, 2 + ) # (batch, head, time2, d_k) + + return q_h, k_h, v_h, v + + def forward_attention(self, value, scores, mask, mask_att_chunk_encoder=None): + """Compute attention context vector. + + Args: + value (torch.Tensor): Transformed value (#batch, n_head, time2, d_k). + scores (torch.Tensor): Attention score (#batch, n_head, time1, time2). + mask (torch.Tensor): Mask (#batch, 1, time2) or (#batch, time1, time2). + + Returns: + torch.Tensor: Transformed value (#batch, time1, d_model) + weighted by the attention score (#batch, time1, time2). + + """ + n_batch = value.size(0) + if mask is not None: + if mask_att_chunk_encoder is not None: + mask = mask * mask_att_chunk_encoder + + mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2) + + min_value = -float( + "inf" + ) # float(numpy.finfo(torch.tensor(0, dtype=scores.dtype).numpy().dtype).min) + scores = scores.masked_fill(mask, min_value) + attn = torch.softmax(scores, dim=-1).masked_fill( + mask, 0.0 + ) # (batch, head, time1, time2) + else: + attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2) + + p_attn = self.dropout(attn) + x = torch.matmul(p_attn, value) # (batch, head, time1, d_k) + x = ( + x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k) + ) # (batch, time1, d_model) + + return self.linear_out(x) # (batch, time1, d_model) + + def forward(self, x, mask, mask_shfit_chunk=None, mask_att_chunk_encoder=None): + """Compute scaled dot product attention. + + Args: + query (torch.Tensor): Query tensor (#batch, time1, size). + key (torch.Tensor): Key tensor (#batch, time2, size). + value (torch.Tensor): Value tensor (#batch, time2, size). + mask (torch.Tensor): Mask tensor (#batch, 1, time2) or + (#batch, time1, time2). + + Returns: + torch.Tensor: Output tensor (#batch, time1, d_model). + + """ + q_h, k_h, v_h, v = self.forward_qkv(x) + fsmn_memory = self.forward_fsmn(v, mask, mask_shfit_chunk) + q_h = q_h * self.d_k ** (-0.5) + scores = torch.matmul(q_h, k_h.transpose(-2, -1)) + att_outs = self.forward_attention(v_h, scores, mask, mask_att_chunk_encoder) + return att_outs + fsmn_memory + + +class EncoderLayerSANM(nn.Module): + def __init__( + self, + in_size, + size, + self_attn, + feed_forward, + dropout_rate, + normalize_before=True, + concat_after=False, + stochastic_depth_rate=0.0, + ): + super().__init__() + self.self_attn = self_attn + self.feed_forward = feed_forward + self.norm1 = LayerNorm(in_size) + self.norm2 = LayerNorm(size) + self.dropout = nn.Dropout(dropout_rate) + self.in_size = in_size + self.size = size + self.normalize_before = normalize_before + self.concat_after = concat_after + if self.concat_after: + self.concat_linear = nn.Linear(size + size, size) + self.stochastic_depth_rate = stochastic_depth_rate + self.dropout_rate = dropout_rate + + def forward( + self, x, mask, cache=None, mask_shfit_chunk=None, mask_att_chunk_encoder=None + ): + """Compute encoded features. + + Args: + x_input (torch.Tensor): Input tensor (#batch, time, size). + mask (torch.Tensor): Mask tensor for the input (#batch, time). + cache (torch.Tensor): Cache tensor of the input (#batch, time - 1, size). + + Returns: + torch.Tensor: Output tensor (#batch, time, size). + torch.Tensor: Mask tensor (#batch, time). + + """ + skip_layer = False + # with stochastic depth, residual connection `x + f(x)` becomes + # `x <- x + 1 / (1 - p) * f(x)` at training time. + stoch_layer_coeff = 1.0 + if self.training and self.stochastic_depth_rate > 0: + skip_layer = torch.rand(1).item() < self.stochastic_depth_rate + stoch_layer_coeff = 1.0 / (1 - self.stochastic_depth_rate) + + if skip_layer: + if cache is not None: + x = torch.cat([cache, x], dim=1) + return x, mask + + residual = x + if self.normalize_before: + x = self.norm1(x) + + if self.concat_after: + x_concat = torch.cat( + ( + x, + self.self_attn( + x, + mask, + mask_shfit_chunk=mask_shfit_chunk, + mask_att_chunk_encoder=mask_att_chunk_encoder, + ), + ), + dim=-1, + ) + if self.in_size == self.size: + x = residual + stoch_layer_coeff * self.concat_linear(x_concat) + else: + x = stoch_layer_coeff * self.concat_linear(x_concat) + else: + if self.in_size == self.size: + x = residual + stoch_layer_coeff * self.dropout( + self.self_attn( + x, + mask, + mask_shfit_chunk=mask_shfit_chunk, + mask_att_chunk_encoder=mask_att_chunk_encoder, + ) + ) + else: + x = stoch_layer_coeff * self.dropout( + self.self_attn( + x, + mask, + mask_shfit_chunk=mask_shfit_chunk, + mask_att_chunk_encoder=mask_att_chunk_encoder, + ) + ) + return x, mask + if not self.normalize_before: + x = self.norm1(x) + + residual = x + if self.normalize_before: + x = self.norm2(x) + x = residual + stoch_layer_coeff * self.dropout(self.feed_forward(x)) + if not self.normalize_before: + x = self.norm2(x) + + return x, mask, cache, mask_shfit_chunk, mask_att_chunk_encoder + + +class LayerNorm(nn.LayerNorm): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def forward(self, input): + output = F.layer_norm( + input.float(), + self.normalized_shape, + self.weight.float() if self.weight is not None else None, + self.bias.float() if self.bias is not None else None, + self.eps, + ) + return output.type_as(input) + + +def sequence_mask(lengths, maxlen=None, dtype=torch.float32, device=None): + if maxlen is None: + maxlen = lengths.max() + row_vector = torch.arange(0, maxlen, 1).to(lengths.device) + matrix = torch.unsqueeze(lengths, dim=-1) + mask = row_vector < matrix + mask = mask.detach() + + return mask.type(dtype).to(device) if device is not None else mask.type(dtype) + + +class SenseVoiceEncoderSmall(nn.Module): + def __init__(self): + super().__init__() + self.input_size = 80 * 7 + self.output_size = 512 + self.attention_heads = 4 + self.linear_units = 2048 + self.num_blocks = 50 + self.tp_blocks = 20 + self.input_layer = "pe" + self.pos_enc_class = "SinusoidalPositionEncoder" + self.normalize_before = True + self.kernel_size = 11 + self.sanm_shfit = 0 + self.concat_after = False + self.positionwise_layer_type = "linear" + self.positionwise_conv_kernel_size = 1 + self.padding_idx = -1 + self.selfattention_layer_type = "sanm" + self.dropout_rate = 0.1 + self.attention_dropout_rate = 0.1 + + self._output_size = self.output_size + + self.embed = SinusoidalPositionEncoder() + + positionwise_layer = PositionwiseFeedForward + positionwise_layer_args = ( + self.output_size, + self.linear_units, + self.dropout_rate, + ) + + encoder_selfattn_layer = MultiHeadedAttentionSANM + encoder_selfattn_layer_args0 = ( + self.attention_heads, + self.input_size, + self.output_size, + self.attention_dropout_rate, + self.kernel_size, + self.sanm_shfit, + ) + encoder_selfattn_layer_args = ( + self.attention_heads, + self.output_size, + self.output_size, + self.attention_dropout_rate, + self.kernel_size, + self.sanm_shfit, + ) + + self.encoders0 = nn.ModuleList( + [ + EncoderLayerSANM( + self.input_size, + self.output_size, + encoder_selfattn_layer(*encoder_selfattn_layer_args0), + positionwise_layer(*positionwise_layer_args), + self.dropout_rate, + ) + for i in range(1) + ] + ) + + self.encoders = nn.ModuleList( + [ + EncoderLayerSANM( + self.output_size, + self.output_size, + encoder_selfattn_layer(*encoder_selfattn_layer_args), + positionwise_layer(*positionwise_layer_args), + self.dropout_rate, + ) + for i in range(self.num_blocks - 1) + ] + ) + + self.tp_encoders = nn.ModuleList( + [ + EncoderLayerSANM( + self.output_size, + self.output_size, + encoder_selfattn_layer(*encoder_selfattn_layer_args), + positionwise_layer(*positionwise_layer_args), + self.dropout_rate, + ) + for i in range(self.tp_blocks) + ] + ) + + self.after_norm = LayerNorm(self.output_size) + + self.tp_norm = LayerNorm(self.output_size) + + def forward( + self, + xs_pad: torch.Tensor, + ): + masks = None + + xs_pad *= self.output_size**0.5 + + xs_pad = self.embed(xs_pad) + + # forward encoder1 + for layer_idx, encoder_layer in enumerate(self.encoders0): + encoder_outs = encoder_layer(xs_pad, masks) + xs_pad, masks = encoder_outs[0], encoder_outs[1] + + for layer_idx, encoder_layer in enumerate(self.encoders): + encoder_outs = encoder_layer(xs_pad, masks) + xs_pad, masks = encoder_outs[0], encoder_outs[1] + + xs_pad = self.after_norm(xs_pad) + + for layer_idx, encoder_layer in enumerate(self.tp_encoders): + encoder_outs = encoder_layer(xs_pad, masks) + xs_pad, masks = encoder_outs[0], encoder_outs[1] + + xs_pad = self.tp_norm(xs_pad) + return xs_pad + + +class CTC(nn.Module): + def __init__( + self, + odim: int, + encoder_output_size: int, + dropout_rate: float = 0.0, + ctc_type: str = "builtin", + reduce: bool = True, + ignore_nan_grad: bool = True, + extra_linear: bool = True, + ): + super().__init__() + eprojs = encoder_output_size + self.dropout_rate = dropout_rate + + if extra_linear: + self.ctc_lo = torch.nn.Linear(eprojs, odim) + else: + self.ctc_lo = None + + def softmax(self, hs_pad): + """softmax of frame activations + + Args: + Tensor hs_pad: 3d tensor (B, Tmax, eprojs) + Returns: + torch.Tensor: softmax applied 3d tensor (B, Tmax, odim) + """ + if self.ctc_lo is not None: + return F.softmax(self.ctc_lo(hs_pad), dim=2) + else: + return F.softmax(hs_pad, dim=2) + + def log_softmax(self, hs_pad): + """log_softmax of frame activations + + Args: + Tensor hs_pad: 3d tensor (B, Tmax, eprojs) + Returns: + torch.Tensor: log softmax applied 3d tensor (B, Tmax, odim) + """ + if self.ctc_lo is not None: + return F.log_softmax(self.ctc_lo(hs_pad), dim=2) + else: + return F.log_softmax(hs_pad, dim=2) + + def argmax(self, hs_pad): + """argmax of frame activations + + Args: + torch.Tensor hs_pad: 3d tensor (B, Tmax, eprojs) + Returns: + torch.Tensor: argmax applied 2d tensor (B, Tmax) + """ + if self.ctc_lo is not None: + return torch.argmax(self.ctc_lo(hs_pad), dim=2) + else: + return torch.argmax(hs_pad, dim=2) + + +class SenseVoiceSmall(nn.Module): + def __init__(self, neg_mean: torch.Tensor, inv_stddev: torch.Tensor): + super().__init__() + self.sos = 1 + self.eos = 2 + self.length_normalized_loss = True + self.ignore_id = -1 + self.blank_id = 0 + self.input_size = 80 * 7 + self.vocab_size = 25055 + + self.neg_mean = neg_mean.unsqueeze(0).unsqueeze(0) + self.inv_stddev = inv_stddev.unsqueeze(0).unsqueeze(0) + + self.lid_dict = { + "auto": 0, + "zh": 3, + "en": 4, + "yue": 7, + "ja": 11, + "ko": 12, + "nospeech": 13, + } + self.lid_int_dict = { + 24884: 3, + 24885: 4, + 24888: 7, + 24892: 11, + 24896: 12, + 24992: 13, + } + self.textnorm_dict = {"withitn": 14, "woitn": 15} + self.textnorm_int_dict = {25016: 14, 25017: 15} + + self.emo_dict = { + "unk": 25009, + "happy": 25001, + "sad": 25002, + "angry": 25003, + "neutral": 25004, + } + + self.encoder = SenseVoiceEncoderSmall() + self.ctc = CTC( + odim=self.vocab_size, + encoder_output_size=self.encoder.output_size, + ) + self.embed = torch.nn.Embedding( + 7 + len(self.lid_dict) + len(self.textnorm_dict), self.input_size + ) + + def forward(self, x, prompt): + input_query = self.embed(prompt).unsqueeze(0) + + # for export, we always assume x and self.neg_mean are on CPU + x = (x + self.neg_mean) * self.inv_stddev + x = torch.cat((input_query, x), dim=1) + + encoder_out = self.encoder(x) + logits = self.ctc.ctc_lo(encoder_out) + + return logits diff --git a/sherpa-onnx/csrc/CMakeLists.txt b/sherpa-onnx/csrc/CMakeLists.txt index 96bd6396..fc80ac9b 100644 --- a/sherpa-onnx/csrc/CMakeLists.txt +++ b/sherpa-onnx/csrc/CMakeLists.txt @@ -173,6 +173,8 @@ list(APPEND sources ) if(SHERPA_ONNX_ENABLE_RKNN) list(APPEND sources + ./rknn/offline-ctc-greedy-search-decoder-rknn.cc + ./rknn/offline-sense-voice-model-rknn.cc ./rknn/online-stream-rknn.cc ./rknn/online-transducer-greedy-search-decoder-rknn.cc ./rknn/online-transducer-modified-beam-search-decoder-rknn.cc diff --git a/sherpa-onnx/csrc/offline-model-config.cc b/sherpa-onnx/csrc/offline-model-config.cc index 493309fc..68f98dd2 100644 --- a/sherpa-onnx/csrc/offline-model-config.cc +++ b/sherpa-onnx/csrc/offline-model-config.cc @@ -7,6 +7,7 @@ #include "sherpa-onnx/csrc/file-utils.h" #include "sherpa-onnx/csrc/macros.h" +#include "sherpa-onnx/csrc/text-utils.h" namespace sherpa_onnx { @@ -57,9 +58,37 @@ void OfflineModelConfig::Register(ParseOptions *po) { } bool OfflineModelConfig::Validate() const { - if (num_threads < 1) { - SHERPA_ONNX_LOGE("num_threads should be > 0. Given %d", num_threads); - return false; + // For RK NPU, we reinterpret num_threads: + // + // For RK3588 only + // num_threads == 1 -> Select a core randomly + // num_threads == 0 -> Use NPU core 0 + // num_threads == -1 -> Use NPU core 1 + // num_threads == -2 -> Use NPU core 2 + // num_threads == -3 -> Use NPU core 0 and core 1 + // num_threads == -4 -> Use NPU core 0, core 1, and core 2 + if (provider != "rknn") { + if (num_threads < 1) { + SHERPA_ONNX_LOGE("num_threads should be > 0. Given %d", num_threads); + return false; + } + if (!sense_voice.model.empty() && (EndsWith(sense_voice.model, ".rknn"))) { + SHERPA_ONNX_LOGE( + "--provider is %s, which is not rknn, but you pass a rknn model " + "filename. model: '%s'", + provider.c_str(), sense_voice.model.c_str()); + return false; + } + } + + if (provider == "rknn") { + if (!sense_voice.model.empty() && (EndsWith(sense_voice.model, ".onnx"))) { + SHERPA_ONNX_LOGE( + "--provider is rknn, but you pass an onnx model " + "filename. model: '%s'", + sense_voice.model.c_str()); + return false; + } } if (!FileExists(tokens)) { diff --git a/sherpa-onnx/csrc/offline-recognizer-impl.cc b/sherpa-onnx/csrc/offline-recognizer-impl.cc index 837bcc89..11e7c590 100644 --- a/sherpa-onnx/csrc/offline-recognizer-impl.cc +++ b/sherpa-onnx/csrc/offline-recognizer-impl.cc @@ -35,10 +35,32 @@ #include "sherpa-onnx/csrc/offline-recognizer-whisper-impl.h" #include "sherpa-onnx/csrc/text-utils.h" +#if SHERPA_ONNX_ENABLE_RKNN +#include "sherpa-onnx/csrc/rknn/offline-recognizer-sense-voice-rknn-impl.h" +#endif + namespace sherpa_onnx { std::unique_ptr OfflineRecognizerImpl::Create( const OfflineRecognizerConfig &config) { + if (config.model_config.provider == "rknn") { +#if SHERPA_ONNX_ENABLE_RKNN + if (config.model_config.sense_voice.model.empty()) { + SHERPA_ONNX_LOGE( + "Only SenseVoice models are currently supported " + "by rknn for non-streaming ASR. Fallback to CPU"); + } else if (!config.model_config.sense_voice.model.empty()) { + return std::make_unique(config); + } +#else + SHERPA_ONNX_LOGE( + "Please rebuild sherpa-onnx with -DSHERPA_ONNX_ENABLE_RKNN=ON if you " + "want to use rknn."); + SHERPA_ONNX_EXIT(-1); + return nullptr; +#endif + } + if (!config.model_config.sense_voice.model.empty()) { return std::make_unique(config); } @@ -229,6 +251,24 @@ std::unique_ptr OfflineRecognizerImpl::Create( template std::unique_ptr OfflineRecognizerImpl::Create( Manager *mgr, const OfflineRecognizerConfig &config) { + if (config.model_config.provider == "rknn") { +#if SHERPA_ONNX_ENABLE_RKNN + if (config.model_config.sense_voice.model.empty()) { + SHERPA_ONNX_LOGE( + "Only SenseVoice models are currently supported " + "by rknn for non-streaming ASR. Fallback to CPU"); + } else if (!config.model_config.sense_voice.model.empty()) { + return std::make_unique(mgr, config); + } +#else + SHERPA_ONNX_LOGE( + "Please rebuild sherpa-onnx with -DSHERPA_ONNX_ENABLE_RKNN=ON if you " + "want to use rknn."); + SHERPA_ONNX_EXIT(-1); + return nullptr; +#endif + } + if (!config.model_config.sense_voice.model.empty()) { return std::make_unique(mgr, config); } diff --git a/sherpa-onnx/csrc/offline-recognizer-sense-voice-impl.h b/sherpa-onnx/csrc/offline-recognizer-sense-voice-impl.h index 82266cd7..c703f129 100644 --- a/sherpa-onnx/csrc/offline-recognizer-sense-voice-impl.h +++ b/sherpa-onnx/csrc/offline-recognizer-sense-voice-impl.h @@ -11,6 +11,7 @@ #include #include +#include "sherpa-onnx/csrc/macros.h" #include "sherpa-onnx/csrc/offline-ctc-greedy-search-decoder.h" #include "sherpa-onnx/csrc/offline-model-config.h" #include "sherpa-onnx/csrc/offline-recognizer-impl.h" @@ -21,7 +22,7 @@ namespace sherpa_onnx { -static OfflineRecognitionResult ConvertSenseVoiceResult( +OfflineRecognitionResult ConvertSenseVoiceResult( const OfflineCtcDecoderResult &src, const SymbolTable &sym_table, int32_t frame_shift_ms, int32_t subsampling_factor) { OfflineRecognitionResult r; @@ -72,7 +73,7 @@ class OfflineRecognizerSenseVoiceImpl : public OfflineRecognizerImpl { } else { SHERPA_ONNX_LOGE("Only greedy_search is supported at present. Given %s", config.decoding_method.c_str()); - exit(-1); + SHERPA_ONNX_EXIT(-1); } InitFeatConfig(); @@ -93,7 +94,7 @@ class OfflineRecognizerSenseVoiceImpl : public OfflineRecognizerImpl { } else { SHERPA_ONNX_LOGE("Only greedy_search is supported at present. Given %s", config.decoding_method.c_str()); - exit(-1); + SHERPA_ONNX_EXIT(-1); } InitFeatConfig(); diff --git a/sherpa-onnx/csrc/online-recognizer-impl.cc b/sherpa-onnx/csrc/online-recognizer-impl.cc index 011751c3..c558cb4e 100644 --- a/sherpa-onnx/csrc/online-recognizer-impl.cc +++ b/sherpa-onnx/csrc/online-recognizer-impl.cc @@ -37,7 +37,6 @@ std::unique_ptr OnlineRecognizerImpl::Create( const OnlineRecognizerConfig &config) { if (config.model_config.provider_config.provider == "rknn") { #if SHERPA_ONNX_ENABLE_RKNN - // Currently, only zipformer v1 is suported for rknn if (config.model_config.transducer.encoder.empty() && config.model_config.zipformer2_ctc.model.empty()) { SHERPA_ONNX_LOGE( diff --git a/sherpa-onnx/csrc/rknn/offline-ctc-greedy-search-decoder-rknn.cc b/sherpa-onnx/csrc/rknn/offline-ctc-greedy-search-decoder-rknn.cc new file mode 100644 index 00000000..3f5b3252 --- /dev/null +++ b/sherpa-onnx/csrc/rknn/offline-ctc-greedy-search-decoder-rknn.cc @@ -0,0 +1,38 @@ +// sherpa-onnx/csrc/rknn/offline-ctc-greedy-search-decoder-rknn.cc +// +// Copyright (c) 2025 Xiaomi Corporation + +#include "sherpa-onnx/csrc/rknn/offline-ctc-greedy-search-decoder-rknn.h" + +#include +#include +#include + +#include "sherpa-onnx/csrc/macros.h" + +namespace sherpa_onnx { + +OfflineCtcDecoderResult OfflineCtcGreedySearchDecoderRknn::Decode( + const float *logits, int32_t num_frames, int32_t vocab_size) { + OfflineCtcDecoderResult ans; + + int64_t prev_id = -1; + + for (int32_t t = 0; t != num_frames; ++t) { + auto y = static_cast(std::distance( + static_cast(logits), + std::max_element(static_cast(logits), + static_cast(logits) + vocab_size))); + logits += vocab_size; + + if (y != blank_id_ && y != prev_id) { + ans.tokens.push_back(y); + ans.timestamps.push_back(t); + } + prev_id = y; + } // for (int32_t t = 0; ...) + + return ans; +} + +} // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/rknn/offline-ctc-greedy-search-decoder-rknn.h b/sherpa-onnx/csrc/rknn/offline-ctc-greedy-search-decoder-rknn.h new file mode 100644 index 00000000..eda99040 --- /dev/null +++ b/sherpa-onnx/csrc/rknn/offline-ctc-greedy-search-decoder-rknn.h @@ -0,0 +1,28 @@ +// sherpa-onnx/csrc/rknn/offline-ctc-greedy-search-decoder-rknn.h +// +// Copyright (c) 2025 Xiaomi Corporation + +#ifndef SHERPA_ONNX_CSRC_RKNN_OFFLINE_CTC_GREEDY_SEARCH_DECODER_RKNN_H_ +#define SHERPA_ONNX_CSRC_RKNN_OFFLINE_CTC_GREEDY_SEARCH_DECODER_RKNN_H_ + +#include + +#include "sherpa-onnx/csrc/offline-ctc-decoder.h" + +namespace sherpa_onnx { + +class OfflineCtcGreedySearchDecoderRknn { + public: + explicit OfflineCtcGreedySearchDecoderRknn(int32_t blank_id) + : blank_id_(blank_id) {} + + OfflineCtcDecoderResult Decode(const float *logits, int32_t num_frames, + int32_t vocab_size); + + private: + int32_t blank_id_; +}; + +} // namespace sherpa_onnx + +#endif // SHERPA_ONNX_CSRC_RKNN_OFFLINE_CTC_GREEDY_SEARCH_DECODER_RKNN_H_ diff --git a/sherpa-onnx/csrc/rknn/offline-recognizer-sense-voice-rknn-impl.h b/sherpa-onnx/csrc/rknn/offline-recognizer-sense-voice-rknn-impl.h new file mode 100644 index 00000000..d6f2f208 --- /dev/null +++ b/sherpa-onnx/csrc/rknn/offline-recognizer-sense-voice-rknn-impl.h @@ -0,0 +1,138 @@ +// sherpa-onnx/csrc/offline-recognizer-sense-voice-rknn-impl.h +// +// Copyright (c) 2025 Xiaomi Corporation + +#ifndef SHERPA_ONNX_CSRC_RKNN_OFFLINE_RECOGNIZER_SENSE_VOICE_RKNN_IMPL_H_ +#define SHERPA_ONNX_CSRC_RKNN_OFFLINE_RECOGNIZER_SENSE_VOICE_RKNN_IMPL_H_ + +#include +#include +#include + +#include "sherpa-onnx/csrc/macros.h" +#include "sherpa-onnx/csrc/offline-model-config.h" +#include "sherpa-onnx/csrc/offline-recognizer-impl.h" +#include "sherpa-onnx/csrc/offline-recognizer.h" +#include "sherpa-onnx/csrc/rknn/offline-ctc-greedy-search-decoder-rknn.h" +#include "sherpa-onnx/csrc/rknn/offline-sense-voice-model-rknn.h" +#include "sherpa-onnx/csrc/symbol-table.h" + +namespace sherpa_onnx { + +// defined in ../online-recognizer-sense-voice-impl.h +OfflineRecognitionResult ConvertSenseVoiceResult( + const OfflineCtcDecoderResult &src, const SymbolTable &sym_table, + int32_t frame_shift_ms, int32_t subsampling_factor); + +class OfflineRecognizerSenseVoiceRknnImpl : public OfflineRecognizerImpl { + public: + explicit OfflineRecognizerSenseVoiceRknnImpl( + const OfflineRecognizerConfig &config) + : OfflineRecognizerImpl(config), + config_(config), + symbol_table_(config_.model_config.tokens), + model_( + std::make_unique(config.model_config)) { + const auto &meta_data = model_->GetModelMetadata(); + if (config.decoding_method == "greedy_search") { + decoder_ = std::make_unique( + meta_data.blank_id); + } else { + SHERPA_ONNX_LOGE("Only greedy_search is supported at present. Given %s", + config.decoding_method.c_str()); + SHERPA_ONNX_EXIT(-1); + } + + InitFeatConfig(); + } + + template + OfflineRecognizerSenseVoiceRknnImpl(Manager *mgr, + const OfflineRecognizerConfig &config) + : OfflineRecognizerImpl(mgr, config), + config_(config), + symbol_table_(mgr, config_.model_config.tokens), + model_(std::make_unique( + mgr, config.model_config)) { + const auto &meta_data = model_->GetModelMetadata(); + if (config.decoding_method == "greedy_search") { + decoder_ = std::make_unique( + meta_data.blank_id); + } else { + SHERPA_ONNX_LOGE("Only greedy_search is supported at present. Given %s", + config.decoding_method.c_str()); + SHERPA_ONNX_EXIT(-1); + } + + InitFeatConfig(); + } + + std::unique_ptr CreateStream() const override { + return std::make_unique(config_.feat_config); + } + + void DecodeStreams(OfflineStream **ss, int32_t n) const override { + for (int32_t i = 0; i < n; ++i) { + DecodeOneStream(ss[i]); + } + } + + OfflineRecognizerConfig GetConfig() const override { return config_; } + + private: + void InitFeatConfig() { + const auto &meta_data = model_->GetModelMetadata(); + + config_.feat_config.normalize_samples = meta_data.normalize_samples; + config_.feat_config.window_type = "hamming"; + config_.feat_config.high_freq = 0; + config_.feat_config.snip_edges = true; + } + + void DecodeOneStream(OfflineStream *s) const { + const auto &meta_data = model_->GetModelMetadata(); + + std::vector f = s->GetFrames(); + + int32_t language = 0; + if (config_.model_config.sense_voice.language.empty()) { + language = 0; + } else if (meta_data.lang2id.count( + config_.model_config.sense_voice.language)) { + language = + meta_data.lang2id.at(config_.model_config.sense_voice.language); + } else { + SHERPA_ONNX_LOGE("Unknown language: %s. Use 0 instead.", + config_.model_config.sense_voice.language.c_str()); + } + + int32_t text_norm = config_.model_config.sense_voice.use_itn + ? meta_data.with_itn_id + : meta_data.without_itn_id; + + std::vector logits = model_->Run(std::move(f), language, text_norm); + int32_t num_out_frames = logits.size() / meta_data.vocab_size; + + auto result = + decoder_->Decode(logits.data(), num_out_frames, meta_data.vocab_size); + + int32_t frame_shift_ms = 10; + int32_t subsampling_factor = meta_data.window_shift; + auto r = ConvertSenseVoiceResult(result, symbol_table_, frame_shift_ms, + subsampling_factor); + + r.text = ApplyInverseTextNormalization(std::move(r.text)); + r.text = ApplyHomophoneReplacer(std::move(r.text)); + s->SetResult(r); + } + + private: + OfflineRecognizerConfig config_; + SymbolTable symbol_table_; + std::unique_ptr model_; + std::unique_ptr decoder_; +}; + +} // namespace sherpa_onnx + +#endif // SHERPA_ONNX_CSRC_RKNN_OFFLINE_RECOGNIZER_SENSE_VOICE_RKNN_IMPL_H_ diff --git a/sherpa-onnx/csrc/rknn/offline-sense-voice-model-rknn.cc b/sherpa-onnx/csrc/rknn/offline-sense-voice-model-rknn.cc new file mode 100644 index 00000000..e879c0bd --- /dev/null +++ b/sherpa-onnx/csrc/rknn/offline-sense-voice-model-rknn.cc @@ -0,0 +1,244 @@ +// sherpa-onnx/csrc/rknn/offline-sense-voice-model-rknn.cc +// +// Copyright (c) 2025 Xiaomi Corporation + +#include "sherpa-onnx/csrc/rknn/offline-sense-voice-model-rknn.h" + +#include +#include +#include +#include + +#if __ANDROID_API__ >= 9 +#include "android/asset_manager.h" +#include "android/asset_manager_jni.h" +#endif + +#if __OHOS__ +#include "rawfile/raw_file_manager.h" +#endif + +#include "sherpa-onnx/csrc/file-utils.h" +#include "sherpa-onnx/csrc/rknn/macros.h" +#include "sherpa-onnx/csrc/rknn/utils.h" + +namespace sherpa_onnx { + +class OfflineSenseVoiceModelRknn::Impl { + public: + ~Impl() { + auto ret = rknn_destroy(ctx_); + if (ret != RKNN_SUCC) { + SHERPA_ONNX_LOGE("Failed to destroy the context"); + } + } + + explicit Impl(const OfflineModelConfig &config) : config_(config) { + { + auto buf = ReadFile(config_.sense_voice.model); + Init(buf.data(), buf.size()); + } + + SetCoreMask(ctx_, config_.num_threads); + } + + template + Impl(Manager *mgr, const OfflineModelConfig &config) : config_(config) { + { + auto buf = ReadFile(mgr, config_.sense_voice.model); + Init(buf.data(), buf.size()); + } + + SetCoreMask(ctx_, config_.num_threads); + } + + const OfflineSenseVoiceModelMetaData &GetModelMetadata() const { + return meta_data_; + } + + std::vector Run(std::vector features, int32_t language, + int32_t text_norm) { + features = ApplyLFR(std::move(features)); + + std::vector inputs(input_attrs_.size()); + + std::array prompt{language, 1, 2, text_norm}; + + inputs[0].index = input_attrs_[0].index; + inputs[0].type = RKNN_TENSOR_FLOAT32; + inputs[0].fmt = input_attrs_[0].fmt; + inputs[0].buf = reinterpret_cast(features.data()); + inputs[0].size = features.size() * sizeof(float); + + inputs[1].index = input_attrs_[1].index; + inputs[1].type = RKNN_TENSOR_INT32; + inputs[1].fmt = input_attrs_[1].fmt; + inputs[1].buf = reinterpret_cast(prompt.data()); + inputs[1].size = prompt.size() * sizeof(int32_t); + + std::vector out(output_attrs_[0].n_elems); + + std::vector outputs(output_attrs_.size()); + outputs[0].index = output_attrs_[0].index; + outputs[0].is_prealloc = 1; + outputs[0].want_float = 1; + outputs[0].size = out.size() * sizeof(float); + outputs[0].buf = reinterpret_cast(out.data()); + + rknn_context ctx = 0; + auto ret = rknn_dup_context(&ctx_, &ctx); + SHERPA_ONNX_RKNN_CHECK(ret, "Failed to duplicate the ctx"); + + ret = rknn_inputs_set(ctx, inputs.size(), inputs.data()); + SHERPA_ONNX_RKNN_CHECK(ret, "Failed to set inputs"); + + ret = rknn_run(ctx, nullptr); + SHERPA_ONNX_RKNN_CHECK(ret, "Failed to run the model"); + + ret = rknn_outputs_get(ctx, outputs.size(), outputs.data(), nullptr); + SHERPA_ONNX_RKNN_CHECK(ret, "Failed to get model output"); + + rknn_destroy(ctx); + + return out; + } + + private: + void Init(void *model_data, size_t model_data_length) { + InitContext(model_data, model_data_length, config_.debug, &ctx_); + + InitInputOutputAttrs(ctx_, config_.debug, &input_attrs_, &output_attrs_); + + rknn_custom_string custom_string = GetCustomString(ctx_, config_.debug); + + auto meta = Parse(custom_string, config_.debug); + +#define SHERPA_ONNX_RKNN_READ_META_DATA_INT(dst, src_key) \ + do { \ + if (!meta.count(#src_key)) { \ + SHERPA_ONNX_LOGE("'%s' does not exist in the custom_string", #src_key); \ + SHERPA_ONNX_EXIT(-1); \ + } \ + \ + dst = atoi(meta.at(#src_key).c_str()); \ + } while (0) + + SHERPA_ONNX_RKNN_READ_META_DATA_INT(meta_data_.with_itn_id, with_itn); + SHERPA_ONNX_RKNN_READ_META_DATA_INT(meta_data_.without_itn_id, without_itn); + SHERPA_ONNX_RKNN_READ_META_DATA_INT(meta_data_.window_size, + lfr_window_size); + SHERPA_ONNX_RKNN_READ_META_DATA_INT(meta_data_.window_shift, + lfr_window_shift); + SHERPA_ONNX_RKNN_READ_META_DATA_INT(meta_data_.vocab_size, vocab_size); + SHERPA_ONNX_RKNN_READ_META_DATA_INT(meta_data_.normalize_samples, + normalize_samples); + + int32_t lang_auto = 0; + int32_t lang_zh = 0; + int32_t lang_en = 0; + int32_t lang_ja = 0; + int32_t lang_ko = 0; + int32_t lang_yue = 0; + + SHERPA_ONNX_RKNN_READ_META_DATA_INT(lang_auto, lang_auto); + SHERPA_ONNX_RKNN_READ_META_DATA_INT(lang_zh, lang_zh); + SHERPA_ONNX_RKNN_READ_META_DATA_INT(lang_en, lang_en); + SHERPA_ONNX_RKNN_READ_META_DATA_INT(lang_ja, lang_ja); + SHERPA_ONNX_RKNN_READ_META_DATA_INT(lang_ko, lang_ko); + SHERPA_ONNX_RKNN_READ_META_DATA_INT(lang_yue, lang_yue); + + meta_data_.lang2id = { + {"auto", lang_auto}, {"zh", lang_zh}, {"en", lang_en}, + {"ja", lang_ja}, {"ko", lang_ko}, {"yue", lang_yue}, + }; + + // for rknn models, neg_mean and inv_stddev are stored inside the model + +#undef SHERPA_ONNX_RKNN_READ_META_DATA_INT + + num_input_frames_ = input_attrs_[0].dims[1]; + } + + std::vector ApplyLFR(std::vector in) const { + int32_t lfr_window_size = meta_data_.window_size; + int32_t lfr_window_shift = meta_data_.window_shift; + int32_t in_feat_dim = 80; + + int32_t in_num_frames = in.size() / in_feat_dim; + int32_t out_num_frames = + (in_num_frames - lfr_window_size) / lfr_window_shift + 1; + + if (out_num_frames > num_input_frames_) { + SHERPA_ONNX_LOGE( + "Number of input frames %d is too large. Truncate it to %d frames.", + out_num_frames, num_input_frames_); + + SHERPA_ONNX_LOGE( + "Recognition result may be truncated/incomplete. Please select a " + "model accepting longer audios."); + + out_num_frames = num_input_frames_; + } + + int32_t out_feat_dim = in_feat_dim * lfr_window_size; + + std::vector out(num_input_frames_ * out_feat_dim); + + const float *p_in = in.data(); + float *p_out = out.data(); + + for (int32_t i = 0; i != out_num_frames; ++i) { + std::copy(p_in, p_in + out_feat_dim, p_out); + + p_out += out_feat_dim; + p_in += lfr_window_shift * in_feat_dim; + } + + return out; + } + + private: + OfflineModelConfig config_; + + rknn_context ctx_ = 0; + + std::vector input_attrs_; + std::vector output_attrs_; + + OfflineSenseVoiceModelMetaData meta_data_; + int32_t num_input_frames_ = -1; +}; + +OfflineSenseVoiceModelRknn::~OfflineSenseVoiceModelRknn() = default; + +OfflineSenseVoiceModelRknn::OfflineSenseVoiceModelRknn( + const OfflineModelConfig &config) + : impl_(std::make_unique(config)) {} + +template +OfflineSenseVoiceModelRknn::OfflineSenseVoiceModelRknn( + Manager *mgr, const OfflineModelConfig &config) + : impl_(std::make_unique(mgr, config)) {} + +std::vector OfflineSenseVoiceModelRknn::Run(std::vector features, + int32_t language, + int32_t text_norm) const { + return impl_->Run(std::move(features), language, text_norm); +} + +const OfflineSenseVoiceModelMetaData & +OfflineSenseVoiceModelRknn::GetModelMetadata() const { + return impl_->GetModelMetadata(); +} + +#if __ANDROID_API__ >= 9 +template OfflineSenseVoiceModelRknn::OfflineSenseVoiceModelRknn( + AAssetManager *mgr, const OfflineModelConfig &config); +#endif + +#if __OHOS__ +template OfflineSenseVoiceModelRknn::OfflineSenseVoiceModelRknn( + NativeResourceManager *mgr, const OfflineModelConfig &config); +#endif + +} // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/rknn/offline-sense-voice-model-rknn.h b/sherpa-onnx/csrc/rknn/offline-sense-voice-model-rknn.h new file mode 100644 index 00000000..ddbc86b1 --- /dev/null +++ b/sherpa-onnx/csrc/rknn/offline-sense-voice-model-rknn.h @@ -0,0 +1,45 @@ +// sherpa-onnx/csrc/rknn/offline-sense-voice-model-rknn.h +// +// Copyright (c) 2025 Xiaomi Corporation +#ifndef SHERPA_ONNX_CSRC_RKNN_OFFLINE_SENSE_VOICE_MODEL_RKNN_H_ +#define SHERPA_ONNX_CSRC_RKNN_OFFLINE_SENSE_VOICE_MODEL_RKNN_H_ + +#include +#include +#include + +#include "rknn_api.h" // NOLINT +#include "sherpa-onnx/csrc/offline-model-config.h" +#include "sherpa-onnx/csrc/offline-sense-voice-model-meta-data.h" + +namespace sherpa_onnx { + +class OfflineSenseVoiceModelRknn { + public: + ~OfflineSenseVoiceModelRknn(); + + explicit OfflineSenseVoiceModelRknn(const OfflineModelConfig &config); + + template + OfflineSenseVoiceModelRknn(Manager *mgr, const OfflineModelConfig &config); + + /** + * @param features A tensor of shape (num_frames, feature_dim) + * before applying LFR. + * @param language + * @param text_norm + * @returns Return a tensor of shape (num_output_frames, vocab_size) + */ + std::vector Run(std::vector features, int32_t language, + int32_t text_norm) const; + + const OfflineSenseVoiceModelMetaData &GetModelMetadata() const; + + private: + class Impl; + std::unique_ptr impl_; +}; + +} // namespace sherpa_onnx + +#endif // SHERPA_ONNX_CSRC_RKNN_OFFLINE_SENSE_VOICE_MODEL_RKNN_H_ diff --git a/sherpa-onnx/csrc/rknn/online-zipformer-ctc-model-rknn.cc b/sherpa-onnx/csrc/rknn/online-zipformer-ctc-model-rknn.cc index adfce0c9..bd3da85e 100644 --- a/sherpa-onnx/csrc/rknn/online-zipformer-ctc-model-rknn.cc +++ b/sherpa-onnx/csrc/rknn/online-zipformer-ctc-model-rknn.cc @@ -55,8 +55,6 @@ class OnlineZipformerCtcModelRknn::Impl { SetCoreMask(ctx_, config_.num_threads); } - // TODO(fangjun): Support Android - std::vector> GetInitStates() const { // input_attrs_[0] is for the feature // input_attrs_[1:] is for states diff --git a/sherpa-onnx/csrc/rknn/online-zipformer-transducer-model-rknn.cc b/sherpa-onnx/csrc/rknn/online-zipformer-transducer-model-rknn.cc index 5e2fdf5a..4d34cafb 100644 --- a/sherpa-onnx/csrc/rknn/online-zipformer-transducer-model-rknn.cc +++ b/sherpa-onnx/csrc/rknn/online-zipformer-transducer-model-rknn.cc @@ -89,8 +89,6 @@ class OnlineZipformerTransducerModelRknn::Impl { SetCoreMask(joiner_ctx_, config_.num_threads); } - // TODO(fangjun): Support Android - std::vector> GetEncoderInitStates() const { // encoder_input_attrs_[0] is for the feature // encoder_input_attrs_[1:] is for states