mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2025-04-05 19:41:56 +08:00
modify freeze_quantizer mode, avoid quantizer's codebook updating (#953)
This commit is contained in:
parent
0b806dba37
commit
a95e6c13b8
@ -15,6 +15,7 @@ from module.mrte_model import MRTE
|
|||||||
from module.quantize import ResidualVectorQuantizer
|
from module.quantize import ResidualVectorQuantizer
|
||||||
from text import symbols
|
from text import symbols
|
||||||
from torch.cuda.amp import autocast
|
from torch.cuda.amp import autocast
|
||||||
|
import contextlib
|
||||||
|
|
||||||
|
|
||||||
class StochasticDurationPredictor(nn.Module):
|
class StochasticDurationPredictor(nn.Module):
|
||||||
@ -890,9 +891,10 @@ class SynthesizerTrn(nn.Module):
|
|||||||
self.ssl_proj = nn.Conv1d(ssl_dim, ssl_dim, 1, stride=1)
|
self.ssl_proj = nn.Conv1d(ssl_dim, ssl_dim, 1, stride=1)
|
||||||
|
|
||||||
self.quantizer = ResidualVectorQuantizer(dimension=ssl_dim, n_q=1, bins=1024)
|
self.quantizer = ResidualVectorQuantizer(dimension=ssl_dim, n_q=1, bins=1024)
|
||||||
if freeze_quantizer:
|
self.freeze_quantizer = freeze_quantizer
|
||||||
self.ssl_proj.requires_grad_(False)
|
# if freeze_quantizer:
|
||||||
self.quantizer.requires_grad_(False)
|
# self.ssl_proj.requires_grad_(False)
|
||||||
|
# self.quantizer.requires_grad_(False)
|
||||||
#self.quantizer.eval()
|
#self.quantizer.eval()
|
||||||
# self.enc_p.text_embedding.requires_grad_(False)
|
# self.enc_p.text_embedding.requires_grad_(False)
|
||||||
# self.enc_p.encoder_text.requires_grad_(False)
|
# self.enc_p.encoder_text.requires_grad_(False)
|
||||||
@ -905,6 +907,11 @@ class SynthesizerTrn(nn.Module):
|
|||||||
ge = self.ref_enc(y * y_mask, y_mask)
|
ge = self.ref_enc(y * y_mask, y_mask)
|
||||||
|
|
||||||
with autocast(enabled=False):
|
with autocast(enabled=False):
|
||||||
|
maybe_no_grad = torch.no_grad() if self.freeze_quantizer else contextlib.nullcontext
|
||||||
|
with maybe_no_grad:
|
||||||
|
if self.freeze_quantizer:
|
||||||
|
self.ssl_proj.eval()
|
||||||
|
self.quantizer.eval()
|
||||||
ssl = self.ssl_proj(ssl)
|
ssl = self.ssl_proj(ssl)
|
||||||
quantized, codes, commit_loss, quantized_list = self.quantizer(
|
quantized, codes, commit_loss, quantized_list = self.quantizer(
|
||||||
ssl, layers=[0]
|
ssl, layers=[0]
|
||||||
|
Loading…
x
Reference in New Issue
Block a user