XXXXRT666 c1a4ff476c .
2025-10-19 21:51:54 +01:00

159 lines
4.9 KiB
Python

from typing import cast
import torch
from . import nn
Tensor = torch.Tensor
# based on ComfyUI's and MinusZoneAI's fp8_linear optimization
def fp8_linear_forward(cls: nn.Linear, input: Tensor):
weight_dtype = cls.weight.dtype
base_dtype = input.dtype
if weight_dtype in [torch.float8_e4m3fn, torch.float8_e5m2]:
if len(input.shape) == 3:
input_shape = input.shape
scale_weight: Tensor | None = getattr(cls, "scale_weight", None)
if scale_weight is None:
scale_weight = torch.ones((), device=input.device, dtype=torch.float32)
else:
scale_weight = scale_weight.to(input.device).squeeze()
scale_input = torch.ones((), device=input.device, dtype=torch.float32)
input = torch.clamp(input, min=-448, max=448, out=input)
inn = (
input.reshape(-1, input_shape[2]).to(torch.float8_e4m3fn).contiguous()
) # always e4m3fn because e5m2 * e5m2 is not supported
bias = cls.bias if cls.bias is not None else None
o = torch._scaled_mm(
inn, cls.weight.t(), out_dtype=base_dtype, bias=bias, scale_a=scale_input, scale_b=scale_weight
)
return o.reshape((-1, input_shape[1], cls.weight.shape[0]))
else:
raise
else:
raise
def convert_fp8_linear(
module: nn.Module,
):
apply_fn = fp8_linear_forward
for _, sub in list(module.named_modules()):
if isinstance(sub, nn.Linear):
if getattr(sub, "_fp8", False):
continue
setattr(sub, "forward", apply_fn)
setattr(sub, "_fp8", True)
return module
def per_tensor_quantize(tensor: torch.Tensor) -> tuple[Tensor, Tensor]:
"""Quantize a tensor using per-tensor static scaling factor.
Args:
tensor: The input tensor.
"""
finfo = torch.finfo(torch.float8_e4m3fn)
# Calculate the scale as dtype max divided by absmax.
# Since .abs() creates a new tensor, we use aminmax to get
# the min and max first and then calculate the absmax.
if tensor.numel() == 0:
# Deal with empty tensors (triggered by empty MoE experts)
min_val, max_val = (
torch.tensor(0.0, dtype=tensor.dtype),
torch.tensor(1.0, dtype=tensor.dtype),
)
else:
min_val, max_val = tensor.aminmax()
amax = min_val.abs().max(max_val.abs())
scale = finfo.max / amax.clamp(min=1e-12)
# scale and clamp the tensor to bring it to
# the representative range of float8 data type
# (as default cast is unsaturated)
qweight = (tensor * scale).clamp(min=finfo.min, max=finfo.max)
# Return both float8 data and the inverse scale (as float),
# as both required as inputs to torch._scaled_mm
qweight = qweight.to(torch.float8_e4m3fn)
scale = scale.float().reciprocal()
return qweight, scale
def fp8_gemm(A, A_scale, B, B_scale, bias, out_dtype):
cuda_compute_capability = torch.cuda.get_device_capability()
if cuda_compute_capability >= (9, 0):
output, _ = torch._scaled_mm(
A,
B.t(),
out_dtype=out_dtype,
scale_a=A_scale,
scale_b=B_scale,
bias=bias,
)
else:
output = torch.nn.functional.linear(
A.to(out_dtype) * A_scale,
B.to(out_dtype) * B_scale.to(out_dtype),
bias=bias,
)
return output
class FP8DynamicLinear(nn.Module):
def __init__(self, qweight: Tensor, scale: Tensor, bias: Tensor):
super().__init__()
self.weight = torch.nn.Parameter(qweight, requires_grad=False)
self.weight_scale = torch.nn.Parameter(scale, requires_grad=False)
self.bias = bias
def __call__(self, x):
qinput, x_scale = per_tensor_quantize(x)
output = fp8_gemm(
A=qinput,
A_scale=x_scale,
B=self.weight,
B_scale=self.weight_scale,
bias=self.bias,
out_dtype=x.dtype,
)
return output
def replace_all_linear_with_fp8(model: nn.Module):
"""
Recursively replace every nn.Linear with FP8DynamicLinear in-place.
"""
def _recursively_replace(parent: nn.Module):
for child_name, child in list(parent.named_children()):
child = cast(nn.Module, child)
if isinstance(child, FP8DynamicLinear):
continue
if isinstance(child, nn.Linear):
device = child.weight.device
w = child.weight
b = child.bias.clone()
qw, qs = per_tensor_quantize(w)
quant_linear = FP8DynamicLinear(qw, qs, b)
quant_linear.to(device)
setattr(parent, child_name, quant_linear)
del child
else:
_recursively_replace(child)
_recursively_replace(model)