mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2025-10-15 21:26:51 +08:00
.
This commit is contained in:
parent
8b390b4112
commit
34373c5b07
@ -156,7 +156,7 @@ class T2SEngine(T2SEngineProtocol):
|
||||
if idx == 1:
|
||||
t1 = time.perf_counter()
|
||||
|
||||
if idx == 51:
|
||||
if idx == 20:
|
||||
torch_profiler.end()
|
||||
|
||||
if idx % 100 == 0:
|
||||
@ -184,7 +184,6 @@ class T2SEngine(T2SEngineProtocol):
|
||||
case "cpu":
|
||||
gc.collect()
|
||||
|
||||
torch_profiler.end()
|
||||
if request.use_cuda_graph and self.graphcache.is_applicable:
|
||||
self.graphcache.release_graph(session)
|
||||
|
||||
|
@ -9,6 +9,7 @@ import os
|
||||
import pickle
|
||||
import platform
|
||||
import random
|
||||
import time
|
||||
from abc import ABC, abstractmethod
|
||||
from contextlib import nullcontext
|
||||
from pathlib import Path
|
||||
@ -18,7 +19,7 @@ import torch
|
||||
import torch._inductor.config
|
||||
import torch.nn.functional as F
|
||||
from torch.cuda.graphs import CUDAGraph
|
||||
from torch.profiler import ProfilerAction, tensorboard_trace_handler
|
||||
from torch.profiler import ExecutionTraceObserver, ProfilerAction, tensorboard_trace_handler
|
||||
|
||||
from tools.my_utils import get_machine_id
|
||||
|
||||
@ -663,7 +664,7 @@ class CUDAGraphCacheABC(ABC):
|
||||
class TorchProfiler:
|
||||
def __init__(self, debug: bool, log_dir: str = "./profiler") -> None:
|
||||
self.debug = debug
|
||||
self.log_dir = log_dir
|
||||
self.log_dir = log_dir + str(time.time())
|
||||
self.__profiler: torch.profiler.profile
|
||||
|
||||
if self.debug and not os.path.exists(self.log_dir):
|
||||
@ -678,7 +679,6 @@ class TorchProfiler:
|
||||
|
||||
@staticmethod
|
||||
def three_step_schedule(step: int) -> ProfilerAction:
|
||||
print(111, step)
|
||||
if step == 0:
|
||||
return ProfilerAction.NONE
|
||||
elif step == 1:
|
||||
@ -686,17 +686,15 @@ class TorchProfiler:
|
||||
elif step == 2:
|
||||
return ProfilerAction.RECORD_AND_SAVE
|
||||
else:
|
||||
return ProfilerAction.NONE
|
||||
return ProfilerAction.RECORD_AND_SAVE
|
||||
|
||||
def start(self):
|
||||
print(222)
|
||||
if not self.debug:
|
||||
return
|
||||
assert self.__profiler is not None
|
||||
self.__profiler.step()
|
||||
|
||||
def end(self):
|
||||
print(333)
|
||||
if not self.debug:
|
||||
return
|
||||
assert self.__profiler is not None
|
||||
@ -713,9 +711,13 @@ class TorchProfiler:
|
||||
record_shapes=True,
|
||||
with_stack=True,
|
||||
with_modules=True,
|
||||
with_flops=True,
|
||||
profile_memory=True,
|
||||
schedule=self.three_step_schedule,
|
||||
on_trace_ready=self.profiler_callback,
|
||||
execution_trace_observer=(
|
||||
ExecutionTraceObserver().register_callback(f"{self.log_dir}/execution_trace.json")
|
||||
),
|
||||
)
|
||||
return self.__profiler
|
||||
else:
|
||||
|
Loading…
x
Reference in New Issue
Block a user