Merge 45cf41f8fc349d06636d6d7b28f25c00f3843d2d into 13055fa56994e75a7152c176047c56c62bbeede4

This commit is contained in:
Karyl01 2025-05-12 11:26:41 +08:00 committed by GitHub
commit 57b1744f7c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 253 additions and 20 deletions

BIN
Arona_Academy_In_2.ogg.wav Normal file

Binary file not shown.

View File

@ -18,6 +18,13 @@ from transformers import AutoModelForMaskedLM, AutoTokenizer
from TTS_infer_pack.text_segmentation_method import split_big_text, splits, get_method as get_seg_method
from tools.i18n.i18n import I18nAuto, scan_language_list
from functools import lru_cache
import torch
from .cached1 import get_cached_bert
from .cached1 import CachedBertExtractor
language = os.environ.get("language", "Auto")
language = sys.argv[-1] if sys.argv[-1] in scan_language_list() else language
@ -56,6 +63,8 @@ class TextPreprocessor:
self.device = device
self.bert_lock = threading.RLock()
self.bert_extractor = CachedBertExtractor("bert-base-chinese", device=device)
def preprocess(self, text: str, lang: str, text_split_method: str, version: str = "v2") -> List[Dict]:
print(f"############ {i18n('切分文本')} ############")
text = self.replace_consecutive_punctuation(text)
@ -186,20 +195,25 @@ class TextPreprocessor:
return phones, bert, norm_text
def get_bert_feature(self, text: str, word2ph: list) -> torch.Tensor:
with torch.no_grad():
inputs = self.tokenizer(text, return_tensors="pt")
for i in inputs:
inputs[i] = inputs[i].to(self.device)
res = self.bert_model(**inputs, output_hidden_states=True)
res = torch.cat(res["hidden_states"][-3:-2], -1)[0].cpu()[1:-1]
assert len(word2ph) == len(text)
phone_level_feature = []
for i in range(len(word2ph)):
repeat_feature = res[i].repeat(word2ph[i], 1)
phone_level_feature.append(repeat_feature)
phone_level_feature = torch.cat(phone_level_feature, dim=0)
return phone_level_feature.T
# def get_bert_feature(self, text: str, word2ph: list) -> torch.Tensor:
# with torch.no_grad():
# inputs = self.tokenizer(text, return_tensors="pt")
# for i in inputs:
# inputs[i] = inputs[i].to(self.device)
# res = self.bert_model(**inputs, output_hidden_states=True)
# res = torch.cat(res["hidden_states"][-3:-2], -1)[0].cpu()[1:-1]
# assert len(word2ph) == len(text)
# phone_level_feature = []
# for i in range(len(word2ph)):
# repeat_feature = res[i].repeat(word2ph[i], 1)
# phone_level_feature.append(repeat_feature)
# phone_level_feature = torch.cat(phone_level_feature, dim=0)
# return phone_level_feature.T
def get_bert_feature(self, norm_text: str, word2ph: list) -> torch.Tensor:
# 注意word2ph 是 list需转为 tuple 作为缓存键
bert = get_cached_bert(norm_text, tuple(word2ph), str(self.device))
return bert.to(self.device)
def clean_text_inf(self, text: str, language: str, version: str = "v2"):
language = language.replace("all_", "")
@ -235,3 +249,5 @@ class TextPreprocessor:
pattern = f"([{punctuations}])([{punctuations}])+"
result = re.sub(pattern, r"\1", text)
return result

View File

@ -0,0 +1,75 @@
from functools import lru_cache
import torch
import torch
from functools import lru_cache
from transformers import AutoTokenizer, AutoModelForMaskedLM
from typing import List, Tuple
@lru_cache(maxsize=1000)
def get_cached_bert(norm_text: str, word2ph_tuple: tuple, device_str: str = "cuda"):
"""
缓存 BERT 提取函数用于相同 norm_text 时复用特征
Args:
norm_text (str): 清洗后的文本可复用
word2ph_tuple (tuple): word2ph 列表转换成 tuple因为 lru_cache 不支持 list
device_str (str): 设备信息用于转移到正确设备上
Returns:
Tensor: 形状 [hidden_dim, total_phonemes]
"""
from transformers import AutoTokenizer, AutoModelForMaskedLM
# 如果你在类里,可以改成 self.tokenizer 和 self.model
tokenizer = AutoTokenizer.from_pretrained("bert-base-chinese")
model = AutoModelForMaskedLM.from_pretrained("bert-base-chinese", output_hidden_states=True).eval().to(device_str)
inputs = tokenizer(norm_text, return_tensors="pt").to(device_str)
with torch.no_grad():
outputs = model(**inputs)
hidden = torch.cat(outputs.hidden_states[-3:-2], dim=-1)[0][1:-1] # 去掉 CLS/SEP
word2ph = torch.tensor(list(word2ph_tuple), device=hidden.device)
indices = torch.repeat_interleave(torch.arange(len(word2ph), device=hidden.device), word2ph)
phone_level_feature = hidden[indices]
return phone_level_feature.T.cpu()
class CachedBertExtractor:
def __init__(self, model_name_or_path: str = "bert-base-chinese", device: str = "cuda"):
self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
self.device = device
self.bert_model = AutoModelForMaskedLM.from_pretrained(
model_name_or_path, output_hidden_states=True
).eval().to(device)
def get_bert_feature(self, norm_text: str, word2ph: List[int]) -> torch.Tensor:
"""
Public method: gets cached BERT feature tensor
"""
word2ph_tuple = tuple(word2ph)
return self._cached_bert(norm_text, word2ph_tuple).to(self.device)
@lru_cache(maxsize=1024)
def _cached_bert(self, norm_text: str, word2ph_tuple: Tuple[int, ...]) -> torch.Tensor:
"""
Cached private method: returns CPU tensor (for lru_cache compatibility)
"""
inputs = self.tokenizer(norm_text, return_tensors="pt").to(self.device)
with torch.no_grad():
outputs = self.bert_model(**inputs)
hidden = torch.cat(outputs.hidden_states[-3:-2], dim=-1)[0][1:-1] # shape: [seq_len-2, hidden_dim]
word2ph_tensor = torch.tensor(list(word2ph_tuple), device=self.device)
indices = torch.repeat_interleave(torch.arange(len(word2ph_tuple), device=self.device), word2ph_tensor)
phone_level_feature = hidden[indices] # [sum(word2ph), hidden_size]
return phone_level_feature.T.cpu() # cache-safe
def clear_cache(self):
"""
Clear the internal BERT feature cache
"""
self._cached_bert.cache_clear()

View File

@ -20,6 +20,13 @@ import torch
import soundfile
from librosa.filters import mel as librosa_mel_fn
import time
import random
import torch
from tqdm import tqdm
from transformers import BertTokenizer
# tokenizer = BertTokenizer.from_pretrained("bert-base-chinese")
from inference_webui import get_spepc, norm_spec, resample, ssl_model
@ -921,6 +928,24 @@ def test_export1(
import time
@torch.jit.script
def build_phone_level_feature(res: torch.Tensor, word2ph: torch.IntTensor) -> torch.Tensor:
"""
将词级别的 BERT 特征转换为音素级别的特征通过 word2ph 指定每个词对应的音素数
Args:
res: [N_words, hidden_dim]
word2ph: [N_words], 每个元素表示当前词需要复制多少次即包含多少个音素
Returns:
[sum(word2ph), hidden_dim] phone 级别特征
"""
phone_level_feature = []
for i in range(word2ph.shape[0]):
repeat_feature = res[i].repeat(word2ph[i].item(), 1)
phone_level_feature.append(repeat_feature)
return torch.cat(phone_level_feature, dim=0)
def test_():
sovits = get_sovits_weights("GPT_SoVITS/pretrained_models/s2Gv3.pth")
@ -1010,6 +1035,23 @@ def test_():
# )
def extract_bert_features(texts: list, desc: str = "提取文本Bert特征"):
"""
"""
# print(f"############ {desc} ############")
for text in tqdm(texts, desc=desc, unit="it"):
# 分词操作tokenize
tokens = tokenizer.tokenize(text)
input_ids = tokenizer.convert_tokens_to_ids(tokens)
fake_tensor = torch.randn(768, len(input_ids))
_ = fake_tensor.mean(dim=1)
delay = round(random.uniform(0.8, 1.6), 2)
time.sleep(delay)
def test_export_gpt_sovits_v3():
gpt_sovits_v3 = torch.jit.load("onnx/ad/gpt_sovits_v3.pt", map_location=device)
# test_export1(
@ -1029,7 +1071,31 @@ def test_export_gpt_sovits_v3():
)
with torch.no_grad():
# export()
test_()
# test_export_gpt_sovits_v3()
class MyBertModel(torch.nn.Module):
def __init__(self, bert_model):
super(MyBertModel, self).__init__()
self.bert = bert_model
def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, token_type_ids: torch.Tensor, word2ph: torch.IntTensor):
outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, output_hidden_states=True)
hidden_states = outputs.hidden_states
res = torch.cat(hidden_states[-3:-2], -1)[0][1:-1] # 去掉CLS和SEP
phone_level_feature = []
for i in range(word2ph.shape[0]):
repeat_feature = res[i].repeat(word2ph[i].item(), 1)
phone_level_feature.append(repeat_feature)
phone_level_feature = torch.cat(phone_level_feature, dim=0)
return phone_level_feature.T
# with torch.no_grad():
# # export()
# # test_()
# # test_export_gpt_sovits_v3()
# print()

View File

@ -69,8 +69,8 @@ if torch.cuda.is_available():
else:
device = "cpu"
# is_half = False
# device = "cpu"
is_half = True
device = "cpu"
dict_language_v1 = {
i18n("中文"): "all_zh", # 全部按中文识别

Binary file not shown.

View File

@ -0,0 +1 @@
878b3caf4d1cd7c2927c26e85072a2f5

View File

@ -0,0 +1,28 @@
from transformers import AutoTokenizer, AutoModelForMaskedLM
import torch
from export_torch_script_v3 import MyBertModel, build_phone_level_feature
bert_path = "GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large"
tokenizer = AutoTokenizer.from_pretrained(bert_path)
model = AutoModelForMaskedLM.from_pretrained(bert_path, output_hidden_states=True)
# 构建包装模型
wrapped_model = MyBertModel(model)
# 准备示例输入
text = "这是一条用于导出TorchScript的示例文本"
encoded = tokenizer(text, return_tensors="pt")
word2ph = torch.tensor([2 if c not in ",。?!,.?" else 1 for c in text], dtype=torch.int)
# 包装成输入
example_inputs = {
"input_ids": encoded["input_ids"],
"attention_mask": encoded["attention_mask"],
"token_type_ids": encoded["token_type_ids"],
"word2ph": word2ph
}
# Trace 模型并保存
traced = torch.jit.trace(wrapped_model, example_kwarg_inputs=example_inputs)
traced.save("pretrained_models/bert_script.pt")
print("✅ BERT TorchScript 模型导出完成!")

View File

@ -15,6 +15,8 @@ cnhubert_path = "GPT_SoVITS/pretrained_models/chinese-hubert-base"
bert_path = "GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large"
pretrained_sovits_path = "GPT_SoVITS/pretrained_models/s2G488k.pth"
pretrained_gpt_path = "GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt"
# pretrained_sovits_path = "GPT_SoVITS/pretrained_models/Aris_e16_s272.pth.pth"
# pretrained_gpt_path = "GPT_SoVITS/pretrained_models/Aris_e15.ckpt"
exp_root = "logs"
python_exec = sys.executable or "python"

BIN
output.wav Normal file

Binary file not shown.

Binary file not shown.

45
test.py Normal file
View File

@ -0,0 +1,45 @@
import requests
# API地址本地运行
url = "http://127.0.0.1:9880/tts"
# 请求体(对齐 api_v2.py 的 POST 定义)
payload = {
"ref_audio_path": r"C:\Users\bdxly\Desktop\GPT-SoVITS\Arona_Academy_In_2.ogg.wav",
"prompt_text": "様々な授業やイベントが準備されているので、ご希望のスケジュールを選んでください!",
"prompt_lang": "ja",
"text": "这是我的失误。我的选择,和因它发生的这一切。 直到最后,迎来了这样的结局,我才明白您是对的。 …我知道,事到如今再来说这些,挺厚脸皮的。但还是拜托您了。老师。 我想,您一定会忘记我说的这些话,不过…没关系。因为就算您什么都不记得了,在相同的情况下,应该还是会做那样的选择吧…… 所以重要的不是经历,是选择。 很多很多,只有您才能做出的选择。 我们以前聊过……关于负责人之人的话题吧。 我当时不懂……但是现在,我能理解了。 身为大人的责任与义务。以及在其延长线上的,您所做出的选择。 甚至还有,您做出选择时的那份心情。…… 所以,老师。 您是我唯一可以信任的大人,我相信您一定能找到,通往与这条扭曲的终点截然不同的……另一个结局的正确选项。所以,老师,请您一定要",
"text_lang": "zh",
"top_k": 5,
"top_p": 1.0,
"temperature": 1.0,
"text_split_method": "cut0",
"batch_size": 1,
"batch_threshold": 0.75,
"split_bucket": True,
"speed_factor": 1.0,
"fragment_interval": 0.3,
"seed": -1,
"media_type": "wav",
"streaming_mode": False,
"parallel_infer": True,
"repetition_penalty": 1.35,
"sample_steps": 32,
"super_sampling": False
}
# 发送 POST 请求
response = requests.post(url, json=payload)
# 检查返回并保存音频
if response.status_code == 200:
with open("output.wav", "wb") as f:
f.write(response.content)
print(" 生成成功,保存为 output.wav")
else:
print(f" 生成失败: {response.status_code}, 返回信息: {response.text}")