Merge pull request #1 from XXXXRT666/XXXXRT666-patch-1

Update .pre-commit-config.yaml
This commit is contained in:
XXXXRT666 2025-05-01 12:29:38 +01:00 committed by GitHub
commit c26354b1b7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
25 changed files with 292 additions and 400 deletions

View File

@ -13,9 +13,3 @@ repos:
- id: ruff-format - id: ruff-format
types_or: [ python, pyi ] types_or: [ python, pyi ]
args: [ --line-length, "120", --target-version, "py310" ] args: [ --line-length, "120", --target-version, "py310" ]
# - repo: https://github.com/codespell-project/codespell
# rev: v2.4.1
# hooks:
# - id: codespell
# files: ^.*\.(py|md)$

View File

@ -108,7 +108,7 @@ resample_transform_dict = {}
def resample(audio_tensor, sr0, sr1, device): def resample(audio_tensor, sr0, sr1, device):
global resample_transform_dict global resample_transform_dict
key="%s-%s"%(sr0,sr1) key = "%s-%s" % (sr0, sr1)
if key not in resample_transform_dict: if key not in resample_transform_dict:
resample_transform_dict[key] = torchaudio.transforms.Resample(sr0, sr1).to(device) resample_transform_dict[key] = torchaudio.transforms.Resample(sr0, sr1).to(device)
return resample_transform_dict[key](audio_tensor) return resample_transform_dict[key](audio_tensor)
@ -252,7 +252,6 @@ class TTS_Config:
"cnhuhbert_base_path": "GPT_SoVITS/pretrained_models/chinese-hubert-base", "cnhuhbert_base_path": "GPT_SoVITS/pretrained_models/chinese-hubert-base",
"bert_base_path": "GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large", "bert_base_path": "GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large",
}, },
} }
configs: dict = None configs: dict = None
v1_languages: list = ["auto", "en", "zh", "ja", "all_zh", "all_ja"] v1_languages: list = ["auto", "en", "zh", "ja", "all_zh", "all_ja"]
@ -432,7 +431,6 @@ class TTS:
"aux_ref_audio_paths": [], "aux_ref_audio_paths": [],
} }
self.stop_flag: bool = False self.stop_flag: bool = False
self.precision: torch.dtype = torch.float16 if self.configs.is_half else torch.float32 self.precision: torch.dtype = torch.float16 if self.configs.is_half else torch.float32
@ -468,7 +466,7 @@ class TTS:
path_sovits = self.configs.default_configs[model_version]["vits_weights_path"] path_sovits = self.configs.default_configs[model_version]["vits_weights_path"]
if if_lora_v3 == True and os.path.exists(path_sovits) == False: if if_lora_v3 == True and os.path.exists(path_sovits) == False:
info = path_sovits + i18n("SoVITS %s 底模缺失,无法加载相应 LoRA 权重"%model_version) info = path_sovits + i18n("SoVITS %s 底模缺失,无法加载相应 LoRA 权重" % model_version)
raise FileExistsError(info) raise FileExistsError(info)
# dict_s2 = torch.load(weights_path, map_location=self.configs.device,weights_only=False) # dict_s2 = torch.load(weights_path, map_location=self.configs.device,weights_only=False)
@ -507,7 +505,7 @@ class TTS:
) )
self.configs.use_vocoder = False self.configs.use_vocoder = False
else: else:
kwargs["version"]=model_version kwargs["version"] = model_version
vits_model = SynthesizerTrnV3( vits_model = SynthesizerTrnV3(
self.configs.filter_length // 2 + 1, self.configs.filter_length // 2 + 1,
self.configs.segment_size // self.configs.hop_length, self.configs.segment_size // self.configs.hop_length,
@ -572,7 +570,7 @@ class TTS:
self.vocoder.cpu() self.vocoder.cpu()
del self.vocoder del self.vocoder
self.empty_cache() self.empty_cache()
self.vocoder = BigVGAN.from_pretrained( self.vocoder = BigVGAN.from_pretrained(
"%s/GPT_SoVITS/pretrained_models/models--nvidia--bigvgan_v2_24khz_100band_256x" % (now_dir,), "%s/GPT_SoVITS/pretrained_models/models--nvidia--bigvgan_v2_24khz_100band_256x" % (now_dir,),
use_cuda_kernel=False, use_cuda_kernel=False,
@ -595,18 +593,21 @@ class TTS:
self.empty_cache() self.empty_cache()
self.vocoder = Generator( self.vocoder = Generator(
initial_channel=100, initial_channel=100,
resblock="1", resblock="1",
resblock_kernel_sizes=[3, 7, 11], resblock_kernel_sizes=[3, 7, 11],
resblock_dilation_sizes=[[1, 3, 5], [1, 3, 5], [1, 3, 5]], resblock_dilation_sizes=[[1, 3, 5], [1, 3, 5], [1, 3, 5]],
upsample_rates=[10, 6, 2, 2, 2], upsample_rates=[10, 6, 2, 2, 2],
upsample_initial_channel=512, upsample_initial_channel=512,
upsample_kernel_sizes=[20, 12, 4, 4, 4], upsample_kernel_sizes=[20, 12, 4, 4, 4],
gin_channels=0, is_bias=True gin_channels=0,
) is_bias=True,
)
self.vocoder.remove_weight_norm() self.vocoder.remove_weight_norm()
state_dict_g = torch.load("%s/GPT_SoVITS/pretrained_models/gsv-v4-pretrained/vocoder.pth" % (now_dir,), map_location="cpu") state_dict_g = torch.load(
print("loading vocoder",self.vocoder.load_state_dict(state_dict_g)) "%s/GPT_SoVITS/pretrained_models/gsv-v4-pretrained/vocoder.pth" % (now_dir,), map_location="cpu"
)
print("loading vocoder", self.vocoder.load_state_dict(state_dict_g))
self.vocoder_configs["sr"] = 48000 self.vocoder_configs["sr"] = 48000
self.vocoder_configs["T_ref"] = 500 self.vocoder_configs["T_ref"] = 500
@ -614,9 +615,6 @@ class TTS:
self.vocoder_configs["upsample_rate"] = 480 self.vocoder_configs["upsample_rate"] = 480
self.vocoder_configs["overlapped_len"] = 12 self.vocoder_configs["overlapped_len"] = 12
self.vocoder = self.vocoder.eval() self.vocoder = self.vocoder.eval()
if self.configs.is_half == True: if self.configs.is_half == True:
self.vocoder = self.vocoder.half().to(self.configs.device) self.vocoder = self.vocoder.half().to(self.configs.device)
@ -1439,7 +1437,7 @@ class TTS:
ref_audio = ref_audio.to(self.configs.device).float() ref_audio = ref_audio.to(self.configs.device).float()
if ref_audio.shape[0] == 2: if ref_audio.shape[0] == 2:
ref_audio = ref_audio.mean(0).unsqueeze(0) ref_audio = ref_audio.mean(0).unsqueeze(0)
# tgt_sr = self.vocoder_configs["sr"] # tgt_sr = self.vocoder_configs["sr"]
tgt_sr = 24000 if self.configs.version == "v3" else 32000 tgt_sr = 24000 if self.configs.version == "v3" else 32000
if ref_sr != tgt_sr: if ref_sr != tgt_sr:

View File

@ -106,7 +106,7 @@ cnhubert.cnhubert_base_path = cnhubert_base_path
import random import random
from GPT_SoVITS.module.models import SynthesizerTrn, SynthesizerTrnV3,Generator from GPT_SoVITS.module.models import SynthesizerTrn, SynthesizerTrnV3, Generator
def set_seed(seed): def set_seed(seed):
@ -226,9 +226,9 @@ else:
resample_transform_dict = {} resample_transform_dict = {}
def resample(audio_tensor, sr0,sr1): def resample(audio_tensor, sr0, sr1):
global resample_transform_dict global resample_transform_dict
key="%s-%s"%(sr0,sr1) key = "%s-%s" % (sr0, sr1)
if key not in resample_transform_dict: if key not in resample_transform_dict:
resample_transform_dict[key] = torchaudio.transforms.Resample(sr0, sr1).to(device) resample_transform_dict[key] = torchaudio.transforms.Resample(sr0, sr1).to(device)
return resample_transform_dict[key](audio_tensor) return resample_transform_dict[key](audio_tensor)
@ -238,14 +238,18 @@ def resample(audio_tensor, sr0,sr1):
# symbol_version-model_version-if_lora_v3 # symbol_version-model_version-if_lora_v3
from process_ckpt import get_sovits_version_from_path_fast, load_sovits_new from process_ckpt import get_sovits_version_from_path_fast, load_sovits_new
v3v4set={"v3","v4"} v3v4set = {"v3", "v4"}
def change_sovits_weights(sovits_path, prompt_language=None, text_language=None): def change_sovits_weights(sovits_path, prompt_language=None, text_language=None):
global vq_model, hps, version, model_version, dict_language, if_lora_v3 global vq_model, hps, version, model_version, dict_language, if_lora_v3
version, model_version, if_lora_v3 = get_sovits_version_from_path_fast(sovits_path) version, model_version, if_lora_v3 = get_sovits_version_from_path_fast(sovits_path)
print(sovits_path,version, model_version, if_lora_v3) print(sovits_path, version, model_version, if_lora_v3)
is_exist=is_exist_s2gv3 if model_version=="v3"else is_exist_s2gv4 is_exist = is_exist_s2gv3 if model_version == "v3" else is_exist_s2gv4
if if_lora_v3 == True and is_exist == False: if if_lora_v3 == True and is_exist == False:
info = "GPT_SoVITS/pretrained_models/s2Gv3.pth" + i18n("SoVITS %s 底模缺失,无法加载相应 LoRA 权重"%model_version) info = "GPT_SoVITS/pretrained_models/s2Gv3.pth" + i18n(
"SoVITS %s 底模缺失,无法加载相应 LoRA 权重" % model_version
)
gr.Warning(info) gr.Warning(info)
raise FileExistsError(info) raise FileExistsError(info)
dict_language = dict_language_v1 if version == "v1" else dict_language_v2 dict_language = dict_language_v1 if version == "v1" else dict_language_v2
@ -276,10 +280,15 @@ def change_sovits_weights(sovits_path, prompt_language=None, text_language=None)
prompt_language_update, prompt_language_update,
text_update, text_update,
text_language_update, text_language_update,
{"__type__": "update", "visible": visible_sample_steps, "value": 32 if model_version=="v3"else 8,"choices":[4, 8, 16, 32,64,128]if model_version=="v3"else [4, 8, 16, 32]}, {
"__type__": "update",
"visible": visible_sample_steps,
"value": 32 if model_version == "v3" else 8,
"choices": [4, 8, 16, 32, 64, 128] if model_version == "v3" else [4, 8, 16, 32],
},
{"__type__": "update", "visible": visible_inp_refs}, {"__type__": "update", "visible": visible_inp_refs},
{"__type__": "update", "value": False, "interactive": True if model_version not in v3v4set else False}, {"__type__": "update", "value": False, "interactive": True if model_version not in v3v4set else False},
{"__type__": "update", "visible": True if model_version =="v3" else False}, {"__type__": "update", "visible": True if model_version == "v3" else False},
{"__type__": "update", "value": i18n("模型加载中,请等待"), "interactive": False}, {"__type__": "update", "value": i18n("模型加载中,请等待"), "interactive": False},
) )
@ -304,7 +313,7 @@ def change_sovits_weights(sovits_path, prompt_language=None, text_language=None)
) )
model_version = version model_version = version
else: else:
hps.model.version=model_version hps.model.version = model_version
vq_model = SynthesizerTrnV3( vq_model = SynthesizerTrnV3(
hps.data.filter_length // 2 + 1, hps.data.filter_length // 2 + 1,
hps.train.segment_size // hps.data.hop_length, hps.train.segment_size // hps.data.hop_length,
@ -326,7 +335,7 @@ def change_sovits_weights(sovits_path, prompt_language=None, text_language=None)
else: else:
path_sovits = path_sovits_v3 if model_version == "v3" else path_sovits_v4 path_sovits = path_sovits_v3 if model_version == "v3" else path_sovits_v4
print( print(
"loading sovits_%spretrained_G"%model_version, "loading sovits_%spretrained_G" % model_version,
vq_model.load_state_dict(load_sovits_new(path_sovits)["weight"], strict=False), vq_model.load_state_dict(load_sovits_new(path_sovits)["weight"], strict=False),
) )
lora_rank = dict_s2["lora_rank"] lora_rank = dict_s2["lora_rank"]
@ -337,7 +346,7 @@ def change_sovits_weights(sovits_path, prompt_language=None, text_language=None)
init_lora_weights=True, init_lora_weights=True,
) )
vq_model.cfm = get_peft_model(vq_model.cfm, lora_config) vq_model.cfm = get_peft_model(vq_model.cfm, lora_config)
print("loading sovits_%s_lora%s" % (model_version,lora_rank)) print("loading sovits_%s_lora%s" % (model_version, lora_rank))
vq_model.load_state_dict(dict_s2["weight"], strict=False) vq_model.load_state_dict(dict_s2["weight"], strict=False)
vq_model.cfm = vq_model.cfm.merge_and_unload() vq_model.cfm = vq_model.cfm.merge_and_unload()
# torch.save(vq_model.state_dict(),"merge_win.pth") # torch.save(vq_model.state_dict(),"merge_win.pth")
@ -350,10 +359,15 @@ def change_sovits_weights(sovits_path, prompt_language=None, text_language=None)
prompt_language_update, prompt_language_update,
text_update, text_update,
text_language_update, text_language_update,
{"__type__": "update", "visible": visible_sample_steps, "value":32 if model_version=="v3"else 8,"choices":[4, 8, 16, 32,64,128]if model_version=="v3"else [4, 8, 16, 32]}, {
"__type__": "update",
"visible": visible_sample_steps,
"value": 32 if model_version == "v3" else 8,
"choices": [4, 8, 16, 32, 64, 128] if model_version == "v3" else [4, 8, 16, 32],
},
{"__type__": "update", "visible": visible_inp_refs}, {"__type__": "update", "visible": visible_inp_refs},
{"__type__": "update", "value": False, "interactive": True if model_version not in v3v4set else False}, {"__type__": "update", "value": False, "interactive": True if model_version not in v3v4set else False},
{"__type__": "update", "visible": True if model_version =="v3" else False}, {"__type__": "update", "visible": True if model_version == "v3" else False},
{"__type__": "update", "value": i18n("合成语音"), "interactive": True}, {"__type__": "update", "value": i18n("合成语音"), "interactive": True},
) )
with open("./weight.json") as f: with open("./weight.json") as f:
@ -400,7 +414,7 @@ now_dir = os.getcwd()
def init_bigvgan(): def init_bigvgan():
global bigvgan_model,hifigan_model global bigvgan_model, hifigan_model
from BigVGAN import bigvgan from BigVGAN import bigvgan
bigvgan_model = bigvgan.BigVGAN.from_pretrained( bigvgan_model = bigvgan.BigVGAN.from_pretrained(
@ -411,17 +425,20 @@ def init_bigvgan():
bigvgan_model.remove_weight_norm() bigvgan_model.remove_weight_norm()
bigvgan_model = bigvgan_model.eval() bigvgan_model = bigvgan_model.eval()
if hifigan_model: if hifigan_model:
hifigan_model=hifigan_model.cpu() hifigan_model = hifigan_model.cpu()
hifigan_model=None hifigan_model = None
try:torch.cuda.empty_cache() try:
except:pass torch.cuda.empty_cache()
except:
pass
if is_half == True: if is_half == True:
bigvgan_model = bigvgan_model.half().to(device) bigvgan_model = bigvgan_model.half().to(device)
else: else:
bigvgan_model = bigvgan_model.to(device) bigvgan_model = bigvgan_model.to(device)
def init_hifigan(): def init_hifigan():
global hifigan_model,bigvgan_model global hifigan_model, bigvgan_model
hifigan_model = Generator( hifigan_model = Generator(
initial_channel=100, initial_channel=100,
resblock="1", resblock="1",
@ -430,26 +447,32 @@ def init_hifigan():
upsample_rates=[10, 6, 2, 2, 2], upsample_rates=[10, 6, 2, 2, 2],
upsample_initial_channel=512, upsample_initial_channel=512,
upsample_kernel_sizes=[20, 12, 4, 4, 4], upsample_kernel_sizes=[20, 12, 4, 4, 4],
gin_channels=0, is_bias=True gin_channels=0,
is_bias=True,
) )
hifigan_model.eval() hifigan_model.eval()
hifigan_model.remove_weight_norm() hifigan_model.remove_weight_norm()
state_dict_g = torch.load("%s/GPT_SoVITS/pretrained_models/gsv-v4-pretrained/vocoder.pth" % (now_dir,), map_location="cpu") state_dict_g = torch.load(
print("loading vocoder",hifigan_model.load_state_dict(state_dict_g)) "%s/GPT_SoVITS/pretrained_models/gsv-v4-pretrained/vocoder.pth" % (now_dir,), map_location="cpu"
)
print("loading vocoder", hifigan_model.load_state_dict(state_dict_g))
if bigvgan_model: if bigvgan_model:
bigvgan_model=bigvgan_model.cpu() bigvgan_model = bigvgan_model.cpu()
bigvgan_model=None bigvgan_model = None
try:torch.cuda.empty_cache() try:
except:pass torch.cuda.empty_cache()
except:
pass
if is_half == True: if is_half == True:
hifigan_model = hifigan_model.half().to(device) hifigan_model = hifigan_model.half().to(device)
else: else:
hifigan_model = hifigan_model.to(device) hifigan_model = hifigan_model.to(device)
bigvgan_model=hifigan_model=None
if model_version=="v3": bigvgan_model = hifigan_model = None
if model_version == "v3":
init_bigvgan() init_bigvgan()
if model_version=="v4": if model_version == "v4":
init_hifigan() init_hifigan()
@ -831,17 +854,17 @@ def get_tts_wav(
ref_audio = ref_audio.to(device).float() ref_audio = ref_audio.to(device).float()
if ref_audio.shape[0] == 2: if ref_audio.shape[0] == 2:
ref_audio = ref_audio.mean(0).unsqueeze(0) ref_audio = ref_audio.mean(0).unsqueeze(0)
tgt_sr=24000 if model_version=="v3"else 32000 tgt_sr = 24000 if model_version == "v3" else 32000
if sr != tgt_sr: if sr != tgt_sr:
ref_audio = resample(ref_audio, sr,tgt_sr) ref_audio = resample(ref_audio, sr, tgt_sr)
# print("ref_audio",ref_audio.abs().mean()) # print("ref_audio",ref_audio.abs().mean())
mel2 = mel_fn(ref_audio)if model_version=="v3"else mel_fn_v4(ref_audio) mel2 = mel_fn(ref_audio) if model_version == "v3" else mel_fn_v4(ref_audio)
mel2 = norm_spec(mel2) mel2 = norm_spec(mel2)
T_min = min(mel2.shape[2], fea_ref.shape[2]) T_min = min(mel2.shape[2], fea_ref.shape[2])
mel2 = mel2[:, :, :T_min] mel2 = mel2[:, :, :T_min]
fea_ref = fea_ref[:, :, :T_min] fea_ref = fea_ref[:, :, :T_min]
Tref=468 if model_version=="v3"else 500 Tref = 468 if model_version == "v3" else 500
Tchunk=934 if model_version=="v3"else 1000 Tchunk = 934 if model_version == "v3" else 1000
if T_min > Tref: if T_min > Tref:
mel2 = mel2[:, :, -Tref:] mel2 = mel2[:, :, -Tref:]
fea_ref = fea_ref[:, :, -Tref:] fea_ref = fea_ref[:, :, -Tref:]
@ -866,13 +889,13 @@ def get_tts_wav(
cfm_resss.append(cfm_res) cfm_resss.append(cfm_res)
cfm_res = torch.cat(cfm_resss, 2) cfm_res = torch.cat(cfm_resss, 2)
cfm_res = denorm_spec(cfm_res) cfm_res = denorm_spec(cfm_res)
if model_version=="v3": if model_version == "v3":
if bigvgan_model == None: if bigvgan_model == None:
init_bigvgan() init_bigvgan()
else:#v4 else: # v4
if hifigan_model == None: if hifigan_model == None:
init_hifigan() init_hifigan()
vocoder_model=bigvgan_model if model_version=="v3"else hifigan_model vocoder_model = bigvgan_model if model_version == "v3" else hifigan_model
with torch.inference_mode(): with torch.inference_mode():
wav_gen = vocoder_model(cfm_res) wav_gen = vocoder_model(cfm_res)
audio = wav_gen[0][0] # .cpu().detach().numpy() audio = wav_gen[0][0] # .cpu().detach().numpy()
@ -886,9 +909,12 @@ def get_tts_wav(
t1 = ttime() t1 = ttime()
print("%.3f\t%.3f\t%.3f\t%.3f" % (t[0], sum(t[1::3]), sum(t[2::3]), sum(t[3::3]))) print("%.3f\t%.3f\t%.3f\t%.3f" % (t[0], sum(t[1::3]), sum(t[2::3]), sum(t[3::3])))
audio_opt = torch.cat(audio_opt, 0) # np.concatenate audio_opt = torch.cat(audio_opt, 0) # np.concatenate
if model_version in {"v1","v2"}:opt_sr=32000 if model_version in {"v1", "v2"}:
elif model_version=="v3":opt_sr=24000 opt_sr = 32000
else:opt_sr=48000#v4 elif model_version == "v3":
opt_sr = 24000
else:
opt_sr = 48000 # v4
if if_sr == True and opt_sr == 24000: if if_sr == True and opt_sr == 24000:
print(i18n("音频超分中")) print(i18n("音频超分中"))
audio_opt, opt_sr = audio_sr(audio_opt.unsqueeze(0), opt_sr) audio_opt, opt_sr = audio_sr(audio_opt.unsqueeze(0), opt_sr)
@ -1131,16 +1157,16 @@ with gr.Blocks(title="GPT-SoVITS WebUI") as app:
sample_steps = ( sample_steps = (
gr.Radio( gr.Radio(
label=i18n("采样步数,如果觉得电,提高试试,如果觉得慢,降低试试"), label=i18n("采样步数,如果觉得电,提高试试,如果觉得慢,降低试试"),
value=32 if model_version=="v3"else 8, value=32 if model_version == "v3" else 8,
choices=[4, 8, 16, 32,64,128]if model_version=="v3"else [4, 8, 16, 32], choices=[4, 8, 16, 32, 64, 128] if model_version == "v3" else [4, 8, 16, 32],
visible=True, visible=True,
) )
if model_version in v3v4set if model_version in v3v4set
else gr.Radio( else gr.Radio(
label=i18n("采样步数,如果觉得电,提高试试,如果觉得慢,降低试试"), label=i18n("采样步数,如果觉得电,提高试试,如果觉得慢,降低试试"),
choices=[4, 8, 16, 32,64,128]if model_version=="v3"else [4, 8, 16, 32], choices=[4, 8, 16, 32, 64, 128] if model_version == "v3" else [4, 8, 16, 32],
visible=False, visible=False,
value=32 if model_version=="v3"else 8, value=32 if model_version == "v3" else 8,
) )
) )
if_sr_Checkbox = gr.Checkbox( if_sr_Checkbox = gr.Checkbox(
@ -1148,7 +1174,7 @@ with gr.Blocks(title="GPT-SoVITS WebUI") as app:
value=False, value=False,
interactive=True, interactive=True,
show_label=True, show_label=True,
visible=False if model_version !="v3" else True, visible=False if model_version != "v3" else True,
) )
gr.Markdown(html_center(i18n("*请填写需要合成的目标文本和语种模式"), "h3")) gr.Markdown(html_center(i18n("*请填写需要合成的目标文本和语种模式"), "h3"))
with gr.Row(): with gr.Row():

View File

@ -262,15 +262,17 @@ SoVITS_names, GPT_names = get_weights_names(GPT_weight_root, SoVITS_weight_root)
from process_ckpt import get_sovits_version_from_path_fast from process_ckpt import get_sovits_version_from_path_fast
v3v4set={"v3","v4"} v3v4set = {"v3", "v4"}
def change_sovits_weights(sovits_path, prompt_language=None, text_language=None): def change_sovits_weights(sovits_path, prompt_language=None, text_language=None):
global version, model_version, dict_language, if_lora_v3 global version, model_version, dict_language, if_lora_v3
version, model_version, if_lora_v3 = get_sovits_version_from_path_fast(sovits_path) version, model_version, if_lora_v3 = get_sovits_version_from_path_fast(sovits_path)
# print(sovits_path,version, model_version, if_lora_v3) # print(sovits_path,version, model_version, if_lora_v3)
is_exist=is_exist_s2gv3 if model_version=="v3"else is_exist_s2gv4 is_exist = is_exist_s2gv3 if model_version == "v3" else is_exist_s2gv4
path_sovits = path_sovits_v3 if model_version == "v3" else path_sovits_v4 path_sovits = path_sovits_v3 if model_version == "v3" else path_sovits_v4
if if_lora_v3 == True and is_exist == False: if if_lora_v3 == True and is_exist == False:
info = path_sovits + i18n("SoVITS %s 底模缺失,无法加载相应 LoRA 权重"%model_version) info = path_sovits + i18n("SoVITS %s 底模缺失,无法加载相应 LoRA 权重" % model_version)
gr.Warning(info) gr.Warning(info)
raise FileExistsError(info) raise FileExistsError(info)
dict_language = dict_language_v1 if version == "v1" else dict_language_v2 dict_language = dict_language_v1 if version == "v1" else dict_language_v2

View File

@ -470,6 +470,7 @@ class TextAudioSpeakerCollateV3:
# return ssl_padded, spec_padded,mel_padded, ssl_lengths, spec_lengths, text_padded, text_lengths, wav_padded, wav_lengths,mel_lengths # return ssl_padded, spec_padded,mel_padded, ssl_lengths, spec_lengths, text_padded, text_lengths, wav_padded, wav_lengths,mel_lengths
return ssl_padded, spec_padded, mel_padded, ssl_lengths, spec_lengths, text_padded, text_lengths, mel_lengths return ssl_padded, spec_padded, mel_padded, ssl_lengths, spec_lengths, text_padded, text_lengths, mel_lengths
class TextAudioSpeakerLoaderV4(torch.utils.data.Dataset): class TextAudioSpeakerLoaderV4(torch.utils.data.Dataset):
""" """
1) loads audio, speaker_id, text pairs 1) loads audio, speaker_id, text pairs
@ -596,7 +597,7 @@ class TextAudioSpeakerLoaderV4(torch.utils.data.Dataset):
audio_norm, self.filter_length, self.sampling_rate, self.hop_length, self.win_length, center=False audio_norm, self.filter_length, self.sampling_rate, self.hop_length, self.win_length, center=False
) )
spec = torch.squeeze(spec, 0) spec = torch.squeeze(spec, 0)
spec1 = spectrogram_torch(audio_norm, 1280,32000, 320, 1280,center=False) spec1 = spectrogram_torch(audio_norm, 1280, 32000, 320, 1280, center=False)
mel = spec_to_mel_torch(spec1, 1280, 100, 32000, 0, None) mel = spec_to_mel_torch(spec1, 1280, 100, 32000, 0, None)
mel = self.norm_spec(torch.squeeze(mel, 0)) mel = self.norm_spec(torch.squeeze(mel, 0))
return spec, mel return spec, mel
@ -643,7 +644,7 @@ class TextAudioSpeakerCollateV4:
mel_lengths = torch.LongTensor(len(batch)) mel_lengths = torch.LongTensor(len(batch))
spec_padded = torch.FloatTensor(len(batch), batch[0][1].size(0), max_spec_len) spec_padded = torch.FloatTensor(len(batch), batch[0][1].size(0), max_spec_len)
mel_padded = torch.FloatTensor(len(batch), batch[0][2].size(0), max_spec_len*2) mel_padded = torch.FloatTensor(len(batch), batch[0][2].size(0), max_spec_len * 2)
ssl_padded = torch.FloatTensor(len(batch), batch[0][0].size(1), max_ssl_len) ssl_padded = torch.FloatTensor(len(batch), batch[0][0].size(1), max_ssl_len)
text_padded = torch.LongTensor(len(batch), max_text_len) text_padded = torch.LongTensor(len(batch), max_text_len)
# wav_padded = torch.FloatTensor(len(batch), 1, max_wav_len) # wav_padded = torch.FloatTensor(len(batch), 1, max_wav_len)

View File

@ -39,24 +39,36 @@ hann_window = {}
def spectrogram_torch(y, n_fft, sampling_rate, hop_size, win_size, center=False): def spectrogram_torch(y, n_fft, sampling_rate, hop_size, win_size, center=False):
if torch.min(y) < -1.2: if torch.min(y) < -1.2:
print('min value is ', torch.min(y)) print("min value is ", torch.min(y))
if torch.max(y) > 1.2: if torch.max(y) > 1.2:
print('max value is ', torch.max(y)) print("max value is ", torch.max(y))
global hann_window global hann_window
dtype_device = str(y.dtype) + '_' + str(y.device) dtype_device = str(y.dtype) + "_" + str(y.device)
# wnsize_dtype_device = str(win_size) + '_' + dtype_device # wnsize_dtype_device = str(win_size) + '_' + dtype_device
key = "%s-%s-%s-%s-%s" %(dtype_device,n_fft, sampling_rate, hop_size, win_size) key = "%s-%s-%s-%s-%s" % (dtype_device, n_fft, sampling_rate, hop_size, win_size)
# if wnsize_dtype_device not in hann_window: # if wnsize_dtype_device not in hann_window:
if key not in hann_window: if key not in hann_window:
# hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(dtype=y.dtype, device=y.device) # hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(dtype=y.dtype, device=y.device)
hann_window[key] = torch.hann_window(win_size).to(dtype=y.dtype, device=y.device) hann_window[key] = torch.hann_window(win_size).to(dtype=y.dtype, device=y.device)
y = torch.nn.functional.pad(y.unsqueeze(1), (int((n_fft-hop_size)/2), int((n_fft-hop_size)/2)), mode='reflect') y = torch.nn.functional.pad(
y.unsqueeze(1), (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)), mode="reflect"
)
y = y.squeeze(1) y = y.squeeze(1)
# spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[wnsize_dtype_device], # spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[wnsize_dtype_device],
spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[key], spec = torch.stft(
center=center, pad_mode='reflect', normalized=False, onesided=True, return_complex=False) y,
n_fft,
hop_length=hop_size,
win_length=win_size,
window=hann_window[key],
center=center,
pad_mode="reflect",
normalized=False,
onesided=True,
return_complex=False,
)
spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-8) spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-8)
return spec return spec
@ -64,9 +76,9 @@ def spectrogram_torch(y, n_fft, sampling_rate, hop_size, win_size, center=False)
def spec_to_mel_torch(spec, n_fft, num_mels, sampling_rate, fmin, fmax): def spec_to_mel_torch(spec, n_fft, num_mels, sampling_rate, fmin, fmax):
global mel_basis global mel_basis
dtype_device = str(spec.dtype) + '_' + str(spec.device) dtype_device = str(spec.dtype) + "_" + str(spec.device)
# fmax_dtype_device = str(fmax) + '_' + dtype_device # fmax_dtype_device = str(fmax) + '_' + dtype_device
key = "%s-%s-%s-%s-%s-%s"%(dtype_device,n_fft, num_mels, sampling_rate, fmin, fmax) key = "%s-%s-%s-%s-%s-%s" % (dtype_device, n_fft, num_mels, sampling_rate, fmin, fmax)
# if fmax_dtype_device not in mel_basis: # if fmax_dtype_device not in mel_basis:
if key not in mel_basis: if key not in mel_basis:
mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax) mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax)
@ -78,17 +90,25 @@ def spec_to_mel_torch(spec, n_fft, num_mels, sampling_rate, fmin, fmax):
return spec return spec
def mel_spectrogram_torch(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False): def mel_spectrogram_torch(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False):
if torch.min(y) < -1.2: if torch.min(y) < -1.2:
print('min value is ', torch.min(y)) print("min value is ", torch.min(y))
if torch.max(y) > 1.2: if torch.max(y) > 1.2:
print('max value is ', torch.max(y)) print("max value is ", torch.max(y))
global mel_basis, hann_window global mel_basis, hann_window
dtype_device = str(y.dtype) + '_' + str(y.device) dtype_device = str(y.dtype) + "_" + str(y.device)
# fmax_dtype_device = str(fmax) + '_' + dtype_device # fmax_dtype_device = str(fmax) + '_' + dtype_device
fmax_dtype_device = "%s-%s-%s-%s-%s-%s-%s-%s"%(dtype_device,n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax) fmax_dtype_device = "%s-%s-%s-%s-%s-%s-%s-%s" % (
dtype_device,
n_fft,
num_mels,
sampling_rate,
hop_size,
win_size,
fmin,
fmax,
)
# wnsize_dtype_device = str(win_size) + '_' + dtype_device # wnsize_dtype_device = str(win_size) + '_' + dtype_device
wnsize_dtype_device = fmax_dtype_device wnsize_dtype_device = fmax_dtype_device
if fmax_dtype_device not in mel_basis: if fmax_dtype_device not in mel_basis:
@ -97,11 +117,23 @@ def mel_spectrogram_torch(y, n_fft, num_mels, sampling_rate, hop_size, win_size,
if wnsize_dtype_device not in hann_window: if wnsize_dtype_device not in hann_window:
hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(dtype=y.dtype, device=y.device) hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(dtype=y.dtype, device=y.device)
y = torch.nn.functional.pad(y.unsqueeze(1), (int((n_fft-hop_size)/2), int((n_fft-hop_size)/2)), mode='reflect') y = torch.nn.functional.pad(
y.unsqueeze(1), (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)), mode="reflect"
)
y = y.squeeze(1) y = y.squeeze(1)
spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[wnsize_dtype_device], spec = torch.stft(
center=center, pad_mode='reflect', normalized=False, onesided=True, return_complex=False) y,
n_fft,
hop_length=hop_size,
win_length=win_size,
window=hann_window[wnsize_dtype_device],
center=center,
pad_mode="reflect",
normalized=False,
onesided=True,
return_complex=False,
)
spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-8) spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-8)

View File

@ -414,7 +414,8 @@ class Generator(torch.nn.Module):
upsample_rates, upsample_rates,
upsample_initial_channel, upsample_initial_channel,
upsample_kernel_sizes, upsample_kernel_sizes,
gin_channels=0,is_bias=False, gin_channels=0,
is_bias=False,
): ):
super(Generator, self).__init__() super(Generator, self).__init__()
self.num_kernels = len(resblock_kernel_sizes) self.num_kernels = len(resblock_kernel_sizes)
@ -1173,7 +1174,7 @@ class SynthesizerTrnV3(nn.Module):
quantized = F.interpolate(quantized, scale_factor=2, mode="nearest") ##BCT quantized = F.interpolate(quantized, scale_factor=2, mode="nearest") ##BCT
x, m_p, logs_p, y_mask = self.enc_p(quantized, y_lengths, text, text_lengths, ge) x, m_p, logs_p, y_mask = self.enc_p(quantized, y_lengths, text, text_lengths, ge)
fea = self.bridge(x) fea = self.bridge(x)
fea = F.interpolate(fea, scale_factor=(1.875 if self.version=="v3"else 2), mode="nearest") ##BCT fea = F.interpolate(fea, scale_factor=(1.875 if self.version == "v3" else 2), mode="nearest") ##BCT
fea, y_mask_ = self.wns1( fea, y_mask_ = self.wns1(
fea, mel_lengths, ge fea, mel_lengths, ge
) ##If the 1-minute fine-tuning works fine, no need to manually adjust the learning rate. ) ##If the 1-minute fine-tuning works fine, no need to manually adjust the learning rate.
@ -1196,9 +1197,9 @@ class SynthesizerTrnV3(nn.Module):
ge = self.ref_enc(refer[:, :704] * refer_mask, refer_mask) ge = self.ref_enc(refer[:, :704] * refer_mask, refer_mask)
y_lengths = torch.LongTensor([int(codes.size(2) * 2)]).to(codes.device) y_lengths = torch.LongTensor([int(codes.size(2) * 2)]).to(codes.device)
if speed == 1: if speed == 1:
sizee = int(codes.size(2) * (3.875 if self.version=="v3"else 4)) sizee = int(codes.size(2) * (3.875 if self.version == "v3" else 4))
else: else:
sizee = int(codes.size(2) * (3.875 if self.version=="v3"else 4) / speed) + 1 sizee = int(codes.size(2) * (3.875 if self.version == "v3" else 4) / speed) + 1
y_lengths1 = torch.LongTensor([sizee]).to(codes.device) y_lengths1 = torch.LongTensor([sizee]).to(codes.device)
text_lengths = torch.LongTensor([text.size(-1)]).to(text.device) text_lengths = torch.LongTensor([text.size(-1)]).to(text.device)
@ -1207,7 +1208,7 @@ class SynthesizerTrnV3(nn.Module):
quantized = F.interpolate(quantized, scale_factor=2, mode="nearest") ##BCT quantized = F.interpolate(quantized, scale_factor=2, mode="nearest") ##BCT
x, m_p, logs_p, y_mask = self.enc_p(quantized, y_lengths, text, text_lengths, ge, speed) x, m_p, logs_p, y_mask = self.enc_p(quantized, y_lengths, text, text_lengths, ge, speed)
fea = self.bridge(x) fea = self.bridge(x)
fea = F.interpolate(fea, scale_factor=(1.875 if self.version=="v3"else 2), mode="nearest") ##BCT fea = F.interpolate(fea, scale_factor=(1.875 if self.version == "v3" else 2), mode="nearest") ##BCT
####more wn paramter to learn mel ####more wn paramter to learn mel
fea, y_mask_ = self.wns1(fea, y_lengths1, ge) fea, y_mask_ = self.wns1(fea, y_lengths1, ge)
return fea, ge return fea, ge

View File

@ -28,18 +28,18 @@ def my_save(fea, path): #####fix issue: torch.save doesn't support chinese path
from io import BytesIO from io import BytesIO
def my_save2(fea, path,cfm_version): def my_save2(fea, path, cfm_version):
bio = BytesIO() bio = BytesIO()
torch.save(fea, bio) torch.save(fea, bio)
bio.seek(0) bio.seek(0)
data = bio.getvalue() data = bio.getvalue()
byte=b"03" if cfm_version=="v3"else b"04" byte = b"03" if cfm_version == "v3" else b"04"
data = byte + data[2:] data = byte + data[2:]
with open(path, "wb") as f: with open(path, "wb") as f:
f.write(data) f.write(data)
def savee(ckpt, name, epoch, steps, hps, cfm_version=None,lora_rank=None): def savee(ckpt, name, epoch, steps, hps, cfm_version=None, lora_rank=None):
try: try:
opt = OrderedDict() opt = OrderedDict()
opt["weight"] = {} opt["weight"] = {}
@ -51,7 +51,7 @@ def savee(ckpt, name, epoch, steps, hps, cfm_version=None,lora_rank=None):
opt["info"] = "%sepoch_%siteration" % (epoch, steps) opt["info"] = "%sepoch_%siteration" % (epoch, steps)
if lora_rank: if lora_rank:
opt["lora_rank"] = lora_rank opt["lora_rank"] = lora_rank
my_save2(opt, "%s/%s.pth" % (hps.save_weight_dir, name),cfm_version) my_save2(opt, "%s/%s.pth" % (hps.save_weight_dir, name), cfm_version)
else: else:
my_save(opt, "%s/%s.pth" % (hps.save_weight_dir, name)) my_save(opt, "%s/%s.pth" % (hps.save_weight_dir, name))
return "Success." return "Success."

View File

@ -31,7 +31,6 @@ from module.data_utils import (
TextAudioSpeakerLoaderV3, TextAudioSpeakerLoaderV3,
TextAudioSpeakerCollateV4, TextAudioSpeakerCollateV4,
TextAudioSpeakerLoaderV4, TextAudioSpeakerLoaderV4,
) )
from module.models import ( from module.models import (
SynthesizerTrnV3 as SynthesizerTrn, SynthesizerTrnV3 as SynthesizerTrn,
@ -88,8 +87,8 @@ def run(rank, n_gpus, hps):
if torch.cuda.is_available(): if torch.cuda.is_available():
torch.cuda.set_device(rank) torch.cuda.set_device(rank)
TextAudioSpeakerLoader=TextAudioSpeakerLoaderV3 if hps.model.version=="v3"else TextAudioSpeakerLoaderV4 TextAudioSpeakerLoader = TextAudioSpeakerLoaderV3 if hps.model.version == "v3" else TextAudioSpeakerLoaderV4
TextAudioSpeakerCollate=TextAudioSpeakerCollateV3 if hps.model.version=="v3"else TextAudioSpeakerCollateV4 TextAudioSpeakerCollate = TextAudioSpeakerCollateV3 if hps.model.version == "v3" else TextAudioSpeakerCollateV4
train_dataset = TextAudioSpeakerLoader(hps.data) ######## train_dataset = TextAudioSpeakerLoader(hps.data) ########
train_sampler = DistributedBucketSampler( train_sampler = DistributedBucketSampler(
train_dataset, train_dataset,
@ -365,7 +364,8 @@ def train_and_evaluate(rank, epoch, hps, nets, optims, schedulers, scaler, loade
hps.name + "_e%s_s%s_l%s" % (epoch, global_step, lora_rank), hps.name + "_e%s_s%s_l%s" % (epoch, global_step, lora_rank),
epoch, epoch,
global_step, global_step,
hps,cfm_version=hps.model.version, hps,
cfm_version=hps.model.version,
lora_rank=lora_rank, lora_rank=lora_rank,
), ),
) )

View File

@ -32,18 +32,10 @@ def make_pair(mix_dir, inst_dir):
input_exts = [".wav", ".m4a", ".mp3", ".mp4", ".flac"] input_exts = [".wav", ".m4a", ".mp3", ".mp4", ".flac"]
X_list = sorted( X_list = sorted(
[ [os.path.join(mix_dir, fname) for fname in os.listdir(mix_dir) if os.path.splitext(fname)[1] in input_exts]
os.path.join(mix_dir, fname)
for fname in os.listdir(mix_dir)
if os.path.splitext(fname)[1] in input_exts
]
) )
y_list = sorted( y_list = sorted(
[ [os.path.join(inst_dir, fname) for fname in os.listdir(inst_dir) if os.path.splitext(fname)[1] in input_exts]
os.path.join(inst_dir, fname)
for fname in os.listdir(inst_dir)
if os.path.splitext(fname)[1] in input_exts
]
) )
filelist = list(zip(X_list, y_list)) filelist = list(zip(X_list, y_list))
@ -65,14 +57,10 @@ def train_val_split(dataset_dir, split_mode, val_rate, val_filelist):
train_filelist = filelist[:-val_size] train_filelist = filelist[:-val_size]
val_filelist = filelist[-val_size:] val_filelist = filelist[-val_size:]
else: else:
train_filelist = [ train_filelist = [pair for pair in filelist if list(pair) not in val_filelist]
pair for pair in filelist if list(pair) not in val_filelist
]
elif split_mode == "subdirs": elif split_mode == "subdirs":
if len(val_filelist) != 0: if len(val_filelist) != 0:
raise ValueError( raise ValueError("The `val_filelist` option is not available in `subdirs` mode")
"The `val_filelist` option is not available in `subdirs` mode"
)
train_filelist = make_pair( train_filelist = make_pair(
os.path.join(dataset_dir, "training/mixtures"), os.path.join(dataset_dir, "training/mixtures"),
@ -91,9 +79,7 @@ def augment(X, y, reduction_rate, reduction_mask, mixup_rate, mixup_alpha):
perm = np.random.permutation(len(X)) perm = np.random.permutation(len(X))
for i, idx in enumerate(tqdm(perm)): for i, idx in enumerate(tqdm(perm)):
if np.random.uniform() < reduction_rate: if np.random.uniform() < reduction_rate:
y[idx] = spec_utils.reduce_vocal_aggressively( y[idx] = spec_utils.reduce_vocal_aggressively(X[idx], y[idx], reduction_mask)
X[idx], y[idx], reduction_mask
)
if np.random.uniform() < 0.5: if np.random.uniform() < 0.5:
# swap channel # swap channel
@ -152,9 +138,7 @@ def make_training_set(filelist, cropsize, patches, sr, hop_length, n_fft, offset
def make_validation_set(filelist, cropsize, sr, hop_length, n_fft, offset): def make_validation_set(filelist, cropsize, sr, hop_length, n_fft, offset):
patch_list = [] patch_list = []
patch_dir = "cs{}_sr{}_hl{}_nf{}_of{}".format( patch_dir = "cs{}_sr{}_hl{}_nf{}_of{}".format(cropsize, sr, hop_length, n_fft, offset)
cropsize, sr, hop_length, n_fft, offset
)
os.makedirs(patch_dir, exist_ok=True) os.makedirs(patch_dir, exist_ok=True)
for i, (X_path, y_path) in enumerate(tqdm(filelist)): for i, (X_path, y_path) in enumerate(tqdm(filelist)):

View File

@ -63,9 +63,7 @@ class Encoder(nn.Module):
class Decoder(nn.Module): class Decoder(nn.Module):
def __init__( def __init__(self, nin, nout, ksize=3, stride=1, pad=1, activ=nn.ReLU, dropout=False):
self, nin, nout, ksize=3, stride=1, pad=1, activ=nn.ReLU, dropout=False
):
super(Decoder, self).__init__() super(Decoder, self).__init__()
self.conv = Conv2DBNActiv(nin, nout, ksize, 1, pad, activ=activ) self.conv = Conv2DBNActiv(nin, nout, ksize, 1, pad, activ=activ)
self.dropout = nn.Dropout2d(0.1) if dropout else None self.dropout = nn.Dropout2d(0.1) if dropout else None
@ -91,24 +89,14 @@ class ASPPModule(nn.Module):
Conv2DBNActiv(nin, nin, 1, 1, 0, activ=activ), Conv2DBNActiv(nin, nin, 1, 1, 0, activ=activ),
) )
self.conv2 = Conv2DBNActiv(nin, nin, 1, 1, 0, activ=activ) self.conv2 = Conv2DBNActiv(nin, nin, 1, 1, 0, activ=activ)
self.conv3 = SeperableConv2DBNActiv( self.conv3 = SeperableConv2DBNActiv(nin, nin, 3, 1, dilations[0], dilations[0], activ=activ)
nin, nin, 3, 1, dilations[0], dilations[0], activ=activ self.conv4 = SeperableConv2DBNActiv(nin, nin, 3, 1, dilations[1], dilations[1], activ=activ)
) self.conv5 = SeperableConv2DBNActiv(nin, nin, 3, 1, dilations[2], dilations[2], activ=activ)
self.conv4 = SeperableConv2DBNActiv( self.bottleneck = nn.Sequential(Conv2DBNActiv(nin * 5, nout, 1, 1, 0, activ=activ), nn.Dropout2d(0.1))
nin, nin, 3, 1, dilations[1], dilations[1], activ=activ
)
self.conv5 = SeperableConv2DBNActiv(
nin, nin, 3, 1, dilations[2], dilations[2], activ=activ
)
self.bottleneck = nn.Sequential(
Conv2DBNActiv(nin * 5, nout, 1, 1, 0, activ=activ), nn.Dropout2d(0.1)
)
def forward(self, x): def forward(self, x):
_, _, h, w = x.size() _, _, h, w = x.size()
feat1 = F.interpolate( feat1 = F.interpolate(self.conv1(x), size=(h, w), mode="bilinear", align_corners=True)
self.conv1(x), size=(h, w), mode="bilinear", align_corners=True
)
feat2 = self.conv2(x) feat2 = self.conv2(x)
feat3 = self.conv3(x) feat3 = self.conv3(x)
feat4 = self.conv4(x) feat4 = self.conv4(x)

View File

@ -63,9 +63,7 @@ class Encoder(nn.Module):
class Decoder(nn.Module): class Decoder(nn.Module):
def __init__( def __init__(self, nin, nout, ksize=3, stride=1, pad=1, activ=nn.ReLU, dropout=False):
self, nin, nout, ksize=3, stride=1, pad=1, activ=nn.ReLU, dropout=False
):
super(Decoder, self).__init__() super(Decoder, self).__init__()
self.conv = Conv2DBNActiv(nin, nout, ksize, 1, pad, activ=activ) self.conv = Conv2DBNActiv(nin, nout, ksize, 1, pad, activ=activ)
self.dropout = nn.Dropout2d(0.1) if dropout else None self.dropout = nn.Dropout2d(0.1) if dropout else None
@ -91,24 +89,14 @@ class ASPPModule(nn.Module):
Conv2DBNActiv(nin, nin, 1, 1, 0, activ=activ), Conv2DBNActiv(nin, nin, 1, 1, 0, activ=activ),
) )
self.conv2 = Conv2DBNActiv(nin, nin, 1, 1, 0, activ=activ) self.conv2 = Conv2DBNActiv(nin, nin, 1, 1, 0, activ=activ)
self.conv3 = SeperableConv2DBNActiv( self.conv3 = SeperableConv2DBNActiv(nin, nin, 3, 1, dilations[0], dilations[0], activ=activ)
nin, nin, 3, 1, dilations[0], dilations[0], activ=activ self.conv4 = SeperableConv2DBNActiv(nin, nin, 3, 1, dilations[1], dilations[1], activ=activ)
) self.conv5 = SeperableConv2DBNActiv(nin, nin, 3, 1, dilations[2], dilations[2], activ=activ)
self.conv4 = SeperableConv2DBNActiv( self.bottleneck = nn.Sequential(Conv2DBNActiv(nin * 5, nout, 1, 1, 0, activ=activ), nn.Dropout2d(0.1))
nin, nin, 3, 1, dilations[1], dilations[1], activ=activ
)
self.conv5 = SeperableConv2DBNActiv(
nin, nin, 3, 1, dilations[2], dilations[2], activ=activ
)
self.bottleneck = nn.Sequential(
Conv2DBNActiv(nin * 5, nout, 1, 1, 0, activ=activ), nn.Dropout2d(0.1)
)
def forward(self, x): def forward(self, x):
_, _, h, w = x.size() _, _, h, w = x.size()
feat1 = F.interpolate( feat1 = F.interpolate(self.conv1(x), size=(h, w), mode="bilinear", align_corners=True)
self.conv1(x), size=(h, w), mode="bilinear", align_corners=True
)
feat2 = self.conv2(x) feat2 = self.conv2(x)
feat3 = self.conv3(x) feat3 = self.conv3(x)
feat4 = self.conv4(x) feat4 = self.conv4(x)

View File

@ -63,9 +63,7 @@ class Encoder(nn.Module):
class Decoder(nn.Module): class Decoder(nn.Module):
def __init__( def __init__(self, nin, nout, ksize=3, stride=1, pad=1, activ=nn.ReLU, dropout=False):
self, nin, nout, ksize=3, stride=1, pad=1, activ=nn.ReLU, dropout=False
):
super(Decoder, self).__init__() super(Decoder, self).__init__()
self.conv = Conv2DBNActiv(nin, nout, ksize, 1, pad, activ=activ) self.conv = Conv2DBNActiv(nin, nout, ksize, 1, pad, activ=activ)
self.dropout = nn.Dropout2d(0.1) if dropout else None self.dropout = nn.Dropout2d(0.1) if dropout else None
@ -91,24 +89,14 @@ class ASPPModule(nn.Module):
Conv2DBNActiv(nin, nin, 1, 1, 0, activ=activ), Conv2DBNActiv(nin, nin, 1, 1, 0, activ=activ),
) )
self.conv2 = Conv2DBNActiv(nin, nin, 1, 1, 0, activ=activ) self.conv2 = Conv2DBNActiv(nin, nin, 1, 1, 0, activ=activ)
self.conv3 = SeperableConv2DBNActiv( self.conv3 = SeperableConv2DBNActiv(nin, nin, 3, 1, dilations[0], dilations[0], activ=activ)
nin, nin, 3, 1, dilations[0], dilations[0], activ=activ self.conv4 = SeperableConv2DBNActiv(nin, nin, 3, 1, dilations[1], dilations[1], activ=activ)
) self.conv5 = SeperableConv2DBNActiv(nin, nin, 3, 1, dilations[2], dilations[2], activ=activ)
self.conv4 = SeperableConv2DBNActiv( self.bottleneck = nn.Sequential(Conv2DBNActiv(nin * 5, nout, 1, 1, 0, activ=activ), nn.Dropout2d(0.1))
nin, nin, 3, 1, dilations[1], dilations[1], activ=activ
)
self.conv5 = SeperableConv2DBNActiv(
nin, nin, 3, 1, dilations[2], dilations[2], activ=activ
)
self.bottleneck = nn.Sequential(
Conv2DBNActiv(nin * 5, nout, 1, 1, 0, activ=activ), nn.Dropout2d(0.1)
)
def forward(self, x): def forward(self, x):
_, _, h, w = x.size() _, _, h, w = x.size()
feat1 = F.interpolate( feat1 = F.interpolate(self.conv1(x), size=(h, w), mode="bilinear", align_corners=True)
self.conv1(x), size=(h, w), mode="bilinear", align_corners=True
)
feat2 = self.conv2(x) feat2 = self.conv2(x)
feat3 = self.conv3(x) feat3 = self.conv3(x)
feat4 = self.conv4(x) feat4 = self.conv4(x)

View File

@ -63,9 +63,7 @@ class Encoder(nn.Module):
class Decoder(nn.Module): class Decoder(nn.Module):
def __init__( def __init__(self, nin, nout, ksize=3, stride=1, pad=1, activ=nn.ReLU, dropout=False):
self, nin, nout, ksize=3, stride=1, pad=1, activ=nn.ReLU, dropout=False
):
super(Decoder, self).__init__() super(Decoder, self).__init__()
self.conv = Conv2DBNActiv(nin, nout, ksize, 1, pad, activ=activ) self.conv = Conv2DBNActiv(nin, nout, ksize, 1, pad, activ=activ)
self.dropout = nn.Dropout2d(0.1) if dropout else None self.dropout = nn.Dropout2d(0.1) if dropout else None
@ -91,30 +89,16 @@ class ASPPModule(nn.Module):
Conv2DBNActiv(nin, nin, 1, 1, 0, activ=activ), Conv2DBNActiv(nin, nin, 1, 1, 0, activ=activ),
) )
self.conv2 = Conv2DBNActiv(nin, nin, 1, 1, 0, activ=activ) self.conv2 = Conv2DBNActiv(nin, nin, 1, 1, 0, activ=activ)
self.conv3 = SeperableConv2DBNActiv( self.conv3 = SeperableConv2DBNActiv(nin, nin, 3, 1, dilations[0], dilations[0], activ=activ)
nin, nin, 3, 1, dilations[0], dilations[0], activ=activ self.conv4 = SeperableConv2DBNActiv(nin, nin, 3, 1, dilations[1], dilations[1], activ=activ)
) self.conv5 = SeperableConv2DBNActiv(nin, nin, 3, 1, dilations[2], dilations[2], activ=activ)
self.conv4 = SeperableConv2DBNActiv( self.conv6 = SeperableConv2DBNActiv(nin, nin, 3, 1, dilations[2], dilations[2], activ=activ)
nin, nin, 3, 1, dilations[1], dilations[1], activ=activ self.conv7 = SeperableConv2DBNActiv(nin, nin, 3, 1, dilations[2], dilations[2], activ=activ)
) self.bottleneck = nn.Sequential(Conv2DBNActiv(nin * 7, nout, 1, 1, 0, activ=activ), nn.Dropout2d(0.1))
self.conv5 = SeperableConv2DBNActiv(
nin, nin, 3, 1, dilations[2], dilations[2], activ=activ
)
self.conv6 = SeperableConv2DBNActiv(
nin, nin, 3, 1, dilations[2], dilations[2], activ=activ
)
self.conv7 = SeperableConv2DBNActiv(
nin, nin, 3, 1, dilations[2], dilations[2], activ=activ
)
self.bottleneck = nn.Sequential(
Conv2DBNActiv(nin * 7, nout, 1, 1, 0, activ=activ), nn.Dropout2d(0.1)
)
def forward(self, x): def forward(self, x):
_, _, h, w = x.size() _, _, h, w = x.size()
feat1 = F.interpolate( feat1 = F.interpolate(self.conv1(x), size=(h, w), mode="bilinear", align_corners=True)
self.conv1(x), size=(h, w), mode="bilinear", align_corners=True
)
feat2 = self.conv2(x) feat2 = self.conv2(x)
feat3 = self.conv3(x) feat3 = self.conv3(x)
feat4 = self.conv4(x) feat4 = self.conv4(x)

View File

@ -63,9 +63,7 @@ class Encoder(nn.Module):
class Decoder(nn.Module): class Decoder(nn.Module):
def __init__( def __init__(self, nin, nout, ksize=3, stride=1, pad=1, activ=nn.ReLU, dropout=False):
self, nin, nout, ksize=3, stride=1, pad=1, activ=nn.ReLU, dropout=False
):
super(Decoder, self).__init__() super(Decoder, self).__init__()
self.conv = Conv2DBNActiv(nin, nout, ksize, 1, pad, activ=activ) self.conv = Conv2DBNActiv(nin, nout, ksize, 1, pad, activ=activ)
self.dropout = nn.Dropout2d(0.1) if dropout else None self.dropout = nn.Dropout2d(0.1) if dropout else None
@ -91,30 +89,16 @@ class ASPPModule(nn.Module):
Conv2DBNActiv(nin, nin, 1, 1, 0, activ=activ), Conv2DBNActiv(nin, nin, 1, 1, 0, activ=activ),
) )
self.conv2 = Conv2DBNActiv(nin, nin, 1, 1, 0, activ=activ) self.conv2 = Conv2DBNActiv(nin, nin, 1, 1, 0, activ=activ)
self.conv3 = SeperableConv2DBNActiv( self.conv3 = SeperableConv2DBNActiv(nin, nin, 3, 1, dilations[0], dilations[0], activ=activ)
nin, nin, 3, 1, dilations[0], dilations[0], activ=activ self.conv4 = SeperableConv2DBNActiv(nin, nin, 3, 1, dilations[1], dilations[1], activ=activ)
) self.conv5 = SeperableConv2DBNActiv(nin, nin, 3, 1, dilations[2], dilations[2], activ=activ)
self.conv4 = SeperableConv2DBNActiv( self.conv6 = SeperableConv2DBNActiv(nin, nin, 3, 1, dilations[2], dilations[2], activ=activ)
nin, nin, 3, 1, dilations[1], dilations[1], activ=activ self.conv7 = SeperableConv2DBNActiv(nin, nin, 3, 1, dilations[2], dilations[2], activ=activ)
) self.bottleneck = nn.Sequential(Conv2DBNActiv(nin * 7, nout, 1, 1, 0, activ=activ), nn.Dropout2d(0.1))
self.conv5 = SeperableConv2DBNActiv(
nin, nin, 3, 1, dilations[2], dilations[2], activ=activ
)
self.conv6 = SeperableConv2DBNActiv(
nin, nin, 3, 1, dilations[2], dilations[2], activ=activ
)
self.conv7 = SeperableConv2DBNActiv(
nin, nin, 3, 1, dilations[2], dilations[2], activ=activ
)
self.bottleneck = nn.Sequential(
Conv2DBNActiv(nin * 7, nout, 1, 1, 0, activ=activ), nn.Dropout2d(0.1)
)
def forward(self, x): def forward(self, x):
_, _, h, w = x.size() _, _, h, w = x.size()
feat1 = F.interpolate( feat1 = F.interpolate(self.conv1(x), size=(h, w), mode="bilinear", align_corners=True)
self.conv1(x), size=(h, w), mode="bilinear", align_corners=True
)
feat2 = self.conv2(x) feat2 = self.conv2(x)
feat3 = self.conv3(x) feat3 = self.conv3(x)
feat4 = self.conv4(x) feat4 = self.conv4(x)

View File

@ -63,9 +63,7 @@ class Encoder(nn.Module):
class Decoder(nn.Module): class Decoder(nn.Module):
def __init__( def __init__(self, nin, nout, ksize=3, stride=1, pad=1, activ=nn.ReLU, dropout=False):
self, nin, nout, ksize=3, stride=1, pad=1, activ=nn.ReLU, dropout=False
):
super(Decoder, self).__init__() super(Decoder, self).__init__()
self.conv = Conv2DBNActiv(nin, nout, ksize, 1, pad, activ=activ) self.conv = Conv2DBNActiv(nin, nout, ksize, 1, pad, activ=activ)
self.dropout = nn.Dropout2d(0.1) if dropout else None self.dropout = nn.Dropout2d(0.1) if dropout else None
@ -91,30 +89,16 @@ class ASPPModule(nn.Module):
Conv2DBNActiv(nin, nin, 1, 1, 0, activ=activ), Conv2DBNActiv(nin, nin, 1, 1, 0, activ=activ),
) )
self.conv2 = Conv2DBNActiv(nin, nin, 1, 1, 0, activ=activ) self.conv2 = Conv2DBNActiv(nin, nin, 1, 1, 0, activ=activ)
self.conv3 = SeperableConv2DBNActiv( self.conv3 = SeperableConv2DBNActiv(nin, nin, 3, 1, dilations[0], dilations[0], activ=activ)
nin, nin, 3, 1, dilations[0], dilations[0], activ=activ self.conv4 = SeperableConv2DBNActiv(nin, nin, 3, 1, dilations[1], dilations[1], activ=activ)
) self.conv5 = SeperableConv2DBNActiv(nin, nin, 3, 1, dilations[2], dilations[2], activ=activ)
self.conv4 = SeperableConv2DBNActiv( self.conv6 = SeperableConv2DBNActiv(nin, nin, 3, 1, dilations[2], dilations[2], activ=activ)
nin, nin, 3, 1, dilations[1], dilations[1], activ=activ self.conv7 = SeperableConv2DBNActiv(nin, nin, 3, 1, dilations[2], dilations[2], activ=activ)
) self.bottleneck = nn.Sequential(Conv2DBNActiv(nin * 7, nout, 1, 1, 0, activ=activ), nn.Dropout2d(0.1))
self.conv5 = SeperableConv2DBNActiv(
nin, nin, 3, 1, dilations[2], dilations[2], activ=activ
)
self.conv6 = SeperableConv2DBNActiv(
nin, nin, 3, 1, dilations[2], dilations[2], activ=activ
)
self.conv7 = SeperableConv2DBNActiv(
nin, nin, 3, 1, dilations[2], dilations[2], activ=activ
)
self.bottleneck = nn.Sequential(
Conv2DBNActiv(nin * 7, nout, 1, 1, 0, activ=activ), nn.Dropout2d(0.1)
)
def forward(self, x): def forward(self, x):
_, _, h, w = x.size() _, _, h, w = x.size()
feat1 = F.interpolate( feat1 = F.interpolate(self.conv1(x), size=(h, w), mode="bilinear", align_corners=True)
self.conv1(x), size=(h, w), mode="bilinear", align_corners=True
)
feat2 = self.conv2(x) feat2 = self.conv2(x)
feat3 = self.conv3(x) feat3 = self.conv3(x)
feat4 = self.conv4(x) feat4 = self.conv4(x)

View File

@ -40,9 +40,7 @@ class Encoder(nn.Module):
class Decoder(nn.Module): class Decoder(nn.Module):
def __init__( def __init__(self, nin, nout, ksize=3, stride=1, pad=1, activ=nn.ReLU, dropout=False):
self, nin, nout, ksize=3, stride=1, pad=1, activ=nn.ReLU, dropout=False
):
super(Decoder, self).__init__() super(Decoder, self).__init__()
self.conv1 = Conv2DBNActiv(nin, nout, ksize, 1, pad, activ=activ) self.conv1 = Conv2DBNActiv(nin, nout, ksize, 1, pad, activ=activ)
# self.conv2 = Conv2DBNActiv(nout, nout, ksize, 1, pad, activ=activ) # self.conv2 = Conv2DBNActiv(nout, nout, ksize, 1, pad, activ=activ)
@ -72,23 +70,15 @@ class ASPPModule(nn.Module):
Conv2DBNActiv(nin, nout, 1, 1, 0, activ=activ), Conv2DBNActiv(nin, nout, 1, 1, 0, activ=activ),
) )
self.conv2 = Conv2DBNActiv(nin, nout, 1, 1, 0, activ=activ) self.conv2 = Conv2DBNActiv(nin, nout, 1, 1, 0, activ=activ)
self.conv3 = Conv2DBNActiv( self.conv3 = Conv2DBNActiv(nin, nout, 3, 1, dilations[0], dilations[0], activ=activ)
nin, nout, 3, 1, dilations[0], dilations[0], activ=activ self.conv4 = Conv2DBNActiv(nin, nout, 3, 1, dilations[1], dilations[1], activ=activ)
) self.conv5 = Conv2DBNActiv(nin, nout, 3, 1, dilations[2], dilations[2], activ=activ)
self.conv4 = Conv2DBNActiv(
nin, nout, 3, 1, dilations[1], dilations[1], activ=activ
)
self.conv5 = Conv2DBNActiv(
nin, nout, 3, 1, dilations[2], dilations[2], activ=activ
)
self.bottleneck = Conv2DBNActiv(nout * 5, nout, 1, 1, 0, activ=activ) self.bottleneck = Conv2DBNActiv(nout * 5, nout, 1, 1, 0, activ=activ)
self.dropout = nn.Dropout2d(0.1) if dropout else None self.dropout = nn.Dropout2d(0.1) if dropout else None
def forward(self, x): def forward(self, x):
_, _, h, w = x.size() _, _, h, w = x.size()
feat1 = F.interpolate( feat1 = F.interpolate(self.conv1(x), size=(h, w), mode="bilinear", align_corners=True)
self.conv1(x), size=(h, w), mode="bilinear", align_corners=True
)
feat2 = self.conv2(x) feat2 = self.conv2(x)
feat3 = self.conv3(x) feat3 = self.conv3(x)
feat4 = self.conv4(x) feat4 = self.conv4(x)
@ -106,12 +96,8 @@ class LSTMModule(nn.Module):
def __init__(self, nin_conv, nin_lstm, nout_lstm): def __init__(self, nin_conv, nin_lstm, nout_lstm):
super(LSTMModule, self).__init__() super(LSTMModule, self).__init__()
self.conv = Conv2DBNActiv(nin_conv, 1, 1, 1, 0) self.conv = Conv2DBNActiv(nin_conv, 1, 1, 1, 0)
self.lstm = nn.LSTM( self.lstm = nn.LSTM(input_size=nin_lstm, hidden_size=nout_lstm // 2, bidirectional=True)
input_size=nin_lstm, hidden_size=nout_lstm // 2, bidirectional=True self.dense = nn.Sequential(nn.Linear(nout_lstm, nin_lstm), nn.BatchNorm1d(nin_lstm), nn.ReLU())
)
self.dense = nn.Sequential(
nn.Linear(nout_lstm, nin_lstm), nn.BatchNorm1d(nin_lstm), nn.ReLU()
)
def forward(self, x): def forward(self, x):
N, _, nbins, nframes = x.size() N, _, nbins, nframes = x.size()

View File

@ -1,5 +1,4 @@
import json import json
import os
import pathlib import pathlib
default_param = {} default_param = {}
@ -48,9 +47,7 @@ class ModelParameters(object):
import zipfile import zipfile
with zipfile.ZipFile(config_path, "r") as zip: with zipfile.ZipFile(config_path, "r") as zip:
self.param = json.loads( self.param = json.loads(zip.read("param.json"), object_pairs_hook=int_keys)
zip.read("param.json"), object_pairs_hook=int_keys
)
elif ".json" == pathlib.Path(config_path).suffix: elif ".json" == pathlib.Path(config_path).suffix:
with open(config_path, "r") as f: with open(config_path, "r") as f:
self.param = json.loads(f.read(), object_pairs_hook=int_keys) self.param = json.loads(f.read(), object_pairs_hook=int_keys)
@ -65,5 +62,5 @@ class ModelParameters(object):
"stereo_n", "stereo_n",
"reverse", "reverse",
]: ]:
if not k in self.param: if k not in self.param:
self.param[k] = False self.param[k] = False

View File

@ -3,8 +3,6 @@ import torch
import torch.nn.functional as F import torch.nn.functional as F
from torch import nn from torch import nn
from . import spec_utils
class BaseASPPNet(nn.Module): class BaseASPPNet(nn.Module):
def __init__(self, nin, ch, dilations=(4, 8, 16)): def __init__(self, nin, ch, dilations=(4, 8, 16)):

View File

@ -1,4 +1,3 @@
import numpy as np
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from torch import nn from torch import nn

View File

@ -1,4 +1,3 @@
import numpy as np
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from torch import nn from torch import nn

View File

@ -6,9 +6,7 @@ from . import layers_new
class BaseNet(nn.Module): class BaseNet(nn.Module):
def __init__( def __init__(self, nin, nout, nin_lstm, nout_lstm, dilations=((4, 2), (8, 4), (12, 6))):
self, nin, nout, nin_lstm, nout_lstm, dilations=((4, 2), (8, 4), (12, 6))
):
super(BaseNet, self).__init__() super(BaseNet, self).__init__()
self.enc1 = layers_new.Conv2DBNActiv(nin, nout, 3, 1, 1) self.enc1 = layers_new.Conv2DBNActiv(nin, nout, 3, 1, 1)
self.enc2 = layers_new.Encoder(nout, nout * 2, 3, 2, 1) self.enc2 = layers_new.Encoder(nout, nout * 2, 3, 2, 1)
@ -56,21 +54,15 @@ class CascadedNet(nn.Module):
layers_new.Conv2DBNActiv(nout // 2, nout // 4, 1, 1, 0), layers_new.Conv2DBNActiv(nout // 2, nout // 4, 1, 1, 0),
) )
self.stg1_high_band_net = BaseNet( self.stg1_high_band_net = BaseNet(2, nout // 4, self.nin_lstm // 2, nout_lstm // 2)
2, nout // 4, self.nin_lstm // 2, nout_lstm // 2
)
self.stg2_low_band_net = nn.Sequential( self.stg2_low_band_net = nn.Sequential(
BaseNet(nout // 4 + 2, nout, self.nin_lstm // 2, nout_lstm), BaseNet(nout // 4 + 2, nout, self.nin_lstm // 2, nout_lstm),
layers_new.Conv2DBNActiv(nout, nout // 2, 1, 1, 0), layers_new.Conv2DBNActiv(nout, nout // 2, 1, 1, 0),
) )
self.stg2_high_band_net = BaseNet( self.stg2_high_band_net = BaseNet(nout // 4 + 2, nout // 2, self.nin_lstm // 2, nout_lstm // 2)
nout // 4 + 2, nout // 2, self.nin_lstm // 2, nout_lstm // 2
)
self.stg3_full_band_net = BaseNet( self.stg3_full_band_net = BaseNet(3 * nout // 4 + 2, nout, self.nin_lstm, nout_lstm)
3 * nout // 4 + 2, nout, self.nin_lstm, nout_lstm
)
self.out = nn.Conv2d(nout, 2, 1, bias=False) self.out = nn.Conv2d(nout, 2, 1, bias=False)
self.aux_out = nn.Conv2d(3 * nout // 4, 2, 1, bias=False) self.aux_out = nn.Conv2d(3 * nout // 4, 2, 1, bias=False)

View File

@ -27,9 +27,7 @@ def crop_center(h1, h2):
return h1 return h1
def wave_to_spectrogram( def wave_to_spectrogram(wave, hop_length, n_fft, mid_side=False, mid_side_b2=False, reverse=False):
wave, hop_length, n_fft, mid_side=False, mid_side_b2=False, reverse=False
):
if reverse: if reverse:
wave_left = np.flip(np.asfortranarray(wave[0])) wave_left = np.flip(np.asfortranarray(wave[0]))
wave_right = np.flip(np.asfortranarray(wave[1])) wave_right = np.flip(np.asfortranarray(wave[1]))
@ -43,7 +41,7 @@ def wave_to_spectrogram(
wave_left = np.asfortranarray(wave[0]) wave_left = np.asfortranarray(wave[0])
wave_right = np.asfortranarray(wave[1]) wave_right = np.asfortranarray(wave[1])
spec_left = librosa.stft(wave_left, n_fft=n_fft, hop_length=hop_length) spec_left = librosa.stft(wave_left, n_fft=n_fft, hop_length=hop_length)
spec_right = librosa.stft(wave_right, n_fft=n_fft, hop_length=hop_length) spec_right = librosa.stft(wave_right, n_fft=n_fft, hop_length=hop_length)
spec = np.asfortranarray([spec_left, spec_right]) spec = np.asfortranarray([spec_left, spec_right])
@ -51,9 +49,7 @@ def wave_to_spectrogram(
return spec return spec
def wave_to_spectrogram_mt( def wave_to_spectrogram_mt(wave, hop_length, n_fft, mid_side=False, mid_side_b2=False, reverse=False):
wave, hop_length, n_fft, mid_side=False, mid_side_b2=False, reverse=False
):
import threading import threading
if reverse: if reverse:
@ -103,21 +99,13 @@ def combine_spectrograms(specs, mp):
raise ValueError("Too much bins") raise ValueError("Too much bins")
# lowpass fiter # lowpass fiter
if ( if mp.param["pre_filter_start"] > 0: # and mp.param['band'][bands_n]['res_type'] in ['scipy', 'polyphase']:
mp.param["pre_filter_start"] > 0
): # and mp.param['band'][bands_n]['res_type'] in ['scipy', 'polyphase']:
if bands_n == 1: if bands_n == 1:
spec_c = fft_lp_filter( spec_c = fft_lp_filter(spec_c, mp.param["pre_filter_start"], mp.param["pre_filter_stop"])
spec_c, mp.param["pre_filter_start"], mp.param["pre_filter_stop"]
)
else: else:
gp = 1 gp = 1
for b in range( for b in range(mp.param["pre_filter_start"] + 1, mp.param["pre_filter_stop"]):
mp.param["pre_filter_start"] + 1, mp.param["pre_filter_stop"] g = math.pow(10, -(b - mp.param["pre_filter_start"]) * (3.5 - gp) / 20.0)
):
g = math.pow(
10, -(b - mp.param["pre_filter_start"]) * (3.5 - gp) / 20.0
)
gp = g gp = g
spec_c[:, b, :] *= g spec_c[:, b, :] *= g
@ -189,9 +177,7 @@ def mask_silence(mag, ref, thres=0.2, min_range=64, fade_size=32):
else: else:
e += fade_size e += fade_size
mag[:, :, s + fade_size : e - fade_size] += ref[ mag[:, :, s + fade_size : e - fade_size] += ref[:, :, s + fade_size : e - fade_size]
:, :, s + fade_size : e - fade_size
]
old_e = e old_e = e
return mag return mag
@ -207,9 +193,7 @@ def cache_or_load(mix_path, inst_path, mp):
mix_basename = os.path.splitext(os.path.basename(mix_path))[0] mix_basename = os.path.splitext(os.path.basename(mix_path))[0]
inst_basename = os.path.splitext(os.path.basename(inst_path))[0] inst_basename = os.path.splitext(os.path.basename(inst_path))[0]
cache_dir = "mph{}".format( cache_dir = "mph{}".format(hashlib.sha1(json.dumps(mp.param, sort_keys=True).encode("utf-8")).hexdigest())
hashlib.sha1(json.dumps(mp.param, sort_keys=True).encode("utf-8")).hexdigest()
)
mix_cache_dir = os.path.join("cache", cache_dir) mix_cache_dir = os.path.join("cache", cache_dir)
inst_cache_dir = os.path.join("cache", cache_dir) inst_cache_dir = os.path.join("cache", cache_dir)
@ -230,31 +214,27 @@ def cache_or_load(mix_path, inst_path, mp):
if d == len(mp.param["band"]): # high-end band if d == len(mp.param["band"]): # high-end band
X_wave[d], _ = librosa.load( X_wave[d], _ = librosa.load(
mix_path, mix_path, sr=bp["sr"], mono=False, dtype=np.float32, res_type=bp["res_type"]
sr = bp["sr"],
mono = False,
dtype = np.float32,
res_type = bp["res_type"]
) )
y_wave[d], _ = librosa.load( y_wave[d], _ = librosa.load(
inst_path, inst_path,
sr = bp["sr"], sr=bp["sr"],
mono = False, mono=False,
dtype = np.float32, dtype=np.float32,
res_type = bp["res_type"], res_type=bp["res_type"],
) )
else: # lower bands else: # lower bands
X_wave[d] = librosa.resample( X_wave[d] = librosa.resample(
X_wave[d + 1], X_wave[d + 1],
orig_sr = mp.param["band"][d + 1]["sr"], orig_sr=mp.param["band"][d + 1]["sr"],
target_sr = bp["sr"], target_sr=bp["sr"],
res_type = bp["res_type"], res_type=bp["res_type"],
) )
y_wave[d] = librosa.resample( y_wave[d] = librosa.resample(
y_wave[d + 1], y_wave[d + 1],
orig_sr = mp.param["band"][d + 1]["sr"], orig_sr=mp.param["band"][d + 1]["sr"],
target_sr = bp["sr"], target_sr=bp["sr"],
res_type = bp["res_type"], res_type=bp["res_type"],
) )
X_wave[d], y_wave[d] = align_wave_head_and_tail(X_wave[d], y_wave[d]) X_wave[d], y_wave[d] = align_wave_head_and_tail(X_wave[d], y_wave[d])
@ -302,9 +282,7 @@ def spectrogram_to_wave(spec, hop_length, mid_side, mid_side_b2, reverse):
if reverse: if reverse:
return np.asfortranarray([np.flip(wave_left), np.flip(wave_right)]) return np.asfortranarray([np.flip(wave_left), np.flip(wave_right)])
elif mid_side: elif mid_side:
return np.asfortranarray( return np.asfortranarray([np.add(wave_left, wave_right / 2), np.subtract(wave_left, wave_right / 2)])
[np.add(wave_left, wave_right / 2), np.subtract(wave_left, wave_right / 2)]
)
elif mid_side_b2: elif mid_side_b2:
return np.asfortranarray( return np.asfortranarray(
[ [
@ -326,9 +304,7 @@ def spectrogram_to_wave_mt(spec, hop_length, mid_side, reverse, mid_side_b2):
global wave_left global wave_left
wave_left = librosa.istft(**kwargs) wave_left = librosa.istft(**kwargs)
thread = threading.Thread( thread = threading.Thread(target=run_thread, kwargs={"stft_matrix": spec_left, "hop_length": hop_length})
target=run_thread, kwargs={"stft_matrix": spec_left, "hop_length": hop_length}
)
thread.start() thread.start()
wave_right = librosa.istft(spec_right, hop_length=hop_length) wave_right = librosa.istft(spec_right, hop_length=hop_length)
thread.join() thread.join()
@ -336,9 +312,7 @@ def spectrogram_to_wave_mt(spec, hop_length, mid_side, reverse, mid_side_b2):
if reverse: if reverse:
return np.asfortranarray([np.flip(wave_left), np.flip(wave_right)]) return np.asfortranarray([np.flip(wave_left), np.flip(wave_right)])
elif mid_side: elif mid_side:
return np.asfortranarray( return np.asfortranarray([np.add(wave_left, wave_right / 2), np.subtract(wave_left, wave_right / 2)])
[np.add(wave_left, wave_right / 2), np.subtract(wave_left, wave_right / 2)]
)
elif mid_side_b2: elif mid_side_b2:
return np.asfortranarray( return np.asfortranarray(
[ [
@ -357,21 +331,15 @@ def cmb_spectrogram_to_wave(spec_m, mp, extra_bins_h=None, extra_bins=None):
for d in range(1, bands_n + 1): for d in range(1, bands_n + 1):
bp = mp.param["band"][d] bp = mp.param["band"][d]
spec_s = np.ndarray( spec_s = np.ndarray(shape=(2, bp["n_fft"] // 2 + 1, spec_m.shape[2]), dtype=complex)
shape=(2, bp["n_fft"] // 2 + 1, spec_m.shape[2]), dtype=complex
)
h = bp["crop_stop"] - bp["crop_start"] h = bp["crop_stop"] - bp["crop_start"]
spec_s[:, bp["crop_start"] : bp["crop_stop"], :] = spec_m[ spec_s[:, bp["crop_start"] : bp["crop_stop"], :] = spec_m[:, offset : offset + h, :]
:, offset : offset + h, :
]
offset += h offset += h
if d == bands_n: # higher if d == bands_n: # higher
if extra_bins_h: # if --high_end_process bypass if extra_bins_h: # if --high_end_process bypass
max_bin = bp["n_fft"] // 2 max_bin = bp["n_fft"] // 2
spec_s[:, max_bin - extra_bins_h : max_bin, :] = extra_bins[ spec_s[:, max_bin - extra_bins_h : max_bin, :] = extra_bins[:, :extra_bins_h, :]
:, :extra_bins_h, :
]
if bp["hpf_start"] > 0: if bp["hpf_start"] > 0:
spec_s = fft_hp_filter(spec_s, bp["hpf_start"], bp["hpf_stop"] - 1) spec_s = fft_hp_filter(spec_s, bp["hpf_start"], bp["hpf_stop"] - 1)
if bands_n == 1: if bands_n == 1:
@ -405,9 +373,9 @@ def cmb_spectrogram_to_wave(spec_m, mp, extra_bins_h=None, extra_bins=None):
mp.param["mid_side_b2"], mp.param["mid_side_b2"],
mp.param["reverse"], mp.param["reverse"],
), ),
orig_sr = bp["sr"], orig_sr=bp["sr"],
target_sr = sr, target_sr=sr,
res_type = "sinc_fastest", res_type="sinc_fastest",
) )
else: # mid else: # mid
spec_s = fft_hp_filter(spec_s, bp["hpf_start"], bp["hpf_stop"] - 1) spec_s = fft_hp_filter(spec_s, bp["hpf_start"], bp["hpf_stop"] - 1)
@ -456,10 +424,7 @@ def mirroring(a, spec_m, input_high_end, mp):
np.abs( np.abs(
spec_m[ spec_m[
:, :,
mp.param["pre_filter_start"] mp.param["pre_filter_start"] - 10 - input_high_end.shape[1] : mp.param["pre_filter_start"] - 10,
- 10
- input_high_end.shape[1] : mp.param["pre_filter_start"]
- 10,
:, :,
] ]
), ),
@ -467,19 +432,14 @@ def mirroring(a, spec_m, input_high_end, mp):
) )
mirror = mirror * np.exp(1.0j * np.angle(input_high_end)) mirror = mirror * np.exp(1.0j * np.angle(input_high_end))
return np.where( return np.where(np.abs(input_high_end) <= np.abs(mirror), input_high_end, mirror)
np.abs(input_high_end) <= np.abs(mirror), input_high_end, mirror
)
if "mirroring2" == a: if "mirroring2" == a:
mirror = np.flip( mirror = np.flip(
np.abs( np.abs(
spec_m[ spec_m[
:, :,
mp.param["pre_filter_start"] mp.param["pre_filter_start"] - 10 - input_high_end.shape[1] : mp.param["pre_filter_start"] - 10,
- 10
- input_high_end.shape[1] : mp.param["pre_filter_start"]
- 10,
:, :,
] ]
), ),
@ -528,7 +488,6 @@ def istft(spec, hl):
if __name__ == "__main__": if __name__ == "__main__":
import argparse import argparse
import sys
import time import time
import cv2 import cv2
@ -573,10 +532,10 @@ if __name__ == "__main__":
if d == len(mp.param["band"]): # high-end band if d == len(mp.param["band"]): # high-end band
wave[d], _ = librosa.load( wave[d], _ = librosa.load(
args.input[i], args.input[i],
sr = bp["sr"], sr=bp["sr"],
mono = False, mono=False,
dtype = np.float32, dtype=np.float32,
res_type = bp["res_type"], res_type=bp["res_type"],
) )
if len(wave[d].shape) == 1: # mono to stereo if len(wave[d].shape) == 1: # mono to stereo
@ -584,9 +543,9 @@ if __name__ == "__main__":
else: # lower bands else: # lower bands
wave[d] = librosa.resample( wave[d] = librosa.resample(
wave[d + 1], wave[d + 1],
orig_sr = mp.param["band"][d + 1]["sr"], orig_sr=mp.param["band"][d + 1]["sr"],
target_sr = bp["sr"], target_sr=bp["sr"],
res_type = bp["res_type"], res_type=bp["res_type"],
) )
spec[d] = wave_to_spectrogram( spec[d] = wave_to_spectrogram(

View File

@ -27,9 +27,7 @@ def inference(X_spec, device, model, aggressiveness, data):
data : dic configs data : dic configs
""" """
def _execute( def _execute(X_mag_pad, roi_size, n_window, device, model, aggressiveness, is_half=True):
X_mag_pad, roi_size, n_window, device, model, aggressiveness, is_half=True
):
model.eval() model.eval()
with torch.no_grad(): with torch.no_grad():
preds = [] preds = []
@ -39,9 +37,7 @@ def inference(X_spec, device, model, aggressiveness, data):
total_iterations = sum(iterations) total_iterations = sum(iterations)
for i in tqdm(range(n_window)): for i in tqdm(range(n_window)):
start = i * roi_size start = i * roi_size
X_mag_window = X_mag_pad[ X_mag_window = X_mag_pad[None, :, :, start : start + data["window_size"]]
None, :, :, start : start + data["window_size"]
]
X_mag_window = torch.from_numpy(X_mag_window) X_mag_window = torch.from_numpy(X_mag_window)
if is_half: if is_half:
X_mag_window = X_mag_window.half() X_mag_window = X_mag_window.half()
@ -76,9 +72,7 @@ def inference(X_spec, device, model, aggressiveness, data):
is_half = True is_half = True
else: else:
is_half = False is_half = False
pred = _execute( pred = _execute(X_mag_pad, roi_size, n_window, device, model, aggressiveness, is_half)
X_mag_pad, roi_size, n_window, device, model, aggressiveness, is_half
)
pred = pred[:, :, :n_frame] pred = pred[:, :, :n_frame]
if data["tta"]: if data["tta"]:
@ -88,9 +82,7 @@ def inference(X_spec, device, model, aggressiveness, data):
X_mag_pad = np.pad(X_mag_pre, ((0, 0), (0, 0), (pad_l, pad_r)), mode="constant") X_mag_pad = np.pad(X_mag_pre, ((0, 0), (0, 0), (pad_l, pad_r)), mode="constant")
pred_tta = _execute( pred_tta = _execute(X_mag_pad, roi_size, n_window, device, model, aggressiveness, is_half)
X_mag_pad, roi_size, n_window, device, model, aggressiveness, is_half
)
pred_tta = pred_tta[:, :, roi_size // 2 :] pred_tta = pred_tta[:, :, roi_size // 2 :]
pred_tta = pred_tta[:, :, :n_frame] pred_tta = pred_tta[:, :, :n_frame]

View File

@ -147,7 +147,9 @@ if torch.cuda.is_available() or ngpu != 0:
# mem.append(psutil.virtual_memory().total/ 1024 / 1024 / 1024) # 实测使用系统内存作为显存不会爆显存 # mem.append(psutil.virtual_memory().total/ 1024 / 1024 / 1024) # 实测使用系统内存作为显存不会爆显存
v3v4set={"v3","v4"} v3v4set = {"v3", "v4"}
def set_default(): def set_default():
global \ global \
default_batch_size, \ default_batch_size, \
@ -589,6 +591,7 @@ def close_denoise():
p_train_SoVITS = None p_train_SoVITS = None
process_name_sovits = i18n("SoVITS训练") process_name_sovits = i18n("SoVITS训练")
def open1Ba( def open1Ba(
batch_size, batch_size,
total_epoch, total_epoch,
@ -641,7 +644,9 @@ def open1Ba(
yield ( yield (
process_info(process_name_sovits, "opened"), process_info(process_name_sovits, "opened"),
{"__type__": "update", "visible": False}, {"__type__": "update", "visible": False},
{"__type__": "update", "visible": True},{"__type__": "update"},{"__type__": "update"} {"__type__": "update", "visible": True},
{"__type__": "update"},
{"__type__": "update"},
) )
print(cmd) print(cmd)
p_train_SoVITS = Popen(cmd, shell=True) p_train_SoVITS = Popen(cmd, shell=True)
@ -651,13 +656,17 @@ def open1Ba(
yield ( yield (
process_info(process_name_sovits, "finish"), process_info(process_name_sovits, "finish"),
{"__type__": "update", "visible": True}, {"__type__": "update", "visible": True},
{"__type__": "update", "visible": False},SoVITS_dropdown_update,GPT_dropdown_update {"__type__": "update", "visible": False},
SoVITS_dropdown_update,
GPT_dropdown_update,
) )
else: else:
yield ( yield (
process_info(process_name_sovits, "occupy"), process_info(process_name_sovits, "occupy"),
{"__type__": "update", "visible": False}, {"__type__": "update", "visible": False},
{"__type__": "update", "visible": True},{"__type__": "update"},{"__type__": "update"} {"__type__": "update", "visible": True},
{"__type__": "update"},
{"__type__": "update"},
) )
@ -726,7 +735,9 @@ def open1Bb(
yield ( yield (
process_info(process_name_gpt, "opened"), process_info(process_name_gpt, "opened"),
{"__type__": "update", "visible": False}, {"__type__": "update", "visible": False},
{"__type__": "update", "visible": True},{"__type__": "update"},{"__type__": "update"} {"__type__": "update", "visible": True},
{"__type__": "update"},
{"__type__": "update"},
) )
print(cmd) print(cmd)
p_train_GPT = Popen(cmd, shell=True) p_train_GPT = Popen(cmd, shell=True)
@ -736,13 +747,17 @@ def open1Bb(
yield ( yield (
process_info(process_name_gpt, "finish"), process_info(process_name_gpt, "finish"),
{"__type__": "update", "visible": True}, {"__type__": "update", "visible": True},
{"__type__": "update", "visible": False},SoVITS_dropdown_update,GPT_dropdown_update {"__type__": "update", "visible": False},
SoVITS_dropdown_update,
GPT_dropdown_update,
) )
else: else:
yield ( yield (
process_info(process_name_gpt, "occupy"), process_info(process_name_gpt, "occupy"),
{"__type__": "update", "visible": False}, {"__type__": "update", "visible": False},
{"__type__": "update", "visible": True},{"__type__": "update"},{"__type__": "update"} {"__type__": "update", "visible": True},
{"__type__": "update"},
{"__type__": "update"},
) )
@ -1291,6 +1306,7 @@ def close1abc():
{"__type__": "update", "visible": False}, {"__type__": "update", "visible": False},
) )
def switch_version(version_): def switch_version(version_):
os.environ["version"] = version_ os.environ["version"] = version_
global version global version
@ -1492,7 +1508,7 @@ with gr.Blocks(title="GPT-SoVITS WebUI") as app:
with gr.Row(): with gr.Row():
exp_name = gr.Textbox(label=i18n("*实验/模型名"), value="xxx", interactive=True) exp_name = gr.Textbox(label=i18n("*实验/模型名"), value="xxx", interactive=True)
gpu_info = gr.Textbox(label=i18n("显卡信息"), value=gpu_info, visible=True, interactive=False) gpu_info = gr.Textbox(label=i18n("显卡信息"), value=gpu_info, visible=True, interactive=False)
version_checkbox = gr.Radio(label=i18n("版本"), value=version, choices=["v1", "v2", "v4"])#, "v3" version_checkbox = gr.Radio(label=i18n("版本"), value=version, choices=["v1", "v2", "v4"]) # , "v3"
with gr.Row(): with gr.Row():
pretrained_s2G = gr.Textbox( pretrained_s2G = gr.Textbox(
label=i18n("预训练SoVITS-G模型路径"), label=i18n("预训练SoVITS-G模型路径"),
@ -1915,7 +1931,7 @@ with gr.Blocks(title="GPT-SoVITS WebUI") as app:
if_grad_ckpt, if_grad_ckpt,
lora_rank, lora_rank,
], ],
[info1Ba, button1Ba_open, button1Ba_close,SoVITS_dropdown,GPT_dropdown], [info1Ba, button1Ba_open, button1Ba_close, SoVITS_dropdown, GPT_dropdown],
) )
button1Bb_open.click( button1Bb_open.click(
open1Bb, open1Bb,
@ -1930,7 +1946,7 @@ with gr.Blocks(title="GPT-SoVITS WebUI") as app:
gpu_numbers1Bb, gpu_numbers1Bb,
pretrained_s1, pretrained_s1,
], ],
[info1Bb, button1Bb_open, button1Bb_close,SoVITS_dropdown,GPT_dropdown], [info1Bb, button1Bb_open, button1Bb_close, SoVITS_dropdown, GPT_dropdown],
) )
version_checkbox.change( version_checkbox.change(
switch_version, switch_version,