mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2025-04-06 03:57:44 +08:00
fix some bugs GPT_SoVITS/TTS_infer_pack/TTS.py
This commit is contained in:
parent
17832e5c4a
commit
7556e8cc96
@ -275,7 +275,7 @@ class TTS:
|
||||
prompt_semantic = codes[0, 0].to(self.configs.device)
|
||||
self.prompt_cache["prompt_semantic"] = prompt_semantic
|
||||
|
||||
def batch_sequences(self, sequences: List[torch.Tensor], axis: int = 0, pad_value: int = 0):
|
||||
def batch_sequences(self, sequences: List[torch.Tensor], axis: int = 0, pad_value: int = 0, max_length:int=None):
|
||||
seq = sequences[0]
|
||||
ndim = seq.dim()
|
||||
if axis < 0:
|
||||
@ -283,7 +283,10 @@ class TTS:
|
||||
dtype:torch.dtype = seq.dtype
|
||||
pad_value = torch.tensor(pad_value, dtype=dtype)
|
||||
seq_lengths = [seq.shape[axis] for seq in sequences]
|
||||
max_length = max(seq_lengths)
|
||||
if max_length is None:
|
||||
max_length = max(seq_lengths)
|
||||
else:
|
||||
max_length = max(seq_lengths) if max_length < max(seq_lengths) else max_length
|
||||
|
||||
padded_sequences = []
|
||||
for seq, length in zip(sequences, seq_lengths):
|
||||
@ -333,6 +336,8 @@ class TTS:
|
||||
all_phones_list = []
|
||||
all_bert_features_list = []
|
||||
norm_text_batch = []
|
||||
bert_max_len = 0
|
||||
phones_max_len = 0
|
||||
for item in item_list:
|
||||
if prompt_data is not None:
|
||||
all_bert_features = torch.cat([prompt_data["bert_features"].clone(), item["bert_features"]], 1)
|
||||
@ -344,15 +349,18 @@ class TTS:
|
||||
phones = torch.LongTensor(item["phones"])
|
||||
all_phones = phones.clone()
|
||||
# norm_text = item["norm_text"]
|
||||
|
||||
bert_max_len = max(bert_max_len, all_bert_features.shape[-1])
|
||||
phones_max_len = max(phones_max_len, phones.shape[-1])
|
||||
|
||||
phones_list.append(phones)
|
||||
all_phones_list.append(all_phones)
|
||||
all_bert_features_list.append(all_bert_features)
|
||||
norm_text_batch.append(item["norm_text"])
|
||||
# phones_batch = phones_list
|
||||
phones_batch = self.batch_sequences(phones_list, axis=0, pad_value=0)
|
||||
all_phones_batch = self.batch_sequences(all_phones_list, axis=0, pad_value=0)
|
||||
all_bert_features_batch = torch.FloatTensor(len(item_list), 1024, all_phones_batch.shape[-1])
|
||||
max_len = max(bert_max_len, phones_max_len)
|
||||
phones_batch = self.batch_sequences(phones_list, axis=0, pad_value=0, max_length=max_len)
|
||||
all_phones_batch = self.batch_sequences(all_phones_list, axis=0, pad_value=0, max_length=max_len)
|
||||
all_bert_features_batch = torch.FloatTensor(len(item_list), 1024, max_len)
|
||||
all_bert_features_batch.zero_()
|
||||
|
||||
for idx, item in enumerate(all_bert_features_list):
|
||||
|
Loading…
x
Reference in New Issue
Block a user