mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2026-07-02 20:10:27 +08:00
.
This commit is contained in:
parent
41e44afa12
commit
620b7810e6
@ -20,11 +20,11 @@ class SampleProtocolMLX(Protocol):
|
||||
|
||||
def apply_repetition_penalty(logits: Array, previous_tokens: Array, repetition_penalty: float):
|
||||
batch_idx = mx.arange(previous_tokens.shape[0])
|
||||
selected_logits = logits[batch_idx, previous_tokens]
|
||||
selected_logits = mx.take_along_axis(logits, previous_tokens, axis=1)
|
||||
selected_logits = mx.where(
|
||||
selected_logits < 0, selected_logits * repetition_penalty, selected_logits / repetition_penalty
|
||||
)
|
||||
logits[batch_idx, previous_tokens] = selected_logits
|
||||
logits[batch_idx.reshape(-1, 1), previous_tokens] = selected_logits
|
||||
return logits
|
||||
|
||||
|
||||
|
||||
@ -117,7 +117,8 @@ class T2SSessionMLX:
|
||||
|
||||
self.input_pos = mx.zeros_like(self.prefill_len)
|
||||
self.input_pos += self.prefill_len
|
||||
self.input_pos = self.input_pos.squeeze(0) # 30% Performance Improvement
|
||||
if bsz == 1:
|
||||
self.input_pos = self.input_pos.squeeze(0) # 30% Performance Improvement in bsz=1
|
||||
|
||||
# EOS
|
||||
self.completed = mx.array([False] * len(self.x)).astype(mx.bool_)
|
||||
|
||||
@ -91,7 +91,7 @@ class T2SEngine(T2SEngineProtocol):
|
||||
session.attn_mask,
|
||||
session.kv_cache,
|
||||
) # bs, seq_len, embed_dim
|
||||
xy_dec = xy_dec[None, batch_idx, session.input_pos - 1]
|
||||
xy_dec = xy_dec[batch_idx, None, session.input_pos - 1]
|
||||
if debug:
|
||||
mx.eval(xy_dec)
|
||||
else:
|
||||
@ -120,7 +120,7 @@ class T2SEngine(T2SEngineProtocol):
|
||||
mx.metal.stop_capture()
|
||||
|
||||
decoder.post_forward(idx, session)
|
||||
logits = decoder.ar_predict_layer(xy_dec[:, -1])
|
||||
logits = decoder.ar_predict_layer(xy_dec.squeeze(1))
|
||||
session.input_pos += 1
|
||||
|
||||
if idx == 0:
|
||||
@ -136,7 +136,7 @@ class T2SEngine(T2SEngineProtocol):
|
||||
temperature=request.temperature,
|
||||
)
|
||||
|
||||
session.y[batch_idx, session.y_len + idx] = samples
|
||||
session.y[batch_idx.reshape(-1, 1), session.y_len + idx] = samples
|
||||
|
||||
if debug:
|
||||
mx.eval(samples)
|
||||
@ -168,9 +168,9 @@ class T2SEngine(T2SEngineProtocol):
|
||||
logger.info(
|
||||
f"T2S Decoding EOS {session.prefill_len.tolist().__str__().strip('[]')} -> {[i.shape[-1] for i in session.y_results].__str__().strip('[]')}"
|
||||
)
|
||||
logger.info(f"Infer Speed: {(idx + 1) / (time.perf_counter() - t1):.2f} token/s")
|
||||
logger.info(f"Infer Speed: {(idx + 1) * session.bsz / (time.perf_counter() - t1):.2f} token/s")
|
||||
infer_time = time.perf_counter() - t1
|
||||
infer_speed = (idx + 1) / infer_time
|
||||
infer_speed = (idx + 1) * session.bsz / infer_time
|
||||
break
|
||||
|
||||
if (request.early_stop_num != -1 and idx >= request.early_stop_num) or idx == max_token - 1:
|
||||
@ -179,9 +179,9 @@ class T2SEngine(T2SEngineProtocol):
|
||||
session.y_results[j] = session.y[[j], session.y_len : session.y_len + idx]
|
||||
session.completed[j] = True
|
||||
logger.error("Bad Full Prediction")
|
||||
logger.info(f"Infer Speed: {(idx + 1) / (time.perf_counter() - t1):.2f} token/s")
|
||||
logger.info(f"Infer Speed: {(idx + 1) * session.bsz / (time.perf_counter() - t1):.2f} token/s")
|
||||
infer_time = time.perf_counter() - t1
|
||||
infer_speed = (idx + 1) / infer_time
|
||||
infer_speed = (idx + 1) * session.bsz / infer_time
|
||||
break
|
||||
|
||||
with timer("MLX.NextPos", debug=debug):
|
||||
|
||||
@ -109,7 +109,8 @@ class SinePositionalEmbedding(nn.Module):
|
||||
self.compute_pe(x.dtype)
|
||||
assert self.pe is not None
|
||||
|
||||
pe_values = self.pe[:, : x.shape[-2]]
|
||||
batch_size = x.shape[0]
|
||||
pe_values = self.pe[mx.arange(batch_size), : x.shape[-2]]
|
||||
return x * self.x_scale + self.alpha * pe_values
|
||||
|
||||
|
||||
|
||||
@ -71,7 +71,7 @@ class T2SEngine(T2SEngineProtocol):
|
||||
session.kv_cache = decoder.init_cache(session.bsz)
|
||||
t1 = time.perf_counter()
|
||||
xy_dec = decoder.h.prefill(session.xy_pos, session.kv_cache, session.attn_mask)
|
||||
xy_dec = xy_dec[None, batch_idx, session.input_pos - 1]
|
||||
xy_dec = xy_dec[batch_idx, None, session.input_pos - 1]
|
||||
else:
|
||||
if (
|
||||
request.use_cuda_graph
|
||||
@ -101,7 +101,7 @@ class T2SEngine(T2SEngineProtocol):
|
||||
|
||||
with torch.cuda.stream(session.stream) if session.stream is not None else contextlib.nullcontext():
|
||||
decoder.post_forward(idx, session)
|
||||
logits = decoder.ar_predict_layer(xy_dec[:, -1])
|
||||
logits = decoder.ar_predict_layer(xy_dec.squeeze(1))
|
||||
|
||||
if idx == 0:
|
||||
logits[:, -1] = float("-inf")
|
||||
@ -115,7 +115,7 @@ class T2SEngine(T2SEngineProtocol):
|
||||
repetition_penalty=request.repetition_penalty,
|
||||
temperature=request.temperature,
|
||||
)
|
||||
session.y[batch_idx, session.y_len + idx] = samples
|
||||
session.y[batch_idx.reshape(-1, 1), session.y_len + idx] = samples
|
||||
session.input_pos.add_(1)
|
||||
|
||||
with torch_profiler.record("EOS"), timer("Torch.EOS", debug=debug):
|
||||
|
||||
@ -5,7 +5,8 @@ import random
|
||||
import time
|
||||
import traceback
|
||||
from copy import deepcopy
|
||||
from typing import List, Tuple, Union
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import ffmpeg
|
||||
import librosa
|
||||
@ -18,12 +19,12 @@ from peft import LoraConfig, get_peft_model
|
||||
from tqdm import tqdm
|
||||
from transformers import AutoModelForMaskedLM, AutoTokenizer
|
||||
|
||||
from GPT_SoVITS.AR.models.t2s_lightning_module import Text2SemanticLightningModule
|
||||
from GPT_SoVITS.Accelerate import MLX, PyTorch, T2SEngineProtocol, T2SRequest, backends
|
||||
from GPT_SoVITS.BigVGAN.bigvgan import BigVGAN
|
||||
from GPT_SoVITS.feature_extractor.cnhubert import CNHubert
|
||||
from GPT_SoVITS.module.mel_processing import mel_spectrogram_torch, spectrogram_torch
|
||||
from GPT_SoVITS.module.models import Generator, SynthesizerTrn, SynthesizerTrnV3
|
||||
from GPT_SoVITS.process_ckpt import get_sovits_version_from_path_fast, load_sovits_new
|
||||
from GPT_SoVITS.process_ckpt import inspect_version
|
||||
from GPT_SoVITS.sv import SV
|
||||
from GPT_SoVITS.TTS_infer_pack.text_segmentation_method import splits
|
||||
from GPT_SoVITS.TTS_infer_pack.TextPreprocessor import TextPreprocessor
|
||||
@ -34,6 +35,7 @@ from tools.my_utils import DictToAttrRecursive
|
||||
now_dir = os.getcwd()
|
||||
|
||||
resample_transform_dict = {}
|
||||
v3v4set = {"v3", "v4"}
|
||||
|
||||
|
||||
def resample(audio_tensor, sr0, sr1, device):
|
||||
@ -170,12 +172,6 @@ def set_seed(seed: int):
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed(seed)
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
# torch.backends.cudnn.deterministic = True
|
||||
# torch.backends.cudnn.benchmark = False
|
||||
# torch.backends.cudnn.enabled = True
|
||||
# 开启后会影响精度
|
||||
torch.backends.cuda.matmul.allow_tf32 = False
|
||||
torch.backends.cudnn.allow_tf32 = False
|
||||
except:
|
||||
pass
|
||||
return seed
|
||||
@ -254,7 +250,7 @@ class TTS_Config:
|
||||
# "auto",#多语种启动切分识别语种
|
||||
# "auto_yue",#多语种启动切分识别语种
|
||||
|
||||
def __init__(self, configs: Union[dict, str] = None):
|
||||
def __init__(self, configs: dict | str = None):
|
||||
# 设置默认配置文件路径
|
||||
configs_base_path: str = "GPT_SoVITS/configs/"
|
||||
os.makedirs(configs_base_path, exist_ok=True)
|
||||
@ -312,7 +308,6 @@ class TTS_Config:
|
||||
self.update_configs()
|
||||
|
||||
self.max_sec = None
|
||||
self.hz: int = 50
|
||||
self.semantic_frame_rate: str = "25hz"
|
||||
self.segment_size: int = 20480
|
||||
self.filter_length: int = 2048
|
||||
@ -377,14 +372,14 @@ class TTS_Config:
|
||||
|
||||
|
||||
class TTS:
|
||||
def __init__(self, configs: Union[dict, str, TTS_Config]):
|
||||
def __init__(self, configs: dict | str | TTS_Config):
|
||||
if isinstance(configs, TTS_Config):
|
||||
self.configs = configs
|
||||
else:
|
||||
self.configs: TTS_Config = TTS_Config(configs)
|
||||
|
||||
self.t2s_model: Text2SemanticLightningModule = None
|
||||
self.vits_model: Union[SynthesizerTrn, SynthesizerTrnV3] = None
|
||||
self.t2s_model: T2SEngineProtocol = None
|
||||
self.vits_model: SynthesizerTrn | SynthesizerTrnV3 = None
|
||||
self.bert_tokenizer: AutoTokenizer = None
|
||||
self.bert_model: AutoModelForMaskedLM = None
|
||||
self.cnhuhbert_model: CNHubert = None
|
||||
@ -401,12 +396,6 @@ class TTS:
|
||||
"overlapped_len": None,
|
||||
}
|
||||
|
||||
self._init_models()
|
||||
|
||||
self.text_preprocessor: TextPreprocessor = TextPreprocessor(
|
||||
self.bert_model, self.bert_tokenizer, self.configs.device
|
||||
)
|
||||
|
||||
self.prompt_cache: dict = {
|
||||
"ref_audio_path": None,
|
||||
"prompt_semantic": None,
|
||||
@ -422,6 +411,12 @@ class TTS:
|
||||
self.stop_flag: bool = False
|
||||
self.precision: torch.dtype = torch.float16 if self.configs.is_half else torch.float32
|
||||
|
||||
self._init_models()
|
||||
|
||||
self.text_preprocessor: TextPreprocessor = TextPreprocessor(
|
||||
self.bert_model, self.bert_tokenizer, self.configs.device
|
||||
)
|
||||
|
||||
def _init_models(
|
||||
self,
|
||||
):
|
||||
@ -450,34 +445,17 @@ class TTS:
|
||||
|
||||
def init_vits_weights(self, weights_path: str):
|
||||
self.configs.vits_weights_path = weights_path
|
||||
version, model_version, if_lora_v3 = get_sovits_version_from_path_fast(weights_path)
|
||||
model_version, version, if_lora_v3, hps, dict_s2 = inspect_version(weights_path)
|
||||
if "Pro" in model_version:
|
||||
self.init_sv_model()
|
||||
path_sovits = self.configs.default_configs[model_version]["vits_weights_path"]
|
||||
path_sovits: str = self.configs.default_configs[model_version]["vits_weights_path"]
|
||||
|
||||
if if_lora_v3 is True and os.path.exists(path_sovits) is False:
|
||||
info = path_sovits + i18n("SoVITS %s 底模缺失,无法加载相应 LoRA 权重" % model_version)
|
||||
info = path_sovits + i18n(f"SoVITS {model_version} 底模缺失,无法加载相应 LoRA 权重")
|
||||
raise FileExistsError(info)
|
||||
|
||||
# dict_s2 = torch.load(weights_path, map_location=self.configs.device,weights_only=False)
|
||||
dict_s2 = load_sovits_new(weights_path)
|
||||
hps = dict_s2["config"]
|
||||
hps["model"]["semantic_frame_rate"] = "25hz"
|
||||
if "enc_p.text_embedding.weight" not in dict_s2["weight"]:
|
||||
hps["model"]["version"] = "v2" # v3model,v2sybomls
|
||||
elif dict_s2["weight"]["enc_p.text_embedding.weight"].shape[0] == 322:
|
||||
hps["model"]["version"] = "v1"
|
||||
else:
|
||||
hps["model"]["version"] = "v2"
|
||||
version = hps["model"]["version"]
|
||||
v3v4set = {"v3", "v4"}
|
||||
if model_version not in v3v4set:
|
||||
if "Pro" not in model_version:
|
||||
model_version = version
|
||||
else:
|
||||
hps["model"]["version"] = model_version
|
||||
else:
|
||||
hps["model"]["version"] = model_version
|
||||
|
||||
self.configs.filter_length = hps["data"]["filter_length"]
|
||||
self.configs.segment_size = hps["train"]["segment_size"]
|
||||
@ -518,12 +496,14 @@ class TTS:
|
||||
|
||||
if if_lora_v3 is False:
|
||||
print(
|
||||
f"Loading VITS weights from {weights_path}. {vits_model.load_state_dict(dict_s2['weight'], strict=False)}"
|
||||
f"Loading VITS weights from {weights_path}. {vits_model.load_state_dict(dict_s2['weight'], strict=True)}"
|
||||
)
|
||||
else:
|
||||
print(
|
||||
f"Loading VITS pretrained weights from {weights_path}. {vits_model.load_state_dict(load_sovits_new(path_sovits)['weight'], strict=False)}"
|
||||
)
|
||||
print(f">> loading sovits_{model_version}spretrained_G")
|
||||
dict_pretrain = torch.load(path_sovits)["weight"]
|
||||
print(f">> loading sovits_{model_version}_lora{model_version}")
|
||||
dict_pretrain.update(dict_s2["weight"])
|
||||
state_dict = dict_pretrain
|
||||
lora_rank = dict_s2["lora_rank"]
|
||||
lora_config = LoraConfig(
|
||||
target_modules=["to_k", "to_q", "to_v", "to_out.0"],
|
||||
@ -531,12 +511,10 @@ class TTS:
|
||||
lora_alpha=lora_rank,
|
||||
init_lora_weights=True,
|
||||
)
|
||||
vits_model.cfm = get_peft_model(vits_model.cfm, lora_config)
|
||||
print(
|
||||
f"Loading LoRA weights from {weights_path}. {vits_model.load_state_dict(dict_s2['weight'], strict=False)}"
|
||||
)
|
||||
|
||||
vits_model.cfm = vits_model.cfm.merge_and_unload()
|
||||
vits_model.cfm = get_peft_model(vits_model.cfm, lora_config) # type: ignore
|
||||
vits_model.load_state_dict(state_dict, strict=True)
|
||||
vits_model.cfm = vits_model.cfm.merge_and_unload() # pyright: ignore[reportAttributeAccessIssue, reportCallIssue]
|
||||
vits_model.eval()
|
||||
|
||||
vits_model = vits_model.to(self.configs.device)
|
||||
vits_model = vits_model.eval()
|
||||
@ -547,21 +525,29 @@ class TTS:
|
||||
|
||||
self.configs.save_configs()
|
||||
|
||||
def init_t2s_weights(self, weights_path: str):
|
||||
def init_t2s_weights(self, weights_path: str, ar_backend: str = backends[-1], quantization: Any = None):
|
||||
print(f"Loading Text2Semantic weights from {weights_path}")
|
||||
self.configs.t2s_weights_path = weights_path
|
||||
self.configs.save_configs()
|
||||
self.configs.hz = 50
|
||||
dict_s1 = torch.load(weights_path, map_location=self.configs.device, weights_only=False)
|
||||
config = dict_s1["config"]
|
||||
self.configs.max_sec = config["data"]["max_sec"]
|
||||
t2s_model = Text2SemanticLightningModule(config, "****", is_train=False)
|
||||
t2s_model.load_state_dict(dict_s1["weight"])
|
||||
t2s_model = t2s_model.to(self.configs.device)
|
||||
t2s_model = t2s_model.eval()
|
||||
self.t2s_model = t2s_model
|
||||
if self.configs.is_half and str(self.configs.device) != "cpu":
|
||||
self.t2s_model = self.t2s_model.half()
|
||||
|
||||
if "mlx" in ar_backend.lower():
|
||||
t2s_engine = MLX.T2SEngineMLX(
|
||||
MLX.T2SEngineMLX.load_decoder(
|
||||
Path(weights_path), backend=ar_backend, quantize_mode=quantization, max_batch_size=20
|
||||
),
|
||||
"mx.gpu" if self.configs.device.type != "cpu" else "mx.cpu",
|
||||
dtype=self.precision,
|
||||
)
|
||||
else:
|
||||
t2s_engine = PyTorch.T2SEngineTorch(
|
||||
PyTorch.T2SEngineTorch.load_decoder(
|
||||
Path(weights_path), backend=ar_backend, quantize_mode=quantization, max_batch_size=20
|
||||
),
|
||||
self.configs.device if not torch.mps.is_available() else torch.device("cpu"),
|
||||
dtype=self.precision,
|
||||
)
|
||||
|
||||
self.t2s_model = t2s_engine
|
||||
|
||||
def init_vocoder(self, version: str):
|
||||
if version == "v3":
|
||||
@ -782,7 +768,7 @@ class TTS:
|
||||
prompt_semantic = codes[0, 0].to(self.configs.device)
|
||||
self.prompt_cache["prompt_semantic"] = prompt_semantic
|
||||
|
||||
def batch_sequences(self, sequences: List[torch.Tensor], axis: int = 0, pad_value: int = 0, max_length: int = None):
|
||||
def batch_sequences(self, sequences: list[torch.Tensor], axis: int = 0, pad_value: int = 0, max_length: int = None):
|
||||
seq = sequences[0]
|
||||
ndim = seq.dim()
|
||||
if axis < 0:
|
||||
@ -923,11 +909,11 @@ class TTS:
|
||||
Recovery the order of the audio according to the batch_index_list.
|
||||
|
||||
Args:
|
||||
data (List[list(torch.Tensor)]): the out of order audio .
|
||||
batch_index_list (List[list[int]]): the batch index list.
|
||||
data (list[list(torch.Tensor)]): the out of order audio .
|
||||
batch_index_list (list[list[int]]): the batch index list.
|
||||
|
||||
Returns:
|
||||
list (List[torch.Tensor]): the data in the original order.
|
||||
list (list[torch.Tensor]): the data in the original order.
|
||||
"""
|
||||
length = len(sum(batch_index_list, []))
|
||||
_data = [None] * length
|
||||
@ -964,7 +950,6 @@ class TTS:
|
||||
"text_split_method": "cut0", # str. text split method, see text_segmentation_method.py for details.
|
||||
"batch_size": 1, # int. batch size for inference
|
||||
"batch_threshold": 0.75, # float. threshold for batch splitting.
|
||||
"split_bucket: True, # bool. whether to split the batch into multiple buckets.
|
||||
"return_fragment": False, # bool. step by step return the audio fragment.
|
||||
"speed_factor":1.0, # float. control the speed of the synthesized audio.
|
||||
"fragment_interval":0.3, # float. to control the interval of the audio fragment.
|
||||
@ -992,40 +977,18 @@ class TTS:
|
||||
batch_size = inputs.get("batch_size", 1)
|
||||
batch_threshold = inputs.get("batch_threshold", 0.75)
|
||||
speed_factor = inputs.get("speed_factor", 1.0)
|
||||
split_bucket = inputs.get("split_bucket", True)
|
||||
return_fragment = inputs.get("return_fragment", False)
|
||||
fragment_interval = inputs.get("fragment_interval", 0.3)
|
||||
seed = inputs.get("seed", -1)
|
||||
seed = -1 if seed in ["", None] else seed
|
||||
actual_seed = set_seed(seed)
|
||||
set_seed(seed)
|
||||
parallel_infer = inputs.get("parallel_infer", True)
|
||||
repetition_penalty = inputs.get("repetition_penalty", 1.35)
|
||||
sample_steps = inputs.get("sample_steps", 32)
|
||||
super_sampling = inputs.get("super_sampling", False)
|
||||
|
||||
if parallel_infer:
|
||||
print(i18n("并行推理模式已开启"))
|
||||
self.t2s_model.model.infer_panel = self.t2s_model.model.infer_panel_batch_infer
|
||||
else:
|
||||
print(i18n("并行推理模式已关闭"))
|
||||
self.t2s_model.model.infer_panel = self.t2s_model.model.infer_panel_naive_batched
|
||||
|
||||
if return_fragment:
|
||||
print(i18n("分段返回模式已开启"))
|
||||
if split_bucket:
|
||||
split_bucket = False
|
||||
print(i18n("分段返回模式不支持分桶处理,已自动关闭分桶处理"))
|
||||
|
||||
if split_bucket and speed_factor == 1.0 and not (self.configs.use_vocoder and parallel_infer):
|
||||
print(i18n("分桶处理模式已开启"))
|
||||
elif speed_factor != 1.0:
|
||||
print(i18n("语速调节不支持分桶处理,已自动关闭分桶处理"))
|
||||
split_bucket = False
|
||||
elif self.configs.use_vocoder and parallel_infer:
|
||||
print(i18n("当开启并行推理模式时,SoVits V3/4模型不支持分桶处理,已自动关闭分桶处理"))
|
||||
split_bucket = False
|
||||
else:
|
||||
print(i18n("分桶处理模式已关闭"))
|
||||
|
||||
if fragment_interval < 0.01:
|
||||
fragment_interval = 0.01
|
||||
@ -1102,7 +1065,7 @@ class TTS:
|
||||
prompt_data=self.prompt_cache if not no_prompt_text else None,
|
||||
batch_size=batch_size,
|
||||
threshold=batch_threshold,
|
||||
split_bucket=split_bucket,
|
||||
split_bucket=False,
|
||||
device=self.configs.device,
|
||||
precision=self.precision,
|
||||
)
|
||||
@ -1158,37 +1121,44 @@ class TTS:
|
||||
if item is None:
|
||||
continue
|
||||
|
||||
batch_phones: List[torch.LongTensor] = item["phones"]
|
||||
batch_phones: list[torch.Tensor] = item["phones"]
|
||||
# batch_phones:torch.LongTensor = item["phones"]
|
||||
batch_phones_len: torch.LongTensor = item["phones_len"]
|
||||
all_phoneme_ids: torch.LongTensor = item["all_phones"]
|
||||
all_phoneme_lens: torch.LongTensor = item["all_phones_len"]
|
||||
all_bert_features: torch.LongTensor = item["all_bert_features"]
|
||||
batch_phones_len: torch.Tensor = item["phones_len"]
|
||||
all_phoneme_ids: list[torch.Tensor] = item["all_phones"]
|
||||
all_phoneme_lens: torch.Tensor = item["all_phones_len"]
|
||||
all_bert_features: list[torch.Tensor] = item["all_bert_features"]
|
||||
norm_text: str = item["norm_text"]
|
||||
max_len = item["max_len"]
|
||||
|
||||
print(i18n("前端处理后的文本(每句):"), norm_text)
|
||||
if no_prompt_text:
|
||||
prompt = None
|
||||
prompt = torch.zeros(1, 0).to(self.configs.device, self.precision)
|
||||
else:
|
||||
prompt = (
|
||||
self.prompt_cache["prompt_semantic"].expand(len(all_phoneme_ids), -1).to(self.configs.device)
|
||||
)
|
||||
prompt = self.prompt_cache["prompt_semantic"].to(self.configs.device).unsqueeze(0)
|
||||
|
||||
print(f"############ {i18n('预测语义Token')} ############")
|
||||
pred_semantic_list, idx_list = self.t2s_model.model.infer_panel(
|
||||
t2s_request = T2SRequest(
|
||||
all_phoneme_ids,
|
||||
all_phoneme_lens,
|
||||
prompt,
|
||||
all_bert_features,
|
||||
# prompt_phone_len=ph_offset,
|
||||
valid_length=len(all_phoneme_ids),
|
||||
top_k=top_k,
|
||||
top_p=top_p,
|
||||
temperature=temperature,
|
||||
early_stop_num=self.configs.hz * self.configs.max_sec,
|
||||
max_len=max_len,
|
||||
repetition_penalty=repetition_penalty,
|
||||
debug=os.environ.get("DEBUG", "0") == "1",
|
||||
use_cuda_graph=torch.cuda.is_available(),
|
||||
)
|
||||
t2s_result = self.t2s_model.generate(t2s_request)
|
||||
|
||||
if t2s_result.exception is not None:
|
||||
print(t2s_result.traceback)
|
||||
raise RuntimeError()
|
||||
|
||||
pred_semantic_list = t2s_result.result
|
||||
assert pred_semantic_list
|
||||
pred_semantic_list = [semantic.squeeze(0) for semantic in pred_semantic_list]
|
||||
|
||||
t4 = time.perf_counter()
|
||||
t_34 += t4 - t3
|
||||
|
||||
@ -1220,7 +1190,6 @@ class TTS:
|
||||
if speed_factor == 1.0:
|
||||
print(f"{i18n('并行合成中')}...")
|
||||
# ## vits并行推理 method 2
|
||||
pred_semantic_list = [item[-idx:] for item, idx in zip(pred_semantic_list, idx_list)]
|
||||
upsample_rate = math.prod(self.vits_model.upsample_rates)
|
||||
audio_frag_idx = [
|
||||
pred_semantic_list[i].shape[0] * 2 * upsample_rate
|
||||
@ -1231,7 +1200,7 @@ class TTS:
|
||||
torch.cat(pred_semantic_list).unsqueeze(0).unsqueeze(0).to(self.configs.device)
|
||||
)
|
||||
_batch_phones = torch.cat(batch_phones).unsqueeze(0).to(self.configs.device)
|
||||
if self.is_v2pro != True:
|
||||
if not self.is_v2pro:
|
||||
_batch_audio_fragment = self.vits_model.decode(
|
||||
all_pred_semantic, _batch_phones, refer_audio_spec, speed=speed_factor
|
||||
).detach()[0, 0, :]
|
||||
@ -1251,7 +1220,7 @@ class TTS:
|
||||
_pred_semantic = (
|
||||
pred_semantic_list[i][-idx:].unsqueeze(0).unsqueeze(0)
|
||||
) # .unsqueeze(0)#mq要多unsqueeze一次
|
||||
if self.is_v2pro != True:
|
||||
if not self.is_v2pro:
|
||||
audio_fragment = self.vits_model.decode(
|
||||
_pred_semantic, phones, refer_audio_spec, speed=speed_factor
|
||||
).detach()[0, 0, :]
|
||||
@ -1308,7 +1277,7 @@ class TTS:
|
||||
output_sr,
|
||||
batch_index_list,
|
||||
speed_factor,
|
||||
split_bucket,
|
||||
False,
|
||||
fragment_interval,
|
||||
super_sampling if self.configs.use_vocoder and self.configs.version == "v3" else False,
|
||||
)
|
||||
@ -1340,14 +1309,14 @@ class TTS:
|
||||
|
||||
def audio_postprocess(
|
||||
self,
|
||||
audio: List[torch.Tensor],
|
||||
audio: list[torch.Tensor],
|
||||
sr: int,
|
||||
batch_index_list: list = None,
|
||||
speed_factor: float = 1.0,
|
||||
split_bucket: bool = True,
|
||||
split_bucket: bool = False,
|
||||
fragment_interval: float = 0.3,
|
||||
super_sampling: bool = False,
|
||||
) -> Tuple[int, np.ndarray]:
|
||||
) -> tuple[int, np.ndarray]:
|
||||
zero_wav = torch.zeros(
|
||||
int(self.configs.sampling_rate * fragment_interval), dtype=self.precision, device=self.configs.device
|
||||
)
|
||||
@ -1459,12 +1428,12 @@ class TTS:
|
||||
|
||||
def using_vocoder_synthesis_batched_infer(
|
||||
self,
|
||||
idx_list: List[int],
|
||||
semantic_tokens_list: List[torch.Tensor],
|
||||
batch_phones: List[torch.Tensor],
|
||||
idx_list: list[int],
|
||||
semantic_tokens_list: list[torch.Tensor],
|
||||
batch_phones: list[torch.Tensor],
|
||||
speed: float = 1.0,
|
||||
sample_steps: int = 32,
|
||||
) -> List[torch.Tensor]:
|
||||
) -> list[torch.Tensor]:
|
||||
prompt_semantic_tokens = self.prompt_cache["prompt_semantic"].unsqueeze(0).unsqueeze(0).to(self.configs.device)
|
||||
prompt_phones = torch.LongTensor(self.prompt_cache["phones"]).unsqueeze(0).to(self.configs.device)
|
||||
raw_entry = self.prompt_cache["refer_spec"][0]
|
||||
@ -1574,7 +1543,7 @@ class TTS:
|
||||
|
||||
def sola_algorithm(
|
||||
self,
|
||||
audio_fragments: List[torch.Tensor],
|
||||
audio_fragments: list[torch.Tensor],
|
||||
overlap_len: int,
|
||||
):
|
||||
for i in range(len(audio_fragments) - 1):
|
||||
|
||||
@ -365,7 +365,7 @@ def change_sovits_weights(sovits_path, prompt_language=None, text_language=None)
|
||||
del vq_model.enc_q
|
||||
|
||||
if is_lora is False:
|
||||
console.print(f">> loading sovits_{model_version}", vq_model.load_state_dict(dict_s2["weight"], strict=False))
|
||||
console.print(f">> loading sovits_{model_version}", vq_model.load_state_dict(dict_s2["weight"]))
|
||||
else:
|
||||
path_sovits = path_sovits_v3 if model_version == "v3" else path_sovits_v4
|
||||
console.print(f">> loading sovits_{model_version}spretrained_G")
|
||||
@ -381,7 +381,7 @@ def change_sovits_weights(sovits_path, prompt_language=None, text_language=None)
|
||||
init_lora_weights=True,
|
||||
)
|
||||
vq_model.cfm = get_peft_model(vq_model.cfm, lora_config) # type: ignore
|
||||
vq_model.load_state_dict(state_dict, strict=False)
|
||||
vq_model.load_state_dict(state_dict)
|
||||
vq_model.cfm = vq_model.cfm.merge_and_unload() # pyright: ignore[reportAttributeAccessIssue, reportCallIssue]
|
||||
vq_model.eval()
|
||||
|
||||
|
||||
@ -138,9 +138,6 @@ is_share = args.share
|
||||
infer_device = torch.device(args.device)
|
||||
device = infer_device
|
||||
|
||||
if torch.mps.is_available():
|
||||
device = torch.device("cpu")
|
||||
|
||||
dtype = get_dtype(device.index)
|
||||
is_half = dtype == torch.float16
|
||||
|
||||
@ -152,7 +149,7 @@ SoVITS_names, GPT_names = get_weights_names(i18n)
|
||||
gpt_path = str(args.gpt) or GPT_names[0][-1]
|
||||
sovits_path = str(args.sovits) or SoVITS_names[0][-1]
|
||||
|
||||
cnhubert_base_path = str(args.cuhubert)
|
||||
cnhubert_base_path = str(args.cnhubert)
|
||||
bert_path = str(args.bert)
|
||||
|
||||
version = model_version = "v2"
|
||||
@ -415,7 +412,7 @@ with gr.Blocks(title="GPT-SoVITS WebUI", analytics_enabled=False, js=js, css=css
|
||||
with gr.Column():
|
||||
with gr.Row(equal_height=True):
|
||||
batch_size = gr.Slider(
|
||||
minimum=1, maximum=200, step=1, label=i18n("batch_size"), value=20, interactive=True
|
||||
minimum=1, maximum=20, step=1, label=i18n("batch_size"), value=10, interactive=True
|
||||
)
|
||||
sample_steps = gr.Radio(
|
||||
label=i18n("采样步数(仅对V3/4生效)"), value=32, choices=[4, 8, 16, 32, 64, 128], visible=True
|
||||
@ -462,8 +459,8 @@ with gr.Blocks(title="GPT-SoVITS WebUI", analytics_enabled=False, js=js, css=css
|
||||
parallel_infer = gr.Checkbox(label=i18n("并行推理"), value=True, interactive=True, show_label=True)
|
||||
split_bucket = gr.Checkbox(
|
||||
label=i18n("数据分桶(并行推理时会降低一点计算量)"),
|
||||
value=True,
|
||||
interactive=True,
|
||||
value=False,
|
||||
interactive=False,
|
||||
show_label=True,
|
||||
)
|
||||
|
||||
|
||||
2
test.py
2
test.py
@ -323,7 +323,7 @@ def change_sovits_weights(sovits_path, prompt_language=None, text_language=None)
|
||||
if hasattr(vq_model, "enc_q"):
|
||||
del vq_model.enc_q
|
||||
|
||||
vq_model.load_state_dict(dict_s2["weight"], strict=False)
|
||||
vq_model.load_state_dict(dict_s2["weight"])
|
||||
|
||||
vq_model = vq_model.to(infer_device, dtype)
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user