This commit is contained in:
XXXXRT666 2025-10-05 07:37:28 +01:00
parent 8b390b4112
commit 34373c5b07
2 changed files with 9 additions and 8 deletions

View File

@ -156,7 +156,7 @@ class T2SEngine(T2SEngineProtocol):
if idx == 1: if idx == 1:
t1 = time.perf_counter() t1 = time.perf_counter()
if idx == 51: if idx == 20:
torch_profiler.end() torch_profiler.end()
if idx % 100 == 0: if idx % 100 == 0:
@ -184,7 +184,6 @@ class T2SEngine(T2SEngineProtocol):
case "cpu": case "cpu":
gc.collect() gc.collect()
torch_profiler.end()
if request.use_cuda_graph and self.graphcache.is_applicable: if request.use_cuda_graph and self.graphcache.is_applicable:
self.graphcache.release_graph(session) self.graphcache.release_graph(session)

View File

@ -9,6 +9,7 @@ import os
import pickle import pickle
import platform import platform
import random import random
import time
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from contextlib import nullcontext from contextlib import nullcontext
from pathlib import Path from pathlib import Path
@ -18,7 +19,7 @@ import torch
import torch._inductor.config import torch._inductor.config
import torch.nn.functional as F import torch.nn.functional as F
from torch.cuda.graphs import CUDAGraph from torch.cuda.graphs import CUDAGraph
from torch.profiler import ProfilerAction, tensorboard_trace_handler from torch.profiler import ExecutionTraceObserver, ProfilerAction, tensorboard_trace_handler
from tools.my_utils import get_machine_id from tools.my_utils import get_machine_id
@ -663,7 +664,7 @@ class CUDAGraphCacheABC(ABC):
class TorchProfiler: class TorchProfiler:
def __init__(self, debug: bool, log_dir: str = "./profiler") -> None: def __init__(self, debug: bool, log_dir: str = "./profiler") -> None:
self.debug = debug self.debug = debug
self.log_dir = log_dir self.log_dir = log_dir + str(time.time())
self.__profiler: torch.profiler.profile self.__profiler: torch.profiler.profile
if self.debug and not os.path.exists(self.log_dir): if self.debug and not os.path.exists(self.log_dir):
@ -678,7 +679,6 @@ class TorchProfiler:
@staticmethod @staticmethod
def three_step_schedule(step: int) -> ProfilerAction: def three_step_schedule(step: int) -> ProfilerAction:
print(111, step)
if step == 0: if step == 0:
return ProfilerAction.NONE return ProfilerAction.NONE
elif step == 1: elif step == 1:
@ -686,17 +686,15 @@ class TorchProfiler:
elif step == 2: elif step == 2:
return ProfilerAction.RECORD_AND_SAVE return ProfilerAction.RECORD_AND_SAVE
else: else:
return ProfilerAction.NONE return ProfilerAction.RECORD_AND_SAVE
def start(self): def start(self):
print(222)
if not self.debug: if not self.debug:
return return
assert self.__profiler is not None assert self.__profiler is not None
self.__profiler.step() self.__profiler.step()
def end(self): def end(self):
print(333)
if not self.debug: if not self.debug:
return return
assert self.__profiler is not None assert self.__profiler is not None
@ -713,9 +711,13 @@ class TorchProfiler:
record_shapes=True, record_shapes=True,
with_stack=True, with_stack=True,
with_modules=True, with_modules=True,
with_flops=True,
profile_memory=True, profile_memory=True,
schedule=self.three_step_schedule, schedule=self.three_step_schedule,
on_trace_ready=self.profiler_callback, on_trace_ready=self.profiler_callback,
execution_trace_observer=(
ExecutionTraceObserver().register_callback(f"{self.log_dir}/execution_trace.json")
),
) )
return self.__profiler return self.__profiler
else: else: