mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2025-08-11 19:09:51 +08:00
Format
This commit is contained in:
parent
a8f366ac14
commit
1e59f757a2
@ -28,7 +28,8 @@ class Text2SemanticLightningModule(LightningModule):
|
|||||||
self.load_state_dict(
|
self.load_state_dict(
|
||||||
torch.load(
|
torch.load(
|
||||||
pretrained_s1,
|
pretrained_s1,
|
||||||
map_location="cpu", weights_only=False,
|
map_location="cpu",
|
||||||
|
weights_only=False,
|
||||||
)["weight"],
|
)["weight"],
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
@ -32,19 +32,21 @@ from transformers import AutoModelForMaskedLM, AutoTokenizer
|
|||||||
|
|
||||||
from tools.audio_sr import AP_BWE
|
from tools.audio_sr import AP_BWE
|
||||||
from tools.i18n.i18n import I18nAuto, scan_language_list
|
from tools.i18n.i18n import I18nAuto, scan_language_list
|
||||||
from tools.my_utils import load_audio
|
|
||||||
from TTS_infer_pack.text_segmentation_method import splits
|
from TTS_infer_pack.text_segmentation_method import splits
|
||||||
from TTS_infer_pack.TextPreprocessor import TextPreprocessor
|
from TTS_infer_pack.TextPreprocessor import TextPreprocessor
|
||||||
from sv import SV
|
from sv import SV
|
||||||
|
|
||||||
resample_transform_dict = {}
|
resample_transform_dict = {}
|
||||||
|
|
||||||
|
|
||||||
def resample(audio_tensor, sr0, sr1, device):
|
def resample(audio_tensor, sr0, sr1, device):
|
||||||
global resample_transform_dict
|
global resample_transform_dict
|
||||||
key = "%s-%s-%s" % (sr0, sr1, str(device))
|
key = "%s-%s-%s" % (sr0, sr1, str(device))
|
||||||
if key not in resample_transform_dict:
|
if key not in resample_transform_dict:
|
||||||
resample_transform_dict[key] = torchaudio.transforms.Resample(
|
resample_transform_dict[key] = torchaudio.transforms.Resample(sr0, sr1).to(device)
|
||||||
sr0, sr1
|
|
||||||
).to(device)
|
|
||||||
return resample_transform_dict[key](audio_tensor)
|
return resample_transform_dict[key](audio_tensor)
|
||||||
|
|
||||||
|
|
||||||
language = os.environ.get("language", "Auto")
|
language = os.environ.get("language", "Auto")
|
||||||
language = sys.argv[-1] if sys.argv[-1] in scan_language_list() else language
|
language = sys.argv[-1] if sys.argv[-1] in scan_language_list() else language
|
||||||
i18n = I18nAuto(language=language)
|
i18n = I18nAuto(language=language)
|
||||||
@ -111,6 +113,7 @@ def speed_change(input_audio: np.ndarray, speed: float, sr: int):
|
|||||||
|
|
||||||
return processed_audio
|
return processed_audio
|
||||||
|
|
||||||
|
|
||||||
class DictToAttrRecursive(dict):
|
class DictToAttrRecursive(dict):
|
||||||
def __init__(self, input_dict):
|
def __init__(self, input_dict):
|
||||||
super().__init__(input_dict)
|
super().__init__(input_dict)
|
||||||
@ -632,7 +635,9 @@ class TTS:
|
|||||||
)
|
)
|
||||||
self.vocoder.remove_weight_norm()
|
self.vocoder.remove_weight_norm()
|
||||||
state_dict_g = torch.load(
|
state_dict_g = torch.load(
|
||||||
"%s/GPT_SoVITS/pretrained_models/gsv-v4-pretrained/vocoder.pth" % (now_dir,), map_location="cpu", weights_only=False
|
"%s/GPT_SoVITS/pretrained_models/gsv-v4-pretrained/vocoder.pth" % (now_dir,),
|
||||||
|
map_location="cpu",
|
||||||
|
weights_only=False,
|
||||||
)
|
)
|
||||||
print("loading vocoder", self.vocoder.load_state_dict(state_dict_g))
|
print("loading vocoder", self.vocoder.load_state_dict(state_dict_g))
|
||||||
|
|
||||||
@ -752,11 +757,13 @@ class TTS:
|
|||||||
|
|
||||||
if raw_sr != self.configs.sampling_rate:
|
if raw_sr != self.configs.sampling_rate:
|
||||||
audio = raw_audio.to(self.configs.device)
|
audio = raw_audio.to(self.configs.device)
|
||||||
if (audio.shape[0] == 2): audio = audio.mean(0).unsqueeze(0)
|
if audio.shape[0] == 2:
|
||||||
|
audio = audio.mean(0).unsqueeze(0)
|
||||||
audio = resample(audio, raw_sr, self.configs.sampling_rate, self.configs.device)
|
audio = resample(audio, raw_sr, self.configs.sampling_rate, self.configs.device)
|
||||||
else:
|
else:
|
||||||
audio = raw_audio.to(self.configs.device)
|
audio = raw_audio.to(self.configs.device)
|
||||||
if (audio.shape[0] == 2): audio = audio.mean(0).unsqueeze(0)
|
if audio.shape[0] == 2:
|
||||||
|
audio = audio.mean(0).unsqueeze(0)
|
||||||
|
|
||||||
maxx = audio.abs().max()
|
maxx = audio.abs().max()
|
||||||
if maxx > 1:
|
if maxx > 1:
|
||||||
@ -775,7 +782,8 @@ class TTS:
|
|||||||
audio = resample(audio, self.configs.sampling_rate, 16000, self.configs.device)
|
audio = resample(audio, self.configs.sampling_rate, 16000, self.configs.device)
|
||||||
if self.configs.is_half:
|
if self.configs.is_half:
|
||||||
audio = audio.half()
|
audio = audio.half()
|
||||||
else:audio=None
|
else:
|
||||||
|
audio = None
|
||||||
return spec, audio
|
return spec, audio
|
||||||
|
|
||||||
def _set_prompt_semantic(self, ref_wav_path: str):
|
def _set_prompt_semantic(self, ref_wav_path: str):
|
||||||
@ -1073,7 +1081,10 @@ class TTS:
|
|||||||
|
|
||||||
###### setting reference audio and prompt text preprocessing ########
|
###### setting reference audio and prompt text preprocessing ########
|
||||||
t0 = time.perf_counter()
|
t0 = time.perf_counter()
|
||||||
if (ref_audio_path is not None) and (ref_audio_path != self.prompt_cache["ref_audio_path"] or (self.is_v2pro and self.prompt_cache["refer_spec"][0][1] is None)):
|
if (ref_audio_path is not None) and (
|
||||||
|
ref_audio_path != self.prompt_cache["ref_audio_path"]
|
||||||
|
or (self.is_v2pro and self.prompt_cache["refer_spec"][0][1] is None)
|
||||||
|
):
|
||||||
if not os.path.exists(ref_audio_path):
|
if not os.path.exists(ref_audio_path):
|
||||||
raise ValueError(f"{ref_audio_path} not exists")
|
raise ValueError(f"{ref_audio_path} not exists")
|
||||||
self.set_ref_audio(ref_audio_path)
|
self.set_ref_audio(ref_audio_path)
|
||||||
@ -1212,7 +1223,8 @@ class TTS:
|
|||||||
t_34 += t4 - t3
|
t_34 += t4 - t3
|
||||||
|
|
||||||
refer_audio_spec = []
|
refer_audio_spec = []
|
||||||
if self.is_v2pro:sv_emb=[]
|
if self.is_v2pro:
|
||||||
|
sv_emb = []
|
||||||
for spec, audio_tensor in self.prompt_cache["refer_spec"]:
|
for spec, audio_tensor in self.prompt_cache["refer_spec"]:
|
||||||
spec = spec.to(dtype=self.precision, device=self.configs.device)
|
spec = spec.to(dtype=self.precision, device=self.configs.device)
|
||||||
refer_audio_spec.append(spec)
|
refer_audio_spec.append(spec)
|
||||||
@ -1250,9 +1262,13 @@ class TTS:
|
|||||||
)
|
)
|
||||||
_batch_phones = torch.cat(batch_phones).unsqueeze(0).to(self.configs.device)
|
_batch_phones = torch.cat(batch_phones).unsqueeze(0).to(self.configs.device)
|
||||||
if self.is_v2pro != True:
|
if self.is_v2pro != True:
|
||||||
_batch_audio_fragment = self.vits_model.decode(all_pred_semantic, _batch_phones, refer_audio_spec, speed=speed_factor).detach()[0, 0, :]
|
_batch_audio_fragment = self.vits_model.decode(
|
||||||
|
all_pred_semantic, _batch_phones, refer_audio_spec, speed=speed_factor
|
||||||
|
).detach()[0, 0, :]
|
||||||
else:
|
else:
|
||||||
_batch_audio_fragment = self.vits_model.decode(all_pred_semantic, _batch_phones, refer_audio_spec, speed=speed_factor,sv_emb=sv_emb).detach()[0, 0, :]
|
_batch_audio_fragment = self.vits_model.decode(
|
||||||
|
all_pred_semantic, _batch_phones, refer_audio_spec, speed=speed_factor, sv_emb=sv_emb
|
||||||
|
).detach()[0, 0, :]
|
||||||
audio_frag_end_idx.insert(0, 0)
|
audio_frag_end_idx.insert(0, 0)
|
||||||
batch_audio_fragment = [
|
batch_audio_fragment = [
|
||||||
_batch_audio_fragment[audio_frag_end_idx[i - 1] : audio_frag_end_idx[i]]
|
_batch_audio_fragment[audio_frag_end_idx[i - 1] : audio_frag_end_idx[i]]
|
||||||
@ -1266,9 +1282,13 @@ class TTS:
|
|||||||
pred_semantic_list[i][-idx:].unsqueeze(0).unsqueeze(0)
|
pred_semantic_list[i][-idx:].unsqueeze(0).unsqueeze(0)
|
||||||
) # .unsqueeze(0)#mq要多unsqueeze一次
|
) # .unsqueeze(0)#mq要多unsqueeze一次
|
||||||
if self.is_v2pro != True:
|
if self.is_v2pro != True:
|
||||||
audio_fragment = self.vits_model.decode(_pred_semantic, phones, refer_audio_spec, speed=speed_factor).detach()[0, 0, :]
|
audio_fragment = self.vits_model.decode(
|
||||||
|
_pred_semantic, phones, refer_audio_spec, speed=speed_factor
|
||||||
|
).detach()[0, 0, :]
|
||||||
else:
|
else:
|
||||||
audio_fragment = self.vits_model.decode(_pred_semantic, phones, refer_audio_spec, speed=speed_factor,sv_emb=sv_emb).detach()[0, 0, :]
|
audio_fragment = self.vits_model.decode(
|
||||||
|
_pred_semantic, phones, refer_audio_spec, speed=speed_factor, sv_emb=sv_emb
|
||||||
|
).detach()[0, 0, :]
|
||||||
batch_audio_fragment.append(audio_fragment) ###试试重建不带上prompt部分
|
batch_audio_fragment.append(audio_fragment) ###试试重建不带上prompt部分
|
||||||
else:
|
else:
|
||||||
if parallel_infer:
|
if parallel_infer:
|
||||||
|
@ -160,7 +160,9 @@ class TextPreprocessor:
|
|||||||
else:
|
else:
|
||||||
for tmp in LangSegmenter.getTexts(text):
|
for tmp in LangSegmenter.getTexts(text):
|
||||||
if langlist:
|
if langlist:
|
||||||
if (tmp["lang"] == "en" and langlist[-1] == "en") or (tmp["lang"] != "en" and langlist[-1] != "en"):
|
if (tmp["lang"] == "en" and langlist[-1] == "en") or (
|
||||||
|
tmp["lang"] != "en" and langlist[-1] != "en"
|
||||||
|
):
|
||||||
textlist[-1] += tmp["text"]
|
textlist[-1] += tmp["text"]
|
||||||
continue
|
continue
|
||||||
if tmp["lang"] == "en":
|
if tmp["lang"] == "en":
|
||||||
|
@ -8,7 +8,6 @@
|
|||||||
The global feature fusion (GFF) takes acoustic features of different scales as input to aggregate global signal.
|
The global feature fusion (GFF) takes acoustic features of different scales as input to aggregate global signal.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import math
|
import math
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
@ -16,15 +15,14 @@ import torch.nn.functional as F
|
|||||||
import pooling_layers as pooling_layers
|
import pooling_layers as pooling_layers
|
||||||
from fusion import AFF
|
from fusion import AFF
|
||||||
|
|
||||||
class ReLU(nn.Hardtanh):
|
|
||||||
|
|
||||||
|
class ReLU(nn.Hardtanh):
|
||||||
def __init__(self, inplace=False):
|
def __init__(self, inplace=False):
|
||||||
super(ReLU, self).__init__(0, 20, inplace)
|
super(ReLU, self).__init__(0, 20, inplace)
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
inplace_str = 'inplace' if self.inplace else ''
|
inplace_str = "inplace" if self.inplace else ""
|
||||||
return self.__class__.__name__ + ' (' \
|
return self.__class__.__name__ + " (" + inplace_str + ")"
|
||||||
+ inplace_str + ')'
|
|
||||||
|
|
||||||
|
|
||||||
class BasicBlockERes2Net(nn.Module):
|
class BasicBlockERes2Net(nn.Module):
|
||||||
@ -51,9 +49,9 @@ class BasicBlockERes2Net(nn.Module):
|
|||||||
self.shortcut = nn.Sequential()
|
self.shortcut = nn.Sequential()
|
||||||
if stride != 1 or in_planes != self.expansion * planes:
|
if stride != 1 or in_planes != self.expansion * planes:
|
||||||
self.shortcut = nn.Sequential(
|
self.shortcut = nn.Sequential(
|
||||||
nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1,
|
nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
|
||||||
stride=stride, bias=False),
|
nn.BatchNorm2d(self.expansion * planes),
|
||||||
nn.BatchNorm2d(self.expansion * planes))
|
)
|
||||||
self.stride = stride
|
self.stride = stride
|
||||||
self.width = width
|
self.width = width
|
||||||
self.scale = scale
|
self.scale = scale
|
||||||
@ -86,6 +84,7 @@ class BasicBlockERes2Net(nn.Module):
|
|||||||
|
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
class BasicBlockERes2Net_diff_AFF(nn.Module):
|
class BasicBlockERes2Net_diff_AFF(nn.Module):
|
||||||
expansion = 2
|
expansion = 2
|
||||||
|
|
||||||
@ -115,9 +114,9 @@ class BasicBlockERes2Net_diff_AFF(nn.Module):
|
|||||||
self.shortcut = nn.Sequential()
|
self.shortcut = nn.Sequential()
|
||||||
if stride != 1 or in_planes != self.expansion * planes:
|
if stride != 1 or in_planes != self.expansion * planes:
|
||||||
self.shortcut = nn.Sequential(
|
self.shortcut = nn.Sequential(
|
||||||
nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1,
|
nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
|
||||||
stride=stride, bias=False),
|
nn.BatchNorm2d(self.expansion * planes),
|
||||||
nn.BatchNorm2d(self.expansion * planes))
|
)
|
||||||
self.stride = stride
|
self.stride = stride
|
||||||
self.width = width
|
self.width = width
|
||||||
self.scale = scale
|
self.scale = scale
|
||||||
@ -151,16 +150,19 @@ class BasicBlockERes2Net_diff_AFF(nn.Module):
|
|||||||
|
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
class ERes2Net(nn.Module):
|
class ERes2Net(nn.Module):
|
||||||
def __init__(self,
|
def __init__(
|
||||||
|
self,
|
||||||
block=BasicBlockERes2Net,
|
block=BasicBlockERes2Net,
|
||||||
block_fuse=BasicBlockERes2Net_diff_AFF,
|
block_fuse=BasicBlockERes2Net_diff_AFF,
|
||||||
num_blocks=[3, 4, 6, 3],
|
num_blocks=[3, 4, 6, 3],
|
||||||
m_channels=32,
|
m_channels=32,
|
||||||
feat_dim=80,
|
feat_dim=80,
|
||||||
embedding_size=192,
|
embedding_size=192,
|
||||||
pooling_func='TSTP',
|
pooling_func="TSTP",
|
||||||
two_emb_layer=False):
|
two_emb_layer=False,
|
||||||
|
):
|
||||||
super(ERes2Net, self).__init__()
|
super(ERes2Net, self).__init__()
|
||||||
self.in_planes = m_channels
|
self.in_planes = m_channels
|
||||||
self.feat_dim = feat_dim
|
self.feat_dim = feat_dim
|
||||||
@ -176,20 +178,24 @@ class ERes2Net(nn.Module):
|
|||||||
self.layer4 = self._make_layer(block_fuse, m_channels * 8, num_blocks[3], stride=2)
|
self.layer4 = self._make_layer(block_fuse, m_channels * 8, num_blocks[3], stride=2)
|
||||||
|
|
||||||
# Downsampling module for each layer
|
# Downsampling module for each layer
|
||||||
self.layer1_downsample = nn.Conv2d(m_channels * 2, m_channels * 4, kernel_size=3, stride=2, padding=1, bias=False)
|
self.layer1_downsample = nn.Conv2d(
|
||||||
self.layer2_downsample = nn.Conv2d(m_channels * 4, m_channels * 8, kernel_size=3, padding=1, stride=2, bias=False)
|
m_channels * 2, m_channels * 4, kernel_size=3, stride=2, padding=1, bias=False
|
||||||
self.layer3_downsample = nn.Conv2d(m_channels * 8, m_channels * 16, kernel_size=3, padding=1, stride=2, bias=False)
|
)
|
||||||
|
self.layer2_downsample = nn.Conv2d(
|
||||||
|
m_channels * 4, m_channels * 8, kernel_size=3, padding=1, stride=2, bias=False
|
||||||
|
)
|
||||||
|
self.layer3_downsample = nn.Conv2d(
|
||||||
|
m_channels * 8, m_channels * 16, kernel_size=3, padding=1, stride=2, bias=False
|
||||||
|
)
|
||||||
|
|
||||||
# Bottom-up fusion module
|
# Bottom-up fusion module
|
||||||
self.fuse_mode12 = AFF(channels=m_channels * 4)
|
self.fuse_mode12 = AFF(channels=m_channels * 4)
|
||||||
self.fuse_mode123 = AFF(channels=m_channels * 8)
|
self.fuse_mode123 = AFF(channels=m_channels * 8)
|
||||||
self.fuse_mode1234 = AFF(channels=m_channels * 16)
|
self.fuse_mode1234 = AFF(channels=m_channels * 16)
|
||||||
|
|
||||||
self.n_stats = 1 if pooling_func == 'TAP' or pooling_func == "TSDP" else 2
|
self.n_stats = 1 if pooling_func == "TAP" or pooling_func == "TSDP" else 2
|
||||||
self.pool = getattr(pooling_layers, pooling_func)(
|
self.pool = getattr(pooling_layers, pooling_func)(in_dim=self.stats_dim * block.expansion)
|
||||||
in_dim=self.stats_dim * block.expansion)
|
self.seg_1 = nn.Linear(self.stats_dim * block.expansion * self.n_stats, embedding_size)
|
||||||
self.seg_1 = nn.Linear(self.stats_dim * block.expansion * self.n_stats,
|
|
||||||
embedding_size)
|
|
||||||
if self.two_emb_layer:
|
if self.two_emb_layer:
|
||||||
self.seg_bn_1 = nn.BatchNorm1d(embedding_size, affine=False)
|
self.seg_bn_1 = nn.BatchNorm1d(embedding_size, affine=False)
|
||||||
self.seg_2 = nn.Linear(embedding_size, embedding_size)
|
self.seg_2 = nn.Linear(embedding_size, embedding_size)
|
||||||
@ -247,14 +253,12 @@ class ERes2Net(nn.Module):
|
|||||||
return fuse_out1234
|
return fuse_out1234
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == "__main__":
|
||||||
|
|
||||||
x = torch.zeros(10, 300, 80)
|
x = torch.zeros(10, 300, 80)
|
||||||
model = ERes2Net(feat_dim=80, embedding_size=192, pooling_func='TSTP')
|
model = ERes2Net(feat_dim=80, embedding_size=192, pooling_func="TSTP")
|
||||||
model.eval()
|
model.eval()
|
||||||
out = model(x)
|
out = model(x)
|
||||||
print(out.shape) # torch.Size([10, 192])
|
print(out.shape) # torch.Size([10, 192])
|
||||||
|
|
||||||
num_params = sum(param.numel() for param in model.parameters())
|
num_params = sum(param.numel() for param in model.parameters())
|
||||||
print("{} M".format(num_params / 1e6)) # 6.61M
|
print("{} M".format(num_params / 1e6)) # 6.61M
|
||||||
|
|
||||||
|
@ -8,8 +8,6 @@
|
|||||||
both the model parameters and its computational cost.
|
both the model parameters and its computational cost.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import math
|
import math
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
@ -17,19 +15,17 @@ import torch.nn.functional as F
|
|||||||
import pooling_layers as pooling_layers
|
import pooling_layers as pooling_layers
|
||||||
from fusion import AFF
|
from fusion import AFF
|
||||||
|
|
||||||
class ReLU(nn.Hardtanh):
|
|
||||||
|
|
||||||
|
class ReLU(nn.Hardtanh):
|
||||||
def __init__(self, inplace=False):
|
def __init__(self, inplace=False):
|
||||||
super(ReLU, self).__init__(0, 20, inplace)
|
super(ReLU, self).__init__(0, 20, inplace)
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
inplace_str = 'inplace' if self.inplace else ''
|
inplace_str = "inplace" if self.inplace else ""
|
||||||
return self.__class__.__name__ + ' (' \
|
return self.__class__.__name__ + " (" + inplace_str + ")"
|
||||||
+ inplace_str + ')'
|
|
||||||
|
|
||||||
|
|
||||||
class BasicBlockERes2NetV2(nn.Module):
|
class BasicBlockERes2NetV2(nn.Module):
|
||||||
|
|
||||||
def __init__(self, in_planes, planes, stride=1, baseWidth=26, scale=2, expansion=2):
|
def __init__(self, in_planes, planes, stride=1, baseWidth=26, scale=2, expansion=2):
|
||||||
super(BasicBlockERes2NetV2, self).__init__()
|
super(BasicBlockERes2NetV2, self).__init__()
|
||||||
width = int(math.floor(planes * (baseWidth / 64.0)))
|
width = int(math.floor(planes * (baseWidth / 64.0)))
|
||||||
@ -52,12 +48,9 @@ class BasicBlockERes2NetV2(nn.Module):
|
|||||||
self.shortcut = nn.Sequential()
|
self.shortcut = nn.Sequential()
|
||||||
if stride != 1 or in_planes != self.expansion * planes:
|
if stride != 1 or in_planes != self.expansion * planes:
|
||||||
self.shortcut = nn.Sequential(
|
self.shortcut = nn.Sequential(
|
||||||
nn.Conv2d(in_planes,
|
nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
|
||||||
self.expansion * planes,
|
nn.BatchNorm2d(self.expansion * planes),
|
||||||
kernel_size=1,
|
)
|
||||||
stride=stride,
|
|
||||||
bias=False),
|
|
||||||
nn.BatchNorm2d(self.expansion * planes))
|
|
||||||
self.stride = stride
|
self.stride = stride
|
||||||
self.width = width
|
self.width = width
|
||||||
self.scale = scale
|
self.scale = scale
|
||||||
@ -90,8 +83,8 @@ class BasicBlockERes2NetV2(nn.Module):
|
|||||||
|
|
||||||
return out
|
return out
|
||||||
|
|
||||||
class BasicBlockERes2NetV2AFF(nn.Module):
|
|
||||||
|
|
||||||
|
class BasicBlockERes2NetV2AFF(nn.Module):
|
||||||
def __init__(self, in_planes, planes, stride=1, baseWidth=26, scale=2, expansion=2):
|
def __init__(self, in_planes, planes, stride=1, baseWidth=26, scale=2, expansion=2):
|
||||||
super(BasicBlockERes2NetV2AFF, self).__init__()
|
super(BasicBlockERes2NetV2AFF, self).__init__()
|
||||||
width = int(math.floor(planes * (baseWidth / 64.0)))
|
width = int(math.floor(planes * (baseWidth / 64.0)))
|
||||||
@ -119,12 +112,9 @@ class BasicBlockERes2NetV2AFF(nn.Module):
|
|||||||
self.shortcut = nn.Sequential()
|
self.shortcut = nn.Sequential()
|
||||||
if stride != 1 or in_planes != self.expansion * planes:
|
if stride != 1 or in_planes != self.expansion * planes:
|
||||||
self.shortcut = nn.Sequential(
|
self.shortcut = nn.Sequential(
|
||||||
nn.Conv2d(in_planes,
|
nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
|
||||||
self.expansion * planes,
|
nn.BatchNorm2d(self.expansion * planes),
|
||||||
kernel_size=1,
|
)
|
||||||
stride=stride,
|
|
||||||
bias=False),
|
|
||||||
nn.BatchNorm2d(self.expansion * planes))
|
|
||||||
self.stride = stride
|
self.stride = stride
|
||||||
self.width = width
|
self.width = width
|
||||||
self.scale = scale
|
self.scale = scale
|
||||||
@ -158,8 +148,10 @@ class BasicBlockERes2NetV2AFF(nn.Module):
|
|||||||
|
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
class ERes2NetV2(nn.Module):
|
class ERes2NetV2(nn.Module):
|
||||||
def __init__(self,
|
def __init__(
|
||||||
|
self,
|
||||||
block=BasicBlockERes2NetV2,
|
block=BasicBlockERes2NetV2,
|
||||||
block_fuse=BasicBlockERes2NetV2AFF,
|
block_fuse=BasicBlockERes2NetV2AFF,
|
||||||
num_blocks=[3, 4, 6, 3],
|
num_blocks=[3, 4, 6, 3],
|
||||||
@ -169,8 +161,9 @@ class ERes2NetV2(nn.Module):
|
|||||||
baseWidth=26,
|
baseWidth=26,
|
||||||
scale=2,
|
scale=2,
|
||||||
expansion=2,
|
expansion=2,
|
||||||
pooling_func='TSTP',
|
pooling_func="TSTP",
|
||||||
two_emb_layer=False):
|
two_emb_layer=False,
|
||||||
|
):
|
||||||
super(ERes2NetV2, self).__init__()
|
super(ERes2NetV2, self).__init__()
|
||||||
self.in_planes = m_channels
|
self.in_planes = m_channels
|
||||||
self.feat_dim = feat_dim
|
self.feat_dim = feat_dim
|
||||||
@ -181,42 +174,29 @@ class ERes2NetV2(nn.Module):
|
|||||||
self.scale = scale
|
self.scale = scale
|
||||||
self.expansion = expansion
|
self.expansion = expansion
|
||||||
|
|
||||||
self.conv1 = nn.Conv2d(1,
|
self.conv1 = nn.Conv2d(1, m_channels, kernel_size=3, stride=1, padding=1, bias=False)
|
||||||
m_channels,
|
|
||||||
kernel_size=3,
|
|
||||||
stride=1,
|
|
||||||
padding=1,
|
|
||||||
bias=False)
|
|
||||||
self.bn1 = nn.BatchNorm2d(m_channels)
|
self.bn1 = nn.BatchNorm2d(m_channels)
|
||||||
self.layer1 = self._make_layer(block,
|
self.layer1 = self._make_layer(block, m_channels, num_blocks[0], stride=1)
|
||||||
m_channels,
|
self.layer2 = self._make_layer(block, m_channels * 2, num_blocks[1], stride=2)
|
||||||
num_blocks[0],
|
self.layer3 = self._make_layer(block_fuse, m_channels * 4, num_blocks[2], stride=2)
|
||||||
stride=1)
|
self.layer4 = self._make_layer(block_fuse, m_channels * 8, num_blocks[3], stride=2)
|
||||||
self.layer2 = self._make_layer(block,
|
|
||||||
m_channels * 2,
|
|
||||||
num_blocks[1],
|
|
||||||
stride=2)
|
|
||||||
self.layer3 = self._make_layer(block_fuse,
|
|
||||||
m_channels * 4,
|
|
||||||
num_blocks[2],
|
|
||||||
stride=2)
|
|
||||||
self.layer4 = self._make_layer(block_fuse,
|
|
||||||
m_channels * 8,
|
|
||||||
num_blocks[3],
|
|
||||||
stride=2)
|
|
||||||
|
|
||||||
# Downsampling module
|
# Downsampling module
|
||||||
self.layer3_ds = nn.Conv2d(m_channels * 4 * self.expansion, m_channels * 8 * self.expansion, kernel_size=3, \
|
self.layer3_ds = nn.Conv2d(
|
||||||
padding=1, stride=2, bias=False)
|
m_channels * 4 * self.expansion,
|
||||||
|
m_channels * 8 * self.expansion,
|
||||||
|
kernel_size=3,
|
||||||
|
padding=1,
|
||||||
|
stride=2,
|
||||||
|
bias=False,
|
||||||
|
)
|
||||||
|
|
||||||
# Bottom-up fusion module
|
# Bottom-up fusion module
|
||||||
self.fuse34 = AFF(channels=m_channels * 8 * self.expansion, r=4)
|
self.fuse34 = AFF(channels=m_channels * 8 * self.expansion, r=4)
|
||||||
|
|
||||||
self.n_stats = 1 if pooling_func == 'TAP' or pooling_func == "TSDP" else 2
|
self.n_stats = 1 if pooling_func == "TAP" or pooling_func == "TSDP" else 2
|
||||||
self.pool = getattr(pooling_layers, pooling_func)(
|
self.pool = getattr(pooling_layers, pooling_func)(in_dim=self.stats_dim * self.expansion)
|
||||||
in_dim=self.stats_dim * self.expansion)
|
self.seg_1 = nn.Linear(self.stats_dim * self.expansion * self.n_stats, embedding_size)
|
||||||
self.seg_1 = nn.Linear(self.stats_dim * self.expansion * self.n_stats,
|
|
||||||
embedding_size)
|
|
||||||
if self.two_emb_layer:
|
if self.two_emb_layer:
|
||||||
self.seg_bn_1 = nn.BatchNorm1d(embedding_size, affine=False)
|
self.seg_bn_1 = nn.BatchNorm1d(embedding_size, affine=False)
|
||||||
self.seg_2 = nn.Linear(embedding_size, embedding_size)
|
self.seg_2 = nn.Linear(embedding_size, embedding_size)
|
||||||
@ -228,7 +208,11 @@ class ERes2NetV2(nn.Module):
|
|||||||
strides = [stride] + [1] * (num_blocks - 1)
|
strides = [stride] + [1] * (num_blocks - 1)
|
||||||
layers = []
|
layers = []
|
||||||
for stride in strides:
|
for stride in strides:
|
||||||
layers.append(block(self.in_planes, planes, stride, baseWidth=self.baseWidth, scale=self.scale, expansion=self.expansion))
|
layers.append(
|
||||||
|
block(
|
||||||
|
self.in_planes, planes, stride, baseWidth=self.baseWidth, scale=self.scale, expansion=self.expansion
|
||||||
|
)
|
||||||
|
)
|
||||||
self.in_planes = planes * self.expansion
|
self.in_planes = planes * self.expansion
|
||||||
return nn.Sequential(*layers)
|
return nn.Sequential(*layers)
|
||||||
|
|
||||||
@ -276,8 +260,8 @@ class ERes2NetV2(nn.Module):
|
|||||||
# else:
|
# else:
|
||||||
# return embed_a
|
# return embed_a
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
x = torch.randn(1, 300, 80)
|
x = torch.randn(1, 300, 80)
|
||||||
model = ERes2NetV2(feat_dim=80, embedding_size=192, m_channels=64, baseWidth=26, scale=2, expansion=2)
|
model = ERes2NetV2(feat_dim=80, embedding_size=192, m_channels=64, baseWidth=26, scale=2, expansion=2)
|
||||||
model.eval()
|
model.eval()
|
||||||
@ -286,7 +270,3 @@ if __name__ == '__main__':
|
|||||||
macs, num_params = profile(model, inputs=(x,))
|
macs, num_params = profile(model, inputs=(x,))
|
||||||
print("Params: {} M".format(num_params / 1e6)) # 17.86 M
|
print("Params: {} M".format(num_params / 1e6)) # 17.86 M
|
||||||
print("MACs: {} G".format(macs / 1e9)) # 12.69 G
|
print("MACs: {} G".format(macs / 1e9)) # 12.69 G
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@ -8,7 +8,6 @@
|
|||||||
ERes2Net-huge is an upgraded version of ERes2Net that uses a larger number of parameters to achieve better
|
ERes2Net-huge is an upgraded version of ERes2Net that uses a larger number of parameters to achieve better
|
||||||
recognition performance. Parameters expansion, baseWidth, and scale can be modified to obtain optimal performance.
|
recognition performance. Parameters expansion, baseWidth, and scale can be modified to obtain optimal performance.
|
||||||
"""
|
"""
|
||||||
import pdb
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import math
|
import math
|
||||||
@ -17,15 +16,14 @@ import torch.nn.functional as F
|
|||||||
import pooling_layers as pooling_layers
|
import pooling_layers as pooling_layers
|
||||||
from fusion import AFF
|
from fusion import AFF
|
||||||
|
|
||||||
class ReLU(nn.Hardtanh):
|
|
||||||
|
|
||||||
|
class ReLU(nn.Hardtanh):
|
||||||
def __init__(self, inplace=False):
|
def __init__(self, inplace=False):
|
||||||
super(ReLU, self).__init__(0, 20, inplace)
|
super(ReLU, self).__init__(0, 20, inplace)
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
inplace_str = 'inplace' if self.inplace else ''
|
inplace_str = "inplace" if self.inplace else ""
|
||||||
return self.__class__.__name__ + ' (' \
|
return self.__class__.__name__ + " (" + inplace_str + ")"
|
||||||
+ inplace_str + ')'
|
|
||||||
|
|
||||||
|
|
||||||
class BasicBlockERes2Net(nn.Module):
|
class BasicBlockERes2Net(nn.Module):
|
||||||
@ -53,7 +51,8 @@ class BasicBlockERes2Net(nn.Module):
|
|||||||
if stride != 1 or in_planes != self.expansion * planes:
|
if stride != 1 or in_planes != self.expansion * planes:
|
||||||
self.shortcut = nn.Sequential(
|
self.shortcut = nn.Sequential(
|
||||||
nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
|
nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
|
||||||
nn.BatchNorm2d(self.expansion * planes))
|
nn.BatchNorm2d(self.expansion * planes),
|
||||||
|
)
|
||||||
self.stride = stride
|
self.stride = stride
|
||||||
self.width = width
|
self.width = width
|
||||||
self.scale = scale
|
self.scale = scale
|
||||||
@ -86,6 +85,7 @@ class BasicBlockERes2Net(nn.Module):
|
|||||||
|
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
class BasicBlockERes2Net_diff_AFF(nn.Module):
|
class BasicBlockERes2Net_diff_AFF(nn.Module):
|
||||||
expansion = 4
|
expansion = 4
|
||||||
|
|
||||||
@ -116,7 +116,8 @@ class BasicBlockERes2Net_diff_AFF(nn.Module):
|
|||||||
if stride != 1 or in_planes != self.expansion * planes:
|
if stride != 1 or in_planes != self.expansion * planes:
|
||||||
self.shortcut = nn.Sequential(
|
self.shortcut = nn.Sequential(
|
||||||
nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
|
nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
|
||||||
nn.BatchNorm2d(self.expansion * planes))
|
nn.BatchNorm2d(self.expansion * planes),
|
||||||
|
)
|
||||||
self.stride = stride
|
self.stride = stride
|
||||||
self.width = width
|
self.width = width
|
||||||
self.scale = scale
|
self.scale = scale
|
||||||
@ -141,7 +142,6 @@ class BasicBlockERes2Net_diff_AFF(nn.Module):
|
|||||||
else:
|
else:
|
||||||
out = torch.cat((out, sp), 1)
|
out = torch.cat((out, sp), 1)
|
||||||
|
|
||||||
|
|
||||||
out = self.conv3(out)
|
out = self.conv3(out)
|
||||||
out = self.bn3(out)
|
out = self.bn3(out)
|
||||||
|
|
||||||
@ -151,16 +151,19 @@ class BasicBlockERes2Net_diff_AFF(nn.Module):
|
|||||||
|
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
class ERes2Net(nn.Module):
|
class ERes2Net(nn.Module):
|
||||||
def __init__(self,
|
def __init__(
|
||||||
|
self,
|
||||||
block=BasicBlockERes2Net,
|
block=BasicBlockERes2Net,
|
||||||
block_fuse=BasicBlockERes2Net_diff_AFF,
|
block_fuse=BasicBlockERes2Net_diff_AFF,
|
||||||
num_blocks=[3, 4, 6, 3],
|
num_blocks=[3, 4, 6, 3],
|
||||||
m_channels=64,
|
m_channels=64,
|
||||||
feat_dim=80,
|
feat_dim=80,
|
||||||
embedding_size=192,
|
embedding_size=192,
|
||||||
pooling_func='TSTP',
|
pooling_func="TSTP",
|
||||||
two_emb_layer=False):
|
two_emb_layer=False,
|
||||||
|
):
|
||||||
super(ERes2Net, self).__init__()
|
super(ERes2Net, self).__init__()
|
||||||
self.in_planes = m_channels
|
self.in_planes = m_channels
|
||||||
self.feat_dim = feat_dim
|
self.feat_dim = feat_dim
|
||||||
@ -176,17 +179,22 @@ class ERes2Net(nn.Module):
|
|||||||
self.layer3 = self._make_layer(block_fuse, m_channels * 4, num_blocks[2], stride=2)
|
self.layer3 = self._make_layer(block_fuse, m_channels * 4, num_blocks[2], stride=2)
|
||||||
self.layer4 = self._make_layer(block_fuse, m_channels * 8, num_blocks[3], stride=2)
|
self.layer4 = self._make_layer(block_fuse, m_channels * 8, num_blocks[3], stride=2)
|
||||||
|
|
||||||
self.layer1_downsample = nn.Conv2d(m_channels * 4, m_channels * 8, kernel_size=3, padding=1, stride=2, bias=False)
|
self.layer1_downsample = nn.Conv2d(
|
||||||
self.layer2_downsample = nn.Conv2d(m_channels * 8, m_channels * 16, kernel_size=3, padding=1, stride=2, bias=False)
|
m_channels * 4, m_channels * 8, kernel_size=3, padding=1, stride=2, bias=False
|
||||||
self.layer3_downsample = nn.Conv2d(m_channels * 16, m_channels * 32, kernel_size=3, padding=1, stride=2, bias=False)
|
)
|
||||||
|
self.layer2_downsample = nn.Conv2d(
|
||||||
|
m_channels * 8, m_channels * 16, kernel_size=3, padding=1, stride=2, bias=False
|
||||||
|
)
|
||||||
|
self.layer3_downsample = nn.Conv2d(
|
||||||
|
m_channels * 16, m_channels * 32, kernel_size=3, padding=1, stride=2, bias=False
|
||||||
|
)
|
||||||
|
|
||||||
self.fuse_mode12 = AFF(channels=m_channels * 8)
|
self.fuse_mode12 = AFF(channels=m_channels * 8)
|
||||||
self.fuse_mode123 = AFF(channels=m_channels * 16)
|
self.fuse_mode123 = AFF(channels=m_channels * 16)
|
||||||
self.fuse_mode1234 = AFF(channels=m_channels * 32)
|
self.fuse_mode1234 = AFF(channels=m_channels * 32)
|
||||||
|
|
||||||
self.n_stats = 1 if pooling_func == 'TAP' or pooling_func == "TSDP" else 2
|
self.n_stats = 1 if pooling_func == "TAP" or pooling_func == "TSDP" else 2
|
||||||
self.pool = getattr(pooling_layers, pooling_func)(
|
self.pool = getattr(pooling_layers, pooling_func)(in_dim=self.stats_dim * block.expansion)
|
||||||
in_dim=self.stats_dim * block.expansion)
|
|
||||||
self.seg_1 = nn.Linear(self.stats_dim * block.expansion * self.n_stats, embedding_size)
|
self.seg_1 = nn.Linear(self.stats_dim * block.expansion * self.n_stats, embedding_size)
|
||||||
if self.two_emb_layer:
|
if self.two_emb_layer:
|
||||||
self.seg_bn_1 = nn.BatchNorm1d(embedding_size, affine=False)
|
self.seg_bn_1 = nn.BatchNorm1d(embedding_size, affine=False)
|
||||||
@ -244,14 +252,13 @@ class ERes2Net(nn.Module):
|
|||||||
out4 = self.layer4(out3)
|
out4 = self.layer4(out3)
|
||||||
fuse_out123_downsample = self.layer3_downsample(fuse_out123)
|
fuse_out123_downsample = self.layer3_downsample(fuse_out123)
|
||||||
fuse_out1234 = self.fuse_mode1234(out4, fuse_out123_downsample).flatten(start_dim=1, end_dim=2) # bs,20480,T
|
fuse_out1234 = self.fuse_mode1234(out4, fuse_out123_downsample).flatten(start_dim=1, end_dim=2) # bs,20480,T
|
||||||
if(if_mean==False):
|
if if_mean == False:
|
||||||
mean = fuse_out1234[0].transpose(1, 0) # (T,20480),bs=T
|
mean = fuse_out1234[0].transpose(1, 0) # (T,20480),bs=T
|
||||||
else:
|
else:
|
||||||
mean = fuse_out1234.mean(2) # bs,20480
|
mean = fuse_out1234.mean(2) # bs,20480
|
||||||
mean_std = torch.cat([mean, torch.zeros_like(mean)], 1)
|
mean_std = torch.cat([mean, torch.zeros_like(mean)], 1)
|
||||||
return self.seg_1(mean_std) # (T,192)
|
return self.seg_1(mean_std) # (T,192)
|
||||||
|
|
||||||
|
|
||||||
# stats = self.pool(fuse_out1234)
|
# stats = self.pool(fuse_out1234)
|
||||||
# if self.two_emb_layer:
|
# if self.two_emb_layer:
|
||||||
# out = F.relu(embed_a)
|
# out = F.relu(embed_a)
|
||||||
@ -280,7 +287,3 @@ class ERes2Net(nn.Module):
|
|||||||
# print(fuse_out1234.shape)
|
# print(fuse_out1234.shape)
|
||||||
# print(fuse_out1234.flatten(start_dim=1,end_dim=2).shape)
|
# print(fuse_out1234.flatten(start_dim=1,end_dim=2).shape)
|
||||||
# pdb.set_trace()
|
# pdb.set_trace()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@ -6,7 +6,6 @@ import torch.nn as nn
|
|||||||
|
|
||||||
|
|
||||||
class AFF(nn.Module):
|
class AFF(nn.Module):
|
||||||
|
|
||||||
def __init__(self, channels=64, r=4):
|
def __init__(self, channels=64, r=4):
|
||||||
super(AFF, self).__init__()
|
super(AFF, self).__init__()
|
||||||
inter_channels = int(channels // r)
|
inter_channels = int(channels // r)
|
||||||
@ -26,4 +25,3 @@ class AFF(nn.Module):
|
|||||||
xo = torch.mul(x, x_att) + torch.mul(ds_y, 2.0 - x_att)
|
xo = torch.mul(x, x_att) + torch.mul(ds_y, 2.0 - x_att)
|
||||||
|
|
||||||
return xo
|
return xo
|
||||||
|
|
||||||
|
@ -144,7 +144,7 @@ def _get_waveform_and_window_properties(
|
|||||||
)
|
)
|
||||||
assert 0 < window_shift, "`window_shift` must be greater than 0"
|
assert 0 < window_shift, "`window_shift` must be greater than 0"
|
||||||
assert padded_window_size % 2 == 0, (
|
assert padded_window_size % 2 == 0, (
|
||||||
"the padded `window_size` must be divisible by two." " use `round_to_power_of_two` or change `frame_length`"
|
"the padded `window_size` must be divisible by two. use `round_to_power_of_two` or change `frame_length`"
|
||||||
)
|
)
|
||||||
assert 0.0 <= preemphasis_coefficient <= 1.0, "`preemphasis_coefficient` must be between [0,1]"
|
assert 0.0 <= preemphasis_coefficient <= 1.0, "`preemphasis_coefficient` must be between [0,1]"
|
||||||
assert sample_frequency > 0, "`sample_frequency` must be greater than zero"
|
assert sample_frequency > 0, "`sample_frequency` must be greater than zero"
|
||||||
@ -441,7 +441,9 @@ def get_mel_banks(
|
|||||||
high_freq: float,
|
high_freq: float,
|
||||||
vtln_low: float,
|
vtln_low: float,
|
||||||
vtln_high: float,
|
vtln_high: float,
|
||||||
vtln_warp_factor: float,device=None,dtype=None
|
vtln_warp_factor: float,
|
||||||
|
device=None,
|
||||||
|
dtype=None,
|
||||||
) -> Tuple[Tensor, Tensor]:
|
) -> Tuple[Tensor, Tensor]:
|
||||||
"""
|
"""
|
||||||
Returns:
|
Returns:
|
||||||
@ -457,9 +459,9 @@ def get_mel_banks(
|
|||||||
if high_freq <= 0.0:
|
if high_freq <= 0.0:
|
||||||
high_freq += nyquist
|
high_freq += nyquist
|
||||||
|
|
||||||
assert (
|
assert (0.0 <= low_freq < nyquist) and (0.0 < high_freq <= nyquist) and (low_freq < high_freq), (
|
||||||
(0.0 <= low_freq < nyquist) and (0.0 < high_freq <= nyquist) and (low_freq < high_freq)
|
"Bad values in options: low-freq {} and high-freq {} vs. nyquist {}".format(low_freq, high_freq, nyquist)
|
||||||
), "Bad values in options: low-freq {} and high-freq {} vs. nyquist {}".format(low_freq, high_freq, nyquist)
|
)
|
||||||
|
|
||||||
# fft-bin width [think of it as Nyquist-freq / half-window-length]
|
# fft-bin width [think of it as Nyquist-freq / half-window-length]
|
||||||
fft_bin_width = sample_freq / window_length_padded
|
fft_bin_width = sample_freq / window_length_padded
|
||||||
@ -475,7 +477,7 @@ def get_mel_banks(
|
|||||||
|
|
||||||
assert vtln_warp_factor == 1.0 or (
|
assert vtln_warp_factor == 1.0 or (
|
||||||
(low_freq < vtln_low < high_freq) and (0.0 < vtln_high < high_freq) and (vtln_low < vtln_high)
|
(low_freq < vtln_low < high_freq) and (0.0 < vtln_high < high_freq) and (vtln_low < vtln_high)
|
||||||
), "Bad values in options: vtln-low {} and vtln-high {}, versus " "low-freq {} and high-freq {}".format(
|
), "Bad values in options: vtln-low {} and vtln-high {}, versus low-freq {} and high-freq {}".format(
|
||||||
vtln_low, vtln_high, low_freq, high_freq
|
vtln_low, vtln_high, low_freq, high_freq
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -510,7 +512,10 @@ def get_mel_banks(
|
|||||||
|
|
||||||
return bins.to(device=device, dtype=dtype) # , center_freqs
|
return bins.to(device=device, dtype=dtype) # , center_freqs
|
||||||
|
|
||||||
|
|
||||||
cache = {}
|
cache = {}
|
||||||
|
|
||||||
|
|
||||||
def fbank(
|
def fbank(
|
||||||
waveform: Tensor,
|
waveform: Tensor,
|
||||||
blackman_coeff: float = 0.42,
|
blackman_coeff: float = 0.42,
|
||||||
@ -620,10 +625,30 @@ def fbank(
|
|||||||
# size (num_mel_bins, padded_window_size // 2)
|
# size (num_mel_bins, padded_window_size // 2)
|
||||||
# print(num_mel_bins, padded_window_size, sample_frequency, low_freq, high_freq, vtln_low, vtln_high, vtln_warp)
|
# print(num_mel_bins, padded_window_size, sample_frequency, low_freq, high_freq, vtln_low, vtln_high, vtln_warp)
|
||||||
|
|
||||||
cache_key="%s-%s-%s-%s-%s-%s-%s-%s-%s-%s"%(num_mel_bins, padded_window_size, sample_frequency, low_freq, high_freq, vtln_low, vtln_high, vtln_warp,device,dtype)
|
cache_key = "%s-%s-%s-%s-%s-%s-%s-%s-%s-%s" % (
|
||||||
|
num_mel_bins,
|
||||||
|
padded_window_size,
|
||||||
|
sample_frequency,
|
||||||
|
low_freq,
|
||||||
|
high_freq,
|
||||||
|
vtln_low,
|
||||||
|
vtln_high,
|
||||||
|
vtln_warp,
|
||||||
|
device,
|
||||||
|
dtype,
|
||||||
|
)
|
||||||
if cache_key not in cache:
|
if cache_key not in cache:
|
||||||
mel_energies = get_mel_banks(
|
mel_energies = get_mel_banks(
|
||||||
num_mel_bins, padded_window_size, sample_frequency, low_freq, high_freq, vtln_low, vtln_high, vtln_warp,device,dtype
|
num_mel_bins,
|
||||||
|
padded_window_size,
|
||||||
|
sample_frequency,
|
||||||
|
low_freq,
|
||||||
|
high_freq,
|
||||||
|
vtln_low,
|
||||||
|
vtln_high,
|
||||||
|
vtln_warp,
|
||||||
|
device,
|
||||||
|
dtype,
|
||||||
)
|
)
|
||||||
cache[cache_key] = mel_energies
|
cache[cache_key] = mel_energies
|
||||||
else:
|
else:
|
||||||
|
@ -11,6 +11,7 @@ class TAP(nn.Module):
|
|||||||
"""
|
"""
|
||||||
Temporal average pooling, only first-order mean is considered
|
Temporal average pooling, only first-order mean is considered
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, **kwargs):
|
def __init__(self, **kwargs):
|
||||||
super(TAP, self).__init__()
|
super(TAP, self).__init__()
|
||||||
|
|
||||||
@ -25,6 +26,7 @@ class TSDP(nn.Module):
|
|||||||
"""
|
"""
|
||||||
Temporal standard deviation pooling, only second-order std is considered
|
Temporal standard deviation pooling, only second-order std is considered
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, **kwargs):
|
def __init__(self, **kwargs):
|
||||||
super(TSDP, self).__init__()
|
super(TSDP, self).__init__()
|
||||||
|
|
||||||
@ -41,6 +43,7 @@ class TSTP(nn.Module):
|
|||||||
x-vector
|
x-vector
|
||||||
Comment: simple concatenation can not make full use of both statistics
|
Comment: simple concatenation can not make full use of both statistics
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, **kwargs):
|
def __init__(self, **kwargs):
|
||||||
super(TSTP, self).__init__()
|
super(TSTP, self).__init__()
|
||||||
|
|
||||||
@ -59,6 +62,7 @@ class ASTP(nn.Module):
|
|||||||
"""Attentive statistics pooling: Channel- and context-dependent
|
"""Attentive statistics pooling: Channel- and context-dependent
|
||||||
statistics pooling, first used in ECAPA_TDNN.
|
statistics pooling, first used in ECAPA_TDNN.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, in_dim, bottleneck_dim=128, global_context_att=False):
|
def __init__(self, in_dim, bottleneck_dim=128, global_context_att=False):
|
||||||
super(ASTP, self).__init__()
|
super(ASTP, self).__init__()
|
||||||
self.global_context_att = global_context_att
|
self.global_context_att = global_context_att
|
||||||
@ -66,15 +70,10 @@ class ASTP(nn.Module):
|
|||||||
# Use Conv1d with stride == 1 rather than Linear, then we don't
|
# Use Conv1d with stride == 1 rather than Linear, then we don't
|
||||||
# need to transpose inputs.
|
# need to transpose inputs.
|
||||||
if global_context_att:
|
if global_context_att:
|
||||||
self.linear1 = nn.Conv1d(
|
self.linear1 = nn.Conv1d(in_dim * 3, bottleneck_dim, kernel_size=1) # equals W and b in the paper
|
||||||
in_dim * 3, bottleneck_dim,
|
|
||||||
kernel_size=1) # equals W and b in the paper
|
|
||||||
else:
|
else:
|
||||||
self.linear1 = nn.Conv1d(
|
self.linear1 = nn.Conv1d(in_dim, bottleneck_dim, kernel_size=1) # equals W and b in the paper
|
||||||
in_dim, bottleneck_dim,
|
self.linear2 = nn.Conv1d(bottleneck_dim, in_dim, kernel_size=1) # equals V and k in the paper
|
||||||
kernel_size=1) # equals W and b in the paper
|
|
||||||
self.linear2 = nn.Conv1d(bottleneck_dim, in_dim,
|
|
||||||
kernel_size=1) # equals V and k in the paper
|
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
"""
|
"""
|
||||||
@ -88,15 +87,13 @@ class ASTP(nn.Module):
|
|||||||
|
|
||||||
if self.global_context_att:
|
if self.global_context_att:
|
||||||
context_mean = torch.mean(x, dim=-1, keepdim=True).expand_as(x)
|
context_mean = torch.mean(x, dim=-1, keepdim=True).expand_as(x)
|
||||||
context_std = torch.sqrt(
|
context_std = torch.sqrt(torch.var(x, dim=-1, keepdim=True) + 1e-10).expand_as(x)
|
||||||
torch.var(x, dim=-1, keepdim=True) + 1e-10).expand_as(x)
|
|
||||||
x_in = torch.cat((x, context_mean, context_std), dim=1)
|
x_in = torch.cat((x, context_mean, context_std), dim=1)
|
||||||
else:
|
else:
|
||||||
x_in = x
|
x_in = x
|
||||||
|
|
||||||
# DON'T use ReLU here! ReLU may be hard to converge.
|
# DON'T use ReLU here! ReLU may be hard to converge.
|
||||||
alpha = torch.tanh(
|
alpha = torch.tanh(self.linear1(x_in)) # alpha = F.relu(self.linear1(x_in))
|
||||||
self.linear1(x_in)) # alpha = F.relu(self.linear1(x_in))
|
|
||||||
alpha = torch.softmax(self.linear2(alpha), dim=2)
|
alpha = torch.softmax(self.linear2(alpha), dim=2)
|
||||||
mean = torch.sum(alpha * x, dim=2)
|
mean = torch.sum(alpha * x, dim=2)
|
||||||
var = torch.sum(alpha * (x**2), dim=2) - mean**2
|
var = torch.sum(alpha * (x**2), dim=2) - mean**2
|
||||||
|
@ -435,6 +435,7 @@ class GPTSoVITSV3(torch.nn.Module):
|
|||||||
wav_gen = torch.cat(wav_gen_list, 2)
|
wav_gen = torch.cat(wav_gen_list, 2)
|
||||||
return wav_gen[0][0][:wav_gen_length]
|
return wav_gen[0][0][:wav_gen_length]
|
||||||
|
|
||||||
|
|
||||||
class GPTSoVITSV4(torch.nn.Module):
|
class GPTSoVITSV4(torch.nn.Module):
|
||||||
def __init__(self, gpt_sovits_half, cfm, hifigan):
|
def __init__(self, gpt_sovits_half, cfm, hifigan):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -577,6 +578,7 @@ from process_ckpt import get_sovits_version_from_path_fast, load_sovits_new
|
|||||||
|
|
||||||
v3v4set = {"v3", "v4"}
|
v3v4set = {"v3", "v4"}
|
||||||
|
|
||||||
|
|
||||||
def get_sovits_weights(sovits_path):
|
def get_sovits_weights(sovits_path):
|
||||||
path_sovits_v3 = "GPT_SoVITS/pretrained_models/s2Gv3.pth"
|
path_sovits_v3 = "GPT_SoVITS/pretrained_models/s2Gv3.pth"
|
||||||
is_exist_s2gv3 = os.path.exists(path_sovits_v3)
|
is_exist_s2gv3 = os.path.exists(path_sovits_v3)
|
||||||
@ -707,7 +709,6 @@ def export_1(ref_wav_path,ref_wav_text,version="v3"):
|
|||||||
sovits = get_sovits_weights("GPT_SoVITS/pretrained_models/gsv-v4-pretrained/s2Gv4.pth")
|
sovits = get_sovits_weights("GPT_SoVITS/pretrained_models/gsv-v4-pretrained/s2Gv4.pth")
|
||||||
init_hifigan()
|
init_hifigan()
|
||||||
|
|
||||||
|
|
||||||
dict_s1 = torch.load("GPT_SoVITS/pretrained_models/s1v3.ckpt")
|
dict_s1 = torch.load("GPT_SoVITS/pretrained_models/s1v3.ckpt")
|
||||||
raw_t2s = get_raw_t2s_model(dict_s1).to(device)
|
raw_t2s = get_raw_t2s_model(dict_s1).to(device)
|
||||||
print("#### get_raw_t2s_model ####")
|
print("#### get_raw_t2s_model ####")
|
||||||
@ -751,9 +752,7 @@ def export_1(ref_wav_path,ref_wav_text,version="v3"):
|
|||||||
# phones1, bert1, norm_text1 = get_phones_and_bert(
|
# phones1, bert1, norm_text1 = get_phones_and_bert(
|
||||||
# "你这老坏蛋,我找了你这么久,真没想到在这里找到你。他说。", "all_zh", "v3"
|
# "你这老坏蛋,我找了你这么久,真没想到在这里找到你。他说。", "all_zh", "v3"
|
||||||
# )
|
# )
|
||||||
phones1, bert1, norm_text1 = get_phones_and_bert(
|
phones1, bert1, norm_text1 = get_phones_and_bert(ref_wav_text, "auto", "v3")
|
||||||
ref_wav_text, "auto", "v3"
|
|
||||||
)
|
|
||||||
phones2, bert2, norm_text2 = get_phones_and_bert(
|
phones2, bert2, norm_text2 = get_phones_and_bert(
|
||||||
"这是一个简单的示例,真没想到这么简单就完成了。The King and His Stories.Once there was a king. He likes to write stories, but his stories were not good. As people were afraid of him, they all said his stories were good.After reading them, the writer at once turned to the soldiers and said: Take me back to prison, please.",
|
"这是一个简单的示例,真没想到这么简单就完成了。The King and His Stories.Once there was a king. He likes to write stories, but his stories were not good. As people were afraid of him, they all said his stories were good.After reading them, the writer at once turned to the soldiers and said: Take me back to prison, please.",
|
||||||
"auto",
|
"auto",
|
||||||
@ -1201,7 +1200,6 @@ def export_2(version="v3"):
|
|||||||
gpt_sovits_v3v4 = gpt_sovits_v3 if version == "v3" else gpt_sovits_v4
|
gpt_sovits_v3v4 = gpt_sovits_v3 if version == "v3" else gpt_sovits_v4
|
||||||
sr = 24000 if version == "v3" else 48000
|
sr = 24000 if version == "v3" else 48000
|
||||||
|
|
||||||
|
|
||||||
time.sleep(5)
|
time.sleep(5)
|
||||||
# print("thread:", torch.get_num_threads())
|
# print("thread:", torch.get_num_threads())
|
||||||
# print("thread:", torch.get_num_interop_threads())
|
# print("thread:", torch.get_num_interop_threads())
|
||||||
@ -1212,14 +1210,14 @@ def export_2(version="v3"):
|
|||||||
"汗流浃背了呀!老弟~ My uncle has two dogs. One is big and the other is small. He likes them very much. He often plays with them. He takes them for a walk every day. He says they are his good friends. He is very happy with them. 最后还是我得了 MVP....",
|
"汗流浃背了呀!老弟~ My uncle has two dogs. One is big and the other is small. He likes them very much. He often plays with them. He takes them for a walk every day. He says they are his good friends. He is very happy with them. 最后还是我得了 MVP....",
|
||||||
gpt_sovits_v3v4,
|
gpt_sovits_v3v4,
|
||||||
"out.wav",
|
"out.wav",
|
||||||
sr
|
sr,
|
||||||
)
|
)
|
||||||
|
|
||||||
test_export(
|
test_export(
|
||||||
"你小子是什么来路.汗流浃背了呀!老弟~ My uncle has two dogs. He is very happy with them. 最后还是我得了 MVP!",
|
"你小子是什么来路.汗流浃背了呀!老弟~ My uncle has two dogs. He is very happy with them. 最后还是我得了 MVP!",
|
||||||
gpt_sovits_v3v4,
|
gpt_sovits_v3v4,
|
||||||
"out2.wav",
|
"out2.wav",
|
||||||
sr
|
sr,
|
||||||
)
|
)
|
||||||
|
|
||||||
# test_export(
|
# test_export(
|
||||||
|
@ -252,9 +252,28 @@ class TextAudioSpeakerCollate:
|
|||||||
if self.is_v2Pro:
|
if self.is_v2Pro:
|
||||||
sv_embs[i] = row[4]
|
sv_embs[i] = row[4]
|
||||||
if self.is_v2Pro:
|
if self.is_v2Pro:
|
||||||
return ssl_padded, ssl_lengths, spec_padded, spec_lengths, wav_padded, wav_lengths, text_padded, text_lengths,sv_embs
|
return (
|
||||||
|
ssl_padded,
|
||||||
|
ssl_lengths,
|
||||||
|
spec_padded,
|
||||||
|
spec_lengths,
|
||||||
|
wav_padded,
|
||||||
|
wav_lengths,
|
||||||
|
text_padded,
|
||||||
|
text_lengths,
|
||||||
|
sv_embs,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
return ssl_padded, ssl_lengths, spec_padded, spec_lengths, wav_padded, wav_lengths, text_padded, text_lengths
|
return (
|
||||||
|
ssl_padded,
|
||||||
|
ssl_lengths,
|
||||||
|
spec_padded,
|
||||||
|
spec_lengths,
|
||||||
|
wav_padded,
|
||||||
|
wav_lengths,
|
||||||
|
text_padded,
|
||||||
|
text_lengths,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class TextAudioSpeakerLoaderV3(torch.utils.data.Dataset):
|
class TextAudioSpeakerLoaderV3(torch.utils.data.Dataset):
|
||||||
|
@ -586,12 +586,17 @@ class DiscriminatorS(torch.nn.Module):
|
|||||||
|
|
||||||
return x, fmap
|
return x, fmap
|
||||||
|
|
||||||
|
|
||||||
v2pro_set = {"v2Pro", "v2ProPlus"}
|
v2pro_set = {"v2Pro", "v2ProPlus"}
|
||||||
|
|
||||||
|
|
||||||
class MultiPeriodDiscriminator(torch.nn.Module):
|
class MultiPeriodDiscriminator(torch.nn.Module):
|
||||||
def __init__(self, use_spectral_norm=False, version=None):
|
def __init__(self, use_spectral_norm=False, version=None):
|
||||||
super(MultiPeriodDiscriminator, self).__init__()
|
super(MultiPeriodDiscriminator, self).__init__()
|
||||||
if version in v2pro_set:periods = [2, 3, 5, 7, 11,17,23]
|
if version in v2pro_set:
|
||||||
else:periods = [2, 3, 5, 7, 11]
|
periods = [2, 3, 5, 7, 11, 17, 23]
|
||||||
|
else:
|
||||||
|
periods = [2, 3, 5, 7, 11]
|
||||||
|
|
||||||
discs = [DiscriminatorS(use_spectral_norm=use_spectral_norm)]
|
discs = [DiscriminatorS(use_spectral_norm=use_spectral_norm)]
|
||||||
discs = discs + [DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods]
|
discs = discs + [DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods]
|
||||||
@ -787,6 +792,7 @@ class CodePredictor(nn.Module):
|
|||||||
|
|
||||||
return pred_codes.transpose(0, 1)
|
return pred_codes.transpose(0, 1)
|
||||||
|
|
||||||
|
|
||||||
class SynthesizerTrn(nn.Module):
|
class SynthesizerTrn(nn.Module):
|
||||||
"""
|
"""
|
||||||
Synthesizer for Training
|
Synthesizer for Training
|
||||||
@ -983,7 +989,14 @@ class SynthesizerTrn(nn.Module):
|
|||||||
quantized = self.quantizer.decode(codes)
|
quantized = self.quantizer.decode(codes)
|
||||||
if self.semantic_frame_rate == "25hz":
|
if self.semantic_frame_rate == "25hz":
|
||||||
quantized = F.interpolate(quantized, size=int(quantized.shape[-1] * 2), mode="nearest")
|
quantized = F.interpolate(quantized, size=int(quantized.shape[-1] * 2), mode="nearest")
|
||||||
x, m_p, logs_p, y_mask = self.enc_p(quantized, y_lengths, text, text_lengths, self.ge_to512(ge.transpose(2,1)).transpose(2,1)if self.is_v2pro else ge, speed)
|
x, m_p, logs_p, y_mask = self.enc_p(
|
||||||
|
quantized,
|
||||||
|
y_lengths,
|
||||||
|
text,
|
||||||
|
text_lengths,
|
||||||
|
self.ge_to512(ge.transpose(2, 1)).transpose(2, 1) if self.is_v2pro else ge,
|
||||||
|
speed,
|
||||||
|
)
|
||||||
z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale
|
z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale
|
||||||
|
|
||||||
z = self.flow(z_p, y_mask, g=ge, reverse=True)
|
z = self.flow(z_p, y_mask, g=ge, reverse=True)
|
||||||
@ -996,6 +1009,7 @@ class SynthesizerTrn(nn.Module):
|
|||||||
quantized, codes, commit_loss, quantized_list = self.quantizer(ssl)
|
quantized, codes, commit_loss, quantized_list = self.quantizer(ssl)
|
||||||
return codes.transpose(0, 1)
|
return codes.transpose(0, 1)
|
||||||
|
|
||||||
|
|
||||||
class CFM(torch.nn.Module):
|
class CFM(torch.nn.Module):
|
||||||
def __init__(self, in_channels, dit):
|
def __init__(self, in_channels, dit):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -1029,7 +1043,18 @@ class CFM(torch.nn.Module):
|
|||||||
t_tensor = torch.ones(x.shape[0], device=x.device, dtype=mu.dtype) * t
|
t_tensor = torch.ones(x.shape[0], device=x.device, dtype=mu.dtype) * t
|
||||||
# v_pred = model(x, t_tensor, d_tensor, **extra_args)
|
# v_pred = model(x, t_tensor, d_tensor, **extra_args)
|
||||||
v_pred, text_emb, dt = self.estimator(
|
v_pred, text_emb, dt = self.estimator(
|
||||||
x, prompt_x, x_lens, t_tensor, d_tensor, mu, use_grad_ckpt=False, drop_audio_cond=False, drop_text=False, infer=True, text_cache=text_cache, dt_cache=dt_cache
|
x,
|
||||||
|
prompt_x,
|
||||||
|
x_lens,
|
||||||
|
t_tensor,
|
||||||
|
d_tensor,
|
||||||
|
mu,
|
||||||
|
use_grad_ckpt=False,
|
||||||
|
drop_audio_cond=False,
|
||||||
|
drop_text=False,
|
||||||
|
infer=True,
|
||||||
|
text_cache=text_cache,
|
||||||
|
dt_cache=dt_cache,
|
||||||
)
|
)
|
||||||
v_pred = v_pred.transpose(2, 1)
|
v_pred = v_pred.transpose(2, 1)
|
||||||
if self.use_conditioner_cache:
|
if self.use_conditioner_cache:
|
||||||
@ -1048,7 +1073,7 @@ class CFM(torch.nn.Module):
|
|||||||
drop_text=True,
|
drop_text=True,
|
||||||
infer=True,
|
infer=True,
|
||||||
text_cache=text_cfg_cache,
|
text_cache=text_cfg_cache,
|
||||||
dt_cache=dt_cache
|
dt_cache=dt_cache,
|
||||||
)
|
)
|
||||||
neg = neg.transpose(2, 1)
|
neg = neg.transpose(2, 1)
|
||||||
if self.use_conditioner_cache:
|
if self.use_conditioner_cache:
|
||||||
|
@ -1,5 +1,4 @@
|
|||||||
import math
|
import math
|
||||||
import pdb
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
@ -10,7 +10,6 @@ i_part = os.environ.get("i_part")
|
|||||||
all_parts = os.environ.get("all_parts")
|
all_parts = os.environ.get("all_parts")
|
||||||
if "_CUDA_VISIBLE_DEVICES" in os.environ:
|
if "_CUDA_VISIBLE_DEVICES" in os.environ:
|
||||||
os.environ["CUDA_VISIBLE_DEVICES"] = os.environ["_CUDA_VISIBLE_DEVICES"]
|
os.environ["CUDA_VISIBLE_DEVICES"] = os.environ["_CUDA_VISIBLE_DEVICES"]
|
||||||
from feature_extractor import cnhubert
|
|
||||||
|
|
||||||
opt_dir = os.environ.get("opt_dir")
|
opt_dir = os.environ.get("opt_dir")
|
||||||
sv_path = os.environ.get("sv_path")
|
sv_path = os.environ.get("sv_path")
|
||||||
@ -19,19 +18,18 @@ import torch
|
|||||||
is_half = eval(os.environ.get("is_half", "True")) and torch.cuda.is_available()
|
is_half = eval(os.environ.get("is_half", "True")) and torch.cuda.is_available()
|
||||||
|
|
||||||
import traceback
|
import traceback
|
||||||
import numpy as np
|
|
||||||
from scipy.io import wavfile
|
|
||||||
import torchaudio
|
import torchaudio
|
||||||
|
|
||||||
now_dir = os.getcwd()
|
now_dir = os.getcwd()
|
||||||
sys.path.append(now_dir)
|
sys.path.append(now_dir)
|
||||||
sys.path.append(f"{now_dir}/GPT_SoVITS/eres2net")
|
sys.path.append(f"{now_dir}/GPT_SoVITS/eres2net")
|
||||||
from tools.my_utils import load_audio, clean_path
|
from tools.my_utils import clean_path
|
||||||
from time import time as ttime
|
from time import time as ttime
|
||||||
import shutil
|
import shutil
|
||||||
from ERes2NetV2 import ERes2NetV2
|
from ERes2NetV2 import ERes2NetV2
|
||||||
import kaldi as Kaldi
|
import kaldi as Kaldi
|
||||||
|
|
||||||
|
|
||||||
def my_save(fea, path): #####fix issue: torch.save doesn't support chinese path
|
def my_save(fea, path): #####fix issue: torch.save doesn't support chinese path
|
||||||
dir = os.path.dirname(path)
|
dir = os.path.dirname(path)
|
||||||
name = os.path.basename(path)
|
name = os.path.basename(path)
|
||||||
@ -56,9 +54,10 @@ if torch.cuda.is_available():
|
|||||||
else:
|
else:
|
||||||
device = "cpu"
|
device = "cpu"
|
||||||
|
|
||||||
|
|
||||||
class SV:
|
class SV:
|
||||||
def __init__(self, device, is_half):
|
def __init__(self, device, is_half):
|
||||||
pretrained_state = torch.load(sv_path, map_location='cpu')
|
pretrained_state = torch.load(sv_path, map_location="cpu")
|
||||||
embedding_model = ERes2NetV2(baseWidth=24, scale=4, expansion=4)
|
embedding_model = ERes2NetV2(baseWidth=24, scale=4, expansion=4)
|
||||||
embedding_model.load_state_dict(pretrained_state)
|
embedding_model.load_state_dict(pretrained_state)
|
||||||
embedding_model.eval()
|
embedding_model.eval()
|
||||||
@ -73,15 +72,22 @@ class SV:
|
|||||||
def compute_embedding3(self, wav): # (1,x)#-1~1
|
def compute_embedding3(self, wav): # (1,x)#-1~1
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
wav = self.res(wav)
|
wav = self.res(wav)
|
||||||
if self.is_half==True:wav=wav.half()
|
if self.is_half == True:
|
||||||
feat = torch.stack([Kaldi.fbank(wav0.unsqueeze(0), num_mel_bins=80, sample_frequency=16000, dither=0) for wav0 in wav])
|
wav = wav.half()
|
||||||
|
feat = torch.stack(
|
||||||
|
[Kaldi.fbank(wav0.unsqueeze(0), num_mel_bins=80, sample_frequency=16000, dither=0) for wav0 in wav]
|
||||||
|
)
|
||||||
sv_emb = self.embedding_model.forward3(feat)
|
sv_emb = self.embedding_model.forward3(feat)
|
||||||
return sv_emb
|
return sv_emb
|
||||||
|
|
||||||
|
|
||||||
sv = SV(device, is_half)
|
sv = SV(device, is_half)
|
||||||
|
|
||||||
|
|
||||||
def name2go(wav_name, wav_path):
|
def name2go(wav_name, wav_path):
|
||||||
sv_cn_path = "%s/%s.pt" % (sv_cn_dir, wav_name)
|
sv_cn_path = "%s/%s.pt" % (sv_cn_dir, wav_name)
|
||||||
if os.path.exists(sv_cn_path):return
|
if os.path.exists(sv_cn_path):
|
||||||
|
return
|
||||||
wav_path = "%s/%s" % (wav32dir, wav_name)
|
wav_path = "%s/%s" % (wav32dir, wav_name)
|
||||||
wav32k, sr0 = torchaudio.load(wav_path)
|
wav32k, sr0 = torchaudio.load(wav_path)
|
||||||
assert sr0 == 32000
|
assert sr0 == 32000
|
||||||
|
@ -17,7 +17,6 @@ def my_save(fea, path): #####fix issue: torch.save doesn't support chinese path
|
|||||||
shutil.move(tmp_path, "%s/%s" % (dir, name))
|
shutil.move(tmp_path, "%s/%s" % (dir, name))
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
|
|
||||||
model_version2byte = {
|
model_version2byte = {
|
||||||
@ -26,6 +25,8 @@ model_version2byte={
|
|||||||
"v2Pro": b"05",
|
"v2Pro": b"05",
|
||||||
"v2ProPlus": b"06",
|
"v2ProPlus": b"06",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def my_save2(fea, path, model_version):
|
def my_save2(fea, path, model_version):
|
||||||
bio = BytesIO()
|
bio = BytesIO()
|
||||||
torch.save(fea, bio)
|
torch.save(fea, bio)
|
||||||
@ -50,7 +51,7 @@ def savee(ckpt, name, epoch, steps, hps, model_version=None, lora_rank=None):
|
|||||||
if lora_rank:
|
if lora_rank:
|
||||||
opt["lora_rank"] = lora_rank
|
opt["lora_rank"] = lora_rank
|
||||||
my_save2(opt, "%s/%s.pth" % (hps.save_weight_dir, name), model_version)
|
my_save2(opt, "%s/%s.pth" % (hps.save_weight_dir, name), model_version)
|
||||||
elif (model_version!=None and "Pro"in model_version):
|
elif model_version != None and "Pro" in model_version:
|
||||||
my_save2(opt, "%s/%s.pth" % (hps.save_weight_dir, name), model_version)
|
my_save2(opt, "%s/%s.pth" % (hps.save_weight_dir, name), model_version)
|
||||||
else:
|
else:
|
||||||
my_save(opt, "%s/%s.pth" % (hps.save_weight_dir, name))
|
my_save(opt, "%s/%s.pth" % (hps.save_weight_dir, name))
|
||||||
@ -58,6 +59,7 @@ def savee(ckpt, name, epoch, steps, hps, model_version=None, lora_rank=None):
|
|||||||
except:
|
except:
|
||||||
return traceback.format_exc()
|
return traceback.format_exc()
|
||||||
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
00:v1
|
00:v1
|
||||||
01:v2
|
01:v2
|
||||||
|
@ -36,7 +36,7 @@ from module.models import (
|
|||||||
MultiPeriodDiscriminator,
|
MultiPeriodDiscriminator,
|
||||||
SynthesizerTrn,
|
SynthesizerTrn,
|
||||||
)
|
)
|
||||||
from process_ckpt import savee,my_save2
|
from process_ckpt import savee
|
||||||
|
|
||||||
torch.backends.cudnn.benchmark = False
|
torch.backends.cudnn.benchmark = False
|
||||||
torch.backends.cudnn.deterministic = False
|
torch.backends.cudnn.deterministic = False
|
||||||
@ -91,7 +91,26 @@ def run(rank, n_gpus, hps):
|
|||||||
train_sampler = DistributedBucketSampler(
|
train_sampler = DistributedBucketSampler(
|
||||||
train_dataset,
|
train_dataset,
|
||||||
hps.train.batch_size,
|
hps.train.batch_size,
|
||||||
[32,300,400,500,600,700,800,900,1000,1100,1200,1300,1400,1500,1600,1700,1800,1900,],
|
[
|
||||||
|
32,
|
||||||
|
300,
|
||||||
|
400,
|
||||||
|
500,
|
||||||
|
600,
|
||||||
|
700,
|
||||||
|
800,
|
||||||
|
900,
|
||||||
|
1000,
|
||||||
|
1100,
|
||||||
|
1200,
|
||||||
|
1300,
|
||||||
|
1400,
|
||||||
|
1500,
|
||||||
|
1600,
|
||||||
|
1700,
|
||||||
|
1800,
|
||||||
|
1900,
|
||||||
|
],
|
||||||
num_replicas=n_gpus,
|
num_replicas=n_gpus,
|
||||||
rank=rank,
|
rank=rank,
|
||||||
shuffle=True,
|
shuffle=True,
|
||||||
@ -315,12 +334,39 @@ def train_and_evaluate(rank, epoch, hps, nets, optims, schedulers, scaler, loade
|
|||||||
else:
|
else:
|
||||||
ssl, ssl_lengths, spec, spec_lengths, y, y_lengths, text, text_lengths = data
|
ssl, ssl_lengths, spec, spec_lengths, y, y_lengths, text, text_lengths = data
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
spec, spec_lengths = (spec.cuda(rank,non_blocking=True,),spec_lengths.cuda(rank,non_blocking=True,),)
|
spec, spec_lengths = (
|
||||||
y, y_lengths = (y.cuda(rank,non_blocking=True,),y_lengths.cuda(rank,non_blocking=True,),)
|
spec.cuda(
|
||||||
|
rank,
|
||||||
|
non_blocking=True,
|
||||||
|
),
|
||||||
|
spec_lengths.cuda(
|
||||||
|
rank,
|
||||||
|
non_blocking=True,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
y, y_lengths = (
|
||||||
|
y.cuda(
|
||||||
|
rank,
|
||||||
|
non_blocking=True,
|
||||||
|
),
|
||||||
|
y_lengths.cuda(
|
||||||
|
rank,
|
||||||
|
non_blocking=True,
|
||||||
|
),
|
||||||
|
)
|
||||||
ssl = ssl.cuda(rank, non_blocking=True)
|
ssl = ssl.cuda(rank, non_blocking=True)
|
||||||
ssl.requires_grad = False
|
ssl.requires_grad = False
|
||||||
# ssl_lengths = ssl_lengths.cuda(rank, non_blocking=True)
|
# ssl_lengths = ssl_lengths.cuda(rank, non_blocking=True)
|
||||||
text, text_lengths = (text.cuda(rank,non_blocking=True,),text_lengths.cuda(rank,non_blocking=True,),)
|
text, text_lengths = (
|
||||||
|
text.cuda(
|
||||||
|
rank,
|
||||||
|
non_blocking=True,
|
||||||
|
),
|
||||||
|
text_lengths.cuda(
|
||||||
|
rank,
|
||||||
|
non_blocking=True,
|
||||||
|
),
|
||||||
|
)
|
||||||
if hps.model.version in {"v2Pro", "v2ProPlus"}:
|
if hps.model.version in {"v2Pro", "v2ProPlus"}:
|
||||||
sv_emb = sv_emb.cuda(rank, non_blocking=True)
|
sv_emb = sv_emb.cuda(rank, non_blocking=True)
|
||||||
else:
|
else:
|
||||||
@ -334,9 +380,19 @@ def train_and_evaluate(rank, epoch, hps, nets, optims, schedulers, scaler, loade
|
|||||||
sv_emb = sv_emb.to(device)
|
sv_emb = sv_emb.to(device)
|
||||||
with autocast(enabled=hps.train.fp16_run):
|
with autocast(enabled=hps.train.fp16_run):
|
||||||
if hps.model.version in {"v2Pro", "v2ProPlus"}:
|
if hps.model.version in {"v2Pro", "v2ProPlus"}:
|
||||||
(y_hat,kl_ssl,ids_slice,x_mask,z_mask,(z, z_p, m_p, logs_p, m_q, logs_q),stats_ssl) = net_g(ssl, spec, spec_lengths, text, text_lengths,sv_emb)
|
(y_hat, kl_ssl, ids_slice, x_mask, z_mask, (z, z_p, m_p, logs_p, m_q, logs_q), stats_ssl) = net_g(
|
||||||
|
ssl, spec, spec_lengths, text, text_lengths, sv_emb
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
(y_hat,kl_ssl,ids_slice,x_mask,z_mask,(z, z_p, m_p, logs_p, m_q, logs_q),stats_ssl,) = net_g(ssl, spec, spec_lengths, text, text_lengths)
|
(
|
||||||
|
y_hat,
|
||||||
|
kl_ssl,
|
||||||
|
ids_slice,
|
||||||
|
x_mask,
|
||||||
|
z_mask,
|
||||||
|
(z, z_p, m_p, logs_p, m_q, logs_q),
|
||||||
|
stats_ssl,
|
||||||
|
) = net_g(ssl, spec, spec_lengths, text, text_lengths)
|
||||||
|
|
||||||
mel = spec_to_mel_torch(
|
mel = spec_to_mel_torch(
|
||||||
spec,
|
spec,
|
||||||
@ -508,7 +564,14 @@ def train_and_evaluate(rank, epoch, hps, nets, optims, schedulers, scaler, loade
|
|||||||
% (
|
% (
|
||||||
hps.name,
|
hps.name,
|
||||||
epoch,
|
epoch,
|
||||||
savee(ckpt,hps.name + "_e%s_s%s" % (epoch, global_step),epoch,global_step,hps,model_version=None if hps.model.version not in {"v2Pro","v2ProPlus"}else hps.model.version),
|
savee(
|
||||||
|
ckpt,
|
||||||
|
hps.name + "_e%s_s%s" % (epoch, global_step),
|
||||||
|
epoch,
|
||||||
|
global_step,
|
||||||
|
hps,
|
||||||
|
model_version=None if hps.model.version not in {"v2Pro", "v2ProPlus"} else hps.model.version,
|
||||||
|
),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -1,11 +1,16 @@
|
|||||||
import sys,os,torch
|
import sys
|
||||||
|
import os
|
||||||
|
import torch
|
||||||
|
|
||||||
sys.path.append(f"{os.getcwd()}/GPT_SoVITS/eres2net")
|
sys.path.append(f"{os.getcwd()}/GPT_SoVITS/eres2net")
|
||||||
sv_path = "GPT_SoVITS/pretrained_models/sv/pretrained_eres2netv2w24s4ep4.ckpt"
|
sv_path = "GPT_SoVITS/pretrained_models/sv/pretrained_eres2netv2w24s4ep4.ckpt"
|
||||||
from ERes2NetV2 import ERes2NetV2
|
from ERes2NetV2 import ERes2NetV2
|
||||||
import kaldi as Kaldi
|
import kaldi as Kaldi
|
||||||
|
|
||||||
|
|
||||||
class SV:
|
class SV:
|
||||||
def __init__(self, device, is_half):
|
def __init__(self, device, is_half):
|
||||||
pretrained_state = torch.load(sv_path, map_location='cpu', weights_only=False)
|
pretrained_state = torch.load(sv_path, map_location="cpu", weights_only=False)
|
||||||
embedding_model = ERes2NetV2(baseWidth=24, scale=4, expansion=4)
|
embedding_model = ERes2NetV2(baseWidth=24, scale=4, expansion=4)
|
||||||
embedding_model.load_state_dict(pretrained_state)
|
embedding_model.load_state_dict(pretrained_state)
|
||||||
embedding_model.eval()
|
embedding_model.eval()
|
||||||
@ -18,7 +23,10 @@ class SV:
|
|||||||
|
|
||||||
def compute_embedding3(self, wav):
|
def compute_embedding3(self, wav):
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
if self.is_half==True:wav=wav.half()
|
if self.is_half == True:
|
||||||
feat = torch.stack([Kaldi.fbank(wav0.unsqueeze(0), num_mel_bins=80, sample_frequency=16000, dither=0) for wav0 in wav])
|
wav = wav.half()
|
||||||
|
feat = torch.stack(
|
||||||
|
[Kaldi.fbank(wav0.unsqueeze(0), num_mel_bins=80, sample_frequency=16000, dither=0) for wav0 in wav]
|
||||||
|
)
|
||||||
sv_emb = self.embedding_model.forward3(feat)
|
sv_emb = self.embedding_model.forward3(feat)
|
||||||
return sv_emb
|
return sv_emb
|
||||||
|
@ -3,19 +3,25 @@ import re
|
|||||||
|
|
||||||
# jieba静音
|
# jieba静音
|
||||||
import jieba
|
import jieba
|
||||||
|
|
||||||
jieba.setLogLevel(logging.CRITICAL)
|
jieba.setLogLevel(logging.CRITICAL)
|
||||||
|
|
||||||
# 更改fast_langdetect大模型位置
|
# 更改fast_langdetect大模型位置
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import fast_langdetect
|
import fast_langdetect
|
||||||
fast_langdetect.infer._default_detector = fast_langdetect.infer.LangDetector(fast_langdetect.infer.LangDetectConfig(cache_dir=Path(__file__).parent.parent.parent / "pretrained_models" / "fast_langdetect"))
|
|
||||||
|
fast_langdetect.infer._default_detector = fast_langdetect.infer.LangDetector(
|
||||||
|
fast_langdetect.infer.LangDetectConfig(
|
||||||
|
cache_dir=Path(__file__).parent.parent.parent / "pretrained_models" / "fast_langdetect"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
from split_lang import LangSplitter
|
from split_lang import LangSplitter
|
||||||
|
|
||||||
|
|
||||||
def full_en(text):
|
def full_en(text):
|
||||||
pattern = r'^(?=.*[A-Za-z])[A-Za-z0-9\s\u0020-\u007E\u2000-\u206F\u3000-\u303F\uFF00-\uFFEF]+$'
|
pattern = r"^(?=.*[A-Za-z])[A-Za-z0-9\s\u0020-\u007E\u2000-\u206F\u3000-\u303F\uFF00-\uFFEF]+$"
|
||||||
return bool(re.match(pattern, text))
|
return bool(re.match(pattern, text))
|
||||||
|
|
||||||
|
|
||||||
@ -34,7 +40,7 @@ def full_cjk(text):
|
|||||||
(0x2EBF0, 0x2EE5D), # CJK Extension H
|
(0x2EBF0, 0x2EE5D), # CJK Extension H
|
||||||
]
|
]
|
||||||
|
|
||||||
pattern = r'[0-9、-〜。!?.!?… /]+$'
|
pattern = r"[0-9、-〜。!?.!?… /]+$"
|
||||||
|
|
||||||
cjk_text = ""
|
cjk_text = ""
|
||||||
for char in text:
|
for char in text:
|
||||||
@ -53,28 +59,28 @@ def split_jako(tag_lang,item):
|
|||||||
|
|
||||||
lang_list: list[dict] = []
|
lang_list: list[dict] = []
|
||||||
tag = 0
|
tag = 0
|
||||||
for match in re.finditer(pattern, item['text']):
|
for match in re.finditer(pattern, item["text"]):
|
||||||
if match.start() > tag:
|
if match.start() > tag:
|
||||||
lang_list.append({'lang':item['lang'],'text':item['text'][tag:match.start()]})
|
lang_list.append({"lang": item["lang"], "text": item["text"][tag : match.start()]})
|
||||||
|
|
||||||
tag = match.end()
|
tag = match.end()
|
||||||
lang_list.append({'lang':tag_lang,'text':item['text'][match.start():match.end()]})
|
lang_list.append({"lang": tag_lang, "text": item["text"][match.start() : match.end()]})
|
||||||
|
|
||||||
if tag < len(item['text']):
|
if tag < len(item["text"]):
|
||||||
lang_list.append({'lang':item['lang'],'text':item['text'][tag:len(item['text'])]})
|
lang_list.append({"lang": item["lang"], "text": item["text"][tag : len(item["text"])]})
|
||||||
|
|
||||||
return lang_list
|
return lang_list
|
||||||
|
|
||||||
|
|
||||||
def merge_lang(lang_list, item):
|
def merge_lang(lang_list, item):
|
||||||
if lang_list and item['lang'] == lang_list[-1]['lang']:
|
if lang_list and item["lang"] == lang_list[-1]["lang"]:
|
||||||
lang_list[-1]['text'] += item['text']
|
lang_list[-1]["text"] += item["text"]
|
||||||
else:
|
else:
|
||||||
lang_list.append(item)
|
lang_list.append(item)
|
||||||
return lang_list
|
return lang_list
|
||||||
|
|
||||||
|
|
||||||
class LangSegmenter():
|
class LangSegmenter:
|
||||||
# 默认过滤器, 基于gsv目前四种语言
|
# 默认过滤器, 基于gsv目前四种语言
|
||||||
DEFAULT_LANG_MAP = {
|
DEFAULT_LANG_MAP = {
|
||||||
"zh": "zh",
|
"zh": "zh",
|
||||||
@ -87,7 +93,6 @@ class LangSegmenter():
|
|||||||
"en": "en",
|
"en": "en",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def getTexts(text):
|
def getTexts(text):
|
||||||
lang_splitter = LangSplitter(lang_map=LangSegmenter.DEFAULT_LANG_MAP)
|
lang_splitter = LangSplitter(lang_map=LangSegmenter.DEFAULT_LANG_MAP)
|
||||||
substr = lang_splitter.split_by_lang(text=text)
|
substr = lang_splitter.split_by_lang(text=text)
|
||||||
@ -95,18 +100,18 @@ class LangSegmenter():
|
|||||||
lang_list: list[dict] = []
|
lang_list: list[dict] = []
|
||||||
|
|
||||||
for _, item in enumerate(substr):
|
for _, item in enumerate(substr):
|
||||||
dict_item = {'lang':item.lang,'text':item.text}
|
dict_item = {"lang": item.lang, "text": item.text}
|
||||||
|
|
||||||
# 处理短英文被识别为其他语言的问题
|
# 处理短英文被识别为其他语言的问题
|
||||||
if full_en(dict_item['text']):
|
if full_en(dict_item["text"]):
|
||||||
dict_item['lang'] = 'en'
|
dict_item["lang"] = "en"
|
||||||
lang_list = merge_lang(lang_list, dict_item)
|
lang_list = merge_lang(lang_list, dict_item)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# 处理非日语夹日文的问题(不包含CJK)
|
# 处理非日语夹日文的问题(不包含CJK)
|
||||||
ja_list: list[dict] = []
|
ja_list: list[dict] = []
|
||||||
if dict_item['lang'] != 'ja':
|
if dict_item["lang"] != "ja":
|
||||||
ja_list = split_jako('ja',dict_item)
|
ja_list = split_jako("ja", dict_item)
|
||||||
|
|
||||||
if not ja_list:
|
if not ja_list:
|
||||||
ja_list.append(dict_item)
|
ja_list.append(dict_item)
|
||||||
@ -115,8 +120,8 @@ class LangSegmenter():
|
|||||||
ko_list: list[dict] = []
|
ko_list: list[dict] = []
|
||||||
temp_list: list[dict] = []
|
temp_list: list[dict] = []
|
||||||
for _, ko_item in enumerate(ja_list):
|
for _, ko_item in enumerate(ja_list):
|
||||||
if ko_item["lang"] != 'ko':
|
if ko_item["lang"] != "ko":
|
||||||
ko_list = split_jako('ko',ko_item)
|
ko_list = split_jako("ko", ko_item)
|
||||||
|
|
||||||
if ko_list:
|
if ko_list:
|
||||||
temp_list.extend(ko_list)
|
temp_list.extend(ko_list)
|
||||||
@ -126,10 +131,10 @@ class LangSegmenter():
|
|||||||
# 未存在非日韩文夹日韩文
|
# 未存在非日韩文夹日韩文
|
||||||
if len(temp_list) == 1:
|
if len(temp_list) == 1:
|
||||||
# 未知语言检查是否为CJK
|
# 未知语言检查是否为CJK
|
||||||
if dict_item['lang'] == 'x':
|
if dict_item["lang"] == "x":
|
||||||
cjk_text = full_cjk(dict_item['text'])
|
cjk_text = full_cjk(dict_item["text"])
|
||||||
if cjk_text:
|
if cjk_text:
|
||||||
dict_item = {'lang':'zh','text':cjk_text}
|
dict_item = {"lang": "zh", "text": cjk_text}
|
||||||
lang_list = merge_lang(lang_list, dict_item)
|
lang_list = merge_lang(lang_list, dict_item)
|
||||||
else:
|
else:
|
||||||
lang_list = merge_lang(lang_list, dict_item)
|
lang_list = merge_lang(lang_list, dict_item)
|
||||||
@ -141,10 +146,10 @@ class LangSegmenter():
|
|||||||
# 存在非日韩文夹日韩文
|
# 存在非日韩文夹日韩文
|
||||||
for _, temp_item in enumerate(temp_list):
|
for _, temp_item in enumerate(temp_list):
|
||||||
# 未知语言检查是否为CJK
|
# 未知语言检查是否为CJK
|
||||||
if temp_item['lang'] == 'x':
|
if temp_item["lang"] == "x":
|
||||||
cjk_text = full_cjk(dict_item['text'])
|
cjk_text = full_cjk(dict_item["text"])
|
||||||
if cjk_text:
|
if cjk_text:
|
||||||
dict_item = {'lang':'zh','text':cjk_text}
|
dict_item = {"lang": "zh", "text": cjk_text}
|
||||||
lang_list = merge_lang(lang_list, dict_item)
|
lang_list = merge_lang(lang_list, dict_item)
|
||||||
else:
|
else:
|
||||||
lang_list = merge_lang(lang_list, dict_item)
|
lang_list = merge_lang(lang_list, dict_item)
|
||||||
@ -154,13 +159,13 @@ class LangSegmenter():
|
|||||||
temp_list = lang_list
|
temp_list = lang_list
|
||||||
lang_list = []
|
lang_list = []
|
||||||
for _, temp_item in enumerate(temp_list):
|
for _, temp_item in enumerate(temp_list):
|
||||||
if temp_item['lang'] == 'x':
|
if temp_item["lang"] == "x":
|
||||||
if lang_list:
|
if lang_list:
|
||||||
temp_item['lang'] = lang_list[-1]['lang']
|
temp_item["lang"] = lang_list[-1]["lang"]
|
||||||
elif len(temp_list) > 1:
|
elif len(temp_list) > 1:
|
||||||
temp_item['lang'] = temp_list[1]['lang']
|
temp_item["lang"] = temp_list[1]["lang"]
|
||||||
else:
|
else:
|
||||||
temp_item['lang'] = 'zh'
|
temp_item["lang"] = "zh"
|
||||||
|
|
||||||
lang_list = merge_lang(lang_list, temp_item)
|
lang_list = merge_lang(lang_list, temp_item)
|
||||||
|
|
||||||
|
@ -3,7 +3,6 @@
|
|||||||
|
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import traceback
|
|
||||||
import warnings
|
import warnings
|
||||||
import zipfile
|
import zipfile
|
||||||
from typing import Any, Dict, List, Tuple
|
from typing import Any, Dict, List, Tuple
|
||||||
|
@ -655,11 +655,7 @@ class ToneSandhi:
|
|||||||
while i < len(seg):
|
while i < len(seg):
|
||||||
word, pos = seg[i]
|
word, pos = seg[i]
|
||||||
merged = False
|
merged = False
|
||||||
if (
|
if i - 1 >= 0 and word == "一" and i + 1 < len(seg):
|
||||||
i - 1 >= 0
|
|
||||||
and word == "一"
|
|
||||||
and i + 1 < len(seg)
|
|
||||||
):
|
|
||||||
last = new_seg[-1] if new_seg else seg[i - 1]
|
last = new_seg[-1] if new_seg else seg[i - 1]
|
||||||
if last[0] == seg[i + 1][0] and last[1] == "v" and seg[i + 1][1] == "v":
|
if last[0] == seg[i + 1][0] and last[1] == "v" and seg[i + 1][1] == "v":
|
||||||
combined = last[0] + "一" + seg[i + 1][0]
|
combined = last[0] + "一" + seg[i + 1][0]
|
||||||
|
38
api.py
38
api.py
@ -199,6 +199,8 @@ def is_full(*items): # 任意一项为空返回False
|
|||||||
|
|
||||||
|
|
||||||
bigvgan_model = hifigan_model = sv_cn_model = None
|
bigvgan_model = hifigan_model = sv_cn_model = None
|
||||||
|
|
||||||
|
|
||||||
def clean_hifigan_model():
|
def clean_hifigan_model():
|
||||||
global hifigan_model
|
global hifigan_model
|
||||||
if hifigan_model:
|
if hifigan_model:
|
||||||
@ -208,6 +210,8 @@ def clean_hifigan_model():
|
|||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
def clean_bigvgan_model():
|
def clean_bigvgan_model():
|
||||||
global bigvgan_model
|
global bigvgan_model
|
||||||
if bigvgan_model:
|
if bigvgan_model:
|
||||||
@ -217,6 +221,8 @@ def clean_bigvgan_model():
|
|||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
def clean_sv_cn_model():
|
def clean_sv_cn_model():
|
||||||
global sv_cn_model
|
global sv_cn_model
|
||||||
if sv_cn_model:
|
if sv_cn_model:
|
||||||
@ -262,7 +268,9 @@ def init_hifigan():
|
|||||||
hifigan_model.eval()
|
hifigan_model.eval()
|
||||||
hifigan_model.remove_weight_norm()
|
hifigan_model.remove_weight_norm()
|
||||||
state_dict_g = torch.load(
|
state_dict_g = torch.load(
|
||||||
"%s/GPT_SoVITS/pretrained_models/gsv-v4-pretrained/vocoder.pth" % (now_dir,), map_location="cpu", weights_only=False
|
"%s/GPT_SoVITS/pretrained_models/gsv-v4-pretrained/vocoder.pth" % (now_dir,),
|
||||||
|
map_location="cpu",
|
||||||
|
weights_only=False,
|
||||||
)
|
)
|
||||||
print("loading vocoder", hifigan_model.load_state_dict(state_dict_g))
|
print("loading vocoder", hifigan_model.load_state_dict(state_dict_g))
|
||||||
if is_half == True:
|
if is_half == True:
|
||||||
@ -272,19 +280,21 @@ def init_hifigan():
|
|||||||
|
|
||||||
|
|
||||||
from sv import SV
|
from sv import SV
|
||||||
|
|
||||||
|
|
||||||
def init_sv_cn():
|
def init_sv_cn():
|
||||||
global hifigan_model, bigvgan_model, sv_cn_model
|
global hifigan_model, bigvgan_model, sv_cn_model
|
||||||
sv_cn_model = SV(device, is_half)
|
sv_cn_model = SV(device, is_half)
|
||||||
|
|
||||||
|
|
||||||
resample_transform_dict = {}
|
resample_transform_dict = {}
|
||||||
|
|
||||||
|
|
||||||
def resample(audio_tensor, sr0, sr1, device):
|
def resample(audio_tensor, sr0, sr1, device):
|
||||||
global resample_transform_dict
|
global resample_transform_dict
|
||||||
key = "%s-%s-%s" % (sr0, sr1, str(device))
|
key = "%s-%s-%s" % (sr0, sr1, str(device))
|
||||||
if key not in resample_transform_dict:
|
if key not in resample_transform_dict:
|
||||||
resample_transform_dict[key] = torchaudio.transforms.Resample(
|
resample_transform_dict[key] = torchaudio.transforms.Resample(sr0, sr1).to(device)
|
||||||
sr0, sr1
|
|
||||||
).to(device)
|
|
||||||
return resample_transform_dict[key](audio_tensor)
|
return resample_transform_dict[key](audio_tensor)
|
||||||
|
|
||||||
|
|
||||||
@ -370,6 +380,7 @@ from process_ckpt import get_sovits_version_from_path_fast, load_sovits_new
|
|||||||
|
|
||||||
def get_sovits_weights(sovits_path):
|
def get_sovits_weights(sovits_path):
|
||||||
from config import pretrained_sovits_name
|
from config import pretrained_sovits_name
|
||||||
|
|
||||||
path_sovits_v3 = pretrained_sovits_name["v3"]
|
path_sovits_v3 = pretrained_sovits_name["v3"]
|
||||||
path_sovits_v4 = pretrained_sovits_name["v4"]
|
path_sovits_v4 = pretrained_sovits_name["v4"]
|
||||||
is_exist_s2gv3 = os.path.exists(path_sovits_v3)
|
is_exist_s2gv3 = os.path.exists(path_sovits_v3)
|
||||||
@ -632,11 +643,13 @@ def get_spepc(hps, filename, dtype, device, is_v2pro=False):
|
|||||||
audio, sr0 = torchaudio.load(filename)
|
audio, sr0 = torchaudio.load(filename)
|
||||||
if sr0 != sr1:
|
if sr0 != sr1:
|
||||||
audio = audio.to(device)
|
audio = audio.to(device)
|
||||||
if(audio.shape[0]==2):audio=audio.mean(0).unsqueeze(0)
|
if audio.shape[0] == 2:
|
||||||
|
audio = audio.mean(0).unsqueeze(0)
|
||||||
audio = resample(audio, sr0, sr1, device)
|
audio = resample(audio, sr0, sr1, device)
|
||||||
else:
|
else:
|
||||||
audio = audio.to(device)
|
audio = audio.to(device)
|
||||||
if(audio.shape[0]==2):audio=audio.mean(0).unsqueeze(0)
|
if audio.shape[0] == 2:
|
||||||
|
audio = audio.mean(0).unsqueeze(0)
|
||||||
|
|
||||||
maxx = audio.abs().max()
|
maxx = audio.abs().max()
|
||||||
if maxx > 1:
|
if maxx > 1:
|
||||||
@ -937,14 +950,22 @@ def get_tts_wav(
|
|||||||
if version not in {"v3", "v4"}:
|
if version not in {"v3", "v4"}:
|
||||||
if is_v2pro:
|
if is_v2pro:
|
||||||
audio = (
|
audio = (
|
||||||
vq_model.decode(pred_semantic, torch.LongTensor(phones2).to(device).unsqueeze(0), refers, speed=speed,sv_emb=sv_emb)
|
vq_model.decode(
|
||||||
|
pred_semantic,
|
||||||
|
torch.LongTensor(phones2).to(device).unsqueeze(0),
|
||||||
|
refers,
|
||||||
|
speed=speed,
|
||||||
|
sv_emb=sv_emb,
|
||||||
|
)
|
||||||
.detach()
|
.detach()
|
||||||
.cpu()
|
.cpu()
|
||||||
.numpy()[0, 0]
|
.numpy()[0, 0]
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
audio = (
|
audio = (
|
||||||
vq_model.decode(pred_semantic, torch.LongTensor(phones2).to(device).unsqueeze(0), refers, speed=speed)
|
vq_model.decode(
|
||||||
|
pred_semantic, torch.LongTensor(phones2).to(device).unsqueeze(0), refers, speed=speed
|
||||||
|
)
|
||||||
.detach()
|
.detach()
|
||||||
.cpu()
|
.cpu()
|
||||||
.numpy()[0, 0]
|
.numpy()[0, 0]
|
||||||
@ -1108,7 +1129,6 @@ def handle(
|
|||||||
if not default_refer.is_ready():
|
if not default_refer.is_ready():
|
||||||
return JSONResponse({"code": 400, "message": "未指定参考音频且接口无预设"}, status_code=400)
|
return JSONResponse({"code": 400, "message": "未指定参考音频且接口无预设"}, status_code=400)
|
||||||
|
|
||||||
|
|
||||||
if cut_punc == None:
|
if cut_punc == None:
|
||||||
text = cut_text(text, default_cut_punc)
|
text = cut_text(text, default_cut_punc)
|
||||||
else:
|
else:
|
||||||
|
10
config.py
10
config.py
@ -144,6 +144,7 @@ webui_port_subfix = 9871
|
|||||||
|
|
||||||
api_port = 9880
|
api_port = 9880
|
||||||
|
|
||||||
|
|
||||||
# Thanks to the contribution of @Karasukaigan and @XXXXRT666
|
# Thanks to the contribution of @Karasukaigan and @XXXXRT666
|
||||||
def get_device_dtype_sm(idx: int) -> tuple[torch.device, torch.dtype, float, float]:
|
def get_device_dtype_sm(idx: int) -> tuple[torch.device, torch.dtype, float, float]:
|
||||||
cpu = torch.device("cpu")
|
cpu = torch.device("cpu")
|
||||||
@ -158,9 +159,12 @@ def get_device_dtype_sm(idx: int) -> tuple[torch.device, torch.dtype, float, flo
|
|||||||
major, minor = capability
|
major, minor = capability
|
||||||
sm_version = major + minor / 10.0
|
sm_version = major + minor / 10.0
|
||||||
is_16_series = bool(re.search(r"16\d{2}", name)) and sm_version == 7.5
|
is_16_series = bool(re.search(r"16\d{2}", name)) and sm_version == 7.5
|
||||||
if mem_gb < 4 or sm_version < 5.3:return cpu, torch.float32, 0.0, 0.0
|
if mem_gb < 4 or sm_version < 5.3:
|
||||||
if sm_version == 6.1 or is_16_series==True:return cuda, torch.float32, sm_version, mem_gb
|
return cpu, torch.float32, 0.0, 0.0
|
||||||
if sm_version > 6.1:return cuda, torch.float16, sm_version, mem_gb
|
if sm_version == 6.1 or is_16_series == True:
|
||||||
|
return cuda, torch.float32, sm_version, mem_gb
|
||||||
|
if sm_version > 6.1:
|
||||||
|
return cuda, torch.float16, sm_version, mem_gb
|
||||||
return cpu, torch.float32, 0.0, 0.0
|
return cpu, torch.float32, 0.0, 0.0
|
||||||
|
|
||||||
|
|
||||||
|
@ -190,14 +190,14 @@ class Predictor:
|
|||||||
opt_path_vocal = path_vocal[:-4] + ".%s" % format
|
opt_path_vocal = path_vocal[:-4] + ".%s" % format
|
||||||
opt_path_other = path_other[:-4] + ".%s" % format
|
opt_path_other = path_other[:-4] + ".%s" % format
|
||||||
if os.path.exists(path_vocal):
|
if os.path.exists(path_vocal):
|
||||||
os.system("ffmpeg -i \"%s\" -vn \"%s\" -q:a 2 -y" % (path_vocal, opt_path_vocal))
|
os.system('ffmpeg -i "%s" -vn "%s" -q:a 2 -y' % (path_vocal, opt_path_vocal))
|
||||||
if os.path.exists(opt_path_vocal):
|
if os.path.exists(opt_path_vocal):
|
||||||
try:
|
try:
|
||||||
os.remove(path_vocal)
|
os.remove(path_vocal)
|
||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
if os.path.exists(path_other):
|
if os.path.exists(path_other):
|
||||||
os.system("ffmpeg -i \"%s\" -vn \"%s\" -q:a 2 -y" % (path_other, opt_path_other))
|
os.system('ffmpeg -i "%s" -vn "%s" -q:a 2 -y' % (path_other, opt_path_other))
|
||||||
if os.path.exists(opt_path_other):
|
if os.path.exists(opt_path_other):
|
||||||
try:
|
try:
|
||||||
os.remove(path_other)
|
os.remove(path_other)
|
||||||
|
@ -140,7 +140,7 @@ class AudioPre:
|
|||||||
)
|
)
|
||||||
if os.path.exists(path):
|
if os.path.exists(path):
|
||||||
opt_format_path = path[:-4] + ".%s" % format
|
opt_format_path = path[:-4] + ".%s" % format
|
||||||
cmd="ffmpeg -i \"%s\" -vn \"%s\" -q:a 2 -y" % (path, opt_format_path)
|
cmd = 'ffmpeg -i "%s" -vn "%s" -q:a 2 -y' % (path, opt_format_path)
|
||||||
print(cmd)
|
print(cmd)
|
||||||
os.system(cmd)
|
os.system(cmd)
|
||||||
if os.path.exists(opt_format_path):
|
if os.path.exists(opt_format_path):
|
||||||
@ -177,7 +177,7 @@ class AudioPre:
|
|||||||
)
|
)
|
||||||
if os.path.exists(path):
|
if os.path.exists(path):
|
||||||
opt_format_path = path[:-4] + ".%s" % format
|
opt_format_path = path[:-4] + ".%s" % format
|
||||||
cmd="ffmpeg -i \"%s\" -vn \"%s\" -q:a 2 -y" % (path, opt_format_path)
|
cmd = 'ffmpeg -i "%s" -vn "%s" -q:a 2 -y' % (path, opt_format_path)
|
||||||
print(cmd)
|
print(cmd)
|
||||||
os.system(cmd)
|
os.system(cmd)
|
||||||
if os.path.exists(opt_format_path):
|
if os.path.exists(opt_format_path):
|
||||||
@ -307,7 +307,7 @@ class AudioPreDeEcho:
|
|||||||
)
|
)
|
||||||
if os.path.exists(path):
|
if os.path.exists(path):
|
||||||
opt_format_path = path[:-4] + ".%s" % format
|
opt_format_path = path[:-4] + ".%s" % format
|
||||||
cmd="ffmpeg -i \"%s\" -vn \"%s\" -q:a 2 -y" % (path, opt_format_path)
|
cmd = 'ffmpeg -i "%s" -vn "%s" -q:a 2 -y' % (path, opt_format_path)
|
||||||
print(cmd)
|
print(cmd)
|
||||||
os.system(cmd)
|
os.system(cmd)
|
||||||
if os.path.exists(opt_format_path):
|
if os.path.exists(opt_format_path):
|
||||||
@ -340,7 +340,7 @@ class AudioPreDeEcho:
|
|||||||
)
|
)
|
||||||
if os.path.exists(path):
|
if os.path.exists(path):
|
||||||
opt_format_path = path[:-4] + ".%s" % format
|
opt_format_path = path[:-4] + ".%s" % format
|
||||||
cmd="ffmpeg -i \"%s\" -vn \"%s\" -q:a 2 -y" % (path, opt_format_path)
|
cmd = 'ffmpeg -i "%s" -vn "%s" -q:a 2 -y' % (path, opt_format_path)
|
||||||
print(cmd)
|
print(cmd)
|
||||||
os.system(cmd)
|
os.system(cmd)
|
||||||
if os.path.exists(opt_format_path):
|
if os.path.exists(opt_format_path):
|
||||||
|
Loading…
x
Reference in New Issue
Block a user