This commit is contained in:
XXXXRT666 2025-10-05 07:37:28 +01:00
parent 8b390b4112
commit 34373c5b07
2 changed files with 9 additions and 8 deletions

View File

@ -156,7 +156,7 @@ class T2SEngine(T2SEngineProtocol):
if idx == 1:
t1 = time.perf_counter()
if idx == 51:
if idx == 20:
torch_profiler.end()
if idx % 100 == 0:
@ -184,7 +184,6 @@ class T2SEngine(T2SEngineProtocol):
case "cpu":
gc.collect()
torch_profiler.end()
if request.use_cuda_graph and self.graphcache.is_applicable:
self.graphcache.release_graph(session)

View File

@ -9,6 +9,7 @@ import os
import pickle
import platform
import random
import time
from abc import ABC, abstractmethod
from contextlib import nullcontext
from pathlib import Path
@ -18,7 +19,7 @@ import torch
import torch._inductor.config
import torch.nn.functional as F
from torch.cuda.graphs import CUDAGraph
from torch.profiler import ProfilerAction, tensorboard_trace_handler
from torch.profiler import ExecutionTraceObserver, ProfilerAction, tensorboard_trace_handler
from tools.my_utils import get_machine_id
@ -663,7 +664,7 @@ class CUDAGraphCacheABC(ABC):
class TorchProfiler:
def __init__(self, debug: bool, log_dir: str = "./profiler") -> None:
self.debug = debug
self.log_dir = log_dir
self.log_dir = log_dir + str(time.time())
self.__profiler: torch.profiler.profile
if self.debug and not os.path.exists(self.log_dir):
@ -678,7 +679,6 @@ class TorchProfiler:
@staticmethod
def three_step_schedule(step: int) -> ProfilerAction:
print(111, step)
if step == 0:
return ProfilerAction.NONE
elif step == 1:
@ -686,17 +686,15 @@ class TorchProfiler:
elif step == 2:
return ProfilerAction.RECORD_AND_SAVE
else:
return ProfilerAction.NONE
return ProfilerAction.RECORD_AND_SAVE
def start(self):
print(222)
if not self.debug:
return
assert self.__profiler is not None
self.__profiler.step()
def end(self):
print(333)
if not self.debug:
return
assert self.__profiler is not None
@ -713,9 +711,13 @@ class TorchProfiler:
record_shapes=True,
with_stack=True,
with_modules=True,
with_flops=True,
profile_memory=True,
schedule=self.three_step_schedule,
on_trace_ready=self.profiler_callback,
execution_trace_observer=(
ExecutionTraceObserver().register_callback(f"{self.log_dir}/execution_trace.json")
),
)
return self.__profiler
else: