mirror of
https://github.com/k2-fsa/sherpa-onnx.git
synced 2026-01-09 07:41:06 +08:00
Fix modified beam search for iOS and android (#76)
* Use Int type for sampling rate * Fix swift * Fix iOS
This commit is contained in:
parent
7f72c13d9a
commit
5f31b22c12
1
.github/scripts/.gitignore
vendored
1
.github/scripts/.gitignore
vendored
@ -1,2 +1,3 @@
|
||||
Makefile
|
||||
*.jar
|
||||
hs_err_pid*.log
|
||||
|
||||
19
.github/scripts/Main.kt
vendored
19
.github/scripts/Main.kt
vendored
@ -4,7 +4,7 @@ import android.content.res.AssetManager
|
||||
|
||||
fun main() {
|
||||
var featConfig = FeatureConfig(
|
||||
sampleRate = 16000.0f,
|
||||
sampleRate = 16000,
|
||||
featureDim = 80,
|
||||
)
|
||||
|
||||
@ -13,7 +13,7 @@ fun main() {
|
||||
decoder = "./sherpa-onnx-streaming-zipformer-en-2023-02-21/decoder-epoch-99-avg-1.onnx",
|
||||
joiner = "./sherpa-onnx-streaming-zipformer-en-2023-02-21/joiner-epoch-99-avg-1.onnx",
|
||||
tokens = "./sherpa-onnx-streaming-zipformer-en-2023-02-21/tokens.txt",
|
||||
numThreads = 4,
|
||||
numThreads = 1,
|
||||
debug = false,
|
||||
)
|
||||
|
||||
@ -24,22 +24,31 @@ fun main() {
|
||||
featConfig = featConfig,
|
||||
endpointConfig = endpointConfig,
|
||||
enableEndpoint = true,
|
||||
decodingMethod = "greedy_search",
|
||||
maxActivePaths = 4,
|
||||
)
|
||||
|
||||
var model = SherpaOnnx(
|
||||
assetManager = AssetManager(),
|
||||
config = config,
|
||||
)
|
||||
|
||||
var samples = WaveReader.readWave(
|
||||
assetManager = AssetManager(),
|
||||
filename = "./sherpa-onnx-streaming-zipformer-en-2023-02-21/test_wavs/1089-134686-0001.wav",
|
||||
)
|
||||
|
||||
model.decodeSamples(samples!!)
|
||||
model.acceptWaveform(samples!!, sampleRate=16000)
|
||||
while (model.isReady()) {
|
||||
model.decode()
|
||||
}
|
||||
|
||||
var tail_paddings = FloatArray(8000) // 0.5 seconds
|
||||
model.decodeSamples(tail_paddings)
|
||||
|
||||
model.acceptWaveform(tail_paddings, sampleRate=16000)
|
||||
model.inputFinished()
|
||||
while (model.isReady()) {
|
||||
model.decode()
|
||||
}
|
||||
|
||||
println("results: ${model.text}")
|
||||
}
|
||||
|
||||
1
.gitignore
vendored
1
.gitignore
vendored
@ -38,3 +38,4 @@ log.txt
|
||||
tags
|
||||
run-decode-file-python.sh
|
||||
android/SherpaOnnx/app/src/main/assets/
|
||||
*.ncnn.*
|
||||
|
||||
@ -121,7 +121,10 @@ class MainActivity : AppCompatActivity() {
|
||||
val ret = audioRecord?.read(buffer, 0, buffer.size)
|
||||
if (ret != null && ret > 0) {
|
||||
val samples = FloatArray(ret) { buffer[it] / 32768.0f }
|
||||
model.decodeSamples(samples)
|
||||
model.acceptWaveform(samples, sampleRate=16000)
|
||||
while (model.isReady()) {
|
||||
model.decode()
|
||||
}
|
||||
runOnUiThread {
|
||||
val isEndpoint = model.isEndpoint()
|
||||
val text = model.text
|
||||
@ -177,33 +180,17 @@ class MainActivity : AppCompatActivity() {
|
||||
val type = 0
|
||||
println("Select model type ${type}")
|
||||
val config = OnlineRecognizerConfig(
|
||||
featConfig = getFeatureConfig(sampleRate = 16000.0f, featureDim = 80),
|
||||
featConfig = getFeatureConfig(sampleRate = 16000, featureDim = 80),
|
||||
modelConfig = getModelConfig(type = type)!!,
|
||||
endpointConfig = getEndpointConfig(),
|
||||
enableEndpoint = true
|
||||
enableEndpoint = true,
|
||||
decodingMethod = "greedy_search",
|
||||
maxActivePaths = 4,
|
||||
)
|
||||
|
||||
model = SherpaOnnx(
|
||||
assetManager = application.assets,
|
||||
config = config,
|
||||
)
|
||||
/*
|
||||
println("reading samples")
|
||||
val samples = WaveReader.readWave(
|
||||
assetManager = application.assets,
|
||||
// filename = "sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/test_wavs/0.wav",
|
||||
filename = "sherpa-onnx-lstm-zh-2023-02-20/test_wavs/0.wav",
|
||||
// filename="sherpa-onnx-lstm-en-2023-02-17/test_wavs/1089-134686-0001.wav"
|
||||
)
|
||||
println("samples read done!")
|
||||
|
||||
model.decodeSamples(samples!!)
|
||||
|
||||
val tailPaddings = FloatArray(8000) // 0.5 seconds
|
||||
model.decodeSamples(tailPaddings)
|
||||
|
||||
println("result is: ${model.text}")
|
||||
model.reset()
|
||||
*/
|
||||
}
|
||||
}
|
||||
|
||||
@ -24,7 +24,7 @@ data class OnlineTransducerModelConfig(
|
||||
)
|
||||
|
||||
data class FeatureConfig(
|
||||
var sampleRate: Float = 16000.0f,
|
||||
var sampleRate: Int = 16000,
|
||||
var featureDim: Int = 80,
|
||||
)
|
||||
|
||||
@ -32,7 +32,9 @@ data class OnlineRecognizerConfig(
|
||||
var featConfig: FeatureConfig = FeatureConfig(),
|
||||
var modelConfig: OnlineTransducerModelConfig,
|
||||
var endpointConfig: EndpointConfig = EndpointConfig(),
|
||||
var enableEndpoint: Boolean,
|
||||
var enableEndpoint: Boolean = true,
|
||||
var decodingMethod: String = "greedy_search",
|
||||
var maxActivePaths: Int = 4,
|
||||
)
|
||||
|
||||
class SherpaOnnx(
|
||||
@ -49,12 +51,14 @@ class SherpaOnnx(
|
||||
}
|
||||
|
||||
|
||||
fun decodeSamples(samples: FloatArray) =
|
||||
decodeSamples(ptr, samples, sampleRate = config.featConfig.sampleRate)
|
||||
fun acceptWaveform(samples: FloatArray, sampleRate: Int) =
|
||||
acceptWaveform(ptr, samples, sampleRate)
|
||||
|
||||
fun inputFinished() = inputFinished(ptr)
|
||||
fun reset() = reset(ptr)
|
||||
fun decode() = decode(ptr)
|
||||
fun isEndpoint(): Boolean = isEndpoint(ptr)
|
||||
fun isReady(): Boolean = isReady(ptr)
|
||||
|
||||
val text: String
|
||||
get() = getText(ptr)
|
||||
@ -66,11 +70,13 @@ class SherpaOnnx(
|
||||
config: OnlineRecognizerConfig,
|
||||
): Long
|
||||
|
||||
private external fun decodeSamples(ptr: Long, samples: FloatArray, sampleRate: Float)
|
||||
private external fun acceptWaveform(ptr: Long, samples: FloatArray, sampleRate: Int)
|
||||
private external fun inputFinished(ptr: Long)
|
||||
private external fun getText(ptr: Long): String
|
||||
private external fun reset(ptr: Long)
|
||||
private external fun decode(ptr: Long)
|
||||
private external fun isEndpoint(ptr: Long): Boolean
|
||||
private external fun isReady(ptr: Long): Boolean
|
||||
|
||||
companion object {
|
||||
init {
|
||||
@ -79,7 +85,7 @@ class SherpaOnnx(
|
||||
}
|
||||
}
|
||||
|
||||
fun getFeatureConfig(sampleRate: Float, featureDim: Int): FeatureConfig {
|
||||
fun getFeatureConfig(sampleRate: Int, featureDim: Int): FeatureConfig {
|
||||
return FeatureConfig(sampleRate=sampleRate, featureDim=featureDim)
|
||||
}
|
||||
|
||||
|
||||
@ -23,10 +23,10 @@ extension AVAudioPCMBuffer {
|
||||
class ViewController: UIViewController {
|
||||
@IBOutlet weak var resultLabel: UILabel!
|
||||
@IBOutlet weak var recordBtn: UIButton!
|
||||
|
||||
|
||||
var audioEngine: AVAudioEngine? = nil
|
||||
var recognizer: SherpaOnnxRecognizer! = nil
|
||||
|
||||
|
||||
/// It saves the decoded results so far
|
||||
var sentences: [String] = [] {
|
||||
didSet {
|
||||
@ -42,7 +42,7 @@ class ViewController: UIViewController {
|
||||
if sentences.isEmpty {
|
||||
return "0: \(lastSentence.lowercased())"
|
||||
}
|
||||
|
||||
|
||||
let start = max(sentences.count - maxSentence, 0)
|
||||
if lastSentence.isEmpty {
|
||||
return sentences.enumerated().map { (index, s) in "\(index): \(s.lowercased())" }[start...]
|
||||
@ -52,23 +52,23 @@ class ViewController: UIViewController {
|
||||
.joined(separator: "\n") + "\n\(sentences.count): \(lastSentence.lowercased())"
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
func updateLabel() {
|
||||
DispatchQueue.main.async {
|
||||
self.resultLabel.text = self.results
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
override func viewDidLoad() {
|
||||
super.viewDidLoad()
|
||||
// Do any additional setup after loading the view.
|
||||
|
||||
|
||||
resultLabel.text = "ASR with Next-gen Kaldi\n\nSee https://github.com/k2-fsa/sherpa-onnx\n\nPress the Start button to run!"
|
||||
recordBtn.setTitle("Start", for: .normal)
|
||||
initRecognizer()
|
||||
initRecorder()
|
||||
}
|
||||
|
||||
|
||||
@IBAction func onRecordBtnClick(_ sender: UIButton) {
|
||||
if recordBtn.currentTitle == "Start" {
|
||||
startRecorder()
|
||||
@ -78,30 +78,32 @@ class ViewController: UIViewController {
|
||||
recordBtn.setTitle("Start", for: .normal)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
func initRecognizer() {
|
||||
// Please select one model that is best suitable for you.
|
||||
//
|
||||
// You can also modify Model.swift to add new pre-trained models from
|
||||
// https://k2-fsa.github.io/sherpa/ncnn/pretrained_models/index.html
|
||||
|
||||
|
||||
let modelConfig = getBilingualStreamZhEnZipformer20230220()
|
||||
|
||||
|
||||
let featConfig = sherpaOnnxFeatureConfig(
|
||||
sampleRate: 16000,
|
||||
featureDim: 80)
|
||||
|
||||
|
||||
var config = sherpaOnnxOnlineRecognizerConfig(
|
||||
featConfig: featConfig,
|
||||
modelConfig: modelConfig,
|
||||
enableEndpoint: true,
|
||||
rule1MinTrailingSilence: 2.4,
|
||||
rule2MinTrailingSilence: 0.8,
|
||||
rule3MinUtteranceLength: 30
|
||||
rule3MinUtteranceLength: 30,
|
||||
decodingMethod: "greedy_search",
|
||||
maxActivePaths: 4
|
||||
)
|
||||
recognizer = SherpaOnnxRecognizer(config: &config)
|
||||
}
|
||||
|
||||
|
||||
func initRecorder() {
|
||||
print("init recorder")
|
||||
audioEngine = AVAudioEngine()
|
||||
@ -112,9 +114,9 @@ class ViewController: UIViewController {
|
||||
commonFormat: .pcmFormatFloat32,
|
||||
sampleRate: 16000, channels: 1,
|
||||
interleaved: false)!
|
||||
|
||||
|
||||
let converter = AVAudioConverter(from: inputFormat!, to: outputFormat)!
|
||||
|
||||
|
||||
inputNode!.installTap(
|
||||
onBus: bus,
|
||||
bufferSize: 1024,
|
||||
@ -122,34 +124,34 @@ class ViewController: UIViewController {
|
||||
) {
|
||||
(buffer: AVAudioPCMBuffer, when: AVAudioTime) in
|
||||
var newBufferAvailable = true
|
||||
|
||||
|
||||
let inputCallback: AVAudioConverterInputBlock = {
|
||||
inNumPackets, outStatus in
|
||||
if newBufferAvailable {
|
||||
outStatus.pointee = .haveData
|
||||
newBufferAvailable = false
|
||||
|
||||
|
||||
return buffer
|
||||
} else {
|
||||
outStatus.pointee = .noDataNow
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
let convertedBuffer = AVAudioPCMBuffer(
|
||||
pcmFormat: outputFormat,
|
||||
frameCapacity:
|
||||
AVAudioFrameCount(outputFormat.sampleRate)
|
||||
* buffer.frameLength
|
||||
/ AVAudioFrameCount(buffer.format.sampleRate))!
|
||||
|
||||
|
||||
var error: NSError?
|
||||
let _ = converter.convert(
|
||||
to: convertedBuffer,
|
||||
error: &error, withInputFrom: inputCallback)
|
||||
|
||||
|
||||
// TODO(fangjun): Handle status != haveData
|
||||
|
||||
|
||||
let array = convertedBuffer.array()
|
||||
if !array.isEmpty {
|
||||
self.recognizer.acceptWaveform(samples: array)
|
||||
@ -158,13 +160,13 @@ class ViewController: UIViewController {
|
||||
}
|
||||
let isEndpoint = self.recognizer.isEndpoint()
|
||||
let text = self.recognizer.getResult().text
|
||||
|
||||
|
||||
if !text.isEmpty && self.lastSentence != text {
|
||||
self.lastSentence = text
|
||||
self.updateLabel()
|
||||
print(text)
|
||||
}
|
||||
|
||||
|
||||
if isEndpoint {
|
||||
if !text.isEmpty {
|
||||
let tmp = self.lastSentence
|
||||
@ -175,13 +177,13 @@ class ViewController: UIViewController {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
|
||||
|
||||
func startRecorder() {
|
||||
lastSentence = ""
|
||||
sentences = []
|
||||
|
||||
|
||||
do {
|
||||
try self.audioEngine?.start()
|
||||
} catch let error as NSError {
|
||||
@ -189,7 +191,7 @@ class ViewController: UIViewController {
|
||||
}
|
||||
print("started")
|
||||
}
|
||||
|
||||
|
||||
func stopRecorder() {
|
||||
audioEngine?.stop()
|
||||
print("stopped")
|
||||
|
||||
@ -76,7 +76,7 @@ SherpaOnnxOnlineStream *CreateOnlineStream(
|
||||
|
||||
void DestoryOnlineStream(SherpaOnnxOnlineStream *stream) { delete stream; }
|
||||
|
||||
void AcceptWaveform(SherpaOnnxOnlineStream *stream, float sample_rate,
|
||||
void AcceptWaveform(SherpaOnnxOnlineStream *stream, int32_t sample_rate,
|
||||
const float *samples, int32_t n) {
|
||||
stream->impl->AcceptWaveform(sample_rate, samples, n);
|
||||
}
|
||||
|
||||
@ -120,7 +120,7 @@ void DestoryOnlineStream(SherpaOnnxOnlineStream *stream);
|
||||
/// @param samples A pointer to a 1-D array containing audio samples.
|
||||
/// The range of samples has to be normalized to [-1, 1].
|
||||
/// @param n Number of elements in the samples array.
|
||||
void AcceptWaveform(SherpaOnnxOnlineStream *stream, float sample_rate,
|
||||
void AcceptWaveform(SherpaOnnxOnlineStream *stream, int32_t sample_rate,
|
||||
const float *samples, int32_t n);
|
||||
|
||||
/// Return 1 if there are enough number of feature frames for decoding.
|
||||
|
||||
@ -48,7 +48,7 @@ class FeatureExtractor::Impl {
|
||||
fbank_ = std::make_unique<knf::OnlineFbank>(opts_);
|
||||
}
|
||||
|
||||
void AcceptWaveform(float sampling_rate, const float *waveform, int32_t n) {
|
||||
void AcceptWaveform(int32_t sampling_rate, const float *waveform, int32_t n) {
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
fbank_->AcceptWaveform(sampling_rate, waveform, n);
|
||||
}
|
||||
@ -107,7 +107,7 @@ FeatureExtractor::FeatureExtractor(const FeatureExtractorConfig &config /*={}*/)
|
||||
|
||||
FeatureExtractor::~FeatureExtractor() = default;
|
||||
|
||||
void FeatureExtractor::AcceptWaveform(float sampling_rate,
|
||||
void FeatureExtractor::AcceptWaveform(int32_t sampling_rate,
|
||||
const float *waveform, int32_t n) {
|
||||
impl_->AcceptWaveform(sampling_rate, waveform, n);
|
||||
}
|
||||
|
||||
@ -14,7 +14,7 @@
|
||||
namespace sherpa_onnx {
|
||||
|
||||
struct FeatureExtractorConfig {
|
||||
float sampling_rate = 16000;
|
||||
int32_t sampling_rate = 16000;
|
||||
int32_t feature_dim = 80;
|
||||
int32_t max_feature_vectors = -1;
|
||||
|
||||
@ -34,7 +34,7 @@ class FeatureExtractor {
|
||||
@param waveform Pointer to a 1-D array of size n
|
||||
@param n Number of entries in waveform
|
||||
*/
|
||||
void AcceptWaveform(float sampling_rate, const float *waveform, int32_t n);
|
||||
void AcceptWaveform(int32_t sampling_rate, const float *waveform, int32_t n);
|
||||
|
||||
/**
|
||||
* InputFinished() tells the class you won't be providing any
|
||||
|
||||
@ -112,7 +112,7 @@ for a list of pre-trained models to download.
|
||||
|
||||
param.suggestedLatency = info->defaultLowInputLatency;
|
||||
param.hostApiSpecificStreamInfo = nullptr;
|
||||
const float sample_rate = 16000;
|
||||
float sample_rate = 16000;
|
||||
|
||||
PaStream *stream;
|
||||
PaError err =
|
||||
|
||||
@ -61,7 +61,7 @@ for a list of pre-trained models to download.
|
||||
|
||||
sherpa_onnx::OnlineRecognizer recognizer(config);
|
||||
|
||||
float expected_sampling_rate = config.feat_config.sampling_rate;
|
||||
int32_t expected_sampling_rate = config.feat_config.sampling_rate;
|
||||
|
||||
bool is_ok = false;
|
||||
std::vector<float> samples =
|
||||
@ -72,7 +72,7 @@ for a list of pre-trained models to download.
|
||||
return -1;
|
||||
}
|
||||
|
||||
float duration = samples.size() / expected_sampling_rate;
|
||||
float duration = samples.size() / static_cast<float>(expected_sampling_rate);
|
||||
|
||||
fprintf(stderr, "wav filename: %s\n", wav_filename.c_str());
|
||||
fprintf(stderr, "wav duration (s): %.3f\n", duration);
|
||||
|
||||
@ -40,19 +40,18 @@ class SherpaOnnx {
|
||||
mgr,
|
||||
#endif
|
||||
config),
|
||||
stream_(recognizer_.CreateStream()),
|
||||
tail_padding_(16000 * 0.32, 0) {
|
||||
stream_(recognizer_.CreateStream()) {
|
||||
}
|
||||
|
||||
void DecodeSamples(float sample_rate, const float *samples, int32_t n) const {
|
||||
void AcceptWaveform(int32_t sample_rate, const float *samples,
|
||||
int32_t n) const {
|
||||
stream_->AcceptWaveform(sample_rate, samples, n);
|
||||
Decode();
|
||||
}
|
||||
|
||||
void InputFinished() const {
|
||||
stream_->AcceptWaveform(16000, tail_padding_.data(), tail_padding_.size());
|
||||
std::vector<float> tail_padding(16000 * 0.32, 0);
|
||||
stream_->AcceptWaveform(16000, tail_padding.data(), tail_padding.size());
|
||||
stream_->InputFinished();
|
||||
Decode();
|
||||
}
|
||||
|
||||
const std::string GetText() const {
|
||||
@ -62,19 +61,15 @@ class SherpaOnnx {
|
||||
|
||||
bool IsEndpoint() const { return recognizer_.IsEndpoint(stream_.get()); }
|
||||
|
||||
bool IsReady() const { return recognizer_.IsReady(stream_.get()); }
|
||||
|
||||
void Reset() const { return recognizer_.Reset(stream_.get()); }
|
||||
|
||||
private:
|
||||
void Decode() const {
|
||||
while (recognizer_.IsReady(stream_.get())) {
|
||||
recognizer_.DecodeStream(stream_.get());
|
||||
}
|
||||
}
|
||||
void Decode() const { recognizer_.DecodeStream(stream_.get()); }
|
||||
|
||||
private:
|
||||
sherpa_onnx::OnlineRecognizer recognizer_;
|
||||
std::unique_ptr<sherpa_onnx::OnlineStream> stream_;
|
||||
std::vector<float> tail_padding_;
|
||||
};
|
||||
|
||||
static OnlineRecognizerConfig GetConfig(JNIEnv *env, jobject config) {
|
||||
@ -86,14 +81,24 @@ static OnlineRecognizerConfig GetConfig(JNIEnv *env, jobject config) {
|
||||
// https://docs.oracle.com/javase/7/docs/technotes/guides/jni/spec/types.html
|
||||
// https://courses.cs.washington.edu/courses/cse341/99wi/java/tutorial/native1.1/implementing/field.html
|
||||
|
||||
//---------- decoding ----------
|
||||
fid = env->GetFieldID(cls, "decodingMethod", "Ljava/lang/String;");
|
||||
jstring s = (jstring)env->GetObjectField(config, fid);
|
||||
const char *p = env->GetStringUTFChars(s, nullptr);
|
||||
ans.decoding_method = p;
|
||||
env->ReleaseStringUTFChars(s, p);
|
||||
|
||||
fid = env->GetFieldID(cls, "maxActivePaths", "I");
|
||||
ans.max_active_paths = env->GetIntField(config, fid);
|
||||
|
||||
//---------- feat config ----------
|
||||
fid = env->GetFieldID(cls, "featConfig",
|
||||
"Lcom/k2fsa/sherpa/onnx/FeatureConfig;");
|
||||
jobject feat_config = env->GetObjectField(config, fid);
|
||||
jclass feat_config_cls = env->GetObjectClass(feat_config);
|
||||
|
||||
fid = env->GetFieldID(feat_config_cls, "sampleRate", "F");
|
||||
ans.feat_config.sampling_rate = env->GetFloatField(feat_config, fid);
|
||||
fid = env->GetFieldID(feat_config_cls, "sampleRate", "I");
|
||||
ans.feat_config.sampling_rate = env->GetIntField(feat_config, fid);
|
||||
|
||||
fid = env->GetFieldID(feat_config_cls, "featureDim", "I");
|
||||
ans.feat_config.feature_dim = env->GetIntField(feat_config, fid);
|
||||
@ -153,8 +158,8 @@ static OnlineRecognizerConfig GetConfig(JNIEnv *env, jobject config) {
|
||||
jclass model_config_cls = env->GetObjectClass(model_config);
|
||||
|
||||
fid = env->GetFieldID(model_config_cls, "encoder", "Ljava/lang/String;");
|
||||
jstring s = (jstring)env->GetObjectField(model_config, fid);
|
||||
const char *p = env->GetStringUTFChars(s, nullptr);
|
||||
s = (jstring)env->GetObjectField(model_config, fid);
|
||||
p = env->GetStringUTFChars(s, nullptr);
|
||||
ans.model_config.encoder_filename = p;
|
||||
env->ReleaseStringUTFChars(s, p);
|
||||
|
||||
@ -198,6 +203,7 @@ JNIEXPORT jlong JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnx_new(
|
||||
#endif
|
||||
|
||||
auto config = sherpa_onnx::GetConfig(env, _config);
|
||||
SHERPA_ONNX_LOGE("config:\n%s", config.ToString().c_str());
|
||||
auto model = new sherpa_onnx::SherpaOnnx(
|
||||
#if __ANDROID_API__ >= 9
|
||||
mgr,
|
||||
@ -220,6 +226,13 @@ JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnx_reset(
|
||||
model->Reset();
|
||||
}
|
||||
|
||||
SHERPA_ONNX_EXTERN_C
|
||||
JNIEXPORT bool JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnx_isReady(
|
||||
JNIEnv *env, jobject /*obj*/, jlong ptr) {
|
||||
auto model = reinterpret_cast<sherpa_onnx::SherpaOnnx *>(ptr);
|
||||
return model->IsReady();
|
||||
}
|
||||
|
||||
SHERPA_ONNX_EXTERN_C
|
||||
JNIEXPORT bool JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnx_isEndpoint(
|
||||
JNIEnv *env, jobject /*obj*/, jlong ptr) {
|
||||
@ -228,15 +241,22 @@ JNIEXPORT bool JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnx_isEndpoint(
|
||||
}
|
||||
|
||||
SHERPA_ONNX_EXTERN_C
|
||||
JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnx_decodeSamples(
|
||||
JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnx_decode(
|
||||
JNIEnv *env, jobject /*obj*/, jlong ptr) {
|
||||
auto model = reinterpret_cast<sherpa_onnx::SherpaOnnx *>(ptr);
|
||||
model->Decode();
|
||||
}
|
||||
|
||||
SHERPA_ONNX_EXTERN_C
|
||||
JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnx_acceptWaveform(
|
||||
JNIEnv *env, jobject /*obj*/, jlong ptr, jfloatArray samples,
|
||||
jfloat sample_rate) {
|
||||
jint sample_rate) {
|
||||
auto model = reinterpret_cast<sherpa_onnx::SherpaOnnx *>(ptr);
|
||||
|
||||
jfloat *p = env->GetFloatArrayElements(samples, nullptr);
|
||||
jsize n = env->GetArrayLength(samples);
|
||||
|
||||
model->DecodeSamples(sample_rate, p, n);
|
||||
model->AcceptWaveform(sample_rate, p, n);
|
||||
|
||||
env->ReleaseFloatArrayElements(samples, p, JNI_ABORT);
|
||||
}
|
||||
|
||||
@ -62,11 +62,15 @@ func sherpaOnnxOnlineRecognizerConfig(
|
||||
enableEndpoint: Bool = false,
|
||||
rule1MinTrailingSilence: Float = 2.4,
|
||||
rule2MinTrailingSilence: Float = 1.2,
|
||||
rule3MinUtteranceLength: Float = 30
|
||||
rule3MinUtteranceLength: Float = 30,
|
||||
decodingMethod: String = "greedy_search",
|
||||
maxActivePaths: Int = 4
|
||||
) -> SherpaOnnxOnlineRecognizerConfig{
|
||||
return SherpaOnnxOnlineRecognizerConfig(
|
||||
feat_config: featConfig,
|
||||
model_config: modelConfig,
|
||||
decoding_method: toCPointer(decodingMethod),
|
||||
max_active_paths: Int32(maxActivePaths),
|
||||
enable_endpoint: enableEndpoint ? 1 : 0,
|
||||
rule1_min_trailing_silence: rule1MinTrailingSilence,
|
||||
rule2_min_trailing_silence: rule2MinTrailingSilence,
|
||||
@ -128,12 +132,12 @@ class SherpaOnnxRecognizer {
|
||||
/// Decode wave samples.
|
||||
///
|
||||
/// - Parameters:
|
||||
/// - samples: Audio samples normalzed to the range [-1, 1]
|
||||
/// - samples: Audio samples normalized to the range [-1, 1]
|
||||
/// - sampleRate: Sample rate of the input audio samples. Must match
|
||||
/// the one expected by the model. It must be 16000 for
|
||||
/// models from icefall.
|
||||
func acceptWaveform(samples: [Float], sampleRate: Float = 16000) {
|
||||
AcceptWaveform(stream, sampleRate, samples, Int32(samples.count))
|
||||
func acceptWaveform(samples: [Float], sampleRate: Int = 16000) {
|
||||
AcceptWaveform(stream, Int32(sampleRate), samples, Int32(samples.count))
|
||||
}
|
||||
|
||||
func isReady() -> Bool {
|
||||
|
||||
@ -32,7 +32,9 @@ func run() {
|
||||
var config = sherpaOnnxOnlineRecognizerConfig(
|
||||
featConfig: featConfig,
|
||||
modelConfig: modelConfig,
|
||||
enableEndpoint: false
|
||||
enableEndpoint: false,
|
||||
decodingMethod: "modified_beam_search",
|
||||
maxActivePaths: 4
|
||||
)
|
||||
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user