mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2026-06-27 16:28:15 +08:00
.
This commit is contained in:
parent
99fba5154e
commit
8a2448fa3c
@ -51,6 +51,8 @@ class T2SEngine(T2SEngineProtocol):
|
||||
infer_time = 0.0
|
||||
idx = 0
|
||||
graph_state = None
|
||||
bsz = session.bsz
|
||||
result: list[torch.Tensor] = []
|
||||
|
||||
torch_profiler = TorchProfiler(debug)
|
||||
with (
|
||||
@ -65,15 +67,15 @@ class T2SEngine(T2SEngineProtocol):
|
||||
) as progress,
|
||||
):
|
||||
torch_profiler.start()
|
||||
max_token = min(int(1500 - session.input_pos.max()), 1000) * session.bsz
|
||||
max_token = min(int(1500 - session.input_pos.max()), 1000) * bsz
|
||||
task = progress.add_task("T2S Decoding", total=max_token)
|
||||
|
||||
try:
|
||||
for idx in range(max_token):
|
||||
progress.update(task, advance=session.bsz)
|
||||
progress.update(task, advance=bsz)
|
||||
if idx == 0:
|
||||
with torch_profiler.record("Prefill"), timer("Torch.Prefill", debug=debug):
|
||||
session.kv_cache = decoder.init_cache(session.bsz)
|
||||
session.kv_cache = decoder.init_cache(bsz)
|
||||
t1 = time.perf_counter()
|
||||
xy_dec = decoder.h.prefill(session.xy_pos, session.kv_cache, session.attn_mask)
|
||||
xy_dec = xy_dec[batch_idx, None, session.input_pos - 1]
|
||||
@ -86,7 +88,7 @@ class T2SEngine(T2SEngineProtocol):
|
||||
and torch.version.cuda is not None
|
||||
and os.environ.get("CUDAGraph", "1") != "0"
|
||||
):
|
||||
graph_state = self.graphcache[session.bsz].assign_graph(session)
|
||||
graph_state = self.graphcache[bsz].assign_graph(session)
|
||||
|
||||
with torch_profiler.record("Decode"), timer("Torch.Decode", debug=debug):
|
||||
if session.graph:
|
||||
@ -148,15 +150,13 @@ class T2SEngine(T2SEngineProtocol):
|
||||
logger.info(
|
||||
f"T2S Decoding EOS {session.prefill_len.tolist().__str__().strip('[]')} -> {[i.size(-1) for i in session.y_results].__str__().strip('[]')}"
|
||||
)
|
||||
logger.info(
|
||||
f"Infer Speed: {(idx + 1) * session.bsz / (time.perf_counter() - t1):.2f} token/s"
|
||||
)
|
||||
logger.info(f"Infer Speed: {(idx + 1) * bsz / (time.perf_counter() - t1):.2f} token/s")
|
||||
infer_time = time.perf_counter() - t1
|
||||
infer_speed = (idx + 1) * session.bsz / infer_time
|
||||
infer_speed = (idx + 1) * bsz / infer_time
|
||||
break
|
||||
|
||||
if (request.early_stop_num != -1 and idx >= request.early_stop_num) or idx == max_token - 1:
|
||||
for i in range(session.bsz):
|
||||
for i in range(bsz):
|
||||
if not session.completed[i].item():
|
||||
session.y_results[i] = session.y[
|
||||
[i], session.y_len : session.y_len + idx
|
||||
@ -164,7 +164,7 @@ class T2SEngine(T2SEngineProtocol):
|
||||
session.completed[i] = True
|
||||
logger.error("Bad Full Prediction")
|
||||
infer_time = time.perf_counter() - t1
|
||||
infer_speed = (idx + 1) * session.bsz / infer_time
|
||||
infer_speed = (idx + 1) * bsz / infer_time
|
||||
break
|
||||
|
||||
with torch_profiler.record("NextPos"), timer("Torch.NextPos", debug=debug):
|
||||
@ -178,7 +178,13 @@ class T2SEngine(T2SEngineProtocol):
|
||||
timer.summary()
|
||||
timer.clear()
|
||||
|
||||
result = session.y_results[: request.valid_length]
|
||||
|
||||
return result, infer_speed, infer_time, (idx + 1) * bsz
|
||||
|
||||
finally:
|
||||
del session
|
||||
|
||||
if (
|
||||
request.use_cuda_graph
|
||||
and self.graphcache.is_applicable
|
||||
@ -200,8 +206,6 @@ class T2SEngine(T2SEngineProtocol):
|
||||
case "cpu":
|
||||
gc.collect(1)
|
||||
|
||||
return session.y_results[: request.valid_length], infer_speed, infer_time, (idx + 1) * session.bsz
|
||||
|
||||
def generate(self, request: T2SRequest):
|
||||
try:
|
||||
result, infer_speed, infer_time, total_tokens = self._handle_request(request)
|
||||
|
||||
@ -964,6 +964,7 @@ class TTS:
|
||||
"""
|
||||
########## variables initialization ###########
|
||||
torch.set_grad_enabled(False)
|
||||
self.empty_cache()
|
||||
ttft_time = time.perf_counter()
|
||||
self.stop_flag: bool = False
|
||||
text: str = inputs.get("text", "")
|
||||
@ -1334,7 +1335,7 @@ class TTS:
|
||||
gc.collect()
|
||||
|
||||
if self.configs.device.type == "cuda":
|
||||
logger.info(str(torch.cuda.memory_allocated(self.configs.device.index) / 1024**3) + "GB")
|
||||
logger.info("Curr: " + str(torch.cuda.memory_allocated(self.configs.device.index) / 1024**3) + " GB")
|
||||
elif self.configs.device.type == "mps":
|
||||
logger.info("Curr: " + str(torch.mps.current_allocated_memory() / 1024**3) + " GB")
|
||||
logger.info("Driver: " + str(torch.mps.driver_allocated_memory() / 1024**3) + " GB")
|
||||
|
||||
@ -54,7 +54,7 @@ Unseen speakers few-shot fine-tuning demo:
|
||||
| Device | RTF | TTFT | Batch Size | Backend |
|
||||
| :---------: | :----: | :----: | :--------: | :-------------------------: |
|
||||
| RTX 5090 | 0.05 | 150 ms | 1 | Flash Attn Varlen CUDAGraph |
|
||||
| RTX 5090 | 0.0122 | UNK | 30 | Flash Attn Varlen CUDAGraph |
|
||||
| RTX 5090 | 0.0109 | UNK | 40 | Flash Attn Varlen CUDAGraph |
|
||||
| RTX 4060 Ti | 0.07 | 460 ms | 1 | Flash Attn Varlen CUDAGraph |
|
||||
| RTX 4060 Ti | 0.028 | UNK | 28 | Flash Attn Varlen CUDAGraph |
|
||||
| Apple M4 | 0.16 | 1363ms | 1 | MLX Varlen |
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user