test: 嵌入说话人向量的可行性

This commit is contained in:
bwnotfound 2024-02-12 21:46:33 +08:00
parent 35106c977e
commit 6d2e5319d3
2 changed files with 24 additions and 3 deletions

View File

@ -23,7 +23,8 @@ class Text2SemanticLightningModule(LightningModule):
# print(self.load_state_dict(torch.load(pretrained_s1,map_location="cpu")["state_dict"]))
print(
self.load_state_dict(
torch.load(pretrained_s1, map_location="cpu")["weight"]
torch.load(pretrained_s1, map_location="cpu")["weight"],
strict=False,
)
)
if is_train:
@ -31,6 +32,12 @@ class Text2SemanticLightningModule(LightningModule):
self.save_hyperparameters()
self.eval_dir = output_dir / "eval"
self.eval_dir.mkdir(parents=True, exist_ok=True)
for param in self.model.parameters():
param.requires_grad = False
self.model.speaker_proj.weight.requires_grad = True
self.model.speaker_proj.bias.requires_grad = True
self.model.speaker_proj.train()
self.model.speaker_feat.requires_grad = True
def training_step(self, batch: Dict, batch_idx: int):
opt = self.optimizers()
@ -47,6 +54,7 @@ class Text2SemanticLightningModule(LightningModule):
opt.step()
opt.zero_grad()
scheduler.step()
torch.save(self.model.speaker_feat.data, "C:/Users/86150/Desktop/GPT-SoVITS/zyj.pt")
self.log(
"total_loss",

View File

@ -1,6 +1,7 @@
# modified from https://github.com/feng-yufei/shared_debugging_code/blob/main/model/t2s_model.py
import torch
from tqdm import tqdm
import os
from AR.models.utils import make_pad_mask
from AR.models.utils import (
@ -51,6 +52,12 @@ class Text2SemanticDecoder(nn.Module):
# should be same as num of kmeans bin
# assert self.EOS == 1024
self.bert_proj = nn.Linear(1024, self.embedding_dim)
self.speaker_proj = nn.Linear(1024, self.embedding_dim)
self.path_speaker = "C:/Users/86150/Desktop/GPT-SoVITS/zyj.pt"
if not os.path.exists(self.path_speaker):
self.speaker_feat = nn.Parameter(torch.randn(1024) * 0.1)
else:
self.speaker_feat = nn.Parameter(torch.load(self.path_speaker, map_location="cpu"))
self.ar_text_embedding = TokenEmbedding(
self.embedding_dim, self.phoneme_vocab_size, self.p_dropout
)
@ -104,6 +111,7 @@ class Text2SemanticDecoder(nn.Module):
x_len = x_lens.max()
y_len = y_lens.max()
y_emb = self.ar_audio_embedding(y)
y_emb = y_emb + self.speaker_proj(self.speaker_feat).view(1,1,-1)
y_pos = self.ar_audio_position(y_emb)
xy_padding_mask = torch.concat([x_mask, y_mask], dim=1)
@ -330,6 +338,7 @@ class Text2SemanticDecoder(nn.Module):
top_p: int = 100,
early_stop_num: int = -1,
temperature: float = 1.0,
use_speaker_feat=True,
):
x = self.ar_text_embedding(x)
x = x + self.bert_proj(bert_feature.transpose(1, 2))
@ -353,12 +362,16 @@ class Text2SemanticDecoder(nn.Module):
"first_infer": 1,
"stage": 0,
}
if use_speaker_feat:
speaker_feat = self.speaker_proj(self.speaker_feat).view(1,1,-1)
else:
speaker_feat = 0
for idx in tqdm(range(1500)):
if cache["first_infer"] == 1:
y_emb = self.ar_audio_embedding(y)
y_emb = self.ar_audio_embedding(y) + speaker_feat
else:
y_emb = torch.cat(
[cache["y_emb"], self.ar_audio_embedding(y[:, -1:])], 1
[cache["y_emb"], self.ar_audio_embedding(y[:, -1:]) + speaker_feat], 1
)
cache["y_emb"] = y_emb
y_pos = self.ar_audio_position(y_emb)