mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2026-06-28 00:38:15 +08:00
159 lines
4.9 KiB
Python
159 lines
4.9 KiB
Python
from typing import cast
|
|
|
|
import torch
|
|
|
|
from . import nn
|
|
|
|
Tensor = torch.Tensor
|
|
|
|
|
|
# based on ComfyUI's and MinusZoneAI's fp8_linear optimization
|
|
def fp8_linear_forward(cls: nn.Linear, input: Tensor):
|
|
weight_dtype = cls.weight.dtype
|
|
base_dtype = input.dtype
|
|
if weight_dtype in [torch.float8_e4m3fn, torch.float8_e5m2]:
|
|
if len(input.shape) == 3:
|
|
input_shape = input.shape
|
|
|
|
scale_weight: Tensor | None = getattr(cls, "scale_weight", None)
|
|
if scale_weight is None:
|
|
scale_weight = torch.ones((), device=input.device, dtype=torch.float32)
|
|
else:
|
|
scale_weight = scale_weight.to(input.device).squeeze()
|
|
|
|
scale_input = torch.ones((), device=input.device, dtype=torch.float32)
|
|
|
|
input = torch.clamp(input, min=-448, max=448, out=input)
|
|
inn = (
|
|
input.reshape(-1, input_shape[2]).to(torch.float8_e4m3fn).contiguous()
|
|
) # always e4m3fn because e5m2 * e5m2 is not supported
|
|
|
|
bias = cls.bias if cls.bias is not None else None
|
|
|
|
o = torch._scaled_mm(
|
|
inn, cls.weight.t(), out_dtype=base_dtype, bias=bias, scale_a=scale_input, scale_b=scale_weight
|
|
)
|
|
|
|
return o.reshape((-1, input_shape[1], cls.weight.shape[0]))
|
|
else:
|
|
raise
|
|
else:
|
|
raise
|
|
|
|
|
|
def convert_fp8_linear(
|
|
module: nn.Module,
|
|
):
|
|
apply_fn = fp8_linear_forward
|
|
|
|
for _, sub in list(module.named_modules()):
|
|
if isinstance(sub, nn.Linear):
|
|
if getattr(sub, "_fp8", False):
|
|
continue
|
|
setattr(sub, "forward", apply_fn)
|
|
setattr(sub, "_fp8", True)
|
|
return module
|
|
|
|
|
|
def per_tensor_quantize(tensor: torch.Tensor) -> tuple[Tensor, Tensor]:
|
|
"""Quantize a tensor using per-tensor static scaling factor.
|
|
Args:
|
|
tensor: The input tensor.
|
|
"""
|
|
finfo = torch.finfo(torch.float8_e4m3fn)
|
|
# Calculate the scale as dtype max divided by absmax.
|
|
# Since .abs() creates a new tensor, we use aminmax to get
|
|
# the min and max first and then calculate the absmax.
|
|
if tensor.numel() == 0:
|
|
# Deal with empty tensors (triggered by empty MoE experts)
|
|
min_val, max_val = (
|
|
torch.tensor(0.0, dtype=tensor.dtype),
|
|
torch.tensor(1.0, dtype=tensor.dtype),
|
|
)
|
|
else:
|
|
min_val, max_val = tensor.aminmax()
|
|
amax = min_val.abs().max(max_val.abs())
|
|
scale = finfo.max / amax.clamp(min=1e-12)
|
|
# scale and clamp the tensor to bring it to
|
|
# the representative range of float8 data type
|
|
# (as default cast is unsaturated)
|
|
qweight = (tensor * scale).clamp(min=finfo.min, max=finfo.max)
|
|
# Return both float8 data and the inverse scale (as float),
|
|
# as both required as inputs to torch._scaled_mm
|
|
qweight = qweight.to(torch.float8_e4m3fn)
|
|
scale = scale.float().reciprocal()
|
|
return qweight, scale
|
|
|
|
|
|
def fp8_gemm(A, A_scale, B, B_scale, bias, out_dtype):
|
|
cuda_compute_capability = torch.cuda.get_device_capability()
|
|
if cuda_compute_capability >= (9, 0):
|
|
output, _ = torch._scaled_mm(
|
|
A,
|
|
B.t(),
|
|
out_dtype=out_dtype,
|
|
scale_a=A_scale,
|
|
scale_b=B_scale,
|
|
bias=bias,
|
|
)
|
|
else:
|
|
output = torch.nn.functional.linear(
|
|
A.to(out_dtype) * A_scale,
|
|
B.to(out_dtype) * B_scale.to(out_dtype),
|
|
bias=bias,
|
|
)
|
|
return output
|
|
|
|
|
|
class FP8DynamicLinear(nn.Module):
|
|
def __init__(self, qweight: Tensor, scale: Tensor, bias: Tensor):
|
|
super().__init__()
|
|
self.weight = torch.nn.Parameter(qweight, requires_grad=False)
|
|
self.weight_scale = torch.nn.Parameter(scale, requires_grad=False)
|
|
self.bias = bias
|
|
|
|
def __call__(self, x):
|
|
qinput, x_scale = per_tensor_quantize(x)
|
|
output = fp8_gemm(
|
|
A=qinput,
|
|
A_scale=x_scale,
|
|
B=self.weight,
|
|
B_scale=self.weight_scale,
|
|
bias=self.bias,
|
|
out_dtype=x.dtype,
|
|
)
|
|
return output
|
|
|
|
|
|
def replace_all_linear_with_fp8(model: nn.Module):
|
|
"""
|
|
Recursively replace every nn.Linear with FP8DynamicLinear in-place.
|
|
"""
|
|
|
|
def _recursively_replace(parent: nn.Module):
|
|
for child_name, child in list(parent.named_children()):
|
|
child = cast(nn.Module, child)
|
|
if isinstance(child, FP8DynamicLinear):
|
|
continue
|
|
|
|
if isinstance(child, nn.Linear):
|
|
device = child.weight.device
|
|
|
|
w = child.weight
|
|
|
|
b = child.bias.clone()
|
|
|
|
qw, qs = per_tensor_quantize(w)
|
|
|
|
quant_linear = FP8DynamicLinear(qw, qs, b)
|
|
|
|
quant_linear.to(device)
|
|
|
|
setattr(parent, child_name, quant_linear)
|
|
|
|
del child
|
|
else:
|
|
_recursively_replace(child)
|
|
|
|
_recursively_replace(model)
|