mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2025-04-27 03:56:41 +08:00
修复了,中英文混合文本合成英文时, 出现空字符报错的问题
优化了代码, 增加了健壮性
This commit is contained in:
parent
46826b28b0
commit
bfd7286068
@ -173,35 +173,36 @@ class TTS:
|
|||||||
|
|
||||||
|
|
||||||
self.stop_flag:bool = False
|
self.stop_flag:bool = False
|
||||||
|
self.precison:torch.dtype = torch.float16 if self.configs.is_half else torch.float32
|
||||||
|
|
||||||
def _init_models(self,):
|
def _init_models(self,):
|
||||||
self.init_t2s_weights(self.configs.t2s_weights_path)
|
self.init_t2s_weights(self.configs.t2s_weights_path)
|
||||||
self.init_vits_weights(self.configs.vits_weights_path)
|
self.init_vits_weights(self.configs.vits_weights_path)
|
||||||
self.init_bert_weights(self.configs.bert_base_path)
|
self.init_bert_weights(self.configs.bert_base_path)
|
||||||
self.init_cnhuhbert_weights(self.configs.cnhuhbert_base_path)
|
self.init_cnhuhbert_weights(self.configs.cnhuhbert_base_path)
|
||||||
|
self.enable_half_precision(self.configs.is_half)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def init_cnhuhbert_weights(self, base_path: str):
|
def init_cnhuhbert_weights(self, base_path: str):
|
||||||
|
print(f"Loading CNHuBERT weights from {base_path}")
|
||||||
self.cnhuhbert_model = CNHubert(base_path)
|
self.cnhuhbert_model = CNHubert(base_path)
|
||||||
self.cnhuhbert_model=self.cnhuhbert_model.eval()
|
self.cnhuhbert_model=self.cnhuhbert_model.eval()
|
||||||
if self.configs.is_half == True:
|
|
||||||
self.cnhuhbert_model = self.cnhuhbert_model.half()
|
|
||||||
self.cnhuhbert_model = self.cnhuhbert_model.to(self.configs.device)
|
self.cnhuhbert_model = self.cnhuhbert_model.to(self.configs.device)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def init_bert_weights(self, base_path: str):
|
def init_bert_weights(self, base_path: str):
|
||||||
|
print(f"Loading BERT weights from {base_path}")
|
||||||
self.bert_tokenizer = AutoTokenizer.from_pretrained(base_path)
|
self.bert_tokenizer = AutoTokenizer.from_pretrained(base_path)
|
||||||
self.bert_model = AutoModelForMaskedLM.from_pretrained(base_path)
|
self.bert_model = AutoModelForMaskedLM.from_pretrained(base_path)
|
||||||
self.bert_model=self.bert_model.eval()
|
self.bert_model=self.bert_model.eval()
|
||||||
if self.configs.is_half:
|
|
||||||
self.bert_model = self.bert_model.half()
|
|
||||||
self.bert_model = self.bert_model.to(self.configs.device)
|
self.bert_model = self.bert_model.to(self.configs.device)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def init_vits_weights(self, weights_path: str):
|
def init_vits_weights(self, weights_path: str):
|
||||||
|
print(f"Loading VITS weights from {weights_path}")
|
||||||
self.configs.vits_weights_path = weights_path
|
self.configs.vits_weights_path = weights_path
|
||||||
self.configs.save_configs()
|
self.configs.save_configs()
|
||||||
dict_s2 = torch.load(weights_path, map_location=self.configs.device)
|
dict_s2 = torch.load(weights_path, map_location=self.configs.device)
|
||||||
@ -224,8 +225,6 @@ class TTS:
|
|||||||
if hasattr(vits_model, "enc_q"):
|
if hasattr(vits_model, "enc_q"):
|
||||||
del vits_model.enc_q
|
del vits_model.enc_q
|
||||||
|
|
||||||
if self.configs.is_half:
|
|
||||||
vits_model = vits_model.half()
|
|
||||||
vits_model = vits_model.to(self.configs.device)
|
vits_model = vits_model.to(self.configs.device)
|
||||||
vits_model = vits_model.eval()
|
vits_model = vits_model.eval()
|
||||||
vits_model.load_state_dict(dict_s2["weight"], strict=False)
|
vits_model.load_state_dict(dict_s2["weight"], strict=False)
|
||||||
@ -233,6 +232,7 @@ class TTS:
|
|||||||
|
|
||||||
|
|
||||||
def init_t2s_weights(self, weights_path: str):
|
def init_t2s_weights(self, weights_path: str):
|
||||||
|
print(f"Loading Text2Semantic weights from {weights_path}")
|
||||||
self.configs.t2s_weights_path = weights_path
|
self.configs.t2s_weights_path = weights_path
|
||||||
self.configs.save_configs()
|
self.configs.save_configs()
|
||||||
self.configs.hz = 50
|
self.configs.hz = 50
|
||||||
@ -242,12 +242,60 @@ class TTS:
|
|||||||
t2s_model = Text2SemanticLightningModule(config, "****", is_train=False,
|
t2s_model = Text2SemanticLightningModule(config, "****", is_train=False,
|
||||||
flash_attn_enabled=self.configs.flash_attn_enabled)
|
flash_attn_enabled=self.configs.flash_attn_enabled)
|
||||||
t2s_model.load_state_dict(dict_s1["weight"])
|
t2s_model.load_state_dict(dict_s1["weight"])
|
||||||
if self.configs.is_half:
|
|
||||||
t2s_model = t2s_model.half()
|
|
||||||
t2s_model = t2s_model.to(self.configs.device)
|
t2s_model = t2s_model.to(self.configs.device)
|
||||||
t2s_model = t2s_model.eval()
|
t2s_model = t2s_model.eval()
|
||||||
self.t2s_model = t2s_model
|
self.t2s_model = t2s_model
|
||||||
|
|
||||||
|
def enable_half_precision(self, enable: bool = True):
|
||||||
|
'''
|
||||||
|
To enable half precision for the TTS model.
|
||||||
|
Args:
|
||||||
|
enable: bool, whether to enable half precision.
|
||||||
|
|
||||||
|
'''
|
||||||
|
if self.configs.device == "cpu":
|
||||||
|
print("Half precision is not supported on CPU.")
|
||||||
|
return
|
||||||
|
|
||||||
|
self.configs.is_half = enable
|
||||||
|
self.precison = torch.float16 if enable else torch.float32
|
||||||
|
self.configs.save_configs()
|
||||||
|
if enable:
|
||||||
|
if self.t2s_model is not None:
|
||||||
|
self.t2s_model =self.t2s_model.half()
|
||||||
|
if self.vits_model is not None:
|
||||||
|
self.vits_model = self.vits_model.half()
|
||||||
|
if self.bert_model is not None:
|
||||||
|
self.bert_model =self.bert_model.half()
|
||||||
|
if self.cnhuhbert_model is not None:
|
||||||
|
self.cnhuhbert_model = self.cnhuhbert_model.half()
|
||||||
|
else:
|
||||||
|
if self.t2s_model is not None:
|
||||||
|
self.t2s_model = self.t2s_model.float()
|
||||||
|
if self.vits_model is not None:
|
||||||
|
self.vits_model = self.vits_model.float()
|
||||||
|
if self.bert_model is not None:
|
||||||
|
self.bert_model = self.bert_model.float()
|
||||||
|
if self.cnhuhbert_model is not None:
|
||||||
|
self.cnhuhbert_model = self.cnhuhbert_model.float()
|
||||||
|
|
||||||
|
def set_device(self, device: torch.device):
|
||||||
|
'''
|
||||||
|
To set the device for all models.
|
||||||
|
Args:
|
||||||
|
device: torch.device, the device to use for all models.
|
||||||
|
'''
|
||||||
|
self.configs.device = device
|
||||||
|
self.configs.save_configs()
|
||||||
|
if self.t2s_model is not None:
|
||||||
|
self.t2s_model = self.t2s_model.to(device)
|
||||||
|
if self.vits_model is not None:
|
||||||
|
self.vits_model = self.vits_model.to(device)
|
||||||
|
if self.bert_model is not None:
|
||||||
|
self.bert_model = self.bert_model.to(device)
|
||||||
|
if self.cnhuhbert_model is not None:
|
||||||
|
self.cnhuhbert_model = self.cnhuhbert_model.to(device)
|
||||||
|
|
||||||
def set_ref_audio(self, ref_audio_path:str):
|
def set_ref_audio(self, ref_audio_path:str):
|
||||||
'''
|
'''
|
||||||
To set the reference audio for the TTS model,
|
To set the reference audio for the TTS model,
|
||||||
@ -347,7 +395,7 @@ class TTS:
|
|||||||
pos_end = min(pos+batch_size,index_and_len_list.shape[0])
|
pos_end = min(pos+batch_size,index_and_len_list.shape[0])
|
||||||
while pos < pos_end:
|
while pos < pos_end:
|
||||||
batch=index_and_len_list[pos:pos_end, 1].astype(np.float32)
|
batch=index_and_len_list[pos:pos_end, 1].astype(np.float32)
|
||||||
score=batch[(pos_end-pos)//2]/batch.mean()
|
score=batch[(pos_end-pos)//2]/(batch.mean()+1e-8)
|
||||||
if (score>=threshold) or (pos_end-pos==1):
|
if (score>=threshold) or (pos_end-pos==1):
|
||||||
batch_index=index_and_len_list[pos:pos_end, 0].tolist()
|
batch_index=index_and_len_list[pos:pos_end, 0].tolist()
|
||||||
batch_index_list_len += len(batch_index)
|
batch_index_list_len += len(batch_index)
|
||||||
@ -379,13 +427,13 @@ class TTS:
|
|||||||
for item in item_list:
|
for item in item_list:
|
||||||
if prompt_data is not None:
|
if prompt_data is not None:
|
||||||
all_bert_features = torch.cat([prompt_data["bert_features"], item["bert_features"]], 1)\
|
all_bert_features = torch.cat([prompt_data["bert_features"], item["bert_features"]], 1)\
|
||||||
.to(dtype=torch.float32 if not self.configs.is_half else torch.float16)
|
.to(dtype=self.precison)
|
||||||
all_phones = torch.LongTensor(prompt_data["phones"]+item["phones"])
|
all_phones = torch.LongTensor(prompt_data["phones"]+item["phones"])
|
||||||
phones = torch.LongTensor(item["phones"])
|
phones = torch.LongTensor(item["phones"])
|
||||||
# norm_text = prompt_data["norm_text"]+item["norm_text"]
|
# norm_text = prompt_data["norm_text"]+item["norm_text"]
|
||||||
else:
|
else:
|
||||||
all_bert_features = item["bert_features"]\
|
all_bert_features = item["bert_features"]\
|
||||||
.to(dtype=torch.float32 if not self.configs.is_half else torch.float16)
|
.to(dtype=self.precison)
|
||||||
phones = torch.LongTensor(item["phones"])
|
phones = torch.LongTensor(item["phones"])
|
||||||
all_phones = phones
|
all_phones = phones
|
||||||
# norm_text = item["norm_text"]
|
# norm_text = item["norm_text"]
|
||||||
@ -405,7 +453,7 @@ class TTS:
|
|||||||
# phones_batch = self.batch_sequences(phones_list, axis=0, pad_value=0, max_length=max_len)
|
# phones_batch = self.batch_sequences(phones_list, axis=0, pad_value=0, max_length=max_len)
|
||||||
all_phones_batch = self.batch_sequences(all_phones_list, axis=0, pad_value=0, max_length=max_len)
|
all_phones_batch = self.batch_sequences(all_phones_list, axis=0, pad_value=0, max_length=max_len)
|
||||||
# all_bert_features_batch = all_bert_features_list
|
# all_bert_features_batch = all_bert_features_list
|
||||||
all_bert_features_batch = torch.zeros(len(item_list), 1024, max_len, dtype=torch.float32)
|
all_bert_features_batch = torch.zeros(len(item_list), 1024, max_len, dtype=self.precison)
|
||||||
for idx, item in enumerate(all_bert_features_list):
|
for idx, item in enumerate(all_bert_features_list):
|
||||||
all_bert_features_batch[idx, :, : item.shape[-1]] = item
|
all_bert_features_batch[idx, :, : item.shape[-1]] = item
|
||||||
|
|
||||||
@ -535,6 +583,11 @@ class TTS:
|
|||||||
|
|
||||||
###### text preprocessing ########
|
###### text preprocessing ########
|
||||||
data = self.text_preprocessor.preprocess(text, text_lang, text_split_method)
|
data = self.text_preprocessor.preprocess(text, text_lang, text_split_method)
|
||||||
|
if len(data) == 0:
|
||||||
|
yield self.configs.sampling_rate, np.zeros(int(self.configs.sampling_rate * 0.3),
|
||||||
|
dtype=np.int16)
|
||||||
|
return
|
||||||
|
|
||||||
t1 = ttime()
|
t1 = ttime()
|
||||||
data, batch_index_list = self.to_batch(data,
|
data, batch_index_list = self.to_batch(data,
|
||||||
prompt_data=self.prompt_cache if not no_prompt_text else None,
|
prompt_data=self.prompt_cache if not no_prompt_text else None,
|
||||||
@ -587,10 +640,8 @@ class TTS:
|
|||||||
t4 = ttime()
|
t4 = ttime()
|
||||||
t_34 += t4 - t3
|
t_34 += t4 - t3
|
||||||
|
|
||||||
refer_audio_spepc:torch.Tensor = self.prompt_cache["refer_spepc"].to(self.configs.device)
|
refer_audio_spepc:torch.Tensor = self.prompt_cache["refer_spepc"]\
|
||||||
if self.configs.is_half:
|
.to(dtype=self.precison, device=self.configs.device)
|
||||||
refer_audio_spepc = refer_audio_spepc.half()
|
|
||||||
|
|
||||||
|
|
||||||
batch_audio_fragment = []
|
batch_audio_fragment = []
|
||||||
|
|
||||||
@ -672,7 +723,7 @@ class TTS:
|
|||||||
split_bucket:bool=True)->tuple[int, np.ndarray]:
|
split_bucket:bool=True)->tuple[int, np.ndarray]:
|
||||||
zero_wav = torch.zeros(
|
zero_wav = torch.zeros(
|
||||||
int(self.configs.sampling_rate * 0.3),
|
int(self.configs.sampling_rate * 0.3),
|
||||||
dtype=torch.float16 if self.configs.is_half else torch.float32,
|
dtype=self.precison,
|
||||||
device=self.configs.device
|
device=self.configs.device
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -59,6 +59,8 @@ class TextPreprocessor:
|
|||||||
print(i18n("############ 提取文本Bert特征 ############"))
|
print(i18n("############ 提取文本Bert特征 ############"))
|
||||||
for text in tqdm(texts):
|
for text in tqdm(texts):
|
||||||
phones, bert_features, norm_text = self.segment_and_extract_feature_for_text(text, lang)
|
phones, bert_features, norm_text = self.segment_and_extract_feature_for_text(text, lang)
|
||||||
|
if phones is None:
|
||||||
|
continue
|
||||||
res={
|
res={
|
||||||
"phones": phones,
|
"phones": phones,
|
||||||
"bert_features": bert_features,
|
"bert_features": bert_features,
|
||||||
@ -79,14 +81,12 @@ class TextPreprocessor:
|
|||||||
|
|
||||||
while "\n\n" in text:
|
while "\n\n" in text:
|
||||||
text = text.replace("\n\n", "\n")
|
text = text.replace("\n\n", "\n")
|
||||||
print(i18n("实际输入的目标文本(切句后):"))
|
|
||||||
print(text)
|
|
||||||
_texts = text.split("\n")
|
_texts = text.split("\n")
|
||||||
_texts = merge_short_text_in_array(_texts, 5)
|
_texts = merge_short_text_in_array(_texts, 5)
|
||||||
texts = []
|
texts = []
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
for text in _texts:
|
for text in _texts:
|
||||||
# 解决输入目标文本的空行导致报错的问题
|
# 解决输入目标文本的空行导致报错的问题
|
||||||
if (len(text.strip()) == 0):
|
if (len(text.strip()) == 0):
|
||||||
@ -94,15 +94,21 @@ class TextPreprocessor:
|
|||||||
if (text[-1] not in splits): text += "。" if lang != "en" else "."
|
if (text[-1] not in splits): text += "。" if lang != "en" else "."
|
||||||
|
|
||||||
# 解决句子过长导致Bert报错的问题
|
# 解决句子过长导致Bert报错的问题
|
||||||
|
if (len(text) > 510):
|
||||||
texts.extend(split_big_text(text))
|
texts.extend(split_big_text(text))
|
||||||
|
else:
|
||||||
|
texts.append(text)
|
||||||
|
|
||||||
|
print(i18n("实际输入的目标文本(切句后):"))
|
||||||
|
print(texts)
|
||||||
return texts
|
return texts
|
||||||
|
|
||||||
def segment_and_extract_feature_for_text(self, texts:list, language:str)->Tuple[list, torch.Tensor, str]:
|
def segment_and_extract_feature_for_text(self, texts:list, language:str)->Tuple[list, torch.Tensor, str]:
|
||||||
textlist, langlist = self.seg_text(texts, language)
|
textlist, langlist = self.seg_text(texts, language)
|
||||||
phones, bert_features, norm_text = self.extract_bert_feature(textlist, langlist)
|
if len(textlist) == 0:
|
||||||
|
return None, None, None
|
||||||
|
|
||||||
|
phones, bert_features, norm_text = self.extract_bert_feature(textlist, langlist)
|
||||||
return phones, bert_features, norm_text
|
return phones, bert_features, norm_text
|
||||||
|
|
||||||
|
|
||||||
@ -113,6 +119,8 @@ class TextPreprocessor:
|
|||||||
if language in ["auto", "zh", "ja"]:
|
if language in ["auto", "zh", "ja"]:
|
||||||
LangSegment.setfilters(["zh","ja","en","ko"])
|
LangSegment.setfilters(["zh","ja","en","ko"])
|
||||||
for tmp in LangSegment.getTexts(text):
|
for tmp in LangSegment.getTexts(text):
|
||||||
|
if tmp["text"] == "":
|
||||||
|
continue
|
||||||
if tmp["lang"] == "ko":
|
if tmp["lang"] == "ko":
|
||||||
langlist.append("zh")
|
langlist.append("zh")
|
||||||
elif tmp["lang"] == "en":
|
elif tmp["lang"] == "en":
|
||||||
@ -126,14 +134,18 @@ class TextPreprocessor:
|
|||||||
formattext = " ".join(tmp["text"] for tmp in LangSegment.getTexts(text))
|
formattext = " ".join(tmp["text"] for tmp in LangSegment.getTexts(text))
|
||||||
while " " in formattext:
|
while " " in formattext:
|
||||||
formattext = formattext.replace(" ", " ")
|
formattext = formattext.replace(" ", " ")
|
||||||
|
if formattext != "":
|
||||||
textlist.append(formattext)
|
textlist.append(formattext)
|
||||||
langlist.append("en")
|
langlist.append("en")
|
||||||
|
|
||||||
elif language in ["all_zh","all_ja"]:
|
elif language in ["all_zh","all_ja"]:
|
||||||
|
|
||||||
formattext = text
|
formattext = text
|
||||||
while " " in formattext:
|
while " " in formattext:
|
||||||
formattext = formattext.replace(" ", " ")
|
formattext = formattext.replace(" ", " ")
|
||||||
language = language.replace("all_","")
|
language = language.replace("all_","")
|
||||||
|
if text == "":
|
||||||
|
return [],[]
|
||||||
textlist.append(formattext)
|
textlist.append(formattext)
|
||||||
langlist.append(language)
|
langlist.append(language)
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user