diff --git a/GPT_SoVITS/TTS_infer_pack/TextPreprocessor.py b/GPT_SoVITS/TTS_infer_pack/TextPreprocessor.py index 426929f8..7aa419f0 100644 --- a/GPT_SoVITS/TTS_infer_pack/TextPreprocessor.py +++ b/GPT_SoVITS/TTS_infer_pack/TextPreprocessor.py @@ -61,6 +61,7 @@ class TextPreprocessor: text = self.replace_consecutive_punctuation(text) texts = self.pre_seg_text(text, lang, text_split_method) result = [] + # text_batch = [] print(f"############ {i18n('提取文本Bert特征')} ############") for text in tqdm(texts): phones, bert_features, norm_text = self.segment_and_extract_feature_for_text(text, lang, version) @@ -73,6 +74,34 @@ class TextPreprocessor: } result.append(res) return result + # for text in texts: + # if text.strip(): # 忽略空句子 + # text_batch.append(text) + # phones_list, bert_list, norm_texts = self.batch_get_phones_and_bert(text_batch, lang, version) + # for phones, bert_features, norm_text in zip(phones_list, bert_list, norm_texts): + # if phones is None or norm_text == "": + # continue + # res = { + # "phones": phones, + # "bert_features": bert_features, + # "norm_text": norm_text, + # } + # result.append(res) + # return result + + + # for text in tqdm(texts): + # phones, bert_features, norm_text = self.segment_and_extract_feature_for_text(text, lang, version) + # if phones is None or norm_text == "": + # continue + # res = { + # "phones": phones, + # "bert_features": bert_features, + # "norm_text": norm_text, + # } + # result.append(res) + + # return result def pre_seg_text(self, text: str, lang: str, text_split_method: str): text = text.strip("\n") @@ -235,3 +264,42 @@ class TextPreprocessor: pattern = f"([{punctuations}])([{punctuations}])+" result = re.sub(pattern, r"\1", text) return result + + + def batch_get_phones_and_bert(self, texts: List[str], language: str, version: str): + phones_list = [] + bert_list = [] + norm_text_list = [] + + # 预处理文本,获取每句的 phones, word2ph, norm_text + format_texts = [self.clean_text_inf(t, language, version) for t in texts] + norm_texts = [x[2] for x in format_texts] + word2phs = [x[1] for x in format_texts] + + # 批量送入 tokenizer + inputs = self.tokenizer(norm_texts, return_tensors="pt", padding=True, truncation=True) + inputs = {k: v.to(self.device) for k, v in inputs.items()} + + with torch.no_grad(): + outputs = self.bert_model(**inputs, output_hidden_states=True) + # 使用 last_hidden_state 是正确且高效的方式 + hidden_states = outputs.last_hidden_state # [batch_size, seq_len, hidden_dim] + + for i in range(len(texts)): + res = hidden_states[i][1:-1].cpu() # 去掉 [CLS] 和 [SEP] + + word2ph = word2phs[i] + phone_level_feature = [] + for j in range(len(word2ph)): + if j >= res.shape[0]: + print(f"警告:BERT输出不足,跳过第 {i} 句中第 {j} 个 token") + continue + phone_level_feature.append(res[j].repeat(word2ph[j], 1)) + phone_level_feature = torch.cat(phone_level_feature, dim=0) + + bert_list.append(phone_level_feature.T) + phones_list.append(cleaned_text_to_sequence(format_texts[i][0], version)) + norm_text_list.append(norm_texts[i]) + + return phones_list, bert_list, norm_text_list + diff --git a/GPT_SoVITS/configs/tts_infer.yaml b/GPT_SoVITS/configs/tts_infer.yaml index e2c13c28..20c41a20 100644 --- a/GPT_SoVITS/configs/tts_infer.yaml +++ b/GPT_SoVITS/configs/tts_infer.yaml @@ -1,8 +1,8 @@ custom: bert_base_path: GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large cnhuhbert_base_path: GPT_SoVITS/pretrained_models/chinese-hubert-base - device: cpu - is_half: false + device: cuda + is_half: true t2s_weights_path: GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s1bert25hz-5kh-longer-epoch=12-step=369668.ckpt version: v2 vits_weights_path: GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s2G2333k.pth diff --git a/GPT_SoVITS/export_torch_script_v3.py b/GPT_SoVITS/export_torch_script_v3.py index 2f57df0e..b34495a7 100644 --- a/GPT_SoVITS/export_torch_script_v3.py +++ b/GPT_SoVITS/export_torch_script_v3.py @@ -26,7 +26,7 @@ from inference_webui import get_spepc, norm_spec, resample, ssl_model logging.config.dictConfig(uvicorn.config.LOGGING_CONFIG) logger = logging.getLogger("uvicorn") -is_half = False +is_half = True device = "cuda" if torch.cuda.is_available() else "cpu" now_dir = os.getcwd() diff --git a/GPT_SoVITS/inference_webui_fast.py b/GPT_SoVITS/inference_webui_fast.py index 02686a00..b2a02685 100644 --- a/GPT_SoVITS/inference_webui_fast.py +++ b/GPT_SoVITS/inference_webui_fast.py @@ -69,7 +69,7 @@ if torch.cuda.is_available(): else: device = "cpu" -is_half = False +is_half = True device = "cpu" dict_language_v1 = { diff --git a/output.wav b/output.wav index 8b3c38a8..e533dc26 100644 Binary files a/output.wav and b/output.wav differ