This commit is contained in:
XXXXRT666 2025-10-20 00:07:27 +01:00
parent 41e44afa12
commit 620b7810e6
9 changed files with 109 additions and 141 deletions

View File

@ -20,11 +20,11 @@ class SampleProtocolMLX(Protocol):
def apply_repetition_penalty(logits: Array, previous_tokens: Array, repetition_penalty: float):
batch_idx = mx.arange(previous_tokens.shape[0])
selected_logits = logits[batch_idx, previous_tokens]
selected_logits = mx.take_along_axis(logits, previous_tokens, axis=1)
selected_logits = mx.where(
selected_logits < 0, selected_logits * repetition_penalty, selected_logits / repetition_penalty
)
logits[batch_idx, previous_tokens] = selected_logits
logits[batch_idx.reshape(-1, 1), previous_tokens] = selected_logits
return logits

View File

@ -117,7 +117,8 @@ class T2SSessionMLX:
self.input_pos = mx.zeros_like(self.prefill_len)
self.input_pos += self.prefill_len
self.input_pos = self.input_pos.squeeze(0) # 30% Performance Improvement
if bsz == 1:
self.input_pos = self.input_pos.squeeze(0) # 30% Performance Improvement in bsz=1
# EOS
self.completed = mx.array([False] * len(self.x)).astype(mx.bool_)

View File

@ -91,7 +91,7 @@ class T2SEngine(T2SEngineProtocol):
session.attn_mask,
session.kv_cache,
) # bs, seq_len, embed_dim
xy_dec = xy_dec[None, batch_idx, session.input_pos - 1]
xy_dec = xy_dec[batch_idx, None, session.input_pos - 1]
if debug:
mx.eval(xy_dec)
else:
@ -120,7 +120,7 @@ class T2SEngine(T2SEngineProtocol):
mx.metal.stop_capture()
decoder.post_forward(idx, session)
logits = decoder.ar_predict_layer(xy_dec[:, -1])
logits = decoder.ar_predict_layer(xy_dec.squeeze(1))
session.input_pos += 1
if idx == 0:
@ -136,7 +136,7 @@ class T2SEngine(T2SEngineProtocol):
temperature=request.temperature,
)
session.y[batch_idx, session.y_len + idx] = samples
session.y[batch_idx.reshape(-1, 1), session.y_len + idx] = samples
if debug:
mx.eval(samples)
@ -168,9 +168,9 @@ class T2SEngine(T2SEngineProtocol):
logger.info(
f"T2S Decoding EOS {session.prefill_len.tolist().__str__().strip('[]')} -> {[i.shape[-1] for i in session.y_results].__str__().strip('[]')}"
)
logger.info(f"Infer Speed: {(idx + 1) / (time.perf_counter() - t1):.2f} token/s")
logger.info(f"Infer Speed: {(idx + 1) * session.bsz / (time.perf_counter() - t1):.2f} token/s")
infer_time = time.perf_counter() - t1
infer_speed = (idx + 1) / infer_time
infer_speed = (idx + 1) * session.bsz / infer_time
break
if (request.early_stop_num != -1 and idx >= request.early_stop_num) or idx == max_token - 1:
@ -179,9 +179,9 @@ class T2SEngine(T2SEngineProtocol):
session.y_results[j] = session.y[[j], session.y_len : session.y_len + idx]
session.completed[j] = True
logger.error("Bad Full Prediction")
logger.info(f"Infer Speed: {(idx + 1) / (time.perf_counter() - t1):.2f} token/s")
logger.info(f"Infer Speed: {(idx + 1) * session.bsz / (time.perf_counter() - t1):.2f} token/s")
infer_time = time.perf_counter() - t1
infer_speed = (idx + 1) / infer_time
infer_speed = (idx + 1) * session.bsz / infer_time
break
with timer("MLX.NextPos", debug=debug):

View File

@ -109,7 +109,8 @@ class SinePositionalEmbedding(nn.Module):
self.compute_pe(x.dtype)
assert self.pe is not None
pe_values = self.pe[:, : x.shape[-2]]
batch_size = x.shape[0]
pe_values = self.pe[mx.arange(batch_size), : x.shape[-2]]
return x * self.x_scale + self.alpha * pe_values

View File

@ -71,7 +71,7 @@ class T2SEngine(T2SEngineProtocol):
session.kv_cache = decoder.init_cache(session.bsz)
t1 = time.perf_counter()
xy_dec = decoder.h.prefill(session.xy_pos, session.kv_cache, session.attn_mask)
xy_dec = xy_dec[None, batch_idx, session.input_pos - 1]
xy_dec = xy_dec[batch_idx, None, session.input_pos - 1]
else:
if (
request.use_cuda_graph
@ -101,7 +101,7 @@ class T2SEngine(T2SEngineProtocol):
with torch.cuda.stream(session.stream) if session.stream is not None else contextlib.nullcontext():
decoder.post_forward(idx, session)
logits = decoder.ar_predict_layer(xy_dec[:, -1])
logits = decoder.ar_predict_layer(xy_dec.squeeze(1))
if idx == 0:
logits[:, -1] = float("-inf")
@ -115,7 +115,7 @@ class T2SEngine(T2SEngineProtocol):
repetition_penalty=request.repetition_penalty,
temperature=request.temperature,
)
session.y[batch_idx, session.y_len + idx] = samples
session.y[batch_idx.reshape(-1, 1), session.y_len + idx] = samples
session.input_pos.add_(1)
with torch_profiler.record("EOS"), timer("Torch.EOS", debug=debug):

View File

@ -5,7 +5,8 @@ import random
import time
import traceback
from copy import deepcopy
from typing import List, Tuple, Union
from pathlib import Path
from typing import Any
import ffmpeg
import librosa
@ -18,12 +19,12 @@ from peft import LoraConfig, get_peft_model
from tqdm import tqdm
from transformers import AutoModelForMaskedLM, AutoTokenizer
from GPT_SoVITS.AR.models.t2s_lightning_module import Text2SemanticLightningModule
from GPT_SoVITS.Accelerate import MLX, PyTorch, T2SEngineProtocol, T2SRequest, backends
from GPT_SoVITS.BigVGAN.bigvgan import BigVGAN
from GPT_SoVITS.feature_extractor.cnhubert import CNHubert
from GPT_SoVITS.module.mel_processing import mel_spectrogram_torch, spectrogram_torch
from GPT_SoVITS.module.models import Generator, SynthesizerTrn, SynthesizerTrnV3
from GPT_SoVITS.process_ckpt import get_sovits_version_from_path_fast, load_sovits_new
from GPT_SoVITS.process_ckpt import inspect_version
from GPT_SoVITS.sv import SV
from GPT_SoVITS.TTS_infer_pack.text_segmentation_method import splits
from GPT_SoVITS.TTS_infer_pack.TextPreprocessor import TextPreprocessor
@ -34,6 +35,7 @@ from tools.my_utils import DictToAttrRecursive
now_dir = os.getcwd()
resample_transform_dict = {}
v3v4set = {"v3", "v4"}
def resample(audio_tensor, sr0, sr1, device):
@ -170,12 +172,6 @@ def set_seed(seed: int):
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
# torch.backends.cudnn.deterministic = True
# torch.backends.cudnn.benchmark = False
# torch.backends.cudnn.enabled = True
# 开启后会影响精度
torch.backends.cuda.matmul.allow_tf32 = False
torch.backends.cudnn.allow_tf32 = False
except:
pass
return seed
@ -254,7 +250,7 @@ class TTS_Config:
# "auto",#多语种启动切分识别语种
# "auto_yue",#多语种启动切分识别语种
def __init__(self, configs: Union[dict, str] = None):
def __init__(self, configs: dict | str = None):
# 设置默认配置文件路径
configs_base_path: str = "GPT_SoVITS/configs/"
os.makedirs(configs_base_path, exist_ok=True)
@ -312,7 +308,6 @@ class TTS_Config:
self.update_configs()
self.max_sec = None
self.hz: int = 50
self.semantic_frame_rate: str = "25hz"
self.segment_size: int = 20480
self.filter_length: int = 2048
@ -377,14 +372,14 @@ class TTS_Config:
class TTS:
def __init__(self, configs: Union[dict, str, TTS_Config]):
def __init__(self, configs: dict | str | TTS_Config):
if isinstance(configs, TTS_Config):
self.configs = configs
else:
self.configs: TTS_Config = TTS_Config(configs)
self.t2s_model: Text2SemanticLightningModule = None
self.vits_model: Union[SynthesizerTrn, SynthesizerTrnV3] = None
self.t2s_model: T2SEngineProtocol = None
self.vits_model: SynthesizerTrn | SynthesizerTrnV3 = None
self.bert_tokenizer: AutoTokenizer = None
self.bert_model: AutoModelForMaskedLM = None
self.cnhuhbert_model: CNHubert = None
@ -401,12 +396,6 @@ class TTS:
"overlapped_len": None,
}
self._init_models()
self.text_preprocessor: TextPreprocessor = TextPreprocessor(
self.bert_model, self.bert_tokenizer, self.configs.device
)
self.prompt_cache: dict = {
"ref_audio_path": None,
"prompt_semantic": None,
@ -422,6 +411,12 @@ class TTS:
self.stop_flag: bool = False
self.precision: torch.dtype = torch.float16 if self.configs.is_half else torch.float32
self._init_models()
self.text_preprocessor: TextPreprocessor = TextPreprocessor(
self.bert_model, self.bert_tokenizer, self.configs.device
)
def _init_models(
self,
):
@ -450,34 +445,17 @@ class TTS:
def init_vits_weights(self, weights_path: str):
self.configs.vits_weights_path = weights_path
version, model_version, if_lora_v3 = get_sovits_version_from_path_fast(weights_path)
model_version, version, if_lora_v3, hps, dict_s2 = inspect_version(weights_path)
if "Pro" in model_version:
self.init_sv_model()
path_sovits = self.configs.default_configs[model_version]["vits_weights_path"]
path_sovits: str = self.configs.default_configs[model_version]["vits_weights_path"]
if if_lora_v3 is True and os.path.exists(path_sovits) is False:
info = path_sovits + i18n("SoVITS %s 底模缺失,无法加载相应 LoRA 权重" % model_version)
info = path_sovits + i18n(f"SoVITS {model_version} 底模缺失,无法加载相应 LoRA 权重")
raise FileExistsError(info)
# dict_s2 = torch.load(weights_path, map_location=self.configs.device,weights_only=False)
dict_s2 = load_sovits_new(weights_path)
hps = dict_s2["config"]
hps["model"]["semantic_frame_rate"] = "25hz"
if "enc_p.text_embedding.weight" not in dict_s2["weight"]:
hps["model"]["version"] = "v2" # v3model,v2sybomls
elif dict_s2["weight"]["enc_p.text_embedding.weight"].shape[0] == 322:
hps["model"]["version"] = "v1"
else:
hps["model"]["version"] = "v2"
version = hps["model"]["version"]
v3v4set = {"v3", "v4"}
if model_version not in v3v4set:
if "Pro" not in model_version:
model_version = version
else:
hps["model"]["version"] = model_version
else:
hps["model"]["version"] = model_version
self.configs.filter_length = hps["data"]["filter_length"]
self.configs.segment_size = hps["train"]["segment_size"]
@ -518,12 +496,14 @@ class TTS:
if if_lora_v3 is False:
print(
f"Loading VITS weights from {weights_path}. {vits_model.load_state_dict(dict_s2['weight'], strict=False)}"
f"Loading VITS weights from {weights_path}. {vits_model.load_state_dict(dict_s2['weight'], strict=True)}"
)
else:
print(
f"Loading VITS pretrained weights from {weights_path}. {vits_model.load_state_dict(load_sovits_new(path_sovits)['weight'], strict=False)}"
)
print(f">> loading sovits_{model_version}spretrained_G")
dict_pretrain = torch.load(path_sovits)["weight"]
print(f">> loading sovits_{model_version}_lora{model_version}")
dict_pretrain.update(dict_s2["weight"])
state_dict = dict_pretrain
lora_rank = dict_s2["lora_rank"]
lora_config = LoraConfig(
target_modules=["to_k", "to_q", "to_v", "to_out.0"],
@ -531,12 +511,10 @@ class TTS:
lora_alpha=lora_rank,
init_lora_weights=True,
)
vits_model.cfm = get_peft_model(vits_model.cfm, lora_config)
print(
f"Loading LoRA weights from {weights_path}. {vits_model.load_state_dict(dict_s2['weight'], strict=False)}"
)
vits_model.cfm = vits_model.cfm.merge_and_unload()
vits_model.cfm = get_peft_model(vits_model.cfm, lora_config) # type: ignore
vits_model.load_state_dict(state_dict, strict=True)
vits_model.cfm = vits_model.cfm.merge_and_unload() # pyright: ignore[reportAttributeAccessIssue, reportCallIssue]
vits_model.eval()
vits_model = vits_model.to(self.configs.device)
vits_model = vits_model.eval()
@ -547,21 +525,29 @@ class TTS:
self.configs.save_configs()
def init_t2s_weights(self, weights_path: str):
def init_t2s_weights(self, weights_path: str, ar_backend: str = backends[-1], quantization: Any = None):
print(f"Loading Text2Semantic weights from {weights_path}")
self.configs.t2s_weights_path = weights_path
self.configs.save_configs()
self.configs.hz = 50
dict_s1 = torch.load(weights_path, map_location=self.configs.device, weights_only=False)
config = dict_s1["config"]
self.configs.max_sec = config["data"]["max_sec"]
t2s_model = Text2SemanticLightningModule(config, "****", is_train=False)
t2s_model.load_state_dict(dict_s1["weight"])
t2s_model = t2s_model.to(self.configs.device)
t2s_model = t2s_model.eval()
self.t2s_model = t2s_model
if self.configs.is_half and str(self.configs.device) != "cpu":
self.t2s_model = self.t2s_model.half()
if "mlx" in ar_backend.lower():
t2s_engine = MLX.T2SEngineMLX(
MLX.T2SEngineMLX.load_decoder(
Path(weights_path), backend=ar_backend, quantize_mode=quantization, max_batch_size=20
),
"mx.gpu" if self.configs.device.type != "cpu" else "mx.cpu",
dtype=self.precision,
)
else:
t2s_engine = PyTorch.T2SEngineTorch(
PyTorch.T2SEngineTorch.load_decoder(
Path(weights_path), backend=ar_backend, quantize_mode=quantization, max_batch_size=20
),
self.configs.device if not torch.mps.is_available() else torch.device("cpu"),
dtype=self.precision,
)
self.t2s_model = t2s_engine
def init_vocoder(self, version: str):
if version == "v3":
@ -782,7 +768,7 @@ class TTS:
prompt_semantic = codes[0, 0].to(self.configs.device)
self.prompt_cache["prompt_semantic"] = prompt_semantic
def batch_sequences(self, sequences: List[torch.Tensor], axis: int = 0, pad_value: int = 0, max_length: int = None):
def batch_sequences(self, sequences: list[torch.Tensor], axis: int = 0, pad_value: int = 0, max_length: int = None):
seq = sequences[0]
ndim = seq.dim()
if axis < 0:
@ -923,11 +909,11 @@ class TTS:
Recovery the order of the audio according to the batch_index_list.
Args:
data (List[list(torch.Tensor)]): the out of order audio .
batch_index_list (List[list[int]]): the batch index list.
data (list[list(torch.Tensor)]): the out of order audio .
batch_index_list (list[list[int]]): the batch index list.
Returns:
list (List[torch.Tensor]): the data in the original order.
list (list[torch.Tensor]): the data in the original order.
"""
length = len(sum(batch_index_list, []))
_data = [None] * length
@ -964,7 +950,6 @@ class TTS:
"text_split_method": "cut0", # str. text split method, see text_segmentation_method.py for details.
"batch_size": 1, # int. batch size for inference
"batch_threshold": 0.75, # float. threshold for batch splitting.
"split_bucket: True, # bool. whether to split the batch into multiple buckets.
"return_fragment": False, # bool. step by step return the audio fragment.
"speed_factor":1.0, # float. control the speed of the synthesized audio.
"fragment_interval":0.3, # float. to control the interval of the audio fragment.
@ -992,40 +977,18 @@ class TTS:
batch_size = inputs.get("batch_size", 1)
batch_threshold = inputs.get("batch_threshold", 0.75)
speed_factor = inputs.get("speed_factor", 1.0)
split_bucket = inputs.get("split_bucket", True)
return_fragment = inputs.get("return_fragment", False)
fragment_interval = inputs.get("fragment_interval", 0.3)
seed = inputs.get("seed", -1)
seed = -1 if seed in ["", None] else seed
actual_seed = set_seed(seed)
set_seed(seed)
parallel_infer = inputs.get("parallel_infer", True)
repetition_penalty = inputs.get("repetition_penalty", 1.35)
sample_steps = inputs.get("sample_steps", 32)
super_sampling = inputs.get("super_sampling", False)
if parallel_infer:
print(i18n("并行推理模式已开启"))
self.t2s_model.model.infer_panel = self.t2s_model.model.infer_panel_batch_infer
else:
print(i18n("并行推理模式已关闭"))
self.t2s_model.model.infer_panel = self.t2s_model.model.infer_panel_naive_batched
if return_fragment:
print(i18n("分段返回模式已开启"))
if split_bucket:
split_bucket = False
print(i18n("分段返回模式不支持分桶处理,已自动关闭分桶处理"))
if split_bucket and speed_factor == 1.0 and not (self.configs.use_vocoder and parallel_infer):
print(i18n("分桶处理模式已开启"))
elif speed_factor != 1.0:
print(i18n("语速调节不支持分桶处理,已自动关闭分桶处理"))
split_bucket = False
elif self.configs.use_vocoder and parallel_infer:
print(i18n("当开启并行推理模式时SoVits V3/4模型不支持分桶处理已自动关闭分桶处理"))
split_bucket = False
else:
print(i18n("分桶处理模式已关闭"))
if fragment_interval < 0.01:
fragment_interval = 0.01
@ -1102,7 +1065,7 @@ class TTS:
prompt_data=self.prompt_cache if not no_prompt_text else None,
batch_size=batch_size,
threshold=batch_threshold,
split_bucket=split_bucket,
split_bucket=False,
device=self.configs.device,
precision=self.precision,
)
@ -1158,37 +1121,44 @@ class TTS:
if item is None:
continue
batch_phones: List[torch.LongTensor] = item["phones"]
batch_phones: list[torch.Tensor] = item["phones"]
# batch_phones:torch.LongTensor = item["phones"]
batch_phones_len: torch.LongTensor = item["phones_len"]
all_phoneme_ids: torch.LongTensor = item["all_phones"]
all_phoneme_lens: torch.LongTensor = item["all_phones_len"]
all_bert_features: torch.LongTensor = item["all_bert_features"]
batch_phones_len: torch.Tensor = item["phones_len"]
all_phoneme_ids: list[torch.Tensor] = item["all_phones"]
all_phoneme_lens: torch.Tensor = item["all_phones_len"]
all_bert_features: list[torch.Tensor] = item["all_bert_features"]
norm_text: str = item["norm_text"]
max_len = item["max_len"]
print(i18n("前端处理后的文本(每句):"), norm_text)
if no_prompt_text:
prompt = None
prompt = torch.zeros(1, 0).to(self.configs.device, self.precision)
else:
prompt = (
self.prompt_cache["prompt_semantic"].expand(len(all_phoneme_ids), -1).to(self.configs.device)
)
prompt = self.prompt_cache["prompt_semantic"].to(self.configs.device).unsqueeze(0)
print(f"############ {i18n('预测语义Token')} ############")
pred_semantic_list, idx_list = self.t2s_model.model.infer_panel(
t2s_request = T2SRequest(
all_phoneme_ids,
all_phoneme_lens,
prompt,
all_bert_features,
# prompt_phone_len=ph_offset,
valid_length=len(all_phoneme_ids),
top_k=top_k,
top_p=top_p,
temperature=temperature,
early_stop_num=self.configs.hz * self.configs.max_sec,
max_len=max_len,
repetition_penalty=repetition_penalty,
debug=os.environ.get("DEBUG", "0") == "1",
use_cuda_graph=torch.cuda.is_available(),
)
t2s_result = self.t2s_model.generate(t2s_request)
if t2s_result.exception is not None:
print(t2s_result.traceback)
raise RuntimeError()
pred_semantic_list = t2s_result.result
assert pred_semantic_list
pred_semantic_list = [semantic.squeeze(0) for semantic in pred_semantic_list]
t4 = time.perf_counter()
t_34 += t4 - t3
@ -1220,7 +1190,6 @@ class TTS:
if speed_factor == 1.0:
print(f"{i18n('并行合成中')}...")
# ## vits并行推理 method 2
pred_semantic_list = [item[-idx:] for item, idx in zip(pred_semantic_list, idx_list)]
upsample_rate = math.prod(self.vits_model.upsample_rates)
audio_frag_idx = [
pred_semantic_list[i].shape[0] * 2 * upsample_rate
@ -1231,7 +1200,7 @@ class TTS:
torch.cat(pred_semantic_list).unsqueeze(0).unsqueeze(0).to(self.configs.device)
)
_batch_phones = torch.cat(batch_phones).unsqueeze(0).to(self.configs.device)
if self.is_v2pro != True:
if not self.is_v2pro:
_batch_audio_fragment = self.vits_model.decode(
all_pred_semantic, _batch_phones, refer_audio_spec, speed=speed_factor
).detach()[0, 0, :]
@ -1251,7 +1220,7 @@ class TTS:
_pred_semantic = (
pred_semantic_list[i][-idx:].unsqueeze(0).unsqueeze(0)
) # .unsqueeze(0)#mq要多unsqueeze一次
if self.is_v2pro != True:
if not self.is_v2pro:
audio_fragment = self.vits_model.decode(
_pred_semantic, phones, refer_audio_spec, speed=speed_factor
).detach()[0, 0, :]
@ -1308,7 +1277,7 @@ class TTS:
output_sr,
batch_index_list,
speed_factor,
split_bucket,
False,
fragment_interval,
super_sampling if self.configs.use_vocoder and self.configs.version == "v3" else False,
)
@ -1340,14 +1309,14 @@ class TTS:
def audio_postprocess(
self,
audio: List[torch.Tensor],
audio: list[torch.Tensor],
sr: int,
batch_index_list: list = None,
speed_factor: float = 1.0,
split_bucket: bool = True,
split_bucket: bool = False,
fragment_interval: float = 0.3,
super_sampling: bool = False,
) -> Tuple[int, np.ndarray]:
) -> tuple[int, np.ndarray]:
zero_wav = torch.zeros(
int(self.configs.sampling_rate * fragment_interval), dtype=self.precision, device=self.configs.device
)
@ -1459,12 +1428,12 @@ class TTS:
def using_vocoder_synthesis_batched_infer(
self,
idx_list: List[int],
semantic_tokens_list: List[torch.Tensor],
batch_phones: List[torch.Tensor],
idx_list: list[int],
semantic_tokens_list: list[torch.Tensor],
batch_phones: list[torch.Tensor],
speed: float = 1.0,
sample_steps: int = 32,
) -> List[torch.Tensor]:
) -> list[torch.Tensor]:
prompt_semantic_tokens = self.prompt_cache["prompt_semantic"].unsqueeze(0).unsqueeze(0).to(self.configs.device)
prompt_phones = torch.LongTensor(self.prompt_cache["phones"]).unsqueeze(0).to(self.configs.device)
raw_entry = self.prompt_cache["refer_spec"][0]
@ -1574,7 +1543,7 @@ class TTS:
def sola_algorithm(
self,
audio_fragments: List[torch.Tensor],
audio_fragments: list[torch.Tensor],
overlap_len: int,
):
for i in range(len(audio_fragments) - 1):

View File

@ -365,7 +365,7 @@ def change_sovits_weights(sovits_path, prompt_language=None, text_language=None)
del vq_model.enc_q
if is_lora is False:
console.print(f">> loading sovits_{model_version}", vq_model.load_state_dict(dict_s2["weight"], strict=False))
console.print(f">> loading sovits_{model_version}", vq_model.load_state_dict(dict_s2["weight"]))
else:
path_sovits = path_sovits_v3 if model_version == "v3" else path_sovits_v4
console.print(f">> loading sovits_{model_version}spretrained_G")
@ -381,7 +381,7 @@ def change_sovits_weights(sovits_path, prompt_language=None, text_language=None)
init_lora_weights=True,
)
vq_model.cfm = get_peft_model(vq_model.cfm, lora_config) # type: ignore
vq_model.load_state_dict(state_dict, strict=False)
vq_model.load_state_dict(state_dict)
vq_model.cfm = vq_model.cfm.merge_and_unload() # pyright: ignore[reportAttributeAccessIssue, reportCallIssue]
vq_model.eval()

View File

@ -138,9 +138,6 @@ is_share = args.share
infer_device = torch.device(args.device)
device = infer_device
if torch.mps.is_available():
device = torch.device("cpu")
dtype = get_dtype(device.index)
is_half = dtype == torch.float16
@ -152,7 +149,7 @@ SoVITS_names, GPT_names = get_weights_names(i18n)
gpt_path = str(args.gpt) or GPT_names[0][-1]
sovits_path = str(args.sovits) or SoVITS_names[0][-1]
cnhubert_base_path = str(args.cuhubert)
cnhubert_base_path = str(args.cnhubert)
bert_path = str(args.bert)
version = model_version = "v2"
@ -415,7 +412,7 @@ with gr.Blocks(title="GPT-SoVITS WebUI", analytics_enabled=False, js=js, css=css
with gr.Column():
with gr.Row(equal_height=True):
batch_size = gr.Slider(
minimum=1, maximum=200, step=1, label=i18n("batch_size"), value=20, interactive=True
minimum=1, maximum=20, step=1, label=i18n("batch_size"), value=10, interactive=True
)
sample_steps = gr.Radio(
label=i18n("采样步数(仅对V3/4生效)"), value=32, choices=[4, 8, 16, 32, 64, 128], visible=True
@ -462,8 +459,8 @@ with gr.Blocks(title="GPT-SoVITS WebUI", analytics_enabled=False, js=js, css=css
parallel_infer = gr.Checkbox(label=i18n("并行推理"), value=True, interactive=True, show_label=True)
split_bucket = gr.Checkbox(
label=i18n("数据分桶(并行推理时会降低一点计算量)"),
value=True,
interactive=True,
value=False,
interactive=False,
show_label=True,
)

View File

@ -323,7 +323,7 @@ def change_sovits_weights(sovits_path, prompt_language=None, text_language=None)
if hasattr(vq_model, "enc_q"):
del vq_model.enc_q
vq_model.load_state_dict(dict_s2["weight"], strict=False)
vq_model.load_state_dict(dict_s2["weight"])
vq_model = vq_model.to(infer_device, dtype)