mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2025-10-06 14:40:00 +08:00
fix memory overflow issue in get-hubert-wav
This commit is contained in:
parent
a70e1ad30c
commit
434ca2e82c
@ -17,6 +17,7 @@ is_half = eval(os.environ.get("is_half", "True")) and torch.cuda.is_available()
|
|||||||
import pdb,traceback,numpy as np,logging
|
import pdb,traceback,numpy as np,logging
|
||||||
from scipy.io import wavfile
|
from scipy.io import wavfile
|
||||||
import librosa
|
import librosa
|
||||||
|
import gc
|
||||||
now_dir = os.getcwd()
|
now_dir = os.getcwd()
|
||||||
sys.path.append(now_dir)
|
sys.path.append(now_dir)
|
||||||
from tools.my_utils import load_audio,clean_path
|
from tools.my_utils import load_audio,clean_path
|
||||||
@ -64,35 +65,38 @@ else:
|
|||||||
model = model.to(device)
|
model = model.to(device)
|
||||||
|
|
||||||
nan_fails=[]
|
nan_fails=[]
|
||||||
def name2go(wav_name,wav_path):
|
def name2go(wav_name, wav_path):
|
||||||
hubert_path="%s/%s.pt"%(hubert_dir,wav_name)
|
hubert_path = f"{hubert_dir}/{wav_name}.pt"
|
||||||
if(os.path.exists(hubert_path)):return
|
if os.path.exists(hubert_path):
|
||||||
|
return
|
||||||
tmp_audio = load_audio(wav_path, 32000)
|
tmp_audio = load_audio(wav_path, 32000)
|
||||||
tmp_max = np.abs(tmp_audio).max()
|
tmp_max = np.abs(tmp_audio).max()
|
||||||
if tmp_max > 2.2:
|
if tmp_max > 2.2:
|
||||||
print("%s-filtered,%s" % (wav_name, tmp_max))
|
print(f"{wav_name}-filtered, {tmp_max}")
|
||||||
return
|
return
|
||||||
tmp_audio32 = (tmp_audio / tmp_max * (maxx * alpha*32768)) + ((1 - alpha)*32768) * tmp_audio
|
tmp_audio32 = (tmp_audio / tmp_max * (maxx * alpha * 32768)) + ((1 - alpha) * 32768) * tmp_audio
|
||||||
tmp_audio32b = (tmp_audio / tmp_max * (maxx * alpha*1145.14)) + ((1 - alpha)*1145.14) * tmp_audio
|
tmp_audio32b = (tmp_audio / tmp_max * (maxx * alpha * 1145.14)) + ((1 - alpha) * 1145.14) * tmp_audio
|
||||||
tmp_audio = librosa.resample(
|
tmp_audio = librosa.resample(tmp_audio32b, orig_sr=32000, target_sr=16000)
|
||||||
tmp_audio32b, orig_sr=32000, target_sr=16000
|
|
||||||
)#不是重采样问题
|
tensor_wav16 = torch.from_numpy(tmp_audio).to(device)
|
||||||
tensor_wav16 = torch.from_numpy(tmp_audio)
|
if is_half:
|
||||||
if (is_half == True):
|
tensor_wav16 = tensor_wav16.half()
|
||||||
tensor_wav16=tensor_wav16.half().to(device)
|
|
||||||
else:
|
try:
|
||||||
tensor_wav16 = tensor_wav16.to(device)
|
with torch.no_grad():
|
||||||
ssl=model.model(tensor_wav16.unsqueeze(0))["last_hidden_state"].transpose(1,2).cpu()#torch.Size([1, 768, 215])
|
ssl = model.model(tensor_wav16.unsqueeze(0))["last_hidden_state"].transpose(1, 2).cpu()
|
||||||
if np.isnan(ssl.detach().numpy()).sum()!= 0:
|
if torch.isnan(ssl).any():
|
||||||
nan_fails.append((wav_name,wav_path))
|
nan_fails.append((wav_name, wav_path))
|
||||||
print("nan filtered:%s"%wav_name)
|
print(f"nan filtered: {wav_name}")
|
||||||
return
|
return
|
||||||
wavfile.write(
|
wavfile.write(f"{wav32dir}/{wav_name}", 32000, tmp_audio32.astype("int16"))
|
||||||
"%s/%s"%(wav32dir,wav_name),
|
my_save(ssl, hubert_path)
|
||||||
32000,
|
except Exception as e:
|
||||||
tmp_audio32.astype("int16"),
|
print(f"Error processing {wav_name}: {e}")
|
||||||
)
|
finally:
|
||||||
my_save(ssl,hubert_path)
|
del tensor_wav16, ssl
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
gc.collect()
|
||||||
|
|
||||||
with open(inp_text,"r",encoding="utf8")as f:
|
with open(inp_text,"r",encoding="utf8")as f:
|
||||||
lines=f.read().strip("\n").split("\n")
|
lines=f.read().strip("\n").split("\n")
|
||||||
|
Loading…
x
Reference in New Issue
Block a user