From a95e6c13b80736f500929b284d4d8f1aafec7c5a Mon Sep 17 00:00:00 2001 From: hcwu1993 <15855138469@163.com> Date: Thu, 2 May 2024 21:26:44 +0800 Subject: [PATCH] modify freeze_quantizer mode, avoid quantizer's codebook updating (#953) --- GPT_SoVITS/module/models.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/GPT_SoVITS/module/models.py b/GPT_SoVITS/module/models.py index 29676f4..0059033 100644 --- a/GPT_SoVITS/module/models.py +++ b/GPT_SoVITS/module/models.py @@ -15,6 +15,7 @@ from module.mrte_model import MRTE from module.quantize import ResidualVectorQuantizer from text import symbols from torch.cuda.amp import autocast +import contextlib class StochasticDurationPredictor(nn.Module): @@ -890,9 +891,10 @@ class SynthesizerTrn(nn.Module): self.ssl_proj = nn.Conv1d(ssl_dim, ssl_dim, 1, stride=1) self.quantizer = ResidualVectorQuantizer(dimension=ssl_dim, n_q=1, bins=1024) - if freeze_quantizer: - self.ssl_proj.requires_grad_(False) - self.quantizer.requires_grad_(False) + self.freeze_quantizer = freeze_quantizer + # if freeze_quantizer: + # self.ssl_proj.requires_grad_(False) + # self.quantizer.requires_grad_(False) #self.quantizer.eval() # self.enc_p.text_embedding.requires_grad_(False) # self.enc_p.encoder_text.requires_grad_(False) @@ -905,6 +907,11 @@ class SynthesizerTrn(nn.Module): ge = self.ref_enc(y * y_mask, y_mask) with autocast(enabled=False): + maybe_no_grad = torch.no_grad() if self.freeze_quantizer else contextlib.nullcontext + with maybe_no_grad: + if self.freeze_quantizer: + self.ssl_proj.eval() + self.quantizer.eval() ssl = self.ssl_proj(ssl) quantized, codes, commit_loss, quantized_list = self.quantizer( ssl, layers=[0]