Merge 90141d2029e8a752e9ee8ead77dfca69220fc0c3 into 11aa78bd9bda8b53047cfcae03abf7ca94d27391

This commit is contained in:
Kakaru 2025-10-13 14:22:06 +00:00 committed by GitHub
commit bc11c2b618
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
22 changed files with 484 additions and 36 deletions

View File

@ -146,4 +146,4 @@ class DistributedBucketSampler(Sampler[T_co]):
Args: Args:
epoch (int): Epoch number. epoch (int): Epoch number.
""" """
self.epoch = epoch self.epoch = epoch

View File

@ -36,6 +36,8 @@ from TTS_infer_pack.text_segmentation_method import splits
from TTS_infer_pack.TextPreprocessor import TextPreprocessor from TTS_infer_pack.TextPreprocessor import TextPreprocessor
from sv import SV from sv import SV
import musa_utils
resample_transform_dict = {} resample_transform_dict = {}
@ -209,6 +211,10 @@ def set_seed(seed: int):
# 开启后会影响精度 # 开启后会影响精度
torch.backends.cuda.matmul.allow_tf32 = False torch.backends.cuda.matmul.allow_tf32 = False
torch.backends.cudnn.allow_tf32 = False torch.backends.cudnn.allow_tf32 = False
elif musa_utils.is_available():
musa_utils.manual_seed(seed)
musa_utils.manual_seed_all(seed)
torch.backends.mudnn.allow_tf32 = False
except: except:
pass pass
return seed return seed
@ -310,8 +316,10 @@ class TTS_Config:
self.default_configs = deepcopy(configs_) self.default_configs = deepcopy(configs_)
self.device = self.configs.get("device", torch.device("cpu")) self.device = self.configs.get("device", torch.device("cpu"))
if "cuda" in str(self.device) and not torch.cuda.is_available(): cuda_mismatch = "cuda" in str(self.device) and not torch.cuda.is_available()
print("Warning: CUDA is not available, set device to CPU.") musa_mismatch = "musa" in str(self.device) and not musa_utils.is_available()
if cuda_mismatch or musa_mismatch:
print(f"Warning: Requested device '{self.device}' is not available, falling back to CPU.")
self.device = torch.device("cpu") self.device = torch.device("cpu")
self.is_half = self.configs.get("is_half", False) self.is_half = self.configs.get("is_half", False)
@ -1369,6 +1377,8 @@ class TTS:
gc.collect() # 触发gc的垃圾回收。避免内存一直增长。 gc.collect() # 触发gc的垃圾回收。避免内存一直增长。
if "cuda" in str(self.configs.device): if "cuda" in str(self.configs.device):
torch.cuda.empty_cache() torch.cuda.empty_cache()
elif "musa" in str(self.configs.device):
torch.musa.empty_cache()
elif str(self.configs.device) == "mps": elif str(self.configs.device) == "mps":
torch.mps.empty_cache() torch.mps.empty_cache()
except: except:

View File

@ -5,6 +5,8 @@ import torch
import torchaudio import torchaudio
from torch import Tensor from torch import Tensor
import musa_utils
__all__ = [ __all__ = [
"get_mel_banks", "get_mel_banks",
"inverse_mel_scale", "inverse_mel_scale",
@ -305,7 +307,12 @@ def spectrogram(
) )
# size (m, padded_window_size // 2 + 1, 2) # size (m, padded_window_size // 2 + 1, 2)
if musa_utils.is_available() : # 怎么还有算子不支持怎么还有算子不支持怎么还有算子不支持
ori_device = strided_input.device
strided_input = strided_input.cpu()
fft = torch.fft.rfft(strided_input) fft = torch.fft.rfft(strided_input)
if musa_utils.is_available() :
fft = fft.to(ori_device)
# Convert the FFT into a power spectrum # Convert the FFT into a power spectrum
power_spectrum = torch.max(fft.abs().pow(2.0), epsilon).log() # size (m, padded_window_size // 2 + 1) power_spectrum = torch.max(fft.abs().pow(2.0), epsilon).log() # size (m, padded_window_size // 2 + 1)
@ -618,7 +625,12 @@ def fbank(
) )
# size (m, padded_window_size // 2 + 1) # size (m, padded_window_size // 2 + 1)
if musa_utils.is_available() : # 怎么还有算子不支持怎么还有算子不支持怎么还有算子不支持
ori_device = strided_input.device
strided_input = strided_input.cpu()
spectrum = torch.fft.rfft(strided_input).abs() spectrum = torch.fft.rfft(strided_input).abs()
if musa_utils.is_available() :
spectrum = spectrum.to(ori_device)
if use_power: if use_power:
spectrum = spectrum.pow(2.0) spectrum = spectrum.pow(2.0)

View File

@ -49,6 +49,8 @@ from config import change_choices, get_weights_names, name2gpt_path, name2sovits
SoVITS_names, GPT_names = get_weights_names() SoVITS_names, GPT_names = get_weights_names()
from config import pretrained_sovits_name from config import pretrained_sovits_name
import musa_utils
path_sovits_v3 = pretrained_sovits_name["v3"] path_sovits_v3 = pretrained_sovits_name["v3"]
path_sovits_v4 = pretrained_sovits_name["v4"] path_sovits_v4 = pretrained_sovits_name["v4"]
is_exist_s2gv3 = os.path.exists(path_sovits_v3) is_exist_s2gv3 = os.path.exists(path_sovits_v3)
@ -87,7 +89,9 @@ is_share = os.environ.get("is_share", "False")
is_share = eval(is_share) is_share = eval(is_share)
if "_CUDA_VISIBLE_DEVICES" in os.environ: if "_CUDA_VISIBLE_DEVICES" in os.environ:
os.environ["CUDA_VISIBLE_DEVICES"] = os.environ["_CUDA_VISIBLE_DEVICES"] os.environ["CUDA_VISIBLE_DEVICES"] = os.environ["_CUDA_VISIBLE_DEVICES"]
is_half = eval(os.environ.get("is_half", "True")) and torch.cuda.is_available() if "_MUSA_VISIBLE_DEVICES" in os.environ:
os.environ["MUSA_VISIBLE_DEVICES"] = os.environ["_MUSA_VISIBLE_DEVICES"]
is_half = eval(os.environ.get("is_half", "True")) and (torch.cuda.is_available() or musa_utils.is_available())
# is_half=False # is_half=False
punctuation = set(["!", "?", "", ",", ".", "-", " "]) punctuation = set(["!", "?", "", ",", ".", "-", " "])
import gradio as gr import gradio as gr
@ -112,6 +116,8 @@ def set_seed(seed):
np.random.seed(seed) np.random.seed(seed)
torch.manual_seed(seed) torch.manual_seed(seed)
torch.cuda.manual_seed(seed) torch.cuda.manual_seed(seed)
if musa_utils.is_available():
musa_utils.manual_seed(seed)
# set_seed(42) # set_seed(42)
@ -134,6 +140,8 @@ i18n = I18nAuto(language=language)
if torch.cuda.is_available(): if torch.cuda.is_available():
device = "cuda" device = "cuda"
elif musa_utils.is_available():
device = "musa"
else: else:
device = "cpu" device = "cpu"
@ -411,6 +419,7 @@ def clean_hifigan_model():
hifigan_model = None hifigan_model = None
try: try:
torch.cuda.empty_cache() torch.cuda.empty_cache()
torch.musa.empty_cache()
except: except:
pass pass
@ -422,6 +431,7 @@ def clean_bigvgan_model():
bigvgan_model = None bigvgan_model = None
try: try:
torch.cuda.empty_cache() torch.cuda.empty_cache()
torch.musa.empty_cache()
except: except:
pass pass
@ -433,6 +443,7 @@ def clean_sv_cn_model():
sv_cn_model = None sv_cn_model = None
try: try:
torch.cuda.empty_cache() torch.cuda.empty_cache()
torch.musa.empty_cache()
except: except:
pass pass

View File

@ -29,6 +29,8 @@ import sys
import torch import torch
import musa_utils
now_dir = os.getcwd() now_dir = os.getcwd()
sys.path.append(now_dir) sys.path.append(now_dir)
sys.path.append("%s/GPT_SoVITS" % (now_dir)) sys.path.append("%s/GPT_SoVITS" % (now_dir))
@ -48,8 +50,10 @@ is_share = os.environ.get("is_share", "False")
is_share = eval(is_share) is_share = eval(is_share)
if "_CUDA_VISIBLE_DEVICES" in os.environ: if "_CUDA_VISIBLE_DEVICES" in os.environ:
os.environ["CUDA_VISIBLE_DEVICES"] = os.environ["_CUDA_VISIBLE_DEVICES"] os.environ["CUDA_VISIBLE_DEVICES"] = os.environ["_CUDA_VISIBLE_DEVICES"]
if "_MUSA_VISIBLE_DEVICES" in os.environ:
os.environ["MUSA_VISIBLE_DEVICES"] = os.environ["_MUSA_VISIBLE_DEVICES"]
is_half = eval(os.environ.get("is_half", "True")) and torch.cuda.is_available() is_half = eval(os.environ.get("is_half", "True")) and (torch.cuda.is_available() or musa_utils.is_available())
gpt_path = os.environ.get("gpt_path", None) gpt_path = os.environ.get("gpt_path", None)
sovits_path = os.environ.get("sovits_path", None) sovits_path = os.environ.get("sovits_path", None)
cnhubert_base_path = os.environ.get("cnhubert_base_path", None) cnhubert_base_path = os.environ.get("cnhubert_base_path", None)
@ -72,6 +76,8 @@ i18n = I18nAuto(language=language)
if torch.cuda.is_available(): if torch.cuda.is_available():
device = "cuda" device = "cuda"
elif musa_utils.is_available():
device = "musa"
# elif torch.backends.mps.is_available(): # elif torch.backends.mps.is_available():
# device = "mps" # device = "mps"
else: else:

View File

@ -7,6 +7,7 @@ import torch
from torch import nn from torch import nn
from torch.nn import functional as F from torch.nn import functional as F
import musa_utils
from module import commons from module import commons
from module import modules from module import modules
from module import attentions from module import attentions
@ -20,7 +21,10 @@ from module.quantize import ResidualVectorQuantizer
# from text import symbols # from text import symbols
from text import symbols as symbols_v1 from text import symbols as symbols_v1
from text import symbols2 as symbols_v2 from text import symbols2 as symbols_v2
from torch.cuda.amp import autocast if musa_utils.is_available():
autocast = torch.musa.amp.autocast
elif torch.cuda.is_available():
from torch.cuda.amp import autocast
import contextlib import contextlib
import random import random

View File

@ -9,11 +9,14 @@ i_part = os.environ.get("i_part")
all_parts = os.environ.get("all_parts") all_parts = os.environ.get("all_parts")
if "_CUDA_VISIBLE_DEVICES" in os.environ: if "_CUDA_VISIBLE_DEVICES" in os.environ:
os.environ["CUDA_VISIBLE_DEVICES"] = os.environ["_CUDA_VISIBLE_DEVICES"] os.environ["CUDA_VISIBLE_DEVICES"] = os.environ["_CUDA_VISIBLE_DEVICES"]
if "_MUSA_VISIBLE_DEVICES" in os.environ:
os.environ["MUSA_VISIBLE_DEVICES"] = os.environ["_MUSA_VISIBLE_DEVICES"]
opt_dir = os.environ.get("opt_dir") opt_dir = os.environ.get("opt_dir")
bert_pretrained_dir = os.environ.get("bert_pretrained_dir") bert_pretrained_dir = os.environ.get("bert_pretrained_dir")
import torch import torch
import musa_utils
is_half = eval(os.environ.get("is_half", "True")) and torch.cuda.is_available() is_half = eval(os.environ.get("is_half", "True")) and (torch.cuda.is_available() or musa_utils.is_available())
version = os.environ.get("version", None) version = os.environ.get("version", None)
import traceback import traceback
import os.path import os.path
@ -50,6 +53,8 @@ if os.path.exists(txt_path) == False:
os.makedirs(bert_dir, exist_ok=True) os.makedirs(bert_dir, exist_ok=True)
if torch.cuda.is_available(): if torch.cuda.is_available():
device = "cuda:0" device = "cuda:0"
elif musa_utils.is_available():
device = "musa:0"
# elif torch.backends.mps.is_available(): # elif torch.backends.mps.is_available():
# device = "mps" # device = "mps"
else: else:

View File

@ -10,13 +10,16 @@ i_part = os.environ.get("i_part")
all_parts = os.environ.get("all_parts") all_parts = os.environ.get("all_parts")
if "_CUDA_VISIBLE_DEVICES" in os.environ: if "_CUDA_VISIBLE_DEVICES" in os.environ:
os.environ["CUDA_VISIBLE_DEVICES"] = os.environ["_CUDA_VISIBLE_DEVICES"] os.environ["CUDA_VISIBLE_DEVICES"] = os.environ["_CUDA_VISIBLE_DEVICES"]
if "_MUSA_VISIBLE_DEVICES" in os.environ:
os.environ["MUSA_VISIBLE_DEVICES"] = os.environ["_MUSA_VISIBLE_DEVICES"]
from feature_extractor import cnhubert from feature_extractor import cnhubert
opt_dir = os.environ.get("opt_dir") opt_dir = os.environ.get("opt_dir")
cnhubert.cnhubert_base_path = os.environ.get("cnhubert_base_dir") cnhubert.cnhubert_base_path = os.environ.get("cnhubert_base_dir")
import torch import torch
import musa_utils
is_half = eval(os.environ.get("is_half", "True")) and torch.cuda.is_available() is_half = eval(os.environ.get("is_half", "True")) and (torch.cuda.is_available() or musa_utils.is_available())
import traceback import traceback
import numpy as np import numpy as np
@ -61,6 +64,8 @@ maxx = 0.95
alpha = 0.5 alpha = 0.5
if torch.cuda.is_available(): if torch.cuda.is_available():
device = "cuda:0" device = "cuda:0"
elif musa_utils.is_available():
device = "musa:0"
# elif torch.backends.mps.is_available(): # elif torch.backends.mps.is_available():
# device = "mps" # device = "mps"
else: else:

View File

@ -10,12 +10,15 @@ i_part = os.environ.get("i_part")
all_parts = os.environ.get("all_parts") all_parts = os.environ.get("all_parts")
if "_CUDA_VISIBLE_DEVICES" in os.environ: if "_CUDA_VISIBLE_DEVICES" in os.environ:
os.environ["CUDA_VISIBLE_DEVICES"] = os.environ["_CUDA_VISIBLE_DEVICES"] os.environ["CUDA_VISIBLE_DEVICES"] = os.environ["_CUDA_VISIBLE_DEVICES"]
if "_MUSA_VISIBLE_DEVICES" in os.environ:
os.environ["MUSA_VISIBLE_DEVICES"] = os.environ["_MUSA_VISIBLE_DEVICES"]
opt_dir = os.environ.get("opt_dir") opt_dir = os.environ.get("opt_dir")
sv_path = os.environ.get("sv_path") sv_path = os.environ.get("sv_path")
import torch import torch
import musa_utils
is_half = eval(os.environ.get("is_half", "True")) and torch.cuda.is_available() is_half = eval(os.environ.get("is_half", "True")) and (torch.cuda.is_available() or musa_utils.is_available())
import traceback import traceback
import torchaudio import torchaudio
@ -49,6 +52,8 @@ maxx = 0.95
alpha = 0.5 alpha = 0.5
if torch.cuda.is_available(): if torch.cuda.is_available():
device = "cuda:0" device = "cuda:0"
elif musa_utils.is_available():
device = "musa:0"
# elif torch.backends.mps.is_available(): # elif torch.backends.mps.is_available():
# device = "mps" # device = "mps"
else: else:

View File

@ -6,6 +6,8 @@ i_part = os.environ.get("i_part")
all_parts = os.environ.get("all_parts") all_parts = os.environ.get("all_parts")
if "_CUDA_VISIBLE_DEVICES" in os.environ: if "_CUDA_VISIBLE_DEVICES" in os.environ:
os.environ["CUDA_VISIBLE_DEVICES"] = os.environ["_CUDA_VISIBLE_DEVICES"] os.environ["CUDA_VISIBLE_DEVICES"] = os.environ["_CUDA_VISIBLE_DEVICES"]
if "_MUSA_VISIBLE_DEVICES" in os.environ:
os.environ["MUSA_VISIBLE_DEVICES"] = os.environ["_MUSA_VISIBLE_DEVICES"]
opt_dir = os.environ.get("opt_dir") opt_dir = os.environ.get("opt_dir")
pretrained_s2G = os.environ.get("pretrained_s2G") pretrained_s2G = os.environ.get("pretrained_s2G")
s2config_path = os.environ.get("s2config_path") s2config_path = os.environ.get("s2config_path")
@ -27,8 +29,9 @@ elif size < 700 * 1024 * 1024:
else: else:
version = "v3" version = "v3"
import torch import torch
import musa_utils
is_half = eval(os.environ.get("is_half", "True")) and torch.cuda.is_available() is_half = eval(os.environ.get("is_half", "True")) and (torch.cuda.is_available() or musa_utils.is_available())
import traceback import traceback
import sys import sys
@ -61,6 +64,8 @@ if os.path.exists(semantic_path) == False:
if torch.cuda.is_available(): if torch.cuda.is_available():
device = "cuda" device = "cuda"
elif musa_utils.is_available():
device = "musa:0"
# elif torch.backends.mps.is_available(): # elif torch.backends.mps.is_available():
# device = "mps" # device = "mps"
else: else:

View File

@ -15,7 +15,7 @@ from AR.utils.io import load_yaml_config
from pytorch_lightning import Trainer, seed_everything from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning.callbacks import ModelCheckpoint from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger # WandbLogger from pytorch_lightning.loggers import TensorBoardLogger # WandbLogger
from pytorch_lightning.strategies import DDPStrategy from pytorch_lightning.strategies import DDPStrategy, SingleDeviceStrategy
logging.getLogger("numba").setLevel(logging.WARNING) logging.getLogger("numba").setLevel(logging.WARNING)
logging.getLogger("matplotlib").setLevel(logging.WARNING) logging.getLogger("matplotlib").setLevel(logging.WARNING)
@ -25,6 +25,11 @@ from collections import OrderedDict
from AR.utils import get_newest_ckpt from AR.utils import get_newest_ckpt
from process_ckpt import my_save from process_ckpt import my_save
import musa_utils
if musa_utils.is_available():
import musa_accelerator
os.environ["MUSA_VISIBLE_DEVICES"] = os.environ["_MUSA_VISIBLE_DEVICES"]
class my_model_ckpt(ModelCheckpoint): class my_model_ckpt(ModelCheckpoint):
def __init__( def __init__(
@ -108,18 +113,28 @@ def main(args):
logger = TensorBoardLogger(name=output_dir.stem, save_dir=output_dir) logger = TensorBoardLogger(name=output_dir.stem, save_dir=output_dir)
os.environ["MASTER_ADDR"] = "localhost" os.environ["MASTER_ADDR"] = "localhost"
os.environ["USE_LIBUV"] = "0" os.environ["USE_LIBUV"] = "0"
if musa_utils.is_available():
accelerator = musa_accelerator.MUSAAccelerator()
devices = -1 # MUSA暂时使用单GPU
strategy = SingleDeviceStrategy(device=musa_utils.get_device()) # 不使用分布式训练
elif torch.cuda.is_available():
accelerator = "gpu"
devices = -1
strategy = DDPStrategy(process_group_backend="nccl" if platform.system() != "Windows" else "gloo")
else:
accelerator = "cpu"
devices = 1
strategy = "auto"
trainer: Trainer = Trainer( trainer: Trainer = Trainer(
max_epochs=config["train"]["epochs"], max_epochs=config["train"]["epochs"],
accelerator="gpu" if torch.cuda.is_available() else "cpu", accelerator=accelerator,
# val_check_interval=9999999999999999999999,###不要验证 # val_check_interval=9999999999999999999999,###不要验证
# check_val_every_n_epoch=None, # check_val_every_n_epoch=None,
limit_val_batches=0, limit_val_batches=0,
devices=-1 if torch.cuda.is_available() else 1, devices=devices,
benchmark=False, benchmark=False,
fast_dev_run=False, fast_dev_run=False,
strategy=DDPStrategy(process_group_backend="nccl" if platform.system() != "Windows" else "gloo") strategy=strategy,
if torch.cuda.is_available()
else "auto",
precision=config["train"]["precision"], precision=config["train"]["precision"],
logger=logger, logger=logger,
num_sanity_val_steps=0, num_sanity_val_steps=0,
@ -168,4 +183,4 @@ if __name__ == "__main__":
args = parser.parse_args() args = parser.parse_args()
logging.info(str(args)) logging.info(str(args))
main(args) main(args)

View File

@ -4,15 +4,23 @@ warnings.filterwarnings("ignore")
import os import os
import utils import utils
import musa_utils
hps = utils.get_hparams(stage=2) hps = utils.get_hparams(stage=2)
os.environ["CUDA_VISIBLE_DEVICES"] = hps.train.gpu_numbers.replace("-", ",") os.environ["CUDA_VISIBLE_DEVICES"] = hps.train.gpu_numbers.replace("-", ",")
os.environ["MUSA_VISIBLE_DEVICES"] = hps.train.gpu_numbers.replace("-", ",")
import logging import logging
import torch import torch
import torch.distributed as dist import torch.distributed as dist
import torch.multiprocessing as mp import torch.multiprocessing as mp
from torch.cuda.amp import GradScaler, autocast from torch.cuda.amp import GradScaler
musa_ddp = False
if musa_utils.is_available():
autocast = torch.musa.amp.autocast
musa_ddp = musa_utils.should_ddp()
elif torch.cuda.is_available():
from torch.cuda.amp import autocast
from torch.nn import functional as F from torch.nn import functional as F
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
@ -43,16 +51,22 @@ torch.backends.cudnn.deterministic = False
###反正A100fp32更快那试试tf32吧 ###反正A100fp32更快那试试tf32吧
torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True
if musa_utils.is_available():
torch.backends.mudnn.allow_tf32 = True
torch.set_float32_matmul_precision("medium") # 最低精度但最快(也就快一丁点),对于结果造成不了影响 torch.set_float32_matmul_precision("medium") # 最低精度但最快(也就快一丁点),对于结果造成不了影响
# from config import pretrained_s2G,pretrained_s2D # from config import pretrained_s2G,pretrained_s2D
global_step = 0 global_step = 0
device = "cpu" # cuda以外的设备等mps优化后加入 device = "cpu" # cuda以外的设备等mps优化后加入
if not musa_ddp:
device = "musa"
def main(): def main():
if torch.cuda.is_available(): if torch.cuda.is_available():
n_gpus = torch.cuda.device_count() n_gpus = torch.cuda.device_count()
elif musa_ddp:
n_gpus = musa_utils.device_count()
else: else:
n_gpus = 1 n_gpus = 1
os.environ["MASTER_ADDR"] = "localhost" os.environ["MASTER_ADDR"] = "localhost"
@ -78,7 +92,7 @@ def run(rank, n_gpus, hps):
writer_eval = SummaryWriter(log_dir=os.path.join(hps.s2_ckpt_dir, "eval")) writer_eval = SummaryWriter(log_dir=os.path.join(hps.s2_ckpt_dir, "eval"))
dist.init_process_group( dist.init_process_group(
backend="gloo" if os.name == "nt" or not torch.cuda.is_available() else "nccl", backend = "mccl" if torch.musa.is_available() else ("gloo" if os.name == "nt" or not torch.cuda.is_available() else "nccl"),
init_method="env://?use_libuv=False", init_method="env://?use_libuv=False",
world_size=n_gpus, world_size=n_gpus,
rank=rank, rank=rank,
@ -86,6 +100,8 @@ def run(rank, n_gpus, hps):
torch.manual_seed(hps.train.seed) torch.manual_seed(hps.train.seed)
if torch.cuda.is_available(): if torch.cuda.is_available():
torch.cuda.set_device(rank) torch.cuda.set_device(rank)
elif musa_ddp:
musa_utils.set_device(rank)
train_dataset = TextAudioSpeakerLoader(hps.data, version=hps.model.version) train_dataset = TextAudioSpeakerLoader(hps.data, version=hps.model.version)
train_sampler = DistributedBucketSampler( train_sampler = DistributedBucketSampler(
@ -140,6 +156,13 @@ def run(rank, n_gpus, hps):
**hps.model, **hps.model,
).cuda(rank) ).cuda(rank)
if torch.cuda.is_available() if torch.cuda.is_available()
else SynthesizerTrn(
hps.data.filter_length // 2 + 1,
hps.train.segment_size // hps.data.hop_length,
n_speakers=hps.data.n_speakers,
**hps.model,
).musa(rank)
if musa_ddp
else SynthesizerTrn( else SynthesizerTrn(
hps.data.filter_length // 2 + 1, hps.data.filter_length // 2 + 1,
hps.train.segment_size // hps.data.hop_length, hps.train.segment_size // hps.data.hop_length,
@ -151,6 +174,8 @@ def run(rank, n_gpus, hps):
net_d = ( net_d = (
MultiPeriodDiscriminator(hps.model.use_spectral_norm, version=hps.model.version).cuda(rank) MultiPeriodDiscriminator(hps.model.use_spectral_norm, version=hps.model.version).cuda(rank)
if torch.cuda.is_available() if torch.cuda.is_available()
else MultiPeriodDiscriminator(hps.model.use_spectral_norm, version=hps.model.version).musa(rank)
if musa_ddp
else MultiPeriodDiscriminator(hps.model.use_spectral_norm, version=hps.model.version).to(device) else MultiPeriodDiscriminator(hps.model.use_spectral_norm, version=hps.model.version).to(device)
) )
for name, param in net_g.named_parameters(): for name, param in net_g.named_parameters():
@ -196,7 +221,7 @@ def run(rank, n_gpus, hps):
betas=hps.train.betas, betas=hps.train.betas,
eps=hps.train.eps, eps=hps.train.eps,
) )
if torch.cuda.is_available(): if torch.cuda.is_available() or musa_ddp:
net_g = DDP(net_g, device_ids=[rank], find_unused_parameters=True) net_g = DDP(net_g, device_ids=[rank], find_unused_parameters=True)
net_d = DDP(net_d, device_ids=[rank], find_unused_parameters=True) net_d = DDP(net_d, device_ids=[rank], find_unused_parameters=True)
else: else:
@ -238,7 +263,7 @@ def run(rank, n_gpus, hps):
torch.load(hps.train.pretrained_s2G, map_location="cpu", weights_only=False)["weight"], torch.load(hps.train.pretrained_s2G, map_location="cpu", weights_only=False)["weight"],
strict=False, strict=False,
) )
if torch.cuda.is_available() if torch.cuda.is_available() or musa_ddp
else net_g.load_state_dict( else net_g.load_state_dict(
torch.load(hps.train.pretrained_s2G, map_location="cpu", weights_only=False)["weight"], torch.load(hps.train.pretrained_s2G, map_location="cpu", weights_only=False)["weight"],
strict=False, strict=False,
@ -256,7 +281,7 @@ def run(rank, n_gpus, hps):
net_d.module.load_state_dict( net_d.module.load_state_dict(
torch.load(hps.train.pretrained_s2D, map_location="cpu", weights_only=False)["weight"], strict=False torch.load(hps.train.pretrained_s2D, map_location="cpu", weights_only=False)["weight"], strict=False
) )
if torch.cuda.is_available() if torch.cuda.is_available() or musa_ddp
else net_d.load_state_dict( else net_d.load_state_dict(
torch.load(hps.train.pretrained_s2D, map_location="cpu", weights_only=False)["weight"], torch.load(hps.train.pretrained_s2D, map_location="cpu", weights_only=False)["weight"],
), ),
@ -279,7 +304,10 @@ def run(rank, n_gpus, hps):
scheduler_g.step() scheduler_g.step()
scheduler_d.step() scheduler_d.step()
scaler = GradScaler(enabled=hps.train.fp16_run) if musa_utils.is_available():
scaler = torch.musa.amp.GradScaler(enabled=hps.train.fp16_run)
else:
scaler = GradScaler(enabled=hps.train.fp16_run)
print("start training from epoch %s" % epoch_str) print("start training from epoch %s" % epoch_str)
for epoch in range(epoch_str, hps.train.epochs + 1): for epoch in range(epoch_str, hps.train.epochs + 1):
@ -369,6 +397,42 @@ def train_and_evaluate(rank, epoch, hps, nets, optims, schedulers, scaler, loade
) )
if hps.model.version in {"v2Pro", "v2ProPlus"}: if hps.model.version in {"v2Pro", "v2ProPlus"}:
sv_emb = sv_emb.cuda(rank, non_blocking=True) sv_emb = sv_emb.cuda(rank, non_blocking=True)
elif musa_ddp:
spec, spec_lengths = (
spec.musa(
rank,
non_blocking=True,
),
spec_lengths.musa(
rank,
non_blocking=True,
),
)
y, y_lengths = (
y.musa(
rank,
non_blocking=True,
),
y_lengths.musa(
rank,
non_blocking=True,
),
)
ssl = ssl.musa(rank, non_blocking=True)
ssl.requires_grad = False
# ssl_lengths = ssl_lengths.musa(rank, non_blocking=True)
text, text_lengths = (
text.musa(
rank,
non_blocking=True,
),
text_lengths.musa(
rank,
non_blocking=True,
),
)
if hps.model.version in {"v2Pro", "v2ProPlus"}:
sv_emb = sv_emb.musa(rank, non_blocking=True)
else: else:
spec, spec_lengths = spec.to(device), spec_lengths.to(device) spec, spec_lengths = spec.to(device), spec_lengths.to(device)
y, y_lengths = y.to(device), y_lengths.to(device) y, y_lengths = y.to(device), y_lengths.to(device)
@ -595,12 +659,17 @@ def evaluate(hps, generator, eval_loader, writer_eval):
text, text,
text_lengths, text_lengths,
) in enumerate(eval_loader): ) in enumerate(eval_loader):
print(111) print("确实在跑")
if torch.cuda.is_available(): if torch.cuda.is_available():
spec, spec_lengths = spec.cuda(), spec_lengths.cuda() spec, spec_lengths = spec.cuda(), spec_lengths.cuda()
y, y_lengths = y.cuda(), y_lengths.cuda() y, y_lengths = y.cuda(), y_lengths.cuda()
ssl = ssl.cuda() ssl = ssl.cuda()
text, text_lengths = text.cuda(), text_lengths.cuda() text, text_lengths = text.cuda(), text_lengths.cuda()
elif musa_utils.is_available():
spec, spec_lengths = spec.musa(), spec_lengths.musa()
y, y_lengths = y.musa(), y_lengths.musa()
ssl = ssl.musa()
text, text_lengths = text.musa(), text_lengths.musa()
else: else:
spec, spec_lengths = spec.to(device), spec_lengths.to(device) spec, spec_lengths = spec.to(device), spec_lengths.to(device)
y, y_lengths = y.to(device), y_lengths.to(device) y, y_lengths = y.to(device), y_lengths.to(device)
@ -616,7 +685,7 @@ def evaluate(hps, generator, eval_loader, writer_eval):
text_lengths, text_lengths,
test=test, test=test,
) )
if torch.cuda.is_available() if torch.cuda.is_available() or musa_utils.is_available()
else generator.infer( else generator.infer(
ssl, ssl,
spec, spec,

View File

@ -4,15 +4,22 @@ warnings.filterwarnings("ignore")
import os import os
import utils import utils
import musa_utils
hps = utils.get_hparams(stage=2) hps = utils.get_hparams(stage=2)
os.environ["CUDA_VISIBLE_DEVICES"] = hps.train.gpu_numbers.replace("-", ",") os.environ["CUDA_VISIBLE_DEVICES"] = hps.train.gpu_numbers.replace("-", ",")
os.environ["MUSA_VISIBLE_DEVICES"] = hps.train.gpu_numbers.replace("-", ",")
import logging import logging
import torch import torch
import torch.distributed as dist import torch.distributed as dist
import torch.multiprocessing as mp import torch.multiprocessing as mp
from torch.cuda.amp import GradScaler, autocast from torch.cuda.amp import GradScaler
if musa_utils.is_available():
autocast = torch.musa.amp.autocast
musa_ddp = musa_utils.should_ddp()
elif torch.cuda.is_available():
from torch.cuda.amp import autocast
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
@ -43,16 +50,22 @@ torch.backends.cudnn.deterministic = False
###反正A100fp32更快那试试tf32吧 ###反正A100fp32更快那试试tf32吧
torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True
if musa_utils.is_available():
torch.backends.mudnn.allow_tf32 = True
torch.set_float32_matmul_precision("medium") # 最低精度但最快(也就快一丁点),对于结果造成不了影响 torch.set_float32_matmul_precision("medium") # 最低精度但最快(也就快一丁点),对于结果造成不了影响
# from config import pretrained_s2G,pretrained_s2D # from config import pretrained_s2G,pretrained_s2D
global_step = 0 global_step = 0
device = "cpu" # cuda以外的设备等mps优化后加入 device = "cpu" # cuda以外的设备等mps优化后加入
if not musa_ddp:
device = "musa"
def main(): def main():
if torch.cuda.is_available(): if torch.cuda.is_available():
n_gpus = torch.cuda.device_count() n_gpus = torch.cuda.device_count()
elif musa_ddp:
n_gpus = musa_utils.device_count()
else: else:
n_gpus = 1 n_gpus = 1
os.environ["MASTER_ADDR"] = "localhost" os.environ["MASTER_ADDR"] = "localhost"
@ -78,7 +91,7 @@ def run(rank, n_gpus, hps):
writer_eval = SummaryWriter(log_dir=os.path.join(hps.s2_ckpt_dir, "eval")) writer_eval = SummaryWriter(log_dir=os.path.join(hps.s2_ckpt_dir, "eval"))
dist.init_process_group( dist.init_process_group(
backend="gloo" if os.name == "nt" or not torch.cuda.is_available() else "nccl", backend= "mccl" if torch.musa.is_available() else ("gloo" if os.name == "nt" or not torch.cuda.is_available() else "nccl"),
init_method="env://?use_libuv=False", init_method="env://?use_libuv=False",
world_size=n_gpus, world_size=n_gpus,
rank=rank, rank=rank,
@ -86,6 +99,8 @@ def run(rank, n_gpus, hps):
torch.manual_seed(hps.train.seed) torch.manual_seed(hps.train.seed)
if torch.cuda.is_available(): if torch.cuda.is_available():
torch.cuda.set_device(rank) torch.cuda.set_device(rank)
elif musa_ddp:
musa_utils.set_device(rank)
train_dataset = TextAudioSpeakerLoader(hps.data) ######## train_dataset = TextAudioSpeakerLoader(hps.data) ########
train_sampler = DistributedBucketSampler( train_sampler = DistributedBucketSampler(
@ -140,6 +155,13 @@ def run(rank, n_gpus, hps):
**hps.model, **hps.model,
).cuda(rank) ).cuda(rank)
if torch.cuda.is_available() if torch.cuda.is_available()
else SynthesizerTrn(
hps.data.filter_length // 2 + 1,
hps.train.segment_size // hps.data.hop_length,
n_speakers=hps.data.n_speakers,
**hps.model,
).musa(rank)
if musa_ddp
else SynthesizerTrn( else SynthesizerTrn(
hps.data.filter_length // 2 + 1, hps.data.filter_length // 2 + 1,
hps.train.segment_size // hps.data.hop_length, hps.train.segment_size // hps.data.hop_length,
@ -165,7 +187,7 @@ def run(rank, n_gpus, hps):
# betas=hps.train.betas, # betas=hps.train.betas,
# eps=hps.train.eps, # eps=hps.train.eps,
# ) # )
if torch.cuda.is_available(): if torch.cuda.is_available() or musa_ddp:
net_g = DDP(net_g, device_ids=[rank], find_unused_parameters=True) net_g = DDP(net_g, device_ids=[rank], find_unused_parameters=True)
# net_d = DDP(net_d, device_ids=[rank], find_unused_parameters=True) # net_d = DDP(net_d, device_ids=[rank], find_unused_parameters=True)
else: else:
@ -207,7 +229,7 @@ def run(rank, n_gpus, hps):
torch.load(hps.train.pretrained_s2G, map_location="cpu", weights_only=False)["weight"], torch.load(hps.train.pretrained_s2G, map_location="cpu", weights_only=False)["weight"],
strict=False, strict=False,
) )
if torch.cuda.is_available() if torch.cuda.is_available() or musa_ddp
else net_g.load_state_dict( else net_g.load_state_dict(
torch.load(hps.train.pretrained_s2G, map_location="cpu", weights_only=False)["weight"], torch.load(hps.train.pretrained_s2G, map_location="cpu", weights_only=False)["weight"],
strict=False, strict=False,
@ -235,7 +257,10 @@ def run(rank, n_gpus, hps):
scheduler_g.step() scheduler_g.step()
# scheduler_d.step() # scheduler_d.step()
scaler = GradScaler(enabled=hps.train.fp16_run) if musa_utils.is_available():
scaler = torch.musa.amp.GradScaler(enabled=hps.train.fp16_run)
else:
scaler = GradScaler(enabled=hps.train.fp16_run)
net_d = optim_d = scheduler_d = None net_d = optim_d = scheduler_d = None
print("start training from epoch %s" % epoch_str) print("start training from epoch %s" % epoch_str)
@ -334,6 +359,31 @@ def train_and_evaluate(
non_blocking=True, non_blocking=True,
), ),
) )
elif musa_ddp:
spec, spec_lengths = (
spec.musa(
rank,
non_blocking=True,
),
spec_lengths.musa(
rank,
non_blocking=True,
),
)
mel, mel_lengths = mel.musa(rank, non_blocking=True), mel_lengths.musa(rank, non_blocking=True)
ssl = ssl.musa(rank, non_blocking=True)
ssl.requires_grad = False
# ssl_lengths = ssl_lengths.musa(rank, non_blocking=True)
text, text_lengths = (
text.musa(
rank,
non_blocking=True,
),
text_lengths.musa(
rank,
non_blocking=True,
),
)
else: else:
spec, spec_lengths = spec.to(device), spec_lengths.to(device) spec, spec_lengths = spec.to(device), spec_lengths.to(device)
mel, mel_lengths = mel.to(device), mel_lengths.to(device) mel, mel_lengths = mel.to(device), mel_lengths.to(device)

View File

@ -4,15 +4,21 @@ warnings.filterwarnings("ignore")
import os import os
import utils import utils
import musa_utils
hps = utils.get_hparams(stage=2) hps = utils.get_hparams(stage=2)
os.environ["CUDA_VISIBLE_DEVICES"] = hps.train.gpu_numbers.replace("-", ",") os.environ["CUDA_VISIBLE_DEVICES"] = hps.train.gpu_numbers.replace("-", ",")
os.environ["MUSA_VISIBLE_DEVICES"] = hps.train.gpu_numbers.replace("-", ",")
import logging import logging
import torch import torch
import torch.distributed as dist import torch.distributed as dist
import torch.multiprocessing as mp import torch.multiprocessing as mp
from torch.cuda.amp import GradScaler, autocast from torch.cuda.amp import GradScaler
if musa_utils.is_available():
autocast = torch.musa.amp.autocast
elif torch.cuda.is_available():
from torch.cuda.amp import autocast
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
@ -43,11 +49,15 @@ torch.backends.cudnn.deterministic = False
###反正A100fp32更快那试试tf32吧 ###反正A100fp32更快那试试tf32吧
torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True
if musa_utils.is_available():
torch.backends.mudnn.allow_tf32 = True
torch.set_float32_matmul_precision("medium") # 最低精度但最快(也就快一丁点),对于结果造成不了影响 torch.set_float32_matmul_precision("medium") # 最低精度但最快(也就快一丁点),对于结果造成不了影响
# from config import pretrained_s2G,pretrained_s2D # from config import pretrained_s2G,pretrained_s2D
global_step = 0 global_step = 0
device = "cpu" # cuda以外的设备等mps优化后加入 device = "cpu" # cuda以外的设备等mps优化后加入
if musa_utils.is_available(): # DDP支持不写了没设备测试
device = "musa"
def main(): def main():
@ -209,7 +219,10 @@ def run(rank, n_gpus, hps):
for _ in range(epoch_str): for _ in range(epoch_str):
scheduler_g.step() scheduler_g.step()
scaler = GradScaler(enabled=hps.train.fp16_run) if musa_utils.is_available():
scaler = torch.musa.amp.GradScaler(enabled=hps.train.fp16_run)
else:
scaler = GradScaler(enabled=hps.train.fp16_run)
net_d = optim_d = scheduler_d = None net_d = optim_d = scheduler_d = None
print("start training from epoch %s" % epoch_str) print("start training from epoch %s" % epoch_str)

3
api.py
View File

@ -208,6 +208,7 @@ def clean_hifigan_model():
hifigan_model = None hifigan_model = None
try: try:
torch.cuda.empty_cache() torch.cuda.empty_cache()
torch.musa.empty_cache()
except: except:
pass pass
@ -219,6 +220,7 @@ def clean_bigvgan_model():
bigvgan_model = None bigvgan_model = None
try: try:
torch.cuda.empty_cache() torch.cuda.empty_cache()
torch.musa.empty_cache()
except: except:
pass pass
@ -230,6 +232,7 @@ def clean_sv_cn_model():
sv_cn_model = None sv_cn_model = None
try: try:
torch.cuda.empty_cache() torch.cuda.empty_cache()
torch.musa.empty_cache()
except: except:
pass pass

View File

@ -4,6 +4,8 @@ import sys
import torch import torch
import musa_utils
from tools.i18n.i18n import I18nAuto from tools.i18n.i18n import I18nAuto
i18n = I18nAuto(language=os.environ.get("language", "Auto")) i18n = I18nAuto(language=os.environ.get("language", "Auto"))
@ -147,11 +149,15 @@ api_port = 9880
# Thanks to the contribution of @Karasukaigan and @XXXXRT666 # Thanks to the contribution of @Karasukaigan and @XXXXRT666
def get_device_dtype_sm(idx: int) -> tuple[torch.device, torch.dtype, float, float]: def get_device_dtype_sm(idx: int) -> tuple[torch.device, torch.dtype, float, float]:
device_idx = idx
cpu = torch.device("cpu") cpu = torch.device("cpu")
cuda = torch.device(f"cuda:{idx}") cuda = torch.device(f"cuda:{idx}")
if musa_utils.is_available():
musa = torch.device(f"musa:{idx}")
mdtype, mversion, mmem_gb = musa_utils.get_device_dtype(device_idx)
return musa, mdtype, mversion, mmem_gb
if not torch.cuda.is_available(): if not torch.cuda.is_available():
return cpu, torch.float32, 0.0, 0.0 return cpu, torch.float32, 0.0, 0.0
device_idx = idx
capability = torch.cuda.get_device_capability(device_idx) capability = torch.cuda.get_device_capability(device_idx)
name = torch.cuda.get_device_name(device_idx) name = torch.cuda.get_device_name(device_idx)
mem_bytes = torch.cuda.get_device_properties(device_idx).total_memory mem_bytes = torch.cuda.get_device_properties(device_idx).total_memory
@ -167,11 +173,16 @@ def get_device_dtype_sm(idx: int) -> tuple[torch.device, torch.dtype, float, flo
return cuda, torch.float16, sm_version, mem_gb return cuda, torch.float16, sm_version, mem_gb
return cpu, torch.float32, 0.0, 0.0 return cpu, torch.float32, 0.0, 0.0
def get_gpu_count():
if musa_utils.is_available():
return musa_utils.device_count()
else:
return torch.cuda.device_count()
IS_GPU = True IS_GPU = True
GPU_INFOS: list[str] = [] GPU_INFOS: list[str] = []
GPU_INDEX: set[int] = set() GPU_INDEX: set[int] = set()
GPU_COUNT = torch.cuda.device_count() GPU_COUNT = get_gpu_count()
CPU_INFO: str = "0\tCPU " + i18n("CPU训练,较慢") CPU_INFO: str = "0\tCPU " + i18n("CPU训练,较慢")
tmp: list[tuple[torch.device, torch.dtype, float, float]] = [] tmp: list[tuple[torch.device, torch.dtype, float, float]] = []
memset: set[float] = set() memset: set[float] = set()
@ -183,7 +194,11 @@ for j in tmp:
device = j[0] device = j[0]
memset.add(j[3]) memset.add(j[3])
if device.type != "cpu": if device.type != "cpu":
GPU_INFOS.append(f"{device.index}\t{torch.cuda.get_device_name(device.index)}") if device.type == "cuda" and torch.cuda.is_available():
device_name = torch.cuda.get_device_name(device.index)
elif device.type == "musa" and musa_utils.is_available():
device_name = torch.musa.get_device_properties(device.index).name
GPU_INFOS.append(f"{device.index}\t{device_name}")
GPU_INDEX.add(device.index) GPU_INDEX.add(device.index)
if not GPU_INFOS: if not GPU_INFOS:

91
musa_accelerator.py Normal file
View File

@ -0,0 +1,91 @@
# refer to: https://github.com/plae-tljg/GPT-SoVITS-Musa/blob/main/code_patches.zip
# Kakaru(https://github.com/KakaruHayate/) 2025/10/11
import logging
import os
import shutil
import subprocess
from typing import Any, Optional, Union
import torch
import torch_musa
from typing_extensions import override
import pytorch_lightning as pl
from lightning_fabric.accelerators import _AcceleratorRegistry
from lightning_fabric.utilities.device_parser import _parse_gpu_ids
from lightning_fabric.utilities.types import _DEVICE
from pytorch_lightning.accelerators.accelerator import Accelerator
from pytorch_lightning.utilities.exceptions import MisconfigurationException
import logging
_log = logging.getLogger(__name__)
class MUSAAccelerator(Accelerator):
"""Accelerator for Moore Threads MUSA devices."""
@override
def setup_device(self, device: torch.device) -> None:
"""
Raises:
MisconfigurationException: If the selected device is not MUSA.
"""
if device.type != "musa":
raise MisconfigurationException(f"Device should be MUSA, got {device} instead.")
if device.index is None:
device = torch.device(f"musa:0")
torch.musa.set_device(device)
#@override
#def get_device_stats(self, device: _DEVICE) -> dict[str, Any]:
# if device.type == "musa":
# return torch_musa.memory_stats(device)
@override
def teardown(self) -> None:
"""Clean up any state created by the accelerator."""
torch.musa.empty_cache()
@staticmethod
def parse_devices(devices: Any) -> Any:
return torch.device("musa") if _musa_available else None
@staticmethod
@override
def get_parallel_devices(devices: Any) -> list[torch.device]:
"""Gets parallel devices for the Accelerator."""
if isinstance(devices, torch.device):
return [devices]
if devices is None or devices == "auto":
num_devices = MUSAAccelerator.auto_device_count()
devices = list(range(num_devices))
elif isinstance(devices, int):
devices = list(range(devices))
elif isinstance(devices, (list, tuple)):
pass
else:
raise ValueError(f"Unsupported devices type: {type(devices)}. Expected torch.device, int, list, tuple, or 'auto'.")
return [torch.device("musa", i) for i in devices]
@staticmethod
@override
def auto_device_count() -> int:
"""Get the number of MUSA devices when set to `auto`."""
return torch_musa.device_count()
@staticmethod
@override
def is_available() -> bool:
"""Checks if MUSA is available on the system."""
return torch_musa.is_available()
@classmethod
@override
def register_accelerators(cls, accelerator_registry: _AcceleratorRegistry) -> None:
"""Register this accelerator with its name and description."""
accelerator_registry.register(
"musa",
cls,
description=cls.__name__,
override=True,
)

78
musa_utils.py Normal file
View File

@ -0,0 +1,78 @@
import torch
from typing import Optional
import re
_musa_available = False
_musa_err_msg = "torch_musa not found or not configured correctly."
try:
if hasattr(torch, 'musa') and torch.musa.is_available():
_musa_available = True
except Exception:
_musa_available = False
def is_available() -> bool:
return _musa_available
def get_device() -> Optional[torch.device]:
return torch.device("musa") if _musa_available else None
def device_count() -> int:
if _musa_available:
try:
return torch.musa.device_count()
except Exception:
return 0
return 0
def set_device(device_index: int):
if _musa_available:
try:
torch.musa.set_device(device_index)
except Exception:
pass
def empty_cache():
if _musa_available:
try:
torch.musa.empty_cache()
except Exception:
pass
def manual_seed(seed: int):
if _musa_available:
try:
torch.musa.manual_seed(seed)
except Exception:
pass
def manual_seed_all(seed: int):
if _musa_available:
try:
torch.musa.manual_seed_all(seed)
except Exception:
pass
def get_device_dtype(device_idx: int = 0) -> tuple[torch.dtype, float, float]:
properties = torch.musa.get_device_properties(device_idx)
major, minor = torch.musa.get_device_capability()
version_val = major + minor / 10.0
mem_bytes = properties.total_memory
mem_gb = mem_bytes / (1024**3) + 0.4
device_name = properties.name
dtype = torch.float16
numbers_in_name = [int(n) for n in re.findall(r'\d+', device_name)]
if version_val <= 2.1:
dtype = torch.float32
return dtype, version_val, mem_gb
def should_ddp(device_idx: int = 0) -> bool:
device_name = torch.musa.get_device_properties(device_idx).name
numbers_in_name = [int(n) for n in re.findall(r'\d+', device_name)]
if version_val <= 2.1:
return False
else:
return True
DEVICE: Optional[torch.device] = get_device()

View File

@ -13,6 +13,12 @@ from tools.asr.config import get_models
from tools.asr.funasr_asr import only_asr from tools.asr.funasr_asr import only_asr
from tools.my_utils import load_cudnn from tools.my_utils import load_cudnn
try:
import torch_musa
use_torch_musa = True
except ImportError:
use_torch_musa = False
# fmt: off # fmt: off
language_code_list = [ language_code_list = [
"af", "am", "ar", "as", "az", "af", "am", "ar", "as", "az",
@ -86,6 +92,8 @@ def execute_asr(input_folder, output_folder, model_path, language, precision):
language = None # 不设置语种由模型自动输出概率最高的语种 language = None # 不设置语种由模型自动输出概率最高的语种
print("loading faster whisper model:", model_path, model_path) print("loading faster whisper model:", model_path, model_path)
device = "cuda" if torch.cuda.is_available() else "cpu" device = "cuda" if torch.cuda.is_available() else "cpu"
if use_torch_musa:
device = "musa" if torch.musa.is_available() else "cpu"
model = WhisperModel(model_path, device=device, compute_type=precision) model = WhisperModel(model_path, device=device, compute_type=precision)
input_file_names = os.listdir(input_folder) input_file_names = os.listdir(input_folder)

View File

@ -10,6 +10,12 @@ import torch.nn as nn
import yaml import yaml
from tqdm import tqdm from tqdm import tqdm
try:
import torch_musa
use_torch_musa = True
except ImportError:
use_torch_musa = False
warnings.filterwarnings("ignore") warnings.filterwarnings("ignore")
@ -135,7 +141,13 @@ class Roformer_Loader:
window_middle[-fade_size:] *= fadeout window_middle[-fade_size:] *= fadeout
window_middle[:fade_size] *= fadein window_middle[:fade_size] *= fadein
with torch.amp.autocast("cuda"): if use_torch_musa:
if torch.musa.is_available():
set_device = "musa"
else:
set_device = "cuda"
with torch.amp.autocast(set_device):
with torch.inference_mode(): with torch.inference_mode():
if self.config["training"]["target_instrument"] is None: if self.config["training"]["target_instrument"] is None:
req_shape = (len(self.config["training"]["instruments"]),) + tuple(mix.shape) req_shape = (len(self.config["training"]["instruments"]),) + tuple(mix.shape)

View File

@ -18,6 +18,12 @@ from bsroformer import Roformer_Loader
from mdxnet import MDXNetDereverb from mdxnet import MDXNetDereverb
from vr import AudioPre, AudioPreDeEcho from vr import AudioPre, AudioPreDeEcho
try:
import torch_musa
use_torch_musa = True
except ImportError:
use_torch_musa = False
weight_uvr5_root = "tools/uvr5/uvr5_weights" weight_uvr5_root = "tools/uvr5/uvr5_weights"
uvr5_names = [] uvr5_names = []
for name in os.listdir(weight_uvr5_root): for name in os.listdir(weight_uvr5_root):
@ -122,6 +128,9 @@ def uvr(model_name, inp_root, save_root_vocal, paths, save_root_ins, agg, format
print("clean_empty_cache") print("clean_empty_cache")
if torch.cuda.is_available(): if torch.cuda.is_available():
torch.cuda.empty_cache() torch.cuda.empty_cache()
if use_torch_musa:
if torch.musa.is_available():
torch.musa.empty_cache()
yield "\n".join(infos) yield "\n".join(infos)

View File

@ -92,6 +92,8 @@ os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
# os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1' # 当遇到mps不支持的步骤时使用cpu # os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1' # 当遇到mps不支持的步骤时使用cpu
import gradio as gr import gradio as gr
import musa_utils
n_cpu = cpu_count() n_cpu = cpu_count()
set_gpu_numbers = GPU_INDEX set_gpu_numbers = GPU_INDEX
@ -347,6 +349,8 @@ def change_tts_inference(bert_path, cnhubert_base_path, gpu_number, gpt_path, so
os.environ["is_half"] = str(is_half) os.environ["is_half"] = str(is_half)
os.environ["infer_ttswebui"] = str(webui_port_infer_tts) os.environ["infer_ttswebui"] = str(webui_port_infer_tts)
os.environ["is_share"] = str(is_share) os.environ["is_share"] = str(is_share)
if musa_utils.is_available():
os.environ["_MUSA_VISIBLE_DEVICES"] = str(fix_gpu_number(gpu_number))
yield ( yield (
process_info(process_name_tts, "opened"), process_info(process_name_tts, "opened"),
{"__type__": "update", "visible": False}, {"__type__": "update", "visible": False},
@ -629,6 +633,8 @@ def open1Bb(
# data["version"]=version # data["version"]=version
os.environ["_CUDA_VISIBLE_DEVICES"] = str(fix_gpu_numbers(gpu_numbers.replace("-", ","))) os.environ["_CUDA_VISIBLE_DEVICES"] = str(fix_gpu_numbers(gpu_numbers.replace("-", ",")))
if musa_utils.is_available():
os.environ["_MUSA_VISIBLE_DEVICES"] = str(fix_gpu_number(gpu_numbers.replace("-", ",")))
os.environ["hz"] = "25hz" os.environ["hz"] = "25hz"
tmp_config_path = "%s/tmp_s1.yaml" % tmp tmp_config_path = "%s/tmp_s1.yaml" % tmp
with open(tmp_config_path, "w") as f: with open(tmp_config_path, "w") as f:
@ -805,6 +811,8 @@ def open1a(inp_text, inp_wav_dir, exp_name, gpu_numbers, bert_pretrained_dir):
"is_half": str(is_half), "is_half": str(is_half),
} }
) )
if musa_utils.is_available():
config.update({"_MUSA_VISIBLE_DEVICES": str(fix_gpu_number(gpu_names[i_part])),})
os.environ.update(config) os.environ.update(config)
cmd = '"%s" -s GPT_SoVITS/prepare_datasets/1-get-text.py' % python_exec cmd = '"%s" -s GPT_SoVITS/prepare_datasets/1-get-text.py' % python_exec
print(cmd) print(cmd)
@ -895,6 +903,8 @@ def open1b(version, inp_text, inp_wav_dir, exp_name, gpu_numbers, ssl_pretrained
"_CUDA_VISIBLE_DEVICES": str(fix_gpu_number(gpu_names[i_part])), "_CUDA_VISIBLE_DEVICES": str(fix_gpu_number(gpu_names[i_part])),
} }
) )
if musa_utils.is_available():
config.update({"_MUSA_VISIBLE_DEVICES": str(fix_gpu_number(gpu_names[i_part])),})
os.environ.update(config) os.environ.update(config)
cmd = '"%s" -s GPT_SoVITS/prepare_datasets/2-get-hubert-wav32k.py' % python_exec cmd = '"%s" -s GPT_SoVITS/prepare_datasets/2-get-hubert-wav32k.py' % python_exec
print(cmd) print(cmd)
@ -917,6 +927,8 @@ def open1b(version, inp_text, inp_wav_dir, exp_name, gpu_numbers, ssl_pretrained
"_CUDA_VISIBLE_DEVICES": str(fix_gpu_number(gpu_names[i_part])), "_CUDA_VISIBLE_DEVICES": str(fix_gpu_number(gpu_names[i_part])),
} }
) )
if musa_utils.is_available():
config.update({"_MUSA_VISIBLE_DEVICES": str(fix_gpu_number(gpu_names[i_part])),})
os.environ.update(config) os.environ.update(config)
cmd = '"%s" -s GPT_SoVITS/prepare_datasets/2-get-sv.py' % python_exec cmd = '"%s" -s GPT_SoVITS/prepare_datasets/2-get-sv.py' % python_exec
print(cmd) print(cmd)
@ -989,6 +1001,8 @@ def open1c(version, inp_text, inp_wav_dir, exp_name, gpu_numbers, pretrained_s2G
"_CUDA_VISIBLE_DEVICES": str(fix_gpu_number(gpu_names[i_part])), "_CUDA_VISIBLE_DEVICES": str(fix_gpu_number(gpu_names[i_part])),
} }
) )
if musa_utils.is_available():
config.update({"_MUSA_VISIBLE_DEVICES": str(fix_gpu_number(gpu_names[i_part])),})
os.environ.update(config) os.environ.update(config)
cmd = '"%s" -s GPT_SoVITS/prepare_datasets/3-get-semantic.py' % python_exec cmd = '"%s" -s GPT_SoVITS/prepare_datasets/3-get-semantic.py' % python_exec
print(cmd) print(cmd)
@ -1089,6 +1103,8 @@ def open1abc(
"_CUDA_VISIBLE_DEVICES": str(fix_gpu_number(gpu_names[i_part])), "_CUDA_VISIBLE_DEVICES": str(fix_gpu_number(gpu_names[i_part])),
} }
) )
if musa_utils.is_available():
config.update({"_MUSA_VISIBLE_DEVICES": str(fix_gpu_number(gpu_names[i_part])),})
os.environ.update(config) os.environ.update(config)
cmd = '"%s" -s GPT_SoVITS/prepare_datasets/1-get-text.py' % python_exec cmd = '"%s" -s GPT_SoVITS/prepare_datasets/1-get-text.py' % python_exec
print(cmd) print(cmd)
@ -1136,6 +1152,8 @@ def open1abc(
"_CUDA_VISIBLE_DEVICES": str(fix_gpu_number(gpu_names[i_part])), "_CUDA_VISIBLE_DEVICES": str(fix_gpu_number(gpu_names[i_part])),
} }
) )
if musa_utils.is_available():
config.update({"_MUSA_VISIBLE_DEVICES": str(fix_gpu_number(gpu_names[i_part])),})
os.environ.update(config) os.environ.update(config)
cmd = '"%s" -s GPT_SoVITS/prepare_datasets/2-get-hubert-wav32k.py' % python_exec cmd = '"%s" -s GPT_SoVITS/prepare_datasets/2-get-hubert-wav32k.py' % python_exec
print(cmd) print(cmd)
@ -1158,6 +1176,8 @@ def open1abc(
"_CUDA_VISIBLE_DEVICES": str(fix_gpu_number(gpu_names[i_part])), "_CUDA_VISIBLE_DEVICES": str(fix_gpu_number(gpu_names[i_part])),
} }
) )
if musa_utils.is_available():
config.update({"_MUSA_VISIBLE_DEVICES": str(fix_gpu_number(gpu_names[i_part])),})
os.environ.update(config) os.environ.update(config)
cmd = '"%s" -s GPT_SoVITS/prepare_datasets/2-get-sv.py' % python_exec cmd = '"%s" -s GPT_SoVITS/prepare_datasets/2-get-sv.py' % python_exec
print(cmd) print(cmd)
@ -1198,6 +1218,8 @@ def open1abc(
"_CUDA_VISIBLE_DEVICES": str(fix_gpu_number(gpu_names[i_part])), "_CUDA_VISIBLE_DEVICES": str(fix_gpu_number(gpu_names[i_part])),
} }
) )
if musa_utils.is_available():
config.update({"_MUSA_VISIBLE_DEVICES": str(fix_gpu_number(gpu_names[i_part])),})
os.environ.update(config) os.environ.update(config)
cmd = '"%s" -s GPT_SoVITS/prepare_datasets/3-get-semantic.py' % python_exec cmd = '"%s" -s GPT_SoVITS/prepare_datasets/3-get-semantic.py' % python_exec
print(cmd) print(cmd)