mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2025-08-23 21:19:47 +08:00
修改批处理bert特征
This commit is contained in:
parent
0d654cd238
commit
c7b61c6fd4
@ -61,6 +61,7 @@ class TextPreprocessor:
|
||||
text = self.replace_consecutive_punctuation(text)
|
||||
texts = self.pre_seg_text(text, lang, text_split_method)
|
||||
result = []
|
||||
# text_batch = []
|
||||
print(f"############ {i18n('提取文本Bert特征')} ############")
|
||||
for text in tqdm(texts):
|
||||
phones, bert_features, norm_text = self.segment_and_extract_feature_for_text(text, lang, version)
|
||||
@ -73,6 +74,34 @@ class TextPreprocessor:
|
||||
}
|
||||
result.append(res)
|
||||
return result
|
||||
# for text in texts:
|
||||
# if text.strip(): # 忽略空句子
|
||||
# text_batch.append(text)
|
||||
# phones_list, bert_list, norm_texts = self.batch_get_phones_and_bert(text_batch, lang, version)
|
||||
# for phones, bert_features, norm_text in zip(phones_list, bert_list, norm_texts):
|
||||
# if phones is None or norm_text == "":
|
||||
# continue
|
||||
# res = {
|
||||
# "phones": phones,
|
||||
# "bert_features": bert_features,
|
||||
# "norm_text": norm_text,
|
||||
# }
|
||||
# result.append(res)
|
||||
# return result
|
||||
|
||||
|
||||
# for text in tqdm(texts):
|
||||
# phones, bert_features, norm_text = self.segment_and_extract_feature_for_text(text, lang, version)
|
||||
# if phones is None or norm_text == "":
|
||||
# continue
|
||||
# res = {
|
||||
# "phones": phones,
|
||||
# "bert_features": bert_features,
|
||||
# "norm_text": norm_text,
|
||||
# }
|
||||
# result.append(res)
|
||||
|
||||
# return result
|
||||
|
||||
def pre_seg_text(self, text: str, lang: str, text_split_method: str):
|
||||
text = text.strip("\n")
|
||||
@ -235,3 +264,42 @@ class TextPreprocessor:
|
||||
pattern = f"([{punctuations}])([{punctuations}])+"
|
||||
result = re.sub(pattern, r"\1", text)
|
||||
return result
|
||||
|
||||
|
||||
def batch_get_phones_and_bert(self, texts: List[str], language: str, version: str):
|
||||
phones_list = []
|
||||
bert_list = []
|
||||
norm_text_list = []
|
||||
|
||||
# 预处理文本,获取每句的 phones, word2ph, norm_text
|
||||
format_texts = [self.clean_text_inf(t, language, version) for t in texts]
|
||||
norm_texts = [x[2] for x in format_texts]
|
||||
word2phs = [x[1] for x in format_texts]
|
||||
|
||||
# 批量送入 tokenizer
|
||||
inputs = self.tokenizer(norm_texts, return_tensors="pt", padding=True, truncation=True)
|
||||
inputs = {k: v.to(self.device) for k, v in inputs.items()}
|
||||
|
||||
with torch.no_grad():
|
||||
outputs = self.bert_model(**inputs, output_hidden_states=True)
|
||||
# 使用 last_hidden_state 是正确且高效的方式
|
||||
hidden_states = outputs.last_hidden_state # [batch_size, seq_len, hidden_dim]
|
||||
|
||||
for i in range(len(texts)):
|
||||
res = hidden_states[i][1:-1].cpu() # 去掉 [CLS] 和 [SEP]
|
||||
|
||||
word2ph = word2phs[i]
|
||||
phone_level_feature = []
|
||||
for j in range(len(word2ph)):
|
||||
if j >= res.shape[0]:
|
||||
print(f"警告:BERT输出不足,跳过第 {i} 句中第 {j} 个 token")
|
||||
continue
|
||||
phone_level_feature.append(res[j].repeat(word2ph[j], 1))
|
||||
phone_level_feature = torch.cat(phone_level_feature, dim=0)
|
||||
|
||||
bert_list.append(phone_level_feature.T)
|
||||
phones_list.append(cleaned_text_to_sequence(format_texts[i][0], version))
|
||||
norm_text_list.append(norm_texts[i])
|
||||
|
||||
return phones_list, bert_list, norm_text_list
|
||||
|
||||
|
@ -1,8 +1,8 @@
|
||||
custom:
|
||||
bert_base_path: GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large
|
||||
cnhuhbert_base_path: GPT_SoVITS/pretrained_models/chinese-hubert-base
|
||||
device: cpu
|
||||
is_half: false
|
||||
device: cuda
|
||||
is_half: true
|
||||
t2s_weights_path: GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s1bert25hz-5kh-longer-epoch=12-step=369668.ckpt
|
||||
version: v2
|
||||
vits_weights_path: GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s2G2333k.pth
|
||||
|
@ -26,7 +26,7 @@ from inference_webui import get_spepc, norm_spec, resample, ssl_model
|
||||
logging.config.dictConfig(uvicorn.config.LOGGING_CONFIG)
|
||||
logger = logging.getLogger("uvicorn")
|
||||
|
||||
is_half = False
|
||||
is_half = True
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
now_dir = os.getcwd()
|
||||
|
||||
|
@ -69,7 +69,7 @@ if torch.cuda.is_available():
|
||||
else:
|
||||
device = "cpu"
|
||||
|
||||
is_half = False
|
||||
is_half = True
|
||||
device = "cpu"
|
||||
|
||||
dict_language_v1 = {
|
||||
|
BIN
output.wav
BIN
output.wav
Binary file not shown.
Loading…
x
Reference in New Issue
Block a user