diff --git a/GPT_SoVITS/AR/models/t2s_model.py b/GPT_SoVITS/AR/models/t2s_model.py index dbca6bae..fee4e7e8 100644 --- a/GPT_SoVITS/AR/models/t2s_model.py +++ b/GPT_SoVITS/AR/models/t2s_model.py @@ -328,8 +328,8 @@ class Text2SemanticDecoder(nn.Module): prompts, ####参考音频token bert_feature, top_k: int = -100, - top_p: int = 100, min_p: float = 0.0, + top_p: int = 100, early_stop_num: int = -1, temperature: float = 1.0, ): @@ -398,7 +398,7 @@ class Text2SemanticDecoder(nn.Module): if(idx==0):###第一次跑不能EOS否则没有了 logits = logits[:, :-1] ###刨除1024终止符号的概率 samples = sample( - logits[0], y, top_k=top_k, top_p=top_p, min_p=min_p, repetition_penalty=1.35, temperature=temperature + logits[0], y, top_k=top_k, min_p=min_p, top_p=top_p, repetition_penalty=1.35, temperature=temperature )[0].unsqueeze(0) # 本次生成的 semantic_ids 和之前的 y 构成新的 y # print(samples.shape)#[1,1]#第一个1是bs