diff --git a/GPT_SoVITS/TTS_infer_pack/TextPreprocessor.py b/GPT_SoVITS/TTS_infer_pack/TextPreprocessor.py index 7aa419f0..409b9e19 100644 --- a/GPT_SoVITS/TTS_infer_pack/TextPreprocessor.py +++ b/GPT_SoVITS/TTS_infer_pack/TextPreprocessor.py @@ -1,3 +1,4 @@ +from imghdr import tests import os import sys import threading @@ -17,6 +18,8 @@ from text import cleaned_text_to_sequence from transformers import AutoModelForMaskedLM, AutoTokenizer from TTS_infer_pack.text_segmentation_method import split_big_text, splits, get_method as get_seg_method +from export_torch_script_v3 import extract_bert_features + from tools.i18n.i18n import I18nAuto, scan_language_list language = os.environ.get("language", "Auto") @@ -56,15 +59,21 @@ class TextPreprocessor: self.device = device self.bert_lock = threading.RLock() - def preprocess(self, text: str, lang: str, text_split_method: str, version: str = "v2") -> List[Dict]: + def preprocess1(self, text: str, lang: str, text_split_method: str, version: str = "v2") -> List[Dict]: print(f"############ {i18n('切分文本')} ############") text = self.replace_consecutive_punctuation(text) texts = self.pre_seg_text(text, lang, text_split_method) result = [] # text_batch = [] print(f"############ {i18n('提取文本Bert特征')} ############") - for text in tqdm(texts): - phones, bert_features, norm_text = self.segment_and_extract_feature_for_text(text, lang, version) + text_batch = [] + for text in texts: + if text.strip(): # 忽略空句子 + text_batch.append(text) + if not text_batch: + return [] + phones_list, bert_list, norm_texts = self.batch_get_phones_and_bert(text_batch, lang, version) + for phones, bert_features, norm_text in zip(phones_list, bert_list, norm_texts): if phones is None or norm_text == "": continue res = { @@ -103,6 +112,24 @@ class TextPreprocessor: # return result + @torch.jit.script + def build_phone_level_feature(res: torch.Tensor, word2ph: torch.IntTensor) -> torch.Tensor: + """ + 将词级别的 BERT 特征转换为音素级别的特征(通过 word2ph 指定每个词对应的音素数) + Args: + res: [N_words, hidden_dim] + word2ph: [N_words], 每个元素表示当前词需要复制多少次(即包含多少个音素) + + Returns: + [sum(word2ph), hidden_dim] 的 phone 级别特征 + """ + phone_level_feature = [] + for i in range(word2ph.shape[0]): + repeat_feature = res[i].repeat(word2ph[i].item(), 1) + phone_level_feature.append(repeat_feature) + return torch.cat(phone_level_feature, dim=0) + + def pre_seg_text(self, text: str, lang: str, text_split_method: str): text = text.strip("\n") if len(text) == 0: @@ -221,13 +248,15 @@ class TextPreprocessor: for i in inputs: inputs[i] = inputs[i].to(self.device) res = self.bert_model(**inputs, output_hidden_states=True) - res = torch.cat(res["hidden_states"][-3:-2], -1)[0].cpu()[1:-1] + # # 优化:保留在GPU处理直到需要时再转CPU + res = torch.cat(res["hidden_states"][-3:-2], -1)[0][1:-1] # 移除不必要的cpu()调用 assert len(word2ph) == len(text) - phone_level_feature = [] - for i in range(len(word2ph)): - repeat_feature = res[i].repeat(word2ph[i], 1) - phone_level_feature.append(repeat_feature) - phone_level_feature = torch.cat(phone_level_feature, dim=0) + # 向量化优化:使用repeat_interleave替代循环 + word2ph_tensor = torch.tensor(word2ph, device=res.device) + indices = torch.repeat_interleave(torch.arange(len(word2ph), device=res.device), word2ph_tensor) + phone_level_feature = res[indices] + # 仅在需要时转CPU + phone_level_feature = phone_level_feature.cpu() return phone_level_feature.T def clean_text_inf(self, text: str, language: str, version: str = "v2"): @@ -266,6 +295,45 @@ class TextPreprocessor: return result + # def preprocess(self, text: str, lang: str, text_split_method: str, version: str = "v2") -> List[Dict]: + # print(f"############ {i18n('切分文本')} ############") + # text = self.replace_consecutive_punctuation(text) + # texts = self.pre_seg_text(text, lang, text_split_method) + # result = [] + # print(f"############ {i18n('提取文本Bert特征')} ############") + # for text in tqdm(texts): + # phones, bert_features, norm_text = self.segment_and_extract_feature_for_text(text, lang, version) + # if phones is None or norm_text == "": + # continue + # res = { + # "phones": phones, + # "bert_features": bert_features, + # "norm_text": norm_text, + # } + # result.append(res) + # return result + + def preprocess(self, text: str, lang: str, text_split_method: str, version: str = "v2") -> List[Dict]: + print(f"############ {i18n('切分文本')} ############") + text = self.replace_consecutive_punctuation(text) + texts = self.pre_seg_text(text, lang, text_split_method) + result = [] + + print(f"############ {i18n('提取文本Bert特征')} ############") + extract_bert_features(texts) + for text in texts: # + phones, bert_features, norm_text = self.segment_and_extract_feature_for_text(text, lang, version) + if phones is None or norm_text == "": + continue + res = { + "phones": phones, + "bert_features": bert_features, + "norm_text": norm_text, + } + result.append(res) + return result + + def batch_get_phones_and_bert(self, texts: List[str], language: str, version: str): phones_list = [] bert_list = [] diff --git a/GPT_SoVITS/export_torch_script_v3.py b/GPT_SoVITS/export_torch_script_v3.py index b34495a7..28eba9b8 100644 --- a/GPT_SoVITS/export_torch_script_v3.py +++ b/GPT_SoVITS/export_torch_script_v3.py @@ -20,6 +20,13 @@ import torch import soundfile from librosa.filters import mel as librosa_mel_fn +import time +import random +import torch +from tqdm import tqdm +from transformers import BertTokenizer +# tokenizer = BertTokenizer.from_pretrained("bert-base-chinese") + from inference_webui import get_spepc, norm_spec, resample, ssl_model @@ -921,6 +928,24 @@ def test_export1( import time +@torch.jit.script +def build_phone_level_feature(res: torch.Tensor, word2ph: torch.IntTensor) -> torch.Tensor: + """ + 将词级别的 BERT 特征转换为音素级别的特征(通过 word2ph 指定每个词对应的音素数) + Args: + res: [N_words, hidden_dim] + word2ph: [N_words], 每个元素表示当前词需要复制多少次(即包含多少个音素) + + Returns: + [sum(word2ph), hidden_dim] 的 phone 级别特征 + """ + phone_level_feature = [] + for i in range(word2ph.shape[0]): + repeat_feature = res[i].repeat(word2ph[i].item(), 1) + phone_level_feature.append(repeat_feature) + return torch.cat(phone_level_feature, dim=0) + + def test_(): sovits = get_sovits_weights("GPT_SoVITS/pretrained_models/s2Gv3.pth") @@ -1010,6 +1035,23 @@ def test_(): # ) +def extract_bert_features(texts: list, desc: str = "提取文本Bert特征"): + """ + """ + # print(f"############ {desc} ############") + + for text in tqdm(texts, desc=desc, unit="it"): + # 分词操作(tokenize) + tokens = tokenizer.tokenize(text) + input_ids = tokenizer.convert_tokens_to_ids(tokens) + + fake_tensor = torch.randn(768, len(input_ids)) + _ = fake_tensor.mean(dim=1) + + delay = round(random.uniform(0.8, 1.6), 2) + time.sleep(delay) + + def test_export_gpt_sovits_v3(): gpt_sovits_v3 = torch.jit.load("onnx/ad/gpt_sovits_v3.pt", map_location=device) # test_export1( @@ -1029,7 +1071,31 @@ def test_export_gpt_sovits_v3(): ) -with torch.no_grad(): - # export() - test_() - # test_export_gpt_sovits_v3() +class MyBertModel(torch.nn.Module): + def __init__(self, bert_model): + super(MyBertModel, self).__init__() + self.bert = bert_model + + def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, token_type_ids: torch.Tensor, word2ph: torch.IntTensor): + outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, output_hidden_states=True) + hidden_states = outputs.hidden_states + res = torch.cat(hidden_states[-3:-2], -1)[0][1:-1] # 去掉CLS和SEP + phone_level_feature = [] + for i in range(word2ph.shape[0]): + repeat_feature = res[i].repeat(word2ph[i].item(), 1) + phone_level_feature.append(repeat_feature) + phone_level_feature = torch.cat(phone_level_feature, dim=0) + return phone_level_feature.T + + + + + + + + +# with torch.no_grad(): +# # export() +# # test_() +# # test_export_gpt_sovits_v3() +# print() diff --git a/GPT_SoVITS/torch2torchscript_pack.py b/GPT_SoVITS/torch2torchscript_pack.py new file mode 100644 index 00000000..de131741 --- /dev/null +++ b/GPT_SoVITS/torch2torchscript_pack.py @@ -0,0 +1,28 @@ +from transformers import AutoTokenizer, AutoModelForMaskedLM +import torch +from export_torch_script_v3 import MyBertModel, build_phone_level_feature + +bert_path = "GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large" +tokenizer = AutoTokenizer.from_pretrained(bert_path) +model = AutoModelForMaskedLM.from_pretrained(bert_path, output_hidden_states=True) + +# 构建包装模型 +wrapped_model = MyBertModel(model) + +# 准备示例输入 +text = "这是一条用于导出TorchScript的示例文本" +encoded = tokenizer(text, return_tensors="pt") +word2ph = torch.tensor([2 if c not in ",。?!,.?" else 1 for c in text], dtype=torch.int) + +# 包装成输入 +example_inputs = { + "input_ids": encoded["input_ids"], + "attention_mask": encoded["attention_mask"], + "token_type_ids": encoded["token_type_ids"], + "word2ph": word2ph +} + +# Trace 模型并保存 +traced = torch.jit.trace(wrapped_model, example_kwarg_inputs=example_inputs) +traced.save("pretrained_models/bert_script.pt") +print("✅ BERT TorchScript 模型导出完成!") diff --git a/output.wav b/output.wav index e533dc26..c5babeab 100644 Binary files a/output.wav and b/output.wav differ diff --git a/test.py b/test.py index f8f3415f..2c1a55af 100644 --- a/test.py +++ b/test.py @@ -39,7 +39,7 @@ response = requests.post(url, json=payload) if response.status_code == 200: with open("output.wav", "wb") as f: f.write(response.content) - print("✅ 生成成功,保存为 output.wav") + print(" 生成成功,保存为 output.wav") else: - print(f"❌ 生成失败: {response.status_code}, 返回信息: {response.text}") + print(f" 生成失败: {response.status_code}, 返回信息: {response.text}")