exposing online punctuation model support in node-addon-api (#2609)

* exposing online punctuation model support in node-addon-api

* renaming nodejs-addon-examples/test_punctuation.js to test_offline_punctuation.js

* adding test_online_punctuation to nodejs-addon-examples and updating CI to run test_offline_punctuation and test_online_punctuation
This commit is contained in:
colourmebrad 2025-09-19 11:29:55 -04:00 committed by GitHub
parent 26aa2fa932
commit ef5c23e6c9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 186 additions and 6 deletions

View File

@ -264,8 +264,16 @@ if [[ $arch != "ia32" && $platform != "win32" && $node_version != 21 ]]; then
tar xvf sherpa-onnx-punct-ct-transformer-zh-en-vocab272727-2024-04-12.tar.bz2 tar xvf sherpa-onnx-punct-ct-transformer-zh-en-vocab272727-2024-04-12.tar.bz2
rm sherpa-onnx-punct-ct-transformer-zh-en-vocab272727-2024-04-12.tar.bz2 rm sherpa-onnx-punct-ct-transformer-zh-en-vocab272727-2024-04-12.tar.bz2
node ./test_punctuation.js node ./test_offline_punctuation.js
rm -rf sherpa-onnx-punct-ct-transformer-zh-en-vocab272727-2024-04-12 rm -rf sherpa-onnx-punct-ct-transformer-zh-en-vocab272727-2024-04-12
curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/punctuation-models/sherpa-onnx-online-punct-en-2024-08-06.tar.bz2
tar xvf sherpa-onnx-online-punct-en-2024-08-06.tar.bz2
rm sherpa-onnx-online-punct-en-2024-08-06.tar.bz2
node ./test_online_punctuation.js
rm -rf sherpa-onnx-online-punct-en-2024-08-06.tar.bz2
fi fi
echo "----------audio tagging----------" echo "----------audio tagging----------"

View File

@ -7,6 +7,7 @@
#include "napi.h" // NOLINT #include "napi.h" // NOLINT
#include "sherpa-onnx/c-api/c-api.h" #include "sherpa-onnx/c-api/c-api.h"
// offline punctuation models
static SherpaOnnxOfflinePunctuationModelConfig GetOfflinePunctuationModelConfig( static SherpaOnnxOfflinePunctuationModelConfig GetOfflinePunctuationModelConfig(
Napi::Object obj) { Napi::Object obj) {
SherpaOnnxOfflinePunctuationModelConfig c; SherpaOnnxOfflinePunctuationModelConfig c;
@ -121,10 +122,136 @@ static Napi::String OfflinePunctuationAddPunctWraper(
return ans; return ans;
} }
// online punctuation models
static SherpaOnnxOnlinePunctuationModelConfig GetOnlinePunctuationModelConfig(
Napi::Object obj) {
SherpaOnnxOnlinePunctuationModelConfig c;
memset(&c, 0, sizeof(c));
if (!obj.Has("model") || !obj.Get("model").IsObject()) {
return c;
}
Napi::Object o = obj.Get("model").As<Napi::Object>();
SHERPA_ONNX_ASSIGN_ATTR_STR(cnn_bilstm, cnnBilstm);
SHERPA_ONNX_ASSIGN_ATTR_STR(bpe_vocab, bpeVocab);
SHERPA_ONNX_ASSIGN_ATTR_INT32(num_threads, numThreads);
if (o.Has("debug") &&
(o.Get("debug").IsNumber() || o.Get("debug").IsBoolean())) {
if (o.Get("debug").IsBoolean()) {
c.debug = o.Get("debug").As<Napi::Boolean>().Value();
} else {
c.debug = o.Get("debug").As<Napi::Number>().Int32Value();
}
}
SHERPA_ONNX_ASSIGN_ATTR_STR(provider, provider);
return c;
}
static Napi::External<SherpaOnnxOnlinePunctuation>
CreateOnlinePunctuationWrapper(const Napi::CallbackInfo &info) {
Napi::Env env = info.Env();
if (info.Length() != 1) {
std::ostringstream os;
os << "Expect only 1 argument. Given: " << info.Length();
Napi::TypeError::New(env, os.str()).ThrowAsJavaScriptException();
return {};
}
if (!info[0].IsObject()) {
Napi::TypeError::New(env, "You should pass an object as the only argument.")
.ThrowAsJavaScriptException();
return {};
}
Napi::Object o = info[0].As<Napi::Object>();
SherpaOnnxOnlinePunctuationConfig c;
memset(&c, 0, sizeof(c));
c.model = GetOnlinePunctuationModelConfig(o);
const SherpaOnnxOnlinePunctuation *punct =
SherpaOnnxCreateOnlinePunctuation(&c);
SHERPA_ONNX_DELETE_C_STR(c.model.cnn_bilstm);
SHERPA_ONNX_DELETE_C_STR(c.model.bpe_vocab);
SHERPA_ONNX_DELETE_C_STR(c.model.provider);
if (!punct) {
Napi::TypeError::New(env, "Please check your config!")
.ThrowAsJavaScriptException();
return {};
}
return Napi::External<SherpaOnnxOnlinePunctuation>::New(
env, const_cast<SherpaOnnxOnlinePunctuation *>(punct),
[](Napi::Env env, SherpaOnnxOnlinePunctuation *punct) {
SherpaOnnxDestroyOnlinePunctuation(punct);
});
}
static Napi::String OnlinePunctuationAddPunctWraper(
const Napi::CallbackInfo &info) {
Napi::Env env = info.Env();
if (info.Length() != 2) {
std::ostringstream os;
os << "Expect only 2 arguments. Given: " << info.Length();
Napi::TypeError::New(env, os.str()).ThrowAsJavaScriptException();
return {};
}
if (!info[0].IsExternal()) {
Napi::TypeError::New(
env,
"You should pass an online punctuation pointer as the first argument")
.ThrowAsJavaScriptException();
return {};
}
if (!info[1].IsString()) {
Napi::TypeError::New(env, "You should pass a string as the second argument")
.ThrowAsJavaScriptException();
return {};
}
const SherpaOnnxOnlinePunctuation *punct =
info[0].As<Napi::External<SherpaOnnxOnlinePunctuation>>().Data();
Napi::String js_text = info[1].As<Napi::String>();
std::string text = js_text.Utf8Value();
const char *punct_text =
SherpaOnnxOnlinePunctuationAddPunct(punct, text.c_str());
Napi::String ans = Napi::String::New(env, punct_text);
SherpaOnnxOnlinePunctuationFreeText(punct_text);
return ans;
}
// exports
void InitPunctuation(Napi::Env env, Napi::Object exports) { void InitPunctuation(Napi::Env env, Napi::Object exports) {
exports.Set(Napi::String::New(env, "createOfflinePunctuation"), exports.Set(Napi::String::New(env, "createOfflinePunctuation"),
Napi::Function::New(env, CreateOfflinePunctuationWrapper)); Napi::Function::New(env, CreateOfflinePunctuationWrapper));
exports.Set(Napi::String::New(env, "offlinePunctuationAddPunct"), exports.Set(Napi::String::New(env, "offlinePunctuationAddPunct"),
Napi::Function::New(env, OfflinePunctuationAddPunctWraper)); Napi::Function::New(env, OfflinePunctuationAddPunctWraper));
exports.Set(Napi::String::New(env, "createOnlinePunctuation"),
Napi::Function::New(env, CreateOnlinePunctuationWrapper));
exports.Set(Napi::String::New(env, "onlinePunctuationAddPunct"),
Napi::Function::New(env, OnlinePunctuationAddPunctWraper));
} }

View File

@ -14,7 +14,7 @@ function createPunctuation() {
provider: 'cpu', provider: 'cpu',
}, },
}; };
return new sherpa_onnx.Punctuation(config); return new sherpa_onnx.OfflinePunctuation(config);
} }
const punct = createPunctuation(); const punct = createPunctuation();

View File

@ -0,0 +1,33 @@
// Copyright (c) 2023-2024 Xiaomi Corporation (authors: Fangjun Kuang)
const sherpa_onnx = require('sherpa-onnx-node');
// Please download test files from
// https://github.com/k2-fsa/sherpa-onnx/releases/tag/punctuation-models
function createPunctuation() {
const config = {
model: {
cnnBilstm:
'./sherpa-onnx-online-punct-en-2024-08-06/model.onnx',
bpeVocab:
'./sherpa-onnx-online-punct-en-2024-08-06/bpe.vocab',
debug: true,
numThreads: 1,
provider: 'cpu',
},
};
return new sherpa_onnx.OnlinePunctuation(config);
}
const punct = createPunctuation();
const sentences = [
'How are you i am fine thank you',
'The african blogosphere is rapidly expanding bringing more voices online in the form of commentaries opinions analyses rants and poetry',
];
console.log('---');
for (let sentence of sentences) {
const punct_text = punct.addPunct(sentence);
console.log(`Input: ${sentence}`);
console.log(`Output: ${punct_text}`);
console.log('---');
}

View File

@ -1,6 +1,6 @@
const addon = require('./addon.js'); const addon = require('./addon.js');
class Punctuation { class OfflinePunctuation {
constructor(config) { constructor(config) {
this.handle = addon.createOfflinePunctuation(config); this.handle = addon.createOfflinePunctuation(config);
this.config = config; this.config = config;
@ -10,6 +10,17 @@ class Punctuation {
} }
} }
module.exports = { class OnlinePunctuation {
Punctuation, constructor(config) {
this.handle = addon.createOnlinePunctuation(config);
this.config = config;
}
addPunct(text) {
return addon.onlinePunctuationAddPunct(this.handle, text);
}
}
module.exports = {
OfflinePunctuation,
OnlinePunctuation,
} }

View File

@ -24,7 +24,8 @@ module.exports = {
SpeakerEmbeddingExtractor: sid.SpeakerEmbeddingExtractor, SpeakerEmbeddingExtractor: sid.SpeakerEmbeddingExtractor,
SpeakerEmbeddingManager: sid.SpeakerEmbeddingManager, SpeakerEmbeddingManager: sid.SpeakerEmbeddingManager,
AudioTagging: at.AudioTagging, AudioTagging: at.AudioTagging,
Punctuation: punct.Punctuation, OfflinePunctuation: punct.OfflinePunctuation,
OnlinePunctuation: punct.OnlinePunctuation,
KeywordSpotter: kws.KeywordSpotter, KeywordSpotter: kws.KeywordSpotter,
OfflineSpeakerDiarization: sd.OfflineSpeakerDiarization, OfflineSpeakerDiarization: sd.OfflineSpeakerDiarization,
OfflineSpeechDenoiser: speech_denoiser.OfflineSpeechDenoiser, OfflineSpeechDenoiser: speech_denoiser.OfflineSpeechDenoiser,