modify freeze_quantizer mode, avoid quantizer's codebook updating (#953)

This commit is contained in:
hcwu1993 2024-05-02 21:26:44 +08:00 committed by GitHub
parent 0b806dba37
commit a95e6c13b8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -15,6 +15,7 @@ from module.mrte_model import MRTE
from module.quantize import ResidualVectorQuantizer from module.quantize import ResidualVectorQuantizer
from text import symbols from text import symbols
from torch.cuda.amp import autocast from torch.cuda.amp import autocast
import contextlib
class StochasticDurationPredictor(nn.Module): class StochasticDurationPredictor(nn.Module):
@ -890,9 +891,10 @@ class SynthesizerTrn(nn.Module):
self.ssl_proj = nn.Conv1d(ssl_dim, ssl_dim, 1, stride=1) self.ssl_proj = nn.Conv1d(ssl_dim, ssl_dim, 1, stride=1)
self.quantizer = ResidualVectorQuantizer(dimension=ssl_dim, n_q=1, bins=1024) self.quantizer = ResidualVectorQuantizer(dimension=ssl_dim, n_q=1, bins=1024)
if freeze_quantizer: self.freeze_quantizer = freeze_quantizer
self.ssl_proj.requires_grad_(False) # if freeze_quantizer:
self.quantizer.requires_grad_(False) # self.ssl_proj.requires_grad_(False)
# self.quantizer.requires_grad_(False)
#self.quantizer.eval() #self.quantizer.eval()
# self.enc_p.text_embedding.requires_grad_(False) # self.enc_p.text_embedding.requires_grad_(False)
# self.enc_p.encoder_text.requires_grad_(False) # self.enc_p.encoder_text.requires_grad_(False)
@ -905,6 +907,11 @@ class SynthesizerTrn(nn.Module):
ge = self.ref_enc(y * y_mask, y_mask) ge = self.ref_enc(y * y_mask, y_mask)
with autocast(enabled=False): with autocast(enabled=False):
maybe_no_grad = torch.no_grad() if self.freeze_quantizer else contextlib.nullcontext
with maybe_no_grad:
if self.freeze_quantizer:
self.ssl_proj.eval()
self.quantizer.eval()
ssl = self.ssl_proj(ssl) ssl = self.ssl_proj(ssl)
quantized, codes, commit_loss, quantized_list = self.quantizer( quantized, codes, commit_loss, quantized_list = self.quantizer(
ssl, layers=[0] ssl, layers=[0]