mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2026-07-01 11:14:06 +08:00
Draft
This commit is contained in:
parent
fdf794e31d
commit
d6d89a224d
7
.github/build_windows_packages.ps1
vendored
7
.github/build_windows_packages.ps1
vendored
@ -115,12 +115,17 @@ Remove-Item $ffDir.FullName -Recurse -Force
|
||||
Write-Host "[INFO] Installing PyTorch..."
|
||||
& ".\runtime\python.exe" -m ensurepip
|
||||
& ".\runtime\python.exe" -m pip install --upgrade pip --no-warn-script-location
|
||||
|
||||
switch ($cuda) {
|
||||
"cu124" {
|
||||
& ".\runtime\python.exe" -m pip install torch==2.6 torchaudio --index-url https://download.pytorch.org/whl/cu124 --no-warn-script-location
|
||||
& ".\runtime\python.exe" -m pip install psutil ninja packaging wheel "setuptools>=42" --no-warn-script-location
|
||||
& ".\runtime\python.exe" -m pip install torch torchaudio --index-url https://download.pytorch.org/whl/cu124 --no-warn-script-location
|
||||
& ".\runtime\python.exe" -m pip install flash-attn -i https://xxxxrt666.github.io/PIP-Index/ --no-build-isolation
|
||||
}
|
||||
"cu128" {
|
||||
& ".\runtime\python.exe" -m pip install psutil ninja packaging wheel "setuptools>=42" --no-warn-script-location
|
||||
& ".\runtime\python.exe" -m pip install torch torchaudio --index-url https://download.pytorch.org/whl/cu128 --no-warn-script-location
|
||||
& ".\runtime\python.exe" -m pip install flash-attn -i https://xxxxrt666.github.io/PIP-Index/ --no-build-isolation
|
||||
}
|
||||
default {
|
||||
Write-Error "Unsupported CUDA version: $cuda"
|
||||
|
||||
@ -31,6 +31,15 @@ jobs:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Install Windows CUDA 12.9
|
||||
if: ${{ runner.os == 'Windows' && matrix.torch_cuda == '12.8' }}
|
||||
uses: Jimver/cuda-toolkit
|
||||
id: cuda-toolkit-win-129
|
||||
with:
|
||||
cuda: 12.9.1
|
||||
method: "network"
|
||||
sub-packages: '["nvcc", "cudart", "visual_studio_integration"]'
|
||||
|
||||
- name: Run Build and Upload Script
|
||||
shell: pwsh
|
||||
run: |
|
||||
|
||||
@ -23,8 +23,10 @@ fi
|
||||
|
||||
if [ "$TARGETPLATFORM" = "linux/amd64" ]; then
|
||||
"${WGET_CMD[@]}" -O miniconda.sh https://repo.anaconda.com/miniconda/Miniconda3-py311_25.3.1-1-Linux-x86_64.sh
|
||||
SYSROOT_PKG="sysroot_linux-64>=2.28"
|
||||
elif [ "$TARGETPLATFORM" = "linux/arm64" ]; then
|
||||
"${WGET_CMD[@]}" -O miniconda.sh https://repo.anaconda.com/miniconda/Miniconda3-py311_25.3.1-1-Linux-aarch64.sh
|
||||
SYSROOT_PKG="sysroot_linux-aarch64>=2.28"
|
||||
else
|
||||
exit 1
|
||||
fi
|
||||
@ -45,20 +47,36 @@ rm miniconda.sh
|
||||
|
||||
source "$HOME/miniconda3/etc/profile.d/conda.sh"
|
||||
|
||||
"$HOME/miniconda3/bin/conda" init bash
|
||||
|
||||
source "$HOME/.bashrc"
|
||||
|
||||
"$HOME/miniconda3/bin/conda" config --add channels conda-forge
|
||||
|
||||
"$HOME/miniconda3/bin/conda" update -q --all -y 1>/dev/null
|
||||
|
||||
"$HOME/miniconda3/bin/conda" install python=3.11 -q -y
|
||||
|
||||
"$HOME/miniconda3/bin/conda" install gcc=14 gxx ffmpeg cmake make unzip -q -y
|
||||
"$HOME/miniconda3/bin/conda" install gcc=11 gxx ffmpeg cmake make unzip $SYSROOT_PKG "libstdcxx-ng>=11" -q -y
|
||||
|
||||
if [ "$CUDA_VERSION" = "12.8" ]; then
|
||||
"$HOME/miniconda3/bin/pip" install torch torchaudio --no-cache-dir --index-url https://download.pytorch.org/whl/cu128
|
||||
"$HOME/miniconda3/bin/conda" install cuda-nvcc=12.8 -c nvidia
|
||||
elif [ "$CUDA_VERSION" = "12.6" ]; then
|
||||
"$HOME/miniconda3/bin/pip" install torch==2.6 torchaudio --no-cache-dir --index-url https://download.pytorch.org/whl/cu126
|
||||
"$HOME/miniconda3/bin/pip" install torch torchaudio --no-cache-dir --index-url https://download.pytorch.org/whl/cu126
|
||||
"$HOME/miniconda3/bin/conda" install cuda-nvcc=12.6 -c nvidia
|
||||
fi
|
||||
|
||||
CUDA_PATH=$(echo "$HOME/miniconda3/targets/"*-linux | awk '{print $1}')
|
||||
|
||||
export CUDA_HOME=$CUDA_PATH
|
||||
export PATH="$HOME/miniconda3/bin:$PATH"
|
||||
export PATH="$CUDA_HOME/bin:$PATH"
|
||||
export PATH="$CUDA_HOME/nvvm/bin:$PATH"
|
||||
|
||||
"$HOME/miniconda3/bin/pip" install psutil ninja packaging wheel "setuptools>=42"
|
||||
"$HOME/miniconda3/bin/pip" install flash-attn -i https://xxxxrt666.github.io/PIP-Index/ --no-build-isolation
|
||||
|
||||
"$HOME/miniconda3/bin/pip" cache purge
|
||||
|
||||
rm $LOG_PATH
|
||||
|
||||
@ -1,72 +0,0 @@
|
||||
# modified from https://github.com/yangdongchao/SoundStorm/blob/master/soundstorm/s1/AR/text_processing/phonemizer.py
|
||||
# reference: https://github.com/lifeiteng/vall-e
|
||||
import itertools
|
||||
import re
|
||||
from typing import Dict
|
||||
from typing import List
|
||||
|
||||
import regex
|
||||
from gruut import sentences
|
||||
from gruut.const import Sentence
|
||||
from gruut.const import Word
|
||||
from AR.text_processing.symbols import SYMBOL_TO_ID
|
||||
|
||||
|
||||
class GruutPhonemizer:
|
||||
def __init__(self, language: str):
|
||||
self._phonemizer = sentences
|
||||
self.lang = language
|
||||
self.symbol_to_id = SYMBOL_TO_ID
|
||||
self._special_cases_dict: Dict[str] = {
|
||||
r"\.\.\.": "... ",
|
||||
";": "; ",
|
||||
":": ": ",
|
||||
",": ", ",
|
||||
r"\.": ". ",
|
||||
"!": "! ",
|
||||
r"\?": "? ",
|
||||
"—": "—",
|
||||
"…": "… ",
|
||||
"«": "«",
|
||||
"»": "»",
|
||||
}
|
||||
self._punctuation_regexp: str = rf"([{''.join(self._special_cases_dict.keys())}])"
|
||||
|
||||
def _normalize_punctuation(self, text: str) -> str:
|
||||
text = regex.sub(rf"\pZ+{self._punctuation_regexp}", r"\1", text)
|
||||
text = regex.sub(rf"{self._punctuation_regexp}(\pL)", r"\1 \2", text)
|
||||
text = regex.sub(r"\pZ+", r" ", text)
|
||||
return text.strip()
|
||||
|
||||
def _convert_punctuation(self, word: Word) -> str:
|
||||
if not word.phonemes:
|
||||
return ""
|
||||
if word.phonemes[0] in ["‖", "|"]:
|
||||
return word.text.strip()
|
||||
|
||||
phonemes = "".join(word.phonemes)
|
||||
# remove modifier characters ˈˌː with regex
|
||||
phonemes = re.sub(r"[ˈˌː͡]", "", phonemes)
|
||||
return phonemes.strip()
|
||||
|
||||
def phonemize(self, text: str, espeak: bool = False) -> str:
|
||||
text_to_phonemize: str = self._normalize_punctuation(text)
|
||||
sents: List[Sentence] = [sent for sent in self._phonemizer(text_to_phonemize, lang="en-us", espeak=espeak)]
|
||||
words: List[str] = [self._convert_punctuation(word) for word in itertools.chain(*sents)]
|
||||
return " ".join(words)
|
||||
|
||||
def transform(self, phonemes):
|
||||
# convert phonemes to ids
|
||||
# dictionary is in symbols.py
|
||||
return [self.symbol_to_id[p] for p in phonemes if p in self.symbol_to_id.keys()]
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
phonemizer = GruutPhonemizer("en-us")
|
||||
# text -> IPA
|
||||
phonemes = phonemizer.phonemize("Hello, wor-ld ?")
|
||||
print("phonemes:", phonemes)
|
||||
print("len(phonemes):", len(phonemes))
|
||||
phoneme_ids = phonemizer.transform(phonemes)
|
||||
print("phoneme_ids:", phoneme_ids)
|
||||
print("len(phoneme_ids):", len(phoneme_ids))
|
||||
@ -1,12 +0,0 @@
|
||||
# modified from https://github.com/yangdongchao/SoundStorm/blob/master/soundstorm/s1/AR/text_processing/symbols.py
|
||||
# reference: https://github.com/lifeiteng/vall-e
|
||||
PAD = "_"
|
||||
PUNCTUATION = ';:,.!?¡¿—…"«»“” '
|
||||
LETTERS = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"
|
||||
IPA_LETTERS = (
|
||||
"ɑɐɒæɓʙβɔɕçɗɖðʤəɘɚɛɜɝɞɟʄɡɠɢʛɦɧħɥʜɨɪʝɭɬɫɮʟɱɯɰŋɳɲɴøɵɸθœɶʘɹɺɾɻʀʁɽʂʃʈʧʉʊʋⱱʌɣɤʍχʎʏʑʐʒʔʡʕʢǀǁǂǃˈˌːˑʼʴʰʱʲʷˠˤ˞↓↑→↗↘'̩'ᵻ"
|
||||
)
|
||||
SYMBOLS = [PAD] + list(PUNCTUATION) + list(LETTERS) + list(IPA_LETTERS)
|
||||
SPACE_ID = SYMBOLS.index(" ")
|
||||
SYMBOL_TO_ID = {s: i for i, s in enumerate(SYMBOLS)}
|
||||
ID_TO_SYMBOL = {i: s for i, s in enumerate(SYMBOLS)}
|
||||
11
GPT_SoVITS/Accelerate/MLX/__init__.py
Normal file
11
GPT_SoVITS/Accelerate/MLX/__init__.py
Normal file
@ -0,0 +1,11 @@
|
||||
import importlib.util
|
||||
|
||||
if importlib.util.find_spec("mlx") is not None:
|
||||
from .sample_funcs_mlx import sample_naive as sample_naive_mlx
|
||||
from .t2s_engine_mlx import T2SEngine as T2SEngineMLX
|
||||
|
||||
backends = ["mlx_static", "mlx_quantized", "mlx_varlen"]
|
||||
else:
|
||||
backends = []
|
||||
|
||||
__all__ = ["T2SEngineMLX", "sample_naive_mlx", "backends"]
|
||||
172
GPT_SoVITS/Accelerate/MLX/backends/mlx_quantized.py
Normal file
172
GPT_SoVITS/Accelerate/MLX/backends/mlx_quantized.py
Normal file
@ -0,0 +1,172 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import cast
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from ..structs_mlx import KVCacheQ
|
||||
from ..t2s_model_abc import (
|
||||
AttentionABC,
|
||||
KVCache,
|
||||
KVCacheHND,
|
||||
T2SDecoderABC,
|
||||
TransformerBlockABC,
|
||||
TransformerDecoderABC,
|
||||
)
|
||||
|
||||
Array = mx.array
|
||||
|
||||
|
||||
class Attention(AttentionABC):
|
||||
def __init__(self, n_head: int, hidden_dim: int, max_seq_length: int):
|
||||
super().__init__(n_head, hidden_dim, max_seq_length)
|
||||
self.kc_class = KVCacheHND
|
||||
|
||||
@staticmethod
|
||||
def quantized_scaled_dot_product_attention(
|
||||
queries: Array,
|
||||
q_keys: tuple[Array, Array, Array],
|
||||
q_values: tuple[Array, Array, Array],
|
||||
scale: float,
|
||||
mask: Array,
|
||||
group_size: int = 32,
|
||||
bits: int = 8,
|
||||
) -> Array:
|
||||
queries *= scale
|
||||
|
||||
scores = mx.quantized_matmul(queries, *q_keys, transpose=True, group_size=group_size, bits=bits)
|
||||
scores = mx.where(mask, scores, -mx.inf)
|
||||
scores = mx.softmax(scores, axis=-1, precise=True) # type: ignore
|
||||
out = mx.quantized_matmul(scores, *q_values, transpose=False, group_size=group_size, bits=bits)
|
||||
|
||||
return out
|
||||
|
||||
def __call__(self, x: Array, input_pos: Array, kv_cache: KVCache | KVCacheQ, cache_idx: Array, attn_mask: Array):
|
||||
bsz, seqlen, _ = cast(tuple[int, ...], x.shape)
|
||||
|
||||
q, k, v = self.in_proj(x).split(3, axis=-1)
|
||||
|
||||
q, k, v = map(lambda x: x.reshape(bsz, seqlen, self.n_head, self.head_dim), (q, k, v))
|
||||
|
||||
q, k, v = map(lambda x: x.swapaxes(1, 2), (q, k, v))
|
||||
|
||||
kv_cache = self.kc_class.update_cache(input_pos, k, v, kv_cache, cache_idx)
|
||||
assert len(kv_cache) == 2
|
||||
|
||||
max_idx = int(input_pos.max())
|
||||
|
||||
q, k, v = map(lambda x: x[..., :max_idx, :], (q, *kv_cache))
|
||||
|
||||
mask = attn_mask[..., :max_idx]
|
||||
|
||||
attn = mx.fast.scaled_dot_product_attention(q, k, v, scale=self.scale, mask=mask)
|
||||
|
||||
attn = attn.swapaxes(1, 2).reshape(bsz, seqlen, self.hidden_dim)
|
||||
|
||||
attn = self.out_proj(attn)
|
||||
|
||||
return attn
|
||||
|
||||
# def __call__(self, x: Array, input_pos: Array, kv_cache: KVCache | KVCacheQ, cache_idx: Array, attn_mask: Array):
|
||||
# bsz, seqlen, _ = cast(tuple[int, ...], x.shape)
|
||||
|
||||
# q, k, v = self.in_proj(x).split(3, axis=-1)
|
||||
|
||||
# q, k, v = map(lambda x: x.reshape(bsz, seqlen, self.n_head, self.head_dim), (q, k, v))
|
||||
|
||||
# q, k, v = map(lambda x: x.swapaxes(1, 2), (q, k, v))
|
||||
|
||||
# kv_cache = self.kc_class.update_cache(input_pos, k, v, kv_cache, cache_idx)
|
||||
|
||||
# assert len(kv_cache) == 3
|
||||
# (k_q, k_s, k_b), (v_q, v_s, v_b), (group_size, bits) = kv_cache
|
||||
|
||||
# k_q, k_s, k_b, v_q, v_s, v_b = map(lambda x: x[..., : int(input_pos.max()), :], (k_q, k_s, k_b, v_q, v_s, v_b))
|
||||
|
||||
# mask = attn_mask[..., : int(input_pos.max())]
|
||||
|
||||
# attn = Attention.quantized_scaled_dot_product_attention(
|
||||
# q,
|
||||
# (k_q, k_s, k_b),
|
||||
# (v_q, v_s, v_b),
|
||||
# self.scale,
|
||||
# mask,
|
||||
# group_size,
|
||||
# bits,
|
||||
# )
|
||||
|
||||
# attn = attn.swapaxes(1, 2).reshape(bsz, seqlen, self.hidden_dim)
|
||||
|
||||
# output = self.out_proj(attn)
|
||||
|
||||
# return output
|
||||
|
||||
|
||||
class TransformerBlock(TransformerBlockABC):
|
||||
def __init__(self, n_head: int, ffn_dim: int, hidden_dim: int, max_seq_length: int, *args, **kwds) -> None:
|
||||
super().__init__(n_head, ffn_dim, hidden_dim, max_seq_length, *args, **kwds)
|
||||
|
||||
self.attention = Attention(n_head, hidden_dim, max_seq_length, *args, **kwds)
|
||||
|
||||
|
||||
class TransformerDecoder(TransformerDecoderABC):
|
||||
def __init__(
|
||||
self,
|
||||
hidden_dim: int,
|
||||
n_layer: int,
|
||||
n_head: int,
|
||||
ffn_dim: int,
|
||||
vocab_size: int,
|
||||
max_seq_length: int,
|
||||
max_batch_size: int,
|
||||
*args,
|
||||
**kwds,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
hidden_dim,
|
||||
n_layer,
|
||||
n_head,
|
||||
ffn_dim,
|
||||
vocab_size,
|
||||
max_seq_length,
|
||||
max_batch_size,
|
||||
*args,
|
||||
**kwds,
|
||||
)
|
||||
|
||||
self.layers = [
|
||||
TransformerBlock(
|
||||
n_head,
|
||||
ffn_dim,
|
||||
hidden_dim,
|
||||
max_seq_length,
|
||||
*args,
|
||||
**kwds,
|
||||
)
|
||||
for _ in range(n_layer)
|
||||
]
|
||||
|
||||
|
||||
class T2SDecoder(T2SDecoderABC):
|
||||
def __init__(
|
||||
self,
|
||||
config: dict,
|
||||
max_seq_length: int = 1800,
|
||||
max_batch_size: int = 10,
|
||||
) -> None:
|
||||
super().__init__(config, max_seq_length, max_batch_size)
|
||||
|
||||
self.h = TransformerDecoder(
|
||||
self.hidden_dim, self.n_layer, self.n_head, self.ffn_dim, self.vocab_size, max_seq_length, max_batch_size
|
||||
)
|
||||
|
||||
self.kv_class = KVCacheHND
|
||||
self.group_size = 32
|
||||
self.bits = 8
|
||||
|
||||
# def init_cache(self, bsz: int = 0):
|
||||
# return super().init_cache(bsz, group_size=self.group_size, bits=self.bits)
|
||||
|
||||
def quantized(self):
|
||||
nn.quantize(self.h, self.group_size, self.bits)
|
||||
99
GPT_SoVITS/Accelerate/MLX/backends/mlx_static.py
Normal file
99
GPT_SoVITS/Accelerate/MLX/backends/mlx_static.py
Normal file
@ -0,0 +1,99 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import cast
|
||||
|
||||
import mlx.core as mx
|
||||
|
||||
from ..structs_mlx import KVCache, KVCacheQ
|
||||
from ..t2s_model_abc import (
|
||||
AttentionABC,
|
||||
KVCacheHND,
|
||||
T2SDecoderABC,
|
||||
TransformerBlockABC,
|
||||
TransformerDecoderABC,
|
||||
)
|
||||
|
||||
Array = mx.array
|
||||
|
||||
|
||||
class Attention(AttentionABC):
|
||||
def __init__(self, n_head: int, hidden_dim: int, max_seq_length: int):
|
||||
super().__init__(n_head, hidden_dim, max_seq_length)
|
||||
self.kc_class = KVCacheHND
|
||||
|
||||
def __call__(self, x: Array, input_pos: Array, kv_cache: KVCache | KVCacheQ, cache_idx: Array, attn_mask: Array):
|
||||
bsz, seqlen, _ = cast(tuple[int, ...], x.shape)
|
||||
|
||||
q, k, v = self.in_proj(x).split(3, axis=-1)
|
||||
|
||||
q, k, v = map(lambda x: x.reshape(bsz, seqlen, self.n_head, self.head_dim), (q, k, v))
|
||||
|
||||
q, k, v = map(lambda x: x.swapaxes(1, 2), (q, k, v))
|
||||
|
||||
kv_cache = self.kc_class.update_cache(input_pos, k, v, kv_cache, cache_idx)
|
||||
assert len(kv_cache) == 2
|
||||
|
||||
k, v = kv_cache
|
||||
|
||||
attn = mx.fast.scaled_dot_product_attention(q, k, v, scale=self.scale, mask=attn_mask)
|
||||
|
||||
attn = attn.swapaxes(1, 2).reshape(bsz, seqlen, self.hidden_dim)
|
||||
|
||||
attn = self.out_proj(attn)
|
||||
|
||||
return attn
|
||||
|
||||
|
||||
class TransformerBlock(TransformerBlockABC):
|
||||
def __init__(self, n_head: int, ffn_dim: int, hidden_dim: int, max_seq_length: int) -> None:
|
||||
super().__init__(n_head, ffn_dim, hidden_dim, max_seq_length)
|
||||
|
||||
self.attention = Attention(n_head, hidden_dim, max_seq_length)
|
||||
|
||||
|
||||
class TransformerDecoder(TransformerDecoderABC):
|
||||
def __init__(
|
||||
self,
|
||||
hidden_dim: int,
|
||||
n_layer: int,
|
||||
n_head: int,
|
||||
ffn_dim: int,
|
||||
vocab_size: int,
|
||||
max_seq_length: int,
|
||||
max_batch_size: int,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
hidden_dim,
|
||||
n_layer,
|
||||
n_head,
|
||||
ffn_dim,
|
||||
vocab_size,
|
||||
max_seq_length,
|
||||
max_batch_size,
|
||||
)
|
||||
|
||||
self.layers = [
|
||||
TransformerBlock(
|
||||
n_head,
|
||||
ffn_dim,
|
||||
hidden_dim,
|
||||
max_seq_length,
|
||||
)
|
||||
for _ in range(n_layer)
|
||||
]
|
||||
|
||||
|
||||
class T2SDecoder(T2SDecoderABC):
|
||||
def __init__(
|
||||
self,
|
||||
config: dict,
|
||||
max_seq_length: int = 1800,
|
||||
max_batch_size: int = 10,
|
||||
) -> None:
|
||||
super().__init__(config, max_seq_length, max_batch_size)
|
||||
|
||||
self.h = TransformerDecoder(
|
||||
self.hidden_dim, self.n_layer, self.n_head, self.ffn_dim, self.vocab_size, max_seq_length, max_batch_size
|
||||
)
|
||||
|
||||
self.kv_class = KVCacheHND
|
||||
103
GPT_SoVITS/Accelerate/MLX/backends/mlx_varlen.py
Normal file
103
GPT_SoVITS/Accelerate/MLX/backends/mlx_varlen.py
Normal file
@ -0,0 +1,103 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import cast
|
||||
|
||||
import mlx.core as mx
|
||||
|
||||
from ..structs_mlx import KVCache, KVCacheQ
|
||||
from ..t2s_model_abc import (
|
||||
AttentionABC,
|
||||
KVCacheHND,
|
||||
T2SDecoderABC,
|
||||
TransformerBlockABC,
|
||||
TransformerDecoderABC,
|
||||
)
|
||||
|
||||
Array = mx.array
|
||||
|
||||
|
||||
class Attention(AttentionABC):
|
||||
def __init__(self, n_head: int, hidden_dim: int, max_seq_length: int):
|
||||
super().__init__(n_head, hidden_dim, max_seq_length)
|
||||
self.kc_class = KVCacheHND
|
||||
|
||||
def __call__(self, x: Array, input_pos: Array, kv_cache: KVCache | KVCacheQ, cache_idx: Array, attn_mask: Array):
|
||||
bsz, seqlen, _ = cast(tuple[int, ...], x.shape)
|
||||
|
||||
q, k, v = self.in_proj(x).split(3, axis=-1)
|
||||
|
||||
q, k, v = map(lambda x: x.reshape(bsz, seqlen, self.n_head, self.head_dim), (q, k, v))
|
||||
|
||||
q, k, v = map(lambda x: x.swapaxes(1, 2), (q, k, v))
|
||||
|
||||
kv_cache = self.kc_class.update_cache(input_pos, k, v, kv_cache, cache_idx)
|
||||
assert len(kv_cache) == 2
|
||||
|
||||
max_idx = int(input_pos.max())
|
||||
|
||||
q, k, v = map(lambda x: x[..., :max_idx, :], (q, *kv_cache))
|
||||
|
||||
mask = attn_mask[..., :max_idx]
|
||||
|
||||
attn = mx.fast.scaled_dot_product_attention(q, k, v, scale=self.scale, mask=mask)
|
||||
|
||||
attn = attn.swapaxes(1, 2).reshape(bsz, seqlen, self.hidden_dim)
|
||||
|
||||
attn = self.out_proj(attn)
|
||||
|
||||
return attn
|
||||
|
||||
|
||||
class TransformerBlock(TransformerBlockABC):
|
||||
def __init__(self, n_head: int, ffn_dim: int, hidden_dim: int, max_seq_length: int) -> None:
|
||||
super().__init__(n_head, ffn_dim, hidden_dim, max_seq_length)
|
||||
|
||||
self.attention = Attention(n_head, hidden_dim, max_seq_length)
|
||||
|
||||
|
||||
class TransformerDecoder(TransformerDecoderABC):
|
||||
def __init__(
|
||||
self,
|
||||
hidden_dim: int,
|
||||
n_layer: int,
|
||||
n_head: int,
|
||||
ffn_dim: int,
|
||||
vocab_size: int,
|
||||
max_seq_length: int,
|
||||
max_batch_size: int,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
hidden_dim,
|
||||
n_layer,
|
||||
n_head,
|
||||
ffn_dim,
|
||||
vocab_size,
|
||||
max_seq_length,
|
||||
max_batch_size,
|
||||
)
|
||||
|
||||
self.layers = [
|
||||
TransformerBlock(
|
||||
n_head,
|
||||
ffn_dim,
|
||||
hidden_dim,
|
||||
max_seq_length,
|
||||
)
|
||||
for _ in range(n_layer)
|
||||
]
|
||||
|
||||
|
||||
class T2SDecoder(T2SDecoderABC):
|
||||
def __init__(
|
||||
self,
|
||||
config: dict,
|
||||
max_seq_length: int = 1800,
|
||||
max_batch_size: int = 10,
|
||||
) -> None:
|
||||
super().__init__(config, max_seq_length, max_batch_size)
|
||||
|
||||
self.h = TransformerDecoder(
|
||||
self.hidden_dim, self.n_layer, self.n_head, self.ffn_dim, self.vocab_size, max_seq_length, max_batch_size
|
||||
)
|
||||
|
||||
self.kv_class = KVCacheHND
|
||||
64
GPT_SoVITS/Accelerate/MLX/sample_funcs_mlx.py
Normal file
64
GPT_SoVITS/Accelerate/MLX/sample_funcs_mlx.py
Normal file
@ -0,0 +1,64 @@
|
||||
from functools import partial
|
||||
from typing import Protocol, cast
|
||||
|
||||
import mlx.core as mx
|
||||
|
||||
Array = mx.array
|
||||
|
||||
|
||||
class SampleProtocolMLX(Protocol):
|
||||
@staticmethod
|
||||
def __call__(
|
||||
logits: Array,
|
||||
previous_tokens: Array,
|
||||
temperature: float,
|
||||
top_k: int,
|
||||
top_p: float,
|
||||
repetition_penalty: float,
|
||||
) -> Array: ...
|
||||
|
||||
|
||||
class sample_naive(SampleProtocolMLX):
|
||||
@partial(mx.compile)
|
||||
@staticmethod
|
||||
def __call__(
|
||||
logits,
|
||||
previous_tokens,
|
||||
temperature,
|
||||
top_k,
|
||||
top_p,
|
||||
repetition_penalty,
|
||||
):
|
||||
if temperature <= 1e-5:
|
||||
probs = mx.softmax(logits, axis=-1)
|
||||
return mx.argmax(probs, axis=-1, keepdims=True)
|
||||
|
||||
if repetition_penalty != 1.0:
|
||||
batch_idx = mx.arange(cast(tuple[int, ...], previous_tokens.shape)[0])
|
||||
previous_tokens = previous_tokens.astype(mx.int64)
|
||||
selected_logists = logits[batch_idx, previous_tokens]
|
||||
selected_logists = mx.where(
|
||||
selected_logists < 0, selected_logists * repetition_penalty, selected_logists / repetition_penalty
|
||||
)
|
||||
logits[batch_idx, previous_tokens] = selected_logists
|
||||
|
||||
sorted_indices = mx.argsort(-logits, axis=-1)
|
||||
sorted_logits = mx.take_along_axis(logits, sorted_indices, axis=-1)
|
||||
cum_probs = mx.cumsum(mx.softmax(sorted_logits, axis=-1), axis=-1)
|
||||
sorted_indices_to_remove = cum_probs > top_p
|
||||
sorted_indices_to_remove[:, 0] = False
|
||||
indices_to_remove = mx.zeros_like(logits).astype(mx.bool_)
|
||||
batch_indices = mx.arange(cast(tuple[int, ...], logits.shape)[0])[:, None]
|
||||
indices_to_remove[batch_indices, sorted_indices] = sorted_indices_to_remove
|
||||
logits = mx.where(indices_to_remove, -mx.inf, logits)
|
||||
|
||||
logits = logits / temperature
|
||||
|
||||
v = mx.topk(logits, top_k)
|
||||
pivot = mx.expand_dims(v[:, -1], -1)
|
||||
logits = mx.where(logits < pivot, -mx.inf, logits)
|
||||
|
||||
gumbel_noise = mx.random.gumbel(shape=cast(tuple[int, ...], logits.shape), dtype=logits.dtype)
|
||||
idx_next = mx.argmax(logits + gumbel_noise, axis=-1, keepdims=True).astype(mx.int32)
|
||||
|
||||
return idx_next
|
||||
164
GPT_SoVITS/Accelerate/MLX/structs_mlx.py
Normal file
164
GPT_SoVITS/Accelerate/MLX/structs_mlx.py
Normal file
@ -0,0 +1,164 @@
|
||||
"""
|
||||
Modified From https://github.com/XXXXRT666/GPT-SoVITS
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from typing import List, MutableSequence, Protocol, TypeAlias, cast
|
||||
|
||||
import mlx.core as mx
|
||||
import torch
|
||||
|
||||
from ..PyTorch.structs import T2SRequest, T2SResult
|
||||
from .sample_funcs_mlx import SampleProtocolMLX, sample_naive
|
||||
|
||||
Tensor = torch.Tensor
|
||||
Array = mx.array
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class T2SRequestMLX:
|
||||
x: List[Array]
|
||||
x_lens: Array
|
||||
prompts: Array
|
||||
bert_feature: List[Array]
|
||||
valid_length: int
|
||||
top_k: int = 5
|
||||
top_p: float = 1
|
||||
early_stop_num: int = -1
|
||||
temperature: float = 1.0
|
||||
repetition_penalty: float = 1.35
|
||||
|
||||
@classmethod
|
||||
def from_torch(cls, request: T2SRequest) -> T2SRequestMLX:
|
||||
x = list(map(lambda tensor: mx.array(tensor.cpu()), request.x))
|
||||
x_lens = mx.array(request.x_lens.cpu())
|
||||
prompts = mx.array(request.prompts.cpu())
|
||||
bert_feature = list(map(lambda tensor: mx.array(tensor.cpu()), request.bert_feature))
|
||||
|
||||
return cls(
|
||||
x,
|
||||
x_lens,
|
||||
prompts,
|
||||
bert_feature,
|
||||
request.valid_length,
|
||||
request.top_k,
|
||||
request.top_p,
|
||||
request.early_stop_num,
|
||||
request.temperature,
|
||||
request.repetition_penalty,
|
||||
)
|
||||
|
||||
|
||||
KVCache: TypeAlias = tuple[Array, Array]
|
||||
KVCacheQ: TypeAlias = tuple[tuple[Array, Array, Array], tuple[Array, Array, Array], tuple[int, int]]
|
||||
|
||||
|
||||
class KVCacheProtocol(Protocol):
|
||||
@staticmethod
|
||||
def empty(kv_cache: KVCache | KVCacheQ) -> None: ...
|
||||
|
||||
@staticmethod
|
||||
def update_cache(
|
||||
input_pos: Array, k_val: Array, v_val: Array, kv_cache: KVCache | KVCacheQ, cache_idx: Array
|
||||
) -> KVCache | KVCacheQ: ...
|
||||
|
||||
@staticmethod
|
||||
def prefill_kv(k_val: Array, v_val: Array, kv_cache: KVCache | KVCacheQ) -> None: ...
|
||||
|
||||
@staticmethod
|
||||
def init_cache(
|
||||
batch_size: int, max_seq_length: int, n_heads: int, head_dim: int, dtype: mx.Dtype, *args, **kwds
|
||||
) -> KVCache | KVCacheQ: ...
|
||||
|
||||
|
||||
class T2SDecoderProtocol(Protocol):
|
||||
max_seq_length: int
|
||||
EOS: int
|
||||
n_head: int
|
||||
|
||||
def embed(self, x: list[Array], y: Array, bert_features: list[Array]) -> Array: ...
|
||||
|
||||
|
||||
class T2SEngineProtocol(Protocol):
|
||||
def _handle_request(self, request: T2SRequest) -> tuple[list[Array], float]: ...
|
||||
|
||||
def generate(self, request: T2SRequest) -> T2SResult: ...
|
||||
|
||||
@staticmethod
|
||||
def load_decoder(
|
||||
weights_path: os.PathLike, max_batch_size: int = 1, implement: str = "MLX"
|
||||
) -> T2SDecoderProtocol: ...
|
||||
|
||||
|
||||
class T2SSessionMLX:
|
||||
def __init__(
|
||||
self,
|
||||
decoder: T2SDecoderProtocol,
|
||||
request_torch: T2SRequest,
|
||||
sample_func: type[SampleProtocolMLX] = sample_naive,
|
||||
device: mx.Device = mx.Device(mx.cpu),
|
||||
dtype: mx.Dtype = mx.float32,
|
||||
):
|
||||
with mx.stream(device):
|
||||
request = T2SRequestMLX.from_torch(request_torch)
|
||||
|
||||
self.decoder = decoder
|
||||
self.request = request
|
||||
self.device = device
|
||||
self.dtype = dtype
|
||||
|
||||
bsz = len(request.x)
|
||||
y_len: int = cast(tuple[int, ...], request.prompts.shape)[-1]
|
||||
self.bsz = bsz
|
||||
self.y_len = y_len
|
||||
|
||||
# Cache
|
||||
self.kv_cache: MutableSequence[KVCache | KVCacheQ]
|
||||
self.sample = sample_func()
|
||||
|
||||
# Forward args
|
||||
self.x = [i.astype(mx.int32) for i in request.x]
|
||||
self.x_lens = request.x_lens.astype(mx.int32)
|
||||
self.y = mx.zeros((bsz, decoder.max_seq_length)).astype(mx.int32)
|
||||
self.y[:, : cast(tuple[int, ...], request.prompts.shape)[-1]] = request.prompts.astype(mx.int32)
|
||||
self.bert_feature = [i.astype(dtype) for i in request.bert_feature]
|
||||
|
||||
self.prefill_len = self.x_lens + cast(tuple[int, ...], request.prompts.shape)[1]
|
||||
|
||||
self.input_pos = mx.zeros_like(self.prefill_len)
|
||||
self.input_pos += self.prefill_len
|
||||
|
||||
# EOS
|
||||
self.completed = mx.array([False] * len(self.x)).astype(mx.bool_)
|
||||
self.y_results: List[Array] = [None] * len(self.x) # type: ignore
|
||||
|
||||
self.xy_pos = decoder.embed(self.x, request.prompts, self.bert_feature)
|
||||
|
||||
max_len = int(self.prefill_len.max(-1))
|
||||
attn_mask = mx.zeros(shape=(bsz, max_len, max_len), dtype=mx.bool_)
|
||||
|
||||
for bs in range(bsz):
|
||||
pos = int(self.x_lens[bs])
|
||||
seq_len = pos + y_len
|
||||
|
||||
attn_mask[bs, :seq_len, :pos] = True
|
||||
|
||||
ar_mask = ~mx.triu(
|
||||
x=mx.ones(
|
||||
shape=(
|
||||
y_len,
|
||||
y_len,
|
||||
),
|
||||
dtype=mx.bool_,
|
||||
),
|
||||
k=1,
|
||||
)
|
||||
attn_mask[bs, pos:seq_len, pos:seq_len] = ar_mask
|
||||
|
||||
attn_mask = mx.repeat(mx.expand_dims(attn_mask, 1), decoder.n_head, 1)
|
||||
self.attn_mask = attn_mask
|
||||
|
||||
mx.eval(self.attn_mask)
|
||||
231
GPT_SoVITS/Accelerate/MLX/t2s_engine_mlx.py
Normal file
231
GPT_SoVITS/Accelerate/MLX/t2s_engine_mlx.py
Normal file
@ -0,0 +1,231 @@
|
||||
import gc
|
||||
import os
|
||||
import time
|
||||
import traceback
|
||||
from typing import cast
|
||||
|
||||
import mlx.core as mx
|
||||
import torch
|
||||
from rich.progress import BarColumn, Progress, TextColumn, TimeRemainingColumn
|
||||
from tqdm import tqdm
|
||||
|
||||
from ..PyTorch.structs import T2SEngineProtocol, T2SRequest
|
||||
from .backends import mlx_quantized, mlx_static, mlx_varlen
|
||||
from .structs_mlx import T2SResult, T2SSessionMLX
|
||||
from .t2s_model_abc import T2SDecoderABC
|
||||
|
||||
Array = mx.array
|
||||
Tensor = torch.Tensor
|
||||
|
||||
|
||||
class T2SEngine(T2SEngineProtocol):
|
||||
def __init__(
|
||||
self,
|
||||
decoder_model: T2SDecoderABC,
|
||||
device: mx.Device | str = mx.Device(mx.cpu),
|
||||
dtype: torch.dtype | mx.Dtype = torch.float32,
|
||||
) -> None:
|
||||
if isinstance(device, str):
|
||||
match device:
|
||||
case "mx.cpu":
|
||||
device = mx.Device(mx.cpu)
|
||||
case "mx.gpu":
|
||||
device = mx.Device(mx.gpu)
|
||||
|
||||
match dtype:
|
||||
case torch.float32:
|
||||
dtype = mx.float32
|
||||
case torch.float16:
|
||||
dtype = mx.float16
|
||||
case torch.bfloat16:
|
||||
dtype = mx.bfloat16
|
||||
|
||||
device = cast(mx.Device, device)
|
||||
dtype = cast(mx.Dtype, dtype)
|
||||
|
||||
assert device.type.value in {0, 1}
|
||||
assert dtype in {mx.float16, mx.bfloat16, mx.float32}
|
||||
|
||||
self.device = device
|
||||
self.dtype = dtype
|
||||
|
||||
mx.set_default_device(device)
|
||||
decoder_model.set_dtype(self.dtype)
|
||||
|
||||
self.decoder_model: T2SDecoderABC = decoder_model
|
||||
self.decoder_model.compile()
|
||||
|
||||
def _handle_request(self, request: T2SRequest):
|
||||
decoder = self.decoder_model
|
||||
session = T2SSessionMLX(decoder, request, device=self.device, dtype=self.dtype)
|
||||
batch_idx = mx.arange(session.bsz)
|
||||
|
||||
t1 = 0.0
|
||||
infer_speed = 0.0
|
||||
infer_time = 0.0
|
||||
|
||||
with (
|
||||
mx.stream(session.device),
|
||||
Progress(
|
||||
TextColumn("[cyan]{task.description}"),
|
||||
BarColumn(),
|
||||
TextColumn("{task.completed}/{task.total}"),
|
||||
TimeRemainingColumn(),
|
||||
transient=True,
|
||||
) as progress,
|
||||
):
|
||||
task = progress.add_task("T2S Decoding", total=1500)
|
||||
for idx in range(1500):
|
||||
progress.update(task, advance=1)
|
||||
if idx == 0:
|
||||
session.kv_cache = decoder.init_cache(session.bsz)
|
||||
xy_dec = decoder.h.prefill(
|
||||
session.xy_pos,
|
||||
session.attn_mask,
|
||||
session.kv_cache,
|
||||
) # bs, seq_len, embed_dim
|
||||
xy_dec = xy_dec[None, batch_idx, session.input_pos - 1]
|
||||
else:
|
||||
args, kwds = decoder.pre_forward(session)
|
||||
xy_dec = decoder.h(
|
||||
session.input_pos,
|
||||
session.xy_pos,
|
||||
session.kv_cache,
|
||||
batch_idx,
|
||||
*args,
|
||||
**kwds,
|
||||
)
|
||||
|
||||
decoder.post_forward(idx, session)
|
||||
logits = decoder.ar_predict_layer(xy_dec[:, -1])
|
||||
session.input_pos += 1
|
||||
|
||||
if idx == 0:
|
||||
logits[:, -1] = -mx.inf
|
||||
|
||||
samples = session.sample(
|
||||
logits=logits,
|
||||
previous_tokens=session.y[:, : session.y_len + idx],
|
||||
top_k=request.top_k,
|
||||
top_p=request.top_p,
|
||||
repetition_penalty=request.repetition_penalty,
|
||||
temperature=request.temperature,
|
||||
)
|
||||
|
||||
session.y[batch_idx, session.y_len + idx] = samples
|
||||
|
||||
argmax_token = mx.argmax(logits, axis=-1)
|
||||
sample_token = samples.squeeze(1)
|
||||
EOS_mask = (cast(Array, argmax_token == decoder.EOS)) | (sample_token == decoder.EOS)
|
||||
|
||||
newly_done_mask = EOS_mask & (~session.completed)
|
||||
newly_done_indices = mx.where(newly_done_mask, batch_idx, -1)
|
||||
pos = mx.where(newly_done_indices != -1, batch_idx, session.bsz)
|
||||
pos_sorted = mx.sort(pos, axis=0)
|
||||
valid_count = session.bsz - mx.sum(cast(Array, pos_sorted == session.bsz))
|
||||
pos_final = pos_sorted[: int(valid_count)]
|
||||
newly_done_indices = mx.expand_dims(newly_done_indices[pos_final], 0)
|
||||
|
||||
if newly_done_indices.size > 0:
|
||||
for i in newly_done_indices:
|
||||
session.y_results[int(i)] = session.y[i, session.y_len : session.y_len + idx]
|
||||
session.completed[newly_done_indices] = True
|
||||
|
||||
if mx.all(session.completed).item():
|
||||
if session.y.sum() == 0:
|
||||
session.y_results = [mx.array([0]) for _ in range(session.bsz)]
|
||||
tqdm.write("Bad Zero Prediction")
|
||||
else:
|
||||
tqdm.write(
|
||||
f"T2S Decoding EOS {session.prefill_len.tolist().__str__().strip('[]')} -> {[cast(tuple[int, ...], i.shape)[-1] for i in session.y_results].__str__().strip('[]')}"
|
||||
)
|
||||
tqdm.write(f"Infer Speed: {(idx - 1) / (time.perf_counter() - t1):.2f} token/s")
|
||||
infer_time = time.perf_counter() - t1
|
||||
infer_speed = (idx - 1) / infer_time
|
||||
break
|
||||
|
||||
if (request.early_stop_num != -1 and idx >= request.early_stop_num) or idx == 1499:
|
||||
for j in range(session.bsz):
|
||||
if not session.completed[j].item():
|
||||
session.y_results[j] = session.y[[j], session.y_len : session.y_len + 1499]
|
||||
session.completed[j] = True
|
||||
tqdm.write(
|
||||
f"T2S Decoding EOS {session.prefill_len.tolist().__str__().strip('[]')} -> {[cast(tuple[int, ...], i.shape)[-1] for i in session.y_results].__str__().strip('[]')}"
|
||||
)
|
||||
tqdm.write(f"Infer Speed: {(idx - 1) / (time.perf_counter() - t1):.2f} token/s")
|
||||
infer_time = time.perf_counter() - t1
|
||||
infer_speed = (idx - 1) / infer_time
|
||||
break
|
||||
|
||||
y_emb = decoder.ar_audio_embedding(samples)
|
||||
session.xy_pos = decoder.ar_audio_position(session.input_pos - session.x_lens, y_emb)
|
||||
mx.eval(session.xy_pos, session.y)
|
||||
|
||||
if idx == 1:
|
||||
t1 = time.perf_counter()
|
||||
|
||||
if idx % 100 == 0:
|
||||
mx.clear_cache()
|
||||
|
||||
match session.device:
|
||||
case mx.gpu:
|
||||
mx.clear_cache()
|
||||
case mx.cpu:
|
||||
gc.collect()
|
||||
|
||||
result_mlx = session.y_results[: request.valid_length]
|
||||
mx.eval(result_mlx)
|
||||
result = [torch.tensor(k) for k in result_mlx]
|
||||
return result, infer_speed, infer_time
|
||||
|
||||
def generate(self, request: T2SRequest):
|
||||
try:
|
||||
result, infer_speed, infer_time = self._handle_request(request)
|
||||
t2s_result = T2SResult(result=result, infer_speed=(infer_speed, infer_time), status="Success")
|
||||
except Exception as e:
|
||||
t2s_result = T2SResult(status="Error", exception=e, traceback=traceback.format_exc())
|
||||
return t2s_result
|
||||
|
||||
@staticmethod
|
||||
def replace_key(state_dict: dict[str, Tensor]):
|
||||
state_dict_mlx: list[tuple[str, Array]] = []
|
||||
for key, value in state_dict.items():
|
||||
key = (
|
||||
key.replace("model.", "")
|
||||
.replace("in_proj_", "in_proj.")
|
||||
.replace("self_attn", "attention")
|
||||
.replace("linear", "feed_forward.linear")
|
||||
.replace("norm1", "attention_norm")
|
||||
.replace("norm2", "ffn_norm")
|
||||
)
|
||||
value_mlx = mx.array(value)
|
||||
state_dict_mlx.append((key, value_mlx))
|
||||
return state_dict_mlx
|
||||
|
||||
@staticmethod
|
||||
def load_decoder(weights_path: os.PathLike, max_batch_size: int = 1, backend: str = "MLX-Varlen"):
|
||||
print(f"Loading Text2Semantic Weights from {weights_path} with {backend} Backend")
|
||||
dict_s1 = torch.load(weights_path, map_location="cpu", weights_only=False, mmap=True)
|
||||
config = dict_s1["config"]
|
||||
match backend:
|
||||
case "MLX-Varlen":
|
||||
decoder_cls: type[T2SDecoderABC] = mlx_varlen.T2SDecoder
|
||||
case "MLX-Static":
|
||||
decoder_cls = mlx_static.T2SDecoder
|
||||
case "MLX-Quantized":
|
||||
decoder_cls = mlx_quantized.T2SDecoder
|
||||
case _:
|
||||
raise RuntimeError(f"Backend {backend} Not Found")
|
||||
|
||||
decoder: T2SDecoderABC = decoder_cls(config, max_batch_size=max_batch_size)
|
||||
state_dict = dict_s1["weight"]
|
||||
state_dict_mlx = T2SEngine.replace_key(state_dict)
|
||||
decoder.load_weights(state_dict_mlx)
|
||||
decoder.eval()
|
||||
mx.eval(decoder)
|
||||
|
||||
if "Quantized" in backend and isinstance(decoder, mlx_quantized.T2SDecoder):
|
||||
decoder.quantized()
|
||||
mx.eval(decoder)
|
||||
|
||||
return decoder
|
||||
530
GPT_SoVITS/Accelerate/MLX/t2s_model_abc.py
Normal file
530
GPT_SoVITS/Accelerate/MLX/t2s_model_abc.py
Normal file
@ -0,0 +1,530 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import MutableSequence, cast
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from .structs_mlx import KVCache, KVCacheProtocol, KVCacheQ, T2SDecoderProtocol, T2SSessionMLX
|
||||
|
||||
Array = mx.array
|
||||
|
||||
|
||||
class TokenEmbedding(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
embedding_dim: int,
|
||||
vocab_size: int,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.vocab_size = vocab_size
|
||||
self.embedding_dim = embedding_dim
|
||||
|
||||
self.word_embeddings = nn.Embedding(self.vocab_size, self.embedding_dim)
|
||||
|
||||
@property
|
||||
def weight(self):
|
||||
return self.word_embeddings.weight
|
||||
|
||||
def embedding(self, index: int):
|
||||
return self.word_embeddings.weight[index : index + 1]
|
||||
|
||||
def __call__(self, x: Array):
|
||||
x = self.word_embeddings(x)
|
||||
return x
|
||||
|
||||
|
||||
class SinePositionalEmbedding(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
embedding_dim: int,
|
||||
scale: bool = False,
|
||||
max_batch_size: int = 10,
|
||||
max_seq_len: int = 1800,
|
||||
):
|
||||
super().__init__()
|
||||
self.embedding_dim = embedding_dim
|
||||
self.x_scale = math.sqrt(embedding_dim) if scale else 1.0
|
||||
self.alpha = mx.ones(1)
|
||||
self.max_batch_size = max_batch_size
|
||||
self.max_seq_len = max_seq_len
|
||||
|
||||
self.reverse = False
|
||||
self._pe = mx.zeros((max_batch_size, max_seq_len, embedding_dim))
|
||||
self.compute_pe()
|
||||
|
||||
def compute_pe(self):
|
||||
"""Reset the positional encodings."""
|
||||
|
||||
if self.reverse:
|
||||
position = mx.expand_dims(mx.arange(self.max_seq_len - 1, -1, -1.0), axis=1)
|
||||
else:
|
||||
position = mx.expand_dims(mx.arange(self.max_seq_len), axis=1)
|
||||
div_term = mx.exp(
|
||||
mx.arange(
|
||||
0,
|
||||
self.embedding_dim,
|
||||
2,
|
||||
)
|
||||
* -(math.log(10000.0) / self.embedding_dim)
|
||||
)
|
||||
pe = self._pe
|
||||
pe[:, :, 0::2] = mx.sin(position * div_term)
|
||||
pe[:, :, 1::2] = mx.cos(position * div_term)
|
||||
|
||||
def __call__(self, input_pos: Array, x: Array):
|
||||
"""
|
||||
Args:
|
||||
input_pos (Array): [batch_size, ]
|
||||
x (Array): [batch_size, 1, embed_dim]
|
||||
|
||||
Returns:
|
||||
embedded_x (Array): [batch_size, 1, embed_dim]
|
||||
"""
|
||||
|
||||
batch_size = cast(tuple[int, ...], x.shape)[0]
|
||||
pe_values = self._pe[mx.arange(batch_size), input_pos - 1] # (batch_size, embed_dim)
|
||||
|
||||
return x * self.x_scale + self.alpha * mx.expand_dims(pe_values, 1) # (batch_size, 1, embed_dim)
|
||||
|
||||
def prefill(self, x: Array):
|
||||
"""
|
||||
Args:
|
||||
x (Array): [batch_size, seq_len, embed_dim]
|
||||
|
||||
Returns:
|
||||
embedded_x (Array): [batch_size, seq_len, embed_dim]
|
||||
"""
|
||||
pe_values = self._pe[:, : cast(tuple[int, ...], x.shape)[-2]]
|
||||
return x * self.x_scale + self.alpha * pe_values
|
||||
|
||||
|
||||
class KVCacheHND(KVCacheProtocol):
|
||||
@staticmethod
|
||||
def empty(kv_cache):
|
||||
assert len(kv_cache) == 2
|
||||
k_cache, v_cache = kv_cache
|
||||
|
||||
k_cache[:] = 0
|
||||
v_cache[:] = 0
|
||||
|
||||
@staticmethod
|
||||
def update_cache(input_pos, k_val, v_val, kv_cache, cache_idx):
|
||||
# input_pos: [B, ], k_val: [B, H, 1, D]
|
||||
assert len(kv_cache) == 2
|
||||
k_out, v_out = kv_cache
|
||||
ip0 = input_pos - 1
|
||||
|
||||
k_out[cache_idx, :, ip0, None] = k_val
|
||||
v_out[cache_idx, :, ip0, None] = v_val
|
||||
|
||||
return k_out, v_out
|
||||
|
||||
@staticmethod
|
||||
def prefill_kv(k_val, v_val, kv_cache):
|
||||
# k_val: [B, S, H, D]
|
||||
assert len(kv_cache) == 2
|
||||
k_cache, v_cache = kv_cache
|
||||
|
||||
k_cache[..., : cast(tuple[int, ...], k_val.shape)[1], :] = k_val.swapaxes(1, 2)
|
||||
v_cache[..., : cast(tuple[int, ...], v_val.shape)[1], :] = v_val.swapaxes(1, 2)
|
||||
|
||||
@staticmethod
|
||||
def init_cache(batch_size: int, max_seq_length: int, n_heads: int, head_dim: int, dtype: mx.Dtype) -> KVCache:
|
||||
cache_shape = (batch_size, n_heads, max_seq_length, head_dim)
|
||||
|
||||
return (mx.zeros(cache_shape, dtype=dtype), mx.zeros(cache_shape, dtype=dtype))
|
||||
|
||||
|
||||
class KVCacheHNDQuantized(KVCacheProtocol):
|
||||
@staticmethod
|
||||
def _el_per_int(bits: int) -> int:
|
||||
return 32 // bits
|
||||
|
||||
@staticmethod
|
||||
def _packed_dim(head_dim: int, bits: int = 8) -> int:
|
||||
el_per_int = KVCacheHNDQuantized._el_per_int(bits)
|
||||
if head_dim % el_per_int != 0:
|
||||
raise ValueError(f"{head_dim=} is not divisible by {el_per_int=} ({bits=})")
|
||||
return head_dim // el_per_int
|
||||
|
||||
@staticmethod
|
||||
def _group_count(head_dim: int, group_size: int = 32) -> int:
|
||||
assert group_size in {32, 64, 128}
|
||||
if head_dim % group_size != 0:
|
||||
raise ValueError(f"{head_dim} is not divisible by {group_size=}")
|
||||
return head_dim // group_size
|
||||
|
||||
@staticmethod
|
||||
def empty(kv_cache) -> None:
|
||||
assert len(kv_cache) == 3
|
||||
(k_q, k_s, k_b), (v_q, v_s, v_b), (_, __) = kv_cache
|
||||
|
||||
k_q[:] = 0
|
||||
k_s[:] = 0
|
||||
k_b[:] = 0
|
||||
v_q[:] = 0
|
||||
v_s[:] = 0
|
||||
v_b[:] = 0
|
||||
|
||||
@staticmethod
|
||||
def update_cache(
|
||||
input_pos,
|
||||
k_val,
|
||||
v_val,
|
||||
kv_cache,
|
||||
cache_idx,
|
||||
):
|
||||
# input_pos: [B, ], k_val: [B, H, 1, D]
|
||||
|
||||
assert len(kv_cache) == 3
|
||||
(k_q_out, k_s_out, k_b_out), (v_q_out, v_s_out, v_b_out), (group_size, bits) = kv_cache
|
||||
|
||||
k_q, k_s, k_b = mx.quantize(k_val, group_size=group_size, bits=bits)
|
||||
v_q, v_s, v_b = mx.quantize(v_val, group_size=group_size, bits=bits)
|
||||
|
||||
ip0 = input_pos - 1
|
||||
|
||||
k_q_out[cache_idx, :, ip0, None] = k_q
|
||||
k_s_out[cache_idx, :, ip0, None] = k_s
|
||||
k_b_out[cache_idx, :, ip0, None] = k_b
|
||||
|
||||
v_q_out[cache_idx, :, ip0, None] = v_q
|
||||
v_s_out[cache_idx, :, ip0, None] = v_s
|
||||
v_b_out[cache_idx, :, ip0, None] = v_b
|
||||
|
||||
return (k_q_out, k_s_out, k_b_out), (v_q_out, v_s_out, v_b_out), (group_size, bits)
|
||||
|
||||
@staticmethod
|
||||
def prefill_kv(
|
||||
k_val,
|
||||
v_val,
|
||||
kv_cache,
|
||||
) -> None:
|
||||
assert len(kv_cache) == 3
|
||||
(k_q_out, k_s_out, k_b_out), (v_q_out, v_s_out, v_b_out), (group_size, bits) = kv_cache
|
||||
|
||||
S = cast(tuple[int, ...], k_val.shape)[1]
|
||||
|
||||
k_sw = k_val.swapaxes(1, 2)
|
||||
v_sw = v_val.swapaxes(1, 2)
|
||||
|
||||
k_q, k_s, k_b = mx.quantize(k_sw, group_size=group_size, bits=bits)
|
||||
v_q, v_s, v_b = mx.quantize(v_sw, group_size=group_size, bits=bits)
|
||||
|
||||
k_q_out[..., :S, :] = k_q
|
||||
k_s_out[..., :S, :] = k_s
|
||||
k_b_out[..., :S, :] = k_b
|
||||
|
||||
v_q_out[..., :S, :] = v_q
|
||||
v_s_out[..., :S, :] = v_s
|
||||
v_b_out[..., :S, :] = v_b
|
||||
|
||||
@staticmethod
|
||||
def init_cache(
|
||||
batch_size: int,
|
||||
max_seq_length: int,
|
||||
n_heads: int,
|
||||
head_dim: int,
|
||||
dtype: mx.Dtype,
|
||||
*,
|
||||
group_size: int = 32,
|
||||
bits: int = 8,
|
||||
) -> KVCacheQ:
|
||||
packed_dim = KVCacheHNDQuantized._packed_dim(head_dim, bits=bits)
|
||||
group_cnt = KVCacheHNDQuantized._group_count(head_dim, group_size=group_size)
|
||||
|
||||
packed_shape = (batch_size, n_heads, max_seq_length, packed_dim)
|
||||
group_shape = (batch_size, n_heads, max_seq_length, group_cnt)
|
||||
|
||||
k_q = mx.zeros(packed_shape, dtype=mx.uint32)
|
||||
k_s = mx.zeros(group_shape, dtype=dtype)
|
||||
k_b = mx.zeros(group_shape, dtype=dtype)
|
||||
|
||||
v_q = mx.zeros(packed_shape, dtype=mx.uint32)
|
||||
v_s = mx.zeros(group_shape, dtype=dtype)
|
||||
v_b = mx.zeros(group_shape, dtype=dtype)
|
||||
|
||||
return (k_q, k_s, k_b), (v_q, v_s, v_b), (group_size, bits)
|
||||
|
||||
|
||||
class AttentionABC(ABC, nn.Module):
|
||||
def __init__(self, n_head: int, hidden_dim: int, max_seq_length: int, *args, **kwds):
|
||||
super().__init__()
|
||||
|
||||
self.n_head = n_head
|
||||
self.hidden_dim = hidden_dim
|
||||
assert hidden_dim % n_head == 0
|
||||
self.head_dim = hidden_dim // n_head
|
||||
|
||||
self.max_seq_length = max_seq_length
|
||||
|
||||
# key, query, value projections for all heads, but in a batch
|
||||
self.in_proj = nn.Linear(hidden_dim, hidden_dim * 3, bias=True)
|
||||
self.out_proj = nn.Linear(hidden_dim, hidden_dim, bias=True)
|
||||
|
||||
self.scale = 1 / math.sqrt(self.head_dim)
|
||||
|
||||
self.kc_class: KVCacheProtocol
|
||||
|
||||
@abstractmethod
|
||||
def __call__(
|
||||
self, x: Array, input_pos: Array, kv_cache: KVCache | KVCacheQ, cache_idx: Array, attn_mask: Array
|
||||
) -> Array: ...
|
||||
|
||||
def prefill(self, x: Array, kv_cache: KVCache | KVCacheQ, attn_mask: Array):
|
||||
bsz, seqlen, _ = cast(tuple[int, ...], x.shape)
|
||||
|
||||
q, k, v = self.in_proj(mx.expand_dims(x, 0)).split(3, axis=-1)
|
||||
|
||||
q, k, v = map(lambda x: x.reshape(bsz, seqlen, self.n_head, self.head_dim), (q, k, v))
|
||||
|
||||
self.kc_class.prefill_kv(k, v, kv_cache)
|
||||
|
||||
q, k, v = map(lambda x: x.swapaxes(1, 2), (q, k, v))
|
||||
|
||||
attn = mx.fast.scaled_dot_product_attention(q, k, v, mask=attn_mask, scale=self.scale)
|
||||
|
||||
attn = mx.nan_to_num(attn)
|
||||
|
||||
attn = attn.swapaxes(1, 2).reshape(1, -1, self.hidden_dim)
|
||||
|
||||
output = self.out_proj(attn)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(self, dim: int, hidden_dim: int) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.linear1 = nn.Linear(dim, hidden_dim, bias=True)
|
||||
self.linear2 = nn.Linear(hidden_dim, dim, bias=True)
|
||||
|
||||
def __call__(self, x: Array):
|
||||
return self.linear2(nn.relu(self.linear1(x)))
|
||||
|
||||
|
||||
class TransformerBlockABC(nn.Module):
|
||||
def __init__(self, n_head: int, ffn_dim: int, hidden_dim: int, max_seq_length: int, *args, **kwds) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.hidden_dim = hidden_dim
|
||||
self.max_seq_length = max_seq_length
|
||||
|
||||
self.attention: AttentionABC
|
||||
|
||||
self.feed_forward = FeedForward(hidden_dim, ffn_dim)
|
||||
self.attention_norm = nn.LayerNorm(self.hidden_dim)
|
||||
self.ffn_norm = nn.LayerNorm(self.hidden_dim)
|
||||
|
||||
def __call__(self, x: Array, input_pos: Array, kv_cache: KVCache | KVCacheQ, cache_idx: Array, attn_mask: Array):
|
||||
h = self.attention_norm(
|
||||
x
|
||||
+ self.attention(
|
||||
x,
|
||||
input_pos,
|
||||
kv_cache,
|
||||
cache_idx,
|
||||
attn_mask,
|
||||
)
|
||||
)
|
||||
out = self.ffn_norm(h + self.feed_forward(h))
|
||||
return out
|
||||
|
||||
def prefill(self, x: Array, attn_mask: Array, kv_cache: KVCache | KVCacheQ):
|
||||
h = self.attention_norm(
|
||||
x
|
||||
+ self.attention.prefill(
|
||||
x,
|
||||
kv_cache,
|
||||
attn_mask,
|
||||
)
|
||||
)
|
||||
out = self.ffn_norm(h + self.feed_forward(h))
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class TransformerDecoderABC(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
hidden_dim: int,
|
||||
n_layer: int,
|
||||
n_head: int,
|
||||
ffn_dim: int,
|
||||
vocab_size: int,
|
||||
max_seq_length: int,
|
||||
max_batch_size: int,
|
||||
*args,
|
||||
**kwds,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.hidden_dim = hidden_dim
|
||||
self.n_head = n_head
|
||||
assert hidden_dim % n_head == 0
|
||||
|
||||
self.head_dim = hidden_dim // n_head
|
||||
self.vocab_size = vocab_size
|
||||
|
||||
self.n_layer = n_layer
|
||||
|
||||
self.layers: MutableSequence[TransformerBlockABC]
|
||||
|
||||
self.max_seq_length = max_seq_length
|
||||
self.max_batch_size = max_batch_size
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
input_pos: Array,
|
||||
x: Array,
|
||||
kv_caches: MutableSequence[KVCache | KVCacheQ],
|
||||
cache_idx: Array,
|
||||
*args,
|
||||
**kwds,
|
||||
):
|
||||
for layer, kv_cache in zip(self.layers, kv_caches):
|
||||
x = layer(
|
||||
x,
|
||||
input_pos,
|
||||
kv_cache,
|
||||
cache_idx,
|
||||
*args,
|
||||
**kwds,
|
||||
)
|
||||
|
||||
return x
|
||||
|
||||
def prefill(self, x: Array, mask: Array, kv_caches: MutableSequence[KVCache | KVCacheQ]):
|
||||
for layer, kv_cache in zip(self.layers, kv_caches):
|
||||
x = layer.prefill(
|
||||
x,
|
||||
mask,
|
||||
kv_cache,
|
||||
)
|
||||
return x
|
||||
|
||||
|
||||
class T2SDecoderABC(nn.Module, T2SDecoderProtocol):
|
||||
def __init__(
|
||||
self,
|
||||
config: dict,
|
||||
max_seq_length: int = 1800,
|
||||
max_batch_size: int = 10,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
hidden_dim: int = config["model"]["hidden_dim"]
|
||||
embedding_dim: int = config["model"]["embedding_dim"]
|
||||
n_head: int = config["model"]["head"]
|
||||
n_layer: int = config["model"]["n_layer"]
|
||||
vocab_size: int = config["model"]["vocab_size"]
|
||||
phoneme_vocab_size: int = config["model"]["phoneme_vocab_size"]
|
||||
EOS: int = config["model"]["EOS"]
|
||||
ffn_dim: int = hidden_dim * 4
|
||||
|
||||
self.n_layer = int(n_layer)
|
||||
self.hidden_dim = int(hidden_dim)
|
||||
self.n_head = int(n_head)
|
||||
assert hidden_dim % n_head == 0
|
||||
|
||||
self.head_dim = int(hidden_dim // n_head)
|
||||
self.embedding_dim = int(embedding_dim)
|
||||
self.ffn_dim = int(ffn_dim)
|
||||
self.vocab_size = int(vocab_size)
|
||||
self.phoneme_vocab_size = int(phoneme_vocab_size)
|
||||
self.max_seq_length = max_seq_length
|
||||
self.max_batch_size = max_batch_size
|
||||
self.EOS = EOS
|
||||
assert self.EOS == self.vocab_size - 1
|
||||
|
||||
self.bert_proj = nn.Linear(1024, self.embedding_dim)
|
||||
self.ar_predict_layer = nn.Linear(self.hidden_dim, self.vocab_size, bias=False)
|
||||
self.h: TransformerDecoderABC
|
||||
|
||||
self.ar_text_embedding = TokenEmbedding(self.embedding_dim, self.phoneme_vocab_size)
|
||||
self.ar_text_position = SinePositionalEmbedding(
|
||||
self.embedding_dim,
|
||||
scale=False,
|
||||
max_batch_size=max_batch_size,
|
||||
max_seq_len=max_seq_length,
|
||||
)
|
||||
self.ar_audio_embedding = TokenEmbedding(self.embedding_dim, self.vocab_size)
|
||||
self.ar_audio_position = SinePositionalEmbedding(
|
||||
self.embedding_dim,
|
||||
scale=False,
|
||||
max_batch_size=max_batch_size,
|
||||
max_seq_len=max_seq_length,
|
||||
)
|
||||
|
||||
self.kv_class: KVCacheProtocol
|
||||
|
||||
def init_cache(self, bsz: int = 0, *args, **kwds) -> MutableSequence[KVCache | KVCacheQ]:
|
||||
bsz = bsz or self.h.max_batch_size
|
||||
assert bsz <= self.h.max_batch_size
|
||||
seq_lens = self.h.max_seq_length
|
||||
dtype = self.bert_proj.bias.dtype
|
||||
cache: MutableSequence[KVCache | KVCacheQ] = [
|
||||
self.kv_class.init_cache(bsz, seq_lens, self.n_head, self.head_dim, dtype, *args, **kwds)
|
||||
for _ in range(self.n_layer)
|
||||
]
|
||||
mx.eval(cache)
|
||||
return cache
|
||||
|
||||
def embed(
|
||||
self,
|
||||
x: list[Array],
|
||||
y: Array,
|
||||
bert_features: list[Array],
|
||||
):
|
||||
x_len: list[int] = [cast(tuple[int, ...], i.shape)[0] for i in x]
|
||||
x_len_max = max(x_len)
|
||||
xy_pos = mx.zeros((len(x), x_len_max + cast(tuple[int, ...], y.shape)[1], self.embedding_dim)).astype(
|
||||
bert_features[0].dtype
|
||||
)
|
||||
|
||||
bert_features = list(map(lambda x: x.swapaxes(0, 1), bert_features))
|
||||
|
||||
y_len = cast(tuple[int, ...], y.shape)[1]
|
||||
y_emb = self.ar_audio_embedding(y)
|
||||
y_pos = self.ar_audio_position.prefill(y_emb)
|
||||
|
||||
for bs, (x_, len_, bert_feature) in enumerate(zip(x, x_len, bert_features)):
|
||||
x_emb = self.ar_text_embedding(x_)
|
||||
bert = self.bert_proj(bert_feature)
|
||||
x_emb = x_emb + bert
|
||||
x_pos = self.ar_text_position.prefill(mx.expand_dims(x_emb, 0))
|
||||
xy_pos[[bs], :len_] = x_pos
|
||||
xy_pos[[bs], len_ : len_ + y_len] = y_pos
|
||||
|
||||
mx.eval(xy_pos)
|
||||
return xy_pos
|
||||
|
||||
def compile(self):
|
||||
setattr(self.h, "__call__", mx.compile(self.h.__call__))
|
||||
# setattr(self.h, "prefill", mx.compile(self.h.prefill, shapeless=True))
|
||||
|
||||
def pre_forward(self, session: T2SSessionMLX):
|
||||
attn_mask = session.attn_mask
|
||||
return list(), dict(attn_mask=attn_mask)
|
||||
|
||||
def post_forward(self, idx: int, session: T2SSessionMLX) -> None:
|
||||
if idx == 0:
|
||||
prefill_len = session.prefill_len
|
||||
bsz = session.bsz
|
||||
|
||||
range_tensor = mx.arange(self.max_seq_length).reshape(1, 1, 1, self.max_seq_length)
|
||||
prefill_len_expanded = prefill_len.reshape(bsz, 1, 1, 1)
|
||||
attn_mask = range_tensor < prefill_len_expanded
|
||||
attn_mask = mx.repeat(attn_mask, self.n_head, 1)
|
||||
|
||||
session.attn_mask = attn_mask
|
||||
|
||||
attn_mask = session.attn_mask
|
||||
input_pos = session.input_pos
|
||||
attn_mask[mx.arange(session.bsz), :, :, input_pos] = True
|
||||
mx.eval(attn_mask)
|
||||
28
GPT_SoVITS/Accelerate/PyTorch/__init__.py
Normal file
28
GPT_SoVITS/Accelerate/PyTorch/__init__.py
Normal file
@ -0,0 +1,28 @@
|
||||
import importlib.util
|
||||
|
||||
import torch
|
||||
|
||||
from .sample_funcs import sample_naive
|
||||
from .structs import T2SRequest, T2SResult
|
||||
from .t2s_engine import T2SEngine as T2SEngineTorch
|
||||
|
||||
backends = ["torch_varlen"]
|
||||
if torch.cuda.is_available():
|
||||
backends.append("torch_static_cuda_graph")
|
||||
if importlib.util.find_spec("sageattention") is not None:
|
||||
for i in range(torch.cuda.device_count()):
|
||||
major, minor = torch.cuda.get_device_capability(i)
|
||||
sm_version = major + minor / 10.0
|
||||
if sm_version >= 7.0:
|
||||
backends.append("sage_attn_varlen_cuda_graph")
|
||||
if importlib.util.find_spec("flash_attn") is not None:
|
||||
for i in range(torch.cuda.device_count()):
|
||||
major, minor = torch.cuda.get_device_capability(i)
|
||||
sm_version = major + minor / 10.0
|
||||
if sm_version >= 7.5:
|
||||
backends.append("flash_attn_varlen_cuda_graph")
|
||||
if torch.mps.is_available():
|
||||
backends.append("mps_flash_attn_varlen")
|
||||
|
||||
|
||||
__all__ = ["T2SEngineTorch", "T2SRequest", "sample_naive", "T2SResult", "backends"]
|
||||
@ -0,0 +1,157 @@
|
||||
"""
|
||||
Modified From https://github.com/XXXXRT666/GPT-SoVITS
|
||||
"""
|
||||
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
import kernels
|
||||
import torch
|
||||
|
||||
from .. import nn
|
||||
from ..structs import T2SSession
|
||||
from ..t2s_model_abc import (
|
||||
AttentionABC,
|
||||
CUDAGraphCacheABC,
|
||||
FeedForward,
|
||||
KVCacheNHD,
|
||||
KVCacheProtocol,
|
||||
T2SDecoderABC,
|
||||
TransformerBlockABC,
|
||||
TransformerDecoderABC,
|
||||
)
|
||||
|
||||
flash_attn_kernel = None
|
||||
try:
|
||||
import flash_attn_interface as flash_attn # type: ignore
|
||||
|
||||
flash_attn_kernel = flash_attn.flash_attn_with_kvcache
|
||||
except ModuleNotFoundError:
|
||||
try:
|
||||
import flash_attn # type: ignore
|
||||
|
||||
flash_attn_kernel = flash_attn.flash_attn_with_kvcache
|
||||
|
||||
except ModuleNotFoundError:
|
||||
pass
|
||||
|
||||
if flash_attn_kernel is None:
|
||||
flash_attn_kernel = kernels.get_kernel("kernels-community/flash-attn").flash_attn_with_kvcache
|
||||
|
||||
|
||||
Tensor = torch.Tensor
|
||||
|
||||
|
||||
class Attention(AttentionABC):
|
||||
def __init__(self, n_head, hidden_dim, max_seq_length):
|
||||
super().__init__(n_head, hidden_dim, max_seq_length)
|
||||
|
||||
self.in_proj = nn.Linear(hidden_dim, hidden_dim * 3, bias=True)
|
||||
self.out_proj = nn.Linear(hidden_dim, hidden_dim, bias=True)
|
||||
|
||||
def __call__(self, x: Tensor, input_pos: Tensor, kv_cache: KVCacheProtocol, *args, **kwds) -> Tensor:
|
||||
bsz, seqlen, _ = x.shape
|
||||
|
||||
q, k, v = self.in_proj(x).chunk(3, dim=-1)
|
||||
|
||||
q = q.view(bsz, seqlen, self.n_head, self.head_dim)
|
||||
k = k.view(bsz, seqlen, self.n_head, self.head_dim)
|
||||
v = v.view(bsz, seqlen, self.n_head, self.head_dim)
|
||||
|
||||
attn: Tensor = flash_attn.flash_attn_with_kvcache( # type: ignore
|
||||
q, kv_cache.k_cache, kv_cache.v_cache, k, v, cache_seqlens=input_pos - 1
|
||||
)
|
||||
|
||||
attn = attn.view(bsz, seqlen, self.hidden_dim)
|
||||
|
||||
attn = self.out_proj(attn)
|
||||
|
||||
return attn
|
||||
|
||||
|
||||
class TransformerBlock(TransformerBlockABC):
|
||||
def __init__(self, n_head, ffn_dim, hidden_dim, max_seq_length) -> None:
|
||||
super().__init__(n_head, ffn_dim, hidden_dim, max_seq_length)
|
||||
|
||||
self.attention = Attention(n_head, hidden_dim, max_seq_length)
|
||||
self.feed_forward = FeedForward(hidden_dim, ffn_dim)
|
||||
self.attention_norm = nn.LayerNorm([self.hidden_dim])
|
||||
self.ffn_norm = nn.LayerNorm([self.hidden_dim])
|
||||
|
||||
|
||||
class TransformerDecoder(TransformerDecoderABC):
|
||||
def __init__(
|
||||
self,
|
||||
hidden_dim,
|
||||
n_layer,
|
||||
n_head,
|
||||
ffn_dim,
|
||||
vocab_size,
|
||||
max_seq_length,
|
||||
max_batch_size,
|
||||
) -> None:
|
||||
super().__init__(hidden_dim, n_layer, n_head, ffn_dim, vocab_size, max_seq_length, max_batch_size)
|
||||
|
||||
self.layers = nn.ModuleList( # type: ignore
|
||||
TransformerBlock(n_head, ffn_dim, hidden_dim, max_seq_length) for _ in range(n_layer)
|
||||
)
|
||||
|
||||
|
||||
class T2SDecoder(T2SDecoderABC):
|
||||
def __init__(
|
||||
self,
|
||||
config,
|
||||
max_seq_length=1800,
|
||||
max_batch_size=10,
|
||||
) -> None:
|
||||
assert torch.cuda.is_available()
|
||||
super().__init__(config, max_seq_length, max_batch_size)
|
||||
|
||||
self.bert_proj = nn.Linear(1024, self.embedding_dim)
|
||||
self.ar_predict_layer = nn.Linear(self.hidden_dim, self.vocab_size, bias=False)
|
||||
self.h: TransformerDecoderABC = TransformerDecoder(
|
||||
self.hidden_dim, self.n_layer, self.n_head, self.ffn_dim, self.vocab_size, max_seq_length, max_batch_size
|
||||
)
|
||||
|
||||
self.kv_class = KVCacheNHD
|
||||
|
||||
def post_forward(self, idx: int, session: T2SSession) -> None:
|
||||
return super().post_forward(idx, session)
|
||||
|
||||
def pre_forward(self, session: T2SSession) -> Tuple[List, Dict]:
|
||||
return super().pre_forward(session)
|
||||
|
||||
|
||||
class CUDAGraphCache(CUDAGraphCacheABC):
|
||||
def __init__(
|
||||
self,
|
||||
decoder: T2SDecoder,
|
||||
) -> None:
|
||||
super().__init__(decoder)
|
||||
|
||||
def release_graph(self, session: T2SSession):
|
||||
if session.id != self.id:
|
||||
self.assigned = False
|
||||
else:
|
||||
del session.graph, session.xy_pos_, session.xy_dec_, session.input_pos, session.kv_cache
|
||||
|
||||
def get_cache_graph(self, session: T2SSession):
|
||||
assert self.graph
|
||||
session.graph = self.graph
|
||||
session.stream = self.stream
|
||||
|
||||
session.xy_pos_ = self.xy_pos
|
||||
session.xy_dec_ = self.xy_dec
|
||||
session.input_pos = self.input_pos.copy_(session.input_pos)
|
||||
|
||||
for cache, cache_ in zip(self.kv_cache, session.kv_cache):
|
||||
cache.sync_cache(cache_)
|
||||
|
||||
def capture_new_graph(self, session: T2SSession):
|
||||
session.xy_pos_ = self.xy_pos.clone()
|
||||
session.xy_dec_ = self.xy_dec.clone()
|
||||
session.input_pos = self.input_pos.clone().copy_(session.input_pos)
|
||||
|
||||
args, kwds = self.decoder.pre_forward(session)
|
||||
graph = self.decoder.capture(self.input_pos, self.xy_pos, self.xy_dec, self.kv_cache, *args, **kwds)
|
||||
session.graph = graph
|
||||
session.stream = torch.cuda.Stream() # type: ignore
|
||||
165
GPT_SoVITS/Accelerate/PyTorch/backends/mps_flash_attn_varlen.py
Normal file
165
GPT_SoVITS/Accelerate/PyTorch/backends/mps_flash_attn_varlen.py
Normal file
@ -0,0 +1,165 @@
|
||||
import torch
|
||||
from torch.nn import functional as F
|
||||
|
||||
from .. import nn
|
||||
from ..structs import KVCacheProtocol, T2SSession
|
||||
from ..t2s_model_abc import (
|
||||
AttentionABC,
|
||||
CUDAGraphCacheABC,
|
||||
FeedForward,
|
||||
KVCacheHND,
|
||||
T2SDecoderABC,
|
||||
TransformerBlockABC,
|
||||
TransformerDecoderABC,
|
||||
)
|
||||
|
||||
Tensor = torch.Tensor
|
||||
|
||||
|
||||
class Attention(AttentionABC):
|
||||
def __init__(self, n_head, hidden_dim, max_seq_length):
|
||||
super().__init__(n_head, hidden_dim, max_seq_length)
|
||||
|
||||
# key, query, value projections for all heads, but in a batch
|
||||
self.in_proj = nn.Linear(hidden_dim, hidden_dim * 3, bias=True)
|
||||
self.out_proj = nn.Linear(hidden_dim, hidden_dim, bias=True)
|
||||
|
||||
def __call__(self, x: Tensor, input_pos: Tensor, kv_cache: KVCacheProtocol, attn_mask: Tensor):
|
||||
bsz, seqlen, _ = x.shape
|
||||
|
||||
q, k, v = self.in_proj(x).chunk(3, dim=-1)
|
||||
|
||||
q = q.view(bsz, seqlen, self.n_head, self.head_dim)
|
||||
k = k.view(bsz, seqlen, self.n_head, self.head_dim)
|
||||
v = v.view(bsz, seqlen, self.n_head, self.head_dim)
|
||||
|
||||
q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v))
|
||||
|
||||
k, v = kv_cache.update(input_pos, k, v)
|
||||
|
||||
attn = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
|
||||
|
||||
attn = attn.transpose(1, 2).contiguous().view(bsz, seqlen, self.hidden_dim)
|
||||
|
||||
attn = self.out_proj(attn)
|
||||
|
||||
return attn
|
||||
|
||||
|
||||
class TransformerBlock(TransformerBlockABC):
|
||||
def __init__(self, n_head: int, ffn_dim: int, hidden_dim: int, max_seq_length: int) -> None:
|
||||
super().__init__(n_head, ffn_dim, hidden_dim, max_seq_length)
|
||||
|
||||
self.attention = Attention(n_head, hidden_dim, max_seq_length)
|
||||
self.feed_forward = FeedForward(hidden_dim, ffn_dim)
|
||||
self.attention_norm = nn.LayerNorm([self.hidden_dim])
|
||||
self.ffn_norm = nn.LayerNorm([self.hidden_dim])
|
||||
|
||||
|
||||
class TransformerDecoder(TransformerDecoderABC):
|
||||
def __init__(
|
||||
self,
|
||||
hidden_dim,
|
||||
n_layer,
|
||||
n_head,
|
||||
ffn_dim,
|
||||
vocab_size,
|
||||
max_seq_length,
|
||||
max_batch_size,
|
||||
) -> None:
|
||||
super().__init__(hidden_dim, n_layer, n_head, ffn_dim, vocab_size, max_seq_length, max_batch_size)
|
||||
|
||||
self.layers = nn.ModuleList( # type: ignore
|
||||
TransformerBlock(n_head, ffn_dim, hidden_dim, max_seq_length) for _ in range(n_layer)
|
||||
)
|
||||
|
||||
|
||||
class T2SDecoder(T2SDecoderABC):
|
||||
def __init__(
|
||||
self,
|
||||
config,
|
||||
max_seq_length=1800,
|
||||
max_batch_size=10,
|
||||
) -> None:
|
||||
super().__init__(config, max_seq_length, max_batch_size)
|
||||
|
||||
self.bert_proj = nn.Linear(1024, self.embedding_dim)
|
||||
self.ar_predict_layer = nn.Linear(self.hidden_dim, self.vocab_size, bias=False)
|
||||
self.h: TransformerDecoderABC = TransformerDecoder(
|
||||
self.hidden_dim, self.n_layer, self.n_head, self.ffn_dim, self.vocab_size, max_seq_length, max_batch_size
|
||||
)
|
||||
|
||||
self.kv_class = KVCacheHND
|
||||
|
||||
def pre_forward(self, session: T2SSession):
|
||||
attn_mask = session.attn_mask
|
||||
return list(), dict(attn_mask=attn_mask)
|
||||
|
||||
def post_forward(self, idx: int, session: T2SSession) -> None:
|
||||
if idx == 0:
|
||||
prefill_len = session.prefill_len
|
||||
bsz = session.bsz
|
||||
|
||||
range_tensor = torch.arange(self.max_seq_length).view(1, 1, 1, self.max_seq_length)
|
||||
prefill_len_expanded = prefill_len.view(bsz, 1, 1, 1)
|
||||
attn_mask = range_tensor < prefill_len_expanded
|
||||
attn_mask = attn_mask.expand(-1, self.n_head, -1, -1)
|
||||
|
||||
session.attn_mask = attn_mask
|
||||
|
||||
attn_mask = session.attn_mask
|
||||
input_pos = session.input_pos
|
||||
attn_mask[torch.arange(session.bsz), :, :, input_pos] = True
|
||||
|
||||
|
||||
class CUDAGraphCache(CUDAGraphCacheABC):
|
||||
def __init__(
|
||||
self,
|
||||
decoder,
|
||||
) -> None:
|
||||
super().__init__(decoder)
|
||||
if torch.cuda.is_available():
|
||||
self.attn_mask = (
|
||||
torch.randint(0, 2, (decoder.max_batch_size, decoder.n_head, 1, decoder.max_seq_length))
|
||||
.bool()
|
||||
.to(self.device, self.dtype)
|
||||
)
|
||||
|
||||
def release_graph(self, session: T2SSession):
|
||||
if session.id != self.id:
|
||||
self.assigned = False
|
||||
else:
|
||||
del (
|
||||
session.graph,
|
||||
session.xy_pos_,
|
||||
session.xy_dec_,
|
||||
session.input_pos,
|
||||
session.kv_cache,
|
||||
session.attn_mask,
|
||||
)
|
||||
|
||||
def get_cache_graph(self, session: T2SSession):
|
||||
assert self.graph
|
||||
session.graph = self.graph
|
||||
session.stream = self.stream
|
||||
|
||||
session.xy_pos_ = self.xy_pos
|
||||
session.xy_dec_ = self.xy_dec
|
||||
session.input_pos = self.input_pos.copy_(session.input_pos)
|
||||
|
||||
session.attn_mask = self.attn_mask
|
||||
|
||||
for cache, cache_ in zip(self.kv_cache, session.kv_cache):
|
||||
cache.sync_cache(cache_)
|
||||
|
||||
def capture_new_graph(self, session: T2SSession):
|
||||
session.xy_pos_ = self.xy_pos.clone()
|
||||
session.xy_dec_ = self.xy_dec.clone()
|
||||
session.input_pos = self.input_pos.clone().copy_(session.input_pos)
|
||||
|
||||
session.attn_mask = self.attn_mask.clone().copy_(session.attn_mask)
|
||||
|
||||
args, kwds = self.decoder.pre_forward(session)
|
||||
graph = self.decoder.capture(self.input_pos, self.xy_pos, self.xy_dec, self.kv_cache, *args, **kwds)
|
||||
session.graph = graph
|
||||
session.stream = torch.cuda.Stream() # type: ignore
|
||||
@ -0,0 +1,176 @@
|
||||
from typing import MutableSequence
|
||||
|
||||
import sageattention # type: ignore
|
||||
import torch
|
||||
|
||||
from .. import nn
|
||||
from ..structs import T2SSession
|
||||
from ..t2s_model_abc import (
|
||||
AttentionABC,
|
||||
CUDAGraphCacheABC,
|
||||
FeedForward,
|
||||
KVCacheHND,
|
||||
KVCacheProtocol,
|
||||
T2SDecoderABC,
|
||||
TransformerBlockABC,
|
||||
TransformerDecoderABC,
|
||||
)
|
||||
|
||||
Tensor = torch.Tensor
|
||||
|
||||
|
||||
class Attention(AttentionABC):
|
||||
def __init__(self, n_head, hidden_dim, max_seq_length):
|
||||
super().__init__(n_head, hidden_dim, max_seq_length)
|
||||
|
||||
# key, query, value projections for all heads, but in a batch
|
||||
self.in_proj = nn.Linear(hidden_dim, hidden_dim * 3, bias=True)
|
||||
self.out_proj = nn.Linear(hidden_dim, hidden_dim, bias=True)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: Tensor,
|
||||
input_pos: Tensor,
|
||||
kv_cache: KVCacheProtocol,
|
||||
cu_seqlens_q: Tensor,
|
||||
cu_seqlens_kv: Tensor,
|
||||
) -> Tensor:
|
||||
bsz, seqlen, _ = x.shape
|
||||
|
||||
q, k, v = self.in_proj(x).chunk(3, dim=-1)
|
||||
|
||||
q = q.view(bsz, seqlen, self.n_head, self.head_dim)
|
||||
k = k.view(bsz, seqlen, self.n_head, self.head_dim)
|
||||
v = v.view(bsz, seqlen, self.n_head, self.head_dim)
|
||||
|
||||
q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v))
|
||||
|
||||
k, v = kv_cache.update(input_pos, k, v)
|
||||
|
||||
attn: Tensor = sageattention.sageattn_varlen(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
cu_seqlens_q=cu_seqlens_q,
|
||||
cu_seqlens_kv=cu_seqlens_kv,
|
||||
max_seqlen_q=1,
|
||||
max_seqlen_k=self.max_seq_length,
|
||||
)
|
||||
|
||||
attn = attn.transpose(1, 2).contiguous().view(bsz, seqlen, self.hidden_dim)
|
||||
|
||||
attn = self.out_proj(attn)
|
||||
|
||||
return attn
|
||||
|
||||
|
||||
class TransformerBlock(TransformerBlockABC):
|
||||
def __init__(self, n_head, ffn_dim, hidden_dim, max_seq_length) -> None:
|
||||
super().__init__(n_head, ffn_dim, hidden_dim, max_seq_length)
|
||||
|
||||
self.attention = Attention(n_head, hidden_dim, max_seq_length)
|
||||
self.feed_forward = FeedForward(hidden_dim, ffn_dim)
|
||||
self.attention_norm = nn.LayerNorm([self.hidden_dim])
|
||||
self.ffn_norm = nn.LayerNorm([self.hidden_dim])
|
||||
|
||||
|
||||
class TransformerDecoder(TransformerDecoderABC):
|
||||
def __init__(
|
||||
self,
|
||||
hidden_dim,
|
||||
n_layer,
|
||||
n_head,
|
||||
ffn_dim,
|
||||
vocab_size,
|
||||
max_seq_length,
|
||||
max_batch_size,
|
||||
) -> None:
|
||||
super().__init__(hidden_dim, n_layer, n_head, ffn_dim, vocab_size, max_seq_length, max_batch_size)
|
||||
|
||||
self.layers = nn.ModuleList( # type: ignore
|
||||
TransformerBlock(n_head, ffn_dim, hidden_dim, max_seq_length) for _ in range(n_layer)
|
||||
)
|
||||
|
||||
|
||||
class T2SDecoder(T2SDecoderABC):
|
||||
def __init__(
|
||||
self,
|
||||
config,
|
||||
max_seq_length=1800,
|
||||
max_batch_size=10,
|
||||
) -> None:
|
||||
super().__init__(config, max_seq_length, max_batch_size)
|
||||
|
||||
self.bert_proj = nn.Linear(1024, self.embedding_dim)
|
||||
self.ar_predict_layer = nn.Linear(self.hidden_dim, self.vocab_size, bias=False)
|
||||
self.h: TransformerDecoderABC = TransformerDecoder(
|
||||
self.hidden_dim, self.n_layer, self.n_head, self.ffn_dim, self.vocab_size, max_seq_length, max_batch_size
|
||||
)
|
||||
|
||||
self.kv_class = KVCacheHND
|
||||
|
||||
def pre_forward(self, session: T2SSession) -> tuple[list[Tensor], dict[str, Tensor]]:
|
||||
return list(), dict(cu_seqlens_q=session.cu_seqlens_q, cu_seqlens_kv=session.cu_seqlens_kv)
|
||||
|
||||
def post_forward(self, idx: int, session: T2SSession):
|
||||
if idx == 0:
|
||||
session.cu_seqlens_q = torch.arange(0, session.bsz + 1, dtype=torch.int32)
|
||||
session.cu_seqlens_kv = torch.cat([torch.tensor(0, dtype=torch.int32), session.input_pos])
|
||||
else:
|
||||
cu_seqlens_q = session.cu_seqlens_q
|
||||
cu_seqlens_kv = session.cu_seqlens_kv
|
||||
cu_seqlens_kv.add_(cu_seqlens_q)
|
||||
|
||||
|
||||
class CUDAGraphCache(CUDAGraphCacheABC):
|
||||
def __init__(
|
||||
self,
|
||||
decoder: T2SDecoder,
|
||||
) -> None:
|
||||
super().__init__(decoder)
|
||||
|
||||
if torch.cuda.is_available():
|
||||
self.cu_seqlens_q = torch.arange(0, decoder.max_batch_size + 1, dtype=torch.int32).to(self.device)
|
||||
self.cu_seqlens_kv = torch.cat([torch.tensor(0, dtype=torch.int32), self.input_pos]).to(self.device)
|
||||
|
||||
def release_graph(self, session: T2SSession):
|
||||
if session.id != self.id:
|
||||
self.assigned = False
|
||||
else:
|
||||
del (
|
||||
session.graph,
|
||||
session.xy_pos_,
|
||||
session.xy_dec_,
|
||||
session.input_pos,
|
||||
session.kv_cache,
|
||||
session.cu_seqlens_q,
|
||||
session.cu_seqlens_kv,
|
||||
)
|
||||
|
||||
def get_cache_graph(self, session: T2SSession):
|
||||
assert self.graph
|
||||
session.graph = self.graph
|
||||
session.stream = self.stream
|
||||
|
||||
session.xy_pos_ = self.xy_pos
|
||||
session.xy_dec_ = self.xy_dec
|
||||
session.input_pos = self.input_pos.copy_(session.input_pos)
|
||||
|
||||
session.cu_seqlens_q = self.cu_seqlens_q
|
||||
session.cu_seqlens_kv = self.cu_seqlens_kv
|
||||
|
||||
for cache, cache_ in zip(self.kv_cache, session.kv_cache):
|
||||
cache.sync_cache(cache_)
|
||||
|
||||
def capture_new_graph(self, session: T2SSession):
|
||||
session.xy_pos_ = self.xy_pos.clone()
|
||||
session.xy_dec_ = self.xy_dec.clone()
|
||||
session.input_pos = self.input_pos.clone().copy_(session.input_pos)
|
||||
|
||||
session.cu_seqlens_q = self.cu_seqlens_q.clone().copy_(session.cu_seqlens_q)
|
||||
session.cu_seqlens_kv = self.cu_seqlens_kv.clone().copy_(session.cu_seqlens_kv)
|
||||
|
||||
args, kwds = self.decoder.pre_forward(session)
|
||||
graph = self.decoder.capture(self.input_pos, self.xy_pos, self.xy_dec, self.kv_cache, *args, **kwds)
|
||||
session.graph = graph
|
||||
session.stream = torch.cuda.Stream() # type: ignore
|
||||
@ -0,0 +1,165 @@
|
||||
import torch
|
||||
from torch.nn import functional as F
|
||||
|
||||
from .. import nn
|
||||
from ..structs import KVCacheProtocol, T2SSession
|
||||
from ..t2s_model_abc import (
|
||||
AttentionABC,
|
||||
CUDAGraphCacheABC,
|
||||
FeedForward,
|
||||
KVCacheHND,
|
||||
T2SDecoderABC,
|
||||
TransformerBlockABC,
|
||||
TransformerDecoderABC,
|
||||
)
|
||||
|
||||
Tensor = torch.Tensor
|
||||
|
||||
|
||||
class Attention(AttentionABC):
|
||||
def __init__(self, n_head, hidden_dim, max_seq_length):
|
||||
super().__init__(n_head, hidden_dim, max_seq_length)
|
||||
|
||||
# key, query, value projections for all heads, but in a batch
|
||||
self.in_proj = nn.Linear(hidden_dim, hidden_dim * 3, bias=True)
|
||||
self.out_proj = nn.Linear(hidden_dim, hidden_dim, bias=True)
|
||||
|
||||
def __call__(self, x: Tensor, input_pos: Tensor, kv_cache: KVCacheProtocol, attn_mask: Tensor):
|
||||
bsz, seqlen, _ = x.shape
|
||||
|
||||
q, k, v = self.in_proj(x).chunk(3, dim=-1)
|
||||
|
||||
q = q.view(bsz, seqlen, self.n_head, self.head_dim)
|
||||
k = k.view(bsz, seqlen, self.n_head, self.head_dim)
|
||||
v = v.view(bsz, seqlen, self.n_head, self.head_dim)
|
||||
|
||||
q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v))
|
||||
|
||||
k, v = kv_cache.update(input_pos, k, v)
|
||||
|
||||
attn = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
|
||||
|
||||
attn = attn.transpose(1, 2).contiguous().view(bsz, seqlen, self.hidden_dim)
|
||||
|
||||
attn = self.out_proj(attn)
|
||||
|
||||
return attn
|
||||
|
||||
|
||||
class TransformerBlock(TransformerBlockABC):
|
||||
def __init__(self, n_head: int, ffn_dim: int, hidden_dim: int, max_seq_length: int) -> None:
|
||||
super().__init__(n_head, ffn_dim, hidden_dim, max_seq_length)
|
||||
|
||||
self.attention = Attention(n_head, hidden_dim, max_seq_length)
|
||||
self.feed_forward = FeedForward(hidden_dim, ffn_dim)
|
||||
self.attention_norm = nn.LayerNorm([self.hidden_dim])
|
||||
self.ffn_norm = nn.LayerNorm([self.hidden_dim])
|
||||
|
||||
|
||||
class TransformerDecoder(TransformerDecoderABC):
|
||||
def __init__(
|
||||
self,
|
||||
hidden_dim,
|
||||
n_layer,
|
||||
n_head,
|
||||
ffn_dim,
|
||||
vocab_size,
|
||||
max_seq_length,
|
||||
max_batch_size,
|
||||
) -> None:
|
||||
super().__init__(hidden_dim, n_layer, n_head, ffn_dim, vocab_size, max_seq_length, max_batch_size)
|
||||
|
||||
self.layers = nn.ModuleList( # type: ignore
|
||||
TransformerBlock(n_head, ffn_dim, hidden_dim, max_seq_length) for _ in range(n_layer)
|
||||
)
|
||||
|
||||
|
||||
class T2SDecoder(T2SDecoderABC):
|
||||
def __init__(
|
||||
self,
|
||||
config,
|
||||
max_seq_length=1800,
|
||||
max_batch_size=10,
|
||||
) -> None:
|
||||
super().__init__(config, max_seq_length, max_batch_size)
|
||||
|
||||
self.bert_proj = nn.Linear(1024, self.embedding_dim)
|
||||
self.ar_predict_layer = nn.Linear(self.hidden_dim, self.vocab_size, bias=False)
|
||||
self.h: TransformerDecoderABC = TransformerDecoder(
|
||||
self.hidden_dim, self.n_layer, self.n_head, self.ffn_dim, self.vocab_size, max_seq_length, max_batch_size
|
||||
)
|
||||
|
||||
self.kv_class = KVCacheHND
|
||||
|
||||
def pre_forward(self, session: T2SSession):
|
||||
attn_mask = session.attn_mask
|
||||
return list(), dict(attn_mask=attn_mask)
|
||||
|
||||
def post_forward(self, idx: int, session: T2SSession) -> None:
|
||||
if idx == 0:
|
||||
prefill_len = session.prefill_len
|
||||
bsz = session.bsz
|
||||
|
||||
range_tensor = torch.arange(self.max_seq_length).view(1, 1, 1, self.max_seq_length)
|
||||
prefill_len_expanded = prefill_len.view(bsz, 1, 1, 1)
|
||||
attn_mask = range_tensor < prefill_len_expanded
|
||||
attn_mask = attn_mask.expand(-1, self.n_head, -1, -1)
|
||||
|
||||
session.attn_mask = attn_mask
|
||||
|
||||
attn_mask = session.attn_mask
|
||||
input_pos = session.input_pos
|
||||
attn_mask[torch.arange(session.bsz), :, :, input_pos] = True
|
||||
|
||||
|
||||
class CUDAGraphCache(CUDAGraphCacheABC):
|
||||
def __init__(
|
||||
self,
|
||||
decoder,
|
||||
) -> None:
|
||||
super().__init__(decoder)
|
||||
if torch.cuda.is_available():
|
||||
self.attn_mask = (
|
||||
torch.randint(0, 2, (decoder.max_batch_size, decoder.n_head, 1, decoder.max_seq_length))
|
||||
.bool()
|
||||
.to(self.device, self.dtype)
|
||||
)
|
||||
|
||||
def release_graph(self, session: T2SSession):
|
||||
if session.id != self.id:
|
||||
self.assigned = False
|
||||
else:
|
||||
del (
|
||||
session.graph,
|
||||
session.xy_pos_,
|
||||
session.xy_dec_,
|
||||
session.input_pos,
|
||||
session.kv_cache,
|
||||
session.attn_mask,
|
||||
)
|
||||
|
||||
def get_cache_graph(self, session: T2SSession):
|
||||
assert self.graph
|
||||
session.graph = self.graph
|
||||
session.stream = self.stream
|
||||
|
||||
session.xy_pos_ = self.xy_pos
|
||||
session.xy_dec_ = self.xy_dec
|
||||
session.input_pos = self.input_pos.copy_(session.input_pos)
|
||||
|
||||
session.attn_mask = self.attn_mask
|
||||
|
||||
for cache, cache_ in zip(self.kv_cache, session.kv_cache):
|
||||
cache.sync_cache(cache_)
|
||||
|
||||
def capture_new_graph(self, session: T2SSession):
|
||||
session.xy_pos_ = self.xy_pos.clone()
|
||||
session.xy_dec_ = self.xy_dec.clone()
|
||||
session.input_pos = self.input_pos.clone().copy_(session.input_pos)
|
||||
|
||||
session.attn_mask = self.attn_mask.clone().copy_(session.attn_mask)
|
||||
|
||||
args, kwds = self.decoder.pre_forward(session)
|
||||
graph = self.decoder.capture(self.input_pos, self.xy_pos, self.xy_dec, self.kv_cache, *args, **kwds)
|
||||
session.graph = graph
|
||||
session.stream = torch.cuda.Stream() # type: ignore
|
||||
144
GPT_SoVITS/Accelerate/PyTorch/backends/torch_varlen.py
Normal file
144
GPT_SoVITS/Accelerate/PyTorch/backends/torch_varlen.py
Normal file
@ -0,0 +1,144 @@
|
||||
from typing import NoReturn
|
||||
|
||||
import torch
|
||||
from torch.nn import functional as F
|
||||
|
||||
from .. import nn
|
||||
from ..structs import KVCacheProtocol, T2SSession
|
||||
from ..t2s_model_abc import (
|
||||
AttentionABC,
|
||||
CUDAGraphCacheABC,
|
||||
FeedForward,
|
||||
KVCacheHNDVarlen,
|
||||
T2SDecoderABC,
|
||||
TransformerBlockABC,
|
||||
TransformerDecoderABC,
|
||||
)
|
||||
|
||||
Tensor = torch.Tensor
|
||||
|
||||
|
||||
class Attention(AttentionABC):
|
||||
def __init__(self, n_head, hidden_dim, max_seq_length):
|
||||
super().__init__(n_head, hidden_dim, max_seq_length)
|
||||
|
||||
# key, query, value projections for all heads, but in a batch
|
||||
self.in_proj = nn.Linear(hidden_dim, hidden_dim * 3, bias=True)
|
||||
self.out_proj = nn.Linear(hidden_dim, hidden_dim, bias=True)
|
||||
|
||||
def __call__(self, x: Tensor, input_pos: Tensor, kv_cache: KVCacheProtocol, attn_mask: Tensor):
|
||||
bsz, seqlen, _ = x.shape
|
||||
|
||||
q, k, v = self.in_proj(x).chunk(3, dim=-1)
|
||||
|
||||
q = q.view(bsz, seqlen, self.n_head, self.head_dim)
|
||||
k = k.view(bsz, seqlen, self.n_head, self.head_dim)
|
||||
v = v.view(bsz, seqlen, self.n_head, self.head_dim)
|
||||
|
||||
q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v))
|
||||
|
||||
k, v = kv_cache.update(input_pos, k, v)
|
||||
|
||||
max_idx = input_pos.max()
|
||||
|
||||
q, k, v = map(lambda x: x[..., :max_idx, :], (q, k, v))
|
||||
|
||||
mask = attn_mask[..., :max_idx]
|
||||
|
||||
attn = F.scaled_dot_product_attention(q, k, v, mask)
|
||||
|
||||
attn = attn.transpose(1, 2).contiguous().view(bsz, seqlen, self.hidden_dim)
|
||||
|
||||
attn = self.out_proj(attn)
|
||||
|
||||
return attn
|
||||
|
||||
|
||||
class TransformerBlock(TransformerBlockABC):
|
||||
def __init__(self, n_head: int, ffn_dim: int, hidden_dim: int, max_seq_length: int) -> None:
|
||||
super().__init__(n_head, ffn_dim, hidden_dim, max_seq_length)
|
||||
|
||||
self.attention = Attention(n_head, hidden_dim, max_seq_length)
|
||||
self.feed_forward = FeedForward(hidden_dim, ffn_dim)
|
||||
self.attention_norm = nn.LayerNorm([self.hidden_dim])
|
||||
self.ffn_norm = nn.LayerNorm([self.hidden_dim])
|
||||
|
||||
|
||||
class TransformerDecoder(TransformerDecoderABC):
|
||||
def __init__(
|
||||
self,
|
||||
hidden_dim,
|
||||
n_layer,
|
||||
n_head,
|
||||
ffn_dim,
|
||||
vocab_size,
|
||||
max_seq_length,
|
||||
max_batch_size,
|
||||
) -> None:
|
||||
super().__init__(hidden_dim, n_layer, n_head, ffn_dim, vocab_size, max_seq_length, max_batch_size)
|
||||
|
||||
self.layers = nn.ModuleList( # type: ignore
|
||||
TransformerBlock(n_head, ffn_dim, hidden_dim, max_seq_length) for _ in range(n_layer)
|
||||
)
|
||||
|
||||
|
||||
class T2SDecoder(T2SDecoderABC):
|
||||
def __init__(
|
||||
self,
|
||||
config,
|
||||
max_seq_length=1800,
|
||||
max_batch_size=10,
|
||||
) -> None:
|
||||
super().__init__(config, max_seq_length, max_batch_size)
|
||||
|
||||
self.bert_proj = nn.Linear(1024, self.embedding_dim)
|
||||
self.ar_predict_layer = nn.Linear(self.hidden_dim, self.vocab_size, bias=False)
|
||||
self.h: TransformerDecoderABC = TransformerDecoder(
|
||||
self.hidden_dim, self.n_layer, self.n_head, self.ffn_dim, self.vocab_size, max_seq_length, max_batch_size
|
||||
)
|
||||
|
||||
self.kv_class = KVCacheHNDVarlen
|
||||
|
||||
def capture(
|
||||
self,
|
||||
*args,
|
||||
**kwds,
|
||||
) -> NoReturn:
|
||||
raise NotImplementedError("Cuda Graph Is Not Supported For Varlen Model")
|
||||
|
||||
def pre_forward(self, session: T2SSession):
|
||||
attn_mask = session.attn_mask
|
||||
return list(), dict(attn_mask=attn_mask)
|
||||
|
||||
def post_forward(self, idx: int, session: T2SSession) -> None:
|
||||
if idx == 0:
|
||||
prefill_len = session.prefill_len
|
||||
bsz = session.bsz
|
||||
|
||||
range_tensor = torch.arange(self.max_seq_length).view(1, 1, 1, self.max_seq_length)
|
||||
prefill_len_expanded = prefill_len.view(bsz, 1, 1, 1)
|
||||
attn_mask = range_tensor < prefill_len_expanded
|
||||
attn_mask = attn_mask.expand(-1, self.n_head, -1, -1)
|
||||
|
||||
session.attn_mask = attn_mask
|
||||
|
||||
attn_mask = session.attn_mask
|
||||
input_pos = session.input_pos
|
||||
attn_mask[torch.arange(session.bsz), :, :, input_pos] = True
|
||||
|
||||
|
||||
class CUDAGraphCache(CUDAGraphCacheABC):
|
||||
def __init__(
|
||||
self,
|
||||
decoder,
|
||||
) -> None:
|
||||
super().__init__(decoder, False)
|
||||
|
||||
def release_graph(self, session: T2SSession):
|
||||
raise NotImplementedError("Cuda Graph Is Not Supported For Varlen Model")
|
||||
|
||||
def get_cache_graph(self, session: T2SSession):
|
||||
raise NotImplementedError("Cuda Graph Is Not Supported For Varlen Model")
|
||||
|
||||
def capture_new_graph(self, session: T2SSession):
|
||||
raise NotImplementedError("Cuda Graph Is Not Supported For Varlen Model")
|
||||
69
GPT_SoVITS/Accelerate/PyTorch/nn.py
Normal file
69
GPT_SoVITS/Accelerate/PyTorch/nn.py
Normal file
@ -0,0 +1,69 @@
|
||||
"""
|
||||
Enhanced Type Hint nn.Module
|
||||
Modified From https://github.com/labmlai/labml/blob/master/helpers/labml_helpers/module.py
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
import torch.nn
|
||||
from torch.nn import (
|
||||
functional as functional,
|
||||
)
|
||||
from torch.nn import (
|
||||
utils as utils,
|
||||
)
|
||||
from torch.nn.modules import * # type: ignore # noqa: F403
|
||||
from torch.nn.parameter import (
|
||||
Parameter as Parameter,
|
||||
)
|
||||
|
||||
Tensor = torch.Tensor
|
||||
|
||||
|
||||
class Module(torch.nn.Module):
|
||||
r"""
|
||||
Wraps ``torch.nn.Module`` to overload ``__call__`` instead of
|
||||
``forward`` for better type checking.
|
||||
|
||||
`PyTorch Github issue for clarification <https://github.com/pytorch/pytorch/issues/44605>`_
|
||||
"""
|
||||
|
||||
def _forward_unimplemented(self, *input: Any) -> None:
|
||||
# To stop PyTorch from giving abstract methods warning
|
||||
pass
|
||||
|
||||
def __init_subclass__(cls, **kwargs):
|
||||
if cls.__dict__.get("__call__", None) is None:
|
||||
return
|
||||
|
||||
setattr(cls, "forward", cls.__dict__["__call__"])
|
||||
delattr(cls, "__call__")
|
||||
|
||||
@property
|
||||
def device(self) -> torch.device:
|
||||
params = self.parameters()
|
||||
try:
|
||||
sample_param = next(params)
|
||||
return sample_param.device
|
||||
except StopIteration:
|
||||
raise RuntimeError(f"Unable to determine device of {self.__class__.__name__}") from None
|
||||
|
||||
|
||||
class Linear(torch.nn.Linear):
|
||||
def __call__(self, input: Tensor) -> Tensor:
|
||||
return super().__call__(input)
|
||||
|
||||
|
||||
class Dropout(torch.nn.Dropout):
|
||||
def __call__(self, input: Tensor) -> Tensor:
|
||||
return super().__call__(input)
|
||||
|
||||
|
||||
class Embedding(torch.nn.Embedding):
|
||||
def __call__(self, input: Tensor) -> Tensor:
|
||||
return super().__call__(input)
|
||||
|
||||
|
||||
class LayerNorm(torch.nn.LayerNorm):
|
||||
def __call__(self, input: Tensor) -> Tensor:
|
||||
return super().__call__(input)
|
||||
62
GPT_SoVITS/Accelerate/PyTorch/sample_funcs.py
Normal file
62
GPT_SoVITS/Accelerate/PyTorch/sample_funcs.py
Normal file
@ -0,0 +1,62 @@
|
||||
from typing import Protocol
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
Tensor = torch.Tensor
|
||||
|
||||
|
||||
class SampleProtocol(Protocol):
|
||||
@staticmethod
|
||||
def __call__(
|
||||
logits: Tensor,
|
||||
previous_tokens: Tensor,
|
||||
temperature: float,
|
||||
top_k: int,
|
||||
top_p: float,
|
||||
repetition_penalty: float,
|
||||
) -> Tensor: ...
|
||||
|
||||
|
||||
class sample_naive(SampleProtocol):
|
||||
@staticmethod
|
||||
def __call__(
|
||||
logits: Tensor,
|
||||
previous_tokens: Tensor,
|
||||
temperature: float,
|
||||
top_k: int,
|
||||
top_p: float,
|
||||
repetition_penalty: float,
|
||||
):
|
||||
if temperature <= 1e-5:
|
||||
probs = F.softmax(logits, dim=-1)
|
||||
return torch.argmax(probs, dim=-1, keepdim=True)
|
||||
|
||||
if repetition_penalty != 1.0:
|
||||
previous_tokens = previous_tokens.long()
|
||||
score = torch.gather(logits, dim=1, index=previous_tokens)
|
||||
score = torch.where(
|
||||
score < 0,
|
||||
score * repetition_penalty,
|
||||
score / repetition_penalty,
|
||||
)
|
||||
logits.scatter_(dim=1, index=previous_tokens, src=score)
|
||||
|
||||
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
|
||||
cum_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
|
||||
sorted_indices_to_remove = cum_probs > top_p
|
||||
sorted_indices_to_remove[:, 0] = False # keep at least one option
|
||||
indices_to_remove = sorted_indices_to_remove.scatter(dim=1, index=sorted_indices, src=sorted_indices_to_remove)
|
||||
logits = logits.masked_fill(indices_to_remove, -float("Inf"))
|
||||
|
||||
logits /= temperature
|
||||
|
||||
v, _ = torch.topk(logits, top_k)
|
||||
pivot = v[:, -1].unsqueeze(-1)
|
||||
logits = torch.where(logits < pivot, -float("Inf"), logits)
|
||||
|
||||
probs = F.softmax(logits, dim=-1)
|
||||
q = torch.empty_like(probs).exponential_(1.0)
|
||||
idx_next = torch.argmax(probs / q, dim=-1, keepdim=True).to(dtype=torch.int32)
|
||||
|
||||
return idx_next
|
||||
151
GPT_SoVITS/Accelerate/PyTorch/structs.py
Normal file
151
GPT_SoVITS/Accelerate/PyTorch/structs.py
Normal file
@ -0,0 +1,151 @@
|
||||
"""
|
||||
Modified From https://github.com/XXXXRT666/GPT-SoVITS
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Literal, MutableSequence, Optional, Protocol
|
||||
|
||||
import torch
|
||||
|
||||
from .sample_funcs import SampleProtocol, sample_naive
|
||||
|
||||
Tensor = torch.Tensor
|
||||
|
||||
|
||||
@dataclass
|
||||
class T2SResult:
|
||||
result: list[Tensor] | None = None
|
||||
infer_speed: tuple[float, float] = (0.0, 0.0)
|
||||
status: Literal["Success", "Error"] = "Success"
|
||||
exception: Optional[Exception] = None
|
||||
traceback: Optional[str] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class T2SRequest:
|
||||
x: list[torch.Tensor]
|
||||
x_lens: Tensor
|
||||
prompts: torch.Tensor
|
||||
bert_feature: list[Tensor]
|
||||
valid_length: int
|
||||
top_k: int = 5
|
||||
top_p: float = 1
|
||||
early_stop_num: int = -1
|
||||
temperature: float = 1.0
|
||||
repetition_penalty: float = 1.35
|
||||
use_cuda_graph: bool = False
|
||||
debug: bool = False
|
||||
|
||||
|
||||
class KVCacheProtocol(Protocol):
|
||||
k_cache: Tensor
|
||||
v_cache: Tensor
|
||||
|
||||
def __init__(self, batch_size: int, max_seq_length: int, n_heads: int, head_dim: int) -> None: ...
|
||||
|
||||
def empty(self) -> None: ...
|
||||
|
||||
def update(self, input_pos: Tensor, k_val: Tensor, v_val: Tensor, *args, **kwds) -> tuple[Tensor, Tensor]: ...
|
||||
|
||||
def prefill_kv(self, k_val: Tensor, v_val: Tensor) -> None: ...
|
||||
|
||||
def sync_cache(self, kv_cache: KVCacheProtocol) -> None: ...
|
||||
|
||||
|
||||
class T2SDecoderProtocol(Protocol):
|
||||
max_seq_length: int
|
||||
EOS: int
|
||||
n_head: int
|
||||
|
||||
@property
|
||||
def device(self) -> torch.device: ...
|
||||
|
||||
def embed(self, x: list[Tensor], y: Tensor, bert_features: list[Tensor]) -> Tensor: ...
|
||||
|
||||
|
||||
class T2SEngineProtocol(Protocol):
|
||||
def _handle_request(self, request: T2SRequest) -> tuple[list[Tensor], float]: ...
|
||||
|
||||
def generate(self, request: T2SRequest) -> T2SResult: ...
|
||||
|
||||
|
||||
class T2SSession:
|
||||
def __init__(
|
||||
self,
|
||||
decoder: T2SDecoderProtocol,
|
||||
request: T2SRequest,
|
||||
sapmle_func: type[SampleProtocol] = sample_naive,
|
||||
device: torch.device = torch.device("cpu"),
|
||||
dtype: torch.dtype = torch.float32,
|
||||
):
|
||||
with device:
|
||||
self.decoder = decoder
|
||||
self.request = request
|
||||
self.device = device
|
||||
self.dtype = dtype
|
||||
|
||||
bsz = len(request.x)
|
||||
y_len = request.prompts.size(-1)
|
||||
self.bsz = bsz
|
||||
self.y_len = y_len
|
||||
request.prompts = request.prompts.to(device, torch.int32)
|
||||
|
||||
# Cache
|
||||
self.kv_cache: MutableSequence[KVCacheProtocol]
|
||||
self.sample = sapmle_func()
|
||||
|
||||
# Forward args
|
||||
self.x = [i.to(device) for i in request.x]
|
||||
self.x_lens = request.x_lens.to(torch.int32)
|
||||
self.y = torch.zeros((bsz, decoder.max_seq_length)).to(torch.int32)
|
||||
self.y[:, : request.prompts.shape[-1]] = request.prompts
|
||||
self.bert_feature = [i.to(device, dtype) for i in request.bert_feature]
|
||||
|
||||
self.prefill_len = self.x_lens + request.prompts.size(1)
|
||||
|
||||
self.input_pos = torch.zeros_like(self.prefill_len)
|
||||
self.input_pos.add_(self.prefill_len)
|
||||
|
||||
# CUDA Graph
|
||||
self.stream: Optional[torch.cuda.Stream] = None
|
||||
self.graph: Optional[torch.cuda.CUDAGraph] = None
|
||||
self.xy_pos_: Tensor
|
||||
self.xy_dec_: Tensor
|
||||
|
||||
# EOS
|
||||
self.completed = torch.Tensor([False] * len(self.x)).bool().to(device)
|
||||
self.y_results: list[Tensor] = [None] * len(self.x) # type: ignore
|
||||
|
||||
self.xy_pos = decoder.embed(self.x, request.prompts, self.bert_feature)
|
||||
|
||||
max_len = int(self.prefill_len.max().item())
|
||||
attn_mask = torch.zeros(size=(bsz, max_len, max_len), dtype=torch.bool)
|
||||
|
||||
for bs in range(bsz):
|
||||
pos = int(self.x_lens[bs])
|
||||
seq_len = pos + y_len
|
||||
|
||||
attn_mask[bs, :seq_len, :pos] = True
|
||||
|
||||
ar_mask = ~torch.triu(
|
||||
input=torch.ones(
|
||||
size=(
|
||||
y_len,
|
||||
y_len,
|
||||
),
|
||||
dtype=torch.bool,
|
||||
),
|
||||
diagonal=1,
|
||||
)
|
||||
attn_mask[bs, pos:seq_len, pos:seq_len] = ar_mask
|
||||
|
||||
self.attn_mask = attn_mask
|
||||
self.attn_mask = attn_mask.unsqueeze(0).expand(-1, decoder.n_head, -1, -1)
|
||||
|
||||
self.id: int = -1
|
||||
|
||||
# Sage Attn & Transformer Engine Impl
|
||||
self.cu_seqlens_q: Tensor
|
||||
self.cu_seqlens_kv: Tensor
|
||||
202
GPT_SoVITS/Accelerate/PyTorch/t2s_engine.py
Normal file
202
GPT_SoVITS/Accelerate/PyTorch/t2s_engine.py
Normal file
@ -0,0 +1,202 @@
|
||||
import contextlib
|
||||
import gc
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import traceback
|
||||
from importlib import import_module
|
||||
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
from .structs import T2SEngineProtocol, T2SRequest, T2SResult, T2SSession
|
||||
from .t2s_model_abc import (
|
||||
CUDAGraphCacheABC,
|
||||
T2SDecoderABC,
|
||||
TorchProfiler,
|
||||
)
|
||||
|
||||
torch.set_grad_enabled(False)
|
||||
|
||||
|
||||
class T2SEngine(T2SEngineProtocol):
|
||||
def __init__(
|
||||
self,
|
||||
decoder_model: T2SDecoderABC,
|
||||
device: torch.device = torch.device("cpu"),
|
||||
dtype: torch.dtype = torch.float32,
|
||||
) -> None:
|
||||
assert device.type in {"cpu", "cuda", "mps", "xpu", "mtia"}
|
||||
assert dtype in {torch.float16, torch.bfloat16, torch.float32}
|
||||
|
||||
self.device = device
|
||||
self.dtype = dtype
|
||||
|
||||
self.decoder_model: T2SDecoderABC = decoder_model.to(self.device, self.dtype)
|
||||
|
||||
self.graphcache: CUDAGraphCacheABC = self.init_cache()
|
||||
|
||||
def _handle_request(self, request: T2SRequest):
|
||||
with self.device:
|
||||
decoder = self.decoder_model
|
||||
session = T2SSession(decoder, request, device=self.device, dtype=self.dtype)
|
||||
batch_idx = torch.arange(session.bsz)
|
||||
|
||||
t1 = 0.0
|
||||
infer_speed = 0.0
|
||||
|
||||
torch_profiler = TorchProfiler(request.debug)
|
||||
with torch_profiler.profiler():
|
||||
for idx in tqdm(range(1500)):
|
||||
if idx == 0:
|
||||
session.kv_cache = decoder.init_cache(session.bsz)
|
||||
xy_dec = decoder.h.prefill(session.xy_pos, session.kv_cache, session.attn_mask)
|
||||
xy_dec = xy_dec[None, batch_idx, session.input_pos - 1]
|
||||
else:
|
||||
if request.use_cuda_graph and session.graph is None and torch.cuda.is_available():
|
||||
self.graphcache.assign_graph(session)
|
||||
|
||||
with torch_profiler.record("AR"):
|
||||
if session.graph:
|
||||
assert session.stream
|
||||
session.stream.wait_stream(torch.cuda.default_stream())
|
||||
with torch.cuda.stream(session.stream):
|
||||
session.xy_pos_.copy_(session.xy_pos)
|
||||
session.graph.replay()
|
||||
xy_dec = session.xy_dec_.clone()
|
||||
else:
|
||||
args, kwds = decoder.pre_forward(session)
|
||||
xy_dec = decoder.h(
|
||||
session.input_pos,
|
||||
session.xy_pos,
|
||||
session.kv_cache,
|
||||
*args,
|
||||
**kwds,
|
||||
)
|
||||
|
||||
with torch.cuda.stream(session.stream) if session.stream is not None else contextlib.nullcontext():
|
||||
decoder.post_forward(idx, session)
|
||||
logits = decoder.ar_predict_layer(xy_dec[:, -1])
|
||||
|
||||
if idx == 0:
|
||||
logits[:, -1] = float("-inf")
|
||||
|
||||
with torch_profiler.record("Sampling"):
|
||||
samples = session.sample(
|
||||
logits=logits,
|
||||
previous_tokens=session.y[:, : session.y_len + idx],
|
||||
top_k=request.top_k,
|
||||
top_p=request.top_p,
|
||||
repetition_penalty=request.repetition_penalty,
|
||||
temperature=request.temperature,
|
||||
)
|
||||
session.y[batch_idx, session.y_len + idx] = samples
|
||||
session.input_pos.add_(1)
|
||||
|
||||
with torch_profiler.record("EOS"):
|
||||
argmax_token = torch.argmax(logits, dim=-1)
|
||||
sample_token = samples.squeeze(1)
|
||||
EOS_mask = (argmax_token == decoder.EOS) | (sample_token == decoder.EOS)
|
||||
|
||||
newly_done_mask = EOS_mask & (~session.completed)
|
||||
newly_done_indices = newly_done_mask.nonzero()
|
||||
|
||||
if newly_done_indices.numel() > 0:
|
||||
for i in newly_done_indices:
|
||||
print(i, i.shape, newly_done_indices, newly_done_indices.shape)
|
||||
session.y_results[i] = session.y[i, session.y_len : session.y_len + idx]
|
||||
session.completed[newly_done_indices] = True
|
||||
|
||||
if torch.all(session.completed).item():
|
||||
if session.y.sum() == 0:
|
||||
session.y_results = [torch.tensor(0) for _ in range(session.bsz)]
|
||||
tqdm.write("Bad Zero Prediction")
|
||||
else:
|
||||
tqdm.write(
|
||||
f"T2S Decoding EOS {session.prefill_len.tolist().__str__().strip('[]')} -> \n{[i.size(-1) for i in session.y_results].__str__().strip('[]')}"
|
||||
)
|
||||
tqdm.write(f"Infer Speed: {(idx - 1) / (time.perf_counter() - t1):.2f} token/s")
|
||||
infer_speed = (idx - 1) / (time.perf_counter() - t1)
|
||||
break
|
||||
|
||||
if (request.early_stop_num != -1 and idx >= request.early_stop_num) or idx == 1499:
|
||||
for i in range(session.bsz):
|
||||
if not session.completed[i].item():
|
||||
session.y_results[i] = session.y[i, session.y_len : session.y_len + 1499]
|
||||
session.completed[i] = True
|
||||
break
|
||||
|
||||
with torch_profiler.record("NextPos"):
|
||||
y_emb = decoder.ar_audio_embedding(samples)
|
||||
session.xy_pos = decoder.ar_audio_position(session.input_pos - session.x_lens, y_emb)
|
||||
|
||||
if idx == 1:
|
||||
torch_profiler.start()
|
||||
t1 = time.perf_counter()
|
||||
|
||||
if idx == 51:
|
||||
torch_profiler.end()
|
||||
|
||||
if idx % 100 == 0:
|
||||
match session.device.type:
|
||||
case "cuda":
|
||||
torch.cuda.empty_cache()
|
||||
case "mps":
|
||||
torch.mps.empty_cache()
|
||||
case "xpu":
|
||||
torch.xpu.empty_cache()
|
||||
case "mtia":
|
||||
torch.mtia.empty_cache()
|
||||
|
||||
match session.device.type:
|
||||
case "cuda":
|
||||
if session.stream is not None:
|
||||
torch.cuda.current_stream().wait_stream(session.stream)
|
||||
torch.cuda.empty_cache()
|
||||
case "mps":
|
||||
torch.mps.empty_cache()
|
||||
case "xpu":
|
||||
torch.xpu.empty_cache()
|
||||
case "mtia":
|
||||
torch.mtia.empty_cache()
|
||||
case "cpu":
|
||||
gc.collect()
|
||||
|
||||
torch_profiler.end()
|
||||
if request.use_cuda_graph and torch.cuda.is_available():
|
||||
self.graphcache.release_graph(session)
|
||||
return session.y_results[: request.valid_length], infer_speed
|
||||
|
||||
def generate(self, request: T2SRequest):
|
||||
try:
|
||||
result, infer_speed = self._handle_request(request)
|
||||
t2s_result = T2SResult(result=result, infer_speed=infer_speed, status="Success")
|
||||
except Exception as e:
|
||||
t2s_result = T2SResult(status="Error", exception=e, traceback=traceback.format_exc())
|
||||
return t2s_result
|
||||
|
||||
@staticmethod
|
||||
def load_decoder(weights_path: os.PathLike, max_batch_size: int = 1, backend: str = "Flash Attn CUDAGraph"):
|
||||
print(f"Loading Text2Semantic Weights from {weights_path} with {backend} Backend")
|
||||
module_path = f".backends.{backend.lower().replace('-', '_')}"
|
||||
decoder_cls_name = "T2SDecoder"
|
||||
decoder_mod = import_module(module_path, package=__package__)
|
||||
decoder_cls: type[T2SDecoderABC] = getattr(decoder_mod, decoder_cls_name)
|
||||
dict_s1 = torch.load(weights_path, map_location="cpu", weights_only=False, mmap=True)
|
||||
config = dict_s1["config"]
|
||||
decoder: T2SDecoderABC = decoder_cls(config, max_batch_size=max_batch_size)
|
||||
state_dict = dict_s1["weight"]
|
||||
decoder.load_state_dict(state_dict)
|
||||
|
||||
return decoder.eval()
|
||||
|
||||
def init_cache(self):
|
||||
assert self.decoder_model
|
||||
|
||||
module_name = self.decoder_model.__class__.__module__
|
||||
module = sys.modules.get(module_name)
|
||||
assert module
|
||||
|
||||
target_class: type[CUDAGraphCacheABC] = getattr(module, "CUDAGraphCache")
|
||||
|
||||
return target_class(self.decoder_model)
|
||||
668
GPT_SoVITS/Accelerate/PyTorch/t2s_model_abc.py
Normal file
668
GPT_SoVITS/Accelerate/PyTorch/t2s_model_abc.py
Normal file
@ -0,0 +1,668 @@
|
||||
"""
|
||||
Modified From https://github.com/XXXXRT666/GPT-SoVITS
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
import os
|
||||
import random
|
||||
from abc import ABC, abstractmethod
|
||||
from contextlib import nullcontext
|
||||
from typing import MutableSequence
|
||||
|
||||
import torch
|
||||
import torch._inductor.config
|
||||
import torch.nn.functional as F
|
||||
from torch.cuda.graphs import CUDAGraph
|
||||
from torch.profiler import ProfilerAction, tensorboard_trace_handler
|
||||
|
||||
from . import nn
|
||||
from .structs import KVCacheProtocol, T2SDecoderProtocol, T2SSession
|
||||
|
||||
Tensor = torch.Tensor
|
||||
|
||||
|
||||
class TokenEmbedding(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
embedding_dim: int,
|
||||
vocab_size: int,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.vocab_size = vocab_size
|
||||
self.embedding_dim = embedding_dim
|
||||
|
||||
self.word_embeddings = nn.Embedding(self.vocab_size, self.embedding_dim)
|
||||
|
||||
@property
|
||||
def weight(self) -> Tensor:
|
||||
return self.word_embeddings.weight
|
||||
|
||||
def embedding(self, index: int) -> Tensor:
|
||||
return self.word_embeddings.weight[index : index + 1]
|
||||
|
||||
def __call__(self, x: Tensor):
|
||||
x = self.word_embeddings(x)
|
||||
return x
|
||||
|
||||
|
||||
class SinePositionalEmbedding(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
embedding_dim: int,
|
||||
scale: bool = False,
|
||||
alpha: bool = False,
|
||||
max_batch_size: int = 10,
|
||||
max_seq_len: int = 1800,
|
||||
):
|
||||
super().__init__()
|
||||
self.embedding_dim = embedding_dim
|
||||
self.x_scale = math.sqrt(embedding_dim) if scale else 1.0
|
||||
self.alpha = nn.Parameter(torch.ones(1), requires_grad=alpha)
|
||||
self.max_batch_size = max_batch_size
|
||||
self.max_seq_len = max_seq_len
|
||||
|
||||
self.reverse = False
|
||||
self.register_buffer("pe", torch.zeros(max_batch_size, max_seq_len, embedding_dim), persistent=False)
|
||||
self.pe: torch.Tensor
|
||||
self.compute_pe()
|
||||
|
||||
def compute_pe(self):
|
||||
"""Reset the positional encodings."""
|
||||
if self.reverse:
|
||||
position = torch.arange(self.max_seq_len - 1, -1, -1.0, dtype=torch.float32).unsqueeze(1)
|
||||
else:
|
||||
position = torch.arange(self.max_seq_len, dtype=torch.float32).unsqueeze(1)
|
||||
div_term = torch.exp(
|
||||
torch.arange(0, self.embedding_dim, 2, dtype=torch.float32) * -(math.log(10000.0) / self.embedding_dim)
|
||||
)
|
||||
pe = self.pe
|
||||
pe[:, :, 0::2] = torch.sin(position * div_term)
|
||||
pe[:, :, 1::2] = torch.cos(position * div_term)
|
||||
|
||||
def __call__(self, input_pos: Tensor, x: Tensor) -> Tensor:
|
||||
"""
|
||||
Args:
|
||||
input_pos (Tensor): [batch_size, ]
|
||||
x (Tensor): [batch_size, 1, embed_dim]
|
||||
|
||||
Returns:
|
||||
embedded_x (Tensor): [batch_size, 1, embed_dim]
|
||||
"""
|
||||
|
||||
batch_size = x.shape[0]
|
||||
pe_values = self.pe[torch.arange(batch_size), input_pos - 1] # (batch_size, embed_dim)
|
||||
|
||||
return x * self.x_scale + self.alpha * pe_values.unsqueeze(1) # (batch_size, 1, embed_dim)
|
||||
|
||||
def prefill(self, x: Tensor) -> Tensor:
|
||||
"""
|
||||
Args:
|
||||
x (Tensor): [batch_size, seq_len, embed_dim]
|
||||
|
||||
Returns:
|
||||
embedded_x (Tensor): [batch_size, seq_len, embed_dim]
|
||||
"""
|
||||
|
||||
pe_values = self.pe[:, : x.shape[-2]]
|
||||
return x * self.x_scale + self.alpha.item() * pe_values
|
||||
|
||||
|
||||
class KVCacheABC(nn.Module, ABC, KVCacheProtocol):
|
||||
def __init__(self, batch_size: int, max_seq_length: int, n_heads: int, head_dim: int) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.n_head = n_heads
|
||||
self.head_dim = head_dim
|
||||
self.batch_size = batch_size
|
||||
self.max_seq_length = max_seq_length
|
||||
|
||||
self.k_cache: Tensor
|
||||
self.v_cache: Tensor
|
||||
|
||||
def empty(self):
|
||||
self.k_cache.zero_()
|
||||
self.v_cache.zero_()
|
||||
|
||||
@abstractmethod
|
||||
def update(self, input_pos: Tensor, k_val: Tensor, v_val: Tensor, *args, **kwds) -> tuple[Tensor, Tensor]: ...
|
||||
|
||||
@abstractmethod
|
||||
def prefill_kv(self, k_val: Tensor, v_val: Tensor) -> None: ...
|
||||
|
||||
def sync_cache(self, kv_cache: KVCacheProtocol):
|
||||
self.k_cache.copy_(kv_cache.k_cache)
|
||||
self.v_cache.copy_(kv_cache.v_cache)
|
||||
|
||||
|
||||
class KVCacheNHD(KVCacheABC):
|
||||
def __init__(self, batch_size, max_seq_length, n_heads, head_dim):
|
||||
super().__init__(batch_size, max_seq_length, n_heads, head_dim)
|
||||
|
||||
assert batch_size > 0
|
||||
cache_shape = (batch_size, max_seq_length, n_heads, head_dim)
|
||||
|
||||
self.register_buffer("k_cache", torch.zeros(size=cache_shape), persistent=False)
|
||||
self.register_buffer("v_cache", torch.zeros(size=cache_shape), persistent=False)
|
||||
|
||||
def update(self, input_pos: Tensor, k_val: Tensor, v_val: Tensor):
|
||||
# input_pos: [B, ], k_val: [B, 1, H, D]
|
||||
|
||||
index = (
|
||||
(input_pos - 1)
|
||||
.unsqueeze(-1)
|
||||
.unsqueeze(-1)
|
||||
.unsqueeze(-1)
|
||||
.expand(
|
||||
-1,
|
||||
-1,
|
||||
self.n_head,
|
||||
self.head_dim,
|
||||
)
|
||||
.to(torch.int64)
|
||||
) # (bs, 1, num_head, head_dim)
|
||||
|
||||
k_out = self.k_cache
|
||||
v_out = self.v_cache
|
||||
k_out.scatter_(1, index, k_val)
|
||||
v_out.scatter_(1, index, v_val)
|
||||
|
||||
return k_out, v_out
|
||||
|
||||
def empty(self):
|
||||
self.k_cache.zero_()
|
||||
self.v_cache.zero_()
|
||||
|
||||
def prefill_kv(self, k_val: Tensor, v_val: Tensor):
|
||||
# input_pos: int, k_val: [B, S, H, D]
|
||||
|
||||
self.k_cache[:, : k_val.shape[1]] = k_val
|
||||
self.v_cache[:, : v_val.shape[1]] = v_val
|
||||
|
||||
|
||||
class KVCacheHND(KVCacheABC):
|
||||
def __init__(self, batch_size, max_seq_length, n_heads, head_dim):
|
||||
super().__init__(batch_size, max_seq_length, n_heads, head_dim)
|
||||
|
||||
cache_shape = (batch_size, n_heads, max_seq_length, head_dim)
|
||||
|
||||
self.register_buffer("k_cache", torch.zeros(size=cache_shape), persistent=False)
|
||||
self.register_buffer("v_cache", torch.zeros(size=cache_shape), persistent=False)
|
||||
|
||||
def update(self, input_pos: Tensor, k_val: Tensor, v_val: Tensor):
|
||||
# input_pos: [B, ], k_val: [B, H, 1, D]
|
||||
|
||||
index = (
|
||||
(input_pos - 1)
|
||||
.unsqueeze(-1)
|
||||
.unsqueeze(-1)
|
||||
.unsqueeze(-1)
|
||||
.expand(
|
||||
-1,
|
||||
self.n_head,
|
||||
-1,
|
||||
self.head_dim,
|
||||
)
|
||||
.to(torch.int64)
|
||||
) # (bs, num_head, 1, head_dim)
|
||||
|
||||
k_out = self.k_cache
|
||||
v_out = self.v_cache
|
||||
k_out.scatter_(2, index, k_val)
|
||||
v_out.scatter_(2, index, v_val)
|
||||
|
||||
return k_out, v_out
|
||||
|
||||
def empty(self):
|
||||
self.k_cache.zero_()
|
||||
self.v_cache.zero_()
|
||||
|
||||
def prefill_kv(self, k_val: Tensor, v_val: Tensor):
|
||||
# input_pos: int, k_val: [B, S, H, D]
|
||||
|
||||
self.k_cache[..., : k_val.shape[1], :] = k_val.transpose(1, 2)
|
||||
self.v_cache[..., : v_val.shape[1], :] = v_val.transpose(1, 2)
|
||||
|
||||
|
||||
class KVCacheHNDVarlen(KVCacheABC):
|
||||
def __init__(self, batch_size, max_seq_length, n_heads, head_dim):
|
||||
super().__init__(batch_size, max_seq_length, n_heads, head_dim)
|
||||
|
||||
cache_shape = (batch_size, n_heads, max_seq_length, head_dim)
|
||||
self.cache_idx: Tensor
|
||||
|
||||
self.register_buffer("cache_idx", torch.arange(batch_size), persistent=False)
|
||||
self.register_buffer("k_cache", torch.zeros(size=cache_shape), persistent=False)
|
||||
self.register_buffer("v_cache", torch.zeros(size=cache_shape), persistent=False)
|
||||
|
||||
def update(self, input_pos: Tensor, k_val: Tensor, v_val: Tensor):
|
||||
# input_pos: [B, ], k_val: [B, H, 1, D]
|
||||
|
||||
k_out = self.k_cache
|
||||
v_out = self.v_cache
|
||||
|
||||
k_out[self.cache_idx, :, input_pos - 1, :] = k_val
|
||||
v_out[self.cache_idx, :, input_pos - 1, :] = v_val
|
||||
|
||||
return k_out, v_out
|
||||
|
||||
def empty(self):
|
||||
self.k_cache.zero_()
|
||||
self.v_cache.zero_()
|
||||
|
||||
def prefill_kv(self, k_val: Tensor, v_val: Tensor):
|
||||
# input_pos: int, k_val: [B, S, H, D]
|
||||
|
||||
self.k_cache[..., : k_val.shape[1], :] = k_val.transpose(1, 2)
|
||||
self.v_cache[..., : v_val.shape[1], :] = v_val.transpose(1, 2)
|
||||
|
||||
|
||||
class AttentionABC(nn.Module, ABC):
|
||||
def __init__(self, n_head: int, hidden_dim: int, max_seq_length: int):
|
||||
super().__init__()
|
||||
|
||||
self.n_head = n_head
|
||||
self.hidden_dim = hidden_dim
|
||||
assert hidden_dim % n_head == 0
|
||||
self.head_dim = hidden_dim // n_head
|
||||
|
||||
self.max_seq_length = max_seq_length
|
||||
|
||||
# key, query, value projections for all heads, but in a batch
|
||||
self.in_proj: nn.Linear
|
||||
self.out_proj: nn.Linear
|
||||
|
||||
self._register_load_state_dict_pre_hook(self.load_hook)
|
||||
|
||||
def load_hook(self, state_dict: dict[str, Tensor], prefix, *args):
|
||||
keys_to_modify = [key for key in state_dict if "in_proj_" in key]
|
||||
for key in keys_to_modify:
|
||||
new_key = key.replace("in_proj_", "in_proj.") # in_proj_ -> in_proj.
|
||||
state_dict[new_key] = state_dict.pop(key)
|
||||
|
||||
@abstractmethod
|
||||
def __call__(self, x: Tensor, input_pos: Tensor, kv_cache: KVCacheProtocol, *args, **kwds) -> Tensor: ...
|
||||
|
||||
def prefill(self, x: Tensor, kv_cache: KVCacheProtocol, attn_mask: Tensor) -> Tensor:
|
||||
bsz, seqlen, _ = x.shape
|
||||
|
||||
q, k, v = self.in_proj(x.unsqueeze(0)).chunk(3, dim=-1)
|
||||
|
||||
q, k, v = map(lambda x: x.contiguous().view(bsz, seqlen, self.n_head, self.head_dim), (q, k, v))
|
||||
|
||||
kv_cache.prefill_kv(k, v)
|
||||
|
||||
q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v))
|
||||
|
||||
attn = F.scaled_dot_product_attention(q, k, v, attn_mask)
|
||||
|
||||
attn = attn.transpose(1, 2).contiguous().view(1, -1, self.hidden_dim)
|
||||
|
||||
output = self.out_proj(attn)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(self, dim: int, hidden_dim: int) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.linear1 = nn.Linear(dim, hidden_dim, bias=True)
|
||||
self.linear2 = nn.Linear(hidden_dim, dim, bias=True)
|
||||
|
||||
def __call__(self, x: Tensor):
|
||||
return self.linear2(F.relu(self.linear1(x), inplace=True))
|
||||
|
||||
|
||||
class TransformerBlockABC(nn.Module, ABC):
|
||||
def __init__(self, n_head: int, ffn_dim: int, hidden_dim: int, max_seq_length: int) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.hidden_dim = hidden_dim
|
||||
self.max_seq_length = max_seq_length
|
||||
|
||||
self.attention: AttentionABC
|
||||
self.feed_forward: FeedForward
|
||||
self.attention_norm: nn.LayerNorm
|
||||
self.ffn_norm: nn.LayerNorm
|
||||
|
||||
self._register_load_state_dict_pre_hook(self.load_hook)
|
||||
|
||||
def load_hook(self, state_dict: dict[str, Tensor], prefix, *args):
|
||||
for key in list(state_dict.keys()):
|
||||
new_key = (
|
||||
key.replace("self_attn", "attention")
|
||||
.replace("linear", "feed_forward.linear")
|
||||
.replace("norm1", "attention_norm")
|
||||
.replace("norm2", "ffn_norm")
|
||||
)
|
||||
state_dict[new_key] = state_dict.pop(key)
|
||||
|
||||
def __call__(self, x: Tensor, input_pos: Tensor, kv_cache: KVCacheProtocol, *args, **kwds):
|
||||
h = self.attention_norm(
|
||||
x
|
||||
+ self.attention(
|
||||
x,
|
||||
input_pos,
|
||||
kv_cache,
|
||||
*args,
|
||||
**kwds,
|
||||
)
|
||||
)
|
||||
out = self.ffn_norm(h + self.feed_forward(h))
|
||||
return out
|
||||
|
||||
def prefill(
|
||||
self,
|
||||
x: Tensor,
|
||||
kv_cache: KVCacheProtocol,
|
||||
attn_mask: Tensor,
|
||||
) -> Tensor:
|
||||
h = self.attention_norm(
|
||||
x
|
||||
+ self.attention.prefill(
|
||||
x,
|
||||
kv_cache,
|
||||
attn_mask,
|
||||
)
|
||||
)
|
||||
out = self.ffn_norm(h + self.feed_forward(h))
|
||||
return out
|
||||
|
||||
|
||||
class TransformerDecoderABC(nn.Module, ABC):
|
||||
def __init__(
|
||||
self,
|
||||
hidden_dim: int,
|
||||
n_layer: int,
|
||||
n_head: int,
|
||||
ffn_dim: int,
|
||||
vocab_size: int,
|
||||
max_seq_length: int,
|
||||
max_batch_size: int,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.hidden_dim = hidden_dim
|
||||
self.n_head = n_head
|
||||
assert hidden_dim % n_head == 0
|
||||
|
||||
self.head_dim = hidden_dim // n_head
|
||||
self.vocab_size = vocab_size
|
||||
|
||||
self.n_layer = n_layer
|
||||
|
||||
self.layers: MutableSequence[TransformerBlockABC]
|
||||
|
||||
self.max_seq_length = max_seq_length
|
||||
self.max_batch_size = max_batch_size
|
||||
|
||||
def __call__(self, input_pos: Tensor, x: Tensor, kv_caches: MutableSequence[KVCacheProtocol], *args, **kwds):
|
||||
for layer, kv_cache in zip(self.layers, kv_caches):
|
||||
x = layer(x, input_pos, kv_cache, *args, **kwds)
|
||||
return x
|
||||
|
||||
def prefill(self, x: Tensor, kv_caches: MutableSequence[KVCacheProtocol], attn_mask: Tensor):
|
||||
for layer, kv_cache in zip(self.layers, kv_caches):
|
||||
x = layer.prefill(x, kv_cache, attn_mask)
|
||||
return x
|
||||
|
||||
|
||||
class T2SDecoderABC(nn.Module, ABC, T2SDecoderProtocol):
|
||||
def __init__(
|
||||
self,
|
||||
config: dict,
|
||||
max_seq_length: int = 1800,
|
||||
max_batch_size: int = 10,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
hidden_dim: int = config["model"]["hidden_dim"]
|
||||
embedding_dim: int = config["model"]["embedding_dim"]
|
||||
n_head: int = config["model"]["head"]
|
||||
n_layer: int = config["model"]["n_layer"]
|
||||
vocab_size: int = config["model"]["vocab_size"]
|
||||
phoneme_vocab_size: int = config["model"]["phoneme_vocab_size"]
|
||||
EOS: int = config["model"]["EOS"]
|
||||
ffn_dim: int = hidden_dim * 4
|
||||
|
||||
self.n_layer = int(n_layer)
|
||||
self.hidden_dim = int(hidden_dim)
|
||||
self.n_head = int(n_head)
|
||||
assert hidden_dim % n_head == 0
|
||||
|
||||
self.head_dim = int(hidden_dim // n_head)
|
||||
self.embedding_dim = int(embedding_dim)
|
||||
self.ffn_dim = int(ffn_dim)
|
||||
self.vocab_size = int(vocab_size)
|
||||
self.phoneme_vocab_size = int(phoneme_vocab_size)
|
||||
self.max_seq_length = max_seq_length
|
||||
self.max_batch_size = max_batch_size
|
||||
self.EOS = EOS
|
||||
assert self.EOS == self.vocab_size - 1
|
||||
|
||||
self.bert_proj: nn.Linear
|
||||
self.ar_predict_layer: nn.Linear
|
||||
self.h: TransformerDecoderABC
|
||||
|
||||
self.kv_class: type[KVCacheABC]
|
||||
|
||||
self.GraphCache: CUDAGraphCacheABC | None
|
||||
|
||||
self.ar_text_embedding = TokenEmbedding(self.embedding_dim, self.phoneme_vocab_size)
|
||||
self.ar_text_position = SinePositionalEmbedding(
|
||||
self.embedding_dim,
|
||||
scale=False,
|
||||
alpha=True,
|
||||
max_batch_size=max_batch_size,
|
||||
max_seq_len=max_seq_length,
|
||||
)
|
||||
self.ar_audio_embedding = TokenEmbedding(self.embedding_dim, self.vocab_size)
|
||||
self.ar_audio_position = SinePositionalEmbedding(
|
||||
self.embedding_dim,
|
||||
scale=False,
|
||||
alpha=True,
|
||||
max_batch_size=max_batch_size,
|
||||
max_seq_len=max_seq_length,
|
||||
)
|
||||
|
||||
self._register_load_state_dict_pre_hook(self.load_hook)
|
||||
|
||||
def load_hook(self, state_dict: dict[str, Tensor], prefix, *args):
|
||||
model_keys = [key for key in state_dict if key.startswith("model.")]
|
||||
for key in model_keys:
|
||||
new_key = key[len("model.") :]
|
||||
state_dict[new_key] = state_dict.pop(key)
|
||||
|
||||
def init_cache(self, bsz: int = 0) -> MutableSequence[KVCacheProtocol]:
|
||||
bsz = bsz or self.h.max_batch_size
|
||||
assert bsz <= self.h.max_batch_size
|
||||
seq_lens = self.h.max_seq_length
|
||||
dtype = self.bert_proj.bias.dtype
|
||||
kvclass = self.kv_class
|
||||
|
||||
return nn.ModuleList(
|
||||
[kvclass(bsz, seq_lens, self.n_head, self.head_dim) for _ in range(self.n_layer)],
|
||||
).to(self.device, dtype) # type: ignore
|
||||
|
||||
def embed(
|
||||
self,
|
||||
x: list[torch.Tensor],
|
||||
y: torch.Tensor,
|
||||
bert_features: list[torch.Tensor],
|
||||
):
|
||||
x_len: list[int] = [i.shape[0] for i in x]
|
||||
x_len_max = max(x_len)
|
||||
xy_pos = torch.zeros((len(x), x_len_max + y.shape[1], self.embedding_dim)).to(bert_features[0].dtype)
|
||||
|
||||
bert_features = list(map(lambda x: x.transpose(0, 1), bert_features))
|
||||
|
||||
y_len = y.shape[1]
|
||||
y_emb = self.ar_audio_embedding(y)
|
||||
y_pos = self.ar_audio_position.prefill(y_emb)
|
||||
|
||||
for bs, (x_, len_, bert_feature) in enumerate(zip(x, x_len, bert_features)):
|
||||
x_emb = self.ar_text_embedding(x_)
|
||||
bert = self.bert_proj(bert_feature)
|
||||
x_emb = x_emb + bert
|
||||
x_pos = self.ar_text_position.prefill(x_emb.unsqueeze(0))
|
||||
xy_pos[[bs], :len_] = x_pos
|
||||
xy_pos[[bs], len_ : len_ + y_len] = y_pos
|
||||
|
||||
return xy_pos
|
||||
|
||||
def compile(self, *args, **kwds):
|
||||
# Experimental features to reduce compilation times, will be on by default in future
|
||||
torch._inductor.config.triton.cudagraph_skip_dynamic_graphs = True
|
||||
torch._inductor.config.coordinate_descent_tuning = True
|
||||
torch._inductor.config.triton.unique_kernel_names = True
|
||||
torch._inductor.config.fx_graph_cache = True
|
||||
torch._inductor.config.triton.cudagraph_trees = True
|
||||
torch._inductor.config.triton.cudagraph_support_input_mutation = True
|
||||
self.h.compile(fullgraph=True, mode="reduce-overhead")
|
||||
|
||||
def capture(
|
||||
self, input_pos: Tensor, x: Tensor, x_dec: Tensor, kv_caches: MutableSequence[KVCacheProtocol], *args, **kwds
|
||||
) -> CUDAGraph:
|
||||
assert torch.cuda.is_available()
|
||||
s = torch.cuda.Stream()
|
||||
s.wait_stream(torch.cuda.current_stream())
|
||||
|
||||
graph = torch.cuda.CUDAGraph()
|
||||
|
||||
with torch.cuda.stream(s): # type: ignore
|
||||
for _ in range(5):
|
||||
self.h(input_pos, x, kv_caches, *args, **kwds)
|
||||
torch.cuda.current_stream().wait_stream(s)
|
||||
|
||||
with torch.cuda.graph(graph):
|
||||
x_dec.copy_(self.h(input_pos, x, kv_caches, *args, **kwds))
|
||||
torch.cuda.synchronize()
|
||||
|
||||
return graph
|
||||
|
||||
@abstractmethod
|
||||
def pre_forward(self, session: T2SSession) -> tuple[list[Tensor], dict[str, Tensor]]:
|
||||
return list(), dict()
|
||||
|
||||
@abstractmethod
|
||||
def post_forward(self, idx: int, session: T2SSession) -> None:
|
||||
return
|
||||
|
||||
|
||||
class CUDAGraphCacheABC(ABC):
|
||||
def __init__(
|
||||
self,
|
||||
decoder: T2SDecoderABC,
|
||||
enabled: bool = False,
|
||||
) -> None:
|
||||
if torch.cuda.is_available() and enabled:
|
||||
self.device: torch.device = decoder.device
|
||||
self.dtype = decoder.bert_proj.bias.dtype
|
||||
|
||||
self.assigned: bool = False
|
||||
|
||||
self.decoder: T2SDecoderABC = decoder
|
||||
self.kv_cache: MutableSequence[KVCacheProtocol] = decoder.init_cache(decoder.max_batch_size)
|
||||
self.xy_pos = torch.rand(size=(decoder.max_batch_size, 1, decoder.embedding_dim), device=self.device).to(
|
||||
self.dtype
|
||||
)
|
||||
self.xy_dec = self.xy_pos.clone()
|
||||
|
||||
self.input_pos = torch.tensor([10] * decoder.max_batch_size, device=self.device).int()
|
||||
self.graph: torch.cuda.CUDAGraph | None = None
|
||||
self.stream: torch.cuda.Stream | None
|
||||
|
||||
self.id: int = random.randint(1, 2**32 - 1)
|
||||
|
||||
def assign_graph(self, session: T2SSession):
|
||||
if self.graph is None:
|
||||
args, kwds = self.decoder.pre_forward(session)
|
||||
graph = self.decoder.capture(self.input_pos, self.xy_pos, self.xy_dec, self.kv_cache, *args, **kwds)
|
||||
self.graph = graph
|
||||
self.stream = torch.cuda.Stream() # type: ignore
|
||||
|
||||
if self.assigned is False:
|
||||
self.get_cache_graph(session)
|
||||
session.id = self.id
|
||||
self.assigned = True
|
||||
else:
|
||||
self.capture_new_graph(session)
|
||||
|
||||
@abstractmethod
|
||||
def release_graph(self, session: T2SSession): ...
|
||||
|
||||
@abstractmethod
|
||||
def get_cache_graph(self, session: T2SSession):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def capture_new_graph(self, session: T2SSession):
|
||||
pass
|
||||
|
||||
|
||||
class TorchProfiler:
|
||||
def __init__(self, debug: bool, log_dir: str = "./profiler") -> None:
|
||||
self.debug = debug
|
||||
self.log_dir = log_dir
|
||||
self.__profiler: torch.profiler.profile
|
||||
|
||||
if self.debug and not os.path.exists(self.log_dir):
|
||||
os.makedirs(self.log_dir)
|
||||
|
||||
self.tensorboard_handler = tensorboard_trace_handler(self.log_dir)
|
||||
|
||||
def profiler_callback(self, prof: torch.profiler.profile):
|
||||
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=30))
|
||||
print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=30))
|
||||
self.tensorboard_handler(prof)
|
||||
|
||||
@staticmethod
|
||||
def three_step_schedule(step: int) -> ProfilerAction:
|
||||
if step == 0:
|
||||
return ProfilerAction.NONE
|
||||
elif step == 1:
|
||||
return ProfilerAction.RECORD
|
||||
elif step == 2:
|
||||
return ProfilerAction.RECORD_AND_SAVE
|
||||
else:
|
||||
return ProfilerAction.NONE
|
||||
|
||||
def start(self):
|
||||
if not self.debug:
|
||||
return
|
||||
assert self.__profiler is not None
|
||||
self.__profiler.step()
|
||||
|
||||
def end(self):
|
||||
if not self.debug:
|
||||
return
|
||||
assert self.__profiler is not None
|
||||
self.__profiler.step()
|
||||
|
||||
def profiler(self):
|
||||
if self.debug:
|
||||
activities_list = [torch.profiler.ProfilerActivity.CPU]
|
||||
if torch.cuda.is_available():
|
||||
activities_list.append(torch.profiler.ProfilerActivity.CUDA)
|
||||
|
||||
self.__profiler = torch.profiler.profile(
|
||||
activities=activities_list,
|
||||
record_shapes=True,
|
||||
with_stack=True,
|
||||
with_modules=True,
|
||||
profile_memory=True,
|
||||
schedule=self.three_step_schedule,
|
||||
on_trace_ready=self.profiler_callback,
|
||||
)
|
||||
return self.__profiler
|
||||
else:
|
||||
return nullcontext()
|
||||
|
||||
def record(self, func_name: str):
|
||||
if self.debug:
|
||||
return torch.profiler.record_function(func_name)
|
||||
else:
|
||||
return nullcontext()
|
||||
11
GPT_SoVITS/Accelerate/__init__.py
Normal file
11
GPT_SoVITS/Accelerate/__init__.py
Normal file
@ -0,0 +1,11 @@
|
||||
from . import MLX, PyTorch
|
||||
from .PyTorch import T2SEngineTorch, T2SRequest, T2SResult
|
||||
|
||||
backends = PyTorch.backends + MLX.backends
|
||||
|
||||
backends = [
|
||||
b.replace("_", "-").title().replace("Mlx", "MLX").replace("Mps", "MPS").replace("Cuda", "CUDA") for b in backends
|
||||
]
|
||||
|
||||
|
||||
__all__ = ["T2SEngineTorch", "T2SRequest", "T2SResult", "backends", "MLX", "PyTorch"]
|
||||
@ -1,86 +0,0 @@
|
||||
import argparse
|
||||
import os
|
||||
import soundfile as sf
|
||||
|
||||
from tools.i18n.i18n import I18nAuto
|
||||
from GPT_SoVITS.inference_webui import change_gpt_weights, change_sovits_weights, get_tts_wav
|
||||
|
||||
i18n = I18nAuto()
|
||||
|
||||
|
||||
def synthesize(
|
||||
GPT_model_path,
|
||||
SoVITS_model_path,
|
||||
ref_audio_path,
|
||||
ref_text_path,
|
||||
ref_language,
|
||||
target_text_path,
|
||||
target_language,
|
||||
output_path,
|
||||
):
|
||||
# Read reference text
|
||||
with open(ref_text_path, "r", encoding="utf-8") as file:
|
||||
ref_text = file.read()
|
||||
|
||||
# Read target text
|
||||
with open(target_text_path, "r", encoding="utf-8") as file:
|
||||
target_text = file.read()
|
||||
|
||||
# Change model weights
|
||||
change_gpt_weights(gpt_path=GPT_model_path)
|
||||
change_sovits_weights(sovits_path=SoVITS_model_path)
|
||||
|
||||
# Synthesize audio
|
||||
synthesis_result = get_tts_wav(
|
||||
ref_wav_path=ref_audio_path,
|
||||
prompt_text=ref_text,
|
||||
prompt_language=i18n(ref_language),
|
||||
text=target_text,
|
||||
text_language=i18n(target_language),
|
||||
top_p=1,
|
||||
temperature=1,
|
||||
)
|
||||
|
||||
result_list = list(synthesis_result)
|
||||
|
||||
if result_list:
|
||||
last_sampling_rate, last_audio_data = result_list[-1]
|
||||
output_wav_path = os.path.join(output_path, "output.wav")
|
||||
sf.write(output_wav_path, last_audio_data, last_sampling_rate)
|
||||
print(f"Audio saved to {output_wav_path}")
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="GPT-SoVITS Command Line Tool")
|
||||
parser.add_argument("--gpt_model", required=True, help="Path to the GPT model file")
|
||||
parser.add_argument("--sovits_model", required=True, help="Path to the SoVITS model file")
|
||||
parser.add_argument("--ref_audio", required=True, help="Path to the reference audio file")
|
||||
parser.add_argument("--ref_text", required=True, help="Path to the reference text file")
|
||||
parser.add_argument(
|
||||
"--ref_language", required=True, choices=["中文", "英文", "日文"], help="Language of the reference audio"
|
||||
)
|
||||
parser.add_argument("--target_text", required=True, help="Path to the target text file")
|
||||
parser.add_argument(
|
||||
"--target_language",
|
||||
required=True,
|
||||
choices=["中文", "英文", "日文", "中英混合", "日英混合", "多语种混合"],
|
||||
help="Language of the target text",
|
||||
)
|
||||
parser.add_argument("--output_path", required=True, help="Path to the output directory")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
synthesize(
|
||||
args.gpt_model,
|
||||
args.sovits_model,
|
||||
args.ref_audio,
|
||||
args.ref_text,
|
||||
args.ref_language,
|
||||
args.target_text,
|
||||
args.target_language,
|
||||
args.output_path,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@ -1,316 +0,0 @@
|
||||
import os
|
||||
import sys
|
||||
from PyQt5.QtCore import QEvent
|
||||
from PyQt5.QtWidgets import QApplication, QMainWindow, QLabel, QLineEdit, QPushButton, QTextEdit
|
||||
from PyQt5.QtWidgets import QGridLayout, QVBoxLayout, QWidget, QFileDialog, QStatusBar, QComboBox
|
||||
import soundfile as sf
|
||||
|
||||
from tools.i18n.i18n import I18nAuto
|
||||
|
||||
i18n = I18nAuto()
|
||||
|
||||
from inference_webui import gpt_path, sovits_path, change_gpt_weights, change_sovits_weights, get_tts_wav
|
||||
|
||||
|
||||
class GPTSoVITSGUI(QMainWindow):
|
||||
GPT_Path = gpt_path
|
||||
SoVITS_Path = sovits_path
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
self.setWindowTitle("GPT-SoVITS GUI")
|
||||
self.setGeometry(800, 450, 950, 850)
|
||||
|
||||
self.setStyleSheet("""
|
||||
QWidget {
|
||||
background-color: #a3d3b1;
|
||||
}
|
||||
|
||||
QTabWidget::pane {
|
||||
background-color: #a3d3b1;
|
||||
}
|
||||
|
||||
QTabWidget::tab-bar {
|
||||
alignment: left;
|
||||
}
|
||||
|
||||
QTabBar::tab {
|
||||
background: #8da4bf;
|
||||
color: #ffffff;
|
||||
padding: 8px;
|
||||
}
|
||||
|
||||
QTabBar::tab:selected {
|
||||
background: #2a3f54;
|
||||
}
|
||||
|
||||
QLabel {
|
||||
color: #000000;
|
||||
}
|
||||
|
||||
QPushButton {
|
||||
background-color: #4CAF50;
|
||||
color: white;
|
||||
padding: 8px;
|
||||
border: 1px solid #4CAF50;
|
||||
border-radius: 4px;
|
||||
}
|
||||
|
||||
QPushButton:hover {
|
||||
background-color: #45a049;
|
||||
border: 1px solid #45a049;
|
||||
box-shadow: 2px 2px 2px rgba(0, 0, 0, 0.1);
|
||||
}
|
||||
""")
|
||||
|
||||
license_text = (
|
||||
"本软件以MIT协议开源, 作者不对软件具备任何控制力, 使用软件者、传播软件导出的声音者自负全责. "
|
||||
"如不认可该条款, 则不能使用或引用软件包内任何代码和文件. 详见根目录LICENSE."
|
||||
)
|
||||
license_label = QLabel(license_text)
|
||||
license_label.setWordWrap(True)
|
||||
|
||||
self.GPT_model_label = QLabel("选择GPT模型:")
|
||||
self.GPT_model_input = QLineEdit()
|
||||
self.GPT_model_input.setPlaceholderText("拖拽或选择文件")
|
||||
self.GPT_model_input.setText(self.GPT_Path)
|
||||
self.GPT_model_input.setReadOnly(True)
|
||||
self.GPT_model_button = QPushButton("选择GPT模型文件")
|
||||
self.GPT_model_button.clicked.connect(self.select_GPT_model)
|
||||
|
||||
self.SoVITS_model_label = QLabel("选择SoVITS模型:")
|
||||
self.SoVITS_model_input = QLineEdit()
|
||||
self.SoVITS_model_input.setPlaceholderText("拖拽或选择文件")
|
||||
self.SoVITS_model_input.setText(self.SoVITS_Path)
|
||||
self.SoVITS_model_input.setReadOnly(True)
|
||||
self.SoVITS_model_button = QPushButton("选择SoVITS模型文件")
|
||||
self.SoVITS_model_button.clicked.connect(self.select_SoVITS_model)
|
||||
|
||||
self.ref_audio_label = QLabel("上传参考音频:")
|
||||
self.ref_audio_input = QLineEdit()
|
||||
self.ref_audio_input.setPlaceholderText("拖拽或选择文件")
|
||||
self.ref_audio_input.setReadOnly(True)
|
||||
self.ref_audio_button = QPushButton("选择音频文件")
|
||||
self.ref_audio_button.clicked.connect(self.select_ref_audio)
|
||||
|
||||
self.ref_text_label = QLabel("参考音频文本:")
|
||||
self.ref_text_input = QLineEdit()
|
||||
self.ref_text_input.setPlaceholderText("直接输入文字或上传文本")
|
||||
self.ref_text_button = QPushButton("上传文本")
|
||||
self.ref_text_button.clicked.connect(self.upload_ref_text)
|
||||
|
||||
self.ref_language_label = QLabel("参考音频语言:")
|
||||
self.ref_language_combobox = QComboBox()
|
||||
self.ref_language_combobox.addItems(["中文", "英文", "日文", "中英混合", "日英混合", "多语种混合"])
|
||||
self.ref_language_combobox.setCurrentText("多语种混合")
|
||||
|
||||
self.target_text_label = QLabel("合成目标文本:")
|
||||
self.target_text_input = QLineEdit()
|
||||
self.target_text_input.setPlaceholderText("直接输入文字或上传文本")
|
||||
self.target_text_button = QPushButton("上传文本")
|
||||
self.target_text_button.clicked.connect(self.upload_target_text)
|
||||
|
||||
self.target_language_label = QLabel("合成音频语言:")
|
||||
self.target_language_combobox = QComboBox()
|
||||
self.target_language_combobox.addItems(["中文", "英文", "日文", "中英混合", "日英混合", "多语种混合"])
|
||||
self.target_language_combobox.setCurrentText("多语种混合")
|
||||
|
||||
self.output_label = QLabel("输出音频路径:")
|
||||
self.output_input = QLineEdit()
|
||||
self.output_input.setPlaceholderText("拖拽或选择文件")
|
||||
self.output_input.setReadOnly(True)
|
||||
self.output_button = QPushButton("选择文件夹")
|
||||
self.output_button.clicked.connect(self.select_output_path)
|
||||
|
||||
self.output_text = QTextEdit()
|
||||
self.output_text.setReadOnly(True)
|
||||
|
||||
self.add_drag_drop_events(
|
||||
[
|
||||
self.GPT_model_input,
|
||||
self.SoVITS_model_input,
|
||||
self.ref_audio_input,
|
||||
self.ref_text_input,
|
||||
self.target_text_input,
|
||||
self.output_input,
|
||||
]
|
||||
)
|
||||
|
||||
self.synthesize_button = QPushButton("合成")
|
||||
self.synthesize_button.clicked.connect(self.synthesize)
|
||||
|
||||
self.clear_output_button = QPushButton("清空输出")
|
||||
self.clear_output_button.clicked.connect(self.clear_output)
|
||||
|
||||
self.status_bar = QStatusBar()
|
||||
|
||||
main_layout = QVBoxLayout()
|
||||
|
||||
input_layout = QGridLayout(self)
|
||||
input_layout.setSpacing(10)
|
||||
|
||||
input_layout.addWidget(license_label, 0, 0, 1, 3)
|
||||
|
||||
input_layout.addWidget(self.GPT_model_label, 1, 0)
|
||||
input_layout.addWidget(self.GPT_model_input, 2, 0, 1, 2)
|
||||
input_layout.addWidget(self.GPT_model_button, 2, 2)
|
||||
|
||||
input_layout.addWidget(self.SoVITS_model_label, 3, 0)
|
||||
input_layout.addWidget(self.SoVITS_model_input, 4, 0, 1, 2)
|
||||
input_layout.addWidget(self.SoVITS_model_button, 4, 2)
|
||||
|
||||
input_layout.addWidget(self.ref_audio_label, 5, 0)
|
||||
input_layout.addWidget(self.ref_audio_input, 6, 0, 1, 2)
|
||||
input_layout.addWidget(self.ref_audio_button, 6, 2)
|
||||
|
||||
input_layout.addWidget(self.ref_language_label, 7, 0)
|
||||
input_layout.addWidget(self.ref_language_combobox, 8, 0, 1, 1)
|
||||
input_layout.addWidget(self.ref_text_label, 9, 0)
|
||||
input_layout.addWidget(self.ref_text_input, 10, 0, 1, 2)
|
||||
input_layout.addWidget(self.ref_text_button, 10, 2)
|
||||
|
||||
input_layout.addWidget(self.target_language_label, 11, 0)
|
||||
input_layout.addWidget(self.target_language_combobox, 12, 0, 1, 1)
|
||||
input_layout.addWidget(self.target_text_label, 13, 0)
|
||||
input_layout.addWidget(self.target_text_input, 14, 0, 1, 2)
|
||||
input_layout.addWidget(self.target_text_button, 14, 2)
|
||||
|
||||
input_layout.addWidget(self.output_label, 15, 0)
|
||||
input_layout.addWidget(self.output_input, 16, 0, 1, 2)
|
||||
input_layout.addWidget(self.output_button, 16, 2)
|
||||
|
||||
main_layout.addLayout(input_layout)
|
||||
|
||||
output_layout = QVBoxLayout()
|
||||
output_layout.addWidget(self.output_text)
|
||||
main_layout.addLayout(output_layout)
|
||||
|
||||
main_layout.addWidget(self.synthesize_button)
|
||||
|
||||
main_layout.addWidget(self.clear_output_button)
|
||||
|
||||
main_layout.addWidget(self.status_bar)
|
||||
|
||||
self.central_widget = QWidget()
|
||||
self.central_widget.setLayout(main_layout)
|
||||
self.setCentralWidget(self.central_widget)
|
||||
|
||||
def dragEnterEvent(self, event):
|
||||
if event.mimeData().hasUrls():
|
||||
event.acceptProposedAction()
|
||||
|
||||
def dropEvent(self, event):
|
||||
if event.mimeData().hasUrls():
|
||||
file_paths = [url.toLocalFile() for url in event.mimeData().urls()]
|
||||
if len(file_paths) == 1:
|
||||
self.update_ref_audio(file_paths[0])
|
||||
else:
|
||||
self.update_ref_audio(", ".join(file_paths))
|
||||
|
||||
def add_drag_drop_events(self, widgets):
|
||||
for widget in widgets:
|
||||
widget.setAcceptDrops(True)
|
||||
widget.installEventFilter(self)
|
||||
|
||||
def eventFilter(self, obj, event):
|
||||
if event.type() in (QEvent.DragEnter, QEvent.Drop):
|
||||
mime_data = event.mimeData()
|
||||
if mime_data.hasUrls():
|
||||
event.acceptProposedAction()
|
||||
|
||||
return super().eventFilter(obj, event)
|
||||
|
||||
def select_GPT_model(self):
|
||||
file_path, _ = QFileDialog.getOpenFileName(self, "选择GPT模型文件", "", "GPT Files (*.ckpt)")
|
||||
if file_path:
|
||||
self.GPT_model_input.setText(file_path)
|
||||
|
||||
def select_SoVITS_model(self):
|
||||
file_path, _ = QFileDialog.getOpenFileName(self, "选择SoVITS模型文件", "", "SoVITS Files (*.pth)")
|
||||
if file_path:
|
||||
self.SoVITS_model_input.setText(file_path)
|
||||
|
||||
def select_ref_audio(self):
|
||||
file_path, _ = QFileDialog.getOpenFileName(self, "选择参考音频文件", "", "Audio Files (*.wav *.mp3)")
|
||||
if file_path:
|
||||
self.update_ref_audio(file_path)
|
||||
|
||||
def upload_ref_text(self):
|
||||
file_path, _ = QFileDialog.getOpenFileName(self, "选择文本文件", "", "Text Files (*.txt)")
|
||||
if file_path:
|
||||
with open(file_path, "r", encoding="utf-8") as file:
|
||||
content = file.read()
|
||||
self.ref_text_input.setText(content)
|
||||
|
||||
def upload_target_text(self):
|
||||
file_path, _ = QFileDialog.getOpenFileName(self, "选择文本文件", "", "Text Files (*.txt)")
|
||||
if file_path:
|
||||
with open(file_path, "r", encoding="utf-8") as file:
|
||||
content = file.read()
|
||||
self.target_text_input.setText(content)
|
||||
|
||||
def select_output_path(self):
|
||||
options = QFileDialog.Options()
|
||||
options |= QFileDialog.DontUseNativeDialog
|
||||
options |= QFileDialog.ShowDirsOnly
|
||||
|
||||
folder_dialog = QFileDialog()
|
||||
folder_dialog.setOptions(options)
|
||||
folder_dialog.setFileMode(QFileDialog.Directory)
|
||||
|
||||
if folder_dialog.exec_():
|
||||
folder_path = folder_dialog.selectedFiles()[0]
|
||||
self.output_input.setText(folder_path)
|
||||
|
||||
def update_ref_audio(self, file_path):
|
||||
self.ref_audio_input.setText(file_path)
|
||||
|
||||
def clear_output(self):
|
||||
self.output_text.clear()
|
||||
|
||||
def synthesize(self):
|
||||
GPT_model_path = self.GPT_model_input.text()
|
||||
SoVITS_model_path = self.SoVITS_model_input.text()
|
||||
ref_audio_path = self.ref_audio_input.text()
|
||||
language_combobox = self.ref_language_combobox.currentText()
|
||||
language_combobox = i18n(language_combobox)
|
||||
ref_text = self.ref_text_input.text()
|
||||
target_language_combobox = self.target_language_combobox.currentText()
|
||||
target_language_combobox = i18n(target_language_combobox)
|
||||
target_text = self.target_text_input.text()
|
||||
output_path = self.output_input.text()
|
||||
|
||||
if GPT_model_path != self.GPT_Path:
|
||||
change_gpt_weights(gpt_path=GPT_model_path)
|
||||
self.GPT_Path = GPT_model_path
|
||||
if SoVITS_model_path != self.SoVITS_Path:
|
||||
change_sovits_weights(sovits_path=SoVITS_model_path)
|
||||
self.SoVITS_Path = SoVITS_model_path
|
||||
|
||||
synthesis_result = get_tts_wav(
|
||||
ref_wav_path=ref_audio_path,
|
||||
prompt_text=ref_text,
|
||||
prompt_language=language_combobox,
|
||||
text=target_text,
|
||||
text_language=target_language_combobox,
|
||||
)
|
||||
|
||||
result_list = list(synthesis_result)
|
||||
|
||||
if result_list:
|
||||
last_sampling_rate, last_audio_data = result_list[-1]
|
||||
output_wav_path = os.path.join(output_path, "output.wav")
|
||||
sf.write(output_wav_path, last_audio_data, last_sampling_rate)
|
||||
|
||||
result = "Audio saved to " + output_wav_path
|
||||
|
||||
self.status_bar.showMessage("合成完成!输出路径:" + output_wav_path, 5000)
|
||||
self.output_text.append("处理结果:\n" + result)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
app = QApplication(sys.argv)
|
||||
mainWin = GPTSoVITSGUI()
|
||||
mainWin.show()
|
||||
sys.exit(app.exec_())
|
||||
File diff suppressed because it is too large
Load Diff
@ -315,7 +315,7 @@ with gr.Blocks(title="GPT-SoVITS WebUI", analytics_enabled=False, js=js, css=css
|
||||
with gr.Column():
|
||||
# with gr.Group():
|
||||
gr.Markdown(value=i18n("模型切换"))
|
||||
with gr.Row():
|
||||
with gr.Row(equal_height=True):
|
||||
GPT_dropdown = gr.Dropdown(
|
||||
label=i18n("GPT模型列表"),
|
||||
choices=sorted(GPT_names, key=custom_sort_key),
|
||||
@ -331,18 +331,22 @@ with gr.Blocks(title="GPT-SoVITS WebUI", analytics_enabled=False, js=js, css=css
|
||||
refresh_button = gr.Button(i18n("刷新模型路径"), variant="primary")
|
||||
refresh_button.click(fn=change_choices, inputs=[], outputs=[SoVITS_dropdown, GPT_dropdown])
|
||||
|
||||
with gr.Row():
|
||||
with gr.Row(equal_height=True):
|
||||
with gr.Column():
|
||||
gr.Markdown(value=i18n("*请上传并填写参考信息"))
|
||||
with gr.Row():
|
||||
inp_ref = gr.Audio(label=i18n("主参考音频(请上传3~10秒内参考音频,超过会报错!)"), type="filepath")
|
||||
with gr.Row(equal_height=True):
|
||||
inp_ref = gr.Audio(
|
||||
label=i18n("主参考音频(请上传3~10秒内参考音频,超过会报错!)"),
|
||||
type="filepath",
|
||||
waveform_options={"show_recording_waveform": False},
|
||||
)
|
||||
inp_refs = gr.File(
|
||||
label=i18n("辅参考音频(可选多个,或不选)"),
|
||||
file_count="multiple",
|
||||
visible=True if model_version != "v3" else False,
|
||||
)
|
||||
prompt_text = gr.Textbox(label=i18n("主参考音频的文本"), value="", lines=2)
|
||||
with gr.Row():
|
||||
with gr.Row(equal_height=True):
|
||||
prompt_language = gr.Dropdown(
|
||||
label=i18n("主参考音频的语种"), choices=list(dict_language.keys()), value=i18n("中文")
|
||||
)
|
||||
@ -368,26 +372,26 @@ with gr.Blocks(title="GPT-SoVITS WebUI", analytics_enabled=False, js=js, css=css
|
||||
|
||||
with gr.Group():
|
||||
gr.Markdown(value=i18n("推理设置"))
|
||||
with gr.Row():
|
||||
with gr.Row(equal_height=True):
|
||||
with gr.Column():
|
||||
with gr.Row():
|
||||
with gr.Row(equal_height=True):
|
||||
batch_size = gr.Slider(
|
||||
minimum=1, maximum=200, step=1, label=i18n("batch_size"), value=20, interactive=True
|
||||
)
|
||||
sample_steps = gr.Radio(
|
||||
label=i18n("采样步数(仅对V3/4生效)"), value=32, choices=[4, 8, 16, 32, 64, 128], visible=True
|
||||
)
|
||||
with gr.Row():
|
||||
with gr.Row(equal_height=True):
|
||||
fragment_interval = gr.Slider(
|
||||
minimum=0.01, maximum=1, step=0.01, label=i18n("分段间隔(秒)"), value=0.3, interactive=True
|
||||
)
|
||||
speed_factor = gr.Slider(
|
||||
minimum=0.6, maximum=1.65, step=0.05, label="语速", value=1.0, interactive=True
|
||||
)
|
||||
with gr.Row():
|
||||
with gr.Row(equal_height=True):
|
||||
top_k = gr.Slider(minimum=1, maximum=100, step=1, label=i18n("top_k"), value=5, interactive=True)
|
||||
top_p = gr.Slider(minimum=0, maximum=1, step=0.05, label=i18n("top_p"), value=1, interactive=True)
|
||||
with gr.Row():
|
||||
with gr.Row(equal_height=True):
|
||||
temperature = gr.Slider(
|
||||
minimum=0, maximum=1, step=0.05, label=i18n("temperature"), value=1, interactive=True
|
||||
)
|
||||
@ -396,7 +400,7 @@ with gr.Blocks(title="GPT-SoVITS WebUI", analytics_enabled=False, js=js, css=css
|
||||
)
|
||||
|
||||
with gr.Column():
|
||||
with gr.Row():
|
||||
with gr.Row(equal_height=True):
|
||||
how_to_cut = gr.Dropdown(
|
||||
label=i18n("怎么切"),
|
||||
choices=[
|
||||
@ -415,7 +419,7 @@ with gr.Blocks(title="GPT-SoVITS WebUI", analytics_enabled=False, js=js, css=css
|
||||
label=i18n("音频超采样(仅对V3生效))"), value=False, interactive=True, show_label=True
|
||||
)
|
||||
|
||||
with gr.Row():
|
||||
with gr.Row(equal_height=True):
|
||||
parallel_infer = gr.Checkbox(label=i18n("并行推理"), value=True, interactive=True, show_label=True)
|
||||
split_bucket = gr.Checkbox(
|
||||
label=i18n("数据分桶(并行推理时会降低一点计算量)"),
|
||||
@ -424,12 +428,15 @@ with gr.Blocks(title="GPT-SoVITS WebUI", analytics_enabled=False, js=js, css=css
|
||||
show_label=True,
|
||||
)
|
||||
|
||||
with gr.Row():
|
||||
with gr.Row(equal_height=True):
|
||||
seed = gr.Number(label=i18n("随机种子"), value=-1)
|
||||
keep_random = gr.Checkbox(label=i18n("保持随机"), value=True, interactive=True, show_label=True)
|
||||
|
||||
output = gr.Audio(label=i18n("输出的语音"))
|
||||
with gr.Row():
|
||||
output = gr.Audio(
|
||||
label=i18n("输出的语音"),
|
||||
waveform_options={"show_recording_waveform": False},
|
||||
)
|
||||
with gr.Row(equal_height=True):
|
||||
inference_button = gr.Button(i18n("合成语音"), variant="primary")
|
||||
stop_infer = gr.Button(i18n("终止合成"), variant="primary")
|
||||
|
||||
@ -485,7 +492,7 @@ with gr.Blocks(title="GPT-SoVITS WebUI", analytics_enabled=False, js=js, css=css
|
||||
"文本切分工具。太长的文本合成出来效果不一定好,所以太长建议先切。合成会根据文本的换行分开合成再拼起来。"
|
||||
)
|
||||
)
|
||||
with gr.Row():
|
||||
with gr.Row(equal_height=True):
|
||||
text_inp = gr.Textbox(label=i18n("需要合成的切分前文本"), value="", lines=4)
|
||||
with gr.Column():
|
||||
_how_to_cut = gr.Radio(
|
||||
|
||||
@ -1,29 +1,32 @@
|
||||
# modified from https://github.com/feng-yufei/shared_debugging_code/blob/main/train_t2s.py
|
||||
import os
|
||||
|
||||
if "_CUDA_VISIBLE_DEVICES" in os.environ:
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = os.environ["_CUDA_VISIBLE_DEVICES"]
|
||||
import argparse
|
||||
import logging
|
||||
import os
|
||||
import platform
|
||||
from collections import OrderedDict
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from AR.data.data_module import Text2SemanticDataModule
|
||||
from AR.models.t2s_lightning_module import Text2SemanticLightningModule
|
||||
from AR.utils.io import load_yaml_config
|
||||
from pytorch_lightning import Trainer, seed_everything
|
||||
from pytorch_lightning.callbacks import ModelCheckpoint
|
||||
from pytorch_lightning.loggers import TensorBoardLogger # WandbLogger
|
||||
from pytorch_lightning.strategies import DDPStrategy
|
||||
from pytorch_lightning.strategies import DDPStrategy, SingleDeviceStrategy
|
||||
|
||||
from GPT_SoVITS.AR.data.data_module import Text2SemanticDataModule
|
||||
from GPT_SoVITS.AR.models.t2s_lightning_module import Text2SemanticLightningModule
|
||||
from GPT_SoVITS.AR.utils import get_newest_ckpt
|
||||
from GPT_SoVITS.AR.utils.io import load_yaml_config
|
||||
from GPT_SoVITS.process_ckpt import my_save
|
||||
|
||||
logging.getLogger("numba").setLevel(logging.WARNING)
|
||||
logging.getLogger("matplotlib").setLevel(logging.WARNING)
|
||||
torch.set_float32_matmul_precision("high")
|
||||
from collections import OrderedDict
|
||||
|
||||
from AR.utils import get_newest_ckpt
|
||||
from process_ckpt import my_save
|
||||
|
||||
if "_CUDA_VISIBLE_DEVICES" in os.environ:
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = os.environ["_CUDA_VISIBLE_DEVICES"]
|
||||
os.environ["MASTER_ADDR"] = "localhost"
|
||||
os.environ["USE_LIBUV"] = "0"
|
||||
|
||||
|
||||
class my_model_ckpt(ModelCheckpoint):
|
||||
@ -49,35 +52,30 @@ class my_model_ckpt(ModelCheckpoint):
|
||||
monitor_candidates = self._monitor_candidates(trainer)
|
||||
if self._every_n_epochs >= 1 and (trainer.current_epoch + 1) % self._every_n_epochs == 0:
|
||||
if (
|
||||
self.if_save_latest == True
|
||||
self.if_save_latest is True
|
||||
): ####如果设置只保存最后一个ckpt,在保存下一个ckpt后要清理掉之前的所有ckpt
|
||||
to_clean = list(os.listdir(self.dirpath))
|
||||
self._save_topk_checkpoint(trainer, monitor_candidates)
|
||||
if self.if_save_latest == True:
|
||||
if self.if_save_latest is True:
|
||||
for name in to_clean:
|
||||
try:
|
||||
os.remove("%s/%s" % (self.dirpath, name))
|
||||
except:
|
||||
os.remove(f"{self.dirpath}/{name}")
|
||||
except Exception as _:
|
||||
pass
|
||||
if self.if_save_every_weights == True:
|
||||
if self.if_save_every_weights is True:
|
||||
to_save_od = OrderedDict()
|
||||
to_save_od["weight"] = OrderedDict()
|
||||
dictt = trainer.strategy._lightning_module.state_dict()
|
||||
for key in dictt:
|
||||
to_save_od["weight"][key] = dictt[key].half()
|
||||
to_save_od["config"] = self.config
|
||||
to_save_od["info"] = "GPT-e%s" % (trainer.current_epoch + 1)
|
||||
to_save_od["info"] = f"GPT-e{trainer.current_epoch + 1}"
|
||||
# torch.save(
|
||||
# print(os.environ)
|
||||
if os.environ.get("LOCAL_RANK", "0") == "0":
|
||||
my_save(
|
||||
to_save_od,
|
||||
"%s/%s-e%s.ckpt"
|
||||
% (
|
||||
self.half_weights_save_dir,
|
||||
self.exp_name,
|
||||
trainer.current_epoch + 1,
|
||||
),
|
||||
f"{self.half_weights_save_dir}/{self.exp_name}-e{trainer.current_epoch + 1}.ckpt",
|
||||
)
|
||||
self._save_last_checkpoint(trainer, monitor_candidates)
|
||||
|
||||
@ -91,6 +89,14 @@ def main(args):
|
||||
ckpt_dir = output_dir / "ckpt"
|
||||
ckpt_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
if torch.cuda.is_available():
|
||||
if torch.cuda.device_count() > 1:
|
||||
strategy = DDPStrategy(process_group_backend="nccl" if platform.system() != "Windows" else "gloo")
|
||||
else:
|
||||
strategy = SingleDeviceStrategy("cuda")
|
||||
else:
|
||||
strategy = SingleDeviceStrategy("cpu")
|
||||
|
||||
seed_everything(config["train"]["seed"], workers=True)
|
||||
ckpt_callback: ModelCheckpoint = my_model_ckpt(
|
||||
config=config,
|
||||
@ -106,8 +112,7 @@ def main(args):
|
||||
dirpath=ckpt_dir,
|
||||
)
|
||||
logger = TensorBoardLogger(name=output_dir.stem, save_dir=output_dir)
|
||||
os.environ["MASTER_ADDR"] = "localhost"
|
||||
os.environ["USE_LIBUV"] = "0"
|
||||
|
||||
trainer: Trainer = Trainer(
|
||||
max_epochs=config["train"]["epochs"],
|
||||
accelerator="gpu" if torch.cuda.is_available() else "cpu",
|
||||
@ -117,9 +122,7 @@ def main(args):
|
||||
devices=-1 if torch.cuda.is_available() else 1,
|
||||
benchmark=False,
|
||||
fast_dev_run=False,
|
||||
strategy=DDPStrategy(process_group_backend="nccl" if platform.system() != "Windows" else "gloo")
|
||||
if torch.cuda.is_available()
|
||||
else "auto",
|
||||
strategy=strategy,
|
||||
precision=config["train"]["precision"],
|
||||
logger=logger,
|
||||
num_sanity_val_steps=0,
|
||||
|
||||
@ -1,40 +1,41 @@
|
||||
import logging
|
||||
import re
|
||||
from pathlib import Path
|
||||
|
||||
# jieba静音
|
||||
import fast_langdetect
|
||||
import jieba
|
||||
from split_lang import LangSplitter
|
||||
|
||||
jieba.setLogLevel(logging.CRITICAL)
|
||||
|
||||
# 更改fast_langdetect大模型位置
|
||||
from pathlib import Path
|
||||
import fast_langdetect
|
||||
fast_langdetect.infer._default_detector = fast_langdetect.infer.LangDetector(fast_langdetect.infer.LangDetectConfig(cache_dir=Path(__file__).parent.parent.parent / "pretrained_models" / "fast_langdetect"))
|
||||
|
||||
|
||||
from split_lang import LangSplitter
|
||||
fast_langdetect.infer._default_detector = fast_langdetect.infer.LangDetector(
|
||||
fast_langdetect.infer.LangDetectConfig(
|
||||
cache_dir=str(Path(__file__).parent.parent.parent / "pretrained_models" / "fast_langdetect")
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def full_en(text):
|
||||
pattern = r'^(?=.*[A-Za-z])[A-Za-z0-9\s\u0020-\u007E\u2000-\u206F\u3000-\u303F\uFF00-\uFFEF]+$'
|
||||
pattern = r"^(?=.*[A-Za-z])[A-Za-z0-9\s\u0020-\u007E\u2000-\u206F\u3000-\u303F\uFF00-\uFFEF]+$"
|
||||
return bool(re.match(pattern, text))
|
||||
|
||||
|
||||
def full_cjk(text):
|
||||
# 来自wiki
|
||||
cjk_ranges = [
|
||||
(0x4E00, 0x9FFF), # CJK Unified Ideographs
|
||||
(0x3400, 0x4DB5), # CJK Extension A
|
||||
(0x20000, 0x2A6DD), # CJK Extension B
|
||||
(0x2A700, 0x2B73F), # CJK Extension C
|
||||
(0x2B740, 0x2B81F), # CJK Extension D
|
||||
(0x2B820, 0x2CEAF), # CJK Extension E
|
||||
(0x2CEB0, 0x2EBEF), # CJK Extension F
|
||||
(0x30000, 0x3134A), # CJK Extension G
|
||||
(0x31350, 0x323AF), # CJK Extension H
|
||||
(0x2EBF0, 0x2EE5D), # CJK Extension H
|
||||
(0x4E00, 0x9FFF), # CJK Unified Ideographs
|
||||
(0x3400, 0x4DB5), # CJK Extension A
|
||||
(0x20000, 0x2A6DD), # CJK Extension B
|
||||
(0x2A700, 0x2B73F), # CJK Extension C
|
||||
(0x2B740, 0x2B81F), # CJK Extension D
|
||||
(0x2B820, 0x2CEAF), # CJK Extension E
|
||||
(0x2CEB0, 0x2EBEF), # CJK Extension F
|
||||
(0x30000, 0x3134A), # CJK Extension G
|
||||
(0x31350, 0x323AF), # CJK Extension H
|
||||
(0x2EBF0, 0x2EE5D), # CJK Extension H
|
||||
]
|
||||
|
||||
pattern = r'[0-9、-〜。!?.!?… /]+$'
|
||||
pattern = r"[0-9、-〜。!?.!?… /]+$"
|
||||
|
||||
cjk_text = ""
|
||||
for char in text:
|
||||
@ -45,7 +46,7 @@ def full_cjk(text):
|
||||
return cjk_text
|
||||
|
||||
|
||||
def split_jako(tag_lang,item):
|
||||
def split_jako(tag_lang, item):
|
||||
if tag_lang == "ja":
|
||||
pattern = r"([\u3041-\u3096\u3099\u309A\u30A1-\u30FA\u30FC]+(?:[0-9、-〜。!?.!?… ]+[\u3041-\u3096\u3099\u309A\u30A1-\u30FA\u30FC]*)*)"
|
||||
else:
|
||||
@ -53,41 +54,42 @@ def split_jako(tag_lang,item):
|
||||
|
||||
lang_list: list[dict] = []
|
||||
tag = 0
|
||||
for match in re.finditer(pattern, item['text']):
|
||||
for match in re.finditer(pattern, item["text"]):
|
||||
if match.start() > tag:
|
||||
lang_list.append({'lang':item['lang'],'text':item['text'][tag:match.start()]})
|
||||
lang_list.append({"lang": item["lang"], "text": item["text"][tag : match.start()]})
|
||||
|
||||
tag = match.end()
|
||||
lang_list.append({'lang':tag_lang,'text':item['text'][match.start():match.end()]})
|
||||
lang_list.append({"lang": tag_lang, "text": item["text"][match.start() : match.end()]})
|
||||
|
||||
if tag < len(item['text']):
|
||||
lang_list.append({'lang':item['lang'],'text':item['text'][tag:len(item['text'])]})
|
||||
if tag < len(item["text"]):
|
||||
lang_list.append({"lang": item["lang"], "text": item["text"][tag : len(item["text"])]})
|
||||
|
||||
return lang_list
|
||||
|
||||
|
||||
def merge_lang(lang_list, item):
|
||||
if lang_list and item['lang'] == lang_list[-1]['lang']:
|
||||
lang_list[-1]['text'] += item['text']
|
||||
if lang_list and item["lang"] == lang_list[-1]["lang"]:
|
||||
lang_list[-1]["text"] += item["text"]
|
||||
else:
|
||||
lang_list.append(item)
|
||||
return lang_list
|
||||
|
||||
|
||||
class LangSegmenter():
|
||||
class LangSegmenter:
|
||||
# 默认过滤器, 基于gsv目前四种语言
|
||||
DEFAULT_LANG_MAP = {
|
||||
"zh": "zh",
|
||||
"yue": "zh", # 粤语
|
||||
"wuu": "zh", # 吴语
|
||||
"zh-cn": "zh",
|
||||
"zh-tw": "x", # 繁体设置为x
|
||||
"zh-tw": "x", # 繁体设置为x
|
||||
"ko": "ko",
|
||||
"ja": "ja",
|
||||
"en": "en",
|
||||
}
|
||||
|
||||
def getTexts(text,default_lang = ""):
|
||||
@staticmethod
|
||||
def getTexts(text, default_lang=""):
|
||||
lang_splitter = LangSplitter(lang_map=LangSegmenter.DEFAULT_LANG_MAP)
|
||||
lang_splitter.merge_across_digit = False
|
||||
substr = lang_splitter.split_by_lang(text=text)
|
||||
@ -97,31 +99,31 @@ class LangSegmenter():
|
||||
have_num = False
|
||||
|
||||
for _, item in enumerate(substr):
|
||||
dict_item = {'lang':item.lang,'text':item.text}
|
||||
dict_item = {"lang": item.lang, "text": item.text}
|
||||
|
||||
if dict_item['lang'] == 'digit':
|
||||
if dict_item["lang"] == "digit":
|
||||
if default_lang != "":
|
||||
dict_item['lang'] = default_lang
|
||||
dict_item["lang"] = default_lang
|
||||
else:
|
||||
have_num = True
|
||||
lang_list = merge_lang(lang_list,dict_item)
|
||||
lang_list = merge_lang(lang_list, dict_item)
|
||||
continue
|
||||
|
||||
# 处理短英文被识别为其他语言的问题
|
||||
if full_en(dict_item['text']):
|
||||
dict_item['lang'] = 'en'
|
||||
lang_list = merge_lang(lang_list,dict_item)
|
||||
if full_en(dict_item["text"]):
|
||||
dict_item["lang"] = "en"
|
||||
lang_list = merge_lang(lang_list, dict_item)
|
||||
continue
|
||||
|
||||
if default_lang != "":
|
||||
dict_item['lang'] = default_lang
|
||||
lang_list = merge_lang(lang_list,dict_item)
|
||||
dict_item["lang"] = default_lang
|
||||
lang_list = merge_lang(lang_list, dict_item)
|
||||
continue
|
||||
else:
|
||||
# 处理非日语夹日文的问题(不包含CJK)
|
||||
ja_list: list[dict] = []
|
||||
if dict_item['lang'] != 'ja':
|
||||
ja_list = split_jako('ja',dict_item)
|
||||
if dict_item["lang"] != "ja":
|
||||
ja_list = split_jako("ja", dict_item)
|
||||
|
||||
if not ja_list:
|
||||
ja_list.append(dict_item)
|
||||
@ -130,8 +132,8 @@ class LangSegmenter():
|
||||
ko_list: list[dict] = []
|
||||
temp_list: list[dict] = []
|
||||
for _, ko_item in enumerate(ja_list):
|
||||
if ko_item["lang"] != 'ko':
|
||||
ko_list = split_jako('ko',ko_item)
|
||||
if ko_item["lang"] != "ko":
|
||||
ko_list = split_jako("ko", ko_item)
|
||||
|
||||
if ko_list:
|
||||
temp_list.extend(ko_list)
|
||||
@ -141,77 +143,76 @@ class LangSegmenter():
|
||||
# 未存在非日韩文夹日韩文
|
||||
if len(temp_list) == 1:
|
||||
# 未知语言检查是否为CJK
|
||||
if dict_item['lang'] == 'x':
|
||||
cjk_text = full_cjk(dict_item['text'])
|
||||
if dict_item["lang"] == "x":
|
||||
cjk_text = full_cjk(dict_item["text"])
|
||||
if cjk_text:
|
||||
dict_item = {'lang':'zh','text':cjk_text}
|
||||
lang_list = merge_lang(lang_list,dict_item)
|
||||
dict_item = {"lang": "zh", "text": cjk_text}
|
||||
lang_list = merge_lang(lang_list, dict_item)
|
||||
else:
|
||||
lang_list = merge_lang(lang_list,dict_item)
|
||||
lang_list = merge_lang(lang_list, dict_item)
|
||||
continue
|
||||
else:
|
||||
lang_list = merge_lang(lang_list,dict_item)
|
||||
lang_list = merge_lang(lang_list, dict_item)
|
||||
continue
|
||||
|
||||
# 存在非日韩文夹日韩文
|
||||
for _, temp_item in enumerate(temp_list):
|
||||
# 未知语言检查是否为CJK
|
||||
if temp_item['lang'] == 'x':
|
||||
cjk_text = full_cjk(temp_item['text'])
|
||||
if temp_item["lang"] == "x":
|
||||
cjk_text = full_cjk(temp_item["text"])
|
||||
if cjk_text:
|
||||
lang_list = merge_lang(lang_list,{'lang':'zh','text':cjk_text})
|
||||
lang_list = merge_lang(lang_list, {"lang": "zh", "text": cjk_text})
|
||||
else:
|
||||
lang_list = merge_lang(lang_list,temp_item)
|
||||
lang_list = merge_lang(lang_list, temp_item)
|
||||
else:
|
||||
lang_list = merge_lang(lang_list,temp_item)
|
||||
lang_list = merge_lang(lang_list, temp_item)
|
||||
|
||||
# 有数字
|
||||
if have_num:
|
||||
temp_list = lang_list
|
||||
lang_list = []
|
||||
for i, temp_item in enumerate(temp_list):
|
||||
if temp_item['lang'] == 'digit':
|
||||
if temp_item["lang"] == "digit":
|
||||
if default_lang:
|
||||
temp_item['lang'] = default_lang
|
||||
temp_item["lang"] = default_lang
|
||||
elif lang_list and i == len(temp_list) - 1:
|
||||
temp_item['lang'] = lang_list[-1]['lang']
|
||||
temp_item["lang"] = lang_list[-1]["lang"]
|
||||
elif not lang_list and i < len(temp_list) - 1:
|
||||
temp_item['lang'] = temp_list[1]['lang']
|
||||
temp_item["lang"] = temp_list[1]["lang"]
|
||||
elif lang_list and i < len(temp_list) - 1:
|
||||
if lang_list[-1]['lang'] == temp_list[i + 1]['lang']:
|
||||
temp_item['lang'] = lang_list[-1]['lang']
|
||||
elif lang_list[-1]['text'][-1] in [",",".","!","?",",","。","!","?"]:
|
||||
temp_item['lang'] = temp_list[i + 1]['lang']
|
||||
elif temp_list[i + 1]['text'][0] in [",",".","!","?",",","。","!","?"]:
|
||||
temp_item['lang'] = lang_list[-1]['lang']
|
||||
elif temp_item['text'][-1] in ["。","."]:
|
||||
temp_item['lang'] = lang_list[-1]['lang']
|
||||
elif len(lang_list[-1]['text']) >= len(temp_list[i + 1]['text']):
|
||||
temp_item['lang'] = lang_list[-1]['lang']
|
||||
if lang_list[-1]["lang"] == temp_list[i + 1]["lang"]:
|
||||
temp_item["lang"] = lang_list[-1]["lang"]
|
||||
elif lang_list[-1]["text"][-1] in [",", ".", "!", "?", ",", "。", "!", "?"]:
|
||||
temp_item["lang"] = temp_list[i + 1]["lang"]
|
||||
elif temp_list[i + 1]["text"][0] in [",", ".", "!", "?", ",", "。", "!", "?"]:
|
||||
temp_item["lang"] = lang_list[-1]["lang"]
|
||||
elif temp_item["text"][-1] in ["。", "."]:
|
||||
temp_item["lang"] = lang_list[-1]["lang"]
|
||||
elif len(lang_list[-1]["text"]) >= len(temp_list[i + 1]["text"]):
|
||||
temp_item["lang"] = lang_list[-1]["lang"]
|
||||
else:
|
||||
temp_item['lang'] = temp_list[i + 1]['lang']
|
||||
temp_item["lang"] = temp_list[i + 1]["lang"]
|
||||
else:
|
||||
temp_item['lang'] = 'zh'
|
||||
|
||||
lang_list = merge_lang(lang_list,temp_item)
|
||||
temp_item["lang"] = "zh"
|
||||
|
||||
lang_list = merge_lang(lang_list, temp_item)
|
||||
|
||||
# 筛X
|
||||
temp_list = lang_list
|
||||
lang_list = []
|
||||
for _, temp_item in enumerate(temp_list):
|
||||
if temp_item['lang'] == 'x':
|
||||
if temp_item["lang"] == "x":
|
||||
if lang_list:
|
||||
temp_item['lang'] = lang_list[-1]['lang']
|
||||
temp_item["lang"] = lang_list[-1]["lang"]
|
||||
elif len(temp_list) > 1:
|
||||
temp_item['lang'] = temp_list[1]['lang']
|
||||
temp_item["lang"] = temp_list[1]["lang"]
|
||||
else:
|
||||
temp_item['lang'] = 'zh'
|
||||
temp_item["lang"] = "zh"
|
||||
|
||||
lang_list = merge_lang(lang_list,temp_item)
|
||||
lang_list = merge_lang(lang_list, temp_item)
|
||||
|
||||
return lang_list
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
text = "MyGO?,你也喜欢まいご吗?"
|
||||
@ -221,5 +222,5 @@ if __name__ == "__main__":
|
||||
print(LangSegmenter.getTexts(text))
|
||||
|
||||
text = "当时ThinkPad T60刚刚发布,一同推出的还有一款名为Advanced Dock的扩展坞配件。这款扩展坞通过连接T60底部的插槽,扩展出包括PCIe在内的一大堆接口,并且自带电源,让T60可以安装桌面显卡来提升性能。"
|
||||
print(LangSegmenter.getTexts(text,"zh"))
|
||||
print(LangSegmenter.getTexts(text))
|
||||
print(LangSegmenter.getTexts(text, "zh"))
|
||||
print(LangSegmenter.getTexts(text))
|
||||
|
||||
@ -248,13 +248,13 @@ if you want to switch to V1,then double-click`go-webui-v1.bat` or use `go-webui-
|
||||
#### Others
|
||||
|
||||
```bash
|
||||
python webui.py <language(optional)>
|
||||
PYTHONPATH=. python webui.py <language(optional)>
|
||||
```
|
||||
|
||||
if you want to switch to V1,then
|
||||
|
||||
```bash
|
||||
python webui.py v1 <language(optional)>
|
||||
PYTHONPATH=. python webui.py v1 <language(optional)>
|
||||
```
|
||||
|
||||
Or maunally switch version in WebUI
|
||||
@ -285,7 +285,7 @@ python GPT_SoVITS/inference_webui.py <language(optional)>
|
||||
OR
|
||||
|
||||
```bash
|
||||
python webui.py
|
||||
PYTHONPATH=. python webui.py
|
||||
```
|
||||
|
||||
then open the inference webui at `1-GPT-SoVITS-TTS/1C-inference`
|
||||
|
||||
21
config.py
21
config.py
@ -161,7 +161,7 @@ def get_device_dtype_sm(idx: int) -> tuple[torch.device, torch.dtype, float, flo
|
||||
is_16_series = bool(re.search(r"16\d{2}", name)) and sm_version == 7.5
|
||||
if mem_gb < 4 or sm_version < 5.3:
|
||||
return cpu, torch.float32, 0.0, 0.0
|
||||
if sm_version == 6.1 or is_16_series == True:
|
||||
if sm_version == 6.1 or is_16_series is True:
|
||||
return cuda, torch.float32, sm_version, mem_gb
|
||||
if sm_version > 6.1:
|
||||
return cuda, torch.float16, sm_version, mem_gb
|
||||
@ -216,3 +216,22 @@ class Config:
|
||||
self.webui_port_subfix = webui_port_subfix
|
||||
|
||||
self.api_port = api_port
|
||||
|
||||
|
||||
def get_implement(device: torch.device):
|
||||
if torch.cuda.is_available():
|
||||
idx = device.index
|
||||
capability = torch.cuda.get_device_capability(idx)
|
||||
major, minor = capability
|
||||
sm_version = major + minor / 10.0
|
||||
if sm_version >= 7.5:
|
||||
return "flash_attn"
|
||||
else:
|
||||
if sys.platform == "linux":
|
||||
return "sage_attn"
|
||||
else:
|
||||
return "naive"
|
||||
elif torch.mps.is_available():
|
||||
return "mlx"
|
||||
else:
|
||||
return "naive"
|
||||
|
||||
@ -236,13 +236,13 @@ D:\GPT-SoVITS\xxx/xxx.wav|xxx|zh|我爱玩原神.
|
||||
#### 其他
|
||||
|
||||
```bash
|
||||
python webui.py <language(optional)>
|
||||
PYTHONPATH=. python webui.py <language(optional)>
|
||||
```
|
||||
|
||||
若想使用 V1,则
|
||||
|
||||
```bash
|
||||
python webui.py v1 <language(optional)>
|
||||
PYTHONPATH=. python webui.py v1 <language(optional)>
|
||||
```
|
||||
|
||||
或者在 webUI 内动态切换
|
||||
@ -273,7 +273,7 @@ python GPT_SoVITS/inference_webui.py <language(optional)>
|
||||
或者
|
||||
|
||||
```bash
|
||||
python webui.py
|
||||
PYTHONPATH=. python webui.py
|
||||
```
|
||||
|
||||
然后在 `1-GPT-SoVITS-TTS/1C-推理` 中打开推理 webUI
|
||||
|
||||
@ -222,13 +222,13 @@ V1 に切り替えたい場合は、`go-webui-v1.bat`をダブルクリックす
|
||||
#### その他
|
||||
|
||||
```bash
|
||||
python webui.py <言語(オプション)>
|
||||
PYTHONPATH=. python webui.py <言語(オプション)>
|
||||
```
|
||||
|
||||
V1 に切り替えたい場合は
|
||||
|
||||
```bash
|
||||
python webui.py v1 <言語(オプション)>
|
||||
PYTHONPATH=. python webui.py v1 <言語(オプション)>
|
||||
```
|
||||
|
||||
または WebUI で手動でバージョンを切り替えてください.
|
||||
@ -259,7 +259,7 @@ python GPT_SoVITS/inference_webui.py <言語(オプション)>
|
||||
または
|
||||
|
||||
```bash
|
||||
python webui.py
|
||||
PYTHONPATH=. python webui.py
|
||||
```
|
||||
|
||||
その後、`1-GPT-SoVITS-TTS/1C-inference`で推論 webui を開きます.
|
||||
|
||||
@ -228,13 +228,13 @@ V1으로 전환하려면, `go-webui-v1.bat`을 더블 클릭하거나 `go-webui-
|
||||
#### 기타
|
||||
|
||||
```bash
|
||||
python webui.py <언어(옵션)>
|
||||
PYTHONPATH=. python webui.py <언어(옵션)>
|
||||
```
|
||||
|
||||
V1으로 전환하려면,
|
||||
|
||||
```bash
|
||||
python webui.py v1 <언어(옵션)>
|
||||
PYTHONPATH=. python webui.py v1 <언어(옵션)>
|
||||
```
|
||||
|
||||
또는 WebUI에서 수동으로 버전을 전환하십시오.
|
||||
@ -265,7 +265,7 @@ python GPT_SoVITS/inference_webui.py <언어(옵션)>
|
||||
또는
|
||||
|
||||
```bash
|
||||
python webui.py
|
||||
PYTHONPATH=. python webui.py
|
||||
```
|
||||
|
||||
그런 다음 `1-GPT-SoVITS-TTS/1C-inference`에서 추론 webui를 엽니다.
|
||||
|
||||
@ -229,13 +229,13 @@ V1'e geçmek istiyorsanız, `go-webui-v1.bat` dosyasına çift tıklayın veya `
|
||||
#### Diğerleri
|
||||
|
||||
```bash
|
||||
python webui.py <dil(isteğe bağlı)>
|
||||
PYTHONPATH=. python webui.py <dil(isteğe bağlı)>
|
||||
```
|
||||
|
||||
V1'e geçmek istiyorsanız,
|
||||
|
||||
```bash
|
||||
python webui.py v1 <dil(isteğe bağlı)>
|
||||
PYTHONPATH=. python webui.py v1 <dil(isteğe bağlı)>
|
||||
```
|
||||
|
||||
veya WebUI'de manuel olarak sürüm değiştirin.
|
||||
@ -266,7 +266,7 @@ python GPT_SoVITS/inference_webui.py <dil(isteğe bağlı)>
|
||||
VEYA
|
||||
|
||||
```bash
|
||||
python webui.py
|
||||
PYTHONPATH=. python webui.py
|
||||
```
|
||||
|
||||
ardından çıkarım webui'sini `1-GPT-SoVITS-TTS/1C-inference` adresinde açın.
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
set "SCRIPT_DIR=%~dp0"
|
||||
set "SCRIPT_DIR=%SCRIPT_DIR:~0,-1%"
|
||||
cd /d "%SCRIPT_DIR%"
|
||||
set "PATH=%SCRIPT_DIR%\runtime;%PATH%"
|
||||
set "PATH=%SCRIPT_DIR%\runtime"
|
||||
set "PYTHONPATH=%SCRIPT_DIR%"
|
||||
runtime\python.exe -I webui.py zh_CN
|
||||
pause
|
||||
|
||||
@ -2,6 +2,7 @@ $ErrorActionPreference = "SilentlyContinue"
|
||||
chcp 65001
|
||||
Set-Location $PSScriptRoot
|
||||
$runtimePath = Join-Path $PSScriptRoot "runtime"
|
||||
$env:PATH = "$runtimePath;$env:PATH"
|
||||
$env:PATH = "$runtimePath"
|
||||
$env:PYTHONPATH = "$runtimePath"
|
||||
& "$runtimePath\python.exe" -I "$PSScriptRoot\webui.py" zh_CN
|
||||
pause
|
||||
|
||||
@ -1,243 +0,0 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "9fd922fb",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Deprecated"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "45857cb2",
|
||||
"metadata": {
|
||||
"_cell_guid": "b1076dfc-b9ad-4769-8c92-a6c4dae69d19",
|
||||
"_uuid": "8f2839f25d086af736a60e9eeb907d3b93b6e0e5",
|
||||
"execution": {
|
||||
"iopub.execute_input": "2024-02-18T14:43:46.735480Z",
|
||||
"iopub.status.busy": "2024-02-18T14:43:46.735183Z",
|
||||
"iopub.status.idle": "2024-02-18T14:48:10.724175Z",
|
||||
"shell.execute_reply": "2024-02-18T14:48:10.723059Z"
|
||||
},
|
||||
"papermill": {
|
||||
"duration": 263.994935,
|
||||
"end_time": "2024-02-18T14:48:10.726613",
|
||||
"exception": false,
|
||||
"start_time": "2024-02-18T14:43:46.731678",
|
||||
"status": "completed"
|
||||
},
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"!git clone https://github.com/RVC-Boss/GPT-SoVITS.git\n",
|
||||
"%cd GPT-SoVITS\n",
|
||||
"!apt-get update && apt-get install -y --no-install-recommends tzdata ffmpeg libsox-dev parallel aria2 git git-lfs && git lfs install\n",
|
||||
"!pip install -r requirements.txt\n",
|
||||
"!pip install -r extra-req.txt --no-deps"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "b9d346b4",
|
||||
"metadata": {
|
||||
"execution": {
|
||||
"iopub.execute_input": "2024-02-18T14:48:10.815802Z",
|
||||
"iopub.status.busy": "2024-02-18T14:48:10.814899Z",
|
||||
"iopub.status.idle": "2024-02-18T14:50:31.253276Z",
|
||||
"shell.execute_reply": "2024-02-18T14:50:31.252024Z"
|
||||
},
|
||||
"papermill": {
|
||||
"duration": 140.484893,
|
||||
"end_time": "2024-02-18T14:50:31.255720",
|
||||
"exception": false,
|
||||
"start_time": "2024-02-18T14:48:10.770827",
|
||||
"status": "completed"
|
||||
},
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# @title Download pretrained models 下载预训练模型\n",
|
||||
"!mkdir -p /kaggle/working/GPT-SoVITS/GPT_SoVITS/pretrained_models\n",
|
||||
"!mkdir -p /kaggle/working/GPT-SoVITS/tools/asr/models\n",
|
||||
"!mkdir -p /kaggle/working/GPT-SoVITS/tools/uvr5\n",
|
||||
"%cd /kaggle/working/GPT-SoVITS/GPT_SoVITS/pretrained_models\n",
|
||||
"!git clone https://huggingface.co/lj1995/GPT-SoVITS\n",
|
||||
"%cd /kaggle/working/GPT-SoVITS/tools/asr/models\n",
|
||||
"!git clone https://www.modelscope.cn/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch.git\n",
|
||||
"!git clone https://www.modelscope.cn/damo/speech_fsmn_vad_zh-cn-16k-common-pytorch.git\n",
|
||||
"!git clone https://www.modelscope.cn/damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch.git\n",
|
||||
"# # @title UVR5 pretrains 安装uvr5模型\n",
|
||||
"%cd /kaggle/working/GPT-SoVITS/tools/uvr5\n",
|
||||
"!git clone https://huggingface.co/Delik/uvr5_weights\n",
|
||||
"!git config core.sparseCheckout true\n",
|
||||
"!mv /kaggle/working/GPT-SoVITS/GPT_SoVITS/pretrained_models/GPT-SoVITS/* /kaggle/working/GPT-SoVITS/GPT_SoVITS/pretrained_models/"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "ea94d245",
|
||||
"metadata": {
|
||||
"execution": {
|
||||
"iopub.execute_input": "2024-02-18T14:29:01.071549Z",
|
||||
"iopub.status.busy": "2024-02-18T14:29:01.070592Z",
|
||||
"iopub.status.idle": "2024-02-18T14:40:45.318368Z",
|
||||
"shell.execute_reply": "2024-02-18T14:40:45.317130Z",
|
||||
"shell.execute_reply.started": "2024-02-18T14:29:01.071512Z"
|
||||
},
|
||||
"papermill": {
|
||||
"duration": null,
|
||||
"end_time": null,
|
||||
"exception": false,
|
||||
"start_time": "2024-02-18T14:50:31.309013",
|
||||
"status": "running"
|
||||
},
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# @title launch WebUI 启动WebUI\n",
|
||||
"%cd /kaggle/working/GPT-SoVITS/\n",
|
||||
"!npm install -g localtunnel\n",
|
||||
"import subprocess\n",
|
||||
"import threading\n",
|
||||
"import time\n",
|
||||
"import socket\n",
|
||||
"import urllib.request\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def iframe_thread(port):\n",
|
||||
" while True:\n",
|
||||
" time.sleep(0.5)\n",
|
||||
" sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)\n",
|
||||
" result = sock.connect_ex((\"127.0.0.1\", port))\n",
|
||||
" if result == 0:\n",
|
||||
" break\n",
|
||||
" sock.close()\n",
|
||||
"\n",
|
||||
" from colorama import Fore, Style\n",
|
||||
" print(\n",
|
||||
" Fore.GREEN + \"\\nIP: \",\n",
|
||||
" Fore.RED,\n",
|
||||
" urllib.request.urlopen(\"https://ipv4.icanhazip.com\").read().decode(\"utf8\").strip(\"\\n\"),\n",
|
||||
" \"\\n\",\n",
|
||||
" Style.RESET_ALL,\n",
|
||||
" )\n",
|
||||
" p = subprocess.Popen([\"lt\", \"--port\", \"{}\".format(port)], stdout=subprocess.PIPE)\n",
|
||||
" for line in p.stdout:\n",
|
||||
" print(line.decode(), end=\"\")\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"threading.Thread(target=iframe_thread, daemon=True, args=(9874,)).start()\n",
|
||||
"\n",
|
||||
"!python webui.py"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "dda88a6d",
|
||||
"metadata": {
|
||||
"execution": {
|
||||
"iopub.execute_input": "2024-02-18T14:40:56.880608Z",
|
||||
"iopub.status.busy": "2024-02-18T14:40:56.879879Z"
|
||||
},
|
||||
"papermill": {
|
||||
"duration": null,
|
||||
"end_time": null,
|
||||
"exception": null,
|
||||
"start_time": null,
|
||||
"status": "pending"
|
||||
},
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# 开启推理页面\n",
|
||||
"%cd /kaggle/working/GPT-SoVITS/\n",
|
||||
"!npm install -g localtunnel\n",
|
||||
"import threading\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def iframe_thread(port):\n",
|
||||
" while True:\n",
|
||||
" time.sleep(0.5)\n",
|
||||
" sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)\n",
|
||||
" result = sock.connect_ex((\"127.0.0.1\", port))\n",
|
||||
" if result == 0:\n",
|
||||
" break\n",
|
||||
" sock.close()\n",
|
||||
"\n",
|
||||
" from colorama import Fore, Style\n",
|
||||
" print(\n",
|
||||
" Fore.GREEN + \"\\nIP: \",\n",
|
||||
" Fore.RED,\n",
|
||||
" urllib.request.urlopen(\"https://ipv4.icanhazip.com\").read().decode(\"utf8\").strip(\"\\n\"),\n",
|
||||
" \"\\n\",\n",
|
||||
" Style.RESET_ALL,\n",
|
||||
" )\n",
|
||||
" p = subprocess.Popen([\"lt\", \"--port\", \"{}\".format(port)], stdout=subprocess.PIPE)\n",
|
||||
" for line in p.stdout:\n",
|
||||
" print(line.decode(), end=\"\")\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"threading.Thread(target=iframe_thread, daemon=True, args=(9872,)).start()\n",
|
||||
"\n",
|
||||
"!python ./GPT_SoVITS/inference_webui.py"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kaggle": {
|
||||
"accelerator": "nvidiaTeslaT4",
|
||||
"dataSources": [
|
||||
{
|
||||
"datasetId": 4459328,
|
||||
"sourceId": 7649639,
|
||||
"sourceType": "datasetVersion"
|
||||
}
|
||||
],
|
||||
"dockerImageVersionId": 30646,
|
||||
"isGpuEnabled": true,
|
||||
"isInternetEnabled": true,
|
||||
"language": "python",
|
||||
"sourceType": "notebook"
|
||||
},
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.13"
|
||||
},
|
||||
"papermill": {
|
||||
"default_parameters": {},
|
||||
"duration": null,
|
||||
"end_time": null,
|
||||
"environment_variables": {},
|
||||
"exception": null,
|
||||
"input_path": "__notebook__.ipynb",
|
||||
"output_path": "__notebook__.ipynb",
|
||||
"parameters": {},
|
||||
"start_time": "2024-02-18T14:43:44.011910",
|
||||
"version": "2.5.0"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
24
install.ps1
24
install.ps1
@ -40,6 +40,10 @@ function Write-Info($msg) {
|
||||
Write-Host "[INFO]:" -ForegroundColor Green -NoNewline
|
||||
Write-Host " $msg"
|
||||
}
|
||||
function Write-Warning($msg) {
|
||||
Write-Host "[Warning]:" -ForegroundColor Yellow -NoNewline
|
||||
Write-Host " $msg"
|
||||
}
|
||||
function Write-Success($msg) {
|
||||
Write-Host "[SUCCESS]:" -ForegroundColor Blue -NoNewline
|
||||
Write-Host " $msg"
|
||||
@ -137,7 +141,7 @@ chcp 65001
|
||||
Set-Location $PSScriptRoot
|
||||
|
||||
Write-Info "Installing FFmpeg & CMake..."
|
||||
Invoke-Conda ffmpeg cmake
|
||||
Invoke-Conda ffmpeg cmake vc14_runtime
|
||||
Write-Success "FFmpeg & CMake Installed"
|
||||
|
||||
$PretrainedURL = ""
|
||||
@ -208,12 +212,30 @@ if ($DownloadUVR5) {
|
||||
|
||||
switch ($Device) {
|
||||
"CU128" {
|
||||
$cudaLine = nvidia-smi | Select-String "CUDA Version"
|
||||
$version = ($cudaLine -split "CUDA Version:")[1].Trim()
|
||||
Write-Info "Maximum CUDA Version Supported By Current Driver: $version"
|
||||
if ([version](nvidia-smi | Select-String "CUDA Version" | ForEach-Object { ($_ -split "CUDA Version:")[1].Trim() }) -ge [version]"12.8") {
|
||||
Write-Warning "CUDA 12.8 Is Not Supported By Current Driver"
|
||||
}
|
||||
Write-Info "Installing PyTorch For CUDA 12.8..."
|
||||
Invoke-Pip torch torchaudio --index-url "https://download.pytorch.org/whl/cu128"
|
||||
Invoke-Conda cuda-nvcc=12.8
|
||||
Invoke-Pip psutil ninja packaging wheel "setuptools>=42"
|
||||
Invoke-Pip flash-attn -i https://xxxxrt666.github.io/PIP-Index/ --no-build-isolation
|
||||
}
|
||||
"CU126" {
|
||||
$cudaLine = nvidia-smi | Select-String "CUDA Version"
|
||||
$version = ($cudaLine -split "CUDA Version:")[1].Trim()
|
||||
Write-Info "Maximum CUDA Version Supported By Current Driver: $version"
|
||||
if ([version](nvidia-smi | Select-String "CUDA Version" | ForEach-Object { ($_ -split "CUDA Version:")[1].Trim() }) -ge [version]"12.8") {
|
||||
Write-Warning "CUDA 12.6 Is Not Supported By Current Driver"
|
||||
}
|
||||
Write-Info "Installing PyTorch For CUDA 12.6..."
|
||||
Invoke-Pip torch torchaudio --index-url "https://download.pytorch.org/whl/cu126"
|
||||
Invoke-Conda cuda-nvcc=12.6
|
||||
Invoke-Pip psutil ninja packaging wheel "setuptools>=42"
|
||||
Invoke-Pip flash-attn -i https://xxxxrt666.github.io/PIP-Index/ --no-build-isolation
|
||||
}
|
||||
"CPU" {
|
||||
Write-Info "Installing PyTorch For CPU..."
|
||||
|
||||
20
install.sh
20
install.sh
@ -127,7 +127,7 @@ while [[ $# -gt 0 ]]; do
|
||||
USE_ROCM=true
|
||||
;;
|
||||
MPS)
|
||||
USE_CPU=true
|
||||
USE_MPS=true
|
||||
;;
|
||||
CPU)
|
||||
USE_CPU=true
|
||||
@ -157,7 +157,7 @@ while [[ $# -gt 0 ]]; do
|
||||
esac
|
||||
done
|
||||
|
||||
if ! $USE_CUDA && ! $USE_ROCM && ! $USE_CPU; then
|
||||
if ! $USE_CUDA && ! $USE_ROCM && ! $USE_MPS && ! $USE_CPU; then
|
||||
echo -e "${ERROR}Error: Device is REQUIRED"
|
||||
echo ""
|
||||
print_help
|
||||
@ -322,13 +322,29 @@ if [ "$USE_ROCM" = true ] && [ "$WORKFLOW" = false ]; then
|
||||
fi
|
||||
|
||||
if [ "$USE_CUDA" = true ] && [ "$WORKFLOW" = false ]; then
|
||||
CUDAVERSION=$(nvidia-smi | grep "CUDA Version" | sed -E 's/.*CUDA Version: ([0-9]+\.[0-9]+).*/\1/')
|
||||
echo -e "${INFO}Maximum CUDA Version Supported By Current Driver: $CUDAVERSION"
|
||||
if [ "$CUDA" = 128 ]; then
|
||||
if awk "BEGIN {exit !($CUDAVERSION < 12.8)}"; then
|
||||
echo -r "${WARNING}CUDA 12.8 Is Not Supported By Current Driver"
|
||||
fi
|
||||
echo -e "${INFO}Installing PyTorch For CUDA 12.8..."
|
||||
run_pip_quiet torch torchaudio --index-url "https://download.pytorch.org/whl/cu128"
|
||||
run_conda_quiet cuda-nvcc=12.8
|
||||
elif [ "$CUDA" = 126 ]; then
|
||||
if awk "BEGIN {exit !($CUDAVERSION < 12.6)}"; then
|
||||
echo -r "${WARNING}CUDA 12.6 Is Not Supported By Current Driver"
|
||||
fi
|
||||
echo -e "${INFO}Installing PyTorch For CUDA 12.6..."
|
||||
run_pip_quiet torch torchaudio --index-url "https://download.pytorch.org/whl/cu126"
|
||||
run_conda_quiet cuda-nvcc=12.6
|
||||
fi
|
||||
run_pip_quiet psutil ninja packaging wheel "setuptools>=42"
|
||||
run_pip_quiet flash-attn -i https://xxxxrt666.github.io/PIP-Index/ --no-build-isolation
|
||||
elif [ "$USE_MPS" = true ] && [ "$WORKFLOW" = false ]; then
|
||||
echo -e "${INFO}Installing PyTorch For MPS..."
|
||||
run_pip_quiet torch torchaudio --index-url "https://download.pytorch.org/whl/cpu"
|
||||
run_pip_quiet mlx
|
||||
elif [ "$USE_ROCM" = true ] && [ "$WORKFLOW" = false ]; then
|
||||
echo -e "${INFO}Installing PyTorch For ROCm 6.2..."
|
||||
run_pip_quiet torch torchaudio --index-url "https://download.pytorch.org/whl/rocm6.2"
|
||||
|
||||
@ -5,7 +5,7 @@ tensorboard
|
||||
librosa==0.10.2
|
||||
numba
|
||||
pytorch-lightning>=2.4
|
||||
gradio<5
|
||||
gradio==5.25.0
|
||||
ffmpeg-python
|
||||
onnxruntime; platform_machine == "aarch64" or platform_machine == "arm64"
|
||||
onnxruntime-gpu; platform_machine == "x86_64" or platform_machine == "AMD64"
|
||||
@ -16,9 +16,11 @@ pypinyin
|
||||
pyopenjtalk>=0.4.1
|
||||
g2p_en
|
||||
torchaudio
|
||||
modelscope==1.10.0
|
||||
modelscope
|
||||
sentencepiece
|
||||
transformers>=4.43,<=4.50
|
||||
transformers
|
||||
huggingface_hub
|
||||
kernels
|
||||
peft
|
||||
chardet
|
||||
PyYAML
|
||||
@ -39,7 +41,6 @@ x_transformers
|
||||
torchmetrics<=1.5
|
||||
pydantic<=2.10.6
|
||||
ctranslate2>=4.0,<5
|
||||
huggingface_hub>=0.13
|
||||
tokenizers>=0.13,<1
|
||||
av>=11
|
||||
tqdm
|
||||
|
||||
@ -222,5 +222,6 @@
|
||||
"预训练SoVITS-D模型路径": "Pretrained SoVITS-D Model Path",
|
||||
"预训练SoVITS-G模型路径": "Pretrained SoVITS-G Model Path",
|
||||
"预训练中文BERT模型路径": "Pretrained Chinese BERT Model Path",
|
||||
"预训练模型路径": "Pretrained Model Path"
|
||||
"预训练模型路径": "Pretrained Model Path",
|
||||
"推理后端": "Inference Backend"
|
||||
}
|
||||
|
||||
@ -222,5 +222,6 @@
|
||||
"预训练SoVITS-D模型路径": "预训练SoVITS-D模型路径",
|
||||
"预训练SoVITS-G模型路径": "预训练SoVITS-G模型路径",
|
||||
"预训练中文BERT模型路径": "预训练中文BERT模型路径",
|
||||
"预训练模型路径": "预训练模型路径"
|
||||
"预训练模型路径": "预训练模型路径",
|
||||
"推理后端": "推理后端"
|
||||
}
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
import sys
|
||||
|
||||
from tools.i18n.i18n import I18nAuto, scan_language_list
|
||||
|
||||
language = sys.argv[-1] if sys.argv[-1] in scan_language_list() else "Auto"
|
||||
@ -314,7 +315,7 @@ if __name__ == "__main__":
|
||||
"Submit Text: 将当前页所有文本框内容手工保存到内存和文件(翻页前后或者退出标注页面前如果没点这个按钮,你再翻回来就回滚了,白忙活。)"
|
||||
)
|
||||
)
|
||||
with gr.Row():
|
||||
with gr.Row(equal_height=True):
|
||||
btn_change_index = gr.Button("Change Index")
|
||||
btn_submit_change = gr.Button("Submit Text")
|
||||
btn_merge_audio = gr.Button("Merge Audio")
|
||||
@ -322,7 +323,7 @@ if __name__ == "__main__":
|
||||
btn_previous_index = gr.Button("Previous Index")
|
||||
btn_next_index = gr.Button("Next Index")
|
||||
|
||||
with gr.Row():
|
||||
with gr.Row(equal_height=True):
|
||||
index_slider = gr.Slider(minimum=0, maximum=g_max_json_index, value=g_index, step=1, label="Index", scale=3)
|
||||
splitpoint_slider = gr.Slider(
|
||||
minimum=0, maximum=120.0, value=0, step=0.1, label="Audio Split Point(s)", scale=3
|
||||
@ -331,18 +332,23 @@ if __name__ == "__main__":
|
||||
btn_save_json = gr.Button("Save File", visible=True, scale=1)
|
||||
btn_invert_selection = gr.Button("Invert Selection", scale=1)
|
||||
|
||||
with gr.Row():
|
||||
with gr.Row(equal_height=True):
|
||||
with gr.Column():
|
||||
for _ in range(0, g_batch):
|
||||
with gr.Row():
|
||||
with gr.Row(equal_height=True):
|
||||
text = gr.Textbox(label="Text", visible=True, scale=5)
|
||||
audio_output = gr.Audio(label="Output Audio", visible=True, scale=5)
|
||||
audio_output = gr.Audio(
|
||||
label="Output Audio",
|
||||
visible=True,
|
||||
scale=5,
|
||||
waveform_options={"show_recording_waveform": False},
|
||||
)
|
||||
audio_check = gr.Checkbox(label="Yes", show_label=True, info="Choose Audio", scale=1)
|
||||
g_text_list.append(text)
|
||||
g_audio_list.append(audio_output)
|
||||
g_checkbox_list.append(audio_check)
|
||||
|
||||
with gr.Row():
|
||||
with gr.Row(equal_height=True):
|
||||
batchsize_slider = gr.Slider(
|
||||
minimum=1, maximum=g_batch, value=g_batch, step=1, label="Batch Size", scale=3, interactive=False
|
||||
)
|
||||
|
||||
@ -168,7 +168,7 @@ with gr.Blocks(title="UVR5 WebUI", analytics_enabled=False) as app:
|
||||
"h4",
|
||||
)
|
||||
)
|
||||
with gr.Row():
|
||||
with gr.Row(equal_height=True):
|
||||
with gr.Column():
|
||||
model_choose = gr.Dropdown(label=i18n("模型"), choices=uvr5_names)
|
||||
dir_wav_input = gr.Textbox(
|
||||
@ -197,9 +197,9 @@ with gr.Blocks(title="UVR5 WebUI", analytics_enabled=False) as app:
|
||||
interactive=True,
|
||||
)
|
||||
with gr.Column():
|
||||
with gr.Row():
|
||||
with gr.Row(equal_height=True):
|
||||
but2 = gr.Button(i18n("转换"), variant="primary")
|
||||
with gr.Row():
|
||||
with gr.Row(equal_height=True):
|
||||
vc_output4 = gr.Textbox(label=i18n("输出信息"), lines=3)
|
||||
but2.click(
|
||||
uvr,
|
||||
|
||||
586
webui.py
586
webui.py
@ -1,22 +1,96 @@
|
||||
import os
|
||||
import sys
|
||||
|
||||
os.environ["version"] = version = "v2Pro"
|
||||
now_dir = os.getcwd()
|
||||
sys.path.insert(0, now_dir)
|
||||
import warnings
|
||||
|
||||
warnings.filterwarnings("ignore")
|
||||
import argparse
|
||||
import contextlib
|
||||
import json
|
||||
import os
|
||||
import platform
|
||||
import re
|
||||
import shutil
|
||||
import signal
|
||||
import site
|
||||
import subprocess
|
||||
import traceback
|
||||
import warnings
|
||||
from multiprocessing import cpu_count
|
||||
from subprocess import Popen
|
||||
|
||||
import gradio as gr
|
||||
import psutil
|
||||
import torch
|
||||
import yaml
|
||||
|
||||
from config import (
|
||||
GPU_INDEX,
|
||||
GPU_INFOS,
|
||||
IS_GPU,
|
||||
GPT_weight_root,
|
||||
GPT_weight_version2root,
|
||||
SoVITS_weight_root,
|
||||
SoVITS_weight_version2root,
|
||||
change_choices,
|
||||
exp_root,
|
||||
get_weights_names,
|
||||
infer_device,
|
||||
is_half,
|
||||
is_share,
|
||||
memset,
|
||||
pretrained_gpt_name,
|
||||
pretrained_sovits_name,
|
||||
python_exec,
|
||||
webui_port_infer_tts,
|
||||
webui_port_main,
|
||||
webui_port_subfix,
|
||||
webui_port_uvr5,
|
||||
)
|
||||
from GPT_SoVITS.Accelerate import backends
|
||||
from tools import my_utils
|
||||
from tools.asr.config import asr_dict
|
||||
from tools.assets import css, js, top_html
|
||||
from tools.i18n.i18n import I18nAuto, scan_language_list
|
||||
from tools.my_utils import check_details, check_for_existance
|
||||
|
||||
os.environ["PYTHONPATH"] = now_dir = os.getcwd()
|
||||
|
||||
backends_gradio = [(b.replace("-", " "), b) for b in backends]
|
||||
|
||||
_LANG_RE = re.compile(r"^[a-z]{2}[_-][A-Z]{2}$")
|
||||
|
||||
|
||||
def lang_type(text: str) -> str:
|
||||
if text == "Auto":
|
||||
return text
|
||||
if not _LANG_RE.match(text):
|
||||
raise argparse.ArgumentTypeError(f"Unspported Format: {text}, Expected ll_CC/ll-CC")
|
||||
ll, cc = re.split(r"[_-]", text)
|
||||
language = f"{ll}_{cc}"
|
||||
if language in scan_language_list():
|
||||
return language
|
||||
else:
|
||||
return "en_US"
|
||||
|
||||
|
||||
def build_parser() -> argparse.ArgumentParser:
|
||||
p = argparse.ArgumentParser(
|
||||
prog="train_webui",
|
||||
description="python -s webui.py zh_CN",
|
||||
)
|
||||
p.add_argument(
|
||||
"language",
|
||||
nargs="?",
|
||||
default="Auto",
|
||||
type=lang_type,
|
||||
help="Language Code, Such as zh_CN, en-US",
|
||||
)
|
||||
return p
|
||||
|
||||
|
||||
args = build_parser().parse_args()
|
||||
|
||||
os.environ["version"] = version = "v2Pro"
|
||||
os.environ["TORCH_DISTRIBUTED_DEBUG"] = "INFO"
|
||||
|
||||
warnings.filterwarnings("ignore")
|
||||
|
||||
|
||||
torch.manual_seed(233333)
|
||||
tmp = os.path.join(now_dir, "TEMP")
|
||||
os.makedirs(tmp, exist_ok=True)
|
||||
@ -32,8 +106,6 @@ if os.path.exists(tmp):
|
||||
except Exception as e:
|
||||
print(str(e))
|
||||
pass
|
||||
import site
|
||||
import traceback
|
||||
|
||||
site_packages_roots = []
|
||||
for path in site.getsitepackages():
|
||||
@ -41,7 +113,6 @@ for path in site.getsitepackages():
|
||||
site_packages_roots.append(path)
|
||||
if site_packages_roots == []:
|
||||
site_packages_roots = ["%s/runtime/Lib/site-packages" % now_dir]
|
||||
# os.environ["OPENBLAS_NUM_THREADS"] = "4"
|
||||
os.environ["no_proxy"] = "localhost, 127.0.0.1, ::1"
|
||||
os.environ["all_proxy"] = ""
|
||||
for site_packages_root in site_packages_roots:
|
||||
@ -56,41 +127,10 @@ for site_packages_root in site_packages_roots:
|
||||
break
|
||||
except PermissionError:
|
||||
traceback.print_exc()
|
||||
import shutil
|
||||
import subprocess
|
||||
from subprocess import Popen
|
||||
|
||||
from tools.assets import css, js, top_html
|
||||
from tools.i18n.i18n import I18nAuto, scan_language_list
|
||||
|
||||
language = sys.argv[-1] if sys.argv[-1] in scan_language_list() else "Auto"
|
||||
os.environ["language"] = language
|
||||
language = args.language
|
||||
i18n = I18nAuto(language=language)
|
||||
from multiprocessing import cpu_count
|
||||
|
||||
from config import (
|
||||
GPU_INDEX,
|
||||
GPU_INFOS,
|
||||
IS_GPU,
|
||||
exp_root,
|
||||
infer_device,
|
||||
is_half,
|
||||
is_share,
|
||||
memset,
|
||||
python_exec,
|
||||
webui_port_infer_tts,
|
||||
webui_port_main,
|
||||
webui_port_subfix,
|
||||
webui_port_uvr5,
|
||||
)
|
||||
from tools import my_utils
|
||||
from tools.my_utils import check_details, check_for_existance
|
||||
|
||||
os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"
|
||||
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
|
||||
|
||||
# os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1' # 当遇到mps不支持的步骤时使用cpu
|
||||
import gradio as gr
|
||||
|
||||
n_cpu = cpu_count()
|
||||
|
||||
@ -147,7 +187,7 @@ def fix_gpu_number(input): # 将越界的number强制改到界内
|
||||
try:
|
||||
if int(input) not in set_gpu_numbers:
|
||||
return default_gpu_numbers
|
||||
except:
|
||||
except Exception as _:
|
||||
return input
|
||||
return input
|
||||
|
||||
@ -158,13 +198,10 @@ def fix_gpu_numbers(inputs):
|
||||
for input in inputs.split(","):
|
||||
output.append(str(fix_gpu_number(input)))
|
||||
return ",".join(output)
|
||||
except:
|
||||
except Exception as _:
|
||||
return inputs
|
||||
|
||||
|
||||
from config import pretrained_gpt_name, pretrained_sovits_name
|
||||
|
||||
|
||||
def check_pretrained_is_exist(version):
|
||||
pretrained_model_list = (
|
||||
pretrained_sovits_name[version],
|
||||
@ -189,14 +226,6 @@ for key in pretrained_gpt_name.keys():
|
||||
if os.path.exists(pretrained_gpt_name[key]) == False:
|
||||
pretrained_gpt_name[key] = ""
|
||||
|
||||
from config import (
|
||||
GPT_weight_root,
|
||||
GPT_weight_version2root,
|
||||
SoVITS_weight_root,
|
||||
SoVITS_weight_version2root,
|
||||
change_choices,
|
||||
get_weights_names,
|
||||
)
|
||||
|
||||
for root in SoVITS_weight_root + GPT_weight_root:
|
||||
os.makedirs(root, exist_ok=True)
|
||||
@ -218,15 +247,11 @@ def kill_proc_tree(pid, including_parent=True):
|
||||
|
||||
children = parent.children(recursive=True)
|
||||
for child in children:
|
||||
try:
|
||||
with contextlib.suppress(OSError):
|
||||
os.kill(child.pid, signal.SIGTERM) # or signal.SIGKILL
|
||||
except OSError:
|
||||
pass
|
||||
if including_parent:
|
||||
try:
|
||||
with contextlib.suppress(OSError):
|
||||
os.kill(parent.pid, signal.SIGTERM) # or signal.SIGKILL
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
|
||||
system = platform.system()
|
||||
@ -329,21 +354,20 @@ def change_uvr5():
|
||||
process_name_tts = i18n("TTS推理WebUI")
|
||||
|
||||
|
||||
def change_tts_inference(bert_path, cnhubert_base_path, gpu_number, gpt_path, sovits_path, batched_infer_enabled):
|
||||
def change_tts_inference(
|
||||
bert_path, cnhubert_base_path, gpu_number, gpt_path, sovits_path, batched_infer_enabled, backends_dropdown
|
||||
):
|
||||
global p_tts_inference
|
||||
if batched_infer_enabled:
|
||||
cmd = '"%s" -s GPT_SoVITS/inference_webui_fast.py "%s"' % (python_exec, language)
|
||||
cmd = f"'{python_exec}' -s GPT_SoVITS/inference_webui_fast.py {language}"
|
||||
else:
|
||||
cmd = '"%s" -s GPT_SoVITS/inference_webui.py "%s"' % (python_exec, language)
|
||||
# #####v3暂不支持加速推理
|
||||
# if version=="v3":
|
||||
# cmd = '"%s" GPT_SoVITS/inference_webui.py "%s"'%(python_exec, language)
|
||||
cmd = f"'{python_exec}' -s GPT_SoVITS/inference_webui.py {language} -b {backends_dropdown}"
|
||||
if p_tts_inference is None:
|
||||
os.environ["gpt_path"] = gpt_path
|
||||
os.environ["sovits_path"] = sovits_path
|
||||
os.environ["cnhubert_base_path"] = cnhubert_base_path
|
||||
os.environ["bert_path"] = bert_path
|
||||
os.environ["_CUDA_VISIBLE_DEVICES"] = fix_gpu_number(gpu_number)
|
||||
os.environ["_CUDA_VISIBLE_DEVICES"] = str(fix_gpu_number(gpu_number))
|
||||
os.environ["is_half"] = str(is_half)
|
||||
os.environ["infer_ttswebui"] = str(webui_port_infer_tts)
|
||||
os.environ["is_share"] = str(is_share)
|
||||
@ -364,8 +388,6 @@ def change_tts_inference(bert_path, cnhubert_base_path, gpu_number, gpt_path, so
|
||||
)
|
||||
|
||||
|
||||
from tools.asr.config import asr_dict
|
||||
|
||||
process_name_asr = i18n("语音识别")
|
||||
|
||||
|
||||
@ -764,7 +786,7 @@ def close_slice():
|
||||
for p_slice in ps_slice:
|
||||
try:
|
||||
kill_process(p_slice.pid, process_name_slice)
|
||||
except:
|
||||
except Exception as _:
|
||||
traceback.print_exc()
|
||||
ps_slice = []
|
||||
return (
|
||||
@ -853,7 +875,7 @@ def close1a():
|
||||
for p1a in ps1a:
|
||||
try:
|
||||
kill_process(p1a.pid, process_name_1a)
|
||||
except:
|
||||
except Exception as _:
|
||||
traceback.print_exc()
|
||||
ps1a = []
|
||||
return (
|
||||
@ -944,7 +966,7 @@ def close1b():
|
||||
for p1b in ps1b:
|
||||
try:
|
||||
kill_process(p1b.pid, process_name_1b)
|
||||
except:
|
||||
except Exception as _:
|
||||
traceback.print_exc()
|
||||
ps1b = []
|
||||
return (
|
||||
@ -1030,7 +1052,7 @@ def close1c():
|
||||
for p1c in ps1c:
|
||||
try:
|
||||
kill_process(p1c.pid, process_name_1c)
|
||||
except:
|
||||
except Exception as _:
|
||||
traceback.print_exc()
|
||||
ps1c = []
|
||||
return (
|
||||
@ -1230,7 +1252,7 @@ def open1abc(
|
||||
{"__type__": "update", "visible": True},
|
||||
{"__type__": "update", "visible": False},
|
||||
)
|
||||
except:
|
||||
except Exception as _:
|
||||
traceback.print_exc()
|
||||
close1abc()
|
||||
yield (
|
||||
@ -1252,7 +1274,7 @@ def close1abc():
|
||||
for p1abc in ps1abc:
|
||||
try:
|
||||
kill_process(p1abc.pid, process_name_1abc)
|
||||
except:
|
||||
except Exception as _:
|
||||
traceback.print_exc()
|
||||
ps1abc = []
|
||||
return (
|
||||
@ -1303,6 +1325,14 @@ def sync(text):
|
||||
return {"__type__": "update", "value": text}
|
||||
|
||||
|
||||
def changeBackend(flag: bool):
|
||||
if flag:
|
||||
return gr.update(choices=["Torch Varlen"], value="Torch Varlen")
|
||||
else:
|
||||
return gr.update(choices=backends_gradio, value=backends_gradio[-1][-1])
|
||||
|
||||
|
||||
GPU_INDEX.add(0)
|
||||
with gr.Blocks(title="GPT-SoVITS WebUI", analytics_enabled=False, js=js, css=css) as app:
|
||||
gr.HTML(
|
||||
top_html.format(
|
||||
@ -1315,9 +1345,9 @@ with gr.Blocks(title="GPT-SoVITS WebUI", analytics_enabled=False, js=js, css=css
|
||||
with gr.Tabs():
|
||||
with gr.TabItem("0-" + i18n("前置数据集获取工具")): # 提前随机切片防止uvr5爆内存->uvr5->slicer->asr->打标
|
||||
with gr.Accordion(label="0a-" + i18n("UVR5人声伴奏分离&去混响去延迟工具")):
|
||||
with gr.Row():
|
||||
with gr.Row(equal_height=True):
|
||||
with gr.Column(scale=3):
|
||||
with gr.Row():
|
||||
with gr.Row(equal_height=True):
|
||||
uvr5_info = gr.Textbox(label=process_info(process_name_uvr5, "info"))
|
||||
open_uvr5 = gr.Button(
|
||||
value=process_info(process_name_uvr5, "open"), variant="primary", visible=True
|
||||
@ -1327,14 +1357,14 @@ with gr.Blocks(title="GPT-SoVITS WebUI", analytics_enabled=False, js=js, css=css
|
||||
)
|
||||
|
||||
with gr.Accordion(label="0b-" + i18n("语音切分工具")):
|
||||
with gr.Row():
|
||||
with gr.Row(equal_height=True):
|
||||
with gr.Column(scale=3):
|
||||
with gr.Row():
|
||||
with gr.Row(equal_height=True):
|
||||
slice_inp_path = gr.Textbox(label=i18n("音频自动切分输入路径,可文件可文件夹"), value="")
|
||||
slice_opt_root = gr.Textbox(
|
||||
label=i18n("切分后的子音频的输出根目录"), value="output/slicer_opt"
|
||||
)
|
||||
with gr.Row():
|
||||
with gr.Row(equal_height=True):
|
||||
threshold = gr.Textbox(
|
||||
label=i18n("threshold:音量小于这个值视作静音的备选切割点"), value="-34"
|
||||
)
|
||||
@ -1348,7 +1378,7 @@ with gr.Blocks(title="GPT-SoVITS WebUI", analytics_enabled=False, js=js, css=css
|
||||
value="10",
|
||||
)
|
||||
max_sil_kept = gr.Textbox(label=i18n("max_sil_kept:切完后静音最多留多长"), value="500")
|
||||
with gr.Row():
|
||||
with gr.Row(equal_height=True):
|
||||
_max = gr.Slider(
|
||||
minimum=0,
|
||||
maximum=1,
|
||||
@ -1365,7 +1395,7 @@ with gr.Blocks(title="GPT-SoVITS WebUI", analytics_enabled=False, js=js, css=css
|
||||
value=0.25,
|
||||
interactive=True,
|
||||
)
|
||||
with gr.Row():
|
||||
with gr.Row(equal_height=True):
|
||||
n_process = gr.Slider(
|
||||
minimum=1,
|
||||
maximum=n_cpu,
|
||||
@ -1385,10 +1415,10 @@ with gr.Blocks(title="GPT-SoVITS WebUI", analytics_enabled=False, js=js, css=css
|
||||
# gr.Markdown(value="0bb-" + i18n("语音降噪工具")+i18n("(不稳定,先别用,可能劣化模型效果!)"))
|
||||
with gr.Row(visible=False):
|
||||
with gr.Column(scale=3):
|
||||
with gr.Row():
|
||||
with gr.Row(equal_height=True):
|
||||
denoise_input_dir = gr.Textbox(label=i18n("输入文件夹路径"), value="")
|
||||
denoise_output_dir = gr.Textbox(label=i18n("输出文件夹路径"), value="output/denoise_opt")
|
||||
with gr.Row():
|
||||
with gr.Row(equal_height=True):
|
||||
denoise_info = gr.Textbox(label=process_info(process_name_denoise, "info"))
|
||||
open_denoise_button = gr.Button(
|
||||
value=process_info(process_name_denoise, "open"), variant="primary", visible=True
|
||||
@ -1398,16 +1428,16 @@ with gr.Blocks(title="GPT-SoVITS WebUI", analytics_enabled=False, js=js, css=css
|
||||
)
|
||||
|
||||
with gr.Accordion(label="0c-" + i18n("语音识别工具")):
|
||||
with gr.Row():
|
||||
with gr.Row(equal_height=True):
|
||||
with gr.Column(scale=3):
|
||||
with gr.Row():
|
||||
with gr.Row(equal_height=True):
|
||||
asr_inp_dir = gr.Textbox(
|
||||
label=i18n("输入文件夹路径"), value="D:\\GPT-SoVITS\\raw\\xxx", interactive=True
|
||||
)
|
||||
asr_opt_dir = gr.Textbox(
|
||||
label=i18n("输出文件夹路径"), value="output/asr_opt", interactive=True
|
||||
)
|
||||
with gr.Row():
|
||||
with gr.Row(equal_height=True):
|
||||
asr_model = gr.Dropdown(
|
||||
label=i18n("ASR 模型"),
|
||||
choices=list(asr_dict.keys()),
|
||||
@ -1423,7 +1453,7 @@ with gr.Blocks(title="GPT-SoVITS WebUI", analytics_enabled=False, js=js, css=css
|
||||
asr_precision = gr.Dropdown(
|
||||
label=i18n("数据类型精度"), choices=["float32"], interactive=True, value="float32"
|
||||
)
|
||||
with gr.Row():
|
||||
with gr.Row(equal_height=True):
|
||||
asr_info = gr.Textbox(label=process_info(process_name_asr, "info"))
|
||||
open_asr_button = gr.Button(
|
||||
value=process_info(process_name_asr, "open"), variant="primary", visible=True
|
||||
@ -1455,9 +1485,9 @@ with gr.Blocks(title="GPT-SoVITS WebUI", analytics_enabled=False, js=js, css=css
|
||||
asr_model.change(change_precision_choices, [asr_model], [asr_precision])
|
||||
|
||||
with gr.Accordion(label="0d-" + i18n("语音文本校对标注工具")):
|
||||
with gr.Row():
|
||||
with gr.Row(equal_height=True):
|
||||
with gr.Column(scale=3):
|
||||
with gr.Row():
|
||||
with gr.Row(equal_height=True):
|
||||
path_list = gr.Textbox(
|
||||
label=i18n("标注文件路径 (含文件后缀 *.list)"),
|
||||
value="D:\\RVC1006\\GPT-SoVITS\\raw\\xxx.list",
|
||||
@ -1478,7 +1508,7 @@ with gr.Blocks(title="GPT-SoVITS WebUI", analytics_enabled=False, js=js, css=css
|
||||
|
||||
with gr.TabItem(i18n("1-GPT-SoVITS-TTS")):
|
||||
with gr.Accordion(i18n("微调模型信息")):
|
||||
with gr.Row():
|
||||
with gr.Row(equal_height=True):
|
||||
with gr.Row(equal_height=True):
|
||||
exp_name = gr.Textbox(
|
||||
label=i18n("*实验/模型名"),
|
||||
@ -1500,7 +1530,7 @@ with gr.Blocks(title="GPT-SoVITS WebUI", analytics_enabled=False, js=js, css=css
|
||||
scale=5,
|
||||
)
|
||||
with gr.Accordion(label=i18n("预训练模型路径"), open=False):
|
||||
with gr.Row():
|
||||
with gr.Row(equal_height=True):
|
||||
with gr.Row(equal_height=True):
|
||||
pretrained_s1 = gr.Textbox(
|
||||
label=i18n("预训练GPT模型路径"),
|
||||
@ -1529,15 +1559,15 @@ with gr.Blocks(title="GPT-SoVITS WebUI", analytics_enabled=False, js=js, css=css
|
||||
|
||||
with gr.TabItem("1A-" + i18n("训练集格式化工具")):
|
||||
with gr.Accordion(label=i18n("输出logs/实验名目录下应有23456开头的文件和文件夹")):
|
||||
with gr.Row():
|
||||
with gr.Row():
|
||||
with gr.Row(equal_height=True):
|
||||
with gr.Row(equal_height=True):
|
||||
inp_text = gr.Textbox(
|
||||
label=i18n("*文本标注文件"),
|
||||
value=r"D:\RVC1006\GPT-SoVITS\raw\xxx.list",
|
||||
interactive=True,
|
||||
scale=10,
|
||||
)
|
||||
with gr.Row():
|
||||
with gr.Row(equal_height=True):
|
||||
inp_wav_dir = gr.Textbox(
|
||||
label=i18n("*训练集音频文件目录"),
|
||||
# value=r"D:\RVC1006\GPT-SoVITS\raw\xxx",
|
||||
@ -1549,90 +1579,90 @@ with gr.Blocks(title="GPT-SoVITS WebUI", analytics_enabled=False, js=js, css=css
|
||||
)
|
||||
|
||||
with gr.Accordion(label="1Aa-" + process_name_1a):
|
||||
with gr.Row():
|
||||
with gr.Row():
|
||||
with gr.Row(equal_height=True):
|
||||
with gr.Row(equal_height=True):
|
||||
gpu_numbers1a = gr.Textbox(
|
||||
label=i18n("GPU卡号以-分割,每个卡号一个进程"),
|
||||
value="%s-%s" % (gpus, gpus),
|
||||
interactive=True,
|
||||
)
|
||||
with gr.Row():
|
||||
with gr.Row(equal_height=True):
|
||||
bert_pretrained_dir = gr.Textbox(
|
||||
label=i18n("预训练中文BERT模型路径"),
|
||||
value="GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large",
|
||||
interactive=False,
|
||||
lines=2,
|
||||
)
|
||||
with gr.Row():
|
||||
with gr.Row(equal_height=True):
|
||||
button1a_open = gr.Button(
|
||||
value=process_info(process_name_1a, "open"), variant="primary", visible=True
|
||||
)
|
||||
button1a_close = gr.Button(
|
||||
value=process_info(process_name_1a, "close"), variant="primary", visible=False
|
||||
)
|
||||
with gr.Row():
|
||||
with gr.Row(equal_height=True):
|
||||
info1a = gr.Textbox(label=process_info(process_name_1a, "info"))
|
||||
|
||||
with gr.Accordion(label="1Ab-" + process_name_1b):
|
||||
with gr.Row():
|
||||
with gr.Row():
|
||||
with gr.Row(equal_height=True):
|
||||
with gr.Row(equal_height=True):
|
||||
gpu_numbers1Ba = gr.Textbox(
|
||||
label=i18n("GPU卡号以-分割,每个卡号一个进程"),
|
||||
value="%s-%s" % (gpus, gpus),
|
||||
interactive=True,
|
||||
)
|
||||
with gr.Row():
|
||||
with gr.Row(equal_height=True):
|
||||
cnhubert_base_dir = gr.Textbox(
|
||||
label=i18n("预训练SSL模型路径"),
|
||||
value="GPT_SoVITS/pretrained_models/chinese-hubert-base",
|
||||
interactive=False,
|
||||
lines=2,
|
||||
)
|
||||
with gr.Row():
|
||||
with gr.Row(equal_height=True):
|
||||
button1b_open = gr.Button(
|
||||
value=process_info(process_name_1b, "open"), variant="primary", visible=True
|
||||
)
|
||||
button1b_close = gr.Button(
|
||||
value=process_info(process_name_1b, "close"), variant="primary", visible=False
|
||||
)
|
||||
with gr.Row():
|
||||
with gr.Row(equal_height=True):
|
||||
info1b = gr.Textbox(label=process_info(process_name_1b, "info"))
|
||||
|
||||
with gr.Accordion(label="1Ac-" + process_name_1c):
|
||||
with gr.Row():
|
||||
with gr.Row():
|
||||
with gr.Row(equal_height=True):
|
||||
with gr.Row(equal_height=True):
|
||||
gpu_numbers1c = gr.Textbox(
|
||||
label=i18n("GPU卡号以-分割,每个卡号一个进程"),
|
||||
value="%s-%s" % (gpus, gpus),
|
||||
interactive=True,
|
||||
)
|
||||
with gr.Row():
|
||||
with gr.Row(equal_height=True):
|
||||
pretrained_s2G_ = gr.Textbox(
|
||||
label=i18n("预训练SoVITS-G模型路径"),
|
||||
value=pretrained_sovits_name[version],
|
||||
interactive=False,
|
||||
lines=2,
|
||||
)
|
||||
with gr.Row():
|
||||
with gr.Row(equal_height=True):
|
||||
button1c_open = gr.Button(
|
||||
value=process_info(process_name_1c, "open"), variant="primary", visible=True
|
||||
)
|
||||
button1c_close = gr.Button(
|
||||
value=process_info(process_name_1c, "close"), variant="primary", visible=False
|
||||
)
|
||||
with gr.Row():
|
||||
with gr.Row(equal_height=True):
|
||||
info1c = gr.Textbox(label=process_info(process_name_1c, "info"))
|
||||
|
||||
with gr.Accordion(label="1Aabc-" + process_name_1abc):
|
||||
with gr.Row():
|
||||
with gr.Row():
|
||||
with gr.Row(equal_height=True):
|
||||
with gr.Row(equal_height=True):
|
||||
button1abc_open = gr.Button(
|
||||
value=process_info(process_name_1abc, "open"), variant="primary", visible=True
|
||||
)
|
||||
button1abc_close = gr.Button(
|
||||
value=process_info(process_name_1abc, "close"), variant="primary", visible=False
|
||||
)
|
||||
with gr.Row():
|
||||
with gr.Row(equal_height=True):
|
||||
info1abc = gr.Textbox(label=process_info(process_name_1abc, "info"))
|
||||
|
||||
pretrained_s2G.change(sync, [pretrained_s2G], [pretrained_s2G_])
|
||||
@ -1704,149 +1734,146 @@ with gr.Blocks(title="GPT-SoVITS WebUI", analytics_enabled=False, js=js, css=css
|
||||
|
||||
with gr.TabItem("1B-" + i18n("微调训练")):
|
||||
with gr.Accordion(label="1Ba-" + i18n("SoVITS 训练: 模型权重文件在 SoVITS_weights/")):
|
||||
with gr.Row():
|
||||
with gr.Row(equal_height=True):
|
||||
batch_size = gr.Slider(
|
||||
minimum=1,
|
||||
maximum=default_max_batch_size,
|
||||
step=1,
|
||||
label=i18n("每张显卡的batch_size"),
|
||||
value=default_batch_size,
|
||||
interactive=True,
|
||||
)
|
||||
total_epoch = gr.Slider(
|
||||
minimum=1,
|
||||
maximum=max_sovits_epoch,
|
||||
step=1,
|
||||
label=i18n("总训练轮数total_epoch,不建议太高"),
|
||||
value=default_sovits_epoch,
|
||||
interactive=True,
|
||||
)
|
||||
with gr.Column(scale=2):
|
||||
if_save_latest = gr.Checkbox(
|
||||
label=i18n("是否仅保存最新的权重文件以节省硬盘空间"),
|
||||
value=True,
|
||||
interactive=True,
|
||||
show_label=True,
|
||||
)
|
||||
if_save_every_weights = gr.Checkbox(
|
||||
label=i18n("是否在每次保存时间点将最终小模型保存至weights文件夹"),
|
||||
value=True,
|
||||
interactive=True,
|
||||
show_label=True,
|
||||
)
|
||||
if_grad_ckpt = gr.Checkbox(
|
||||
label="v3是否开启梯度检查点节省显存占用",
|
||||
value=False,
|
||||
interactive=True if version in v3v4set else False,
|
||||
show_label=True,
|
||||
visible=False,
|
||||
) # 只有V3s2可以用
|
||||
with gr.Row(equal_height=True):
|
||||
text_low_lr_rate = gr.Slider(
|
||||
minimum=0.2,
|
||||
maximum=0.6,
|
||||
step=0.05,
|
||||
label=i18n("文本模块学习率权重"),
|
||||
value=0.4,
|
||||
visible=True if version not in v3v4set else False,
|
||||
) # v3v4 not need
|
||||
lora_rank = gr.Radio(
|
||||
label=i18n("LoRA秩"),
|
||||
value="32",
|
||||
choices=["16", "32", "64", "128"],
|
||||
visible=True if version in v3v4set else False,
|
||||
) # v1v2 not need
|
||||
save_every_epoch = gr.Slider(
|
||||
minimum=1,
|
||||
maximum=max_sovits_save_every_epoch,
|
||||
step=1,
|
||||
label=i18n("保存频率save_every_epoch"),
|
||||
value=default_sovits_save_every_epoch,
|
||||
interactive=True,
|
||||
)
|
||||
with gr.Column(scale=3):
|
||||
gpu_numbers1Ba = gr.Textbox(
|
||||
label=i18n("GPU卡号以-分割,每个卡号一个进程"),
|
||||
value="%s" % (gpus),
|
||||
interactive=True,
|
||||
)
|
||||
with gr.Row(equal_height=True):
|
||||
with gr.Column():
|
||||
with gr.Row():
|
||||
batch_size = gr.Slider(
|
||||
minimum=1,
|
||||
maximum=default_max_batch_size,
|
||||
step=1,
|
||||
label=i18n("每张显卡的batch_size"),
|
||||
value=default_batch_size,
|
||||
interactive=True,
|
||||
)
|
||||
total_epoch = gr.Slider(
|
||||
minimum=1,
|
||||
maximum=max_sovits_epoch,
|
||||
step=1,
|
||||
label=i18n("总训练轮数total_epoch,不建议太高"),
|
||||
value=default_sovits_epoch,
|
||||
interactive=True,
|
||||
)
|
||||
with gr.Row():
|
||||
text_low_lr_rate = gr.Slider(
|
||||
minimum=0.2,
|
||||
maximum=0.6,
|
||||
step=0.05,
|
||||
label=i18n("文本模块学习率权重"),
|
||||
value=0.4,
|
||||
visible=True if version not in v3v4set else False,
|
||||
) # v3v4 not need
|
||||
lora_rank = gr.Radio(
|
||||
label=i18n("LoRA秩"),
|
||||
value="32",
|
||||
choices=["16", "32", "64", "128"],
|
||||
visible=True if version in v3v4set else False,
|
||||
) # v1v2 not need
|
||||
save_every_epoch = gr.Slider(
|
||||
minimum=1,
|
||||
maximum=max_sovits_save_every_epoch,
|
||||
step=1,
|
||||
label=i18n("保存频率save_every_epoch"),
|
||||
value=default_sovits_save_every_epoch,
|
||||
interactive=True,
|
||||
)
|
||||
with gr.Column():
|
||||
with gr.Column():
|
||||
if_save_latest = gr.Checkbox(
|
||||
label=i18n("是否仅保存最新的权重文件以节省硬盘空间"),
|
||||
value=True,
|
||||
interactive=True,
|
||||
show_label=True,
|
||||
)
|
||||
if_save_every_weights = gr.Checkbox(
|
||||
label=i18n("是否在每次保存时间点将最终小模型保存至weights文件夹"),
|
||||
value=True,
|
||||
interactive=True,
|
||||
show_label=True,
|
||||
)
|
||||
if_grad_ckpt = gr.Checkbox(
|
||||
label="v3是否开启梯度检查点节省显存占用",
|
||||
value=False,
|
||||
interactive=True if version in v3v4set else False,
|
||||
show_label=True,
|
||||
visible=False,
|
||||
) # 只有V3s2可以用
|
||||
with gr.Row():
|
||||
gpu_numbers1Ba = gr.Textbox(
|
||||
label=i18n("GPU卡号以-分割,每个卡号一个进程"),
|
||||
value="%s" % (gpus),
|
||||
interactive=True,
|
||||
)
|
||||
with gr.Row():
|
||||
with gr.Row():
|
||||
button1Ba_open = gr.Button(
|
||||
value=process_info(process_name_sovits, "open"), variant="primary", visible=True
|
||||
)
|
||||
button1Ba_close = gr.Button(
|
||||
value=process_info(process_name_sovits, "close"), variant="primary", visible=False
|
||||
)
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
info1Ba = gr.Textbox(label=process_info(process_name_sovits, "info"))
|
||||
with gr.Accordion(label="1Bb-" + i18n("GPT 训练: 模型权重文件在 GPT_weights/")):
|
||||
with gr.Row():
|
||||
with gr.Row(equal_height=True):
|
||||
batch_size1Bb = gr.Slider(
|
||||
minimum=1,
|
||||
maximum=40,
|
||||
step=1,
|
||||
label=i18n("每张显卡的batch_size"),
|
||||
value=default_batch_size_s1,
|
||||
interactive=True,
|
||||
)
|
||||
total_epoch1Bb = gr.Slider(
|
||||
minimum=2,
|
||||
maximum=50,
|
||||
step=1,
|
||||
label=i18n("总训练轮数total_epoch"),
|
||||
value=15,
|
||||
interactive=True,
|
||||
)
|
||||
with gr.Column(scale=2):
|
||||
if_save_latest1Bb = gr.Checkbox(
|
||||
label=i18n("是否仅保存最新的权重文件以节省硬盘空间"),
|
||||
value=True,
|
||||
interactive=True,
|
||||
show_label=True,
|
||||
)
|
||||
if_save_every_weights1Bb = gr.Checkbox(
|
||||
label=i18n("是否在每次保存时间点将最终小模型保存至weights文件夹"),
|
||||
value=True,
|
||||
interactive=True,
|
||||
show_label=True,
|
||||
)
|
||||
with gr.Row(equal_height=True):
|
||||
# with gr.Column():
|
||||
save_every_epoch1Bb = gr.Slider(
|
||||
minimum=1,
|
||||
maximum=50,
|
||||
step=1,
|
||||
label=i18n("保存频率save_every_epoch"),
|
||||
value=5,
|
||||
interactive=True,
|
||||
)
|
||||
# with gr.Column():
|
||||
if_dpo = gr.Checkbox(
|
||||
label=i18n("是否开启DPO训练选项(实验性)"),
|
||||
value=False,
|
||||
interactive=True,
|
||||
show_label=True,
|
||||
)
|
||||
with gr.Column(scale=2):
|
||||
gpu_numbers1Bb = gr.Textbox(
|
||||
label=i18n("GPU卡号以-分割,每个卡号一个进程"),
|
||||
value="%s" % (gpus),
|
||||
interactive=True,
|
||||
)
|
||||
with gr.Row(equal_height=True):
|
||||
with gr.Column():
|
||||
with gr.Row():
|
||||
batch_size1Bb = gr.Slider(
|
||||
minimum=1,
|
||||
maximum=40,
|
||||
step=1,
|
||||
label=i18n("每张显卡的batch_size"),
|
||||
value=default_batch_size_s1,
|
||||
interactive=True,
|
||||
with gr.Row(equal_height=True):
|
||||
button1Bb_open = gr.Button(
|
||||
value=process_info(process_name_gpt, "open"), variant="primary", visible=True
|
||||
)
|
||||
total_epoch1Bb = gr.Slider(
|
||||
minimum=2,
|
||||
maximum=50,
|
||||
step=1,
|
||||
label=i18n("总训练轮数total_epoch"),
|
||||
value=15,
|
||||
interactive=True,
|
||||
)
|
||||
with gr.Row():
|
||||
save_every_epoch1Bb = gr.Slider(
|
||||
minimum=1,
|
||||
maximum=50,
|
||||
step=1,
|
||||
label=i18n("保存频率save_every_epoch"),
|
||||
value=5,
|
||||
interactive=True,
|
||||
)
|
||||
if_dpo = gr.Checkbox(
|
||||
label=i18n("是否开启DPO训练选项(实验性)"),
|
||||
value=False,
|
||||
interactive=True,
|
||||
show_label=True,
|
||||
button1Bb_close = gr.Button(
|
||||
value=process_info(process_name_gpt, "close"), variant="primary", visible=False
|
||||
)
|
||||
with gr.Column():
|
||||
with gr.Column():
|
||||
if_save_latest1Bb = gr.Checkbox(
|
||||
label=i18n("是否仅保存最新的权重文件以节省硬盘空间"),
|
||||
value=True,
|
||||
interactive=True,
|
||||
show_label=True,
|
||||
)
|
||||
if_save_every_weights1Bb = gr.Checkbox(
|
||||
label=i18n("是否在每次保存时间点将最终小模型保存至weights文件夹"),
|
||||
value=True,
|
||||
interactive=True,
|
||||
show_label=True,
|
||||
)
|
||||
with gr.Row():
|
||||
gpu_numbers1Bb = gr.Textbox(
|
||||
label=i18n("GPU卡号以-分割,每个卡号一个进程"),
|
||||
value="%s" % (gpus),
|
||||
interactive=True,
|
||||
)
|
||||
with gr.Row():
|
||||
with gr.Row():
|
||||
button1Bb_open = gr.Button(
|
||||
value=process_info(process_name_gpt, "open"), variant="primary", visible=True
|
||||
)
|
||||
button1Bb_close = gr.Button(
|
||||
value=process_info(process_name_gpt, "close"), variant="primary", visible=False
|
||||
)
|
||||
with gr.Row():
|
||||
info1Bb = gr.Textbox(label=process_info(process_name_gpt, "info"))
|
||||
|
||||
button1Ba_close.click(close1Ba, [], [info1Ba, button1Ba_open, button1Ba_close])
|
||||
@ -1858,41 +1885,60 @@ with gr.Blocks(title="GPT-SoVITS WebUI", analytics_enabled=False, js=js, css=css
|
||||
"选择训练完存放在SoVITS_weights和GPT_weights下的模型。默认的几个是底模,体验5秒Zero Shot TTS不训练推理用。"
|
||||
)
|
||||
)
|
||||
with gr.Row():
|
||||
with gr.Row(equal_height=True):
|
||||
with gr.Column(scale=2):
|
||||
with gr.Row():
|
||||
GPT_dropdown = gr.Dropdown(
|
||||
label=i18n("GPT模型列表"),
|
||||
choices=GPT_names,
|
||||
value=GPT_names[-1],
|
||||
interactive=True,
|
||||
)
|
||||
SoVITS_dropdown = gr.Dropdown(
|
||||
label=i18n("SoVITS模型列表"),
|
||||
choices=SoVITS_names,
|
||||
value=SoVITS_names[0],
|
||||
interactive=True,
|
||||
)
|
||||
with gr.Row(equal_height=True):
|
||||
with gr.Column():
|
||||
GPT_dropdown = gr.Dropdown(
|
||||
label=i18n("GPT模型列表"),
|
||||
choices=GPT_names,
|
||||
value=GPT_names[-1],
|
||||
interactive=True,
|
||||
)
|
||||
with gr.Column():
|
||||
SoVITS_dropdown = gr.Dropdown(
|
||||
label=i18n("SoVITS模型列表"),
|
||||
choices=SoVITS_names,
|
||||
value=SoVITS_names[0],
|
||||
interactive=True,
|
||||
)
|
||||
with gr.Column(scale=2):
|
||||
with gr.Row():
|
||||
gpu_number_1C = gr.Textbox(
|
||||
label=i18n("GPU卡号,只能填1个整数"), value=gpus, interactive=True
|
||||
with gr.Row(equal_height=True):
|
||||
gpu_number_1C = gr.Dropdown(
|
||||
choices=sorted(list(GPU_INDEX)),
|
||||
value=sorted(list(GPU_INDEX))[0],
|
||||
label=i18n("GPU卡号,只能填1个整数"),
|
||||
interactive=True,
|
||||
)
|
||||
refresh_button = gr.Button(i18n("刷新模型路径"), variant="primary")
|
||||
refresh_button.click(fn=change_choices, inputs=[], outputs=[SoVITS_dropdown, GPT_dropdown])
|
||||
with gr.Row(equal_height=True):
|
||||
with gr.Row():
|
||||
batched_infer_enabled = gr.Checkbox(
|
||||
label=i18n("启用并行推理版本"), value=False, interactive=True, show_label=True
|
||||
)
|
||||
with gr.Row(equal_height=True):
|
||||
with gr.Column():
|
||||
batched_infer_enabled = gr.Checkbox(
|
||||
label=i18n("启用并行推理版本"), value=False, interactive=True, show_label=True
|
||||
)
|
||||
with gr.Column():
|
||||
backends_dropdown = gr.Dropdown(
|
||||
choices=backends_gradio,
|
||||
label=i18n("推理后端"),
|
||||
value=backends_gradio[-1][-1],
|
||||
interactive=True,
|
||||
)
|
||||
with gr.Row(equal_height=True):
|
||||
tts_info = gr.Textbox(label=process_info(process_name_tts, "info"))
|
||||
open_tts = gr.Button(
|
||||
value=process_info(process_name_tts, "open"), variant="primary", visible=True
|
||||
)
|
||||
close_tts = gr.Button(
|
||||
value=process_info(process_name_tts, "close"), variant="primary", visible=False
|
||||
)
|
||||
with gr.Column():
|
||||
tts_info = gr.Textbox(label=process_info(process_name_tts, "info"), scale=2)
|
||||
|
||||
batched_infer_enabled.change(
|
||||
changeBackend,
|
||||
[batched_infer_enabled],
|
||||
[backends_dropdown],
|
||||
)
|
||||
open_tts.click(
|
||||
change_tts_inference,
|
||||
[
|
||||
@ -1902,6 +1948,7 @@ with gr.Blocks(title="GPT-SoVITS WebUI", analytics_enabled=False, js=js, css=css
|
||||
GPT_dropdown,
|
||||
SoVITS_dropdown,
|
||||
batched_infer_enabled,
|
||||
backends_dropdown,
|
||||
],
|
||||
[tts_info, open_tts, close_tts],
|
||||
)
|
||||
@ -1914,6 +1961,7 @@ with gr.Blocks(title="GPT-SoVITS WebUI", analytics_enabled=False, js=js, css=css
|
||||
GPT_dropdown,
|
||||
SoVITS_dropdown,
|
||||
batched_infer_enabled,
|
||||
backends_dropdown,
|
||||
],
|
||||
[tts_info, open_tts, close_tts],
|
||||
)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user