mirror of
https://github.com/k2-fsa/sherpa-onnx.git
synced 2026-01-09 07:41:06 +08:00
1194 lines
40 KiB
Python
Executable File
1194 lines
40 KiB
Python
Executable File
#!/usr/bin/env python3
|
|
# Copyright 2022-2023 Xiaomi Corp.
|
|
"""
|
|
A server for non-streaming speech recognition. Non-streaming means you send all
|
|
the content of the audio at once for recognition.
|
|
|
|
It supports multiple clients sending at the same time.
|
|
|
|
Usage:
|
|
./non_streaming_server.py --help
|
|
|
|
Please refer to
|
|
|
|
https://k2-fsa.github.io/sherpa/onnx/pretrained_models/offline-transducer/index.html
|
|
https://k2-fsa.github.io/sherpa/onnx/pretrained_models/offline-paraformer/index.html
|
|
https://k2-fsa.github.io/sherpa/onnx/pretrained_models/offline-ctc/index.html
|
|
https://k2-fsa.github.io/sherpa/onnx/pretrained_models/whisper/index.html
|
|
|
|
for pre-trained models to download.
|
|
|
|
Usage examples:
|
|
|
|
(1) Use a non-streaming transducer model
|
|
|
|
cd /path/to/sherpa-onnx
|
|
curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-zipformer-en-2023-06-26.tar.bz2
|
|
tar xvf sherpa-onnx-zipformer-en-2023-06-26.tar.bz2
|
|
rm sherpa-onnx-zipformer-en-2023-06-26.tar.bz2
|
|
|
|
python3 ./python-api-examples/non_streaming_server.py \
|
|
--encoder ./sherpa-onnx-zipformer-en-2023-06-26/encoder-epoch-99-avg-1.onnx \
|
|
--decoder ./sherpa-onnx-zipformer-en-2023-06-26/decoder-epoch-99-avg-1.onnx \
|
|
--joiner ./sherpa-onnx-zipformer-en-2023-06-26/joiner-epoch-99-avg-1.onnx \
|
|
--tokens ./sherpa-onnx-zipformer-en-2023-06-26/tokens.txt \
|
|
--port 6006
|
|
|
|
(2) Use a non-streaming paraformer
|
|
|
|
cd /path/to/sherpa-onnx
|
|
curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-paraformer-zh-2023-09-14.tar.bz2
|
|
tar xvf sherpa-onnx-paraformer-zh-2023-09-14.tar.bz2
|
|
rm sherpa-onnx-paraformer-zh-2023-09-14.tar.bz2
|
|
|
|
python3 ./python-api-examples/non_streaming_server.py \
|
|
--paraformer ./sherpa-onnx-paraformer-zh-2023-09-14/model.int8.onnx \
|
|
--tokens ./sherpa-onnx-paraformer-zh-2023-09-14/tokens.txt
|
|
|
|
(3) Use a non-streaming CTC model from NeMo
|
|
|
|
cd /path/to/sherpa-onnx
|
|
curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-nemo-ctc-en-conformer-medium.tar.bz2
|
|
tar xvf sherpa-onnx-nemo-ctc-en-conformer-medium.tar.bz2
|
|
rm sherpa-onnx-nemo-ctc-en-conformer-medium.tar.bz2
|
|
|
|
python3 ./python-api-examples/non_streaming_server.py \
|
|
--nemo-ctc ./sherpa-onnx-nemo-ctc-en-conformer-medium/model.onnx \
|
|
--tokens ./sherpa-onnx-nemo-ctc-en-conformer-medium/tokens.txt
|
|
|
|
(4) Use a non-streaming CTC model from WeNet
|
|
|
|
cd /path/to/sherpa-onnx
|
|
curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-zh-wenet-wenetspeech.tar.bz2
|
|
tar xvf sherpa-onnx-zh-wenet-wenetspeech.tar.bz2
|
|
rm sherpa-onnx-zh-wenet-wenetspeech.tar.bz2
|
|
|
|
python3 ./python-api-examples/non_streaming_server.py \
|
|
--wenet-ctc ./sherpa-onnx-zh-wenet-wenetspeech/model.onnx \
|
|
--tokens ./sherpa-onnx-zh-wenet-wenetspeech/tokens.txt
|
|
|
|
(5) Use a Moonshine model
|
|
|
|
cd /path/to/sherpa-onnx
|
|
curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-moonshine-tiny-en-int8.tar.bz2
|
|
tar xvf sherpa-onnx-moonshine-tiny-en-int8.tar.bz2
|
|
rm sherpa-onnx-moonshine-tiny-en-int8.tar.bz2
|
|
|
|
python3 ./python-api-examples/non_streaming_server.py \
|
|
--moonshine-preprocessor=./sherpa-onnx-moonshine-tiny-en-int8/preprocess.onnx \
|
|
--moonshine-encoder=./sherpa-onnx-moonshine-tiny-en-int8/encode.int8.onnx \
|
|
--moonshine-uncached-decoder=./sherpa-onnx-moonshine-tiny-en-int8/uncached_decode.int8.onnx \
|
|
--moonshine-cached-decoder=./sherpa-onnx-moonshine-tiny-en-int8/cached_decode.int8.onnx \
|
|
--tokens=./sherpa-onnx-moonshine-tiny-en-int8/tokens.txt
|
|
|
|
(6) Use a Whisper model
|
|
|
|
cd /path/to/sherpa-onnx
|
|
curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-whisper-tiny.en.tar.bz2
|
|
tar xvf sherpa-onnx-whisper-tiny.en.tar.bz2
|
|
rm sherpa-onnx-whisper-tiny.en.tar.bz2
|
|
|
|
python3 ./python-api-examples/non_streaming_server.py \
|
|
--whisper-encoder=./sherpa-onnx-whisper-tiny.en/tiny.en-encoder.onnx \
|
|
--whisper-decoder=./sherpa-onnx-whisper-tiny.en/tiny.en-decoder.onnx \
|
|
--tokens=./sherpa-onnx-whisper-tiny.en/tiny.en-tokens.txt
|
|
|
|
(7) Use a tdnn model of the yesno recipe from icefall
|
|
|
|
cd /path/to/sherpa-onnx
|
|
|
|
curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-tdnn-yesno.tar.bz2
|
|
tar xvf sherpa-onnx-tdnn-yesno.tar.bz2
|
|
rm sherpa-onnx-tdnn-yesno.tar.bz2
|
|
|
|
python3 ./python-api-examples/non_streaming_server.py \
|
|
--sample-rate=8000 \
|
|
--feat-dim=23 \
|
|
--tdnn-model=./sherpa-onnx-tdnn-yesno/model-epoch-14-avg-2.onnx \
|
|
--tokens=./sherpa-onnx-tdnn-yesno/tokens.txt
|
|
|
|
(8) Use a Non-streaming SenseVoice model
|
|
|
|
curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17.tar.bz2
|
|
tar xvf sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17.tar.bz2
|
|
rm sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17.tar.bz2
|
|
|
|
python3 ./python-api-examples/non_streaming_server.py \
|
|
--sense-voice=./sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17/model.int8.onnx \
|
|
--tokens=./sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17/tokens.txt
|
|
|
|
(9) Use a Non-streaming telespeech ctc model
|
|
|
|
curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-telespeech-ctc-int8-zh-2024-06-04.tar.bz2
|
|
tar xvf sherpa-onnx-telespeech-ctc-int8-zh-2024-06-04.tar.bz2
|
|
rm sherpa-onnx-telespeech-ctc-int8-zh-2024-06-04.tar.bz2
|
|
|
|
python3 ./python-api-examples/non_streaming_server.py \
|
|
--telespeech-ctc=./sherpa-onnx-telespeech-ctc-int8-zh-2024-06-04/model.int8.onnx \
|
|
--tokens=./sherpa-onnx-telespeech-ctc-int8-zh-2024-06-04/tokens.txt
|
|
|
|
----
|
|
|
|
To use a certificate so that you can use https, please use
|
|
|
|
python3 ./python-api-examples/non_streaming_server.py \
|
|
--whisper-encoder=./sherpa-onnx-whisper-tiny.en/tiny.en-encoder.onnx \
|
|
--whisper-decoder=./sherpa-onnx-whisper-tiny.en/tiny.en-decoder.onnx \
|
|
--certificate=/path/to/your/cert.pem
|
|
|
|
If you don't have a certificate, please run:
|
|
|
|
cd ./python-api-examples/web
|
|
./generate-certificate.py
|
|
|
|
It will generate 3 files, one of which is the required `cert.pem`.
|
|
""" # noqa
|
|
|
|
import argparse
|
|
import asyncio
|
|
import http
|
|
import logging
|
|
import socket
|
|
import ssl
|
|
import sys
|
|
import warnings
|
|
from concurrent.futures import ThreadPoolExecutor
|
|
from datetime import datetime
|
|
from pathlib import Path
|
|
from typing import Optional, Tuple
|
|
|
|
import numpy as np
|
|
import sherpa_onnx
|
|
|
|
import websockets
|
|
|
|
from http_server import HttpServer
|
|
|
|
|
|
def setup_logger(
|
|
log_filename: str,
|
|
log_level: str = "info",
|
|
use_console: bool = True,
|
|
) -> None:
|
|
"""Setup log level.
|
|
|
|
Args:
|
|
log_filename:
|
|
The filename to save the log.
|
|
log_level:
|
|
The log level to use, e.g., "debug", "info", "warning", "error",
|
|
"critical"
|
|
use_console:
|
|
True to also print logs to console.
|
|
"""
|
|
now = datetime.now()
|
|
date_time = now.strftime("%Y-%m-%d-%H-%M-%S")
|
|
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
|
log_filename = f"{log_filename}-{date_time}.txt"
|
|
|
|
Path(log_filename).parent.mkdir(parents=True, exist_ok=True)
|
|
|
|
level = logging.ERROR
|
|
if log_level == "debug":
|
|
level = logging.DEBUG
|
|
elif log_level == "info":
|
|
level = logging.INFO
|
|
elif log_level == "warning":
|
|
level = logging.WARNING
|
|
elif log_level == "critical":
|
|
level = logging.CRITICAL
|
|
|
|
logging.basicConfig(
|
|
filename=log_filename,
|
|
format=formatter,
|
|
level=level,
|
|
filemode="w",
|
|
)
|
|
if use_console:
|
|
console = logging.StreamHandler()
|
|
console.setLevel(level)
|
|
console.setFormatter(logging.Formatter(formatter))
|
|
logging.getLogger("").addHandler(console)
|
|
|
|
|
|
def add_transducer_model_args(parser: argparse.ArgumentParser):
|
|
parser.add_argument(
|
|
"--encoder",
|
|
default="",
|
|
type=str,
|
|
help="Path to the transducer encoder model",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--decoder",
|
|
default="",
|
|
type=str,
|
|
help="Path to the transducer decoder model",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--joiner",
|
|
default="",
|
|
type=str,
|
|
help="Path to the transducer joiner model",
|
|
)
|
|
|
|
|
|
def add_paraformer_model_args(parser: argparse.ArgumentParser):
|
|
parser.add_argument(
|
|
"--paraformer",
|
|
default="",
|
|
type=str,
|
|
help="Path to the model.onnx from Paraformer",
|
|
)
|
|
|
|
|
|
def add_sense_voice_model_args(parser: argparse.ArgumentParser):
|
|
parser.add_argument(
|
|
"--sense-voice",
|
|
default="",
|
|
type=str,
|
|
help="Path to the model.onnx from SenseVoice",
|
|
)
|
|
|
|
|
|
def add_nemo_ctc_model_args(parser: argparse.ArgumentParser):
|
|
parser.add_argument(
|
|
"--nemo-ctc",
|
|
default="",
|
|
type=str,
|
|
help="Path to the model.onnx from NeMo CTC",
|
|
)
|
|
|
|
|
|
def add_telespeech_ctc_model_args(parser: argparse.ArgumentParser):
|
|
parser.add_argument(
|
|
"--telespeech-ctc",
|
|
default="",
|
|
type=str,
|
|
help="Path to the model.onnx from TeleSpeech CTC",
|
|
)
|
|
|
|
|
|
def add_wenet_ctc_model_args(parser: argparse.ArgumentParser):
|
|
parser.add_argument(
|
|
"--wenet-ctc",
|
|
default="",
|
|
type=str,
|
|
help="Path to the model.onnx from WeNet CTC",
|
|
)
|
|
|
|
|
|
def add_tdnn_ctc_model_args(parser: argparse.ArgumentParser):
|
|
parser.add_argument(
|
|
"--tdnn-model",
|
|
default="",
|
|
type=str,
|
|
help="Path to the model.onnx for the tdnn model of the yesno recipe",
|
|
)
|
|
|
|
|
|
def add_moonshine_model_args(parser: argparse.ArgumentParser):
|
|
parser.add_argument(
|
|
"--moonshine-preprocessor",
|
|
default="",
|
|
type=str,
|
|
help="Path to moonshine preprocessor model",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--moonshine-encoder",
|
|
default="",
|
|
type=str,
|
|
help="Path to moonshine encoder model",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--moonshine-uncached-decoder",
|
|
default="",
|
|
type=str,
|
|
help="Path to moonshine uncached decoder model",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--moonshine-cached-decoder",
|
|
default="",
|
|
type=str,
|
|
help="Path to moonshine cached decoder model",
|
|
)
|
|
|
|
|
|
def add_whisper_model_args(parser: argparse.ArgumentParser):
|
|
parser.add_argument(
|
|
"--whisper-encoder",
|
|
default="",
|
|
type=str,
|
|
help="Path to whisper encoder model",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--whisper-decoder",
|
|
default="",
|
|
type=str,
|
|
help="Path to whisper decoder model",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--whisper-language",
|
|
default="",
|
|
type=str,
|
|
help="""It specifies the spoken language in the input audio file.
|
|
Example values: en, fr, de, zh, jp.
|
|
Available languages for multilingual models can be found at
|
|
https://github.com/openai/whisper/blob/main/whisper/tokenizer.py#L10
|
|
If not specified, we infer the language from the input audio file.
|
|
""",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--whisper-task",
|
|
default="transcribe",
|
|
choices=["transcribe", "translate"],
|
|
type=str,
|
|
help="""For multilingual models, if you specify translate, the output
|
|
will be in English.
|
|
""",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--whisper-tail-paddings",
|
|
default=-1,
|
|
type=int,
|
|
help="""Number of tail padding frames.
|
|
We have removed the 30-second constraint from whisper, so you need to
|
|
choose the amount of tail padding frames by yourself.
|
|
Use -1 to use a default value for tail padding.
|
|
""",
|
|
)
|
|
|
|
|
|
def add_model_args(parser: argparse.ArgumentParser):
|
|
add_transducer_model_args(parser)
|
|
add_paraformer_model_args(parser)
|
|
add_sense_voice_model_args(parser)
|
|
add_nemo_ctc_model_args(parser)
|
|
add_wenet_ctc_model_args(parser)
|
|
add_telespeech_ctc_model_args(parser)
|
|
add_tdnn_ctc_model_args(parser)
|
|
add_whisper_model_args(parser)
|
|
add_moonshine_model_args(parser)
|
|
|
|
parser.add_argument(
|
|
"--tokens",
|
|
type=str,
|
|
help="Path to tokens.txt",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--num-threads",
|
|
type=int,
|
|
default=2,
|
|
help="Number of threads to run the neural network model",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--provider",
|
|
type=str,
|
|
default="cpu",
|
|
help="Valid values: cpu, cuda, coreml",
|
|
)
|
|
|
|
|
|
def add_feature_config_args(parser: argparse.ArgumentParser):
|
|
parser.add_argument(
|
|
"--sample-rate",
|
|
type=int,
|
|
default=16000,
|
|
help="Sample rate of the data used to train the model. ",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--feat-dim",
|
|
type=int,
|
|
default=80,
|
|
help="Feature dimension of the model",
|
|
)
|
|
|
|
|
|
def add_decoding_args(parser: argparse.ArgumentParser):
|
|
parser.add_argument(
|
|
"--decoding-method",
|
|
type=str,
|
|
default="greedy_search",
|
|
help="""Decoding method to use. Current supported methods are:
|
|
- greedy_search
|
|
- modified_beam_search (for transducer models only)
|
|
""",
|
|
)
|
|
|
|
add_modified_beam_search_args(parser)
|
|
|
|
|
|
def add_modified_beam_search_args(parser: argparse.ArgumentParser):
|
|
parser.add_argument(
|
|
"--max-active-paths",
|
|
type=int,
|
|
default=4,
|
|
help="""Used only when --decoding-method is modified_beam_search.
|
|
It specifies number of active paths to keep during decoding.
|
|
""",
|
|
)
|
|
|
|
|
|
def add_hotwords_args(parser: argparse.ArgumentParser):
|
|
parser.add_argument(
|
|
"--hotwords-file",
|
|
type=str,
|
|
default="",
|
|
help="""
|
|
The file containing hotwords, one words/phrases per line, and for each
|
|
phrase the bpe/cjkchar are separated by a space. For example:
|
|
|
|
▁HE LL O ▁WORLD
|
|
你 好 世 界
|
|
""",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--hotwords-score",
|
|
type=float,
|
|
default=1.5,
|
|
help="""
|
|
The hotword score of each token for biasing word/phrase. Used only if
|
|
--hotwords-file is given.
|
|
""",
|
|
)
|
|
|
|
|
|
def add_blank_penalty_args(parser: argparse.ArgumentParser):
|
|
parser.add_argument(
|
|
"--blank-penalty",
|
|
type=float,
|
|
default=0.0,
|
|
help="""
|
|
The penalty applied on blank symbol during decoding.
|
|
Note: It is a positive value that would be applied to logits like
|
|
this `logits[:, 0] -= blank_penalty` (suppose logits.shape is
|
|
[batch_size, vocab] and blank id is 0).
|
|
""",
|
|
)
|
|
|
|
|
|
def check_args(args):
|
|
if not Path(args.tokens).is_file():
|
|
raise ValueError(f"{args.tokens} does not exist")
|
|
|
|
if args.decoding_method not in (
|
|
"greedy_search",
|
|
"modified_beam_search",
|
|
):
|
|
raise ValueError(f"Unsupported decoding method {args.decoding_method}")
|
|
|
|
if args.decoding_method == "modified_beam_search":
|
|
assert args.num_active_paths > 0, args.num_active_paths
|
|
assert Path(args.encoder).is_file(), args.encoder
|
|
assert Path(args.decoder).is_file(), args.decoder
|
|
assert Path(args.joiner).is_file(), args.joiner
|
|
|
|
if args.hotwords_file != "":
|
|
assert args.decoding_method == "modified_beam_search", args.decoding_method
|
|
assert Path(args.hotwords_file).is_file(), args.hotwords_file
|
|
|
|
|
|
def get_args():
|
|
parser = argparse.ArgumentParser(
|
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
|
)
|
|
|
|
add_model_args(parser)
|
|
add_feature_config_args(parser)
|
|
add_decoding_args(parser)
|
|
add_hotwords_args(parser)
|
|
add_blank_penalty_args(parser)
|
|
|
|
parser.add_argument(
|
|
"--port",
|
|
type=int,
|
|
default=6006,
|
|
help="The server will listen on this port",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--max-batch-size",
|
|
type=int,
|
|
default=3,
|
|
help="""Max batch size for computation. Note if there are not enough
|
|
requests in the queue, it will wait for max_wait_ms time. After that,
|
|
even if there are not enough requests, it still sends the
|
|
available requests in the queue for computation.
|
|
""",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--max-wait-ms",
|
|
type=float,
|
|
default=5,
|
|
help="""Max time in millisecond to wait to build batches for inference.
|
|
If there are not enough requests in the feature queue to build a batch
|
|
of max_batch_size, it waits up to this time before fetching available
|
|
requests for computation.
|
|
""",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--nn-pool-size",
|
|
type=int,
|
|
default=1,
|
|
help="Number of threads for NN computation and decoding.",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--max-message-size",
|
|
type=int,
|
|
default=(1 << 20),
|
|
help="""Max message size in bytes.
|
|
The max size per message cannot exceed this limit.
|
|
""",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--max-queue-size",
|
|
type=int,
|
|
default=32,
|
|
help="Max number of messages in the queue for each connection.",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--max-active-connections",
|
|
type=int,
|
|
default=200,
|
|
help="""Maximum number of active connections. The server will refuse
|
|
to accept new connections once the current number of active connections
|
|
equals to this limit.
|
|
""",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--certificate",
|
|
type=str,
|
|
help="""Path to the X.509 certificate. You need it only if you want to
|
|
use a secure websocket connection, i.e., use wss:// instead of ws://.
|
|
You can use ./web/generate-certificate.py
|
|
to generate the certificate `cert.pem`.
|
|
Note ./web/generate-certificate.py will generate three files but you
|
|
only need to pass the generated cert.pem to this option.
|
|
""",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--doc-root",
|
|
type=str,
|
|
default="./python-api-examples/web",
|
|
help="Path to the web root",
|
|
)
|
|
|
|
return parser.parse_args()
|
|
|
|
|
|
class NonStreamingServer:
|
|
def __init__(
|
|
self,
|
|
recognizer: sherpa_onnx.OfflineRecognizer,
|
|
max_batch_size: int,
|
|
max_wait_ms: float,
|
|
nn_pool_size: int,
|
|
max_message_size: int,
|
|
max_queue_size: int,
|
|
max_active_connections: int,
|
|
doc_root: str,
|
|
certificate: Optional[str] = None,
|
|
):
|
|
"""
|
|
Args:
|
|
recognizer:
|
|
An instance of the sherpa_onnx.OfflineRecognizer.
|
|
max_batch_size:
|
|
Max batch size for inference.
|
|
max_wait_ms:
|
|
Max wait time in milliseconds in order to build a batch of
|
|
`max_batch_size`.
|
|
nn_pool_size:
|
|
Number of threads for the thread pool that is used for NN
|
|
computation and decoding.
|
|
max_message_size:
|
|
Max size in bytes per message.
|
|
max_queue_size:
|
|
Max number of messages in the queue for each connection.
|
|
max_active_connections:
|
|
Max number of active connections. Once number of active client
|
|
equals to this limit, the server refuses to accept new connections.
|
|
doc_root:
|
|
Path to the directory where files like index.html for the HTTP
|
|
server locate.
|
|
certificate:
|
|
Optional. If not None, it will use secure websocket.
|
|
You can use ./web/generate-certificate.py to generate
|
|
it (the default generated filename is `cert.pem`).
|
|
"""
|
|
self.recognizer = recognizer
|
|
|
|
self.certificate = certificate
|
|
self.http_server = HttpServer(doc_root)
|
|
|
|
self.nn_pool_size = nn_pool_size
|
|
self.nn_pool = ThreadPoolExecutor(
|
|
max_workers=nn_pool_size,
|
|
thread_name_prefix="nn",
|
|
)
|
|
|
|
self.stream_queue = asyncio.Queue()
|
|
|
|
self.max_wait_ms = max_wait_ms
|
|
self.max_batch_size = max_batch_size
|
|
self.max_message_size = max_message_size
|
|
self.max_queue_size = max_queue_size
|
|
self.max_active_connections = max_active_connections
|
|
|
|
self.current_active_connections = 0
|
|
self.sample_rate = int(recognizer.config.feat_config.sampling_rate)
|
|
|
|
async def process_request(
|
|
self,
|
|
path: str,
|
|
request_headers: websockets.Headers,
|
|
) -> Optional[Tuple[http.HTTPStatus, websockets.Headers, bytes]]:
|
|
if "sec-websocket-key" not in (
|
|
request_headers.headers # For new request_headers
|
|
if hasattr(request_headers, "headers")
|
|
else request_headers # For old request_headers
|
|
):
|
|
# This is a normal HTTP request
|
|
if path == "/":
|
|
path = "/index.html"
|
|
if path[-1] == "?":
|
|
path = path[:-1]
|
|
|
|
if path == "/streaming_record.html":
|
|
response = r"""
|
|
<!doctype html><html><head>
|
|
<title>Speech recognition with next-gen Kaldi</title><body>
|
|
<h2>Only
|
|
<a href="/upload.html">/upload.html</a>
|
|
and
|
|
<a href="/offline_record.html">/offline_record.html</a>
|
|
is available for the non-streaming server.<h2>
|
|
<br/>
|
|
<br/>
|
|
Go back to <a href="/upload.html">/upload.html</a>
|
|
or <a href="/offline_record.html">/offline_record.html</a>
|
|
</body></head></html>
|
|
"""
|
|
found = True
|
|
mime_type = "text/html"
|
|
else:
|
|
found, response, mime_type = self.http_server.process_request(path)
|
|
if isinstance(response, str):
|
|
response = response.encode("utf-8")
|
|
|
|
if not found:
|
|
status = http.HTTPStatus.NOT_FOUND
|
|
else:
|
|
status = http.HTTPStatus.OK
|
|
header = {"Content-Type": mime_type}
|
|
return status, header, response
|
|
|
|
if self.current_active_connections < self.max_active_connections:
|
|
self.current_active_connections += 1
|
|
return None
|
|
|
|
# Refuse new connections
|
|
status = http.HTTPStatus.SERVICE_UNAVAILABLE # 503
|
|
header = {"Hint": "The server is overloaded. Please retry later."}
|
|
response = b"The server is busy. Please retry later."
|
|
|
|
return status, header, response
|
|
|
|
async def run(self, port: int):
|
|
logging.info("started")
|
|
|
|
tasks = []
|
|
for i in range(self.nn_pool_size):
|
|
tasks.append(asyncio.create_task(self.stream_consumer_task()))
|
|
|
|
if self.certificate:
|
|
logging.info(f"Using certificate: {self.certificate}")
|
|
ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
|
|
ssl_context.load_cert_chain(self.certificate)
|
|
else:
|
|
ssl_context = None
|
|
logging.info("No certificate provided")
|
|
|
|
async with websockets.serve(
|
|
self.handle_connection,
|
|
host="",
|
|
port=port,
|
|
max_size=self.max_message_size,
|
|
max_queue=self.max_queue_size,
|
|
process_request=self.process_request,
|
|
ssl=ssl_context,
|
|
):
|
|
ip_list = ["localhost"]
|
|
if ssl_context:
|
|
ip_list += ["0.0.0.0", "127.0.0.1"]
|
|
ip_list.append(socket.gethostbyname(socket.gethostname()))
|
|
|
|
proto = "http://" if ssl_context is None else "https://"
|
|
s = "Please visit one of the following addresses:\n\n"
|
|
for p in ip_list:
|
|
s += " " + proto + p + f":{port}" "\n"
|
|
logging.info(s)
|
|
|
|
await asyncio.Future() # run forever
|
|
|
|
await asyncio.gather(*tasks) # not reachable
|
|
|
|
async def recv_audio_samples(
|
|
self,
|
|
socket: websockets.WebSocketServerProtocol,
|
|
) -> Tuple[Optional[np.ndarray], Optional[float]]:
|
|
"""Receive a tensor from the client.
|
|
|
|
The message from the client is a **bytes** buffer.
|
|
|
|
The first message can be either "Done" meaning the client won't send
|
|
anything in the future or it can be a buffer containing 8 bytes.
|
|
The first 4 bytes in little endian specifies the sample
|
|
rate of the audio samples; the second 4 bytes in little endian specifies
|
|
the number of bytes in the audio file, which will be sent by the client
|
|
in the subsequent messages.
|
|
Since there is a limit in the message size posed by the websocket
|
|
protocol, the client may send the audio file in multiple messages if the
|
|
audio file is very large.
|
|
|
|
The second and remaining messages contain audio samples.
|
|
|
|
Please refer to ./offline-websocket-client-decode-files-paralell.py
|
|
and ./offline-websocket-client-decode-files-sequential.py
|
|
for how the client sends the message.
|
|
|
|
Args:
|
|
socket:
|
|
The socket for communicating with the client.
|
|
Returns:
|
|
Return a containing:
|
|
- 1-D np.float32 array containing the audio samples
|
|
- sample rate of the audio samples
|
|
or return (None, None) indicating the end of utterance.
|
|
"""
|
|
header = await socket.recv()
|
|
if header == "Done":
|
|
return None, None
|
|
|
|
assert len(header) >= 8, (
|
|
"The first message should contain at least 8 bytes."
|
|
+ f"Given {len(header)}"
|
|
)
|
|
|
|
sample_rate = int.from_bytes(header[:4], "little", signed=True)
|
|
expected_num_bytes = int.from_bytes(header[4:8], "little", signed=True)
|
|
|
|
received = []
|
|
num_received_bytes = 0
|
|
if len(header) > 8:
|
|
received.append(header[8:])
|
|
num_received_bytes += len(header) - 8
|
|
|
|
if num_received_bytes < expected_num_bytes:
|
|
async for message in socket:
|
|
received.append(message)
|
|
num_received_bytes += len(message)
|
|
if num_received_bytes >= expected_num_bytes:
|
|
break
|
|
|
|
assert num_received_bytes == expected_num_bytes, (
|
|
num_received_bytes,
|
|
expected_num_bytes,
|
|
)
|
|
|
|
samples = b"".join(received)
|
|
array = np.frombuffer(samples, dtype=np.float32)
|
|
return array, sample_rate
|
|
|
|
async def stream_consumer_task(self):
|
|
"""This function extracts streams from the queue, batches them up, sends
|
|
them to the RNN-T model for computation and decoding.
|
|
"""
|
|
while True:
|
|
if self.stream_queue.empty():
|
|
await asyncio.sleep(self.max_wait_ms / 1000)
|
|
continue
|
|
|
|
batch = []
|
|
try:
|
|
while len(batch) < self.max_batch_size:
|
|
item = self.stream_queue.get_nowait()
|
|
|
|
batch.append(item)
|
|
except asyncio.QueueEmpty:
|
|
pass
|
|
|
|
stream_list = [b[0] for b in batch]
|
|
future_list = [b[1] for b in batch]
|
|
|
|
loop = asyncio.get_running_loop()
|
|
await loop.run_in_executor(
|
|
self.nn_pool,
|
|
self.recognizer.decode_streams,
|
|
stream_list,
|
|
)
|
|
|
|
for f in future_list:
|
|
self.stream_queue.task_done()
|
|
f.set_result(None)
|
|
|
|
async def compute_and_decode(
|
|
self,
|
|
stream: sherpa_onnx.OfflineStream,
|
|
) -> None:
|
|
"""Put the stream into the queue and wait it to be processed by the
|
|
consumer task.
|
|
|
|
Args:
|
|
stream:
|
|
The stream to be processed. Note: It is changed in-place.
|
|
"""
|
|
loop = asyncio.get_running_loop()
|
|
future = loop.create_future()
|
|
await self.stream_queue.put((stream, future))
|
|
await future
|
|
|
|
async def handle_connection(
|
|
self,
|
|
socket: websockets.WebSocketServerProtocol,
|
|
):
|
|
"""Receive audio samples from the client, process it, and sends
|
|
deocoding result back to the client.
|
|
|
|
Args:
|
|
socket:
|
|
The socket for communicating with the client.
|
|
"""
|
|
try:
|
|
await self.handle_connection_impl(socket)
|
|
except websockets.exceptions.ConnectionClosedError:
|
|
logging.info(f"{socket.remote_address} disconnected")
|
|
finally:
|
|
# Decrement so that it can accept new connections
|
|
self.current_active_connections -= 1
|
|
|
|
logging.info(
|
|
f"Disconnected: {socket.remote_address}. "
|
|
f"Number of connections: {self.current_active_connections}/{self.max_active_connections}" # noqa
|
|
)
|
|
|
|
async def handle_connection_impl(
|
|
self,
|
|
socket: websockets.WebSocketServerProtocol,
|
|
):
|
|
"""Receive audio samples from the client, process it, and send
|
|
decoding results back to the client.
|
|
|
|
Args:
|
|
socket:
|
|
The socket for communicating with the client.
|
|
"""
|
|
logging.info(
|
|
f"Connected: {socket.remote_address}. "
|
|
f"Number of connections: {self.current_active_connections}/{self.max_active_connections}" # noqa
|
|
)
|
|
|
|
while True:
|
|
stream = self.recognizer.create_stream()
|
|
samples, sample_rate = await self.recv_audio_samples(socket)
|
|
if samples is None:
|
|
break
|
|
# stream.accept_samples() runs in the main thread
|
|
|
|
stream.accept_waveform(sample_rate, samples)
|
|
|
|
await self.compute_and_decode(stream)
|
|
result = stream.result.text
|
|
logging.info(f"result: {result}")
|
|
|
|
if result:
|
|
await socket.send(result)
|
|
else:
|
|
# If result is an empty string, send something to the client.
|
|
# Otherwise, socket.send() is a no-op and the client will
|
|
# wait for a reply indefinitely.
|
|
await socket.send("<EMPTY>")
|
|
|
|
|
|
def assert_file_exists(filename: str):
|
|
assert Path(filename).is_file(), (
|
|
f"{filename} does not exist!\n"
|
|
"Please refer to "
|
|
"https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html to download it"
|
|
)
|
|
|
|
|
|
def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer:
|
|
if args.encoder:
|
|
assert len(args.paraformer) == 0, args.paraformer
|
|
assert len(args.sense_voice) == 0, args.sense_voice
|
|
assert len(args.nemo_ctc) == 0, args.nemo_ctc
|
|
assert len(args.wenet_ctc) == 0, args.wenet_ctc
|
|
assert len(args.telespeech_ctc) == 0, args.telespeech_ctc
|
|
assert len(args.whisper_encoder) == 0, args.whisper_encoder
|
|
assert len(args.whisper_decoder) == 0, args.whisper_decoder
|
|
assert len(args.tdnn_model) == 0, args.tdnn_model
|
|
assert len(args.moonshine_preprocessor) == 0, args.moonshine_preprocessor
|
|
assert len(args.moonshine_encoder) == 0, args.moonshine_encoder
|
|
assert (
|
|
len(args.moonshine_uncached_decoder) == 0
|
|
), args.moonshine_uncached_decoder
|
|
assert len(args.moonshine_cached_decoder) == 0, args.moonshine_cached_decoder
|
|
|
|
assert_file_exists(args.encoder)
|
|
assert_file_exists(args.decoder)
|
|
assert_file_exists(args.joiner)
|
|
|
|
recognizer = sherpa_onnx.OfflineRecognizer.from_transducer(
|
|
encoder=args.encoder,
|
|
decoder=args.decoder,
|
|
joiner=args.joiner,
|
|
tokens=args.tokens,
|
|
num_threads=args.num_threads,
|
|
sample_rate=args.sample_rate,
|
|
feature_dim=args.feat_dim,
|
|
decoding_method=args.decoding_method,
|
|
max_active_paths=args.max_active_paths,
|
|
hotwords_file=args.hotwords_file,
|
|
hotwords_score=args.hotwords_score,
|
|
blank_penalty=args.blank_penalty,
|
|
provider=args.provider,
|
|
)
|
|
elif args.paraformer:
|
|
assert len(args.sense_voice) == 0, args.sense_voice
|
|
assert len(args.nemo_ctc) == 0, args.nemo_ctc
|
|
assert len(args.wenet_ctc) == 0, args.wenet_ctc
|
|
assert len(args.telespeech_ctc) == 0, args.telespeech_ctc
|
|
assert len(args.whisper_encoder) == 0, args.whisper_encoder
|
|
assert len(args.whisper_decoder) == 0, args.whisper_decoder
|
|
assert len(args.tdnn_model) == 0, args.tdnn_model
|
|
assert len(args.moonshine_preprocessor) == 0, args.moonshine_preprocessor
|
|
assert len(args.moonshine_encoder) == 0, args.moonshine_encoder
|
|
assert (
|
|
len(args.moonshine_uncached_decoder) == 0
|
|
), args.moonshine_uncached_decoder
|
|
assert len(args.moonshine_cached_decoder) == 0, args.moonshine_cached_decoder
|
|
|
|
assert_file_exists(args.paraformer)
|
|
|
|
recognizer = sherpa_onnx.OfflineRecognizer.from_paraformer(
|
|
paraformer=args.paraformer,
|
|
tokens=args.tokens,
|
|
num_threads=args.num_threads,
|
|
sample_rate=args.sample_rate,
|
|
feature_dim=args.feat_dim,
|
|
decoding_method=args.decoding_method,
|
|
provider=args.provider,
|
|
)
|
|
elif args.sense_voice:
|
|
assert len(args.nemo_ctc) == 0, args.nemo_ctc
|
|
assert len(args.wenet_ctc) == 0, args.wenet_ctc
|
|
assert len(args.telespeech_ctc) == 0, args.telespeech_ctc
|
|
assert len(args.whisper_encoder) == 0, args.whisper_encoder
|
|
assert len(args.whisper_decoder) == 0, args.whisper_decoder
|
|
assert len(args.tdnn_model) == 0, args.tdnn_model
|
|
assert len(args.moonshine_preprocessor) == 0, args.moonshine_preprocessor
|
|
assert len(args.moonshine_encoder) == 0, args.moonshine_encoder
|
|
assert (
|
|
len(args.moonshine_uncached_decoder) == 0
|
|
), args.moonshine_uncached_decoder
|
|
assert len(args.moonshine_cached_decoder) == 0, args.moonshine_cached_decoder
|
|
|
|
assert_file_exists(args.sense_voice)
|
|
recognizer = sherpa_onnx.OfflineRecognizer.from_sense_voice(
|
|
model=args.sense_voice,
|
|
tokens=args.tokens,
|
|
num_threads=args.num_threads,
|
|
use_itn=True,
|
|
)
|
|
elif args.nemo_ctc:
|
|
assert len(args.wenet_ctc) == 0, args.wenet_ctc
|
|
assert len(args.telespeech_ctc) == 0, args.telespeech_ctc
|
|
assert len(args.whisper_encoder) == 0, args.whisper_encoder
|
|
assert len(args.whisper_decoder) == 0, args.whisper_decoder
|
|
assert len(args.tdnn_model) == 0, args.tdnn_model
|
|
assert len(args.moonshine_preprocessor) == 0, args.moonshine_preprocessor
|
|
assert len(args.moonshine_encoder) == 0, args.moonshine_encoder
|
|
assert (
|
|
len(args.moonshine_uncached_decoder) == 0
|
|
), args.moonshine_uncached_decoder
|
|
assert len(args.moonshine_cached_decoder) == 0, args.moonshine_cached_decoder
|
|
|
|
assert_file_exists(args.nemo_ctc)
|
|
|
|
recognizer = sherpa_onnx.OfflineRecognizer.from_nemo_ctc(
|
|
model=args.nemo_ctc,
|
|
tokens=args.tokens,
|
|
num_threads=args.num_threads,
|
|
sample_rate=args.sample_rate,
|
|
feature_dim=args.feat_dim,
|
|
decoding_method=args.decoding_method,
|
|
provider=args.provider,
|
|
)
|
|
elif args.wenet_ctc:
|
|
assert len(args.telespeech_ctc) == 0, args.telespeech_ctc
|
|
assert len(args.whisper_encoder) == 0, args.whisper_encoder
|
|
assert len(args.whisper_decoder) == 0, args.whisper_decoder
|
|
assert len(args.tdnn_model) == 0, args.tdnn_model
|
|
assert len(args.moonshine_preprocessor) == 0, args.moonshine_preprocessor
|
|
assert len(args.moonshine_encoder) == 0, args.moonshine_encoder
|
|
assert (
|
|
len(args.moonshine_uncached_decoder) == 0
|
|
), args.moonshine_uncached_decoder
|
|
assert len(args.moonshine_cached_decoder) == 0, args.moonshine_cached_decoder
|
|
|
|
assert_file_exists(args.wenet_ctc)
|
|
|
|
recognizer = sherpa_onnx.OfflineRecognizer.from_wenet_ctc(
|
|
model=args.wenet_ctc,
|
|
tokens=args.tokens,
|
|
num_threads=args.num_threads,
|
|
sample_rate=args.sample_rate,
|
|
feature_dim=args.feat_dim,
|
|
decoding_method=args.decoding_method,
|
|
provider=args.provider,
|
|
)
|
|
elif args.telespeech_ctc:
|
|
assert len(args.whisper_encoder) == 0, args.whisper_encoder
|
|
assert len(args.whisper_decoder) == 0, args.whisper_decoder
|
|
assert len(args.tdnn_model) == 0, args.tdnn_model
|
|
assert len(args.moonshine_preprocessor) == 0, args.moonshine_preprocessor
|
|
assert len(args.moonshine_encoder) == 0, args.moonshine_encoder
|
|
assert (
|
|
len(args.moonshine_uncached_decoder) == 0
|
|
), args.moonshine_uncached_decoder
|
|
assert len(args.moonshine_cached_decoder) == 0, args.moonshine_cached_decoder
|
|
|
|
assert_file_exists(args.telespeech_ctc)
|
|
|
|
recognizer = sherpa_onnx.OfflineRecognizer.from_telespeech_ctc(
|
|
model=args.telespeech_ctc,
|
|
tokens=args.tokens,
|
|
num_threads=args.num_threads,
|
|
sample_rate=args.sample_rate,
|
|
feature_dim=args.feat_dim,
|
|
decoding_method=args.decoding_method,
|
|
provider=args.provider,
|
|
)
|
|
elif args.whisper_encoder:
|
|
assert len(args.tdnn_model) == 0, args.tdnn_model
|
|
assert_file_exists(args.whisper_encoder)
|
|
assert_file_exists(args.whisper_decoder)
|
|
assert len(args.moonshine_preprocessor) == 0, args.moonshine_preprocessor
|
|
assert len(args.moonshine_encoder) == 0, args.moonshine_encoder
|
|
assert (
|
|
len(args.moonshine_uncached_decoder) == 0
|
|
), args.moonshine_uncached_decoder
|
|
assert len(args.moonshine_cached_decoder) == 0, args.moonshine_cached_decoder
|
|
|
|
recognizer = sherpa_onnx.OfflineRecognizer.from_whisper(
|
|
encoder=args.whisper_encoder,
|
|
decoder=args.whisper_decoder,
|
|
tokens=args.tokens,
|
|
num_threads=args.num_threads,
|
|
decoding_method=args.decoding_method,
|
|
language=args.whisper_language,
|
|
task=args.whisper_task,
|
|
tail_paddings=args.whisper_tail_paddings,
|
|
provider=args.provider,
|
|
)
|
|
elif args.tdnn_model:
|
|
assert_file_exists(args.tdnn_model)
|
|
assert len(args.moonshine_preprocessor) == 0, args.moonshine_preprocessor
|
|
assert len(args.moonshine_encoder) == 0, args.moonshine_encoder
|
|
assert (
|
|
len(args.moonshine_uncached_decoder) == 0
|
|
), args.moonshine_uncached_decoder
|
|
assert len(args.moonshine_cached_decoder) == 0, args.moonshine_cached_decoder
|
|
|
|
recognizer = sherpa_onnx.OfflineRecognizer.from_tdnn_ctc(
|
|
model=args.tdnn_model,
|
|
tokens=args.tokens,
|
|
sample_rate=args.sample_rate,
|
|
feature_dim=args.feat_dim,
|
|
num_threads=args.num_threads,
|
|
decoding_method=args.decoding_method,
|
|
provider=args.provider,
|
|
)
|
|
elif args.moonshine_preprocessor:
|
|
assert_file_exists(args.moonshine_preprocessor)
|
|
assert_file_exists(args.moonshine_encoder)
|
|
assert_file_exists(args.moonshine_uncached_decoder)
|
|
assert_file_exists(args.moonshine_cached_decoder)
|
|
|
|
recognizer = sherpa_onnx.OfflineRecognizer.from_moonshine(
|
|
preprocessor=args.moonshine_preprocessor,
|
|
encoder=args.moonshine_encoder,
|
|
uncached_decoder=args.moonshine_uncached_decoder,
|
|
cached_decoder=args.moonshine_cached_decoder,
|
|
tokens=args.tokens,
|
|
num_threads=args.num_threads,
|
|
decoding_method=args.decoding_method,
|
|
)
|
|
else:
|
|
raise ValueError("Please specify at least one model")
|
|
|
|
return recognizer
|
|
|
|
|
|
def main():
|
|
args = get_args()
|
|
logging.info(vars(args))
|
|
check_args(args)
|
|
|
|
recognizer = create_recognizer(args)
|
|
|
|
port = args.port
|
|
max_wait_ms = args.max_wait_ms
|
|
max_batch_size = args.max_batch_size
|
|
nn_pool_size = args.nn_pool_size
|
|
max_message_size = args.max_message_size
|
|
max_queue_size = args.max_queue_size
|
|
max_active_connections = args.max_active_connections
|
|
certificate = args.certificate
|
|
doc_root = args.doc_root
|
|
|
|
if certificate and not Path(certificate).is_file():
|
|
raise ValueError(f"{certificate} does not exist")
|
|
|
|
if not Path(doc_root).is_dir():
|
|
raise ValueError(f"Directory {doc_root} does not exist")
|
|
|
|
non_streaming_server = NonStreamingServer(
|
|
recognizer=recognizer,
|
|
max_wait_ms=max_wait_ms,
|
|
max_batch_size=max_batch_size,
|
|
nn_pool_size=nn_pool_size,
|
|
max_message_size=max_message_size,
|
|
max_queue_size=max_queue_size,
|
|
max_active_connections=max_active_connections,
|
|
certificate=certificate,
|
|
doc_root=doc_root,
|
|
)
|
|
asyncio.run(non_streaming_server.run(port))
|
|
|
|
|
|
if __name__ == "__main__":
|
|
log_filename = "log/log-non-streaming-server"
|
|
setup_logger(log_filename)
|
|
main()
|