mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2025-12-16 09:16:59 +08:00
feat: Migrate from CUDA to XPU for Intel GPU support
This commit migrates the project from using NVIDIA CUDA to Intel XPU for GPU acceleration, based on the PyTorch 2.9 release. Key changes include: - Replaced `torch.cuda` with `torch.xpu` for device checks, memory management, and distributed training. - Updated device strings from "cuda" to "xpu" across the codebase. - Switched the distributed training backend from "nccl" to "ccl" for Intel GPUs. - Disabled custom CUDA kernels in the `BigVGAN` module by setting `use_cuda_kernel=False`. - Updated `requirements.txt` to include `torch==2.9` and `intel-extension-for-pytorch`. - Modified CI/CD pipelines and build scripts to remove CUDA dependencies and build for an XPU target.
This commit is contained in:
parent
11aa78bd9b
commit
d3b8f7e09e
@ -15,11 +15,7 @@ on:
|
|||||||
jobs:
|
jobs:
|
||||||
build:
|
build:
|
||||||
runs-on: windows-latest
|
runs-on: windows-latest
|
||||||
strategy:
|
|
||||||
matrix:
|
|
||||||
torch_cuda: [cu124, cu128]
|
|
||||||
env:
|
env:
|
||||||
TORCH_CUDA: ${{ matrix.torch_cuda }}
|
|
||||||
MODELSCOPE_USERNAME: ${{ secrets.MODELSCOPE_USERNAME }}
|
MODELSCOPE_USERNAME: ${{ secrets.MODELSCOPE_USERNAME }}
|
||||||
MODELSCOPE_TOKEN: ${{ secrets.MODELSCOPE_TOKEN }}
|
MODELSCOPE_TOKEN: ${{ secrets.MODELSCOPE_TOKEN }}
|
||||||
HUGGINGFACE_USERNAME: ${{ secrets.HUGGINGFACE_USERNAME }}
|
HUGGINGFACE_USERNAME: ${{ secrets.HUGGINGFACE_USERNAME }}
|
||||||
|
|||||||
234
.github/workflows/docker-publish.yaml
vendored
234
.github/workflows/docker-publish.yaml
vendored
@ -18,68 +18,22 @@ jobs:
|
|||||||
DATE=$(date +'%Y%m%d')
|
DATE=$(date +'%Y%m%d')
|
||||||
COMMIT=$(git rev-parse --short=6 HEAD)
|
COMMIT=$(git rev-parse --short=6 HEAD)
|
||||||
echo "tag=${DATE}-${COMMIT}" >> $GITHUB_OUTPUT
|
echo "tag=${DATE}-${COMMIT}" >> $GITHUB_OUTPUT
|
||||||
build-amd64:
|
build-and-publish:
|
||||||
needs: generate-meta
|
needs: generate-meta
|
||||||
runs-on: ubuntu-22.04
|
runs-on: ubuntu-22.04
|
||||||
strategy:
|
strategy:
|
||||||
matrix:
|
matrix:
|
||||||
include:
|
include:
|
||||||
- cuda_version: 12.6
|
- lite: true
|
||||||
lite: true
|
|
||||||
torch_base: lite
|
torch_base: lite
|
||||||
tag_prefix: cu126-lite
|
tag_prefix: xpu-lite
|
||||||
- cuda_version: 12.6
|
- lite: false
|
||||||
lite: false
|
|
||||||
torch_base: full
|
torch_base: full
|
||||||
tag_prefix: cu126
|
tag_prefix: xpu
|
||||||
- cuda_version: 12.8
|
|
||||||
lite: true
|
|
||||||
torch_base: lite
|
|
||||||
tag_prefix: cu128-lite
|
|
||||||
- cuda_version: 12.8
|
|
||||||
lite: false
|
|
||||||
torch_base: full
|
|
||||||
tag_prefix: cu128
|
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout Code
|
- name: Checkout Code
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
- name: Free up disk space
|
|
||||||
run: |
|
|
||||||
echo "Before cleanup:"
|
|
||||||
df -h
|
|
||||||
|
|
||||||
sudo rm -rf /opt/ghc
|
|
||||||
sudo rm -rf /opt/hostedtoolcache/CodeQL
|
|
||||||
sudo rm -rf /opt/hostedtoolcache/PyPy
|
|
||||||
sudo rm -rf /opt/hostedtoolcache/go
|
|
||||||
sudo rm -rf /opt/hostedtoolcache/node
|
|
||||||
sudo rm -rf /opt/hostedtoolcache/Ruby
|
|
||||||
sudo rm -rf /opt/microsoft
|
|
||||||
sudo rm -rf /opt/pipx
|
|
||||||
sudo rm -rf /opt/az
|
|
||||||
sudo rm -rf /opt/google
|
|
||||||
|
|
||||||
|
|
||||||
sudo rm -rf /usr/lib/jvm
|
|
||||||
sudo rm -rf /usr/lib/google-cloud-sdk
|
|
||||||
sudo rm -rf /usr/lib/dotnet
|
|
||||||
|
|
||||||
sudo rm -rf /usr/local/lib/android
|
|
||||||
sudo rm -rf /usr/local/.ghcup
|
|
||||||
sudo rm -rf /usr/local/julia1.11.5
|
|
||||||
sudo rm -rf /usr/local/share/powershell
|
|
||||||
sudo rm -rf /usr/local/share/chromium
|
|
||||||
|
|
||||||
sudo rm -rf /usr/share/swift
|
|
||||||
sudo rm -rf /usr/share/miniconda
|
|
||||||
sudo rm -rf /usr/share/az_12.1.0
|
|
||||||
sudo rm -rf /usr/share/dotnet
|
|
||||||
|
|
||||||
echo "After cleanup:"
|
|
||||||
df -h
|
|
||||||
|
|
||||||
- name: Set up Docker Buildx
|
- name: Set up Docker Buildx
|
||||||
uses: docker/setup-buildx-action@v3
|
uses: docker/setup-buildx-action@v3
|
||||||
|
|
||||||
@ -89,188 +43,18 @@ jobs:
|
|||||||
username: ${{ secrets.DOCKER_HUB_USERNAME }}
|
username: ${{ secrets.DOCKER_HUB_USERNAME }}
|
||||||
password: ${{ secrets.DOCKER_HUB_PASSWORD }}
|
password: ${{ secrets.DOCKER_HUB_PASSWORD }}
|
||||||
|
|
||||||
- name: Build and Push Docker Image (amd64)
|
- name: Build and Push Docker Image
|
||||||
uses: docker/build-push-action@v5
|
uses: docker/build-push-action@v5
|
||||||
with:
|
with:
|
||||||
context: .
|
context: .
|
||||||
file: ./Dockerfile
|
file: ./Dockerfile
|
||||||
push: true
|
push: true
|
||||||
platforms: linux/amd64
|
platforms: linux/amd64,linux/arm64
|
||||||
build-args: |
|
build-args: |
|
||||||
LITE=${{ matrix.lite }}
|
LITE=${{ matrix.lite }}
|
||||||
TORCH_BASE=${{ matrix.torch_base }}
|
TORCH_BASE=${{ matrix.torch_base }}
|
||||||
CUDA_VERSION=${{ matrix.cuda_version }}
|
|
||||||
WORKFLOW=true
|
WORKFLOW=true
|
||||||
tags: |
|
tags: |
|
||||||
xxxxrt666/gpt-sovits:${{ matrix.tag_prefix }}-${{ needs.generate-meta.outputs.tag }}-amd64
|
xxxxrt666/gpt-sovits:${{ matrix.tag_prefix }}-${{ needs.generate-meta.outputs.tag }}
|
||||||
xxxxrt666/gpt-sovits:latest-${{ matrix.tag_prefix }}-amd64
|
xxxxrt666/gpt-sovits:latest-${{ matrix.tag_prefix }}
|
||||||
|
|
||||||
build-arm64:
|
|
||||||
needs: generate-meta
|
|
||||||
runs-on: ubuntu-22.04-arm
|
|
||||||
strategy:
|
|
||||||
matrix:
|
|
||||||
include:
|
|
||||||
- cuda_version: 12.6
|
|
||||||
lite: true
|
|
||||||
torch_base: lite
|
|
||||||
tag_prefix: cu126-lite
|
|
||||||
- cuda_version: 12.6
|
|
||||||
lite: false
|
|
||||||
torch_base: full
|
|
||||||
tag_prefix: cu126
|
|
||||||
- cuda_version: 12.8
|
|
||||||
lite: true
|
|
||||||
torch_base: lite
|
|
||||||
tag_prefix: cu128-lite
|
|
||||||
- cuda_version: 12.8
|
|
||||||
lite: false
|
|
||||||
torch_base: full
|
|
||||||
tag_prefix: cu128
|
|
||||||
|
|
||||||
steps:
|
|
||||||
- name: Checkout Code
|
|
||||||
uses: actions/checkout@v4
|
|
||||||
|
|
||||||
- name: Free up disk space
|
|
||||||
run: |
|
|
||||||
echo "Before cleanup:"
|
|
||||||
df -h
|
|
||||||
|
|
||||||
sudo rm -rf /opt/ghc
|
|
||||||
sudo rm -rf /opt/hostedtoolcache/CodeQL
|
|
||||||
sudo rm -rf /opt/hostedtoolcache/PyPy
|
|
||||||
sudo rm -rf /opt/hostedtoolcache/go
|
|
||||||
sudo rm -rf /opt/hostedtoolcache/node
|
|
||||||
sudo rm -rf /opt/hostedtoolcache/Ruby
|
|
||||||
sudo rm -rf /opt/microsoft
|
|
||||||
sudo rm -rf /opt/pipx
|
|
||||||
sudo rm -rf /opt/az
|
|
||||||
sudo rm -rf /opt/google
|
|
||||||
|
|
||||||
|
|
||||||
sudo rm -rf /usr/lib/jvm
|
|
||||||
sudo rm -rf /usr/lib/google-cloud-sdk
|
|
||||||
sudo rm -rf /usr/lib/dotnet
|
|
||||||
|
|
||||||
sudo rm -rf /usr/local/lib/android
|
|
||||||
sudo rm -rf /usr/local/.ghcup
|
|
||||||
sudo rm -rf /usr/local/julia1.11.5
|
|
||||||
sudo rm -rf /usr/local/share/powershell
|
|
||||||
sudo rm -rf /usr/local/share/chromium
|
|
||||||
|
|
||||||
sudo rm -rf /usr/share/swift
|
|
||||||
sudo rm -rf /usr/share/miniconda
|
|
||||||
sudo rm -rf /usr/share/az_12.1.0
|
|
||||||
sudo rm -rf /usr/share/dotnet
|
|
||||||
|
|
||||||
echo "After cleanup:"
|
|
||||||
df -h
|
|
||||||
|
|
||||||
- name: Set up Docker Buildx
|
|
||||||
uses: docker/setup-buildx-action@v3
|
|
||||||
|
|
||||||
- name: Log in to Docker Hub
|
|
||||||
uses: docker/login-action@v3
|
|
||||||
with:
|
|
||||||
username: ${{ secrets.DOCKER_HUB_USERNAME }}
|
|
||||||
password: ${{ secrets.DOCKER_HUB_PASSWORD }}
|
|
||||||
|
|
||||||
- name: Build and Push Docker Image (arm64)
|
|
||||||
uses: docker/build-push-action@v5
|
|
||||||
with:
|
|
||||||
context: .
|
|
||||||
file: ./Dockerfile
|
|
||||||
push: true
|
|
||||||
platforms: linux/arm64
|
|
||||||
build-args: |
|
|
||||||
LITE=${{ matrix.lite }}
|
|
||||||
TORCH_BASE=${{ matrix.torch_base }}
|
|
||||||
CUDA_VERSION=${{ matrix.cuda_version }}
|
|
||||||
WORKFLOW=true
|
|
||||||
tags: |
|
|
||||||
xxxxrt666/gpt-sovits:${{ matrix.tag_prefix }}-${{ needs.generate-meta.outputs.tag }}-arm64
|
|
||||||
xxxxrt666/gpt-sovits:latest-${{ matrix.tag_prefix }}-arm64
|
|
||||||
|
|
||||||
|
|
||||||
merge-and-clean:
|
|
||||||
needs:
|
|
||||||
- build-amd64
|
|
||||||
- build-arm64
|
|
||||||
- generate-meta
|
|
||||||
runs-on: ubuntu-latest
|
|
||||||
strategy:
|
|
||||||
matrix:
|
|
||||||
include:
|
|
||||||
- tag_prefix: cu126-lite
|
|
||||||
- tag_prefix: cu126
|
|
||||||
- tag_prefix: cu128-lite
|
|
||||||
- tag_prefix: cu128
|
|
||||||
|
|
||||||
steps:
|
|
||||||
- name: Set up Docker Buildx
|
|
||||||
uses: docker/setup-buildx-action@v3
|
|
||||||
|
|
||||||
- name: Log in to Docker Hub
|
|
||||||
uses: docker/login-action@v3
|
|
||||||
with:
|
|
||||||
username: ${{ secrets.DOCKER_HUB_USERNAME }}
|
|
||||||
password: ${{ secrets.DOCKER_HUB_PASSWORD }}
|
|
||||||
|
|
||||||
- name: Merge amd64 and arm64 into multi-arch image
|
|
||||||
run: |
|
|
||||||
DATE_TAG=${{ needs.generate-meta.outputs.tag }}
|
|
||||||
TAG_PREFIX=${{ matrix.tag_prefix }}
|
|
||||||
|
|
||||||
docker buildx imagetools create \
|
|
||||||
--tag ${{ secrets.DOCKER_HUB_USERNAME }}/gpt-sovits:${TAG_PREFIX}-${DATE_TAG} \
|
|
||||||
${{ secrets.DOCKER_HUB_USERNAME }}/gpt-sovits:${TAG_PREFIX}-${DATE_TAG}-amd64 \
|
|
||||||
${{ secrets.DOCKER_HUB_USERNAME }}/gpt-sovits:${TAG_PREFIX}-${DATE_TAG}-arm64
|
|
||||||
|
|
||||||
docker buildx imagetools create \
|
|
||||||
--tag ${{ secrets.DOCKER_HUB_USERNAME }}/gpt-sovits:latest-${TAG_PREFIX} \
|
|
||||||
${{ secrets.DOCKER_HUB_USERNAME }}/gpt-sovits:latest-${TAG_PREFIX}-amd64 \
|
|
||||||
${{ secrets.DOCKER_HUB_USERNAME }}/gpt-sovits:latest-${TAG_PREFIX}-arm64
|
|
||||||
- name: Delete old platform-specific tags via Docker Hub API
|
|
||||||
env:
|
|
||||||
DOCKER_HUB_USERNAME: ${{ secrets.DOCKER_HUB_USERNAME }}
|
|
||||||
DOCKER_HUB_TOKEN: ${{ secrets.DOCKER_HUB_PASSWORD }}
|
|
||||||
TAG_PREFIX: ${{ matrix.tag_prefix }}
|
|
||||||
DATE_TAG: ${{ needs.generate-meta.outputs.tag }}
|
|
||||||
run: |
|
|
||||||
sudo apt-get update && sudo apt-get install -y jq
|
|
||||||
|
|
||||||
TOKEN=$(curl -s -u $DOCKER_HUB_USERNAME:$DOCKER_HUB_TOKEN \
|
|
||||||
"https://auth.docker.io/token?service=registry.docker.io&scope=repository:$DOCKER_HUB_USERNAME/gpt-sovits:pull,push,delete" \
|
|
||||||
| jq -r .token)
|
|
||||||
|
|
||||||
for PLATFORM in amd64 arm64; do
|
|
||||||
SAFE_PLATFORM=$(echo $PLATFORM | sed 's/\//-/g')
|
|
||||||
TAG="${TAG_PREFIX}-${DATE_TAG}-${SAFE_PLATFORM}"
|
|
||||||
LATEST_TAG="latest-${TAG_PREFIX}-${SAFE_PLATFORM}"
|
|
||||||
|
|
||||||
for DEL_TAG in "$TAG" "$LATEST_TAG"; do
|
|
||||||
echo "Deleting tag: $DEL_TAG"
|
|
||||||
curl -X DELETE -H "Authorization: Bearer $TOKEN" https://registry-1.docker.io/v2/$DOCKER_HUB_USERNAME/gpt-sovits/manifests/$DEL_TAG
|
|
||||||
done
|
|
||||||
done
|
|
||||||
create-default:
|
|
||||||
runs-on: ubuntu-latest
|
|
||||||
needs:
|
|
||||||
- merge-and-clean
|
|
||||||
steps:
|
|
||||||
- name: Set up Docker Buildx
|
|
||||||
uses: docker/setup-buildx-action@v3
|
|
||||||
|
|
||||||
- name: Log in to Docker Hub
|
|
||||||
uses: docker/login-action@v3
|
|
||||||
with:
|
|
||||||
username: ${{ secrets.DOCKER_HUB_USERNAME }}
|
|
||||||
password: ${{ secrets.DOCKER_HUB_PASSWORD }}
|
|
||||||
|
|
||||||
- name: Create Default Tag
|
|
||||||
run: |
|
|
||||||
docker buildx imagetools create \
|
|
||||||
--tag ${{ secrets.DOCKER_HUB_USERNAME }}/gpt-sovits:latest \
|
|
||||||
${{ secrets.DOCKER_HUB_USERNAME }}/gpt-sovits:latest-cu126-lite
|
|
||||||
|
|
||||||
@ -58,9 +58,11 @@ def main():
|
|||||||
parser.add_argument("--input_wavs_dir", default="test_files")
|
parser.add_argument("--input_wavs_dir", default="test_files")
|
||||||
parser.add_argument("--output_dir", default="generated_files")
|
parser.add_argument("--output_dir", default="generated_files")
|
||||||
parser.add_argument("--checkpoint_file", required=True)
|
parser.add_argument("--checkpoint_file", required=True)
|
||||||
parser.add_argument("--use_cuda_kernel", action="store_true", default=False)
|
# --use_cuda_kernel argument is removed to disable custom CUDA kernels.
|
||||||
|
# parser.add_argument("--use_cuda_kernel", action="store_true", default=False)
|
||||||
|
|
||||||
a = parser.parse_args()
|
a = parser.parse_args()
|
||||||
|
a.use_cuda_kernel = False
|
||||||
|
|
||||||
config_file = os.path.join(os.path.split(a.checkpoint_file)[0], "config.json")
|
config_file = os.path.join(os.path.split(a.checkpoint_file)[0], "config.json")
|
||||||
with open(config_file) as f:
|
with open(config_file) as f:
|
||||||
@ -72,9 +74,9 @@ def main():
|
|||||||
|
|
||||||
torch.manual_seed(h.seed)
|
torch.manual_seed(h.seed)
|
||||||
global device
|
global device
|
||||||
if torch.cuda.is_available():
|
if torch.xpu.is_available():
|
||||||
torch.cuda.manual_seed(h.seed)
|
torch.xpu.manual_seed(h.seed)
|
||||||
device = torch.device("cuda")
|
device = torch.device("xpu")
|
||||||
else:
|
else:
|
||||||
device = torch.device("cpu")
|
device = torch.device("cpu")
|
||||||
|
|
||||||
|
|||||||
@ -73,9 +73,11 @@ def main():
|
|||||||
parser.add_argument("--input_mels_dir", default="test_mel_files")
|
parser.add_argument("--input_mels_dir", default="test_mel_files")
|
||||||
parser.add_argument("--output_dir", default="generated_files_from_mel")
|
parser.add_argument("--output_dir", default="generated_files_from_mel")
|
||||||
parser.add_argument("--checkpoint_file", required=True)
|
parser.add_argument("--checkpoint_file", required=True)
|
||||||
parser.add_argument("--use_cuda_kernel", action="store_true", default=False)
|
# --use_cuda_kernel argument is removed to disable custom CUDA kernels.
|
||||||
|
# parser.add_argument("--use_cuda_kernel", action="store_true", default=False)
|
||||||
|
|
||||||
a = parser.parse_args()
|
a = parser.parse_args()
|
||||||
|
a.use_cuda_kernel = False
|
||||||
|
|
||||||
config_file = os.path.join(os.path.split(a.checkpoint_file)[0], "config.json")
|
config_file = os.path.join(os.path.split(a.checkpoint_file)[0], "config.json")
|
||||||
with open(config_file) as f:
|
with open(config_file) as f:
|
||||||
@ -87,9 +89,9 @@ def main():
|
|||||||
|
|
||||||
torch.manual_seed(h.seed)
|
torch.manual_seed(h.seed)
|
||||||
global device
|
global device
|
||||||
if torch.cuda.is_available():
|
if torch.xpu.is_available():
|
||||||
torch.cuda.manual_seed(h.seed)
|
torch.xpu.manual_seed(h.seed)
|
||||||
device = torch.device("cuda")
|
device = torch.device("xpu")
|
||||||
else:
|
else:
|
||||||
device = torch.device("cpu")
|
device = torch.device("cpu")
|
||||||
|
|
||||||
|
|||||||
@ -1,215 +0,0 @@
|
|||||||
# Copyright (c) 2024 NVIDIA CORPORATION.
|
|
||||||
# Licensed under the MIT license.
|
|
||||||
|
|
||||||
import os
|
|
||||||
import sys
|
|
||||||
|
|
||||||
# to import modules from parent_dir
|
|
||||||
parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
|
||||||
sys.path.append(parent_dir)
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import json
|
|
||||||
from env import AttrDict
|
|
||||||
from bigvgan import BigVGAN
|
|
||||||
from time import time
|
|
||||||
from tqdm import tqdm
|
|
||||||
from meldataset import mel_spectrogram, MAX_WAV_VALUE
|
|
||||||
from scipy.io.wavfile import write
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
import argparse
|
|
||||||
|
|
||||||
torch.backends.cudnn.benchmark = True
|
|
||||||
|
|
||||||
# For easier debugging
|
|
||||||
torch.set_printoptions(linewidth=200, threshold=10_000)
|
|
||||||
|
|
||||||
|
|
||||||
def generate_soundwave(duration=5.0, sr=24000):
|
|
||||||
t = np.linspace(0, duration, int(sr * duration), False, dtype=np.float32)
|
|
||||||
|
|
||||||
modulation = np.sin(2 * np.pi * t / duration)
|
|
||||||
|
|
||||||
min_freq = 220
|
|
||||||
max_freq = 1760
|
|
||||||
frequencies = min_freq + (max_freq - min_freq) * (modulation + 1) / 2
|
|
||||||
soundwave = np.sin(2 * np.pi * frequencies * t)
|
|
||||||
|
|
||||||
soundwave = soundwave / np.max(np.abs(soundwave)) * 0.95
|
|
||||||
|
|
||||||
return soundwave, sr
|
|
||||||
|
|
||||||
|
|
||||||
def get_mel(x, h):
|
|
||||||
return mel_spectrogram(x, h.n_fft, h.num_mels, h.sampling_rate, h.hop_size, h.win_size, h.fmin, h.fmax)
|
|
||||||
|
|
||||||
|
|
||||||
def load_checkpoint(filepath, device):
|
|
||||||
assert os.path.isfile(filepath)
|
|
||||||
print(f"Loading '{filepath}'")
|
|
||||||
checkpoint_dict = torch.load(filepath, map_location=device)
|
|
||||||
print("Complete.")
|
|
||||||
return checkpoint_dict
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
parser = argparse.ArgumentParser(description="Test script to check CUDA kernel correctness.")
|
|
||||||
parser.add_argument(
|
|
||||||
"--checkpoint_file",
|
|
||||||
type=str,
|
|
||||||
required=True,
|
|
||||||
help="Path to the checkpoint file. Assumes config.json exists in the directory.",
|
|
||||||
)
|
|
||||||
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
config_file = os.path.join(os.path.split(args.checkpoint_file)[0], "config.json")
|
|
||||||
with open(config_file) as f:
|
|
||||||
config = f.read()
|
|
||||||
json_config = json.loads(config)
|
|
||||||
h = AttrDict({**json_config})
|
|
||||||
|
|
||||||
print("loading plain Pytorch BigVGAN")
|
|
||||||
generator_original = BigVGAN(h).to("cuda")
|
|
||||||
print("loading CUDA kernel BigVGAN with auto-build")
|
|
||||||
generator_cuda_kernel = BigVGAN(h, use_cuda_kernel=True).to("cuda")
|
|
||||||
|
|
||||||
state_dict_g = load_checkpoint(args.checkpoint_file, "cuda")
|
|
||||||
generator_original.load_state_dict(state_dict_g["generator"])
|
|
||||||
generator_cuda_kernel.load_state_dict(state_dict_g["generator"])
|
|
||||||
|
|
||||||
generator_original.remove_weight_norm()
|
|
||||||
generator_original.eval()
|
|
||||||
generator_cuda_kernel.remove_weight_norm()
|
|
||||||
generator_cuda_kernel.eval()
|
|
||||||
|
|
||||||
# define number of samples and length of mel frame to benchmark
|
|
||||||
num_sample = 10
|
|
||||||
num_mel_frame = 16384
|
|
||||||
|
|
||||||
# CUDA kernel correctness check
|
|
||||||
diff = 0.0
|
|
||||||
for i in tqdm(range(num_sample)):
|
|
||||||
# Random mel
|
|
||||||
data = torch.rand((1, h.num_mels, num_mel_frame), device="cuda")
|
|
||||||
|
|
||||||
with torch.inference_mode():
|
|
||||||
audio_original = generator_original(data)
|
|
||||||
|
|
||||||
with torch.inference_mode():
|
|
||||||
audio_cuda_kernel = generator_cuda_kernel(data)
|
|
||||||
|
|
||||||
# Both outputs should be (almost) the same
|
|
||||||
test_result = (audio_original - audio_cuda_kernel).abs()
|
|
||||||
diff += test_result.mean(dim=-1).item()
|
|
||||||
|
|
||||||
diff /= num_sample
|
|
||||||
if diff <= 2e-3: # We can expect a small difference (~1e-3) which does not affect perceptual quality
|
|
||||||
print(
|
|
||||||
f"\n[Success] test CUDA fused vs. plain torch BigVGAN inference"
|
|
||||||
f"\n > mean_difference={diff}"
|
|
||||||
f"\n > fused_values={audio_cuda_kernel[-1][-1][-30:].tolist()}"
|
|
||||||
f"\n > torch_values={audio_original[-1][-1][-30:].tolist()}"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
print(
|
|
||||||
f"\n[Fail] test CUDA fused vs. plain torch BigVGAN inference"
|
|
||||||
f"\n > mean_difference={diff}"
|
|
||||||
f"\n > fused_values={audio_cuda_kernel[-1][-1][-30:].tolist()}, "
|
|
||||||
f"\n > torch_values={audio_original[-1][-1][-30:].tolist()}"
|
|
||||||
)
|
|
||||||
|
|
||||||
del data, audio_original, audio_cuda_kernel
|
|
||||||
|
|
||||||
# Variables for tracking total time and VRAM usage
|
|
||||||
toc_total_original = 0
|
|
||||||
toc_total_cuda_kernel = 0
|
|
||||||
vram_used_original_total = 0
|
|
||||||
vram_used_cuda_kernel_total = 0
|
|
||||||
audio_length_total = 0
|
|
||||||
|
|
||||||
# Measure Original inference in isolation
|
|
||||||
for i in tqdm(range(num_sample)):
|
|
||||||
torch.cuda.reset_peak_memory_stats(device="cuda")
|
|
||||||
data = torch.rand((1, h.num_mels, num_mel_frame), device="cuda")
|
|
||||||
torch.cuda.synchronize()
|
|
||||||
tic = time()
|
|
||||||
with torch.inference_mode():
|
|
||||||
audio_original = generator_original(data)
|
|
||||||
torch.cuda.synchronize()
|
|
||||||
toc = time() - tic
|
|
||||||
toc_total_original += toc
|
|
||||||
|
|
||||||
vram_used_original_total += torch.cuda.max_memory_allocated(device="cuda")
|
|
||||||
|
|
||||||
del data, audio_original
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
|
|
||||||
# Measure CUDA kernel inference in isolation
|
|
||||||
for i in tqdm(range(num_sample)):
|
|
||||||
torch.cuda.reset_peak_memory_stats(device="cuda")
|
|
||||||
data = torch.rand((1, h.num_mels, num_mel_frame), device="cuda")
|
|
||||||
torch.cuda.synchronize()
|
|
||||||
tic = time()
|
|
||||||
with torch.inference_mode():
|
|
||||||
audio_cuda_kernel = generator_cuda_kernel(data)
|
|
||||||
torch.cuda.synchronize()
|
|
||||||
toc = time() - tic
|
|
||||||
toc_total_cuda_kernel += toc
|
|
||||||
|
|
||||||
audio_length_total += audio_cuda_kernel.shape[-1]
|
|
||||||
|
|
||||||
vram_used_cuda_kernel_total += torch.cuda.max_memory_allocated(device="cuda")
|
|
||||||
|
|
||||||
del data, audio_cuda_kernel
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
|
|
||||||
# Calculate metrics
|
|
||||||
audio_second = audio_length_total / h.sampling_rate
|
|
||||||
khz_original = audio_length_total / toc_total_original / 1000
|
|
||||||
khz_cuda_kernel = audio_length_total / toc_total_cuda_kernel / 1000
|
|
||||||
vram_used_original_gb = vram_used_original_total / num_sample / (1024**3)
|
|
||||||
vram_used_cuda_kernel_gb = vram_used_cuda_kernel_total / num_sample / (1024**3)
|
|
||||||
|
|
||||||
# Print results
|
|
||||||
print(
|
|
||||||
f"Original BigVGAN: took {toc_total_original:.2f} seconds to generate {audio_second:.2f} seconds of audio, {khz_original:.1f}kHz, {audio_second / toc_total_original:.1f} faster than realtime, VRAM used {vram_used_original_gb:.1f} GB"
|
|
||||||
)
|
|
||||||
print(
|
|
||||||
f"CUDA kernel BigVGAN: took {toc_total_cuda_kernel:.2f} seconds to generate {audio_second:.2f} seconds of audio, {khz_cuda_kernel:.1f}kHz, {audio_second / toc_total_cuda_kernel:.1f} faster than realtime, VRAM used {vram_used_cuda_kernel_gb:.1f} GB"
|
|
||||||
)
|
|
||||||
print(f"speedup of CUDA kernel: {khz_cuda_kernel / khz_original}")
|
|
||||||
print(f"VRAM saving of CUDA kernel: {vram_used_original_gb / vram_used_cuda_kernel_gb}")
|
|
||||||
|
|
||||||
# Use artificial sine waves for inference test
|
|
||||||
audio_real, sr = generate_soundwave(duration=5.0, sr=h.sampling_rate)
|
|
||||||
audio_real = torch.tensor(audio_real).to("cuda")
|
|
||||||
# Compute mel spectrogram from the ground truth audio
|
|
||||||
x = get_mel(audio_real.unsqueeze(0), h)
|
|
||||||
|
|
||||||
with torch.inference_mode():
|
|
||||||
y_g_hat_original = generator_original(x)
|
|
||||||
y_g_hat_cuda_kernel = generator_cuda_kernel(x)
|
|
||||||
|
|
||||||
audio_real = audio_real.squeeze()
|
|
||||||
audio_real = audio_real * MAX_WAV_VALUE
|
|
||||||
audio_real = audio_real.cpu().numpy().astype("int16")
|
|
||||||
|
|
||||||
audio_original = y_g_hat_original.squeeze()
|
|
||||||
audio_original = audio_original * MAX_WAV_VALUE
|
|
||||||
audio_original = audio_original.cpu().numpy().astype("int16")
|
|
||||||
|
|
||||||
audio_cuda_kernel = y_g_hat_cuda_kernel.squeeze()
|
|
||||||
audio_cuda_kernel = audio_cuda_kernel * MAX_WAV_VALUE
|
|
||||||
audio_cuda_kernel = audio_cuda_kernel.cpu().numpy().astype("int16")
|
|
||||||
|
|
||||||
os.makedirs("tmp", exist_ok=True)
|
|
||||||
output_file_real = os.path.join("tmp", "audio_real.wav")
|
|
||||||
output_file_original = os.path.join("tmp", "audio_generated_original.wav")
|
|
||||||
output_file_cuda_kernel = os.path.join("tmp", "audio_generated_cuda_kernel.wav")
|
|
||||||
write(output_file_real, h.sampling_rate, audio_real)
|
|
||||||
write(output_file_original, h.sampling_rate, audio_original)
|
|
||||||
write(output_file_cuda_kernel, h.sampling_rate, audio_cuda_kernel)
|
|
||||||
print("Example generated audios of original vs. fused CUDA kernel written to tmp!")
|
|
||||||
print("Done")
|
|
||||||
@ -110,15 +110,15 @@ def main(args):
|
|||||||
os.environ["USE_LIBUV"] = "0"
|
os.environ["USE_LIBUV"] = "0"
|
||||||
trainer: Trainer = Trainer(
|
trainer: Trainer = Trainer(
|
||||||
max_epochs=config["train"]["epochs"],
|
max_epochs=config["train"]["epochs"],
|
||||||
accelerator="gpu" if torch.cuda.is_available() else "cpu",
|
accelerator="xpu" if torch.xpu.is_available() else "cpu",
|
||||||
# val_check_interval=9999999999999999999999,###不要验证
|
# val_check_interval=9999999999999999999999,###不要验证
|
||||||
# check_val_every_n_epoch=None,
|
# check_val_every_n_epoch=None,
|
||||||
limit_val_batches=0,
|
limit_val_batches=0,
|
||||||
devices=-1 if torch.cuda.is_available() else 1,
|
devices=-1 if torch.xpu.is_available() else 1,
|
||||||
benchmark=False,
|
benchmark=False,
|
||||||
fast_dev_run=False,
|
fast_dev_run=False,
|
||||||
strategy=DDPStrategy(process_group_backend="nccl" if platform.system() != "Windows" else "gloo")
|
strategy=DDPStrategy(process_group_backend="ccl" if platform.system() != "Windows" else "gloo")
|
||||||
if torch.cuda.is_available()
|
if torch.xpu.is_available()
|
||||||
else "auto",
|
else "auto",
|
||||||
precision=config["train"]["precision"],
|
precision=config["train"]["precision"],
|
||||||
logger=logger,
|
logger=logger,
|
||||||
|
|||||||
@ -41,18 +41,18 @@ from process_ckpt import savee
|
|||||||
torch.backends.cudnn.benchmark = False
|
torch.backends.cudnn.benchmark = False
|
||||||
torch.backends.cudnn.deterministic = False
|
torch.backends.cudnn.deterministic = False
|
||||||
###反正A100fp32更快,那试试tf32吧
|
###反正A100fp32更快,那试试tf32吧
|
||||||
torch.backends.cuda.matmul.allow_tf32 = True
|
# torch.backends.cuda.matmul.allow_tf32 = True # XPU does not support this
|
||||||
torch.backends.cudnn.allow_tf32 = True
|
# torch.backends.cudnn.allow_tf32 = True # XPU does not support this
|
||||||
torch.set_float32_matmul_precision("medium") # 最低精度但最快(也就快一丁点),对于结果造成不了影响
|
torch.set_float32_matmul_precision("medium") # 最低精度但最快(也就快一丁点),对于结果造成不了影响
|
||||||
# from config import pretrained_s2G,pretrained_s2D
|
# from config import pretrained_s2G,pretrained_s2D
|
||||||
global_step = 0
|
global_step = 0
|
||||||
|
|
||||||
device = "cpu" # cuda以外的设备,等mps优化后加入
|
device = "xpu" if torch.xpu.is_available() else "cpu"
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
if torch.cuda.is_available():
|
if torch.xpu.is_available():
|
||||||
n_gpus = torch.cuda.device_count()
|
n_gpus = torch.xpu.device_count()
|
||||||
else:
|
else:
|
||||||
n_gpus = 1
|
n_gpus = 1
|
||||||
os.environ["MASTER_ADDR"] = "localhost"
|
os.environ["MASTER_ADDR"] = "localhost"
|
||||||
@ -78,14 +78,14 @@ def run(rank, n_gpus, hps):
|
|||||||
writer_eval = SummaryWriter(log_dir=os.path.join(hps.s2_ckpt_dir, "eval"))
|
writer_eval = SummaryWriter(log_dir=os.path.join(hps.s2_ckpt_dir, "eval"))
|
||||||
|
|
||||||
dist.init_process_group(
|
dist.init_process_group(
|
||||||
backend="gloo" if os.name == "nt" or not torch.cuda.is_available() else "nccl",
|
backend="gloo" if os.name == "nt" or not torch.xpu.is_available() else "ccl",
|
||||||
init_method="env://?use_libuv=False",
|
init_method="env://?use_libuv=False",
|
||||||
world_size=n_gpus,
|
world_size=n_gpus,
|
||||||
rank=rank,
|
rank=rank,
|
||||||
)
|
)
|
||||||
torch.manual_seed(hps.train.seed)
|
torch.manual_seed(hps.train.seed)
|
||||||
if torch.cuda.is_available():
|
if torch.xpu.is_available():
|
||||||
torch.cuda.set_device(rank)
|
torch.xpu.set_device(rank)
|
||||||
|
|
||||||
train_dataset = TextAudioSpeakerLoader(hps.data, version=hps.model.version)
|
train_dataset = TextAudioSpeakerLoader(hps.data, version=hps.model.version)
|
||||||
train_sampler = DistributedBucketSampler(
|
train_sampler = DistributedBucketSampler(
|
||||||
@ -132,27 +132,14 @@ def run(rank, n_gpus, hps):
|
|||||||
# batch_size=1, pin_memory=True,
|
# batch_size=1, pin_memory=True,
|
||||||
# drop_last=False, collate_fn=collate_fn)
|
# drop_last=False, collate_fn=collate_fn)
|
||||||
|
|
||||||
net_g = (
|
net_g = SynthesizerTrn(
|
||||||
SynthesizerTrn(
|
hps.data.filter_length // 2 + 1,
|
||||||
hps.data.filter_length // 2 + 1,
|
hps.train.segment_size // hps.data.hop_length,
|
||||||
hps.train.segment_size // hps.data.hop_length,
|
n_speakers=hps.data.n_speakers,
|
||||||
n_speakers=hps.data.n_speakers,
|
**hps.model,
|
||||||
**hps.model,
|
).to(device)
|
||||||
).cuda(rank)
|
|
||||||
if torch.cuda.is_available()
|
|
||||||
else SynthesizerTrn(
|
|
||||||
hps.data.filter_length // 2 + 1,
|
|
||||||
hps.train.segment_size // hps.data.hop_length,
|
|
||||||
n_speakers=hps.data.n_speakers,
|
|
||||||
**hps.model,
|
|
||||||
).to(device)
|
|
||||||
)
|
|
||||||
|
|
||||||
net_d = (
|
net_d = MultiPeriodDiscriminator(hps.model.use_spectral_norm, version=hps.model.version).to(device)
|
||||||
MultiPeriodDiscriminator(hps.model.use_spectral_norm, version=hps.model.version).cuda(rank)
|
|
||||||
if torch.cuda.is_available()
|
|
||||||
else MultiPeriodDiscriminator(hps.model.use_spectral_norm, version=hps.model.version).to(device)
|
|
||||||
)
|
|
||||||
for name, param in net_g.named_parameters():
|
for name, param in net_g.named_parameters():
|
||||||
if not param.requires_grad:
|
if not param.requires_grad:
|
||||||
print(name, "not requires_grad")
|
print(name, "not requires_grad")
|
||||||
@ -196,7 +183,7 @@ def run(rank, n_gpus, hps):
|
|||||||
betas=hps.train.betas,
|
betas=hps.train.betas,
|
||||||
eps=hps.train.eps,
|
eps=hps.train.eps,
|
||||||
)
|
)
|
||||||
if torch.cuda.is_available():
|
if torch.xpu.is_available():
|
||||||
net_g = DDP(net_g, device_ids=[rank], find_unused_parameters=True)
|
net_g = DDP(net_g, device_ids=[rank], find_unused_parameters=True)
|
||||||
net_d = DDP(net_d, device_ids=[rank], find_unused_parameters=True)
|
net_d = DDP(net_d, device_ids=[rank], find_unused_parameters=True)
|
||||||
else:
|
else:
|
||||||
@ -238,7 +225,7 @@ def run(rank, n_gpus, hps):
|
|||||||
torch.load(hps.train.pretrained_s2G, map_location="cpu", weights_only=False)["weight"],
|
torch.load(hps.train.pretrained_s2G, map_location="cpu", weights_only=False)["weight"],
|
||||||
strict=False,
|
strict=False,
|
||||||
)
|
)
|
||||||
if torch.cuda.is_available()
|
if torch.xpu.is_available()
|
||||||
else net_g.load_state_dict(
|
else net_g.load_state_dict(
|
||||||
torch.load(hps.train.pretrained_s2G, map_location="cpu", weights_only=False)["weight"],
|
torch.load(hps.train.pretrained_s2G, map_location="cpu", weights_only=False)["weight"],
|
||||||
strict=False,
|
strict=False,
|
||||||
@ -256,7 +243,7 @@ def run(rank, n_gpus, hps):
|
|||||||
net_d.module.load_state_dict(
|
net_d.module.load_state_dict(
|
||||||
torch.load(hps.train.pretrained_s2D, map_location="cpu", weights_only=False)["weight"], strict=False
|
torch.load(hps.train.pretrained_s2D, map_location="cpu", weights_only=False)["weight"], strict=False
|
||||||
)
|
)
|
||||||
if torch.cuda.is_available()
|
if torch.xpu.is_available()
|
||||||
else net_d.load_state_dict(
|
else net_d.load_state_dict(
|
||||||
torch.load(hps.train.pretrained_s2D, map_location="cpu", weights_only=False)["weight"],
|
torch.load(hps.train.pretrained_s2D, map_location="cpu", weights_only=False)["weight"],
|
||||||
),
|
),
|
||||||
@ -333,42 +320,24 @@ def train_and_evaluate(rank, epoch, hps, nets, optims, schedulers, scaler, loade
|
|||||||
ssl, ssl_lengths, spec, spec_lengths, y, y_lengths, text, text_lengths, sv_emb = data
|
ssl, ssl_lengths, spec, spec_lengths, y, y_lengths, text, text_lengths, sv_emb = data
|
||||||
else:
|
else:
|
||||||
ssl, ssl_lengths, spec, spec_lengths, y, y_lengths, text, text_lengths = data
|
ssl, ssl_lengths, spec, spec_lengths, y, y_lengths, text, text_lengths = data
|
||||||
if torch.cuda.is_available():
|
if torch.xpu.is_available():
|
||||||
spec, spec_lengths = (
|
spec, spec_lengths = (
|
||||||
spec.cuda(
|
spec.to(device, non_blocking=True),
|
||||||
rank,
|
spec_lengths.to(device, non_blocking=True),
|
||||||
non_blocking=True,
|
|
||||||
),
|
|
||||||
spec_lengths.cuda(
|
|
||||||
rank,
|
|
||||||
non_blocking=True,
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
y, y_lengths = (
|
y, y_lengths = (
|
||||||
y.cuda(
|
y.to(device, non_blocking=True),
|
||||||
rank,
|
y_lengths.to(device, non_blocking=True),
|
||||||
non_blocking=True,
|
|
||||||
),
|
|
||||||
y_lengths.cuda(
|
|
||||||
rank,
|
|
||||||
non_blocking=True,
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
ssl = ssl.cuda(rank, non_blocking=True)
|
ssl = ssl.to(device, non_blocking=True)
|
||||||
ssl.requires_grad = False
|
ssl.requires_grad = False
|
||||||
# ssl_lengths = ssl_lengths.cuda(rank, non_blocking=True)
|
# ssl_lengths = ssl_lengths.to(device, non_blocking=True)
|
||||||
text, text_lengths = (
|
text, text_lengths = (
|
||||||
text.cuda(
|
text.to(device, non_blocking=True),
|
||||||
rank,
|
text_lengths.to(device, non_blocking=True),
|
||||||
non_blocking=True,
|
|
||||||
),
|
|
||||||
text_lengths.cuda(
|
|
||||||
rank,
|
|
||||||
non_blocking=True,
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
if hps.model.version in {"v2Pro", "v2ProPlus"}:
|
if hps.model.version in {"v2Pro", "v2ProPlus"}:
|
||||||
sv_emb = sv_emb.cuda(rank, non_blocking=True)
|
sv_emb = sv_emb.to(device, non_blocking=True)
|
||||||
else:
|
else:
|
||||||
spec, spec_lengths = spec.to(device), spec_lengths.to(device)
|
spec, spec_lengths = spec.to(device), spec_lengths.to(device)
|
||||||
y, y_lengths = y.to(device), y_lengths.to(device)
|
y, y_lengths = y.to(device), y_lengths.to(device)
|
||||||
@ -596,11 +565,11 @@ def evaluate(hps, generator, eval_loader, writer_eval):
|
|||||||
text_lengths,
|
text_lengths,
|
||||||
) in enumerate(eval_loader):
|
) in enumerate(eval_loader):
|
||||||
print(111)
|
print(111)
|
||||||
if torch.cuda.is_available():
|
if torch.xpu.is_available():
|
||||||
spec, spec_lengths = spec.cuda(), spec_lengths.cuda()
|
spec, spec_lengths = spec.to(device), spec_lengths.to(device)
|
||||||
y, y_lengths = y.cuda(), y_lengths.cuda()
|
y, y_lengths = y.to(device), y_lengths.to(device)
|
||||||
ssl = ssl.cuda()
|
ssl = ssl.to(device)
|
||||||
text, text_lengths = text.cuda(), text_lengths.cuda()
|
text, text_lengths = text.to(device), text_lengths.to(device)
|
||||||
else:
|
else:
|
||||||
spec, spec_lengths = spec.to(device), spec_lengths.to(device)
|
spec, spec_lengths = spec.to(device), spec_lengths.to(device)
|
||||||
y, y_lengths = y.to(device), y_lengths.to(device)
|
y, y_lengths = y.to(device), y_lengths.to(device)
|
||||||
|
|||||||
442
config.py
442
config.py
@ -1,218 +1,224 @@
|
|||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from tools.i18n.i18n import I18nAuto
|
from tools.i18n.i18n import I18nAuto
|
||||||
|
|
||||||
i18n = I18nAuto(language=os.environ.get("language", "Auto"))
|
i18n = I18nAuto(language=os.environ.get("language", "Auto"))
|
||||||
|
|
||||||
|
|
||||||
pretrained_sovits_name = {
|
pretrained_sovits_name = {
|
||||||
"v1": "GPT_SoVITS/pretrained_models/s2G488k.pth",
|
"v1": "GPT_SoVITS/pretrained_models/s2G488k.pth",
|
||||||
"v2": "GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s2G2333k.pth",
|
"v2": "GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s2G2333k.pth",
|
||||||
"v3": "GPT_SoVITS/pretrained_models/s2Gv3.pth", ###v3v4还要检查vocoder,算了。。。
|
"v3": "GPT_SoVITS/pretrained_models/s2Gv3.pth", ###v3v4还要检查vocoder,算了。。。
|
||||||
"v4": "GPT_SoVITS/pretrained_models/gsv-v4-pretrained/s2Gv4.pth",
|
"v4": "GPT_SoVITS/pretrained_models/gsv-v4-pretrained/s2Gv4.pth",
|
||||||
"v2Pro": "GPT_SoVITS/pretrained_models/v2Pro/s2Gv2Pro.pth",
|
"v2Pro": "GPT_SoVITS/pretrained_models/v2Pro/s2Gv2Pro.pth",
|
||||||
"v2ProPlus": "GPT_SoVITS/pretrained_models/v2Pro/s2Gv2ProPlus.pth",
|
"v2ProPlus": "GPT_SoVITS/pretrained_models/v2Pro/s2Gv2ProPlus.pth",
|
||||||
}
|
}
|
||||||
|
|
||||||
pretrained_gpt_name = {
|
pretrained_gpt_name = {
|
||||||
"v1": "GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt",
|
"v1": "GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt",
|
||||||
"v2": "GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s1bert25hz-5kh-longer-epoch=12-step=369668.ckpt",
|
"v2": "GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s1bert25hz-5kh-longer-epoch=12-step=369668.ckpt",
|
||||||
"v3": "GPT_SoVITS/pretrained_models/s1v3.ckpt",
|
"v3": "GPT_SoVITS/pretrained_models/s1v3.ckpt",
|
||||||
"v4": "GPT_SoVITS/pretrained_models/s1v3.ckpt",
|
"v4": "GPT_SoVITS/pretrained_models/s1v3.ckpt",
|
||||||
"v2Pro": "GPT_SoVITS/pretrained_models/s1v3.ckpt",
|
"v2Pro": "GPT_SoVITS/pretrained_models/s1v3.ckpt",
|
||||||
"v2ProPlus": "GPT_SoVITS/pretrained_models/s1v3.ckpt",
|
"v2ProPlus": "GPT_SoVITS/pretrained_models/s1v3.ckpt",
|
||||||
}
|
}
|
||||||
name2sovits_path = {
|
name2sovits_path = {
|
||||||
# i18n("不训练直接推v1底模!"): "GPT_SoVITS/pretrained_models/s2G488k.pth",
|
# i18n("不训练直接推v1底模!"): "GPT_SoVITS/pretrained_models/s2G488k.pth",
|
||||||
i18n("不训练直接推v2底模!"): "GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s2G2333k.pth",
|
i18n("不训练直接推v2底模!"): "GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s2G2333k.pth",
|
||||||
# i18n("不训练直接推v3底模!"): "GPT_SoVITS/pretrained_models/s2Gv3.pth",
|
# i18n("不训练直接推v3底模!"): "GPT_SoVITS/pretrained_models/s2Gv3.pth",
|
||||||
# i18n("不训练直接推v4底模!"): "GPT_SoVITS/pretrained_models/gsv-v4-pretrained/s2Gv4.pth",
|
# i18n("不训练直接推v4底模!"): "GPT_SoVITS/pretrained_models/gsv-v4-pretrained/s2Gv4.pth",
|
||||||
i18n("不训练直接推v2Pro底模!"): "GPT_SoVITS/pretrained_models/v2Pro/s2Gv2Pro.pth",
|
i18n("不训练直接推v2Pro底模!"): "GPT_SoVITS/pretrained_models/v2Pro/s2Gv2Pro.pth",
|
||||||
i18n("不训练直接推v2ProPlus底模!"): "GPT_SoVITS/pretrained_models/v2Pro/s2Gv2ProPlus.pth",
|
i18n("不训练直接推v2ProPlus底模!"): "GPT_SoVITS/pretrained_models/v2Pro/s2Gv2ProPlus.pth",
|
||||||
}
|
}
|
||||||
name2gpt_path = {
|
name2gpt_path = {
|
||||||
# i18n("不训练直接推v1底模!"):"GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt",
|
# i18n("不训练直接推v1底模!"):"GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt",
|
||||||
i18n(
|
i18n(
|
||||||
"不训练直接推v2底模!"
|
"不训练直接推v2底模!"
|
||||||
): "GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s1bert25hz-5kh-longer-epoch=12-step=369668.ckpt",
|
): "GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s1bert25hz-5kh-longer-epoch=12-step=369668.ckpt",
|
||||||
i18n("不训练直接推v3底模!"): "GPT_SoVITS/pretrained_models/s1v3.ckpt",
|
i18n("不训练直接推v3底模!"): "GPT_SoVITS/pretrained_models/s1v3.ckpt",
|
||||||
}
|
}
|
||||||
SoVITS_weight_root = [
|
SoVITS_weight_root = [
|
||||||
"SoVITS_weights",
|
"SoVITS_weights",
|
||||||
"SoVITS_weights_v2",
|
"SoVITS_weights_v2",
|
||||||
"SoVITS_weights_v3",
|
"SoVITS_weights_v3",
|
||||||
"SoVITS_weights_v4",
|
"SoVITS_weights_v4",
|
||||||
"SoVITS_weights_v2Pro",
|
"SoVITS_weights_v2Pro",
|
||||||
"SoVITS_weights_v2ProPlus",
|
"SoVITS_weights_v2ProPlus",
|
||||||
]
|
]
|
||||||
GPT_weight_root = [
|
GPT_weight_root = [
|
||||||
"GPT_weights",
|
"GPT_weights",
|
||||||
"GPT_weights_v2",
|
"GPT_weights_v2",
|
||||||
"GPT_weights_v3",
|
"GPT_weights_v3",
|
||||||
"GPT_weights_v4",
|
"GPT_weights_v4",
|
||||||
"GPT_weights_v2Pro",
|
"GPT_weights_v2Pro",
|
||||||
"GPT_weights_v2ProPlus",
|
"GPT_weights_v2ProPlus",
|
||||||
]
|
]
|
||||||
SoVITS_weight_version2root = {
|
SoVITS_weight_version2root = {
|
||||||
"v1": "SoVITS_weights",
|
"v1": "SoVITS_weights",
|
||||||
"v2": "SoVITS_weights_v2",
|
"v2": "SoVITS_weights_v2",
|
||||||
"v3": "SoVITS_weights_v3",
|
"v3": "SoVITS_weights_v3",
|
||||||
"v4": "SoVITS_weights_v4",
|
"v4": "SoVITS_weights_v4",
|
||||||
"v2Pro": "SoVITS_weights_v2Pro",
|
"v2Pro": "SoVITS_weights_v2Pro",
|
||||||
"v2ProPlus": "SoVITS_weights_v2ProPlus",
|
"v2ProPlus": "SoVITS_weights_v2ProPlus",
|
||||||
}
|
}
|
||||||
GPT_weight_version2root = {
|
GPT_weight_version2root = {
|
||||||
"v1": "GPT_weights",
|
"v1": "GPT_weights",
|
||||||
"v2": "GPT_weights_v2",
|
"v2": "GPT_weights_v2",
|
||||||
"v3": "GPT_weights_v3",
|
"v3": "GPT_weights_v3",
|
||||||
"v4": "GPT_weights_v4",
|
"v4": "GPT_weights_v4",
|
||||||
"v2Pro": "GPT_weights_v2Pro",
|
"v2Pro": "GPT_weights_v2Pro",
|
||||||
"v2ProPlus": "GPT_weights_v2ProPlus",
|
"v2ProPlus": "GPT_weights_v2ProPlus",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def custom_sort_key(s):
|
def custom_sort_key(s):
|
||||||
# 使用正则表达式提取字符串中的数字部分和非数字部分
|
# 使用正则表达式提取字符串中的数字部分和非数字部分
|
||||||
parts = re.split("(\d+)", s)
|
parts = re.split("(\d+)", s)
|
||||||
# 将数字部分转换为整数,非数字部分保持不变
|
# 将数字部分转换为整数,非数字部分保持不变
|
||||||
parts = [int(part) if part.isdigit() else part for part in parts]
|
parts = [int(part) if part.isdigit() else part for part in parts]
|
||||||
return parts
|
return parts
|
||||||
|
|
||||||
|
|
||||||
def get_weights_names():
|
def get_weights_names():
|
||||||
SoVITS_names = []
|
SoVITS_names = []
|
||||||
for key in name2sovits_path:
|
for key in name2sovits_path:
|
||||||
if os.path.exists(name2sovits_path[key]):
|
if os.path.exists(name2sovits_path[key]):
|
||||||
SoVITS_names.append(key)
|
SoVITS_names.append(key)
|
||||||
for path in SoVITS_weight_root:
|
for path in SoVITS_weight_root:
|
||||||
if not os.path.exists(path):
|
if not os.path.exists(path):
|
||||||
continue
|
continue
|
||||||
for name in os.listdir(path):
|
for name in os.listdir(path):
|
||||||
if name.endswith(".pth"):
|
if name.endswith(".pth"):
|
||||||
SoVITS_names.append("%s/%s" % (path, name))
|
SoVITS_names.append("%s/%s" % (path, name))
|
||||||
if not SoVITS_names:
|
if not SoVITS_names:
|
||||||
SoVITS_names = [""]
|
SoVITS_names = [""]
|
||||||
GPT_names = []
|
GPT_names = []
|
||||||
for key in name2gpt_path:
|
for key in name2gpt_path:
|
||||||
if os.path.exists(name2gpt_path[key]):
|
if os.path.exists(name2gpt_path[key]):
|
||||||
GPT_names.append(key)
|
GPT_names.append(key)
|
||||||
for path in GPT_weight_root:
|
for path in GPT_weight_root:
|
||||||
if not os.path.exists(path):
|
if not os.path.exists(path):
|
||||||
continue
|
continue
|
||||||
for name in os.listdir(path):
|
for name in os.listdir(path):
|
||||||
if name.endswith(".ckpt"):
|
if name.endswith(".ckpt"):
|
||||||
GPT_names.append("%s/%s" % (path, name))
|
GPT_names.append("%s/%s" % (path, name))
|
||||||
SoVITS_names = sorted(SoVITS_names, key=custom_sort_key)
|
SoVITS_names = sorted(SoVITS_names, key=custom_sort_key)
|
||||||
GPT_names = sorted(GPT_names, key=custom_sort_key)
|
GPT_names = sorted(GPT_names, key=custom_sort_key)
|
||||||
if not GPT_names:
|
if not GPT_names:
|
||||||
GPT_names = [""]
|
GPT_names = [""]
|
||||||
return SoVITS_names, GPT_names
|
return SoVITS_names, GPT_names
|
||||||
|
|
||||||
|
|
||||||
def change_choices():
|
def change_choices():
|
||||||
SoVITS_names, GPT_names = get_weights_names()
|
SoVITS_names, GPT_names = get_weights_names()
|
||||||
return {"choices": SoVITS_names, "__type__": "update"}, {
|
return {"choices": SoVITS_names, "__type__": "update"}, {
|
||||||
"choices": GPT_names,
|
"choices": GPT_names,
|
||||||
"__type__": "update",
|
"__type__": "update",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
# 推理用的指定模型
|
# 推理用的指定模型
|
||||||
sovits_path = ""
|
sovits_path = ""
|
||||||
gpt_path = ""
|
gpt_path = ""
|
||||||
is_half_str = os.environ.get("is_half", "True")
|
is_half_str = os.environ.get("is_half", "True")
|
||||||
is_half = True if is_half_str.lower() == "true" else False
|
is_half = True if is_half_str.lower() == "true" else False
|
||||||
is_share_str = os.environ.get("is_share", "False")
|
is_share_str = os.environ.get("is_share", "False")
|
||||||
is_share = True if is_share_str.lower() == "true" else False
|
is_share = True if is_share_str.lower() == "true" else False
|
||||||
|
|
||||||
cnhubert_path = "GPT_SoVITS/pretrained_models/chinese-hubert-base"
|
cnhubert_path = "GPT_SoVITS/pretrained_models/chinese-hubert-base"
|
||||||
bert_path = "GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large"
|
bert_path = "GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large"
|
||||||
pretrained_sovits_path = "GPT_SoVITS/pretrained_models/s2G488k.pth"
|
pretrained_sovits_path = "GPT_SoVITS/pretrained_models/s2G488k.pth"
|
||||||
pretrained_gpt_path = "GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt"
|
pretrained_gpt_path = "GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt"
|
||||||
|
|
||||||
exp_root = "logs"
|
exp_root = "logs"
|
||||||
python_exec = sys.executable or "python"
|
python_exec = sys.executable or "python"
|
||||||
|
|
||||||
webui_port_main = 9874
|
webui_port_main = 9874
|
||||||
webui_port_uvr5 = 9873
|
webui_port_uvr5 = 9873
|
||||||
webui_port_infer_tts = 9872
|
webui_port_infer_tts = 9872
|
||||||
webui_port_subfix = 9871
|
webui_port_subfix = 9871
|
||||||
|
|
||||||
api_port = 9880
|
api_port = 9880
|
||||||
|
|
||||||
|
|
||||||
# Thanks to the contribution of @Karasukaigan and @XXXXRT666
|
# Thanks to the contribution of @Karasukaigan and @XXXXRT666
|
||||||
def get_device_dtype_sm(idx: int) -> tuple[torch.device, torch.dtype, float, float]:
|
# Modified for Intel GPU (XPU)
|
||||||
cpu = torch.device("cpu")
|
def get_device_dtype_sm(idx: int) -> tuple[torch.device, torch.dtype, float, float]:
|
||||||
cuda = torch.device(f"cuda:{idx}")
|
cpu = torch.device("cpu")
|
||||||
if not torch.cuda.is_available():
|
try:
|
||||||
return cpu, torch.float32, 0.0, 0.0
|
if not torch.xpu.is_available():
|
||||||
device_idx = idx
|
return cpu, torch.float32, 0.0, 0.0
|
||||||
capability = torch.cuda.get_device_capability(device_idx)
|
except AttributeError:
|
||||||
name = torch.cuda.get_device_name(device_idx)
|
return cpu, torch.float32, 0.0, 0.0
|
||||||
mem_bytes = torch.cuda.get_device_properties(device_idx).total_memory
|
|
||||||
mem_gb = mem_bytes / (1024**3) + 0.4
|
xpu_device = torch.device(f"xpu:{idx}")
|
||||||
major, minor = capability
|
properties = torch.xpu.get_device_properties(idx)
|
||||||
sm_version = major + minor / 10.0
|
mem_bytes = properties.total_memory
|
||||||
is_16_series = bool(re.search(r"16\d{2}", name)) and sm_version == 7.5
|
mem_gb = mem_bytes / (1024**3)
|
||||||
if mem_gb < 4 or sm_version < 5.3:
|
|
||||||
return cpu, torch.float32, 0.0, 0.0
|
# Simplified logic for XPU, assuming FP16/BF16 is generally supported.
|
||||||
if sm_version == 6.1 or is_16_series == True:
|
# The complex SM version check is CUDA-specific.
|
||||||
return cuda, torch.float32, sm_version, mem_gb
|
if mem_gb < 4: # Example threshold
|
||||||
if sm_version > 6.1:
|
return cpu, torch.float32, 0.0, 0.0
|
||||||
return cuda, torch.float16, sm_version, mem_gb
|
|
||||||
return cpu, torch.float32, 0.0, 0.0
|
# For Intel GPUs, we can generally assume float16 is available.
|
||||||
|
# The 'sm_version' equivalent is not straightforward, so we use a placeholder value (e.g., 1.0)
|
||||||
|
# for compatibility with the downstream logic that sorts devices.
|
||||||
IS_GPU = True
|
return xpu_device, torch.float16, 1.0, mem_gb
|
||||||
GPU_INFOS: list[str] = []
|
|
||||||
GPU_INDEX: set[int] = set()
|
|
||||||
GPU_COUNT = torch.cuda.device_count()
|
IS_GPU = True
|
||||||
CPU_INFO: str = "0\tCPU " + i18n("CPU训练,较慢")
|
GPU_INFOS: list[str] = []
|
||||||
tmp: list[tuple[torch.device, torch.dtype, float, float]] = []
|
GPU_INDEX: set[int] = set()
|
||||||
memset: set[float] = set()
|
try:
|
||||||
|
GPU_COUNT = torch.xpu.device_count() if torch.xpu.is_available() else 0
|
||||||
for i in range(max(GPU_COUNT, 1)):
|
except AttributeError:
|
||||||
tmp.append(get_device_dtype_sm(i))
|
GPU_COUNT = 0
|
||||||
|
CPU_INFO: str = "0\tCPU " + i18n("CPU训练,较慢")
|
||||||
for j in tmp:
|
tmp: list[tuple[torch.device, torch.dtype, float, float]] = []
|
||||||
device = j[0]
|
memset: set[float] = set()
|
||||||
memset.add(j[3])
|
|
||||||
if device.type != "cpu":
|
for i in range(max(GPU_COUNT, 1)):
|
||||||
GPU_INFOS.append(f"{device.index}\t{torch.cuda.get_device_name(device.index)}")
|
tmp.append(get_device_dtype_sm(i))
|
||||||
GPU_INDEX.add(device.index)
|
|
||||||
|
for j in tmp:
|
||||||
if not GPU_INFOS:
|
device = j[0]
|
||||||
IS_GPU = False
|
memset.add(j[3])
|
||||||
GPU_INFOS.append(CPU_INFO)
|
if device.type == "xpu":
|
||||||
GPU_INDEX.add(0)
|
GPU_INFOS.append(f"{device.index}\t{torch.xpu.get_device_name(device.index)}")
|
||||||
|
GPU_INDEX.add(device.index)
|
||||||
infer_device = max(tmp, key=lambda x: (x[2], x[3]))[0]
|
|
||||||
is_half = any(dtype == torch.float16 for _, dtype, _, _ in tmp)
|
if not GPU_INFOS:
|
||||||
|
IS_GPU = False
|
||||||
|
GPU_INFOS.append(CPU_INFO)
|
||||||
class Config:
|
GPU_INDEX.add(0)
|
||||||
def __init__(self):
|
|
||||||
self.sovits_path = sovits_path
|
infer_device = max(tmp, key=lambda x: (x[2], x[3]))[0]
|
||||||
self.gpt_path = gpt_path
|
is_half = any(dtype == torch.float16 for _, dtype, _, _ in tmp)
|
||||||
self.is_half = is_half
|
|
||||||
|
|
||||||
self.cnhubert_path = cnhubert_path
|
class Config:
|
||||||
self.bert_path = bert_path
|
def __init__(self):
|
||||||
self.pretrained_sovits_path = pretrained_sovits_path
|
self.sovits_path = sovits_path
|
||||||
self.pretrained_gpt_path = pretrained_gpt_path
|
self.gpt_path = gpt_path
|
||||||
|
self.is_half = is_half
|
||||||
self.exp_root = exp_root
|
|
||||||
self.python_exec = python_exec
|
self.cnhubert_path = cnhubert_path
|
||||||
self.infer_device = infer_device
|
self.bert_path = bert_path
|
||||||
|
self.pretrained_sovits_path = pretrained_sovits_path
|
||||||
self.webui_port_main = webui_port_main
|
self.pretrained_gpt_path = pretrained_gpt_path
|
||||||
self.webui_port_uvr5 = webui_port_uvr5
|
|
||||||
self.webui_port_infer_tts = webui_port_infer_tts
|
self.exp_root = exp_root
|
||||||
self.webui_port_subfix = webui_port_subfix
|
self.python_exec = python_exec
|
||||||
|
self.infer_device = infer_device
|
||||||
self.api_port = api_port
|
|
||||||
|
self.webui_port_main = webui_port_main
|
||||||
|
self.webui_port_uvr5 = webui_port_uvr5
|
||||||
|
self.webui_port_infer_tts = webui_port_infer_tts
|
||||||
|
self.webui_port_subfix = webui_port_subfix
|
||||||
|
|
||||||
|
self.api_port = api_port
|
||||||
|
|||||||
@ -14,49 +14,29 @@ fi
|
|||||||
trap 'echo "Error Occured at \"$BASH_COMMAND\" with exit code $?"; exit 1' ERR
|
trap 'echo "Error Occured at \"$BASH_COMMAND\" with exit code $?"; exit 1' ERR
|
||||||
|
|
||||||
LITE=false
|
LITE=false
|
||||||
CUDA_VERSION=12.6
|
|
||||||
|
|
||||||
print_help() {
|
print_help() {
|
||||||
echo "Usage: bash docker_build.sh [OPTIONS]"
|
echo "Usage: bash docker_build.sh [OPTIONS]"
|
||||||
echo ""
|
echo ""
|
||||||
echo "Options:"
|
echo "Options:"
|
||||||
echo " --cuda 12.6|12.8 Specify the CUDA VERSION (REQUIRED)"
|
|
||||||
echo " --lite Build a Lite Image"
|
echo " --lite Build a Lite Image"
|
||||||
echo " -h, --help Show this help message and exit"
|
echo " -h, --help Show this help message and exit"
|
||||||
echo ""
|
echo ""
|
||||||
echo "Examples:"
|
echo "Examples:"
|
||||||
echo " bash docker_build.sh --cuda 12.6 --funasr --faster-whisper"
|
echo " bash docker_build.sh --lite"
|
||||||
}
|
}
|
||||||
|
|
||||||
# Show help if no arguments provided
|
|
||||||
if [[ $# -eq 0 ]]; then
|
|
||||||
print_help
|
|
||||||
exit 0
|
|
||||||
fi
|
|
||||||
|
|
||||||
# Parse arguments
|
# Parse arguments
|
||||||
while [[ $# -gt 0 ]]; do
|
while [[ $# -gt 0 ]]; do
|
||||||
case "$1" in
|
case "$1" in
|
||||||
--cuda)
|
|
||||||
case "$2" in
|
|
||||||
12.6)
|
|
||||||
CUDA_VERSION=12.6
|
|
||||||
;;
|
|
||||||
12.8)
|
|
||||||
CUDA_VERSION=12.8
|
|
||||||
;;
|
|
||||||
*)
|
|
||||||
echo "Error: Invalid CUDA_VERSION: $2"
|
|
||||||
echo "Choose From: [12.6, 12.8]"
|
|
||||||
exit 1
|
|
||||||
;;
|
|
||||||
esac
|
|
||||||
shift 2
|
|
||||||
;;
|
|
||||||
--lite)
|
--lite)
|
||||||
LITE=true
|
LITE=true
|
||||||
shift
|
shift
|
||||||
;;
|
;;
|
||||||
|
-h|--help)
|
||||||
|
print_help
|
||||||
|
exit 0
|
||||||
|
;;
|
||||||
*)
|
*)
|
||||||
echo "Unknown Argument: $1"
|
echo "Unknown Argument: $1"
|
||||||
echo "Use -h or --help to see available options."
|
echo "Use -h or --help to see available options."
|
||||||
@ -74,7 +54,6 @@ else
|
|||||||
fi
|
fi
|
||||||
|
|
||||||
docker build \
|
docker build \
|
||||||
--build-arg CUDA_VERSION=$CUDA_VERSION \
|
|
||||||
--build-arg LITE=$LITE \
|
--build-arg LITE=$LITE \
|
||||||
--build-arg TARGETPLATFORM="$TARGETPLATFORM" \
|
--build-arg TARGETPLATFORM="$TARGETPLATFORM" \
|
||||||
--build-arg TORCH_BASE=$TORCH_BASE \
|
--build-arg TORCH_BASE=$TORCH_BASE \
|
||||||
|
|||||||
@ -5,6 +5,9 @@ tensorboard
|
|||||||
librosa==0.10.2
|
librosa==0.10.2
|
||||||
numba
|
numba
|
||||||
pytorch-lightning>=2.4
|
pytorch-lightning>=2.4
|
||||||
|
torch==2.9
|
||||||
|
intel-extension-for-pytorch
|
||||||
|
torchvision
|
||||||
gradio<5
|
gradio<5
|
||||||
ffmpeg-python
|
ffmpeg-python
|
||||||
onnxruntime; platform_machine == "aarch64" or platform_machine == "arm64"
|
onnxruntime; platform_machine == "aarch64" or platform_machine == "arm64"
|
||||||
|
|||||||
@ -85,7 +85,7 @@ def execute_asr(input_folder, output_folder, model_path, language, precision):
|
|||||||
if language == "auto":
|
if language == "auto":
|
||||||
language = None # 不设置语种由模型自动输出概率最高的语种
|
language = None # 不设置语种由模型自动输出概率最高的语种
|
||||||
print("loading faster whisper model:", model_path, model_path)
|
print("loading faster whisper model:", model_path, model_path)
|
||||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
device = "xpu" if torch.xpu.is_available() else "cpu"
|
||||||
model = WhisperModel(model_path, device=device, compute_type=precision)
|
model = WhisperModel(model_path, device=device, compute_type=precision)
|
||||||
|
|
||||||
input_file_names = os.listdir(input_folder)
|
input_file_names = os.listdir(input_folder)
|
||||||
@ -128,8 +128,6 @@ def execute_asr(input_folder, output_folder, model_path, language, precision):
|
|||||||
return output_file_path
|
return output_file_path
|
||||||
|
|
||||||
|
|
||||||
load_cudnn()
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
|
|||||||
@ -1,231 +1,137 @@
|
|||||||
import ctypes
|
import ctypes
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import ffmpeg
|
import ffmpeg
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
|
|
||||||
from tools.i18n.i18n import I18nAuto
|
from tools.i18n.i18n import I18nAuto
|
||||||
|
|
||||||
i18n = I18nAuto(language=os.environ.get("language", "Auto"))
|
i18n = I18nAuto(language=os.environ.get("language", "Auto"))
|
||||||
|
|
||||||
|
|
||||||
def load_audio(file, sr):
|
def load_audio(file, sr):
|
||||||
try:
|
try:
|
||||||
# https://github.com/openai/whisper/blob/main/whisper/audio.py#L26
|
# https://github.com/openai/whisper/blob/main/whisper/audio.py#L26
|
||||||
# This launches a subprocess to decode audio while down-mixing and resampling as necessary.
|
# This launches a subprocess to decode audio while down-mixing and resampling as necessary.
|
||||||
# Requires the ffmpeg CLI and `ffmpeg-python` package to be installed.
|
# Requires the ffmpeg CLI and `ffmpeg-python` package to be installed.
|
||||||
file = clean_path(file) # 防止小白拷路径头尾带了空格和"和回车
|
file = clean_path(file) # 防止小白拷路径头尾带了空格和"和回车
|
||||||
if os.path.exists(file) is False:
|
if os.path.exists(file) is False:
|
||||||
raise RuntimeError("You input a wrong audio path that does not exists, please fix it!")
|
raise RuntimeError("You input a wrong audio path that does not exists, please fix it!")
|
||||||
out, _ = (
|
out, _ = (
|
||||||
ffmpeg.input(file, threads=0)
|
ffmpeg.input(file, threads=0)
|
||||||
.output("-", format="f32le", acodec="pcm_f32le", ac=1, ar=sr)
|
.output("-", format="f32le", acodec="pcm_f32le", ac=1, ar=sr)
|
||||||
.run(cmd=["ffmpeg", "-nostdin"], capture_stdout=True, capture_stderr=True)
|
.run(cmd=["ffmpeg", "-nostdin"], capture_stdout=True, capture_stderr=True)
|
||||||
)
|
)
|
||||||
except Exception:
|
except Exception:
|
||||||
out, _ = (
|
out, _ = (
|
||||||
ffmpeg.input(file, threads=0)
|
ffmpeg.input(file, threads=0)
|
||||||
.output("-", format="f32le", acodec="pcm_f32le", ac=1, ar=sr)
|
.output("-", format="f32le", acodec="pcm_f32le", ac=1, ar=sr)
|
||||||
.run(cmd=["ffmpeg", "-nostdin"], capture_stdout=True)
|
.run(cmd=["ffmpeg", "-nostdin"], capture_stdout=True)
|
||||||
) # Expose the Error
|
) # Expose the Error
|
||||||
raise RuntimeError(i18n("音频加载失败"))
|
raise RuntimeError(i18n("音频加载失败"))
|
||||||
|
|
||||||
return np.frombuffer(out, np.float32).flatten()
|
return np.frombuffer(out, np.float32).flatten()
|
||||||
|
|
||||||
|
|
||||||
def clean_path(path_str: str):
|
def clean_path(path_str: str):
|
||||||
if path_str.endswith(("\\", "/")):
|
if path_str.endswith(("\\", "/")):
|
||||||
return clean_path(path_str[0:-1])
|
return clean_path(path_str[0:-1])
|
||||||
path_str = path_str.replace("/", os.sep).replace("\\", os.sep)
|
path_str = path_str.replace("/", os.sep).replace("\\", os.sep)
|
||||||
return path_str.strip(
|
return path_str.strip(
|
||||||
" '\n\"\u202a"
|
" '\n\"\u202a"
|
||||||
) # path_str.strip(" ").strip('\'').strip("\n").strip('"').strip(" ").strip("\u202a")
|
) # path_str.strip(" ").strip('\'').strip("\n").strip('"').strip(" ").strip("\u202a")
|
||||||
|
|
||||||
|
|
||||||
def check_for_existance(file_list: list = None, is_train=False, is_dataset_processing=False):
|
def check_for_existance(file_list: list = None, is_train=False, is_dataset_processing=False):
|
||||||
files_status = []
|
files_status = []
|
||||||
if is_train == True and file_list:
|
if is_train == True and file_list:
|
||||||
file_list.append(os.path.join(file_list[0], "2-name2text.txt"))
|
file_list.append(os.path.join(file_list[0], "2-name2text.txt"))
|
||||||
file_list.append(os.path.join(file_list[0], "3-bert"))
|
file_list.append(os.path.join(file_list[0], "3-bert"))
|
||||||
file_list.append(os.path.join(file_list[0], "4-cnhubert"))
|
file_list.append(os.path.join(file_list[0], "4-cnhubert"))
|
||||||
file_list.append(os.path.join(file_list[0], "5-wav32k"))
|
file_list.append(os.path.join(file_list[0], "5-wav32k"))
|
||||||
file_list.append(os.path.join(file_list[0], "6-name2semantic.tsv"))
|
file_list.append(os.path.join(file_list[0], "6-name2semantic.tsv"))
|
||||||
for file in file_list:
|
for file in file_list:
|
||||||
if os.path.exists(file):
|
if os.path.exists(file):
|
||||||
files_status.append(True)
|
files_status.append(True)
|
||||||
else:
|
else:
|
||||||
files_status.append(False)
|
files_status.append(False)
|
||||||
if sum(files_status) != len(files_status):
|
if sum(files_status) != len(files_status):
|
||||||
if is_train:
|
if is_train:
|
||||||
for file, status in zip(file_list, files_status):
|
for file, status in zip(file_list, files_status):
|
||||||
if status:
|
if status:
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
gr.Warning(file)
|
gr.Warning(file)
|
||||||
gr.Warning(i18n("以下文件或文件夹不存在"))
|
gr.Warning(i18n("以下文件或文件夹不存在"))
|
||||||
return False
|
return False
|
||||||
elif is_dataset_processing:
|
elif is_dataset_processing:
|
||||||
if files_status[0]:
|
if files_status[0]:
|
||||||
return True
|
return True
|
||||||
elif not files_status[0]:
|
elif not files_status[0]:
|
||||||
gr.Warning(file_list[0])
|
gr.Warning(file_list[0])
|
||||||
elif not files_status[1] and file_list[1]:
|
elif not files_status[1] and file_list[1]:
|
||||||
gr.Warning(file_list[1])
|
gr.Warning(file_list[1])
|
||||||
gr.Warning(i18n("以下文件或文件夹不存在"))
|
gr.Warning(i18n("以下文件或文件夹不存在"))
|
||||||
return False
|
return False
|
||||||
else:
|
else:
|
||||||
if file_list[0]:
|
if file_list[0]:
|
||||||
gr.Warning(file_list[0])
|
gr.Warning(file_list[0])
|
||||||
gr.Warning(i18n("以下文件或文件夹不存在"))
|
gr.Warning(i18n("以下文件或文件夹不存在"))
|
||||||
else:
|
else:
|
||||||
gr.Warning(i18n("路径不能为空"))
|
gr.Warning(i18n("路径不能为空"))
|
||||||
return False
|
return False
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
def check_details(path_list=None, is_train=False, is_dataset_processing=False):
|
def check_details(path_list=None, is_train=False, is_dataset_processing=False):
|
||||||
if is_dataset_processing:
|
if is_dataset_processing:
|
||||||
list_path, audio_path = path_list
|
list_path, audio_path = path_list
|
||||||
if not list_path.endswith(".list"):
|
if not list_path.endswith(".list"):
|
||||||
gr.Warning(i18n("请填入正确的List路径"))
|
gr.Warning(i18n("请填入正确的List路径"))
|
||||||
return
|
return
|
||||||
if audio_path:
|
if audio_path:
|
||||||
if not os.path.isdir(audio_path):
|
if not os.path.isdir(audio_path):
|
||||||
gr.Warning(i18n("请填入正确的音频文件夹路径"))
|
gr.Warning(i18n("请填入正确的音频文件夹路径"))
|
||||||
return
|
return
|
||||||
with open(list_path, "r", encoding="utf8") as f:
|
with open(list_path, "r", encoding="utf8") as f:
|
||||||
line = f.readline().strip("\n").split("\n")
|
line = f.readline().strip("\n").split("\n")
|
||||||
wav_name, _, __, ___ = line[0].split("|")
|
wav_name, _, __, ___ = line[0].split("|")
|
||||||
wav_name = clean_path(wav_name)
|
wav_name = clean_path(wav_name)
|
||||||
if audio_path != "" and audio_path != None:
|
if audio_path != "" and audio_path != None:
|
||||||
wav_name = os.path.basename(wav_name)
|
wav_name = os.path.basename(wav_name)
|
||||||
wav_path = "%s/%s" % (audio_path, wav_name)
|
wav_path = "%s/%s" % (audio_path, wav_name)
|
||||||
else:
|
else:
|
||||||
wav_path = wav_name
|
wav_path = wav_name
|
||||||
if os.path.exists(wav_path):
|
if os.path.exists(wav_path):
|
||||||
...
|
...
|
||||||
else:
|
else:
|
||||||
gr.Warning(wav_path + i18n("路径错误"))
|
gr.Warning(wav_path + i18n("路径错误"))
|
||||||
return
|
return
|
||||||
if is_train:
|
if is_train:
|
||||||
path_list.append(os.path.join(path_list[0], "2-name2text.txt"))
|
path_list.append(os.path.join(path_list[0], "2-name2text.txt"))
|
||||||
path_list.append(os.path.join(path_list[0], "4-cnhubert"))
|
path_list.append(os.path.join(path_list[0], "4-cnhubert"))
|
||||||
path_list.append(os.path.join(path_list[0], "5-wav32k"))
|
path_list.append(os.path.join(path_list[0], "5-wav32k"))
|
||||||
path_list.append(os.path.join(path_list[0], "6-name2semantic.tsv"))
|
path_list.append(os.path.join(path_list[0], "6-name2semantic.tsv"))
|
||||||
phone_path, hubert_path, wav_path, semantic_path = path_list[1:]
|
phone_path, hubert_path, wav_path, semantic_path = path_list[1:]
|
||||||
with open(phone_path, "r", encoding="utf-8") as f:
|
with open(phone_path, "r", encoding="utf-8") as f:
|
||||||
if f.read(1):
|
if f.read(1):
|
||||||
...
|
...
|
||||||
else:
|
else:
|
||||||
gr.Warning(i18n("缺少音素数据集"))
|
gr.Warning(i18n("缺少音素数据集"))
|
||||||
if os.listdir(hubert_path):
|
if os.listdir(hubert_path):
|
||||||
...
|
...
|
||||||
else:
|
else:
|
||||||
gr.Warning(i18n("缺少Hubert数据集"))
|
gr.Warning(i18n("缺少Hubert数据集"))
|
||||||
if os.listdir(wav_path):
|
if os.listdir(wav_path):
|
||||||
...
|
...
|
||||||
else:
|
else:
|
||||||
gr.Warning(i18n("缺少音频数据集"))
|
gr.Warning(i18n("缺少音频数据集"))
|
||||||
df = pd.read_csv(semantic_path, delimiter="\t", encoding="utf-8")
|
df = pd.read_csv(semantic_path, delimiter="\t", encoding="utf-8")
|
||||||
if len(df) >= 1:
|
if len(df) >= 1:
|
||||||
...
|
...
|
||||||
else:
|
else:
|
||||||
gr.Warning(i18n("缺少语义数据集"))
|
gr.Warning(i18n("缺少语义数据集"))
|
||||||
|
|
||||||
|
|
||||||
def load_cudnn():
|
|
||||||
import torch
|
|
||||||
|
|
||||||
if not torch.cuda.is_available():
|
|
||||||
print("[INFO] CUDA is not available, skipping cuDNN setup.")
|
|
||||||
return
|
|
||||||
|
|
||||||
if sys.platform == "win32":
|
|
||||||
torch_lib_dir = Path(torch.__file__).parent / "lib"
|
|
||||||
if torch_lib_dir.exists():
|
|
||||||
os.add_dll_directory(str(torch_lib_dir))
|
|
||||||
print(f"[INFO] Added DLL directory: {torch_lib_dir}")
|
|
||||||
matching_files = sorted(torch_lib_dir.glob("cudnn_cnn*.dll"))
|
|
||||||
if not matching_files:
|
|
||||||
print(f"[ERROR] No cudnn_cnn*.dll found in {torch_lib_dir}")
|
|
||||||
return
|
|
||||||
for dll_path in matching_files:
|
|
||||||
dll_name = os.path.basename(dll_path)
|
|
||||||
try:
|
|
||||||
ctypes.CDLL(dll_name)
|
|
||||||
print(f"[INFO] Loaded: {dll_name}")
|
|
||||||
except OSError as e:
|
|
||||||
print(f"[WARNING] Failed to load {dll_name}: {e}")
|
|
||||||
else:
|
|
||||||
print(f"[WARNING] Torch lib directory not found: {torch_lib_dir}")
|
|
||||||
|
|
||||||
elif sys.platform == "linux":
|
|
||||||
site_packages = Path(torch.__file__).resolve().parents[1]
|
|
||||||
cudnn_dir = site_packages / "nvidia" / "cudnn" / "lib"
|
|
||||||
|
|
||||||
if not cudnn_dir.exists():
|
|
||||||
print(f"[ERROR] cudnn dir not found: {cudnn_dir}")
|
|
||||||
return
|
|
||||||
|
|
||||||
matching_files = sorted(cudnn_dir.glob("libcudnn_cnn*.so*"))
|
|
||||||
if not matching_files:
|
|
||||||
print(f"[ERROR] No libcudnn_cnn*.so* found in {cudnn_dir}")
|
|
||||||
return
|
|
||||||
|
|
||||||
for so_path in matching_files:
|
|
||||||
try:
|
|
||||||
ctypes.CDLL(so_path, mode=ctypes.RTLD_GLOBAL) # type: ignore
|
|
||||||
print(f"[INFO] Loaded: {so_path}")
|
|
||||||
except OSError as e:
|
|
||||||
print(f"[WARNING] Failed to load {so_path}: {e}")
|
|
||||||
|
|
||||||
|
|
||||||
def load_nvrtc():
|
|
||||||
import torch
|
|
||||||
|
|
||||||
if not torch.cuda.is_available():
|
|
||||||
print("[INFO] CUDA is not available, skipping nvrtc setup.")
|
|
||||||
return
|
|
||||||
|
|
||||||
if sys.platform == "win32":
|
|
||||||
torch_lib_dir = Path(torch.__file__).parent / "lib"
|
|
||||||
if torch_lib_dir.exists():
|
|
||||||
os.add_dll_directory(str(torch_lib_dir))
|
|
||||||
print(f"[INFO] Added DLL directory: {torch_lib_dir}")
|
|
||||||
matching_files = sorted(torch_lib_dir.glob("nvrtc*.dll"))
|
|
||||||
if not matching_files:
|
|
||||||
print(f"[ERROR] No nvrtc*.dll found in {torch_lib_dir}")
|
|
||||||
return
|
|
||||||
for dll_path in matching_files:
|
|
||||||
dll_name = os.path.basename(dll_path)
|
|
||||||
try:
|
|
||||||
ctypes.CDLL(dll_name)
|
|
||||||
print(f"[INFO] Loaded: {dll_name}")
|
|
||||||
except OSError as e:
|
|
||||||
print(f"[WARNING] Failed to load {dll_name}: {e}")
|
|
||||||
else:
|
|
||||||
print(f"[WARNING] Torch lib directory not found: {torch_lib_dir}")
|
|
||||||
|
|
||||||
elif sys.platform == "linux":
|
|
||||||
site_packages = Path(torch.__file__).resolve().parents[1]
|
|
||||||
nvrtc_dir = site_packages / "nvidia" / "cuda_nvrtc" / "lib"
|
|
||||||
|
|
||||||
if not nvrtc_dir.exists():
|
|
||||||
print(f"[ERROR] nvrtc dir not found: {nvrtc_dir}")
|
|
||||||
return
|
|
||||||
|
|
||||||
matching_files = sorted(nvrtc_dir.glob("libnvrtc*.so*"))
|
|
||||||
if not matching_files:
|
|
||||||
print(f"[ERROR] No libnvrtc*.so* found in {nvrtc_dir}")
|
|
||||||
return
|
|
||||||
|
|
||||||
for so_path in matching_files:
|
|
||||||
try:
|
|
||||||
ctypes.CDLL(so_path, mode=ctypes.RTLD_GLOBAL) # type: ignore
|
|
||||||
print(f"[INFO] Loaded: {so_path}")
|
|
||||||
except OSError as e:
|
|
||||||
print(f"[WARNING] Failed to load {so_path}: {e}")
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user