diff --git a/.github/workflows/build_windows_packages.yaml b/.github/workflows/build_windows_packages.yaml index 32861463..d5bef14b 100644 --- a/.github/workflows/build_windows_packages.yaml +++ b/.github/workflows/build_windows_packages.yaml @@ -15,11 +15,7 @@ on: jobs: build: runs-on: windows-latest - strategy: - matrix: - torch_cuda: [cu124, cu128] env: - TORCH_CUDA: ${{ matrix.torch_cuda }} MODELSCOPE_USERNAME: ${{ secrets.MODELSCOPE_USERNAME }} MODELSCOPE_TOKEN: ${{ secrets.MODELSCOPE_TOKEN }} HUGGINGFACE_USERNAME: ${{ secrets.HUGGINGFACE_USERNAME }} diff --git a/.github/workflows/docker-publish.yaml b/.github/workflows/docker-publish.yaml index a00a0a77..1e3fd257 100644 --- a/.github/workflows/docker-publish.yaml +++ b/.github/workflows/docker-publish.yaml @@ -18,68 +18,22 @@ jobs: DATE=$(date +'%Y%m%d') COMMIT=$(git rev-parse --short=6 HEAD) echo "tag=${DATE}-${COMMIT}" >> $GITHUB_OUTPUT - build-amd64: + build-and-publish: needs: generate-meta runs-on: ubuntu-22.04 strategy: matrix: include: - - cuda_version: 12.6 - lite: true + - lite: true torch_base: lite - tag_prefix: cu126-lite - - cuda_version: 12.6 - lite: false + tag_prefix: xpu-lite + - lite: false torch_base: full - tag_prefix: cu126 - - cuda_version: 12.8 - lite: true - torch_base: lite - tag_prefix: cu128-lite - - cuda_version: 12.8 - lite: false - torch_base: full - tag_prefix: cu128 - + tag_prefix: xpu steps: - name: Checkout Code uses: actions/checkout@v4 - - name: Free up disk space - run: | - echo "Before cleanup:" - df -h - - sudo rm -rf /opt/ghc - sudo rm -rf /opt/hostedtoolcache/CodeQL - sudo rm -rf /opt/hostedtoolcache/PyPy - sudo rm -rf /opt/hostedtoolcache/go - sudo rm -rf /opt/hostedtoolcache/node - sudo rm -rf /opt/hostedtoolcache/Ruby - sudo rm -rf /opt/microsoft - sudo rm -rf /opt/pipx - sudo rm -rf /opt/az - sudo rm -rf /opt/google - - - sudo rm -rf /usr/lib/jvm - sudo rm -rf /usr/lib/google-cloud-sdk - sudo rm -rf /usr/lib/dotnet - - sudo rm -rf /usr/local/lib/android - sudo rm -rf /usr/local/.ghcup - sudo rm -rf /usr/local/julia1.11.5 - sudo rm -rf /usr/local/share/powershell - sudo rm -rf /usr/local/share/chromium - - sudo rm -rf /usr/share/swift - sudo rm -rf /usr/share/miniconda - sudo rm -rf /usr/share/az_12.1.0 - sudo rm -rf /usr/share/dotnet - - echo "After cleanup:" - df -h - - name: Set up Docker Buildx uses: docker/setup-buildx-action@v3 @@ -89,188 +43,18 @@ jobs: username: ${{ secrets.DOCKER_HUB_USERNAME }} password: ${{ secrets.DOCKER_HUB_PASSWORD }} - - name: Build and Push Docker Image (amd64) + - name: Build and Push Docker Image uses: docker/build-push-action@v5 with: context: . file: ./Dockerfile push: true - platforms: linux/amd64 + platforms: linux/amd64,linux/arm64 build-args: | LITE=${{ matrix.lite }} TORCH_BASE=${{ matrix.torch_base }} - CUDA_VERSION=${{ matrix.cuda_version }} WORKFLOW=true tags: | - xxxxrt666/gpt-sovits:${{ matrix.tag_prefix }}-${{ needs.generate-meta.outputs.tag }}-amd64 - xxxxrt666/gpt-sovits:latest-${{ matrix.tag_prefix }}-amd64 - - build-arm64: - needs: generate-meta - runs-on: ubuntu-22.04-arm - strategy: - matrix: - include: - - cuda_version: 12.6 - lite: true - torch_base: lite - tag_prefix: cu126-lite - - cuda_version: 12.6 - lite: false - torch_base: full - tag_prefix: cu126 - - cuda_version: 12.8 - lite: true - torch_base: lite - tag_prefix: cu128-lite - - cuda_version: 12.8 - lite: false - torch_base: full - tag_prefix: cu128 - - steps: - - name: Checkout Code - uses: actions/checkout@v4 - - - name: Free up disk space - run: | - echo "Before cleanup:" - df -h - - sudo rm -rf /opt/ghc - sudo rm -rf /opt/hostedtoolcache/CodeQL - sudo rm -rf /opt/hostedtoolcache/PyPy - sudo rm -rf /opt/hostedtoolcache/go - sudo rm -rf /opt/hostedtoolcache/node - sudo rm -rf /opt/hostedtoolcache/Ruby - sudo rm -rf /opt/microsoft - sudo rm -rf /opt/pipx - sudo rm -rf /opt/az - sudo rm -rf /opt/google - - - sudo rm -rf /usr/lib/jvm - sudo rm -rf /usr/lib/google-cloud-sdk - sudo rm -rf /usr/lib/dotnet - - sudo rm -rf /usr/local/lib/android - sudo rm -rf /usr/local/.ghcup - sudo rm -rf /usr/local/julia1.11.5 - sudo rm -rf /usr/local/share/powershell - sudo rm -rf /usr/local/share/chromium - - sudo rm -rf /usr/share/swift - sudo rm -rf /usr/share/miniconda - sudo rm -rf /usr/share/az_12.1.0 - sudo rm -rf /usr/share/dotnet - - echo "After cleanup:" - df -h - - - name: Set up Docker Buildx - uses: docker/setup-buildx-action@v3 - - - name: Log in to Docker Hub - uses: docker/login-action@v3 - with: - username: ${{ secrets.DOCKER_HUB_USERNAME }} - password: ${{ secrets.DOCKER_HUB_PASSWORD }} - - - name: Build and Push Docker Image (arm64) - uses: docker/build-push-action@v5 - with: - context: . - file: ./Dockerfile - push: true - platforms: linux/arm64 - build-args: | - LITE=${{ matrix.lite }} - TORCH_BASE=${{ matrix.torch_base }} - CUDA_VERSION=${{ matrix.cuda_version }} - WORKFLOW=true - tags: | - xxxxrt666/gpt-sovits:${{ matrix.tag_prefix }}-${{ needs.generate-meta.outputs.tag }}-arm64 - xxxxrt666/gpt-sovits:latest-${{ matrix.tag_prefix }}-arm64 - - - merge-and-clean: - needs: - - build-amd64 - - build-arm64 - - generate-meta - runs-on: ubuntu-latest - strategy: - matrix: - include: - - tag_prefix: cu126-lite - - tag_prefix: cu126 - - tag_prefix: cu128-lite - - tag_prefix: cu128 - - steps: - - name: Set up Docker Buildx - uses: docker/setup-buildx-action@v3 - - - name: Log in to Docker Hub - uses: docker/login-action@v3 - with: - username: ${{ secrets.DOCKER_HUB_USERNAME }} - password: ${{ secrets.DOCKER_HUB_PASSWORD }} - - - name: Merge amd64 and arm64 into multi-arch image - run: | - DATE_TAG=${{ needs.generate-meta.outputs.tag }} - TAG_PREFIX=${{ matrix.tag_prefix }} - - docker buildx imagetools create \ - --tag ${{ secrets.DOCKER_HUB_USERNAME }}/gpt-sovits:${TAG_PREFIX}-${DATE_TAG} \ - ${{ secrets.DOCKER_HUB_USERNAME }}/gpt-sovits:${TAG_PREFIX}-${DATE_TAG}-amd64 \ - ${{ secrets.DOCKER_HUB_USERNAME }}/gpt-sovits:${TAG_PREFIX}-${DATE_TAG}-arm64 - - docker buildx imagetools create \ - --tag ${{ secrets.DOCKER_HUB_USERNAME }}/gpt-sovits:latest-${TAG_PREFIX} \ - ${{ secrets.DOCKER_HUB_USERNAME }}/gpt-sovits:latest-${TAG_PREFIX}-amd64 \ - ${{ secrets.DOCKER_HUB_USERNAME }}/gpt-sovits:latest-${TAG_PREFIX}-arm64 - - name: Delete old platform-specific tags via Docker Hub API - env: - DOCKER_HUB_USERNAME: ${{ secrets.DOCKER_HUB_USERNAME }} - DOCKER_HUB_TOKEN: ${{ secrets.DOCKER_HUB_PASSWORD }} - TAG_PREFIX: ${{ matrix.tag_prefix }} - DATE_TAG: ${{ needs.generate-meta.outputs.tag }} - run: | - sudo apt-get update && sudo apt-get install -y jq - - TOKEN=$(curl -s -u $DOCKER_HUB_USERNAME:$DOCKER_HUB_TOKEN \ - "https://auth.docker.io/token?service=registry.docker.io&scope=repository:$DOCKER_HUB_USERNAME/gpt-sovits:pull,push,delete" \ - | jq -r .token) - - for PLATFORM in amd64 arm64; do - SAFE_PLATFORM=$(echo $PLATFORM | sed 's/\//-/g') - TAG="${TAG_PREFIX}-${DATE_TAG}-${SAFE_PLATFORM}" - LATEST_TAG="latest-${TAG_PREFIX}-${SAFE_PLATFORM}" - - for DEL_TAG in "$TAG" "$LATEST_TAG"; do - echo "Deleting tag: $DEL_TAG" - curl -X DELETE -H "Authorization: Bearer $TOKEN" https://registry-1.docker.io/v2/$DOCKER_HUB_USERNAME/gpt-sovits/manifests/$DEL_TAG - done - done - create-default: - runs-on: ubuntu-latest - needs: - - merge-and-clean - steps: - - name: Set up Docker Buildx - uses: docker/setup-buildx-action@v3 - - - name: Log in to Docker Hub - uses: docker/login-action@v3 - with: - username: ${{ secrets.DOCKER_HUB_USERNAME }} - password: ${{ secrets.DOCKER_HUB_PASSWORD }} - - - name: Create Default Tag - run: | - docker buildx imagetools create \ - --tag ${{ secrets.DOCKER_HUB_USERNAME }}/gpt-sovits:latest \ - ${{ secrets.DOCKER_HUB_USERNAME }}/gpt-sovits:latest-cu126-lite + xxxxrt666/gpt-sovits:${{ matrix.tag_prefix }}-${{ needs.generate-meta.outputs.tag }} + xxxxrt666/gpt-sovits:latest-${{ matrix.tag_prefix }} \ No newline at end of file diff --git a/GPT_SoVITS/BigVGAN/inference.py b/GPT_SoVITS/BigVGAN/inference.py index 5f892a3c..c5d125a7 100644 --- a/GPT_SoVITS/BigVGAN/inference.py +++ b/GPT_SoVITS/BigVGAN/inference.py @@ -58,9 +58,11 @@ def main(): parser.add_argument("--input_wavs_dir", default="test_files") parser.add_argument("--output_dir", default="generated_files") parser.add_argument("--checkpoint_file", required=True) - parser.add_argument("--use_cuda_kernel", action="store_true", default=False) + # --use_cuda_kernel argument is removed to disable custom CUDA kernels. + # parser.add_argument("--use_cuda_kernel", action="store_true", default=False) a = parser.parse_args() + a.use_cuda_kernel = False config_file = os.path.join(os.path.split(a.checkpoint_file)[0], "config.json") with open(config_file) as f: @@ -72,9 +74,9 @@ def main(): torch.manual_seed(h.seed) global device - if torch.cuda.is_available(): - torch.cuda.manual_seed(h.seed) - device = torch.device("cuda") + if torch.xpu.is_available(): + torch.xpu.manual_seed(h.seed) + device = torch.device("xpu") else: device = torch.device("cpu") diff --git a/GPT_SoVITS/BigVGAN/inference_e2e.py b/GPT_SoVITS/BigVGAN/inference_e2e.py index 9c0df774..9b1c6ddc 100644 --- a/GPT_SoVITS/BigVGAN/inference_e2e.py +++ b/GPT_SoVITS/BigVGAN/inference_e2e.py @@ -73,9 +73,11 @@ def main(): parser.add_argument("--input_mels_dir", default="test_mel_files") parser.add_argument("--output_dir", default="generated_files_from_mel") parser.add_argument("--checkpoint_file", required=True) - parser.add_argument("--use_cuda_kernel", action="store_true", default=False) + # --use_cuda_kernel argument is removed to disable custom CUDA kernels. + # parser.add_argument("--use_cuda_kernel", action="store_true", default=False) a = parser.parse_args() + a.use_cuda_kernel = False config_file = os.path.join(os.path.split(a.checkpoint_file)[0], "config.json") with open(config_file) as f: @@ -87,9 +89,9 @@ def main(): torch.manual_seed(h.seed) global device - if torch.cuda.is_available(): - torch.cuda.manual_seed(h.seed) - device = torch.device("cuda") + if torch.xpu.is_available(): + torch.xpu.manual_seed(h.seed) + device = torch.device("xpu") else: device = torch.device("cpu") diff --git a/GPT_SoVITS/BigVGAN/tests/test_cuda_vs_torch_model.py b/GPT_SoVITS/BigVGAN/tests/test_cuda_vs_torch_model.py deleted file mode 100644 index 8ddb29e5..00000000 --- a/GPT_SoVITS/BigVGAN/tests/test_cuda_vs_torch_model.py +++ /dev/null @@ -1,215 +0,0 @@ -# Copyright (c) 2024 NVIDIA CORPORATION. -# Licensed under the MIT license. - -import os -import sys - -# to import modules from parent_dir -parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) -sys.path.append(parent_dir) - -import torch -import json -from env import AttrDict -from bigvgan import BigVGAN -from time import time -from tqdm import tqdm -from meldataset import mel_spectrogram, MAX_WAV_VALUE -from scipy.io.wavfile import write -import numpy as np - -import argparse - -torch.backends.cudnn.benchmark = True - -# For easier debugging -torch.set_printoptions(linewidth=200, threshold=10_000) - - -def generate_soundwave(duration=5.0, sr=24000): - t = np.linspace(0, duration, int(sr * duration), False, dtype=np.float32) - - modulation = np.sin(2 * np.pi * t / duration) - - min_freq = 220 - max_freq = 1760 - frequencies = min_freq + (max_freq - min_freq) * (modulation + 1) / 2 - soundwave = np.sin(2 * np.pi * frequencies * t) - - soundwave = soundwave / np.max(np.abs(soundwave)) * 0.95 - - return soundwave, sr - - -def get_mel(x, h): - return mel_spectrogram(x, h.n_fft, h.num_mels, h.sampling_rate, h.hop_size, h.win_size, h.fmin, h.fmax) - - -def load_checkpoint(filepath, device): - assert os.path.isfile(filepath) - print(f"Loading '{filepath}'") - checkpoint_dict = torch.load(filepath, map_location=device) - print("Complete.") - return checkpoint_dict - - -if __name__ == "__main__": - parser = argparse.ArgumentParser(description="Test script to check CUDA kernel correctness.") - parser.add_argument( - "--checkpoint_file", - type=str, - required=True, - help="Path to the checkpoint file. Assumes config.json exists in the directory.", - ) - - args = parser.parse_args() - - config_file = os.path.join(os.path.split(args.checkpoint_file)[0], "config.json") - with open(config_file) as f: - config = f.read() - json_config = json.loads(config) - h = AttrDict({**json_config}) - - print("loading plain Pytorch BigVGAN") - generator_original = BigVGAN(h).to("cuda") - print("loading CUDA kernel BigVGAN with auto-build") - generator_cuda_kernel = BigVGAN(h, use_cuda_kernel=True).to("cuda") - - state_dict_g = load_checkpoint(args.checkpoint_file, "cuda") - generator_original.load_state_dict(state_dict_g["generator"]) - generator_cuda_kernel.load_state_dict(state_dict_g["generator"]) - - generator_original.remove_weight_norm() - generator_original.eval() - generator_cuda_kernel.remove_weight_norm() - generator_cuda_kernel.eval() - - # define number of samples and length of mel frame to benchmark - num_sample = 10 - num_mel_frame = 16384 - - # CUDA kernel correctness check - diff = 0.0 - for i in tqdm(range(num_sample)): - # Random mel - data = torch.rand((1, h.num_mels, num_mel_frame), device="cuda") - - with torch.inference_mode(): - audio_original = generator_original(data) - - with torch.inference_mode(): - audio_cuda_kernel = generator_cuda_kernel(data) - - # Both outputs should be (almost) the same - test_result = (audio_original - audio_cuda_kernel).abs() - diff += test_result.mean(dim=-1).item() - - diff /= num_sample - if diff <= 2e-3: # We can expect a small difference (~1e-3) which does not affect perceptual quality - print( - f"\n[Success] test CUDA fused vs. plain torch BigVGAN inference" - f"\n > mean_difference={diff}" - f"\n > fused_values={audio_cuda_kernel[-1][-1][-30:].tolist()}" - f"\n > torch_values={audio_original[-1][-1][-30:].tolist()}" - ) - else: - print( - f"\n[Fail] test CUDA fused vs. plain torch BigVGAN inference" - f"\n > mean_difference={diff}" - f"\n > fused_values={audio_cuda_kernel[-1][-1][-30:].tolist()}, " - f"\n > torch_values={audio_original[-1][-1][-30:].tolist()}" - ) - - del data, audio_original, audio_cuda_kernel - - # Variables for tracking total time and VRAM usage - toc_total_original = 0 - toc_total_cuda_kernel = 0 - vram_used_original_total = 0 - vram_used_cuda_kernel_total = 0 - audio_length_total = 0 - - # Measure Original inference in isolation - for i in tqdm(range(num_sample)): - torch.cuda.reset_peak_memory_stats(device="cuda") - data = torch.rand((1, h.num_mels, num_mel_frame), device="cuda") - torch.cuda.synchronize() - tic = time() - with torch.inference_mode(): - audio_original = generator_original(data) - torch.cuda.synchronize() - toc = time() - tic - toc_total_original += toc - - vram_used_original_total += torch.cuda.max_memory_allocated(device="cuda") - - del data, audio_original - torch.cuda.empty_cache() - - # Measure CUDA kernel inference in isolation - for i in tqdm(range(num_sample)): - torch.cuda.reset_peak_memory_stats(device="cuda") - data = torch.rand((1, h.num_mels, num_mel_frame), device="cuda") - torch.cuda.synchronize() - tic = time() - with torch.inference_mode(): - audio_cuda_kernel = generator_cuda_kernel(data) - torch.cuda.synchronize() - toc = time() - tic - toc_total_cuda_kernel += toc - - audio_length_total += audio_cuda_kernel.shape[-1] - - vram_used_cuda_kernel_total += torch.cuda.max_memory_allocated(device="cuda") - - del data, audio_cuda_kernel - torch.cuda.empty_cache() - - # Calculate metrics - audio_second = audio_length_total / h.sampling_rate - khz_original = audio_length_total / toc_total_original / 1000 - khz_cuda_kernel = audio_length_total / toc_total_cuda_kernel / 1000 - vram_used_original_gb = vram_used_original_total / num_sample / (1024**3) - vram_used_cuda_kernel_gb = vram_used_cuda_kernel_total / num_sample / (1024**3) - - # Print results - print( - f"Original BigVGAN: took {toc_total_original:.2f} seconds to generate {audio_second:.2f} seconds of audio, {khz_original:.1f}kHz, {audio_second / toc_total_original:.1f} faster than realtime, VRAM used {vram_used_original_gb:.1f} GB" - ) - print( - f"CUDA kernel BigVGAN: took {toc_total_cuda_kernel:.2f} seconds to generate {audio_second:.2f} seconds of audio, {khz_cuda_kernel:.1f}kHz, {audio_second / toc_total_cuda_kernel:.1f} faster than realtime, VRAM used {vram_used_cuda_kernel_gb:.1f} GB" - ) - print(f"speedup of CUDA kernel: {khz_cuda_kernel / khz_original}") - print(f"VRAM saving of CUDA kernel: {vram_used_original_gb / vram_used_cuda_kernel_gb}") - - # Use artificial sine waves for inference test - audio_real, sr = generate_soundwave(duration=5.0, sr=h.sampling_rate) - audio_real = torch.tensor(audio_real).to("cuda") - # Compute mel spectrogram from the ground truth audio - x = get_mel(audio_real.unsqueeze(0), h) - - with torch.inference_mode(): - y_g_hat_original = generator_original(x) - y_g_hat_cuda_kernel = generator_cuda_kernel(x) - - audio_real = audio_real.squeeze() - audio_real = audio_real * MAX_WAV_VALUE - audio_real = audio_real.cpu().numpy().astype("int16") - - audio_original = y_g_hat_original.squeeze() - audio_original = audio_original * MAX_WAV_VALUE - audio_original = audio_original.cpu().numpy().astype("int16") - - audio_cuda_kernel = y_g_hat_cuda_kernel.squeeze() - audio_cuda_kernel = audio_cuda_kernel * MAX_WAV_VALUE - audio_cuda_kernel = audio_cuda_kernel.cpu().numpy().astype("int16") - - os.makedirs("tmp", exist_ok=True) - output_file_real = os.path.join("tmp", "audio_real.wav") - output_file_original = os.path.join("tmp", "audio_generated_original.wav") - output_file_cuda_kernel = os.path.join("tmp", "audio_generated_cuda_kernel.wav") - write(output_file_real, h.sampling_rate, audio_real) - write(output_file_original, h.sampling_rate, audio_original) - write(output_file_cuda_kernel, h.sampling_rate, audio_cuda_kernel) - print("Example generated audios of original vs. fused CUDA kernel written to tmp!") - print("Done") diff --git a/GPT_SoVITS/s1_train.py b/GPT_SoVITS/s1_train.py index 1176f0bc..c64248c2 100644 --- a/GPT_SoVITS/s1_train.py +++ b/GPT_SoVITS/s1_train.py @@ -110,15 +110,15 @@ def main(args): os.environ["USE_LIBUV"] = "0" trainer: Trainer = Trainer( max_epochs=config["train"]["epochs"], - accelerator="gpu" if torch.cuda.is_available() else "cpu", + accelerator="xpu" if torch.xpu.is_available() else "cpu", # val_check_interval=9999999999999999999999,###不要验证 # check_val_every_n_epoch=None, limit_val_batches=0, - devices=-1 if torch.cuda.is_available() else 1, + devices=-1 if torch.xpu.is_available() else 1, benchmark=False, fast_dev_run=False, - strategy=DDPStrategy(process_group_backend="nccl" if platform.system() != "Windows" else "gloo") - if torch.cuda.is_available() + strategy=DDPStrategy(process_group_backend="ccl" if platform.system() != "Windows" else "gloo") + if torch.xpu.is_available() else "auto", precision=config["train"]["precision"], logger=logger, diff --git a/GPT_SoVITS/s2_train.py b/GPT_SoVITS/s2_train.py index 4b9f6488..513ab257 100644 --- a/GPT_SoVITS/s2_train.py +++ b/GPT_SoVITS/s2_train.py @@ -41,18 +41,18 @@ from process_ckpt import savee torch.backends.cudnn.benchmark = False torch.backends.cudnn.deterministic = False ###反正A100fp32更快,那试试tf32吧 -torch.backends.cuda.matmul.allow_tf32 = True -torch.backends.cudnn.allow_tf32 = True +# torch.backends.cuda.matmul.allow_tf32 = True # XPU does not support this +# torch.backends.cudnn.allow_tf32 = True # XPU does not support this torch.set_float32_matmul_precision("medium") # 最低精度但最快(也就快一丁点),对于结果造成不了影响 # from config import pretrained_s2G,pretrained_s2D global_step = 0 -device = "cpu" # cuda以外的设备,等mps优化后加入 +device = "xpu" if torch.xpu.is_available() else "cpu" def main(): - if torch.cuda.is_available(): - n_gpus = torch.cuda.device_count() + if torch.xpu.is_available(): + n_gpus = torch.xpu.device_count() else: n_gpus = 1 os.environ["MASTER_ADDR"] = "localhost" @@ -78,14 +78,14 @@ def run(rank, n_gpus, hps): writer_eval = SummaryWriter(log_dir=os.path.join(hps.s2_ckpt_dir, "eval")) dist.init_process_group( - backend="gloo" if os.name == "nt" or not torch.cuda.is_available() else "nccl", + backend="gloo" if os.name == "nt" or not torch.xpu.is_available() else "ccl", init_method="env://?use_libuv=False", world_size=n_gpus, rank=rank, ) torch.manual_seed(hps.train.seed) - if torch.cuda.is_available(): - torch.cuda.set_device(rank) + if torch.xpu.is_available(): + torch.xpu.set_device(rank) train_dataset = TextAudioSpeakerLoader(hps.data, version=hps.model.version) train_sampler = DistributedBucketSampler( @@ -132,27 +132,14 @@ def run(rank, n_gpus, hps): # batch_size=1, pin_memory=True, # drop_last=False, collate_fn=collate_fn) - net_g = ( - SynthesizerTrn( - hps.data.filter_length // 2 + 1, - hps.train.segment_size // hps.data.hop_length, - n_speakers=hps.data.n_speakers, - **hps.model, - ).cuda(rank) - if torch.cuda.is_available() - else SynthesizerTrn( - hps.data.filter_length // 2 + 1, - hps.train.segment_size // hps.data.hop_length, - n_speakers=hps.data.n_speakers, - **hps.model, - ).to(device) - ) + net_g = SynthesizerTrn( + hps.data.filter_length // 2 + 1, + hps.train.segment_size // hps.data.hop_length, + n_speakers=hps.data.n_speakers, + **hps.model, + ).to(device) - net_d = ( - MultiPeriodDiscriminator(hps.model.use_spectral_norm, version=hps.model.version).cuda(rank) - if torch.cuda.is_available() - else MultiPeriodDiscriminator(hps.model.use_spectral_norm, version=hps.model.version).to(device) - ) + net_d = MultiPeriodDiscriminator(hps.model.use_spectral_norm, version=hps.model.version).to(device) for name, param in net_g.named_parameters(): if not param.requires_grad: print(name, "not requires_grad") @@ -196,7 +183,7 @@ def run(rank, n_gpus, hps): betas=hps.train.betas, eps=hps.train.eps, ) - if torch.cuda.is_available(): + if torch.xpu.is_available(): net_g = DDP(net_g, device_ids=[rank], find_unused_parameters=True) net_d = DDP(net_d, device_ids=[rank], find_unused_parameters=True) else: @@ -238,7 +225,7 @@ def run(rank, n_gpus, hps): torch.load(hps.train.pretrained_s2G, map_location="cpu", weights_only=False)["weight"], strict=False, ) - if torch.cuda.is_available() + if torch.xpu.is_available() else net_g.load_state_dict( torch.load(hps.train.pretrained_s2G, map_location="cpu", weights_only=False)["weight"], strict=False, @@ -256,7 +243,7 @@ def run(rank, n_gpus, hps): net_d.module.load_state_dict( torch.load(hps.train.pretrained_s2D, map_location="cpu", weights_only=False)["weight"], strict=False ) - if torch.cuda.is_available() + if torch.xpu.is_available() else net_d.load_state_dict( torch.load(hps.train.pretrained_s2D, map_location="cpu", weights_only=False)["weight"], ), @@ -333,42 +320,24 @@ def train_and_evaluate(rank, epoch, hps, nets, optims, schedulers, scaler, loade ssl, ssl_lengths, spec, spec_lengths, y, y_lengths, text, text_lengths, sv_emb = data else: ssl, ssl_lengths, spec, spec_lengths, y, y_lengths, text, text_lengths = data - if torch.cuda.is_available(): + if torch.xpu.is_available(): spec, spec_lengths = ( - spec.cuda( - rank, - non_blocking=True, - ), - spec_lengths.cuda( - rank, - non_blocking=True, - ), + spec.to(device, non_blocking=True), + spec_lengths.to(device, non_blocking=True), ) y, y_lengths = ( - y.cuda( - rank, - non_blocking=True, - ), - y_lengths.cuda( - rank, - non_blocking=True, - ), + y.to(device, non_blocking=True), + y_lengths.to(device, non_blocking=True), ) - ssl = ssl.cuda(rank, non_blocking=True) + ssl = ssl.to(device, non_blocking=True) ssl.requires_grad = False - # ssl_lengths = ssl_lengths.cuda(rank, non_blocking=True) + # ssl_lengths = ssl_lengths.to(device, non_blocking=True) text, text_lengths = ( - text.cuda( - rank, - non_blocking=True, - ), - text_lengths.cuda( - rank, - non_blocking=True, - ), + text.to(device, non_blocking=True), + text_lengths.to(device, non_blocking=True), ) if hps.model.version in {"v2Pro", "v2ProPlus"}: - sv_emb = sv_emb.cuda(rank, non_blocking=True) + sv_emb = sv_emb.to(device, non_blocking=True) else: spec, spec_lengths = spec.to(device), spec_lengths.to(device) y, y_lengths = y.to(device), y_lengths.to(device) @@ -596,11 +565,11 @@ def evaluate(hps, generator, eval_loader, writer_eval): text_lengths, ) in enumerate(eval_loader): print(111) - if torch.cuda.is_available(): - spec, spec_lengths = spec.cuda(), spec_lengths.cuda() - y, y_lengths = y.cuda(), y_lengths.cuda() - ssl = ssl.cuda() - text, text_lengths = text.cuda(), text_lengths.cuda() + if torch.xpu.is_available(): + spec, spec_lengths = spec.to(device), spec_lengths.to(device) + y, y_lengths = y.to(device), y_lengths.to(device) + ssl = ssl.to(device) + text, text_lengths = text.to(device), text_lengths.to(device) else: spec, spec_lengths = spec.to(device), spec_lengths.to(device) y, y_lengths = y.to(device), y_lengths.to(device) diff --git a/api.py b/api.py index cc0896a2..1c26a951 100644 --- a/api.py +++ b/api.py @@ -1,1395 +1,1398 @@ -""" -# api.py usage - -` python api.py -dr "123.wav" -dt "一二三。" -dl "zh" ` - -## 执行参数: - -`-s` - `SoVITS模型路径, 可在 config.py 中指定` -`-g` - `GPT模型路径, 可在 config.py 中指定` - -调用请求缺少参考音频时使用 -`-dr` - `默认参考音频路径` -`-dt` - `默认参考音频文本` -`-dl` - `默认参考音频语种, "中文","英文","日文","韩文","粤语,"zh","en","ja","ko","yue"` - -`-d` - `推理设备, "cuda","cpu"` -`-a` - `绑定地址, 默认"127.0.0.1"` -`-p` - `绑定端口, 默认9880, 可在 config.py 中指定` -`-fp` - `覆盖 config.py 使用全精度` -`-hp` - `覆盖 config.py 使用半精度` -`-sm` - `流式返回模式, 默认不启用, "close","c", "normal","n", "keepalive","k"` -·-mt` - `返回的音频编码格式, 流式默认ogg, 非流式默认wav, "wav", "ogg", "aac"` -·-st` - `返回的音频数据类型, 默认int16, "int16", "int32"` -·-cp` - `文本切分符号设定, 默认为空, 以",.,。"字符串的方式传入` - -`-hb` - `cnhubert路径` -`-b` - `bert路径` - -## 调用: - -### 推理 - -endpoint: `/` - -使用执行参数指定的参考音频: -GET: - `http://127.0.0.1:9880?text=先帝创业未半而中道崩殂,今天下三分,益州疲弊,此诚危急存亡之秋也。&text_language=zh` -POST: -```json -{ - "text": "先帝创业未半而中道崩殂,今天下三分,益州疲弊,此诚危急存亡之秋也。", - "text_language": "zh" -} -``` - -使用执行参数指定的参考音频并设定分割符号: -GET: - `http://127.0.0.1:9880?text=先帝创业未半而中道崩殂,今天下三分,益州疲弊,此诚危急存亡之秋也。&text_language=zh&cut_punc=,。` -POST: -```json -{ - "text": "先帝创业未半而中道崩殂,今天下三分,益州疲弊,此诚危急存亡之秋也。", - "text_language": "zh", - "cut_punc": ",。", -} -``` - -手动指定当次推理所使用的参考音频: -GET: - `http://127.0.0.1:9880?refer_wav_path=123.wav&prompt_text=一二三。&prompt_language=zh&text=先帝创业未半而中道崩殂,今天下三分,益州疲弊,此诚危急存亡之秋也。&text_language=zh` -POST: -```json -{ - "refer_wav_path": "123.wav", - "prompt_text": "一二三。", - "prompt_language": "zh", - "text": "先帝创业未半而中道崩殂,今天下三分,益州疲弊,此诚危急存亡之秋也。", - "text_language": "zh" -} -``` - -RESP: -成功: 直接返回 wav 音频流, http code 200 -失败: 返回包含错误信息的 json, http code 400 - -手动指定当次推理所使用的参考音频,并提供参数: -GET: - `http://127.0.0.1:9880?refer_wav_path=123.wav&prompt_text=一二三。&prompt_language=zh&text=先帝创业未半而中道崩殂,今天下三分,益州疲弊,此诚危急存亡之秋也。&text_language=zh&top_k=20&top_p=0.6&temperature=0.6&speed=1&inp_refs="456.wav"&inp_refs="789.wav"` -POST: -```json -{ - "refer_wav_path": "123.wav", - "prompt_text": "一二三。", - "prompt_language": "zh", - "text": "先帝创业未半而中道崩殂,今天下三分,益州疲弊,此诚危急存亡之秋也。", - "text_language": "zh", - "top_k": 20, - "top_p": 0.6, - "temperature": 0.6, - "speed": 1, - "inp_refs": ["456.wav","789.wav"] -} -``` - -RESP: -成功: 直接返回 wav 音频流, http code 200 -失败: 返回包含错误信息的 json, http code 400 - - -### 更换默认参考音频 - -endpoint: `/change_refer` - -key与推理端一样 - -GET: - `http://127.0.0.1:9880/change_refer?refer_wav_path=123.wav&prompt_text=一二三。&prompt_language=zh` -POST: -```json -{ - "refer_wav_path": "123.wav", - "prompt_text": "一二三。", - "prompt_language": "zh" -} -``` - -RESP: -成功: json, http code 200 -失败: json, 400 - - -### 命令控制 - -endpoint: `/control` - -command: -"restart": 重新运行 -"exit": 结束运行 - -GET: - `http://127.0.0.1:9880/control?command=restart` -POST: -```json -{ - "command": "restart" -} -``` - -RESP: 无 - -""" - -import argparse -import os -import re -import sys - -now_dir = os.getcwd() -sys.path.append(now_dir) -sys.path.append("%s/GPT_SoVITS" % (now_dir)) - -import signal -from text.LangSegmenter import LangSegmenter -from time import time as ttime -import torch -import torchaudio -import librosa -import soundfile as sf -from fastapi import FastAPI, Request, Query -from fastapi.responses import StreamingResponse, JSONResponse -import uvicorn -from transformers import AutoModelForMaskedLM, AutoTokenizer -import numpy as np -from feature_extractor import cnhubert -from io import BytesIO -from module.models import Generator, SynthesizerTrn, SynthesizerTrnV3 -from peft import LoraConfig, get_peft_model -from AR.models.t2s_lightning_module import Text2SemanticLightningModule -from text import cleaned_text_to_sequence -from text.cleaner import clean_text -from module.mel_processing import spectrogram_torch -import config as global_config -import logging -import subprocess - - -class DefaultRefer: - def __init__(self, path, text, language): - self.path = args.default_refer_path - self.text = args.default_refer_text - self.language = args.default_refer_language - - def is_ready(self) -> bool: - return is_full(self.path, self.text, self.language) - - -def is_empty(*items): # 任意一项不为空返回False - for item in items: - if item is not None and item != "": - return False - return True - - -def is_full(*items): # 任意一项为空返回False - for item in items: - if item is None or item == "": - return False - return True - - -bigvgan_model = hifigan_model = sv_cn_model = None - - -def clean_hifigan_model(): - global hifigan_model - if hifigan_model: - hifigan_model = hifigan_model.cpu() - hifigan_model = None - try: - torch.cuda.empty_cache() - except: - pass - - -def clean_bigvgan_model(): - global bigvgan_model - if bigvgan_model: - bigvgan_model = bigvgan_model.cpu() - bigvgan_model = None - try: - torch.cuda.empty_cache() - except: - pass - - -def clean_sv_cn_model(): - global sv_cn_model - if sv_cn_model: - sv_cn_model.embedding_model = sv_cn_model.embedding_model.cpu() - sv_cn_model = None - try: - torch.cuda.empty_cache() - except: - pass - - -def init_bigvgan(): - global bigvgan_model, hifigan_model, sv_cn_model - from BigVGAN import bigvgan - - bigvgan_model = bigvgan.BigVGAN.from_pretrained( - "%s/GPT_SoVITS/pretrained_models/models--nvidia--bigvgan_v2_24khz_100band_256x" % (now_dir,), - use_cuda_kernel=False, - ) # if True, RuntimeError: Ninja is required to load C++ extensions - # remove weight norm in the model and set to eval mode - bigvgan_model.remove_weight_norm() - bigvgan_model = bigvgan_model.eval() - - if is_half == True: - bigvgan_model = bigvgan_model.half().to(device) - else: - bigvgan_model = bigvgan_model.to(device) - - -def init_hifigan(): - global hifigan_model, bigvgan_model, sv_cn_model - hifigan_model = Generator( - initial_channel=100, - resblock="1", - resblock_kernel_sizes=[3, 7, 11], - resblock_dilation_sizes=[[1, 3, 5], [1, 3, 5], [1, 3, 5]], - upsample_rates=[10, 6, 2, 2, 2], - upsample_initial_channel=512, - upsample_kernel_sizes=[20, 12, 4, 4, 4], - gin_channels=0, - is_bias=True, - ) - hifigan_model.eval() - hifigan_model.remove_weight_norm() - state_dict_g = torch.load( - "%s/GPT_SoVITS/pretrained_models/gsv-v4-pretrained/vocoder.pth" % (now_dir,), - map_location="cpu", - weights_only=False, - ) - print("loading vocoder", hifigan_model.load_state_dict(state_dict_g)) - if is_half == True: - hifigan_model = hifigan_model.half().to(device) - else: - hifigan_model = hifigan_model.to(device) - - -from sv import SV - - -def init_sv_cn(): - global hifigan_model, bigvgan_model, sv_cn_model - sv_cn_model = SV(device, is_half) - - -resample_transform_dict = {} - - -def resample(audio_tensor, sr0, sr1, device): - global resample_transform_dict - key = "%s-%s-%s" % (sr0, sr1, str(device)) - if key not in resample_transform_dict: - resample_transform_dict[key] = torchaudio.transforms.Resample(sr0, sr1).to(device) - return resample_transform_dict[key](audio_tensor) - - -from module.mel_processing import mel_spectrogram_torch - -spec_min = -12 -spec_max = 2 - - -def norm_spec(x): - return (x - spec_min) / (spec_max - spec_min) * 2 - 1 - - -def denorm_spec(x): - return (x + 1) / 2 * (spec_max - spec_min) + spec_min - - -mel_fn = lambda x: mel_spectrogram_torch( - x, - **{ - "n_fft": 1024, - "win_size": 1024, - "hop_size": 256, - "num_mels": 100, - "sampling_rate": 24000, - "fmin": 0, - "fmax": None, - "center": False, - }, -) -mel_fn_v4 = lambda x: mel_spectrogram_torch( - x, - **{ - "n_fft": 1280, - "win_size": 1280, - "hop_size": 320, - "num_mels": 100, - "sampling_rate": 32000, - "fmin": 0, - "fmax": None, - "center": False, - }, -) - - -sr_model = None - - -def audio_sr(audio, sr): - global sr_model - if sr_model == None: - from tools.audio_sr import AP_BWE - - try: - sr_model = AP_BWE(device, DictToAttrRecursive) - except FileNotFoundError: - logger.info("你没有下载超分模型的参数,因此不进行超分。如想超分请先参照教程把文件下载") - return audio.cpu().detach().numpy(), sr - return sr_model(audio, sr) - - -class Speaker: - def __init__(self, name, gpt, sovits, phones=None, bert=None, prompt=None): - self.name = name - self.sovits = sovits - self.gpt = gpt - self.phones = phones - self.bert = bert - self.prompt = prompt - - -speaker_list = {} - - -class Sovits: - def __init__(self, vq_model, hps): - self.vq_model = vq_model - self.hps = hps - - -from process_ckpt import get_sovits_version_from_path_fast, load_sovits_new - - -def get_sovits_weights(sovits_path): - from config import pretrained_sovits_name - - path_sovits_v3 = pretrained_sovits_name["v3"] - path_sovits_v4 = pretrained_sovits_name["v4"] - is_exist_s2gv3 = os.path.exists(path_sovits_v3) - is_exist_s2gv4 = os.path.exists(path_sovits_v4) - - version, model_version, if_lora_v3 = get_sovits_version_from_path_fast(sovits_path) - is_exist = is_exist_s2gv3 if model_version == "v3" else is_exist_s2gv4 - path_sovits = path_sovits_v3 if model_version == "v3" else path_sovits_v4 - - if if_lora_v3 == True and is_exist == False: - logger.info("SoVITS %s 底模缺失,无法加载相应 LoRA 权重" % model_version) - - dict_s2 = load_sovits_new(sovits_path) - hps = dict_s2["config"] - hps = DictToAttrRecursive(hps) - hps.model.semantic_frame_rate = "25hz" - if "enc_p.text_embedding.weight" not in dict_s2["weight"]: - hps.model.version = "v2" # v3model,v2sybomls - elif dict_s2["weight"]["enc_p.text_embedding.weight"].shape[0] == 322: - hps.model.version = "v1" - else: - hps.model.version = "v2" - - model_params_dict = vars(hps.model) - if model_version not in {"v3", "v4"}: - if "Pro" in model_version: - hps.model.version = model_version - if sv_cn_model == None: - init_sv_cn() - - vq_model = SynthesizerTrn( - hps.data.filter_length // 2 + 1, - hps.train.segment_size // hps.data.hop_length, - n_speakers=hps.data.n_speakers, - **model_params_dict, - ) - else: - hps.model.version = model_version - vq_model = SynthesizerTrnV3( - hps.data.filter_length // 2 + 1, - hps.train.segment_size // hps.data.hop_length, - n_speakers=hps.data.n_speakers, - **model_params_dict, - ) - if model_version == "v3": - init_bigvgan() - if model_version == "v4": - init_hifigan() - - model_version = hps.model.version - logger.info(f"模型版本: {model_version}") - if "pretrained" not in sovits_path: - try: - del vq_model.enc_q - except: - pass - if is_half == True: - vq_model = vq_model.half().to(device) - else: - vq_model = vq_model.to(device) - vq_model.eval() - if if_lora_v3 == False: - vq_model.load_state_dict(dict_s2["weight"], strict=False) - else: - path_sovits = path_sovits_v3 if model_version == "v3" else path_sovits_v4 - vq_model.load_state_dict(load_sovits_new(path_sovits)["weight"], strict=False) - lora_rank = dict_s2["lora_rank"] - lora_config = LoraConfig( - target_modules=["to_k", "to_q", "to_v", "to_out.0"], - r=lora_rank, - lora_alpha=lora_rank, - init_lora_weights=True, - ) - vq_model.cfm = get_peft_model(vq_model.cfm, lora_config) - vq_model.load_state_dict(dict_s2["weight"], strict=False) - vq_model.cfm = vq_model.cfm.merge_and_unload() - # torch.save(vq_model.state_dict(),"merge_win.pth") - vq_model.eval() - - sovits = Sovits(vq_model, hps) - return sovits - - -class Gpt: - def __init__(self, max_sec, t2s_model): - self.max_sec = max_sec - self.t2s_model = t2s_model - - -global hz -hz = 50 - - -def get_gpt_weights(gpt_path): - dict_s1 = torch.load(gpt_path, map_location="cpu", weights_only=False) - config = dict_s1["config"] - max_sec = config["data"]["max_sec"] - t2s_model = Text2SemanticLightningModule(config, "****", is_train=False) - t2s_model.load_state_dict(dict_s1["weight"]) - if is_half == True: - t2s_model = t2s_model.half() - t2s_model = t2s_model.to(device) - t2s_model.eval() - # total = sum([param.nelement() for param in t2s_model.parameters()]) - # logger.info("Number of parameter: %.2fM" % (total / 1e6)) - - gpt = Gpt(max_sec, t2s_model) - return gpt - - -def change_gpt_sovits_weights(gpt_path, sovits_path): - try: - gpt = get_gpt_weights(gpt_path) - sovits = get_sovits_weights(sovits_path) - except Exception as e: - return JSONResponse({"code": 400, "message": str(e)}, status_code=400) - - speaker_list["default"] = Speaker(name="default", gpt=gpt, sovits=sovits) - return JSONResponse({"code": 0, "message": "Success"}, status_code=200) - - -def get_bert_feature(text, word2ph): - with torch.no_grad(): - inputs = tokenizer(text, return_tensors="pt") - for i in inputs: - inputs[i] = inputs[i].to(device) #####输入是long不用管精度问题,精度随bert_model - res = bert_model(**inputs, output_hidden_states=True) - res = torch.cat(res["hidden_states"][-3:-2], -1)[0].cpu()[1:-1] - assert len(word2ph) == len(text) - phone_level_feature = [] - for i in range(len(word2ph)): - repeat_feature = res[i].repeat(word2ph[i], 1) - phone_level_feature.append(repeat_feature) - phone_level_feature = torch.cat(phone_level_feature, dim=0) - # if(is_half==True):phone_level_feature=phone_level_feature.half() - return phone_level_feature.T - - -def clean_text_inf(text, language, version): - language = language.replace("all_", "") - phones, word2ph, norm_text = clean_text(text, language, version) - phones = cleaned_text_to_sequence(phones, version) - return phones, word2ph, norm_text - - -def get_bert_inf(phones, word2ph, norm_text, language): - language = language.replace("all_", "") - if language == "zh": - bert = get_bert_feature(norm_text, word2ph).to(device) # .to(dtype) - else: - bert = torch.zeros( - (1024, len(phones)), - dtype=torch.float16 if is_half == True else torch.float32, - ).to(device) - - return bert - - -from text import chinese - - -def get_phones_and_bert(text, language, version, final=False): - text = re.sub(r' {2,}', ' ', text) - textlist = [] - langlist = [] - if language == "all_zh": - for tmp in LangSegmenter.getTexts(text,"zh"): - langlist.append(tmp["lang"]) - textlist.append(tmp["text"]) - elif language == "all_yue": - for tmp in LangSegmenter.getTexts(text,"zh"): - if tmp["lang"] == "zh": - tmp["lang"] = "yue" - langlist.append(tmp["lang"]) - textlist.append(tmp["text"]) - elif language == "all_ja": - for tmp in LangSegmenter.getTexts(text,"ja"): - langlist.append(tmp["lang"]) - textlist.append(tmp["text"]) - elif language == "all_ko": - for tmp in LangSegmenter.getTexts(text,"ko"): - langlist.append(tmp["lang"]) - textlist.append(tmp["text"]) - elif language == "en": - langlist.append("en") - textlist.append(text) - elif language == "auto": - for tmp in LangSegmenter.getTexts(text): - langlist.append(tmp["lang"]) - textlist.append(tmp["text"]) - elif language == "auto_yue": - for tmp in LangSegmenter.getTexts(text): - if tmp["lang"] == "zh": - tmp["lang"] = "yue" - langlist.append(tmp["lang"]) - textlist.append(tmp["text"]) - else: - for tmp in LangSegmenter.getTexts(text): - if langlist: - if (tmp["lang"] == "en" and langlist[-1] == "en") or (tmp["lang"] != "en" and langlist[-1] != "en"): - textlist[-1] += tmp["text"] - continue - if tmp["lang"] == "en": - langlist.append(tmp["lang"]) - else: - # 因无法区别中日韩文汉字,以用户输入为准 - langlist.append(language) - textlist.append(tmp["text"]) - phones_list = [] - bert_list = [] - norm_text_list = [] - for i in range(len(textlist)): - lang = langlist[i] - phones, word2ph, norm_text = clean_text_inf(textlist[i], lang, version) - bert = get_bert_inf(phones, word2ph, norm_text, lang) - phones_list.append(phones) - norm_text_list.append(norm_text) - bert_list.append(bert) - bert = torch.cat(bert_list, dim=1) - phones = sum(phones_list, []) - norm_text = "".join(norm_text_list) - - if not final and len(phones) < 6: - return get_phones_and_bert("." + text, language, version, final=True) - - return phones, bert.to(torch.float16 if is_half == True else torch.float32), norm_text - - -class DictToAttrRecursive(dict): - def __init__(self, input_dict): - super().__init__(input_dict) - for key, value in input_dict.items(): - if isinstance(value, dict): - value = DictToAttrRecursive(value) - self[key] = value - setattr(self, key, value) - - def __getattr__(self, item): - try: - return self[item] - except KeyError: - raise AttributeError(f"Attribute {item} not found") - - def __setattr__(self, key, value): - if isinstance(value, dict): - value = DictToAttrRecursive(value) - super(DictToAttrRecursive, self).__setitem__(key, value) - super().__setattr__(key, value) - - def __delattr__(self, item): - try: - del self[item] - except KeyError: - raise AttributeError(f"Attribute {item} not found") - - -def get_spepc(hps, filename, dtype, device, is_v2pro=False): - sr1 = int(hps.data.sampling_rate) - audio, sr0 = torchaudio.load(filename) - if sr0 != sr1: - audio = audio.to(device) - if audio.shape[0] == 2: - audio = audio.mean(0).unsqueeze(0) - audio = resample(audio, sr0, sr1, device) - else: - audio = audio.to(device) - if audio.shape[0] == 2: - audio = audio.mean(0).unsqueeze(0) - - maxx = audio.abs().max() - if maxx > 1: - audio /= min(2, maxx) - spec = spectrogram_torch( - audio, - hps.data.filter_length, - hps.data.sampling_rate, - hps.data.hop_length, - hps.data.win_length, - center=False, - ) - spec = spec.to(dtype) - if is_v2pro == True: - audio = resample(audio, sr1, 16000, device).to(dtype) - return spec, audio - - -def pack_audio(audio_bytes, data, rate): - if media_type == "ogg": - audio_bytes = pack_ogg(audio_bytes, data, rate) - elif media_type == "aac": - audio_bytes = pack_aac(audio_bytes, data, rate) - else: - # wav无法流式, 先暂存raw - audio_bytes = pack_raw(audio_bytes, data, rate) - - return audio_bytes - - -def pack_ogg(audio_bytes, data, rate): - # Author: AkagawaTsurunaki - # Issue: - # Stack overflow probabilistically occurs - # when the function `sf_writef_short` of `libsndfile_64bit.dll` is called - # using the Python library `soundfile` - # Note: - # This is an issue related to `libsndfile`, not this project itself. - # It happens when you generate a large audio tensor (about 499804 frames in my PC) - # and try to convert it to an ogg file. - # Related: - # https://github.com/RVC-Boss/GPT-SoVITS/issues/1199 - # https://github.com/libsndfile/libsndfile/issues/1023 - # https://github.com/bastibe/python-soundfile/issues/396 - # Suggestion: - # Or split the whole audio data into smaller audio segment to avoid stack overflow? - - def handle_pack_ogg(): - with sf.SoundFile(audio_bytes, mode="w", samplerate=rate, channels=1, format="ogg") as audio_file: - audio_file.write(data) - - import threading - - # See: https://docs.python.org/3/library/threading.html - # The stack size of this thread is at least 32768 - # If stack overflow error still occurs, just modify the `stack_size`. - # stack_size = n * 4096, where n should be a positive integer. - # Here we chose n = 4096. - stack_size = 4096 * 4096 - try: - threading.stack_size(stack_size) - pack_ogg_thread = threading.Thread(target=handle_pack_ogg) - pack_ogg_thread.start() - pack_ogg_thread.join() - except RuntimeError as e: - # If changing the thread stack size is unsupported, a RuntimeError is raised. - print("RuntimeError: {}".format(e)) - print("Changing the thread stack size is unsupported.") - except ValueError as e: - # If the specified stack size is invalid, a ValueError is raised and the stack size is unmodified. - print("ValueError: {}".format(e)) - print("The specified stack size is invalid.") - - return audio_bytes - - -def pack_raw(audio_bytes, data, rate): - audio_bytes.write(data.tobytes()) - - return audio_bytes - - -def pack_wav(audio_bytes, rate): - if is_int32: - data = np.frombuffer(audio_bytes.getvalue(), dtype=np.int32) - wav_bytes = BytesIO() - sf.write(wav_bytes, data, rate, format="WAV", subtype="PCM_32") - else: - data = np.frombuffer(audio_bytes.getvalue(), dtype=np.int16) - wav_bytes = BytesIO() - sf.write(wav_bytes, data, rate, format="WAV") - return wav_bytes - - -def pack_aac(audio_bytes, data, rate): - if is_int32: - pcm = "s32le" - bit_rate = "256k" - else: - pcm = "s16le" - bit_rate = "128k" - process = subprocess.Popen( - [ - "ffmpeg", - "-f", - pcm, # 输入16位有符号小端整数PCM - "-ar", - str(rate), # 设置采样率 - "-ac", - "1", # 单声道 - "-i", - "pipe:0", # 从管道读取输入 - "-c:a", - "aac", # 音频编码器为AAC - "-b:a", - bit_rate, # 比特率 - "-vn", # 不包含视频 - "-f", - "adts", # 输出AAC数据流格式 - "pipe:1", # 将输出写入管道 - ], - stdin=subprocess.PIPE, - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - ) - out, _ = process.communicate(input=data.tobytes()) - audio_bytes.write(out) - - return audio_bytes - - -def read_clean_buffer(audio_bytes): - audio_chunk = audio_bytes.getvalue() - audio_bytes.truncate(0) - audio_bytes.seek(0) - - return audio_bytes, audio_chunk - - -def cut_text(text, punc): - punc_list = [p for p in punc if p in {",", ".", ";", "?", "!", "、", ",", "。", "?", "!", ";", ":", "…"}] - if len(punc_list) > 0: - punds = r"[" + "".join(punc_list) + r"]" - text = text.strip("\n") - items = re.split(f"({punds})", text) - mergeitems = ["".join(group) for group in zip(items[::2], items[1::2])] - # 在句子不存在符号或句尾无符号的时候保证文本完整 - if len(items) % 2 == 1: - mergeitems.append(items[-1]) - text = "\n".join(mergeitems) - - while "\n\n" in text: - text = text.replace("\n\n", "\n") - - return text - - -def only_punc(text): - return not any(t.isalnum() or t.isalpha() for t in text) - - -splits = { - ",", - "。", - "?", - "!", - ",", - ".", - "?", - "!", - "~", - ":", - ":", - "—", - "…", -} - - -def get_tts_wav( - ref_wav_path, - prompt_text, - prompt_language, - text, - text_language, - top_k=15, - top_p=0.6, - temperature=0.6, - speed=1, - inp_refs=None, - sample_steps=32, - if_sr=False, - spk="default", -): - infer_sovits = speaker_list[spk].sovits - vq_model = infer_sovits.vq_model - hps = infer_sovits.hps - version = vq_model.version - - infer_gpt = speaker_list[spk].gpt - t2s_model = infer_gpt.t2s_model - max_sec = infer_gpt.max_sec - - if version == "v3": - if sample_steps not in [4, 8, 16, 32, 64, 128]: - sample_steps = 32 - elif version == "v4": - if sample_steps not in [4, 8, 16, 32]: - sample_steps = 8 - - if if_sr and version != "v3": - if_sr = False - - t0 = ttime() - prompt_text = prompt_text.strip("\n") - if prompt_text[-1] not in splits: - prompt_text += "。" if prompt_language != "en" else "." - prompt_language, text = prompt_language, text.strip("\n") - dtype = torch.float16 if is_half == True else torch.float32 - zero_wav = np.zeros(int(hps.data.sampling_rate * 0.3), dtype=np.float16 if is_half == True else np.float32) - with torch.no_grad(): - wav16k, sr = librosa.load(ref_wav_path, sr=16000) - wav16k = torch.from_numpy(wav16k) - zero_wav_torch = torch.from_numpy(zero_wav) - if is_half == True: - wav16k = wav16k.half().to(device) - zero_wav_torch = zero_wav_torch.half().to(device) - else: - wav16k = wav16k.to(device) - zero_wav_torch = zero_wav_torch.to(device) - wav16k = torch.cat([wav16k, zero_wav_torch]) - ssl_content = ssl_model.model(wav16k.unsqueeze(0))["last_hidden_state"].transpose(1, 2) # .float() - codes = vq_model.extract_latent(ssl_content) - prompt_semantic = codes[0, 0] - prompt = prompt_semantic.unsqueeze(0).to(device) - - is_v2pro = version in {"v2Pro", "v2ProPlus"} - if version not in {"v3", "v4"}: - refers = [] - if is_v2pro: - sv_emb = [] - if sv_cn_model == None: - init_sv_cn() - if inp_refs: - for path in inp_refs: - try: #####这里加上提取sv的逻辑,要么一堆sv一堆refer,要么单个sv单个refer - refer, audio_tensor = get_spepc(hps, path.name, dtype, device, is_v2pro) - refers.append(refer) - if is_v2pro: - sv_emb.append(sv_cn_model.compute_embedding3(audio_tensor)) - except Exception as e: - logger.error(e) - if len(refers) == 0: - refers, audio_tensor = get_spepc(hps, ref_wav_path, dtype, device, is_v2pro) - refers = [refers] - if is_v2pro: - sv_emb = [sv_cn_model.compute_embedding3(audio_tensor)] - else: - refer, audio_tensor = get_spepc(hps, ref_wav_path, dtype, device) - - t1 = ttime() - # os.environ['version'] = version - prompt_language = dict_language[prompt_language.lower()] - text_language = dict_language[text_language.lower()] - phones1, bert1, norm_text1 = get_phones_and_bert(prompt_text, prompt_language, version) - texts = text.split("\n") - audio_bytes = BytesIO() - - for text in texts: - # 简单防止纯符号引发参考音频泄露 - if only_punc(text): - continue - - audio_opt = [] - if text[-1] not in splits: - text += "。" if text_language != "en" else "." - phones2, bert2, norm_text2 = get_phones_and_bert(text, text_language, version) - bert = torch.cat([bert1, bert2], 1) - - all_phoneme_ids = torch.LongTensor(phones1 + phones2).to(device).unsqueeze(0) - bert = bert.to(device).unsqueeze(0) - all_phoneme_len = torch.tensor([all_phoneme_ids.shape[-1]]).to(device) - t2 = ttime() - with torch.no_grad(): - pred_semantic, idx = t2s_model.model.infer_panel( - all_phoneme_ids, - all_phoneme_len, - prompt, - bert, - # prompt_phone_len=ph_offset, - top_k=top_k, - top_p=top_p, - temperature=temperature, - early_stop_num=hz * max_sec, - ) - pred_semantic = pred_semantic[:, -idx:].unsqueeze(0) - t3 = ttime() - - if version not in {"v3", "v4"}: - if is_v2pro: - audio = ( - vq_model.decode( - pred_semantic, - torch.LongTensor(phones2).to(device).unsqueeze(0), - refers, - speed=speed, - sv_emb=sv_emb, - ) - .detach() - .cpu() - .numpy()[0, 0] - ) - else: - audio = ( - vq_model.decode( - pred_semantic, torch.LongTensor(phones2).to(device).unsqueeze(0), refers, speed=speed - ) - .detach() - .cpu() - .numpy()[0, 0] - ) - else: - phoneme_ids0 = torch.LongTensor(phones1).to(device).unsqueeze(0) - phoneme_ids1 = torch.LongTensor(phones2).to(device).unsqueeze(0) - - fea_ref, ge = vq_model.decode_encp(prompt.unsqueeze(0), phoneme_ids0, refer) - ref_audio, sr = torchaudio.load(ref_wav_path) - ref_audio = ref_audio.to(device).float() - if ref_audio.shape[0] == 2: - ref_audio = ref_audio.mean(0).unsqueeze(0) - - tgt_sr = 24000 if version == "v3" else 32000 - if sr != tgt_sr: - ref_audio = resample(ref_audio, sr, tgt_sr, device) - mel2 = mel_fn(ref_audio) if version == "v3" else mel_fn_v4(ref_audio) - mel2 = norm_spec(mel2) - T_min = min(mel2.shape[2], fea_ref.shape[2]) - mel2 = mel2[:, :, :T_min] - fea_ref = fea_ref[:, :, :T_min] - Tref = 468 if version == "v3" else 500 - Tchunk = 934 if version == "v3" else 1000 - if T_min > Tref: - mel2 = mel2[:, :, -Tref:] - fea_ref = fea_ref[:, :, -Tref:] - T_min = Tref - chunk_len = Tchunk - T_min - mel2 = mel2.to(dtype) - fea_todo, ge = vq_model.decode_encp(pred_semantic, phoneme_ids1, refer, ge, speed) - cfm_resss = [] - idx = 0 - while 1: - fea_todo_chunk = fea_todo[:, :, idx : idx + chunk_len] - if fea_todo_chunk.shape[-1] == 0: - break - idx += chunk_len - fea = torch.cat([fea_ref, fea_todo_chunk], 2).transpose(2, 1) - cfm_res = vq_model.cfm.inference( - fea, torch.LongTensor([fea.size(1)]).to(fea.device), mel2, sample_steps, inference_cfg_rate=0 - ) - cfm_res = cfm_res[:, :, mel2.shape[2] :] - mel2 = cfm_res[:, :, -T_min:] - fea_ref = fea_todo_chunk[:, :, -T_min:] - cfm_resss.append(cfm_res) - cfm_res = torch.cat(cfm_resss, 2) - cfm_res = denorm_spec(cfm_res) - if version == "v3": - if bigvgan_model == None: - init_bigvgan() - else: # v4 - if hifigan_model == None: - init_hifigan() - vocoder_model = bigvgan_model if version == "v3" else hifigan_model - with torch.inference_mode(): - wav_gen = vocoder_model(cfm_res) - audio = wav_gen[0][0].cpu().detach().numpy() - - max_audio = np.abs(audio).max() - if max_audio > 1: - audio /= max_audio - audio_opt.append(audio) - audio_opt.append(zero_wav) - audio_opt = np.concatenate(audio_opt, 0) - t4 = ttime() - - if version in {"v1", "v2", "v2Pro", "v2ProPlus"}: - sr = 32000 - elif version == "v3": - sr = 24000 - else: - sr = 48000 # v4 - - if if_sr and sr == 24000: - audio_opt = torch.from_numpy(audio_opt).float().to(device) - audio_opt, sr = audio_sr(audio_opt.unsqueeze(0), sr) - max_audio = np.abs(audio_opt).max() - if max_audio > 1: - audio_opt /= max_audio - sr = 48000 - - if is_int32: - audio_bytes = pack_audio(audio_bytes, (audio_opt * 2147483647).astype(np.int32), sr) - else: - audio_bytes = pack_audio(audio_bytes, (audio_opt * 32768).astype(np.int16), sr) - # logger.info("%.3f\t%.3f\t%.3f\t%.3f" % (t1 - t0, t2 - t1, t3 - t2, t4 - t3)) - if stream_mode == "normal": - audio_bytes, audio_chunk = read_clean_buffer(audio_bytes) - yield audio_chunk - - if not stream_mode == "normal": - if media_type == "wav": - if version in {"v1", "v2", "v2Pro", "v2ProPlus"}: - sr = 32000 - elif version == "v3": - sr = 48000 if if_sr else 24000 - else: - sr = 48000 # v4 - audio_bytes = pack_wav(audio_bytes, sr) - yield audio_bytes.getvalue() - - -def handle_control(command): - if command == "restart": - os.execl(g_config.python_exec, g_config.python_exec, *sys.argv) - elif command == "exit": - os.kill(os.getpid(), signal.SIGTERM) - exit(0) - - -def handle_change(path, text, language): - if is_empty(path, text, language): - return JSONResponse( - {"code": 400, "message": '缺少任意一项以下参数: "path", "text", "language"'}, status_code=400 - ) - - if path != "" or path is not None: - default_refer.path = path - if text != "" or text is not None: - default_refer.text = text - if language != "" or language is not None: - default_refer.language = language - - logger.info(f"当前默认参考音频路径: {default_refer.path}") - logger.info(f"当前默认参考音频文本: {default_refer.text}") - logger.info(f"当前默认参考音频语种: {default_refer.language}") - logger.info(f"is_ready: {default_refer.is_ready()}") - - return JSONResponse({"code": 0, "message": "Success"}, status_code=200) - - -def handle( - refer_wav_path, - prompt_text, - prompt_language, - text, - text_language, - cut_punc, - top_k, - top_p, - temperature, - speed, - inp_refs, - sample_steps, - if_sr, -): - if ( - refer_wav_path == "" - or refer_wav_path is None - or prompt_text == "" - or prompt_text is None - or prompt_language == "" - or prompt_language is None - ): - refer_wav_path, prompt_text, prompt_language = ( - default_refer.path, - default_refer.text, - default_refer.language, - ) - if not default_refer.is_ready(): - return JSONResponse({"code": 400, "message": "未指定参考音频且接口无预设"}, status_code=400) - - if cut_punc == None: - text = cut_text(text, default_cut_punc) - else: - text = cut_text(text, cut_punc) - - return StreamingResponse( - get_tts_wav( - refer_wav_path, - prompt_text, - prompt_language, - text, - text_language, - top_k, - top_p, - temperature, - speed, - inp_refs, - sample_steps, - if_sr, - ), - media_type="audio/" + media_type, - ) - - -# -------------------------------- -# 初始化部分 -# -------------------------------- -dict_language = { - "中文": "all_zh", - "粤语": "all_yue", - "英文": "en", - "日文": "all_ja", - "韩文": "all_ko", - "中英混合": "zh", - "粤英混合": "yue", - "日英混合": "ja", - "韩英混合": "ko", - "多语种混合": "auto", # 多语种启动切分识别语种 - "多语种混合(粤语)": "auto_yue", - "all_zh": "all_zh", - "all_yue": "all_yue", - "en": "en", - "all_ja": "all_ja", - "all_ko": "all_ko", - "zh": "zh", - "yue": "yue", - "ja": "ja", - "ko": "ko", - "auto": "auto", - "auto_yue": "auto_yue", -} - -# logger -logging.config.dictConfig(uvicorn.config.LOGGING_CONFIG) -logger = logging.getLogger("uvicorn") - -# 获取配置 -g_config = global_config.Config() - -# 获取参数 -parser = argparse.ArgumentParser(description="GPT-SoVITS api") - -parser.add_argument("-s", "--sovits_path", type=str, default=g_config.sovits_path, help="SoVITS模型路径") -parser.add_argument("-g", "--gpt_path", type=str, default=g_config.gpt_path, help="GPT模型路径") -parser.add_argument("-dr", "--default_refer_path", type=str, default="", help="默认参考音频路径") -parser.add_argument("-dt", "--default_refer_text", type=str, default="", help="默认参考音频文本") -parser.add_argument("-dl", "--default_refer_language", type=str, default="", help="默认参考音频语种") -parser.add_argument("-d", "--device", type=str, default=g_config.infer_device, help="cuda / cpu") -parser.add_argument("-a", "--bind_addr", type=str, default="0.0.0.0", help="default: 0.0.0.0") -parser.add_argument("-p", "--port", type=int, default=g_config.api_port, help="default: 9880") -parser.add_argument( - "-fp", "--full_precision", action="store_true", default=False, help="覆盖config.is_half为False, 使用全精度" -) -parser.add_argument( - "-hp", "--half_precision", action="store_true", default=False, help="覆盖config.is_half为True, 使用半精度" -) -# bool值的用法为 `python ./api.py -fp ...` -# 此时 full_precision==True, half_precision==False -parser.add_argument("-sm", "--stream_mode", type=str, default="close", help="流式返回模式, close / normal / keepalive") -parser.add_argument("-mt", "--media_type", type=str, default="wav", help="音频编码格式, wav / ogg / aac") -parser.add_argument("-st", "--sub_type", type=str, default="int16", help="音频数据类型, int16 / int32") -parser.add_argument("-cp", "--cut_punc", type=str, default="", help="文本切分符号设定, 符号范围,.;?!、,。?!;:…") -# 切割常用分句符为 `python ./api.py -cp ".?!。?!"` -parser.add_argument("-hb", "--hubert_path", type=str, default=g_config.cnhubert_path, help="覆盖config.cnhubert_path") -parser.add_argument("-b", "--bert_path", type=str, default=g_config.bert_path, help="覆盖config.bert_path") - -args = parser.parse_args() -sovits_path = args.sovits_path -gpt_path = args.gpt_path -device = args.device -port = args.port -host = args.bind_addr -cnhubert_base_path = args.hubert_path -bert_path = args.bert_path -default_cut_punc = args.cut_punc - -# 应用参数配置 -default_refer = DefaultRefer(args.default_refer_path, args.default_refer_text, args.default_refer_language) - -# 模型路径检查 -if sovits_path == "": - sovits_path = g_config.pretrained_sovits_path - logger.warning(f"未指定SoVITS模型路径, fallback后当前值: {sovits_path}") -if gpt_path == "": - gpt_path = g_config.pretrained_gpt_path - logger.warning(f"未指定GPT模型路径, fallback后当前值: {gpt_path}") - -# 指定默认参考音频, 调用方 未提供/未给全 参考音频参数时使用 -if default_refer.path == "" or default_refer.text == "" or default_refer.language == "": - default_refer.path, default_refer.text, default_refer.language = "", "", "" - logger.info("未指定默认参考音频") -else: - logger.info(f"默认参考音频路径: {default_refer.path}") - logger.info(f"默认参考音频文本: {default_refer.text}") - logger.info(f"默认参考音频语种: {default_refer.language}") - -# 获取半精度 -is_half = g_config.is_half -if args.full_precision: - is_half = False -if args.half_precision: - is_half = True -if args.full_precision and args.half_precision: - is_half = g_config.is_half # 炒饭fallback -logger.info(f"半精: {is_half}") - -# 流式返回模式 -if args.stream_mode.lower() in ["normal", "n"]: - stream_mode = "normal" - logger.info("流式返回已开启") -else: - stream_mode = "close" - -# 音频编码格式 -if args.media_type.lower() in ["aac", "ogg"]: - media_type = args.media_type.lower() -elif stream_mode == "close": - media_type = "wav" -else: - media_type = "ogg" -logger.info(f"编码格式: {media_type}") - -# 音频数据类型 -if args.sub_type.lower() == "int32": - is_int32 = True - logger.info("数据类型: int32") -else: - is_int32 = False - logger.info("数据类型: int16") - -# 初始化模型 -cnhubert.cnhubert_base_path = cnhubert_base_path -tokenizer = AutoTokenizer.from_pretrained(bert_path) -bert_model = AutoModelForMaskedLM.from_pretrained(bert_path) -ssl_model = cnhubert.get_model() -if is_half: - bert_model = bert_model.half().to(device) - ssl_model = ssl_model.half().to(device) -else: - bert_model = bert_model.to(device) - ssl_model = ssl_model.to(device) -change_gpt_sovits_weights(gpt_path=gpt_path, sovits_path=sovits_path) - - -# -------------------------------- -# 接口部分 -# -------------------------------- -app = FastAPI() - - -@app.post("/set_model") -async def set_model(request: Request): - json_post_raw = await request.json() - return change_gpt_sovits_weights( - gpt_path=json_post_raw.get("gpt_model_path"), sovits_path=json_post_raw.get("sovits_model_path") - ) - - -@app.get("/set_model") -async def set_model( - gpt_model_path: str = None, - sovits_model_path: str = None, -): - return change_gpt_sovits_weights(gpt_path=gpt_model_path, sovits_path=sovits_model_path) - - -@app.post("/control") -async def control(request: Request): - json_post_raw = await request.json() - return handle_control(json_post_raw.get("command")) - - -@app.get("/control") -async def control(command: str = None): - return handle_control(command) - - -@app.post("/change_refer") -async def change_refer(request: Request): - json_post_raw = await request.json() - return handle_change( - json_post_raw.get("refer_wav_path"), json_post_raw.get("prompt_text"), json_post_raw.get("prompt_language") - ) - - -@app.get("/change_refer") -async def change_refer(refer_wav_path: str = None, prompt_text: str = None, prompt_language: str = None): - return handle_change(refer_wav_path, prompt_text, prompt_language) - - -@app.post("/") -async def tts_endpoint(request: Request): - json_post_raw = await request.json() - return handle( - json_post_raw.get("refer_wav_path"), - json_post_raw.get("prompt_text"), - json_post_raw.get("prompt_language"), - json_post_raw.get("text"), - json_post_raw.get("text_language"), - json_post_raw.get("cut_punc"), - json_post_raw.get("top_k", 15), - json_post_raw.get("top_p", 1.0), - json_post_raw.get("temperature", 1.0), - json_post_raw.get("speed", 1.0), - json_post_raw.get("inp_refs", []), - json_post_raw.get("sample_steps", 32), - json_post_raw.get("if_sr", False), - ) - - -@app.get("/") -async def tts_endpoint( - refer_wav_path: str = None, - prompt_text: str = None, - prompt_language: str = None, - text: str = None, - text_language: str = None, - cut_punc: str = None, - top_k: int = 15, - top_p: float = 1.0, - temperature: float = 1.0, - speed: float = 1.0, - inp_refs: list = Query(default=[]), - sample_steps: int = 32, - if_sr: bool = False, -): - return handle( - refer_wav_path, - prompt_text, - prompt_language, - text, - text_language, - cut_punc, - top_k, - top_p, - temperature, - speed, - inp_refs, - sample_steps, - if_sr, - ) - - -if __name__ == "__main__": - uvicorn.run(app, host=host, port=port, workers=1) +""" +# api.py usage + +` python api.py -dr "123.wav" -dt "一二三。" -dl "zh" ` + +## 执行参数: + +`-s` - `SoVITS模型路径, 可在 config.py 中指定` +`-g` - `GPT模型路径, 可在 config.py 中指定` + +调用请求缺少参考音频时使用 +`-dr` - `默认参考音频路径` +`-dt` - `默认参考音频文本` +`-dl` - `默认参考音频语种, "中文","英文","日文","韩文","粤语,"zh","en","ja","ko","yue"` + +`-d` - `推理设备, "xpu","cpu"` +`-a` - `绑定地址, 默认"127.0.0.1"` +`-p` - `绑定端口, 默认9880, 可在 config.py 中指定` +`-fp` - `覆盖 config.py 使用全精度` +`-hp` - `覆盖 config.py 使用半精度` +`-sm` - `流式返回模式, 默认不启用, "close","c", "normal","n", "keepalive","k"` +·-mt` - `返回的音频编码格式, 流式默认ogg, 非流式默认wav, "wav", "ogg", "aac"` +·-st` - `返回的音频数据类型, 默认int16, "int16", "int32"` +·-cp` - `文本切分符号设定, 默认为空, 以",.,。"字符串的方式传入` + +`-hb` - `cnhubert路径` +`-b` - `bert路径` + +## 调用: + +### 推理 + +endpoint: `/` + +使用执行参数指定的参考音频: +GET: + `http://127.0.0.1:9880?text=先帝创业未半而中道崩殂,今天下三分,益州疲弊,此诚危急存亡之秋也。&text_language=zh` +POST: +```json +{ + "text": "先帝创业未半而中道崩殂,今天下三分,益州疲弊,此诚危急存亡之秋也。", + "text_language": "zh" +} +``` + +使用执行参数指定的参考音频并设定分割符号: +GET: + `http://127.0.0.1:9880?text=先帝创业未半而中道崩殂,今天下三分,益州疲弊,此诚危急存亡之秋也。&text_language=zh&cut_punc=,。` +POST: +```json +{ + "text": "先帝创业未半而中道崩殂,今天下三分,益州疲弊,此诚危急存亡之秋也。", + "text_language": "zh", + "cut_punc": ",。", +} +``` + +手动指定当次推理所使用的参考音频: +GET: + `http://127.0.0.1:9880?refer_wav_path=123.wav&prompt_text=一二三。&prompt_language=zh&text=先帝创业未半而中道崩殂,今天下三分,益州疲弊,此诚危急存亡之秋也。&text_language=zh` +POST: +```json +{ + "refer_wav_path": "123.wav", + "prompt_text": "一二三。", + "prompt_language": "zh", + "text": "先帝创业未半而中道崩殂,今天下三分,益州疲弊,此诚危急存亡之秋也。", + "text_language": "zh" +} +``` + +RESP: +成功: 直接返回 wav 音频流, http code 200 +失败: 返回包含错误信息的 json, http code 400 + +手动指定当次推理所使用的参考音频,并提供参数: +GET: + `http://127.0.0.1:9880?refer_wav_path=123.wav&prompt_text=一二三。&prompt_language=zh&text=先帝创业未半而中道崩殂,今天下三分,益州疲弊,此诚危急存亡之秋也。&text_language=zh&top_k=20&top_p=0.6&temperature=0.6&speed=1&inp_refs="456.wav"&inp_refs="789.wav"` +POST: +```json +{ + "refer_wav_path": "123.wav", + "prompt_text": "一二三。", + "prompt_language": "zh", + "text": "先帝创业未半而中道崩殂,今天下三分,益州疲弊,此诚危急存亡之秋也。", + "text_language": "zh", + "top_k": 20, + "top_p": 0.6, + "temperature": 0.6, + "speed": 1, + "inp_refs": ["456.wav","789.wav"] +} +``` + +RESP: +成功: 直接返回 wav 音频流, http code 200 +失败: 返回包含错误信息的 json, http code 400 + + +### 更换默认参考音频 + +endpoint: `/change_refer` + +key与推理端一样 + +GET: + `http://127.0.0.1:9880/change_refer?refer_wav_path=123.wav&prompt_text=一二三。&prompt_language=zh` +POST: +```json +{ + "refer_wav_path": "123.wav", + "prompt_text": "一二三。", + "prompt_language": "zh" +} +``` + +RESP: +成功: json, http code 200 +失败: json, 400 + + +### 命令控制 + +endpoint: `/control` + +command: +"restart": 重新运行 +"exit": 结束运行 + +GET: + `http://127.0.0.1:9880/control?command=restart` +POST: +```json +{ + "command": "restart" +} +``` + +RESP: 无 + +""" + +import argparse +import os +import re +import sys + +now_dir = os.getcwd() +sys.path.append(now_dir) +sys.path.append("%s/GPT_SoVITS" % (now_dir)) + +import signal +from text.LangSegmenter import LangSegmenter +from time import time as ttime +import torch +import torchaudio +import librosa +import soundfile as sf +from fastapi import FastAPI, Request, Query +from fastapi.responses import StreamingResponse, JSONResponse +import uvicorn +from transformers import AutoModelForMaskedLM, AutoTokenizer +import numpy as np +from feature_extractor import cnhubert +from io import BytesIO +from module.models import Generator, SynthesizerTrn, SynthesizerTrnV3 +from peft import LoraConfig, get_peft_model +from AR.models.t2s_lightning_module import Text2SemanticLightningModule +from text import cleaned_text_to_sequence +from text.cleaner import clean_text +from module.mel_processing import spectrogram_torch +import config as global_config +import logging +import subprocess + + +class DefaultRefer: + def __init__(self, path, text, language): + self.path = args.default_refer_path + self.text = args.default_refer_text + self.language = args.default_refer_language + + def is_ready(self) -> bool: + return is_full(self.path, self.text, self.language) + + +def is_empty(*items): # 任意一项不为空返回False + for item in items: + if item is not None and item != "": + return False + return True + + +def is_full(*items): # 任意一项为空返回False + for item in items: + if item is None or item == "": + return False + return True + + +bigvgan_model = hifigan_model = sv_cn_model = None + + +def clean_hifigan_model(): + global hifigan_model + if hifigan_model: + hifigan_model = hifigan_model.cpu() + hifigan_model = None + try: + if torch.xpu.is_available(): + torch.xpu.empty_cache() + except: + pass + + +def clean_bigvgan_model(): + global bigvgan_model + if bigvgan_model: + bigvgan_model = bigvgan_model.cpu() + bigvgan_model = None + try: + if torch.xpu.is_available(): + torch.xpu.empty_cache() + except: + pass + + +def clean_sv_cn_model(): + global sv_cn_model + if sv_cn_model: + sv_cn_model.embedding_model = sv_cn_model.embedding_model.cpu() + sv_cn_model = None + try: + if torch.xpu.is_available(): + torch.xpu.empty_cache() + except: + pass + + +def init_bigvgan(): + global bigvgan_model, hifigan_model, sv_cn_model + from BigVGAN import bigvgan + + bigvgan_model = bigvgan.BigVGAN.from_pretrained( + "%s/GPT_SoVITS/pretrained_models/models--nvidia--bigvgan_v2_24khz_100band_256x" % (now_dir,), + use_cuda_kernel=False, + ) # if True, RuntimeError: Ninja is required to load C++ extensions + # remove weight norm in the model and set to eval mode + bigvgan_model.remove_weight_norm() + bigvgan_model = bigvgan_model.eval() + + if is_half == True: + bigvgan_model = bigvgan_model.half().to(device) + else: + bigvgan_model = bigvgan_model.to(device) + + +def init_hifigan(): + global hifigan_model, bigvgan_model, sv_cn_model + hifigan_model = Generator( + initial_channel=100, + resblock="1", + resblock_kernel_sizes=[3, 7, 11], + resblock_dilation_sizes=[[1, 3, 5], [1, 3, 5], [1, 3, 5]], + upsample_rates=[10, 6, 2, 2, 2], + upsample_initial_channel=512, + upsample_kernel_sizes=[20, 12, 4, 4, 4], + gin_channels=0, + is_bias=True, + ) + hifigan_model.eval() + hifigan_model.remove_weight_norm() + state_dict_g = torch.load( + "%s/GPT_SoVITS/pretrained_models/gsv-v4-pretrained/vocoder.pth" % (now_dir,), + map_location="cpu", + weights_only=False, + ) + print("loading vocoder", hifigan_model.load_state_dict(state_dict_g)) + if is_half == True: + hifigan_model = hifigan_model.half().to(device) + else: + hifigan_model = hifigan_model.to(device) + + +from sv import SV + + +def init_sv_cn(): + global hifigan_model, bigvgan_model, sv_cn_model + sv_cn_model = SV(device, is_half) + + +resample_transform_dict = {} + + +def resample(audio_tensor, sr0, sr1, device): + global resample_transform_dict + key = "%s-%s-%s" % (sr0, sr1, str(device)) + if key not in resample_transform_dict: + resample_transform_dict[key] = torchaudio.transforms.Resample(sr0, sr1).to(device) + return resample_transform_dict[key](audio_tensor) + + +from module.mel_processing import mel_spectrogram_torch + +spec_min = -12 +spec_max = 2 + + +def norm_spec(x): + return (x - spec_min) / (spec_max - spec_min) * 2 - 1 + + +def denorm_spec(x): + return (x + 1) / 2 * (spec_max - spec_min) + spec_min + + +mel_fn = lambda x: mel_spectrogram_torch( + x, + **{ + "n_fft": 1024, + "win_size": 1024, + "hop_size": 256, + "num_mels": 100, + "sampling_rate": 24000, + "fmin": 0, + "fmax": None, + "center": False, + }, +) +mel_fn_v4 = lambda x: mel_spectrogram_torch( + x, + **{ + "n_fft": 1280, + "win_size": 1280, + "hop_size": 320, + "num_mels": 100, + "sampling_rate": 32000, + "fmin": 0, + "fmax": None, + "center": False, + }, +) + + +sr_model = None + + +def audio_sr(audio, sr): + global sr_model + if sr_model == None: + from tools.audio_sr import AP_BWE + + try: + sr_model = AP_BWE(device, DictToAttrRecursive) + except FileNotFoundError: + logger.info("你没有下载超分模型的参数,因此不进行超分。如想超分请先参照教程把文件下载") + return audio.cpu().detach().numpy(), sr + return sr_model(audio, sr) + + +class Speaker: + def __init__(self, name, gpt, sovits, phones=None, bert=None, prompt=None): + self.name = name + self.sovits = sovits + self.gpt = gpt + self.phones = phones + self.bert = bert + self.prompt = prompt + + +speaker_list = {} + + +class Sovits: + def __init__(self, vq_model, hps): + self.vq_model = vq_model + self.hps = hps + + +from process_ckpt import get_sovits_version_from_path_fast, load_sovits_new + + +def get_sovits_weights(sovits_path): + from config import pretrained_sovits_name + + path_sovits_v3 = pretrained_sovits_name["v3"] + path_sovits_v4 = pretrained_sovits_name["v4"] + is_exist_s2gv3 = os.path.exists(path_sovits_v3) + is_exist_s2gv4 = os.path.exists(path_sovits_v4) + + version, model_version, if_lora_v3 = get_sovits_version_from_path_fast(sovits_path) + is_exist = is_exist_s2gv3 if model_version == "v3" else is_exist_s2gv4 + path_sovits = path_sovits_v3 if model_version == "v3" else path_sovits_v4 + + if if_lora_v3 == True and is_exist == False: + logger.info("SoVITS %s 底模缺失,无法加载相应 LoRA 权重" % model_version) + + dict_s2 = load_sovits_new(sovits_path) + hps = dict_s2["config"] + hps = DictToAttrRecursive(hps) + hps.model.semantic_frame_rate = "25hz" + if "enc_p.text_embedding.weight" not in dict_s2["weight"]: + hps.model.version = "v2" # v3model,v2sybomls + elif dict_s2["weight"]["enc_p.text_embedding.weight"].shape[0] == 322: + hps.model.version = "v1" + else: + hps.model.version = "v2" + + model_params_dict = vars(hps.model) + if model_version not in {"v3", "v4"}: + if "Pro" in model_version: + hps.model.version = model_version + if sv_cn_model == None: + init_sv_cn() + + vq_model = SynthesizerTrn( + hps.data.filter_length // 2 + 1, + hps.train.segment_size // hps.data.hop_length, + n_speakers=hps.data.n_speakers, + **model_params_dict, + ) + else: + hps.model.version = model_version + vq_model = SynthesizerTrnV3( + hps.data.filter_length // 2 + 1, + hps.train.segment_size // hps.data.hop_length, + n_speakers=hps.data.n_speakers, + **model_params_dict, + ) + if model_version == "v3": + init_bigvgan() + if model_version == "v4": + init_hifigan() + + model_version = hps.model.version + logger.info(f"模型版本: {model_version}") + if "pretrained" not in sovits_path: + try: + del vq_model.enc_q + except: + pass + if is_half == True: + vq_model = vq_model.half().to(device) + else: + vq_model = vq_model.to(device) + vq_model.eval() + if if_lora_v3 == False: + vq_model.load_state_dict(dict_s2["weight"], strict=False) + else: + path_sovits = path_sovits_v3 if model_version == "v3" else path_sovits_v4 + vq_model.load_state_dict(load_sovits_new(path_sovits)["weight"], strict=False) + lora_rank = dict_s2["lora_rank"] + lora_config = LoraConfig( + target_modules=["to_k", "to_q", "to_v", "to_out.0"], + r=lora_rank, + lora_alpha=lora_rank, + init_lora_weights=True, + ) + vq_model.cfm = get_peft_model(vq_model.cfm, lora_config) + vq_model.load_state_dict(dict_s2["weight"], strict=False) + vq_model.cfm = vq_model.cfm.merge_and_unload() + # torch.save(vq_model.state_dict(),"merge_win.pth") + vq_model.eval() + + sovits = Sovits(vq_model, hps) + return sovits + + +class Gpt: + def __init__(self, max_sec, t2s_model): + self.max_sec = max_sec + self.t2s_model = t2s_model + + +global hz +hz = 50 + + +def get_gpt_weights(gpt_path): + dict_s1 = torch.load(gpt_path, map_location="cpu", weights_only=False) + config = dict_s1["config"] + max_sec = config["data"]["max_sec"] + t2s_model = Text2SemanticLightningModule(config, "****", is_train=False) + t2s_model.load_state_dict(dict_s1["weight"]) + if is_half == True: + t2s_model = t2s_model.half() + t2s_model = t2s_model.to(device) + t2s_model.eval() + # total = sum([param.nelement() for param in t2s_model.parameters()]) + # logger.info("Number of parameter: %.2fM" % (total / 1e6)) + + gpt = Gpt(max_sec, t2s_model) + return gpt + + +def change_gpt_sovits_weights(gpt_path, sovits_path): + try: + gpt = get_gpt_weights(gpt_path) + sovits = get_sovits_weights(sovits_path) + except Exception as e: + return JSONResponse({"code": 400, "message": str(e)}, status_code=400) + + speaker_list["default"] = Speaker(name="default", gpt=gpt, sovits=sovits) + return JSONResponse({"code": 0, "message": "Success"}, status_code=200) + + +def get_bert_feature(text, word2ph): + with torch.no_grad(): + inputs = tokenizer(text, return_tensors="pt") + for i in inputs: + inputs[i] = inputs[i].to(device) #####输入是long不用管精度问题,精度随bert_model + res = bert_model(**inputs, output_hidden_states=True) + res = torch.cat(res["hidden_states"][-3:-2], -1)[0].cpu()[1:-1] + assert len(word2ph) == len(text) + phone_level_feature = [] + for i in range(len(word2ph)): + repeat_feature = res[i].repeat(word2ph[i], 1) + phone_level_feature.append(repeat_feature) + phone_level_feature = torch.cat(phone_level_feature, dim=0) + # if(is_half==True):phone_level_feature=phone_level_feature.half() + return phone_level_feature.T + + +def clean_text_inf(text, language, version): + language = language.replace("all_", "") + phones, word2ph, norm_text = clean_text(text, language, version) + phones = cleaned_text_to_sequence(phones, version) + return phones, word2ph, norm_text + + +def get_bert_inf(phones, word2ph, norm_text, language): + language = language.replace("all_", "") + if language == "zh": + bert = get_bert_feature(norm_text, word2ph).to(device) # .to(dtype) + else: + bert = torch.zeros( + (1024, len(phones)), + dtype=torch.float16 if is_half == True else torch.float32, + ).to(device) + + return bert + + +from text import chinese + + +def get_phones_and_bert(text, language, version, final=False): + text = re.sub(r' {2,}', ' ', text) + textlist = [] + langlist = [] + if language == "all_zh": + for tmp in LangSegmenter.getTexts(text,"zh"): + langlist.append(tmp["lang"]) + textlist.append(tmp["text"]) + elif language == "all_yue": + for tmp in LangSegmenter.getTexts(text,"zh"): + if tmp["lang"] == "zh": + tmp["lang"] = "yue" + langlist.append(tmp["lang"]) + textlist.append(tmp["text"]) + elif language == "all_ja": + for tmp in LangSegmenter.getTexts(text,"ja"): + langlist.append(tmp["lang"]) + textlist.append(tmp["text"]) + elif language == "all_ko": + for tmp in LangSegmenter.getTexts(text,"ko"): + langlist.append(tmp["lang"]) + textlist.append(tmp["text"]) + elif language == "en": + langlist.append("en") + textlist.append(text) + elif language == "auto": + for tmp in LangSegmenter.getTexts(text): + langlist.append(tmp["lang"]) + textlist.append(tmp["text"]) + elif language == "auto_yue": + for tmp in LangSegmenter.getTexts(text): + if tmp["lang"] == "zh": + tmp["lang"] = "yue" + langlist.append(tmp["lang"]) + textlist.append(tmp["text"]) + else: + for tmp in LangSegmenter.getTexts(text): + if langlist: + if (tmp["lang"] == "en" and langlist[-1] == "en") or (tmp["lang"] != "en" and langlist[-1] != "en"): + textlist[-1] += tmp["text"] + continue + if tmp["lang"] == "en": + langlist.append(tmp["lang"]) + else: + # 因无法区别中日韩文汉字,以用户输入为准 + langlist.append(language) + textlist.append(tmp["text"]) + phones_list = [] + bert_list = [] + norm_text_list = [] + for i in range(len(textlist)): + lang = langlist[i] + phones, word2ph, norm_text = clean_text_inf(textlist[i], lang, version) + bert = get_bert_inf(phones, word2ph, norm_text, lang) + phones_list.append(phones) + norm_text_list.append(norm_text) + bert_list.append(bert) + bert = torch.cat(bert_list, dim=1) + phones = sum(phones_list, []) + norm_text = "".join(norm_text_list) + + if not final and len(phones) < 6: + return get_phones_and_bert("." + text, language, version, final=True) + + return phones, bert.to(torch.float16 if is_half == True else torch.float32), norm_text + + +class DictToAttrRecursive(dict): + def __init__(self, input_dict): + super().__init__(input_dict) + for key, value in input_dict.items(): + if isinstance(value, dict): + value = DictToAttrRecursive(value) + self[key] = value + setattr(self, key, value) + + def __getattr__(self, item): + try: + return self[item] + except KeyError: + raise AttributeError(f"Attribute {item} not found") + + def __setattr__(self, key, value): + if isinstance(value, dict): + value = DictToAttrRecursive(value) + super(DictToAttrRecursive, self).__setitem__(key, value) + super().__setattr__(key, value) + + def __delattr__(self, item): + try: + del self[item] + except KeyError: + raise AttributeError(f"Attribute {item} not found") + + +def get_spepc(hps, filename, dtype, device, is_v2pro=False): + sr1 = int(hps.data.sampling_rate) + audio, sr0 = torchaudio.load(filename) + if sr0 != sr1: + audio = audio.to(device) + if audio.shape[0] == 2: + audio = audio.mean(0).unsqueeze(0) + audio = resample(audio, sr0, sr1, device) + else: + audio = audio.to(device) + if audio.shape[0] == 2: + audio = audio.mean(0).unsqueeze(0) + + maxx = audio.abs().max() + if maxx > 1: + audio /= min(2, maxx) + spec = spectrogram_torch( + audio, + hps.data.filter_length, + hps.data.sampling_rate, + hps.data.hop_length, + hps.data.win_length, + center=False, + ) + spec = spec.to(dtype) + if is_v2pro == True: + audio = resample(audio, sr1, 16000, device).to(dtype) + return spec, audio + + +def pack_audio(audio_bytes, data, rate): + if media_type == "ogg": + audio_bytes = pack_ogg(audio_bytes, data, rate) + elif media_type == "aac": + audio_bytes = pack_aac(audio_bytes, data, rate) + else: + # wav无法流式, 先暂存raw + audio_bytes = pack_raw(audio_bytes, data, rate) + + return audio_bytes + + +def pack_ogg(audio_bytes, data, rate): + # Author: AkagawaTsurunaki + # Issue: + # Stack overflow probabilistically occurs + # when the function `sf_writef_short` of `libsndfile_64bit.dll` is called + # using the Python library `soundfile` + # Note: + # This is an issue related to `libsndfile`, not this project itself. + # It happens when you generate a large audio tensor (about 499804 frames in my PC) + # and try to convert it to an ogg file. + # Related: + # https://github.com/RVC-Boss/GPT-SoVITS/issues/1199 + # https://github.com/libsndfile/libsndfile/issues/1023 + # https://github.com/bastibe/python-soundfile/issues/396 + # Suggestion: + # Or split the whole audio data into smaller audio segment to avoid stack overflow? + + def handle_pack_ogg(): + with sf.SoundFile(audio_bytes, mode="w", samplerate=rate, channels=1, format="ogg") as audio_file: + audio_file.write(data) + + import threading + + # See: https://docs.python.org/3/library/threading.html + # The stack size of this thread is at least 32768 + # If stack overflow error still occurs, just modify the `stack_size`. + # stack_size = n * 4096, where n should be a positive integer. + # Here we chose n = 4096. + stack_size = 4096 * 4096 + try: + threading.stack_size(stack_size) + pack_ogg_thread = threading.Thread(target=handle_pack_ogg) + pack_ogg_thread.start() + pack_ogg_thread.join() + except RuntimeError as e: + # If changing the thread stack size is unsupported, a RuntimeError is raised. + print("RuntimeError: {}".format(e)) + print("Changing the thread stack size is unsupported.") + except ValueError as e: + # If the specified stack size is invalid, a ValueError is raised and the stack size is unmodified. + print("ValueError: {}".format(e)) + print("The specified stack size is invalid.") + + return audio_bytes + + +def pack_raw(audio_bytes, data, rate): + audio_bytes.write(data.tobytes()) + + return audio_bytes + + +def pack_wav(audio_bytes, rate): + if is_int32: + data = np.frombuffer(audio_bytes.getvalue(), dtype=np.int32) + wav_bytes = BytesIO() + sf.write(wav_bytes, data, rate, format="WAV", subtype="PCM_32") + else: + data = np.frombuffer(audio_bytes.getvalue(), dtype=np.int16) + wav_bytes = BytesIO() + sf.write(wav_bytes, data, rate, format="WAV") + return wav_bytes + + +def pack_aac(audio_bytes, data, rate): + if is_int32: + pcm = "s32le" + bit_rate = "256k" + else: + pcm = "s16le" + bit_rate = "128k" + process = subprocess.Popen( + [ + "ffmpeg", + "-f", + pcm, # 输入16位有符号小端整数PCM + "-ar", + str(rate), # 设置采样率 + "-ac", + "1", # 单声道 + "-i", + "pipe:0", # 从管道读取输入 + "-c:a", + "aac", # 音频编码器为AAC + "-b:a", + bit_rate, # 比特率 + "-vn", # 不包含视频 + "-f", + "adts", # 输出AAC数据流格式 + "pipe:1", # 将输出写入管道 + ], + stdin=subprocess.PIPE, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + ) + out, _ = process.communicate(input=data.tobytes()) + audio_bytes.write(out) + + return audio_bytes + + +def read_clean_buffer(audio_bytes): + audio_chunk = audio_bytes.getvalue() + audio_bytes.truncate(0) + audio_bytes.seek(0) + + return audio_bytes, audio_chunk + + +def cut_text(text, punc): + punc_list = [p for p in punc if p in {",", ".", ";", "?", "!", "、", ",", "。", "?", "!", ";", ":", "…"}] + if len(punc_list) > 0: + punds = r"[" + "".join(punc_list) + r"]" + text = text.strip("\n") + items = re.split(f"({punds})", text) + mergeitems = ["".join(group) for group in zip(items[::2], items[1::2])] + # 在句子不存在符号或句尾无符号的时候保证文本完整 + if len(items) % 2 == 1: + mergeitems.append(items[-1]) + text = "\n".join(mergeitems) + + while "\n\n" in text: + text = text.replace("\n\n", "\n") + + return text + + +def only_punc(text): + return not any(t.isalnum() or t.isalpha() for t in text) + + +splits = { + ",", + "。", + "?", + "!", + ",", + ".", + "?", + "!", + "~", + ":", + ":", + "—", + "…", +} + + +def get_tts_wav( + ref_wav_path, + prompt_text, + prompt_language, + text, + text_language, + top_k=15, + top_p=0.6, + temperature=0.6, + speed=1, + inp_refs=None, + sample_steps=32, + if_sr=False, + spk="default", +): + infer_sovits = speaker_list[spk].sovits + vq_model = infer_sovits.vq_model + hps = infer_sovits.hps + version = vq_model.version + + infer_gpt = speaker_list[spk].gpt + t2s_model = infer_gpt.t2s_model + max_sec = infer_gpt.max_sec + + if version == "v3": + if sample_steps not in [4, 8, 16, 32, 64, 128]: + sample_steps = 32 + elif version == "v4": + if sample_steps not in [4, 8, 16, 32]: + sample_steps = 8 + + if if_sr and version != "v3": + if_sr = False + + t0 = ttime() + prompt_text = prompt_text.strip("\n") + if prompt_text[-1] not in splits: + prompt_text += "。" if prompt_language != "en" else "." + prompt_language, text = prompt_language, text.strip("\n") + dtype = torch.float16 if is_half == True else torch.float32 + zero_wav = np.zeros(int(hps.data.sampling_rate * 0.3), dtype=np.float16 if is_half == True else np.float32) + with torch.no_grad(): + wav16k, sr = librosa.load(ref_wav_path, sr=16000) + wav16k = torch.from_numpy(wav16k) + zero_wav_torch = torch.from_numpy(zero_wav) + if is_half == True: + wav16k = wav16k.half().to(device) + zero_wav_torch = zero_wav_torch.half().to(device) + else: + wav16k = wav16k.to(device) + zero_wav_torch = zero_wav_torch.to(device) + wav16k = torch.cat([wav16k, zero_wav_torch]) + ssl_content = ssl_model.model(wav16k.unsqueeze(0))["last_hidden_state"].transpose(1, 2) # .float() + codes = vq_model.extract_latent(ssl_content) + prompt_semantic = codes[0, 0] + prompt = prompt_semantic.unsqueeze(0).to(device) + + is_v2pro = version in {"v2Pro", "v2ProPlus"} + if version not in {"v3", "v4"}: + refers = [] + if is_v2pro: + sv_emb = [] + if sv_cn_model == None: + init_sv_cn() + if inp_refs: + for path in inp_refs: + try: #####这里加上提取sv的逻辑,要么一堆sv一堆refer,要么单个sv单个refer + refer, audio_tensor = get_spepc(hps, path.name, dtype, device, is_v2pro) + refers.append(refer) + if is_v2pro: + sv_emb.append(sv_cn_model.compute_embedding3(audio_tensor)) + except Exception as e: + logger.error(e) + if len(refers) == 0: + refers, audio_tensor = get_spepc(hps, ref_wav_path, dtype, device, is_v2pro) + refers = [refers] + if is_v2pro: + sv_emb = [sv_cn_model.compute_embedding3(audio_tensor)] + else: + refer, audio_tensor = get_spepc(hps, ref_wav_path, dtype, device) + + t1 = ttime() + # os.environ['version'] = version + prompt_language = dict_language[prompt_language.lower()] + text_language = dict_language[text_language.lower()] + phones1, bert1, norm_text1 = get_phones_and_bert(prompt_text, prompt_language, version) + texts = text.split("\n") + audio_bytes = BytesIO() + + for text in texts: + # 简单防止纯符号引发参考音频泄露 + if only_punc(text): + continue + + audio_opt = [] + if text[-1] not in splits: + text += "。" if text_language != "en" else "." + phones2, bert2, norm_text2 = get_phones_and_bert(text, text_language, version) + bert = torch.cat([bert1, bert2], 1) + + all_phoneme_ids = torch.LongTensor(phones1 + phones2).to(device).unsqueeze(0) + bert = bert.to(device).unsqueeze(0) + all_phoneme_len = torch.tensor([all_phoneme_ids.shape[-1]]).to(device) + t2 = ttime() + with torch.no_grad(): + pred_semantic, idx = t2s_model.model.infer_panel( + all_phoneme_ids, + all_phoneme_len, + prompt, + bert, + # prompt_phone_len=ph_offset, + top_k=top_k, + top_p=top_p, + temperature=temperature, + early_stop_num=hz * max_sec, + ) + pred_semantic = pred_semantic[:, -idx:].unsqueeze(0) + t3 = ttime() + + if version not in {"v3", "v4"}: + if is_v2pro: + audio = ( + vq_model.decode( + pred_semantic, + torch.LongTensor(phones2).to(device).unsqueeze(0), + refers, + speed=speed, + sv_emb=sv_emb, + ) + .detach() + .cpu() + .numpy()[0, 0] + ) + else: + audio = ( + vq_model.decode( + pred_semantic, torch.LongTensor(phones2).to(device).unsqueeze(0), refers, speed=speed + ) + .detach() + .cpu() + .numpy()[0, 0] + ) + else: + phoneme_ids0 = torch.LongTensor(phones1).to(device).unsqueeze(0) + phoneme_ids1 = torch.LongTensor(phones2).to(device).unsqueeze(0) + + fea_ref, ge = vq_model.decode_encp(prompt.unsqueeze(0), phoneme_ids0, refer) + ref_audio, sr = torchaudio.load(ref_wav_path) + ref_audio = ref_audio.to(device).float() + if ref_audio.shape[0] == 2: + ref_audio = ref_audio.mean(0).unsqueeze(0) + + tgt_sr = 24000 if version == "v3" else 32000 + if sr != tgt_sr: + ref_audio = resample(ref_audio, sr, tgt_sr, device) + mel2 = mel_fn(ref_audio) if version == "v3" else mel_fn_v4(ref_audio) + mel2 = norm_spec(mel2) + T_min = min(mel2.shape[2], fea_ref.shape[2]) + mel2 = mel2[:, :, :T_min] + fea_ref = fea_ref[:, :, :T_min] + Tref = 468 if version == "v3" else 500 + Tchunk = 934 if version == "v3" else 1000 + if T_min > Tref: + mel2 = mel2[:, :, -Tref:] + fea_ref = fea_ref[:, :, -Tref:] + T_min = Tref + chunk_len = Tchunk - T_min + mel2 = mel2.to(dtype) + fea_todo, ge = vq_model.decode_encp(pred_semantic, phoneme_ids1, refer, ge, speed) + cfm_resss = [] + idx = 0 + while 1: + fea_todo_chunk = fea_todo[:, :, idx : idx + chunk_len] + if fea_todo_chunk.shape[-1] == 0: + break + idx += chunk_len + fea = torch.cat([fea_ref, fea_todo_chunk], 2).transpose(2, 1) + cfm_res = vq_model.cfm.inference( + fea, torch.LongTensor([fea.size(1)]).to(fea.device), mel2, sample_steps, inference_cfg_rate=0 + ) + cfm_res = cfm_res[:, :, mel2.shape[2] :] + mel2 = cfm_res[:, :, -T_min:] + fea_ref = fea_todo_chunk[:, :, -T_min:] + cfm_resss.append(cfm_res) + cfm_res = torch.cat(cfm_resss, 2) + cfm_res = denorm_spec(cfm_res) + if version == "v3": + if bigvgan_model == None: + init_bigvgan() + else: # v4 + if hifigan_model == None: + init_hifigan() + vocoder_model = bigvgan_model if version == "v3" else hifigan_model + with torch.inference_mode(): + wav_gen = vocoder_model(cfm_res) + audio = wav_gen[0][0].cpu().detach().numpy() + + max_audio = np.abs(audio).max() + if max_audio > 1: + audio /= max_audio + audio_opt.append(audio) + audio_opt.append(zero_wav) + audio_opt = np.concatenate(audio_opt, 0) + t4 = ttime() + + if version in {"v1", "v2", "v2Pro", "v2ProPlus"}: + sr = 32000 + elif version == "v3": + sr = 24000 + else: + sr = 48000 # v4 + + if if_sr and sr == 24000: + audio_opt = torch.from_numpy(audio_opt).float().to(device) + audio_opt, sr = audio_sr(audio_opt.unsqueeze(0), sr) + max_audio = np.abs(audio_opt).max() + if max_audio > 1: + audio_opt /= max_audio + sr = 48000 + + if is_int32: + audio_bytes = pack_audio(audio_bytes, (audio_opt * 2147483647).astype(np.int32), sr) + else: + audio_bytes = pack_audio(audio_bytes, (audio_opt * 32768).astype(np.int16), sr) + # logger.info("%.3f\t%.3f\t%.3f\t%.3f" % (t1 - t0, t2 - t1, t3 - t2, t4 - t3)) + if stream_mode == "normal": + audio_bytes, audio_chunk = read_clean_buffer(audio_bytes) + yield audio_chunk + + if not stream_mode == "normal": + if media_type == "wav": + if version in {"v1", "v2", "v2Pro", "v2ProPlus"}: + sr = 32000 + elif version == "v3": + sr = 48000 if if_sr else 24000 + else: + sr = 48000 # v4 + audio_bytes = pack_wav(audio_bytes, sr) + yield audio_bytes.getvalue() + + +def handle_control(command): + if command == "restart": + os.execl(g_config.python_exec, g_config.python_exec, *sys.argv) + elif command == "exit": + os.kill(os.getpid(), signal.SIGTERM) + exit(0) + + +def handle_change(path, text, language): + if is_empty(path, text, language): + return JSONResponse( + {"code": 400, "message": '缺少任意一项以下参数: "path", "text", "language"'}, status_code=400 + ) + + if path != "" or path is not None: + default_refer.path = path + if text != "" or text is not None: + default_refer.text = text + if language != "" or language is not None: + default_refer.language = language + + logger.info(f"当前默认参考音频路径: {default_refer.path}") + logger.info(f"当前默认参考音频文本: {default_refer.text}") + logger.info(f"当前默认参考音频语种: {default_refer.language}") + logger.info(f"is_ready: {default_refer.is_ready()}") + + return JSONResponse({"code": 0, "message": "Success"}, status_code=200) + + +def handle( + refer_wav_path, + prompt_text, + prompt_language, + text, + text_language, + cut_punc, + top_k, + top_p, + temperature, + speed, + inp_refs, + sample_steps, + if_sr, +): + if ( + refer_wav_path == "" + or refer_wav_path is None + or prompt_text == "" + or prompt_text is None + or prompt_language == "" + or prompt_language is None + ): + refer_wav_path, prompt_text, prompt_language = ( + default_refer.path, + default_refer.text, + default_refer.language, + ) + if not default_refer.is_ready(): + return JSONResponse({"code": 400, "message": "未指定参考音频且接口无预设"}, status_code=400) + + if cut_punc == None: + text = cut_text(text, default_cut_punc) + else: + text = cut_text(text, cut_punc) + + return StreamingResponse( + get_tts_wav( + refer_wav_path, + prompt_text, + prompt_language, + text, + text_language, + top_k, + top_p, + temperature, + speed, + inp_refs, + sample_steps, + if_sr, + ), + media_type="audio/" + media_type, + ) + + +# -------------------------------- +# 初始化部分 +# -------------------------------- +dict_language = { + "中文": "all_zh", + "粤语": "all_yue", + "英文": "en", + "日文": "all_ja", + "韩文": "all_ko", + "中英混合": "zh", + "粤英混合": "yue", + "日英混合": "ja", + "韩英混合": "ko", + "多语种混合": "auto", # 多语种启动切分识别语种 + "多语种混合(粤语)": "auto_yue", + "all_zh": "all_zh", + "all_yue": "all_yue", + "en": "en", + "all_ja": "all_ja", + "all_ko": "all_ko", + "zh": "zh", + "yue": "yue", + "ja": "ja", + "ko": "ko", + "auto": "auto", + "auto_yue": "auto_yue", +} + +# logger +logging.config.dictConfig(uvicorn.config.LOGGING_CONFIG) +logger = logging.getLogger("uvicorn") + +# 获取配置 +g_config = global_config.Config() + +# 获取参数 +parser = argparse.ArgumentParser(description="GPT-SoVITS api") + +parser.add_argument("-s", "--sovits_path", type=str, default=g_config.sovits_path, help="SoVITS模型路径") +parser.add_argument("-g", "--gpt_path", type=str, default=g_config.gpt_path, help="GPT模型路径") +parser.add_argument("-dr", "--default_refer_path", type=str, default="", help="默认参考音频路径") +parser.add_argument("-dt", "--default_refer_text", type=str, default="", help="默认参考音频文本") +parser.add_argument("-dl", "--default_refer_language", type=str, default="", help="默认参考音频语种") +parser.add_argument("-d", "--device", type=str, default=g_config.infer_device, help="xpu / cpu") +parser.add_argument("-a", "--bind_addr", type=str, default="0.0.0.0", help="default: 0.0.0.0") +parser.add_argument("-p", "--port", type=int, default=g_config.api_port, help="default: 9880") +parser.add_argument( + "-fp", "--full_precision", action="store_true", default=False, help="覆盖config.is_half为False, 使用全精度" +) +parser.add_argument( + "-hp", "--half_precision", action="store_true", default=False, help="覆盖config.is_half为True, 使用半精度" +) +# bool值的用法为 `python ./api.py -fp ...` +# 此时 full_precision==True, half_precision==False +parser.add_argument("-sm", "--stream_mode", type=str, default="close", help="流式返回模式, close / normal / keepalive") +parser.add_argument("-mt", "--media_type", type=str, default="wav", help="音频编码格式, wav / ogg / aac") +parser.add_argument("-st", "--sub_type", type=str, default="int16", help="音频数据类型, int16 / int32") +parser.add_argument("-cp", "--cut_punc", type=str, default="", help="文本切分符号设定, 符号范围,.;?!、,。?!;:…") +# 切割常用分句符为 `python ./api.py -cp ".?!。?!"` +parser.add_argument("-hb", "--hubert_path", type=str, default=g_config.cnhubert_path, help="覆盖config.cnhubert_path") +parser.add_argument("-b", "--bert_path", type=str, default=g_config.bert_path, help="覆盖config.bert_path") + +args = parser.parse_args() +sovits_path = args.sovits_path +gpt_path = args.gpt_path +device = args.device +port = args.port +host = args.bind_addr +cnhubert_base_path = args.hubert_path +bert_path = args.bert_path +default_cut_punc = args.cut_punc + +# 应用参数配置 +default_refer = DefaultRefer(args.default_refer_path, args.default_refer_text, args.default_refer_language) + +# 模型路径检查 +if sovits_path == "": + sovits_path = g_config.pretrained_sovits_path + logger.warning(f"未指定SoVITS模型路径, fallback后当前值: {sovits_path}") +if gpt_path == "": + gpt_path = g_config.pretrained_gpt_path + logger.warning(f"未指定GPT模型路径, fallback后当前值: {gpt_path}") + +# 指定默认参考音频, 调用方 未提供/未给全 参考音频参数时使用 +if default_refer.path == "" or default_refer.text == "" or default_refer.language == "": + default_refer.path, default_refer.text, default_refer.language = "", "", "" + logger.info("未指定默认参考音频") +else: + logger.info(f"默认参考音频路径: {default_refer.path}") + logger.info(f"默认参考音频文本: {default_refer.text}") + logger.info(f"默认参考音频语种: {default_refer.language}") + +# 获取半精度 +is_half = g_config.is_half +if args.full_precision: + is_half = False +if args.half_precision: + is_half = True +if args.full_precision and args.half_precision: + is_half = g_config.is_half # 炒饭fallback +logger.info(f"半精: {is_half}") + +# 流式返回模式 +if args.stream_mode.lower() in ["normal", "n"]: + stream_mode = "normal" + logger.info("流式返回已开启") +else: + stream_mode = "close" + +# 音频编码格式 +if args.media_type.lower() in ["aac", "ogg"]: + media_type = args.media_type.lower() +elif stream_mode == "close": + media_type = "wav" +else: + media_type = "ogg" +logger.info(f"编码格式: {media_type}") + +# 音频数据类型 +if args.sub_type.lower() == "int32": + is_int32 = True + logger.info("数据类型: int32") +else: + is_int32 = False + logger.info("数据类型: int16") + +# 初始化模型 +cnhubert.cnhubert_base_path = cnhubert_base_path +tokenizer = AutoTokenizer.from_pretrained(bert_path) +bert_model = AutoModelForMaskedLM.from_pretrained(bert_path) +ssl_model = cnhubert.get_model() +if is_half: + bert_model = bert_model.half().to(device) + ssl_model = ssl_model.half().to(device) +else: + bert_model = bert_model.to(device) + ssl_model = ssl_model.to(device) +change_gpt_sovits_weights(gpt_path=gpt_path, sovits_path=sovits_path) + + +# -------------------------------- +# 接口部分 +# -------------------------------- +app = FastAPI() + + +@app.post("/set_model") +async def set_model(request: Request): + json_post_raw = await request.json() + return change_gpt_sovits_weights( + gpt_path=json_post_raw.get("gpt_model_path"), sovits_path=json_post_raw.get("sovits_model_path") + ) + + +@app.get("/set_model") +async def set_model( + gpt_model_path: str = None, + sovits_model_path: str = None, +): + return change_gpt_sovits_weights(gpt_path=gpt_model_path, sovits_path=sovits_model_path) + + +@app.post("/control") +async def control(request: Request): + json_post_raw = await request.json() + return handle_control(json_post_raw.get("command")) + + +@app.get("/control") +async def control(command: str = None): + return handle_control(command) + + +@app.post("/change_refer") +async def change_refer(request: Request): + json_post_raw = await request.json() + return handle_change( + json_post_raw.get("refer_wav_path"), json_post_raw.get("prompt_text"), json_post_raw.get("prompt_language") + ) + + +@app.get("/change_refer") +async def change_refer(refer_wav_path: str = None, prompt_text: str = None, prompt_language: str = None): + return handle_change(refer_wav_path, prompt_text, prompt_language) + + +@app.post("/") +async def tts_endpoint(request: Request): + json_post_raw = await request.json() + return handle( + json_post_raw.get("refer_wav_path"), + json_post_raw.get("prompt_text"), + json_post_raw.get("prompt_language"), + json_post_raw.get("text"), + json_post_raw.get("text_language"), + json_post_raw.get("cut_punc"), + json_post_raw.get("top_k", 15), + json_post_raw.get("top_p", 1.0), + json_post_raw.get("temperature", 1.0), + json_post_raw.get("speed", 1.0), + json_post_raw.get("inp_refs", []), + json_post_raw.get("sample_steps", 32), + json_post_raw.get("if_sr", False), + ) + + +@app.get("/") +async def tts_endpoint( + refer_wav_path: str = None, + prompt_text: str = None, + prompt_language: str = None, + text: str = None, + text_language: str = None, + cut_punc: str = None, + top_k: int = 15, + top_p: float = 1.0, + temperature: float = 1.0, + speed: float = 1.0, + inp_refs: list = Query(default=[]), + sample_steps: int = 32, + if_sr: bool = False, +): + return handle( + refer_wav_path, + prompt_text, + prompt_language, + text, + text_language, + cut_punc, + top_k, + top_p, + temperature, + speed, + inp_refs, + sample_steps, + if_sr, + ) + + +if __name__ == "__main__": + uvicorn.run(app, host=host, port=port, workers=1) diff --git a/config.py b/config.py index fdc11c0a..eee1b21d 100644 --- a/config.py +++ b/config.py @@ -1,218 +1,224 @@ -import os -import re -import sys - -import torch - -from tools.i18n.i18n import I18nAuto - -i18n = I18nAuto(language=os.environ.get("language", "Auto")) - - -pretrained_sovits_name = { - "v1": "GPT_SoVITS/pretrained_models/s2G488k.pth", - "v2": "GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s2G2333k.pth", - "v3": "GPT_SoVITS/pretrained_models/s2Gv3.pth", ###v3v4还要检查vocoder,算了。。。 - "v4": "GPT_SoVITS/pretrained_models/gsv-v4-pretrained/s2Gv4.pth", - "v2Pro": "GPT_SoVITS/pretrained_models/v2Pro/s2Gv2Pro.pth", - "v2ProPlus": "GPT_SoVITS/pretrained_models/v2Pro/s2Gv2ProPlus.pth", -} - -pretrained_gpt_name = { - "v1": "GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt", - "v2": "GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s1bert25hz-5kh-longer-epoch=12-step=369668.ckpt", - "v3": "GPT_SoVITS/pretrained_models/s1v3.ckpt", - "v4": "GPT_SoVITS/pretrained_models/s1v3.ckpt", - "v2Pro": "GPT_SoVITS/pretrained_models/s1v3.ckpt", - "v2ProPlus": "GPT_SoVITS/pretrained_models/s1v3.ckpt", -} -name2sovits_path = { - # i18n("不训练直接推v1底模!"): "GPT_SoVITS/pretrained_models/s2G488k.pth", - i18n("不训练直接推v2底模!"): "GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s2G2333k.pth", - # i18n("不训练直接推v3底模!"): "GPT_SoVITS/pretrained_models/s2Gv3.pth", - # i18n("不训练直接推v4底模!"): "GPT_SoVITS/pretrained_models/gsv-v4-pretrained/s2Gv4.pth", - i18n("不训练直接推v2Pro底模!"): "GPT_SoVITS/pretrained_models/v2Pro/s2Gv2Pro.pth", - i18n("不训练直接推v2ProPlus底模!"): "GPT_SoVITS/pretrained_models/v2Pro/s2Gv2ProPlus.pth", -} -name2gpt_path = { - # i18n("不训练直接推v1底模!"):"GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt", - i18n( - "不训练直接推v2底模!" - ): "GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s1bert25hz-5kh-longer-epoch=12-step=369668.ckpt", - i18n("不训练直接推v3底模!"): "GPT_SoVITS/pretrained_models/s1v3.ckpt", -} -SoVITS_weight_root = [ - "SoVITS_weights", - "SoVITS_weights_v2", - "SoVITS_weights_v3", - "SoVITS_weights_v4", - "SoVITS_weights_v2Pro", - "SoVITS_weights_v2ProPlus", -] -GPT_weight_root = [ - "GPT_weights", - "GPT_weights_v2", - "GPT_weights_v3", - "GPT_weights_v4", - "GPT_weights_v2Pro", - "GPT_weights_v2ProPlus", -] -SoVITS_weight_version2root = { - "v1": "SoVITS_weights", - "v2": "SoVITS_weights_v2", - "v3": "SoVITS_weights_v3", - "v4": "SoVITS_weights_v4", - "v2Pro": "SoVITS_weights_v2Pro", - "v2ProPlus": "SoVITS_weights_v2ProPlus", -} -GPT_weight_version2root = { - "v1": "GPT_weights", - "v2": "GPT_weights_v2", - "v3": "GPT_weights_v3", - "v4": "GPT_weights_v4", - "v2Pro": "GPT_weights_v2Pro", - "v2ProPlus": "GPT_weights_v2ProPlus", -} - - -def custom_sort_key(s): - # 使用正则表达式提取字符串中的数字部分和非数字部分 - parts = re.split("(\d+)", s) - # 将数字部分转换为整数,非数字部分保持不变 - parts = [int(part) if part.isdigit() else part for part in parts] - return parts - - -def get_weights_names(): - SoVITS_names = [] - for key in name2sovits_path: - if os.path.exists(name2sovits_path[key]): - SoVITS_names.append(key) - for path in SoVITS_weight_root: - if not os.path.exists(path): - continue - for name in os.listdir(path): - if name.endswith(".pth"): - SoVITS_names.append("%s/%s" % (path, name)) - if not SoVITS_names: - SoVITS_names = [""] - GPT_names = [] - for key in name2gpt_path: - if os.path.exists(name2gpt_path[key]): - GPT_names.append(key) - for path in GPT_weight_root: - if not os.path.exists(path): - continue - for name in os.listdir(path): - if name.endswith(".ckpt"): - GPT_names.append("%s/%s" % (path, name)) - SoVITS_names = sorted(SoVITS_names, key=custom_sort_key) - GPT_names = sorted(GPT_names, key=custom_sort_key) - if not GPT_names: - GPT_names = [""] - return SoVITS_names, GPT_names - - -def change_choices(): - SoVITS_names, GPT_names = get_weights_names() - return {"choices": SoVITS_names, "__type__": "update"}, { - "choices": GPT_names, - "__type__": "update", - } - - -# 推理用的指定模型 -sovits_path = "" -gpt_path = "" -is_half_str = os.environ.get("is_half", "True") -is_half = True if is_half_str.lower() == "true" else False -is_share_str = os.environ.get("is_share", "False") -is_share = True if is_share_str.lower() == "true" else False - -cnhubert_path = "GPT_SoVITS/pretrained_models/chinese-hubert-base" -bert_path = "GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large" -pretrained_sovits_path = "GPT_SoVITS/pretrained_models/s2G488k.pth" -pretrained_gpt_path = "GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt" - -exp_root = "logs" -python_exec = sys.executable or "python" - -webui_port_main = 9874 -webui_port_uvr5 = 9873 -webui_port_infer_tts = 9872 -webui_port_subfix = 9871 - -api_port = 9880 - - -# Thanks to the contribution of @Karasukaigan and @XXXXRT666 -def get_device_dtype_sm(idx: int) -> tuple[torch.device, torch.dtype, float, float]: - cpu = torch.device("cpu") - cuda = torch.device(f"cuda:{idx}") - if not torch.cuda.is_available(): - return cpu, torch.float32, 0.0, 0.0 - device_idx = idx - capability = torch.cuda.get_device_capability(device_idx) - name = torch.cuda.get_device_name(device_idx) - mem_bytes = torch.cuda.get_device_properties(device_idx).total_memory - mem_gb = mem_bytes / (1024**3) + 0.4 - major, minor = capability - sm_version = major + minor / 10.0 - is_16_series = bool(re.search(r"16\d{2}", name)) and sm_version == 7.5 - if mem_gb < 4 or sm_version < 5.3: - return cpu, torch.float32, 0.0, 0.0 - if sm_version == 6.1 or is_16_series == True: - return cuda, torch.float32, sm_version, mem_gb - if sm_version > 6.1: - return cuda, torch.float16, sm_version, mem_gb - return cpu, torch.float32, 0.0, 0.0 - - -IS_GPU = True -GPU_INFOS: list[str] = [] -GPU_INDEX: set[int] = set() -GPU_COUNT = torch.cuda.device_count() -CPU_INFO: str = "0\tCPU " + i18n("CPU训练,较慢") -tmp: list[tuple[torch.device, torch.dtype, float, float]] = [] -memset: set[float] = set() - -for i in range(max(GPU_COUNT, 1)): - tmp.append(get_device_dtype_sm(i)) - -for j in tmp: - device = j[0] - memset.add(j[3]) - if device.type != "cpu": - GPU_INFOS.append(f"{device.index}\t{torch.cuda.get_device_name(device.index)}") - GPU_INDEX.add(device.index) - -if not GPU_INFOS: - IS_GPU = False - GPU_INFOS.append(CPU_INFO) - GPU_INDEX.add(0) - -infer_device = max(tmp, key=lambda x: (x[2], x[3]))[0] -is_half = any(dtype == torch.float16 for _, dtype, _, _ in tmp) - - -class Config: - def __init__(self): - self.sovits_path = sovits_path - self.gpt_path = gpt_path - self.is_half = is_half - - self.cnhubert_path = cnhubert_path - self.bert_path = bert_path - self.pretrained_sovits_path = pretrained_sovits_path - self.pretrained_gpt_path = pretrained_gpt_path - - self.exp_root = exp_root - self.python_exec = python_exec - self.infer_device = infer_device - - self.webui_port_main = webui_port_main - self.webui_port_uvr5 = webui_port_uvr5 - self.webui_port_infer_tts = webui_port_infer_tts - self.webui_port_subfix = webui_port_subfix - - self.api_port = api_port +import os +import re +import sys + +import torch + +from tools.i18n.i18n import I18nAuto + +i18n = I18nAuto(language=os.environ.get("language", "Auto")) + + +pretrained_sovits_name = { + "v1": "GPT_SoVITS/pretrained_models/s2G488k.pth", + "v2": "GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s2G2333k.pth", + "v3": "GPT_SoVITS/pretrained_models/s2Gv3.pth", ###v3v4还要检查vocoder,算了。。。 + "v4": "GPT_SoVITS/pretrained_models/gsv-v4-pretrained/s2Gv4.pth", + "v2Pro": "GPT_SoVITS/pretrained_models/v2Pro/s2Gv2Pro.pth", + "v2ProPlus": "GPT_SoVITS/pretrained_models/v2Pro/s2Gv2ProPlus.pth", +} + +pretrained_gpt_name = { + "v1": "GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt", + "v2": "GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s1bert25hz-5kh-longer-epoch=12-step=369668.ckpt", + "v3": "GPT_SoVITS/pretrained_models/s1v3.ckpt", + "v4": "GPT_SoVITS/pretrained_models/s1v3.ckpt", + "v2Pro": "GPT_SoVITS/pretrained_models/s1v3.ckpt", + "v2ProPlus": "GPT_SoVITS/pretrained_models/s1v3.ckpt", +} +name2sovits_path = { + # i18n("不训练直接推v1底模!"): "GPT_SoVITS/pretrained_models/s2G488k.pth", + i18n("不训练直接推v2底模!"): "GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s2G2333k.pth", + # i18n("不训练直接推v3底模!"): "GPT_SoVITS/pretrained_models/s2Gv3.pth", + # i18n("不训练直接推v4底模!"): "GPT_SoVITS/pretrained_models/gsv-v4-pretrained/s2Gv4.pth", + i18n("不训练直接推v2Pro底模!"): "GPT_SoVITS/pretrained_models/v2Pro/s2Gv2Pro.pth", + i18n("不训练直接推v2ProPlus底模!"): "GPT_SoVITS/pretrained_models/v2Pro/s2Gv2ProPlus.pth", +} +name2gpt_path = { + # i18n("不训练直接推v1底模!"):"GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt", + i18n( + "不训练直接推v2底模!" + ): "GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s1bert25hz-5kh-longer-epoch=12-step=369668.ckpt", + i18n("不训练直接推v3底模!"): "GPT_SoVITS/pretrained_models/s1v3.ckpt", +} +SoVITS_weight_root = [ + "SoVITS_weights", + "SoVITS_weights_v2", + "SoVITS_weights_v3", + "SoVITS_weights_v4", + "SoVITS_weights_v2Pro", + "SoVITS_weights_v2ProPlus", +] +GPT_weight_root = [ + "GPT_weights", + "GPT_weights_v2", + "GPT_weights_v3", + "GPT_weights_v4", + "GPT_weights_v2Pro", + "GPT_weights_v2ProPlus", +] +SoVITS_weight_version2root = { + "v1": "SoVITS_weights", + "v2": "SoVITS_weights_v2", + "v3": "SoVITS_weights_v3", + "v4": "SoVITS_weights_v4", + "v2Pro": "SoVITS_weights_v2Pro", + "v2ProPlus": "SoVITS_weights_v2ProPlus", +} +GPT_weight_version2root = { + "v1": "GPT_weights", + "v2": "GPT_weights_v2", + "v3": "GPT_weights_v3", + "v4": "GPT_weights_v4", + "v2Pro": "GPT_weights_v2Pro", + "v2ProPlus": "GPT_weights_v2ProPlus", +} + + +def custom_sort_key(s): + # 使用正则表达式提取字符串中的数字部分和非数字部分 + parts = re.split("(\d+)", s) + # 将数字部分转换为整数,非数字部分保持不变 + parts = [int(part) if part.isdigit() else part for part in parts] + return parts + + +def get_weights_names(): + SoVITS_names = [] + for key in name2sovits_path: + if os.path.exists(name2sovits_path[key]): + SoVITS_names.append(key) + for path in SoVITS_weight_root: + if not os.path.exists(path): + continue + for name in os.listdir(path): + if name.endswith(".pth"): + SoVITS_names.append("%s/%s" % (path, name)) + if not SoVITS_names: + SoVITS_names = [""] + GPT_names = [] + for key in name2gpt_path: + if os.path.exists(name2gpt_path[key]): + GPT_names.append(key) + for path in GPT_weight_root: + if not os.path.exists(path): + continue + for name in os.listdir(path): + if name.endswith(".ckpt"): + GPT_names.append("%s/%s" % (path, name)) + SoVITS_names = sorted(SoVITS_names, key=custom_sort_key) + GPT_names = sorted(GPT_names, key=custom_sort_key) + if not GPT_names: + GPT_names = [""] + return SoVITS_names, GPT_names + + +def change_choices(): + SoVITS_names, GPT_names = get_weights_names() + return {"choices": SoVITS_names, "__type__": "update"}, { + "choices": GPT_names, + "__type__": "update", + } + + +# 推理用的指定模型 +sovits_path = "" +gpt_path = "" +is_half_str = os.environ.get("is_half", "True") +is_half = True if is_half_str.lower() == "true" else False +is_share_str = os.environ.get("is_share", "False") +is_share = True if is_share_str.lower() == "true" else False + +cnhubert_path = "GPT_SoVITS/pretrained_models/chinese-hubert-base" +bert_path = "GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large" +pretrained_sovits_path = "GPT_SoVITS/pretrained_models/s2G488k.pth" +pretrained_gpt_path = "GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt" + +exp_root = "logs" +python_exec = sys.executable or "python" + +webui_port_main = 9874 +webui_port_uvr5 = 9873 +webui_port_infer_tts = 9872 +webui_port_subfix = 9871 + +api_port = 9880 + + +# Thanks to the contribution of @Karasukaigan and @XXXXRT666 +# Modified for Intel GPU (XPU) +def get_device_dtype_sm(idx: int) -> tuple[torch.device, torch.dtype, float, float]: + cpu = torch.device("cpu") + try: + if not torch.xpu.is_available(): + return cpu, torch.float32, 0.0, 0.0 + except AttributeError: + return cpu, torch.float32, 0.0, 0.0 + + xpu_device = torch.device(f"xpu:{idx}") + properties = torch.xpu.get_device_properties(idx) + mem_bytes = properties.total_memory + mem_gb = mem_bytes / (1024**3) + + # Simplified logic for XPU, assuming FP16/BF16 is generally supported. + # The complex SM version check is CUDA-specific. + if mem_gb < 4: # Example threshold + return cpu, torch.float32, 0.0, 0.0 + + # For Intel GPUs, we can generally assume float16 is available. + # The 'sm_version' equivalent is not straightforward, so we use a placeholder value (e.g., 1.0) + # for compatibility with the downstream logic that sorts devices. + return xpu_device, torch.float16, 1.0, mem_gb + + +IS_GPU = True +GPU_INFOS: list[str] = [] +GPU_INDEX: set[int] = set() +try: + GPU_COUNT = torch.xpu.device_count() if torch.xpu.is_available() else 0 +except AttributeError: + GPU_COUNT = 0 +CPU_INFO: str = "0\tCPU " + i18n("CPU训练,较慢") +tmp: list[tuple[torch.device, torch.dtype, float, float]] = [] +memset: set[float] = set() + +for i in range(max(GPU_COUNT, 1)): + tmp.append(get_device_dtype_sm(i)) + +for j in tmp: + device = j[0] + memset.add(j[3]) + if device.type == "xpu": + GPU_INFOS.append(f"{device.index}\t{torch.xpu.get_device_name(device.index)}") + GPU_INDEX.add(device.index) + +if not GPU_INFOS: + IS_GPU = False + GPU_INFOS.append(CPU_INFO) + GPU_INDEX.add(0) + +infer_device = max(tmp, key=lambda x: (x[2], x[3]))[0] +is_half = any(dtype == torch.float16 for _, dtype, _, _ in tmp) + + +class Config: + def __init__(self): + self.sovits_path = sovits_path + self.gpt_path = gpt_path + self.is_half = is_half + + self.cnhubert_path = cnhubert_path + self.bert_path = bert_path + self.pretrained_sovits_path = pretrained_sovits_path + self.pretrained_gpt_path = pretrained_gpt_path + + self.exp_root = exp_root + self.python_exec = python_exec + self.infer_device = infer_device + + self.webui_port_main = webui_port_main + self.webui_port_uvr5 = webui_port_uvr5 + self.webui_port_infer_tts = webui_port_infer_tts + self.webui_port_subfix = webui_port_subfix + + self.api_port = api_port diff --git a/docker_build.sh b/docker_build.sh index 354599d2..30035267 100644 --- a/docker_build.sh +++ b/docker_build.sh @@ -14,49 +14,29 @@ fi trap 'echo "Error Occured at \"$BASH_COMMAND\" with exit code $?"; exit 1' ERR LITE=false -CUDA_VERSION=12.6 print_help() { echo "Usage: bash docker_build.sh [OPTIONS]" echo "" echo "Options:" - echo " --cuda 12.6|12.8 Specify the CUDA VERSION (REQUIRED)" echo " --lite Build a Lite Image" echo " -h, --help Show this help message and exit" echo "" echo "Examples:" - echo " bash docker_build.sh --cuda 12.6 --funasr --faster-whisper" + echo " bash docker_build.sh --lite" } -# Show help if no arguments provided -if [[ $# -eq 0 ]]; then - print_help - exit 0 -fi - # Parse arguments while [[ $# -gt 0 ]]; do case "$1" in - --cuda) - case "$2" in - 12.6) - CUDA_VERSION=12.6 - ;; - 12.8) - CUDA_VERSION=12.8 - ;; - *) - echo "Error: Invalid CUDA_VERSION: $2" - echo "Choose From: [12.6, 12.8]" - exit 1 - ;; - esac - shift 2 - ;; --lite) LITE=true shift ;; + -h|--help) + print_help + exit 0 + ;; *) echo "Unknown Argument: $1" echo "Use -h or --help to see available options." @@ -74,7 +54,6 @@ else fi docker build \ - --build-arg CUDA_VERSION=$CUDA_VERSION \ --build-arg LITE=$LITE \ --build-arg TARGETPLATFORM="$TARGETPLATFORM" \ --build-arg TORCH_BASE=$TORCH_BASE \ diff --git a/requirements.txt b/requirements.txt index 90e4957d..39cd7293 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,6 +5,9 @@ tensorboard librosa==0.10.2 numba pytorch-lightning>=2.4 +torch==2.9 +intel-extension-for-pytorch +torchvision gradio<5 ffmpeg-python onnxruntime; platform_machine == "aarch64" or platform_machine == "arm64" diff --git a/tools/asr/fasterwhisper_asr.py b/tools/asr/fasterwhisper_asr.py index a2ebe975..30329371 100644 --- a/tools/asr/fasterwhisper_asr.py +++ b/tools/asr/fasterwhisper_asr.py @@ -85,7 +85,7 @@ def execute_asr(input_folder, output_folder, model_path, language, precision): if language == "auto": language = None # 不设置语种由模型自动输出概率最高的语种 print("loading faster whisper model:", model_path, model_path) - device = "cuda" if torch.cuda.is_available() else "cpu" + device = "xpu" if torch.xpu.is_available() else "cpu" model = WhisperModel(model_path, device=device, compute_type=precision) input_file_names = os.listdir(input_folder) @@ -128,8 +128,6 @@ def execute_asr(input_folder, output_folder, model_path, language, precision): return output_file_path -load_cudnn() - if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument( diff --git a/tools/my_utils.py b/tools/my_utils.py index 04f1a98a..b66878f6 100644 --- a/tools/my_utils.py +++ b/tools/my_utils.py @@ -1,231 +1,137 @@ -import ctypes -import os -import sys -from pathlib import Path - -import ffmpeg -import gradio as gr -import numpy as np -import pandas as pd - -from tools.i18n.i18n import I18nAuto - -i18n = I18nAuto(language=os.environ.get("language", "Auto")) - - -def load_audio(file, sr): - try: - # https://github.com/openai/whisper/blob/main/whisper/audio.py#L26 - # This launches a subprocess to decode audio while down-mixing and resampling as necessary. - # Requires the ffmpeg CLI and `ffmpeg-python` package to be installed. - file = clean_path(file) # 防止小白拷路径头尾带了空格和"和回车 - if os.path.exists(file) is False: - raise RuntimeError("You input a wrong audio path that does not exists, please fix it!") - out, _ = ( - ffmpeg.input(file, threads=0) - .output("-", format="f32le", acodec="pcm_f32le", ac=1, ar=sr) - .run(cmd=["ffmpeg", "-nostdin"], capture_stdout=True, capture_stderr=True) - ) - except Exception: - out, _ = ( - ffmpeg.input(file, threads=0) - .output("-", format="f32le", acodec="pcm_f32le", ac=1, ar=sr) - .run(cmd=["ffmpeg", "-nostdin"], capture_stdout=True) - ) # Expose the Error - raise RuntimeError(i18n("音频加载失败")) - - return np.frombuffer(out, np.float32).flatten() - - -def clean_path(path_str: str): - if path_str.endswith(("\\", "/")): - return clean_path(path_str[0:-1]) - path_str = path_str.replace("/", os.sep).replace("\\", os.sep) - return path_str.strip( - " '\n\"\u202a" - ) # path_str.strip(" ").strip('\'').strip("\n").strip('"').strip(" ").strip("\u202a") - - -def check_for_existance(file_list: list = None, is_train=False, is_dataset_processing=False): - files_status = [] - if is_train == True and file_list: - file_list.append(os.path.join(file_list[0], "2-name2text.txt")) - file_list.append(os.path.join(file_list[0], "3-bert")) - file_list.append(os.path.join(file_list[0], "4-cnhubert")) - file_list.append(os.path.join(file_list[0], "5-wav32k")) - file_list.append(os.path.join(file_list[0], "6-name2semantic.tsv")) - for file in file_list: - if os.path.exists(file): - files_status.append(True) - else: - files_status.append(False) - if sum(files_status) != len(files_status): - if is_train: - for file, status in zip(file_list, files_status): - if status: - pass - else: - gr.Warning(file) - gr.Warning(i18n("以下文件或文件夹不存在")) - return False - elif is_dataset_processing: - if files_status[0]: - return True - elif not files_status[0]: - gr.Warning(file_list[0]) - elif not files_status[1] and file_list[1]: - gr.Warning(file_list[1]) - gr.Warning(i18n("以下文件或文件夹不存在")) - return False - else: - if file_list[0]: - gr.Warning(file_list[0]) - gr.Warning(i18n("以下文件或文件夹不存在")) - else: - gr.Warning(i18n("路径不能为空")) - return False - return True - - -def check_details(path_list=None, is_train=False, is_dataset_processing=False): - if is_dataset_processing: - list_path, audio_path = path_list - if not list_path.endswith(".list"): - gr.Warning(i18n("请填入正确的List路径")) - return - if audio_path: - if not os.path.isdir(audio_path): - gr.Warning(i18n("请填入正确的音频文件夹路径")) - return - with open(list_path, "r", encoding="utf8") as f: - line = f.readline().strip("\n").split("\n") - wav_name, _, __, ___ = line[0].split("|") - wav_name = clean_path(wav_name) - if audio_path != "" and audio_path != None: - wav_name = os.path.basename(wav_name) - wav_path = "%s/%s" % (audio_path, wav_name) - else: - wav_path = wav_name - if os.path.exists(wav_path): - ... - else: - gr.Warning(wav_path + i18n("路径错误")) - return - if is_train: - path_list.append(os.path.join(path_list[0], "2-name2text.txt")) - path_list.append(os.path.join(path_list[0], "4-cnhubert")) - path_list.append(os.path.join(path_list[0], "5-wav32k")) - path_list.append(os.path.join(path_list[0], "6-name2semantic.tsv")) - phone_path, hubert_path, wav_path, semantic_path = path_list[1:] - with open(phone_path, "r", encoding="utf-8") as f: - if f.read(1): - ... - else: - gr.Warning(i18n("缺少音素数据集")) - if os.listdir(hubert_path): - ... - else: - gr.Warning(i18n("缺少Hubert数据集")) - if os.listdir(wav_path): - ... - else: - gr.Warning(i18n("缺少音频数据集")) - df = pd.read_csv(semantic_path, delimiter="\t", encoding="utf-8") - if len(df) >= 1: - ... - else: - gr.Warning(i18n("缺少语义数据集")) - - -def load_cudnn(): - import torch - - if not torch.cuda.is_available(): - print("[INFO] CUDA is not available, skipping cuDNN setup.") - return - - if sys.platform == "win32": - torch_lib_dir = Path(torch.__file__).parent / "lib" - if torch_lib_dir.exists(): - os.add_dll_directory(str(torch_lib_dir)) - print(f"[INFO] Added DLL directory: {torch_lib_dir}") - matching_files = sorted(torch_lib_dir.glob("cudnn_cnn*.dll")) - if not matching_files: - print(f"[ERROR] No cudnn_cnn*.dll found in {torch_lib_dir}") - return - for dll_path in matching_files: - dll_name = os.path.basename(dll_path) - try: - ctypes.CDLL(dll_name) - print(f"[INFO] Loaded: {dll_name}") - except OSError as e: - print(f"[WARNING] Failed to load {dll_name}: {e}") - else: - print(f"[WARNING] Torch lib directory not found: {torch_lib_dir}") - - elif sys.platform == "linux": - site_packages = Path(torch.__file__).resolve().parents[1] - cudnn_dir = site_packages / "nvidia" / "cudnn" / "lib" - - if not cudnn_dir.exists(): - print(f"[ERROR] cudnn dir not found: {cudnn_dir}") - return - - matching_files = sorted(cudnn_dir.glob("libcudnn_cnn*.so*")) - if not matching_files: - print(f"[ERROR] No libcudnn_cnn*.so* found in {cudnn_dir}") - return - - for so_path in matching_files: - try: - ctypes.CDLL(so_path, mode=ctypes.RTLD_GLOBAL) # type: ignore - print(f"[INFO] Loaded: {so_path}") - except OSError as e: - print(f"[WARNING] Failed to load {so_path}: {e}") - - -def load_nvrtc(): - import torch - - if not torch.cuda.is_available(): - print("[INFO] CUDA is not available, skipping nvrtc setup.") - return - - if sys.platform == "win32": - torch_lib_dir = Path(torch.__file__).parent / "lib" - if torch_lib_dir.exists(): - os.add_dll_directory(str(torch_lib_dir)) - print(f"[INFO] Added DLL directory: {torch_lib_dir}") - matching_files = sorted(torch_lib_dir.glob("nvrtc*.dll")) - if not matching_files: - print(f"[ERROR] No nvrtc*.dll found in {torch_lib_dir}") - return - for dll_path in matching_files: - dll_name = os.path.basename(dll_path) - try: - ctypes.CDLL(dll_name) - print(f"[INFO] Loaded: {dll_name}") - except OSError as e: - print(f"[WARNING] Failed to load {dll_name}: {e}") - else: - print(f"[WARNING] Torch lib directory not found: {torch_lib_dir}") - - elif sys.platform == "linux": - site_packages = Path(torch.__file__).resolve().parents[1] - nvrtc_dir = site_packages / "nvidia" / "cuda_nvrtc" / "lib" - - if not nvrtc_dir.exists(): - print(f"[ERROR] nvrtc dir not found: {nvrtc_dir}") - return - - matching_files = sorted(nvrtc_dir.glob("libnvrtc*.so*")) - if not matching_files: - print(f"[ERROR] No libnvrtc*.so* found in {nvrtc_dir}") - return - - for so_path in matching_files: - try: - ctypes.CDLL(so_path, mode=ctypes.RTLD_GLOBAL) # type: ignore - print(f"[INFO] Loaded: {so_path}") - except OSError as e: - print(f"[WARNING] Failed to load {so_path}: {e}") +import ctypes +import os +import sys +from pathlib import Path + +import ffmpeg +import gradio as gr +import numpy as np +import pandas as pd + +from tools.i18n.i18n import I18nAuto + +i18n = I18nAuto(language=os.environ.get("language", "Auto")) + + +def load_audio(file, sr): + try: + # https://github.com/openai/whisper/blob/main/whisper/audio.py#L26 + # This launches a subprocess to decode audio while down-mixing and resampling as necessary. + # Requires the ffmpeg CLI and `ffmpeg-python` package to be installed. + file = clean_path(file) # 防止小白拷路径头尾带了空格和"和回车 + if os.path.exists(file) is False: + raise RuntimeError("You input a wrong audio path that does not exists, please fix it!") + out, _ = ( + ffmpeg.input(file, threads=0) + .output("-", format="f32le", acodec="pcm_f32le", ac=1, ar=sr) + .run(cmd=["ffmpeg", "-nostdin"], capture_stdout=True, capture_stderr=True) + ) + except Exception: + out, _ = ( + ffmpeg.input(file, threads=0) + .output("-", format="f32le", acodec="pcm_f32le", ac=1, ar=sr) + .run(cmd=["ffmpeg", "-nostdin"], capture_stdout=True) + ) # Expose the Error + raise RuntimeError(i18n("音频加载失败")) + + return np.frombuffer(out, np.float32).flatten() + + +def clean_path(path_str: str): + if path_str.endswith(("\\", "/")): + return clean_path(path_str[0:-1]) + path_str = path_str.replace("/", os.sep).replace("\\", os.sep) + return path_str.strip( + " '\n\"\u202a" + ) # path_str.strip(" ").strip('\'').strip("\n").strip('"').strip(" ").strip("\u202a") + + +def check_for_existance(file_list: list = None, is_train=False, is_dataset_processing=False): + files_status = [] + if is_train == True and file_list: + file_list.append(os.path.join(file_list[0], "2-name2text.txt")) + file_list.append(os.path.join(file_list[0], "3-bert")) + file_list.append(os.path.join(file_list[0], "4-cnhubert")) + file_list.append(os.path.join(file_list[0], "5-wav32k")) + file_list.append(os.path.join(file_list[0], "6-name2semantic.tsv")) + for file in file_list: + if os.path.exists(file): + files_status.append(True) + else: + files_status.append(False) + if sum(files_status) != len(files_status): + if is_train: + for file, status in zip(file_list, files_status): + if status: + pass + else: + gr.Warning(file) + gr.Warning(i18n("以下文件或文件夹不存在")) + return False + elif is_dataset_processing: + if files_status[0]: + return True + elif not files_status[0]: + gr.Warning(file_list[0]) + elif not files_status[1] and file_list[1]: + gr.Warning(file_list[1]) + gr.Warning(i18n("以下文件或文件夹不存在")) + return False + else: + if file_list[0]: + gr.Warning(file_list[0]) + gr.Warning(i18n("以下文件或文件夹不存在")) + else: + gr.Warning(i18n("路径不能为空")) + return False + return True + + +def check_details(path_list=None, is_train=False, is_dataset_processing=False): + if is_dataset_processing: + list_path, audio_path = path_list + if not list_path.endswith(".list"): + gr.Warning(i18n("请填入正确的List路径")) + return + if audio_path: + if not os.path.isdir(audio_path): + gr.Warning(i18n("请填入正确的音频文件夹路径")) + return + with open(list_path, "r", encoding="utf8") as f: + line = f.readline().strip("\n").split("\n") + wav_name, _, __, ___ = line[0].split("|") + wav_name = clean_path(wav_name) + if audio_path != "" and audio_path != None: + wav_name = os.path.basename(wav_name) + wav_path = "%s/%s" % (audio_path, wav_name) + else: + wav_path = wav_name + if os.path.exists(wav_path): + ... + else: + gr.Warning(wav_path + i18n("路径错误")) + return + if is_train: + path_list.append(os.path.join(path_list[0], "2-name2text.txt")) + path_list.append(os.path.join(path_list[0], "4-cnhubert")) + path_list.append(os.path.join(path_list[0], "5-wav32k")) + path_list.append(os.path.join(path_list[0], "6-name2semantic.tsv")) + phone_path, hubert_path, wav_path, semantic_path = path_list[1:] + with open(phone_path, "r", encoding="utf-8") as f: + if f.read(1): + ... + else: + gr.Warning(i18n("缺少音素数据集")) + if os.listdir(hubert_path): + ... + else: + gr.Warning(i18n("缺少Hubert数据集")) + if os.listdir(wav_path): + ... + else: + gr.Warning(i18n("缺少音频数据集")) + df = pd.read_csv(semantic_path, delimiter="\t", encoding="utf-8") + if len(df) >= 1: + ... + else: + gr.Warning(i18n("缺少语义数据集"))