This commit is contained in:
XXXXRT666 2025-08-16 18:34:35 +08:00
parent fdf794e31d
commit 4cf4dd7236
41 changed files with 4387 additions and 758 deletions

View File

@ -115,12 +115,17 @@ Remove-Item $ffDir.FullName -Recurse -Force
Write-Host "[INFO] Installing PyTorch..."
& ".\runtime\python.exe" -m ensurepip
& ".\runtime\python.exe" -m pip install --upgrade pip --no-warn-script-location
switch ($cuda) {
"cu124" {
& ".\runtime\python.exe" -m pip install torch==2.6 torchaudio --index-url https://download.pytorch.org/whl/cu124 --no-warn-script-location
& ".\runtime\python.exe" -m pip install psutil ninja packaging wheel "setuptools>=42" --no-warn-script-location
& ".\runtime\python.exe" -m pip install torch torchaudio --index-url https://download.pytorch.org/whl/cu124 --no-warn-script-location
& ".\runtime\python.exe" -m pip install flash-attn -i https://xxxxrt666.github.io/PIP-Index/ --no-build-isolation
}
"cu128" {
& ".\runtime\python.exe" -m pip install psutil ninja packaging wheel "setuptools>=42" --no-warn-script-location
& ".\runtime\python.exe" -m pip install torch torchaudio --index-url https://download.pytorch.org/whl/cu128 --no-warn-script-location
& ".\runtime\python.exe" -m pip install flash-attn -i https://xxxxrt666.github.io/PIP-Index/ --no-build-isolation
}
default {
Write-Error "Unsupported CUDA version: $cuda"

View File

@ -31,6 +31,15 @@ jobs:
- name: Checkout
uses: actions/checkout@v4
- name: Install Windows CUDA 12.9
if: ${{ runner.os == 'Windows' && matrix.torch_cuda == '12.8' }}
uses: Jimver/cuda-toolkit
id: cuda-toolkit-win-129
with:
cuda: 12.9.1
method: "network"
sub-packages: '["nvcc", "cudart", "visual_studio_integration"]'
- name: Run Build and Upload Script
shell: pwsh
run: |

View File

@ -23,8 +23,10 @@ fi
if [ "$TARGETPLATFORM" = "linux/amd64" ]; then
"${WGET_CMD[@]}" -O miniconda.sh https://repo.anaconda.com/miniconda/Miniconda3-py311_25.3.1-1-Linux-x86_64.sh
SYSROOT_PKG="sysroot_linux-64>=2.28"
elif [ "$TARGETPLATFORM" = "linux/arm64" ]; then
"${WGET_CMD[@]}" -O miniconda.sh https://repo.anaconda.com/miniconda/Miniconda3-py311_25.3.1-1-Linux-aarch64.sh
SYSROOT_PKG="sysroot_linux-aarch64>=2.28"
else
exit 1
fi
@ -45,20 +47,36 @@ rm miniconda.sh
source "$HOME/miniconda3/etc/profile.d/conda.sh"
"$HOME/miniconda3/bin/conda" init bash
source "$HOME/.bashrc"
"$HOME/miniconda3/bin/conda" config --add channels conda-forge
"$HOME/miniconda3/bin/conda" update -q --all -y 1>/dev/null
"$HOME/miniconda3/bin/conda" install python=3.11 -q -y
"$HOME/miniconda3/bin/conda" install gcc=14 gxx ffmpeg cmake make unzip -q -y
"$HOME/miniconda3/bin/conda" install gcc=11 gxx ffmpeg cmake make unzip $SYSROOT_PKG "libstdcxx-ng>=11" -q -y
if [ "$CUDA_VERSION" = "12.8" ]; then
"$HOME/miniconda3/bin/pip" install torch torchaudio --no-cache-dir --index-url https://download.pytorch.org/whl/cu128
"$HOME/miniconda3/bin/conda" install cuda-nvcc=12.8 -c nvidia
elif [ "$CUDA_VERSION" = "12.6" ]; then
"$HOME/miniconda3/bin/pip" install torch==2.6 torchaudio --no-cache-dir --index-url https://download.pytorch.org/whl/cu126
"$HOME/miniconda3/bin/pip" install torch torchaudio --no-cache-dir --index-url https://download.pytorch.org/whl/cu126
"$HOME/miniconda3/bin/conda" install cuda-nvcc=12.6 -c nvidia
fi
CUDA_PATH=$(echo "$HOME/miniconda3/targets/"*-linux | awk '{print $1}')
export CUDA_HOME=$CUDA_PATH
export PATH="$HOME/miniconda3/bin:$PATH"
export PATH="$CUDA_HOME/bin:$PATH"
export PATH="$CUDA_HOME/nvvm/bin:$PATH"
"$HOME/miniconda3/bin/pip" install psutil ninja packaging wheel "setuptools>=42"
"$HOME/miniconda3/bin/pip" install flash-attn -i https://xxxxrt666.github.io/PIP-Index/ --no-build-isolation
"$HOME/miniconda3/bin/pip" cache purge
rm $LOG_PATH

View File

@ -0,0 +1,11 @@
import importlib.util
if importlib.util.find_spec("mlx") is not None:
from .sample_funcs_mlx import sample_naive as sample_naive_mlx
from .t2s_engine_mlx import T2SEngine as T2SEngineMLX
backends = ["MLX"]
else:
backends = []
__all__ = ["T2SEngineMLX", "sample_naive_mlx", "backends"]

View File

@ -0,0 +1,64 @@
from functools import partial
from typing import Protocol, cast
import mlx.core as mx
Array = mx.array
class SampleProtocolMLX(Protocol):
@staticmethod
def __call__(
logits: Array,
previous_tokens: Array,
temperature: float,
top_k: int,
top_p: float,
repetition_penalty: float,
) -> Array: ...
class sample_naive(SampleProtocolMLX):
@partial(mx.compile, shapeless=True)
@staticmethod
def __call__(
logits,
previous_tokens,
temperature,
top_k,
top_p,
repetition_penalty,
):
if temperature <= 1e-5:
probs = mx.softmax(logits, axis=-1)
return mx.argmax(probs, axis=-1, keepdims=True)
if repetition_penalty != 1.0:
batch_idx = mx.arange(cast(tuple[int, ...], previous_tokens.shape)[0])
previous_tokens = previous_tokens.astype(mx.int64)
selected_logists = logits[batch_idx, previous_tokens]
selected_logists = mx.where(
selected_logists < 0, selected_logists * repetition_penalty, selected_logists / repetition_penalty
)
logits[batch_idx, previous_tokens] = selected_logists
sorted_indices = mx.argsort(-logits, axis=-1)
sorted_logits = mx.take_along_axis(logits, sorted_indices, axis=-1)
cum_probs = mx.cumsum(mx.softmax(sorted_logits, axis=-1), axis=-1)
sorted_indices_to_remove = cum_probs > top_p
sorted_indices_to_remove[:, 0] = False
indices_to_remove = mx.zeros_like(logits).astype(mx.bool_)
batch_indices = mx.arange(cast(tuple[int, ...], logits.shape)[0])[:, None]
indices_to_remove[batch_indices, sorted_indices] = sorted_indices_to_remove
logits = mx.where(indices_to_remove, -mx.inf, logits)
logits = logits / temperature
v = mx.topk(logits, top_k)
pivot = mx.expand_dims(v[:, -1], -1)
logits = mx.where(logits < pivot, -mx.inf, logits)
gumbel_noise = mx.random.gumbel(shape=cast(tuple[int, ...], logits.shape), dtype=logits.dtype)
idx_next = mx.argmax(logits + gumbel_noise, axis=-1, keepdims=True).astype(mx.int32)
return idx_next

View File

@ -0,0 +1,155 @@
"""
Modified From https://github.com/XXXXRT666/GPT-SoVITS
"""
from __future__ import annotations
import os
from dataclasses import dataclass
from typing import Callable, List, MutableSequence, Protocol, Type, cast
import mlx.core as mx
import torch
from ..PyTorch.structs import T2SRequest, T2SResult
from .sample_funcs_mlx import SampleProtocolMLX, sample_naive
Tensor = torch.Tensor
Array = mx.array
@dataclass(slots=True)
class T2SRequestMLX:
x: List[Array]
x_lens: Array
prompts: Array
bert_feature: List[Array]
valid_length: int
top_k: int = 5
top_p: float = 1
early_stop_num: int = -1
temperature: float = 1.0
repetition_penalty: float = 1.35
@classmethod
def from_torch(cls, request: T2SRequest) -> T2SRequestMLX:
x = list(map(lambda tensor: mx.array(tensor.cpu()), request.x))
x_lens = mx.array(request.x_lens.cpu())
prompts = mx.array(request.prompts.cpu())
bert_feature = list(map(lambda tensor: mx.array(tensor.cpu()), request.bert_feature))
return cls(
x,
x_lens,
prompts,
bert_feature,
request.valid_length,
request.top_k,
request.top_p,
request.early_stop_num,
request.temperature,
request.repetition_penalty,
)
class KVCacheProtocol(Protocol):
k_cache: Array
v_cache: Array
def empty(self) -> None: ...
def update_cache(self, input_pos: Array, k_val: Array, v_val: Array, *args, **kwds) -> tuple[Array, Array]: ...
def prefill_kv(self, k_val: Array, v_val: Array) -> None: ...
def sync_cache(self, kv_cache: KVCacheProtocol) -> None: ...
class T2SDecoderProtocol(Protocol):
max_seq_length: int
EOS: int
n_head: int
def embed(self, x: list[Array], y: Array, bert_features: list[Array]) -> Array: ...
class T2SEngineProtocol(Protocol):
def _handle_request(self, request: T2SRequest) -> tuple[list[Array], float]: ...
def generate(self, request: T2SRequest) -> T2SResult: ...
@staticmethod
def load_decoder(
weights_path: os.PathLike, max_batch_size: int = 1, implement: str = "MLX"
) -> T2SDecoderProtocol: ...
class T2SSessionMLX:
def __init__(
self,
decoder: T2SDecoderProtocol,
request_torch: T2SRequest,
sample_func: Type[SampleProtocolMLX] = sample_naive,
device: mx.Device = mx.Device(mx.cpu),
dtype: mx.Dtype = mx.float32,
):
with mx.stream(device):
request = T2SRequestMLX.from_torch(request_torch)
self.decoder = decoder
self.request = request
self.device = device
self.dtype = dtype
bsz = len(request.x)
y_len: int = cast(tuple[int, ...], request.prompts.shape)[-1]
self.bsz = bsz
self.y_len = y_len
# Cache
self.kv_cache: MutableSequence[KVCacheProtocol]
self.sample = sample_func()
# Forward args
self.x = [i.astype(mx.int32) for i in request.x]
self.x_lens = request.x_lens.astype(mx.int32)
self.y = mx.zeros((bsz, decoder.max_seq_length)).astype(mx.int32)
self.y[:, : cast(tuple[int, ...], request.prompts.shape)[-1]] = request.prompts.astype(mx.int32)
self.bert_feature = [i.astype(dtype) for i in request.bert_feature]
self.prefill_len = self.x_lens + cast(tuple[int, ...], request.prompts.shape)[1]
self.input_pos = mx.zeros_like(self.prefill_len)
self.input_pos += self.prefill_len
# EOS
self.completed = mx.array([False] * len(self.x)).astype(mx.bool_)
self.y_results: List[Array] = [None] * len(self.x) # type: ignore
self.xy_pos = decoder.embed(self.x, request.prompts, self.bert_feature)
max_len = int(self.prefill_len.max(-1))
attn_mask = mx.zeros(shape=(bsz, max_len, max_len), dtype=mx.bool_)
for bs in range(bsz):
pos = int(self.x_lens[bs])
seq_len = pos + y_len
attn_mask[bs, :seq_len, :pos] = True
ar_mask = ~mx.triu(
x=mx.ones(
shape=(
y_len,
y_len,
),
dtype=mx.bool_,
),
k=1,
)
attn_mask[bs, pos:seq_len, pos:seq_len] = ar_mask
attn_mask = mx.repeat(mx.expand_dims(attn_mask, 1), decoder.n_head, 1)
self.attn_mask = attn_mask
mx.eval(self.attn_mask)

View File

@ -0,0 +1,195 @@
import gc
import os
import time
import traceback
from typing import cast
import mlx.core as mx
import torch
from tqdm import tqdm
from ..PyTorch.structs import T2SEngineProtocol, T2SRequest
from .structs_mlx import T2SResult, T2SSessionMLX
from .t2s_model_mlx_varlen import T2SDecoder
Array = mx.array
Tensor = torch.Tensor
class T2SEngine(T2SEngineProtocol):
def __init__(
self,
decoder_model: T2SDecoder,
device: mx.Device | str = mx.Device(mx.cpu),
dtype: torch.dtype | mx.Dtype = torch.float32,
) -> None:
if isinstance(device, str):
match device:
case "mx.cpu":
device = mx.Device(mx.cpu)
case "mx.gpu":
device = mx.Device(mx.gpu)
match dtype:
case torch.float32:
dtype = mx.float32
case torch.float16:
dtype = mx.float16
case torch.bfloat16:
dtype = mx.bfloat16
device = cast(mx.Device, device)
dtype = cast(mx.Dtype, dtype)
assert device.type.value in {0, 1}
assert dtype in {mx.float16, mx.bfloat16, mx.float32}
self.device = device
self.dtype = dtype
mx.set_default_device(device)
decoder_model.set_dtype(self.dtype)
self.decoder_model: T2SDecoder = decoder_model
def _handle_request(self, request: T2SRequest):
decoder = self.decoder_model
session = T2SSessionMLX(decoder, request, device=self.device, dtype=self.dtype)
batch_idx = mx.arange(session.bsz)
t1 = 0.0
infer_speed = 0.0
with mx.stream(session.device):
for idx in tqdm(range(1500)):
if idx == 0:
session.kv_cache = decoder.init_cache(session.bsz)
xy_dec = decoder.h.prefill(
session.xy_pos, session.attn_mask, session.kv_cache
) # bs, seq_len, embed_dim
xy_dec = xy_dec[None, batch_idx, session.input_pos - 1]
else:
args, kwds = decoder.pre_forward(session)
xy_dec = decoder.h(
session.input_pos,
session.xy_pos,
session.kv_cache,
*args,
**kwds,
)
decoder.post_forward(idx, session)
logits = decoder.ar_predict_layer(xy_dec[:, -1])
session.input_pos += 1
if idx == 0:
logits[:, -1] = float("-inf")
samples = session.sample(
logits=logits,
previous_tokens=session.y[:, : session.y_len + idx],
top_k=request.top_k,
top_p=request.top_p,
repetition_penalty=request.repetition_penalty,
temperature=request.temperature,
)
session.y[batch_idx, session.y_len + idx] = samples
argmax_token = mx.argmax(logits, axis=-1)
sample_token = samples.squeeze(1)
EOS_mask = (cast(Array, argmax_token == decoder.EOS)) | (sample_token == decoder.EOS)
newly_done_mask = EOS_mask & (~session.completed)
newly_done_indices = mx.where(newly_done_mask, batch_idx, -1)
pos = mx.where(newly_done_indices != -1, batch_idx, session.bsz)
pos_sorted = mx.sort(pos, axis=0)
valid_count = session.bsz - mx.sum(cast(Array, pos_sorted == session.bsz))
pos_final = pos_sorted[: int(valid_count)]
newly_done_indices = mx.expand_dims(newly_done_indices[pos_final], 0)
if newly_done_indices.size > 0:
for i in newly_done_indices:
session.y_results[int(i)] = session.y[i, session.y_len : session.y_len + idx]
session.completed[newly_done_indices] = True
if mx.all(session.completed).item():
if session.y.sum() == 0:
session.y_results = [mx.array([0]) for _ in range(session.bsz)]
tqdm.write("Bad Zero Prediction")
else:
tqdm.write(
f"T2S Decoding EOS {session.prefill_len.tolist().__str__().strip('[]')} -> \n{[cast(tuple[int, ...], i.shape)[-1] for i in session.y_results].__str__().strip('[]')}"
)
tqdm.write(f"Infer Speed: {(idx - 1) / (time.perf_counter() - t1):.2f} token/s")
infer_speed = (idx - 1) / (time.perf_counter() - t1)
break
if (request.early_stop_num != -1 and idx >= request.early_stop_num) or idx == 1499:
for j in range(session.bsz):
if not session.completed[j].item():
session.y_results[j] = session.y[[j], session.y_len : session.y_len + 1499]
session.completed[j] = True
break
y_emb = decoder.ar_audio_embedding(samples)
session.xy_pos = decoder.ar_audio_position(session.input_pos - session.x_lens, y_emb)
mx.eval(session.xy_pos, session.y)
if idx == 1:
t1 = time.perf_counter()
if idx % 100 == 0:
mx.clear_cache()
match session.device:
case mx.gpu:
mx.clear_cache()
case mx.cpu:
gc.collect()
result_mlx = session.y_results[: request.valid_length]
mx.eval(result_mlx)
result = [torch.tensor(k) for k in result_mlx]
return result, infer_speed
def generate(self, request: T2SRequest):
try:
result, infer_speed = self._handle_request(request)
t2s_result = T2SResult(result=result, infer_speed=infer_speed, status="Success")
except Exception as e:
t2s_result = T2SResult(status="Error", exception=e, traceback=traceback.format_exc())
return t2s_result
@staticmethod
def replace_key(state_dict: dict[str, Tensor]):
state_dict_mlx: list[tuple[str, Array]] = []
for key, value in state_dict.items():
key = (
key.replace("model.", "")
.replace("in_proj_", "in_proj.")
.replace("self_attn", "attention")
.replace("linear", "feed_forward.linear")
.replace("norm1", "attention_norm")
.replace("norm2", "ffn_norm")
)
value_mlx = mx.array(value)
state_dict_mlx.append((key, value_mlx))
return state_dict_mlx
@staticmethod
def load_decoder(weights_path: os.PathLike, max_batch_size: int = 1, backend: str = "MLX"):
if backend != "MLX":
raise RuntimeError("")
print(f"Loading Text2Semantic Weights from {weights_path} with MLX Backend")
dict_s1 = torch.load(weights_path, map_location="cpu", weights_only=False, mmap=True)
config = dict_s1["config"]
decoder: T2SDecoder = T2SDecoder(config, max_batch_size=max_batch_size)
state_dict = dict_s1["weight"]
state_dict_mlx = T2SEngine.replace_key(state_dict)
decoder.load_weights(state_dict_mlx)
decoder.eval()
mx.eval(decoder)
return decoder

View File

@ -0,0 +1,440 @@
from __future__ import annotations
import math
from typing import MutableSequence, cast
import mlx.core as mx
import mlx.nn as nn
from .structs_mlx import KVCacheProtocol, T2SDecoderProtocol, T2SSessionMLX
Array = mx.array
class TokenEmbedding(nn.Module):
def __init__(
self,
embedding_dim: int,
vocab_size: int,
dropout: float = 0.0,
):
super().__init__()
self.vocab_size = vocab_size
self.embedding_dim = embedding_dim
self.dropout = nn.Dropout(p=dropout)
self.word_embeddings = nn.Embedding(self.vocab_size, self.embedding_dim)
@property
def weight(self):
return self.word_embeddings.weight
def embedding(self, index: int):
return self.word_embeddings.weight[index : index + 1]
def __call__(self, x: Array):
x = self.word_embeddings(x)
x = self.dropout(x)
return x
class SinePositionalEmbedding(nn.Module):
def __init__(
self,
embedding_dim: int,
dropout: float = 0.0,
scale: bool = False,
max_batch_size: int = 10,
max_seq_len: int = 1800,
):
super().__init__()
self.embedding_dim = embedding_dim
self.x_scale = math.sqrt(embedding_dim) if scale else 1.0
self.alpha = mx.ones(1)
self.dropout = nn.Dropout(p=dropout)
self.max_batch_size = max_batch_size
self.max_seq_len = max_seq_len
self.reverse = False
self._pe = mx.zeros((max_batch_size, max_seq_len, embedding_dim))
self.compute_pe()
def compute_pe(self):
"""Reset the positional encodings."""
if self.reverse:
position = mx.expand_dims(mx.arange(self.max_seq_len - 1, -1, -1.0), axis=1)
else:
position = mx.expand_dims(mx.arange(self.max_seq_len), axis=1)
div_term = mx.exp(
mx.arange(
0,
self.embedding_dim,
2,
)
* -(math.log(10000.0) / self.embedding_dim)
)
pe = self._pe
pe[:, :, 0::2] = mx.sin(position * div_term)
pe[:, :, 1::2] = mx.cos(position * div_term)
def __call__(self, input_pos: Array, x: Array):
"""
Args:
input_pos (Array): [batch_size, ]
x (Array): [batch_size, 1, embed_dim]
Returns:
embedded_x (Array): [batch_size, 1, embed_dim]
"""
batch_size = cast(tuple[int, ...], x.shape)[0]
pe_values = self._pe[mx.arange(batch_size), input_pos - 1] # (batch_size, embed_dim)
return x * self.x_scale + self.alpha * mx.expand_dims(pe_values, 1) # (batch_size, 1, embed_dim)
def prefill(self, x: Array):
"""
Args:
x (Array): [batch_size, seq_len, embed_dim]
Returns:
embedded_x (Array): [batch_size, seq_len, embed_dim]
"""
pe_values = self._pe[:, : cast(tuple[int, ...], x.shape)[-2]]
return x * self.x_scale + self.alpha * pe_values
class KVCache(nn.Module, KVCacheProtocol):
def __init__(self, batch_size: int, max_seq_length: int, n_heads: int, head_dim: int) -> None:
super().__init__()
self.n_head = n_heads
self.head_dim = head_dim
self.batch_size = batch_size
self.max_seq_length = max_seq_length
assert batch_size > 0
cache_shape = (batch_size, n_heads, max_seq_length, head_dim)
self.cache_idx = mx.arange(batch_size)
self.k_cache: Array = mx.zeros(cache_shape)
self.v_cache: Array = mx.zeros(cache_shape)
def empty(self):
self.k_cache[:] = 0
self.v_cache[:] = 0
def update_cache(self, input_pos: Array, k_val: Array, v_val: Array):
# input_pos: [B, ], k_val: [B, H, 1, D]
k_out = self.k_cache
v_out = self.v_cache
k_out[self.cache_idx, :, input_pos, None] = k_val
v_out[self.cache_idx, :, input_pos, None] = v_val
return k_out, v_out
def prefill_kv(self, k_val: Array, v_val: Array):
# k_val: [B, S, H, D]
self.k_cache[..., : cast(tuple[int, ...], k_val.shape)[1], :] = k_val.swapaxes(1, 2)
self.v_cache[..., : cast(tuple[int, ...], v_val.shape)[1], :] = v_val.swapaxes(1, 2)
def sync_cache(self, kv_cache: KVCacheProtocol):
self.k_cache[:] = kv_cache.k_cache
self.v_cache[:] = kv_cache.v_cache
class Attention(nn.Module):
def __init__(self, n_head: int, hidden_dim: int, max_seq_length: int):
super().__init__()
self.n_head = n_head
self.hidden_dim = hidden_dim
assert hidden_dim % n_head == 0
self.head_dim = hidden_dim // n_head
self.max_seq_length = max_seq_length
# key, query, value projections for all heads, but in a batch
self.in_proj = nn.Linear(hidden_dim, hidden_dim * 3, bias=True)
self.out_proj = nn.Linear(hidden_dim, hidden_dim, bias=True)
self.scale = 1 / math.sqrt(self.head_dim)
self.dropout = nn.Dropout(0.1)
def __call__(self, x: Array, input_pos: Array, kv_cache: KVCacheProtocol, attn_mask: Array):
bsz, seqlen, _ = cast(tuple[int, ...], x.shape)
q, k, v = self.in_proj(x).split(3, axis=-1)
q, k, v = map(lambda x: x.reshape(bsz, seqlen, self.n_head, self.head_dim), (q, k, v))
q, k, v = map(lambda x: x.swapaxes(1, 2), (q, k, v))
k, v = kv_cache.update_cache(input_pos, k, v)
attn = mx.fast.scaled_dot_product_attention(q, k, v, scale=self.scale, mask=attn_mask)
attn = self.dropout(attn)
attn = attn.swapaxes(1, 2).reshape(bsz, seqlen, self.hidden_dim)
attn = self.out_proj(attn)
return attn
def prefill(self, x: Array, kv_cache: KVCacheProtocol, attn_mask: Array):
bsz, seqlen, _ = cast(tuple[int, ...], x.shape)
q, k, v = self.in_proj(mx.expand_dims(x, 0)).split(3, axis=-1)
q, k, v = map(lambda x: x.reshape(bsz, seqlen, self.n_head, self.head_dim), (q, k, v))
kv_cache.prefill_kv(k, v)
q, k, v = map(lambda x: x.swapaxes(1, 2), (q, k, v))
attn = mx.fast.scaled_dot_product_attention(q, k, v, mask=attn_mask, scale=self.scale)
attn = mx.nan_to_num(attn)
attn = self.dropout(attn)
attn = attn.swapaxes(1, 2).reshape(1, -1, self.hidden_dim)
output = self.out_proj(attn)
return output
class FeedForward(nn.Module):
def __init__(self, dim: int, hidden_dim: int) -> None:
super().__init__()
self.linear1 = nn.Linear(dim, hidden_dim, bias=True)
self.linear2 = nn.Linear(hidden_dim, dim, bias=True)
self.dropout = nn.Dropout(0.1)
def __call__(self, x: Array):
return self.dropout(self.linear2(self.dropout(nn.relu(self.linear1(x)))))
class TransformerBlock(nn.Module):
def __init__(self, n_head: int, ffn_dim: int, hidden_dim: int, max_seq_length: int) -> None:
super().__init__()
self.hidden_dim = hidden_dim
self.max_seq_length = max_seq_length
self.attention = Attention(n_head, hidden_dim, max_seq_length)
self.feed_forward = FeedForward(hidden_dim, ffn_dim)
self.attention_norm = nn.LayerNorm(self.hidden_dim)
self.ffn_norm = nn.LayerNorm(self.hidden_dim)
self.dropout = nn.Dropout(0.1)
def __call__(self, x: Array, input_pos: Array, kv_cache: KVCacheProtocol, attn_mask: Array):
h = self.attention_norm(
x
+ self.dropout(
self.attention(
x,
input_pos,
kv_cache,
attn_mask,
)
)
)
out = self.ffn_norm(h + self.feed_forward(h))
return out
def prefill(self, x: Array, mask: Array, kv_cache: KVCacheProtocol):
h = self.attention_norm(
x
+ self.dropout(
self.attention.prefill(
x,
kv_cache,
mask,
)
)
)
out = self.ffn_norm(h + self.feed_forward(h))
return out
class TransformerDecoder(nn.Module):
def __init__(
self,
hidden_dim: int,
n_layer: int,
n_head: int,
ffn_dim: int,
vocab_size: int,
max_seq_length: int,
max_batch_size: int,
) -> None:
super().__init__()
self.hidden_dim = hidden_dim
self.n_head = n_head
assert hidden_dim % n_head == 0
self.head_dim = hidden_dim // n_head
self.vocab_size = vocab_size
self.n_layer = n_layer
self.layers: MutableSequence[TransformerBlock] = [
TransformerBlock(
n_head,
ffn_dim,
hidden_dim,
max_seq_length,
)
for _ in range(n_layer)
]
self.max_seq_length = max_seq_length
self.max_batch_size = max_batch_size
def __call__(self, input_pos: Array, x: Array, kv_caches: MutableSequence[KVCacheProtocol], *args, **kwds):
for layer, kv_cache in zip(self.layers, kv_caches):
x = layer(x, input_pos, kv_cache, *args, **kwds)
return x
def prefill(self, x: Array, mask: Array, kv_caches: MutableSequence[KVCacheProtocol]):
for layer, kv_cache in zip(self.layers, kv_caches):
x = layer.prefill(x, mask, kv_cache)
return x
class T2SDecoder(nn.Module, T2SDecoderProtocol):
def __init__(
self,
config: dict,
max_seq_length: int = 1800,
max_batch_size: int = 10,
) -> None:
super().__init__()
hidden_dim: int = config["model"]["hidden_dim"]
embedding_dim: int = config["model"]["embedding_dim"]
n_head: int = config["model"]["head"]
n_layer: int = config["model"]["n_layer"]
vocab_size: int = config["model"]["vocab_size"]
phoneme_vocab_size: int = config["model"]["phoneme_vocab_size"]
p_dropout: float = config["model"]["dropout"]
EOS: int = config["model"]["EOS"]
ffn_dim: int = hidden_dim * 4
self.n_layer = int(n_layer)
self.hidden_dim = int(hidden_dim)
self.n_head = int(n_head)
assert hidden_dim % n_head == 0
self.head_dim = int(hidden_dim // n_head)
self.embedding_dim = int(embedding_dim)
self.ffn_dim = int(ffn_dim)
self.vocab_size = int(vocab_size)
self.phoneme_vocab_size = int(phoneme_vocab_size)
self.p_dropout = float(p_dropout)
self.max_seq_length = max_seq_length
self.max_batch_size = max_batch_size
self.EOS = EOS
assert self.EOS == self.vocab_size - 1
self.bert_proj = nn.Linear(1024, self.embedding_dim)
self.ar_predict_layer = nn.Linear(self.hidden_dim, self.vocab_size, bias=False)
self.h = TransformerDecoder(
self.hidden_dim, self.n_layer, self.n_head, self.ffn_dim, self.vocab_size, max_seq_length, max_batch_size
)
self.ar_text_embedding = TokenEmbedding(self.embedding_dim, self.phoneme_vocab_size, self.p_dropout)
self.ar_text_position = SinePositionalEmbedding(
self.embedding_dim,
dropout=0.1,
scale=False,
max_batch_size=max_batch_size,
max_seq_len=max_seq_length,
)
self.ar_audio_embedding = TokenEmbedding(self.embedding_dim, self.vocab_size, self.p_dropout)
self.ar_audio_position = SinePositionalEmbedding(
self.embedding_dim,
dropout=0.1,
scale=False,
max_batch_size=max_batch_size,
max_seq_len=max_seq_length,
)
def init_cache(self, bsz: int = 0) -> MutableSequence[KVCacheProtocol]:
bsz = bsz or self.h.max_batch_size
assert bsz <= self.h.max_batch_size
seq_lens = self.h.max_seq_length
dtype = self.bert_proj.bias.dtype
cache: MutableSequence[KVCacheProtocol] = [
KVCache(bsz, seq_lens, self.n_head, self.head_dim) for _ in range(self.n_layer)
]
for c in cache:
cast(KVCache, c).set_dtype(dtype)
mx.eval(cache)
return cache
def embed(
self,
x: list[Array],
y: Array,
bert_features: list[Array],
):
x_len: list[int] = [cast(tuple[int, ...], i.shape)[0] for i in x]
x_len_max = max(x_len)
xy_pos = mx.zeros((len(x), x_len_max + cast(tuple[int, ...], y.shape)[1], self.embedding_dim)).astype(
bert_features[0].dtype
)
bert_features = list(map(lambda x: x.swapaxes(0, 1), bert_features))
y_len = cast(tuple[int, ...], y.shape)[1]
y_emb = self.ar_audio_embedding(y)
y_pos = self.ar_audio_position.prefill(y_emb)
for bs, (x_, len_, bert_feature) in enumerate(zip(x, x_len, bert_features)):
x_emb = self.ar_text_embedding(x_)
bert = self.bert_proj(bert_feature)
x_emb = x_emb + bert
x_pos = self.ar_text_position.prefill(mx.expand_dims(x_emb, 0))
xy_pos[[bs], :len_] = x_pos
xy_pos[[bs], len_ : len_ + y_len] = y_pos
mx.eval(xy_pos)
return xy_pos
# def compile(self):
# setattr(self.h, "__call__", mx.compile(self.h.__call__))
# setattr(self.h, "prefill", mx.compile(self.h.prefill, shapeless=True))
def pre_forward(self, session: T2SSessionMLX):
attn_mask = session.attn_mask
return list(), dict(attn_mask=attn_mask)
def post_forward(self, idx: int, session: T2SSessionMLX) -> None:
if idx == 0:
prefill_len = session.prefill_len
bsz = session.bsz
range_tensor = mx.arange(self.max_seq_length).reshape(1, 1, 1, self.max_seq_length)
prefill_len_expanded = prefill_len.reshape(bsz, 1, 1, 1)
attn_mask = range_tensor < prefill_len_expanded
attn_mask = mx.repeat(attn_mask, self.n_head, 1)
session.attn_mask = attn_mask
attn_mask = session.attn_mask
input_pos = session.input_pos
attn_mask[mx.arange(session.bsz), :, :, input_pos] = True
mx.eval(attn_mask)

View File

@ -0,0 +1,440 @@
from __future__ import annotations
import math
from typing import MutableSequence, cast
import mlx.core as mx
import mlx.nn as nn
from .structs_mlx import KVCacheProtocol, T2SDecoderProtocol, T2SSessionMLX
Array = mx.array
class TokenEmbedding(nn.Module):
def __init__(
self,
embedding_dim: int,
vocab_size: int,
dropout: float = 0.0,
):
super().__init__()
self.vocab_size = vocab_size
self.embedding_dim = embedding_dim
self.dropout = nn.Dropout(p=dropout)
self.word_embeddings = nn.Embedding(self.vocab_size, self.embedding_dim)
@property
def weight(self):
return self.word_embeddings.weight
def embedding(self, index: int):
return self.word_embeddings.weight[index : index + 1]
def __call__(self, x: Array):
x = self.word_embeddings(x)
x = self.dropout(x)
return x
class SinePositionalEmbedding(nn.Module):
def __init__(
self,
embedding_dim: int,
dropout: float = 0.0,
scale: bool = False,
max_batch_size: int = 10,
max_seq_len: int = 1800,
):
super().__init__()
self.embedding_dim = embedding_dim
self.x_scale = math.sqrt(embedding_dim) if scale else 1.0
self.alpha = mx.ones(1)
self.dropout = nn.Dropout(p=dropout)
self.max_batch_size = max_batch_size
self.max_seq_len = max_seq_len
self.reverse = False
self._pe = mx.zeros((max_batch_size, max_seq_len, embedding_dim))
self.compute_pe()
def compute_pe(self):
"""Reset the positional encodings."""
if self.reverse:
position = mx.expand_dims(mx.arange(self.max_seq_len - 1, -1, -1.0), axis=1)
else:
position = mx.expand_dims(mx.arange(self.max_seq_len), axis=1)
div_term = mx.exp(
mx.arange(
0,
self.embedding_dim,
2,
)
* -(math.log(10000.0) / self.embedding_dim)
)
pe = self._pe
pe[:, :, 0::2] = mx.sin(position * div_term)
pe[:, :, 1::2] = mx.cos(position * div_term)
def __call__(self, input_pos: Array, x: Array):
"""
Args:
input_pos (Array): [batch_size, ]
x (Array): [batch_size, 1, embed_dim]
Returns:
embedded_x (Array): [batch_size, 1, embed_dim]
"""
batch_size = cast(tuple[int, ...], x.shape)[0]
pe_values = self._pe[mx.arange(batch_size), input_pos - 1] # (batch_size, embed_dim)
return x * self.x_scale + self.alpha * mx.expand_dims(pe_values, 1) # (batch_size, 1, embed_dim)
def prefill(self, x: Array):
"""
Args:
x (Array): [batch_size, seq_len, embed_dim]
Returns:
embedded_x (Array): [batch_size, seq_len, embed_dim]
"""
pe_values = self._pe[:, : cast(tuple[int, ...], x.shape)[-2]]
return x * self.x_scale + self.alpha * pe_values
class KVCache(nn.Module, KVCacheProtocol):
def __init__(self, batch_size: int, max_seq_length: int, n_heads: int, head_dim: int) -> None:
super().__init__()
self.n_head = n_heads
self.head_dim = head_dim
self.batch_size = batch_size
self.max_seq_length = max_seq_length
assert batch_size > 0
cache_shape = (batch_size, n_heads, max_seq_length, head_dim)
self.cache_idx = mx.arange(batch_size)
self.k_cache: Array = mx.zeros(cache_shape)
self.v_cache: Array = mx.zeros(cache_shape)
def empty(self):
self.k_cache[:] = 0
self.v_cache[:] = 0
def update_cache(self, input_pos: Array, k_val: Array, v_val: Array):
# input_pos: [B, ], k_val: [B, H, 1, D]
k_out = self.k_cache
v_out = self.v_cache
k_out[self.cache_idx, :, input_pos, None] = k_val
v_out[self.cache_idx, :, input_pos, None] = v_val
return k_out, v_out
def prefill_kv(self, k_val: Array, v_val: Array):
# k_val: [B, S, H, D]
self.k_cache[..., : cast(tuple[int, ...], k_val.shape)[1], :] = k_val.swapaxes(1, 2)
self.v_cache[..., : cast(tuple[int, ...], v_val.shape)[1], :] = v_val.swapaxes(1, 2)
def sync_cache(self, kv_cache: KVCacheProtocol):
self.k_cache[:] = kv_cache.k_cache
self.v_cache[:] = kv_cache.v_cache
class Attention(nn.Module):
def __init__(self, n_head: int, hidden_dim: int, max_seq_length: int):
super().__init__()
self.n_head = n_head
self.hidden_dim = hidden_dim
assert hidden_dim % n_head == 0
self.head_dim = hidden_dim // n_head
self.max_seq_length = max_seq_length
# key, query, value projections for all heads, but in a batch
self.in_proj = nn.Linear(hidden_dim, hidden_dim * 3, bias=True)
self.out_proj = nn.Linear(hidden_dim, hidden_dim, bias=True)
self.scale = 1 / math.sqrt(self.head_dim)
self.dropout = nn.Dropout(0.1)
def __call__(self, x: Array, input_pos: Array, kv_cache: KVCacheProtocol, attn_mask: Array):
bsz, seqlen, _ = cast(tuple[int, ...], x.shape)
q, k, v = self.in_proj(x).split(3, axis=-1)
q, k, v = map(lambda x: x.reshape(bsz, seqlen, self.n_head, self.head_dim), (q, k, v))
q, k, v = map(lambda x: x.swapaxes(1, 2), (q, k, v))
k, v = kv_cache.update_cache(input_pos, k, v)
attn = mx.fast.scaled_dot_product_attention(q, k, v, scale=self.scale, mask=attn_mask)
attn = self.dropout(attn)
attn = attn.swapaxes(1, 2).reshape(bsz, seqlen, self.hidden_dim)
attn = self.out_proj(attn)
return attn
def prefill(self, x: Array, kv_cache: KVCacheProtocol, attn_mask: Array):
bsz, seqlen, _ = cast(tuple[int, ...], x.shape)
q, k, v = self.in_proj(mx.expand_dims(x, 0)).split(3, axis=-1)
q, k, v = map(lambda x: x.reshape(bsz, seqlen, self.n_head, self.head_dim), (q, k, v))
kv_cache.prefill_kv(k, v)
q, k, v = map(lambda x: x.swapaxes(1, 2), (q, k, v))
attn = mx.fast.scaled_dot_product_attention(q, k, v, mask=attn_mask, scale=self.scale)
attn = mx.nan_to_num(attn)
attn = self.dropout(attn)
attn = attn.swapaxes(1, 2).reshape(1, -1, self.hidden_dim)
output = self.out_proj(attn)
return output
class FeedForward(nn.Module):
def __init__(self, dim: int, hidden_dim: int) -> None:
super().__init__()
self.linear1 = nn.Linear(dim, hidden_dim, bias=True)
self.linear2 = nn.Linear(hidden_dim, dim, bias=True)
self.dropout = nn.Dropout(0.1)
def __call__(self, x: Array):
return self.dropout(self.linear2(self.dropout(nn.relu(self.linear1(x)))))
class TransformerBlock(nn.Module):
def __init__(self, n_head: int, ffn_dim: int, hidden_dim: int, max_seq_length: int) -> None:
super().__init__()
self.hidden_dim = hidden_dim
self.max_seq_length = max_seq_length
self.attention = Attention(n_head, hidden_dim, max_seq_length)
self.feed_forward = FeedForward(hidden_dim, ffn_dim)
self.attention_norm = nn.LayerNorm(self.hidden_dim)
self.ffn_norm = nn.LayerNorm(self.hidden_dim)
self.dropout = nn.Dropout(0.1)
def __call__(self, x: Array, input_pos: Array, kv_cache: KVCacheProtocol, attn_mask: Array):
h = self.attention_norm(
x
+ self.dropout(
self.attention(
x,
input_pos,
kv_cache,
attn_mask,
)
)
)
out = self.ffn_norm(h + self.feed_forward(h))
return out
def prefill(self, x: Array, mask: Array, kv_cache: KVCacheProtocol):
h = self.attention_norm(
x
+ self.dropout(
self.attention.prefill(
x,
kv_cache,
mask,
)
)
)
out = self.ffn_norm(h + self.feed_forward(h))
return out
class TransformerDecoder(nn.Module):
def __init__(
self,
hidden_dim: int,
n_layer: int,
n_head: int,
ffn_dim: int,
vocab_size: int,
max_seq_length: int,
max_batch_size: int,
) -> None:
super().__init__()
self.hidden_dim = hidden_dim
self.n_head = n_head
assert hidden_dim % n_head == 0
self.head_dim = hidden_dim // n_head
self.vocab_size = vocab_size
self.n_layer = n_layer
self.layers: MutableSequence[TransformerBlock] = [
TransformerBlock(
n_head,
ffn_dim,
hidden_dim,
max_seq_length,
)
for _ in range(n_layer)
]
self.max_seq_length = max_seq_length
self.max_batch_size = max_batch_size
def __call__(self, input_pos: Array, x: Array, kv_caches: MutableSequence[KVCacheProtocol], *args, **kwds):
for layer, kv_cache in zip(self.layers, kv_caches):
x = layer(x, input_pos, kv_cache, *args, **kwds)
return x
def prefill(self, x: Array, mask: Array, kv_caches: MutableSequence[KVCacheProtocol]):
for layer, kv_cache in zip(self.layers, kv_caches):
x = layer.prefill(x, mask, kv_cache)
return x
class T2SDecoder(nn.Module, T2SDecoderProtocol):
def __init__(
self,
config: dict,
max_seq_length: int = 1800,
max_batch_size: int = 10,
) -> None:
super().__init__()
hidden_dim: int = config["model"]["hidden_dim"]
embedding_dim: int = config["model"]["embedding_dim"]
n_head: int = config["model"]["head"]
n_layer: int = config["model"]["n_layer"]
vocab_size: int = config["model"]["vocab_size"]
phoneme_vocab_size: int = config["model"]["phoneme_vocab_size"]
p_dropout: float = config["model"]["dropout"]
EOS: int = config["model"]["EOS"]
ffn_dim: int = hidden_dim * 4
self.n_layer = int(n_layer)
self.hidden_dim = int(hidden_dim)
self.n_head = int(n_head)
assert hidden_dim % n_head == 0
self.head_dim = int(hidden_dim // n_head)
self.embedding_dim = int(embedding_dim)
self.ffn_dim = int(ffn_dim)
self.vocab_size = int(vocab_size)
self.phoneme_vocab_size = int(phoneme_vocab_size)
self.p_dropout = float(p_dropout)
self.max_seq_length = max_seq_length
self.max_batch_size = max_batch_size
self.EOS = EOS
assert self.EOS == self.vocab_size - 1
self.bert_proj = nn.Linear(1024, self.embedding_dim)
self.ar_predict_layer = nn.Linear(self.hidden_dim, self.vocab_size, bias=False)
self.h = TransformerDecoder(
self.hidden_dim, self.n_layer, self.n_head, self.ffn_dim, self.vocab_size, max_seq_length, max_batch_size
)
self.ar_text_embedding = TokenEmbedding(self.embedding_dim, self.phoneme_vocab_size, self.p_dropout)
self.ar_text_position = SinePositionalEmbedding(
self.embedding_dim,
dropout=0.1,
scale=False,
max_batch_size=max_batch_size,
max_seq_len=max_seq_length,
)
self.ar_audio_embedding = TokenEmbedding(self.embedding_dim, self.vocab_size, self.p_dropout)
self.ar_audio_position = SinePositionalEmbedding(
self.embedding_dim,
dropout=0.1,
scale=False,
max_batch_size=max_batch_size,
max_seq_len=max_seq_length,
)
def init_cache(self, bsz: int = 0) -> MutableSequence[KVCacheProtocol]:
bsz = bsz or self.h.max_batch_size
assert bsz <= self.h.max_batch_size
seq_lens = self.h.max_seq_length
dtype = self.bert_proj.bias.dtype
cache: MutableSequence[KVCacheProtocol] = [
KVCache(bsz, seq_lens, self.n_head, self.head_dim) for _ in range(self.n_layer)
]
for c in cache:
cast(KVCache, c).set_dtype(dtype)
mx.eval(cache)
return cache
def embed(
self,
x: list[Array],
y: Array,
bert_features: list[Array],
):
x_len: list[int] = [cast(tuple[int, ...], i.shape)[0] for i in x]
x_len_max = max(x_len)
xy_pos = mx.zeros((len(x), x_len_max + cast(tuple[int, ...], y.shape)[1], self.embedding_dim)).astype(
bert_features[0].dtype
)
bert_features = list(map(lambda x: x.swapaxes(0, 1), bert_features))
y_len = cast(tuple[int, ...], y.shape)[1]
y_emb = self.ar_audio_embedding(y)
y_pos = self.ar_audio_position.prefill(y_emb)
for bs, (x_, len_, bert_feature) in enumerate(zip(x, x_len, bert_features)):
x_emb = self.ar_text_embedding(x_)
bert = self.bert_proj(bert_feature)
x_emb = x_emb + bert
x_pos = self.ar_text_position.prefill(mx.expand_dims(x_emb, 0))
xy_pos[[bs], :len_] = x_pos
xy_pos[[bs], len_ : len_ + y_len] = y_pos
mx.eval(xy_pos)
return xy_pos
# def compile(self):
# setattr(self.h, "__call__", mx.compile(self.h.__call__))
# setattr(self.h, "prefill", mx.compile(self.h.prefill, shapeless=True))
def pre_forward(self, session: T2SSessionMLX):
attn_mask = session.attn_mask
return list(), dict(attn_mask=attn_mask)
def post_forward(self, idx: int, session: T2SSessionMLX) -> None:
if idx == 0:
prefill_len = session.prefill_len
bsz = session.bsz
range_tensor = mx.arange(self.max_seq_length).reshape(1, 1, 1, self.max_seq_length)
prefill_len_expanded = prefill_len.reshape(bsz, 1, 1, 1)
attn_mask = range_tensor < prefill_len_expanded
attn_mask = mx.repeat(attn_mask, self.n_head, 1)
session.attn_mask = attn_mask
attn_mask = session.attn_mask
input_pos = session.input_pos
attn_mask[mx.arange(session.bsz), :, :, input_pos] = True
mx.eval(attn_mask)

View File

@ -0,0 +1,25 @@
import importlib.util
import torch
from .sample_funcs import sample_naive
from .structs import T2SRequest, T2SResult
from .t2s_engine import T2SEngine as T2SEngineTorch
backends = ["naive"]
if torch.cuda.is_available():
if importlib.util.find_spec("sageattention") is not None:
for i in range(torch.cuda.device_count()):
major, minor = torch.cuda.get_device_capability(i)
sm_version = major + minor / 10.0
if sm_version >= 7.0:
backends.append("sage_attn")
if importlib.util.find_spec("flash_attn") is not None:
for i in range(torch.cuda.device_count()):
major, minor = torch.cuda.get_device_capability(i)
sm_version = major + minor / 10.0
if sm_version >= 7.5:
backends.append("flash_attn")
__all__ = ["T2SEngineTorch", "T2SRequest", "sample_naive", "T2SResult", "backends"]

View File

@ -0,0 +1,159 @@
"""
Modified From https://github.com/XXXXRT666/GPT-SoVITS
"""
from typing import Dict, List, Tuple
import kernels
import torch
from .. import nn
from ..structs import T2SSession
from ..t2s_model_abc import (
AttentionABC,
CUDAGraphCacheABC,
FeedForward,
KVCacheNHD,
KVCacheProtocol,
T2SDecoderABC,
TransformerBlockABC,
TransformerDecoderABC,
)
flash_attn_kernel = None
try:
import flash_attn_interface as flash_attn # type: ignore
flash_attn_kernel = flash_attn.flash_attn_with_kvcache
except ModuleNotFoundError:
try:
import flash_attn # type: ignore
flash_attn_kernel = flash_attn.flash_attn_with_kvcache
except ModuleNotFoundError:
pass
if flash_attn_kernel is None:
flash_attn_kernel = kernels.get_kernel("kernels-community/flash-attn").flash_attn_with_kvcache
Tensor = torch.Tensor
class Attention(AttentionABC):
def __init__(self, n_head, hidden_dim, max_seq_length):
super().__init__(n_head, hidden_dim, max_seq_length)
self.in_proj = nn.Linear(hidden_dim, hidden_dim * 3, bias=True)
self.out_proj = nn.Linear(hidden_dim, hidden_dim, bias=True)
def __call__(self, x: Tensor, input_pos: Tensor, kv_cache: KVCacheProtocol, *args, **kwds) -> Tensor:
bsz, seqlen, _ = x.shape
q, k, v = self.in_proj(x).chunk(3, dim=-1)
q = q.view(bsz, seqlen, self.n_head, self.head_dim)
k = k.view(bsz, seqlen, self.n_head, self.head_dim)
v = v.view(bsz, seqlen, self.n_head, self.head_dim)
attn: Tensor = flash_attn.flash_attn_with_kvcache( # type: ignore
q, kv_cache.k_cache, kv_cache.v_cache, k, v, cache_seqlens=input_pos - 1
)
attn = self.dropout(attn)
attn = attn.view(bsz, seqlen, self.hidden_dim)
attn = self.out_proj(attn)
return attn
class TransformerBlock(TransformerBlockABC):
def __init__(self, n_head, ffn_dim, hidden_dim, max_seq_length) -> None:
super().__init__(n_head, ffn_dim, hidden_dim, max_seq_length)
self.attention = Attention(n_head, hidden_dim, max_seq_length)
self.feed_forward = FeedForward(hidden_dim, ffn_dim)
self.attention_norm = nn.LayerNorm([self.hidden_dim])
self.ffn_norm = nn.LayerNorm([self.hidden_dim])
class TransformerDecoder(TransformerDecoderABC):
def __init__(
self,
hidden_dim,
n_layer,
n_head,
ffn_dim,
vocab_size,
max_seq_length,
max_batch_size,
) -> None:
super().__init__(hidden_dim, n_layer, n_head, ffn_dim, vocab_size, max_seq_length, max_batch_size)
self.layers = nn.ModuleList( # type: ignore
TransformerBlock(n_head, ffn_dim, hidden_dim, max_seq_length) for _ in range(n_layer)
)
class T2SDecoder(T2SDecoderABC):
def __init__(
self,
config,
max_seq_length=1800,
max_batch_size=10,
) -> None:
assert torch.cuda.is_available()
super().__init__(config, max_seq_length, max_batch_size)
self.bert_proj = nn.Linear(1024, self.embedding_dim)
self.ar_predict_layer = nn.Linear(self.hidden_dim, self.vocab_size, bias=False)
self.h: TransformerDecoderABC = TransformerDecoder(
self.hidden_dim, self.n_layer, self.n_head, self.ffn_dim, self.vocab_size, max_seq_length, max_batch_size
)
self.kv_class = KVCacheNHD
def post_forward(self, idx: int, session: T2SSession) -> None:
return super().post_forward(idx, session)
def pre_forward(self, session: T2SSession) -> Tuple[List, Dict]:
return super().pre_forward(session)
class CUDAGraphCache(CUDAGraphCacheABC):
def __init__(
self,
decoder: T2SDecoder,
) -> None:
super().__init__(decoder)
def release_graph(self, session: T2SSession):
if session.id != self.id:
self.assigned = False
else:
del session.graph, session.xy_pos_, session.xy_dec_, session.input_pos, session.kv_cache
def get_cache_graph(self, session: T2SSession):
assert self.graph
session.graph = self.graph
session.stream = self.stream
session.xy_pos_ = self.xy_pos
session.xy_dec_ = self.xy_dec
session.input_pos = self.input_pos.copy_(session.input_pos)
for cache, cache_ in zip(self.kv_cache, session.kv_cache):
cache.sync_cache(cache_)
def capture_new_graph(self, session: T2SSession):
session.xy_pos_ = self.xy_pos.clone()
session.xy_dec_ = self.xy_dec.clone()
session.input_pos = self.input_pos.clone().copy_(session.input_pos)
args, kwds = self.decoder.pre_forward(session)
graph = self.decoder.capture(self.input_pos, self.xy_pos, self.xy_dec, self.kv_cache, *args, **kwds)
session.graph = graph
session.stream = torch.cuda.Stream() # type: ignore

View File

@ -0,0 +1,167 @@
import torch
from torch.nn import functional as F
from .. import nn
from ..structs import KVCacheProtocol, T2SSession
from ..t2s_model_abc import (
AttentionABC,
CUDAGraphCacheABC,
FeedForward,
KVCacheHND,
T2SDecoderABC,
TransformerBlockABC,
TransformerDecoderABC,
)
Tensor = torch.Tensor
class Attention(AttentionABC):
def __init__(self, n_head, hidden_dim, max_seq_length):
super().__init__(n_head, hidden_dim, max_seq_length)
# key, query, value projections for all heads, but in a batch
self.in_proj = nn.Linear(hidden_dim, hidden_dim * 3, bias=True)
self.out_proj = nn.Linear(hidden_dim, hidden_dim, bias=True)
def __call__(self, x: Tensor, input_pos: Tensor, kv_cache: KVCacheProtocol, attn_mask: Tensor):
bsz, seqlen, _ = x.shape
q, k, v = self.in_proj(x).chunk(3, dim=-1)
q = q.view(bsz, seqlen, self.n_head, self.head_dim)
k = k.view(bsz, seqlen, self.n_head, self.head_dim)
v = v.view(bsz, seqlen, self.n_head, self.head_dim)
q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v))
k, v = kv_cache.update(input_pos, k, v)
attn = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
attn = self.dropout(attn)
attn = attn.transpose(1, 2).contiguous().view(bsz, seqlen, self.hidden_dim)
attn = self.out_proj(attn)
return attn
class TransformerBlock(TransformerBlockABC):
def __init__(self, n_head: int, ffn_dim: int, hidden_dim: int, max_seq_length: int) -> None:
super().__init__(n_head, ffn_dim, hidden_dim, max_seq_length)
self.attention = Attention(n_head, hidden_dim, max_seq_length)
self.feed_forward = FeedForward(hidden_dim, ffn_dim)
self.attention_norm = nn.LayerNorm([self.hidden_dim])
self.ffn_norm = nn.LayerNorm([self.hidden_dim])
class TransformerDecoder(TransformerDecoderABC):
def __init__(
self,
hidden_dim,
n_layer,
n_head,
ffn_dim,
vocab_size,
max_seq_length,
max_batch_size,
) -> None:
super().__init__(hidden_dim, n_layer, n_head, ffn_dim, vocab_size, max_seq_length, max_batch_size)
self.layers = nn.ModuleList( # type: ignore
TransformerBlock(n_head, ffn_dim, hidden_dim, max_seq_length) for _ in range(n_layer)
)
class T2SDecoder(T2SDecoderABC):
def __init__(
self,
config,
max_seq_length=1800,
max_batch_size=10,
) -> None:
super().__init__(config, max_seq_length, max_batch_size)
self.bert_proj = nn.Linear(1024, self.embedding_dim)
self.ar_predict_layer = nn.Linear(self.hidden_dim, self.vocab_size, bias=False)
self.h: TransformerDecoderABC = TransformerDecoder(
self.hidden_dim, self.n_layer, self.n_head, self.ffn_dim, self.vocab_size, max_seq_length, max_batch_size
)
self.kv_class = KVCacheHND
def pre_forward(self, session: T2SSession):
attn_mask = session.attn_mask
return list(), dict(attn_mask=attn_mask)
def post_forward(self, idx: int, session: T2SSession) -> None:
if idx == 0:
prefill_len = session.prefill_len
bsz = session.bsz
range_tensor = torch.arange(self.max_seq_length).view(1, 1, 1, self.max_seq_length)
prefill_len_expanded = prefill_len.view(bsz, 1, 1, 1)
attn_mask = range_tensor < prefill_len_expanded
attn_mask = attn_mask.expand(-1, self.n_head, -1, -1)
session.attn_mask = attn_mask
attn_mask = session.attn_mask
input_pos = session.input_pos
attn_mask[torch.arange(session.bsz), :, :, input_pos] = True
class CUDAGraphCache(CUDAGraphCacheABC):
def __init__(
self,
decoder,
) -> None:
super().__init__(decoder)
if torch.cuda.is_available():
self.attn_mask = (
torch.randint(0, 2, (decoder.max_batch_size, decoder.n_head, 1, decoder.max_seq_length))
.bool()
.to(self.device, self.dtype)
)
def release_graph(self, session: T2SSession):
if session.id != self.id:
self.assigned = False
else:
del (
session.graph,
session.xy_pos_,
session.xy_dec_,
session.input_pos,
session.kv_cache,
session.attn_mask,
)
def get_cache_graph(self, session: T2SSession):
assert self.graph
session.graph = self.graph
session.stream = self.stream
session.xy_pos_ = self.xy_pos
session.xy_dec_ = self.xy_dec
session.input_pos = self.input_pos.copy_(session.input_pos)
session.attn_mask = self.attn_mask
for cache, cache_ in zip(self.kv_cache, session.kv_cache):
cache.sync_cache(cache_)
def capture_new_graph(self, session: T2SSession):
session.xy_pos_ = self.xy_pos.clone()
session.xy_dec_ = self.xy_dec.clone()
session.input_pos = self.input_pos.clone().copy_(session.input_pos)
session.attn_mask = self.attn_mask.clone().copy_(session.attn_mask)
args, kwds = self.decoder.pre_forward(session)
graph = self.decoder.capture(self.input_pos, self.xy_pos, self.xy_dec, self.kv_cache, *args, **kwds)
session.graph = graph
session.stream = torch.cuda.Stream() # type: ignore

View File

@ -0,0 +1,167 @@
import torch
from torch.nn import functional as F
from .. import nn
from ..structs import KVCacheProtocol, T2SSession
from ..t2s_model_abc import (
AttentionABC,
CUDAGraphCacheABC,
FeedForward,
KVCacheHND,
T2SDecoderABC,
TransformerBlockABC,
TransformerDecoderABC,
)
Tensor = torch.Tensor
class Attention(AttentionABC):
def __init__(self, n_head, hidden_dim, max_seq_length):
super().__init__(n_head, hidden_dim, max_seq_length)
# key, query, value projections for all heads, but in a batch
self.in_proj = nn.Linear(hidden_dim, hidden_dim * 3, bias=True)
self.out_proj = nn.Linear(hidden_dim, hidden_dim, bias=True)
def __call__(self, x: Tensor, input_pos: Tensor, kv_cache: KVCacheProtocol, attn_mask: Tensor):
bsz, seqlen, _ = x.shape
q, k, v = self.in_proj(x).chunk(3, dim=-1)
q = q.view(bsz, seqlen, self.n_head, self.head_dim)
k = k.view(bsz, seqlen, self.n_head, self.head_dim)
v = v.view(bsz, seqlen, self.n_head, self.head_dim)
q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v))
k, v = kv_cache.update(input_pos, k, v)
attn = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
attn = self.dropout(attn)
attn = attn.transpose(1, 2).contiguous().view(bsz, seqlen, self.hidden_dim)
attn = self.out_proj(attn)
return attn
class TransformerBlock(TransformerBlockABC):
def __init__(self, n_head: int, ffn_dim: int, hidden_dim: int, max_seq_length: int) -> None:
super().__init__(n_head, ffn_dim, hidden_dim, max_seq_length)
self.attention = Attention(n_head, hidden_dim, max_seq_length)
self.feed_forward = FeedForward(hidden_dim, ffn_dim)
self.attention_norm = nn.LayerNorm([self.hidden_dim])
self.ffn_norm = nn.LayerNorm([self.hidden_dim])
class TransformerDecoder(TransformerDecoderABC):
def __init__(
self,
hidden_dim,
n_layer,
n_head,
ffn_dim,
vocab_size,
max_seq_length,
max_batch_size,
) -> None:
super().__init__(hidden_dim, n_layer, n_head, ffn_dim, vocab_size, max_seq_length, max_batch_size)
self.layers = nn.ModuleList( # type: ignore
TransformerBlock(n_head, ffn_dim, hidden_dim, max_seq_length) for _ in range(n_layer)
)
class T2SDecoder(T2SDecoderABC):
def __init__(
self,
config,
max_seq_length=1800,
max_batch_size=10,
) -> None:
super().__init__(config, max_seq_length, max_batch_size)
self.bert_proj = nn.Linear(1024, self.embedding_dim)
self.ar_predict_layer = nn.Linear(self.hidden_dim, self.vocab_size, bias=False)
self.h: TransformerDecoderABC = TransformerDecoder(
self.hidden_dim, self.n_layer, self.n_head, self.ffn_dim, self.vocab_size, max_seq_length, max_batch_size
)
self.kv_class = KVCacheHND
def pre_forward(self, session: T2SSession):
attn_mask = session.attn_mask
return list(), dict(attn_mask=attn_mask)
def post_forward(self, idx: int, session: T2SSession) -> None:
if idx == 0:
prefill_len = session.prefill_len
bsz = session.bsz
range_tensor = torch.arange(self.max_seq_length).view(1, 1, 1, self.max_seq_length)
prefill_len_expanded = prefill_len.view(bsz, 1, 1, 1)
attn_mask = range_tensor < prefill_len_expanded
attn_mask = attn_mask.expand(-1, self.n_head, -1, -1)
session.attn_mask = attn_mask
attn_mask = session.attn_mask
input_pos = session.input_pos
attn_mask[torch.arange(session.bsz), :, :, input_pos] = True
class CUDAGraphCache(CUDAGraphCacheABC):
def __init__(
self,
decoder,
) -> None:
super().__init__(decoder)
if torch.cuda.is_available():
self.attn_mask = (
torch.randint(0, 2, (decoder.max_batch_size, decoder.n_head, 1, decoder.max_seq_length))
.bool()
.to(self.device, self.dtype)
)
def release_graph(self, session: T2SSession):
if session.id != self.id:
self.assigned = False
else:
del (
session.graph,
session.xy_pos_,
session.xy_dec_,
session.input_pos,
session.kv_cache,
session.attn_mask,
)
def get_cache_graph(self, session: T2SSession):
assert self.graph
session.graph = self.graph
session.stream = self.stream
session.xy_pos_ = self.xy_pos
session.xy_dec_ = self.xy_dec
session.input_pos = self.input_pos.copy_(session.input_pos)
session.attn_mask = self.attn_mask
for cache, cache_ in zip(self.kv_cache, session.kv_cache):
cache.sync_cache(cache_)
def capture_new_graph(self, session: T2SSession):
session.xy_pos_ = self.xy_pos.clone()
session.xy_dec_ = self.xy_dec.clone()
session.input_pos = self.input_pos.clone().copy_(session.input_pos)
session.attn_mask = self.attn_mask.clone().copy_(session.attn_mask)
args, kwds = self.decoder.pre_forward(session)
graph = self.decoder.capture(self.input_pos, self.xy_pos, self.xy_dec, self.kv_cache, *args, **kwds)
session.graph = graph
session.stream = torch.cuda.Stream() # type: ignore

View File

@ -0,0 +1,167 @@
import torch
from torch.nn import functional as F
from .. import nn
from ..structs import KVCacheProtocol, T2SSession
from ..t2s_model_abc import (
AttentionABC,
CUDAGraphCacheABC,
FeedForward,
KVCacheHND,
T2SDecoderABC,
TransformerBlockABC,
TransformerDecoderABC,
)
Tensor = torch.Tensor
class Attention(AttentionABC):
def __init__(self, n_head, hidden_dim, max_seq_length):
super().__init__(n_head, hidden_dim, max_seq_length)
# key, query, value projections for all heads, but in a batch
self.in_proj = nn.Linear(hidden_dim, hidden_dim * 3, bias=True)
self.out_proj = nn.Linear(hidden_dim, hidden_dim, bias=True)
def __call__(self, x: Tensor, input_pos: Tensor, kv_cache: KVCacheProtocol, attn_mask: Tensor):
bsz, seqlen, _ = x.shape
q, k, v = self.in_proj(x).chunk(3, dim=-1)
q = q.view(bsz, seqlen, self.n_head, self.head_dim)
k = k.view(bsz, seqlen, self.n_head, self.head_dim)
v = v.view(bsz, seqlen, self.n_head, self.head_dim)
q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v))
k, v = kv_cache.update(input_pos, k, v)
attn = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
attn = self.dropout(attn)
attn = attn.transpose(1, 2).contiguous().view(bsz, seqlen, self.hidden_dim)
attn = self.out_proj(attn)
return attn
class TransformerBlock(TransformerBlockABC):
def __init__(self, n_head: int, ffn_dim: int, hidden_dim: int, max_seq_length: int) -> None:
super().__init__(n_head, ffn_dim, hidden_dim, max_seq_length)
self.attention = Attention(n_head, hidden_dim, max_seq_length)
self.feed_forward = FeedForward(hidden_dim, ffn_dim)
self.attention_norm = nn.LayerNorm([self.hidden_dim])
self.ffn_norm = nn.LayerNorm([self.hidden_dim])
class TransformerDecoder(TransformerDecoderABC):
def __init__(
self,
hidden_dim,
n_layer,
n_head,
ffn_dim,
vocab_size,
max_seq_length,
max_batch_size,
) -> None:
super().__init__(hidden_dim, n_layer, n_head, ffn_dim, vocab_size, max_seq_length, max_batch_size)
self.layers = nn.ModuleList( # type: ignore
TransformerBlock(n_head, ffn_dim, hidden_dim, max_seq_length) for _ in range(n_layer)
)
class T2SDecoder(T2SDecoderABC):
def __init__(
self,
config,
max_seq_length=1800,
max_batch_size=10,
) -> None:
super().__init__(config, max_seq_length, max_batch_size)
self.bert_proj = nn.Linear(1024, self.embedding_dim)
self.ar_predict_layer = nn.Linear(self.hidden_dim, self.vocab_size, bias=False)
self.h: TransformerDecoderABC = TransformerDecoder(
self.hidden_dim, self.n_layer, self.n_head, self.ffn_dim, self.vocab_size, max_seq_length, max_batch_size
)
self.kv_class = KVCacheHND
def pre_forward(self, session: T2SSession):
attn_mask = session.attn_mask
return list(), dict(attn_mask=attn_mask)
def post_forward(self, idx: int, session: T2SSession) -> None:
if idx == 0:
prefill_len = session.prefill_len
bsz = session.bsz
range_tensor = torch.arange(self.max_seq_length).view(1, 1, 1, self.max_seq_length)
prefill_len_expanded = prefill_len.view(bsz, 1, 1, 1)
attn_mask = range_tensor < prefill_len_expanded
attn_mask = attn_mask.expand(-1, self.n_head, -1, -1)
session.attn_mask = attn_mask
attn_mask = session.attn_mask
input_pos = session.input_pos
attn_mask[torch.arange(session.bsz), :, :, input_pos] = True
class CUDAGraphCache(CUDAGraphCacheABC):
def __init__(
self,
decoder,
) -> None:
super().__init__(decoder)
if torch.cuda.is_available():
self.attn_mask = (
torch.randint(0, 2, (decoder.max_batch_size, decoder.n_head, 1, decoder.max_seq_length))
.bool()
.to(self.device, self.dtype)
)
def release_graph(self, session: T2SSession):
if session.id != self.id:
self.assigned = False
else:
del (
session.graph,
session.xy_pos_,
session.xy_dec_,
session.input_pos,
session.kv_cache,
session.attn_mask,
)
def get_cache_graph(self, session: T2SSession):
assert self.graph
session.graph = self.graph
session.stream = self.stream
session.xy_pos_ = self.xy_pos
session.xy_dec_ = self.xy_dec
session.input_pos = self.input_pos.copy_(session.input_pos)
session.attn_mask = self.attn_mask
for cache, cache_ in zip(self.kv_cache, session.kv_cache):
cache.sync_cache(cache_)
def capture_new_graph(self, session: T2SSession):
session.xy_pos_ = self.xy_pos.clone()
session.xy_dec_ = self.xy_dec.clone()
session.input_pos = self.input_pos.clone().copy_(session.input_pos)
session.attn_mask = self.attn_mask.clone().copy_(session.attn_mask)
args, kwds = self.decoder.pre_forward(session)
graph = self.decoder.capture(self.input_pos, self.xy_pos, self.xy_dec, self.kv_cache, *args, **kwds)
session.graph = graph
session.stream = torch.cuda.Stream() # type: ignore

View File

@ -0,0 +1,167 @@
import torch
from torch.nn import functional as F
from .. import nn
from ..structs import KVCacheProtocol, T2SSession
from ..t2s_model_abc import (
AttentionABC,
CUDAGraphCacheABC,
FeedForward,
KVCacheHND,
T2SDecoderABC,
TransformerBlockABC,
TransformerDecoderABC,
)
Tensor = torch.Tensor
class Attention(AttentionABC):
def __init__(self, n_head, hidden_dim, max_seq_length):
super().__init__(n_head, hidden_dim, max_seq_length)
# key, query, value projections for all heads, but in a batch
self.in_proj = nn.Linear(hidden_dim, hidden_dim * 3, bias=True)
self.out_proj = nn.Linear(hidden_dim, hidden_dim, bias=True)
def __call__(self, x: Tensor, input_pos: Tensor, kv_cache: KVCacheProtocol, attn_mask: Tensor):
bsz, seqlen, _ = x.shape
q, k, v = self.in_proj(x).chunk(3, dim=-1)
q = q.view(bsz, seqlen, self.n_head, self.head_dim)
k = k.view(bsz, seqlen, self.n_head, self.head_dim)
v = v.view(bsz, seqlen, self.n_head, self.head_dim)
q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v))
k, v = kv_cache.update(input_pos, k, v)
attn = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
attn = self.dropout(attn)
attn = attn.transpose(1, 2).contiguous().view(bsz, seqlen, self.hidden_dim)
attn = self.out_proj(attn)
return attn
class TransformerBlock(TransformerBlockABC):
def __init__(self, n_head: int, ffn_dim: int, hidden_dim: int, max_seq_length: int) -> None:
super().__init__(n_head, ffn_dim, hidden_dim, max_seq_length)
self.attention = Attention(n_head, hidden_dim, max_seq_length)
self.feed_forward = FeedForward(hidden_dim, ffn_dim)
self.attention_norm = nn.LayerNorm([self.hidden_dim])
self.ffn_norm = nn.LayerNorm([self.hidden_dim])
class TransformerDecoder(TransformerDecoderABC):
def __init__(
self,
hidden_dim,
n_layer,
n_head,
ffn_dim,
vocab_size,
max_seq_length,
max_batch_size,
) -> None:
super().__init__(hidden_dim, n_layer, n_head, ffn_dim, vocab_size, max_seq_length, max_batch_size)
self.layers = nn.ModuleList( # type: ignore
TransformerBlock(n_head, ffn_dim, hidden_dim, max_seq_length) for _ in range(n_layer)
)
class T2SDecoder(T2SDecoderABC):
def __init__(
self,
config,
max_seq_length=1800,
max_batch_size=10,
) -> None:
super().__init__(config, max_seq_length, max_batch_size)
self.bert_proj = nn.Linear(1024, self.embedding_dim)
self.ar_predict_layer = nn.Linear(self.hidden_dim, self.vocab_size, bias=False)
self.h: TransformerDecoderABC = TransformerDecoder(
self.hidden_dim, self.n_layer, self.n_head, self.ffn_dim, self.vocab_size, max_seq_length, max_batch_size
)
self.kv_class = KVCacheHND
def pre_forward(self, session: T2SSession):
attn_mask = session.attn_mask
return list(), dict(attn_mask=attn_mask)
def post_forward(self, idx: int, session: T2SSession) -> None:
if idx == 0:
prefill_len = session.prefill_len
bsz = session.bsz
range_tensor = torch.arange(self.max_seq_length).view(1, 1, 1, self.max_seq_length)
prefill_len_expanded = prefill_len.view(bsz, 1, 1, 1)
attn_mask = range_tensor < prefill_len_expanded
attn_mask = attn_mask.expand(-1, self.n_head, -1, -1)
session.attn_mask = attn_mask
attn_mask = session.attn_mask
input_pos = session.input_pos
attn_mask[torch.arange(session.bsz), :, :, input_pos] = True
class CUDAGraphCache(CUDAGraphCacheABC):
def __init__(
self,
decoder,
) -> None:
super().__init__(decoder)
if torch.cuda.is_available():
self.attn_mask = (
torch.randint(0, 2, (decoder.max_batch_size, decoder.n_head, 1, decoder.max_seq_length))
.bool()
.to(self.device, self.dtype)
)
def release_graph(self, session: T2SSession):
if session.id != self.id:
self.assigned = False
else:
del (
session.graph,
session.xy_pos_,
session.xy_dec_,
session.input_pos,
session.kv_cache,
session.attn_mask,
)
def get_cache_graph(self, session: T2SSession):
assert self.graph
session.graph = self.graph
session.stream = self.stream
session.xy_pos_ = self.xy_pos
session.xy_dec_ = self.xy_dec
session.input_pos = self.input_pos.copy_(session.input_pos)
session.attn_mask = self.attn_mask
for cache, cache_ in zip(self.kv_cache, session.kv_cache):
cache.sync_cache(cache_)
def capture_new_graph(self, session: T2SSession):
session.xy_pos_ = self.xy_pos.clone()
session.xy_dec_ = self.xy_dec.clone()
session.input_pos = self.input_pos.clone().copy_(session.input_pos)
session.attn_mask = self.attn_mask.clone().copy_(session.attn_mask)
args, kwds = self.decoder.pre_forward(session)
graph = self.decoder.capture(self.input_pos, self.xy_pos, self.xy_dec, self.kv_cache, *args, **kwds)
session.graph = graph
session.stream = torch.cuda.Stream() # type: ignore

View File

@ -0,0 +1,178 @@
from typing import MutableSequence
import sageattention # type: ignore
import torch
from .. import nn
from ..structs import T2SSession
from ..t2s_model_abc import (
AttentionABC,
CUDAGraphCacheABC,
FeedForward,
KVCacheHND,
KVCacheProtocol,
T2SDecoderABC,
TransformerBlockABC,
TransformerDecoderABC,
)
Tensor = torch.Tensor
class Attention(AttentionABC):
def __init__(self, n_head, hidden_dim, max_seq_length):
super().__init__(n_head, hidden_dim, max_seq_length)
# key, query, value projections for all heads, but in a batch
self.in_proj = nn.Linear(hidden_dim, hidden_dim * 3, bias=True)
self.out_proj = nn.Linear(hidden_dim, hidden_dim, bias=True)
def __call__(
self,
x: Tensor,
input_pos: Tensor,
kv_cache: KVCacheProtocol,
cu_seqlens_q: Tensor,
cu_seqlens_kv: Tensor,
) -> Tensor:
bsz, seqlen, _ = x.shape
q, k, v = self.in_proj(x).chunk(3, dim=-1)
q = q.view(bsz, seqlen, self.n_head, self.head_dim)
k = k.view(bsz, seqlen, self.n_head, self.head_dim)
v = v.view(bsz, seqlen, self.n_head, self.head_dim)
q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v))
k, v = kv_cache.update(input_pos, k, v)
attn: Tensor = sageattention.sageattn_varlen(
q,
k,
v,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_kv=cu_seqlens_kv,
max_seqlen_q=1,
max_seqlen_k=self.max_seq_length,
)
attn = self.dropout(attn)
attn = attn.transpose(1, 2).contiguous().view(bsz, seqlen, self.hidden_dim)
attn = self.out_proj(attn)
return attn
class TransformerBlock(TransformerBlockABC):
def __init__(self, n_head, ffn_dim, hidden_dim, max_seq_length) -> None:
super().__init__(n_head, ffn_dim, hidden_dim, max_seq_length)
self.attention = Attention(n_head, hidden_dim, max_seq_length)
self.feed_forward = FeedForward(hidden_dim, ffn_dim)
self.attention_norm = nn.LayerNorm([self.hidden_dim])
self.ffn_norm = nn.LayerNorm([self.hidden_dim])
class TransformerDecoder(TransformerDecoderABC):
def __init__(
self,
hidden_dim,
n_layer,
n_head,
ffn_dim,
vocab_size,
max_seq_length,
max_batch_size,
) -> None:
super().__init__(hidden_dim, n_layer, n_head, ffn_dim, vocab_size, max_seq_length, max_batch_size)
self.layers = nn.ModuleList( # type: ignore
TransformerBlock(n_head, ffn_dim, hidden_dim, max_seq_length) for _ in range(n_layer)
)
class T2SDecoder(T2SDecoderABC):
def __init__(
self,
config,
max_seq_length=1800,
max_batch_size=10,
) -> None:
super().__init__(config, max_seq_length, max_batch_size)
self.bert_proj = nn.Linear(1024, self.embedding_dim)
self.ar_predict_layer = nn.Linear(self.hidden_dim, self.vocab_size, bias=False)
self.h: TransformerDecoderABC = TransformerDecoder(
self.hidden_dim, self.n_layer, self.n_head, self.ffn_dim, self.vocab_size, max_seq_length, max_batch_size
)
self.kv_class = KVCacheHND
def pre_forward(self, session: T2SSession) -> tuple[list[Tensor], dict[str, Tensor]]:
return list(), dict(cu_seqlens_q=session.cu_seqlens_q, cu_seqlens_kv=session.cu_seqlens_kv)
def post_forward(self, idx: int, session: T2SSession):
if idx == 0:
session.cu_seqlens_q = torch.arange(0, session.bsz + 1, dtype=torch.int32)
session.cu_seqlens_kv = torch.cat([torch.tensor(0, dtype=torch.int32), session.input_pos])
else:
cu_seqlens_q = session.cu_seqlens_q
cu_seqlens_kv = session.cu_seqlens_kv
cu_seqlens_kv.add_(cu_seqlens_q)
class CUDAGraphCache(CUDAGraphCacheABC):
def __init__(
self,
decoder: T2SDecoder,
) -> None:
super().__init__(decoder)
if torch.cuda.is_available():
self.cu_seqlens_q = torch.arange(0, decoder.max_batch_size + 1, dtype=torch.int32).to(self.device)
self.cu_seqlens_kv = torch.cat([torch.tensor(0, dtype=torch.int32), self.input_pos]).to(self.device)
def release_graph(self, session: T2SSession):
if session.id != self.id:
self.assigned = False
else:
del (
session.graph,
session.xy_pos_,
session.xy_dec_,
session.input_pos,
session.kv_cache,
session.cu_seqlens_q,
session.cu_seqlens_kv,
)
def get_cache_graph(self, session: T2SSession):
assert self.graph
session.graph = self.graph
session.stream = self.stream
session.xy_pos_ = self.xy_pos
session.xy_dec_ = self.xy_dec
session.input_pos = self.input_pos.copy_(session.input_pos)
session.cu_seqlens_q = self.cu_seqlens_q
session.cu_seqlens_kv = self.cu_seqlens_kv
for cache, cache_ in zip(self.kv_cache, session.kv_cache):
cache.sync_cache(cache_)
def capture_new_graph(self, session: T2SSession):
session.xy_pos_ = self.xy_pos.clone()
session.xy_dec_ = self.xy_dec.clone()
session.input_pos = self.input_pos.clone().copy_(session.input_pos)
session.cu_seqlens_q = self.cu_seqlens_q.clone().copy_(session.cu_seqlens_q)
session.cu_seqlens_kv = self.cu_seqlens_kv.clone().copy_(session.cu_seqlens_kv)
args, kwds = self.decoder.pre_forward(session)
graph = self.decoder.capture(self.input_pos, self.xy_pos, self.xy_dec, self.kv_cache, *args, **kwds)
session.graph = graph
session.stream = torch.cuda.Stream() # type: ignore

View File

@ -0,0 +1,69 @@
"""
Enhanced Type Hint nn.Module
Modified From https://github.com/labmlai/labml/blob/master/helpers/labml_helpers/module.py
"""
from typing import Any
import torch.nn
from torch.nn import (
functional as functional,
)
from torch.nn import (
utils as utils,
)
from torch.nn.modules import * # type: ignore # noqa: F403
from torch.nn.parameter import (
Parameter as Parameter,
)
Tensor = torch.Tensor
class Module(torch.nn.Module):
r"""
Wraps ``torch.nn.Module`` to overload ``__call__`` instead of
``forward`` for better type checking.
`PyTorch Github issue for clarification <https://github.com/pytorch/pytorch/issues/44605>`_
"""
def _forward_unimplemented(self, *input: Any) -> None:
# To stop PyTorch from giving abstract methods warning
pass
def __init_subclass__(cls, **kwargs):
if cls.__dict__.get("__call__", None) is None:
return
setattr(cls, "forward", cls.__dict__["__call__"])
delattr(cls, "__call__")
@property
def device(self) -> torch.device:
params = self.parameters()
try:
sample_param = next(params)
return sample_param.device
except StopIteration:
raise RuntimeError(f"Unable to determine device of {self.__class__.__name__}") from None
class Linear(torch.nn.Linear):
def __call__(self, input: Tensor) -> Tensor:
return super().__call__(input)
class Dropout(torch.nn.Dropout):
def __call__(self, input: Tensor) -> Tensor:
return super().__call__(input)
class Embedding(torch.nn.Embedding):
def __call__(self, input: Tensor) -> Tensor:
return super().__call__(input)
class LayerNorm(torch.nn.LayerNorm):
def __call__(self, input: Tensor) -> Tensor:
return super().__call__(input)

View File

@ -0,0 +1,62 @@
from typing import Protocol
import torch
import torch.nn.functional as F
Tensor = torch.Tensor
class SampleProtocol(Protocol):
@staticmethod
def __call__(
logits: Tensor,
previous_tokens: Tensor,
temperature: float,
top_k: int,
top_p: float,
repetition_penalty: float,
) -> Tensor: ...
class sample_naive(SampleProtocol):
@staticmethod
def __call__(
logits: Tensor,
previous_tokens: Tensor,
temperature: float,
top_k: int,
top_p: float,
repetition_penalty: float,
):
if temperature <= 1e-5:
probs = F.softmax(logits, dim=-1)
return torch.argmax(probs, dim=-1, keepdim=True)
if repetition_penalty != 1.0:
previous_tokens = previous_tokens.long()
score = torch.gather(logits, dim=1, index=previous_tokens)
score = torch.where(
score < 0,
score * repetition_penalty,
score / repetition_penalty,
)
logits.scatter_(dim=1, index=previous_tokens, src=score)
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
cum_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
sorted_indices_to_remove = cum_probs > top_p
sorted_indices_to_remove[:, 0] = False # keep at least one option
indices_to_remove = sorted_indices_to_remove.scatter(dim=1, index=sorted_indices, src=sorted_indices_to_remove)
logits = logits.masked_fill(indices_to_remove, -float("Inf"))
logits /= temperature
v, _ = torch.topk(logits, top_k)
pivot = v[:, -1].unsqueeze(-1)
logits = torch.where(logits < pivot, -float("Inf"), logits)
probs = F.softmax(logits, dim=-1)
q = torch.empty_like(probs).exponential_(1.0)
idx_next = torch.argmax(probs / q, dim=-1, keepdim=True).to(dtype=torch.int32)
return idx_next

View File

@ -0,0 +1,149 @@
"""
Modified From https://github.com/XXXXRT666/GPT-SoVITS
"""
from __future__ import annotations
from dataclasses import dataclass
from typing import Literal, MutableSequence, Optional, Protocol, Type
import torch
from .sample_funcs import SampleProtocol, sample_naive
Tensor = torch.Tensor
@dataclass
class T2SResult:
result: list[Tensor] | None = None
infer_speed: float = 0.0
status: Literal["Success", "Error"] = "Success"
exception: Optional[Exception] = None
traceback: Optional[str] = None
@dataclass
class T2SRequest:
x: list[torch.Tensor]
x_lens: Tensor
prompts: torch.Tensor
bert_feature: list[Tensor]
valid_length: int
top_k: int = 5
top_p: float = 1
early_stop_num: int = -1
temperature: float = 1.0
repetition_penalty: float = 1.35
use_cuda_graph: bool = False
debug: bool = False
class KVCacheProtocol(Protocol):
k_cache: Tensor
v_cache: Tensor
def empty(self) -> None: ...
def update(self, input_pos: Tensor, k_val: Tensor, v_val: Tensor, *args, **kwds) -> tuple[Tensor, Tensor]: ...
def prefill_kv(self, k_val: Tensor, v_val: Tensor) -> None: ...
def sync_cache(self, kv_cache: KVCacheProtocol) -> None: ...
class T2SDecoderProtocol(Protocol):
max_seq_length: int
EOS: int
n_head: int
@property
def device(self) -> torch.device: ...
def embed(self, x: list[Tensor], y: Tensor, bert_features: list[Tensor]) -> Tensor: ...
class T2SEngineProtocol(Protocol):
def _handle_request(self, request: T2SRequest) -> tuple[list[Tensor], float]: ...
def generate(self, request: T2SRequest) -> T2SResult: ...
class T2SSession:
def __init__(
self,
decoder: T2SDecoderProtocol,
request: T2SRequest,
sapmle_func: Type[SampleProtocol] = sample_naive,
device: torch.device = torch.device("cpu"),
dtype: torch.dtype = torch.float32,
):
with device:
self.decoder = decoder
self.request = request
self.device = device
self.dtype = dtype
bsz = len(request.x)
y_len = request.prompts.size(-1)
self.bsz = bsz
self.y_len = y_len
request.prompts = request.prompts.to(device, torch.int32)
# Cache
self.kv_cache: MutableSequence[KVCacheProtocol]
self.sample = sapmle_func()
# Forward args
self.x = [i.to(device) for i in request.x]
self.x_lens = request.x_lens.to(torch.int32)
self.y = torch.zeros((bsz, decoder.max_seq_length)).to(torch.int32)
self.y[:, : request.prompts.shape[-1]] = request.prompts
self.bert_feature = [i.to(device, dtype) for i in request.bert_feature]
self.prefill_len = self.x_lens + request.prompts.size(1)
self.input_pos = torch.zeros_like(self.prefill_len)
self.input_pos.add_(self.prefill_len)
# CUDA Graph
self.stream: Optional[torch.cuda.Stream] = None
self.graph: Optional[torch.cuda.CUDAGraph] = None
self.xy_pos_: Tensor
self.xy_dec_: Tensor
# EOS
self.completed = torch.Tensor([False] * len(self.x)).bool().to(device)
self.y_results: list[Tensor] = [None] * len(self.x) # type: ignore
self.xy_pos = decoder.embed(self.x, request.prompts, self.bert_feature)
max_len = int(self.prefill_len.max().item())
attn_mask = torch.zeros(size=(bsz, max_len, max_len), dtype=torch.bool)
for bs in range(bsz):
pos = int(self.x_lens[bs])
seq_len = pos + y_len
attn_mask[bs, :seq_len, :pos] = True
ar_mask = ~torch.triu(
input=torch.ones(
size=(
y_len,
y_len,
),
dtype=torch.bool,
),
diagonal=1,
)
attn_mask[bs, pos:seq_len, pos:seq_len] = ar_mask
self.attn_mask = attn_mask
self.attn_mask = attn_mask.unsqueeze(0).expand(-1, decoder.n_head, -1, -1)
self.id: int = -1
# Sage Attn & Transformer Engine Impl
self.cu_seqlens_q: Tensor
self.cu_seqlens_kv: Tensor

View File

@ -0,0 +1,204 @@
import contextlib
import gc
import os
import sys
import time
import traceback
from importlib import import_module
from typing import Type
import torch
from tqdm import tqdm
from .structs import T2SEngineProtocol, T2SRequest, T2SResult, T2SSession
from .t2s_model_abc import (
CUDAGraphCacheABC,
T2SDecoderABC,
TorchProfiler,
)
torch.set_grad_enabled(False)
class T2SEngine(T2SEngineProtocol):
def __init__(
self,
decoder_model: T2SDecoderABC,
device: torch.device = torch.device("cpu"),
dtype: torch.dtype = torch.float32,
) -> None:
assert device.type in {"cpu", "cuda", "mps", "xpu", "mtia"}
assert dtype in {torch.float16, torch.bfloat16, torch.float32}
self.device = device
self.dtype = dtype
self.decoder_model: T2SDecoderABC = decoder_model.to(self.device, self.dtype)
self.graphcache: CUDAGraphCacheABC = self.init_cache()
def _handle_request(self, request: T2SRequest):
with self.device:
decoder = self.decoder_model
session = T2SSession(decoder, request, device=self.device, dtype=self.dtype)
batch_idx = torch.arange(session.bsz)
t1 = 0.0
infer_speed = 0.0
torch_profiler = TorchProfiler(request.debug)
with torch_profiler.profiler():
for idx in tqdm(range(1500)):
if idx == 0:
session.kv_cache = decoder.init_cache(session.bsz)
xy_dec = decoder.h.prefill(session.xy_pos, session.kv_cache, session.attn_mask)
xy_dec = xy_dec[None, batch_idx, session.input_pos - 1]
else:
if request.use_cuda_graph and session.graph is None and torch.cuda.is_available():
self.graphcache.assign_graph(session)
with torch_profiler.record("AR"):
if session.graph:
assert session.stream
session.stream.wait_stream(torch.cuda.default_stream())
with torch.cuda.stream(session.stream):
session.xy_pos_.copy_(session.xy_pos)
session.graph.replay()
xy_dec = session.xy_dec_.clone()
else:
args, kwds = decoder.pre_forward(session)
xy_dec = decoder.h(
session.input_pos,
session.xy_pos,
session.kv_cache,
*args,
**kwds,
)
with torch.cuda.stream(session.stream) if session.stream is not None else contextlib.nullcontext():
decoder.post_forward(idx, session)
logits = decoder.ar_predict_layer(xy_dec[:, -1])
if idx == 0:
logits[:, -1] = float("-inf")
with torch_profiler.record("Sampling"):
samples = session.sample(
logits=logits,
previous_tokens=session.y,
input_pos=session.input_pos,
top_k=request.top_k,
top_p=request.top_p,
repetition_penalty=request.repetition_penalty,
temperature=request.temperature,
)
session.y[batch_idx, session.y_len + idx] = samples
session.input_pos.add_(1)
with torch_profiler.record("EOS"):
argmax_token = torch.argmax(logits, dim=-1)
sample_token = samples.squeeze(1)
EOS_mask = (argmax_token == decoder.EOS) | (sample_token == decoder.EOS)
newly_done_mask = EOS_mask & (~session.completed)
newly_done_indices = newly_done_mask.nonzero()
if newly_done_indices.numel() > 0:
for i in newly_done_indices:
print(i, i.shape, newly_done_indices, newly_done_indices.shape)
session.y_results[i] = session.y[i, session.y_len : session.y_len + idx]
session.completed[newly_done_indices] = True
if torch.all(session.completed).item():
if session.y.sum() == 0:
session.y_results = [torch.tensor(0) for _ in range(session.bsz)]
tqdm.write("Bad Zero Prediction")
else:
tqdm.write(
f"T2S Decoding EOS {session.prefill_len.tolist().__str__().strip('[]')} -> \n{[i.size(-1) for i in session.y_results].__str__().strip('[]')}"
)
tqdm.write(f"Infer Speed: {(idx - 1) / (time.perf_counter() - t1):.2f} token/s")
infer_speed = (idx - 1) / (time.perf_counter() - t1)
break
if (request.early_stop_num != -1 and idx >= request.early_stop_num) or idx == 1499:
for i in range(session.bsz):
if not session.completed[i].item():
session.y_results[i] = session.y[i, session.y_len : session.y_len + 1499]
session.completed[i] = True
break
with torch_profiler.record("NextPos"):
y_emb = decoder.ar_audio_embedding(samples)
session.xy_pos = decoder.ar_audio_position(session.input_pos - session.x_lens, y_emb)
if idx == 1:
torch_profiler.start()
t1 = time.perf_counter()
if idx == 51:
torch_profiler.end()
if idx % 100 == 0:
match session.device.type:
case "cuda":
torch.cuda.empty_cache()
case "mps":
torch.mps.empty_cache()
case "xpu":
torch.xpu.empty_cache()
case "mtia":
torch.mtia.empty_cache()
match session.device.type:
case "cuda":
if session.stream is not None:
torch.cuda.current_stream().wait_stream(session.stream)
torch.cuda.empty_cache()
case "mps":
torch.mps.empty_cache()
case "xpu":
torch.xpu.empty_cache()
case "mtia":
torch.mtia.empty_cache()
case "cpu":
gc.collect()
torch_profiler.end()
if request.use_cuda_graph and torch.cuda.is_available():
self.graphcache.release_graph(session)
return session.y_results[: request.valid_length], infer_speed
def generate(self, request: T2SRequest):
try:
result, infer_speed = self._handle_request(request)
t2s_result = T2SResult(result=result, infer_speed=infer_speed, status="Success")
except Exception as e:
t2s_result = T2SResult(status="Error", exception=e, traceback=traceback.format_exc())
return t2s_result
@staticmethod
def load_decoder(weights_path: os.PathLike, max_batch_size: int = 1, backend: str = "flash_attn"):
print(f"Loading Text2Semantic Weights from {weights_path} with {backend.replace('_', ' ').title()} Backend")
module_path = f".backends.t2s_model_{backend.lower()}"
decoder_cls_name = "T2SDecoder"
decoder_mod = import_module(module_path, package=__package__)
decoder_cls: Type[T2SDecoderABC] = getattr(decoder_mod, decoder_cls_name)
dict_s1 = torch.load(weights_path, map_location="cpu", weights_only=False, mmap=True)
config = dict_s1["config"]
decoder: T2SDecoderABC = decoder_cls(config, max_batch_size=max_batch_size)
state_dict = dict_s1["weight"]
decoder.load_state_dict(state_dict)
return decoder.eval()
def init_cache(self):
assert self.decoder_model
module_name = self.decoder_model.__class__.__module__
module = sys.modules.get(module_name)
assert module
target_class: Type[CUDAGraphCacheABC] = getattr(module, "CUDAGraphCache")
return target_class(self.decoder_model)

View File

@ -0,0 +1,652 @@
"""
Modified From https://github.com/XXXXRT666/GPT-SoVITS
"""
from __future__ import annotations
import math
import os
import random
from abc import ABC, abstractmethod
from contextlib import nullcontext
from typing import MutableSequence, Type
import torch
import torch._inductor.config
import torch.nn.functional as F
from torch.cuda.graphs import CUDAGraph
from torch.profiler import ProfilerAction, tensorboard_trace_handler
from . import nn
from .structs import KVCacheProtocol, T2SDecoderProtocol, T2SSession
Tensor = torch.Tensor
class TokenEmbedding(nn.Module):
def __init__(
self,
embedding_dim: int,
vocab_size: int,
dropout: float = 0.0,
):
super().__init__()
self.vocab_size = vocab_size
self.embedding_dim = embedding_dim
self.dropout = nn.Dropout(p=dropout)
self.word_embeddings = nn.Embedding(self.vocab_size, self.embedding_dim)
@property
def weight(self) -> Tensor:
return self.word_embeddings.weight
def embedding(self, index: int) -> Tensor:
return self.word_embeddings.weight[index : index + 1]
def __call__(self, x: Tensor):
x = self.word_embeddings(x)
x = self.dropout(x)
return x
class SinePositionalEmbedding(nn.Module):
def __init__(
self,
embedding_dim: int,
dropout: float = 0.0,
scale: bool = False,
alpha: bool = False,
max_batch_size: int = 10,
max_seq_len: int = 1800,
):
super().__init__()
self.embedding_dim = embedding_dim
self.x_scale = math.sqrt(embedding_dim) if scale else 1.0
self.alpha = nn.Parameter(torch.ones(1), requires_grad=alpha)
self.dropout = nn.Dropout(p=dropout)
self.max_batch_size = max_batch_size
self.max_seq_len = max_seq_len
self.reverse = False
self.register_buffer("pe", torch.zeros(max_batch_size, max_seq_len, embedding_dim), persistent=False)
self.pe: torch.Tensor
self.compute_pe()
def compute_pe(self):
"""Reset the positional encodings."""
if self.reverse:
position = torch.arange(self.max_seq_len - 1, -1, -1.0, dtype=torch.float32).unsqueeze(1)
else:
position = torch.arange(self.max_seq_len, dtype=torch.float32).unsqueeze(1)
div_term = torch.exp(
torch.arange(0, self.embedding_dim, 2, dtype=torch.float32) * -(math.log(10000.0) / self.embedding_dim)
)
pe = self.pe
pe[:, :, 0::2] = torch.sin(position * div_term)
pe[:, :, 1::2] = torch.cos(position * div_term)
def __call__(self, input_pos: Tensor, x: Tensor) -> Tensor:
"""
Args:
input_pos (Tensor): [batch_size, ]
x (Tensor): [batch_size, 1, embed_dim]
Returns:
embedded_x (Tensor): [batch_size, 1, embed_dim]
"""
batch_size = x.shape[0]
pe_values = self.pe[torch.arange(batch_size), input_pos - 1] # (batch_size, embed_dim)
return x * self.x_scale + self.alpha * pe_values.unsqueeze(1) # (batch_size, 1, embed_dim)
def prefill(self, x: Tensor) -> Tensor:
"""
Args:
x (Tensor): [batch_size, seq_len, embed_dim]
Returns:
embedded_x (Tensor): [batch_size, seq_len, embed_dim]
"""
pe_values = self.pe[:, : x.shape[-2]]
return x * self.x_scale + self.alpha.item() * pe_values
class KVCacheABC(nn.Module, ABC, KVCacheProtocol):
def __init__(self, batch_size: int, max_seq_length: int, n_heads: int, head_dim: int) -> None:
super().__init__()
self.n_head = n_heads
self.head_dim = head_dim
self.batch_size = batch_size
self.max_seq_length = max_seq_length
self.k_cache: Tensor
self.v_cache: Tensor
def empty(self):
self.k_cache.zero_()
self.v_cache.zero_()
@abstractmethod
def update(self, input_pos: Tensor, k_val: Tensor, v_val: Tensor, *args, **kwds) -> tuple[Tensor, Tensor]: ...
@abstractmethod
def prefill_kv(self, k_val: Tensor, v_val: Tensor) -> None: ...
def sync_cache(self, kv_cache: KVCacheProtocol):
self.k_cache.copy_(kv_cache.k_cache)
self.v_cache.copy_(kv_cache.v_cache)
class KVCacheNHD(KVCacheABC):
def __init__(self, batch_size, max_seq_length, n_heads, head_dim):
super().__init__(batch_size, max_seq_length, n_heads, head_dim)
assert batch_size > 0
cache_shape = (batch_size, max_seq_length, n_heads, head_dim)
self.register_buffer("k_cache", torch.zeros(size=cache_shape), persistent=False)
self.register_buffer("v_cache", torch.zeros(size=cache_shape), persistent=False)
def update(self, input_pos: Tensor, k_val: Tensor, v_val: Tensor):
# input_pos: [B, ], k_val: [B, 1, H, D]
index = (
(input_pos - 1)
.unsqueeze(-1)
.unsqueeze(-1)
.unsqueeze(-1)
.expand(
-1,
-1,
self.n_head,
self.head_dim,
)
.to(torch.int64)
) # (bs, 1, num_head, head_dim)
k_out = self.k_cache
v_out = self.v_cache
k_out.scatter_(1, index, k_val)
v_out.scatter_(1, index, v_val)
return k_out, v_out
def empty(self):
self.k_cache.zero_()
self.v_cache.zero_()
def prefill_kv(self, k_val: Tensor, v_val: Tensor):
# input_pos: int, k_val: [B, S, H, D]
self.k_cache[:, : k_val.shape[1]] = k_val
self.v_cache[:, : v_val.shape[1]] = v_val
class KVCacheHND(KVCacheABC):
def __init__(self, batch_size, max_seq_length, n_heads, head_dim):
super().__init__(batch_size, max_seq_length, n_heads, head_dim)
cache_shape = (batch_size, n_heads, max_seq_length, head_dim)
self.register_buffer("k_cache", torch.zeros(size=cache_shape), persistent=False)
self.register_buffer("v_cache", torch.zeros(size=cache_shape), persistent=False)
def update(self, input_pos: Tensor, k_val: Tensor, v_val: Tensor):
# input_pos: [B, ], k_val: [B, H, 1, D]
index = (
(input_pos - 1)
.unsqueeze(-1)
.unsqueeze(-1)
.unsqueeze(-1)
.expand(
-1,
self.n_head,
-1,
self.head_dim,
)
.to(torch.int64)
) # (bs, num_head, 1, head_dim)
k_out = self.k_cache
v_out = self.v_cache
k_out.scatter_(2, index, k_val)
v_out.scatter_(2, index, v_val)
return k_out, v_out
def empty(self):
self.k_cache.zero_()
self.v_cache.zero_()
def prefill_kv(self, k_val: Tensor, v_val: Tensor):
# input_pos: int, k_val: [B, S, H, D]
self.k_cache[..., : k_val.shape[1], :] = k_val.transpose(1, 2)
self.v_cache[..., : v_val.shape[1], :] = v_val.transpose(1, 2)
class AttentionABC(nn.Module, ABC):
def __init__(self, n_head: int, hidden_dim: int, max_seq_length: int):
super().__init__()
self.n_head = n_head
self.hidden_dim = hidden_dim
assert hidden_dim % n_head == 0
self.head_dim = hidden_dim // n_head
self.max_seq_length = max_seq_length
# key, query, value projections for all heads, but in a batch
self.in_proj: nn.Linear
self.out_proj: nn.Linear
self.dropout = nn.Dropout(0.1)
self._register_load_state_dict_pre_hook(self.load_hook)
def load_hook(self, state_dict: dict[str, Tensor], prefix, *args):
keys_to_modify = [key for key in state_dict if "in_proj_" in key]
for key in keys_to_modify:
new_key = key.replace("in_proj_", "in_proj.") # in_proj_ -> in_proj.
state_dict[new_key] = state_dict.pop(key)
@abstractmethod
def __call__(self, x: Tensor, input_pos: Tensor, kv_cache: KVCacheProtocol, *args, **kwds) -> Tensor: ...
def prefill(self, x: Tensor, kv_cache: KVCacheProtocol, attn_mask: Tensor) -> Tensor:
bsz, seqlen, _ = x.shape
q, k, v = self.in_proj(x.unsqueeze(0)).chunk(3, dim=-1)
q, k, v = map(lambda x: x.contiguous().view(bsz, seqlen, self.n_head, self.head_dim), (q, k, v))
kv_cache.prefill_kv(k, v)
q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v))
attn = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
attn = self.dropout(attn)
attn = attn.transpose(1, 2).contiguous().view(1, -1, self.hidden_dim)
output = self.out_proj(attn)
return output
class FeedForward(nn.Module):
def __init__(self, dim: int, hidden_dim: int) -> None:
super().__init__()
self.linear1 = nn.Linear(dim, hidden_dim, bias=True)
self.linear2 = nn.Linear(hidden_dim, dim, bias=True)
self.dropout = nn.Dropout(0.1)
def __call__(self, x: Tensor):
return self.dropout(self.linear2(self.dropout(F.relu(self.linear1(x)))))
class TransformerBlockABC(nn.Module, ABC):
def __init__(self, n_head: int, ffn_dim: int, hidden_dim: int, max_seq_length: int) -> None:
super().__init__()
self.hidden_dim = hidden_dim
self.max_seq_length = max_seq_length
self.attention: AttentionABC
self.feed_forward: FeedForward
self.attention_norm: nn.LayerNorm
self.ffn_norm: nn.LayerNorm
self.dropout = nn.Dropout(0.1)
self._register_load_state_dict_pre_hook(self.load_hook)
def load_hook(self, state_dict: dict[str, Tensor], prefix, *args):
for key in list(state_dict.keys()):
new_key = (
key.replace("self_attn", "attention")
.replace("linear", "feed_forward.linear")
.replace("norm1", "attention_norm")
.replace("norm2", "ffn_norm")
)
state_dict[new_key] = state_dict.pop(key)
def __call__(self, x: Tensor, input_pos: Tensor, kv_cache: KVCacheProtocol, *args, **kwds):
h = self.attention_norm(
x
+ self.dropout(
self.attention(
x,
input_pos,
kv_cache,
*args,
**kwds,
)
)
)
out = self.ffn_norm(h + self.feed_forward(h))
return out
def prefill(
self,
x: Tensor,
kv_cache: KVCacheProtocol,
attn_mask: Tensor,
) -> Tensor:
h = self.attention_norm(
x
+ self.dropout(
self.attention.prefill(
x,
kv_cache,
attn_mask,
)
)
)
out = self.ffn_norm(h + self.feed_forward(h))
return out
class TransformerDecoderABC(nn.Module, ABC):
def __init__(
self,
hidden_dim: int,
n_layer: int,
n_head: int,
ffn_dim: int,
vocab_size: int,
max_seq_length: int,
max_batch_size: int,
) -> None:
super().__init__()
self.hidden_dim = hidden_dim
self.n_head = n_head
assert hidden_dim % n_head == 0
self.head_dim = hidden_dim // n_head
self.vocab_size = vocab_size
self.n_layer = n_layer
self.layers: MutableSequence[TransformerBlockABC]
self.max_seq_length = max_seq_length
self.max_batch_size = max_batch_size
def __call__(self, input_pos: Tensor, x: Tensor, kv_caches: MutableSequence[KVCacheProtocol], *args, **kwds):
for layer, kv_cache in zip(self.layers, kv_caches):
x = layer(x, input_pos, kv_cache, *args, **kwds)
return x
def prefill(self, x: Tensor, kv_caches: MutableSequence[KVCacheProtocol], attn_mask: Tensor):
for layer, kv_cache in zip(self.layers, kv_caches):
x = layer.prefill(x, kv_cache, attn_mask)
return x
class T2SDecoderABC(nn.Module, ABC, T2SDecoderProtocol):
def __init__(
self,
config: dict,
max_seq_length: int = 1800,
max_batch_size: int = 10,
) -> None:
super().__init__()
hidden_dim: int = config["model"]["hidden_dim"]
embedding_dim: int = config["model"]["embedding_dim"]
n_head: int = config["model"]["head"]
n_layer: int = config["model"]["n_layer"]
vocab_size: int = config["model"]["vocab_size"]
phoneme_vocab_size: int = config["model"]["phoneme_vocab_size"]
p_dropout: float = config["model"]["dropout"]
EOS: int = config["model"]["EOS"]
ffn_dim: int = hidden_dim * 4
self.n_layer = int(n_layer)
self.hidden_dim = int(hidden_dim)
self.n_head = int(n_head)
assert hidden_dim % n_head == 0
self.head_dim = int(hidden_dim // n_head)
self.embedding_dim = int(embedding_dim)
self.ffn_dim = int(ffn_dim)
self.vocab_size = int(vocab_size)
self.phoneme_vocab_size = int(phoneme_vocab_size)
self.p_dropout = float(p_dropout)
self.max_seq_length = max_seq_length
self.max_batch_size = max_batch_size
self.EOS = EOS
assert self.EOS == self.vocab_size - 1
self.bert_proj: nn.Linear
self.ar_predict_layer: nn.Linear
self.h: TransformerDecoderABC
self.kv_class: Type[KVCacheNHD] | Type[KVCacheHND]
self.GraphCache: CUDAGraphCacheABC | None
self.ar_text_embedding = TokenEmbedding(self.embedding_dim, self.phoneme_vocab_size, self.p_dropout)
self.ar_text_position = SinePositionalEmbedding(
self.embedding_dim,
dropout=0.1,
scale=False,
alpha=True,
max_batch_size=max_batch_size,
max_seq_len=max_seq_length,
)
self.ar_audio_embedding = TokenEmbedding(self.embedding_dim, self.vocab_size, self.p_dropout)
self.ar_audio_position = SinePositionalEmbedding(
self.embedding_dim,
dropout=0.1,
scale=False,
alpha=True,
max_batch_size=max_batch_size,
max_seq_len=max_seq_length,
)
self._register_load_state_dict_pre_hook(self.load_hook)
def load_hook(self, state_dict: dict[str, Tensor], prefix, *args):
model_keys = [key for key in state_dict if key.startswith("model.")]
for key in model_keys:
new_key = key[len("model.") :]
state_dict[new_key] = state_dict.pop(key)
def init_cache(self, bsz: int = 0) -> MutableSequence[KVCacheProtocol]:
bsz = bsz or self.h.max_batch_size
assert bsz <= self.h.max_batch_size
seq_lens = self.h.max_seq_length
dtype = self.bert_proj.bias.dtype
kvclass = self.kv_class
return nn.ModuleList(
[kvclass(bsz, seq_lens, self.n_head, self.head_dim) for _ in range(self.n_layer)],
).to(self.device, dtype) # type: ignore
def embed(
self,
x: list[torch.Tensor],
y: torch.Tensor,
bert_features: list[torch.Tensor],
):
x_len: list[int] = [i.shape[0] for i in x]
x_len_max = max(x_len)
xy_pos = torch.zeros((len(x), x_len_max + y.shape[1], self.embedding_dim)).to(bert_features[0].dtype)
bert_features = list(map(lambda x: x.transpose(0, 1), bert_features))
y_len = y.shape[1]
y_emb = self.ar_audio_embedding(y)
y_pos = self.ar_audio_position.prefill(y_emb)
for bs, (x_, len_, bert_feature) in enumerate(zip(x, x_len, bert_features)):
x_emb = self.ar_text_embedding(x_)
bert = self.bert_proj(bert_feature)
x_emb = x_emb + bert
x_pos = self.ar_text_position.prefill(x_emb.unsqueeze(0))
xy_pos[[bs], :len_] = x_pos
xy_pos[[bs], len_ : len_ + y_len] = y_pos
return xy_pos
def compile(self, *args, **kwds):
# Experimental features to reduce compilation times, will be on by default in future
torch._inductor.config.triton.cudagraph_skip_dynamic_graphs = True
torch._inductor.config.coordinate_descent_tuning = True
torch._inductor.config.triton.unique_kernel_names = True
torch._inductor.config.fx_graph_cache = True
torch._inductor.config.triton.cudagraph_trees = True
torch._inductor.config.triton.cudagraph_support_input_mutation = True
self.h.compile(fullgraph=True, mode="reduce-overhead")
def capture(
self, input_pos: Tensor, x: Tensor, x_dec: Tensor, kv_caches: MutableSequence[KVCacheProtocol], *args, **kwds
) -> CUDAGraph:
assert torch.cuda.is_available()
s = torch.cuda.Stream()
s.wait_stream(torch.cuda.current_stream())
graph = torch.cuda.CUDAGraph()
with torch.cuda.stream(s): # type: ignore
for _ in range(5):
self.h(input_pos, x, kv_caches, *args, **kwds)
torch.cuda.current_stream().wait_stream(s)
with torch.cuda.graph(graph):
x_dec.copy_(self.h(input_pos, x, kv_caches, *args, **kwds))
torch.cuda.synchronize()
return graph
@abstractmethod
def pre_forward(self, session: T2SSession) -> tuple[list[Tensor], dict[str, Tensor]]:
return list(), dict()
@abstractmethod
def post_forward(self, idx: int, session: T2SSession) -> None:
return
class CUDAGraphCacheABC(ABC):
def __init__(
self,
decoder: T2SDecoderABC,
) -> None:
if torch.cuda.is_available():
self.device: torch.device = decoder.device
self.dtype = decoder.bert_proj.bias.dtype
self.assigned: bool = False
self.decoder: T2SDecoderABC = decoder
self.kv_cache: MutableSequence[KVCacheProtocol] = decoder.init_cache(decoder.max_batch_size)
self.xy_pos = torch.rand(size=(decoder.max_batch_size, 1, decoder.embedding_dim), device=self.device).to(
self.dtype
)
self.xy_dec = self.xy_pos.clone()
self.input_pos = torch.tensor([10] * decoder.max_batch_size, device=self.device).int()
self.graph: torch.cuda.CUDAGraph | None = None
self.stream: torch.cuda.Stream | None
self.id: int = random.randint(1, 2**32 - 1)
def assign_graph(self, session: T2SSession):
if self.graph is None:
args, kwds = self.decoder.pre_forward(session)
graph = self.decoder.capture(self.input_pos, self.xy_pos, self.xy_dec, self.kv_cache, *args, **kwds)
self.graph = graph
self.stream = torch.cuda.Stream() # type: ignore
if self.assigned is False:
self.get_cache_graph(session)
session.id = self.id
self.assigned = True
else:
self.capture_new_graph(session)
@abstractmethod
def release_graph(self, session: T2SSession): ...
@abstractmethod
def get_cache_graph(self, session: T2SSession):
pass
@abstractmethod
def capture_new_graph(self, session: T2SSession):
pass
class TorchProfiler:
def __init__(self, debug: bool, log_dir: str = "./profiler") -> None:
self.debug = debug
self.log_dir = log_dir
self.__profiler: torch.profiler.profile
if self.debug and not os.path.exists(self.log_dir):
os.makedirs(self.log_dir)
self.tensorboard_handler = tensorboard_trace_handler(self.log_dir)
def profiler_callback(self, prof: torch.profiler.profile):
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=30))
print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=30))
self.tensorboard_handler(prof)
@staticmethod
def three_step_schedule(step: int) -> ProfilerAction:
if step == 0:
return ProfilerAction.NONE
elif step == 1:
return ProfilerAction.RECORD
elif step == 2:
return ProfilerAction.RECORD_AND_SAVE
else:
return ProfilerAction.NONE
def start(self):
if not self.debug:
return
assert self.__profiler is not None
self.__profiler.step()
def end(self):
if not self.debug:
return
assert self.__profiler is not None
self.__profiler.step()
def profiler(self):
if self.debug:
activities_list = [torch.profiler.ProfilerActivity.CPU]
if torch.cuda.is_available():
activities_list.append(torch.profiler.ProfilerActivity.CUDA)
self.__profiler = torch.profiler.profile(
activities=activities_list,
record_shapes=True,
with_stack=True,
with_modules=True,
profile_memory=True,
schedule=self.three_step_schedule,
on_trace_ready=self.profiler_callback,
)
return self.__profiler
else:
return nullcontext()
def record(self, func_name: str):
if self.debug:
return torch.profiler.record_function(func_name)
else:
return nullcontext()

View File

@ -0,0 +1,6 @@
from . import MLX, PyTorch
from .PyTorch import T2SEngineTorch, T2SRequest, T2SResult
backends = PyTorch.backends + MLX.backends
__all__ = ["T2SEngineTorch", "T2SRequest", "T2SResult", "backends", "MLX", "PyTorch"]

File diff suppressed because it is too large Load Diff

View File

@ -315,7 +315,7 @@ with gr.Blocks(title="GPT-SoVITS WebUI", analytics_enabled=False, js=js, css=css
with gr.Column():
# with gr.Group():
gr.Markdown(value=i18n("模型切换"))
with gr.Row():
with gr.Row(equal_height=True):
GPT_dropdown = gr.Dropdown(
label=i18n("GPT模型列表"),
choices=sorted(GPT_names, key=custom_sort_key),
@ -331,18 +331,22 @@ with gr.Blocks(title="GPT-SoVITS WebUI", analytics_enabled=False, js=js, css=css
refresh_button = gr.Button(i18n("刷新模型路径"), variant="primary")
refresh_button.click(fn=change_choices, inputs=[], outputs=[SoVITS_dropdown, GPT_dropdown])
with gr.Row():
with gr.Row(equal_height=True):
with gr.Column():
gr.Markdown(value=i18n("*请上传并填写参考信息"))
with gr.Row():
inp_ref = gr.Audio(label=i18n("主参考音频(请上传3~10秒内参考音频超过会报错)"), type="filepath")
with gr.Row(equal_height=True):
inp_ref = gr.Audio(
label=i18n("主参考音频(请上传3~10秒内参考音频超过会报错)"),
type="filepath",
waveform_options={"show_recording_waveform": False},
)
inp_refs = gr.File(
label=i18n("辅参考音频(可选多个,或不选)"),
file_count="multiple",
visible=True if model_version != "v3" else False,
)
prompt_text = gr.Textbox(label=i18n("主参考音频的文本"), value="", lines=2)
with gr.Row():
with gr.Row(equal_height=True):
prompt_language = gr.Dropdown(
label=i18n("主参考音频的语种"), choices=list(dict_language.keys()), value=i18n("中文")
)
@ -368,26 +372,26 @@ with gr.Blocks(title="GPT-SoVITS WebUI", analytics_enabled=False, js=js, css=css
with gr.Group():
gr.Markdown(value=i18n("推理设置"))
with gr.Row():
with gr.Row(equal_height=True):
with gr.Column():
with gr.Row():
with gr.Row(equal_height=True):
batch_size = gr.Slider(
minimum=1, maximum=200, step=1, label=i18n("batch_size"), value=20, interactive=True
)
sample_steps = gr.Radio(
label=i18n("采样步数(仅对V3/4生效)"), value=32, choices=[4, 8, 16, 32, 64, 128], visible=True
)
with gr.Row():
with gr.Row(equal_height=True):
fragment_interval = gr.Slider(
minimum=0.01, maximum=1, step=0.01, label=i18n("分段间隔(秒)"), value=0.3, interactive=True
)
speed_factor = gr.Slider(
minimum=0.6, maximum=1.65, step=0.05, label="语速", value=1.0, interactive=True
)
with gr.Row():
with gr.Row(equal_height=True):
top_k = gr.Slider(minimum=1, maximum=100, step=1, label=i18n("top_k"), value=5, interactive=True)
top_p = gr.Slider(minimum=0, maximum=1, step=0.05, label=i18n("top_p"), value=1, interactive=True)
with gr.Row():
with gr.Row(equal_height=True):
temperature = gr.Slider(
minimum=0, maximum=1, step=0.05, label=i18n("temperature"), value=1, interactive=True
)
@ -396,7 +400,7 @@ with gr.Blocks(title="GPT-SoVITS WebUI", analytics_enabled=False, js=js, css=css
)
with gr.Column():
with gr.Row():
with gr.Row(equal_height=True):
how_to_cut = gr.Dropdown(
label=i18n("怎么切"),
choices=[
@ -415,7 +419,7 @@ with gr.Blocks(title="GPT-SoVITS WebUI", analytics_enabled=False, js=js, css=css
label=i18n("音频超采样(仅对V3生效))"), value=False, interactive=True, show_label=True
)
with gr.Row():
with gr.Row(equal_height=True):
parallel_infer = gr.Checkbox(label=i18n("并行推理"), value=True, interactive=True, show_label=True)
split_bucket = gr.Checkbox(
label=i18n("数据分桶(并行推理时会降低一点计算量)"),
@ -424,12 +428,15 @@ with gr.Blocks(title="GPT-SoVITS WebUI", analytics_enabled=False, js=js, css=css
show_label=True,
)
with gr.Row():
with gr.Row(equal_height=True):
seed = gr.Number(label=i18n("随机种子"), value=-1)
keep_random = gr.Checkbox(label=i18n("保持随机"), value=True, interactive=True, show_label=True)
output = gr.Audio(label=i18n("输出的语音"))
with gr.Row():
output = gr.Audio(
label=i18n("输出的语音"),
waveform_options={"show_recording_waveform": False},
)
with gr.Row(equal_height=True):
inference_button = gr.Button(i18n("合成语音"), variant="primary")
stop_infer = gr.Button(i18n("终止合成"), variant="primary")
@ -485,7 +492,7 @@ with gr.Blocks(title="GPT-SoVITS WebUI", analytics_enabled=False, js=js, css=css
"文本切分工具。太长的文本合成出来效果不一定好,所以太长建议先切。合成会根据文本的换行分开合成再拼起来。"
)
)
with gr.Row():
with gr.Row(equal_height=True):
text_inp = gr.Textbox(label=i18n("需要合成的切分前文本"), value="", lines=4)
with gr.Column():
_how_to_cut = gr.Radio(

View File

@ -1,40 +1,41 @@
import logging
import re
from pathlib import Path
# jieba静音
import fast_langdetect
import jieba
from split_lang import LangSplitter
jieba.setLogLevel(logging.CRITICAL)
# 更改fast_langdetect大模型位置
from pathlib import Path
import fast_langdetect
fast_langdetect.infer._default_detector = fast_langdetect.infer.LangDetector(fast_langdetect.infer.LangDetectConfig(cache_dir=Path(__file__).parent.parent.parent / "pretrained_models" / "fast_langdetect"))
from split_lang import LangSplitter
fast_langdetect.infer._default_detector = fast_langdetect.infer.LangDetector(
fast_langdetect.infer.LangDetectConfig(
cache_dir=str(Path(__file__).parent.parent.parent / "pretrained_models" / "fast_langdetect")
)
)
def full_en(text):
pattern = r'^(?=.*[A-Za-z])[A-Za-z0-9\s\u0020-\u007E\u2000-\u206F\u3000-\u303F\uFF00-\uFFEF]+$'
pattern = r"^(?=.*[A-Za-z])[A-Za-z0-9\s\u0020-\u007E\u2000-\u206F\u3000-\u303F\uFF00-\uFFEF]+$"
return bool(re.match(pattern, text))
def full_cjk(text):
# 来自wiki
cjk_ranges = [
(0x4E00, 0x9FFF), # CJK Unified Ideographs
(0x3400, 0x4DB5), # CJK Extension A
(0x20000, 0x2A6DD), # CJK Extension B
(0x2A700, 0x2B73F), # CJK Extension C
(0x2B740, 0x2B81F), # CJK Extension D
(0x2B820, 0x2CEAF), # CJK Extension E
(0x2CEB0, 0x2EBEF), # CJK Extension F
(0x30000, 0x3134A), # CJK Extension G
(0x31350, 0x323AF), # CJK Extension H
(0x2EBF0, 0x2EE5D), # CJK Extension H
(0x4E00, 0x9FFF), # CJK Unified Ideographs
(0x3400, 0x4DB5), # CJK Extension A
(0x20000, 0x2A6DD), # CJK Extension B
(0x2A700, 0x2B73F), # CJK Extension C
(0x2B740, 0x2B81F), # CJK Extension D
(0x2B820, 0x2CEAF), # CJK Extension E
(0x2CEB0, 0x2EBEF), # CJK Extension F
(0x30000, 0x3134A), # CJK Extension G
(0x31350, 0x323AF), # CJK Extension H
(0x2EBF0, 0x2EE5D), # CJK Extension H
]
pattern = r'[0-9、-〜。!?.!?… /]+$'
pattern = r"[0-9、-〜。!?.!?… /]+$"
cjk_text = ""
for char in text:
@ -45,7 +46,7 @@ def full_cjk(text):
return cjk_text
def split_jako(tag_lang,item):
def split_jako(tag_lang, item):
if tag_lang == "ja":
pattern = r"([\u3041-\u3096\u3099\u309A\u30A1-\u30FA\u30FC]+(?:[0-9、-〜。!?.!?… ]+[\u3041-\u3096\u3099\u309A\u30A1-\u30FA\u30FC]*)*)"
else:
@ -53,41 +54,42 @@ def split_jako(tag_lang,item):
lang_list: list[dict] = []
tag = 0
for match in re.finditer(pattern, item['text']):
for match in re.finditer(pattern, item["text"]):
if match.start() > tag:
lang_list.append({'lang':item['lang'],'text':item['text'][tag:match.start()]})
lang_list.append({"lang": item["lang"], "text": item["text"][tag : match.start()]})
tag = match.end()
lang_list.append({'lang':tag_lang,'text':item['text'][match.start():match.end()]})
lang_list.append({"lang": tag_lang, "text": item["text"][match.start() : match.end()]})
if tag < len(item['text']):
lang_list.append({'lang':item['lang'],'text':item['text'][tag:len(item['text'])]})
if tag < len(item["text"]):
lang_list.append({"lang": item["lang"], "text": item["text"][tag : len(item["text"])]})
return lang_list
def merge_lang(lang_list, item):
if lang_list and item['lang'] == lang_list[-1]['lang']:
lang_list[-1]['text'] += item['text']
if lang_list and item["lang"] == lang_list[-1]["lang"]:
lang_list[-1]["text"] += item["text"]
else:
lang_list.append(item)
return lang_list
class LangSegmenter():
class LangSegmenter:
# 默认过滤器, 基于gsv目前四种语言
DEFAULT_LANG_MAP = {
"zh": "zh",
"yue": "zh", # 粤语
"wuu": "zh", # 吴语
"zh-cn": "zh",
"zh-tw": "x", # 繁体设置为x
"zh-tw": "x", # 繁体设置为x
"ko": "ko",
"ja": "ja",
"en": "en",
}
def getTexts(text,default_lang = ""):
@staticmethod
def getTexts(text, default_lang=""):
lang_splitter = LangSplitter(lang_map=LangSegmenter.DEFAULT_LANG_MAP)
lang_splitter.merge_across_digit = False
substr = lang_splitter.split_by_lang(text=text)
@ -97,31 +99,31 @@ class LangSegmenter():
have_num = False
for _, item in enumerate(substr):
dict_item = {'lang':item.lang,'text':item.text}
dict_item = {"lang": item.lang, "text": item.text}
if dict_item['lang'] == 'digit':
if dict_item["lang"] == "digit":
if default_lang != "":
dict_item['lang'] = default_lang
dict_item["lang"] = default_lang
else:
have_num = True
lang_list = merge_lang(lang_list,dict_item)
lang_list = merge_lang(lang_list, dict_item)
continue
# 处理短英文被识别为其他语言的问题
if full_en(dict_item['text']):
dict_item['lang'] = 'en'
lang_list = merge_lang(lang_list,dict_item)
if full_en(dict_item["text"]):
dict_item["lang"] = "en"
lang_list = merge_lang(lang_list, dict_item)
continue
if default_lang != "":
dict_item['lang'] = default_lang
lang_list = merge_lang(lang_list,dict_item)
dict_item["lang"] = default_lang
lang_list = merge_lang(lang_list, dict_item)
continue
else:
# 处理非日语夹日文的问题(不包含CJK)
ja_list: list[dict] = []
if dict_item['lang'] != 'ja':
ja_list = split_jako('ja',dict_item)
if dict_item["lang"] != "ja":
ja_list = split_jako("ja", dict_item)
if not ja_list:
ja_list.append(dict_item)
@ -130,8 +132,8 @@ class LangSegmenter():
ko_list: list[dict] = []
temp_list: list[dict] = []
for _, ko_item in enumerate(ja_list):
if ko_item["lang"] != 'ko':
ko_list = split_jako('ko',ko_item)
if ko_item["lang"] != "ko":
ko_list = split_jako("ko", ko_item)
if ko_list:
temp_list.extend(ko_list)
@ -141,77 +143,76 @@ class LangSegmenter():
# 未存在非日韩文夹日韩文
if len(temp_list) == 1:
# 未知语言检查是否为CJK
if dict_item['lang'] == 'x':
cjk_text = full_cjk(dict_item['text'])
if dict_item["lang"] == "x":
cjk_text = full_cjk(dict_item["text"])
if cjk_text:
dict_item = {'lang':'zh','text':cjk_text}
lang_list = merge_lang(lang_list,dict_item)
dict_item = {"lang": "zh", "text": cjk_text}
lang_list = merge_lang(lang_list, dict_item)
else:
lang_list = merge_lang(lang_list,dict_item)
lang_list = merge_lang(lang_list, dict_item)
continue
else:
lang_list = merge_lang(lang_list,dict_item)
lang_list = merge_lang(lang_list, dict_item)
continue
# 存在非日韩文夹日韩文
for _, temp_item in enumerate(temp_list):
# 未知语言检查是否为CJK
if temp_item['lang'] == 'x':
cjk_text = full_cjk(temp_item['text'])
if temp_item["lang"] == "x":
cjk_text = full_cjk(temp_item["text"])
if cjk_text:
lang_list = merge_lang(lang_list,{'lang':'zh','text':cjk_text})
lang_list = merge_lang(lang_list, {"lang": "zh", "text": cjk_text})
else:
lang_list = merge_lang(lang_list,temp_item)
lang_list = merge_lang(lang_list, temp_item)
else:
lang_list = merge_lang(lang_list,temp_item)
lang_list = merge_lang(lang_list, temp_item)
# 有数字
if have_num:
temp_list = lang_list
lang_list = []
for i, temp_item in enumerate(temp_list):
if temp_item['lang'] == 'digit':
if temp_item["lang"] == "digit":
if default_lang:
temp_item['lang'] = default_lang
temp_item["lang"] = default_lang
elif lang_list and i == len(temp_list) - 1:
temp_item['lang'] = lang_list[-1]['lang']
temp_item["lang"] = lang_list[-1]["lang"]
elif not lang_list and i < len(temp_list) - 1:
temp_item['lang'] = temp_list[1]['lang']
temp_item["lang"] = temp_list[1]["lang"]
elif lang_list and i < len(temp_list) - 1:
if lang_list[-1]['lang'] == temp_list[i + 1]['lang']:
temp_item['lang'] = lang_list[-1]['lang']
elif lang_list[-1]['text'][-1] in [",",".","!","?","","","",""]:
temp_item['lang'] = temp_list[i + 1]['lang']
elif temp_list[i + 1]['text'][0] in [",",".","!","?","","","",""]:
temp_item['lang'] = lang_list[-1]['lang']
elif temp_item['text'][-1] in ["","."]:
temp_item['lang'] = lang_list[-1]['lang']
elif len(lang_list[-1]['text']) >= len(temp_list[i + 1]['text']):
temp_item['lang'] = lang_list[-1]['lang']
if lang_list[-1]["lang"] == temp_list[i + 1]["lang"]:
temp_item["lang"] = lang_list[-1]["lang"]
elif lang_list[-1]["text"][-1] in [",", ".", "!", "?", "", "", "", ""]:
temp_item["lang"] = temp_list[i + 1]["lang"]
elif temp_list[i + 1]["text"][0] in [",", ".", "!", "?", "", "", "", ""]:
temp_item["lang"] = lang_list[-1]["lang"]
elif temp_item["text"][-1] in ["", "."]:
temp_item["lang"] = lang_list[-1]["lang"]
elif len(lang_list[-1]["text"]) >= len(temp_list[i + 1]["text"]):
temp_item["lang"] = lang_list[-1]["lang"]
else:
temp_item['lang'] = temp_list[i + 1]['lang']
temp_item["lang"] = temp_list[i + 1]["lang"]
else:
temp_item['lang'] = 'zh'
lang_list = merge_lang(lang_list,temp_item)
temp_item["lang"] = "zh"
lang_list = merge_lang(lang_list, temp_item)
# 筛X
temp_list = lang_list
lang_list = []
for _, temp_item in enumerate(temp_list):
if temp_item['lang'] == 'x':
if temp_item["lang"] == "x":
if lang_list:
temp_item['lang'] = lang_list[-1]['lang']
temp_item["lang"] = lang_list[-1]["lang"]
elif len(temp_list) > 1:
temp_item['lang'] = temp_list[1]['lang']
temp_item["lang"] = temp_list[1]["lang"]
else:
temp_item['lang'] = 'zh'
temp_item["lang"] = "zh"
lang_list = merge_lang(lang_list,temp_item)
lang_list = merge_lang(lang_list, temp_item)
return lang_list
if __name__ == "__main__":
text = "MyGO?,你也喜欢まいご吗?"
@ -221,5 +222,5 @@ if __name__ == "__main__":
print(LangSegmenter.getTexts(text))
text = "当时ThinkPad T60刚刚发布一同推出的还有一款名为Advanced Dock的扩展坞配件。这款扩展坞通过连接T60底部的插槽扩展出包括PCIe在内的一大堆接口并且自带电源让T60可以安装桌面显卡来提升性能。"
print(LangSegmenter.getTexts(text,"zh"))
print(LangSegmenter.getTexts(text))
print(LangSegmenter.getTexts(text, "zh"))
print(LangSegmenter.getTexts(text))

View File

@ -248,13 +248,13 @@ if you want to switch to V1,then double-click`go-webui-v1.bat` or use `go-webui-
#### Others
```bash
python webui.py <language(optional)>
PYTHONPATH=. python webui.py <language(optional)>
```
if you want to switch to V1,then
```bash
python webui.py v1 <language(optional)>
PYTHONPATH=. python webui.py v1 <language(optional)>
```
Or maunally switch version in WebUI
@ -285,7 +285,7 @@ python GPT_SoVITS/inference_webui.py <language(optional)>
OR
```bash
python webui.py
PYTHONPATH=. python webui.py
```
then open the inference webui at `1-GPT-SoVITS-TTS/1C-inference`

View File

@ -161,7 +161,7 @@ def get_device_dtype_sm(idx: int) -> tuple[torch.device, torch.dtype, float, flo
is_16_series = bool(re.search(r"16\d{2}", name)) and sm_version == 7.5
if mem_gb < 4 or sm_version < 5.3:
return cpu, torch.float32, 0.0, 0.0
if sm_version == 6.1 or is_16_series == True:
if sm_version == 6.1 or is_16_series is True:
return cuda, torch.float32, sm_version, mem_gb
if sm_version > 6.1:
return cuda, torch.float16, sm_version, mem_gb
@ -216,3 +216,22 @@ class Config:
self.webui_port_subfix = webui_port_subfix
self.api_port = api_port
def get_implement(device: torch.device):
if torch.cuda.is_available():
idx = device.index
capability = torch.cuda.get_device_capability(idx)
major, minor = capability
sm_version = major + minor / 10.0
if sm_version >= 7.5:
return "flash_attn"
else:
if sys.platform == "linux":
return "sage_attn"
else:
return "naive"
elif torch.mps.is_available():
return "mlx"
else:
return "naive"

View File

@ -236,13 +236,13 @@ D:\GPT-SoVITS\xxx/xxx.wav|xxx|zh|我爱玩原神.
#### 其他
```bash
python webui.py <language(optional)>
PYTHONPATH=. python webui.py <language(optional)>
```
若想使用 V1,则
```bash
python webui.py v1 <language(optional)>
PYTHONPATH=. python webui.py v1 <language(optional)>
```
或者在 webUI 内动态切换
@ -273,7 +273,7 @@ python GPT_SoVITS/inference_webui.py <language(optional)>
或者
```bash
python webui.py
PYTHONPATH=. python webui.py
```
然后在 `1-GPT-SoVITS-TTS/1C-推理` 中打开推理 webUI

View File

@ -222,13 +222,13 @@ V1 に切り替えたい場合は、`go-webui-v1.bat`をダブルクリックす
#### その他
```bash
python webui.py <言語(オプション)>
PYTHONPATH=. python webui.py <言語(オプション)>
```
V1 に切り替えたい場合は
```bash
python webui.py v1 <言語(オプション)>
PYTHONPATH=. python webui.py v1 <言語(オプション)>
```
または WebUI で手動でバージョンを切り替えてください.
@ -259,7 +259,7 @@ python GPT_SoVITS/inference_webui.py <言語(オプション)>
または
```bash
python webui.py
PYTHONPATH=. python webui.py
```
その後、`1-GPT-SoVITS-TTS/1C-inference`で推論 webui を開きます.

View File

@ -228,13 +228,13 @@ V1으로 전환하려면, `go-webui-v1.bat`을 더블 클릭하거나 `go-webui-
#### 기타
```bash
python webui.py <언어(옵션)>
PYTHONPATH=. python webui.py <언어(옵션)>
```
V1으로 전환하려면,
```bash
python webui.py v1 <언어(옵션)>
PYTHONPATH=. python webui.py v1 <언어(옵션)>
```
또는 WebUI에서 수동으로 버전을 전환하십시오.
@ -265,7 +265,7 @@ python GPT_SoVITS/inference_webui.py <언어(옵션)>
또는
```bash
python webui.py
PYTHONPATH=. python webui.py
```
그런 다음 `1-GPT-SoVITS-TTS/1C-inference`에서 추론 webui를 엽니다.

View File

@ -229,13 +229,13 @@ V1'e geçmek istiyorsanız, `go-webui-v1.bat` dosyasına çift tıklayın veya `
#### Diğerleri
```bash
python webui.py <dil(isteğe bağlı)>
PYTHONPATH=. python webui.py <dil(isteğe bağlı)>
```
V1'e geçmek istiyorsanız,
```bash
python webui.py v1 <dil(isteğe bağlı)>
PYTHONPATH=. python webui.py v1 <dil(isteğe bağlı)>
```
veya WebUI'de manuel olarak sürüm değiştirin.
@ -266,7 +266,7 @@ python GPT_SoVITS/inference_webui.py <dil(isteğe bağlı)>
VEYA
```bash
python webui.py
PYTHONPATH=. python webui.py
```
ardından çıkarım webui'sini `1-GPT-SoVITS-TTS/1C-inference` adresinde açın.

View File

@ -2,5 +2,6 @@ set "SCRIPT_DIR=%~dp0"
set "SCRIPT_DIR=%SCRIPT_DIR:~0,-1%"
cd /d "%SCRIPT_DIR%"
set "PATH=%SCRIPT_DIR%\runtime;%PATH%"
set "PYTHONPATH=%SCRIPT_DIR%"
runtime\python.exe -I webui.py zh_CN
pause

View File

@ -3,5 +3,6 @@ chcp 65001
Set-Location $PSScriptRoot
$runtimePath = Join-Path $PSScriptRoot "runtime"
$env:PATH = "$runtimePath;$env:PATH"
$env:PYTHONPATH = "$runtimePath"
& "$runtimePath\python.exe" -I "$PSScriptRoot\webui.py" zh_CN
pause

View File

@ -40,6 +40,10 @@ function Write-Info($msg) {
Write-Host "[INFO]:" -ForegroundColor Green -NoNewline
Write-Host " $msg"
}
function Write-Warning($msg) {
Write-Host "[Warning]:" -ForegroundColor Yellow -NoNewline
Write-Host " $msg"
}
function Write-Success($msg) {
Write-Host "[SUCCESS]:" -ForegroundColor Blue -NoNewline
Write-Host " $msg"
@ -137,7 +141,7 @@ chcp 65001
Set-Location $PSScriptRoot
Write-Info "Installing FFmpeg & CMake..."
Invoke-Conda ffmpeg cmake
Invoke-Conda ffmpeg cmake vc14_runtime
Write-Success "FFmpeg & CMake Installed"
$PretrainedURL = ""
@ -208,12 +212,30 @@ if ($DownloadUVR5) {
switch ($Device) {
"CU128" {
$cudaLine = nvidia-smi | Select-String "CUDA Version"
$version = ($cudaLine -split "CUDA Version:")[1].Trim()
Write-Info "Maximum CUDA Version Supported By Current Driver: $version"
if ([version](nvidia-smi | Select-String "CUDA Version" | ForEach-Object { ($_ -split "CUDA Version:")[1].Trim() }) -ge [version]"12.8") {
Write-Warning "CUDA 12.8 Is Not Supported By Current Driver"
}
Write-Info "Installing PyTorch For CUDA 12.8..."
Invoke-Pip torch torchaudio --index-url "https://download.pytorch.org/whl/cu128"
Invoke-Conda cuda-nvcc=12.8
Invoke-Pip psutil ninja packaging wheel "setuptools>=42"
Invoke-Pip flash-attn -i https://xxxxrt666.github.io/PIP-Index/ --no-build-isolation
}
"CU126" {
$cudaLine = nvidia-smi | Select-String "CUDA Version"
$version = ($cudaLine -split "CUDA Version:")[1].Trim()
Write-Info "Maximum CUDA Version Supported By Current Driver: $version"
if ([version](nvidia-smi | Select-String "CUDA Version" | ForEach-Object { ($_ -split "CUDA Version:")[1].Trim() }) -ge [version]"12.8") {
Write-Warning "CUDA 12.6 Is Not Supported By Current Driver"
}
Write-Info "Installing PyTorch For CUDA 12.6..."
Invoke-Pip torch torchaudio --index-url "https://download.pytorch.org/whl/cu126"
Invoke-Conda cuda-nvcc=12.6
Invoke-Pip psutil ninja packaging wheel "setuptools>=42"
Invoke-Pip flash-attn -i https://xxxxrt666.github.io/PIP-Index/ --no-build-isolation
}
"CPU" {
Write-Info "Installing PyTorch For CPU..."

View File

@ -127,7 +127,7 @@ while [[ $# -gt 0 ]]; do
USE_ROCM=true
;;
MPS)
USE_CPU=true
USE_MPS=true
;;
CPU)
USE_CPU=true
@ -157,7 +157,7 @@ while [[ $# -gt 0 ]]; do
esac
done
if ! $USE_CUDA && ! $USE_ROCM && ! $USE_CPU; then
if ! $USE_CUDA && ! $USE_ROCM && ! $USE_MPS && ! $USE_CPU; then
echo -e "${ERROR}Error: Device is REQUIRED"
echo ""
print_help
@ -322,13 +322,29 @@ if [ "$USE_ROCM" = true ] && [ "$WORKFLOW" = false ]; then
fi
if [ "$USE_CUDA" = true ] && [ "$WORKFLOW" = false ]; then
CUDAVERSION=$(nvidia-smi | grep "CUDA Version" | sed -E 's/.*CUDA Version: ([0-9]+\.[0-9]+).*/\1/')
echo -e "${INFO}Maximum CUDA Version Supported By Current Driver: $CUDAVERSION"
if [ "$CUDA" = 128 ]; then
if awk "BEGIN {exit !($CUDAVERSION < 12.8)}"; then
echo -r "${WARNING}CUDA 12.8 Is Not Supported By Current Driver"
fi
echo -e "${INFO}Installing PyTorch For CUDA 12.8..."
run_pip_quiet torch torchaudio --index-url "https://download.pytorch.org/whl/cu128"
run_conda_quiet cuda-nvcc=12.8
elif [ "$CUDA" = 126 ]; then
if awk "BEGIN {exit !($CUDAVERSION < 12.6)}"; then
echo -r "${WARNING}CUDA 12.6 Is Not Supported By Current Driver"
fi
echo -e "${INFO}Installing PyTorch For CUDA 12.6..."
run_pip_quiet torch torchaudio --index-url "https://download.pytorch.org/whl/cu126"
run_conda_quiet cuda-nvcc=12.6
fi
run_pip_quiet psutil ninja packaging wheel "setuptools>=42"
run_pip_quiet flash-attn -i https://xxxxrt666.github.io/PIP-Index/ --no-build-isolation
elif [ "$USE_MPS" = true ] && [ "$WORKFLOW" = false ]; then
echo -e "${INFO}Installing PyTorch For MPS..."
run_pip_quiet torch torchaudio --index-url "https://download.pytorch.org/whl/cpu"
run_pip_quiet mlx mlx-lm
elif [ "$USE_ROCM" = true ] && [ "$WORKFLOW" = false ]; then
echo -e "${INFO}Installing PyTorch For ROCm 6.2..."
run_pip_quiet torch torchaudio --index-url "https://download.pytorch.org/whl/rocm6.2"

View File

@ -5,7 +5,7 @@ tensorboard
librosa==0.10.2
numba
pytorch-lightning>=2.4
gradio<5
gradio==5.25.0
ffmpeg-python
onnxruntime; platform_machine == "aarch64" or platform_machine == "arm64"
onnxruntime-gpu; platform_machine == "x86_64" or platform_machine == "AMD64"
@ -16,9 +16,11 @@ pypinyin
pyopenjtalk>=0.4.1
g2p_en
torchaudio
modelscope==1.10.0
modelscope
sentencepiece
transformers>=4.43,<=4.50
transformers
huggingface_hub
kernels
peft
chardet
PyYAML
@ -39,7 +41,6 @@ x_transformers
torchmetrics<=1.5
pydantic<=2.10.6
ctranslate2>=4.0,<5
huggingface_hub>=0.13
tokenizers>=0.13,<1
av>=11
tqdm

View File

@ -222,5 +222,6 @@
"预训练SoVITS-D模型路径": "Pretrained SoVITS-D Model Path",
"预训练SoVITS-G模型路径": "Pretrained SoVITS-G Model Path",
"预训练中文BERT模型路径": "Pretrained Chinese BERT Model Path",
"预训练模型路径": "Pretrained Model Path"
"预训练模型路径": "Pretrained Model Path",
"推理后端": "Inference Backend"
}

View File

@ -222,5 +222,6 @@
"预训练SoVITS-D模型路径": "预训练SoVITS-D模型路径",
"预训练SoVITS-G模型路径": "预训练SoVITS-G模型路径",
"预训练中文BERT模型路径": "预训练中文BERT模型路径",
"预训练模型路径": "预训练模型路径"
"预训练模型路径": "预训练模型路径",
"推理后端": "推理后端"
}

View File

@ -1,4 +1,5 @@
import sys
from tools.i18n.i18n import I18nAuto, scan_language_list
language = sys.argv[-1] if sys.argv[-1] in scan_language_list() else "Auto"
@ -314,7 +315,7 @@ if __name__ == "__main__":
"Submit Text: 将当前页所有文本框内容手工保存到内存和文件(翻页前后或者退出标注页面前如果没点这个按钮,你再翻回来就回滚了,白忙活。)"
)
)
with gr.Row():
with gr.Row(equal_height=True):
btn_change_index = gr.Button("Change Index")
btn_submit_change = gr.Button("Submit Text")
btn_merge_audio = gr.Button("Merge Audio")
@ -322,7 +323,7 @@ if __name__ == "__main__":
btn_previous_index = gr.Button("Previous Index")
btn_next_index = gr.Button("Next Index")
with gr.Row():
with gr.Row(equal_height=True):
index_slider = gr.Slider(minimum=0, maximum=g_max_json_index, value=g_index, step=1, label="Index", scale=3)
splitpoint_slider = gr.Slider(
minimum=0, maximum=120.0, value=0, step=0.1, label="Audio Split Point(s)", scale=3
@ -331,18 +332,23 @@ if __name__ == "__main__":
btn_save_json = gr.Button("Save File", visible=True, scale=1)
btn_invert_selection = gr.Button("Invert Selection", scale=1)
with gr.Row():
with gr.Row(equal_height=True):
with gr.Column():
for _ in range(0, g_batch):
with gr.Row():
with gr.Row(equal_height=True):
text = gr.Textbox(label="Text", visible=True, scale=5)
audio_output = gr.Audio(label="Output Audio", visible=True, scale=5)
audio_output = gr.Audio(
label="Output Audio",
visible=True,
scale=5,
waveform_options={"show_recording_waveform": False},
)
audio_check = gr.Checkbox(label="Yes", show_label=True, info="Choose Audio", scale=1)
g_text_list.append(text)
g_audio_list.append(audio_output)
g_checkbox_list.append(audio_check)
with gr.Row():
with gr.Row(equal_height=True):
batchsize_slider = gr.Slider(
minimum=1, maximum=g_batch, value=g_batch, step=1, label="Batch Size", scale=3, interactive=False
)

View File

@ -168,7 +168,7 @@ with gr.Blocks(title="UVR5 WebUI", analytics_enabled=False) as app:
"h4",
)
)
with gr.Row():
with gr.Row(equal_height=True):
with gr.Column():
model_choose = gr.Dropdown(label=i18n("模型"), choices=uvr5_names)
dir_wav_input = gr.Textbox(
@ -197,9 +197,9 @@ with gr.Blocks(title="UVR5 WebUI", analytics_enabled=False) as app:
interactive=True,
)
with gr.Column():
with gr.Row():
with gr.Row(equal_height=True):
but2 = gr.Button(i18n("转换"), variant="primary")
with gr.Row():
with gr.Row(equal_height=True):
vc_output4 = gr.Textbox(label=i18n("输出信息"), lines=3)
but2.click(
uvr,

586
webui.py
View File

@ -1,22 +1,96 @@
import os
import sys
os.environ["version"] = version = "v2Pro"
now_dir = os.getcwd()
sys.path.insert(0, now_dir)
import warnings
warnings.filterwarnings("ignore")
import argparse
import contextlib
import json
import os
import platform
import re
import shutil
import signal
import site
import subprocess
import sys
import traceback
import warnings
from multiprocessing import cpu_count
from subprocess import Popen
import gradio as gr
import psutil
import torch
import yaml
now_dir = os.getcwd()
sys.path.insert(0, now_dir)
from config import (
GPU_INDEX,
GPU_INFOS,
IS_GPU,
GPT_weight_root,
GPT_weight_version2root,
SoVITS_weight_root,
SoVITS_weight_version2root,
change_choices,
exp_root,
get_weights_names,
infer_device,
is_half,
is_share,
memset,
pretrained_gpt_name,
pretrained_sovits_name,
python_exec,
webui_port_infer_tts,
webui_port_main,
webui_port_subfix,
webui_port_uvr5,
)
from GPT_SoVITS.Accelerate import backends
from tools import my_utils
from tools.asr.config import asr_dict
from tools.assets import css, js, top_html
from tools.i18n.i18n import I18nAuto, scan_language_list
from tools.my_utils import check_details, check_for_existance
_LANG_RE = re.compile(r"^[a-z]{2}[_-][A-Z]{2}$")
def lang_type(text: str) -> str:
if text == "Auto":
return text
if not _LANG_RE.match(text):
raise argparse.ArgumentTypeError(f"Unspported Format: {text}, Expected ll_CC/ll-CC")
ll, cc = re.split(r"[_-]", text)
language = f"{ll}_{cc}"
if language in scan_language_list():
return language
else:
return "en_US"
def build_parser() -> argparse.ArgumentParser:
p = argparse.ArgumentParser(
prog="train_webui",
description="python -s webui.py zh_CN",
)
p.add_argument(
"language",
nargs="?",
default="Auto",
type=lang_type,
help="Language Code, Such as zh_CN, en-US",
)
return p
args = build_parser().parse_args()
os.environ["version"] = version = "v2Pro"
os.environ["TORCH_DISTRIBUTED_DEBUG"] = "INFO"
warnings.filterwarnings("ignore")
torch.manual_seed(233333)
tmp = os.path.join(now_dir, "TEMP")
os.makedirs(tmp, exist_ok=True)
@ -32,8 +106,6 @@ if os.path.exists(tmp):
except Exception as e:
print(str(e))
pass
import site
import traceback
site_packages_roots = []
for path in site.getsitepackages():
@ -41,7 +113,6 @@ for path in site.getsitepackages():
site_packages_roots.append(path)
if site_packages_roots == []:
site_packages_roots = ["%s/runtime/Lib/site-packages" % now_dir]
# os.environ["OPENBLAS_NUM_THREADS"] = "4"
os.environ["no_proxy"] = "localhost, 127.0.0.1, ::1"
os.environ["all_proxy"] = ""
for site_packages_root in site_packages_roots:
@ -56,41 +127,10 @@ for site_packages_root in site_packages_roots:
break
except PermissionError:
traceback.print_exc()
import shutil
import subprocess
from subprocess import Popen
from tools.assets import css, js, top_html
from tools.i18n.i18n import I18nAuto, scan_language_list
language = sys.argv[-1] if sys.argv[-1] in scan_language_list() else "Auto"
os.environ["language"] = language
language = args.language
i18n = I18nAuto(language=language)
from multiprocessing import cpu_count
from config import (
GPU_INDEX,
GPU_INFOS,
IS_GPU,
exp_root,
infer_device,
is_half,
is_share,
memset,
python_exec,
webui_port_infer_tts,
webui_port_main,
webui_port_subfix,
webui_port_uvr5,
)
from tools import my_utils
from tools.my_utils import check_details, check_for_existance
os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
# os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1' # 当遇到mps不支持的步骤时使用cpu
import gradio as gr
n_cpu = cpu_count()
@ -147,7 +187,7 @@ def fix_gpu_number(input): # 将越界的number强制改到界内
try:
if int(input) not in set_gpu_numbers:
return default_gpu_numbers
except:
except Exception as _:
return input
return input
@ -158,13 +198,10 @@ def fix_gpu_numbers(inputs):
for input in inputs.split(","):
output.append(str(fix_gpu_number(input)))
return ",".join(output)
except:
except Exception as _:
return inputs
from config import pretrained_gpt_name, pretrained_sovits_name
def check_pretrained_is_exist(version):
pretrained_model_list = (
pretrained_sovits_name[version],
@ -189,14 +226,6 @@ for key in pretrained_gpt_name.keys():
if os.path.exists(pretrained_gpt_name[key]) == False:
pretrained_gpt_name[key] = ""
from config import (
GPT_weight_root,
GPT_weight_version2root,
SoVITS_weight_root,
SoVITS_weight_version2root,
change_choices,
get_weights_names,
)
for root in SoVITS_weight_root + GPT_weight_root:
os.makedirs(root, exist_ok=True)
@ -218,15 +247,11 @@ def kill_proc_tree(pid, including_parent=True):
children = parent.children(recursive=True)
for child in children:
try:
with contextlib.suppress(OSError):
os.kill(child.pid, signal.SIGTERM) # or signal.SIGKILL
except OSError:
pass
if including_parent:
try:
with contextlib.suppress(OSError):
os.kill(parent.pid, signal.SIGTERM) # or signal.SIGKILL
except OSError:
pass
system = platform.system()
@ -329,21 +354,20 @@ def change_uvr5():
process_name_tts = i18n("TTS推理WebUI")
def change_tts_inference(bert_path, cnhubert_base_path, gpu_number, gpt_path, sovits_path, batched_infer_enabled):
def change_tts_inference(
bert_path, cnhubert_base_path, gpu_number, gpt_path, sovits_path, batched_infer_enabled, backends_dropdown
):
global p_tts_inference
if batched_infer_enabled:
cmd = '"%s" -s GPT_SoVITS/inference_webui_fast.py "%s"' % (python_exec, language)
cmd = f"'{python_exec}' -s GPT_SoVITS/inference_webui_fast.py {language}"
else:
cmd = '"%s" -s GPT_SoVITS/inference_webui.py "%s"' % (python_exec, language)
# #####v3暂不支持加速推理
# if version=="v3":
# cmd = '"%s" GPT_SoVITS/inference_webui.py "%s"'%(python_exec, language)
cmd = f"'{python_exec}' -s GPT_SoVITS/inference_webui.py {language} -b {backends_dropdown}"
if p_tts_inference is None:
os.environ["gpt_path"] = gpt_path
os.environ["sovits_path"] = sovits_path
os.environ["cnhubert_base_path"] = cnhubert_base_path
os.environ["bert_path"] = bert_path
os.environ["_CUDA_VISIBLE_DEVICES"] = fix_gpu_number(gpu_number)
os.environ["_CUDA_VISIBLE_DEVICES"] = str(fix_gpu_number(gpu_number))
os.environ["is_half"] = str(is_half)
os.environ["infer_ttswebui"] = str(webui_port_infer_tts)
os.environ["is_share"] = str(is_share)
@ -364,8 +388,6 @@ def change_tts_inference(bert_path, cnhubert_base_path, gpu_number, gpt_path, so
)
from tools.asr.config import asr_dict
process_name_asr = i18n("语音识别")
@ -764,7 +786,7 @@ def close_slice():
for p_slice in ps_slice:
try:
kill_process(p_slice.pid, process_name_slice)
except:
except Exception as _:
traceback.print_exc()
ps_slice = []
return (
@ -853,7 +875,7 @@ def close1a():
for p1a in ps1a:
try:
kill_process(p1a.pid, process_name_1a)
except:
except Exception as _:
traceback.print_exc()
ps1a = []
return (
@ -944,7 +966,7 @@ def close1b():
for p1b in ps1b:
try:
kill_process(p1b.pid, process_name_1b)
except:
except Exception as _:
traceback.print_exc()
ps1b = []
return (
@ -1030,7 +1052,7 @@ def close1c():
for p1c in ps1c:
try:
kill_process(p1c.pid, process_name_1c)
except:
except Exception as _:
traceback.print_exc()
ps1c = []
return (
@ -1230,7 +1252,7 @@ def open1abc(
{"__type__": "update", "visible": True},
{"__type__": "update", "visible": False},
)
except:
except Exception as _:
traceback.print_exc()
close1abc()
yield (
@ -1252,7 +1274,7 @@ def close1abc():
for p1abc in ps1abc:
try:
kill_process(p1abc.pid, process_name_1abc)
except:
except Exception as _:
traceback.print_exc()
ps1abc = []
return (
@ -1303,6 +1325,14 @@ def sync(text):
return {"__type__": "update", "value": text}
def changeBackend(flag: bool):
if flag:
return gr.update(choices=["naive"], value="naive")
else:
return gr.update(choices=backends, value=backends[-1])
GPU_INDEX.add(0)
with gr.Blocks(title="GPT-SoVITS WebUI", analytics_enabled=False, js=js, css=css) as app:
gr.HTML(
top_html.format(
@ -1315,9 +1345,9 @@ with gr.Blocks(title="GPT-SoVITS WebUI", analytics_enabled=False, js=js, css=css
with gr.Tabs():
with gr.TabItem("0-" + i18n("前置数据集获取工具")): # 提前随机切片防止uvr5爆内存->uvr5->slicer->asr->打标
with gr.Accordion(label="0a-" + i18n("UVR5人声伴奏分离&去混响去延迟工具")):
with gr.Row():
with gr.Row(equal_height=True):
with gr.Column(scale=3):
with gr.Row():
with gr.Row(equal_height=True):
uvr5_info = gr.Textbox(label=process_info(process_name_uvr5, "info"))
open_uvr5 = gr.Button(
value=process_info(process_name_uvr5, "open"), variant="primary", visible=True
@ -1327,14 +1357,14 @@ with gr.Blocks(title="GPT-SoVITS WebUI", analytics_enabled=False, js=js, css=css
)
with gr.Accordion(label="0b-" + i18n("语音切分工具")):
with gr.Row():
with gr.Row(equal_height=True):
with gr.Column(scale=3):
with gr.Row():
with gr.Row(equal_height=True):
slice_inp_path = gr.Textbox(label=i18n("音频自动切分输入路径,可文件可文件夹"), value="")
slice_opt_root = gr.Textbox(
label=i18n("切分后的子音频的输出根目录"), value="output/slicer_opt"
)
with gr.Row():
with gr.Row(equal_height=True):
threshold = gr.Textbox(
label=i18n("threshold:音量小于这个值视作静音的备选切割点"), value="-34"
)
@ -1348,7 +1378,7 @@ with gr.Blocks(title="GPT-SoVITS WebUI", analytics_enabled=False, js=js, css=css
value="10",
)
max_sil_kept = gr.Textbox(label=i18n("max_sil_kept:切完后静音最多留多长"), value="500")
with gr.Row():
with gr.Row(equal_height=True):
_max = gr.Slider(
minimum=0,
maximum=1,
@ -1365,7 +1395,7 @@ with gr.Blocks(title="GPT-SoVITS WebUI", analytics_enabled=False, js=js, css=css
value=0.25,
interactive=True,
)
with gr.Row():
with gr.Row(equal_height=True):
n_process = gr.Slider(
minimum=1,
maximum=n_cpu,
@ -1385,10 +1415,10 @@ with gr.Blocks(title="GPT-SoVITS WebUI", analytics_enabled=False, js=js, css=css
# gr.Markdown(value="0bb-" + i18n("语音降噪工具")+i18n("(不稳定,先别用,可能劣化模型效果!)"))
with gr.Row(visible=False):
with gr.Column(scale=3):
with gr.Row():
with gr.Row(equal_height=True):
denoise_input_dir = gr.Textbox(label=i18n("输入文件夹路径"), value="")
denoise_output_dir = gr.Textbox(label=i18n("输出文件夹路径"), value="output/denoise_opt")
with gr.Row():
with gr.Row(equal_height=True):
denoise_info = gr.Textbox(label=process_info(process_name_denoise, "info"))
open_denoise_button = gr.Button(
value=process_info(process_name_denoise, "open"), variant="primary", visible=True
@ -1398,16 +1428,16 @@ with gr.Blocks(title="GPT-SoVITS WebUI", analytics_enabled=False, js=js, css=css
)
with gr.Accordion(label="0c-" + i18n("语音识别工具")):
with gr.Row():
with gr.Row(equal_height=True):
with gr.Column(scale=3):
with gr.Row():
with gr.Row(equal_height=True):
asr_inp_dir = gr.Textbox(
label=i18n("输入文件夹路径"), value="D:\\GPT-SoVITS\\raw\\xxx", interactive=True
)
asr_opt_dir = gr.Textbox(
label=i18n("输出文件夹路径"), value="output/asr_opt", interactive=True
)
with gr.Row():
with gr.Row(equal_height=True):
asr_model = gr.Dropdown(
label=i18n("ASR 模型"),
choices=list(asr_dict.keys()),
@ -1423,7 +1453,7 @@ with gr.Blocks(title="GPT-SoVITS WebUI", analytics_enabled=False, js=js, css=css
asr_precision = gr.Dropdown(
label=i18n("数据类型精度"), choices=["float32"], interactive=True, value="float32"
)
with gr.Row():
with gr.Row(equal_height=True):
asr_info = gr.Textbox(label=process_info(process_name_asr, "info"))
open_asr_button = gr.Button(
value=process_info(process_name_asr, "open"), variant="primary", visible=True
@ -1455,9 +1485,9 @@ with gr.Blocks(title="GPT-SoVITS WebUI", analytics_enabled=False, js=js, css=css
asr_model.change(change_precision_choices, [asr_model], [asr_precision])
with gr.Accordion(label="0d-" + i18n("语音文本校对标注工具")):
with gr.Row():
with gr.Row(equal_height=True):
with gr.Column(scale=3):
with gr.Row():
with gr.Row(equal_height=True):
path_list = gr.Textbox(
label=i18n("标注文件路径 (含文件后缀 *.list)"),
value="D:\\RVC1006\\GPT-SoVITS\\raw\\xxx.list",
@ -1478,7 +1508,7 @@ with gr.Blocks(title="GPT-SoVITS WebUI", analytics_enabled=False, js=js, css=css
with gr.TabItem(i18n("1-GPT-SoVITS-TTS")):
with gr.Accordion(i18n("微调模型信息")):
with gr.Row():
with gr.Row(equal_height=True):
with gr.Row(equal_height=True):
exp_name = gr.Textbox(
label=i18n("*实验/模型名"),
@ -1500,7 +1530,7 @@ with gr.Blocks(title="GPT-SoVITS WebUI", analytics_enabled=False, js=js, css=css
scale=5,
)
with gr.Accordion(label=i18n("预训练模型路径"), open=False):
with gr.Row():
with gr.Row(equal_height=True):
with gr.Row(equal_height=True):
pretrained_s1 = gr.Textbox(
label=i18n("预训练GPT模型路径"),
@ -1529,15 +1559,15 @@ with gr.Blocks(title="GPT-SoVITS WebUI", analytics_enabled=False, js=js, css=css
with gr.TabItem("1A-" + i18n("训练集格式化工具")):
with gr.Accordion(label=i18n("输出logs/实验名目录下应有23456开头的文件和文件夹")):
with gr.Row():
with gr.Row():
with gr.Row(equal_height=True):
with gr.Row(equal_height=True):
inp_text = gr.Textbox(
label=i18n("*文本标注文件"),
value=r"D:\RVC1006\GPT-SoVITS\raw\xxx.list",
interactive=True,
scale=10,
)
with gr.Row():
with gr.Row(equal_height=True):
inp_wav_dir = gr.Textbox(
label=i18n("*训练集音频文件目录"),
# value=r"D:\RVC1006\GPT-SoVITS\raw\xxx",
@ -1549,90 +1579,90 @@ with gr.Blocks(title="GPT-SoVITS WebUI", analytics_enabled=False, js=js, css=css
)
with gr.Accordion(label="1Aa-" + process_name_1a):
with gr.Row():
with gr.Row():
with gr.Row(equal_height=True):
with gr.Row(equal_height=True):
gpu_numbers1a = gr.Textbox(
label=i18n("GPU卡号以-分割,每个卡号一个进程"),
value="%s-%s" % (gpus, gpus),
interactive=True,
)
with gr.Row():
with gr.Row(equal_height=True):
bert_pretrained_dir = gr.Textbox(
label=i18n("预训练中文BERT模型路径"),
value="GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large",
interactive=False,
lines=2,
)
with gr.Row():
with gr.Row(equal_height=True):
button1a_open = gr.Button(
value=process_info(process_name_1a, "open"), variant="primary", visible=True
)
button1a_close = gr.Button(
value=process_info(process_name_1a, "close"), variant="primary", visible=False
)
with gr.Row():
with gr.Row(equal_height=True):
info1a = gr.Textbox(label=process_info(process_name_1a, "info"))
with gr.Accordion(label="1Ab-" + process_name_1b):
with gr.Row():
with gr.Row():
with gr.Row(equal_height=True):
with gr.Row(equal_height=True):
gpu_numbers1Ba = gr.Textbox(
label=i18n("GPU卡号以-分割,每个卡号一个进程"),
value="%s-%s" % (gpus, gpus),
interactive=True,
)
with gr.Row():
with gr.Row(equal_height=True):
cnhubert_base_dir = gr.Textbox(
label=i18n("预训练SSL模型路径"),
value="GPT_SoVITS/pretrained_models/chinese-hubert-base",
interactive=False,
lines=2,
)
with gr.Row():
with gr.Row(equal_height=True):
button1b_open = gr.Button(
value=process_info(process_name_1b, "open"), variant="primary", visible=True
)
button1b_close = gr.Button(
value=process_info(process_name_1b, "close"), variant="primary", visible=False
)
with gr.Row():
with gr.Row(equal_height=True):
info1b = gr.Textbox(label=process_info(process_name_1b, "info"))
with gr.Accordion(label="1Ac-" + process_name_1c):
with gr.Row():
with gr.Row():
with gr.Row(equal_height=True):
with gr.Row(equal_height=True):
gpu_numbers1c = gr.Textbox(
label=i18n("GPU卡号以-分割,每个卡号一个进程"),
value="%s-%s" % (gpus, gpus),
interactive=True,
)
with gr.Row():
with gr.Row(equal_height=True):
pretrained_s2G_ = gr.Textbox(
label=i18n("预训练SoVITS-G模型路径"),
value=pretrained_sovits_name[version],
interactive=False,
lines=2,
)
with gr.Row():
with gr.Row(equal_height=True):
button1c_open = gr.Button(
value=process_info(process_name_1c, "open"), variant="primary", visible=True
)
button1c_close = gr.Button(
value=process_info(process_name_1c, "close"), variant="primary", visible=False
)
with gr.Row():
with gr.Row(equal_height=True):
info1c = gr.Textbox(label=process_info(process_name_1c, "info"))
with gr.Accordion(label="1Aabc-" + process_name_1abc):
with gr.Row():
with gr.Row():
with gr.Row(equal_height=True):
with gr.Row(equal_height=True):
button1abc_open = gr.Button(
value=process_info(process_name_1abc, "open"), variant="primary", visible=True
)
button1abc_close = gr.Button(
value=process_info(process_name_1abc, "close"), variant="primary", visible=False
)
with gr.Row():
with gr.Row(equal_height=True):
info1abc = gr.Textbox(label=process_info(process_name_1abc, "info"))
pretrained_s2G.change(sync, [pretrained_s2G], [pretrained_s2G_])
@ -1704,149 +1734,146 @@ with gr.Blocks(title="GPT-SoVITS WebUI", analytics_enabled=False, js=js, css=css
with gr.TabItem("1B-" + i18n("微调训练")):
with gr.Accordion(label="1Ba-" + i18n("SoVITS 训练: 模型权重文件在 SoVITS_weights/")):
with gr.Row():
with gr.Row(equal_height=True):
batch_size = gr.Slider(
minimum=1,
maximum=default_max_batch_size,
step=1,
label=i18n("每张显卡的batch_size"),
value=default_batch_size,
interactive=True,
)
total_epoch = gr.Slider(
minimum=1,
maximum=max_sovits_epoch,
step=1,
label=i18n("总训练轮数total_epoch不建议太高"),
value=default_sovits_epoch,
interactive=True,
)
with gr.Column(scale=2):
if_save_latest = gr.Checkbox(
label=i18n("是否仅保存最新的权重文件以节省硬盘空间"),
value=True,
interactive=True,
show_label=True,
)
if_save_every_weights = gr.Checkbox(
label=i18n("是否在每次保存时间点将最终小模型保存至weights文件夹"),
value=True,
interactive=True,
show_label=True,
)
if_grad_ckpt = gr.Checkbox(
label="v3是否开启梯度检查点节省显存占用",
value=False,
interactive=True if version in v3v4set else False,
show_label=True,
visible=False,
) # 只有V3s2可以用
with gr.Row(equal_height=True):
text_low_lr_rate = gr.Slider(
minimum=0.2,
maximum=0.6,
step=0.05,
label=i18n("文本模块学习率权重"),
value=0.4,
visible=True if version not in v3v4set else False,
) # v3v4 not need
lora_rank = gr.Radio(
label=i18n("LoRA秩"),
value="32",
choices=["16", "32", "64", "128"],
visible=True if version in v3v4set else False,
) # v1v2 not need
save_every_epoch = gr.Slider(
minimum=1,
maximum=max_sovits_save_every_epoch,
step=1,
label=i18n("保存频率save_every_epoch"),
value=default_sovits_save_every_epoch,
interactive=True,
)
with gr.Column(scale=3):
gpu_numbers1Ba = gr.Textbox(
label=i18n("GPU卡号以-分割,每个卡号一个进程"),
value="%s" % (gpus),
interactive=True,
)
with gr.Row(equal_height=True):
with gr.Column():
with gr.Row():
batch_size = gr.Slider(
minimum=1,
maximum=default_max_batch_size,
step=1,
label=i18n("每张显卡的batch_size"),
value=default_batch_size,
interactive=True,
)
total_epoch = gr.Slider(
minimum=1,
maximum=max_sovits_epoch,
step=1,
label=i18n("总训练轮数total_epoch不建议太高"),
value=default_sovits_epoch,
interactive=True,
)
with gr.Row():
text_low_lr_rate = gr.Slider(
minimum=0.2,
maximum=0.6,
step=0.05,
label=i18n("文本模块学习率权重"),
value=0.4,
visible=True if version not in v3v4set else False,
) # v3v4 not need
lora_rank = gr.Radio(
label=i18n("LoRA秩"),
value="32",
choices=["16", "32", "64", "128"],
visible=True if version in v3v4set else False,
) # v1v2 not need
save_every_epoch = gr.Slider(
minimum=1,
maximum=max_sovits_save_every_epoch,
step=1,
label=i18n("保存频率save_every_epoch"),
value=default_sovits_save_every_epoch,
interactive=True,
)
with gr.Column():
with gr.Column():
if_save_latest = gr.Checkbox(
label=i18n("是否仅保存最新的权重文件以节省硬盘空间"),
value=True,
interactive=True,
show_label=True,
)
if_save_every_weights = gr.Checkbox(
label=i18n("是否在每次保存时间点将最终小模型保存至weights文件夹"),
value=True,
interactive=True,
show_label=True,
)
if_grad_ckpt = gr.Checkbox(
label="v3是否开启梯度检查点节省显存占用",
value=False,
interactive=True if version in v3v4set else False,
show_label=True,
visible=False,
) # 只有V3s2可以用
with gr.Row():
gpu_numbers1Ba = gr.Textbox(
label=i18n("GPU卡号以-分割,每个卡号一个进程"),
value="%s" % (gpus),
interactive=True,
)
with gr.Row():
with gr.Row():
button1Ba_open = gr.Button(
value=process_info(process_name_sovits, "open"), variant="primary", visible=True
)
button1Ba_close = gr.Button(
value=process_info(process_name_sovits, "close"), variant="primary", visible=False
)
with gr.Row():
with gr.Column():
info1Ba = gr.Textbox(label=process_info(process_name_sovits, "info"))
with gr.Accordion(label="1Bb-" + i18n("GPT 训练: 模型权重文件在 GPT_weights/")):
with gr.Row():
with gr.Row(equal_height=True):
batch_size1Bb = gr.Slider(
minimum=1,
maximum=40,
step=1,
label=i18n("每张显卡的batch_size"),
value=default_batch_size_s1,
interactive=True,
)
total_epoch1Bb = gr.Slider(
minimum=2,
maximum=50,
step=1,
label=i18n("总训练轮数total_epoch"),
value=15,
interactive=True,
)
with gr.Column(scale=2):
if_save_latest1Bb = gr.Checkbox(
label=i18n("是否仅保存最新的权重文件以节省硬盘空间"),
value=True,
interactive=True,
show_label=True,
)
if_save_every_weights1Bb = gr.Checkbox(
label=i18n("是否在每次保存时间点将最终小模型保存至weights文件夹"),
value=True,
interactive=True,
show_label=True,
)
with gr.Row(equal_height=True):
# with gr.Column():
save_every_epoch1Bb = gr.Slider(
minimum=1,
maximum=50,
step=1,
label=i18n("保存频率save_every_epoch"),
value=5,
interactive=True,
)
# with gr.Column():
if_dpo = gr.Checkbox(
label=i18n("是否开启DPO训练选项(实验性)"),
value=False,
interactive=True,
show_label=True,
)
with gr.Column(scale=2):
gpu_numbers1Bb = gr.Textbox(
label=i18n("GPU卡号以-分割,每个卡号一个进程"),
value="%s" % (gpus),
interactive=True,
)
with gr.Row(equal_height=True):
with gr.Column():
with gr.Row():
batch_size1Bb = gr.Slider(
minimum=1,
maximum=40,
step=1,
label=i18n("每张显卡的batch_size"),
value=default_batch_size_s1,
interactive=True,
with gr.Row(equal_height=True):
button1Bb_open = gr.Button(
value=process_info(process_name_gpt, "open"), variant="primary", visible=True
)
total_epoch1Bb = gr.Slider(
minimum=2,
maximum=50,
step=1,
label=i18n("总训练轮数total_epoch"),
value=15,
interactive=True,
)
with gr.Row():
save_every_epoch1Bb = gr.Slider(
minimum=1,
maximum=50,
step=1,
label=i18n("保存频率save_every_epoch"),
value=5,
interactive=True,
)
if_dpo = gr.Checkbox(
label=i18n("是否开启DPO训练选项(实验性)"),
value=False,
interactive=True,
show_label=True,
button1Bb_close = gr.Button(
value=process_info(process_name_gpt, "close"), variant="primary", visible=False
)
with gr.Column():
with gr.Column():
if_save_latest1Bb = gr.Checkbox(
label=i18n("是否仅保存最新的权重文件以节省硬盘空间"),
value=True,
interactive=True,
show_label=True,
)
if_save_every_weights1Bb = gr.Checkbox(
label=i18n("是否在每次保存时间点将最终小模型保存至weights文件夹"),
value=True,
interactive=True,
show_label=True,
)
with gr.Row():
gpu_numbers1Bb = gr.Textbox(
label=i18n("GPU卡号以-分割,每个卡号一个进程"),
value="%s" % (gpus),
interactive=True,
)
with gr.Row():
with gr.Row():
button1Bb_open = gr.Button(
value=process_info(process_name_gpt, "open"), variant="primary", visible=True
)
button1Bb_close = gr.Button(
value=process_info(process_name_gpt, "close"), variant="primary", visible=False
)
with gr.Row():
info1Bb = gr.Textbox(label=process_info(process_name_gpt, "info"))
button1Ba_close.click(close1Ba, [], [info1Ba, button1Ba_open, button1Ba_close])
@ -1858,41 +1885,60 @@ with gr.Blocks(title="GPT-SoVITS WebUI", analytics_enabled=False, js=js, css=css
"选择训练完存放在SoVITS_weights和GPT_weights下的模型。默认的几个是底模体验5秒Zero Shot TTS不训练推理用。"
)
)
with gr.Row():
with gr.Row(equal_height=True):
with gr.Column(scale=2):
with gr.Row():
GPT_dropdown = gr.Dropdown(
label=i18n("GPT模型列表"),
choices=GPT_names,
value=GPT_names[-1],
interactive=True,
)
SoVITS_dropdown = gr.Dropdown(
label=i18n("SoVITS模型列表"),
choices=SoVITS_names,
value=SoVITS_names[0],
interactive=True,
)
with gr.Row(equal_height=True):
with gr.Column():
GPT_dropdown = gr.Dropdown(
label=i18n("GPT模型列表"),
choices=GPT_names,
value=GPT_names[-1],
interactive=True,
)
with gr.Column():
SoVITS_dropdown = gr.Dropdown(
label=i18n("SoVITS模型列表"),
choices=SoVITS_names,
value=SoVITS_names[0],
interactive=True,
)
with gr.Column(scale=2):
with gr.Row():
gpu_number_1C = gr.Textbox(
label=i18n("GPU卡号,只能填1个整数"), value=gpus, interactive=True
with gr.Row(equal_height=True):
gpu_number_1C = gr.Dropdown(
choices=sorted(list(GPU_INDEX)),
value=sorted(list(GPU_INDEX))[0],
label=i18n("GPU卡号,只能填1个整数"),
interactive=True,
)
refresh_button = gr.Button(i18n("刷新模型路径"), variant="primary")
refresh_button.click(fn=change_choices, inputs=[], outputs=[SoVITS_dropdown, GPT_dropdown])
with gr.Row(equal_height=True):
with gr.Row():
batched_infer_enabled = gr.Checkbox(
label=i18n("启用并行推理版本"), value=False, interactive=True, show_label=True
)
with gr.Row(equal_height=True):
with gr.Column():
batched_infer_enabled = gr.Checkbox(
label=i18n("启用并行推理版本"), value=False, interactive=True, show_label=True
)
with gr.Column():
backends_dropdown = gr.Dropdown(
choices=backends,
label=i18n("推理后端"),
value=backends[-1],
interactive=True,
)
with gr.Row(equal_height=True):
tts_info = gr.Textbox(label=process_info(process_name_tts, "info"))
open_tts = gr.Button(
value=process_info(process_name_tts, "open"), variant="primary", visible=True
)
close_tts = gr.Button(
value=process_info(process_name_tts, "close"), variant="primary", visible=False
)
with gr.Column():
tts_info = gr.Textbox(label=process_info(process_name_tts, "info"), scale=2)
batched_infer_enabled.change(
changeBackend,
[batched_infer_enabled],
[backends_dropdown],
)
open_tts.click(
change_tts_inference,
[
@ -1902,6 +1948,7 @@ with gr.Blocks(title="GPT-SoVITS WebUI", analytics_enabled=False, js=js, css=css
GPT_dropdown,
SoVITS_dropdown,
batched_infer_enabled,
backends_dropdown,
],
[tts_info, open_tts, close_tts],
)
@ -1914,6 +1961,7 @@ with gr.Blocks(title="GPT-SoVITS WebUI", analytics_enabled=False, js=js, css=css
GPT_dropdown,
SoVITS_dropdown,
batched_infer_enabled,
backends_dropdown,
],
[tts_info, open_tts, close_tts],
)