fix some bugs GPT_SoVITS/TTS_infer_pack/TTS.py

This commit is contained in:
chasonjiang 2024-03-09 01:02:09 +08:00
parent 17832e5c4a
commit 7556e8cc96

View File

@ -275,7 +275,7 @@ class TTS:
prompt_semantic = codes[0, 0].to(self.configs.device)
self.prompt_cache["prompt_semantic"] = prompt_semantic
def batch_sequences(self, sequences: List[torch.Tensor], axis: int = 0, pad_value: int = 0):
def batch_sequences(self, sequences: List[torch.Tensor], axis: int = 0, pad_value: int = 0, max_length:int=None):
seq = sequences[0]
ndim = seq.dim()
if axis < 0:
@ -283,7 +283,10 @@ class TTS:
dtype:torch.dtype = seq.dtype
pad_value = torch.tensor(pad_value, dtype=dtype)
seq_lengths = [seq.shape[axis] for seq in sequences]
max_length = max(seq_lengths)
if max_length is None:
max_length = max(seq_lengths)
else:
max_length = max(seq_lengths) if max_length < max(seq_lengths) else max_length
padded_sequences = []
for seq, length in zip(sequences, seq_lengths):
@ -333,6 +336,8 @@ class TTS:
all_phones_list = []
all_bert_features_list = []
norm_text_batch = []
bert_max_len = 0
phones_max_len = 0
for item in item_list:
if prompt_data is not None:
all_bert_features = torch.cat([prompt_data["bert_features"].clone(), item["bert_features"]], 1)
@ -344,15 +349,18 @@ class TTS:
phones = torch.LongTensor(item["phones"])
all_phones = phones.clone()
# norm_text = item["norm_text"]
bert_max_len = max(bert_max_len, all_bert_features.shape[-1])
phones_max_len = max(phones_max_len, phones.shape[-1])
phones_list.append(phones)
all_phones_list.append(all_phones)
all_bert_features_list.append(all_bert_features)
norm_text_batch.append(item["norm_text"])
# phones_batch = phones_list
phones_batch = self.batch_sequences(phones_list, axis=0, pad_value=0)
all_phones_batch = self.batch_sequences(all_phones_list, axis=0, pad_value=0)
all_bert_features_batch = torch.FloatTensor(len(item_list), 1024, all_phones_batch.shape[-1])
max_len = max(bert_max_len, phones_max_len)
phones_batch = self.batch_sequences(phones_list, axis=0, pad_value=0, max_length=max_len)
all_phones_batch = self.batch_sequences(all_phones_list, axis=0, pad_value=0, max_length=max_len)
all_bert_features_batch = torch.FloatTensor(len(item_list), 1024, max_len)
all_bert_features_batch.zero_()
for idx, item in enumerate(all_bert_features_list):