mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2026-04-29 21:00:42 +08:00
Compare commits
9 Commits
0488d07bd2
...
2ec561432b
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
2ec561432b | ||
|
|
938f05fce8 | ||
|
|
445d18ccce | ||
|
|
00ce973412 | ||
|
|
14191901cd | ||
|
|
780383d5bd | ||
|
|
ba8de9b760 | ||
|
|
7f6787121b | ||
|
|
6e027ec111 |
5
.gitignore
vendored
5
.gitignore
vendored
@ -193,3 +193,8 @@ cython_debug/
|
||||
|
||||
# PyPI configuration file
|
||||
.pypirc
|
||||
/.vs
|
||||
/GPT_SoVITS/configs/tts_infer.yaml
|
||||
/GPT_SoVITS/configs/infer_settings.json
|
||||
/last_selected_preset.json
|
||||
/last_selected_models.json
|
||||
|
||||
@ -67,8 +67,10 @@ class Text2SemanticDataset(Dataset):
|
||||
)
|
||||
) # "%s/3-bert"%exp_dir#bert_dir
|
||||
self.path6 = semantic_path # "%s/6-name2semantic.tsv"%exp_dir#semantic_path
|
||||
assert os.path.exists(self.path2)
|
||||
assert os.path.exists(self.path6)
|
||||
if not os.path.exists(self.path2):
|
||||
raise FileNotFoundError(f"Phoneme data file not found: {self.path2}")
|
||||
if not os.path.exists(self.path6):
|
||||
raise FileNotFoundError(f"Semantic data file not found: {self.path6}")
|
||||
self.phoneme_data = {}
|
||||
with open(self.path2, "r", encoding="utf8") as f:
|
||||
lines = f.read().strip("\n").split("\n")
|
||||
@ -131,7 +133,7 @@ class Text2SemanticDataset(Dataset):
|
||||
phoneme, word2ph, text = self.phoneme_data[item_name]
|
||||
except Exception:
|
||||
traceback.print_exc()
|
||||
# print(f"{item_name} not in self.phoneme_data !")
|
||||
print(f"Warning: File \"{item_name}\" not in self.phoneme_data! Skipped. ")
|
||||
num_not_in += 1
|
||||
continue
|
||||
|
||||
@ -152,7 +154,7 @@ class Text2SemanticDataset(Dataset):
|
||||
phoneme_ids = cleaned_text_to_sequence(phoneme, version)
|
||||
except:
|
||||
traceback.print_exc()
|
||||
# print(f"{item_name} not in self.phoneme_data !")
|
||||
print(f"Warning: Failed to convert phonemes to sequence for file \"{item_name}\"! Skipped. ")
|
||||
num_not_in += 1
|
||||
continue
|
||||
# if len(phoneme_ids) >400:###########2:改为恒定限制为semantic/2.5就行
|
||||
@ -228,7 +230,11 @@ class Text2SemanticDataset(Dataset):
|
||||
# bert_feature=torch.zeros_like(phoneme_ids,dtype=torch.float32)
|
||||
bert_feature = None
|
||||
else:
|
||||
assert bert_feature.shape[-1] == len(phoneme_ids)
|
||||
try:
|
||||
assert bert_feature.shape[-1] == len(phoneme_ids)
|
||||
except AssertionError:
|
||||
print(f"AssertionError: The BERT feature dimension ({bert_feature.shape[-1]}) of the file '{item_name}' does not match the length of the phoneme sequence ({len(phoneme_ids)}).")
|
||||
raise
|
||||
return {
|
||||
"idx": idx,
|
||||
"phoneme_ids": phoneme_ids,
|
||||
|
||||
@ -262,7 +262,7 @@ def make_reject_y(y_o, y_lens):
|
||||
reject_y = []
|
||||
reject_y_lens = []
|
||||
for b in range(bs):
|
||||
process_item_idx = torch.randint(0, 1, size=(1,))[0]
|
||||
process_item_idx = torch.randint(0, 2, size=(1,))[0]
|
||||
if process_item_idx == 0:
|
||||
new_y = repeat_P(y_o[b])
|
||||
reject_y.append(new_y)
|
||||
|
||||
@ -499,7 +499,7 @@ class TTS:
|
||||
|
||||
if if_lora_v3 == True and os.path.exists(path_sovits) == False:
|
||||
info = path_sovits + i18n("SoVITS %s 底模缺失,无法加载相应 LoRA 权重" % model_version)
|
||||
raise FileExistsError(info)
|
||||
raise FileNotFoundError(info)
|
||||
|
||||
# dict_s2 = torch.load(weights_path, map_location=self.configs.device,weights_only=False)
|
||||
dict_s2 = load_sovits_new(weights_path)
|
||||
@ -1578,16 +1578,15 @@ class TTS:
|
||||
max_audio = np.abs(audio).max()
|
||||
if max_audio > 1:
|
||||
audio /= max_audio
|
||||
audio = (audio * 32768).astype(np.int16)
|
||||
audio = (audio * 32768).astype(np.int16)
|
||||
else:
|
||||
audio = audio.cpu().numpy()
|
||||
audio = (audio * 32768).astype(np.int16)
|
||||
t2 = time.perf_counter()
|
||||
print(f"超采样用时:{t2 - t1:.3f}s")
|
||||
else:
|
||||
# audio = audio.float() * 32768
|
||||
# audio = audio.to(dtype=torch.int16).clamp(-32768, 32767).cpu().numpy()
|
||||
|
||||
audio = audio.cpu().numpy()
|
||||
|
||||
audio = (audio * 32768).astype(np.int16)
|
||||
audio = (audio * 32768).astype(np.int16)
|
||||
|
||||
|
||||
# try:
|
||||
@ -1768,7 +1767,10 @@ class TTS:
|
||||
pos += chunk_len * upsample_rate
|
||||
|
||||
audio = self.sola_algorithm(audio_fragments, overlapped_len * upsample_rate)
|
||||
audio = audio[overlapped_len * upsample_rate : -padding_len * upsample_rate]
|
||||
if padding_len > 0:
|
||||
audio = audio[overlapped_len * upsample_rate : -padding_len * upsample_rate]
|
||||
else:
|
||||
audio = audio[overlapped_len * upsample_rate :]
|
||||
|
||||
audio_fragments = []
|
||||
for feat_len in feat_lens:
|
||||
|
||||
@ -92,7 +92,7 @@ def cut0(inp):
|
||||
if not set(inp).issubset(punctuation):
|
||||
return inp
|
||||
else:
|
||||
return "/n"
|
||||
return "\n"
|
||||
|
||||
|
||||
# 凑四句一切
|
||||
|
||||
@ -1,56 +0,0 @@
|
||||
custom:
|
||||
bert_base_path: GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large
|
||||
cnhuhbert_base_path: GPT_SoVITS/pretrained_models/chinese-hubert-base
|
||||
device: cuda
|
||||
is_half: true
|
||||
t2s_weights_path: GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s1bert25hz-5kh-longer-epoch=12-step=369668.ckpt
|
||||
version: v2
|
||||
vits_weights_path: GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s2G2333k.pth
|
||||
v1:
|
||||
bert_base_path: GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large
|
||||
cnhuhbert_base_path: GPT_SoVITS/pretrained_models/chinese-hubert-base
|
||||
device: cpu
|
||||
is_half: false
|
||||
t2s_weights_path: GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt
|
||||
version: v1
|
||||
vits_weights_path: GPT_SoVITS/pretrained_models/s2G488k.pth
|
||||
v2:
|
||||
bert_base_path: GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large
|
||||
cnhuhbert_base_path: GPT_SoVITS/pretrained_models/chinese-hubert-base
|
||||
device: cpu
|
||||
is_half: false
|
||||
t2s_weights_path: GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s1bert25hz-5kh-longer-epoch=12-step=369668.ckpt
|
||||
version: v2
|
||||
vits_weights_path: GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s2G2333k.pth
|
||||
v2Pro:
|
||||
bert_base_path: GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large
|
||||
cnhuhbert_base_path: GPT_SoVITS/pretrained_models/chinese-hubert-base
|
||||
device: cpu
|
||||
is_half: false
|
||||
t2s_weights_path: GPT_SoVITS/pretrained_models/s1v3.ckpt
|
||||
version: v2Pro
|
||||
vits_weights_path: GPT_SoVITS/pretrained_models/v2Pro/s2Gv2Pro.pth
|
||||
v2ProPlus:
|
||||
bert_base_path: GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large
|
||||
cnhuhbert_base_path: GPT_SoVITS/pretrained_models/chinese-hubert-base
|
||||
device: cpu
|
||||
is_half: false
|
||||
t2s_weights_path: GPT_SoVITS/pretrained_models/s1v3.ckpt
|
||||
version: v2ProPlus
|
||||
vits_weights_path: GPT_SoVITS/pretrained_models/v2Pro/s2Gv2ProPlus.pth
|
||||
v3:
|
||||
bert_base_path: GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large
|
||||
cnhuhbert_base_path: GPT_SoVITS/pretrained_models/chinese-hubert-base
|
||||
device: cpu
|
||||
is_half: false
|
||||
t2s_weights_path: GPT_SoVITS/pretrained_models/s1v3.ckpt
|
||||
version: v3
|
||||
vits_weights_path: GPT_SoVITS/pretrained_models/s2Gv3.pth
|
||||
v4:
|
||||
bert_base_path: GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large
|
||||
cnhuhbert_base_path: GPT_SoVITS/pretrained_models/chinese-hubert-base
|
||||
device: cpu
|
||||
is_half: false
|
||||
t2s_weights_path: GPT_SoVITS/pretrained_models/s1v3.ckpt
|
||||
version: v4
|
||||
vits_weights_path: GPT_SoVITS/pretrained_models/gsv-v4-pretrained/s2Gv4.pth
|
||||
File diff suppressed because it is too large
Load Diff
@ -87,7 +87,7 @@ def sync_buffer(buffers, average=True):
|
||||
for buffer, handle in handles:
|
||||
handle.wait()
|
||||
if average:
|
||||
buffer.data /= world_size
|
||||
buffer.data /= world_size()
|
||||
|
||||
|
||||
def sync_grad(params):
|
||||
|
||||
425
GPT_SoVITS/persistence_tools.py
Normal file
425
GPT_SoVITS/persistence_tools.py
Normal file
@ -0,0 +1,425 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
GPT-SoVITS 持久化工具类
|
||||
包含:模型配置、参考音频、推理参数 的持久化读写与管理
|
||||
抽离自主文件,减少主文件臃肿,方便后续维护
|
||||
"""
|
||||
import json
|
||||
import yaml
|
||||
import hashlib
|
||||
import os
|
||||
import shutil
|
||||
import random
|
||||
from pathlib import Path
|
||||
|
||||
# ===================== 全局配置(统一管理所有持久化文件路径) =====================
|
||||
# 模型持久化配置文件
|
||||
LAST_SELECTED_MODELS_JSON = Path("./last_selected_models.json")
|
||||
# 参考预设最后选中配置文件
|
||||
LAST_SELECTED_PRESET_JSON = Path("./last_selected_preset.json")
|
||||
# 参考音频持久化目录
|
||||
REF_AUDIO_DIR = Path("GPT_SoVITS/ref_audios")
|
||||
# 参考预设配置文件
|
||||
REF_PRESETS_YAML = Path("GPT_SoVITS/configs/ref_audios_presets.yaml")
|
||||
# 推理参数配置文件
|
||||
INFER_SETTINGS_JSON = Path("GPT_SoVITS/configs/infer_settings.json")
|
||||
|
||||
# 参考音频配置常量
|
||||
MAX_FILENAME_LENGTH = 40
|
||||
INVALID_FILE_CHARS = set(r'\/:*?"<>|')
|
||||
|
||||
# 默认推理参数
|
||||
DEFAULT_INFER_SETTINGS = {
|
||||
"batch_size": 20,
|
||||
"sample_steps": 32,
|
||||
"fragment_interval": 0.2,
|
||||
"speed_factor": 1.0,
|
||||
"top_k": 5,
|
||||
"top_p": 1.0,
|
||||
"temperature": 1.0,
|
||||
"repetition_penalty": 1.35,
|
||||
"how_to_cut": "凑四句一切",
|
||||
"super_sampling": False,
|
||||
"parallel_infer": True,
|
||||
"split_bucket": True,
|
||||
"seed": -1,
|
||||
"keep_random": True
|
||||
}
|
||||
|
||||
# ===================== 通用工具函数(抽离重复逻辑) =====================
|
||||
def sanitize_filename(name):
|
||||
"""清理文件名中的非法字符,替换为下划线"""
|
||||
if not name:
|
||||
return "unnamed_preset"
|
||||
return ''.join(c if c not in INVALID_FILE_CHARS else '_' for c in name)
|
||||
|
||||
def get_audio_md5(file_path, chunk_size=4096):
|
||||
"""计算音频文件的MD5值(取前8位),用于区分不同音频内容"""
|
||||
if not os.path.exists(file_path):
|
||||
return "invalid_file"
|
||||
try:
|
||||
md5 = hashlib.md5()
|
||||
with open(file_path, 'rb') as f:
|
||||
while chunk := f.read(chunk_size):
|
||||
md5.update(chunk)
|
||||
return md5.hexdigest()[:8]
|
||||
except Exception as e:
|
||||
print(f"计算音频MD5失败:{e}")
|
||||
return f"err_{random.randint(10000000, 99999999)}"
|
||||
|
||||
def ensure_dir_exists(dir_path):
|
||||
"""确保目录存在,不存在则创建"""
|
||||
if dir_path and not dir_path.exists():
|
||||
dir_path.mkdir(exist_ok=True, parents=True)
|
||||
|
||||
# ===================== 1. 模型配置持久化(last_selected_models.json) =====================
|
||||
def init_last_selected_models(gpt_default, sovits_default, current_version):
|
||||
"""初始化模型配置文件,写入默认模型路径"""
|
||||
ensure_dir_exists(LAST_SELECTED_MODELS_JSON.parent)
|
||||
init_data = {
|
||||
"gpt_model_path": gpt_default,
|
||||
"sovits_model_path": sovits_default,
|
||||
"version": current_version
|
||||
}
|
||||
with open(LAST_SELECTED_MODELS_JSON, "w", encoding="utf-8") as f:
|
||||
json.dump(init_data, f, ensure_ascii=False, indent=4)
|
||||
print(f"首次生成模型配置文件:{LAST_SELECTED_MODELS_JSON}")
|
||||
return init_data
|
||||
|
||||
def read_last_selected_models():
|
||||
"""读取模型配置文件中的路径"""
|
||||
if not LAST_SELECTED_MODELS_JSON.exists():
|
||||
return None
|
||||
try:
|
||||
with open(LAST_SELECTED_MODELS_JSON, "r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
# 校验必要字段
|
||||
required_fields = ["gpt_model_path", "sovits_model_path", "version"]
|
||||
for field in required_fields:
|
||||
if field not in data:
|
||||
return None
|
||||
return data
|
||||
except Exception as e:
|
||||
print(f"读取模型配置失败:{e}")
|
||||
return None
|
||||
|
||||
def write_last_selected_models(gpt_path_new, sovits_path_new, current_version):
|
||||
"""写入新的模型路径到配置文件"""
|
||||
ensure_dir_exists(LAST_SELECTED_MODELS_JSON.parent)
|
||||
try:
|
||||
data = read_last_selected_models() or {}
|
||||
data["gpt_model_path"] = gpt_path_new
|
||||
data["sovits_model_path"] = sovits_path_new
|
||||
data["version"] = current_version
|
||||
with open(LAST_SELECTED_MODELS_JSON, "w", encoding="utf-8") as f:
|
||||
json.dump(data, f, ensure_ascii=False, indent=4)
|
||||
except Exception as e:
|
||||
print(f"写入模型配置失败:{e}")
|
||||
|
||||
# ===================== 2. 参考音频预设持久化(last_selected_preset.json + ref_audios_presets.yaml) =====================
|
||||
# 2.1 最后选中预设的读写清
|
||||
def read_last_selected_preset():
|
||||
"""读取最后一次选中的预设名称"""
|
||||
if not LAST_SELECTED_PRESET_JSON.exists():
|
||||
return None
|
||||
try:
|
||||
with open(LAST_SELECTED_PRESET_JSON, "r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
return data.get("last_selected_preset")
|
||||
except Exception as e:
|
||||
print(f"读取最后选中预设失败:{e}")
|
||||
return None
|
||||
|
||||
def write_last_selected_preset(preset_name):
|
||||
"""写入最后一次选中的预设名称"""
|
||||
ensure_dir_exists(LAST_SELECTED_PRESET_JSON.parent)
|
||||
try:
|
||||
data = {"last_selected_preset": preset_name.strip()}
|
||||
with open(LAST_SELECTED_PRESET_JSON, "w", encoding="utf-8") as f:
|
||||
json.dump(data, f, ensure_ascii=False, indent=4)
|
||||
print(f"已记录最后选中的预设:{preset_name.strip()}")
|
||||
except Exception as e:
|
||||
print(f"写入最后选中预设失败:{e}")
|
||||
|
||||
def clear_last_selected_preset():
|
||||
"""清空最后选中的预设记录"""
|
||||
if not LAST_SELECTED_PRESET_JSON.exists():
|
||||
return
|
||||
try:
|
||||
with open(LAST_SELECTED_PRESET_JSON, "w", encoding="utf-8") as f:
|
||||
json.dump({"last_selected_preset": ""}, f, ensure_ascii=False, indent=4)
|
||||
except Exception as e:
|
||||
print(f"清空最后选中预设失败:{e}")
|
||||
|
||||
# 2.2 参考预设配置的加载/保存/删除
|
||||
def load_ref_presets():
|
||||
"""加载多组参考预设配置"""
|
||||
ensure_dir_exists(REF_PRESETS_YAML.parent)
|
||||
|
||||
# 新增:配置文件不存在时,自动创建空文件
|
||||
if not REF_PRESETS_YAML.exists():
|
||||
with open(REF_PRESETS_YAML, "w", encoding="utf-8") as f:
|
||||
yaml.dump([], f, indent=4, allow_unicode=True)
|
||||
print(f"暂未检测到参考预设配置文件,已自动创建空文件:{REF_PRESETS_YAML}")
|
||||
return []
|
||||
|
||||
try:
|
||||
with open(REF_PRESETS_YAML, "r", encoding="utf-8") as f:
|
||||
presets = yaml.load(f, Loader=yaml.FullLoader) or []
|
||||
|
||||
# 兼容旧格式转换
|
||||
if isinstance(presets, dict):
|
||||
presets = [{"name": "旧配置转换", "ref_audio_path": presets.get("ref_audio_path"),
|
||||
"prompt_text": presets.get("prompt_text", ""), "prompt_language": presets.get("prompt_language", "中文")}]
|
||||
|
||||
# 补充缺失字段 + 校验音频路径
|
||||
default_template = {"name": "", "ref_audio_path": None, "prompt_text": "", "prompt_language": "中文"}
|
||||
for preset in presets:
|
||||
for key, value in default_template.items():
|
||||
preset.setdefault(key, value)
|
||||
# 校验音频路径有效性
|
||||
audio_path = preset["ref_audio_path"]
|
||||
if audio_path and not os.path.exists(str(audio_path)):
|
||||
preset["ref_audio_path"] = None
|
||||
|
||||
# 清理冗余音频
|
||||
clean_unreferenced_audios(presets)
|
||||
print(f"参考预设加载成功,共 {len(presets)} 组")
|
||||
return presets
|
||||
except Exception as e:
|
||||
print(f"参考预设加载失败:{e}")
|
||||
return []
|
||||
|
||||
def get_preset_by_name(preset_name, presets=None):
|
||||
"""根据配置名称查询对应的配置详情"""
|
||||
# 核心修复:先判断 preset_name 是否为 None,避免 AttributeError
|
||||
if preset_name is None:
|
||||
return {"name": "", "ref_audio_path": None, "prompt_text": "", "prompt_language": "中文"}
|
||||
|
||||
if not presets:
|
||||
presets = load_ref_presets()
|
||||
|
||||
# 现在再调用 strip(),确保 preset_name 不是 None
|
||||
preset_name_str = preset_name.strip()
|
||||
for preset in presets:
|
||||
if preset["name"].strip() == preset_name_str:
|
||||
return preset
|
||||
|
||||
# 无匹配预设时,返回空的合法预设字典
|
||||
return {"name": "", "ref_audio_path": None, "prompt_text": "", "prompt_language": "中文"}
|
||||
|
||||
def save_ref_preset_core(preset_name, ref_audio_path, prompt_text, prompt_language, confirm_override=False):
|
||||
"""保存/覆盖参考预设核心逻辑(返回:提示信息、是否成功、预设列表)"""
|
||||
ensure_dir_exists(REF_AUDIO_DIR)
|
||||
presets = load_ref_presets()
|
||||
preset_name = preset_name.strip()
|
||||
|
||||
# 前置校验
|
||||
if not ref_audio_path or not os.path.exists(str(ref_audio_path)):
|
||||
return "保存失败!请先上传有效的主参考音频文件。", False, [p["name"] for p in presets]
|
||||
if not preset_name:
|
||||
return "保存失败!配置名称不能为空。", False, [p["name"] for p in presets]
|
||||
|
||||
# 音频持久化处理
|
||||
persistent_audio_path = get_persistent_audio_path(ref_audio_path, preset_name)
|
||||
if not persistent_audio_path:
|
||||
return "保存失败!音频文件持久化存储失败。", False, [p["name"] for p in presets]
|
||||
|
||||
# 同名检测
|
||||
preset_index = -1
|
||||
for idx, p in enumerate(presets):
|
||||
if p["name"].strip() == preset_name:
|
||||
preset_index = idx
|
||||
break
|
||||
|
||||
if preset_index >= 0 and not confirm_override:
|
||||
return f"配置「{preset_name}」已存在,如需替换请确认覆盖!", False, [p["name"] for p in presets]
|
||||
|
||||
# 构造新配置
|
||||
new_preset = {
|
||||
"name": preset_name,
|
||||
"ref_audio_path": persistent_audio_path,
|
||||
"prompt_text": prompt_text,
|
||||
"prompt_language": prompt_language
|
||||
}
|
||||
|
||||
# 更新配置列表
|
||||
is_new_preset = preset_index < 0
|
||||
if preset_index >= 0:
|
||||
presets[preset_index] = new_preset
|
||||
tip = "同名配置已覆盖!"
|
||||
else:
|
||||
presets.append(new_preset)
|
||||
tip = "新配置已新增!"
|
||||
|
||||
# 写入配置文件
|
||||
try:
|
||||
with open(REF_PRESETS_YAML, "w", encoding="utf-8") as f:
|
||||
yaml.dump(presets, f, indent=4, allow_unicode=True)
|
||||
|
||||
# 新增预设自动记录为最后选中
|
||||
if is_new_preset:
|
||||
write_last_selected_preset(preset_name)
|
||||
|
||||
preset_names = [p["name"] for p in presets]
|
||||
return f"配置保存成功!{tip}", True, preset_names
|
||||
except Exception as e:
|
||||
return f"保存失败:{str(e)}", False, [p["name"] for p in presets]
|
||||
|
||||
def delete_ref_preset_core(preset_name):
|
||||
"""删除参考预设核心逻辑(返回:提示信息、预设列表、默认选中预设)"""
|
||||
presets = load_ref_presets()
|
||||
preset_name = preset_name.strip()
|
||||
|
||||
if not presets:
|
||||
return "暂无配置可删除!", [], None
|
||||
|
||||
# 获取待删除音频路径
|
||||
target_audio_path = None
|
||||
for p in presets:
|
||||
if p["name"].strip() == preset_name:
|
||||
target_audio_path = p.get("ref_audio_path")
|
||||
break
|
||||
|
||||
# 过滤删除
|
||||
presets = [p for p in presets if p["name"].strip() != preset_name]
|
||||
|
||||
# 写入配置文件
|
||||
try:
|
||||
with open(REF_PRESETS_YAML, "w", encoding="utf-8") as f:
|
||||
yaml.dump(presets, f, indent=4, allow_unicode=True)
|
||||
|
||||
# 删除对应音频
|
||||
if target_audio_path and os.path.exists(target_audio_path):
|
||||
os.unlink(target_audio_path)
|
||||
print(f"同步删除配置对应音频:{target_audio_path}")
|
||||
|
||||
# 清空最后选中记录(若删除的是最后选中的预设)
|
||||
last_selected = read_last_selected_preset()
|
||||
if last_selected and last_selected == preset_name:
|
||||
clear_last_selected_preset()
|
||||
|
||||
preset_names = [p["name"] for p in presets]
|
||||
new_selected = preset_names[0] if preset_names else None
|
||||
tip = "配置删除成功!已同步清理对应音频文件" if preset_names else "配置删除成功!已同步清理对应音频文件,当前无剩余配置"
|
||||
return tip, preset_names, new_selected
|
||||
except Exception as e:
|
||||
return f"删除失败:{str(e)}", [p["name"] for p in presets], preset_name
|
||||
|
||||
# 2.3 参考音频文件管理
|
||||
def get_persistent_audio_path(src_audio_path, preset_name):
|
||||
"""获取音频持久化路径,清理同配置名旧音频"""
|
||||
if not src_audio_path or not os.path.exists(src_audio_path):
|
||||
return None
|
||||
|
||||
# 清理文件名
|
||||
safe_preset_name = sanitize_filename(preset_name)
|
||||
safe_preset_name = safe_preset_name[:MAX_FILENAME_LENGTH]
|
||||
|
||||
# 提取后缀
|
||||
src_suffix = Path(src_audio_path).suffix.lower()
|
||||
if not src_suffix or src_suffix not in [".wav", ".mp3", ".flac", ".ogg", ".m4a"]:
|
||||
src_suffix = ".wav"
|
||||
|
||||
# 计算MD5
|
||||
audio_md5 = get_audio_md5(src_audio_path)
|
||||
dst_filename = f"{safe_preset_name}_{audio_md5}{src_suffix}"
|
||||
dst_path = REF_AUDIO_DIR / dst_filename
|
||||
|
||||
# 清理同配置名旧音频
|
||||
for old_audio in REF_AUDIO_DIR.glob(f"{safe_preset_name}_*"):
|
||||
if old_audio.suffix.lower() in [".wav", ".mp3", ".flac", ".ogg", ".m4a"]:
|
||||
try:
|
||||
old_audio.unlink()
|
||||
except Exception as e:
|
||||
print(f"清理旧音频失败:{e}")
|
||||
|
||||
# 复制新音频
|
||||
try:
|
||||
shutil.copy2(src_audio_path, dst_path)
|
||||
return str(dst_path)
|
||||
except Exception as e:
|
||||
print(f"音频持久化复制失败:{e}")
|
||||
return None
|
||||
|
||||
def clean_unreferenced_audios(presets):
|
||||
"""清理未被任何预设引用的冗余音频"""
|
||||
if not REF_AUDIO_DIR.exists():
|
||||
return
|
||||
|
||||
# 收集已引用音频
|
||||
referenced = set()
|
||||
for preset in presets:
|
||||
audio_path = preset.get("ref_audio_path")
|
||||
if audio_path and os.path.exists(audio_path):
|
||||
referenced.add(Path(audio_path).absolute())
|
||||
|
||||
# 删除未引用音频
|
||||
deleted_count = 0
|
||||
for audio_file in REF_AUDIO_DIR.glob("*"):
|
||||
if audio_file.is_file() and audio_file.suffix.lower() in [".wav", ".mp3", ".flac", ".ogg", ".m4a"]:
|
||||
if audio_file.absolute() not in referenced:
|
||||
try:
|
||||
audio_file.unlink()
|
||||
deleted_count += 1
|
||||
except Exception as e:
|
||||
print(f"清理冗余音频失败:{e}")
|
||||
|
||||
if deleted_count > 0:
|
||||
print(f"清理冗余未引用音频 {deleted_count} 个")
|
||||
|
||||
# ===================== 3. 推理参数持久化(infer_settings.json) =====================
|
||||
def load_infer_settings():
|
||||
"""加载推理参数配置"""
|
||||
ensure_dir_exists(INFER_SETTINGS_JSON.parent)
|
||||
if not INFER_SETTINGS_JSON.exists():
|
||||
return DEFAULT_INFER_SETTINGS
|
||||
try:
|
||||
with open(INFER_SETTINGS_JSON, "r", encoding="utf-8") as f:
|
||||
saved = json.load(f)
|
||||
return {**DEFAULT_INFER_SETTINGS, **saved}
|
||||
except Exception as e:
|
||||
print(f"加载推理参数失败,使用默认值:{e}")
|
||||
return DEFAULT_INFER_SETTINGS
|
||||
|
||||
def save_infer_settings_core(settings):
|
||||
"""保存推理参数核心逻辑(返回:提示信息)"""
|
||||
ensure_dir_exists(INFER_SETTINGS_JSON.parent)
|
||||
try:
|
||||
with open(INFER_SETTINGS_JSON, "w", encoding="utf-8") as f:
|
||||
json.dump(settings, f, indent=4, ensure_ascii=False)
|
||||
|
||||
# 精简日志输出
|
||||
print(f"✅ 推理配置保存成功:{INFER_SETTINGS_JSON.absolute()}")
|
||||
return "推理设置保存成功!已覆盖原有配置文件。"
|
||||
except Exception as e:
|
||||
print(f"❌ 推理配置保存失败:{e}")
|
||||
return f"推理设置保存失败:{str(e)}"
|
||||
|
||||
def restore_default_infer_settings_core():
|
||||
"""恢复推理参数默认值核心逻辑(返回:默认参数列表)"""
|
||||
ensure_dir_exists(INFER_SETTINGS_JSON.parent)
|
||||
try:
|
||||
with open(INFER_SETTINGS_JSON, "w", encoding="utf-8") as f:
|
||||
json.dump(DEFAULT_INFER_SETTINGS, f, indent=4, ensure_ascii=False)
|
||||
print(f"✅ 推理配置已恢复默认值:{INFER_SETTINGS_JSON.absolute()}")
|
||||
except Exception as e:
|
||||
print(f"❌ 推理配置恢复默认失败:{e}")
|
||||
|
||||
# 返回默认参数(按顺序对应UI组件)
|
||||
return [
|
||||
DEFAULT_INFER_SETTINGS["batch_size"],
|
||||
DEFAULT_INFER_SETTINGS["sample_steps"],
|
||||
DEFAULT_INFER_SETTINGS["fragment_interval"],
|
||||
DEFAULT_INFER_SETTINGS["speed_factor"],
|
||||
DEFAULT_INFER_SETTINGS["top_k"],
|
||||
DEFAULT_INFER_SETTINGS["top_p"],
|
||||
DEFAULT_INFER_SETTINGS["temperature"],
|
||||
DEFAULT_INFER_SETTINGS["repetition_penalty"],
|
||||
DEFAULT_INFER_SETTINGS["how_to_cut"],
|
||||
DEFAULT_INFER_SETTINGS["super_sampling"],
|
||||
DEFAULT_INFER_SETTINGS["parallel_infer"],
|
||||
DEFAULT_INFER_SETTINGS["split_bucket"],
|
||||
DEFAULT_INFER_SETTINGS["seed"],
|
||||
DEFAULT_INFER_SETTINGS["keep_random"]
|
||||
]
|
||||
@ -55,6 +55,10 @@ def main():
|
||||
n_gpus = torch.cuda.device_count()
|
||||
else:
|
||||
n_gpus = 1
|
||||
if n_gpus <= 1:
|
||||
run(0, n_gpus, hps)
|
||||
return
|
||||
|
||||
os.environ["MASTER_ADDR"] = "localhost"
|
||||
os.environ["MASTER_PORT"] = str(randint(20000, 55555))
|
||||
|
||||
@ -77,12 +81,14 @@ def run(rank, n_gpus, hps):
|
||||
writer = SummaryWriter(log_dir=hps.s2_ckpt_dir)
|
||||
writer_eval = SummaryWriter(log_dir=os.path.join(hps.s2_ckpt_dir, "eval"))
|
||||
|
||||
dist.init_process_group(
|
||||
backend="gloo" if os.name == "nt" or not torch.cuda.is_available() else "nccl",
|
||||
init_method="env://?use_libuv=False",
|
||||
world_size=n_gpus,
|
||||
rank=rank,
|
||||
)
|
||||
use_ddp = n_gpus > 1
|
||||
if use_ddp:
|
||||
dist.init_process_group(
|
||||
backend="gloo" if os.name == "nt" or not torch.cuda.is_available() else "nccl",
|
||||
init_method="env://?use_libuv=False",
|
||||
world_size=n_gpus,
|
||||
rank=rank,
|
||||
)
|
||||
torch.manual_seed(hps.train.seed)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.set_device(rank)
|
||||
@ -118,15 +124,20 @@ def run(rank, n_gpus, hps):
|
||||
shuffle=True,
|
||||
)
|
||||
collate_fn = TextAudioSpeakerCollate()
|
||||
train_loader = DataLoader(
|
||||
train_dataset,
|
||||
num_workers=5,
|
||||
worker_count = 0 if os.name == "nt" and n_gpus <= 1 else min(2 if os.name == "nt" else 5, os.cpu_count() or 1)
|
||||
loader_kwargs = dict(
|
||||
num_workers=worker_count,
|
||||
shuffle=False,
|
||||
pin_memory=True,
|
||||
pin_memory=torch.cuda.is_available(),
|
||||
collate_fn=collate_fn,
|
||||
batch_sampler=train_sampler,
|
||||
persistent_workers=True,
|
||||
prefetch_factor=3,
|
||||
)
|
||||
if worker_count > 0:
|
||||
loader_kwargs["persistent_workers"] = True
|
||||
loader_kwargs["prefetch_factor"] = 2 if os.name == "nt" else 3
|
||||
train_loader = DataLoader(
|
||||
train_dataset,
|
||||
**loader_kwargs,
|
||||
)
|
||||
save_root = "%s/logs_s2_%s_lora_%s" % (hps.data.exp_dir, hps.model.version, hps.train.lora_rank)
|
||||
os.makedirs(save_root, exist_ok=True)
|
||||
@ -156,7 +167,9 @@ def run(rank, n_gpus, hps):
|
||||
|
||||
def model2cuda(net_g, rank):
|
||||
if torch.cuda.is_available():
|
||||
net_g = DDP(net_g.cuda(rank), device_ids=[rank], find_unused_parameters=True)
|
||||
net_g = net_g.cuda(rank)
|
||||
if use_ddp:
|
||||
net_g = DDP(net_g, device_ids=[rank], find_unused_parameters=True)
|
||||
else:
|
||||
net_g = net_g.to(device)
|
||||
return net_g
|
||||
@ -242,6 +255,8 @@ def run(rank, n_gpus, hps):
|
||||
None,
|
||||
)
|
||||
scheduler_g.step()
|
||||
if use_ddp and dist.is_initialized():
|
||||
dist.destroy_process_group()
|
||||
print("training done")
|
||||
|
||||
|
||||
|
||||
@ -180,10 +180,15 @@ def _merge_erhua(initials: list[str], finals: list[str], word: str, pos: str) ->
|
||||
def _g2p(segments):
|
||||
phones_list = []
|
||||
word2ph = []
|
||||
for seg in segments:
|
||||
g2pw_batch_results = []
|
||||
g2pw_batch_cursor = 0
|
||||
processed_segments = [re.sub("[a-zA-Z]+", "", seg) for seg in segments]
|
||||
if is_g2pw:
|
||||
batch_inputs = [seg for seg in processed_segments if seg]
|
||||
g2pw_batch_results = g2pw._g2pw(batch_inputs) if batch_inputs else []
|
||||
|
||||
for seg in processed_segments:
|
||||
pinyins = []
|
||||
# Replace all English words in the sentence
|
||||
seg = re.sub("[a-zA-Z]+", "", seg)
|
||||
seg_cut = psg.lcut(seg)
|
||||
seg_cut = tone_modifier.pre_merge_for_modify(seg_cut)
|
||||
initials = []
|
||||
@ -204,8 +209,10 @@ def _g2p(segments):
|
||||
finals = sum(finals, [])
|
||||
print("pypinyin结果", initials, finals)
|
||||
else:
|
||||
# g2pw采用整句推理
|
||||
pinyins = g2pw.lazy_pinyin(seg, neutral_tone_with_five=True, style=Style.TONE3)
|
||||
# g2pw采用整句推理(批量推理,逐句取结果)
|
||||
if seg:
|
||||
pinyins = g2pw_batch_results[g2pw_batch_cursor]
|
||||
g2pw_batch_cursor += 1
|
||||
|
||||
pre_word_length = 0
|
||||
for word, pos in seg_cut:
|
||||
|
||||
@ -18,6 +18,7 @@ Credits
|
||||
|
||||
from typing import Dict
|
||||
from typing import List
|
||||
from typing import Optional
|
||||
from typing import Tuple
|
||||
|
||||
import numpy as np
|
||||
@ -37,6 +38,8 @@ def prepare_onnx_input(
|
||||
use_mask: bool = False,
|
||||
window_size: int = None,
|
||||
max_len: int = 512,
|
||||
char2id: Optional[Dict[str, int]] = None,
|
||||
char_phoneme_masks: Optional[Dict[str, List[int]]] = None,
|
||||
) -> Dict[str, np.array]:
|
||||
if window_size is not None:
|
||||
truncated_texts, truncated_query_ids = _truncate_texts(
|
||||
@ -48,33 +51,88 @@ def prepare_onnx_input(
|
||||
phoneme_masks = []
|
||||
char_ids = []
|
||||
position_ids = []
|
||||
tokenized_cache = {}
|
||||
|
||||
if char2id is None:
|
||||
char2id = {char: idx for idx, char in enumerate(chars)}
|
||||
if use_mask:
|
||||
if char_phoneme_masks is None:
|
||||
char_phoneme_masks = {
|
||||
char: [1 if i in char2phonemes[char] else 0 for i in range(len(labels))]
|
||||
for char in char2phonemes
|
||||
}
|
||||
else:
|
||||
full_phoneme_mask = [1] * len(labels)
|
||||
|
||||
for idx in range(len(texts)):
|
||||
text = (truncated_texts if window_size else texts)[idx].lower()
|
||||
query_id = (truncated_query_ids if window_size else query_ids)[idx]
|
||||
|
||||
try:
|
||||
tokens, text2token, token2text = tokenize_and_map(tokenizer=tokenizer, text=text)
|
||||
except Exception:
|
||||
print(f'warning: text "{text}" is invalid')
|
||||
return {}
|
||||
cached = tokenized_cache.get(text)
|
||||
if cached is None:
|
||||
try:
|
||||
tokens, text2token, token2text = tokenize_and_map(tokenizer=tokenizer, text=text)
|
||||
except Exception:
|
||||
print(f'warning: text "{text}" is invalid')
|
||||
return {}
|
||||
|
||||
text, query_id, tokens, text2token, token2text = _truncate(
|
||||
max_len=max_len, text=text, query_id=query_id, tokens=tokens, text2token=text2token, token2text=token2text
|
||||
)
|
||||
if len(tokens) <= max_len - 2:
|
||||
processed_tokens = ["[CLS]"] + tokens + ["[SEP]"]
|
||||
shared_input_id = list(np.array(tokenizer.convert_tokens_to_ids(processed_tokens)))
|
||||
shared_token_type_id = list(np.zeros((len(processed_tokens),), dtype=int))
|
||||
shared_attention_mask = list(np.ones((len(processed_tokens),), dtype=int))
|
||||
cached = {
|
||||
"is_short": True,
|
||||
"tokens": tokens,
|
||||
"text2token": text2token,
|
||||
"token2text": token2text,
|
||||
"input_id": shared_input_id,
|
||||
"token_type_id": shared_token_type_id,
|
||||
"attention_mask": shared_attention_mask,
|
||||
}
|
||||
else:
|
||||
cached = {
|
||||
"is_short": False,
|
||||
"tokens": tokens,
|
||||
"text2token": text2token,
|
||||
"token2text": token2text,
|
||||
}
|
||||
tokenized_cache[text] = cached
|
||||
|
||||
processed_tokens = ["[CLS]"] + tokens + ["[SEP]"]
|
||||
if cached["is_short"]:
|
||||
text_for_query = text
|
||||
query_id_for_query = query_id
|
||||
text2token_for_query = cached["text2token"]
|
||||
input_id = cached["input_id"]
|
||||
token_type_id = cached["token_type_id"]
|
||||
attention_mask = cached["attention_mask"]
|
||||
else:
|
||||
(
|
||||
text_for_query,
|
||||
query_id_for_query,
|
||||
tokens_for_query,
|
||||
text2token_for_query,
|
||||
_token2text_for_query,
|
||||
) = _truncate(
|
||||
max_len=max_len,
|
||||
text=text,
|
||||
query_id=query_id,
|
||||
tokens=cached["tokens"],
|
||||
text2token=cached["text2token"],
|
||||
token2text=cached["token2text"],
|
||||
)
|
||||
processed_tokens = ["[CLS]"] + tokens_for_query + ["[SEP]"]
|
||||
input_id = list(np.array(tokenizer.convert_tokens_to_ids(processed_tokens)))
|
||||
token_type_id = list(np.zeros((len(processed_tokens),), dtype=int))
|
||||
attention_mask = list(np.ones((len(processed_tokens),), dtype=int))
|
||||
|
||||
input_id = list(np.array(tokenizer.convert_tokens_to_ids(processed_tokens)))
|
||||
token_type_id = list(np.zeros((len(processed_tokens),), dtype=int))
|
||||
attention_mask = list(np.ones((len(processed_tokens),), dtype=int))
|
||||
|
||||
query_char = text[query_id]
|
||||
phoneme_mask = (
|
||||
[1 if i in char2phonemes[query_char] else 0 for i in range(len(labels))] if use_mask else [1] * len(labels)
|
||||
)
|
||||
char_id = chars.index(query_char)
|
||||
position_id = text2token[query_id] + 1 # [CLS] token locate at first place
|
||||
query_char = text_for_query[query_id_for_query]
|
||||
if use_mask:
|
||||
phoneme_mask = char_phoneme_masks[query_char]
|
||||
else:
|
||||
phoneme_mask = full_phoneme_mask
|
||||
char_id = char2id[query_char]
|
||||
position_id = text2token_for_query[query_id_for_query] + 1 # [CLS] token locate at first place
|
||||
|
||||
input_ids.append(input_id)
|
||||
token_type_ids.append(token_type_id)
|
||||
@ -83,10 +141,15 @@ def prepare_onnx_input(
|
||||
char_ids.append(char_id)
|
||||
position_ids.append(position_id)
|
||||
|
||||
max_token_length = max(len(seq) for seq in input_ids)
|
||||
|
||||
def _pad_sequences(sequences, pad_value=0):
|
||||
return [seq + [pad_value] * (max_token_length - len(seq)) for seq in sequences]
|
||||
|
||||
outputs = {
|
||||
"input_ids": np.array(input_ids).astype(np.int64),
|
||||
"token_type_ids": np.array(token_type_ids).astype(np.int64),
|
||||
"attention_masks": np.array(attention_masks).astype(np.int64),
|
||||
"input_ids": np.array(_pad_sequences(input_ids, pad_value=0)).astype(np.int64),
|
||||
"token_type_ids": np.array(_pad_sequences(token_type_ids, pad_value=0)).astype(np.int64),
|
||||
"attention_masks": np.array(_pad_sequences(attention_masks, pad_value=0)).astype(np.int64),
|
||||
"phoneme_masks": np.array(phoneme_masks).astype(np.float32),
|
||||
"char_ids": np.array(char_ids).astype(np.int64),
|
||||
"position_ids": np.array(position_ids).astype(np.int64),
|
||||
|
||||
@ -10,7 +10,6 @@ from typing import Any, Dict, List, Tuple
|
||||
import numpy as np
|
||||
import onnxruntime
|
||||
import requests
|
||||
import torch
|
||||
from opencc import OpenCC
|
||||
from pypinyin import Style, pinyin
|
||||
from transformers.models.auto.tokenization_auto import AutoTokenizer
|
||||
@ -22,9 +21,8 @@ from .utils import load_config
|
||||
onnxruntime.set_default_logger_severity(3)
|
||||
try:
|
||||
onnxruntime.preload_dlls()
|
||||
except:
|
||||
except Exception:
|
||||
pass
|
||||
# traceback.print_exc()
|
||||
warnings.filterwarnings("ignore")
|
||||
|
||||
model_version = "1.1"
|
||||
@ -55,6 +53,24 @@ def predict(session, onnx_input: Dict[str, Any], labels: List[str]) -> Tuple[Lis
|
||||
return all_preds, all_confidences
|
||||
|
||||
|
||||
def _load_json_from_candidates(filename: str, candidate_dirs: List[str]) -> Dict[str, Any]:
|
||||
for candidate_dir in candidate_dirs:
|
||||
if not candidate_dir:
|
||||
continue
|
||||
json_path = os.path.join(candidate_dir, filename)
|
||||
if os.path.exists(json_path):
|
||||
with open(json_path, "r", encoding="utf-8") as fr:
|
||||
return json.load(fr)
|
||||
raise FileNotFoundError(f"Cannot locate {filename} in candidate dirs: {candidate_dirs}")
|
||||
|
||||
|
||||
def _find_first_existing_file(*paths: str) -> str:
|
||||
for path in paths:
|
||||
if path and os.path.exists(path):
|
||||
return path
|
||||
raise FileNotFoundError(f"Files not found: {paths}")
|
||||
|
||||
|
||||
def download_and_decompress(model_dir: str = "G2PWModel/"):
|
||||
if not os.path.exists(model_dir):
|
||||
parent_directory = os.path.dirname(model_dir)
|
||||
@ -62,7 +78,7 @@ def download_and_decompress(model_dir: str = "G2PWModel/"):
|
||||
extract_dir = os.path.join(parent_directory, "G2PWModel_1.1")
|
||||
extract_dir_new = os.path.join(parent_directory, "G2PWModel")
|
||||
print("Downloading g2pw model...")
|
||||
modelscope_url = "https://www.modelscope.cn/models/kamiorinn/g2pw/resolve/master/G2PWModel_1.1.zip" # "https://paddlespeech.cdn.bcebos.com/Parakeet/released_models/g2p/G2PWModel_1.1.zip"
|
||||
modelscope_url = "https://www.modelscope.cn/models/kamiorinn/g2pw/resolve/master/G2PWModel_1.1.zip"
|
||||
with requests.get(modelscope_url, stream=True) as r:
|
||||
r.raise_for_status()
|
||||
with open(zip_dir, "wb") as f:
|
||||
@ -79,7 +95,7 @@ def download_and_decompress(model_dir: str = "G2PWModel/"):
|
||||
return model_dir
|
||||
|
||||
|
||||
class G2PWOnnxConverter:
|
||||
class _G2PWBaseOnnxConverter:
|
||||
def __init__(
|
||||
self,
|
||||
model_dir: str = "G2PWModel/",
|
||||
@ -87,33 +103,16 @@ class G2PWOnnxConverter:
|
||||
model_source: str = None,
|
||||
enable_non_tradional_chinese: bool = False,
|
||||
):
|
||||
uncompress_path = download_and_decompress(model_dir)
|
||||
|
||||
sess_options = onnxruntime.SessionOptions()
|
||||
sess_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
|
||||
sess_options.execution_mode = onnxruntime.ExecutionMode.ORT_SEQUENTIAL
|
||||
sess_options.intra_op_num_threads = 2 if torch.cuda.is_available() else 0
|
||||
if "CUDAExecutionProvider" in onnxruntime.get_available_providers():
|
||||
self.session_g2pW = onnxruntime.InferenceSession(
|
||||
os.path.join(uncompress_path, "g2pW.onnx"),
|
||||
sess_options=sess_options,
|
||||
providers=["CUDAExecutionProvider", "CPUExecutionProvider"],
|
||||
)
|
||||
else:
|
||||
self.session_g2pW = onnxruntime.InferenceSession(
|
||||
os.path.join(uncompress_path, "g2pW.onnx"),
|
||||
sess_options=sess_options,
|
||||
providers=["CPUExecutionProvider"],
|
||||
)
|
||||
self.config = load_config(config_path=os.path.join(uncompress_path, "config.py"), use_default=True)
|
||||
self.model_dir = download_and_decompress(model_dir)
|
||||
self.config = load_config(config_path=os.path.join(self.model_dir, "config.py"), use_default=True)
|
||||
|
||||
self.model_source = model_source if model_source else self.config.model_source
|
||||
self.enable_opencc = enable_non_tradional_chinese
|
||||
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(self.model_source)
|
||||
|
||||
polyphonic_chars_path = os.path.join(uncompress_path, "POLYPHONIC_CHARS.txt")
|
||||
monophonic_chars_path = os.path.join(uncompress_path, "MONOPHONIC_CHARS.txt")
|
||||
polyphonic_chars_path = os.path.join(self.model_dir, "POLYPHONIC_CHARS.txt")
|
||||
monophonic_chars_path = os.path.join(self.model_dir, "MONOPHONIC_CHARS.txt")
|
||||
|
||||
self.polyphonic_chars = [
|
||||
line.split("\t") for line in open(polyphonic_chars_path, encoding="utf-8").read().strip().split("\n")
|
||||
]
|
||||
@ -149,31 +148,47 @@ class G2PWOnnxConverter:
|
||||
)
|
||||
|
||||
self.chars = sorted(list(self.char2phonemes.keys()))
|
||||
self.char2id = {char: idx for idx, char in enumerate(self.chars)}
|
||||
self.char_phoneme_masks = (
|
||||
{
|
||||
char: [1 if i in self.char2phonemes[char] else 0 for i in range(len(self.labels))]
|
||||
for char in self.char2phonemes
|
||||
}
|
||||
if self.config.use_mask
|
||||
else None
|
||||
)
|
||||
|
||||
self.polyphonic_chars_new = set(self.chars)
|
||||
for char in self.non_polyphonic:
|
||||
if char in self.polyphonic_chars_new:
|
||||
self.polyphonic_chars_new.remove(char)
|
||||
self.polyphonic_chars_new.discard(char)
|
||||
|
||||
self.monophonic_chars_dict = {char: phoneme for char, phoneme in self.monophonic_chars}
|
||||
for char in self.non_monophonic:
|
||||
if char in self.monophonic_chars_dict:
|
||||
self.monophonic_chars_dict.pop(char)
|
||||
self.monophonic_chars_dict.pop(char, None)
|
||||
|
||||
self.pos_tags = ["UNK", "A", "C", "D", "I", "N", "P", "T", "V", "DE", "SHI"]
|
||||
default_asset_dir = os.path.normpath(os.path.join(os.path.dirname(__file__), "..", "G2PWModel"))
|
||||
candidate_asset_dirs = [self.model_dir, default_asset_dir]
|
||||
self.bopomofo_convert_dict = _load_json_from_candidates(
|
||||
"bopomofo_to_pinyin_wo_tune_dict.json", candidate_asset_dirs
|
||||
)
|
||||
self.char_bopomofo_dict = _load_json_from_candidates("char_bopomofo_dict.json", candidate_asset_dirs)
|
||||
|
||||
with open(os.path.join(uncompress_path, "bopomofo_to_pinyin_wo_tune_dict.json"), "r", encoding="utf-8") as fr:
|
||||
self.bopomofo_convert_dict = json.load(fr)
|
||||
self.style_convert_func = {
|
||||
"bopomofo": lambda x: x,
|
||||
"pinyin": self._convert_bopomofo_to_pinyin,
|
||||
}[style]
|
||||
|
||||
with open(os.path.join(uncompress_path, "char_bopomofo_dict.json"), "r", encoding="utf-8") as fr:
|
||||
self.char_bopomofo_dict = json.load(fr)
|
||||
|
||||
if self.enable_opencc:
|
||||
self.cc = OpenCC("s2tw")
|
||||
self.enable_sentence_dedup = os.getenv("g2pw_sentence_dedup", "true").strip().lower() in {
|
||||
"1",
|
||||
"true",
|
||||
"yes",
|
||||
"y",
|
||||
"on",
|
||||
}
|
||||
# 聚焦到多音字附近上下文,默认左右各16字;设为0表示关闭裁剪(整句)。
|
||||
self.polyphonic_context_chars = max(0, int(os.getenv("g2pw_polyphonic_context_chars", "16")))
|
||||
|
||||
def _convert_bopomofo_to_pinyin(self, bopomofo: str) -> str:
|
||||
tone = bopomofo[-1]
|
||||
@ -181,9 +196,8 @@ class G2PWOnnxConverter:
|
||||
component = self.bopomofo_convert_dict.get(bopomofo[:-1])
|
||||
if component:
|
||||
return component + tone
|
||||
else:
|
||||
print(f'Warning: "{bopomofo}" cannot convert to pinyin')
|
||||
return None
|
||||
print(f'Warning: "{bopomofo}" cannot convert to pinyin')
|
||||
return None
|
||||
|
||||
def __call__(self, sentences: List[str]) -> List[List[str]]:
|
||||
if isinstance(sentences, str):
|
||||
@ -197,51 +211,147 @@ class G2PWOnnxConverter:
|
||||
translated_sentences.append(translated_sent)
|
||||
sentences = translated_sentences
|
||||
|
||||
texts, query_ids, sent_ids, partial_results = self._prepare_data(sentences=sentences)
|
||||
texts, model_query_ids, result_query_ids, sent_ids, partial_results = self._prepare_data(sentences=sentences)
|
||||
if len(texts) == 0:
|
||||
# sentences no polyphonic words
|
||||
return partial_results
|
||||
|
||||
onnx_input = prepare_onnx_input(
|
||||
model_input = prepare_onnx_input(
|
||||
tokenizer=self.tokenizer,
|
||||
labels=self.labels,
|
||||
char2phonemes=self.char2phonemes,
|
||||
chars=self.chars,
|
||||
texts=texts,
|
||||
query_ids=query_ids,
|
||||
query_ids=model_query_ids,
|
||||
use_mask=self.config.use_mask,
|
||||
window_size=None,
|
||||
char2id=self.char2id,
|
||||
char_phoneme_masks=self.char_phoneme_masks,
|
||||
)
|
||||
|
||||
preds, confidences = predict(session=self.session_g2pW, onnx_input=onnx_input, labels=self.labels)
|
||||
if not model_input:
|
||||
return partial_results
|
||||
|
||||
if self.enable_sentence_dedup:
|
||||
preds, _confidences = self._predict_with_sentence_dedup(model_input=model_input, texts=texts)
|
||||
else:
|
||||
preds, _confidences = self._predict(model_input=model_input)
|
||||
|
||||
if self.config.use_char_phoneme:
|
||||
preds = [pred.split(" ")[1] for pred in preds]
|
||||
|
||||
results = partial_results
|
||||
for sent_id, query_id, pred in zip(sent_ids, query_ids, preds):
|
||||
for sent_id, query_id, pred in zip(sent_ids, result_query_ids, preds):
|
||||
results[sent_id][query_id] = self.style_convert_func(pred)
|
||||
|
||||
return results
|
||||
|
||||
def _prepare_data(self, sentences: List[str]) -> Tuple[List[str], List[int], List[int], List[List[str]]]:
|
||||
texts, query_ids, sent_ids, partial_results = [], [], [], []
|
||||
def _prepare_data(
|
||||
self, sentences: List[str]
|
||||
) -> Tuple[List[str], List[int], List[int], List[int], List[List[str]]]:
|
||||
texts, model_query_ids, result_query_ids, sent_ids, partial_results = [], [], [], [], []
|
||||
for sent_id, sent in enumerate(sentences):
|
||||
# pypinyin works well for Simplified Chinese than Traditional Chinese
|
||||
sent_s = tranditional_to_simplified(sent)
|
||||
pypinyin_result = pinyin(sent_s, neutral_tone_with_five=True, style=Style.TONE3)
|
||||
partial_result = [None] * len(sent)
|
||||
polyphonic_indices: List[int] = []
|
||||
for i, char in enumerate(sent):
|
||||
if char in self.polyphonic_chars_new:
|
||||
texts.append(sent)
|
||||
query_ids.append(i)
|
||||
sent_ids.append(sent_id)
|
||||
polyphonic_indices.append(i)
|
||||
elif char in self.monophonic_chars_dict:
|
||||
partial_result[i] = self.style_convert_func(self.monophonic_chars_dict[char])
|
||||
elif char in self.char_bopomofo_dict:
|
||||
partial_result[i] = pypinyin_result[i][0]
|
||||
# partial_result[i] = self.style_convert_func(self.char_bopomofo_dict[char][0])
|
||||
else:
|
||||
partial_result[i] = pypinyin_result[i][0]
|
||||
|
||||
if polyphonic_indices:
|
||||
if self.polyphonic_context_chars > 0:
|
||||
left = max(0, polyphonic_indices[0] - self.polyphonic_context_chars)
|
||||
right = min(len(sent), polyphonic_indices[-1] + self.polyphonic_context_chars + 1)
|
||||
sent_for_predict = sent[left:right]
|
||||
query_offset = left
|
||||
else:
|
||||
sent_for_predict = sent
|
||||
query_offset = 0
|
||||
|
||||
for index in polyphonic_indices:
|
||||
texts.append(sent_for_predict)
|
||||
model_query_ids.append(index - query_offset)
|
||||
result_query_ids.append(index)
|
||||
sent_ids.append(sent_id)
|
||||
|
||||
partial_results.append(partial_result)
|
||||
return texts, query_ids, sent_ids, partial_results
|
||||
return texts, model_query_ids, result_query_ids, sent_ids, partial_results
|
||||
|
||||
def _predict(self, model_input: Dict[str, Any]) -> Tuple[List[str], List[float]]:
|
||||
raise NotImplementedError
|
||||
|
||||
def _predict_with_sentence_dedup(
|
||||
self, model_input: Dict[str, Any], texts: List[str]
|
||||
) -> Tuple[List[str], List[float]]:
|
||||
if len(texts) <= 1:
|
||||
return self._predict(model_input=model_input)
|
||||
|
||||
grouped_indices: Dict[str, List[int]] = {}
|
||||
for idx, text in enumerate(texts):
|
||||
grouped_indices.setdefault(text, []).append(idx)
|
||||
|
||||
if all(len(indices) == 1 for indices in grouped_indices.values()):
|
||||
return self._predict(model_input=model_input)
|
||||
|
||||
preds: List[str] = [""] * len(texts)
|
||||
confidences: List[float] = [0.0] * len(texts)
|
||||
for indices in grouped_indices.values():
|
||||
group_input = {name: value[indices] for name, value in model_input.items()}
|
||||
if len(indices) > 1:
|
||||
for name in ("input_ids", "token_type_ids", "attention_masks"):
|
||||
group_input[name] = group_input[name][:1]
|
||||
|
||||
group_preds, group_confidences = self._predict(model_input=group_input)
|
||||
for output_idx, pred, confidence in zip(indices, group_preds, group_confidences):
|
||||
preds[output_idx] = pred
|
||||
confidences[output_idx] = confidence
|
||||
|
||||
return preds, confidences
|
||||
|
||||
|
||||
class G2PWOnnxConverter(_G2PWBaseOnnxConverter):
|
||||
def __init__(
|
||||
self,
|
||||
model_dir: str = "G2PWModel/",
|
||||
style: str = "bopomofo",
|
||||
model_source: str = None,
|
||||
enable_non_tradional_chinese: bool = False,
|
||||
):
|
||||
super().__init__(
|
||||
model_dir=model_dir,
|
||||
style=style,
|
||||
model_source=model_source,
|
||||
enable_non_tradional_chinese=enable_non_tradional_chinese,
|
||||
)
|
||||
|
||||
sess_options = onnxruntime.SessionOptions()
|
||||
sess_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
|
||||
sess_options.execution_mode = onnxruntime.ExecutionMode.ORT_SEQUENTIAL
|
||||
sess_options.intra_op_num_threads = 2
|
||||
|
||||
onnx_path = _find_first_existing_file(
|
||||
os.path.join(self.model_dir, "g2pW.onnx"),
|
||||
os.path.join(self.model_dir, "g2pw.onnx"),
|
||||
)
|
||||
|
||||
if "CUDAExecutionProvider" in onnxruntime.get_available_providers():
|
||||
self.session_g2pw = onnxruntime.InferenceSession(
|
||||
onnx_path,
|
||||
sess_options=sess_options,
|
||||
providers=["CUDAExecutionProvider", "CPUExecutionProvider"],
|
||||
)
|
||||
else:
|
||||
self.session_g2pw = onnxruntime.InferenceSession(
|
||||
onnx_path,
|
||||
sess_options=sess_options,
|
||||
providers=["CPUExecutionProvider"],
|
||||
)
|
||||
|
||||
def _predict(self, model_input: Dict[str, Any]) -> Tuple[List[str], List[float]]:
|
||||
return predict(session=self.session_g2pw, onnx_input=model_input, labels=self.labels)
|
||||
|
||||
33
gowebui_batched_infer.bat
Normal file
33
gowebui_batched_infer.bat
Normal file
@ -0,0 +1,33 @@
|
||||
@echo off
|
||||
:: 1. 切换命令行编码为UTF-8,解决中文显示乱码(必须放在最前面)
|
||||
chcp 65001 > nul
|
||||
|
||||
:: 2. 获取当前bat文件所在目录并格式化
|
||||
set "SCRIPT_DIR=%~dp0"
|
||||
set "SCRIPT_DIR=%SCRIPT_DIR:~0,-1%"
|
||||
|
||||
:: 3. 切换到脚本根目录
|
||||
cd /d "%SCRIPT_DIR%"
|
||||
|
||||
:: 4. 创建专属TEMP目录(补充主页面的核心步骤)
|
||||
if not exist "TEMP" md "TEMP"
|
||||
set "TEMP=%SCRIPT_DIR%\TEMP"
|
||||
|
||||
:: 5. 设置核心环境变量(补充推理脚本依赖的配置)
|
||||
set "version=v2Pro"
|
||||
:: 语言配置
|
||||
set "language=zh_CN"
|
||||
:: 启用半精度推理(GPU用户推荐,CPU用户改为False)
|
||||
set "is_half=True"
|
||||
:: 指定GPU卡号(多卡可修改,无GPU则删除此行)
|
||||
set "_CUDA_VISIBLE_DEVICES=0"
|
||||
|
||||
:: 6. 将runtime目录加入环境变量,确保能调用内置python
|
||||
set "PATH=%SCRIPT_DIR%\runtime;%PATH%"
|
||||
|
||||
:: 7. 直接启动并行推理脚本,传入中文语言参数
|
||||
echo 正在启动GPT-SoVITS并行推理页面...
|
||||
runtime\python.exe -I GPT_SoVITS/inference_webui_fast.py zh_CN
|
||||
|
||||
:: 8. 执行完成后暂停,便于查看报错信息
|
||||
pause
|
||||
@ -39,6 +39,7 @@ def create_model(language="zh"):
|
||||
local_dir="tools/asr/models/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch",
|
||||
)
|
||||
model_revision = "v2.0.4"
|
||||
vad_model_revision = punc_model_revision = "v2.0.4"
|
||||
elif language == "yue":
|
||||
path_asr = "tools/asr/models/speech_UniASR_asr_2pass-cantonese-CHS-16k-common-vocab1468-tensorflow1-online"
|
||||
snapshot_download(
|
||||
@ -51,8 +52,6 @@ def create_model(language="zh"):
|
||||
else:
|
||||
raise ValueError(f"{language} is not supported")
|
||||
|
||||
vad_model_revision = punc_model_revision = "v2.0.4"
|
||||
|
||||
if language in funasr_models:
|
||||
return funasr_models[language]
|
||||
else:
|
||||
|
||||
@ -485,6 +485,8 @@ def istft(spec, hl):
|
||||
wave_right = librosa.istft(spec_right, hop_length=hl)
|
||||
wave = np.asfortranarray([wave_left, wave_right])
|
||||
|
||||
return wave
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user