diff --git a/GPT_SoVITS/AR/models/t2s_lightning_module.py b/GPT_SoVITS/AR/models/t2s_lightning_module.py index f9dfc648..39427695 100644 --- a/GPT_SoVITS/AR/models/t2s_lightning_module.py +++ b/GPT_SoVITS/AR/models/t2s_lightning_module.py @@ -23,7 +23,8 @@ class Text2SemanticLightningModule(LightningModule): # print(self.load_state_dict(torch.load(pretrained_s1,map_location="cpu")["state_dict"])) print( self.load_state_dict( - torch.load(pretrained_s1, map_location="cpu")["weight"] + torch.load(pretrained_s1, map_location="cpu")["weight"], + strict=False, ) ) if is_train: @@ -31,6 +32,12 @@ class Text2SemanticLightningModule(LightningModule): self.save_hyperparameters() self.eval_dir = output_dir / "eval" self.eval_dir.mkdir(parents=True, exist_ok=True) + for param in self.model.parameters(): + param.requires_grad = False + self.model.speaker_proj.weight.requires_grad = True + self.model.speaker_proj.bias.requires_grad = True + self.model.speaker_proj.train() + self.model.speaker_feat.requires_grad = True def training_step(self, batch: Dict, batch_idx: int): opt = self.optimizers() @@ -47,6 +54,7 @@ class Text2SemanticLightningModule(LightningModule): opt.step() opt.zero_grad() scheduler.step() + torch.save(self.model.speaker_feat.data, "C:/Users/86150/Desktop/GPT-SoVITS/zyj.pt") self.log( "total_loss", diff --git a/GPT_SoVITS/AR/models/t2s_model.py b/GPT_SoVITS/AR/models/t2s_model.py index d3e550d1..111663ae 100644 --- a/GPT_SoVITS/AR/models/t2s_model.py +++ b/GPT_SoVITS/AR/models/t2s_model.py @@ -1,6 +1,7 @@ # modified from https://github.com/feng-yufei/shared_debugging_code/blob/main/model/t2s_model.py import torch from tqdm import tqdm +import os from AR.models.utils import make_pad_mask from AR.models.utils import ( @@ -51,6 +52,12 @@ class Text2SemanticDecoder(nn.Module): # should be same as num of kmeans bin # assert self.EOS == 1024 self.bert_proj = nn.Linear(1024, self.embedding_dim) + self.speaker_proj = nn.Linear(1024, self.embedding_dim) + self.path_speaker = "C:/Users/86150/Desktop/GPT-SoVITS/zyj.pt" + if not os.path.exists(self.path_speaker): + self.speaker_feat = nn.Parameter(torch.randn(1024) * 0.1) + else: + self.speaker_feat = nn.Parameter(torch.load(self.path_speaker, map_location="cpu")) self.ar_text_embedding = TokenEmbedding( self.embedding_dim, self.phoneme_vocab_size, self.p_dropout ) @@ -104,6 +111,7 @@ class Text2SemanticDecoder(nn.Module): x_len = x_lens.max() y_len = y_lens.max() y_emb = self.ar_audio_embedding(y) + y_emb = y_emb + self.speaker_proj(self.speaker_feat).view(1,1,-1) y_pos = self.ar_audio_position(y_emb) xy_padding_mask = torch.concat([x_mask, y_mask], dim=1) @@ -330,6 +338,7 @@ class Text2SemanticDecoder(nn.Module): top_p: int = 100, early_stop_num: int = -1, temperature: float = 1.0, + use_speaker_feat=True, ): x = self.ar_text_embedding(x) x = x + self.bert_proj(bert_feature.transpose(1, 2)) @@ -353,12 +362,16 @@ class Text2SemanticDecoder(nn.Module): "first_infer": 1, "stage": 0, } + if use_speaker_feat: + speaker_feat = self.speaker_proj(self.speaker_feat).view(1,1,-1) + else: + speaker_feat = 0 for idx in tqdm(range(1500)): if cache["first_infer"] == 1: - y_emb = self.ar_audio_embedding(y) + y_emb = self.ar_audio_embedding(y) + speaker_feat else: y_emb = torch.cat( - [cache["y_emb"], self.ar_audio_embedding(y[:, -1:])], 1 + [cache["y_emb"], self.ar_audio_embedding(y[:, -1:]) + speaker_feat], 1 ) cache["y_emb"] = y_emb y_pos = self.ar_audio_position(y_emb)