mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2025-04-05 19:41:56 +08:00
119 lines
4.4 KiB
Python
119 lines
4.4 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
"""Residual vector quantizer implementation."""
|
|
|
|
from dataclasses import dataclass, field
|
|
import typing as tp
|
|
|
|
import torch
|
|
from torch import nn
|
|
|
|
from module.core_vq import ResidualVectorQuantization
|
|
|
|
|
|
@dataclass
|
|
class QuantizedResult:
|
|
quantized: torch.Tensor
|
|
codes: torch.Tensor
|
|
bandwidth: torch.Tensor # bandwidth in kb/s used, per batch item.
|
|
penalty: tp.Optional[torch.Tensor] = None
|
|
metrics: dict = field(default_factory=dict)
|
|
|
|
|
|
class ResidualVectorQuantizer(nn.Module):
|
|
"""Residual Vector Quantizer.
|
|
Args:
|
|
dimension (int): Dimension of the codebooks.
|
|
n_q (int): Number of residual vector quantizers used.
|
|
bins (int): Codebook size.
|
|
decay (float): Decay for exponential moving average over the codebooks.
|
|
kmeans_init (bool): Whether to use kmeans to initialize the codebooks.
|
|
kmeans_iters (int): Number of iterations used for kmeans initialization.
|
|
threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes
|
|
that have an exponential moving average cluster size less than the specified threshold with
|
|
randomly selected vector from the current batch.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
dimension: int = 256,
|
|
n_q: int = 8,
|
|
bins: int = 1024,
|
|
decay: float = 0.99,
|
|
kmeans_init: bool = True,
|
|
kmeans_iters: int = 50,
|
|
threshold_ema_dead_code: int = 2,
|
|
):
|
|
super().__init__()
|
|
self.n_q = n_q
|
|
self.dimension = dimension
|
|
self.bins = bins
|
|
self.decay = decay
|
|
self.kmeans_init = kmeans_init
|
|
self.kmeans_iters = kmeans_iters
|
|
self.threshold_ema_dead_code = threshold_ema_dead_code
|
|
self.vq = ResidualVectorQuantization(
|
|
dim=self.dimension,
|
|
codebook_size=self.bins,
|
|
num_quantizers=self.n_q,
|
|
decay=self.decay,
|
|
kmeans_init=self.kmeans_init,
|
|
kmeans_iters=self.kmeans_iters,
|
|
threshold_ema_dead_code=self.threshold_ema_dead_code,
|
|
)
|
|
|
|
def forward(
|
|
self,
|
|
x: torch.Tensor,
|
|
n_q: tp.Optional[int] = None,
|
|
layers: tp.Optional[list] = None,
|
|
) -> QuantizedResult:
|
|
"""Residual vector quantization on the given input tensor.
|
|
Args:
|
|
x (torch.Tensor): Input tensor.
|
|
n_q (int): Number of quantizer used to quantize. Default: All quantizers.
|
|
layers (list): Layer that need to return quantized. Defalt: None.
|
|
Returns:
|
|
QuantizedResult:
|
|
The quantized (or approximately quantized) representation with
|
|
the associated numbert quantizers and layer quantized required to return.
|
|
"""
|
|
n_q = n_q if n_q else self.n_q
|
|
if layers and max(layers) >= n_q:
|
|
raise ValueError(
|
|
f"Last layer index in layers: A {max(layers)}. Number of quantizers in RVQ: B {self.n_q}. A must less than B."
|
|
)
|
|
quantized, codes, commit_loss, quantized_list = self.vq(
|
|
x, n_q=n_q, layers=layers
|
|
)
|
|
return quantized, codes, torch.mean(commit_loss), quantized_list
|
|
|
|
def encode(
|
|
self, x: torch.Tensor, n_q: tp.Optional[int] = None, st: tp.Optional[int] = None
|
|
) -> torch.Tensor:
|
|
"""Encode a given input tensor with the specified sample rate at the given bandwidth.
|
|
The RVQ encode method sets the appropriate number of quantizer to use
|
|
and returns indices for each quantizer.
|
|
Args:
|
|
x (torch.Tensor): Input tensor.
|
|
n_q (int): Number of quantizer used to quantize. Default: All quantizers.
|
|
st (int): Start to encode input from which layers. Default: 0.
|
|
"""
|
|
n_q = n_q if n_q else self.n_q
|
|
st = st or 0
|
|
codes = self.vq.encode(x, n_q=n_q, st=st)
|
|
return codes
|
|
|
|
def decode(self, codes: torch.Tensor, st: int = 0) -> torch.Tensor:
|
|
"""Decode the given codes to the quantized representation.
|
|
Args:
|
|
codes (torch.Tensor): Input indices for each quantizer.
|
|
st (int): Start to decode input codes from which layers. Default: 0.
|
|
"""
|
|
quantized = self.vq.decode(codes, st=st)
|
|
return quantized
|