This commit is contained in:
XXXXRT666 2025-10-05 07:10:14 +01:00
parent 184ebbf54a
commit 8b390b4112
2 changed files with 9 additions and 5 deletions

View File

@ -58,15 +58,17 @@ class T2SEngine(T2SEngineProtocol):
transient=True,
) as progress,
):
torch_profiler.start()
max_token = int(min(2000 - session.input_pos.max(), 1500))
task = progress.add_task("T2S Decoding", total=max_token)
for idx in range(max_token):
progress.update(task, advance=1)
if idx == 0:
session.kv_cache = decoder.init_cache(session.bsz)
xy_dec = decoder.h.prefill(session.xy_pos, session.kv_cache, session.attn_mask)
xy_dec = xy_dec[None, batch_idx, session.input_pos - 1]
with torch_profiler.record("Prefill"):
session.kv_cache = decoder.init_cache(session.bsz)
xy_dec = decoder.h.prefill(session.xy_pos, session.kv_cache, session.attn_mask)
xy_dec = xy_dec[None, batch_idx, session.input_pos - 1]
else:
if (
request.use_cuda_graph
@ -76,7 +78,7 @@ class T2SEngine(T2SEngineProtocol):
):
self.graphcache.assign_graph(session)
with torch_profiler.record("AR"):
with torch_profiler.record("Decode"):
if session.graph:
assert session.stream
session.stream.wait_stream(torch.cuda.default_stream())
@ -152,7 +154,6 @@ class T2SEngine(T2SEngineProtocol):
session.xy_pos = decoder.ar_audio_position(session.input_pos - session.x_lens, y_emb)
if idx == 1:
torch_profiler.start()
t1 = time.perf_counter()
if idx == 51:

View File

@ -678,6 +678,7 @@ class TorchProfiler:
@staticmethod
def three_step_schedule(step: int) -> ProfilerAction:
print(111, step)
if step == 0:
return ProfilerAction.NONE
elif step == 1:
@ -688,12 +689,14 @@ class TorchProfiler:
return ProfilerAction.NONE
def start(self):
print(222)
if not self.debug:
return
assert self.__profiler is not None
self.__profiler.step()
def end(self):
print(333)
if not self.debug:
return
assert self.__profiler is not None