mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2025-08-10 10:09:51 +08:00
修复gpt的loss计算问题
This commit is contained in:
parent
b9211657d8
commit
e9475921d0
@ -356,7 +356,7 @@ class Text2SemanticDecoder(nn.Module):
|
|||||||
x = self.ar_text_embedding(x)
|
x = self.ar_text_embedding(x)
|
||||||
x = x + self.bert_proj(bert_feature.transpose(1, 2))
|
x = x + self.bert_proj(bert_feature.transpose(1, 2))
|
||||||
x = self.ar_text_position(x)
|
x = self.ar_text_position(x)
|
||||||
x_mask = make_pad_mask(x_lens)
|
x_mask = make_pad_mask_left(x_lens)
|
||||||
|
|
||||||
y_mask = make_pad_mask(y_lens)
|
y_mask = make_pad_mask(y_lens)
|
||||||
y_mask_int = y_mask.type(torch.int64)
|
y_mask_int = y_mask.type(torch.int64)
|
||||||
@ -420,7 +420,7 @@ class Text2SemanticDecoder(nn.Module):
|
|||||||
mask=xy_attn_mask,
|
mask=xy_attn_mask,
|
||||||
)
|
)
|
||||||
x_len = x_lens.max()
|
x_len = x_lens.max()
|
||||||
logits = self.ar_predict_layer(xy_dec[:, x_len:])
|
logits = self.ar_predict_layer(xy_dec[:, x_len-1:])
|
||||||
|
|
||||||
###### DPO #############
|
###### DPO #############
|
||||||
reject_xy_pos, reject_xy_attn_mask, reject_targets = self.make_input_data(
|
reject_xy_pos, reject_xy_attn_mask, reject_targets = self.make_input_data(
|
||||||
@ -432,7 +432,7 @@ class Text2SemanticDecoder(nn.Module):
|
|||||||
mask=reject_xy_attn_mask,
|
mask=reject_xy_attn_mask,
|
||||||
)
|
)
|
||||||
x_len = x_lens.max()
|
x_len = x_lens.max()
|
||||||
reject_logits = self.ar_predict_layer(reject_xy_dec[:, x_len:])
|
reject_logits = self.ar_predict_layer(reject_xy_dec[:, x_len-1:])
|
||||||
|
|
||||||
# loss
|
# loss
|
||||||
# from feiteng: 每次 duration 越多, 梯度更新也应该更多, 所以用 sum
|
# from feiteng: 每次 duration 越多, 梯度更新也应该更多, 所以用 sum
|
||||||
@ -455,7 +455,7 @@ class Text2SemanticDecoder(nn.Module):
|
|||||||
x = self.ar_text_embedding(x)
|
x = self.ar_text_embedding(x)
|
||||||
x = x + self.bert_proj(bert_feature.transpose(1, 2))
|
x = x + self.bert_proj(bert_feature.transpose(1, 2))
|
||||||
x = self.ar_text_position(x)
|
x = self.ar_text_position(x)
|
||||||
x_mask = make_pad_mask(x_lens)
|
x_mask = make_pad_mask_left(x_lens)
|
||||||
|
|
||||||
y_mask = make_pad_mask(y_lens)
|
y_mask = make_pad_mask(y_lens)
|
||||||
y_mask_int = y_mask.type(torch.int64)
|
y_mask_int = y_mask.type(torch.int64)
|
||||||
@ -502,7 +502,7 @@ class Text2SemanticDecoder(nn.Module):
|
|||||||
(xy_pos, None),
|
(xy_pos, None),
|
||||||
mask=xy_attn_mask,
|
mask=xy_attn_mask,
|
||||||
)
|
)
|
||||||
logits = self.ar_predict_layer(xy_dec[:, x_len:]).permute(0, 2, 1)
|
logits = self.ar_predict_layer(xy_dec[:, x_len-1:]).permute(0, 2, 1)
|
||||||
# loss
|
# loss
|
||||||
# from feiteng: 每次 duration 越多, 梯度更新也应该更多, 所以用 sum
|
# from feiteng: 每次 duration 越多, 梯度更新也应该更多, 所以用 sum
|
||||||
loss = F.cross_entropy(logits, targets, reduction="sum")
|
loss = F.cross_entropy(logits, targets, reduction="sum")
|
||||||
@ -578,7 +578,7 @@ class Text2SemanticDecoder(nn.Module):
|
|||||||
def pad_y_eos(self, y, y_mask_int, eos_id):
|
def pad_y_eos(self, y, y_mask_int, eos_id):
|
||||||
targets = F.pad(y, (0, 1), value=0) + eos_id * F.pad(y_mask_int, (0, 1), value=1)
|
targets = F.pad(y, (0, 1), value=0) + eos_id * F.pad(y_mask_int, (0, 1), value=1)
|
||||||
# 错位
|
# 错位
|
||||||
return targets[:, :-1], targets[:, 1:]
|
return targets[:, :-1], targets
|
||||||
|
|
||||||
def infer_panel_batch_infer(
|
def infer_panel_batch_infer(
|
||||||
self,
|
self,
|
||||||
|
@ -3,9 +3,9 @@ custom:
|
|||||||
cnhuhbert_base_path: GPT_SoVITS/pretrained_models/chinese-hubert-base
|
cnhuhbert_base_path: GPT_SoVITS/pretrained_models/chinese-hubert-base
|
||||||
device: cuda
|
device: cuda
|
||||||
is_half: true
|
is_half: true
|
||||||
t2s_weights_path: GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s1bert25hz-5kh-longer-epoch=12-step=369668.ckpt
|
t2s_weights_path: "GPT_weights_v2Pro/\u97F5\u513F-e15.ckpt"
|
||||||
version: v2
|
version: v2
|
||||||
vits_weights_path: GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s2G2333k.pth
|
vits_weights_path: "SoVITS_weights_v2/\u97F5\u513F_e25_s9475.pth"
|
||||||
v1:
|
v1:
|
||||||
bert_base_path: GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large
|
bert_base_path: GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large
|
||||||
cnhuhbert_base_path: GPT_SoVITS/pretrained_models/chinese-hubert-base
|
cnhuhbert_base_path: GPT_SoVITS/pretrained_models/chinese-hubert-base
|
||||||
|
Loading…
x
Reference in New Issue
Block a user