mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2025-04-05 19:41:56 +08:00
缓解了batch_size>1时的复读问题,缓解方法是:在T2S模型中,先对phones进行embedding、对bert_features进行project,再pad到相同长度。
This commit is contained in:
parent
3c78539c44
commit
864a148d75
@ -504,18 +504,29 @@ class Text2SemanticDecoder(nn.Module):
|
||||
|
||||
def infer_panel_batch_infer_with_flash_attn(
|
||||
self,
|
||||
x, #####全部文本token
|
||||
x_lens,
|
||||
prompts, ####参考音频token
|
||||
bert_feature,
|
||||
x:List[torch.LongTensor], #####全部文本token
|
||||
x_lens:torch.LongTensor,
|
||||
prompts:torch.LongTensor, ####参考音频token
|
||||
bert_feature:List[torch.LongTensor],
|
||||
top_k: int = -100,
|
||||
top_p: int = 100,
|
||||
early_stop_num: int = -1,
|
||||
temperature: float = 1.0,
|
||||
):
|
||||
|
||||
bert_feature = self.bert_proj(bert_feature.transpose(1, 2))
|
||||
x = self.ar_text_embedding(x)
|
||||
# 先对phones进行embedding、对bert_features进行project,再pad到相同长度,以缓解复读问题。(可能还有其他因素导致复读)
|
||||
max_len = 0
|
||||
for x_item, bert_item in zip(x, bert_feature):
|
||||
max_len = max(max_len, x_item.shape[0], bert_item.shape[1])
|
||||
x_list = [self.ar_text_embedding(item) for item in x]
|
||||
x_list = [F.pad(item,(0,0,0,max_len-item.shape[0]),value=0) if item.shape[0]<max_len else item for item in x_list]
|
||||
x = torch.stack(x_list, dim=0)
|
||||
|
||||
bert_features_list = [self.bert_proj(item.transpose(0, 1)) for item in bert_feature]
|
||||
bert_features_list = [F.pad(item,(0,0,0,max_len-item.shape[0]), value=0) if item.shape[0]<max_len else item for item in bert_features_list]
|
||||
bert_feature = torch.stack(bert_features_list, dim=0)
|
||||
|
||||
# bert_feature = self.bert_proj(bert_feature.transpose(1, 2))
|
||||
# x = self.ar_text_embedding(x)
|
||||
x = x + bert_feature
|
||||
x = self.ar_text_position(x)
|
||||
|
||||
@ -658,17 +669,30 @@ class Text2SemanticDecoder(nn.Module):
|
||||
|
||||
def infer_panel_batch_only(
|
||||
self,
|
||||
x, #####全部文本token
|
||||
x_lens,
|
||||
prompts, ####参考音频token
|
||||
bert_feature,
|
||||
x:List[torch.LongTensor], #####全部文本token
|
||||
x_lens:torch.LongTensor,
|
||||
prompts:torch.LongTensor, ####参考音频token
|
||||
bert_feature:List[torch.LongTensor],
|
||||
top_k: int = -100,
|
||||
top_p: int = 100,
|
||||
early_stop_num: int = -1,
|
||||
temperature: float = 1.0,
|
||||
):
|
||||
x = self.ar_text_embedding(x)
|
||||
x = x + self.bert_proj(bert_feature.transpose(1, 2))
|
||||
# 先对phones进行embedding、对bert_features进行project,再pad到相同长度,以缓解复读问题。(可能还有其他因素导致复读)
|
||||
max_len = 0
|
||||
for x_item, bert_item in zip(x, bert_feature):
|
||||
max_len = max(max_len, x_item.shape[0], bert_item.shape[1])
|
||||
x_list = [self.ar_text_embedding(item) for item in x]
|
||||
x_list = [F.pad(item,(0,0,0,max_len-item.shape[0]),value=0) if item.shape[0]<max_len else item for item in x_list]
|
||||
x = torch.stack(x_list, dim=0)
|
||||
|
||||
bert_features_list = [self.bert_proj(item.transpose(0, 1)) for item in bert_feature]
|
||||
bert_features_list = [F.pad(item,(0,0,0,max_len-item.shape[0]), value=0) if item.shape[0]<max_len else item for item in bert_features_list]
|
||||
bert_feature = torch.stack(bert_features_list, dim=0)
|
||||
|
||||
# bert_feature = self.bert_proj(bert_feature.transpose(1, 2))
|
||||
# x = self.ar_text_embedding(x)
|
||||
x = x + bert_feature
|
||||
x = self.ar_text_position(x)
|
||||
|
||||
# AR Decoder
|
||||
|
@ -55,6 +55,7 @@ def set_seed(seed:int):
|
||||
seed = int(seed)
|
||||
seed = seed if seed != -1 else random.randrange(1 << 32)
|
||||
print(f"Set seed to {seed}")
|
||||
os.environ['PYTHONHASHSEED'] = str(seed)
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
@ -428,7 +429,14 @@ class TTS:
|
||||
batch = torch.stack(padded_sequences)
|
||||
return batch
|
||||
|
||||
def to_batch(self, data:list, prompt_data:dict=None, batch_size:int=5, threshold:float=0.75, split_bucket:bool=True):
|
||||
def to_batch(self, data:list,
|
||||
prompt_data:dict=None,
|
||||
batch_size:int=5,
|
||||
threshold:float=0.75,
|
||||
split_bucket:bool=True,
|
||||
device:torch.device=torch.device("cpu"),
|
||||
precison:torch.dtype=torch.float32,
|
||||
):
|
||||
|
||||
_data:list = []
|
||||
index_and_len_list = []
|
||||
@ -480,14 +488,14 @@ class TTS:
|
||||
for item in item_list:
|
||||
if prompt_data is not None:
|
||||
all_bert_features = torch.cat([prompt_data["bert_features"], item["bert_features"]], 1)\
|
||||
.to(dtype=self.precison)
|
||||
all_phones = torch.LongTensor(prompt_data["phones"]+item["phones"])
|
||||
phones = torch.LongTensor(item["phones"])
|
||||
.to(dtype=precison, device=device)
|
||||
all_phones = torch.LongTensor(prompt_data["phones"]+item["phones"]).to(device)
|
||||
phones = torch.LongTensor(item["phones"]).to(device)
|
||||
# norm_text = prompt_data["norm_text"]+item["norm_text"]
|
||||
else:
|
||||
all_bert_features = item["bert_features"]\
|
||||
.to(dtype=self.precison)
|
||||
phones = torch.LongTensor(item["phones"])
|
||||
.to(dtype=precison, device=device)
|
||||
phones = torch.LongTensor(item["phones"]).to(device)
|
||||
all_phones = phones
|
||||
# norm_text = item["norm_text"]
|
||||
|
||||
@ -502,19 +510,33 @@ class TTS:
|
||||
norm_text_batch.append(item["norm_text"])
|
||||
|
||||
phones_batch = phones_list
|
||||
max_len = max(bert_max_len, phones_max_len)
|
||||
all_phones_batch = all_phones_list
|
||||
all_bert_features_batch = all_bert_features_list
|
||||
|
||||
|
||||
# max_len = max(bert_max_len, phones_max_len)
|
||||
# phones_batch = self.batch_sequences(phones_list, axis=0, pad_value=0, max_length=max_len)
|
||||
all_phones_batch = self.batch_sequences(all_phones_list, axis=0, pad_value=0, max_length=max_len)
|
||||
#### 直接对phones和bert_features进行pad,会增大复读概率。
|
||||
# all_phones_batch = self.batch_sequences(all_phones_list, axis=0, pad_value=0, max_length=max_len)
|
||||
# all_bert_features_batch = all_bert_features_list
|
||||
all_bert_features_batch = torch.zeros(len(item_list), 1024, max_len, dtype=self.precison)
|
||||
for idx, item in enumerate(all_bert_features_list):
|
||||
all_bert_features_batch[idx, :, : item.shape[-1]] = item
|
||||
# all_bert_features_batch = torch.zeros(len(item_list), 1024, max_len, dtype=precison, device=device)
|
||||
# for idx, item in enumerate(all_bert_features_list):
|
||||
# all_bert_features_batch[idx, :, : item.shape[-1]] = item
|
||||
|
||||
# #### 先对phones进行embedding、对bert_features进行project,再pad到相同长度,以缓解复读问题。(可能还有其他因素导致复读)
|
||||
# all_phones_list = [self.t2s_model.model.ar_text_embedding(item.to(self.t2s_model.device)) for item in all_phones_list]
|
||||
# all_phones_list = [F.pad(item,(0,0,0,max_len-item.shape[0]),value=0) for item in all_phones_list]
|
||||
# all_phones_batch = torch.stack(all_phones_list, dim=0)
|
||||
|
||||
# all_bert_features_list = [self.t2s_model.model.bert_proj(item.to(self.t2s_model.device).transpose(0, 1)) for item in all_bert_features_list]
|
||||
# all_bert_features_list = [F.pad(item,(0,0,0,max_len-item.shape[0]), value=0) for item in all_bert_features_list]
|
||||
# all_bert_features_batch = torch.stack(all_bert_features_list, dim=0)
|
||||
|
||||
batch = {
|
||||
"phones": phones_batch,
|
||||
"phones_len": torch.LongTensor(phones_len_list),
|
||||
"phones_len": torch.LongTensor(phones_len_list).to(device),
|
||||
"all_phones": all_phones_batch,
|
||||
"all_phones_len": torch.LongTensor(all_phones_len_list),
|
||||
"all_phones_len": torch.LongTensor(all_phones_len_list).to(device),
|
||||
"all_bert_features": all_bert_features_batch,
|
||||
"norm_text": norm_text_batch
|
||||
}
|
||||
@ -658,7 +680,9 @@ class TTS:
|
||||
prompt_data=self.prompt_cache if not no_prompt_text else None,
|
||||
batch_size=batch_size,
|
||||
threshold=batch_threshold,
|
||||
split_bucket=split_bucket
|
||||
split_bucket=split_bucket,
|
||||
device=self.configs.device,
|
||||
precison=self.precison
|
||||
)
|
||||
else:
|
||||
print(i18n("############ 切分文本 ############"))
|
||||
@ -688,7 +712,9 @@ class TTS:
|
||||
prompt_data=self.prompt_cache if not no_prompt_text else None,
|
||||
batch_size=batch_size,
|
||||
threshold=batch_threshold,
|
||||
split_bucket=False
|
||||
split_bucket=False,
|
||||
device=self.configs.device,
|
||||
precison=self.precison
|
||||
)
|
||||
return batch[0]
|
||||
|
||||
@ -706,26 +732,18 @@ class TTS:
|
||||
if item is None:
|
||||
continue
|
||||
|
||||
batch_phones = item["phones"]
|
||||
batch_phones_len = item["phones_len"]
|
||||
all_phoneme_ids = item["all_phones"]
|
||||
all_phoneme_lens = item["all_phones_len"]
|
||||
all_bert_features = item["all_bert_features"]
|
||||
norm_text = item["norm_text"]
|
||||
|
||||
# batch_phones = batch_phones.to(self.configs.device)
|
||||
batch_phones_len = batch_phones_len.to(self.configs.device)
|
||||
all_phoneme_ids = all_phoneme_ids.to(self.configs.device)
|
||||
all_phoneme_lens = all_phoneme_lens.to(self.configs.device)
|
||||
all_bert_features = all_bert_features.to(self.configs.device)
|
||||
if self.configs.is_half:
|
||||
all_bert_features = all_bert_features.half()
|
||||
batch_phones:List[torch.LongTensor] = item["phones"]
|
||||
batch_phones_len:torch.LongTensor = item["phones_len"]
|
||||
all_phoneme_ids:List[torch.LongTensor] = item["all_phones"]
|
||||
all_phoneme_lens:torch.LongTensor = item["all_phones_len"]
|
||||
all_bert_features:List[torch.LongTensor] = item["all_bert_features"]
|
||||
norm_text:str = item["norm_text"]
|
||||
|
||||
print(i18n("前端处理后的文本(每句):"), norm_text)
|
||||
if no_prompt_text :
|
||||
prompt = None
|
||||
else:
|
||||
prompt = self.prompt_cache["prompt_semantic"].expand(all_phoneme_ids.shape[0], -1).to(self.configs.device)
|
||||
prompt = self.prompt_cache["prompt_semantic"].expand(len(all_phoneme_ids), -1).to(self.configs.device)
|
||||
|
||||
with torch.no_grad():
|
||||
pred_semantic_list, idx_list = self.t2s_model.model.infer_panel(
|
||||
|
Loading…
x
Reference in New Issue
Block a user