diff --git a/GPT_SoVITS/Accelerate/MLX/sample_funcs_mlx.py b/GPT_SoVITS/Accelerate/MLX/sample_funcs_mlx.py index dcc02a63..8d81d311 100644 --- a/GPT_SoVITS/Accelerate/MLX/sample_funcs_mlx.py +++ b/GPT_SoVITS/Accelerate/MLX/sample_funcs_mlx.py @@ -20,11 +20,11 @@ class SampleProtocolMLX(Protocol): def apply_repetition_penalty(logits: Array, previous_tokens: Array, repetition_penalty: float): batch_idx = mx.arange(previous_tokens.shape[0]) - selected_logits = logits[batch_idx, previous_tokens] + selected_logits = mx.take_along_axis(logits, previous_tokens, axis=1) selected_logits = mx.where( selected_logits < 0, selected_logits * repetition_penalty, selected_logits / repetition_penalty ) - logits[batch_idx, previous_tokens] = selected_logits + logits[batch_idx.reshape(-1, 1), previous_tokens] = selected_logits return logits diff --git a/GPT_SoVITS/Accelerate/MLX/structs_mlx.py b/GPT_SoVITS/Accelerate/MLX/structs_mlx.py index 426309e6..60b655df 100644 --- a/GPT_SoVITS/Accelerate/MLX/structs_mlx.py +++ b/GPT_SoVITS/Accelerate/MLX/structs_mlx.py @@ -117,7 +117,8 @@ class T2SSessionMLX: self.input_pos = mx.zeros_like(self.prefill_len) self.input_pos += self.prefill_len - self.input_pos = self.input_pos.squeeze(0) # 30% Performance Improvement + if bsz == 1: + self.input_pos = self.input_pos.squeeze(0) # 30% Performance Improvement in bsz=1 # EOS self.completed = mx.array([False] * len(self.x)).astype(mx.bool_) diff --git a/GPT_SoVITS/Accelerate/MLX/t2s_engine_mlx.py b/GPT_SoVITS/Accelerate/MLX/t2s_engine_mlx.py index d21557cb..1b603ec1 100644 --- a/GPT_SoVITS/Accelerate/MLX/t2s_engine_mlx.py +++ b/GPT_SoVITS/Accelerate/MLX/t2s_engine_mlx.py @@ -91,7 +91,7 @@ class T2SEngine(T2SEngineProtocol): session.attn_mask, session.kv_cache, ) # bs, seq_len, embed_dim - xy_dec = xy_dec[None, batch_idx, session.input_pos - 1] + xy_dec = xy_dec[batch_idx, None, session.input_pos - 1] if debug: mx.eval(xy_dec) else: @@ -120,7 +120,7 @@ class T2SEngine(T2SEngineProtocol): mx.metal.stop_capture() decoder.post_forward(idx, session) - logits = decoder.ar_predict_layer(xy_dec[:, -1]) + logits = decoder.ar_predict_layer(xy_dec.squeeze(1)) session.input_pos += 1 if idx == 0: @@ -136,7 +136,7 @@ class T2SEngine(T2SEngineProtocol): temperature=request.temperature, ) - session.y[batch_idx, session.y_len + idx] = samples + session.y[batch_idx.reshape(-1, 1), session.y_len + idx] = samples if debug: mx.eval(samples) @@ -168,9 +168,9 @@ class T2SEngine(T2SEngineProtocol): logger.info( f"T2S Decoding EOS {session.prefill_len.tolist().__str__().strip('[]')} -> {[i.shape[-1] for i in session.y_results].__str__().strip('[]')}" ) - logger.info(f"Infer Speed: {(idx + 1) / (time.perf_counter() - t1):.2f} token/s") + logger.info(f"Infer Speed: {(idx + 1) * session.bsz / (time.perf_counter() - t1):.2f} token/s") infer_time = time.perf_counter() - t1 - infer_speed = (idx + 1) / infer_time + infer_speed = (idx + 1) * session.bsz / infer_time break if (request.early_stop_num != -1 and idx >= request.early_stop_num) or idx == max_token - 1: @@ -179,9 +179,9 @@ class T2SEngine(T2SEngineProtocol): session.y_results[j] = session.y[[j], session.y_len : session.y_len + idx] session.completed[j] = True logger.error("Bad Full Prediction") - logger.info(f"Infer Speed: {(idx + 1) / (time.perf_counter() - t1):.2f} token/s") + logger.info(f"Infer Speed: {(idx + 1) * session.bsz / (time.perf_counter() - t1):.2f} token/s") infer_time = time.perf_counter() - t1 - infer_speed = (idx + 1) / infer_time + infer_speed = (idx + 1) * session.bsz / infer_time break with timer("MLX.NextPos", debug=debug): diff --git a/GPT_SoVITS/Accelerate/MLX/t2s_model_abc.py b/GPT_SoVITS/Accelerate/MLX/t2s_model_abc.py index 80de8f39..4e22940a 100644 --- a/GPT_SoVITS/Accelerate/MLX/t2s_model_abc.py +++ b/GPT_SoVITS/Accelerate/MLX/t2s_model_abc.py @@ -109,7 +109,8 @@ class SinePositionalEmbedding(nn.Module): self.compute_pe(x.dtype) assert self.pe is not None - pe_values = self.pe[:, : x.shape[-2]] + batch_size = x.shape[0] + pe_values = self.pe[mx.arange(batch_size), : x.shape[-2]] return x * self.x_scale + self.alpha * pe_values diff --git a/GPT_SoVITS/Accelerate/PyTorch/t2s_engine.py b/GPT_SoVITS/Accelerate/PyTorch/t2s_engine.py index 89f18bc8..b92ab23b 100644 --- a/GPT_SoVITS/Accelerate/PyTorch/t2s_engine.py +++ b/GPT_SoVITS/Accelerate/PyTorch/t2s_engine.py @@ -71,7 +71,7 @@ class T2SEngine(T2SEngineProtocol): session.kv_cache = decoder.init_cache(session.bsz) t1 = time.perf_counter() xy_dec = decoder.h.prefill(session.xy_pos, session.kv_cache, session.attn_mask) - xy_dec = xy_dec[None, batch_idx, session.input_pos - 1] + xy_dec = xy_dec[batch_idx, None, session.input_pos - 1] else: if ( request.use_cuda_graph @@ -101,7 +101,7 @@ class T2SEngine(T2SEngineProtocol): with torch.cuda.stream(session.stream) if session.stream is not None else contextlib.nullcontext(): decoder.post_forward(idx, session) - logits = decoder.ar_predict_layer(xy_dec[:, -1]) + logits = decoder.ar_predict_layer(xy_dec.squeeze(1)) if idx == 0: logits[:, -1] = float("-inf") @@ -115,7 +115,7 @@ class T2SEngine(T2SEngineProtocol): repetition_penalty=request.repetition_penalty, temperature=request.temperature, ) - session.y[batch_idx, session.y_len + idx] = samples + session.y[batch_idx.reshape(-1, 1), session.y_len + idx] = samples session.input_pos.add_(1) with torch_profiler.record("EOS"), timer("Torch.EOS", debug=debug): diff --git a/GPT_SoVITS/TTS_infer_pack/TTS.py b/GPT_SoVITS/TTS_infer_pack/TTS.py index be2a4c64..350e7ba4 100644 --- a/GPT_SoVITS/TTS_infer_pack/TTS.py +++ b/GPT_SoVITS/TTS_infer_pack/TTS.py @@ -5,7 +5,8 @@ import random import time import traceback from copy import deepcopy -from typing import List, Tuple, Union +from pathlib import Path +from typing import Any import ffmpeg import librosa @@ -18,12 +19,12 @@ from peft import LoraConfig, get_peft_model from tqdm import tqdm from transformers import AutoModelForMaskedLM, AutoTokenizer -from GPT_SoVITS.AR.models.t2s_lightning_module import Text2SemanticLightningModule +from GPT_SoVITS.Accelerate import MLX, PyTorch, T2SEngineProtocol, T2SRequest, backends from GPT_SoVITS.BigVGAN.bigvgan import BigVGAN from GPT_SoVITS.feature_extractor.cnhubert import CNHubert from GPT_SoVITS.module.mel_processing import mel_spectrogram_torch, spectrogram_torch from GPT_SoVITS.module.models import Generator, SynthesizerTrn, SynthesizerTrnV3 -from GPT_SoVITS.process_ckpt import get_sovits_version_from_path_fast, load_sovits_new +from GPT_SoVITS.process_ckpt import inspect_version from GPT_SoVITS.sv import SV from GPT_SoVITS.TTS_infer_pack.text_segmentation_method import splits from GPT_SoVITS.TTS_infer_pack.TextPreprocessor import TextPreprocessor @@ -34,6 +35,7 @@ from tools.my_utils import DictToAttrRecursive now_dir = os.getcwd() resample_transform_dict = {} +v3v4set = {"v3", "v4"} def resample(audio_tensor, sr0, sr1, device): @@ -170,12 +172,6 @@ def set_seed(seed: int): if torch.cuda.is_available(): torch.cuda.manual_seed(seed) torch.cuda.manual_seed_all(seed) - # torch.backends.cudnn.deterministic = True - # torch.backends.cudnn.benchmark = False - # torch.backends.cudnn.enabled = True - # 开启后会影响精度 - torch.backends.cuda.matmul.allow_tf32 = False - torch.backends.cudnn.allow_tf32 = False except: pass return seed @@ -254,7 +250,7 @@ class TTS_Config: # "auto",#多语种启动切分识别语种 # "auto_yue",#多语种启动切分识别语种 - def __init__(self, configs: Union[dict, str] = None): + def __init__(self, configs: dict | str = None): # 设置默认配置文件路径 configs_base_path: str = "GPT_SoVITS/configs/" os.makedirs(configs_base_path, exist_ok=True) @@ -312,7 +308,6 @@ class TTS_Config: self.update_configs() self.max_sec = None - self.hz: int = 50 self.semantic_frame_rate: str = "25hz" self.segment_size: int = 20480 self.filter_length: int = 2048 @@ -377,14 +372,14 @@ class TTS_Config: class TTS: - def __init__(self, configs: Union[dict, str, TTS_Config]): + def __init__(self, configs: dict | str | TTS_Config): if isinstance(configs, TTS_Config): self.configs = configs else: self.configs: TTS_Config = TTS_Config(configs) - self.t2s_model: Text2SemanticLightningModule = None - self.vits_model: Union[SynthesizerTrn, SynthesizerTrnV3] = None + self.t2s_model: T2SEngineProtocol = None + self.vits_model: SynthesizerTrn | SynthesizerTrnV3 = None self.bert_tokenizer: AutoTokenizer = None self.bert_model: AutoModelForMaskedLM = None self.cnhuhbert_model: CNHubert = None @@ -401,12 +396,6 @@ class TTS: "overlapped_len": None, } - self._init_models() - - self.text_preprocessor: TextPreprocessor = TextPreprocessor( - self.bert_model, self.bert_tokenizer, self.configs.device - ) - self.prompt_cache: dict = { "ref_audio_path": None, "prompt_semantic": None, @@ -422,6 +411,12 @@ class TTS: self.stop_flag: bool = False self.precision: torch.dtype = torch.float16 if self.configs.is_half else torch.float32 + self._init_models() + + self.text_preprocessor: TextPreprocessor = TextPreprocessor( + self.bert_model, self.bert_tokenizer, self.configs.device + ) + def _init_models( self, ): @@ -450,34 +445,17 @@ class TTS: def init_vits_weights(self, weights_path: str): self.configs.vits_weights_path = weights_path - version, model_version, if_lora_v3 = get_sovits_version_from_path_fast(weights_path) + model_version, version, if_lora_v3, hps, dict_s2 = inspect_version(weights_path) if "Pro" in model_version: self.init_sv_model() - path_sovits = self.configs.default_configs[model_version]["vits_weights_path"] + path_sovits: str = self.configs.default_configs[model_version]["vits_weights_path"] if if_lora_v3 is True and os.path.exists(path_sovits) is False: - info = path_sovits + i18n("SoVITS %s 底模缺失,无法加载相应 LoRA 权重" % model_version) + info = path_sovits + i18n(f"SoVITS {model_version} 底模缺失,无法加载相应 LoRA 权重") raise FileExistsError(info) # dict_s2 = torch.load(weights_path, map_location=self.configs.device,weights_only=False) - dict_s2 = load_sovits_new(weights_path) - hps = dict_s2["config"] hps["model"]["semantic_frame_rate"] = "25hz" - if "enc_p.text_embedding.weight" not in dict_s2["weight"]: - hps["model"]["version"] = "v2" # v3model,v2sybomls - elif dict_s2["weight"]["enc_p.text_embedding.weight"].shape[0] == 322: - hps["model"]["version"] = "v1" - else: - hps["model"]["version"] = "v2" - version = hps["model"]["version"] - v3v4set = {"v3", "v4"} - if model_version not in v3v4set: - if "Pro" not in model_version: - model_version = version - else: - hps["model"]["version"] = model_version - else: - hps["model"]["version"] = model_version self.configs.filter_length = hps["data"]["filter_length"] self.configs.segment_size = hps["train"]["segment_size"] @@ -518,12 +496,14 @@ class TTS: if if_lora_v3 is False: print( - f"Loading VITS weights from {weights_path}. {vits_model.load_state_dict(dict_s2['weight'], strict=False)}" + f"Loading VITS weights from {weights_path}. {vits_model.load_state_dict(dict_s2['weight'], strict=True)}" ) else: - print( - f"Loading VITS pretrained weights from {weights_path}. {vits_model.load_state_dict(load_sovits_new(path_sovits)['weight'], strict=False)}" - ) + print(f">> loading sovits_{model_version}spretrained_G") + dict_pretrain = torch.load(path_sovits)["weight"] + print(f">> loading sovits_{model_version}_lora{model_version}") + dict_pretrain.update(dict_s2["weight"]) + state_dict = dict_pretrain lora_rank = dict_s2["lora_rank"] lora_config = LoraConfig( target_modules=["to_k", "to_q", "to_v", "to_out.0"], @@ -531,12 +511,10 @@ class TTS: lora_alpha=lora_rank, init_lora_weights=True, ) - vits_model.cfm = get_peft_model(vits_model.cfm, lora_config) - print( - f"Loading LoRA weights from {weights_path}. {vits_model.load_state_dict(dict_s2['weight'], strict=False)}" - ) - - vits_model.cfm = vits_model.cfm.merge_and_unload() + vits_model.cfm = get_peft_model(vits_model.cfm, lora_config) # type: ignore + vits_model.load_state_dict(state_dict, strict=True) + vits_model.cfm = vits_model.cfm.merge_and_unload() # pyright: ignore[reportAttributeAccessIssue, reportCallIssue] + vits_model.eval() vits_model = vits_model.to(self.configs.device) vits_model = vits_model.eval() @@ -547,21 +525,29 @@ class TTS: self.configs.save_configs() - def init_t2s_weights(self, weights_path: str): + def init_t2s_weights(self, weights_path: str, ar_backend: str = backends[-1], quantization: Any = None): print(f"Loading Text2Semantic weights from {weights_path}") self.configs.t2s_weights_path = weights_path self.configs.save_configs() - self.configs.hz = 50 - dict_s1 = torch.load(weights_path, map_location=self.configs.device, weights_only=False) - config = dict_s1["config"] - self.configs.max_sec = config["data"]["max_sec"] - t2s_model = Text2SemanticLightningModule(config, "****", is_train=False) - t2s_model.load_state_dict(dict_s1["weight"]) - t2s_model = t2s_model.to(self.configs.device) - t2s_model = t2s_model.eval() - self.t2s_model = t2s_model - if self.configs.is_half and str(self.configs.device) != "cpu": - self.t2s_model = self.t2s_model.half() + + if "mlx" in ar_backend.lower(): + t2s_engine = MLX.T2SEngineMLX( + MLX.T2SEngineMLX.load_decoder( + Path(weights_path), backend=ar_backend, quantize_mode=quantization, max_batch_size=20 + ), + "mx.gpu" if self.configs.device.type != "cpu" else "mx.cpu", + dtype=self.precision, + ) + else: + t2s_engine = PyTorch.T2SEngineTorch( + PyTorch.T2SEngineTorch.load_decoder( + Path(weights_path), backend=ar_backend, quantize_mode=quantization, max_batch_size=20 + ), + self.configs.device if not torch.mps.is_available() else torch.device("cpu"), + dtype=self.precision, + ) + + self.t2s_model = t2s_engine def init_vocoder(self, version: str): if version == "v3": @@ -782,7 +768,7 @@ class TTS: prompt_semantic = codes[0, 0].to(self.configs.device) self.prompt_cache["prompt_semantic"] = prompt_semantic - def batch_sequences(self, sequences: List[torch.Tensor], axis: int = 0, pad_value: int = 0, max_length: int = None): + def batch_sequences(self, sequences: list[torch.Tensor], axis: int = 0, pad_value: int = 0, max_length: int = None): seq = sequences[0] ndim = seq.dim() if axis < 0: @@ -923,11 +909,11 @@ class TTS: Recovery the order of the audio according to the batch_index_list. Args: - data (List[list(torch.Tensor)]): the out of order audio . - batch_index_list (List[list[int]]): the batch index list. + data (list[list(torch.Tensor)]): the out of order audio . + batch_index_list (list[list[int]]): the batch index list. Returns: - list (List[torch.Tensor]): the data in the original order. + list (list[torch.Tensor]): the data in the original order. """ length = len(sum(batch_index_list, [])) _data = [None] * length @@ -964,7 +950,6 @@ class TTS: "text_split_method": "cut0", # str. text split method, see text_segmentation_method.py for details. "batch_size": 1, # int. batch size for inference "batch_threshold": 0.75, # float. threshold for batch splitting. - "split_bucket: True, # bool. whether to split the batch into multiple buckets. "return_fragment": False, # bool. step by step return the audio fragment. "speed_factor":1.0, # float. control the speed of the synthesized audio. "fragment_interval":0.3, # float. to control the interval of the audio fragment. @@ -992,40 +977,18 @@ class TTS: batch_size = inputs.get("batch_size", 1) batch_threshold = inputs.get("batch_threshold", 0.75) speed_factor = inputs.get("speed_factor", 1.0) - split_bucket = inputs.get("split_bucket", True) return_fragment = inputs.get("return_fragment", False) fragment_interval = inputs.get("fragment_interval", 0.3) seed = inputs.get("seed", -1) seed = -1 if seed in ["", None] else seed - actual_seed = set_seed(seed) + set_seed(seed) parallel_infer = inputs.get("parallel_infer", True) repetition_penalty = inputs.get("repetition_penalty", 1.35) sample_steps = inputs.get("sample_steps", 32) super_sampling = inputs.get("super_sampling", False) - if parallel_infer: - print(i18n("并行推理模式已开启")) - self.t2s_model.model.infer_panel = self.t2s_model.model.infer_panel_batch_infer - else: - print(i18n("并行推理模式已关闭")) - self.t2s_model.model.infer_panel = self.t2s_model.model.infer_panel_naive_batched - if return_fragment: print(i18n("分段返回模式已开启")) - if split_bucket: - split_bucket = False - print(i18n("分段返回模式不支持分桶处理,已自动关闭分桶处理")) - - if split_bucket and speed_factor == 1.0 and not (self.configs.use_vocoder and parallel_infer): - print(i18n("分桶处理模式已开启")) - elif speed_factor != 1.0: - print(i18n("语速调节不支持分桶处理,已自动关闭分桶处理")) - split_bucket = False - elif self.configs.use_vocoder and parallel_infer: - print(i18n("当开启并行推理模式时,SoVits V3/4模型不支持分桶处理,已自动关闭分桶处理")) - split_bucket = False - else: - print(i18n("分桶处理模式已关闭")) if fragment_interval < 0.01: fragment_interval = 0.01 @@ -1102,7 +1065,7 @@ class TTS: prompt_data=self.prompt_cache if not no_prompt_text else None, batch_size=batch_size, threshold=batch_threshold, - split_bucket=split_bucket, + split_bucket=False, device=self.configs.device, precision=self.precision, ) @@ -1158,37 +1121,44 @@ class TTS: if item is None: continue - batch_phones: List[torch.LongTensor] = item["phones"] + batch_phones: list[torch.Tensor] = item["phones"] # batch_phones:torch.LongTensor = item["phones"] - batch_phones_len: torch.LongTensor = item["phones_len"] - all_phoneme_ids: torch.LongTensor = item["all_phones"] - all_phoneme_lens: torch.LongTensor = item["all_phones_len"] - all_bert_features: torch.LongTensor = item["all_bert_features"] + batch_phones_len: torch.Tensor = item["phones_len"] + all_phoneme_ids: list[torch.Tensor] = item["all_phones"] + all_phoneme_lens: torch.Tensor = item["all_phones_len"] + all_bert_features: list[torch.Tensor] = item["all_bert_features"] norm_text: str = item["norm_text"] max_len = item["max_len"] print(i18n("前端处理后的文本(每句):"), norm_text) if no_prompt_text: - prompt = None + prompt = torch.zeros(1, 0).to(self.configs.device, self.precision) else: - prompt = ( - self.prompt_cache["prompt_semantic"].expand(len(all_phoneme_ids), -1).to(self.configs.device) - ) + prompt = self.prompt_cache["prompt_semantic"].to(self.configs.device).unsqueeze(0) print(f"############ {i18n('预测语义Token')} ############") - pred_semantic_list, idx_list = self.t2s_model.model.infer_panel( + t2s_request = T2SRequest( all_phoneme_ids, all_phoneme_lens, prompt, all_bert_features, - # prompt_phone_len=ph_offset, + valid_length=len(all_phoneme_ids), top_k=top_k, top_p=top_p, temperature=temperature, - early_stop_num=self.configs.hz * self.configs.max_sec, - max_len=max_len, - repetition_penalty=repetition_penalty, + debug=os.environ.get("DEBUG", "0") == "1", + use_cuda_graph=torch.cuda.is_available(), ) + t2s_result = self.t2s_model.generate(t2s_request) + + if t2s_result.exception is not None: + print(t2s_result.traceback) + raise RuntimeError() + + pred_semantic_list = t2s_result.result + assert pred_semantic_list + pred_semantic_list = [semantic.squeeze(0) for semantic in pred_semantic_list] + t4 = time.perf_counter() t_34 += t4 - t3 @@ -1220,7 +1190,6 @@ class TTS: if speed_factor == 1.0: print(f"{i18n('并行合成中')}...") # ## vits并行推理 method 2 - pred_semantic_list = [item[-idx:] for item, idx in zip(pred_semantic_list, idx_list)] upsample_rate = math.prod(self.vits_model.upsample_rates) audio_frag_idx = [ pred_semantic_list[i].shape[0] * 2 * upsample_rate @@ -1231,7 +1200,7 @@ class TTS: torch.cat(pred_semantic_list).unsqueeze(0).unsqueeze(0).to(self.configs.device) ) _batch_phones = torch.cat(batch_phones).unsqueeze(0).to(self.configs.device) - if self.is_v2pro != True: + if not self.is_v2pro: _batch_audio_fragment = self.vits_model.decode( all_pred_semantic, _batch_phones, refer_audio_spec, speed=speed_factor ).detach()[0, 0, :] @@ -1251,7 +1220,7 @@ class TTS: _pred_semantic = ( pred_semantic_list[i][-idx:].unsqueeze(0).unsqueeze(0) ) # .unsqueeze(0)#mq要多unsqueeze一次 - if self.is_v2pro != True: + if not self.is_v2pro: audio_fragment = self.vits_model.decode( _pred_semantic, phones, refer_audio_spec, speed=speed_factor ).detach()[0, 0, :] @@ -1308,7 +1277,7 @@ class TTS: output_sr, batch_index_list, speed_factor, - split_bucket, + False, fragment_interval, super_sampling if self.configs.use_vocoder and self.configs.version == "v3" else False, ) @@ -1340,14 +1309,14 @@ class TTS: def audio_postprocess( self, - audio: List[torch.Tensor], + audio: list[torch.Tensor], sr: int, batch_index_list: list = None, speed_factor: float = 1.0, - split_bucket: bool = True, + split_bucket: bool = False, fragment_interval: float = 0.3, super_sampling: bool = False, - ) -> Tuple[int, np.ndarray]: + ) -> tuple[int, np.ndarray]: zero_wav = torch.zeros( int(self.configs.sampling_rate * fragment_interval), dtype=self.precision, device=self.configs.device ) @@ -1459,12 +1428,12 @@ class TTS: def using_vocoder_synthesis_batched_infer( self, - idx_list: List[int], - semantic_tokens_list: List[torch.Tensor], - batch_phones: List[torch.Tensor], + idx_list: list[int], + semantic_tokens_list: list[torch.Tensor], + batch_phones: list[torch.Tensor], speed: float = 1.0, sample_steps: int = 32, - ) -> List[torch.Tensor]: + ) -> list[torch.Tensor]: prompt_semantic_tokens = self.prompt_cache["prompt_semantic"].unsqueeze(0).unsqueeze(0).to(self.configs.device) prompt_phones = torch.LongTensor(self.prompt_cache["phones"]).unsqueeze(0).to(self.configs.device) raw_entry = self.prompt_cache["refer_spec"][0] @@ -1574,7 +1543,7 @@ class TTS: def sola_algorithm( self, - audio_fragments: List[torch.Tensor], + audio_fragments: list[torch.Tensor], overlap_len: int, ): for i in range(len(audio_fragments) - 1): diff --git a/GPT_SoVITS/inference_webui.py b/GPT_SoVITS/inference_webui.py index 904affbb..4e05e5e0 100644 --- a/GPT_SoVITS/inference_webui.py +++ b/GPT_SoVITS/inference_webui.py @@ -365,7 +365,7 @@ def change_sovits_weights(sovits_path, prompt_language=None, text_language=None) del vq_model.enc_q if is_lora is False: - console.print(f">> loading sovits_{model_version}", vq_model.load_state_dict(dict_s2["weight"], strict=False)) + console.print(f">> loading sovits_{model_version}", vq_model.load_state_dict(dict_s2["weight"])) else: path_sovits = path_sovits_v3 if model_version == "v3" else path_sovits_v4 console.print(f">> loading sovits_{model_version}spretrained_G") @@ -381,7 +381,7 @@ def change_sovits_weights(sovits_path, prompt_language=None, text_language=None) init_lora_weights=True, ) vq_model.cfm = get_peft_model(vq_model.cfm, lora_config) # type: ignore - vq_model.load_state_dict(state_dict, strict=False) + vq_model.load_state_dict(state_dict) vq_model.cfm = vq_model.cfm.merge_and_unload() # pyright: ignore[reportAttributeAccessIssue, reportCallIssue] vq_model.eval() diff --git a/GPT_SoVITS/inference_webui_fast.py b/GPT_SoVITS/inference_webui_fast.py index bbf0227f..6d3ca208 100644 --- a/GPT_SoVITS/inference_webui_fast.py +++ b/GPT_SoVITS/inference_webui_fast.py @@ -138,9 +138,6 @@ is_share = args.share infer_device = torch.device(args.device) device = infer_device -if torch.mps.is_available(): - device = torch.device("cpu") - dtype = get_dtype(device.index) is_half = dtype == torch.float16 @@ -152,7 +149,7 @@ SoVITS_names, GPT_names = get_weights_names(i18n) gpt_path = str(args.gpt) or GPT_names[0][-1] sovits_path = str(args.sovits) or SoVITS_names[0][-1] -cnhubert_base_path = str(args.cuhubert) +cnhubert_base_path = str(args.cnhubert) bert_path = str(args.bert) version = model_version = "v2" @@ -415,7 +412,7 @@ with gr.Blocks(title="GPT-SoVITS WebUI", analytics_enabled=False, js=js, css=css with gr.Column(): with gr.Row(equal_height=True): batch_size = gr.Slider( - minimum=1, maximum=200, step=1, label=i18n("batch_size"), value=20, interactive=True + minimum=1, maximum=20, step=1, label=i18n("batch_size"), value=10, interactive=True ) sample_steps = gr.Radio( label=i18n("采样步数(仅对V3/4生效)"), value=32, choices=[4, 8, 16, 32, 64, 128], visible=True @@ -462,8 +459,8 @@ with gr.Blocks(title="GPT-SoVITS WebUI", analytics_enabled=False, js=js, css=css parallel_infer = gr.Checkbox(label=i18n("并行推理"), value=True, interactive=True, show_label=True) split_bucket = gr.Checkbox( label=i18n("数据分桶(并行推理时会降低一点计算量)"), - value=True, - interactive=True, + value=False, + interactive=False, show_label=True, ) diff --git a/test.py b/test.py index fd4b7127..894d3db4 100644 --- a/test.py +++ b/test.py @@ -323,7 +323,7 @@ def change_sovits_weights(sovits_path, prompt_language=None, text_language=None) if hasattr(vq_model, "enc_q"): del vq_model.enc_q - vq_model.load_state_dict(dict_s2["weight"], strict=False) + vq_model.load_state_dict(dict_s2["weight"]) vq_model = vq_model.to(infer_device, dtype)