mirror of
https://github.com/k2-fsa/sherpa-onnx.git
synced 2026-01-09 07:41:06 +08:00
Add C++ and Python support for T-one streaming Russian ASR models (#2575)
This PR adds support for T-one streaming Russian ASR models in both C++ and Python APIs. The T-one model is a CTC-based Russian speech recognition model with specific characteristics including float16 state handling, 300ms frame lengths, and 8kHz sampling rate. - Added new OnlineToneCtcModel implementation with specialized processing for T-one models - Integrated T-one support into the existing CTC model pipeline and Python bindings - Added Python example and test scripts for the new functionality
This commit is contained in:
parent
e4f48ce6a6
commit
858b5052a2
10
.github/scripts/test-python.sh
vendored
10
.github/scripts/test-python.sh
vendored
@ -8,6 +8,16 @@ log() {
|
||||
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
|
||||
}
|
||||
|
||||
log "test T-one"
|
||||
|
||||
curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-streaming-t-one-russian-2025-09-08.tar.bz2
|
||||
tar xvf sherpa-onnx-streaming-t-one-russian-2025-09-08.tar.bz2
|
||||
rm sherpa-onnx-streaming-t-one-russian-2025-09-08.tar.bz2
|
||||
|
||||
python3 ./python-api-examples/online-t-one-ctc-decode-files.py
|
||||
|
||||
rm -rf sherpa-onnx-streaming-t-one-russian-2025-09-08
|
||||
|
||||
log "test nemo canary"
|
||||
curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-nemo-canary-180m-flash-en-es-de-fr-int8.tar.bz2
|
||||
tar xvf sherpa-onnx-nemo-canary-180m-flash-en-es-de-fr-int8.tar.bz2
|
||||
|
||||
1
.gitignore
vendored
1
.gitignore
vendored
@ -149,3 +149,4 @@ kitten-nano-en-v0_1-fp16
|
||||
*.egg-info
|
||||
*.jar
|
||||
vocab.json
|
||||
*.so
|
||||
|
||||
@ -2,7 +2,8 @@
|
||||
// Copyright (c) 2025 Xiaomi Corporation
|
||||
|
||||
// To use punctuation model:
|
||||
// wget https://github.com/k2-fsa/sherpa-onnx/releases/download/punctuation-models/sherpa-onnx-punct-ct-transformer-zh-en-vocab272727-2024-04-12.tar.bz2
|
||||
// wget
|
||||
// https://github.com/k2-fsa/sherpa-onnx/releases/download/punctuation-models/sherpa-onnx-punct-ct-transformer-zh-en-vocab272727-2024-04-12.tar.bz2
|
||||
// tar xvf sherpa-onnx-punct-ct-transformer-zh-en-vocab272727-2024-04-12.tar.bz2
|
||||
// rm sherpa-onnx-punct-ct-transformer-zh-en-vocab272727-2024-04-12.tar.bz2
|
||||
|
||||
@ -15,14 +16,17 @@ int32_t main() {
|
||||
using namespace sherpa_onnx::cxx; // NOLINT
|
||||
|
||||
OfflinePunctuationConfig punctuation_config;
|
||||
punctuation_config.model.ct_transformer = "./sherpa-onnx-punct-ct-transformer-zh-en-vocab272727-2024-04-12/model.onnx";
|
||||
punctuation_config.model.ct_transformer =
|
||||
"./sherpa-onnx-punct-ct-transformer-zh-en-vocab272727-2024-04-12/"
|
||||
"model.onnx";
|
||||
punctuation_config.model.num_threads = 1;
|
||||
punctuation_config.model.debug = false;
|
||||
punctuation_config.model.provider = "cpu";
|
||||
|
||||
OfflinePunctuation punct = OfflinePunctuation::Create(punctuation_config);
|
||||
if (!punct.Get()) {
|
||||
std::cerr << "Failed to create punctuation model. Please check your config\n";
|
||||
std::cerr
|
||||
<< "Failed to create punctuation model. Please check your config\n";
|
||||
return -1;
|
||||
}
|
||||
|
||||
|
||||
75
python-api-examples/online-t-one-ctc-decode-files.py
Executable file
75
python-api-examples/online-t-one-ctc-decode-files.py
Executable file
@ -0,0 +1,75 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
"""
|
||||
This file shows how to use a streaming CTC model from T-one
|
||||
to decode files.
|
||||
|
||||
Please download model files from
|
||||
https://github.com/k2-fsa/sherpa-onnx/releases/tag/asr-models
|
||||
|
||||
|
||||
The example model is converted from
|
||||
https://github.com/voicekit-team/T-one
|
||||
using
|
||||
https://github.com/k2-fsa/sherpa-onnx/tree/master/scripts/t-one
|
||||
|
||||
wget https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-streaming-t-one-russian-2025-09-08.tar.bz2
|
||||
tar xvf sherpa-onnx-streaming-t-one-russian-2025-09-08.tar.bz2
|
||||
rm sherpa-onnx-streaming-t-one-russian-2025-09-08.tar.bz2
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import sherpa_onnx
|
||||
import soundfile as sf
|
||||
|
||||
|
||||
def create_recognizer():
|
||||
model = "./sherpa-onnx-streaming-t-one-russian-2025-09-08/model.onnx"
|
||||
tokens = "./sherpa-onnx-streaming-t-one-russian-2025-09-08/tokens.txt"
|
||||
test_wav = "./sherpa-onnx-streaming-t-one-russian-2025-09-08/0.wav"
|
||||
|
||||
if not Path(model).is_file() or not Path(test_wav).is_file():
|
||||
raise ValueError(
|
||||
"""Please download model files from
|
||||
https://github.com/k2-fsa/sherpa-onnx/releases/tag/asr-models
|
||||
"""
|
||||
)
|
||||
return (
|
||||
sherpa_onnx.OnlineRecognizer.from_t_one_ctc(
|
||||
model=model,
|
||||
tokens=tokens,
|
||||
debug=True,
|
||||
),
|
||||
test_wav,
|
||||
)
|
||||
|
||||
|
||||
def main():
|
||||
recognizer, wave_filename = create_recognizer()
|
||||
|
||||
audio, sample_rate = sf.read(wave_filename, dtype="float32", always_2d=True)
|
||||
audio = audio[:, 0] # only use the first channel
|
||||
|
||||
# audio is a 1-D float32 numpy array normalized to the range [-1, 1]
|
||||
# sample_rate does not need to be 8000 Hz
|
||||
|
||||
stream = recognizer.create_stream()
|
||||
left_paddings = np.zeros(int(0.3 * sample_rate), dtype=np.float32)
|
||||
stream.accept_waveform(sample_rate, left_paddings)
|
||||
|
||||
stream.accept_waveform(sample_rate, audio)
|
||||
|
||||
tail_paddings = np.zeros(int(0.66 * sample_rate), dtype=np.float32)
|
||||
stream.accept_waveform(sample_rate, tail_paddings)
|
||||
stream.input_finished()
|
||||
|
||||
while recognizer.is_ready(stream):
|
||||
recognizer.decode_stream(stream)
|
||||
print(wave_filename)
|
||||
print(recognizer.get_result_all(stream))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@ -147,14 +147,13 @@ def main():
|
||||
sample_rate = model.sample_rate
|
||||
|
||||
# Pad 0.5 seconds
|
||||
samples = np.pad(samples, (0, 4000))
|
||||
samples = np.pad(samples, (2400, 2400))
|
||||
|
||||
features = compute_feat(
|
||||
samples=samples,
|
||||
sample_rate=sample_rate,
|
||||
frame_length_ms=model.frame_length_ms,
|
||||
)
|
||||
print(features.shape)
|
||||
|
||||
id2token = load_tokens(args.tokens)
|
||||
|
||||
|
||||
@ -95,6 +95,8 @@ set(sources
|
||||
online-recognizer.cc
|
||||
online-rnn-lm.cc
|
||||
online-stream.cc
|
||||
online-t-one-ctc-model-config.cc
|
||||
online-t-one-ctc-model.cc
|
||||
online-transducer-decoder.cc
|
||||
online-transducer-greedy-search-decoder.cc
|
||||
online-transducer-greedy-search-nemo-decoder.cc
|
||||
|
||||
@ -7,8 +7,10 @@
|
||||
#include <algorithm>
|
||||
#include <functional>
|
||||
#include <numeric>
|
||||
#include <sstream>
|
||||
#include <utility>
|
||||
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
#include "sherpa-onnx/csrc/onnx-utils.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
@ -27,10 +29,12 @@ static bool Compare(const std::vector<int64_t> &a,
|
||||
}
|
||||
|
||||
static void PrintShape(const std::vector<int64_t> &a) {
|
||||
std::ostringstream os;
|
||||
for (auto i : a) {
|
||||
fprintf(stderr, "%d ", static_cast<int32_t>(i));
|
||||
os << i << " ";
|
||||
}
|
||||
fprintf(stderr, "\n");
|
||||
os << "\n";
|
||||
SHERPA_ONNX_LOGE("%s", os.str().c_str());
|
||||
}
|
||||
|
||||
template <typename T /*=float*/>
|
||||
@ -51,15 +55,15 @@ Ort::Value Cat(OrtAllocator *allocator,
|
||||
|
||||
bool ret = Compare(v0_shape, s, dim);
|
||||
if (!ret) {
|
||||
fprintf(stderr, "Incorrect shape in Cat !\n");
|
||||
SHERPA_ONNX_LOGE("Incorrect shape in Cat !\n");
|
||||
|
||||
fprintf(stderr, "Shape for tensor 0: ");
|
||||
SHERPA_ONNX_LOGE("Shape for tensor 0: ");
|
||||
PrintShape(v0_shape);
|
||||
|
||||
fprintf(stderr, "Shape for tensor %d: ", i);
|
||||
SHERPA_ONNX_LOGE("Shape for tensor %d: ", i);
|
||||
PrintShape(s);
|
||||
|
||||
exit(-1);
|
||||
SHERPA_ONNX_EXIT(-1);
|
||||
}
|
||||
}
|
||||
|
||||
@ -99,8 +103,77 @@ template Ort::Value Cat<float>(OrtAllocator *allocator,
|
||||
const std::vector<const Ort::Value *> &values,
|
||||
int32_t dim);
|
||||
|
||||
template Ort::Value Cat<uint16_t>(OrtAllocator *allocator,
|
||||
const std::vector<const Ort::Value *> &values,
|
||||
int32_t dim);
|
||||
|
||||
template Ort::Value Cat<int64_t>(OrtAllocator *allocator,
|
||||
const std::vector<const Ort::Value *> &values,
|
||||
int32_t dim);
|
||||
|
||||
Ort::Value CatFloat16(OrtAllocator *allocator,
|
||||
const std::vector<const Ort::Value *> &values,
|
||||
int32_t dim) {
|
||||
if (values.size() == 1u) {
|
||||
return Clone(allocator, values[0]);
|
||||
}
|
||||
|
||||
std::vector<int64_t> v0_shape =
|
||||
values[0]->GetTensorTypeAndShapeInfo().GetShape();
|
||||
|
||||
int64_t total_dim = v0_shape[dim];
|
||||
|
||||
for (int32_t i = 1; i != static_cast<int32_t>(values.size()); ++i) {
|
||||
auto s = values[i]->GetTensorTypeAndShapeInfo().GetShape();
|
||||
total_dim += s[dim];
|
||||
|
||||
bool ret = Compare(v0_shape, s, dim);
|
||||
if (!ret) {
|
||||
SHERPA_ONNX_LOGE("Incorrect shape in Cat !\n");
|
||||
|
||||
SHERPA_ONNX_LOGE("Shape for tensor 0: ");
|
||||
PrintShape(v0_shape);
|
||||
|
||||
SHERPA_ONNX_LOGE("Shape for tensor %d: ", i);
|
||||
PrintShape(s);
|
||||
|
||||
SHERPA_ONNX_EXIT(-1);
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<int64_t> ans_shape;
|
||||
ans_shape.reserve(v0_shape.size());
|
||||
ans_shape.insert(ans_shape.end(), v0_shape.data(), v0_shape.data() + dim);
|
||||
ans_shape.push_back(total_dim);
|
||||
ans_shape.insert(ans_shape.end(), v0_shape.data() + dim + 1,
|
||||
v0_shape.data() + v0_shape.size());
|
||||
|
||||
auto leading_size = static_cast<int32_t>(std::accumulate(
|
||||
v0_shape.begin(), v0_shape.begin() + dim, 1, std::multiplies<int64_t>()));
|
||||
|
||||
auto trailing_size = static_cast<int32_t>(
|
||||
std::accumulate(v0_shape.begin() + dim + 1, v0_shape.end(), 1,
|
||||
std::multiplies<int64_t>()));
|
||||
|
||||
Ort::Value ans =
|
||||
Ort::Value::CreateTensor(allocator, ans_shape.data(), ans_shape.size(),
|
||||
ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16);
|
||||
using T = uint16_t;
|
||||
|
||||
T *dst = ans.GetTensorMutableData<T>();
|
||||
|
||||
for (int32_t i = 0; i != leading_size; ++i) {
|
||||
for (auto value : values) {
|
||||
auto this_dim = value->GetTensorTypeAndShapeInfo().GetShape()[dim];
|
||||
const T *src = value->GetTensorData<T>();
|
||||
src += i * this_dim * trailing_size;
|
||||
|
||||
std::copy(src, src + this_dim * trailing_size, dst);
|
||||
dst += this_dim * trailing_size;
|
||||
}
|
||||
}
|
||||
|
||||
return ans;
|
||||
}
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
@ -23,6 +23,10 @@ template <typename T = float>
|
||||
Ort::Value Cat(OrtAllocator *allocator,
|
||||
const std::vector<const Ort::Value *> &values, int32_t dim);
|
||||
|
||||
Ort::Value CatFloat16(OrtAllocator *allocator,
|
||||
const std::vector<const Ort::Value *> &values,
|
||||
int32_t dim);
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
#endif // SHERPA_ONNX_CSRC_CAT_H_
|
||||
|
||||
@ -62,6 +62,8 @@ class FeatureExtractor::Impl {
|
||||
InitMfcc();
|
||||
} else if (config_.is_whisper) {
|
||||
InitWhisper();
|
||||
} else if (config_.is_t_one) {
|
||||
InitRawAudioSamples();
|
||||
} else {
|
||||
InitFbank();
|
||||
}
|
||||
@ -135,6 +137,9 @@ class FeatureExtractor::Impl {
|
||||
} else if (whisper_fbank_) {
|
||||
whisper_fbank_->InputFinished();
|
||||
return;
|
||||
} else if (raw_audio_) {
|
||||
raw_audio_->InputFinished();
|
||||
return;
|
||||
} else if (mfcc_) {
|
||||
mfcc_->InputFinished();
|
||||
return;
|
||||
@ -149,6 +154,8 @@ class FeatureExtractor::Impl {
|
||||
return fbank_->NumFramesReady();
|
||||
} else if (whisper_fbank_) {
|
||||
return whisper_fbank_->NumFramesReady();
|
||||
} else if (raw_audio_) {
|
||||
return raw_audio_->NumFramesReady();
|
||||
} else if (mfcc_) {
|
||||
return mfcc_->NumFramesReady();
|
||||
}
|
||||
@ -163,6 +170,8 @@ class FeatureExtractor::Impl {
|
||||
return fbank_->IsLastFrame(frame);
|
||||
} else if (whisper_fbank_) {
|
||||
return whisper_fbank_->IsLastFrame(frame);
|
||||
} else if (raw_audio_) {
|
||||
return raw_audio_->IsLastFrame(frame);
|
||||
} else if (mfcc_) {
|
||||
return mfcc_->IsLastFrame(frame);
|
||||
}
|
||||
@ -209,6 +218,8 @@ class FeatureExtractor::Impl {
|
||||
return opts_.mel_opts.num_bins;
|
||||
} else if (mfcc_) {
|
||||
return mfcc_opts_.num_ceps;
|
||||
} else if (raw_audio_) {
|
||||
return raw_audio_->Dim();
|
||||
}
|
||||
|
||||
SHERPA_ONNX_LOGE("unreachable code");
|
||||
@ -225,6 +236,9 @@ class FeatureExtractor::Impl {
|
||||
} else if (whisper_fbank_) {
|
||||
whisper_fbank_->AcceptWaveform(sampling_rate, waveform, n);
|
||||
return;
|
||||
} else if (raw_audio_) {
|
||||
raw_audio_->AcceptWaveform(sampling_rate, waveform, n);
|
||||
return;
|
||||
} else if (mfcc_) {
|
||||
mfcc_->AcceptWaveform(sampling_rate, waveform, n);
|
||||
return;
|
||||
@ -239,6 +253,8 @@ class FeatureExtractor::Impl {
|
||||
return fbank_->GetFrame(frame_index);
|
||||
} else if (whisper_fbank_) {
|
||||
return whisper_fbank_->GetFrame(frame_index);
|
||||
} else if (raw_audio_) {
|
||||
return raw_audio_->GetFrame(frame_index);
|
||||
} else if (mfcc_) {
|
||||
return mfcc_->GetFrame(frame_index);
|
||||
}
|
||||
@ -255,6 +271,9 @@ class FeatureExtractor::Impl {
|
||||
} else if (whisper_fbank_) {
|
||||
whisper_fbank_->Pop(discard_num);
|
||||
return;
|
||||
} else if (raw_audio_) {
|
||||
raw_audio_->Pop(discard_num);
|
||||
return;
|
||||
} else if (mfcc_) {
|
||||
mfcc_->Pop(discard_num);
|
||||
return;
|
||||
@ -322,11 +341,21 @@ class FeatureExtractor::Impl {
|
||||
config_.sampling_rate = opts_.frame_opts.samp_freq;
|
||||
}
|
||||
|
||||
void InitRawAudioSamples() {
|
||||
opts_raw_audio_.frame_opts.samp_freq = config_.sampling_rate;
|
||||
opts_raw_audio_.frame_opts.frame_length_ms = config_.frame_length_ms;
|
||||
opts_raw_audio_.frame_opts.frame_shift_ms = config_.frame_shift_ms;
|
||||
|
||||
raw_audio_ = std::make_unique<knf::OnlineRawAudioSamples>(opts_raw_audio_);
|
||||
}
|
||||
|
||||
private:
|
||||
std::unique_ptr<knf::OnlineFbank> fbank_;
|
||||
std::unique_ptr<knf::OnlineMfcc> mfcc_;
|
||||
std::unique_ptr<knf::OnlineWhisperFbank> whisper_fbank_;
|
||||
std::unique_ptr<knf::OnlineRawAudioSamples> raw_audio_;
|
||||
knf::FbankOptions opts_;
|
||||
knf::RawAudioSamplesOptions opts_raw_audio_;
|
||||
knf::MfccOptions mfcc_opts_;
|
||||
FeatureExtractorConfig config_;
|
||||
mutable std::mutex mutex_;
|
||||
|
||||
@ -81,6 +81,8 @@ struct FeatureExtractorConfig {
|
||||
|
||||
bool is_whisper = false;
|
||||
|
||||
bool is_t_one = false;
|
||||
|
||||
bool round_to_power_of_two = true;
|
||||
|
||||
std::string ToString() const;
|
||||
|
||||
@ -4,6 +4,7 @@
|
||||
|
||||
#include "sherpa-onnx/csrc/jieba-lexicon.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <fstream>
|
||||
#include <regex> // NOLINT
|
||||
#include <strstream>
|
||||
|
||||
@ -38,7 +38,8 @@ struct OfflineRecognitionResult {
|
||||
/// timestamps[i] records the time in seconds when tokens[i] is decoded.
|
||||
std::vector<float> timestamps;
|
||||
|
||||
/// durations[i] contains the duration (in seconds) for tokens[i] (TDT models only)
|
||||
/// durations[i] contains the duration (in seconds) for tokens[i] (TDT models
|
||||
/// only)
|
||||
std::vector<float> durations;
|
||||
|
||||
std::vector<int32_t> words;
|
||||
|
||||
@ -4,6 +4,7 @@
|
||||
#ifndef SHERPA_ONNX_CSRC_OFFLINE_TTS_ZIPVOICE_IMPL_H_
|
||||
#define SHERPA_ONNX_CSRC_OFFLINE_TTS_ZIPVOICE_IMPL_H_
|
||||
|
||||
#include <algorithm>
|
||||
#include <cmath>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
||||
@ -104,7 +104,8 @@ class OfflineTtsZipvoiceModel::Impl {
|
||||
int64_t feat_dim = meta_data_.feat_dim;
|
||||
|
||||
std::vector<float> x_data(batch_size * num_frames * feat_dim);
|
||||
std::default_random_engine rng(std::random_device{}());
|
||||
std::random_device rd;
|
||||
std::default_random_engine rng(rd());
|
||||
std::normal_distribution<float> norm(0, 1);
|
||||
for (auto &v : x_data) v = norm(rng);
|
||||
std::vector<int64_t> x_shape = {batch_size, num_frames, feat_dim};
|
||||
|
||||
@ -7,6 +7,7 @@
|
||||
#include <cmath>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
#include "android/asset_manager.h"
|
||||
|
||||
@ -28,6 +28,13 @@ void OnlineCtcGreedySearchDecoder::Decode(
|
||||
auto &r = (*results)[b];
|
||||
|
||||
int32_t prev_id = -1;
|
||||
if (!r.tokens.empty()) {
|
||||
if (r.num_trailing_blanks > 0) {
|
||||
prev_id = blank_id_;
|
||||
} else {
|
||||
prev_id = r.tokens.back();
|
||||
}
|
||||
}
|
||||
|
||||
for (int32_t t = 0; t != num_frames; ++t, p += vocab_size) {
|
||||
int32_t y = static_cast<int32_t>(std::distance(
|
||||
|
||||
@ -20,6 +20,7 @@
|
||||
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
#include "sherpa-onnx/csrc/online-nemo-ctc-model.h"
|
||||
#include "sherpa-onnx/csrc/online-t-one-ctc-model.h"
|
||||
#include "sherpa-onnx/csrc/online-wenet-ctc-model.h"
|
||||
#include "sherpa-onnx/csrc/online-zipformer2-ctc-model.h"
|
||||
#include "sherpa-onnx/csrc/onnx-utils.h"
|
||||
@ -34,9 +35,11 @@ std::unique_ptr<OnlineCtcModel> OnlineCtcModel::Create(
|
||||
return std::make_unique<OnlineZipformer2CtcModel>(config);
|
||||
} else if (!config.nemo_ctc.model.empty()) {
|
||||
return std::make_unique<OnlineNeMoCtcModel>(config);
|
||||
} else if (!config.t_one_ctc.model.empty()) {
|
||||
return std::make_unique<OnlineToneCtcModel>(config);
|
||||
} else {
|
||||
SHERPA_ONNX_LOGE("Please specify a CTC model");
|
||||
exit(-1);
|
||||
SHERPA_ONNX_EXIT(-1);
|
||||
}
|
||||
}
|
||||
|
||||
@ -49,9 +52,11 @@ std::unique_ptr<OnlineCtcModel> OnlineCtcModel::Create(
|
||||
return std::make_unique<OnlineZipformer2CtcModel>(mgr, config);
|
||||
} else if (!config.nemo_ctc.model.empty()) {
|
||||
return std::make_unique<OnlineNeMoCtcModel>(mgr, config);
|
||||
} else if (!config.t_one_ctc.model.empty()) {
|
||||
return std::make_unique<OnlineToneCtcModel>(mgr, config);
|
||||
} else {
|
||||
SHERPA_ONNX_LOGE("Please specify a CTC model");
|
||||
exit(-1);
|
||||
SHERPA_ONNX_EXIT(-1);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -17,6 +17,7 @@ void OnlineModelConfig::Register(ParseOptions *po) {
|
||||
wenet_ctc.Register(po);
|
||||
zipformer2_ctc.Register(po);
|
||||
nemo_ctc.Register(po);
|
||||
t_one_ctc.Register(po);
|
||||
provider_config.Register(po);
|
||||
|
||||
po->Register("tokens", &tokens, "Path to tokens.txt");
|
||||
@ -149,6 +150,10 @@ bool OnlineModelConfig::Validate() const {
|
||||
return nemo_ctc.Validate();
|
||||
}
|
||||
|
||||
if (!t_one_ctc.model.empty()) {
|
||||
return t_one_ctc.Validate();
|
||||
}
|
||||
|
||||
if (!provider_config.Validate()) {
|
||||
return false;
|
||||
}
|
||||
@ -165,6 +170,7 @@ std::string OnlineModelConfig::ToString() const {
|
||||
os << "wenet_ctc=" << wenet_ctc.ToString() << ", ";
|
||||
os << "zipformer2_ctc=" << zipformer2_ctc.ToString() << ", ";
|
||||
os << "nemo_ctc=" << nemo_ctc.ToString() << ", ";
|
||||
os << "t_one_ctc=" << t_one_ctc.ToString() << ", ";
|
||||
os << "provider_config=" << provider_config.ToString() << ", ";
|
||||
os << "tokens=\"" << tokens << "\", ";
|
||||
os << "num_threads=" << num_threads << ", ";
|
||||
|
||||
@ -8,6 +8,7 @@
|
||||
|
||||
#include "sherpa-onnx/csrc/online-nemo-ctc-model-config.h"
|
||||
#include "sherpa-onnx/csrc/online-paraformer-model-config.h"
|
||||
#include "sherpa-onnx/csrc/online-t-one-ctc-model-config.h"
|
||||
#include "sherpa-onnx/csrc/online-transducer-model-config.h"
|
||||
#include "sherpa-onnx/csrc/online-wenet-ctc-model-config.h"
|
||||
#include "sherpa-onnx/csrc/online-zipformer2-ctc-model-config.h"
|
||||
@ -21,6 +22,7 @@ struct OnlineModelConfig {
|
||||
OnlineWenetCtcModelConfig wenet_ctc;
|
||||
OnlineZipformer2CtcModelConfig zipformer2_ctc;
|
||||
OnlineNeMoCtcModelConfig nemo_ctc;
|
||||
OnlineToneCtcModelConfig t_one_ctc;
|
||||
ProviderConfig provider_config;
|
||||
std::string tokens;
|
||||
int32_t num_threads = 1;
|
||||
@ -56,6 +58,7 @@ struct OnlineModelConfig {
|
||||
const OnlineWenetCtcModelConfig &wenet_ctc,
|
||||
const OnlineZipformer2CtcModelConfig &zipformer2_ctc,
|
||||
const OnlineNeMoCtcModelConfig &nemo_ctc,
|
||||
const OnlineToneCtcModelConfig &t_one_ctc,
|
||||
const ProviderConfig &provider_config,
|
||||
const std::string &tokens, int32_t num_threads,
|
||||
int32_t warm_up, bool debug, const std::string &model_type,
|
||||
@ -66,6 +69,7 @@ struct OnlineModelConfig {
|
||||
wenet_ctc(wenet_ctc),
|
||||
zipformer2_ctc(zipformer2_ctc),
|
||||
nemo_ctc(nemo_ctc),
|
||||
t_one_ctc(t_one_ctc),
|
||||
provider_config(provider_config),
|
||||
tokens(tokens),
|
||||
num_threads(num_threads),
|
||||
|
||||
@ -6,6 +6,7 @@
|
||||
#define SHERPA_ONNX_CSRC_ONLINE_RECOGNIZER_CTC_IMPL_H_
|
||||
|
||||
#include <algorithm>
|
||||
#include <cassert>
|
||||
#include <ios>
|
||||
#include <memory>
|
||||
#include <sstream>
|
||||
@ -79,24 +80,7 @@ class OnlineRecognizerCtcImpl : public OnlineRecognizerImpl {
|
||||
config_(config),
|
||||
model_(OnlineCtcModel::Create(config.model_config)),
|
||||
endpoint_(config_.endpoint_config) {
|
||||
if (!config.model_config.tokens_buf.empty()) {
|
||||
sym_ = SymbolTable(config.model_config.tokens_buf, false);
|
||||
} else {
|
||||
/// assuming tokens_buf and tokens are guaranteed not being both empty
|
||||
sym_ = SymbolTable(config.model_config.tokens, true);
|
||||
}
|
||||
|
||||
if (!config.model_config.wenet_ctc.model.empty()) {
|
||||
// WeNet CTC models assume input samples are in the range
|
||||
// [-32768, 32767], so we set normalize_samples to false
|
||||
config_.feat_config.normalize_samples = false;
|
||||
}
|
||||
|
||||
if (model_->UseWhisperFeature()) {
|
||||
config_.feat_config.is_whisper = true;
|
||||
}
|
||||
|
||||
InitDecoder();
|
||||
PostInit();
|
||||
}
|
||||
|
||||
template <typename Manager>
|
||||
@ -107,17 +91,7 @@ class OnlineRecognizerCtcImpl : public OnlineRecognizerImpl {
|
||||
model_(OnlineCtcModel::Create(mgr, config.model_config)),
|
||||
sym_(mgr, config.model_config.tokens),
|
||||
endpoint_(config_.endpoint_config) {
|
||||
if (!config.model_config.wenet_ctc.model.empty()) {
|
||||
// WeNet CTC models assume input samples are in the range
|
||||
// [-32768, 32767], so we set normalize_samples to false
|
||||
config_.feat_config.normalize_samples = false;
|
||||
}
|
||||
|
||||
if (model_->UseWhisperFeature()) {
|
||||
config_.feat_config.is_whisper = true;
|
||||
}
|
||||
|
||||
InitDecoder();
|
||||
PostInit();
|
||||
}
|
||||
|
||||
std::unique_ptr<OnlineStream> CreateStream() const override {
|
||||
@ -211,6 +185,14 @@ class OnlineRecognizerCtcImpl : public OnlineRecognizerImpl {
|
||||
// TODO(fangjun): Remember to change these constants if needed
|
||||
int32_t frame_shift_ms = 10;
|
||||
int32_t subsampling_factor = 4;
|
||||
if (!config_.model_config.t_one_ctc.model.empty()) {
|
||||
// each input frame is of 300ms long, which produces 10 output frames.
|
||||
// so frame_shift_ms is 300/10 = 30ms
|
||||
//
|
||||
frame_shift_ms = 30;
|
||||
subsampling_factor = 1;
|
||||
}
|
||||
|
||||
auto r =
|
||||
ConvertCtc(decoder_result, sym_, frame_shift_ms, subsampling_factor,
|
||||
s->GetCurrentSegment(), s->GetNumFramesSinceStart());
|
||||
@ -258,6 +240,33 @@ class OnlineRecognizerCtcImpl : public OnlineRecognizerImpl {
|
||||
}
|
||||
|
||||
private:
|
||||
void PostInit() {
|
||||
if (!config_.model_config.tokens_buf.empty()) {
|
||||
sym_ = SymbolTable(config_.model_config.tokens_buf, false);
|
||||
} else {
|
||||
/// assuming tokens_buf and tokens are guaranteed not being both empty
|
||||
sym_ = SymbolTable(config_.model_config.tokens, true);
|
||||
}
|
||||
|
||||
if (!config_.model_config.wenet_ctc.model.empty()) {
|
||||
// WeNet CTC models assume input samples are in the range
|
||||
// [-32768, 32767], so we set normalize_samples to false
|
||||
config_.feat_config.normalize_samples = false;
|
||||
}
|
||||
|
||||
if (!config_.model_config.t_one_ctc.model.empty()) {
|
||||
config_.feat_config.is_t_one = true;
|
||||
config_.feat_config.frame_length_ms = 300;
|
||||
config_.feat_config.frame_shift_ms = 300;
|
||||
config_.feat_config.sampling_rate = 8000;
|
||||
}
|
||||
|
||||
if (model_->UseWhisperFeature()) {
|
||||
config_.feat_config.is_whisper = true;
|
||||
}
|
||||
|
||||
InitDecoder();
|
||||
}
|
||||
void InitDecoder() {
|
||||
if (!sym_.Contains("<blk>") && !sym_.Contains("<eps>") &&
|
||||
!sym_.Contains("<blank>")) {
|
||||
|
||||
@ -83,12 +83,13 @@ std::unique_ptr<OnlineRecognizerImpl> OnlineRecognizerImpl::Create(
|
||||
|
||||
if (!config.model_config.wenet_ctc.model.empty() ||
|
||||
!config.model_config.zipformer2_ctc.model.empty() ||
|
||||
!config.model_config.nemo_ctc.model.empty()) {
|
||||
!config.model_config.nemo_ctc.model.empty() ||
|
||||
!config.model_config.t_one_ctc.model.empty()) {
|
||||
return std::make_unique<OnlineRecognizerCtcImpl>(config);
|
||||
}
|
||||
|
||||
SHERPA_ONNX_LOGE("Please specify a model");
|
||||
exit(-1);
|
||||
SHERPA_ONNX_EXIT(-1);
|
||||
}
|
||||
|
||||
template <typename Manager>
|
||||
@ -142,12 +143,13 @@ std::unique_ptr<OnlineRecognizerImpl> OnlineRecognizerImpl::Create(
|
||||
|
||||
if (!config.model_config.wenet_ctc.model.empty() ||
|
||||
!config.model_config.zipformer2_ctc.model.empty() ||
|
||||
!config.model_config.nemo_ctc.model.empty()) {
|
||||
!config.model_config.nemo_ctc.model.empty() ||
|
||||
!config.model_config.t_one_ctc.model.empty()) {
|
||||
return std::make_unique<OnlineRecognizerCtcImpl>(mgr, config);
|
||||
}
|
||||
|
||||
SHERPA_ONNX_LOGE("Please specify a model");
|
||||
exit(-1);
|
||||
SHERPA_ONNX_EXIT(-1);
|
||||
}
|
||||
|
||||
OnlineRecognizerImpl::OnlineRecognizerImpl(const OnlineRecognizerConfig &config)
|
||||
|
||||
36
sherpa-onnx/csrc/online-t-one-ctc-model-config.cc
Normal file
36
sherpa-onnx/csrc/online-t-one-ctc-model-config.cc
Normal file
@ -0,0 +1,36 @@
|
||||
// sherpa-onnx/csrc/online-t-one-ctc-model-config.cc
|
||||
//
|
||||
// Copyright (c) 2025 Xiaomi Corporation
|
||||
|
||||
#include "sherpa-onnx/csrc/online-t-one-ctc-model-config.h"
|
||||
|
||||
#include "sherpa-onnx/csrc/file-utils.h"
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
void OnlineToneCtcModelConfig::Register(ParseOptions *po) {
|
||||
po->Register("t-one-ctc-model", &model,
|
||||
"Path to CTC model.onnx from T-one. Please see "
|
||||
"https://github.com/k2-fsa/sherpa-onnx/pull/2571");
|
||||
}
|
||||
|
||||
bool OnlineToneCtcModelConfig::Validate() const {
|
||||
if (!FileExists(model)) {
|
||||
SHERPA_ONNX_LOGE("T-one CTC model '%s' does not exist", model.c_str());
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
std::string OnlineToneCtcModelConfig::ToString() const {
|
||||
std::ostringstream os;
|
||||
|
||||
os << "OnlineToneCtcModelConfig(";
|
||||
os << "model=\"" << model << "\")";
|
||||
|
||||
return os.str();
|
||||
}
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
28
sherpa-onnx/csrc/online-t-one-ctc-model-config.h
Normal file
28
sherpa-onnx/csrc/online-t-one-ctc-model-config.h
Normal file
@ -0,0 +1,28 @@
|
||||
// sherpa-onnx/csrc/online-t-one-ctc-model-config.h
|
||||
//
|
||||
// Copyright (c) 2025 Xiaomi Corporation
|
||||
#ifndef SHERPA_ONNX_CSRC_ONLINE_T_ONE_CTC_MODEL_CONFIG_H_
|
||||
#define SHERPA_ONNX_CSRC_ONLINE_T_ONE_CTC_MODEL_CONFIG_H_
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "sherpa-onnx/csrc/parse-options.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
struct OnlineToneCtcModelConfig {
|
||||
std::string model;
|
||||
|
||||
OnlineToneCtcModelConfig() = default;
|
||||
|
||||
explicit OnlineToneCtcModelConfig(const std::string &model) : model(model) {}
|
||||
|
||||
void Register(ParseOptions *po);
|
||||
bool Validate() const;
|
||||
|
||||
std::string ToString() const;
|
||||
};
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
#endif // SHERPA_ONNX_CSRC_ONLINE_T_ONE_CTC_MODEL_CONFIG_H_
|
||||
274
sherpa-onnx/csrc/online-t-one-ctc-model.cc
Normal file
274
sherpa-onnx/csrc/online-t-one-ctc-model.cc
Normal file
@ -0,0 +1,274 @@
|
||||
// sherpa-onnx/csrc/online-t-one-ctc-model.cc
|
||||
//
|
||||
// Copyright (c) 2025 Xiaomi Corporation
|
||||
|
||||
#include "sherpa-onnx/csrc/online-t-one-ctc-model.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <cmath>
|
||||
#include <string>
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
#include "android/asset_manager.h"
|
||||
#include "android/asset_manager_jni.h"
|
||||
#endif
|
||||
|
||||
#if __OHOS__
|
||||
#include "rawfile/raw_file_manager.h"
|
||||
#endif
|
||||
|
||||
#include "sherpa-onnx/csrc/cat.h"
|
||||
#include "sherpa-onnx/csrc/file-utils.h"
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
#include "sherpa-onnx/csrc/onnx-utils.h"
|
||||
#include "sherpa-onnx/csrc/session.h"
|
||||
#include "sherpa-onnx/csrc/text-utils.h"
|
||||
#include "sherpa-onnx/csrc/unbind.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
class OnlineToneCtcModel::Impl {
|
||||
public:
|
||||
explicit Impl(const OnlineModelConfig &config)
|
||||
: config_(config),
|
||||
env_(ORT_LOGGING_LEVEL_ERROR),
|
||||
sess_opts_(GetSessionOptions(config)),
|
||||
allocator_{} {
|
||||
{
|
||||
auto buf = ReadFile(config.t_one_ctc.model);
|
||||
Init(buf.data(), buf.size());
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Manager>
|
||||
Impl(Manager *mgr, const OnlineModelConfig &config)
|
||||
: config_(config),
|
||||
env_(ORT_LOGGING_LEVEL_ERROR),
|
||||
sess_opts_(GetSessionOptions(config)),
|
||||
allocator_{} {
|
||||
{
|
||||
auto buf = ReadFile(mgr, config.t_one_ctc.model);
|
||||
Init(buf.data(), buf.size());
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<Ort::Value> Forward(Ort::Value x,
|
||||
std::vector<Ort::Value> states) {
|
||||
// shape0 is (batch_size, 1, num_samples)
|
||||
auto shape0 = x.GetTensorTypeAndShapeInfo().GetShape();
|
||||
std::array<int64_t, 3> shape = {shape0[0], shape0[2], shape0[1]};
|
||||
std::vector<int32_t> samples(shape[0] * shape[1] * shape[2]);
|
||||
const float *px = x.GetTensorData<float>();
|
||||
|
||||
for (int32_t i = 0; i < samples.size(); ++i) {
|
||||
float f = px[i];
|
||||
f = f > 1 ? 1 : f;
|
||||
f = f < -1 ? -1 : f;
|
||||
samples[i] = static_cast<int32_t>(f * 32767);
|
||||
}
|
||||
|
||||
auto memory_info =
|
||||
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
|
||||
|
||||
Ort::Value xx =
|
||||
Ort::Value::CreateTensor(memory_info, samples.data(), samples.size(),
|
||||
shape.data(), shape.size());
|
||||
|
||||
std::array<Ort::Value, 2> inputs = {std::move(xx), std::move(states[0])};
|
||||
|
||||
auto out =
|
||||
sess_->Run({}, input_names_ptr_.data(), inputs.data(), inputs.size(),
|
||||
output_names_ptr_.data(), output_names_ptr_.size());
|
||||
// out[0]: log_probs
|
||||
// out[1] next_states
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
int32_t VocabSize() const { return vocab_size_; }
|
||||
|
||||
int32_t ChunkLength() const { return 1; }
|
||||
|
||||
int32_t ChunkShift() const { return 1; }
|
||||
|
||||
OrtAllocator *Allocator() { return allocator_; }
|
||||
|
||||
// Return a vector containing 1 tensor
|
||||
// - state_
|
||||
std::vector<Ort::Value> GetInitStates() {
|
||||
std::vector<Ort::Value> ans;
|
||||
ans.push_back(View(&state_));
|
||||
|
||||
return ans;
|
||||
}
|
||||
|
||||
std::vector<Ort::Value> StackStates(
|
||||
std::vector<std::vector<Ort::Value>> states) {
|
||||
int32_t batch_size = static_cast<int32_t>(states.size());
|
||||
if (batch_size == 1) {
|
||||
return std::move(states[0]);
|
||||
}
|
||||
|
||||
std::vector<Ort::Value> ans;
|
||||
ans.reserve(1);
|
||||
|
||||
std::vector<const Ort::Value *> buf;
|
||||
buf.reserve(batch_size);
|
||||
|
||||
for (int32_t b = 0; b != batch_size; ++b) {
|
||||
buf.push_back(&states[b][0]);
|
||||
}
|
||||
|
||||
Ort::Value c{nullptr};
|
||||
c = CatFloat16(allocator_, buf, 0);
|
||||
|
||||
ans.push_back(std::move(c));
|
||||
|
||||
return ans;
|
||||
}
|
||||
|
||||
std::vector<std::vector<Ort::Value>> UnStackStates(
|
||||
std::vector<Ort::Value> states) const {
|
||||
auto allocator = const_cast<Impl *>(this)->allocator_;
|
||||
|
||||
std::vector<std::vector<Ort::Value>> ans;
|
||||
|
||||
auto shape = states[0].GetTensorTypeAndShapeInfo().GetShape();
|
||||
int32_t batch_size = shape[0];
|
||||
ans.resize(batch_size);
|
||||
|
||||
if (batch_size == 1) {
|
||||
ans[0] = std::move(states);
|
||||
return ans;
|
||||
}
|
||||
|
||||
std::vector<Ort::Value> v;
|
||||
v = UnbindFloat16(allocator, &states[0], 0);
|
||||
|
||||
for (int32_t b = 0; b != batch_size; ++b) {
|
||||
ans[b].push_back(std::move(v[b]));
|
||||
}
|
||||
|
||||
return ans;
|
||||
}
|
||||
|
||||
private:
|
||||
void Init(void *model_data, size_t model_data_length) {
|
||||
sess_ = std::make_unique<Ort::Session>(env_, model_data, model_data_length,
|
||||
sess_opts_);
|
||||
|
||||
GetInputNames(sess_.get(), &input_names_, &input_names_ptr_);
|
||||
|
||||
GetOutputNames(sess_.get(), &output_names_, &output_names_ptr_);
|
||||
|
||||
// get meta data
|
||||
Ort::ModelMetadata meta_data = sess_->GetModelMetadata();
|
||||
if (config_.debug) {
|
||||
std::ostringstream os;
|
||||
PrintModelMetadata(os, meta_data);
|
||||
#if __OHOS__
|
||||
SHERPA_ONNX_LOGE("%{public}s", os.str().c_str());
|
||||
#else
|
||||
SHERPA_ONNX_LOGE("%s", os.str().c_str());
|
||||
#endif
|
||||
}
|
||||
|
||||
Ort::AllocatorWithDefaultOptions allocator; // used in the macro below
|
||||
SHERPA_ONNX_READ_META_DATA(frame_length_ms_, "frame_length_ms");
|
||||
SHERPA_ONNX_READ_META_DATA(state_dim_, "state_dim");
|
||||
SHERPA_ONNX_READ_META_DATA(sample_rate_, "sample_rate");
|
||||
|
||||
InitStates();
|
||||
|
||||
vocab_size_ = sess_->GetOutputTypeInfo(0)
|
||||
.GetTensorTypeAndShapeInfo()
|
||||
.GetShape()
|
||||
.back();
|
||||
}
|
||||
|
||||
void InitStates() {
|
||||
std::array<int64_t, 2> state_shape{1, state_dim_};
|
||||
|
||||
state_ = Ort::Value::CreateTensor(allocator_, state_shape.data(),
|
||||
state_shape.size(),
|
||||
ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16);
|
||||
|
||||
auto p = state_.GetTensorMutableData<uint16_t>();
|
||||
std::fill(p, p + state_dim_, 0);
|
||||
}
|
||||
|
||||
private:
|
||||
OnlineModelConfig config_;
|
||||
Ort::Env env_;
|
||||
Ort::SessionOptions sess_opts_;
|
||||
Ort::AllocatorWithDefaultOptions allocator_;
|
||||
|
||||
std::unique_ptr<Ort::Session> sess_;
|
||||
|
||||
std::vector<std::string> input_names_;
|
||||
std::vector<const char *> input_names_ptr_;
|
||||
|
||||
std::vector<std::string> output_names_;
|
||||
std::vector<const char *> output_names_ptr_;
|
||||
|
||||
// One input frame is of length is 300ms
|
||||
// For each input frame, there are 10 output frames,
|
||||
// so each output frame is 30ms
|
||||
int32_t frame_length_ms_ = 0;
|
||||
int32_t state_dim_ = 0;
|
||||
int32_t sample_rate_ = 0;
|
||||
int32_t vocab_size_ = 0;
|
||||
|
||||
Ort::Value state_{nullptr};
|
||||
};
|
||||
|
||||
OnlineToneCtcModel::OnlineToneCtcModel(const OnlineModelConfig &config)
|
||||
: impl_(std::make_unique<Impl>(config)) {}
|
||||
|
||||
template <typename Manager>
|
||||
OnlineToneCtcModel::OnlineToneCtcModel(Manager *mgr,
|
||||
const OnlineModelConfig &config)
|
||||
: impl_(std::make_unique<Impl>(mgr, config)) {}
|
||||
|
||||
OnlineToneCtcModel::~OnlineToneCtcModel() = default;
|
||||
|
||||
std::vector<Ort::Value> OnlineToneCtcModel::Forward(
|
||||
Ort::Value x, std::vector<Ort::Value> states) const {
|
||||
return impl_->Forward(std::move(x), std::move(states));
|
||||
}
|
||||
|
||||
int32_t OnlineToneCtcModel::VocabSize() const { return impl_->VocabSize(); }
|
||||
|
||||
int32_t OnlineToneCtcModel::ChunkLength() const { return impl_->ChunkLength(); }
|
||||
|
||||
int32_t OnlineToneCtcModel::ChunkShift() const { return impl_->ChunkShift(); }
|
||||
|
||||
OrtAllocator *OnlineToneCtcModel::Allocator() const {
|
||||
return impl_->Allocator();
|
||||
}
|
||||
|
||||
std::vector<Ort::Value> OnlineToneCtcModel::GetInitStates() const {
|
||||
return impl_->GetInitStates();
|
||||
}
|
||||
|
||||
std::vector<Ort::Value> OnlineToneCtcModel::StackStates(
|
||||
std::vector<std::vector<Ort::Value>> states) const {
|
||||
return impl_->StackStates(std::move(states));
|
||||
}
|
||||
|
||||
std::vector<std::vector<Ort::Value>> OnlineToneCtcModel::UnStackStates(
|
||||
std::vector<Ort::Value> states) const {
|
||||
return impl_->UnStackStates(std::move(states));
|
||||
}
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
template OnlineToneCtcModel::OnlineToneCtcModel(
|
||||
AAssetManager *mgr, const OnlineModelConfig &config);
|
||||
#endif
|
||||
|
||||
#if __OHOS__
|
||||
template OnlineToneCtcModel::OnlineToneCtcModel(
|
||||
NativeResourceManager *mgr, const OnlineModelConfig &config);
|
||||
#endif
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
73
sherpa-onnx/csrc/online-t-one-ctc-model.h
Normal file
73
sherpa-onnx/csrc/online-t-one-ctc-model.h
Normal file
@ -0,0 +1,73 @@
|
||||
// sherpa-onnx/csrc/online-t-one-ctc-model.h
|
||||
//
|
||||
// Copyright (c) 2025 Xiaomi Corporation
|
||||
#ifndef SHERPA_ONNX_CSRC_ONLINE_T_ONE_CTC_MODEL_H_
|
||||
#define SHERPA_ONNX_CSRC_ONLINE_T_ONE_CTC_MODEL_H_
|
||||
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "onnxruntime_cxx_api.h" // NOLINT
|
||||
#include "sherpa-onnx/csrc/online-ctc-model.h"
|
||||
#include "sherpa-onnx/csrc/online-model-config.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
class OnlineToneCtcModel : public OnlineCtcModel {
|
||||
public:
|
||||
explicit OnlineToneCtcModel(const OnlineModelConfig &config);
|
||||
|
||||
template <typename Manager>
|
||||
OnlineToneCtcModel(Manager *mgr, const OnlineModelConfig &config);
|
||||
|
||||
~OnlineToneCtcModel() override;
|
||||
|
||||
// A list of 1 tensor:
|
||||
// - (batch_size, state_dim)
|
||||
std::vector<Ort::Value> GetInitStates() const override;
|
||||
|
||||
std::vector<Ort::Value> StackStates(
|
||||
std::vector<std::vector<Ort::Value>> states) const override;
|
||||
|
||||
std::vector<std::vector<Ort::Value>> UnStackStates(
|
||||
std::vector<Ort::Value> states) const override;
|
||||
|
||||
/**
|
||||
*
|
||||
* @param x A 3-D tensor of shape (batch_size, num_samples).
|
||||
* @param states It is from GetInitStates() or returned from this method.
|
||||
*
|
||||
* @return Return a list of tensors
|
||||
* - ans[0] contains log_probs, of shape (N, T, C)
|
||||
* - ans[1:] contains next_states
|
||||
*/
|
||||
std::vector<Ort::Value> Forward(
|
||||
Ort::Value x, std::vector<Ort::Value> states) const override;
|
||||
|
||||
/** Return the vocabulary size of the model
|
||||
*/
|
||||
int32_t VocabSize() const override;
|
||||
|
||||
/** Return an allocator for allocating memory
|
||||
*/
|
||||
OrtAllocator *Allocator() const override;
|
||||
|
||||
// The model accepts this number of frames before subsampling as input
|
||||
int32_t ChunkLength() const override;
|
||||
|
||||
// Similar to frame_shift in feature extractor, after processing
|
||||
// ChunkLength() frames, we advance by ChunkShift() frames
|
||||
// before we process the next chunk.
|
||||
int32_t ChunkShift() const override;
|
||||
|
||||
bool SupportBatchProcessing() const override { return true; }
|
||||
|
||||
private:
|
||||
class Impl;
|
||||
std::unique_ptr<Impl> impl_;
|
||||
};
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
#endif // SHERPA_ONNX_CSRC_ONLINE_T_ONE_CTC_MODEL_H_
|
||||
@ -155,10 +155,30 @@ Ort::Value Clone(OrtAllocator *allocator, const Ort::Value *v) {
|
||||
std::copy(start, end, dst);
|
||||
return ans;
|
||||
}
|
||||
case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16: {
|
||||
Ort::Value ans =
|
||||
Ort::Value::CreateTensor(allocator, shape.data(), shape.size(),
|
||||
ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16);
|
||||
const auto *start = v->GetTensorData<uint16_t>();
|
||||
const auto *end = start + type_and_shape.GetElementCount();
|
||||
auto *dst = ans.GetTensorMutableData<uint16_t>();
|
||||
std::copy(start, end, dst);
|
||||
return ans;
|
||||
}
|
||||
case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16: {
|
||||
Ort::Value ans = Ort::Value::CreateTensor<uint16_t>(
|
||||
allocator, shape.data(), shape.size());
|
||||
const auto *start = v->GetTensorData<uint16_t>();
|
||||
const auto *end = start + type_and_shape.GetElementCount();
|
||||
auto *dst = ans.GetTensorMutableData<uint16_t>();
|
||||
std::copy(start, end, dst);
|
||||
return ans;
|
||||
}
|
||||
|
||||
default:
|
||||
fprintf(stderr, "Unsupported type: %d\n",
|
||||
static_cast<int32_t>(type_and_shape.GetElementType()));
|
||||
exit(-1);
|
||||
SHERPA_ONNX_LOGE("Unsupported type: %d\n",
|
||||
static_cast<int32_t>(type_and_shape.GetElementType()));
|
||||
SHERPA_ONNX_EXIT(-1);
|
||||
// unreachable code
|
||||
return Ort::Value{nullptr};
|
||||
}
|
||||
@ -183,14 +203,23 @@ Ort::Value View(Ort::Value *v) {
|
||||
return Ort::Value::CreateTensor(
|
||||
memory_info, v->GetTensorMutableData<float>(),
|
||||
type_and_shape.GetElementCount(), shape.data(), shape.size());
|
||||
case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16:
|
||||
return Ort::Value::CreateTensor(
|
||||
memory_info, v->GetTensorMutableData<uint16_t>(),
|
||||
type_and_shape.GetElementCount() * sizeof(uint16_t), shape.data(),
|
||||
shape.size(), ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16);
|
||||
case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16:
|
||||
return Ort::Value::CreateTensor(
|
||||
memory_info, v->GetTensorMutableData<uint16_t>(),
|
||||
type_and_shape.GetElementCount(), shape.data(), shape.size());
|
||||
case ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL:
|
||||
return Ort::Value::CreateTensor(
|
||||
memory_info, v->GetTensorMutableData<bool>(),
|
||||
type_and_shape.GetElementCount(), shape.data(), shape.size());
|
||||
default:
|
||||
fprintf(stderr, "Unsupported type: %d\n",
|
||||
static_cast<int32_t>(type_and_shape.GetElementType()));
|
||||
exit(-1);
|
||||
SHERPA_ONNX_LOGE("Unsupported type: %d\n",
|
||||
static_cast<int32_t>(type_and_shape.GetElementType()));
|
||||
SHERPA_ONNX_EXIT(-1);
|
||||
// unreachable code
|
||||
return Ort::Value{nullptr};
|
||||
}
|
||||
|
||||
@ -11,6 +11,7 @@
|
||||
#include <locale>
|
||||
#endif
|
||||
|
||||
#include <algorithm>
|
||||
#include <cassert>
|
||||
#include <ostream>
|
||||
#include <string>
|
||||
|
||||
@ -117,6 +117,11 @@ for a list of pre-trained models to download.
|
||||
const float duration = samples.size() / static_cast<float>(sampling_rate);
|
||||
|
||||
auto s = recognizer.CreateStream();
|
||||
|
||||
std::vector<float> left_paddings(static_cast<int>(0.3 * sampling_rate));
|
||||
s->AcceptWaveform(sampling_rate, left_paddings.data(),
|
||||
left_paddings.size());
|
||||
|
||||
s->AcceptWaveform(sampling_rate, samples.data(), samples.size());
|
||||
|
||||
std::vector<float> tail_paddings(static_cast<int>(0.8 * sampling_rate));
|
||||
|
||||
@ -4,7 +4,7 @@
|
||||
|
||||
#include "sherpa-onnx/csrc/text-utils.h"
|
||||
|
||||
#include <regex>
|
||||
#include <regex> // NOLINT
|
||||
#include <sstream>
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
@ -68,4 +68,49 @@ template std::vector<Ort::Value> Unbind<int64_t>(OrtAllocator *allocator,
|
||||
const Ort::Value *value,
|
||||
int32_t dim);
|
||||
|
||||
std::vector<Ort::Value> UnbindFloat16(OrtAllocator *allocator,
|
||||
const Ort::Value *value, int32_t dim) {
|
||||
std::vector<int64_t> shape = value->GetTensorTypeAndShapeInfo().GetShape();
|
||||
assert(dim >= 0);
|
||||
assert(dim < static_cast<int32_t>(shape.size()));
|
||||
int32_t n = static_cast<int32_t>(shape[dim]);
|
||||
if (n == 1) {
|
||||
std::vector<Ort::Value> ans;
|
||||
ans.push_back(Clone(allocator, value));
|
||||
return ans;
|
||||
}
|
||||
|
||||
std::vector<int64_t> ans_shape = shape;
|
||||
ans_shape[dim] = 1; // // Unlike torch, we keep the dim to 1
|
||||
|
||||
// allocator tensors
|
||||
std::vector<Ort::Value> ans;
|
||||
ans.reserve(n);
|
||||
for (int32_t i = 0; i != n; ++i) {
|
||||
Ort::Value t =
|
||||
Ort::Value::CreateTensor(allocator, ans_shape.data(), ans_shape.size(),
|
||||
ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16);
|
||||
ans.push_back(std::move(t));
|
||||
}
|
||||
|
||||
auto leading_size = static_cast<int32_t>(std::accumulate(
|
||||
shape.begin(), shape.begin() + dim, 1, std::multiplies<int64_t>()));
|
||||
|
||||
auto trailing_size = static_cast<int32_t>(std::accumulate(
|
||||
shape.begin() + dim + 1, shape.end(), 1, std::multiplies<int64_t>()));
|
||||
|
||||
using T = uint16_t;
|
||||
const T *src = value->GetTensorData<T>();
|
||||
|
||||
for (int32_t i = 0; i != leading_size; ++i) {
|
||||
for (int32_t k = 0; k != n; ++k) {
|
||||
T *dst = ans[k].GetTensorMutableData<T>() + i * trailing_size;
|
||||
std::copy(src, src + trailing_size, dst);
|
||||
src += trailing_size;
|
||||
}
|
||||
}
|
||||
|
||||
return ans;
|
||||
}
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
@ -23,6 +23,9 @@ template <typename T = float>
|
||||
std::vector<Ort::Value> Unbind(OrtAllocator *allocator, const Ort::Value *value,
|
||||
int32_t dim);
|
||||
|
||||
std::vector<Ort::Value> UnbindFloat16(OrtAllocator *allocator,
|
||||
const Ort::Value *value, int32_t dim);
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
#endif // SHERPA_ONNX_CSRC_UNBIND_H_
|
||||
|
||||
@ -42,6 +42,7 @@ set(srcs
|
||||
online-punctuation.cc
|
||||
online-recognizer.cc
|
||||
online-stream.cc
|
||||
online-t-one-ctc-model-config.cc
|
||||
online-transducer-model-config.cc
|
||||
online-wenet-ctc-model-config.cc
|
||||
online-zipformer2-ctc-model-config.cc
|
||||
|
||||
@ -5,6 +5,7 @@
|
||||
|
||||
#include <algorithm>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "sherpa-onnx/csrc/offline-tts.h"
|
||||
#include "sherpa-onnx/python/csrc/offline-tts-model-config.h"
|
||||
|
||||
@ -12,6 +12,7 @@
|
||||
#include "sherpa-onnx/csrc/provider-config.h"
|
||||
#include "sherpa-onnx/python/csrc/online-nemo-ctc-model-config.h"
|
||||
#include "sherpa-onnx/python/csrc/online-paraformer-model-config.h"
|
||||
#include "sherpa-onnx/python/csrc/online-t-one-ctc-model-config.h"
|
||||
#include "sherpa-onnx/python/csrc/online-transducer-model-config.h"
|
||||
#include "sherpa-onnx/python/csrc/online-wenet-ctc-model-config.h"
|
||||
#include "sherpa-onnx/python/csrc/online-zipformer2-ctc-model-config.h"
|
||||
@ -25,6 +26,7 @@ void PybindOnlineModelConfig(py::module *m) {
|
||||
PybindOnlineWenetCtcModelConfig(m);
|
||||
PybindOnlineZipformer2CtcModelConfig(m);
|
||||
PybindOnlineNeMoCtcModelConfig(m);
|
||||
PybindOnlineToneCtcModelConfig(m);
|
||||
PybindProviderConfig(m);
|
||||
|
||||
using PyClass = OnlineModelConfig;
|
||||
@ -34,17 +36,18 @@ void PybindOnlineModelConfig(py::module *m) {
|
||||
const OnlineWenetCtcModelConfig &,
|
||||
const OnlineZipformer2CtcModelConfig &,
|
||||
const OnlineNeMoCtcModelConfig &,
|
||||
const ProviderConfig &,
|
||||
const std::string &, int32_t, int32_t,
|
||||
bool, const std::string &, const std::string &,
|
||||
const OnlineToneCtcModelConfig &, const ProviderConfig &,
|
||||
const std::string &, int32_t, int32_t, bool,
|
||||
const std::string &, const std::string &,
|
||||
const std::string &>(),
|
||||
py::arg("transducer") = OnlineTransducerModelConfig(),
|
||||
py::arg("paraformer") = OnlineParaformerModelConfig(),
|
||||
py::arg("wenet_ctc") = OnlineWenetCtcModelConfig(),
|
||||
py::arg("zipformer2_ctc") = OnlineZipformer2CtcModelConfig(),
|
||||
py::arg("nemo_ctc") = OnlineNeMoCtcModelConfig(),
|
||||
py::arg("provider_config") = ProviderConfig(),
|
||||
py::arg("tokens"), py::arg("num_threads"), py::arg("warm_up") = 0,
|
||||
py::arg("t_one_ctc") = OnlineToneCtcModelConfig(),
|
||||
py::arg("provider_config") = ProviderConfig(), py::arg("tokens"),
|
||||
py::arg("num_threads"), py::arg("warm_up") = 0,
|
||||
py::arg("debug") = false, py::arg("model_type") = "",
|
||||
py::arg("modeling_unit") = "", py::arg("bpe_vocab") = "")
|
||||
.def_readwrite("transducer", &PyClass::transducer)
|
||||
@ -52,6 +55,7 @@ void PybindOnlineModelConfig(py::module *m) {
|
||||
.def_readwrite("wenet_ctc", &PyClass::wenet_ctc)
|
||||
.def_readwrite("zipformer2_ctc", &PyClass::zipformer2_ctc)
|
||||
.def_readwrite("nemo_ctc", &PyClass::nemo_ctc)
|
||||
.def_readwrite("t_one_ctc", &PyClass::t_one_ctc)
|
||||
.def_readwrite("provider_config", &PyClass::provider_config)
|
||||
.def_readwrite("tokens", &PyClass::tokens)
|
||||
.def_readwrite("num_threads", &PyClass::num_threads)
|
||||
|
||||
22
sherpa-onnx/python/csrc/online-t-one-ctc-model-config.cc
Normal file
22
sherpa-onnx/python/csrc/online-t-one-ctc-model-config.cc
Normal file
@ -0,0 +1,22 @@
|
||||
// sherpa-onnx/python/csrc/online-t-one-ctc-model-config.cc
|
||||
//
|
||||
// Copyright (c) 2025 Xiaomi Corporation
|
||||
|
||||
#include "sherpa-onnx/python/csrc/online-t-one-ctc-model-config.h"
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "sherpa-onnx/csrc/online-t-one-ctc-model-config.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
void PybindOnlineToneCtcModelConfig(py::module *m) {
|
||||
using PyClass = OnlineToneCtcModelConfig;
|
||||
py::class_<PyClass>(*m, "OnlineToneCtcModelConfig")
|
||||
.def(py::init<const std::string &>(), py::arg("model"))
|
||||
.def_readwrite("model", &PyClass::model)
|
||||
.def("__str__", &PyClass::ToString);
|
||||
}
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
16
sherpa-onnx/python/csrc/online-t-one-ctc-model-config.h
Normal file
16
sherpa-onnx/python/csrc/online-t-one-ctc-model-config.h
Normal file
@ -0,0 +1,16 @@
|
||||
// sherpa-onnx/python/csrc/online-t-one-ctc-model-config.h
|
||||
//
|
||||
// Copyright (c) 2025 Xiaomi Corporation
|
||||
|
||||
#ifndef SHERPA_ONNX_PYTHON_CSRC_ONLINE_T_ONE_CTC_MODEL_CONFIG_H_
|
||||
#define SHERPA_ONNX_PYTHON_CSRC_ONLINE_T_ONE_CTC_MODEL_CONFIG_H_
|
||||
|
||||
#include "sherpa-onnx/python/csrc/sherpa-onnx.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
void PybindOnlineToneCtcModelConfig(py::module *m);
|
||||
|
||||
}
|
||||
|
||||
#endif // SHERPA_ONNX_PYTHON_CSRC_ONLINE_T_ONE_CTC_MODEL_CONFIG_H_
|
||||
@ -18,6 +18,7 @@ from sherpa_onnx.lib._sherpa_onnx import (
|
||||
OnlineRecognizerConfig,
|
||||
OnlineRecognizerResult,
|
||||
OnlineStream,
|
||||
OnlineToneCtcModelConfig,
|
||||
OnlineTransducerModelConfig,
|
||||
OnlineWenetCtcModelConfig,
|
||||
OnlineZipformer2CtcModelConfig,
|
||||
@ -602,6 +603,132 @@ class OnlineRecognizer(object):
|
||||
self.config = recognizer_config
|
||||
return self
|
||||
|
||||
@classmethod
|
||||
def from_t_one_ctc(
|
||||
cls,
|
||||
tokens: str,
|
||||
model: str,
|
||||
num_threads: int = 2,
|
||||
sample_rate: float = 8000,
|
||||
feature_dim: int = 80,
|
||||
enable_endpoint_detection: bool = False,
|
||||
rule1_min_trailing_silence: float = 2.4,
|
||||
rule2_min_trailing_silence: float = 1.2,
|
||||
rule3_min_utterance_length: float = 20.0,
|
||||
decoding_method: str = "greedy_search",
|
||||
provider: str = "cpu",
|
||||
debug: bool = False,
|
||||
rule_fsts: str = "",
|
||||
rule_fars: str = "",
|
||||
device: int = 0,
|
||||
hr_dict_dir: str = "",
|
||||
hr_rule_fsts: str = "",
|
||||
hr_lexicon: str = "",
|
||||
):
|
||||
"""
|
||||
Please refer to
|
||||
`<https://github.com/k2-fsa/sherpa-onnx/releases/tag/asr-models>`_
|
||||
to download pre-trained models.
|
||||
|
||||
Args:
|
||||
tokens:
|
||||
Path to ``tokens.txt``. Each line in ``tokens.txt`` contains two
|
||||
columns::
|
||||
|
||||
symbol integer_id
|
||||
|
||||
model:
|
||||
Path to ``model.onnx``.
|
||||
num_threads:
|
||||
Number of threads for neural network computation.
|
||||
sample_rate:
|
||||
Sample rate of the training data used to train the model.
|
||||
feature_dim:
|
||||
Dimension of the feature used to train the model.
|
||||
enable_endpoint_detection:
|
||||
True to enable endpoint detection. False to disable endpoint
|
||||
detection.
|
||||
rule1_min_trailing_silence:
|
||||
Used only when enable_endpoint_detection is True. If the duration
|
||||
of trailing silence in seconds is larger than this value, we assume
|
||||
an endpoint is detected.
|
||||
rule2_min_trailing_silence:
|
||||
Used only when enable_endpoint_detection is True. If we have decoded
|
||||
something that is nonsilence and if the duration of trailing silence
|
||||
in seconds is larger than this value, we assume an endpoint is
|
||||
detected.
|
||||
rule3_min_utterance_length:
|
||||
Used only when enable_endpoint_detection is True. If the utterance
|
||||
length in seconds is larger than this value, we assume an endpoint
|
||||
is detected.
|
||||
decoding_method:
|
||||
The only valid value is greedy_search.
|
||||
provider:
|
||||
onnxruntime execution providers. Valid values are: cpu, cuda, coreml.
|
||||
debug:
|
||||
True to show meta data in the model.
|
||||
rule_fsts:
|
||||
If not empty, it specifies fsts for inverse text normalization.
|
||||
If there are multiple fsts, they are separated by a comma.
|
||||
rule_fars:
|
||||
If not empty, it specifies fst archives for inverse text normalization.
|
||||
If there are multiple archives, they are separated by a comma.
|
||||
device:
|
||||
onnxruntime cuda device index.
|
||||
"""
|
||||
self = cls.__new__(cls)
|
||||
_assert_file_exists(tokens)
|
||||
_assert_file_exists(model)
|
||||
|
||||
assert num_threads > 0, num_threads
|
||||
|
||||
t_one_ctc_config = OnlineToneCtcModelConfig(
|
||||
model=model,
|
||||
)
|
||||
|
||||
provider_config = ProviderConfig(
|
||||
provider=provider,
|
||||
device=device,
|
||||
)
|
||||
|
||||
model_config = OnlineModelConfig(
|
||||
t_one_ctc=t_one_ctc_config,
|
||||
tokens=tokens,
|
||||
num_threads=num_threads,
|
||||
provider_config=provider_config,
|
||||
debug=debug,
|
||||
)
|
||||
|
||||
feat_config = FeatureExtractorConfig(
|
||||
sampling_rate=sample_rate,
|
||||
feature_dim=feature_dim,
|
||||
)
|
||||
|
||||
endpoint_config = EndpointConfig(
|
||||
rule1_min_trailing_silence=rule1_min_trailing_silence,
|
||||
rule2_min_trailing_silence=rule2_min_trailing_silence,
|
||||
rule3_min_utterance_length=rule3_min_utterance_length,
|
||||
)
|
||||
|
||||
recognizer_config = OnlineRecognizerConfig(
|
||||
feat_config=feat_config,
|
||||
model_config=model_config,
|
||||
endpoint_config=endpoint_config,
|
||||
enable_endpoint=enable_endpoint_detection,
|
||||
decoding_method=decoding_method,
|
||||
rule_fsts=rule_fsts,
|
||||
rule_fars=rule_fars,
|
||||
hr=HomophoneReplacerConfig(
|
||||
dict_dir=hr_dict_dir,
|
||||
lexicon=hr_lexicon,
|
||||
rule_fsts=hr_rule_fsts,
|
||||
),
|
||||
)
|
||||
|
||||
self.recognizer = _Recognizer(recognizer_config)
|
||||
self.config = recognizer_config
|
||||
return self
|
||||
|
||||
@classmethod
|
||||
def from_nemo_ctc(
|
||||
cls,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user