Compare commits

...

42 Commits

Author SHA1 Message Date
RVC-Boss
11aa78bd9b
修复环境变量可能不为str的问题
修复环境变量可能不为str的问题
2025-09-10 15:01:04 +08:00
XXXXRT666
fdf794e31d
Update WSL Rocm (#2561) 2025-08-02 17:47:15 +08:00
多玩幻灵qwq
0be59c8043
fix: 更正链接 (#2539) 2025-07-19 00:29:48 +08:00
ChasonJiang
b5a67e6247
修复gpt的loss计算问题 (#2537)
* 修复gpt的loss计算问题

* fallback tts config
2025-07-18 14:59:59 +08:00
ChasonJiang
b9211657d8
优化TTS_Config的代码逻辑 (#2536)
* 优化TTS_Config的代码逻辑

* 在载入vits权重之后保存tts_config
2025-07-18 11:54:40 +08:00
XXXXRT666
cefafee32c
Add Distil (#2531) 2025-07-17 20:28:25 +08:00
RVC-Boss
2d09bbe63a
Update tts_infer.yaml 2025-07-16 15:44:04 +08:00
RVC-Boss
4d8ebf8523
Update TTS.py 2025-07-16 15:43:26 +08:00
jiangsier-xyz
e476b01f30
解决 TTS.py 无法识别真正支持版本 v2Pro、v2ProPlus 的问题 (#2490)
同时更新一版默认配置。

Co-authored-by: jiangsier-xyz <jiangsier131@gmail.com>
2025-07-16 15:42:36 +08:00
RVC-Boss
42586e20f7
add RTF performence
add RTF performence
2025-07-14 19:01:26 +08:00
RVC-Boss
85035f7ac0
add RTF performence
add RTF performence
2025-07-14 18:56:22 +08:00
RVC-Boss
706bec74f8
Update assets.py 2025-07-11 16:11:08 +08:00
XXXXRT666
ec1218893e
Update Badge (#2518)
* Update README.md

* Update README.md

* Update Badges

* specify ranges
2025-07-11 16:10:07 +08:00
RVC-Boss
fec515dcce
Update Changelog_CN.md 2025-07-10 18:33:18 +08:00
RVC-Boss
426e1a2bb4
提升推理进程优先级 2025-07-10 18:16:45 +08:00
RVC-Boss
4e3c69043c
Update inference_webui.py 2025-07-10 18:16:24 +08:00
RVC-Boss
e63e0901fd
Update assets.py 2025-07-10 18:12:24 +08:00
RVC-Boss
97e37c74d8
Update README.md 2025-07-10 18:06:04 +08:00
RVC-Boss
3a75f5023f
Update README.md 2025-07-10 18:05:03 +08:00
RVC-Boss
0899b7e432
Update README.md 2025-07-10 17:59:49 +08:00
Yixiao Chen
8c579d46dd
Update export_torch_script.py (#2494)
Avoid dtype inconsistency when exporting
2025-07-02 22:48:28 +08:00
KamioRinn
6df61f58e4
语言分割及格式化优化 (#2488)
* better LangSegmenter

* add version num2str

* better version num2str

* sync fast infer

* sync api

* remove duplicate spaces

* remove unnecessary code

---------

Co-authored-by: RVC-Boss <129054828+RVC-Boss@users.noreply.github.com>
2025-06-27 11:58:41 +08:00
KamioRinn
90ebefa78f
make sure ort providers available (#2489) 2025-06-27 10:41:52 +08:00
XXXXRT666
4839e82148
Add Windows Install Powershell Scripts (#2487) 2025-06-27 01:04:18 +08:00
XXXXRT666
37f5abfcb4
Fix Issues with libstdcxx and conda sysroot (#2482) 2025-06-25 14:52:27 +08:00
Ella Zhang
4987df5a71
fixed syntax errors in api_v2.py (#2473) 2025-06-19 15:34:11 +08:00
XXXXRT666
d46c069e52
Remove Debug Code (#2471) 2025-06-18 10:38:54 +08:00
XXXXRT666
6fdc67ca83
Fix bugs in install.sh, reduce log noise, and improve error reporting (#2464)
* Update Install.sh

* Format Code

* Delete dev null

* Update README, Support Dark Mode in CSS/JS
2025-06-17 15:21:36 +08:00
zzz
7dec5f5bb0
Merge pull request #2460 from L-jasmine/export_v2pro
优化 torch_script 导出模型
2025-06-13 22:10:11 +08:00
RVC-Boss
1a9b8854ee
Merge pull request #2456 from L-jasmine/export_v2pro
export_torch_script.py support v2Pro & v2ProPlus
2025-06-12 23:15:46 +08:00
csh
5c91e66d2e export_torch_script.py support v2Pro & v2ProPlus 2025-06-12 21:53:14 +08:00
RVC-Boss
ed89a02337
修复“修复ge.sum数值可能爆炸的”可能导致的训练爆炸的问题
修复“修复ge.sum数值可能爆炸的”可能导致的训练爆炸的问题
2025-06-11 23:14:52 +08:00
RVC-Boss
cd6de7398e
Merge pull request #2449 from KamioRinn/maga
support v4 v2Pro v2ProPlus for api & optimize LangSegmenter
2025-06-11 10:29:39 +08:00
YYuX-1145
dd2b9253aa
Update TTS.py (#2450) 2025-06-11 10:28:42 +08:00
KamioRinn
29165eb02e support v4 v2Pro v2ProPlus for api 2025-06-11 02:09:07 +08:00
KamioRinn
746cb536c6 Fix LangSegmenter 2025-06-10 19:18:05 +08:00
Emmanuel Ferdman
0d2f273402
Resolve Python Logger warnings (#2379)
Signed-off-by: Emmanuel Ferdman <emmanuelferdman@gmail.com>
2025-06-10 18:03:23 +08:00
RVC-Boss
d39836b8fa
Update Changelog_CN.md 2025-06-10 17:30:06 +08:00
RVC-Boss
2c0436b9ce
修复实验名结尾出现空格在win中路径不正确的问题
修复实验名结尾出现空格在win中路径不正确的问题
2025-06-10 14:58:00 +08:00
RVC-Boss
8056efe4ab
修复ge.sum数值可能爆炸问题
修复ge.sum数值可能爆炸问题
2025-06-09 23:53:16 +08:00
wzy3650
d6b78c927a
fix configs error (#2439)
* fix configs error

* fix configs error

---------

Co-authored-by: wangzeyuan <wangzeyuan@agora.io>
Co-authored-by: wangzeyuan <wangzeyuan@shengwang.cn>
2025-06-09 11:25:55 +08:00
RVC-Boss
74e79ae6d6
Delete batch_inference.py 2025-06-07 14:40:30 +08:00
54 changed files with 2185 additions and 1513 deletions

View File

@ -28,7 +28,8 @@ class Text2SemanticLightningModule(LightningModule):
self.load_state_dict( self.load_state_dict(
torch.load( torch.load(
pretrained_s1, pretrained_s1,
map_location="cpu", weights_only=False, map_location="cpu",
weights_only=False,
)["weight"], )["weight"],
) )
) )

View File

@ -356,7 +356,7 @@ class Text2SemanticDecoder(nn.Module):
x = self.ar_text_embedding(x) x = self.ar_text_embedding(x)
x = x + self.bert_proj(bert_feature.transpose(1, 2)) x = x + self.bert_proj(bert_feature.transpose(1, 2))
x = self.ar_text_position(x) x = self.ar_text_position(x)
x_mask = make_pad_mask(x_lens) x_mask = make_pad_mask_left(x_lens)
y_mask = make_pad_mask(y_lens) y_mask = make_pad_mask(y_lens)
y_mask_int = y_mask.type(torch.int64) y_mask_int = y_mask.type(torch.int64)
@ -420,7 +420,7 @@ class Text2SemanticDecoder(nn.Module):
mask=xy_attn_mask, mask=xy_attn_mask,
) )
x_len = x_lens.max() x_len = x_lens.max()
logits = self.ar_predict_layer(xy_dec[:, x_len:]) logits = self.ar_predict_layer(xy_dec[:, x_len-1:])
###### DPO ############# ###### DPO #############
reject_xy_pos, reject_xy_attn_mask, reject_targets = self.make_input_data( reject_xy_pos, reject_xy_attn_mask, reject_targets = self.make_input_data(
@ -432,7 +432,7 @@ class Text2SemanticDecoder(nn.Module):
mask=reject_xy_attn_mask, mask=reject_xy_attn_mask,
) )
x_len = x_lens.max() x_len = x_lens.max()
reject_logits = self.ar_predict_layer(reject_xy_dec[:, x_len:]) reject_logits = self.ar_predict_layer(reject_xy_dec[:, x_len-1:])
# loss # loss
# from feiteng: 每次 duration 越多, 梯度更新也应该更多, 所以用 sum # from feiteng: 每次 duration 越多, 梯度更新也应该更多, 所以用 sum
@ -455,7 +455,7 @@ class Text2SemanticDecoder(nn.Module):
x = self.ar_text_embedding(x) x = self.ar_text_embedding(x)
x = x + self.bert_proj(bert_feature.transpose(1, 2)) x = x + self.bert_proj(bert_feature.transpose(1, 2))
x = self.ar_text_position(x) x = self.ar_text_position(x)
x_mask = make_pad_mask(x_lens) x_mask = make_pad_mask_left(x_lens)
y_mask = make_pad_mask(y_lens) y_mask = make_pad_mask(y_lens)
y_mask_int = y_mask.type(torch.int64) y_mask_int = y_mask.type(torch.int64)
@ -502,7 +502,7 @@ class Text2SemanticDecoder(nn.Module):
(xy_pos, None), (xy_pos, None),
mask=xy_attn_mask, mask=xy_attn_mask,
) )
logits = self.ar_predict_layer(xy_dec[:, x_len:]).permute(0, 2, 1) logits = self.ar_predict_layer(xy_dec[:, x_len-1:]).permute(0, 2, 1)
# loss # loss
# from feiteng: 每次 duration 越多, 梯度更新也应该更多, 所以用 sum # from feiteng: 每次 duration 越多, 梯度更新也应该更多, 所以用 sum
loss = F.cross_entropy(logits, targets, reduction="sum") loss = F.cross_entropy(logits, targets, reduction="sum")
@ -578,7 +578,7 @@ class Text2SemanticDecoder(nn.Module):
def pad_y_eos(self, y, y_mask_int, eos_id): def pad_y_eos(self, y, y_mask_int, eos_id):
targets = F.pad(y, (0, 1), value=0) + eos_id * F.pad(y_mask_int, (0, 1), value=1) targets = F.pad(y, (0, 1), value=0) + eos_id * F.pad(y_mask_int, (0, 1), value=1)
# 错位 # 错位
return targets[:, :-1], targets[:, 1:] return targets[:, :-1], targets
def infer_panel_batch_infer( def infer_panel_batch_infer(
self, self,

View File

@ -354,7 +354,7 @@ class ScaledAdam(BatchedOptimizer):
if ans < 1.0: if ans < 1.0:
first_state["num_clipped"] += 1 first_state["num_clipped"] += 1
if ans < 0.1: if ans < 0.1:
logging.warn(f"Scaling gradients by {ans}, model_norm_threshold={model_norm_threshold}") logging.warning(f"Scaling gradients by {ans}, model_norm_threshold={model_norm_threshold}")
if self.show_dominant_parameters: if self.show_dominant_parameters:
assert p.shape[0] == len(param_names) assert p.shape[0] == len(param_names)
self._show_gradient_dominating_parameter(tuples, tot_sumsq) self._show_gradient_dominating_parameter(tuples, tot_sumsq)
@ -362,7 +362,7 @@ class ScaledAdam(BatchedOptimizer):
def _show_gradient_dominating_parameter(self, tuples: List[Tuple[Tensor, dict, List[str]]], tot_sumsq: Tensor): def _show_gradient_dominating_parameter(self, tuples: List[Tuple[Tensor, dict, List[str]]], tot_sumsq: Tensor):
""" """
Show information of parameter wihch dominanting tot_sumsq. Show information of parameter which dominating tot_sumsq.
Args: Args:
tuples: a list of tuples of (param, state, param_names) tuples: a list of tuples of (param, state, param_names)
@ -415,7 +415,7 @@ class ScaledAdam(BatchedOptimizer):
dominant_grad, dominant_grad,
) = sorted_by_proportion[dominant_param_name] ) = sorted_by_proportion[dominant_param_name]
logging.info( logging.info(
f"Parameter Dominanting tot_sumsq {dominant_param_name}" f"Parameter Dominating tot_sumsq {dominant_param_name}"
f" with proportion {dominant_proportion:.2f}," f" with proportion {dominant_proportion:.2f},"
f" where dominant_sumsq=(grad_sumsq*orig_rms_sq)" f" where dominant_sumsq=(grad_sumsq*orig_rms_sq)"
f"={dominant_sumsq:.3e}," f"={dominant_sumsq:.3e},"

View File

@ -32,19 +32,21 @@ from transformers import AutoModelForMaskedLM, AutoTokenizer
from tools.audio_sr import AP_BWE from tools.audio_sr import AP_BWE
from tools.i18n.i18n import I18nAuto, scan_language_list from tools.i18n.i18n import I18nAuto, scan_language_list
from tools.my_utils import load_audio
from TTS_infer_pack.text_segmentation_method import splits from TTS_infer_pack.text_segmentation_method import splits
from TTS_infer_pack.TextPreprocessor import TextPreprocessor from TTS_infer_pack.TextPreprocessor import TextPreprocessor
from sv import SV from sv import SV
resample_transform_dict={}
def resample(audio_tensor, sr0,sr1,device): resample_transform_dict = {}
def resample(audio_tensor, sr0, sr1, device):
global resample_transform_dict global resample_transform_dict
key="%s-%s-%s"%(sr0,sr1,str(device)) key = "%s-%s-%s" % (sr0, sr1, str(device))
if key not in resample_transform_dict: if key not in resample_transform_dict:
resample_transform_dict[key] = torchaudio.transforms.Resample( resample_transform_dict[key] = torchaudio.transforms.Resample(sr0, sr1).to(device)
sr0, sr1
).to(device)
return resample_transform_dict[key](audio_tensor) return resample_transform_dict[key](audio_tensor)
language = os.environ.get("language", "Auto") language = os.environ.get("language", "Auto")
language = sys.argv[-1] if sys.argv[-1] in scan_language_list() else language language = sys.argv[-1] if sys.argv[-1] in scan_language_list() else language
i18n = I18nAuto(language=language) i18n = I18nAuto(language=language)
@ -111,6 +113,7 @@ def speed_change(input_audio: np.ndarray, speed: float, sr: int):
return processed_audio return processed_audio
class DictToAttrRecursive(dict): class DictToAttrRecursive(dict):
def __init__(self, input_dict): def __init__(self, input_dict):
super().__init__(input_dict) super().__init__(input_dict)
@ -301,10 +304,10 @@ class TTS_Config:
configs: dict = self._load_configs(self.configs_path) configs: dict = self._load_configs(self.configs_path)
assert isinstance(configs, dict) assert isinstance(configs, dict)
version = configs.get("version", "v2").lower() configs_ = deepcopy(self.default_configs)
assert version in ["v1", "v2", "v3", "v4", "v2Pro", "v2ProPlus"] configs_.update(configs)
self.default_configs[version] = configs.get(version, self.default_configs[version]) self.configs: dict = configs_.get("custom", configs_["v2"])
self.configs: dict = configs.get("custom", deepcopy(self.default_configs[version])) self.default_configs = deepcopy(configs_)
self.device = self.configs.get("device", torch.device("cpu")) self.device = self.configs.get("device", torch.device("cpu"))
if "cuda" in str(self.device) and not torch.cuda.is_available(): if "cuda" in str(self.device) and not torch.cuda.is_available():
@ -312,11 +315,13 @@ class TTS_Config:
self.device = torch.device("cpu") self.device = torch.device("cpu")
self.is_half = self.configs.get("is_half", False) self.is_half = self.configs.get("is_half", False)
# if str(self.device) == "cpu" and self.is_half: if str(self.device) == "cpu" and self.is_half:
# print(f"Warning: Half precision is not supported on CPU, set is_half to False.") print(f"Warning: Half precision is not supported on CPU, set is_half to False.")
# self.is_half = False self.is_half = False
version = self.configs.get("version", None)
self.version = version self.version = version
assert self.version in ["v1", "v2", "v3", "v4", "v2Pro", "v2ProPlus"], "Invalid version!"
self.t2s_weights_path = self.configs.get("t2s_weights_path", None) self.t2s_weights_path = self.configs.get("t2s_weights_path", None)
self.vits_weights_path = self.configs.get("vits_weights_path", None) self.vits_weights_path = self.configs.get("vits_weights_path", None)
self.bert_base_path = self.configs.get("bert_base_path", None) self.bert_base_path = self.configs.get("bert_base_path", None)
@ -479,7 +484,7 @@ class TTS:
def init_vits_weights(self, weights_path: str): def init_vits_weights(self, weights_path: str):
self.configs.vits_weights_path = weights_path self.configs.vits_weights_path = weights_path
version, model_version, if_lora_v3 = get_sovits_version_from_path_fast(weights_path) version, model_version, if_lora_v3 = get_sovits_version_from_path_fast(weights_path)
if "Pro"in model_version: if "Pro" in model_version:
self.init_sv_model() self.init_sv_model()
path_sovits = self.configs.default_configs[model_version]["vits_weights_path"] path_sovits = self.configs.default_configs[model_version]["vits_weights_path"]
@ -498,9 +503,9 @@ class TTS:
else: else:
hps["model"]["version"] = "v2" hps["model"]["version"] = "v2"
version = hps["model"]["version"] version = hps["model"]["version"]
v3v4set={"v3", "v4"} v3v4set = {"v3", "v4"}
if model_version not in v3v4set: if model_version not in v3v4set:
if "Pro"not in model_version: if "Pro" not in model_version:
model_version = version model_version = version
else: else:
hps["model"]["version"] = model_version hps["model"]["version"] = model_version
@ -542,7 +547,7 @@ class TTS:
if "pretrained" not in weights_path and hasattr(vits_model, "enc_q"): if "pretrained" not in weights_path and hasattr(vits_model, "enc_q"):
del vits_model.enc_q del vits_model.enc_q
self.is_v2pro=model_version in {"v2Pro","v2ProPlus"} self.is_v2pro = model_version in {"v2Pro", "v2ProPlus"}
if if_lora_v3 == False: if if_lora_v3 == False:
print( print(
@ -573,6 +578,10 @@ class TTS:
if self.configs.is_half and str(self.configs.device) != "cpu": if self.configs.is_half and str(self.configs.device) != "cpu":
self.vits_model = self.vits_model.half() self.vits_model = self.vits_model.half()
self.configs.save_configs()
def init_t2s_weights(self, weights_path: str): def init_t2s_weights(self, weights_path: str):
print(f"Loading Text2Semantic weights from {weights_path}") print(f"Loading Text2Semantic weights from {weights_path}")
self.configs.t2s_weights_path = weights_path self.configs.t2s_weights_path = weights_path
@ -632,7 +641,9 @@ class TTS:
) )
self.vocoder.remove_weight_norm() self.vocoder.remove_weight_norm()
state_dict_g = torch.load( state_dict_g = torch.load(
"%s/GPT_SoVITS/pretrained_models/gsv-v4-pretrained/vocoder.pth" % (now_dir,), map_location="cpu", weights_only=False "%s/GPT_SoVITS/pretrained_models/gsv-v4-pretrained/vocoder.pth" % (now_dir,),
map_location="cpu",
weights_only=False,
) )
print("loading vocoder", self.vocoder.load_state_dict(state_dict_g)) print("loading vocoder", self.vocoder.load_state_dict(state_dict_g))
@ -752,11 +763,13 @@ class TTS:
if raw_sr != self.configs.sampling_rate: if raw_sr != self.configs.sampling_rate:
audio = raw_audio.to(self.configs.device) audio = raw_audio.to(self.configs.device)
if (audio.shape[0] == 2): audio = audio.mean(0).unsqueeze(0) if audio.shape[0] == 2:
audio = audio.mean(0).unsqueeze(0)
audio = resample(audio, raw_sr, self.configs.sampling_rate, self.configs.device) audio = resample(audio, raw_sr, self.configs.sampling_rate, self.configs.device)
else: else:
audio = raw_audio.to(self.configs.device) audio = raw_audio.to(self.configs.device)
if (audio.shape[0] == 2): audio = audio.mean(0).unsqueeze(0) if audio.shape[0] == 2:
audio = audio.mean(0).unsqueeze(0)
maxx = audio.abs().max() maxx = audio.abs().max()
if maxx > 1: if maxx > 1:
@ -775,8 +788,9 @@ class TTS:
audio = resample(audio, self.configs.sampling_rate, 16000, self.configs.device) audio = resample(audio, self.configs.sampling_rate, 16000, self.configs.device)
if self.configs.is_half: if self.configs.is_half:
audio = audio.half() audio = audio.half()
else:audio=None else:
return spec,audio audio = None
return spec, audio
def _set_prompt_semantic(self, ref_wav_path: str): def _set_prompt_semantic(self, ref_wav_path: str):
zero_wav = np.zeros( zero_wav = np.zeros(
@ -1073,7 +1087,10 @@ class TTS:
###### setting reference audio and prompt text preprocessing ######## ###### setting reference audio and prompt text preprocessing ########
t0 = time.perf_counter() t0 = time.perf_counter()
if (ref_audio_path is not None) and (ref_audio_path != self.prompt_cache["ref_audio_path"]): if (ref_audio_path is not None) and (
ref_audio_path != self.prompt_cache["ref_audio_path"]
or (self.is_v2pro and self.prompt_cache["refer_spec"][0][1] is None)
):
if not os.path.exists(ref_audio_path): if not os.path.exists(ref_audio_path):
raise ValueError(f"{ref_audio_path} not exists") raise ValueError(f"{ref_audio_path} not exists")
self.set_ref_audio(ref_audio_path) self.set_ref_audio(ref_audio_path)
@ -1212,9 +1229,10 @@ class TTS:
t_34 += t4 - t3 t_34 += t4 - t3
refer_audio_spec = [] refer_audio_spec = []
if self.is_v2pro:sv_emb=[] if self.is_v2pro:
for spec,audio_tensor in self.prompt_cache["refer_spec"]: sv_emb = []
spec=spec.to(dtype=self.precision, device=self.configs.device) for spec, audio_tensor in self.prompt_cache["refer_spec"]:
spec = spec.to(dtype=self.precision, device=self.configs.device)
refer_audio_spec.append(spec) refer_audio_spec.append(spec)
if self.is_v2pro: if self.is_v2pro:
sv_emb.append(self.sv_model.compute_embedding3(audio_tensor)) sv_emb.append(self.sv_model.compute_embedding3(audio_tensor))
@ -1249,10 +1267,14 @@ class TTS:
torch.cat(pred_semantic_list).unsqueeze(0).unsqueeze(0).to(self.configs.device) torch.cat(pred_semantic_list).unsqueeze(0).unsqueeze(0).to(self.configs.device)
) )
_batch_phones = torch.cat(batch_phones).unsqueeze(0).to(self.configs.device) _batch_phones = torch.cat(batch_phones).unsqueeze(0).to(self.configs.device)
if self.is_v2pro!=True: if self.is_v2pro != True:
_batch_audio_fragment = self.vits_model.decode(all_pred_semantic, _batch_phones, refer_audio_spec, speed=speed_factor).detach()[0, 0, :] _batch_audio_fragment = self.vits_model.decode(
all_pred_semantic, _batch_phones, refer_audio_spec, speed=speed_factor
).detach()[0, 0, :]
else: else:
_batch_audio_fragment = self.vits_model.decode(all_pred_semantic, _batch_phones, refer_audio_spec, speed=speed_factor,sv_emb=sv_emb).detach()[0, 0, :] _batch_audio_fragment = self.vits_model.decode(
all_pred_semantic, _batch_phones, refer_audio_spec, speed=speed_factor, sv_emb=sv_emb
).detach()[0, 0, :]
audio_frag_end_idx.insert(0, 0) audio_frag_end_idx.insert(0, 0)
batch_audio_fragment = [ batch_audio_fragment = [
_batch_audio_fragment[audio_frag_end_idx[i - 1] : audio_frag_end_idx[i]] _batch_audio_fragment[audio_frag_end_idx[i - 1] : audio_frag_end_idx[i]]
@ -1266,9 +1288,13 @@ class TTS:
pred_semantic_list[i][-idx:].unsqueeze(0).unsqueeze(0) pred_semantic_list[i][-idx:].unsqueeze(0).unsqueeze(0)
) # .unsqueeze(0)#mq要多unsqueeze一次 ) # .unsqueeze(0)#mq要多unsqueeze一次
if self.is_v2pro != True: if self.is_v2pro != True:
audio_fragment = self.vits_model.decode(_pred_semantic, phones, refer_audio_spec, speed=speed_factor).detach()[0, 0, :] audio_fragment = self.vits_model.decode(
_pred_semantic, phones, refer_audio_spec, speed=speed_factor
).detach()[0, 0, :]
else: else:
audio_fragment = self.vits_model.decode(_pred_semantic, phones, refer_audio_spec, speed=speed_factor,sv_emb=sv_emb).detach()[0, 0, :] audio_fragment = self.vits_model.decode(
_pred_semantic, phones, refer_audio_spec, speed=speed_factor, sv_emb=sv_emb
).detach()[0, 0, :]
batch_audio_fragment.append(audio_fragment) ###试试重建不带上prompt部分 batch_audio_fragment.append(audio_fragment) ###试试重建不带上prompt部分
else: else:
if parallel_infer: if parallel_infer:
@ -1410,7 +1436,7 @@ class TTS:
raw_entry = self.prompt_cache["refer_spec"][0] raw_entry = self.prompt_cache["refer_spec"][0]
if isinstance(raw_entry, tuple): if isinstance(raw_entry, tuple):
raw_entry = raw_entry[0] raw_entry = raw_entry[0]
refer_audio_spec = raw_entry.to(dtype=self.precision,device=self.configs.device) refer_audio_spec = raw_entry.to(dtype=self.precision, device=self.configs.device)
fea_ref, ge = self.vits_model.decode_encp(prompt_semantic_tokens, prompt_phones, refer_audio_spec) fea_ref, ge = self.vits_model.decode_encp(prompt_semantic_tokens, prompt_phones, refer_audio_spec)
ref_audio: torch.Tensor = self.prompt_cache["raw_audio"] ref_audio: torch.Tensor = self.prompt_cache["raw_audio"]
@ -1480,7 +1506,7 @@ class TTS:
raw_entry = self.prompt_cache["refer_spec"][0] raw_entry = self.prompt_cache["refer_spec"][0]
if isinstance(raw_entry, tuple): if isinstance(raw_entry, tuple):
raw_entry = raw_entry[0] raw_entry = raw_entry[0]
refer_audio_spec = raw_entry.to(dtype=self.precision,device=self.configs.device) refer_audio_spec = raw_entry.to(dtype=self.precision, device=self.configs.device)
fea_ref, ge = self.vits_model.decode_encp(prompt_semantic_tokens, prompt_phones, refer_audio_spec) fea_ref, ge = self.vits_model.decode_encp(prompt_semantic_tokens, prompt_phones, refer_audio_spec)
ref_audio: torch.Tensor = self.prompt_cache["raw_audio"] ref_audio: torch.Tensor = self.prompt_cache["raw_audio"]

View File

@ -121,33 +121,31 @@ class TextPreprocessor:
def get_phones_and_bert(self, text: str, language: str, version: str, final: bool = False): def get_phones_and_bert(self, text: str, language: str, version: str, final: bool = False):
with self.bert_lock: with self.bert_lock:
if language in {"en", "all_zh", "all_ja", "all_ko", "all_yue"}: text = re.sub(r' {2,}', ' ', text)
# language = language.replace("all_","")
formattext = text
while " " in formattext:
formattext = formattext.replace(" ", " ")
if language == "all_zh":
if re.search(r"[A-Za-z]", formattext):
formattext = re.sub(r"[a-z]", lambda x: x.group(0).upper(), formattext)
formattext = chinese.mix_text_normalize(formattext)
return self.get_phones_and_bert(formattext, "zh", version)
else:
phones, word2ph, norm_text = self.clean_text_inf(formattext, language, version)
bert = self.get_bert_feature(norm_text, word2ph).to(self.device)
elif language == "all_yue" and re.search(r"[A-Za-z]", formattext):
formattext = re.sub(r"[a-z]", lambda x: x.group(0).upper(), formattext)
formattext = chinese.mix_text_normalize(formattext)
return self.get_phones_and_bert(formattext, "yue", version)
else:
phones, word2ph, norm_text = self.clean_text_inf(formattext, language, version)
bert = torch.zeros(
(1024, len(phones)),
dtype=torch.float32,
).to(self.device)
elif language in {"zh", "ja", "ko", "yue", "auto", "auto_yue"}:
textlist = [] textlist = []
langlist = [] langlist = []
if language == "auto": if language == "all_zh":
for tmp in LangSegmenter.getTexts(text,"zh"):
langlist.append(tmp["lang"])
textlist.append(tmp["text"])
elif language == "all_yue":
for tmp in LangSegmenter.getTexts(text,"zh"):
if tmp["lang"] == "zh":
tmp["lang"] = "yue"
langlist.append(tmp["lang"])
textlist.append(tmp["text"])
elif language == "all_ja":
for tmp in LangSegmenter.getTexts(text,"ja"):
langlist.append(tmp["lang"])
textlist.append(tmp["text"])
elif language == "all_ko":
for tmp in LangSegmenter.getTexts(text,"ko"):
langlist.append(tmp["lang"])
textlist.append(tmp["text"])
elif language == "en":
langlist.append("en")
textlist.append(text)
elif language == "auto":
for tmp in LangSegmenter.getTexts(text): for tmp in LangSegmenter.getTexts(text):
langlist.append(tmp["lang"]) langlist.append(tmp["lang"])
textlist.append(tmp["text"]) textlist.append(tmp["text"])
@ -159,6 +157,10 @@ class TextPreprocessor:
textlist.append(tmp["text"]) textlist.append(tmp["text"])
else: else:
for tmp in LangSegmenter.getTexts(text): for tmp in LangSegmenter.getTexts(text):
if langlist:
if (tmp["lang"] == "en" and langlist[-1] == "en") or (tmp["lang"] != "en" and langlist[-1] != "en"):
textlist[-1] += tmp["text"]
continue
if tmp["lang"] == "en": if tmp["lang"] == "en":
langlist.append(tmp["lang"]) langlist.append(tmp["lang"])
else: else:

View File

@ -22,6 +22,22 @@ v2:
t2s_weights_path: GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s1bert25hz-5kh-longer-epoch=12-step=369668.ckpt t2s_weights_path: GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s1bert25hz-5kh-longer-epoch=12-step=369668.ckpt
version: v2 version: v2
vits_weights_path: GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s2G2333k.pth vits_weights_path: GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s2G2333k.pth
v2Pro:
bert_base_path: GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large
cnhuhbert_base_path: GPT_SoVITS/pretrained_models/chinese-hubert-base
device: cpu
is_half: false
t2s_weights_path: GPT_SoVITS/pretrained_models/s1v3.ckpt
version: v2Pro
vits_weights_path: GPT_SoVITS/pretrained_models/v2Pro/s2Gv2Pro.pth
v2ProPlus:
bert_base_path: GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large
cnhuhbert_base_path: GPT_SoVITS/pretrained_models/chinese-hubert-base
device: cpu
is_half: false
t2s_weights_path: GPT_SoVITS/pretrained_models/s1v3.ckpt
version: v2ProPlus
vits_weights_path: GPT_SoVITS/pretrained_models/v2Pro/s2Gv2ProPlus.pth
v3: v3:
bert_base_path: GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large bert_base_path: GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large
cnhuhbert_base_path: GPT_SoVITS/pretrained_models/chinese-hubert-base cnhuhbert_base_path: GPT_SoVITS/pretrained_models/chinese-hubert-base

View File

@ -2,13 +2,12 @@
# Licensed under the Apache License, Version 2.0 (http://www.apache.org/licenses/LICENSE-2.0) # Licensed under the Apache License, Version 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
""" """
Res2Net implementation is adapted from https://github.com/wenet-e2e/wespeaker. Res2Net implementation is adapted from https://github.com/wenet-e2e/wespeaker.
ERes2Net incorporates both local and global feature fusion techniques to improve the performance. ERes2Net incorporates both local and global feature fusion techniques to improve the performance.
The local feature fusion (LFF) fuses the features within one single residual block to extract the local signal. The local feature fusion (LFF) fuses the features within one single residual block to extract the local signal.
The global feature fusion (GFF) takes acoustic features of different scales as input to aggregate global signal. The global feature fusion (GFF) takes acoustic features of different scales as input to aggregate global signal.
""" """
import torch import torch
import math import math
import torch.nn as nn import torch.nn as nn
@ -16,15 +15,14 @@ import torch.nn.functional as F
import pooling_layers as pooling_layers import pooling_layers as pooling_layers
from fusion import AFF from fusion import AFF
class ReLU(nn.Hardtanh):
class ReLU(nn.Hardtanh):
def __init__(self, inplace=False): def __init__(self, inplace=False):
super(ReLU, self).__init__(0, 20, inplace) super(ReLU, self).__init__(0, 20, inplace)
def __repr__(self): def __repr__(self):
inplace_str = 'inplace' if self.inplace else '' inplace_str = "inplace" if self.inplace else ""
return self.__class__.__name__ + ' (' \ return self.__class__.__name__ + " (" + inplace_str + ")"
+ inplace_str + ')'
class BasicBlockERes2Net(nn.Module): class BasicBlockERes2Net(nn.Module):
@ -32,13 +30,13 @@ class BasicBlockERes2Net(nn.Module):
def __init__(self, in_planes, planes, stride=1, baseWidth=32, scale=2): def __init__(self, in_planes, planes, stride=1, baseWidth=32, scale=2):
super(BasicBlockERes2Net, self).__init__() super(BasicBlockERes2Net, self).__init__()
width = int(math.floor(planes*(baseWidth/64.0))) width = int(math.floor(planes * (baseWidth / 64.0)))
self.conv1 = nn.Conv2d(in_planes, width*scale, kernel_size=1, stride=stride, bias=False) self.conv1 = nn.Conv2d(in_planes, width * scale, kernel_size=1, stride=stride, bias=False)
self.bn1 = nn.BatchNorm2d(width*scale) self.bn1 = nn.BatchNorm2d(width * scale)
self.nums = scale self.nums = scale
convs=[] convs = []
bns=[] bns = []
for i in range(self.nums): for i in range(self.nums):
convs.append(nn.Conv2d(width, width, kernel_size=3, padding=1, bias=False)) convs.append(nn.Conv2d(width, width, kernel_size=3, padding=1, bias=False))
bns.append(nn.BatchNorm2d(width)) bns.append(nn.BatchNorm2d(width))
@ -46,14 +44,14 @@ class BasicBlockERes2Net(nn.Module):
self.bns = nn.ModuleList(bns) self.bns = nn.ModuleList(bns)
self.relu = ReLU(inplace=True) self.relu = ReLU(inplace=True)
self.conv3 = nn.Conv2d(width*scale, planes*self.expansion, kernel_size=1, bias=False) self.conv3 = nn.Conv2d(width * scale, planes * self.expansion, kernel_size=1, bias=False)
self.bn3 = nn.BatchNorm2d(planes*self.expansion) self.bn3 = nn.BatchNorm2d(planes * self.expansion)
self.shortcut = nn.Sequential() self.shortcut = nn.Sequential()
if stride != 1 or in_planes != self.expansion * planes: if stride != 1 or in_planes != self.expansion * planes:
self.shortcut = nn.Sequential( self.shortcut = nn.Sequential(
nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
stride=stride, bias=False), nn.BatchNorm2d(self.expansion * planes),
nn.BatchNorm2d(self.expansion * planes)) )
self.stride = stride self.stride = stride
self.width = width self.width = width
self.scale = scale self.scale = scale
@ -64,18 +62,18 @@ class BasicBlockERes2Net(nn.Module):
out = self.conv1(x) out = self.conv1(x)
out = self.bn1(out) out = self.bn1(out)
out = self.relu(out) out = self.relu(out)
spx = torch.split(out,self.width,1) spx = torch.split(out, self.width, 1)
for i in range(self.nums): for i in range(self.nums):
if i==0: if i == 0:
sp = spx[i] sp = spx[i]
else: else:
sp = sp + spx[i] sp = sp + spx[i]
sp = self.convs[i](sp) sp = self.convs[i](sp)
sp = self.relu(self.bns[i](sp)) sp = self.relu(self.bns[i](sp))
if i==0: if i == 0:
out = sp out = sp
else: else:
out = torch.cat((out,sp),1) out = torch.cat((out, sp), 1)
out = self.conv3(out) out = self.conv3(out)
out = self.bn3(out) out = self.bn3(out)
@ -86,19 +84,20 @@ class BasicBlockERes2Net(nn.Module):
return out return out
class BasicBlockERes2Net_diff_AFF(nn.Module): class BasicBlockERes2Net_diff_AFF(nn.Module):
expansion = 2 expansion = 2
def __init__(self, in_planes, planes, stride=1, baseWidth=32, scale=2): def __init__(self, in_planes, planes, stride=1, baseWidth=32, scale=2):
super(BasicBlockERes2Net_diff_AFF, self).__init__() super(BasicBlockERes2Net_diff_AFF, self).__init__()
width = int(math.floor(planes*(baseWidth/64.0))) width = int(math.floor(planes * (baseWidth / 64.0)))
self.conv1 = nn.Conv2d(in_planes, width*scale, kernel_size=1, stride=stride, bias=False) self.conv1 = nn.Conv2d(in_planes, width * scale, kernel_size=1, stride=stride, bias=False)
self.bn1 = nn.BatchNorm2d(width*scale) self.bn1 = nn.BatchNorm2d(width * scale)
self.nums = scale self.nums = scale
convs=[] convs = []
fuse_models=[] fuse_models = []
bns=[] bns = []
for i in range(self.nums): for i in range(self.nums):
convs.append(nn.Conv2d(width, width, kernel_size=3, padding=1, bias=False)) convs.append(nn.Conv2d(width, width, kernel_size=3, padding=1, bias=False))
bns.append(nn.BatchNorm2d(width)) bns.append(nn.BatchNorm2d(width))
@ -110,14 +109,14 @@ class BasicBlockERes2Net_diff_AFF(nn.Module):
self.fuse_models = nn.ModuleList(fuse_models) self.fuse_models = nn.ModuleList(fuse_models)
self.relu = ReLU(inplace=True) self.relu = ReLU(inplace=True)
self.conv3 = nn.Conv2d(width*scale, planes*self.expansion, kernel_size=1, bias=False) self.conv3 = nn.Conv2d(width * scale, planes * self.expansion, kernel_size=1, bias=False)
self.bn3 = nn.BatchNorm2d(planes*self.expansion) self.bn3 = nn.BatchNorm2d(planes * self.expansion)
self.shortcut = nn.Sequential() self.shortcut = nn.Sequential()
if stride != 1 or in_planes != self.expansion * planes: if stride != 1 or in_planes != self.expansion * planes:
self.shortcut = nn.Sequential( self.shortcut = nn.Sequential(
nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
stride=stride, bias=False), nn.BatchNorm2d(self.expansion * planes),
nn.BatchNorm2d(self.expansion * planes)) )
self.stride = stride self.stride = stride
self.width = width self.width = width
self.scale = scale self.scale = scale
@ -128,19 +127,19 @@ class BasicBlockERes2Net_diff_AFF(nn.Module):
out = self.conv1(x) out = self.conv1(x)
out = self.bn1(out) out = self.bn1(out)
out = self.relu(out) out = self.relu(out)
spx = torch.split(out,self.width,1) spx = torch.split(out, self.width, 1)
for i in range(self.nums): for i in range(self.nums):
if i==0: if i == 0:
sp = spx[i] sp = spx[i]
else: else:
sp = self.fuse_models[i-1](sp, spx[i]) sp = self.fuse_models[i - 1](sp, spx[i])
sp = self.convs[i](sp) sp = self.convs[i](sp)
sp = self.relu(self.bns[i](sp)) sp = self.relu(self.bns[i](sp))
if i==0: if i == 0:
out = sp out = sp
else: else:
out = torch.cat((out,sp),1) out = torch.cat((out, sp), 1)
out = self.conv3(out) out = self.conv3(out)
out = self.bn3(out) out = self.bn3(out)
@ -151,16 +150,19 @@ class BasicBlockERes2Net_diff_AFF(nn.Module):
return out return out
class ERes2Net(nn.Module): class ERes2Net(nn.Module):
def __init__(self, def __init__(
self,
block=BasicBlockERes2Net, block=BasicBlockERes2Net,
block_fuse=BasicBlockERes2Net_diff_AFF, block_fuse=BasicBlockERes2Net_diff_AFF,
num_blocks=[3, 4, 6, 3], num_blocks=[3, 4, 6, 3],
m_channels=32, m_channels=32,
feat_dim=80, feat_dim=80,
embedding_size=192, embedding_size=192,
pooling_func='TSTP', pooling_func="TSTP",
two_emb_layer=False): two_emb_layer=False,
):
super(ERes2Net, self).__init__() super(ERes2Net, self).__init__()
self.in_planes = m_channels self.in_planes = m_channels
self.feat_dim = feat_dim self.feat_dim = feat_dim
@ -176,20 +178,24 @@ class ERes2Net(nn.Module):
self.layer4 = self._make_layer(block_fuse, m_channels * 8, num_blocks[3], stride=2) self.layer4 = self._make_layer(block_fuse, m_channels * 8, num_blocks[3], stride=2)
# Downsampling module for each layer # Downsampling module for each layer
self.layer1_downsample = nn.Conv2d(m_channels * 2, m_channels * 4, kernel_size=3, stride=2, padding=1, bias=False) self.layer1_downsample = nn.Conv2d(
self.layer2_downsample = nn.Conv2d(m_channels * 4, m_channels * 8, kernel_size=3, padding=1, stride=2, bias=False) m_channels * 2, m_channels * 4, kernel_size=3, stride=2, padding=1, bias=False
self.layer3_downsample = nn.Conv2d(m_channels * 8, m_channels * 16, kernel_size=3, padding=1, stride=2, bias=False) )
self.layer2_downsample = nn.Conv2d(
m_channels * 4, m_channels * 8, kernel_size=3, padding=1, stride=2, bias=False
)
self.layer3_downsample = nn.Conv2d(
m_channels * 8, m_channels * 16, kernel_size=3, padding=1, stride=2, bias=False
)
# Bottom-up fusion module # Bottom-up fusion module
self.fuse_mode12 = AFF(channels=m_channels * 4) self.fuse_mode12 = AFF(channels=m_channels * 4)
self.fuse_mode123 = AFF(channels=m_channels * 8) self.fuse_mode123 = AFF(channels=m_channels * 8)
self.fuse_mode1234 = AFF(channels=m_channels * 16) self.fuse_mode1234 = AFF(channels=m_channels * 16)
self.n_stats = 1 if pooling_func == 'TAP' or pooling_func == "TSDP" else 2 self.n_stats = 1 if pooling_func == "TAP" or pooling_func == "TSDP" else 2
self.pool = getattr(pooling_layers, pooling_func)( self.pool = getattr(pooling_layers, pooling_func)(in_dim=self.stats_dim * block.expansion)
in_dim=self.stats_dim * block.expansion) self.seg_1 = nn.Linear(self.stats_dim * block.expansion * self.n_stats, embedding_size)
self.seg_1 = nn.Linear(self.stats_dim * block.expansion * self.n_stats,
embedding_size)
if self.two_emb_layer: if self.two_emb_layer:
self.seg_bn_1 = nn.BatchNorm1d(embedding_size, affine=False) self.seg_bn_1 = nn.BatchNorm1d(embedding_size, affine=False)
self.seg_2 = nn.Linear(embedding_size, embedding_size) self.seg_2 = nn.Linear(embedding_size, embedding_size)
@ -243,18 +249,16 @@ class ERes2Net(nn.Module):
fuse_out123 = self.fuse_mode123(out3, fuse_out12_downsample) fuse_out123 = self.fuse_mode123(out3, fuse_out12_downsample)
out4 = self.layer4(out3) out4 = self.layer4(out3)
fuse_out123_downsample = self.layer3_downsample(fuse_out123) fuse_out123_downsample = self.layer3_downsample(fuse_out123)
fuse_out1234 = self.fuse_mode1234(out4, fuse_out123_downsample).flatten(start_dim=1,end_dim=2).mean(-1) fuse_out1234 = self.fuse_mode1234(out4, fuse_out123_downsample).flatten(start_dim=1, end_dim=2).mean(-1)
return fuse_out1234 return fuse_out1234
if __name__ == '__main__': if __name__ == "__main__":
x = torch.zeros(10, 300, 80) x = torch.zeros(10, 300, 80)
model = ERes2Net(feat_dim=80, embedding_size=192, pooling_func='TSTP') model = ERes2Net(feat_dim=80, embedding_size=192, pooling_func="TSTP")
model.eval() model.eval()
out = model(x) out = model(x)
print(out.shape) # torch.Size([10, 192]) print(out.shape) # torch.Size([10, 192])
num_params = sum(param.numel() for param in model.parameters()) num_params = sum(param.numel() for param in model.parameters())
print("{} M".format(num_params / 1e6)) # 6.61M print("{} M".format(num_params / 1e6)) # 6.61M

View File

@ -2,14 +2,12 @@
# Licensed under the Apache License, Version 2.0 (http://www.apache.org/licenses/LICENSE-2.0) # Licensed under the Apache License, Version 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
""" """
To further improve the short-duration feature extraction capability of ERes2Net, we expand the channel dimension To further improve the short-duration feature extraction capability of ERes2Net, we expand the channel dimension
within each stage. However, this modification also increases the number of model parameters and computational complexity. within each stage. However, this modification also increases the number of model parameters and computational complexity.
To alleviate this problem, we propose an improved ERes2NetV2 by pruning redundant structures, ultimately reducing To alleviate this problem, we propose an improved ERes2NetV2 by pruning redundant structures, ultimately reducing
both the model parameters and its computational cost. both the model parameters and its computational cost.
""" """
import torch import torch
import math import math
import torch.nn as nn import torch.nn as nn
@ -17,29 +15,27 @@ import torch.nn.functional as F
import pooling_layers as pooling_layers import pooling_layers as pooling_layers
from fusion import AFF from fusion import AFF
class ReLU(nn.Hardtanh):
class ReLU(nn.Hardtanh):
def __init__(self, inplace=False): def __init__(self, inplace=False):
super(ReLU, self).__init__(0, 20, inplace) super(ReLU, self).__init__(0, 20, inplace)
def __repr__(self): def __repr__(self):
inplace_str = 'inplace' if self.inplace else '' inplace_str = "inplace" if self.inplace else ""
return self.__class__.__name__ + ' (' \ return self.__class__.__name__ + " (" + inplace_str + ")"
+ inplace_str + ')'
class BasicBlockERes2NetV2(nn.Module): class BasicBlockERes2NetV2(nn.Module):
def __init__(self, in_planes, planes, stride=1, baseWidth=26, scale=2, expansion=2): def __init__(self, in_planes, planes, stride=1, baseWidth=26, scale=2, expansion=2):
super(BasicBlockERes2NetV2, self).__init__() super(BasicBlockERes2NetV2, self).__init__()
width = int(math.floor(planes*(baseWidth/64.0))) width = int(math.floor(planes * (baseWidth / 64.0)))
self.conv1 = nn.Conv2d(in_planes, width*scale, kernel_size=1, stride=stride, bias=False) self.conv1 = nn.Conv2d(in_planes, width * scale, kernel_size=1, stride=stride, bias=False)
self.bn1 = nn.BatchNorm2d(width*scale) self.bn1 = nn.BatchNorm2d(width * scale)
self.nums = scale self.nums = scale
self.expansion = expansion self.expansion = expansion
convs=[] convs = []
bns=[] bns = []
for i in range(self.nums): for i in range(self.nums):
convs.append(nn.Conv2d(width, width, kernel_size=3, padding=1, bias=False)) convs.append(nn.Conv2d(width, width, kernel_size=3, padding=1, bias=False))
bns.append(nn.BatchNorm2d(width)) bns.append(nn.BatchNorm2d(width))
@ -47,17 +43,14 @@ class BasicBlockERes2NetV2(nn.Module):
self.bns = nn.ModuleList(bns) self.bns = nn.ModuleList(bns)
self.relu = ReLU(inplace=True) self.relu = ReLU(inplace=True)
self.conv3 = nn.Conv2d(width*scale, planes*self.expansion, kernel_size=1, bias=False) self.conv3 = nn.Conv2d(width * scale, planes * self.expansion, kernel_size=1, bias=False)
self.bn3 = nn.BatchNorm2d(planes*self.expansion) self.bn3 = nn.BatchNorm2d(planes * self.expansion)
self.shortcut = nn.Sequential() self.shortcut = nn.Sequential()
if stride != 1 or in_planes != self.expansion * planes: if stride != 1 or in_planes != self.expansion * planes:
self.shortcut = nn.Sequential( self.shortcut = nn.Sequential(
nn.Conv2d(in_planes, nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
self.expansion * planes, nn.BatchNorm2d(self.expansion * planes),
kernel_size=1, )
stride=stride,
bias=False),
nn.BatchNorm2d(self.expansion * planes))
self.stride = stride self.stride = stride
self.width = width self.width = width
self.scale = scale self.scale = scale
@ -68,18 +61,18 @@ class BasicBlockERes2NetV2(nn.Module):
out = self.conv1(x) out = self.conv1(x)
out = self.bn1(out) out = self.bn1(out)
out = self.relu(out) out = self.relu(out)
spx = torch.split(out,self.width,1) spx = torch.split(out, self.width, 1)
for i in range(self.nums): for i in range(self.nums):
if i==0: if i == 0:
sp = spx[i] sp = spx[i]
else: else:
sp = sp + spx[i] sp = sp + spx[i]
sp = self.convs[i](sp) sp = self.convs[i](sp)
sp = self.relu(self.bns[i](sp)) sp = self.relu(self.bns[i](sp))
if i==0: if i == 0:
out = sp out = sp
else: else:
out = torch.cat((out,sp),1) out = torch.cat((out, sp), 1)
out = self.conv3(out) out = self.conv3(out)
out = self.bn3(out) out = self.bn3(out)
@ -90,19 +83,19 @@ class BasicBlockERes2NetV2(nn.Module):
return out return out
class BasicBlockERes2NetV2AFF(nn.Module):
class BasicBlockERes2NetV2AFF(nn.Module):
def __init__(self, in_planes, planes, stride=1, baseWidth=26, scale=2, expansion=2): def __init__(self, in_planes, planes, stride=1, baseWidth=26, scale=2, expansion=2):
super(BasicBlockERes2NetV2AFF, self).__init__() super(BasicBlockERes2NetV2AFF, self).__init__()
width = int(math.floor(planes*(baseWidth/64.0))) width = int(math.floor(planes * (baseWidth / 64.0)))
self.conv1 = nn.Conv2d(in_planes, width*scale, kernel_size=1, stride=stride, bias=False) self.conv1 = nn.Conv2d(in_planes, width * scale, kernel_size=1, stride=stride, bias=False)
self.bn1 = nn.BatchNorm2d(width*scale) self.bn1 = nn.BatchNorm2d(width * scale)
self.nums = scale self.nums = scale
self.expansion = expansion self.expansion = expansion
convs=[] convs = []
fuse_models=[] fuse_models = []
bns=[] bns = []
for i in range(self.nums): for i in range(self.nums):
convs.append(nn.Conv2d(width, width, kernel_size=3, padding=1, bias=False)) convs.append(nn.Conv2d(width, width, kernel_size=3, padding=1, bias=False))
bns.append(nn.BatchNorm2d(width)) bns.append(nn.BatchNorm2d(width))
@ -114,17 +107,14 @@ class BasicBlockERes2NetV2AFF(nn.Module):
self.fuse_models = nn.ModuleList(fuse_models) self.fuse_models = nn.ModuleList(fuse_models)
self.relu = ReLU(inplace=True) self.relu = ReLU(inplace=True)
self.conv3 = nn.Conv2d(width*scale, planes*self.expansion, kernel_size=1, bias=False) self.conv3 = nn.Conv2d(width * scale, planes * self.expansion, kernel_size=1, bias=False)
self.bn3 = nn.BatchNorm2d(planes*self.expansion) self.bn3 = nn.BatchNorm2d(planes * self.expansion)
self.shortcut = nn.Sequential() self.shortcut = nn.Sequential()
if stride != 1 or in_planes != self.expansion * planes: if stride != 1 or in_planes != self.expansion * planes:
self.shortcut = nn.Sequential( self.shortcut = nn.Sequential(
nn.Conv2d(in_planes, nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
self.expansion * planes, nn.BatchNorm2d(self.expansion * planes),
kernel_size=1, )
stride=stride,
bias=False),
nn.BatchNorm2d(self.expansion * planes))
self.stride = stride self.stride = stride
self.width = width self.width = width
self.scale = scale self.scale = scale
@ -135,19 +125,19 @@ class BasicBlockERes2NetV2AFF(nn.Module):
out = self.conv1(x) out = self.conv1(x)
out = self.bn1(out) out = self.bn1(out)
out = self.relu(out) out = self.relu(out)
spx = torch.split(out,self.width,1) spx = torch.split(out, self.width, 1)
for i in range(self.nums): for i in range(self.nums):
if i==0: if i == 0:
sp = spx[i] sp = spx[i]
else: else:
sp = self.fuse_models[i-1](sp, spx[i]) sp = self.fuse_models[i - 1](sp, spx[i])
sp = self.convs[i](sp) sp = self.convs[i](sp)
sp = self.relu(self.bns[i](sp)) sp = self.relu(self.bns[i](sp))
if i==0: if i == 0:
out = sp out = sp
else: else:
out = torch.cat((out,sp),1) out = torch.cat((out, sp), 1)
out = self.conv3(out) out = self.conv3(out)
out = self.bn3(out) out = self.bn3(out)
@ -158,8 +148,10 @@ class BasicBlockERes2NetV2AFF(nn.Module):
return out return out
class ERes2NetV2(nn.Module): class ERes2NetV2(nn.Module):
def __init__(self, def __init__(
self,
block=BasicBlockERes2NetV2, block=BasicBlockERes2NetV2,
block_fuse=BasicBlockERes2NetV2AFF, block_fuse=BasicBlockERes2NetV2AFF,
num_blocks=[3, 4, 6, 3], num_blocks=[3, 4, 6, 3],
@ -169,8 +161,9 @@ class ERes2NetV2(nn.Module):
baseWidth=26, baseWidth=26,
scale=2, scale=2,
expansion=2, expansion=2,
pooling_func='TSTP', pooling_func="TSTP",
two_emb_layer=False): two_emb_layer=False,
):
super(ERes2NetV2, self).__init__() super(ERes2NetV2, self).__init__()
self.in_planes = m_channels self.in_planes = m_channels
self.feat_dim = feat_dim self.feat_dim = feat_dim
@ -181,42 +174,29 @@ class ERes2NetV2(nn.Module):
self.scale = scale self.scale = scale
self.expansion = expansion self.expansion = expansion
self.conv1 = nn.Conv2d(1, self.conv1 = nn.Conv2d(1, m_channels, kernel_size=3, stride=1, padding=1, bias=False)
m_channels,
kernel_size=3,
stride=1,
padding=1,
bias=False)
self.bn1 = nn.BatchNorm2d(m_channels) self.bn1 = nn.BatchNorm2d(m_channels)
self.layer1 = self._make_layer(block, self.layer1 = self._make_layer(block, m_channels, num_blocks[0], stride=1)
m_channels, self.layer2 = self._make_layer(block, m_channels * 2, num_blocks[1], stride=2)
num_blocks[0], self.layer3 = self._make_layer(block_fuse, m_channels * 4, num_blocks[2], stride=2)
stride=1) self.layer4 = self._make_layer(block_fuse, m_channels * 8, num_blocks[3], stride=2)
self.layer2 = self._make_layer(block,
m_channels * 2,
num_blocks[1],
stride=2)
self.layer3 = self._make_layer(block_fuse,
m_channels * 4,
num_blocks[2],
stride=2)
self.layer4 = self._make_layer(block_fuse,
m_channels * 8,
num_blocks[3],
stride=2)
# Downsampling module # Downsampling module
self.layer3_ds = nn.Conv2d(m_channels * 4 * self.expansion, m_channels * 8 * self.expansion, kernel_size=3, \ self.layer3_ds = nn.Conv2d(
padding=1, stride=2, bias=False) m_channels * 4 * self.expansion,
m_channels * 8 * self.expansion,
kernel_size=3,
padding=1,
stride=2,
bias=False,
)
# Bottom-up fusion module # Bottom-up fusion module
self.fuse34 = AFF(channels=m_channels * 8 * self.expansion, r=4) self.fuse34 = AFF(channels=m_channels * 8 * self.expansion, r=4)
self.n_stats = 1 if pooling_func == 'TAP' or pooling_func == "TSDP" else 2 self.n_stats = 1 if pooling_func == "TAP" or pooling_func == "TSDP" else 2
self.pool = getattr(pooling_layers, pooling_func)( self.pool = getattr(pooling_layers, pooling_func)(in_dim=self.stats_dim * self.expansion)
in_dim=self.stats_dim * self.expansion) self.seg_1 = nn.Linear(self.stats_dim * self.expansion * self.n_stats, embedding_size)
self.seg_1 = nn.Linear(self.stats_dim * self.expansion * self.n_stats,
embedding_size)
if self.two_emb_layer: if self.two_emb_layer:
self.seg_bn_1 = nn.BatchNorm1d(embedding_size, affine=False) self.seg_bn_1 = nn.BatchNorm1d(embedding_size, affine=False)
self.seg_2 = nn.Linear(embedding_size, embedding_size) self.seg_2 = nn.Linear(embedding_size, embedding_size)
@ -228,7 +208,11 @@ class ERes2NetV2(nn.Module):
strides = [stride] + [1] * (num_blocks - 1) strides = [stride] + [1] * (num_blocks - 1)
layers = [] layers = []
for stride in strides: for stride in strides:
layers.append(block(self.in_planes, planes, stride, baseWidth=self.baseWidth, scale=self.scale, expansion=self.expansion)) layers.append(
block(
self.in_planes, planes, stride, baseWidth=self.baseWidth, scale=self.scale, expansion=self.expansion
)
)
self.in_planes = planes * self.expansion self.in_planes = planes * self.expansion
return nn.Sequential(*layers) return nn.Sequential(*layers)
@ -264,7 +248,7 @@ class ERes2NetV2(nn.Module):
out3_ds = self.layer3_ds(out3) out3_ds = self.layer3_ds(out3)
fuse_out34 = self.fuse34(out4, out3_ds) fuse_out34 = self.fuse34(out4, out3_ds)
# print(111111111,fuse_out34.shape)#111111111 torch.Size([16, 2048, 10, 72]) # print(111111111,fuse_out34.shape)#111111111 torch.Size([16, 2048, 10, 72])
return fuse_out34.flatten(start_dim=1,end_dim=2).mean(-1) return fuse_out34.flatten(start_dim=1, end_dim=2).mean(-1)
# stats = self.pool(fuse_out34) # stats = self.pool(fuse_out34)
# #
# embed_a = self.seg_1(stats) # embed_a = self.seg_1(stats)
@ -276,17 +260,13 @@ class ERes2NetV2(nn.Module):
# else: # else:
# return embed_a # return embed_a
if __name__ == '__main__':
if __name__ == "__main__":
x = torch.randn(1, 300, 80) x = torch.randn(1, 300, 80)
model = ERes2NetV2(feat_dim=80, embedding_size=192, m_channels=64, baseWidth=26, scale=2, expansion=2) model = ERes2NetV2(feat_dim=80, embedding_size=192, m_channels=64, baseWidth=26, scale=2, expansion=2)
model.eval() model.eval()
y = model(x) y = model(x)
print(y.size()) print(y.size())
macs, num_params = profile(model, inputs=(x, )) macs, num_params = profile(model, inputs=(x,))
print("Params: {} M".format(num_params / 1e6)) # 17.86 M print("Params: {} M".format(num_params / 1e6)) # 17.86 M
print("MACs: {} G".format(macs / 1e9)) # 12.69 G print("MACs: {} G".format(macs / 1e9)) # 12.69 G

View File

@ -1,14 +1,13 @@
# Copyright 3D-Speaker (https://github.com/alibaba-damo-academy/3D-Speaker). All Rights Reserved. # Copyright 3D-Speaker (https://github.com/alibaba-damo-academy/3D-Speaker). All Rights Reserved.
# Licensed under the Apache License, Version 2.0 (http://www.apache.org/licenses/LICENSE-2.0) # Licensed under the Apache License, Version 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
""" Res2Net implementation is adapted from https://github.com/wenet-e2e/wespeaker. """Res2Net implementation is adapted from https://github.com/wenet-e2e/wespeaker.
ERes2Net incorporates both local and global feature fusion techniques to improve the performance. ERes2Net incorporates both local and global feature fusion techniques to improve the performance.
The local feature fusion (LFF) fuses the features within one single residual block to extract the local signal. The local feature fusion (LFF) fuses the features within one single residual block to extract the local signal.
The global feature fusion (GFF) takes acoustic features of different scales as input to aggregate global signal. The global feature fusion (GFF) takes acoustic features of different scales as input to aggregate global signal.
ERes2Net-huge is an upgraded version of ERes2Net that uses a larger number of parameters to achieve better ERes2Net-huge is an upgraded version of ERes2Net that uses a larger number of parameters to achieve better
recognition performance. Parameters expansion, baseWidth, and scale can be modified to obtain optimal performance. recognition performance. Parameters expansion, baseWidth, and scale can be modified to obtain optimal performance.
""" """
import pdb
import torch import torch
import math import math
@ -17,15 +16,14 @@ import torch.nn.functional as F
import pooling_layers as pooling_layers import pooling_layers as pooling_layers
from fusion import AFF from fusion import AFF
class ReLU(nn.Hardtanh):
class ReLU(nn.Hardtanh):
def __init__(self, inplace=False): def __init__(self, inplace=False):
super(ReLU, self).__init__(0, 20, inplace) super(ReLU, self).__init__(0, 20, inplace)
def __repr__(self): def __repr__(self):
inplace_str = 'inplace' if self.inplace else '' inplace_str = "inplace" if self.inplace else ""
return self.__class__.__name__ + ' (' \ return self.__class__.__name__ + " (" + inplace_str + ")"
+ inplace_str + ')'
class BasicBlockERes2Net(nn.Module): class BasicBlockERes2Net(nn.Module):
@ -33,13 +31,13 @@ class BasicBlockERes2Net(nn.Module):
def __init__(self, in_planes, planes, stride=1, baseWidth=24, scale=3): def __init__(self, in_planes, planes, stride=1, baseWidth=24, scale=3):
super(BasicBlockERes2Net, self).__init__() super(BasicBlockERes2Net, self).__init__()
width = int(math.floor(planes*(baseWidth/64.0))) width = int(math.floor(planes * (baseWidth / 64.0)))
self.conv1 = nn.Conv2d(in_planes, width*scale, kernel_size=1, stride=stride, bias=False) self.conv1 = nn.Conv2d(in_planes, width * scale, kernel_size=1, stride=stride, bias=False)
self.bn1 = nn.BatchNorm2d(width*scale) self.bn1 = nn.BatchNorm2d(width * scale)
self.nums = scale self.nums = scale
convs=[] convs = []
bns=[] bns = []
for i in range(self.nums): for i in range(self.nums):
convs.append(nn.Conv2d(width, width, kernel_size=3, padding=1, bias=False)) convs.append(nn.Conv2d(width, width, kernel_size=3, padding=1, bias=False))
bns.append(nn.BatchNorm2d(width)) bns.append(nn.BatchNorm2d(width))
@ -47,13 +45,14 @@ class BasicBlockERes2Net(nn.Module):
self.bns = nn.ModuleList(bns) self.bns = nn.ModuleList(bns)
self.relu = ReLU(inplace=True) self.relu = ReLU(inplace=True)
self.conv3 = nn.Conv2d(width*scale, planes*self.expansion, kernel_size=1, bias=False) self.conv3 = nn.Conv2d(width * scale, planes * self.expansion, kernel_size=1, bias=False)
self.bn3 = nn.BatchNorm2d(planes*self.expansion) self.bn3 = nn.BatchNorm2d(planes * self.expansion)
self.shortcut = nn.Sequential() self.shortcut = nn.Sequential()
if stride != 1 or in_planes != self.expansion * planes: if stride != 1 or in_planes != self.expansion * planes:
self.shortcut = nn.Sequential( self.shortcut = nn.Sequential(
nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False), nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(self.expansion * planes)) nn.BatchNorm2d(self.expansion * planes),
)
self.stride = stride self.stride = stride
self.width = width self.width = width
self.scale = scale self.scale = scale
@ -64,18 +63,18 @@ class BasicBlockERes2Net(nn.Module):
out = self.conv1(x) out = self.conv1(x)
out = self.bn1(out) out = self.bn1(out)
out = self.relu(out) out = self.relu(out)
spx = torch.split(out,self.width,1) spx = torch.split(out, self.width, 1)
for i in range(self.nums): for i in range(self.nums):
if i==0: if i == 0:
sp = spx[i] sp = spx[i]
else: else:
sp = sp + spx[i] sp = sp + spx[i]
sp = self.convs[i](sp) sp = self.convs[i](sp)
sp = self.relu(self.bns[i](sp)) sp = self.relu(self.bns[i](sp))
if i==0: if i == 0:
out = sp out = sp
else: else:
out = torch.cat((out,sp),1) out = torch.cat((out, sp), 1)
out = self.conv3(out) out = self.conv3(out)
out = self.bn3(out) out = self.bn3(out)
@ -86,19 +85,20 @@ class BasicBlockERes2Net(nn.Module):
return out return out
class BasicBlockERes2Net_diff_AFF(nn.Module): class BasicBlockERes2Net_diff_AFF(nn.Module):
expansion = 4 expansion = 4
def __init__(self, in_planes, planes, stride=1, baseWidth=24, scale=3): def __init__(self, in_planes, planes, stride=1, baseWidth=24, scale=3):
super(BasicBlockERes2Net_diff_AFF, self).__init__() super(BasicBlockERes2Net_diff_AFF, self).__init__()
width = int(math.floor(planes*(baseWidth/64.0))) width = int(math.floor(planes * (baseWidth / 64.0)))
self.conv1 = nn.Conv2d(in_planes, width*scale, kernel_size=1, stride=stride, bias=False) self.conv1 = nn.Conv2d(in_planes, width * scale, kernel_size=1, stride=stride, bias=False)
self.bn1 = nn.BatchNorm2d(width*scale) self.bn1 = nn.BatchNorm2d(width * scale)
self.nums = scale self.nums = scale
convs=[] convs = []
fuse_models=[] fuse_models = []
bns=[] bns = []
for i in range(self.nums): for i in range(self.nums):
convs.append(nn.Conv2d(width, width, kernel_size=3, padding=1, bias=False)) convs.append(nn.Conv2d(width, width, kernel_size=3, padding=1, bias=False))
bns.append(nn.BatchNorm2d(width)) bns.append(nn.BatchNorm2d(width))
@ -110,13 +110,14 @@ class BasicBlockERes2Net_diff_AFF(nn.Module):
self.fuse_models = nn.ModuleList(fuse_models) self.fuse_models = nn.ModuleList(fuse_models)
self.relu = ReLU(inplace=True) self.relu = ReLU(inplace=True)
self.conv3 = nn.Conv2d(width*scale, planes*self.expansion, kernel_size=1, bias=False) self.conv3 = nn.Conv2d(width * scale, planes * self.expansion, kernel_size=1, bias=False)
self.bn3 = nn.BatchNorm2d(planes*self.expansion) self.bn3 = nn.BatchNorm2d(planes * self.expansion)
self.shortcut = nn.Sequential() self.shortcut = nn.Sequential()
if stride != 1 or in_planes != self.expansion * planes: if stride != 1 or in_planes != self.expansion * planes:
self.shortcut = nn.Sequential( self.shortcut = nn.Sequential(
nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False), nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(self.expansion * planes)) nn.BatchNorm2d(self.expansion * planes),
)
self.stride = stride self.stride = stride
self.width = width self.width = width
self.scale = scale self.scale = scale
@ -127,20 +128,19 @@ class BasicBlockERes2Net_diff_AFF(nn.Module):
out = self.conv1(x) out = self.conv1(x)
out = self.bn1(out) out = self.bn1(out)
out = self.relu(out) out = self.relu(out)
spx = torch.split(out,self.width,1) spx = torch.split(out, self.width, 1)
for i in range(self.nums): for i in range(self.nums):
if i==0: if i == 0:
sp = spx[i] sp = spx[i]
else: else:
sp = self.fuse_models[i-1](sp, spx[i]) sp = self.fuse_models[i - 1](sp, spx[i])
sp = self.convs[i](sp) sp = self.convs[i](sp)
sp = self.relu(self.bns[i](sp)) sp = self.relu(self.bns[i](sp))
if i==0: if i == 0:
out = sp out = sp
else: else:
out = torch.cat((out,sp),1) out = torch.cat((out, sp), 1)
out = self.conv3(out) out = self.conv3(out)
out = self.bn3(out) out = self.bn3(out)
@ -151,16 +151,19 @@ class BasicBlockERes2Net_diff_AFF(nn.Module):
return out return out
class ERes2Net(nn.Module): class ERes2Net(nn.Module):
def __init__(self, def __init__(
self,
block=BasicBlockERes2Net, block=BasicBlockERes2Net,
block_fuse=BasicBlockERes2Net_diff_AFF, block_fuse=BasicBlockERes2Net_diff_AFF,
num_blocks=[3, 4, 6, 3], num_blocks=[3, 4, 6, 3],
m_channels=64, m_channels=64,
feat_dim=80, feat_dim=80,
embedding_size=192, embedding_size=192,
pooling_func='TSTP', pooling_func="TSTP",
two_emb_layer=False): two_emb_layer=False,
):
super(ERes2Net, self).__init__() super(ERes2Net, self).__init__()
self.in_planes = m_channels self.in_planes = m_channels
self.feat_dim = feat_dim self.feat_dim = feat_dim
@ -176,17 +179,22 @@ class ERes2Net(nn.Module):
self.layer3 = self._make_layer(block_fuse, m_channels * 4, num_blocks[2], stride=2) self.layer3 = self._make_layer(block_fuse, m_channels * 4, num_blocks[2], stride=2)
self.layer4 = self._make_layer(block_fuse, m_channels * 8, num_blocks[3], stride=2) self.layer4 = self._make_layer(block_fuse, m_channels * 8, num_blocks[3], stride=2)
self.layer1_downsample = nn.Conv2d(m_channels * 4, m_channels * 8, kernel_size=3, padding=1, stride=2, bias=False) self.layer1_downsample = nn.Conv2d(
self.layer2_downsample = nn.Conv2d(m_channels * 8, m_channels * 16, kernel_size=3, padding=1, stride=2, bias=False) m_channels * 4, m_channels * 8, kernel_size=3, padding=1, stride=2, bias=False
self.layer3_downsample = nn.Conv2d(m_channels * 16, m_channels * 32, kernel_size=3, padding=1, stride=2, bias=False) )
self.layer2_downsample = nn.Conv2d(
m_channels * 8, m_channels * 16, kernel_size=3, padding=1, stride=2, bias=False
)
self.layer3_downsample = nn.Conv2d(
m_channels * 16, m_channels * 32, kernel_size=3, padding=1, stride=2, bias=False
)
self.fuse_mode12 = AFF(channels=m_channels * 8) self.fuse_mode12 = AFF(channels=m_channels * 8)
self.fuse_mode123 = AFF(channels=m_channels * 16) self.fuse_mode123 = AFF(channels=m_channels * 16)
self.fuse_mode1234 = AFF(channels=m_channels * 32) self.fuse_mode1234 = AFF(channels=m_channels * 32)
self.n_stats = 1 if pooling_func == 'TAP' or pooling_func == "TSDP" else 2 self.n_stats = 1 if pooling_func == "TAP" or pooling_func == "TSDP" else 2
self.pool = getattr(pooling_layers, pooling_func)( self.pool = getattr(pooling_layers, pooling_func)(in_dim=self.stats_dim * block.expansion)
in_dim=self.stats_dim * block.expansion)
self.seg_1 = nn.Linear(self.stats_dim * block.expansion * self.n_stats, embedding_size) self.seg_1 = nn.Linear(self.stats_dim * block.expansion * self.n_stats, embedding_size)
if self.two_emb_layer: if self.two_emb_layer:
self.seg_bn_1 = nn.BatchNorm1d(embedding_size, affine=False) self.seg_bn_1 = nn.BatchNorm1d(embedding_size, affine=False)
@ -229,7 +237,7 @@ class ERes2Net(nn.Module):
else: else:
return embed_a return embed_a
def forward2(self, x,if_mean): def forward2(self, x, if_mean):
x = x.permute(0, 2, 1) # (B,T,F) => (B,F,T) x = x.permute(0, 2, 1) # (B,T,F) => (B,F,T)
x = x.unsqueeze_(1) x = x.unsqueeze_(1)
@ -243,14 +251,13 @@ class ERes2Net(nn.Module):
fuse_out123 = self.fuse_mode123(out3, fuse_out12_downsample) fuse_out123 = self.fuse_mode123(out3, fuse_out12_downsample)
out4 = self.layer4(out3) out4 = self.layer4(out3)
fuse_out123_downsample = self.layer3_downsample(fuse_out123) fuse_out123_downsample = self.layer3_downsample(fuse_out123)
fuse_out1234 = self.fuse_mode1234(out4, fuse_out123_downsample).flatten(start_dim=1,end_dim=2)#bs,20480,T fuse_out1234 = self.fuse_mode1234(out4, fuse_out123_downsample).flatten(start_dim=1, end_dim=2) # bs,20480,T
if(if_mean==False): if if_mean == False:
mean=fuse_out1234[0].transpose(1,0)#(T,20480),bs=T mean = fuse_out1234[0].transpose(1, 0) # (T,20480),bs=T
else: else:
mean = fuse_out1234.mean(2)#bs,20480 mean = fuse_out1234.mean(2) # bs,20480
mean_std=torch.cat([mean,torch.zeros_like(mean)],1) mean_std = torch.cat([mean, torch.zeros_like(mean)], 1)
return self.seg_1(mean_std)#(T,192) return self.seg_1(mean_std) # (T,192)
# stats = self.pool(fuse_out1234) # stats = self.pool(fuse_out1234)
# if self.two_emb_layer: # if self.two_emb_layer:
@ -275,12 +282,8 @@ class ERes2Net(nn.Module):
fuse_out123 = self.fuse_mode123(out3, fuse_out12_downsample) fuse_out123 = self.fuse_mode123(out3, fuse_out12_downsample)
out4 = self.layer4(out3) out4 = self.layer4(out3)
fuse_out123_downsample = self.layer3_downsample(fuse_out123) fuse_out123_downsample = self.layer3_downsample(fuse_out123)
fuse_out1234 = self.fuse_mode1234(out4, fuse_out123_downsample).flatten(start_dim=1,end_dim=2).mean(-1) fuse_out1234 = self.fuse_mode1234(out4, fuse_out123_downsample).flatten(start_dim=1, end_dim=2).mean(-1)
return fuse_out1234 return fuse_out1234
# print(fuse_out1234.shape) # print(fuse_out1234.shape)
# print(fuse_out1234.flatten(start_dim=1,end_dim=2).shape) # print(fuse_out1234.flatten(start_dim=1,end_dim=2).shape)
# pdb.set_trace() # pdb.set_trace()

View File

@ -6,7 +6,6 @@ import torch.nn as nn
class AFF(nn.Module): class AFF(nn.Module):
def __init__(self, channels=64, r=4): def __init__(self, channels=64, r=4):
super(AFF, self).__init__() super(AFF, self).__init__()
inter_channels = int(channels // r) inter_channels = int(channels // r)
@ -23,7 +22,6 @@ class AFF(nn.Module):
xa = torch.cat((x, ds_y), dim=1) xa = torch.cat((x, ds_y), dim=1)
x_att = self.local_att(xa) x_att = self.local_att(xa)
x_att = 1.0 + torch.tanh(x_att) x_att = 1.0 + torch.tanh(x_att)
xo = torch.mul(x, x_att) + torch.mul(ds_y, 2.0-x_att) xo = torch.mul(x, x_att) + torch.mul(ds_y, 2.0 - x_att)
return xo return xo

View File

@ -144,7 +144,7 @@ def _get_waveform_and_window_properties(
) )
assert 0 < window_shift, "`window_shift` must be greater than 0" assert 0 < window_shift, "`window_shift` must be greater than 0"
assert padded_window_size % 2 == 0, ( assert padded_window_size % 2 == 0, (
"the padded `window_size` must be divisible by two." " use `round_to_power_of_two` or change `frame_length`" "the padded `window_size` must be divisible by two. use `round_to_power_of_two` or change `frame_length`"
) )
assert 0.0 <= preemphasis_coefficient <= 1.0, "`preemphasis_coefficient` must be between [0,1]" assert 0.0 <= preemphasis_coefficient <= 1.0, "`preemphasis_coefficient` must be between [0,1]"
assert sample_frequency > 0, "`sample_frequency` must be greater than zero" assert sample_frequency > 0, "`sample_frequency` must be greater than zero"
@ -441,7 +441,9 @@ def get_mel_banks(
high_freq: float, high_freq: float,
vtln_low: float, vtln_low: float,
vtln_high: float, vtln_high: float,
vtln_warp_factor: float,device=None,dtype=None vtln_warp_factor: float,
device=None,
dtype=None,
) -> Tuple[Tensor, Tensor]: ) -> Tuple[Tensor, Tensor]:
""" """
Returns: Returns:
@ -457,9 +459,9 @@ def get_mel_banks(
if high_freq <= 0.0: if high_freq <= 0.0:
high_freq += nyquist high_freq += nyquist
assert ( assert (0.0 <= low_freq < nyquist) and (0.0 < high_freq <= nyquist) and (low_freq < high_freq), (
(0.0 <= low_freq < nyquist) and (0.0 < high_freq <= nyquist) and (low_freq < high_freq) "Bad values in options: low-freq {} and high-freq {} vs. nyquist {}".format(low_freq, high_freq, nyquist)
), "Bad values in options: low-freq {} and high-freq {} vs. nyquist {}".format(low_freq, high_freq, nyquist) )
# fft-bin width [think of it as Nyquist-freq / half-window-length] # fft-bin width [think of it as Nyquist-freq / half-window-length]
fft_bin_width = sample_freq / window_length_padded fft_bin_width = sample_freq / window_length_padded
@ -475,7 +477,7 @@ def get_mel_banks(
assert vtln_warp_factor == 1.0 or ( assert vtln_warp_factor == 1.0 or (
(low_freq < vtln_low < high_freq) and (0.0 < vtln_high < high_freq) and (vtln_low < vtln_high) (low_freq < vtln_low < high_freq) and (0.0 < vtln_high < high_freq) and (vtln_low < vtln_high)
), "Bad values in options: vtln-low {} and vtln-high {}, versus " "low-freq {} and high-freq {}".format( ), "Bad values in options: vtln-low {} and vtln-high {}, versus low-freq {} and high-freq {}".format(
vtln_low, vtln_high, low_freq, high_freq vtln_low, vtln_high, low_freq, high_freq
) )
@ -508,9 +510,12 @@ def get_mel_banks(
bins[up_idx] = up_slope[up_idx] bins[up_idx] = up_slope[up_idx]
bins[down_idx] = down_slope[down_idx] bins[down_idx] = down_slope[down_idx]
return bins.to(device=device,dtype=dtype)#, center_freqs return bins.to(device=device, dtype=dtype) # , center_freqs
cache = {}
cache={}
def fbank( def fbank(
waveform: Tensor, waveform: Tensor,
blackman_coeff: float = 0.42, blackman_coeff: float = 0.42,
@ -620,14 +625,34 @@ def fbank(
# size (num_mel_bins, padded_window_size // 2) # size (num_mel_bins, padded_window_size // 2)
# print(num_mel_bins, padded_window_size, sample_frequency, low_freq, high_freq, vtln_low, vtln_high, vtln_warp) # print(num_mel_bins, padded_window_size, sample_frequency, low_freq, high_freq, vtln_low, vtln_high, vtln_warp)
cache_key="%s-%s-%s-%s-%s-%s-%s-%s-%s-%s"%(num_mel_bins, padded_window_size, sample_frequency, low_freq, high_freq, vtln_low, vtln_high, vtln_warp,device,dtype) cache_key = "%s-%s-%s-%s-%s-%s-%s-%s-%s-%s" % (
num_mel_bins,
padded_window_size,
sample_frequency,
low_freq,
high_freq,
vtln_low,
vtln_high,
vtln_warp,
device,
dtype,
)
if cache_key not in cache: if cache_key not in cache:
mel_energies = get_mel_banks( mel_energies = get_mel_banks(
num_mel_bins, padded_window_size, sample_frequency, low_freq, high_freq, vtln_low, vtln_high, vtln_warp,device,dtype num_mel_bins,
padded_window_size,
sample_frequency,
low_freq,
high_freq,
vtln_low,
vtln_high,
vtln_warp,
device,
dtype,
) )
cache[cache_key]=mel_energies cache[cache_key] = mel_energies
else: else:
mel_energies=cache[cache_key] mel_energies = cache[cache_key]
# pad right column with zeros and add dimension, size (num_mel_bins, padded_window_size // 2 + 1) # pad right column with zeros and add dimension, size (num_mel_bins, padded_window_size // 2 + 1)
mel_energies = torch.nn.functional.pad(mel_energies, (0, 1), mode="constant", value=0) mel_energies = torch.nn.functional.pad(mel_energies, (0, 1), mode="constant", value=0)

View File

@ -1,7 +1,7 @@
# Copyright 3D-Speaker (https://github.com/alibaba-damo-academy/3D-Speaker). All Rights Reserved. # Copyright 3D-Speaker (https://github.com/alibaba-damo-academy/3D-Speaker). All Rights Reserved.
# Licensed under the Apache License, Version 2.0 (http://www.apache.org/licenses/LICENSE-2.0) # Licensed under the Apache License, Version 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
""" This implementation is adapted from https://github.com/wenet-e2e/wespeaker.""" """This implementation is adapted from https://github.com/wenet-e2e/wespeaker."""
import torch import torch
import torch.nn as nn import torch.nn as nn
@ -11,6 +11,7 @@ class TAP(nn.Module):
""" """
Temporal average pooling, only first-order mean is considered Temporal average pooling, only first-order mean is considered
""" """
def __init__(self, **kwargs): def __init__(self, **kwargs):
super(TAP, self).__init__() super(TAP, self).__init__()
@ -25,6 +26,7 @@ class TSDP(nn.Module):
""" """
Temporal standard deviation pooling, only second-order std is considered Temporal standard deviation pooling, only second-order std is considered
""" """
def __init__(self, **kwargs): def __init__(self, **kwargs):
super(TSDP, self).__init__() super(TSDP, self).__init__()
@ -41,6 +43,7 @@ class TSTP(nn.Module):
x-vector x-vector
Comment: simple concatenation can not make full use of both statistics Comment: simple concatenation can not make full use of both statistics
""" """
def __init__(self, **kwargs): def __init__(self, **kwargs):
super(TSTP, self).__init__() super(TSTP, self).__init__()
@ -56,9 +59,10 @@ class TSTP(nn.Module):
class ASTP(nn.Module): class ASTP(nn.Module):
""" Attentive statistics pooling: Channel- and context-dependent """Attentive statistics pooling: Channel- and context-dependent
statistics pooling, first used in ECAPA_TDNN. statistics pooling, first used in ECAPA_TDNN.
""" """
def __init__(self, in_dim, bottleneck_dim=128, global_context_att=False): def __init__(self, in_dim, bottleneck_dim=128, global_context_att=False):
super(ASTP, self).__init__() super(ASTP, self).__init__()
self.global_context_att = global_context_att self.global_context_att = global_context_att
@ -66,15 +70,10 @@ class ASTP(nn.Module):
# Use Conv1d with stride == 1 rather than Linear, then we don't # Use Conv1d with stride == 1 rather than Linear, then we don't
# need to transpose inputs. # need to transpose inputs.
if global_context_att: if global_context_att:
self.linear1 = nn.Conv1d( self.linear1 = nn.Conv1d(in_dim * 3, bottleneck_dim, kernel_size=1) # equals W and b in the paper
in_dim * 3, bottleneck_dim,
kernel_size=1) # equals W and b in the paper
else: else:
self.linear1 = nn.Conv1d( self.linear1 = nn.Conv1d(in_dim, bottleneck_dim, kernel_size=1) # equals W and b in the paper
in_dim, bottleneck_dim, self.linear2 = nn.Conv1d(bottleneck_dim, in_dim, kernel_size=1) # equals V and k in the paper
kernel_size=1) # equals W and b in the paper
self.linear2 = nn.Conv1d(bottleneck_dim, in_dim,
kernel_size=1) # equals V and k in the paper
def forward(self, x): def forward(self, x):
""" """
@ -88,15 +87,13 @@ class ASTP(nn.Module):
if self.global_context_att: if self.global_context_att:
context_mean = torch.mean(x, dim=-1, keepdim=True).expand_as(x) context_mean = torch.mean(x, dim=-1, keepdim=True).expand_as(x)
context_std = torch.sqrt( context_std = torch.sqrt(torch.var(x, dim=-1, keepdim=True) + 1e-10).expand_as(x)
torch.var(x, dim=-1, keepdim=True) + 1e-10).expand_as(x)
x_in = torch.cat((x, context_mean, context_std), dim=1) x_in = torch.cat((x, context_mean, context_std), dim=1)
else: else:
x_in = x x_in = x
# DON'T use ReLU here! ReLU may be hard to converge. # DON'T use ReLU here! ReLU may be hard to converge.
alpha = torch.tanh( alpha = torch.tanh(self.linear1(x_in)) # alpha = F.relu(self.linear1(x_in))
self.linear1(x_in)) # alpha = F.relu(self.linear1(x_in))
alpha = torch.softmax(self.linear2(alpha), dim=2) alpha = torch.softmax(self.linear2(alpha), dim=2)
mean = torch.sum(alpha * x, dim=2) mean = torch.sum(alpha * x, dim=2)
var = torch.sum(alpha * (x**2), dim=2) - mean**2 var = torch.sum(alpha * (x**2), dim=2) - mean**2

View File

@ -1,6 +1,7 @@
# modified from https://github.com/yangdongchao/SoundStorm/blob/master/soundstorm/s1/AR/models/t2s_model.py # modified from https://github.com/yangdongchao/SoundStorm/blob/master/soundstorm/s1/AR/models/t2s_model.py
# reference: https://github.com/lifeiteng/vall-e # reference: https://github.com/lifeiteng/vall-e
import argparse import argparse
from io import BytesIO
from typing import Optional from typing import Optional
from my_utils import load_audio from my_utils import load_audio
import torch import torch
@ -17,6 +18,9 @@ from module.models_onnx import SynthesizerTrn
from inference_webui import get_phones_and_bert from inference_webui import get_phones_and_bert
from sv import SV
import kaldi as Kaldi
import os import os
import soundfile import soundfile
@ -32,6 +36,25 @@ default_config = {
"EOS": 1024, "EOS": 1024,
} }
sv_cn_model = None
def init_sv_cn(device, is_half):
global sv_cn_model
sv_cn_model = SV(device, is_half)
def load_sovits_new(sovits_path):
f = open(sovits_path, "rb")
meta = f.read(2)
if meta != b"PK":
data = b"PK" + f.read()
bio = BytesIO()
bio.write(data)
bio.seek(0)
return torch.load(bio, map_location="cpu", weights_only=False)
return torch.load(sovits_path, map_location="cpu", weights_only=False)
def get_raw_t2s_model(dict_s1) -> Text2SemanticLightningModule: def get_raw_t2s_model(dict_s1) -> Text2SemanticLightningModule:
config = dict_s1["config"] config = dict_s1["config"]
@ -83,7 +106,7 @@ def logits_to_probs(
@torch.jit.script @torch.jit.script
def multinomial_sample_one_no_sync(probs_sort): def multinomial_sample_one_no_sync(probs_sort):
# Does multinomial sampling without a cuda synchronization # Does multinomial sampling without a cuda synchronization
q = torch.randn_like(probs_sort) q = torch.empty_like(probs_sort).exponential_(1.0)
return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int) return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int)
@ -94,7 +117,7 @@ def sample(
temperature: float = 1.0, temperature: float = 1.0,
top_k: Optional[int] = None, top_k: Optional[int] = None,
top_p: Optional[int] = None, top_p: Optional[int] = None,
repetition_penalty: float = 1.0, repetition_penalty: float = 1.35,
): ):
probs = logits_to_probs( probs = logits_to_probs(
logits=logits, logits=logits,
@ -109,8 +132,10 @@ def sample(
@torch.jit.script @torch.jit.script
def spectrogram_torch(y: Tensor, n_fft: int, sampling_rate: int, hop_size: int, win_size: int, center: bool = False): def spectrogram_torch(
hann_window = torch.hann_window(win_size, device=y.device, dtype=y.dtype) hann_window: Tensor, y: Tensor, n_fft: int, sampling_rate: int, hop_size: int, win_size: int, center: bool = False
):
# hann_window = torch.hann_window(win_size, device=y.device, dtype=y.dtype)
y = torch.nn.functional.pad( y = torch.nn.functional.pad(
y.unsqueeze(1), y.unsqueeze(1),
(int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)), (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)),
@ -289,8 +314,9 @@ class T2SBlock:
attn = F.scaled_dot_product_attention(q, k, v) attn = F.scaled_dot_product_attention(q, k, v)
attn = attn.permute(2, 0, 1, 3).reshape(batch_size * q_len, self.hidden_dim) # attn = attn.permute(2, 0, 1, 3).reshape(batch_size * q_len, self.hidden_dim)
attn = attn.view(q_len, batch_size, self.hidden_dim).transpose(1, 0) # attn = attn.view(q_len, batch_size, self.hidden_dim).transpose(1, 0)
attn = attn.transpose(1, 2).reshape(batch_size, q_len, -1)
attn = F.linear(attn, self.out_w, self.out_b) attn = F.linear(attn, self.out_w, self.out_b)
x = x + attn x = x + attn
@ -328,15 +354,22 @@ class T2STransformer:
class VitsModel(nn.Module): class VitsModel(nn.Module):
def __init__(self, vits_path): def __init__(self, vits_path, version=None, is_half=True, device="cpu"):
super().__init__() super().__init__()
# dict_s2 = torch.load(vits_path,map_location="cpu") # dict_s2 = torch.load(vits_path,map_location="cpu")
dict_s2 = torch.load(vits_path, weights_only=False) dict_s2 = load_sovits_new(vits_path)
self.hps = dict_s2["config"] self.hps = dict_s2["config"]
if version is None:
if dict_s2["weight"]["enc_p.text_embedding.weight"].shape[0] == 322: if dict_s2["weight"]["enc_p.text_embedding.weight"].shape[0] == 322:
self.hps["model"]["version"] = "v1" self.hps["model"]["version"] = "v1"
else: else:
self.hps["model"]["version"] = "v2" self.hps["model"]["version"] = "v2"
else:
if version in ["v1", "v2", "v3", "v4", "v2Pro", "v2ProPlus"]:
self.hps["model"]["version"] = version
else:
raise ValueError(f"Unsupported version: {version}")
self.hps = DictToAttrRecursive(self.hps) self.hps = DictToAttrRecursive(self.hps)
self.hps.model.semantic_frame_rate = "25hz" self.hps.model.semantic_frame_rate = "25hz"
@ -346,11 +379,19 @@ class VitsModel(nn.Module):
n_speakers=self.hps.data.n_speakers, n_speakers=self.hps.data.n_speakers,
**self.hps.model, **self.hps.model,
) )
self.vq_model.eval()
self.vq_model.load_state_dict(dict_s2["weight"], strict=False) self.vq_model.load_state_dict(dict_s2["weight"], strict=False)
self.vq_model.dec.remove_weight_norm()
if is_half:
self.vq_model = self.vq_model.half()
self.vq_model = self.vq_model.to(device)
self.vq_model.eval()
self.hann_window = torch.hann_window(
self.hps.data.win_length, device=device, dtype=torch.float16 if is_half else torch.float32
)
def forward(self, text_seq, pred_semantic, ref_audio, speed=1.0): def forward(self, text_seq, pred_semantic, ref_audio, speed=1.0, sv_emb=None):
refer = spectrogram_torch( refer = spectrogram_torch(
self.hann_window,
ref_audio, ref_audio,
self.hps.data.filter_length, self.hps.data.filter_length,
self.hps.data.sampling_rate, self.hps.data.sampling_rate,
@ -358,7 +399,7 @@ class VitsModel(nn.Module):
self.hps.data.win_length, self.hps.data.win_length,
center=False, center=False,
) )
return self.vq_model(pred_semantic, text_seq, refer, speed)[0, 0] return self.vq_model(pred_semantic, text_seq, refer, speed=speed, sv_emb=sv_emb)[0, 0]
class T2SModel(nn.Module): class T2SModel(nn.Module):
@ -433,6 +474,10 @@ class T2SModel(nn.Module):
bert = bert.unsqueeze(0) bert = bert.unsqueeze(0)
x = self.ar_text_embedding(all_phoneme_ids) x = self.ar_text_embedding(all_phoneme_ids)
# avoid dtype inconsistency when exporting
bert = bert.to(dtype=self.bert_proj.weight.dtype)
x = x + self.bert_proj(bert.transpose(1, 2)) x = x + self.bert_proj(bert.transpose(1, 2))
x: torch.Tensor = self.ar_text_position(x) x: torch.Tensor = self.ar_text_position(x)
@ -632,7 +677,9 @@ def export(gpt_path, vits_path, ref_audio_path, ref_text, output_path, export_be
ref_seq = torch.LongTensor([ref_seq_id]).to(device) ref_seq = torch.LongTensor([ref_seq_id]).to(device)
ref_bert = ref_bert_T.T.to(ref_seq.device) ref_bert = ref_bert_T.T.to(ref_seq.device)
text_seq_id, text_bert_T, norm_text = get_phones_and_bert( text_seq_id, text_bert_T, norm_text = get_phones_and_bert(
"这是一条测试语音,说什么无所谓,只是给它一个例子", "all_zh", "v2" "这是一个简单的示例真没想到这么简单就完成了。The King and His Stories.Once there was a king. He likes to write stories, but his stories were not good. As people were afraid of him, they all said his stories were good.After reading them, the writer at once turned to the soldiers and said: Take me back to prison, please.",
"auto",
"v2",
) )
text_seq = torch.LongTensor([text_seq_id]).to(device) text_seq = torch.LongTensor([text_seq_id]).to(device)
text_bert = text_bert_T.T.to(text_seq.device) text_bert = text_bert_T.T.to(text_seq.device)
@ -640,7 +687,7 @@ def export(gpt_path, vits_path, ref_audio_path, ref_text, output_path, export_be
ssl_content = ssl(ref_audio).to(device) ssl_content = ssl(ref_audio).to(device)
# vits_path = "SoVITS_weights_v2/xw_e8_s216.pth" # vits_path = "SoVITS_weights_v2/xw_e8_s216.pth"
vits = VitsModel(vits_path).to(device) vits = VitsModel(vits_path, device=device, is_half=False)
vits.eval() vits.eval()
# gpt_path = "GPT_weights_v2/xw-e15.ckpt" # gpt_path = "GPT_weights_v2/xw-e15.ckpt"
@ -679,6 +726,124 @@ def export(gpt_path, vits_path, ref_audio_path, ref_text, output_path, export_be
print("#### exported gpt_sovits ####") print("#### exported gpt_sovits ####")
def export_prov2(
gpt_path,
vits_path,
version,
ref_audio_path,
ref_text,
output_path,
export_bert_and_ssl=False,
device="cpu",
is_half=True,
):
if sv_cn_model == None:
init_sv_cn(device, is_half)
if not os.path.exists(output_path):
os.makedirs(output_path)
print(f"目录已创建: {output_path}")
else:
print(f"目录已存在: {output_path}")
ref_audio = torch.tensor([load_audio(ref_audio_path, 16000)]).float()
ssl = SSLModel()
if export_bert_and_ssl:
s = ExportSSLModel(torch.jit.trace(ssl, example_inputs=(ref_audio)))
ssl_path = os.path.join(output_path, "ssl_model.pt")
torch.jit.script(s).save(ssl_path)
print("#### exported ssl ####")
export_bert(output_path)
else:
s = ExportSSLModel(ssl)
print(f"device: {device}")
ref_seq_id, ref_bert_T, ref_norm_text = get_phones_and_bert(ref_text, "all_zh", "v2")
ref_seq = torch.LongTensor([ref_seq_id]).to(device)
ref_bert = ref_bert_T.T
if is_half:
ref_bert = ref_bert.half()
ref_bert = ref_bert.to(ref_seq.device)
text_seq_id, text_bert_T, norm_text = get_phones_and_bert(
"这是一个简单的示例真没想到这么简单就完成了。The King and His Stories.Once there was a king. He likes to write stories, but his stories were not good. As people were afraid of him, they all said his stories were good.After reading them, the writer at once turned to the soldiers and said: Take me back to prison, please.",
"auto",
"v2",
)
text_seq = torch.LongTensor([text_seq_id]).to(device)
text_bert = text_bert_T.T
if is_half:
text_bert = text_bert.half()
text_bert = text_bert.to(text_seq.device)
ssl_content = ssl(ref_audio)
if is_half:
ssl_content = ssl_content.half()
ssl_content = ssl_content.to(device)
sv_model = ExportERes2NetV2(sv_cn_model)
# vits_path = "SoVITS_weights_v2/xw_e8_s216.pth"
vits = VitsModel(vits_path, version, is_half=is_half, device=device)
vits.eval()
# gpt_path = "GPT_weights_v2/xw-e15.ckpt"
# dict_s1 = torch.load(gpt_path, map_location=device)
dict_s1 = torch.load(gpt_path, weights_only=False)
raw_t2s = get_raw_t2s_model(dict_s1).to(device)
print("#### get_raw_t2s_model ####")
print(raw_t2s.config)
if is_half:
raw_t2s = raw_t2s.half()
t2s_m = T2SModel(raw_t2s)
t2s_m.eval()
t2s = torch.jit.script(t2s_m).to(device)
print("#### script t2s_m ####")
print("vits.hps.data.sampling_rate:", vits.hps.data.sampling_rate)
gpt_sovits = GPT_SoVITS_V2Pro(t2s, vits, sv_model).to(device)
gpt_sovits.eval()
ref_audio_sr = s.resample(ref_audio, 16000, 32000)
if is_half:
ref_audio_sr = ref_audio_sr.half()
ref_audio_sr = ref_audio_sr.to(device)
torch._dynamo.mark_dynamic(ssl_content, 2)
torch._dynamo.mark_dynamic(ref_audio_sr, 1)
torch._dynamo.mark_dynamic(ref_seq, 1)
torch._dynamo.mark_dynamic(text_seq, 1)
torch._dynamo.mark_dynamic(ref_bert, 0)
torch._dynamo.mark_dynamic(text_bert, 0)
# torch._dynamo.mark_dynamic(sv_emb, 0)
top_k = torch.LongTensor([5]).to(device)
# 先跑一遍 sv_model 让它加载 cache详情见 L880
gpt_sovits.sv_model(ref_audio_sr)
with torch.no_grad():
gpt_sovits_export = torch.jit.trace(
gpt_sovits,
example_inputs=(
ssl_content,
ref_audio_sr,
ref_seq,
text_seq,
ref_bert,
text_bert,
top_k,
),
)
gpt_sovits_path = os.path.join(output_path, "gpt_sovits_model.pt")
gpt_sovits_export.save(gpt_sovits_path)
print("#### exported gpt_sovits ####")
audio = gpt_sovits_export(ssl_content, ref_audio_sr, ref_seq, text_seq, ref_bert, text_bert, top_k)
print("start write wav")
soundfile.write("out.wav", audio.float().detach().cpu().numpy(), 32000)
@torch.jit.script @torch.jit.script
def parse_audio(ref_audio): def parse_audio(ref_audio):
ref_audio_16k = torchaudio.functional.resample(ref_audio, 48000, 16000).float() # .to(ref_audio.device) ref_audio_16k = torchaudio.functional.resample(ref_audio, 48000, 16000).float() # .to(ref_audio.device)
@ -717,6 +882,67 @@ class GPT_SoVITS(nn.Module):
return audio return audio
class ExportERes2NetV2(nn.Module):
def __init__(self, sv_cn_model: SV):
super(ExportERes2NetV2, self).__init__()
self.bn1 = sv_cn_model.embedding_model.bn1
self.conv1 = sv_cn_model.embedding_model.conv1
self.layer1 = sv_cn_model.embedding_model.layer1
self.layer2 = sv_cn_model.embedding_model.layer2
self.layer3 = sv_cn_model.embedding_model.layer3
self.layer4 = sv_cn_model.embedding_model.layer4
self.layer3_ds = sv_cn_model.embedding_model.layer3_ds
self.fuse34 = sv_cn_model.embedding_model.fuse34
# audio_16k.shape: [1,N]
def forward(self, audio_16k):
# 这个 fbank 函数有一个 cache, 不过不要紧,它跟 audio_16k 的长度无关
# 只跟 device 和 dtype 有关
x = Kaldi.fbank(audio_16k, num_mel_bins=80, sample_frequency=16000, dither=0)
x = torch.stack([x])
x = x.permute(0, 2, 1) # (B,T,F) => (B,F,T)
x = x.unsqueeze_(1)
out = F.relu(self.bn1(self.conv1(x)))
out1 = self.layer1(out)
out2 = self.layer2(out1)
out3 = self.layer3(out2)
out4 = self.layer4(out3)
out3_ds = self.layer3_ds(out3)
fuse_out34 = self.fuse34(out4, out3_ds)
return fuse_out34.flatten(start_dim=1, end_dim=2).mean(-1)
class GPT_SoVITS_V2Pro(nn.Module):
def __init__(self, t2s: T2SModel, vits: VitsModel, sv_model: ExportERes2NetV2):
super().__init__()
self.t2s = t2s
self.vits = vits
self.sv_model = sv_model
def forward(
self,
ssl_content: torch.Tensor,
ref_audio_sr: torch.Tensor,
ref_seq: Tensor,
text_seq: Tensor,
ref_bert: Tensor,
text_bert: Tensor,
top_k: LongTensor,
speed=1.0,
):
codes = self.vits.vq_model.extract_latent(ssl_content)
prompt_semantic = codes[0, 0]
prompts = prompt_semantic.unsqueeze(0)
audio_16k = resamplex(ref_audio_sr, 32000, 16000).to(ref_audio_sr.dtype)
sv_emb = self.sv_model(audio_16k)
pred_semantic = self.t2s(prompts, ref_seq, text_seq, ref_bert, text_bert, top_k)
audio = self.vits(text_seq, pred_semantic, ref_audio_sr, speed, sv_emb)
return audio
def test(): def test():
parser = argparse.ArgumentParser(description="GPT-SoVITS Command Line Tool") parser = argparse.ArgumentParser(description="GPT-SoVITS Command Line Tool")
parser.add_argument("--gpt_model", required=True, help="Path to the GPT model file") parser.add_argument("--gpt_model", required=True, help="Path to the GPT model file")
@ -839,8 +1065,25 @@ def main():
parser.add_argument("--output_path", required=True, help="Path to the output directory") parser.add_argument("--output_path", required=True, help="Path to the output directory")
parser.add_argument("--export_common_model", action="store_true", help="Export Bert and SSL model") parser.add_argument("--export_common_model", action="store_true", help="Export Bert and SSL model")
parser.add_argument("--device", help="Device to use") parser.add_argument("--device", help="Device to use")
parser.add_argument("--version", help="version of the model", default="v2")
parser.add_argument("--no-half", action="store_true", help="Do not use half precision for model weights")
args = parser.parse_args() args = parser.parse_args()
if args.version in ["v2Pro", "v2ProPlus"]:
is_half = not args.no_half
print(f"Using half precision: {is_half}")
export_prov2(
gpt_path=args.gpt_model,
vits_path=args.sovits_model,
version=args.version,
ref_audio_path=args.ref_audio,
ref_text=args.ref_text,
output_path=args.output_path,
export_bert_and_ssl=args.export_common_model,
device=args.device,
is_half=is_half,
)
else:
export( export(
gpt_path=args.gpt_model, gpt_path=args.gpt_model,
vits_path=args.sovits_model, vits_path=args.sovits_model,
@ -852,10 +1095,7 @@ def main():
) )
import inference_webui
if __name__ == "__main__": if __name__ == "__main__":
inference_webui.is_half = False with torch.no_grad():
inference_webui.dtype = torch.float32
main() main()
# test() # test()

View File

@ -243,6 +243,7 @@ class ExportGPTSovitsHalf(torch.nn.Module):
self.sampling_rate: int = hps.data.sampling_rate self.sampling_rate: int = hps.data.sampling_rate
self.hop_length: int = hps.data.hop_length self.hop_length: int = hps.data.hop_length
self.win_length: int = hps.data.win_length self.win_length: int = hps.data.win_length
self.hann_window = torch.hann_window(self.win_length, device=device, dtype=torch.float32)
def forward( def forward(
self, self,
@ -255,6 +256,7 @@ class ExportGPTSovitsHalf(torch.nn.Module):
top_k, top_k,
): ):
refer = spectrogram_torch( refer = spectrogram_torch(
self.hann_window,
ref_audio_32k, ref_audio_32k,
self.filter_length, self.filter_length,
self.sampling_rate, self.sampling_rate,
@ -321,6 +323,7 @@ class ExportGPTSovitsV4Half(torch.nn.Module):
self.sampling_rate: int = hps.data.sampling_rate self.sampling_rate: int = hps.data.sampling_rate
self.hop_length: int = hps.data.hop_length self.hop_length: int = hps.data.hop_length
self.win_length: int = hps.data.win_length self.win_length: int = hps.data.win_length
self.hann_window = torch.hann_window(self.win_length, device=device, dtype=torch.float32)
def forward( def forward(
self, self,
@ -333,6 +336,7 @@ class ExportGPTSovitsV4Half(torch.nn.Module):
top_k, top_k,
): ):
refer = spectrogram_torch( refer = spectrogram_torch(
self.hann_window,
ref_audio_32k, ref_audio_32k,
self.filter_length, self.filter_length,
self.sampling_rate, self.sampling_rate,
@ -402,7 +406,7 @@ class GPTSoVITSV3(torch.nn.Module):
chunk_len = 934 - fea_ref.shape[2] chunk_len = 934 - fea_ref.shape[2]
wav_gen_list = [] wav_gen_list = []
idx = 0 idx = 0
fea_todo = fea_todo[:,:,:-5] fea_todo = fea_todo[:, :, :-5]
wav_gen_length = fea_todo.shape[2] * 256 wav_gen_length = fea_todo.shape[2] * 256
while 1: while 1:
# current_time = datetime.now() # current_time = datetime.now()
@ -435,6 +439,7 @@ class GPTSoVITSV3(torch.nn.Module):
wav_gen = torch.cat(wav_gen_list, 2) wav_gen = torch.cat(wav_gen_list, 2)
return wav_gen[0][0][:wav_gen_length] return wav_gen[0][0][:wav_gen_length]
class GPTSoVITSV4(torch.nn.Module): class GPTSoVITSV4(torch.nn.Module):
def __init__(self, gpt_sovits_half, cfm, hifigan): def __init__(self, gpt_sovits_half, cfm, hifigan):
super().__init__() super().__init__()
@ -461,7 +466,7 @@ class GPTSoVITSV4(torch.nn.Module):
chunk_len = 1000 - fea_ref.shape[2] chunk_len = 1000 - fea_ref.shape[2]
wav_gen_list = [] wav_gen_list = []
idx = 0 idx = 0
fea_todo = fea_todo[:,:,:-10] fea_todo = fea_todo[:, :, :-10]
wav_gen_length = fea_todo.shape[2] * 480 wav_gen_length = fea_todo.shape[2] * 480
while 1: while 1:
# current_time = datetime.now() # current_time = datetime.now()
@ -577,6 +582,7 @@ from process_ckpt import get_sovits_version_from_path_fast, load_sovits_new
v3v4set = {"v3", "v4"} v3v4set = {"v3", "v4"}
def get_sovits_weights(sovits_path): def get_sovits_weights(sovits_path):
path_sovits_v3 = "GPT_SoVITS/pretrained_models/s2Gv3.pth" path_sovits_v3 = "GPT_SoVITS/pretrained_models/s2Gv3.pth"
is_exist_s2gv3 = os.path.exists(path_sovits_v3) is_exist_s2gv3 = os.path.exists(path_sovits_v3)
@ -699,7 +705,7 @@ def export_cfm(
return export_cfm return export_cfm
def export_1(ref_wav_path,ref_wav_text,version="v3"): def export_1(ref_wav_path, ref_wav_text, version="v3"):
if version == "v3": if version == "v3":
sovits = get_sovits_weights("GPT_SoVITS/pretrained_models/s2Gv3.pth") sovits = get_sovits_weights("GPT_SoVITS/pretrained_models/s2Gv3.pth")
init_bigvgan() init_bigvgan()
@ -707,7 +713,6 @@ def export_1(ref_wav_path,ref_wav_text,version="v3"):
sovits = get_sovits_weights("GPT_SoVITS/pretrained_models/gsv-v4-pretrained/s2Gv4.pth") sovits = get_sovits_weights("GPT_SoVITS/pretrained_models/gsv-v4-pretrained/s2Gv4.pth")
init_hifigan() init_hifigan()
dict_s1 = torch.load("GPT_SoVITS/pretrained_models/s1v3.ckpt") dict_s1 = torch.load("GPT_SoVITS/pretrained_models/s1v3.ckpt")
raw_t2s = get_raw_t2s_model(dict_s1).to(device) raw_t2s = get_raw_t2s_model(dict_s1).to(device)
print("#### get_raw_t2s_model ####") print("#### get_raw_t2s_model ####")
@ -751,9 +756,7 @@ def export_1(ref_wav_path,ref_wav_text,version="v3"):
# phones1, bert1, norm_text1 = get_phones_and_bert( # phones1, bert1, norm_text1 = get_phones_and_bert(
# "你这老坏蛋,我找了你这么久,真没想到在这里找到你。他说。", "all_zh", "v3" # "你这老坏蛋,我找了你这么久,真没想到在这里找到你。他说。", "all_zh", "v3"
# ) # )
phones1, bert1, norm_text1 = get_phones_and_bert( phones1, bert1, norm_text1 = get_phones_and_bert(ref_wav_text, "auto", "v3")
ref_wav_text, "auto", "v3"
)
phones2, bert2, norm_text2 = get_phones_and_bert( phones2, bert2, norm_text2 = get_phones_and_bert(
"这是一个简单的示例真没想到这么简单就完成了。The King and His Stories.Once there was a king. He likes to write stories, but his stories were not good. As people were afraid of him, they all said his stories were good.After reading them, the writer at once turned to the soldiers and said: Take me back to prison, please.", "这是一个简单的示例真没想到这么简单就完成了。The King and His Stories.Once there was a king. He likes to write stories, but his stories were not good. As people were afraid of him, they all said his stories were good.After reading them, the writer at once turned to the soldiers and said: Take me back to prison, please.",
"auto", "auto",
@ -1149,7 +1152,7 @@ def export_2(version="v3"):
raw_t2s = raw_t2s.half().to(device) raw_t2s = raw_t2s.half().to(device)
t2s_m = T2SModel(raw_t2s).half().to(device) t2s_m = T2SModel(raw_t2s).half().to(device)
t2s_m.eval() t2s_m.eval()
t2s_m = torch.jit.script(t2s_m) t2s_m = torch.jit.script(t2s_m).to(device)
t2s_m.eval() t2s_m.eval()
# t2s_m.top_k = 15 # t2s_m.top_k = 15
logger.info("t2s_m ok") logger.info("t2s_m ok")
@ -1201,7 +1204,6 @@ def export_2(version="v3"):
gpt_sovits_v3v4 = gpt_sovits_v3 if version == "v3" else gpt_sovits_v4 gpt_sovits_v3v4 = gpt_sovits_v3 if version == "v3" else gpt_sovits_v4
sr = 24000 if version == "v3" else 48000 sr = 24000 if version == "v3" else 48000
time.sleep(5) time.sleep(5)
# print("thread:", torch.get_num_threads()) # print("thread:", torch.get_num_threads())
# print("thread:", torch.get_num_interop_threads()) # print("thread:", torch.get_num_interop_threads())
@ -1212,14 +1214,14 @@ def export_2(version="v3"):
"汗流浃背了呀!老弟~ My uncle has two dogs. One is big and the other is small. He likes them very much. He often plays with them. He takes them for a walk every day. He says they are his good friends. He is very happy with them. 最后还是我得了 MVP....", "汗流浃背了呀!老弟~ My uncle has two dogs. One is big and the other is small. He likes them very much. He often plays with them. He takes them for a walk every day. He says they are his good friends. He is very happy with them. 最后还是我得了 MVP....",
gpt_sovits_v3v4, gpt_sovits_v3v4,
"out.wav", "out.wav",
sr sr,
) )
test_export( test_export(
"你小子是什么来路.汗流浃背了呀!老弟~ My uncle has two dogs. He is very happy with them. 最后还是我得了 MVP!", "你小子是什么来路.汗流浃背了呀!老弟~ My uncle has two dogs. He is very happy with them. 最后还是我得了 MVP!",
gpt_sovits_v3v4, gpt_sovits_v3v4,
"out2.wav", "out2.wav",
sr sr,
) )
# test_export( # test_export(
@ -1251,6 +1253,6 @@ def test_export_gpt_sovits_v3():
with torch.no_grad(): with torch.no_grad():
export_1("onnx/ad/ref.wav","你这老坏蛋,我找了你这么久,真没想到在这里找到你。他说。","v4") # export_1("onnx/ad/ref.wav","你这老坏蛋,我找了你这么久,真没想到在这里找到你。他说。","v4")
# export_2("v4") export_2("v4")
# test_export_gpt_sovits_v3() # test_export_gpt_sovits_v3()

View File

@ -6,7 +6,20 @@
全部按英文识别 全部按英文识别
全部按日文识别 全部按日文识别
""" """
import psutil
import os
def set_high_priority():
"""把当前 Python 进程设为 HIGH_PRIORITY_CLASS"""
if os.name != "nt":
return # 仅 Windows 有效
p = psutil.Process(os.getpid())
try:
p.nice(psutil.HIGH_PRIORITY_CLASS)
print("已将进程优先级设为 High")
except psutil.AccessDenied:
print("权限不足,无法修改优先级(请用管理员运行)")
set_high_priority()
import json import json
import logging import logging
import os import os
@ -214,7 +227,7 @@ v3v4set = {"v3", "v4"}
def change_sovits_weights(sovits_path, prompt_language=None, text_language=None): def change_sovits_weights(sovits_path, prompt_language=None, text_language=None):
if "" in sovits_path: if "" in sovits_path or "!" in sovits_path:
sovits_path = name2sovits_path[sovits_path] sovits_path = name2sovits_path[sovits_path]
global vq_model, hps, version, model_version, dict_language, if_lora_v3 global vq_model, hps, version, model_version, dict_language, if_lora_v3
version, model_version, if_lora_v3 = get_sovits_version_from_path_fast(sovits_path) version, model_version, if_lora_v3 = get_sovits_version_from_path_fast(sovits_path)
@ -361,7 +374,7 @@ except:
def change_gpt_weights(gpt_path): def change_gpt_weights(gpt_path):
if "" in gpt_path: if "" in gpt_path or "!" in gpt_path:
gpt_path = name2gpt_path[gpt_path] gpt_path = name2gpt_path[gpt_path]
global hz, max_sec, t2s_model, config global hz, max_sec, t2s_model, config
hz = 50 hz = 50
@ -586,32 +599,31 @@ from text import chinese
def get_phones_and_bert(text, language, version, final=False): def get_phones_and_bert(text, language, version, final=False):
if language in {"en", "all_zh", "all_ja", "all_ko", "all_yue"}: text = re.sub(r' {2,}', ' ', text)
formattext = text
while " " in formattext:
formattext = formattext.replace(" ", " ")
if language == "all_zh":
if re.search(r"[A-Za-z]", formattext):
formattext = re.sub(r"[a-z]", lambda x: x.group(0).upper(), formattext)
formattext = chinese.mix_text_normalize(formattext)
return get_phones_and_bert(formattext, "zh", version)
else:
phones, word2ph, norm_text = clean_text_inf(formattext, language, version)
bert = get_bert_feature(norm_text, word2ph).to(device)
elif language == "all_yue" and re.search(r"[A-Za-z]", formattext):
formattext = re.sub(r"[a-z]", lambda x: x.group(0).upper(), formattext)
formattext = chinese.mix_text_normalize(formattext)
return get_phones_and_bert(formattext, "yue", version)
else:
phones, word2ph, norm_text = clean_text_inf(formattext, language, version)
bert = torch.zeros(
(1024, len(phones)),
dtype=torch.float16 if is_half == True else torch.float32,
).to(device)
elif language in {"zh", "ja", "ko", "yue", "auto", "auto_yue"}:
textlist = [] textlist = []
langlist = [] langlist = []
if language == "auto": if language == "all_zh":
for tmp in LangSegmenter.getTexts(text,"zh"):
langlist.append(tmp["lang"])
textlist.append(tmp["text"])
elif language == "all_yue":
for tmp in LangSegmenter.getTexts(text,"zh"):
if tmp["lang"] == "zh":
tmp["lang"] = "yue"
langlist.append(tmp["lang"])
textlist.append(tmp["text"])
elif language == "all_ja":
for tmp in LangSegmenter.getTexts(text,"ja"):
langlist.append(tmp["lang"])
textlist.append(tmp["text"])
elif language == "all_ko":
for tmp in LangSegmenter.getTexts(text,"ko"):
langlist.append(tmp["lang"])
textlist.append(tmp["text"])
elif language == "en":
langlist.append("en")
textlist.append(text)
elif language == "auto":
for tmp in LangSegmenter.getTexts(text): for tmp in LangSegmenter.getTexts(text):
langlist.append(tmp["lang"]) langlist.append(tmp["lang"])
textlist.append(tmp["text"]) textlist.append(tmp["text"])
@ -623,6 +635,10 @@ def get_phones_and_bert(text, language, version, final=False):
textlist.append(tmp["text"]) textlist.append(tmp["text"])
else: else:
for tmp in LangSegmenter.getTexts(text): for tmp in LangSegmenter.getTexts(text):
if langlist:
if (tmp["lang"] == "en" and langlist[-1] == "en") or (tmp["lang"] != "en" and langlist[-1] != "en"):
textlist[-1] += tmp["text"]
continue
if tmp["lang"] == "en": if tmp["lang"] == "en":
langlist.append(tmp["lang"]) langlist.append(tmp["lang"])
else: else:

View File

@ -6,7 +6,20 @@
全部按英文识别 全部按英文识别
全部按日文识别 全部按日文识别
""" """
import psutil
import os
def set_high_priority():
"""把当前 Python 进程设为 HIGH_PRIORITY_CLASS"""
if os.name != "nt":
return # 仅 Windows 有效
p = psutil.Process(os.getpid())
try:
p.nice(psutil.HIGH_PRIORITY_CLASS)
print("已将进程优先级设为 High")
except psutil.AccessDenied:
print("权限不足,无法修改优先级(请用管理员运行)")
set_high_priority()
import json import json
import logging import logging
import os import os
@ -112,13 +125,14 @@ is_exist_s2gv4 = os.path.exists(path_sovits_v4)
tts_config = TTS_Config("GPT_SoVITS/configs/tts_infer.yaml") tts_config = TTS_Config("GPT_SoVITS/configs/tts_infer.yaml")
tts_config.device = device tts_config.device = device
tts_config.is_half = is_half tts_config.is_half = is_half
tts_config.version = version # tts_config.version = version
tts_config.update_version(version)
if gpt_path is not None: if gpt_path is not None:
if "" in gpt_path: if "" in gpt_path or "!" in gpt_path:
gpt_path = name2gpt_path[gpt_path] gpt_path = name2gpt_path[gpt_path]
tts_config.t2s_weights_path = gpt_path tts_config.t2s_weights_path = gpt_path
if sovits_path is not None: if sovits_path is not None:
if "" in sovits_path: if "" in sovits_path or "!" in sovits_path:
sovits_path = name2sovits_path[sovits_path] sovits_path = name2sovits_path[sovits_path]
tts_config.vits_weights_path = sovits_path tts_config.vits_weights_path = sovits_path
if cnhubert_base_path is not None: if cnhubert_base_path is not None:
@ -217,7 +231,7 @@ v3v4set = {"v3", "v4"}
def change_sovits_weights(sovits_path, prompt_language=None, text_language=None): def change_sovits_weights(sovits_path, prompt_language=None, text_language=None):
if "" in sovits_path: if "" in sovits_path or "!" in sovits_path:
sovits_path = name2sovits_path[sovits_path] sovits_path = name2sovits_path[sovits_path]
global version, model_version, dict_language, if_lora_v3 global version, model_version, dict_language, if_lora_v3
version, model_version, if_lora_v3 = get_sovits_version_from_path_fast(sovits_path) version, model_version, if_lora_v3 = get_sovits_version_from_path_fast(sovits_path)
@ -283,6 +297,12 @@ def change_sovits_weights(sovits_path, prompt_language=None, text_language=None)
f.write(json.dumps(data)) f.write(json.dumps(data))
def change_gpt_weights(gpt_path):
if "" in gpt_path or "!" in gpt_path:
gpt_path = name2gpt_path[gpt_path]
tts_pipeline.init_t2s_weights(gpt_path)
with gr.Blocks(title="GPT-SoVITS WebUI", analytics_enabled=False, js=js, css=css) as app: with gr.Blocks(title="GPT-SoVITS WebUI", analytics_enabled=False, js=js, css=css) as app:
gr.HTML( gr.HTML(
top_html.format( top_html.format(
@ -457,7 +477,7 @@ with gr.Blocks(title="GPT-SoVITS WebUI", analytics_enabled=False, js=js, css=css
inference_button, inference_button,
], ],
) # ) #
GPT_dropdown.change(tts_pipeline.init_t2s_weights, [GPT_dropdown], []) GPT_dropdown.change(change_gpt_weights, [GPT_dropdown], [])
with gr.Group(): with gr.Group():
gr.Markdown( gr.Markdown(

View File

@ -21,7 +21,7 @@ class TextAudioSpeakerLoader(torch.utils.data.Dataset):
3) computes spectrograms from audio files. 3) computes spectrograms from audio files.
""" """
def __init__(self, hparams, version=None,val=False): def __init__(self, hparams, version=None, val=False):
exp_dir = hparams.exp_dir exp_dir = hparams.exp_dir
self.path2 = "%s/2-name2text.txt" % exp_dir self.path2 = "%s/2-name2text.txt" % exp_dir
self.path4 = "%s/4-cnhubert" % exp_dir self.path4 = "%s/4-cnhubert" % exp_dir
@ -29,7 +29,7 @@ class TextAudioSpeakerLoader(torch.utils.data.Dataset):
assert os.path.exists(self.path2) assert os.path.exists(self.path2)
assert os.path.exists(self.path4) assert os.path.exists(self.path4)
assert os.path.exists(self.path5) assert os.path.exists(self.path5)
self.is_v2Pro=version in {"v2Pro","v2ProPlus"} self.is_v2Pro = version in {"v2Pro", "v2ProPlus"}
if self.is_v2Pro: if self.is_v2Pro:
self.path7 = "%s/7-sv_cn" % exp_dir self.path7 = "%s/7-sv_cn" % exp_dir
assert os.path.exists(self.path7) assert os.path.exists(self.path7)
@ -118,7 +118,7 @@ class TextAudioSpeakerLoader(torch.utils.data.Dataset):
ssl = F.pad(ssl.float(), (0, 1), mode="replicate").to(typee) ssl = F.pad(ssl.float(), (0, 1), mode="replicate").to(typee)
ssl.requires_grad = False ssl.requires_grad = False
if self.is_v2Pro: if self.is_v2Pro:
sv_emb=torch.load("%s/%s.pt" % (self.path7, audiopath), map_location="cpu") sv_emb = torch.load("%s/%s.pt" % (self.path7, audiopath), map_location="cpu")
except: except:
traceback.print_exc() traceback.print_exc()
spec = torch.zeros(1025, 100) spec = torch.zeros(1025, 100)
@ -126,10 +126,10 @@ class TextAudioSpeakerLoader(torch.utils.data.Dataset):
ssl = torch.zeros(1, 768, 100) ssl = torch.zeros(1, 768, 100)
text = text[-1:] text = text[-1:]
if self.is_v2Pro: if self.is_v2Pro:
sv_emb=torch.zeros(1,20480) sv_emb = torch.zeros(1, 20480)
print("load audio or ssl error!!!!!!", audiopath) print("load audio or ssl error!!!!!!", audiopath)
if self.is_v2Pro: if self.is_v2Pro:
return (ssl, spec, wav, text,sv_emb) return (ssl, spec, wav, text, sv_emb)
else: else:
return (ssl, spec, wav, text) return (ssl, spec, wav, text)
@ -192,9 +192,9 @@ class TextAudioSpeakerLoader(torch.utils.data.Dataset):
class TextAudioSpeakerCollate: class TextAudioSpeakerCollate:
"""Zero-pads model inputs and targets""" """Zero-pads model inputs and targets"""
def __init__(self, return_ids=False,version=None): def __init__(self, return_ids=False, version=None):
self.return_ids = return_ids self.return_ids = return_ids
self.is_v2Pro=version in {"v2Pro","v2ProPlus"} self.is_v2Pro = version in {"v2Pro", "v2ProPlus"}
def __call__(self, batch): def __call__(self, batch):
"""Collate's training batch from normalized text, audio and speaker identities """Collate's training batch from normalized text, audio and speaker identities
@ -228,7 +228,7 @@ class TextAudioSpeakerCollate:
text_padded.zero_() text_padded.zero_()
if self.is_v2Pro: if self.is_v2Pro:
sv_embs=torch.FloatTensor(len(batch),20480) sv_embs = torch.FloatTensor(len(batch), 20480)
for i in range(len(ids_sorted_decreasing)): for i in range(len(ids_sorted_decreasing)):
row = batch[ids_sorted_decreasing[i]] row = batch[ids_sorted_decreasing[i]]
@ -250,11 +250,30 @@ class TextAudioSpeakerCollate:
text_lengths[i] = text.size(0) text_lengths[i] = text.size(0)
if self.is_v2Pro: if self.is_v2Pro:
sv_embs[i]=row[4] sv_embs[i] = row[4]
if self.is_v2Pro: if self.is_v2Pro:
return ssl_padded, ssl_lengths, spec_padded, spec_lengths, wav_padded, wav_lengths, text_padded, text_lengths,sv_embs return (
ssl_padded,
ssl_lengths,
spec_padded,
spec_lengths,
wav_padded,
wav_lengths,
text_padded,
text_lengths,
sv_embs,
)
else: else:
return ssl_padded, ssl_lengths, spec_padded, spec_lengths, wav_padded, wav_lengths, text_padded, text_lengths return (
ssl_padded,
ssl_lengths,
spec_padded,
spec_lengths,
wav_padded,
wav_lengths,
text_padded,
text_lengths,
)
class TextAudioSpeakerLoaderV3(torch.utils.data.Dataset): class TextAudioSpeakerLoaderV3(torch.utils.data.Dataset):

View File

@ -586,12 +586,17 @@ class DiscriminatorS(torch.nn.Module):
return x, fmap return x, fmap
v2pro_set={"v2Pro","v2ProPlus"}
v2pro_set = {"v2Pro", "v2ProPlus"}
class MultiPeriodDiscriminator(torch.nn.Module): class MultiPeriodDiscriminator(torch.nn.Module):
def __init__(self, use_spectral_norm=False,version=None): def __init__(self, use_spectral_norm=False, version=None):
super(MultiPeriodDiscriminator, self).__init__() super(MultiPeriodDiscriminator, self).__init__()
if version in v2pro_set:periods = [2, 3, 5, 7, 11,17,23] if version in v2pro_set:
else:periods = [2, 3, 5, 7, 11] periods = [2, 3, 5, 7, 11, 17, 23]
else:
periods = [2, 3, 5, 7, 11]
discs = [DiscriminatorS(use_spectral_norm=use_spectral_norm)] discs = [DiscriminatorS(use_spectral_norm=use_spectral_norm)]
discs = discs + [DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods] discs = discs + [DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods]
@ -787,6 +792,7 @@ class CodePredictor(nn.Module):
return pred_codes.transpose(0, 1) return pred_codes.transpose(0, 1)
class SynthesizerTrn(nn.Module): class SynthesizerTrn(nn.Module):
""" """
Synthesizer for Training Synthesizer for Training
@ -886,13 +892,13 @@ class SynthesizerTrn(nn.Module):
self.quantizer = ResidualVectorQuantizer(dimension=ssl_dim, n_q=1, bins=1024) self.quantizer = ResidualVectorQuantizer(dimension=ssl_dim, n_q=1, bins=1024)
self.freeze_quantizer = freeze_quantizer self.freeze_quantizer = freeze_quantizer
self.is_v2pro=self.version in v2pro_set self.is_v2pro = self.version in v2pro_set
if self.is_v2pro: if self.is_v2pro:
self.sv_emb = nn.Linear(20480, gin_channels) self.sv_emb = nn.Linear(20480, gin_channels)
self.ge_to512 = nn.Linear(gin_channels, 512) self.ge_to512 = nn.Linear(gin_channels, 512)
self.prelu = nn.PReLU(num_parameters=gin_channels) self.prelu = nn.PReLU(num_parameters=gin_channels)
def forward(self, ssl, y, y_lengths, text, text_lengths,sv_emb=None): def forward(self, ssl, y, y_lengths, text, text_lengths, sv_emb=None):
y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, y.size(2)), 1).to(y.dtype) y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, y.size(2)), 1).to(y.dtype)
if self.version == "v1": if self.version == "v1":
ge = self.ref_enc(y * y_mask, y_mask) ge = self.ref_enc(y * y_mask, y_mask)
@ -952,7 +958,7 @@ class SynthesizerTrn(nn.Module):
return o, y_mask, (z, z_p, m_p, logs_p) return o, y_mask, (z, z_p, m_p, logs_p)
@torch.no_grad() @torch.no_grad()
def decode(self, codes, text, refer,noise_scale=0.5, speed=1, sv_emb=None): def decode(self, codes, text, refer, noise_scale=0.5, speed=1, sv_emb=None):
def get_ge(refer, sv_emb): def get_ge(refer, sv_emb):
ge = None ge = None
if refer is not None: if refer is not None:
@ -970,8 +976,8 @@ class SynthesizerTrn(nn.Module):
if type(refer) == list: if type(refer) == list:
ges = [] ges = []
for idx,_refer in enumerate(refer): for idx, _refer in enumerate(refer):
ge = get_ge(_refer, sv_emb[idx]if self.is_v2pro else None) ge = get_ge(_refer, sv_emb[idx] if self.is_v2pro else None)
ges.append(ge) ges.append(ge)
ge = torch.stack(ges, 0).mean(0) ge = torch.stack(ges, 0).mean(0)
else: else:
@ -983,7 +989,14 @@ class SynthesizerTrn(nn.Module):
quantized = self.quantizer.decode(codes) quantized = self.quantizer.decode(codes)
if self.semantic_frame_rate == "25hz": if self.semantic_frame_rate == "25hz":
quantized = F.interpolate(quantized, size=int(quantized.shape[-1] * 2), mode="nearest") quantized = F.interpolate(quantized, size=int(quantized.shape[-1] * 2), mode="nearest")
x, m_p, logs_p, y_mask = self.enc_p(quantized, y_lengths, text, text_lengths, self.ge_to512(ge.transpose(2,1)).transpose(2,1)if self.is_v2pro else ge, speed) x, m_p, logs_p, y_mask = self.enc_p(
quantized,
y_lengths,
text,
text_lengths,
self.ge_to512(ge.transpose(2, 1)).transpose(2, 1) if self.is_v2pro else ge,
speed,
)
z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale
z = self.flow(z_p, y_mask, g=ge, reverse=True) z = self.flow(z_p, y_mask, g=ge, reverse=True)
@ -996,6 +1009,7 @@ class SynthesizerTrn(nn.Module):
quantized, codes, commit_loss, quantized_list = self.quantizer(ssl) quantized, codes, commit_loss, quantized_list = self.quantizer(ssl)
return codes.transpose(0, 1) return codes.transpose(0, 1)
class CFM(torch.nn.Module): class CFM(torch.nn.Module):
def __init__(self, in_channels, dit): def __init__(self, in_channels, dit):
super().__init__() super().__init__()
@ -1029,7 +1043,18 @@ class CFM(torch.nn.Module):
t_tensor = torch.ones(x.shape[0], device=x.device, dtype=mu.dtype) * t t_tensor = torch.ones(x.shape[0], device=x.device, dtype=mu.dtype) * t
# v_pred = model(x, t_tensor, d_tensor, **extra_args) # v_pred = model(x, t_tensor, d_tensor, **extra_args)
v_pred, text_emb, dt = self.estimator( v_pred, text_emb, dt = self.estimator(
x, prompt_x, x_lens, t_tensor, d_tensor, mu, use_grad_ckpt=False, drop_audio_cond=False, drop_text=False, infer=True, text_cache=text_cache, dt_cache=dt_cache x,
prompt_x,
x_lens,
t_tensor,
d_tensor,
mu,
use_grad_ckpt=False,
drop_audio_cond=False,
drop_text=False,
infer=True,
text_cache=text_cache,
dt_cache=dt_cache,
) )
v_pred = v_pred.transpose(2, 1) v_pred = v_pred.transpose(2, 1)
if self.use_conditioner_cache: if self.use_conditioner_cache:
@ -1048,7 +1073,7 @@ class CFM(torch.nn.Module):
drop_text=True, drop_text=True,
infer=True, infer=True,
text_cache=text_cfg_cache, text_cache=text_cfg_cache,
dt_cache=dt_cache dt_cache=dt_cache,
) )
neg = neg.transpose(2, 1) neg = neg.transpose(2, 1)
if self.use_conditioner_cache: if self.use_conditioner_cache:

View File

@ -763,6 +763,9 @@ class CodePredictor(nn.Module):
return pred_codes.transpose(0, 1) return pred_codes.transpose(0, 1)
v2pro_set = {"v2Pro", "v2ProPlus"}
class SynthesizerTrn(nn.Module): class SynthesizerTrn(nn.Module):
""" """
Synthesizer for Training Synthesizer for Training
@ -867,19 +870,32 @@ class SynthesizerTrn(nn.Module):
# self.enc_p.text_embedding.requires_grad_(False) # self.enc_p.text_embedding.requires_grad_(False)
# self.enc_p.encoder_text.requires_grad_(False) # self.enc_p.encoder_text.requires_grad_(False)
# self.enc_p.mrte.requires_grad_(False) # self.enc_p.mrte.requires_grad_(False)
self.is_v2pro = self.version in v2pro_set
if self.is_v2pro:
self.sv_emb = nn.Linear(20480, gin_channels)
self.ge_to512 = nn.Linear(gin_channels, 512)
self.prelu = nn.PReLU(num_parameters=gin_channels)
def forward(self, codes, text, refer, noise_scale=0.5, speed=1): def forward(self, codes, text, refer, noise_scale=0.5, speed=1, sv_emb=None):
refer_mask = torch.ones_like(refer[:1, :1, :]) refer_mask = torch.ones_like(refer[:1, :1, :])
if self.version == "v1": if self.version == "v1":
ge = self.ref_enc(refer * refer_mask, refer_mask) ge = self.ref_enc(refer * refer_mask, refer_mask)
else: else:
ge = self.ref_enc(refer[:, :704] * refer_mask, refer_mask) ge = self.ref_enc(refer[:, :704] * refer_mask, refer_mask)
if self.is_v2pro:
sv_emb = self.sv_emb(sv_emb)
ge += sv_emb.unsqueeze(-1)
ge = self.prelu(ge)
quantized = self.quantizer.decode(codes) quantized = self.quantizer.decode(codes)
if self.semantic_frame_rate == "25hz": if self.semantic_frame_rate == "25hz":
dquantized = torch.cat([quantized, quantized]).permute(1, 2, 0) dquantized = torch.cat([quantized, quantized]).permute(1, 2, 0)
quantized = dquantized.contiguous().view(1, self.ssl_dim, -1) quantized = dquantized.contiguous().view(1, self.ssl_dim, -1)
if self.is_v2pro:
ge_ = self.ge_to512(ge.transpose(2, 1)).transpose(2, 1)
x, m_p, logs_p, y_mask = self.enc_p(quantized, text, ge_, speed)
else:
x, m_p, logs_p, y_mask = self.enc_p(quantized, text, ge, speed) x, m_p, logs_p, y_mask = self.enc_p(quantized, text, ge, speed)
z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale

View File

@ -1,4 +1,5 @@
import math import math
import numpy as np import numpy as np
import torch import torch
from torch import nn from torch import nn
@ -718,8 +719,10 @@ class MelStyleEncoder(nn.Module):
else: else:
len_ = (~mask).sum(dim=1).unsqueeze(1) len_ = (~mask).sum(dim=1).unsqueeze(1)
x = x.masked_fill(mask.unsqueeze(-1), 0) x = x.masked_fill(mask.unsqueeze(-1), 0)
x = x.sum(dim=1) dtype = x.dtype
out = torch.div(x, len_) x = x.float()
x = torch.div(x, len_.unsqueeze(1))
out = x.sum(dim=1).to(dtype)
return out return out
def forward(self, x, mask=None): def forward(self, x, mask=None):
@ -743,7 +746,6 @@ class MelStyleEncoder(nn.Module):
x = self.fc(x) x = self.fc(x)
# temoral average pooling # temoral average pooling
w = self.temporal_avg_pool(x, mask=mask) w = self.temporal_avg_pool(x, mask=mask)
return w.unsqueeze(-1) return w.unsqueeze(-1)

View File

@ -10,7 +10,6 @@ i_part = os.environ.get("i_part")
all_parts = os.environ.get("all_parts") all_parts = os.environ.get("all_parts")
if "_CUDA_VISIBLE_DEVICES" in os.environ: if "_CUDA_VISIBLE_DEVICES" in os.environ:
os.environ["CUDA_VISIBLE_DEVICES"] = os.environ["_CUDA_VISIBLE_DEVICES"] os.environ["CUDA_VISIBLE_DEVICES"] = os.environ["_CUDA_VISIBLE_DEVICES"]
from feature_extractor import cnhubert
opt_dir = os.environ.get("opt_dir") opt_dir = os.environ.get("opt_dir")
sv_path = os.environ.get("sv_path") sv_path = os.environ.get("sv_path")
@ -19,19 +18,18 @@ import torch
is_half = eval(os.environ.get("is_half", "True")) and torch.cuda.is_available() is_half = eval(os.environ.get("is_half", "True")) and torch.cuda.is_available()
import traceback import traceback
import numpy as np
from scipy.io import wavfile
import torchaudio import torchaudio
now_dir = os.getcwd() now_dir = os.getcwd()
sys.path.append(now_dir) sys.path.append(now_dir)
sys.path.append(f"{now_dir}/GPT_SoVITS/eres2net") sys.path.append(f"{now_dir}/GPT_SoVITS/eres2net")
from tools.my_utils import load_audio, clean_path from tools.my_utils import clean_path
from time import time as ttime from time import time as ttime
import shutil import shutil
from ERes2NetV2 import ERes2NetV2 from ERes2NetV2 import ERes2NetV2
import kaldi as Kaldi import kaldi as Kaldi
def my_save(fea, path): #####fix issue: torch.save doesn't support chinese path def my_save(fea, path): #####fix issue: torch.save doesn't support chinese path
dir = os.path.dirname(path) dir = os.path.dirname(path)
name = os.path.basename(path) name = os.path.basename(path)
@ -56,37 +54,45 @@ if torch.cuda.is_available():
else: else:
device = "cpu" device = "cpu"
class SV: class SV:
def __init__(self,device,is_half): def __init__(self, device, is_half):
pretrained_state = torch.load(sv_path, map_location='cpu') pretrained_state = torch.load(sv_path, map_location="cpu")
embedding_model = ERes2NetV2(baseWidth=24,scale=4,expansion=4) embedding_model = ERes2NetV2(baseWidth=24, scale=4, expansion=4)
embedding_model.load_state_dict(pretrained_state) embedding_model.load_state_dict(pretrained_state)
embedding_model.eval() embedding_model.eval()
self.embedding_model=embedding_model self.embedding_model = embedding_model
self.res=torchaudio.transforms.Resample(32000, 16000).to(device) self.res = torchaudio.transforms.Resample(32000, 16000).to(device)
if is_half == False: if is_half == False:
self.embedding_model=self.embedding_model.to(device) self.embedding_model = self.embedding_model.to(device)
else: else:
self.embedding_model=self.embedding_model.half().to(device) self.embedding_model = self.embedding_model.half().to(device)
self.is_half=is_half self.is_half = is_half
def compute_embedding3(self,wav):#(1,x)#-1~1 def compute_embedding3(self, wav): # (1,x)#-1~1
with torch.no_grad(): with torch.no_grad():
wav=self.res(wav) wav = self.res(wav)
if self.is_half==True:wav=wav.half() if self.is_half == True:
feat = torch.stack([Kaldi.fbank(wav0.unsqueeze(0), num_mel_bins=80, sample_frequency=16000, dither=0) for wav0 in wav]) wav = wav.half()
feat = torch.stack(
[Kaldi.fbank(wav0.unsqueeze(0), num_mel_bins=80, sample_frequency=16000, dither=0) for wav0 in wav]
)
sv_emb = self.embedding_model.forward3(feat) sv_emb = self.embedding_model.forward3(feat)
return sv_emb return sv_emb
sv=SV(device,is_half)
sv = SV(device, is_half)
def name2go(wav_name, wav_path): def name2go(wav_name, wav_path):
sv_cn_path = "%s/%s.pt" % (sv_cn_dir, wav_name) sv_cn_path = "%s/%s.pt" % (sv_cn_dir, wav_name)
if os.path.exists(sv_cn_path):return if os.path.exists(sv_cn_path):
wav_path="%s/%s" % (wav32dir, wav_name) return
wav32k,sr0 = torchaudio.load(wav_path) wav_path = "%s/%s" % (wav32dir, wav_name)
assert sr0==32000 wav32k, sr0 = torchaudio.load(wav_path)
assert sr0 == 32000
wav32k = wav32k.to(device) wav32k = wav32k.to(device)
emb=sv.compute_embedding3(wav32k).cpu() # torch.Size([1, 20480]) emb = sv.compute_embedding3(wav32k).cpu() # torch.Size([1, 20480])
my_save(emb, sv_cn_path) my_save(emb, sv_cn_path)

View File

@ -17,15 +17,16 @@ def my_save(fea, path): #####fix issue: torch.save doesn't support chinese path
shutil.move(tmp_path, "%s/%s" % (dir, name)) shutil.move(tmp_path, "%s/%s" % (dir, name))
from io import BytesIO from io import BytesIO
model_version2byte={ model_version2byte = {
"v3":b"03", "v3": b"03",
"v4":b"04", "v4": b"04",
"v2Pro":b"05", "v2Pro": b"05",
"v2ProPlus":b"06", "v2ProPlus": b"06",
} }
def my_save2(fea, path, model_version): def my_save2(fea, path, model_version):
bio = BytesIO() bio = BytesIO()
torch.save(fea, bio) torch.save(fea, bio)
@ -50,7 +51,7 @@ def savee(ckpt, name, epoch, steps, hps, model_version=None, lora_rank=None):
if lora_rank: if lora_rank:
opt["lora_rank"] = lora_rank opt["lora_rank"] = lora_rank
my_save2(opt, "%s/%s.pth" % (hps.save_weight_dir, name), model_version) my_save2(opt, "%s/%s.pth" % (hps.save_weight_dir, name), model_version)
elif (model_version!=None and "Pro"in model_version): elif model_version != None and "Pro" in model_version:
my_save2(opt, "%s/%s.pth" % (hps.save_weight_dir, name), model_version) my_save2(opt, "%s/%s.pth" % (hps.save_weight_dir, name), model_version)
else: else:
my_save(opt, "%s/%s.pth" % (hps.save_weight_dir, name)) my_save(opt, "%s/%s.pth" % (hps.save_weight_dir, name))
@ -58,6 +59,7 @@ def savee(ckpt, name, epoch, steps, hps, model_version=None, lora_rank=None):
except: except:
return traceback.format_exc() return traceback.format_exc()
""" """
00:v1 00:v1
01:v2 01:v2
@ -127,7 +129,7 @@ def get_sovits_version_from_path_fast(sovits_path):
def load_sovits_new(sovits_path): def load_sovits_new(sovits_path):
f = open(sovits_path, "rb") f = open(sovits_path, "rb")
meta = f.read(2) meta = f.read(2)
if meta != "PK": if meta != b"PK":
data = b"PK" + f.read() data = b"PK" + f.read()
bio = BytesIO() bio = BytesIO()
bio.write(data) bio.write(data)

View File

@ -36,7 +36,7 @@ from module.models import (
MultiPeriodDiscriminator, MultiPeriodDiscriminator,
SynthesizerTrn, SynthesizerTrn,
) )
from process_ckpt import savee,my_save2 from process_ckpt import savee
torch.backends.cudnn.benchmark = False torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = False torch.backends.cudnn.deterministic = False
@ -87,11 +87,30 @@ def run(rank, n_gpus, hps):
if torch.cuda.is_available(): if torch.cuda.is_available():
torch.cuda.set_device(rank) torch.cuda.set_device(rank)
train_dataset = TextAudioSpeakerLoader(hps.data,version=hps.model.version) train_dataset = TextAudioSpeakerLoader(hps.data, version=hps.model.version)
train_sampler = DistributedBucketSampler( train_sampler = DistributedBucketSampler(
train_dataset, train_dataset,
hps.train.batch_size, hps.train.batch_size,
[32,300,400,500,600,700,800,900,1000,1100,1200,1300,1400,1500,1600,1700,1800,1900,], [
32,
300,
400,
500,
600,
700,
800,
900,
1000,
1100,
1200,
1300,
1400,
1500,
1600,
1700,
1800,
1900,
],
num_replicas=n_gpus, num_replicas=n_gpus,
rank=rank, rank=rank,
shuffle=True, shuffle=True,
@ -130,9 +149,9 @@ def run(rank, n_gpus, hps):
) )
net_d = ( net_d = (
MultiPeriodDiscriminator(hps.model.use_spectral_norm,version=hps.model.version).cuda(rank) MultiPeriodDiscriminator(hps.model.use_spectral_norm, version=hps.model.version).cuda(rank)
if torch.cuda.is_available() if torch.cuda.is_available()
else MultiPeriodDiscriminator(hps.model.use_spectral_norm,version=hps.model.version).to(device) else MultiPeriodDiscriminator(hps.model.use_spectral_norm, version=hps.model.version).to(device)
) )
for name, param in net_g.named_parameters(): for name, param in net_g.named_parameters():
if not param.requires_grad: if not param.requires_grad:
@ -235,7 +254,7 @@ def run(rank, n_gpus, hps):
print( print(
"loaded pretrained %s" % hps.train.pretrained_s2D, "loaded pretrained %s" % hps.train.pretrained_s2D,
net_d.module.load_state_dict( net_d.module.load_state_dict(
torch.load(hps.train.pretrained_s2D, map_location="cpu", weights_only=False)["weight"],strict=False torch.load(hps.train.pretrained_s2D, map_location="cpu", weights_only=False)["weight"], strict=False
) )
if torch.cuda.is_available() if torch.cuda.is_available()
else net_d.load_state_dict( else net_d.load_state_dict(
@ -310,17 +329,44 @@ def train_and_evaluate(rank, epoch, hps, nets, optims, schedulers, scaler, loade
net_g.train() net_g.train()
net_d.train() net_d.train()
for batch_idx, data in enumerate(tqdm(train_loader)): for batch_idx, data in enumerate(tqdm(train_loader)):
if hps.model.version in {"v2Pro","v2ProPlus"}: if hps.model.version in {"v2Pro", "v2ProPlus"}:
ssl, ssl_lengths, spec, spec_lengths, y, y_lengths, text, text_lengths,sv_emb=data ssl, ssl_lengths, spec, spec_lengths, y, y_lengths, text, text_lengths, sv_emb = data
else: else:
ssl, ssl_lengths, spec, spec_lengths, y, y_lengths, text, text_lengths=data ssl, ssl_lengths, spec, spec_lengths, y, y_lengths, text, text_lengths = data
if torch.cuda.is_available(): if torch.cuda.is_available():
spec, spec_lengths = (spec.cuda(rank,non_blocking=True,),spec_lengths.cuda(rank,non_blocking=True,),) spec, spec_lengths = (
y, y_lengths = (y.cuda(rank,non_blocking=True,),y_lengths.cuda(rank,non_blocking=True,),) spec.cuda(
rank,
non_blocking=True,
),
spec_lengths.cuda(
rank,
non_blocking=True,
),
)
y, y_lengths = (
y.cuda(
rank,
non_blocking=True,
),
y_lengths.cuda(
rank,
non_blocking=True,
),
)
ssl = ssl.cuda(rank, non_blocking=True) ssl = ssl.cuda(rank, non_blocking=True)
ssl.requires_grad = False ssl.requires_grad = False
# ssl_lengths = ssl_lengths.cuda(rank, non_blocking=True) # ssl_lengths = ssl_lengths.cuda(rank, non_blocking=True)
text, text_lengths = (text.cuda(rank,non_blocking=True,),text_lengths.cuda(rank,non_blocking=True,),) text, text_lengths = (
text.cuda(
rank,
non_blocking=True,
),
text_lengths.cuda(
rank,
non_blocking=True,
),
)
if hps.model.version in {"v2Pro", "v2ProPlus"}: if hps.model.version in {"v2Pro", "v2ProPlus"}:
sv_emb = sv_emb.cuda(rank, non_blocking=True) sv_emb = sv_emb.cuda(rank, non_blocking=True)
else: else:
@ -334,9 +380,19 @@ def train_and_evaluate(rank, epoch, hps, nets, optims, schedulers, scaler, loade
sv_emb = sv_emb.to(device) sv_emb = sv_emb.to(device)
with autocast(enabled=hps.train.fp16_run): with autocast(enabled=hps.train.fp16_run):
if hps.model.version in {"v2Pro", "v2ProPlus"}: if hps.model.version in {"v2Pro", "v2ProPlus"}:
(y_hat,kl_ssl,ids_slice,x_mask,z_mask,(z, z_p, m_p, logs_p, m_q, logs_q),stats_ssl) = net_g(ssl, spec, spec_lengths, text, text_lengths,sv_emb) (y_hat, kl_ssl, ids_slice, x_mask, z_mask, (z, z_p, m_p, logs_p, m_q, logs_q), stats_ssl) = net_g(
ssl, spec, spec_lengths, text, text_lengths, sv_emb
)
else: else:
(y_hat,kl_ssl,ids_slice,x_mask,z_mask,(z, z_p, m_p, logs_p, m_q, logs_q),stats_ssl,) = net_g(ssl, spec, spec_lengths, text, text_lengths) (
y_hat,
kl_ssl,
ids_slice,
x_mask,
z_mask,
(z, z_p, m_p, logs_p, m_q, logs_q),
stats_ssl,
) = net_g(ssl, spec, spec_lengths, text, text_lengths)
mel = spec_to_mel_torch( mel = spec_to_mel_torch(
spec, spec,
@ -508,7 +564,14 @@ def train_and_evaluate(rank, epoch, hps, nets, optims, schedulers, scaler, loade
% ( % (
hps.name, hps.name,
epoch, epoch,
savee(ckpt,hps.name + "_e%s_s%s" % (epoch, global_step),epoch,global_step,hps,model_version=None if hps.model.version not in {"v2Pro","v2ProPlus"}else hps.model.version), savee(
ckpt,
hps.name + "_e%s_s%s" % (epoch, global_step),
epoch,
global_step,
hps,
model_version=None if hps.model.version not in {"v2Pro", "v2ProPlus"} else hps.model.version,
),
) )
) )

View File

@ -1,24 +1,32 @@
import sys,os,torch import sys
import os
import torch
sys.path.append(f"{os.getcwd()}/GPT_SoVITS/eres2net") sys.path.append(f"{os.getcwd()}/GPT_SoVITS/eres2net")
sv_path = "GPT_SoVITS/pretrained_models/sv/pretrained_eres2netv2w24s4ep4.ckpt" sv_path = "GPT_SoVITS/pretrained_models/sv/pretrained_eres2netv2w24s4ep4.ckpt"
from ERes2NetV2 import ERes2NetV2 from ERes2NetV2 import ERes2NetV2
import kaldi as Kaldi import kaldi as Kaldi
class SV: class SV:
def __init__(self,device,is_half): def __init__(self, device, is_half):
pretrained_state = torch.load(sv_path, map_location='cpu', weights_only=False) pretrained_state = torch.load(sv_path, map_location="cpu", weights_only=False)
embedding_model = ERes2NetV2(baseWidth=24,scale=4,expansion=4) embedding_model = ERes2NetV2(baseWidth=24, scale=4, expansion=4)
embedding_model.load_state_dict(pretrained_state) embedding_model.load_state_dict(pretrained_state)
embedding_model.eval() embedding_model.eval()
self.embedding_model=embedding_model self.embedding_model = embedding_model
if is_half == False: if is_half == False:
self.embedding_model=self.embedding_model.to(device) self.embedding_model = self.embedding_model.to(device)
else: else:
self.embedding_model=self.embedding_model.half().to(device) self.embedding_model = self.embedding_model.half().to(device)
self.is_half=is_half self.is_half = is_half
def compute_embedding3(self,wav): def compute_embedding3(self, wav):
with torch.no_grad(): with torch.no_grad():
if self.is_half==True:wav=wav.half() if self.is_half == True:
feat = torch.stack([Kaldi.fbank(wav0.unsqueeze(0), num_mel_bins=80, sample_frequency=16000, dither=0) for wav0 in wav]) wav = wav.half()
feat = torch.stack(
[Kaldi.fbank(wav0.unsqueeze(0), num_mel_bins=80, sample_frequency=16000, dither=0) for wav0 in wav]
)
sv_emb = self.embedding_model.forward3(feat) sv_emb = self.embedding_model.forward3(feat)
return sv_emb return sv_emb

View File

@ -87,22 +87,37 @@ class LangSegmenter():
"en": "en", "en": "en",
} }
def getTexts(text,default_lang = ""):
def getTexts(text):
lang_splitter = LangSplitter(lang_map=LangSegmenter.DEFAULT_LANG_MAP) lang_splitter = LangSplitter(lang_map=LangSegmenter.DEFAULT_LANG_MAP)
lang_splitter.merge_across_digit = False
substr = lang_splitter.split_by_lang(text=text) substr = lang_splitter.split_by_lang(text=text)
lang_list: list[dict] = [] lang_list: list[dict] = []
have_num = False
for _, item in enumerate(substr): for _, item in enumerate(substr):
dict_item = {'lang':item.lang,'text':item.text} dict_item = {'lang':item.lang,'text':item.text}
if dict_item['lang'] == 'digit':
if default_lang != "":
dict_item['lang'] = default_lang
else:
have_num = True
lang_list = merge_lang(lang_list,dict_item)
continue
# 处理短英文被识别为其他语言的问题 # 处理短英文被识别为其他语言的问题
if full_en(dict_item['text']): if full_en(dict_item['text']):
dict_item['lang'] = 'en' dict_item['lang'] = 'en'
lang_list = merge_lang(lang_list,dict_item) lang_list = merge_lang(lang_list,dict_item)
continue continue
if default_lang != "":
dict_item['lang'] = default_lang
lang_list = merge_lang(lang_list,dict_item)
continue
else:
# 处理非日语夹日文的问题(不包含CJK) # 处理非日语夹日文的问题(不包含CJK)
ja_list: list[dict] = [] ja_list: list[dict] = []
if dict_item['lang'] != 'ja': if dict_item['lang'] != 'ja':
@ -142,15 +157,46 @@ class LangSegmenter():
for _, temp_item in enumerate(temp_list): for _, temp_item in enumerate(temp_list):
# 未知语言检查是否为CJK # 未知语言检查是否为CJK
if temp_item['lang'] == 'x': if temp_item['lang'] == 'x':
cjk_text = full_cjk(dict_item['text']) cjk_text = full_cjk(temp_item['text'])
if cjk_text: if cjk_text:
dict_item = {'lang':'zh','text':cjk_text} lang_list = merge_lang(lang_list,{'lang':'zh','text':cjk_text})
lang_list = merge_lang(lang_list,dict_item)
else: else:
lang_list = merge_lang(lang_list,dict_item) lang_list = merge_lang(lang_list,temp_item)
else: else:
lang_list = merge_lang(lang_list,temp_item) lang_list = merge_lang(lang_list,temp_item)
# 有数字
if have_num:
temp_list = lang_list
lang_list = []
for i, temp_item in enumerate(temp_list):
if temp_item['lang'] == 'digit':
if default_lang:
temp_item['lang'] = default_lang
elif lang_list and i == len(temp_list) - 1:
temp_item['lang'] = lang_list[-1]['lang']
elif not lang_list and i < len(temp_list) - 1:
temp_item['lang'] = temp_list[1]['lang']
elif lang_list and i < len(temp_list) - 1:
if lang_list[-1]['lang'] == temp_list[i + 1]['lang']:
temp_item['lang'] = lang_list[-1]['lang']
elif lang_list[-1]['text'][-1] in [",",".","!","?","","","",""]:
temp_item['lang'] = temp_list[i + 1]['lang']
elif temp_list[i + 1]['text'][0] in [",",".","!","?","","","",""]:
temp_item['lang'] = lang_list[-1]['lang']
elif temp_item['text'][-1] in ["","."]:
temp_item['lang'] = lang_list[-1]['lang']
elif len(lang_list[-1]['text']) >= len(temp_list[i + 1]['text']):
temp_item['lang'] = lang_list[-1]['lang']
else:
temp_item['lang'] = temp_list[i + 1]['lang']
else:
temp_item['lang'] = 'zh'
lang_list = merge_lang(lang_list,temp_item)
# 筛X
temp_list = lang_list temp_list = lang_list
lang_list = [] lang_list = []
for _, temp_item in enumerate(temp_list): for _, temp_item in enumerate(temp_list):
@ -173,3 +219,7 @@ if __name__ == "__main__":
text = "ねえ、知ってる?最近、僕は天文学を勉強してるんだ。君の瞳が星空みたいにキラキラしてるからさ。" text = "ねえ、知ってる?最近、僕は天文学を勉強してるんだ。君の瞳が星空みたいにキラキラしてるからさ。"
print(LangSegmenter.getTexts(text)) print(LangSegmenter.getTexts(text))
text = "当时ThinkPad T60刚刚发布一同推出的还有一款名为Advanced Dock的扩展坞配件。这款扩展坞通过连接T60底部的插槽扩展出包括PCIe在内的一大堆接口并且自带电源让T60可以安装桌面显卡来提升性能。"
print(LangSegmenter.getTexts(text,"zh"))
print(LangSegmenter.getTexts(text))

View File

@ -181,20 +181,6 @@ def text_normalize(text):
return dest_text return dest_text
# 不排除英文的文本格式化
def mix_text_normalize(text):
# https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/paddlespeech/t2s/frontend/zh_normalization
tx = TextNormalizer()
sentences = tx.normalize(text)
dest_text = ""
for sentence in sentences:
dest_text += replace_punctuation_with_en(sentence)
# 避免重复标点引起的参考泄露
dest_text = replace_consecutive_punctuation(dest_text)
return dest_text
if __name__ == "__main__": if __name__ == "__main__":
text = "啊——但是《原神》是由,米哈\游自主,研发的一款全.新开放世界.冒险游戏" text = "啊——但是《原神》是由,米哈\游自主,研发的一款全.新开放世界.冒险游戏"
text = "呣呣呣~就是…大人的鼹鼠党吧?" text = "呣呣呣~就是…大人的鼹鼠党吧?"

View File

@ -326,20 +326,6 @@ def text_normalize(text):
return dest_text return dest_text
# 不排除英文的文本格式化
def mix_text_normalize(text):
# https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/paddlespeech/t2s/frontend/zh_normalization
tx = TextNormalizer()
sentences = tx.normalize(text)
dest_text = ""
for sentence in sentences:
dest_text += replace_punctuation_with_en(sentence)
# 避免重复标点引起的参考泄露
dest_text = replace_consecutive_punctuation(dest_text)
return dest_text
if __name__ == "__main__": if __name__ == "__main__":
text = "啊——但是《原神》是由,米哈\游自主,研发的一款全.新开放世界.冒险游戏" text = "啊——但是《原神》是由,米哈\游自主,研发的一款全.新开放世界.冒险游戏"
text = "呣呣呣~就是…大人的鼹鼠党吧?" text = "呣呣呣~就是…大人的鼹鼠党吧?"

View File

@ -3,7 +3,6 @@
import json import json
import os import os
import traceback
import warnings import warnings
import zipfile import zipfile
from typing import Any, Dict, List, Tuple from typing import Any, Dict, List, Tuple
@ -23,8 +22,9 @@ from .utils import load_config
onnxruntime.set_default_logger_severity(3) onnxruntime.set_default_logger_severity(3)
try: try:
onnxruntime.preload_dlls() onnxruntime.preload_dlls()
except:pass except:
#traceback.print_exc() pass
# traceback.print_exc()
warnings.filterwarnings("ignore") warnings.filterwarnings("ignore")
model_version = "1.1" model_version = "1.1"
@ -93,13 +93,13 @@ class G2PWOnnxConverter:
sess_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL sess_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
sess_options.execution_mode = onnxruntime.ExecutionMode.ORT_SEQUENTIAL sess_options.execution_mode = onnxruntime.ExecutionMode.ORT_SEQUENTIAL
sess_options.intra_op_num_threads = 2 if torch.cuda.is_available() else 0 sess_options.intra_op_num_threads = 2 if torch.cuda.is_available() else 0
try: if "CUDAExecutionProvider" in onnxruntime.get_available_providers():
self.session_g2pW = onnxruntime.InferenceSession( self.session_g2pW = onnxruntime.InferenceSession(
os.path.join(uncompress_path, "g2pW.onnx"), os.path.join(uncompress_path, "g2pW.onnx"),
sess_options=sess_options, sess_options=sess_options,
providers=["CUDAExecutionProvider", "CPUExecutionProvider"], providers=["CUDAExecutionProvider", "CPUExecutionProvider"],
) )
except: else:
self.session_g2pW = onnxruntime.InferenceSession( self.session_g2pW = onnxruntime.InferenceSession(
os.path.join(uncompress_path, "g2pW.onnx"), os.path.join(uncompress_path, "g2pW.onnx"),
sess_options=sess_options, sess_options=sess_options,

View File

@ -655,11 +655,7 @@ class ToneSandhi:
while i < len(seg): while i < len(seg):
word, pos = seg[i] word, pos = seg[i]
merged = False merged = False
if ( if i - 1 >= 0 and word == "" and i + 1 < len(seg):
i - 1 >= 0
and word == ""
and i + 1 < len(seg)
):
last = new_seg[-1] if new_seg else seg[i - 1] last = new_seg[-1] if new_seg else seg[i - 1]
if last[0] == seg[i + 1][0] and last[1] == "v" and seg[i + 1][1] == "v": if last[0] == seg[i + 1][0] and last[1] == "v" and seg[i + 1][1] == "v":
combined = last[0] + "" + seg[i + 1][0] combined = last[0] + "" + seg[i + 1][0]

View File

@ -256,6 +256,24 @@ def replace_to_range(match) -> str:
return result return result
RE_VERSION_NUM = re.compile(r"((\d+)(\.\d+)(\.\d+)?(\.\d+)+)")
def replace_vrsion_num(match) -> str:
"""
Args:
match (re.Match)
Returns:
str
"""
result = ""
for c in match.group(1):
if c == ".":
result += ""
else:
result += num2str(c)
return result
def _get_value(value_string: str, use_zero: bool = True) -> List[str]: def _get_value(value_string: str, use_zero: bool = True) -> List[str]:
stripped = value_string.lstrip("0") stripped = value_string.lstrip("0")
if len(stripped) == 0: if len(stripped) == 0:
@ -308,7 +326,11 @@ def num2str(value_string: str) -> str:
result = verbalize_cardinal(integer) result = verbalize_cardinal(integer)
if decimal.endswith("0"):
decimal = decimal.rstrip("0") + "0"
else:
decimal = decimal.rstrip("0") decimal = decimal.rstrip("0")
if decimal: if decimal:
# '.22' is verbalized as '零点二二' # '.22' is verbalized as '零点二二'
# '3.20' is verbalized as '三点二 # '3.20' is verbalized as '三点二

View File

@ -25,6 +25,7 @@ from .chronology import replace_time
from .constants import F2H_ASCII_LETTERS from .constants import F2H_ASCII_LETTERS
from .constants import F2H_DIGITS from .constants import F2H_DIGITS
from .constants import F2H_SPACE from .constants import F2H_SPACE
from .num import RE_VERSION_NUM
from .num import RE_DECIMAL_NUM from .num import RE_DECIMAL_NUM
from .num import RE_DEFAULT_NUM from .num import RE_DEFAULT_NUM
from .num import RE_FRAC from .num import RE_FRAC
@ -36,6 +37,7 @@ from .num import RE_RANGE
from .num import RE_TO_RANGE from .num import RE_TO_RANGE
from .num import RE_ASMD from .num import RE_ASMD
from .num import RE_POWER from .num import RE_POWER
from .num import replace_vrsion_num
from .num import replace_default_num from .num import replace_default_num
from .num import replace_frac from .num import replace_frac
from .num import replace_negative_num from .num import replace_negative_num
@ -158,6 +160,7 @@ class TextNormalizer:
sentence = RE_RANGE.sub(replace_range, sentence) sentence = RE_RANGE.sub(replace_range, sentence)
sentence = RE_INTEGER.sub(replace_negative_num, sentence) sentence = RE_INTEGER.sub(replace_negative_num, sentence)
sentence = RE_VERSION_NUM.sub(replace_vrsion_num, sentence)
sentence = RE_DECIMAL_NUM.sub(replace_number, sentence) sentence = RE_DECIMAL_NUM.sub(replace_number, sentence)
sentence = RE_POSITIVE_QUANTIFIERS.sub(replace_positive_quantifier, sentence) sentence = RE_POSITIVE_QUANTIFIERS.sub(replace_positive_quantifier, sentence)
sentence = RE_DEFAULT_NUM.sub(replace_default_num, sentence) sentence = RE_DEFAULT_NUM.sub(replace_default_num, sentence)

View File

@ -283,7 +283,7 @@ def get_hparams_from_file(config_path):
def check_git_hash(model_dir): def check_git_hash(model_dir):
source_dir = os.path.dirname(os.path.realpath(__file__)) source_dir = os.path.dirname(os.path.realpath(__file__))
if not os.path.exists(os.path.join(source_dir, ".git")): if not os.path.exists(os.path.join(source_dir, ".git")):
logger.warn( logger.warning(
"{} is not a git repository, therefore hash value comparison will be ignored.".format( "{} is not a git repository, therefore hash value comparison will be ignored.".format(
source_dir, source_dir,
) )
@ -296,7 +296,7 @@ def check_git_hash(model_dir):
if os.path.exists(path): if os.path.exists(path):
saved_hash = open(path).read() saved_hash = open(path).read()
if saved_hash != cur_hash: if saved_hash != cur_hash:
logger.warn( logger.warning(
"git hash values are different. {}(saved) != {}(current)".format( "git hash values are different. {}(saved) != {}(current)".format(
saved_hash[:8], saved_hash[:8],
cur_hash[:8], cur_hash[:8],

View File

@ -9,10 +9,17 @@ A Powerful Few-shot Voice Conversion and Text-to-Speech WebUI.<br><br>
<!-- img src="https://counter.seku.su/cmoe?name=gptsovits&theme=r34" /><br> --> <!-- img src="https://counter.seku.su/cmoe?name=gptsovits&theme=r34" /><br> -->
[![Open In Colab](https://img.shields.io/badge/Colab-F9AB00?style=for-the-badge&logo=googlecolab&color=525252)](https://colab.research.google.com/github/RVC-Boss/GPT-SoVITS/blob/main/colab_webui.ipynb) [![Python](https://img.shields.io/badge/python-3.10--3.12-blue?style=for-the-badge&logo=python)](https://www.python.org)
[![License](https://img.shields.io/badge/LICENSE-MIT-green.svg?style=for-the-badge)](https://github.com/RVC-Boss/GPT-SoVITS/blob/main/LICENSE) [![GitHub release](https://img.shields.io/github/v/release/RVC-Boss/gpt-sovits?style=for-the-badge&logo=github)](https://github.com/RVC-Boss/gpt-sovits/releases)
[![Huggingface](https://img.shields.io/badge/🤗%20-online%20demo-yellow.svg?style=for-the-badge)](https://huggingface.co/spaces/lj1995/GPT-SoVITS-v2)
[![Discord](https://img.shields.io/discord/1198701940511617164?color=%23738ADB&label=Discord&style=for-the-badge)](https://discord.gg/dnrgs5GHfG) [![Train In Colab](https://img.shields.io/badge/Colab-Training-F9AB00?style=for-the-badge&logo=googlecolab)](https://colab.research.google.com/github/RVC-Boss/GPT-SoVITS/blob/main/Colab-WebUI.ipynb)
[![Huggingface](https://img.shields.io/badge/免费在线体验-free_online_demo-yellow.svg?style=for-the-badge&logo=huggingface)](https://lj1995-gpt-sovits-proplus.hf.space/)
[![Image Size](https://img.shields.io/docker/image-size/xxxxrt666/gpt-sovits/latest?style=for-the-badge&logo=docker)](https://hub.docker.com/r/xxxxrt666/gpt-sovits)
[![简体中文](https://img.shields.io/badge/简体中文-阅读文档-blue?style=for-the-badge&logo=googledocs&logoColor=white)](https://www.yuque.com/baicaigongchang1145haoyuangong/ib3g1e)
[![English](https://img.shields.io/badge/English-Read%20Docs-blue?style=for-the-badge&logo=googledocs&logoColor=white)](https://rentry.co/GPT-SoVITS-guide#/)
[![Change Log](https://img.shields.io/badge/Change%20Log-View%20Updates-blue?style=for-the-badge&logo=googledocs&logoColor=white)](https://github.com/RVC-Boss/GPT-SoVITS/blob/main/docs/en/Changelog_EN.md)
[![License](https://img.shields.io/badge/LICENSE-MIT-green.svg?style=for-the-badge&logo=opensourceinitiative)](https://github.com/RVC-Boss/GPT-SoVITS/blob/main/LICENSE)
**English** | [**中文简体**](./docs/cn/README.md) | [**日本語**](./docs/ja/README.md) | [**한국어**](./docs/ko/README.md) | [**Türkçe**](./docs/tr/README.md) **English** | [**中文简体**](./docs/cn/README.md) | [**日本語**](./docs/ja/README.md) | [**한국어**](./docs/ko/README.md) | [**Türkçe**](./docs/tr/README.md)
@ -36,6 +43,11 @@ Unseen speakers few-shot fine-tuning demo:
https://github.com/RVC-Boss/GPT-SoVITS/assets/129054828/05bee1fa-bdd8-4d85-9350-80c060ab47fb https://github.com/RVC-Boss/GPT-SoVITS/assets/129054828/05bee1fa-bdd8-4d85-9350-80c060ab47fb
**RTF(inference speed) of GPT-SoVITS v2 ProPlus**:
0.028 tested in 4060Ti, 0.014 tested in 4090 (1400words~=4min, inference time is 3.36s), 0.526 in M4 CPU. You can test our [huggingface demo](https://lj1995-gpt-sovits-proplus.hf.space/) (half H200) to experience high-speed inference .
请不要尬黑GPT-SoVITS推理速度慢谢谢
**User guide: [简体中文](https://www.yuque.com/baicaigongchang1145haoyuangong/ib3g1e) | [English](https://rentry.co/GPT-SoVITS-guide#/)** **User guide: [简体中文](https://www.yuque.com/baicaigongchang1145haoyuangong/ib3g1e) | [English](https://rentry.co/GPT-SoVITS-guide#/)**
## Installation ## Installation
@ -60,6 +72,14 @@ If you are a Windows user (tested with win>=10), you can [download the integrate
**Users in China can [download the package here](https://www.yuque.com/baicaigongchang1145haoyuangong/ib3g1e/dkxgpiy9zb96hob4#KTvnO).** **Users in China can [download the package here](https://www.yuque.com/baicaigongchang1145haoyuangong/ib3g1e/dkxgpiy9zb96hob4#KTvnO).**
Install the program by running the following commands:
```pwsh
conda create -n GPTSoVits python=3.10
conda activate GPTSoVits
pwsh -F install.ps1 --Device <CU126|CU128|CPU> --Source <HF|HF-Mirror|ModelScope> [--DownloadUVR5]
```
### Linux ### Linux
```bash ```bash
@ -128,8 +148,9 @@ Due to rapid development in the codebase and a slower Docker image release cycle
- Check [Docker Hub](https://hub.docker.com/r/xxxxrt666/gpt-sovits) for the latest available image tags - Check [Docker Hub](https://hub.docker.com/r/xxxxrt666/gpt-sovits) for the latest available image tags
- Choose an appropriate image tag for your environment - Choose an appropriate image tag for your environment
- `Lite` means the Docker image does not include ASR models and UVR5 models. You can manually download the UVR5 models, while the program will automatically download the ASR models as needed - `Lite` means the Docker image **does not include** ASR models and UVR5 models. You can manually download the UVR5 models, while the program will automatically download the ASR models as needed
- The appropriate architecture image (amd64/arm64) will be automatically pulled during Docker Compose - The appropriate architecture image (amd64/arm64) will be automatically pulled during Docker Compose
- Docker Compose will mount **all files** in the current directory. Please switch to the project root directory and **pull the latest code** before using the Docker image
- Optionally, build the image locally using the provided Dockerfile for the most up-to-date changes - Optionally, build the image locally using the provided Dockerfile for the most up-to-date changes
#### Environment Variables #### Environment Variables
@ -333,7 +354,7 @@ Use v4 from v1/v2/v3 environment:
New Features: New Features:
1. Slightly higher VRAM usage than v2, surpassing v4's performance, with v2's hardware cost and speed. 1. Slightly higher VRAM usage than v2, surpassing v4's performance, with v2's hardware cost and speed.
[more details](https://github.com/RVC-Boss/GPT-SoVITS/wiki/GPT%E2%80%90SoVITS%E2%80%90features-(%E5%90%84%E7%89%88%E6%9C%AC%E7%89%B9%E6%80%A7)) [more details](<https://github.com/RVC-Boss/GPT-SoVITS/wiki/GPT%E2%80%90SoVITS%E2%80%90features-(%E5%90%84%E7%89%88%E6%9C%AC%E7%89%B9%E6%80%A7)>)
2.v1/v2 and the v2Pro series share the same characteristics, while v3/v4 have similar features. For training sets with average audio quality, v1/v2/v2Pro can deliver decent results, but v3/v4 cannot. Additionally, the synthesized tone and timebre of v3/v4 lean more toward the reference audio rather than the overall training set. 2.v1/v2 and the v2Pro series share the same characteristics, while v3/v4 have similar features. For training sets with average audio quality, v1/v2/v2Pro can deliver decent results, but v3/v4 cannot. Additionally, the synthesized tone and timebre of v3/v4 lean more toward the reference audio rather than the overall training set.

319
api.py
View File

@ -163,7 +163,7 @@ from transformers import AutoModelForMaskedLM, AutoTokenizer
import numpy as np import numpy as np
from feature_extractor import cnhubert from feature_extractor import cnhubert
from io import BytesIO from io import BytesIO
from module.models import SynthesizerTrn, SynthesizerTrnV3 from module.models import Generator, SynthesizerTrn, SynthesizerTrnV3
from peft import LoraConfig, get_peft_model from peft import LoraConfig, get_peft_model
from AR.models.t2s_lightning_module import Text2SemanticLightningModule from AR.models.t2s_lightning_module import Text2SemanticLightningModule
from text import cleaned_text_to_sequence from text import cleaned_text_to_sequence
@ -198,8 +198,44 @@ def is_full(*items): # 任意一项为空返回False
return True return True
def init_bigvgan(): bigvgan_model = hifigan_model = sv_cn_model = None
def clean_hifigan_model():
global hifigan_model
if hifigan_model:
hifigan_model = hifigan_model.cpu()
hifigan_model = None
try:
torch.cuda.empty_cache()
except:
pass
def clean_bigvgan_model():
global bigvgan_model global bigvgan_model
if bigvgan_model:
bigvgan_model = bigvgan_model.cpu()
bigvgan_model = None
try:
torch.cuda.empty_cache()
except:
pass
def clean_sv_cn_model():
global sv_cn_model
if sv_cn_model:
sv_cn_model.embedding_model = sv_cn_model.embedding_model.cpu()
sv_cn_model = None
try:
torch.cuda.empty_cache()
except:
pass
def init_bigvgan():
global bigvgan_model, hifigan_model, sv_cn_model
from BigVGAN import bigvgan from BigVGAN import bigvgan
bigvgan_model = bigvgan.BigVGAN.from_pretrained( bigvgan_model = bigvgan.BigVGAN.from_pretrained(
@ -209,20 +245,57 @@ def init_bigvgan():
# remove weight norm in the model and set to eval mode # remove weight norm in the model and set to eval mode
bigvgan_model.remove_weight_norm() bigvgan_model.remove_weight_norm()
bigvgan_model = bigvgan_model.eval() bigvgan_model = bigvgan_model.eval()
if is_half == True: if is_half == True:
bigvgan_model = bigvgan_model.half().to(device) bigvgan_model = bigvgan_model.half().to(device)
else: else:
bigvgan_model = bigvgan_model.to(device) bigvgan_model = bigvgan_model.to(device)
def init_hifigan():
global hifigan_model, bigvgan_model, sv_cn_model
hifigan_model = Generator(
initial_channel=100,
resblock="1",
resblock_kernel_sizes=[3, 7, 11],
resblock_dilation_sizes=[[1, 3, 5], [1, 3, 5], [1, 3, 5]],
upsample_rates=[10, 6, 2, 2, 2],
upsample_initial_channel=512,
upsample_kernel_sizes=[20, 12, 4, 4, 4],
gin_channels=0,
is_bias=True,
)
hifigan_model.eval()
hifigan_model.remove_weight_norm()
state_dict_g = torch.load(
"%s/GPT_SoVITS/pretrained_models/gsv-v4-pretrained/vocoder.pth" % (now_dir,),
map_location="cpu",
weights_only=False,
)
print("loading vocoder", hifigan_model.load_state_dict(state_dict_g))
if is_half == True:
hifigan_model = hifigan_model.half().to(device)
else:
hifigan_model = hifigan_model.to(device)
from sv import SV
def init_sv_cn():
global hifigan_model, bigvgan_model, sv_cn_model
sv_cn_model = SV(device, is_half)
resample_transform_dict = {} resample_transform_dict = {}
def resample(audio_tensor, sr0): def resample(audio_tensor, sr0, sr1, device):
global resample_transform_dict global resample_transform_dict
if sr0 not in resample_transform_dict: key = "%s-%s-%s" % (sr0, sr1, str(device))
resample_transform_dict[sr0] = torchaudio.transforms.Resample(sr0, 24000).to(device) if key not in resample_transform_dict:
return resample_transform_dict[sr0](audio_tensor) resample_transform_dict[key] = torchaudio.transforms.Resample(sr0, sr1).to(device)
return resample_transform_dict[key](audio_tensor)
from module.mel_processing import mel_spectrogram_torch from module.mel_processing import mel_spectrogram_torch
@ -252,6 +325,19 @@ mel_fn = lambda x: mel_spectrogram_torch(
"center": False, "center": False,
}, },
) )
mel_fn_v4 = lambda x: mel_spectrogram_torch(
x,
**{
"n_fft": 1280,
"win_size": 1280,
"hop_size": 320,
"num_mels": 100,
"sampling_rate": 32000,
"fmin": 0,
"fmax": None,
"center": False,
},
)
sr_model = None sr_model = None
@ -293,12 +379,19 @@ from process_ckpt import get_sovits_version_from_path_fast, load_sovits_new
def get_sovits_weights(sovits_path): def get_sovits_weights(sovits_path):
path_sovits_v3 = "GPT_SoVITS/pretrained_models/s2Gv3.pth" from config import pretrained_sovits_name
path_sovits_v3 = pretrained_sovits_name["v3"]
path_sovits_v4 = pretrained_sovits_name["v4"]
is_exist_s2gv3 = os.path.exists(path_sovits_v3) is_exist_s2gv3 = os.path.exists(path_sovits_v3)
is_exist_s2gv4 = os.path.exists(path_sovits_v4)
version, model_version, if_lora_v3 = get_sovits_version_from_path_fast(sovits_path) version, model_version, if_lora_v3 = get_sovits_version_from_path_fast(sovits_path)
if if_lora_v3 == True and is_exist_s2gv3 == False: is_exist = is_exist_s2gv3 if model_version == "v3" else is_exist_s2gv4
logger.info("SoVITS V3 底模缺失,无法加载相应 LoRA 权重") path_sovits = path_sovits_v3 if model_version == "v3" else path_sovits_v4
if if_lora_v3 == True and is_exist == False:
logger.info("SoVITS %s 底模缺失,无法加载相应 LoRA 权重" % model_version)
dict_s2 = load_sovits_new(sovits_path) dict_s2 = load_sovits_new(sovits_path)
hps = dict_s2["config"] hps = dict_s2["config"]
@ -311,11 +404,13 @@ def get_sovits_weights(sovits_path):
else: else:
hps.model.version = "v2" hps.model.version = "v2"
if model_version == "v3":
hps.model.version = "v3"
model_params_dict = vars(hps.model) model_params_dict = vars(hps.model)
if model_version != "v3": if model_version not in {"v3", "v4"}:
if "Pro" in model_version:
hps.model.version = model_version
if sv_cn_model == None:
init_sv_cn()
vq_model = SynthesizerTrn( vq_model = SynthesizerTrn(
hps.data.filter_length // 2 + 1, hps.data.filter_length // 2 + 1,
hps.train.segment_size // hps.data.hop_length, hps.train.segment_size // hps.data.hop_length,
@ -323,13 +418,18 @@ def get_sovits_weights(sovits_path):
**model_params_dict, **model_params_dict,
) )
else: else:
hps.model.version = model_version
vq_model = SynthesizerTrnV3( vq_model = SynthesizerTrnV3(
hps.data.filter_length // 2 + 1, hps.data.filter_length // 2 + 1,
hps.train.segment_size // hps.data.hop_length, hps.train.segment_size // hps.data.hop_length,
n_speakers=hps.data.n_speakers, n_speakers=hps.data.n_speakers,
**model_params_dict, **model_params_dict,
) )
if model_version == "v3":
init_bigvgan() init_bigvgan()
if model_version == "v4":
init_hifigan()
model_version = hps.model.version model_version = hps.model.version
logger.info(f"模型版本: {model_version}") logger.info(f"模型版本: {model_version}")
if "pretrained" not in sovits_path: if "pretrained" not in sovits_path:
@ -345,7 +445,8 @@ def get_sovits_weights(sovits_path):
if if_lora_v3 == False: if if_lora_v3 == False:
vq_model.load_state_dict(dict_s2["weight"], strict=False) vq_model.load_state_dict(dict_s2["weight"], strict=False)
else: else:
vq_model.load_state_dict(load_sovits_new(path_sovits_v3)["weight"], strict=False) path_sovits = path_sovits_v3 if model_version == "v3" else path_sovits_v4
vq_model.load_state_dict(load_sovits_new(path_sovits)["weight"], strict=False)
lora_rank = dict_s2["lora_rank"] lora_rank = dict_s2["lora_rank"]
lora_config = LoraConfig( lora_config = LoraConfig(
target_modules=["to_k", "to_q", "to_v", "to_out.0"], target_modules=["to_k", "to_q", "to_v", "to_out.0"],
@ -442,32 +543,31 @@ from text import chinese
def get_phones_and_bert(text, language, version, final=False): def get_phones_and_bert(text, language, version, final=False):
if language in {"en", "all_zh", "all_ja", "all_ko", "all_yue"}: text = re.sub(r' {2,}', ' ', text)
formattext = text
while " " in formattext:
formattext = formattext.replace(" ", " ")
if language == "all_zh":
if re.search(r"[A-Za-z]", formattext):
formattext = re.sub(r"[a-z]", lambda x: x.group(0).upper(), formattext)
formattext = chinese.mix_text_normalize(formattext)
return get_phones_and_bert(formattext, "zh", version)
else:
phones, word2ph, norm_text = clean_text_inf(formattext, language, version)
bert = get_bert_feature(norm_text, word2ph).to(device)
elif language == "all_yue" and re.search(r"[A-Za-z]", formattext):
formattext = re.sub(r"[a-z]", lambda x: x.group(0).upper(), formattext)
formattext = chinese.mix_text_normalize(formattext)
return get_phones_and_bert(formattext, "yue", version)
else:
phones, word2ph, norm_text = clean_text_inf(formattext, language, version)
bert = torch.zeros(
(1024, len(phones)),
dtype=torch.float16 if is_half == True else torch.float32,
).to(device)
elif language in {"zh", "ja", "ko", "yue", "auto", "auto_yue"}:
textlist = [] textlist = []
langlist = [] langlist = []
if language == "auto": if language == "all_zh":
for tmp in LangSegmenter.getTexts(text,"zh"):
langlist.append(tmp["lang"])
textlist.append(tmp["text"])
elif language == "all_yue":
for tmp in LangSegmenter.getTexts(text,"zh"):
if tmp["lang"] == "zh":
tmp["lang"] = "yue"
langlist.append(tmp["lang"])
textlist.append(tmp["text"])
elif language == "all_ja":
for tmp in LangSegmenter.getTexts(text,"ja"):
langlist.append(tmp["lang"])
textlist.append(tmp["text"])
elif language == "all_ko":
for tmp in LangSegmenter.getTexts(text,"ko"):
langlist.append(tmp["lang"])
textlist.append(tmp["text"])
elif language == "en":
langlist.append("en")
textlist.append(text)
elif language == "auto":
for tmp in LangSegmenter.getTexts(text): for tmp in LangSegmenter.getTexts(text):
langlist.append(tmp["lang"]) langlist.append(tmp["lang"])
textlist.append(tmp["text"]) textlist.append(tmp["text"])
@ -479,6 +579,10 @@ def get_phones_and_bert(text, language, version, final=False):
textlist.append(tmp["text"]) textlist.append(tmp["text"])
else: else:
for tmp in LangSegmenter.getTexts(text): for tmp in LangSegmenter.getTexts(text):
if langlist:
if (tmp["lang"] == "en" and langlist[-1] == "en") or (tmp["lang"] != "en" and langlist[-1] != "en"):
textlist[-1] += tmp["text"]
continue
if tmp["lang"] == "en": if tmp["lang"] == "en":
langlist.append(tmp["lang"]) langlist.append(tmp["lang"])
else: else:
@ -533,23 +637,34 @@ class DictToAttrRecursive(dict):
raise AttributeError(f"Attribute {item} not found") raise AttributeError(f"Attribute {item} not found")
def get_spepc(hps, filename): def get_spepc(hps, filename, dtype, device, is_v2pro=False):
audio, _ = librosa.load(filename, sr=int(hps.data.sampling_rate)) sr1 = int(hps.data.sampling_rate)
audio = torch.FloatTensor(audio) audio, sr0 = torchaudio.load(filename)
if sr0 != sr1:
audio = audio.to(device)
if audio.shape[0] == 2:
audio = audio.mean(0).unsqueeze(0)
audio = resample(audio, sr0, sr1, device)
else:
audio = audio.to(device)
if audio.shape[0] == 2:
audio = audio.mean(0).unsqueeze(0)
maxx = audio.abs().max() maxx = audio.abs().max()
if maxx > 1: if maxx > 1:
audio /= min(2, maxx) audio /= min(2, maxx)
audio_norm = audio
audio_norm = audio_norm.unsqueeze(0)
spec = spectrogram_torch( spec = spectrogram_torch(
audio_norm, audio,
hps.data.filter_length, hps.data.filter_length,
hps.data.sampling_rate, hps.data.sampling_rate,
hps.data.hop_length, hps.data.hop_length,
hps.data.win_length, hps.data.win_length,
center=False, center=False,
) )
return spec spec = spec.to(dtype)
if is_v2pro == True:
audio = resample(audio, sr1, 16000, device).to(dtype)
return spec, audio
def pack_audio(audio_bytes, data, rate): def pack_audio(audio_bytes, data, rate):
@ -736,6 +851,16 @@ def get_tts_wav(
t2s_model = infer_gpt.t2s_model t2s_model = infer_gpt.t2s_model
max_sec = infer_gpt.max_sec max_sec = infer_gpt.max_sec
if version == "v3":
if sample_steps not in [4, 8, 16, 32, 64, 128]:
sample_steps = 32
elif version == "v4":
if sample_steps not in [4, 8, 16, 32]:
sample_steps = 8
if if_sr and version != "v3":
if_sr = False
t0 = ttime() t0 = ttime()
prompt_text = prompt_text.strip("\n") prompt_text = prompt_text.strip("\n")
if prompt_text[-1] not in splits: if prompt_text[-1] not in splits:
@ -759,19 +884,29 @@ def get_tts_wav(
prompt_semantic = codes[0, 0] prompt_semantic = codes[0, 0]
prompt = prompt_semantic.unsqueeze(0).to(device) prompt = prompt_semantic.unsqueeze(0).to(device)
if version != "v3": is_v2pro = version in {"v2Pro", "v2ProPlus"}
if version not in {"v3", "v4"}:
refers = [] refers = []
if is_v2pro:
sv_emb = []
if sv_cn_model == None:
init_sv_cn()
if inp_refs: if inp_refs:
for path in inp_refs: for path in inp_refs:
try: try: #####这里加上提取sv的逻辑要么一堆sv一堆refer要么单个sv单个refer
refer = get_spepc(hps, path).to(dtype).to(device) refer, audio_tensor = get_spepc(hps, path.name, dtype, device, is_v2pro)
refers.append(refer) refers.append(refer)
if is_v2pro:
sv_emb.append(sv_cn_model.compute_embedding3(audio_tensor))
except Exception as e: except Exception as e:
logger.error(e) logger.error(e)
if len(refers) == 0: if len(refers) == 0:
refers = [get_spepc(hps, ref_wav_path).to(dtype).to(device)] refers, audio_tensor = get_spepc(hps, ref_wav_path, dtype, device, is_v2pro)
refers = [refers]
if is_v2pro:
sv_emb = [sv_cn_model.compute_embedding3(audio_tensor)]
else: else:
refer = get_spepc(hps, ref_wav_path).to(device).to(dtype) refer, audio_tensor = get_spepc(hps, ref_wav_path, dtype, device)
t1 = ttime() t1 = ttime()
# os.environ['version'] = version # os.environ['version'] = version
@ -811,41 +946,56 @@ def get_tts_wav(
pred_semantic = pred_semantic[:, -idx:].unsqueeze(0) pred_semantic = pred_semantic[:, -idx:].unsqueeze(0)
t3 = ttime() t3 = ttime()
if version != "v3": if version not in {"v3", "v4"}:
if is_v2pro:
audio = ( audio = (
vq_model.decode(pred_semantic, torch.LongTensor(phones2).to(device).unsqueeze(0), refers, speed=speed) vq_model.decode(
pred_semantic,
torch.LongTensor(phones2).to(device).unsqueeze(0),
refers,
speed=speed,
sv_emb=sv_emb,
)
.detach() .detach()
.cpu() .cpu()
.numpy()[0, 0] .numpy()[0, 0]
) ###试试重建不带上prompt部分 )
else:
audio = (
vq_model.decode(
pred_semantic, torch.LongTensor(phones2).to(device).unsqueeze(0), refers, speed=speed
)
.detach()
.cpu()
.numpy()[0, 0]
)
else: else:
phoneme_ids0 = torch.LongTensor(phones1).to(device).unsqueeze(0) phoneme_ids0 = torch.LongTensor(phones1).to(device).unsqueeze(0)
phoneme_ids1 = torch.LongTensor(phones2).to(device).unsqueeze(0) phoneme_ids1 = torch.LongTensor(phones2).to(device).unsqueeze(0)
# print(11111111, phoneme_ids0, phoneme_ids1)
fea_ref, ge = vq_model.decode_encp(prompt.unsqueeze(0), phoneme_ids0, refer) fea_ref, ge = vq_model.decode_encp(prompt.unsqueeze(0), phoneme_ids0, refer)
ref_audio, sr = torchaudio.load(ref_wav_path) ref_audio, sr = torchaudio.load(ref_wav_path)
ref_audio = ref_audio.to(device).float() ref_audio = ref_audio.to(device).float()
if ref_audio.shape[0] == 2: if ref_audio.shape[0] == 2:
ref_audio = ref_audio.mean(0).unsqueeze(0) ref_audio = ref_audio.mean(0).unsqueeze(0)
if sr != 24000:
ref_audio = resample(ref_audio, sr) tgt_sr = 24000 if version == "v3" else 32000
# print("ref_audio",ref_audio.abs().mean()) if sr != tgt_sr:
mel2 = mel_fn(ref_audio) ref_audio = resample(ref_audio, sr, tgt_sr, device)
mel2 = mel_fn(ref_audio) if version == "v3" else mel_fn_v4(ref_audio)
mel2 = norm_spec(mel2) mel2 = norm_spec(mel2)
T_min = min(mel2.shape[2], fea_ref.shape[2]) T_min = min(mel2.shape[2], fea_ref.shape[2])
mel2 = mel2[:, :, :T_min] mel2 = mel2[:, :, :T_min]
fea_ref = fea_ref[:, :, :T_min] fea_ref = fea_ref[:, :, :T_min]
if T_min > 468: Tref = 468 if version == "v3" else 500
mel2 = mel2[:, :, -468:] Tchunk = 934 if version == "v3" else 1000
fea_ref = fea_ref[:, :, -468:] if T_min > Tref:
T_min = 468 mel2 = mel2[:, :, -Tref:]
chunk_len = 934 - T_min fea_ref = fea_ref[:, :, -Tref:]
# print("fea_ref",fea_ref,fea_ref.shape) T_min = Tref
# print("mel2",mel2) chunk_len = Tchunk - T_min
mel2 = mel2.to(dtype) mel2 = mel2.to(dtype)
fea_todo, ge = vq_model.decode_encp(pred_semantic, phoneme_ids1, refer, ge, speed) fea_todo, ge = vq_model.decode_encp(pred_semantic, phoneme_ids1, refer, ge, speed)
# print("fea_todo",fea_todo)
# print("ge",ge.abs().mean())
cfm_resss = [] cfm_resss = []
idx = 0 idx = 0
while 1: while 1:
@ -854,22 +1004,24 @@ def get_tts_wav(
break break
idx += chunk_len idx += chunk_len
fea = torch.cat([fea_ref, fea_todo_chunk], 2).transpose(2, 1) fea = torch.cat([fea_ref, fea_todo_chunk], 2).transpose(2, 1)
# set_seed(123)
cfm_res = vq_model.cfm.inference( cfm_res = vq_model.cfm.inference(
fea, torch.LongTensor([fea.size(1)]).to(fea.device), mel2, sample_steps, inference_cfg_rate=0 fea, torch.LongTensor([fea.size(1)]).to(fea.device), mel2, sample_steps, inference_cfg_rate=0
) )
cfm_res = cfm_res[:, :, mel2.shape[2] :] cfm_res = cfm_res[:, :, mel2.shape[2] :]
mel2 = cfm_res[:, :, -T_min:] mel2 = cfm_res[:, :, -T_min:]
# print("fea", fea)
# print("mel2in", mel2)
fea_ref = fea_todo_chunk[:, :, -T_min:] fea_ref = fea_todo_chunk[:, :, -T_min:]
cfm_resss.append(cfm_res) cfm_resss.append(cfm_res)
cmf_res = torch.cat(cfm_resss, 2) cfm_res = torch.cat(cfm_resss, 2)
cmf_res = denorm_spec(cmf_res) cfm_res = denorm_spec(cfm_res)
if version == "v3":
if bigvgan_model == None: if bigvgan_model == None:
init_bigvgan() init_bigvgan()
else: # v4
if hifigan_model == None:
init_hifigan()
vocoder_model = bigvgan_model if version == "v3" else hifigan_model
with torch.inference_mode(): with torch.inference_mode():
wav_gen = bigvgan_model(cmf_res) wav_gen = vocoder_model(cfm_res)
audio = wav_gen[0][0].cpu().detach().numpy() audio = wav_gen[0][0].cpu().detach().numpy()
max_audio = np.abs(audio).max() max_audio = np.abs(audio).max()
@ -880,7 +1032,13 @@ def get_tts_wav(
audio_opt = np.concatenate(audio_opt, 0) audio_opt = np.concatenate(audio_opt, 0)
t4 = ttime() t4 = ttime()
sr = hps.data.sampling_rate if version != "v3" else 24000 if version in {"v1", "v2", "v2Pro", "v2ProPlus"}:
sr = 32000
elif version == "v3":
sr = 24000
else:
sr = 48000 # v4
if if_sr and sr == 24000: if if_sr and sr == 24000:
audio_opt = torch.from_numpy(audio_opt).float().to(device) audio_opt = torch.from_numpy(audio_opt).float().to(device)
audio_opt, sr = audio_sr(audio_opt.unsqueeze(0), sr) audio_opt, sr = audio_sr(audio_opt.unsqueeze(0), sr)
@ -900,8 +1058,12 @@ def get_tts_wav(
if not stream_mode == "normal": if not stream_mode == "normal":
if media_type == "wav": if media_type == "wav":
if version in {"v1", "v2", "v2Pro", "v2ProPlus"}:
sr = 32000
elif version == "v3":
sr = 48000 if if_sr else 24000 sr = 48000 if if_sr else 24000
sr = hps.data.sampling_rate if version != "v3" else sr else:
sr = 48000 # v4
audio_bytes = pack_wav(audio_bytes, sr) audio_bytes = pack_wav(audio_bytes, sr)
yield audio_bytes.getvalue() yield audio_bytes.getvalue()
@ -966,9 +1128,6 @@ def handle(
if not default_refer.is_ready(): if not default_refer.is_ready():
return JSONResponse({"code": 400, "message": "未指定参考音频且接口无预设"}, status_code=400) return JSONResponse({"code": 400, "message": "未指定参考音频且接口无预设"}, status_code=400)
if sample_steps not in [4, 8, 16, 32]:
sample_steps = 32
if cut_punc == None: if cut_punc == None:
text = cut_text(text, default_cut_punc) text = cut_text(text, default_cut_punc)
else: else:
@ -1071,10 +1230,10 @@ default_refer = DefaultRefer(args.default_refer_path, args.default_refer_text, a
# 模型路径检查 # 模型路径检查
if sovits_path == "": if sovits_path == "":
sovits_path = g_config.pretrained_sovits_path sovits_path = g_config.pretrained_sovits_path
logger.warn(f"未指定SoVITS模型路径, fallback后当前值: {sovits_path}") logger.warning(f"未指定SoVITS模型路径, fallback后当前值: {sovits_path}")
if gpt_path == "": if gpt_path == "":
gpt_path = g_config.pretrained_gpt_path gpt_path = g_config.pretrained_gpt_path
logger.warn(f"未指定GPT模型路径, fallback后当前值: {gpt_path}") logger.warning(f"未指定GPT模型路径, fallback后当前值: {gpt_path}")
# 指定默认参考音频, 调用方 未提供/未给全 参考音频参数时使用 # 指定默认参考音频, 调用方 未提供/未给全 参考音频参数时使用
if default_refer.path == "" or default_refer.text == "" or default_refer.language == "": if default_refer.path == "" or default_refer.text == "" or default_refer.language == "":

View File

@ -33,14 +33,14 @@ POST:
"text_split_method": "cut0", # str. text split method, see text_segmentation_method.py for details. "text_split_method": "cut0", # str. text split method, see text_segmentation_method.py for details.
"batch_size": 1, # int. batch size for inference "batch_size": 1, # int. batch size for inference
"batch_threshold": 0.75, # float. threshold for batch splitting. "batch_threshold": 0.75, # float. threshold for batch splitting.
"split_bucket: True, # bool. whether to split the batch into multiple buckets. "split_bucket": True, # bool. whether to split the batch into multiple buckets.
"speed_factor":1.0, # float. control the speed of the synthesized audio. "speed_factor":1.0, # float. control the speed of the synthesized audio.
"streaming_mode": False, # bool. whether to return a streaming response. "streaming_mode": False, # bool. whether to return a streaming response.
"seed": -1, # int. random seed for reproducibility. "seed": -1, # int. random seed for reproducibility.
"parallel_infer": True, # bool. whether to use parallel inference. "parallel_infer": True, # bool. whether to use parallel inference.
"repetition_penalty": 1.35 # float. repetition penalty for T2S model. "repetition_penalty": 1.35, # float. repetition penalty for T2S model.
"sample_steps": 32, # int. number of sampling steps for VITS model V3. "sample_steps": 32, # int. number of sampling steps for VITS model V3.
"super_sampling": False, # bool. whether to use super-sampling for audio when using VITS model V3. "super_sampling": False # bool. whether to use super-sampling for audio when using VITS model V3.
} }
``` ```

View File

@ -1,442 +0,0 @@
import argparse
import os
import pdb
import signal
import sys
from time import time as ttime
import torch
import librosa
import soundfile as sf
from fastapi import FastAPI, Request, HTTPException
from fastapi.responses import StreamingResponse
import uvicorn
from transformers import AutoModelForMaskedLM, AutoTokenizer
import numpy as np
from feature_extractor import cnhubert
from io import BytesIO
from module.models import SynthesizerTrn
from AR.models.t2s_lightning_module import Text2SemanticLightningModule
from text import cleaned_text_to_sequence
from text.cleaner import clean_text
from module.mel_processing import spectrogram_torch
from my_utils import load_audio
import config as global_config
g_config = global_config.Config()
# AVAILABLE_COMPUTE = "cuda" if torch.cuda.is_available() else "cpu"
parser = argparse.ArgumentParser(description="GPT-SoVITS api")
parser.add_argument("-s", "--sovits_path", type=str, default=g_config.sovits_path, help="SoVITS模型路径")
parser.add_argument("-g", "--gpt_path", type=str, default=g_config.gpt_path, help="GPT模型路径")
parser.add_argument("-dr", "--default_refer_path", type=str, default="",
help="默认参考音频路径, 请求缺少参考音频时调用")
parser.add_argument("-dt", "--default_refer_text", type=str, default="", help="默认参考音频文本")
parser.add_argument("-dl", "--default_refer_language", type=str, default="", help="默认参考音频语种")
parser.add_argument("-d", "--device", type=str, default=g_config.infer_device, help="cuda / cpu")
parser.add_argument("-p", "--port", type=int, default=g_config.api_port, help="default: 9880")
parser.add_argument("-a", "--bind_addr", type=str, default="127.0.0.1", help="default: 127.0.0.1")
parser.add_argument("-fp", "--full_precision", action="store_true", default=False, help="覆盖config.is_half为False, 使用全精度")
parser.add_argument("-hp", "--half_precision", action="store_true", default=False, help="覆盖config.is_half为True, 使用半精度")
# bool值的用法为 `python ./api.py -fp ...`
# 此时 full_precision==True, half_precision==False
parser.add_argument("-hb", "--hubert_path", type=str, default=g_config.cnhubert_path, help="覆盖config.cnhubert_path")
parser.add_argument("-b", "--bert_path", type=str, default=g_config.bert_path, help="覆盖config.bert_path")
args = parser.parse_args()
sovits_path = args.sovits_path
gpt_path = args.gpt_path
default_refer_path = args.default_refer_path
default_refer_text = args.default_refer_text
default_refer_language = args.default_refer_language
has_preset = False
device = args.device
port = args.port
host = args.bind_addr
if sovits_path == "":
sovits_path = g_config.pretrained_sovits_path
print(f"[WARN] 未指定SoVITS模型路径, fallback后当前值: {sovits_path}")
if gpt_path == "":
gpt_path = g_config.pretrained_gpt_path
print(f"[WARN] 未指定GPT模型路径, fallback后当前值: {gpt_path}")
# 指定默认参考音频, 调用方 未提供/未给全 参考音频参数时使用
if default_refer_path == "" or default_refer_text == "" or default_refer_language == "":
default_refer_path, default_refer_text, default_refer_language = "", "", ""
print("[INFO] 未指定默认参考音频")
has_preset = False
else:
print(f"[INFO] 默认参考音频路径: {default_refer_path}")
print(f"[INFO] 默认参考音频文本: {default_refer_text}")
print(f"[INFO] 默认参考音频语种: {default_refer_language}")
has_preset = True
is_half = g_config.is_half
if args.full_precision:
is_half = False
if args.half_precision:
is_half = True
if args.full_precision and args.half_precision:
is_half = g_config.is_half # 炒饭fallback
print(f"[INFO] 半精: {is_half}")
cnhubert_base_path = args.hubert_path
bert_path = args.bert_path
cnhubert.cnhubert_base_path = cnhubert_base_path
tokenizer = AutoTokenizer.from_pretrained(bert_path)
bert_model = AutoModelForMaskedLM.from_pretrained(bert_path)
if is_half:
bert_model = bert_model.half().to(device)
else:
bert_model = bert_model.to(device)
def get_bert_feature(text, word2ph):
with torch.no_grad():
inputs = tokenizer(text, return_tensors="pt")
for i in inputs:
inputs[i] = inputs[i].to(device) #####输入是long不用管精度问题精度随bert_model
res = bert_model(**inputs, output_hidden_states=True)
res = torch.cat(res["hidden_states"][-3:-2], -1)[0].cpu()[1:-1]
assert len(word2ph) == len(text)
phone_level_feature = []
for i in range(len(word2ph)):
repeat_feature = res[i].repeat(word2ph[i], 1)
phone_level_feature.append(repeat_feature)
phone_level_feature = torch.cat(phone_level_feature, dim=0)
# if(is_half==True):phone_level_feature=phone_level_feature.half()
return phone_level_feature.T
n_semantic = 1024
dict_s2 = torch.load(sovits_path, map_location="cpu", weights_only=False)
hps = dict_s2["config"]
print(hps)
class DictToAttrRecursive(dict):
def __init__(self, input_dict):
super().__init__(input_dict)
for key, value in input_dict.items():
if isinstance(value, dict):
value = DictToAttrRecursive(value)
self[key] = value
setattr(self, key, value)
def __getattr__(self, item):
try:
return self[item]
except KeyError:
raise AttributeError(f"Attribute {item} not found")
def __setattr__(self, key, value):
if isinstance(value, dict):
value = DictToAttrRecursive(value)
super(DictToAttrRecursive, self).__setitem__(key, value)
super().__setattr__(key, value)
def __delattr__(self, item):
try:
del self[item]
except KeyError:
raise AttributeError(f"Attribute {item} not found")
hps = DictToAttrRecursive(hps)
hps.model.semantic_frame_rate = "25hz"
dict_s1 = torch.load(gpt_path, map_location="cpu", weights_only=False)
config = dict_s1["config"]
ssl_model = cnhubert.get_model()
if is_half:
ssl_model = ssl_model.half().to(device)
else:
ssl_model = ssl_model.to(device)
vq_model = SynthesizerTrn(
hps.data.filter_length // 2 + 1,
hps.train.segment_size // hps.data.hop_length,
n_speakers=hps.data.n_speakers,
**hps.model)
if is_half:
vq_model = vq_model.half().to(device)
else:
vq_model = vq_model.to(device)
vq_model.eval()
print(vq_model.load_state_dict(dict_s2["weight"], strict=False))
hz = 50
max_sec = config['data']['max_sec']
t2s_model = Text2SemanticLightningModule(config, "ojbk", is_train=False)
t2s_model.load_state_dict(dict_s1["weight"])
if is_half:
t2s_model = t2s_model.half()
t2s_model = t2s_model.to(device)
t2s_model.eval()
total = sum([param.nelement() for param in t2s_model.parameters()])
print("Number of parameter: %.2fM" % (total / 1e6))
def get_spepc(hps, filename):
audio = load_audio(filename, int(hps.data.sampling_rate))
audio = torch.FloatTensor(audio)
audio_norm = audio
audio_norm = audio_norm.unsqueeze(0)
spec = spectrogram_torch(audio_norm, hps.data.filter_length, hps.data.sampling_rate, hps.data.hop_length,
hps.data.win_length, center=False)
return spec
dict_language = {
"中文": "zh",
"英文": "en",
"日文": "ja",
"ZH": "zh",
"EN": "en",
"JA": "ja",
"zh": "zh",
"en": "en",
"ja": "ja"
}
def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language):
t0 = ttime()
prompt_text = prompt_text.strip("\n")
prompt_language, text = prompt_language, text.strip("\n")
zero_wav = np.zeros(int(hps.data.sampling_rate * 0.3), dtype=np.float16 if is_half == True else np.float32)
with torch.no_grad():
wav16k, sr = librosa.load(ref_wav_path, sr=16000)
wav16k = torch.from_numpy(wav16k)
zero_wav_torch = torch.from_numpy(zero_wav)
if (is_half == True):
wav16k = wav16k.half().to(device)
zero_wav_torch = zero_wav_torch.half().to(device)
else:
wav16k = wav16k.to(device)
zero_wav_torch = zero_wav_torch.to(device)
wav16k=torch.cat([wav16k,zero_wav_torch])
ssl_content = ssl_model.model(wav16k.unsqueeze(0))["last_hidden_state"].transpose(1, 2) # .float()
codes = vq_model.extract_latent(ssl_content)
prompt_semantic = codes[0, 0]
t1 = ttime()
prompt_language = dict_language[prompt_language]
text_language = dict_language[text_language]
phones1, word2ph1, norm_text1 = clean_text(prompt_text, prompt_language)
phones1 = cleaned_text_to_sequence(phones1)
texts = text.split("\n")
audio_opt = []
for text in texts:
phones2, word2ph2, norm_text2 = clean_text(text, text_language)
phones2 = cleaned_text_to_sequence(phones2)
if (prompt_language == "zh"):
bert1 = get_bert_feature(norm_text1, word2ph1).to(device)
else:
bert1 = torch.zeros((1024, len(phones1)), dtype=torch.float16 if is_half == True else torch.float32).to(
device)
if (text_language == "zh"):
bert2 = get_bert_feature(norm_text2, word2ph2).to(device)
else:
bert2 = torch.zeros((1024, len(phones2))).to(bert1)
bert = torch.cat([bert1, bert2], 1)
all_phoneme_ids = torch.LongTensor(phones1 + phones2).to(device).unsqueeze(0)
bert = bert.to(device).unsqueeze(0)
all_phoneme_len = torch.tensor([all_phoneme_ids.shape[-1]]).to(device)
prompt = prompt_semantic.unsqueeze(0).to(device)
t2 = ttime()
with torch.no_grad():
# pred_semantic = t2s_model.model.infer(
pred_semantic, idx = t2s_model.model.infer_panel(
all_phoneme_ids,
all_phoneme_len,
prompt,
bert,
# prompt_phone_len=ph_offset,
top_k=config['inference']['top_k'],
early_stop_num=hz * max_sec)
t3 = ttime()
# print(pred_semantic.shape,idx)
pred_semantic = pred_semantic[:, -idx:].unsqueeze(0) # .unsqueeze(0)#mq要多unsqueeze一次
refer = get_spepc(hps, ref_wav_path) # .to(device)
if (is_half == True):
refer = refer.half().to(device)
else:
refer = refer.to(device)
# audio = vq_model.decode(pred_semantic, all_phoneme_ids, refer).detach().cpu().numpy()[0, 0]
audio = \
vq_model.decode(pred_semantic, torch.LongTensor(phones2).to(device).unsqueeze(0),
refer).detach().cpu().numpy()[
0, 0] ###试试重建不带上prompt部分
audio_opt.append(audio)
audio_opt.append(zero_wav)
t4 = ttime()
print("%.3f\t%.3f\t%.3f\t%.3f" % (t1 - t0, t2 - t1, t3 - t2, t4 - t3))
# yield hps.data.sampling_rate, (np.concatenate(audio_opt, 0) * 32768).astype(np.int16)
return hps.data.sampling_rate, (np.concatenate(audio_opt, 0) * 32768).astype(np.int16)
def get_tts_wavs(ref_wav_path, prompt_text, prompt_language, textss, text_language):
t0 = ttime()
prompt_text = prompt_text.strip("\n")
zero_wav = np.zeros(int(hps.data.sampling_rate * 0.3), dtype=np.float16 if is_half == True else np.float32)
with torch.no_grad():
wav16k, sr = librosa.load(ref_wav_path, sr=16000)
wav16k = torch.from_numpy(wav16k)
zero_wav_torch = torch.from_numpy(zero_wav)
if (is_half == True):
wav16k = wav16k.half().to(device)
zero_wav_torch = zero_wav_torch.half().to(device)
else:
wav16k = wav16k.to(device)
zero_wav_torch = zero_wav_torch.to(device)
wav16k=torch.cat([wav16k,zero_wav_torch])
ssl_content = ssl_model.model(wav16k.unsqueeze(0))["last_hidden_state"].transpose(1, 2) # .float()
codes = vq_model.extract_latent(ssl_content)
prompt_semantic = codes[0, 0]
t1 = ttime()
prompt_language = dict_language[prompt_language]
text_language = dict_language[text_language]
phones1, word2ph1, norm_text1 = clean_text(prompt_text, prompt_language)
phones1 = cleaned_text_to_sequence(phones1)
audios_opt=[]
for text0 in textss:
texts = text0.strip("\n").split("\n")
audio_opt = []
for text in texts:
text=text.strip("")+""
phones2, word2ph2, norm_text2 = clean_text(text, text_language)
phones2 = cleaned_text_to_sequence(phones2)
if (prompt_language == "zh"):
bert1 = get_bert_feature(norm_text1, word2ph1).to(device)
else:
bert1 = torch.zeros((1024, len(phones1)), dtype=torch.float16 if is_half == True else torch.float32).to(
device)
if (text_language == "zh"):
bert2 = get_bert_feature(norm_text2, word2ph2).to(device)
else:
bert2 = torch.zeros((1024, len(phones2))).to(bert1)
bert = torch.cat([bert1, bert2], 1)
all_phoneme_ids = torch.LongTensor(phones1 + phones2).to(device).unsqueeze(0)
bert = bert.to(device).unsqueeze(0)
all_phoneme_len = torch.tensor([all_phoneme_ids.shape[-1]]).to(device)
prompt = prompt_semantic.unsqueeze(0).to(device)
t2 = ttime()
with torch.no_grad():
# pred_semantic = t2s_model.model.infer(
pred_semantic, idx = t2s_model.model.infer_panel(
all_phoneme_ids,
all_phoneme_len,
prompt,
bert,
# prompt_phone_len=ph_offset,
top_k=config['inference']['top_k'],
early_stop_num=hz * max_sec)
t3 = ttime()
# print(pred_semantic.shape,idx)
pred_semantic = pred_semantic[:, -idx:].unsqueeze(0) # .unsqueeze(0)#mq要多unsqueeze一次
refer = get_spepc(hps, ref_wav_path) # .to(device)
if (is_half == True):
refer = refer.half().to(device)
else:
refer = refer.to(device)
# audio = vq_model.decode(pred_semantic, all_phoneme_ids, refer).detach().cpu().numpy()[0, 0]
audio = \
vq_model.decode(pred_semantic, torch.LongTensor(phones2).to(device).unsqueeze(0),
refer).detach().cpu().numpy()[
0, 0] ###试试重建不带上prompt部分
audio_opt.append(audio)
audio_opt.append(zero_wav)
t4 = ttime()
print("%.3f\t%.3f\t%.3f\t%.3f" % (t1 - t0, t2 - t1, t3 - t2, t4 - t3))
audios_opt.append([text0,(np.concatenate(audio_opt, 0) * 32768).astype(np.int16)])
return audios_opt
# get_tts_wav(r"D:\BaiduNetdiskDownload\gsv\speech\萧逸声音-你得先从滑雪的基本技巧学起.wav", "你得先从滑雪的基本技巧学起。", "中文", "我觉得还是该给喜欢的女孩子一场认真的告白。", "中文")
# with open(r"D:\BaiduNetdiskDownload\gsv\烟嗓-todo1.txt","r",encoding="utf8")as f:
# with open(r"D:\BaiduNetdiskDownload\gsv\年下-todo1.txt","r",encoding="utf8")as f:
# with open(r"D:\BaiduNetdiskDownload\gsv\萧逸3b.txt","r",encoding="utf8")as f:
with open(r"D:\BaiduNetdiskDownload\gsv\萧逸4.txt","r",encoding="utf8")as f:
textss=f.read().split("\n")
for idx,(text,audio)in enumerate(get_tts_wavs(r"D:\BaiduNetdiskDownload\gsv\speech\萧逸声音-你得先从滑雪的基本技巧学起.wav", "你得先从滑雪的基本技巧学起。", "中文", textss, "中文")):
# for idx,(text,audio)in enumerate(get_tts_wavs(r"D:\BaiduNetdiskDownload\gsv\足够的能力,去制定好自己的生活规划。低沉烟嗓.MP3_1940480_2095360.wav", "足够的能力,去制定好自己的生活规划。", "中文", textss, "中文")):
# for idx,(text,audio)in enumerate(get_tts_wavs(r"D:\BaiduNetdiskDownload\gsv\不会呀!你前几天才吃过你还说好吃来着。年下少年音.MP3_537600_711040.wav", "不会呀!你前几天才吃过你还说好吃来着。", "中文", textss, "中文")):
print(idx,text)
# sf.write(r"D:\BaiduNetdiskDownload\gsv\output\烟嗓第一批\%04d-%s.wav"%(idx,text),audio,32000)
# sf.write(r"D:\BaiduNetdiskDownload\gsv\output\年下\%04d-%s.wav"%(idx,text),audio,32000)
sf.write(r"D:\BaiduNetdiskDownload\gsv\output\萧逸第4批\%04d-%s.wav"%(idx,text),audio,32000)
# def handle(command, refer_wav_path, prompt_text, prompt_language, text, text_language):
# if command == "/restart":
# os.execl(g_config.python_exec, g_config.python_exec, *sys.argv)
# elif command == "/exit":
# os.kill(os.getpid(), signal.SIGTERM)
# exit(0)
#
# if (
# refer_wav_path == "" or refer_wav_path is None
# or prompt_text == "" or prompt_text is None
# or prompt_language == "" or prompt_language is None
# ):
# refer_wav_path, prompt_text, prompt_language = (
# default_refer_path,
# default_refer_text,
# default_refer_language,
# )
# if not has_preset:
# raise HTTPException(status_code=400, detail="未指定参考音频且接口无预设")
#
# with torch.no_grad():
# gen = get_tts_wav(
# refer_wav_path, prompt_text, prompt_language, text, text_language
# )
# sampling_rate, audio_data = next(gen)
#
# wav = BytesIO()
# sf.write(wav, audio_data, sampling_rate, format="wav")
# wav.seek(0)
#
# torch.cuda.empty_cache()
# return StreamingResponse(wav, media_type="audio/wav")
# app = FastAPI()
#
#
# @app.post("/")
# async def tts_endpoint(request: Request):
# json_post_raw = await request.json()
# return handle(
# json_post_raw.get("command"),
# json_post_raw.get("refer_wav_path"),
# json_post_raw.get("prompt_text"),
# json_post_raw.get("prompt_language"),
# json_post_raw.get("text"),
# json_post_raw.get("text_language"),
# )
#
#
# @app.get("/")
# async def tts_endpoint(
# command: str = None,
# refer_wav_path: str = None,
# prompt_text: str = None,
# prompt_language: str = None,
# text: str = None,
# text_language: str = None,
# ):
# return handle(command, refer_wav_path, prompt_text, prompt_language, text, text_language)
#
#
# if __name__ == "__main__":
# uvicorn.run(app, host=host, port=port, workers=1)

View File

@ -144,7 +144,8 @@ webui_port_subfix = 9871
api_port = 9880 api_port = 9880
#Thanks to the contribution of @Karasukaigan and @XXXXRT666
# Thanks to the contribution of @Karasukaigan and @XXXXRT666
def get_device_dtype_sm(idx: int) -> tuple[torch.device, torch.dtype, float, float]: def get_device_dtype_sm(idx: int) -> tuple[torch.device, torch.dtype, float, float]:
cpu = torch.device("cpu") cpu = torch.device("cpu")
cuda = torch.device(f"cuda:{idx}") cuda = torch.device(f"cuda:{idx}")
@ -157,10 +158,13 @@ def get_device_dtype_sm(idx: int) -> tuple[torch.device, torch.dtype, float, flo
mem_gb = mem_bytes / (1024**3) + 0.4 mem_gb = mem_bytes / (1024**3) + 0.4
major, minor = capability major, minor = capability
sm_version = major + minor / 10.0 sm_version = major + minor / 10.0
is_16_series = bool(re.search(r"16\d{2}", name))and sm_version == 7.5 is_16_series = bool(re.search(r"16\d{2}", name)) and sm_version == 7.5
if mem_gb < 4 or sm_version < 5.3:return cpu, torch.float32, 0.0, 0.0 if mem_gb < 4 or sm_version < 5.3:
if sm_version == 6.1 or is_16_series==True:return cuda, torch.float32, sm_version, mem_gb return cpu, torch.float32, 0.0, 0.0
if sm_version > 6.1:return cuda, torch.float16, sm_version, mem_gb if sm_version == 6.1 or is_16_series == True:
return cuda, torch.float32, sm_version, mem_gb
if sm_version > 6.1:
return cuda, torch.float16, sm_version, mem_gb
return cpu, torch.float32, 0.0, 0.0 return cpu, torch.float32, 0.0, 0.0

View File

@ -12,10 +12,6 @@ services:
- "9880:9880" - "9880:9880"
volumes: volumes:
- .:/workspace/GPT-SoVITS - .:/workspace/GPT-SoVITS
- /dev/null:/workspace/GPT-SoVITS/GPT_SoVITS/pretrained_models
- /dev/null:/workspace/GPT-SoVITS/GPT_SoVITS/text/G2PWModel
- /dev/null:/workspace/GPT-SoVITS/tools/asr/models
- /dev/null:/workspace/GPT-SoVITS/tools/uvr5/uvr5_weights
environment: environment:
- is_half=true - is_half=true
tty: true tty: true
@ -34,10 +30,6 @@ services:
- "9880:9880" - "9880:9880"
volumes: volumes:
- .:/workspace/GPT-SoVITS - .:/workspace/GPT-SoVITS
- /dev/null:/workspace/GPT-SoVITS/GPT_SoVITS/pretrained_models
- /dev/null:/workspace/GPT-SoVITS/GPT_SoVITS/text/G2PWModel
- /dev/null:/workspace/GPT-SoVITS/tools/asr/models
- /dev/null:/workspace/GPT-SoVITS/tools/uvr5/uvr5_weights
- tools/asr/models:/workspace/models/asr_models - tools/asr/models:/workspace/models/asr_models
- tools/uvr5/uvr5_weights:/workspace/models/uvr5_weights - tools/uvr5/uvr5_weights:/workspace/models/uvr5_weights
environment: environment:
@ -58,10 +50,6 @@ services:
- "9880:9880" - "9880:9880"
volumes: volumes:
- .:/workspace/GPT-SoVITS - .:/workspace/GPT-SoVITS
- /dev/null:/workspace/GPT-SoVITS/GPT_SoVITS/pretrained_models
- /dev/null:/workspace/GPT-SoVITS/GPT_SoVITS/text/G2PWModel
- /dev/null:/workspace/GPT-SoVITS/tools/asr/models
- /dev/null:/workspace/GPT-SoVITS/tools/uvr5/uvr5_weights
environment: environment:
- is_half=true - is_half=true
tty: true tty: true
@ -80,10 +68,6 @@ services:
- "9880:9880" - "9880:9880"
volumes: volumes:
- .:/workspace/GPT-SoVITS - .:/workspace/GPT-SoVITS
- /dev/null:/workspace/GPT-SoVITS/GPT_SoVITS/pretrained_models
- /dev/null:/workspace/GPT-SoVITS/GPT_SoVITS/text/G2PWModel
- /dev/null:/workspace/GPT-SoVITS/tools/asr/models
- /dev/null:/workspace/GPT-SoVITS/tools/uvr5/uvr5_weights
- tools/asr/models:/workspace/models/asr_models - tools/asr/models:/workspace/models/asr_models
- tools/uvr5/uvr5_weights:/workspace/models/uvr5_weights - tools/uvr5/uvr5_weights:/workspace/models/uvr5_weights
environment: environment:

View File

@ -578,3 +578,49 @@
- 内容: 优化精度自动检测逻辑, 给 WebUI 前端界面模块增加折叠功能. - 内容: 优化精度自动检测逻辑, 给 WebUI 前端界面模块增加折叠功能.
- 类型: 新功能 - 类型: 新功能
- 提交: XXXXRT666, RVC-Boss - 提交: XXXXRT666, RVC-Boss
- 2025.06.06 [PR#2427](https://github.com/RVC-Boss/GPT-SoVITS/pull/2427)
- 内容: X一X型多音字判断修复
- 类型: 修复
- 提交: wzy3650
- 2025.06.05 [PR#2439](https://github.com/RVC-Boss/GPT-SoVITS/pull/2439)
- 内容: 配置修复sovits模型读取修复
- 类型: 修复
- 提交: wzy3650
- 2025.06.09 [Commit#8056efe4](https://github.com/RVC-Boss/GPT-SoVITS/commit/8056efe4ab7bbc3610c72ae356a6f37518441f7d)
- 内容: 修复ge.sum数值可能爆炸导致推理无声的问题
- 类型: 修复
- 提交: RVC-Boss
- 2025.06.10 [Commit#2c0436b9](https://github.com/RVC-Boss/GPT-SoVITS/commit/2c0436b9ce397424ae03476c836fb64c6e5ebcc6)
- 内容: 修复实验名结尾出现空格在win中路径不正确的问题
- 类型: 修复
- 提交: RVC-Boss
- 2025.06.10 [Commit#746cb536](https://github.com/RVC-Boss/GPT-SoVITS/commit/746cb536c68b1fe6ce3ca7e882235375b8a8dd89)
- 内容: 语种分割优化
- 类型: 优化
- 提交: KamioRinn
- 2025.06.11 [Commit#dd2b9253](https://github.com/RVC-Boss/GPT-SoVITS/commit/dd2b9253aabb09db32db7a3344570ed9df043351)
- 内容: 修复并行推理对v2pro支持bug
- 类型: 修复
- 提交: YYuX-1145
- 2025.06.11 [Commit#ed89a023](https://github.com/RVC-Boss/GPT-SoVITS/commit/ed89a023378dabba9d4b6580235bb9742245816d)
- 内容: v2pro对ge提取时会出现数值溢出的问题修复
- 类型: 修复
- 提交: RVC-Boss
- 2025.06.11 [Commit#37f5abfc](https://github.com/RVC-Boss/GPT-SoVITS/commit/6fdc67ca83418306f11e90b9139278313ac5c3e9)[Commit#6fdc67ca](https://github.com/RVC-Boss/GPT-SoVITS/commit/37f5abfcb4a6553652235909db2e124b6f8ff3a5)
- 内容: install.sh逻辑优化
- 类型: 优化
- 提交: XXXXRT666
- 2025.06.27 [Commit#90ebefa7](https://github.com/RVC-Boss/GPT-SoVITS/commit/90ebefa78fd544da36eebe0b2003620879c921b0)
- 内容: onnxruntime加载逻辑优化对gpu/cpu的判断
- 类型: 优化
- 提交: KamioRinn
- 2025.06.27 [Commit#6df61f58](https://github.com/RVC-Boss/GPT-SoVITS/commit/6df61f58e4d18d4c2ad9d1eddd6a1bd690034c23)
- 内容: 语言分割及格式化优化
- 类型: 优化
- 提交: KamioRinn
- 2025.07.10 [Commit#426e1a2bb](https://github.com/RVC-Boss/GPT-SoVITS/commit/426e1a2bb43614af2479b877c37acfb0591e952f)
- 内容: 提升推理进程优先级修复win11下可能GPU利用率受限的问题
- 类型: 修复
- 提交: XianYue0125

View File

@ -7,12 +7,18 @@
<a href="https://trendshift.io/repositories/7033" target="_blank"><img src="https://trendshift.io/api/badge/repositories/7033" alt="RVC-Boss%2FGPT-SoVITS | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/></a> <a href="https://trendshift.io/repositories/7033" target="_blank"><img src="https://trendshift.io/api/badge/repositories/7033" alt="RVC-Boss%2FGPT-SoVITS | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/></a>
<!-- img src="https://counter.seku.su/cmoe?name=gptsovits&theme=r34" /><br> --> [![Python](https://img.shields.io/badge/python-3.10--3.12-blue?style=for-the-badge&logo=python)](https://www.python.org)
[![GitHub release](https://img.shields.io/github/v/release/RVC-Boss/gpt-sovits?style=for-the-badge&logo=github)](https://github.com/RVC-Boss/gpt-sovits/releases)
[![Train In Colab](https://img.shields.io/badge/Colab-Training-F9AB00?style=for-the-badge&logo=googlecolab)](https://colab.research.google.com/github/RVC-Boss/GPT-SoVITS/blob/main/Colab-WebUI.ipynb)
[![Huggingface](https://img.shields.io/badge/免费在线体验-free_online_demo-yellow.svg?style=for-the-badge&logo=huggingface)](https://lj1995-gpt-sovits-proplus.hf.space/)
[![Image Size](https://img.shields.io/docker/image-size/xxxxrt666/gpt-sovits/latest?style=for-the-badge&logo=docker)](https://hub.docker.com/r/xxxxrt666/gpt-sovits)
[![简体中文](https://img.shields.io/badge/简体中文-阅读文档-blue?style=for-the-badge&logo=googledocs&logoColor=white)](https://www.yuque.com/baicaigongchang1145haoyuangong/ib3g1e)
[![English](https://img.shields.io/badge/English-Read%20Docs-blue?style=for-the-badge&logo=googledocs&logoColor=white)](https://rentry.co/GPT-SoVITS-guide#/)
[![Change Log](https://img.shields.io/badge/Change%20Log-View%20Updates-blue?style=for-the-badge&logo=googledocs&logoColor=white)](https://github.com/RVC-Boss/GPT-SoVITS/blob/main/docs/en/Changelog_EN.md)
[![License](https://img.shields.io/badge/LICENSE-MIT-green.svg?style=for-the-badge&logo=opensourceinitiative)](https://github.com/RVC-Boss/GPT-SoVITS/blob/main/LICENSE)
[![Open In Colab](https://img.shields.io/badge/Colab-F9AB00?style=for-the-badge&logo=googlecolab&color=525252)](https://colab.research.google.com/github/RVC-Boss/GPT-SoVITS/blob/main/colab_webui.ipynb)
[![License](https://img.shields.io/badge/LICENSE-MIT-green.svg?style=for-the-badge)](https://github.com/RVC-Boss/GPT-SoVITS/blob/main/LICENSE)
[![Huggingface](https://img.shields.io/badge/🤗%20-online%20demo-yellow.svg?style=for-the-badge)](https://huggingface.co/spaces/lj1995/GPT-SoVITS-v2)
[![Discord](https://img.shields.io/discord/1198701940511617164?color=%23738ADB&label=Discord&style=for-the-badge)](https://discord.gg/dnrgs5GHfG)
[**English**](../../README.md) | **中文简体** | [**日本語**](../ja/README.md) | [**한국어**](../ko/README.md) | [**Türkçe**](../tr/README.md) [**English**](../../README.md) | **中文简体** | [**日本語**](../ja/README.md) | [**한국어**](../ko/README.md) | [**Türkçe**](../tr/README.md)
@ -60,6 +66,12 @@
**中国地区的用户可以[在此处下载整合包](https://www.yuque.com/baicaigongchang1145haoyuangong/ib3g1e/dkxgpiy9zb96hob4#KTvnO).** **中国地区的用户可以[在此处下载整合包](https://www.yuque.com/baicaigongchang1145haoyuangong/ib3g1e/dkxgpiy9zb96hob4#KTvnO).**
```pwsh
conda create -n GPTSoVits python=3.10
conda activate GPTSoVits
pwsh -F install.ps1 --Device <CU126|CU128|CPU> --Source <HF|HF-Mirror|ModelScope> [--DownloadUVR5]
```
### Linux ### Linux
```bash ```bash
@ -128,8 +140,9 @@ brew install ffmpeg
- 前往 [Docker Hub](https://hub.docker.com/r/xxxxrt666/gpt-sovits) 查看最新可用的镜像标签(tags) - 前往 [Docker Hub](https://hub.docker.com/r/xxxxrt666/gpt-sovits) 查看最新可用的镜像标签(tags)
- 根据你的运行环境选择合适的镜像标签 - 根据你的运行环境选择合适的镜像标签
- `Lite` Docker 镜像不包含 ASR 模型和 UVR5 模型. 你可以自行下载 UVR5 模型, ASR 模型则会在需要时由程序自动下载 - `Lite` Docker 镜像**不包含** ASR 模型和 UVR5 模型. 你可以自行下载 UVR5 模型, ASR 模型则会在需要时由程序自动下载
- 在使用 Docker Compose 时, 会自动拉取适配的架构镜像 (amd64 或 arm64) - 在使用 Docker Compose 时, 会自动拉取适配的架构镜像 (amd64 或 arm64)
- Docker Compose 将会挂载当前目录的**所有文件**, 请在使用 Docker 镜像前先切换到项目根目录并**拉取代码更新**
- 可选:为了获得最新的更改, 你可以使用提供的 Dockerfile 在本地构建镜像 - 可选:为了获得最新的更改, 你可以使用提供的 Dockerfile 在本地构建镜像
#### 环境变量 #### 环境变量
@ -329,7 +342,7 @@ python webui.py
新特性: 新特性:
1. **相比 V2 占用稍高显存, 性能超过 V4, 在保留 V2 硬件成本和推理速度优势的同时实现更高音质.** 1. **相比 V2 占用稍高显存, 性能超过 V4, 在保留 V2 硬件成本和推理速度优势的同时实现更高音质.**
[更多详情](https://github.com/RVC-Boss/GPT-SoVITS/wiki/GPT%E2%80%90SoVITS%E2%80%90features-(%E5%90%84%E7%89%88%E6%9C%AC%E7%89%B9%E6%80%A7)) [更多详情](<https://github.com/RVC-Boss/GPT-SoVITS/wiki/GPT%E2%80%90SoVITS%E2%80%90features-(%E5%90%84%E7%89%88%E6%9C%AC%E7%89%B9%E6%80%A7)>)
2. V1/V2 与 V2Pro 系列具有相同特性, V3/V4 则具备相近功能. 对于平均音频质量较低的训练集, V1/V2/V2Pro 可以取得较好的效果, 但 V3/V4 无法做到. 此外, V3/V4 合成的声音更偏向参考音频, 而不是整体训练集的风格. 2. V1/V2 与 V2Pro 系列具有相同特性, V3/V4 则具备相近功能. 对于平均音频质量较低的训练集, V1/V2/V2Pro 可以取得较好的效果, 但 V3/V4 无法做到. 此外, V3/V4 合成的声音更偏向参考音频, 而不是整体训练集的风格.

View File

@ -5,12 +5,20 @@
[![madewithlove](https://img.shields.io/badge/made_with-%E2%9D%A4-red?style=for-the-badge&labelColor=orange)](https://github.com/RVC-Boss/GPT-SoVITS) [![madewithlove](https://img.shields.io/badge/made_with-%E2%9D%A4-red?style=for-the-badge&labelColor=orange)](https://github.com/RVC-Boss/GPT-SoVITS)
<img src="https://counter.seku.su/cmoe?name=gptsovits&theme=r34" /><br> <a href="https://trendshift.io/repositories/7033" target="_blank"><img src="https://trendshift.io/api/badge/repositories/7033" alt="RVC-Boss%2FGPT-SoVITS | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/></a>
[![Python](https://img.shields.io/badge/python-3.10--3.12-blue?style=for-the-badge&logo=python)](https://www.python.org)
[![GitHub release](https://img.shields.io/github/v/release/RVC-Boss/gpt-sovits?style=for-the-badge&logo=github)](https://github.com/RVC-Boss/gpt-sovits/releases)
[![Train In Colab](https://img.shields.io/badge/Colab-Training-F9AB00?style=for-the-badge&logo=googlecolab)](https://colab.research.google.com/github/RVC-Boss/GPT-SoVITS/blob/main/Colab-WebUI.ipynb)
[![Huggingface](https://img.shields.io/badge/免费在线体验-free_online_demo-yellow.svg?style=for-the-badge&logo=huggingface)](https://lj1995-gpt-sovits-proplus.hf.space/)
[![Image Size](https://img.shields.io/docker/image-size/xxxxrt666/gpt-sovits/latest?style=for-the-badge&logo=docker)](https://hub.docker.com/r/xxxxrt666/gpt-sovits)
[![简体中文](https://img.shields.io/badge/简体中文-阅读文档-blue?style=for-the-badge&logo=googledocs&logoColor=white)](https://www.yuque.com/baicaigongchang1145haoyuangong/ib3g1e)
[![English](https://img.shields.io/badge/English-Read%20Docs-blue?style=for-the-badge&logo=googledocs&logoColor=white)](https://rentry.co/GPT-SoVITS-guide#/)
[![Change Log](https://img.shields.io/badge/Change%20Log-View%20Updates-blue?style=for-the-badge&logo=googledocs&logoColor=white)](https://github.com/RVC-Boss/GPT-SoVITS/blob/main/docs/en/Changelog_EN.md)
[![License](https://img.shields.io/badge/LICENSE-MIT-green.svg?style=for-the-badge&logo=opensourceinitiative)](https://github.com/RVC-Boss/GPT-SoVITS/blob/main/LICENSE)
[![Open In Colab](https://img.shields.io/badge/Colab-F9AB00?style=for-the-badge&logo=googlecolab&color=525252)](https://colab.research.google.com/github/RVC-Boss/GPT-SoVITS/blob/main/colab_webui.ipynb)
[![License](https://img.shields.io/badge/LICENSE-MIT-green.svg?style=for-the-badge)](https://github.com/RVC-Boss/GPT-SoVITS/blob/main/LICENSE)
[![Huggingface](https://img.shields.io/badge/🤗%20-online%20demo-yellow.svg?style=for-the-badge)](https://huggingface.co/spaces/lj1995/GPT-SoVITS-v2)
[![Discord](https://img.shields.io/discord/1198701940511617164?color=%23738ADB&label=Discord&style=for-the-badge)](https://discord.gg/dnrgs5GHfG)
[**English**](../../README.md) | [**中文简体**](../cn/README.md) | **日本語** | [**한국어**](../ko/README.md) | [**Türkçe**](../tr/README.md) [**English**](../../README.md) | [**中文简体**](../cn/README.md) | **日本語** | [**한국어**](../ko/README.md) | [**Türkçe**](../tr/README.md)
@ -122,8 +130,9 @@ brew install ffmpeg
- [Docker Hub](https://hub.docker.com/r/xxxxrt666/gpt-sovits) で最新のイメージタグを確認してください - [Docker Hub](https://hub.docker.com/r/xxxxrt666/gpt-sovits) で最新のイメージタグを確認してください
- 環境に合った適切なイメージタグを選択してください - 環境に合った適切なイメージタグを選択してください
- `Lite` とは、Docker イメージに ASR モデルおよび UVR5 モデルが含まれていないことを意味します. UVR5 モデルは手動でダウンロードし、ASR モデルは必要に応じてプログラムが自動的にダウンロードします - `Lite` とは、Docker イメージに ASR モデルおよび UVR5 モデルが**含まれていない**ことを意味します. UVR5 モデルは手動でダウンロードし、ASR モデルは必要に応じてプログラムが自動的にダウンロードします
- Docker Compose 実行時に、対応するアーキテクチャ (amd64 または arm64) のイメージが自動的に取得されます - Docker Compose 実行時に、対応するアーキテクチャ (amd64 または arm64) のイメージが自動的に取得されます
- Docker Compose は現在のディレクトリ内の**すべてのファイル**をマウントします. Docker イメージを使用する前に、プロジェクトのルートディレクトリに移動し、**コードを最新の状態に更新**してください
- オプション:最新の変更を反映させるため、提供されている Dockerfile を使ってローカルでイメージをビルドすることも可能です - オプション:最新の変更を反映させるため、提供されている Dockerfile を使ってローカルでイメージをビルドすることも可能です
#### 環境変数 #### 環境変数
@ -304,7 +313,7 @@ v2 環境から v3 を使用する方法:
新機能: 新機能:
1. **V4 は、V3 で発生していた非整数倍アップサンプリングによる金属音の問題を修正し、音声がこもる問題を防ぐためにネイティブに 48kHz 音声を出力しますV3 はネイティブに 24kHz 音声のみ出力)**. 作者は V4 を V3 の直接的な置き換えとして推奨していますが、さらなるテストが必要です. 1. **V4 は、V3 で発生していた非整数倍アップサンプリングによる金属音の問題を修正し、音声がこもる問題を防ぐためにネイティブに 48kHz 音声を出力しますV3 はネイティブに 24kHz 音声のみ出力)**. 作者は V4 を V3 の直接的な置き換えとして推奨していますが、さらなるテストが必要です.
[詳細はこちら](https://github.com/RVC-Boss/GPT-SoVITS/wiki/GPT%E2%80%90SoVITS%E2%80%90v3v4%E2%80%90features-(%E6%96%B0%E7%89%B9%E6%80%A7)) [詳細はこちら](<https://github.com/RVC-Boss/GPT-SoVITS/wiki/GPT%E2%80%90SoVITS%E2%80%90v3v4%E2%80%90features-(%E6%96%B0%E7%89%B9%E6%80%A7)>)
V1/V2/V3 環境から V4 への移行方法: V1/V2/V3 環境から V4 への移行方法:
@ -319,7 +328,7 @@ V1/V2/V3 環境から V4 への移行方法:
新機能: 新機能:
1. **V2 と比較してやや高いメモリ使用量ですが、ハードウェアコストと推論速度は維持しつつ、V4 よりも高い性能と音質を実現します. ** 1. **V2 と比較してやや高いメモリ使用量ですが、ハードウェアコストと推論速度は維持しつつ、V4 よりも高い性能と音質を実現します. **
[詳細はこちら](https://github.com/RVC-Boss/GPT-SoVITS/wiki/GPT%E2%80%90SoVITS%E2%80%90features-(%E5%90%84%E7%89%88%E6%9C%AC%E7%89%B9%E6%80%A7)) [詳細はこちら](<https://github.com/RVC-Boss/GPT-SoVITS/wiki/GPT%E2%80%90SoVITS%E2%80%90features-(%E5%90%84%E7%89%88%E6%9C%AC%E7%89%B9%E6%80%A7)>)
2. V1/V2 と V2Pro シリーズは類似した特徴を持ち、V3/V4 も同様の機能を持っています. 平均音質が低いトレーニングセットの場合、V1/V2/V2Pro は良好な結果を出すことができますが、V3/V4 では対応できません. また、V3/V4 の合成音声はトレーニング全体ではなく、より参考音声に寄った音質になります. 2. V1/V2 と V2Pro シリーズは類似した特徴を持ち、V3/V4 も同様の機能を持っています. 平均音質が低いトレーニングセットの場合、V1/V2/V2Pro は良好な結果を出すことができますが、V3/V4 では対応できません. また、V3/V4 の合成音声はトレーニング全体ではなく、より参考音声に寄った音質になります.

View File

@ -5,12 +5,20 @@
[![madewithlove](https://img.shields.io/badge/made_with-%E2%9D%A4-red?style=for-the-badge&labelColor=orange)](https://github.com/RVC-Boss/GPT-SoVITS) [![madewithlove](https://img.shields.io/badge/made_with-%E2%9D%A4-red?style=for-the-badge&labelColor=orange)](https://github.com/RVC-Boss/GPT-SoVITS)
<img src="https://counter.seku.su/cmoe?name=gptsovits&theme=r34" /><br> <a href="https://trendshift.io/repositories/7033" target="_blank"><img src="https://trendshift.io/api/badge/repositories/7033" alt="RVC-Boss%2FGPT-SoVITS | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/></a>
[![Python](https://img.shields.io/badge/python-3.10--3.12-blue?style=for-the-badge&logo=python)](https://www.python.org)
[![GitHub release](https://img.shields.io/github/v/release/RVC-Boss/gpt-sovits?style=for-the-badge&logo=github)](https://github.com/RVC-Boss/gpt-sovits/releases)
[![Train In Colab](https://img.shields.io/badge/Colab-Training-F9AB00?style=for-the-badge&logo=googlecolab)](https://colab.research.google.com/github/RVC-Boss/GPT-SoVITS/blob/main/Colab-WebUI.ipynb)
[![Huggingface](https://img.shields.io/badge/免费在线体验-free_online_demo-yellow.svg?style=for-the-badge&logo=huggingface)](https://lj1995-gpt-sovits-proplus.hf.space/)
[![Image Size](https://img.shields.io/docker/image-size/xxxxrt666/gpt-sovits/latest?style=for-the-badge&logo=docker)](https://hub.docker.com/r/xxxxrt666/gpt-sovits)
[![简体中文](https://img.shields.io/badge/简体中文-阅读文档-blue?style=for-the-badge&logo=googledocs&logoColor=white)](https://www.yuque.com/baicaigongchang1145haoyuangong/ib3g1e)
[![English](https://img.shields.io/badge/English-Read%20Docs-blue?style=for-the-badge&logo=googledocs&logoColor=white)](https://rentry.co/GPT-SoVITS-guide#/)
[![Change Log](https://img.shields.io/badge/Change%20Log-View%20Updates-blue?style=for-the-badge&logo=googledocs&logoColor=white)](https://github.com/RVC-Boss/GPT-SoVITS/blob/main/docs/en/Changelog_EN.md)
[![License](https://img.shields.io/badge/LICENSE-MIT-green.svg?style=for-the-badge&logo=opensourceinitiative)](https://github.com/RVC-Boss/GPT-SoVITS/blob/main/LICENSE)
[![Open In Colab](https://img.shields.io/badge/Colab-F9AB00?style=for-the-badge&logo=googlecolab&color=525252)](https://colab.research.google.com/github/RVC-Boss/GPT-SoVITS/blob/main/colab_webui.ipynb)
[![License](https://img.shields.io/badge/LICENSE-MIT-green.svg?style=for-the-badge)](https://github.com/RVC-Boss/GPT-SoVITS/blob/main/LICENSE)
[![Huggingface](https://img.shields.io/badge/🤗%20-online%20demo-yellow.svg?style=for-the-badge)](https://huggingface.co/spaces/lj1995/GPT-SoVITS-v2)
[![Discord](https://img.shields.io/discord/1198701940511617164?color=%23738ADB&label=Discord&style=for-the-badge)](https://discord.gg/dnrgs5GHfG)
[**English**](../../README.md) | [**中文简体**](../cn/README.md) | [**日本語**](../ja/README.md) | **한국어** | [**Türkçe**](../tr/README.md) [**English**](../../README.md) | [**中文简体**](../cn/README.md) | [**日本語**](../ja/README.md) | **한국어** | [**Türkçe**](../tr/README.md)
@ -54,6 +62,12 @@ https://github.com/RVC-Boss/GPT-SoVITS/assets/129054828/05bee1fa-bdd8-4d85-9350-
Windows 사용자라면 (win>=10에서 테스트됨), [통합 패키지를 다운로드](https://huggingface.co/lj1995/GPT-SoVITS-windows-package/resolve/main/GPT-SoVITS-v3lora-20250228.7z?download=true)한 후 압축을 풀고 _go-webui.bat_ 파일을 더블 클릭하면 GPT-SoVITS-WebUI를 시작할 수 있습니다. Windows 사용자라면 (win>=10에서 테스트됨), [통합 패키지를 다운로드](https://huggingface.co/lj1995/GPT-SoVITS-windows-package/resolve/main/GPT-SoVITS-v3lora-20250228.7z?download=true)한 후 압축을 풀고 _go-webui.bat_ 파일을 더블 클릭하면 GPT-SoVITS-WebUI를 시작할 수 있습니다.
```pwsh
conda create -n GPTSoVits python=3.10
conda activate GPTSoVits
pwsh -F install.ps1 --Device <CU126|CU128|CPU> --Source <HF|HF-Mirror|ModelScope> [--DownloadUVR5]
```
### Linux ### Linux
```bash ```bash
@ -122,8 +136,9 @@ brew install ffmpeg
- [Docker Hub](https://hub.docker.com/r/xxxxrt666/gpt-sovits)에서 최신 이미지 태그를 확인하세요 - [Docker Hub](https://hub.docker.com/r/xxxxrt666/gpt-sovits)에서 최신 이미지 태그를 확인하세요
- 환경에 맞는 적절한 이미지 태그를 선택하세요 - 환경에 맞는 적절한 이미지 태그를 선택하세요
- `Lite` 는 Docker 이미지에 ASR 모델과 UVR5 모델이 포함되어 있지 않음을 의미합니다. UVR5 모델은 사용자가 직접 다운로드해야 하며, ASR 모델은 필요 시 프로그램이 자동으로 다운로드합니다 - `Lite` 는 Docker 이미지에 ASR 모델과 UVR5 모델이 **포함되어 있지 않음**을 의미합니다. UVR5 모델은 사용자가 직접 다운로드해야 하며, ASR 모델은 필요 시 프로그램이 자동으로 다운로드합니다
- Docker Compose 실행 시, 해당 아키텍처에 맞는 이미지(amd64 또는 arm64)가 자동으로 다운로드됩니다 - Docker Compose 실행 시, 해당 아키텍처에 맞는 이미지(amd64 또는 arm64)가 자동으로 다운로드됩니다
- Docker Compose는 현재 디렉터리의 **모든 파일**을 마운트합니다. Docker 이미지를 사용하기 전에 프로젝트 루트 디렉터리로 이동하여 코드를 **최신 상태로 업데이트**하세요
- 선택 사항: 최신 변경사항을 반영하려면 제공된 Dockerfile을 사용하여 로컬에서 직접 이미지를 빌드할 수 있습니다 - 선택 사항: 최신 변경사항을 반영하려면 제공된 Dockerfile을 사용하여 로컬에서 직접 이미지를 빌드할 수 있습니다
#### 환경 변수 #### 환경 변수
@ -319,7 +334,7 @@ V1/V2/V3 환경에서 V4로 전환 방법:
신규 기능: 신규 기능:
1. **V2보다 약간 높은 VRAM 사용량이지만 성능은 V4보다 우수하며, V2 수준의 하드웨어 비용과 속도를 유지합니다**. 1. **V2보다 약간 높은 VRAM 사용량이지만 성능은 V4보다 우수하며, V2 수준의 하드웨어 비용과 속도를 유지합니다**.
[자세히 보기](https://github.com/RVC-Boss/GPT-SoVITS/wiki/GPT%E2%80%90SoVITS%E2%80%90features-(%E5%90%84%E7%89%88%E6%9C%AC%E7%89%B9%E6%80%A7)) [자세히 보기](<https://github.com/RVC-Boss/GPT-SoVITS/wiki/GPT%E2%80%90SoVITS%E2%80%90features-(%E5%90%84%E7%89%88%E6%9C%AC%E7%89%B9%E6%80%A7)>)
2. V1/V2와 V2Pro 시리즈는 유사한 특징을 가지며, V3/V4도 비슷한 기능을 가지고 있습니다. 평균 음질이 낮은 학습 데이터셋에서는 V1/V2/V2Pro가 좋은 결과를 내지만 V3/V4는 그렇지 못합니다. 또한 V3/V4의 합성 음색은 전체 학습 데이터셋보다는 참고 음성에 더 가깝습니다. 2. V1/V2와 V2Pro 시리즈는 유사한 특징을 가지며, V3/V4도 비슷한 기능을 가지고 있습니다. 평균 음질이 낮은 학습 데이터셋에서는 V1/V2/V2Pro가 좋은 결과를 내지만 V3/V4는 그렇지 못합니다. 또한 V3/V4의 합성 음색은 전체 학습 데이터셋보다는 참고 음성에 더 가깝습니다.

View File

@ -7,12 +7,17 @@ Güçlü Birkaç Örnekli Ses Dönüştürme ve Metinden Konuşmaya Web Arayüz
<a href="https://trendshift.io/repositories/7033" target="_blank"><img src="https://trendshift.io/api/badge/repositories/7033" alt="RVC-Boss%2FGPT-SoVITS | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/></a> <a href="https://trendshift.io/repositories/7033" target="_blank"><img src="https://trendshift.io/api/badge/repositories/7033" alt="RVC-Boss%2FGPT-SoVITS | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/></a>
<!-- img src="https://counter.seku.su/cmoe?name=gptsovits&theme=r34" /><br> --> [![Python](https://img.shields.io/badge/python-3.10--3.12-blue?style=for-the-badge&logo=python)](https://www.python.org)
[![GitHub release](https://img.shields.io/github/v/release/RVC-Boss/gpt-sovits?style=for-the-badge&logo=github)](https://github.com/RVC-Boss/gpt-sovits/releases)
[![Open In Colab](https://img.shields.io/badge/Colab-F9AB00?style=for-the-badge&logo=googlecolab&color=525252)](https://colab.research.google.com/github/RVC-Boss/GPT-SoVITS/blob/main/colab_webui.ipynb) [![Train In Colab](https://img.shields.io/badge/Colab-Training-F9AB00?style=for-the-badge&logo=googlecolab)](https://colab.research.google.com/github/RVC-Boss/GPT-SoVITS/blob/main/Colab-WebUI.ipynb)
[![License](https://img.shields.io/badge/LICENSE-MIT-green.svg?style=for-the-badge)](https://github.com/RVC-Boss/GPT-SoVITS/blob/main/LICENSE) [![Huggingface](https://img.shields.io/badge/免费在线体验-free_online_demo-yellow.svg?style=for-the-badge&logo=huggingface)](https://lj1995-gpt-sovits-proplus.hf.space/)
[![Huggingface](https://img.shields.io/badge/🤗%20-online%20demo-yellow.svg?style=for-the-badge)](https://huggingface.co/spaces/lj1995/GPT-SoVITS-v2) [![Image Size](https://img.shields.io/docker/image-size/xxxxrt666/gpt-sovits/latest?style=for-the-badge&logo=docker)](https://hub.docker.com/r/xxxxrt666/gpt-sovits)
[![Discord](https://img.shields.io/discord/1198701940511617164?color=%23738ADB&label=Discord&style=for-the-badge)](https://discord.gg/dnrgs5GHfG)
[![简体中文](https://img.shields.io/badge/简体中文-阅读文档-blue?style=for-the-badge&logo=googledocs&logoColor=white)](https://www.yuque.com/baicaigongchang1145haoyuangong/ib3g1e)
[![English](https://img.shields.io/badge/English-Read%20Docs-blue?style=for-the-badge&logo=googledocs&logoColor=white)](https://rentry.co/GPT-SoVITS-guide#/)
[![Change Log](https://img.shields.io/badge/Change%20Log-View%20Updates-blue?style=for-the-badge&logo=googledocs&logoColor=white)](https://github.com/RVC-Boss/GPT-SoVITS/blob/main/docs/en/Changelog_EN.md)
[![License](https://img.shields.io/badge/LICENSE-MIT-green.svg?style=for-the-badge&logo=opensourceinitiative)](https://github.com/RVC-Boss/GPT-SoVITS/blob/main/LICENSE)
[**English**](../../README.md) | [**中文简体**](../cn/README.md) | [**日本語**](../ja/README.md) | [**한국어**](../ko/README.md) | **Türkçe** [**English**](../../README.md) | [**中文简体**](../cn/README.md) | [**日本語**](../ja/README.md) | [**한국어**](../ko/README.md) | **Türkçe**
@ -56,6 +61,12 @@ https://github.com/RVC-Boss/GPT-SoVITS/assets/129054828/05bee1fa-bdd8-4d85-9350-
Eğer bir Windows kullanıcısıysanız (win>=10 ile test edilmiştir), [entegre paketi indirin](https://huggingface.co/lj1995/GPT-SoVITS-windows-package/resolve/main/GPT-SoVITS-v3lora-20250228.7z?download=true) ve _go-webui.bat_ dosyasına çift tıklayarak GPT-SoVITS-WebUI'yi başlatın. Eğer bir Windows kullanıcısıysanız (win>=10 ile test edilmiştir), [entegre paketi indirin](https://huggingface.co/lj1995/GPT-SoVITS-windows-package/resolve/main/GPT-SoVITS-v3lora-20250228.7z?download=true) ve _go-webui.bat_ dosyasına çift tıklayarak GPT-SoVITS-WebUI'yi başlatın.
```pwsh
conda create -n GPTSoVits python=3.10
conda activate GPTSoVits
pwsh -F install.ps1 --Device <CU126|CU128|CPU> --Source <HF|HF-Mirror|ModelScope> [--DownloadUVR5]
```
### Linux ### Linux
```bash ```bash
@ -124,8 +135,9 @@ Kod tabanı hızla geliştiği halde Docker imajları daha yavaş yayınlandığ
- En güncel kullanılabilir imaj etiketlerini görmek için [Docker Hub](https://hub.docker.com/r/xxxxrt666/gpt-sovits) adresini kontrol edin - En güncel kullanılabilir imaj etiketlerini görmek için [Docker Hub](https://hub.docker.com/r/xxxxrt666/gpt-sovits) adresini kontrol edin
- Ortamınıza uygun bir imaj etiketi seçin - Ortamınıza uygun bir imaj etiketi seçin
- `Lite`, Docker imajında ASR modelleri ve UVR5 modellerinin bulunmadığı anlamına gelir. UVR5 modellerini manuel olarak indirebilirsiniz; ASR modelleri ise gerektiğinde program tarafından otomatik olarak indirilir - `Lite`, Docker imajında ASR modelleri ve UVR5 modellerinin **bulunmadığı** anlamına gelir. UVR5 modellerini manuel olarak indirebilirsiniz; ASR modelleri ise gerektiğinde program tarafından otomatik olarak indirilir
- Docker Compose sırasında, uygun mimariye (amd64 veya arm64) ait imaj otomatik olarak indirilir - Docker Compose sırasında, uygun mimariye (amd64 veya arm64) ait imaj otomatik olarak indirilir
- Docker Compose, mevcut dizindeki **tüm dosyaları** bağlayacaktır. Docker imajını kullanmadan önce lütfen proje kök dizinine geçin ve **en son kodu çekin**
- Opsiyonel: En güncel değişiklikleri almak için, sağlanan Dockerfile ile yerel olarak imajı kendiniz oluşturabilirsiniz - Opsiyonel: En güncel değişiklikleri almak için, sağlanan Dockerfile ile yerel olarak imajı kendiniz oluşturabilirsiniz
#### Ortam Değişkenleri #### Ortam Değişkenleri
@ -323,7 +335,7 @@ V1/V2/V3 ortamından V4'e geçiş:
Yeni Özellikler: Yeni Özellikler:
1. **V2 ile karşılaştırıldığında biraz daha yüksek VRAM kullanımı sağlar ancak V4'ten daha iyi performans gösterir; aynı donanım maliyeti ve hız avantajını korur**. 1. **V2 ile karşılaştırıldığında biraz daha yüksek VRAM kullanımı sağlar ancak V4'ten daha iyi performans gösterir; aynı donanım maliyeti ve hız avantajını korur**.
[Daha fazla bilgi](https://github.com/RVC-Boss/GPT-SoVITS/wiki/GPT%E2%80%90SoVITS%E2%80%90features-(%E5%90%84%E7%89%88%E6%9C%AC%E7%89%B9%E6%80%A7)) [Daha fazla bilgi](<https://github.com/RVC-Boss/GPT-SoVITS/wiki/GPT%E2%80%90SoVITS%E2%80%90features-(%E5%90%84%E7%89%88%E6%9C%AC%E7%89%B9%E6%80%A7)>)
2. V1/V2 ve V2Pro serisi benzer özelliklere sahipken, V3/V4 de yakın işlevleri paylaşır. Ortalama kalite düşük olan eğitim setleriyle V1/V2/V2Pro iyi sonuçlar verebilir ama V3/V4 veremez. Ayrıca, V3/V4ün ürettiği ses tonu genel eğitim setine değil, referans ses örneğine daha çok benzemektedir. 2. V1/V2 ve V2Pro serisi benzer özelliklere sahipken, V3/V4 de yakın işlevleri paylaşır. Ortalama kalite düşük olan eğitim setleriyle V1/V2/V2Pro iyi sonuçlar verebilir ama V3/V4 veremez. Ayrıca, V3/V4ün ürettiği ses tonu genel eğitim setine değil, referans ses örneğine daha çok benzemektedir.

241
install.ps1 Normal file
View File

@ -0,0 +1,241 @@
Param (
[Parameter(Mandatory=$true)][ValidateSet("CU126", "CU128", "CPU")][string]$Device,
[Parameter(Mandatory=$true)][ValidateSet("HF", "HF-Mirror", "ModelScope")][string]$Source,
[switch]$DownloadUVR5
)
$global:ErrorActionPreference = 'Stop'
trap {
Write-ErrorLog $_
}
function Write-ErrorLog {
param (
[System.Management.Automation.ErrorRecord]$ErrorRecord
)
Write-Host "`n[ERROR] Command failed:" -ForegroundColor Red
if (-not $ErrorRecord.Exception.Message){
} else {
Write-Host "Message:" -ForegroundColor Red
$ErrorRecord.Exception.Message -split "`n" | ForEach-Object {
Write-Host " $_"
}
}
Write-Host "Command:" -ForegroundColor Red -NoNewline
Write-Host " $($ErrorRecord.InvocationInfo.Line)".Replace("`r", "").Replace("`n", "")
Write-Host "Location:" -ForegroundColor Red -NoNewline
Write-Host " $($ErrorRecord.InvocationInfo.ScriptName):$($ErrorRecord.InvocationInfo.ScriptLineNumber)"
Write-Host "Call Stack:" -ForegroundColor DarkRed
$ErrorRecord.ScriptStackTrace -split "`n" | ForEach-Object {
Write-Host " $_" -ForegroundColor DarkRed
}
exit 1
}
function Write-Info($msg) {
Write-Host "[INFO]:" -ForegroundColor Green -NoNewline
Write-Host " $msg"
}
function Write-Success($msg) {
Write-Host "[SUCCESS]:" -ForegroundColor Blue -NoNewline
Write-Host " $msg"
}
function Invoke-Conda {
param (
[Parameter(ValueFromRemainingArguments = $true)]
[string[]]$Args
)
$output = & conda install -y -q -c conda-forge @Args 2>&1
$exitCode = $LASTEXITCODE
if ($exitCode -ne 0) {
Write-Host "Conda Install $Args Failed" -ForegroundColor Red
$errorMessages = @()
foreach ($item in $output) {
if ($item -is [System.Management.Automation.ErrorRecord]) {
$msg = $item.Exception.Message
Write-Host "$msg" -ForegroundColor Red
$errorMessages += $msg
}
else {
Write-Host $item
$errorMessages += $item
}
}
throw [System.Exception]::new(($errorMessages -join "`n"))
}
}
function Invoke-Pip {
param (
[Parameter(ValueFromRemainingArguments = $true)]
[string[]]$Args
)
$output = & pip install @Args 2>&1
$exitCode = $LASTEXITCODE
if ($exitCode -ne 0) {
$errorMessages = @()
Write-Host "Pip Install $Args Failed" -ForegroundColor Red
foreach ($item in $output) {
if ($item -is [System.Management.Automation.ErrorRecord]) {
$msg = $item.Exception.Message
Write-Host "$msg" -ForegroundColor Red
$errorMessages += $msg
}
else {
Write-Host $item
$errorMessages += $item
}
}
throw [System.Exception]::new(($errorMessages -join "`n"))
}
}
function Invoke-Download {
param (
[Parameter(Mandatory = $true)]
[string]$Uri,
[Parameter()]
[string]$OutFile
)
try {
$params = @{
Uri = $Uri
}
if ($OutFile) {
$params["OutFile"] = $OutFile
}
$null = Invoke-WebRequest @params -ErrorAction Stop
} catch {
Write-Host "Failed to download:" -ForegroundColor Red
Write-Host " $Uri"
throw
}
}
function Invoke-Unzip {
param($ZipPath, $DestPath)
Expand-Archive -Path $ZipPath -DestinationPath $DestPath -Force
Remove-Item $ZipPath -Force
}
chcp 65001
Set-Location $PSScriptRoot
Write-Info "Installing FFmpeg & CMake..."
Invoke-Conda ffmpeg cmake
Write-Success "FFmpeg & CMake Installed"
$PretrainedURL = ""
$G2PWURL = ""
$UVR5URL = ""
$NLTKURL = ""
$OpenJTalkURL = ""
switch ($Source) {
"HF" {
Write-Info "Download Model From HuggingFace"
$PretrainedURL = "https://huggingface.co/XXXXRT/GPT-SoVITS-Pretrained/resolve/main/pretrained_models.zip"
$G2PWURL = "https://huggingface.co/XXXXRT/GPT-SoVITS-Pretrained/resolve/main/G2PWModel.zip"
$UVR5URL = "https://huggingface.co/XXXXRT/GPT-SoVITS-Pretrained/resolve/main/uvr5_weights.zip"
$NLTKURL = "https://huggingface.co/XXXXRT/GPT-SoVITS-Pretrained/resolve/main/nltk_data.zip"
$OpenJTalkURL = "https://huggingface.co/XXXXRT/GPT-SoVITS-Pretrained/resolve/main/open_jtalk_dic_utf_8-1.11.tar.gz"
}
"HF-Mirror" {
Write-Info "Download Model From HuggingFace-Mirror"
$PretrainedURL = "https://hf-mirror.com/XXXXRT/GPT-SoVITS-Pretrained/resolve/main/pretrained_models.zip"
$G2PWURL = "https://hf-mirror.com/XXXXRT/GPT-SoVITS-Pretrained/resolve/main/G2PWModel.zip"
$UVR5URL = "https://hf-mirror.com/XXXXRT/GPT-SoVITS-Pretrained/resolve/main/uvr5_weights.zip"
$NLTKURL = "https://hf-mirror.com/XXXXRT/GPT-SoVITS-Pretrained/resolve/main/nltk_data.zip"
$OpenJTalkURL = "https://hf-mirror.com/XXXXRT/GPT-SoVITS-Pretrained/resolve/main/open_jtalk_dic_utf_8-1.11.tar.gz"
}
"ModelScope" {
Write-Info "Download Model From ModelScope"
$PretrainedURL = "https://www.modelscope.cn/models/XXXXRT/GPT-SoVITS-Pretrained/resolve/master/pretrained_models.zip"
$G2PWURL = "https://www.modelscope.cn/models/XXXXRT/GPT-SoVITS-Pretrained/resolve/master/G2PWModel.zip"
$UVR5URL = "https://www.modelscope.cn/models/XXXXRT/GPT-SoVITS-Pretrained/resolve/master/uvr5_weights.zip"
$NLTKURL = "https://www.modelscope.cn/models/XXXXRT/GPT-SoVITS-Pretrained/resolve/master/nltk_data.zip"
$OpenJTalkURL = "https://www.modelscope.cn/models/XXXXRT/GPT-SoVITS-Pretrained/resolve/master/open_jtalk_dic_utf_8-1.11.tar.gz"
}
}
if (-not (Test-Path "GPT_SoVITS/pretrained_models/sv")) {
Write-Info "Downloading Pretrained Models..."
Invoke-Download -Uri $PretrainedURL -OutFile "pretrained_models.zip"
Invoke-Unzip "pretrained_models.zip" "GPT_SoVITS"
Write-Success "Pretrained Models Downloaded"
} else {
Write-Info "Pretrained Model Exists"
Write-Info "Skip Downloading Pretrained Models"
}
if (-not (Test-Path "GPT_SoVITS/text/G2PWModel")) {
Write-Info "Downloading G2PWModel..."
Invoke-Download -Uri $G2PWURL -OutFile "G2PWModel.zip"
Invoke-Unzip "G2PWModel.zip" "GPT_SoVITS/text"
Write-Success "G2PWModel Downloaded"
} else {
Write-Info "G2PWModel Exists"
Write-Info "Skip Downloading G2PWModel"
}
if ($DownloadUVR5) {
if (-not (Test-Path "tools/uvr5/uvr5_weights")) {
Write-Info "Downloading UVR5 Models..."
Invoke-Download -Uri $UVR5URL -OutFile "uvr5_weights.zip"
Invoke-Unzip "uvr5_weights.zip" "tools/uvr5"
Write-Success "UVR5 Models Downloaded"
} else {
Write-Info "UVR5 Models Exists"
Write-Info "Skip Downloading UVR5 Models"
}
}
switch ($Device) {
"CU128" {
Write-Info "Installing PyTorch For CUDA 12.8..."
Invoke-Pip torch torchaudio --index-url "https://download.pytorch.org/whl/cu128"
}
"CU126" {
Write-Info "Installing PyTorch For CUDA 12.6..."
Invoke-Pip torch torchaudio --index-url "https://download.pytorch.org/whl/cu126"
}
"CPU" {
Write-Info "Installing PyTorch For CPU..."
Invoke-Pip torch torchaudio --index-url "https://download.pytorch.org/whl/cpu"
}
}
Write-Success "PyTorch Installed"
Write-Info "Installing Python Dependencies From requirements.txt..."
Invoke-Pip -r extra-req.txt --no-deps
Invoke-Pip -r requirements.txt
Write-Success "Python Dependencies Installed"
Write-Info "Downloading NLTK Data..."
Invoke-Download -Uri $NLTKURL -OutFile "nltk_data.zip"
Invoke-Unzip "nltk_data.zip" (python -c "import sys; print(sys.prefix)").Trim()
Write-Info "Downloading Open JTalk Dict..."
Invoke-Download -Uri $OpenJTalkURL -OutFile "open_jtalk_dic_utf_8-1.11.tar.gz"
$target = (python -c "import os, pyopenjtalk; print(os.path.dirname(pyopenjtalk.__file__))").Trim()
tar -xzf open_jtalk_dic_utf_8-1.11.tar.gz -C $target
Remove-Item "open_jtalk_dic_utf_8-1.11.tar.gz" -Force
Write-Success "Open JTalk Dic Downloaded"
Write-Success "Installation Completed"

View File

@ -5,15 +5,62 @@ SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" &>/dev/null && pwd)"
cd "$SCRIPT_DIR" || exit 1 cd "$SCRIPT_DIR" || exit 1
set -e RESET="\033[0m"
BOLD="\033[1m"
ERROR="\033[1;31m[ERROR]: $RESET"
WARNING="\033[1;33m[WARNING]: $RESET"
INFO="\033[1;32m[INFO]: $RESET"
SUCCESS="\033[1;34m[SUCCESS]: $RESET"
set -eE
set -o errtrace
trap 'on_error $LINENO "$BASH_COMMAND" $?' ERR
# shellcheck disable=SC2317
on_error() {
local lineno="$1"
local cmd="$2"
local code="$3"
echo -e "${ERROR}${BOLD}Command \"${cmd}\" Failed${RESET} at ${BOLD}Line ${lineno}${RESET} with Exit Code ${BOLD}${code}${RESET}"
echo -e "${ERROR}${BOLD}Call Stack:${RESET}"
for ((i = ${#FUNCNAME[@]} - 1; i >= 1; i--)); do
echo -e " in ${BOLD}${FUNCNAME[i]}()${RESET} at ${BASH_SOURCE[i]}:${BOLD}${BASH_LINENO[i - 1]}${RESET}"
done
exit "$code"
}
run_conda_quiet() {
local output
output=$(conda install --yes --quiet -c conda-forge "$@" 2>&1) || {
echo -e "${ERROR} Conda install failed:\n$output"
exit 1
}
}
run_pip_quiet() {
local output
output=$(pip install "$@" 2>&1) || {
echo -e "${ERROR} Pip install failed:\n$output"
exit 1
}
}
run_wget_quiet() {
if wget --tries=25 --wait=5 --read-timeout=40 -q --show-progress "$@" 2>&1; then
tput cuu1 && tput el
else
echo -e "${ERROR} Wget failed"
exit 1
fi
}
if ! command -v conda &>/dev/null; then if ! command -v conda &>/dev/null; then
echo "Conda Not Found" echo -e "${ERROR}Conda Not Found"
exit 1 exit 1
fi fi
trap 'echo "Error Occured at \"$BASH_COMMAND\" with exit code $?"; exit 1' ERR
USE_CUDA=false USE_CUDA=false
USE_ROCM=false USE_ROCM=false
USE_CPU=false USE_CPU=false
@ -34,8 +81,8 @@ print_help() {
echo " -h, --help Show this help message and exit" echo " -h, --help Show this help message and exit"
echo "" echo ""
echo "Examples:" echo "Examples:"
echo " bash install.sh --source HF --download-uvr5" echo " bash install.sh --device CU128 --source HF --download-uvr5"
echo " bash install.sh --source ModelScope" echo " bash install.sh --device MPS --source ModelScope"
} }
# Show help if no arguments provided # Show help if no arguments provided
@ -59,8 +106,8 @@ while [[ $# -gt 0 ]]; do
USE_MODELSCOPE=true USE_MODELSCOPE=true
;; ;;
*) *)
echo "Error: Invalid Download Source: $2" echo -e "${ERROR}Error: Invalid Download Source: $2"
echo "Choose From: [HF, HF-Mirror, ModelScope]" echo -e "${ERROR}Choose From: [HF, HF-Mirror, ModelScope]"
exit 1 exit 1
;; ;;
esac esac
@ -86,8 +133,8 @@ while [[ $# -gt 0 ]]; do
USE_CPU=true USE_CPU=true
;; ;;
*) *)
echo "Error: Invalid Device: $2" echo -e "${ERROR}Error: Invalid Device: $2"
echo "Choose From: [CU126, CU128, ROCM, MPS, CPU]" echo -e "${ERROR}Choose From: [CU126, CU128, ROCM, MPS, CPU]"
exit 1 exit 1
;; ;;
esac esac
@ -102,78 +149,102 @@ while [[ $# -gt 0 ]]; do
exit 0 exit 0
;; ;;
*) *)
echo "Unknown Argument: $1" echo -e "${ERROR}Unknown Argument: $1"
echo "Use -h or --help to see available options." echo ""
print_help
exit 1 exit 1
;; ;;
esac esac
done done
if ! $USE_CUDA && ! $USE_ROCM && ! $USE_CPU; then if ! $USE_CUDA && ! $USE_ROCM && ! $USE_CPU; then
echo "Error: Device is REQUIRED" echo -e "${ERROR}Error: Device is REQUIRED"
echo "" echo ""
print_help print_help
exit 1 exit 1
fi fi
if ! $USE_HF && ! $USE_HF_MIRROR && ! $USE_MODELSCOPE; then if ! $USE_HF && ! $USE_HF_MIRROR && ! $USE_MODELSCOPE; then
echo "Error: Download Source is REQUIRED" echo -e "${ERROR}Error: Download Source is REQUIRED"
echo "" echo ""
print_help print_help
exit 1 exit 1
fi fi
# 安装构建工具 case "$(uname -m)" in
x86_64 | amd64) SYSROOT_PKG="sysroot_linux-64>=2.28" ;;
aarch64 | arm64) SYSROOT_PKG="sysroot_linux-aarch64>=2.28" ;;
ppc64le) SYSROOT_PKG="sysroot_linux-ppc64le>=2.28" ;;
*)
echo "Unsupported architecture: $(uname -m)"
exit 1
;;
esac
# Install build tools # Install build tools
echo -e "${INFO}Detected system: $(uname -s) $(uname -r) $(uname -m)"
if [ "$(uname)" != "Darwin" ]; then if [ "$(uname)" != "Darwin" ]; then
gcc_major_version=$(command -v gcc >/dev/null 2>&1 && gcc -dumpversion | cut -d. -f1 || echo 0) gcc_major_version=$(command -v gcc >/dev/null 2>&1 && gcc -dumpversion | cut -d. -f1 || echo 0)
if [ "$gcc_major_version" -lt 11 ]; then if [ "$gcc_major_version" -lt 11 ]; then
echo "Installing GCC & G++..." echo -e "${INFO}Installing GCC & G++..."
conda install -c conda-forge gcc=11 gxx=11 -q -y run_conda_quiet gcc=11 gxx=11
run_conda_quiet "$SYSROOT_PKG"
echo -e "${SUCCESS}GCC & G++ Installed..."
else else
echo "GCC >=11" echo -e "${INFO}Detected GCC Version: $gcc_major_version"
echo -e "${INFO}Skip Installing GCC & G++ From Conda-Forge"
echo -e "${INFO}Installing libstdcxx-ng From Conda-Forge"
run_conda_quiet "libstdcxx-ng>=$gcc_major_version"
echo -e "${SUCCESS}libstdcxx-ng=$gcc_major_version Installed..."
fi fi
else else
if ! xcode-select -p &>/dev/null; then if ! xcode-select -p &>/dev/null; then
echo "Installing Xcode Command Line Tools..." echo -e "${INFO}Installing Xcode Command Line Tools..."
xcode-select --install xcode-select --install
fi echo -e "${INFO}Waiting For Xcode Command Line Tools Installation Complete..."
echo "Waiting For Xcode Command Line Tools Installation Complete..."
while true; do while true; do
sleep 20 sleep 20
if xcode-select -p &>/dev/null; then if xcode-select -p &>/dev/null; then
echo "Xcode Command Line Tools Installed" echo -e "${SUCCESS}Xcode Command Line Tools Installed"
break break
else else
echo "InstallingPlease Wait..." echo -e "${INFO}InstallingPlease Wait..."
fi fi
done done
conda install -c conda-forge -q -y else
XCODE_PATH=$(xcode-select -p)
if [[ "$XCODE_PATH" == *"Xcode.app"* ]]; then
echo -e "${WARNING} Detected Xcode path: $XCODE_PATH"
echo -e "${WARNING} If your Xcode version does not match your macOS version, it may cause unexpected issues during compilation or package builds."
fi
fi
fi fi
echo "Installing ffmpeg and cmake..." echo -e "${INFO}Installing FFmpeg & CMake..."
conda install ffmpeg cmake make -q -y run_conda_quiet ffmpeg cmake make
echo -e "${SUCCESS}FFmpeg & CMake Installed"
echo "Installing unzip..." echo -e "${INFO}Installing unzip..."
conda install unzip -y --quiet run_conda_quiet unzip
echo -e "${SUCCESS}unzip Installed"
if [ "$USE_HF" = "true" ]; then if [ "$USE_HF" = "true" ]; then
echo "Download Model From HuggingFace" echo -e "${INFO}Download Model From HuggingFace"
PRETRINED_URL="https://huggingface.co/XXXXRT/GPT-SoVITS-Pretrained/resolve/main/pretrained_models.zip" PRETRINED_URL="https://huggingface.co/XXXXRT/GPT-SoVITS-Pretrained/resolve/main/pretrained_models.zip"
G2PW_URL="https://huggingface.co/XXXXRT/GPT-SoVITS-Pretrained/resolve/main/G2PWModel.zip" G2PW_URL="https://huggingface.co/XXXXRT/GPT-SoVITS-Pretrained/resolve/main/G2PWModel.zip"
UVR5_URL="https://huggingface.co/XXXXRT/GPT-SoVITS-Pretrained/resolve/main/uvr5_weights.zip" UVR5_URL="https://huggingface.co/XXXXRT/GPT-SoVITS-Pretrained/resolve/main/uvr5_weights.zip"
NLTK_URL="https://huggingface.co/XXXXRT/GPT-SoVITS-Pretrained/resolve/main/nltk_data.zip" NLTK_URL="https://huggingface.co/XXXXRT/GPT-SoVITS-Pretrained/resolve/main/nltk_data.zip"
PYOPENJTALK_URL="https://huggingface.co/XXXXRT/GPT-SoVITS-Pretrained/resolve/main/open_jtalk_dic_utf_8-1.11.tar.gz" PYOPENJTALK_URL="https://huggingface.co/XXXXRT/GPT-SoVITS-Pretrained/resolve/main/open_jtalk_dic_utf_8-1.11.tar.gz"
elif [ "$USE_HF_MIRROR" = "true" ]; then elif [ "$USE_HF_MIRROR" = "true" ]; then
echo "Download Model From HuggingFace-Mirror" echo -e "${INFO}Download Model From HuggingFace-Mirror"
PRETRINED_URL="https://hf-mirror.com/XXXXRT/GPT-SoVITS-Pretrained/resolve/main/pretrained_models.zip" PRETRINED_URL="https://hf-mirror.com/XXXXRT/GPT-SoVITS-Pretrained/resolve/main/pretrained_models.zip"
G2PW_URL="https://hf-mirror.com/XXXXRT/GPT-SoVITS-Pretrained/resolve/main/G2PWModel.zip" G2PW_URL="https://hf-mirror.com/XXXXRT/GPT-SoVITS-Pretrained/resolve/main/G2PWModel.zip"
UVR5_URL="https://hf-mirror.com/XXXXRT/GPT-SoVITS-Pretrained/resolve/main/uvr5_weights.zip" UVR5_URL="https://hf-mirror.com/XXXXRT/GPT-SoVITS-Pretrained/resolve/main/uvr5_weights.zip"
NLTK_URL="https://hf-mirror.com/XXXXRT/GPT-SoVITS-Pretrained/resolve/main/nltk_data.zip" NLTK_URL="https://hf-mirror.com/XXXXRT/GPT-SoVITS-Pretrained/resolve/main/nltk_data.zip"
PYOPENJTALK_URL="https://hf-mirror.com/XXXXRT/GPT-SoVITS-Pretrained/resolve/main/open_jtalk_dic_utf_8-1.11.tar.gz" PYOPENJTALK_URL="https://hf-mirror.com/XXXXRT/GPT-SoVITS-Pretrained/resolve/main/open_jtalk_dic_utf_8-1.11.tar.gz"
elif [ "$USE_MODELSCOPE" = "true" ]; then elif [ "$USE_MODELSCOPE" = "true" ]; then
echo "Download Model From ModelScope" echo -e "${INFO}Download Model From ModelScope"
PRETRINED_URL="https://www.modelscope.cn/models/XXXXRT/GPT-SoVITS-Pretrained/resolve/master/pretrained_models.zip" PRETRINED_URL="https://www.modelscope.cn/models/XXXXRT/GPT-SoVITS-Pretrained/resolve/master/pretrained_models.zip"
G2PW_URL="https://www.modelscope.cn/models/XXXXRT/GPT-SoVITS-Pretrained/resolve/master/G2PWModel.zip" G2PW_URL="https://www.modelscope.cn/models/XXXXRT/GPT-SoVITS-Pretrained/resolve/master/G2PWModel.zip"
UVR5_URL="https://www.modelscope.cn/models/XXXXRT/GPT-SoVITS-Pretrained/resolve/master/uvr5_weights.zip" UVR5_URL="https://www.modelscope.cn/models/XXXXRT/GPT-SoVITS-Pretrained/resolve/master/uvr5_weights.zip"
@ -181,118 +252,129 @@ elif [ "$USE_MODELSCOPE" = "true" ]; then
PYOPENJTALK_URL="https://www.modelscope.cn/models/XXXXRT/GPT-SoVITS-Pretrained/resolve/master/open_jtalk_dic_utf_8-1.11.tar.gz" PYOPENJTALK_URL="https://www.modelscope.cn/models/XXXXRT/GPT-SoVITS-Pretrained/resolve/master/open_jtalk_dic_utf_8-1.11.tar.gz"
fi fi
if [ "$WORKFLOW" = "true" ]; then if [ ! -d "GPT_SoVITS/pretrained_models/sv" ]; then
WGET_CMD=(wget -nv --tries=25 --wait=5 --read-timeout=40 --retry-on-http-error=404) echo -e "${INFO}Downloading Pretrained Models..."
else rm -rf pretrained_models.zip
WGET_CMD=(wget --tries=25 --wait=5 --read-timeout=40 --retry-on-http-error=404) run_wget_quiet "$PRETRINED_URL"
fi
if find -L "GPT_SoVITS/pretrained_models" -mindepth 1 ! -name '.gitignore' | grep -q .; then
echo "Pretrained Model Exists"
else
echo "Download Pretrained Models"
"${WGET_CMD[@]}" "$PRETRINED_URL"
unzip -q -o pretrained_models.zip -d GPT_SoVITS unzip -q -o pretrained_models.zip -d GPT_SoVITS
rm -rf pretrained_models.zip rm -rf pretrained_models.zip
echo -e "${SUCCESS}Pretrained Models Downloaded"
else
echo -e "${INFO}Pretrained Model Exists"
echo -e "${INFO}Skip Downloading Pretrained Models"
fi fi
if [ ! -d "GPT_SoVITS/text/G2PWModel" ]; then if [ ! -d "GPT_SoVITS/text/G2PWModel" ]; then
echo "Download G2PWModel" echo -e "${INFO}Downloading G2PWModel.."
"${WGET_CMD[@]}" "$G2PW_URL" rm -rf G2PWModel.zip
run_wget_quiet "$G2PW_URL"
unzip -q -o G2PWModel.zip -d GPT_SoVITS/text unzip -q -o G2PWModel.zip -d GPT_SoVITS/text
rm -rf G2PWModel.zip rm -rf G2PWModel.zip
echo -e "${SUCCESS}G2PWModel Downloaded"
else else
echo "G2PWModel Exists" echo -e "${INFO}G2PWModel Exists"
echo -e "${INFO}Skip Downloading G2PWModel"
fi fi
if [ "$DOWNLOAD_UVR5" = "true" ]; then if [ "$DOWNLOAD_UVR5" = "true" ]; then
if find -L "tools/uvr5/uvr5_weights" -mindepth 1 ! -name '.gitignore' | grep -q .; then if find -L "tools/uvr5/uvr5_weights" -mindepth 1 ! -name '.gitignore' | grep -q .; then
echo "UVR5 Model Exists" echo -e"${INFO}UVR5 Models Exists"
echo -e "${INFO}Skip Downloading UVR5 Models"
else else
echo "Download UVR5 Model" echo -e "${INFO}Downloading UVR5 Models..."
"${WGET_CMD[@]}" "$UVR5_URL" rm -rf uvr5_weights.zip
run_wget_quiet "$UVR5_URL"
unzip -q -o uvr5_weights.zip -d tools/uvr5 unzip -q -o uvr5_weights.zip -d tools/uvr5
rm -rf uvr5_weights.zip rm -rf uvr5_weights.zip
echo -e "${SUCCESS}UVR5 Models Downloaded"
fi fi
fi fi
if [ "$USE_CUDA" = true ] && [ "$WORKFLOW" = false ]; then if [ "$USE_CUDA" = true ] && [ "$WORKFLOW" = false ]; then
echo "Checking for CUDA installation..." echo -e "${INFO}Checking For Nvidia Driver Installation..."
if command -v nvidia-smi &>/dev/null; then if command -v nvidia-smi &>/dev/null; then
echo "CUDA found." echo "${INFO}Nvidia Driver Founded"
else else
echo -e "${WARNING}Nvidia Driver Not Found, Fallback to CPU"
USE_CUDA=false USE_CUDA=false
USE_CPU=true USE_CPU=true
echo "CUDA not found."
fi fi
fi fi
if [ "$USE_ROCM" = true ] && [ "$WORKFLOW" = false ]; then if [ "$USE_ROCM" = true ] && [ "$WORKFLOW" = false ]; then
echo "Checking for ROCm installation..." echo -e "${INFO}Checking For ROCm Installation..."
if [ -d "/opt/rocm" ]; then if [ -d "/opt/rocm" ]; then
echo "ROCm found." echo -e "${INFO}ROCm Founded"
if grep -qi "microsoft" /proc/version; then if grep -qi "microsoft" /proc/version; then
echo "You are running WSL." echo -e "${INFO}WSL2 Founded"
IS_WSL=true IS_WSL=true
else else
echo "You are NOT running WSL."
IS_WSL=false IS_WSL=false
fi fi
else else
echo -e "${WARNING}ROCm Not Found, Fallback to CPU"
USE_ROCM=false USE_ROCM=false
USE_CPU=true USE_CPU=true
echo "ROCm not found."
fi fi
fi fi
if [ "$USE_CUDA" = true ] && [ "$WORKFLOW" = false ]; then if [ "$USE_CUDA" = true ] && [ "$WORKFLOW" = false ]; then
echo "Installing PyTorch with CUDA support..."
if [ "$CUDA" = 128 ]; then if [ "$CUDA" = 128 ]; then
pip install torch torchaudio --index-url https://download.pytorch.org/whl/cu128 echo -e "${INFO}Installing PyTorch For CUDA 12.8..."
run_pip_quiet torch torchaudio --index-url "https://download.pytorch.org/whl/cu128"
elif [ "$CUDA" = 126 ]; then elif [ "$CUDA" = 126 ]; then
pip install torch==2.6 torchaudio --index-url https://download.pytorch.org/whl/cu126 echo -e "${INFO}Installing PyTorch For CUDA 12.6..."
run_pip_quiet torch torchaudio --index-url "https://download.pytorch.org/whl/cu126"
fi fi
elif [ "$USE_ROCM" = true ] && [ "$WORKFLOW" = false ]; then elif [ "$USE_ROCM" = true ] && [ "$WORKFLOW" = false ]; then
echo "Installing PyTorch with ROCm support..." echo -e "${INFO}Installing PyTorch For ROCm 6.2..."
pip install torch==2.6 torchaudio --index-url https://download.pytorch.org/whl/rocm6.2 run_pip_quiet torch torchaudio --index-url "https://download.pytorch.org/whl/rocm6.2"
elif [ "$USE_CPU" = true ] && [ "$WORKFLOW" = false ]; then elif [ "$USE_CPU" = true ] && [ "$WORKFLOW" = false ]; then
echo "Installing PyTorch for CPU..." echo -e "${INFO}Installing PyTorch For CPU..."
pip install torch==2.6 torchaudio --index-url https://download.pytorch.org/whl/cpu run_pip_quiet torch torchaudio --index-url "https://download.pytorch.org/whl/cpu"
elif [ "$WORKFLOW" = false ]; then elif [ "$WORKFLOW" = false ]; then
echo "Unknown Err" echo -e "${ERROR}Unknown Err"
exit 1 exit 1
fi fi
echo -e "${SUCCESS}PyTorch Installed"
echo "Installing Python dependencies from requirements.txt..." echo -e "${INFO}Installing Python Dependencies From requirements.txt..."
# 刷新环境
# Refresh environment
hash -r hash -r
pip install -r extra-req.txt --no-deps --quiet run_pip_quiet -r extra-req.txt --no-deps
pip install -r requirements.txt --quiet run_pip_quiet -r requirements.txt
echo -e "${SUCCESS}Python Dependencies Installed"
PY_PREFIX=$(python -c "import sys; print(sys.prefix)") PY_PREFIX=$(python -c "import sys; print(sys.prefix)")
PYOPENJTALK_PREFIX=$(python -c "import os, pyopenjtalk; print(os.path.dirname(pyopenjtalk.__file__))") PYOPENJTALK_PREFIX=$(python -c "import os, pyopenjtalk; print(os.path.dirname(pyopenjtalk.__file__))")
"${WGET_CMD[@]}" "$NLTK_URL" -O nltk_data.zip echo -e "${INFO}Downloading NLTK Data..."
rm -rf nltk_data.zip
run_wget_quiet "$NLTK_URL" -O nltk_data.zip
unzip -q -o nltk_data -d "$PY_PREFIX" unzip -q -o nltk_data -d "$PY_PREFIX"
rm -rf nltk_data.zip rm -rf nltk_data.zip
echo -e "${SUCCESS}NLTK Data Downloaded"
"${WGET_CMD[@]}" "$PYOPENJTALK_URL" -O open_jtalk_dic_utf_8-1.11.tar.gz echo -e "${INFO}Downloading Open JTalk Dict..."
tar -xvzf open_jtalk_dic_utf_8-1.11.tar.gz -C "$PYOPENJTALK_PREFIX"
rm -rf open_jtalk_dic_utf_8-1.11.tar.gz rm -rf open_jtalk_dic_utf_8-1.11.tar.gz
run_wget_quiet "$PYOPENJTALK_URL" -O open_jtalk_dic_utf_8-1.11.tar.gz
tar -xzf open_jtalk_dic_utf_8-1.11.tar.gz -C "$PYOPENJTALK_PREFIX"
rm -rf open_jtalk_dic_utf_8-1.11.tar.gz
echo -e "${SUCCESS}Open JTalk Dic Downloaded"
if [ "$USE_ROCM" = true ] && [ "$IS_WSL" = true ]; then if [ "$USE_ROCM" = true ] && [ "$IS_WSL" = true ]; then
echo "Update to WSL compatible runtime lib..." echo -e "${INFO}Updating WSL Compatible Runtime Lib For ROCm..."
location=$(pip show torch | grep Location | awk -F ": " '{print $2}') location=$(pip show torch | grep Location | awk -F ": " '{print $2}')
cd "${location}"/torch/lib/ || exit cd "${location}"/torch/lib/ || exit
rm libhsa-runtime64.so* rm libhsa-runtime64.so*
cp /opt/rocm/lib/libhsa-runtime64.so.1.2 libhsa-runtime64.so cp "$(readlink -f /opt/rocm/lib/libhsa-runtime64.so)" libhsa-runtime64.so
echo -e "${SUCCESS}ROCm Runtime Lib Updated..."
fi fi
echo "Installation completed successfully!" echo -e "${SUCCESS}Installation Completed"

View File

@ -6,15 +6,10 @@ def check_fw_local_models():
启动时检查本地是否有 Faster Whisper 模型. 启动时检查本地是否有 Faster Whisper 模型.
""" """
model_size_list = [ model_size_list = [
"tiny",
"tiny.en",
"base",
"base.en",
"small",
"small.en",
"medium", "medium",
"medium.en", "medium.en",
"large", "distil-large-v2",
"distil-large-v3",
"large-v1", "large-v1",
"large-v2", "large-v2",
"large-v3", "large-v3",
@ -25,11 +20,24 @@ def check_fw_local_models():
return model_size_list return model_size_list
def get_models():
model_size_list = [
"medium",
"medium.en",
"distil-large-v2",
"distil-large-v3",
"large-v1",
"large-v2",
"large-v3",
]
return model_size_list
asr_dict = { asr_dict = {
"达摩 ASR (中文)": {"lang": ["zh", "yue"], "size": ["large"], "path": "funasr_asr.py", "precision": ["float32"]}, "达摩 ASR (中文)": {"lang": ["zh", "yue"], "size": ["large"], "path": "funasr_asr.py", "precision": ["float32"]},
"Faster Whisper (多语种)": { "Faster Whisper (多语种)": {
"lang": ["auto", "zh", "en", "ja", "ko", "yue"], "lang": ["auto", "zh", "en", "ja", "ko", "yue"],
"size": check_fw_local_models(), "size": get_models(),
"path": "fasterwhisper_asr.py", "path": "fasterwhisper_asr.py",
"precision": ["float32", "float16", "int8"], "precision": ["float32", "float16", "int8"],
}, },

View File

@ -1,15 +1,16 @@
import argparse import argparse
import os import os
import time
import traceback import traceback
os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
import torch import torch
from faster_whisper import WhisperModel from faster_whisper import WhisperModel
from huggingface_hub import snapshot_download
from huggingface_hub.errors import LocalEntryNotFoundError
from tqdm import tqdm from tqdm import tqdm
from tools.asr.config import check_fw_local_models from tools.asr.config import get_models
from tools.asr.funasr_asr import only_asr
from tools.my_utils import load_cudnn from tools.my_utils import load_cudnn
# fmt: off # fmt: off
@ -38,20 +39,54 @@ language_code_list = [
# fmt: on # fmt: on
def execute_asr(input_folder, output_folder, model_size, language, precision): def download_model(model_size: str):
if "-local" in model_size: if "distil" in model_size:
model_size = model_size[:-6] repo_id = "Systran/faster-{}-whisper-{}".format(*model_size.split("-", maxsplit=1))
model_path = f"tools/asr/models/faster-whisper-{model_size}"
else: else:
model_path = model_size repo_id = f"Systran/faster-whisper-{model_size}"
model_path = f"tools/asr/models/{repo_id.strip('Systran/')}"
files: list[str] = [
"config.json",
"model.bin",
"tokenizer.json",
"vocabulary.txt",
]
if model_size == "large-v3" or "distil" in model_size:
files.append("preprocessor_config.json")
files.append("vocabulary.json")
files.remove("vocabulary.txt")
for attempt in range(2):
try:
snapshot_download(
repo_id=repo_id,
allow_patterns=files,
local_dir=model_path,
)
break
except LocalEntryNotFoundError:
if attempt < 1:
time.sleep(2)
else:
print("[ERROR] LocalEntryNotFoundError and no fallback.")
traceback.print_exc()
exit(1)
except Exception as e:
print(f"[ERROR] Unexpected error on attempt {attempt + 1}: {e}")
traceback.print_exc()
exit(1)
return model_path
def execute_asr(input_folder, output_folder, model_path, language, precision):
if language == "auto": if language == "auto":
language = None # 不设置语种由模型自动输出概率最高的语种 language = None # 不设置语种由模型自动输出概率最高的语种
print("loading faster whisper model:", model_size, model_path) print("loading faster whisper model:", model_path, model_path)
device = "cuda" if torch.cuda.is_available() else "cpu" device = "cuda" if torch.cuda.is_available() else "cpu"
try:
model = WhisperModel(model_path, device=device, compute_type=precision) model = WhisperModel(model_path, device=device, compute_type=precision)
except:
return print(traceback.format_exc())
input_file_names = os.listdir(input_folder) input_file_names = os.listdir(input_folder)
input_file_names.sort() input_file_names.sort()
@ -73,16 +108,15 @@ def execute_asr(input_folder, output_folder, model_size, language, precision):
if info.language == "zh": if info.language == "zh":
print("检测为中文文本, 转 FunASR 处理") print("检测为中文文本, 转 FunASR 处理")
if "only_asr" not in globals():
from tools.asr.funasr_asr import only_asr # 如果用英文就不需要导入下载模型
text = only_asr(file_path, language=info.language.lower()) text = only_asr(file_path, language=info.language.lower())
if text == "": if text == "":
for segment in segments: for segment in segments:
text += segment.text text += segment.text
output.append(f"{file_path}|{output_file_name}|{info.language.upper()}|{text}") output.append(f"{file_path}|{output_file_name}|{info.language.upper()}|{text}")
except: except Exception as e:
print(traceback.format_exc()) print(e)
traceback.print_exc()
output_folder = output_folder or "output/asr_opt" output_folder = output_folder or "output/asr_opt"
os.makedirs(output_folder, exist_ok=True) os.makedirs(output_folder, exist_ok=True)
@ -107,7 +141,7 @@ if __name__ == "__main__":
"--model_size", "--model_size",
type=str, type=str,
default="large-v3", default="large-v3",
choices=check_fw_local_models(), choices=get_models(),
help="Model Size of Faster Whisper", help="Model Size of Faster Whisper",
) )
parser.add_argument( parser.add_argument(
@ -123,10 +157,14 @@ if __name__ == "__main__":
) )
cmd = parser.parse_args() cmd = parser.parse_args()
model_size = cmd.model_size
if model_size == "large":
model_size = "large-v3"
model_path = download_model(model_size)
output_file_path = execute_asr( output_file_path = execute_asr(
input_folder=cmd.input_folder, input_folder=cmd.input_folder,
output_folder=cmd.output_folder, output_folder=cmd.output_folder,
model_size=cmd.model_size, model_path=model_path,
language=cmd.language, language=cmd.language,
precision=cmd.precision, precision=cmd.precision,
) )

View File

@ -1,81 +1,38 @@
js = """ js = """
function createGradioAnimation() { function deleteTheme() {
const params = new URLSearchParams(window.location.search); const params = new URLSearchParams(window.location.search);
if (params.get('__theme') !== 'light') { if (params.has('__theme')) {
params.set('__theme', 'light'); // 仅当 __theme 不是 'light' 时设置为 'light' params.delete('__theme');
window.location.search = params.toString(); // 更新 URL触发页面刷新 const newUrl = `${window.location.pathname}?${params.toString()}`;
} window.location.replace(newUrl);
}
var container = document.createElement('div');
container.id = 'gradio-animation';
container.style.fontSize = '2em';
container.style.fontWeight = '500';
container.style.textAlign = 'center';
container.style.marginBottom = '20px';
container.style.fontFamily = '-apple-system, sans-serif, Arial, Calibri';
var text = 'Welcome to GPT-SoVITS !';
for (var i = 0; i < text.length; i++) {
(function(i){
setTimeout(function(){
var letter = document.createElement('span');
letter.style.opacity = '0';
letter.style.transition = 'opacity 0.5s';
letter.innerText = text[i];
container.appendChild(letter);
setTimeout(function() {
letter.style.opacity = '1';
}, 50);
}, i * 250);
})(i);
}
return 'Animation created';
} }
""" """
css = """ css = """
/* CSSStyleRule */ /* CSSStyleRule */
.markdown { .markdown {
background-color: lightblue;
padding: 6px 10px; padding: 6px 10px;
} }
.checkbox_info { @media (prefers-color-scheme: light) {
color: var(--block-title-text-color) !important; .markdown {
font-size: var(--block-title-text-size) !important; background-color: lightblue;
font-weight: var(--block-title-text-weight) !important; color: #000;
height: 22px; }
margin-bottom: 8px !important; }
@media (prefers-color-scheme: dark) {
.markdown {
background-color: #4b4b4b;
color: rgb(244, 244, 245);
}
} }
::selection { ::selection {
background: #ffc078; !important; background: #ffc078 !important;
}
#checkbox_train_dpo input[type="checkbox"]{
margin-top: 6px;
}
#checkbox_train_dpo span {
margin-top: 6px;
}
#checkbox_align_train {
padding-top: 18px;
padding-bottom: 18px;
}
#checkbox_align_infer input[type="checkbox"] {
margin-top: 10px;
}
#checkbox_align_infer span {
margin-top: 10px;
} }
footer { footer {
@ -91,16 +48,20 @@ footer * {
} }
""" """
top_html = """ top_html = """
<div align="center"> <div align="center">
<div style="margin-bottom: 5px; font-size: 15px;">{}</div> <div style="margin-bottom: 5px; font-size: 15px;">{}</div>
<div style="display: flex; gap: 80px; justify-content: center;"> <div style="display: flex; gap: 60px; justify-content: center;">
<a href="https://github.com/RVC-Boss/GPT-SoVITS" target="_blank"> <a href="https://github.com/RVC-Boss/GPT-SoVITS" target="_blank">
<img src="https://img.shields.io/badge/GitHub-GPT--SoVITS-blue.svg?style=for-the-badge&logo=github" style="width: auto; height: 30px;"> <img src="https://img.shields.io/badge/GitHub-GPT--SoVITS-blue.svg?style=for-the-badge&logo=github" style="width: auto; height: 30px;">
</a> </a>
<a href="https://www.yuque.com/baicaigongchang1145haoyuangong/ib3g1e" target="_blank"> <a href="https://www.yuque.com/baicaigongchang1145haoyuangong/ib3g1e" target="_blank">
<img src="https://img.shields.io/badge/简体中文-阅读文档-blue?style=for-the-badge&logo=googledocs&logoColor=white" style="width: auto; height: 30px;"> <img src="https://img.shields.io/badge/简体中文-阅读文档-blue?style=for-the-badge&logo=googledocs&logoColor=white" style="width: auto; height: 30px;">
</a> </a>
<a href="https://lj1995-gpt-sovits-proplus.hf.space/" target="_blank">
<img src="https://img.shields.io/badge/免费在线体验-free_online_demo-yellow.svg?style=for-the-badge&logo=huggingface" style="width: auto; height: 30px;">
</a>
<a href="https://www.yuque.com/baicaigongchang1145haoyuangong/ib3g1e" target="_blank"> <a href="https://www.yuque.com/baicaigongchang1145haoyuangong/ib3g1e" target="_blank">
<img src="https://img.shields.io/badge/English-READ%20DOCS-blue?style=for-the-badge&logo=googledocs&logoColor=white" style="width: auto; height: 30px;"> <img src="https://img.shields.io/badge/English-READ%20DOCS-blue?style=for-the-badge&logo=googledocs&logoColor=white" style="width: auto; height: 30px;">
</a> </a>

View File

@ -109,7 +109,7 @@ def check_details(path_list=None, is_train=False, is_dataset_processing=False):
if os.path.exists(wav_path): if os.path.exists(wav_path):
... ...
else: else:
gr.Warning(wav_path+i18n("路径错误")) gr.Warning(wav_path + i18n("路径错误"))
return return
if is_train: if is_train:
path_list.append(os.path.join(path_list[0], "2-name2text.txt")) path_list.append(os.path.join(path_list[0], "2-name2text.txt"))

View File

@ -1,5 +1,6 @@
import sys import sys
from tools.i18n.i18n import I18nAuto, scan_language_list from tools.i18n.i18n import I18nAuto, scan_language_list
language = sys.argv[-1] if sys.argv[-1] in scan_language_list() else "Auto" language = sys.argv[-1] if sys.argv[-1] in scan_language_list() else "Auto"
i18n = I18nAuto(language=language) i18n = I18nAuto(language=language)
import argparse import argparse
@ -309,7 +310,9 @@ if __name__ == "__main__":
with gr.Blocks(analytics_enabled=False) as demo: with gr.Blocks(analytics_enabled=False) as demo:
gr.Markdown( gr.Markdown(
value=i18n("Submit Text: 将当前页所有文本框内容手工保存到内存和文件(翻页前后或者退出标注页面前如果没点这个按钮,你再翻回来就回滚了,白忙活。)") value=i18n(
"Submit Text: 将当前页所有文本框内容手工保存到内存和文件(翻页前后或者退出标注页面前如果没点这个按钮,你再翻回来就回滚了,白忙活。)"
)
) )
with gr.Row(): with gr.Row():
btn_change_index = gr.Button("Change Index") btn_change_index = gr.Button("Change Index")

View File

@ -190,14 +190,14 @@ class Predictor:
opt_path_vocal = path_vocal[:-4] + ".%s" % format opt_path_vocal = path_vocal[:-4] + ".%s" % format
opt_path_other = path_other[:-4] + ".%s" % format opt_path_other = path_other[:-4] + ".%s" % format
if os.path.exists(path_vocal): if os.path.exists(path_vocal):
os.system("ffmpeg -i \"%s\" -vn \"%s\" -q:a 2 -y" % (path_vocal, opt_path_vocal)) os.system('ffmpeg -i "%s" -vn "%s" -q:a 2 -y' % (path_vocal, opt_path_vocal))
if os.path.exists(opt_path_vocal): if os.path.exists(opt_path_vocal):
try: try:
os.remove(path_vocal) os.remove(path_vocal)
except: except:
pass pass
if os.path.exists(path_other): if os.path.exists(path_other):
os.system("ffmpeg -i \"%s\" -vn \"%s\" -q:a 2 -y" % (path_other, opt_path_other)) os.system('ffmpeg -i "%s" -vn "%s" -q:a 2 -y' % (path_other, opt_path_other))
if os.path.exists(opt_path_other): if os.path.exists(opt_path_other):
try: try:
os.remove(path_other) os.remove(path_other)

View File

@ -140,7 +140,7 @@ class AudioPre:
) )
if os.path.exists(path): if os.path.exists(path):
opt_format_path = path[:-4] + ".%s" % format opt_format_path = path[:-4] + ".%s" % format
cmd="ffmpeg -i \"%s\" -vn \"%s\" -q:a 2 -y" % (path, opt_format_path) cmd = 'ffmpeg -i "%s" -vn "%s" -q:a 2 -y' % (path, opt_format_path)
print(cmd) print(cmd)
os.system(cmd) os.system(cmd)
if os.path.exists(opt_format_path): if os.path.exists(opt_format_path):
@ -177,7 +177,7 @@ class AudioPre:
) )
if os.path.exists(path): if os.path.exists(path):
opt_format_path = path[:-4] + ".%s" % format opt_format_path = path[:-4] + ".%s" % format
cmd="ffmpeg -i \"%s\" -vn \"%s\" -q:a 2 -y" % (path, opt_format_path) cmd = 'ffmpeg -i "%s" -vn "%s" -q:a 2 -y' % (path, opt_format_path)
print(cmd) print(cmd)
os.system(cmd) os.system(cmd)
if os.path.exists(opt_format_path): if os.path.exists(opt_format_path):
@ -307,7 +307,7 @@ class AudioPreDeEcho:
) )
if os.path.exists(path): if os.path.exists(path):
opt_format_path = path[:-4] + ".%s" % format opt_format_path = path[:-4] + ".%s" % format
cmd="ffmpeg -i \"%s\" -vn \"%s\" -q:a 2 -y" % (path, opt_format_path) cmd = 'ffmpeg -i "%s" -vn "%s" -q:a 2 -y' % (path, opt_format_path)
print(cmd) print(cmd)
os.system(cmd) os.system(cmd)
if os.path.exists(opt_format_path): if os.path.exists(opt_format_path):
@ -340,7 +340,7 @@ class AudioPreDeEcho:
) )
if os.path.exists(path): if os.path.exists(path):
opt_format_path = path[:-4] + ".%s" % format opt_format_path = path[:-4] + ".%s" % format
cmd="ffmpeg -i \"%s\" -vn \"%s\" -q:a 2 -y" % (path, opt_format_path) cmd = 'ffmpeg -i "%s" -vn "%s" -q:a 2 -y' % (path, opt_format_path)
print(cmd) print(cmd)
os.system(cmd) os.system(cmd)
if os.path.exists(opt_format_path): if os.path.exists(opt_format_path):

View File

@ -86,13 +86,10 @@ from config import (
from tools import my_utils from tools import my_utils
from tools.my_utils import check_details, check_for_existance from tools.my_utils import check_details, check_for_existance
# os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1' # 当遇到mps不支持的步骤时使用cpu os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"
try: os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
import gradio.analytics as analytics
analytics.version_check = lambda: None # os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1' # 当遇到mps不支持的步骤时使用cpu
except:
...
import gradio as gr import gradio as gr
n_cpu = cpu_count() n_cpu = cpu_count()
@ -346,7 +343,7 @@ def change_tts_inference(bert_path, cnhubert_base_path, gpu_number, gpt_path, so
os.environ["sovits_path"] = sovits_path os.environ["sovits_path"] = sovits_path
os.environ["cnhubert_base_path"] = cnhubert_base_path os.environ["cnhubert_base_path"] = cnhubert_base_path
os.environ["bert_path"] = bert_path os.environ["bert_path"] = bert_path
os.environ["_CUDA_VISIBLE_DEVICES"] = fix_gpu_number(gpu_number) os.environ["_CUDA_VISIBLE_DEVICES"] = str(fix_gpu_number(gpu_number))
os.environ["is_half"] = str(is_half) os.environ["is_half"] = str(is_half)
os.environ["infer_ttswebui"] = str(webui_port_infer_tts) os.environ["infer_ttswebui"] = str(webui_port_infer_tts)
os.environ["is_share"] = str(is_share) os.environ["is_share"] = str(is_share)
@ -507,6 +504,7 @@ def open1Ba(
): ):
global p_train_SoVITS global p_train_SoVITS
if p_train_SoVITS == None: if p_train_SoVITS == None:
exp_name = exp_name.rstrip(" ")
config_file = ( config_file = (
"GPT_SoVITS/configs/s2.json" "GPT_SoVITS/configs/s2.json"
if version not in {"v2Pro", "v2ProPlus"} if version not in {"v2Pro", "v2ProPlus"}
@ -603,6 +601,7 @@ def open1Bb(
): ):
global p_train_GPT global p_train_GPT
if p_train_GPT == None: if p_train_GPT == None:
exp_name = exp_name.rstrip(" ")
with open( with open(
"GPT_SoVITS/configs/s1longer.yaml" if version == "v1" else "GPT_SoVITS/configs/s1longer-v2.yaml" "GPT_SoVITS/configs/s1longer.yaml" if version == "v1" else "GPT_SoVITS/configs/s1longer-v2.yaml"
) as f: ) as f:
@ -629,7 +628,7 @@ def open1Bb(
data["output_dir"] = "%s/logs_s1_%s" % (s1_dir, version) data["output_dir"] = "%s/logs_s1_%s" % (s1_dir, version)
# data["version"]=version # data["version"]=version
os.environ["_CUDA_VISIBLE_DEVICES"] = fix_gpu_numbers(gpu_numbers.replace("-", ",")) os.environ["_CUDA_VISIBLE_DEVICES"] = str(fix_gpu_numbers(gpu_numbers.replace("-", ",")))
os.environ["hz"] = "25hz" os.environ["hz"] = "25hz"
tmp_config_path = "%s/tmp_s1.yaml" % tmp tmp_config_path = "%s/tmp_s1.yaml" % tmp
with open(tmp_config_path, "w") as f: with open(tmp_config_path, "w") as f:
@ -785,6 +784,7 @@ def open1a(inp_text, inp_wav_dir, exp_name, gpu_numbers, bert_pretrained_dir):
inp_wav_dir = my_utils.clean_path(inp_wav_dir) inp_wav_dir = my_utils.clean_path(inp_wav_dir)
if check_for_existance([inp_text, inp_wav_dir], is_dataset_processing=True): if check_for_existance([inp_text, inp_wav_dir], is_dataset_processing=True):
check_details([inp_text, inp_wav_dir], is_dataset_processing=True) check_details([inp_text, inp_wav_dir], is_dataset_processing=True)
exp_name = exp_name.rstrip(" ")
if ps1a == []: if ps1a == []:
opt_dir = "%s/%s" % (exp_root, exp_name) opt_dir = "%s/%s" % (exp_root, exp_name)
config = { config = {
@ -801,7 +801,7 @@ def open1a(inp_text, inp_wav_dir, exp_name, gpu_numbers, bert_pretrained_dir):
{ {
"i_part": str(i_part), "i_part": str(i_part),
"all_parts": str(all_parts), "all_parts": str(all_parts),
"_CUDA_VISIBLE_DEVICES": fix_gpu_number(gpu_names[i_part]), "_CUDA_VISIBLE_DEVICES": str(fix_gpu_number(gpu_names[i_part])),
"is_half": str(is_half), "is_half": str(is_half),
} }
) )
@ -874,6 +874,7 @@ def open1b(version, inp_text, inp_wav_dir, exp_name, gpu_numbers, ssl_pretrained
inp_wav_dir = my_utils.clean_path(inp_wav_dir) inp_wav_dir = my_utils.clean_path(inp_wav_dir)
if check_for_existance([inp_text, inp_wav_dir], is_dataset_processing=True): if check_for_existance([inp_text, inp_wav_dir], is_dataset_processing=True):
check_details([inp_text, inp_wav_dir], is_dataset_processing=True) check_details([inp_text, inp_wav_dir], is_dataset_processing=True)
exp_name = exp_name.rstrip(" ")
if ps1b == []: if ps1b == []:
config = { config = {
"inp_text": inp_text, "inp_text": inp_text,
@ -891,7 +892,7 @@ def open1b(version, inp_text, inp_wav_dir, exp_name, gpu_numbers, ssl_pretrained
{ {
"i_part": str(i_part), "i_part": str(i_part),
"all_parts": str(all_parts), "all_parts": str(all_parts),
"_CUDA_VISIBLE_DEVICES": fix_gpu_number(gpu_names[i_part]), "_CUDA_VISIBLE_DEVICES": str(fix_gpu_number(gpu_names[i_part])),
} }
) )
os.environ.update(config) os.environ.update(config)
@ -913,7 +914,7 @@ def open1b(version, inp_text, inp_wav_dir, exp_name, gpu_numbers, ssl_pretrained
{ {
"i_part": str(i_part), "i_part": str(i_part),
"all_parts": str(all_parts), "all_parts": str(all_parts),
"_CUDA_VISIBLE_DEVICES": fix_gpu_number(gpu_names[i_part]), "_CUDA_VISIBLE_DEVICES": str(fix_gpu_number(gpu_names[i_part])),
} }
) )
os.environ.update(config) os.environ.update(config)
@ -962,6 +963,7 @@ def open1c(version, inp_text, inp_wav_dir, exp_name, gpu_numbers, pretrained_s2G
inp_text = my_utils.clean_path(inp_text) inp_text = my_utils.clean_path(inp_text)
if check_for_existance([inp_text, inp_wav_dir], is_dataset_processing=True): if check_for_existance([inp_text, inp_wav_dir], is_dataset_processing=True):
check_details([inp_text, inp_wav_dir], is_dataset_processing=True) check_details([inp_text, inp_wav_dir], is_dataset_processing=True)
exp_name = exp_name.rstrip(" ")
if ps1c == []: if ps1c == []:
opt_dir = "%s/%s" % (exp_root, exp_name) opt_dir = "%s/%s" % (exp_root, exp_name)
config_file = ( config_file = (
@ -984,7 +986,7 @@ def open1c(version, inp_text, inp_wav_dir, exp_name, gpu_numbers, pretrained_s2G
{ {
"i_part": str(i_part), "i_part": str(i_part),
"all_parts": str(all_parts), "all_parts": str(all_parts),
"_CUDA_VISIBLE_DEVICES": fix_gpu_number(gpu_names[i_part]), "_CUDA_VISIBLE_DEVICES": str(fix_gpu_number(gpu_names[i_part])),
} }
) )
os.environ.update(config) os.environ.update(config)
@ -1059,6 +1061,7 @@ def open1abc(
inp_wav_dir = my_utils.clean_path(inp_wav_dir) inp_wav_dir = my_utils.clean_path(inp_wav_dir)
if check_for_existance([inp_text, inp_wav_dir], is_dataset_processing=True): if check_for_existance([inp_text, inp_wav_dir], is_dataset_processing=True):
check_details([inp_text, inp_wav_dir], is_dataset_processing=True) check_details([inp_text, inp_wav_dir], is_dataset_processing=True)
exp_name = exp_name.rstrip(" ")
if ps1abc == []: if ps1abc == []:
opt_dir = "%s/%s" % (exp_root, exp_name) opt_dir = "%s/%s" % (exp_root, exp_name)
try: try:
@ -1083,7 +1086,7 @@ def open1abc(
{ {
"i_part": str(i_part), "i_part": str(i_part),
"all_parts": str(all_parts), "all_parts": str(all_parts),
"_CUDA_VISIBLE_DEVICES": fix_gpu_number(gpu_names[i_part]), "_CUDA_VISIBLE_DEVICES": str(fix_gpu_number(gpu_names[i_part])),
} }
) )
os.environ.update(config) os.environ.update(config)
@ -1130,7 +1133,7 @@ def open1abc(
{ {
"i_part": str(i_part), "i_part": str(i_part),
"all_parts": str(all_parts), "all_parts": str(all_parts),
"_CUDA_VISIBLE_DEVICES": fix_gpu_number(gpu_names[i_part]), "_CUDA_VISIBLE_DEVICES": str(fix_gpu_number(gpu_names[i_part])),
} }
) )
os.environ.update(config) os.environ.update(config)
@ -1152,7 +1155,7 @@ def open1abc(
{ {
"i_part": str(i_part), "i_part": str(i_part),
"all_parts": str(all_parts), "all_parts": str(all_parts),
"_CUDA_VISIBLE_DEVICES": fix_gpu_number(gpu_names[i_part]), "_CUDA_VISIBLE_DEVICES": str(fix_gpu_number(gpu_names[i_part])),
} }
) )
os.environ.update(config) os.environ.update(config)
@ -1192,7 +1195,7 @@ def open1abc(
{ {
"i_part": str(i_part), "i_part": str(i_part),
"all_parts": str(all_parts), "all_parts": str(all_parts),
"_CUDA_VISIBLE_DEVICES": fix_gpu_number(gpu_names[i_part]), "_CUDA_VISIBLE_DEVICES": str(fix_gpu_number(gpu_names[i_part])),
} }
) )
os.environ.update(config) os.environ.update(config)
@ -1977,3 +1980,4 @@ with gr.Blocks(title="GPT-SoVITS WebUI", analytics_enabled=False, js=js, css=css
server_port=webui_port_main, server_port=webui_port_main,
# quiet=True, # quiet=True,
) )