mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2026-06-29 17:58:17 +08:00
.
This commit is contained in:
parent
915ed53c76
commit
c1a4ff476c
4
.github/build_windows_packages.ps1
vendored
4
.github/build_windows_packages.ps1
vendored
@ -157,12 +157,12 @@ Write-Host "[INFO] Installing PyTorch..."
|
||||
switch ($cuda) {
|
||||
"cu126" {
|
||||
& ".\runtime\python.exe" -m pip install psutil ninja packaging wheel "setuptools>=42" --no-warn-script-location --no-cache-dir
|
||||
& ".\runtime\python.exe" -m pip install torch --index-url https://download.pytorch.org/whl/cu126 --no-warn-script-location --no-cache-dir
|
||||
& ".\runtime\python.exe" -m pip install torch torchao --index-url https://download.pytorch.org/whl/cu126 --no-warn-script-location --no-cache-dir
|
||||
& ".\runtime\python.exe" -m pip install flash-attn -i https://xxxxrt666.github.io/PIP-Index/ --no-build-isolation --no-cache-dir
|
||||
}
|
||||
"cu128" {
|
||||
& ".\runtime\python.exe" -m pip install psutil ninja packaging wheel "setuptools>=42" --no-warn-script-location --no-cache-dir
|
||||
& ".\runtime\python.exe" -m pip install torch --index-url https://download.pytorch.org/whl/cu128 --no-warn-script-location --no-cache-dir
|
||||
& ".\runtime\python.exe" -m pip install torch torchao --index-url https://download.pytorch.org/whl/cu128 --no-warn-script-location --no-cache-dir
|
||||
& ".\runtime\python.exe" -m pip install flash-attn -i https://xxxxrt666.github.io/PIP-Index/ --no-build-isolation --no-cache-dir
|
||||
}
|
||||
default {
|
||||
|
||||
1
.gitignore
vendored
1
.gitignore
vendored
@ -20,6 +20,7 @@ tools/AP_BWE/24kto48k/*
|
||||
!tools/AP_BWE/24kto48k/readme.txt
|
||||
onnx_export
|
||||
compile_cache
|
||||
profiler
|
||||
|
||||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
|
||||
@ -60,10 +60,10 @@ source "$HOME/.bashrc"
|
||||
"$HOME/miniconda3/bin/conda" install gcc=11 gxx ffmpeg cmake make unzip $SYSROOT_PKG "libstdcxx-ng>=11" -q -y
|
||||
|
||||
if [ "$CUDA_VERSION" = "12.8" ]; then
|
||||
"$HOME/miniconda3/bin/pip" install torch torchaudio --no-cache-dir --index-url https://download.pytorch.org/whl/cu128
|
||||
"$HOME/miniconda3/bin/pip" install torch torchao --no-cache-dir --index-url https://download.pytorch.org/whl/cu128
|
||||
"$HOME/miniconda3/bin/conda" install cuda-nvcc=12.8 -c nvidia
|
||||
elif [ "$CUDA_VERSION" = "12.6" ]; then
|
||||
"$HOME/miniconda3/bin/pip" install torch torchaudio --no-cache-dir --index-url https://download.pytorch.org/whl/cu126
|
||||
"$HOME/miniconda3/bin/pip" install torch torchao --no-cache-dir --index-url https://download.pytorch.org/whl/cu126
|
||||
"$HOME/miniconda3/bin/conda" install cuda-nvcc=12.6 -c nvidia
|
||||
fi
|
||||
|
||||
|
||||
@ -5,8 +5,10 @@ if importlib.util.find_spec("mlx") is not None and platform.system() == "Darwin"
|
||||
from .sample_funcs_mlx import sample_naive as sample_naive_mlx
|
||||
from .t2s_engine_mlx import T2SEngine as T2SEngineMLX
|
||||
|
||||
backends = ["mlx_static", "mlx_quantized_mxfp4", "mlx_quantized_affine", "mlx_varlen"]
|
||||
backends = ["mlx_static", "mlx_varlen"]
|
||||
else:
|
||||
backends = []
|
||||
|
||||
__all__ = ["T2SEngineMLX", "sample_naive_mlx", "backends"]
|
||||
quantization_methods_mlx = [None, "MXFP4", "Affine"]
|
||||
|
||||
__all__ = ["T2SEngineMLX", "sample_naive_mlx", "backends", "quantization_methods_mlx"]
|
||||
|
||||
@ -1,179 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from ..structs_mlx import KVCacheQ
|
||||
from ..t2s_model_abc import (
|
||||
AttentionABC,
|
||||
KVCache,
|
||||
KVCacheHND,
|
||||
T2SDecoderABC,
|
||||
TransformerBlockABC,
|
||||
TransformerDecoderABC,
|
||||
)
|
||||
|
||||
Array = mx.array
|
||||
|
||||
|
||||
class Attention(AttentionABC):
|
||||
def __init__(self, n_head: int, hidden_dim: int, max_seq_length: int):
|
||||
super().__init__(n_head, hidden_dim, max_seq_length)
|
||||
self.kc_class = KVCacheHND
|
||||
|
||||
@staticmethod
|
||||
def quantized_scaled_dot_product_attention(
|
||||
queries: Array,
|
||||
q_keys: tuple[Array, Array, Array],
|
||||
q_values: tuple[Array, Array, Array],
|
||||
scale: float,
|
||||
mask: Array,
|
||||
group_size: int = 32,
|
||||
bits: int = 8,
|
||||
) -> Array:
|
||||
queries *= scale
|
||||
|
||||
scores = mx.quantized_matmul(queries, *q_keys, transpose=True, group_size=group_size, bits=bits)
|
||||
scores = mx.where(mask, scores, -mx.inf)
|
||||
scores = mx.softmax(scores, axis=-1, precise=True) # type: ignore
|
||||
out = mx.quantized_matmul(scores, *q_values, transpose=False, group_size=group_size, bits=bits)
|
||||
|
||||
return out
|
||||
|
||||
def __call__(self, x: Array, input_pos: Array, kv_cache: KVCache | KVCacheQ, cache_idx: Array, attn_mask: Array):
|
||||
bsz, seqlen, _ = x.shape
|
||||
|
||||
q, k, v = self.in_proj(x).split(3, axis=-1)
|
||||
|
||||
q, k, v = map(lambda x: x.reshape(bsz, seqlen, self.n_head, self.head_dim), (q, k, v))
|
||||
|
||||
q, k, v = map(lambda x: x.swapaxes(1, 2), (q, k, v))
|
||||
|
||||
kv_cache = self.kc_class.update_cache(input_pos, k, v, kv_cache, cache_idx)
|
||||
assert len(kv_cache) == 2
|
||||
|
||||
max_idx = int(input_pos.max())
|
||||
|
||||
q, k, v = map(lambda x: x[..., :max_idx, :], (q, *kv_cache))
|
||||
|
||||
mask = attn_mask[..., :max_idx]
|
||||
|
||||
attn = mx.fast.scaled_dot_product_attention(q, k, v, scale=self.scale, mask=mask)
|
||||
|
||||
attn = attn.swapaxes(1, 2).reshape(bsz, seqlen, self.hidden_dim)
|
||||
|
||||
attn = self.out_proj(attn)
|
||||
|
||||
return attn
|
||||
|
||||
# def __call__(self, x: Array, input_pos: Array, kv_cache: KVCache | KVCacheQ, cache_idx: Array, attn_mask: Array):
|
||||
# bsz, seqlen, _ = x.shape
|
||||
|
||||
# q, k, v = self.in_proj(x).split(3, axis=-1)
|
||||
|
||||
# q, k, v = map(lambda x: x.reshape(bsz, seqlen, self.n_head, self.head_dim), (q, k, v))
|
||||
|
||||
# q, k, v = map(lambda x: x.swapaxes(1, 2), (q, k, v))
|
||||
|
||||
# kv_cache = self.kc_class.update_cache(input_pos, k, v, kv_cache, cache_idx)
|
||||
|
||||
# assert len(kv_cache) == 3
|
||||
# (k_q, k_s, k_b), (v_q, v_s, v_b), (group_size, bits) = kv_cache
|
||||
|
||||
# k_q, k_s, k_b, v_q, v_s, v_b = map(lambda x: x[..., : int(input_pos.max()), :], (k_q, k_s, k_b, v_q, v_s, v_b))
|
||||
|
||||
# mask = attn_mask[..., : int(input_pos.max())]
|
||||
|
||||
# attn = Attention.quantized_scaled_dot_product_attention(
|
||||
# q,
|
||||
# (k_q, k_s, k_b),
|
||||
# (v_q, v_s, v_b),
|
||||
# self.scale,
|
||||
# mask,
|
||||
# group_size,
|
||||
# bits,
|
||||
# )
|
||||
|
||||
# attn = attn.swapaxes(1, 2).reshape(bsz, seqlen, self.hidden_dim)
|
||||
|
||||
# output = self.out_proj(attn)
|
||||
|
||||
# return output
|
||||
|
||||
|
||||
class TransformerBlock(TransformerBlockABC):
|
||||
def __init__(self, n_head: int, ffn_dim: int, hidden_dim: int, max_seq_length: int, *args, **kwds) -> None:
|
||||
super().__init__(n_head, ffn_dim, hidden_dim, max_seq_length, *args, **kwds)
|
||||
|
||||
self.attention = Attention(n_head, hidden_dim, max_seq_length, *args, **kwds)
|
||||
|
||||
|
||||
class TransformerDecoder(TransformerDecoderABC):
|
||||
def __init__(
|
||||
self,
|
||||
hidden_dim: int,
|
||||
n_layer: int,
|
||||
n_head: int,
|
||||
ffn_dim: int,
|
||||
vocab_size: int,
|
||||
max_seq_length: int,
|
||||
max_batch_size: int,
|
||||
*args,
|
||||
**kwds,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
hidden_dim,
|
||||
n_layer,
|
||||
n_head,
|
||||
ffn_dim,
|
||||
vocab_size,
|
||||
max_seq_length,
|
||||
max_batch_size,
|
||||
*args,
|
||||
**kwds,
|
||||
)
|
||||
|
||||
self.layers = [
|
||||
TransformerBlock(
|
||||
n_head,
|
||||
ffn_dim,
|
||||
hidden_dim,
|
||||
max_seq_length,
|
||||
*args,
|
||||
**kwds,
|
||||
)
|
||||
for _ in range(n_layer)
|
||||
]
|
||||
|
||||
|
||||
class T2SDecoder(T2SDecoderABC):
|
||||
def __init__(
|
||||
self,
|
||||
config: dict,
|
||||
max_seq_length: int = 2000,
|
||||
max_batch_size: int = 10,
|
||||
) -> None:
|
||||
super().__init__(config, max_seq_length, max_batch_size)
|
||||
|
||||
self.h = TransformerDecoder(
|
||||
self.hidden_dim, self.n_layer, self.n_head, self.ffn_dim, self.vocab_size, max_seq_length, max_batch_size
|
||||
)
|
||||
|
||||
self.kv_class = KVCacheHND
|
||||
self.group_size = 32
|
||||
self.bits = 8
|
||||
self.mode = "affine"
|
||||
|
||||
def set_mode(self, mode: str):
|
||||
assert mode in ["affine", "mxfp4"]
|
||||
self.mode = mode
|
||||
if self.mode == "mxfp4":
|
||||
self.bits = 4
|
||||
else:
|
||||
self.bits = 8
|
||||
|
||||
def quantized(self):
|
||||
nn.quantize(self, self.group_size, self.bits, mode=self.mode)
|
||||
# for layer in self.h.layers:
|
||||
# nn.quantize(layer.feed_forward, self.group_size, self.bits)
|
||||
# nn.quantize(layer.attention, self.group_size, self.bits)
|
||||
@ -2,7 +2,7 @@ from __future__ import annotations
|
||||
|
||||
import mlx.core as mx
|
||||
|
||||
from ..structs_mlx import KVCache, KVCacheQ
|
||||
from ..structs_mlx import KVCache
|
||||
from ..t2s_model_abc import (
|
||||
AttentionABC,
|
||||
KVCacheHND,
|
||||
@ -19,23 +19,24 @@ class Attention(AttentionABC):
|
||||
super().__init__(n_head, hidden_dim, max_seq_length)
|
||||
self.kc_class = KVCacheHND
|
||||
|
||||
def __call__(self, x: Array, input_pos: Array, kv_cache: KVCache | KVCacheQ, cache_idx: Array, attn_mask: Array):
|
||||
def __call__(self, x: Array, input_pos: Array, kv_cache: KVCache, cache_idx: Array, attn_mask: Array):
|
||||
bsz, seqlen, _ = x.shape
|
||||
|
||||
q, k, v = self.in_proj(x).split(3, axis=-1)
|
||||
qkv = self.in_proj(x)
|
||||
|
||||
q, k, v = map(lambda x: x.reshape(bsz, seqlen, self.n_head, self.head_dim), (q, k, v))
|
||||
q, k, v = mx.split(qkv, 3, -1)
|
||||
|
||||
q, k, v = map(lambda x: x.swapaxes(1, 2), (q, k, v))
|
||||
q = q.reshape(bsz, seqlen, self.n_head, -1).transpose(0, 2, 1, 3)
|
||||
k = k.reshape(bsz, seqlen, self.n_head, -1).transpose(0, 2, 1, 3)
|
||||
v = v.reshape(bsz, seqlen, self.n_head, -1).transpose(0, 2, 1, 3)
|
||||
|
||||
kv_cache = self.kc_class.update_cache(input_pos, k, v, kv_cache, cache_idx)
|
||||
assert len(kv_cache) == 2
|
||||
|
||||
k, v = kv_cache
|
||||
|
||||
attn = mx.fast.scaled_dot_product_attention(q, k, v, scale=self.scale, mask=attn_mask)
|
||||
|
||||
attn = attn.swapaxes(1, 2).reshape(bsz, seqlen, self.hidden_dim)
|
||||
attn = attn.transpose(0, 2, 1, 3).reshape(bsz, seqlen, -1)
|
||||
|
||||
attn = self.out_proj(attn)
|
||||
|
||||
@ -85,7 +86,7 @@ class T2SDecoder(T2SDecoderABC):
|
||||
def __init__(
|
||||
self,
|
||||
config: dict,
|
||||
max_seq_length: int = 2000,
|
||||
max_seq_length: int = 1500,
|
||||
max_batch_size: int = 10,
|
||||
) -> None:
|
||||
super().__init__(config, max_seq_length, max_batch_size)
|
||||
|
||||
@ -2,7 +2,7 @@ from __future__ import annotations
|
||||
|
||||
import mlx.core as mx
|
||||
|
||||
from ..structs_mlx import KVCache, KVCacheQ
|
||||
from ..structs_mlx import KVCache
|
||||
from ..t2s_model_abc import (
|
||||
AttentionABC,
|
||||
KVCacheHND,
|
||||
@ -19,7 +19,7 @@ class Attention(AttentionABC):
|
||||
super().__init__(n_head, hidden_dim, max_seq_length)
|
||||
self.kc_class = KVCacheHND
|
||||
|
||||
def __call__(self, x: Array, input_pos: Array, kv_cache: KVCache | KVCacheQ, cache_idx: Array, attn_mask: Array):
|
||||
def __call__(self, x: Array, input_pos: Array, kv_cache: KVCache, cache_idx: Array, attn_mask: Array):
|
||||
bsz, seqlen, _ = x.shape
|
||||
|
||||
q, k, v = self.in_proj(x).split(3, axis=-1)
|
||||
@ -29,7 +29,6 @@ class Attention(AttentionABC):
|
||||
q, k, v = map(lambda x: x.swapaxes(1, 2), (q, k, v))
|
||||
|
||||
kv_cache = self.kc_class.update_cache(input_pos, k, v, kv_cache, cache_idx)
|
||||
assert len(kv_cache) == 2
|
||||
|
||||
max_idx = int(input_pos.max())
|
||||
|
||||
@ -39,7 +38,7 @@ class Attention(AttentionABC):
|
||||
|
||||
attn = mx.fast.scaled_dot_product_attention(q, k, v, scale=self.scale, mask=mask)
|
||||
|
||||
attn = attn.swapaxes(1, 2).reshape(bsz, seqlen, self.hidden_dim)
|
||||
attn = attn.swapaxes(1, 2).reshape(bsz, seqlen, -1)
|
||||
|
||||
attn = self.out_proj(attn)
|
||||
|
||||
@ -89,7 +88,7 @@ class T2SDecoder(T2SDecoderABC):
|
||||
def __init__(
|
||||
self,
|
||||
config: dict,
|
||||
max_seq_length: int = 2000,
|
||||
max_seq_length: int = 1500,
|
||||
max_batch_size: int = 10,
|
||||
) -> None:
|
||||
super().__init__(config, max_seq_length, max_batch_size)
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
from functools import partial
|
||||
from typing import Protocol
|
||||
|
||||
import mlx.core as mx
|
||||
@ -17,8 +18,56 @@ class SampleProtocolMLX(Protocol):
|
||||
) -> Array: ...
|
||||
|
||||
|
||||
def apply_repetition_penalty(logits: Array, previous_tokens: Array, repetition_penalty: float):
|
||||
batch_idx = mx.arange(previous_tokens.shape[0])
|
||||
selected_logits = logits[batch_idx, previous_tokens]
|
||||
selected_logits = mx.where(
|
||||
selected_logits < 0, selected_logits * repetition_penalty, selected_logits / repetition_penalty
|
||||
)
|
||||
logits[batch_idx, previous_tokens] = selected_logits
|
||||
return logits
|
||||
|
||||
|
||||
@partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state)
|
||||
def apply_greedy_sampling(logits: Array):
|
||||
return mx.argmax(logits, axis=-1, keepdims=True).astype(mx.int32)
|
||||
|
||||
|
||||
@partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state)
|
||||
def apply_temperature(logits: Array, temperature: float):
|
||||
return logits / temperature
|
||||
|
||||
|
||||
@partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state)
|
||||
def apply_top_k(logits: Array, top_k: int):
|
||||
v = mx.topk(logits, top_k)
|
||||
pivot = mx.expand_dims(v[:, 0], -1)
|
||||
logits = mx.where(logits < pivot, -mx.inf, logits)
|
||||
return logits
|
||||
|
||||
|
||||
@partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state)
|
||||
def apply_top_p(logits: Array, top_p: float):
|
||||
sorted_indices = mx.argsort(-logits, axis=-1)
|
||||
sorted_logits = mx.take_along_axis(logits, sorted_indices, axis=-1)
|
||||
cum_probs = mx.cumsum(mx.softmax(sorted_logits, axis=-1), axis=-1)
|
||||
sorted_indices_to_remove = cum_probs > top_p
|
||||
sorted_indices_to_remove[:, -1] = False
|
||||
indices_to_remove = mx.zeros_like(logits).astype(mx.bool_)
|
||||
batch_indices = mx.arange(logits.shape[0])[:, None]
|
||||
indices_to_remove[batch_indices, sorted_indices] = sorted_indices_to_remove
|
||||
logits = mx.where(indices_to_remove, -mx.inf, logits)
|
||||
return logits
|
||||
|
||||
|
||||
@partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state)
|
||||
def apply_sampling(logits: Array):
|
||||
gumbel_noise = mx.random.gumbel(shape=logits.shape, dtype=logits.dtype)
|
||||
idx_next = mx.argmax(logits + gumbel_noise, axis=-1, keepdims=True).astype(mx.int32)
|
||||
return idx_next
|
||||
|
||||
|
||||
class sample_naive(SampleProtocolMLX):
|
||||
# @partial(mx.compile)
|
||||
@staticmethod
|
||||
def __call__(
|
||||
logits,
|
||||
@ -28,38 +77,18 @@ class sample_naive(SampleProtocolMLX):
|
||||
top_p,
|
||||
repetition_penalty,
|
||||
):
|
||||
if temperature <= 1e-5:
|
||||
probs = mx.softmax(logits, axis=-1)
|
||||
return mx.argmax(probs, axis=-1, keepdims=True).astype(mx.int32)
|
||||
|
||||
if repetition_penalty != 1.0:
|
||||
batch_idx = mx.arange(previous_tokens.shape[0])
|
||||
previous_tokens = previous_tokens.astype(mx.int64)
|
||||
selected_logists = logits[batch_idx, previous_tokens]
|
||||
selected_logists = mx.where(
|
||||
selected_logists < 0, selected_logists * repetition_penalty, selected_logists / repetition_penalty
|
||||
)
|
||||
logits[batch_idx, previous_tokens] = selected_logists
|
||||
logits = apply_repetition_penalty(logits, previous_tokens, repetition_penalty)
|
||||
|
||||
if temperature <= 1e-5:
|
||||
return apply_greedy_sampling(logits)
|
||||
elif temperature < 1.0:
|
||||
logits = apply_temperature(logits, temperature)
|
||||
|
||||
if top_k < 1025:
|
||||
logits = apply_top_k(logits, top_k)
|
||||
|
||||
if top_p < 1.0:
|
||||
sorted_indices = mx.argsort(-logits, axis=-1)
|
||||
sorted_logits = mx.take_along_axis(logits, sorted_indices, axis=-1)
|
||||
cum_probs = mx.cumsum(mx.softmax(sorted_logits, axis=-1), axis=-1)
|
||||
sorted_indices_to_remove = cum_probs > top_p
|
||||
sorted_indices_to_remove[:, -1] = False
|
||||
indices_to_remove = mx.zeros_like(logits).astype(mx.bool_)
|
||||
batch_indices = mx.arange(logits.shape[0])[:, None]
|
||||
indices_to_remove[batch_indices, sorted_indices] = sorted_indices_to_remove
|
||||
logits = mx.where(indices_to_remove, -mx.inf, logits)
|
||||
logits = apply_top_p(logits, top_p)
|
||||
|
||||
if temperature < 1.0:
|
||||
logits = logits / temperature
|
||||
|
||||
v = mx.topk(logits, top_k)
|
||||
pivot = mx.expand_dims(v[:, 0], -1)
|
||||
logits = mx.where(logits < pivot, -mx.inf, logits)
|
||||
|
||||
gumbel_noise = mx.random.gumbel(shape=logits.shape, dtype=logits.dtype)
|
||||
idx_next = mx.argmax(logits + gumbel_noise, axis=-1, keepdims=True).astype(mx.int32)
|
||||
|
||||
return idx_next
|
||||
return apply_sampling(logits)
|
||||
|
||||
@ -29,6 +29,7 @@ class T2SRequestMLX:
|
||||
early_stop_num: int = -1
|
||||
temperature: float = 1.0
|
||||
repetition_penalty: float = 1.35
|
||||
debug: bool = False
|
||||
|
||||
@classmethod
|
||||
def from_torch(cls, request: T2SRequest) -> T2SRequestMLX:
|
||||
@ -48,29 +49,27 @@ class T2SRequestMLX:
|
||||
request.early_stop_num,
|
||||
request.temperature,
|
||||
request.repetition_penalty,
|
||||
request.debug,
|
||||
)
|
||||
|
||||
|
||||
KVCache: TypeAlias = tuple[Array, Array]
|
||||
KVCacheQ: TypeAlias = tuple[tuple[Array, Array, Array], tuple[Array, Array, Array], tuple[int, int]]
|
||||
|
||||
|
||||
class KVCacheProtocol(Protocol):
|
||||
@staticmethod
|
||||
def empty(kv_cache: KVCache | KVCacheQ) -> None: ...
|
||||
def empty(kv_cache: KVCache) -> None: ...
|
||||
|
||||
@staticmethod
|
||||
def update_cache(
|
||||
input_pos: Array, k_val: Array, v_val: Array, kv_cache: KVCache | KVCacheQ, cache_idx: Array
|
||||
) -> KVCache | KVCacheQ: ...
|
||||
def update_cache(input_pos: Array, k_val: Array, v_val: Array, kv_cache: KVCache, cache_idx: Array) -> KVCache: ...
|
||||
|
||||
@staticmethod
|
||||
def prefill_kv(k_val: Array, v_val: Array, kv_cache: KVCache | KVCacheQ) -> None: ...
|
||||
def prefill_kv(k_val: Array, v_val: Array, kv_cache: KVCache) -> None: ...
|
||||
|
||||
@staticmethod
|
||||
def init_cache(
|
||||
batch_size: int, max_seq_length: int, n_heads: int, head_dim: int, dtype: mx.Dtype, *args, **kwds
|
||||
) -> KVCache | KVCacheQ: ...
|
||||
) -> KVCache: ...
|
||||
|
||||
|
||||
class T2SDecoderProtocol(Protocol):
|
||||
@ -104,7 +103,7 @@ class T2SSessionMLX:
|
||||
self.y_len = y_len
|
||||
|
||||
# Cache
|
||||
self.kv_cache: MutableSequence[KVCache | KVCacheQ]
|
||||
self.kv_cache: MutableSequence[KVCache]
|
||||
self.sample = sample_func()
|
||||
|
||||
# Forward args
|
||||
@ -118,6 +117,7 @@ class T2SSessionMLX:
|
||||
|
||||
self.input_pos = mx.zeros_like(self.prefill_len)
|
||||
self.input_pos += self.prefill_len
|
||||
self.input_pos = self.input_pos.squeeze(0) # 30% Performance Improvement
|
||||
|
||||
# EOS
|
||||
self.completed = mx.array([False] * len(self.x)).astype(mx.bool_)
|
||||
@ -148,5 +148,3 @@ class T2SSessionMLX:
|
||||
|
||||
attn_mask = mx.repeat(mx.expand_dims(attn_mask, 1), decoder.n_head, 1)
|
||||
self.attn_mask = attn_mask
|
||||
|
||||
mx.eval(self.attn_mask)
|
||||
|
||||
@ -1,22 +1,23 @@
|
||||
import gc
|
||||
import os
|
||||
import time
|
||||
import traceback
|
||||
from typing import cast
|
||||
from typing import Literal, cast
|
||||
|
||||
import mlx.core as mx
|
||||
import torch
|
||||
from rich.progress import BarColumn, Progress, TextColumn
|
||||
|
||||
from ..logger import SpeedColumnToken, console, logger
|
||||
from ..logger import SpeedColumnToken, Timer, console, logger
|
||||
from ..PyTorch.structs import T2SEngineProtocol, T2SRequest, T2SResult
|
||||
from .backends import mlx_quantized, mlx_static, mlx_varlen
|
||||
from .backends import mlx_static, mlx_varlen
|
||||
from .structs_mlx import T2SSessionMLX
|
||||
from .t2s_model_abc import T2SDecoderABC
|
||||
|
||||
Array = mx.array
|
||||
Tensor = torch.Tensor
|
||||
|
||||
timer = Timer()
|
||||
|
||||
|
||||
class T2SEngine(T2SEngineProtocol):
|
||||
def __init__(
|
||||
@ -31,10 +32,10 @@ class T2SEngine(T2SEngineProtocol):
|
||||
device = mx.Device(mx.cpu)
|
||||
case "mx.gpu":
|
||||
device = mx.Device(mx.gpu)
|
||||
|
||||
device = cast(mx.Device, device)
|
||||
match dtype:
|
||||
case torch.float32:
|
||||
dtype = mx.float32
|
||||
dtype = mx.float16 if device.type == mx.gpu else mx.float32
|
||||
case torch.float16:
|
||||
dtype = mx.float16
|
||||
case torch.bfloat16:
|
||||
@ -59,13 +60,14 @@ class T2SEngine(T2SEngineProtocol):
|
||||
decoder = self.decoder_model
|
||||
session = T2SSessionMLX(decoder, request, device=self.device, dtype=self.dtype)
|
||||
batch_idx = mx.arange(session.bsz)
|
||||
debug = request.debug
|
||||
|
||||
t1 = 0.0
|
||||
infer_speed = 0.0
|
||||
infer_time = 0.0
|
||||
idx = 0
|
||||
|
||||
with (
|
||||
mx.stream(session.device),
|
||||
Progress(
|
||||
TextColumn("[cyan]{task.description}"),
|
||||
BarColumn(),
|
||||
@ -75,29 +77,47 @@ class T2SEngine(T2SEngineProtocol):
|
||||
transient=True,
|
||||
) as progress,
|
||||
):
|
||||
max_token = min(2000 - int(session.input_pos.max()), 1500)
|
||||
max_token = min(1500 - int(session.input_pos.max()), 1000) * session.bsz
|
||||
|
||||
task = progress.add_task("T2S Decoding", total=max_token)
|
||||
for idx in range(1500):
|
||||
progress.update(task, advance=1)
|
||||
for idx in range(max_token):
|
||||
progress.update(task, advance=session.bsz)
|
||||
if idx == 0:
|
||||
session.kv_cache = decoder.init_cache(session.bsz)
|
||||
xy_dec = decoder.h.prefill(
|
||||
session.xy_pos,
|
||||
session.attn_mask,
|
||||
session.kv_cache,
|
||||
) # bs, seq_len, embed_dim
|
||||
xy_dec = xy_dec[None, batch_idx, session.input_pos - 1]
|
||||
t1 = time.perf_counter()
|
||||
with timer("MLX.Prefill", debug=debug):
|
||||
xy_dec = decoder.h.prefill(
|
||||
session.xy_pos,
|
||||
session.attn_mask,
|
||||
session.kv_cache,
|
||||
) # bs, seq_len, embed_dim
|
||||
xy_dec = xy_dec[None, batch_idx, session.input_pos - 1]
|
||||
if debug:
|
||||
mx.eval(xy_dec)
|
||||
else:
|
||||
args, kwds = decoder.pre_forward(session)
|
||||
xy_dec = decoder.h(
|
||||
session.input_pos,
|
||||
session.xy_pos,
|
||||
session.kv_cache,
|
||||
batch_idx,
|
||||
*args,
|
||||
**kwds,
|
||||
)
|
||||
|
||||
if debug:
|
||||
mx.eval(session.input_pos, session.xy_pos, session.kv_cache, args, kwds, batch_idx)
|
||||
|
||||
if debug and idx == 50 and os.environ.get("MTL_CAPTURE_ENABLED") == "1":
|
||||
os.makedirs("./profiler/mlx", exist_ok=True)
|
||||
mx.metal.start_capture(f"./profiler/mlx/{time.time()}.gputrace")
|
||||
|
||||
with timer("MLX.Decode", debug=debug):
|
||||
xy_dec = decoder.h(
|
||||
session.input_pos,
|
||||
session.xy_pos,
|
||||
session.kv_cache,
|
||||
batch_idx,
|
||||
*args,
|
||||
**kwds,
|
||||
)
|
||||
if debug:
|
||||
mx.eval(xy_dec)
|
||||
|
||||
if debug and idx == 50 and os.environ.get("MTL_CAPTURE_ENABLED") == "1":
|
||||
mx.metal.stop_capture()
|
||||
|
||||
decoder.post_forward(idx, session)
|
||||
logits = decoder.ar_predict_layer(xy_dec[:, -1])
|
||||
@ -106,28 +126,38 @@ class T2SEngine(T2SEngineProtocol):
|
||||
if idx == 0:
|
||||
logits[:, -1] = -mx.inf
|
||||
|
||||
samples = session.sample(
|
||||
logits=logits,
|
||||
previous_tokens=session.y[:, : session.y_len + idx],
|
||||
top_k=request.top_k,
|
||||
top_p=request.top_p,
|
||||
repetition_penalty=request.repetition_penalty,
|
||||
temperature=request.temperature,
|
||||
)
|
||||
with timer("MLX.Sampling", debug=debug):
|
||||
samples = session.sample(
|
||||
logits=logits,
|
||||
previous_tokens=session.y[:, : session.y_len + idx],
|
||||
top_k=request.top_k,
|
||||
top_p=request.top_p,
|
||||
repetition_penalty=request.repetition_penalty,
|
||||
temperature=request.temperature,
|
||||
)
|
||||
|
||||
session.y[batch_idx, session.y_len + idx] = samples
|
||||
session.y[batch_idx, session.y_len + idx] = samples
|
||||
|
||||
argmax_token = mx.argmax(logits, axis=-1)
|
||||
sample_token = samples.squeeze(1)
|
||||
EOS_mask = (cast(Array, argmax_token == decoder.EOS)) | (sample_token == decoder.EOS)
|
||||
if debug:
|
||||
mx.eval(samples)
|
||||
|
||||
newly_done_mask = EOS_mask & (~session.completed)
|
||||
newly_done_indices = mx.where(newly_done_mask, batch_idx, -1)
|
||||
pos = mx.where(newly_done_indices != -1, batch_idx, session.bsz)
|
||||
pos_sorted = mx.sort(pos, axis=0)
|
||||
valid_count = session.bsz - mx.sum(cast(Array, pos_sorted == session.bsz))
|
||||
pos_final = pos_sorted[: int(valid_count)]
|
||||
newly_done_indices = mx.expand_dims(newly_done_indices[pos_final], 0)
|
||||
with timer("MLX.EOS", debug=debug):
|
||||
mx.set_default_device(mx.Device(mx.cpu))
|
||||
argmax_token = mx.argmax(logits, axis=-1)
|
||||
sample_token = samples.squeeze(1)
|
||||
EOS_mask = (cast(Array, argmax_token == decoder.EOS)) | (sample_token == decoder.EOS)
|
||||
|
||||
newly_done_mask = EOS_mask & (~session.completed)
|
||||
newly_done_indices = mx.where(newly_done_mask, batch_idx, -1)
|
||||
pos = mx.where(newly_done_indices != -1, batch_idx, session.bsz)
|
||||
pos_sorted = mx.sort(pos, axis=0)
|
||||
valid_count = session.bsz - mx.sum(cast(Array, pos_sorted == session.bsz))
|
||||
pos_final = pos_sorted[: int(valid_count)]
|
||||
newly_done_indices = mx.expand_dims(newly_done_indices[pos_final], 0)
|
||||
mx.set_default_device(self.device)
|
||||
|
||||
if debug:
|
||||
mx.eval(newly_done_indices)
|
||||
|
||||
if newly_done_indices.size > 0:
|
||||
for i in newly_done_indices:
|
||||
@ -135,54 +165,53 @@ class T2SEngine(T2SEngineProtocol):
|
||||
session.completed[newly_done_indices] = True
|
||||
|
||||
if mx.all(session.completed).item():
|
||||
if session.y[:, session.y_len :].sum() == 0:
|
||||
session.y_results = [mx.array([0]) for _ in range(session.bsz)]
|
||||
logger.error("Bad Zero Prediction")
|
||||
else:
|
||||
logger.info(
|
||||
f"T2S Decoding EOS {session.prefill_len.tolist().__str__().strip('[]')} -> {[i.shape[-1] for i in session.y_results].__str__().strip('[]')}"
|
||||
)
|
||||
logger.info(f"Infer Speed: {(idx - 1) / (time.perf_counter() - t1):.2f} token/s")
|
||||
infer_time = time.perf_counter() - t1
|
||||
infer_speed = (idx - 1) / infer_time
|
||||
logger.info(
|
||||
f"T2S Decoding EOS {session.prefill_len.tolist().__str__().strip('[]')} -> {[i.shape[-1] for i in session.y_results].__str__().strip('[]')}"
|
||||
)
|
||||
logger.info(f"Infer Speed: {(idx + 1) / (time.perf_counter() - t1):.2f} token/s")
|
||||
infer_time = time.perf_counter() - t1
|
||||
infer_speed = (idx + 1) / infer_time
|
||||
break
|
||||
|
||||
if (request.early_stop_num != -1 and idx >= request.early_stop_num) or idx == max_token - 1:
|
||||
for j in range(session.bsz):
|
||||
if not session.completed[j].item():
|
||||
session.y_results[j] = session.y[[j], session.y_len : session.y_len + 1499]
|
||||
session.y_results[j] = session.y[[j], session.y_len : session.y_len + idx]
|
||||
session.completed[j] = True
|
||||
logger.error("Bad Full Prediction")
|
||||
logger.info(f"Infer Speed: {(idx - 1) / (time.perf_counter() - t1):.2f} token/s")
|
||||
logger.info(f"Infer Speed: {(idx + 1) / (time.perf_counter() - t1):.2f} token/s")
|
||||
infer_time = time.perf_counter() - t1
|
||||
infer_speed = (idx - 1) / infer_time
|
||||
infer_speed = (idx + 1) / infer_time
|
||||
break
|
||||
|
||||
y_emb = decoder.ar_audio_embedding(samples)
|
||||
session.xy_pos = decoder.ar_audio_position(session.input_pos - session.x_lens, y_emb)
|
||||
mx.eval(session.xy_pos, session.y)
|
||||
with timer("MLX.NextPos", debug=debug):
|
||||
y_emb = decoder.ar_audio_embedding(samples)
|
||||
session.xy_pos = decoder.ar_audio_position(session.input_pos - session.x_lens, y_emb)
|
||||
mx.eval(session.xy_pos, session.y)
|
||||
|
||||
if idx == 1:
|
||||
t1 = time.perf_counter()
|
||||
|
||||
if idx % 100 == 0:
|
||||
if idx % 128 == 0:
|
||||
mx.clear_cache()
|
||||
|
||||
match session.device:
|
||||
case mx.gpu:
|
||||
mx.clear_cache()
|
||||
case mx.cpu:
|
||||
gc.collect()
|
||||
|
||||
result_mlx = session.y_results[: request.valid_length]
|
||||
mx.eval(result_mlx)
|
||||
result = [torch.tensor(k) for k in result_mlx]
|
||||
return result, infer_speed, infer_time
|
||||
mx.clear_cache()
|
||||
|
||||
if debug:
|
||||
timer.summary()
|
||||
timer.clear()
|
||||
|
||||
return result, infer_speed, infer_time, (idx + 1) * session.bsz
|
||||
|
||||
def generate(self, request: T2SRequest):
|
||||
try:
|
||||
result, infer_speed, infer_time = self._handle_request(request)
|
||||
t2s_result = T2SResult(result=result, infer_speed=(infer_speed, infer_time), status="Success")
|
||||
result, infer_speed, infer_time, total_tokens = self._handle_request(request)
|
||||
t2s_result = T2SResult(
|
||||
result=result,
|
||||
infer_speed=(infer_speed, infer_time),
|
||||
total_tokens=total_tokens,
|
||||
status="Success",
|
||||
)
|
||||
except Exception as e:
|
||||
t2s_result = T2SResult(status="Error", exception=e, traceback=traceback.format_exc())
|
||||
return t2s_result
|
||||
@ -199,40 +228,39 @@ class T2SEngine(T2SEngineProtocol):
|
||||
.replace("norm1", "attention_norm")
|
||||
.replace("norm2", "ffn_norm")
|
||||
)
|
||||
value_mlx = mx.array(value) # type: ignore
|
||||
value_mlx = mx.array(value.to(torch.float32).cpu().numpy())
|
||||
state_dict_mlx.append((key, value_mlx))
|
||||
return state_dict_mlx
|
||||
|
||||
@staticmethod
|
||||
def load_decoder(weights_path: os.PathLike, max_batch_size: int = 1, backend: str = "MLX-Varlen"):
|
||||
def load_decoder(
|
||||
weights_path: os.PathLike,
|
||||
max_batch_size: int = 1,
|
||||
backend: str = "MLX-Varlen",
|
||||
quantize_mode: Literal["Affine", "MXFP4"] | None = None,
|
||||
) -> T2SDecoderABC:
|
||||
logger.info(f"Loading Text2Semantic Weights from {weights_path} with {backend} Backend")
|
||||
dict_s1 = torch.load(weights_path, map_location="cpu", weights_only=False, mmap=True)
|
||||
dict_s1 = torch.load(weights_path, map_location="cpu", weights_only=True, mmap=True)
|
||||
config = dict_s1["config"]
|
||||
match backend:
|
||||
case "MLX-Varlen":
|
||||
decoder_cls: type[T2SDecoderABC] = mlx_varlen.T2SDecoder
|
||||
case "MLX-Static":
|
||||
decoder_cls = mlx_static.T2SDecoder
|
||||
case "MLX-Quantized-Affine" | "MLX-Quantized-MXFP4":
|
||||
decoder_cls = mlx_quantized.T2SDecoder
|
||||
case _:
|
||||
raise RuntimeError(f"Backend {backend} Not Found")
|
||||
|
||||
decoder: T2SDecoderABC = decoder_cls(config, max_batch_size=max_batch_size)
|
||||
state_dict = dict_s1["weight"]
|
||||
state_dict_mlx = T2SEngine.replace_key(state_dict)
|
||||
|
||||
decoder.load_weights(state_dict_mlx)
|
||||
decoder.eval()
|
||||
|
||||
if quantize_mode is not None:
|
||||
decoder.quantize(quantize_mode)
|
||||
logger.info(
|
||||
f"Quantized to {decoder.bits}-Bit with Group Size {decoder.group_size} by {quantize_mode} Quantization"
|
||||
)
|
||||
|
||||
mx.eval(decoder)
|
||||
|
||||
if "Quantized" in backend and isinstance(decoder, mlx_quantized.T2SDecoder):
|
||||
if backend == "MLX-Quantized-Affine":
|
||||
decoder.set_mode("affine")
|
||||
elif backend == "MLX-Quantized-MXFP4":
|
||||
decoder.set_mode("mxfp4")
|
||||
else:
|
||||
raise RuntimeError(f"Quantized Backend {backend} Not Supported")
|
||||
decoder.quantized()
|
||||
mx.eval(decoder)
|
||||
|
||||
return decoder
|
||||
|
||||
@ -2,12 +2,13 @@ from __future__ import annotations
|
||||
|
||||
import math
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import MutableSequence
|
||||
from typing import Literal, MutableSequence, Type
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
from mlx.core import Dtype
|
||||
|
||||
from .structs_mlx import KVCache, KVCacheProtocol, KVCacheQ, T2SDecoderProtocol, T2SSessionMLX
|
||||
from .structs_mlx import KVCache, KVCacheProtocol, T2SDecoderProtocol, T2SSessionMLX
|
||||
|
||||
Array = mx.array
|
||||
|
||||
@ -43,26 +44,28 @@ class SinePositionalEmbedding(nn.Module):
|
||||
embedding_dim: int,
|
||||
scale: bool = False,
|
||||
max_batch_size: int = 10,
|
||||
max_seq_len: int = 2000,
|
||||
max_seq_length: int = 1500,
|
||||
):
|
||||
super().__init__()
|
||||
self.embedding_dim = embedding_dim
|
||||
self.x_scale = math.sqrt(embedding_dim) if scale else 1.0
|
||||
self.alpha = mx.ones(1)
|
||||
self.max_batch_size = max_batch_size
|
||||
self.max_seq_len = max_seq_len
|
||||
self.max_seq_length = max_seq_length
|
||||
|
||||
self.reverse = False
|
||||
self._pe = mx.zeros((max_batch_size, max_seq_len, embedding_dim))
|
||||
self.compute_pe()
|
||||
self.pe: Array | None = None
|
||||
|
||||
def compute_pe(self):
|
||||
"""Reset the positional encodings."""
|
||||
def compute_pe(self, dtype: Dtype):
|
||||
"""Compute the positional encodings."""
|
||||
|
||||
if self.pe is not None and self.pe.dtype == dtype:
|
||||
return
|
||||
|
||||
if self.reverse:
|
||||
position = mx.expand_dims(mx.arange(self.max_seq_len - 1, -1, -1.0), axis=1)
|
||||
position = mx.expand_dims(mx.arange(self.max_seq_length - 1, -1, -1.0), axis=1)
|
||||
else:
|
||||
position = mx.expand_dims(mx.arange(self.max_seq_len), axis=1)
|
||||
position = mx.expand_dims(mx.arange(self.max_seq_length), axis=1)
|
||||
div_term = mx.exp(
|
||||
mx.arange(
|
||||
0,
|
||||
@ -70,10 +73,13 @@ class SinePositionalEmbedding(nn.Module):
|
||||
2,
|
||||
)
|
||||
* -(math.log(10000.0) / self.embedding_dim)
|
||||
)
|
||||
pe = self._pe
|
||||
pe[:, :, 0::2] = mx.sin(position * div_term)
|
||||
pe[:, :, 1::2] = mx.cos(position * div_term)
|
||||
).astype(dtype)
|
||||
pe = mx.zeros((self.max_batch_size, self.max_seq_length, self.embedding_dim)).astype(dtype)
|
||||
|
||||
pe[:, :, 0::2] = mx.sin(position * div_term).astype(dtype)
|
||||
pe[:, :, 1::2] = mx.cos(position * div_term).astype(dtype)
|
||||
|
||||
self.pe = pe
|
||||
|
||||
def __call__(self, input_pos: Array, x: Array):
|
||||
"""
|
||||
@ -84,9 +90,11 @@ class SinePositionalEmbedding(nn.Module):
|
||||
Returns:
|
||||
embedded_x (Array): [batch_size, 1, embed_dim]
|
||||
"""
|
||||
self.compute_pe(x.dtype)
|
||||
assert self.pe is not None
|
||||
|
||||
batch_size = x.shape[0]
|
||||
pe_values = self._pe[mx.arange(batch_size), input_pos - 1] # (batch_size, embed_dim)
|
||||
pe_values = self.pe[mx.arange(batch_size), input_pos - 1] # (batch_size, embed_dim)
|
||||
|
||||
return x * self.x_scale + self.alpha * mx.expand_dims(pe_values, 1) # (batch_size, 1, embed_dim)
|
||||
|
||||
@ -98,7 +106,10 @@ class SinePositionalEmbedding(nn.Module):
|
||||
Returns:
|
||||
embedded_x (Array): [batch_size, seq_len, embed_dim]
|
||||
"""
|
||||
pe_values = self._pe[:, : x.shape[-2]]
|
||||
self.compute_pe(x.dtype)
|
||||
assert self.pe is not None
|
||||
|
||||
pe_values = self.pe[:, : x.shape[-2]]
|
||||
return x * self.x_scale + self.alpha * pe_values
|
||||
|
||||
|
||||
@ -125,12 +136,12 @@ class KVCacheHND(KVCacheProtocol):
|
||||
|
||||
@staticmethod
|
||||
def prefill_kv(k_val, v_val, kv_cache):
|
||||
# k_val: [B, S, H, D]
|
||||
# k_val: [B, H, S, D]
|
||||
assert len(kv_cache) == 2
|
||||
k_cache, v_cache = kv_cache
|
||||
|
||||
k_cache[..., : k_val.shape[1], :] = k_val.swapaxes(1, 2)
|
||||
v_cache[..., : v_val.shape[1], :] = v_val.swapaxes(1, 2)
|
||||
k_cache[..., : k_val.shape[2], :] = k_val
|
||||
v_cache[..., : v_val.shape[2], :] = v_val
|
||||
|
||||
@staticmethod
|
||||
def init_cache(batch_size: int, max_seq_length: int, n_heads: int, head_dim: int, dtype: mx.Dtype) -> KVCache:
|
||||
@ -139,118 +150,6 @@ class KVCacheHND(KVCacheProtocol):
|
||||
return (mx.zeros(cache_shape, dtype=dtype), mx.zeros(cache_shape, dtype=dtype))
|
||||
|
||||
|
||||
class KVCacheHNDQuantized(KVCacheProtocol):
|
||||
@staticmethod
|
||||
def _el_per_int(bits: int) -> int:
|
||||
return 32 // bits
|
||||
|
||||
@staticmethod
|
||||
def _packed_dim(head_dim: int, bits: int = 8) -> int:
|
||||
el_per_int = KVCacheHNDQuantized._el_per_int(bits)
|
||||
if head_dim % el_per_int != 0:
|
||||
raise ValueError(f"{head_dim=} is not divisible by {el_per_int=} ({bits=})")
|
||||
return head_dim // el_per_int
|
||||
|
||||
@staticmethod
|
||||
def _group_count(head_dim: int, group_size: int = 32) -> int:
|
||||
assert group_size in {32, 64, 128}
|
||||
if head_dim % group_size != 0:
|
||||
raise ValueError(f"{head_dim} is not divisible by {group_size=}")
|
||||
return head_dim // group_size
|
||||
|
||||
@staticmethod
|
||||
def empty(kv_cache) -> None:
|
||||
assert len(kv_cache) == 3
|
||||
(k_q, k_s, k_b), (v_q, v_s, v_b), (_, __) = kv_cache
|
||||
|
||||
k_q[:] = 0
|
||||
k_s[:] = 0
|
||||
k_b[:] = 0
|
||||
v_q[:] = 0
|
||||
v_s[:] = 0
|
||||
v_b[:] = 0
|
||||
|
||||
@staticmethod
|
||||
def update_cache(
|
||||
input_pos,
|
||||
k_val,
|
||||
v_val,
|
||||
kv_cache,
|
||||
cache_idx,
|
||||
):
|
||||
# input_pos: [B, ], k_val: [B, H, 1, D]
|
||||
|
||||
assert len(kv_cache) == 3
|
||||
(k_q_out, k_s_out, k_b_out), (v_q_out, v_s_out, v_b_out), (group_size, bits) = kv_cache
|
||||
|
||||
k_q, k_s, k_b = mx.quantize(k_val, group_size=group_size, bits=bits)
|
||||
v_q, v_s, v_b = mx.quantize(v_val, group_size=group_size, bits=bits)
|
||||
|
||||
ip0 = input_pos - 1
|
||||
|
||||
k_q_out[cache_idx, :, ip0, None] = k_q
|
||||
k_s_out[cache_idx, :, ip0, None] = k_s
|
||||
k_b_out[cache_idx, :, ip0, None] = k_b
|
||||
|
||||
v_q_out[cache_idx, :, ip0, None] = v_q
|
||||
v_s_out[cache_idx, :, ip0, None] = v_s
|
||||
v_b_out[cache_idx, :, ip0, None] = v_b
|
||||
|
||||
return (k_q_out, k_s_out, k_b_out), (v_q_out, v_s_out, v_b_out), (group_size, bits)
|
||||
|
||||
@staticmethod
|
||||
def prefill_kv(
|
||||
k_val,
|
||||
v_val,
|
||||
kv_cache,
|
||||
) -> None:
|
||||
assert len(kv_cache) == 3
|
||||
(k_q_out, k_s_out, k_b_out), (v_q_out, v_s_out, v_b_out), (group_size, bits) = kv_cache
|
||||
|
||||
S = k_val.shape[1]
|
||||
|
||||
k_sw = k_val.swapaxes(1, 2)
|
||||
v_sw = v_val.swapaxes(1, 2)
|
||||
|
||||
k_q, k_s, k_b = mx.quantize(k_sw, group_size=group_size, bits=bits)
|
||||
v_q, v_s, v_b = mx.quantize(v_sw, group_size=group_size, bits=bits)
|
||||
|
||||
k_q_out[..., :S, :] = k_q
|
||||
k_s_out[..., :S, :] = k_s
|
||||
k_b_out[..., :S, :] = k_b
|
||||
|
||||
v_q_out[..., :S, :] = v_q
|
||||
v_s_out[..., :S, :] = v_s
|
||||
v_b_out[..., :S, :] = v_b
|
||||
|
||||
@staticmethod
|
||||
def init_cache(
|
||||
batch_size: int,
|
||||
max_seq_length: int,
|
||||
n_heads: int,
|
||||
head_dim: int,
|
||||
dtype: mx.Dtype,
|
||||
*,
|
||||
group_size: int = 32,
|
||||
bits: int = 8,
|
||||
) -> KVCacheQ:
|
||||
packed_dim = KVCacheHNDQuantized._packed_dim(head_dim, bits=bits)
|
||||
group_cnt = KVCacheHNDQuantized._group_count(head_dim, group_size=group_size)
|
||||
|
||||
packed_shape = (batch_size, n_heads, max_seq_length, packed_dim)
|
||||
group_shape = (batch_size, n_heads, max_seq_length, group_cnt)
|
||||
|
||||
k_q = mx.zeros(packed_shape, dtype=mx.uint32)
|
||||
k_s = mx.zeros(group_shape, dtype=dtype)
|
||||
k_b = mx.zeros(group_shape, dtype=dtype)
|
||||
|
||||
v_q = mx.zeros(packed_shape, dtype=mx.uint32)
|
||||
v_s = mx.zeros(group_shape, dtype=dtype)
|
||||
v_b = mx.zeros(group_shape, dtype=dtype)
|
||||
|
||||
return (k_q, k_s, k_b), (v_q, v_s, v_b), (group_size, bits)
|
||||
|
||||
|
||||
class AttentionABC(ABC, nn.Module):
|
||||
def __init__(self, n_head: int, hidden_dim: int, max_seq_length: int, *args, **kwds):
|
||||
super().__init__()
|
||||
@ -271,26 +170,26 @@ class AttentionABC(ABC, nn.Module):
|
||||
self.kc_class: KVCacheProtocol
|
||||
|
||||
@abstractmethod
|
||||
def __call__(
|
||||
self, x: Array, input_pos: Array, kv_cache: KVCache | KVCacheQ, cache_idx: Array, attn_mask: Array
|
||||
) -> Array: ...
|
||||
def __call__(self, x: Array, input_pos: Array, kv_cache: KVCache, cache_idx: Array, attn_mask: Array) -> Array: ...
|
||||
|
||||
def prefill(self, x: Array, kv_cache: KVCache | KVCacheQ, attn_mask: Array):
|
||||
def prefill(self, x: Array, kv_cache: KVCache, attn_mask: Array):
|
||||
bsz, seqlen, _ = x.shape
|
||||
|
||||
q, k, v = self.in_proj(x).split(3, axis=-1)
|
||||
qkv = self.in_proj(x)
|
||||
|
||||
q, k, v = map(lambda x: x.reshape(bsz, seqlen, self.n_head, self.head_dim), (q, k, v))
|
||||
q, k, v = mx.split(qkv, 3, -1)
|
||||
|
||||
q = q.reshape(bsz, seqlen, self.n_head, -1).transpose(0, 2, 1, 3)
|
||||
k = k.reshape(bsz, seqlen, self.n_head, -1).transpose(0, 2, 1, 3)
|
||||
v = v.reshape(bsz, seqlen, self.n_head, -1).transpose(0, 2, 1, 3)
|
||||
|
||||
self.kc_class.prefill_kv(k, v, kv_cache)
|
||||
|
||||
q, k, v = map(lambda x: x.swapaxes(1, 2), (q, k, v))
|
||||
|
||||
attn = mx.fast.scaled_dot_product_attention(q, k, v, mask=attn_mask, scale=self.scale)
|
||||
|
||||
attn = mx.nan_to_num(attn)
|
||||
|
||||
attn = attn.swapaxes(1, 2).reshape(1, -1, self.hidden_dim)
|
||||
attn = attn.transpose(0, 2, 1, 3).reshape(bsz, seqlen, -1)
|
||||
|
||||
output = self.out_proj(attn)
|
||||
|
||||
@ -321,7 +220,7 @@ class TransformerBlockABC(nn.Module):
|
||||
self.attention_norm = nn.LayerNorm(self.hidden_dim)
|
||||
self.ffn_norm = nn.LayerNorm(self.hidden_dim)
|
||||
|
||||
def __call__(self, x: Array, input_pos: Array, kv_cache: KVCache | KVCacheQ, cache_idx: Array, attn_mask: Array):
|
||||
def __call__(self, x: Array, input_pos: Array, kv_cache: KVCache, cache_idx: Array, attn_mask: Array):
|
||||
h = self.attention_norm(
|
||||
x
|
||||
+ self.attention(
|
||||
@ -335,7 +234,7 @@ class TransformerBlockABC(nn.Module):
|
||||
out = self.ffn_norm(h + self.feed_forward(h))
|
||||
return out
|
||||
|
||||
def prefill(self, x: Array, attn_mask: Array, kv_cache: KVCache | KVCacheQ):
|
||||
def prefill(self, x: Array, attn_mask: Array, kv_cache: KVCache):
|
||||
h = self.attention_norm(
|
||||
x
|
||||
+ self.attention.prefill(
|
||||
@ -382,7 +281,7 @@ class TransformerDecoderABC(nn.Module):
|
||||
self,
|
||||
input_pos: Array,
|
||||
x: Array,
|
||||
kv_caches: MutableSequence[KVCache | KVCacheQ],
|
||||
kv_caches: MutableSequence[KVCache],
|
||||
cache_idx: Array,
|
||||
*args,
|
||||
**kwds,
|
||||
@ -399,7 +298,7 @@ class TransformerDecoderABC(nn.Module):
|
||||
|
||||
return x
|
||||
|
||||
def prefill(self, x: Array, mask: Array, kv_caches: MutableSequence[KVCache | KVCacheQ]):
|
||||
def prefill(self, x: Array, mask: Array, kv_caches: MutableSequence[KVCache]):
|
||||
for layer, kv_cache in zip(self.layers, kv_caches):
|
||||
x = layer.prefill(
|
||||
x,
|
||||
@ -413,7 +312,7 @@ class T2SDecoderABC(nn.Module, T2SDecoderProtocol):
|
||||
def __init__(
|
||||
self,
|
||||
config: dict,
|
||||
max_seq_length: int = 2000,
|
||||
max_seq_length: int = 1500,
|
||||
max_batch_size: int = 10,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
@ -451,24 +350,27 @@ class T2SDecoderABC(nn.Module, T2SDecoderProtocol):
|
||||
self.embedding_dim,
|
||||
scale=False,
|
||||
max_batch_size=max_batch_size,
|
||||
max_seq_len=max_seq_length,
|
||||
max_seq_length=max_seq_length,
|
||||
)
|
||||
self.ar_audio_embedding = TokenEmbedding(self.embedding_dim, self.vocab_size)
|
||||
self.ar_audio_position = SinePositionalEmbedding(
|
||||
self.embedding_dim,
|
||||
scale=False,
|
||||
max_batch_size=max_batch_size,
|
||||
max_seq_len=max_seq_length,
|
||||
max_seq_length=max_seq_length,
|
||||
)
|
||||
|
||||
self.kv_class: KVCacheProtocol
|
||||
self.kv_class: Type[KVCacheProtocol]
|
||||
|
||||
def init_cache(self, bsz: int = 0, *args, **kwds) -> MutableSequence[KVCache | KVCacheQ]:
|
||||
self.bits: int = -1
|
||||
self.group_size: int = -1
|
||||
|
||||
def init_cache(self, bsz: int = 0, *args, **kwds) -> MutableSequence[KVCache]:
|
||||
bsz = bsz or self.h.max_batch_size
|
||||
assert bsz <= self.h.max_batch_size
|
||||
seq_lens = self.h.max_seq_length
|
||||
dtype = self.bert_proj.bias.dtype
|
||||
cache: MutableSequence[KVCache | KVCacheQ] = [
|
||||
cache: MutableSequence[KVCache] = [
|
||||
self.kv_class.init_cache(bsz, seq_lens, self.n_head, self.head_dim, dtype, *args, **kwds)
|
||||
for _ in range(self.n_layer)
|
||||
]
|
||||
@ -503,8 +405,7 @@ class T2SDecoderABC(nn.Module, T2SDecoderProtocol):
|
||||
return xy_pos
|
||||
|
||||
def compile(self):
|
||||
setattr(self.h, "__call__", mx.compile(self.h.__call__))
|
||||
# setattr(self.h, "prefill", mx.compile(self.h.prefill, shapeless=True))
|
||||
setattr(self.h, "__call__", mx.compile(self.h.__call__, shapeless=True))
|
||||
|
||||
def pre_forward(self, session: T2SSessionMLX):
|
||||
attn_mask = session.attn_mask
|
||||
@ -525,4 +426,21 @@ class T2SDecoderABC(nn.Module, T2SDecoderProtocol):
|
||||
attn_mask = session.attn_mask
|
||||
input_pos = session.input_pos
|
||||
attn_mask[mx.arange(session.bsz), :, :, input_pos] = True
|
||||
mx.eval(attn_mask)
|
||||
|
||||
def quantize(self, mode: Literal["Affine", "MXFP4"] | None = None) -> None:
|
||||
if mode is None:
|
||||
return
|
||||
if mode not in {"Affine", "MXFP4"}:
|
||||
raise ValueError(f"Unsupported quantization mode: {mode}")
|
||||
match mode:
|
||||
case "Affine":
|
||||
self.bits = 8
|
||||
self.group_size = 32
|
||||
nn.quantize(self.h, group_size=self.group_size, bits=self.bits, mode="affine")
|
||||
case "MXFP4":
|
||||
self.bits = 4
|
||||
self.group_size = 32
|
||||
nn.quantize(self.h, group_size=self.group_size, bits=self.bits, mode="mxfp4")
|
||||
|
||||
case _:
|
||||
raise ValueError(f"Unsupported Quantization Mode for MLX: {mode}")
|
||||
|
||||
@ -7,10 +7,23 @@ from .structs import T2SRequest, T2SResult
|
||||
from .t2s_engine import T2SEngine as T2SEngineTorch
|
||||
|
||||
torch.set_grad_enabled(False)
|
||||
|
||||
if torch.__version__ >= "2.9.0":
|
||||
torch.backends.fp32_precision = "tf32" # type: ignore
|
||||
else:
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
torch.backends.cudnn.allow_tf32 = True
|
||||
|
||||
if torch.cuda.is_available():
|
||||
torch.backends.cuda.preferred_blas_library("cublaslt")
|
||||
|
||||
|
||||
torch.backends.cuda.allow_fp16_bf16_reduction_math_sdp(True)
|
||||
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = True
|
||||
torch.backends.cuda.matmul.allow_fp16_accumulation = True
|
||||
torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = True
|
||||
torch.backends.cudnn.benchmark = True
|
||||
torch.backends.cudnn.enabled = True
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
torch.backends.cudnn.allow_tf32 = True
|
||||
|
||||
backends = ["torch_varlen"]
|
||||
if torch.cuda.is_available():
|
||||
@ -30,5 +43,21 @@ if torch.cuda.is_available():
|
||||
# if torch.mps.is_available():
|
||||
# backends.append("mps_flash_attn_varlen")
|
||||
|
||||
BLACKWELL = False
|
||||
if torch.cuda.is_available():
|
||||
for i in range(torch.cuda.device_count()):
|
||||
major, minor = torch.cuda.get_device_capability(i)
|
||||
sm_version = major + minor / 10.0
|
||||
if sm_version >= 9.0:
|
||||
BLACKWELL = True
|
||||
|
||||
__all__ = ["T2SEngineTorch", "T2SRequest", "sample_naive", "T2SResult", "backends"]
|
||||
quantization_methods_torch: list[str | None] = [None]
|
||||
if importlib.util.find_spec("torchao") is not None:
|
||||
quantization_methods_torch.append("Int8")
|
||||
if BLACKWELL:
|
||||
quantization_methods_torch.append("FP8")
|
||||
if BLACKWELL:
|
||||
quantization_methods_torch.append("FP8_E4M3FN")
|
||||
|
||||
|
||||
__all__ = ["T2SEngineTorch", "T2SRequest", "sample_naive", "T2SResult", "backends", "quantization_methods_torch"]
|
||||
|
||||
@ -100,7 +100,7 @@ class T2SDecoder(T2SDecoderABC):
|
||||
def __init__(
|
||||
self,
|
||||
config,
|
||||
max_seq_length=2000,
|
||||
max_seq_length=1500,
|
||||
max_batch_size=10,
|
||||
) -> None:
|
||||
assert torch.cuda.is_available()
|
||||
|
||||
@ -78,7 +78,7 @@ class T2SDecoder(T2SDecoderABC):
|
||||
def __init__(
|
||||
self,
|
||||
config,
|
||||
max_seq_length=2000,
|
||||
max_seq_length=1500,
|
||||
max_batch_size=10,
|
||||
) -> None:
|
||||
super().__init__(config, max_seq_length, max_batch_size)
|
||||
|
||||
@ -94,7 +94,7 @@ class T2SDecoder(T2SDecoderABC):
|
||||
def __init__(
|
||||
self,
|
||||
config,
|
||||
max_seq_length=2000,
|
||||
max_seq_length=1500,
|
||||
max_batch_size=10,
|
||||
) -> None:
|
||||
super().__init__(config, max_seq_length, max_batch_size)
|
||||
|
||||
@ -78,7 +78,7 @@ class T2SDecoder(T2SDecoderABC):
|
||||
def __init__(
|
||||
self,
|
||||
config,
|
||||
max_seq_length=2000,
|
||||
max_seq_length=1500,
|
||||
max_batch_size=10,
|
||||
) -> None:
|
||||
super().__init__(config, max_seq_length, max_batch_size)
|
||||
|
||||
@ -86,7 +86,7 @@ class T2SDecoder(T2SDecoderABC):
|
||||
def __init__(
|
||||
self,
|
||||
config,
|
||||
max_seq_length=2000,
|
||||
max_seq_length=1500,
|
||||
max_batch_size=10,
|
||||
) -> None:
|
||||
super().__init__(config, max_seq_length, max_batch_size)
|
||||
|
||||
158
GPT_SoVITS/Accelerate/PyTorch/quantization.py
Normal file
158
GPT_SoVITS/Accelerate/PyTorch/quantization.py
Normal file
@ -0,0 +1,158 @@
|
||||
from typing import cast
|
||||
|
||||
import torch
|
||||
|
||||
from . import nn
|
||||
|
||||
Tensor = torch.Tensor
|
||||
|
||||
|
||||
# based on ComfyUI's and MinusZoneAI's fp8_linear optimization
|
||||
def fp8_linear_forward(cls: nn.Linear, input: Tensor):
|
||||
weight_dtype = cls.weight.dtype
|
||||
base_dtype = input.dtype
|
||||
if weight_dtype in [torch.float8_e4m3fn, torch.float8_e5m2]:
|
||||
if len(input.shape) == 3:
|
||||
input_shape = input.shape
|
||||
|
||||
scale_weight: Tensor | None = getattr(cls, "scale_weight", None)
|
||||
if scale_weight is None:
|
||||
scale_weight = torch.ones((), device=input.device, dtype=torch.float32)
|
||||
else:
|
||||
scale_weight = scale_weight.to(input.device).squeeze()
|
||||
|
||||
scale_input = torch.ones((), device=input.device, dtype=torch.float32)
|
||||
|
||||
input = torch.clamp(input, min=-448, max=448, out=input)
|
||||
inn = (
|
||||
input.reshape(-1, input_shape[2]).to(torch.float8_e4m3fn).contiguous()
|
||||
) # always e4m3fn because e5m2 * e5m2 is not supported
|
||||
|
||||
bias = cls.bias if cls.bias is not None else None
|
||||
|
||||
o = torch._scaled_mm(
|
||||
inn, cls.weight.t(), out_dtype=base_dtype, bias=bias, scale_a=scale_input, scale_b=scale_weight
|
||||
)
|
||||
|
||||
return o.reshape((-1, input_shape[1], cls.weight.shape[0]))
|
||||
else:
|
||||
raise
|
||||
else:
|
||||
raise
|
||||
|
||||
|
||||
def convert_fp8_linear(
|
||||
module: nn.Module,
|
||||
):
|
||||
apply_fn = fp8_linear_forward
|
||||
|
||||
for _, sub in list(module.named_modules()):
|
||||
if isinstance(sub, nn.Linear):
|
||||
if getattr(sub, "_fp8", False):
|
||||
continue
|
||||
setattr(sub, "forward", apply_fn)
|
||||
setattr(sub, "_fp8", True)
|
||||
return module
|
||||
|
||||
|
||||
def per_tensor_quantize(tensor: torch.Tensor) -> tuple[Tensor, Tensor]:
|
||||
"""Quantize a tensor using per-tensor static scaling factor.
|
||||
Args:
|
||||
tensor: The input tensor.
|
||||
"""
|
||||
finfo = torch.finfo(torch.float8_e4m3fn)
|
||||
# Calculate the scale as dtype max divided by absmax.
|
||||
# Since .abs() creates a new tensor, we use aminmax to get
|
||||
# the min and max first and then calculate the absmax.
|
||||
if tensor.numel() == 0:
|
||||
# Deal with empty tensors (triggered by empty MoE experts)
|
||||
min_val, max_val = (
|
||||
torch.tensor(0.0, dtype=tensor.dtype),
|
||||
torch.tensor(1.0, dtype=tensor.dtype),
|
||||
)
|
||||
else:
|
||||
min_val, max_val = tensor.aminmax()
|
||||
amax = min_val.abs().max(max_val.abs())
|
||||
scale = finfo.max / amax.clamp(min=1e-12)
|
||||
# scale and clamp the tensor to bring it to
|
||||
# the representative range of float8 data type
|
||||
# (as default cast is unsaturated)
|
||||
qweight = (tensor * scale).clamp(min=finfo.min, max=finfo.max)
|
||||
# Return both float8 data and the inverse scale (as float),
|
||||
# as both required as inputs to torch._scaled_mm
|
||||
qweight = qweight.to(torch.float8_e4m3fn)
|
||||
scale = scale.float().reciprocal()
|
||||
return qweight, scale
|
||||
|
||||
|
||||
def fp8_gemm(A, A_scale, B, B_scale, bias, out_dtype):
|
||||
cuda_compute_capability = torch.cuda.get_device_capability()
|
||||
if cuda_compute_capability >= (9, 0):
|
||||
output, _ = torch._scaled_mm(
|
||||
A,
|
||||
B.t(),
|
||||
out_dtype=out_dtype,
|
||||
scale_a=A_scale,
|
||||
scale_b=B_scale,
|
||||
bias=bias,
|
||||
)
|
||||
else:
|
||||
output = torch.nn.functional.linear(
|
||||
A.to(out_dtype) * A_scale,
|
||||
B.to(out_dtype) * B_scale.to(out_dtype),
|
||||
bias=bias,
|
||||
)
|
||||
return output
|
||||
|
||||
|
||||
class FP8DynamicLinear(nn.Module):
|
||||
def __init__(self, qweight: Tensor, scale: Tensor, bias: Tensor):
|
||||
super().__init__()
|
||||
self.weight = torch.nn.Parameter(qweight, requires_grad=False)
|
||||
self.weight_scale = torch.nn.Parameter(scale, requires_grad=False)
|
||||
self.bias = bias
|
||||
|
||||
def __call__(self, x):
|
||||
qinput, x_scale = per_tensor_quantize(x)
|
||||
output = fp8_gemm(
|
||||
A=qinput,
|
||||
A_scale=x_scale,
|
||||
B=self.weight,
|
||||
B_scale=self.weight_scale,
|
||||
bias=self.bias,
|
||||
out_dtype=x.dtype,
|
||||
)
|
||||
return output
|
||||
|
||||
|
||||
def replace_all_linear_with_fp8(model: nn.Module):
|
||||
"""
|
||||
Recursively replace every nn.Linear with FP8DynamicLinear in-place.
|
||||
"""
|
||||
|
||||
def _recursively_replace(parent: nn.Module):
|
||||
for child_name, child in list(parent.named_children()):
|
||||
child = cast(nn.Module, child)
|
||||
if isinstance(child, FP8DynamicLinear):
|
||||
continue
|
||||
|
||||
if isinstance(child, nn.Linear):
|
||||
device = child.weight.device
|
||||
|
||||
w = child.weight
|
||||
|
||||
b = child.bias.clone()
|
||||
|
||||
qw, qs = per_tensor_quantize(w)
|
||||
|
||||
quant_linear = FP8DynamicLinear(qw, qs, b)
|
||||
|
||||
quant_linear.to(device)
|
||||
|
||||
setattr(parent, child_name, quant_linear)
|
||||
|
||||
del child
|
||||
else:
|
||||
_recursively_replace(child)
|
||||
|
||||
_recursively_replace(model)
|
||||
@ -1,20 +1,79 @@
|
||||
from typing import Protocol
|
||||
from typing import Callable, Protocol, TypeVar, cast
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from typing_extensions import ParamSpec
|
||||
|
||||
P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
Tensor = torch.Tensor
|
||||
|
||||
|
||||
def script(fn: Callable[P, R]) -> Callable[P, R]:
|
||||
scripted = torch.jit.script(fn)
|
||||
return cast(Callable[P, R], scripted)
|
||||
|
||||
|
||||
@script
|
||||
def apply_repetition_penalty(logits: Tensor, previous_tokens: Tensor, repetition_penalty: float):
|
||||
previous_tokens = previous_tokens.long()
|
||||
score = torch.gather(logits, dim=1, index=previous_tokens)
|
||||
score = torch.where(
|
||||
score < 0,
|
||||
score * repetition_penalty,
|
||||
score / repetition_penalty,
|
||||
)
|
||||
logits.scatter_(dim=1, index=previous_tokens, src=score)
|
||||
return logits
|
||||
|
||||
|
||||
@script
|
||||
def apply_greedy_sampling(logits: Tensor):
|
||||
return torch.argmax(logits, dim=-1, keepdim=True).to(dtype=torch.int32)
|
||||
|
||||
|
||||
@script
|
||||
def apply_temperature(logits: Tensor, temperature: float):
|
||||
return logits / temperature
|
||||
|
||||
|
||||
@script
|
||||
def apply_top_k(logits: Tensor, top_k: int):
|
||||
v, _ = torch.topk(logits, top_k)
|
||||
pivot = v[:, -1].unsqueeze(-1)
|
||||
logits = torch.where(logits < pivot, -float("Inf"), logits)
|
||||
return logits
|
||||
|
||||
|
||||
@script
|
||||
def apply_top_p(logits: Tensor, top_p: float):
|
||||
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
|
||||
cum_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
|
||||
cum_probs[cum_probs > 1] = 1
|
||||
sorted_indices_to_remove = cum_probs > top_p
|
||||
sorted_indices_to_remove[:, 0] = False # keep at least one option
|
||||
indices_to_remove = sorted_indices_to_remove.scatter(dim=1, index=sorted_indices, src=sorted_indices_to_remove)
|
||||
logits = logits.masked_fill(indices_to_remove, -float("Inf"))
|
||||
return logits
|
||||
|
||||
|
||||
@script
|
||||
def apply_sampling(logits: Tensor):
|
||||
probs = F.softmax(logits, dim=-1)
|
||||
q = -torch.log(torch.rand_like(probs))
|
||||
idx_next = torch.argmax(probs / q, dim=-1, keepdim=True).to(dtype=torch.int32)
|
||||
return idx_next
|
||||
|
||||
|
||||
class SampleProtocol(Protocol):
|
||||
@staticmethod
|
||||
def __call__(
|
||||
logits: Tensor,
|
||||
previous_tokens: Tensor,
|
||||
repetition_penalty: float,
|
||||
temperature: float,
|
||||
top_k: int,
|
||||
top_p: float,
|
||||
repetition_penalty: float,
|
||||
) -> Tensor: ...
|
||||
|
||||
|
||||
@ -23,45 +82,23 @@ class sample_naive(SampleProtocol):
|
||||
def __call__(
|
||||
logits: Tensor,
|
||||
previous_tokens: Tensor,
|
||||
temperature: float,
|
||||
top_k: int,
|
||||
top_p: float,
|
||||
repetition_penalty: float,
|
||||
repetition_penalty: float = 1.35,
|
||||
temperature: float = 1.0,
|
||||
top_k: int = 15,
|
||||
top_p: float = 1.0,
|
||||
):
|
||||
if temperature <= 1e-5:
|
||||
probs = F.softmax(logits, dim=-1)
|
||||
return torch.argmax(probs, dim=-1, keepdim=True).to(dtype=torch.int32)
|
||||
|
||||
if repetition_penalty != 1.0:
|
||||
previous_tokens = previous_tokens.long()
|
||||
score = torch.gather(logits, dim=1, index=previous_tokens)
|
||||
score = torch.where(
|
||||
score < 0,
|
||||
score * repetition_penalty,
|
||||
score / repetition_penalty,
|
||||
)
|
||||
logits.scatter_(dim=1, index=previous_tokens, src=score)
|
||||
logits = apply_repetition_penalty(logits, previous_tokens, repetition_penalty)
|
||||
|
||||
if temperature <= 1e-5:
|
||||
return apply_greedy_sampling(logits)
|
||||
elif temperature < 1.0:
|
||||
logits = apply_temperature(logits, temperature)
|
||||
|
||||
if top_k < 1025:
|
||||
logits = apply_top_k(logits, top_k)
|
||||
|
||||
if top_p < 1.0:
|
||||
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
|
||||
cum_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
|
||||
cum_probs[cum_probs > 1] = 1
|
||||
sorted_indices_to_remove = cum_probs > top_p
|
||||
sorted_indices_to_remove[:, 0] = False # keep at least one option
|
||||
indices_to_remove = sorted_indices_to_remove.scatter(
|
||||
dim=1, index=sorted_indices, src=sorted_indices_to_remove
|
||||
)
|
||||
logits = logits.masked_fill(indices_to_remove, -float("Inf"))
|
||||
logits = apply_top_p(logits, top_p)
|
||||
|
||||
if temperature < 1.0:
|
||||
logits /= temperature
|
||||
|
||||
v, _ = torch.topk(logits, top_k)
|
||||
pivot = v[:, -1].unsqueeze(-1)
|
||||
logits = torch.where(logits < pivot, -float("Inf"), logits)
|
||||
|
||||
probs = F.softmax(logits, dim=-1)
|
||||
q = -torch.log(torch.rand_like(probs))
|
||||
idx_next = torch.argmax(probs / q, dim=-1, keepdim=True).to(dtype=torch.int32)
|
||||
|
||||
return idx_next
|
||||
return apply_sampling(logits)
|
||||
|
||||
@ -18,6 +18,7 @@ Tensor = torch.Tensor
|
||||
class T2SResult:
|
||||
result: list[Tensor] | None = None
|
||||
infer_speed: tuple[float, float] = (0.0, 0.0)
|
||||
total_tokens: int = 0
|
||||
status: Literal["Success", "Error"] = "Success"
|
||||
exception: Optional[Exception] = None
|
||||
traceback: Optional[str] = None
|
||||
@ -66,7 +67,7 @@ class T2SDecoderProtocol(Protocol):
|
||||
|
||||
|
||||
class T2SEngineProtocol(Protocol):
|
||||
def _handle_request(self, request: T2SRequest) -> tuple[list[Tensor], float, float]: ...
|
||||
def _handle_request(self, request: T2SRequest) -> tuple[list[Tensor], float, float, int]: ...
|
||||
|
||||
def generate(self, request: T2SRequest) -> T2SResult: ...
|
||||
|
||||
@ -107,6 +108,7 @@ class T2SSession:
|
||||
|
||||
self.input_pos = torch.zeros_like(self.prefill_len)
|
||||
self.input_pos.add_(self.prefill_len)
|
||||
self.input_pos.squeeze_(0)
|
||||
|
||||
# CUDA Graph
|
||||
self.stream: Optional[torch.cuda.Stream] = None
|
||||
|
||||
@ -1,15 +1,15 @@
|
||||
import contextlib
|
||||
import gc
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import traceback
|
||||
from importlib import import_module
|
||||
from typing import Literal
|
||||
|
||||
import torch
|
||||
from rich.progress import BarColumn, Progress, TextColumn
|
||||
|
||||
from ..logger import SpeedColumnToken, console, logger
|
||||
from ..logger import SpeedColumnToken, console, logger, timer
|
||||
from .structs import T2SEngineProtocol, T2SRequest, T2SResult, T2SSession
|
||||
from .t2s_model_abc import (
|
||||
CUDAGraphCacheABC,
|
||||
@ -41,12 +41,14 @@ class T2SEngine(T2SEngineProtocol):
|
||||
decoder = self.decoder_model
|
||||
session = T2SSession(decoder, request, device=self.device, dtype=self.dtype)
|
||||
batch_idx = torch.arange(session.bsz)
|
||||
debug = request.debug
|
||||
|
||||
t1 = 0.0
|
||||
infer_speed = 0.0
|
||||
infer_time = 0.0
|
||||
idx = 0
|
||||
|
||||
torch_profiler = TorchProfiler(request.debug)
|
||||
torch_profiler = TorchProfiler(debug)
|
||||
with (
|
||||
torch_profiler.profiler(),
|
||||
Progress(
|
||||
@ -59,14 +61,15 @@ class T2SEngine(T2SEngineProtocol):
|
||||
) as progress,
|
||||
):
|
||||
torch_profiler.start()
|
||||
max_token = int(min(2000 - session.input_pos.max(), 1500))
|
||||
max_token = min(int(1500 - session.input_pos.max()), 1000)
|
||||
task = progress.add_task("T2S Decoding", total=max_token)
|
||||
|
||||
for idx in range(max_token):
|
||||
progress.update(task, advance=1)
|
||||
if idx == 0:
|
||||
with torch_profiler.record("Prefill"):
|
||||
with torch_profiler.record("Prefill"), timer("Torch.Prefill", debug=debug):
|
||||
session.kv_cache = decoder.init_cache(session.bsz)
|
||||
t1 = time.perf_counter()
|
||||
xy_dec = decoder.h.prefill(session.xy_pos, session.kv_cache, session.attn_mask)
|
||||
xy_dec = xy_dec[None, batch_idx, session.input_pos - 1]
|
||||
else:
|
||||
@ -78,7 +81,7 @@ class T2SEngine(T2SEngineProtocol):
|
||||
):
|
||||
self.graphcache.assign_graph(session)
|
||||
|
||||
with torch_profiler.record("Decode"):
|
||||
with torch_profiler.record("Decode"), timer("Torch.Decode", debug=debug):
|
||||
if session.graph:
|
||||
assert session.stream
|
||||
session.stream.wait_stream(torch.cuda.default_stream())
|
||||
@ -103,7 +106,7 @@ class T2SEngine(T2SEngineProtocol):
|
||||
if idx == 0:
|
||||
logits[:, -1] = float("-inf")
|
||||
|
||||
with torch_profiler.record("Sampling"):
|
||||
with torch_profiler.record("Sampling"), timer("Torch.Sampling", debug=debug):
|
||||
samples = session.sample(
|
||||
logits=logits,
|
||||
previous_tokens=session.y[:, : session.y_len + idx],
|
||||
@ -115,7 +118,7 @@ class T2SEngine(T2SEngineProtocol):
|
||||
session.y[batch_idx, session.y_len + idx] = samples
|
||||
session.input_pos.add_(1)
|
||||
|
||||
with torch_profiler.record("EOS"):
|
||||
with torch_profiler.record("EOS"), timer("Torch.EOS", debug=debug):
|
||||
argmax_token = torch.argmax(logits, dim=-1)
|
||||
sample_token = samples.squeeze(1)
|
||||
EOS_mask = (argmax_token == decoder.EOS) | (sample_token == decoder.EOS)
|
||||
@ -128,73 +131,48 @@ class T2SEngine(T2SEngineProtocol):
|
||||
session.y_results[i] = session.y[i, session.y_len : session.y_len + idx]
|
||||
session.completed[newly_done_indices] = True
|
||||
|
||||
if torch.all(session.completed).item():
|
||||
if session.y[:, session.y_len :].sum() == 0:
|
||||
session.y_results = [torch.tensor(0) for _ in range(session.bsz)]
|
||||
logger.error("Bad Zero Prediction")
|
||||
else:
|
||||
logger.info(
|
||||
f"T2S Decoding EOS {session.prefill_len.tolist().__str__().strip('[]')} -> {[i.size(-1) for i in session.y_results].__str__().strip('[]')}"
|
||||
)
|
||||
logger.info(f"Infer Speed: {(idx - 1) / (time.perf_counter() - t1):.2f} token/s")
|
||||
infer_time = time.perf_counter() - t1
|
||||
infer_speed = (idx - 1) / infer_time
|
||||
break
|
||||
if torch.all(session.completed).item():
|
||||
logger.info(
|
||||
f"T2S Decoding EOS {session.prefill_len.tolist().__str__().strip('[]')} -> {[i.size(-1) for i in session.y_results].__str__().strip('[]')}"
|
||||
)
|
||||
logger.info(
|
||||
f"Infer Speed: {(idx + 1) * session.bsz / (time.perf_counter() - t1):.2f} token/s"
|
||||
)
|
||||
infer_time = time.perf_counter() - t1
|
||||
infer_speed = (idx + 1) * session.bsz / infer_time
|
||||
break
|
||||
|
||||
if (request.early_stop_num != -1 and idx >= request.early_stop_num) or idx == max_token - 1:
|
||||
for i in range(session.bsz):
|
||||
if not session.completed[i].item():
|
||||
session.y_results[i] = session.y[i, session.y_len : session.y_len + 1499]
|
||||
session.completed[i] = True
|
||||
logger.error("Bad Full Prediction")
|
||||
break
|
||||
if (request.early_stop_num != -1 and idx >= request.early_stop_num) or idx == max_token - 1:
|
||||
for i in range(session.bsz):
|
||||
if not session.completed[i].item():
|
||||
session.y_results[i] = session.y[[i], session.y_len : session.y_len + idx]
|
||||
session.completed[i] = True
|
||||
logger.error("Bad Full Prediction")
|
||||
infer_time = time.perf_counter() - t1
|
||||
infer_speed = (idx + 1) * session.bsz / infer_time
|
||||
break
|
||||
|
||||
with torch_profiler.record("NextPos"):
|
||||
with torch_profiler.record("NextPos"), timer("Torch.NextPos", debug=debug):
|
||||
y_emb = decoder.ar_audio_embedding(samples)
|
||||
session.xy_pos = decoder.ar_audio_position(session.input_pos - session.x_lens, y_emb)
|
||||
|
||||
if idx == 1:
|
||||
t1 = time.perf_counter()
|
||||
|
||||
if idx == 20:
|
||||
if idx == 10:
|
||||
torch_profiler.end()
|
||||
|
||||
if idx % 100 == 0:
|
||||
match session.device.type:
|
||||
case "cuda":
|
||||
torch.cuda.empty_cache()
|
||||
case "mps":
|
||||
torch.mps.empty_cache()
|
||||
case "xpu":
|
||||
torch.xpu.empty_cache()
|
||||
case "mtia":
|
||||
torch.mtia.empty_cache()
|
||||
case "cpu":
|
||||
pass
|
||||
|
||||
match session.device.type:
|
||||
case "cuda":
|
||||
if session.stream is not None:
|
||||
torch.cuda.current_stream().wait_stream(session.stream)
|
||||
torch.cuda.empty_cache()
|
||||
case "mps":
|
||||
torch.mps.empty_cache()
|
||||
case "xpu":
|
||||
torch.xpu.empty_cache()
|
||||
case "mtia":
|
||||
torch.mtia.empty_cache()
|
||||
case "cpu":
|
||||
gc.collect(1)
|
||||
|
||||
if request.use_cuda_graph and self.graphcache.is_applicable:
|
||||
self.graphcache.release_graph(session)
|
||||
|
||||
return session.y_results[: request.valid_length], infer_speed, infer_time
|
||||
return session.y_results[: request.valid_length], infer_speed, infer_time, (idx + 1) * session.bsz
|
||||
|
||||
def generate(self, request: T2SRequest):
|
||||
try:
|
||||
result, infer_speed, infer_time = self._handle_request(request)
|
||||
t2s_result = T2SResult(result=result, infer_speed=(infer_speed, infer_time), status="Success")
|
||||
result, infer_speed, infer_time, total_tokens = self._handle_request(request)
|
||||
t2s_result = T2SResult(
|
||||
result=result,
|
||||
infer_speed=(infer_speed, infer_time),
|
||||
total_tokens=total_tokens,
|
||||
status="Success",
|
||||
)
|
||||
except Exception as e:
|
||||
t2s_result = T2SResult(status="Error", exception=e, traceback=traceback.format_exc())
|
||||
if self.decoder_model.compiled:
|
||||
@ -203,7 +181,12 @@ class T2SEngine(T2SEngineProtocol):
|
||||
return t2s_result
|
||||
|
||||
@staticmethod
|
||||
def load_decoder(weights_path: os.PathLike, max_batch_size: int = 1, backend: str = "Flash-Attn-Varlen-CUDAGraph"):
|
||||
def load_decoder(
|
||||
weights_path: os.PathLike,
|
||||
max_batch_size: int = 1,
|
||||
backend: str = "Flash-Attn-Varlen-CUDAGraph",
|
||||
quantize_mode: Literal["Int8", "FP8", "FP8_E4M3FN"] | None = None,
|
||||
) -> T2SDecoderABC:
|
||||
logger.info(f"Loading Text2Semantic Weights from {weights_path} with {backend} Backend")
|
||||
module_path = f".backends.{backend.lower().replace('-', '_').replace('cudagraph', 'cuda_graph')}"
|
||||
decoder_cls_name = "T2SDecoder"
|
||||
@ -215,6 +198,10 @@ class T2SEngine(T2SEngineProtocol):
|
||||
state_dict = dict_s1["weight"]
|
||||
decoder.load_state_dict(state_dict)
|
||||
|
||||
if quantize_mode is not None:
|
||||
decoder.quantize(quantize_mode)
|
||||
logger.info(f"Quantized by {quantize_mode} Quantization")
|
||||
|
||||
return decoder.eval()
|
||||
|
||||
def init_cache(self):
|
||||
|
||||
@ -13,7 +13,7 @@ import time
|
||||
from abc import ABC, abstractmethod
|
||||
from contextlib import nullcontext
|
||||
from pathlib import Path
|
||||
from typing import MutableSequence
|
||||
from typing import Literal, MutableSequence
|
||||
|
||||
import torch
|
||||
import torch._inductor.config
|
||||
@ -24,6 +24,7 @@ from torch.profiler import ExecutionTraceObserver, ProfilerAction, tensorboard_t
|
||||
from tools.my_utils import get_machine_id
|
||||
|
||||
from . import nn
|
||||
from .quantization import replace_all_linear_with_fp8
|
||||
from .structs import KVCacheProtocol, T2SDecoderProtocol, T2SSession
|
||||
|
||||
Tensor = torch.Tensor
|
||||
@ -61,26 +62,26 @@ class SinePositionalEmbedding(nn.Module):
|
||||
scale: bool = False,
|
||||
alpha: bool = False,
|
||||
max_batch_size: int = 10,
|
||||
max_seq_len: int = 2000,
|
||||
max_seq_length: int = 1500,
|
||||
):
|
||||
super().__init__()
|
||||
self.embedding_dim = embedding_dim
|
||||
self.x_scale = math.sqrt(embedding_dim) if scale else 1.0
|
||||
self.alpha = nn.Parameter(torch.ones(1), requires_grad=alpha)
|
||||
self.max_batch_size = max_batch_size
|
||||
self.max_seq_len = max_seq_len
|
||||
self.max_seq_length = max_seq_length
|
||||
|
||||
self.reverse = False
|
||||
self.register_buffer("pe", torch.zeros(max_batch_size, max_seq_len, embedding_dim), persistent=False)
|
||||
self.register_buffer("pe", torch.zeros(max_batch_size, max_seq_length, embedding_dim), persistent=False)
|
||||
self.pe: torch.Tensor
|
||||
self.compute_pe()
|
||||
|
||||
def compute_pe(self):
|
||||
"""Reset the positional encodings."""
|
||||
if self.reverse:
|
||||
position = torch.arange(self.max_seq_len - 1, -1, -1.0, dtype=torch.float32).unsqueeze(1)
|
||||
position = torch.arange(self.max_seq_length - 1, -1, -1.0, dtype=torch.float32).unsqueeze(1)
|
||||
else:
|
||||
position = torch.arange(self.max_seq_len, dtype=torch.float32).unsqueeze(1)
|
||||
position = torch.arange(self.max_seq_length, dtype=torch.float32).unsqueeze(1)
|
||||
div_term = torch.exp(
|
||||
torch.arange(0, self.embedding_dim, 2, dtype=torch.float32) * -(math.log(10000.0) / self.embedding_dim)
|
||||
)
|
||||
@ -423,7 +424,7 @@ class T2SDecoderABC(nn.Module, ABC, T2SDecoderProtocol):
|
||||
def __init__(
|
||||
self,
|
||||
config: dict,
|
||||
max_seq_length: int = 2000,
|
||||
max_seq_length: int = 1500,
|
||||
max_batch_size: int = 10,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
@ -467,7 +468,7 @@ class T2SDecoderABC(nn.Module, ABC, T2SDecoderProtocol):
|
||||
scale=False,
|
||||
alpha=True,
|
||||
max_batch_size=max_batch_size,
|
||||
max_seq_len=max_seq_length,
|
||||
max_seq_length=max_seq_length,
|
||||
)
|
||||
self.ar_audio_embedding = TokenEmbedding(self.embedding_dim, self.vocab_size)
|
||||
self.ar_audio_position = SinePositionalEmbedding(
|
||||
@ -475,9 +476,12 @@ class T2SDecoderABC(nn.Module, ABC, T2SDecoderProtocol):
|
||||
scale=False,
|
||||
alpha=True,
|
||||
max_batch_size=max_batch_size,
|
||||
max_seq_len=max_seq_length,
|
||||
max_seq_length=max_seq_length,
|
||||
)
|
||||
|
||||
self.bits: int
|
||||
self.group_size: int
|
||||
|
||||
self._register_load_state_dict_pre_hook(self.load_hook)
|
||||
|
||||
def load_hook(self, state_dict: dict[str, Tensor], prefix, *args):
|
||||
@ -608,6 +612,32 @@ class T2SDecoderABC(nn.Module, ABC, T2SDecoderProtocol):
|
||||
def post_forward(self, idx: int, session: T2SSession) -> None:
|
||||
return
|
||||
|
||||
def quantize(self, mode: Literal["Int8", "FP8", "FP8_E4M3FN"] | None = None) -> None:
|
||||
if mode is None:
|
||||
return
|
||||
if mode not in {"Int8", "FP8", "FP8_E4M3FN"}:
|
||||
raise ValueError(f"Unsupported quantization mode: {mode}")
|
||||
match mode:
|
||||
case "Int8":
|
||||
self.bits = 8
|
||||
self.group_size = 32
|
||||
import torchao
|
||||
|
||||
torchao.quantization.quantize_(self.h, torchao.quantization.Int8WeightOnlyConfig(self.group_size))
|
||||
|
||||
case "FP8":
|
||||
self.bits = 8
|
||||
import torchao
|
||||
|
||||
torchao.quantization.quantize_(self.h, torchao.quantization.Float8WeightOnlyConfig())
|
||||
|
||||
case "FP8_E4M3FN":
|
||||
self.bits = 8
|
||||
replace_all_linear_with_fp8(self.h)
|
||||
|
||||
case _:
|
||||
raise ValueError(f"Unsupported Quantization Mode for PyTorch: {mode}")
|
||||
|
||||
|
||||
class CUDAGraphCacheABC(ABC):
|
||||
def __init__(
|
||||
@ -662,13 +692,13 @@ class CUDAGraphCacheABC(ABC):
|
||||
|
||||
|
||||
class TorchProfiler:
|
||||
def __init__(self, debug: bool, log_dir: str = "./profiler") -> None:
|
||||
self.debug = debug
|
||||
self.log_dir = log_dir + str(time.time())
|
||||
def __init__(self, debug: bool, log_dir: str = "./profiler/torch") -> None:
|
||||
self.debug = debug and os.environ.get("TORCH_PROFILER") == "1"
|
||||
self.log_dir = log_dir + "/" + str(time.time())
|
||||
self.__profiler: torch.profiler.profile
|
||||
|
||||
if self.debug and not os.path.exists(self.log_dir):
|
||||
os.makedirs(self.log_dir)
|
||||
os.makedirs(self.log_dir, exist_ok=True)
|
||||
|
||||
self.tensorboard_handler = tensorboard_trace_handler(self.log_dir)
|
||||
|
||||
|
||||
@ -1,18 +1,13 @@
|
||||
from . import MLX, PyTorch
|
||||
from .logger import console, logger, tb
|
||||
from .PyTorch import T2SEngineTorch, T2SRequest, T2SResult
|
||||
from .MLX import quantization_methods_mlx
|
||||
from .PyTorch import T2SEngineTorch, T2SRequest, T2SResult, quantization_methods_torch
|
||||
from .PyTorch.structs import T2SEngineProtocol
|
||||
|
||||
backends = PyTorch.backends + MLX.backends
|
||||
|
||||
backends = [
|
||||
b.replace("_", "-")
|
||||
.title()
|
||||
.replace("Mlx", "MLX")
|
||||
.replace("Mps", "MPS")
|
||||
.replace("Cuda", "CUDA")
|
||||
.replace("Mxfp4", "MXFP4")
|
||||
for b in backends
|
||||
b.replace("_", "-").title().replace("Mlx", "MLX").replace("Mps", "MPS").replace("Cuda", "CUDA") for b in backends
|
||||
]
|
||||
|
||||
|
||||
@ -27,4 +22,6 @@ __all__ = [
|
||||
"console",
|
||||
"tb",
|
||||
"T2SEngineProtocol",
|
||||
"quantization_methods_torch",
|
||||
"quantization_methods_mlx",
|
||||
]
|
||||
|
||||
@ -1,4 +1,7 @@
|
||||
import sys
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from contextlib import nullcontext
|
||||
from typing import Optional
|
||||
|
||||
from loguru import logger
|
||||
@ -201,3 +204,46 @@ if __name__ == "__main__":
|
||||
raise RuntimeError()
|
||||
except Exception:
|
||||
logger.bind(show_locals=False).exception("TEST")
|
||||
|
||||
|
||||
class Timer:
|
||||
def __init__(self):
|
||||
self.records: dict[str, list[float]] = defaultdict(list)
|
||||
self._stack: list[tuple[str, int]] = []
|
||||
|
||||
def __call__(self, category: str, debug=False):
|
||||
timer = self
|
||||
|
||||
class _Ctx:
|
||||
def __enter__(self):
|
||||
timer._stack.append((category, time.perf_counter_ns()))
|
||||
return timer # 如需在with块里调用timer方法
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
end = time.perf_counter_ns()
|
||||
if not timer._stack:
|
||||
raise RuntimeError("Timer stack underflow: __exit__ without matching __enter__")
|
||||
cat, start = timer._stack.pop()
|
||||
if cat != category:
|
||||
raise RuntimeError(f"Mismatched timer context: expected '{cat}', got '{category}'")
|
||||
elapsed_sec = (end - start) / 1e9
|
||||
timer.records[cat].append(elapsed_sec)
|
||||
return False
|
||||
|
||||
if debug:
|
||||
return _Ctx()
|
||||
else:
|
||||
return nullcontext()
|
||||
|
||||
def clear(self):
|
||||
self.records.clear()
|
||||
self._stack.clear()
|
||||
|
||||
def summary(self):
|
||||
for cat, times in self.records.items():
|
||||
total = sum(times)
|
||||
avg = total / len(times) if times else 0.0
|
||||
print(f"{cat}: count={len(times)}, total={total:.6f}s, avg={avg:.6f}s")
|
||||
|
||||
|
||||
timer = Timer()
|
||||
|
||||
@ -51,13 +51,11 @@ warnings.filterwarnings(
|
||||
)
|
||||
warnings.filterwarnings("ignore", message=".*ComplexHalf support is experimental.*")
|
||||
|
||||
logging.getLogger("markdown_it").setLevel(logging.ERROR)
|
||||
logging.getLogger("urllib3").setLevel(logging.ERROR)
|
||||
logging.getLogger("httpcore").setLevel(logging.ERROR)
|
||||
logging.getLogger("httpx").setLevel(logging.ERROR)
|
||||
logging.getLogger("asyncio").setLevel(logging.ERROR)
|
||||
logging.getLogger("charset_normalizer").setLevel(logging.ERROR)
|
||||
logging.getLogger("torchaudio._extension").setLevel(logging.ERROR)
|
||||
logging.getLogger("multipart.multipart").setLevel(logging.ERROR)
|
||||
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
@ -88,6 +86,12 @@ def lang_type(text: str) -> str:
|
||||
return "Auto"
|
||||
|
||||
|
||||
def none_or_str(value: str):
|
||||
if value == "None":
|
||||
return None
|
||||
return value
|
||||
|
||||
|
||||
def build_parser() -> argparse.ArgumentParser:
|
||||
p = argparse.ArgumentParser(
|
||||
prog="inference_webui",
|
||||
@ -108,6 +112,15 @@ def build_parser() -> argparse.ArgumentParser:
|
||||
help="AR Inference Backend",
|
||||
required=False,
|
||||
)
|
||||
p.add_argument(
|
||||
"--quantization",
|
||||
"-q",
|
||||
default="None",
|
||||
choices=MLX.quantization_methods_mlx + PyTorch.quantization_methods_torch,
|
||||
type=none_or_str,
|
||||
help="Quantization Method",
|
||||
required=False,
|
||||
)
|
||||
p.add_argument(
|
||||
"--device",
|
||||
"-d",
|
||||
@ -393,10 +406,9 @@ with contextlib.suppress(UnboundLocalError):
|
||||
|
||||
def change_gpt_weights(gpt_path):
|
||||
global t2s_engine, config
|
||||
|
||||
if "mlx" in ar_backend.lower():
|
||||
t2s_engine = MLX.T2SEngineMLX(
|
||||
MLX.T2SEngineMLX.load_decoder(Path(gpt_path), backend=ar_backend),
|
||||
MLX.T2SEngineMLX.load_decoder(Path(gpt_path), backend=ar_backend, quantize_mode=args.quantization),
|
||||
"mx.gpu" if infer_device.type != "cpu" else "mx.cpu",
|
||||
dtype=dtype,
|
||||
)
|
||||
@ -404,7 +416,7 @@ def change_gpt_weights(gpt_path):
|
||||
total = sum((p[-1].size for p in mxutils.tree_flatten(t2s_engine.decoder_model.parameters()))) # type: ignore
|
||||
else:
|
||||
t2s_engine = PyTorch.T2SEngineTorch(
|
||||
PyTorch.T2SEngineTorch.load_decoder(Path(gpt_path), backend=ar_backend),
|
||||
PyTorch.T2SEngineTorch.load_decoder(Path(gpt_path), backend=ar_backend, quantize_mode=args.quantization),
|
||||
device,
|
||||
dtype=dtype,
|
||||
)
|
||||
@ -824,7 +836,7 @@ def get_tts_wav(
|
||||
pred_semantic_list = t2s_result.result
|
||||
assert pred_semantic_list, t2s_result.traceback
|
||||
pred_semantic = pred_semantic_list[0].unsqueeze(0).to(infer_device)
|
||||
infer_len.append(pred_semantic.shape[-1])
|
||||
infer_len.append(t2s_result.total_tokens)
|
||||
infer_time.append(t2s_result.infer_speed[-1])
|
||||
|
||||
cache[i_text] = pred_semantic
|
||||
|
||||
@ -30,6 +30,7 @@ pinyin_to_symbol_map = {
|
||||
parent_directory = os.path.dirname(current_file_path)
|
||||
|
||||
is_g2pw = os.getenv("G2PW", "1") == "1"
|
||||
debug = os.getenv("DEBUG", "0") == "1"
|
||||
if is_g2pw:
|
||||
g2pw = G2PWPinyin(
|
||||
model_dir="GPT_SoVITS/text/G2PWModel",
|
||||
@ -202,7 +203,8 @@ def _g2p(segments):
|
||||
# assert len(sub_initials) == len(sub_finals) == len(word)
|
||||
initials = sum(initials, [])
|
||||
finals = sum(finals, [])
|
||||
print("pypinyin结果", initials, finals)
|
||||
if debug:
|
||||
print("pypinyin结果", initials, finals)
|
||||
else:
|
||||
# g2pw采用整句推理
|
||||
pinyins = g2pw.lazy_pinyin(seg, neutral_tone_with_five=True, style=Style.TONE3)
|
||||
|
||||
@ -13,7 +13,7 @@ A Powerful Few-shot Voice Conversion and Text-to-Speech WebUI.<br><br>
|
||||
[](https://github.com/RVC-Boss/gpt-sovits/releases)
|
||||
|
||||
[](https://colab.research.google.com/github/RVC-Boss/GPT-SoVITS/blob/main/Colab-WebUI.ipynb)
|
||||
[](https://lj1995-gpt-sovits-proplus.hf.space/)
|
||||
[](https://lj1995-gpt-sovits-proplus.hf.space/)
|
||||
[](https://hub.docker.com/r/xxxxrt666/gpt-sovits)
|
||||
|
||||
[](https://www.yuque.com/baicaigongchang1145haoyuangong/ib3g1e)
|
||||
@ -57,7 +57,7 @@ Unseen speakers few-shot fine-tuning demo:
|
||||
| RTX 4090 | 0.014 | UNK | 24 | Flash Attn Varlen CUDAGraph |
|
||||
| RTX 4060 Ti | 0.07 | 460 ms | 1 | Flash Attn Varlen CUDAGraph |
|
||||
| RTX 4060 Ti | 0.028 | UNK | 28 | Flash Attn Varlen CUDAGraph |
|
||||
| Apple M4 | 0.21 | UNK | 1 | MLX Quantized Affined |
|
||||
| Apple M4 | 0.16 | UNK | 1 | MLX Varlen |
|
||||
|
||||
</div>
|
||||
|
||||
|
||||
@ -13,7 +13,7 @@
|
||||
[](https://github.com/RVC-Boss/gpt-sovits/releases)
|
||||
|
||||
[](https://colab.research.google.com/github/RVC-Boss/GPT-SoVITS/blob/main/Colab-WebUI.ipynb)
|
||||
[](https://lj1995-gpt-sovits-proplus.hf.space/)
|
||||
[](https://lj1995-gpt-sovits-proplus.hf.space/)
|
||||
[](https://hub.docker.com/r/xxxxrt666/gpt-sovits)
|
||||
|
||||
[](https://www.yuque.com/baicaigongchang1145haoyuangong/ib3g1e)
|
||||
@ -57,7 +57,7 @@
|
||||
| RTX 4090 | 0.014 | UNK | 24 | Flash Attn Varlen CUDAGraph |
|
||||
| RTX 4060 Ti | 0.07 | 460 ms | 1 | Flash Attn Varlen CUDAGraph |
|
||||
| RTX 4060 Ti | 0.028 | UNK | 28 | Flash Attn Varlen CUDAGraph |
|
||||
| Apple M4 | 0.21 | UNK | 1 | MLX Quantized Affined |
|
||||
| Apple M4 | 0.16 | UNK | 1 | MLX Varlen |
|
||||
|
||||
</div>
|
||||
|
||||
|
||||
@ -13,7 +13,7 @@
|
||||
[](https://github.com/RVC-Boss/gpt-sovits/releases)
|
||||
|
||||
[](https://colab.research.google.com/github/RVC-Boss/GPT-SoVITS/blob/main/Colab-WebUI.ipynb)
|
||||
[](https://lj1995-gpt-sovits-proplus.hf.space/)
|
||||
[](https://lj1995-gpt-sovits-proplus.hf.space/)
|
||||
[](https://hub.docker.com/r/xxxxrt666/gpt-sovits)
|
||||
|
||||
[](https://www.yuque.com/baicaigongchang1145haoyuangong/ib3g1e)
|
||||
@ -57,7 +57,7 @@
|
||||
| RTX 4090 | 0.014 | UNK | 24 | Flash Attn Varlen CUDAGraph |
|
||||
| RTX 4060 Ti | 0.07 | 460 ms | 1 | Flash Attn Varlen CUDAGraph |
|
||||
| RTX 4060 Ti | 0.028 | UNK | 28 | Flash Attn Varlen CUDAGraph |
|
||||
| Apple M4 | 0.21 | UNK | 1 | MLX Quantized Affined |
|
||||
| Apple M4 | 0.16 | UNK | 1 | MLX Varlen |
|
||||
|
||||
</div>
|
||||
|
||||
|
||||
@ -13,7 +13,7 @@
|
||||
[](https://github.com/RVC-Boss/gpt-sovits/releases)
|
||||
|
||||
[](https://colab.research.google.com/github/RVC-Boss/GPT-SoVITS/blob/main/Colab-WebUI.ipynb)
|
||||
[](https://lj1995-gpt-sovits-proplus.hf.space/)
|
||||
[](https://lj1995-gpt-sovits-proplus.hf.space/)
|
||||
[](https://hub.docker.com/r/xxxxrt666/gpt-sovits)
|
||||
|
||||
[](https://www.yuque.com/baicaigongchang1145haoyuangong/ib3g1e)
|
||||
@ -57,7 +57,7 @@
|
||||
| RTX 4090 | 0.014 | UNK | 24 | Flash Attn Varlen CUDAGraph |
|
||||
| RTX 4060 Ti | 0.07 | 460 ms | 1 | Flash Attn Varlen CUDAGraph |
|
||||
| RTX 4060 Ti | 0.028 | UNK | 28 | Flash Attn Varlen CUDAGraph |
|
||||
| Apple M4 | 0.21 | UNK | 1 | MLX Quantized Affined |
|
||||
| Apple M4 | 0.16 | UNK | 1 | MLX Varlen |
|
||||
|
||||
</div>
|
||||
|
||||
|
||||
@ -13,7 +13,7 @@ Güçlü Birkaç Örnekli Ses Dönüştürme ve Metinden Konuşmaya Web Arayüz
|
||||
[](https://github.com/RVC-Boss/gpt-sovits/releases)
|
||||
|
||||
[](https://colab.research.google.com/github/RVC-Boss/GPT-SoVITS/blob/main/Colab-WebUI.ipynb)
|
||||
[](https://lj1995-gpt-sovits-proplus.hf.space/)
|
||||
[](https://lj1995-gpt-sovits-proplus.hf.space/)
|
||||
[](https://hub.docker.com/r/xxxxrt666/gpt-sovits)
|
||||
|
||||
[](https://www.yuque.com/baicaigongchang1145haoyuangong/ib3g1e)
|
||||
@ -57,7 +57,7 @@ Görünmeyen konuşmacılar birkaç örnekli ince ayar demosu:
|
||||
| RTX 4090 | 0.014 | UNK | 24 | Flash Attn Varlen CUDAGraph |
|
||||
| RTX 4060 Ti | 0.07 | 460 ms | 1 | Flash Attn Varlen CUDAGraph |
|
||||
| RTX 4060 Ti | 0.028 | UNK | 28 | Flash Attn Varlen CUDAGraph |
|
||||
| Apple M4 | 0.21 | UNK | 1 | MLX Quantized Affined |
|
||||
| Apple M4 | 0.16 | UNK | 1 | MLX Varlen |
|
||||
|
||||
</div>
|
||||
|
||||
|
||||
@ -225,7 +225,7 @@ switch ($Device) {
|
||||
Write-Warning "CUDA 12.8 Is Not Supported By Current Driver"
|
||||
}
|
||||
Write-Info "Installing PyTorch For CUDA 12.8..."
|
||||
Invoke-Pip torch torchaudio --index-url "https://download.pytorch.org/whl/cu128"
|
||||
Invoke-Pip torch torchao --index-url "https://download.pytorch.org/whl/cu128"
|
||||
Invoke-Conda cuda-nvcc=12.8
|
||||
Invoke-Pip psutil ninja packaging wheel "setuptools>=42"
|
||||
Write-Info "Installing Flash Attn..."
|
||||
@ -240,7 +240,7 @@ switch ($Device) {
|
||||
Write-Warning "CUDA 12.6 Is Not Supported By Current Driver"
|
||||
}
|
||||
Write-Info "Installing PyTorch For CUDA 12.6..."
|
||||
Invoke-Pip torch torchaudio --index-url "https://download.pytorch.org/whl/cu126"
|
||||
Invoke-Pip torch torchao --index-url "https://download.pytorch.org/whl/cu126"
|
||||
Invoke-Conda cuda-nvcc=12.6
|
||||
Invoke-Pip psutil ninja packaging wheel "setuptools>=42"
|
||||
Write-Info "Installing Flash Attn..."
|
||||
@ -249,7 +249,7 @@ switch ($Device) {
|
||||
}
|
||||
"CPU" {
|
||||
Write-Info "Installing PyTorch For CPU..."
|
||||
Invoke-Pip torch torchaudio --index-url "https://download.pytorch.org/whl/cpu"
|
||||
Invoke-Pip torch torchao --index-url "https://download.pytorch.org/whl/cpu"
|
||||
}
|
||||
}
|
||||
Write-Success "PyTorch Installed"
|
||||
|
||||
10
install.sh
10
install.sh
@ -334,14 +334,14 @@ if [ "$USE_CUDA" = true ] && [ "$WORKFLOW" = false ]; then
|
||||
echo -r "${WARNING}CUDA 12.8 Is Not Supported By Current Driver"
|
||||
fi
|
||||
echo -e "${INFO}Installing PyTorch For CUDA 12.8..."
|
||||
run_pip_quiet torch torchaudio --index-url "https://download.pytorch.org/whl/cu128"
|
||||
run_pip_quiet torch torchao --index-url "https://download.pytorch.org/whl/cu128"
|
||||
run_conda_quiet cuda-nvcc=12.8
|
||||
elif [ "$CUDA" = 126 ]; then
|
||||
if awk "BEGIN {exit !($CUDAVERSION < 12.6)}"; then
|
||||
echo -r "${WARNING}CUDA 12.6 Is Not Supported By Current Driver"
|
||||
fi
|
||||
echo -e "${INFO}Installing PyTorch For CUDA 12.6..."
|
||||
run_pip_quiet torch torchaudio --index-url "https://download.pytorch.org/whl/cu126"
|
||||
run_pip_quiet torch torchao --index-url "https://download.pytorch.org/whl/cu126"
|
||||
run_conda_quiet cuda-nvcc=12.6
|
||||
fi
|
||||
echo -e "${INFO}Installing Flash Attn"
|
||||
@ -350,14 +350,14 @@ if [ "$USE_CUDA" = true ] && [ "$WORKFLOW" = false ]; then
|
||||
echo -e "${SUCCESS}Flash Attn Installed"
|
||||
elif [ "$USE_MLX" = true ] && [ "$WORKFLOW" = false ]; then
|
||||
echo -e "${INFO}Installing MLX & PyTorch For MPS..."
|
||||
run_pip_quiet torch torchaudio --index-url "https://download.pytorch.org/whl/cpu"
|
||||
run_pip_quiet torch torchao --index-url "https://download.pytorch.org/whl/cpu"
|
||||
run_pip_quiet mlx
|
||||
elif [ "$USE_ROCM" = true ] && [ "$WORKFLOW" = false ]; then
|
||||
echo -e "${INFO}Installing PyTorch For ROCm 6.2..."
|
||||
run_pip_quiet torch torchaudio --index-url "https://download.pytorch.org/whl/rocm6.2"
|
||||
run_pip_quiet torch torchao --index-url "https://download.pytorch.org/whl/rocm6.2"
|
||||
elif [ "$USE_CPU" = true ] && [ "$WORKFLOW" = false ]; then
|
||||
echo -e "${INFO}Installing PyTorch For CPU..."
|
||||
run_pip_quiet torch torchaudio --index-url "https://download.pytorch.org/whl/cpu"
|
||||
run_pip_quiet torch torchao --index-url "https://download.pytorch.org/whl/cpu"
|
||||
elif [ "$WORKFLOW" = false ]; then
|
||||
echo -e "${ERROR}Unknown Err"
|
||||
exit 1
|
||||
|
||||
@ -13,6 +13,7 @@ peft
|
||||
py-cpuinfo
|
||||
pypinyin
|
||||
split-lang
|
||||
torchao
|
||||
torchaudio
|
||||
torchcodec
|
||||
transformers
|
||||
|
||||
879
test.py
Normal file
879
test.py
Normal file
@ -0,0 +1,879 @@
|
||||
import argparse
|
||||
import contextlib
|
||||
import gc
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import traceback
|
||||
import warnings
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
from time import perf_counter as ttime
|
||||
from typing import Any
|
||||
|
||||
import gradio as gr
|
||||
import librosa
|
||||
import numpy as np
|
||||
import torch
|
||||
import torchaudio
|
||||
from transformers import AutoModelForMaskedLM, AutoTokenizer
|
||||
|
||||
from config import (
|
||||
change_choices,
|
||||
get_dtype,
|
||||
get_weights_names,
|
||||
pretrained_sovits_name,
|
||||
)
|
||||
from config import (
|
||||
infer_device as default_device,
|
||||
)
|
||||
from GPT_SoVITS.Accelerate import MLX, PyTorch, T2SEngineProtocol, T2SRequest, backends
|
||||
from GPT_SoVITS.Accelerate.logger import console, timer
|
||||
from GPT_SoVITS.feature_extractor import cnhubert
|
||||
from GPT_SoVITS.module.mel_processing import mel_spectrogram_torch, spectrogram_torch
|
||||
from GPT_SoVITS.module.models import SynthesizerTrn, SynthesizerTrnV3
|
||||
from GPT_SoVITS.process_ckpt import inspect_version
|
||||
from GPT_SoVITS.text import cleaned_text_to_sequence
|
||||
from GPT_SoVITS.text.cleaner import clean_text
|
||||
from GPT_SoVITS.text.LangSegmenter import LangSegmenter
|
||||
from tools.i18n.i18n import I18nAuto, scan_language_list
|
||||
from tools.my_utils import DictToAttrRecursive
|
||||
|
||||
warnings.filterwarnings(
|
||||
"ignore", message="MPS: The constant padding of more than 3 dimensions is not currently supported natively."
|
||||
)
|
||||
warnings.filterwarnings("ignore", message=".*ComplexHalf support is experimental.*")
|
||||
|
||||
logging.getLogger("urllib3").setLevel(logging.ERROR)
|
||||
logging.getLogger("httpcore").setLevel(logging.ERROR)
|
||||
logging.getLogger("httpx").setLevel(logging.ERROR)
|
||||
logging.getLogger("asyncio").setLevel(logging.ERROR)
|
||||
logging.getLogger("charset_normalizer").setLevel(logging.ERROR)
|
||||
logging.getLogger("multipart.multipart").setLevel(logging.ERROR)
|
||||
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
|
||||
|
||||
|
||||
_LANG_RE = re.compile(r"^[a-z]{2}[_-][A-Z]{2}$")
|
||||
|
||||
|
||||
def lang_type(text: str) -> str:
|
||||
if text == "Auto":
|
||||
return text
|
||||
if not _LANG_RE.match(text):
|
||||
raise argparse.ArgumentTypeError(f"Unspported Format: {text}, Expected ll_CC/ll-CC")
|
||||
ll, cc = re.split(r"[_-]", text)
|
||||
language = f"{ll}_{cc}"
|
||||
if language in scan_language_list():
|
||||
return language
|
||||
else:
|
||||
return "Auto"
|
||||
|
||||
|
||||
def none_or_str(value: str):
|
||||
if value == "None":
|
||||
return None
|
||||
return value
|
||||
|
||||
|
||||
def build_parser() -> argparse.ArgumentParser:
|
||||
p = argparse.ArgumentParser(
|
||||
prog="inference_webui",
|
||||
description=f"python -s -m GPT_SoVITS.inference_webui zh_CN -b {backends[-1]}",
|
||||
)
|
||||
p.add_argument(
|
||||
"language",
|
||||
nargs="?",
|
||||
default="Auto",
|
||||
type=lang_type,
|
||||
help="Language Code, Such as zh_CN, en-US",
|
||||
)
|
||||
p.add_argument(
|
||||
"--backends",
|
||||
"-b",
|
||||
choices=backends,
|
||||
default=backends[-1],
|
||||
help="AR Inference Backend",
|
||||
required=False,
|
||||
)
|
||||
p.add_argument(
|
||||
"--quantization",
|
||||
"-q",
|
||||
default="None",
|
||||
choices=MLX.quantization_methods_mlx + PyTorch.quantization_methods_torch,
|
||||
type=none_or_str,
|
||||
help="Quantization Method",
|
||||
required=False,
|
||||
)
|
||||
p.add_argument(
|
||||
"--device",
|
||||
"-d",
|
||||
default=str(default_device),
|
||||
help="Inference Device",
|
||||
required=False,
|
||||
)
|
||||
p.add_argument(
|
||||
"--port",
|
||||
"-p",
|
||||
default=9872,
|
||||
type=int,
|
||||
help="WebUI Binding Port",
|
||||
required=False,
|
||||
)
|
||||
p.add_argument(
|
||||
"--share",
|
||||
"-s",
|
||||
default=False,
|
||||
action="store_true",
|
||||
help="Gradio Share Link",
|
||||
required=False,
|
||||
)
|
||||
p.add_argument(
|
||||
"--cnhubert",
|
||||
default="GPT_SoVITS/pretrained_models/chinese-hubert-base",
|
||||
help="CNHuBERT Pretrain",
|
||||
required=False,
|
||||
)
|
||||
p.add_argument(
|
||||
"--bert",
|
||||
default="GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large",
|
||||
help="BERT Pretrain",
|
||||
required=False,
|
||||
)
|
||||
p.add_argument(
|
||||
"--gpt",
|
||||
default="",
|
||||
help="GPT Model",
|
||||
required=False,
|
||||
)
|
||||
p.add_argument(
|
||||
"--sovits",
|
||||
default="",
|
||||
help="SoVITS Model",
|
||||
required=False,
|
||||
)
|
||||
|
||||
return p
|
||||
|
||||
|
||||
args = build_parser().parse_args()
|
||||
|
||||
hps: Any = None
|
||||
vq_model: SynthesizerTrn | SynthesizerTrnV3 | None = None
|
||||
t2s_engine: T2SEngineProtocol | None = None
|
||||
|
||||
version = model_version = "v2"
|
||||
path_sovits_v3 = pretrained_sovits_name["v3"]
|
||||
path_sovits_v4 = pretrained_sovits_name["v4"]
|
||||
is_exist_s2gv3 = os.path.exists(path_sovits_v3)
|
||||
is_exist_s2gv4 = os.path.exists(path_sovits_v4)
|
||||
|
||||
cnhubert_base_path = str(args.cnhubert)
|
||||
bert_path = str(args.bert)
|
||||
infer_ttswebui = int(args.port)
|
||||
is_share = bool(args.share)
|
||||
|
||||
|
||||
i18n = I18nAuto(language=args.language)
|
||||
ar_backend: str = args.backends
|
||||
change_choices_i18n = partial(change_choices, i18n=i18n)
|
||||
|
||||
SoVITS_names, GPT_names = get_weights_names(i18n)
|
||||
|
||||
|
||||
dict_language_v1 = {
|
||||
i18n("中文"): "all_zh", # 全部按中文识别
|
||||
i18n("英文"): "en", # 全部按英文识别
|
||||
i18n("日文"): "all_ja", # 全部按日文识别
|
||||
i18n("中英混合"): "zh", # 按中英混合识别
|
||||
i18n("日英混合"): "ja", # 按日英混合识别
|
||||
i18n("多语种混合"): "auto", # 多语种启动切分识别语种
|
||||
}
|
||||
dict_language_v2 = {
|
||||
i18n("中文"): "all_zh", # 全部按中文识别
|
||||
i18n("英文"): "en", # 全部按英文识别
|
||||
i18n("日文"): "all_ja", # 全部按日文识别
|
||||
i18n("粤语"): "all_yue", # 全部按粤语识别
|
||||
i18n("韩文"): "all_ko", # 全部按韩文识别
|
||||
i18n("中英混合"): "zh",
|
||||
i18n("日英混合"): "ja",
|
||||
i18n("粤英混合"): "yue",
|
||||
i18n("韩英混合"): "ko",
|
||||
i18n("多语种混合"): "auto", # 多语种启动切分识别语种
|
||||
i18n("多语种混合(粤语)"): "auto_yue", # 多语种启动切分识别语种
|
||||
}
|
||||
dict_language = dict_language_v1 if version == "v1" else dict_language_v2
|
||||
|
||||
punctuation = set(["!", "?", "…", ",", ".", "-", " "])
|
||||
splits = {",", "。", "?", "!", ",", ".", "?", "!", "~", ":", ":", "—", "…"}
|
||||
v3v4set = {"v3", "v4"}
|
||||
|
||||
infer_device = torch.device(args.device)
|
||||
device = infer_device if infer_device.type == "cuda" else torch.device("cpu")
|
||||
|
||||
dtype = get_dtype(device.index)
|
||||
is_half = dtype == torch.float16
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(bert_path)
|
||||
bert_model = AutoModelForMaskedLM.from_pretrained(bert_path).to(infer_device, dtype)
|
||||
|
||||
cnhubert.cnhubert_base_path = cnhubert_base_path
|
||||
ssl_model = cnhubert.get_model().to(infer_device, dtype)
|
||||
|
||||
|
||||
def mel_fn(x):
|
||||
return mel_spectrogram_torch(
|
||||
y=x,
|
||||
n_fft=1024,
|
||||
num_mels=100,
|
||||
sampling_rate=24000,
|
||||
hop_size=256,
|
||||
win_size=1024,
|
||||
fmin=0,
|
||||
fmax=None,
|
||||
center=False,
|
||||
)
|
||||
|
||||
|
||||
gpt_path = str(args.gpt) or GPT_names[0][-1]
|
||||
sovits_path = str(args.sovits) or SoVITS_names[0][-1]
|
||||
|
||||
|
||||
def get_bert_feature(text, word2ph):
|
||||
inputs = tokenizer(text, return_tensors="pt")
|
||||
for i in inputs:
|
||||
inputs[i] = inputs[i].to(infer_device)
|
||||
res = bert_model(**inputs, output_hidden_states=True)
|
||||
res = torch.cat(res["hidden_states"][-3:-2], -1)[0].cpu()[1:-1]
|
||||
|
||||
assert len(word2ph) == len(text)
|
||||
phone_level_feature = []
|
||||
for i in range(len(word2ph)):
|
||||
repeat_feature = res[i].repeat(word2ph[i], 1)
|
||||
phone_level_feature.append(repeat_feature)
|
||||
phone_level_feature_t = torch.cat(phone_level_feature, dim=0)
|
||||
return phone_level_feature_t.T
|
||||
|
||||
|
||||
def change_sovits_weights(sovits_path, prompt_language=None, text_language=None):
|
||||
global vq_model, hps, version, model_version, dict_language
|
||||
model_version, version, is_lora, hps, dict_s2 = inspect_version(sovits_path)
|
||||
is_exist = is_exist_s2gv3 if model_version == "v3" else is_exist_s2gv4
|
||||
path_sovits = path_sovits_v3 if model_version == "v3" else path_sovits_v4
|
||||
if is_lora is True and is_exist is False:
|
||||
info = f"{path_sovits} SoVITS {model_version} {i18n('底模缺失,无法加载相应 LoRA 权重')}"
|
||||
gr.Warning(info)
|
||||
raise FileNotFoundError(info)
|
||||
dict_language = dict_language_v1 if version == "v1" else dict_language_v2
|
||||
visible_sample_steps = visible_inp_refs = None
|
||||
if prompt_language is not None and text_language is not None:
|
||||
if prompt_language in list(dict_language.keys()):
|
||||
prompt_text_update, prompt_language_update = gr.skip(), gr.update(choices=list(dict_language.keys()))
|
||||
else:
|
||||
prompt_text_update = gr.update(value="")
|
||||
prompt_language_update = gr.update(value=i18n("中文"), choices=list(dict_language.keys()))
|
||||
if text_language in list(dict_language.keys()):
|
||||
text_update, text_language_update = gr.skip(), gr.skip()
|
||||
else:
|
||||
text_update = gr.update(value="")
|
||||
text_language_update = gr.update(value=i18n("中文"), choices=list(dict_language.keys()))
|
||||
|
||||
if model_version in v3v4set:
|
||||
visible_sample_steps = True
|
||||
visible_inp_refs = False
|
||||
else:
|
||||
visible_sample_steps = False
|
||||
visible_inp_refs = True
|
||||
yield (
|
||||
prompt_text_update,
|
||||
prompt_language_update,
|
||||
text_update,
|
||||
text_language_update,
|
||||
gr.update(
|
||||
visible=visible_sample_steps,
|
||||
value=32 if model_version == "v3" else 8,
|
||||
choices=[4, 8, 16, 32, 64, 128] if model_version == "v3" else [4, 8, 16, 32],
|
||||
),
|
||||
gr.update(visible=visible_inp_refs),
|
||||
gr.update(value=False, interactive=True if model_version not in v3v4set else False),
|
||||
gr.update(visible=True if model_version == "v3" else False),
|
||||
gr.update(value=i18n("模型加载中,请等待"), interactive=False),
|
||||
)
|
||||
|
||||
hps = DictToAttrRecursive(hps)
|
||||
hps.model.semantic_frame_rate = "25hz"
|
||||
hps.model.version = model_version
|
||||
if model_version not in v3v4set:
|
||||
vq_model = SynthesizerTrn(
|
||||
hps.data.filter_length // 2 + 1,
|
||||
hps.train.segment_size // hps.data.hop_length,
|
||||
n_speakers=hps.data.n_speakers,
|
||||
**hps.model,
|
||||
)
|
||||
else:
|
||||
vq_model = SynthesizerTrnV3(
|
||||
hps.data.filter_length // 2 + 1,
|
||||
hps.train.segment_size // hps.data.hop_length,
|
||||
n_speakers=hps.data.n_speakers,
|
||||
**hps.model,
|
||||
).eval()
|
||||
|
||||
if "pretrained" not in sovits_path:
|
||||
if hasattr(vq_model, "enc_q"):
|
||||
del vq_model.enc_q
|
||||
|
||||
vq_model.load_state_dict(dict_s2["weight"], strict=False)
|
||||
|
||||
vq_model = vq_model.to(infer_device, dtype)
|
||||
|
||||
yield (
|
||||
gr.skip(),
|
||||
gr.skip(),
|
||||
gr.skip(),
|
||||
gr.skip(),
|
||||
gr.skip(),
|
||||
gr.skip(),
|
||||
gr.skip(),
|
||||
gr.skip(),
|
||||
gr.update(value=i18n("合成语音"), interactive=True),
|
||||
)
|
||||
|
||||
|
||||
with contextlib.suppress(UnboundLocalError):
|
||||
next(change_sovits_weights(sovits_path))
|
||||
|
||||
|
||||
def change_gpt_weights(gpt_path):
|
||||
global t2s_engine, config
|
||||
if "mlx" in ar_backend.lower():
|
||||
t2s_engine = MLX.T2SEngineMLX(
|
||||
MLX.T2SEngineMLX.load_decoder(Path(gpt_path), backend=ar_backend, quantize_mode=args.quantization),
|
||||
"mx.gpu" if infer_device.type != "cpu" else "mx.cpu",
|
||||
dtype=dtype,
|
||||
)
|
||||
# t2s_engine.decoder_model.compile()
|
||||
else:
|
||||
t2s_engine = PyTorch.T2SEngineTorch(
|
||||
PyTorch.T2SEngineTorch.load_decoder(Path(gpt_path), backend=ar_backend, quantize_mode=args.quantization),
|
||||
device,
|
||||
dtype=dtype,
|
||||
)
|
||||
# t2s_engine.decoder_model.compile()
|
||||
|
||||
|
||||
change_gpt_weights(gpt_path)
|
||||
|
||||
resample_transform_dict = {}
|
||||
|
||||
|
||||
def resample(audio_tensor, sr0, sr1, device):
|
||||
global resample_transform_dict
|
||||
key = f"{sr0}-{sr1}-{device}"
|
||||
if key not in resample_transform_dict:
|
||||
resample_transform_dict[key] = torchaudio.transforms.Resample(sr0, sr1).to(device)
|
||||
return resample_transform_dict[key](audio_tensor)
|
||||
|
||||
|
||||
def get_spepc(hps, filename, dtype, device, is_v2pro=False):
|
||||
sr1 = int(hps.data.sampling_rate)
|
||||
audio, sr0 = torchaudio.load_with_torchcodec(filename)
|
||||
audio = audio.to(device)
|
||||
|
||||
if sr0 != sr1:
|
||||
audio = resample(audio, sr0, sr1, device)
|
||||
if audio.shape[0] > 1:
|
||||
audio = audio.mean(0).unsqueeze(0)
|
||||
|
||||
maxx = float(audio.abs().max())
|
||||
if maxx > 1:
|
||||
audio /= min(2, maxx)
|
||||
spec = spectrogram_torch(
|
||||
audio,
|
||||
hps.data.filter_length,
|
||||
hps.data.sampling_rate,
|
||||
hps.data.hop_length,
|
||||
hps.data.win_length,
|
||||
center=False,
|
||||
)
|
||||
spec = spec.to(dtype)
|
||||
if is_v2pro is True:
|
||||
audio = resample(audio, sr1, 16000, device).to(dtype)
|
||||
return spec, audio
|
||||
|
||||
|
||||
def clean_text_inf(text, language, version):
|
||||
language = language.replace("all_", "")
|
||||
phones, word2ph, norm_text = clean_text(text, language, version)
|
||||
phones = cleaned_text_to_sequence(phones, version)
|
||||
return phones, word2ph, norm_text
|
||||
|
||||
|
||||
def get_bert_inf(phones, word2ph, norm_text, language):
|
||||
language = language.replace("all_", "")
|
||||
if language == "zh":
|
||||
bert = get_bert_feature(norm_text, word2ph).to(device) # .to(dtype)
|
||||
else:
|
||||
bert = torch.zeros(
|
||||
(1024, len(phones)),
|
||||
dtype=torch.float16 if is_half is True else torch.float32,
|
||||
).to(device)
|
||||
|
||||
return bert
|
||||
|
||||
|
||||
def get_first(text):
|
||||
pattern = "[" + "".join(re.escape(sep) for sep in splits) + "]"
|
||||
text = re.split(pattern, text)[0].strip()
|
||||
return text
|
||||
|
||||
|
||||
def get_phones_and_bert(text, language, version, final=False):
|
||||
text = re.sub(r" {2,}", " ", text)
|
||||
textlist = []
|
||||
langlist = []
|
||||
if language == "all_zh":
|
||||
for tmp in LangSegmenter.getTexts(text, "zh"):
|
||||
langlist.append(tmp["lang"])
|
||||
textlist.append(tmp["text"])
|
||||
elif language == "all_yue":
|
||||
for tmp in LangSegmenter.getTexts(text, "zh"):
|
||||
if tmp["lang"] == "zh":
|
||||
tmp["lang"] = "yue"
|
||||
langlist.append(tmp["lang"])
|
||||
textlist.append(tmp["text"])
|
||||
elif language == "all_ja":
|
||||
for tmp in LangSegmenter.getTexts(text, "ja"):
|
||||
langlist.append(tmp["lang"])
|
||||
textlist.append(tmp["text"])
|
||||
elif language == "all_ko":
|
||||
for tmp in LangSegmenter.getTexts(text, "ko"):
|
||||
langlist.append(tmp["lang"])
|
||||
textlist.append(tmp["text"])
|
||||
elif language == "en":
|
||||
langlist.append("en")
|
||||
textlist.append(text)
|
||||
elif language == "auto":
|
||||
for tmp in LangSegmenter.getTexts(text):
|
||||
langlist.append(tmp["lang"])
|
||||
textlist.append(tmp["text"])
|
||||
elif language == "auto_yue":
|
||||
for tmp in LangSegmenter.getTexts(text):
|
||||
if tmp["lang"] == "zh":
|
||||
tmp["lang"] = "yue"
|
||||
langlist.append(tmp["lang"])
|
||||
textlist.append(tmp["text"])
|
||||
else:
|
||||
for tmp in LangSegmenter.getTexts(text):
|
||||
if langlist:
|
||||
if (tmp["lang"] == "en" and langlist[-1] == "en") or (tmp["lang"] != "en" and langlist[-1] != "en"):
|
||||
textlist[-1] += tmp["text"]
|
||||
continue
|
||||
if tmp["lang"] == "en":
|
||||
langlist.append(tmp["lang"])
|
||||
else:
|
||||
# 因无法区别中日韩文汉字,以用户输入为准
|
||||
langlist.append(language)
|
||||
textlist.append(tmp["text"])
|
||||
phones_list = []
|
||||
bert_list = []
|
||||
norm_text_list = []
|
||||
for i in range(len(textlist)):
|
||||
lang = langlist[i]
|
||||
phones, word2ph, norm_text = clean_text_inf(textlist[i], lang, version)
|
||||
bert = get_bert_inf(phones, word2ph, norm_text, lang)
|
||||
phones_list.append(phones)
|
||||
norm_text_list.append(norm_text)
|
||||
bert_list.append(bert)
|
||||
bert = torch.cat(bert_list, dim=1)
|
||||
phones = sum(phones_list, [])
|
||||
norm_text = "".join(norm_text_list)
|
||||
|
||||
if not final and len(phones) < 6:
|
||||
return get_phones_and_bert("." + text, language, version, final=True)
|
||||
|
||||
return phones, bert.to(dtype), norm_text
|
||||
|
||||
|
||||
def merge_short_text_in_array(texts, threshold):
|
||||
if (len(texts)) < 2:
|
||||
return texts
|
||||
result = []
|
||||
text = ""
|
||||
for ele in texts:
|
||||
text += ele
|
||||
if len(text) >= threshold:
|
||||
result.append(text)
|
||||
text = ""
|
||||
if len(text) > 0:
|
||||
if len(result) == 0:
|
||||
result.append(text)
|
||||
else:
|
||||
result[len(result) - 1] += text
|
||||
return result
|
||||
|
||||
|
||||
sr_model = None
|
||||
|
||||
|
||||
def audio_sr(audio, sr):
|
||||
global sr_model
|
||||
if sr_model is None:
|
||||
from tools.audio_sr import AP_BWE
|
||||
|
||||
try:
|
||||
sr_model = AP_BWE(infer_device, DictToAttrRecursive)
|
||||
except FileNotFoundError:
|
||||
gr.Warning(i18n("你没有下载超分模型的参数,因此不进行超分。如想超分请先参照教程把文件下载好"))
|
||||
return audio.cpu().numpy(), sr
|
||||
return sr_model(audio, sr)
|
||||
|
||||
|
||||
cache: dict[int, Any] = {}
|
||||
|
||||
|
||||
def get_tts_wav(
|
||||
ref_wav_path,
|
||||
prompt_text,
|
||||
prompt_language,
|
||||
text,
|
||||
text_language,
|
||||
how_to_cut=i18n("不切"),
|
||||
top_k=20,
|
||||
top_p=0.6,
|
||||
temperature=0.6,
|
||||
ref_free=False,
|
||||
speed=1,
|
||||
if_freeze=False,
|
||||
inp_refs=None,
|
||||
sample_steps=8,
|
||||
if_sr=False,
|
||||
pause_second=0.3,
|
||||
):
|
||||
torch.set_grad_enabled(False)
|
||||
debug = os.getenv("DEBUG") == "1"
|
||||
ttfb_time = ttime()
|
||||
|
||||
if ref_wav_path:
|
||||
pass
|
||||
else:
|
||||
gr.Warning(i18n("请上传参考音频"))
|
||||
if text:
|
||||
pass
|
||||
else:
|
||||
gr.Warning(i18n("请填入推理文本"))
|
||||
t = []
|
||||
if prompt_text is None or len(prompt_text) == 0:
|
||||
ref_free = True
|
||||
if model_version in v3v4set:
|
||||
ref_free = False # s2v3暂不支持ref_free
|
||||
t0 = ttime()
|
||||
prompt_language = dict_language[prompt_language]
|
||||
text_language = dict_language[text_language]
|
||||
|
||||
if not ref_free:
|
||||
prompt_text = prompt_text.strip("\n")
|
||||
if prompt_text[-1] not in splits:
|
||||
prompt_text += "。" if prompt_language != "en" else "."
|
||||
text = text.strip("\n")
|
||||
|
||||
zero_wav = np.zeros(
|
||||
int(hps.data.sampling_rate * pause_second),
|
||||
dtype=np.float16 if is_half is True else np.float32,
|
||||
)
|
||||
zero_wav_torch = torch.from_numpy(zero_wav)
|
||||
if is_half is True:
|
||||
zero_wav_torch = zero_wav_torch.half().to(infer_device)
|
||||
else:
|
||||
zero_wav_torch = zero_wav_torch.to(infer_device)
|
||||
if not ref_free:
|
||||
assert vq_model
|
||||
wav16k, sr = librosa.load(ref_wav_path, sr=16000)
|
||||
if wav16k.shape[0] > 160000 or wav16k.shape[0] < 48000:
|
||||
gr.Warning(i18n("参考音频在3~10秒范围外,请更换!"))
|
||||
raise OSError(i18n("参考音频在3~10秒范围外,请更换!"))
|
||||
wav16k_t = torch.from_numpy(wav16k)
|
||||
if is_half is True:
|
||||
wav16k_t = wav16k_t.half().to(infer_device)
|
||||
else:
|
||||
wav16k_t = wav16k_t.to(infer_device)
|
||||
wav16k_t = torch.cat([wav16k_t, zero_wav_torch])
|
||||
ssl_content = ssl_model.model(wav16k_t.unsqueeze(0))["last_hidden_state"].transpose(1, 2) # .float()
|
||||
codes = vq_model.extract_latent(ssl_content)
|
||||
prompt_semantic = codes[0, 0]
|
||||
prompt = prompt_semantic.unsqueeze(0).to(device)
|
||||
else:
|
||||
prompt = torch.zeros((1, 0)).to(device, torch.int32)
|
||||
|
||||
t1 = ttime()
|
||||
t.append(t1 - t0)
|
||||
|
||||
if how_to_cut == i18n("凑四句一切"):
|
||||
text = cut1(text)
|
||||
elif how_to_cut == i18n("凑50字一切"):
|
||||
text = cut2(text)
|
||||
elif how_to_cut == i18n("按中文句号。切"):
|
||||
text = cut3(text)
|
||||
elif how_to_cut == i18n("按英文句号.切"):
|
||||
text = cut4(text)
|
||||
elif how_to_cut == i18n("按标点符号切"):
|
||||
text = cut5(text)
|
||||
while "\n\n" in text:
|
||||
text = text.replace("\n\n", "\n")
|
||||
texts = text.split("\n")
|
||||
texts = merge_short_text_in_array(texts, 5)
|
||||
audio_opt = []
|
||||
# s2v3暂不支持ref_free
|
||||
if not ref_free:
|
||||
phones1, bert1, _ = get_phones_and_bert(prompt_text, prompt_language, version)
|
||||
else:
|
||||
phones1, bert1 = [], torch.zeros(1024, 0).to(device, dtype)
|
||||
|
||||
infer_len: list[int] = []
|
||||
infer_time: list[float] = []
|
||||
assert vq_model
|
||||
|
||||
for i_text, text in enumerate(texts):
|
||||
if len(text.strip()) == 0:
|
||||
continue
|
||||
if text[-1] not in splits:
|
||||
text += "。" if text_language != "en" else "."
|
||||
phones2, bert2, norm_text2 = get_phones_and_bert(text, text_language, version)
|
||||
|
||||
bert = torch.cat([bert1, bert2], 1)
|
||||
all_phoneme_ids = torch.LongTensor(phones1 + phones2).to(device).unsqueeze(0)
|
||||
|
||||
bert = bert.to(device).unsqueeze(0)
|
||||
all_phoneme_len = torch.tensor([all_phoneme_ids.shape[-1]]).to(device)
|
||||
|
||||
t2 = ttime()
|
||||
if i_text in cache and if_freeze is True:
|
||||
pred_semantic = cache[i_text]
|
||||
else:
|
||||
t2s_request = T2SRequest(
|
||||
[all_phoneme_ids.squeeze(0)],
|
||||
all_phoneme_len,
|
||||
prompt,
|
||||
[bert.squeeze(0)],
|
||||
valid_length=1,
|
||||
top_k=top_k,
|
||||
top_p=top_p,
|
||||
temperature=temperature,
|
||||
early_stop_num=1500,
|
||||
use_cuda_graph=torch.cuda.is_available(), # Try to use CUDA Graph for all backend, fallback to normal if not applicapble
|
||||
debug=debug,
|
||||
)
|
||||
assert t2s_engine
|
||||
t2s_result = t2s_engine.generate(t2s_request)
|
||||
if t2s_result.exception is not None:
|
||||
console.print(t2s_result.traceback)
|
||||
raise RuntimeError()
|
||||
pred_semantic_list = t2s_result.result
|
||||
assert pred_semantic_list, t2s_result.traceback
|
||||
pred_semantic = pred_semantic_list[0].unsqueeze(0).to(infer_device)
|
||||
infer_len.append(t2s_result.total_tokens)
|
||||
infer_time.append(t2s_result.infer_speed[-1])
|
||||
|
||||
cache[i_text] = pred_semantic
|
||||
t3 = ttime()
|
||||
is_v2pro = model_version in {"v2Pro", "v2ProPlus"}
|
||||
|
||||
refers = []
|
||||
if inp_refs:
|
||||
for path in inp_refs:
|
||||
try: # 这里加上提取sv的逻辑,要么一堆sv一堆refer,要么单个sv单个refer
|
||||
refer, audio_tensor = get_spepc(hps, path.name, dtype, infer_device, is_v2pro)
|
||||
refers.append(refer)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
traceback.print_exc()
|
||||
if len(refers) == 0:
|
||||
refers, audio_tensor = get_spepc(hps, ref_wav_path, dtype, infer_device, is_v2pro)
|
||||
refers = [refers]
|
||||
audio = vq_model.decode(
|
||||
pred_semantic,
|
||||
torch.LongTensor(phones2).to(infer_device).unsqueeze(0),
|
||||
refers,
|
||||
speed=speed,
|
||||
)[0][0] # type: ignore
|
||||
|
||||
if i_text == 0:
|
||||
ttfb_time = ttime() - ttfb_time
|
||||
max_audio = torch.abs(audio).max() # 简单防止16bit爆音
|
||||
if max_audio > 1:
|
||||
audio = audio / max_audio
|
||||
audio_opt.append(audio)
|
||||
audio_opt.append(zero_wav_torch) # zero_wav
|
||||
t4 = ttime()
|
||||
t.extend([t2 - t1, t3 - t2, t4 - t3])
|
||||
t1 = ttime()
|
||||
|
||||
audio_opt_t = torch.cat(audio_opt, 0) # np.concatenate
|
||||
if model_version in {"v1", "v2", "v2Pro", "v2ProPlus"}:
|
||||
opt_sr = 32000
|
||||
elif model_version == "v3":
|
||||
opt_sr = 24000
|
||||
else:
|
||||
opt_sr = 48000 # v4
|
||||
audio_opt_n = audio_opt_t.cpu().numpy()
|
||||
|
||||
t0 = t[0]
|
||||
t1 = sum(t[1::3])
|
||||
t2 = sum(t[2::3])
|
||||
t3 = sum(t[3::3])
|
||||
|
||||
infer_speed_avg = sum(infer_len) / sum(infer_time)
|
||||
rtf_value = sum(t) / (audio_opt_n.__len__() / opt_sr)
|
||||
|
||||
console.print(f">> Time Stamps: {t0:.3f}\t{t1:.3f}\t{t2:.3f}\t{t3:.3f}")
|
||||
console.print(f">> Infer Speed: {infer_speed_avg:.2f} Token/s")
|
||||
console.print(f">> RTF: {rtf_value:.2f}")
|
||||
|
||||
if ttfb_time > 2:
|
||||
console.print(f">> TTFB: {ttfb_time:.3f} s")
|
||||
else:
|
||||
console.print(f">> TTFB: {ttfb_time * 1000:.3f} ms")
|
||||
|
||||
yield opt_sr, (audio_opt_n * 32767).astype(np.int16)
|
||||
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
|
||||
|
||||
def split(todo_text):
|
||||
todo_text = todo_text.replace("……", "。").replace("——", ",")
|
||||
if todo_text[-1] not in splits:
|
||||
todo_text += "。"
|
||||
i_split_head = i_split_tail = 0
|
||||
len_text = len(todo_text)
|
||||
todo_texts = []
|
||||
while 1:
|
||||
if i_split_head >= len_text:
|
||||
break # 结尾一定有标点,所以直接跳出即可,最后一段在上次已加入
|
||||
if todo_text[i_split_head] in splits:
|
||||
i_split_head += 1
|
||||
todo_texts.append(todo_text[i_split_tail:i_split_head])
|
||||
i_split_tail = i_split_head
|
||||
else:
|
||||
i_split_head += 1
|
||||
return todo_texts
|
||||
|
||||
|
||||
def cut1(inp):
|
||||
inp = inp.strip("\n")
|
||||
inps = split(inp)
|
||||
split_idx: list[int | None] = list(range(0, len(inps) + 1, 4))
|
||||
split_idx[-1] = None
|
||||
if len(split_idx) > 1:
|
||||
opts = []
|
||||
for idx in range(len(split_idx) - 1):
|
||||
opts.append("".join(inps[split_idx[idx] : split_idx[idx + 1]]))
|
||||
else:
|
||||
opts = [inp]
|
||||
opts = [item for item in opts if not set(item).issubset(punctuation)]
|
||||
return "\n".join(opts)
|
||||
|
||||
|
||||
def cut2(inp):
|
||||
inp = inp.strip("\n")
|
||||
inps = split(inp)
|
||||
if len(inps) < 2:
|
||||
return inp
|
||||
opts = []
|
||||
summ = 0
|
||||
tmp_str = ""
|
||||
for i in range(len(inps)):
|
||||
summ += len(inps[i])
|
||||
tmp_str += inps[i]
|
||||
if summ > 50:
|
||||
summ = 0
|
||||
opts.append(tmp_str)
|
||||
tmp_str = ""
|
||||
if tmp_str != "":
|
||||
opts.append(tmp_str)
|
||||
if len(opts) > 1 and len(opts[-1]) < 50: # 如果最后一个太短了,和前一个合一起
|
||||
opts[-2] = opts[-2] + opts[-1]
|
||||
opts = opts[:-1]
|
||||
opts = [item for item in opts if not set(item).issubset(punctuation)]
|
||||
return "\n".join(opts)
|
||||
|
||||
|
||||
def cut3(inp):
|
||||
inp = inp.strip("\n")
|
||||
opts = inp.strip("。").split("。")
|
||||
opts = [item for item in opts if not set(item).issubset(punctuation)]
|
||||
return "\n".join(opts)
|
||||
|
||||
|
||||
def cut4(inp):
|
||||
inp = inp.strip("\n")
|
||||
opts = re.split(r"(?<!\d)\.(?!\d)", inp.strip("."))
|
||||
opts = [item for item in opts if not set(item).issubset(punctuation)]
|
||||
return "\n".join(opts)
|
||||
|
||||
|
||||
# contributed by https://github.com/AI-Hobbyist/GPT-SoVITS/blob/main/GPT_SoVITS/inference_webui.py
|
||||
def cut5(inp):
|
||||
inp = inp.strip("\n")
|
||||
punds = {",", ".", ";", "?", "!", "、", ",", "。", "?", "!", ";", ":", "…"}
|
||||
mergeitems = []
|
||||
items = []
|
||||
|
||||
for i, char in enumerate(inp):
|
||||
if char in punds:
|
||||
if char == "." and i > 0 and i < len(inp) - 1 and inp[i - 1].isdigit() and inp[i + 1].isdigit():
|
||||
items.append(char)
|
||||
else:
|
||||
items.append(char)
|
||||
mergeitems.append("".join(items))
|
||||
items = []
|
||||
else:
|
||||
items.append(char)
|
||||
|
||||
if items:
|
||||
mergeitems.append("".join(items))
|
||||
|
||||
opt = [item for item in mergeitems if not set(item).issubset(punds)]
|
||||
return "\n".join(opt)
|
||||
|
||||
|
||||
a = get_tts_wav(
|
||||
"/Users/XXXXRT/Desktop/参考/不过呢因为有些特殊情况,所以我在一年半之前并没有这个,退网啊.wav",
|
||||
"不过呢因为有些特殊情况,所以我在一年半之前并没有这个,退网啊",
|
||||
i18n("中文"),
|
||||
"我在我青春韶华的时候遇到了你,还记得刚刚开学的时候,那是第一次见你,我和我朋友在楼道间打闹的时候无意间瞟到了你正在学习时的侧颜",
|
||||
i18n("中文"),
|
||||
)
|
||||
|
||||
next(a)
|
||||
|
||||
timer.clear()
|
||||
|
||||
a = get_tts_wav(
|
||||
"/Users/XXXXRT/Desktop/参考/Cream去能理解很多人的想法时,既然已经被这样想了,没有挽回的余地了.wav",
|
||||
"去能理解很多人的想法时,既然已经被这样想了,没有挽回的余地了",
|
||||
i18n("中文"),
|
||||
"我在我青春韶华的时候遇到了你,还记得刚刚开学的时候,那是第一次见你,我和我朋友在楼道间打闹的时候无意间瞟到了你正在学习时的侧颜",
|
||||
i18n("中文"),
|
||||
)
|
||||
|
||||
|
||||
next(a)
|
||||
|
||||
timer.summary()
|
||||
|
||||
a = get_tts_wav(
|
||||
"/Users/XXXXRT/Desktop/参考/不过呢因为有些特殊情况,所以我在一年半之前并没有这个,退网啊.wav",
|
||||
"不过呢因为有些特殊情况,所以我在一年半之前并没有这个,退网啊",
|
||||
i18n("中文"),
|
||||
"我在我青春韶华的时候遇到了你,还记得刚刚开学的时候,那是第一次见你,我和我朋友在楼道间打闹的时候无意间瞟到了你正在学习时的侧颜",
|
||||
i18n("中文"),
|
||||
)
|
||||
|
||||
|
||||
next(a)
|
||||
|
||||
timer.summary()
|
||||
|
||||
print("-" * 15 + "test2" + "-" * 15)
|
||||
51
webui.py
51
webui.py
@ -8,6 +8,7 @@ import traceback
|
||||
from functools import partial
|
||||
from multiprocessing import cpu_count
|
||||
from subprocess import Popen
|
||||
from typing import cast
|
||||
|
||||
import gradio as gr
|
||||
import psutil
|
||||
@ -37,7 +38,15 @@ from config import (
|
||||
webui_port_subfix,
|
||||
webui_port_uvr5,
|
||||
)
|
||||
from GPT_SoVITS.Accelerate import backends, console, logger
|
||||
from GPT_SoVITS.Accelerate import (
|
||||
MLX,
|
||||
PyTorch,
|
||||
backends,
|
||||
console,
|
||||
logger,
|
||||
quantization_methods_mlx,
|
||||
quantization_methods_torch,
|
||||
)
|
||||
from tools import my_utils
|
||||
from tools.asr.config import asr_dict
|
||||
from tools.assets import css, js, top_html
|
||||
@ -310,8 +319,8 @@ def change_tts_inference(
|
||||
sovits_path: str,
|
||||
batched_infer_enabled: bool,
|
||||
backends_dropdown: str,
|
||||
quantization_methods_dropdown: str | None,
|
||||
):
|
||||
console.print(gpt_path, sovits_path)
|
||||
global p_tts_inference
|
||||
env = os.environ.copy()
|
||||
cmd: list[str] = [python_exec, "-s", "-m"]
|
||||
@ -334,6 +343,7 @@ def change_tts_inference(
|
||||
"-b", backends_dropdown,
|
||||
"-d", f"{infer_device.type}:{gpu_number}",
|
||||
"-p", str(webui_port_infer_tts),
|
||||
"-q", str(quantization_methods_dropdown),
|
||||
"--gpt", gpt_path,
|
||||
"--sovits", sovits_path,
|
||||
]
|
||||
@ -344,7 +354,6 @@ def change_tts_inference(
|
||||
|
||||
if p_tts_inference is None:
|
||||
yield (
|
||||
process_info(process_name_tts, "opened"),
|
||||
gr.update(visible=False),
|
||||
gr.update(visible=True),
|
||||
)
|
||||
@ -354,7 +363,6 @@ def change_tts_inference(
|
||||
kill_process(p_tts_inference.pid, process_name_tts)
|
||||
p_tts_inference = None
|
||||
yield (
|
||||
process_info(process_name_tts, "closed"),
|
||||
gr.update(visible=True),
|
||||
gr.update(visible=False),
|
||||
)
|
||||
@ -1280,6 +1288,21 @@ def changeBackend(flag: bool):
|
||||
return gr.update(choices=backends_gradio, value=backends_gradio[-1][-1])
|
||||
|
||||
|
||||
def changeQuantization(backend: str, gradio_call=True):
|
||||
backend = backend.lower().replace("-", "_")
|
||||
if backend in MLX.backends:
|
||||
choices = quantization_methods_mlx
|
||||
elif backend in PyTorch.backends:
|
||||
choices = quantization_methods_torch
|
||||
else:
|
||||
choices = [None]
|
||||
|
||||
if gradio_call:
|
||||
return gr.update(choices=choices, value=None)
|
||||
else:
|
||||
return choices
|
||||
|
||||
|
||||
GPU_INDEX.add(0)
|
||||
GPU_INDEX_LIST = list(GPU_INDEX)
|
||||
GPU_INDEX_LIST.sort()
|
||||
@ -1891,7 +1914,13 @@ with gr.Blocks(title="GPT-SoVITS WebUI", analytics_enabled=False, js=js, css=css
|
||||
interactive=True,
|
||||
)
|
||||
with gr.Row(equal_height=True):
|
||||
tts_info = gr.Textbox(label=process_info(process_name_tts, "info"))
|
||||
with gr.Column():
|
||||
quantization_methods_dropdown = gr.Dropdown(
|
||||
choices=cast(list, changeQuantization(backends_gradio[-1][-1], gradio_call=False)),
|
||||
label=i18n("量化方法"),
|
||||
value=None,
|
||||
interactive=True,
|
||||
)
|
||||
open_tts = gr.Button(
|
||||
value=process_info(process_name_tts, "open"), variant="primary", visible=True
|
||||
)
|
||||
@ -1904,6 +1933,12 @@ with gr.Blocks(title="GPT-SoVITS WebUI", analytics_enabled=False, js=js, css=css
|
||||
[batched_infer_enabled],
|
||||
[backends_dropdown],
|
||||
)
|
||||
backends_dropdown.change(
|
||||
changeQuantization,
|
||||
[backends_dropdown],
|
||||
[quantization_methods_dropdown],
|
||||
)
|
||||
|
||||
open_tts.click(
|
||||
change_tts_inference,
|
||||
[
|
||||
@ -1912,8 +1947,9 @@ with gr.Blocks(title="GPT-SoVITS WebUI", analytics_enabled=False, js=js, css=css
|
||||
SoVITS_dropdown,
|
||||
batched_infer_enabled,
|
||||
backends_dropdown,
|
||||
quantization_methods_dropdown,
|
||||
],
|
||||
[tts_info, open_tts, close_tts],
|
||||
[open_tts, close_tts],
|
||||
)
|
||||
close_tts.click(
|
||||
change_tts_inference,
|
||||
@ -1923,8 +1959,9 @@ with gr.Blocks(title="GPT-SoVITS WebUI", analytics_enabled=False, js=js, css=css
|
||||
SoVITS_dropdown,
|
||||
batched_infer_enabled,
|
||||
backends_dropdown,
|
||||
quantization_methods_dropdown,
|
||||
],
|
||||
[tts_info, open_tts, close_tts],
|
||||
[open_tts, close_tts],
|
||||
)
|
||||
button1Ba_open.click(
|
||||
open1Ba,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user