diff --git a/GPT_SoVITS/TTS_infer_pack/TTS.py b/GPT_SoVITS/TTS_infer_pack/TTS.py index 9f98a24..09f3175 100644 --- a/GPT_SoVITS/TTS_infer_pack/TTS.py +++ b/GPT_SoVITS/TTS_infer_pack/TTS.py @@ -275,7 +275,7 @@ class TTS: prompt_semantic = codes[0, 0].to(self.configs.device) self.prompt_cache["prompt_semantic"] = prompt_semantic - def batch_sequences(self, sequences: List[torch.Tensor], axis: int = 0, pad_value: int = 0): + def batch_sequences(self, sequences: List[torch.Tensor], axis: int = 0, pad_value: int = 0, max_length:int=None): seq = sequences[0] ndim = seq.dim() if axis < 0: @@ -283,7 +283,10 @@ class TTS: dtype:torch.dtype = seq.dtype pad_value = torch.tensor(pad_value, dtype=dtype) seq_lengths = [seq.shape[axis] for seq in sequences] - max_length = max(seq_lengths) + if max_length is None: + max_length = max(seq_lengths) + else: + max_length = max(seq_lengths) if max_length < max(seq_lengths) else max_length padded_sequences = [] for seq, length in zip(sequences, seq_lengths): @@ -333,6 +336,8 @@ class TTS: all_phones_list = [] all_bert_features_list = [] norm_text_batch = [] + bert_max_len = 0 + phones_max_len = 0 for item in item_list: if prompt_data is not None: all_bert_features = torch.cat([prompt_data["bert_features"].clone(), item["bert_features"]], 1) @@ -344,15 +349,18 @@ class TTS: phones = torch.LongTensor(item["phones"]) all_phones = phones.clone() # norm_text = item["norm_text"] - + bert_max_len = max(bert_max_len, all_bert_features.shape[-1]) + phones_max_len = max(phones_max_len, phones.shape[-1]) + phones_list.append(phones) all_phones_list.append(all_phones) all_bert_features_list.append(all_bert_features) norm_text_batch.append(item["norm_text"]) # phones_batch = phones_list - phones_batch = self.batch_sequences(phones_list, axis=0, pad_value=0) - all_phones_batch = self.batch_sequences(all_phones_list, axis=0, pad_value=0) - all_bert_features_batch = torch.FloatTensor(len(item_list), 1024, all_phones_batch.shape[-1]) + max_len = max(bert_max_len, phones_max_len) + phones_batch = self.batch_sequences(phones_list, axis=0, pad_value=0, max_length=max_len) + all_phones_batch = self.batch_sequences(all_phones_list, axis=0, pad_value=0, max_length=max_len) + all_bert_features_batch = torch.FloatTensor(len(item_list), 1024, max_len) all_bert_features_batch.zero_() for idx, item in enumerate(all_bert_features_list):