Compare commits

...

9 Commits

Author SHA1 Message Date
linguikun1986
2ec561432b
Merge 7f6787121bca21f78a77a9c9de9206f9b48179d7 into 938f05fce8bcfb2407b8311fbbc10ac4d9ffe1c0 2026-04-18 17:20:13 +08:00
Harikrishna KP
938f05fce8
fix: correct torch.randint upper bound to include both values (#2733) 2026-04-18 17:19:55 +08:00
huang yutong
445d18ccce
fix: 修复 TTS 音频后处理中的多个缺陷 (#2753)
1. 修复音频超采样时 int16 双重转换导致整数溢出(CRITICAL)
   - audio_postprocess 中 `audio = (audio * 32768).astype(np.int16)` 位于
     if/else 块之外无条件执行,当 super_sampling=True 时音频已在分支内
     转为 int16,再次乘以 32768 导致溢出和音频完全失真
   - 同时修复 super_sampling=True 但超分模型不存在时 torch.Tensor 调用
     .astype() 的 AttributeError

2. 修复 batched vocoder 推理中 padding_len=0 导致音频丢失(HIGH)
   - 当 padding_len 恰好为 0 时,`-0 * upsample_rate == 0`,切片
     `audio[x:0]` 返回空张量,导致整段音频丢失

3. 修复文件不存在时错误地抛出 FileExistsError(LOW)
   - 应为 FileNotFoundError

Made-with: Cursor
2026-04-18 17:16:24 +08:00
Mushroomcowisheggs
00ce973412
feat: 添加数据集的错误处理提示 (#2758)
Co-authored-by: moomushroom <107208254+moomushroom@users.noreply.github.com>
2026-04-18 17:13:30 +08:00
huang yutong
14191901cd
fix: 修复多个模块中的独立 bug (#2755)
1. 修复 sync_buffer 中除以函数对象而非调用结果(distrib.py)
   - `buffer.data /= world_size` 中 world_size 是函数,缺少 (),
     导致 TypeError 使分布式训练 buffer 同步失败

2. 修复 istft 函数缺少 return 语句(spec_utils.py)
   - 函数计算了结果但未返回,调用者始终得到 None

3. 修复 cut0 返回字面量 "/n" 而非换行符 "\n"(text_segmentation_method.py)
   - 导致后续 text.split("\n") 无法正确切分,字面 /n 被当作文本内容

4. 修复粤语 ASR 的 vad/punc model_revision 被无条件覆盖(funasr_asr.py)
   - 粤语分支将 vad_model_revision 设为空(因不使用 VAD/标点模型),
     但 if/else 外的赋值将其覆盖为 "v2.0.4",传入错误的 revision 参数

Made-with: Cursor
2026-04-18 17:10:56 +08:00
东云
780383d5bd
[codex] Improve Windows single-GPU v3 LoRA training / 改进 Windows 单卡 v3 LoRA 训练流程 (#2767)
* Improve Windows single-GPU v3 LoRA training

* Drop unrelated checkpoint helper change from PR

* Tighten PR scope to single-GPU training path fixes
2026-04-18 16:54:26 +08:00
白菜工厂1145号员工
ba8de9b760
优化 G2PW 的推理输入构造与多音字处理流程,减少重复计算,降低长句场景下的推理开销 (#2763)
* Enhance G2P processing by implementing batch input handling in _g2p function, improving efficiency. Update prepare_onnx_input to utilize caching for tokenization and add optional parameters for character ID mapping and phoneme masks. Refactor G2PWOnnxConverter to streamline model loading and configuration management.

* Enhance G2PW model input handling by introducing polyphonic context character support and updating the data preparation method to return additional query IDs. This improves the processing of polyphonic characters in sentences.
2026-04-18 16:52:32 +08:00
kun
7f6787121b Merge branch 'kun' of https://github.com/linguikun1986/GPT-SoVITS-Kun into kun 2026-01-21 00:59:08 +08:00
ChasonJiang
6e027ec111 新增直接打开推理页面bat命令,针对参考音频、推理参数做了持久化配置,解决每次推理都要重复操作的痛点,新增模型记忆,即每次打开推理页面,默认加载最后一次选择的模型。如果是从主页进入,则主动加载主页选择的模修复bug (#2704) 2026-01-21 00:58:49 +08:00
16 changed files with 1488 additions and 440 deletions

5
.gitignore vendored
View File

@ -193,3 +193,8 @@ cython_debug/
# PyPI configuration file
.pypirc
/.vs
/GPT_SoVITS/configs/tts_infer.yaml
/GPT_SoVITS/configs/infer_settings.json
/last_selected_preset.json
/last_selected_models.json

View File

@ -67,8 +67,10 @@ class Text2SemanticDataset(Dataset):
)
) # "%s/3-bert"%exp_dir#bert_dir
self.path6 = semantic_path # "%s/6-name2semantic.tsv"%exp_dir#semantic_path
assert os.path.exists(self.path2)
assert os.path.exists(self.path6)
if not os.path.exists(self.path2):
raise FileNotFoundError(f"Phoneme data file not found: {self.path2}")
if not os.path.exists(self.path6):
raise FileNotFoundError(f"Semantic data file not found: {self.path6}")
self.phoneme_data = {}
with open(self.path2, "r", encoding="utf8") as f:
lines = f.read().strip("\n").split("\n")
@ -131,7 +133,7 @@ class Text2SemanticDataset(Dataset):
phoneme, word2ph, text = self.phoneme_data[item_name]
except Exception:
traceback.print_exc()
# print(f"{item_name} not in self.phoneme_data !")
print(f"Warning: File \"{item_name}\" not in self.phoneme_data! Skipped. ")
num_not_in += 1
continue
@ -152,7 +154,7 @@ class Text2SemanticDataset(Dataset):
phoneme_ids = cleaned_text_to_sequence(phoneme, version)
except:
traceback.print_exc()
# print(f"{item_name} not in self.phoneme_data !")
print(f"Warning: Failed to convert phonemes to sequence for file \"{item_name}\"! Skipped. ")
num_not_in += 1
continue
# if len(phoneme_ids) >400:###########2改为恒定限制为semantic/2.5就行
@ -228,7 +230,11 @@ class Text2SemanticDataset(Dataset):
# bert_feature=torch.zeros_like(phoneme_ids,dtype=torch.float32)
bert_feature = None
else:
try:
assert bert_feature.shape[-1] == len(phoneme_ids)
except AssertionError:
print(f"AssertionError: The BERT feature dimension ({bert_feature.shape[-1]}) of the file '{item_name}' does not match the length of the phoneme sequence ({len(phoneme_ids)}).")
raise
return {
"idx": idx,
"phoneme_ids": phoneme_ids,

View File

@ -262,7 +262,7 @@ def make_reject_y(y_o, y_lens):
reject_y = []
reject_y_lens = []
for b in range(bs):
process_item_idx = torch.randint(0, 1, size=(1,))[0]
process_item_idx = torch.randint(0, 2, size=(1,))[0]
if process_item_idx == 0:
new_y = repeat_P(y_o[b])
reject_y.append(new_y)

View File

@ -499,7 +499,7 @@ class TTS:
if if_lora_v3 == True and os.path.exists(path_sovits) == False:
info = path_sovits + i18n("SoVITS %s 底模缺失,无法加载相应 LoRA 权重" % model_version)
raise FileExistsError(info)
raise FileNotFoundError(info)
# dict_s2 = torch.load(weights_path, map_location=self.configs.device,weights_only=False)
dict_s2 = load_sovits_new(weights_path)
@ -1579,14 +1579,13 @@ class TTS:
if max_audio > 1:
audio /= max_audio
audio = (audio * 32768).astype(np.int16)
else:
audio = audio.cpu().numpy()
audio = (audio * 32768).astype(np.int16)
t2 = time.perf_counter()
print(f"超采样用时:{t2 - t1:.3f}s")
else:
# audio = audio.float() * 32768
# audio = audio.to(dtype=torch.int16).clamp(-32768, 32767).cpu().numpy()
audio = audio.cpu().numpy()
audio = (audio * 32768).astype(np.int16)
@ -1768,7 +1767,10 @@ class TTS:
pos += chunk_len * upsample_rate
audio = self.sola_algorithm(audio_fragments, overlapped_len * upsample_rate)
if padding_len > 0:
audio = audio[overlapped_len * upsample_rate : -padding_len * upsample_rate]
else:
audio = audio[overlapped_len * upsample_rate :]
audio_fragments = []
for feat_len in feat_lens:

View File

@ -92,7 +92,7 @@ def cut0(inp):
if not set(inp).issubset(punctuation):
return inp
else:
return "/n"
return "\n"
# 凑四句一切

View File

@ -1,56 +0,0 @@
custom:
bert_base_path: GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large
cnhuhbert_base_path: GPT_SoVITS/pretrained_models/chinese-hubert-base
device: cuda
is_half: true
t2s_weights_path: GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s1bert25hz-5kh-longer-epoch=12-step=369668.ckpt
version: v2
vits_weights_path: GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s2G2333k.pth
v1:
bert_base_path: GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large
cnhuhbert_base_path: GPT_SoVITS/pretrained_models/chinese-hubert-base
device: cpu
is_half: false
t2s_weights_path: GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt
version: v1
vits_weights_path: GPT_SoVITS/pretrained_models/s2G488k.pth
v2:
bert_base_path: GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large
cnhuhbert_base_path: GPT_SoVITS/pretrained_models/chinese-hubert-base
device: cpu
is_half: false
t2s_weights_path: GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s1bert25hz-5kh-longer-epoch=12-step=369668.ckpt
version: v2
vits_weights_path: GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s2G2333k.pth
v2Pro:
bert_base_path: GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large
cnhuhbert_base_path: GPT_SoVITS/pretrained_models/chinese-hubert-base
device: cpu
is_half: false
t2s_weights_path: GPT_SoVITS/pretrained_models/s1v3.ckpt
version: v2Pro
vits_weights_path: GPT_SoVITS/pretrained_models/v2Pro/s2Gv2Pro.pth
v2ProPlus:
bert_base_path: GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large
cnhuhbert_base_path: GPT_SoVITS/pretrained_models/chinese-hubert-base
device: cpu
is_half: false
t2s_weights_path: GPT_SoVITS/pretrained_models/s1v3.ckpt
version: v2ProPlus
vits_weights_path: GPT_SoVITS/pretrained_models/v2Pro/s2Gv2ProPlus.pth
v3:
bert_base_path: GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large
cnhuhbert_base_path: GPT_SoVITS/pretrained_models/chinese-hubert-base
device: cpu
is_half: false
t2s_weights_path: GPT_SoVITS/pretrained_models/s1v3.ckpt
version: v3
vits_weights_path: GPT_SoVITS/pretrained_models/s2Gv3.pth
v4:
bert_base_path: GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large
cnhuhbert_base_path: GPT_SoVITS/pretrained_models/chinese-hubert-base
device: cpu
is_half: false
t2s_weights_path: GPT_SoVITS/pretrained_models/s1v3.ckpt
version: v4
vits_weights_path: GPT_SoVITS/pretrained_models/gsv-v4-pretrained/s2Gv4.pth

File diff suppressed because it is too large Load Diff

View File

@ -87,7 +87,7 @@ def sync_buffer(buffers, average=True):
for buffer, handle in handles:
handle.wait()
if average:
buffer.data /= world_size
buffer.data /= world_size()
def sync_grad(params):

View File

@ -0,0 +1,425 @@
# -*- coding: utf-8 -*-
"""
GPT-SoVITS 持久化工具类
包含模型配置参考音频推理参数 的持久化读写与管理
抽离自主文件减少主文件臃肿方便后续维护
"""
import json
import yaml
import hashlib
import os
import shutil
import random
from pathlib import Path
# ===================== 全局配置(统一管理所有持久化文件路径) =====================
# 模型持久化配置文件
LAST_SELECTED_MODELS_JSON = Path("./last_selected_models.json")
# 参考预设最后选中配置文件
LAST_SELECTED_PRESET_JSON = Path("./last_selected_preset.json")
# 参考音频持久化目录
REF_AUDIO_DIR = Path("GPT_SoVITS/ref_audios")
# 参考预设配置文件
REF_PRESETS_YAML = Path("GPT_SoVITS/configs/ref_audios_presets.yaml")
# 推理参数配置文件
INFER_SETTINGS_JSON = Path("GPT_SoVITS/configs/infer_settings.json")
# 参考音频配置常量
MAX_FILENAME_LENGTH = 40
INVALID_FILE_CHARS = set(r'\/:*?"<>|')
# 默认推理参数
DEFAULT_INFER_SETTINGS = {
"batch_size": 20,
"sample_steps": 32,
"fragment_interval": 0.2,
"speed_factor": 1.0,
"top_k": 5,
"top_p": 1.0,
"temperature": 1.0,
"repetition_penalty": 1.35,
"how_to_cut": "凑四句一切",
"super_sampling": False,
"parallel_infer": True,
"split_bucket": True,
"seed": -1,
"keep_random": True
}
# ===================== 通用工具函数(抽离重复逻辑) =====================
def sanitize_filename(name):
"""清理文件名中的非法字符,替换为下划线"""
if not name:
return "unnamed_preset"
return ''.join(c if c not in INVALID_FILE_CHARS else '_' for c in name)
def get_audio_md5(file_path, chunk_size=4096):
"""计算音频文件的MD5值取前8位用于区分不同音频内容"""
if not os.path.exists(file_path):
return "invalid_file"
try:
md5 = hashlib.md5()
with open(file_path, 'rb') as f:
while chunk := f.read(chunk_size):
md5.update(chunk)
return md5.hexdigest()[:8]
except Exception as e:
print(f"计算音频MD5失败{e}")
return f"err_{random.randint(10000000, 99999999)}"
def ensure_dir_exists(dir_path):
"""确保目录存在,不存在则创建"""
if dir_path and not dir_path.exists():
dir_path.mkdir(exist_ok=True, parents=True)
# ===================== 1. 模型配置持久化last_selected_models.json =====================
def init_last_selected_models(gpt_default, sovits_default, current_version):
"""初始化模型配置文件,写入默认模型路径"""
ensure_dir_exists(LAST_SELECTED_MODELS_JSON.parent)
init_data = {
"gpt_model_path": gpt_default,
"sovits_model_path": sovits_default,
"version": current_version
}
with open(LAST_SELECTED_MODELS_JSON, "w", encoding="utf-8") as f:
json.dump(init_data, f, ensure_ascii=False, indent=4)
print(f"首次生成模型配置文件:{LAST_SELECTED_MODELS_JSON}")
return init_data
def read_last_selected_models():
"""读取模型配置文件中的路径"""
if not LAST_SELECTED_MODELS_JSON.exists():
return None
try:
with open(LAST_SELECTED_MODELS_JSON, "r", encoding="utf-8") as f:
data = json.load(f)
# 校验必要字段
required_fields = ["gpt_model_path", "sovits_model_path", "version"]
for field in required_fields:
if field not in data:
return None
return data
except Exception as e:
print(f"读取模型配置失败:{e}")
return None
def write_last_selected_models(gpt_path_new, sovits_path_new, current_version):
"""写入新的模型路径到配置文件"""
ensure_dir_exists(LAST_SELECTED_MODELS_JSON.parent)
try:
data = read_last_selected_models() or {}
data["gpt_model_path"] = gpt_path_new
data["sovits_model_path"] = sovits_path_new
data["version"] = current_version
with open(LAST_SELECTED_MODELS_JSON, "w", encoding="utf-8") as f:
json.dump(data, f, ensure_ascii=False, indent=4)
except Exception as e:
print(f"写入模型配置失败:{e}")
# ===================== 2. 参考音频预设持久化last_selected_preset.json + ref_audios_presets.yaml =====================
# 2.1 最后选中预设的读写清
def read_last_selected_preset():
"""读取最后一次选中的预设名称"""
if not LAST_SELECTED_PRESET_JSON.exists():
return None
try:
with open(LAST_SELECTED_PRESET_JSON, "r", encoding="utf-8") as f:
data = json.load(f)
return data.get("last_selected_preset")
except Exception as e:
print(f"读取最后选中预设失败:{e}")
return None
def write_last_selected_preset(preset_name):
"""写入最后一次选中的预设名称"""
ensure_dir_exists(LAST_SELECTED_PRESET_JSON.parent)
try:
data = {"last_selected_preset": preset_name.strip()}
with open(LAST_SELECTED_PRESET_JSON, "w", encoding="utf-8") as f:
json.dump(data, f, ensure_ascii=False, indent=4)
print(f"已记录最后选中的预设:{preset_name.strip()}")
except Exception as e:
print(f"写入最后选中预设失败:{e}")
def clear_last_selected_preset():
"""清空最后选中的预设记录"""
if not LAST_SELECTED_PRESET_JSON.exists():
return
try:
with open(LAST_SELECTED_PRESET_JSON, "w", encoding="utf-8") as f:
json.dump({"last_selected_preset": ""}, f, ensure_ascii=False, indent=4)
except Exception as e:
print(f"清空最后选中预设失败:{e}")
# 2.2 参考预设配置的加载/保存/删除
def load_ref_presets():
"""加载多组参考预设配置"""
ensure_dir_exists(REF_PRESETS_YAML.parent)
# 新增:配置文件不存在时,自动创建空文件
if not REF_PRESETS_YAML.exists():
with open(REF_PRESETS_YAML, "w", encoding="utf-8") as f:
yaml.dump([], f, indent=4, allow_unicode=True)
print(f"暂未检测到参考预设配置文件,已自动创建空文件:{REF_PRESETS_YAML}")
return []
try:
with open(REF_PRESETS_YAML, "r", encoding="utf-8") as f:
presets = yaml.load(f, Loader=yaml.FullLoader) or []
# 兼容旧格式转换
if isinstance(presets, dict):
presets = [{"name": "旧配置转换", "ref_audio_path": presets.get("ref_audio_path"),
"prompt_text": presets.get("prompt_text", ""), "prompt_language": presets.get("prompt_language", "中文")}]
# 补充缺失字段 + 校验音频路径
default_template = {"name": "", "ref_audio_path": None, "prompt_text": "", "prompt_language": "中文"}
for preset in presets:
for key, value in default_template.items():
preset.setdefault(key, value)
# 校验音频路径有效性
audio_path = preset["ref_audio_path"]
if audio_path and not os.path.exists(str(audio_path)):
preset["ref_audio_path"] = None
# 清理冗余音频
clean_unreferenced_audios(presets)
print(f"参考预设加载成功,共 {len(presets)}")
return presets
except Exception as e:
print(f"参考预设加载失败:{e}")
return []
def get_preset_by_name(preset_name, presets=None):
"""根据配置名称查询对应的配置详情"""
# 核心修复:先判断 preset_name 是否为 None避免 AttributeError
if preset_name is None:
return {"name": "", "ref_audio_path": None, "prompt_text": "", "prompt_language": "中文"}
if not presets:
presets = load_ref_presets()
# 现在再调用 strip(),确保 preset_name 不是 None
preset_name_str = preset_name.strip()
for preset in presets:
if preset["name"].strip() == preset_name_str:
return preset
# 无匹配预设时,返回空的合法预设字典
return {"name": "", "ref_audio_path": None, "prompt_text": "", "prompt_language": "中文"}
def save_ref_preset_core(preset_name, ref_audio_path, prompt_text, prompt_language, confirm_override=False):
"""保存/覆盖参考预设核心逻辑(返回:提示信息、是否成功、预设列表)"""
ensure_dir_exists(REF_AUDIO_DIR)
presets = load_ref_presets()
preset_name = preset_name.strip()
# 前置校验
if not ref_audio_path or not os.path.exists(str(ref_audio_path)):
return "保存失败!请先上传有效的主参考音频文件。", False, [p["name"] for p in presets]
if not preset_name:
return "保存失败!配置名称不能为空。", False, [p["name"] for p in presets]
# 音频持久化处理
persistent_audio_path = get_persistent_audio_path(ref_audio_path, preset_name)
if not persistent_audio_path:
return "保存失败!音频文件持久化存储失败。", False, [p["name"] for p in presets]
# 同名检测
preset_index = -1
for idx, p in enumerate(presets):
if p["name"].strip() == preset_name:
preset_index = idx
break
if preset_index >= 0 and not confirm_override:
return f"配置「{preset_name}」已存在,如需替换请确认覆盖!", False, [p["name"] for p in presets]
# 构造新配置
new_preset = {
"name": preset_name,
"ref_audio_path": persistent_audio_path,
"prompt_text": prompt_text,
"prompt_language": prompt_language
}
# 更新配置列表
is_new_preset = preset_index < 0
if preset_index >= 0:
presets[preset_index] = new_preset
tip = "同名配置已覆盖!"
else:
presets.append(new_preset)
tip = "新配置已新增!"
# 写入配置文件
try:
with open(REF_PRESETS_YAML, "w", encoding="utf-8") as f:
yaml.dump(presets, f, indent=4, allow_unicode=True)
# 新增预设自动记录为最后选中
if is_new_preset:
write_last_selected_preset(preset_name)
preset_names = [p["name"] for p in presets]
return f"配置保存成功!{tip}", True, preset_names
except Exception as e:
return f"保存失败:{str(e)}", False, [p["name"] for p in presets]
def delete_ref_preset_core(preset_name):
"""删除参考预设核心逻辑(返回:提示信息、预设列表、默认选中预设)"""
presets = load_ref_presets()
preset_name = preset_name.strip()
if not presets:
return "暂无配置可删除!", [], None
# 获取待删除音频路径
target_audio_path = None
for p in presets:
if p["name"].strip() == preset_name:
target_audio_path = p.get("ref_audio_path")
break
# 过滤删除
presets = [p for p in presets if p["name"].strip() != preset_name]
# 写入配置文件
try:
with open(REF_PRESETS_YAML, "w", encoding="utf-8") as f:
yaml.dump(presets, f, indent=4, allow_unicode=True)
# 删除对应音频
if target_audio_path and os.path.exists(target_audio_path):
os.unlink(target_audio_path)
print(f"同步删除配置对应音频:{target_audio_path}")
# 清空最后选中记录(若删除的是最后选中的预设)
last_selected = read_last_selected_preset()
if last_selected and last_selected == preset_name:
clear_last_selected_preset()
preset_names = [p["name"] for p in presets]
new_selected = preset_names[0] if preset_names else None
tip = "配置删除成功!已同步清理对应音频文件" if preset_names else "配置删除成功!已同步清理对应音频文件,当前无剩余配置"
return tip, preset_names, new_selected
except Exception as e:
return f"删除失败:{str(e)}", [p["name"] for p in presets], preset_name
# 2.3 参考音频文件管理
def get_persistent_audio_path(src_audio_path, preset_name):
"""获取音频持久化路径,清理同配置名旧音频"""
if not src_audio_path or not os.path.exists(src_audio_path):
return None
# 清理文件名
safe_preset_name = sanitize_filename(preset_name)
safe_preset_name = safe_preset_name[:MAX_FILENAME_LENGTH]
# 提取后缀
src_suffix = Path(src_audio_path).suffix.lower()
if not src_suffix or src_suffix not in [".wav", ".mp3", ".flac", ".ogg", ".m4a"]:
src_suffix = ".wav"
# 计算MD5
audio_md5 = get_audio_md5(src_audio_path)
dst_filename = f"{safe_preset_name}_{audio_md5}{src_suffix}"
dst_path = REF_AUDIO_DIR / dst_filename
# 清理同配置名旧音频
for old_audio in REF_AUDIO_DIR.glob(f"{safe_preset_name}_*"):
if old_audio.suffix.lower() in [".wav", ".mp3", ".flac", ".ogg", ".m4a"]:
try:
old_audio.unlink()
except Exception as e:
print(f"清理旧音频失败:{e}")
# 复制新音频
try:
shutil.copy2(src_audio_path, dst_path)
return str(dst_path)
except Exception as e:
print(f"音频持久化复制失败:{e}")
return None
def clean_unreferenced_audios(presets):
"""清理未被任何预设引用的冗余音频"""
if not REF_AUDIO_DIR.exists():
return
# 收集已引用音频
referenced = set()
for preset in presets:
audio_path = preset.get("ref_audio_path")
if audio_path and os.path.exists(audio_path):
referenced.add(Path(audio_path).absolute())
# 删除未引用音频
deleted_count = 0
for audio_file in REF_AUDIO_DIR.glob("*"):
if audio_file.is_file() and audio_file.suffix.lower() in [".wav", ".mp3", ".flac", ".ogg", ".m4a"]:
if audio_file.absolute() not in referenced:
try:
audio_file.unlink()
deleted_count += 1
except Exception as e:
print(f"清理冗余音频失败:{e}")
if deleted_count > 0:
print(f"清理冗余未引用音频 {deleted_count}")
# ===================== 3. 推理参数持久化infer_settings.json =====================
def load_infer_settings():
"""加载推理参数配置"""
ensure_dir_exists(INFER_SETTINGS_JSON.parent)
if not INFER_SETTINGS_JSON.exists():
return DEFAULT_INFER_SETTINGS
try:
with open(INFER_SETTINGS_JSON, "r", encoding="utf-8") as f:
saved = json.load(f)
return {**DEFAULT_INFER_SETTINGS, **saved}
except Exception as e:
print(f"加载推理参数失败,使用默认值:{e}")
return DEFAULT_INFER_SETTINGS
def save_infer_settings_core(settings):
"""保存推理参数核心逻辑(返回:提示信息)"""
ensure_dir_exists(INFER_SETTINGS_JSON.parent)
try:
with open(INFER_SETTINGS_JSON, "w", encoding="utf-8") as f:
json.dump(settings, f, indent=4, ensure_ascii=False)
# 精简日志输出
print(f"✅ 推理配置保存成功:{INFER_SETTINGS_JSON.absolute()}")
return "推理设置保存成功!已覆盖原有配置文件。"
except Exception as e:
print(f"❌ 推理配置保存失败:{e}")
return f"推理设置保存失败:{str(e)}"
def restore_default_infer_settings_core():
"""恢复推理参数默认值核心逻辑(返回:默认参数列表)"""
ensure_dir_exists(INFER_SETTINGS_JSON.parent)
try:
with open(INFER_SETTINGS_JSON, "w", encoding="utf-8") as f:
json.dump(DEFAULT_INFER_SETTINGS, f, indent=4, ensure_ascii=False)
print(f"✅ 推理配置已恢复默认值:{INFER_SETTINGS_JSON.absolute()}")
except Exception as e:
print(f"❌ 推理配置恢复默认失败:{e}")
# 返回默认参数按顺序对应UI组件
return [
DEFAULT_INFER_SETTINGS["batch_size"],
DEFAULT_INFER_SETTINGS["sample_steps"],
DEFAULT_INFER_SETTINGS["fragment_interval"],
DEFAULT_INFER_SETTINGS["speed_factor"],
DEFAULT_INFER_SETTINGS["top_k"],
DEFAULT_INFER_SETTINGS["top_p"],
DEFAULT_INFER_SETTINGS["temperature"],
DEFAULT_INFER_SETTINGS["repetition_penalty"],
DEFAULT_INFER_SETTINGS["how_to_cut"],
DEFAULT_INFER_SETTINGS["super_sampling"],
DEFAULT_INFER_SETTINGS["parallel_infer"],
DEFAULT_INFER_SETTINGS["split_bucket"],
DEFAULT_INFER_SETTINGS["seed"],
DEFAULT_INFER_SETTINGS["keep_random"]
]

View File

@ -55,6 +55,10 @@ def main():
n_gpus = torch.cuda.device_count()
else:
n_gpus = 1
if n_gpus <= 1:
run(0, n_gpus, hps)
return
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = str(randint(20000, 55555))
@ -77,6 +81,8 @@ def run(rank, n_gpus, hps):
writer = SummaryWriter(log_dir=hps.s2_ckpt_dir)
writer_eval = SummaryWriter(log_dir=os.path.join(hps.s2_ckpt_dir, "eval"))
use_ddp = n_gpus > 1
if use_ddp:
dist.init_process_group(
backend="gloo" if os.name == "nt" or not torch.cuda.is_available() else "nccl",
init_method="env://?use_libuv=False",
@ -118,15 +124,20 @@ def run(rank, n_gpus, hps):
shuffle=True,
)
collate_fn = TextAudioSpeakerCollate()
train_loader = DataLoader(
train_dataset,
num_workers=5,
worker_count = 0 if os.name == "nt" and n_gpus <= 1 else min(2 if os.name == "nt" else 5, os.cpu_count() or 1)
loader_kwargs = dict(
num_workers=worker_count,
shuffle=False,
pin_memory=True,
pin_memory=torch.cuda.is_available(),
collate_fn=collate_fn,
batch_sampler=train_sampler,
persistent_workers=True,
prefetch_factor=3,
)
if worker_count > 0:
loader_kwargs["persistent_workers"] = True
loader_kwargs["prefetch_factor"] = 2 if os.name == "nt" else 3
train_loader = DataLoader(
train_dataset,
**loader_kwargs,
)
save_root = "%s/logs_s2_%s_lora_%s" % (hps.data.exp_dir, hps.model.version, hps.train.lora_rank)
os.makedirs(save_root, exist_ok=True)
@ -156,7 +167,9 @@ def run(rank, n_gpus, hps):
def model2cuda(net_g, rank):
if torch.cuda.is_available():
net_g = DDP(net_g.cuda(rank), device_ids=[rank], find_unused_parameters=True)
net_g = net_g.cuda(rank)
if use_ddp:
net_g = DDP(net_g, device_ids=[rank], find_unused_parameters=True)
else:
net_g = net_g.to(device)
return net_g
@ -242,6 +255,8 @@ def run(rank, n_gpus, hps):
None,
)
scheduler_g.step()
if use_ddp and dist.is_initialized():
dist.destroy_process_group()
print("training done")

View File

@ -180,10 +180,15 @@ def _merge_erhua(initials: list[str], finals: list[str], word: str, pos: str) ->
def _g2p(segments):
phones_list = []
word2ph = []
for seg in segments:
g2pw_batch_results = []
g2pw_batch_cursor = 0
processed_segments = [re.sub("[a-zA-Z]+", "", seg) for seg in segments]
if is_g2pw:
batch_inputs = [seg for seg in processed_segments if seg]
g2pw_batch_results = g2pw._g2pw(batch_inputs) if batch_inputs else []
for seg in processed_segments:
pinyins = []
# Replace all English words in the sentence
seg = re.sub("[a-zA-Z]+", "", seg)
seg_cut = psg.lcut(seg)
seg_cut = tone_modifier.pre_merge_for_modify(seg_cut)
initials = []
@ -204,8 +209,10 @@ def _g2p(segments):
finals = sum(finals, [])
print("pypinyin结果", initials, finals)
else:
# g2pw采用整句推理
pinyins = g2pw.lazy_pinyin(seg, neutral_tone_with_five=True, style=Style.TONE3)
# g2pw采用整句推理批量推理逐句取结果
if seg:
pinyins = g2pw_batch_results[g2pw_batch_cursor]
g2pw_batch_cursor += 1
pre_word_length = 0
for word, pos in seg_cut:

View File

@ -18,6 +18,7 @@ Credits
from typing import Dict
from typing import List
from typing import Optional
from typing import Tuple
import numpy as np
@ -37,6 +38,8 @@ def prepare_onnx_input(
use_mask: bool = False,
window_size: int = None,
max_len: int = 512,
char2id: Optional[Dict[str, int]] = None,
char_phoneme_masks: Optional[Dict[str, List[int]]] = None,
) -> Dict[str, np.array]:
if window_size is not None:
truncated_texts, truncated_query_ids = _truncate_texts(
@ -48,33 +51,88 @@ def prepare_onnx_input(
phoneme_masks = []
char_ids = []
position_ids = []
tokenized_cache = {}
if char2id is None:
char2id = {char: idx for idx, char in enumerate(chars)}
if use_mask:
if char_phoneme_masks is None:
char_phoneme_masks = {
char: [1 if i in char2phonemes[char] else 0 for i in range(len(labels))]
for char in char2phonemes
}
else:
full_phoneme_mask = [1] * len(labels)
for idx in range(len(texts)):
text = (truncated_texts if window_size else texts)[idx].lower()
query_id = (truncated_query_ids if window_size else query_ids)[idx]
cached = tokenized_cache.get(text)
if cached is None:
try:
tokens, text2token, token2text = tokenize_and_map(tokenizer=tokenizer, text=text)
except Exception:
print(f'warning: text "{text}" is invalid')
return {}
text, query_id, tokens, text2token, token2text = _truncate(
max_len=max_len, text=text, query_id=query_id, tokens=tokens, text2token=text2token, token2text=token2text
)
if len(tokens) <= max_len - 2:
processed_tokens = ["[CLS]"] + tokens + ["[SEP]"]
shared_input_id = list(np.array(tokenizer.convert_tokens_to_ids(processed_tokens)))
shared_token_type_id = list(np.zeros((len(processed_tokens),), dtype=int))
shared_attention_mask = list(np.ones((len(processed_tokens),), dtype=int))
cached = {
"is_short": True,
"tokens": tokens,
"text2token": text2token,
"token2text": token2text,
"input_id": shared_input_id,
"token_type_id": shared_token_type_id,
"attention_mask": shared_attention_mask,
}
else:
cached = {
"is_short": False,
"tokens": tokens,
"text2token": text2token,
"token2text": token2text,
}
tokenized_cache[text] = cached
if cached["is_short"]:
text_for_query = text
query_id_for_query = query_id
text2token_for_query = cached["text2token"]
input_id = cached["input_id"]
token_type_id = cached["token_type_id"]
attention_mask = cached["attention_mask"]
else:
(
text_for_query,
query_id_for_query,
tokens_for_query,
text2token_for_query,
_token2text_for_query,
) = _truncate(
max_len=max_len,
text=text,
query_id=query_id,
tokens=cached["tokens"],
text2token=cached["text2token"],
token2text=cached["token2text"],
)
processed_tokens = ["[CLS]"] + tokens_for_query + ["[SEP]"]
input_id = list(np.array(tokenizer.convert_tokens_to_ids(processed_tokens)))
token_type_id = list(np.zeros((len(processed_tokens),), dtype=int))
attention_mask = list(np.ones((len(processed_tokens),), dtype=int))
query_char = text[query_id]
phoneme_mask = (
[1 if i in char2phonemes[query_char] else 0 for i in range(len(labels))] if use_mask else [1] * len(labels)
)
char_id = chars.index(query_char)
position_id = text2token[query_id] + 1 # [CLS] token locate at first place
query_char = text_for_query[query_id_for_query]
if use_mask:
phoneme_mask = char_phoneme_masks[query_char]
else:
phoneme_mask = full_phoneme_mask
char_id = char2id[query_char]
position_id = text2token_for_query[query_id_for_query] + 1 # [CLS] token locate at first place
input_ids.append(input_id)
token_type_ids.append(token_type_id)
@ -83,10 +141,15 @@ def prepare_onnx_input(
char_ids.append(char_id)
position_ids.append(position_id)
max_token_length = max(len(seq) for seq in input_ids)
def _pad_sequences(sequences, pad_value=0):
return [seq + [pad_value] * (max_token_length - len(seq)) for seq in sequences]
outputs = {
"input_ids": np.array(input_ids).astype(np.int64),
"token_type_ids": np.array(token_type_ids).astype(np.int64),
"attention_masks": np.array(attention_masks).astype(np.int64),
"input_ids": np.array(_pad_sequences(input_ids, pad_value=0)).astype(np.int64),
"token_type_ids": np.array(_pad_sequences(token_type_ids, pad_value=0)).astype(np.int64),
"attention_masks": np.array(_pad_sequences(attention_masks, pad_value=0)).astype(np.int64),
"phoneme_masks": np.array(phoneme_masks).astype(np.float32),
"char_ids": np.array(char_ids).astype(np.int64),
"position_ids": np.array(position_ids).astype(np.int64),

View File

@ -10,7 +10,6 @@ from typing import Any, Dict, List, Tuple
import numpy as np
import onnxruntime
import requests
import torch
from opencc import OpenCC
from pypinyin import Style, pinyin
from transformers.models.auto.tokenization_auto import AutoTokenizer
@ -22,9 +21,8 @@ from .utils import load_config
onnxruntime.set_default_logger_severity(3)
try:
onnxruntime.preload_dlls()
except:
except Exception:
pass
# traceback.print_exc()
warnings.filterwarnings("ignore")
model_version = "1.1"
@ -55,6 +53,24 @@ def predict(session, onnx_input: Dict[str, Any], labels: List[str]) -> Tuple[Lis
return all_preds, all_confidences
def _load_json_from_candidates(filename: str, candidate_dirs: List[str]) -> Dict[str, Any]:
for candidate_dir in candidate_dirs:
if not candidate_dir:
continue
json_path = os.path.join(candidate_dir, filename)
if os.path.exists(json_path):
with open(json_path, "r", encoding="utf-8") as fr:
return json.load(fr)
raise FileNotFoundError(f"Cannot locate {filename} in candidate dirs: {candidate_dirs}")
def _find_first_existing_file(*paths: str) -> str:
for path in paths:
if path and os.path.exists(path):
return path
raise FileNotFoundError(f"Files not found: {paths}")
def download_and_decompress(model_dir: str = "G2PWModel/"):
if not os.path.exists(model_dir):
parent_directory = os.path.dirname(model_dir)
@ -62,7 +78,7 @@ def download_and_decompress(model_dir: str = "G2PWModel/"):
extract_dir = os.path.join(parent_directory, "G2PWModel_1.1")
extract_dir_new = os.path.join(parent_directory, "G2PWModel")
print("Downloading g2pw model...")
modelscope_url = "https://www.modelscope.cn/models/kamiorinn/g2pw/resolve/master/G2PWModel_1.1.zip" # "https://paddlespeech.cdn.bcebos.com/Parakeet/released_models/g2p/G2PWModel_1.1.zip"
modelscope_url = "https://www.modelscope.cn/models/kamiorinn/g2pw/resolve/master/G2PWModel_1.1.zip"
with requests.get(modelscope_url, stream=True) as r:
r.raise_for_status()
with open(zip_dir, "wb") as f:
@ -79,7 +95,7 @@ def download_and_decompress(model_dir: str = "G2PWModel/"):
return model_dir
class G2PWOnnxConverter:
class _G2PWBaseOnnxConverter:
def __init__(
self,
model_dir: str = "G2PWModel/",
@ -87,33 +103,16 @@ class G2PWOnnxConverter:
model_source: str = None,
enable_non_tradional_chinese: bool = False,
):
uncompress_path = download_and_decompress(model_dir)
sess_options = onnxruntime.SessionOptions()
sess_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
sess_options.execution_mode = onnxruntime.ExecutionMode.ORT_SEQUENTIAL
sess_options.intra_op_num_threads = 2 if torch.cuda.is_available() else 0
if "CUDAExecutionProvider" in onnxruntime.get_available_providers():
self.session_g2pW = onnxruntime.InferenceSession(
os.path.join(uncompress_path, "g2pW.onnx"),
sess_options=sess_options,
providers=["CUDAExecutionProvider", "CPUExecutionProvider"],
)
else:
self.session_g2pW = onnxruntime.InferenceSession(
os.path.join(uncompress_path, "g2pW.onnx"),
sess_options=sess_options,
providers=["CPUExecutionProvider"],
)
self.config = load_config(config_path=os.path.join(uncompress_path, "config.py"), use_default=True)
self.model_dir = download_and_decompress(model_dir)
self.config = load_config(config_path=os.path.join(self.model_dir, "config.py"), use_default=True)
self.model_source = model_source if model_source else self.config.model_source
self.enable_opencc = enable_non_tradional_chinese
self.tokenizer = AutoTokenizer.from_pretrained(self.model_source)
polyphonic_chars_path = os.path.join(uncompress_path, "POLYPHONIC_CHARS.txt")
monophonic_chars_path = os.path.join(uncompress_path, "MONOPHONIC_CHARS.txt")
polyphonic_chars_path = os.path.join(self.model_dir, "POLYPHONIC_CHARS.txt")
monophonic_chars_path = os.path.join(self.model_dir, "MONOPHONIC_CHARS.txt")
self.polyphonic_chars = [
line.split("\t") for line in open(polyphonic_chars_path, encoding="utf-8").read().strip().split("\n")
]
@ -149,31 +148,47 @@ class G2PWOnnxConverter:
)
self.chars = sorted(list(self.char2phonemes.keys()))
self.char2id = {char: idx for idx, char in enumerate(self.chars)}
self.char_phoneme_masks = (
{
char: [1 if i in self.char2phonemes[char] else 0 for i in range(len(self.labels))]
for char in self.char2phonemes
}
if self.config.use_mask
else None
)
self.polyphonic_chars_new = set(self.chars)
for char in self.non_polyphonic:
if char in self.polyphonic_chars_new:
self.polyphonic_chars_new.remove(char)
self.polyphonic_chars_new.discard(char)
self.monophonic_chars_dict = {char: phoneme for char, phoneme in self.monophonic_chars}
for char in self.non_monophonic:
if char in self.monophonic_chars_dict:
self.monophonic_chars_dict.pop(char)
self.monophonic_chars_dict.pop(char, None)
self.pos_tags = ["UNK", "A", "C", "D", "I", "N", "P", "T", "V", "DE", "SHI"]
default_asset_dir = os.path.normpath(os.path.join(os.path.dirname(__file__), "..", "G2PWModel"))
candidate_asset_dirs = [self.model_dir, default_asset_dir]
self.bopomofo_convert_dict = _load_json_from_candidates(
"bopomofo_to_pinyin_wo_tune_dict.json", candidate_asset_dirs
)
self.char_bopomofo_dict = _load_json_from_candidates("char_bopomofo_dict.json", candidate_asset_dirs)
with open(os.path.join(uncompress_path, "bopomofo_to_pinyin_wo_tune_dict.json"), "r", encoding="utf-8") as fr:
self.bopomofo_convert_dict = json.load(fr)
self.style_convert_func = {
"bopomofo": lambda x: x,
"pinyin": self._convert_bopomofo_to_pinyin,
}[style]
with open(os.path.join(uncompress_path, "char_bopomofo_dict.json"), "r", encoding="utf-8") as fr:
self.char_bopomofo_dict = json.load(fr)
if self.enable_opencc:
self.cc = OpenCC("s2tw")
self.enable_sentence_dedup = os.getenv("g2pw_sentence_dedup", "true").strip().lower() in {
"1",
"true",
"yes",
"y",
"on",
}
# 聚焦到多音字附近上下文默认左右各16字设为0表示关闭裁剪整句
self.polyphonic_context_chars = max(0, int(os.getenv("g2pw_polyphonic_context_chars", "16")))
def _convert_bopomofo_to_pinyin(self, bopomofo: str) -> str:
tone = bopomofo[-1]
@ -181,7 +196,6 @@ class G2PWOnnxConverter:
component = self.bopomofo_convert_dict.get(bopomofo[:-1])
if component:
return component + tone
else:
print(f'Warning: "{bopomofo}" cannot convert to pinyin')
return None
@ -197,51 +211,147 @@ class G2PWOnnxConverter:
translated_sentences.append(translated_sent)
sentences = translated_sentences
texts, query_ids, sent_ids, partial_results = self._prepare_data(sentences=sentences)
texts, model_query_ids, result_query_ids, sent_ids, partial_results = self._prepare_data(sentences=sentences)
if len(texts) == 0:
# sentences no polyphonic words
return partial_results
onnx_input = prepare_onnx_input(
model_input = prepare_onnx_input(
tokenizer=self.tokenizer,
labels=self.labels,
char2phonemes=self.char2phonemes,
chars=self.chars,
texts=texts,
query_ids=query_ids,
query_ids=model_query_ids,
use_mask=self.config.use_mask,
window_size=None,
char2id=self.char2id,
char_phoneme_masks=self.char_phoneme_masks,
)
preds, confidences = predict(session=self.session_g2pW, onnx_input=onnx_input, labels=self.labels)
if not model_input:
return partial_results
if self.enable_sentence_dedup:
preds, _confidences = self._predict_with_sentence_dedup(model_input=model_input, texts=texts)
else:
preds, _confidences = self._predict(model_input=model_input)
if self.config.use_char_phoneme:
preds = [pred.split(" ")[1] for pred in preds]
results = partial_results
for sent_id, query_id, pred in zip(sent_ids, query_ids, preds):
for sent_id, query_id, pred in zip(sent_ids, result_query_ids, preds):
results[sent_id][query_id] = self.style_convert_func(pred)
return results
def _prepare_data(self, sentences: List[str]) -> Tuple[List[str], List[int], List[int], List[List[str]]]:
texts, query_ids, sent_ids, partial_results = [], [], [], []
def _prepare_data(
self, sentences: List[str]
) -> Tuple[List[str], List[int], List[int], List[int], List[List[str]]]:
texts, model_query_ids, result_query_ids, sent_ids, partial_results = [], [], [], [], []
for sent_id, sent in enumerate(sentences):
# pypinyin works well for Simplified Chinese than Traditional Chinese
sent_s = tranditional_to_simplified(sent)
pypinyin_result = pinyin(sent_s, neutral_tone_with_five=True, style=Style.TONE3)
partial_result = [None] * len(sent)
polyphonic_indices: List[int] = []
for i, char in enumerate(sent):
if char in self.polyphonic_chars_new:
texts.append(sent)
query_ids.append(i)
sent_ids.append(sent_id)
polyphonic_indices.append(i)
elif char in self.monophonic_chars_dict:
partial_result[i] = self.style_convert_func(self.monophonic_chars_dict[char])
elif char in self.char_bopomofo_dict:
partial_result[i] = pypinyin_result[i][0]
# partial_result[i] = self.style_convert_func(self.char_bopomofo_dict[char][0])
else:
partial_result[i] = pypinyin_result[i][0]
if polyphonic_indices:
if self.polyphonic_context_chars > 0:
left = max(0, polyphonic_indices[0] - self.polyphonic_context_chars)
right = min(len(sent), polyphonic_indices[-1] + self.polyphonic_context_chars + 1)
sent_for_predict = sent[left:right]
query_offset = left
else:
sent_for_predict = sent
query_offset = 0
for index in polyphonic_indices:
texts.append(sent_for_predict)
model_query_ids.append(index - query_offset)
result_query_ids.append(index)
sent_ids.append(sent_id)
partial_results.append(partial_result)
return texts, query_ids, sent_ids, partial_results
return texts, model_query_ids, result_query_ids, sent_ids, partial_results
def _predict(self, model_input: Dict[str, Any]) -> Tuple[List[str], List[float]]:
raise NotImplementedError
def _predict_with_sentence_dedup(
self, model_input: Dict[str, Any], texts: List[str]
) -> Tuple[List[str], List[float]]:
if len(texts) <= 1:
return self._predict(model_input=model_input)
grouped_indices: Dict[str, List[int]] = {}
for idx, text in enumerate(texts):
grouped_indices.setdefault(text, []).append(idx)
if all(len(indices) == 1 for indices in grouped_indices.values()):
return self._predict(model_input=model_input)
preds: List[str] = [""] * len(texts)
confidences: List[float] = [0.0] * len(texts)
for indices in grouped_indices.values():
group_input = {name: value[indices] for name, value in model_input.items()}
if len(indices) > 1:
for name in ("input_ids", "token_type_ids", "attention_masks"):
group_input[name] = group_input[name][:1]
group_preds, group_confidences = self._predict(model_input=group_input)
for output_idx, pred, confidence in zip(indices, group_preds, group_confidences):
preds[output_idx] = pred
confidences[output_idx] = confidence
return preds, confidences
class G2PWOnnxConverter(_G2PWBaseOnnxConverter):
def __init__(
self,
model_dir: str = "G2PWModel/",
style: str = "bopomofo",
model_source: str = None,
enable_non_tradional_chinese: bool = False,
):
super().__init__(
model_dir=model_dir,
style=style,
model_source=model_source,
enable_non_tradional_chinese=enable_non_tradional_chinese,
)
sess_options = onnxruntime.SessionOptions()
sess_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
sess_options.execution_mode = onnxruntime.ExecutionMode.ORT_SEQUENTIAL
sess_options.intra_op_num_threads = 2
onnx_path = _find_first_existing_file(
os.path.join(self.model_dir, "g2pW.onnx"),
os.path.join(self.model_dir, "g2pw.onnx"),
)
if "CUDAExecutionProvider" in onnxruntime.get_available_providers():
self.session_g2pw = onnxruntime.InferenceSession(
onnx_path,
sess_options=sess_options,
providers=["CUDAExecutionProvider", "CPUExecutionProvider"],
)
else:
self.session_g2pw = onnxruntime.InferenceSession(
onnx_path,
sess_options=sess_options,
providers=["CPUExecutionProvider"],
)
def _predict(self, model_input: Dict[str, Any]) -> Tuple[List[str], List[float]]:
return predict(session=self.session_g2pw, onnx_input=model_input, labels=self.labels)

33
gowebui_batched_infer.bat Normal file
View File

@ -0,0 +1,33 @@
@echo off
:: 1. 切换命令行编码为UTF-8解决中文显示乱码必须放在最前面
chcp 65001 > nul
:: 2. 获取当前bat文件所在目录并格式化
set "SCRIPT_DIR=%~dp0"
set "SCRIPT_DIR=%SCRIPT_DIR:~0,-1%"
:: 3. 切换到脚本根目录
cd /d "%SCRIPT_DIR%"
:: 4. 创建专属TEMP目录补充主页面的核心步骤
if not exist "TEMP" md "TEMP"
set "TEMP=%SCRIPT_DIR%\TEMP"
:: 5. 设置核心环境变量(补充推理脚本依赖的配置)
set "version=v2Pro"
:: 语言配置
set "language=zh_CN"
:: 启用半精度推理GPU用户推荐CPU用户改为False
set "is_half=True"
:: 指定GPU卡号多卡可修改无GPU则删除此行
set "_CUDA_VISIBLE_DEVICES=0"
:: 6. 将runtime目录加入环境变量确保能调用内置python
set "PATH=%SCRIPT_DIR%\runtime;%PATH%"
:: 7. 直接启动并行推理脚本,传入中文语言参数
echo 正在启动GPT-SoVITS并行推理页面...
runtime\python.exe -I GPT_SoVITS/inference_webui_fast.py zh_CN
:: 8. 执行完成后暂停,便于查看报错信息
pause

View File

@ -39,6 +39,7 @@ def create_model(language="zh"):
local_dir="tools/asr/models/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch",
)
model_revision = "v2.0.4"
vad_model_revision = punc_model_revision = "v2.0.4"
elif language == "yue":
path_asr = "tools/asr/models/speech_UniASR_asr_2pass-cantonese-CHS-16k-common-vocab1468-tensorflow1-online"
snapshot_download(
@ -51,8 +52,6 @@ def create_model(language="zh"):
else:
raise ValueError(f"{language} is not supported")
vad_model_revision = punc_model_revision = "v2.0.4"
if language in funasr_models:
return funasr_models[language]
else:

View File

@ -485,6 +485,8 @@ def istft(spec, hl):
wave_right = librosa.istft(spec_right, hop_length=hl)
wave = np.asfortranarray([wave_left, wave_right])
return wave
if __name__ == "__main__":
import argparse