diff --git a/GPT_SoVITS/Accelerate/PyTorch/t2s_engine.py b/GPT_SoVITS/Accelerate/PyTorch/t2s_engine.py index 1d348802..1a848d45 100644 --- a/GPT_SoVITS/Accelerate/PyTorch/t2s_engine.py +++ b/GPT_SoVITS/Accelerate/PyTorch/t2s_engine.py @@ -51,6 +51,8 @@ class T2SEngine(T2SEngineProtocol): infer_time = 0.0 idx = 0 graph_state = None + bsz = session.bsz + result: list[torch.Tensor] = [] torch_profiler = TorchProfiler(debug) with ( @@ -65,15 +67,15 @@ class T2SEngine(T2SEngineProtocol): ) as progress, ): torch_profiler.start() - max_token = min(int(1500 - session.input_pos.max()), 1000) * session.bsz + max_token = min(int(1500 - session.input_pos.max()), 1000) * bsz task = progress.add_task("T2S Decoding", total=max_token) try: for idx in range(max_token): - progress.update(task, advance=session.bsz) + progress.update(task, advance=bsz) if idx == 0: with torch_profiler.record("Prefill"), timer("Torch.Prefill", debug=debug): - session.kv_cache = decoder.init_cache(session.bsz) + session.kv_cache = decoder.init_cache(bsz) t1 = time.perf_counter() xy_dec = decoder.h.prefill(session.xy_pos, session.kv_cache, session.attn_mask) xy_dec = xy_dec[batch_idx, None, session.input_pos - 1] @@ -86,7 +88,7 @@ class T2SEngine(T2SEngineProtocol): and torch.version.cuda is not None and os.environ.get("CUDAGraph", "1") != "0" ): - graph_state = self.graphcache[session.bsz].assign_graph(session) + graph_state = self.graphcache[bsz].assign_graph(session) with torch_profiler.record("Decode"), timer("Torch.Decode", debug=debug): if session.graph: @@ -148,15 +150,13 @@ class T2SEngine(T2SEngineProtocol): logger.info( f"T2S Decoding EOS {session.prefill_len.tolist().__str__().strip('[]')} -> {[i.size(-1) for i in session.y_results].__str__().strip('[]')}" ) - logger.info( - f"Infer Speed: {(idx + 1) * session.bsz / (time.perf_counter() - t1):.2f} token/s" - ) + logger.info(f"Infer Speed: {(idx + 1) * bsz / (time.perf_counter() - t1):.2f} token/s") infer_time = time.perf_counter() - t1 - infer_speed = (idx + 1) * session.bsz / infer_time + infer_speed = (idx + 1) * bsz / infer_time break if (request.early_stop_num != -1 and idx >= request.early_stop_num) or idx == max_token - 1: - for i in range(session.bsz): + for i in range(bsz): if not session.completed[i].item(): session.y_results[i] = session.y[ [i], session.y_len : session.y_len + idx @@ -164,7 +164,7 @@ class T2SEngine(T2SEngineProtocol): session.completed[i] = True logger.error("Bad Full Prediction") infer_time = time.perf_counter() - t1 - infer_speed = (idx + 1) * session.bsz / infer_time + infer_speed = (idx + 1) * bsz / infer_time break with torch_profiler.record("NextPos"), timer("Torch.NextPos", debug=debug): @@ -178,7 +178,13 @@ class T2SEngine(T2SEngineProtocol): timer.summary() timer.clear() + result = session.y_results[: request.valid_length] + + return result, infer_speed, infer_time, (idx + 1) * bsz + finally: + del session + if ( request.use_cuda_graph and self.graphcache.is_applicable @@ -200,8 +206,6 @@ class T2SEngine(T2SEngineProtocol): case "cpu": gc.collect(1) - return session.y_results[: request.valid_length], infer_speed, infer_time, (idx + 1) * session.bsz - def generate(self, request: T2SRequest): try: result, infer_speed, infer_time, total_tokens = self._handle_request(request) diff --git a/GPT_SoVITS/TTS_infer_pack/TTS.py b/GPT_SoVITS/TTS_infer_pack/TTS.py index aaa661d8..1ed238f9 100644 --- a/GPT_SoVITS/TTS_infer_pack/TTS.py +++ b/GPT_SoVITS/TTS_infer_pack/TTS.py @@ -964,6 +964,7 @@ class TTS: """ ########## variables initialization ########### torch.set_grad_enabled(False) + self.empty_cache() ttft_time = time.perf_counter() self.stop_flag: bool = False text: str = inputs.get("text", "") @@ -1334,7 +1335,7 @@ class TTS: gc.collect() if self.configs.device.type == "cuda": - logger.info(str(torch.cuda.memory_allocated(self.configs.device.index) / 1024**3) + "GB") + logger.info("Curr: " + str(torch.cuda.memory_allocated(self.configs.device.index) / 1024**3) + " GB") elif self.configs.device.type == "mps": logger.info("Curr: " + str(torch.mps.current_allocated_memory() / 1024**3) + " GB") logger.info("Driver: " + str(torch.mps.driver_allocated_memory() / 1024**3) + " GB") diff --git a/README.md b/README.md index 3b5aa06f..12f682a9 100644 --- a/README.md +++ b/README.md @@ -54,7 +54,7 @@ Unseen speakers few-shot fine-tuning demo: | Device | RTF | TTFT | Batch Size | Backend | | :---------: | :----: | :----: | :--------: | :-------------------------: | | RTX 5090 | 0.05 | 150 ms | 1 | Flash Attn Varlen CUDAGraph | -| RTX 5090 | 0.0122 | UNK | 30 | Flash Attn Varlen CUDAGraph | +| RTX 5090 | 0.0109 | UNK | 40 | Flash Attn Varlen CUDAGraph | | RTX 4060 Ti | 0.07 | 460 ms | 1 | Flash Attn Varlen CUDAGraph | | RTX 4060 Ti | 0.028 | UNK | 28 | Flash Attn Varlen CUDAGraph | | Apple M4 | 0.16 | 1363ms | 1 | MLX Varlen |