Support on MUSA device.

fix

Update musa_utils.py

Update musa_utils.py

Update config.py

fix

rollback S1 train

DDP only support S4000

DDP only support S4000

fix
This commit is contained in:
KakaruHayate 2025-10-10 21:14:27 +08:00
parent 11aa78bd9b
commit 72be145051
21 changed files with 393 additions and 30 deletions

View File

@ -146,4 +146,4 @@ class DistributedBucketSampler(Sampler[T_co]):
Args: Args:
epoch (int): Epoch number. epoch (int): Epoch number.
""" """
self.epoch = epoch self.epoch = epoch

View File

@ -36,6 +36,8 @@ from TTS_infer_pack.text_segmentation_method import splits
from TTS_infer_pack.TextPreprocessor import TextPreprocessor from TTS_infer_pack.TextPreprocessor import TextPreprocessor
from sv import SV from sv import SV
import musa_utils
resample_transform_dict = {} resample_transform_dict = {}
@ -209,6 +211,10 @@ def set_seed(seed: int):
# 开启后会影响精度 # 开启后会影响精度
torch.backends.cuda.matmul.allow_tf32 = False torch.backends.cuda.matmul.allow_tf32 = False
torch.backends.cudnn.allow_tf32 = False torch.backends.cudnn.allow_tf32 = False
elif musa_utils.is_available():
musa_utils.manual_seed(seed)
musa_utils.manual_seed_all(seed)
torch.backends.mudnn.allow_tf32 = False
except: except:
pass pass
return seed return seed
@ -310,8 +316,10 @@ class TTS_Config:
self.default_configs = deepcopy(configs_) self.default_configs = deepcopy(configs_)
self.device = self.configs.get("device", torch.device("cpu")) self.device = self.configs.get("device", torch.device("cpu"))
if "cuda" in str(self.device) and not torch.cuda.is_available(): cuda_mismatch = "cuda" in str(self.device) and not torch.cuda.is_available()
print("Warning: CUDA is not available, set device to CPU.") musa_mismatch = "musa" in str(self.device) and not musa_utils.is_available()
if cuda_mismatch or musa_mismatch:
print(f"Warning: Requested device '{self.device}' is not available, falling back to CPU.")
self.device = torch.device("cpu") self.device = torch.device("cpu")
self.is_half = self.configs.get("is_half", False) self.is_half = self.configs.get("is_half", False)
@ -1369,6 +1377,8 @@ class TTS:
gc.collect() # 触发gc的垃圾回收。避免内存一直增长。 gc.collect() # 触发gc的垃圾回收。避免内存一直增长。
if "cuda" in str(self.configs.device): if "cuda" in str(self.configs.device):
torch.cuda.empty_cache() torch.cuda.empty_cache()
elif "musa" in str(self.configs.device):
torch.musa.empty_cache()
elif str(self.configs.device) == "mps": elif str(self.configs.device) == "mps":
torch.mps.empty_cache() torch.mps.empty_cache()
except: except:

View File

@ -5,6 +5,8 @@ import torch
import torchaudio import torchaudio
from torch import Tensor from torch import Tensor
import musa_utils
__all__ = [ __all__ = [
"get_mel_banks", "get_mel_banks",
"inverse_mel_scale", "inverse_mel_scale",
@ -305,7 +307,12 @@ def spectrogram(
) )
# size (m, padded_window_size // 2 + 1, 2) # size (m, padded_window_size // 2 + 1, 2)
if musa_utils.is_available() : # 怎么还有算子不支持怎么还有算子不支持怎么还有算子不支持
ori_device = strided_input.device
strided_input = strided_input.cpu()
fft = torch.fft.rfft(strided_input) fft = torch.fft.rfft(strided_input)
if musa_utils.is_available() :
fft = fft.to(ori_device)
# Convert the FFT into a power spectrum # Convert the FFT into a power spectrum
power_spectrum = torch.max(fft.abs().pow(2.0), epsilon).log() # size (m, padded_window_size // 2 + 1) power_spectrum = torch.max(fft.abs().pow(2.0), epsilon).log() # size (m, padded_window_size // 2 + 1)
@ -618,7 +625,12 @@ def fbank(
) )
# size (m, padded_window_size // 2 + 1) # size (m, padded_window_size // 2 + 1)
if musa_utils.is_available() : # 怎么还有算子不支持怎么还有算子不支持怎么还有算子不支持
ori_device = strided_input.device
strided_input = strided_input.cpu()
spectrum = torch.fft.rfft(strided_input).abs() spectrum = torch.fft.rfft(strided_input).abs()
if musa_utils.is_available() :
spectrum = spectrum.to(ori_device)
if use_power: if use_power:
spectrum = spectrum.pow(2.0) spectrum = spectrum.pow(2.0)

View File

@ -49,6 +49,8 @@ from config import change_choices, get_weights_names, name2gpt_path, name2sovits
SoVITS_names, GPT_names = get_weights_names() SoVITS_names, GPT_names = get_weights_names()
from config import pretrained_sovits_name from config import pretrained_sovits_name
import musa_utils
path_sovits_v3 = pretrained_sovits_name["v3"] path_sovits_v3 = pretrained_sovits_name["v3"]
path_sovits_v4 = pretrained_sovits_name["v4"] path_sovits_v4 = pretrained_sovits_name["v4"]
is_exist_s2gv3 = os.path.exists(path_sovits_v3) is_exist_s2gv3 = os.path.exists(path_sovits_v3)
@ -87,7 +89,9 @@ is_share = os.environ.get("is_share", "False")
is_share = eval(is_share) is_share = eval(is_share)
if "_CUDA_VISIBLE_DEVICES" in os.environ: if "_CUDA_VISIBLE_DEVICES" in os.environ:
os.environ["CUDA_VISIBLE_DEVICES"] = os.environ["_CUDA_VISIBLE_DEVICES"] os.environ["CUDA_VISIBLE_DEVICES"] = os.environ["_CUDA_VISIBLE_DEVICES"]
is_half = eval(os.environ.get("is_half", "True")) and torch.cuda.is_available() if "_MUSA_VISIBLE_DEVICES" in os.environ:
os.environ["MUSA_VISIBLE_DEVICES"] = os.environ["_MUSA_VISIBLE_DEVICES"]
is_half = eval(os.environ.get("is_half", "True")) and (torch.cuda.is_available() or musa_utils.is_available())
# is_half=False # is_half=False
punctuation = set(["!", "?", "", ",", ".", "-", " "]) punctuation = set(["!", "?", "", ",", ".", "-", " "])
import gradio as gr import gradio as gr
@ -112,6 +116,8 @@ def set_seed(seed):
np.random.seed(seed) np.random.seed(seed)
torch.manual_seed(seed) torch.manual_seed(seed)
torch.cuda.manual_seed(seed) torch.cuda.manual_seed(seed)
if musa_utils.is_available():
musa_utils.manual_seed(seed)
# set_seed(42) # set_seed(42)
@ -134,6 +140,8 @@ i18n = I18nAuto(language=language)
if torch.cuda.is_available(): if torch.cuda.is_available():
device = "cuda" device = "cuda"
elif musa_utils.is_available():
device = "musa"
else: else:
device = "cpu" device = "cpu"
@ -411,6 +419,7 @@ def clean_hifigan_model():
hifigan_model = None hifigan_model = None
try: try:
torch.cuda.empty_cache() torch.cuda.empty_cache()
torch.musa.empty_cache()
except: except:
pass pass
@ -422,6 +431,7 @@ def clean_bigvgan_model():
bigvgan_model = None bigvgan_model = None
try: try:
torch.cuda.empty_cache() torch.cuda.empty_cache()
torch.musa.empty_cache()
except: except:
pass pass
@ -433,6 +443,7 @@ def clean_sv_cn_model():
sv_cn_model = None sv_cn_model = None
try: try:
torch.cuda.empty_cache() torch.cuda.empty_cache()
torch.musa.empty_cache()
except: except:
pass pass

View File

@ -29,6 +29,8 @@ import sys
import torch import torch
import musa_utils
now_dir = os.getcwd() now_dir = os.getcwd()
sys.path.append(now_dir) sys.path.append(now_dir)
sys.path.append("%s/GPT_SoVITS" % (now_dir)) sys.path.append("%s/GPT_SoVITS" % (now_dir))
@ -48,8 +50,10 @@ is_share = os.environ.get("is_share", "False")
is_share = eval(is_share) is_share = eval(is_share)
if "_CUDA_VISIBLE_DEVICES" in os.environ: if "_CUDA_VISIBLE_DEVICES" in os.environ:
os.environ["CUDA_VISIBLE_DEVICES"] = os.environ["_CUDA_VISIBLE_DEVICES"] os.environ["CUDA_VISIBLE_DEVICES"] = os.environ["_CUDA_VISIBLE_DEVICES"]
if "_MUSA_VISIBLE_DEVICES" in os.environ:
os.environ["MUSA_VISIBLE_DEVICES"] = os.environ["_MUSA_VISIBLE_DEVICES"]
is_half = eval(os.environ.get("is_half", "True")) and torch.cuda.is_available() is_half = eval(os.environ.get("is_half", "True")) and (torch.cuda.is_available() or musa_utils.is_available())
gpt_path = os.environ.get("gpt_path", None) gpt_path = os.environ.get("gpt_path", None)
sovits_path = os.environ.get("sovits_path", None) sovits_path = os.environ.get("sovits_path", None)
cnhubert_base_path = os.environ.get("cnhubert_base_path", None) cnhubert_base_path = os.environ.get("cnhubert_base_path", None)
@ -72,6 +76,8 @@ i18n = I18nAuto(language=language)
if torch.cuda.is_available(): if torch.cuda.is_available():
device = "cuda" device = "cuda"
elif musa_utils.is_available():
device = "musa"
# elif torch.backends.mps.is_available(): # elif torch.backends.mps.is_available():
# device = "mps" # device = "mps"
else: else:

View File

@ -7,6 +7,7 @@ import torch
from torch import nn from torch import nn
from torch.nn import functional as F from torch.nn import functional as F
import musa_utils
from module import commons from module import commons
from module import modules from module import modules
from module import attentions from module import attentions
@ -20,7 +21,10 @@ from module.quantize import ResidualVectorQuantizer
# from text import symbols # from text import symbols
from text import symbols as symbols_v1 from text import symbols as symbols_v1
from text import symbols2 as symbols_v2 from text import symbols2 as symbols_v2
from torch.cuda.amp import autocast if musa_utils.is_available():
autocast = torch.musa.amp.autocast
elif torch.cuda.is_available():
from torch.cuda.amp import autocast
import contextlib import contextlib
import random import random

View File

@ -9,11 +9,14 @@ i_part = os.environ.get("i_part")
all_parts = os.environ.get("all_parts") all_parts = os.environ.get("all_parts")
if "_CUDA_VISIBLE_DEVICES" in os.environ: if "_CUDA_VISIBLE_DEVICES" in os.environ:
os.environ["CUDA_VISIBLE_DEVICES"] = os.environ["_CUDA_VISIBLE_DEVICES"] os.environ["CUDA_VISIBLE_DEVICES"] = os.environ["_CUDA_VISIBLE_DEVICES"]
if "_MUSA_VISIBLE_DEVICES" in os.environ:
os.environ["MUSA_VISIBLE_DEVICES"] = os.environ["_MUSA_VISIBLE_DEVICES"]
opt_dir = os.environ.get("opt_dir") opt_dir = os.environ.get("opt_dir")
bert_pretrained_dir = os.environ.get("bert_pretrained_dir") bert_pretrained_dir = os.environ.get("bert_pretrained_dir")
import torch import torch
import musa_utils
is_half = eval(os.environ.get("is_half", "True")) and torch.cuda.is_available() is_half = eval(os.environ.get("is_half", "True")) and (torch.cuda.is_available() or musa_utils.is_available())
version = os.environ.get("version", None) version = os.environ.get("version", None)
import traceback import traceback
import os.path import os.path
@ -50,6 +53,8 @@ if os.path.exists(txt_path) == False:
os.makedirs(bert_dir, exist_ok=True) os.makedirs(bert_dir, exist_ok=True)
if torch.cuda.is_available(): if torch.cuda.is_available():
device = "cuda:0" device = "cuda:0"
elif musa_utils.is_available():
device = "musa:0"
# elif torch.backends.mps.is_available(): # elif torch.backends.mps.is_available():
# device = "mps" # device = "mps"
else: else:

View File

@ -10,13 +10,16 @@ i_part = os.environ.get("i_part")
all_parts = os.environ.get("all_parts") all_parts = os.environ.get("all_parts")
if "_CUDA_VISIBLE_DEVICES" in os.environ: if "_CUDA_VISIBLE_DEVICES" in os.environ:
os.environ["CUDA_VISIBLE_DEVICES"] = os.environ["_CUDA_VISIBLE_DEVICES"] os.environ["CUDA_VISIBLE_DEVICES"] = os.environ["_CUDA_VISIBLE_DEVICES"]
if "_MUSA_VISIBLE_DEVICES" in os.environ:
os.environ["MUSA_VISIBLE_DEVICES"] = os.environ["_MUSA_VISIBLE_DEVICES"]
from feature_extractor import cnhubert from feature_extractor import cnhubert
opt_dir = os.environ.get("opt_dir") opt_dir = os.environ.get("opt_dir")
cnhubert.cnhubert_base_path = os.environ.get("cnhubert_base_dir") cnhubert.cnhubert_base_path = os.environ.get("cnhubert_base_dir")
import torch import torch
import musa_utils
is_half = eval(os.environ.get("is_half", "True")) and torch.cuda.is_available() is_half = eval(os.environ.get("is_half", "True")) and (torch.cuda.is_available() or musa_utils.is_available())
import traceback import traceback
import numpy as np import numpy as np
@ -61,6 +64,8 @@ maxx = 0.95
alpha = 0.5 alpha = 0.5
if torch.cuda.is_available(): if torch.cuda.is_available():
device = "cuda:0" device = "cuda:0"
elif musa_utils.is_available():
device = "musa:0"
# elif torch.backends.mps.is_available(): # elif torch.backends.mps.is_available():
# device = "mps" # device = "mps"
else: else:

View File

@ -10,12 +10,15 @@ i_part = os.environ.get("i_part")
all_parts = os.environ.get("all_parts") all_parts = os.environ.get("all_parts")
if "_CUDA_VISIBLE_DEVICES" in os.environ: if "_CUDA_VISIBLE_DEVICES" in os.environ:
os.environ["CUDA_VISIBLE_DEVICES"] = os.environ["_CUDA_VISIBLE_DEVICES"] os.environ["CUDA_VISIBLE_DEVICES"] = os.environ["_CUDA_VISIBLE_DEVICES"]
if "_MUSA_VISIBLE_DEVICES" in os.environ:
os.environ["MUSA_VISIBLE_DEVICES"] = os.environ["_MUSA_VISIBLE_DEVICES"]
opt_dir = os.environ.get("opt_dir") opt_dir = os.environ.get("opt_dir")
sv_path = os.environ.get("sv_path") sv_path = os.environ.get("sv_path")
import torch import torch
import musa_utils
is_half = eval(os.environ.get("is_half", "True")) and torch.cuda.is_available() is_half = eval(os.environ.get("is_half", "True")) and (torch.cuda.is_available() or musa_utils.is_available())
import traceback import traceback
import torchaudio import torchaudio
@ -49,6 +52,8 @@ maxx = 0.95
alpha = 0.5 alpha = 0.5
if torch.cuda.is_available(): if torch.cuda.is_available():
device = "cuda:0" device = "cuda:0"
elif musa_utils.is_available():
device = "musa:0"
# elif torch.backends.mps.is_available(): # elif torch.backends.mps.is_available():
# device = "mps" # device = "mps"
else: else:

View File

@ -6,6 +6,8 @@ i_part = os.environ.get("i_part")
all_parts = os.environ.get("all_parts") all_parts = os.environ.get("all_parts")
if "_CUDA_VISIBLE_DEVICES" in os.environ: if "_CUDA_VISIBLE_DEVICES" in os.environ:
os.environ["CUDA_VISIBLE_DEVICES"] = os.environ["_CUDA_VISIBLE_DEVICES"] os.environ["CUDA_VISIBLE_DEVICES"] = os.environ["_CUDA_VISIBLE_DEVICES"]
if "_MUSA_VISIBLE_DEVICES" in os.environ:
os.environ["MUSA_VISIBLE_DEVICES"] = os.environ["_MUSA_VISIBLE_DEVICES"]
opt_dir = os.environ.get("opt_dir") opt_dir = os.environ.get("opt_dir")
pretrained_s2G = os.environ.get("pretrained_s2G") pretrained_s2G = os.environ.get("pretrained_s2G")
s2config_path = os.environ.get("s2config_path") s2config_path = os.environ.get("s2config_path")
@ -27,8 +29,9 @@ elif size < 700 * 1024 * 1024:
else: else:
version = "v3" version = "v3"
import torch import torch
import musa_utils
is_half = eval(os.environ.get("is_half", "True")) and torch.cuda.is_available() is_half = eval(os.environ.get("is_half", "True")) and (torch.cuda.is_available() or musa_utils.is_available())
import traceback import traceback
import sys import sys
@ -61,6 +64,8 @@ if os.path.exists(semantic_path) == False:
if torch.cuda.is_available(): if torch.cuda.is_available():
device = "cuda" device = "cuda"
elif musa_utils.is_available():
device = "musa:0"
# elif torch.backends.mps.is_available(): # elif torch.backends.mps.is_available():
# device = "mps" # device = "mps"
else: else:

View File

@ -168,4 +168,4 @@ if __name__ == "__main__":
args = parser.parse_args() args = parser.parse_args()
logging.info(str(args)) logging.info(str(args))
main(args) main(args)

View File

@ -4,15 +4,23 @@ warnings.filterwarnings("ignore")
import os import os
import utils import utils
import musa_utils
hps = utils.get_hparams(stage=2) hps = utils.get_hparams(stage=2)
os.environ["CUDA_VISIBLE_DEVICES"] = hps.train.gpu_numbers.replace("-", ",") os.environ["CUDA_VISIBLE_DEVICES"] = hps.train.gpu_numbers.replace("-", ",")
os.environ["MUSA_VISIBLE_DEVICES"] = hps.train.gpu_numbers.replace("-", ",")
import logging import logging
import torch import torch
import torch.distributed as dist import torch.distributed as dist
import torch.multiprocessing as mp import torch.multiprocessing as mp
from torch.cuda.amp import GradScaler, autocast from torch.cuda.amp import GradScaler
musa_ddp = False
if musa_utils.is_available():
autocast = torch.musa.amp.autocast
musa_ddp = musa_utils.should_ddp()
elif torch.cuda.is_available():
from torch.cuda.amp import autocast
from torch.nn import functional as F from torch.nn import functional as F
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
@ -43,16 +51,22 @@ torch.backends.cudnn.deterministic = False
###反正A100fp32更快那试试tf32吧 ###反正A100fp32更快那试试tf32吧
torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True
if musa_utils.is_available():
torch.backends.mudnn.allow_tf32 = True
torch.set_float32_matmul_precision("medium") # 最低精度但最快(也就快一丁点),对于结果造成不了影响 torch.set_float32_matmul_precision("medium") # 最低精度但最快(也就快一丁点),对于结果造成不了影响
# from config import pretrained_s2G,pretrained_s2D # from config import pretrained_s2G,pretrained_s2D
global_step = 0 global_step = 0
device = "cpu" # cuda以外的设备等mps优化后加入 device = "cpu" # cuda以外的设备等mps优化后加入
if not musa_ddp:
device = "musa"
def main(): def main():
if torch.cuda.is_available(): if torch.cuda.is_available():
n_gpus = torch.cuda.device_count() n_gpus = torch.cuda.device_count()
elif musa_ddp:
n_gpus = musa_utils.device_count()
else: else:
n_gpus = 1 n_gpus = 1
os.environ["MASTER_ADDR"] = "localhost" os.environ["MASTER_ADDR"] = "localhost"
@ -78,7 +92,7 @@ def run(rank, n_gpus, hps):
writer_eval = SummaryWriter(log_dir=os.path.join(hps.s2_ckpt_dir, "eval")) writer_eval = SummaryWriter(log_dir=os.path.join(hps.s2_ckpt_dir, "eval"))
dist.init_process_group( dist.init_process_group(
backend="gloo" if os.name == "nt" or not torch.cuda.is_available() else "nccl", backend = "mccl" if torch.musa.is_available() else ("gloo" if os.name == "nt" or not torch.cuda.is_available() else "nccl"),
init_method="env://?use_libuv=False", init_method="env://?use_libuv=False",
world_size=n_gpus, world_size=n_gpus,
rank=rank, rank=rank,
@ -86,6 +100,8 @@ def run(rank, n_gpus, hps):
torch.manual_seed(hps.train.seed) torch.manual_seed(hps.train.seed)
if torch.cuda.is_available(): if torch.cuda.is_available():
torch.cuda.set_device(rank) torch.cuda.set_device(rank)
elif musa_ddp:
musa_utils.set_device(rank)
train_dataset = TextAudioSpeakerLoader(hps.data, version=hps.model.version) train_dataset = TextAudioSpeakerLoader(hps.data, version=hps.model.version)
train_sampler = DistributedBucketSampler( train_sampler = DistributedBucketSampler(
@ -140,6 +156,13 @@ def run(rank, n_gpus, hps):
**hps.model, **hps.model,
).cuda(rank) ).cuda(rank)
if torch.cuda.is_available() if torch.cuda.is_available()
else SynthesizerTrn(
hps.data.filter_length // 2 + 1,
hps.train.segment_size // hps.data.hop_length,
n_speakers=hps.data.n_speakers,
**hps.model,
).musa(rank)
if musa_ddp
else SynthesizerTrn( else SynthesizerTrn(
hps.data.filter_length // 2 + 1, hps.data.filter_length // 2 + 1,
hps.train.segment_size // hps.data.hop_length, hps.train.segment_size // hps.data.hop_length,
@ -151,6 +174,8 @@ def run(rank, n_gpus, hps):
net_d = ( net_d = (
MultiPeriodDiscriminator(hps.model.use_spectral_norm, version=hps.model.version).cuda(rank) MultiPeriodDiscriminator(hps.model.use_spectral_norm, version=hps.model.version).cuda(rank)
if torch.cuda.is_available() if torch.cuda.is_available()
else MultiPeriodDiscriminator(hps.model.use_spectral_norm, version=hps.model.version).musa(rank)
if musa_ddp
else MultiPeriodDiscriminator(hps.model.use_spectral_norm, version=hps.model.version).to(device) else MultiPeriodDiscriminator(hps.model.use_spectral_norm, version=hps.model.version).to(device)
) )
for name, param in net_g.named_parameters(): for name, param in net_g.named_parameters():
@ -196,7 +221,7 @@ def run(rank, n_gpus, hps):
betas=hps.train.betas, betas=hps.train.betas,
eps=hps.train.eps, eps=hps.train.eps,
) )
if torch.cuda.is_available(): if torch.cuda.is_available() or musa_ddp:
net_g = DDP(net_g, device_ids=[rank], find_unused_parameters=True) net_g = DDP(net_g, device_ids=[rank], find_unused_parameters=True)
net_d = DDP(net_d, device_ids=[rank], find_unused_parameters=True) net_d = DDP(net_d, device_ids=[rank], find_unused_parameters=True)
else: else:
@ -238,7 +263,7 @@ def run(rank, n_gpus, hps):
torch.load(hps.train.pretrained_s2G, map_location="cpu", weights_only=False)["weight"], torch.load(hps.train.pretrained_s2G, map_location="cpu", weights_only=False)["weight"],
strict=False, strict=False,
) )
if torch.cuda.is_available() if torch.cuda.is_available() or musa_ddp
else net_g.load_state_dict( else net_g.load_state_dict(
torch.load(hps.train.pretrained_s2G, map_location="cpu", weights_only=False)["weight"], torch.load(hps.train.pretrained_s2G, map_location="cpu", weights_only=False)["weight"],
strict=False, strict=False,
@ -256,7 +281,7 @@ def run(rank, n_gpus, hps):
net_d.module.load_state_dict( net_d.module.load_state_dict(
torch.load(hps.train.pretrained_s2D, map_location="cpu", weights_only=False)["weight"], strict=False torch.load(hps.train.pretrained_s2D, map_location="cpu", weights_only=False)["weight"], strict=False
) )
if torch.cuda.is_available() if torch.cuda.is_available() or musa_ddp
else net_d.load_state_dict( else net_d.load_state_dict(
torch.load(hps.train.pretrained_s2D, map_location="cpu", weights_only=False)["weight"], torch.load(hps.train.pretrained_s2D, map_location="cpu", weights_only=False)["weight"],
), ),
@ -279,7 +304,10 @@ def run(rank, n_gpus, hps):
scheduler_g.step() scheduler_g.step()
scheduler_d.step() scheduler_d.step()
scaler = GradScaler(enabled=hps.train.fp16_run) if musa_utils.is_available():
scaler = torch.musa.amp.GradScaler(enabled=hps.train.fp16_run)
else:
scaler = GradScaler(enabled=hps.train.fp16_run)
print("start training from epoch %s" % epoch_str) print("start training from epoch %s" % epoch_str)
for epoch in range(epoch_str, hps.train.epochs + 1): for epoch in range(epoch_str, hps.train.epochs + 1):
@ -369,6 +397,42 @@ def train_and_evaluate(rank, epoch, hps, nets, optims, schedulers, scaler, loade
) )
if hps.model.version in {"v2Pro", "v2ProPlus"}: if hps.model.version in {"v2Pro", "v2ProPlus"}:
sv_emb = sv_emb.cuda(rank, non_blocking=True) sv_emb = sv_emb.cuda(rank, non_blocking=True)
elif musa_ddp:
spec, spec_lengths = (
spec.musa(
rank,
non_blocking=True,
),
spec_lengths.musa(
rank,
non_blocking=True,
),
)
y, y_lengths = (
y.musa(
rank,
non_blocking=True,
),
y_lengths.musa(
rank,
non_blocking=True,
),
)
ssl = ssl.musa(rank, non_blocking=True)
ssl.requires_grad = False
# ssl_lengths = ssl_lengths.musa(rank, non_blocking=True)
text, text_lengths = (
text.musa(
rank,
non_blocking=True,
),
text_lengths.musa(
rank,
non_blocking=True,
),
)
if hps.model.version in {"v2Pro", "v2ProPlus"}:
sv_emb = sv_emb.musa(rank, non_blocking=True)
else: else:
spec, spec_lengths = spec.to(device), spec_lengths.to(device) spec, spec_lengths = spec.to(device), spec_lengths.to(device)
y, y_lengths = y.to(device), y_lengths.to(device) y, y_lengths = y.to(device), y_lengths.to(device)
@ -595,12 +659,17 @@ def evaluate(hps, generator, eval_loader, writer_eval):
text, text,
text_lengths, text_lengths,
) in enumerate(eval_loader): ) in enumerate(eval_loader):
print(111) print("确实在跑")
if torch.cuda.is_available(): if torch.cuda.is_available():
spec, spec_lengths = spec.cuda(), spec_lengths.cuda() spec, spec_lengths = spec.cuda(), spec_lengths.cuda()
y, y_lengths = y.cuda(), y_lengths.cuda() y, y_lengths = y.cuda(), y_lengths.cuda()
ssl = ssl.cuda() ssl = ssl.cuda()
text, text_lengths = text.cuda(), text_lengths.cuda() text, text_lengths = text.cuda(), text_lengths.cuda()
elif musa_utils.is_available():
spec, spec_lengths = spec.musa(), spec_lengths.musa()
y, y_lengths = y.musa(), y_lengths.musa()
ssl = ssl.musa()
text, text_lengths = text.musa(), text_lengths.musa()
else: else:
spec, spec_lengths = spec.to(device), spec_lengths.to(device) spec, spec_lengths = spec.to(device), spec_lengths.to(device)
y, y_lengths = y.to(device), y_lengths.to(device) y, y_lengths = y.to(device), y_lengths.to(device)
@ -616,7 +685,7 @@ def evaluate(hps, generator, eval_loader, writer_eval):
text_lengths, text_lengths,
test=test, test=test,
) )
if torch.cuda.is_available() if torch.cuda.is_available() or musa_utils.is_available()
else generator.infer( else generator.infer(
ssl, ssl,
spec, spec,

View File

@ -4,15 +4,22 @@ warnings.filterwarnings("ignore")
import os import os
import utils import utils
import musa_utils
hps = utils.get_hparams(stage=2) hps = utils.get_hparams(stage=2)
os.environ["CUDA_VISIBLE_DEVICES"] = hps.train.gpu_numbers.replace("-", ",") os.environ["CUDA_VISIBLE_DEVICES"] = hps.train.gpu_numbers.replace("-", ",")
os.environ["MUSA_VISIBLE_DEVICES"] = hps.train.gpu_numbers.replace("-", ",")
import logging import logging
import torch import torch
import torch.distributed as dist import torch.distributed as dist
import torch.multiprocessing as mp import torch.multiprocessing as mp
from torch.cuda.amp import GradScaler, autocast from torch.cuda.amp import GradScaler
if musa_utils.is_available():
autocast = torch.musa.amp.autocast
musa_ddp = musa_utils.should_ddp()
elif torch.cuda.is_available():
from torch.cuda.amp import autocast
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
@ -43,16 +50,22 @@ torch.backends.cudnn.deterministic = False
###反正A100fp32更快那试试tf32吧 ###反正A100fp32更快那试试tf32吧
torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True
if musa_utils.is_available():
torch.backends.mudnn.allow_tf32 = True
torch.set_float32_matmul_precision("medium") # 最低精度但最快(也就快一丁点),对于结果造成不了影响 torch.set_float32_matmul_precision("medium") # 最低精度但最快(也就快一丁点),对于结果造成不了影响
# from config import pretrained_s2G,pretrained_s2D # from config import pretrained_s2G,pretrained_s2D
global_step = 0 global_step = 0
device = "cpu" # cuda以外的设备等mps优化后加入 device = "cpu" # cuda以外的设备等mps优化后加入
if not musa_ddp:
device = "musa"
def main(): def main():
if torch.cuda.is_available(): if torch.cuda.is_available():
n_gpus = torch.cuda.device_count() n_gpus = torch.cuda.device_count()
elif musa_ddp:
n_gpus = musa_utils.device_count()
else: else:
n_gpus = 1 n_gpus = 1
os.environ["MASTER_ADDR"] = "localhost" os.environ["MASTER_ADDR"] = "localhost"
@ -78,7 +91,7 @@ def run(rank, n_gpus, hps):
writer_eval = SummaryWriter(log_dir=os.path.join(hps.s2_ckpt_dir, "eval")) writer_eval = SummaryWriter(log_dir=os.path.join(hps.s2_ckpt_dir, "eval"))
dist.init_process_group( dist.init_process_group(
backend="gloo" if os.name == "nt" or not torch.cuda.is_available() else "nccl", backend= "mccl" if torch.musa.is_available() else ("gloo" if os.name == "nt" or not torch.cuda.is_available() else "nccl"),
init_method="env://?use_libuv=False", init_method="env://?use_libuv=False",
world_size=n_gpus, world_size=n_gpus,
rank=rank, rank=rank,
@ -86,6 +99,8 @@ def run(rank, n_gpus, hps):
torch.manual_seed(hps.train.seed) torch.manual_seed(hps.train.seed)
if torch.cuda.is_available(): if torch.cuda.is_available():
torch.cuda.set_device(rank) torch.cuda.set_device(rank)
elif musa_ddp:
musa_utils.set_device(rank)
train_dataset = TextAudioSpeakerLoader(hps.data) ######## train_dataset = TextAudioSpeakerLoader(hps.data) ########
train_sampler = DistributedBucketSampler( train_sampler = DistributedBucketSampler(
@ -140,6 +155,13 @@ def run(rank, n_gpus, hps):
**hps.model, **hps.model,
).cuda(rank) ).cuda(rank)
if torch.cuda.is_available() if torch.cuda.is_available()
else SynthesizerTrn(
hps.data.filter_length // 2 + 1,
hps.train.segment_size // hps.data.hop_length,
n_speakers=hps.data.n_speakers,
**hps.model,
).musa(rank)
if musa_ddp
else SynthesizerTrn( else SynthesizerTrn(
hps.data.filter_length // 2 + 1, hps.data.filter_length // 2 + 1,
hps.train.segment_size // hps.data.hop_length, hps.train.segment_size // hps.data.hop_length,
@ -165,7 +187,7 @@ def run(rank, n_gpus, hps):
# betas=hps.train.betas, # betas=hps.train.betas,
# eps=hps.train.eps, # eps=hps.train.eps,
# ) # )
if torch.cuda.is_available(): if torch.cuda.is_available() or musa_ddp:
net_g = DDP(net_g, device_ids=[rank], find_unused_parameters=True) net_g = DDP(net_g, device_ids=[rank], find_unused_parameters=True)
# net_d = DDP(net_d, device_ids=[rank], find_unused_parameters=True) # net_d = DDP(net_d, device_ids=[rank], find_unused_parameters=True)
else: else:
@ -207,7 +229,7 @@ def run(rank, n_gpus, hps):
torch.load(hps.train.pretrained_s2G, map_location="cpu", weights_only=False)["weight"], torch.load(hps.train.pretrained_s2G, map_location="cpu", weights_only=False)["weight"],
strict=False, strict=False,
) )
if torch.cuda.is_available() if torch.cuda.is_available() or musa_ddp
else net_g.load_state_dict( else net_g.load_state_dict(
torch.load(hps.train.pretrained_s2G, map_location="cpu", weights_only=False)["weight"], torch.load(hps.train.pretrained_s2G, map_location="cpu", weights_only=False)["weight"],
strict=False, strict=False,
@ -235,7 +257,10 @@ def run(rank, n_gpus, hps):
scheduler_g.step() scheduler_g.step()
# scheduler_d.step() # scheduler_d.step()
scaler = GradScaler(enabled=hps.train.fp16_run) if musa_utils.is_available():
scaler = torch.musa.amp.GradScaler(enabled=hps.train.fp16_run)
else:
scaler = GradScaler(enabled=hps.train.fp16_run)
net_d = optim_d = scheduler_d = None net_d = optim_d = scheduler_d = None
print("start training from epoch %s" % epoch_str) print("start training from epoch %s" % epoch_str)
@ -334,6 +359,31 @@ def train_and_evaluate(
non_blocking=True, non_blocking=True,
), ),
) )
elif musa_ddp:
spec, spec_lengths = (
spec.musa(
rank,
non_blocking=True,
),
spec_lengths.musa(
rank,
non_blocking=True,
),
)
mel, mel_lengths = mel.musa(rank, non_blocking=True), mel_lengths.musa(rank, non_blocking=True)
ssl = ssl.musa(rank, non_blocking=True)
ssl.requires_grad = False
# ssl_lengths = ssl_lengths.musa(rank, non_blocking=True)
text, text_lengths = (
text.musa(
rank,
non_blocking=True,
),
text_lengths.musa(
rank,
non_blocking=True,
),
)
else: else:
spec, spec_lengths = spec.to(device), spec_lengths.to(device) spec, spec_lengths = spec.to(device), spec_lengths.to(device)
mel, mel_lengths = mel.to(device), mel_lengths.to(device) mel, mel_lengths = mel.to(device), mel_lengths.to(device)

View File

@ -4,15 +4,21 @@ warnings.filterwarnings("ignore")
import os import os
import utils import utils
import musa_utils
hps = utils.get_hparams(stage=2) hps = utils.get_hparams(stage=2)
os.environ["CUDA_VISIBLE_DEVICES"] = hps.train.gpu_numbers.replace("-", ",") os.environ["CUDA_VISIBLE_DEVICES"] = hps.train.gpu_numbers.replace("-", ",")
os.environ["MUSA_VISIBLE_DEVICES"] = hps.train.gpu_numbers.replace("-", ",")
import logging import logging
import torch import torch
import torch.distributed as dist import torch.distributed as dist
import torch.multiprocessing as mp import torch.multiprocessing as mp
from torch.cuda.amp import GradScaler, autocast from torch.cuda.amp import GradScaler
if musa_utils.is_available():
autocast = torch.musa.amp.autocast
elif torch.cuda.is_available():
from torch.cuda.amp import autocast
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
@ -43,11 +49,15 @@ torch.backends.cudnn.deterministic = False
###反正A100fp32更快那试试tf32吧 ###反正A100fp32更快那试试tf32吧
torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True
if musa_utils.is_available():
torch.backends.mudnn.allow_tf32 = True
torch.set_float32_matmul_precision("medium") # 最低精度但最快(也就快一丁点),对于结果造成不了影响 torch.set_float32_matmul_precision("medium") # 最低精度但最快(也就快一丁点),对于结果造成不了影响
# from config import pretrained_s2G,pretrained_s2D # from config import pretrained_s2G,pretrained_s2D
global_step = 0 global_step = 0
device = "cpu" # cuda以外的设备等mps优化后加入 device = "cpu" # cuda以外的设备等mps优化后加入
if musa_utils.is_available(): # DDP支持不写了没设备测试
device = "musa"
def main(): def main():
@ -209,7 +219,10 @@ def run(rank, n_gpus, hps):
for _ in range(epoch_str): for _ in range(epoch_str):
scheduler_g.step() scheduler_g.step()
scaler = GradScaler(enabled=hps.train.fp16_run) if musa_utils.is_available():
scaler = torch.musa.amp.GradScaler(enabled=hps.train.fp16_run)
else:
scaler = GradScaler(enabled=hps.train.fp16_run)
net_d = optim_d = scheduler_d = None net_d = optim_d = scheduler_d = None
print("start training from epoch %s" % epoch_str) print("start training from epoch %s" % epoch_str)

3
api.py
View File

@ -208,6 +208,7 @@ def clean_hifigan_model():
hifigan_model = None hifigan_model = None
try: try:
torch.cuda.empty_cache() torch.cuda.empty_cache()
torch.musa.empty_cache()
except: except:
pass pass
@ -219,6 +220,7 @@ def clean_bigvgan_model():
bigvgan_model = None bigvgan_model = None
try: try:
torch.cuda.empty_cache() torch.cuda.empty_cache()
torch.musa.empty_cache()
except: except:
pass pass
@ -230,6 +232,7 @@ def clean_sv_cn_model():
sv_cn_model = None sv_cn_model = None
try: try:
torch.cuda.empty_cache() torch.cuda.empty_cache()
torch.musa.empty_cache()
except: except:
pass pass

View File

@ -4,6 +4,8 @@ import sys
import torch import torch
import musa_utils
from tools.i18n.i18n import I18nAuto from tools.i18n.i18n import I18nAuto
i18n = I18nAuto(language=os.environ.get("language", "Auto")) i18n = I18nAuto(language=os.environ.get("language", "Auto"))
@ -147,11 +149,15 @@ api_port = 9880
# Thanks to the contribution of @Karasukaigan and @XXXXRT666 # Thanks to the contribution of @Karasukaigan and @XXXXRT666
def get_device_dtype_sm(idx: int) -> tuple[torch.device, torch.dtype, float, float]: def get_device_dtype_sm(idx: int) -> tuple[torch.device, torch.dtype, float, float]:
device_idx = idx
cpu = torch.device("cpu") cpu = torch.device("cpu")
cuda = torch.device(f"cuda:{idx}") cuda = torch.device(f"cuda:{idx}")
if musa_utils.is_available():
musa = torch.device(f"musa:{idx}")
mdtype, mversion, mmem_gb = musa_utils.get_device_dtype(device_idx)
return musa, mdtype, mversion, mmem_gb
if not torch.cuda.is_available(): if not torch.cuda.is_available():
return cpu, torch.float32, 0.0, 0.0 return cpu, torch.float32, 0.0, 0.0
device_idx = idx
capability = torch.cuda.get_device_capability(device_idx) capability = torch.cuda.get_device_capability(device_idx)
name = torch.cuda.get_device_name(device_idx) name = torch.cuda.get_device_name(device_idx)
mem_bytes = torch.cuda.get_device_properties(device_idx).total_memory mem_bytes = torch.cuda.get_device_properties(device_idx).total_memory
@ -167,11 +173,16 @@ def get_device_dtype_sm(idx: int) -> tuple[torch.device, torch.dtype, float, flo
return cuda, torch.float16, sm_version, mem_gb return cuda, torch.float16, sm_version, mem_gb
return cpu, torch.float32, 0.0, 0.0 return cpu, torch.float32, 0.0, 0.0
def get_gpu_count():
if musa_utils.is_available():
return musa_utils.device_count()
else:
return torch.cuda.device_count()
IS_GPU = True IS_GPU = True
GPU_INFOS: list[str] = [] GPU_INFOS: list[str] = []
GPU_INDEX: set[int] = set() GPU_INDEX: set[int] = set()
GPU_COUNT = torch.cuda.device_count() GPU_COUNT = get_gpu_count()
CPU_INFO: str = "0\tCPU " + i18n("CPU训练,较慢") CPU_INFO: str = "0\tCPU " + i18n("CPU训练,较慢")
tmp: list[tuple[torch.device, torch.dtype, float, float]] = [] tmp: list[tuple[torch.device, torch.dtype, float, float]] = []
memset: set[float] = set() memset: set[float] = set()
@ -183,7 +194,11 @@ for j in tmp:
device = j[0] device = j[0]
memset.add(j[3]) memset.add(j[3])
if device.type != "cpu": if device.type != "cpu":
GPU_INFOS.append(f"{device.index}\t{torch.cuda.get_device_name(device.index)}") if device.type == "cuda" and torch.cuda.is_available():
device_name = torch.cuda.get_device_name(device.index)
elif device.type == "musa" and musa_utils.is_available():
device_name = torch.musa.get_device_properties(device.index).name
GPU_INFOS.append(f"{device.index}\t{device_name}")
GPU_INDEX.add(device.index) GPU_INDEX.add(device.index)
if not GPU_INFOS: if not GPU_INFOS:

99
musa_utils.py Normal file
View File

@ -0,0 +1,99 @@
import torch
from typing import Optional
import re
_musa_available = False
_musa_err_msg = "torch_musa not found or not configured correctly."
try:
if hasattr(torch, 'musa') and torch.musa.is_available():
try:
major, minor = torch.musa.get_device_capability()
version_val = major + minor / 10.0
if version_val < 2.1:
raise RuntimeError(
f"MUSA version check failed! "
f"Found capability {major}.{minor} (version value {version_val:.2f}), "
f"but this project requires a version >= 2.1. "
f"Please upgrade your torch_musa and MUSA SDK."
f"See: https://github.com/MooreThreads/torch_musa"
)
_musa_available = True
except Exception as e:
_musa_err_msg = f"MUSA availability check failed: {e}"
if isinstance(e, RuntimeError):
raise e
_musa_available = False
except Exception:
_musa_available = False
def is_available() -> bool:
return _musa_available
def get_device() -> Optional[torch.device]:
return torch.device("musa") if _musa_available else None
def device_count() -> int:
if _musa_available:
try:
return torch.musa.device_count()
except Exception:
return 0
return 0
def set_device(device_index: int):
if _musa_available:
try:
torch.musa.set_device(device_index)
except Exception:
pass
def empty_cache():
if _musa_available:
try:
torch.musa.empty_cache()
except Exception:
pass
def manual_seed(seed: int):
if _musa_available:
try:
torch.musa.manual_seed(seed)
except Exception:
pass
def manual_seed_all(seed: int):
if _musa_available:
try:
torch.musa.manual_seed_all(seed)
except Exception:
pass
def get_device_dtype(device_idx: int = 0) -> tuple[torch.dtype, float, float]:
properties = torch.musa.get_device_properties(device_idx)
major, minor = torch.musa.get_device_capability()
version_val = major + minor / 10.0
mem_bytes = properties.total_memory
mem_gb = mem_bytes / (1024**3) + 0.4
device_name = properties.name
dtype = torch.float32
numbers_in_name = [int(n) for n in re.findall(r'\d+', device_name)]
if any(num >= 4000 for num in numbers_in_name):
dtype = torch.float16
return dtype, version_val, mem_gb
def should_ddp(device_idx: int = 0) -> bool:
device_name = torch.musa.get_device_properties(device_idx).name
numbers_in_name = [int(n) for n in re.findall(r'\d+', device_name)]
if any(num >= 4000 for num in numbers_in_name):
return True
else:
return False
DEVICE: Optional[torch.device] = get_device()

View File

@ -13,6 +13,12 @@ from tools.asr.config import get_models
from tools.asr.funasr_asr import only_asr from tools.asr.funasr_asr import only_asr
from tools.my_utils import load_cudnn from tools.my_utils import load_cudnn
try:
import torch_musa
use_torch_musa = True
except ImportError:
use_torch_musa = False
# fmt: off # fmt: off
language_code_list = [ language_code_list = [
"af", "am", "ar", "as", "az", "af", "am", "ar", "as", "az",
@ -86,6 +92,8 @@ def execute_asr(input_folder, output_folder, model_path, language, precision):
language = None # 不设置语种由模型自动输出概率最高的语种 language = None # 不设置语种由模型自动输出概率最高的语种
print("loading faster whisper model:", model_path, model_path) print("loading faster whisper model:", model_path, model_path)
device = "cuda" if torch.cuda.is_available() else "cpu" device = "cuda" if torch.cuda.is_available() else "cpu"
if use_torch_musa:
device = "musa" if torch.musa.is_available() else "cpu"
model = WhisperModel(model_path, device=device, compute_type=precision) model = WhisperModel(model_path, device=device, compute_type=precision)
input_file_names = os.listdir(input_folder) input_file_names = os.listdir(input_folder)

View File

@ -10,6 +10,12 @@ import torch.nn as nn
import yaml import yaml
from tqdm import tqdm from tqdm import tqdm
try:
import torch_musa
use_torch_musa = True
except ImportError:
use_torch_musa = False
warnings.filterwarnings("ignore") warnings.filterwarnings("ignore")
@ -135,7 +141,13 @@ class Roformer_Loader:
window_middle[-fade_size:] *= fadeout window_middle[-fade_size:] *= fadeout
window_middle[:fade_size] *= fadein window_middle[:fade_size] *= fadein
with torch.amp.autocast("cuda"): if use_torch_musa:
if torch.musa.is_available():
set_device = "musa"
else:
set_device = "cuda"
with torch.amp.autocast(set_device):
with torch.inference_mode(): with torch.inference_mode():
if self.config["training"]["target_instrument"] is None: if self.config["training"]["target_instrument"] is None:
req_shape = (len(self.config["training"]["instruments"]),) + tuple(mix.shape) req_shape = (len(self.config["training"]["instruments"]),) + tuple(mix.shape)

View File

@ -18,6 +18,12 @@ from bsroformer import Roformer_Loader
from mdxnet import MDXNetDereverb from mdxnet import MDXNetDereverb
from vr import AudioPre, AudioPreDeEcho from vr import AudioPre, AudioPreDeEcho
try:
import torch_musa
use_torch_musa = True
except ImportError:
use_torch_musa = False
weight_uvr5_root = "tools/uvr5/uvr5_weights" weight_uvr5_root = "tools/uvr5/uvr5_weights"
uvr5_names = [] uvr5_names = []
for name in os.listdir(weight_uvr5_root): for name in os.listdir(weight_uvr5_root):
@ -122,6 +128,9 @@ def uvr(model_name, inp_root, save_root_vocal, paths, save_root_ins, agg, format
print("clean_empty_cache") print("clean_empty_cache")
if torch.cuda.is_available(): if torch.cuda.is_available():
torch.cuda.empty_cache() torch.cuda.empty_cache()
if use_torch_musa:
if torch.musa.is_available():
torch.musa.empty_cache()
yield "\n".join(infos) yield "\n".join(infos)

View File

@ -92,6 +92,8 @@ os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
# os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1' # 当遇到mps不支持的步骤时使用cpu # os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1' # 当遇到mps不支持的步骤时使用cpu
import gradio as gr import gradio as gr
import musa_utils
n_cpu = cpu_count() n_cpu = cpu_count()
set_gpu_numbers = GPU_INDEX set_gpu_numbers = GPU_INDEX
@ -347,6 +349,8 @@ def change_tts_inference(bert_path, cnhubert_base_path, gpu_number, gpt_path, so
os.environ["is_half"] = str(is_half) os.environ["is_half"] = str(is_half)
os.environ["infer_ttswebui"] = str(webui_port_infer_tts) os.environ["infer_ttswebui"] = str(webui_port_infer_tts)
os.environ["is_share"] = str(is_share) os.environ["is_share"] = str(is_share)
if musa_utils.is_available():
os.environ["_MUSA_VISIBLE_DEVICES"] = str(fix_gpu_number(gpu_number))
yield ( yield (
process_info(process_name_tts, "opened"), process_info(process_name_tts, "opened"),
{"__type__": "update", "visible": False}, {"__type__": "update", "visible": False},
@ -629,6 +633,8 @@ def open1Bb(
# data["version"]=version # data["version"]=version
os.environ["_CUDA_VISIBLE_DEVICES"] = str(fix_gpu_numbers(gpu_numbers.replace("-", ","))) os.environ["_CUDA_VISIBLE_DEVICES"] = str(fix_gpu_numbers(gpu_numbers.replace("-", ",")))
if musa_utils.is_available():
os.environ["_MUSA_VISIBLE_DEVICES"] = str(fix_gpu_number(gpu_numbers.replace("-", ",")))
os.environ["hz"] = "25hz" os.environ["hz"] = "25hz"
tmp_config_path = "%s/tmp_s1.yaml" % tmp tmp_config_path = "%s/tmp_s1.yaml" % tmp
with open(tmp_config_path, "w") as f: with open(tmp_config_path, "w") as f:
@ -805,6 +811,8 @@ def open1a(inp_text, inp_wav_dir, exp_name, gpu_numbers, bert_pretrained_dir):
"is_half": str(is_half), "is_half": str(is_half),
} }
) )
if musa_utils.is_available():
config.update({"_MUSA_VISIBLE_DEVICES": str(fix_gpu_number(gpu_names[i_part])),})
os.environ.update(config) os.environ.update(config)
cmd = '"%s" -s GPT_SoVITS/prepare_datasets/1-get-text.py' % python_exec cmd = '"%s" -s GPT_SoVITS/prepare_datasets/1-get-text.py' % python_exec
print(cmd) print(cmd)
@ -895,6 +903,8 @@ def open1b(version, inp_text, inp_wav_dir, exp_name, gpu_numbers, ssl_pretrained
"_CUDA_VISIBLE_DEVICES": str(fix_gpu_number(gpu_names[i_part])), "_CUDA_VISIBLE_DEVICES": str(fix_gpu_number(gpu_names[i_part])),
} }
) )
if musa_utils.is_available():
config.update({"_MUSA_VISIBLE_DEVICES": str(fix_gpu_number(gpu_names[i_part])),})
os.environ.update(config) os.environ.update(config)
cmd = '"%s" -s GPT_SoVITS/prepare_datasets/2-get-hubert-wav32k.py' % python_exec cmd = '"%s" -s GPT_SoVITS/prepare_datasets/2-get-hubert-wav32k.py' % python_exec
print(cmd) print(cmd)
@ -917,6 +927,8 @@ def open1b(version, inp_text, inp_wav_dir, exp_name, gpu_numbers, ssl_pretrained
"_CUDA_VISIBLE_DEVICES": str(fix_gpu_number(gpu_names[i_part])), "_CUDA_VISIBLE_DEVICES": str(fix_gpu_number(gpu_names[i_part])),
} }
) )
if musa_utils.is_available():
config.update({"_MUSA_VISIBLE_DEVICES": str(fix_gpu_number(gpu_names[i_part])),})
os.environ.update(config) os.environ.update(config)
cmd = '"%s" -s GPT_SoVITS/prepare_datasets/2-get-sv.py' % python_exec cmd = '"%s" -s GPT_SoVITS/prepare_datasets/2-get-sv.py' % python_exec
print(cmd) print(cmd)
@ -989,6 +1001,8 @@ def open1c(version, inp_text, inp_wav_dir, exp_name, gpu_numbers, pretrained_s2G
"_CUDA_VISIBLE_DEVICES": str(fix_gpu_number(gpu_names[i_part])), "_CUDA_VISIBLE_DEVICES": str(fix_gpu_number(gpu_names[i_part])),
} }
) )
if musa_utils.is_available():
config.update({"_MUSA_VISIBLE_DEVICES": str(fix_gpu_number(gpu_names[i_part])),})
os.environ.update(config) os.environ.update(config)
cmd = '"%s" -s GPT_SoVITS/prepare_datasets/3-get-semantic.py' % python_exec cmd = '"%s" -s GPT_SoVITS/prepare_datasets/3-get-semantic.py' % python_exec
print(cmd) print(cmd)
@ -1089,6 +1103,8 @@ def open1abc(
"_CUDA_VISIBLE_DEVICES": str(fix_gpu_number(gpu_names[i_part])), "_CUDA_VISIBLE_DEVICES": str(fix_gpu_number(gpu_names[i_part])),
} }
) )
if musa_utils.is_available():
config.update({"_MUSA_VISIBLE_DEVICES": str(fix_gpu_number(gpu_names[i_part])),})
os.environ.update(config) os.environ.update(config)
cmd = '"%s" -s GPT_SoVITS/prepare_datasets/1-get-text.py' % python_exec cmd = '"%s" -s GPT_SoVITS/prepare_datasets/1-get-text.py' % python_exec
print(cmd) print(cmd)
@ -1136,6 +1152,8 @@ def open1abc(
"_CUDA_VISIBLE_DEVICES": str(fix_gpu_number(gpu_names[i_part])), "_CUDA_VISIBLE_DEVICES": str(fix_gpu_number(gpu_names[i_part])),
} }
) )
if musa_utils.is_available():
config.update({"_MUSA_VISIBLE_DEVICES": str(fix_gpu_number(gpu_names[i_part])),})
os.environ.update(config) os.environ.update(config)
cmd = '"%s" -s GPT_SoVITS/prepare_datasets/2-get-hubert-wav32k.py' % python_exec cmd = '"%s" -s GPT_SoVITS/prepare_datasets/2-get-hubert-wav32k.py' % python_exec
print(cmd) print(cmd)
@ -1158,6 +1176,8 @@ def open1abc(
"_CUDA_VISIBLE_DEVICES": str(fix_gpu_number(gpu_names[i_part])), "_CUDA_VISIBLE_DEVICES": str(fix_gpu_number(gpu_names[i_part])),
} }
) )
if musa_utils.is_available():
config.update({"_MUSA_VISIBLE_DEVICES": str(fix_gpu_number(gpu_names[i_part])),})
os.environ.update(config) os.environ.update(config)
cmd = '"%s" -s GPT_SoVITS/prepare_datasets/2-get-sv.py' % python_exec cmd = '"%s" -s GPT_SoVITS/prepare_datasets/2-get-sv.py' % python_exec
print(cmd) print(cmd)
@ -1198,6 +1218,8 @@ def open1abc(
"_CUDA_VISIBLE_DEVICES": str(fix_gpu_number(gpu_names[i_part])), "_CUDA_VISIBLE_DEVICES": str(fix_gpu_number(gpu_names[i_part])),
} }
) )
if musa_utils.is_available():
config.update({"_MUSA_VISIBLE_DEVICES": str(fix_gpu_number(gpu_names[i_part])),})
os.environ.update(config) os.environ.update(config)
cmd = '"%s" -s GPT_SoVITS/prepare_datasets/3-get-semantic.py' % python_exec cmd = '"%s" -s GPT_SoVITS/prepare_datasets/3-get-semantic.py' % python_exec
print(cmd) print(cmd)