From 72be145051f03ef6d1f6f29e392cfae1c93602c0 Mon Sep 17 00:00:00 2001 From: KakaruHayate Date: Fri, 10 Oct 2025 21:14:27 +0800 Subject: [PATCH] Support on MUSA device. fix Update musa_utils.py Update musa_utils.py Update config.py fix rollback S1 train DDP only support S4000 DDP only support S4000 fix --- GPT_SoVITS/AR/data/bucket_sampler.py | 2 +- GPT_SoVITS/TTS_infer_pack/TTS.py | 14 ++- GPT_SoVITS/eres2net/kaldi.py | 12 +++ GPT_SoVITS/inference_webui.py | 13 ++- GPT_SoVITS/inference_webui_fast.py | 8 +- GPT_SoVITS/module/models.py | 6 +- GPT_SoVITS/prepare_datasets/1-get-text.py | 7 +- .../prepare_datasets/2-get-hubert-wav32k.py | 7 +- GPT_SoVITS/prepare_datasets/2-get-sv.py | 7 +- GPT_SoVITS/prepare_datasets/3-get-semantic.py | 7 +- GPT_SoVITS/s1_train.py | 2 +- GPT_SoVITS/s2_train.py | 85 ++++++++++++++-- GPT_SoVITS/s2_train_v3.py | 60 ++++++++++- GPT_SoVITS/s2_train_v3_lora.py | 17 +++- api.py | 3 + config.py | 21 +++- musa_utils.py | 99 +++++++++++++++++++ tools/asr/fasterwhisper_asr.py | 8 ++ tools/uvr5/bsroformer.py | 14 ++- tools/uvr5/webui.py | 9 ++ webui.py | 22 +++++ 21 files changed, 393 insertions(+), 30 deletions(-) create mode 100644 musa_utils.py diff --git a/GPT_SoVITS/AR/data/bucket_sampler.py b/GPT_SoVITS/AR/data/bucket_sampler.py index d8457334..a1ab6338 100644 --- a/GPT_SoVITS/AR/data/bucket_sampler.py +++ b/GPT_SoVITS/AR/data/bucket_sampler.py @@ -146,4 +146,4 @@ class DistributedBucketSampler(Sampler[T_co]): Args: epoch (int): Epoch number. """ - self.epoch = epoch + self.epoch = epoch \ No newline at end of file diff --git a/GPT_SoVITS/TTS_infer_pack/TTS.py b/GPT_SoVITS/TTS_infer_pack/TTS.py index 0c1d2484..e8aaa3f0 100644 --- a/GPT_SoVITS/TTS_infer_pack/TTS.py +++ b/GPT_SoVITS/TTS_infer_pack/TTS.py @@ -36,6 +36,8 @@ from TTS_infer_pack.text_segmentation_method import splits from TTS_infer_pack.TextPreprocessor import TextPreprocessor from sv import SV +import musa_utils + resample_transform_dict = {} @@ -209,6 +211,10 @@ def set_seed(seed: int): # 开启后会影响精度 torch.backends.cuda.matmul.allow_tf32 = False torch.backends.cudnn.allow_tf32 = False + elif musa_utils.is_available(): + musa_utils.manual_seed(seed) + musa_utils.manual_seed_all(seed) + torch.backends.mudnn.allow_tf32 = False except: pass return seed @@ -310,8 +316,10 @@ class TTS_Config: self.default_configs = deepcopy(configs_) self.device = self.configs.get("device", torch.device("cpu")) - if "cuda" in str(self.device) and not torch.cuda.is_available(): - print("Warning: CUDA is not available, set device to CPU.") + cuda_mismatch = "cuda" in str(self.device) and not torch.cuda.is_available() + musa_mismatch = "musa" in str(self.device) and not musa_utils.is_available() + if cuda_mismatch or musa_mismatch: + print(f"Warning: Requested device '{self.device}' is not available, falling back to CPU.") self.device = torch.device("cpu") self.is_half = self.configs.get("is_half", False) @@ -1369,6 +1377,8 @@ class TTS: gc.collect() # 触发gc的垃圾回收。避免内存一直增长。 if "cuda" in str(self.configs.device): torch.cuda.empty_cache() + elif "musa" in str(self.configs.device): + torch.musa.empty_cache() elif str(self.configs.device) == "mps": torch.mps.empty_cache() except: diff --git a/GPT_SoVITS/eres2net/kaldi.py b/GPT_SoVITS/eres2net/kaldi.py index a80e5e6b..4891f149 100644 --- a/GPT_SoVITS/eres2net/kaldi.py +++ b/GPT_SoVITS/eres2net/kaldi.py @@ -5,6 +5,8 @@ import torch import torchaudio from torch import Tensor +import musa_utils + __all__ = [ "get_mel_banks", "inverse_mel_scale", @@ -305,7 +307,12 @@ def spectrogram( ) # size (m, padded_window_size // 2 + 1, 2) + if musa_utils.is_available() : # 怎么还有算子不支持怎么还有算子不支持怎么还有算子不支持 + ori_device = strided_input.device + strided_input = strided_input.cpu() fft = torch.fft.rfft(strided_input) + if musa_utils.is_available() : + fft = fft.to(ori_device) # Convert the FFT into a power spectrum power_spectrum = torch.max(fft.abs().pow(2.0), epsilon).log() # size (m, padded_window_size // 2 + 1) @@ -618,7 +625,12 @@ def fbank( ) # size (m, padded_window_size // 2 + 1) + if musa_utils.is_available() : # 怎么还有算子不支持怎么还有算子不支持怎么还有算子不支持 + ori_device = strided_input.device + strided_input = strided_input.cpu() spectrum = torch.fft.rfft(strided_input).abs() + if musa_utils.is_available() : + spectrum = spectrum.to(ori_device) if use_power: spectrum = spectrum.pow(2.0) diff --git a/GPT_SoVITS/inference_webui.py b/GPT_SoVITS/inference_webui.py index a361ed58..25915dfc 100644 --- a/GPT_SoVITS/inference_webui.py +++ b/GPT_SoVITS/inference_webui.py @@ -49,6 +49,8 @@ from config import change_choices, get_weights_names, name2gpt_path, name2sovits SoVITS_names, GPT_names = get_weights_names() from config import pretrained_sovits_name +import musa_utils + path_sovits_v3 = pretrained_sovits_name["v3"] path_sovits_v4 = pretrained_sovits_name["v4"] is_exist_s2gv3 = os.path.exists(path_sovits_v3) @@ -87,7 +89,9 @@ is_share = os.environ.get("is_share", "False") is_share = eval(is_share) if "_CUDA_VISIBLE_DEVICES" in os.environ: os.environ["CUDA_VISIBLE_DEVICES"] = os.environ["_CUDA_VISIBLE_DEVICES"] -is_half = eval(os.environ.get("is_half", "True")) and torch.cuda.is_available() +if "_MUSA_VISIBLE_DEVICES" in os.environ: + os.environ["MUSA_VISIBLE_DEVICES"] = os.environ["_MUSA_VISIBLE_DEVICES"] +is_half = eval(os.environ.get("is_half", "True")) and (torch.cuda.is_available() or musa_utils.is_available()) # is_half=False punctuation = set(["!", "?", "…", ",", ".", "-", " "]) import gradio as gr @@ -112,6 +116,8 @@ def set_seed(seed): np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed(seed) + if musa_utils.is_available(): + musa_utils.manual_seed(seed) # set_seed(42) @@ -134,6 +140,8 @@ i18n = I18nAuto(language=language) if torch.cuda.is_available(): device = "cuda" +elif musa_utils.is_available(): + device = "musa" else: device = "cpu" @@ -411,6 +419,7 @@ def clean_hifigan_model(): hifigan_model = None try: torch.cuda.empty_cache() + torch.musa.empty_cache() except: pass @@ -422,6 +431,7 @@ def clean_bigvgan_model(): bigvgan_model = None try: torch.cuda.empty_cache() + torch.musa.empty_cache() except: pass @@ -433,6 +443,7 @@ def clean_sv_cn_model(): sv_cn_model = None try: torch.cuda.empty_cache() + torch.musa.empty_cache() except: pass diff --git a/GPT_SoVITS/inference_webui_fast.py b/GPT_SoVITS/inference_webui_fast.py index 51a120f1..01dd9922 100644 --- a/GPT_SoVITS/inference_webui_fast.py +++ b/GPT_SoVITS/inference_webui_fast.py @@ -29,6 +29,8 @@ import sys import torch +import musa_utils + now_dir = os.getcwd() sys.path.append(now_dir) sys.path.append("%s/GPT_SoVITS" % (now_dir)) @@ -48,8 +50,10 @@ is_share = os.environ.get("is_share", "False") is_share = eval(is_share) if "_CUDA_VISIBLE_DEVICES" in os.environ: os.environ["CUDA_VISIBLE_DEVICES"] = os.environ["_CUDA_VISIBLE_DEVICES"] +if "_MUSA_VISIBLE_DEVICES" in os.environ: + os.environ["MUSA_VISIBLE_DEVICES"] = os.environ["_MUSA_VISIBLE_DEVICES"] -is_half = eval(os.environ.get("is_half", "True")) and torch.cuda.is_available() +is_half = eval(os.environ.get("is_half", "True")) and (torch.cuda.is_available() or musa_utils.is_available()) gpt_path = os.environ.get("gpt_path", None) sovits_path = os.environ.get("sovits_path", None) cnhubert_base_path = os.environ.get("cnhubert_base_path", None) @@ -72,6 +76,8 @@ i18n = I18nAuto(language=language) if torch.cuda.is_available(): device = "cuda" +elif musa_utils.is_available(): + device = "musa" # elif torch.backends.mps.is_available(): # device = "mps" else: diff --git a/GPT_SoVITS/module/models.py b/GPT_SoVITS/module/models.py index 1c8e662f..f5263273 100644 --- a/GPT_SoVITS/module/models.py +++ b/GPT_SoVITS/module/models.py @@ -7,6 +7,7 @@ import torch from torch import nn from torch.nn import functional as F +import musa_utils from module import commons from module import modules from module import attentions @@ -20,7 +21,10 @@ from module.quantize import ResidualVectorQuantizer # from text import symbols from text import symbols as symbols_v1 from text import symbols2 as symbols_v2 -from torch.cuda.amp import autocast +if musa_utils.is_available(): + autocast = torch.musa.amp.autocast +elif torch.cuda.is_available(): + from torch.cuda.amp import autocast import contextlib import random diff --git a/GPT_SoVITS/prepare_datasets/1-get-text.py b/GPT_SoVITS/prepare_datasets/1-get-text.py index 8d83e79a..bb0a1064 100644 --- a/GPT_SoVITS/prepare_datasets/1-get-text.py +++ b/GPT_SoVITS/prepare_datasets/1-get-text.py @@ -9,11 +9,14 @@ i_part = os.environ.get("i_part") all_parts = os.environ.get("all_parts") if "_CUDA_VISIBLE_DEVICES" in os.environ: os.environ["CUDA_VISIBLE_DEVICES"] = os.environ["_CUDA_VISIBLE_DEVICES"] +if "_MUSA_VISIBLE_DEVICES" in os.environ: + os.environ["MUSA_VISIBLE_DEVICES"] = os.environ["_MUSA_VISIBLE_DEVICES"] opt_dir = os.environ.get("opt_dir") bert_pretrained_dir = os.environ.get("bert_pretrained_dir") import torch +import musa_utils -is_half = eval(os.environ.get("is_half", "True")) and torch.cuda.is_available() +is_half = eval(os.environ.get("is_half", "True")) and (torch.cuda.is_available() or musa_utils.is_available()) version = os.environ.get("version", None) import traceback import os.path @@ -50,6 +53,8 @@ if os.path.exists(txt_path) == False: os.makedirs(bert_dir, exist_ok=True) if torch.cuda.is_available(): device = "cuda:0" + elif musa_utils.is_available(): + device = "musa:0" # elif torch.backends.mps.is_available(): # device = "mps" else: diff --git a/GPT_SoVITS/prepare_datasets/2-get-hubert-wav32k.py b/GPT_SoVITS/prepare_datasets/2-get-hubert-wav32k.py index 3a84c014..1c6d78b2 100644 --- a/GPT_SoVITS/prepare_datasets/2-get-hubert-wav32k.py +++ b/GPT_SoVITS/prepare_datasets/2-get-hubert-wav32k.py @@ -10,13 +10,16 @@ i_part = os.environ.get("i_part") all_parts = os.environ.get("all_parts") if "_CUDA_VISIBLE_DEVICES" in os.environ: os.environ["CUDA_VISIBLE_DEVICES"] = os.environ["_CUDA_VISIBLE_DEVICES"] +if "_MUSA_VISIBLE_DEVICES" in os.environ: + os.environ["MUSA_VISIBLE_DEVICES"] = os.environ["_MUSA_VISIBLE_DEVICES"] from feature_extractor import cnhubert opt_dir = os.environ.get("opt_dir") cnhubert.cnhubert_base_path = os.environ.get("cnhubert_base_dir") import torch +import musa_utils -is_half = eval(os.environ.get("is_half", "True")) and torch.cuda.is_available() +is_half = eval(os.environ.get("is_half", "True")) and (torch.cuda.is_available() or musa_utils.is_available()) import traceback import numpy as np @@ -61,6 +64,8 @@ maxx = 0.95 alpha = 0.5 if torch.cuda.is_available(): device = "cuda:0" +elif musa_utils.is_available(): + device = "musa:0" # elif torch.backends.mps.is_available(): # device = "mps" else: diff --git a/GPT_SoVITS/prepare_datasets/2-get-sv.py b/GPT_SoVITS/prepare_datasets/2-get-sv.py index 80b0ad69..33b9b916 100644 --- a/GPT_SoVITS/prepare_datasets/2-get-sv.py +++ b/GPT_SoVITS/prepare_datasets/2-get-sv.py @@ -10,12 +10,15 @@ i_part = os.environ.get("i_part") all_parts = os.environ.get("all_parts") if "_CUDA_VISIBLE_DEVICES" in os.environ: os.environ["CUDA_VISIBLE_DEVICES"] = os.environ["_CUDA_VISIBLE_DEVICES"] +if "_MUSA_VISIBLE_DEVICES" in os.environ: + os.environ["MUSA_VISIBLE_DEVICES"] = os.environ["_MUSA_VISIBLE_DEVICES"] opt_dir = os.environ.get("opt_dir") sv_path = os.environ.get("sv_path") import torch +import musa_utils -is_half = eval(os.environ.get("is_half", "True")) and torch.cuda.is_available() +is_half = eval(os.environ.get("is_half", "True")) and (torch.cuda.is_available() or musa_utils.is_available()) import traceback import torchaudio @@ -49,6 +52,8 @@ maxx = 0.95 alpha = 0.5 if torch.cuda.is_available(): device = "cuda:0" +elif musa_utils.is_available(): + device = "musa:0" # elif torch.backends.mps.is_available(): # device = "mps" else: diff --git a/GPT_SoVITS/prepare_datasets/3-get-semantic.py b/GPT_SoVITS/prepare_datasets/3-get-semantic.py index ddb0607c..40366de9 100644 --- a/GPT_SoVITS/prepare_datasets/3-get-semantic.py +++ b/GPT_SoVITS/prepare_datasets/3-get-semantic.py @@ -6,6 +6,8 @@ i_part = os.environ.get("i_part") all_parts = os.environ.get("all_parts") if "_CUDA_VISIBLE_DEVICES" in os.environ: os.environ["CUDA_VISIBLE_DEVICES"] = os.environ["_CUDA_VISIBLE_DEVICES"] +if "_MUSA_VISIBLE_DEVICES" in os.environ: + os.environ["MUSA_VISIBLE_DEVICES"] = os.environ["_MUSA_VISIBLE_DEVICES"] opt_dir = os.environ.get("opt_dir") pretrained_s2G = os.environ.get("pretrained_s2G") s2config_path = os.environ.get("s2config_path") @@ -27,8 +29,9 @@ elif size < 700 * 1024 * 1024: else: version = "v3" import torch +import musa_utils -is_half = eval(os.environ.get("is_half", "True")) and torch.cuda.is_available() +is_half = eval(os.environ.get("is_half", "True")) and (torch.cuda.is_available() or musa_utils.is_available()) import traceback import sys @@ -61,6 +64,8 @@ if os.path.exists(semantic_path) == False: if torch.cuda.is_available(): device = "cuda" + elif musa_utils.is_available(): + device = "musa:0" # elif torch.backends.mps.is_available(): # device = "mps" else: diff --git a/GPT_SoVITS/s1_train.py b/GPT_SoVITS/s1_train.py index 1176f0bc..3eeb57c1 100644 --- a/GPT_SoVITS/s1_train.py +++ b/GPT_SoVITS/s1_train.py @@ -168,4 +168,4 @@ if __name__ == "__main__": args = parser.parse_args() logging.info(str(args)) - main(args) + main(args) \ No newline at end of file diff --git a/GPT_SoVITS/s2_train.py b/GPT_SoVITS/s2_train.py index 4b9f6488..7d504edd 100644 --- a/GPT_SoVITS/s2_train.py +++ b/GPT_SoVITS/s2_train.py @@ -4,15 +4,23 @@ warnings.filterwarnings("ignore") import os import utils +import musa_utils hps = utils.get_hparams(stage=2) os.environ["CUDA_VISIBLE_DEVICES"] = hps.train.gpu_numbers.replace("-", ",") +os.environ["MUSA_VISIBLE_DEVICES"] = hps.train.gpu_numbers.replace("-", ",") import logging import torch import torch.distributed as dist import torch.multiprocessing as mp -from torch.cuda.amp import GradScaler, autocast +from torch.cuda.amp import GradScaler +musa_ddp = False +if musa_utils.is_available(): + autocast = torch.musa.amp.autocast + musa_ddp = musa_utils.should_ddp() +elif torch.cuda.is_available(): + from torch.cuda.amp import autocast from torch.nn import functional as F from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.data import DataLoader @@ -43,16 +51,22 @@ torch.backends.cudnn.deterministic = False ###反正A100fp32更快,那试试tf32吧 torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True +if musa_utils.is_available(): + torch.backends.mudnn.allow_tf32 = True torch.set_float32_matmul_precision("medium") # 最低精度但最快(也就快一丁点),对于结果造成不了影响 # from config import pretrained_s2G,pretrained_s2D global_step = 0 device = "cpu" # cuda以外的设备,等mps优化后加入 +if not musa_ddp: + device = "musa" def main(): if torch.cuda.is_available(): n_gpus = torch.cuda.device_count() + elif musa_ddp: + n_gpus = musa_utils.device_count() else: n_gpus = 1 os.environ["MASTER_ADDR"] = "localhost" @@ -78,7 +92,7 @@ def run(rank, n_gpus, hps): writer_eval = SummaryWriter(log_dir=os.path.join(hps.s2_ckpt_dir, "eval")) dist.init_process_group( - backend="gloo" if os.name == "nt" or not torch.cuda.is_available() else "nccl", + backend = "mccl" if torch.musa.is_available() else ("gloo" if os.name == "nt" or not torch.cuda.is_available() else "nccl"), init_method="env://?use_libuv=False", world_size=n_gpus, rank=rank, @@ -86,6 +100,8 @@ def run(rank, n_gpus, hps): torch.manual_seed(hps.train.seed) if torch.cuda.is_available(): torch.cuda.set_device(rank) + elif musa_ddp: + musa_utils.set_device(rank) train_dataset = TextAudioSpeakerLoader(hps.data, version=hps.model.version) train_sampler = DistributedBucketSampler( @@ -140,6 +156,13 @@ def run(rank, n_gpus, hps): **hps.model, ).cuda(rank) if torch.cuda.is_available() + else SynthesizerTrn( + hps.data.filter_length // 2 + 1, + hps.train.segment_size // hps.data.hop_length, + n_speakers=hps.data.n_speakers, + **hps.model, + ).musa(rank) + if musa_ddp else SynthesizerTrn( hps.data.filter_length // 2 + 1, hps.train.segment_size // hps.data.hop_length, @@ -151,6 +174,8 @@ def run(rank, n_gpus, hps): net_d = ( MultiPeriodDiscriminator(hps.model.use_spectral_norm, version=hps.model.version).cuda(rank) if torch.cuda.is_available() + else MultiPeriodDiscriminator(hps.model.use_spectral_norm, version=hps.model.version).musa(rank) + if musa_ddp else MultiPeriodDiscriminator(hps.model.use_spectral_norm, version=hps.model.version).to(device) ) for name, param in net_g.named_parameters(): @@ -196,7 +221,7 @@ def run(rank, n_gpus, hps): betas=hps.train.betas, eps=hps.train.eps, ) - if torch.cuda.is_available(): + if torch.cuda.is_available() or musa_ddp: net_g = DDP(net_g, device_ids=[rank], find_unused_parameters=True) net_d = DDP(net_d, device_ids=[rank], find_unused_parameters=True) else: @@ -238,7 +263,7 @@ def run(rank, n_gpus, hps): torch.load(hps.train.pretrained_s2G, map_location="cpu", weights_only=False)["weight"], strict=False, ) - if torch.cuda.is_available() + if torch.cuda.is_available() or musa_ddp else net_g.load_state_dict( torch.load(hps.train.pretrained_s2G, map_location="cpu", weights_only=False)["weight"], strict=False, @@ -256,7 +281,7 @@ def run(rank, n_gpus, hps): net_d.module.load_state_dict( torch.load(hps.train.pretrained_s2D, map_location="cpu", weights_only=False)["weight"], strict=False ) - if torch.cuda.is_available() + if torch.cuda.is_available() or musa_ddp else net_d.load_state_dict( torch.load(hps.train.pretrained_s2D, map_location="cpu", weights_only=False)["weight"], ), @@ -279,7 +304,10 @@ def run(rank, n_gpus, hps): scheduler_g.step() scheduler_d.step() - scaler = GradScaler(enabled=hps.train.fp16_run) + if musa_utils.is_available(): + scaler = torch.musa.amp.GradScaler(enabled=hps.train.fp16_run) + else: + scaler = GradScaler(enabled=hps.train.fp16_run) print("start training from epoch %s" % epoch_str) for epoch in range(epoch_str, hps.train.epochs + 1): @@ -369,6 +397,42 @@ def train_and_evaluate(rank, epoch, hps, nets, optims, schedulers, scaler, loade ) if hps.model.version in {"v2Pro", "v2ProPlus"}: sv_emb = sv_emb.cuda(rank, non_blocking=True) + elif musa_ddp: + spec, spec_lengths = ( + spec.musa( + rank, + non_blocking=True, + ), + spec_lengths.musa( + rank, + non_blocking=True, + ), + ) + y, y_lengths = ( + y.musa( + rank, + non_blocking=True, + ), + y_lengths.musa( + rank, + non_blocking=True, + ), + ) + ssl = ssl.musa(rank, non_blocking=True) + ssl.requires_grad = False + # ssl_lengths = ssl_lengths.musa(rank, non_blocking=True) + text, text_lengths = ( + text.musa( + rank, + non_blocking=True, + ), + text_lengths.musa( + rank, + non_blocking=True, + ), + ) + if hps.model.version in {"v2Pro", "v2ProPlus"}: + sv_emb = sv_emb.musa(rank, non_blocking=True) else: spec, spec_lengths = spec.to(device), spec_lengths.to(device) y, y_lengths = y.to(device), y_lengths.to(device) @@ -595,12 +659,17 @@ def evaluate(hps, generator, eval_loader, writer_eval): text, text_lengths, ) in enumerate(eval_loader): - print(111) + print("确实在跑") if torch.cuda.is_available(): spec, spec_lengths = spec.cuda(), spec_lengths.cuda() y, y_lengths = y.cuda(), y_lengths.cuda() ssl = ssl.cuda() text, text_lengths = text.cuda(), text_lengths.cuda() + elif musa_utils.is_available(): + spec, spec_lengths = spec.musa(), spec_lengths.musa() + y, y_lengths = y.musa(), y_lengths.musa() + ssl = ssl.musa() + text, text_lengths = text.musa(), text_lengths.musa() else: spec, spec_lengths = spec.to(device), spec_lengths.to(device) y, y_lengths = y.to(device), y_lengths.to(device) @@ -616,7 +685,7 @@ def evaluate(hps, generator, eval_loader, writer_eval): text_lengths, test=test, ) - if torch.cuda.is_available() + if torch.cuda.is_available() or musa_utils.is_available() else generator.infer( ssl, spec, diff --git a/GPT_SoVITS/s2_train_v3.py b/GPT_SoVITS/s2_train_v3.py index aa8dae7f..ea3accdf 100644 --- a/GPT_SoVITS/s2_train_v3.py +++ b/GPT_SoVITS/s2_train_v3.py @@ -4,15 +4,22 @@ warnings.filterwarnings("ignore") import os import utils +import musa_utils hps = utils.get_hparams(stage=2) os.environ["CUDA_VISIBLE_DEVICES"] = hps.train.gpu_numbers.replace("-", ",") +os.environ["MUSA_VISIBLE_DEVICES"] = hps.train.gpu_numbers.replace("-", ",") import logging import torch import torch.distributed as dist import torch.multiprocessing as mp -from torch.cuda.amp import GradScaler, autocast +from torch.cuda.amp import GradScaler +if musa_utils.is_available(): + autocast = torch.musa.amp.autocast + musa_ddp = musa_utils.should_ddp() +elif torch.cuda.is_available(): + from torch.cuda.amp import autocast from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.data import DataLoader from torch.utils.tensorboard import SummaryWriter @@ -43,16 +50,22 @@ torch.backends.cudnn.deterministic = False ###反正A100fp32更快,那试试tf32吧 torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True +if musa_utils.is_available(): + torch.backends.mudnn.allow_tf32 = True torch.set_float32_matmul_precision("medium") # 最低精度但最快(也就快一丁点),对于结果造成不了影响 # from config import pretrained_s2G,pretrained_s2D global_step = 0 device = "cpu" # cuda以外的设备,等mps优化后加入 +if not musa_ddp: + device = "musa" def main(): if torch.cuda.is_available(): n_gpus = torch.cuda.device_count() + elif musa_ddp: + n_gpus = musa_utils.device_count() else: n_gpus = 1 os.environ["MASTER_ADDR"] = "localhost" @@ -78,7 +91,7 @@ def run(rank, n_gpus, hps): writer_eval = SummaryWriter(log_dir=os.path.join(hps.s2_ckpt_dir, "eval")) dist.init_process_group( - backend="gloo" if os.name == "nt" or not torch.cuda.is_available() else "nccl", + backend= "mccl" if torch.musa.is_available() else ("gloo" if os.name == "nt" or not torch.cuda.is_available() else "nccl"), init_method="env://?use_libuv=False", world_size=n_gpus, rank=rank, @@ -86,6 +99,8 @@ def run(rank, n_gpus, hps): torch.manual_seed(hps.train.seed) if torch.cuda.is_available(): torch.cuda.set_device(rank) + elif musa_ddp: + musa_utils.set_device(rank) train_dataset = TextAudioSpeakerLoader(hps.data) ######## train_sampler = DistributedBucketSampler( @@ -140,6 +155,13 @@ def run(rank, n_gpus, hps): **hps.model, ).cuda(rank) if torch.cuda.is_available() + else SynthesizerTrn( + hps.data.filter_length // 2 + 1, + hps.train.segment_size // hps.data.hop_length, + n_speakers=hps.data.n_speakers, + **hps.model, + ).musa(rank) + if musa_ddp else SynthesizerTrn( hps.data.filter_length // 2 + 1, hps.train.segment_size // hps.data.hop_length, @@ -165,7 +187,7 @@ def run(rank, n_gpus, hps): # betas=hps.train.betas, # eps=hps.train.eps, # ) - if torch.cuda.is_available(): + if torch.cuda.is_available() or musa_ddp: net_g = DDP(net_g, device_ids=[rank], find_unused_parameters=True) # net_d = DDP(net_d, device_ids=[rank], find_unused_parameters=True) else: @@ -207,7 +229,7 @@ def run(rank, n_gpus, hps): torch.load(hps.train.pretrained_s2G, map_location="cpu", weights_only=False)["weight"], strict=False, ) - if torch.cuda.is_available() + if torch.cuda.is_available() or musa_ddp else net_g.load_state_dict( torch.load(hps.train.pretrained_s2G, map_location="cpu", weights_only=False)["weight"], strict=False, @@ -235,7 +257,10 @@ def run(rank, n_gpus, hps): scheduler_g.step() # scheduler_d.step() - scaler = GradScaler(enabled=hps.train.fp16_run) + if musa_utils.is_available(): + scaler = torch.musa.amp.GradScaler(enabled=hps.train.fp16_run) + else: + scaler = GradScaler(enabled=hps.train.fp16_run) net_d = optim_d = scheduler_d = None print("start training from epoch %s" % epoch_str) @@ -334,6 +359,31 @@ def train_and_evaluate( non_blocking=True, ), ) + elif musa_ddp: + spec, spec_lengths = ( + spec.musa( + rank, + non_blocking=True, + ), + spec_lengths.musa( + rank, + non_blocking=True, + ), + ) + mel, mel_lengths = mel.musa(rank, non_blocking=True), mel_lengths.musa(rank, non_blocking=True) + ssl = ssl.musa(rank, non_blocking=True) + ssl.requires_grad = False + # ssl_lengths = ssl_lengths.musa(rank, non_blocking=True) + text, text_lengths = ( + text.musa( + rank, + non_blocking=True, + ), + text_lengths.musa( + rank, + non_blocking=True, + ), + ) else: spec, spec_lengths = spec.to(device), spec_lengths.to(device) mel, mel_lengths = mel.to(device), mel_lengths.to(device) diff --git a/GPT_SoVITS/s2_train_v3_lora.py b/GPT_SoVITS/s2_train_v3_lora.py index ba9e4ed4..db502a7b 100644 --- a/GPT_SoVITS/s2_train_v3_lora.py +++ b/GPT_SoVITS/s2_train_v3_lora.py @@ -4,15 +4,21 @@ warnings.filterwarnings("ignore") import os import utils +import musa_utils hps = utils.get_hparams(stage=2) os.environ["CUDA_VISIBLE_DEVICES"] = hps.train.gpu_numbers.replace("-", ",") +os.environ["MUSA_VISIBLE_DEVICES"] = hps.train.gpu_numbers.replace("-", ",") import logging import torch import torch.distributed as dist import torch.multiprocessing as mp -from torch.cuda.amp import GradScaler, autocast +from torch.cuda.amp import GradScaler +if musa_utils.is_available(): + autocast = torch.musa.amp.autocast +elif torch.cuda.is_available(): + from torch.cuda.amp import autocast from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.data import DataLoader from torch.utils.tensorboard import SummaryWriter @@ -43,11 +49,15 @@ torch.backends.cudnn.deterministic = False ###反正A100fp32更快,那试试tf32吧 torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True +if musa_utils.is_available(): + torch.backends.mudnn.allow_tf32 = True torch.set_float32_matmul_precision("medium") # 最低精度但最快(也就快一丁点),对于结果造成不了影响 # from config import pretrained_s2G,pretrained_s2D global_step = 0 device = "cpu" # cuda以外的设备,等mps优化后加入 +if musa_utils.is_available(): # DDP支持不写了,没设备测试 + device = "musa" def main(): @@ -209,7 +219,10 @@ def run(rank, n_gpus, hps): for _ in range(epoch_str): scheduler_g.step() - scaler = GradScaler(enabled=hps.train.fp16_run) + if musa_utils.is_available(): + scaler = torch.musa.amp.GradScaler(enabled=hps.train.fp16_run) + else: + scaler = GradScaler(enabled=hps.train.fp16_run) net_d = optim_d = scheduler_d = None print("start training from epoch %s" % epoch_str) diff --git a/api.py b/api.py index cc0896a2..ed4cb51f 100644 --- a/api.py +++ b/api.py @@ -208,6 +208,7 @@ def clean_hifigan_model(): hifigan_model = None try: torch.cuda.empty_cache() + torch.musa.empty_cache() except: pass @@ -219,6 +220,7 @@ def clean_bigvgan_model(): bigvgan_model = None try: torch.cuda.empty_cache() + torch.musa.empty_cache() except: pass @@ -230,6 +232,7 @@ def clean_sv_cn_model(): sv_cn_model = None try: torch.cuda.empty_cache() + torch.musa.empty_cache() except: pass diff --git a/config.py b/config.py index fdc11c0a..00a7bb43 100644 --- a/config.py +++ b/config.py @@ -4,6 +4,8 @@ import sys import torch +import musa_utils + from tools.i18n.i18n import I18nAuto i18n = I18nAuto(language=os.environ.get("language", "Auto")) @@ -147,11 +149,15 @@ api_port = 9880 # Thanks to the contribution of @Karasukaigan and @XXXXRT666 def get_device_dtype_sm(idx: int) -> tuple[torch.device, torch.dtype, float, float]: + device_idx = idx cpu = torch.device("cpu") cuda = torch.device(f"cuda:{idx}") + if musa_utils.is_available(): + musa = torch.device(f"musa:{idx}") + mdtype, mversion, mmem_gb = musa_utils.get_device_dtype(device_idx) + return musa, mdtype, mversion, mmem_gb if not torch.cuda.is_available(): return cpu, torch.float32, 0.0, 0.0 - device_idx = idx capability = torch.cuda.get_device_capability(device_idx) name = torch.cuda.get_device_name(device_idx) mem_bytes = torch.cuda.get_device_properties(device_idx).total_memory @@ -167,11 +173,16 @@ def get_device_dtype_sm(idx: int) -> tuple[torch.device, torch.dtype, float, flo return cuda, torch.float16, sm_version, mem_gb return cpu, torch.float32, 0.0, 0.0 +def get_gpu_count(): + if musa_utils.is_available(): + return musa_utils.device_count() + else: + return torch.cuda.device_count() IS_GPU = True GPU_INFOS: list[str] = [] GPU_INDEX: set[int] = set() -GPU_COUNT = torch.cuda.device_count() +GPU_COUNT = get_gpu_count() CPU_INFO: str = "0\tCPU " + i18n("CPU训练,较慢") tmp: list[tuple[torch.device, torch.dtype, float, float]] = [] memset: set[float] = set() @@ -183,7 +194,11 @@ for j in tmp: device = j[0] memset.add(j[3]) if device.type != "cpu": - GPU_INFOS.append(f"{device.index}\t{torch.cuda.get_device_name(device.index)}") + if device.type == "cuda" and torch.cuda.is_available(): + device_name = torch.cuda.get_device_name(device.index) + elif device.type == "musa" and musa_utils.is_available(): + device_name = torch.musa.get_device_properties(device.index).name + GPU_INFOS.append(f"{device.index}\t{device_name}") GPU_INDEX.add(device.index) if not GPU_INFOS: diff --git a/musa_utils.py b/musa_utils.py new file mode 100644 index 00000000..d6de79dc --- /dev/null +++ b/musa_utils.py @@ -0,0 +1,99 @@ +import torch +from typing import Optional +import re + +_musa_available = False +_musa_err_msg = "torch_musa not found or not configured correctly." + +try: + if hasattr(torch, 'musa') and torch.musa.is_available(): + + try: + major, minor = torch.musa.get_device_capability() + version_val = major + minor / 10.0 + + if version_val < 2.1: + raise RuntimeError( + f"MUSA version check failed! " + f"Found capability {major}.{minor} (version value {version_val:.2f}), " + f"but this project requires a version >= 2.1. " + f"Please upgrade your torch_musa and MUSA SDK." + f"See: https://github.com/MooreThreads/torch_musa" + ) + + _musa_available = True + + except Exception as e: + _musa_err_msg = f"MUSA availability check failed: {e}" + if isinstance(e, RuntimeError): + raise e + _musa_available = False + +except Exception: + _musa_available = False + + +def is_available() -> bool: + return _musa_available + +def get_device() -> Optional[torch.device]: + return torch.device("musa") if _musa_available else None + +def device_count() -> int: + if _musa_available: + try: + return torch.musa.device_count() + except Exception: + return 0 + return 0 + +def set_device(device_index: int): + if _musa_available: + try: + torch.musa.set_device(device_index) + except Exception: + pass + +def empty_cache(): + if _musa_available: + try: + torch.musa.empty_cache() + except Exception: + pass + +def manual_seed(seed: int): + if _musa_available: + try: + torch.musa.manual_seed(seed) + except Exception: + pass + +def manual_seed_all(seed: int): + if _musa_available: + try: + torch.musa.manual_seed_all(seed) + except Exception: + pass + +def get_device_dtype(device_idx: int = 0) -> tuple[torch.dtype, float, float]: + properties = torch.musa.get_device_properties(device_idx) + major, minor = torch.musa.get_device_capability() + version_val = major + minor / 10.0 + mem_bytes = properties.total_memory + mem_gb = mem_bytes / (1024**3) + 0.4 + device_name = properties.name + dtype = torch.float32 + numbers_in_name = [int(n) for n in re.findall(r'\d+', device_name)] + if any(num >= 4000 for num in numbers_in_name): + dtype = torch.float16 + return dtype, version_val, mem_gb + +def should_ddp(device_idx: int = 0) -> bool: + device_name = torch.musa.get_device_properties(device_idx).name + numbers_in_name = [int(n) for n in re.findall(r'\d+', device_name)] + if any(num >= 4000 for num in numbers_in_name): + return True + else: + return False + +DEVICE: Optional[torch.device] = get_device() \ No newline at end of file diff --git a/tools/asr/fasterwhisper_asr.py b/tools/asr/fasterwhisper_asr.py index a2ebe975..a72b7237 100644 --- a/tools/asr/fasterwhisper_asr.py +++ b/tools/asr/fasterwhisper_asr.py @@ -13,6 +13,12 @@ from tools.asr.config import get_models from tools.asr.funasr_asr import only_asr from tools.my_utils import load_cudnn +try: + import torch_musa + use_torch_musa = True +except ImportError: + use_torch_musa = False + # fmt: off language_code_list = [ "af", "am", "ar", "as", "az", @@ -86,6 +92,8 @@ def execute_asr(input_folder, output_folder, model_path, language, precision): language = None # 不设置语种由模型自动输出概率最高的语种 print("loading faster whisper model:", model_path, model_path) device = "cuda" if torch.cuda.is_available() else "cpu" + if use_torch_musa: + device = "musa" if torch.musa.is_available() else "cpu" model = WhisperModel(model_path, device=device, compute_type=precision) input_file_names = os.listdir(input_folder) diff --git a/tools/uvr5/bsroformer.py b/tools/uvr5/bsroformer.py index ddcbfa74..f081f283 100644 --- a/tools/uvr5/bsroformer.py +++ b/tools/uvr5/bsroformer.py @@ -10,6 +10,12 @@ import torch.nn as nn import yaml from tqdm import tqdm +try: + import torch_musa + use_torch_musa = True +except ImportError: + use_torch_musa = False + warnings.filterwarnings("ignore") @@ -135,7 +141,13 @@ class Roformer_Loader: window_middle[-fade_size:] *= fadeout window_middle[:fade_size] *= fadein - with torch.amp.autocast("cuda"): + if use_torch_musa: + if torch.musa.is_available(): + set_device = "musa" + else: + set_device = "cuda" + + with torch.amp.autocast(set_device): with torch.inference_mode(): if self.config["training"]["target_instrument"] is None: req_shape = (len(self.config["training"]["instruments"]),) + tuple(mix.shape) diff --git a/tools/uvr5/webui.py b/tools/uvr5/webui.py index f5f8d3f6..1066fd89 100644 --- a/tools/uvr5/webui.py +++ b/tools/uvr5/webui.py @@ -18,6 +18,12 @@ from bsroformer import Roformer_Loader from mdxnet import MDXNetDereverb from vr import AudioPre, AudioPreDeEcho +try: + import torch_musa + use_torch_musa = True +except ImportError: + use_torch_musa = False + weight_uvr5_root = "tools/uvr5/uvr5_weights" uvr5_names = [] for name in os.listdir(weight_uvr5_root): @@ -122,6 +128,9 @@ def uvr(model_name, inp_root, save_root_vocal, paths, save_root_ins, agg, format print("clean_empty_cache") if torch.cuda.is_available(): torch.cuda.empty_cache() + if use_torch_musa: + if torch.musa.is_available(): + torch.musa.empty_cache() yield "\n".join(infos) diff --git a/webui.py b/webui.py index cf5d8a3a..04432646 100644 --- a/webui.py +++ b/webui.py @@ -92,6 +92,8 @@ os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE" # os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1' # 当遇到mps不支持的步骤时使用cpu import gradio as gr +import musa_utils + n_cpu = cpu_count() set_gpu_numbers = GPU_INDEX @@ -347,6 +349,8 @@ def change_tts_inference(bert_path, cnhubert_base_path, gpu_number, gpt_path, so os.environ["is_half"] = str(is_half) os.environ["infer_ttswebui"] = str(webui_port_infer_tts) os.environ["is_share"] = str(is_share) + if musa_utils.is_available(): + os.environ["_MUSA_VISIBLE_DEVICES"] = str(fix_gpu_number(gpu_number)) yield ( process_info(process_name_tts, "opened"), {"__type__": "update", "visible": False}, @@ -629,6 +633,8 @@ def open1Bb( # data["version"]=version os.environ["_CUDA_VISIBLE_DEVICES"] = str(fix_gpu_numbers(gpu_numbers.replace("-", ","))) + if musa_utils.is_available(): + os.environ["_MUSA_VISIBLE_DEVICES"] = str(fix_gpu_number(gpu_numbers.replace("-", ","))) os.environ["hz"] = "25hz" tmp_config_path = "%s/tmp_s1.yaml" % tmp with open(tmp_config_path, "w") as f: @@ -805,6 +811,8 @@ def open1a(inp_text, inp_wav_dir, exp_name, gpu_numbers, bert_pretrained_dir): "is_half": str(is_half), } ) + if musa_utils.is_available(): + config.update({"_MUSA_VISIBLE_DEVICES": str(fix_gpu_number(gpu_names[i_part])),}) os.environ.update(config) cmd = '"%s" -s GPT_SoVITS/prepare_datasets/1-get-text.py' % python_exec print(cmd) @@ -895,6 +903,8 @@ def open1b(version, inp_text, inp_wav_dir, exp_name, gpu_numbers, ssl_pretrained "_CUDA_VISIBLE_DEVICES": str(fix_gpu_number(gpu_names[i_part])), } ) + if musa_utils.is_available(): + config.update({"_MUSA_VISIBLE_DEVICES": str(fix_gpu_number(gpu_names[i_part])),}) os.environ.update(config) cmd = '"%s" -s GPT_SoVITS/prepare_datasets/2-get-hubert-wav32k.py' % python_exec print(cmd) @@ -917,6 +927,8 @@ def open1b(version, inp_text, inp_wav_dir, exp_name, gpu_numbers, ssl_pretrained "_CUDA_VISIBLE_DEVICES": str(fix_gpu_number(gpu_names[i_part])), } ) + if musa_utils.is_available(): + config.update({"_MUSA_VISIBLE_DEVICES": str(fix_gpu_number(gpu_names[i_part])),}) os.environ.update(config) cmd = '"%s" -s GPT_SoVITS/prepare_datasets/2-get-sv.py' % python_exec print(cmd) @@ -989,6 +1001,8 @@ def open1c(version, inp_text, inp_wav_dir, exp_name, gpu_numbers, pretrained_s2G "_CUDA_VISIBLE_DEVICES": str(fix_gpu_number(gpu_names[i_part])), } ) + if musa_utils.is_available(): + config.update({"_MUSA_VISIBLE_DEVICES": str(fix_gpu_number(gpu_names[i_part])),}) os.environ.update(config) cmd = '"%s" -s GPT_SoVITS/prepare_datasets/3-get-semantic.py' % python_exec print(cmd) @@ -1089,6 +1103,8 @@ def open1abc( "_CUDA_VISIBLE_DEVICES": str(fix_gpu_number(gpu_names[i_part])), } ) + if musa_utils.is_available(): + config.update({"_MUSA_VISIBLE_DEVICES": str(fix_gpu_number(gpu_names[i_part])),}) os.environ.update(config) cmd = '"%s" -s GPT_SoVITS/prepare_datasets/1-get-text.py' % python_exec print(cmd) @@ -1136,6 +1152,8 @@ def open1abc( "_CUDA_VISIBLE_DEVICES": str(fix_gpu_number(gpu_names[i_part])), } ) + if musa_utils.is_available(): + config.update({"_MUSA_VISIBLE_DEVICES": str(fix_gpu_number(gpu_names[i_part])),}) os.environ.update(config) cmd = '"%s" -s GPT_SoVITS/prepare_datasets/2-get-hubert-wav32k.py' % python_exec print(cmd) @@ -1158,6 +1176,8 @@ def open1abc( "_CUDA_VISIBLE_DEVICES": str(fix_gpu_number(gpu_names[i_part])), } ) + if musa_utils.is_available(): + config.update({"_MUSA_VISIBLE_DEVICES": str(fix_gpu_number(gpu_names[i_part])),}) os.environ.update(config) cmd = '"%s" -s GPT_SoVITS/prepare_datasets/2-get-sv.py' % python_exec print(cmd) @@ -1198,6 +1218,8 @@ def open1abc( "_CUDA_VISIBLE_DEVICES": str(fix_gpu_number(gpu_names[i_part])), } ) + if musa_utils.is_available(): + config.update({"_MUSA_VISIBLE_DEVICES": str(fix_gpu_number(gpu_names[i_part])),}) os.environ.update(config) cmd = '"%s" -s GPT_SoVITS/prepare_datasets/3-get-semantic.py' % python_exec print(cmd)