From 434ca2e82c45186745ea6ee222ee37ff3884596c Mon Sep 17 00:00:00 2001 From: Keming Date: Fri, 22 Nov 2024 00:46:18 -0800 Subject: [PATCH] fix memory overflow issue in get-hubert-wav --- .../prepare_datasets/2-get-hubert-wav32k.py | 54 ++++++++++--------- 1 file changed, 29 insertions(+), 25 deletions(-) diff --git a/GPT_SoVITS/prepare_datasets/2-get-hubert-wav32k.py b/GPT_SoVITS/prepare_datasets/2-get-hubert-wav32k.py index 27b61f27..0c6f239c 100644 --- a/GPT_SoVITS/prepare_datasets/2-get-hubert-wav32k.py +++ b/GPT_SoVITS/prepare_datasets/2-get-hubert-wav32k.py @@ -17,6 +17,7 @@ is_half = eval(os.environ.get("is_half", "True")) and torch.cuda.is_available() import pdb,traceback,numpy as np,logging from scipy.io import wavfile import librosa +import gc now_dir = os.getcwd() sys.path.append(now_dir) from tools.my_utils import load_audio,clean_path @@ -64,35 +65,38 @@ else: model = model.to(device) nan_fails=[] -def name2go(wav_name,wav_path): - hubert_path="%s/%s.pt"%(hubert_dir,wav_name) - if(os.path.exists(hubert_path)):return +def name2go(wav_name, wav_path): + hubert_path = f"{hubert_dir}/{wav_name}.pt" + if os.path.exists(hubert_path): + return tmp_audio = load_audio(wav_path, 32000) tmp_max = np.abs(tmp_audio).max() if tmp_max > 2.2: - print("%s-filtered,%s" % (wav_name, tmp_max)) + print(f"{wav_name}-filtered, {tmp_max}") return - tmp_audio32 = (tmp_audio / tmp_max * (maxx * alpha*32768)) + ((1 - alpha)*32768) * tmp_audio - tmp_audio32b = (tmp_audio / tmp_max * (maxx * alpha*1145.14)) + ((1 - alpha)*1145.14) * tmp_audio - tmp_audio = librosa.resample( - tmp_audio32b, orig_sr=32000, target_sr=16000 - )#不是重采样问题 - tensor_wav16 = torch.from_numpy(tmp_audio) - if (is_half == True): - tensor_wav16=tensor_wav16.half().to(device) - else: - tensor_wav16 = tensor_wav16.to(device) - ssl=model.model(tensor_wav16.unsqueeze(0))["last_hidden_state"].transpose(1,2).cpu()#torch.Size([1, 768, 215]) - if np.isnan(ssl.detach().numpy()).sum()!= 0: - nan_fails.append((wav_name,wav_path)) - print("nan filtered:%s"%wav_name) - return - wavfile.write( - "%s/%s"%(wav32dir,wav_name), - 32000, - tmp_audio32.astype("int16"), - ) - my_save(ssl,hubert_path) + tmp_audio32 = (tmp_audio / tmp_max * (maxx * alpha * 32768)) + ((1 - alpha) * 32768) * tmp_audio + tmp_audio32b = (tmp_audio / tmp_max * (maxx * alpha * 1145.14)) + ((1 - alpha) * 1145.14) * tmp_audio + tmp_audio = librosa.resample(tmp_audio32b, orig_sr=32000, target_sr=16000) + + tensor_wav16 = torch.from_numpy(tmp_audio).to(device) + if is_half: + tensor_wav16 = tensor_wav16.half() + + try: + with torch.no_grad(): + ssl = model.model(tensor_wav16.unsqueeze(0))["last_hidden_state"].transpose(1, 2).cpu() + if torch.isnan(ssl).any(): + nan_fails.append((wav_name, wav_path)) + print(f"nan filtered: {wav_name}") + return + wavfile.write(f"{wav32dir}/{wav_name}", 32000, tmp_audio32.astype("int16")) + my_save(ssl, hubert_path) + except Exception as e: + print(f"Error processing {wav_name}: {e}") + finally: + del tensor_wav16, ssl + torch.cuda.empty_cache() + gc.collect() with open(inp_text,"r",encoding="utf8")as f: lines=f.read().strip("\n").split("\n")