diff --git a/GPT_SoVITS/AR/models/t2s_model.py b/GPT_SoVITS/AR/models/t2s_model.py index ac905f4b..7ed63e22 100644 --- a/GPT_SoVITS/AR/models/t2s_model.py +++ b/GPT_SoVITS/AR/models/t2s_model.py @@ -553,6 +553,10 @@ class Text2SemanticDecoder(nn.Module): mask=xy_attn_mask, ) logits = self.ar_predict_layer(xy_dec[:, -1]) + + eos_penalty = 2.0 + logits[:, self.EOS] -= eos_penalty + samples = topk_sampling(logits, top_k=top_k, top_p=1.0, temperature=temperature) if early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num: @@ -705,6 +709,9 @@ class Text2SemanticDecoder(nn.Module): xy_dec, k_cache, v_cache = self.t2s_transformer.decode_next_token(xy_pos, k_cache, v_cache, attn_mask) logits = self.ar_predict_layer(xy_dec[:, -1]) + eos_penalty = 2.0 + logits[:, self.EOS] -= eos_penalty + if idx == 0: attn_mask = F.pad(attn_mask[:, :, -1].unsqueeze(-2), (0, 1), value=False) else: @@ -895,6 +902,9 @@ class Text2SemanticDecoder(nn.Module): logits = self.ar_predict_layer(xy_dec[:, -1]) + eos_penalty = 2.0 + logits[:, self.EOS] -= eos_penalty + if idx == 0: xy_attn_mask = None if idx < 11: ###至少预测出10个token不然不给停止(0.4s)