mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2025-10-16 05:36:34 +08:00
.
This commit is contained in:
parent
184ebbf54a
commit
8b390b4112
@ -58,15 +58,17 @@ class T2SEngine(T2SEngineProtocol):
|
||||
transient=True,
|
||||
) as progress,
|
||||
):
|
||||
torch_profiler.start()
|
||||
max_token = int(min(2000 - session.input_pos.max(), 1500))
|
||||
task = progress.add_task("T2S Decoding", total=max_token)
|
||||
|
||||
for idx in range(max_token):
|
||||
progress.update(task, advance=1)
|
||||
if idx == 0:
|
||||
session.kv_cache = decoder.init_cache(session.bsz)
|
||||
xy_dec = decoder.h.prefill(session.xy_pos, session.kv_cache, session.attn_mask)
|
||||
xy_dec = xy_dec[None, batch_idx, session.input_pos - 1]
|
||||
with torch_profiler.record("Prefill"):
|
||||
session.kv_cache = decoder.init_cache(session.bsz)
|
||||
xy_dec = decoder.h.prefill(session.xy_pos, session.kv_cache, session.attn_mask)
|
||||
xy_dec = xy_dec[None, batch_idx, session.input_pos - 1]
|
||||
else:
|
||||
if (
|
||||
request.use_cuda_graph
|
||||
@ -76,7 +78,7 @@ class T2SEngine(T2SEngineProtocol):
|
||||
):
|
||||
self.graphcache.assign_graph(session)
|
||||
|
||||
with torch_profiler.record("AR"):
|
||||
with torch_profiler.record("Decode"):
|
||||
if session.graph:
|
||||
assert session.stream
|
||||
session.stream.wait_stream(torch.cuda.default_stream())
|
||||
@ -152,7 +154,6 @@ class T2SEngine(T2SEngineProtocol):
|
||||
session.xy_pos = decoder.ar_audio_position(session.input_pos - session.x_lens, y_emb)
|
||||
|
||||
if idx == 1:
|
||||
torch_profiler.start()
|
||||
t1 = time.perf_counter()
|
||||
|
||||
if idx == 51:
|
||||
|
@ -678,6 +678,7 @@ class TorchProfiler:
|
||||
|
||||
@staticmethod
|
||||
def three_step_schedule(step: int) -> ProfilerAction:
|
||||
print(111, step)
|
||||
if step == 0:
|
||||
return ProfilerAction.NONE
|
||||
elif step == 1:
|
||||
@ -688,12 +689,14 @@ class TorchProfiler:
|
||||
return ProfilerAction.NONE
|
||||
|
||||
def start(self):
|
||||
print(222)
|
||||
if not self.debug:
|
||||
return
|
||||
assert self.__profiler is not None
|
||||
self.__profiler.step()
|
||||
|
||||
def end(self):
|
||||
print(333)
|
||||
if not self.debug:
|
||||
return
|
||||
assert self.__profiler is not None
|
||||
|
Loading…
x
Reference in New Issue
Block a user