mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2025-10-07 07:02:57 +08:00
test: 嵌入说话人向量的可行性
This commit is contained in:
parent
35106c977e
commit
6d2e5319d3
@ -23,7 +23,8 @@ class Text2SemanticLightningModule(LightningModule):
|
|||||||
# print(self.load_state_dict(torch.load(pretrained_s1,map_location="cpu")["state_dict"]))
|
# print(self.load_state_dict(torch.load(pretrained_s1,map_location="cpu")["state_dict"]))
|
||||||
print(
|
print(
|
||||||
self.load_state_dict(
|
self.load_state_dict(
|
||||||
torch.load(pretrained_s1, map_location="cpu")["weight"]
|
torch.load(pretrained_s1, map_location="cpu")["weight"],
|
||||||
|
strict=False,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
if is_train:
|
if is_train:
|
||||||
@ -31,6 +32,12 @@ class Text2SemanticLightningModule(LightningModule):
|
|||||||
self.save_hyperparameters()
|
self.save_hyperparameters()
|
||||||
self.eval_dir = output_dir / "eval"
|
self.eval_dir = output_dir / "eval"
|
||||||
self.eval_dir.mkdir(parents=True, exist_ok=True)
|
self.eval_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
for param in self.model.parameters():
|
||||||
|
param.requires_grad = False
|
||||||
|
self.model.speaker_proj.weight.requires_grad = True
|
||||||
|
self.model.speaker_proj.bias.requires_grad = True
|
||||||
|
self.model.speaker_proj.train()
|
||||||
|
self.model.speaker_feat.requires_grad = True
|
||||||
|
|
||||||
def training_step(self, batch: Dict, batch_idx: int):
|
def training_step(self, batch: Dict, batch_idx: int):
|
||||||
opt = self.optimizers()
|
opt = self.optimizers()
|
||||||
@ -47,6 +54,7 @@ class Text2SemanticLightningModule(LightningModule):
|
|||||||
opt.step()
|
opt.step()
|
||||||
opt.zero_grad()
|
opt.zero_grad()
|
||||||
scheduler.step()
|
scheduler.step()
|
||||||
|
torch.save(self.model.speaker_feat.data, "C:/Users/86150/Desktop/GPT-SoVITS/zyj.pt")
|
||||||
|
|
||||||
self.log(
|
self.log(
|
||||||
"total_loss",
|
"total_loss",
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
# modified from https://github.com/feng-yufei/shared_debugging_code/blob/main/model/t2s_model.py
|
# modified from https://github.com/feng-yufei/shared_debugging_code/blob/main/model/t2s_model.py
|
||||||
import torch
|
import torch
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
import os
|
||||||
|
|
||||||
from AR.models.utils import make_pad_mask
|
from AR.models.utils import make_pad_mask
|
||||||
from AR.models.utils import (
|
from AR.models.utils import (
|
||||||
@ -51,6 +52,12 @@ class Text2SemanticDecoder(nn.Module):
|
|||||||
# should be same as num of kmeans bin
|
# should be same as num of kmeans bin
|
||||||
# assert self.EOS == 1024
|
# assert self.EOS == 1024
|
||||||
self.bert_proj = nn.Linear(1024, self.embedding_dim)
|
self.bert_proj = nn.Linear(1024, self.embedding_dim)
|
||||||
|
self.speaker_proj = nn.Linear(1024, self.embedding_dim)
|
||||||
|
self.path_speaker = "C:/Users/86150/Desktop/GPT-SoVITS/zyj.pt"
|
||||||
|
if not os.path.exists(self.path_speaker):
|
||||||
|
self.speaker_feat = nn.Parameter(torch.randn(1024) * 0.1)
|
||||||
|
else:
|
||||||
|
self.speaker_feat = nn.Parameter(torch.load(self.path_speaker, map_location="cpu"))
|
||||||
self.ar_text_embedding = TokenEmbedding(
|
self.ar_text_embedding = TokenEmbedding(
|
||||||
self.embedding_dim, self.phoneme_vocab_size, self.p_dropout
|
self.embedding_dim, self.phoneme_vocab_size, self.p_dropout
|
||||||
)
|
)
|
||||||
@ -104,6 +111,7 @@ class Text2SemanticDecoder(nn.Module):
|
|||||||
x_len = x_lens.max()
|
x_len = x_lens.max()
|
||||||
y_len = y_lens.max()
|
y_len = y_lens.max()
|
||||||
y_emb = self.ar_audio_embedding(y)
|
y_emb = self.ar_audio_embedding(y)
|
||||||
|
y_emb = y_emb + self.speaker_proj(self.speaker_feat).view(1,1,-1)
|
||||||
y_pos = self.ar_audio_position(y_emb)
|
y_pos = self.ar_audio_position(y_emb)
|
||||||
|
|
||||||
xy_padding_mask = torch.concat([x_mask, y_mask], dim=1)
|
xy_padding_mask = torch.concat([x_mask, y_mask], dim=1)
|
||||||
@ -330,6 +338,7 @@ class Text2SemanticDecoder(nn.Module):
|
|||||||
top_p: int = 100,
|
top_p: int = 100,
|
||||||
early_stop_num: int = -1,
|
early_stop_num: int = -1,
|
||||||
temperature: float = 1.0,
|
temperature: float = 1.0,
|
||||||
|
use_speaker_feat=True,
|
||||||
):
|
):
|
||||||
x = self.ar_text_embedding(x)
|
x = self.ar_text_embedding(x)
|
||||||
x = x + self.bert_proj(bert_feature.transpose(1, 2))
|
x = x + self.bert_proj(bert_feature.transpose(1, 2))
|
||||||
@ -353,12 +362,16 @@ class Text2SemanticDecoder(nn.Module):
|
|||||||
"first_infer": 1,
|
"first_infer": 1,
|
||||||
"stage": 0,
|
"stage": 0,
|
||||||
}
|
}
|
||||||
|
if use_speaker_feat:
|
||||||
|
speaker_feat = self.speaker_proj(self.speaker_feat).view(1,1,-1)
|
||||||
|
else:
|
||||||
|
speaker_feat = 0
|
||||||
for idx in tqdm(range(1500)):
|
for idx in tqdm(range(1500)):
|
||||||
if cache["first_infer"] == 1:
|
if cache["first_infer"] == 1:
|
||||||
y_emb = self.ar_audio_embedding(y)
|
y_emb = self.ar_audio_embedding(y) + speaker_feat
|
||||||
else:
|
else:
|
||||||
y_emb = torch.cat(
|
y_emb = torch.cat(
|
||||||
[cache["y_emb"], self.ar_audio_embedding(y[:, -1:])], 1
|
[cache["y_emb"], self.ar_audio_embedding(y[:, -1:]) + speaker_feat], 1
|
||||||
)
|
)
|
||||||
cache["y_emb"] = y_emb
|
cache["y_emb"] = y_emb
|
||||||
y_pos = self.ar_audio_position(y_emb)
|
y_pos = self.ar_audio_position(y_emb)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user