feat: make GPU selectable in get_tts_wav

This commit is contained in:
Jacky He 2025-09-02 17:48:44 +08:00
parent fdf794e31d
commit 611ff1e8c0

View File

@ -765,8 +765,17 @@ def get_tts_wav(
sample_steps=8, sample_steps=8,
if_sr=False, if_sr=False,
pause_second=0.3, pause_second=0.3,
device_override=None
): ):
global cache global cache
global device
if device_override:
device = device_override
# Check if models are loaded
if (tokenizer is None or bert_model is None or ssl_model is None or
vq_model is None or t2s_model is None):
raise RuntimeError("Models not loaded. Please call load_models() first.")
if ref_wav_path: if ref_wav_path:
pass pass
else: else: