mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2025-08-07 15:33:29 +08:00
Merge branch 'RVC-Boss:main' into feat/frontend-usability-enhancements
This commit is contained in:
commit
6df7921d56
@ -28,7 +28,8 @@ class Text2SemanticLightningModule(LightningModule):
|
||||
self.load_state_dict(
|
||||
torch.load(
|
||||
pretrained_s1,
|
||||
map_location="cpu", weights_only=False,
|
||||
map_location="cpu",
|
||||
weights_only=False,
|
||||
)["weight"],
|
||||
)
|
||||
)
|
||||
|
@ -354,7 +354,7 @@ class ScaledAdam(BatchedOptimizer):
|
||||
if ans < 1.0:
|
||||
first_state["num_clipped"] += 1
|
||||
if ans < 0.1:
|
||||
logging.warn(f"Scaling gradients by {ans}, model_norm_threshold={model_norm_threshold}")
|
||||
logging.warning(f"Scaling gradients by {ans}, model_norm_threshold={model_norm_threshold}")
|
||||
if self.show_dominant_parameters:
|
||||
assert p.shape[0] == len(param_names)
|
||||
self._show_gradient_dominating_parameter(tuples, tot_sumsq)
|
||||
@ -362,7 +362,7 @@ class ScaledAdam(BatchedOptimizer):
|
||||
|
||||
def _show_gradient_dominating_parameter(self, tuples: List[Tuple[Tensor, dict, List[str]]], tot_sumsq: Tensor):
|
||||
"""
|
||||
Show information of parameter wihch dominanting tot_sumsq.
|
||||
Show information of parameter which dominating tot_sumsq.
|
||||
|
||||
Args:
|
||||
tuples: a list of tuples of (param, state, param_names)
|
||||
@ -415,7 +415,7 @@ class ScaledAdam(BatchedOptimizer):
|
||||
dominant_grad,
|
||||
) = sorted_by_proportion[dominant_param_name]
|
||||
logging.info(
|
||||
f"Parameter Dominanting tot_sumsq {dominant_param_name}"
|
||||
f"Parameter Dominating tot_sumsq {dominant_param_name}"
|
||||
f" with proportion {dominant_proportion:.2f},"
|
||||
f" where dominant_sumsq=(grad_sumsq*orig_rms_sq)"
|
||||
f"={dominant_sumsq:.3e},"
|
||||
|
@ -32,19 +32,21 @@ from transformers import AutoModelForMaskedLM, AutoTokenizer
|
||||
|
||||
from tools.audio_sr import AP_BWE
|
||||
from tools.i18n.i18n import I18nAuto, scan_language_list
|
||||
from tools.my_utils import load_audio
|
||||
from TTS_infer_pack.text_segmentation_method import splits
|
||||
from TTS_infer_pack.TextPreprocessor import TextPreprocessor
|
||||
from sv import SV
|
||||
resample_transform_dict={}
|
||||
def resample(audio_tensor, sr0,sr1,device):
|
||||
|
||||
resample_transform_dict = {}
|
||||
|
||||
|
||||
def resample(audio_tensor, sr0, sr1, device):
|
||||
global resample_transform_dict
|
||||
key="%s-%s-%s"%(sr0,sr1,str(device))
|
||||
key = "%s-%s-%s" % (sr0, sr1, str(device))
|
||||
if key not in resample_transform_dict:
|
||||
resample_transform_dict[key] = torchaudio.transforms.Resample(
|
||||
sr0, sr1
|
||||
).to(device)
|
||||
resample_transform_dict[key] = torchaudio.transforms.Resample(sr0, sr1).to(device)
|
||||
return resample_transform_dict[key](audio_tensor)
|
||||
|
||||
|
||||
language = os.environ.get("language", "Auto")
|
||||
language = sys.argv[-1] if sys.argv[-1] in scan_language_list() else language
|
||||
i18n = I18nAuto(language=language)
|
||||
@ -111,6 +113,7 @@ def speed_change(input_audio: np.ndarray, speed: float, sr: int):
|
||||
|
||||
return processed_audio
|
||||
|
||||
|
||||
class DictToAttrRecursive(dict):
|
||||
def __init__(self, input_dict):
|
||||
super().__init__(input_dict)
|
||||
@ -479,7 +482,7 @@ class TTS:
|
||||
def init_vits_weights(self, weights_path: str):
|
||||
self.configs.vits_weights_path = weights_path
|
||||
version, model_version, if_lora_v3 = get_sovits_version_from_path_fast(weights_path)
|
||||
if "Pro"in model_version:
|
||||
if "Pro" in model_version:
|
||||
self.init_sv_model()
|
||||
path_sovits = self.configs.default_configs[model_version]["vits_weights_path"]
|
||||
|
||||
@ -498,9 +501,9 @@ class TTS:
|
||||
else:
|
||||
hps["model"]["version"] = "v2"
|
||||
version = hps["model"]["version"]
|
||||
v3v4set={"v3", "v4"}
|
||||
v3v4set = {"v3", "v4"}
|
||||
if model_version not in v3v4set:
|
||||
if "Pro"not in model_version:
|
||||
if "Pro" not in model_version:
|
||||
model_version = version
|
||||
else:
|
||||
hps["model"]["version"] = model_version
|
||||
@ -542,7 +545,7 @@ class TTS:
|
||||
if "pretrained" not in weights_path and hasattr(vits_model, "enc_q"):
|
||||
del vits_model.enc_q
|
||||
|
||||
self.is_v2pro=model_version in {"v2Pro","v2ProPlus"}
|
||||
self.is_v2pro = model_version in {"v2Pro", "v2ProPlus"}
|
||||
|
||||
if if_lora_v3 == False:
|
||||
print(
|
||||
@ -632,7 +635,9 @@ class TTS:
|
||||
)
|
||||
self.vocoder.remove_weight_norm()
|
||||
state_dict_g = torch.load(
|
||||
"%s/GPT_SoVITS/pretrained_models/gsv-v4-pretrained/vocoder.pth" % (now_dir,), map_location="cpu", weights_only=False
|
||||
"%s/GPT_SoVITS/pretrained_models/gsv-v4-pretrained/vocoder.pth" % (now_dir,),
|
||||
map_location="cpu",
|
||||
weights_only=False,
|
||||
)
|
||||
print("loading vocoder", self.vocoder.load_state_dict(state_dict_g))
|
||||
|
||||
@ -752,11 +757,13 @@ class TTS:
|
||||
|
||||
if raw_sr != self.configs.sampling_rate:
|
||||
audio = raw_audio.to(self.configs.device)
|
||||
if (audio.shape[0] == 2): audio = audio.mean(0).unsqueeze(0)
|
||||
if audio.shape[0] == 2:
|
||||
audio = audio.mean(0).unsqueeze(0)
|
||||
audio = resample(audio, raw_sr, self.configs.sampling_rate, self.configs.device)
|
||||
else:
|
||||
audio = raw_audio.to(self.configs.device)
|
||||
if (audio.shape[0] == 2): audio = audio.mean(0).unsqueeze(0)
|
||||
if audio.shape[0] == 2:
|
||||
audio = audio.mean(0).unsqueeze(0)
|
||||
|
||||
maxx = audio.abs().max()
|
||||
if maxx > 1:
|
||||
@ -775,8 +782,9 @@ class TTS:
|
||||
audio = resample(audio, self.configs.sampling_rate, 16000, self.configs.device)
|
||||
if self.configs.is_half:
|
||||
audio = audio.half()
|
||||
else:audio=None
|
||||
return spec,audio
|
||||
else:
|
||||
audio = None
|
||||
return spec, audio
|
||||
|
||||
def _set_prompt_semantic(self, ref_wav_path: str):
|
||||
zero_wav = np.zeros(
|
||||
@ -1073,7 +1081,10 @@ class TTS:
|
||||
|
||||
###### setting reference audio and prompt text preprocessing ########
|
||||
t0 = time.perf_counter()
|
||||
if (ref_audio_path is not None) and (ref_audio_path != self.prompt_cache["ref_audio_path"]):
|
||||
if (ref_audio_path is not None) and (
|
||||
ref_audio_path != self.prompt_cache["ref_audio_path"]
|
||||
or (self.is_v2pro and self.prompt_cache["refer_spec"][0][1] is None)
|
||||
):
|
||||
if not os.path.exists(ref_audio_path):
|
||||
raise ValueError(f"{ref_audio_path} not exists")
|
||||
self.set_ref_audio(ref_audio_path)
|
||||
@ -1212,9 +1223,10 @@ class TTS:
|
||||
t_34 += t4 - t3
|
||||
|
||||
refer_audio_spec = []
|
||||
if self.is_v2pro:sv_emb=[]
|
||||
for spec,audio_tensor in self.prompt_cache["refer_spec"]:
|
||||
spec=spec.to(dtype=self.precision, device=self.configs.device)
|
||||
if self.is_v2pro:
|
||||
sv_emb = []
|
||||
for spec, audio_tensor in self.prompt_cache["refer_spec"]:
|
||||
spec = spec.to(dtype=self.precision, device=self.configs.device)
|
||||
refer_audio_spec.append(spec)
|
||||
if self.is_v2pro:
|
||||
sv_emb.append(self.sv_model.compute_embedding3(audio_tensor))
|
||||
@ -1249,10 +1261,14 @@ class TTS:
|
||||
torch.cat(pred_semantic_list).unsqueeze(0).unsqueeze(0).to(self.configs.device)
|
||||
)
|
||||
_batch_phones = torch.cat(batch_phones).unsqueeze(0).to(self.configs.device)
|
||||
if self.is_v2pro!=True:
|
||||
_batch_audio_fragment = self.vits_model.decode(all_pred_semantic, _batch_phones, refer_audio_spec, speed=speed_factor).detach()[0, 0, :]
|
||||
if self.is_v2pro != True:
|
||||
_batch_audio_fragment = self.vits_model.decode(
|
||||
all_pred_semantic, _batch_phones, refer_audio_spec, speed=speed_factor
|
||||
).detach()[0, 0, :]
|
||||
else:
|
||||
_batch_audio_fragment = self.vits_model.decode(all_pred_semantic, _batch_phones, refer_audio_spec, speed=speed_factor,sv_emb=sv_emb).detach()[0, 0, :]
|
||||
_batch_audio_fragment = self.vits_model.decode(
|
||||
all_pred_semantic, _batch_phones, refer_audio_spec, speed=speed_factor, sv_emb=sv_emb
|
||||
).detach()[0, 0, :]
|
||||
audio_frag_end_idx.insert(0, 0)
|
||||
batch_audio_fragment = [
|
||||
_batch_audio_fragment[audio_frag_end_idx[i - 1] : audio_frag_end_idx[i]]
|
||||
@ -1266,9 +1282,13 @@ class TTS:
|
||||
pred_semantic_list[i][-idx:].unsqueeze(0).unsqueeze(0)
|
||||
) # .unsqueeze(0)#mq要多unsqueeze一次
|
||||
if self.is_v2pro != True:
|
||||
audio_fragment = self.vits_model.decode(_pred_semantic, phones, refer_audio_spec, speed=speed_factor).detach()[0, 0, :]
|
||||
audio_fragment = self.vits_model.decode(
|
||||
_pred_semantic, phones, refer_audio_spec, speed=speed_factor
|
||||
).detach()[0, 0, :]
|
||||
else:
|
||||
audio_fragment = self.vits_model.decode(_pred_semantic, phones, refer_audio_spec, speed=speed_factor,sv_emb=sv_emb).detach()[0, 0, :]
|
||||
audio_fragment = self.vits_model.decode(
|
||||
_pred_semantic, phones, refer_audio_spec, speed=speed_factor, sv_emb=sv_emb
|
||||
).detach()[0, 0, :]
|
||||
batch_audio_fragment.append(audio_fragment) ###试试重建不带上prompt部分
|
||||
else:
|
||||
if parallel_infer:
|
||||
@ -1410,7 +1430,7 @@ class TTS:
|
||||
raw_entry = self.prompt_cache["refer_spec"][0]
|
||||
if isinstance(raw_entry, tuple):
|
||||
raw_entry = raw_entry[0]
|
||||
refer_audio_spec = raw_entry.to(dtype=self.precision,device=self.configs.device)
|
||||
refer_audio_spec = raw_entry.to(dtype=self.precision, device=self.configs.device)
|
||||
|
||||
fea_ref, ge = self.vits_model.decode_encp(prompt_semantic_tokens, prompt_phones, refer_audio_spec)
|
||||
ref_audio: torch.Tensor = self.prompt_cache["raw_audio"]
|
||||
@ -1480,7 +1500,7 @@ class TTS:
|
||||
raw_entry = self.prompt_cache["refer_spec"][0]
|
||||
if isinstance(raw_entry, tuple):
|
||||
raw_entry = raw_entry[0]
|
||||
refer_audio_spec = raw_entry.to(dtype=self.precision,device=self.configs.device)
|
||||
refer_audio_spec = raw_entry.to(dtype=self.precision, device=self.configs.device)
|
||||
|
||||
fea_ref, ge = self.vits_model.decode_encp(prompt_semantic_tokens, prompt_phones, refer_audio_spec)
|
||||
ref_audio: torch.Tensor = self.prompt_cache["raw_audio"]
|
||||
|
@ -159,6 +159,12 @@ class TextPreprocessor:
|
||||
textlist.append(tmp["text"])
|
||||
else:
|
||||
for tmp in LangSegmenter.getTexts(text):
|
||||
if langlist:
|
||||
if (tmp["lang"] == "en" and langlist[-1] == "en") or (
|
||||
tmp["lang"] != "en" and langlist[-1] != "en"
|
||||
):
|
||||
textlist[-1] += tmp["text"]
|
||||
continue
|
||||
if tmp["lang"] == "en":
|
||||
langlist.append(tmp["lang"])
|
||||
else:
|
||||
|
@ -1,13 +1,12 @@
|
||||
# Copyright 3D-Speaker (https://github.com/alibaba-damo-academy/3D-Speaker). All Rights Reserved.
|
||||
# Licensed under the Apache License, Version 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
||||
|
||||
"""
|
||||
Res2Net implementation is adapted from https://github.com/wenet-e2e/wespeaker.
|
||||
ERes2Net incorporates both local and global feature fusion techniques to improve the performance.
|
||||
The local feature fusion (LFF) fuses the features within one single residual block to extract the local signal.
|
||||
The global feature fusion (GFF) takes acoustic features of different scales as input to aggregate global signal.
|
||||
"""
|
||||
|
||||
Res2Net implementation is adapted from https://github.com/wenet-e2e/wespeaker.
|
||||
ERes2Net incorporates both local and global feature fusion techniques to improve the performance.
|
||||
The local feature fusion (LFF) fuses the features within one single residual block to extract the local signal.
|
||||
The global feature fusion (GFF) takes acoustic features of different scales as input to aggregate global signal.
|
||||
"""
|
||||
|
||||
import torch
|
||||
import math
|
||||
@ -16,15 +15,14 @@ import torch.nn.functional as F
|
||||
import pooling_layers as pooling_layers
|
||||
from fusion import AFF
|
||||
|
||||
class ReLU(nn.Hardtanh):
|
||||
|
||||
class ReLU(nn.Hardtanh):
|
||||
def __init__(self, inplace=False):
|
||||
super(ReLU, self).__init__(0, 20, inplace)
|
||||
|
||||
def __repr__(self):
|
||||
inplace_str = 'inplace' if self.inplace else ''
|
||||
return self.__class__.__name__ + ' (' \
|
||||
+ inplace_str + ')'
|
||||
inplace_str = "inplace" if self.inplace else ""
|
||||
return self.__class__.__name__ + " (" + inplace_str + ")"
|
||||
|
||||
|
||||
class BasicBlockERes2Net(nn.Module):
|
||||
@ -32,28 +30,28 @@ class BasicBlockERes2Net(nn.Module):
|
||||
|
||||
def __init__(self, in_planes, planes, stride=1, baseWidth=32, scale=2):
|
||||
super(BasicBlockERes2Net, self).__init__()
|
||||
width = int(math.floor(planes*(baseWidth/64.0)))
|
||||
self.conv1 = nn.Conv2d(in_planes, width*scale, kernel_size=1, stride=stride, bias=False)
|
||||
self.bn1 = nn.BatchNorm2d(width*scale)
|
||||
width = int(math.floor(planes * (baseWidth / 64.0)))
|
||||
self.conv1 = nn.Conv2d(in_planes, width * scale, kernel_size=1, stride=stride, bias=False)
|
||||
self.bn1 = nn.BatchNorm2d(width * scale)
|
||||
self.nums = scale
|
||||
|
||||
convs=[]
|
||||
bns=[]
|
||||
convs = []
|
||||
bns = []
|
||||
for i in range(self.nums):
|
||||
convs.append(nn.Conv2d(width, width, kernel_size=3, padding=1, bias=False))
|
||||
bns.append(nn.BatchNorm2d(width))
|
||||
convs.append(nn.Conv2d(width, width, kernel_size=3, padding=1, bias=False))
|
||||
bns.append(nn.BatchNorm2d(width))
|
||||
self.convs = nn.ModuleList(convs)
|
||||
self.bns = nn.ModuleList(bns)
|
||||
self.relu = ReLU(inplace=True)
|
||||
|
||||
self.conv3 = nn.Conv2d(width*scale, planes*self.expansion, kernel_size=1, bias=False)
|
||||
self.bn3 = nn.BatchNorm2d(planes*self.expansion)
|
||||
|
||||
self.conv3 = nn.Conv2d(width * scale, planes * self.expansion, kernel_size=1, bias=False)
|
||||
self.bn3 = nn.BatchNorm2d(planes * self.expansion)
|
||||
self.shortcut = nn.Sequential()
|
||||
if stride != 1 or in_planes != self.expansion * planes:
|
||||
self.shortcut = nn.Sequential(
|
||||
nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1,
|
||||
stride=stride, bias=False),
|
||||
nn.BatchNorm2d(self.expansion * planes))
|
||||
nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
|
||||
nn.BatchNorm2d(self.expansion * planes),
|
||||
)
|
||||
self.stride = stride
|
||||
self.width = width
|
||||
self.scale = scale
|
||||
@ -64,18 +62,18 @@ class BasicBlockERes2Net(nn.Module):
|
||||
out = self.conv1(x)
|
||||
out = self.bn1(out)
|
||||
out = self.relu(out)
|
||||
spx = torch.split(out,self.width,1)
|
||||
spx = torch.split(out, self.width, 1)
|
||||
for i in range(self.nums):
|
||||
if i==0:
|
||||
sp = spx[i]
|
||||
else:
|
||||
sp = sp + spx[i]
|
||||
sp = self.convs[i](sp)
|
||||
sp = self.relu(self.bns[i](sp))
|
||||
if i==0:
|
||||
out = sp
|
||||
else:
|
||||
out = torch.cat((out,sp),1)
|
||||
if i == 0:
|
||||
sp = spx[i]
|
||||
else:
|
||||
sp = sp + spx[i]
|
||||
sp = self.convs[i](sp)
|
||||
sp = self.relu(self.bns[i](sp))
|
||||
if i == 0:
|
||||
out = sp
|
||||
else:
|
||||
out = torch.cat((out, sp), 1)
|
||||
|
||||
out = self.conv3(out)
|
||||
out = self.bn3(out)
|
||||
@ -86,22 +84,23 @@ class BasicBlockERes2Net(nn.Module):
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class BasicBlockERes2Net_diff_AFF(nn.Module):
|
||||
expansion = 2
|
||||
|
||||
def __init__(self, in_planes, planes, stride=1, baseWidth=32, scale=2):
|
||||
super(BasicBlockERes2Net_diff_AFF, self).__init__()
|
||||
width = int(math.floor(planes*(baseWidth/64.0)))
|
||||
self.conv1 = nn.Conv2d(in_planes, width*scale, kernel_size=1, stride=stride, bias=False)
|
||||
self.bn1 = nn.BatchNorm2d(width*scale)
|
||||
width = int(math.floor(planes * (baseWidth / 64.0)))
|
||||
self.conv1 = nn.Conv2d(in_planes, width * scale, kernel_size=1, stride=stride, bias=False)
|
||||
self.bn1 = nn.BatchNorm2d(width * scale)
|
||||
self.nums = scale
|
||||
|
||||
convs=[]
|
||||
fuse_models=[]
|
||||
bns=[]
|
||||
convs = []
|
||||
fuse_models = []
|
||||
bns = []
|
||||
for i in range(self.nums):
|
||||
convs.append(nn.Conv2d(width, width, kernel_size=3, padding=1, bias=False))
|
||||
bns.append(nn.BatchNorm2d(width))
|
||||
convs.append(nn.Conv2d(width, width, kernel_size=3, padding=1, bias=False))
|
||||
bns.append(nn.BatchNorm2d(width))
|
||||
for j in range(self.nums - 1):
|
||||
fuse_models.append(AFF(channels=width))
|
||||
|
||||
@ -109,15 +108,15 @@ class BasicBlockERes2Net_diff_AFF(nn.Module):
|
||||
self.bns = nn.ModuleList(bns)
|
||||
self.fuse_models = nn.ModuleList(fuse_models)
|
||||
self.relu = ReLU(inplace=True)
|
||||
|
||||
self.conv3 = nn.Conv2d(width*scale, planes*self.expansion, kernel_size=1, bias=False)
|
||||
self.bn3 = nn.BatchNorm2d(planes*self.expansion)
|
||||
|
||||
self.conv3 = nn.Conv2d(width * scale, planes * self.expansion, kernel_size=1, bias=False)
|
||||
self.bn3 = nn.BatchNorm2d(planes * self.expansion)
|
||||
self.shortcut = nn.Sequential()
|
||||
if stride != 1 or in_planes != self.expansion * planes:
|
||||
self.shortcut = nn.Sequential(
|
||||
nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1,
|
||||
stride=stride, bias=False),
|
||||
nn.BatchNorm2d(self.expansion * planes))
|
||||
nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
|
||||
nn.BatchNorm2d(self.expansion * planes),
|
||||
)
|
||||
self.stride = stride
|
||||
self.width = width
|
||||
self.scale = scale
|
||||
@ -128,19 +127,19 @@ class BasicBlockERes2Net_diff_AFF(nn.Module):
|
||||
out = self.conv1(x)
|
||||
out = self.bn1(out)
|
||||
out = self.relu(out)
|
||||
spx = torch.split(out,self.width,1)
|
||||
spx = torch.split(out, self.width, 1)
|
||||
for i in range(self.nums):
|
||||
if i==0:
|
||||
if i == 0:
|
||||
sp = spx[i]
|
||||
else:
|
||||
sp = self.fuse_models[i-1](sp, spx[i])
|
||||
|
||||
sp = self.fuse_models[i - 1](sp, spx[i])
|
||||
|
||||
sp = self.convs[i](sp)
|
||||
sp = self.relu(self.bns[i](sp))
|
||||
if i==0:
|
||||
if i == 0:
|
||||
out = sp
|
||||
else:
|
||||
out = torch.cat((out,sp),1)
|
||||
out = torch.cat((out, sp), 1)
|
||||
|
||||
out = self.conv3(out)
|
||||
out = self.bn3(out)
|
||||
@ -151,16 +150,19 @@ class BasicBlockERes2Net_diff_AFF(nn.Module):
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class ERes2Net(nn.Module):
|
||||
def __init__(self,
|
||||
block=BasicBlockERes2Net,
|
||||
block_fuse=BasicBlockERes2Net_diff_AFF,
|
||||
num_blocks=[3, 4, 6, 3],
|
||||
m_channels=32,
|
||||
feat_dim=80,
|
||||
embedding_size=192,
|
||||
pooling_func='TSTP',
|
||||
two_emb_layer=False):
|
||||
def __init__(
|
||||
self,
|
||||
block=BasicBlockERes2Net,
|
||||
block_fuse=BasicBlockERes2Net_diff_AFF,
|
||||
num_blocks=[3, 4, 6, 3],
|
||||
m_channels=32,
|
||||
feat_dim=80,
|
||||
embedding_size=192,
|
||||
pooling_func="TSTP",
|
||||
two_emb_layer=False,
|
||||
):
|
||||
super(ERes2Net, self).__init__()
|
||||
self.in_planes = m_channels
|
||||
self.feat_dim = feat_dim
|
||||
@ -176,20 +178,24 @@ class ERes2Net(nn.Module):
|
||||
self.layer4 = self._make_layer(block_fuse, m_channels * 8, num_blocks[3], stride=2)
|
||||
|
||||
# Downsampling module for each layer
|
||||
self.layer1_downsample = nn.Conv2d(m_channels * 2, m_channels * 4, kernel_size=3, stride=2, padding=1, bias=False)
|
||||
self.layer2_downsample = nn.Conv2d(m_channels * 4, m_channels * 8, kernel_size=3, padding=1, stride=2, bias=False)
|
||||
self.layer3_downsample = nn.Conv2d(m_channels * 8, m_channels * 16, kernel_size=3, padding=1, stride=2, bias=False)
|
||||
self.layer1_downsample = nn.Conv2d(
|
||||
m_channels * 2, m_channels * 4, kernel_size=3, stride=2, padding=1, bias=False
|
||||
)
|
||||
self.layer2_downsample = nn.Conv2d(
|
||||
m_channels * 4, m_channels * 8, kernel_size=3, padding=1, stride=2, bias=False
|
||||
)
|
||||
self.layer3_downsample = nn.Conv2d(
|
||||
m_channels * 8, m_channels * 16, kernel_size=3, padding=1, stride=2, bias=False
|
||||
)
|
||||
|
||||
# Bottom-up fusion module
|
||||
self.fuse_mode12 = AFF(channels=m_channels * 4)
|
||||
self.fuse_mode123 = AFF(channels=m_channels * 8)
|
||||
self.fuse_mode1234 = AFF(channels=m_channels * 16)
|
||||
|
||||
self.n_stats = 1 if pooling_func == 'TAP' or pooling_func == "TSDP" else 2
|
||||
self.pool = getattr(pooling_layers, pooling_func)(
|
||||
in_dim=self.stats_dim * block.expansion)
|
||||
self.seg_1 = nn.Linear(self.stats_dim * block.expansion * self.n_stats,
|
||||
embedding_size)
|
||||
self.n_stats = 1 if pooling_func == "TAP" or pooling_func == "TSDP" else 2
|
||||
self.pool = getattr(pooling_layers, pooling_func)(in_dim=self.stats_dim * block.expansion)
|
||||
self.seg_1 = nn.Linear(self.stats_dim * block.expansion * self.n_stats, embedding_size)
|
||||
if self.two_emb_layer:
|
||||
self.seg_bn_1 = nn.BatchNorm1d(embedding_size, affine=False)
|
||||
self.seg_2 = nn.Linear(embedding_size, embedding_size)
|
||||
@ -212,7 +218,7 @@ class ERes2Net(nn.Module):
|
||||
out1 = self.layer1(out)
|
||||
out2 = self.layer2(out1)
|
||||
out1_downsample = self.layer1_downsample(out1)
|
||||
fuse_out12 = self.fuse_mode12(out2, out1_downsample)
|
||||
fuse_out12 = self.fuse_mode12(out2, out1_downsample)
|
||||
out3 = self.layer3(out2)
|
||||
fuse_out12_downsample = self.layer2_downsample(fuse_out12)
|
||||
fuse_out123 = self.fuse_mode123(out3, fuse_out12_downsample)
|
||||
@ -243,18 +249,16 @@ class ERes2Net(nn.Module):
|
||||
fuse_out123 = self.fuse_mode123(out3, fuse_out12_downsample)
|
||||
out4 = self.layer4(out3)
|
||||
fuse_out123_downsample = self.layer3_downsample(fuse_out123)
|
||||
fuse_out1234 = self.fuse_mode1234(out4, fuse_out123_downsample).flatten(start_dim=1,end_dim=2).mean(-1)
|
||||
fuse_out1234 = self.fuse_mode1234(out4, fuse_out123_downsample).flatten(start_dim=1, end_dim=2).mean(-1)
|
||||
return fuse_out1234
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
if __name__ == "__main__":
|
||||
x = torch.zeros(10, 300, 80)
|
||||
model = ERes2Net(feat_dim=80, embedding_size=192, pooling_func='TSTP')
|
||||
model = ERes2Net(feat_dim=80, embedding_size=192, pooling_func="TSTP")
|
||||
model.eval()
|
||||
out = model(x)
|
||||
print(out.shape) # torch.Size([10, 192])
|
||||
print(out.shape) # torch.Size([10, 192])
|
||||
|
||||
num_params = sum(param.numel() for param in model.parameters())
|
||||
print("{} M".format(num_params / 1e6)) # 6.61M
|
||||
|
||||
print("{} M".format(num_params / 1e6)) # 6.61M
|
||||
|
@ -1,14 +1,12 @@
|
||||
# Copyright 3D-Speaker (https://github.com/alibaba-damo-academy/3D-Speaker). All Rights Reserved.
|
||||
# Licensed under the Apache License, Version 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
||||
|
||||
"""
|
||||
To further improve the short-duration feature extraction capability of ERes2Net, we expand the channel dimension
|
||||
within each stage. However, this modification also increases the number of model parameters and computational complexity.
|
||||
To alleviate this problem, we propose an improved ERes2NetV2 by pruning redundant structures, ultimately reducing
|
||||
both the model parameters and its computational cost.
|
||||
"""
|
||||
|
||||
|
||||
To further improve the short-duration feature extraction capability of ERes2Net, we expand the channel dimension
|
||||
within each stage. However, this modification also increases the number of model parameters and computational complexity.
|
||||
To alleviate this problem, we propose an improved ERes2NetV2 by pruning redundant structures, ultimately reducing
|
||||
both the model parameters and its computational cost.
|
||||
"""
|
||||
|
||||
import torch
|
||||
import math
|
||||
@ -17,47 +15,42 @@ import torch.nn.functional as F
|
||||
import pooling_layers as pooling_layers
|
||||
from fusion import AFF
|
||||
|
||||
class ReLU(nn.Hardtanh):
|
||||
|
||||
class ReLU(nn.Hardtanh):
|
||||
def __init__(self, inplace=False):
|
||||
super(ReLU, self).__init__(0, 20, inplace)
|
||||
|
||||
def __repr__(self):
|
||||
inplace_str = 'inplace' if self.inplace else ''
|
||||
return self.__class__.__name__ + ' (' \
|
||||
+ inplace_str + ')'
|
||||
inplace_str = "inplace" if self.inplace else ""
|
||||
return self.__class__.__name__ + " (" + inplace_str + ")"
|
||||
|
||||
|
||||
class BasicBlockERes2NetV2(nn.Module):
|
||||
|
||||
def __init__(self, in_planes, planes, stride=1, baseWidth=26, scale=2, expansion=2):
|
||||
super(BasicBlockERes2NetV2, self).__init__()
|
||||
width = int(math.floor(planes*(baseWidth/64.0)))
|
||||
self.conv1 = nn.Conv2d(in_planes, width*scale, kernel_size=1, stride=stride, bias=False)
|
||||
self.bn1 = nn.BatchNorm2d(width*scale)
|
||||
width = int(math.floor(planes * (baseWidth / 64.0)))
|
||||
self.conv1 = nn.Conv2d(in_planes, width * scale, kernel_size=1, stride=stride, bias=False)
|
||||
self.bn1 = nn.BatchNorm2d(width * scale)
|
||||
self.nums = scale
|
||||
self.expansion = expansion
|
||||
|
||||
convs=[]
|
||||
bns=[]
|
||||
convs = []
|
||||
bns = []
|
||||
for i in range(self.nums):
|
||||
convs.append(nn.Conv2d(width, width, kernel_size=3, padding=1, bias=False))
|
||||
bns.append(nn.BatchNorm2d(width))
|
||||
convs.append(nn.Conv2d(width, width, kernel_size=3, padding=1, bias=False))
|
||||
bns.append(nn.BatchNorm2d(width))
|
||||
self.convs = nn.ModuleList(convs)
|
||||
self.bns = nn.ModuleList(bns)
|
||||
self.relu = ReLU(inplace=True)
|
||||
|
||||
self.conv3 = nn.Conv2d(width*scale, planes*self.expansion, kernel_size=1, bias=False)
|
||||
self.bn3 = nn.BatchNorm2d(planes*self.expansion)
|
||||
|
||||
self.conv3 = nn.Conv2d(width * scale, planes * self.expansion, kernel_size=1, bias=False)
|
||||
self.bn3 = nn.BatchNorm2d(planes * self.expansion)
|
||||
self.shortcut = nn.Sequential()
|
||||
if stride != 1 or in_planes != self.expansion * planes:
|
||||
self.shortcut = nn.Sequential(
|
||||
nn.Conv2d(in_planes,
|
||||
self.expansion * planes,
|
||||
kernel_size=1,
|
||||
stride=stride,
|
||||
bias=False),
|
||||
nn.BatchNorm2d(self.expansion * planes))
|
||||
nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
|
||||
nn.BatchNorm2d(self.expansion * planes),
|
||||
)
|
||||
self.stride = stride
|
||||
self.width = width
|
||||
self.scale = scale
|
||||
@ -68,18 +61,18 @@ class BasicBlockERes2NetV2(nn.Module):
|
||||
out = self.conv1(x)
|
||||
out = self.bn1(out)
|
||||
out = self.relu(out)
|
||||
spx = torch.split(out,self.width,1)
|
||||
spx = torch.split(out, self.width, 1)
|
||||
for i in range(self.nums):
|
||||
if i==0:
|
||||
sp = spx[i]
|
||||
else:
|
||||
sp = sp + spx[i]
|
||||
sp = self.convs[i](sp)
|
||||
sp = self.relu(self.bns[i](sp))
|
||||
if i==0:
|
||||
out = sp
|
||||
else:
|
||||
out = torch.cat((out,sp),1)
|
||||
if i == 0:
|
||||
sp = spx[i]
|
||||
else:
|
||||
sp = sp + spx[i]
|
||||
sp = self.convs[i](sp)
|
||||
sp = self.relu(self.bns[i](sp))
|
||||
if i == 0:
|
||||
out = sp
|
||||
else:
|
||||
out = torch.cat((out, sp), 1)
|
||||
|
||||
out = self.conv3(out)
|
||||
out = self.bn3(out)
|
||||
@ -90,22 +83,22 @@ class BasicBlockERes2NetV2(nn.Module):
|
||||
|
||||
return out
|
||||
|
||||
class BasicBlockERes2NetV2AFF(nn.Module):
|
||||
|
||||
class BasicBlockERes2NetV2AFF(nn.Module):
|
||||
def __init__(self, in_planes, planes, stride=1, baseWidth=26, scale=2, expansion=2):
|
||||
super(BasicBlockERes2NetV2AFF, self).__init__()
|
||||
width = int(math.floor(planes*(baseWidth/64.0)))
|
||||
self.conv1 = nn.Conv2d(in_planes, width*scale, kernel_size=1, stride=stride, bias=False)
|
||||
self.bn1 = nn.BatchNorm2d(width*scale)
|
||||
width = int(math.floor(planes * (baseWidth / 64.0)))
|
||||
self.conv1 = nn.Conv2d(in_planes, width * scale, kernel_size=1, stride=stride, bias=False)
|
||||
self.bn1 = nn.BatchNorm2d(width * scale)
|
||||
self.nums = scale
|
||||
self.expansion = expansion
|
||||
|
||||
convs=[]
|
||||
fuse_models=[]
|
||||
bns=[]
|
||||
convs = []
|
||||
fuse_models = []
|
||||
bns = []
|
||||
for i in range(self.nums):
|
||||
convs.append(nn.Conv2d(width, width, kernel_size=3, padding=1, bias=False))
|
||||
bns.append(nn.BatchNorm2d(width))
|
||||
convs.append(nn.Conv2d(width, width, kernel_size=3, padding=1, bias=False))
|
||||
bns.append(nn.BatchNorm2d(width))
|
||||
for j in range(self.nums - 1):
|
||||
fuse_models.append(AFF(channels=width, r=4))
|
||||
|
||||
@ -113,18 +106,15 @@ class BasicBlockERes2NetV2AFF(nn.Module):
|
||||
self.bns = nn.ModuleList(bns)
|
||||
self.fuse_models = nn.ModuleList(fuse_models)
|
||||
self.relu = ReLU(inplace=True)
|
||||
|
||||
self.conv3 = nn.Conv2d(width*scale, planes*self.expansion, kernel_size=1, bias=False)
|
||||
self.bn3 = nn.BatchNorm2d(planes*self.expansion)
|
||||
|
||||
self.conv3 = nn.Conv2d(width * scale, planes * self.expansion, kernel_size=1, bias=False)
|
||||
self.bn3 = nn.BatchNorm2d(planes * self.expansion)
|
||||
self.shortcut = nn.Sequential()
|
||||
if stride != 1 or in_planes != self.expansion * planes:
|
||||
self.shortcut = nn.Sequential(
|
||||
nn.Conv2d(in_planes,
|
||||
self.expansion * planes,
|
||||
kernel_size=1,
|
||||
stride=stride,
|
||||
bias=False),
|
||||
nn.BatchNorm2d(self.expansion * planes))
|
||||
nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
|
||||
nn.BatchNorm2d(self.expansion * planes),
|
||||
)
|
||||
self.stride = stride
|
||||
self.width = width
|
||||
self.scale = scale
|
||||
@ -135,19 +125,19 @@ class BasicBlockERes2NetV2AFF(nn.Module):
|
||||
out = self.conv1(x)
|
||||
out = self.bn1(out)
|
||||
out = self.relu(out)
|
||||
spx = torch.split(out,self.width,1)
|
||||
spx = torch.split(out, self.width, 1)
|
||||
for i in range(self.nums):
|
||||
if i==0:
|
||||
if i == 0:
|
||||
sp = spx[i]
|
||||
else:
|
||||
sp = self.fuse_models[i-1](sp, spx[i])
|
||||
|
||||
sp = self.fuse_models[i - 1](sp, spx[i])
|
||||
|
||||
sp = self.convs[i](sp)
|
||||
sp = self.relu(self.bns[i](sp))
|
||||
if i==0:
|
||||
if i == 0:
|
||||
out = sp
|
||||
else:
|
||||
out = torch.cat((out,sp),1)
|
||||
out = torch.cat((out, sp), 1)
|
||||
|
||||
out = self.conv3(out)
|
||||
out = self.bn3(out)
|
||||
@ -158,19 +148,22 @@ class BasicBlockERes2NetV2AFF(nn.Module):
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class ERes2NetV2(nn.Module):
|
||||
def __init__(self,
|
||||
block=BasicBlockERes2NetV2,
|
||||
block_fuse=BasicBlockERes2NetV2AFF,
|
||||
num_blocks=[3, 4, 6, 3],
|
||||
m_channels=64,
|
||||
feat_dim=80,
|
||||
embedding_size=192,
|
||||
baseWidth=26,
|
||||
scale=2,
|
||||
expansion=2,
|
||||
pooling_func='TSTP',
|
||||
two_emb_layer=False):
|
||||
def __init__(
|
||||
self,
|
||||
block=BasicBlockERes2NetV2,
|
||||
block_fuse=BasicBlockERes2NetV2AFF,
|
||||
num_blocks=[3, 4, 6, 3],
|
||||
m_channels=64,
|
||||
feat_dim=80,
|
||||
embedding_size=192,
|
||||
baseWidth=26,
|
||||
scale=2,
|
||||
expansion=2,
|
||||
pooling_func="TSTP",
|
||||
two_emb_layer=False,
|
||||
):
|
||||
super(ERes2NetV2, self).__init__()
|
||||
self.in_planes = m_channels
|
||||
self.feat_dim = feat_dim
|
||||
@ -181,42 +174,29 @@ class ERes2NetV2(nn.Module):
|
||||
self.scale = scale
|
||||
self.expansion = expansion
|
||||
|
||||
self.conv1 = nn.Conv2d(1,
|
||||
m_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
bias=False)
|
||||
self.conv1 = nn.Conv2d(1, m_channels, kernel_size=3, stride=1, padding=1, bias=False)
|
||||
self.bn1 = nn.BatchNorm2d(m_channels)
|
||||
self.layer1 = self._make_layer(block,
|
||||
m_channels,
|
||||
num_blocks[0],
|
||||
stride=1)
|
||||
self.layer2 = self._make_layer(block,
|
||||
m_channels * 2,
|
||||
num_blocks[1],
|
||||
stride=2)
|
||||
self.layer3 = self._make_layer(block_fuse,
|
||||
m_channels * 4,
|
||||
num_blocks[2],
|
||||
stride=2)
|
||||
self.layer4 = self._make_layer(block_fuse,
|
||||
m_channels * 8,
|
||||
num_blocks[3],
|
||||
stride=2)
|
||||
self.layer1 = self._make_layer(block, m_channels, num_blocks[0], stride=1)
|
||||
self.layer2 = self._make_layer(block, m_channels * 2, num_blocks[1], stride=2)
|
||||
self.layer3 = self._make_layer(block_fuse, m_channels * 4, num_blocks[2], stride=2)
|
||||
self.layer4 = self._make_layer(block_fuse, m_channels * 8, num_blocks[3], stride=2)
|
||||
|
||||
# Downsampling module
|
||||
self.layer3_ds = nn.Conv2d(m_channels * 4 * self.expansion, m_channels * 8 * self.expansion, kernel_size=3, \
|
||||
padding=1, stride=2, bias=False)
|
||||
self.layer3_ds = nn.Conv2d(
|
||||
m_channels * 4 * self.expansion,
|
||||
m_channels * 8 * self.expansion,
|
||||
kernel_size=3,
|
||||
padding=1,
|
||||
stride=2,
|
||||
bias=False,
|
||||
)
|
||||
|
||||
# Bottom-up fusion module
|
||||
self.fuse34 = AFF(channels=m_channels * 8 * self.expansion, r=4)
|
||||
|
||||
self.n_stats = 1 if pooling_func == 'TAP' or pooling_func == "TSDP" else 2
|
||||
self.pool = getattr(pooling_layers, pooling_func)(
|
||||
in_dim=self.stats_dim * self.expansion)
|
||||
self.seg_1 = nn.Linear(self.stats_dim * self.expansion * self.n_stats,
|
||||
embedding_size)
|
||||
self.n_stats = 1 if pooling_func == "TAP" or pooling_func == "TSDP" else 2
|
||||
self.pool = getattr(pooling_layers, pooling_func)(in_dim=self.stats_dim * self.expansion)
|
||||
self.seg_1 = nn.Linear(self.stats_dim * self.expansion * self.n_stats, embedding_size)
|
||||
if self.two_emb_layer:
|
||||
self.seg_bn_1 = nn.BatchNorm1d(embedding_size, affine=False)
|
||||
self.seg_2 = nn.Linear(embedding_size, embedding_size)
|
||||
@ -228,7 +208,11 @@ class ERes2NetV2(nn.Module):
|
||||
strides = [stride] + [1] * (num_blocks - 1)
|
||||
layers = []
|
||||
for stride in strides:
|
||||
layers.append(block(self.in_planes, planes, stride, baseWidth=self.baseWidth, scale=self.scale, expansion=self.expansion))
|
||||
layers.append(
|
||||
block(
|
||||
self.in_planes, planes, stride, baseWidth=self.baseWidth, scale=self.scale, expansion=self.expansion
|
||||
)
|
||||
)
|
||||
self.in_planes = planes * self.expansion
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
@ -264,7 +248,7 @@ class ERes2NetV2(nn.Module):
|
||||
out3_ds = self.layer3_ds(out3)
|
||||
fuse_out34 = self.fuse34(out4, out3_ds)
|
||||
# print(111111111,fuse_out34.shape)#111111111 torch.Size([16, 2048, 10, 72])
|
||||
return fuse_out34.flatten(start_dim=1,end_dim=2).mean(-1)
|
||||
return fuse_out34.flatten(start_dim=1, end_dim=2).mean(-1)
|
||||
# stats = self.pool(fuse_out34)
|
||||
#
|
||||
# embed_a = self.seg_1(stats)
|
||||
@ -276,17 +260,13 @@ class ERes2NetV2(nn.Module):
|
||||
# else:
|
||||
# return embed_a
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
if __name__ == "__main__":
|
||||
x = torch.randn(1, 300, 80)
|
||||
model = ERes2NetV2(feat_dim=80, embedding_size=192, m_channels=64, baseWidth=26, scale=2, expansion=2)
|
||||
model.eval()
|
||||
y = model(x)
|
||||
print(y.size())
|
||||
macs, num_params = profile(model, inputs=(x, ))
|
||||
print("Params: {} M".format(num_params / 1e6)) # 17.86 M
|
||||
print("MACs: {} G".format(macs / 1e9)) # 12.69 G
|
||||
|
||||
|
||||
|
||||
|
||||
macs, num_params = profile(model, inputs=(x,))
|
||||
print("Params: {} M".format(num_params / 1e6)) # 17.86 M
|
||||
print("MACs: {} G".format(macs / 1e9)) # 12.69 G
|
||||
|
@ -1,14 +1,13 @@
|
||||
# Copyright 3D-Speaker (https://github.com/alibaba-damo-academy/3D-Speaker). All Rights Reserved.
|
||||
# Licensed under the Apache License, Version 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
||||
|
||||
""" Res2Net implementation is adapted from https://github.com/wenet-e2e/wespeaker.
|
||||
ERes2Net incorporates both local and global feature fusion techniques to improve the performance.
|
||||
The local feature fusion (LFF) fuses the features within one single residual block to extract the local signal.
|
||||
The global feature fusion (GFF) takes acoustic features of different scales as input to aggregate global signal.
|
||||
ERes2Net-huge is an upgraded version of ERes2Net that uses a larger number of parameters to achieve better
|
||||
recognition performance. Parameters expansion, baseWidth, and scale can be modified to obtain optimal performance.
|
||||
"""Res2Net implementation is adapted from https://github.com/wenet-e2e/wespeaker.
|
||||
ERes2Net incorporates both local and global feature fusion techniques to improve the performance.
|
||||
The local feature fusion (LFF) fuses the features within one single residual block to extract the local signal.
|
||||
The global feature fusion (GFF) takes acoustic features of different scales as input to aggregate global signal.
|
||||
ERes2Net-huge is an upgraded version of ERes2Net that uses a larger number of parameters to achieve better
|
||||
recognition performance. Parameters expansion, baseWidth, and scale can be modified to obtain optimal performance.
|
||||
"""
|
||||
import pdb
|
||||
|
||||
import torch
|
||||
import math
|
||||
@ -17,15 +16,14 @@ import torch.nn.functional as F
|
||||
import pooling_layers as pooling_layers
|
||||
from fusion import AFF
|
||||
|
||||
class ReLU(nn.Hardtanh):
|
||||
|
||||
class ReLU(nn.Hardtanh):
|
||||
def __init__(self, inplace=False):
|
||||
super(ReLU, self).__init__(0, 20, inplace)
|
||||
|
||||
def __repr__(self):
|
||||
inplace_str = 'inplace' if self.inplace else ''
|
||||
return self.__class__.__name__ + ' (' \
|
||||
+ inplace_str + ')'
|
||||
inplace_str = "inplace" if self.inplace else ""
|
||||
return self.__class__.__name__ + " (" + inplace_str + ")"
|
||||
|
||||
|
||||
class BasicBlockERes2Net(nn.Module):
|
||||
@ -33,27 +31,28 @@ class BasicBlockERes2Net(nn.Module):
|
||||
|
||||
def __init__(self, in_planes, planes, stride=1, baseWidth=24, scale=3):
|
||||
super(BasicBlockERes2Net, self).__init__()
|
||||
width = int(math.floor(planes*(baseWidth/64.0)))
|
||||
self.conv1 = nn.Conv2d(in_planes, width*scale, kernel_size=1, stride=stride, bias=False)
|
||||
self.bn1 = nn.BatchNorm2d(width*scale)
|
||||
width = int(math.floor(planes * (baseWidth / 64.0)))
|
||||
self.conv1 = nn.Conv2d(in_planes, width * scale, kernel_size=1, stride=stride, bias=False)
|
||||
self.bn1 = nn.BatchNorm2d(width * scale)
|
||||
self.nums = scale
|
||||
|
||||
convs=[]
|
||||
bns=[]
|
||||
convs = []
|
||||
bns = []
|
||||
for i in range(self.nums):
|
||||
convs.append(nn.Conv2d(width, width, kernel_size=3, padding=1, bias=False))
|
||||
bns.append(nn.BatchNorm2d(width))
|
||||
convs.append(nn.Conv2d(width, width, kernel_size=3, padding=1, bias=False))
|
||||
bns.append(nn.BatchNorm2d(width))
|
||||
self.convs = nn.ModuleList(convs)
|
||||
self.bns = nn.ModuleList(bns)
|
||||
self.relu = ReLU(inplace=True)
|
||||
|
||||
self.conv3 = nn.Conv2d(width*scale, planes*self.expansion, kernel_size=1, bias=False)
|
||||
self.bn3 = nn.BatchNorm2d(planes*self.expansion)
|
||||
|
||||
self.conv3 = nn.Conv2d(width * scale, planes * self.expansion, kernel_size=1, bias=False)
|
||||
self.bn3 = nn.BatchNorm2d(planes * self.expansion)
|
||||
self.shortcut = nn.Sequential()
|
||||
if stride != 1 or in_planes != self.expansion * planes:
|
||||
self.shortcut = nn.Sequential(
|
||||
nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
|
||||
nn.BatchNorm2d(self.expansion * planes))
|
||||
nn.BatchNorm2d(self.expansion * planes),
|
||||
)
|
||||
self.stride = stride
|
||||
self.width = width
|
||||
self.scale = scale
|
||||
@ -64,18 +63,18 @@ class BasicBlockERes2Net(nn.Module):
|
||||
out = self.conv1(x)
|
||||
out = self.bn1(out)
|
||||
out = self.relu(out)
|
||||
spx = torch.split(out,self.width,1)
|
||||
spx = torch.split(out, self.width, 1)
|
||||
for i in range(self.nums):
|
||||
if i==0:
|
||||
sp = spx[i]
|
||||
else:
|
||||
sp = sp + spx[i]
|
||||
sp = self.convs[i](sp)
|
||||
sp = self.relu(self.bns[i](sp))
|
||||
if i==0:
|
||||
out = sp
|
||||
else:
|
||||
out = torch.cat((out,sp),1)
|
||||
if i == 0:
|
||||
sp = spx[i]
|
||||
else:
|
||||
sp = sp + spx[i]
|
||||
sp = self.convs[i](sp)
|
||||
sp = self.relu(self.bns[i](sp))
|
||||
if i == 0:
|
||||
out = sp
|
||||
else:
|
||||
out = torch.cat((out, sp), 1)
|
||||
|
||||
out = self.conv3(out)
|
||||
out = self.bn3(out)
|
||||
@ -86,22 +85,23 @@ class BasicBlockERes2Net(nn.Module):
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class BasicBlockERes2Net_diff_AFF(nn.Module):
|
||||
expansion = 4
|
||||
|
||||
def __init__(self, in_planes, planes, stride=1, baseWidth=24, scale=3):
|
||||
super(BasicBlockERes2Net_diff_AFF, self).__init__()
|
||||
width = int(math.floor(planes*(baseWidth/64.0)))
|
||||
self.conv1 = nn.Conv2d(in_planes, width*scale, kernel_size=1, stride=stride, bias=False)
|
||||
self.bn1 = nn.BatchNorm2d(width*scale)
|
||||
width = int(math.floor(planes * (baseWidth / 64.0)))
|
||||
self.conv1 = nn.Conv2d(in_planes, width * scale, kernel_size=1, stride=stride, bias=False)
|
||||
self.bn1 = nn.BatchNorm2d(width * scale)
|
||||
self.nums = scale
|
||||
|
||||
convs=[]
|
||||
fuse_models=[]
|
||||
bns=[]
|
||||
convs = []
|
||||
fuse_models = []
|
||||
bns = []
|
||||
for i in range(self.nums):
|
||||
convs.append(nn.Conv2d(width, width, kernel_size=3, padding=1, bias=False))
|
||||
bns.append(nn.BatchNorm2d(width))
|
||||
convs.append(nn.Conv2d(width, width, kernel_size=3, padding=1, bias=False))
|
||||
bns.append(nn.BatchNorm2d(width))
|
||||
for j in range(self.nums - 1):
|
||||
fuse_models.append(AFF(channels=width))
|
||||
|
||||
@ -109,14 +109,15 @@ class BasicBlockERes2Net_diff_AFF(nn.Module):
|
||||
self.bns = nn.ModuleList(bns)
|
||||
self.fuse_models = nn.ModuleList(fuse_models)
|
||||
self.relu = ReLU(inplace=True)
|
||||
|
||||
self.conv3 = nn.Conv2d(width*scale, planes*self.expansion, kernel_size=1, bias=False)
|
||||
self.bn3 = nn.BatchNorm2d(planes*self.expansion)
|
||||
|
||||
self.conv3 = nn.Conv2d(width * scale, planes * self.expansion, kernel_size=1, bias=False)
|
||||
self.bn3 = nn.BatchNorm2d(planes * self.expansion)
|
||||
self.shortcut = nn.Sequential()
|
||||
if stride != 1 or in_planes != self.expansion * planes:
|
||||
self.shortcut = nn.Sequential(
|
||||
nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
|
||||
nn.BatchNorm2d(self.expansion * planes))
|
||||
nn.BatchNorm2d(self.expansion * planes),
|
||||
)
|
||||
self.stride = stride
|
||||
self.width = width
|
||||
self.scale = scale
|
||||
@ -127,20 +128,19 @@ class BasicBlockERes2Net_diff_AFF(nn.Module):
|
||||
out = self.conv1(x)
|
||||
out = self.bn1(out)
|
||||
out = self.relu(out)
|
||||
spx = torch.split(out,self.width,1)
|
||||
spx = torch.split(out, self.width, 1)
|
||||
for i in range(self.nums):
|
||||
if i==0:
|
||||
if i == 0:
|
||||
sp = spx[i]
|
||||
else:
|
||||
sp = self.fuse_models[i-1](sp, spx[i])
|
||||
|
||||
sp = self.fuse_models[i - 1](sp, spx[i])
|
||||
|
||||
sp = self.convs[i](sp)
|
||||
sp = self.relu(self.bns[i](sp))
|
||||
if i==0:
|
||||
if i == 0:
|
||||
out = sp
|
||||
else:
|
||||
out = torch.cat((out,sp),1)
|
||||
|
||||
out = torch.cat((out, sp), 1)
|
||||
|
||||
out = self.conv3(out)
|
||||
out = self.bn3(out)
|
||||
@ -151,16 +151,19 @@ class BasicBlockERes2Net_diff_AFF(nn.Module):
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class ERes2Net(nn.Module):
|
||||
def __init__(self,
|
||||
block=BasicBlockERes2Net,
|
||||
block_fuse=BasicBlockERes2Net_diff_AFF,
|
||||
num_blocks=[3, 4, 6, 3],
|
||||
m_channels=64,
|
||||
feat_dim=80,
|
||||
embedding_size=192,
|
||||
pooling_func='TSTP',
|
||||
two_emb_layer=False):
|
||||
def __init__(
|
||||
self,
|
||||
block=BasicBlockERes2Net,
|
||||
block_fuse=BasicBlockERes2Net_diff_AFF,
|
||||
num_blocks=[3, 4, 6, 3],
|
||||
m_channels=64,
|
||||
feat_dim=80,
|
||||
embedding_size=192,
|
||||
pooling_func="TSTP",
|
||||
two_emb_layer=False,
|
||||
):
|
||||
super(ERes2Net, self).__init__()
|
||||
self.in_planes = m_channels
|
||||
self.feat_dim = feat_dim
|
||||
@ -176,17 +179,22 @@ class ERes2Net(nn.Module):
|
||||
self.layer3 = self._make_layer(block_fuse, m_channels * 4, num_blocks[2], stride=2)
|
||||
self.layer4 = self._make_layer(block_fuse, m_channels * 8, num_blocks[3], stride=2)
|
||||
|
||||
self.layer1_downsample = nn.Conv2d(m_channels * 4, m_channels * 8, kernel_size=3, padding=1, stride=2, bias=False)
|
||||
self.layer2_downsample = nn.Conv2d(m_channels * 8, m_channels * 16, kernel_size=3, padding=1, stride=2, bias=False)
|
||||
self.layer3_downsample = nn.Conv2d(m_channels * 16, m_channels * 32, kernel_size=3, padding=1, stride=2, bias=False)
|
||||
self.layer1_downsample = nn.Conv2d(
|
||||
m_channels * 4, m_channels * 8, kernel_size=3, padding=1, stride=2, bias=False
|
||||
)
|
||||
self.layer2_downsample = nn.Conv2d(
|
||||
m_channels * 8, m_channels * 16, kernel_size=3, padding=1, stride=2, bias=False
|
||||
)
|
||||
self.layer3_downsample = nn.Conv2d(
|
||||
m_channels * 16, m_channels * 32, kernel_size=3, padding=1, stride=2, bias=False
|
||||
)
|
||||
|
||||
self.fuse_mode12 = AFF(channels=m_channels * 8)
|
||||
self.fuse_mode123 = AFF(channels=m_channels * 16)
|
||||
self.fuse_mode1234 = AFF(channels=m_channels * 32)
|
||||
|
||||
self.n_stats = 1 if pooling_func == 'TAP' or pooling_func == "TSDP" else 2
|
||||
self.pool = getattr(pooling_layers, pooling_func)(
|
||||
in_dim=self.stats_dim * block.expansion)
|
||||
self.n_stats = 1 if pooling_func == "TAP" or pooling_func == "TSDP" else 2
|
||||
self.pool = getattr(pooling_layers, pooling_func)(in_dim=self.stats_dim * block.expansion)
|
||||
self.seg_1 = nn.Linear(self.stats_dim * block.expansion * self.n_stats, embedding_size)
|
||||
if self.two_emb_layer:
|
||||
self.seg_bn_1 = nn.BatchNorm1d(embedding_size, affine=False)
|
||||
@ -229,7 +237,7 @@ class ERes2Net(nn.Module):
|
||||
else:
|
||||
return embed_a
|
||||
|
||||
def forward2(self, x,if_mean):
|
||||
def forward2(self, x, if_mean):
|
||||
x = x.permute(0, 2, 1) # (B,T,F) => (B,F,T)
|
||||
|
||||
x = x.unsqueeze_(1)
|
||||
@ -243,14 +251,13 @@ class ERes2Net(nn.Module):
|
||||
fuse_out123 = self.fuse_mode123(out3, fuse_out12_downsample)
|
||||
out4 = self.layer4(out3)
|
||||
fuse_out123_downsample = self.layer3_downsample(fuse_out123)
|
||||
fuse_out1234 = self.fuse_mode1234(out4, fuse_out123_downsample).flatten(start_dim=1,end_dim=2)#bs,20480,T
|
||||
if(if_mean==False):
|
||||
mean=fuse_out1234[0].transpose(1,0)#(T,20480),bs=T
|
||||
fuse_out1234 = self.fuse_mode1234(out4, fuse_out123_downsample).flatten(start_dim=1, end_dim=2) # bs,20480,T
|
||||
if if_mean == False:
|
||||
mean = fuse_out1234[0].transpose(1, 0) # (T,20480),bs=T
|
||||
else:
|
||||
mean = fuse_out1234.mean(2)#bs,20480
|
||||
mean_std=torch.cat([mean,torch.zeros_like(mean)],1)
|
||||
return self.seg_1(mean_std)#(T,192)
|
||||
|
||||
mean = fuse_out1234.mean(2) # bs,20480
|
||||
mean_std = torch.cat([mean, torch.zeros_like(mean)], 1)
|
||||
return self.seg_1(mean_std) # (T,192)
|
||||
|
||||
# stats = self.pool(fuse_out1234)
|
||||
# if self.two_emb_layer:
|
||||
@ -275,12 +282,8 @@ class ERes2Net(nn.Module):
|
||||
fuse_out123 = self.fuse_mode123(out3, fuse_out12_downsample)
|
||||
out4 = self.layer4(out3)
|
||||
fuse_out123_downsample = self.layer3_downsample(fuse_out123)
|
||||
fuse_out1234 = self.fuse_mode1234(out4, fuse_out123_downsample).flatten(start_dim=1,end_dim=2).mean(-1)
|
||||
fuse_out1234 = self.fuse_mode1234(out4, fuse_out123_downsample).flatten(start_dim=1, end_dim=2).mean(-1)
|
||||
return fuse_out1234
|
||||
# print(fuse_out1234.shape)
|
||||
# print(fuse_out1234.flatten(start_dim=1,end_dim=2).shape)
|
||||
# pdb.set_trace()
|
||||
|
||||
|
||||
|
||||
|
||||
|
@ -6,7 +6,6 @@ import torch.nn as nn
|
||||
|
||||
|
||||
class AFF(nn.Module):
|
||||
|
||||
def __init__(self, channels=64, r=4):
|
||||
super(AFF, self).__init__()
|
||||
inter_channels = int(channels // r)
|
||||
@ -23,7 +22,6 @@ class AFF(nn.Module):
|
||||
xa = torch.cat((x, ds_y), dim=1)
|
||||
x_att = self.local_att(xa)
|
||||
x_att = 1.0 + torch.tanh(x_att)
|
||||
xo = torch.mul(x, x_att) + torch.mul(ds_y, 2.0-x_att)
|
||||
xo = torch.mul(x, x_att) + torch.mul(ds_y, 2.0 - x_att)
|
||||
|
||||
return xo
|
||||
|
||||
|
@ -144,7 +144,7 @@ def _get_waveform_and_window_properties(
|
||||
)
|
||||
assert 0 < window_shift, "`window_shift` must be greater than 0"
|
||||
assert padded_window_size % 2 == 0, (
|
||||
"the padded `window_size` must be divisible by two." " use `round_to_power_of_two` or change `frame_length`"
|
||||
"the padded `window_size` must be divisible by two. use `round_to_power_of_two` or change `frame_length`"
|
||||
)
|
||||
assert 0.0 <= preemphasis_coefficient <= 1.0, "`preemphasis_coefficient` must be between [0,1]"
|
||||
assert sample_frequency > 0, "`sample_frequency` must be greater than zero"
|
||||
@ -441,7 +441,9 @@ def get_mel_banks(
|
||||
high_freq: float,
|
||||
vtln_low: float,
|
||||
vtln_high: float,
|
||||
vtln_warp_factor: float,device=None,dtype=None
|
||||
vtln_warp_factor: float,
|
||||
device=None,
|
||||
dtype=None,
|
||||
) -> Tuple[Tensor, Tensor]:
|
||||
"""
|
||||
Returns:
|
||||
@ -457,9 +459,9 @@ def get_mel_banks(
|
||||
if high_freq <= 0.0:
|
||||
high_freq += nyquist
|
||||
|
||||
assert (
|
||||
(0.0 <= low_freq < nyquist) and (0.0 < high_freq <= nyquist) and (low_freq < high_freq)
|
||||
), "Bad values in options: low-freq {} and high-freq {} vs. nyquist {}".format(low_freq, high_freq, nyquist)
|
||||
assert (0.0 <= low_freq < nyquist) and (0.0 < high_freq <= nyquist) and (low_freq < high_freq), (
|
||||
"Bad values in options: low-freq {} and high-freq {} vs. nyquist {}".format(low_freq, high_freq, nyquist)
|
||||
)
|
||||
|
||||
# fft-bin width [think of it as Nyquist-freq / half-window-length]
|
||||
fft_bin_width = sample_freq / window_length_padded
|
||||
@ -475,7 +477,7 @@ def get_mel_banks(
|
||||
|
||||
assert vtln_warp_factor == 1.0 or (
|
||||
(low_freq < vtln_low < high_freq) and (0.0 < vtln_high < high_freq) and (vtln_low < vtln_high)
|
||||
), "Bad values in options: vtln-low {} and vtln-high {}, versus " "low-freq {} and high-freq {}".format(
|
||||
), "Bad values in options: vtln-low {} and vtln-high {}, versus low-freq {} and high-freq {}".format(
|
||||
vtln_low, vtln_high, low_freq, high_freq
|
||||
)
|
||||
|
||||
@ -508,9 +510,12 @@ def get_mel_banks(
|
||||
bins[up_idx] = up_slope[up_idx]
|
||||
bins[down_idx] = down_slope[down_idx]
|
||||
|
||||
return bins.to(device=device,dtype=dtype)#, center_freqs
|
||||
return bins.to(device=device, dtype=dtype) # , center_freqs
|
||||
|
||||
|
||||
cache = {}
|
||||
|
||||
|
||||
cache={}
|
||||
def fbank(
|
||||
waveform: Tensor,
|
||||
blackman_coeff: float = 0.42,
|
||||
@ -620,14 +625,34 @@ def fbank(
|
||||
# size (num_mel_bins, padded_window_size // 2)
|
||||
# print(num_mel_bins, padded_window_size, sample_frequency, low_freq, high_freq, vtln_low, vtln_high, vtln_warp)
|
||||
|
||||
cache_key="%s-%s-%s-%s-%s-%s-%s-%s-%s-%s"%(num_mel_bins, padded_window_size, sample_frequency, low_freq, high_freq, vtln_low, vtln_high, vtln_warp,device,dtype)
|
||||
cache_key = "%s-%s-%s-%s-%s-%s-%s-%s-%s-%s" % (
|
||||
num_mel_bins,
|
||||
padded_window_size,
|
||||
sample_frequency,
|
||||
low_freq,
|
||||
high_freq,
|
||||
vtln_low,
|
||||
vtln_high,
|
||||
vtln_warp,
|
||||
device,
|
||||
dtype,
|
||||
)
|
||||
if cache_key not in cache:
|
||||
mel_energies = get_mel_banks(
|
||||
num_mel_bins, padded_window_size, sample_frequency, low_freq, high_freq, vtln_low, vtln_high, vtln_warp,device,dtype
|
||||
num_mel_bins,
|
||||
padded_window_size,
|
||||
sample_frequency,
|
||||
low_freq,
|
||||
high_freq,
|
||||
vtln_low,
|
||||
vtln_high,
|
||||
vtln_warp,
|
||||
device,
|
||||
dtype,
|
||||
)
|
||||
cache[cache_key]=mel_energies
|
||||
cache[cache_key] = mel_energies
|
||||
else:
|
||||
mel_energies=cache[cache_key]
|
||||
mel_energies = cache[cache_key]
|
||||
|
||||
# pad right column with zeros and add dimension, size (num_mel_bins, padded_window_size // 2 + 1)
|
||||
mel_energies = torch.nn.functional.pad(mel_energies, (0, 1), mode="constant", value=0)
|
||||
|
@ -1,7 +1,7 @@
|
||||
# Copyright 3D-Speaker (https://github.com/alibaba-damo-academy/3D-Speaker). All Rights Reserved.
|
||||
# Licensed under the Apache License, Version 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
||||
|
||||
""" This implementation is adapted from https://github.com/wenet-e2e/wespeaker."""
|
||||
"""This implementation is adapted from https://github.com/wenet-e2e/wespeaker."""
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@ -11,6 +11,7 @@ class TAP(nn.Module):
|
||||
"""
|
||||
Temporal average pooling, only first-order mean is considered
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super(TAP, self).__init__()
|
||||
|
||||
@ -25,6 +26,7 @@ class TSDP(nn.Module):
|
||||
"""
|
||||
Temporal standard deviation pooling, only second-order std is considered
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super(TSDP, self).__init__()
|
||||
|
||||
@ -41,6 +43,7 @@ class TSTP(nn.Module):
|
||||
x-vector
|
||||
Comment: simple concatenation can not make full use of both statistics
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super(TSTP, self).__init__()
|
||||
|
||||
@ -56,9 +59,10 @@ class TSTP(nn.Module):
|
||||
|
||||
|
||||
class ASTP(nn.Module):
|
||||
""" Attentive statistics pooling: Channel- and context-dependent
|
||||
statistics pooling, first used in ECAPA_TDNN.
|
||||
"""Attentive statistics pooling: Channel- and context-dependent
|
||||
statistics pooling, first used in ECAPA_TDNN.
|
||||
"""
|
||||
|
||||
def __init__(self, in_dim, bottleneck_dim=128, global_context_att=False):
|
||||
super(ASTP, self).__init__()
|
||||
self.global_context_att = global_context_att
|
||||
@ -66,15 +70,10 @@ class ASTP(nn.Module):
|
||||
# Use Conv1d with stride == 1 rather than Linear, then we don't
|
||||
# need to transpose inputs.
|
||||
if global_context_att:
|
||||
self.linear1 = nn.Conv1d(
|
||||
in_dim * 3, bottleneck_dim,
|
||||
kernel_size=1) # equals W and b in the paper
|
||||
self.linear1 = nn.Conv1d(in_dim * 3, bottleneck_dim, kernel_size=1) # equals W and b in the paper
|
||||
else:
|
||||
self.linear1 = nn.Conv1d(
|
||||
in_dim, bottleneck_dim,
|
||||
kernel_size=1) # equals W and b in the paper
|
||||
self.linear2 = nn.Conv1d(bottleneck_dim, in_dim,
|
||||
kernel_size=1) # equals V and k in the paper
|
||||
self.linear1 = nn.Conv1d(in_dim, bottleneck_dim, kernel_size=1) # equals W and b in the paper
|
||||
self.linear2 = nn.Conv1d(bottleneck_dim, in_dim, kernel_size=1) # equals V and k in the paper
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
@ -88,15 +87,13 @@ class ASTP(nn.Module):
|
||||
|
||||
if self.global_context_att:
|
||||
context_mean = torch.mean(x, dim=-1, keepdim=True).expand_as(x)
|
||||
context_std = torch.sqrt(
|
||||
torch.var(x, dim=-1, keepdim=True) + 1e-10).expand_as(x)
|
||||
context_std = torch.sqrt(torch.var(x, dim=-1, keepdim=True) + 1e-10).expand_as(x)
|
||||
x_in = torch.cat((x, context_mean, context_std), dim=1)
|
||||
else:
|
||||
x_in = x
|
||||
|
||||
# DON'T use ReLU here! ReLU may be hard to converge.
|
||||
alpha = torch.tanh(
|
||||
self.linear1(x_in)) # alpha = F.relu(self.linear1(x_in))
|
||||
alpha = torch.tanh(self.linear1(x_in)) # alpha = F.relu(self.linear1(x_in))
|
||||
alpha = torch.softmax(self.linear2(alpha), dim=2)
|
||||
mean = torch.sum(alpha * x, dim=2)
|
||||
var = torch.sum(alpha * (x**2), dim=2) - mean**2
|
||||
|
@ -1,6 +1,7 @@
|
||||
# modified from https://github.com/yangdongchao/SoundStorm/blob/master/soundstorm/s1/AR/models/t2s_model.py
|
||||
# reference: https://github.com/lifeiteng/vall-e
|
||||
import argparse
|
||||
from io import BytesIO
|
||||
from typing import Optional
|
||||
from my_utils import load_audio
|
||||
import torch
|
||||
@ -17,6 +18,9 @@ from module.models_onnx import SynthesizerTrn
|
||||
|
||||
from inference_webui import get_phones_and_bert
|
||||
|
||||
from sv import SV
|
||||
import kaldi as Kaldi
|
||||
|
||||
import os
|
||||
import soundfile
|
||||
|
||||
@ -32,6 +36,25 @@ default_config = {
|
||||
"EOS": 1024,
|
||||
}
|
||||
|
||||
sv_cn_model = None
|
||||
|
||||
|
||||
def init_sv_cn(device, is_half):
|
||||
global sv_cn_model
|
||||
sv_cn_model = SV(device, is_half)
|
||||
|
||||
|
||||
def load_sovits_new(sovits_path):
|
||||
f = open(sovits_path, "rb")
|
||||
meta = f.read(2)
|
||||
if meta != b"PK":
|
||||
data = b"PK" + f.read()
|
||||
bio = BytesIO()
|
||||
bio.write(data)
|
||||
bio.seek(0)
|
||||
return torch.load(bio, map_location="cpu", weights_only=False)
|
||||
return torch.load(sovits_path, map_location="cpu", weights_only=False)
|
||||
|
||||
|
||||
def get_raw_t2s_model(dict_s1) -> Text2SemanticLightningModule:
|
||||
config = dict_s1["config"]
|
||||
@ -83,7 +106,7 @@ def logits_to_probs(
|
||||
@torch.jit.script
|
||||
def multinomial_sample_one_no_sync(probs_sort):
|
||||
# Does multinomial sampling without a cuda synchronization
|
||||
q = torch.randn_like(probs_sort)
|
||||
q = torch.empty_like(probs_sort).exponential_(1.0)
|
||||
return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int)
|
||||
|
||||
|
||||
@ -94,7 +117,7 @@ def sample(
|
||||
temperature: float = 1.0,
|
||||
top_k: Optional[int] = None,
|
||||
top_p: Optional[int] = None,
|
||||
repetition_penalty: float = 1.0,
|
||||
repetition_penalty: float = 1.35,
|
||||
):
|
||||
probs = logits_to_probs(
|
||||
logits=logits,
|
||||
@ -109,8 +132,10 @@ def sample(
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def spectrogram_torch(y: Tensor, n_fft: int, sampling_rate: int, hop_size: int, win_size: int, center: bool = False):
|
||||
hann_window = torch.hann_window(win_size, device=y.device, dtype=y.dtype)
|
||||
def spectrogram_torch(
|
||||
hann_window: Tensor, y: Tensor, n_fft: int, sampling_rate: int, hop_size: int, win_size: int, center: bool = False
|
||||
):
|
||||
# hann_window = torch.hann_window(win_size, device=y.device, dtype=y.dtype)
|
||||
y = torch.nn.functional.pad(
|
||||
y.unsqueeze(1),
|
||||
(int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)),
|
||||
@ -289,8 +314,9 @@ class T2SBlock:
|
||||
|
||||
attn = F.scaled_dot_product_attention(q, k, v)
|
||||
|
||||
attn = attn.permute(2, 0, 1, 3).reshape(batch_size * q_len, self.hidden_dim)
|
||||
attn = attn.view(q_len, batch_size, self.hidden_dim).transpose(1, 0)
|
||||
# attn = attn.permute(2, 0, 1, 3).reshape(batch_size * q_len, self.hidden_dim)
|
||||
# attn = attn.view(q_len, batch_size, self.hidden_dim).transpose(1, 0)
|
||||
attn = attn.transpose(1, 2).reshape(batch_size, q_len, -1)
|
||||
attn = F.linear(attn, self.out_w, self.out_b)
|
||||
|
||||
x = x + attn
|
||||
@ -328,15 +354,22 @@ class T2STransformer:
|
||||
|
||||
|
||||
class VitsModel(nn.Module):
|
||||
def __init__(self, vits_path):
|
||||
def __init__(self, vits_path, version=None, is_half=True, device="cpu"):
|
||||
super().__init__()
|
||||
# dict_s2 = torch.load(vits_path,map_location="cpu")
|
||||
dict_s2 = torch.load(vits_path, weights_only=False)
|
||||
dict_s2 = load_sovits_new(vits_path)
|
||||
self.hps = dict_s2["config"]
|
||||
if dict_s2["weight"]["enc_p.text_embedding.weight"].shape[0] == 322:
|
||||
self.hps["model"]["version"] = "v1"
|
||||
|
||||
if version is None:
|
||||
if dict_s2["weight"]["enc_p.text_embedding.weight"].shape[0] == 322:
|
||||
self.hps["model"]["version"] = "v1"
|
||||
else:
|
||||
self.hps["model"]["version"] = "v2"
|
||||
else:
|
||||
self.hps["model"]["version"] = "v2"
|
||||
if version in ["v1", "v2", "v3", "v4", "v2Pro", "v2ProPlus"]:
|
||||
self.hps["model"]["version"] = version
|
||||
else:
|
||||
raise ValueError(f"Unsupported version: {version}")
|
||||
|
||||
self.hps = DictToAttrRecursive(self.hps)
|
||||
self.hps.model.semantic_frame_rate = "25hz"
|
||||
@ -346,11 +379,19 @@ class VitsModel(nn.Module):
|
||||
n_speakers=self.hps.data.n_speakers,
|
||||
**self.hps.model,
|
||||
)
|
||||
self.vq_model.eval()
|
||||
self.vq_model.load_state_dict(dict_s2["weight"], strict=False)
|
||||
self.vq_model.dec.remove_weight_norm()
|
||||
if is_half:
|
||||
self.vq_model = self.vq_model.half()
|
||||
self.vq_model = self.vq_model.to(device)
|
||||
self.vq_model.eval()
|
||||
self.hann_window = torch.hann_window(
|
||||
self.hps.data.win_length, device=device, dtype=torch.float16 if is_half else torch.float32
|
||||
)
|
||||
|
||||
def forward(self, text_seq, pred_semantic, ref_audio, speed=1.0):
|
||||
def forward(self, text_seq, pred_semantic, ref_audio, speed=1.0, sv_emb=None):
|
||||
refer = spectrogram_torch(
|
||||
self.hann_window,
|
||||
ref_audio,
|
||||
self.hps.data.filter_length,
|
||||
self.hps.data.sampling_rate,
|
||||
@ -358,7 +399,7 @@ class VitsModel(nn.Module):
|
||||
self.hps.data.win_length,
|
||||
center=False,
|
||||
)
|
||||
return self.vq_model(pred_semantic, text_seq, refer, speed)[0, 0]
|
||||
return self.vq_model(pred_semantic, text_seq, refer, speed=speed, sv_emb=sv_emb)[0, 0]
|
||||
|
||||
|
||||
class T2SModel(nn.Module):
|
||||
@ -632,7 +673,9 @@ def export(gpt_path, vits_path, ref_audio_path, ref_text, output_path, export_be
|
||||
ref_seq = torch.LongTensor([ref_seq_id]).to(device)
|
||||
ref_bert = ref_bert_T.T.to(ref_seq.device)
|
||||
text_seq_id, text_bert_T, norm_text = get_phones_and_bert(
|
||||
"这是一条测试语音,说什么无所谓,只是给它一个例子", "all_zh", "v2"
|
||||
"这是一个简单的示例,真没想到这么简单就完成了。The King and His Stories.Once there was a king. He likes to write stories, but his stories were not good. As people were afraid of him, they all said his stories were good.After reading them, the writer at once turned to the soldiers and said: Take me back to prison, please.",
|
||||
"auto",
|
||||
"v2",
|
||||
)
|
||||
text_seq = torch.LongTensor([text_seq_id]).to(device)
|
||||
text_bert = text_bert_T.T.to(text_seq.device)
|
||||
@ -640,7 +683,7 @@ def export(gpt_path, vits_path, ref_audio_path, ref_text, output_path, export_be
|
||||
ssl_content = ssl(ref_audio).to(device)
|
||||
|
||||
# vits_path = "SoVITS_weights_v2/xw_e8_s216.pth"
|
||||
vits = VitsModel(vits_path).to(device)
|
||||
vits = VitsModel(vits_path, device=device, is_half=False)
|
||||
vits.eval()
|
||||
|
||||
# gpt_path = "GPT_weights_v2/xw-e15.ckpt"
|
||||
@ -679,6 +722,124 @@ def export(gpt_path, vits_path, ref_audio_path, ref_text, output_path, export_be
|
||||
print("#### exported gpt_sovits ####")
|
||||
|
||||
|
||||
def export_prov2(
|
||||
gpt_path,
|
||||
vits_path,
|
||||
version,
|
||||
ref_audio_path,
|
||||
ref_text,
|
||||
output_path,
|
||||
export_bert_and_ssl=False,
|
||||
device="cpu",
|
||||
is_half=True,
|
||||
):
|
||||
if sv_cn_model == None:
|
||||
init_sv_cn(device, is_half)
|
||||
|
||||
if not os.path.exists(output_path):
|
||||
os.makedirs(output_path)
|
||||
print(f"目录已创建: {output_path}")
|
||||
else:
|
||||
print(f"目录已存在: {output_path}")
|
||||
|
||||
ref_audio = torch.tensor([load_audio(ref_audio_path, 16000)]).float()
|
||||
ssl = SSLModel()
|
||||
if export_bert_and_ssl:
|
||||
s = ExportSSLModel(torch.jit.trace(ssl, example_inputs=(ref_audio)))
|
||||
ssl_path = os.path.join(output_path, "ssl_model.pt")
|
||||
torch.jit.script(s).save(ssl_path)
|
||||
print("#### exported ssl ####")
|
||||
export_bert(output_path)
|
||||
else:
|
||||
s = ExportSSLModel(ssl)
|
||||
|
||||
print(f"device: {device}")
|
||||
|
||||
ref_seq_id, ref_bert_T, ref_norm_text = get_phones_and_bert(ref_text, "all_zh", "v2")
|
||||
ref_seq = torch.LongTensor([ref_seq_id]).to(device)
|
||||
ref_bert = ref_bert_T.T
|
||||
if is_half:
|
||||
ref_bert = ref_bert.half()
|
||||
ref_bert = ref_bert.to(ref_seq.device)
|
||||
|
||||
text_seq_id, text_bert_T, norm_text = get_phones_and_bert(
|
||||
"这是一个简单的示例,真没想到这么简单就完成了。The King and His Stories.Once there was a king. He likes to write stories, but his stories were not good. As people were afraid of him, they all said his stories were good.After reading them, the writer at once turned to the soldiers and said: Take me back to prison, please.",
|
||||
"auto",
|
||||
"v2",
|
||||
)
|
||||
text_seq = torch.LongTensor([text_seq_id]).to(device)
|
||||
text_bert = text_bert_T.T
|
||||
if is_half:
|
||||
text_bert = text_bert.half()
|
||||
text_bert = text_bert.to(text_seq.device)
|
||||
|
||||
ssl_content = ssl(ref_audio)
|
||||
if is_half:
|
||||
ssl_content = ssl_content.half()
|
||||
ssl_content = ssl_content.to(device)
|
||||
|
||||
sv_model = ExportERes2NetV2(sv_cn_model)
|
||||
|
||||
# vits_path = "SoVITS_weights_v2/xw_e8_s216.pth"
|
||||
vits = VitsModel(vits_path, version, is_half=is_half, device=device)
|
||||
vits.eval()
|
||||
|
||||
# gpt_path = "GPT_weights_v2/xw-e15.ckpt"
|
||||
# dict_s1 = torch.load(gpt_path, map_location=device)
|
||||
dict_s1 = torch.load(gpt_path, weights_only=False)
|
||||
raw_t2s = get_raw_t2s_model(dict_s1).to(device)
|
||||
print("#### get_raw_t2s_model ####")
|
||||
print(raw_t2s.config)
|
||||
if is_half:
|
||||
raw_t2s = raw_t2s.half()
|
||||
t2s_m = T2SModel(raw_t2s)
|
||||
t2s_m.eval()
|
||||
t2s = torch.jit.script(t2s_m).to(device)
|
||||
print("#### script t2s_m ####")
|
||||
|
||||
print("vits.hps.data.sampling_rate:", vits.hps.data.sampling_rate)
|
||||
gpt_sovits = GPT_SoVITS_V2Pro(t2s, vits, sv_model).to(device)
|
||||
gpt_sovits.eval()
|
||||
|
||||
ref_audio_sr = s.resample(ref_audio, 16000, 32000)
|
||||
if is_half:
|
||||
ref_audio_sr = ref_audio_sr.half()
|
||||
ref_audio_sr = ref_audio_sr.to(device)
|
||||
|
||||
torch._dynamo.mark_dynamic(ssl_content, 2)
|
||||
torch._dynamo.mark_dynamic(ref_audio_sr, 1)
|
||||
torch._dynamo.mark_dynamic(ref_seq, 1)
|
||||
torch._dynamo.mark_dynamic(text_seq, 1)
|
||||
torch._dynamo.mark_dynamic(ref_bert, 0)
|
||||
torch._dynamo.mark_dynamic(text_bert, 0)
|
||||
# torch._dynamo.mark_dynamic(sv_emb, 0)
|
||||
|
||||
top_k = torch.LongTensor([5]).to(device)
|
||||
# 先跑一遍 sv_model 让它加载 cache,详情见 L880
|
||||
gpt_sovits.sv_model(ref_audio_sr)
|
||||
|
||||
with torch.no_grad():
|
||||
gpt_sovits_export = torch.jit.trace(
|
||||
gpt_sovits,
|
||||
example_inputs=(
|
||||
ssl_content,
|
||||
ref_audio_sr,
|
||||
ref_seq,
|
||||
text_seq,
|
||||
ref_bert,
|
||||
text_bert,
|
||||
top_k,
|
||||
),
|
||||
)
|
||||
|
||||
gpt_sovits_path = os.path.join(output_path, "gpt_sovits_model.pt")
|
||||
gpt_sovits_export.save(gpt_sovits_path)
|
||||
print("#### exported gpt_sovits ####")
|
||||
audio = gpt_sovits_export(ssl_content, ref_audio_sr, ref_seq, text_seq, ref_bert, text_bert, top_k)
|
||||
print("start write wav")
|
||||
soundfile.write("out.wav", audio.float().detach().cpu().numpy(), 32000)
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def parse_audio(ref_audio):
|
||||
ref_audio_16k = torchaudio.functional.resample(ref_audio, 48000, 16000).float() # .to(ref_audio.device)
|
||||
@ -717,6 +878,67 @@ class GPT_SoVITS(nn.Module):
|
||||
return audio
|
||||
|
||||
|
||||
class ExportERes2NetV2(nn.Module):
|
||||
def __init__(self, sv_cn_model: SV):
|
||||
super(ExportERes2NetV2, self).__init__()
|
||||
self.bn1 = sv_cn_model.embedding_model.bn1
|
||||
self.conv1 = sv_cn_model.embedding_model.conv1
|
||||
self.layer1 = sv_cn_model.embedding_model.layer1
|
||||
self.layer2 = sv_cn_model.embedding_model.layer2
|
||||
self.layer3 = sv_cn_model.embedding_model.layer3
|
||||
self.layer4 = sv_cn_model.embedding_model.layer4
|
||||
self.layer3_ds = sv_cn_model.embedding_model.layer3_ds
|
||||
self.fuse34 = sv_cn_model.embedding_model.fuse34
|
||||
|
||||
# audio_16k.shape: [1,N]
|
||||
def forward(self, audio_16k):
|
||||
# 这个 fbank 函数有一个 cache, 不过不要紧,它跟 audio_16k 的长度无关
|
||||
# 只跟 device 和 dtype 有关
|
||||
x = Kaldi.fbank(audio_16k, num_mel_bins=80, sample_frequency=16000, dither=0)
|
||||
x = torch.stack([x])
|
||||
|
||||
x = x.permute(0, 2, 1) # (B,T,F) => (B,F,T)
|
||||
x = x.unsqueeze_(1)
|
||||
out = F.relu(self.bn1(self.conv1(x)))
|
||||
out1 = self.layer1(out)
|
||||
out2 = self.layer2(out1)
|
||||
out3 = self.layer3(out2)
|
||||
out4 = self.layer4(out3)
|
||||
out3_ds = self.layer3_ds(out3)
|
||||
fuse_out34 = self.fuse34(out4, out3_ds)
|
||||
return fuse_out34.flatten(start_dim=1, end_dim=2).mean(-1)
|
||||
|
||||
|
||||
class GPT_SoVITS_V2Pro(nn.Module):
|
||||
def __init__(self, t2s: T2SModel, vits: VitsModel, sv_model: ExportERes2NetV2):
|
||||
super().__init__()
|
||||
self.t2s = t2s
|
||||
self.vits = vits
|
||||
self.sv_model = sv_model
|
||||
|
||||
def forward(
|
||||
self,
|
||||
ssl_content: torch.Tensor,
|
||||
ref_audio_sr: torch.Tensor,
|
||||
ref_seq: Tensor,
|
||||
text_seq: Tensor,
|
||||
ref_bert: Tensor,
|
||||
text_bert: Tensor,
|
||||
top_k: LongTensor,
|
||||
speed=1.0,
|
||||
):
|
||||
codes = self.vits.vq_model.extract_latent(ssl_content)
|
||||
prompt_semantic = codes[0, 0]
|
||||
prompts = prompt_semantic.unsqueeze(0)
|
||||
|
||||
audio_16k = resamplex(ref_audio_sr, 32000, 16000).to(ref_audio_sr.dtype)
|
||||
sv_emb = self.sv_model(audio_16k)
|
||||
|
||||
pred_semantic = self.t2s(prompts, ref_seq, text_seq, ref_bert, text_bert, top_k)
|
||||
audio = self.vits(text_seq, pred_semantic, ref_audio_sr, speed, sv_emb)
|
||||
return audio
|
||||
|
||||
|
||||
def test():
|
||||
parser = argparse.ArgumentParser(description="GPT-SoVITS Command Line Tool")
|
||||
parser.add_argument("--gpt_model", required=True, help="Path to the GPT model file")
|
||||
@ -839,23 +1061,37 @@ def main():
|
||||
parser.add_argument("--output_path", required=True, help="Path to the output directory")
|
||||
parser.add_argument("--export_common_model", action="store_true", help="Export Bert and SSL model")
|
||||
parser.add_argument("--device", help="Device to use")
|
||||
parser.add_argument("--version", help="version of the model", default="v2")
|
||||
parser.add_argument("--no-half", action="store_true", help="Do not use half precision for model weights")
|
||||
|
||||
args = parser.parse_args()
|
||||
export(
|
||||
gpt_path=args.gpt_model,
|
||||
vits_path=args.sovits_model,
|
||||
ref_audio_path=args.ref_audio,
|
||||
ref_text=args.ref_text,
|
||||
output_path=args.output_path,
|
||||
device=args.device,
|
||||
export_bert_and_ssl=args.export_common_model,
|
||||
)
|
||||
if args.version in ["v2Pro", "v2ProPlus"]:
|
||||
is_half = not args.no_half
|
||||
print(f"Using half precision: {is_half}")
|
||||
export_prov2(
|
||||
gpt_path=args.gpt_model,
|
||||
vits_path=args.sovits_model,
|
||||
version=args.version,
|
||||
ref_audio_path=args.ref_audio,
|
||||
ref_text=args.ref_text,
|
||||
output_path=args.output_path,
|
||||
export_bert_and_ssl=args.export_common_model,
|
||||
device=args.device,
|
||||
is_half=is_half,
|
||||
)
|
||||
else:
|
||||
export(
|
||||
gpt_path=args.gpt_model,
|
||||
vits_path=args.sovits_model,
|
||||
ref_audio_path=args.ref_audio,
|
||||
ref_text=args.ref_text,
|
||||
output_path=args.output_path,
|
||||
device=args.device,
|
||||
export_bert_and_ssl=args.export_common_model,
|
||||
)
|
||||
|
||||
|
||||
import inference_webui
|
||||
|
||||
if __name__ == "__main__":
|
||||
inference_webui.is_half = False
|
||||
inference_webui.dtype = torch.float32
|
||||
main()
|
||||
with torch.no_grad():
|
||||
main()
|
||||
# test()
|
||||
|
@ -243,6 +243,7 @@ class ExportGPTSovitsHalf(torch.nn.Module):
|
||||
self.sampling_rate: int = hps.data.sampling_rate
|
||||
self.hop_length: int = hps.data.hop_length
|
||||
self.win_length: int = hps.data.win_length
|
||||
self.hann_window = torch.hann_window(self.win_length, device=device, dtype=torch.float32)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -255,6 +256,7 @@ class ExportGPTSovitsHalf(torch.nn.Module):
|
||||
top_k,
|
||||
):
|
||||
refer = spectrogram_torch(
|
||||
self.hann_window,
|
||||
ref_audio_32k,
|
||||
self.filter_length,
|
||||
self.sampling_rate,
|
||||
@ -321,6 +323,7 @@ class ExportGPTSovitsV4Half(torch.nn.Module):
|
||||
self.sampling_rate: int = hps.data.sampling_rate
|
||||
self.hop_length: int = hps.data.hop_length
|
||||
self.win_length: int = hps.data.win_length
|
||||
self.hann_window = torch.hann_window(self.win_length, device=device, dtype=torch.float32)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -333,6 +336,7 @@ class ExportGPTSovitsV4Half(torch.nn.Module):
|
||||
top_k,
|
||||
):
|
||||
refer = spectrogram_torch(
|
||||
self.hann_window,
|
||||
ref_audio_32k,
|
||||
self.filter_length,
|
||||
self.sampling_rate,
|
||||
@ -402,7 +406,7 @@ class GPTSoVITSV3(torch.nn.Module):
|
||||
chunk_len = 934 - fea_ref.shape[2]
|
||||
wav_gen_list = []
|
||||
idx = 0
|
||||
fea_todo = fea_todo[:,:,:-5]
|
||||
fea_todo = fea_todo[:, :, :-5]
|
||||
wav_gen_length = fea_todo.shape[2] * 256
|
||||
while 1:
|
||||
# current_time = datetime.now()
|
||||
@ -434,7 +438,8 @@ class GPTSoVITSV3(torch.nn.Module):
|
||||
|
||||
wav_gen = torch.cat(wav_gen_list, 2)
|
||||
return wav_gen[0][0][:wav_gen_length]
|
||||
|
||||
|
||||
|
||||
class GPTSoVITSV4(torch.nn.Module):
|
||||
def __init__(self, gpt_sovits_half, cfm, hifigan):
|
||||
super().__init__()
|
||||
@ -461,7 +466,7 @@ class GPTSoVITSV4(torch.nn.Module):
|
||||
chunk_len = 1000 - fea_ref.shape[2]
|
||||
wav_gen_list = []
|
||||
idx = 0
|
||||
fea_todo = fea_todo[:,:,:-10]
|
||||
fea_todo = fea_todo[:, :, :-10]
|
||||
wav_gen_length = fea_todo.shape[2] * 480
|
||||
while 1:
|
||||
# current_time = datetime.now()
|
||||
@ -577,6 +582,7 @@ from process_ckpt import get_sovits_version_from_path_fast, load_sovits_new
|
||||
|
||||
v3v4set = {"v3", "v4"}
|
||||
|
||||
|
||||
def get_sovits_weights(sovits_path):
|
||||
path_sovits_v3 = "GPT_SoVITS/pretrained_models/s2Gv3.pth"
|
||||
is_exist_s2gv3 = os.path.exists(path_sovits_v3)
|
||||
@ -699,14 +705,13 @@ def export_cfm(
|
||||
return export_cfm
|
||||
|
||||
|
||||
def export_1(ref_wav_path,ref_wav_text,version="v3"):
|
||||
def export_1(ref_wav_path, ref_wav_text, version="v3"):
|
||||
if version == "v3":
|
||||
sovits = get_sovits_weights("GPT_SoVITS/pretrained_models/s2Gv3.pth")
|
||||
init_bigvgan()
|
||||
else:
|
||||
sovits = get_sovits_weights("GPT_SoVITS/pretrained_models/gsv-v4-pretrained/s2Gv4.pth")
|
||||
init_hifigan()
|
||||
|
||||
|
||||
dict_s1 = torch.load("GPT_SoVITS/pretrained_models/s1v3.ckpt")
|
||||
raw_t2s = get_raw_t2s_model(dict_s1).to(device)
|
||||
@ -751,9 +756,7 @@ def export_1(ref_wav_path,ref_wav_text,version="v3"):
|
||||
# phones1, bert1, norm_text1 = get_phones_and_bert(
|
||||
# "你这老坏蛋,我找了你这么久,真没想到在这里找到你。他说。", "all_zh", "v3"
|
||||
# )
|
||||
phones1, bert1, norm_text1 = get_phones_and_bert(
|
||||
ref_wav_text, "auto", "v3"
|
||||
)
|
||||
phones1, bert1, norm_text1 = get_phones_and_bert(ref_wav_text, "auto", "v3")
|
||||
phones2, bert2, norm_text2 = get_phones_and_bert(
|
||||
"这是一个简单的示例,真没想到这么简单就完成了。The King and His Stories.Once there was a king. He likes to write stories, but his stories were not good. As people were afraid of him, they all said his stories were good.After reading them, the writer at once turned to the soldiers and said: Take me back to prison, please.",
|
||||
"auto",
|
||||
@ -914,7 +917,7 @@ def export_1(ref_wav_path,ref_wav_text,version="v3"):
|
||||
hifigan_model_ = torch.jit.trace(hifigan_model, optimize=True, example_inputs=(cmf_res_rand,))
|
||||
hifigan_model_.save("onnx/ad/hifigan_model.pt")
|
||||
wav_gen = hifigan_model(cmf_res)
|
||||
|
||||
|
||||
print("wav_gen:", wav_gen.shape, wav_gen.dtype)
|
||||
audio = wav_gen[0][0].cpu().detach().numpy()
|
||||
|
||||
@ -1149,7 +1152,7 @@ def export_2(version="v3"):
|
||||
raw_t2s = raw_t2s.half().to(device)
|
||||
t2s_m = T2SModel(raw_t2s).half().to(device)
|
||||
t2s_m.eval()
|
||||
t2s_m = torch.jit.script(t2s_m)
|
||||
t2s_m = torch.jit.script(t2s_m).to(device)
|
||||
t2s_m.eval()
|
||||
# t2s_m.top_k = 15
|
||||
logger.info("t2s_m ok")
|
||||
@ -1201,7 +1204,6 @@ def export_2(version="v3"):
|
||||
gpt_sovits_v3v4 = gpt_sovits_v3 if version == "v3" else gpt_sovits_v4
|
||||
sr = 24000 if version == "v3" else 48000
|
||||
|
||||
|
||||
time.sleep(5)
|
||||
# print("thread:", torch.get_num_threads())
|
||||
# print("thread:", torch.get_num_interop_threads())
|
||||
@ -1212,14 +1214,14 @@ def export_2(version="v3"):
|
||||
"汗流浃背了呀!老弟~ My uncle has two dogs. One is big and the other is small. He likes them very much. He often plays with them. He takes them for a walk every day. He says they are his good friends. He is very happy with them. 最后还是我得了 MVP....",
|
||||
gpt_sovits_v3v4,
|
||||
"out.wav",
|
||||
sr
|
||||
sr,
|
||||
)
|
||||
|
||||
test_export(
|
||||
"你小子是什么来路.汗流浃背了呀!老弟~ My uncle has two dogs. He is very happy with them. 最后还是我得了 MVP!",
|
||||
gpt_sovits_v3v4,
|
||||
"out2.wav",
|
||||
sr
|
||||
sr,
|
||||
)
|
||||
|
||||
# test_export(
|
||||
@ -1251,6 +1253,6 @@ def test_export_gpt_sovits_v3():
|
||||
|
||||
|
||||
with torch.no_grad():
|
||||
export_1("onnx/ad/ref.wav","你这老坏蛋,我找了你这么久,真没想到在这里找到你。他说。","v4")
|
||||
# export_2("v4")
|
||||
# export_1("onnx/ad/ref.wav","你这老坏蛋,我找了你这么久,真没想到在这里找到你。他说。","v4")
|
||||
export_2("v4")
|
||||
# test_export_gpt_sovits_v3()
|
||||
|
@ -143,9 +143,9 @@ class DiT(nn.Module):
|
||||
drop_audio_cond=False, # cfg for cond audio
|
||||
drop_text=False, # cfg for text
|
||||
# mask: bool["b n"] | None = None, # noqa: F722
|
||||
infer=False, # bool
|
||||
text_cache=None, # torch tensor as text_embed
|
||||
dt_cache=None, # torch tensor as dt
|
||||
infer=False, # bool
|
||||
text_cache=None, # torch tensor as text_embed
|
||||
dt_cache=None, # torch tensor as dt
|
||||
):
|
||||
x = x0.transpose(2, 1)
|
||||
cond = cond0.transpose(2, 1)
|
||||
@ -191,4 +191,4 @@ class DiT(nn.Module):
|
||||
if infer:
|
||||
return output, text_embed, dt
|
||||
else:
|
||||
return output
|
||||
return output
|
||||
|
@ -214,7 +214,7 @@ v3v4set = {"v3", "v4"}
|
||||
|
||||
|
||||
def change_sovits_weights(sovits_path, prompt_language=None, text_language=None):
|
||||
if "!" in sovits_path:
|
||||
if "!" in sovits_path or "!" in sovits_path:
|
||||
sovits_path = name2sovits_path[sovits_path]
|
||||
global vq_model, hps, version, model_version, dict_language, if_lora_v3
|
||||
version, model_version, if_lora_v3 = get_sovits_version_from_path_fast(sovits_path)
|
||||
@ -361,7 +361,7 @@ except:
|
||||
|
||||
|
||||
def change_gpt_weights(gpt_path):
|
||||
if "!" in gpt_path:
|
||||
if "!" in gpt_path or "!" in gpt_path:
|
||||
gpt_path = name2gpt_path[gpt_path]
|
||||
global hz, max_sec, t2s_model, config
|
||||
hz = 50
|
||||
@ -623,6 +623,10 @@ def get_phones_and_bert(text, language, version, final=False):
|
||||
textlist.append(tmp["text"])
|
||||
else:
|
||||
for tmp in LangSegmenter.getTexts(text):
|
||||
if langlist:
|
||||
if (tmp["lang"] == "en" and langlist[-1] == "en") or (tmp["lang"] != "en" and langlist[-1] != "en"):
|
||||
textlist[-1] += tmp["text"]
|
||||
continue
|
||||
if tmp["lang"] == "en":
|
||||
langlist.append(tmp["lang"])
|
||||
else:
|
||||
|
@ -114,11 +114,11 @@ tts_config.device = device
|
||||
tts_config.is_half = is_half
|
||||
tts_config.version = version
|
||||
if gpt_path is not None:
|
||||
if "!" in gpt_path:
|
||||
if "!" in gpt_path or "!" in gpt_path:
|
||||
gpt_path = name2gpt_path[gpt_path]
|
||||
tts_config.t2s_weights_path = gpt_path
|
||||
if sovits_path is not None:
|
||||
if "!" in sovits_path:
|
||||
if "!" in sovits_path or "!" in sovits_path:
|
||||
sovits_path = name2sovits_path[sovits_path]
|
||||
tts_config.vits_weights_path = sovits_path
|
||||
if cnhubert_base_path is not None:
|
||||
@ -217,7 +217,7 @@ v3v4set = {"v3", "v4"}
|
||||
|
||||
|
||||
def change_sovits_weights(sovits_path, prompt_language=None, text_language=None):
|
||||
if "!" in sovits_path:
|
||||
if "!" in sovits_path or "!" in sovits_path:
|
||||
sovits_path = name2sovits_path[sovits_path]
|
||||
global version, model_version, dict_language, if_lora_v3
|
||||
version, model_version, if_lora_v3 = get_sovits_version_from_path_fast(sovits_path)
|
||||
@ -283,6 +283,12 @@ def change_sovits_weights(sovits_path, prompt_language=None, text_language=None)
|
||||
f.write(json.dumps(data))
|
||||
|
||||
|
||||
def change_gpt_weights(gpt_path):
|
||||
if "!" in gpt_path or "!" in gpt_path:
|
||||
gpt_path = name2gpt_path[gpt_path]
|
||||
tts_pipeline.init_t2s_weights(gpt_path)
|
||||
|
||||
|
||||
with gr.Blocks(title="GPT-SoVITS WebUI", analytics_enabled=False, js=js, css=css) as app:
|
||||
gr.HTML(
|
||||
top_html.format(
|
||||
@ -457,7 +463,7 @@ with gr.Blocks(title="GPT-SoVITS WebUI", analytics_enabled=False, js=js, css=css
|
||||
inference_button,
|
||||
],
|
||||
) #
|
||||
GPT_dropdown.change(tts_pipeline.init_t2s_weights, [GPT_dropdown], [])
|
||||
GPT_dropdown.change(change_gpt_weights, [GPT_dropdown], [])
|
||||
|
||||
with gr.Group():
|
||||
gr.Markdown(
|
||||
|
@ -21,7 +21,7 @@ class TextAudioSpeakerLoader(torch.utils.data.Dataset):
|
||||
3) computes spectrograms from audio files.
|
||||
"""
|
||||
|
||||
def __init__(self, hparams, version=None,val=False):
|
||||
def __init__(self, hparams, version=None, val=False):
|
||||
exp_dir = hparams.exp_dir
|
||||
self.path2 = "%s/2-name2text.txt" % exp_dir
|
||||
self.path4 = "%s/4-cnhubert" % exp_dir
|
||||
@ -29,7 +29,7 @@ class TextAudioSpeakerLoader(torch.utils.data.Dataset):
|
||||
assert os.path.exists(self.path2)
|
||||
assert os.path.exists(self.path4)
|
||||
assert os.path.exists(self.path5)
|
||||
self.is_v2Pro=version in {"v2Pro","v2ProPlus"}
|
||||
self.is_v2Pro = version in {"v2Pro", "v2ProPlus"}
|
||||
if self.is_v2Pro:
|
||||
self.path7 = "%s/7-sv_cn" % exp_dir
|
||||
assert os.path.exists(self.path7)
|
||||
@ -118,7 +118,7 @@ class TextAudioSpeakerLoader(torch.utils.data.Dataset):
|
||||
ssl = F.pad(ssl.float(), (0, 1), mode="replicate").to(typee)
|
||||
ssl.requires_grad = False
|
||||
if self.is_v2Pro:
|
||||
sv_emb=torch.load("%s/%s.pt" % (self.path7, audiopath), map_location="cpu")
|
||||
sv_emb = torch.load("%s/%s.pt" % (self.path7, audiopath), map_location="cpu")
|
||||
except:
|
||||
traceback.print_exc()
|
||||
spec = torch.zeros(1025, 100)
|
||||
@ -126,10 +126,10 @@ class TextAudioSpeakerLoader(torch.utils.data.Dataset):
|
||||
ssl = torch.zeros(1, 768, 100)
|
||||
text = text[-1:]
|
||||
if self.is_v2Pro:
|
||||
sv_emb=torch.zeros(1,20480)
|
||||
sv_emb = torch.zeros(1, 20480)
|
||||
print("load audio or ssl error!!!!!!", audiopath)
|
||||
if self.is_v2Pro:
|
||||
return (ssl, spec, wav, text,sv_emb)
|
||||
return (ssl, spec, wav, text, sv_emb)
|
||||
else:
|
||||
return (ssl, spec, wav, text)
|
||||
|
||||
@ -192,9 +192,9 @@ class TextAudioSpeakerLoader(torch.utils.data.Dataset):
|
||||
class TextAudioSpeakerCollate:
|
||||
"""Zero-pads model inputs and targets"""
|
||||
|
||||
def __init__(self, return_ids=False,version=None):
|
||||
def __init__(self, return_ids=False, version=None):
|
||||
self.return_ids = return_ids
|
||||
self.is_v2Pro=version in {"v2Pro","v2ProPlus"}
|
||||
self.is_v2Pro = version in {"v2Pro", "v2ProPlus"}
|
||||
|
||||
def __call__(self, batch):
|
||||
"""Collate's training batch from normalized text, audio and speaker identities
|
||||
@ -228,7 +228,7 @@ class TextAudioSpeakerCollate:
|
||||
text_padded.zero_()
|
||||
|
||||
if self.is_v2Pro:
|
||||
sv_embs=torch.FloatTensor(len(batch),20480)
|
||||
sv_embs = torch.FloatTensor(len(batch), 20480)
|
||||
|
||||
for i in range(len(ids_sorted_decreasing)):
|
||||
row = batch[ids_sorted_decreasing[i]]
|
||||
@ -250,11 +250,30 @@ class TextAudioSpeakerCollate:
|
||||
text_lengths[i] = text.size(0)
|
||||
|
||||
if self.is_v2Pro:
|
||||
sv_embs[i]=row[4]
|
||||
sv_embs[i] = row[4]
|
||||
if self.is_v2Pro:
|
||||
return ssl_padded, ssl_lengths, spec_padded, spec_lengths, wav_padded, wav_lengths, text_padded, text_lengths,sv_embs
|
||||
return (
|
||||
ssl_padded,
|
||||
ssl_lengths,
|
||||
spec_padded,
|
||||
spec_lengths,
|
||||
wav_padded,
|
||||
wav_lengths,
|
||||
text_padded,
|
||||
text_lengths,
|
||||
sv_embs,
|
||||
)
|
||||
else:
|
||||
return ssl_padded, ssl_lengths, spec_padded, spec_lengths, wav_padded, wav_lengths, text_padded, text_lengths
|
||||
return (
|
||||
ssl_padded,
|
||||
ssl_lengths,
|
||||
spec_padded,
|
||||
spec_lengths,
|
||||
wav_padded,
|
||||
wav_lengths,
|
||||
text_padded,
|
||||
text_lengths,
|
||||
)
|
||||
|
||||
|
||||
class TextAudioSpeakerLoaderV3(torch.utils.data.Dataset):
|
||||
|
@ -586,12 +586,17 @@ class DiscriminatorS(torch.nn.Module):
|
||||
|
||||
return x, fmap
|
||||
|
||||
v2pro_set={"v2Pro","v2ProPlus"}
|
||||
|
||||
v2pro_set = {"v2Pro", "v2ProPlus"}
|
||||
|
||||
|
||||
class MultiPeriodDiscriminator(torch.nn.Module):
|
||||
def __init__(self, use_spectral_norm=False,version=None):
|
||||
def __init__(self, use_spectral_norm=False, version=None):
|
||||
super(MultiPeriodDiscriminator, self).__init__()
|
||||
if version in v2pro_set:periods = [2, 3, 5, 7, 11,17,23]
|
||||
else:periods = [2, 3, 5, 7, 11]
|
||||
if version in v2pro_set:
|
||||
periods = [2, 3, 5, 7, 11, 17, 23]
|
||||
else:
|
||||
periods = [2, 3, 5, 7, 11]
|
||||
|
||||
discs = [DiscriminatorS(use_spectral_norm=use_spectral_norm)]
|
||||
discs = discs + [DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods]
|
||||
@ -787,6 +792,7 @@ class CodePredictor(nn.Module):
|
||||
|
||||
return pred_codes.transpose(0, 1)
|
||||
|
||||
|
||||
class SynthesizerTrn(nn.Module):
|
||||
"""
|
||||
Synthesizer for Training
|
||||
@ -886,13 +892,13 @@ class SynthesizerTrn(nn.Module):
|
||||
self.quantizer = ResidualVectorQuantizer(dimension=ssl_dim, n_q=1, bins=1024)
|
||||
self.freeze_quantizer = freeze_quantizer
|
||||
|
||||
self.is_v2pro=self.version in v2pro_set
|
||||
self.is_v2pro = self.version in v2pro_set
|
||||
if self.is_v2pro:
|
||||
self.sv_emb = nn.Linear(20480, gin_channels)
|
||||
self.ge_to512 = nn.Linear(gin_channels, 512)
|
||||
self.prelu = nn.PReLU(num_parameters=gin_channels)
|
||||
|
||||
def forward(self, ssl, y, y_lengths, text, text_lengths,sv_emb=None):
|
||||
def forward(self, ssl, y, y_lengths, text, text_lengths, sv_emb=None):
|
||||
y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, y.size(2)), 1).to(y.dtype)
|
||||
if self.version == "v1":
|
||||
ge = self.ref_enc(y * y_mask, y_mask)
|
||||
@ -952,7 +958,7 @@ class SynthesizerTrn(nn.Module):
|
||||
return o, y_mask, (z, z_p, m_p, logs_p)
|
||||
|
||||
@torch.no_grad()
|
||||
def decode(self, codes, text, refer,noise_scale=0.5, speed=1, sv_emb=None):
|
||||
def decode(self, codes, text, refer, noise_scale=0.5, speed=1, sv_emb=None):
|
||||
def get_ge(refer, sv_emb):
|
||||
ge = None
|
||||
if refer is not None:
|
||||
@ -970,8 +976,8 @@ class SynthesizerTrn(nn.Module):
|
||||
|
||||
if type(refer) == list:
|
||||
ges = []
|
||||
for idx,_refer in enumerate(refer):
|
||||
ge = get_ge(_refer, sv_emb[idx]if self.is_v2pro else None)
|
||||
for idx, _refer in enumerate(refer):
|
||||
ge = get_ge(_refer, sv_emb[idx] if self.is_v2pro else None)
|
||||
ges.append(ge)
|
||||
ge = torch.stack(ges, 0).mean(0)
|
||||
else:
|
||||
@ -983,7 +989,14 @@ class SynthesizerTrn(nn.Module):
|
||||
quantized = self.quantizer.decode(codes)
|
||||
if self.semantic_frame_rate == "25hz":
|
||||
quantized = F.interpolate(quantized, size=int(quantized.shape[-1] * 2), mode="nearest")
|
||||
x, m_p, logs_p, y_mask = self.enc_p(quantized, y_lengths, text, text_lengths, self.ge_to512(ge.transpose(2,1)).transpose(2,1)if self.is_v2pro else ge, speed)
|
||||
x, m_p, logs_p, y_mask = self.enc_p(
|
||||
quantized,
|
||||
y_lengths,
|
||||
text,
|
||||
text_lengths,
|
||||
self.ge_to512(ge.transpose(2, 1)).transpose(2, 1) if self.is_v2pro else ge,
|
||||
speed,
|
||||
)
|
||||
z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale
|
||||
|
||||
z = self.flow(z_p, y_mask, g=ge, reverse=True)
|
||||
@ -996,6 +1009,7 @@ class SynthesizerTrn(nn.Module):
|
||||
quantized, codes, commit_loss, quantized_list = self.quantizer(ssl)
|
||||
return codes.transpose(0, 1)
|
||||
|
||||
|
||||
class CFM(torch.nn.Module):
|
||||
def __init__(self, in_channels, dit):
|
||||
super().__init__()
|
||||
@ -1029,7 +1043,18 @@ class CFM(torch.nn.Module):
|
||||
t_tensor = torch.ones(x.shape[0], device=x.device, dtype=mu.dtype) * t
|
||||
# v_pred = model(x, t_tensor, d_tensor, **extra_args)
|
||||
v_pred, text_emb, dt = self.estimator(
|
||||
x, prompt_x, x_lens, t_tensor, d_tensor, mu, use_grad_ckpt=False, drop_audio_cond=False, drop_text=False, infer=True, text_cache=text_cache, dt_cache=dt_cache
|
||||
x,
|
||||
prompt_x,
|
||||
x_lens,
|
||||
t_tensor,
|
||||
d_tensor,
|
||||
mu,
|
||||
use_grad_ckpt=False,
|
||||
drop_audio_cond=False,
|
||||
drop_text=False,
|
||||
infer=True,
|
||||
text_cache=text_cache,
|
||||
dt_cache=dt_cache,
|
||||
)
|
||||
v_pred = v_pred.transpose(2, 1)
|
||||
if self.use_conditioner_cache:
|
||||
@ -1037,18 +1062,18 @@ class CFM(torch.nn.Module):
|
||||
dt_cache = dt
|
||||
if inference_cfg_rate > 1e-5:
|
||||
neg, text_cfg_emb, _ = self.estimator(
|
||||
x,
|
||||
prompt_x,
|
||||
x_lens,
|
||||
t_tensor,
|
||||
d_tensor,
|
||||
mu,
|
||||
use_grad_ckpt=False,
|
||||
drop_audio_cond=True,
|
||||
drop_text=True,
|
||||
infer=True,
|
||||
text_cache=text_cfg_cache,
|
||||
dt_cache=dt_cache
|
||||
x,
|
||||
prompt_x,
|
||||
x_lens,
|
||||
t_tensor,
|
||||
d_tensor,
|
||||
mu,
|
||||
use_grad_ckpt=False,
|
||||
drop_audio_cond=True,
|
||||
drop_text=True,
|
||||
infer=True,
|
||||
text_cache=text_cfg_cache,
|
||||
dt_cache=dt_cache,
|
||||
)
|
||||
neg = neg.transpose(2, 1)
|
||||
if self.use_conditioner_cache:
|
||||
|
@ -763,6 +763,9 @@ class CodePredictor(nn.Module):
|
||||
return pred_codes.transpose(0, 1)
|
||||
|
||||
|
||||
v2pro_set = {"v2Pro", "v2ProPlus"}
|
||||
|
||||
|
||||
class SynthesizerTrn(nn.Module):
|
||||
"""
|
||||
Synthesizer for Training
|
||||
@ -867,20 +870,33 @@ class SynthesizerTrn(nn.Module):
|
||||
# self.enc_p.text_embedding.requires_grad_(False)
|
||||
# self.enc_p.encoder_text.requires_grad_(False)
|
||||
# self.enc_p.mrte.requires_grad_(False)
|
||||
self.is_v2pro = self.version in v2pro_set
|
||||
if self.is_v2pro:
|
||||
self.sv_emb = nn.Linear(20480, gin_channels)
|
||||
self.ge_to512 = nn.Linear(gin_channels, 512)
|
||||
self.prelu = nn.PReLU(num_parameters=gin_channels)
|
||||
|
||||
def forward(self, codes, text, refer, noise_scale=0.5, speed=1):
|
||||
def forward(self, codes, text, refer, noise_scale=0.5, speed=1, sv_emb=None):
|
||||
refer_mask = torch.ones_like(refer[:1, :1, :])
|
||||
if self.version == "v1":
|
||||
ge = self.ref_enc(refer * refer_mask, refer_mask)
|
||||
else:
|
||||
ge = self.ref_enc(refer[:, :704] * refer_mask, refer_mask)
|
||||
if self.is_v2pro:
|
||||
sv_emb = self.sv_emb(sv_emb)
|
||||
ge += sv_emb.unsqueeze(-1)
|
||||
ge = self.prelu(ge)
|
||||
|
||||
quantized = self.quantizer.decode(codes)
|
||||
if self.semantic_frame_rate == "25hz":
|
||||
dquantized = torch.cat([quantized, quantized]).permute(1, 2, 0)
|
||||
quantized = dquantized.contiguous().view(1, self.ssl_dim, -1)
|
||||
|
||||
x, m_p, logs_p, y_mask = self.enc_p(quantized, text, ge, speed)
|
||||
if self.is_v2pro:
|
||||
ge_ = self.ge_to512(ge.transpose(2, 1)).transpose(2, 1)
|
||||
x, m_p, logs_p, y_mask = self.enc_p(quantized, text, ge_, speed)
|
||||
else:
|
||||
x, m_p, logs_p, y_mask = self.enc_p(quantized, text, ge, speed)
|
||||
|
||||
z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale
|
||||
|
||||
|
@ -1,4 +1,5 @@
|
||||
import math
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import nn
|
||||
@ -718,8 +719,10 @@ class MelStyleEncoder(nn.Module):
|
||||
else:
|
||||
len_ = (~mask).sum(dim=1).unsqueeze(1)
|
||||
x = x.masked_fill(mask.unsqueeze(-1), 0)
|
||||
x = x.sum(dim=1)
|
||||
out = torch.div(x, len_)
|
||||
dtype = x.dtype
|
||||
x = x.float()
|
||||
x = torch.div(x, len_.unsqueeze(1))
|
||||
out = x.sum(dim=1).to(dtype)
|
||||
return out
|
||||
|
||||
def forward(self, x, mask=None):
|
||||
@ -743,7 +746,6 @@ class MelStyleEncoder(nn.Module):
|
||||
x = self.fc(x)
|
||||
# temoral average pooling
|
||||
w = self.temporal_avg_pool(x, mask=mask)
|
||||
|
||||
return w.unsqueeze(-1)
|
||||
|
||||
|
||||
|
@ -10,7 +10,6 @@ i_part = os.environ.get("i_part")
|
||||
all_parts = os.environ.get("all_parts")
|
||||
if "_CUDA_VISIBLE_DEVICES" in os.environ:
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = os.environ["_CUDA_VISIBLE_DEVICES"]
|
||||
from feature_extractor import cnhubert
|
||||
|
||||
opt_dir = os.environ.get("opt_dir")
|
||||
sv_path = os.environ.get("sv_path")
|
||||
@ -19,19 +18,18 @@ import torch
|
||||
is_half = eval(os.environ.get("is_half", "True")) and torch.cuda.is_available()
|
||||
|
||||
import traceback
|
||||
import numpy as np
|
||||
from scipy.io import wavfile
|
||||
import torchaudio
|
||||
|
||||
now_dir = os.getcwd()
|
||||
sys.path.append(now_dir)
|
||||
sys.path.append(f"{now_dir}/GPT_SoVITS/eres2net")
|
||||
from tools.my_utils import load_audio, clean_path
|
||||
from tools.my_utils import clean_path
|
||||
from time import time as ttime
|
||||
import shutil
|
||||
from ERes2NetV2 import ERes2NetV2
|
||||
import kaldi as Kaldi
|
||||
|
||||
|
||||
def my_save(fea, path): #####fix issue: torch.save doesn't support chinese path
|
||||
dir = os.path.dirname(path)
|
||||
name = os.path.basename(path)
|
||||
@ -56,37 +54,45 @@ if torch.cuda.is_available():
|
||||
else:
|
||||
device = "cpu"
|
||||
|
||||
|
||||
class SV:
|
||||
def __init__(self,device,is_half):
|
||||
pretrained_state = torch.load(sv_path, map_location='cpu')
|
||||
embedding_model = ERes2NetV2(baseWidth=24,scale=4,expansion=4)
|
||||
def __init__(self, device, is_half):
|
||||
pretrained_state = torch.load(sv_path, map_location="cpu")
|
||||
embedding_model = ERes2NetV2(baseWidth=24, scale=4, expansion=4)
|
||||
embedding_model.load_state_dict(pretrained_state)
|
||||
embedding_model.eval()
|
||||
self.embedding_model=embedding_model
|
||||
self.res=torchaudio.transforms.Resample(32000, 16000).to(device)
|
||||
self.embedding_model = embedding_model
|
||||
self.res = torchaudio.transforms.Resample(32000, 16000).to(device)
|
||||
if is_half == False:
|
||||
self.embedding_model=self.embedding_model.to(device)
|
||||
self.embedding_model = self.embedding_model.to(device)
|
||||
else:
|
||||
self.embedding_model=self.embedding_model.half().to(device)
|
||||
self.is_half=is_half
|
||||
self.embedding_model = self.embedding_model.half().to(device)
|
||||
self.is_half = is_half
|
||||
|
||||
def compute_embedding3(self,wav):#(1,x)#-1~1
|
||||
def compute_embedding3(self, wav): # (1,x)#-1~1
|
||||
with torch.no_grad():
|
||||
wav=self.res(wav)
|
||||
if self.is_half==True:wav=wav.half()
|
||||
feat = torch.stack([Kaldi.fbank(wav0.unsqueeze(0), num_mel_bins=80, sample_frequency=16000, dither=0) for wav0 in wav])
|
||||
wav = self.res(wav)
|
||||
if self.is_half == True:
|
||||
wav = wav.half()
|
||||
feat = torch.stack(
|
||||
[Kaldi.fbank(wav0.unsqueeze(0), num_mel_bins=80, sample_frequency=16000, dither=0) for wav0 in wav]
|
||||
)
|
||||
sv_emb = self.embedding_model.forward3(feat)
|
||||
return sv_emb
|
||||
|
||||
sv=SV(device,is_half)
|
||||
|
||||
sv = SV(device, is_half)
|
||||
|
||||
|
||||
def name2go(wav_name, wav_path):
|
||||
sv_cn_path = "%s/%s.pt" % (sv_cn_dir, wav_name)
|
||||
if os.path.exists(sv_cn_path):return
|
||||
wav_path="%s/%s" % (wav32dir, wav_name)
|
||||
wav32k,sr0 = torchaudio.load(wav_path)
|
||||
assert sr0==32000
|
||||
if os.path.exists(sv_cn_path):
|
||||
return
|
||||
wav_path = "%s/%s" % (wav32dir, wav_name)
|
||||
wav32k, sr0 = torchaudio.load(wav_path)
|
||||
assert sr0 == 32000
|
||||
wav32k = wav32k.to(device)
|
||||
emb=sv.compute_embedding3(wav32k).cpu() # torch.Size([1, 20480])
|
||||
emb = sv.compute_embedding3(wav32k).cpu() # torch.Size([1, 20480])
|
||||
my_save(emb, sv_cn_path)
|
||||
|
||||
|
||||
|
@ -17,15 +17,16 @@ def my_save(fea, path): #####fix issue: torch.save doesn't support chinese path
|
||||
shutil.move(tmp_path, "%s/%s" % (dir, name))
|
||||
|
||||
|
||||
|
||||
from io import BytesIO
|
||||
|
||||
model_version2byte={
|
||||
"v3":b"03",
|
||||
"v4":b"04",
|
||||
"v2Pro":b"05",
|
||||
"v2ProPlus":b"06",
|
||||
model_version2byte = {
|
||||
"v3": b"03",
|
||||
"v4": b"04",
|
||||
"v2Pro": b"05",
|
||||
"v2ProPlus": b"06",
|
||||
}
|
||||
|
||||
|
||||
def my_save2(fea, path, model_version):
|
||||
bio = BytesIO()
|
||||
torch.save(fea, bio)
|
||||
@ -50,7 +51,7 @@ def savee(ckpt, name, epoch, steps, hps, model_version=None, lora_rank=None):
|
||||
if lora_rank:
|
||||
opt["lora_rank"] = lora_rank
|
||||
my_save2(opt, "%s/%s.pth" % (hps.save_weight_dir, name), model_version)
|
||||
elif (model_version!=None and "Pro"in model_version):
|
||||
elif model_version != None and "Pro" in model_version:
|
||||
my_save2(opt, "%s/%s.pth" % (hps.save_weight_dir, name), model_version)
|
||||
else:
|
||||
my_save(opt, "%s/%s.pth" % (hps.save_weight_dir, name))
|
||||
@ -58,6 +59,7 @@ def savee(ckpt, name, epoch, steps, hps, model_version=None, lora_rank=None):
|
||||
except:
|
||||
return traceback.format_exc()
|
||||
|
||||
|
||||
"""
|
||||
00:v1
|
||||
01:v2
|
||||
@ -127,7 +129,7 @@ def get_sovits_version_from_path_fast(sovits_path):
|
||||
def load_sovits_new(sovits_path):
|
||||
f = open(sovits_path, "rb")
|
||||
meta = f.read(2)
|
||||
if meta != "PK":
|
||||
if meta != b"PK":
|
||||
data = b"PK" + f.read()
|
||||
bio = BytesIO()
|
||||
bio.write(data)
|
||||
|
@ -36,7 +36,7 @@ from module.models import (
|
||||
MultiPeriodDiscriminator,
|
||||
SynthesizerTrn,
|
||||
)
|
||||
from process_ckpt import savee,my_save2
|
||||
from process_ckpt import savee
|
||||
|
||||
torch.backends.cudnn.benchmark = False
|
||||
torch.backends.cudnn.deterministic = False
|
||||
@ -87,11 +87,30 @@ def run(rank, n_gpus, hps):
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.set_device(rank)
|
||||
|
||||
train_dataset = TextAudioSpeakerLoader(hps.data,version=hps.model.version)
|
||||
train_dataset = TextAudioSpeakerLoader(hps.data, version=hps.model.version)
|
||||
train_sampler = DistributedBucketSampler(
|
||||
train_dataset,
|
||||
hps.train.batch_size,
|
||||
[32,300,400,500,600,700,800,900,1000,1100,1200,1300,1400,1500,1600,1700,1800,1900,],
|
||||
[
|
||||
32,
|
||||
300,
|
||||
400,
|
||||
500,
|
||||
600,
|
||||
700,
|
||||
800,
|
||||
900,
|
||||
1000,
|
||||
1100,
|
||||
1200,
|
||||
1300,
|
||||
1400,
|
||||
1500,
|
||||
1600,
|
||||
1700,
|
||||
1800,
|
||||
1900,
|
||||
],
|
||||
num_replicas=n_gpus,
|
||||
rank=rank,
|
||||
shuffle=True,
|
||||
@ -130,9 +149,9 @@ def run(rank, n_gpus, hps):
|
||||
)
|
||||
|
||||
net_d = (
|
||||
MultiPeriodDiscriminator(hps.model.use_spectral_norm,version=hps.model.version).cuda(rank)
|
||||
MultiPeriodDiscriminator(hps.model.use_spectral_norm, version=hps.model.version).cuda(rank)
|
||||
if torch.cuda.is_available()
|
||||
else MultiPeriodDiscriminator(hps.model.use_spectral_norm,version=hps.model.version).to(device)
|
||||
else MultiPeriodDiscriminator(hps.model.use_spectral_norm, version=hps.model.version).to(device)
|
||||
)
|
||||
for name, param in net_g.named_parameters():
|
||||
if not param.requires_grad:
|
||||
@ -235,7 +254,7 @@ def run(rank, n_gpus, hps):
|
||||
print(
|
||||
"loaded pretrained %s" % hps.train.pretrained_s2D,
|
||||
net_d.module.load_state_dict(
|
||||
torch.load(hps.train.pretrained_s2D, map_location="cpu", weights_only=False)["weight"],strict=False
|
||||
torch.load(hps.train.pretrained_s2D, map_location="cpu", weights_only=False)["weight"], strict=False
|
||||
)
|
||||
if torch.cuda.is_available()
|
||||
else net_d.load_state_dict(
|
||||
@ -310,17 +329,44 @@ def train_and_evaluate(rank, epoch, hps, nets, optims, schedulers, scaler, loade
|
||||
net_g.train()
|
||||
net_d.train()
|
||||
for batch_idx, data in enumerate(tqdm(train_loader)):
|
||||
if hps.model.version in {"v2Pro","v2ProPlus"}:
|
||||
ssl, ssl_lengths, spec, spec_lengths, y, y_lengths, text, text_lengths,sv_emb=data
|
||||
if hps.model.version in {"v2Pro", "v2ProPlus"}:
|
||||
ssl, ssl_lengths, spec, spec_lengths, y, y_lengths, text, text_lengths, sv_emb = data
|
||||
else:
|
||||
ssl, ssl_lengths, spec, spec_lengths, y, y_lengths, text, text_lengths=data
|
||||
ssl, ssl_lengths, spec, spec_lengths, y, y_lengths, text, text_lengths = data
|
||||
if torch.cuda.is_available():
|
||||
spec, spec_lengths = (spec.cuda(rank,non_blocking=True,),spec_lengths.cuda(rank,non_blocking=True,),)
|
||||
y, y_lengths = (y.cuda(rank,non_blocking=True,),y_lengths.cuda(rank,non_blocking=True,),)
|
||||
spec, spec_lengths = (
|
||||
spec.cuda(
|
||||
rank,
|
||||
non_blocking=True,
|
||||
),
|
||||
spec_lengths.cuda(
|
||||
rank,
|
||||
non_blocking=True,
|
||||
),
|
||||
)
|
||||
y, y_lengths = (
|
||||
y.cuda(
|
||||
rank,
|
||||
non_blocking=True,
|
||||
),
|
||||
y_lengths.cuda(
|
||||
rank,
|
||||
non_blocking=True,
|
||||
),
|
||||
)
|
||||
ssl = ssl.cuda(rank, non_blocking=True)
|
||||
ssl.requires_grad = False
|
||||
# ssl_lengths = ssl_lengths.cuda(rank, non_blocking=True)
|
||||
text, text_lengths = (text.cuda(rank,non_blocking=True,),text_lengths.cuda(rank,non_blocking=True,),)
|
||||
text, text_lengths = (
|
||||
text.cuda(
|
||||
rank,
|
||||
non_blocking=True,
|
||||
),
|
||||
text_lengths.cuda(
|
||||
rank,
|
||||
non_blocking=True,
|
||||
),
|
||||
)
|
||||
if hps.model.version in {"v2Pro", "v2ProPlus"}:
|
||||
sv_emb = sv_emb.cuda(rank, non_blocking=True)
|
||||
else:
|
||||
@ -334,9 +380,19 @@ def train_and_evaluate(rank, epoch, hps, nets, optims, schedulers, scaler, loade
|
||||
sv_emb = sv_emb.to(device)
|
||||
with autocast(enabled=hps.train.fp16_run):
|
||||
if hps.model.version in {"v2Pro", "v2ProPlus"}:
|
||||
(y_hat,kl_ssl,ids_slice,x_mask,z_mask,(z, z_p, m_p, logs_p, m_q, logs_q),stats_ssl) = net_g(ssl, spec, spec_lengths, text, text_lengths,sv_emb)
|
||||
(y_hat, kl_ssl, ids_slice, x_mask, z_mask, (z, z_p, m_p, logs_p, m_q, logs_q), stats_ssl) = net_g(
|
||||
ssl, spec, spec_lengths, text, text_lengths, sv_emb
|
||||
)
|
||||
else:
|
||||
(y_hat,kl_ssl,ids_slice,x_mask,z_mask,(z, z_p, m_p, logs_p, m_q, logs_q),stats_ssl,) = net_g(ssl, spec, spec_lengths, text, text_lengths)
|
||||
(
|
||||
y_hat,
|
||||
kl_ssl,
|
||||
ids_slice,
|
||||
x_mask,
|
||||
z_mask,
|
||||
(z, z_p, m_p, logs_p, m_q, logs_q),
|
||||
stats_ssl,
|
||||
) = net_g(ssl, spec, spec_lengths, text, text_lengths)
|
||||
|
||||
mel = spec_to_mel_torch(
|
||||
spec,
|
||||
@ -508,7 +564,14 @@ def train_and_evaluate(rank, epoch, hps, nets, optims, schedulers, scaler, loade
|
||||
% (
|
||||
hps.name,
|
||||
epoch,
|
||||
savee(ckpt,hps.name + "_e%s_s%s" % (epoch, global_step),epoch,global_step,hps,model_version=None if hps.model.version not in {"v2Pro","v2ProPlus"}else hps.model.version),
|
||||
savee(
|
||||
ckpt,
|
||||
hps.name + "_e%s_s%s" % (epoch, global_step),
|
||||
epoch,
|
||||
global_step,
|
||||
hps,
|
||||
model_version=None if hps.model.version not in {"v2Pro", "v2ProPlus"} else hps.model.version,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
|
@ -1,24 +1,32 @@
|
||||
import sys,os,torch
|
||||
import sys
|
||||
import os
|
||||
import torch
|
||||
|
||||
sys.path.append(f"{os.getcwd()}/GPT_SoVITS/eres2net")
|
||||
sv_path = "GPT_SoVITS/pretrained_models/sv/pretrained_eres2netv2w24s4ep4.ckpt"
|
||||
from ERes2NetV2 import ERes2NetV2
|
||||
import kaldi as Kaldi
|
||||
|
||||
|
||||
class SV:
|
||||
def __init__(self,device,is_half):
|
||||
pretrained_state = torch.load(sv_path, map_location='cpu', weights_only=False)
|
||||
embedding_model = ERes2NetV2(baseWidth=24,scale=4,expansion=4)
|
||||
def __init__(self, device, is_half):
|
||||
pretrained_state = torch.load(sv_path, map_location="cpu", weights_only=False)
|
||||
embedding_model = ERes2NetV2(baseWidth=24, scale=4, expansion=4)
|
||||
embedding_model.load_state_dict(pretrained_state)
|
||||
embedding_model.eval()
|
||||
self.embedding_model=embedding_model
|
||||
self.embedding_model = embedding_model
|
||||
if is_half == False:
|
||||
self.embedding_model=self.embedding_model.to(device)
|
||||
self.embedding_model = self.embedding_model.to(device)
|
||||
else:
|
||||
self.embedding_model=self.embedding_model.half().to(device)
|
||||
self.is_half=is_half
|
||||
self.embedding_model = self.embedding_model.half().to(device)
|
||||
self.is_half = is_half
|
||||
|
||||
def compute_embedding3(self,wav):
|
||||
def compute_embedding3(self, wav):
|
||||
with torch.no_grad():
|
||||
if self.is_half==True:wav=wav.half()
|
||||
feat = torch.stack([Kaldi.fbank(wav0.unsqueeze(0), num_mel_bins=80, sample_frequency=16000, dither=0) for wav0 in wav])
|
||||
if self.is_half == True:
|
||||
wav = wav.half()
|
||||
feat = torch.stack(
|
||||
[Kaldi.fbank(wav0.unsqueeze(0), num_mel_bins=80, sample_frequency=16000, dither=0) for wav0 in wav]
|
||||
)
|
||||
sv_emb = self.embedding_model.forward3(feat)
|
||||
return sv_emb
|
||||
|
@ -3,38 +3,44 @@ import re
|
||||
|
||||
# jieba静音
|
||||
import jieba
|
||||
|
||||
jieba.setLogLevel(logging.CRITICAL)
|
||||
|
||||
# 更改fast_langdetect大模型位置
|
||||
from pathlib import Path
|
||||
import fast_langdetect
|
||||
fast_langdetect.infer._default_detector = fast_langdetect.infer.LangDetector(fast_langdetect.infer.LangDetectConfig(cache_dir=Path(__file__).parent.parent.parent / "pretrained_models" / "fast_langdetect"))
|
||||
|
||||
fast_langdetect.infer._default_detector = fast_langdetect.infer.LangDetector(
|
||||
fast_langdetect.infer.LangDetectConfig(
|
||||
cache_dir=Path(__file__).parent.parent.parent / "pretrained_models" / "fast_langdetect"
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
from split_lang import LangSplitter
|
||||
|
||||
|
||||
def full_en(text):
|
||||
pattern = r'^(?=.*[A-Za-z])[A-Za-z0-9\s\u0020-\u007E\u2000-\u206F\u3000-\u303F\uFF00-\uFFEF]+$'
|
||||
pattern = r"^(?=.*[A-Za-z])[A-Za-z0-9\s\u0020-\u007E\u2000-\u206F\u3000-\u303F\uFF00-\uFFEF]+$"
|
||||
return bool(re.match(pattern, text))
|
||||
|
||||
|
||||
def full_cjk(text):
|
||||
# 来自wiki
|
||||
cjk_ranges = [
|
||||
(0x4E00, 0x9FFF), # CJK Unified Ideographs
|
||||
(0x3400, 0x4DB5), # CJK Extension A
|
||||
(0x20000, 0x2A6DD), # CJK Extension B
|
||||
(0x2A700, 0x2B73F), # CJK Extension C
|
||||
(0x2B740, 0x2B81F), # CJK Extension D
|
||||
(0x2B820, 0x2CEAF), # CJK Extension E
|
||||
(0x2CEB0, 0x2EBEF), # CJK Extension F
|
||||
(0x30000, 0x3134A), # CJK Extension G
|
||||
(0x31350, 0x323AF), # CJK Extension H
|
||||
(0x2EBF0, 0x2EE5D), # CJK Extension H
|
||||
(0x4E00, 0x9FFF), # CJK Unified Ideographs
|
||||
(0x3400, 0x4DB5), # CJK Extension A
|
||||
(0x20000, 0x2A6DD), # CJK Extension B
|
||||
(0x2A700, 0x2B73F), # CJK Extension C
|
||||
(0x2B740, 0x2B81F), # CJK Extension D
|
||||
(0x2B820, 0x2CEAF), # CJK Extension E
|
||||
(0x2CEB0, 0x2EBEF), # CJK Extension F
|
||||
(0x30000, 0x3134A), # CJK Extension G
|
||||
(0x31350, 0x323AF), # CJK Extension H
|
||||
(0x2EBF0, 0x2EE5D), # CJK Extension H
|
||||
]
|
||||
|
||||
pattern = r'[0-9、-〜。!?.!?… /]+$'
|
||||
pattern = r"[0-9、-〜。!?.!?… /]+$"
|
||||
|
||||
cjk_text = ""
|
||||
for char in text:
|
||||
@ -45,7 +51,7 @@ def full_cjk(text):
|
||||
return cjk_text
|
||||
|
||||
|
||||
def split_jako(tag_lang,item):
|
||||
def split_jako(tag_lang, item):
|
||||
if tag_lang == "ja":
|
||||
pattern = r"([\u3041-\u3096\u3099\u309A\u30A1-\u30FA\u30FC]+(?:[0-9、-〜。!?.!?… ]+[\u3041-\u3096\u3099\u309A\u30A1-\u30FA\u30FC]*)*)"
|
||||
else:
|
||||
@ -53,41 +59,40 @@ def split_jako(tag_lang,item):
|
||||
|
||||
lang_list: list[dict] = []
|
||||
tag = 0
|
||||
for match in re.finditer(pattern, item['text']):
|
||||
for match in re.finditer(pattern, item["text"]):
|
||||
if match.start() > tag:
|
||||
lang_list.append({'lang':item['lang'],'text':item['text'][tag:match.start()]})
|
||||
lang_list.append({"lang": item["lang"], "text": item["text"][tag : match.start()]})
|
||||
|
||||
tag = match.end()
|
||||
lang_list.append({'lang':tag_lang,'text':item['text'][match.start():match.end()]})
|
||||
lang_list.append({"lang": tag_lang, "text": item["text"][match.start() : match.end()]})
|
||||
|
||||
if tag < len(item['text']):
|
||||
lang_list.append({'lang':item['lang'],'text':item['text'][tag:len(item['text'])]})
|
||||
if tag < len(item["text"]):
|
||||
lang_list.append({"lang": item["lang"], "text": item["text"][tag : len(item["text"])]})
|
||||
|
||||
return lang_list
|
||||
|
||||
|
||||
def merge_lang(lang_list, item):
|
||||
if lang_list and item['lang'] == lang_list[-1]['lang']:
|
||||
lang_list[-1]['text'] += item['text']
|
||||
if lang_list and item["lang"] == lang_list[-1]["lang"]:
|
||||
lang_list[-1]["text"] += item["text"]
|
||||
else:
|
||||
lang_list.append(item)
|
||||
return lang_list
|
||||
|
||||
|
||||
class LangSegmenter():
|
||||
class LangSegmenter:
|
||||
# 默认过滤器, 基于gsv目前四种语言
|
||||
DEFAULT_LANG_MAP = {
|
||||
"zh": "zh",
|
||||
"yue": "zh", # 粤语
|
||||
"wuu": "zh", # 吴语
|
||||
"zh-cn": "zh",
|
||||
"zh-tw": "x", # 繁体设置为x
|
||||
"zh-tw": "x", # 繁体设置为x
|
||||
"ko": "ko",
|
||||
"ja": "ja",
|
||||
"en": "en",
|
||||
}
|
||||
|
||||
|
||||
def getTexts(text):
|
||||
lang_splitter = LangSplitter(lang_map=LangSegmenter.DEFAULT_LANG_MAP)
|
||||
substr = lang_splitter.split_by_lang(text=text)
|
||||
@ -95,18 +100,18 @@ class LangSegmenter():
|
||||
lang_list: list[dict] = []
|
||||
|
||||
for _, item in enumerate(substr):
|
||||
dict_item = {'lang':item.lang,'text':item.text}
|
||||
dict_item = {"lang": item.lang, "text": item.text}
|
||||
|
||||
# 处理短英文被识别为其他语言的问题
|
||||
if full_en(dict_item['text']):
|
||||
dict_item['lang'] = 'en'
|
||||
lang_list = merge_lang(lang_list,dict_item)
|
||||
if full_en(dict_item["text"]):
|
||||
dict_item["lang"] = "en"
|
||||
lang_list = merge_lang(lang_list, dict_item)
|
||||
continue
|
||||
|
||||
# 处理非日语夹日文的问题(不包含CJK)
|
||||
ja_list: list[dict] = []
|
||||
if dict_item['lang'] != 'ja':
|
||||
ja_list = split_jako('ja',dict_item)
|
||||
if dict_item["lang"] != "ja":
|
||||
ja_list = split_jako("ja", dict_item)
|
||||
|
||||
if not ja_list:
|
||||
ja_list.append(dict_item)
|
||||
@ -115,8 +120,8 @@ class LangSegmenter():
|
||||
ko_list: list[dict] = []
|
||||
temp_list: list[dict] = []
|
||||
for _, ko_item in enumerate(ja_list):
|
||||
if ko_item["lang"] != 'ko':
|
||||
ko_list = split_jako('ko',ko_item)
|
||||
if ko_item["lang"] != "ko":
|
||||
ko_list = split_jako("ko", ko_item)
|
||||
|
||||
if ko_list:
|
||||
temp_list.extend(ko_list)
|
||||
@ -126,50 +131,50 @@ class LangSegmenter():
|
||||
# 未存在非日韩文夹日韩文
|
||||
if len(temp_list) == 1:
|
||||
# 未知语言检查是否为CJK
|
||||
if dict_item['lang'] == 'x':
|
||||
cjk_text = full_cjk(dict_item['text'])
|
||||
if dict_item["lang"] == "x":
|
||||
cjk_text = full_cjk(dict_item["text"])
|
||||
if cjk_text:
|
||||
dict_item = {'lang':'zh','text':cjk_text}
|
||||
lang_list = merge_lang(lang_list,dict_item)
|
||||
dict_item = {"lang": "zh", "text": cjk_text}
|
||||
lang_list = merge_lang(lang_list, dict_item)
|
||||
else:
|
||||
lang_list = merge_lang(lang_list,dict_item)
|
||||
lang_list = merge_lang(lang_list, dict_item)
|
||||
continue
|
||||
else:
|
||||
lang_list = merge_lang(lang_list,dict_item)
|
||||
lang_list = merge_lang(lang_list, dict_item)
|
||||
continue
|
||||
|
||||
# 存在非日韩文夹日韩文
|
||||
for _, temp_item in enumerate(temp_list):
|
||||
# 未知语言检查是否为CJK
|
||||
if temp_item['lang'] == 'x':
|
||||
cjk_text = full_cjk(dict_item['text'])
|
||||
if temp_item["lang"] == "x":
|
||||
cjk_text = full_cjk(dict_item["text"])
|
||||
if cjk_text:
|
||||
dict_item = {'lang':'zh','text':cjk_text}
|
||||
lang_list = merge_lang(lang_list,dict_item)
|
||||
dict_item = {"lang": "zh", "text": cjk_text}
|
||||
lang_list = merge_lang(lang_list, dict_item)
|
||||
else:
|
||||
lang_list = merge_lang(lang_list,dict_item)
|
||||
lang_list = merge_lang(lang_list, dict_item)
|
||||
else:
|
||||
lang_list = merge_lang(lang_list,temp_item)
|
||||
lang_list = merge_lang(lang_list, temp_item)
|
||||
|
||||
temp_list = lang_list
|
||||
lang_list = []
|
||||
for _, temp_item in enumerate(temp_list):
|
||||
if temp_item['lang'] == 'x':
|
||||
if temp_item["lang"] == "x":
|
||||
if lang_list:
|
||||
temp_item['lang'] = lang_list[-1]['lang']
|
||||
temp_item["lang"] = lang_list[-1]["lang"]
|
||||
elif len(temp_list) > 1:
|
||||
temp_item['lang'] = temp_list[1]['lang']
|
||||
temp_item["lang"] = temp_list[1]["lang"]
|
||||
else:
|
||||
temp_item['lang'] = 'zh'
|
||||
temp_item["lang"] = "zh"
|
||||
|
||||
lang_list = merge_lang(lang_list,temp_item)
|
||||
lang_list = merge_lang(lang_list, temp_item)
|
||||
|
||||
return lang_list
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
text = "MyGO?,你也喜欢まいご吗?"
|
||||
print(LangSegmenter.getTexts(text))
|
||||
|
||||
text = "ねえ、知ってる?最近、僕は天文学を勉強してるんだ。君の瞳が星空みたいにキラキラしてるからさ。"
|
||||
print(LangSegmenter.getTexts(text))
|
||||
print(LangSegmenter.getTexts(text))
|
||||
|
@ -3,7 +3,6 @@
|
||||
|
||||
import json
|
||||
import os
|
||||
import traceback
|
||||
import warnings
|
||||
import zipfile
|
||||
from typing import Any, Dict, List, Tuple
|
||||
@ -23,8 +22,9 @@ from .utils import load_config
|
||||
onnxruntime.set_default_logger_severity(3)
|
||||
try:
|
||||
onnxruntime.preload_dlls()
|
||||
except:pass
|
||||
#traceback.print_exc()
|
||||
except:
|
||||
pass
|
||||
# traceback.print_exc()
|
||||
warnings.filterwarnings("ignore")
|
||||
|
||||
model_version = "1.1"
|
||||
|
@ -655,11 +655,7 @@ class ToneSandhi:
|
||||
while i < len(seg):
|
||||
word, pos = seg[i]
|
||||
merged = False
|
||||
if (
|
||||
i - 1 >= 0
|
||||
and word == "一"
|
||||
and i + 1 < len(seg)
|
||||
):
|
||||
if i - 1 >= 0 and word == "一" and i + 1 < len(seg):
|
||||
last = new_seg[-1] if new_seg else seg[i - 1]
|
||||
if last[0] == seg[i + 1][0] and last[1] == "v" and seg[i + 1][1] == "v":
|
||||
combined = last[0] + "一" + seg[i + 1][0]
|
||||
|
@ -283,7 +283,7 @@ def get_hparams_from_file(config_path):
|
||||
def check_git_hash(model_dir):
|
||||
source_dir = os.path.dirname(os.path.realpath(__file__))
|
||||
if not os.path.exists(os.path.join(source_dir, ".git")):
|
||||
logger.warn(
|
||||
logger.warning(
|
||||
"{} is not a git repository, therefore hash value comparison will be ignored.".format(
|
||||
source_dir,
|
||||
)
|
||||
@ -296,7 +296,7 @@ def check_git_hash(model_dir):
|
||||
if os.path.exists(path):
|
||||
saved_hash = open(path).read()
|
||||
if saved_hash != cur_hash:
|
||||
logger.warn(
|
||||
logger.warning(
|
||||
"git hash values are different. {}(saved) != {}(current)".format(
|
||||
saved_hash[:8],
|
||||
cur_hash[:8],
|
||||
|
17
README.md
17
README.md
@ -9,10 +9,14 @@ A Powerful Few-shot Voice Conversion and Text-to-Speech WebUI.<br><br>
|
||||
|
||||
<!-- img src="https://counter.seku.su/cmoe?name=gptsovits&theme=r34" /><br> -->
|
||||
|
||||
[](https://colab.research.google.com/github/RVC-Boss/GPT-SoVITS/blob/main/colab_webui.ipynb)
|
||||
[](https://github.com/RVC-Boss/GPT-SoVITS/blob/main/LICENSE)
|
||||
[](https://huggingface.co/spaces/lj1995/GPT-SoVITS-v2)
|
||||
[](https://discord.gg/dnrgs5GHfG)
|
||||
[](https://colab.research.google.com/github/RVC-Boss/GPT-SoVITS/blob/main/Colab-WebUI.ipynb)
|
||||
[](https://github.com/RVC-Boss/GPT-SoVITS/blob/main/LICENSE)
|
||||
[](https://huggingface.co/spaces/lj1995/GPT-SoVITS-v2)
|
||||
[](https://hub.docker.com/r/xxxxrt666/gpt-sovits)
|
||||
|
||||
[](https://www.yuque.com/baicaigongchang1145haoyuangong/ib3g1e)
|
||||
[](https://rentry.co/GPT-SoVITS-guide#/)
|
||||
[](https://github.com/RVC-Boss/GPT-SoVITS/blob/main/docs/en/Changelog_EN.md)
|
||||
|
||||
**English** | [**中文简体**](./docs/cn/README.md) | [**日本語**](./docs/ja/README.md) | [**한국어**](./docs/ko/README.md) | [**Türkçe**](./docs/tr/README.md)
|
||||
|
||||
@ -128,8 +132,9 @@ Due to rapid development in the codebase and a slower Docker image release cycle
|
||||
|
||||
- Check [Docker Hub](https://hub.docker.com/r/xxxxrt666/gpt-sovits) for the latest available image tags
|
||||
- Choose an appropriate image tag for your environment
|
||||
- `Lite` means the Docker image does not include ASR models and UVR5 models. You can manually download the UVR5 models, while the program will automatically download the ASR models as needed
|
||||
- `Lite` means the Docker image **does not include** ASR models and UVR5 models. You can manually download the UVR5 models, while the program will automatically download the ASR models as needed
|
||||
- The appropriate architecture image (amd64/arm64) will be automatically pulled during Docker Compose
|
||||
- Docker Compose will mount **all files** in the current directory. Please switch to the project root directory and **pull the latest code** before using the Docker image
|
||||
- Optionally, build the image locally using the provided Dockerfile for the most up-to-date changes
|
||||
|
||||
#### Environment Variables
|
||||
@ -333,7 +338,7 @@ Use v4 from v1/v2/v3 environment:
|
||||
New Features:
|
||||
|
||||
1. Slightly higher VRAM usage than v2, surpassing v4's performance, with v2's hardware cost and speed.
|
||||
[more details](https://github.com/RVC-Boss/GPT-SoVITS/wiki/GPT%E2%80%90SoVITS%E2%80%90features-(%E5%90%84%E7%89%88%E6%9C%AC%E7%89%B9%E6%80%A7))
|
||||
[more details](<https://github.com/RVC-Boss/GPT-SoVITS/wiki/GPT%E2%80%90SoVITS%E2%80%90features-(%E5%90%84%E7%89%88%E6%9C%AC%E7%89%B9%E6%80%A7)>)
|
||||
|
||||
2.v1/v2 and the v2Pro series share the same characteristics, while v3/v4 have similar features. For training sets with average audio quality, v1/v2/v2Pro can deliver decent results, but v3/v4 cannot. Additionally, the synthesized tone and timebre of v3/v4 lean more toward the reference audio rather than the overall training set.
|
||||
|
||||
|
288
api.py
288
api.py
@ -163,7 +163,7 @@ from transformers import AutoModelForMaskedLM, AutoTokenizer
|
||||
import numpy as np
|
||||
from feature_extractor import cnhubert
|
||||
from io import BytesIO
|
||||
from module.models import SynthesizerTrn, SynthesizerTrnV3
|
||||
from module.models import Generator, SynthesizerTrn, SynthesizerTrnV3
|
||||
from peft import LoraConfig, get_peft_model
|
||||
from AR.models.t2s_lightning_module import Text2SemanticLightningModule
|
||||
from text import cleaned_text_to_sequence
|
||||
@ -198,8 +198,44 @@ def is_full(*items): # 任意一项为空返回False
|
||||
return True
|
||||
|
||||
|
||||
def init_bigvgan():
|
||||
bigvgan_model = hifigan_model = sv_cn_model = None
|
||||
|
||||
|
||||
def clean_hifigan_model():
|
||||
global hifigan_model
|
||||
if hifigan_model:
|
||||
hifigan_model = hifigan_model.cpu()
|
||||
hifigan_model = None
|
||||
try:
|
||||
torch.cuda.empty_cache()
|
||||
except:
|
||||
pass
|
||||
|
||||
|
||||
def clean_bigvgan_model():
|
||||
global bigvgan_model
|
||||
if bigvgan_model:
|
||||
bigvgan_model = bigvgan_model.cpu()
|
||||
bigvgan_model = None
|
||||
try:
|
||||
torch.cuda.empty_cache()
|
||||
except:
|
||||
pass
|
||||
|
||||
|
||||
def clean_sv_cn_model():
|
||||
global sv_cn_model
|
||||
if sv_cn_model:
|
||||
sv_cn_model.embedding_model = sv_cn_model.embedding_model.cpu()
|
||||
sv_cn_model = None
|
||||
try:
|
||||
torch.cuda.empty_cache()
|
||||
except:
|
||||
pass
|
||||
|
||||
|
||||
def init_bigvgan():
|
||||
global bigvgan_model, hifigan_model, sv_cn_model
|
||||
from BigVGAN import bigvgan
|
||||
|
||||
bigvgan_model = bigvgan.BigVGAN.from_pretrained(
|
||||
@ -209,20 +245,57 @@ def init_bigvgan():
|
||||
# remove weight norm in the model and set to eval mode
|
||||
bigvgan_model.remove_weight_norm()
|
||||
bigvgan_model = bigvgan_model.eval()
|
||||
|
||||
if is_half == True:
|
||||
bigvgan_model = bigvgan_model.half().to(device)
|
||||
else:
|
||||
bigvgan_model = bigvgan_model.to(device)
|
||||
|
||||
|
||||
def init_hifigan():
|
||||
global hifigan_model, bigvgan_model, sv_cn_model
|
||||
hifigan_model = Generator(
|
||||
initial_channel=100,
|
||||
resblock="1",
|
||||
resblock_kernel_sizes=[3, 7, 11],
|
||||
resblock_dilation_sizes=[[1, 3, 5], [1, 3, 5], [1, 3, 5]],
|
||||
upsample_rates=[10, 6, 2, 2, 2],
|
||||
upsample_initial_channel=512,
|
||||
upsample_kernel_sizes=[20, 12, 4, 4, 4],
|
||||
gin_channels=0,
|
||||
is_bias=True,
|
||||
)
|
||||
hifigan_model.eval()
|
||||
hifigan_model.remove_weight_norm()
|
||||
state_dict_g = torch.load(
|
||||
"%s/GPT_SoVITS/pretrained_models/gsv-v4-pretrained/vocoder.pth" % (now_dir,),
|
||||
map_location="cpu",
|
||||
weights_only=False,
|
||||
)
|
||||
print("loading vocoder", hifigan_model.load_state_dict(state_dict_g))
|
||||
if is_half == True:
|
||||
hifigan_model = hifigan_model.half().to(device)
|
||||
else:
|
||||
hifigan_model = hifigan_model.to(device)
|
||||
|
||||
|
||||
from sv import SV
|
||||
|
||||
|
||||
def init_sv_cn():
|
||||
global hifigan_model, bigvgan_model, sv_cn_model
|
||||
sv_cn_model = SV(device, is_half)
|
||||
|
||||
|
||||
resample_transform_dict = {}
|
||||
|
||||
|
||||
def resample(audio_tensor, sr0):
|
||||
def resample(audio_tensor, sr0, sr1, device):
|
||||
global resample_transform_dict
|
||||
if sr0 not in resample_transform_dict:
|
||||
resample_transform_dict[sr0] = torchaudio.transforms.Resample(sr0, 24000).to(device)
|
||||
return resample_transform_dict[sr0](audio_tensor)
|
||||
key = "%s-%s-%s" % (sr0, sr1, str(device))
|
||||
if key not in resample_transform_dict:
|
||||
resample_transform_dict[key] = torchaudio.transforms.Resample(sr0, sr1).to(device)
|
||||
return resample_transform_dict[key](audio_tensor)
|
||||
|
||||
|
||||
from module.mel_processing import mel_spectrogram_torch
|
||||
@ -252,6 +325,19 @@ mel_fn = lambda x: mel_spectrogram_torch(
|
||||
"center": False,
|
||||
},
|
||||
)
|
||||
mel_fn_v4 = lambda x: mel_spectrogram_torch(
|
||||
x,
|
||||
**{
|
||||
"n_fft": 1280,
|
||||
"win_size": 1280,
|
||||
"hop_size": 320,
|
||||
"num_mels": 100,
|
||||
"sampling_rate": 32000,
|
||||
"fmin": 0,
|
||||
"fmax": None,
|
||||
"center": False,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
sr_model = None
|
||||
@ -293,12 +379,19 @@ from process_ckpt import get_sovits_version_from_path_fast, load_sovits_new
|
||||
|
||||
|
||||
def get_sovits_weights(sovits_path):
|
||||
path_sovits_v3 = "GPT_SoVITS/pretrained_models/s2Gv3.pth"
|
||||
from config import pretrained_sovits_name
|
||||
|
||||
path_sovits_v3 = pretrained_sovits_name["v3"]
|
||||
path_sovits_v4 = pretrained_sovits_name["v4"]
|
||||
is_exist_s2gv3 = os.path.exists(path_sovits_v3)
|
||||
is_exist_s2gv4 = os.path.exists(path_sovits_v4)
|
||||
|
||||
version, model_version, if_lora_v3 = get_sovits_version_from_path_fast(sovits_path)
|
||||
if if_lora_v3 == True and is_exist_s2gv3 == False:
|
||||
logger.info("SoVITS V3 底模缺失,无法加载相应 LoRA 权重")
|
||||
is_exist = is_exist_s2gv3 if model_version == "v3" else is_exist_s2gv4
|
||||
path_sovits = path_sovits_v3 if model_version == "v3" else path_sovits_v4
|
||||
|
||||
if if_lora_v3 == True and is_exist == False:
|
||||
logger.info("SoVITS %s 底模缺失,无法加载相应 LoRA 权重" % model_version)
|
||||
|
||||
dict_s2 = load_sovits_new(sovits_path)
|
||||
hps = dict_s2["config"]
|
||||
@ -311,11 +404,13 @@ def get_sovits_weights(sovits_path):
|
||||
else:
|
||||
hps.model.version = "v2"
|
||||
|
||||
if model_version == "v3":
|
||||
hps.model.version = "v3"
|
||||
|
||||
model_params_dict = vars(hps.model)
|
||||
if model_version != "v3":
|
||||
if model_version not in {"v3", "v4"}:
|
||||
if "Pro" in model_version:
|
||||
hps.model.version = model_version
|
||||
if sv_cn_model == None:
|
||||
init_sv_cn()
|
||||
|
||||
vq_model = SynthesizerTrn(
|
||||
hps.data.filter_length // 2 + 1,
|
||||
hps.train.segment_size // hps.data.hop_length,
|
||||
@ -323,13 +418,18 @@ def get_sovits_weights(sovits_path):
|
||||
**model_params_dict,
|
||||
)
|
||||
else:
|
||||
hps.model.version = model_version
|
||||
vq_model = SynthesizerTrnV3(
|
||||
hps.data.filter_length // 2 + 1,
|
||||
hps.train.segment_size // hps.data.hop_length,
|
||||
n_speakers=hps.data.n_speakers,
|
||||
**model_params_dict,
|
||||
)
|
||||
init_bigvgan()
|
||||
if model_version == "v3":
|
||||
init_bigvgan()
|
||||
if model_version == "v4":
|
||||
init_hifigan()
|
||||
|
||||
model_version = hps.model.version
|
||||
logger.info(f"模型版本: {model_version}")
|
||||
if "pretrained" not in sovits_path:
|
||||
@ -345,7 +445,8 @@ def get_sovits_weights(sovits_path):
|
||||
if if_lora_v3 == False:
|
||||
vq_model.load_state_dict(dict_s2["weight"], strict=False)
|
||||
else:
|
||||
vq_model.load_state_dict(load_sovits_new(path_sovits_v3)["weight"], strict=False)
|
||||
path_sovits = path_sovits_v3 if model_version == "v3" else path_sovits_v4
|
||||
vq_model.load_state_dict(load_sovits_new(path_sovits)["weight"], strict=False)
|
||||
lora_rank = dict_s2["lora_rank"]
|
||||
lora_config = LoraConfig(
|
||||
target_modules=["to_k", "to_q", "to_v", "to_out.0"],
|
||||
@ -479,6 +580,10 @@ def get_phones_and_bert(text, language, version, final=False):
|
||||
textlist.append(tmp["text"])
|
||||
else:
|
||||
for tmp in LangSegmenter.getTexts(text):
|
||||
if langlist:
|
||||
if (tmp["lang"] == "en" and langlist[-1] == "en") or (tmp["lang"] != "en" and langlist[-1] != "en"):
|
||||
textlist[-1] += tmp["text"]
|
||||
continue
|
||||
if tmp["lang"] == "en":
|
||||
langlist.append(tmp["lang"])
|
||||
else:
|
||||
@ -533,23 +638,34 @@ class DictToAttrRecursive(dict):
|
||||
raise AttributeError(f"Attribute {item} not found")
|
||||
|
||||
|
||||
def get_spepc(hps, filename):
|
||||
audio, _ = librosa.load(filename, sr=int(hps.data.sampling_rate))
|
||||
audio = torch.FloatTensor(audio)
|
||||
def get_spepc(hps, filename, dtype, device, is_v2pro=False):
|
||||
sr1 = int(hps.data.sampling_rate)
|
||||
audio, sr0 = torchaudio.load(filename)
|
||||
if sr0 != sr1:
|
||||
audio = audio.to(device)
|
||||
if audio.shape[0] == 2:
|
||||
audio = audio.mean(0).unsqueeze(0)
|
||||
audio = resample(audio, sr0, sr1, device)
|
||||
else:
|
||||
audio = audio.to(device)
|
||||
if audio.shape[0] == 2:
|
||||
audio = audio.mean(0).unsqueeze(0)
|
||||
|
||||
maxx = audio.abs().max()
|
||||
if maxx > 1:
|
||||
audio /= min(2, maxx)
|
||||
audio_norm = audio
|
||||
audio_norm = audio_norm.unsqueeze(0)
|
||||
spec = spectrogram_torch(
|
||||
audio_norm,
|
||||
audio,
|
||||
hps.data.filter_length,
|
||||
hps.data.sampling_rate,
|
||||
hps.data.hop_length,
|
||||
hps.data.win_length,
|
||||
center=False,
|
||||
)
|
||||
return spec
|
||||
spec = spec.to(dtype)
|
||||
if is_v2pro == True:
|
||||
audio = resample(audio, sr1, 16000, device).to(dtype)
|
||||
return spec, audio
|
||||
|
||||
|
||||
def pack_audio(audio_bytes, data, rate):
|
||||
@ -736,6 +852,16 @@ def get_tts_wav(
|
||||
t2s_model = infer_gpt.t2s_model
|
||||
max_sec = infer_gpt.max_sec
|
||||
|
||||
if version == "v3":
|
||||
if sample_steps not in [4, 8, 16, 32, 64, 128]:
|
||||
sample_steps = 32
|
||||
elif version == "v4":
|
||||
if sample_steps not in [4, 8, 16, 32]:
|
||||
sample_steps = 8
|
||||
|
||||
if if_sr and version != "v3":
|
||||
if_sr = False
|
||||
|
||||
t0 = ttime()
|
||||
prompt_text = prompt_text.strip("\n")
|
||||
if prompt_text[-1] not in splits:
|
||||
@ -759,19 +885,29 @@ def get_tts_wav(
|
||||
prompt_semantic = codes[0, 0]
|
||||
prompt = prompt_semantic.unsqueeze(0).to(device)
|
||||
|
||||
if version != "v3":
|
||||
is_v2pro = version in {"v2Pro", "v2ProPlus"}
|
||||
if version not in {"v3", "v4"}:
|
||||
refers = []
|
||||
if is_v2pro:
|
||||
sv_emb = []
|
||||
if sv_cn_model == None:
|
||||
init_sv_cn()
|
||||
if inp_refs:
|
||||
for path in inp_refs:
|
||||
try:
|
||||
refer = get_spepc(hps, path).to(dtype).to(device)
|
||||
try: #####这里加上提取sv的逻辑,要么一堆sv一堆refer,要么单个sv单个refer
|
||||
refer, audio_tensor = get_spepc(hps, path.name, dtype, device, is_v2pro)
|
||||
refers.append(refer)
|
||||
if is_v2pro:
|
||||
sv_emb.append(sv_cn_model.compute_embedding3(audio_tensor))
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
if len(refers) == 0:
|
||||
refers = [get_spepc(hps, ref_wav_path).to(dtype).to(device)]
|
||||
refers, audio_tensor = get_spepc(hps, ref_wav_path, dtype, device, is_v2pro)
|
||||
refers = [refers]
|
||||
if is_v2pro:
|
||||
sv_emb = [sv_cn_model.compute_embedding3(audio_tensor)]
|
||||
else:
|
||||
refer = get_spepc(hps, ref_wav_path).to(device).to(dtype)
|
||||
refer, audio_tensor = get_spepc(hps, ref_wav_path, dtype, device)
|
||||
|
||||
t1 = ttime()
|
||||
# os.environ['version'] = version
|
||||
@ -811,41 +947,56 @@ def get_tts_wav(
|
||||
pred_semantic = pred_semantic[:, -idx:].unsqueeze(0)
|
||||
t3 = ttime()
|
||||
|
||||
if version != "v3":
|
||||
audio = (
|
||||
vq_model.decode(pred_semantic, torch.LongTensor(phones2).to(device).unsqueeze(0), refers, speed=speed)
|
||||
.detach()
|
||||
.cpu()
|
||||
.numpy()[0, 0]
|
||||
) ###试试重建不带上prompt部分
|
||||
if version not in {"v3", "v4"}:
|
||||
if is_v2pro:
|
||||
audio = (
|
||||
vq_model.decode(
|
||||
pred_semantic,
|
||||
torch.LongTensor(phones2).to(device).unsqueeze(0),
|
||||
refers,
|
||||
speed=speed,
|
||||
sv_emb=sv_emb,
|
||||
)
|
||||
.detach()
|
||||
.cpu()
|
||||
.numpy()[0, 0]
|
||||
)
|
||||
else:
|
||||
audio = (
|
||||
vq_model.decode(
|
||||
pred_semantic, torch.LongTensor(phones2).to(device).unsqueeze(0), refers, speed=speed
|
||||
)
|
||||
.detach()
|
||||
.cpu()
|
||||
.numpy()[0, 0]
|
||||
)
|
||||
else:
|
||||
phoneme_ids0 = torch.LongTensor(phones1).to(device).unsqueeze(0)
|
||||
phoneme_ids1 = torch.LongTensor(phones2).to(device).unsqueeze(0)
|
||||
# print(11111111, phoneme_ids0, phoneme_ids1)
|
||||
|
||||
fea_ref, ge = vq_model.decode_encp(prompt.unsqueeze(0), phoneme_ids0, refer)
|
||||
ref_audio, sr = torchaudio.load(ref_wav_path)
|
||||
ref_audio = ref_audio.to(device).float()
|
||||
if ref_audio.shape[0] == 2:
|
||||
ref_audio = ref_audio.mean(0).unsqueeze(0)
|
||||
if sr != 24000:
|
||||
ref_audio = resample(ref_audio, sr)
|
||||
# print("ref_audio",ref_audio.abs().mean())
|
||||
mel2 = mel_fn(ref_audio)
|
||||
|
||||
tgt_sr = 24000 if version == "v3" else 32000
|
||||
if sr != tgt_sr:
|
||||
ref_audio = resample(ref_audio, sr, tgt_sr, device)
|
||||
mel2 = mel_fn(ref_audio) if version == "v3" else mel_fn_v4(ref_audio)
|
||||
mel2 = norm_spec(mel2)
|
||||
T_min = min(mel2.shape[2], fea_ref.shape[2])
|
||||
mel2 = mel2[:, :, :T_min]
|
||||
fea_ref = fea_ref[:, :, :T_min]
|
||||
if T_min > 468:
|
||||
mel2 = mel2[:, :, -468:]
|
||||
fea_ref = fea_ref[:, :, -468:]
|
||||
T_min = 468
|
||||
chunk_len = 934 - T_min
|
||||
# print("fea_ref",fea_ref,fea_ref.shape)
|
||||
# print("mel2",mel2)
|
||||
Tref = 468 if version == "v3" else 500
|
||||
Tchunk = 934 if version == "v3" else 1000
|
||||
if T_min > Tref:
|
||||
mel2 = mel2[:, :, -Tref:]
|
||||
fea_ref = fea_ref[:, :, -Tref:]
|
||||
T_min = Tref
|
||||
chunk_len = Tchunk - T_min
|
||||
mel2 = mel2.to(dtype)
|
||||
fea_todo, ge = vq_model.decode_encp(pred_semantic, phoneme_ids1, refer, ge, speed)
|
||||
# print("fea_todo",fea_todo)
|
||||
# print("ge",ge.abs().mean())
|
||||
cfm_resss = []
|
||||
idx = 0
|
||||
while 1:
|
||||
@ -854,22 +1005,24 @@ def get_tts_wav(
|
||||
break
|
||||
idx += chunk_len
|
||||
fea = torch.cat([fea_ref, fea_todo_chunk], 2).transpose(2, 1)
|
||||
# set_seed(123)
|
||||
cfm_res = vq_model.cfm.inference(
|
||||
fea, torch.LongTensor([fea.size(1)]).to(fea.device), mel2, sample_steps, inference_cfg_rate=0
|
||||
)
|
||||
cfm_res = cfm_res[:, :, mel2.shape[2] :]
|
||||
mel2 = cfm_res[:, :, -T_min:]
|
||||
# print("fea", fea)
|
||||
# print("mel2in", mel2)
|
||||
fea_ref = fea_todo_chunk[:, :, -T_min:]
|
||||
cfm_resss.append(cfm_res)
|
||||
cmf_res = torch.cat(cfm_resss, 2)
|
||||
cmf_res = denorm_spec(cmf_res)
|
||||
if bigvgan_model == None:
|
||||
init_bigvgan()
|
||||
cfm_res = torch.cat(cfm_resss, 2)
|
||||
cfm_res = denorm_spec(cfm_res)
|
||||
if version == "v3":
|
||||
if bigvgan_model == None:
|
||||
init_bigvgan()
|
||||
else: # v4
|
||||
if hifigan_model == None:
|
||||
init_hifigan()
|
||||
vocoder_model = bigvgan_model if version == "v3" else hifigan_model
|
||||
with torch.inference_mode():
|
||||
wav_gen = bigvgan_model(cmf_res)
|
||||
wav_gen = vocoder_model(cfm_res)
|
||||
audio = wav_gen[0][0].cpu().detach().numpy()
|
||||
|
||||
max_audio = np.abs(audio).max()
|
||||
@ -880,7 +1033,13 @@ def get_tts_wav(
|
||||
audio_opt = np.concatenate(audio_opt, 0)
|
||||
t4 = ttime()
|
||||
|
||||
sr = hps.data.sampling_rate if version != "v3" else 24000
|
||||
if version in {"v1", "v2", "v2Pro", "v2ProPlus"}:
|
||||
sr = 32000
|
||||
elif version == "v3":
|
||||
sr = 24000
|
||||
else:
|
||||
sr = 48000 # v4
|
||||
|
||||
if if_sr and sr == 24000:
|
||||
audio_opt = torch.from_numpy(audio_opt).float().to(device)
|
||||
audio_opt, sr = audio_sr(audio_opt.unsqueeze(0), sr)
|
||||
@ -900,8 +1059,12 @@ def get_tts_wav(
|
||||
|
||||
if not stream_mode == "normal":
|
||||
if media_type == "wav":
|
||||
sr = 48000 if if_sr else 24000
|
||||
sr = hps.data.sampling_rate if version != "v3" else sr
|
||||
if version in {"v1", "v2", "v2Pro", "v2ProPlus"}:
|
||||
sr = 32000
|
||||
elif version == "v3":
|
||||
sr = 48000 if if_sr else 24000
|
||||
else:
|
||||
sr = 48000 # v4
|
||||
audio_bytes = pack_wav(audio_bytes, sr)
|
||||
yield audio_bytes.getvalue()
|
||||
|
||||
@ -966,9 +1129,6 @@ def handle(
|
||||
if not default_refer.is_ready():
|
||||
return JSONResponse({"code": 400, "message": "未指定参考音频且接口无预设"}, status_code=400)
|
||||
|
||||
if sample_steps not in [4, 8, 16, 32]:
|
||||
sample_steps = 32
|
||||
|
||||
if cut_punc == None:
|
||||
text = cut_text(text, default_cut_punc)
|
||||
else:
|
||||
@ -1071,10 +1231,10 @@ default_refer = DefaultRefer(args.default_refer_path, args.default_refer_text, a
|
||||
# 模型路径检查
|
||||
if sovits_path == "":
|
||||
sovits_path = g_config.pretrained_sovits_path
|
||||
logger.warn(f"未指定SoVITS模型路径, fallback后当前值: {sovits_path}")
|
||||
logger.warning(f"未指定SoVITS模型路径, fallback后当前值: {sovits_path}")
|
||||
if gpt_path == "":
|
||||
gpt_path = g_config.pretrained_gpt_path
|
||||
logger.warn(f"未指定GPT模型路径, fallback后当前值: {gpt_path}")
|
||||
logger.warning(f"未指定GPT模型路径, fallback后当前值: {gpt_path}")
|
||||
|
||||
# 指定默认参考音频, 调用方 未提供/未给全 参考音频参数时使用
|
||||
if default_refer.path == "" or default_refer.text == "" or default_refer.language == "":
|
||||
|
@ -33,14 +33,14 @@ POST:
|
||||
"text_split_method": "cut0", # str. text split method, see text_segmentation_method.py for details.
|
||||
"batch_size": 1, # int. batch size for inference
|
||||
"batch_threshold": 0.75, # float. threshold for batch splitting.
|
||||
"split_bucket: True, # bool. whether to split the batch into multiple buckets.
|
||||
"split_bucket": True, # bool. whether to split the batch into multiple buckets.
|
||||
"speed_factor":1.0, # float. control the speed of the synthesized audio.
|
||||
"streaming_mode": False, # bool. whether to return a streaming response.
|
||||
"seed": -1, # int. random seed for reproducibility.
|
||||
"parallel_infer": True, # bool. whether to use parallel inference.
|
||||
"repetition_penalty": 1.35 # float. repetition penalty for T2S model.
|
||||
"repetition_penalty": 1.35, # float. repetition penalty for T2S model.
|
||||
"sample_steps": 32, # int. number of sampling steps for VITS model V3.
|
||||
"super_sampling": False, # bool. whether to use super-sampling for audio when using VITS model V3.
|
||||
"super_sampling": False # bool. whether to use super-sampling for audio when using VITS model V3.
|
||||
}
|
||||
```
|
||||
|
||||
|
@ -1,442 +0,0 @@
|
||||
import argparse
|
||||
import os
|
||||
import pdb
|
||||
import signal
|
||||
import sys
|
||||
from time import time as ttime
|
||||
import torch
|
||||
import librosa
|
||||
import soundfile as sf
|
||||
from fastapi import FastAPI, Request, HTTPException
|
||||
from fastapi.responses import StreamingResponse
|
||||
import uvicorn
|
||||
from transformers import AutoModelForMaskedLM, AutoTokenizer
|
||||
import numpy as np
|
||||
from feature_extractor import cnhubert
|
||||
from io import BytesIO
|
||||
from module.models import SynthesizerTrn
|
||||
from AR.models.t2s_lightning_module import Text2SemanticLightningModule
|
||||
from text import cleaned_text_to_sequence
|
||||
from text.cleaner import clean_text
|
||||
from module.mel_processing import spectrogram_torch
|
||||
from my_utils import load_audio
|
||||
import config as global_config
|
||||
|
||||
g_config = global_config.Config()
|
||||
|
||||
# AVAILABLE_COMPUTE = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
parser = argparse.ArgumentParser(description="GPT-SoVITS api")
|
||||
|
||||
parser.add_argument("-s", "--sovits_path", type=str, default=g_config.sovits_path, help="SoVITS模型路径")
|
||||
parser.add_argument("-g", "--gpt_path", type=str, default=g_config.gpt_path, help="GPT模型路径")
|
||||
|
||||
parser.add_argument("-dr", "--default_refer_path", type=str, default="",
|
||||
help="默认参考音频路径, 请求缺少参考音频时调用")
|
||||
parser.add_argument("-dt", "--default_refer_text", type=str, default="", help="默认参考音频文本")
|
||||
parser.add_argument("-dl", "--default_refer_language", type=str, default="", help="默认参考音频语种")
|
||||
|
||||
parser.add_argument("-d", "--device", type=str, default=g_config.infer_device, help="cuda / cpu")
|
||||
parser.add_argument("-p", "--port", type=int, default=g_config.api_port, help="default: 9880")
|
||||
parser.add_argument("-a", "--bind_addr", type=str, default="127.0.0.1", help="default: 127.0.0.1")
|
||||
parser.add_argument("-fp", "--full_precision", action="store_true", default=False, help="覆盖config.is_half为False, 使用全精度")
|
||||
parser.add_argument("-hp", "--half_precision", action="store_true", default=False, help="覆盖config.is_half为True, 使用半精度")
|
||||
# bool值的用法为 `python ./api.py -fp ...`
|
||||
# 此时 full_precision==True, half_precision==False
|
||||
|
||||
parser.add_argument("-hb", "--hubert_path", type=str, default=g_config.cnhubert_path, help="覆盖config.cnhubert_path")
|
||||
parser.add_argument("-b", "--bert_path", type=str, default=g_config.bert_path, help="覆盖config.bert_path")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
sovits_path = args.sovits_path
|
||||
gpt_path = args.gpt_path
|
||||
|
||||
default_refer_path = args.default_refer_path
|
||||
default_refer_text = args.default_refer_text
|
||||
default_refer_language = args.default_refer_language
|
||||
has_preset = False
|
||||
|
||||
device = args.device
|
||||
port = args.port
|
||||
host = args.bind_addr
|
||||
|
||||
if sovits_path == "":
|
||||
sovits_path = g_config.pretrained_sovits_path
|
||||
print(f"[WARN] 未指定SoVITS模型路径, fallback后当前值: {sovits_path}")
|
||||
if gpt_path == "":
|
||||
gpt_path = g_config.pretrained_gpt_path
|
||||
print(f"[WARN] 未指定GPT模型路径, fallback后当前值: {gpt_path}")
|
||||
|
||||
# 指定默认参考音频, 调用方 未提供/未给全 参考音频参数时使用
|
||||
if default_refer_path == "" or default_refer_text == "" or default_refer_language == "":
|
||||
default_refer_path, default_refer_text, default_refer_language = "", "", ""
|
||||
print("[INFO] 未指定默认参考音频")
|
||||
has_preset = False
|
||||
else:
|
||||
print(f"[INFO] 默认参考音频路径: {default_refer_path}")
|
||||
print(f"[INFO] 默认参考音频文本: {default_refer_text}")
|
||||
print(f"[INFO] 默认参考音频语种: {default_refer_language}")
|
||||
has_preset = True
|
||||
|
||||
is_half = g_config.is_half
|
||||
if args.full_precision:
|
||||
is_half = False
|
||||
if args.half_precision:
|
||||
is_half = True
|
||||
if args.full_precision and args.half_precision:
|
||||
is_half = g_config.is_half # 炒饭fallback
|
||||
|
||||
print(f"[INFO] 半精: {is_half}")
|
||||
|
||||
cnhubert_base_path = args.hubert_path
|
||||
bert_path = args.bert_path
|
||||
|
||||
cnhubert.cnhubert_base_path = cnhubert_base_path
|
||||
tokenizer = AutoTokenizer.from_pretrained(bert_path)
|
||||
bert_model = AutoModelForMaskedLM.from_pretrained(bert_path)
|
||||
if is_half:
|
||||
bert_model = bert_model.half().to(device)
|
||||
else:
|
||||
bert_model = bert_model.to(device)
|
||||
|
||||
|
||||
def get_bert_feature(text, word2ph):
|
||||
with torch.no_grad():
|
||||
inputs = tokenizer(text, return_tensors="pt")
|
||||
for i in inputs:
|
||||
inputs[i] = inputs[i].to(device) #####输入是long不用管精度问题,精度随bert_model
|
||||
res = bert_model(**inputs, output_hidden_states=True)
|
||||
res = torch.cat(res["hidden_states"][-3:-2], -1)[0].cpu()[1:-1]
|
||||
assert len(word2ph) == len(text)
|
||||
phone_level_feature = []
|
||||
for i in range(len(word2ph)):
|
||||
repeat_feature = res[i].repeat(word2ph[i], 1)
|
||||
phone_level_feature.append(repeat_feature)
|
||||
phone_level_feature = torch.cat(phone_level_feature, dim=0)
|
||||
# if(is_half==True):phone_level_feature=phone_level_feature.half()
|
||||
return phone_level_feature.T
|
||||
|
||||
|
||||
n_semantic = 1024
|
||||
dict_s2 = torch.load(sovits_path, map_location="cpu", weights_only=False)
|
||||
hps = dict_s2["config"]
|
||||
print(hps)
|
||||
|
||||
class DictToAttrRecursive(dict):
|
||||
def __init__(self, input_dict):
|
||||
super().__init__(input_dict)
|
||||
for key, value in input_dict.items():
|
||||
if isinstance(value, dict):
|
||||
value = DictToAttrRecursive(value)
|
||||
self[key] = value
|
||||
setattr(self, key, value)
|
||||
|
||||
def __getattr__(self, item):
|
||||
try:
|
||||
return self[item]
|
||||
except KeyError:
|
||||
raise AttributeError(f"Attribute {item} not found")
|
||||
|
||||
def __setattr__(self, key, value):
|
||||
if isinstance(value, dict):
|
||||
value = DictToAttrRecursive(value)
|
||||
super(DictToAttrRecursive, self).__setitem__(key, value)
|
||||
super().__setattr__(key, value)
|
||||
|
||||
def __delattr__(self, item):
|
||||
try:
|
||||
del self[item]
|
||||
except KeyError:
|
||||
raise AttributeError(f"Attribute {item} not found")
|
||||
|
||||
|
||||
hps = DictToAttrRecursive(hps)
|
||||
hps.model.semantic_frame_rate = "25hz"
|
||||
dict_s1 = torch.load(gpt_path, map_location="cpu", weights_only=False)
|
||||
config = dict_s1["config"]
|
||||
ssl_model = cnhubert.get_model()
|
||||
if is_half:
|
||||
ssl_model = ssl_model.half().to(device)
|
||||
else:
|
||||
ssl_model = ssl_model.to(device)
|
||||
|
||||
vq_model = SynthesizerTrn(
|
||||
hps.data.filter_length // 2 + 1,
|
||||
hps.train.segment_size // hps.data.hop_length,
|
||||
n_speakers=hps.data.n_speakers,
|
||||
**hps.model)
|
||||
if is_half:
|
||||
vq_model = vq_model.half().to(device)
|
||||
else:
|
||||
vq_model = vq_model.to(device)
|
||||
vq_model.eval()
|
||||
print(vq_model.load_state_dict(dict_s2["weight"], strict=False))
|
||||
hz = 50
|
||||
max_sec = config['data']['max_sec']
|
||||
t2s_model = Text2SemanticLightningModule(config, "ojbk", is_train=False)
|
||||
t2s_model.load_state_dict(dict_s1["weight"])
|
||||
if is_half:
|
||||
t2s_model = t2s_model.half()
|
||||
t2s_model = t2s_model.to(device)
|
||||
t2s_model.eval()
|
||||
total = sum([param.nelement() for param in t2s_model.parameters()])
|
||||
print("Number of parameter: %.2fM" % (total / 1e6))
|
||||
|
||||
|
||||
def get_spepc(hps, filename):
|
||||
audio = load_audio(filename, int(hps.data.sampling_rate))
|
||||
audio = torch.FloatTensor(audio)
|
||||
audio_norm = audio
|
||||
audio_norm = audio_norm.unsqueeze(0)
|
||||
spec = spectrogram_torch(audio_norm, hps.data.filter_length, hps.data.sampling_rate, hps.data.hop_length,
|
||||
hps.data.win_length, center=False)
|
||||
return spec
|
||||
|
||||
|
||||
dict_language = {
|
||||
"中文": "zh",
|
||||
"英文": "en",
|
||||
"日文": "ja",
|
||||
"ZH": "zh",
|
||||
"EN": "en",
|
||||
"JA": "ja",
|
||||
"zh": "zh",
|
||||
"en": "en",
|
||||
"ja": "ja"
|
||||
}
|
||||
|
||||
|
||||
def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language):
|
||||
t0 = ttime()
|
||||
prompt_text = prompt_text.strip("\n")
|
||||
prompt_language, text = prompt_language, text.strip("\n")
|
||||
zero_wav = np.zeros(int(hps.data.sampling_rate * 0.3), dtype=np.float16 if is_half == True else np.float32)
|
||||
with torch.no_grad():
|
||||
wav16k, sr = librosa.load(ref_wav_path, sr=16000)
|
||||
wav16k = torch.from_numpy(wav16k)
|
||||
zero_wav_torch = torch.from_numpy(zero_wav)
|
||||
if (is_half == True):
|
||||
wav16k = wav16k.half().to(device)
|
||||
zero_wav_torch = zero_wav_torch.half().to(device)
|
||||
else:
|
||||
wav16k = wav16k.to(device)
|
||||
zero_wav_torch = zero_wav_torch.to(device)
|
||||
wav16k=torch.cat([wav16k,zero_wav_torch])
|
||||
ssl_content = ssl_model.model(wav16k.unsqueeze(0))["last_hidden_state"].transpose(1, 2) # .float()
|
||||
codes = vq_model.extract_latent(ssl_content)
|
||||
prompt_semantic = codes[0, 0]
|
||||
t1 = ttime()
|
||||
prompt_language = dict_language[prompt_language]
|
||||
text_language = dict_language[text_language]
|
||||
phones1, word2ph1, norm_text1 = clean_text(prompt_text, prompt_language)
|
||||
phones1 = cleaned_text_to_sequence(phones1)
|
||||
texts = text.split("\n")
|
||||
audio_opt = []
|
||||
|
||||
for text in texts:
|
||||
phones2, word2ph2, norm_text2 = clean_text(text, text_language)
|
||||
phones2 = cleaned_text_to_sequence(phones2)
|
||||
if (prompt_language == "zh"):
|
||||
bert1 = get_bert_feature(norm_text1, word2ph1).to(device)
|
||||
else:
|
||||
bert1 = torch.zeros((1024, len(phones1)), dtype=torch.float16 if is_half == True else torch.float32).to(
|
||||
device)
|
||||
if (text_language == "zh"):
|
||||
bert2 = get_bert_feature(norm_text2, word2ph2).to(device)
|
||||
else:
|
||||
bert2 = torch.zeros((1024, len(phones2))).to(bert1)
|
||||
bert = torch.cat([bert1, bert2], 1)
|
||||
|
||||
all_phoneme_ids = torch.LongTensor(phones1 + phones2).to(device).unsqueeze(0)
|
||||
bert = bert.to(device).unsqueeze(0)
|
||||
all_phoneme_len = torch.tensor([all_phoneme_ids.shape[-1]]).to(device)
|
||||
prompt = prompt_semantic.unsqueeze(0).to(device)
|
||||
t2 = ttime()
|
||||
with torch.no_grad():
|
||||
# pred_semantic = t2s_model.model.infer(
|
||||
pred_semantic, idx = t2s_model.model.infer_panel(
|
||||
all_phoneme_ids,
|
||||
all_phoneme_len,
|
||||
prompt,
|
||||
bert,
|
||||
# prompt_phone_len=ph_offset,
|
||||
top_k=config['inference']['top_k'],
|
||||
early_stop_num=hz * max_sec)
|
||||
t3 = ttime()
|
||||
# print(pred_semantic.shape,idx)
|
||||
pred_semantic = pred_semantic[:, -idx:].unsqueeze(0) # .unsqueeze(0)#mq要多unsqueeze一次
|
||||
refer = get_spepc(hps, ref_wav_path) # .to(device)
|
||||
if (is_half == True):
|
||||
refer = refer.half().to(device)
|
||||
else:
|
||||
refer = refer.to(device)
|
||||
# audio = vq_model.decode(pred_semantic, all_phoneme_ids, refer).detach().cpu().numpy()[0, 0]
|
||||
audio = \
|
||||
vq_model.decode(pred_semantic, torch.LongTensor(phones2).to(device).unsqueeze(0),
|
||||
refer).detach().cpu().numpy()[
|
||||
0, 0] ###试试重建不带上prompt部分
|
||||
audio_opt.append(audio)
|
||||
audio_opt.append(zero_wav)
|
||||
t4 = ttime()
|
||||
print("%.3f\t%.3f\t%.3f\t%.3f" % (t1 - t0, t2 - t1, t3 - t2, t4 - t3))
|
||||
# yield hps.data.sampling_rate, (np.concatenate(audio_opt, 0) * 32768).astype(np.int16)
|
||||
return hps.data.sampling_rate, (np.concatenate(audio_opt, 0) * 32768).astype(np.int16)
|
||||
def get_tts_wavs(ref_wav_path, prompt_text, prompt_language, textss, text_language):
|
||||
t0 = ttime()
|
||||
prompt_text = prompt_text.strip("\n")
|
||||
zero_wav = np.zeros(int(hps.data.sampling_rate * 0.3), dtype=np.float16 if is_half == True else np.float32)
|
||||
with torch.no_grad():
|
||||
wav16k, sr = librosa.load(ref_wav_path, sr=16000)
|
||||
wav16k = torch.from_numpy(wav16k)
|
||||
zero_wav_torch = torch.from_numpy(zero_wav)
|
||||
if (is_half == True):
|
||||
wav16k = wav16k.half().to(device)
|
||||
zero_wav_torch = zero_wav_torch.half().to(device)
|
||||
else:
|
||||
wav16k = wav16k.to(device)
|
||||
zero_wav_torch = zero_wav_torch.to(device)
|
||||
wav16k=torch.cat([wav16k,zero_wav_torch])
|
||||
ssl_content = ssl_model.model(wav16k.unsqueeze(0))["last_hidden_state"].transpose(1, 2) # .float()
|
||||
codes = vq_model.extract_latent(ssl_content)
|
||||
prompt_semantic = codes[0, 0]
|
||||
t1 = ttime()
|
||||
prompt_language = dict_language[prompt_language]
|
||||
text_language = dict_language[text_language]
|
||||
phones1, word2ph1, norm_text1 = clean_text(prompt_text, prompt_language)
|
||||
phones1 = cleaned_text_to_sequence(phones1)
|
||||
audios_opt=[]
|
||||
for text0 in textss:
|
||||
texts = text0.strip("\n").split("\n")
|
||||
audio_opt = []
|
||||
for text in texts:
|
||||
text=text.strip("。")+"。"
|
||||
phones2, word2ph2, norm_text2 = clean_text(text, text_language)
|
||||
phones2 = cleaned_text_to_sequence(phones2)
|
||||
if (prompt_language == "zh"):
|
||||
bert1 = get_bert_feature(norm_text1, word2ph1).to(device)
|
||||
else:
|
||||
bert1 = torch.zeros((1024, len(phones1)), dtype=torch.float16 if is_half == True else torch.float32).to(
|
||||
device)
|
||||
if (text_language == "zh"):
|
||||
bert2 = get_bert_feature(norm_text2, word2ph2).to(device)
|
||||
else:
|
||||
bert2 = torch.zeros((1024, len(phones2))).to(bert1)
|
||||
bert = torch.cat([bert1, bert2], 1)
|
||||
|
||||
all_phoneme_ids = torch.LongTensor(phones1 + phones2).to(device).unsqueeze(0)
|
||||
bert = bert.to(device).unsqueeze(0)
|
||||
all_phoneme_len = torch.tensor([all_phoneme_ids.shape[-1]]).to(device)
|
||||
prompt = prompt_semantic.unsqueeze(0).to(device)
|
||||
t2 = ttime()
|
||||
with torch.no_grad():
|
||||
# pred_semantic = t2s_model.model.infer(
|
||||
pred_semantic, idx = t2s_model.model.infer_panel(
|
||||
all_phoneme_ids,
|
||||
all_phoneme_len,
|
||||
prompt,
|
||||
bert,
|
||||
# prompt_phone_len=ph_offset,
|
||||
top_k=config['inference']['top_k'],
|
||||
early_stop_num=hz * max_sec)
|
||||
t3 = ttime()
|
||||
# print(pred_semantic.shape,idx)
|
||||
pred_semantic = pred_semantic[:, -idx:].unsqueeze(0) # .unsqueeze(0)#mq要多unsqueeze一次
|
||||
refer = get_spepc(hps, ref_wav_path) # .to(device)
|
||||
if (is_half == True):
|
||||
refer = refer.half().to(device)
|
||||
else:
|
||||
refer = refer.to(device)
|
||||
# audio = vq_model.decode(pred_semantic, all_phoneme_ids, refer).detach().cpu().numpy()[0, 0]
|
||||
audio = \
|
||||
vq_model.decode(pred_semantic, torch.LongTensor(phones2).to(device).unsqueeze(0),
|
||||
refer).detach().cpu().numpy()[
|
||||
0, 0] ###试试重建不带上prompt部分
|
||||
audio_opt.append(audio)
|
||||
audio_opt.append(zero_wav)
|
||||
t4 = ttime()
|
||||
print("%.3f\t%.3f\t%.3f\t%.3f" % (t1 - t0, t2 - t1, t3 - t2, t4 - t3))
|
||||
audios_opt.append([text0,(np.concatenate(audio_opt, 0) * 32768).astype(np.int16)])
|
||||
return audios_opt
|
||||
|
||||
|
||||
# get_tts_wav(r"D:\BaiduNetdiskDownload\gsv\speech\萧逸声音-你得先从滑雪的基本技巧学起.wav", "你得先从滑雪的基本技巧学起。", "中文", "我觉得还是该给喜欢的女孩子一场认真的告白。", "中文")
|
||||
# with open(r"D:\BaiduNetdiskDownload\gsv\烟嗓-todo1.txt","r",encoding="utf8")as f:
|
||||
# with open(r"D:\BaiduNetdiskDownload\gsv\年下-todo1.txt","r",encoding="utf8")as f:
|
||||
# with open(r"D:\BaiduNetdiskDownload\gsv\萧逸3b.txt","r",encoding="utf8")as f:
|
||||
with open(r"D:\BaiduNetdiskDownload\gsv\萧逸4.txt","r",encoding="utf8")as f:
|
||||
textss=f.read().split("\n")
|
||||
for idx,(text,audio)in enumerate(get_tts_wavs(r"D:\BaiduNetdiskDownload\gsv\speech\萧逸声音-你得先从滑雪的基本技巧学起.wav", "你得先从滑雪的基本技巧学起。", "中文", textss, "中文")):
|
||||
|
||||
# for idx,(text,audio)in enumerate(get_tts_wavs(r"D:\BaiduNetdiskDownload\gsv\足够的能力,去制定好自己的生活规划。低沉烟嗓.MP3_1940480_2095360.wav", "足够的能力,去制定好自己的生活规划。", "中文", textss, "中文")):
|
||||
# for idx,(text,audio)in enumerate(get_tts_wavs(r"D:\BaiduNetdiskDownload\gsv\不会呀!你前几天才吃过你还说好吃来着。年下少年音.MP3_537600_711040.wav", "不会呀!你前几天才吃过你还说好吃来着。", "中文", textss, "中文")):
|
||||
print(idx,text)
|
||||
# sf.write(r"D:\BaiduNetdiskDownload\gsv\output\烟嗓第一批\%04d-%s.wav"%(idx,text),audio,32000)
|
||||
# sf.write(r"D:\BaiduNetdiskDownload\gsv\output\年下\%04d-%s.wav"%(idx,text),audio,32000)
|
||||
sf.write(r"D:\BaiduNetdiskDownload\gsv\output\萧逸第4批\%04d-%s.wav"%(idx,text),audio,32000)
|
||||
|
||||
|
||||
# def handle(command, refer_wav_path, prompt_text, prompt_language, text, text_language):
|
||||
# if command == "/restart":
|
||||
# os.execl(g_config.python_exec, g_config.python_exec, *sys.argv)
|
||||
# elif command == "/exit":
|
||||
# os.kill(os.getpid(), signal.SIGTERM)
|
||||
# exit(0)
|
||||
#
|
||||
# if (
|
||||
# refer_wav_path == "" or refer_wav_path is None
|
||||
# or prompt_text == "" or prompt_text is None
|
||||
# or prompt_language == "" or prompt_language is None
|
||||
# ):
|
||||
# refer_wav_path, prompt_text, prompt_language = (
|
||||
# default_refer_path,
|
||||
# default_refer_text,
|
||||
# default_refer_language,
|
||||
# )
|
||||
# if not has_preset:
|
||||
# raise HTTPException(status_code=400, detail="未指定参考音频且接口无预设")
|
||||
#
|
||||
# with torch.no_grad():
|
||||
# gen = get_tts_wav(
|
||||
# refer_wav_path, prompt_text, prompt_language, text, text_language
|
||||
# )
|
||||
# sampling_rate, audio_data = next(gen)
|
||||
#
|
||||
# wav = BytesIO()
|
||||
# sf.write(wav, audio_data, sampling_rate, format="wav")
|
||||
# wav.seek(0)
|
||||
#
|
||||
# torch.cuda.empty_cache()
|
||||
# return StreamingResponse(wav, media_type="audio/wav")
|
||||
|
||||
|
||||
# app = FastAPI()
|
||||
#
|
||||
#
|
||||
# @app.post("/")
|
||||
# async def tts_endpoint(request: Request):
|
||||
# json_post_raw = await request.json()
|
||||
# return handle(
|
||||
# json_post_raw.get("command"),
|
||||
# json_post_raw.get("refer_wav_path"),
|
||||
# json_post_raw.get("prompt_text"),
|
||||
# json_post_raw.get("prompt_language"),
|
||||
# json_post_raw.get("text"),
|
||||
# json_post_raw.get("text_language"),
|
||||
# )
|
||||
#
|
||||
#
|
||||
# @app.get("/")
|
||||
# async def tts_endpoint(
|
||||
# command: str = None,
|
||||
# refer_wav_path: str = None,
|
||||
# prompt_text: str = None,
|
||||
# prompt_language: str = None,
|
||||
# text: str = None,
|
||||
# text_language: str = None,
|
||||
# ):
|
||||
# return handle(command, refer_wav_path, prompt_text, prompt_language, text, text_language)
|
||||
#
|
||||
#
|
||||
# if __name__ == "__main__":
|
||||
# uvicorn.run(app, host=host, port=port, workers=1)
|
14
config.py
14
config.py
@ -144,7 +144,8 @@ webui_port_subfix = 9871
|
||||
|
||||
api_port = 9880
|
||||
|
||||
#Thanks to the contribution of @Karasukaigan and @XXXXRT666
|
||||
|
||||
# Thanks to the contribution of @Karasukaigan and @XXXXRT666
|
||||
def get_device_dtype_sm(idx: int) -> tuple[torch.device, torch.dtype, float, float]:
|
||||
cpu = torch.device("cpu")
|
||||
cuda = torch.device(f"cuda:{idx}")
|
||||
@ -157,10 +158,13 @@ def get_device_dtype_sm(idx: int) -> tuple[torch.device, torch.dtype, float, flo
|
||||
mem_gb = mem_bytes / (1024**3) + 0.4
|
||||
major, minor = capability
|
||||
sm_version = major + minor / 10.0
|
||||
is_16_series = bool(re.search(r"16\d{2}", name))and sm_version == 7.5
|
||||
if mem_gb < 4 or sm_version < 5.3:return cpu, torch.float32, 0.0, 0.0
|
||||
if sm_version == 6.1 or is_16_series==True:return cuda, torch.float32, sm_version, mem_gb
|
||||
if sm_version > 6.1:return cuda, torch.float16, sm_version, mem_gb
|
||||
is_16_series = bool(re.search(r"16\d{2}", name)) and sm_version == 7.5
|
||||
if mem_gb < 4 or sm_version < 5.3:
|
||||
return cpu, torch.float32, 0.0, 0.0
|
||||
if sm_version == 6.1 or is_16_series == True:
|
||||
return cuda, torch.float32, sm_version, mem_gb
|
||||
if sm_version > 6.1:
|
||||
return cuda, torch.float16, sm_version, mem_gb
|
||||
return cpu, torch.float32, 0.0, 0.0
|
||||
|
||||
|
||||
|
@ -12,10 +12,6 @@ services:
|
||||
- "9880:9880"
|
||||
volumes:
|
||||
- .:/workspace/GPT-SoVITS
|
||||
- /dev/null:/workspace/GPT-SoVITS/GPT_SoVITS/pretrained_models
|
||||
- /dev/null:/workspace/GPT-SoVITS/GPT_SoVITS/text/G2PWModel
|
||||
- /dev/null:/workspace/GPT-SoVITS/tools/asr/models
|
||||
- /dev/null:/workspace/GPT-SoVITS/tools/uvr5/uvr5_weights
|
||||
environment:
|
||||
- is_half=true
|
||||
tty: true
|
||||
@ -34,10 +30,6 @@ services:
|
||||
- "9880:9880"
|
||||
volumes:
|
||||
- .:/workspace/GPT-SoVITS
|
||||
- /dev/null:/workspace/GPT-SoVITS/GPT_SoVITS/pretrained_models
|
||||
- /dev/null:/workspace/GPT-SoVITS/GPT_SoVITS/text/G2PWModel
|
||||
- /dev/null:/workspace/GPT-SoVITS/tools/asr/models
|
||||
- /dev/null:/workspace/GPT-SoVITS/tools/uvr5/uvr5_weights
|
||||
- tools/asr/models:/workspace/models/asr_models
|
||||
- tools/uvr5/uvr5_weights:/workspace/models/uvr5_weights
|
||||
environment:
|
||||
@ -58,10 +50,6 @@ services:
|
||||
- "9880:9880"
|
||||
volumes:
|
||||
- .:/workspace/GPT-SoVITS
|
||||
- /dev/null:/workspace/GPT-SoVITS/GPT_SoVITS/pretrained_models
|
||||
- /dev/null:/workspace/GPT-SoVITS/GPT_SoVITS/text/G2PWModel
|
||||
- /dev/null:/workspace/GPT-SoVITS/tools/asr/models
|
||||
- /dev/null:/workspace/GPT-SoVITS/tools/uvr5/uvr5_weights
|
||||
environment:
|
||||
- is_half=true
|
||||
tty: true
|
||||
@ -80,10 +68,6 @@ services:
|
||||
- "9880:9880"
|
||||
volumes:
|
||||
- .:/workspace/GPT-SoVITS
|
||||
- /dev/null:/workspace/GPT-SoVITS/GPT_SoVITS/pretrained_models
|
||||
- /dev/null:/workspace/GPT-SoVITS/GPT_SoVITS/text/G2PWModel
|
||||
- /dev/null:/workspace/GPT-SoVITS/tools/asr/models
|
||||
- /dev/null:/workspace/GPT-SoVITS/tools/uvr5/uvr5_weights
|
||||
- tools/asr/models:/workspace/models/asr_models
|
||||
- tools/uvr5/uvr5_weights:/workspace/models/uvr5_weights
|
||||
environment:
|
||||
|
@ -578,3 +578,19 @@
|
||||
- 内容: 优化精度自动检测逻辑, 给 WebUI 前端界面模块增加折叠功能.
|
||||
- 类型: 新功能
|
||||
- 提交: XXXXRT666, RVC-Boss
|
||||
- 2025.06.06 [PR#2427](https://github.com/RVC-Boss/GPT-SoVITS/pull/2427)
|
||||
- 内容: X一X型多音字判断修复
|
||||
- 类型: 修复
|
||||
- 提交: wzy3650
|
||||
- 2025.06.05 [PR#2439](https://github.com/RVC-Boss/GPT-SoVITS/pull/2439)
|
||||
- 内容: 配置修复;sovits模型读取修复
|
||||
- 类型: 修复
|
||||
- 提交: wzy3650
|
||||
- 2025.06.09 [Commit#8056efe4](https://github.com/RVC-Boss/GPT-SoVITS/commit/8056efe4ab7bbc3610c72ae356a6f37518441f7d)
|
||||
- 内容: 修复ge.sum数值可能爆炸导致推理无声的问题
|
||||
- 类型: 修复
|
||||
- 提交: RVC-Boss
|
||||
- 2025.06.10 [Commit#2c0436b9](https://github.com/RVC-Boss/GPT-SoVITS/commit/2c0436b9ce397424ae03476c836fb64c6e5ebcc6)
|
||||
- 内容: 修复实验名结尾出现空格在win中路径不正确的问题
|
||||
- 类型: 修复
|
||||
- 提交: RVC-Boss
|
||||
|
@ -7,12 +7,14 @@
|
||||
|
||||
<a href="https://trendshift.io/repositories/7033" target="_blank"><img src="https://trendshift.io/api/badge/repositories/7033" alt="RVC-Boss%2FGPT-SoVITS | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/></a>
|
||||
|
||||
<!-- img src="https://counter.seku.su/cmoe?name=gptsovits&theme=r34" /><br> -->
|
||||
[](https://colab.research.google.com/github/RVC-Boss/GPT-SoVITS/blob/main/Colab-WebUI.ipynb)
|
||||
[](https://github.com/RVC-Boss/GPT-SoVITS/blob/main/LICENSE)
|
||||
[](https://huggingface.co/spaces/lj1995/GPT-SoVITS-v2)
|
||||
[](https://hub.docker.com/r/xxxxrt666/gpt-sovits)
|
||||
|
||||
[](https://colab.research.google.com/github/RVC-Boss/GPT-SoVITS/blob/main/colab_webui.ipynb)
|
||||
[](https://github.com/RVC-Boss/GPT-SoVITS/blob/main/LICENSE)
|
||||
[](https://huggingface.co/spaces/lj1995/GPT-SoVITS-v2)
|
||||
[](https://discord.gg/dnrgs5GHfG)
|
||||
[](https://www.yuque.com/baicaigongchang1145haoyuangong/ib3g1e)
|
||||
[](https://rentry.co/GPT-SoVITS-guide#/)
|
||||
[](https://github.com/RVC-Boss/GPT-SoVITS/blob/main/docs/cn/Changelog_CN.md)
|
||||
|
||||
[**English**](../../README.md) | **中文简体** | [**日本語**](../ja/README.md) | [**한국어**](../ko/README.md) | [**Türkçe**](../tr/README.md)
|
||||
|
||||
@ -128,8 +130,9 @@ brew install ffmpeg
|
||||
|
||||
- 前往 [Docker Hub](https://hub.docker.com/r/xxxxrt666/gpt-sovits) 查看最新可用的镜像标签(tags)
|
||||
- 根据你的运行环境选择合适的镜像标签
|
||||
- `Lite` Docker 镜像不包含 ASR 模型和 UVR5 模型. 你可以自行下载 UVR5 模型, ASR 模型则会在需要时由程序自动下载
|
||||
- `Lite` Docker 镜像**不包含** ASR 模型和 UVR5 模型. 你可以自行下载 UVR5 模型, ASR 模型则会在需要时由程序自动下载
|
||||
- 在使用 Docker Compose 时, 会自动拉取适配的架构镜像 (amd64 或 arm64)
|
||||
- Docker Compose 将会挂载当前目录的**所有文件**, 请在使用 Docker 镜像前先切换到项目根目录并**拉取代码更新**
|
||||
- 可选:为了获得最新的更改, 你可以使用提供的 Dockerfile 在本地构建镜像
|
||||
|
||||
#### 环境变量
|
||||
@ -329,7 +332,7 @@ python webui.py
|
||||
新特性:
|
||||
|
||||
1. **相比 V2 占用稍高显存, 性能超过 V4, 在保留 V2 硬件成本和推理速度优势的同时实现更高音质.**
|
||||
[更多详情](https://github.com/RVC-Boss/GPT-SoVITS/wiki/GPT%E2%80%90SoVITS%E2%80%90features-(%E5%90%84%E7%89%88%E6%9C%AC%E7%89%B9%E6%80%A7))
|
||||
[更多详情](<https://github.com/RVC-Boss/GPT-SoVITS/wiki/GPT%E2%80%90SoVITS%E2%80%90features-(%E5%90%84%E7%89%88%E6%9C%AC%E7%89%B9%E6%80%A7)>)
|
||||
|
||||
2. V1/V2 与 V2Pro 系列具有相同特性, V3/V4 则具备相近功能. 对于平均音频质量较低的训练集, V1/V2/V2Pro 可以取得较好的效果, 但 V3/V4 无法做到. 此外, V3/V4 合成的声音更偏向参考音频, 而不是整体训练集的风格.
|
||||
|
||||
|
@ -5,12 +5,16 @@
|
||||
|
||||
[](https://github.com/RVC-Boss/GPT-SoVITS)
|
||||
|
||||
<img src="https://counter.seku.su/cmoe?name=gptsovits&theme=r34" /><br>
|
||||
<a href="https://trendshift.io/repositories/7033" target="_blank"><img src="https://trendshift.io/api/badge/repositories/7033" alt="RVC-Boss%2FGPT-SoVITS | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/></a>
|
||||
|
||||
[](https://colab.research.google.com/github/RVC-Boss/GPT-SoVITS/blob/main/colab_webui.ipynb)
|
||||
[](https://github.com/RVC-Boss/GPT-SoVITS/blob/main/LICENSE)
|
||||
[](https://huggingface.co/spaces/lj1995/GPT-SoVITS-v2)
|
||||
[](https://discord.gg/dnrgs5GHfG)
|
||||
[](https://colab.research.google.com/github/RVC-Boss/GPT-SoVITS/blob/main/Colab-WebUI.ipynb)
|
||||
[](https://github.com/RVC-Boss/GPT-SoVITS/blob/main/LICENSE)
|
||||
[](https://huggingface.co/spaces/lj1995/GPT-SoVITS-v2)
|
||||
[](https://hub.docker.com/r/xxxxrt666/gpt-sovits)
|
||||
|
||||
[](https://www.yuque.com/baicaigongchang1145haoyuangong/ib3g1e)
|
||||
[](https://rentry.co/GPT-SoVITS-guide#/)
|
||||
[](https://github.com/RVC-Boss/GPT-SoVITS/blob/main/docs/ja/Changelog_JA.md)
|
||||
|
||||
[**English**](../../README.md) | [**中文简体**](../cn/README.md) | **日本語** | [**한국어**](../ko/README.md) | [**Türkçe**](../tr/README.md)
|
||||
|
||||
@ -122,8 +126,9 @@ brew install ffmpeg
|
||||
|
||||
- [Docker Hub](https://hub.docker.com/r/xxxxrt666/gpt-sovits) で最新のイメージタグを確認してください
|
||||
- 環境に合った適切なイメージタグを選択してください
|
||||
- `Lite` とは、Docker イメージに ASR モデルおよび UVR5 モデルが含まれていないことを意味します. UVR5 モデルは手動でダウンロードし、ASR モデルは必要に応じてプログラムが自動的にダウンロードします
|
||||
- `Lite` とは、Docker イメージに ASR モデルおよび UVR5 モデルが**含まれていない**ことを意味します. UVR5 モデルは手動でダウンロードし、ASR モデルは必要に応じてプログラムが自動的にダウンロードします
|
||||
- Docker Compose 実行時に、対応するアーキテクチャ (amd64 または arm64) のイメージが自動的に取得されます
|
||||
- Docker Compose は現在のディレクトリ内の**すべてのファイル**をマウントします. Docker イメージを使用する前に、プロジェクトのルートディレクトリに移動し、**コードを最新の状態に更新**してください
|
||||
- オプション:最新の変更を反映させるため、提供されている Dockerfile を使ってローカルでイメージをビルドすることも可能です
|
||||
|
||||
#### 環境変数
|
||||
@ -304,7 +309,7 @@ v2 環境から v3 を使用する方法:
|
||||
新機能:
|
||||
|
||||
1. **V4 は、V3 で発生していた非整数倍アップサンプリングによる金属音の問題を修正し、音声がこもる問題を防ぐためにネイティブに 48kHz 音声を出力します(V3 はネイティブに 24kHz 音声のみ出力)**. 作者は V4 を V3 の直接的な置き換えとして推奨していますが、さらなるテストが必要です.
|
||||
[詳細はこちら](https://github.com/RVC-Boss/GPT-SoVITS/wiki/GPT%E2%80%90SoVITS%E2%80%90v3v4%E2%80%90features-(%E6%96%B0%E7%89%B9%E6%80%A7))
|
||||
[詳細はこちら](<https://github.com/RVC-Boss/GPT-SoVITS/wiki/GPT%E2%80%90SoVITS%E2%80%90v3v4%E2%80%90features-(%E6%96%B0%E7%89%B9%E6%80%A7)>)
|
||||
|
||||
V1/V2/V3 環境から V4 への移行方法:
|
||||
|
||||
@ -319,7 +324,7 @@ V1/V2/V3 環境から V4 への移行方法:
|
||||
新機能:
|
||||
|
||||
1. **V2 と比較してやや高いメモリ使用量ですが、ハードウェアコストと推論速度は維持しつつ、V4 よりも高い性能と音質を実現します. **
|
||||
[詳細はこちら](https://github.com/RVC-Boss/GPT-SoVITS/wiki/GPT%E2%80%90SoVITS%E2%80%90features-(%E5%90%84%E7%89%88%E6%9C%AC%E7%89%B9%E6%80%A7))
|
||||
[詳細はこちら](<https://github.com/RVC-Boss/GPT-SoVITS/wiki/GPT%E2%80%90SoVITS%E2%80%90features-(%E5%90%84%E7%89%88%E6%9C%AC%E7%89%B9%E6%80%A7)>)
|
||||
|
||||
2. V1/V2 と V2Pro シリーズは類似した特徴を持ち、V3/V4 も同様の機能を持っています. 平均音質が低いトレーニングセットの場合、V1/V2/V2Pro は良好な結果を出すことができますが、V3/V4 では対応できません. また、V3/V4 の合成音声はトレーニング全体ではなく、より参考音声に寄った音質になります.
|
||||
|
||||
|
@ -5,12 +5,16 @@
|
||||
|
||||
[](https://github.com/RVC-Boss/GPT-SoVITS)
|
||||
|
||||
<img src="https://counter.seku.su/cmoe?name=gptsovits&theme=r34" /><br>
|
||||
<a href="https://trendshift.io/repositories/7033" target="_blank"><img src="https://trendshift.io/api/badge/repositories/7033" alt="RVC-Boss%2FGPT-SoVITS | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/></a>
|
||||
|
||||
[](https://colab.research.google.com/github/RVC-Boss/GPT-SoVITS/blob/main/colab_webui.ipynb)
|
||||
[](https://github.com/RVC-Boss/GPT-SoVITS/blob/main/LICENSE)
|
||||
[](https://huggingface.co/spaces/lj1995/GPT-SoVITS-v2)
|
||||
[](https://discord.gg/dnrgs5GHfG)
|
||||
[](https://colab.research.google.com/github/RVC-Boss/GPT-SoVITS/blob/main/Colab-WebUI.ipynb)
|
||||
[](https://github.com/RVC-Boss/GPT-SoVITS/blob/main/LICENSE)
|
||||
[](https://huggingface.co/spaces/lj1995/GPT-SoVITS-v2)
|
||||
[](https://hub.docker.com/r/xxxxrt666/gpt-sovits)
|
||||
|
||||
[](https://www.yuque.com/baicaigongchang1145haoyuangong/ib3g1e)
|
||||
[](https://rentry.co/GPT-SoVITS-guide#/)
|
||||
[](https://github.com/RVC-Boss/GPT-SoVITS/blob/main/docs/ko/Changelog_KO.md)
|
||||
|
||||
[**English**](../../README.md) | [**中文简体**](../cn/README.md) | [**日本語**](../ja/README.md) | **한국어** | [**Türkçe**](../tr/README.md)
|
||||
|
||||
@ -122,8 +126,9 @@ brew install ffmpeg
|
||||
|
||||
- [Docker Hub](https://hub.docker.com/r/xxxxrt666/gpt-sovits)에서 최신 이미지 태그를 확인하세요
|
||||
- 환경에 맞는 적절한 이미지 태그를 선택하세요
|
||||
- `Lite` 는 Docker 이미지에 ASR 모델과 UVR5 모델이 포함되어 있지 않음을 의미합니다. UVR5 모델은 사용자가 직접 다운로드해야 하며, ASR 모델은 필요 시 프로그램이 자동으로 다운로드합니다
|
||||
- `Lite` 는 Docker 이미지에 ASR 모델과 UVR5 모델이 **포함되어 있지 않음**을 의미합니다. UVR5 모델은 사용자가 직접 다운로드해야 하며, ASR 모델은 필요 시 프로그램이 자동으로 다운로드합니다
|
||||
- Docker Compose 실행 시, 해당 아키텍처에 맞는 이미지(amd64 또는 arm64)가 자동으로 다운로드됩니다
|
||||
- Docker Compose는 현재 디렉터리의 **모든 파일**을 마운트합니다. Docker 이미지를 사용하기 전에 프로젝트 루트 디렉터리로 이동하여 코드를 **최신 상태로 업데이트**하세요
|
||||
- 선택 사항: 최신 변경사항을 반영하려면 제공된 Dockerfile을 사용하여 로컬에서 직접 이미지를 빌드할 수 있습니다
|
||||
|
||||
#### 환경 변수
|
||||
@ -319,7 +324,7 @@ V1/V2/V3 환경에서 V4로 전환 방법:
|
||||
신규 기능:
|
||||
|
||||
1. **V2보다 약간 높은 VRAM 사용량이지만 성능은 V4보다 우수하며, V2 수준의 하드웨어 비용과 속도를 유지합니다**.
|
||||
[자세히 보기](https://github.com/RVC-Boss/GPT-SoVITS/wiki/GPT%E2%80%90SoVITS%E2%80%90features-(%E5%90%84%E7%89%88%E6%9C%AC%E7%89%B9%E6%80%A7))
|
||||
[자세히 보기](<https://github.com/RVC-Boss/GPT-SoVITS/wiki/GPT%E2%80%90SoVITS%E2%80%90features-(%E5%90%84%E7%89%88%E6%9C%AC%E7%89%B9%E6%80%A7)>)
|
||||
|
||||
2. V1/V2와 V2Pro 시리즈는 유사한 특징을 가지며, V3/V4도 비슷한 기능을 가지고 있습니다. 평균 음질이 낮은 학습 데이터셋에서는 V1/V2/V2Pro가 좋은 결과를 내지만 V3/V4는 그렇지 못합니다. 또한 V3/V4의 합성 음색은 전체 학습 데이터셋보다는 참고 음성에 더 가깝습니다.
|
||||
|
||||
|
@ -7,12 +7,14 @@ Güçlü Birkaç Örnekli Ses Dönüştürme ve Metinden Konuşmaya Web Arayüz
|
||||
|
||||
<a href="https://trendshift.io/repositories/7033" target="_blank"><img src="https://trendshift.io/api/badge/repositories/7033" alt="RVC-Boss%2FGPT-SoVITS | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/></a>
|
||||
|
||||
<!-- img src="https://counter.seku.su/cmoe?name=gptsovits&theme=r34" /><br> -->
|
||||
[](https://colab.research.google.com/github/RVC-Boss/GPT-SoVITS/blob/main/Colab-WebUI.ipynb)
|
||||
[](https://github.com/RVC-Boss/GPT-SoVITS/blob/main/LICENSE)
|
||||
[](https://huggingface.co/spaces/lj1995/GPT-SoVITS-v2)
|
||||
[](https://hub.docker.com/r/xxxxrt666/gpt-sovits)
|
||||
|
||||
[](https://colab.research.google.com/github/RVC-Boss/GPT-SoVITS/blob/main/colab_webui.ipynb)
|
||||
[](https://github.com/RVC-Boss/GPT-SoVITS/blob/main/LICENSE)
|
||||
[](https://huggingface.co/spaces/lj1995/GPT-SoVITS-v2)
|
||||
[](https://discord.gg/dnrgs5GHfG)
|
||||
[](https://www.yuque.com/baicaigongchang1145haoyuangong/ib3g1e)
|
||||
[](https://rentry.co/GPT-SoVITS-guide#/)
|
||||
[](https://github.com/RVC-Boss/GPT-SoVITS/blob/main/docs/tr/Changelog_TR.md)
|
||||
|
||||
[**English**](../../README.md) | [**中文简体**](../cn/README.md) | [**日本語**](../ja/README.md) | [**한국어**](../ko/README.md) | **Türkçe**
|
||||
|
||||
@ -124,8 +126,9 @@ Kod tabanı hızla geliştiği halde Docker imajları daha yavaş yayınlandığ
|
||||
|
||||
- En güncel kullanılabilir imaj etiketlerini görmek için [Docker Hub](https://hub.docker.com/r/xxxxrt666/gpt-sovits) adresini kontrol edin
|
||||
- Ortamınıza uygun bir imaj etiketi seçin
|
||||
- `Lite`, Docker imajında ASR modelleri ve UVR5 modellerinin bulunmadığı anlamına gelir. UVR5 modellerini manuel olarak indirebilirsiniz; ASR modelleri ise gerektiğinde program tarafından otomatik olarak indirilir
|
||||
- `Lite`, Docker imajında ASR modelleri ve UVR5 modellerinin **bulunmadığı** anlamına gelir. UVR5 modellerini manuel olarak indirebilirsiniz; ASR modelleri ise gerektiğinde program tarafından otomatik olarak indirilir
|
||||
- Docker Compose sırasında, uygun mimariye (amd64 veya arm64) ait imaj otomatik olarak indirilir
|
||||
- Docker Compose, mevcut dizindeki **tüm dosyaları** bağlayacaktır. Docker imajını kullanmadan önce lütfen proje kök dizinine geçin ve **en son kodu çekin**
|
||||
- Opsiyonel: En güncel değişiklikleri almak için, sağlanan Dockerfile ile yerel olarak imajı kendiniz oluşturabilirsiniz
|
||||
|
||||
#### Ortam Değişkenleri
|
||||
@ -323,7 +326,7 @@ V1/V2/V3 ortamından V4'e geçiş:
|
||||
Yeni Özellikler:
|
||||
|
||||
1. **V2 ile karşılaştırıldığında biraz daha yüksek VRAM kullanımı sağlar ancak V4'ten daha iyi performans gösterir; aynı donanım maliyeti ve hız avantajını korur**.
|
||||
[Daha fazla bilgi](https://github.com/RVC-Boss/GPT-SoVITS/wiki/GPT%E2%80%90SoVITS%E2%80%90features-(%E5%90%84%E7%89%88%E6%9C%AC%E7%89%B9%E6%80%A7))
|
||||
[Daha fazla bilgi](<https://github.com/RVC-Boss/GPT-SoVITS/wiki/GPT%E2%80%90SoVITS%E2%80%90features-(%E5%90%84%E7%89%88%E6%9C%AC%E7%89%B9%E6%80%A7)>)
|
||||
|
||||
2. V1/V2 ve V2Pro serisi benzer özelliklere sahipken, V3/V4 de yakın işlevleri paylaşır. Ortalama kalite düşük olan eğitim setleriyle V1/V2/V2Pro iyi sonuçlar verebilir ama V3/V4 veremez. Ayrıca, V3/V4’ün ürettiği ses tonu genel eğitim setine değil, referans ses örneğine daha çok benzemektedir.
|
||||
|
||||
|
222
install.sh
222
install.sh
@ -5,15 +5,61 @@ SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" &>/dev/null && pwd)"
|
||||
|
||||
cd "$SCRIPT_DIR" || exit 1
|
||||
|
||||
set -e
|
||||
RESET="\033[0m"
|
||||
BOLD="\033[1m"
|
||||
ERROR="\033[1;31m[ERROR]: $RESET"
|
||||
WARNING="\033[1;33m[WARNING]: $RESET"
|
||||
INFO="\033[1;32m[INFO]: $RESET"
|
||||
SUCCESS="\033[1;34m[SUCCESS]: $RESET"
|
||||
|
||||
set -eE
|
||||
set -o errtrace
|
||||
|
||||
trap 'on_error $LINENO "$BASH_COMMAND" $?' ERR
|
||||
|
||||
# shellcheck disable=SC2317
|
||||
on_error() {
|
||||
local lineno="$1"
|
||||
local cmd="$2"
|
||||
local code="$3"
|
||||
|
||||
echo -e "${ERROR}${BOLD}Command \"${cmd}\" Failed${RESET} at ${BOLD}Line ${lineno}${RESET} with Exit Code ${BOLD}${code}${RESET}"
|
||||
echo -e "${ERROR}${BOLD}Call Stack:${RESET}"
|
||||
for ((i = ${#FUNCNAME[@]} - 1; i >= 1; i--)); do
|
||||
echo -e " in ${BOLD}${FUNCNAME[i]}()${RESET} at ${BASH_SOURCE[i]}:${BOLD}${BASH_LINENO[i - 1]}${RESET}"
|
||||
done
|
||||
exit "$code"
|
||||
}
|
||||
|
||||
run_conda_quiet() {
|
||||
local output
|
||||
output=$(conda install --yes --quiet -c conda-forge "$@" 2>&1) || {
|
||||
echo -e "${ERROR} Conda install failed:\n$output"
|
||||
exit 1
|
||||
}
|
||||
}
|
||||
|
||||
run_pip_quiet() {
|
||||
local output
|
||||
output=$(pip install "$@" 2>&1) || {
|
||||
echo -e "${ERROR} Pip install failed:\n$output"
|
||||
exit 1
|
||||
}
|
||||
}
|
||||
|
||||
run_wget_quiet() {
|
||||
local output
|
||||
output=$(wget --tries=25 --wait=5 --read-timeout=40 --retry-on-http-error=404 "$@" 2>&1) || {
|
||||
echo -e "${ERROR} Wget failed:\n$output"
|
||||
exit 1
|
||||
}
|
||||
}
|
||||
|
||||
if ! command -v conda &>/dev/null; then
|
||||
echo "Conda Not Found"
|
||||
echo -e "${ERROR}Conda Not Found"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
trap 'echo "Error Occured at \"$BASH_COMMAND\" with exit code $?"; exit 1' ERR
|
||||
|
||||
USE_CUDA=false
|
||||
USE_ROCM=false
|
||||
USE_CPU=false
|
||||
@ -34,8 +80,8 @@ print_help() {
|
||||
echo " -h, --help Show this help message and exit"
|
||||
echo ""
|
||||
echo "Examples:"
|
||||
echo " bash install.sh --source HF --download-uvr5"
|
||||
echo " bash install.sh --source ModelScope"
|
||||
echo " bash install.sh --device CU128 --source HF --download-uvr5"
|
||||
echo " bash install.sh --device MPS --source ModelScope"
|
||||
}
|
||||
|
||||
# Show help if no arguments provided
|
||||
@ -59,8 +105,8 @@ while [[ $# -gt 0 ]]; do
|
||||
USE_MODELSCOPE=true
|
||||
;;
|
||||
*)
|
||||
echo "Error: Invalid Download Source: $2"
|
||||
echo "Choose From: [HF, HF-Mirror, ModelScope]"
|
||||
echo -e "${ERROR}Error: Invalid Download Source: $2"
|
||||
echo -e "${ERROR}Choose From: [HF, HF-Mirror, ModelScope]"
|
||||
exit 1
|
||||
;;
|
||||
esac
|
||||
@ -86,8 +132,8 @@ while [[ $# -gt 0 ]]; do
|
||||
USE_CPU=true
|
||||
;;
|
||||
*)
|
||||
echo "Error: Invalid Device: $2"
|
||||
echo "Choose From: [CU126, CU128, ROCM, MPS, CPU]"
|
||||
echo -e "${ERROR}Error: Invalid Device: $2"
|
||||
echo -e "${ERROR}Choose From: [CU126, CU128, ROCM, MPS, CPU]"
|
||||
exit 1
|
||||
;;
|
||||
esac
|
||||
@ -102,22 +148,23 @@ while [[ $# -gt 0 ]]; do
|
||||
exit 0
|
||||
;;
|
||||
*)
|
||||
echo "Unknown Argument: $1"
|
||||
echo "Use -h or --help to see available options."
|
||||
echo -e "${ERROR}Unknown Argument: $1"
|
||||
echo ""
|
||||
print_help
|
||||
exit 1
|
||||
;;
|
||||
esac
|
||||
done
|
||||
|
||||
if ! $USE_CUDA && ! $USE_ROCM && ! $USE_CPU; then
|
||||
echo "Error: Device is REQUIRED"
|
||||
echo -e "${ERROR}Error: Device is REQUIRED"
|
||||
echo ""
|
||||
print_help
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if ! $USE_HF && ! $USE_HF_MIRROR && ! $USE_MODELSCOPE; then
|
||||
echo "Error: Download Source is REQUIRED"
|
||||
echo -e "${ERROR}Error: Download Source is REQUIRED"
|
||||
echo ""
|
||||
print_help
|
||||
exit 1
|
||||
@ -125,55 +172,65 @@ fi
|
||||
|
||||
# 安装构建工具
|
||||
# Install build tools
|
||||
echo -e "${INFO}Detected system: $(uname -s) $(uname -r) $(uname -m)"
|
||||
if [ "$(uname)" != "Darwin" ]; then
|
||||
gcc_major_version=$(command -v gcc >/dev/null 2>&1 && gcc -dumpversion | cut -d. -f1 || echo 0)
|
||||
if [ "$gcc_major_version" -lt 11 ]; then
|
||||
echo "Installing GCC & G++..."
|
||||
conda install -c conda-forge gcc=11 gxx=11 -q -y
|
||||
echo -e "${INFO}Installing GCC & G++..."
|
||||
run_conda_quiet gcc=11 gxx=11
|
||||
echo -e "${SUCCESS}GCC & G++ Installed..."
|
||||
else
|
||||
echo "GCC >=11"
|
||||
echo -e "${INFO}Detected GCC Version: $gcc_major_version"
|
||||
echo -e "${INFO}Skip Installing GCC & G++ From Conda-Forge"
|
||||
fi
|
||||
else
|
||||
if ! xcode-select -p &>/dev/null; then
|
||||
echo "Installing Xcode Command Line Tools..."
|
||||
echo -e "${INFO}Installing Xcode Command Line Tools..."
|
||||
xcode-select --install
|
||||
fi
|
||||
echo "Waiting For Xcode Command Line Tools Installation Complete..."
|
||||
while true; do
|
||||
sleep 20
|
||||
echo -e "${INFO}Waiting For Xcode Command Line Tools Installation Complete..."
|
||||
while true; do
|
||||
sleep 20
|
||||
|
||||
if xcode-select -p &>/dev/null; then
|
||||
echo "Xcode Command Line Tools Installed"
|
||||
break
|
||||
else
|
||||
echo "Installing,Please Wait..."
|
||||
if xcode-select -p &>/dev/null; then
|
||||
echo -e "${SUCCESS}Xcode Command Line Tools Installed"
|
||||
break
|
||||
else
|
||||
echo -e "${INFO}Installing,Please Wait..."
|
||||
fi
|
||||
done
|
||||
else
|
||||
XCODE_PATH=$(xcode-select -p)
|
||||
if [[ "$XCODE_PATH" == *"Xcode.app"* ]]; then
|
||||
echo -e "${WARNING} Detected Xcode path: $XCODE_PATH"
|
||||
echo -e "${WARNING} If your Xcode version does not match your macOS version, it may cause unexpected issues during compilation or package builds."
|
||||
fi
|
||||
done
|
||||
conda install -c conda-forge -q -y
|
||||
fi
|
||||
fi
|
||||
|
||||
echo "Installing ffmpeg and cmake..."
|
||||
conda install ffmpeg cmake make -q -y
|
||||
echo -e "${INFO}Installing FFmpeg & CMake..."
|
||||
run_conda_quiet ffmpeg cmake make
|
||||
echo -e "${SUCCESS}FFmpeg & CMake Installed"
|
||||
|
||||
echo "Installing unzip..."
|
||||
conda install unzip -y --quiet
|
||||
echo -e "${INFO}Installing unzip..."
|
||||
run_conda_quiet unzip
|
||||
echo -e "${SUCCESS}unzip Installed"
|
||||
|
||||
if [ "$USE_HF" = "true" ]; then
|
||||
echo "Download Model From HuggingFace"
|
||||
echo -e "${INFO}Download Model From HuggingFace"
|
||||
PRETRINED_URL="https://huggingface.co/XXXXRT/GPT-SoVITS-Pretrained/resolve/main/pretrained_models.zip"
|
||||
G2PW_URL="https://huggingface.co/XXXXRT/GPT-SoVITS-Pretrained/resolve/main/G2PWModel.zip"
|
||||
UVR5_URL="https://huggingface.co/XXXXRT/GPT-SoVITS-Pretrained/resolve/main/uvr5_weights.zip"
|
||||
NLTK_URL="https://huggingface.co/XXXXRT/GPT-SoVITS-Pretrained/resolve/main/nltk_data.zip"
|
||||
PYOPENJTALK_URL="https://huggingface.co/XXXXRT/GPT-SoVITS-Pretrained/resolve/main/open_jtalk_dic_utf_8-1.11.tar.gz"
|
||||
elif [ "$USE_HF_MIRROR" = "true" ]; then
|
||||
echo "Download Model From HuggingFace-Mirror"
|
||||
echo -e "${INFO}Download Model From HuggingFace-Mirror"
|
||||
PRETRINED_URL="https://hf-mirror.com/XXXXRT/GPT-SoVITS-Pretrained/resolve/main/pretrained_models.zip"
|
||||
G2PW_URL="https://hf-mirror.com/XXXXRT/GPT-SoVITS-Pretrained/resolve/main/G2PWModel.zip"
|
||||
UVR5_URL="https://hf-mirror.com/XXXXRT/GPT-SoVITS-Pretrained/resolve/main/uvr5_weights.zip"
|
||||
NLTK_URL="https://hf-mirror.com/XXXXRT/GPT-SoVITS-Pretrained/resolve/main/nltk_data.zip"
|
||||
PYOPENJTALK_URL="https://hf-mirror.com/XXXXRT/GPT-SoVITS-Pretrained/resolve/main/open_jtalk_dic_utf_8-1.11.tar.gz"
|
||||
elif [ "$USE_MODELSCOPE" = "true" ]; then
|
||||
echo "Download Model From ModelScope"
|
||||
echo -e "${INFO}Download Model From ModelScope"
|
||||
PRETRINED_URL="https://www.modelscope.cn/models/XXXXRT/GPT-SoVITS-Pretrained/resolve/master/pretrained_models.zip"
|
||||
G2PW_URL="https://www.modelscope.cn/models/XXXXRT/GPT-SoVITS-Pretrained/resolve/master/G2PWModel.zip"
|
||||
UVR5_URL="https://www.modelscope.cn/models/XXXXRT/GPT-SoVITS-Pretrained/resolve/master/uvr5_weights.zip"
|
||||
@ -181,118 +238,129 @@ elif [ "$USE_MODELSCOPE" = "true" ]; then
|
||||
PYOPENJTALK_URL="https://www.modelscope.cn/models/XXXXRT/GPT-SoVITS-Pretrained/resolve/master/open_jtalk_dic_utf_8-1.11.tar.gz"
|
||||
fi
|
||||
|
||||
if [ "$WORKFLOW" = "true" ]; then
|
||||
WGET_CMD=(wget -nv --tries=25 --wait=5 --read-timeout=40 --retry-on-http-error=404)
|
||||
else
|
||||
WGET_CMD=(wget --tries=25 --wait=5 --read-timeout=40 --retry-on-http-error=404)
|
||||
fi
|
||||
|
||||
if find -L "GPT_SoVITS/pretrained_models" -mindepth 1 ! -name '.gitignore' | grep -q .; then
|
||||
echo "Pretrained Model Exists"
|
||||
echo -e "${INFO}Pretrained Model Exists"
|
||||
echo -e "${INFO}Skip Downloading Pretrained Models"
|
||||
else
|
||||
echo "Download Pretrained Models"
|
||||
"${WGET_CMD[@]}" "$PRETRINED_URL"
|
||||
echo -e "${INFO}Downloading Pretrained Models..."
|
||||
rm -rf pretrained_models.zip
|
||||
run_wget_quiet "$PRETRINED_URL"
|
||||
|
||||
unzip -q -o pretrained_models.zip -d GPT_SoVITS
|
||||
rm -rf pretrained_models.zip
|
||||
echo -e "${SUCCESS}Pretrained Models Downloaded"
|
||||
fi
|
||||
|
||||
if [ ! -d "GPT_SoVITS/text/G2PWModel" ]; then
|
||||
echo "Download G2PWModel"
|
||||
"${WGET_CMD[@]}" "$G2PW_URL"
|
||||
echo -e "${INFO}Downloading G2PWModel.."
|
||||
rm -rf G2PWModel.zip
|
||||
run_wget_quiet "$G2PW_URL"
|
||||
|
||||
unzip -q -o G2PWModel.zip -d GPT_SoVITS/text
|
||||
rm -rf G2PWModel.zip
|
||||
echo -e "${SUCCESS}G2PWModel Downloaded"
|
||||
else
|
||||
echo "G2PWModel Exists"
|
||||
echo -e "${INFO}G2PWModel Exists"
|
||||
echo -e "${INFO}Skip Downloading G2PWModel"
|
||||
fi
|
||||
|
||||
if [ "$DOWNLOAD_UVR5" = "true" ]; then
|
||||
if find -L "tools/uvr5/uvr5_weights" -mindepth 1 ! -name '.gitignore' | grep -q .; then
|
||||
echo "UVR5 Model Exists"
|
||||
echo -e"${INFO}UVR5 Models Exists"
|
||||
echo -e "${INFO}Skip Downloading UVR5 Models"
|
||||
else
|
||||
echo "Download UVR5 Model"
|
||||
"${WGET_CMD[@]}" "$UVR5_URL"
|
||||
echo -e "${INFO}Downloading UVR5 Models..."
|
||||
rm -rf uvr5_weights.zip
|
||||
run_wget_quiet "$UVR5_URL"
|
||||
|
||||
unzip -q -o uvr5_weights.zip -d tools/uvr5
|
||||
rm -rf uvr5_weights.zip
|
||||
echo -e "${SUCCESS}UVR5 Models Downloaded"
|
||||
fi
|
||||
fi
|
||||
|
||||
if [ "$USE_CUDA" = true ] && [ "$WORKFLOW" = false ]; then
|
||||
echo "Checking for CUDA installation..."
|
||||
echo -e "${INFO}Checking For Nvidia Driver Installation..."
|
||||
if command -v nvidia-smi &>/dev/null; then
|
||||
echo "CUDA found."
|
||||
echo "${INFO}Nvidia Driver Founded"
|
||||
else
|
||||
echo -e "${WARNING}Nvidia Driver Not Found, Fallback to CPU"
|
||||
USE_CUDA=false
|
||||
USE_CPU=true
|
||||
echo "CUDA not found."
|
||||
fi
|
||||
fi
|
||||
|
||||
if [ "$USE_ROCM" = true ] && [ "$WORKFLOW" = false ]; then
|
||||
echo "Checking for ROCm installation..."
|
||||
echo -e "${INFO}Checking For ROCm Installation..."
|
||||
if [ -d "/opt/rocm" ]; then
|
||||
echo "ROCm found."
|
||||
echo -e "${INFO}ROCm Founded"
|
||||
if grep -qi "microsoft" /proc/version; then
|
||||
echo "You are running WSL."
|
||||
echo -e "${INFO}WSL2 Founded"
|
||||
IS_WSL=true
|
||||
else
|
||||
echo "You are NOT running WSL."
|
||||
IS_WSL=false
|
||||
fi
|
||||
else
|
||||
echo -e "${WARNING}ROCm Not Found, Fallback to CPU"
|
||||
USE_ROCM=false
|
||||
USE_CPU=true
|
||||
echo "ROCm not found."
|
||||
fi
|
||||
fi
|
||||
|
||||
if [ "$USE_CUDA" = true ] && [ "$WORKFLOW" = false ]; then
|
||||
echo "Installing PyTorch with CUDA support..."
|
||||
if [ "$CUDA" = 128 ]; then
|
||||
pip install torch torchaudio --index-url https://download.pytorch.org/whl/cu128
|
||||
echo -e "${INFO}Installing PyTorch For CUDA 12.8..."
|
||||
run_pip_quiet torch torchaudio --index-url "https://download.pytorch.org/whl/cu128"
|
||||
elif [ "$CUDA" = 126 ]; then
|
||||
pip install torch==2.6 torchaudio --index-url https://download.pytorch.org/whl/cu126
|
||||
echo -e "${INFO}Installing PyTorch For CUDA 12.6..."
|
||||
run_pip_quiet torch torchaudio --index-url "https://download.pytorch.org/whl/cu126"
|
||||
fi
|
||||
elif [ "$USE_ROCM" = true ] && [ "$WORKFLOW" = false ]; then
|
||||
echo "Installing PyTorch with ROCm support..."
|
||||
pip install torch==2.6 torchaudio --index-url https://download.pytorch.org/whl/rocm6.2
|
||||
echo -e "${INFO}Installing PyTorch For ROCm 6.2..."
|
||||
run_pip_quiet torch torchaudio --index-url "https://download.pytorch.org/whl/rocm6.2"
|
||||
elif [ "$USE_CPU" = true ] && [ "$WORKFLOW" = false ]; then
|
||||
echo "Installing PyTorch for CPU..."
|
||||
pip install torch==2.6 torchaudio --index-url https://download.pytorch.org/whl/cpu
|
||||
echo -e "${INFO}Installing PyTorch For CPU..."
|
||||
run_pip_quiet torch torchaudio --index-url "https://download.pytorch.org/whl/cpu"
|
||||
elif [ "$WORKFLOW" = false ]; then
|
||||
echo "Unknown Err"
|
||||
echo -e "${ERROR}Unknown Err"
|
||||
exit 1
|
||||
fi
|
||||
echo -e "${SUCCESS}PyTorch Installed"
|
||||
|
||||
echo "Installing Python dependencies from requirements.txt..."
|
||||
echo -e "${INFO}Installing Python Dependencies From requirements.txt..."
|
||||
|
||||
# 刷新环境
|
||||
# Refresh environment
|
||||
hash -r
|
||||
|
||||
pip install -r extra-req.txt --no-deps --quiet
|
||||
run_pip_quiet -r extra-req.txt --no-deps
|
||||
|
||||
pip install -r requirements.txt --quiet
|
||||
run_pip_quiet -r requirements.txt
|
||||
|
||||
echo -e "${SUCCESS}Python Dependencies Installed"
|
||||
|
||||
PY_PREFIX=$(python -c "import sys; print(sys.prefix)")
|
||||
PYOPENJTALK_PREFIX=$(python -c "import os, pyopenjtalk; print(os.path.dirname(pyopenjtalk.__file__))")
|
||||
|
||||
"${WGET_CMD[@]}" "$NLTK_URL" -O nltk_data.zip
|
||||
echo -e "${INFO}Downloading NLTK Data..."
|
||||
rm -rf nltk_data.zip
|
||||
run_wget_quiet "$NLTK_URL" -O nltk_data.zip
|
||||
unzip -q -o nltk_data -d "$PY_PREFIX"
|
||||
rm -rf nltk_data.zip
|
||||
echo -e "${SUCCESS}NLTK Data Downloaded"
|
||||
|
||||
"${WGET_CMD[@]}" "$PYOPENJTALK_URL" -O open_jtalk_dic_utf_8-1.11.tar.gz
|
||||
tar -xvzf open_jtalk_dic_utf_8-1.11.tar.gz -C "$PYOPENJTALK_PREFIX"
|
||||
echo -e "${INFO}Downloading Open JTalk Dict..."
|
||||
rm -rf open_jtalk_dic_utf_8-1.11.tar.gz
|
||||
run_wget_quiet "$PYOPENJTALK_URL" -O open_jtalk_dic_utf_8-1.11.tar.gz
|
||||
tar -xzf open_jtalk_dic_utf_8-1.11.tar.gz -C "$PYOPENJTALK_PREFIX"
|
||||
rm -rf open_jtalk_dic_utf_8-1.11.tar.gz
|
||||
echo -e "${SUCCESS}Open JTalk Dic Downloaded"
|
||||
|
||||
if [ "$USE_ROCM" = true ] && [ "$IS_WSL" = true ]; then
|
||||
echo "Update to WSL compatible runtime lib..."
|
||||
echo -e "${INFO}Updating WSL Compatible Runtime Lib For ROCm..."
|
||||
location=$(pip show torch | grep Location | awk -F ": " '{print $2}')
|
||||
cd "${location}"/torch/lib/ || exit
|
||||
rm libhsa-runtime64.so*
|
||||
cp /opt/rocm/lib/libhsa-runtime64.so.1.2 libhsa-runtime64.so
|
||||
echo -e "${SUCCESS}ROCm Runtime Lib Updated..."
|
||||
fi
|
||||
|
||||
echo "Installation completed successfully!"
|
||||
echo -e "${SUCCESS}Installation Completed"
|
||||
|
@ -1,81 +1,38 @@
|
||||
js = """
|
||||
function createGradioAnimation() {
|
||||
function deleteTheme() {
|
||||
|
||||
const params = new URLSearchParams(window.location.search);
|
||||
if (params.get('__theme') !== 'light') {
|
||||
params.set('__theme', 'light'); // 仅当 __theme 不是 'light' 时设置为 'light'
|
||||
window.location.search = params.toString(); // 更新 URL,触发页面刷新
|
||||
}
|
||||
|
||||
var container = document.createElement('div');
|
||||
container.id = 'gradio-animation';
|
||||
container.style.fontSize = '2em';
|
||||
container.style.fontWeight = '500';
|
||||
container.style.textAlign = 'center';
|
||||
container.style.marginBottom = '20px';
|
||||
container.style.fontFamily = '-apple-system, sans-serif, Arial, Calibri';
|
||||
const params = new URLSearchParams(window.location.search);
|
||||
if (params.has('__theme')) {
|
||||
params.delete('__theme');
|
||||
const newUrl = `${window.location.pathname}?${params.toString()}`;
|
||||
window.location.replace(newUrl);
|
||||
}
|
||||
|
||||
var text = 'Welcome to GPT-SoVITS !';
|
||||
for (var i = 0; i < text.length; i++) {
|
||||
(function(i){
|
||||
setTimeout(function(){
|
||||
var letter = document.createElement('span');
|
||||
letter.style.opacity = '0';
|
||||
letter.style.transition = 'opacity 0.5s';
|
||||
letter.innerText = text[i];
|
||||
|
||||
container.appendChild(letter);
|
||||
|
||||
setTimeout(function() {
|
||||
letter.style.opacity = '1';
|
||||
}, 50);
|
||||
}, i * 250);
|
||||
})(i);
|
||||
}
|
||||
return 'Animation created';
|
||||
}
|
||||
"""
|
||||
|
||||
|
||||
css = """
|
||||
/* CSSStyleRule */
|
||||
|
||||
.markdown {
|
||||
background-color: lightblue;
|
||||
padding: 6px 10px;
|
||||
}
|
||||
|
||||
.checkbox_info {
|
||||
color: var(--block-title-text-color) !important;
|
||||
font-size: var(--block-title-text-size) !important;
|
||||
font-weight: var(--block-title-text-weight) !important;
|
||||
height: 22px;
|
||||
margin-bottom: 8px !important;
|
||||
@media (prefers-color-scheme: light) {
|
||||
.markdown {
|
||||
background-color: lightblue;
|
||||
color: #000;
|
||||
}
|
||||
}
|
||||
|
||||
@media (prefers-color-scheme: dark) {
|
||||
.markdown {
|
||||
background-color: #4b4b4b;
|
||||
color: rgb(244, 244, 245);
|
||||
}
|
||||
}
|
||||
|
||||
::selection {
|
||||
background: #ffc078; !important;
|
||||
}
|
||||
|
||||
#checkbox_train_dpo input[type="checkbox"]{
|
||||
margin-top: 6px;
|
||||
}
|
||||
|
||||
#checkbox_train_dpo span {
|
||||
margin-top: 6px;
|
||||
}
|
||||
|
||||
#checkbox_align_train {
|
||||
padding-top: 18px;
|
||||
padding-bottom: 18px;
|
||||
}
|
||||
|
||||
#checkbox_align_infer input[type="checkbox"] {
|
||||
margin-top: 10px;
|
||||
}
|
||||
|
||||
#checkbox_align_infer span {
|
||||
margin-top: 10px;
|
||||
background: #ffc078 !important;
|
||||
}
|
||||
|
||||
footer {
|
||||
@ -91,6 +48,7 @@ footer * {
|
||||
}
|
||||
|
||||
"""
|
||||
|
||||
top_html = """
|
||||
<div align="center">
|
||||
<div style="margin-bottom: 5px; font-size: 15px;">{}</div>
|
||||
|
@ -109,7 +109,7 @@ def check_details(path_list=None, is_train=False, is_dataset_processing=False):
|
||||
if os.path.exists(wav_path):
|
||||
...
|
||||
else:
|
||||
gr.Warning(wav_path+i18n("路径错误"))
|
||||
gr.Warning(wav_path + i18n("路径错误"))
|
||||
return
|
||||
if is_train:
|
||||
path_list.append(os.path.join(path_list[0], "2-name2text.txt"))
|
||||
|
@ -1,5 +1,6 @@
|
||||
import sys
|
||||
from tools.i18n.i18n import I18nAuto, scan_language_list
|
||||
|
||||
language = sys.argv[-1] if sys.argv[-1] in scan_language_list() else "Auto"
|
||||
i18n = I18nAuto(language=language)
|
||||
import argparse
|
||||
@ -309,7 +310,9 @@ if __name__ == "__main__":
|
||||
|
||||
with gr.Blocks(analytics_enabled=False) as demo:
|
||||
gr.Markdown(
|
||||
value=i18n("Submit Text: 将当前页所有文本框内容手工保存到内存和文件(翻页前后或者退出标注页面前如果没点这个按钮,你再翻回来就回滚了,白忙活。)")
|
||||
value=i18n(
|
||||
"Submit Text: 将当前页所有文本框内容手工保存到内存和文件(翻页前后或者退出标注页面前如果没点这个按钮,你再翻回来就回滚了,白忙活。)"
|
||||
)
|
||||
)
|
||||
with gr.Row():
|
||||
btn_change_index = gr.Button("Change Index")
|
||||
|
@ -190,14 +190,14 @@ class Predictor:
|
||||
opt_path_vocal = path_vocal[:-4] + ".%s" % format
|
||||
opt_path_other = path_other[:-4] + ".%s" % format
|
||||
if os.path.exists(path_vocal):
|
||||
os.system("ffmpeg -i \"%s\" -vn \"%s\" -q:a 2 -y" % (path_vocal, opt_path_vocal))
|
||||
os.system('ffmpeg -i "%s" -vn "%s" -q:a 2 -y' % (path_vocal, opt_path_vocal))
|
||||
if os.path.exists(opt_path_vocal):
|
||||
try:
|
||||
os.remove(path_vocal)
|
||||
except:
|
||||
pass
|
||||
if os.path.exists(path_other):
|
||||
os.system("ffmpeg -i \"%s\" -vn \"%s\" -q:a 2 -y" % (path_other, opt_path_other))
|
||||
os.system('ffmpeg -i "%s" -vn "%s" -q:a 2 -y' % (path_other, opt_path_other))
|
||||
if os.path.exists(opt_path_other):
|
||||
try:
|
||||
os.remove(path_other)
|
||||
|
@ -140,7 +140,7 @@ class AudioPre:
|
||||
)
|
||||
if os.path.exists(path):
|
||||
opt_format_path = path[:-4] + ".%s" % format
|
||||
cmd="ffmpeg -i \"%s\" -vn \"%s\" -q:a 2 -y" % (path, opt_format_path)
|
||||
cmd = 'ffmpeg -i "%s" -vn "%s" -q:a 2 -y' % (path, opt_format_path)
|
||||
print(cmd)
|
||||
os.system(cmd)
|
||||
if os.path.exists(opt_format_path):
|
||||
@ -177,7 +177,7 @@ class AudioPre:
|
||||
)
|
||||
if os.path.exists(path):
|
||||
opt_format_path = path[:-4] + ".%s" % format
|
||||
cmd="ffmpeg -i \"%s\" -vn \"%s\" -q:a 2 -y" % (path, opt_format_path)
|
||||
cmd = 'ffmpeg -i "%s" -vn "%s" -q:a 2 -y' % (path, opt_format_path)
|
||||
print(cmd)
|
||||
os.system(cmd)
|
||||
if os.path.exists(opt_format_path):
|
||||
@ -307,7 +307,7 @@ class AudioPreDeEcho:
|
||||
)
|
||||
if os.path.exists(path):
|
||||
opt_format_path = path[:-4] + ".%s" % format
|
||||
cmd="ffmpeg -i \"%s\" -vn \"%s\" -q:a 2 -y" % (path, opt_format_path)
|
||||
cmd = 'ffmpeg -i "%s" -vn "%s" -q:a 2 -y' % (path, opt_format_path)
|
||||
print(cmd)
|
||||
os.system(cmd)
|
||||
if os.path.exists(opt_format_path):
|
||||
@ -340,7 +340,7 @@ class AudioPreDeEcho:
|
||||
)
|
||||
if os.path.exists(path):
|
||||
opt_format_path = path[:-4] + ".%s" % format
|
||||
cmd="ffmpeg -i \"%s\" -vn \"%s\" -q:a 2 -y" % (path, opt_format_path)
|
||||
cmd = 'ffmpeg -i "%s" -vn "%s" -q:a 2 -y' % (path, opt_format_path)
|
||||
print(cmd)
|
||||
os.system(cmd)
|
||||
if os.path.exists(opt_format_path):
|
||||
|
6
webui.py
6
webui.py
@ -507,6 +507,7 @@ def open1Ba(
|
||||
):
|
||||
global p_train_SoVITS
|
||||
if p_train_SoVITS == None:
|
||||
exp_name = exp_name.rstrip(" ")
|
||||
config_file = (
|
||||
"GPT_SoVITS/configs/s2.json"
|
||||
if version not in {"v2Pro", "v2ProPlus"}
|
||||
@ -603,6 +604,7 @@ def open1Bb(
|
||||
):
|
||||
global p_train_GPT
|
||||
if p_train_GPT == None:
|
||||
exp_name = exp_name.rstrip(" ")
|
||||
with open(
|
||||
"GPT_SoVITS/configs/s1longer.yaml" if version == "v1" else "GPT_SoVITS/configs/s1longer-v2.yaml"
|
||||
) as f:
|
||||
@ -785,6 +787,7 @@ def open1a(inp_text, inp_wav_dir, exp_name, gpu_numbers, bert_pretrained_dir):
|
||||
inp_wav_dir = my_utils.clean_path(inp_wav_dir)
|
||||
if check_for_existance([inp_text, inp_wav_dir], is_dataset_processing=True):
|
||||
check_details([inp_text, inp_wav_dir], is_dataset_processing=True)
|
||||
exp_name = exp_name.rstrip(" ")
|
||||
if ps1a == []:
|
||||
opt_dir = "%s/%s" % (exp_root, exp_name)
|
||||
config = {
|
||||
@ -874,6 +877,7 @@ def open1b(version, inp_text, inp_wav_dir, exp_name, gpu_numbers, ssl_pretrained
|
||||
inp_wav_dir = my_utils.clean_path(inp_wav_dir)
|
||||
if check_for_existance([inp_text, inp_wav_dir], is_dataset_processing=True):
|
||||
check_details([inp_text, inp_wav_dir], is_dataset_processing=True)
|
||||
exp_name = exp_name.rstrip(" ")
|
||||
if ps1b == []:
|
||||
config = {
|
||||
"inp_text": inp_text,
|
||||
@ -962,6 +966,7 @@ def open1c(version, inp_text, inp_wav_dir, exp_name, gpu_numbers, pretrained_s2G
|
||||
inp_text = my_utils.clean_path(inp_text)
|
||||
if check_for_existance([inp_text, inp_wav_dir], is_dataset_processing=True):
|
||||
check_details([inp_text, inp_wav_dir], is_dataset_processing=True)
|
||||
exp_name = exp_name.rstrip(" ")
|
||||
if ps1c == []:
|
||||
opt_dir = "%s/%s" % (exp_root, exp_name)
|
||||
config_file = (
|
||||
@ -1059,6 +1064,7 @@ def open1abc(
|
||||
inp_wav_dir = my_utils.clean_path(inp_wav_dir)
|
||||
if check_for_existance([inp_text, inp_wav_dir], is_dataset_processing=True):
|
||||
check_details([inp_text, inp_wav_dir], is_dataset_processing=True)
|
||||
exp_name = exp_name.rstrip(" ")
|
||||
if ps1abc == []:
|
||||
opt_dir = "%s/%s" % (exp_root, exp_name)
|
||||
try:
|
||||
|
Loading…
x
Reference in New Issue
Block a user