mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2025-08-24 05:29:45 +08:00
修改批处理bert特征
This commit is contained in:
parent
0d654cd238
commit
c7b61c6fd4
@ -61,6 +61,7 @@ class TextPreprocessor:
|
|||||||
text = self.replace_consecutive_punctuation(text)
|
text = self.replace_consecutive_punctuation(text)
|
||||||
texts = self.pre_seg_text(text, lang, text_split_method)
|
texts = self.pre_seg_text(text, lang, text_split_method)
|
||||||
result = []
|
result = []
|
||||||
|
# text_batch = []
|
||||||
print(f"############ {i18n('提取文本Bert特征')} ############")
|
print(f"############ {i18n('提取文本Bert特征')} ############")
|
||||||
for text in tqdm(texts):
|
for text in tqdm(texts):
|
||||||
phones, bert_features, norm_text = self.segment_and_extract_feature_for_text(text, lang, version)
|
phones, bert_features, norm_text = self.segment_and_extract_feature_for_text(text, lang, version)
|
||||||
@ -73,6 +74,34 @@ class TextPreprocessor:
|
|||||||
}
|
}
|
||||||
result.append(res)
|
result.append(res)
|
||||||
return result
|
return result
|
||||||
|
# for text in texts:
|
||||||
|
# if text.strip(): # 忽略空句子
|
||||||
|
# text_batch.append(text)
|
||||||
|
# phones_list, bert_list, norm_texts = self.batch_get_phones_and_bert(text_batch, lang, version)
|
||||||
|
# for phones, bert_features, norm_text in zip(phones_list, bert_list, norm_texts):
|
||||||
|
# if phones is None or norm_text == "":
|
||||||
|
# continue
|
||||||
|
# res = {
|
||||||
|
# "phones": phones,
|
||||||
|
# "bert_features": bert_features,
|
||||||
|
# "norm_text": norm_text,
|
||||||
|
# }
|
||||||
|
# result.append(res)
|
||||||
|
# return result
|
||||||
|
|
||||||
|
|
||||||
|
# for text in tqdm(texts):
|
||||||
|
# phones, bert_features, norm_text = self.segment_and_extract_feature_for_text(text, lang, version)
|
||||||
|
# if phones is None or norm_text == "":
|
||||||
|
# continue
|
||||||
|
# res = {
|
||||||
|
# "phones": phones,
|
||||||
|
# "bert_features": bert_features,
|
||||||
|
# "norm_text": norm_text,
|
||||||
|
# }
|
||||||
|
# result.append(res)
|
||||||
|
|
||||||
|
# return result
|
||||||
|
|
||||||
def pre_seg_text(self, text: str, lang: str, text_split_method: str):
|
def pre_seg_text(self, text: str, lang: str, text_split_method: str):
|
||||||
text = text.strip("\n")
|
text = text.strip("\n")
|
||||||
@ -235,3 +264,42 @@ class TextPreprocessor:
|
|||||||
pattern = f"([{punctuations}])([{punctuations}])+"
|
pattern = f"([{punctuations}])([{punctuations}])+"
|
||||||
result = re.sub(pattern, r"\1", text)
|
result = re.sub(pattern, r"\1", text)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def batch_get_phones_and_bert(self, texts: List[str], language: str, version: str):
|
||||||
|
phones_list = []
|
||||||
|
bert_list = []
|
||||||
|
norm_text_list = []
|
||||||
|
|
||||||
|
# 预处理文本,获取每句的 phones, word2ph, norm_text
|
||||||
|
format_texts = [self.clean_text_inf(t, language, version) for t in texts]
|
||||||
|
norm_texts = [x[2] for x in format_texts]
|
||||||
|
word2phs = [x[1] for x in format_texts]
|
||||||
|
|
||||||
|
# 批量送入 tokenizer
|
||||||
|
inputs = self.tokenizer(norm_texts, return_tensors="pt", padding=True, truncation=True)
|
||||||
|
inputs = {k: v.to(self.device) for k, v in inputs.items()}
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
outputs = self.bert_model(**inputs, output_hidden_states=True)
|
||||||
|
# 使用 last_hidden_state 是正确且高效的方式
|
||||||
|
hidden_states = outputs.last_hidden_state # [batch_size, seq_len, hidden_dim]
|
||||||
|
|
||||||
|
for i in range(len(texts)):
|
||||||
|
res = hidden_states[i][1:-1].cpu() # 去掉 [CLS] 和 [SEP]
|
||||||
|
|
||||||
|
word2ph = word2phs[i]
|
||||||
|
phone_level_feature = []
|
||||||
|
for j in range(len(word2ph)):
|
||||||
|
if j >= res.shape[0]:
|
||||||
|
print(f"警告:BERT输出不足,跳过第 {i} 句中第 {j} 个 token")
|
||||||
|
continue
|
||||||
|
phone_level_feature.append(res[j].repeat(word2ph[j], 1))
|
||||||
|
phone_level_feature = torch.cat(phone_level_feature, dim=0)
|
||||||
|
|
||||||
|
bert_list.append(phone_level_feature.T)
|
||||||
|
phones_list.append(cleaned_text_to_sequence(format_texts[i][0], version))
|
||||||
|
norm_text_list.append(norm_texts[i])
|
||||||
|
|
||||||
|
return phones_list, bert_list, norm_text_list
|
||||||
|
|
||||||
|
@ -1,8 +1,8 @@
|
|||||||
custom:
|
custom:
|
||||||
bert_base_path: GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large
|
bert_base_path: GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large
|
||||||
cnhuhbert_base_path: GPT_SoVITS/pretrained_models/chinese-hubert-base
|
cnhuhbert_base_path: GPT_SoVITS/pretrained_models/chinese-hubert-base
|
||||||
device: cpu
|
device: cuda
|
||||||
is_half: false
|
is_half: true
|
||||||
t2s_weights_path: GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s1bert25hz-5kh-longer-epoch=12-step=369668.ckpt
|
t2s_weights_path: GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s1bert25hz-5kh-longer-epoch=12-step=369668.ckpt
|
||||||
version: v2
|
version: v2
|
||||||
vits_weights_path: GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s2G2333k.pth
|
vits_weights_path: GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s2G2333k.pth
|
||||||
|
@ -26,7 +26,7 @@ from inference_webui import get_spepc, norm_spec, resample, ssl_model
|
|||||||
logging.config.dictConfig(uvicorn.config.LOGGING_CONFIG)
|
logging.config.dictConfig(uvicorn.config.LOGGING_CONFIG)
|
||||||
logger = logging.getLogger("uvicorn")
|
logger = logging.getLogger("uvicorn")
|
||||||
|
|
||||||
is_half = False
|
is_half = True
|
||||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
now_dir = os.getcwd()
|
now_dir = os.getcwd()
|
||||||
|
|
||||||
|
@ -69,7 +69,7 @@ if torch.cuda.is_available():
|
|||||||
else:
|
else:
|
||||||
device = "cpu"
|
device = "cpu"
|
||||||
|
|
||||||
is_half = False
|
is_half = True
|
||||||
device = "cpu"
|
device = "cpu"
|
||||||
|
|
||||||
dict_language_v1 = {
|
dict_language_v1 = {
|
||||||
|
BIN
output.wav
BIN
output.wav
Binary file not shown.
Loading…
x
Reference in New Issue
Block a user