From 4d79e6a00722daf1d9cdfdf8e6f9ee262a80a8c4 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Mon, 24 Feb 2025 19:07:37 +0800 Subject: [PATCH] Add C++ API for streaming zipformer ASR on RK NPU (#1908) --- .../workflows/aarch64-linux-gnu-shared.yaml | 6 +- .../workflows/aarch64-linux-gnu-static.yaml | 6 +- .github/workflows/android-static.yaml | 8 +- .github/workflows/arm-linux-gnueabihf.yaml | 6 +- .github/workflows/dot-net.yaml | 6 +- .github/workflows/linux-jni-aarch64.yaml | 8 +- .github/workflows/linux-jni.yaml | 8 +- .github/workflows/macos-jni.yaml | 6 +- .github/workflows/riscv64-linux.yaml | 6 +- .github/workflows/rknn-linux-aarch64.yaml | 241 ++++++ .github/workflows/windows-x64-jni.yaml | 6 +- .github/workflows/windows-x64.yaml | 6 +- .github/workflows/windows-x86.yaml | 4 +- CMakeLists.txt | 8 +- build-aarch64-linux-gnu.sh | 2 +- build-rknn-linux-aarch64.sh | 101 +++ scripts/dotnet/run.sh | 4 +- sherpa-onnx/csrc/CMakeLists.txt | 16 + sherpa-onnx/csrc/file-utils.cc | 58 ++ sherpa-onnx/csrc/file-utils.h | 21 + sherpa-onnx/csrc/hifigan-vocoder.cc | 1 + sherpa-onnx/csrc/offline-ced-model.cc | 1 + .../csrc/offline-ct-transformer-model.cc | 1 + sherpa-onnx/csrc/offline-ctc-model.cc | 1 + .../csrc/offline-fire-red-asr-model.cc | 1 + sherpa-onnx/csrc/offline-moonshine-model.cc | 1 + .../csrc/offline-nemo-enc-dec-ctc-model.cc | 1 + sherpa-onnx/csrc/offline-paraformer-model.cc | 1 + sherpa-onnx/csrc/offline-recognizer-impl.cc | 2 +- sherpa-onnx/csrc/offline-rnn-lm.cc | 1 + sherpa-onnx/csrc/offline-sense-voice-model.cc | 1 + ...ine-speaker-segmentation-pyannote-model.cc | 1 + sherpa-onnx/csrc/offline-tdnn-ctc-model.cc | 1 + .../csrc/offline-telespeech-ctc-model.cc | 1 + sherpa-onnx/csrc/offline-transducer-model.cc | 1 + .../csrc/offline-transducer-nemo-model.cc | 1 + sherpa-onnx/csrc/offline-tts-kokoro-model.cc | 1 + sherpa-onnx/csrc/offline-tts-matcha-model.cc | 1 + sherpa-onnx/csrc/offline-tts-vits-model.cc | 1 + sherpa-onnx/csrc/offline-wenet-ctc-model.cc | 1 + sherpa-onnx/csrc/offline-whisper-model.cc | 1 + .../offline-zipformer-audio-tagging-model.cc | 1 + .../csrc/offline-zipformer-ctc-model.cc | 1 + sherpa-onnx/csrc/online-cnn-bilstm-model.cc | 1 + .../csrc/online-conformer-transducer-model.cc | 1 + .../csrc/online-lstm-transducer-model.cc | 1 + sherpa-onnx/csrc/online-model-config.cc | 17 +- sherpa-onnx/csrc/online-nemo-ctc-model.cc | 1 + sherpa-onnx/csrc/online-paraformer-model.cc | 1 + sherpa-onnx/csrc/online-recognizer-impl.cc | 23 +- sherpa-onnx/csrc/online-rnn-lm.cc | 1 + sherpa-onnx/csrc/online-stream.h | 3 +- sherpa-onnx/csrc/online-transducer-model.cc | 1 + .../csrc/online-transducer-nemo-model.cc | 1 + sherpa-onnx/csrc/online-wenet-ctc-model.cc | 1 + .../csrc/online-zipformer-transducer-model.cc | 1 + .../csrc/online-zipformer2-ctc-model.cc | 1 + .../online-zipformer2-transducer-model.cc | 1 + sherpa-onnx/csrc/onnx-utils.cc | 66 +- sherpa-onnx/csrc/onnx-utils.h | 20 - sherpa-onnx/csrc/rknn/macros.h | 19 + .../online-recognizer-transducer-rknn-impl.h | 231 ++++++ sherpa-onnx/csrc/rknn/online-stream-rknn.cc | 60 ++ sherpa-onnx/csrc/rknn/online-stream-rknn.h | 38 + ...e-transducer-greedy-search-decoder-rknn.cc | 94 +++ ...ne-transducer-greedy-search-decoder-rknn.h | 52 ++ .../online-zipformer-transducer-model-rknn.cc | 781 ++++++++++++++++++ .../online-zipformer-transducer-model-rknn.h | 57 ++ sherpa-onnx/csrc/silero-vad-model.cc | 1 + .../csrc/speaker-embedding-extractor-impl.cc | 1 + .../csrc/speaker-embedding-extractor-model.cc | 1 + .../speaker-embedding-extractor-nemo-model.cc | 1 + .../spoken-language-identification-impl.cc | 1 + 73 files changed, 1909 insertions(+), 120 deletions(-) create mode 100644 .github/workflows/rknn-linux-aarch64.yaml create mode 100755 build-rknn-linux-aarch64.sh create mode 100644 sherpa-onnx/csrc/rknn/macros.h create mode 100644 sherpa-onnx/csrc/rknn/online-recognizer-transducer-rknn-impl.h create mode 100644 sherpa-onnx/csrc/rknn/online-stream-rknn.cc create mode 100644 sherpa-onnx/csrc/rknn/online-stream-rknn.h create mode 100644 sherpa-onnx/csrc/rknn/online-transducer-greedy-search-decoder-rknn.cc create mode 100644 sherpa-onnx/csrc/rknn/online-transducer-greedy-search-decoder-rknn.h create mode 100644 sherpa-onnx/csrc/rknn/online-zipformer-transducer-model-rknn.cc create mode 100644 sherpa-onnx/csrc/rknn/online-zipformer-transducer-model-rknn.h diff --git a/.github/workflows/aarch64-linux-gnu-shared.yaml b/.github/workflows/aarch64-linux-gnu-shared.yaml index 8ea8ccde..5747ad09 100644 --- a/.github/workflows/aarch64-linux-gnu-shared.yaml +++ b/.github/workflows/aarch64-linux-gnu-shared.yaml @@ -199,6 +199,7 @@ jobs: timeout_seconds: 200 shell: bash command: | + SHERPA_ONNX_VERSION=$(grep "SHERPA_ONNX_VERSION" ./CMakeLists.txt | cut -d " " -f 2 | cut -d '"' -f 2) git config --global user.email "csukuangfj@gmail.com" git config --global user.name "Fangjun Kuang" @@ -207,9 +208,10 @@ jobs: GIT_LFS_SKIP_SMUDGE=1 git clone https://csukuangfj:$HF_TOKEN@huggingface.co/csukuangfj/sherpa-onnx-libs huggingface cd huggingface - mkdir -p aarch64 + dst=aarch64/$SHERPA_ONNX_VERSION + mkdir -p $dst - cp -v ../sherpa-onnx-*-shared*.tar.bz2 ./aarch64 + cp -v ../sherpa-onnx-*-shared*.tar.bz2 $dst/ git status git lfs track "*.bz2" diff --git a/.github/workflows/aarch64-linux-gnu-static.yaml b/.github/workflows/aarch64-linux-gnu-static.yaml index 48b0bbe2..376d4f51 100644 --- a/.github/workflows/aarch64-linux-gnu-static.yaml +++ b/.github/workflows/aarch64-linux-gnu-static.yaml @@ -124,6 +124,7 @@ jobs: timeout_seconds: 200 shell: bash command: | + SHERPA_ONNX_VERSION=$(grep "SHERPA_ONNX_VERSION" ./CMakeLists.txt | cut -d " " -f 2 | cut -d '"' -f 2) git config --global user.email "csukuangfj@gmail.com" git config --global user.name "Fangjun Kuang" @@ -132,9 +133,10 @@ jobs: GIT_LFS_SKIP_SMUDGE=1 git clone https://csukuangfj:$HF_TOKEN@huggingface.co/csukuangfj/sherpa-onnx-libs huggingface cd huggingface - mkdir -p aarch64 + dst=aarch64/$SHERPA_ONNX_VERSION + mkdir -p $dst - cp -v ../sherpa-onnx-*-static.tar.bz2 ./aarch64 + cp -v ../sherpa-onnx-*-static.tar.bz2 $dst/ git status git lfs track "*.bz2" diff --git a/.github/workflows/android-static.yaml b/.github/workflows/android-static.yaml index 7dad8128..957c51e8 100644 --- a/.github/workflows/android-static.yaml +++ b/.github/workflows/android-static.yaml @@ -140,6 +140,7 @@ jobs: timeout_seconds: 200 shell: bash command: | + SHERPA_ONNX_VERSION=$(grep "SHERPA_ONNX_VERSION" ./CMakeLists.txt | cut -d " " -f 2 | cut -d '"' -f 2) git config --global user.email "csukuangfj@gmail.com" git config --global user.name "Fangjun Kuang" du -h -d1 . @@ -150,8 +151,10 @@ jobs: GIT_LFS_SKIP_SMUDGE=1 git clone https://csukuangfj:$HF_TOKEN@huggingface.co/csukuangfj/sherpa-onnx-libs huggingface cd huggingface + dst=$SHERPA_ONNX_VERSION + mkdir -p $dst - cp -v ../sherpa-onnx-*-android*.tar.bz2 ./ + cp -v ../sherpa-onnx-*-android*.tar.bz2 $dst/ git status git lfs track "*.bz2" @@ -263,6 +266,7 @@ jobs: timeout_seconds: 200 shell: bash command: | + SHERPA_ONNX_VERSION=$(grep "SHERPA_ONNX_VERSION" ./CMakeLists.txt | cut -d " " -f 2 | cut -d '"' -f 2) git config --global user.email "csukuangfj@gmail.com" git config --global user.name "Fangjun Kuang" du -h -d1 . @@ -273,7 +277,7 @@ jobs: GIT_LFS_SKIP_SMUDGE=1 git clone https://csukuangfj:$HF_TOKEN@huggingface.co/csukuangfj/sherpa-onnx-libs huggingface cd huggingface - dst=android/aar + dst=android/aar/$SHERPA_ONNX_VERSION mkdir -p $dst cp -v ../*.aar $dst diff --git a/.github/workflows/arm-linux-gnueabihf.yaml b/.github/workflows/arm-linux-gnueabihf.yaml index 63a5cf41..6b1aa5c4 100644 --- a/.github/workflows/arm-linux-gnueabihf.yaml +++ b/.github/workflows/arm-linux-gnueabihf.yaml @@ -199,6 +199,7 @@ jobs: timeout_seconds: 200 shell: bash command: | + SHERPA_ONNX_VERSION=$(grep "SHERPA_ONNX_VERSION" ./CMakeLists.txt | cut -d " " -f 2 | cut -d '"' -f 2) git config --global user.email "csukuangfj@gmail.com" git config --global user.name "Fangjun Kuang" @@ -206,9 +207,10 @@ jobs: GIT_LFS_SKIP_SMUDGE=1 git clone https://csukuangfj:$HF_TOKEN@huggingface.co/csukuangfj/sherpa-onnx-libs huggingface cd huggingface - mkdir -p arm32 + dst=arm32/$SHERPA_ONNX_VERSION + mkdir -p $dst - cp -v ../sherpa-onnx-*.tar.bz2 ./arm32 + cp -v ../sherpa-onnx-*.tar.bz2 $dst/ git status git lfs track "*.bz2" diff --git a/.github/workflows/dot-net.yaml b/.github/workflows/dot-net.yaml index 899cb999..54fcdc16 100644 --- a/.github/workflows/dot-net.yaml +++ b/.github/workflows/dot-net.yaml @@ -83,6 +83,7 @@ jobs: timeout_seconds: 200 shell: bash command: | + SHERPA_ONNX_VERSION=$(grep "SHERPA_ONNX_VERSION" ./CMakeLists.txt | cut -d " " -f 2 | cut -d '"' -f 2) git config --global user.email "csukuangfj@gmail.com" git config --global user.name "Fangjun Kuang" @@ -95,9 +96,10 @@ jobs: cd huggingface git fetch git pull - mkdir -p windows-for-dotnet + dst=windows-for-dotnet/$SHERPA_ONNX_VERSION + mkdir -p $dst - cp -v ../sherpa-onnx-*.tar.bz2 ./windows-for-dotnet + cp -v ../sherpa-onnx-*.tar.bz2 $dst/ git status git lfs track "*.bz2" diff --git a/.github/workflows/linux-jni-aarch64.yaml b/.github/workflows/linux-jni-aarch64.yaml index 4b1df3f5..1877b2a6 100644 --- a/.github/workflows/linux-jni-aarch64.yaml +++ b/.github/workflows/linux-jni-aarch64.yaml @@ -138,6 +138,7 @@ jobs: timeout_seconds: 200 shell: bash command: | + SHERPA_ONNX_VERSION=$(grep "SHERPA_ONNX_VERSION" ./CMakeLists.txt | cut -d " " -f 2 | cut -d '"' -f 2) git config --global user.email "csukuangfj@gmail.com" git config --global user.name "Fangjun Kuang" @@ -146,10 +147,11 @@ jobs: GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/sherpa-onnx-libs huggingface cd huggingface - mkdir -p jni + dst=jni/$SHERPA_ONNX_VERSION + mkdir -p $dst - cp -v ../sherpa-onnx-*.tar.bz2 ./jni - cp -v ../*.jar ./jni + cp -v ../sherpa-onnx-*.tar.bz2 $dst/ + cp -v ../*.jar $dst/ git status git lfs track "*.bz2" diff --git a/.github/workflows/linux-jni.yaml b/.github/workflows/linux-jni.yaml index b56a6ccf..be2a3f7a 100644 --- a/.github/workflows/linux-jni.yaml +++ b/.github/workflows/linux-jni.yaml @@ -197,6 +197,7 @@ jobs: timeout_seconds: 200 shell: bash command: | + SHERPA_ONNX_VERSION=$(grep "SHERPA_ONNX_VERSION" ./CMakeLists.txt | cut -d " " -f 2 | cut -d '"' -f 2) git config --global user.email "csukuangfj@gmail.com" git config --global user.name "Fangjun Kuang" @@ -205,10 +206,11 @@ jobs: GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/sherpa-onnx-libs huggingface cd huggingface - mkdir -p jni + dst=jni/$SHERPA_ONNX_VERSION + mkdir -p $dst - cp -v ../sherpa-onnx-*.tar.bz2 ./jni - cp -v ../*.jar ./jni + cp -v ../sherpa-onnx-*.tar.bz2 $dst/ + cp -v ../*.jar $dst/ git status git lfs track "*.bz2" diff --git a/.github/workflows/macos-jni.yaml b/.github/workflows/macos-jni.yaml index 80ba5c53..a59c1c38 100644 --- a/.github/workflows/macos-jni.yaml +++ b/.github/workflows/macos-jni.yaml @@ -113,6 +113,7 @@ jobs: timeout_seconds: 200 shell: bash command: | + SHERPA_ONNX_VERSION=$(grep "SHERPA_ONNX_VERSION" ./CMakeLists.txt | cut -d " " -f 2 | cut -d '"' -f 2) git config --global user.email "csukuangfj@gmail.com" git config --global user.name "Fangjun Kuang" @@ -121,9 +122,10 @@ jobs: GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/sherpa-onnx-libs huggingface cd huggingface - mkdir -p jni + dst=jni/$SHERPA_ONNX_VERSION + mkdir -p $dst - cp -v ../sherpa-onnx-*.tar.bz2 ./jni + cp -v ../sherpa-onnx-*.tar.bz2 $dst git status git lfs track "*.bz2" diff --git a/.github/workflows/riscv64-linux.yaml b/.github/workflows/riscv64-linux.yaml index f81d5cb2..7f21f98c 100644 --- a/.github/workflows/riscv64-linux.yaml +++ b/.github/workflows/riscv64-linux.yaml @@ -239,6 +239,7 @@ jobs: timeout_seconds: 200 shell: bash command: | + SHERPA_ONNX_VERSION=$(grep "SHERPA_ONNX_VERSION" ./CMakeLists.txt | cut -d " " -f 2 | cut -d '"' -f 2) git config --global user.email "csukuangfj@gmail.com" git config --global user.name "Fangjun Kuang" @@ -248,9 +249,10 @@ jobs: GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/sherpa-onnx-libs huggingface cd huggingface - mkdir -p riscv64 + dst=riscv64/$SHERPA_ONNX_VERSION + mkdir -p $dst - cp -v ../sherpa-onnx-*-shared.tar.bz2 ./riscv64 + cp -v ../sherpa-onnx-*-shared.tar.bz2 $dst/ git status git lfs track "*.bz2" diff --git a/.github/workflows/rknn-linux-aarch64.yaml b/.github/workflows/rknn-linux-aarch64.yaml new file mode 100644 index 00000000..32b2452f --- /dev/null +++ b/.github/workflows/rknn-linux-aarch64.yaml @@ -0,0 +1,241 @@ +name: rknn-linux-aarch64 + +on: + push: + branches: + - master + - rknn-c-api-2 + tags: + - 'v[0-9]+.[0-9]+.[0-9]+*' + paths: + - '.github/workflows/rknn-linux-aarch64.yaml' + - 'cmake/**' + - 'sherpa-onnx/csrc/*' + - 'sherpa-onnx/csrc/rknn/*' + - 'sherpa-onnx/c-api/*' + - 'toolchains/aarch64-linux-gnu.toolchain.cmake' + pull_request: + branches: + - master + paths: + - '.github/workflows/rknn-linux-aarch64.yaml' + - 'cmake/**' + - 'sherpa-onnx/csrc/*' + - 'sherpa-onnx/csrc/rknn/*' + - 'sherpa-onnx/c-api/*' + - 'toolchains/aarch64-linux-gnu.toolchain.cmake' + + workflow_dispatch: + +concurrency: + group: aarch64-linux-gnu-shared-${{ github.ref }} + cancel-in-progress: true + +jobs: + rknn_linux_aarch64: + runs-on: ${{ matrix.os }} + name: rknn shared ${{ matrix.shared }} + strategy: + fail-fast: false + matrix: + include: + - os: ubuntu-22.04-arm + shared: ON + - os: ubuntu-22.04-arm + shared: OFF + + steps: + - uses: actions/checkout@v4 + with: + fetch-depth: 0 + + - name: ccache + uses: hendrikmuhs/ccache-action@v1.2 + with: + key: ${{ matrix.os }}-${{ matrix.shared }}-rknn-linux-aarch64 + + - name: Download rknn-toolkit2 + shell: bash + run: | + git clone --depth 1 https://github.com/airockchip/rknn-toolkit2 + + - name: Build sherpa-onnx + shell: bash + run: | + export CMAKE_CXX_COMPILER_LAUNCHER=ccache + export PATH="/usr/lib/ccache:/usr/local/opt/ccache/libexec:$PATH" + cmake --version + + echo "config: ${{ matrix.config }}" + uname -a + which gcc + + gcc --version + g++ --version + + echo "pwd" + + ls -lh + + git clone --depth 1 --branch v1.2.12 https://github.com/alsa-project/alsa-lib + pushd alsa-lib + ./gitcompile + popd + + export SHERPA_ONNX_RKNN_TOOLKIT2_PATH=$PWD/rknn-toolkit2 + export SHERPA_ONNX_RKNN_TOOLKIT2_LIB_DIR=$SHERPA_ONNX_RKNN_TOOLKIT2_PATH/rknpu2/runtime/Linux/librknn_api/aarch64 + export CPLUS_INCLUDE_PATH=$SHERPA_ONNX_RKNN_TOOLKIT2_PATH/rknpu2/runtime/Linux/librknn_api/include:$CPLUS_INCLUDE_PATH + export CPLUS_INCLUDE_PATH=$PWD/alsa-lib/include:$CPLUS_INCLUDE_PATH + export SHERPA_ONNX_ALSA_LIB_DIR=$PWD/alsa-lib/src/.libs + + mkdir build + cd build + + BUILD_SHARED_LIBS=${{ matrix.shared }} + + cmake \ + -DBUILD_SHARED_LIBS=ON \ + -DCMAKE_INSTALL_PREFIX=./install \ + -DSHERPA_ONNX_ENABLE_RKNN=ON \ + -DBUILD_SHARED_LIBS=$BUILD_SHARED_LIBS \ + .. + + make -j4 install + + rm -rf install/lib/pkgconfig + rm -fv install/lib/cargs.h + rm -fv install/lib/libcargs.so + + - name: Display system info + shell: bash + run: | + uname -a + gcc --version + g++ --version + + - name: Display generated files + shell: bash + run: | + export SHERPA_ONNX_RKNN_TOOLKIT2_PATH=$PWD/rknn-toolkit2 + export LD_LIBRARY_PATH=$SHERPA_ONNX_RKNN_TOOLKIT2_PATH/rknpu2/runtime/Linux/librknn_api/aarch64:$LD_LIBRARY_PATH + + cd build/install + + ls -lh bin + + echo "---" + + ls -lh lib + + file bin/sherpa-onnx + + readelf -d bin/sherpa-onnx + + ldd bin/sherpa-onnx + + ./bin/sherpa-onnx --help + + - name: Copy files + shell: bash + run: | + SHERPA_ONNX_VERSION=v$(grep "SHERPA_ONNX_VERSION" ./CMakeLists.txt | cut -d " " -f 2 | cut -d '"' -f 2) + + if [[ ${{ matrix.shared }} == ON ]]; then + suffix=shared + else + suffix=static + fi + + dst=sherpa-onnx-${SHERPA_ONNX_VERSION}-rknn-linux-aarch64-$suffix + mkdir $dst + + cp -a build/install/bin $dst/ + + if [[ ${{ matrix.shared }} == ON ]]; then + cp -v build/install/lib/lib*.so $dst/ + fi + + ls -lh build/install/lib + ls -lh build/install/bin + + ls -lh $dst/bin/ + echo "strip" + strip $dst/bin/* + + echo "after strip" + ls -lh $dst/bin/ + + tree $dst + + tar cjvf ${dst}.tar.bz2 $dst + + - uses: actions/upload-artifact@v4 + with: + name: sherpa-onnx-linux-linux-aarch64-shared-${{ matrix.shared }} + path: sherpa-onnx-*linux-aarch64*.tar.bz2 + + # https://huggingface.co/docs/hub/spaces-github-actions + - name: Publish to huggingface + if: (github.repository_owner == 'csukuangfj' || github.repository_owner == 'k2-fsa') && (github.event_name == 'push' || github.event_name == 'workflow_dispatch') + env: + HF_TOKEN: ${{ secrets.HF_TOKEN }} + uses: nick-fields/retry@v3 + with: + max_attempts: 20 + timeout_seconds: 200 + shell: bash + command: | + SHERPA_ONNX_VERSION=$(grep "SHERPA_ONNX_VERSION" ./CMakeLists.txt | cut -d " " -f 2 | cut -d '"' -f 2) + + git config --global user.email "csukuangfj@gmail.com" + git config --global user.name "Fangjun Kuang" + + rm -rf huggingface + export GIT_CLONE_PROTECTION_ACTIVE=false + GIT_LFS_SKIP_SMUDGE=1 git clone https://csukuangfj:$HF_TOKEN@huggingface.co/csukuangfj/sherpa-onnx-libs huggingface + + cd huggingface + dst=rknn-linux-aarch64/$SHERPA_ONNX_VERSION + mkdir -p $dst + + cp -v ../sherpa-onnx-*rknn*-*.tar.bz2 $dst + + git status + git lfs track "*.bz2" + + git add . + + git commit -m "upload sherpa-onnx-${SHERPA_ONNX_VERSION}-rknn-linux-aarch64.tar.bz2" + + git push https://csukuangfj:$HF_TOKEN@huggingface.co/csukuangfj/sherpa-onnx-libs main + + - name: Release pre-compiled binaries and libs for rknn linux aarch64 + if: github.repository_owner == 'k2-fsa' && github.event_name == 'push' && contains(github.ref, 'refs/tags/') + uses: svenstaro/upload-release-action@v2 + with: + file_glob: true + overwrite: true + file: sherpa-onnx-*linux-aarch64*.tar.bz2 + + - name: Release pre-compiled binaries and libs for rknn linux aarch64 + # if: github.repository_owner == 'csukuangfj' && github.event_name == 'push' && contains(github.ref, 'refs/tags/') + uses: svenstaro/upload-release-action@v2 + with: + file_glob: true + overwrite: true + file: sherpa-onnx-*linux-aarch64*.tar.bz2 + repo_name: k2-fsa/sherpa-onnx + repo_token: ${{ secrets.UPLOAD_GH_SHERPA_ONNX_TOKEN }} + tag: v1.10.45 + + - name: Test offline Moonshine + if: matrix.build_type != 'Debug' + shell: bash + run: | + du -h -d1 . + export PATH=$PWD/build/install/bin:$PATH + export EXE=sherpa-onnx-offline + + readelf -d build/bin/sherpa-onnx-offline + + .github/scripts/test-offline-moonshine.sh diff --git a/.github/workflows/windows-x64-jni.yaml b/.github/workflows/windows-x64-jni.yaml index 6bb801b9..99cecf15 100644 --- a/.github/workflows/windows-x64-jni.yaml +++ b/.github/workflows/windows-x64-jni.yaml @@ -92,6 +92,7 @@ jobs: timeout_seconds: 200 shell: bash command: | + SHERPA_ONNX_VERSION=$(grep "SHERPA_ONNX_VERSION" ./CMakeLists.txt | cut -d " " -f 2 | cut -d '"' -f 2) git config --global user.email "csukuangfj@gmail.com" git config --global user.name "Fangjun Kuang" @@ -100,9 +101,10 @@ jobs: GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/sherpa-onnx-libs huggingface cd huggingface - mkdir -p jni + dst=jni/$SHERPA_ONNX_VERSION + mkdir -p $dst - cp -v ../sherpa-onnx-*.tar.bz2 ./jni + cp -v ../sherpa-onnx-*.tar.bz2 $dst git status git lfs track "*.bz2" diff --git a/.github/workflows/windows-x64.yaml b/.github/workflows/windows-x64.yaml index 6cf9f579..498f0b5c 100644 --- a/.github/workflows/windows-x64.yaml +++ b/.github/workflows/windows-x64.yaml @@ -128,6 +128,7 @@ jobs: timeout_seconds: 200 shell: bash command: | + SHERPA_ONNX_VERSION=$(grep "SHERPA_ONNX_VERSION" ./CMakeLists.txt | cut -d " " -f 2 | cut -d '"' -f 2) git config --global user.email "csukuangfj@gmail.com" git config --global user.name "Fangjun Kuang" @@ -136,9 +137,10 @@ jobs: GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/sherpa-onnx-libs huggingface cd huggingface - mkdir -p win64 + dst=win64/$SHERPA_ONNX_VERSION + mkdir -p $dst - cp -v ../sherpa-onnx-*.tar.bz2 ./win64 + cp -v ../sherpa-onnx-*.tar.bz2 $dst git status git lfs track "*.bz2" diff --git a/.github/workflows/windows-x86.yaml b/.github/workflows/windows-x86.yaml index b69c8244..2dedc05e 100644 --- a/.github/workflows/windows-x86.yaml +++ b/.github/workflows/windows-x86.yaml @@ -131,6 +131,7 @@ jobs: timeout_seconds: 200 shell: bash command: | + SHERPA_ONNX_VERSION=$(grep "SHERPA_ONNX_VERSION" ./CMakeLists.txt | cut -d " " -f 2 | cut -d '"' -f 2) git config --global user.email "csukuangfj@gmail.com" git config --global user.name "Fangjun Kuang" @@ -139,7 +140,8 @@ jobs: GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/sherpa-onnx-libs huggingface cd huggingface - mkdir -p win32 + dst=win32/$SHERPA_ONNX_VERSION + mkdir -p $dst cp -v ../sherpa-onnx-*.tar.bz2 ./win32 diff --git a/CMakeLists.txt b/CMakeLists.txt index 3f7998ce..49e0e77a 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -45,6 +45,7 @@ option(SHERPA_ONNX_LINK_LIBSTDCPP_STATICALLY "True to link libstdc++ statically. option(SHERPA_ONNX_USE_PRE_INSTALLED_ONNXRUNTIME_IF_AVAILABLE "True to use pre-installed onnxruntime if available" ON) option(SHERPA_ONNX_ENABLE_SANITIZER "Whether to enable ubsan and asan" OFF) option(SHERPA_ONNX_BUILD_C_API_EXAMPLES "Whether to enable C API examples" ON) +option(SHERPA_ONNX_ENABLE_RKNN "Whether to build for RKNN NPU " OFF) set(SHERPA_ONNX_LINUX_ARM64_GPU_ONNXRUNTIME_VERSION "1.11.0" CACHE STRING "Used only for Linux ARM64 GPU. If you use Jetson nano b01, then please set it to 1.11.0. If you use Jetson Orin NX, then set it to 1.16.0.If you use NVIDIA Jetson Orin Nano Engineering Reference Developer Kit Super - Jetpack 6.2 [L4T 36.4.3], then set it to 1.18.1") @@ -155,6 +156,7 @@ message(STATUS "SHERPA_ONNX_LINK_LIBSTDCPP_STATICALLY ${SHERPA_ONNX_LINK_LIBSTDC message(STATUS "SHERPA_ONNX_USE_PRE_INSTALLED_ONNXRUNTIME_IF_AVAILABLE ${SHERPA_ONNX_USE_PRE_INSTALLED_ONNXRUNTIME_IF_AVAILABLE}") message(STATUS "SHERPA_ONNX_ENABLE_SANITIZER: ${SHERPA_ONNX_ENABLE_SANITIZER}") message(STATUS "SHERPA_ONNX_BUILD_C_API_EXAMPLES: ${SHERPA_ONNX_BUILD_C_API_EXAMPLES}") +message(STATUS "SHERPA_ONNX_ENABLE_RKNN: ${SHERPA_ONNX_ENABLE_RKNN}") if(BUILD_SHARED_LIBS OR SHERPA_ONNX_ENABLE_JNI) set(CMAKE_CXX_VISIBILITY_PRESET hidden) @@ -199,7 +201,7 @@ if(SHERPA_ONNX_ENABLE_DIRECTML) message(STATUS "DirectML is enabled") add_definitions(-DSHERPA_ONNX_ENABLE_DIRECTML=1) else() - message(WARNING "DirectML is disabled") + message(STATUS "DirectML is disabled") add_definitions(-DSHERPA_ONNX_ENABLE_DIRECTML=0) endif() @@ -267,6 +269,10 @@ message(STATUS "C++ Standard version: ${CMAKE_CXX_STANDARD}") include(CheckIncludeFileCXX) +if(SHERPA_ONNX_ENABLE_RKNN) + add_definitions(-DSHERPA_ONNX_ENABLE_RKNN=1) +endif() + if(UNIX AND NOT APPLE AND NOT SHERPA_ONNX_ENABLE_WASM AND NOT CMAKE_SYSTEM_NAME STREQUAL Android AND NOT CMAKE_SYSTEM_NAME STREQUAL OHOS) check_include_file_cxx(alsa/asoundlib.h SHERPA_ONNX_HAS_ALSA) if(SHERPA_ONNX_HAS_ALSA) diff --git a/build-aarch64-linux-gnu.sh b/build-aarch64-linux-gnu.sh index a5e1a8d1..86abcc94 100755 --- a/build-aarch64-linux-gnu.sh +++ b/build-aarch64-linux-gnu.sh @@ -74,7 +74,7 @@ if [[ x"$BUILD_SHARED_LIBS" == x"" ]]; then fi if [[ x"$SHERPA_ONNX_ENABLE_GPU" == x"" ]]; then - # By default, use CPU + # By default, don't use CPU SHERPA_ONNX_ENABLE_GPU=OFF fi diff --git a/build-rknn-linux-aarch64.sh b/build-rknn-linux-aarch64.sh new file mode 100755 index 00000000..bf664aa8 --- /dev/null +++ b/build-rknn-linux-aarch64.sh @@ -0,0 +1,101 @@ +#!/usr/bin/env bash + +set -ex + +# Before you run this file, make sure you have first cloned +# https://github.com/airockchip/rknn-toolkit2 +# and set the environment variable SHERPA_ONNX_RKNN_TOOLKIT2_PATH + +if [ -z $SHERPA_ONNX_RKNN_TOOLKIT2_PATH ]; then + SHERPA_ONNX_RKNN_TOOLKIT2_PATH=/star-fj/fangjun/open-source/rknn-toolkit2 +fi + +if [ ! -d $SHERPA_ONNX_RKNN_TOOLKIT2_PATH ]; then + echo "Please first clone https://github.com/airockchip/rknn-toolkit2" + echo "You can use" + echo " git clone --depth 1 https://github.com/airockchip/rknn-toolkit2" + echo " export SHERPA_ONNX_RKNN_TOOLKIT2_PATH=$PWD/rknn-toolkit2" + + exit 1 +fi + +if [ ! -f $SHERPA_ONNX_RKNN_TOOLKIT2_PATH/rknpu2/runtime/Linux/librknn_api/include/rknn_api.h ]; then + echo "$SHERPA_ONNX_RKNN_TOOLKIT2_PATH/rknpu2/runtime/Linux/librknn_api/include/rknn_api.h does not exist" + exit 1 +fi + +if [ ! -f $SHERPA_ONNX_RKNN_TOOLKIT2_PATH/rknpu2/runtime/Linux/librknn_api/aarch64/librknnrt.so ]; then + echo "$SHERPA_ONNX_RKNN_TOOLKIT2_PATH/rknpu2/runtime/Linux/librknn_api/aarch64/librknnrt.so does not exist" + exit 1 +fi + +export SHERPA_ONNX_RKNN_TOOLKIT2_LIB_DIR=$SHERPA_ONNX_RKNN_TOOLKIT2_PATH/rknpu2/runtime/Linux/librknn_api/aarch64 + +export CPLUS_INCLUDE_PATH=$SHERPA_ONNX_RKNN_TOOLKIT2_PATH/rknpu2/runtime/Linux/librknn_api/include:$CPLUS_INCLUDE_PATH + +if ! command -v aarch64-linux-gnu-gcc &> /dev/null; then + echo "Please install a toolchain for cross-compiling." + echo "You can refer to: " + echo " https://k2-fsa.github.io/sherpa/onnx/install/rknn-linux-aarch64.html" + echo "for help." + exit 1 +fi + + +dir=$PWD/build-rknn-linux-aarch64 +mkdir -p $dir + +cd $dir + +if [ ! -f alsa-lib/src/.libs/libasound.so ]; then + echo "Start to cross-compile alsa-lib" + if [ ! -d alsa-lib ]; then + git clone --depth 1 --branch v1.2.12 https://github.com/alsa-project/alsa-lib + fi + # If it shows: + # ./gitcompile: line 79: libtoolize: command not found + # Please use: + # sudo apt-get install libtool m4 automake + # + pushd alsa-lib + CC=aarch64-linux-gnu-gcc ./gitcompile --host=aarch64-linux-gnu + popd + echo "Finish cross-compiling alsa-lib" +fi + +export CPLUS_INCLUDE_PATH=$PWD/alsa-lib/include:$CPLUS_INCLUDE_PATH +export SHERPA_ONNX_ALSA_LIB_DIR=$PWD/alsa-lib/src/.libs + +if [[ x"$BUILD_SHARED_LIBS" == x"" ]]; then + # By default, use shared link + BUILD_SHARED_LIBS=ON +fi + +cmake \ + -DBUILD_PIPER_PHONMIZE_EXE=OFF \ + -DBUILD_PIPER_PHONMIZE_TESTS=OFF \ + -DBUILD_ESPEAK_NG_EXE=OFF \ + -DBUILD_ESPEAK_NG_TESTS=OFF \ + -DCMAKE_INSTALL_PREFIX=./install \ + -DCMAKE_BUILD_TYPE=Release \ + -DSHERPA_ONNX_ENABLE_GPU=OFF \ + -DBUILD_SHARED_LIBS=$BUILD_SHARED_LIBS \ + -DSHERPA_ONNX_ENABLE_TESTS=OFF \ + -DSHERPA_ONNX_ENABLE_PYTHON=OFF \ + -DSHERPA_ONNX_ENABLE_CHECK=OFF \ + -DSHERPA_ONNX_ENABLE_PORTAUDIO=OFF \ + -DSHERPA_ONNX_ENABLE_JNI=OFF \ + -DSHERPA_ONNX_ENABLE_C_API=ON \ + -DSHERPA_ONNX_ENABLE_WEBSOCKET=ON \ + -DSHERPA_ONNX_ENABLE_RKNN=ON \ + -DCMAKE_TOOLCHAIN_FILE=../toolchains/aarch64-linux-gnu.toolchain.cmake \ + .. + +make VERBOSE=1 -j4 +make install/strip + +# Enable it if only needed +# cp -v $SHERPA_ONNX_ALSA_LIB_DIR/libasound.so* ./install/lib/ + +# See also +# https://github.com/airockchip/rknn-toolkit2/blob/master/rknpu2/examples/rknn_api_demo/build-linux.sh diff --git a/scripts/dotnet/run.sh b/scripts/dotnet/run.sh index 8fe448a9..dd525546 100755 --- a/scripts/dotnet/run.sh +++ b/scripts/dotnet/run.sh @@ -150,7 +150,7 @@ if [ ! -f $src_dir/windows-x86/sherpa-onnx-c-api.dll ]; then if [ -f $windows_x86_wheel ]; then cp -v $windows_x86_wheel . else - curl -OL https://$HF_MIRROR/csukuangfj/sherpa-onnx-libs/resolve/main/windows-for-dotnet/$windows_x86_wheel_filename + curl -OL https://$HF_MIRROR/csukuangfj/sherpa-onnx-libs/resolve/main/windows-for-dotnet/$SHERPA_ONNX_VERSION/$windows_x86_wheel_filename fi tar xvf $windows_x86_wheel_filename cp -v sherpa-onnx-${SHERPA_ONNX_VERSION}-win-x86/*dll ../ @@ -169,7 +169,7 @@ if [ ! -f $src_dir/windows-arm64/sherpa-onnx-c-api.dll ]; then if [ -f $windows_arm64_wheel ]; then cp -v $windows_arm64_wheel . else - curl -OL https://$HF_MIRROR/csukuangfj/sherpa-onnx-libs/resolve/main/windows-for-dotnet/$windows_arm64_wheel_filename + curl -OL https://$HF_MIRROR/csukuangfj/sherpa-onnx-libs/resolve/main/windows-for-dotnet/$SHERPA_ONNX_VERSION/$windows_arm64_wheel_filename fi tar xvf $windows_arm64_wheel_filename cp -v sherpa-onnx-${SHERPA_ONNX_VERSION}-win-arm64/*dll ../ diff --git a/sherpa-onnx/csrc/CMakeLists.txt b/sherpa-onnx/csrc/CMakeLists.txt index 40f0ee6c..aa50c4ab 100644 --- a/sherpa-onnx/csrc/CMakeLists.txt +++ b/sherpa-onnx/csrc/CMakeLists.txt @@ -151,6 +151,14 @@ list(APPEND sources online-punctuation-model-config.cc online-punctuation.cc ) +if(SHERPA_ONNX_ENABLE_RKNN) + list(APPEND sources + ./rknn/online-stream-rknn.cc + ./rknn/online-transducer-greedy-search-decoder-rknn.cc + ./rknn/online-zipformer-transducer-model-rknn.cc + ) + +endif() if(SHERPA_ONNX_ENABLE_TTS) list(APPEND sources @@ -230,6 +238,14 @@ if(SHERPA_ONNX_ENABLE_GPU) ) endif() +if(SHERPA_ONNX_ENABLE_RKNN) + if(DEFINED ENV{SHERPA_ONNX_RKNN_TOOLKIT2_LIB_DIR}) + target_link_libraries(sherpa-onnx-core -L$ENV{SHERPA_ONNX_RKNN_TOOLKIT2_LIB_DIR} -lrknnrt) + else() + target_link_libraries(sherpa-onnx-core rknnrt) + endif() +endif() + if(BUILD_SHARED_LIBS AND NOT DEFINED onnxruntime_lib_files) target_link_libraries(sherpa-onnx-core onnxruntime) else() diff --git a/sherpa-onnx/csrc/file-utils.cc b/sherpa-onnx/csrc/file-utils.cc index 8d87fd19..bc7881f2 100644 --- a/sherpa-onnx/csrc/file-utils.cc +++ b/sherpa-onnx/csrc/file-utils.cc @@ -5,6 +5,7 @@ #include "sherpa-onnx/csrc/file-utils.h" #include +#include #include #include "sherpa-onnx/csrc/macros.h" @@ -22,4 +23,61 @@ void AssertFileExists(const std::string &filename) { } } +std::vector ReadFile(const std::string &filename) { + std::ifstream input(filename, std::ios::binary); + std::vector buffer(std::istreambuf_iterator(input), {}); + return buffer; +} + +#if __ANDROID_API__ >= 9 +std::vector ReadFile(AAssetManager *mgr, const std::string &filename) { + AAsset *asset = AAssetManager_open(mgr, filename.c_str(), AASSET_MODE_BUFFER); + if (!asset) { + __android_log_print(ANDROID_LOG_FATAL, "sherpa-onnx", + "Read binary file: Load %s failed", filename.c_str()); + exit(-1); + } + + auto p = reinterpret_cast(AAsset_getBuffer(asset)); + size_t asset_length = AAsset_getLength(asset); + + std::vector buffer(p, p + asset_length); + AAsset_close(asset); + + return buffer; +} +#endif + +#if __OHOS__ +std::vector ReadFile(NativeResourceManager *mgr, + const std::string &filename) { + std::unique_ptr fp( + OH_ResourceManager_OpenRawFile(mgr, filename.c_str()), + OH_ResourceManager_CloseRawFile); + + if (!fp) { + std::ostringstream os; + os << "Read file '" << filename << "' failed."; + SHERPA_ONNX_LOGE("%s", os.str().c_str()); + return {}; + } + + auto len = static_cast(OH_ResourceManager_GetRawFileSize(fp.get())); + + std::vector buffer(len); + + int32_t n = OH_ResourceManager_ReadRawFile(fp.get(), buffer.data(), len); + + if (n != len) { + std::ostringstream os; + os << "Read file '" << filename << "' failed. Number of bytes read: " << n + << ". Expected bytes to read: " << len; + SHERPA_ONNX_LOGE("%s", os.str().c_str()); + return {}; + } + + return buffer; +} +#endif + } // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/file-utils.h b/sherpa-onnx/csrc/file-utils.h index a41f6c9c..27167e7f 100644 --- a/sherpa-onnx/csrc/file-utils.h +++ b/sherpa-onnx/csrc/file-utils.h @@ -7,6 +7,16 @@ #include #include +#include + +#if __ANDROID_API__ >= 9 +#include "android/asset_manager.h" +#include "android/asset_manager_jni.h" +#endif + +#if __OHOS__ +#include "rawfile/raw_file_manager.h" +#endif namespace sherpa_onnx { @@ -23,6 +33,17 @@ bool FileExists(const std::string &filename); */ void AssertFileExists(const std::string &filename); +std::vector ReadFile(const std::string &filename); + +#if __ANDROID_API__ >= 9 +std::vector ReadFile(AAssetManager *mgr, const std::string &filename); +#endif + +#if __OHOS__ +std::vector ReadFile(NativeResourceManager *mgr, + const std::string &filename); +#endif + } // namespace sherpa_onnx #endif // SHERPA_ONNX_CSRC_FILE_UTILS_H_ diff --git a/sherpa-onnx/csrc/hifigan-vocoder.cc b/sherpa-onnx/csrc/hifigan-vocoder.cc index b2ff2078..d6edae71 100644 --- a/sherpa-onnx/csrc/hifigan-vocoder.cc +++ b/sherpa-onnx/csrc/hifigan-vocoder.cc @@ -17,6 +17,7 @@ #include "rawfile/raw_file_manager.h" #endif +#include "sherpa-onnx/csrc/file-utils.h" #include "sherpa-onnx/csrc/macros.h" #include "sherpa-onnx/csrc/onnx-utils.h" #include "sherpa-onnx/csrc/session.h" diff --git a/sherpa-onnx/csrc/offline-ced-model.cc b/sherpa-onnx/csrc/offline-ced-model.cc index d6dd3529..241f03b3 100644 --- a/sherpa-onnx/csrc/offline-ced-model.cc +++ b/sherpa-onnx/csrc/offline-ced-model.cc @@ -7,6 +7,7 @@ #include #include +#include "sherpa-onnx/csrc/file-utils.h" #include "sherpa-onnx/csrc/onnx-utils.h" #include "sherpa-onnx/csrc/session.h" #include "sherpa-onnx/csrc/text-utils.h" diff --git a/sherpa-onnx/csrc/offline-ct-transformer-model.cc b/sherpa-onnx/csrc/offline-ct-transformer-model.cc index d616484b..ee016285 100644 --- a/sherpa-onnx/csrc/offline-ct-transformer-model.cc +++ b/sherpa-onnx/csrc/offline-ct-transformer-model.cc @@ -7,6 +7,7 @@ #include #include +#include "sherpa-onnx/csrc/file-utils.h" #include "sherpa-onnx/csrc/onnx-utils.h" #include "sherpa-onnx/csrc/session.h" #include "sherpa-onnx/csrc/text-utils.h" diff --git a/sherpa-onnx/csrc/offline-ctc-model.cc b/sherpa-onnx/csrc/offline-ctc-model.cc index 6ca5f005..10748829 100644 --- a/sherpa-onnx/csrc/offline-ctc-model.cc +++ b/sherpa-onnx/csrc/offline-ctc-model.cc @@ -18,6 +18,7 @@ #include "rawfile/raw_file_manager.h" #endif +#include "sherpa-onnx/csrc/file-utils.h" #include "sherpa-onnx/csrc/macros.h" #include "sherpa-onnx/csrc/offline-nemo-enc-dec-ctc-model.h" #include "sherpa-onnx/csrc/offline-tdnn-ctc-model.h" diff --git a/sherpa-onnx/csrc/offline-fire-red-asr-model.cc b/sherpa-onnx/csrc/offline-fire-red-asr-model.cc index bf453994..f2103892 100644 --- a/sherpa-onnx/csrc/offline-fire-red-asr-model.cc +++ b/sherpa-onnx/csrc/offline-fire-red-asr-model.cc @@ -20,6 +20,7 @@ #include "rawfile/raw_file_manager.h" #endif +#include "sherpa-onnx/csrc/file-utils.h" #include "sherpa-onnx/csrc/macros.h" #include "sherpa-onnx/csrc/onnx-utils.h" #include "sherpa-onnx/csrc/session.h" diff --git a/sherpa-onnx/csrc/offline-moonshine-model.cc b/sherpa-onnx/csrc/offline-moonshine-model.cc index dbd18a92..7c66b351 100644 --- a/sherpa-onnx/csrc/offline-moonshine-model.cc +++ b/sherpa-onnx/csrc/offline-moonshine-model.cc @@ -17,6 +17,7 @@ #include "rawfile/raw_file_manager.h" #endif +#include "sherpa-onnx/csrc/file-utils.h" #include "sherpa-onnx/csrc/macros.h" #include "sherpa-onnx/csrc/onnx-utils.h" #include "sherpa-onnx/csrc/session.h" diff --git a/sherpa-onnx/csrc/offline-nemo-enc-dec-ctc-model.cc b/sherpa-onnx/csrc/offline-nemo-enc-dec-ctc-model.cc index 18db415b..7759c101 100644 --- a/sherpa-onnx/csrc/offline-nemo-enc-dec-ctc-model.cc +++ b/sherpa-onnx/csrc/offline-nemo-enc-dec-ctc-model.cc @@ -13,6 +13,7 @@ #include "rawfile/raw_file_manager.h" #endif +#include "sherpa-onnx/csrc/file-utils.h" #include "sherpa-onnx/csrc/macros.h" #include "sherpa-onnx/csrc/onnx-utils.h" #include "sherpa-onnx/csrc/session.h" diff --git a/sherpa-onnx/csrc/offline-paraformer-model.cc b/sherpa-onnx/csrc/offline-paraformer-model.cc index c8c65c8c..5b8586ef 100644 --- a/sherpa-onnx/csrc/offline-paraformer-model.cc +++ b/sherpa-onnx/csrc/offline-paraformer-model.cc @@ -17,6 +17,7 @@ #include "rawfile/raw_file_manager.h" #endif +#include "sherpa-onnx/csrc/file-utils.h" #include "sherpa-onnx/csrc/macros.h" #include "sherpa-onnx/csrc/onnx-utils.h" #include "sherpa-onnx/csrc/session.h" diff --git a/sherpa-onnx/csrc/offline-recognizer-impl.cc b/sherpa-onnx/csrc/offline-recognizer-impl.cc index 59e74d4c..b74bbbbb 100644 --- a/sherpa-onnx/csrc/offline-recognizer-impl.cc +++ b/sherpa-onnx/csrc/offline-recognizer-impl.cc @@ -22,6 +22,7 @@ #include "fst/extensions/far/far.h" #include "kaldifst/csrc/kaldi-fst-io.h" #include "onnxruntime_cxx_api.h" // NOLINT +#include "sherpa-onnx/csrc/file-utils.h" #include "sherpa-onnx/csrc/macros.h" #include "sherpa-onnx/csrc/offline-recognizer-ctc-impl.h" #include "sherpa-onnx/csrc/offline-recognizer-fire-red-asr-impl.h" @@ -31,7 +32,6 @@ #include "sherpa-onnx/csrc/offline-recognizer-transducer-impl.h" #include "sherpa-onnx/csrc/offline-recognizer-transducer-nemo-impl.h" #include "sherpa-onnx/csrc/offline-recognizer-whisper-impl.h" -#include "sherpa-onnx/csrc/onnx-utils.h" #include "sherpa-onnx/csrc/text-utils.h" namespace sherpa_onnx { diff --git a/sherpa-onnx/csrc/offline-rnn-lm.cc b/sherpa-onnx/csrc/offline-rnn-lm.cc index 665b775b..8f9425da 100644 --- a/sherpa-onnx/csrc/offline-rnn-lm.cc +++ b/sherpa-onnx/csrc/offline-rnn-lm.cc @@ -18,6 +18,7 @@ #endif #include "onnxruntime_cxx_api.h" // NOLINT +#include "sherpa-onnx/csrc/file-utils.h" #include "sherpa-onnx/csrc/macros.h" #include "sherpa-onnx/csrc/onnx-utils.h" #include "sherpa-onnx/csrc/session.h" diff --git a/sherpa-onnx/csrc/offline-sense-voice-model.cc b/sherpa-onnx/csrc/offline-sense-voice-model.cc index 04e7cd22..95664e22 100644 --- a/sherpa-onnx/csrc/offline-sense-voice-model.cc +++ b/sherpa-onnx/csrc/offline-sense-voice-model.cc @@ -17,6 +17,7 @@ #include "rawfile/raw_file_manager.h" #endif +#include "sherpa-onnx/csrc/file-utils.h" #include "sherpa-onnx/csrc/macros.h" #include "sherpa-onnx/csrc/onnx-utils.h" #include "sherpa-onnx/csrc/session.h" diff --git a/sherpa-onnx/csrc/offline-speaker-segmentation-pyannote-model.cc b/sherpa-onnx/csrc/offline-speaker-segmentation-pyannote-model.cc index 093e871b..53671ec4 100644 --- a/sherpa-onnx/csrc/offline-speaker-segmentation-pyannote-model.cc +++ b/sherpa-onnx/csrc/offline-speaker-segmentation-pyannote-model.cc @@ -17,6 +17,7 @@ #include "rawfile/raw_file_manager.h" #endif +#include "sherpa-onnx/csrc/file-utils.h" #include "sherpa-onnx/csrc/onnx-utils.h" #include "sherpa-onnx/csrc/session.h" diff --git a/sherpa-onnx/csrc/offline-tdnn-ctc-model.cc b/sherpa-onnx/csrc/offline-tdnn-ctc-model.cc index de441c48..ca23210f 100644 --- a/sherpa-onnx/csrc/offline-tdnn-ctc-model.cc +++ b/sherpa-onnx/csrc/offline-tdnn-ctc-model.cc @@ -15,6 +15,7 @@ #include "rawfile/raw_file_manager.h" #endif +#include "sherpa-onnx/csrc/file-utils.h" #include "sherpa-onnx/csrc/macros.h" #include "sherpa-onnx/csrc/onnx-utils.h" #include "sherpa-onnx/csrc/session.h" diff --git a/sherpa-onnx/csrc/offline-telespeech-ctc-model.cc b/sherpa-onnx/csrc/offline-telespeech-ctc-model.cc index d87e47a0..f6b2574b 100644 --- a/sherpa-onnx/csrc/offline-telespeech-ctc-model.cc +++ b/sherpa-onnx/csrc/offline-telespeech-ctc-model.cc @@ -13,6 +13,7 @@ #include "rawfile/raw_file_manager.h" #endif +#include "sherpa-onnx/csrc/file-utils.h" #include "sherpa-onnx/csrc/macros.h" #include "sherpa-onnx/csrc/onnx-utils.h" #include "sherpa-onnx/csrc/session.h" diff --git a/sherpa-onnx/csrc/offline-transducer-model.cc b/sherpa-onnx/csrc/offline-transducer-model.cc index da519cc2..a08854fe 100644 --- a/sherpa-onnx/csrc/offline-transducer-model.cc +++ b/sherpa-onnx/csrc/offline-transducer-model.cc @@ -17,6 +17,7 @@ #include "rawfile/raw_file_manager.h" #endif +#include "sherpa-onnx/csrc/file-utils.h" #include "sherpa-onnx/csrc/macros.h" #include "sherpa-onnx/csrc/offline-transducer-decoder.h" #include "sherpa-onnx/csrc/onnx-utils.h" diff --git a/sherpa-onnx/csrc/offline-transducer-nemo-model.cc b/sherpa-onnx/csrc/offline-transducer-nemo-model.cc index bd6f1ab5..6fbb2b61 100644 --- a/sherpa-onnx/csrc/offline-transducer-nemo-model.cc +++ b/sherpa-onnx/csrc/offline-transducer-nemo-model.cc @@ -18,6 +18,7 @@ #include "rawfile/raw_file_manager.h" #endif +#include "sherpa-onnx/csrc/file-utils.h" #include "sherpa-onnx/csrc/macros.h" #include "sherpa-onnx/csrc/offline-transducer-decoder.h" #include "sherpa-onnx/csrc/onnx-utils.h" diff --git a/sherpa-onnx/csrc/offline-tts-kokoro-model.cc b/sherpa-onnx/csrc/offline-tts-kokoro-model.cc index 9f77207b..96b14f79 100644 --- a/sherpa-onnx/csrc/offline-tts-kokoro-model.cc +++ b/sherpa-onnx/csrc/offline-tts-kokoro-model.cc @@ -18,6 +18,7 @@ #include "rawfile/raw_file_manager.h" #endif +#include "sherpa-onnx/csrc/file-utils.h" #include "sherpa-onnx/csrc/macros.h" #include "sherpa-onnx/csrc/onnx-utils.h" #include "sherpa-onnx/csrc/session.h" diff --git a/sherpa-onnx/csrc/offline-tts-matcha-model.cc b/sherpa-onnx/csrc/offline-tts-matcha-model.cc index fccb9012..b90b1302 100644 --- a/sherpa-onnx/csrc/offline-tts-matcha-model.cc +++ b/sherpa-onnx/csrc/offline-tts-matcha-model.cc @@ -18,6 +18,7 @@ #include "rawfile/raw_file_manager.h" #endif +#include "sherpa-onnx/csrc/file-utils.h" #include "sherpa-onnx/csrc/macros.h" #include "sherpa-onnx/csrc/onnx-utils.h" #include "sherpa-onnx/csrc/session.h" diff --git a/sherpa-onnx/csrc/offline-tts-vits-model.cc b/sherpa-onnx/csrc/offline-tts-vits-model.cc index 3587a109..d8bf7325 100644 --- a/sherpa-onnx/csrc/offline-tts-vits-model.cc +++ b/sherpa-onnx/csrc/offline-tts-vits-model.cc @@ -18,6 +18,7 @@ #include "rawfile/raw_file_manager.h" #endif +#include "sherpa-onnx/csrc/file-utils.h" #include "sherpa-onnx/csrc/macros.h" #include "sherpa-onnx/csrc/onnx-utils.h" #include "sherpa-onnx/csrc/session.h" diff --git a/sherpa-onnx/csrc/offline-wenet-ctc-model.cc b/sherpa-onnx/csrc/offline-wenet-ctc-model.cc index 5a939717..31ffe5b6 100644 --- a/sherpa-onnx/csrc/offline-wenet-ctc-model.cc +++ b/sherpa-onnx/csrc/offline-wenet-ctc-model.cc @@ -13,6 +13,7 @@ #include "rawfile/raw_file_manager.h" #endif +#include "sherpa-onnx/csrc/file-utils.h" #include "sherpa-onnx/csrc/macros.h" #include "sherpa-onnx/csrc/onnx-utils.h" #include "sherpa-onnx/csrc/session.h" diff --git a/sherpa-onnx/csrc/offline-whisper-model.cc b/sherpa-onnx/csrc/offline-whisper-model.cc index 360374cd..5443d37a 100644 --- a/sherpa-onnx/csrc/offline-whisper-model.cc +++ b/sherpa-onnx/csrc/offline-whisper-model.cc @@ -20,6 +20,7 @@ #include "rawfile/raw_file_manager.h" #endif +#include "sherpa-onnx/csrc/file-utils.h" #include "sherpa-onnx/csrc/macros.h" #include "sherpa-onnx/csrc/onnx-utils.h" #include "sherpa-onnx/csrc/session.h" diff --git a/sherpa-onnx/csrc/offline-zipformer-audio-tagging-model.cc b/sherpa-onnx/csrc/offline-zipformer-audio-tagging-model.cc index 7ddf6d9b..f464074e 100644 --- a/sherpa-onnx/csrc/offline-zipformer-audio-tagging-model.cc +++ b/sherpa-onnx/csrc/offline-zipformer-audio-tagging-model.cc @@ -7,6 +7,7 @@ #include #include +#include "sherpa-onnx/csrc/file-utils.h" #include "sherpa-onnx/csrc/onnx-utils.h" #include "sherpa-onnx/csrc/session.h" #include "sherpa-onnx/csrc/text-utils.h" diff --git a/sherpa-onnx/csrc/offline-zipformer-ctc-model.cc b/sherpa-onnx/csrc/offline-zipformer-ctc-model.cc index 8cfa30c2..356537e9 100644 --- a/sherpa-onnx/csrc/offline-zipformer-ctc-model.cc +++ b/sherpa-onnx/csrc/offline-zipformer-ctc-model.cc @@ -15,6 +15,7 @@ #include "rawfile/raw_file_manager.h" #endif +#include "sherpa-onnx/csrc/file-utils.h" #include "sherpa-onnx/csrc/macros.h" #include "sherpa-onnx/csrc/onnx-utils.h" #include "sherpa-onnx/csrc/session.h" diff --git a/sherpa-onnx/csrc/online-cnn-bilstm-model.cc b/sherpa-onnx/csrc/online-cnn-bilstm-model.cc index f4fb3c8f..2ca270ca 100644 --- a/sherpa-onnx/csrc/online-cnn-bilstm-model.cc +++ b/sherpa-onnx/csrc/online-cnn-bilstm-model.cc @@ -7,6 +7,7 @@ #include #include +#include "sherpa-onnx/csrc/file-utils.h" #include "sherpa-onnx/csrc/onnx-utils.h" #include "sherpa-onnx/csrc/session.h" #include "sherpa-onnx/csrc/text-utils.h" diff --git a/sherpa-onnx/csrc/online-conformer-transducer-model.cc b/sherpa-onnx/csrc/online-conformer-transducer-model.cc index 519d1a93..17ad7b25 100644 --- a/sherpa-onnx/csrc/online-conformer-transducer-model.cc +++ b/sherpa-onnx/csrc/online-conformer-transducer-model.cc @@ -23,6 +23,7 @@ #include "onnxruntime_cxx_api.h" // NOLINT #include "sherpa-onnx/csrc/cat.h" +#include "sherpa-onnx/csrc/file-utils.h" #include "sherpa-onnx/csrc/macros.h" #include "sherpa-onnx/csrc/online-transducer-decoder.h" #include "sherpa-onnx/csrc/onnx-utils.h" diff --git a/sherpa-onnx/csrc/online-lstm-transducer-model.cc b/sherpa-onnx/csrc/online-lstm-transducer-model.cc index 91b499fd..9cb84bfd 100644 --- a/sherpa-onnx/csrc/online-lstm-transducer-model.cc +++ b/sherpa-onnx/csrc/online-lstm-transducer-model.cc @@ -22,6 +22,7 @@ #include "onnxruntime_cxx_api.h" // NOLINT #include "sherpa-onnx/csrc/cat.h" +#include "sherpa-onnx/csrc/file-utils.h" #include "sherpa-onnx/csrc/macros.h" #include "sherpa-onnx/csrc/online-transducer-decoder.h" #include "sherpa-onnx/csrc/onnx-utils.h" diff --git a/sherpa-onnx/csrc/online-model-config.cc b/sherpa-onnx/csrc/online-model-config.cc index 5592c8d0..3f51a7f1 100644 --- a/sherpa-onnx/csrc/online-model-config.cc +++ b/sherpa-onnx/csrc/online-model-config.cc @@ -51,9 +51,20 @@ void OnlineModelConfig::Register(ParseOptions *po) { } bool OnlineModelConfig::Validate() const { - if (num_threads < 1) { - SHERPA_ONNX_LOGE("num_threads should be > 0. Given %d", num_threads); - return false; + // For RK NPU, we reinterpret num_threads: + // + // For RK3588 only + // num_threads == 1 -> Select a core randomly + // num_threads == 0 -> Use NPU core 0 + // num_threads == -1 -> Use NPU core 1 + // num_threads == -2 -> Use NPU core 2 + // num_threads == -3 -> Use NPU core 0 and core 1 + // num_threads == -4 -> Use NPU core 0, core 1, and core 2 + if (provider_config.provider != "rknn") { + if (num_threads < 1) { + SHERPA_ONNX_LOGE("num_threads should be > 0. Given %d", num_threads); + return false; + } } if (!tokens_buf.empty() && FileExists(tokens)) { diff --git a/sherpa-onnx/csrc/online-nemo-ctc-model.cc b/sherpa-onnx/csrc/online-nemo-ctc-model.cc index 716c7ee7..bbe4f1c7 100644 --- a/sherpa-onnx/csrc/online-nemo-ctc-model.cc +++ b/sherpa-onnx/csrc/online-nemo-ctc-model.cc @@ -18,6 +18,7 @@ #endif #include "sherpa-onnx/csrc/cat.h" +#include "sherpa-onnx/csrc/file-utils.h" #include "sherpa-onnx/csrc/macros.h" #include "sherpa-onnx/csrc/onnx-utils.h" #include "sherpa-onnx/csrc/session.h" diff --git a/sherpa-onnx/csrc/online-paraformer-model.cc b/sherpa-onnx/csrc/online-paraformer-model.cc index b21bb9bc..e75b70a4 100644 --- a/sherpa-onnx/csrc/online-paraformer-model.cc +++ b/sherpa-onnx/csrc/online-paraformer-model.cc @@ -17,6 +17,7 @@ #include "rawfile/raw_file_manager.h" #endif +#include "sherpa-onnx/csrc/file-utils.h" #include "sherpa-onnx/csrc/macros.h" #include "sherpa-onnx/csrc/onnx-utils.h" #include "sherpa-onnx/csrc/session.h" diff --git a/sherpa-onnx/csrc/online-recognizer-impl.cc b/sherpa-onnx/csrc/online-recognizer-impl.cc index 652ed211..810e0a17 100644 --- a/sherpa-onnx/csrc/online-recognizer-impl.cc +++ b/sherpa-onnx/csrc/online-recognizer-impl.cc @@ -1,6 +1,6 @@ // sherpa-onnx/csrc/online-recognizer-impl.cc // -// Copyright (c) 2023 Xiaomi Corporation +// Copyright (c) 2023-2025 Xiaomi Corporation #include "sherpa-onnx/csrc/online-recognizer-impl.h" @@ -26,10 +26,31 @@ #include "sherpa-onnx/csrc/onnx-utils.h" #include "sherpa-onnx/csrc/text-utils.h" +#if SHERPA_ONNX_ENABLE_RKNN +#include "sherpa-onnx/csrc/rknn/online-recognizer-transducer-rknn-impl.h" +#endif + namespace sherpa_onnx { std::unique_ptr OnlineRecognizerImpl::Create( const OnlineRecognizerConfig &config) { + if (config.model_config.provider_config.provider == "rknn") { +#if SHERPA_ONNX_ENABLE_RKNN + // Currently, only zipformer v1 is suported for rknn + if (config.model_config.transducer.encoder.empty()) { + SHERPA_ONNX_LOGE( + "Only Zipformer transducers are currently supported by rknn. " + "Fallback to CPU"); + } else { + return std::make_unique(config); + } +#else + SHERPA_ONNX_LOGE( + "Please rebuild sherpa-onnx with -DSHERPA_ONNX_ENABLE_RKNN=ON if you " + "want to use rknn. Fallback to CPU"); +#endif + } + if (!config.model_config.transducer.encoder.empty()) { Ort::Env env(ORT_LOGGING_LEVEL_ERROR); diff --git a/sherpa-onnx/csrc/online-rnn-lm.cc b/sherpa-onnx/csrc/online-rnn-lm.cc index 2a44ddbe..4e5261ce 100644 --- a/sherpa-onnx/csrc/online-rnn-lm.cc +++ b/sherpa-onnx/csrc/online-rnn-lm.cc @@ -11,6 +11,7 @@ #include #include "onnxruntime_cxx_api.h" // NOLINT +#include "sherpa-onnx/csrc/file-utils.h" #include "sherpa-onnx/csrc/macros.h" #include "sherpa-onnx/csrc/onnx-utils.h" #include "sherpa-onnx/csrc/session.h" diff --git a/sherpa-onnx/csrc/online-stream.h b/sherpa-onnx/csrc/online-stream.h index d1183b29..71600db6 100644 --- a/sherpa-onnx/csrc/online-stream.h +++ b/sherpa-onnx/csrc/online-stream.h @@ -23,7 +23,8 @@ class OnlineStream { public: explicit OnlineStream(const FeatureExtractorConfig &config = {}, ContextGraphPtr context_graph = nullptr); - ~OnlineStream(); + + virtual ~OnlineStream(); /** @param sampling_rate The sampling_rate of the input waveform. If it does diff --git a/sherpa-onnx/csrc/online-transducer-model.cc b/sherpa-onnx/csrc/online-transducer-model.cc index 9ebe4037..66838225 100644 --- a/sherpa-onnx/csrc/online-transducer-model.cc +++ b/sherpa-onnx/csrc/online-transducer-model.cc @@ -18,6 +18,7 @@ #include #include +#include "sherpa-onnx/csrc/file-utils.h" #include "sherpa-onnx/csrc/macros.h" #include "sherpa-onnx/csrc/online-conformer-transducer-model.h" #include "sherpa-onnx/csrc/online-lstm-transducer-model.h" diff --git a/sherpa-onnx/csrc/online-transducer-nemo-model.cc b/sherpa-onnx/csrc/online-transducer-nemo-model.cc index 73c23fe3..a656b74a 100644 --- a/sherpa-onnx/csrc/online-transducer-nemo-model.cc +++ b/sherpa-onnx/csrc/online-transducer-nemo-model.cc @@ -25,6 +25,7 @@ #endif #include "sherpa-onnx/csrc/cat.h" +#include "sherpa-onnx/csrc/file-utils.h" #include "sherpa-onnx/csrc/macros.h" #include "sherpa-onnx/csrc/online-transducer-decoder.h" #include "sherpa-onnx/csrc/onnx-utils.h" diff --git a/sherpa-onnx/csrc/online-wenet-ctc-model.cc b/sherpa-onnx/csrc/online-wenet-ctc-model.cc index bf468484..9024481d 100644 --- a/sherpa-onnx/csrc/online-wenet-ctc-model.cc +++ b/sherpa-onnx/csrc/online-wenet-ctc-model.cc @@ -17,6 +17,7 @@ #include "rawfile/raw_file_manager.h" #endif +#include "sherpa-onnx/csrc/file-utils.h" #include "sherpa-onnx/csrc/macros.h" #include "sherpa-onnx/csrc/onnx-utils.h" #include "sherpa-onnx/csrc/session.h" diff --git a/sherpa-onnx/csrc/online-zipformer-transducer-model.cc b/sherpa-onnx/csrc/online-zipformer-transducer-model.cc index e572feed..7dfaa30a 100644 --- a/sherpa-onnx/csrc/online-zipformer-transducer-model.cc +++ b/sherpa-onnx/csrc/online-zipformer-transducer-model.cc @@ -23,6 +23,7 @@ #include "onnxruntime_cxx_api.h" // NOLINT #include "sherpa-onnx/csrc/cat.h" +#include "sherpa-onnx/csrc/file-utils.h" #include "sherpa-onnx/csrc/macros.h" #include "sherpa-onnx/csrc/online-transducer-decoder.h" #include "sherpa-onnx/csrc/onnx-utils.h" diff --git a/sherpa-onnx/csrc/online-zipformer2-ctc-model.cc b/sherpa-onnx/csrc/online-zipformer2-ctc-model.cc index 298b9052..1cda9a62 100644 --- a/sherpa-onnx/csrc/online-zipformer2-ctc-model.cc +++ b/sherpa-onnx/csrc/online-zipformer2-ctc-model.cc @@ -20,6 +20,7 @@ #endif #include "sherpa-onnx/csrc/cat.h" +#include "sherpa-onnx/csrc/file-utils.h" #include "sherpa-onnx/csrc/macros.h" #include "sherpa-onnx/csrc/onnx-utils.h" #include "sherpa-onnx/csrc/session.h" diff --git a/sherpa-onnx/csrc/online-zipformer2-transducer-model.cc b/sherpa-onnx/csrc/online-zipformer2-transducer-model.cc index bd79cfc4..9370c6f8 100644 --- a/sherpa-onnx/csrc/online-zipformer2-transducer-model.cc +++ b/sherpa-onnx/csrc/online-zipformer2-transducer-model.cc @@ -25,6 +25,7 @@ #include "onnxruntime_cxx_api.h" // NOLINT #include "sherpa-onnx/csrc/cat.h" +#include "sherpa-onnx/csrc/file-utils.h" #include "sherpa-onnx/csrc/macros.h" #include "sherpa-onnx/csrc/online-transducer-decoder.h" #include "sherpa-onnx/csrc/onnx-utils.h" diff --git a/sherpa-onnx/csrc/onnx-utils.cc b/sherpa-onnx/csrc/onnx-utils.cc index 6d7e2684..a32f92f9 100644 --- a/sherpa-onnx/csrc/onnx-utils.cc +++ b/sherpa-onnx/csrc/onnx-utils.cc @@ -13,15 +13,8 @@ #include #include -#include "sherpa-onnx/csrc/macros.h" - -#if __ANDROID_API__ >= 9 -#include "android/asset_manager.h" -#include "android/asset_manager_jni.h" -#include "android/log.h" -#endif - #include "onnxruntime_cxx_api.h" // NOLINT +#include "sherpa-onnx/csrc/macros.h" namespace sherpa_onnx { @@ -305,63 +298,6 @@ void Print4D(const Ort::Value *v) { fprintf(stderr, "\n"); } -std::vector ReadFile(const std::string &filename) { - std::ifstream input(filename, std::ios::binary); - std::vector buffer(std::istreambuf_iterator(input), {}); - return buffer; -} - -#if __ANDROID_API__ >= 9 -std::vector ReadFile(AAssetManager *mgr, const std::string &filename) { - AAsset *asset = AAssetManager_open(mgr, filename.c_str(), AASSET_MODE_BUFFER); - if (!asset) { - __android_log_print(ANDROID_LOG_FATAL, "sherpa-onnx", - "Read binary file: Load %s failed", filename.c_str()); - exit(-1); - } - - auto p = reinterpret_cast(AAsset_getBuffer(asset)); - size_t asset_length = AAsset_getLength(asset); - - std::vector buffer(p, p + asset_length); - AAsset_close(asset); - - return buffer; -} -#endif - -#if __OHOS__ -std::vector ReadFile(NativeResourceManager *mgr, - const std::string &filename) { - std::unique_ptr fp( - OH_ResourceManager_OpenRawFile(mgr, filename.c_str()), - OH_ResourceManager_CloseRawFile); - - if (!fp) { - std::ostringstream os; - os << "Read file '" << filename << "' failed."; - SHERPA_ONNX_LOGE("%s", os.str().c_str()); - return {}; - } - - auto len = static_cast(OH_ResourceManager_GetRawFileSize(fp.get())); - - std::vector buffer(len); - - int32_t n = OH_ResourceManager_ReadRawFile(fp.get(), buffer.data(), len); - - if (n != len) { - std::ostringstream os; - os << "Read file '" << filename << "' failed. Number of bytes read: " << n - << ". Expected bytes to read: " << len; - SHERPA_ONNX_LOGE("%s", os.str().c_str()); - return {}; - } - - return buffer; -} -#endif - Ort::Value Repeat(OrtAllocator *allocator, Ort::Value *cur_encoder_out, const std::vector &hyps_num_split) { std::vector cur_encoder_out_shape = diff --git a/sherpa-onnx/csrc/onnx-utils.h b/sherpa-onnx/csrc/onnx-utils.h index c978fbb7..e695532a 100644 --- a/sherpa-onnx/csrc/onnx-utils.h +++ b/sherpa-onnx/csrc/onnx-utils.h @@ -17,15 +17,6 @@ #include #include -#if __ANDROID_API__ >= 9 -#include "android/asset_manager.h" -#include "android/asset_manager_jni.h" -#endif - -#if __OHOS__ -#include "rawfile/raw_file_manager.h" -#endif - #include "onnxruntime_cxx_api.h" // NOLINT namespace sherpa_onnx { @@ -101,17 +92,6 @@ void Fill(Ort::Value *tensor, T value) { std::fill(p, p + n, value); } -std::vector ReadFile(const std::string &filename); - -#if __ANDROID_API__ >= 9 -std::vector ReadFile(AAssetManager *mgr, const std::string &filename); -#endif - -#if __OHOS__ -std::vector ReadFile(NativeResourceManager *mgr, - const std::string &filename); -#endif - // TODO(fangjun): Document it Ort::Value Repeat(OrtAllocator *allocator, Ort::Value *cur_encoder_out, const std::vector &hyps_num_split); diff --git a/sherpa-onnx/csrc/rknn/macros.h b/sherpa-onnx/csrc/rknn/macros.h new file mode 100644 index 00000000..f4afade7 --- /dev/null +++ b/sherpa-onnx/csrc/rknn/macros.h @@ -0,0 +1,19 @@ +// sherpa-onnx/csrc/macros.h +// +// Copyright 2025 Xiaomi Corporation + +#ifndef SHERPA_ONNX_CSRC_RKNN_MACROS_H_ +#define SHERPA_ONNX_CSRC_RKNN_MACROS_H_ + +#include "sherpa-onnx/csrc/macros.h" + +#define SHERPA_ONNX_RKNN_CHECK(ret, msg, ...) \ + do { \ + if (ret != RKNN_SUCC) { \ + SHERPA_ONNX_LOGE("Return code is: %d", ret); \ + SHERPA_ONNX_LOGE(msg, ##__VA_ARGS__); \ + SHERPA_ONNX_EXIT(-1); \ + } \ + } while (0) + +#endif // SHERPA_ONNX_CSRC_RKNN_MACROS_H_ diff --git a/sherpa-onnx/csrc/rknn/online-recognizer-transducer-rknn-impl.h b/sherpa-onnx/csrc/rknn/online-recognizer-transducer-rknn-impl.h new file mode 100644 index 00000000..9cc29507 --- /dev/null +++ b/sherpa-onnx/csrc/rknn/online-recognizer-transducer-rknn-impl.h @@ -0,0 +1,231 @@ +// sherpa-onnx/csrc/rknn/online-recognizer-transducer-rknn-impl.h +// +// Copyright (c) 2025 Xiaomi Corporation + +#ifndef SHERPA_ONNX_CSRC_RKNN_ONLINE_RECOGNIZER_TRANSDUCER_RKNN_IMPL_H_ +#define SHERPA_ONNX_CSRC_RKNN_ONLINE_RECOGNIZER_TRANSDUCER_RKNN_IMPL_H_ + +#include +#include +#include +#include +#include +#include + +#include "sherpa-onnx/csrc/macros.h" +#include "sherpa-onnx/csrc/online-recognizer-impl.h" +#include "sherpa-onnx/csrc/online-recognizer.h" +#include "sherpa-onnx/csrc/rknn/online-stream-rknn.h" +#include "sherpa-onnx/csrc/rknn/online-transducer-greedy-search-decoder-rknn.h" +#include "sherpa-onnx/csrc/rknn/online-zipformer-transducer-model-rknn.h" +#include "sherpa-onnx/csrc/symbol-table.h" + +namespace sherpa_onnx { + +OnlineRecognizerResult Convert(const OnlineTransducerDecoderResultRknn &src, + const SymbolTable &sym_table, + float frame_shift_ms, int32_t subsampling_factor, + int32_t segment, int32_t frames_since_start) { + OnlineRecognizerResult r; + r.tokens.reserve(src.tokens.size()); + r.timestamps.reserve(src.tokens.size()); + + std::string text; + for (auto i : src.tokens) { + auto sym = sym_table[i]; + + text.append(sym); + + if (sym.size() == 1 && (sym[0] < 0x20 || sym[0] > 0x7e)) { + // for bpe models with byte_fallback + // (but don't rewrite printable characters 0x20..0x7e, + // which collide with standard BPE units) + std::ostringstream os; + os << "<0x" << std::hex << std::uppercase + << (static_cast(sym[0]) & 0xff) << ">"; + sym = os.str(); + } + + r.tokens.push_back(std::move(sym)); + } + + if (sym_table.IsByteBpe()) { + text = sym_table.DecodeByteBpe(text); + } + + r.text = std::move(text); + + float frame_shift_s = frame_shift_ms / 1000. * subsampling_factor; + for (auto t : src.timestamps) { + float time = frame_shift_s * t; + r.timestamps.push_back(time); + } + + r.segment = segment; + r.start_time = frames_since_start * frame_shift_ms / 1000.; + + return r; +} + +class OnlineRecognizerTransducerRknnImpl : public OnlineRecognizerImpl { + public: + explicit OnlineRecognizerTransducerRknnImpl( + const OnlineRecognizerConfig &config) + : OnlineRecognizerImpl(config), + config_(config), + endpoint_(config_.endpoint_config), + model_(std::make_unique( + config.model_config)) { + if (!config.model_config.tokens_buf.empty()) { + sym_ = SymbolTable(config.model_config.tokens_buf, false); + } else { + /// assuming tokens_buf and tokens are guaranteed not being both empty + sym_ = SymbolTable(config.model_config.tokens, true); + } + + if (sym_.Contains("")) { + unk_id_ = sym_[""]; + } + + decoder_ = std::make_unique( + model_.get(), unk_id_); + } + + template + explicit OnlineRecognizerTransducerRknnImpl( + Manager *mgr, const OnlineRecognizerConfig &config) + : OnlineRecognizerImpl(mgr, config), + config_(config), + endpoint_(config_.endpoint_config), + model_( + std::make_unique(mgr, config)) { + // TODO(fangjun): Support Android + } + + std::unique_ptr CreateStream() const override { + auto stream = std::make_unique(config_.feat_config); + auto r = decoder_->GetEmptyResult(); + stream->SetZipformerResult(std::move(r)); + stream->SetZipformerEncoderStates(model_->GetEncoderInitStates()); + return stream; + } + + std::unique_ptr CreateStream( + const std::string &hotwords) const override { + SHERPA_ONNX_LOGE("Hotwords for RKNN is not supported now."); + return CreateStream(); + } + + bool IsReady(OnlineStream *s) const override { + return s->GetNumProcessedFrames() + model_->ChunkSize() < + s->NumFramesReady(); + } + + // Warmping up engine with wp: warm_up count and max-batch-size + + void DecodeStreams(OnlineStream **ss, int32_t n) const override { + for (int32_t i = 0; i < n; ++i) { + DecodeStream(reinterpret_cast(ss[i])); + } + } + + OnlineRecognizerResult GetResult(OnlineStream *s) const override { + OnlineTransducerDecoderResultRknn decoder_result = + reinterpret_cast(s)->GetZipformerResult(); + decoder_->StripLeadingBlanks(&decoder_result); + // TODO(fangjun): Remember to change these constants if needed + int32_t frame_shift_ms = 10; + int32_t subsampling_factor = 4; + auto r = Convert(decoder_result, sym_, frame_shift_ms, subsampling_factor, + s->GetCurrentSegment(), s->GetNumFramesSinceStart()); + r.text = ApplyInverseTextNormalization(std::move(r.text)); + return r; + } + + bool IsEndpoint(OnlineStream *s) const override { + if (!config_.enable_endpoint) { + return false; + } + + int32_t num_processed_frames = s->GetNumProcessedFrames(); + + // frame shift is 10 milliseconds + float frame_shift_in_seconds = 0.01; + + // subsampling factor is 4 + int32_t trailing_silence_frames = reinterpret_cast(s) + ->GetZipformerResult() + .num_trailing_blanks * + 4; + + return endpoint_.IsEndpoint(num_processed_frames, trailing_silence_frames, + frame_shift_in_seconds); + } + + void Reset(OnlineStream *s) const override { + int32_t context_size = model_->ContextSize(); + + { + // segment is incremented only when the last + // result is not empty, contains non-blanks and longer than context_size) + const auto &r = + reinterpret_cast(s)->GetZipformerResult(); + if (!r.tokens.empty() && r.tokens.back() != 0 && + r.tokens.size() > context_size) { + s->GetCurrentSegment() += 1; + } + } + + // reset encoder states + // reinterpret_cast(s)->SetZipformerEncoderStates(model_->GetEncoderInitStates()); + auto r = decoder_->GetEmptyResult(); + auto last_result = + reinterpret_cast(s)->GetZipformerResult(); + + // if last result is not empty, then + // preserve last tokens as the context for next result + if (static_cast(last_result.tokens.size()) > context_size) { + r.tokens = {last_result.tokens.end() - context_size, + last_result.tokens.end()}; + } + reinterpret_cast(s)->SetZipformerResult(std::move(r)); + + // Note: We only update counters. The underlying audio samples + // are not discarded. + s->Reset(); + } + + private: + void DecodeStream(OnlineStreamRknn *s) const { + int32_t chunk_size = model_->ChunkSize(); + int32_t chunk_shift = model_->ChunkShift(); + + int32_t feature_dim = s->FeatureDim(); + + const auto num_processed_frames = s->GetNumProcessedFrames(); + + std::vector features = + s->GetFrames(num_processed_frames, chunk_size); + s->GetNumProcessedFrames() += chunk_shift; + + auto &states = s->GetZipformerEncoderStates(); + + auto p = model_->RunEncoder(features, std::move(states)); + states = std::move(p.second); + + auto &r = s->GetZipformerResult(); + decoder_->Decode(std::move(p.first), &r); + } + + private: + OnlineRecognizerConfig config_; + SymbolTable sym_; + Endpoint endpoint_; + int32_t unk_id_ = -1; + std::unique_ptr model_; + std::unique_ptr decoder_; +}; + +} // namespace sherpa_onnx + +#endif // SHERPA_ONNX_CSRC_RKNN_ONLINE_RECOGNIZER_TRANSDUCER_RKNN_IMPL_H_ diff --git a/sherpa-onnx/csrc/rknn/online-stream-rknn.cc b/sherpa-onnx/csrc/rknn/online-stream-rknn.cc new file mode 100644 index 00000000..f72dc822 --- /dev/null +++ b/sherpa-onnx/csrc/rknn/online-stream-rknn.cc @@ -0,0 +1,60 @@ +// sherpa-onnx/csrc/rknn/online-stream-rknn.cc +// +// Copyright (c) 2025 Xiaomi Corporation + +#include "sherpa-onnx/csrc/rknn/online-stream-rknn.h" + +#include +#include + +namespace sherpa_onnx { + +class OnlineStreamRknn::Impl { + public: + void SetZipformerEncoderStates(std::vector> states) { + states_ = std::move(states); + } + + std::vector> &GetZipformerEncoderStates() { + return states_; + } + + void SetZipformerResult(OnlineTransducerDecoderResultRknn r) { + result_ = std::move(r); + } + + OnlineTransducerDecoderResultRknn &GetZipformerResult() { return result_; } + + private: + std::vector> states_; + OnlineTransducerDecoderResultRknn result_; +}; + +OnlineStreamRknn::OnlineStreamRknn( + const FeatureExtractorConfig &config /*= {}*/, + ContextGraphPtr context_graph /*= nullptr*/) + : OnlineStream(config, context_graph), impl_(std::make_unique()) {} + +OnlineStreamRknn::~OnlineStreamRknn() = default; + +void OnlineStreamRknn::SetZipformerEncoderStates( + std::vector> states) const { + impl_->SetZipformerEncoderStates(std::move(states)); +} + +std::vector> &OnlineStreamRknn::GetZipformerEncoderStates() + const { + return impl_->GetZipformerEncoderStates(); +} + +void OnlineStreamRknn::SetZipformerResult( + OnlineTransducerDecoderResultRknn r) const { + impl_->SetZipformerResult(std::move(r)); +} + +OnlineTransducerDecoderResultRknn &OnlineStreamRknn::GetZipformerResult() + const { + return impl_->GetZipformerResult(); +} + +} // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/rknn/online-stream-rknn.h b/sherpa-onnx/csrc/rknn/online-stream-rknn.h new file mode 100644 index 00000000..fe249d5b --- /dev/null +++ b/sherpa-onnx/csrc/rknn/online-stream-rknn.h @@ -0,0 +1,38 @@ +// sherpa-onnx/csrc/rknn/online-stream-rknn.h +// +// Copyright (c) 2025 Xiaomi Corporation +#ifndef SHERPA_ONNX_CSRC_RKNN_ONLINE_STREAM_RKNN_H_ +#define SHERPA_ONNX_CSRC_RKNN_ONLINE_STREAM_RKNN_H_ +#include +#include + +#include "rknn_api.h" // NOLINT +#include "sherpa-onnx/csrc/online-stream.h" +#include "sherpa-onnx/csrc/rknn/online-transducer-greedy-search-decoder-rknn.h" + +namespace sherpa_onnx { + +class OnlineStreamRknn : public OnlineStream { + public: + explicit OnlineStreamRknn(const FeatureExtractorConfig &config = {}, + ContextGraphPtr context_graph = nullptr); + + ~OnlineStreamRknn(); + + void SetZipformerEncoderStates( + std::vector> states) const; + + std::vector> &GetZipformerEncoderStates() const; + + void SetZipformerResult(OnlineTransducerDecoderResultRknn r) const; + + OnlineTransducerDecoderResultRknn &GetZipformerResult() const; + + private: + class Impl; + std::unique_ptr impl_; +}; + +} // namespace sherpa_onnx + +#endif // SHERPA_ONNX_CSRC_RKNN_ONLINE_STREAM_RKNN_H_ diff --git a/sherpa-onnx/csrc/rknn/online-transducer-greedy-search-decoder-rknn.cc b/sherpa-onnx/csrc/rknn/online-transducer-greedy-search-decoder-rknn.cc new file mode 100644 index 00000000..62a59ed2 --- /dev/null +++ b/sherpa-onnx/csrc/rknn/online-transducer-greedy-search-decoder-rknn.cc @@ -0,0 +1,94 @@ +// sherpa-onnx/csrc/rknn/online-transducer-greedy-search-decoder-rknn.cc +// +// Copyright (c) 2025 Xiaomi Corporation + +#include "sherpa-onnx/csrc/rknn/online-transducer-greedy-search-decoder-rknn.h" + +#include +#include +#include + +#include "sherpa-onnx/csrc/macros.h" + +namespace sherpa_onnx { + +OnlineTransducerDecoderResultRknn +OnlineTransducerGreedySearchDecoderRknn::GetEmptyResult() const { + int32_t context_size = model_->ContextSize(); + int32_t blank_id = 0; // always 0 + OnlineTransducerDecoderResultRknn r; + r.tokens.resize(context_size, -1); + r.tokens.back() = blank_id; + + return r; +} + +void OnlineTransducerGreedySearchDecoderRknn::StripLeadingBlanks( + OnlineTransducerDecoderResultRknn *r) const { + int32_t context_size = model_->ContextSize(); + + auto start = r->tokens.begin() + context_size; + auto end = r->tokens.end(); + + r->tokens = std::vector(start, end); +} + +void OnlineTransducerGreedySearchDecoderRknn::Decode( + std::vector encoder_out, + OnlineTransducerDecoderResultRknn *result) const { + auto &r = result[0]; + auto attr = model_->GetEncoderOutAttr(); + int32_t num_frames = attr.dims[1]; + int32_t encoder_out_dim = attr.dims[2]; + + int32_t vocab_size = model_->VocabSize(); + int32_t context_size = model_->ContextSize(); + + std::vector decoder_input; + std::vector decoder_out; + + if (r.previous_decoder_out.empty()) { + decoder_input = {r.tokens.begin() + (r.tokens.size() - context_size), + r.tokens.end()}; + decoder_out = model_->RunDecoder(std::move(decoder_input)); + + } else { + decoder_out = std::move(r.previous_decoder_out); + } + + const float *p_encoder_out = encoder_out.data(); + for (int32_t t = 0; t != num_frames; ++t) { + auto logit = model_->RunJoiner(p_encoder_out, decoder_out.data()); + p_encoder_out += encoder_out_dim; + + bool emitted = false; + if (blank_penalty_ > 0.0) { + logit[0] -= blank_penalty_; // assuming blank id is 0 + } + + auto y = static_cast(std::distance( + logit.data(), + std::max_element(logit.data(), logit.data() + vocab_size))); + // blank id is hardcoded to 0 + // also, it treats unk as blank + if (y != 0 && y != unk_id_) { + emitted = true; + r.tokens.push_back(y); + r.timestamps.push_back(t + r.frame_offset); + r.num_trailing_blanks = 0; + } else { + ++r.num_trailing_blanks; + } + + if (emitted) { + decoder_input = {r.tokens.begin() + (r.tokens.size() - context_size), + r.tokens.end()}; + decoder_out = model_->RunDecoder(std::move(decoder_input)); + } + } + + r.frame_offset += num_frames; + r.previous_decoder_out = std::move(decoder_out); +} + +} // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/rknn/online-transducer-greedy-search-decoder-rknn.h b/sherpa-onnx/csrc/rknn/online-transducer-greedy-search-decoder-rknn.h new file mode 100644 index 00000000..6699b75f --- /dev/null +++ b/sherpa-onnx/csrc/rknn/online-transducer-greedy-search-decoder-rknn.h @@ -0,0 +1,52 @@ +// sherpa-onnx/csrc/rknn/online-transducer-greedy-search-decoder-rknn.h +// +// Copyright (c) 2025 Xiaomi Corporation + +#ifndef SHERPA_ONNX_CSRC_RKNN_ONLINE_TRANSDUCER_GREEDY_SEARCH_DECODER_RKNN_H_ +#define SHERPA_ONNX_CSRC_RKNN_ONLINE_TRANSDUCER_GREEDY_SEARCH_DECODER_RKNN_H_ + +#include + +#include "sherpa-onnx/csrc/rknn/online-zipformer-transducer-model-rknn.h" + +namespace sherpa_onnx { + +struct OnlineTransducerDecoderResultRknn { + /// Number of frames after subsampling we have decoded so far + int32_t frame_offset = 0; + + /// The decoded token IDs so far + std::vector tokens; + + /// number of trailing blank frames decoded so far + int32_t num_trailing_blanks = 0; + + /// timestamps[i] contains the output frame index where tokens[i] is decoded. + std::vector timestamps; + + std::vector previous_decoder_out; +}; + +class OnlineTransducerGreedySearchDecoderRknn { + public: + explicit OnlineTransducerGreedySearchDecoderRknn( + OnlineZipformerTransducerModelRknn *model, int32_t unk_id = 2, + float blank_penalty = 0.0) + : model_(model), unk_id_(unk_id), blank_penalty_(blank_penalty) {} + + OnlineTransducerDecoderResultRknn GetEmptyResult() const; + + void StripLeadingBlanks(OnlineTransducerDecoderResultRknn *r) const; + + void Decode(std::vector encoder_out, + OnlineTransducerDecoderResultRknn *result) const; + + private: + OnlineZipformerTransducerModelRknn *model_; // Not owned + int32_t unk_id_; + float blank_penalty_; +}; + +} // namespace sherpa_onnx + +#endif // SHERPA_ONNX_CSRC_RKNN_ONLINE_TRANSDUCER_GREEDY_SEARCH_DECODER_RKNN_H_ diff --git a/sherpa-onnx/csrc/rknn/online-zipformer-transducer-model-rknn.cc b/sherpa-onnx/csrc/rknn/online-zipformer-transducer-model-rknn.cc new file mode 100644 index 00000000..a918fa1b --- /dev/null +++ b/sherpa-onnx/csrc/rknn/online-zipformer-transducer-model-rknn.cc @@ -0,0 +1,781 @@ +// sherpa-onnx/csrc/rknn/online-zipformer-transducer-model-rknn.cc +// +// Copyright (c) 2023 Xiaomi Corporation + +#include "sherpa-onnx/csrc/rknn/online-zipformer-transducer-model-rknn.h" + +#include +#include +#include +#include +#include +#include + +#if __ANDROID_API__ >= 9 +#include "android/asset_manager.h" +#include "android/asset_manager_jni.h" +#endif + +#if __OHOS__ +#include "rawfile/raw_file_manager.h" +#endif + +#include "sherpa-onnx/csrc/file-utils.h" +#include "sherpa-onnx/csrc/rknn/macros.h" +#include "sherpa-onnx/csrc/text-utils.h" + +namespace sherpa_onnx { + +// chw -> hwc +static void Transpose(const float *src, int32_t n, int32_t channel, + int32_t height, int32_t width, float *dst) { + for (int32_t i = 0; i < n; ++i) { + for (int32_t h = 0; h < height; ++h) { + for (int32_t w = 0; w < width; ++w) { + for (int32_t c = 0; c < channel; ++c) { + // dst[h, w, c] = src[c, h, w] + dst[i * height * width * channel + h * width * channel + w * channel + + c] = src[i * height * width * channel + c * height * width + + h * width + w]; + } + } + } + } +} + +static std::string ToString(const rknn_tensor_attr &attr) { + std::ostringstream os; + os << "{"; + os << attr.index; + os << ", name: " << attr.name; + os << ", shape: ("; + std::string sep; + for (int32_t i = 0; i < static_cast(attr.n_dims); ++i) { + os << sep << attr.dims[i]; + sep = ","; + } + os << ")"; + os << ", n_elems: " << attr.n_elems; + os << ", size: " << attr.size; + os << ", fmt: " << get_format_string(attr.fmt); + os << ", type: " << get_type_string(attr.type); + os << ", pass_through: " << (attr.pass_through ? "true" : "false"); + os << "}"; + return os.str(); +} + +static std::unordered_map Parse( + const rknn_custom_string &custom_string) { + std::unordered_map ans; + std::vector fields; + SplitStringToVector(custom_string.string, ";", false, &fields); + + std::vector tmp; + for (const auto &f : fields) { + SplitStringToVector(f, "=", false, &tmp); + if (tmp.size() != 2) { + SHERPA_ONNX_LOGE("Invalid custom string %s for %s", custom_string.string, + f.c_str()); + SHERPA_ONNX_EXIT(-1); + } + ans[std::move(tmp[0])] = std::move(tmp[1]); + } + + return ans; +} + +class OnlineZipformerTransducerModelRknn::Impl { + public: + ~Impl() { + auto ret = rknn_destroy(encoder_ctx_); + if (ret != RKNN_SUCC) { + SHERPA_ONNX_LOGE("Failed to destroy the encoder context"); + } + + ret = rknn_destroy(decoder_ctx_); + if (ret != RKNN_SUCC) { + SHERPA_ONNX_LOGE("Failed to destroy the decoder context"); + } + + ret = rknn_destroy(joiner_ctx_); + if (ret != RKNN_SUCC) { + SHERPA_ONNX_LOGE("Failed to destroy the joiner context"); + } + } + + explicit Impl(const OnlineModelConfig &config) : config_(config) { + { + auto buf = ReadFile(config.transducer.encoder); + InitEncoder(buf.data(), buf.size()); + } + + { + auto buf = ReadFile(config.transducer.decoder); + InitDecoder(buf.data(), buf.size()); + } + + { + auto buf = ReadFile(config.transducer.joiner); + InitJoiner(buf.data(), buf.size()); + } + + // Now select which core to run for RK3588 + int32_t ret_encoder = RKNN_SUCC; + int32_t ret_decoder = RKNN_SUCC; + int32_t ret_joiner = RKNN_SUCC; + switch (config_.num_threads) { + case 1: + ret_encoder = rknn_set_core_mask(encoder_ctx_, RKNN_NPU_CORE_AUTO); + ret_decoder = rknn_set_core_mask(decoder_ctx_, RKNN_NPU_CORE_AUTO); + ret_joiner = rknn_set_core_mask(joiner_ctx_, RKNN_NPU_CORE_AUTO); + break; + case 0: + ret_encoder = rknn_set_core_mask(encoder_ctx_, RKNN_NPU_CORE_0); + ret_decoder = rknn_set_core_mask(decoder_ctx_, RKNN_NPU_CORE_0); + ret_joiner = rknn_set_core_mask(joiner_ctx_, RKNN_NPU_CORE_0); + break; + case -1: + ret_encoder = rknn_set_core_mask(encoder_ctx_, RKNN_NPU_CORE_1); + ret_decoder = rknn_set_core_mask(decoder_ctx_, RKNN_NPU_CORE_1); + ret_joiner = rknn_set_core_mask(joiner_ctx_, RKNN_NPU_CORE_1); + break; + case -2: + ret_encoder = rknn_set_core_mask(encoder_ctx_, RKNN_NPU_CORE_2); + ret_decoder = rknn_set_core_mask(decoder_ctx_, RKNN_NPU_CORE_2); + ret_joiner = rknn_set_core_mask(joiner_ctx_, RKNN_NPU_CORE_2); + break; + case -3: + ret_encoder = rknn_set_core_mask(encoder_ctx_, RKNN_NPU_CORE_0_1); + ret_decoder = rknn_set_core_mask(decoder_ctx_, RKNN_NPU_CORE_0_1); + ret_joiner = rknn_set_core_mask(joiner_ctx_, RKNN_NPU_CORE_0_1); + break; + case -4: + ret_encoder = rknn_set_core_mask(encoder_ctx_, RKNN_NPU_CORE_0_1_2); + ret_decoder = rknn_set_core_mask(decoder_ctx_, RKNN_NPU_CORE_0_1_2); + ret_joiner = rknn_set_core_mask(joiner_ctx_, RKNN_NPU_CORE_0_1_2); + break; + default: + SHERPA_ONNX_LOGE( + "Valid num_threads for rk npu is 1 (auto), 0 (core 0), -1 (core " + "1), -2 (core 2), -3 (core 0_1), -4 (core 0_1_2). Given: %d", + config_.num_threads); + break; + } + if (ret_encoder != RKNN_SUCC) { + SHERPA_ONNX_LOGE( + "Failed to select npu core to run encoder (You can ignore it if you " + "are not using RK3588."); + } + + if (ret_decoder != RKNN_SUCC) { + SHERPA_ONNX_LOGE( + "Failed to select npu core to run decoder (You can ignore it if you " + "are not using RK3588."); + } + + if (ret_decoder != RKNN_SUCC) { + SHERPA_ONNX_LOGE( + "Failed to select npu core to run joiner (You can ignore it if you " + "are not using RK3588."); + } + } + + // TODO(fangjun): Support Android + + std::vector> GetEncoderInitStates() const { + // encoder_input_attrs_[0] is for the feature + // encoder_input_attrs_[1:] is for states + // so we use -1 here + std::vector> states(encoder_input_attrs_.size() - 1); + + int32_t i = -1; + for (auto &attr : encoder_input_attrs_) { + i += 1; + if (i == 0) { + // skip processing the attr for features. + continue; + } + + if (attr.type == RKNN_TENSOR_FLOAT16) { + states[i - 1].resize(attr.n_elems * sizeof(float)); + } else if (attr.type == RKNN_TENSOR_INT64) { + states[i - 1].resize(attr.n_elems * sizeof(int64_t)); + } else { + SHERPA_ONNX_LOGE("Unsupported tensor type: %d, %s", attr.type, + get_type_string(attr.type)); + SHERPA_ONNX_EXIT(-1); + } + } + + return states; + } + + std::pair, std::vector>> RunEncoder( + std::vector features, + std::vector> states) const { + std::vector inputs(encoder_input_attrs_.size()); + + for (int32_t i = 0; i < static_cast(inputs.size()); ++i) { + auto &input = inputs[i]; + auto &attr = encoder_input_attrs_[i]; + input.index = attr.index; + + if (attr.type == RKNN_TENSOR_FLOAT16) { + input.type = RKNN_TENSOR_FLOAT32; + } else if (attr.type == RKNN_TENSOR_INT64) { + input.type = RKNN_TENSOR_INT64; + } else { + SHERPA_ONNX_LOGE("Unsupported tensor type %d, %s", attr.type, + get_type_string(attr.type)); + SHERPA_ONNX_EXIT(-1); + } + + input.fmt = attr.fmt; + if (i == 0) { + input.buf = reinterpret_cast(features.data()); + input.size = features.size() * sizeof(float); + } else { + input.buf = reinterpret_cast(states[i - 1].data()); + input.size = states[i - 1].size(); + } + } + + std::vector encoder_out(encoder_output_attrs_[0].n_elems); + + // Note(fangjun): We can reuse the memory from input argument `states` + // auto next_states = GetEncoderInitStates(); + auto &next_states = states; + + std::vector outputs(encoder_output_attrs_.size()); + for (int32_t i = 0; i < outputs.size(); ++i) { + auto &output = outputs[i]; + auto &attr = encoder_output_attrs_[i]; + output.index = attr.index; + output.is_prealloc = 1; + + if (attr.type == RKNN_TENSOR_FLOAT16) { + output.want_float = 1; + } else if (attr.type == RKNN_TENSOR_INT64) { + output.want_float = 0; + } else { + SHERPA_ONNX_LOGE("Unsupported tensor type %d, %s", attr.type, + get_type_string(attr.type)); + SHERPA_ONNX_EXIT(-1); + } + + if (i == 0) { + output.size = encoder_out.size() * sizeof(float); + output.buf = reinterpret_cast(encoder_out.data()); + } else { + output.size = next_states[i - 1].size(); + output.buf = reinterpret_cast(next_states[i - 1].data()); + } + } + + auto ret = rknn_inputs_set(encoder_ctx_, inputs.size(), inputs.data()); + SHERPA_ONNX_RKNN_CHECK(ret, "Failed to set encoder inputs"); + + ret = rknn_run(encoder_ctx_, nullptr); + SHERPA_ONNX_RKNN_CHECK(ret, "Failed to run encoder"); + + ret = + rknn_outputs_get(encoder_ctx_, outputs.size(), outputs.data(), nullptr); + SHERPA_ONNX_RKNN_CHECK(ret, "Failed to get encoder output"); + + for (int32_t i = 0; i < next_states.size(); ++i) { + const auto &attr = encoder_input_attrs_[i + 1]; + if (attr.n_dims == 4) { + // TODO(fangjun): The transpose is copied from + // https://github.com/airockchip/rknn_model_zoo/blob/main/examples/zipformer/cpp/process.cc#L22 + // I don't understand why we need to do that. + std::vector dst(next_states[i].size()); + int32_t n = attr.dims[0]; + int32_t h = attr.dims[1]; + int32_t w = attr.dims[2]; + int32_t c = attr.dims[3]; + Transpose(reinterpret_cast(next_states[i].data()), n, c, + h, w, reinterpret_cast(dst.data())); + next_states[i] = std::move(dst); + } + } + + return {std::move(encoder_out), std::move(next_states)}; + } + + std::vector RunDecoder(std::vector decoder_input) const { + auto &attr = decoder_input_attrs_[0]; + rknn_input input; + + input.index = 0; + input.type = RKNN_TENSOR_INT64; + input.fmt = attr.fmt; + input.buf = decoder_input.data(); + input.size = decoder_input.size() * sizeof(int64_t); + + std::vector decoder_out(decoder_output_attrs_[0].n_elems); + rknn_output output; + output.index = decoder_output_attrs_[0].index; + output.is_prealloc = 1; + output.want_float = 1; + output.size = decoder_out.size() * sizeof(float); + output.buf = decoder_out.data(); + + auto ret = rknn_inputs_set(decoder_ctx_, 1, &input); + SHERPA_ONNX_RKNN_CHECK(ret, "Failed to set decoder inputs"); + + ret = rknn_run(decoder_ctx_, nullptr); + SHERPA_ONNX_RKNN_CHECK(ret, "Failed to run decoder"); + + ret = rknn_outputs_get(decoder_ctx_, 1, &output, nullptr); + SHERPA_ONNX_RKNN_CHECK(ret, "Failed to get decoder output"); + + return decoder_out; + } + + std::vector RunJoiner(const float *encoder_out, + const float *decoder_out) const { + std::vector inputs(2); + inputs[0].index = 0; + inputs[0].type = RKNN_TENSOR_FLOAT32; + inputs[0].fmt = joiner_input_attrs_[0].fmt; + inputs[0].buf = const_cast(encoder_out); + inputs[0].size = joiner_input_attrs_[0].n_elems * sizeof(float); + + inputs[1].index = 1; + inputs[1].type = RKNN_TENSOR_FLOAT32; + inputs[1].fmt = joiner_input_attrs_[1].fmt; + inputs[1].buf = const_cast(decoder_out); + inputs[1].size = joiner_input_attrs_[1].n_elems * sizeof(float); + + std::vector joiner_out(joiner_output_attrs_[0].n_elems); + rknn_output output; + output.index = joiner_output_attrs_[0].index; + output.is_prealloc = 1; + output.want_float = 1; + output.size = joiner_out.size() * sizeof(float); + output.buf = joiner_out.data(); + + auto ret = rknn_inputs_set(joiner_ctx_, inputs.size(), inputs.data()); + SHERPA_ONNX_RKNN_CHECK(ret, "Failed to set joiner inputs"); + + ret = rknn_run(joiner_ctx_, nullptr); + SHERPA_ONNX_RKNN_CHECK(ret, "Failed to run joiner"); + + ret = rknn_outputs_get(joiner_ctx_, 1, &output, nullptr); + SHERPA_ONNX_RKNN_CHECK(ret, "Failed to get joiner output"); + + return joiner_out; + } + + int32_t ContextSize() const { return context_size_; } + + int32_t ChunkSize() const { return T_; } + + int32_t ChunkShift() const { return decode_chunk_len_; } + + int32_t VocabSize() const { return vocab_size_; } + + rknn_tensor_attr GetEncoderOutAttr() const { + return encoder_output_attrs_[0]; + } + + private: + void InitEncoder(void *model_data, size_t model_data_length) { + auto ret = + rknn_init(&encoder_ctx_, model_data, model_data_length, 0, nullptr); + SHERPA_ONNX_RKNN_CHECK(ret, "Failed to init encoder '%s'", + config_.transducer.encoder.c_str()); + + if (config_.debug) { + rknn_sdk_version v; + ret = rknn_query(encoder_ctx_, RKNN_QUERY_SDK_VERSION, &v, sizeof(v)); + SHERPA_ONNX_RKNN_CHECK(ret, "Failed to get rknn sdk version"); + + SHERPA_ONNX_LOGE("sdk api version: %s, driver version: %s", v.api_version, + v.drv_version); + } + + rknn_input_output_num io_num; + ret = rknn_query(encoder_ctx_, RKNN_QUERY_IN_OUT_NUM, &io_num, + sizeof(io_num)); + SHERPA_ONNX_RKNN_CHECK(ret, + "Failed to get I/O information for the encoder"); + + if (config_.debug) { + SHERPA_ONNX_LOGE("encoder: %d inputs, %d outputs", + static_cast(io_num.n_input), + static_cast(io_num.n_output)); + } + + encoder_input_attrs_.resize(io_num.n_input); + encoder_output_attrs_.resize(io_num.n_output); + + int32_t i = 0; + for (auto &attr : encoder_input_attrs_) { + memset(&attr, 0, sizeof(attr)); + attr.index = i; + ret = + rknn_query(encoder_ctx_, RKNN_QUERY_INPUT_ATTR, &attr, sizeof(attr)); + SHERPA_ONNX_RKNN_CHECK(ret, "Failed to get attr for encoder input %d", i); + i += 1; + } + + if (config_.debug) { + std::ostringstream os; + std::string sep; + for (auto &attr : encoder_input_attrs_) { + os << sep << ToString(attr); + sep = "\n"; + } + SHERPA_ONNX_LOGE("\n----------Encoder inputs info----------\n%s", + os.str().c_str()); + } + + i = 0; + for (auto &attr : encoder_output_attrs_) { + memset(&attr, 0, sizeof(attr)); + attr.index = i; + ret = + rknn_query(encoder_ctx_, RKNN_QUERY_OUTPUT_ATTR, &attr, sizeof(attr)); + SHERPA_ONNX_RKNN_CHECK(ret, "Failed to get attr for encoder output %d", + i); + i += 1; + } + + if (config_.debug) { + std::ostringstream os; + std::string sep; + for (auto &attr : encoder_output_attrs_) { + os << sep << ToString(attr); + sep = "\n"; + } + SHERPA_ONNX_LOGE("\n----------Encoder outputs info----------\n%s", + os.str().c_str()); + } + + rknn_custom_string custom_string; + ret = rknn_query(encoder_ctx_, RKNN_QUERY_CUSTOM_STRING, &custom_string, + sizeof(custom_string)); + SHERPA_ONNX_RKNN_CHECK( + ret, "Failed to read custom string from the encoder model"); + if (config_.debug) { + SHERPA_ONNX_LOGE("customs string: %s", custom_string.string); + } + auto meta = Parse(custom_string); + + for (const auto &p : meta) { + SHERPA_ONNX_LOGE("%s: %s", p.first.c_str(), p.second.c_str()); + } + + if (meta.count("encoder_dims")) { + SplitStringToIntegers(meta.at("encoder_dims"), ",", false, + &encoder_dims_); + } + + if (meta.count("attention_dims")) { + SplitStringToIntegers(meta.at("attention_dims"), ",", false, + &attention_dims_); + } + + if (meta.count("num_encoder_layers")) { + SplitStringToIntegers(meta.at("num_encoder_layers"), ",", false, + &num_encoder_layers_); + } + + if (meta.count("cnn_module_kernels")) { + SplitStringToIntegers(meta.at("cnn_module_kernels"), ",", false, + &cnn_module_kernels_); + } + + if (meta.count("left_context_len")) { + SplitStringToIntegers(meta.at("left_context_len"), ",", false, + &left_context_len_); + } + + if (meta.count("T")) { + T_ = atoi(meta.at("T").c_str()); + } + + if (meta.count("decode_chunk_len")) { + decode_chunk_len_ = atoi(meta.at("decode_chunk_len").c_str()); + } + + if (meta.count("context_size")) { + context_size_ = atoi(meta.at("context_size").c_str()); + } + + if (config_.debug) { + auto print = [](const std::vector &v, const char *name) { + std::ostringstream os; + os << name << ": "; + for (auto i : v) { + os << i << " "; + } +#if __OHOS__ + SHERPA_ONNX_LOGE("%{public}s\n", os.str().c_str()); +#else + SHERPA_ONNX_LOGE("%s\n", os.str().c_str()); +#endif + }; + print(encoder_dims_, "encoder_dims"); + print(attention_dims_, "attention_dims"); + print(num_encoder_layers_, "num_encoder_layers"); + print(cnn_module_kernels_, "cnn_module_kernels"); + print(left_context_len_, "left_context_len"); +#if __OHOS__ + SHERPA_ONNX_LOGE("T: %{public}d", T_); + SHERPA_ONNX_LOGE("decode_chunk_len_: %{public}d", decode_chunk_len_); + SHERPA_ONNX_LOGE("context_size: %{public}d", context_size_); +#else + SHERPA_ONNX_LOGE("T: %d", T_); + SHERPA_ONNX_LOGE("decode_chunk_len_: %d", decode_chunk_len_); + SHERPA_ONNX_LOGE("context_size: %d", context_size_); +#endif + } + } + + void InitDecoder(void *model_data, size_t model_data_length) { + auto ret = + rknn_init(&decoder_ctx_, model_data, model_data_length, 0, nullptr); + SHERPA_ONNX_RKNN_CHECK(ret, "Failed to init decoder '%s'", + config_.transducer.decoder.c_str()); + + rknn_input_output_num io_num; + ret = rknn_query(decoder_ctx_, RKNN_QUERY_IN_OUT_NUM, &io_num, + sizeof(io_num)); + SHERPA_ONNX_RKNN_CHECK(ret, + "Failed to get I/O information for the decoder"); + + if (io_num.n_input != 1) { + SHERPA_ONNX_LOGE("Expect only 1 decoder input. Given %d", + static_cast(io_num.n_input)); + SHERPA_ONNX_EXIT(-1); + } + + if (io_num.n_output != 1) { + SHERPA_ONNX_LOGE("Expect only 1 decoder output. Given %d", + static_cast(io_num.n_output)); + SHERPA_ONNX_EXIT(-1); + } + + if (config_.debug) { + SHERPA_ONNX_LOGE("decoder: %d inputs, %d outputs", + static_cast(io_num.n_input), + static_cast(io_num.n_output)); + } + + decoder_input_attrs_.resize(io_num.n_input); + decoder_output_attrs_.resize(io_num.n_output); + + int32_t i = 0; + for (auto &attr : decoder_input_attrs_) { + memset(&attr, 0, sizeof(attr)); + attr.index = i; + ret = + rknn_query(decoder_ctx_, RKNN_QUERY_INPUT_ATTR, &attr, sizeof(attr)); + SHERPA_ONNX_RKNN_CHECK(ret, "Failed to get attr for decoder input %d", i); + i += 1; + } + + if (config_.debug) { + std::ostringstream os; + std::string sep; + for (auto &attr : decoder_input_attrs_) { + os << sep << ToString(attr); + sep = "\n"; + } + SHERPA_ONNX_LOGE("\n----------Decoder inputs info----------\n%s", + os.str().c_str()); + } + + if (decoder_input_attrs_[0].type != RKNN_TENSOR_INT64) { + SHERPA_ONNX_LOGE("Expect int64 for decoder input. Given: %d, %s", + decoder_input_attrs_[0].type, + get_type_string(decoder_input_attrs_[0].type)); + SHERPA_ONNX_EXIT(-1); + } + + i = 0; + for (auto &attr : decoder_output_attrs_) { + memset(&attr, 0, sizeof(attr)); + attr.index = i; + ret = + rknn_query(decoder_ctx_, RKNN_QUERY_OUTPUT_ATTR, &attr, sizeof(attr)); + SHERPA_ONNX_RKNN_CHECK(ret, "Failed to get attr for decoder output %d", + i); + i += 1; + } + + if (config_.debug) { + std::ostringstream os; + std::string sep; + for (auto &attr : decoder_output_attrs_) { + os << sep << ToString(attr); + sep = "\n"; + } + SHERPA_ONNX_LOGE("\n----------Decoder outputs info----------\n%s", + os.str().c_str()); + } + } + + void InitJoiner(void *model_data, size_t model_data_length) { + auto ret = + rknn_init(&joiner_ctx_, model_data, model_data_length, 0, nullptr); + SHERPA_ONNX_RKNN_CHECK(ret, "Failed to init joiner '%s'", + config_.transducer.joiner.c_str()); + + rknn_input_output_num io_num; + ret = + rknn_query(joiner_ctx_, RKNN_QUERY_IN_OUT_NUM, &io_num, sizeof(io_num)); + SHERPA_ONNX_RKNN_CHECK(ret, "Failed to get I/O information for the joiner"); + + if (config_.debug) { + SHERPA_ONNX_LOGE("joiner: %d inputs, %d outputs", + static_cast(io_num.n_input), + static_cast(io_num.n_output)); + } + + joiner_input_attrs_.resize(io_num.n_input); + joiner_output_attrs_.resize(io_num.n_output); + + int32_t i = 0; + for (auto &attr : joiner_input_attrs_) { + memset(&attr, 0, sizeof(attr)); + attr.index = i; + ret = rknn_query(joiner_ctx_, RKNN_QUERY_INPUT_ATTR, &attr, sizeof(attr)); + SHERPA_ONNX_RKNN_CHECK(ret, "Failed to get attr for joiner input %d", i); + i += 1; + } + + if (config_.debug) { + std::ostringstream os; + std::string sep; + for (auto &attr : joiner_input_attrs_) { + os << sep << ToString(attr); + sep = "\n"; + } + SHERPA_ONNX_LOGE("\n----------Joiner inputs info----------\n%s", + os.str().c_str()); + } + + i = 0; + for (auto &attr : joiner_output_attrs_) { + memset(&attr, 0, sizeof(attr)); + attr.index = i; + ret = + rknn_query(joiner_ctx_, RKNN_QUERY_OUTPUT_ATTR, &attr, sizeof(attr)); + SHERPA_ONNX_RKNN_CHECK(ret, "Failed to get attr for joiner output %d", i); + i += 1; + } + + if (config_.debug) { + std::ostringstream os; + std::string sep; + for (auto &attr : joiner_output_attrs_) { + os << sep << ToString(attr); + sep = "\n"; + } + SHERPA_ONNX_LOGE("\n----------Joiner outputs info----------\n%s", + os.str().c_str()); + } + + vocab_size_ = joiner_output_attrs_[0].dims[1]; + if (config_.debug) { + SHERPA_ONNX_LOGE("vocab_size: %d", vocab_size_); + } + } + + private: + OnlineModelConfig config_; + rknn_context encoder_ctx_ = 0; + rknn_context decoder_ctx_ = 0; + rknn_context joiner_ctx_ = 0; + + std::vector encoder_input_attrs_; + std::vector encoder_output_attrs_; + + std::vector decoder_input_attrs_; + std::vector decoder_output_attrs_; + + std::vector joiner_input_attrs_; + std::vector joiner_output_attrs_; + + std::vector encoder_dims_; + std::vector attention_dims_; + std::vector num_encoder_layers_; + std::vector cnn_module_kernels_; + std::vector left_context_len_; + + int32_t T_ = 0; + int32_t decode_chunk_len_ = 0; + + int32_t context_size_ = 2; + int32_t vocab_size_ = 0; +}; + +OnlineZipformerTransducerModelRknn::~OnlineZipformerTransducerModelRknn() = + default; + +OnlineZipformerTransducerModelRknn::OnlineZipformerTransducerModelRknn( + const OnlineModelConfig &config) + : impl_(std::make_unique(config)) {} + +template +OnlineZipformerTransducerModelRknn::OnlineZipformerTransducerModelRknn( + Manager *mgr, const OnlineModelConfig &config) + : impl_(std::make_unique(mgr, config)) { +} + +std::vector> +OnlineZipformerTransducerModelRknn::GetEncoderInitStates() const { + return impl_->GetEncoderInitStates(); +} + +std::pair, std::vector>> +OnlineZipformerTransducerModelRknn::RunEncoder( + std::vector features, + std::vector> states) const { + return impl_->RunEncoder(std::move(features), std::move(states)); +} + +std::vector OnlineZipformerTransducerModelRknn::RunDecoder( + std::vector decoder_input) const { + return impl_->RunDecoder(std::move(decoder_input)); +} + +std::vector OnlineZipformerTransducerModelRknn::RunJoiner( + const float *encoder_out, const float *decoder_out) const { + return impl_->RunJoiner(encoder_out, decoder_out); +} + +int32_t OnlineZipformerTransducerModelRknn::ContextSize() const { + return impl_->ContextSize(); +} + +int32_t OnlineZipformerTransducerModelRknn::ChunkSize() const { + return impl_->ChunkSize(); +} + +int32_t OnlineZipformerTransducerModelRknn::ChunkShift() const { + return impl_->ChunkShift(); +} + +int32_t OnlineZipformerTransducerModelRknn::VocabSize() const { + return impl_->VocabSize(); +} + +rknn_tensor_attr OnlineZipformerTransducerModelRknn::GetEncoderOutAttr() const { + return impl_->GetEncoderOutAttr(); +} + +#if __ANDROID_API__ >= 9 +template OnlineZipformerTransducerModelRknn::OnlineZipformerTransducerModelRknn( + AAssetManager *mgr, const OnlineModelConfig &config); +#endif + +#if __OHOS__ +template OnlineZipformerTransducerModelRknn::OnlineZipformerTransducerModelRknn( + NativeResourceManager *mgr, const OnlineModelConfig &config); +#endif + +} // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/rknn/online-zipformer-transducer-model-rknn.h b/sherpa-onnx/csrc/rknn/online-zipformer-transducer-model-rknn.h new file mode 100644 index 00000000..bc821afa --- /dev/null +++ b/sherpa-onnx/csrc/rknn/online-zipformer-transducer-model-rknn.h @@ -0,0 +1,57 @@ +// sherpa-onnx/csrc/rknn/online-zipformer-transducer-model-rknn.h +// +// Copyright (c) 2025 Xiaomi Corporation +#ifndef SHERPA_ONNX_CSRC_RKNN_ONLINE_ZIPFORMER_TRANSDUCER_MODEL_RKNN_H_ +#define SHERPA_ONNX_CSRC_RKNN_ONLINE_ZIPFORMER_TRANSDUCER_MODEL_RKNN_H_ + +#include +#include +#include + +#include "rknn_api.h" // NOLINT +#include "sherpa-onnx/csrc/online-model-config.h" +#include "sherpa-onnx/csrc/online-transducer-model.h" + +namespace sherpa_onnx { + +// this is for zipformer v1, i.e., the folder +// pruned_transducer_statelss7_streaming from icefall +class OnlineZipformerTransducerModelRknn { + public: + ~OnlineZipformerTransducerModelRknn(); + + explicit OnlineZipformerTransducerModelRknn(const OnlineModelConfig &config); + + template + OnlineZipformerTransducerModelRknn(Manager *mgr, + const OnlineModelConfig &config); + + std::vector> GetEncoderInitStates() const; + + std::pair, std::vector>> RunEncoder( + std::vector features, + std::vector> states) const; + + std::vector RunDecoder(std::vector decoder_input) const; + + std::vector RunJoiner(const float *encoder_out, + const float *decoder_out) const; + + int32_t ContextSize() const; + + int32_t ChunkSize() const; + + int32_t ChunkShift() const; + + int32_t VocabSize() const; + + rknn_tensor_attr GetEncoderOutAttr() const; + + private: + class Impl; + std::unique_ptr impl_; +}; + +} // namespace sherpa_onnx + +#endif // SHERPA_ONNX_CSRC_RKNN_ONLINE_ZIPFORMER_TRANSDUCER_MODEL_RKNN_H_ diff --git a/sherpa-onnx/csrc/silero-vad-model.cc b/sherpa-onnx/csrc/silero-vad-model.cc index 1b281e5d..bc4154f7 100644 --- a/sherpa-onnx/csrc/silero-vad-model.cc +++ b/sherpa-onnx/csrc/silero-vad-model.cc @@ -17,6 +17,7 @@ #include "rawfile/raw_file_manager.h" #endif +#include "sherpa-onnx/csrc/file-utils.h" #include "sherpa-onnx/csrc/macros.h" #include "sherpa-onnx/csrc/onnx-utils.h" #include "sherpa-onnx/csrc/session.h" diff --git a/sherpa-onnx/csrc/speaker-embedding-extractor-impl.cc b/sherpa-onnx/csrc/speaker-embedding-extractor-impl.cc index 650b1576..3abb21a2 100644 --- a/sherpa-onnx/csrc/speaker-embedding-extractor-impl.cc +++ b/sherpa-onnx/csrc/speaker-embedding-extractor-impl.cc @@ -12,6 +12,7 @@ #include "rawfile/raw_file_manager.h" #endif +#include "sherpa-onnx/csrc/file-utils.h" #include "sherpa-onnx/csrc/macros.h" #include "sherpa-onnx/csrc/onnx-utils.h" #include "sherpa-onnx/csrc/speaker-embedding-extractor-general-impl.h" diff --git a/sherpa-onnx/csrc/speaker-embedding-extractor-model.cc b/sherpa-onnx/csrc/speaker-embedding-extractor-model.cc index 48d7f19e..b2192431 100644 --- a/sherpa-onnx/csrc/speaker-embedding-extractor-model.cc +++ b/sherpa-onnx/csrc/speaker-embedding-extractor-model.cc @@ -17,6 +17,7 @@ #include "rawfile/raw_file_manager.h" #endif +#include "sherpa-onnx/csrc/file-utils.h" #include "sherpa-onnx/csrc/macros.h" #include "sherpa-onnx/csrc/onnx-utils.h" #include "sherpa-onnx/csrc/session.h" diff --git a/sherpa-onnx/csrc/speaker-embedding-extractor-nemo-model.cc b/sherpa-onnx/csrc/speaker-embedding-extractor-nemo-model.cc index 3983e1cb..66f79aac 100644 --- a/sherpa-onnx/csrc/speaker-embedding-extractor-nemo-model.cc +++ b/sherpa-onnx/csrc/speaker-embedding-extractor-nemo-model.cc @@ -17,6 +17,7 @@ #include "rawfile/raw_file_manager.h" #endif +#include "sherpa-onnx/csrc/file-utils.h" #include "sherpa-onnx/csrc/macros.h" #include "sherpa-onnx/csrc/onnx-utils.h" #include "sherpa-onnx/csrc/session.h" diff --git a/sherpa-onnx/csrc/spoken-language-identification-impl.cc b/sherpa-onnx/csrc/spoken-language-identification-impl.cc index 5b29df48..1a287c45 100644 --- a/sherpa-onnx/csrc/spoken-language-identification-impl.cc +++ b/sherpa-onnx/csrc/spoken-language-identification-impl.cc @@ -10,6 +10,7 @@ #include "android/asset_manager_jni.h" #endif +#include "sherpa-onnx/csrc/file-utils.h" #include "sherpa-onnx/csrc/macros.h" #include "sherpa-onnx/csrc/onnx-utils.h" #include "sherpa-onnx/csrc/spoken-language-identification-whisper-impl.h"