From 34373c5b07e9aa4cd981ce9ad73b790b155f9339 Mon Sep 17 00:00:00 2001 From: XXXXRT666 <157766680+XXXXRT666@users.noreply.github.com> Date: Sun, 5 Oct 2025 07:37:28 +0100 Subject: [PATCH] . --- GPT_SoVITS/Accelerate/PyTorch/t2s_engine.py | 3 +-- GPT_SoVITS/Accelerate/PyTorch/t2s_model_abc.py | 14 ++++++++------ 2 files changed, 9 insertions(+), 8 deletions(-) diff --git a/GPT_SoVITS/Accelerate/PyTorch/t2s_engine.py b/GPT_SoVITS/Accelerate/PyTorch/t2s_engine.py index 85fcbbdc..604ce05a 100644 --- a/GPT_SoVITS/Accelerate/PyTorch/t2s_engine.py +++ b/GPT_SoVITS/Accelerate/PyTorch/t2s_engine.py @@ -156,7 +156,7 @@ class T2SEngine(T2SEngineProtocol): if idx == 1: t1 = time.perf_counter() - if idx == 51: + if idx == 20: torch_profiler.end() if idx % 100 == 0: @@ -184,7 +184,6 @@ class T2SEngine(T2SEngineProtocol): case "cpu": gc.collect() - torch_profiler.end() if request.use_cuda_graph and self.graphcache.is_applicable: self.graphcache.release_graph(session) diff --git a/GPT_SoVITS/Accelerate/PyTorch/t2s_model_abc.py b/GPT_SoVITS/Accelerate/PyTorch/t2s_model_abc.py index 6ce5b632..7b46b6ab 100644 --- a/GPT_SoVITS/Accelerate/PyTorch/t2s_model_abc.py +++ b/GPT_SoVITS/Accelerate/PyTorch/t2s_model_abc.py @@ -9,6 +9,7 @@ import os import pickle import platform import random +import time from abc import ABC, abstractmethod from contextlib import nullcontext from pathlib import Path @@ -18,7 +19,7 @@ import torch import torch._inductor.config import torch.nn.functional as F from torch.cuda.graphs import CUDAGraph -from torch.profiler import ProfilerAction, tensorboard_trace_handler +from torch.profiler import ExecutionTraceObserver, ProfilerAction, tensorboard_trace_handler from tools.my_utils import get_machine_id @@ -663,7 +664,7 @@ class CUDAGraphCacheABC(ABC): class TorchProfiler: def __init__(self, debug: bool, log_dir: str = "./profiler") -> None: self.debug = debug - self.log_dir = log_dir + self.log_dir = log_dir + str(time.time()) self.__profiler: torch.profiler.profile if self.debug and not os.path.exists(self.log_dir): @@ -678,7 +679,6 @@ class TorchProfiler: @staticmethod def three_step_schedule(step: int) -> ProfilerAction: - print(111, step) if step == 0: return ProfilerAction.NONE elif step == 1: @@ -686,17 +686,15 @@ class TorchProfiler: elif step == 2: return ProfilerAction.RECORD_AND_SAVE else: - return ProfilerAction.NONE + return ProfilerAction.RECORD_AND_SAVE def start(self): - print(222) if not self.debug: return assert self.__profiler is not None self.__profiler.step() def end(self): - print(333) if not self.debug: return assert self.__profiler is not None @@ -713,9 +711,13 @@ class TorchProfiler: record_shapes=True, with_stack=True, with_modules=True, + with_flops=True, profile_memory=True, schedule=self.three_step_schedule, on_trace_ready=self.profiler_callback, + execution_trace_observer=( + ExecutionTraceObserver().register_callback(f"{self.log_dir}/execution_trace.json") + ), ) return self.__profiler else: