mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2025-12-17 01:59:08 +08:00
feat: Migrate from CUDA to XPU for Intel GPU support
This commit migrates the project from using NVIDIA CUDA to Intel XPU for GPU acceleration, based on the PyTorch 2.9 release. Key changes include: - Replaced `torch.cuda` with `torch.xpu` for device checks, memory management, and distributed training. - Updated device strings from "cuda" to "xpu" across the codebase. - Switched the distributed training backend from "nccl" to "ccl" for Intel GPUs. - Disabled custom CUDA kernels in the `BigVGAN` module by setting `use_cuda_kernel=False`. - Updated `requirements.txt` to include `torch==2.9` and `intel-extension-for-pytorch`. - Modified CI/CD pipelines and build scripts to remove CUDA dependencies and build for an XPU target.
This commit is contained in:
parent
11aa78bd9b
commit
d3b8f7e09e
@ -15,11 +15,7 @@ on:
|
|||||||
jobs:
|
jobs:
|
||||||
build:
|
build:
|
||||||
runs-on: windows-latest
|
runs-on: windows-latest
|
||||||
strategy:
|
|
||||||
matrix:
|
|
||||||
torch_cuda: [cu124, cu128]
|
|
||||||
env:
|
env:
|
||||||
TORCH_CUDA: ${{ matrix.torch_cuda }}
|
|
||||||
MODELSCOPE_USERNAME: ${{ secrets.MODELSCOPE_USERNAME }}
|
MODELSCOPE_USERNAME: ${{ secrets.MODELSCOPE_USERNAME }}
|
||||||
MODELSCOPE_TOKEN: ${{ secrets.MODELSCOPE_TOKEN }}
|
MODELSCOPE_TOKEN: ${{ secrets.MODELSCOPE_TOKEN }}
|
||||||
HUGGINGFACE_USERNAME: ${{ secrets.HUGGINGFACE_USERNAME }}
|
HUGGINGFACE_USERNAME: ${{ secrets.HUGGINGFACE_USERNAME }}
|
||||||
|
|||||||
234
.github/workflows/docker-publish.yaml
vendored
234
.github/workflows/docker-publish.yaml
vendored
@ -18,68 +18,22 @@ jobs:
|
|||||||
DATE=$(date +'%Y%m%d')
|
DATE=$(date +'%Y%m%d')
|
||||||
COMMIT=$(git rev-parse --short=6 HEAD)
|
COMMIT=$(git rev-parse --short=6 HEAD)
|
||||||
echo "tag=${DATE}-${COMMIT}" >> $GITHUB_OUTPUT
|
echo "tag=${DATE}-${COMMIT}" >> $GITHUB_OUTPUT
|
||||||
build-amd64:
|
build-and-publish:
|
||||||
needs: generate-meta
|
needs: generate-meta
|
||||||
runs-on: ubuntu-22.04
|
runs-on: ubuntu-22.04
|
||||||
strategy:
|
strategy:
|
||||||
matrix:
|
matrix:
|
||||||
include:
|
include:
|
||||||
- cuda_version: 12.6
|
- lite: true
|
||||||
lite: true
|
|
||||||
torch_base: lite
|
torch_base: lite
|
||||||
tag_prefix: cu126-lite
|
tag_prefix: xpu-lite
|
||||||
- cuda_version: 12.6
|
- lite: false
|
||||||
lite: false
|
|
||||||
torch_base: full
|
torch_base: full
|
||||||
tag_prefix: cu126
|
tag_prefix: xpu
|
||||||
- cuda_version: 12.8
|
|
||||||
lite: true
|
|
||||||
torch_base: lite
|
|
||||||
tag_prefix: cu128-lite
|
|
||||||
- cuda_version: 12.8
|
|
||||||
lite: false
|
|
||||||
torch_base: full
|
|
||||||
tag_prefix: cu128
|
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout Code
|
- name: Checkout Code
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
- name: Free up disk space
|
|
||||||
run: |
|
|
||||||
echo "Before cleanup:"
|
|
||||||
df -h
|
|
||||||
|
|
||||||
sudo rm -rf /opt/ghc
|
|
||||||
sudo rm -rf /opt/hostedtoolcache/CodeQL
|
|
||||||
sudo rm -rf /opt/hostedtoolcache/PyPy
|
|
||||||
sudo rm -rf /opt/hostedtoolcache/go
|
|
||||||
sudo rm -rf /opt/hostedtoolcache/node
|
|
||||||
sudo rm -rf /opt/hostedtoolcache/Ruby
|
|
||||||
sudo rm -rf /opt/microsoft
|
|
||||||
sudo rm -rf /opt/pipx
|
|
||||||
sudo rm -rf /opt/az
|
|
||||||
sudo rm -rf /opt/google
|
|
||||||
|
|
||||||
|
|
||||||
sudo rm -rf /usr/lib/jvm
|
|
||||||
sudo rm -rf /usr/lib/google-cloud-sdk
|
|
||||||
sudo rm -rf /usr/lib/dotnet
|
|
||||||
|
|
||||||
sudo rm -rf /usr/local/lib/android
|
|
||||||
sudo rm -rf /usr/local/.ghcup
|
|
||||||
sudo rm -rf /usr/local/julia1.11.5
|
|
||||||
sudo rm -rf /usr/local/share/powershell
|
|
||||||
sudo rm -rf /usr/local/share/chromium
|
|
||||||
|
|
||||||
sudo rm -rf /usr/share/swift
|
|
||||||
sudo rm -rf /usr/share/miniconda
|
|
||||||
sudo rm -rf /usr/share/az_12.1.0
|
|
||||||
sudo rm -rf /usr/share/dotnet
|
|
||||||
|
|
||||||
echo "After cleanup:"
|
|
||||||
df -h
|
|
||||||
|
|
||||||
- name: Set up Docker Buildx
|
- name: Set up Docker Buildx
|
||||||
uses: docker/setup-buildx-action@v3
|
uses: docker/setup-buildx-action@v3
|
||||||
|
|
||||||
@ -89,188 +43,18 @@ jobs:
|
|||||||
username: ${{ secrets.DOCKER_HUB_USERNAME }}
|
username: ${{ secrets.DOCKER_HUB_USERNAME }}
|
||||||
password: ${{ secrets.DOCKER_HUB_PASSWORD }}
|
password: ${{ secrets.DOCKER_HUB_PASSWORD }}
|
||||||
|
|
||||||
- name: Build and Push Docker Image (amd64)
|
- name: Build and Push Docker Image
|
||||||
uses: docker/build-push-action@v5
|
uses: docker/build-push-action@v5
|
||||||
with:
|
with:
|
||||||
context: .
|
context: .
|
||||||
file: ./Dockerfile
|
file: ./Dockerfile
|
||||||
push: true
|
push: true
|
||||||
platforms: linux/amd64
|
platforms: linux/amd64,linux/arm64
|
||||||
build-args: |
|
build-args: |
|
||||||
LITE=${{ matrix.lite }}
|
LITE=${{ matrix.lite }}
|
||||||
TORCH_BASE=${{ matrix.torch_base }}
|
TORCH_BASE=${{ matrix.torch_base }}
|
||||||
CUDA_VERSION=${{ matrix.cuda_version }}
|
|
||||||
WORKFLOW=true
|
WORKFLOW=true
|
||||||
tags: |
|
tags: |
|
||||||
xxxxrt666/gpt-sovits:${{ matrix.tag_prefix }}-${{ needs.generate-meta.outputs.tag }}-amd64
|
xxxxrt666/gpt-sovits:${{ matrix.tag_prefix }}-${{ needs.generate-meta.outputs.tag }}
|
||||||
xxxxrt666/gpt-sovits:latest-${{ matrix.tag_prefix }}-amd64
|
xxxxrt666/gpt-sovits:latest-${{ matrix.tag_prefix }}
|
||||||
|
|
||||||
build-arm64:
|
|
||||||
needs: generate-meta
|
|
||||||
runs-on: ubuntu-22.04-arm
|
|
||||||
strategy:
|
|
||||||
matrix:
|
|
||||||
include:
|
|
||||||
- cuda_version: 12.6
|
|
||||||
lite: true
|
|
||||||
torch_base: lite
|
|
||||||
tag_prefix: cu126-lite
|
|
||||||
- cuda_version: 12.6
|
|
||||||
lite: false
|
|
||||||
torch_base: full
|
|
||||||
tag_prefix: cu126
|
|
||||||
- cuda_version: 12.8
|
|
||||||
lite: true
|
|
||||||
torch_base: lite
|
|
||||||
tag_prefix: cu128-lite
|
|
||||||
- cuda_version: 12.8
|
|
||||||
lite: false
|
|
||||||
torch_base: full
|
|
||||||
tag_prefix: cu128
|
|
||||||
|
|
||||||
steps:
|
|
||||||
- name: Checkout Code
|
|
||||||
uses: actions/checkout@v4
|
|
||||||
|
|
||||||
- name: Free up disk space
|
|
||||||
run: |
|
|
||||||
echo "Before cleanup:"
|
|
||||||
df -h
|
|
||||||
|
|
||||||
sudo rm -rf /opt/ghc
|
|
||||||
sudo rm -rf /opt/hostedtoolcache/CodeQL
|
|
||||||
sudo rm -rf /opt/hostedtoolcache/PyPy
|
|
||||||
sudo rm -rf /opt/hostedtoolcache/go
|
|
||||||
sudo rm -rf /opt/hostedtoolcache/node
|
|
||||||
sudo rm -rf /opt/hostedtoolcache/Ruby
|
|
||||||
sudo rm -rf /opt/microsoft
|
|
||||||
sudo rm -rf /opt/pipx
|
|
||||||
sudo rm -rf /opt/az
|
|
||||||
sudo rm -rf /opt/google
|
|
||||||
|
|
||||||
|
|
||||||
sudo rm -rf /usr/lib/jvm
|
|
||||||
sudo rm -rf /usr/lib/google-cloud-sdk
|
|
||||||
sudo rm -rf /usr/lib/dotnet
|
|
||||||
|
|
||||||
sudo rm -rf /usr/local/lib/android
|
|
||||||
sudo rm -rf /usr/local/.ghcup
|
|
||||||
sudo rm -rf /usr/local/julia1.11.5
|
|
||||||
sudo rm -rf /usr/local/share/powershell
|
|
||||||
sudo rm -rf /usr/local/share/chromium
|
|
||||||
|
|
||||||
sudo rm -rf /usr/share/swift
|
|
||||||
sudo rm -rf /usr/share/miniconda
|
|
||||||
sudo rm -rf /usr/share/az_12.1.0
|
|
||||||
sudo rm -rf /usr/share/dotnet
|
|
||||||
|
|
||||||
echo "After cleanup:"
|
|
||||||
df -h
|
|
||||||
|
|
||||||
- name: Set up Docker Buildx
|
|
||||||
uses: docker/setup-buildx-action@v3
|
|
||||||
|
|
||||||
- name: Log in to Docker Hub
|
|
||||||
uses: docker/login-action@v3
|
|
||||||
with:
|
|
||||||
username: ${{ secrets.DOCKER_HUB_USERNAME }}
|
|
||||||
password: ${{ secrets.DOCKER_HUB_PASSWORD }}
|
|
||||||
|
|
||||||
- name: Build and Push Docker Image (arm64)
|
|
||||||
uses: docker/build-push-action@v5
|
|
||||||
with:
|
|
||||||
context: .
|
|
||||||
file: ./Dockerfile
|
|
||||||
push: true
|
|
||||||
platforms: linux/arm64
|
|
||||||
build-args: |
|
|
||||||
LITE=${{ matrix.lite }}
|
|
||||||
TORCH_BASE=${{ matrix.torch_base }}
|
|
||||||
CUDA_VERSION=${{ matrix.cuda_version }}
|
|
||||||
WORKFLOW=true
|
|
||||||
tags: |
|
|
||||||
xxxxrt666/gpt-sovits:${{ matrix.tag_prefix }}-${{ needs.generate-meta.outputs.tag }}-arm64
|
|
||||||
xxxxrt666/gpt-sovits:latest-${{ matrix.tag_prefix }}-arm64
|
|
||||||
|
|
||||||
|
|
||||||
merge-and-clean:
|
|
||||||
needs:
|
|
||||||
- build-amd64
|
|
||||||
- build-arm64
|
|
||||||
- generate-meta
|
|
||||||
runs-on: ubuntu-latest
|
|
||||||
strategy:
|
|
||||||
matrix:
|
|
||||||
include:
|
|
||||||
- tag_prefix: cu126-lite
|
|
||||||
- tag_prefix: cu126
|
|
||||||
- tag_prefix: cu128-lite
|
|
||||||
- tag_prefix: cu128
|
|
||||||
|
|
||||||
steps:
|
|
||||||
- name: Set up Docker Buildx
|
|
||||||
uses: docker/setup-buildx-action@v3
|
|
||||||
|
|
||||||
- name: Log in to Docker Hub
|
|
||||||
uses: docker/login-action@v3
|
|
||||||
with:
|
|
||||||
username: ${{ secrets.DOCKER_HUB_USERNAME }}
|
|
||||||
password: ${{ secrets.DOCKER_HUB_PASSWORD }}
|
|
||||||
|
|
||||||
- name: Merge amd64 and arm64 into multi-arch image
|
|
||||||
run: |
|
|
||||||
DATE_TAG=${{ needs.generate-meta.outputs.tag }}
|
|
||||||
TAG_PREFIX=${{ matrix.tag_prefix }}
|
|
||||||
|
|
||||||
docker buildx imagetools create \
|
|
||||||
--tag ${{ secrets.DOCKER_HUB_USERNAME }}/gpt-sovits:${TAG_PREFIX}-${DATE_TAG} \
|
|
||||||
${{ secrets.DOCKER_HUB_USERNAME }}/gpt-sovits:${TAG_PREFIX}-${DATE_TAG}-amd64 \
|
|
||||||
${{ secrets.DOCKER_HUB_USERNAME }}/gpt-sovits:${TAG_PREFIX}-${DATE_TAG}-arm64
|
|
||||||
|
|
||||||
docker buildx imagetools create \
|
|
||||||
--tag ${{ secrets.DOCKER_HUB_USERNAME }}/gpt-sovits:latest-${TAG_PREFIX} \
|
|
||||||
${{ secrets.DOCKER_HUB_USERNAME }}/gpt-sovits:latest-${TAG_PREFIX}-amd64 \
|
|
||||||
${{ secrets.DOCKER_HUB_USERNAME }}/gpt-sovits:latest-${TAG_PREFIX}-arm64
|
|
||||||
- name: Delete old platform-specific tags via Docker Hub API
|
|
||||||
env:
|
|
||||||
DOCKER_HUB_USERNAME: ${{ secrets.DOCKER_HUB_USERNAME }}
|
|
||||||
DOCKER_HUB_TOKEN: ${{ secrets.DOCKER_HUB_PASSWORD }}
|
|
||||||
TAG_PREFIX: ${{ matrix.tag_prefix }}
|
|
||||||
DATE_TAG: ${{ needs.generate-meta.outputs.tag }}
|
|
||||||
run: |
|
|
||||||
sudo apt-get update && sudo apt-get install -y jq
|
|
||||||
|
|
||||||
TOKEN=$(curl -s -u $DOCKER_HUB_USERNAME:$DOCKER_HUB_TOKEN \
|
|
||||||
"https://auth.docker.io/token?service=registry.docker.io&scope=repository:$DOCKER_HUB_USERNAME/gpt-sovits:pull,push,delete" \
|
|
||||||
| jq -r .token)
|
|
||||||
|
|
||||||
for PLATFORM in amd64 arm64; do
|
|
||||||
SAFE_PLATFORM=$(echo $PLATFORM | sed 's/\//-/g')
|
|
||||||
TAG="${TAG_PREFIX}-${DATE_TAG}-${SAFE_PLATFORM}"
|
|
||||||
LATEST_TAG="latest-${TAG_PREFIX}-${SAFE_PLATFORM}"
|
|
||||||
|
|
||||||
for DEL_TAG in "$TAG" "$LATEST_TAG"; do
|
|
||||||
echo "Deleting tag: $DEL_TAG"
|
|
||||||
curl -X DELETE -H "Authorization: Bearer $TOKEN" https://registry-1.docker.io/v2/$DOCKER_HUB_USERNAME/gpt-sovits/manifests/$DEL_TAG
|
|
||||||
done
|
|
||||||
done
|
|
||||||
create-default:
|
|
||||||
runs-on: ubuntu-latest
|
|
||||||
needs:
|
|
||||||
- merge-and-clean
|
|
||||||
steps:
|
|
||||||
- name: Set up Docker Buildx
|
|
||||||
uses: docker/setup-buildx-action@v3
|
|
||||||
|
|
||||||
- name: Log in to Docker Hub
|
|
||||||
uses: docker/login-action@v3
|
|
||||||
with:
|
|
||||||
username: ${{ secrets.DOCKER_HUB_USERNAME }}
|
|
||||||
password: ${{ secrets.DOCKER_HUB_PASSWORD }}
|
|
||||||
|
|
||||||
- name: Create Default Tag
|
|
||||||
run: |
|
|
||||||
docker buildx imagetools create \
|
|
||||||
--tag ${{ secrets.DOCKER_HUB_USERNAME }}/gpt-sovits:latest \
|
|
||||||
${{ secrets.DOCKER_HUB_USERNAME }}/gpt-sovits:latest-cu126-lite
|
|
||||||
|
|
||||||
@ -58,9 +58,11 @@ def main():
|
|||||||
parser.add_argument("--input_wavs_dir", default="test_files")
|
parser.add_argument("--input_wavs_dir", default="test_files")
|
||||||
parser.add_argument("--output_dir", default="generated_files")
|
parser.add_argument("--output_dir", default="generated_files")
|
||||||
parser.add_argument("--checkpoint_file", required=True)
|
parser.add_argument("--checkpoint_file", required=True)
|
||||||
parser.add_argument("--use_cuda_kernel", action="store_true", default=False)
|
# --use_cuda_kernel argument is removed to disable custom CUDA kernels.
|
||||||
|
# parser.add_argument("--use_cuda_kernel", action="store_true", default=False)
|
||||||
|
|
||||||
a = parser.parse_args()
|
a = parser.parse_args()
|
||||||
|
a.use_cuda_kernel = False
|
||||||
|
|
||||||
config_file = os.path.join(os.path.split(a.checkpoint_file)[0], "config.json")
|
config_file = os.path.join(os.path.split(a.checkpoint_file)[0], "config.json")
|
||||||
with open(config_file) as f:
|
with open(config_file) as f:
|
||||||
@ -72,9 +74,9 @@ def main():
|
|||||||
|
|
||||||
torch.manual_seed(h.seed)
|
torch.manual_seed(h.seed)
|
||||||
global device
|
global device
|
||||||
if torch.cuda.is_available():
|
if torch.xpu.is_available():
|
||||||
torch.cuda.manual_seed(h.seed)
|
torch.xpu.manual_seed(h.seed)
|
||||||
device = torch.device("cuda")
|
device = torch.device("xpu")
|
||||||
else:
|
else:
|
||||||
device = torch.device("cpu")
|
device = torch.device("cpu")
|
||||||
|
|
||||||
|
|||||||
@ -73,9 +73,11 @@ def main():
|
|||||||
parser.add_argument("--input_mels_dir", default="test_mel_files")
|
parser.add_argument("--input_mels_dir", default="test_mel_files")
|
||||||
parser.add_argument("--output_dir", default="generated_files_from_mel")
|
parser.add_argument("--output_dir", default="generated_files_from_mel")
|
||||||
parser.add_argument("--checkpoint_file", required=True)
|
parser.add_argument("--checkpoint_file", required=True)
|
||||||
parser.add_argument("--use_cuda_kernel", action="store_true", default=False)
|
# --use_cuda_kernel argument is removed to disable custom CUDA kernels.
|
||||||
|
# parser.add_argument("--use_cuda_kernel", action="store_true", default=False)
|
||||||
|
|
||||||
a = parser.parse_args()
|
a = parser.parse_args()
|
||||||
|
a.use_cuda_kernel = False
|
||||||
|
|
||||||
config_file = os.path.join(os.path.split(a.checkpoint_file)[0], "config.json")
|
config_file = os.path.join(os.path.split(a.checkpoint_file)[0], "config.json")
|
||||||
with open(config_file) as f:
|
with open(config_file) as f:
|
||||||
@ -87,9 +89,9 @@ def main():
|
|||||||
|
|
||||||
torch.manual_seed(h.seed)
|
torch.manual_seed(h.seed)
|
||||||
global device
|
global device
|
||||||
if torch.cuda.is_available():
|
if torch.xpu.is_available():
|
||||||
torch.cuda.manual_seed(h.seed)
|
torch.xpu.manual_seed(h.seed)
|
||||||
device = torch.device("cuda")
|
device = torch.device("xpu")
|
||||||
else:
|
else:
|
||||||
device = torch.device("cpu")
|
device = torch.device("cpu")
|
||||||
|
|
||||||
|
|||||||
@ -1,215 +0,0 @@
|
|||||||
# Copyright (c) 2024 NVIDIA CORPORATION.
|
|
||||||
# Licensed under the MIT license.
|
|
||||||
|
|
||||||
import os
|
|
||||||
import sys
|
|
||||||
|
|
||||||
# to import modules from parent_dir
|
|
||||||
parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
|
||||||
sys.path.append(parent_dir)
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import json
|
|
||||||
from env import AttrDict
|
|
||||||
from bigvgan import BigVGAN
|
|
||||||
from time import time
|
|
||||||
from tqdm import tqdm
|
|
||||||
from meldataset import mel_spectrogram, MAX_WAV_VALUE
|
|
||||||
from scipy.io.wavfile import write
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
import argparse
|
|
||||||
|
|
||||||
torch.backends.cudnn.benchmark = True
|
|
||||||
|
|
||||||
# For easier debugging
|
|
||||||
torch.set_printoptions(linewidth=200, threshold=10_000)
|
|
||||||
|
|
||||||
|
|
||||||
def generate_soundwave(duration=5.0, sr=24000):
|
|
||||||
t = np.linspace(0, duration, int(sr * duration), False, dtype=np.float32)
|
|
||||||
|
|
||||||
modulation = np.sin(2 * np.pi * t / duration)
|
|
||||||
|
|
||||||
min_freq = 220
|
|
||||||
max_freq = 1760
|
|
||||||
frequencies = min_freq + (max_freq - min_freq) * (modulation + 1) / 2
|
|
||||||
soundwave = np.sin(2 * np.pi * frequencies * t)
|
|
||||||
|
|
||||||
soundwave = soundwave / np.max(np.abs(soundwave)) * 0.95
|
|
||||||
|
|
||||||
return soundwave, sr
|
|
||||||
|
|
||||||
|
|
||||||
def get_mel(x, h):
|
|
||||||
return mel_spectrogram(x, h.n_fft, h.num_mels, h.sampling_rate, h.hop_size, h.win_size, h.fmin, h.fmax)
|
|
||||||
|
|
||||||
|
|
||||||
def load_checkpoint(filepath, device):
|
|
||||||
assert os.path.isfile(filepath)
|
|
||||||
print(f"Loading '{filepath}'")
|
|
||||||
checkpoint_dict = torch.load(filepath, map_location=device)
|
|
||||||
print("Complete.")
|
|
||||||
return checkpoint_dict
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
parser = argparse.ArgumentParser(description="Test script to check CUDA kernel correctness.")
|
|
||||||
parser.add_argument(
|
|
||||||
"--checkpoint_file",
|
|
||||||
type=str,
|
|
||||||
required=True,
|
|
||||||
help="Path to the checkpoint file. Assumes config.json exists in the directory.",
|
|
||||||
)
|
|
||||||
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
config_file = os.path.join(os.path.split(args.checkpoint_file)[0], "config.json")
|
|
||||||
with open(config_file) as f:
|
|
||||||
config = f.read()
|
|
||||||
json_config = json.loads(config)
|
|
||||||
h = AttrDict({**json_config})
|
|
||||||
|
|
||||||
print("loading plain Pytorch BigVGAN")
|
|
||||||
generator_original = BigVGAN(h).to("cuda")
|
|
||||||
print("loading CUDA kernel BigVGAN with auto-build")
|
|
||||||
generator_cuda_kernel = BigVGAN(h, use_cuda_kernel=True).to("cuda")
|
|
||||||
|
|
||||||
state_dict_g = load_checkpoint(args.checkpoint_file, "cuda")
|
|
||||||
generator_original.load_state_dict(state_dict_g["generator"])
|
|
||||||
generator_cuda_kernel.load_state_dict(state_dict_g["generator"])
|
|
||||||
|
|
||||||
generator_original.remove_weight_norm()
|
|
||||||
generator_original.eval()
|
|
||||||
generator_cuda_kernel.remove_weight_norm()
|
|
||||||
generator_cuda_kernel.eval()
|
|
||||||
|
|
||||||
# define number of samples and length of mel frame to benchmark
|
|
||||||
num_sample = 10
|
|
||||||
num_mel_frame = 16384
|
|
||||||
|
|
||||||
# CUDA kernel correctness check
|
|
||||||
diff = 0.0
|
|
||||||
for i in tqdm(range(num_sample)):
|
|
||||||
# Random mel
|
|
||||||
data = torch.rand((1, h.num_mels, num_mel_frame), device="cuda")
|
|
||||||
|
|
||||||
with torch.inference_mode():
|
|
||||||
audio_original = generator_original(data)
|
|
||||||
|
|
||||||
with torch.inference_mode():
|
|
||||||
audio_cuda_kernel = generator_cuda_kernel(data)
|
|
||||||
|
|
||||||
# Both outputs should be (almost) the same
|
|
||||||
test_result = (audio_original - audio_cuda_kernel).abs()
|
|
||||||
diff += test_result.mean(dim=-1).item()
|
|
||||||
|
|
||||||
diff /= num_sample
|
|
||||||
if diff <= 2e-3: # We can expect a small difference (~1e-3) which does not affect perceptual quality
|
|
||||||
print(
|
|
||||||
f"\n[Success] test CUDA fused vs. plain torch BigVGAN inference"
|
|
||||||
f"\n > mean_difference={diff}"
|
|
||||||
f"\n > fused_values={audio_cuda_kernel[-1][-1][-30:].tolist()}"
|
|
||||||
f"\n > torch_values={audio_original[-1][-1][-30:].tolist()}"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
print(
|
|
||||||
f"\n[Fail] test CUDA fused vs. plain torch BigVGAN inference"
|
|
||||||
f"\n > mean_difference={diff}"
|
|
||||||
f"\n > fused_values={audio_cuda_kernel[-1][-1][-30:].tolist()}, "
|
|
||||||
f"\n > torch_values={audio_original[-1][-1][-30:].tolist()}"
|
|
||||||
)
|
|
||||||
|
|
||||||
del data, audio_original, audio_cuda_kernel
|
|
||||||
|
|
||||||
# Variables for tracking total time and VRAM usage
|
|
||||||
toc_total_original = 0
|
|
||||||
toc_total_cuda_kernel = 0
|
|
||||||
vram_used_original_total = 0
|
|
||||||
vram_used_cuda_kernel_total = 0
|
|
||||||
audio_length_total = 0
|
|
||||||
|
|
||||||
# Measure Original inference in isolation
|
|
||||||
for i in tqdm(range(num_sample)):
|
|
||||||
torch.cuda.reset_peak_memory_stats(device="cuda")
|
|
||||||
data = torch.rand((1, h.num_mels, num_mel_frame), device="cuda")
|
|
||||||
torch.cuda.synchronize()
|
|
||||||
tic = time()
|
|
||||||
with torch.inference_mode():
|
|
||||||
audio_original = generator_original(data)
|
|
||||||
torch.cuda.synchronize()
|
|
||||||
toc = time() - tic
|
|
||||||
toc_total_original += toc
|
|
||||||
|
|
||||||
vram_used_original_total += torch.cuda.max_memory_allocated(device="cuda")
|
|
||||||
|
|
||||||
del data, audio_original
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
|
|
||||||
# Measure CUDA kernel inference in isolation
|
|
||||||
for i in tqdm(range(num_sample)):
|
|
||||||
torch.cuda.reset_peak_memory_stats(device="cuda")
|
|
||||||
data = torch.rand((1, h.num_mels, num_mel_frame), device="cuda")
|
|
||||||
torch.cuda.synchronize()
|
|
||||||
tic = time()
|
|
||||||
with torch.inference_mode():
|
|
||||||
audio_cuda_kernel = generator_cuda_kernel(data)
|
|
||||||
torch.cuda.synchronize()
|
|
||||||
toc = time() - tic
|
|
||||||
toc_total_cuda_kernel += toc
|
|
||||||
|
|
||||||
audio_length_total += audio_cuda_kernel.shape[-1]
|
|
||||||
|
|
||||||
vram_used_cuda_kernel_total += torch.cuda.max_memory_allocated(device="cuda")
|
|
||||||
|
|
||||||
del data, audio_cuda_kernel
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
|
|
||||||
# Calculate metrics
|
|
||||||
audio_second = audio_length_total / h.sampling_rate
|
|
||||||
khz_original = audio_length_total / toc_total_original / 1000
|
|
||||||
khz_cuda_kernel = audio_length_total / toc_total_cuda_kernel / 1000
|
|
||||||
vram_used_original_gb = vram_used_original_total / num_sample / (1024**3)
|
|
||||||
vram_used_cuda_kernel_gb = vram_used_cuda_kernel_total / num_sample / (1024**3)
|
|
||||||
|
|
||||||
# Print results
|
|
||||||
print(
|
|
||||||
f"Original BigVGAN: took {toc_total_original:.2f} seconds to generate {audio_second:.2f} seconds of audio, {khz_original:.1f}kHz, {audio_second / toc_total_original:.1f} faster than realtime, VRAM used {vram_used_original_gb:.1f} GB"
|
|
||||||
)
|
|
||||||
print(
|
|
||||||
f"CUDA kernel BigVGAN: took {toc_total_cuda_kernel:.2f} seconds to generate {audio_second:.2f} seconds of audio, {khz_cuda_kernel:.1f}kHz, {audio_second / toc_total_cuda_kernel:.1f} faster than realtime, VRAM used {vram_used_cuda_kernel_gb:.1f} GB"
|
|
||||||
)
|
|
||||||
print(f"speedup of CUDA kernel: {khz_cuda_kernel / khz_original}")
|
|
||||||
print(f"VRAM saving of CUDA kernel: {vram_used_original_gb / vram_used_cuda_kernel_gb}")
|
|
||||||
|
|
||||||
# Use artificial sine waves for inference test
|
|
||||||
audio_real, sr = generate_soundwave(duration=5.0, sr=h.sampling_rate)
|
|
||||||
audio_real = torch.tensor(audio_real).to("cuda")
|
|
||||||
# Compute mel spectrogram from the ground truth audio
|
|
||||||
x = get_mel(audio_real.unsqueeze(0), h)
|
|
||||||
|
|
||||||
with torch.inference_mode():
|
|
||||||
y_g_hat_original = generator_original(x)
|
|
||||||
y_g_hat_cuda_kernel = generator_cuda_kernel(x)
|
|
||||||
|
|
||||||
audio_real = audio_real.squeeze()
|
|
||||||
audio_real = audio_real * MAX_WAV_VALUE
|
|
||||||
audio_real = audio_real.cpu().numpy().astype("int16")
|
|
||||||
|
|
||||||
audio_original = y_g_hat_original.squeeze()
|
|
||||||
audio_original = audio_original * MAX_WAV_VALUE
|
|
||||||
audio_original = audio_original.cpu().numpy().astype("int16")
|
|
||||||
|
|
||||||
audio_cuda_kernel = y_g_hat_cuda_kernel.squeeze()
|
|
||||||
audio_cuda_kernel = audio_cuda_kernel * MAX_WAV_VALUE
|
|
||||||
audio_cuda_kernel = audio_cuda_kernel.cpu().numpy().astype("int16")
|
|
||||||
|
|
||||||
os.makedirs("tmp", exist_ok=True)
|
|
||||||
output_file_real = os.path.join("tmp", "audio_real.wav")
|
|
||||||
output_file_original = os.path.join("tmp", "audio_generated_original.wav")
|
|
||||||
output_file_cuda_kernel = os.path.join("tmp", "audio_generated_cuda_kernel.wav")
|
|
||||||
write(output_file_real, h.sampling_rate, audio_real)
|
|
||||||
write(output_file_original, h.sampling_rate, audio_original)
|
|
||||||
write(output_file_cuda_kernel, h.sampling_rate, audio_cuda_kernel)
|
|
||||||
print("Example generated audios of original vs. fused CUDA kernel written to tmp!")
|
|
||||||
print("Done")
|
|
||||||
@ -110,15 +110,15 @@ def main(args):
|
|||||||
os.environ["USE_LIBUV"] = "0"
|
os.environ["USE_LIBUV"] = "0"
|
||||||
trainer: Trainer = Trainer(
|
trainer: Trainer = Trainer(
|
||||||
max_epochs=config["train"]["epochs"],
|
max_epochs=config["train"]["epochs"],
|
||||||
accelerator="gpu" if torch.cuda.is_available() else "cpu",
|
accelerator="xpu" if torch.xpu.is_available() else "cpu",
|
||||||
# val_check_interval=9999999999999999999999,###不要验证
|
# val_check_interval=9999999999999999999999,###不要验证
|
||||||
# check_val_every_n_epoch=None,
|
# check_val_every_n_epoch=None,
|
||||||
limit_val_batches=0,
|
limit_val_batches=0,
|
||||||
devices=-1 if torch.cuda.is_available() else 1,
|
devices=-1 if torch.xpu.is_available() else 1,
|
||||||
benchmark=False,
|
benchmark=False,
|
||||||
fast_dev_run=False,
|
fast_dev_run=False,
|
||||||
strategy=DDPStrategy(process_group_backend="nccl" if platform.system() != "Windows" else "gloo")
|
strategy=DDPStrategy(process_group_backend="ccl" if platform.system() != "Windows" else "gloo")
|
||||||
if torch.cuda.is_available()
|
if torch.xpu.is_available()
|
||||||
else "auto",
|
else "auto",
|
||||||
precision=config["train"]["precision"],
|
precision=config["train"]["precision"],
|
||||||
logger=logger,
|
logger=logger,
|
||||||
|
|||||||
@ -41,18 +41,18 @@ from process_ckpt import savee
|
|||||||
torch.backends.cudnn.benchmark = False
|
torch.backends.cudnn.benchmark = False
|
||||||
torch.backends.cudnn.deterministic = False
|
torch.backends.cudnn.deterministic = False
|
||||||
###反正A100fp32更快,那试试tf32吧
|
###反正A100fp32更快,那试试tf32吧
|
||||||
torch.backends.cuda.matmul.allow_tf32 = True
|
# torch.backends.cuda.matmul.allow_tf32 = True # XPU does not support this
|
||||||
torch.backends.cudnn.allow_tf32 = True
|
# torch.backends.cudnn.allow_tf32 = True # XPU does not support this
|
||||||
torch.set_float32_matmul_precision("medium") # 最低精度但最快(也就快一丁点),对于结果造成不了影响
|
torch.set_float32_matmul_precision("medium") # 最低精度但最快(也就快一丁点),对于结果造成不了影响
|
||||||
# from config import pretrained_s2G,pretrained_s2D
|
# from config import pretrained_s2G,pretrained_s2D
|
||||||
global_step = 0
|
global_step = 0
|
||||||
|
|
||||||
device = "cpu" # cuda以外的设备,等mps优化后加入
|
device = "xpu" if torch.xpu.is_available() else "cpu"
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
if torch.cuda.is_available():
|
if torch.xpu.is_available():
|
||||||
n_gpus = torch.cuda.device_count()
|
n_gpus = torch.xpu.device_count()
|
||||||
else:
|
else:
|
||||||
n_gpus = 1
|
n_gpus = 1
|
||||||
os.environ["MASTER_ADDR"] = "localhost"
|
os.environ["MASTER_ADDR"] = "localhost"
|
||||||
@ -78,14 +78,14 @@ def run(rank, n_gpus, hps):
|
|||||||
writer_eval = SummaryWriter(log_dir=os.path.join(hps.s2_ckpt_dir, "eval"))
|
writer_eval = SummaryWriter(log_dir=os.path.join(hps.s2_ckpt_dir, "eval"))
|
||||||
|
|
||||||
dist.init_process_group(
|
dist.init_process_group(
|
||||||
backend="gloo" if os.name == "nt" or not torch.cuda.is_available() else "nccl",
|
backend="gloo" if os.name == "nt" or not torch.xpu.is_available() else "ccl",
|
||||||
init_method="env://?use_libuv=False",
|
init_method="env://?use_libuv=False",
|
||||||
world_size=n_gpus,
|
world_size=n_gpus,
|
||||||
rank=rank,
|
rank=rank,
|
||||||
)
|
)
|
||||||
torch.manual_seed(hps.train.seed)
|
torch.manual_seed(hps.train.seed)
|
||||||
if torch.cuda.is_available():
|
if torch.xpu.is_available():
|
||||||
torch.cuda.set_device(rank)
|
torch.xpu.set_device(rank)
|
||||||
|
|
||||||
train_dataset = TextAudioSpeakerLoader(hps.data, version=hps.model.version)
|
train_dataset = TextAudioSpeakerLoader(hps.data, version=hps.model.version)
|
||||||
train_sampler = DistributedBucketSampler(
|
train_sampler = DistributedBucketSampler(
|
||||||
@ -132,27 +132,14 @@ def run(rank, n_gpus, hps):
|
|||||||
# batch_size=1, pin_memory=True,
|
# batch_size=1, pin_memory=True,
|
||||||
# drop_last=False, collate_fn=collate_fn)
|
# drop_last=False, collate_fn=collate_fn)
|
||||||
|
|
||||||
net_g = (
|
net_g = SynthesizerTrn(
|
||||||
SynthesizerTrn(
|
|
||||||
hps.data.filter_length // 2 + 1,
|
|
||||||
hps.train.segment_size // hps.data.hop_length,
|
|
||||||
n_speakers=hps.data.n_speakers,
|
|
||||||
**hps.model,
|
|
||||||
).cuda(rank)
|
|
||||||
if torch.cuda.is_available()
|
|
||||||
else SynthesizerTrn(
|
|
||||||
hps.data.filter_length // 2 + 1,
|
hps.data.filter_length // 2 + 1,
|
||||||
hps.train.segment_size // hps.data.hop_length,
|
hps.train.segment_size // hps.data.hop_length,
|
||||||
n_speakers=hps.data.n_speakers,
|
n_speakers=hps.data.n_speakers,
|
||||||
**hps.model,
|
**hps.model,
|
||||||
).to(device)
|
).to(device)
|
||||||
)
|
|
||||||
|
|
||||||
net_d = (
|
net_d = MultiPeriodDiscriminator(hps.model.use_spectral_norm, version=hps.model.version).to(device)
|
||||||
MultiPeriodDiscriminator(hps.model.use_spectral_norm, version=hps.model.version).cuda(rank)
|
|
||||||
if torch.cuda.is_available()
|
|
||||||
else MultiPeriodDiscriminator(hps.model.use_spectral_norm, version=hps.model.version).to(device)
|
|
||||||
)
|
|
||||||
for name, param in net_g.named_parameters():
|
for name, param in net_g.named_parameters():
|
||||||
if not param.requires_grad:
|
if not param.requires_grad:
|
||||||
print(name, "not requires_grad")
|
print(name, "not requires_grad")
|
||||||
@ -196,7 +183,7 @@ def run(rank, n_gpus, hps):
|
|||||||
betas=hps.train.betas,
|
betas=hps.train.betas,
|
||||||
eps=hps.train.eps,
|
eps=hps.train.eps,
|
||||||
)
|
)
|
||||||
if torch.cuda.is_available():
|
if torch.xpu.is_available():
|
||||||
net_g = DDP(net_g, device_ids=[rank], find_unused_parameters=True)
|
net_g = DDP(net_g, device_ids=[rank], find_unused_parameters=True)
|
||||||
net_d = DDP(net_d, device_ids=[rank], find_unused_parameters=True)
|
net_d = DDP(net_d, device_ids=[rank], find_unused_parameters=True)
|
||||||
else:
|
else:
|
||||||
@ -238,7 +225,7 @@ def run(rank, n_gpus, hps):
|
|||||||
torch.load(hps.train.pretrained_s2G, map_location="cpu", weights_only=False)["weight"],
|
torch.load(hps.train.pretrained_s2G, map_location="cpu", weights_only=False)["weight"],
|
||||||
strict=False,
|
strict=False,
|
||||||
)
|
)
|
||||||
if torch.cuda.is_available()
|
if torch.xpu.is_available()
|
||||||
else net_g.load_state_dict(
|
else net_g.load_state_dict(
|
||||||
torch.load(hps.train.pretrained_s2G, map_location="cpu", weights_only=False)["weight"],
|
torch.load(hps.train.pretrained_s2G, map_location="cpu", weights_only=False)["weight"],
|
||||||
strict=False,
|
strict=False,
|
||||||
@ -256,7 +243,7 @@ def run(rank, n_gpus, hps):
|
|||||||
net_d.module.load_state_dict(
|
net_d.module.load_state_dict(
|
||||||
torch.load(hps.train.pretrained_s2D, map_location="cpu", weights_only=False)["weight"], strict=False
|
torch.load(hps.train.pretrained_s2D, map_location="cpu", weights_only=False)["weight"], strict=False
|
||||||
)
|
)
|
||||||
if torch.cuda.is_available()
|
if torch.xpu.is_available()
|
||||||
else net_d.load_state_dict(
|
else net_d.load_state_dict(
|
||||||
torch.load(hps.train.pretrained_s2D, map_location="cpu", weights_only=False)["weight"],
|
torch.load(hps.train.pretrained_s2D, map_location="cpu", weights_only=False)["weight"],
|
||||||
),
|
),
|
||||||
@ -333,42 +320,24 @@ def train_and_evaluate(rank, epoch, hps, nets, optims, schedulers, scaler, loade
|
|||||||
ssl, ssl_lengths, spec, spec_lengths, y, y_lengths, text, text_lengths, sv_emb = data
|
ssl, ssl_lengths, spec, spec_lengths, y, y_lengths, text, text_lengths, sv_emb = data
|
||||||
else:
|
else:
|
||||||
ssl, ssl_lengths, spec, spec_lengths, y, y_lengths, text, text_lengths = data
|
ssl, ssl_lengths, spec, spec_lengths, y, y_lengths, text, text_lengths = data
|
||||||
if torch.cuda.is_available():
|
if torch.xpu.is_available():
|
||||||
spec, spec_lengths = (
|
spec, spec_lengths = (
|
||||||
spec.cuda(
|
spec.to(device, non_blocking=True),
|
||||||
rank,
|
spec_lengths.to(device, non_blocking=True),
|
||||||
non_blocking=True,
|
|
||||||
),
|
|
||||||
spec_lengths.cuda(
|
|
||||||
rank,
|
|
||||||
non_blocking=True,
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
y, y_lengths = (
|
y, y_lengths = (
|
||||||
y.cuda(
|
y.to(device, non_blocking=True),
|
||||||
rank,
|
y_lengths.to(device, non_blocking=True),
|
||||||
non_blocking=True,
|
|
||||||
),
|
|
||||||
y_lengths.cuda(
|
|
||||||
rank,
|
|
||||||
non_blocking=True,
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
ssl = ssl.cuda(rank, non_blocking=True)
|
ssl = ssl.to(device, non_blocking=True)
|
||||||
ssl.requires_grad = False
|
ssl.requires_grad = False
|
||||||
# ssl_lengths = ssl_lengths.cuda(rank, non_blocking=True)
|
# ssl_lengths = ssl_lengths.to(device, non_blocking=True)
|
||||||
text, text_lengths = (
|
text, text_lengths = (
|
||||||
text.cuda(
|
text.to(device, non_blocking=True),
|
||||||
rank,
|
text_lengths.to(device, non_blocking=True),
|
||||||
non_blocking=True,
|
|
||||||
),
|
|
||||||
text_lengths.cuda(
|
|
||||||
rank,
|
|
||||||
non_blocking=True,
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
if hps.model.version in {"v2Pro", "v2ProPlus"}:
|
if hps.model.version in {"v2Pro", "v2ProPlus"}:
|
||||||
sv_emb = sv_emb.cuda(rank, non_blocking=True)
|
sv_emb = sv_emb.to(device, non_blocking=True)
|
||||||
else:
|
else:
|
||||||
spec, spec_lengths = spec.to(device), spec_lengths.to(device)
|
spec, spec_lengths = spec.to(device), spec_lengths.to(device)
|
||||||
y, y_lengths = y.to(device), y_lengths.to(device)
|
y, y_lengths = y.to(device), y_lengths.to(device)
|
||||||
@ -596,11 +565,11 @@ def evaluate(hps, generator, eval_loader, writer_eval):
|
|||||||
text_lengths,
|
text_lengths,
|
||||||
) in enumerate(eval_loader):
|
) in enumerate(eval_loader):
|
||||||
print(111)
|
print(111)
|
||||||
if torch.cuda.is_available():
|
if torch.xpu.is_available():
|
||||||
spec, spec_lengths = spec.cuda(), spec_lengths.cuda()
|
spec, spec_lengths = spec.to(device), spec_lengths.to(device)
|
||||||
y, y_lengths = y.cuda(), y_lengths.cuda()
|
y, y_lengths = y.to(device), y_lengths.to(device)
|
||||||
ssl = ssl.cuda()
|
ssl = ssl.to(device)
|
||||||
text, text_lengths = text.cuda(), text_lengths.cuda()
|
text, text_lengths = text.to(device), text_lengths.to(device)
|
||||||
else:
|
else:
|
||||||
spec, spec_lengths = spec.to(device), spec_lengths.to(device)
|
spec, spec_lengths = spec.to(device), spec_lengths.to(device)
|
||||||
y, y_lengths = y.to(device), y_lengths.to(device)
|
y, y_lengths = y.to(device), y_lengths.to(device)
|
||||||
|
|||||||
13
api.py
13
api.py
@ -13,7 +13,7 @@
|
|||||||
`-dt` - `默认参考音频文本`
|
`-dt` - `默认参考音频文本`
|
||||||
`-dl` - `默认参考音频语种, "中文","英文","日文","韩文","粤语,"zh","en","ja","ko","yue"`
|
`-dl` - `默认参考音频语种, "中文","英文","日文","韩文","粤语,"zh","en","ja","ko","yue"`
|
||||||
|
|
||||||
`-d` - `推理设备, "cuda","cpu"`
|
`-d` - `推理设备, "xpu","cpu"`
|
||||||
`-a` - `绑定地址, 默认"127.0.0.1"`
|
`-a` - `绑定地址, 默认"127.0.0.1"`
|
||||||
`-p` - `绑定端口, 默认9880, 可在 config.py 中指定`
|
`-p` - `绑定端口, 默认9880, 可在 config.py 中指定`
|
||||||
`-fp` - `覆盖 config.py 使用全精度`
|
`-fp` - `覆盖 config.py 使用全精度`
|
||||||
@ -207,7 +207,8 @@ def clean_hifigan_model():
|
|||||||
hifigan_model = hifigan_model.cpu()
|
hifigan_model = hifigan_model.cpu()
|
||||||
hifigan_model = None
|
hifigan_model = None
|
||||||
try:
|
try:
|
||||||
torch.cuda.empty_cache()
|
if torch.xpu.is_available():
|
||||||
|
torch.xpu.empty_cache()
|
||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@ -218,7 +219,8 @@ def clean_bigvgan_model():
|
|||||||
bigvgan_model = bigvgan_model.cpu()
|
bigvgan_model = bigvgan_model.cpu()
|
||||||
bigvgan_model = None
|
bigvgan_model = None
|
||||||
try:
|
try:
|
||||||
torch.cuda.empty_cache()
|
if torch.xpu.is_available():
|
||||||
|
torch.xpu.empty_cache()
|
||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@ -229,7 +231,8 @@ def clean_sv_cn_model():
|
|||||||
sv_cn_model.embedding_model = sv_cn_model.embedding_model.cpu()
|
sv_cn_model.embedding_model = sv_cn_model.embedding_model.cpu()
|
||||||
sv_cn_model = None
|
sv_cn_model = None
|
||||||
try:
|
try:
|
||||||
torch.cuda.empty_cache()
|
if torch.xpu.is_available():
|
||||||
|
torch.xpu.empty_cache()
|
||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@ -1195,7 +1198,7 @@ parser.add_argument("-g", "--gpt_path", type=str, default=g_config.gpt_path, hel
|
|||||||
parser.add_argument("-dr", "--default_refer_path", type=str, default="", help="默认参考音频路径")
|
parser.add_argument("-dr", "--default_refer_path", type=str, default="", help="默认参考音频路径")
|
||||||
parser.add_argument("-dt", "--default_refer_text", type=str, default="", help="默认参考音频文本")
|
parser.add_argument("-dt", "--default_refer_text", type=str, default="", help="默认参考音频文本")
|
||||||
parser.add_argument("-dl", "--default_refer_language", type=str, default="", help="默认参考音频语种")
|
parser.add_argument("-dl", "--default_refer_language", type=str, default="", help="默认参考音频语种")
|
||||||
parser.add_argument("-d", "--device", type=str, default=g_config.infer_device, help="cuda / cpu")
|
parser.add_argument("-d", "--device", type=str, default=g_config.infer_device, help="xpu / cpu")
|
||||||
parser.add_argument("-a", "--bind_addr", type=str, default="0.0.0.0", help="default: 0.0.0.0")
|
parser.add_argument("-a", "--bind_addr", type=str, default="0.0.0.0", help="default: 0.0.0.0")
|
||||||
parser.add_argument("-p", "--port", type=int, default=g_config.api_port, help="default: 9880")
|
parser.add_argument("-p", "--port", type=int, default=g_config.api_port, help="default: 9880")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
|
|||||||
42
config.py
42
config.py
@ -146,32 +146,38 @@ api_port = 9880
|
|||||||
|
|
||||||
|
|
||||||
# Thanks to the contribution of @Karasukaigan and @XXXXRT666
|
# Thanks to the contribution of @Karasukaigan and @XXXXRT666
|
||||||
|
# Modified for Intel GPU (XPU)
|
||||||
def get_device_dtype_sm(idx: int) -> tuple[torch.device, torch.dtype, float, float]:
|
def get_device_dtype_sm(idx: int) -> tuple[torch.device, torch.dtype, float, float]:
|
||||||
cpu = torch.device("cpu")
|
cpu = torch.device("cpu")
|
||||||
cuda = torch.device(f"cuda:{idx}")
|
try:
|
||||||
if not torch.cuda.is_available():
|
if not torch.xpu.is_available():
|
||||||
return cpu, torch.float32, 0.0, 0.0
|
return cpu, torch.float32, 0.0, 0.0
|
||||||
device_idx = idx
|
except AttributeError:
|
||||||
capability = torch.cuda.get_device_capability(device_idx)
|
|
||||||
name = torch.cuda.get_device_name(device_idx)
|
|
||||||
mem_bytes = torch.cuda.get_device_properties(device_idx).total_memory
|
|
||||||
mem_gb = mem_bytes / (1024**3) + 0.4
|
|
||||||
major, minor = capability
|
|
||||||
sm_version = major + minor / 10.0
|
|
||||||
is_16_series = bool(re.search(r"16\d{2}", name)) and sm_version == 7.5
|
|
||||||
if mem_gb < 4 or sm_version < 5.3:
|
|
||||||
return cpu, torch.float32, 0.0, 0.0
|
return cpu, torch.float32, 0.0, 0.0
|
||||||
if sm_version == 6.1 or is_16_series == True:
|
|
||||||
return cuda, torch.float32, sm_version, mem_gb
|
xpu_device = torch.device(f"xpu:{idx}")
|
||||||
if sm_version > 6.1:
|
properties = torch.xpu.get_device_properties(idx)
|
||||||
return cuda, torch.float16, sm_version, mem_gb
|
mem_bytes = properties.total_memory
|
||||||
|
mem_gb = mem_bytes / (1024**3)
|
||||||
|
|
||||||
|
# Simplified logic for XPU, assuming FP16/BF16 is generally supported.
|
||||||
|
# The complex SM version check is CUDA-specific.
|
||||||
|
if mem_gb < 4: # Example threshold
|
||||||
return cpu, torch.float32, 0.0, 0.0
|
return cpu, torch.float32, 0.0, 0.0
|
||||||
|
|
||||||
|
# For Intel GPUs, we can generally assume float16 is available.
|
||||||
|
# The 'sm_version' equivalent is not straightforward, so we use a placeholder value (e.g., 1.0)
|
||||||
|
# for compatibility with the downstream logic that sorts devices.
|
||||||
|
return xpu_device, torch.float16, 1.0, mem_gb
|
||||||
|
|
||||||
|
|
||||||
IS_GPU = True
|
IS_GPU = True
|
||||||
GPU_INFOS: list[str] = []
|
GPU_INFOS: list[str] = []
|
||||||
GPU_INDEX: set[int] = set()
|
GPU_INDEX: set[int] = set()
|
||||||
GPU_COUNT = torch.cuda.device_count()
|
try:
|
||||||
|
GPU_COUNT = torch.xpu.device_count() if torch.xpu.is_available() else 0
|
||||||
|
except AttributeError:
|
||||||
|
GPU_COUNT = 0
|
||||||
CPU_INFO: str = "0\tCPU " + i18n("CPU训练,较慢")
|
CPU_INFO: str = "0\tCPU " + i18n("CPU训练,较慢")
|
||||||
tmp: list[tuple[torch.device, torch.dtype, float, float]] = []
|
tmp: list[tuple[torch.device, torch.dtype, float, float]] = []
|
||||||
memset: set[float] = set()
|
memset: set[float] = set()
|
||||||
@ -182,8 +188,8 @@ for i in range(max(GPU_COUNT, 1)):
|
|||||||
for j in tmp:
|
for j in tmp:
|
||||||
device = j[0]
|
device = j[0]
|
||||||
memset.add(j[3])
|
memset.add(j[3])
|
||||||
if device.type != "cpu":
|
if device.type == "xpu":
|
||||||
GPU_INFOS.append(f"{device.index}\t{torch.cuda.get_device_name(device.index)}")
|
GPU_INFOS.append(f"{device.index}\t{torch.xpu.get_device_name(device.index)}")
|
||||||
GPU_INDEX.add(device.index)
|
GPU_INDEX.add(device.index)
|
||||||
|
|
||||||
if not GPU_INFOS:
|
if not GPU_INFOS:
|
||||||
|
|||||||
@ -14,49 +14,29 @@ fi
|
|||||||
trap 'echo "Error Occured at \"$BASH_COMMAND\" with exit code $?"; exit 1' ERR
|
trap 'echo "Error Occured at \"$BASH_COMMAND\" with exit code $?"; exit 1' ERR
|
||||||
|
|
||||||
LITE=false
|
LITE=false
|
||||||
CUDA_VERSION=12.6
|
|
||||||
|
|
||||||
print_help() {
|
print_help() {
|
||||||
echo "Usage: bash docker_build.sh [OPTIONS]"
|
echo "Usage: bash docker_build.sh [OPTIONS]"
|
||||||
echo ""
|
echo ""
|
||||||
echo "Options:"
|
echo "Options:"
|
||||||
echo " --cuda 12.6|12.8 Specify the CUDA VERSION (REQUIRED)"
|
|
||||||
echo " --lite Build a Lite Image"
|
echo " --lite Build a Lite Image"
|
||||||
echo " -h, --help Show this help message and exit"
|
echo " -h, --help Show this help message and exit"
|
||||||
echo ""
|
echo ""
|
||||||
echo "Examples:"
|
echo "Examples:"
|
||||||
echo " bash docker_build.sh --cuda 12.6 --funasr --faster-whisper"
|
echo " bash docker_build.sh --lite"
|
||||||
}
|
}
|
||||||
|
|
||||||
# Show help if no arguments provided
|
|
||||||
if [[ $# -eq 0 ]]; then
|
|
||||||
print_help
|
|
||||||
exit 0
|
|
||||||
fi
|
|
||||||
|
|
||||||
# Parse arguments
|
# Parse arguments
|
||||||
while [[ $# -gt 0 ]]; do
|
while [[ $# -gt 0 ]]; do
|
||||||
case "$1" in
|
case "$1" in
|
||||||
--cuda)
|
|
||||||
case "$2" in
|
|
||||||
12.6)
|
|
||||||
CUDA_VERSION=12.6
|
|
||||||
;;
|
|
||||||
12.8)
|
|
||||||
CUDA_VERSION=12.8
|
|
||||||
;;
|
|
||||||
*)
|
|
||||||
echo "Error: Invalid CUDA_VERSION: $2"
|
|
||||||
echo "Choose From: [12.6, 12.8]"
|
|
||||||
exit 1
|
|
||||||
;;
|
|
||||||
esac
|
|
||||||
shift 2
|
|
||||||
;;
|
|
||||||
--lite)
|
--lite)
|
||||||
LITE=true
|
LITE=true
|
||||||
shift
|
shift
|
||||||
;;
|
;;
|
||||||
|
-h|--help)
|
||||||
|
print_help
|
||||||
|
exit 0
|
||||||
|
;;
|
||||||
*)
|
*)
|
||||||
echo "Unknown Argument: $1"
|
echo "Unknown Argument: $1"
|
||||||
echo "Use -h or --help to see available options."
|
echo "Use -h or --help to see available options."
|
||||||
@ -74,7 +54,6 @@ else
|
|||||||
fi
|
fi
|
||||||
|
|
||||||
docker build \
|
docker build \
|
||||||
--build-arg CUDA_VERSION=$CUDA_VERSION \
|
|
||||||
--build-arg LITE=$LITE \
|
--build-arg LITE=$LITE \
|
||||||
--build-arg TARGETPLATFORM="$TARGETPLATFORM" \
|
--build-arg TARGETPLATFORM="$TARGETPLATFORM" \
|
||||||
--build-arg TORCH_BASE=$TORCH_BASE \
|
--build-arg TORCH_BASE=$TORCH_BASE \
|
||||||
|
|||||||
@ -5,6 +5,9 @@ tensorboard
|
|||||||
librosa==0.10.2
|
librosa==0.10.2
|
||||||
numba
|
numba
|
||||||
pytorch-lightning>=2.4
|
pytorch-lightning>=2.4
|
||||||
|
torch==2.9
|
||||||
|
intel-extension-for-pytorch
|
||||||
|
torchvision
|
||||||
gradio<5
|
gradio<5
|
||||||
ffmpeg-python
|
ffmpeg-python
|
||||||
onnxruntime; platform_machine == "aarch64" or platform_machine == "arm64"
|
onnxruntime; platform_machine == "aarch64" or platform_machine == "arm64"
|
||||||
|
|||||||
@ -85,7 +85,7 @@ def execute_asr(input_folder, output_folder, model_path, language, precision):
|
|||||||
if language == "auto":
|
if language == "auto":
|
||||||
language = None # 不设置语种由模型自动输出概率最高的语种
|
language = None # 不设置语种由模型自动输出概率最高的语种
|
||||||
print("loading faster whisper model:", model_path, model_path)
|
print("loading faster whisper model:", model_path, model_path)
|
||||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
device = "xpu" if torch.xpu.is_available() else "cpu"
|
||||||
model = WhisperModel(model_path, device=device, compute_type=precision)
|
model = WhisperModel(model_path, device=device, compute_type=precision)
|
||||||
|
|
||||||
input_file_names = os.listdir(input_folder)
|
input_file_names = os.listdir(input_folder)
|
||||||
@ -128,8 +128,6 @@ def execute_asr(input_folder, output_folder, model_path, language, precision):
|
|||||||
return output_file_path
|
return output_file_path
|
||||||
|
|
||||||
|
|
||||||
load_cudnn()
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
|
|||||||
@ -135,97 +135,3 @@ def check_details(path_list=None, is_train=False, is_dataset_processing=False):
|
|||||||
...
|
...
|
||||||
else:
|
else:
|
||||||
gr.Warning(i18n("缺少语义数据集"))
|
gr.Warning(i18n("缺少语义数据集"))
|
||||||
|
|
||||||
|
|
||||||
def load_cudnn():
|
|
||||||
import torch
|
|
||||||
|
|
||||||
if not torch.cuda.is_available():
|
|
||||||
print("[INFO] CUDA is not available, skipping cuDNN setup.")
|
|
||||||
return
|
|
||||||
|
|
||||||
if sys.platform == "win32":
|
|
||||||
torch_lib_dir = Path(torch.__file__).parent / "lib"
|
|
||||||
if torch_lib_dir.exists():
|
|
||||||
os.add_dll_directory(str(torch_lib_dir))
|
|
||||||
print(f"[INFO] Added DLL directory: {torch_lib_dir}")
|
|
||||||
matching_files = sorted(torch_lib_dir.glob("cudnn_cnn*.dll"))
|
|
||||||
if not matching_files:
|
|
||||||
print(f"[ERROR] No cudnn_cnn*.dll found in {torch_lib_dir}")
|
|
||||||
return
|
|
||||||
for dll_path in matching_files:
|
|
||||||
dll_name = os.path.basename(dll_path)
|
|
||||||
try:
|
|
||||||
ctypes.CDLL(dll_name)
|
|
||||||
print(f"[INFO] Loaded: {dll_name}")
|
|
||||||
except OSError as e:
|
|
||||||
print(f"[WARNING] Failed to load {dll_name}: {e}")
|
|
||||||
else:
|
|
||||||
print(f"[WARNING] Torch lib directory not found: {torch_lib_dir}")
|
|
||||||
|
|
||||||
elif sys.platform == "linux":
|
|
||||||
site_packages = Path(torch.__file__).resolve().parents[1]
|
|
||||||
cudnn_dir = site_packages / "nvidia" / "cudnn" / "lib"
|
|
||||||
|
|
||||||
if not cudnn_dir.exists():
|
|
||||||
print(f"[ERROR] cudnn dir not found: {cudnn_dir}")
|
|
||||||
return
|
|
||||||
|
|
||||||
matching_files = sorted(cudnn_dir.glob("libcudnn_cnn*.so*"))
|
|
||||||
if not matching_files:
|
|
||||||
print(f"[ERROR] No libcudnn_cnn*.so* found in {cudnn_dir}")
|
|
||||||
return
|
|
||||||
|
|
||||||
for so_path in matching_files:
|
|
||||||
try:
|
|
||||||
ctypes.CDLL(so_path, mode=ctypes.RTLD_GLOBAL) # type: ignore
|
|
||||||
print(f"[INFO] Loaded: {so_path}")
|
|
||||||
except OSError as e:
|
|
||||||
print(f"[WARNING] Failed to load {so_path}: {e}")
|
|
||||||
|
|
||||||
|
|
||||||
def load_nvrtc():
|
|
||||||
import torch
|
|
||||||
|
|
||||||
if not torch.cuda.is_available():
|
|
||||||
print("[INFO] CUDA is not available, skipping nvrtc setup.")
|
|
||||||
return
|
|
||||||
|
|
||||||
if sys.platform == "win32":
|
|
||||||
torch_lib_dir = Path(torch.__file__).parent / "lib"
|
|
||||||
if torch_lib_dir.exists():
|
|
||||||
os.add_dll_directory(str(torch_lib_dir))
|
|
||||||
print(f"[INFO] Added DLL directory: {torch_lib_dir}")
|
|
||||||
matching_files = sorted(torch_lib_dir.glob("nvrtc*.dll"))
|
|
||||||
if not matching_files:
|
|
||||||
print(f"[ERROR] No nvrtc*.dll found in {torch_lib_dir}")
|
|
||||||
return
|
|
||||||
for dll_path in matching_files:
|
|
||||||
dll_name = os.path.basename(dll_path)
|
|
||||||
try:
|
|
||||||
ctypes.CDLL(dll_name)
|
|
||||||
print(f"[INFO] Loaded: {dll_name}")
|
|
||||||
except OSError as e:
|
|
||||||
print(f"[WARNING] Failed to load {dll_name}: {e}")
|
|
||||||
else:
|
|
||||||
print(f"[WARNING] Torch lib directory not found: {torch_lib_dir}")
|
|
||||||
|
|
||||||
elif sys.platform == "linux":
|
|
||||||
site_packages = Path(torch.__file__).resolve().parents[1]
|
|
||||||
nvrtc_dir = site_packages / "nvidia" / "cuda_nvrtc" / "lib"
|
|
||||||
|
|
||||||
if not nvrtc_dir.exists():
|
|
||||||
print(f"[ERROR] nvrtc dir not found: {nvrtc_dir}")
|
|
||||||
return
|
|
||||||
|
|
||||||
matching_files = sorted(nvrtc_dir.glob("libnvrtc*.so*"))
|
|
||||||
if not matching_files:
|
|
||||||
print(f"[ERROR] No libnvrtc*.so* found in {nvrtc_dir}")
|
|
||||||
return
|
|
||||||
|
|
||||||
for so_path in matching_files:
|
|
||||||
try:
|
|
||||||
ctypes.CDLL(so_path, mode=ctypes.RTLD_GLOBAL) # type: ignore
|
|
||||||
print(f"[INFO] Loaded: {so_path}")
|
|
||||||
except OSError as e:
|
|
||||||
print(f"[WARNING] Failed to load {so_path}: {e}")
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user