From 070ac9b2b2d72c3c6c71139bf0e811276e479a32 Mon Sep 17 00:00:00 2001 From: liufenghua Date: Sun, 11 Feb 2024 15:06:01 +0800 Subject: [PATCH] add DPO training --- .gitignore | 3 +- GPT_SoVITS/AR/models/t2s_model.py | 99 ++++++++++++++++++++++++++++++- GPT_SoVITS/AR/models/utils.py | 69 ++++++++++++++++++++- GPT_SoVITS/inference_webui.py | 10 +++- 4 files changed, 176 insertions(+), 5 deletions(-) diff --git a/.gitignore b/.gitignore index 3a239f8..00f6bb9 100644 --- a/.gitignore +++ b/.gitignore @@ -7,4 +7,5 @@ runtime output logs reference -SoVITS_weights \ No newline at end of file +SoVITS_weights +GPT_weights \ No newline at end of file diff --git a/GPT_SoVITS/AR/models/t2s_model.py b/GPT_SoVITS/AR/models/t2s_model.py index 083dc09..d3e550d 100644 --- a/GPT_SoVITS/AR/models/t2s_model.py +++ b/GPT_SoVITS/AR/models/t2s_model.py @@ -8,6 +8,9 @@ from AR.models.utils import ( sample, logits_to_probs, multinomial_sample_one_no_sync, + dpo_loss, + make_reject_y, + get_batch_logps ) from AR.modules.embedding import SinePositionalEmbedding from AR.modules.embedding import TokenEmbedding @@ -85,11 +88,104 @@ class Text2SemanticDecoder(nn.Module): ignore_index=self.EOS, ) + def make_input_data(self, x, x_lens, y, y_lens, bert_feature): + x = self.ar_text_embedding(x) + x = x + self.bert_proj(bert_feature.transpose(1, 2)) + x = self.ar_text_position(x) + x_mask = make_pad_mask(x_lens) + + y_mask = make_pad_mask(y_lens) + y_mask_int = y_mask.type(torch.int64) + codes = y.type(torch.int64) * (1 - y_mask_int) + + # Training + # AR Decoder + y, targets = self.pad_y_eos(codes, y_mask_int, eos_id=self.EOS) + x_len = x_lens.max() + y_len = y_lens.max() + y_emb = self.ar_audio_embedding(y) + y_pos = self.ar_audio_position(y_emb) + + xy_padding_mask = torch.concat([x_mask, y_mask], dim=1) + + ar_xy_padding_mask = xy_padding_mask + + x_attn_mask = F.pad( + torch.zeros((x_len, x_len), dtype=torch.bool, device=x.device), + (0, y_len), + value=True, + ) + + y_attn_mask = F.pad( + torch.triu( + torch.ones(y_len, y_len, dtype=torch.bool, device=x.device), + diagonal=1, + ), + (x_len, 0), + value=False, + ) + + xy_attn_mask = torch.concat([x_attn_mask, y_attn_mask], dim=0) + bsz, src_len = x.shape[0], x_len + y_len + _xy_padding_mask = ( + ar_xy_padding_mask.view(bsz, 1, 1, src_len) + .expand(-1, self.num_head, -1, -1) + .reshape(bsz * self.num_head, 1, src_len) + ) + xy_attn_mask = xy_attn_mask.logical_or(_xy_padding_mask) + new_attn_mask = torch.zeros_like(xy_attn_mask, dtype=x.dtype) + new_attn_mask.masked_fill_(xy_attn_mask, float("-inf")) + xy_attn_mask = new_attn_mask + # x 和完整的 y 一次性输入模型 + xy_pos = torch.concat([x, y_pos], dim=1) + + return xy_pos, xy_attn_mask, targets + def forward(self, x, x_lens, y, y_lens, bert_feature): """ x: phoneme_ids y: semantic_ids """ + + reject_y, reject_y_lens = make_reject_y(y, y_lens) + + xy_pos, xy_attn_mask, targets = self.make_input_data(x, x_lens, y, y_lens, bert_feature) + + xy_dec, _ = self.h( + (xy_pos, None), + mask=xy_attn_mask, + ) + x_len = x_lens.max() + logits = self.ar_predict_layer(xy_dec[:, x_len:]) + + ###### DPO ############# + reject_xy_pos, reject_xy_attn_mask, reject_targets = self.make_input_data(x, x_lens, reject_y, reject_y_lens, bert_feature) + + reject_xy_dec, _ = self.h( + (reject_xy_pos, None), + mask=reject_xy_attn_mask, + ) + x_len = x_lens.max() + reject_logits = self.ar_predict_layer(reject_xy_dec[:, x_len:]) + + # loss + # from feiteng: 每次 duration 越多, 梯度更新也应该更多, 所以用 sum + + loss_1 = F.cross_entropy(logits.permute(0, 2, 1), targets, reduction="sum") + acc = self.ar_accuracy_metric(logits.permute(0, 2, 1).detach(), targets).item() + + A_logits, R_logits = get_batch_logps(logits, reject_logits, targets, reject_targets) + loss_2, _, _ = dpo_loss(A_logits, R_logits, 0, 0, 0.2, reference_free=True) + + loss = loss_1 + loss_2 + + return loss, acc + + def forward_old(self, x, x_lens, y, y_lens, bert_feature): + """ + x: phoneme_ids + y: semantic_ids + """ x = self.ar_text_embedding(x) x = x + self.bert_proj(bert_feature.transpose(1, 2)) x = self.ar_text_position(x) @@ -231,6 +327,7 @@ class Text2SemanticDecoder(nn.Module): prompts, ####参考音频token bert_feature, top_k: int = -100, + top_p: int = 100, early_stop_num: int = -1, temperature: float = 1.0, ): @@ -305,7 +402,7 @@ class Text2SemanticDecoder(nn.Module): if(idx==0):###第一次跑不能EOS否则没有了 logits = logits[:, :-1] ###刨除1024终止符号的概率 samples = sample( - logits[0], y, top_k=top_k, top_p=1.0, repetition_penalty=1.35 + logits[0], y, top_k=top_k, top_p=top_p, repetition_penalty=1.05, temperature=temperature )[0].unsqueeze(0) if early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num: print("use early stop num:", early_stop_num) diff --git a/GPT_SoVITS/AR/models/utils.py b/GPT_SoVITS/AR/models/utils.py index 25fe446..bc5f2d0 100644 --- a/GPT_SoVITS/AR/models/utils.py +++ b/GPT_SoVITS/AR/models/utils.py @@ -1,7 +1,7 @@ # modified from https://github.com/feng-yufei/shared_debugging_code/blob/main/model/utils.py\ import torch import torch.nn.functional as F - +from typing import Tuple def sequence_mask(length, max_length=None): if max_length is None: @@ -158,3 +158,70 @@ def sample( ) idx_next = multinomial_sample_one_no_sync(probs) return idx_next, probs + +def dpo_loss(policy_chosen_logps: torch.FloatTensor, + policy_rejected_logps: torch.FloatTensor, + reference_chosen_logps: torch.FloatTensor, + reference_rejected_logps: torch.FloatTensor, + beta: float, + reference_free: bool = False) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]: + pi_logratios = policy_chosen_logps - policy_rejected_logps + ref_logratios = reference_chosen_logps - reference_rejected_logps + + if reference_free: + ref_logratios = 0 + + logits = pi_logratios - ref_logratios + + losses = -F.logsigmoid(beta * logits) + chosen_rewards = beta * (policy_chosen_logps - reference_chosen_logps).detach() + rejected_rewards = beta * (policy_rejected_logps - reference_rejected_logps).detach() + + return losses.mean(), chosen_rewards, rejected_rewards + +def get_batch_logps(logits_target: torch.FloatTensor, logits_reject: torch.FloatTensor, labels_target: torch.LongTensor, labels_reject: torch.LongTensor, average_log_prob: bool = False) -> Tuple[torch.FloatTensor, torch.FloatTensor]: + + # dummy token; we'll ignore the losses on these tokens later + + per_token_logps_target = torch.gather(logits_target.log_softmax(-1), dim=2, index=labels_target.unsqueeze(2)).squeeze(2) + per_token_logps_reject = torch.gather(logits_reject.log_softmax(-1), dim=2, index=labels_reject.unsqueeze(2)).squeeze(2) + + return per_token_logps_target.sum(-1), per_token_logps_reject.sum(-1) + +def make_reject_y(y_o, y_lens): + def repeat_P(y): + range_idx, _ = torch.randint(0, len(y), size=(2,)).sort() + pre = y[:range_idx[0]] + shf = y[range_idx[1]:] + range_text = y[range_idx[0]:range_idx[1]] + new_y = torch.cat([pre, range_text, range_text, shf]) + return new_y + def lost_P(y): + range_idx, _ = torch.randint(0, len(y), size=(2,)).sort() + pre = y[:range_idx[0]] + shf = y[range_idx[1]:] + range_text = y[range_idx[0]:range_idx[1]] + new_y = torch.cat([pre, shf]) + return new_y + bs = len(y_lens) + reject_y = [] + reject_y_lens = [] + for b in range(bs): + process_item_idx = torch.randint(0, 1, size=(1, ))[0] + if process_item_idx == 0: + new_y = repeat_P(y_o[b]) + reject_y.append(new_y) + reject_y_lens.append(len(new_y)) + elif process_item_idx==1: + new_y = lost_P(y_o[b]) + reject_y.append(new_y) + reject_y_lens.append(len(new_y)) + max_length = max(reject_y_lens) + for b in range(bs): + pad_length = max_length - reject_y_lens[b] + reject_y[b] = torch.cat([reject_y[b], torch.zeros(pad_length, dtype=y_o.dtype, device=y_o.device)], dim=0) + + reject_y = torch.stack(reject_y, dim = 0) + reject_y_lens = torch.tensor(reject_y_lens, device=y_lens.device) + + return reject_y, reject_y_lens diff --git a/GPT_SoVITS/inference_webui.py b/GPT_SoVITS/inference_webui.py index fc8af08..a85b611 100644 --- a/GPT_SoVITS/inference_webui.py +++ b/GPT_SoVITS/inference_webui.py @@ -359,7 +359,7 @@ def merge_short_text_in_array(texts, threshold): result[len(result) - 1] += text return result -def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language, how_to_cut=i18n("不切")): +def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language, how_to_cut=i18n("不切"), top_k=20, top_p=0.6, temperature=0.6): t0 = ttime() prompt_language = dict_language[prompt_language] text_language = dict_language[text_language] @@ -438,7 +438,9 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language, prompt, bert, # prompt_phone_len=ph_offset, - top_k=config["inference"]["top_k"], + top_k=top_k, + top_p=top_p, + temperature=temperature, early_stop_num=hz * max_sec, ) t3 = ttime() @@ -615,6 +617,10 @@ with gr.Blocks(title="GPT-SoVITS WebUI") as app: value=i18n("凑四句一切"), interactive=True, ) + with gr.Row(): + top_k = gr.Slider(minimum=1,maximum=100,step=1,label=i18n("top_k"),value=20,interactive=True) + top_p = gr.Slider(minimum=0,maximum=1,step=0.05,label=i18n("top_p"),value=0.6,interactive=True) + temperature = gr.Slider(minimum=0,maximum=1,step=0.05,label=i18n("temperature"),value=0.6,interactive=True) inference_button = gr.Button(i18n("合成语音"), variant="primary") output = gr.Audio(label=i18n("输出的语音"))