Update models.py

This commit is contained in:
XXXXRT666 2024-05-23 02:00:46 +01:00 committed by XXXXRT666
parent 3e47ee19c8
commit b5c707ebf2

View File

@ -906,6 +906,11 @@ class SynthesizerTrn(nn.Module):
ge = self.ref_enc(y * y_mask, y_mask)
with autocast(enabled=False):
maybe_no_grad = torch.no_grad() if self.freeze_quantizer else contextlib.nullcontext()
with maybe_no_grad:
if self.freeze_quantizer:
self.ssl_proj.eval()
self.quantizer.eval()
ssl = self.ssl_proj(ssl)
quantized, codes, commit_loss, quantized_list = self.quantizer(
ssl, layers=[0]