mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2025-08-15 13:29:51 +08:00
stream_infer: 更方便找规律的图
This commit is contained in:
parent
920bbafb12
commit
03f99256c7
@ -121,7 +121,7 @@ class StreamT2SModel(nn.Module):
|
||||
# if early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num:
|
||||
# stop = True
|
||||
if torch.argmax(logits, dim=-1)[0] == self.t2s.EOS or samples[0, 0] == self.t2s.EOS:
|
||||
return y[:,:-1], xy_pos, last_token, k_cache, v_cache
|
||||
return y[:,:-1], xy_pos, self.t2s.EOS, k_cache, v_cache
|
||||
|
||||
# if stop:
|
||||
# if y.shape[1] == 0:
|
||||
@ -192,7 +192,6 @@ def test_stream(
|
||||
ref_audio_path,
|
||||
ref_text,
|
||||
output_path,
|
||||
export_bert_and_ssl=False,
|
||||
device="cpu",
|
||||
is_half=True,
|
||||
):
|
||||
@ -286,13 +285,30 @@ def test_stream(
|
||||
audios = []
|
||||
full_audios = []
|
||||
print("y.shape:", y.shape)
|
||||
cut_id = 0
|
||||
while True:
|
||||
y, xy_pos, last_token, k_cache, v_cache = stream_t2s(idx, top_k, y_len, y, xy_pos, k_cache, v_cache)
|
||||
# print("y.shape:", y.shape)
|
||||
stop = last_token==t2s.EOS
|
||||
print('idx:',idx , 'y.shape:', y.shape, y.shape[1]-idx)
|
||||
|
||||
if last_token < 30 and idx-last_idx > (len(audios)+1) * 25 and idx > cut_id:
|
||||
cut_id = idx + 7
|
||||
print('trigger:',idx, last_idx, y[:,-idx+last_idx:], y[:,-idx+last_idx:].shape)
|
||||
# y = torch.cat([y, y[:,-1:]], dim=1)
|
||||
# idx+=1
|
||||
|
||||
if stop :
|
||||
idx -=1
|
||||
print('stop')
|
||||
print(idx, y[:,-idx+last_idx:])
|
||||
print(idx,last_idx, y.shape)
|
||||
print(y[:,-idx:-idx+20])
|
||||
|
||||
|
||||
# 玄学这档子事说不清楚
|
||||
if (y[0,-8] < 30 and idx-last_idx > (len(audios)+1) * 25) or stop:
|
||||
if idx == cut_id or stop:
|
||||
print(f"idx: {idx}, last_idx: {last_idx}, cut_id: {cut_id}, stop: {stop}")
|
||||
audio = vits.vq_model(y[:,-idx:].unsqueeze(0), text_seq, refer, speed=1.0, sv_emb=sv_emb)[0, 0]
|
||||
full_audios.append(audio)
|
||||
if last_idx == 0:
|
||||
@ -303,13 +319,10 @@ def test_stream(
|
||||
audio = audio[last_idx*1280 -1280*8:]
|
||||
else:
|
||||
audio = audio[last_idx*1280 -1280*8:-1280*8]
|
||||
print(y[:,-idx+last_idx:])
|
||||
last_idx = idx
|
||||
# print(f'write {output_path}/out_{audio_index}')
|
||||
# soundfile.write(f"{output_path}/out_{audio_index}.wav", audio.float().detach().cpu().numpy(), 32000)
|
||||
audios.append(audio)
|
||||
|
||||
idx+=1
|
||||
# print(idx,'/',1500 , y.shape, y[0,-1].item(), stop)
|
||||
if idx>1500:
|
||||
break
|
||||
@ -317,6 +330,8 @@ def test_stream(
|
||||
if stop:
|
||||
break
|
||||
|
||||
idx+=1
|
||||
|
||||
at = time.time()
|
||||
|
||||
for (i,a) in enumerate(audios):
|
||||
@ -349,6 +364,10 @@ def test_stream(
|
||||
axes[-1].plot(audio.float().detach().cpu().numpy(), color='black', label='Final Audio')
|
||||
axes[-1].set_xlim(0, max_duration)
|
||||
|
||||
for i,y in enumerate(y[0][-idx:]):
|
||||
axes[-1].text(i*1280, 0.05, str(int(y)), fontsize=12, ha='center')
|
||||
axes[-1].axvline(x=i*1280, color='gray', linestyle=':', alpha=0.5)
|
||||
|
||||
# plt.title('Overlapped Waveform Comparison')
|
||||
# plt.xlabel('Sample Number')
|
||||
# plt.ylabel('Amplitude')
|
||||
|
Loading…
x
Reference in New Issue
Block a user