This commit is contained in:
XXXXRT666 2025-06-11 23:18:57 +08:00
parent a8f366ac14
commit 1e59f757a2
27 changed files with 719 additions and 545 deletions

View File

@ -28,7 +28,8 @@ class Text2SemanticLightningModule(LightningModule):
self.load_state_dict(
torch.load(
pretrained_s1,
map_location="cpu", weights_only=False,
map_location="cpu",
weights_only=False,
)["weight"],
)
)

View File

@ -32,19 +32,21 @@ from transformers import AutoModelForMaskedLM, AutoTokenizer
from tools.audio_sr import AP_BWE
from tools.i18n.i18n import I18nAuto, scan_language_list
from tools.my_utils import load_audio
from TTS_infer_pack.text_segmentation_method import splits
from TTS_infer_pack.TextPreprocessor import TextPreprocessor
from sv import SV
resample_transform_dict = {}
def resample(audio_tensor, sr0, sr1, device):
global resample_transform_dict
key = "%s-%s-%s" % (sr0, sr1, str(device))
if key not in resample_transform_dict:
resample_transform_dict[key] = torchaudio.transforms.Resample(
sr0, sr1
).to(device)
resample_transform_dict[key] = torchaudio.transforms.Resample(sr0, sr1).to(device)
return resample_transform_dict[key](audio_tensor)
language = os.environ.get("language", "Auto")
language = sys.argv[-1] if sys.argv[-1] in scan_language_list() else language
i18n = I18nAuto(language=language)
@ -111,6 +113,7 @@ def speed_change(input_audio: np.ndarray, speed: float, sr: int):
return processed_audio
class DictToAttrRecursive(dict):
def __init__(self, input_dict):
super().__init__(input_dict)
@ -632,7 +635,9 @@ class TTS:
)
self.vocoder.remove_weight_norm()
state_dict_g = torch.load(
"%s/GPT_SoVITS/pretrained_models/gsv-v4-pretrained/vocoder.pth" % (now_dir,), map_location="cpu", weights_only=False
"%s/GPT_SoVITS/pretrained_models/gsv-v4-pretrained/vocoder.pth" % (now_dir,),
map_location="cpu",
weights_only=False,
)
print("loading vocoder", self.vocoder.load_state_dict(state_dict_g))
@ -752,11 +757,13 @@ class TTS:
if raw_sr != self.configs.sampling_rate:
audio = raw_audio.to(self.configs.device)
if (audio.shape[0] == 2): audio = audio.mean(0).unsqueeze(0)
if audio.shape[0] == 2:
audio = audio.mean(0).unsqueeze(0)
audio = resample(audio, raw_sr, self.configs.sampling_rate, self.configs.device)
else:
audio = raw_audio.to(self.configs.device)
if (audio.shape[0] == 2): audio = audio.mean(0).unsqueeze(0)
if audio.shape[0] == 2:
audio = audio.mean(0).unsqueeze(0)
maxx = audio.abs().max()
if maxx > 1:
@ -775,7 +782,8 @@ class TTS:
audio = resample(audio, self.configs.sampling_rate, 16000, self.configs.device)
if self.configs.is_half:
audio = audio.half()
else:audio=None
else:
audio = None
return spec, audio
def _set_prompt_semantic(self, ref_wav_path: str):
@ -1073,7 +1081,10 @@ class TTS:
###### setting reference audio and prompt text preprocessing ########
t0 = time.perf_counter()
if (ref_audio_path is not None) and (ref_audio_path != self.prompt_cache["ref_audio_path"] or (self.is_v2pro and self.prompt_cache["refer_spec"][0][1] is None)):
if (ref_audio_path is not None) and (
ref_audio_path != self.prompt_cache["ref_audio_path"]
or (self.is_v2pro and self.prompt_cache["refer_spec"][0][1] is None)
):
if not os.path.exists(ref_audio_path):
raise ValueError(f"{ref_audio_path} not exists")
self.set_ref_audio(ref_audio_path)
@ -1212,7 +1223,8 @@ class TTS:
t_34 += t4 - t3
refer_audio_spec = []
if self.is_v2pro:sv_emb=[]
if self.is_v2pro:
sv_emb = []
for spec, audio_tensor in self.prompt_cache["refer_spec"]:
spec = spec.to(dtype=self.precision, device=self.configs.device)
refer_audio_spec.append(spec)
@ -1250,9 +1262,13 @@ class TTS:
)
_batch_phones = torch.cat(batch_phones).unsqueeze(0).to(self.configs.device)
if self.is_v2pro != True:
_batch_audio_fragment = self.vits_model.decode(all_pred_semantic, _batch_phones, refer_audio_spec, speed=speed_factor).detach()[0, 0, :]
_batch_audio_fragment = self.vits_model.decode(
all_pred_semantic, _batch_phones, refer_audio_spec, speed=speed_factor
).detach()[0, 0, :]
else:
_batch_audio_fragment = self.vits_model.decode(all_pred_semantic, _batch_phones, refer_audio_spec, speed=speed_factor,sv_emb=sv_emb).detach()[0, 0, :]
_batch_audio_fragment = self.vits_model.decode(
all_pred_semantic, _batch_phones, refer_audio_spec, speed=speed_factor, sv_emb=sv_emb
).detach()[0, 0, :]
audio_frag_end_idx.insert(0, 0)
batch_audio_fragment = [
_batch_audio_fragment[audio_frag_end_idx[i - 1] : audio_frag_end_idx[i]]
@ -1266,9 +1282,13 @@ class TTS:
pred_semantic_list[i][-idx:].unsqueeze(0).unsqueeze(0)
) # .unsqueeze(0)#mq要多unsqueeze一次
if self.is_v2pro != True:
audio_fragment = self.vits_model.decode(_pred_semantic, phones, refer_audio_spec, speed=speed_factor).detach()[0, 0, :]
audio_fragment = self.vits_model.decode(
_pred_semantic, phones, refer_audio_spec, speed=speed_factor
).detach()[0, 0, :]
else:
audio_fragment = self.vits_model.decode(_pred_semantic, phones, refer_audio_spec, speed=speed_factor,sv_emb=sv_emb).detach()[0, 0, :]
audio_fragment = self.vits_model.decode(
_pred_semantic, phones, refer_audio_spec, speed=speed_factor, sv_emb=sv_emb
).detach()[0, 0, :]
batch_audio_fragment.append(audio_fragment) ###试试重建不带上prompt部分
else:
if parallel_infer:

View File

@ -160,7 +160,9 @@ class TextPreprocessor:
else:
for tmp in LangSegmenter.getTexts(text):
if langlist:
if (tmp["lang"] == "en" and langlist[-1] == "en") or (tmp["lang"] != "en" and langlist[-1] != "en"):
if (tmp["lang"] == "en" and langlist[-1] == "en") or (
tmp["lang"] != "en" and langlist[-1] != "en"
):
textlist[-1] += tmp["text"]
continue
if tmp["lang"] == "en":

View File

@ -8,7 +8,6 @@
The global feature fusion (GFF) takes acoustic features of different scales as input to aggregate global signal.
"""
import torch
import math
import torch.nn as nn
@ -16,15 +15,14 @@ import torch.nn.functional as F
import pooling_layers as pooling_layers
from fusion import AFF
class ReLU(nn.Hardtanh):
class ReLU(nn.Hardtanh):
def __init__(self, inplace=False):
super(ReLU, self).__init__(0, 20, inplace)
def __repr__(self):
inplace_str = 'inplace' if self.inplace else ''
return self.__class__.__name__ + ' (' \
+ inplace_str + ')'
inplace_str = "inplace" if self.inplace else ""
return self.__class__.__name__ + " (" + inplace_str + ")"
class BasicBlockERes2Net(nn.Module):
@ -51,9 +49,9 @@ class BasicBlockERes2Net(nn.Module):
self.shortcut = nn.Sequential()
if stride != 1 or in_planes != self.expansion * planes:
self.shortcut = nn.Sequential(
nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1,
stride=stride, bias=False),
nn.BatchNorm2d(self.expansion * planes))
nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(self.expansion * planes),
)
self.stride = stride
self.width = width
self.scale = scale
@ -86,6 +84,7 @@ class BasicBlockERes2Net(nn.Module):
return out
class BasicBlockERes2Net_diff_AFF(nn.Module):
expansion = 2
@ -115,9 +114,9 @@ class BasicBlockERes2Net_diff_AFF(nn.Module):
self.shortcut = nn.Sequential()
if stride != 1 or in_planes != self.expansion * planes:
self.shortcut = nn.Sequential(
nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1,
stride=stride, bias=False),
nn.BatchNorm2d(self.expansion * planes))
nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(self.expansion * planes),
)
self.stride = stride
self.width = width
self.scale = scale
@ -151,16 +150,19 @@ class BasicBlockERes2Net_diff_AFF(nn.Module):
return out
class ERes2Net(nn.Module):
def __init__(self,
def __init__(
self,
block=BasicBlockERes2Net,
block_fuse=BasicBlockERes2Net_diff_AFF,
num_blocks=[3, 4, 6, 3],
m_channels=32,
feat_dim=80,
embedding_size=192,
pooling_func='TSTP',
two_emb_layer=False):
pooling_func="TSTP",
two_emb_layer=False,
):
super(ERes2Net, self).__init__()
self.in_planes = m_channels
self.feat_dim = feat_dim
@ -176,20 +178,24 @@ class ERes2Net(nn.Module):
self.layer4 = self._make_layer(block_fuse, m_channels * 8, num_blocks[3], stride=2)
# Downsampling module for each layer
self.layer1_downsample = nn.Conv2d(m_channels * 2, m_channels * 4, kernel_size=3, stride=2, padding=1, bias=False)
self.layer2_downsample = nn.Conv2d(m_channels * 4, m_channels * 8, kernel_size=3, padding=1, stride=2, bias=False)
self.layer3_downsample = nn.Conv2d(m_channels * 8, m_channels * 16, kernel_size=3, padding=1, stride=2, bias=False)
self.layer1_downsample = nn.Conv2d(
m_channels * 2, m_channels * 4, kernel_size=3, stride=2, padding=1, bias=False
)
self.layer2_downsample = nn.Conv2d(
m_channels * 4, m_channels * 8, kernel_size=3, padding=1, stride=2, bias=False
)
self.layer3_downsample = nn.Conv2d(
m_channels * 8, m_channels * 16, kernel_size=3, padding=1, stride=2, bias=False
)
# Bottom-up fusion module
self.fuse_mode12 = AFF(channels=m_channels * 4)
self.fuse_mode123 = AFF(channels=m_channels * 8)
self.fuse_mode1234 = AFF(channels=m_channels * 16)
self.n_stats = 1 if pooling_func == 'TAP' or pooling_func == "TSDP" else 2
self.pool = getattr(pooling_layers, pooling_func)(
in_dim=self.stats_dim * block.expansion)
self.seg_1 = nn.Linear(self.stats_dim * block.expansion * self.n_stats,
embedding_size)
self.n_stats = 1 if pooling_func == "TAP" or pooling_func == "TSDP" else 2
self.pool = getattr(pooling_layers, pooling_func)(in_dim=self.stats_dim * block.expansion)
self.seg_1 = nn.Linear(self.stats_dim * block.expansion * self.n_stats, embedding_size)
if self.two_emb_layer:
self.seg_bn_1 = nn.BatchNorm1d(embedding_size, affine=False)
self.seg_2 = nn.Linear(embedding_size, embedding_size)
@ -247,14 +253,12 @@ class ERes2Net(nn.Module):
return fuse_out1234
if __name__ == '__main__':
if __name__ == "__main__":
x = torch.zeros(10, 300, 80)
model = ERes2Net(feat_dim=80, embedding_size=192, pooling_func='TSTP')
model = ERes2Net(feat_dim=80, embedding_size=192, pooling_func="TSTP")
model.eval()
out = model(x)
print(out.shape) # torch.Size([10, 192])
num_params = sum(param.numel() for param in model.parameters())
print("{} M".format(num_params / 1e6)) # 6.61M

View File

@ -8,8 +8,6 @@
both the model parameters and its computational cost.
"""
import torch
import math
import torch.nn as nn
@ -17,19 +15,17 @@ import torch.nn.functional as F
import pooling_layers as pooling_layers
from fusion import AFF
class ReLU(nn.Hardtanh):
class ReLU(nn.Hardtanh):
def __init__(self, inplace=False):
super(ReLU, self).__init__(0, 20, inplace)
def __repr__(self):
inplace_str = 'inplace' if self.inplace else ''
return self.__class__.__name__ + ' (' \
+ inplace_str + ')'
inplace_str = "inplace" if self.inplace else ""
return self.__class__.__name__ + " (" + inplace_str + ")"
class BasicBlockERes2NetV2(nn.Module):
def __init__(self, in_planes, planes, stride=1, baseWidth=26, scale=2, expansion=2):
super(BasicBlockERes2NetV2, self).__init__()
width = int(math.floor(planes * (baseWidth / 64.0)))
@ -52,12 +48,9 @@ class BasicBlockERes2NetV2(nn.Module):
self.shortcut = nn.Sequential()
if stride != 1 or in_planes != self.expansion * planes:
self.shortcut = nn.Sequential(
nn.Conv2d(in_planes,
self.expansion * planes,
kernel_size=1,
stride=stride,
bias=False),
nn.BatchNorm2d(self.expansion * planes))
nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(self.expansion * planes),
)
self.stride = stride
self.width = width
self.scale = scale
@ -90,8 +83,8 @@ class BasicBlockERes2NetV2(nn.Module):
return out
class BasicBlockERes2NetV2AFF(nn.Module):
class BasicBlockERes2NetV2AFF(nn.Module):
def __init__(self, in_planes, planes, stride=1, baseWidth=26, scale=2, expansion=2):
super(BasicBlockERes2NetV2AFF, self).__init__()
width = int(math.floor(planes * (baseWidth / 64.0)))
@ -119,12 +112,9 @@ class BasicBlockERes2NetV2AFF(nn.Module):
self.shortcut = nn.Sequential()
if stride != 1 or in_planes != self.expansion * planes:
self.shortcut = nn.Sequential(
nn.Conv2d(in_planes,
self.expansion * planes,
kernel_size=1,
stride=stride,
bias=False),
nn.BatchNorm2d(self.expansion * planes))
nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(self.expansion * planes),
)
self.stride = stride
self.width = width
self.scale = scale
@ -158,8 +148,10 @@ class BasicBlockERes2NetV2AFF(nn.Module):
return out
class ERes2NetV2(nn.Module):
def __init__(self,
def __init__(
self,
block=BasicBlockERes2NetV2,
block_fuse=BasicBlockERes2NetV2AFF,
num_blocks=[3, 4, 6, 3],
@ -169,8 +161,9 @@ class ERes2NetV2(nn.Module):
baseWidth=26,
scale=2,
expansion=2,
pooling_func='TSTP',
two_emb_layer=False):
pooling_func="TSTP",
two_emb_layer=False,
):
super(ERes2NetV2, self).__init__()
self.in_planes = m_channels
self.feat_dim = feat_dim
@ -181,42 +174,29 @@ class ERes2NetV2(nn.Module):
self.scale = scale
self.expansion = expansion
self.conv1 = nn.Conv2d(1,
m_channels,
kernel_size=3,
stride=1,
padding=1,
bias=False)
self.conv1 = nn.Conv2d(1, m_channels, kernel_size=3, stride=1, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(m_channels)
self.layer1 = self._make_layer(block,
m_channels,
num_blocks[0],
stride=1)
self.layer2 = self._make_layer(block,
m_channels * 2,
num_blocks[1],
stride=2)
self.layer3 = self._make_layer(block_fuse,
m_channels * 4,
num_blocks[2],
stride=2)
self.layer4 = self._make_layer(block_fuse,
m_channels * 8,
num_blocks[3],
stride=2)
self.layer1 = self._make_layer(block, m_channels, num_blocks[0], stride=1)
self.layer2 = self._make_layer(block, m_channels * 2, num_blocks[1], stride=2)
self.layer3 = self._make_layer(block_fuse, m_channels * 4, num_blocks[2], stride=2)
self.layer4 = self._make_layer(block_fuse, m_channels * 8, num_blocks[3], stride=2)
# Downsampling module
self.layer3_ds = nn.Conv2d(m_channels * 4 * self.expansion, m_channels * 8 * self.expansion, kernel_size=3, \
padding=1, stride=2, bias=False)
self.layer3_ds = nn.Conv2d(
m_channels * 4 * self.expansion,
m_channels * 8 * self.expansion,
kernel_size=3,
padding=1,
stride=2,
bias=False,
)
# Bottom-up fusion module
self.fuse34 = AFF(channels=m_channels * 8 * self.expansion, r=4)
self.n_stats = 1 if pooling_func == 'TAP' or pooling_func == "TSDP" else 2
self.pool = getattr(pooling_layers, pooling_func)(
in_dim=self.stats_dim * self.expansion)
self.seg_1 = nn.Linear(self.stats_dim * self.expansion * self.n_stats,
embedding_size)
self.n_stats = 1 if pooling_func == "TAP" or pooling_func == "TSDP" else 2
self.pool = getattr(pooling_layers, pooling_func)(in_dim=self.stats_dim * self.expansion)
self.seg_1 = nn.Linear(self.stats_dim * self.expansion * self.n_stats, embedding_size)
if self.two_emb_layer:
self.seg_bn_1 = nn.BatchNorm1d(embedding_size, affine=False)
self.seg_2 = nn.Linear(embedding_size, embedding_size)
@ -228,7 +208,11 @@ class ERes2NetV2(nn.Module):
strides = [stride] + [1] * (num_blocks - 1)
layers = []
for stride in strides:
layers.append(block(self.in_planes, planes, stride, baseWidth=self.baseWidth, scale=self.scale, expansion=self.expansion))
layers.append(
block(
self.in_planes, planes, stride, baseWidth=self.baseWidth, scale=self.scale, expansion=self.expansion
)
)
self.in_planes = planes * self.expansion
return nn.Sequential(*layers)
@ -276,8 +260,8 @@ class ERes2NetV2(nn.Module):
# else:
# return embed_a
if __name__ == '__main__':
if __name__ == "__main__":
x = torch.randn(1, 300, 80)
model = ERes2NetV2(feat_dim=80, embedding_size=192, m_channels=64, baseWidth=26, scale=2, expansion=2)
model.eval()
@ -286,7 +270,3 @@ if __name__ == '__main__':
macs, num_params = profile(model, inputs=(x,))
print("Params: {} M".format(num_params / 1e6)) # 17.86 M
print("MACs: {} G".format(macs / 1e9)) # 12.69 G

View File

@ -8,7 +8,6 @@
ERes2Net-huge is an upgraded version of ERes2Net that uses a larger number of parameters to achieve better
recognition performance. Parameters expansion, baseWidth, and scale can be modified to obtain optimal performance.
"""
import pdb
import torch
import math
@ -17,15 +16,14 @@ import torch.nn.functional as F
import pooling_layers as pooling_layers
from fusion import AFF
class ReLU(nn.Hardtanh):
class ReLU(nn.Hardtanh):
def __init__(self, inplace=False):
super(ReLU, self).__init__(0, 20, inplace)
def __repr__(self):
inplace_str = 'inplace' if self.inplace else ''
return self.__class__.__name__ + ' (' \
+ inplace_str + ')'
inplace_str = "inplace" if self.inplace else ""
return self.__class__.__name__ + " (" + inplace_str + ")"
class BasicBlockERes2Net(nn.Module):
@ -53,7 +51,8 @@ class BasicBlockERes2Net(nn.Module):
if stride != 1 or in_planes != self.expansion * planes:
self.shortcut = nn.Sequential(
nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(self.expansion * planes))
nn.BatchNorm2d(self.expansion * planes),
)
self.stride = stride
self.width = width
self.scale = scale
@ -86,6 +85,7 @@ class BasicBlockERes2Net(nn.Module):
return out
class BasicBlockERes2Net_diff_AFF(nn.Module):
expansion = 4
@ -116,7 +116,8 @@ class BasicBlockERes2Net_diff_AFF(nn.Module):
if stride != 1 or in_planes != self.expansion * planes:
self.shortcut = nn.Sequential(
nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(self.expansion * planes))
nn.BatchNorm2d(self.expansion * planes),
)
self.stride = stride
self.width = width
self.scale = scale
@ -141,7 +142,6 @@ class BasicBlockERes2Net_diff_AFF(nn.Module):
else:
out = torch.cat((out, sp), 1)
out = self.conv3(out)
out = self.bn3(out)
@ -151,16 +151,19 @@ class BasicBlockERes2Net_diff_AFF(nn.Module):
return out
class ERes2Net(nn.Module):
def __init__(self,
def __init__(
self,
block=BasicBlockERes2Net,
block_fuse=BasicBlockERes2Net_diff_AFF,
num_blocks=[3, 4, 6, 3],
m_channels=64,
feat_dim=80,
embedding_size=192,
pooling_func='TSTP',
two_emb_layer=False):
pooling_func="TSTP",
two_emb_layer=False,
):
super(ERes2Net, self).__init__()
self.in_planes = m_channels
self.feat_dim = feat_dim
@ -176,17 +179,22 @@ class ERes2Net(nn.Module):
self.layer3 = self._make_layer(block_fuse, m_channels * 4, num_blocks[2], stride=2)
self.layer4 = self._make_layer(block_fuse, m_channels * 8, num_blocks[3], stride=2)
self.layer1_downsample = nn.Conv2d(m_channels * 4, m_channels * 8, kernel_size=3, padding=1, stride=2, bias=False)
self.layer2_downsample = nn.Conv2d(m_channels * 8, m_channels * 16, kernel_size=3, padding=1, stride=2, bias=False)
self.layer3_downsample = nn.Conv2d(m_channels * 16, m_channels * 32, kernel_size=3, padding=1, stride=2, bias=False)
self.layer1_downsample = nn.Conv2d(
m_channels * 4, m_channels * 8, kernel_size=3, padding=1, stride=2, bias=False
)
self.layer2_downsample = nn.Conv2d(
m_channels * 8, m_channels * 16, kernel_size=3, padding=1, stride=2, bias=False
)
self.layer3_downsample = nn.Conv2d(
m_channels * 16, m_channels * 32, kernel_size=3, padding=1, stride=2, bias=False
)
self.fuse_mode12 = AFF(channels=m_channels * 8)
self.fuse_mode123 = AFF(channels=m_channels * 16)
self.fuse_mode1234 = AFF(channels=m_channels * 32)
self.n_stats = 1 if pooling_func == 'TAP' or pooling_func == "TSDP" else 2
self.pool = getattr(pooling_layers, pooling_func)(
in_dim=self.stats_dim * block.expansion)
self.n_stats = 1 if pooling_func == "TAP" or pooling_func == "TSDP" else 2
self.pool = getattr(pooling_layers, pooling_func)(in_dim=self.stats_dim * block.expansion)
self.seg_1 = nn.Linear(self.stats_dim * block.expansion * self.n_stats, embedding_size)
if self.two_emb_layer:
self.seg_bn_1 = nn.BatchNorm1d(embedding_size, affine=False)
@ -244,14 +252,13 @@ class ERes2Net(nn.Module):
out4 = self.layer4(out3)
fuse_out123_downsample = self.layer3_downsample(fuse_out123)
fuse_out1234 = self.fuse_mode1234(out4, fuse_out123_downsample).flatten(start_dim=1, end_dim=2) # bs,20480,T
if(if_mean==False):
if if_mean == False:
mean = fuse_out1234[0].transpose(1, 0) # (T,20480),bs=T
else:
mean = fuse_out1234.mean(2) # bs,20480
mean_std = torch.cat([mean, torch.zeros_like(mean)], 1)
return self.seg_1(mean_std) # (T,192)
# stats = self.pool(fuse_out1234)
# if self.two_emb_layer:
# out = F.relu(embed_a)
@ -280,7 +287,3 @@ class ERes2Net(nn.Module):
# print(fuse_out1234.shape)
# print(fuse_out1234.flatten(start_dim=1,end_dim=2).shape)
# pdb.set_trace()

View File

@ -6,7 +6,6 @@ import torch.nn as nn
class AFF(nn.Module):
def __init__(self, channels=64, r=4):
super(AFF, self).__init__()
inter_channels = int(channels // r)
@ -26,4 +25,3 @@ class AFF(nn.Module):
xo = torch.mul(x, x_att) + torch.mul(ds_y, 2.0 - x_att)
return xo

View File

@ -144,7 +144,7 @@ def _get_waveform_and_window_properties(
)
assert 0 < window_shift, "`window_shift` must be greater than 0"
assert padded_window_size % 2 == 0, (
"the padded `window_size` must be divisible by two." " use `round_to_power_of_two` or change `frame_length`"
"the padded `window_size` must be divisible by two. use `round_to_power_of_two` or change `frame_length`"
)
assert 0.0 <= preemphasis_coefficient <= 1.0, "`preemphasis_coefficient` must be between [0,1]"
assert sample_frequency > 0, "`sample_frequency` must be greater than zero"
@ -441,7 +441,9 @@ def get_mel_banks(
high_freq: float,
vtln_low: float,
vtln_high: float,
vtln_warp_factor: float,device=None,dtype=None
vtln_warp_factor: float,
device=None,
dtype=None,
) -> Tuple[Tensor, Tensor]:
"""
Returns:
@ -457,9 +459,9 @@ def get_mel_banks(
if high_freq <= 0.0:
high_freq += nyquist
assert (
(0.0 <= low_freq < nyquist) and (0.0 < high_freq <= nyquist) and (low_freq < high_freq)
), "Bad values in options: low-freq {} and high-freq {} vs. nyquist {}".format(low_freq, high_freq, nyquist)
assert (0.0 <= low_freq < nyquist) and (0.0 < high_freq <= nyquist) and (low_freq < high_freq), (
"Bad values in options: low-freq {} and high-freq {} vs. nyquist {}".format(low_freq, high_freq, nyquist)
)
# fft-bin width [think of it as Nyquist-freq / half-window-length]
fft_bin_width = sample_freq / window_length_padded
@ -475,7 +477,7 @@ def get_mel_banks(
assert vtln_warp_factor == 1.0 or (
(low_freq < vtln_low < high_freq) and (0.0 < vtln_high < high_freq) and (vtln_low < vtln_high)
), "Bad values in options: vtln-low {} and vtln-high {}, versus " "low-freq {} and high-freq {}".format(
), "Bad values in options: vtln-low {} and vtln-high {}, versus low-freq {} and high-freq {}".format(
vtln_low, vtln_high, low_freq, high_freq
)
@ -510,7 +512,10 @@ def get_mel_banks(
return bins.to(device=device, dtype=dtype) # , center_freqs
cache = {}
def fbank(
waveform: Tensor,
blackman_coeff: float = 0.42,
@ -620,10 +625,30 @@ def fbank(
# size (num_mel_bins, padded_window_size // 2)
# print(num_mel_bins, padded_window_size, sample_frequency, low_freq, high_freq, vtln_low, vtln_high, vtln_warp)
cache_key="%s-%s-%s-%s-%s-%s-%s-%s-%s-%s"%(num_mel_bins, padded_window_size, sample_frequency, low_freq, high_freq, vtln_low, vtln_high, vtln_warp,device,dtype)
cache_key = "%s-%s-%s-%s-%s-%s-%s-%s-%s-%s" % (
num_mel_bins,
padded_window_size,
sample_frequency,
low_freq,
high_freq,
vtln_low,
vtln_high,
vtln_warp,
device,
dtype,
)
if cache_key not in cache:
mel_energies = get_mel_banks(
num_mel_bins, padded_window_size, sample_frequency, low_freq, high_freq, vtln_low, vtln_high, vtln_warp,device,dtype
num_mel_bins,
padded_window_size,
sample_frequency,
low_freq,
high_freq,
vtln_low,
vtln_high,
vtln_warp,
device,
dtype,
)
cache[cache_key] = mel_energies
else:

View File

@ -11,6 +11,7 @@ class TAP(nn.Module):
"""
Temporal average pooling, only first-order mean is considered
"""
def __init__(self, **kwargs):
super(TAP, self).__init__()
@ -25,6 +26,7 @@ class TSDP(nn.Module):
"""
Temporal standard deviation pooling, only second-order std is considered
"""
def __init__(self, **kwargs):
super(TSDP, self).__init__()
@ -41,6 +43,7 @@ class TSTP(nn.Module):
x-vector
Comment: simple concatenation can not make full use of both statistics
"""
def __init__(self, **kwargs):
super(TSTP, self).__init__()
@ -59,6 +62,7 @@ class ASTP(nn.Module):
"""Attentive statistics pooling: Channel- and context-dependent
statistics pooling, first used in ECAPA_TDNN.
"""
def __init__(self, in_dim, bottleneck_dim=128, global_context_att=False):
super(ASTP, self).__init__()
self.global_context_att = global_context_att
@ -66,15 +70,10 @@ class ASTP(nn.Module):
# Use Conv1d with stride == 1 rather than Linear, then we don't
# need to transpose inputs.
if global_context_att:
self.linear1 = nn.Conv1d(
in_dim * 3, bottleneck_dim,
kernel_size=1) # equals W and b in the paper
self.linear1 = nn.Conv1d(in_dim * 3, bottleneck_dim, kernel_size=1) # equals W and b in the paper
else:
self.linear1 = nn.Conv1d(
in_dim, bottleneck_dim,
kernel_size=1) # equals W and b in the paper
self.linear2 = nn.Conv1d(bottleneck_dim, in_dim,
kernel_size=1) # equals V and k in the paper
self.linear1 = nn.Conv1d(in_dim, bottleneck_dim, kernel_size=1) # equals W and b in the paper
self.linear2 = nn.Conv1d(bottleneck_dim, in_dim, kernel_size=1) # equals V and k in the paper
def forward(self, x):
"""
@ -88,15 +87,13 @@ class ASTP(nn.Module):
if self.global_context_att:
context_mean = torch.mean(x, dim=-1, keepdim=True).expand_as(x)
context_std = torch.sqrt(
torch.var(x, dim=-1, keepdim=True) + 1e-10).expand_as(x)
context_std = torch.sqrt(torch.var(x, dim=-1, keepdim=True) + 1e-10).expand_as(x)
x_in = torch.cat((x, context_mean, context_std), dim=1)
else:
x_in = x
# DON'T use ReLU here! ReLU may be hard to converge.
alpha = torch.tanh(
self.linear1(x_in)) # alpha = F.relu(self.linear1(x_in))
alpha = torch.tanh(self.linear1(x_in)) # alpha = F.relu(self.linear1(x_in))
alpha = torch.softmax(self.linear2(alpha), dim=2)
mean = torch.sum(alpha * x, dim=2)
var = torch.sum(alpha * (x**2), dim=2) - mean**2

View File

@ -435,6 +435,7 @@ class GPTSoVITSV3(torch.nn.Module):
wav_gen = torch.cat(wav_gen_list, 2)
return wav_gen[0][0][:wav_gen_length]
class GPTSoVITSV4(torch.nn.Module):
def __init__(self, gpt_sovits_half, cfm, hifigan):
super().__init__()
@ -577,6 +578,7 @@ from process_ckpt import get_sovits_version_from_path_fast, load_sovits_new
v3v4set = {"v3", "v4"}
def get_sovits_weights(sovits_path):
path_sovits_v3 = "GPT_SoVITS/pretrained_models/s2Gv3.pth"
is_exist_s2gv3 = os.path.exists(path_sovits_v3)
@ -707,7 +709,6 @@ def export_1(ref_wav_path,ref_wav_text,version="v3"):
sovits = get_sovits_weights("GPT_SoVITS/pretrained_models/gsv-v4-pretrained/s2Gv4.pth")
init_hifigan()
dict_s1 = torch.load("GPT_SoVITS/pretrained_models/s1v3.ckpt")
raw_t2s = get_raw_t2s_model(dict_s1).to(device)
print("#### get_raw_t2s_model ####")
@ -751,9 +752,7 @@ def export_1(ref_wav_path,ref_wav_text,version="v3"):
# phones1, bert1, norm_text1 = get_phones_and_bert(
# "你这老坏蛋,我找了你这么久,真没想到在这里找到你。他说。", "all_zh", "v3"
# )
phones1, bert1, norm_text1 = get_phones_and_bert(
ref_wav_text, "auto", "v3"
)
phones1, bert1, norm_text1 = get_phones_and_bert(ref_wav_text, "auto", "v3")
phones2, bert2, norm_text2 = get_phones_and_bert(
"这是一个简单的示例真没想到这么简单就完成了。The King and His Stories.Once there was a king. He likes to write stories, but his stories were not good. As people were afraid of him, they all said his stories were good.After reading them, the writer at once turned to the soldiers and said: Take me back to prison, please.",
"auto",
@ -1201,7 +1200,6 @@ def export_2(version="v3"):
gpt_sovits_v3v4 = gpt_sovits_v3 if version == "v3" else gpt_sovits_v4
sr = 24000 if version == "v3" else 48000
time.sleep(5)
# print("thread:", torch.get_num_threads())
# print("thread:", torch.get_num_interop_threads())
@ -1212,14 +1210,14 @@ def export_2(version="v3"):
"汗流浃背了呀!老弟~ My uncle has two dogs. One is big and the other is small. He likes them very much. He often plays with them. He takes them for a walk every day. He says they are his good friends. He is very happy with them. 最后还是我得了 MVP....",
gpt_sovits_v3v4,
"out.wav",
sr
sr,
)
test_export(
"你小子是什么来路.汗流浃背了呀!老弟~ My uncle has two dogs. He is very happy with them. 最后还是我得了 MVP!",
gpt_sovits_v3v4,
"out2.wav",
sr
sr,
)
# test_export(

View File

@ -252,9 +252,28 @@ class TextAudioSpeakerCollate:
if self.is_v2Pro:
sv_embs[i] = row[4]
if self.is_v2Pro:
return ssl_padded, ssl_lengths, spec_padded, spec_lengths, wav_padded, wav_lengths, text_padded, text_lengths,sv_embs
return (
ssl_padded,
ssl_lengths,
spec_padded,
spec_lengths,
wav_padded,
wav_lengths,
text_padded,
text_lengths,
sv_embs,
)
else:
return ssl_padded, ssl_lengths, spec_padded, spec_lengths, wav_padded, wav_lengths, text_padded, text_lengths
return (
ssl_padded,
ssl_lengths,
spec_padded,
spec_lengths,
wav_padded,
wav_lengths,
text_padded,
text_lengths,
)
class TextAudioSpeakerLoaderV3(torch.utils.data.Dataset):

View File

@ -586,12 +586,17 @@ class DiscriminatorS(torch.nn.Module):
return x, fmap
v2pro_set = {"v2Pro", "v2ProPlus"}
class MultiPeriodDiscriminator(torch.nn.Module):
def __init__(self, use_spectral_norm=False, version=None):
super(MultiPeriodDiscriminator, self).__init__()
if version in v2pro_set:periods = [2, 3, 5, 7, 11,17,23]
else:periods = [2, 3, 5, 7, 11]
if version in v2pro_set:
periods = [2, 3, 5, 7, 11, 17, 23]
else:
periods = [2, 3, 5, 7, 11]
discs = [DiscriminatorS(use_spectral_norm=use_spectral_norm)]
discs = discs + [DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods]
@ -787,6 +792,7 @@ class CodePredictor(nn.Module):
return pred_codes.transpose(0, 1)
class SynthesizerTrn(nn.Module):
"""
Synthesizer for Training
@ -983,7 +989,14 @@ class SynthesizerTrn(nn.Module):
quantized = self.quantizer.decode(codes)
if self.semantic_frame_rate == "25hz":
quantized = F.interpolate(quantized, size=int(quantized.shape[-1] * 2), mode="nearest")
x, m_p, logs_p, y_mask = self.enc_p(quantized, y_lengths, text, text_lengths, self.ge_to512(ge.transpose(2,1)).transpose(2,1)if self.is_v2pro else ge, speed)
x, m_p, logs_p, y_mask = self.enc_p(
quantized,
y_lengths,
text,
text_lengths,
self.ge_to512(ge.transpose(2, 1)).transpose(2, 1) if self.is_v2pro else ge,
speed,
)
z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale
z = self.flow(z_p, y_mask, g=ge, reverse=True)
@ -996,6 +1009,7 @@ class SynthesizerTrn(nn.Module):
quantized, codes, commit_loss, quantized_list = self.quantizer(ssl)
return codes.transpose(0, 1)
class CFM(torch.nn.Module):
def __init__(self, in_channels, dit):
super().__init__()
@ -1029,7 +1043,18 @@ class CFM(torch.nn.Module):
t_tensor = torch.ones(x.shape[0], device=x.device, dtype=mu.dtype) * t
# v_pred = model(x, t_tensor, d_tensor, **extra_args)
v_pred, text_emb, dt = self.estimator(
x, prompt_x, x_lens, t_tensor, d_tensor, mu, use_grad_ckpt=False, drop_audio_cond=False, drop_text=False, infer=True, text_cache=text_cache, dt_cache=dt_cache
x,
prompt_x,
x_lens,
t_tensor,
d_tensor,
mu,
use_grad_ckpt=False,
drop_audio_cond=False,
drop_text=False,
infer=True,
text_cache=text_cache,
dt_cache=dt_cache,
)
v_pred = v_pred.transpose(2, 1)
if self.use_conditioner_cache:
@ -1048,7 +1073,7 @@ class CFM(torch.nn.Module):
drop_text=True,
infer=True,
text_cache=text_cfg_cache,
dt_cache=dt_cache
dt_cache=dt_cache,
)
neg = neg.transpose(2, 1)
if self.use_conditioner_cache:

View File

@ -1,5 +1,4 @@
import math
import pdb
import numpy as np
import torch

View File

@ -10,7 +10,6 @@ i_part = os.environ.get("i_part")
all_parts = os.environ.get("all_parts")
if "_CUDA_VISIBLE_DEVICES" in os.environ:
os.environ["CUDA_VISIBLE_DEVICES"] = os.environ["_CUDA_VISIBLE_DEVICES"]
from feature_extractor import cnhubert
opt_dir = os.environ.get("opt_dir")
sv_path = os.environ.get("sv_path")
@ -19,19 +18,18 @@ import torch
is_half = eval(os.environ.get("is_half", "True")) and torch.cuda.is_available()
import traceback
import numpy as np
from scipy.io import wavfile
import torchaudio
now_dir = os.getcwd()
sys.path.append(now_dir)
sys.path.append(f"{now_dir}/GPT_SoVITS/eres2net")
from tools.my_utils import load_audio, clean_path
from tools.my_utils import clean_path
from time import time as ttime
import shutil
from ERes2NetV2 import ERes2NetV2
import kaldi as Kaldi
def my_save(fea, path): #####fix issue: torch.save doesn't support chinese path
dir = os.path.dirname(path)
name = os.path.basename(path)
@ -56,9 +54,10 @@ if torch.cuda.is_available():
else:
device = "cpu"
class SV:
def __init__(self, device, is_half):
pretrained_state = torch.load(sv_path, map_location='cpu')
pretrained_state = torch.load(sv_path, map_location="cpu")
embedding_model = ERes2NetV2(baseWidth=24, scale=4, expansion=4)
embedding_model.load_state_dict(pretrained_state)
embedding_model.eval()
@ -73,15 +72,22 @@ class SV:
def compute_embedding3(self, wav): # (1,x)#-1~1
with torch.no_grad():
wav = self.res(wav)
if self.is_half==True:wav=wav.half()
feat = torch.stack([Kaldi.fbank(wav0.unsqueeze(0), num_mel_bins=80, sample_frequency=16000, dither=0) for wav0 in wav])
if self.is_half == True:
wav = wav.half()
feat = torch.stack(
[Kaldi.fbank(wav0.unsqueeze(0), num_mel_bins=80, sample_frequency=16000, dither=0) for wav0 in wav]
)
sv_emb = self.embedding_model.forward3(feat)
return sv_emb
sv = SV(device, is_half)
def name2go(wav_name, wav_path):
sv_cn_path = "%s/%s.pt" % (sv_cn_dir, wav_name)
if os.path.exists(sv_cn_path):return
if os.path.exists(sv_cn_path):
return
wav_path = "%s/%s" % (wav32dir, wav_name)
wav32k, sr0 = torchaudio.load(wav_path)
assert sr0 == 32000

View File

@ -17,7 +17,6 @@ def my_save(fea, path): #####fix issue: torch.save doesn't support chinese path
shutil.move(tmp_path, "%s/%s" % (dir, name))
from io import BytesIO
model_version2byte = {
@ -26,6 +25,8 @@ model_version2byte={
"v2Pro": b"05",
"v2ProPlus": b"06",
}
def my_save2(fea, path, model_version):
bio = BytesIO()
torch.save(fea, bio)
@ -50,7 +51,7 @@ def savee(ckpt, name, epoch, steps, hps, model_version=None, lora_rank=None):
if lora_rank:
opt["lora_rank"] = lora_rank
my_save2(opt, "%s/%s.pth" % (hps.save_weight_dir, name), model_version)
elif (model_version!=None and "Pro"in model_version):
elif model_version != None and "Pro" in model_version:
my_save2(opt, "%s/%s.pth" % (hps.save_weight_dir, name), model_version)
else:
my_save(opt, "%s/%s.pth" % (hps.save_weight_dir, name))
@ -58,6 +59,7 @@ def savee(ckpt, name, epoch, steps, hps, model_version=None, lora_rank=None):
except:
return traceback.format_exc()
"""
00:v1
01:v2

View File

@ -36,7 +36,7 @@ from module.models import (
MultiPeriodDiscriminator,
SynthesizerTrn,
)
from process_ckpt import savee,my_save2
from process_ckpt import savee
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = False
@ -91,7 +91,26 @@ def run(rank, n_gpus, hps):
train_sampler = DistributedBucketSampler(
train_dataset,
hps.train.batch_size,
[32,300,400,500,600,700,800,900,1000,1100,1200,1300,1400,1500,1600,1700,1800,1900,],
[
32,
300,
400,
500,
600,
700,
800,
900,
1000,
1100,
1200,
1300,
1400,
1500,
1600,
1700,
1800,
1900,
],
num_replicas=n_gpus,
rank=rank,
shuffle=True,
@ -315,12 +334,39 @@ def train_and_evaluate(rank, epoch, hps, nets, optims, schedulers, scaler, loade
else:
ssl, ssl_lengths, spec, spec_lengths, y, y_lengths, text, text_lengths = data
if torch.cuda.is_available():
spec, spec_lengths = (spec.cuda(rank,non_blocking=True,),spec_lengths.cuda(rank,non_blocking=True,),)
y, y_lengths = (y.cuda(rank,non_blocking=True,),y_lengths.cuda(rank,non_blocking=True,),)
spec, spec_lengths = (
spec.cuda(
rank,
non_blocking=True,
),
spec_lengths.cuda(
rank,
non_blocking=True,
),
)
y, y_lengths = (
y.cuda(
rank,
non_blocking=True,
),
y_lengths.cuda(
rank,
non_blocking=True,
),
)
ssl = ssl.cuda(rank, non_blocking=True)
ssl.requires_grad = False
# ssl_lengths = ssl_lengths.cuda(rank, non_blocking=True)
text, text_lengths = (text.cuda(rank,non_blocking=True,),text_lengths.cuda(rank,non_blocking=True,),)
text, text_lengths = (
text.cuda(
rank,
non_blocking=True,
),
text_lengths.cuda(
rank,
non_blocking=True,
),
)
if hps.model.version in {"v2Pro", "v2ProPlus"}:
sv_emb = sv_emb.cuda(rank, non_blocking=True)
else:
@ -334,9 +380,19 @@ def train_and_evaluate(rank, epoch, hps, nets, optims, schedulers, scaler, loade
sv_emb = sv_emb.to(device)
with autocast(enabled=hps.train.fp16_run):
if hps.model.version in {"v2Pro", "v2ProPlus"}:
(y_hat,kl_ssl,ids_slice,x_mask,z_mask,(z, z_p, m_p, logs_p, m_q, logs_q),stats_ssl) = net_g(ssl, spec, spec_lengths, text, text_lengths,sv_emb)
(y_hat, kl_ssl, ids_slice, x_mask, z_mask, (z, z_p, m_p, logs_p, m_q, logs_q), stats_ssl) = net_g(
ssl, spec, spec_lengths, text, text_lengths, sv_emb
)
else:
(y_hat,kl_ssl,ids_slice,x_mask,z_mask,(z, z_p, m_p, logs_p, m_q, logs_q),stats_ssl,) = net_g(ssl, spec, spec_lengths, text, text_lengths)
(
y_hat,
kl_ssl,
ids_slice,
x_mask,
z_mask,
(z, z_p, m_p, logs_p, m_q, logs_q),
stats_ssl,
) = net_g(ssl, spec, spec_lengths, text, text_lengths)
mel = spec_to_mel_torch(
spec,
@ -508,7 +564,14 @@ def train_and_evaluate(rank, epoch, hps, nets, optims, schedulers, scaler, loade
% (
hps.name,
epoch,
savee(ckpt,hps.name + "_e%s_s%s" % (epoch, global_step),epoch,global_step,hps,model_version=None if hps.model.version not in {"v2Pro","v2ProPlus"}else hps.model.version),
savee(
ckpt,
hps.name + "_e%s_s%s" % (epoch, global_step),
epoch,
global_step,
hps,
model_version=None if hps.model.version not in {"v2Pro", "v2ProPlus"} else hps.model.version,
),
)
)

View File

@ -1,11 +1,16 @@
import sys,os,torch
import sys
import os
import torch
sys.path.append(f"{os.getcwd()}/GPT_SoVITS/eres2net")
sv_path = "GPT_SoVITS/pretrained_models/sv/pretrained_eres2netv2w24s4ep4.ckpt"
from ERes2NetV2 import ERes2NetV2
import kaldi as Kaldi
class SV:
def __init__(self, device, is_half):
pretrained_state = torch.load(sv_path, map_location='cpu', weights_only=False)
pretrained_state = torch.load(sv_path, map_location="cpu", weights_only=False)
embedding_model = ERes2NetV2(baseWidth=24, scale=4, expansion=4)
embedding_model.load_state_dict(pretrained_state)
embedding_model.eval()
@ -18,7 +23,10 @@ class SV:
def compute_embedding3(self, wav):
with torch.no_grad():
if self.is_half==True:wav=wav.half()
feat = torch.stack([Kaldi.fbank(wav0.unsqueeze(0), num_mel_bins=80, sample_frequency=16000, dither=0) for wav0 in wav])
if self.is_half == True:
wav = wav.half()
feat = torch.stack(
[Kaldi.fbank(wav0.unsqueeze(0), num_mel_bins=80, sample_frequency=16000, dither=0) for wav0 in wav]
)
sv_emb = self.embedding_model.forward3(feat)
return sv_emb

View File

@ -3,19 +3,25 @@ import re
# jieba静音
import jieba
jieba.setLogLevel(logging.CRITICAL)
# 更改fast_langdetect大模型位置
from pathlib import Path
import fast_langdetect
fast_langdetect.infer._default_detector = fast_langdetect.infer.LangDetector(fast_langdetect.infer.LangDetectConfig(cache_dir=Path(__file__).parent.parent.parent / "pretrained_models" / "fast_langdetect"))
fast_langdetect.infer._default_detector = fast_langdetect.infer.LangDetector(
fast_langdetect.infer.LangDetectConfig(
cache_dir=Path(__file__).parent.parent.parent / "pretrained_models" / "fast_langdetect"
)
)
from split_lang import LangSplitter
def full_en(text):
pattern = r'^(?=.*[A-Za-z])[A-Za-z0-9\s\u0020-\u007E\u2000-\u206F\u3000-\u303F\uFF00-\uFFEF]+$'
pattern = r"^(?=.*[A-Za-z])[A-Za-z0-9\s\u0020-\u007E\u2000-\u206F\u3000-\u303F\uFF00-\uFFEF]+$"
return bool(re.match(pattern, text))
@ -34,7 +40,7 @@ def full_cjk(text):
(0x2EBF0, 0x2EE5D), # CJK Extension H
]
pattern = r'[0-9、-〜。!?.!?… /]+$'
pattern = r"[0-9、-〜。!?.!?… /]+$"
cjk_text = ""
for char in text:
@ -53,28 +59,28 @@ def split_jako(tag_lang,item):
lang_list: list[dict] = []
tag = 0
for match in re.finditer(pattern, item['text']):
for match in re.finditer(pattern, item["text"]):
if match.start() > tag:
lang_list.append({'lang':item['lang'],'text':item['text'][tag:match.start()]})
lang_list.append({"lang": item["lang"], "text": item["text"][tag : match.start()]})
tag = match.end()
lang_list.append({'lang':tag_lang,'text':item['text'][match.start():match.end()]})
lang_list.append({"lang": tag_lang, "text": item["text"][match.start() : match.end()]})
if tag < len(item['text']):
lang_list.append({'lang':item['lang'],'text':item['text'][tag:len(item['text'])]})
if tag < len(item["text"]):
lang_list.append({"lang": item["lang"], "text": item["text"][tag : len(item["text"])]})
return lang_list
def merge_lang(lang_list, item):
if lang_list and item['lang'] == lang_list[-1]['lang']:
lang_list[-1]['text'] += item['text']
if lang_list and item["lang"] == lang_list[-1]["lang"]:
lang_list[-1]["text"] += item["text"]
else:
lang_list.append(item)
return lang_list
class LangSegmenter():
class LangSegmenter:
# 默认过滤器, 基于gsv目前四种语言
DEFAULT_LANG_MAP = {
"zh": "zh",
@ -87,7 +93,6 @@ class LangSegmenter():
"en": "en",
}
def getTexts(text):
lang_splitter = LangSplitter(lang_map=LangSegmenter.DEFAULT_LANG_MAP)
substr = lang_splitter.split_by_lang(text=text)
@ -95,18 +100,18 @@ class LangSegmenter():
lang_list: list[dict] = []
for _, item in enumerate(substr):
dict_item = {'lang':item.lang,'text':item.text}
dict_item = {"lang": item.lang, "text": item.text}
# 处理短英文被识别为其他语言的问题
if full_en(dict_item['text']):
dict_item['lang'] = 'en'
if full_en(dict_item["text"]):
dict_item["lang"] = "en"
lang_list = merge_lang(lang_list, dict_item)
continue
# 处理非日语夹日文的问题(不包含CJK)
ja_list: list[dict] = []
if dict_item['lang'] != 'ja':
ja_list = split_jako('ja',dict_item)
if dict_item["lang"] != "ja":
ja_list = split_jako("ja", dict_item)
if not ja_list:
ja_list.append(dict_item)
@ -115,8 +120,8 @@ class LangSegmenter():
ko_list: list[dict] = []
temp_list: list[dict] = []
for _, ko_item in enumerate(ja_list):
if ko_item["lang"] != 'ko':
ko_list = split_jako('ko',ko_item)
if ko_item["lang"] != "ko":
ko_list = split_jako("ko", ko_item)
if ko_list:
temp_list.extend(ko_list)
@ -126,10 +131,10 @@ class LangSegmenter():
# 未存在非日韩文夹日韩文
if len(temp_list) == 1:
# 未知语言检查是否为CJK
if dict_item['lang'] == 'x':
cjk_text = full_cjk(dict_item['text'])
if dict_item["lang"] == "x":
cjk_text = full_cjk(dict_item["text"])
if cjk_text:
dict_item = {'lang':'zh','text':cjk_text}
dict_item = {"lang": "zh", "text": cjk_text}
lang_list = merge_lang(lang_list, dict_item)
else:
lang_list = merge_lang(lang_list, dict_item)
@ -141,10 +146,10 @@ class LangSegmenter():
# 存在非日韩文夹日韩文
for _, temp_item in enumerate(temp_list):
# 未知语言检查是否为CJK
if temp_item['lang'] == 'x':
cjk_text = full_cjk(dict_item['text'])
if temp_item["lang"] == "x":
cjk_text = full_cjk(dict_item["text"])
if cjk_text:
dict_item = {'lang':'zh','text':cjk_text}
dict_item = {"lang": "zh", "text": cjk_text}
lang_list = merge_lang(lang_list, dict_item)
else:
lang_list = merge_lang(lang_list, dict_item)
@ -154,13 +159,13 @@ class LangSegmenter():
temp_list = lang_list
lang_list = []
for _, temp_item in enumerate(temp_list):
if temp_item['lang'] == 'x':
if temp_item["lang"] == "x":
if lang_list:
temp_item['lang'] = lang_list[-1]['lang']
temp_item["lang"] = lang_list[-1]["lang"]
elif len(temp_list) > 1:
temp_item['lang'] = temp_list[1]['lang']
temp_item["lang"] = temp_list[1]["lang"]
else:
temp_item['lang'] = 'zh'
temp_item["lang"] = "zh"
lang_list = merge_lang(lang_list, temp_item)

View File

@ -3,7 +3,6 @@
import json
import os
import traceback
import warnings
import zipfile
from typing import Any, Dict, List, Tuple

View File

@ -655,11 +655,7 @@ class ToneSandhi:
while i < len(seg):
word, pos = seg[i]
merged = False
if (
i - 1 >= 0
and word == ""
and i + 1 < len(seg)
):
if i - 1 >= 0 and word == "" and i + 1 < len(seg):
last = new_seg[-1] if new_seg else seg[i - 1]
if last[0] == seg[i + 1][0] and last[1] == "v" and seg[i + 1][1] == "v":
combined = last[0] + "" + seg[i + 1][0]

38
api.py
View File

@ -199,6 +199,8 @@ def is_full(*items): # 任意一项为空返回False
bigvgan_model = hifigan_model = sv_cn_model = None
def clean_hifigan_model():
global hifigan_model
if hifigan_model:
@ -208,6 +210,8 @@ def clean_hifigan_model():
torch.cuda.empty_cache()
except:
pass
def clean_bigvgan_model():
global bigvgan_model
if bigvgan_model:
@ -217,6 +221,8 @@ def clean_bigvgan_model():
torch.cuda.empty_cache()
except:
pass
def clean_sv_cn_model():
global sv_cn_model
if sv_cn_model:
@ -262,7 +268,9 @@ def init_hifigan():
hifigan_model.eval()
hifigan_model.remove_weight_norm()
state_dict_g = torch.load(
"%s/GPT_SoVITS/pretrained_models/gsv-v4-pretrained/vocoder.pth" % (now_dir,), map_location="cpu", weights_only=False
"%s/GPT_SoVITS/pretrained_models/gsv-v4-pretrained/vocoder.pth" % (now_dir,),
map_location="cpu",
weights_only=False,
)
print("loading vocoder", hifigan_model.load_state_dict(state_dict_g))
if is_half == True:
@ -272,19 +280,21 @@ def init_hifigan():
from sv import SV
def init_sv_cn():
global hifigan_model, bigvgan_model, sv_cn_model
sv_cn_model = SV(device, is_half)
resample_transform_dict = {}
def resample(audio_tensor, sr0, sr1, device):
global resample_transform_dict
key = "%s-%s-%s" % (sr0, sr1, str(device))
if key not in resample_transform_dict:
resample_transform_dict[key] = torchaudio.transforms.Resample(
sr0, sr1
).to(device)
resample_transform_dict[key] = torchaudio.transforms.Resample(sr0, sr1).to(device)
return resample_transform_dict[key](audio_tensor)
@ -370,6 +380,7 @@ from process_ckpt import get_sovits_version_from_path_fast, load_sovits_new
def get_sovits_weights(sovits_path):
from config import pretrained_sovits_name
path_sovits_v3 = pretrained_sovits_name["v3"]
path_sovits_v4 = pretrained_sovits_name["v4"]
is_exist_s2gv3 = os.path.exists(path_sovits_v3)
@ -632,11 +643,13 @@ def get_spepc(hps, filename, dtype, device, is_v2pro=False):
audio, sr0 = torchaudio.load(filename)
if sr0 != sr1:
audio = audio.to(device)
if(audio.shape[0]==2):audio=audio.mean(0).unsqueeze(0)
if audio.shape[0] == 2:
audio = audio.mean(0).unsqueeze(0)
audio = resample(audio, sr0, sr1, device)
else:
audio = audio.to(device)
if(audio.shape[0]==2):audio=audio.mean(0).unsqueeze(0)
if audio.shape[0] == 2:
audio = audio.mean(0).unsqueeze(0)
maxx = audio.abs().max()
if maxx > 1:
@ -937,14 +950,22 @@ def get_tts_wav(
if version not in {"v3", "v4"}:
if is_v2pro:
audio = (
vq_model.decode(pred_semantic, torch.LongTensor(phones2).to(device).unsqueeze(0), refers, speed=speed,sv_emb=sv_emb)
vq_model.decode(
pred_semantic,
torch.LongTensor(phones2).to(device).unsqueeze(0),
refers,
speed=speed,
sv_emb=sv_emb,
)
.detach()
.cpu()
.numpy()[0, 0]
)
else:
audio = (
vq_model.decode(pred_semantic, torch.LongTensor(phones2).to(device).unsqueeze(0), refers, speed=speed)
vq_model.decode(
pred_semantic, torch.LongTensor(phones2).to(device).unsqueeze(0), refers, speed=speed
)
.detach()
.cpu()
.numpy()[0, 0]
@ -1108,7 +1129,6 @@ def handle(
if not default_refer.is_ready():
return JSONResponse({"code": 400, "message": "未指定参考音频且接口无预设"}, status_code=400)
if cut_punc == None:
text = cut_text(text, default_cut_punc)
else:

View File

@ -144,6 +144,7 @@ webui_port_subfix = 9871
api_port = 9880
# Thanks to the contribution of @Karasukaigan and @XXXXRT666
def get_device_dtype_sm(idx: int) -> tuple[torch.device, torch.dtype, float, float]:
cpu = torch.device("cpu")
@ -158,9 +159,12 @@ def get_device_dtype_sm(idx: int) -> tuple[torch.device, torch.dtype, float, flo
major, minor = capability
sm_version = major + minor / 10.0
is_16_series = bool(re.search(r"16\d{2}", name)) and sm_version == 7.5
if mem_gb < 4 or sm_version < 5.3:return cpu, torch.float32, 0.0, 0.0
if sm_version == 6.1 or is_16_series==True:return cuda, torch.float32, sm_version, mem_gb
if sm_version > 6.1:return cuda, torch.float16, sm_version, mem_gb
if mem_gb < 4 or sm_version < 5.3:
return cpu, torch.float32, 0.0, 0.0
if sm_version == 6.1 or is_16_series == True:
return cuda, torch.float32, sm_version, mem_gb
if sm_version > 6.1:
return cuda, torch.float16, sm_version, mem_gb
return cpu, torch.float32, 0.0, 0.0

View File

@ -190,14 +190,14 @@ class Predictor:
opt_path_vocal = path_vocal[:-4] + ".%s" % format
opt_path_other = path_other[:-4] + ".%s" % format
if os.path.exists(path_vocal):
os.system("ffmpeg -i \"%s\" -vn \"%s\" -q:a 2 -y" % (path_vocal, opt_path_vocal))
os.system('ffmpeg -i "%s" -vn "%s" -q:a 2 -y' % (path_vocal, opt_path_vocal))
if os.path.exists(opt_path_vocal):
try:
os.remove(path_vocal)
except:
pass
if os.path.exists(path_other):
os.system("ffmpeg -i \"%s\" -vn \"%s\" -q:a 2 -y" % (path_other, opt_path_other))
os.system('ffmpeg -i "%s" -vn "%s" -q:a 2 -y' % (path_other, opt_path_other))
if os.path.exists(opt_path_other):
try:
os.remove(path_other)

View File

@ -140,7 +140,7 @@ class AudioPre:
)
if os.path.exists(path):
opt_format_path = path[:-4] + ".%s" % format
cmd="ffmpeg -i \"%s\" -vn \"%s\" -q:a 2 -y" % (path, opt_format_path)
cmd = 'ffmpeg -i "%s" -vn "%s" -q:a 2 -y' % (path, opt_format_path)
print(cmd)
os.system(cmd)
if os.path.exists(opt_format_path):
@ -177,7 +177,7 @@ class AudioPre:
)
if os.path.exists(path):
opt_format_path = path[:-4] + ".%s" % format
cmd="ffmpeg -i \"%s\" -vn \"%s\" -q:a 2 -y" % (path, opt_format_path)
cmd = 'ffmpeg -i "%s" -vn "%s" -q:a 2 -y' % (path, opt_format_path)
print(cmd)
os.system(cmd)
if os.path.exists(opt_format_path):
@ -307,7 +307,7 @@ class AudioPreDeEcho:
)
if os.path.exists(path):
opt_format_path = path[:-4] + ".%s" % format
cmd="ffmpeg -i \"%s\" -vn \"%s\" -q:a 2 -y" % (path, opt_format_path)
cmd = 'ffmpeg -i "%s" -vn "%s" -q:a 2 -y' % (path, opt_format_path)
print(cmd)
os.system(cmd)
if os.path.exists(opt_format_path):
@ -340,7 +340,7 @@ class AudioPreDeEcho:
)
if os.path.exists(path):
opt_format_path = path[:-4] + ".%s" % format
cmd="ffmpeg -i \"%s\" -vn \"%s\" -q:a 2 -y" % (path, opt_format_path)
cmd = 'ffmpeg -i "%s" -vn "%s" -q:a 2 -y' % (path, opt_format_path)
print(cmd)
os.system(cmd)
if os.path.exists(opt_format_path):