mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2025-08-29 18:24:11 +08:00
fix some bugs GPT_SoVITS/TTS_infer_pack/TTS.py
This commit is contained in:
parent
17832e5c4a
commit
7556e8cc96
@ -275,7 +275,7 @@ class TTS:
|
|||||||
prompt_semantic = codes[0, 0].to(self.configs.device)
|
prompt_semantic = codes[0, 0].to(self.configs.device)
|
||||||
self.prompt_cache["prompt_semantic"] = prompt_semantic
|
self.prompt_cache["prompt_semantic"] = prompt_semantic
|
||||||
|
|
||||||
def batch_sequences(self, sequences: List[torch.Tensor], axis: int = 0, pad_value: int = 0):
|
def batch_sequences(self, sequences: List[torch.Tensor], axis: int = 0, pad_value: int = 0, max_length:int=None):
|
||||||
seq = sequences[0]
|
seq = sequences[0]
|
||||||
ndim = seq.dim()
|
ndim = seq.dim()
|
||||||
if axis < 0:
|
if axis < 0:
|
||||||
@ -283,7 +283,10 @@ class TTS:
|
|||||||
dtype:torch.dtype = seq.dtype
|
dtype:torch.dtype = seq.dtype
|
||||||
pad_value = torch.tensor(pad_value, dtype=dtype)
|
pad_value = torch.tensor(pad_value, dtype=dtype)
|
||||||
seq_lengths = [seq.shape[axis] for seq in sequences]
|
seq_lengths = [seq.shape[axis] for seq in sequences]
|
||||||
max_length = max(seq_lengths)
|
if max_length is None:
|
||||||
|
max_length = max(seq_lengths)
|
||||||
|
else:
|
||||||
|
max_length = max(seq_lengths) if max_length < max(seq_lengths) else max_length
|
||||||
|
|
||||||
padded_sequences = []
|
padded_sequences = []
|
||||||
for seq, length in zip(sequences, seq_lengths):
|
for seq, length in zip(sequences, seq_lengths):
|
||||||
@ -333,6 +336,8 @@ class TTS:
|
|||||||
all_phones_list = []
|
all_phones_list = []
|
||||||
all_bert_features_list = []
|
all_bert_features_list = []
|
||||||
norm_text_batch = []
|
norm_text_batch = []
|
||||||
|
bert_max_len = 0
|
||||||
|
phones_max_len = 0
|
||||||
for item in item_list:
|
for item in item_list:
|
||||||
if prompt_data is not None:
|
if prompt_data is not None:
|
||||||
all_bert_features = torch.cat([prompt_data["bert_features"].clone(), item["bert_features"]], 1)
|
all_bert_features = torch.cat([prompt_data["bert_features"].clone(), item["bert_features"]], 1)
|
||||||
@ -344,15 +349,18 @@ class TTS:
|
|||||||
phones = torch.LongTensor(item["phones"])
|
phones = torch.LongTensor(item["phones"])
|
||||||
all_phones = phones.clone()
|
all_phones = phones.clone()
|
||||||
# norm_text = item["norm_text"]
|
# norm_text = item["norm_text"]
|
||||||
|
bert_max_len = max(bert_max_len, all_bert_features.shape[-1])
|
||||||
|
phones_max_len = max(phones_max_len, phones.shape[-1])
|
||||||
|
|
||||||
phones_list.append(phones)
|
phones_list.append(phones)
|
||||||
all_phones_list.append(all_phones)
|
all_phones_list.append(all_phones)
|
||||||
all_bert_features_list.append(all_bert_features)
|
all_bert_features_list.append(all_bert_features)
|
||||||
norm_text_batch.append(item["norm_text"])
|
norm_text_batch.append(item["norm_text"])
|
||||||
# phones_batch = phones_list
|
# phones_batch = phones_list
|
||||||
phones_batch = self.batch_sequences(phones_list, axis=0, pad_value=0)
|
max_len = max(bert_max_len, phones_max_len)
|
||||||
all_phones_batch = self.batch_sequences(all_phones_list, axis=0, pad_value=0)
|
phones_batch = self.batch_sequences(phones_list, axis=0, pad_value=0, max_length=max_len)
|
||||||
all_bert_features_batch = torch.FloatTensor(len(item_list), 1024, all_phones_batch.shape[-1])
|
all_phones_batch = self.batch_sequences(all_phones_list, axis=0, pad_value=0, max_length=max_len)
|
||||||
|
all_bert_features_batch = torch.FloatTensor(len(item_list), 1024, max_len)
|
||||||
all_bert_features_batch.zero_()
|
all_bert_features_batch.zero_()
|
||||||
|
|
||||||
for idx, item in enumerate(all_bert_features_list):
|
for idx, item in enumerate(all_bert_features_list):
|
||||||
|
Loading…
x
Reference in New Issue
Block a user