mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2025-10-15 21:26:51 +08:00
.
This commit is contained in:
parent
8b390b4112
commit
34373c5b07
@ -156,7 +156,7 @@ class T2SEngine(T2SEngineProtocol):
|
|||||||
if idx == 1:
|
if idx == 1:
|
||||||
t1 = time.perf_counter()
|
t1 = time.perf_counter()
|
||||||
|
|
||||||
if idx == 51:
|
if idx == 20:
|
||||||
torch_profiler.end()
|
torch_profiler.end()
|
||||||
|
|
||||||
if idx % 100 == 0:
|
if idx % 100 == 0:
|
||||||
@ -184,7 +184,6 @@ class T2SEngine(T2SEngineProtocol):
|
|||||||
case "cpu":
|
case "cpu":
|
||||||
gc.collect()
|
gc.collect()
|
||||||
|
|
||||||
torch_profiler.end()
|
|
||||||
if request.use_cuda_graph and self.graphcache.is_applicable:
|
if request.use_cuda_graph and self.graphcache.is_applicable:
|
||||||
self.graphcache.release_graph(session)
|
self.graphcache.release_graph(session)
|
||||||
|
|
||||||
|
@ -9,6 +9,7 @@ import os
|
|||||||
import pickle
|
import pickle
|
||||||
import platform
|
import platform
|
||||||
import random
|
import random
|
||||||
|
import time
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from contextlib import nullcontext
|
from contextlib import nullcontext
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
@ -18,7 +19,7 @@ import torch
|
|||||||
import torch._inductor.config
|
import torch._inductor.config
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from torch.cuda.graphs import CUDAGraph
|
from torch.cuda.graphs import CUDAGraph
|
||||||
from torch.profiler import ProfilerAction, tensorboard_trace_handler
|
from torch.profiler import ExecutionTraceObserver, ProfilerAction, tensorboard_trace_handler
|
||||||
|
|
||||||
from tools.my_utils import get_machine_id
|
from tools.my_utils import get_machine_id
|
||||||
|
|
||||||
@ -663,7 +664,7 @@ class CUDAGraphCacheABC(ABC):
|
|||||||
class TorchProfiler:
|
class TorchProfiler:
|
||||||
def __init__(self, debug: bool, log_dir: str = "./profiler") -> None:
|
def __init__(self, debug: bool, log_dir: str = "./profiler") -> None:
|
||||||
self.debug = debug
|
self.debug = debug
|
||||||
self.log_dir = log_dir
|
self.log_dir = log_dir + str(time.time())
|
||||||
self.__profiler: torch.profiler.profile
|
self.__profiler: torch.profiler.profile
|
||||||
|
|
||||||
if self.debug and not os.path.exists(self.log_dir):
|
if self.debug and not os.path.exists(self.log_dir):
|
||||||
@ -678,7 +679,6 @@ class TorchProfiler:
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def three_step_schedule(step: int) -> ProfilerAction:
|
def three_step_schedule(step: int) -> ProfilerAction:
|
||||||
print(111, step)
|
|
||||||
if step == 0:
|
if step == 0:
|
||||||
return ProfilerAction.NONE
|
return ProfilerAction.NONE
|
||||||
elif step == 1:
|
elif step == 1:
|
||||||
@ -686,17 +686,15 @@ class TorchProfiler:
|
|||||||
elif step == 2:
|
elif step == 2:
|
||||||
return ProfilerAction.RECORD_AND_SAVE
|
return ProfilerAction.RECORD_AND_SAVE
|
||||||
else:
|
else:
|
||||||
return ProfilerAction.NONE
|
return ProfilerAction.RECORD_AND_SAVE
|
||||||
|
|
||||||
def start(self):
|
def start(self):
|
||||||
print(222)
|
|
||||||
if not self.debug:
|
if not self.debug:
|
||||||
return
|
return
|
||||||
assert self.__profiler is not None
|
assert self.__profiler is not None
|
||||||
self.__profiler.step()
|
self.__profiler.step()
|
||||||
|
|
||||||
def end(self):
|
def end(self):
|
||||||
print(333)
|
|
||||||
if not self.debug:
|
if not self.debug:
|
||||||
return
|
return
|
||||||
assert self.__profiler is not None
|
assert self.__profiler is not None
|
||||||
@ -713,9 +711,13 @@ class TorchProfiler:
|
|||||||
record_shapes=True,
|
record_shapes=True,
|
||||||
with_stack=True,
|
with_stack=True,
|
||||||
with_modules=True,
|
with_modules=True,
|
||||||
|
with_flops=True,
|
||||||
profile_memory=True,
|
profile_memory=True,
|
||||||
schedule=self.three_step_schedule,
|
schedule=self.three_step_schedule,
|
||||||
on_trace_ready=self.profiler_callback,
|
on_trace_ready=self.profiler_callback,
|
||||||
|
execution_trace_observer=(
|
||||||
|
ExecutionTraceObserver().register_callback(f"{self.log_dir}/execution_trace.json")
|
||||||
|
),
|
||||||
)
|
)
|
||||||
return self.__profiler
|
return self.__profiler
|
||||||
else:
|
else:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user