mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2025-04-06 03:57:44 +08:00
缓解了batch_size>1时的复读问题,缓解方法是:在T2S模型中,先对phones进行embedding、对bert_features进行project,再pad到相同长度。
This commit is contained in:
parent
3c78539c44
commit
864a148d75
@ -504,18 +504,29 @@ class Text2SemanticDecoder(nn.Module):
|
|||||||
|
|
||||||
def infer_panel_batch_infer_with_flash_attn(
|
def infer_panel_batch_infer_with_flash_attn(
|
||||||
self,
|
self,
|
||||||
x, #####全部文本token
|
x:List[torch.LongTensor], #####全部文本token
|
||||||
x_lens,
|
x_lens:torch.LongTensor,
|
||||||
prompts, ####参考音频token
|
prompts:torch.LongTensor, ####参考音频token
|
||||||
bert_feature,
|
bert_feature:List[torch.LongTensor],
|
||||||
top_k: int = -100,
|
top_k: int = -100,
|
||||||
top_p: int = 100,
|
top_p: int = 100,
|
||||||
early_stop_num: int = -1,
|
early_stop_num: int = -1,
|
||||||
temperature: float = 1.0,
|
temperature: float = 1.0,
|
||||||
):
|
):
|
||||||
|
# 先对phones进行embedding、对bert_features进行project,再pad到相同长度,以缓解复读问题。(可能还有其他因素导致复读)
|
||||||
bert_feature = self.bert_proj(bert_feature.transpose(1, 2))
|
max_len = 0
|
||||||
x = self.ar_text_embedding(x)
|
for x_item, bert_item in zip(x, bert_feature):
|
||||||
|
max_len = max(max_len, x_item.shape[0], bert_item.shape[1])
|
||||||
|
x_list = [self.ar_text_embedding(item) for item in x]
|
||||||
|
x_list = [F.pad(item,(0,0,0,max_len-item.shape[0]),value=0) if item.shape[0]<max_len else item for item in x_list]
|
||||||
|
x = torch.stack(x_list, dim=0)
|
||||||
|
|
||||||
|
bert_features_list = [self.bert_proj(item.transpose(0, 1)) for item in bert_feature]
|
||||||
|
bert_features_list = [F.pad(item,(0,0,0,max_len-item.shape[0]), value=0) if item.shape[0]<max_len else item for item in bert_features_list]
|
||||||
|
bert_feature = torch.stack(bert_features_list, dim=0)
|
||||||
|
|
||||||
|
# bert_feature = self.bert_proj(bert_feature.transpose(1, 2))
|
||||||
|
# x = self.ar_text_embedding(x)
|
||||||
x = x + bert_feature
|
x = x + bert_feature
|
||||||
x = self.ar_text_position(x)
|
x = self.ar_text_position(x)
|
||||||
|
|
||||||
@ -658,17 +669,30 @@ class Text2SemanticDecoder(nn.Module):
|
|||||||
|
|
||||||
def infer_panel_batch_only(
|
def infer_panel_batch_only(
|
||||||
self,
|
self,
|
||||||
x, #####全部文本token
|
x:List[torch.LongTensor], #####全部文本token
|
||||||
x_lens,
|
x_lens:torch.LongTensor,
|
||||||
prompts, ####参考音频token
|
prompts:torch.LongTensor, ####参考音频token
|
||||||
bert_feature,
|
bert_feature:List[torch.LongTensor],
|
||||||
top_k: int = -100,
|
top_k: int = -100,
|
||||||
top_p: int = 100,
|
top_p: int = 100,
|
||||||
early_stop_num: int = -1,
|
early_stop_num: int = -1,
|
||||||
temperature: float = 1.0,
|
temperature: float = 1.0,
|
||||||
):
|
):
|
||||||
x = self.ar_text_embedding(x)
|
# 先对phones进行embedding、对bert_features进行project,再pad到相同长度,以缓解复读问题。(可能还有其他因素导致复读)
|
||||||
x = x + self.bert_proj(bert_feature.transpose(1, 2))
|
max_len = 0
|
||||||
|
for x_item, bert_item in zip(x, bert_feature):
|
||||||
|
max_len = max(max_len, x_item.shape[0], bert_item.shape[1])
|
||||||
|
x_list = [self.ar_text_embedding(item) for item in x]
|
||||||
|
x_list = [F.pad(item,(0,0,0,max_len-item.shape[0]),value=0) if item.shape[0]<max_len else item for item in x_list]
|
||||||
|
x = torch.stack(x_list, dim=0)
|
||||||
|
|
||||||
|
bert_features_list = [self.bert_proj(item.transpose(0, 1)) for item in bert_feature]
|
||||||
|
bert_features_list = [F.pad(item,(0,0,0,max_len-item.shape[0]), value=0) if item.shape[0]<max_len else item for item in bert_features_list]
|
||||||
|
bert_feature = torch.stack(bert_features_list, dim=0)
|
||||||
|
|
||||||
|
# bert_feature = self.bert_proj(bert_feature.transpose(1, 2))
|
||||||
|
# x = self.ar_text_embedding(x)
|
||||||
|
x = x + bert_feature
|
||||||
x = self.ar_text_position(x)
|
x = self.ar_text_position(x)
|
||||||
|
|
||||||
# AR Decoder
|
# AR Decoder
|
||||||
|
@ -55,6 +55,7 @@ def set_seed(seed:int):
|
|||||||
seed = int(seed)
|
seed = int(seed)
|
||||||
seed = seed if seed != -1 else random.randrange(1 << 32)
|
seed = seed if seed != -1 else random.randrange(1 << 32)
|
||||||
print(f"Set seed to {seed}")
|
print(f"Set seed to {seed}")
|
||||||
|
os.environ['PYTHONHASHSEED'] = str(seed)
|
||||||
random.seed(seed)
|
random.seed(seed)
|
||||||
np.random.seed(seed)
|
np.random.seed(seed)
|
||||||
torch.manual_seed(seed)
|
torch.manual_seed(seed)
|
||||||
@ -428,7 +429,14 @@ class TTS:
|
|||||||
batch = torch.stack(padded_sequences)
|
batch = torch.stack(padded_sequences)
|
||||||
return batch
|
return batch
|
||||||
|
|
||||||
def to_batch(self, data:list, prompt_data:dict=None, batch_size:int=5, threshold:float=0.75, split_bucket:bool=True):
|
def to_batch(self, data:list,
|
||||||
|
prompt_data:dict=None,
|
||||||
|
batch_size:int=5,
|
||||||
|
threshold:float=0.75,
|
||||||
|
split_bucket:bool=True,
|
||||||
|
device:torch.device=torch.device("cpu"),
|
||||||
|
precison:torch.dtype=torch.float32,
|
||||||
|
):
|
||||||
|
|
||||||
_data:list = []
|
_data:list = []
|
||||||
index_and_len_list = []
|
index_and_len_list = []
|
||||||
@ -480,14 +488,14 @@ class TTS:
|
|||||||
for item in item_list:
|
for item in item_list:
|
||||||
if prompt_data is not None:
|
if prompt_data is not None:
|
||||||
all_bert_features = torch.cat([prompt_data["bert_features"], item["bert_features"]], 1)\
|
all_bert_features = torch.cat([prompt_data["bert_features"], item["bert_features"]], 1)\
|
||||||
.to(dtype=self.precison)
|
.to(dtype=precison, device=device)
|
||||||
all_phones = torch.LongTensor(prompt_data["phones"]+item["phones"])
|
all_phones = torch.LongTensor(prompt_data["phones"]+item["phones"]).to(device)
|
||||||
phones = torch.LongTensor(item["phones"])
|
phones = torch.LongTensor(item["phones"]).to(device)
|
||||||
# norm_text = prompt_data["norm_text"]+item["norm_text"]
|
# norm_text = prompt_data["norm_text"]+item["norm_text"]
|
||||||
else:
|
else:
|
||||||
all_bert_features = item["bert_features"]\
|
all_bert_features = item["bert_features"]\
|
||||||
.to(dtype=self.precison)
|
.to(dtype=precison, device=device)
|
||||||
phones = torch.LongTensor(item["phones"])
|
phones = torch.LongTensor(item["phones"]).to(device)
|
||||||
all_phones = phones
|
all_phones = phones
|
||||||
# norm_text = item["norm_text"]
|
# norm_text = item["norm_text"]
|
||||||
|
|
||||||
@ -502,19 +510,33 @@ class TTS:
|
|||||||
norm_text_batch.append(item["norm_text"])
|
norm_text_batch.append(item["norm_text"])
|
||||||
|
|
||||||
phones_batch = phones_list
|
phones_batch = phones_list
|
||||||
max_len = max(bert_max_len, phones_max_len)
|
all_phones_batch = all_phones_list
|
||||||
|
all_bert_features_batch = all_bert_features_list
|
||||||
|
|
||||||
|
|
||||||
|
# max_len = max(bert_max_len, phones_max_len)
|
||||||
# phones_batch = self.batch_sequences(phones_list, axis=0, pad_value=0, max_length=max_len)
|
# phones_batch = self.batch_sequences(phones_list, axis=0, pad_value=0, max_length=max_len)
|
||||||
all_phones_batch = self.batch_sequences(all_phones_list, axis=0, pad_value=0, max_length=max_len)
|
#### 直接对phones和bert_features进行pad,会增大复读概率。
|
||||||
|
# all_phones_batch = self.batch_sequences(all_phones_list, axis=0, pad_value=0, max_length=max_len)
|
||||||
# all_bert_features_batch = all_bert_features_list
|
# all_bert_features_batch = all_bert_features_list
|
||||||
all_bert_features_batch = torch.zeros(len(item_list), 1024, max_len, dtype=self.precison)
|
# all_bert_features_batch = torch.zeros(len(item_list), 1024, max_len, dtype=precison, device=device)
|
||||||
for idx, item in enumerate(all_bert_features_list):
|
# for idx, item in enumerate(all_bert_features_list):
|
||||||
all_bert_features_batch[idx, :, : item.shape[-1]] = item
|
# all_bert_features_batch[idx, :, : item.shape[-1]] = item
|
||||||
|
|
||||||
|
# #### 先对phones进行embedding、对bert_features进行project,再pad到相同长度,以缓解复读问题。(可能还有其他因素导致复读)
|
||||||
|
# all_phones_list = [self.t2s_model.model.ar_text_embedding(item.to(self.t2s_model.device)) for item in all_phones_list]
|
||||||
|
# all_phones_list = [F.pad(item,(0,0,0,max_len-item.shape[0]),value=0) for item in all_phones_list]
|
||||||
|
# all_phones_batch = torch.stack(all_phones_list, dim=0)
|
||||||
|
|
||||||
|
# all_bert_features_list = [self.t2s_model.model.bert_proj(item.to(self.t2s_model.device).transpose(0, 1)) for item in all_bert_features_list]
|
||||||
|
# all_bert_features_list = [F.pad(item,(0,0,0,max_len-item.shape[0]), value=0) for item in all_bert_features_list]
|
||||||
|
# all_bert_features_batch = torch.stack(all_bert_features_list, dim=0)
|
||||||
|
|
||||||
batch = {
|
batch = {
|
||||||
"phones": phones_batch,
|
"phones": phones_batch,
|
||||||
"phones_len": torch.LongTensor(phones_len_list),
|
"phones_len": torch.LongTensor(phones_len_list).to(device),
|
||||||
"all_phones": all_phones_batch,
|
"all_phones": all_phones_batch,
|
||||||
"all_phones_len": torch.LongTensor(all_phones_len_list),
|
"all_phones_len": torch.LongTensor(all_phones_len_list).to(device),
|
||||||
"all_bert_features": all_bert_features_batch,
|
"all_bert_features": all_bert_features_batch,
|
||||||
"norm_text": norm_text_batch
|
"norm_text": norm_text_batch
|
||||||
}
|
}
|
||||||
@ -658,7 +680,9 @@ class TTS:
|
|||||||
prompt_data=self.prompt_cache if not no_prompt_text else None,
|
prompt_data=self.prompt_cache if not no_prompt_text else None,
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
threshold=batch_threshold,
|
threshold=batch_threshold,
|
||||||
split_bucket=split_bucket
|
split_bucket=split_bucket,
|
||||||
|
device=self.configs.device,
|
||||||
|
precison=self.precison
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
print(i18n("############ 切分文本 ############"))
|
print(i18n("############ 切分文本 ############"))
|
||||||
@ -688,7 +712,9 @@ class TTS:
|
|||||||
prompt_data=self.prompt_cache if not no_prompt_text else None,
|
prompt_data=self.prompt_cache if not no_prompt_text else None,
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
threshold=batch_threshold,
|
threshold=batch_threshold,
|
||||||
split_bucket=False
|
split_bucket=False,
|
||||||
|
device=self.configs.device,
|
||||||
|
precison=self.precison
|
||||||
)
|
)
|
||||||
return batch[0]
|
return batch[0]
|
||||||
|
|
||||||
@ -706,26 +732,18 @@ class TTS:
|
|||||||
if item is None:
|
if item is None:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
batch_phones = item["phones"]
|
batch_phones:List[torch.LongTensor] = item["phones"]
|
||||||
batch_phones_len = item["phones_len"]
|
batch_phones_len:torch.LongTensor = item["phones_len"]
|
||||||
all_phoneme_ids = item["all_phones"]
|
all_phoneme_ids:List[torch.LongTensor] = item["all_phones"]
|
||||||
all_phoneme_lens = item["all_phones_len"]
|
all_phoneme_lens:torch.LongTensor = item["all_phones_len"]
|
||||||
all_bert_features = item["all_bert_features"]
|
all_bert_features:List[torch.LongTensor] = item["all_bert_features"]
|
||||||
norm_text = item["norm_text"]
|
norm_text:str = item["norm_text"]
|
||||||
|
|
||||||
# batch_phones = batch_phones.to(self.configs.device)
|
|
||||||
batch_phones_len = batch_phones_len.to(self.configs.device)
|
|
||||||
all_phoneme_ids = all_phoneme_ids.to(self.configs.device)
|
|
||||||
all_phoneme_lens = all_phoneme_lens.to(self.configs.device)
|
|
||||||
all_bert_features = all_bert_features.to(self.configs.device)
|
|
||||||
if self.configs.is_half:
|
|
||||||
all_bert_features = all_bert_features.half()
|
|
||||||
|
|
||||||
print(i18n("前端处理后的文本(每句):"), norm_text)
|
print(i18n("前端处理后的文本(每句):"), norm_text)
|
||||||
if no_prompt_text :
|
if no_prompt_text :
|
||||||
prompt = None
|
prompt = None
|
||||||
else:
|
else:
|
||||||
prompt = self.prompt_cache["prompt_semantic"].expand(all_phoneme_ids.shape[0], -1).to(self.configs.device)
|
prompt = self.prompt_cache["prompt_semantic"].expand(len(all_phoneme_ids), -1).to(self.configs.device)
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
pred_semantic_list, idx_list = self.t2s_model.model.infer_panel(
|
pred_semantic_list, idx_list = self.t2s_model.model.infer_panel(
|
||||||
|
Loading…
x
Reference in New Issue
Block a user