mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2025-04-06 03:57:44 +08:00
parent
8dd7cfab93
commit
836bfec1fb
@ -301,7 +301,7 @@ class TTS:
|
|||||||
if self.configs.is_half and str(self.configs.device)!="cpu":
|
if self.configs.is_half and str(self.configs.device)!="cpu":
|
||||||
self.t2s_model = self.t2s_model.half()
|
self.t2s_model = self.t2s_model.half()
|
||||||
|
|
||||||
def enable_half_precision(self, enable: bool = True):
|
def enable_half_precision(self, enable: bool = True, save: bool = True):
|
||||||
'''
|
'''
|
||||||
To enable half precision for the TTS model.
|
To enable half precision for the TTS model.
|
||||||
Args:
|
Args:
|
||||||
@ -314,6 +314,7 @@ class TTS:
|
|||||||
|
|
||||||
self.configs.is_half = enable
|
self.configs.is_half = enable
|
||||||
self.precision = torch.float16 if enable else torch.float32
|
self.precision = torch.float16 if enable else torch.float32
|
||||||
|
if save:
|
||||||
self.configs.save_configs()
|
self.configs.save_configs()
|
||||||
if enable:
|
if enable:
|
||||||
if self.t2s_model is not None:
|
if self.t2s_model is not None:
|
||||||
@ -334,13 +335,14 @@ class TTS:
|
|||||||
if self.cnhuhbert_model is not None:
|
if self.cnhuhbert_model is not None:
|
||||||
self.cnhuhbert_model = self.cnhuhbert_model.float()
|
self.cnhuhbert_model = self.cnhuhbert_model.float()
|
||||||
|
|
||||||
def set_device(self, device: torch.device):
|
def set_device(self, device: torch.device, save: bool = True):
|
||||||
'''
|
'''
|
||||||
To set the device for all models.
|
To set the device for all models.
|
||||||
Args:
|
Args:
|
||||||
device: torch.device, the device to use for all models.
|
device: torch.device, the device to use for all models.
|
||||||
'''
|
'''
|
||||||
self.configs.device = device
|
self.configs.device = device
|
||||||
|
if save:
|
||||||
self.configs.save_configs()
|
self.configs.save_configs()
|
||||||
if self.t2s_model is not None:
|
if self.t2s_model is not None:
|
||||||
self.t2s_model = self.t2s_model.to(device)
|
self.t2s_model = self.t2s_model.to(device)
|
||||||
|
12
api_v3.py
12
api_v3.py
@ -320,7 +320,7 @@ async def tts_handle(req: dict):
|
|||||||
try:
|
try:
|
||||||
tts_instance = get_tts_instance(tts_config)
|
tts_instance = get_tts_instance(tts_config)
|
||||||
|
|
||||||
move_to_gpu(tts_instance, tts_config)
|
move_to_original(tts_instance, tts_config)
|
||||||
|
|
||||||
tts_generator = tts_instance.run(req)
|
tts_generator = tts_instance.run(req)
|
||||||
|
|
||||||
@ -347,13 +347,15 @@ async def tts_handle(req: dict):
|
|||||||
|
|
||||||
def move_to_cpu(tts):
|
def move_to_cpu(tts):
|
||||||
cpu_device = torch.device('cpu')
|
cpu_device = torch.device('cpu')
|
||||||
tts.set_device(cpu_device)
|
tts.set_device(cpu_device, False)
|
||||||
|
tts.enable_half_precision(False, False)
|
||||||
print("Moved TTS models to CPU to save GPU memory.")
|
print("Moved TTS models to CPU to save GPU memory.")
|
||||||
|
|
||||||
|
|
||||||
def move_to_gpu(tts: TTS, tts_config: TTS_Config):
|
def move_to_original(tts: TTS, tts_config: TTS_Config):
|
||||||
tts.set_device(tts_config.device)
|
tts.set_device(tts_config.device, False)
|
||||||
print("Moved TTS models back to GPU for performance.")
|
tts.enable_half_precision(tts_config.is_half, False)
|
||||||
|
print("Moved TTS models back to original device for performance.")
|
||||||
|
|
||||||
|
|
||||||
@APP.get("/control")
|
@APP.get("/control")
|
||||||
|
Loading…
x
Reference in New Issue
Block a user