本地的tts.py不是最新的,导致之前的修复被替换了。回滚到最新版本并添加v2支持

This commit is contained in:
CyberWon 2024-08-08 22:49:57 +08:00
parent 16a37b7b48
commit bd53aa8200

View File

@ -1,6 +1,6 @@
from copy import deepcopy from copy import deepcopy
import math import math
import os, sys import os, sys, gc
import random import random
import traceback import traceback
@ -39,7 +39,7 @@ default:
cnhuhbert_base_path: GPT_SoVITS/pretrained_models/chinese-hubert-base cnhuhbert_base_path: GPT_SoVITS/pretrained_models/chinese-hubert-base
t2s_weights_path: GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt t2s_weights_path: GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt
vits_weights_path: GPT_SoVITS/pretrained_models/s2G488k.pth vits_weights_path: GPT_SoVITS/pretrained_models/s2G488k.pth
custom: custom:
device: cuda device: cuda
is_half: true is_half: true
@ -209,7 +209,7 @@ class TTS:
self.text_preprocessor: TextPreprocessor = \ self.text_preprocessor: TextPreprocessor = \
TextPreprocessor(self.bert_model, TextPreprocessor(self.bert_model,
self.bert_tokenizer, self.bert_tokenizer,
self.configs.device, version=self.version) self.configs.device)
self.prompt_cache: dict = { self.prompt_cache: dict = {
"ref_audio_path": None, "ref_audio_path": None,
@ -301,12 +301,12 @@ class TTS:
if self.configs.is_half and str(self.configs.device) != "cpu": if self.configs.is_half and str(self.configs.device) != "cpu":
self.t2s_model = self.t2s_model.half() self.t2s_model = self.t2s_model.half()
def enable_half_precision(self, enable: bool = True): def enable_half_precision(self, enable: bool = True, save: bool = True):
''' '''
To enable half precision for the TTS model. To enable half precision for the TTS model.
Args: Args:
enable: bool, whether to enable half precision. enable: bool, whether to enable half precision.
''' '''
if str(self.configs.device) == "cpu" and enable: if str(self.configs.device) == "cpu" and enable:
print("Half precision is not supported on CPU.") print("Half precision is not supported on CPU.")
@ -314,7 +314,8 @@ class TTS:
self.configs.is_half = enable self.configs.is_half = enable
self.precision = torch.float16 if enable else torch.float32 self.precision = torch.float16 if enable else torch.float32
self.configs.save_configs() if save:
self.configs.save_configs()
if enable: if enable:
if self.t2s_model is not None: if self.t2s_model is not None:
self.t2s_model = self.t2s_model.half() self.t2s_model = self.t2s_model.half()
@ -334,14 +335,15 @@ class TTS:
if self.cnhuhbert_model is not None: if self.cnhuhbert_model is not None:
self.cnhuhbert_model = self.cnhuhbert_model.float() self.cnhuhbert_model = self.cnhuhbert_model.float()
def set_device(self, device: torch.device): def set_device(self, device: torch.device, save: bool = True):
''' '''
To set the device for all models. To set the device for all models.
Args: Args:
device: torch.device, the device to use for all models. device: torch.device, the device to use for all models.
''' '''
self.configs.device = device self.configs.device = device
self.configs.save_configs() if save:
self.configs.save_configs()
if self.t2s_model is not None: if self.t2s_model is not None:
self.t2s_model = self.t2s_model.to(device) self.t2s_model = self.t2s_model.to(device)
if self.vits_model is not None: if self.vits_model is not None:
@ -353,13 +355,17 @@ class TTS:
def set_ref_audio(self, ref_audio_path: str): def set_ref_audio(self, ref_audio_path: str):
''' '''
To set the reference audio for the TTS model, To set the reference audio for the TTS model,
including the prompt_semantic and refer_spepc. including the prompt_semantic and refer_spepc.
Args: Args:
ref_audio_path: str, the path of the reference audio. ref_audio_path: str, the path of the reference audio.
''' '''
self._set_prompt_semantic(ref_audio_path) self._set_prompt_semantic(ref_audio_path)
self._set_ref_spec(ref_audio_path) self._set_ref_spec(ref_audio_path)
self._set_ref_audio_path(ref_audio_path)
def _set_ref_audio_path(self, ref_audio_path):
self.prompt_cache["ref_audio_path"] = ref_audio_path
def _set_ref_spec(self, ref_audio_path): def _set_ref_spec(self, ref_audio_path):
audio = load_audio(ref_audio_path, int(self.configs.sampling_rate)) audio = load_audio(ref_audio_path, int(self.configs.sampling_rate))
@ -545,11 +551,11 @@ class TTS:
def recovery_order(self, data: list, batch_index_list: list) -> list: def recovery_order(self, data: list, batch_index_list: list) -> list:
''' '''
Recovery the order of the audio according to the batch_index_list. Recovery the order of the audio according to the batch_index_list.
Args: Args:
data (List[list(np.ndarray)]): the out of order audio . data (List[list(np.ndarray)]): the out of order audio .
batch_index_list (List[list[int]]): the batch index list. batch_index_list (List[list[int]]): the batch index list.
Returns: Returns:
list (List[np.ndarray]): the data in the original order. list (List[np.ndarray]): the data in the original order.
''' '''
@ -570,9 +576,9 @@ class TTS:
def run(self, inputs: dict): def run(self, inputs: dict):
""" """
Text to speech inference. Text to speech inference.
Args: Args:
inputs (dict): inputs (dict):
{ {
"text": "", # str.(required) text to be synthesized "text": "", # str.(required) text to be synthesized
"text_lang: "", # str.(required) language of the text to be synthesized "text_lang: "", # str.(required) language of the text to be synthesized
@ -801,7 +807,6 @@ class TTS:
audio_frag_end_idx = [sum(audio_frag_idx[:i + 1]) for i in range(0, len(audio_frag_idx))] audio_frag_end_idx = [sum(audio_frag_idx[:i + 1]) for i in range(0, len(audio_frag_idx))]
all_pred_semantic = torch.cat(pred_semantic_list).unsqueeze(0).unsqueeze(0).to(self.configs.device) all_pred_semantic = torch.cat(pred_semantic_list).unsqueeze(0).unsqueeze(0).to(self.configs.device)
_batch_phones = torch.cat(batch_phones).unsqueeze(0).to(self.configs.device) _batch_phones = torch.cat(batch_phones).unsqueeze(0).to(self.configs.device)
_batch_audio_fragment = (self.vits_model.decode( _batch_audio_fragment = (self.vits_model.decode(
all_pred_semantic, _batch_phones, refer_audio_spec all_pred_semantic, _batch_phones, refer_audio_spec
).detach()[0, 0, :]) ).detach()[0, 0, :])
@ -867,6 +872,7 @@ class TTS:
def empty_cache(self): def empty_cache(self):
try: try:
gc.collect() # 触发gc的垃圾回收。避免内存一直增长。
if "cuda" in str(self.configs.device): if "cuda" in str(self.configs.device):
torch.cuda.empty_cache() torch.cuda.empty_cache()
elif str(self.configs.device) == "mps": elif str(self.configs.device) == "mps":
@ -932,4 +938,4 @@ def speed_change(input_audio: np.ndarray, speed: float, sr: int):
# 将管道输出解码为 NumPy 数组 # 将管道输出解码为 NumPy 数组
processed_audio = np.frombuffer(out, np.int16) processed_audio = np.frombuffer(out, np.int16)
return processed_audio return processed_audio