mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2025-04-05 19:41:56 +08:00
Merge c8c5d6a4cc22d133990eb66b5d2883690bdcba96 into 959a2ddbeb36844afbd97aadabbaa17e39e3b616
This commit is contained in:
commit
827b84205e
@ -629,11 +629,6 @@ class Text2SemanticDecoder(nn.Module):
|
||||
|
||||
xy_mask = torch.concat([x_mask, y_mask], dim=0).view(1 , src_len, src_len).repeat(bsz, 1, 1).to(x.device)
|
||||
_xy_padding_mask = xy_padding_mask.view(bsz, 1, src_len).repeat(1, src_len, 1)
|
||||
|
||||
for i in range(bsz):
|
||||
l = x_lens[i]
|
||||
_xy_padding_mask[i,l:max_len,:]=True
|
||||
|
||||
xy_attn_mask = xy_mask.logical_or(_xy_padding_mask)
|
||||
xy_attn_mask = xy_attn_mask.unsqueeze(1).expand(-1, self.num_head, -1, -1)
|
||||
xy_attn_mask = xy_attn_mask.bool()
|
||||
@ -645,7 +640,7 @@ class Text2SemanticDecoder(nn.Module):
|
||||
idx_list = [None]*y.shape[0]
|
||||
for idx in tqdm(range(1500)):
|
||||
if idx == 0:
|
||||
xy_dec, k_cache, v_cache = self.t2s_transformer.process_prompt(xy_pos, xy_attn_mask, xy_padding_mask, False)
|
||||
xy_dec, k_cache, v_cache = self.t2s_transformer.process_prompt(xy_pos, xy_attn_mask, None, False)
|
||||
else:
|
||||
xy_dec, k_cache, v_cache = self.t2s_transformer.decode_next_token(xy_pos, k_cache, v_cache, xy_attn_mask, False)
|
||||
logits = self.ar_predict_layer(
|
||||
|
Loading…
x
Reference in New Issue
Block a user