diff --git a/GPT_SoVITS/Accelerate/PyTorch/t2s_engine.py b/GPT_SoVITS/Accelerate/PyTorch/t2s_engine.py index b1d7cbf0..85fcbbdc 100644 --- a/GPT_SoVITS/Accelerate/PyTorch/t2s_engine.py +++ b/GPT_SoVITS/Accelerate/PyTorch/t2s_engine.py @@ -58,15 +58,17 @@ class T2SEngine(T2SEngineProtocol): transient=True, ) as progress, ): + torch_profiler.start() max_token = int(min(2000 - session.input_pos.max(), 1500)) task = progress.add_task("T2S Decoding", total=max_token) for idx in range(max_token): progress.update(task, advance=1) if idx == 0: - session.kv_cache = decoder.init_cache(session.bsz) - xy_dec = decoder.h.prefill(session.xy_pos, session.kv_cache, session.attn_mask) - xy_dec = xy_dec[None, batch_idx, session.input_pos - 1] + with torch_profiler.record("Prefill"): + session.kv_cache = decoder.init_cache(session.bsz) + xy_dec = decoder.h.prefill(session.xy_pos, session.kv_cache, session.attn_mask) + xy_dec = xy_dec[None, batch_idx, session.input_pos - 1] else: if ( request.use_cuda_graph @@ -76,7 +78,7 @@ class T2SEngine(T2SEngineProtocol): ): self.graphcache.assign_graph(session) - with torch_profiler.record("AR"): + with torch_profiler.record("Decode"): if session.graph: assert session.stream session.stream.wait_stream(torch.cuda.default_stream()) @@ -152,7 +154,6 @@ class T2SEngine(T2SEngineProtocol): session.xy_pos = decoder.ar_audio_position(session.input_pos - session.x_lens, y_emb) if idx == 1: - torch_profiler.start() t1 = time.perf_counter() if idx == 51: diff --git a/GPT_SoVITS/Accelerate/PyTorch/t2s_model_abc.py b/GPT_SoVITS/Accelerate/PyTorch/t2s_model_abc.py index d86d44e4..6ce5b632 100644 --- a/GPT_SoVITS/Accelerate/PyTorch/t2s_model_abc.py +++ b/GPT_SoVITS/Accelerate/PyTorch/t2s_model_abc.py @@ -678,6 +678,7 @@ class TorchProfiler: @staticmethod def three_step_schedule(step: int) -> ProfilerAction: + print(111, step) if step == 0: return ProfilerAction.NONE elif step == 1: @@ -688,12 +689,14 @@ class TorchProfiler: return ProfilerAction.NONE def start(self): + print(222) if not self.debug: return assert self.__profiler is not None self.__profiler.step() def end(self): + print(333) if not self.debug: return assert self.__profiler is not None