name2go fix for fast_inference (#1133)

This commit is contained in:
XXXXRT666 2024-05-26 17:45:20 +01:00 committed by GitHub
parent f822b9588f
commit 8fc1e34f58
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -82,7 +82,7 @@ def name2go(wav_name,wav_path):
tensor_wav16 = tensor_wav16.to(device)
ssl=model.model(tensor_wav16.unsqueeze(0))["last_hidden_state"].transpose(1,2).cpu()#torch.Size([1, 768, 215])
if np.isnan(ssl.detach().numpy()).sum()!= 0:
nan_fails.append(wav_name)
nan_fails.append((wav_name,wav_path))
print("nan filtered:%s"%wav_name)
return
wavfile.write(
@ -90,7 +90,7 @@ def name2go(wav_name,wav_path):
32000,
tmp_audio32.astype("int16"),
)
my_save(ssl,hubert_path )
my_save(ssl,hubert_path)
with open(inp_text,"r",encoding="utf8")as f:
lines=f.read().strip("\n").split("\n")
@ -113,8 +113,8 @@ for line in lines[int(i_part)::int(all_parts)]:
if(len(nan_fails)>0 and is_half==True):
is_half=False
model=model.float()
for wav_name in nan_fails:
for wav in nan_fails:
try:
name2go(wav_name)
name2go(wav[0],wav[1])
except:
print(wav_name,traceback.format_exc())