modify freeze_quantizer mode, avoid quantizer's codebook updating (#953)

This commit is contained in:
hcwu1993 2024-05-02 21:26:44 +08:00 committed by GitHub
parent de57e59139
commit 2981788b4c

View File

@ -15,6 +15,7 @@ from module.mrte_model import MRTE
from module.quantize import ResidualVectorQuantizer from module.quantize import ResidualVectorQuantizer
from text import symbols from text import symbols
from torch.cuda.amp import autocast from torch.cuda.amp import autocast
import contextlib
class StochasticDurationPredictor(nn.Module): class StochasticDurationPredictor(nn.Module):
@ -890,9 +891,10 @@ class SynthesizerTrn(nn.Module):
self.ssl_proj = nn.Conv1d(ssl_dim, ssl_dim, 1, stride=1) self.ssl_proj = nn.Conv1d(ssl_dim, ssl_dim, 1, stride=1)
self.quantizer = ResidualVectorQuantizer(dimension=ssl_dim, n_q=1, bins=1024) self.quantizer = ResidualVectorQuantizer(dimension=ssl_dim, n_q=1, bins=1024)
if freeze_quantizer: self.freeze_quantizer = freeze_quantizer
self.ssl_proj.requires_grad_(False) # if freeze_quantizer:
self.quantizer.requires_grad_(False) # self.ssl_proj.requires_grad_(False)
# self.quantizer.requires_grad_(False)
#self.quantizer.eval() #self.quantizer.eval()
# self.enc_p.text_embedding.requires_grad_(False) # self.enc_p.text_embedding.requires_grad_(False)
# self.enc_p.encoder_text.requires_grad_(False) # self.enc_p.encoder_text.requires_grad_(False)
@ -905,6 +907,11 @@ class SynthesizerTrn(nn.Module):
ge = self.ref_enc(y * y_mask, y_mask) ge = self.ref_enc(y * y_mask, y_mask)
with autocast(enabled=False): with autocast(enabled=False):
maybe_no_grad = torch.no_grad() if self.freeze_quantizer else contextlib.nullcontext
with maybe_no_grad:
if self.freeze_quantizer:
self.ssl_proj.eval()
self.quantizer.eval()
ssl = self.ssl_proj(ssl) ssl = self.ssl_proj(ssl)
quantized, codes, commit_loss, quantized_list = self.quantizer( quantized, codes, commit_loss, quantized_list = self.quantizer(
ssl, layers=[0] ssl, layers=[0]