mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2025-10-07 23:48:48 +08:00
本地的tts.py不是最新的,导致之前的修复被替换了。回滚到最新版本并添加v2支持
This commit is contained in:
parent
16a37b7b48
commit
bd53aa8200
@ -1,6 +1,6 @@
|
|||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
import math
|
import math
|
||||||
import os, sys
|
import os, sys, gc
|
||||||
import random
|
import random
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
@ -209,7 +209,7 @@ class TTS:
|
|||||||
self.text_preprocessor: TextPreprocessor = \
|
self.text_preprocessor: TextPreprocessor = \
|
||||||
TextPreprocessor(self.bert_model,
|
TextPreprocessor(self.bert_model,
|
||||||
self.bert_tokenizer,
|
self.bert_tokenizer,
|
||||||
self.configs.device, version=self.version)
|
self.configs.device)
|
||||||
|
|
||||||
self.prompt_cache: dict = {
|
self.prompt_cache: dict = {
|
||||||
"ref_audio_path": None,
|
"ref_audio_path": None,
|
||||||
@ -301,7 +301,7 @@ class TTS:
|
|||||||
if self.configs.is_half and str(self.configs.device) != "cpu":
|
if self.configs.is_half and str(self.configs.device) != "cpu":
|
||||||
self.t2s_model = self.t2s_model.half()
|
self.t2s_model = self.t2s_model.half()
|
||||||
|
|
||||||
def enable_half_precision(self, enable: bool = True):
|
def enable_half_precision(self, enable: bool = True, save: bool = True):
|
||||||
'''
|
'''
|
||||||
To enable half precision for the TTS model.
|
To enable half precision for the TTS model.
|
||||||
Args:
|
Args:
|
||||||
@ -314,6 +314,7 @@ class TTS:
|
|||||||
|
|
||||||
self.configs.is_half = enable
|
self.configs.is_half = enable
|
||||||
self.precision = torch.float16 if enable else torch.float32
|
self.precision = torch.float16 if enable else torch.float32
|
||||||
|
if save:
|
||||||
self.configs.save_configs()
|
self.configs.save_configs()
|
||||||
if enable:
|
if enable:
|
||||||
if self.t2s_model is not None:
|
if self.t2s_model is not None:
|
||||||
@ -334,13 +335,14 @@ class TTS:
|
|||||||
if self.cnhuhbert_model is not None:
|
if self.cnhuhbert_model is not None:
|
||||||
self.cnhuhbert_model = self.cnhuhbert_model.float()
|
self.cnhuhbert_model = self.cnhuhbert_model.float()
|
||||||
|
|
||||||
def set_device(self, device: torch.device):
|
def set_device(self, device: torch.device, save: bool = True):
|
||||||
'''
|
'''
|
||||||
To set the device for all models.
|
To set the device for all models.
|
||||||
Args:
|
Args:
|
||||||
device: torch.device, the device to use for all models.
|
device: torch.device, the device to use for all models.
|
||||||
'''
|
'''
|
||||||
self.configs.device = device
|
self.configs.device = device
|
||||||
|
if save:
|
||||||
self.configs.save_configs()
|
self.configs.save_configs()
|
||||||
if self.t2s_model is not None:
|
if self.t2s_model is not None:
|
||||||
self.t2s_model = self.t2s_model.to(device)
|
self.t2s_model = self.t2s_model.to(device)
|
||||||
@ -360,6 +362,10 @@ class TTS:
|
|||||||
'''
|
'''
|
||||||
self._set_prompt_semantic(ref_audio_path)
|
self._set_prompt_semantic(ref_audio_path)
|
||||||
self._set_ref_spec(ref_audio_path)
|
self._set_ref_spec(ref_audio_path)
|
||||||
|
self._set_ref_audio_path(ref_audio_path)
|
||||||
|
|
||||||
|
def _set_ref_audio_path(self, ref_audio_path):
|
||||||
|
self.prompt_cache["ref_audio_path"] = ref_audio_path
|
||||||
|
|
||||||
def _set_ref_spec(self, ref_audio_path):
|
def _set_ref_spec(self, ref_audio_path):
|
||||||
audio = load_audio(ref_audio_path, int(self.configs.sampling_rate))
|
audio = load_audio(ref_audio_path, int(self.configs.sampling_rate))
|
||||||
@ -801,7 +807,6 @@ class TTS:
|
|||||||
audio_frag_end_idx = [sum(audio_frag_idx[:i + 1]) for i in range(0, len(audio_frag_idx))]
|
audio_frag_end_idx = [sum(audio_frag_idx[:i + 1]) for i in range(0, len(audio_frag_idx))]
|
||||||
all_pred_semantic = torch.cat(pred_semantic_list).unsqueeze(0).unsqueeze(0).to(self.configs.device)
|
all_pred_semantic = torch.cat(pred_semantic_list).unsqueeze(0).unsqueeze(0).to(self.configs.device)
|
||||||
_batch_phones = torch.cat(batch_phones).unsqueeze(0).to(self.configs.device)
|
_batch_phones = torch.cat(batch_phones).unsqueeze(0).to(self.configs.device)
|
||||||
|
|
||||||
_batch_audio_fragment = (self.vits_model.decode(
|
_batch_audio_fragment = (self.vits_model.decode(
|
||||||
all_pred_semantic, _batch_phones, refer_audio_spec
|
all_pred_semantic, _batch_phones, refer_audio_spec
|
||||||
).detach()[0, 0, :])
|
).detach()[0, 0, :])
|
||||||
@ -867,6 +872,7 @@ class TTS:
|
|||||||
|
|
||||||
def empty_cache(self):
|
def empty_cache(self):
|
||||||
try:
|
try:
|
||||||
|
gc.collect() # 触发gc的垃圾回收。避免内存一直增长。
|
||||||
if "cuda" in str(self.configs.device):
|
if "cuda" in str(self.configs.device):
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
elif str(self.configs.device) == "mps":
|
elif str(self.configs.device) == "mps":
|
||||||
|
Loading…
x
Reference in New Issue
Block a user