mirror of
https://github.com/k2-fsa/sherpa-onnx.git
synced 2026-01-09 07:41:06 +08:00
Add UVR models for source separation. (#2266)
This commit is contained in:
parent
93e2819c18
commit
921f0f40cb
3
.github/workflows/as_cmake_sub_project.yaml
vendored
3
.github/workflows/as_cmake_sub_project.yaml
vendored
@ -4,9 +4,6 @@ on:
|
||||
push:
|
||||
branches:
|
||||
- master
|
||||
pull_request:
|
||||
branches:
|
||||
- master
|
||||
|
||||
workflow_dispatch:
|
||||
|
||||
|
||||
98
.github/workflows/export-uvr-to-onnx.yaml
vendored
Normal file
98
.github/workflows/export-uvr-to-onnx.yaml
vendored
Normal file
@ -0,0 +1,98 @@
|
||||
name: export-uvr-to-onnx
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- uvr
|
||||
workflow_dispatch:
|
||||
|
||||
concurrency:
|
||||
group: export-uvr-to-onnx-${{ github.ref }}
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
export-uvr-to-onnx:
|
||||
if: github.repository_owner == 'k2-fsa' || github.repository_owner == 'csukuangfj'
|
||||
name: export UVR to ONNX
|
||||
runs-on: ${{ matrix.os }}
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
os: [macos-latest]
|
||||
python-version: ["3.10"]
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- name: Setup Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
|
||||
- name: Install dependencies
|
||||
shell: bash
|
||||
run: |
|
||||
pip install "numpy<2" onnx==1.17.0 onnxruntime==1.17.1 onnxmltools kaldi-native-fbank librosa soundfile
|
||||
|
||||
- name: Run
|
||||
shell: bash
|
||||
run: |
|
||||
cd scripts/uvr_mdx
|
||||
curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/source-separation-models/audio_example.wav
|
||||
ls -lh audio_example.wav
|
||||
./run.sh
|
||||
|
||||
- name: Collect mp3 files
|
||||
shell: bash
|
||||
run: |
|
||||
mv -v scripts/uvr_mdx/*.mp3 ./
|
||||
ls -lh *.mp3
|
||||
|
||||
- uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: generated-mp3
|
||||
path: ./*.mp3
|
||||
|
||||
- name: Collect models
|
||||
shell: bash
|
||||
run: |
|
||||
mv -v scripts/uvr_mdx/models/*.onnx ./
|
||||
ls -lh *.onnx
|
||||
|
||||
- name: Release
|
||||
uses: svenstaro/upload-release-action@v2
|
||||
with:
|
||||
file_glob: true
|
||||
file: ./*.onnx
|
||||
overwrite: true
|
||||
repo_name: k2-fsa/sherpa-onnx
|
||||
repo_token: ${{ secrets.UPLOAD_GH_SHERPA_ONNX_TOKEN }}
|
||||
tag: source-separation-models
|
||||
|
||||
- name: Publish to huggingface
|
||||
env:
|
||||
HF_TOKEN: ${{ secrets.HF_TOKEN }}
|
||||
uses: nick-fields/retry@v3
|
||||
with:
|
||||
max_attempts: 20
|
||||
timeout_seconds: 200
|
||||
shell: bash
|
||||
command: |
|
||||
git config --global user.email "csukuangfj@gmail.com"
|
||||
git config --global user.name "Fangjun Kuang"
|
||||
|
||||
export GIT_LFS_SKIP_SMUDGE=1
|
||||
export GIT_CLONE_PROTECTION_ACTIVE=false
|
||||
|
||||
rm -rf huggingface
|
||||
git clone https://huggingface.co/k2-fsa/sherpa-onnx-models huggingface
|
||||
cd huggingface
|
||||
mkdir -p source-separation-models
|
||||
cp -av ../*.onnx ./source-separation-models
|
||||
git lfs track "*.onnx"
|
||||
git status
|
||||
git add .
|
||||
ls -lh
|
||||
git status
|
||||
git commit -m "add source separation models"
|
||||
git push https://csukuangfj:$HF_TOKEN@huggingface.co/k2-fsa/sherpa-onnx-models main
|
||||
5
scripts/uvr_mdx/READEME.md
Normal file
5
scripts/uvr_mdx/READEME.md
Normal file
@ -0,0 +1,5 @@
|
||||
# Introduction
|
||||
|
||||
This folder contains scripts for converting models from
|
||||
https://github.com/TRvlvr/model_repo/releases/tag/all_public_uvr_models
|
||||
to sherpa-onnx.
|
||||
118
scripts/uvr_mdx/add_meta_data_and_quantize.py
Executable file
118
scripts/uvr_mdx/add_meta_data_and_quantize.py
Executable file
@ -0,0 +1,118 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2025 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
|
||||
import onnx
|
||||
import onnxmltools
|
||||
import onnxruntime
|
||||
from onnxmltools.utils.float16_converter import convert_float_to_float16
|
||||
from onnxruntime.quantization import QuantType, quantize_dynamic
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--filename",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to onnx model",
|
||||
)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def export_onnx_fp16(onnx_fp32_path, onnx_fp16_path):
|
||||
onnx_fp32_model = onnxmltools.utils.load_model(onnx_fp32_path)
|
||||
onnx_fp16_model = convert_float_to_float16(onnx_fp32_model, keep_io_types=True)
|
||||
onnxmltools.utils.save_model(onnx_fp16_model, onnx_fp16_path)
|
||||
|
||||
|
||||
def validate(model: onnxruntime.InferenceSession):
|
||||
for i in model.get_inputs():
|
||||
print(i)
|
||||
|
||||
print("-----")
|
||||
|
||||
for i in model.get_outputs():
|
||||
print(i)
|
||||
|
||||
assert len(model.get_inputs()) == 1, len(model.get_inputs())
|
||||
assert len(model.get_outputs()) == 1, len(model.get_outputs())
|
||||
|
||||
inp = model.get_inputs()[0]
|
||||
outp = model.get_outputs()[0]
|
||||
|
||||
assert len(inp.shape) == 4, inp.shape
|
||||
assert len(outp.shape) == 4, outp.shape
|
||||
|
||||
assert inp.shape[1:] == outp.shape[1:], (inp.shape, outp.shape)
|
||||
|
||||
|
||||
def add_meta_data(filename, meta_data):
|
||||
model = onnx.load(filename)
|
||||
|
||||
print(model.metadata_props)
|
||||
|
||||
while len(model.metadata_props):
|
||||
model.metadata_props.pop()
|
||||
|
||||
for key, value in meta_data.items():
|
||||
meta = model.metadata_props.add()
|
||||
meta.key = key
|
||||
meta.value = str(value)
|
||||
print("--------------------")
|
||||
|
||||
print(model.metadata_props)
|
||||
|
||||
onnx.save(model, filename)
|
||||
|
||||
|
||||
def main():
|
||||
args = get_args()
|
||||
filename = Path(args.filename)
|
||||
if not filename.is_file():
|
||||
raise ValueError(f"{filename} does not exist")
|
||||
|
||||
name = filename.stem
|
||||
print("name", name)
|
||||
|
||||
model = onnx.load(str(filename))
|
||||
|
||||
session_opts = onnxruntime.SessionOptions()
|
||||
session_opts.log_severity_level = 3
|
||||
sess = onnxruntime.InferenceSession(
|
||||
str(filename), session_opts, providers=["CPUExecutionProvider"]
|
||||
)
|
||||
validate(sess)
|
||||
|
||||
inp = sess.get_inputs()[0]
|
||||
outp = sess.get_outputs()[0]
|
||||
|
||||
meta_data = {
|
||||
"model_type": "UVR",
|
||||
"model_name": name,
|
||||
"sample_rate": 44100,
|
||||
"comment": "This model is downloaded from https://github.com/TRvlvr/model_repo/releases",
|
||||
"n_fft": inp.shape[2] * 2,
|
||||
"center": 1,
|
||||
"window_type": "hann",
|
||||
"win_length": inp.shape[2] * 2,
|
||||
"hop_length": 1024,
|
||||
"dim_t": inp.shape[3],
|
||||
"dim_f": inp.shape[2],
|
||||
"dim_c": inp.shape[1],
|
||||
"stems": 2,
|
||||
}
|
||||
add_meta_data(str(filename), meta_data)
|
||||
|
||||
filename_fp16 = f"./{name}.fp16.onnx"
|
||||
export_onnx_fp16(filename, filename_fp16)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
54
scripts/uvr_mdx/run.sh
Executable file
54
scripts/uvr_mdx/run.sh
Executable file
@ -0,0 +1,54 @@
|
||||
#!/usr/bin/env bash
|
||||
# Copyright 2025 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
|
||||
set -ex
|
||||
|
||||
|
||||
# Please see https://github.com/TRvlvr/model_repo/releases/tag/all_public_uvr_models
|
||||
models=(
|
||||
UVR-MDX-NET-Inst_1.onnx
|
||||
UVR-MDX-NET-Inst_2.onnx
|
||||
UVR-MDX-NET-Inst_3.onnx
|
||||
UVR-MDX-NET-Inst_HQ_1.onnx
|
||||
UVR-MDX-NET-Inst_HQ_2.onnx
|
||||
UVR-MDX-NET-Inst_HQ_3.onnx
|
||||
UVR-MDX-NET-Inst_HQ_4.onnx
|
||||
UVR-MDX-NET-Inst_HQ_5.onnx
|
||||
UVR-MDX-NET-Inst_Main.onnx
|
||||
UVR-MDX-NET-Voc_FT.onnx
|
||||
UVR-MDX-NET_Crowd_HQ_1.onnx
|
||||
UVR_MDXNET_1_9703.onnx
|
||||
UVR_MDXNET_2_9682.onnx
|
||||
UVR_MDXNET_3_9662.onnx
|
||||
UVR_MDXNET_9482.onnx
|
||||
UVR_MDXNET_KARA.onnx
|
||||
UVR_MDXNET_KARA_2.onnx
|
||||
UVR_MDXNET_Main.onnx
|
||||
)
|
||||
|
||||
mkdir -p models
|
||||
for m in ${models[@]}; do
|
||||
if [ ! -f models/$m ]; then
|
||||
curl -SL --output models/$m https://github.com/TRvlvr/model_repo/releases/download/all_public_uvr_models/$m
|
||||
fi
|
||||
done
|
||||
|
||||
ls -lh models
|
||||
|
||||
for m in ${models[@]}; do
|
||||
echo "----------$m----------"
|
||||
python3 ./add_meta_data_and_quantize.py --filename models/$m
|
||||
|
||||
ls -lh models/
|
||||
done
|
||||
|
||||
if [ -f ./audio_example.wav ]; then
|
||||
for m in ${models[@]}; do
|
||||
./test.py --model-filename ./models/$m --audio-filename ./audio_example.wav
|
||||
name=$(basename -s .onnx $m)
|
||||
mv -v vocals.mp3 ${name}_vocals.mp3
|
||||
mv -v non_vocals.mp3 ${name}_non_vocals.mp3
|
||||
done
|
||||
|
||||
ls -lh *.mp3
|
||||
fi
|
||||
39
scripts/uvr_mdx/show.py
Executable file
39
scripts/uvr_mdx/show.py
Executable file
@ -0,0 +1,39 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2025 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
|
||||
import onnxruntime
|
||||
import onnx
|
||||
|
||||
"""
|
||||
[]
|
||||
NodeArg(name='input', type='tensor(float)', shape=['batch_size', 4, 3072, 256])
|
||||
-----
|
||||
NodeArg(name='output', type='tensor(float)', shape=['batch_size', 4, 3072, 256])
|
||||
"""
|
||||
|
||||
|
||||
def show(filename):
|
||||
model = onnx.load(filename)
|
||||
print(model.metadata_props)
|
||||
|
||||
session_opts = onnxruntime.SessionOptions()
|
||||
session_opts.log_severity_level = 3
|
||||
sess = onnxruntime.InferenceSession(
|
||||
filename, session_opts, providers=["CPUExecutionProvider"]
|
||||
)
|
||||
for i in sess.get_inputs():
|
||||
print(i)
|
||||
|
||||
print("-----")
|
||||
|
||||
for i in sess.get_outputs():
|
||||
print(i)
|
||||
|
||||
|
||||
def main():
|
||||
# show("./UVR-MDX-NET-Voc_FT.onnx")
|
||||
show("./UVR_MDXNET_1_9703.onnx")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
261
scripts/uvr_mdx/test.py
Executable file
261
scripts/uvr_mdx/test.py
Executable file
@ -0,0 +1,261 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2025 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
|
||||
import time
|
||||
|
||||
import argparse
|
||||
import kaldi_native_fbank as knf
|
||||
import librosa
|
||||
import numpy as np
|
||||
import onnxruntime as ort
|
||||
import soundfile as sf
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--model-filename",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to onnx model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--audio-filename",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to input audio file",
|
||||
)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
class OnnxModel:
|
||||
def __init__(self, filename):
|
||||
session_opts = ort.SessionOptions()
|
||||
session_opts.inter_op_num_threads = 4
|
||||
session_opts.intra_op_num_threads = 4
|
||||
|
||||
self.session_opts = session_opts
|
||||
self.model = ort.InferenceSession(
|
||||
filename,
|
||||
sess_options=self.session_opts,
|
||||
providers=["CPUExecutionProvider"],
|
||||
)
|
||||
|
||||
self.dim_t = self.model.get_outputs()[0].shape[3]
|
||||
|
||||
self.dim_f = self.model.get_outputs()[0].shape[2]
|
||||
|
||||
self.n_fft = self.dim_f * 2
|
||||
|
||||
self.dim_c = self.model.get_outputs()[0].shape[1]
|
||||
assert self.dim_c == 4, self.dim_c
|
||||
|
||||
self.hop = 1024
|
||||
self.n_bins = self.n_fft // 2 + 1
|
||||
self.chunk_size = self.hop * (self.dim_t - 1)
|
||||
|
||||
self.freq_pad = np.zeros([1, self.dim_c, self.n_bins - self.dim_f, self.dim_t])
|
||||
|
||||
print(f"----------inputs for {filename}----------")
|
||||
for i in self.model.get_inputs():
|
||||
print(i)
|
||||
|
||||
print(f"----------outputs for {filename}----------")
|
||||
|
||||
for i in self.model.get_outputs():
|
||||
print(i)
|
||||
print(i.shape)
|
||||
print("--------------------")
|
||||
|
||||
def __call__(self, x):
|
||||
"""
|
||||
Args:
|
||||
x: (batch_size, 4, self.dim_f, self.dim_t)
|
||||
Returns:
|
||||
spec: (batch_size, 4, self.dim_f, self.dim_t)
|
||||
"""
|
||||
spec = self.model.run(
|
||||
[
|
||||
self.model.get_outputs()[0].name,
|
||||
],
|
||||
{
|
||||
self.model.get_inputs()[0].name: x,
|
||||
},
|
||||
)[0]
|
||||
|
||||
return spec
|
||||
|
||||
|
||||
def main():
|
||||
args = get_args()
|
||||
m = OnnxModel(args.model_filename)
|
||||
|
||||
stft_config = knf.StftConfig(
|
||||
n_fft=m.n_fft,
|
||||
hop_length=m.hop,
|
||||
win_length=m.n_fft,
|
||||
center=True,
|
||||
window_type="hann",
|
||||
)
|
||||
knf_stft = knf.Stft(stft_config)
|
||||
knf_istft = knf.IStft(stft_config)
|
||||
|
||||
sample_rate = 44100
|
||||
|
||||
samples, rate = librosa.load(args.audio_filename, mono=False, sr=sample_rate)
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
assert rate == sample_rate, (rate, sample_rate)
|
||||
|
||||
# samples: (2, 479832) , (num_channels, num_samples), 44100, 10.88
|
||||
print("samples", samples.shape, rate, samples.shape[1] / rate)
|
||||
|
||||
assert samples.ndim == 2, samples.shape
|
||||
assert samples.shape[0] == 2, samples.shape
|
||||
|
||||
margin = sample_rate
|
||||
|
||||
num_chunks = 15
|
||||
chunk_size = num_chunks * sample_rate
|
||||
|
||||
# if they are too few samples, reset chunk_size
|
||||
if samples.shape[1] < chunk_size:
|
||||
chunk_size = samples.shape[1]
|
||||
|
||||
if margin > chunk_size:
|
||||
margin = chunk_size
|
||||
|
||||
segments = []
|
||||
for skip in range(0, samples.shape[1], chunk_size):
|
||||
start = max(0, skip - margin)
|
||||
end = min(skip + chunk_size + margin, samples.shape[1])
|
||||
segments.append(samples[:, start:end])
|
||||
if end == samples.shape[1]:
|
||||
break
|
||||
|
||||
sources = []
|
||||
for kk, s in enumerate(segments):
|
||||
num_samples = s.shape[1]
|
||||
trim = m.n_fft // 2
|
||||
gen_size = m.chunk_size - 2 * trim
|
||||
pad = gen_size - s.shape[1] % gen_size
|
||||
mix_p = np.concatenate(
|
||||
(
|
||||
np.zeros((2, trim)),
|
||||
s,
|
||||
np.zeros((2, pad)),
|
||||
np.zeros((2, trim)),
|
||||
),
|
||||
axis=1,
|
||||
)
|
||||
|
||||
chunk_list = []
|
||||
i = 0
|
||||
while i < s.shape[1] + pad:
|
||||
chunk_list.append(mix_p[:, i : i + m.chunk_size])
|
||||
i += gen_size
|
||||
|
||||
mix_waves = np.array(chunk_list)
|
||||
|
||||
mix_waves_reshaped = mix_waves.reshape(-1, m.chunk_size)
|
||||
stft_results = []
|
||||
for w in mix_waves_reshaped:
|
||||
stft = knf_stft(w)
|
||||
stft_results.append(stft)
|
||||
real = np.array(
|
||||
[np.array(s.real).reshape(s.num_frames, -1) for s in stft_results],
|
||||
dtype=np.float32,
|
||||
)[:, :, :-1]
|
||||
# real: (6, 256, 3072)
|
||||
|
||||
real = real.transpose(0, 2, 1)
|
||||
# real: (6, 3072, 256)
|
||||
|
||||
imag = np.array(
|
||||
[np.array(s.imag).reshape(s.num_frames, -1) for s in stft_results],
|
||||
dtype=np.float32,
|
||||
)[:, :, :-1]
|
||||
imag = imag.transpose(0, 2, 1)
|
||||
# imag: (6, 3072, 256)
|
||||
|
||||
x = np.stack([real, imag], axis=1)
|
||||
# x: (6, 2, 3072, 256) -> (batch_size, real_imag, 3072, 256)
|
||||
x = x.reshape(-1, m.dim_c, m.dim_f, m.dim_t)
|
||||
# x: (3, 4, 3072, 256)
|
||||
spec = m(x)
|
||||
|
||||
freq_pad = np.repeat(m.freq_pad, spec.shape[0], axis=0)
|
||||
|
||||
x = np.concatenate([spec, freq_pad], axis=2)
|
||||
# x: (3, 4, 3073, 256)
|
||||
x = x.reshape(-1, 2, m.n_bins, m.dim_t)
|
||||
# x: (6, 2, 3073, 256)
|
||||
x = x.transpose(0, 1, 3, 2)
|
||||
# x: (6, 2, 256, 3073)
|
||||
num_frames = x.shape[2]
|
||||
|
||||
x = x.reshape(x.shape[0], x.shape[1], -1)
|
||||
wav_list = []
|
||||
for k in range(x.shape[0]):
|
||||
istft_result = knf.StftResult(
|
||||
real=x[k, 0].reshape(-1).tolist(),
|
||||
imag=x[k, 1].reshape(-1).tolist(),
|
||||
num_frames=num_frames,
|
||||
)
|
||||
wav = knf_istft(istft_result)
|
||||
wav_list.append(wav)
|
||||
wav = np.array(wav_list, dtype=np.float32)
|
||||
# wav: (6, 261120)
|
||||
|
||||
wav = wav.reshape(-1, 2, wav.shape[-1])
|
||||
# wav: (3, 2, 261120)
|
||||
|
||||
wav = wav[:, :, trim:-trim]
|
||||
# wav: (3, 2, 254976)
|
||||
|
||||
wav = wav.transpose(1, 0, 2)
|
||||
# wav: (2, 3, 254976)
|
||||
|
||||
wav = wav.reshape(2, -1)
|
||||
# wav: (2, 764928)
|
||||
|
||||
wav = wav[:, :-pad]
|
||||
# wav: 2, 705600)
|
||||
if kk == 0:
|
||||
start = 0
|
||||
else:
|
||||
start = margin
|
||||
|
||||
if kk == len(segments) - 1:
|
||||
end = None
|
||||
else:
|
||||
end = -margin
|
||||
|
||||
sources.append(wav[:, start:end])
|
||||
|
||||
sources = np.concatenate(sources, axis=-1)
|
||||
|
||||
vocals = sources
|
||||
non_vocals = samples - vocals
|
||||
end_time = time.time()
|
||||
elapsed_seconds = end_time - start_time
|
||||
print(f"Elapsed seconds: {elapsed_seconds:.3f}")
|
||||
|
||||
audio_duration = samples.shape[1] / sample_rate
|
||||
real_time_factor = elapsed_seconds / audio_duration
|
||||
print(f"Elapsed seconds: {elapsed_seconds:.3f}")
|
||||
print(f"Audio duration in seconds: {audio_duration:.3f}")
|
||||
print(f"RTF: {elapsed_seconds:.3f}/{audio_duration:.3f} = {real_time_factor:.3f}")
|
||||
|
||||
sf.write(f"./vocals.mp3", np.transpose(vocals), sample_rate)
|
||||
sf.write(f"./non_vocals.mp3", np.transpose(non_vocals), sample_rate)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Loading…
x
Reference in New Issue
Block a user