test: 嵌入说话人向量的可行性

This commit is contained in:
bwnotfound 2024-02-12 21:46:33 +08:00
parent 35106c977e
commit 6d2e5319d3
2 changed files with 24 additions and 3 deletions

View File

@ -23,7 +23,8 @@ class Text2SemanticLightningModule(LightningModule):
# print(self.load_state_dict(torch.load(pretrained_s1,map_location="cpu")["state_dict"])) # print(self.load_state_dict(torch.load(pretrained_s1,map_location="cpu")["state_dict"]))
print( print(
self.load_state_dict( self.load_state_dict(
torch.load(pretrained_s1, map_location="cpu")["weight"] torch.load(pretrained_s1, map_location="cpu")["weight"],
strict=False,
) )
) )
if is_train: if is_train:
@ -31,6 +32,12 @@ class Text2SemanticLightningModule(LightningModule):
self.save_hyperparameters() self.save_hyperparameters()
self.eval_dir = output_dir / "eval" self.eval_dir = output_dir / "eval"
self.eval_dir.mkdir(parents=True, exist_ok=True) self.eval_dir.mkdir(parents=True, exist_ok=True)
for param in self.model.parameters():
param.requires_grad = False
self.model.speaker_proj.weight.requires_grad = True
self.model.speaker_proj.bias.requires_grad = True
self.model.speaker_proj.train()
self.model.speaker_feat.requires_grad = True
def training_step(self, batch: Dict, batch_idx: int): def training_step(self, batch: Dict, batch_idx: int):
opt = self.optimizers() opt = self.optimizers()
@ -47,6 +54,7 @@ class Text2SemanticLightningModule(LightningModule):
opt.step() opt.step()
opt.zero_grad() opt.zero_grad()
scheduler.step() scheduler.step()
torch.save(self.model.speaker_feat.data, "C:/Users/86150/Desktop/GPT-SoVITS/zyj.pt")
self.log( self.log(
"total_loss", "total_loss",

View File

@ -1,6 +1,7 @@
# modified from https://github.com/feng-yufei/shared_debugging_code/blob/main/model/t2s_model.py # modified from https://github.com/feng-yufei/shared_debugging_code/blob/main/model/t2s_model.py
import torch import torch
from tqdm import tqdm from tqdm import tqdm
import os
from AR.models.utils import make_pad_mask from AR.models.utils import make_pad_mask
from AR.models.utils import ( from AR.models.utils import (
@ -51,6 +52,12 @@ class Text2SemanticDecoder(nn.Module):
# should be same as num of kmeans bin # should be same as num of kmeans bin
# assert self.EOS == 1024 # assert self.EOS == 1024
self.bert_proj = nn.Linear(1024, self.embedding_dim) self.bert_proj = nn.Linear(1024, self.embedding_dim)
self.speaker_proj = nn.Linear(1024, self.embedding_dim)
self.path_speaker = "C:/Users/86150/Desktop/GPT-SoVITS/zyj.pt"
if not os.path.exists(self.path_speaker):
self.speaker_feat = nn.Parameter(torch.randn(1024) * 0.1)
else:
self.speaker_feat = nn.Parameter(torch.load(self.path_speaker, map_location="cpu"))
self.ar_text_embedding = TokenEmbedding( self.ar_text_embedding = TokenEmbedding(
self.embedding_dim, self.phoneme_vocab_size, self.p_dropout self.embedding_dim, self.phoneme_vocab_size, self.p_dropout
) )
@ -104,6 +111,7 @@ class Text2SemanticDecoder(nn.Module):
x_len = x_lens.max() x_len = x_lens.max()
y_len = y_lens.max() y_len = y_lens.max()
y_emb = self.ar_audio_embedding(y) y_emb = self.ar_audio_embedding(y)
y_emb = y_emb + self.speaker_proj(self.speaker_feat).view(1,1,-1)
y_pos = self.ar_audio_position(y_emb) y_pos = self.ar_audio_position(y_emb)
xy_padding_mask = torch.concat([x_mask, y_mask], dim=1) xy_padding_mask = torch.concat([x_mask, y_mask], dim=1)
@ -330,6 +338,7 @@ class Text2SemanticDecoder(nn.Module):
top_p: int = 100, top_p: int = 100,
early_stop_num: int = -1, early_stop_num: int = -1,
temperature: float = 1.0, temperature: float = 1.0,
use_speaker_feat=True,
): ):
x = self.ar_text_embedding(x) x = self.ar_text_embedding(x)
x = x + self.bert_proj(bert_feature.transpose(1, 2)) x = x + self.bert_proj(bert_feature.transpose(1, 2))
@ -353,12 +362,16 @@ class Text2SemanticDecoder(nn.Module):
"first_infer": 1, "first_infer": 1,
"stage": 0, "stage": 0,
} }
if use_speaker_feat:
speaker_feat = self.speaker_proj(self.speaker_feat).view(1,1,-1)
else:
speaker_feat = 0
for idx in tqdm(range(1500)): for idx in tqdm(range(1500)):
if cache["first_infer"] == 1: if cache["first_infer"] == 1:
y_emb = self.ar_audio_embedding(y) y_emb = self.ar_audio_embedding(y) + speaker_feat
else: else:
y_emb = torch.cat( y_emb = torch.cat(
[cache["y_emb"], self.ar_audio_embedding(y[:, -1:])], 1 [cache["y_emb"], self.ar_audio_embedding(y[:, -1:]) + speaker_feat], 1
) )
cache["y_emb"] = y_emb cache["y_emb"] = y_emb
y_pos = self.ar_audio_position(y_emb) y_pos = self.ar_audio_position(y_emb)