mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2025-08-15 21:39:50 +08:00
stream_infer: 更方便找规律的图
This commit is contained in:
parent
920bbafb12
commit
03f99256c7
@ -121,7 +121,7 @@ class StreamT2SModel(nn.Module):
|
|||||||
# if early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num:
|
# if early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num:
|
||||||
# stop = True
|
# stop = True
|
||||||
if torch.argmax(logits, dim=-1)[0] == self.t2s.EOS or samples[0, 0] == self.t2s.EOS:
|
if torch.argmax(logits, dim=-1)[0] == self.t2s.EOS or samples[0, 0] == self.t2s.EOS:
|
||||||
return y[:,:-1], xy_pos, last_token, k_cache, v_cache
|
return y[:,:-1], xy_pos, self.t2s.EOS, k_cache, v_cache
|
||||||
|
|
||||||
# if stop:
|
# if stop:
|
||||||
# if y.shape[1] == 0:
|
# if y.shape[1] == 0:
|
||||||
@ -192,7 +192,6 @@ def test_stream(
|
|||||||
ref_audio_path,
|
ref_audio_path,
|
||||||
ref_text,
|
ref_text,
|
||||||
output_path,
|
output_path,
|
||||||
export_bert_and_ssl=False,
|
|
||||||
device="cpu",
|
device="cpu",
|
||||||
is_half=True,
|
is_half=True,
|
||||||
):
|
):
|
||||||
@ -286,13 +285,30 @@ def test_stream(
|
|||||||
audios = []
|
audios = []
|
||||||
full_audios = []
|
full_audios = []
|
||||||
print("y.shape:", y.shape)
|
print("y.shape:", y.shape)
|
||||||
|
cut_id = 0
|
||||||
while True:
|
while True:
|
||||||
y, xy_pos, last_token, k_cache, v_cache = stream_t2s(idx, top_k, y_len, y, xy_pos, k_cache, v_cache)
|
y, xy_pos, last_token, k_cache, v_cache = stream_t2s(idx, top_k, y_len, y, xy_pos, k_cache, v_cache)
|
||||||
# print("y.shape:", y.shape)
|
# print("y.shape:", y.shape)
|
||||||
stop = last_token==t2s.EOS
|
stop = last_token==t2s.EOS
|
||||||
|
print('idx:',idx , 'y.shape:', y.shape, y.shape[1]-idx)
|
||||||
|
|
||||||
|
if last_token < 30 and idx-last_idx > (len(audios)+1) * 25 and idx > cut_id:
|
||||||
|
cut_id = idx + 7
|
||||||
|
print('trigger:',idx, last_idx, y[:,-idx+last_idx:], y[:,-idx+last_idx:].shape)
|
||||||
|
# y = torch.cat([y, y[:,-1:]], dim=1)
|
||||||
|
# idx+=1
|
||||||
|
|
||||||
|
if stop :
|
||||||
|
idx -=1
|
||||||
|
print('stop')
|
||||||
|
print(idx, y[:,-idx+last_idx:])
|
||||||
|
print(idx,last_idx, y.shape)
|
||||||
|
print(y[:,-idx:-idx+20])
|
||||||
|
|
||||||
|
|
||||||
# 玄学这档子事说不清楚
|
# 玄学这档子事说不清楚
|
||||||
if (y[0,-8] < 30 and idx-last_idx > (len(audios)+1) * 25) or stop:
|
if idx == cut_id or stop:
|
||||||
|
print(f"idx: {idx}, last_idx: {last_idx}, cut_id: {cut_id}, stop: {stop}")
|
||||||
audio = vits.vq_model(y[:,-idx:].unsqueeze(0), text_seq, refer, speed=1.0, sv_emb=sv_emb)[0, 0]
|
audio = vits.vq_model(y[:,-idx:].unsqueeze(0), text_seq, refer, speed=1.0, sv_emb=sv_emb)[0, 0]
|
||||||
full_audios.append(audio)
|
full_audios.append(audio)
|
||||||
if last_idx == 0:
|
if last_idx == 0:
|
||||||
@ -303,13 +319,10 @@ def test_stream(
|
|||||||
audio = audio[last_idx*1280 -1280*8:]
|
audio = audio[last_idx*1280 -1280*8:]
|
||||||
else:
|
else:
|
||||||
audio = audio[last_idx*1280 -1280*8:-1280*8]
|
audio = audio[last_idx*1280 -1280*8:-1280*8]
|
||||||
print(y[:,-idx+last_idx:])
|
|
||||||
last_idx = idx
|
last_idx = idx
|
||||||
# print(f'write {output_path}/out_{audio_index}')
|
# print(f'write {output_path}/out_{audio_index}')
|
||||||
# soundfile.write(f"{output_path}/out_{audio_index}.wav", audio.float().detach().cpu().numpy(), 32000)
|
# soundfile.write(f"{output_path}/out_{audio_index}.wav", audio.float().detach().cpu().numpy(), 32000)
|
||||||
audios.append(audio)
|
audios.append(audio)
|
||||||
|
|
||||||
idx+=1
|
|
||||||
# print(idx,'/',1500 , y.shape, y[0,-1].item(), stop)
|
# print(idx,'/',1500 , y.shape, y[0,-1].item(), stop)
|
||||||
if idx>1500:
|
if idx>1500:
|
||||||
break
|
break
|
||||||
@ -317,6 +330,8 @@ def test_stream(
|
|||||||
if stop:
|
if stop:
|
||||||
break
|
break
|
||||||
|
|
||||||
|
idx+=1
|
||||||
|
|
||||||
at = time.time()
|
at = time.time()
|
||||||
|
|
||||||
for (i,a) in enumerate(audios):
|
for (i,a) in enumerate(audios):
|
||||||
@ -349,6 +364,10 @@ def test_stream(
|
|||||||
axes[-1].plot(audio.float().detach().cpu().numpy(), color='black', label='Final Audio')
|
axes[-1].plot(audio.float().detach().cpu().numpy(), color='black', label='Final Audio')
|
||||||
axes[-1].set_xlim(0, max_duration)
|
axes[-1].set_xlim(0, max_duration)
|
||||||
|
|
||||||
|
for i,y in enumerate(y[0][-idx:]):
|
||||||
|
axes[-1].text(i*1280, 0.05, str(int(y)), fontsize=12, ha='center')
|
||||||
|
axes[-1].axvline(x=i*1280, color='gray', linestyle=':', alpha=0.5)
|
||||||
|
|
||||||
# plt.title('Overlapped Waveform Comparison')
|
# plt.title('Overlapped Waveform Comparison')
|
||||||
# plt.xlabel('Sample Number')
|
# plt.xlabel('Sample Number')
|
||||||
# plt.ylabel('Amplitude')
|
# plt.ylabel('Amplitude')
|
||||||
|
Loading…
x
Reference in New Issue
Block a user