mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2025-10-16 05:36:34 +08:00
.
This commit is contained in:
parent
184ebbf54a
commit
8b390b4112
@ -58,15 +58,17 @@ class T2SEngine(T2SEngineProtocol):
|
|||||||
transient=True,
|
transient=True,
|
||||||
) as progress,
|
) as progress,
|
||||||
):
|
):
|
||||||
|
torch_profiler.start()
|
||||||
max_token = int(min(2000 - session.input_pos.max(), 1500))
|
max_token = int(min(2000 - session.input_pos.max(), 1500))
|
||||||
task = progress.add_task("T2S Decoding", total=max_token)
|
task = progress.add_task("T2S Decoding", total=max_token)
|
||||||
|
|
||||||
for idx in range(max_token):
|
for idx in range(max_token):
|
||||||
progress.update(task, advance=1)
|
progress.update(task, advance=1)
|
||||||
if idx == 0:
|
if idx == 0:
|
||||||
session.kv_cache = decoder.init_cache(session.bsz)
|
with torch_profiler.record("Prefill"):
|
||||||
xy_dec = decoder.h.prefill(session.xy_pos, session.kv_cache, session.attn_mask)
|
session.kv_cache = decoder.init_cache(session.bsz)
|
||||||
xy_dec = xy_dec[None, batch_idx, session.input_pos - 1]
|
xy_dec = decoder.h.prefill(session.xy_pos, session.kv_cache, session.attn_mask)
|
||||||
|
xy_dec = xy_dec[None, batch_idx, session.input_pos - 1]
|
||||||
else:
|
else:
|
||||||
if (
|
if (
|
||||||
request.use_cuda_graph
|
request.use_cuda_graph
|
||||||
@ -76,7 +78,7 @@ class T2SEngine(T2SEngineProtocol):
|
|||||||
):
|
):
|
||||||
self.graphcache.assign_graph(session)
|
self.graphcache.assign_graph(session)
|
||||||
|
|
||||||
with torch_profiler.record("AR"):
|
with torch_profiler.record("Decode"):
|
||||||
if session.graph:
|
if session.graph:
|
||||||
assert session.stream
|
assert session.stream
|
||||||
session.stream.wait_stream(torch.cuda.default_stream())
|
session.stream.wait_stream(torch.cuda.default_stream())
|
||||||
@ -152,7 +154,6 @@ class T2SEngine(T2SEngineProtocol):
|
|||||||
session.xy_pos = decoder.ar_audio_position(session.input_pos - session.x_lens, y_emb)
|
session.xy_pos = decoder.ar_audio_position(session.input_pos - session.x_lens, y_emb)
|
||||||
|
|
||||||
if idx == 1:
|
if idx == 1:
|
||||||
torch_profiler.start()
|
|
||||||
t1 = time.perf_counter()
|
t1 = time.perf_counter()
|
||||||
|
|
||||||
if idx == 51:
|
if idx == 51:
|
||||||
|
@ -678,6 +678,7 @@ class TorchProfiler:
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def three_step_schedule(step: int) -> ProfilerAction:
|
def three_step_schedule(step: int) -> ProfilerAction:
|
||||||
|
print(111, step)
|
||||||
if step == 0:
|
if step == 0:
|
||||||
return ProfilerAction.NONE
|
return ProfilerAction.NONE
|
||||||
elif step == 1:
|
elif step == 1:
|
||||||
@ -688,12 +689,14 @@ class TorchProfiler:
|
|||||||
return ProfilerAction.NONE
|
return ProfilerAction.NONE
|
||||||
|
|
||||||
def start(self):
|
def start(self):
|
||||||
|
print(222)
|
||||||
if not self.debug:
|
if not self.debug:
|
||||||
return
|
return
|
||||||
assert self.__profiler is not None
|
assert self.__profiler is not None
|
||||||
self.__profiler.step()
|
self.__profiler.step()
|
||||||
|
|
||||||
def end(self):
|
def end(self):
|
||||||
|
print(333)
|
||||||
if not self.debug:
|
if not self.debug:
|
||||||
return
|
return
|
||||||
assert self.__profiler is not None
|
assert self.__profiler is not None
|
||||||
|
Loading…
x
Reference in New Issue
Block a user