mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2025-04-06 03:57:44 +08:00
add DPO training
This commit is contained in:
parent
41041715a4
commit
070ac9b2b2
1
.gitignore
vendored
1
.gitignore
vendored
@ -8,3 +8,4 @@ output
|
|||||||
logs
|
logs
|
||||||
reference
|
reference
|
||||||
SoVITS_weights
|
SoVITS_weights
|
||||||
|
GPT_weights
|
@ -8,6 +8,9 @@ from AR.models.utils import (
|
|||||||
sample,
|
sample,
|
||||||
logits_to_probs,
|
logits_to_probs,
|
||||||
multinomial_sample_one_no_sync,
|
multinomial_sample_one_no_sync,
|
||||||
|
dpo_loss,
|
||||||
|
make_reject_y,
|
||||||
|
get_batch_logps
|
||||||
)
|
)
|
||||||
from AR.modules.embedding import SinePositionalEmbedding
|
from AR.modules.embedding import SinePositionalEmbedding
|
||||||
from AR.modules.embedding import TokenEmbedding
|
from AR.modules.embedding import TokenEmbedding
|
||||||
@ -85,11 +88,104 @@ class Text2SemanticDecoder(nn.Module):
|
|||||||
ignore_index=self.EOS,
|
ignore_index=self.EOS,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def make_input_data(self, x, x_lens, y, y_lens, bert_feature):
|
||||||
|
x = self.ar_text_embedding(x)
|
||||||
|
x = x + self.bert_proj(bert_feature.transpose(1, 2))
|
||||||
|
x = self.ar_text_position(x)
|
||||||
|
x_mask = make_pad_mask(x_lens)
|
||||||
|
|
||||||
|
y_mask = make_pad_mask(y_lens)
|
||||||
|
y_mask_int = y_mask.type(torch.int64)
|
||||||
|
codes = y.type(torch.int64) * (1 - y_mask_int)
|
||||||
|
|
||||||
|
# Training
|
||||||
|
# AR Decoder
|
||||||
|
y, targets = self.pad_y_eos(codes, y_mask_int, eos_id=self.EOS)
|
||||||
|
x_len = x_lens.max()
|
||||||
|
y_len = y_lens.max()
|
||||||
|
y_emb = self.ar_audio_embedding(y)
|
||||||
|
y_pos = self.ar_audio_position(y_emb)
|
||||||
|
|
||||||
|
xy_padding_mask = torch.concat([x_mask, y_mask], dim=1)
|
||||||
|
|
||||||
|
ar_xy_padding_mask = xy_padding_mask
|
||||||
|
|
||||||
|
x_attn_mask = F.pad(
|
||||||
|
torch.zeros((x_len, x_len), dtype=torch.bool, device=x.device),
|
||||||
|
(0, y_len),
|
||||||
|
value=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
y_attn_mask = F.pad(
|
||||||
|
torch.triu(
|
||||||
|
torch.ones(y_len, y_len, dtype=torch.bool, device=x.device),
|
||||||
|
diagonal=1,
|
||||||
|
),
|
||||||
|
(x_len, 0),
|
||||||
|
value=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
xy_attn_mask = torch.concat([x_attn_mask, y_attn_mask], dim=0)
|
||||||
|
bsz, src_len = x.shape[0], x_len + y_len
|
||||||
|
_xy_padding_mask = (
|
||||||
|
ar_xy_padding_mask.view(bsz, 1, 1, src_len)
|
||||||
|
.expand(-1, self.num_head, -1, -1)
|
||||||
|
.reshape(bsz * self.num_head, 1, src_len)
|
||||||
|
)
|
||||||
|
xy_attn_mask = xy_attn_mask.logical_or(_xy_padding_mask)
|
||||||
|
new_attn_mask = torch.zeros_like(xy_attn_mask, dtype=x.dtype)
|
||||||
|
new_attn_mask.masked_fill_(xy_attn_mask, float("-inf"))
|
||||||
|
xy_attn_mask = new_attn_mask
|
||||||
|
# x 和完整的 y 一次性输入模型
|
||||||
|
xy_pos = torch.concat([x, y_pos], dim=1)
|
||||||
|
|
||||||
|
return xy_pos, xy_attn_mask, targets
|
||||||
|
|
||||||
def forward(self, x, x_lens, y, y_lens, bert_feature):
|
def forward(self, x, x_lens, y, y_lens, bert_feature):
|
||||||
"""
|
"""
|
||||||
x: phoneme_ids
|
x: phoneme_ids
|
||||||
y: semantic_ids
|
y: semantic_ids
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
reject_y, reject_y_lens = make_reject_y(y, y_lens)
|
||||||
|
|
||||||
|
xy_pos, xy_attn_mask, targets = self.make_input_data(x, x_lens, y, y_lens, bert_feature)
|
||||||
|
|
||||||
|
xy_dec, _ = self.h(
|
||||||
|
(xy_pos, None),
|
||||||
|
mask=xy_attn_mask,
|
||||||
|
)
|
||||||
|
x_len = x_lens.max()
|
||||||
|
logits = self.ar_predict_layer(xy_dec[:, x_len:])
|
||||||
|
|
||||||
|
###### DPO #############
|
||||||
|
reject_xy_pos, reject_xy_attn_mask, reject_targets = self.make_input_data(x, x_lens, reject_y, reject_y_lens, bert_feature)
|
||||||
|
|
||||||
|
reject_xy_dec, _ = self.h(
|
||||||
|
(reject_xy_pos, None),
|
||||||
|
mask=reject_xy_attn_mask,
|
||||||
|
)
|
||||||
|
x_len = x_lens.max()
|
||||||
|
reject_logits = self.ar_predict_layer(reject_xy_dec[:, x_len:])
|
||||||
|
|
||||||
|
# loss
|
||||||
|
# from feiteng: 每次 duration 越多, 梯度更新也应该更多, 所以用 sum
|
||||||
|
|
||||||
|
loss_1 = F.cross_entropy(logits.permute(0, 2, 1), targets, reduction="sum")
|
||||||
|
acc = self.ar_accuracy_metric(logits.permute(0, 2, 1).detach(), targets).item()
|
||||||
|
|
||||||
|
A_logits, R_logits = get_batch_logps(logits, reject_logits, targets, reject_targets)
|
||||||
|
loss_2, _, _ = dpo_loss(A_logits, R_logits, 0, 0, 0.2, reference_free=True)
|
||||||
|
|
||||||
|
loss = loss_1 + loss_2
|
||||||
|
|
||||||
|
return loss, acc
|
||||||
|
|
||||||
|
def forward_old(self, x, x_lens, y, y_lens, bert_feature):
|
||||||
|
"""
|
||||||
|
x: phoneme_ids
|
||||||
|
y: semantic_ids
|
||||||
|
"""
|
||||||
x = self.ar_text_embedding(x)
|
x = self.ar_text_embedding(x)
|
||||||
x = x + self.bert_proj(bert_feature.transpose(1, 2))
|
x = x + self.bert_proj(bert_feature.transpose(1, 2))
|
||||||
x = self.ar_text_position(x)
|
x = self.ar_text_position(x)
|
||||||
@ -231,6 +327,7 @@ class Text2SemanticDecoder(nn.Module):
|
|||||||
prompts, ####参考音频token
|
prompts, ####参考音频token
|
||||||
bert_feature,
|
bert_feature,
|
||||||
top_k: int = -100,
|
top_k: int = -100,
|
||||||
|
top_p: int = 100,
|
||||||
early_stop_num: int = -1,
|
early_stop_num: int = -1,
|
||||||
temperature: float = 1.0,
|
temperature: float = 1.0,
|
||||||
):
|
):
|
||||||
@ -305,7 +402,7 @@ class Text2SemanticDecoder(nn.Module):
|
|||||||
if(idx==0):###第一次跑不能EOS否则没有了
|
if(idx==0):###第一次跑不能EOS否则没有了
|
||||||
logits = logits[:, :-1] ###刨除1024终止符号的概率
|
logits = logits[:, :-1] ###刨除1024终止符号的概率
|
||||||
samples = sample(
|
samples = sample(
|
||||||
logits[0], y, top_k=top_k, top_p=1.0, repetition_penalty=1.35
|
logits[0], y, top_k=top_k, top_p=top_p, repetition_penalty=1.05, temperature=temperature
|
||||||
)[0].unsqueeze(0)
|
)[0].unsqueeze(0)
|
||||||
if early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num:
|
if early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num:
|
||||||
print("use early stop num:", early_stop_num)
|
print("use early stop num:", early_stop_num)
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
# modified from https://github.com/feng-yufei/shared_debugging_code/blob/main/model/utils.py\
|
# modified from https://github.com/feng-yufei/shared_debugging_code/blob/main/model/utils.py\
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
from typing import Tuple
|
||||||
|
|
||||||
def sequence_mask(length, max_length=None):
|
def sequence_mask(length, max_length=None):
|
||||||
if max_length is None:
|
if max_length is None:
|
||||||
@ -158,3 +158,70 @@ def sample(
|
|||||||
)
|
)
|
||||||
idx_next = multinomial_sample_one_no_sync(probs)
|
idx_next = multinomial_sample_one_no_sync(probs)
|
||||||
return idx_next, probs
|
return idx_next, probs
|
||||||
|
|
||||||
|
def dpo_loss(policy_chosen_logps: torch.FloatTensor,
|
||||||
|
policy_rejected_logps: torch.FloatTensor,
|
||||||
|
reference_chosen_logps: torch.FloatTensor,
|
||||||
|
reference_rejected_logps: torch.FloatTensor,
|
||||||
|
beta: float,
|
||||||
|
reference_free: bool = False) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
|
||||||
|
pi_logratios = policy_chosen_logps - policy_rejected_logps
|
||||||
|
ref_logratios = reference_chosen_logps - reference_rejected_logps
|
||||||
|
|
||||||
|
if reference_free:
|
||||||
|
ref_logratios = 0
|
||||||
|
|
||||||
|
logits = pi_logratios - ref_logratios
|
||||||
|
|
||||||
|
losses = -F.logsigmoid(beta * logits)
|
||||||
|
chosen_rewards = beta * (policy_chosen_logps - reference_chosen_logps).detach()
|
||||||
|
rejected_rewards = beta * (policy_rejected_logps - reference_rejected_logps).detach()
|
||||||
|
|
||||||
|
return losses.mean(), chosen_rewards, rejected_rewards
|
||||||
|
|
||||||
|
def get_batch_logps(logits_target: torch.FloatTensor, logits_reject: torch.FloatTensor, labels_target: torch.LongTensor, labels_reject: torch.LongTensor, average_log_prob: bool = False) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
|
||||||
|
|
||||||
|
# dummy token; we'll ignore the losses on these tokens later
|
||||||
|
|
||||||
|
per_token_logps_target = torch.gather(logits_target.log_softmax(-1), dim=2, index=labels_target.unsqueeze(2)).squeeze(2)
|
||||||
|
per_token_logps_reject = torch.gather(logits_reject.log_softmax(-1), dim=2, index=labels_reject.unsqueeze(2)).squeeze(2)
|
||||||
|
|
||||||
|
return per_token_logps_target.sum(-1), per_token_logps_reject.sum(-1)
|
||||||
|
|
||||||
|
def make_reject_y(y_o, y_lens):
|
||||||
|
def repeat_P(y):
|
||||||
|
range_idx, _ = torch.randint(0, len(y), size=(2,)).sort()
|
||||||
|
pre = y[:range_idx[0]]
|
||||||
|
shf = y[range_idx[1]:]
|
||||||
|
range_text = y[range_idx[0]:range_idx[1]]
|
||||||
|
new_y = torch.cat([pre, range_text, range_text, shf])
|
||||||
|
return new_y
|
||||||
|
def lost_P(y):
|
||||||
|
range_idx, _ = torch.randint(0, len(y), size=(2,)).sort()
|
||||||
|
pre = y[:range_idx[0]]
|
||||||
|
shf = y[range_idx[1]:]
|
||||||
|
range_text = y[range_idx[0]:range_idx[1]]
|
||||||
|
new_y = torch.cat([pre, shf])
|
||||||
|
return new_y
|
||||||
|
bs = len(y_lens)
|
||||||
|
reject_y = []
|
||||||
|
reject_y_lens = []
|
||||||
|
for b in range(bs):
|
||||||
|
process_item_idx = torch.randint(0, 1, size=(1, ))[0]
|
||||||
|
if process_item_idx == 0:
|
||||||
|
new_y = repeat_P(y_o[b])
|
||||||
|
reject_y.append(new_y)
|
||||||
|
reject_y_lens.append(len(new_y))
|
||||||
|
elif process_item_idx==1:
|
||||||
|
new_y = lost_P(y_o[b])
|
||||||
|
reject_y.append(new_y)
|
||||||
|
reject_y_lens.append(len(new_y))
|
||||||
|
max_length = max(reject_y_lens)
|
||||||
|
for b in range(bs):
|
||||||
|
pad_length = max_length - reject_y_lens[b]
|
||||||
|
reject_y[b] = torch.cat([reject_y[b], torch.zeros(pad_length, dtype=y_o.dtype, device=y_o.device)], dim=0)
|
||||||
|
|
||||||
|
reject_y = torch.stack(reject_y, dim = 0)
|
||||||
|
reject_y_lens = torch.tensor(reject_y_lens, device=y_lens.device)
|
||||||
|
|
||||||
|
return reject_y, reject_y_lens
|
||||||
|
@ -359,7 +359,7 @@ def merge_short_text_in_array(texts, threshold):
|
|||||||
result[len(result) - 1] += text
|
result[len(result) - 1] += text
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language, how_to_cut=i18n("不切")):
|
def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language, how_to_cut=i18n("不切"), top_k=20, top_p=0.6, temperature=0.6):
|
||||||
t0 = ttime()
|
t0 = ttime()
|
||||||
prompt_language = dict_language[prompt_language]
|
prompt_language = dict_language[prompt_language]
|
||||||
text_language = dict_language[text_language]
|
text_language = dict_language[text_language]
|
||||||
@ -438,7 +438,9 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language,
|
|||||||
prompt,
|
prompt,
|
||||||
bert,
|
bert,
|
||||||
# prompt_phone_len=ph_offset,
|
# prompt_phone_len=ph_offset,
|
||||||
top_k=config["inference"]["top_k"],
|
top_k=top_k,
|
||||||
|
top_p=top_p,
|
||||||
|
temperature=temperature,
|
||||||
early_stop_num=hz * max_sec,
|
early_stop_num=hz * max_sec,
|
||||||
)
|
)
|
||||||
t3 = ttime()
|
t3 = ttime()
|
||||||
@ -615,6 +617,10 @@ with gr.Blocks(title="GPT-SoVITS WebUI") as app:
|
|||||||
value=i18n("凑四句一切"),
|
value=i18n("凑四句一切"),
|
||||||
interactive=True,
|
interactive=True,
|
||||||
)
|
)
|
||||||
|
with gr.Row():
|
||||||
|
top_k = gr.Slider(minimum=1,maximum=100,step=1,label=i18n("top_k"),value=20,interactive=True)
|
||||||
|
top_p = gr.Slider(minimum=0,maximum=1,step=0.05,label=i18n("top_p"),value=0.6,interactive=True)
|
||||||
|
temperature = gr.Slider(minimum=0,maximum=1,step=0.05,label=i18n("temperature"),value=0.6,interactive=True)
|
||||||
inference_button = gr.Button(i18n("合成语音"), variant="primary")
|
inference_button = gr.Button(i18n("合成语音"), variant="primary")
|
||||||
output = gr.Audio(label=i18n("输出的语音"))
|
output = gr.Audio(label=i18n("输出的语音"))
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user