mirror of
https://github.com/k2-fsa/sherpa-onnx.git
synced 2026-01-09 07:41:06 +08:00
Add C++ API for streaming zipformer ASR on RK NPU (#1908)
This commit is contained in:
parent
bafd1103d0
commit
4d79e6a007
@ -199,6 +199,7 @@ jobs:
|
||||
timeout_seconds: 200
|
||||
shell: bash
|
||||
command: |
|
||||
SHERPA_ONNX_VERSION=$(grep "SHERPA_ONNX_VERSION" ./CMakeLists.txt | cut -d " " -f 2 | cut -d '"' -f 2)
|
||||
git config --global user.email "csukuangfj@gmail.com"
|
||||
git config --global user.name "Fangjun Kuang"
|
||||
|
||||
@ -207,9 +208,10 @@ jobs:
|
||||
GIT_LFS_SKIP_SMUDGE=1 git clone https://csukuangfj:$HF_TOKEN@huggingface.co/csukuangfj/sherpa-onnx-libs huggingface
|
||||
|
||||
cd huggingface
|
||||
mkdir -p aarch64
|
||||
dst=aarch64/$SHERPA_ONNX_VERSION
|
||||
mkdir -p $dst
|
||||
|
||||
cp -v ../sherpa-onnx-*-shared*.tar.bz2 ./aarch64
|
||||
cp -v ../sherpa-onnx-*-shared*.tar.bz2 $dst/
|
||||
|
||||
git status
|
||||
git lfs track "*.bz2"
|
||||
|
||||
@ -124,6 +124,7 @@ jobs:
|
||||
timeout_seconds: 200
|
||||
shell: bash
|
||||
command: |
|
||||
SHERPA_ONNX_VERSION=$(grep "SHERPA_ONNX_VERSION" ./CMakeLists.txt | cut -d " " -f 2 | cut -d '"' -f 2)
|
||||
git config --global user.email "csukuangfj@gmail.com"
|
||||
git config --global user.name "Fangjun Kuang"
|
||||
|
||||
@ -132,9 +133,10 @@ jobs:
|
||||
GIT_LFS_SKIP_SMUDGE=1 git clone https://csukuangfj:$HF_TOKEN@huggingface.co/csukuangfj/sherpa-onnx-libs huggingface
|
||||
|
||||
cd huggingface
|
||||
mkdir -p aarch64
|
||||
dst=aarch64/$SHERPA_ONNX_VERSION
|
||||
mkdir -p $dst
|
||||
|
||||
cp -v ../sherpa-onnx-*-static.tar.bz2 ./aarch64
|
||||
cp -v ../sherpa-onnx-*-static.tar.bz2 $dst/
|
||||
|
||||
git status
|
||||
git lfs track "*.bz2"
|
||||
|
||||
8
.github/workflows/android-static.yaml
vendored
8
.github/workflows/android-static.yaml
vendored
@ -140,6 +140,7 @@ jobs:
|
||||
timeout_seconds: 200
|
||||
shell: bash
|
||||
command: |
|
||||
SHERPA_ONNX_VERSION=$(grep "SHERPA_ONNX_VERSION" ./CMakeLists.txt | cut -d " " -f 2 | cut -d '"' -f 2)
|
||||
git config --global user.email "csukuangfj@gmail.com"
|
||||
git config --global user.name "Fangjun Kuang"
|
||||
du -h -d1 .
|
||||
@ -150,8 +151,10 @@ jobs:
|
||||
GIT_LFS_SKIP_SMUDGE=1 git clone https://csukuangfj:$HF_TOKEN@huggingface.co/csukuangfj/sherpa-onnx-libs huggingface
|
||||
|
||||
cd huggingface
|
||||
dst=$SHERPA_ONNX_VERSION
|
||||
mkdir -p $dst
|
||||
|
||||
cp -v ../sherpa-onnx-*-android*.tar.bz2 ./
|
||||
cp -v ../sherpa-onnx-*-android*.tar.bz2 $dst/
|
||||
|
||||
git status
|
||||
git lfs track "*.bz2"
|
||||
@ -263,6 +266,7 @@ jobs:
|
||||
timeout_seconds: 200
|
||||
shell: bash
|
||||
command: |
|
||||
SHERPA_ONNX_VERSION=$(grep "SHERPA_ONNX_VERSION" ./CMakeLists.txt | cut -d " " -f 2 | cut -d '"' -f 2)
|
||||
git config --global user.email "csukuangfj@gmail.com"
|
||||
git config --global user.name "Fangjun Kuang"
|
||||
du -h -d1 .
|
||||
@ -273,7 +277,7 @@ jobs:
|
||||
GIT_LFS_SKIP_SMUDGE=1 git clone https://csukuangfj:$HF_TOKEN@huggingface.co/csukuangfj/sherpa-onnx-libs huggingface
|
||||
|
||||
cd huggingface
|
||||
dst=android/aar
|
||||
dst=android/aar/$SHERPA_ONNX_VERSION
|
||||
mkdir -p $dst
|
||||
|
||||
cp -v ../*.aar $dst
|
||||
|
||||
6
.github/workflows/arm-linux-gnueabihf.yaml
vendored
6
.github/workflows/arm-linux-gnueabihf.yaml
vendored
@ -199,6 +199,7 @@ jobs:
|
||||
timeout_seconds: 200
|
||||
shell: bash
|
||||
command: |
|
||||
SHERPA_ONNX_VERSION=$(grep "SHERPA_ONNX_VERSION" ./CMakeLists.txt | cut -d " " -f 2 | cut -d '"' -f 2)
|
||||
git config --global user.email "csukuangfj@gmail.com"
|
||||
git config --global user.name "Fangjun Kuang"
|
||||
|
||||
@ -206,9 +207,10 @@ jobs:
|
||||
GIT_LFS_SKIP_SMUDGE=1 git clone https://csukuangfj:$HF_TOKEN@huggingface.co/csukuangfj/sherpa-onnx-libs huggingface
|
||||
|
||||
cd huggingface
|
||||
mkdir -p arm32
|
||||
dst=arm32/$SHERPA_ONNX_VERSION
|
||||
mkdir -p $dst
|
||||
|
||||
cp -v ../sherpa-onnx-*.tar.bz2 ./arm32
|
||||
cp -v ../sherpa-onnx-*.tar.bz2 $dst/
|
||||
|
||||
git status
|
||||
git lfs track "*.bz2"
|
||||
|
||||
6
.github/workflows/dot-net.yaml
vendored
6
.github/workflows/dot-net.yaml
vendored
@ -83,6 +83,7 @@ jobs:
|
||||
timeout_seconds: 200
|
||||
shell: bash
|
||||
command: |
|
||||
SHERPA_ONNX_VERSION=$(grep "SHERPA_ONNX_VERSION" ./CMakeLists.txt | cut -d " " -f 2 | cut -d '"' -f 2)
|
||||
git config --global user.email "csukuangfj@gmail.com"
|
||||
git config --global user.name "Fangjun Kuang"
|
||||
|
||||
@ -95,9 +96,10 @@ jobs:
|
||||
cd huggingface
|
||||
git fetch
|
||||
git pull
|
||||
mkdir -p windows-for-dotnet
|
||||
dst=windows-for-dotnet/$SHERPA_ONNX_VERSION
|
||||
mkdir -p $dst
|
||||
|
||||
cp -v ../sherpa-onnx-*.tar.bz2 ./windows-for-dotnet
|
||||
cp -v ../sherpa-onnx-*.tar.bz2 $dst/
|
||||
|
||||
git status
|
||||
git lfs track "*.bz2"
|
||||
|
||||
8
.github/workflows/linux-jni-aarch64.yaml
vendored
8
.github/workflows/linux-jni-aarch64.yaml
vendored
@ -138,6 +138,7 @@ jobs:
|
||||
timeout_seconds: 200
|
||||
shell: bash
|
||||
command: |
|
||||
SHERPA_ONNX_VERSION=$(grep "SHERPA_ONNX_VERSION" ./CMakeLists.txt | cut -d " " -f 2 | cut -d '"' -f 2)
|
||||
git config --global user.email "csukuangfj@gmail.com"
|
||||
git config --global user.name "Fangjun Kuang"
|
||||
|
||||
@ -146,10 +147,11 @@ jobs:
|
||||
GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/sherpa-onnx-libs huggingface
|
||||
|
||||
cd huggingface
|
||||
mkdir -p jni
|
||||
dst=jni/$SHERPA_ONNX_VERSION
|
||||
mkdir -p $dst
|
||||
|
||||
cp -v ../sherpa-onnx-*.tar.bz2 ./jni
|
||||
cp -v ../*.jar ./jni
|
||||
cp -v ../sherpa-onnx-*.tar.bz2 $dst/
|
||||
cp -v ../*.jar $dst/
|
||||
|
||||
git status
|
||||
git lfs track "*.bz2"
|
||||
|
||||
8
.github/workflows/linux-jni.yaml
vendored
8
.github/workflows/linux-jni.yaml
vendored
@ -197,6 +197,7 @@ jobs:
|
||||
timeout_seconds: 200
|
||||
shell: bash
|
||||
command: |
|
||||
SHERPA_ONNX_VERSION=$(grep "SHERPA_ONNX_VERSION" ./CMakeLists.txt | cut -d " " -f 2 | cut -d '"' -f 2)
|
||||
git config --global user.email "csukuangfj@gmail.com"
|
||||
git config --global user.name "Fangjun Kuang"
|
||||
|
||||
@ -205,10 +206,11 @@ jobs:
|
||||
GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/sherpa-onnx-libs huggingface
|
||||
|
||||
cd huggingface
|
||||
mkdir -p jni
|
||||
dst=jni/$SHERPA_ONNX_VERSION
|
||||
mkdir -p $dst
|
||||
|
||||
cp -v ../sherpa-onnx-*.tar.bz2 ./jni
|
||||
cp -v ../*.jar ./jni
|
||||
cp -v ../sherpa-onnx-*.tar.bz2 $dst/
|
||||
cp -v ../*.jar $dst/
|
||||
|
||||
git status
|
||||
git lfs track "*.bz2"
|
||||
|
||||
6
.github/workflows/macos-jni.yaml
vendored
6
.github/workflows/macos-jni.yaml
vendored
@ -113,6 +113,7 @@ jobs:
|
||||
timeout_seconds: 200
|
||||
shell: bash
|
||||
command: |
|
||||
SHERPA_ONNX_VERSION=$(grep "SHERPA_ONNX_VERSION" ./CMakeLists.txt | cut -d " " -f 2 | cut -d '"' -f 2)
|
||||
git config --global user.email "csukuangfj@gmail.com"
|
||||
git config --global user.name "Fangjun Kuang"
|
||||
|
||||
@ -121,9 +122,10 @@ jobs:
|
||||
GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/sherpa-onnx-libs huggingface
|
||||
|
||||
cd huggingface
|
||||
mkdir -p jni
|
||||
dst=jni/$SHERPA_ONNX_VERSION
|
||||
mkdir -p $dst
|
||||
|
||||
cp -v ../sherpa-onnx-*.tar.bz2 ./jni
|
||||
cp -v ../sherpa-onnx-*.tar.bz2 $dst
|
||||
|
||||
git status
|
||||
git lfs track "*.bz2"
|
||||
|
||||
6
.github/workflows/riscv64-linux.yaml
vendored
6
.github/workflows/riscv64-linux.yaml
vendored
@ -239,6 +239,7 @@ jobs:
|
||||
timeout_seconds: 200
|
||||
shell: bash
|
||||
command: |
|
||||
SHERPA_ONNX_VERSION=$(grep "SHERPA_ONNX_VERSION" ./CMakeLists.txt | cut -d " " -f 2 | cut -d '"' -f 2)
|
||||
git config --global user.email "csukuangfj@gmail.com"
|
||||
git config --global user.name "Fangjun Kuang"
|
||||
|
||||
@ -248,9 +249,10 @@ jobs:
|
||||
GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/sherpa-onnx-libs huggingface
|
||||
|
||||
cd huggingface
|
||||
mkdir -p riscv64
|
||||
dst=riscv64/$SHERPA_ONNX_VERSION
|
||||
mkdir -p $dst
|
||||
|
||||
cp -v ../sherpa-onnx-*-shared.tar.bz2 ./riscv64
|
||||
cp -v ../sherpa-onnx-*-shared.tar.bz2 $dst/
|
||||
|
||||
git status
|
||||
git lfs track "*.bz2"
|
||||
|
||||
241
.github/workflows/rknn-linux-aarch64.yaml
vendored
Normal file
241
.github/workflows/rknn-linux-aarch64.yaml
vendored
Normal file
@ -0,0 +1,241 @@
|
||||
name: rknn-linux-aarch64
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- master
|
||||
- rknn-c-api-2
|
||||
tags:
|
||||
- 'v[0-9]+.[0-9]+.[0-9]+*'
|
||||
paths:
|
||||
- '.github/workflows/rknn-linux-aarch64.yaml'
|
||||
- 'cmake/**'
|
||||
- 'sherpa-onnx/csrc/*'
|
||||
- 'sherpa-onnx/csrc/rknn/*'
|
||||
- 'sherpa-onnx/c-api/*'
|
||||
- 'toolchains/aarch64-linux-gnu.toolchain.cmake'
|
||||
pull_request:
|
||||
branches:
|
||||
- master
|
||||
paths:
|
||||
- '.github/workflows/rknn-linux-aarch64.yaml'
|
||||
- 'cmake/**'
|
||||
- 'sherpa-onnx/csrc/*'
|
||||
- 'sherpa-onnx/csrc/rknn/*'
|
||||
- 'sherpa-onnx/c-api/*'
|
||||
- 'toolchains/aarch64-linux-gnu.toolchain.cmake'
|
||||
|
||||
workflow_dispatch:
|
||||
|
||||
concurrency:
|
||||
group: aarch64-linux-gnu-shared-${{ github.ref }}
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
rknn_linux_aarch64:
|
||||
runs-on: ${{ matrix.os }}
|
||||
name: rknn shared ${{ matrix.shared }}
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
include:
|
||||
- os: ubuntu-22.04-arm
|
||||
shared: ON
|
||||
- os: ubuntu-22.04-arm
|
||||
shared: OFF
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: ccache
|
||||
uses: hendrikmuhs/ccache-action@v1.2
|
||||
with:
|
||||
key: ${{ matrix.os }}-${{ matrix.shared }}-rknn-linux-aarch64
|
||||
|
||||
- name: Download rknn-toolkit2
|
||||
shell: bash
|
||||
run: |
|
||||
git clone --depth 1 https://github.com/airockchip/rknn-toolkit2
|
||||
|
||||
- name: Build sherpa-onnx
|
||||
shell: bash
|
||||
run: |
|
||||
export CMAKE_CXX_COMPILER_LAUNCHER=ccache
|
||||
export PATH="/usr/lib/ccache:/usr/local/opt/ccache/libexec:$PATH"
|
||||
cmake --version
|
||||
|
||||
echo "config: ${{ matrix.config }}"
|
||||
uname -a
|
||||
which gcc
|
||||
|
||||
gcc --version
|
||||
g++ --version
|
||||
|
||||
echo "pwd"
|
||||
|
||||
ls -lh
|
||||
|
||||
git clone --depth 1 --branch v1.2.12 https://github.com/alsa-project/alsa-lib
|
||||
pushd alsa-lib
|
||||
./gitcompile
|
||||
popd
|
||||
|
||||
export SHERPA_ONNX_RKNN_TOOLKIT2_PATH=$PWD/rknn-toolkit2
|
||||
export SHERPA_ONNX_RKNN_TOOLKIT2_LIB_DIR=$SHERPA_ONNX_RKNN_TOOLKIT2_PATH/rknpu2/runtime/Linux/librknn_api/aarch64
|
||||
export CPLUS_INCLUDE_PATH=$SHERPA_ONNX_RKNN_TOOLKIT2_PATH/rknpu2/runtime/Linux/librknn_api/include:$CPLUS_INCLUDE_PATH
|
||||
export CPLUS_INCLUDE_PATH=$PWD/alsa-lib/include:$CPLUS_INCLUDE_PATH
|
||||
export SHERPA_ONNX_ALSA_LIB_DIR=$PWD/alsa-lib/src/.libs
|
||||
|
||||
mkdir build
|
||||
cd build
|
||||
|
||||
BUILD_SHARED_LIBS=${{ matrix.shared }}
|
||||
|
||||
cmake \
|
||||
-DBUILD_SHARED_LIBS=ON \
|
||||
-DCMAKE_INSTALL_PREFIX=./install \
|
||||
-DSHERPA_ONNX_ENABLE_RKNN=ON \
|
||||
-DBUILD_SHARED_LIBS=$BUILD_SHARED_LIBS \
|
||||
..
|
||||
|
||||
make -j4 install
|
||||
|
||||
rm -rf install/lib/pkgconfig
|
||||
rm -fv install/lib/cargs.h
|
||||
rm -fv install/lib/libcargs.so
|
||||
|
||||
- name: Display system info
|
||||
shell: bash
|
||||
run: |
|
||||
uname -a
|
||||
gcc --version
|
||||
g++ --version
|
||||
|
||||
- name: Display generated files
|
||||
shell: bash
|
||||
run: |
|
||||
export SHERPA_ONNX_RKNN_TOOLKIT2_PATH=$PWD/rknn-toolkit2
|
||||
export LD_LIBRARY_PATH=$SHERPA_ONNX_RKNN_TOOLKIT2_PATH/rknpu2/runtime/Linux/librknn_api/aarch64:$LD_LIBRARY_PATH
|
||||
|
||||
cd build/install
|
||||
|
||||
ls -lh bin
|
||||
|
||||
echo "---"
|
||||
|
||||
ls -lh lib
|
||||
|
||||
file bin/sherpa-onnx
|
||||
|
||||
readelf -d bin/sherpa-onnx
|
||||
|
||||
ldd bin/sherpa-onnx
|
||||
|
||||
./bin/sherpa-onnx --help
|
||||
|
||||
- name: Copy files
|
||||
shell: bash
|
||||
run: |
|
||||
SHERPA_ONNX_VERSION=v$(grep "SHERPA_ONNX_VERSION" ./CMakeLists.txt | cut -d " " -f 2 | cut -d '"' -f 2)
|
||||
|
||||
if [[ ${{ matrix.shared }} == ON ]]; then
|
||||
suffix=shared
|
||||
else
|
||||
suffix=static
|
||||
fi
|
||||
|
||||
dst=sherpa-onnx-${SHERPA_ONNX_VERSION}-rknn-linux-aarch64-$suffix
|
||||
mkdir $dst
|
||||
|
||||
cp -a build/install/bin $dst/
|
||||
|
||||
if [[ ${{ matrix.shared }} == ON ]]; then
|
||||
cp -v build/install/lib/lib*.so $dst/
|
||||
fi
|
||||
|
||||
ls -lh build/install/lib
|
||||
ls -lh build/install/bin
|
||||
|
||||
ls -lh $dst/bin/
|
||||
echo "strip"
|
||||
strip $dst/bin/*
|
||||
|
||||
echo "after strip"
|
||||
ls -lh $dst/bin/
|
||||
|
||||
tree $dst
|
||||
|
||||
tar cjvf ${dst}.tar.bz2 $dst
|
||||
|
||||
- uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: sherpa-onnx-linux-linux-aarch64-shared-${{ matrix.shared }}
|
||||
path: sherpa-onnx-*linux-aarch64*.tar.bz2
|
||||
|
||||
# https://huggingface.co/docs/hub/spaces-github-actions
|
||||
- name: Publish to huggingface
|
||||
if: (github.repository_owner == 'csukuangfj' || github.repository_owner == 'k2-fsa') && (github.event_name == 'push' || github.event_name == 'workflow_dispatch')
|
||||
env:
|
||||
HF_TOKEN: ${{ secrets.HF_TOKEN }}
|
||||
uses: nick-fields/retry@v3
|
||||
with:
|
||||
max_attempts: 20
|
||||
timeout_seconds: 200
|
||||
shell: bash
|
||||
command: |
|
||||
SHERPA_ONNX_VERSION=$(grep "SHERPA_ONNX_VERSION" ./CMakeLists.txt | cut -d " " -f 2 | cut -d '"' -f 2)
|
||||
|
||||
git config --global user.email "csukuangfj@gmail.com"
|
||||
git config --global user.name "Fangjun Kuang"
|
||||
|
||||
rm -rf huggingface
|
||||
export GIT_CLONE_PROTECTION_ACTIVE=false
|
||||
GIT_LFS_SKIP_SMUDGE=1 git clone https://csukuangfj:$HF_TOKEN@huggingface.co/csukuangfj/sherpa-onnx-libs huggingface
|
||||
|
||||
cd huggingface
|
||||
dst=rknn-linux-aarch64/$SHERPA_ONNX_VERSION
|
||||
mkdir -p $dst
|
||||
|
||||
cp -v ../sherpa-onnx-*rknn*-*.tar.bz2 $dst
|
||||
|
||||
git status
|
||||
git lfs track "*.bz2"
|
||||
|
||||
git add .
|
||||
|
||||
git commit -m "upload sherpa-onnx-${SHERPA_ONNX_VERSION}-rknn-linux-aarch64.tar.bz2"
|
||||
|
||||
git push https://csukuangfj:$HF_TOKEN@huggingface.co/csukuangfj/sherpa-onnx-libs main
|
||||
|
||||
- name: Release pre-compiled binaries and libs for rknn linux aarch64
|
||||
if: github.repository_owner == 'k2-fsa' && github.event_name == 'push' && contains(github.ref, 'refs/tags/')
|
||||
uses: svenstaro/upload-release-action@v2
|
||||
with:
|
||||
file_glob: true
|
||||
overwrite: true
|
||||
file: sherpa-onnx-*linux-aarch64*.tar.bz2
|
||||
|
||||
- name: Release pre-compiled binaries and libs for rknn linux aarch64
|
||||
# if: github.repository_owner == 'csukuangfj' && github.event_name == 'push' && contains(github.ref, 'refs/tags/')
|
||||
uses: svenstaro/upload-release-action@v2
|
||||
with:
|
||||
file_glob: true
|
||||
overwrite: true
|
||||
file: sherpa-onnx-*linux-aarch64*.tar.bz2
|
||||
repo_name: k2-fsa/sherpa-onnx
|
||||
repo_token: ${{ secrets.UPLOAD_GH_SHERPA_ONNX_TOKEN }}
|
||||
tag: v1.10.45
|
||||
|
||||
- name: Test offline Moonshine
|
||||
if: matrix.build_type != 'Debug'
|
||||
shell: bash
|
||||
run: |
|
||||
du -h -d1 .
|
||||
export PATH=$PWD/build/install/bin:$PATH
|
||||
export EXE=sherpa-onnx-offline
|
||||
|
||||
readelf -d build/bin/sherpa-onnx-offline
|
||||
|
||||
.github/scripts/test-offline-moonshine.sh
|
||||
6
.github/workflows/windows-x64-jni.yaml
vendored
6
.github/workflows/windows-x64-jni.yaml
vendored
@ -92,6 +92,7 @@ jobs:
|
||||
timeout_seconds: 200
|
||||
shell: bash
|
||||
command: |
|
||||
SHERPA_ONNX_VERSION=$(grep "SHERPA_ONNX_VERSION" ./CMakeLists.txt | cut -d " " -f 2 | cut -d '"' -f 2)
|
||||
git config --global user.email "csukuangfj@gmail.com"
|
||||
git config --global user.name "Fangjun Kuang"
|
||||
|
||||
@ -100,9 +101,10 @@ jobs:
|
||||
GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/sherpa-onnx-libs huggingface
|
||||
|
||||
cd huggingface
|
||||
mkdir -p jni
|
||||
dst=jni/$SHERPA_ONNX_VERSION
|
||||
mkdir -p $dst
|
||||
|
||||
cp -v ../sherpa-onnx-*.tar.bz2 ./jni
|
||||
cp -v ../sherpa-onnx-*.tar.bz2 $dst
|
||||
|
||||
git status
|
||||
git lfs track "*.bz2"
|
||||
|
||||
6
.github/workflows/windows-x64.yaml
vendored
6
.github/workflows/windows-x64.yaml
vendored
@ -128,6 +128,7 @@ jobs:
|
||||
timeout_seconds: 200
|
||||
shell: bash
|
||||
command: |
|
||||
SHERPA_ONNX_VERSION=$(grep "SHERPA_ONNX_VERSION" ./CMakeLists.txt | cut -d " " -f 2 | cut -d '"' -f 2)
|
||||
git config --global user.email "csukuangfj@gmail.com"
|
||||
git config --global user.name "Fangjun Kuang"
|
||||
|
||||
@ -136,9 +137,10 @@ jobs:
|
||||
GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/sherpa-onnx-libs huggingface
|
||||
|
||||
cd huggingface
|
||||
mkdir -p win64
|
||||
dst=win64/$SHERPA_ONNX_VERSION
|
||||
mkdir -p $dst
|
||||
|
||||
cp -v ../sherpa-onnx-*.tar.bz2 ./win64
|
||||
cp -v ../sherpa-onnx-*.tar.bz2 $dst
|
||||
|
||||
git status
|
||||
git lfs track "*.bz2"
|
||||
|
||||
4
.github/workflows/windows-x86.yaml
vendored
4
.github/workflows/windows-x86.yaml
vendored
@ -131,6 +131,7 @@ jobs:
|
||||
timeout_seconds: 200
|
||||
shell: bash
|
||||
command: |
|
||||
SHERPA_ONNX_VERSION=$(grep "SHERPA_ONNX_VERSION" ./CMakeLists.txt | cut -d " " -f 2 | cut -d '"' -f 2)
|
||||
git config --global user.email "csukuangfj@gmail.com"
|
||||
git config --global user.name "Fangjun Kuang"
|
||||
|
||||
@ -139,7 +140,8 @@ jobs:
|
||||
GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/sherpa-onnx-libs huggingface
|
||||
|
||||
cd huggingface
|
||||
mkdir -p win32
|
||||
dst=win32/$SHERPA_ONNX_VERSION
|
||||
mkdir -p $dst
|
||||
|
||||
cp -v ../sherpa-onnx-*.tar.bz2 ./win32
|
||||
|
||||
|
||||
@ -45,6 +45,7 @@ option(SHERPA_ONNX_LINK_LIBSTDCPP_STATICALLY "True to link libstdc++ statically.
|
||||
option(SHERPA_ONNX_USE_PRE_INSTALLED_ONNXRUNTIME_IF_AVAILABLE "True to use pre-installed onnxruntime if available" ON)
|
||||
option(SHERPA_ONNX_ENABLE_SANITIZER "Whether to enable ubsan and asan" OFF)
|
||||
option(SHERPA_ONNX_BUILD_C_API_EXAMPLES "Whether to enable C API examples" ON)
|
||||
option(SHERPA_ONNX_ENABLE_RKNN "Whether to build for RKNN NPU " OFF)
|
||||
|
||||
set(SHERPA_ONNX_LINUX_ARM64_GPU_ONNXRUNTIME_VERSION "1.11.0" CACHE STRING "Used only for Linux ARM64 GPU. If you use Jetson nano b01, then please set it to 1.11.0. If you use Jetson Orin NX, then set it to 1.16.0.If you use NVIDIA Jetson Orin Nano Engineering Reference Developer Kit
|
||||
Super - Jetpack 6.2 [L4T 36.4.3], then set it to 1.18.1")
|
||||
@ -155,6 +156,7 @@ message(STATUS "SHERPA_ONNX_LINK_LIBSTDCPP_STATICALLY ${SHERPA_ONNX_LINK_LIBSTDC
|
||||
message(STATUS "SHERPA_ONNX_USE_PRE_INSTALLED_ONNXRUNTIME_IF_AVAILABLE ${SHERPA_ONNX_USE_PRE_INSTALLED_ONNXRUNTIME_IF_AVAILABLE}")
|
||||
message(STATUS "SHERPA_ONNX_ENABLE_SANITIZER: ${SHERPA_ONNX_ENABLE_SANITIZER}")
|
||||
message(STATUS "SHERPA_ONNX_BUILD_C_API_EXAMPLES: ${SHERPA_ONNX_BUILD_C_API_EXAMPLES}")
|
||||
message(STATUS "SHERPA_ONNX_ENABLE_RKNN: ${SHERPA_ONNX_ENABLE_RKNN}")
|
||||
|
||||
if(BUILD_SHARED_LIBS OR SHERPA_ONNX_ENABLE_JNI)
|
||||
set(CMAKE_CXX_VISIBILITY_PRESET hidden)
|
||||
@ -199,7 +201,7 @@ if(SHERPA_ONNX_ENABLE_DIRECTML)
|
||||
message(STATUS "DirectML is enabled")
|
||||
add_definitions(-DSHERPA_ONNX_ENABLE_DIRECTML=1)
|
||||
else()
|
||||
message(WARNING "DirectML is disabled")
|
||||
message(STATUS "DirectML is disabled")
|
||||
add_definitions(-DSHERPA_ONNX_ENABLE_DIRECTML=0)
|
||||
endif()
|
||||
|
||||
@ -267,6 +269,10 @@ message(STATUS "C++ Standard version: ${CMAKE_CXX_STANDARD}")
|
||||
|
||||
include(CheckIncludeFileCXX)
|
||||
|
||||
if(SHERPA_ONNX_ENABLE_RKNN)
|
||||
add_definitions(-DSHERPA_ONNX_ENABLE_RKNN=1)
|
||||
endif()
|
||||
|
||||
if(UNIX AND NOT APPLE AND NOT SHERPA_ONNX_ENABLE_WASM AND NOT CMAKE_SYSTEM_NAME STREQUAL Android AND NOT CMAKE_SYSTEM_NAME STREQUAL OHOS)
|
||||
check_include_file_cxx(alsa/asoundlib.h SHERPA_ONNX_HAS_ALSA)
|
||||
if(SHERPA_ONNX_HAS_ALSA)
|
||||
|
||||
@ -74,7 +74,7 @@ if [[ x"$BUILD_SHARED_LIBS" == x"" ]]; then
|
||||
fi
|
||||
|
||||
if [[ x"$SHERPA_ONNX_ENABLE_GPU" == x"" ]]; then
|
||||
# By default, use CPU
|
||||
# By default, don't use CPU
|
||||
SHERPA_ONNX_ENABLE_GPU=OFF
|
||||
fi
|
||||
|
||||
|
||||
101
build-rknn-linux-aarch64.sh
Executable file
101
build-rknn-linux-aarch64.sh
Executable file
@ -0,0 +1,101 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
set -ex
|
||||
|
||||
# Before you run this file, make sure you have first cloned
|
||||
# https://github.com/airockchip/rknn-toolkit2
|
||||
# and set the environment variable SHERPA_ONNX_RKNN_TOOLKIT2_PATH
|
||||
|
||||
if [ -z $SHERPA_ONNX_RKNN_TOOLKIT2_PATH ]; then
|
||||
SHERPA_ONNX_RKNN_TOOLKIT2_PATH=/star-fj/fangjun/open-source/rknn-toolkit2
|
||||
fi
|
||||
|
||||
if [ ! -d $SHERPA_ONNX_RKNN_TOOLKIT2_PATH ]; then
|
||||
echo "Please first clone https://github.com/airockchip/rknn-toolkit2"
|
||||
echo "You can use"
|
||||
echo " git clone --depth 1 https://github.com/airockchip/rknn-toolkit2"
|
||||
echo " export SHERPA_ONNX_RKNN_TOOLKIT2_PATH=$PWD/rknn-toolkit2"
|
||||
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [ ! -f $SHERPA_ONNX_RKNN_TOOLKIT2_PATH/rknpu2/runtime/Linux/librknn_api/include/rknn_api.h ]; then
|
||||
echo "$SHERPA_ONNX_RKNN_TOOLKIT2_PATH/rknpu2/runtime/Linux/librknn_api/include/rknn_api.h does not exist"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [ ! -f $SHERPA_ONNX_RKNN_TOOLKIT2_PATH/rknpu2/runtime/Linux/librknn_api/aarch64/librknnrt.so ]; then
|
||||
echo "$SHERPA_ONNX_RKNN_TOOLKIT2_PATH/rknpu2/runtime/Linux/librknn_api/aarch64/librknnrt.so does not exist"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
export SHERPA_ONNX_RKNN_TOOLKIT2_LIB_DIR=$SHERPA_ONNX_RKNN_TOOLKIT2_PATH/rknpu2/runtime/Linux/librknn_api/aarch64
|
||||
|
||||
export CPLUS_INCLUDE_PATH=$SHERPA_ONNX_RKNN_TOOLKIT2_PATH/rknpu2/runtime/Linux/librknn_api/include:$CPLUS_INCLUDE_PATH
|
||||
|
||||
if ! command -v aarch64-linux-gnu-gcc &> /dev/null; then
|
||||
echo "Please install a toolchain for cross-compiling."
|
||||
echo "You can refer to: "
|
||||
echo " https://k2-fsa.github.io/sherpa/onnx/install/rknn-linux-aarch64.html"
|
||||
echo "for help."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
|
||||
dir=$PWD/build-rknn-linux-aarch64
|
||||
mkdir -p $dir
|
||||
|
||||
cd $dir
|
||||
|
||||
if [ ! -f alsa-lib/src/.libs/libasound.so ]; then
|
||||
echo "Start to cross-compile alsa-lib"
|
||||
if [ ! -d alsa-lib ]; then
|
||||
git clone --depth 1 --branch v1.2.12 https://github.com/alsa-project/alsa-lib
|
||||
fi
|
||||
# If it shows:
|
||||
# ./gitcompile: line 79: libtoolize: command not found
|
||||
# Please use:
|
||||
# sudo apt-get install libtool m4 automake
|
||||
#
|
||||
pushd alsa-lib
|
||||
CC=aarch64-linux-gnu-gcc ./gitcompile --host=aarch64-linux-gnu
|
||||
popd
|
||||
echo "Finish cross-compiling alsa-lib"
|
||||
fi
|
||||
|
||||
export CPLUS_INCLUDE_PATH=$PWD/alsa-lib/include:$CPLUS_INCLUDE_PATH
|
||||
export SHERPA_ONNX_ALSA_LIB_DIR=$PWD/alsa-lib/src/.libs
|
||||
|
||||
if [[ x"$BUILD_SHARED_LIBS" == x"" ]]; then
|
||||
# By default, use shared link
|
||||
BUILD_SHARED_LIBS=ON
|
||||
fi
|
||||
|
||||
cmake \
|
||||
-DBUILD_PIPER_PHONMIZE_EXE=OFF \
|
||||
-DBUILD_PIPER_PHONMIZE_TESTS=OFF \
|
||||
-DBUILD_ESPEAK_NG_EXE=OFF \
|
||||
-DBUILD_ESPEAK_NG_TESTS=OFF \
|
||||
-DCMAKE_INSTALL_PREFIX=./install \
|
||||
-DCMAKE_BUILD_TYPE=Release \
|
||||
-DSHERPA_ONNX_ENABLE_GPU=OFF \
|
||||
-DBUILD_SHARED_LIBS=$BUILD_SHARED_LIBS \
|
||||
-DSHERPA_ONNX_ENABLE_TESTS=OFF \
|
||||
-DSHERPA_ONNX_ENABLE_PYTHON=OFF \
|
||||
-DSHERPA_ONNX_ENABLE_CHECK=OFF \
|
||||
-DSHERPA_ONNX_ENABLE_PORTAUDIO=OFF \
|
||||
-DSHERPA_ONNX_ENABLE_JNI=OFF \
|
||||
-DSHERPA_ONNX_ENABLE_C_API=ON \
|
||||
-DSHERPA_ONNX_ENABLE_WEBSOCKET=ON \
|
||||
-DSHERPA_ONNX_ENABLE_RKNN=ON \
|
||||
-DCMAKE_TOOLCHAIN_FILE=../toolchains/aarch64-linux-gnu.toolchain.cmake \
|
||||
..
|
||||
|
||||
make VERBOSE=1 -j4
|
||||
make install/strip
|
||||
|
||||
# Enable it if only needed
|
||||
# cp -v $SHERPA_ONNX_ALSA_LIB_DIR/libasound.so* ./install/lib/
|
||||
|
||||
# See also
|
||||
# https://github.com/airockchip/rknn-toolkit2/blob/master/rknpu2/examples/rknn_api_demo/build-linux.sh
|
||||
@ -150,7 +150,7 @@ if [ ! -f $src_dir/windows-x86/sherpa-onnx-c-api.dll ]; then
|
||||
if [ -f $windows_x86_wheel ]; then
|
||||
cp -v $windows_x86_wheel .
|
||||
else
|
||||
curl -OL https://$HF_MIRROR/csukuangfj/sherpa-onnx-libs/resolve/main/windows-for-dotnet/$windows_x86_wheel_filename
|
||||
curl -OL https://$HF_MIRROR/csukuangfj/sherpa-onnx-libs/resolve/main/windows-for-dotnet/$SHERPA_ONNX_VERSION/$windows_x86_wheel_filename
|
||||
fi
|
||||
tar xvf $windows_x86_wheel_filename
|
||||
cp -v sherpa-onnx-${SHERPA_ONNX_VERSION}-win-x86/*dll ../
|
||||
@ -169,7 +169,7 @@ if [ ! -f $src_dir/windows-arm64/sherpa-onnx-c-api.dll ]; then
|
||||
if [ -f $windows_arm64_wheel ]; then
|
||||
cp -v $windows_arm64_wheel .
|
||||
else
|
||||
curl -OL https://$HF_MIRROR/csukuangfj/sherpa-onnx-libs/resolve/main/windows-for-dotnet/$windows_arm64_wheel_filename
|
||||
curl -OL https://$HF_MIRROR/csukuangfj/sherpa-onnx-libs/resolve/main/windows-for-dotnet/$SHERPA_ONNX_VERSION/$windows_arm64_wheel_filename
|
||||
fi
|
||||
tar xvf $windows_arm64_wheel_filename
|
||||
cp -v sherpa-onnx-${SHERPA_ONNX_VERSION}-win-arm64/*dll ../
|
||||
|
||||
@ -151,6 +151,14 @@ list(APPEND sources
|
||||
online-punctuation-model-config.cc
|
||||
online-punctuation.cc
|
||||
)
|
||||
if(SHERPA_ONNX_ENABLE_RKNN)
|
||||
list(APPEND sources
|
||||
./rknn/online-stream-rknn.cc
|
||||
./rknn/online-transducer-greedy-search-decoder-rknn.cc
|
||||
./rknn/online-zipformer-transducer-model-rknn.cc
|
||||
)
|
||||
|
||||
endif()
|
||||
|
||||
if(SHERPA_ONNX_ENABLE_TTS)
|
||||
list(APPEND sources
|
||||
@ -230,6 +238,14 @@ if(SHERPA_ONNX_ENABLE_GPU)
|
||||
)
|
||||
endif()
|
||||
|
||||
if(SHERPA_ONNX_ENABLE_RKNN)
|
||||
if(DEFINED ENV{SHERPA_ONNX_RKNN_TOOLKIT2_LIB_DIR})
|
||||
target_link_libraries(sherpa-onnx-core -L$ENV{SHERPA_ONNX_RKNN_TOOLKIT2_LIB_DIR} -lrknnrt)
|
||||
else()
|
||||
target_link_libraries(sherpa-onnx-core rknnrt)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
if(BUILD_SHARED_LIBS AND NOT DEFINED onnxruntime_lib_files)
|
||||
target_link_libraries(sherpa-onnx-core onnxruntime)
|
||||
else()
|
||||
|
||||
@ -5,6 +5,7 @@
|
||||
#include "sherpa-onnx/csrc/file-utils.h"
|
||||
|
||||
#include <fstream>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
@ -22,4 +23,61 @@ void AssertFileExists(const std::string &filename) {
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<char> ReadFile(const std::string &filename) {
|
||||
std::ifstream input(filename, std::ios::binary);
|
||||
std::vector<char> buffer(std::istreambuf_iterator<char>(input), {});
|
||||
return buffer;
|
||||
}
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
std::vector<char> ReadFile(AAssetManager *mgr, const std::string &filename) {
|
||||
AAsset *asset = AAssetManager_open(mgr, filename.c_str(), AASSET_MODE_BUFFER);
|
||||
if (!asset) {
|
||||
__android_log_print(ANDROID_LOG_FATAL, "sherpa-onnx",
|
||||
"Read binary file: Load %s failed", filename.c_str());
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
auto p = reinterpret_cast<const char *>(AAsset_getBuffer(asset));
|
||||
size_t asset_length = AAsset_getLength(asset);
|
||||
|
||||
std::vector<char> buffer(p, p + asset_length);
|
||||
AAsset_close(asset);
|
||||
|
||||
return buffer;
|
||||
}
|
||||
#endif
|
||||
|
||||
#if __OHOS__
|
||||
std::vector<char> ReadFile(NativeResourceManager *mgr,
|
||||
const std::string &filename) {
|
||||
std::unique_ptr<RawFile, decltype(&OH_ResourceManager_CloseRawFile)> fp(
|
||||
OH_ResourceManager_OpenRawFile(mgr, filename.c_str()),
|
||||
OH_ResourceManager_CloseRawFile);
|
||||
|
||||
if (!fp) {
|
||||
std::ostringstream os;
|
||||
os << "Read file '" << filename << "' failed.";
|
||||
SHERPA_ONNX_LOGE("%s", os.str().c_str());
|
||||
return {};
|
||||
}
|
||||
|
||||
auto len = static_cast<int32_t>(OH_ResourceManager_GetRawFileSize(fp.get()));
|
||||
|
||||
std::vector<char> buffer(len);
|
||||
|
||||
int32_t n = OH_ResourceManager_ReadRawFile(fp.get(), buffer.data(), len);
|
||||
|
||||
if (n != len) {
|
||||
std::ostringstream os;
|
||||
os << "Read file '" << filename << "' failed. Number of bytes read: " << n
|
||||
<< ". Expected bytes to read: " << len;
|
||||
SHERPA_ONNX_LOGE("%s", os.str().c_str());
|
||||
return {};
|
||||
}
|
||||
|
||||
return buffer;
|
||||
}
|
||||
#endif
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
@ -7,6 +7,16 @@
|
||||
|
||||
#include <fstream>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
#include "android/asset_manager.h"
|
||||
#include "android/asset_manager_jni.h"
|
||||
#endif
|
||||
|
||||
#if __OHOS__
|
||||
#include "rawfile/raw_file_manager.h"
|
||||
#endif
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
@ -23,6 +33,17 @@ bool FileExists(const std::string &filename);
|
||||
*/
|
||||
void AssertFileExists(const std::string &filename);
|
||||
|
||||
std::vector<char> ReadFile(const std::string &filename);
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
std::vector<char> ReadFile(AAssetManager *mgr, const std::string &filename);
|
||||
#endif
|
||||
|
||||
#if __OHOS__
|
||||
std::vector<char> ReadFile(NativeResourceManager *mgr,
|
||||
const std::string &filename);
|
||||
#endif
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
#endif // SHERPA_ONNX_CSRC_FILE_UTILS_H_
|
||||
|
||||
@ -17,6 +17,7 @@
|
||||
#include "rawfile/raw_file_manager.h"
|
||||
#endif
|
||||
|
||||
#include "sherpa-onnx/csrc/file-utils.h"
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
#include "sherpa-onnx/csrc/onnx-utils.h"
|
||||
#include "sherpa-onnx/csrc/session.h"
|
||||
|
||||
@ -7,6 +7,7 @@
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "sherpa-onnx/csrc/file-utils.h"
|
||||
#include "sherpa-onnx/csrc/onnx-utils.h"
|
||||
#include "sherpa-onnx/csrc/session.h"
|
||||
#include "sherpa-onnx/csrc/text-utils.h"
|
||||
|
||||
@ -7,6 +7,7 @@
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "sherpa-onnx/csrc/file-utils.h"
|
||||
#include "sherpa-onnx/csrc/onnx-utils.h"
|
||||
#include "sherpa-onnx/csrc/session.h"
|
||||
#include "sherpa-onnx/csrc/text-utils.h"
|
||||
|
||||
@ -18,6 +18,7 @@
|
||||
#include "rawfile/raw_file_manager.h"
|
||||
#endif
|
||||
|
||||
#include "sherpa-onnx/csrc/file-utils.h"
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
#include "sherpa-onnx/csrc/offline-nemo-enc-dec-ctc-model.h"
|
||||
#include "sherpa-onnx/csrc/offline-tdnn-ctc-model.h"
|
||||
|
||||
@ -20,6 +20,7 @@
|
||||
#include "rawfile/raw_file_manager.h"
|
||||
#endif
|
||||
|
||||
#include "sherpa-onnx/csrc/file-utils.h"
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
#include "sherpa-onnx/csrc/onnx-utils.h"
|
||||
#include "sherpa-onnx/csrc/session.h"
|
||||
|
||||
@ -17,6 +17,7 @@
|
||||
#include "rawfile/raw_file_manager.h"
|
||||
#endif
|
||||
|
||||
#include "sherpa-onnx/csrc/file-utils.h"
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
#include "sherpa-onnx/csrc/onnx-utils.h"
|
||||
#include "sherpa-onnx/csrc/session.h"
|
||||
|
||||
@ -13,6 +13,7 @@
|
||||
#include "rawfile/raw_file_manager.h"
|
||||
#endif
|
||||
|
||||
#include "sherpa-onnx/csrc/file-utils.h"
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
#include "sherpa-onnx/csrc/onnx-utils.h"
|
||||
#include "sherpa-onnx/csrc/session.h"
|
||||
|
||||
@ -17,6 +17,7 @@
|
||||
#include "rawfile/raw_file_manager.h"
|
||||
#endif
|
||||
|
||||
#include "sherpa-onnx/csrc/file-utils.h"
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
#include "sherpa-onnx/csrc/onnx-utils.h"
|
||||
#include "sherpa-onnx/csrc/session.h"
|
||||
|
||||
@ -22,6 +22,7 @@
|
||||
#include "fst/extensions/far/far.h"
|
||||
#include "kaldifst/csrc/kaldi-fst-io.h"
|
||||
#include "onnxruntime_cxx_api.h" // NOLINT
|
||||
#include "sherpa-onnx/csrc/file-utils.h"
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
#include "sherpa-onnx/csrc/offline-recognizer-ctc-impl.h"
|
||||
#include "sherpa-onnx/csrc/offline-recognizer-fire-red-asr-impl.h"
|
||||
@ -31,7 +32,6 @@
|
||||
#include "sherpa-onnx/csrc/offline-recognizer-transducer-impl.h"
|
||||
#include "sherpa-onnx/csrc/offline-recognizer-transducer-nemo-impl.h"
|
||||
#include "sherpa-onnx/csrc/offline-recognizer-whisper-impl.h"
|
||||
#include "sherpa-onnx/csrc/onnx-utils.h"
|
||||
#include "sherpa-onnx/csrc/text-utils.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
@ -18,6 +18,7 @@
|
||||
#endif
|
||||
|
||||
#include "onnxruntime_cxx_api.h" // NOLINT
|
||||
#include "sherpa-onnx/csrc/file-utils.h"
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
#include "sherpa-onnx/csrc/onnx-utils.h"
|
||||
#include "sherpa-onnx/csrc/session.h"
|
||||
|
||||
@ -17,6 +17,7 @@
|
||||
#include "rawfile/raw_file_manager.h"
|
||||
#endif
|
||||
|
||||
#include "sherpa-onnx/csrc/file-utils.h"
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
#include "sherpa-onnx/csrc/onnx-utils.h"
|
||||
#include "sherpa-onnx/csrc/session.h"
|
||||
|
||||
@ -17,6 +17,7 @@
|
||||
#include "rawfile/raw_file_manager.h"
|
||||
#endif
|
||||
|
||||
#include "sherpa-onnx/csrc/file-utils.h"
|
||||
#include "sherpa-onnx/csrc/onnx-utils.h"
|
||||
#include "sherpa-onnx/csrc/session.h"
|
||||
|
||||
|
||||
@ -15,6 +15,7 @@
|
||||
#include "rawfile/raw_file_manager.h"
|
||||
#endif
|
||||
|
||||
#include "sherpa-onnx/csrc/file-utils.h"
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
#include "sherpa-onnx/csrc/onnx-utils.h"
|
||||
#include "sherpa-onnx/csrc/session.h"
|
||||
|
||||
@ -13,6 +13,7 @@
|
||||
#include "rawfile/raw_file_manager.h"
|
||||
#endif
|
||||
|
||||
#include "sherpa-onnx/csrc/file-utils.h"
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
#include "sherpa-onnx/csrc/onnx-utils.h"
|
||||
#include "sherpa-onnx/csrc/session.h"
|
||||
|
||||
@ -17,6 +17,7 @@
|
||||
#include "rawfile/raw_file_manager.h"
|
||||
#endif
|
||||
|
||||
#include "sherpa-onnx/csrc/file-utils.h"
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
#include "sherpa-onnx/csrc/offline-transducer-decoder.h"
|
||||
#include "sherpa-onnx/csrc/onnx-utils.h"
|
||||
|
||||
@ -18,6 +18,7 @@
|
||||
#include "rawfile/raw_file_manager.h"
|
||||
#endif
|
||||
|
||||
#include "sherpa-onnx/csrc/file-utils.h"
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
#include "sherpa-onnx/csrc/offline-transducer-decoder.h"
|
||||
#include "sherpa-onnx/csrc/onnx-utils.h"
|
||||
|
||||
@ -18,6 +18,7 @@
|
||||
#include "rawfile/raw_file_manager.h"
|
||||
#endif
|
||||
|
||||
#include "sherpa-onnx/csrc/file-utils.h"
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
#include "sherpa-onnx/csrc/onnx-utils.h"
|
||||
#include "sherpa-onnx/csrc/session.h"
|
||||
|
||||
@ -18,6 +18,7 @@
|
||||
#include "rawfile/raw_file_manager.h"
|
||||
#endif
|
||||
|
||||
#include "sherpa-onnx/csrc/file-utils.h"
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
#include "sherpa-onnx/csrc/onnx-utils.h"
|
||||
#include "sherpa-onnx/csrc/session.h"
|
||||
|
||||
@ -18,6 +18,7 @@
|
||||
#include "rawfile/raw_file_manager.h"
|
||||
#endif
|
||||
|
||||
#include "sherpa-onnx/csrc/file-utils.h"
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
#include "sherpa-onnx/csrc/onnx-utils.h"
|
||||
#include "sherpa-onnx/csrc/session.h"
|
||||
|
||||
@ -13,6 +13,7 @@
|
||||
#include "rawfile/raw_file_manager.h"
|
||||
#endif
|
||||
|
||||
#include "sherpa-onnx/csrc/file-utils.h"
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
#include "sherpa-onnx/csrc/onnx-utils.h"
|
||||
#include "sherpa-onnx/csrc/session.h"
|
||||
|
||||
@ -20,6 +20,7 @@
|
||||
#include "rawfile/raw_file_manager.h"
|
||||
#endif
|
||||
|
||||
#include "sherpa-onnx/csrc/file-utils.h"
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
#include "sherpa-onnx/csrc/onnx-utils.h"
|
||||
#include "sherpa-onnx/csrc/session.h"
|
||||
|
||||
@ -7,6 +7,7 @@
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "sherpa-onnx/csrc/file-utils.h"
|
||||
#include "sherpa-onnx/csrc/onnx-utils.h"
|
||||
#include "sherpa-onnx/csrc/session.h"
|
||||
#include "sherpa-onnx/csrc/text-utils.h"
|
||||
|
||||
@ -15,6 +15,7 @@
|
||||
#include "rawfile/raw_file_manager.h"
|
||||
#endif
|
||||
|
||||
#include "sherpa-onnx/csrc/file-utils.h"
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
#include "sherpa-onnx/csrc/onnx-utils.h"
|
||||
#include "sherpa-onnx/csrc/session.h"
|
||||
|
||||
@ -7,6 +7,7 @@
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "sherpa-onnx/csrc/file-utils.h"
|
||||
#include "sherpa-onnx/csrc/onnx-utils.h"
|
||||
#include "sherpa-onnx/csrc/session.h"
|
||||
#include "sherpa-onnx/csrc/text-utils.h"
|
||||
|
||||
@ -23,6 +23,7 @@
|
||||
|
||||
#include "onnxruntime_cxx_api.h" // NOLINT
|
||||
#include "sherpa-onnx/csrc/cat.h"
|
||||
#include "sherpa-onnx/csrc/file-utils.h"
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
#include "sherpa-onnx/csrc/online-transducer-decoder.h"
|
||||
#include "sherpa-onnx/csrc/onnx-utils.h"
|
||||
|
||||
@ -22,6 +22,7 @@
|
||||
|
||||
#include "onnxruntime_cxx_api.h" // NOLINT
|
||||
#include "sherpa-onnx/csrc/cat.h"
|
||||
#include "sherpa-onnx/csrc/file-utils.h"
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
#include "sherpa-onnx/csrc/online-transducer-decoder.h"
|
||||
#include "sherpa-onnx/csrc/onnx-utils.h"
|
||||
|
||||
@ -51,9 +51,20 @@ void OnlineModelConfig::Register(ParseOptions *po) {
|
||||
}
|
||||
|
||||
bool OnlineModelConfig::Validate() const {
|
||||
if (num_threads < 1) {
|
||||
SHERPA_ONNX_LOGE("num_threads should be > 0. Given %d", num_threads);
|
||||
return false;
|
||||
// For RK NPU, we reinterpret num_threads:
|
||||
//
|
||||
// For RK3588 only
|
||||
// num_threads == 1 -> Select a core randomly
|
||||
// num_threads == 0 -> Use NPU core 0
|
||||
// num_threads == -1 -> Use NPU core 1
|
||||
// num_threads == -2 -> Use NPU core 2
|
||||
// num_threads == -3 -> Use NPU core 0 and core 1
|
||||
// num_threads == -4 -> Use NPU core 0, core 1, and core 2
|
||||
if (provider_config.provider != "rknn") {
|
||||
if (num_threads < 1) {
|
||||
SHERPA_ONNX_LOGE("num_threads should be > 0. Given %d", num_threads);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
if (!tokens_buf.empty() && FileExists(tokens)) {
|
||||
|
||||
@ -18,6 +18,7 @@
|
||||
#endif
|
||||
|
||||
#include "sherpa-onnx/csrc/cat.h"
|
||||
#include "sherpa-onnx/csrc/file-utils.h"
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
#include "sherpa-onnx/csrc/onnx-utils.h"
|
||||
#include "sherpa-onnx/csrc/session.h"
|
||||
|
||||
@ -17,6 +17,7 @@
|
||||
#include "rawfile/raw_file_manager.h"
|
||||
#endif
|
||||
|
||||
#include "sherpa-onnx/csrc/file-utils.h"
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
#include "sherpa-onnx/csrc/onnx-utils.h"
|
||||
#include "sherpa-onnx/csrc/session.h"
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
// sherpa-onnx/csrc/online-recognizer-impl.cc
|
||||
//
|
||||
// Copyright (c) 2023 Xiaomi Corporation
|
||||
// Copyright (c) 2023-2025 Xiaomi Corporation
|
||||
|
||||
#include "sherpa-onnx/csrc/online-recognizer-impl.h"
|
||||
|
||||
@ -26,10 +26,31 @@
|
||||
#include "sherpa-onnx/csrc/onnx-utils.h"
|
||||
#include "sherpa-onnx/csrc/text-utils.h"
|
||||
|
||||
#if SHERPA_ONNX_ENABLE_RKNN
|
||||
#include "sherpa-onnx/csrc/rknn/online-recognizer-transducer-rknn-impl.h"
|
||||
#endif
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
std::unique_ptr<OnlineRecognizerImpl> OnlineRecognizerImpl::Create(
|
||||
const OnlineRecognizerConfig &config) {
|
||||
if (config.model_config.provider_config.provider == "rknn") {
|
||||
#if SHERPA_ONNX_ENABLE_RKNN
|
||||
// Currently, only zipformer v1 is suported for rknn
|
||||
if (config.model_config.transducer.encoder.empty()) {
|
||||
SHERPA_ONNX_LOGE(
|
||||
"Only Zipformer transducers are currently supported by rknn. "
|
||||
"Fallback to CPU");
|
||||
} else {
|
||||
return std::make_unique<OnlineRecognizerTransducerRknnImpl>(config);
|
||||
}
|
||||
#else
|
||||
SHERPA_ONNX_LOGE(
|
||||
"Please rebuild sherpa-onnx with -DSHERPA_ONNX_ENABLE_RKNN=ON if you "
|
||||
"want to use rknn. Fallback to CPU");
|
||||
#endif
|
||||
}
|
||||
|
||||
if (!config.model_config.transducer.encoder.empty()) {
|
||||
Ort::Env env(ORT_LOGGING_LEVEL_ERROR);
|
||||
|
||||
|
||||
@ -11,6 +11,7 @@
|
||||
#include <vector>
|
||||
|
||||
#include "onnxruntime_cxx_api.h" // NOLINT
|
||||
#include "sherpa-onnx/csrc/file-utils.h"
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
#include "sherpa-onnx/csrc/onnx-utils.h"
|
||||
#include "sherpa-onnx/csrc/session.h"
|
||||
|
||||
@ -23,7 +23,8 @@ class OnlineStream {
|
||||
public:
|
||||
explicit OnlineStream(const FeatureExtractorConfig &config = {},
|
||||
ContextGraphPtr context_graph = nullptr);
|
||||
~OnlineStream();
|
||||
|
||||
virtual ~OnlineStream();
|
||||
|
||||
/**
|
||||
@param sampling_rate The sampling_rate of the input waveform. If it does
|
||||
|
||||
@ -18,6 +18,7 @@
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
|
||||
#include "sherpa-onnx/csrc/file-utils.h"
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
#include "sherpa-onnx/csrc/online-conformer-transducer-model.h"
|
||||
#include "sherpa-onnx/csrc/online-lstm-transducer-model.h"
|
||||
|
||||
@ -25,6 +25,7 @@
|
||||
#endif
|
||||
|
||||
#include "sherpa-onnx/csrc/cat.h"
|
||||
#include "sherpa-onnx/csrc/file-utils.h"
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
#include "sherpa-onnx/csrc/online-transducer-decoder.h"
|
||||
#include "sherpa-onnx/csrc/onnx-utils.h"
|
||||
|
||||
@ -17,6 +17,7 @@
|
||||
#include "rawfile/raw_file_manager.h"
|
||||
#endif
|
||||
|
||||
#include "sherpa-onnx/csrc/file-utils.h"
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
#include "sherpa-onnx/csrc/onnx-utils.h"
|
||||
#include "sherpa-onnx/csrc/session.h"
|
||||
|
||||
@ -23,6 +23,7 @@
|
||||
|
||||
#include "onnxruntime_cxx_api.h" // NOLINT
|
||||
#include "sherpa-onnx/csrc/cat.h"
|
||||
#include "sherpa-onnx/csrc/file-utils.h"
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
#include "sherpa-onnx/csrc/online-transducer-decoder.h"
|
||||
#include "sherpa-onnx/csrc/onnx-utils.h"
|
||||
|
||||
@ -20,6 +20,7 @@
|
||||
#endif
|
||||
|
||||
#include "sherpa-onnx/csrc/cat.h"
|
||||
#include "sherpa-onnx/csrc/file-utils.h"
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
#include "sherpa-onnx/csrc/onnx-utils.h"
|
||||
#include "sherpa-onnx/csrc/session.h"
|
||||
|
||||
@ -25,6 +25,7 @@
|
||||
|
||||
#include "onnxruntime_cxx_api.h" // NOLINT
|
||||
#include "sherpa-onnx/csrc/cat.h"
|
||||
#include "sherpa-onnx/csrc/file-utils.h"
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
#include "sherpa-onnx/csrc/online-transducer-decoder.h"
|
||||
#include "sherpa-onnx/csrc/onnx-utils.h"
|
||||
|
||||
@ -13,15 +13,8 @@
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
#include "android/asset_manager.h"
|
||||
#include "android/asset_manager_jni.h"
|
||||
#include "android/log.h"
|
||||
#endif
|
||||
|
||||
#include "onnxruntime_cxx_api.h" // NOLINT
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
@ -305,63 +298,6 @@ void Print4D(const Ort::Value *v) {
|
||||
fprintf(stderr, "\n");
|
||||
}
|
||||
|
||||
std::vector<char> ReadFile(const std::string &filename) {
|
||||
std::ifstream input(filename, std::ios::binary);
|
||||
std::vector<char> buffer(std::istreambuf_iterator<char>(input), {});
|
||||
return buffer;
|
||||
}
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
std::vector<char> ReadFile(AAssetManager *mgr, const std::string &filename) {
|
||||
AAsset *asset = AAssetManager_open(mgr, filename.c_str(), AASSET_MODE_BUFFER);
|
||||
if (!asset) {
|
||||
__android_log_print(ANDROID_LOG_FATAL, "sherpa-onnx",
|
||||
"Read binary file: Load %s failed", filename.c_str());
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
auto p = reinterpret_cast<const char *>(AAsset_getBuffer(asset));
|
||||
size_t asset_length = AAsset_getLength(asset);
|
||||
|
||||
std::vector<char> buffer(p, p + asset_length);
|
||||
AAsset_close(asset);
|
||||
|
||||
return buffer;
|
||||
}
|
||||
#endif
|
||||
|
||||
#if __OHOS__
|
||||
std::vector<char> ReadFile(NativeResourceManager *mgr,
|
||||
const std::string &filename) {
|
||||
std::unique_ptr<RawFile, decltype(&OH_ResourceManager_CloseRawFile)> fp(
|
||||
OH_ResourceManager_OpenRawFile(mgr, filename.c_str()),
|
||||
OH_ResourceManager_CloseRawFile);
|
||||
|
||||
if (!fp) {
|
||||
std::ostringstream os;
|
||||
os << "Read file '" << filename << "' failed.";
|
||||
SHERPA_ONNX_LOGE("%s", os.str().c_str());
|
||||
return {};
|
||||
}
|
||||
|
||||
auto len = static_cast<int32_t>(OH_ResourceManager_GetRawFileSize(fp.get()));
|
||||
|
||||
std::vector<char> buffer(len);
|
||||
|
||||
int32_t n = OH_ResourceManager_ReadRawFile(fp.get(), buffer.data(), len);
|
||||
|
||||
if (n != len) {
|
||||
std::ostringstream os;
|
||||
os << "Read file '" << filename << "' failed. Number of bytes read: " << n
|
||||
<< ". Expected bytes to read: " << len;
|
||||
SHERPA_ONNX_LOGE("%s", os.str().c_str());
|
||||
return {};
|
||||
}
|
||||
|
||||
return buffer;
|
||||
}
|
||||
#endif
|
||||
|
||||
Ort::Value Repeat(OrtAllocator *allocator, Ort::Value *cur_encoder_out,
|
||||
const std::vector<int32_t> &hyps_num_split) {
|
||||
std::vector<int64_t> cur_encoder_out_shape =
|
||||
|
||||
@ -17,15 +17,6 @@
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
#include "android/asset_manager.h"
|
||||
#include "android/asset_manager_jni.h"
|
||||
#endif
|
||||
|
||||
#if __OHOS__
|
||||
#include "rawfile/raw_file_manager.h"
|
||||
#endif
|
||||
|
||||
#include "onnxruntime_cxx_api.h" // NOLINT
|
||||
|
||||
namespace sherpa_onnx {
|
||||
@ -101,17 +92,6 @@ void Fill(Ort::Value *tensor, T value) {
|
||||
std::fill(p, p + n, value);
|
||||
}
|
||||
|
||||
std::vector<char> ReadFile(const std::string &filename);
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
std::vector<char> ReadFile(AAssetManager *mgr, const std::string &filename);
|
||||
#endif
|
||||
|
||||
#if __OHOS__
|
||||
std::vector<char> ReadFile(NativeResourceManager *mgr,
|
||||
const std::string &filename);
|
||||
#endif
|
||||
|
||||
// TODO(fangjun): Document it
|
||||
Ort::Value Repeat(OrtAllocator *allocator, Ort::Value *cur_encoder_out,
|
||||
const std::vector<int32_t> &hyps_num_split);
|
||||
|
||||
19
sherpa-onnx/csrc/rknn/macros.h
Normal file
19
sherpa-onnx/csrc/rknn/macros.h
Normal file
@ -0,0 +1,19 @@
|
||||
// sherpa-onnx/csrc/macros.h
|
||||
//
|
||||
// Copyright 2025 Xiaomi Corporation
|
||||
|
||||
#ifndef SHERPA_ONNX_CSRC_RKNN_MACROS_H_
|
||||
#define SHERPA_ONNX_CSRC_RKNN_MACROS_H_
|
||||
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
|
||||
#define SHERPA_ONNX_RKNN_CHECK(ret, msg, ...) \
|
||||
do { \
|
||||
if (ret != RKNN_SUCC) { \
|
||||
SHERPA_ONNX_LOGE("Return code is: %d", ret); \
|
||||
SHERPA_ONNX_LOGE(msg, ##__VA_ARGS__); \
|
||||
SHERPA_ONNX_EXIT(-1); \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
#endif // SHERPA_ONNX_CSRC_RKNN_MACROS_H_
|
||||
231
sherpa-onnx/csrc/rknn/online-recognizer-transducer-rknn-impl.h
Normal file
231
sherpa-onnx/csrc/rknn/online-recognizer-transducer-rknn-impl.h
Normal file
@ -0,0 +1,231 @@
|
||||
// sherpa-onnx/csrc/rknn/online-recognizer-transducer-rknn-impl.h
|
||||
//
|
||||
// Copyright (c) 2025 Xiaomi Corporation
|
||||
|
||||
#ifndef SHERPA_ONNX_CSRC_RKNN_ONLINE_RECOGNIZER_TRANSDUCER_RKNN_IMPL_H_
|
||||
#define SHERPA_ONNX_CSRC_RKNN_ONLINE_RECOGNIZER_TRANSDUCER_RKNN_IMPL_H_
|
||||
|
||||
#include <algorithm>
|
||||
#include <memory>
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
#include "sherpa-onnx/csrc/online-recognizer-impl.h"
|
||||
#include "sherpa-onnx/csrc/online-recognizer.h"
|
||||
#include "sherpa-onnx/csrc/rknn/online-stream-rknn.h"
|
||||
#include "sherpa-onnx/csrc/rknn/online-transducer-greedy-search-decoder-rknn.h"
|
||||
#include "sherpa-onnx/csrc/rknn/online-zipformer-transducer-model-rknn.h"
|
||||
#include "sherpa-onnx/csrc/symbol-table.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
OnlineRecognizerResult Convert(const OnlineTransducerDecoderResultRknn &src,
|
||||
const SymbolTable &sym_table,
|
||||
float frame_shift_ms, int32_t subsampling_factor,
|
||||
int32_t segment, int32_t frames_since_start) {
|
||||
OnlineRecognizerResult r;
|
||||
r.tokens.reserve(src.tokens.size());
|
||||
r.timestamps.reserve(src.tokens.size());
|
||||
|
||||
std::string text;
|
||||
for (auto i : src.tokens) {
|
||||
auto sym = sym_table[i];
|
||||
|
||||
text.append(sym);
|
||||
|
||||
if (sym.size() == 1 && (sym[0] < 0x20 || sym[0] > 0x7e)) {
|
||||
// for bpe models with byte_fallback
|
||||
// (but don't rewrite printable characters 0x20..0x7e,
|
||||
// which collide with standard BPE units)
|
||||
std::ostringstream os;
|
||||
os << "<0x" << std::hex << std::uppercase
|
||||
<< (static_cast<int32_t>(sym[0]) & 0xff) << ">";
|
||||
sym = os.str();
|
||||
}
|
||||
|
||||
r.tokens.push_back(std::move(sym));
|
||||
}
|
||||
|
||||
if (sym_table.IsByteBpe()) {
|
||||
text = sym_table.DecodeByteBpe(text);
|
||||
}
|
||||
|
||||
r.text = std::move(text);
|
||||
|
||||
float frame_shift_s = frame_shift_ms / 1000. * subsampling_factor;
|
||||
for (auto t : src.timestamps) {
|
||||
float time = frame_shift_s * t;
|
||||
r.timestamps.push_back(time);
|
||||
}
|
||||
|
||||
r.segment = segment;
|
||||
r.start_time = frames_since_start * frame_shift_ms / 1000.;
|
||||
|
||||
return r;
|
||||
}
|
||||
|
||||
class OnlineRecognizerTransducerRknnImpl : public OnlineRecognizerImpl {
|
||||
public:
|
||||
explicit OnlineRecognizerTransducerRknnImpl(
|
||||
const OnlineRecognizerConfig &config)
|
||||
: OnlineRecognizerImpl(config),
|
||||
config_(config),
|
||||
endpoint_(config_.endpoint_config),
|
||||
model_(std::make_unique<OnlineZipformerTransducerModelRknn>(
|
||||
config.model_config)) {
|
||||
if (!config.model_config.tokens_buf.empty()) {
|
||||
sym_ = SymbolTable(config.model_config.tokens_buf, false);
|
||||
} else {
|
||||
/// assuming tokens_buf and tokens are guaranteed not being both empty
|
||||
sym_ = SymbolTable(config.model_config.tokens, true);
|
||||
}
|
||||
|
||||
if (sym_.Contains("<unk>")) {
|
||||
unk_id_ = sym_["<unk>"];
|
||||
}
|
||||
|
||||
decoder_ = std::make_unique<OnlineTransducerGreedySearchDecoderRknn>(
|
||||
model_.get(), unk_id_);
|
||||
}
|
||||
|
||||
template <typename Manager>
|
||||
explicit OnlineRecognizerTransducerRknnImpl(
|
||||
Manager *mgr, const OnlineRecognizerConfig &config)
|
||||
: OnlineRecognizerImpl(mgr, config),
|
||||
config_(config),
|
||||
endpoint_(config_.endpoint_config),
|
||||
model_(
|
||||
std::make_unique<OnlineZipformerTransducerModelRknn>(mgr, config)) {
|
||||
// TODO(fangjun): Support Android
|
||||
}
|
||||
|
||||
std::unique_ptr<OnlineStream> CreateStream() const override {
|
||||
auto stream = std::make_unique<OnlineStreamRknn>(config_.feat_config);
|
||||
auto r = decoder_->GetEmptyResult();
|
||||
stream->SetZipformerResult(std::move(r));
|
||||
stream->SetZipformerEncoderStates(model_->GetEncoderInitStates());
|
||||
return stream;
|
||||
}
|
||||
|
||||
std::unique_ptr<OnlineStream> CreateStream(
|
||||
const std::string &hotwords) const override {
|
||||
SHERPA_ONNX_LOGE("Hotwords for RKNN is not supported now.");
|
||||
return CreateStream();
|
||||
}
|
||||
|
||||
bool IsReady(OnlineStream *s) const override {
|
||||
return s->GetNumProcessedFrames() + model_->ChunkSize() <
|
||||
s->NumFramesReady();
|
||||
}
|
||||
|
||||
// Warmping up engine with wp: warm_up count and max-batch-size
|
||||
|
||||
void DecodeStreams(OnlineStream **ss, int32_t n) const override {
|
||||
for (int32_t i = 0; i < n; ++i) {
|
||||
DecodeStream(reinterpret_cast<OnlineStreamRknn *>(ss[i]));
|
||||
}
|
||||
}
|
||||
|
||||
OnlineRecognizerResult GetResult(OnlineStream *s) const override {
|
||||
OnlineTransducerDecoderResultRknn decoder_result =
|
||||
reinterpret_cast<OnlineStreamRknn *>(s)->GetZipformerResult();
|
||||
decoder_->StripLeadingBlanks(&decoder_result);
|
||||
// TODO(fangjun): Remember to change these constants if needed
|
||||
int32_t frame_shift_ms = 10;
|
||||
int32_t subsampling_factor = 4;
|
||||
auto r = Convert(decoder_result, sym_, frame_shift_ms, subsampling_factor,
|
||||
s->GetCurrentSegment(), s->GetNumFramesSinceStart());
|
||||
r.text = ApplyInverseTextNormalization(std::move(r.text));
|
||||
return r;
|
||||
}
|
||||
|
||||
bool IsEndpoint(OnlineStream *s) const override {
|
||||
if (!config_.enable_endpoint) {
|
||||
return false;
|
||||
}
|
||||
|
||||
int32_t num_processed_frames = s->GetNumProcessedFrames();
|
||||
|
||||
// frame shift is 10 milliseconds
|
||||
float frame_shift_in_seconds = 0.01;
|
||||
|
||||
// subsampling factor is 4
|
||||
int32_t trailing_silence_frames = reinterpret_cast<OnlineStreamRknn *>(s)
|
||||
->GetZipformerResult()
|
||||
.num_trailing_blanks *
|
||||
4;
|
||||
|
||||
return endpoint_.IsEndpoint(num_processed_frames, trailing_silence_frames,
|
||||
frame_shift_in_seconds);
|
||||
}
|
||||
|
||||
void Reset(OnlineStream *s) const override {
|
||||
int32_t context_size = model_->ContextSize();
|
||||
|
||||
{
|
||||
// segment is incremented only when the last
|
||||
// result is not empty, contains non-blanks and longer than context_size)
|
||||
const auto &r =
|
||||
reinterpret_cast<OnlineStreamRknn *>(s)->GetZipformerResult();
|
||||
if (!r.tokens.empty() && r.tokens.back() != 0 &&
|
||||
r.tokens.size() > context_size) {
|
||||
s->GetCurrentSegment() += 1;
|
||||
}
|
||||
}
|
||||
|
||||
// reset encoder states
|
||||
// reinterpret_cast<OnlineStreamRknn*>(s)->SetZipformerEncoderStates(model_->GetEncoderInitStates());
|
||||
auto r = decoder_->GetEmptyResult();
|
||||
auto last_result =
|
||||
reinterpret_cast<OnlineStreamRknn *>(s)->GetZipformerResult();
|
||||
|
||||
// if last result is not empty, then
|
||||
// preserve last tokens as the context for next result
|
||||
if (static_cast<int32_t>(last_result.tokens.size()) > context_size) {
|
||||
r.tokens = {last_result.tokens.end() - context_size,
|
||||
last_result.tokens.end()};
|
||||
}
|
||||
reinterpret_cast<OnlineStreamRknn *>(s)->SetZipformerResult(std::move(r));
|
||||
|
||||
// Note: We only update counters. The underlying audio samples
|
||||
// are not discarded.
|
||||
s->Reset();
|
||||
}
|
||||
|
||||
private:
|
||||
void DecodeStream(OnlineStreamRknn *s) const {
|
||||
int32_t chunk_size = model_->ChunkSize();
|
||||
int32_t chunk_shift = model_->ChunkShift();
|
||||
|
||||
int32_t feature_dim = s->FeatureDim();
|
||||
|
||||
const auto num_processed_frames = s->GetNumProcessedFrames();
|
||||
|
||||
std::vector<float> features =
|
||||
s->GetFrames(num_processed_frames, chunk_size);
|
||||
s->GetNumProcessedFrames() += chunk_shift;
|
||||
|
||||
auto &states = s->GetZipformerEncoderStates();
|
||||
|
||||
auto p = model_->RunEncoder(features, std::move(states));
|
||||
states = std::move(p.second);
|
||||
|
||||
auto &r = s->GetZipformerResult();
|
||||
decoder_->Decode(std::move(p.first), &r);
|
||||
}
|
||||
|
||||
private:
|
||||
OnlineRecognizerConfig config_;
|
||||
SymbolTable sym_;
|
||||
Endpoint endpoint_;
|
||||
int32_t unk_id_ = -1;
|
||||
std::unique_ptr<OnlineZipformerTransducerModelRknn> model_;
|
||||
std::unique_ptr<OnlineTransducerGreedySearchDecoderRknn> decoder_;
|
||||
};
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
#endif // SHERPA_ONNX_CSRC_RKNN_ONLINE_RECOGNIZER_TRANSDUCER_RKNN_IMPL_H_
|
||||
60
sherpa-onnx/csrc/rknn/online-stream-rknn.cc
Normal file
60
sherpa-onnx/csrc/rknn/online-stream-rknn.cc
Normal file
@ -0,0 +1,60 @@
|
||||
// sherpa-onnx/csrc/rknn/online-stream-rknn.cc
|
||||
//
|
||||
// Copyright (c) 2025 Xiaomi Corporation
|
||||
|
||||
#include "sherpa-onnx/csrc/rknn/online-stream-rknn.h"
|
||||
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
class OnlineStreamRknn::Impl {
|
||||
public:
|
||||
void SetZipformerEncoderStates(std::vector<std::vector<uint8_t>> states) {
|
||||
states_ = std::move(states);
|
||||
}
|
||||
|
||||
std::vector<std::vector<uint8_t>> &GetZipformerEncoderStates() {
|
||||
return states_;
|
||||
}
|
||||
|
||||
void SetZipformerResult(OnlineTransducerDecoderResultRknn r) {
|
||||
result_ = std::move(r);
|
||||
}
|
||||
|
||||
OnlineTransducerDecoderResultRknn &GetZipformerResult() { return result_; }
|
||||
|
||||
private:
|
||||
std::vector<std::vector<uint8_t>> states_;
|
||||
OnlineTransducerDecoderResultRknn result_;
|
||||
};
|
||||
|
||||
OnlineStreamRknn::OnlineStreamRknn(
|
||||
const FeatureExtractorConfig &config /*= {}*/,
|
||||
ContextGraphPtr context_graph /*= nullptr*/)
|
||||
: OnlineStream(config, context_graph), impl_(std::make_unique<Impl>()) {}
|
||||
|
||||
OnlineStreamRknn::~OnlineStreamRknn() = default;
|
||||
|
||||
void OnlineStreamRknn::SetZipformerEncoderStates(
|
||||
std::vector<std::vector<uint8_t>> states) const {
|
||||
impl_->SetZipformerEncoderStates(std::move(states));
|
||||
}
|
||||
|
||||
std::vector<std::vector<uint8_t>> &OnlineStreamRknn::GetZipformerEncoderStates()
|
||||
const {
|
||||
return impl_->GetZipformerEncoderStates();
|
||||
}
|
||||
|
||||
void OnlineStreamRknn::SetZipformerResult(
|
||||
OnlineTransducerDecoderResultRknn r) const {
|
||||
impl_->SetZipformerResult(std::move(r));
|
||||
}
|
||||
|
||||
OnlineTransducerDecoderResultRknn &OnlineStreamRknn::GetZipformerResult()
|
||||
const {
|
||||
return impl_->GetZipformerResult();
|
||||
}
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
38
sherpa-onnx/csrc/rknn/online-stream-rknn.h
Normal file
38
sherpa-onnx/csrc/rknn/online-stream-rknn.h
Normal file
@ -0,0 +1,38 @@
|
||||
// sherpa-onnx/csrc/rknn/online-stream-rknn.h
|
||||
//
|
||||
// Copyright (c) 2025 Xiaomi Corporation
|
||||
#ifndef SHERPA_ONNX_CSRC_RKNN_ONLINE_STREAM_RKNN_H_
|
||||
#define SHERPA_ONNX_CSRC_RKNN_ONLINE_STREAM_RKNN_H_
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "rknn_api.h" // NOLINT
|
||||
#include "sherpa-onnx/csrc/online-stream.h"
|
||||
#include "sherpa-onnx/csrc/rknn/online-transducer-greedy-search-decoder-rknn.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
class OnlineStreamRknn : public OnlineStream {
|
||||
public:
|
||||
explicit OnlineStreamRknn(const FeatureExtractorConfig &config = {},
|
||||
ContextGraphPtr context_graph = nullptr);
|
||||
|
||||
~OnlineStreamRknn();
|
||||
|
||||
void SetZipformerEncoderStates(
|
||||
std::vector<std::vector<uint8_t>> states) const;
|
||||
|
||||
std::vector<std::vector<uint8_t>> &GetZipformerEncoderStates() const;
|
||||
|
||||
void SetZipformerResult(OnlineTransducerDecoderResultRknn r) const;
|
||||
|
||||
OnlineTransducerDecoderResultRknn &GetZipformerResult() const;
|
||||
|
||||
private:
|
||||
class Impl;
|
||||
std::unique_ptr<Impl> impl_;
|
||||
};
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
#endif // SHERPA_ONNX_CSRC_RKNN_ONLINE_STREAM_RKNN_H_
|
||||
@ -0,0 +1,94 @@
|
||||
// sherpa-onnx/csrc/rknn/online-transducer-greedy-search-decoder-rknn.cc
|
||||
//
|
||||
// Copyright (c) 2025 Xiaomi Corporation
|
||||
|
||||
#include "sherpa-onnx/csrc/rknn/online-transducer-greedy-search-decoder-rknn.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
OnlineTransducerDecoderResultRknn
|
||||
OnlineTransducerGreedySearchDecoderRknn::GetEmptyResult() const {
|
||||
int32_t context_size = model_->ContextSize();
|
||||
int32_t blank_id = 0; // always 0
|
||||
OnlineTransducerDecoderResultRknn r;
|
||||
r.tokens.resize(context_size, -1);
|
||||
r.tokens.back() = blank_id;
|
||||
|
||||
return r;
|
||||
}
|
||||
|
||||
void OnlineTransducerGreedySearchDecoderRknn::StripLeadingBlanks(
|
||||
OnlineTransducerDecoderResultRknn *r) const {
|
||||
int32_t context_size = model_->ContextSize();
|
||||
|
||||
auto start = r->tokens.begin() + context_size;
|
||||
auto end = r->tokens.end();
|
||||
|
||||
r->tokens = std::vector<int64_t>(start, end);
|
||||
}
|
||||
|
||||
void OnlineTransducerGreedySearchDecoderRknn::Decode(
|
||||
std::vector<float> encoder_out,
|
||||
OnlineTransducerDecoderResultRknn *result) const {
|
||||
auto &r = result[0];
|
||||
auto attr = model_->GetEncoderOutAttr();
|
||||
int32_t num_frames = attr.dims[1];
|
||||
int32_t encoder_out_dim = attr.dims[2];
|
||||
|
||||
int32_t vocab_size = model_->VocabSize();
|
||||
int32_t context_size = model_->ContextSize();
|
||||
|
||||
std::vector<int64_t> decoder_input;
|
||||
std::vector<float> decoder_out;
|
||||
|
||||
if (r.previous_decoder_out.empty()) {
|
||||
decoder_input = {r.tokens.begin() + (r.tokens.size() - context_size),
|
||||
r.tokens.end()};
|
||||
decoder_out = model_->RunDecoder(std::move(decoder_input));
|
||||
|
||||
} else {
|
||||
decoder_out = std::move(r.previous_decoder_out);
|
||||
}
|
||||
|
||||
const float *p_encoder_out = encoder_out.data();
|
||||
for (int32_t t = 0; t != num_frames; ++t) {
|
||||
auto logit = model_->RunJoiner(p_encoder_out, decoder_out.data());
|
||||
p_encoder_out += encoder_out_dim;
|
||||
|
||||
bool emitted = false;
|
||||
if (blank_penalty_ > 0.0) {
|
||||
logit[0] -= blank_penalty_; // assuming blank id is 0
|
||||
}
|
||||
|
||||
auto y = static_cast<int32_t>(std::distance(
|
||||
logit.data(),
|
||||
std::max_element(logit.data(), logit.data() + vocab_size)));
|
||||
// blank id is hardcoded to 0
|
||||
// also, it treats unk as blank
|
||||
if (y != 0 && y != unk_id_) {
|
||||
emitted = true;
|
||||
r.tokens.push_back(y);
|
||||
r.timestamps.push_back(t + r.frame_offset);
|
||||
r.num_trailing_blanks = 0;
|
||||
} else {
|
||||
++r.num_trailing_blanks;
|
||||
}
|
||||
|
||||
if (emitted) {
|
||||
decoder_input = {r.tokens.begin() + (r.tokens.size() - context_size),
|
||||
r.tokens.end()};
|
||||
decoder_out = model_->RunDecoder(std::move(decoder_input));
|
||||
}
|
||||
}
|
||||
|
||||
r.frame_offset += num_frames;
|
||||
r.previous_decoder_out = std::move(decoder_out);
|
||||
}
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
@ -0,0 +1,52 @@
|
||||
// sherpa-onnx/csrc/rknn/online-transducer-greedy-search-decoder-rknn.h
|
||||
//
|
||||
// Copyright (c) 2025 Xiaomi Corporation
|
||||
|
||||
#ifndef SHERPA_ONNX_CSRC_RKNN_ONLINE_TRANSDUCER_GREEDY_SEARCH_DECODER_RKNN_H_
|
||||
#define SHERPA_ONNX_CSRC_RKNN_ONLINE_TRANSDUCER_GREEDY_SEARCH_DECODER_RKNN_H_
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "sherpa-onnx/csrc/rknn/online-zipformer-transducer-model-rknn.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
struct OnlineTransducerDecoderResultRknn {
|
||||
/// Number of frames after subsampling we have decoded so far
|
||||
int32_t frame_offset = 0;
|
||||
|
||||
/// The decoded token IDs so far
|
||||
std::vector<int64_t> tokens;
|
||||
|
||||
/// number of trailing blank frames decoded so far
|
||||
int32_t num_trailing_blanks = 0;
|
||||
|
||||
/// timestamps[i] contains the output frame index where tokens[i] is decoded.
|
||||
std::vector<int32_t> timestamps;
|
||||
|
||||
std::vector<float> previous_decoder_out;
|
||||
};
|
||||
|
||||
class OnlineTransducerGreedySearchDecoderRknn {
|
||||
public:
|
||||
explicit OnlineTransducerGreedySearchDecoderRknn(
|
||||
OnlineZipformerTransducerModelRknn *model, int32_t unk_id = 2,
|
||||
float blank_penalty = 0.0)
|
||||
: model_(model), unk_id_(unk_id), blank_penalty_(blank_penalty) {}
|
||||
|
||||
OnlineTransducerDecoderResultRknn GetEmptyResult() const;
|
||||
|
||||
void StripLeadingBlanks(OnlineTransducerDecoderResultRknn *r) const;
|
||||
|
||||
void Decode(std::vector<float> encoder_out,
|
||||
OnlineTransducerDecoderResultRknn *result) const;
|
||||
|
||||
private:
|
||||
OnlineZipformerTransducerModelRknn *model_; // Not owned
|
||||
int32_t unk_id_;
|
||||
float blank_penalty_;
|
||||
};
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
#endif // SHERPA_ONNX_CSRC_RKNN_ONLINE_TRANSDUCER_GREEDY_SEARCH_DECODER_RKNN_H_
|
||||
781
sherpa-onnx/csrc/rknn/online-zipformer-transducer-model-rknn.cc
Normal file
781
sherpa-onnx/csrc/rknn/online-zipformer-transducer-model-rknn.cc
Normal file
@ -0,0 +1,781 @@
|
||||
// sherpa-onnx/csrc/rknn/online-zipformer-transducer-model-rknn.cc
|
||||
//
|
||||
// Copyright (c) 2023 Xiaomi Corporation
|
||||
|
||||
#include "sherpa-onnx/csrc/rknn/online-zipformer-transducer-model-rknn.h"
|
||||
|
||||
#include <memory>
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
#include "android/asset_manager.h"
|
||||
#include "android/asset_manager_jni.h"
|
||||
#endif
|
||||
|
||||
#if __OHOS__
|
||||
#include "rawfile/raw_file_manager.h"
|
||||
#endif
|
||||
|
||||
#include "sherpa-onnx/csrc/file-utils.h"
|
||||
#include "sherpa-onnx/csrc/rknn/macros.h"
|
||||
#include "sherpa-onnx/csrc/text-utils.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
// chw -> hwc
|
||||
static void Transpose(const float *src, int32_t n, int32_t channel,
|
||||
int32_t height, int32_t width, float *dst) {
|
||||
for (int32_t i = 0; i < n; ++i) {
|
||||
for (int32_t h = 0; h < height; ++h) {
|
||||
for (int32_t w = 0; w < width; ++w) {
|
||||
for (int32_t c = 0; c < channel; ++c) {
|
||||
// dst[h, w, c] = src[c, h, w]
|
||||
dst[i * height * width * channel + h * width * channel + w * channel +
|
||||
c] = src[i * height * width * channel + c * height * width +
|
||||
h * width + w];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static std::string ToString(const rknn_tensor_attr &attr) {
|
||||
std::ostringstream os;
|
||||
os << "{";
|
||||
os << attr.index;
|
||||
os << ", name: " << attr.name;
|
||||
os << ", shape: (";
|
||||
std::string sep;
|
||||
for (int32_t i = 0; i < static_cast<int32_t>(attr.n_dims); ++i) {
|
||||
os << sep << attr.dims[i];
|
||||
sep = ",";
|
||||
}
|
||||
os << ")";
|
||||
os << ", n_elems: " << attr.n_elems;
|
||||
os << ", size: " << attr.size;
|
||||
os << ", fmt: " << get_format_string(attr.fmt);
|
||||
os << ", type: " << get_type_string(attr.type);
|
||||
os << ", pass_through: " << (attr.pass_through ? "true" : "false");
|
||||
os << "}";
|
||||
return os.str();
|
||||
}
|
||||
|
||||
static std::unordered_map<std::string, std::string> Parse(
|
||||
const rknn_custom_string &custom_string) {
|
||||
std::unordered_map<std::string, std::string> ans;
|
||||
std::vector<std::string> fields;
|
||||
SplitStringToVector(custom_string.string, ";", false, &fields);
|
||||
|
||||
std::vector<std::string> tmp;
|
||||
for (const auto &f : fields) {
|
||||
SplitStringToVector(f, "=", false, &tmp);
|
||||
if (tmp.size() != 2) {
|
||||
SHERPA_ONNX_LOGE("Invalid custom string %s for %s", custom_string.string,
|
||||
f.c_str());
|
||||
SHERPA_ONNX_EXIT(-1);
|
||||
}
|
||||
ans[std::move(tmp[0])] = std::move(tmp[1]);
|
||||
}
|
||||
|
||||
return ans;
|
||||
}
|
||||
|
||||
class OnlineZipformerTransducerModelRknn::Impl {
|
||||
public:
|
||||
~Impl() {
|
||||
auto ret = rknn_destroy(encoder_ctx_);
|
||||
if (ret != RKNN_SUCC) {
|
||||
SHERPA_ONNX_LOGE("Failed to destroy the encoder context");
|
||||
}
|
||||
|
||||
ret = rknn_destroy(decoder_ctx_);
|
||||
if (ret != RKNN_SUCC) {
|
||||
SHERPA_ONNX_LOGE("Failed to destroy the decoder context");
|
||||
}
|
||||
|
||||
ret = rknn_destroy(joiner_ctx_);
|
||||
if (ret != RKNN_SUCC) {
|
||||
SHERPA_ONNX_LOGE("Failed to destroy the joiner context");
|
||||
}
|
||||
}
|
||||
|
||||
explicit Impl(const OnlineModelConfig &config) : config_(config) {
|
||||
{
|
||||
auto buf = ReadFile(config.transducer.encoder);
|
||||
InitEncoder(buf.data(), buf.size());
|
||||
}
|
||||
|
||||
{
|
||||
auto buf = ReadFile(config.transducer.decoder);
|
||||
InitDecoder(buf.data(), buf.size());
|
||||
}
|
||||
|
||||
{
|
||||
auto buf = ReadFile(config.transducer.joiner);
|
||||
InitJoiner(buf.data(), buf.size());
|
||||
}
|
||||
|
||||
// Now select which core to run for RK3588
|
||||
int32_t ret_encoder = RKNN_SUCC;
|
||||
int32_t ret_decoder = RKNN_SUCC;
|
||||
int32_t ret_joiner = RKNN_SUCC;
|
||||
switch (config_.num_threads) {
|
||||
case 1:
|
||||
ret_encoder = rknn_set_core_mask(encoder_ctx_, RKNN_NPU_CORE_AUTO);
|
||||
ret_decoder = rknn_set_core_mask(decoder_ctx_, RKNN_NPU_CORE_AUTO);
|
||||
ret_joiner = rknn_set_core_mask(joiner_ctx_, RKNN_NPU_CORE_AUTO);
|
||||
break;
|
||||
case 0:
|
||||
ret_encoder = rknn_set_core_mask(encoder_ctx_, RKNN_NPU_CORE_0);
|
||||
ret_decoder = rknn_set_core_mask(decoder_ctx_, RKNN_NPU_CORE_0);
|
||||
ret_joiner = rknn_set_core_mask(joiner_ctx_, RKNN_NPU_CORE_0);
|
||||
break;
|
||||
case -1:
|
||||
ret_encoder = rknn_set_core_mask(encoder_ctx_, RKNN_NPU_CORE_1);
|
||||
ret_decoder = rknn_set_core_mask(decoder_ctx_, RKNN_NPU_CORE_1);
|
||||
ret_joiner = rknn_set_core_mask(joiner_ctx_, RKNN_NPU_CORE_1);
|
||||
break;
|
||||
case -2:
|
||||
ret_encoder = rknn_set_core_mask(encoder_ctx_, RKNN_NPU_CORE_2);
|
||||
ret_decoder = rknn_set_core_mask(decoder_ctx_, RKNN_NPU_CORE_2);
|
||||
ret_joiner = rknn_set_core_mask(joiner_ctx_, RKNN_NPU_CORE_2);
|
||||
break;
|
||||
case -3:
|
||||
ret_encoder = rknn_set_core_mask(encoder_ctx_, RKNN_NPU_CORE_0_1);
|
||||
ret_decoder = rknn_set_core_mask(decoder_ctx_, RKNN_NPU_CORE_0_1);
|
||||
ret_joiner = rknn_set_core_mask(joiner_ctx_, RKNN_NPU_CORE_0_1);
|
||||
break;
|
||||
case -4:
|
||||
ret_encoder = rknn_set_core_mask(encoder_ctx_, RKNN_NPU_CORE_0_1_2);
|
||||
ret_decoder = rknn_set_core_mask(decoder_ctx_, RKNN_NPU_CORE_0_1_2);
|
||||
ret_joiner = rknn_set_core_mask(joiner_ctx_, RKNN_NPU_CORE_0_1_2);
|
||||
break;
|
||||
default:
|
||||
SHERPA_ONNX_LOGE(
|
||||
"Valid num_threads for rk npu is 1 (auto), 0 (core 0), -1 (core "
|
||||
"1), -2 (core 2), -3 (core 0_1), -4 (core 0_1_2). Given: %d",
|
||||
config_.num_threads);
|
||||
break;
|
||||
}
|
||||
if (ret_encoder != RKNN_SUCC) {
|
||||
SHERPA_ONNX_LOGE(
|
||||
"Failed to select npu core to run encoder (You can ignore it if you "
|
||||
"are not using RK3588.");
|
||||
}
|
||||
|
||||
if (ret_decoder != RKNN_SUCC) {
|
||||
SHERPA_ONNX_LOGE(
|
||||
"Failed to select npu core to run decoder (You can ignore it if you "
|
||||
"are not using RK3588.");
|
||||
}
|
||||
|
||||
if (ret_decoder != RKNN_SUCC) {
|
||||
SHERPA_ONNX_LOGE(
|
||||
"Failed to select npu core to run joiner (You can ignore it if you "
|
||||
"are not using RK3588.");
|
||||
}
|
||||
}
|
||||
|
||||
// TODO(fangjun): Support Android
|
||||
|
||||
std::vector<std::vector<uint8_t>> GetEncoderInitStates() const {
|
||||
// encoder_input_attrs_[0] is for the feature
|
||||
// encoder_input_attrs_[1:] is for states
|
||||
// so we use -1 here
|
||||
std::vector<std::vector<uint8_t>> states(encoder_input_attrs_.size() - 1);
|
||||
|
||||
int32_t i = -1;
|
||||
for (auto &attr : encoder_input_attrs_) {
|
||||
i += 1;
|
||||
if (i == 0) {
|
||||
// skip processing the attr for features.
|
||||
continue;
|
||||
}
|
||||
|
||||
if (attr.type == RKNN_TENSOR_FLOAT16) {
|
||||
states[i - 1].resize(attr.n_elems * sizeof(float));
|
||||
} else if (attr.type == RKNN_TENSOR_INT64) {
|
||||
states[i - 1].resize(attr.n_elems * sizeof(int64_t));
|
||||
} else {
|
||||
SHERPA_ONNX_LOGE("Unsupported tensor type: %d, %s", attr.type,
|
||||
get_type_string(attr.type));
|
||||
SHERPA_ONNX_EXIT(-1);
|
||||
}
|
||||
}
|
||||
|
||||
return states;
|
||||
}
|
||||
|
||||
std::pair<std::vector<float>, std::vector<std::vector<uint8_t>>> RunEncoder(
|
||||
std::vector<float> features,
|
||||
std::vector<std::vector<uint8_t>> states) const {
|
||||
std::vector<rknn_input> inputs(encoder_input_attrs_.size());
|
||||
|
||||
for (int32_t i = 0; i < static_cast<int32_t>(inputs.size()); ++i) {
|
||||
auto &input = inputs[i];
|
||||
auto &attr = encoder_input_attrs_[i];
|
||||
input.index = attr.index;
|
||||
|
||||
if (attr.type == RKNN_TENSOR_FLOAT16) {
|
||||
input.type = RKNN_TENSOR_FLOAT32;
|
||||
} else if (attr.type == RKNN_TENSOR_INT64) {
|
||||
input.type = RKNN_TENSOR_INT64;
|
||||
} else {
|
||||
SHERPA_ONNX_LOGE("Unsupported tensor type %d, %s", attr.type,
|
||||
get_type_string(attr.type));
|
||||
SHERPA_ONNX_EXIT(-1);
|
||||
}
|
||||
|
||||
input.fmt = attr.fmt;
|
||||
if (i == 0) {
|
||||
input.buf = reinterpret_cast<void *>(features.data());
|
||||
input.size = features.size() * sizeof(float);
|
||||
} else {
|
||||
input.buf = reinterpret_cast<void *>(states[i - 1].data());
|
||||
input.size = states[i - 1].size();
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<float> encoder_out(encoder_output_attrs_[0].n_elems);
|
||||
|
||||
// Note(fangjun): We can reuse the memory from input argument `states`
|
||||
// auto next_states = GetEncoderInitStates();
|
||||
auto &next_states = states;
|
||||
|
||||
std::vector<rknn_output> outputs(encoder_output_attrs_.size());
|
||||
for (int32_t i = 0; i < outputs.size(); ++i) {
|
||||
auto &output = outputs[i];
|
||||
auto &attr = encoder_output_attrs_[i];
|
||||
output.index = attr.index;
|
||||
output.is_prealloc = 1;
|
||||
|
||||
if (attr.type == RKNN_TENSOR_FLOAT16) {
|
||||
output.want_float = 1;
|
||||
} else if (attr.type == RKNN_TENSOR_INT64) {
|
||||
output.want_float = 0;
|
||||
} else {
|
||||
SHERPA_ONNX_LOGE("Unsupported tensor type %d, %s", attr.type,
|
||||
get_type_string(attr.type));
|
||||
SHERPA_ONNX_EXIT(-1);
|
||||
}
|
||||
|
||||
if (i == 0) {
|
||||
output.size = encoder_out.size() * sizeof(float);
|
||||
output.buf = reinterpret_cast<void *>(encoder_out.data());
|
||||
} else {
|
||||
output.size = next_states[i - 1].size();
|
||||
output.buf = reinterpret_cast<void *>(next_states[i - 1].data());
|
||||
}
|
||||
}
|
||||
|
||||
auto ret = rknn_inputs_set(encoder_ctx_, inputs.size(), inputs.data());
|
||||
SHERPA_ONNX_RKNN_CHECK(ret, "Failed to set encoder inputs");
|
||||
|
||||
ret = rknn_run(encoder_ctx_, nullptr);
|
||||
SHERPA_ONNX_RKNN_CHECK(ret, "Failed to run encoder");
|
||||
|
||||
ret =
|
||||
rknn_outputs_get(encoder_ctx_, outputs.size(), outputs.data(), nullptr);
|
||||
SHERPA_ONNX_RKNN_CHECK(ret, "Failed to get encoder output");
|
||||
|
||||
for (int32_t i = 0; i < next_states.size(); ++i) {
|
||||
const auto &attr = encoder_input_attrs_[i + 1];
|
||||
if (attr.n_dims == 4) {
|
||||
// TODO(fangjun): The transpose is copied from
|
||||
// https://github.com/airockchip/rknn_model_zoo/blob/main/examples/zipformer/cpp/process.cc#L22
|
||||
// I don't understand why we need to do that.
|
||||
std::vector<uint8_t> dst(next_states[i].size());
|
||||
int32_t n = attr.dims[0];
|
||||
int32_t h = attr.dims[1];
|
||||
int32_t w = attr.dims[2];
|
||||
int32_t c = attr.dims[3];
|
||||
Transpose(reinterpret_cast<const float *>(next_states[i].data()), n, c,
|
||||
h, w, reinterpret_cast<float *>(dst.data()));
|
||||
next_states[i] = std::move(dst);
|
||||
}
|
||||
}
|
||||
|
||||
return {std::move(encoder_out), std::move(next_states)};
|
||||
}
|
||||
|
||||
std::vector<float> RunDecoder(std::vector<int64_t> decoder_input) const {
|
||||
auto &attr = decoder_input_attrs_[0];
|
||||
rknn_input input;
|
||||
|
||||
input.index = 0;
|
||||
input.type = RKNN_TENSOR_INT64;
|
||||
input.fmt = attr.fmt;
|
||||
input.buf = decoder_input.data();
|
||||
input.size = decoder_input.size() * sizeof(int64_t);
|
||||
|
||||
std::vector<float> decoder_out(decoder_output_attrs_[0].n_elems);
|
||||
rknn_output output;
|
||||
output.index = decoder_output_attrs_[0].index;
|
||||
output.is_prealloc = 1;
|
||||
output.want_float = 1;
|
||||
output.size = decoder_out.size() * sizeof(float);
|
||||
output.buf = decoder_out.data();
|
||||
|
||||
auto ret = rknn_inputs_set(decoder_ctx_, 1, &input);
|
||||
SHERPA_ONNX_RKNN_CHECK(ret, "Failed to set decoder inputs");
|
||||
|
||||
ret = rknn_run(decoder_ctx_, nullptr);
|
||||
SHERPA_ONNX_RKNN_CHECK(ret, "Failed to run decoder");
|
||||
|
||||
ret = rknn_outputs_get(decoder_ctx_, 1, &output, nullptr);
|
||||
SHERPA_ONNX_RKNN_CHECK(ret, "Failed to get decoder output");
|
||||
|
||||
return decoder_out;
|
||||
}
|
||||
|
||||
std::vector<float> RunJoiner(const float *encoder_out,
|
||||
const float *decoder_out) const {
|
||||
std::vector<rknn_input> inputs(2);
|
||||
inputs[0].index = 0;
|
||||
inputs[0].type = RKNN_TENSOR_FLOAT32;
|
||||
inputs[0].fmt = joiner_input_attrs_[0].fmt;
|
||||
inputs[0].buf = const_cast<float *>(encoder_out);
|
||||
inputs[0].size = joiner_input_attrs_[0].n_elems * sizeof(float);
|
||||
|
||||
inputs[1].index = 1;
|
||||
inputs[1].type = RKNN_TENSOR_FLOAT32;
|
||||
inputs[1].fmt = joiner_input_attrs_[1].fmt;
|
||||
inputs[1].buf = const_cast<float *>(decoder_out);
|
||||
inputs[1].size = joiner_input_attrs_[1].n_elems * sizeof(float);
|
||||
|
||||
std::vector<float> joiner_out(joiner_output_attrs_[0].n_elems);
|
||||
rknn_output output;
|
||||
output.index = joiner_output_attrs_[0].index;
|
||||
output.is_prealloc = 1;
|
||||
output.want_float = 1;
|
||||
output.size = joiner_out.size() * sizeof(float);
|
||||
output.buf = joiner_out.data();
|
||||
|
||||
auto ret = rknn_inputs_set(joiner_ctx_, inputs.size(), inputs.data());
|
||||
SHERPA_ONNX_RKNN_CHECK(ret, "Failed to set joiner inputs");
|
||||
|
||||
ret = rknn_run(joiner_ctx_, nullptr);
|
||||
SHERPA_ONNX_RKNN_CHECK(ret, "Failed to run joiner");
|
||||
|
||||
ret = rknn_outputs_get(joiner_ctx_, 1, &output, nullptr);
|
||||
SHERPA_ONNX_RKNN_CHECK(ret, "Failed to get joiner output");
|
||||
|
||||
return joiner_out;
|
||||
}
|
||||
|
||||
int32_t ContextSize() const { return context_size_; }
|
||||
|
||||
int32_t ChunkSize() const { return T_; }
|
||||
|
||||
int32_t ChunkShift() const { return decode_chunk_len_; }
|
||||
|
||||
int32_t VocabSize() const { return vocab_size_; }
|
||||
|
||||
rknn_tensor_attr GetEncoderOutAttr() const {
|
||||
return encoder_output_attrs_[0];
|
||||
}
|
||||
|
||||
private:
|
||||
void InitEncoder(void *model_data, size_t model_data_length) {
|
||||
auto ret =
|
||||
rknn_init(&encoder_ctx_, model_data, model_data_length, 0, nullptr);
|
||||
SHERPA_ONNX_RKNN_CHECK(ret, "Failed to init encoder '%s'",
|
||||
config_.transducer.encoder.c_str());
|
||||
|
||||
if (config_.debug) {
|
||||
rknn_sdk_version v;
|
||||
ret = rknn_query(encoder_ctx_, RKNN_QUERY_SDK_VERSION, &v, sizeof(v));
|
||||
SHERPA_ONNX_RKNN_CHECK(ret, "Failed to get rknn sdk version");
|
||||
|
||||
SHERPA_ONNX_LOGE("sdk api version: %s, driver version: %s", v.api_version,
|
||||
v.drv_version);
|
||||
}
|
||||
|
||||
rknn_input_output_num io_num;
|
||||
ret = rknn_query(encoder_ctx_, RKNN_QUERY_IN_OUT_NUM, &io_num,
|
||||
sizeof(io_num));
|
||||
SHERPA_ONNX_RKNN_CHECK(ret,
|
||||
"Failed to get I/O information for the encoder");
|
||||
|
||||
if (config_.debug) {
|
||||
SHERPA_ONNX_LOGE("encoder: %d inputs, %d outputs",
|
||||
static_cast<int32_t>(io_num.n_input),
|
||||
static_cast<int32_t>(io_num.n_output));
|
||||
}
|
||||
|
||||
encoder_input_attrs_.resize(io_num.n_input);
|
||||
encoder_output_attrs_.resize(io_num.n_output);
|
||||
|
||||
int32_t i = 0;
|
||||
for (auto &attr : encoder_input_attrs_) {
|
||||
memset(&attr, 0, sizeof(attr));
|
||||
attr.index = i;
|
||||
ret =
|
||||
rknn_query(encoder_ctx_, RKNN_QUERY_INPUT_ATTR, &attr, sizeof(attr));
|
||||
SHERPA_ONNX_RKNN_CHECK(ret, "Failed to get attr for encoder input %d", i);
|
||||
i += 1;
|
||||
}
|
||||
|
||||
if (config_.debug) {
|
||||
std::ostringstream os;
|
||||
std::string sep;
|
||||
for (auto &attr : encoder_input_attrs_) {
|
||||
os << sep << ToString(attr);
|
||||
sep = "\n";
|
||||
}
|
||||
SHERPA_ONNX_LOGE("\n----------Encoder inputs info----------\n%s",
|
||||
os.str().c_str());
|
||||
}
|
||||
|
||||
i = 0;
|
||||
for (auto &attr : encoder_output_attrs_) {
|
||||
memset(&attr, 0, sizeof(attr));
|
||||
attr.index = i;
|
||||
ret =
|
||||
rknn_query(encoder_ctx_, RKNN_QUERY_OUTPUT_ATTR, &attr, sizeof(attr));
|
||||
SHERPA_ONNX_RKNN_CHECK(ret, "Failed to get attr for encoder output %d",
|
||||
i);
|
||||
i += 1;
|
||||
}
|
||||
|
||||
if (config_.debug) {
|
||||
std::ostringstream os;
|
||||
std::string sep;
|
||||
for (auto &attr : encoder_output_attrs_) {
|
||||
os << sep << ToString(attr);
|
||||
sep = "\n";
|
||||
}
|
||||
SHERPA_ONNX_LOGE("\n----------Encoder outputs info----------\n%s",
|
||||
os.str().c_str());
|
||||
}
|
||||
|
||||
rknn_custom_string custom_string;
|
||||
ret = rknn_query(encoder_ctx_, RKNN_QUERY_CUSTOM_STRING, &custom_string,
|
||||
sizeof(custom_string));
|
||||
SHERPA_ONNX_RKNN_CHECK(
|
||||
ret, "Failed to read custom string from the encoder model");
|
||||
if (config_.debug) {
|
||||
SHERPA_ONNX_LOGE("customs string: %s", custom_string.string);
|
||||
}
|
||||
auto meta = Parse(custom_string);
|
||||
|
||||
for (const auto &p : meta) {
|
||||
SHERPA_ONNX_LOGE("%s: %s", p.first.c_str(), p.second.c_str());
|
||||
}
|
||||
|
||||
if (meta.count("encoder_dims")) {
|
||||
SplitStringToIntegers(meta.at("encoder_dims"), ",", false,
|
||||
&encoder_dims_);
|
||||
}
|
||||
|
||||
if (meta.count("attention_dims")) {
|
||||
SplitStringToIntegers(meta.at("attention_dims"), ",", false,
|
||||
&attention_dims_);
|
||||
}
|
||||
|
||||
if (meta.count("num_encoder_layers")) {
|
||||
SplitStringToIntegers(meta.at("num_encoder_layers"), ",", false,
|
||||
&num_encoder_layers_);
|
||||
}
|
||||
|
||||
if (meta.count("cnn_module_kernels")) {
|
||||
SplitStringToIntegers(meta.at("cnn_module_kernels"), ",", false,
|
||||
&cnn_module_kernels_);
|
||||
}
|
||||
|
||||
if (meta.count("left_context_len")) {
|
||||
SplitStringToIntegers(meta.at("left_context_len"), ",", false,
|
||||
&left_context_len_);
|
||||
}
|
||||
|
||||
if (meta.count("T")) {
|
||||
T_ = atoi(meta.at("T").c_str());
|
||||
}
|
||||
|
||||
if (meta.count("decode_chunk_len")) {
|
||||
decode_chunk_len_ = atoi(meta.at("decode_chunk_len").c_str());
|
||||
}
|
||||
|
||||
if (meta.count("context_size")) {
|
||||
context_size_ = atoi(meta.at("context_size").c_str());
|
||||
}
|
||||
|
||||
if (config_.debug) {
|
||||
auto print = [](const std::vector<int32_t> &v, const char *name) {
|
||||
std::ostringstream os;
|
||||
os << name << ": ";
|
||||
for (auto i : v) {
|
||||
os << i << " ";
|
||||
}
|
||||
#if __OHOS__
|
||||
SHERPA_ONNX_LOGE("%{public}s\n", os.str().c_str());
|
||||
#else
|
||||
SHERPA_ONNX_LOGE("%s\n", os.str().c_str());
|
||||
#endif
|
||||
};
|
||||
print(encoder_dims_, "encoder_dims");
|
||||
print(attention_dims_, "attention_dims");
|
||||
print(num_encoder_layers_, "num_encoder_layers");
|
||||
print(cnn_module_kernels_, "cnn_module_kernels");
|
||||
print(left_context_len_, "left_context_len");
|
||||
#if __OHOS__
|
||||
SHERPA_ONNX_LOGE("T: %{public}d", T_);
|
||||
SHERPA_ONNX_LOGE("decode_chunk_len_: %{public}d", decode_chunk_len_);
|
||||
SHERPA_ONNX_LOGE("context_size: %{public}d", context_size_);
|
||||
#else
|
||||
SHERPA_ONNX_LOGE("T: %d", T_);
|
||||
SHERPA_ONNX_LOGE("decode_chunk_len_: %d", decode_chunk_len_);
|
||||
SHERPA_ONNX_LOGE("context_size: %d", context_size_);
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
void InitDecoder(void *model_data, size_t model_data_length) {
|
||||
auto ret =
|
||||
rknn_init(&decoder_ctx_, model_data, model_data_length, 0, nullptr);
|
||||
SHERPA_ONNX_RKNN_CHECK(ret, "Failed to init decoder '%s'",
|
||||
config_.transducer.decoder.c_str());
|
||||
|
||||
rknn_input_output_num io_num;
|
||||
ret = rknn_query(decoder_ctx_, RKNN_QUERY_IN_OUT_NUM, &io_num,
|
||||
sizeof(io_num));
|
||||
SHERPA_ONNX_RKNN_CHECK(ret,
|
||||
"Failed to get I/O information for the decoder");
|
||||
|
||||
if (io_num.n_input != 1) {
|
||||
SHERPA_ONNX_LOGE("Expect only 1 decoder input. Given %d",
|
||||
static_cast<int32_t>(io_num.n_input));
|
||||
SHERPA_ONNX_EXIT(-1);
|
||||
}
|
||||
|
||||
if (io_num.n_output != 1) {
|
||||
SHERPA_ONNX_LOGE("Expect only 1 decoder output. Given %d",
|
||||
static_cast<int32_t>(io_num.n_output));
|
||||
SHERPA_ONNX_EXIT(-1);
|
||||
}
|
||||
|
||||
if (config_.debug) {
|
||||
SHERPA_ONNX_LOGE("decoder: %d inputs, %d outputs",
|
||||
static_cast<int32_t>(io_num.n_input),
|
||||
static_cast<int32_t>(io_num.n_output));
|
||||
}
|
||||
|
||||
decoder_input_attrs_.resize(io_num.n_input);
|
||||
decoder_output_attrs_.resize(io_num.n_output);
|
||||
|
||||
int32_t i = 0;
|
||||
for (auto &attr : decoder_input_attrs_) {
|
||||
memset(&attr, 0, sizeof(attr));
|
||||
attr.index = i;
|
||||
ret =
|
||||
rknn_query(decoder_ctx_, RKNN_QUERY_INPUT_ATTR, &attr, sizeof(attr));
|
||||
SHERPA_ONNX_RKNN_CHECK(ret, "Failed to get attr for decoder input %d", i);
|
||||
i += 1;
|
||||
}
|
||||
|
||||
if (config_.debug) {
|
||||
std::ostringstream os;
|
||||
std::string sep;
|
||||
for (auto &attr : decoder_input_attrs_) {
|
||||
os << sep << ToString(attr);
|
||||
sep = "\n";
|
||||
}
|
||||
SHERPA_ONNX_LOGE("\n----------Decoder inputs info----------\n%s",
|
||||
os.str().c_str());
|
||||
}
|
||||
|
||||
if (decoder_input_attrs_[0].type != RKNN_TENSOR_INT64) {
|
||||
SHERPA_ONNX_LOGE("Expect int64 for decoder input. Given: %d, %s",
|
||||
decoder_input_attrs_[0].type,
|
||||
get_type_string(decoder_input_attrs_[0].type));
|
||||
SHERPA_ONNX_EXIT(-1);
|
||||
}
|
||||
|
||||
i = 0;
|
||||
for (auto &attr : decoder_output_attrs_) {
|
||||
memset(&attr, 0, sizeof(attr));
|
||||
attr.index = i;
|
||||
ret =
|
||||
rknn_query(decoder_ctx_, RKNN_QUERY_OUTPUT_ATTR, &attr, sizeof(attr));
|
||||
SHERPA_ONNX_RKNN_CHECK(ret, "Failed to get attr for decoder output %d",
|
||||
i);
|
||||
i += 1;
|
||||
}
|
||||
|
||||
if (config_.debug) {
|
||||
std::ostringstream os;
|
||||
std::string sep;
|
||||
for (auto &attr : decoder_output_attrs_) {
|
||||
os << sep << ToString(attr);
|
||||
sep = "\n";
|
||||
}
|
||||
SHERPA_ONNX_LOGE("\n----------Decoder outputs info----------\n%s",
|
||||
os.str().c_str());
|
||||
}
|
||||
}
|
||||
|
||||
void InitJoiner(void *model_data, size_t model_data_length) {
|
||||
auto ret =
|
||||
rknn_init(&joiner_ctx_, model_data, model_data_length, 0, nullptr);
|
||||
SHERPA_ONNX_RKNN_CHECK(ret, "Failed to init joiner '%s'",
|
||||
config_.transducer.joiner.c_str());
|
||||
|
||||
rknn_input_output_num io_num;
|
||||
ret =
|
||||
rknn_query(joiner_ctx_, RKNN_QUERY_IN_OUT_NUM, &io_num, sizeof(io_num));
|
||||
SHERPA_ONNX_RKNN_CHECK(ret, "Failed to get I/O information for the joiner");
|
||||
|
||||
if (config_.debug) {
|
||||
SHERPA_ONNX_LOGE("joiner: %d inputs, %d outputs",
|
||||
static_cast<int32_t>(io_num.n_input),
|
||||
static_cast<int32_t>(io_num.n_output));
|
||||
}
|
||||
|
||||
joiner_input_attrs_.resize(io_num.n_input);
|
||||
joiner_output_attrs_.resize(io_num.n_output);
|
||||
|
||||
int32_t i = 0;
|
||||
for (auto &attr : joiner_input_attrs_) {
|
||||
memset(&attr, 0, sizeof(attr));
|
||||
attr.index = i;
|
||||
ret = rknn_query(joiner_ctx_, RKNN_QUERY_INPUT_ATTR, &attr, sizeof(attr));
|
||||
SHERPA_ONNX_RKNN_CHECK(ret, "Failed to get attr for joiner input %d", i);
|
||||
i += 1;
|
||||
}
|
||||
|
||||
if (config_.debug) {
|
||||
std::ostringstream os;
|
||||
std::string sep;
|
||||
for (auto &attr : joiner_input_attrs_) {
|
||||
os << sep << ToString(attr);
|
||||
sep = "\n";
|
||||
}
|
||||
SHERPA_ONNX_LOGE("\n----------Joiner inputs info----------\n%s",
|
||||
os.str().c_str());
|
||||
}
|
||||
|
||||
i = 0;
|
||||
for (auto &attr : joiner_output_attrs_) {
|
||||
memset(&attr, 0, sizeof(attr));
|
||||
attr.index = i;
|
||||
ret =
|
||||
rknn_query(joiner_ctx_, RKNN_QUERY_OUTPUT_ATTR, &attr, sizeof(attr));
|
||||
SHERPA_ONNX_RKNN_CHECK(ret, "Failed to get attr for joiner output %d", i);
|
||||
i += 1;
|
||||
}
|
||||
|
||||
if (config_.debug) {
|
||||
std::ostringstream os;
|
||||
std::string sep;
|
||||
for (auto &attr : joiner_output_attrs_) {
|
||||
os << sep << ToString(attr);
|
||||
sep = "\n";
|
||||
}
|
||||
SHERPA_ONNX_LOGE("\n----------Joiner outputs info----------\n%s",
|
||||
os.str().c_str());
|
||||
}
|
||||
|
||||
vocab_size_ = joiner_output_attrs_[0].dims[1];
|
||||
if (config_.debug) {
|
||||
SHERPA_ONNX_LOGE("vocab_size: %d", vocab_size_);
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
OnlineModelConfig config_;
|
||||
rknn_context encoder_ctx_ = 0;
|
||||
rknn_context decoder_ctx_ = 0;
|
||||
rknn_context joiner_ctx_ = 0;
|
||||
|
||||
std::vector<rknn_tensor_attr> encoder_input_attrs_;
|
||||
std::vector<rknn_tensor_attr> encoder_output_attrs_;
|
||||
|
||||
std::vector<rknn_tensor_attr> decoder_input_attrs_;
|
||||
std::vector<rknn_tensor_attr> decoder_output_attrs_;
|
||||
|
||||
std::vector<rknn_tensor_attr> joiner_input_attrs_;
|
||||
std::vector<rknn_tensor_attr> joiner_output_attrs_;
|
||||
|
||||
std::vector<int32_t> encoder_dims_;
|
||||
std::vector<int32_t> attention_dims_;
|
||||
std::vector<int32_t> num_encoder_layers_;
|
||||
std::vector<int32_t> cnn_module_kernels_;
|
||||
std::vector<int32_t> left_context_len_;
|
||||
|
||||
int32_t T_ = 0;
|
||||
int32_t decode_chunk_len_ = 0;
|
||||
|
||||
int32_t context_size_ = 2;
|
||||
int32_t vocab_size_ = 0;
|
||||
};
|
||||
|
||||
OnlineZipformerTransducerModelRknn::~OnlineZipformerTransducerModelRknn() =
|
||||
default;
|
||||
|
||||
OnlineZipformerTransducerModelRknn::OnlineZipformerTransducerModelRknn(
|
||||
const OnlineModelConfig &config)
|
||||
: impl_(std::make_unique<Impl>(config)) {}
|
||||
|
||||
template <typename Manager>
|
||||
OnlineZipformerTransducerModelRknn::OnlineZipformerTransducerModelRknn(
|
||||
Manager *mgr, const OnlineModelConfig &config)
|
||||
: impl_(std::make_unique<OnlineZipformerTransducerModelRknn>(mgr, config)) {
|
||||
}
|
||||
|
||||
std::vector<std::vector<uint8_t>>
|
||||
OnlineZipformerTransducerModelRknn::GetEncoderInitStates() const {
|
||||
return impl_->GetEncoderInitStates();
|
||||
}
|
||||
|
||||
std::pair<std::vector<float>, std::vector<std::vector<uint8_t>>>
|
||||
OnlineZipformerTransducerModelRknn::RunEncoder(
|
||||
std::vector<float> features,
|
||||
std::vector<std::vector<uint8_t>> states) const {
|
||||
return impl_->RunEncoder(std::move(features), std::move(states));
|
||||
}
|
||||
|
||||
std::vector<float> OnlineZipformerTransducerModelRknn::RunDecoder(
|
||||
std::vector<int64_t> decoder_input) const {
|
||||
return impl_->RunDecoder(std::move(decoder_input));
|
||||
}
|
||||
|
||||
std::vector<float> OnlineZipformerTransducerModelRknn::RunJoiner(
|
||||
const float *encoder_out, const float *decoder_out) const {
|
||||
return impl_->RunJoiner(encoder_out, decoder_out);
|
||||
}
|
||||
|
||||
int32_t OnlineZipformerTransducerModelRknn::ContextSize() const {
|
||||
return impl_->ContextSize();
|
||||
}
|
||||
|
||||
int32_t OnlineZipformerTransducerModelRknn::ChunkSize() const {
|
||||
return impl_->ChunkSize();
|
||||
}
|
||||
|
||||
int32_t OnlineZipformerTransducerModelRknn::ChunkShift() const {
|
||||
return impl_->ChunkShift();
|
||||
}
|
||||
|
||||
int32_t OnlineZipformerTransducerModelRknn::VocabSize() const {
|
||||
return impl_->VocabSize();
|
||||
}
|
||||
|
||||
rknn_tensor_attr OnlineZipformerTransducerModelRknn::GetEncoderOutAttr() const {
|
||||
return impl_->GetEncoderOutAttr();
|
||||
}
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
template OnlineZipformerTransducerModelRknn::OnlineZipformerTransducerModelRknn(
|
||||
AAssetManager *mgr, const OnlineModelConfig &config);
|
||||
#endif
|
||||
|
||||
#if __OHOS__
|
||||
template OnlineZipformerTransducerModelRknn::OnlineZipformerTransducerModelRknn(
|
||||
NativeResourceManager *mgr, const OnlineModelConfig &config);
|
||||
#endif
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
@ -0,0 +1,57 @@
|
||||
// sherpa-onnx/csrc/rknn/online-zipformer-transducer-model-rknn.h
|
||||
//
|
||||
// Copyright (c) 2025 Xiaomi Corporation
|
||||
#ifndef SHERPA_ONNX_CSRC_RKNN_ONLINE_ZIPFORMER_TRANSDUCER_MODEL_RKNN_H_
|
||||
#define SHERPA_ONNX_CSRC_RKNN_ONLINE_ZIPFORMER_TRANSDUCER_MODEL_RKNN_H_
|
||||
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "rknn_api.h" // NOLINT
|
||||
#include "sherpa-onnx/csrc/online-model-config.h"
|
||||
#include "sherpa-onnx/csrc/online-transducer-model.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
// this is for zipformer v1, i.e., the folder
|
||||
// pruned_transducer_statelss7_streaming from icefall
|
||||
class OnlineZipformerTransducerModelRknn {
|
||||
public:
|
||||
~OnlineZipformerTransducerModelRknn();
|
||||
|
||||
explicit OnlineZipformerTransducerModelRknn(const OnlineModelConfig &config);
|
||||
|
||||
template <typename Manager>
|
||||
OnlineZipformerTransducerModelRknn(Manager *mgr,
|
||||
const OnlineModelConfig &config);
|
||||
|
||||
std::vector<std::vector<uint8_t>> GetEncoderInitStates() const;
|
||||
|
||||
std::pair<std::vector<float>, std::vector<std::vector<uint8_t>>> RunEncoder(
|
||||
std::vector<float> features,
|
||||
std::vector<std::vector<uint8_t>> states) const;
|
||||
|
||||
std::vector<float> RunDecoder(std::vector<int64_t> decoder_input) const;
|
||||
|
||||
std::vector<float> RunJoiner(const float *encoder_out,
|
||||
const float *decoder_out) const;
|
||||
|
||||
int32_t ContextSize() const;
|
||||
|
||||
int32_t ChunkSize() const;
|
||||
|
||||
int32_t ChunkShift() const;
|
||||
|
||||
int32_t VocabSize() const;
|
||||
|
||||
rknn_tensor_attr GetEncoderOutAttr() const;
|
||||
|
||||
private:
|
||||
class Impl;
|
||||
std::unique_ptr<Impl> impl_;
|
||||
};
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
#endif // SHERPA_ONNX_CSRC_RKNN_ONLINE_ZIPFORMER_TRANSDUCER_MODEL_RKNN_H_
|
||||
@ -17,6 +17,7 @@
|
||||
#include "rawfile/raw_file_manager.h"
|
||||
#endif
|
||||
|
||||
#include "sherpa-onnx/csrc/file-utils.h"
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
#include "sherpa-onnx/csrc/onnx-utils.h"
|
||||
#include "sherpa-onnx/csrc/session.h"
|
||||
|
||||
@ -12,6 +12,7 @@
|
||||
#include "rawfile/raw_file_manager.h"
|
||||
#endif
|
||||
|
||||
#include "sherpa-onnx/csrc/file-utils.h"
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
#include "sherpa-onnx/csrc/onnx-utils.h"
|
||||
#include "sherpa-onnx/csrc/speaker-embedding-extractor-general-impl.h"
|
||||
|
||||
@ -17,6 +17,7 @@
|
||||
#include "rawfile/raw_file_manager.h"
|
||||
#endif
|
||||
|
||||
#include "sherpa-onnx/csrc/file-utils.h"
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
#include "sherpa-onnx/csrc/onnx-utils.h"
|
||||
#include "sherpa-onnx/csrc/session.h"
|
||||
|
||||
@ -17,6 +17,7 @@
|
||||
#include "rawfile/raw_file_manager.h"
|
||||
#endif
|
||||
|
||||
#include "sherpa-onnx/csrc/file-utils.h"
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
#include "sherpa-onnx/csrc/onnx-utils.h"
|
||||
#include "sherpa-onnx/csrc/session.h"
|
||||
|
||||
@ -10,6 +10,7 @@
|
||||
#include "android/asset_manager_jni.h"
|
||||
#endif
|
||||
|
||||
#include "sherpa-onnx/csrc/file-utils.h"
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
#include "sherpa-onnx/csrc/onnx-utils.h"
|
||||
#include "sherpa-onnx/csrc/spoken-language-identification-whisper-impl.h"
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user