恢复先前缩进

This commit is contained in:
XTer 2024-04-06 23:16:42 +08:00
parent adb7f71b64
commit 3bfb20763d

View File

@ -280,6 +280,7 @@ class TTS:
if self.configs.is_half and str(self.configs.device)!="cpu":
self.vits_model = self.vits_model.half()
def init_t2s_weights(self, weights_path: str):
print(f"Loading Text2Semantic weights from {weights_path}")
self.configs.t2s_weights_path = weights_path
@ -376,6 +377,7 @@ class TTS:
# self.refer_spec = spec
self.prompt_cache["refer_spec"] = spec
def _set_prompt_semantic(self, ref_wav_path:str):
zero_wav = np.zeros(
int(self.configs.sampling_rate * 0.3),
@ -416,8 +418,7 @@ class TTS:
max_length = max(seq_lengths)
else:
max_length = max(seq_lengths) if max_length < max(seq_lengths) else max_length
# 我爱套 torch.no_grad()
# with torch.no_grad():
padded_sequences = []
for seq, length in zip(sequences, seq_lengths):
padding = [0] * axis + [0, max_length - length] + [0] * (ndim - axis - 1)
@ -434,7 +435,6 @@ class TTS:
device:torch.device=torch.device("cpu"),
precision:torch.dtype=torch.float32,
):
# 但是这里不能套,反而会负优化
# with torch.no_grad():
_data:list = []
@ -472,6 +472,7 @@ class TTS:
batch_index_list.append([])
batch_index_list[-1].append(i)
for batch_idx, index_list in enumerate(batch_index_list):
item_list = [data[idx] for idx in index_list]
phones_list = []
@ -513,6 +514,7 @@ class TTS:
all_phones_batch = all_phones_list
all_bert_features_batch = all_bert_features_list
# max_len = max(bert_max_len, phones_max_len)
# phones_batch = self.batch_sequences(phones_list, axis=0, pad_value=0, max_length=max_len)
#### 直接对phones和bert_features进行pad会增大复读概率。
@ -567,6 +569,7 @@ class TTS:
'''
self.stop_flag = True
def run(self, inputs:dict):
"""
Text to speech inference.
@ -881,12 +884,14 @@ class TTS:
audio_fragment:torch.Tensor = torch.cat([audio_fragment, zero_wav], dim=0)
audio[i][j] = audio_fragment.cpu().numpy()
if split_bucket:
audio = self.recovery_order(audio, batch_index_list)
else:
# audio = [item for batch in audio for item in batch]
audio = sum(audio, [])
audio = np.concatenate(audio, 0)
audio = (audio * 32768).astype(np.int16)
@ -899,6 +904,8 @@ class TTS:
return sr, audio
def speed_change(input_audio:np.ndarray, speed:float, sr:int):
# 将 NumPy 数组转换为原始 PCM 流
raw_audio = input_audio.astype(np.int16).tobytes()