Support on MUSA device.

fix

Update musa_utils.py

Update musa_utils.py

Update config.py

fix

rollback S1 train

DDP only support S4000

DDP only support S4000

fix
This commit is contained in:
KakaruHayate 2025-10-10 21:14:27 +08:00
parent 11aa78bd9b
commit 72be145051
21 changed files with 393 additions and 30 deletions

View File

@ -146,4 +146,4 @@ class DistributedBucketSampler(Sampler[T_co]):
Args:
epoch (int): Epoch number.
"""
self.epoch = epoch
self.epoch = epoch

View File

@ -36,6 +36,8 @@ from TTS_infer_pack.text_segmentation_method import splits
from TTS_infer_pack.TextPreprocessor import TextPreprocessor
from sv import SV
import musa_utils
resample_transform_dict = {}
@ -209,6 +211,10 @@ def set_seed(seed: int):
# 开启后会影响精度
torch.backends.cuda.matmul.allow_tf32 = False
torch.backends.cudnn.allow_tf32 = False
elif musa_utils.is_available():
musa_utils.manual_seed(seed)
musa_utils.manual_seed_all(seed)
torch.backends.mudnn.allow_tf32 = False
except:
pass
return seed
@ -310,8 +316,10 @@ class TTS_Config:
self.default_configs = deepcopy(configs_)
self.device = self.configs.get("device", torch.device("cpu"))
if "cuda" in str(self.device) and not torch.cuda.is_available():
print("Warning: CUDA is not available, set device to CPU.")
cuda_mismatch = "cuda" in str(self.device) and not torch.cuda.is_available()
musa_mismatch = "musa" in str(self.device) and not musa_utils.is_available()
if cuda_mismatch or musa_mismatch:
print(f"Warning: Requested device '{self.device}' is not available, falling back to CPU.")
self.device = torch.device("cpu")
self.is_half = self.configs.get("is_half", False)
@ -1369,6 +1377,8 @@ class TTS:
gc.collect() # 触发gc的垃圾回收。避免内存一直增长。
if "cuda" in str(self.configs.device):
torch.cuda.empty_cache()
elif "musa" in str(self.configs.device):
torch.musa.empty_cache()
elif str(self.configs.device) == "mps":
torch.mps.empty_cache()
except:

View File

@ -5,6 +5,8 @@ import torch
import torchaudio
from torch import Tensor
import musa_utils
__all__ = [
"get_mel_banks",
"inverse_mel_scale",
@ -305,7 +307,12 @@ def spectrogram(
)
# size (m, padded_window_size // 2 + 1, 2)
if musa_utils.is_available() : # 怎么还有算子不支持怎么还有算子不支持怎么还有算子不支持
ori_device = strided_input.device
strided_input = strided_input.cpu()
fft = torch.fft.rfft(strided_input)
if musa_utils.is_available() :
fft = fft.to(ori_device)
# Convert the FFT into a power spectrum
power_spectrum = torch.max(fft.abs().pow(2.0), epsilon).log() # size (m, padded_window_size // 2 + 1)
@ -618,7 +625,12 @@ def fbank(
)
# size (m, padded_window_size // 2 + 1)
if musa_utils.is_available() : # 怎么还有算子不支持怎么还有算子不支持怎么还有算子不支持
ori_device = strided_input.device
strided_input = strided_input.cpu()
spectrum = torch.fft.rfft(strided_input).abs()
if musa_utils.is_available() :
spectrum = spectrum.to(ori_device)
if use_power:
spectrum = spectrum.pow(2.0)

View File

@ -49,6 +49,8 @@ from config import change_choices, get_weights_names, name2gpt_path, name2sovits
SoVITS_names, GPT_names = get_weights_names()
from config import pretrained_sovits_name
import musa_utils
path_sovits_v3 = pretrained_sovits_name["v3"]
path_sovits_v4 = pretrained_sovits_name["v4"]
is_exist_s2gv3 = os.path.exists(path_sovits_v3)
@ -87,7 +89,9 @@ is_share = os.environ.get("is_share", "False")
is_share = eval(is_share)
if "_CUDA_VISIBLE_DEVICES" in os.environ:
os.environ["CUDA_VISIBLE_DEVICES"] = os.environ["_CUDA_VISIBLE_DEVICES"]
is_half = eval(os.environ.get("is_half", "True")) and torch.cuda.is_available()
if "_MUSA_VISIBLE_DEVICES" in os.environ:
os.environ["MUSA_VISIBLE_DEVICES"] = os.environ["_MUSA_VISIBLE_DEVICES"]
is_half = eval(os.environ.get("is_half", "True")) and (torch.cuda.is_available() or musa_utils.is_available())
# is_half=False
punctuation = set(["!", "?", "", ",", ".", "-", " "])
import gradio as gr
@ -112,6 +116,8 @@ def set_seed(seed):
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
if musa_utils.is_available():
musa_utils.manual_seed(seed)
# set_seed(42)
@ -134,6 +140,8 @@ i18n = I18nAuto(language=language)
if torch.cuda.is_available():
device = "cuda"
elif musa_utils.is_available():
device = "musa"
else:
device = "cpu"
@ -411,6 +419,7 @@ def clean_hifigan_model():
hifigan_model = None
try:
torch.cuda.empty_cache()
torch.musa.empty_cache()
except:
pass
@ -422,6 +431,7 @@ def clean_bigvgan_model():
bigvgan_model = None
try:
torch.cuda.empty_cache()
torch.musa.empty_cache()
except:
pass
@ -433,6 +443,7 @@ def clean_sv_cn_model():
sv_cn_model = None
try:
torch.cuda.empty_cache()
torch.musa.empty_cache()
except:
pass

View File

@ -29,6 +29,8 @@ import sys
import torch
import musa_utils
now_dir = os.getcwd()
sys.path.append(now_dir)
sys.path.append("%s/GPT_SoVITS" % (now_dir))
@ -48,8 +50,10 @@ is_share = os.environ.get("is_share", "False")
is_share = eval(is_share)
if "_CUDA_VISIBLE_DEVICES" in os.environ:
os.environ["CUDA_VISIBLE_DEVICES"] = os.environ["_CUDA_VISIBLE_DEVICES"]
if "_MUSA_VISIBLE_DEVICES" in os.environ:
os.environ["MUSA_VISIBLE_DEVICES"] = os.environ["_MUSA_VISIBLE_DEVICES"]
is_half = eval(os.environ.get("is_half", "True")) and torch.cuda.is_available()
is_half = eval(os.environ.get("is_half", "True")) and (torch.cuda.is_available() or musa_utils.is_available())
gpt_path = os.environ.get("gpt_path", None)
sovits_path = os.environ.get("sovits_path", None)
cnhubert_base_path = os.environ.get("cnhubert_base_path", None)
@ -72,6 +76,8 @@ i18n = I18nAuto(language=language)
if torch.cuda.is_available():
device = "cuda"
elif musa_utils.is_available():
device = "musa"
# elif torch.backends.mps.is_available():
# device = "mps"
else:

View File

@ -7,6 +7,7 @@ import torch
from torch import nn
from torch.nn import functional as F
import musa_utils
from module import commons
from module import modules
from module import attentions
@ -20,7 +21,10 @@ from module.quantize import ResidualVectorQuantizer
# from text import symbols
from text import symbols as symbols_v1
from text import symbols2 as symbols_v2
from torch.cuda.amp import autocast
if musa_utils.is_available():
autocast = torch.musa.amp.autocast
elif torch.cuda.is_available():
from torch.cuda.amp import autocast
import contextlib
import random

View File

@ -9,11 +9,14 @@ i_part = os.environ.get("i_part")
all_parts = os.environ.get("all_parts")
if "_CUDA_VISIBLE_DEVICES" in os.environ:
os.environ["CUDA_VISIBLE_DEVICES"] = os.environ["_CUDA_VISIBLE_DEVICES"]
if "_MUSA_VISIBLE_DEVICES" in os.environ:
os.environ["MUSA_VISIBLE_DEVICES"] = os.environ["_MUSA_VISIBLE_DEVICES"]
opt_dir = os.environ.get("opt_dir")
bert_pretrained_dir = os.environ.get("bert_pretrained_dir")
import torch
import musa_utils
is_half = eval(os.environ.get("is_half", "True")) and torch.cuda.is_available()
is_half = eval(os.environ.get("is_half", "True")) and (torch.cuda.is_available() or musa_utils.is_available())
version = os.environ.get("version", None)
import traceback
import os.path
@ -50,6 +53,8 @@ if os.path.exists(txt_path) == False:
os.makedirs(bert_dir, exist_ok=True)
if torch.cuda.is_available():
device = "cuda:0"
elif musa_utils.is_available():
device = "musa:0"
# elif torch.backends.mps.is_available():
# device = "mps"
else:

View File

@ -10,13 +10,16 @@ i_part = os.environ.get("i_part")
all_parts = os.environ.get("all_parts")
if "_CUDA_VISIBLE_DEVICES" in os.environ:
os.environ["CUDA_VISIBLE_DEVICES"] = os.environ["_CUDA_VISIBLE_DEVICES"]
if "_MUSA_VISIBLE_DEVICES" in os.environ:
os.environ["MUSA_VISIBLE_DEVICES"] = os.environ["_MUSA_VISIBLE_DEVICES"]
from feature_extractor import cnhubert
opt_dir = os.environ.get("opt_dir")
cnhubert.cnhubert_base_path = os.environ.get("cnhubert_base_dir")
import torch
import musa_utils
is_half = eval(os.environ.get("is_half", "True")) and torch.cuda.is_available()
is_half = eval(os.environ.get("is_half", "True")) and (torch.cuda.is_available() or musa_utils.is_available())
import traceback
import numpy as np
@ -61,6 +64,8 @@ maxx = 0.95
alpha = 0.5
if torch.cuda.is_available():
device = "cuda:0"
elif musa_utils.is_available():
device = "musa:0"
# elif torch.backends.mps.is_available():
# device = "mps"
else:

View File

@ -10,12 +10,15 @@ i_part = os.environ.get("i_part")
all_parts = os.environ.get("all_parts")
if "_CUDA_VISIBLE_DEVICES" in os.environ:
os.environ["CUDA_VISIBLE_DEVICES"] = os.environ["_CUDA_VISIBLE_DEVICES"]
if "_MUSA_VISIBLE_DEVICES" in os.environ:
os.environ["MUSA_VISIBLE_DEVICES"] = os.environ["_MUSA_VISIBLE_DEVICES"]
opt_dir = os.environ.get("opt_dir")
sv_path = os.environ.get("sv_path")
import torch
import musa_utils
is_half = eval(os.environ.get("is_half", "True")) and torch.cuda.is_available()
is_half = eval(os.environ.get("is_half", "True")) and (torch.cuda.is_available() or musa_utils.is_available())
import traceback
import torchaudio
@ -49,6 +52,8 @@ maxx = 0.95
alpha = 0.5
if torch.cuda.is_available():
device = "cuda:0"
elif musa_utils.is_available():
device = "musa:0"
# elif torch.backends.mps.is_available():
# device = "mps"
else:

View File

@ -6,6 +6,8 @@ i_part = os.environ.get("i_part")
all_parts = os.environ.get("all_parts")
if "_CUDA_VISIBLE_DEVICES" in os.environ:
os.environ["CUDA_VISIBLE_DEVICES"] = os.environ["_CUDA_VISIBLE_DEVICES"]
if "_MUSA_VISIBLE_DEVICES" in os.environ:
os.environ["MUSA_VISIBLE_DEVICES"] = os.environ["_MUSA_VISIBLE_DEVICES"]
opt_dir = os.environ.get("opt_dir")
pretrained_s2G = os.environ.get("pretrained_s2G")
s2config_path = os.environ.get("s2config_path")
@ -27,8 +29,9 @@ elif size < 700 * 1024 * 1024:
else:
version = "v3"
import torch
import musa_utils
is_half = eval(os.environ.get("is_half", "True")) and torch.cuda.is_available()
is_half = eval(os.environ.get("is_half", "True")) and (torch.cuda.is_available() or musa_utils.is_available())
import traceback
import sys
@ -61,6 +64,8 @@ if os.path.exists(semantic_path) == False:
if torch.cuda.is_available():
device = "cuda"
elif musa_utils.is_available():
device = "musa:0"
# elif torch.backends.mps.is_available():
# device = "mps"
else:

View File

@ -168,4 +168,4 @@ if __name__ == "__main__":
args = parser.parse_args()
logging.info(str(args))
main(args)
main(args)

View File

@ -4,15 +4,23 @@ warnings.filterwarnings("ignore")
import os
import utils
import musa_utils
hps = utils.get_hparams(stage=2)
os.environ["CUDA_VISIBLE_DEVICES"] = hps.train.gpu_numbers.replace("-", ",")
os.environ["MUSA_VISIBLE_DEVICES"] = hps.train.gpu_numbers.replace("-", ",")
import logging
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.cuda.amp import GradScaler, autocast
from torch.cuda.amp import GradScaler
musa_ddp = False
if musa_utils.is_available():
autocast = torch.musa.amp.autocast
musa_ddp = musa_utils.should_ddp()
elif torch.cuda.is_available():
from torch.cuda.amp import autocast
from torch.nn import functional as F
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader
@ -43,16 +51,22 @@ torch.backends.cudnn.deterministic = False
###反正A100fp32更快那试试tf32吧
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
if musa_utils.is_available():
torch.backends.mudnn.allow_tf32 = True
torch.set_float32_matmul_precision("medium") # 最低精度但最快(也就快一丁点),对于结果造成不了影响
# from config import pretrained_s2G,pretrained_s2D
global_step = 0
device = "cpu" # cuda以外的设备等mps优化后加入
if not musa_ddp:
device = "musa"
def main():
if torch.cuda.is_available():
n_gpus = torch.cuda.device_count()
elif musa_ddp:
n_gpus = musa_utils.device_count()
else:
n_gpus = 1
os.environ["MASTER_ADDR"] = "localhost"
@ -78,7 +92,7 @@ def run(rank, n_gpus, hps):
writer_eval = SummaryWriter(log_dir=os.path.join(hps.s2_ckpt_dir, "eval"))
dist.init_process_group(
backend="gloo" if os.name == "nt" or not torch.cuda.is_available() else "nccl",
backend = "mccl" if torch.musa.is_available() else ("gloo" if os.name == "nt" or not torch.cuda.is_available() else "nccl"),
init_method="env://?use_libuv=False",
world_size=n_gpus,
rank=rank,
@ -86,6 +100,8 @@ def run(rank, n_gpus, hps):
torch.manual_seed(hps.train.seed)
if torch.cuda.is_available():
torch.cuda.set_device(rank)
elif musa_ddp:
musa_utils.set_device(rank)
train_dataset = TextAudioSpeakerLoader(hps.data, version=hps.model.version)
train_sampler = DistributedBucketSampler(
@ -140,6 +156,13 @@ def run(rank, n_gpus, hps):
**hps.model,
).cuda(rank)
if torch.cuda.is_available()
else SynthesizerTrn(
hps.data.filter_length // 2 + 1,
hps.train.segment_size // hps.data.hop_length,
n_speakers=hps.data.n_speakers,
**hps.model,
).musa(rank)
if musa_ddp
else SynthesizerTrn(
hps.data.filter_length // 2 + 1,
hps.train.segment_size // hps.data.hop_length,
@ -151,6 +174,8 @@ def run(rank, n_gpus, hps):
net_d = (
MultiPeriodDiscriminator(hps.model.use_spectral_norm, version=hps.model.version).cuda(rank)
if torch.cuda.is_available()
else MultiPeriodDiscriminator(hps.model.use_spectral_norm, version=hps.model.version).musa(rank)
if musa_ddp
else MultiPeriodDiscriminator(hps.model.use_spectral_norm, version=hps.model.version).to(device)
)
for name, param in net_g.named_parameters():
@ -196,7 +221,7 @@ def run(rank, n_gpus, hps):
betas=hps.train.betas,
eps=hps.train.eps,
)
if torch.cuda.is_available():
if torch.cuda.is_available() or musa_ddp:
net_g = DDP(net_g, device_ids=[rank], find_unused_parameters=True)
net_d = DDP(net_d, device_ids=[rank], find_unused_parameters=True)
else:
@ -238,7 +263,7 @@ def run(rank, n_gpus, hps):
torch.load(hps.train.pretrained_s2G, map_location="cpu", weights_only=False)["weight"],
strict=False,
)
if torch.cuda.is_available()
if torch.cuda.is_available() or musa_ddp
else net_g.load_state_dict(
torch.load(hps.train.pretrained_s2G, map_location="cpu", weights_only=False)["weight"],
strict=False,
@ -256,7 +281,7 @@ def run(rank, n_gpus, hps):
net_d.module.load_state_dict(
torch.load(hps.train.pretrained_s2D, map_location="cpu", weights_only=False)["weight"], strict=False
)
if torch.cuda.is_available()
if torch.cuda.is_available() or musa_ddp
else net_d.load_state_dict(
torch.load(hps.train.pretrained_s2D, map_location="cpu", weights_only=False)["weight"],
),
@ -279,7 +304,10 @@ def run(rank, n_gpus, hps):
scheduler_g.step()
scheduler_d.step()
scaler = GradScaler(enabled=hps.train.fp16_run)
if musa_utils.is_available():
scaler = torch.musa.amp.GradScaler(enabled=hps.train.fp16_run)
else:
scaler = GradScaler(enabled=hps.train.fp16_run)
print("start training from epoch %s" % epoch_str)
for epoch in range(epoch_str, hps.train.epochs + 1):
@ -369,6 +397,42 @@ def train_and_evaluate(rank, epoch, hps, nets, optims, schedulers, scaler, loade
)
if hps.model.version in {"v2Pro", "v2ProPlus"}:
sv_emb = sv_emb.cuda(rank, non_blocking=True)
elif musa_ddp:
spec, spec_lengths = (
spec.musa(
rank,
non_blocking=True,
),
spec_lengths.musa(
rank,
non_blocking=True,
),
)
y, y_lengths = (
y.musa(
rank,
non_blocking=True,
),
y_lengths.musa(
rank,
non_blocking=True,
),
)
ssl = ssl.musa(rank, non_blocking=True)
ssl.requires_grad = False
# ssl_lengths = ssl_lengths.musa(rank, non_blocking=True)
text, text_lengths = (
text.musa(
rank,
non_blocking=True,
),
text_lengths.musa(
rank,
non_blocking=True,
),
)
if hps.model.version in {"v2Pro", "v2ProPlus"}:
sv_emb = sv_emb.musa(rank, non_blocking=True)
else:
spec, spec_lengths = spec.to(device), spec_lengths.to(device)
y, y_lengths = y.to(device), y_lengths.to(device)
@ -595,12 +659,17 @@ def evaluate(hps, generator, eval_loader, writer_eval):
text,
text_lengths,
) in enumerate(eval_loader):
print(111)
print("确实在跑")
if torch.cuda.is_available():
spec, spec_lengths = spec.cuda(), spec_lengths.cuda()
y, y_lengths = y.cuda(), y_lengths.cuda()
ssl = ssl.cuda()
text, text_lengths = text.cuda(), text_lengths.cuda()
elif musa_utils.is_available():
spec, spec_lengths = spec.musa(), spec_lengths.musa()
y, y_lengths = y.musa(), y_lengths.musa()
ssl = ssl.musa()
text, text_lengths = text.musa(), text_lengths.musa()
else:
spec, spec_lengths = spec.to(device), spec_lengths.to(device)
y, y_lengths = y.to(device), y_lengths.to(device)
@ -616,7 +685,7 @@ def evaluate(hps, generator, eval_loader, writer_eval):
text_lengths,
test=test,
)
if torch.cuda.is_available()
if torch.cuda.is_available() or musa_utils.is_available()
else generator.infer(
ssl,
spec,

View File

@ -4,15 +4,22 @@ warnings.filterwarnings("ignore")
import os
import utils
import musa_utils
hps = utils.get_hparams(stage=2)
os.environ["CUDA_VISIBLE_DEVICES"] = hps.train.gpu_numbers.replace("-", ",")
os.environ["MUSA_VISIBLE_DEVICES"] = hps.train.gpu_numbers.replace("-", ",")
import logging
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.cuda.amp import GradScaler, autocast
from torch.cuda.amp import GradScaler
if musa_utils.is_available():
autocast = torch.musa.amp.autocast
musa_ddp = musa_utils.should_ddp()
elif torch.cuda.is_available():
from torch.cuda.amp import autocast
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
@ -43,16 +50,22 @@ torch.backends.cudnn.deterministic = False
###反正A100fp32更快那试试tf32吧
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
if musa_utils.is_available():
torch.backends.mudnn.allow_tf32 = True
torch.set_float32_matmul_precision("medium") # 最低精度但最快(也就快一丁点),对于结果造成不了影响
# from config import pretrained_s2G,pretrained_s2D
global_step = 0
device = "cpu" # cuda以外的设备等mps优化后加入
if not musa_ddp:
device = "musa"
def main():
if torch.cuda.is_available():
n_gpus = torch.cuda.device_count()
elif musa_ddp:
n_gpus = musa_utils.device_count()
else:
n_gpus = 1
os.environ["MASTER_ADDR"] = "localhost"
@ -78,7 +91,7 @@ def run(rank, n_gpus, hps):
writer_eval = SummaryWriter(log_dir=os.path.join(hps.s2_ckpt_dir, "eval"))
dist.init_process_group(
backend="gloo" if os.name == "nt" or not torch.cuda.is_available() else "nccl",
backend= "mccl" if torch.musa.is_available() else ("gloo" if os.name == "nt" or not torch.cuda.is_available() else "nccl"),
init_method="env://?use_libuv=False",
world_size=n_gpus,
rank=rank,
@ -86,6 +99,8 @@ def run(rank, n_gpus, hps):
torch.manual_seed(hps.train.seed)
if torch.cuda.is_available():
torch.cuda.set_device(rank)
elif musa_ddp:
musa_utils.set_device(rank)
train_dataset = TextAudioSpeakerLoader(hps.data) ########
train_sampler = DistributedBucketSampler(
@ -140,6 +155,13 @@ def run(rank, n_gpus, hps):
**hps.model,
).cuda(rank)
if torch.cuda.is_available()
else SynthesizerTrn(
hps.data.filter_length // 2 + 1,
hps.train.segment_size // hps.data.hop_length,
n_speakers=hps.data.n_speakers,
**hps.model,
).musa(rank)
if musa_ddp
else SynthesizerTrn(
hps.data.filter_length // 2 + 1,
hps.train.segment_size // hps.data.hop_length,
@ -165,7 +187,7 @@ def run(rank, n_gpus, hps):
# betas=hps.train.betas,
# eps=hps.train.eps,
# )
if torch.cuda.is_available():
if torch.cuda.is_available() or musa_ddp:
net_g = DDP(net_g, device_ids=[rank], find_unused_parameters=True)
# net_d = DDP(net_d, device_ids=[rank], find_unused_parameters=True)
else:
@ -207,7 +229,7 @@ def run(rank, n_gpus, hps):
torch.load(hps.train.pretrained_s2G, map_location="cpu", weights_only=False)["weight"],
strict=False,
)
if torch.cuda.is_available()
if torch.cuda.is_available() or musa_ddp
else net_g.load_state_dict(
torch.load(hps.train.pretrained_s2G, map_location="cpu", weights_only=False)["weight"],
strict=False,
@ -235,7 +257,10 @@ def run(rank, n_gpus, hps):
scheduler_g.step()
# scheduler_d.step()
scaler = GradScaler(enabled=hps.train.fp16_run)
if musa_utils.is_available():
scaler = torch.musa.amp.GradScaler(enabled=hps.train.fp16_run)
else:
scaler = GradScaler(enabled=hps.train.fp16_run)
net_d = optim_d = scheduler_d = None
print("start training from epoch %s" % epoch_str)
@ -334,6 +359,31 @@ def train_and_evaluate(
non_blocking=True,
),
)
elif musa_ddp:
spec, spec_lengths = (
spec.musa(
rank,
non_blocking=True,
),
spec_lengths.musa(
rank,
non_blocking=True,
),
)
mel, mel_lengths = mel.musa(rank, non_blocking=True), mel_lengths.musa(rank, non_blocking=True)
ssl = ssl.musa(rank, non_blocking=True)
ssl.requires_grad = False
# ssl_lengths = ssl_lengths.musa(rank, non_blocking=True)
text, text_lengths = (
text.musa(
rank,
non_blocking=True,
),
text_lengths.musa(
rank,
non_blocking=True,
),
)
else:
spec, spec_lengths = spec.to(device), spec_lengths.to(device)
mel, mel_lengths = mel.to(device), mel_lengths.to(device)

View File

@ -4,15 +4,21 @@ warnings.filterwarnings("ignore")
import os
import utils
import musa_utils
hps = utils.get_hparams(stage=2)
os.environ["CUDA_VISIBLE_DEVICES"] = hps.train.gpu_numbers.replace("-", ",")
os.environ["MUSA_VISIBLE_DEVICES"] = hps.train.gpu_numbers.replace("-", ",")
import logging
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.cuda.amp import GradScaler, autocast
from torch.cuda.amp import GradScaler
if musa_utils.is_available():
autocast = torch.musa.amp.autocast
elif torch.cuda.is_available():
from torch.cuda.amp import autocast
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
@ -43,11 +49,15 @@ torch.backends.cudnn.deterministic = False
###反正A100fp32更快那试试tf32吧
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
if musa_utils.is_available():
torch.backends.mudnn.allow_tf32 = True
torch.set_float32_matmul_precision("medium") # 最低精度但最快(也就快一丁点),对于结果造成不了影响
# from config import pretrained_s2G,pretrained_s2D
global_step = 0
device = "cpu" # cuda以外的设备等mps优化后加入
if musa_utils.is_available(): # DDP支持不写了没设备测试
device = "musa"
def main():
@ -209,7 +219,10 @@ def run(rank, n_gpus, hps):
for _ in range(epoch_str):
scheduler_g.step()
scaler = GradScaler(enabled=hps.train.fp16_run)
if musa_utils.is_available():
scaler = torch.musa.amp.GradScaler(enabled=hps.train.fp16_run)
else:
scaler = GradScaler(enabled=hps.train.fp16_run)
net_d = optim_d = scheduler_d = None
print("start training from epoch %s" % epoch_str)

3
api.py
View File

@ -208,6 +208,7 @@ def clean_hifigan_model():
hifigan_model = None
try:
torch.cuda.empty_cache()
torch.musa.empty_cache()
except:
pass
@ -219,6 +220,7 @@ def clean_bigvgan_model():
bigvgan_model = None
try:
torch.cuda.empty_cache()
torch.musa.empty_cache()
except:
pass
@ -230,6 +232,7 @@ def clean_sv_cn_model():
sv_cn_model = None
try:
torch.cuda.empty_cache()
torch.musa.empty_cache()
except:
pass

View File

@ -4,6 +4,8 @@ import sys
import torch
import musa_utils
from tools.i18n.i18n import I18nAuto
i18n = I18nAuto(language=os.environ.get("language", "Auto"))
@ -147,11 +149,15 @@ api_port = 9880
# Thanks to the contribution of @Karasukaigan and @XXXXRT666
def get_device_dtype_sm(idx: int) -> tuple[torch.device, torch.dtype, float, float]:
device_idx = idx
cpu = torch.device("cpu")
cuda = torch.device(f"cuda:{idx}")
if musa_utils.is_available():
musa = torch.device(f"musa:{idx}")
mdtype, mversion, mmem_gb = musa_utils.get_device_dtype(device_idx)
return musa, mdtype, mversion, mmem_gb
if not torch.cuda.is_available():
return cpu, torch.float32, 0.0, 0.0
device_idx = idx
capability = torch.cuda.get_device_capability(device_idx)
name = torch.cuda.get_device_name(device_idx)
mem_bytes = torch.cuda.get_device_properties(device_idx).total_memory
@ -167,11 +173,16 @@ def get_device_dtype_sm(idx: int) -> tuple[torch.device, torch.dtype, float, flo
return cuda, torch.float16, sm_version, mem_gb
return cpu, torch.float32, 0.0, 0.0
def get_gpu_count():
if musa_utils.is_available():
return musa_utils.device_count()
else:
return torch.cuda.device_count()
IS_GPU = True
GPU_INFOS: list[str] = []
GPU_INDEX: set[int] = set()
GPU_COUNT = torch.cuda.device_count()
GPU_COUNT = get_gpu_count()
CPU_INFO: str = "0\tCPU " + i18n("CPU训练,较慢")
tmp: list[tuple[torch.device, torch.dtype, float, float]] = []
memset: set[float] = set()
@ -183,7 +194,11 @@ for j in tmp:
device = j[0]
memset.add(j[3])
if device.type != "cpu":
GPU_INFOS.append(f"{device.index}\t{torch.cuda.get_device_name(device.index)}")
if device.type == "cuda" and torch.cuda.is_available():
device_name = torch.cuda.get_device_name(device.index)
elif device.type == "musa" and musa_utils.is_available():
device_name = torch.musa.get_device_properties(device.index).name
GPU_INFOS.append(f"{device.index}\t{device_name}")
GPU_INDEX.add(device.index)
if not GPU_INFOS:

99
musa_utils.py Normal file
View File

@ -0,0 +1,99 @@
import torch
from typing import Optional
import re
_musa_available = False
_musa_err_msg = "torch_musa not found or not configured correctly."
try:
if hasattr(torch, 'musa') and torch.musa.is_available():
try:
major, minor = torch.musa.get_device_capability()
version_val = major + minor / 10.0
if version_val < 2.1:
raise RuntimeError(
f"MUSA version check failed! "
f"Found capability {major}.{minor} (version value {version_val:.2f}), "
f"but this project requires a version >= 2.1. "
f"Please upgrade your torch_musa and MUSA SDK."
f"See: https://github.com/MooreThreads/torch_musa"
)
_musa_available = True
except Exception as e:
_musa_err_msg = f"MUSA availability check failed: {e}"
if isinstance(e, RuntimeError):
raise e
_musa_available = False
except Exception:
_musa_available = False
def is_available() -> bool:
return _musa_available
def get_device() -> Optional[torch.device]:
return torch.device("musa") if _musa_available else None
def device_count() -> int:
if _musa_available:
try:
return torch.musa.device_count()
except Exception:
return 0
return 0
def set_device(device_index: int):
if _musa_available:
try:
torch.musa.set_device(device_index)
except Exception:
pass
def empty_cache():
if _musa_available:
try:
torch.musa.empty_cache()
except Exception:
pass
def manual_seed(seed: int):
if _musa_available:
try:
torch.musa.manual_seed(seed)
except Exception:
pass
def manual_seed_all(seed: int):
if _musa_available:
try:
torch.musa.manual_seed_all(seed)
except Exception:
pass
def get_device_dtype(device_idx: int = 0) -> tuple[torch.dtype, float, float]:
properties = torch.musa.get_device_properties(device_idx)
major, minor = torch.musa.get_device_capability()
version_val = major + minor / 10.0
mem_bytes = properties.total_memory
mem_gb = mem_bytes / (1024**3) + 0.4
device_name = properties.name
dtype = torch.float32
numbers_in_name = [int(n) for n in re.findall(r'\d+', device_name)]
if any(num >= 4000 for num in numbers_in_name):
dtype = torch.float16
return dtype, version_val, mem_gb
def should_ddp(device_idx: int = 0) -> bool:
device_name = torch.musa.get_device_properties(device_idx).name
numbers_in_name = [int(n) for n in re.findall(r'\d+', device_name)]
if any(num >= 4000 for num in numbers_in_name):
return True
else:
return False
DEVICE: Optional[torch.device] = get_device()

View File

@ -13,6 +13,12 @@ from tools.asr.config import get_models
from tools.asr.funasr_asr import only_asr
from tools.my_utils import load_cudnn
try:
import torch_musa
use_torch_musa = True
except ImportError:
use_torch_musa = False
# fmt: off
language_code_list = [
"af", "am", "ar", "as", "az",
@ -86,6 +92,8 @@ def execute_asr(input_folder, output_folder, model_path, language, precision):
language = None # 不设置语种由模型自动输出概率最高的语种
print("loading faster whisper model:", model_path, model_path)
device = "cuda" if torch.cuda.is_available() else "cpu"
if use_torch_musa:
device = "musa" if torch.musa.is_available() else "cpu"
model = WhisperModel(model_path, device=device, compute_type=precision)
input_file_names = os.listdir(input_folder)

View File

@ -10,6 +10,12 @@ import torch.nn as nn
import yaml
from tqdm import tqdm
try:
import torch_musa
use_torch_musa = True
except ImportError:
use_torch_musa = False
warnings.filterwarnings("ignore")
@ -135,7 +141,13 @@ class Roformer_Loader:
window_middle[-fade_size:] *= fadeout
window_middle[:fade_size] *= fadein
with torch.amp.autocast("cuda"):
if use_torch_musa:
if torch.musa.is_available():
set_device = "musa"
else:
set_device = "cuda"
with torch.amp.autocast(set_device):
with torch.inference_mode():
if self.config["training"]["target_instrument"] is None:
req_shape = (len(self.config["training"]["instruments"]),) + tuple(mix.shape)

View File

@ -18,6 +18,12 @@ from bsroformer import Roformer_Loader
from mdxnet import MDXNetDereverb
from vr import AudioPre, AudioPreDeEcho
try:
import torch_musa
use_torch_musa = True
except ImportError:
use_torch_musa = False
weight_uvr5_root = "tools/uvr5/uvr5_weights"
uvr5_names = []
for name in os.listdir(weight_uvr5_root):
@ -122,6 +128,9 @@ def uvr(model_name, inp_root, save_root_vocal, paths, save_root_ins, agg, format
print("clean_empty_cache")
if torch.cuda.is_available():
torch.cuda.empty_cache()
if use_torch_musa:
if torch.musa.is_available():
torch.musa.empty_cache()
yield "\n".join(infos)

View File

@ -92,6 +92,8 @@ os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
# os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1' # 当遇到mps不支持的步骤时使用cpu
import gradio as gr
import musa_utils
n_cpu = cpu_count()
set_gpu_numbers = GPU_INDEX
@ -347,6 +349,8 @@ def change_tts_inference(bert_path, cnhubert_base_path, gpu_number, gpt_path, so
os.environ["is_half"] = str(is_half)
os.environ["infer_ttswebui"] = str(webui_port_infer_tts)
os.environ["is_share"] = str(is_share)
if musa_utils.is_available():
os.environ["_MUSA_VISIBLE_DEVICES"] = str(fix_gpu_number(gpu_number))
yield (
process_info(process_name_tts, "opened"),
{"__type__": "update", "visible": False},
@ -629,6 +633,8 @@ def open1Bb(
# data["version"]=version
os.environ["_CUDA_VISIBLE_DEVICES"] = str(fix_gpu_numbers(gpu_numbers.replace("-", ",")))
if musa_utils.is_available():
os.environ["_MUSA_VISIBLE_DEVICES"] = str(fix_gpu_number(gpu_numbers.replace("-", ",")))
os.environ["hz"] = "25hz"
tmp_config_path = "%s/tmp_s1.yaml" % tmp
with open(tmp_config_path, "w") as f:
@ -805,6 +811,8 @@ def open1a(inp_text, inp_wav_dir, exp_name, gpu_numbers, bert_pretrained_dir):
"is_half": str(is_half),
}
)
if musa_utils.is_available():
config.update({"_MUSA_VISIBLE_DEVICES": str(fix_gpu_number(gpu_names[i_part])),})
os.environ.update(config)
cmd = '"%s" -s GPT_SoVITS/prepare_datasets/1-get-text.py' % python_exec
print(cmd)
@ -895,6 +903,8 @@ def open1b(version, inp_text, inp_wav_dir, exp_name, gpu_numbers, ssl_pretrained
"_CUDA_VISIBLE_DEVICES": str(fix_gpu_number(gpu_names[i_part])),
}
)
if musa_utils.is_available():
config.update({"_MUSA_VISIBLE_DEVICES": str(fix_gpu_number(gpu_names[i_part])),})
os.environ.update(config)
cmd = '"%s" -s GPT_SoVITS/prepare_datasets/2-get-hubert-wav32k.py' % python_exec
print(cmd)
@ -917,6 +927,8 @@ def open1b(version, inp_text, inp_wav_dir, exp_name, gpu_numbers, ssl_pretrained
"_CUDA_VISIBLE_DEVICES": str(fix_gpu_number(gpu_names[i_part])),
}
)
if musa_utils.is_available():
config.update({"_MUSA_VISIBLE_DEVICES": str(fix_gpu_number(gpu_names[i_part])),})
os.environ.update(config)
cmd = '"%s" -s GPT_SoVITS/prepare_datasets/2-get-sv.py' % python_exec
print(cmd)
@ -989,6 +1001,8 @@ def open1c(version, inp_text, inp_wav_dir, exp_name, gpu_numbers, pretrained_s2G
"_CUDA_VISIBLE_DEVICES": str(fix_gpu_number(gpu_names[i_part])),
}
)
if musa_utils.is_available():
config.update({"_MUSA_VISIBLE_DEVICES": str(fix_gpu_number(gpu_names[i_part])),})
os.environ.update(config)
cmd = '"%s" -s GPT_SoVITS/prepare_datasets/3-get-semantic.py' % python_exec
print(cmd)
@ -1089,6 +1103,8 @@ def open1abc(
"_CUDA_VISIBLE_DEVICES": str(fix_gpu_number(gpu_names[i_part])),
}
)
if musa_utils.is_available():
config.update({"_MUSA_VISIBLE_DEVICES": str(fix_gpu_number(gpu_names[i_part])),})
os.environ.update(config)
cmd = '"%s" -s GPT_SoVITS/prepare_datasets/1-get-text.py' % python_exec
print(cmd)
@ -1136,6 +1152,8 @@ def open1abc(
"_CUDA_VISIBLE_DEVICES": str(fix_gpu_number(gpu_names[i_part])),
}
)
if musa_utils.is_available():
config.update({"_MUSA_VISIBLE_DEVICES": str(fix_gpu_number(gpu_names[i_part])),})
os.environ.update(config)
cmd = '"%s" -s GPT_SoVITS/prepare_datasets/2-get-hubert-wav32k.py' % python_exec
print(cmd)
@ -1158,6 +1176,8 @@ def open1abc(
"_CUDA_VISIBLE_DEVICES": str(fix_gpu_number(gpu_names[i_part])),
}
)
if musa_utils.is_available():
config.update({"_MUSA_VISIBLE_DEVICES": str(fix_gpu_number(gpu_names[i_part])),})
os.environ.update(config)
cmd = '"%s" -s GPT_SoVITS/prepare_datasets/2-get-sv.py' % python_exec
print(cmd)
@ -1198,6 +1218,8 @@ def open1abc(
"_CUDA_VISIBLE_DEVICES": str(fix_gpu_number(gpu_names[i_part])),
}
)
if musa_utils.is_available():
config.update({"_MUSA_VISIBLE_DEVICES": str(fix_gpu_number(gpu_names[i_part])),})
os.environ.update(config)
cmd = '"%s" -s GPT_SoVITS/prepare_datasets/3-get-semantic.py' % python_exec
print(cmd)