恢复先前缩进

This commit is contained in:
XTer 2024-04-06 23:16:42 +08:00
parent adb7f71b64
commit 3bfb20763d

View File

@ -280,6 +280,7 @@ class TTS:
if self.configs.is_half and str(self.configs.device)!="cpu":
self.vits_model = self.vits_model.half()
def init_t2s_weights(self, weights_path: str):
print(f"Loading Text2Semantic weights from {weights_path}")
self.configs.t2s_weights_path = weights_path
@ -376,6 +377,7 @@ class TTS:
# self.refer_spec = spec
self.prompt_cache["refer_spec"] = spec
def _set_prompt_semantic(self, ref_wav_path:str):
zero_wav = np.zeros(
int(self.configs.sampling_rate * 0.3),
@ -416,8 +418,7 @@ class TTS:
max_length = max(seq_lengths)
else:
max_length = max(seq_lengths) if max_length < max(seq_lengths) else max_length
# 我爱套 torch.no_grad()
# with torch.no_grad():
padded_sequences = []
for seq, length in zip(sequences, seq_lengths):
padding = [0] * axis + [0, max_length - length] + [0] * (ndim - axis - 1)
@ -434,114 +435,115 @@ class TTS:
device:torch.device=torch.device("cpu"),
precision:torch.dtype=torch.float32,
):
# 但是这里不能套,反而会负优化
# with torch.no_grad():
_data:list = []
index_and_len_list = []
for idx, item in enumerate(data):
norm_text_len = len(item["norm_text"])
index_and_len_list.append([idx, norm_text_len])
_data:list = []
index_and_len_list = []
for idx, item in enumerate(data):
norm_text_len = len(item["norm_text"])
index_and_len_list.append([idx, norm_text_len])
batch_index_list = []
if split_bucket:
index_and_len_list.sort(key=lambda x: x[1])
index_and_len_list = np.array(index_and_len_list, dtype=np.int64)
batch_index_list = []
if split_bucket:
index_and_len_list.sort(key=lambda x: x[1])
index_and_len_list = np.array(index_and_len_list, dtype=np.int64)
batch_index_list_len = 0
pos = 0
while pos <index_and_len_list.shape[0]:
# batch_index_list.append(index_and_len_list[pos:min(pos+batch_size,len(index_and_len_list))])
pos_end = min(pos+batch_size,index_and_len_list.shape[0])
while pos < pos_end:
batch=index_and_len_list[pos:pos_end, 1].astype(np.float32)
score=batch[(pos_end-pos)//2]/(batch.mean()+1e-8)
if (score>=threshold) or (pos_end-pos==1):
batch_index=index_and_len_list[pos:pos_end, 0].tolist()
batch_index_list_len += len(batch_index)
batch_index_list.append(batch_index)
pos = pos_end
break
pos_end=pos_end-1
batch_index_list_len = 0
pos = 0
while pos <index_and_len_list.shape[0]:
# batch_index_list.append(index_and_len_list[pos:min(pos+batch_size,len(index_and_len_list))])
pos_end = min(pos+batch_size,index_and_len_list.shape[0])
while pos < pos_end:
batch=index_and_len_list[pos:pos_end, 1].astype(np.float32)
score=batch[(pos_end-pos)//2]/(batch.mean()+1e-8)
if (score>=threshold) or (pos_end-pos==1):
batch_index=index_and_len_list[pos:pos_end, 0].tolist()
batch_index_list_len += len(batch_index)
batch_index_list.append(batch_index)
pos = pos_end
break
pos_end=pos_end-1
assert batch_index_list_len == len(data)
assert batch_index_list_len == len(data)
else:
for i in range(len(data)):
if i%batch_size == 0:
batch_index_list.append([])
batch_index_list[-1].append(i)
else:
for i in range(len(data)):
if i%batch_size == 0:
batch_index_list.append([])
batch_index_list[-1].append(i)
for batch_idx, index_list in enumerate(batch_index_list):
item_list = [data[idx] for idx in index_list]
phones_list = []
phones_len_list = []
# bert_features_list = []
all_phones_list = []
all_phones_len_list = []
all_bert_features_list = []
norm_text_batch = []
bert_max_len = 0
phones_max_len = 0
# 但是这里也不能套,反而会负优化
# with torch.no_grad():
for item in item_list:
if prompt_data is not None:
all_bert_features = torch.cat([prompt_data["bert_features"], item["bert_features"]], 1)\
.to(dtype=precision, device=device)
all_phones = torch.LongTensor(prompt_data["phones"]+item["phones"]).to(device)
phones = torch.LongTensor(item["phones"]).to(device)
# norm_text = prompt_data["norm_text"]+item["norm_text"]
else:
all_bert_features = item["bert_features"]\
for batch_idx, index_list in enumerate(batch_index_list):
item_list = [data[idx] for idx in index_list]
phones_list = []
phones_len_list = []
# bert_features_list = []
all_phones_list = []
all_phones_len_list = []
all_bert_features_list = []
norm_text_batch = []
bert_max_len = 0
phones_max_len = 0
# 但是这里也不能套,反而会负优化
# with torch.no_grad():
for item in item_list:
if prompt_data is not None:
all_bert_features = torch.cat([prompt_data["bert_features"], item["bert_features"]], 1)\
.to(dtype=precision, device=device)
phones = torch.LongTensor(item["phones"]).to(device)
all_phones = phones
# norm_text = item["norm_text"]
all_phones = torch.LongTensor(prompt_data["phones"]+item["phones"]).to(device)
phones = torch.LongTensor(item["phones"]).to(device)
# norm_text = prompt_data["norm_text"]+item["norm_text"]
else:
all_bert_features = item["bert_features"]\
.to(dtype=precision, device=device)
phones = torch.LongTensor(item["phones"]).to(device)
all_phones = phones
# norm_text = item["norm_text"]
bert_max_len = max(bert_max_len, all_bert_features.shape[-1])
phones_max_len = max(phones_max_len, phones.shape[-1])
bert_max_len = max(bert_max_len, all_bert_features.shape[-1])
phones_max_len = max(phones_max_len, phones.shape[-1])
phones_list.append(phones)
phones_len_list.append(phones.shape[-1])
all_phones_list.append(all_phones)
all_phones_len_list.append(all_phones.shape[-1])
all_bert_features_list.append(all_bert_features)
norm_text_batch.append(item["norm_text"])
phones_list.append(phones)
phones_len_list.append(phones.shape[-1])
all_phones_list.append(all_phones)
all_phones_len_list.append(all_phones.shape[-1])
all_bert_features_list.append(all_bert_features)
norm_text_batch.append(item["norm_text"])
phones_batch = phones_list
all_phones_batch = all_phones_list
all_bert_features_batch = all_bert_features_list
phones_batch = phones_list
all_phones_batch = all_phones_list
all_bert_features_batch = all_bert_features_list
# max_len = max(bert_max_len, phones_max_len)
# phones_batch = self.batch_sequences(phones_list, axis=0, pad_value=0, max_length=max_len)
#### 直接对phones和bert_features进行pad会增大复读概率。
# all_phones_batch = self.batch_sequences(all_phones_list, axis=0, pad_value=0, max_length=max_len)
# all_bert_features_batch = all_bert_features_list
# all_bert_features_batch = torch.zeros(len(item_list), 1024, max_len, dtype=precision, device=device)
# for idx, item in enumerate(all_bert_features_list):
# all_bert_features_batch[idx, :, : item.shape[-1]] = item
# #### 先对phones进行embedding、对bert_features进行project再pad到相同长度以缓解复读问题。可能还有其他因素导致复读
# all_phones_list = [self.t2s_model.model.ar_text_embedding(item.to(self.t2s_model.device)) for item in all_phones_list]
# all_phones_list = [F.pad(item,(0,0,0,max_len-item.shape[0]),value=0) for item in all_phones_list]
# all_phones_batch = torch.stack(all_phones_list, dim=0)
# max_len = max(bert_max_len, phones_max_len)
# phones_batch = self.batch_sequences(phones_list, axis=0, pad_value=0, max_length=max_len)
#### 直接对phones和bert_features进行pad会增大复读概率。
# all_phones_batch = self.batch_sequences(all_phones_list, axis=0, pad_value=0, max_length=max_len)
# all_bert_features_batch = all_bert_features_list
# all_bert_features_batch = torch.zeros(len(item_list), 1024, max_len, dtype=precision, device=device)
# for idx, item in enumerate(all_bert_features_list):
# all_bert_features_batch[idx, :, : item.shape[-1]] = item
# all_bert_features_list = [self.t2s_model.model.bert_proj(item.to(self.t2s_model.device).transpose(0, 1)) for item in all_bert_features_list]
# all_bert_features_list = [F.pad(item,(0,0,0,max_len-item.shape[0]), value=0) for item in all_bert_features_list]
# all_bert_features_batch = torch.stack(all_bert_features_list, dim=0)
# #### 先对phones进行embedding、对bert_features进行project再pad到相同长度以缓解复读问题。可能还有其他因素导致复读
# all_phones_list = [self.t2s_model.model.ar_text_embedding(item.to(self.t2s_model.device)) for item in all_phones_list]
# all_phones_list = [F.pad(item,(0,0,0,max_len-item.shape[0]),value=0) for item in all_phones_list]
# all_phones_batch = torch.stack(all_phones_list, dim=0)
batch = {
"phones": phones_batch,
"phones_len": torch.LongTensor(phones_len_list).to(device),
"all_phones": all_phones_batch,
"all_phones_len": torch.LongTensor(all_phones_len_list).to(device),
"all_bert_features": all_bert_features_batch,
"norm_text": norm_text_batch
}
_data.append(batch)
# all_bert_features_list = [self.t2s_model.model.bert_proj(item.to(self.t2s_model.device).transpose(0, 1)) for item in all_bert_features_list]
# all_bert_features_list = [F.pad(item,(0,0,0,max_len-item.shape[0]), value=0) for item in all_bert_features_list]
# all_bert_features_batch = torch.stack(all_bert_features_list, dim=0)
return _data, batch_index_list
batch = {
"phones": phones_batch,
"phones_len": torch.LongTensor(phones_len_list).to(device),
"all_phones": all_phones_batch,
"all_phones_len": torch.LongTensor(all_phones_len_list).to(device),
"all_bert_features": all_bert_features_batch,
"norm_text": norm_text_batch
}
_data.append(batch)
return _data, batch_index_list
def recovery_order(self, data:list, batch_index_list:list)->list:
'''
@ -567,6 +569,7 @@ class TTS:
'''
self.stop_flag = True
def run(self, inputs:dict):
"""
Text to speech inference.
@ -881,12 +884,14 @@ class TTS:
audio_fragment:torch.Tensor = torch.cat([audio_fragment, zero_wav], dim=0)
audio[i][j] = audio_fragment.cpu().numpy()
if split_bucket:
audio = self.recovery_order(audio, batch_index_list)
else:
# audio = [item for batch in audio for item in batch]
audio = sum(audio, [])
audio = np.concatenate(audio, 0)
audio = (audio * 32768).astype(np.int16)
@ -899,6 +904,8 @@ class TTS:
return sr, audio
def speed_change(input_audio:np.ndarray, speed:float, sr:int):
# 将 NumPy 数组转换为原始 PCM 流
raw_audio = input_audio.astype(np.int16).tobytes()