fix nan issue(causing sovits zerodivision)

fix nan issue(which will cause sovits zerodivision)
This commit is contained in:
RVC-Boss 2024-01-23 16:59:25 +08:00 committed by GitHub
parent 948e7fc086
commit 93c47cd9f0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -49,10 +49,13 @@ maxx=0.95
alpha=0.5 alpha=0.5
device="cuda:0" device="cuda:0"
model=cnhubert.get_model() model=cnhubert.get_model()
# is_half=False
if(is_half==True): if(is_half==True):
model=model.half().to(device) model=model.half().to(device)
else: else:
model = model.to(device) model = model.to(device)
nan_fails=[]
def name2go(wav_name): def name2go(wav_name):
hubert_path="%s/%s.pt"%(hubert_dir,wav_name) hubert_path="%s/%s.pt"%(hubert_dir,wav_name)
if(os.path.exists(hubert_path)):return if(os.path.exists(hubert_path)):return
@ -60,25 +63,27 @@ def name2go(wav_name):
tmp_audio = load_audio(wav_path, 32000) tmp_audio = load_audio(wav_path, 32000)
tmp_max = np.abs(tmp_audio).max() tmp_max = np.abs(tmp_audio).max()
if tmp_max > 2.2: if tmp_max > 2.2:
print("%s-%s-%s-filtered" % (idx0, idx1, tmp_max)) print("%s-filtered" % (wav_name, tmp_max))
return return
tmp_audio32 = (tmp_audio / tmp_max * (maxx * alpha*32768)) + ((1 - alpha)*32768) * tmp_audio tmp_audio32 = (tmp_audio / tmp_max * (maxx * alpha*32768)) + ((1 - alpha)*32768) * tmp_audio
tmp_audio = librosa.resample( tmp_audio = librosa.resample(
tmp_audio32, orig_sr=32000, target_sr=16000 tmp_audio32, orig_sr=32000, target_sr=16000
) )#不是重采样问题
tensor_wav16 = torch.from_numpy(tmp_audio) tensor_wav16 = torch.from_numpy(tmp_audio)
if (is_half == True): if (is_half == True):
tensor_wav16=tensor_wav16.half().to(device) tensor_wav16=tensor_wav16.half().to(device)
else: else:
tensor_wav16 = tensor_wav16.to(device) tensor_wav16 = tensor_wav16.to(device)
ssl=model.model(tensor_wav16.unsqueeze(0))["last_hidden_state"].transpose(1,2).cpu()#torch.Size([1, 768, 215]) ssl=model.model(tensor_wav16.unsqueeze(0))["last_hidden_state"].transpose(1,2).cpu()#torch.Size([1, 768, 215])
if np.isnan(ssl.detach().numpy()).sum()!= 0:return if np.isnan(ssl.detach().numpy()).sum()!= 0:
nan_fails.append(wav_name)
print("nan filtered:%s"%wav_name)
return
wavfile.write( wavfile.write(
"%s/%s"%(wav32dir,wav_name), "%s/%s"%(wav32dir,wav_name),
32000, 32000,
tmp_audio32.astype("int16"), tmp_audio32.astype("int16"),
) )
# torch.save(ssl,hubert_path )
my_save(ssl,hubert_path ) my_save(ssl,hubert_path )
with open(inp_text,"r",encoding="utf8")as f: with open(inp_text,"r",encoding="utf8")as f:
@ -92,3 +97,12 @@ for line in lines[int(i_part)::int(all_parts)]:
name2go(wav_name) name2go(wav_name)
except: except:
print(line,traceback.format_exc()) print(line,traceback.format_exc())
if(len(nan_fails)>0 and is_half==True):
is_half=False
model=model.float()
for wav_name in nan_fails:
try:
name2go(wav_name)
except:
print(wav_name,traceback.format_exc())