mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2025-04-29 22:10:21 +08:00
修复了OutOfMemoryError时,显存无法释放的问题
This commit is contained in:
parent
f2cbc826c7
commit
d60d8ea3fb
@ -2,6 +2,7 @@ from copy import deepcopy
|
|||||||
import math
|
import math
|
||||||
import os, sys
|
import os, sys
|
||||||
import random
|
import random
|
||||||
|
import traceback
|
||||||
now_dir = os.getcwd()
|
now_dir = os.getcwd()
|
||||||
sys.path.append(now_dir)
|
sys.path.append(now_dir)
|
||||||
import ffmpeg
|
import ffmpeg
|
||||||
@ -48,7 +49,17 @@ custom:
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# def set_seed(seed):
|
||||||
|
# random.seed(seed)
|
||||||
|
# os.environ['PYTHONHASHSEED'] = str(seed)
|
||||||
|
# np.random.seed(seed)
|
||||||
|
# torch.manual_seed(seed)
|
||||||
|
# torch.cuda.manual_seed(seed)
|
||||||
|
# torch.cuda.manual_seed_all(seed)
|
||||||
|
# torch.backends.cudnn.deterministic = True
|
||||||
|
# torch.backends.cudnn.benchmark = False
|
||||||
|
# torch.backends.cudnn.enabled = True
|
||||||
|
# set_seed(1234)
|
||||||
|
|
||||||
class TTS_Config:
|
class TTS_Config:
|
||||||
default_configs={
|
default_configs={
|
||||||
@ -630,7 +641,7 @@ class TTS:
|
|||||||
split_bucket=split_bucket
|
split_bucket=split_bucket
|
||||||
)
|
)
|
||||||
t2 = ttime()
|
t2 = ttime()
|
||||||
|
try:
|
||||||
print("############ 推理 ############")
|
print("############ 推理 ############")
|
||||||
###### inference ######
|
###### inference ######
|
||||||
t_34 = 0.0
|
t_34 = 0.0
|
||||||
@ -741,14 +752,30 @@ class TTS:
|
|||||||
batch_index_list,
|
batch_index_list,
|
||||||
speed_factor,
|
speed_factor,
|
||||||
split_bucket)
|
split_bucket)
|
||||||
|
except Exception as e:
|
||||||
|
traceback.print_exc()
|
||||||
|
# 必须返回一个空音频, 否则会导致显存不释放。
|
||||||
|
yield self.configs.sampling_rate, np.zeros(int(self.configs.sampling_rate),
|
||||||
|
dtype=np.int16)
|
||||||
|
# 重置模型, 否则会导致显存释放不完全。
|
||||||
|
del self.t2s_model
|
||||||
|
del self.vits_model
|
||||||
|
self.t2s_model = None
|
||||||
|
self.vits_model = None
|
||||||
|
self.init_t2s_weights(self.configs.t2s_weights_path)
|
||||||
|
self.init_vits_weights(self.configs.vits_weights_path)
|
||||||
|
finally:
|
||||||
|
self.empty_cache()
|
||||||
|
|
||||||
|
def empty_cache(self):
|
||||||
try:
|
try:
|
||||||
|
if str(self.configs.device) == "cuda":
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
elif str(self.configs.device) == "mps":
|
||||||
|
torch.mps.empty_cache()
|
||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def audio_postprocess(self,
|
def audio_postprocess(self,
|
||||||
audio:List[torch.Tensor],
|
audio:List[torch.Tensor],
|
||||||
sr:int,
|
sr:int,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user