恢复先前缩进

This commit is contained in:
XTer 2024-04-06 23:16:42 +08:00
parent adb7f71b64
commit 3bfb20763d

View File

@ -272,7 +272,7 @@ class TTS:
# if ("pretrained" not in weights_path):
if hasattr(vits_model, "enc_q"):
del vits_model.enc_q
vits_model = vits_model.to(self.configs.device)
vits_model = vits_model.eval()
vits_model.load_state_dict(dict_s2["weight"], strict=False)
@ -280,6 +280,7 @@ class TTS:
if self.configs.is_half and str(self.configs.device)!="cpu":
self.vits_model = self.vits_model.half()
def init_t2s_weights(self, weights_path: str):
print(f"Loading Text2Semantic weights from {weights_path}")
self.configs.t2s_weights_path = weights_path
@ -296,7 +297,7 @@ class TTS:
self.t2s_model = t2s_model
if self.configs.is_half and str(self.configs.device)!="cpu":
self.t2s_model = self.t2s_model.half()
def enable_half_precision(self, enable: bool = True):
'''
To enable half precision for the TTS model.
@ -307,7 +308,7 @@ class TTS:
if str(self.configs.device) == "cpu" and enable:
print("Half precision is not supported on CPU.")
return
self.configs.is_half = enable
self.precision = torch.float16 if enable else torch.float32
self.configs.save_configs()
@ -329,7 +330,7 @@ class TTS:
self.bert_model = self.bert_model.float()
if self.cnhuhbert_model is not None:
self.cnhuhbert_model = self.cnhuhbert_model.float()
def set_device(self, device: torch.device):
'''
To set the device for all models.
@ -346,7 +347,7 @@ class TTS:
self.bert_model = self.bert_model.to(device)
if self.cnhuhbert_model is not None:
self.cnhuhbert_model = self.cnhuhbert_model.to(device)
def set_ref_audio(self, ref_audio_path:str):
'''
To set the reference audio for the TTS model,
@ -356,7 +357,7 @@ class TTS:
'''
self._set_prompt_semantic(ref_audio_path)
self._set_ref_spec(ref_audio_path)
def _set_ref_spec(self, ref_audio_path):
audio = load_audio(ref_audio_path, int(self.configs.sampling_rate))
audio = torch.FloatTensor(audio)
@ -375,7 +376,8 @@ class TTS:
spec = spec.half()
# self.refer_spec = spec
self.prompt_cache["refer_spec"] = spec
def _set_prompt_semantic(self, ref_wav_path:str):
zero_wav = np.zeros(
int(self.configs.sampling_rate * 0.3),
@ -400,10 +402,10 @@ class TTS:
1, 2
) # .float()
codes = self.vits_model.extract_latent(hubert_feature)
prompt_semantic = codes[0, 0].to(self.configs.device)
self.prompt_cache["prompt_semantic"] = prompt_semantic
def batch_sequences(self, sequences: List[torch.Tensor], axis: int = 0, pad_value: int = 0, max_length:int=None):
seq = sequences[0]
ndim = seq.dim()
@ -416,8 +418,7 @@ class TTS:
max_length = max(seq_lengths)
else:
max_length = max(seq_lengths) if max_length < max(seq_lengths) else max_length
# 我爱套 torch.no_grad()
# with torch.no_grad():
padded_sequences = []
for seq, length in zip(sequences, seq_lengths):
padding = [0] * axis + [0, max_length - length] + [0] * (ndim - axis - 1)
@ -425,7 +426,7 @@ class TTS:
padded_sequences.append(padded_seq)
batch = torch.stack(padded_sequences)
return batch
def to_batch(self, data:list,
prompt_data:dict=None,
batch_size:int=5,
@ -434,115 +435,116 @@ class TTS:
device:torch.device=torch.device("cpu"),
precision:torch.dtype=torch.float32,
):
# 但是这里不能套,反而会负优化
# with torch.no_grad():
_data:list = []
index_and_len_list = []
for idx, item in enumerate(data):
norm_text_len = len(item["norm_text"])
index_and_len_list.append([idx, norm_text_len])
_data:list = []
index_and_len_list = []
for idx, item in enumerate(data):
norm_text_len = len(item["norm_text"])
index_and_len_list.append([idx, norm_text_len])
batch_index_list = []
if split_bucket:
index_and_len_list.sort(key=lambda x: x[1])
index_and_len_list = np.array(index_and_len_list, dtype=np.int64)
batch_index_list = []
if split_bucket:
index_and_len_list.sort(key=lambda x: x[1])
index_and_len_list = np.array(index_and_len_list, dtype=np.int64)
batch_index_list_len = 0
pos = 0
while pos <index_and_len_list.shape[0]:
# batch_index_list.append(index_and_len_list[pos:min(pos+batch_size,len(index_and_len_list))])
pos_end = min(pos+batch_size,index_and_len_list.shape[0])
while pos < pos_end:
batch=index_and_len_list[pos:pos_end, 1].astype(np.float32)
score=batch[(pos_end-pos)//2]/(batch.mean()+1e-8)
if (score>=threshold) or (pos_end-pos==1):
batch_index=index_and_len_list[pos:pos_end, 0].tolist()
batch_index_list_len += len(batch_index)
batch_index_list.append(batch_index)
pos = pos_end
break
pos_end=pos_end-1
assert batch_index_list_len == len(data)
else:
for i in range(len(data)):
if i%batch_size == 0:
batch_index_list.append([])
batch_index_list[-1].append(i)
batch_index_list_len = 0
pos = 0
while pos <index_and_len_list.shape[0]:
# batch_index_list.append(index_and_len_list[pos:min(pos+batch_size,len(index_and_len_list))])
pos_end = min(pos+batch_size,index_and_len_list.shape[0])
while pos < pos_end:
batch=index_and_len_list[pos:pos_end, 1].astype(np.float32)
score=batch[(pos_end-pos)//2]/(batch.mean()+1e-8)
if (score>=threshold) or (pos_end-pos==1):
batch_index=index_and_len_list[pos:pos_end, 0].tolist()
batch_index_list_len += len(batch_index)
batch_index_list.append(batch_index)
pos = pos_end
break
pos_end=pos_end-1
assert batch_index_list_len == len(data)
else:
for i in range(len(data)):
if i%batch_size == 0:
batch_index_list.append([])
batch_index_list[-1].append(i)
for batch_idx, index_list in enumerate(batch_index_list):
item_list = [data[idx] for idx in index_list]
phones_list = []
phones_len_list = []
# bert_features_list = []
all_phones_list = []
all_phones_len_list = []
all_bert_features_list = []
norm_text_batch = []
bert_max_len = 0
phones_max_len = 0
# 但是这里也不能套,反而会负优化
# with torch.no_grad():
for item in item_list:
if prompt_data is not None:
all_bert_features = torch.cat([prompt_data["bert_features"], item["bert_features"]], 1)\
.to(dtype=precision, device=device)
all_phones = torch.LongTensor(prompt_data["phones"]+item["phones"]).to(device)
phones = torch.LongTensor(item["phones"]).to(device)
# norm_text = prompt_data["norm_text"]+item["norm_text"]
else:
all_bert_features = item["bert_features"]\
for batch_idx, index_list in enumerate(batch_index_list):
item_list = [data[idx] for idx in index_list]
phones_list = []
phones_len_list = []
# bert_features_list = []
all_phones_list = []
all_phones_len_list = []
all_bert_features_list = []
norm_text_batch = []
bert_max_len = 0
phones_max_len = 0
# 但是这里也不能套,反而会负优化
# with torch.no_grad():
for item in item_list:
if prompt_data is not None:
all_bert_features = torch.cat([prompt_data["bert_features"], item["bert_features"]], 1)\
.to(dtype=precision, device=device)
phones = torch.LongTensor(item["phones"]).to(device)
all_phones = phones
# norm_text = item["norm_text"]
bert_max_len = max(bert_max_len, all_bert_features.shape[-1])
phones_max_len = max(phones_max_len, phones.shape[-1])
phones_list.append(phones)
phones_len_list.append(phones.shape[-1])
all_phones_list.append(all_phones)
all_phones_len_list.append(all_phones.shape[-1])
all_bert_features_list.append(all_bert_features)
norm_text_batch.append(item["norm_text"])
phones_batch = phones_list
all_phones_batch = all_phones_list
all_bert_features_batch = all_bert_features_list
# max_len = max(bert_max_len, phones_max_len)
# phones_batch = self.batch_sequences(phones_list, axis=0, pad_value=0, max_length=max_len)
#### 直接对phones和bert_features进行pad会增大复读概率。
# all_phones_batch = self.batch_sequences(all_phones_list, axis=0, pad_value=0, max_length=max_len)
# all_bert_features_batch = all_bert_features_list
# all_bert_features_batch = torch.zeros(len(item_list), 1024, max_len, dtype=precision, device=device)
# for idx, item in enumerate(all_bert_features_list):
# all_bert_features_batch[idx, :, : item.shape[-1]] = item
# #### 先对phones进行embedding、对bert_features进行project再pad到相同长度以缓解复读问题。可能还有其他因素导致复读
# all_phones_list = [self.t2s_model.model.ar_text_embedding(item.to(self.t2s_model.device)) for item in all_phones_list]
# all_phones_list = [F.pad(item,(0,0,0,max_len-item.shape[0]),value=0) for item in all_phones_list]
# all_phones_batch = torch.stack(all_phones_list, dim=0)
# all_bert_features_list = [self.t2s_model.model.bert_proj(item.to(self.t2s_model.device).transpose(0, 1)) for item in all_bert_features_list]
# all_bert_features_list = [F.pad(item,(0,0,0,max_len-item.shape[0]), value=0) for item in all_bert_features_list]
# all_bert_features_batch = torch.stack(all_bert_features_list, dim=0)
batch = {
"phones": phones_batch,
"phones_len": torch.LongTensor(phones_len_list).to(device),
"all_phones": all_phones_batch,
"all_phones_len": torch.LongTensor(all_phones_len_list).to(device),
"all_bert_features": all_bert_features_batch,
"norm_text": norm_text_batch
}
_data.append(batch)
return _data, batch_index_list
all_phones = torch.LongTensor(prompt_data["phones"]+item["phones"]).to(device)
phones = torch.LongTensor(item["phones"]).to(device)
# norm_text = prompt_data["norm_text"]+item["norm_text"]
else:
all_bert_features = item["bert_features"]\
.to(dtype=precision, device=device)
phones = torch.LongTensor(item["phones"]).to(device)
all_phones = phones
# norm_text = item["norm_text"]
bert_max_len = max(bert_max_len, all_bert_features.shape[-1])
phones_max_len = max(phones_max_len, phones.shape[-1])
phones_list.append(phones)
phones_len_list.append(phones.shape[-1])
all_phones_list.append(all_phones)
all_phones_len_list.append(all_phones.shape[-1])
all_bert_features_list.append(all_bert_features)
norm_text_batch.append(item["norm_text"])
phones_batch = phones_list
all_phones_batch = all_phones_list
all_bert_features_batch = all_bert_features_list
# max_len = max(bert_max_len, phones_max_len)
# phones_batch = self.batch_sequences(phones_list, axis=0, pad_value=0, max_length=max_len)
#### 直接对phones和bert_features进行pad会增大复读概率。
# all_phones_batch = self.batch_sequences(all_phones_list, axis=0, pad_value=0, max_length=max_len)
# all_bert_features_batch = all_bert_features_list
# all_bert_features_batch = torch.zeros(len(item_list), 1024, max_len, dtype=precision, device=device)
# for idx, item in enumerate(all_bert_features_list):
# all_bert_features_batch[idx, :, : item.shape[-1]] = item
# #### 先对phones进行embedding、对bert_features进行project再pad到相同长度以缓解复读问题。可能还有其他因素导致复读
# all_phones_list = [self.t2s_model.model.ar_text_embedding(item.to(self.t2s_model.device)) for item in all_phones_list]
# all_phones_list = [F.pad(item,(0,0,0,max_len-item.shape[0]),value=0) for item in all_phones_list]
# all_phones_batch = torch.stack(all_phones_list, dim=0)
# all_bert_features_list = [self.t2s_model.model.bert_proj(item.to(self.t2s_model.device).transpose(0, 1)) for item in all_bert_features_list]
# all_bert_features_list = [F.pad(item,(0,0,0,max_len-item.shape[0]), value=0) for item in all_bert_features_list]
# all_bert_features_batch = torch.stack(all_bert_features_list, dim=0)
batch = {
"phones": phones_batch,
"phones_len": torch.LongTensor(phones_len_list).to(device),
"all_phones": all_phones_batch,
"all_phones_len": torch.LongTensor(all_phones_len_list).to(device),
"all_bert_features": all_bert_features_batch,
"norm_text": norm_text_batch
}
_data.append(batch)
return _data, batch_index_list
def recovery_order(self, data:list, batch_index_list:list)->list:
'''
Recovery the order of the audio according to the batch_index_list.
@ -566,7 +568,8 @@ class TTS:
Stop the inference process.
'''
self.stop_flag = True
def run(self, inputs:dict):
"""
Text to speech inference.
@ -850,7 +853,7 @@ class TTS:
raise e
finally:
self.empty_cache()
def empty_cache(self):
try:
if "cuda" in str(self.configs.device):
@ -859,7 +862,7 @@ class TTS:
torch.mps.empty_cache()
except:
pass
def audio_postprocess(self,
audio:List[torch.Tensor],
sr:int,
@ -873,32 +876,36 @@ class TTS:
dtype=self.precision,
device=self.configs.device
)
for i, batch in enumerate(audio):
for j, audio_fragment in enumerate(batch):
max_audio=torch.abs(audio_fragment).max()#简单防止16bit爆音
if max_audio>1: audio_fragment/=max_audio
audio_fragment:torch.Tensor = torch.cat([audio_fragment, zero_wav], dim=0)
audio[i][j] = audio_fragment.cpu().numpy()
if split_bucket:
audio = self.recovery_order(audio, batch_index_list)
else:
# audio = [item for batch in audio for item in batch]
audio = sum(audio, [])
audio = np.concatenate(audio, 0)
audio = (audio * 32768).astype(np.int16)
try:
if speed_factor != 1.0:
audio = speed_change(audio, speed=speed_factor, sr=int(sr))
except Exception as e:
print(f"Failed to change speed of audio: \n{e}")
return sr, audio
def speed_change(input_audio:np.ndarray, speed:float, sr:int):
# 将 NumPy 数组转换为原始 PCM 流
raw_audio = input_audio.astype(np.int16).tobytes()
@ -918,4 +925,4 @@ def speed_change(input_audio:np.ndarray, speed:float, sr:int):
# 将管道输出解码为 NumPy 数组
processed_audio = np.frombuffer(out, np.int16)
return processed_audio
return processed_audio