GPT-SoVITS/GPT_SoVITS/TTS_infer_pack/tts_instance_pool.py
2024-05-27 16:20:06 +08:00

134 lines
5.0 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import threading
from time import perf_counter
import traceback
from typing import Dict, Union
from GPT_SoVITS.TTS_infer_pack.TTS import TTS, TTS_Config
class TTSWrapper(TTS):
heat: float = 0
usage_count: int = 0
usage_counter: int = 0
usage_time: float = 0.0
first_used_time: float = 0.0
def __init__(self, configs: Union[dict, str, TTS_Config]):
super(TTSWrapper, self).__init__(configs)
self.first_used_time = perf_counter()
def __hash__(self) -> int:
return hash(self.first_used_time)
def run(self, *args, **kwargs):
self.usage_counter += 1
t0 = perf_counter()
for result in super(TTSWrapper, self).run(*args, **kwargs):
yield result
t1 = perf_counter()
self.usage_time += t1 - t0
idle_time = self.usage_time - self.first_used_time
self.heat = self.usage_counter / idle_time
def reset_heat(self):
self.heat: int = 0
self.usage_count: int = 0
self.usage_time: float = 0.0
self.first_used_time: float = perf_counter()
class TTSInstancePool:
def __init__(self, max_size):
self.max_size: int = max_size
self.semaphore: threading.Semaphore = threading.Semaphore(max_size)
self.pool_lock: threading.Lock = threading.Lock()
self.pool: Dict[int, TTSWrapper] = dict()
self.current_index: int = 0
self.size: int = 0
def acquire(self, configs: TTS_Config):
self.semaphore.acquire()
try:
with self.pool_lock:
# 查询最匹配的实例
indexed_key = None
rank = []
for key, tts_instance in self.pool.items():
if tts_instance.configs.vits_weights_path == configs.vits_weights_path \
and tts_instance.configs.t2s_weights_path == configs.t2s_weights_path:
indexed_key = key
rank.append((tts_instance.heat, key))
rank.sort(key=lambda x: x[0])
matched_key = None if len(rank) == 0 else rank[0][1]
# 如果已有实例匹配,则直接复用
if indexed_key is not None:
tts_instance = self._reuse_instance(indexed_key, configs)
print(f"如果已有实例匹配,则直接复用: {configs.vits_weights_path} {configs.t2s_weights_path}")
return tts_instance
# 如果pool未满则创建一个新实例
if self.size < self.max_size:
tts_instance = TTSWrapper(configs)
self.size += 1
print(f"如果pool未满则创建一个新实例: {configs.vits_weights_path} {configs.t2s_weights_path}")
return tts_instance
else:
# 否则用最合适的实例进行复用
tts_instance = self._reuse_instance(matched_key, configs)
print(f"否则用最合适的实例进行复用: {configs.vits_weights_path} {configs.t2s_weights_path}")
return tts_instance
except Exception as e:
self.semaphore.release()
traceback.print_exc()
raise e
def release(self, tts_instance: TTSWrapper):
assert tts_instance is not None
with self.pool_lock:
key = hash(tts_instance)
if key in self.pool.keys():
return
self.pool[key] = tts_instance
self.semaphore.release()
def clear_pool(self):
for i in range(self.max_size):
self.semaphore.acquire()
with self.pool_lock:
self.pool.clear()
# for i in range(self.max_size):
self.semaphore.release(self.max_size)
def _reuse_instance(self, instance_key: int, configs: TTS_Config) -> TTSWrapper:
"""
复用已有实例
args:
instance_key: int, 已有实例的Key
config: TTS_Config
return:
TTS_Wrapper: 返回复用的TTS实例
"""
# 复用已有实例
tts_instance = self.pool.pop(instance_key, None)
if tts_instance is None:
raise ValueError("Instance not found")
tts_instance.configs.device = configs.device
if tts_instance.configs.vits_weights_path != configs.vits_weights_path \
or tts_instance.configs.t2s_weights_path != configs.t2s_weights_path:
tts_instance.reset_heat()
if tts_instance.configs.vits_weights_path != configs.vits_weights_path:
tts_instance.init_vits_weights(configs.vits_weights_path, False)
tts_instance.configs.vits_weights_path = configs.vits_weights_path
if tts_instance.configs.t2s_weights_path != configs.t2s_weights_path:
tts_instance.init_t2s_weights(configs.t2s_weights_path, False)
tts_instance.configs.t2s_weights_path = configs.t2s_weights_path
tts_instance.set_device(configs.device, False)
return tts_instance