mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2025-10-07 07:02:57 +08:00
test: 嵌入说话人向量的可行性
This commit is contained in:
parent
35106c977e
commit
6d2e5319d3
@ -23,7 +23,8 @@ class Text2SemanticLightningModule(LightningModule):
|
||||
# print(self.load_state_dict(torch.load(pretrained_s1,map_location="cpu")["state_dict"]))
|
||||
print(
|
||||
self.load_state_dict(
|
||||
torch.load(pretrained_s1, map_location="cpu")["weight"]
|
||||
torch.load(pretrained_s1, map_location="cpu")["weight"],
|
||||
strict=False,
|
||||
)
|
||||
)
|
||||
if is_train:
|
||||
@ -31,6 +32,12 @@ class Text2SemanticLightningModule(LightningModule):
|
||||
self.save_hyperparameters()
|
||||
self.eval_dir = output_dir / "eval"
|
||||
self.eval_dir.mkdir(parents=True, exist_ok=True)
|
||||
for param in self.model.parameters():
|
||||
param.requires_grad = False
|
||||
self.model.speaker_proj.weight.requires_grad = True
|
||||
self.model.speaker_proj.bias.requires_grad = True
|
||||
self.model.speaker_proj.train()
|
||||
self.model.speaker_feat.requires_grad = True
|
||||
|
||||
def training_step(self, batch: Dict, batch_idx: int):
|
||||
opt = self.optimizers()
|
||||
@ -47,6 +54,7 @@ class Text2SemanticLightningModule(LightningModule):
|
||||
opt.step()
|
||||
opt.zero_grad()
|
||||
scheduler.step()
|
||||
torch.save(self.model.speaker_feat.data, "C:/Users/86150/Desktop/GPT-SoVITS/zyj.pt")
|
||||
|
||||
self.log(
|
||||
"total_loss",
|
||||
|
@ -1,6 +1,7 @@
|
||||
# modified from https://github.com/feng-yufei/shared_debugging_code/blob/main/model/t2s_model.py
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
import os
|
||||
|
||||
from AR.models.utils import make_pad_mask
|
||||
from AR.models.utils import (
|
||||
@ -51,6 +52,12 @@ class Text2SemanticDecoder(nn.Module):
|
||||
# should be same as num of kmeans bin
|
||||
# assert self.EOS == 1024
|
||||
self.bert_proj = nn.Linear(1024, self.embedding_dim)
|
||||
self.speaker_proj = nn.Linear(1024, self.embedding_dim)
|
||||
self.path_speaker = "C:/Users/86150/Desktop/GPT-SoVITS/zyj.pt"
|
||||
if not os.path.exists(self.path_speaker):
|
||||
self.speaker_feat = nn.Parameter(torch.randn(1024) * 0.1)
|
||||
else:
|
||||
self.speaker_feat = nn.Parameter(torch.load(self.path_speaker, map_location="cpu"))
|
||||
self.ar_text_embedding = TokenEmbedding(
|
||||
self.embedding_dim, self.phoneme_vocab_size, self.p_dropout
|
||||
)
|
||||
@ -104,6 +111,7 @@ class Text2SemanticDecoder(nn.Module):
|
||||
x_len = x_lens.max()
|
||||
y_len = y_lens.max()
|
||||
y_emb = self.ar_audio_embedding(y)
|
||||
y_emb = y_emb + self.speaker_proj(self.speaker_feat).view(1,1,-1)
|
||||
y_pos = self.ar_audio_position(y_emb)
|
||||
|
||||
xy_padding_mask = torch.concat([x_mask, y_mask], dim=1)
|
||||
@ -330,6 +338,7 @@ class Text2SemanticDecoder(nn.Module):
|
||||
top_p: int = 100,
|
||||
early_stop_num: int = -1,
|
||||
temperature: float = 1.0,
|
||||
use_speaker_feat=True,
|
||||
):
|
||||
x = self.ar_text_embedding(x)
|
||||
x = x + self.bert_proj(bert_feature.transpose(1, 2))
|
||||
@ -353,12 +362,16 @@ class Text2SemanticDecoder(nn.Module):
|
||||
"first_infer": 1,
|
||||
"stage": 0,
|
||||
}
|
||||
if use_speaker_feat:
|
||||
speaker_feat = self.speaker_proj(self.speaker_feat).view(1,1,-1)
|
||||
else:
|
||||
speaker_feat = 0
|
||||
for idx in tqdm(range(1500)):
|
||||
if cache["first_infer"] == 1:
|
||||
y_emb = self.ar_audio_embedding(y)
|
||||
y_emb = self.ar_audio_embedding(y) + speaker_feat
|
||||
else:
|
||||
y_emb = torch.cat(
|
||||
[cache["y_emb"], self.ar_audio_embedding(y[:, -1:])], 1
|
||||
[cache["y_emb"], self.ar_audio_embedding(y[:, -1:]) + speaker_feat], 1
|
||||
)
|
||||
cache["y_emb"] = y_emb
|
||||
y_pos = self.ar_audio_position(y_emb)
|
||||
|
Loading…
x
Reference in New Issue
Block a user