This commit is contained in:
XXXXRT666 2025-08-16 18:34:35 +08:00
parent fdf794e31d
commit a3f6b5f1e9
179 changed files with 5908 additions and 9363 deletions

View File

@ -115,12 +115,17 @@ Remove-Item $ffDir.FullName -Recurse -Force
Write-Host "[INFO] Installing PyTorch..."
& ".\runtime\python.exe" -m ensurepip
& ".\runtime\python.exe" -m pip install --upgrade pip --no-warn-script-location
switch ($cuda) {
"cu124" {
& ".\runtime\python.exe" -m pip install torch==2.6 torchaudio --index-url https://download.pytorch.org/whl/cu124 --no-warn-script-location
& ".\runtime\python.exe" -m pip install psutil ninja packaging wheel "setuptools>=42" --no-warn-script-location
& ".\runtime\python.exe" -m pip install torch torchaudio --index-url https://download.pytorch.org/whl/cu124 --no-warn-script-location
& ".\runtime\python.exe" -m pip install flash-attn -i https://xxxxrt666.github.io/PIP-Index/ --no-build-isolation
}
"cu128" {
& ".\runtime\python.exe" -m pip install psutil ninja packaging wheel "setuptools>=42" --no-warn-script-location
& ".\runtime\python.exe" -m pip install torch torchaudio --index-url https://download.pytorch.org/whl/cu128 --no-warn-script-location
& ".\runtime\python.exe" -m pip install flash-attn -i https://xxxxrt666.github.io/PIP-Index/ --no-build-isolation
}
default {
Write-Error "Unsupported CUDA version: $cuda"

View File

@ -31,6 +31,15 @@ jobs:
- name: Checkout
uses: actions/checkout@v4
- name: Install Windows CUDA 12.9
if: ${{ runner.os == 'Windows' && matrix.torch_cuda == '12.8' }}
uses: Jimver/cuda-toolkit
id: cuda-toolkit-win-129
with:
cuda: 12.9.1
method: "network"
sub-packages: '["nvcc", "cudart", "visual_studio_integration"]'
- name: Run Build and Upload Script
shell: pwsh
run: |

4
.gitignore vendored
View File

@ -16,8 +16,8 @@ ffprobe*
cfg.json
speakers.json
ref_audios
tools/AP_BWE_main/24kto48k/*
!tools/AP_BWE_main/24kto48k/readme.txt
tools/AP_BWE/24kto48k/*
!tools/AP_BWE/24kto48k/readme.txt
# Byte-compiled / optimized / DLL files
__pycache__/

View File

@ -23,8 +23,10 @@ fi
if [ "$TARGETPLATFORM" = "linux/amd64" ]; then
"${WGET_CMD[@]}" -O miniconda.sh https://repo.anaconda.com/miniconda/Miniconda3-py311_25.3.1-1-Linux-x86_64.sh
SYSROOT_PKG="sysroot_linux-64>=2.28"
elif [ "$TARGETPLATFORM" = "linux/arm64" ]; then
"${WGET_CMD[@]}" -O miniconda.sh https://repo.anaconda.com/miniconda/Miniconda3-py311_25.3.1-1-Linux-aarch64.sh
SYSROOT_PKG="sysroot_linux-aarch64>=2.28"
else
exit 1
fi
@ -45,20 +47,36 @@ rm miniconda.sh
source "$HOME/miniconda3/etc/profile.d/conda.sh"
"$HOME/miniconda3/bin/conda" init bash
source "$HOME/.bashrc"
"$HOME/miniconda3/bin/conda" config --add channels conda-forge
"$HOME/miniconda3/bin/conda" update -q --all -y 1>/dev/null
"$HOME/miniconda3/bin/conda" install python=3.11 -q -y
"$HOME/miniconda3/bin/conda" install gcc=14 gxx ffmpeg cmake make unzip -q -y
"$HOME/miniconda3/bin/conda" install gcc=11 gxx ffmpeg cmake make unzip $SYSROOT_PKG "libstdcxx-ng>=11" -q -y
if [ "$CUDA_VERSION" = "12.8" ]; then
"$HOME/miniconda3/bin/pip" install torch torchaudio --no-cache-dir --index-url https://download.pytorch.org/whl/cu128
"$HOME/miniconda3/bin/conda" install cuda-nvcc=12.8 -c nvidia
elif [ "$CUDA_VERSION" = "12.6" ]; then
"$HOME/miniconda3/bin/pip" install torch==2.6 torchaudio --no-cache-dir --index-url https://download.pytorch.org/whl/cu126
"$HOME/miniconda3/bin/pip" install torch torchaudio --no-cache-dir --index-url https://download.pytorch.org/whl/cu126
"$HOME/miniconda3/bin/conda" install cuda-nvcc=12.6 -c nvidia
fi
CUDA_PATH=$(echo "$HOME/miniconda3/targets/"*-linux | awk '{print $1}')
export CUDA_HOME=$CUDA_PATH
export PATH="$HOME/miniconda3/bin:$PATH"
export PATH="$CUDA_HOME/bin:$PATH"
export PATH="$CUDA_HOME/nvvm/bin:$PATH"
"$HOME/miniconda3/bin/pip" install psutil ninja packaging wheel "setuptools>=42"
"$HOME/miniconda3/bin/pip" install flash-attn -i https://xxxxrt666.github.io/PIP-Index/ --no-build-isolation
"$HOME/miniconda3/bin/pip" cache purge
rm $LOG_PATH

View File

@ -3,8 +3,8 @@
from pytorch_lightning import LightningDataModule
from torch.utils.data import DataLoader
from AR.data.bucket_sampler import DistributedBucketSampler
from AR.data.dataset import Text2SemanticDataset
from GPT_SoVITS.AR.data.bucket_sampler import DistributedBucketSampler
from GPT_SoVITS.AR.data.dataset import Text2SemanticDataset
class Text2SemanticDataModule(LightningDataModule):

View File

@ -220,7 +220,7 @@ class Text2SemanticDataset(Dataset):
flag = 0
path_bert = "%s/%s.pt" % (self.path3, item_name)
if os.path.exists(path_bert) == True:
if os.path.exists(path_bert) is True:
bert_feature = torch.load(path_bert, map_location="cpu")
else:
flag = 1

View File

@ -10,9 +10,9 @@ from typing import Dict
import torch
from pytorch_lightning import LightningModule
from AR.models.t2s_model import Text2SemanticDecoder
from AR.modules.lr_schedulers import WarmupCosineLRSchedule
from AR.modules.optim import ScaledAdam
from GPT_SoVITS.AR.models.t2s_model import Text2SemanticDecoder
from GPT_SoVITS.AR.modules.lr_schedulers import WarmupCosineLRSchedule
from GPT_SoVITS.AR.modules.optim import ScaledAdam
class Text2SemanticLightningModule(LightningModule):
@ -42,7 +42,7 @@ class Text2SemanticLightningModule(LightningModule):
def training_step(self, batch: Dict, batch_idx: int):
opt = self.optimizers()
scheduler = self.lr_schedulers()
forward = self.model.forward if self.config["train"].get("if_dpo", False) == True else self.model.forward_old
forward = self.model.forward if self.config["train"].get("if_dpo", False) is True else self.model.forward_old
loss, acc = forward(
batch["phoneme_ids"],
batch["phoneme_ids_len"],

View File

@ -10,9 +10,9 @@ from typing import Dict
import torch
from pytorch_lightning import LightningModule
from AR.models.t2s_model_onnx import Text2SemanticDecoder
from AR.modules.lr_schedulers import WarmupCosineLRSchedule
from AR.modules.optim import ScaledAdam
from GPT_SoVITS.AR.models.t2s_model_onnx import Text2SemanticDecoder
from GPT_SoVITS.AR.modules.lr_schedulers import WarmupCosineLRSchedule
from GPT_SoVITS.AR.modules.optim import ScaledAdam
class Text2SemanticLightningModule(LightningModule):

View File

@ -9,7 +9,7 @@ from torch.nn import functional as F
from torchmetrics.classification import MulticlassAccuracy
from tqdm import tqdm
from AR.models.utils import (
from GPT_SoVITS.AR.models.utils import (
dpo_loss,
get_batch_logps,
make_pad_mask,
@ -18,8 +18,8 @@ from AR.models.utils import (
sample,
topk_sampling,
)
from AR.modules.embedding import SinePositionalEmbedding, TokenEmbedding
from AR.modules.transformer import LayerNorm, TransformerEncoder, TransformerEncoderLayer
from GPT_SoVITS.AR.modules.embedding import SinePositionalEmbedding, TokenEmbedding
from GPT_SoVITS.AR.modules.transformer import LayerNorm, TransformerEncoder, TransformerEncoderLayer
default_config = {
"embedding_dim": 512,
@ -420,7 +420,7 @@ class Text2SemanticDecoder(nn.Module):
mask=xy_attn_mask,
)
x_len = x_lens.max()
logits = self.ar_predict_layer(xy_dec[:, x_len-1:])
logits = self.ar_predict_layer(xy_dec[:, x_len - 1 :])
###### DPO #############
reject_xy_pos, reject_xy_attn_mask, reject_targets = self.make_input_data(
@ -432,7 +432,7 @@ class Text2SemanticDecoder(nn.Module):
mask=reject_xy_attn_mask,
)
x_len = x_lens.max()
reject_logits = self.ar_predict_layer(reject_xy_dec[:, x_len-1:])
reject_logits = self.ar_predict_layer(reject_xy_dec[:, x_len - 1 :])
# loss
# from feiteng: 每次 duration 越多, 梯度更新也应该更多, 所以用 sum
@ -502,7 +502,7 @@ class Text2SemanticDecoder(nn.Module):
(xy_pos, None),
mask=xy_attn_mask,
)
logits = self.ar_predict_layer(xy_dec[:, x_len-1:]).permute(0, 2, 1)
logits = self.ar_predict_layer(xy_dec[:, x_len - 1 :]).permute(0, 2, 1)
# loss
# from feiteng: 每次 duration 越多, 梯度更新也应该更多, 所以用 sum
loss = F.cross_entropy(logits, targets, reduction="sum")
@ -724,8 +724,8 @@ class Text2SemanticDecoder(nn.Module):
l1 = samples[:, 0] == self.EOS
l2 = tokens == self.EOS
l = l1.logical_or(l2)
removed_idx_of_batch_for_y = torch.where(l == True)[0].tolist()
reserved_idx_of_batch_for_y = torch.where(l == False)[0]
removed_idx_of_batch_for_y = torch.where(l is True)[0].tolist()
reserved_idx_of_batch_for_y = torch.where(l is False)[0]
# batch_indexs = torch.tensor(batch_idx_map, device=y.device)[removed_idx_of_batch_for_y]
for i in removed_idx_of_batch_for_y:
batch_index = batch_idx_map[i]

View File

@ -5,8 +5,8 @@ from torch import nn
from torch.nn import functional as F
from torchmetrics.classification import MulticlassAccuracy
from AR.modules.embedding_onnx import SinePositionalEmbedding, TokenEmbedding
from AR.modules.transformer_onnx import LayerNorm, TransformerEncoder, TransformerEncoderLayer
from GPT_SoVITS.AR.modules.embedding_onnx import SinePositionalEmbedding, TokenEmbedding
from GPT_SoVITS.AR.modules.transformer_onnx import LayerNorm, TransformerEncoder, TransformerEncoderLayer
default_config = {
"embedding_dim": 512,

View File

@ -9,7 +9,7 @@ from torch.nn.init import constant_, xavier_normal_, xavier_uniform_
from torch.nn.modules.linear import NonDynamicallyQuantizableLinear
from torch.nn.parameter import Parameter
from AR.modules.patched_mha_with_cache import multi_head_attention_forward_patched
from GPT_SoVITS.AR.modules.patched_mha_with_cache import multi_head_attention_forward_patched
F.multi_head_attention_forward = multi_head_attention_forward_patched

View File

@ -8,7 +8,7 @@ from torch.nn.init import constant_, xavier_normal_, xavier_uniform_
from torch.nn.modules.linear import NonDynamicallyQuantizableLinear
from torch.nn.parameter import Parameter
from AR.modules.patched_mha_with_cache_onnx import multi_head_attention_forward_patched
from GPT_SoVITS.AR.modules.patched_mha_with_cache_onnx import multi_head_attention_forward_patched
class MultiheadAttention(Module):

View File

@ -2,20 +2,15 @@
import copy
import numbers
from functools import partial
from typing import Any
from typing import Callable
from typing import List
from typing import Optional
from typing import Tuple
from typing import Union
from typing import Any, Callable, List, Optional, Tuple, Union
import torch
from AR.modules.activation import MultiheadAttention
from AR.modules.scaling import BalancedDoubleSwish
from torch import nn
from torch import Tensor
from torch import Tensor, nn
from torch.nn import functional as F
from GPT_SoVITS.AR.modules.activation import MultiheadAttention
from GPT_SoVITS.AR.modules.scaling import BalancedDoubleSwish
_shape_t = Union[int, List[int], torch.Size]

View File

@ -2,20 +2,15 @@
import copy
import numbers
from functools import partial
from typing import Any
from typing import Callable
from typing import List
from typing import Optional
from typing import Tuple
from typing import Union
from typing import Any, Callable, List, Optional, Tuple, Union
import torch
from AR.modules.activation_onnx import MultiheadAttention
from AR.modules.scaling import BalancedDoubleSwish
from torch import nn
from torch import Tensor
from torch import Tensor, nn
from torch.nn import functional as F
from GPT_SoVITS.AR.modules.activation_onnx import MultiheadAttention
from GPT_SoVITS.AR.modules.scaling import BalancedDoubleSwish
_shape_t = Union[int, List[int], torch.Size]

View File

@ -1,72 +0,0 @@
# modified from https://github.com/yangdongchao/SoundStorm/blob/master/soundstorm/s1/AR/text_processing/phonemizer.py
# reference: https://github.com/lifeiteng/vall-e
import itertools
import re
from typing import Dict
from typing import List
import regex
from gruut import sentences
from gruut.const import Sentence
from gruut.const import Word
from AR.text_processing.symbols import SYMBOL_TO_ID
class GruutPhonemizer:
def __init__(self, language: str):
self._phonemizer = sentences
self.lang = language
self.symbol_to_id = SYMBOL_TO_ID
self._special_cases_dict: Dict[str] = {
r"\.\.\.": "... ",
";": "; ",
":": ": ",
",": ", ",
r"\.": ". ",
"!": "! ",
r"\?": "? ",
"": "",
"": "",
"«": "«",
"»": "»",
}
self._punctuation_regexp: str = rf"([{''.join(self._special_cases_dict.keys())}])"
def _normalize_punctuation(self, text: str) -> str:
text = regex.sub(rf"\pZ+{self._punctuation_regexp}", r"\1", text)
text = regex.sub(rf"{self._punctuation_regexp}(\pL)", r"\1 \2", text)
text = regex.sub(r"\pZ+", r" ", text)
return text.strip()
def _convert_punctuation(self, word: Word) -> str:
if not word.phonemes:
return ""
if word.phonemes[0] in ["", "|"]:
return word.text.strip()
phonemes = "".join(word.phonemes)
# remove modifier characters ˈˌː with regex
phonemes = re.sub(r"[ˈˌː͡]", "", phonemes)
return phonemes.strip()
def phonemize(self, text: str, espeak: bool = False) -> str:
text_to_phonemize: str = self._normalize_punctuation(text)
sents: List[Sentence] = [sent for sent in self._phonemizer(text_to_phonemize, lang="en-us", espeak=espeak)]
words: List[str] = [self._convert_punctuation(word) for word in itertools.chain(*sents)]
return " ".join(words)
def transform(self, phonemes):
# convert phonemes to ids
# dictionary is in symbols.py
return [self.symbol_to_id[p] for p in phonemes if p in self.symbol_to_id.keys()]
if __name__ == "__main__":
phonemizer = GruutPhonemizer("en-us")
# text -> IPA
phonemes = phonemizer.phonemize("Hello, wor-ld ?")
print("phonemes:", phonemes)
print("len(phonemes):", len(phonemes))
phoneme_ids = phonemizer.transform(phonemes)
print("phoneme_ids:", phoneme_ids)
print("len(phoneme_ids):", len(phoneme_ids))

View File

@ -1,12 +0,0 @@
# modified from https://github.com/yangdongchao/SoundStorm/blob/master/soundstorm/s1/AR/text_processing/symbols.py
# reference: https://github.com/lifeiteng/vall-e
PAD = "_"
PUNCTUATION = ';:,.!?¡¿—…"«»“” '
LETTERS = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"
IPA_LETTERS = (
"ɑɐɒæɓʙβɔɕçɗɖðʤəɘɚɛɜɝɞɟʄɡɠɢʛɦɧħɥʜɨɪʝɭɬɫɮʟɱɯɰŋɳɲɴøɵɸθœɶʘɹɺɾɻʀʁɽʂʃʈʧʉʊʋⱱʌɣɤʍχʎʏʑʐʒʔʡʕʢǀǁǂǃˈˌːˑʼʴʰʱʲʷˠˤ˞↓↑→↗↘'̩'"
)
SYMBOLS = [PAD] + list(PUNCTUATION) + list(LETTERS) + list(IPA_LETTERS)
SPACE_ID = SYMBOLS.index(" ")
SYMBOL_TO_ID = {s: i for i, s in enumerate(SYMBOLS)}
ID_TO_SYMBOL = {i: s for i, s in enumerate(SYMBOLS)}

View File

@ -0,0 +1,11 @@
import importlib.util
if importlib.util.find_spec("mlx") is not None:
from .sample_funcs_mlx import sample_naive as sample_naive_mlx
from .t2s_engine_mlx import T2SEngine as T2SEngineMLX
backends = ["mlx_static", "mlx_quantized", "mlx_varlen"]
else:
backends = []
__all__ = ["T2SEngineMLX", "sample_naive_mlx", "backends"]

View File

@ -0,0 +1,174 @@
from __future__ import annotations
from typing import cast
import mlx.core as mx
import mlx.nn as nn
from ..structs_mlx import KVCacheQ
from ..t2s_model_abc import (
AttentionABC,
KVCache,
KVCacheHND,
T2SDecoderABC,
TransformerBlockABC,
TransformerDecoderABC,
)
Array = mx.array
class Attention(AttentionABC):
def __init__(self, n_head: int, hidden_dim: int, max_seq_length: int):
super().__init__(n_head, hidden_dim, max_seq_length)
self.kc_class = KVCacheHND
@staticmethod
def quantized_scaled_dot_product_attention(
queries: Array,
q_keys: tuple[Array, Array, Array],
q_values: tuple[Array, Array, Array],
scale: float,
mask: Array,
group_size: int = 32,
bits: int = 8,
) -> Array:
queries *= scale
scores = mx.quantized_matmul(queries, *q_keys, transpose=True, group_size=group_size, bits=bits)
scores = mx.where(mask, scores, -mx.inf)
scores = mx.softmax(scores, axis=-1, precise=True) # type: ignore
out = mx.quantized_matmul(scores, *q_values, transpose=False, group_size=group_size, bits=bits)
return out
def __call__(self, x: Array, input_pos: Array, kv_cache: KVCache | KVCacheQ, cache_idx: Array, attn_mask: Array):
bsz, seqlen, _ = cast(tuple[int, ...], x.shape)
q, k, v = self.in_proj(x).split(3, axis=-1)
q, k, v = map(lambda x: x.reshape(bsz, seqlen, self.n_head, self.head_dim), (q, k, v))
q, k, v = map(lambda x: x.swapaxes(1, 2), (q, k, v))
kv_cache = self.kc_class.update_cache(input_pos, k, v, kv_cache, cache_idx)
assert len(kv_cache) == 2
max_idx = int(input_pos.max())
q, k, v = map(lambda x: x[..., :max_idx, :], (q, *kv_cache))
mask = attn_mask[..., :max_idx]
attn = mx.fast.scaled_dot_product_attention(q, k, v, scale=self.scale, mask=mask)
attn = attn.swapaxes(1, 2).reshape(bsz, seqlen, self.hidden_dim)
attn = self.out_proj(attn)
return attn
# def __call__(self, x: Array, input_pos: Array, kv_cache: KVCache | KVCacheQ, cache_idx: Array, attn_mask: Array):
# bsz, seqlen, _ = cast(tuple[int, ...], x.shape)
# q, k, v = self.in_proj(x).split(3, axis=-1)
# q, k, v = map(lambda x: x.reshape(bsz, seqlen, self.n_head, self.head_dim), (q, k, v))
# q, k, v = map(lambda x: x.swapaxes(1, 2), (q, k, v))
# kv_cache = self.kc_class.update_cache(input_pos, k, v, kv_cache, cache_idx)
# assert len(kv_cache) == 3
# (k_q, k_s, k_b), (v_q, v_s, v_b), (group_size, bits) = kv_cache
# k_q, k_s, k_b, v_q, v_s, v_b = map(lambda x: x[..., : int(input_pos.max()), :], (k_q, k_s, k_b, v_q, v_s, v_b))
# mask = attn_mask[..., : int(input_pos.max())]
# attn = Attention.quantized_scaled_dot_product_attention(
# q,
# (k_q, k_s, k_b),
# (v_q, v_s, v_b),
# self.scale,
# mask,
# group_size,
# bits,
# )
# attn = attn.swapaxes(1, 2).reshape(bsz, seqlen, self.hidden_dim)
# output = self.out_proj(attn)
# return output
class TransformerBlock(TransformerBlockABC):
def __init__(self, n_head: int, ffn_dim: int, hidden_dim: int, max_seq_length: int, *args, **kwds) -> None:
super().__init__(n_head, ffn_dim, hidden_dim, max_seq_length, *args, **kwds)
self.attention = Attention(n_head, hidden_dim, max_seq_length, *args, **kwds)
class TransformerDecoder(TransformerDecoderABC):
def __init__(
self,
hidden_dim: int,
n_layer: int,
n_head: int,
ffn_dim: int,
vocab_size: int,
max_seq_length: int,
max_batch_size: int,
*args,
**kwds,
) -> None:
super().__init__(
hidden_dim,
n_layer,
n_head,
ffn_dim,
vocab_size,
max_seq_length,
max_batch_size,
*args,
**kwds,
)
self.layers = [
TransformerBlock(
n_head,
ffn_dim,
hidden_dim,
max_seq_length,
*args,
**kwds,
)
for _ in range(n_layer)
]
class T2SDecoder(T2SDecoderABC):
def __init__(
self,
config: dict,
max_seq_length: int = 1800,
max_batch_size: int = 10,
) -> None:
super().__init__(config, max_seq_length, max_batch_size)
self.h = TransformerDecoder(
self.hidden_dim, self.n_layer, self.n_head, self.ffn_dim, self.vocab_size, max_seq_length, max_batch_size
)
self.kv_class = KVCacheHND
self.group_size = 32
self.bits = 8
# def init_cache(self, bsz: int = 0):
# return super().init_cache(bsz, group_size=self.group_size, bits=self.bits)
def quantized(self):
for layer in self.h.layers:
# nn.quantize(layer.feed_forward, self.group_size, self.bits)
nn.quantize(layer.attention, self.group_size, self.bits)

View File

@ -0,0 +1,99 @@
from __future__ import annotations
from typing import cast
import mlx.core as mx
from ..structs_mlx import KVCache, KVCacheQ
from ..t2s_model_abc import (
AttentionABC,
KVCacheHND,
T2SDecoderABC,
TransformerBlockABC,
TransformerDecoderABC,
)
Array = mx.array
class Attention(AttentionABC):
def __init__(self, n_head: int, hidden_dim: int, max_seq_length: int):
super().__init__(n_head, hidden_dim, max_seq_length)
self.kc_class = KVCacheHND
def __call__(self, x: Array, input_pos: Array, kv_cache: KVCache | KVCacheQ, cache_idx: Array, attn_mask: Array):
bsz, seqlen, _ = cast(tuple[int, ...], x.shape)
q, k, v = self.in_proj(x).split(3, axis=-1)
q, k, v = map(lambda x: x.reshape(bsz, seqlen, self.n_head, self.head_dim), (q, k, v))
q, k, v = map(lambda x: x.swapaxes(1, 2), (q, k, v))
kv_cache = self.kc_class.update_cache(input_pos, k, v, kv_cache, cache_idx)
assert len(kv_cache) == 2
k, v = kv_cache
attn = mx.fast.scaled_dot_product_attention(q, k, v, scale=self.scale, mask=attn_mask)
attn = attn.swapaxes(1, 2).reshape(bsz, seqlen, self.hidden_dim)
attn = self.out_proj(attn)
return attn
class TransformerBlock(TransformerBlockABC):
def __init__(self, n_head: int, ffn_dim: int, hidden_dim: int, max_seq_length: int) -> None:
super().__init__(n_head, ffn_dim, hidden_dim, max_seq_length)
self.attention = Attention(n_head, hidden_dim, max_seq_length)
class TransformerDecoder(TransformerDecoderABC):
def __init__(
self,
hidden_dim: int,
n_layer: int,
n_head: int,
ffn_dim: int,
vocab_size: int,
max_seq_length: int,
max_batch_size: int,
) -> None:
super().__init__(
hidden_dim,
n_layer,
n_head,
ffn_dim,
vocab_size,
max_seq_length,
max_batch_size,
)
self.layers = [
TransformerBlock(
n_head,
ffn_dim,
hidden_dim,
max_seq_length,
)
for _ in range(n_layer)
]
class T2SDecoder(T2SDecoderABC):
def __init__(
self,
config: dict,
max_seq_length: int = 1800,
max_batch_size: int = 10,
) -> None:
super().__init__(config, max_seq_length, max_batch_size)
self.h = TransformerDecoder(
self.hidden_dim, self.n_layer, self.n_head, self.ffn_dim, self.vocab_size, max_seq_length, max_batch_size
)
self.kv_class = KVCacheHND

View File

@ -0,0 +1,103 @@
from __future__ import annotations
from typing import cast
import mlx.core as mx
from ..structs_mlx import KVCache, KVCacheQ
from ..t2s_model_abc import (
AttentionABC,
KVCacheHND,
T2SDecoderABC,
TransformerBlockABC,
TransformerDecoderABC,
)
Array = mx.array
class Attention(AttentionABC):
def __init__(self, n_head: int, hidden_dim: int, max_seq_length: int):
super().__init__(n_head, hidden_dim, max_seq_length)
self.kc_class = KVCacheHND
def __call__(self, x: Array, input_pos: Array, kv_cache: KVCache | KVCacheQ, cache_idx: Array, attn_mask: Array):
bsz, seqlen, _ = cast(tuple[int, ...], x.shape)
q, k, v = self.in_proj(x).split(3, axis=-1)
q, k, v = map(lambda x: x.reshape(bsz, seqlen, self.n_head, self.head_dim), (q, k, v))
q, k, v = map(lambda x: x.swapaxes(1, 2), (q, k, v))
kv_cache = self.kc_class.update_cache(input_pos, k, v, kv_cache, cache_idx)
assert len(kv_cache) == 2
max_idx = int(input_pos.max())
q, k, v = map(lambda x: x[..., :max_idx, :], (q, *kv_cache))
mask = attn_mask[..., :max_idx]
attn = mx.fast.scaled_dot_product_attention(q, k, v, scale=self.scale, mask=mask)
attn = attn.swapaxes(1, 2).reshape(bsz, seqlen, self.hidden_dim)
attn = self.out_proj(attn)
return attn
class TransformerBlock(TransformerBlockABC):
def __init__(self, n_head: int, ffn_dim: int, hidden_dim: int, max_seq_length: int) -> None:
super().__init__(n_head, ffn_dim, hidden_dim, max_seq_length)
self.attention = Attention(n_head, hidden_dim, max_seq_length)
class TransformerDecoder(TransformerDecoderABC):
def __init__(
self,
hidden_dim: int,
n_layer: int,
n_head: int,
ffn_dim: int,
vocab_size: int,
max_seq_length: int,
max_batch_size: int,
) -> None:
super().__init__(
hidden_dim,
n_layer,
n_head,
ffn_dim,
vocab_size,
max_seq_length,
max_batch_size,
)
self.layers = [
TransformerBlock(
n_head,
ffn_dim,
hidden_dim,
max_seq_length,
)
for _ in range(n_layer)
]
class T2SDecoder(T2SDecoderABC):
def __init__(
self,
config: dict,
max_seq_length: int = 1800,
max_batch_size: int = 10,
) -> None:
super().__init__(config, max_seq_length, max_batch_size)
self.h = TransformerDecoder(
self.hidden_dim, self.n_layer, self.n_head, self.ffn_dim, self.vocab_size, max_seq_length, max_batch_size
)
self.kv_class = KVCacheHND

View File

@ -0,0 +1,64 @@
from functools import partial
from typing import Protocol, cast
import mlx.core as mx
Array = mx.array
class SampleProtocolMLX(Protocol):
@staticmethod
def __call__(
logits: Array,
previous_tokens: Array,
temperature: float,
top_k: int,
top_p: float,
repetition_penalty: float,
) -> Array: ...
class sample_naive(SampleProtocolMLX):
# @partial(mx.compile)
@staticmethod
def __call__(
logits,
previous_tokens,
temperature,
top_k,
top_p,
repetition_penalty,
):
if temperature <= 1e-5:
probs = mx.softmax(logits, axis=-1)
return mx.argmax(probs, axis=-1, keepdims=True).astype(mx.int32)
if repetition_penalty != 1.0:
batch_idx = mx.arange(cast(tuple[int, ...], previous_tokens.shape)[0])
previous_tokens = previous_tokens.astype(mx.int64)
selected_logists = logits[batch_idx, previous_tokens]
selected_logists = mx.where(
selected_logists < 0, selected_logists * repetition_penalty, selected_logists / repetition_penalty
)
logits[batch_idx, previous_tokens] = selected_logists
sorted_indices = mx.argsort(-logits, axis=-1)
sorted_logits = mx.take_along_axis(logits, sorted_indices, axis=-1)
cum_probs = mx.cumsum(mx.softmax(sorted_logits, axis=-1), axis=-1)
sorted_indices_to_remove = cum_probs > top_p
sorted_indices_to_remove[:, -1] = False
indices_to_remove = mx.zeros_like(logits).astype(mx.bool_)
batch_indices = mx.arange(cast(tuple[int, ...], logits.shape)[0])[:, None]
indices_to_remove[batch_indices, sorted_indices] = sorted_indices_to_remove
logits = mx.where(indices_to_remove, -mx.inf, logits)
logits = logits / temperature
v = mx.topk(logits, top_k)
pivot = mx.expand_dims(v[:, 0], -1)
logits = mx.where(logits < pivot, -mx.inf, logits)
gumbel_noise = mx.random.gumbel(shape=cast(tuple[int, ...], logits.shape), dtype=logits.dtype)
idx_next = mx.argmax(logits + gumbel_noise, axis=-1, keepdims=True).astype(mx.int32)
return idx_next

View File

@ -0,0 +1,164 @@
"""
Modified From https://github.com/XXXXRT666/GPT-SoVITS
"""
from __future__ import annotations
import os
from dataclasses import dataclass
from typing import List, MutableSequence, Protocol, TypeAlias, cast
import mlx.core as mx
import torch
from ..PyTorch.structs import T2SRequest, T2SResult
from .sample_funcs_mlx import SampleProtocolMLX, sample_naive
Tensor = torch.Tensor
Array = mx.array
@dataclass(slots=True)
class T2SRequestMLX:
x: List[Array]
x_lens: Array
prompts: Array
bert_feature: List[Array]
valid_length: int
top_k: int = 5
top_p: float = 1
early_stop_num: int = -1
temperature: float = 1.0
repetition_penalty: float = 1.35
@classmethod
def from_torch(cls, request: T2SRequest) -> T2SRequestMLX:
x = list(map(lambda tensor: mx.array(tensor.cpu()), request.x))
x_lens = mx.array(request.x_lens.cpu())
prompts = mx.array(request.prompts.cpu())
bert_feature = list(map(lambda tensor: mx.array(tensor.cpu()), request.bert_feature))
return cls(
x,
x_lens,
prompts,
bert_feature,
request.valid_length,
request.top_k,
request.top_p,
request.early_stop_num,
request.temperature,
request.repetition_penalty,
)
KVCache: TypeAlias = tuple[Array, Array]
KVCacheQ: TypeAlias = tuple[tuple[Array, Array, Array], tuple[Array, Array, Array], tuple[int, int]]
class KVCacheProtocol(Protocol):
@staticmethod
def empty(kv_cache: KVCache | KVCacheQ) -> None: ...
@staticmethod
def update_cache(
input_pos: Array, k_val: Array, v_val: Array, kv_cache: KVCache | KVCacheQ, cache_idx: Array
) -> KVCache | KVCacheQ: ...
@staticmethod
def prefill_kv(k_val: Array, v_val: Array, kv_cache: KVCache | KVCacheQ) -> None: ...
@staticmethod
def init_cache(
batch_size: int, max_seq_length: int, n_heads: int, head_dim: int, dtype: mx.Dtype, *args, **kwds
) -> KVCache | KVCacheQ: ...
class T2SDecoderProtocol(Protocol):
max_seq_length: int
EOS: int
n_head: int
def embed(self, x: list[Array], y: Array, bert_features: list[Array]) -> Array: ...
class T2SEngineProtocol(Protocol):
def _handle_request(self, request: T2SRequest) -> tuple[list[Array], float]: ...
def generate(self, request: T2SRequest) -> T2SResult: ...
@staticmethod
def load_decoder(
weights_path: os.PathLike, max_batch_size: int = 1, implement: str = "MLX"
) -> T2SDecoderProtocol: ...
class T2SSessionMLX:
def __init__(
self,
decoder: T2SDecoderProtocol,
request_torch: T2SRequest,
sample_func: type[SampleProtocolMLX] = sample_naive,
device: mx.Device = mx.Device(mx.cpu),
dtype: mx.Dtype = mx.float32,
):
with mx.stream(device):
request = T2SRequestMLX.from_torch(request_torch)
self.decoder = decoder
self.request = request
self.device = device
self.dtype = dtype
bsz = len(request.x)
y_len: int = cast(tuple[int, ...], request.prompts.shape)[-1]
self.bsz = bsz
self.y_len = y_len
# Cache
self.kv_cache: MutableSequence[KVCache | KVCacheQ]
self.sample = sample_func()
# Forward args
self.x = [i.astype(mx.int32) for i in request.x]
self.x_lens = request.x_lens.astype(mx.int32)
self.y = mx.zeros((bsz, decoder.max_seq_length)).astype(mx.int32)
self.y[:, : cast(tuple[int, ...], request.prompts.shape)[-1]] = request.prompts.astype(mx.int32)
self.bert_feature = [i.astype(dtype) for i in request.bert_feature]
self.prefill_len = self.x_lens + cast(tuple[int, ...], request.prompts.shape)[1]
self.input_pos = mx.zeros_like(self.prefill_len)
self.input_pos += self.prefill_len
# EOS
self.completed = mx.array([False] * len(self.x)).astype(mx.bool_)
self.y_results: List[Array] = [None] * len(self.x) # type: ignore
self.xy_pos = decoder.embed(self.x, request.prompts, self.bert_feature)
max_len = int(self.prefill_len.max(-1))
attn_mask = mx.zeros(shape=(bsz, max_len, max_len), dtype=mx.bool_)
for bs in range(bsz):
pos = int(self.x_lens[bs])
seq_len = pos + y_len
attn_mask[bs, :seq_len, :pos] = True
ar_mask = ~mx.triu(
x=mx.ones(
shape=(
y_len,
y_len,
),
dtype=mx.bool_,
),
k=1,
)
attn_mask[bs, pos:seq_len, pos:seq_len] = ar_mask
attn_mask = mx.repeat(mx.expand_dims(attn_mask, 1), decoder.n_head, 1)
self.attn_mask = attn_mask
mx.eval(self.attn_mask)

View File

@ -0,0 +1,232 @@
import gc
import os
import time
import traceback
from typing import cast
import mlx.core as mx
import torch
from rich.progress import BarColumn, Progress, TextColumn, TimeRemainingColumn
from ..logger import console, logger
from ..PyTorch.structs import T2SEngineProtocol, T2SRequest
from .backends import mlx_quantized, mlx_static, mlx_varlen
from .structs_mlx import T2SResult, T2SSessionMLX
from .t2s_model_abc import T2SDecoderABC
Array = mx.array
Tensor = torch.Tensor
class T2SEngine(T2SEngineProtocol):
def __init__(
self,
decoder_model: T2SDecoderABC,
device: mx.Device | str = mx.Device(mx.cpu),
dtype: torch.dtype | mx.Dtype = torch.float32,
) -> None:
if isinstance(device, str):
match device:
case "mx.cpu":
device = mx.Device(mx.cpu)
case "mx.gpu":
device = mx.Device(mx.gpu)
match dtype:
case torch.float32:
dtype = mx.float32
case torch.float16:
dtype = mx.float16
case torch.bfloat16:
dtype = mx.bfloat16
device = cast(mx.Device, device)
dtype = cast(mx.Dtype, dtype)
assert device.type.value in {0, 1}
assert dtype in {mx.float16, mx.bfloat16, mx.float32}
self.device = device
self.dtype = dtype
mx.set_default_device(device)
decoder_model.set_dtype(self.dtype)
self.decoder_model: T2SDecoderABC = decoder_model
# self.decoder_model.compile()
def _handle_request(self, request: T2SRequest):
decoder = self.decoder_model
session = T2SSessionMLX(decoder, request, device=self.device, dtype=self.dtype)
batch_idx = mx.arange(session.bsz)
t1 = 0.0
infer_speed = 0.0
infer_time = 0.0
with (
mx.stream(session.device),
Progress(
TextColumn("[cyan]{task.description}"),
BarColumn(),
TextColumn("{task.completed}/{task.total}"),
TimeRemainingColumn(),
console=console,
transient=True,
) as progress,
):
max_token = min(1800 - int(session.input_pos.max()), 1500)
task = progress.add_task("T2S Decoding", total=max_token)
for idx in range(1500):
progress.update(task, advance=1)
if idx == 0:
session.kv_cache = decoder.init_cache(session.bsz)
xy_dec = decoder.h.prefill(
session.xy_pos,
session.attn_mask,
session.kv_cache,
) # bs, seq_len, embed_dim
xy_dec = xy_dec[None, batch_idx, session.input_pos - 1]
else:
args, kwds = decoder.pre_forward(session)
xy_dec = decoder.h(
session.input_pos,
session.xy_pos,
session.kv_cache,
batch_idx,
*args,
**kwds,
)
decoder.post_forward(idx, session)
logits = decoder.ar_predict_layer(xy_dec[:, -1])
session.input_pos += 1
if idx == 0:
logits[:, -1] = -mx.inf
samples = session.sample(
logits=logits,
previous_tokens=session.y[:, : session.y_len + idx],
top_k=request.top_k,
top_p=request.top_p,
repetition_penalty=request.repetition_penalty,
temperature=request.temperature,
)
session.y[batch_idx, session.y_len + idx] = samples
argmax_token = mx.argmax(logits, axis=-1)
sample_token = samples.squeeze(1)
EOS_mask = (cast(Array, argmax_token == decoder.EOS)) | (sample_token == decoder.EOS)
newly_done_mask = EOS_mask & (~session.completed)
newly_done_indices = mx.where(newly_done_mask, batch_idx, -1)
pos = mx.where(newly_done_indices != -1, batch_idx, session.bsz)
pos_sorted = mx.sort(pos, axis=0)
valid_count = session.bsz - mx.sum(cast(Array, pos_sorted == session.bsz))
pos_final = pos_sorted[: int(valid_count)]
newly_done_indices = mx.expand_dims(newly_done_indices[pos_final], 0)
if newly_done_indices.size > 0:
for i in newly_done_indices:
session.y_results[int(i)] = session.y[i, session.y_len : session.y_len + idx]
session.completed[newly_done_indices] = True
if mx.all(session.completed).item():
if session.y[:, session.y_len :].sum() == 0:
session.y_results = [mx.array([0]) for _ in range(session.bsz)]
logger.error("Bad Zero Prediction")
else:
logger.info(
f"T2S Decoding EOS {session.prefill_len.tolist().__str__().strip('[]')} -> {[cast(tuple[int, ...], i.shape)[-1] for i in session.y_results].__str__().strip('[]')}"
)
logger.info(f"Infer Speed: {(idx - 1) / (time.perf_counter() - t1):.2f} token/s")
infer_time = time.perf_counter() - t1
infer_speed = (idx - 1) / infer_time
break
if (request.early_stop_num != -1 and idx >= request.early_stop_num) or idx == max_token - 1:
for j in range(session.bsz):
if not session.completed[j].item():
session.y_results[j] = session.y[[j], session.y_len : session.y_len + 1499]
session.completed[j] = True
logger.error("Bad Full Prediction")
logger.info(f"Infer Speed: {(idx - 1) / (time.perf_counter() - t1):.2f} token/s")
infer_time = time.perf_counter() - t1
infer_speed = (idx - 1) / infer_time
break
y_emb = decoder.ar_audio_embedding(samples)
session.xy_pos = decoder.ar_audio_position(session.input_pos - session.x_lens, y_emb)
mx.eval(session.xy_pos, session.y)
if idx == 1:
t1 = time.perf_counter()
if idx % 100 == 0:
mx.clear_cache()
match session.device:
case mx.gpu:
mx.clear_cache()
case mx.cpu:
gc.collect()
result_mlx = session.y_results[: request.valid_length]
mx.eval(result_mlx)
result = [torch.tensor(k) for k in result_mlx]
return result, infer_speed, infer_time
def generate(self, request: T2SRequest):
try:
result, infer_speed, infer_time = self._handle_request(request)
t2s_result = T2SResult(result=result, infer_speed=(infer_speed, infer_time), status="Success")
except Exception as e:
t2s_result = T2SResult(status="Error", exception=e, traceback=traceback.format_exc())
return t2s_result
@staticmethod
def replace_key(state_dict: dict[str, Tensor]):
state_dict_mlx: list[tuple[str, Array]] = []
for key, value in state_dict.items():
key = (
key.replace("model.", "")
.replace("in_proj_", "in_proj.")
.replace("self_attn", "attention")
.replace("linear", "feed_forward.linear")
.replace("norm1", "attention_norm")
.replace("norm2", "ffn_norm")
)
value_mlx = mx.array(value)
state_dict_mlx.append((key, value_mlx))
return state_dict_mlx
@staticmethod
def load_decoder(weights_path: os.PathLike, max_batch_size: int = 1, backend: str = "MLX-Varlen"):
logger.info(f"Loading Text2Semantic Weights from {weights_path} with {backend} Backend")
dict_s1 = torch.load(weights_path, map_location="cpu", weights_only=False, mmap=True)
config = dict_s1["config"]
match backend:
case "MLX-Varlen":
decoder_cls: type[T2SDecoderABC] = mlx_varlen.T2SDecoder
case "MLX-Static":
decoder_cls = mlx_static.T2SDecoder
case "MLX-Quantized":
decoder_cls = mlx_quantized.T2SDecoder
case _:
raise RuntimeError(f"Backend {backend} Not Found")
decoder: T2SDecoderABC = decoder_cls(config, max_batch_size=max_batch_size)
state_dict = dict_s1["weight"]
state_dict_mlx = T2SEngine.replace_key(state_dict)
decoder.load_weights(state_dict_mlx)
decoder.eval()
mx.eval(decoder)
if "Quantized" in backend and isinstance(decoder, mlx_quantized.T2SDecoder):
decoder.quantized()
mx.eval(decoder)
return decoder

View File

@ -0,0 +1,530 @@
from __future__ import annotations
import math
from abc import ABC, abstractmethod
from typing import MutableSequence, cast
import mlx.core as mx
import mlx.nn as nn
from .structs_mlx import KVCache, KVCacheProtocol, KVCacheQ, T2SDecoderProtocol, T2SSessionMLX
Array = mx.array
class TokenEmbedding(nn.Module):
def __init__(
self,
embedding_dim: int,
vocab_size: int,
):
super().__init__()
self.vocab_size = vocab_size
self.embedding_dim = embedding_dim
self.word_embeddings = nn.Embedding(self.vocab_size, self.embedding_dim)
@property
def weight(self):
return self.word_embeddings.weight
def embedding(self, index: int):
return self.word_embeddings.weight[index : index + 1]
def __call__(self, x: Array):
x = self.word_embeddings(x)
return x
class SinePositionalEmbedding(nn.Module):
def __init__(
self,
embedding_dim: int,
scale: bool = False,
max_batch_size: int = 10,
max_seq_len: int = 1800,
):
super().__init__()
self.embedding_dim = embedding_dim
self.x_scale = math.sqrt(embedding_dim) if scale else 1.0
self.alpha = mx.ones(1)
self.max_batch_size = max_batch_size
self.max_seq_len = max_seq_len
self.reverse = False
self._pe = mx.zeros((max_batch_size, max_seq_len, embedding_dim))
self.compute_pe()
def compute_pe(self):
"""Reset the positional encodings."""
if self.reverse:
position = mx.expand_dims(mx.arange(self.max_seq_len - 1, -1, -1.0), axis=1)
else:
position = mx.expand_dims(mx.arange(self.max_seq_len), axis=1)
div_term = mx.exp(
mx.arange(
0,
self.embedding_dim,
2,
)
* -(math.log(10000.0) / self.embedding_dim)
)
pe = self._pe
pe[:, :, 0::2] = mx.sin(position * div_term)
pe[:, :, 1::2] = mx.cos(position * div_term)
def __call__(self, input_pos: Array, x: Array):
"""
Args:
input_pos (Array): [batch_size, ]
x (Array): [batch_size, 1, embed_dim]
Returns:
embedded_x (Array): [batch_size, 1, embed_dim]
"""
batch_size = cast(tuple[int, ...], x.shape)[0]
pe_values = self._pe[mx.arange(batch_size), input_pos - 1] # (batch_size, embed_dim)
return x * self.x_scale + self.alpha * mx.expand_dims(pe_values, 1) # (batch_size, 1, embed_dim)
def prefill(self, x: Array):
"""
Args:
x (Array): [batch_size, seq_len, embed_dim]
Returns:
embedded_x (Array): [batch_size, seq_len, embed_dim]
"""
pe_values = self._pe[:, : cast(tuple[int, ...], x.shape)[-2]]
return x * self.x_scale + self.alpha * pe_values
class KVCacheHND(KVCacheProtocol):
@staticmethod
def empty(kv_cache):
assert len(kv_cache) == 2
k_cache, v_cache = kv_cache
k_cache[:] = 0
v_cache[:] = 0
@staticmethod
def update_cache(input_pos, k_val, v_val, kv_cache, cache_idx):
# input_pos: [B, ], k_val: [B, H, 1, D]
assert len(kv_cache) == 2
k_out, v_out = kv_cache
ip0 = input_pos - 1
k_out[cache_idx, :, ip0, None] = k_val
v_out[cache_idx, :, ip0, None] = v_val
return k_out, v_out
@staticmethod
def prefill_kv(k_val, v_val, kv_cache):
# k_val: [B, S, H, D]
assert len(kv_cache) == 2
k_cache, v_cache = kv_cache
k_cache[..., : cast(tuple[int, ...], k_val.shape)[1], :] = k_val.swapaxes(1, 2)
v_cache[..., : cast(tuple[int, ...], v_val.shape)[1], :] = v_val.swapaxes(1, 2)
@staticmethod
def init_cache(batch_size: int, max_seq_length: int, n_heads: int, head_dim: int, dtype: mx.Dtype) -> KVCache:
cache_shape = (batch_size, n_heads, max_seq_length, head_dim)
return (mx.zeros(cache_shape, dtype=dtype), mx.zeros(cache_shape, dtype=dtype))
class KVCacheHNDQuantized(KVCacheProtocol):
@staticmethod
def _el_per_int(bits: int) -> int:
return 32 // bits
@staticmethod
def _packed_dim(head_dim: int, bits: int = 8) -> int:
el_per_int = KVCacheHNDQuantized._el_per_int(bits)
if head_dim % el_per_int != 0:
raise ValueError(f"{head_dim=} is not divisible by {el_per_int=} ({bits=})")
return head_dim // el_per_int
@staticmethod
def _group_count(head_dim: int, group_size: int = 32) -> int:
assert group_size in {32, 64, 128}
if head_dim % group_size != 0:
raise ValueError(f"{head_dim} is not divisible by {group_size=}")
return head_dim // group_size
@staticmethod
def empty(kv_cache) -> None:
assert len(kv_cache) == 3
(k_q, k_s, k_b), (v_q, v_s, v_b), (_, __) = kv_cache
k_q[:] = 0
k_s[:] = 0
k_b[:] = 0
v_q[:] = 0
v_s[:] = 0
v_b[:] = 0
@staticmethod
def update_cache(
input_pos,
k_val,
v_val,
kv_cache,
cache_idx,
):
# input_pos: [B, ], k_val: [B, H, 1, D]
assert len(kv_cache) == 3
(k_q_out, k_s_out, k_b_out), (v_q_out, v_s_out, v_b_out), (group_size, bits) = kv_cache
k_q, k_s, k_b = mx.quantize(k_val, group_size=group_size, bits=bits)
v_q, v_s, v_b = mx.quantize(v_val, group_size=group_size, bits=bits)
ip0 = input_pos - 1
k_q_out[cache_idx, :, ip0, None] = k_q
k_s_out[cache_idx, :, ip0, None] = k_s
k_b_out[cache_idx, :, ip0, None] = k_b
v_q_out[cache_idx, :, ip0, None] = v_q
v_s_out[cache_idx, :, ip0, None] = v_s
v_b_out[cache_idx, :, ip0, None] = v_b
return (k_q_out, k_s_out, k_b_out), (v_q_out, v_s_out, v_b_out), (group_size, bits)
@staticmethod
def prefill_kv(
k_val,
v_val,
kv_cache,
) -> None:
assert len(kv_cache) == 3
(k_q_out, k_s_out, k_b_out), (v_q_out, v_s_out, v_b_out), (group_size, bits) = kv_cache
S = cast(tuple[int, ...], k_val.shape)[1]
k_sw = k_val.swapaxes(1, 2)
v_sw = v_val.swapaxes(1, 2)
k_q, k_s, k_b = mx.quantize(k_sw, group_size=group_size, bits=bits)
v_q, v_s, v_b = mx.quantize(v_sw, group_size=group_size, bits=bits)
k_q_out[..., :S, :] = k_q
k_s_out[..., :S, :] = k_s
k_b_out[..., :S, :] = k_b
v_q_out[..., :S, :] = v_q
v_s_out[..., :S, :] = v_s
v_b_out[..., :S, :] = v_b
@staticmethod
def init_cache(
batch_size: int,
max_seq_length: int,
n_heads: int,
head_dim: int,
dtype: mx.Dtype,
*,
group_size: int = 32,
bits: int = 8,
) -> KVCacheQ:
packed_dim = KVCacheHNDQuantized._packed_dim(head_dim, bits=bits)
group_cnt = KVCacheHNDQuantized._group_count(head_dim, group_size=group_size)
packed_shape = (batch_size, n_heads, max_seq_length, packed_dim)
group_shape = (batch_size, n_heads, max_seq_length, group_cnt)
k_q = mx.zeros(packed_shape, dtype=mx.uint32)
k_s = mx.zeros(group_shape, dtype=dtype)
k_b = mx.zeros(group_shape, dtype=dtype)
v_q = mx.zeros(packed_shape, dtype=mx.uint32)
v_s = mx.zeros(group_shape, dtype=dtype)
v_b = mx.zeros(group_shape, dtype=dtype)
return (k_q, k_s, k_b), (v_q, v_s, v_b), (group_size, bits)
class AttentionABC(ABC, nn.Module):
def __init__(self, n_head: int, hidden_dim: int, max_seq_length: int, *args, **kwds):
super().__init__()
self.n_head = n_head
self.hidden_dim = hidden_dim
assert hidden_dim % n_head == 0
self.head_dim = hidden_dim // n_head
self.max_seq_length = max_seq_length
# key, query, value projections for all heads, but in a batch
self.in_proj = nn.Linear(hidden_dim, hidden_dim * 3, bias=True)
self.out_proj = nn.Linear(hidden_dim, hidden_dim, bias=True)
self.scale = 1 / math.sqrt(self.head_dim)
self.kc_class: KVCacheProtocol
@abstractmethod
def __call__(
self, x: Array, input_pos: Array, kv_cache: KVCache | KVCacheQ, cache_idx: Array, attn_mask: Array
) -> Array: ...
def prefill(self, x: Array, kv_cache: KVCache | KVCacheQ, attn_mask: Array):
bsz, seqlen, _ = cast(tuple[int, ...], x.shape)
q, k, v = self.in_proj(mx.expand_dims(x, 0)).split(3, axis=-1)
q, k, v = map(lambda x: x.reshape(bsz, seqlen, self.n_head, self.head_dim), (q, k, v))
self.kc_class.prefill_kv(k, v, kv_cache)
q, k, v = map(lambda x: x.swapaxes(1, 2), (q, k, v))
attn = mx.fast.scaled_dot_product_attention(q, k, v, mask=attn_mask, scale=self.scale)
attn = mx.nan_to_num(attn)
attn = attn.swapaxes(1, 2).reshape(1, -1, self.hidden_dim)
output = self.out_proj(attn)
return output
class FeedForward(nn.Module):
def __init__(self, dim: int, hidden_dim: int) -> None:
super().__init__()
self.linear1 = nn.Linear(dim, hidden_dim, bias=True)
self.linear2 = nn.Linear(hidden_dim, dim, bias=True)
def __call__(self, x: Array):
return self.linear2(nn.relu(self.linear1(x)))
class TransformerBlockABC(nn.Module):
def __init__(self, n_head: int, ffn_dim: int, hidden_dim: int, max_seq_length: int, *args, **kwds) -> None:
super().__init__()
self.hidden_dim = hidden_dim
self.max_seq_length = max_seq_length
self.attention: AttentionABC
self.feed_forward = FeedForward(hidden_dim, ffn_dim)
self.attention_norm = nn.LayerNorm(self.hidden_dim)
self.ffn_norm = nn.LayerNorm(self.hidden_dim)
def __call__(self, x: Array, input_pos: Array, kv_cache: KVCache | KVCacheQ, cache_idx: Array, attn_mask: Array):
h = self.attention_norm(
x
+ self.attention(
x,
input_pos,
kv_cache,
cache_idx,
attn_mask,
)
)
out = self.ffn_norm(h + self.feed_forward(h))
return out
def prefill(self, x: Array, attn_mask: Array, kv_cache: KVCache | KVCacheQ):
h = self.attention_norm(
x
+ self.attention.prefill(
x,
kv_cache,
attn_mask,
)
)
out = self.ffn_norm(h + self.feed_forward(h))
return out
class TransformerDecoderABC(nn.Module):
def __init__(
self,
hidden_dim: int,
n_layer: int,
n_head: int,
ffn_dim: int,
vocab_size: int,
max_seq_length: int,
max_batch_size: int,
*args,
**kwds,
) -> None:
super().__init__()
self.hidden_dim = hidden_dim
self.n_head = n_head
assert hidden_dim % n_head == 0
self.head_dim = hidden_dim // n_head
self.vocab_size = vocab_size
self.n_layer = n_layer
self.layers: MutableSequence[TransformerBlockABC]
self.max_seq_length = max_seq_length
self.max_batch_size = max_batch_size
def __call__(
self,
input_pos: Array,
x: Array,
kv_caches: MutableSequence[KVCache | KVCacheQ],
cache_idx: Array,
*args,
**kwds,
):
for layer, kv_cache in zip(self.layers, kv_caches):
x = layer(
x,
input_pos,
kv_cache,
cache_idx,
*args,
**kwds,
)
return x
def prefill(self, x: Array, mask: Array, kv_caches: MutableSequence[KVCache | KVCacheQ]):
for layer, kv_cache in zip(self.layers, kv_caches):
x = layer.prefill(
x,
mask,
kv_cache,
)
return x
class T2SDecoderABC(nn.Module, T2SDecoderProtocol):
def __init__(
self,
config: dict,
max_seq_length: int = 1800,
max_batch_size: int = 10,
) -> None:
super().__init__()
hidden_dim: int = config["model"]["hidden_dim"]
embedding_dim: int = config["model"]["embedding_dim"]
n_head: int = config["model"]["head"]
n_layer: int = config["model"]["n_layer"]
vocab_size: int = config["model"]["vocab_size"]
phoneme_vocab_size: int = config["model"]["phoneme_vocab_size"]
EOS: int = config["model"]["EOS"]
ffn_dim: int = hidden_dim * 4
self.n_layer = int(n_layer)
self.hidden_dim = int(hidden_dim)
self.n_head = int(n_head)
assert hidden_dim % n_head == 0
self.head_dim = int(hidden_dim // n_head)
self.embedding_dim = int(embedding_dim)
self.ffn_dim = int(ffn_dim)
self.vocab_size = int(vocab_size)
self.phoneme_vocab_size = int(phoneme_vocab_size)
self.max_seq_length = max_seq_length
self.max_batch_size = max_batch_size
self.EOS = EOS
assert self.EOS == self.vocab_size - 1
self.bert_proj = nn.Linear(1024, self.embedding_dim)
self.ar_predict_layer = nn.Linear(self.hidden_dim, self.vocab_size, bias=False)
self.h: TransformerDecoderABC
self.ar_text_embedding = TokenEmbedding(self.embedding_dim, self.phoneme_vocab_size)
self.ar_text_position = SinePositionalEmbedding(
self.embedding_dim,
scale=False,
max_batch_size=max_batch_size,
max_seq_len=max_seq_length,
)
self.ar_audio_embedding = TokenEmbedding(self.embedding_dim, self.vocab_size)
self.ar_audio_position = SinePositionalEmbedding(
self.embedding_dim,
scale=False,
max_batch_size=max_batch_size,
max_seq_len=max_seq_length,
)
self.kv_class: KVCacheProtocol
def init_cache(self, bsz: int = 0, *args, **kwds) -> MutableSequence[KVCache | KVCacheQ]:
bsz = bsz or self.h.max_batch_size
assert bsz <= self.h.max_batch_size
seq_lens = self.h.max_seq_length
dtype = self.bert_proj.bias.dtype
cache: MutableSequence[KVCache | KVCacheQ] = [
self.kv_class.init_cache(bsz, seq_lens, self.n_head, self.head_dim, dtype, *args, **kwds)
for _ in range(self.n_layer)
]
mx.eval(cache)
return cache
def embed(
self,
x: list[Array],
y: Array,
bert_features: list[Array],
):
x_len: list[int] = [cast(tuple[int, ...], i.shape)[0] for i in x]
x_len_max = max(x_len)
xy_pos = mx.zeros((len(x), x_len_max + cast(tuple[int, ...], y.shape)[1], self.embedding_dim)).astype(
bert_features[0].dtype
)
bert_features = list(map(lambda x: x.swapaxes(0, 1), bert_features))
y_len = cast(tuple[int, ...], y.shape)[1]
y_emb = self.ar_audio_embedding(y)
y_pos = self.ar_audio_position.prefill(y_emb)
for bs, (x_, len_, bert_feature) in enumerate(zip(x, x_len, bert_features)):
x_emb = self.ar_text_embedding(x_)
bert = self.bert_proj(bert_feature)
x_emb = x_emb + bert
x_pos = self.ar_text_position.prefill(mx.expand_dims(x_emb, 0))
xy_pos[[bs], :len_] = x_pos
xy_pos[[bs], len_ : len_ + y_len] = y_pos
mx.eval(xy_pos)
return xy_pos
def compile(self):
setattr(self.h, "__call__", mx.compile(self.h.__call__))
# setattr(self.h, "prefill", mx.compile(self.h.prefill, shapeless=True))
def pre_forward(self, session: T2SSessionMLX):
attn_mask = session.attn_mask
return list(), dict(attn_mask=attn_mask)
def post_forward(self, idx: int, session: T2SSessionMLX) -> None:
if idx == 0:
prefill_len = session.prefill_len
bsz = session.bsz
range_tensor = mx.arange(self.max_seq_length).reshape(1, 1, 1, self.max_seq_length)
prefill_len_expanded = prefill_len.reshape(bsz, 1, 1, 1)
attn_mask = range_tensor < prefill_len_expanded
attn_mask = mx.repeat(attn_mask, self.n_head, 1)
session.attn_mask = attn_mask
attn_mask = session.attn_mask
input_pos = session.input_pos
attn_mask[mx.arange(session.bsz), :, :, input_pos] = True
mx.eval(attn_mask)

View File

@ -0,0 +1,28 @@
import importlib.util
import torch
from .sample_funcs import sample_naive
from .structs import T2SRequest, T2SResult
from .t2s_engine import T2SEngine as T2SEngineTorch
backends = ["torch_varlen"]
if torch.cuda.is_available():
backends.append("torch_static_cuda_graph")
if importlib.util.find_spec("sageattention") is not None:
for i in range(torch.cuda.device_count()):
major, minor = torch.cuda.get_device_capability(i)
sm_version = major + minor / 10.0
if sm_version >= 7.0:
backends.append("sage_attn_varlen_cuda_graph")
if importlib.util.find_spec("flash_attn") is not None:
for i in range(torch.cuda.device_count()):
major, minor = torch.cuda.get_device_capability(i)
sm_version = major + minor / 10.0
if sm_version >= 7.5:
backends.append("flash_attn_varlen_cuda_graph")
if torch.mps.is_available():
backends.append("mps_flash_attn_varlen")
__all__ = ["T2SEngineTorch", "T2SRequest", "sample_naive", "T2SResult", "backends"]

View File

@ -0,0 +1,157 @@
"""
Modified From https://github.com/XXXXRT666/GPT-SoVITS
"""
from typing import Dict, List, Tuple
import kernels
import torch
from .. import nn
from ..structs import T2SSession
from ..t2s_model_abc import (
AttentionABC,
CUDAGraphCacheABC,
FeedForward,
KVCacheNHD,
KVCacheProtocol,
T2SDecoderABC,
TransformerBlockABC,
TransformerDecoderABC,
)
flash_attn_kernel = None
try:
import flash_attn_interface as flash_attn # type: ignore
flash_attn_kernel = flash_attn.flash_attn_with_kvcache
except ModuleNotFoundError:
try:
import flash_attn # type: ignore
flash_attn_kernel = flash_attn.flash_attn_with_kvcache
except ModuleNotFoundError:
pass
if flash_attn_kernel is None:
flash_attn_kernel = kernels.get_kernel("kernels-community/flash-attn").flash_attn_with_kvcache
Tensor = torch.Tensor
class Attention(AttentionABC):
def __init__(self, n_head, hidden_dim, max_seq_length):
super().__init__(n_head, hidden_dim, max_seq_length)
self.in_proj = nn.Linear(hidden_dim, hidden_dim * 3, bias=True)
self.out_proj = nn.Linear(hidden_dim, hidden_dim, bias=True)
def __call__(self, x: Tensor, input_pos: Tensor, kv_cache: KVCacheProtocol, *args, **kwds) -> Tensor:
bsz, seqlen, _ = x.shape
q, k, v = self.in_proj(x).chunk(3, dim=-1)
q = q.view(bsz, seqlen, self.n_head, self.head_dim)
k = k.view(bsz, seqlen, self.n_head, self.head_dim)
v = v.view(bsz, seqlen, self.n_head, self.head_dim)
attn: Tensor = flash_attn.flash_attn_with_kvcache( # type: ignore
q, kv_cache.k_cache, kv_cache.v_cache, k, v, cache_seqlens=input_pos - 1
)
attn = attn.view(bsz, seqlen, self.hidden_dim)
attn = self.out_proj(attn)
return attn
class TransformerBlock(TransformerBlockABC):
def __init__(self, n_head, ffn_dim, hidden_dim, max_seq_length) -> None:
super().__init__(n_head, ffn_dim, hidden_dim, max_seq_length)
self.attention = Attention(n_head, hidden_dim, max_seq_length)
self.feed_forward = FeedForward(hidden_dim, ffn_dim)
self.attention_norm = nn.LayerNorm([self.hidden_dim])
self.ffn_norm = nn.LayerNorm([self.hidden_dim])
class TransformerDecoder(TransformerDecoderABC):
def __init__(
self,
hidden_dim,
n_layer,
n_head,
ffn_dim,
vocab_size,
max_seq_length,
max_batch_size,
) -> None:
super().__init__(hidden_dim, n_layer, n_head, ffn_dim, vocab_size, max_seq_length, max_batch_size)
self.layers = nn.ModuleList( # type: ignore
TransformerBlock(n_head, ffn_dim, hidden_dim, max_seq_length) for _ in range(n_layer)
)
class T2SDecoder(T2SDecoderABC):
def __init__(
self,
config,
max_seq_length=1800,
max_batch_size=10,
) -> None:
assert torch.cuda.is_available()
super().__init__(config, max_seq_length, max_batch_size)
self.bert_proj = nn.Linear(1024, self.embedding_dim)
self.ar_predict_layer = nn.Linear(self.hidden_dim, self.vocab_size, bias=False)
self.h: TransformerDecoderABC = TransformerDecoder(
self.hidden_dim, self.n_layer, self.n_head, self.ffn_dim, self.vocab_size, max_seq_length, max_batch_size
)
self.kv_class = KVCacheNHD
def post_forward(self, idx: int, session: T2SSession) -> None:
return super().post_forward(idx, session)
def pre_forward(self, session: T2SSession) -> Tuple[List, Dict]:
return super().pre_forward(session)
class CUDAGraphCache(CUDAGraphCacheABC):
def __init__(
self,
decoder: T2SDecoder,
) -> None:
super().__init__(decoder)
def release_graph(self, session: T2SSession):
if session.id != self.id:
self.assigned = False
else:
del session.graph, session.xy_pos_, session.xy_dec_, session.input_pos, session.kv_cache
def get_cache_graph(self, session: T2SSession):
assert self.graph
session.graph = self.graph
session.stream = self.stream
session.xy_pos_ = self.xy_pos
session.xy_dec_ = self.xy_dec
session.input_pos = self.input_pos.copy_(session.input_pos)
for cache, cache_ in zip(self.kv_cache, session.kv_cache):
cache.sync_cache(cache_)
def capture_new_graph(self, session: T2SSession):
session.xy_pos_ = self.xy_pos.clone()
session.xy_dec_ = self.xy_dec.clone()
session.input_pos = self.input_pos.clone().copy_(session.input_pos)
args, kwds = self.decoder.pre_forward(session)
graph = self.decoder.capture(self.input_pos, self.xy_pos, self.xy_dec, self.kv_cache, *args, **kwds)
session.graph = graph
session.stream = torch.cuda.Stream() # type: ignore

View File

@ -0,0 +1,165 @@
import torch
from torch.nn import functional as F
from .. import nn
from ..structs import KVCacheProtocol, T2SSession
from ..t2s_model_abc import (
AttentionABC,
CUDAGraphCacheABC,
FeedForward,
KVCacheHND,
T2SDecoderABC,
TransformerBlockABC,
TransformerDecoderABC,
)
Tensor = torch.Tensor
class Attention(AttentionABC):
def __init__(self, n_head, hidden_dim, max_seq_length):
super().__init__(n_head, hidden_dim, max_seq_length)
# key, query, value projections for all heads, but in a batch
self.in_proj = nn.Linear(hidden_dim, hidden_dim * 3, bias=True)
self.out_proj = nn.Linear(hidden_dim, hidden_dim, bias=True)
def __call__(self, x: Tensor, input_pos: Tensor, kv_cache: KVCacheProtocol, attn_mask: Tensor):
bsz, seqlen, _ = x.shape
q, k, v = self.in_proj(x).chunk(3, dim=-1)
q = q.view(bsz, seqlen, self.n_head, self.head_dim)
k = k.view(bsz, seqlen, self.n_head, self.head_dim)
v = v.view(bsz, seqlen, self.n_head, self.head_dim)
q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v))
k, v = kv_cache.update(input_pos, k, v)
attn = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
attn = attn.transpose(1, 2).contiguous().view(bsz, seqlen, self.hidden_dim)
attn = self.out_proj(attn)
return attn
class TransformerBlock(TransformerBlockABC):
def __init__(self, n_head: int, ffn_dim: int, hidden_dim: int, max_seq_length: int) -> None:
super().__init__(n_head, ffn_dim, hidden_dim, max_seq_length)
self.attention = Attention(n_head, hidden_dim, max_seq_length)
self.feed_forward = FeedForward(hidden_dim, ffn_dim)
self.attention_norm = nn.LayerNorm([self.hidden_dim])
self.ffn_norm = nn.LayerNorm([self.hidden_dim])
class TransformerDecoder(TransformerDecoderABC):
def __init__(
self,
hidden_dim,
n_layer,
n_head,
ffn_dim,
vocab_size,
max_seq_length,
max_batch_size,
) -> None:
super().__init__(hidden_dim, n_layer, n_head, ffn_dim, vocab_size, max_seq_length, max_batch_size)
self.layers = nn.ModuleList( # type: ignore
TransformerBlock(n_head, ffn_dim, hidden_dim, max_seq_length) for _ in range(n_layer)
)
class T2SDecoder(T2SDecoderABC):
def __init__(
self,
config,
max_seq_length=1800,
max_batch_size=10,
) -> None:
super().__init__(config, max_seq_length, max_batch_size)
self.bert_proj = nn.Linear(1024, self.embedding_dim)
self.ar_predict_layer = nn.Linear(self.hidden_dim, self.vocab_size, bias=False)
self.h: TransformerDecoderABC = TransformerDecoder(
self.hidden_dim, self.n_layer, self.n_head, self.ffn_dim, self.vocab_size, max_seq_length, max_batch_size
)
self.kv_class = KVCacheHND
def pre_forward(self, session: T2SSession):
attn_mask = session.attn_mask
return list(), dict(attn_mask=attn_mask)
def post_forward(self, idx: int, session: T2SSession) -> None:
if idx == 0:
prefill_len = session.prefill_len
bsz = session.bsz
range_tensor = torch.arange(self.max_seq_length).view(1, 1, 1, self.max_seq_length)
prefill_len_expanded = prefill_len.view(bsz, 1, 1, 1)
attn_mask = range_tensor < prefill_len_expanded
attn_mask = attn_mask.expand(-1, self.n_head, -1, -1)
session.attn_mask = attn_mask
attn_mask = session.attn_mask
input_pos = session.input_pos
attn_mask[torch.arange(session.bsz), :, :, input_pos] = True
class CUDAGraphCache(CUDAGraphCacheABC):
def __init__(
self,
decoder,
) -> None:
super().__init__(decoder)
if torch.cuda.is_available():
self.attn_mask = (
torch.randint(0, 2, (decoder.max_batch_size, decoder.n_head, 1, decoder.max_seq_length))
.bool()
.to(self.device, self.dtype)
)
def release_graph(self, session: T2SSession):
if session.id != self.id:
self.assigned = False
else:
del (
session.graph,
session.xy_pos_,
session.xy_dec_,
session.input_pos,
session.kv_cache,
session.attn_mask,
)
def get_cache_graph(self, session: T2SSession):
assert self.graph
session.graph = self.graph
session.stream = self.stream
session.xy_pos_ = self.xy_pos
session.xy_dec_ = self.xy_dec
session.input_pos = self.input_pos.copy_(session.input_pos)
session.attn_mask = self.attn_mask
for cache, cache_ in zip(self.kv_cache, session.kv_cache):
cache.sync_cache(cache_)
def capture_new_graph(self, session: T2SSession):
session.xy_pos_ = self.xy_pos.clone()
session.xy_dec_ = self.xy_dec.clone()
session.input_pos = self.input_pos.clone().copy_(session.input_pos)
session.attn_mask = self.attn_mask.clone().copy_(session.attn_mask)
args, kwds = self.decoder.pre_forward(session)
graph = self.decoder.capture(self.input_pos, self.xy_pos, self.xy_dec, self.kv_cache, *args, **kwds)
session.graph = graph
session.stream = torch.cuda.Stream() # type: ignore

View File

@ -0,0 +1,176 @@
from typing import MutableSequence
import sageattention # type: ignore
import torch
from .. import nn
from ..structs import T2SSession
from ..t2s_model_abc import (
AttentionABC,
CUDAGraphCacheABC,
FeedForward,
KVCacheHND,
KVCacheProtocol,
T2SDecoderABC,
TransformerBlockABC,
TransformerDecoderABC,
)
Tensor = torch.Tensor
class Attention(AttentionABC):
def __init__(self, n_head, hidden_dim, max_seq_length):
super().__init__(n_head, hidden_dim, max_seq_length)
# key, query, value projections for all heads, but in a batch
self.in_proj = nn.Linear(hidden_dim, hidden_dim * 3, bias=True)
self.out_proj = nn.Linear(hidden_dim, hidden_dim, bias=True)
def __call__(
self,
x: Tensor,
input_pos: Tensor,
kv_cache: KVCacheProtocol,
cu_seqlens_q: Tensor,
cu_seqlens_kv: Tensor,
) -> Tensor:
bsz, seqlen, _ = x.shape
q, k, v = self.in_proj(x).chunk(3, dim=-1)
q = q.view(bsz, seqlen, self.n_head, self.head_dim)
k = k.view(bsz, seqlen, self.n_head, self.head_dim)
v = v.view(bsz, seqlen, self.n_head, self.head_dim)
q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v))
k, v = kv_cache.update(input_pos, k, v)
attn: Tensor = sageattention.sageattn_varlen(
q,
k,
v,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_kv=cu_seqlens_kv,
max_seqlen_q=1,
max_seqlen_k=self.max_seq_length,
)
attn = attn.transpose(1, 2).contiguous().view(bsz, seqlen, self.hidden_dim)
attn = self.out_proj(attn)
return attn
class TransformerBlock(TransformerBlockABC):
def __init__(self, n_head, ffn_dim, hidden_dim, max_seq_length) -> None:
super().__init__(n_head, ffn_dim, hidden_dim, max_seq_length)
self.attention = Attention(n_head, hidden_dim, max_seq_length)
self.feed_forward = FeedForward(hidden_dim, ffn_dim)
self.attention_norm = nn.LayerNorm([self.hidden_dim])
self.ffn_norm = nn.LayerNorm([self.hidden_dim])
class TransformerDecoder(TransformerDecoderABC):
def __init__(
self,
hidden_dim,
n_layer,
n_head,
ffn_dim,
vocab_size,
max_seq_length,
max_batch_size,
) -> None:
super().__init__(hidden_dim, n_layer, n_head, ffn_dim, vocab_size, max_seq_length, max_batch_size)
self.layers = nn.ModuleList( # type: ignore
TransformerBlock(n_head, ffn_dim, hidden_dim, max_seq_length) for _ in range(n_layer)
)
class T2SDecoder(T2SDecoderABC):
def __init__(
self,
config,
max_seq_length=1800,
max_batch_size=10,
) -> None:
super().__init__(config, max_seq_length, max_batch_size)
self.bert_proj = nn.Linear(1024, self.embedding_dim)
self.ar_predict_layer = nn.Linear(self.hidden_dim, self.vocab_size, bias=False)
self.h: TransformerDecoderABC = TransformerDecoder(
self.hidden_dim, self.n_layer, self.n_head, self.ffn_dim, self.vocab_size, max_seq_length, max_batch_size
)
self.kv_class = KVCacheHND
def pre_forward(self, session: T2SSession) -> tuple[list[Tensor], dict[str, Tensor]]:
return list(), dict(cu_seqlens_q=session.cu_seqlens_q, cu_seqlens_kv=session.cu_seqlens_kv)
def post_forward(self, idx: int, session: T2SSession):
if idx == 0:
session.cu_seqlens_q = torch.arange(0, session.bsz + 1, dtype=torch.int32)
session.cu_seqlens_kv = torch.cat([torch.tensor(0, dtype=torch.int32), session.input_pos])
else:
cu_seqlens_q = session.cu_seqlens_q
cu_seqlens_kv = session.cu_seqlens_kv
cu_seqlens_kv.add_(cu_seqlens_q)
class CUDAGraphCache(CUDAGraphCacheABC):
def __init__(
self,
decoder: T2SDecoder,
) -> None:
super().__init__(decoder)
if torch.cuda.is_available():
self.cu_seqlens_q = torch.arange(0, decoder.max_batch_size + 1, dtype=torch.int32).to(self.device)
self.cu_seqlens_kv = torch.cat([torch.tensor(0, dtype=torch.int32), self.input_pos]).to(self.device)
def release_graph(self, session: T2SSession):
if session.id != self.id:
self.assigned = False
else:
del (
session.graph,
session.xy_pos_,
session.xy_dec_,
session.input_pos,
session.kv_cache,
session.cu_seqlens_q,
session.cu_seqlens_kv,
)
def get_cache_graph(self, session: T2SSession):
assert self.graph
session.graph = self.graph
session.stream = self.stream
session.xy_pos_ = self.xy_pos
session.xy_dec_ = self.xy_dec
session.input_pos = self.input_pos.copy_(session.input_pos)
session.cu_seqlens_q = self.cu_seqlens_q
session.cu_seqlens_kv = self.cu_seqlens_kv
for cache, cache_ in zip(self.kv_cache, session.kv_cache):
cache.sync_cache(cache_)
def capture_new_graph(self, session: T2SSession):
session.xy_pos_ = self.xy_pos.clone()
session.xy_dec_ = self.xy_dec.clone()
session.input_pos = self.input_pos.clone().copy_(session.input_pos)
session.cu_seqlens_q = self.cu_seqlens_q.clone().copy_(session.cu_seqlens_q)
session.cu_seqlens_kv = self.cu_seqlens_kv.clone().copy_(session.cu_seqlens_kv)
args, kwds = self.decoder.pre_forward(session)
graph = self.decoder.capture(self.input_pos, self.xy_pos, self.xy_dec, self.kv_cache, *args, **kwds)
session.graph = graph
session.stream = torch.cuda.Stream() # type: ignore

View File

@ -0,0 +1,165 @@
import torch
from torch.nn import functional as F
from .. import nn
from ..structs import KVCacheProtocol, T2SSession
from ..t2s_model_abc import (
AttentionABC,
CUDAGraphCacheABC,
FeedForward,
KVCacheHND,
T2SDecoderABC,
TransformerBlockABC,
TransformerDecoderABC,
)
Tensor = torch.Tensor
class Attention(AttentionABC):
def __init__(self, n_head, hidden_dim, max_seq_length):
super().__init__(n_head, hidden_dim, max_seq_length)
# key, query, value projections for all heads, but in a batch
self.in_proj = nn.Linear(hidden_dim, hidden_dim * 3, bias=True)
self.out_proj = nn.Linear(hidden_dim, hidden_dim, bias=True)
def __call__(self, x: Tensor, input_pos: Tensor, kv_cache: KVCacheProtocol, attn_mask: Tensor):
bsz, seqlen, _ = x.shape
q, k, v = self.in_proj(x).chunk(3, dim=-1)
q = q.view(bsz, seqlen, self.n_head, self.head_dim)
k = k.view(bsz, seqlen, self.n_head, self.head_dim)
v = v.view(bsz, seqlen, self.n_head, self.head_dim)
q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v))
k, v = kv_cache.update(input_pos, k, v)
attn = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
attn = attn.transpose(1, 2).contiguous().view(bsz, seqlen, self.hidden_dim)
attn = self.out_proj(attn)
return attn
class TransformerBlock(TransformerBlockABC):
def __init__(self, n_head: int, ffn_dim: int, hidden_dim: int, max_seq_length: int) -> None:
super().__init__(n_head, ffn_dim, hidden_dim, max_seq_length)
self.attention = Attention(n_head, hidden_dim, max_seq_length)
self.feed_forward = FeedForward(hidden_dim, ffn_dim)
self.attention_norm = nn.LayerNorm([self.hidden_dim])
self.ffn_norm = nn.LayerNorm([self.hidden_dim])
class TransformerDecoder(TransformerDecoderABC):
def __init__(
self,
hidden_dim,
n_layer,
n_head,
ffn_dim,
vocab_size,
max_seq_length,
max_batch_size,
) -> None:
super().__init__(hidden_dim, n_layer, n_head, ffn_dim, vocab_size, max_seq_length, max_batch_size)
self.layers = nn.ModuleList( # type: ignore
TransformerBlock(n_head, ffn_dim, hidden_dim, max_seq_length) for _ in range(n_layer)
)
class T2SDecoder(T2SDecoderABC):
def __init__(
self,
config,
max_seq_length=1800,
max_batch_size=10,
) -> None:
super().__init__(config, max_seq_length, max_batch_size)
self.bert_proj = nn.Linear(1024, self.embedding_dim)
self.ar_predict_layer = nn.Linear(self.hidden_dim, self.vocab_size, bias=False)
self.h: TransformerDecoderABC = TransformerDecoder(
self.hidden_dim, self.n_layer, self.n_head, self.ffn_dim, self.vocab_size, max_seq_length, max_batch_size
)
self.kv_class = KVCacheHND
def pre_forward(self, session: T2SSession):
attn_mask = session.attn_mask
return list(), dict(attn_mask=attn_mask)
def post_forward(self, idx: int, session: T2SSession) -> None:
if idx == 0:
prefill_len = session.prefill_len
bsz = session.bsz
range_tensor = torch.arange(self.max_seq_length).view(1, 1, 1, self.max_seq_length)
prefill_len_expanded = prefill_len.view(bsz, 1, 1, 1)
attn_mask = range_tensor < prefill_len_expanded
attn_mask = attn_mask.expand(-1, self.n_head, -1, -1)
session.attn_mask = attn_mask
attn_mask = session.attn_mask
input_pos = session.input_pos
attn_mask[torch.arange(session.bsz), :, :, input_pos] = True
class CUDAGraphCache(CUDAGraphCacheABC):
def __init__(
self,
decoder,
) -> None:
super().__init__(decoder)
if torch.cuda.is_available():
self.attn_mask = (
torch.randint(0, 2, (decoder.max_batch_size, decoder.n_head, 1, decoder.max_seq_length))
.bool()
.to(self.device, self.dtype)
)
def release_graph(self, session: T2SSession):
if session.id != self.id:
self.assigned = False
else:
del (
session.graph,
session.xy_pos_,
session.xy_dec_,
session.input_pos,
session.kv_cache,
session.attn_mask,
)
def get_cache_graph(self, session: T2SSession):
assert self.graph
session.graph = self.graph
session.stream = self.stream
session.xy_pos_ = self.xy_pos
session.xy_dec_ = self.xy_dec
session.input_pos = self.input_pos.copy_(session.input_pos)
session.attn_mask = self.attn_mask
for cache, cache_ in zip(self.kv_cache, session.kv_cache):
cache.sync_cache(cache_)
def capture_new_graph(self, session: T2SSession):
session.xy_pos_ = self.xy_pos.clone()
session.xy_dec_ = self.xy_dec.clone()
session.input_pos = self.input_pos.clone().copy_(session.input_pos)
session.attn_mask = self.attn_mask.clone().copy_(session.attn_mask)
args, kwds = self.decoder.pre_forward(session)
graph = self.decoder.capture(self.input_pos, self.xy_pos, self.xy_dec, self.kv_cache, *args, **kwds)
session.graph = graph
session.stream = torch.cuda.Stream() # type: ignore

View File

@ -0,0 +1,144 @@
from typing import NoReturn
import torch
from torch.nn import functional as F
from .. import nn
from ..structs import KVCacheProtocol, T2SSession
from ..t2s_model_abc import (
AttentionABC,
CUDAGraphCacheABC,
FeedForward,
KVCacheHNDVarlen,
T2SDecoderABC,
TransformerBlockABC,
TransformerDecoderABC,
)
Tensor = torch.Tensor
class Attention(AttentionABC):
def __init__(self, n_head, hidden_dim, max_seq_length):
super().__init__(n_head, hidden_dim, max_seq_length)
# key, query, value projections for all heads, but in a batch
self.in_proj = nn.Linear(hidden_dim, hidden_dim * 3, bias=True)
self.out_proj = nn.Linear(hidden_dim, hidden_dim, bias=True)
def __call__(self, x: Tensor, input_pos: Tensor, kv_cache: KVCacheProtocol, attn_mask: Tensor):
bsz, seqlen, _ = x.shape
q, k, v = self.in_proj(x).chunk(3, dim=-1)
q = q.view(bsz, seqlen, self.n_head, self.head_dim)
k = k.view(bsz, seqlen, self.n_head, self.head_dim)
v = v.view(bsz, seqlen, self.n_head, self.head_dim)
q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v))
k, v = kv_cache.update(input_pos, k, v)
max_idx = input_pos.max()
q, k, v = map(lambda x: x[..., :max_idx, :], (q, k, v))
mask = attn_mask[..., :max_idx]
attn = F.scaled_dot_product_attention(q, k, v, mask)
attn = attn.transpose(1, 2).contiguous().view(bsz, seqlen, self.hidden_dim)
attn = self.out_proj(attn)
return attn
class TransformerBlock(TransformerBlockABC):
def __init__(self, n_head: int, ffn_dim: int, hidden_dim: int, max_seq_length: int) -> None:
super().__init__(n_head, ffn_dim, hidden_dim, max_seq_length)
self.attention = Attention(n_head, hidden_dim, max_seq_length)
self.feed_forward = FeedForward(hidden_dim, ffn_dim)
self.attention_norm = nn.LayerNorm([self.hidden_dim])
self.ffn_norm = nn.LayerNorm([self.hidden_dim])
class TransformerDecoder(TransformerDecoderABC):
def __init__(
self,
hidden_dim,
n_layer,
n_head,
ffn_dim,
vocab_size,
max_seq_length,
max_batch_size,
) -> None:
super().__init__(hidden_dim, n_layer, n_head, ffn_dim, vocab_size, max_seq_length, max_batch_size)
self.layers = nn.ModuleList( # type: ignore
TransformerBlock(n_head, ffn_dim, hidden_dim, max_seq_length) for _ in range(n_layer)
)
class T2SDecoder(T2SDecoderABC):
def __init__(
self,
config,
max_seq_length=1800,
max_batch_size=10,
) -> None:
super().__init__(config, max_seq_length, max_batch_size)
self.bert_proj = nn.Linear(1024, self.embedding_dim)
self.ar_predict_layer = nn.Linear(self.hidden_dim, self.vocab_size, bias=False)
self.h: TransformerDecoderABC = TransformerDecoder(
self.hidden_dim, self.n_layer, self.n_head, self.ffn_dim, self.vocab_size, max_seq_length, max_batch_size
)
self.kv_class = KVCacheHNDVarlen
def capture(
self,
*args,
**kwds,
) -> NoReturn:
raise NotImplementedError("Cuda Graph Is Not Supported For Varlen Model")
def pre_forward(self, session: T2SSession):
attn_mask = session.attn_mask
return list(), dict(attn_mask=attn_mask)
def post_forward(self, idx: int, session: T2SSession) -> None:
if idx == 0:
prefill_len = session.prefill_len
bsz = session.bsz
range_tensor = torch.arange(self.max_seq_length).view(1, 1, 1, self.max_seq_length)
prefill_len_expanded = prefill_len.view(bsz, 1, 1, 1)
attn_mask = range_tensor < prefill_len_expanded
attn_mask = attn_mask.expand(-1, self.n_head, -1, -1)
session.attn_mask = attn_mask
attn_mask = session.attn_mask
input_pos = session.input_pos
attn_mask[torch.arange(session.bsz), :, :, input_pos] = True
class CUDAGraphCache(CUDAGraphCacheABC):
def __init__(
self,
decoder,
) -> None:
super().__init__(decoder, False)
def release_graph(self, session: T2SSession):
raise NotImplementedError("Cuda Graph Is Not Supported For Varlen Model")
def get_cache_graph(self, session: T2SSession):
raise NotImplementedError("Cuda Graph Is Not Supported For Varlen Model")
def capture_new_graph(self, session: T2SSession):
raise NotImplementedError("Cuda Graph Is Not Supported For Varlen Model")

View File

@ -0,0 +1,69 @@
"""
Enhanced Type Hint nn.Module
Modified From https://github.com/labmlai/labml/blob/master/helpers/labml_helpers/module.py
"""
from typing import Any
import torch.nn
from torch.nn import (
functional as functional,
)
from torch.nn import (
utils as utils,
)
from torch.nn.modules import * # type: ignore # noqa: F403
from torch.nn.parameter import (
Parameter as Parameter,
)
Tensor = torch.Tensor
class Module(torch.nn.Module):
r"""
Wraps ``torch.nn.Module`` to overload ``__call__`` instead of
``forward`` for better type checking.
`PyTorch Github issue for clarification <https://github.com/pytorch/pytorch/issues/44605>`_
"""
def _forward_unimplemented(self, *input: Any) -> None:
# To stop PyTorch from giving abstract methods warning
pass
def __init_subclass__(cls, **kwargs):
if cls.__dict__.get("__call__", None) is None:
return
setattr(cls, "forward", cls.__dict__["__call__"])
delattr(cls, "__call__")
@property
def device(self) -> torch.device:
params = self.parameters()
try:
sample_param = next(params)
return sample_param.device
except StopIteration:
raise RuntimeError(f"Unable to determine device of {self.__class__.__name__}") from None
class Linear(torch.nn.Linear):
def __call__(self, input: Tensor) -> Tensor:
return super().__call__(input)
class Dropout(torch.nn.Dropout):
def __call__(self, input: Tensor) -> Tensor:
return super().__call__(input)
class Embedding(torch.nn.Embedding):
def __call__(self, input: Tensor) -> Tensor:
return super().__call__(input)
class LayerNorm(torch.nn.LayerNorm):
def __call__(self, input: Tensor) -> Tensor:
return super().__call__(input)

View File

@ -0,0 +1,63 @@
from typing import Protocol
import torch
import torch.nn.functional as F
Tensor = torch.Tensor
class SampleProtocol(Protocol):
@staticmethod
def __call__(
logits: Tensor,
previous_tokens: Tensor,
temperature: float,
top_k: int,
top_p: float,
repetition_penalty: float,
) -> Tensor: ...
class sample_naive(SampleProtocol):
@staticmethod
def __call__(
logits: Tensor,
previous_tokens: Tensor,
temperature: float,
top_k: int,
top_p: float,
repetition_penalty: float,
):
if temperature <= 1e-5:
probs = F.softmax(logits, dim=-1)
return torch.argmax(probs, dim=-1, keepdim=True).to(dtype=torch.int32)
if repetition_penalty != 1.0:
previous_tokens = previous_tokens.long()
score = torch.gather(logits, dim=1, index=previous_tokens)
score = torch.where(
score < 0,
score * repetition_penalty,
score / repetition_penalty,
)
logits.scatter_(dim=1, index=previous_tokens, src=score)
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
cum_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
cum_probs[cum_probs > 1] = 1
sorted_indices_to_remove = cum_probs > top_p
sorted_indices_to_remove[:, 0] = False # keep at least one option
indices_to_remove = sorted_indices_to_remove.scatter(dim=1, index=sorted_indices, src=sorted_indices_to_remove)
logits = logits.masked_fill(indices_to_remove, -float("Inf"))
logits /= temperature
v, _ = torch.topk(logits, top_k)
pivot = v[:, -1].unsqueeze(-1)
logits = torch.where(logits < pivot, -float("Inf"), logits)
probs = F.softmax(logits, dim=-1)
q = torch.empty_like(probs).exponential_(1.0)
idx_next = torch.argmax(probs / q, dim=-1, keepdim=True).to(dtype=torch.int32)
return idx_next

View File

@ -0,0 +1,151 @@
"""
Modified From https://github.com/XXXXRT666/GPT-SoVITS
"""
from __future__ import annotations
from dataclasses import dataclass
from typing import Literal, MutableSequence, Optional, Protocol
import torch
from .sample_funcs import SampleProtocol, sample_naive
Tensor = torch.Tensor
@dataclass
class T2SResult:
result: list[Tensor] | None = None
infer_speed: tuple[float, float] = (0.0, 0.0)
status: Literal["Success", "Error"] = "Success"
exception: Optional[Exception] = None
traceback: Optional[str] = None
@dataclass
class T2SRequest:
x: list[torch.Tensor]
x_lens: Tensor
prompts: torch.Tensor
bert_feature: list[Tensor]
valid_length: int
top_k: int = 5
top_p: float = 1
early_stop_num: int = -1
temperature: float = 1.0
repetition_penalty: float = 1.35
use_cuda_graph: bool = False
debug: bool = False
class KVCacheProtocol(Protocol):
k_cache: Tensor
v_cache: Tensor
def __init__(self, batch_size: int, max_seq_length: int, n_heads: int, head_dim: int) -> None: ...
def empty(self) -> None: ...
def update(self, input_pos: Tensor, k_val: Tensor, v_val: Tensor, *args, **kwds) -> tuple[Tensor, Tensor]: ...
def prefill_kv(self, k_val: Tensor, v_val: Tensor) -> None: ...
def sync_cache(self, kv_cache: KVCacheProtocol) -> None: ...
class T2SDecoderProtocol(Protocol):
max_seq_length: int
EOS: int
n_head: int
@property
def device(self) -> torch.device: ...
def embed(self, x: list[Tensor], y: Tensor, bert_features: list[Tensor]) -> Tensor: ...
class T2SEngineProtocol(Protocol):
def _handle_request(self, request: T2SRequest) -> tuple[list[Tensor], float, float]: ...
def generate(self, request: T2SRequest) -> T2SResult: ...
class T2SSession:
def __init__(
self,
decoder: T2SDecoderProtocol,
request: T2SRequest,
sapmle_func: type[SampleProtocol] = sample_naive,
device: torch.device = torch.device("cpu"),
dtype: torch.dtype = torch.float32,
):
with device:
self.decoder = decoder
self.request = request
self.device = device
self.dtype = dtype
bsz = len(request.x)
y_len = request.prompts.size(-1)
self.bsz = bsz
self.y_len = y_len
request.prompts = request.prompts.to(device, torch.int32)
# Cache
self.kv_cache: MutableSequence[KVCacheProtocol]
self.sample = sapmle_func()
# Forward args
self.x = [i.to(device) for i in request.x]
self.x_lens = request.x_lens.to(torch.int32)
self.y = torch.zeros((bsz, decoder.max_seq_length)).to(torch.int32)
self.y[:, : request.prompts.shape[-1]] = request.prompts
self.bert_feature = [i.to(device, dtype) for i in request.bert_feature]
self.prefill_len = self.x_lens + request.prompts.size(1)
self.input_pos = torch.zeros_like(self.prefill_len)
self.input_pos.add_(self.prefill_len)
# CUDA Graph
self.stream: Optional[torch.cuda.Stream] = None
self.graph: Optional[torch.cuda.CUDAGraph] = None
self.xy_pos_: Tensor
self.xy_dec_: Tensor
# EOS
self.completed = torch.Tensor([False] * len(self.x)).bool().to(device)
self.y_results: list[Tensor] = [None] * len(self.x) # type: ignore
self.xy_pos = decoder.embed(self.x, request.prompts, self.bert_feature)
max_len = int(self.prefill_len.max().item())
attn_mask = torch.zeros(size=(bsz, max_len, max_len), dtype=torch.bool)
for bs in range(bsz):
pos = int(self.x_lens[bs])
seq_len = pos + y_len
attn_mask[bs, :seq_len, :pos] = True
ar_mask = ~torch.triu(
input=torch.ones(
size=(
y_len,
y_len,
),
dtype=torch.bool,
),
diagonal=1,
)
attn_mask[bs, pos:seq_len, pos:seq_len] = ar_mask
self.attn_mask = attn_mask
self.attn_mask = attn_mask.unsqueeze(0).expand(-1, decoder.n_head, -1, -1)
self.id: int = -1
# Sage Attn & Transformer Engine Impl
self.cu_seqlens_q: Tensor
self.cu_seqlens_kv: Tensor

View File

@ -0,0 +1,220 @@
import contextlib
import gc
import os
import sys
import time
import traceback
from importlib import import_module
import torch
from rich.progress import BarColumn, Progress, TextColumn, TimeRemainingColumn
from ..logger import console, logger
from .structs import T2SEngineProtocol, T2SRequest, T2SResult, T2SSession
from .t2s_model_abc import (
CUDAGraphCacheABC,
T2SDecoderABC,
TorchProfiler,
)
torch.set_grad_enabled(False)
class T2SEngine(T2SEngineProtocol):
def __init__(
self,
decoder_model: T2SDecoderABC,
device: torch.device = torch.device("cpu"),
dtype: torch.dtype = torch.float32,
) -> None:
assert device.type in {"cpu", "cuda", "mps", "xpu", "mtia"}
assert dtype in {torch.float16, torch.bfloat16, torch.float32}
self.device = device
self.dtype = dtype
self.decoder_model: T2SDecoderABC = decoder_model.to(self.device, self.dtype)
self.graphcache: CUDAGraphCacheABC = self.init_cache()
def _handle_request(self, request: T2SRequest):
with self.device:
decoder = self.decoder_model
session = T2SSession(decoder, request, device=self.device, dtype=self.dtype)
batch_idx = torch.arange(session.bsz)
t1 = 0.0
infer_speed = 0.0
infer_time = 0.0
torch_profiler = TorchProfiler(request.debug)
with (
torch_profiler.profiler(),
Progress(
TextColumn("[cyan]{task.description}"),
BarColumn(),
TextColumn("{task.completed}/{task.total}"),
TimeRemainingColumn(),
console=console,
transient=True,
) as progress,
):
max_token = min(1800 - session.input_pos.max(), 1500)
task = progress.add_task("T2S Decoding", total=max_token)
for idx in range(max_token):
progress.update(task, advance=1)
if idx == 0:
session.kv_cache = decoder.init_cache(session.bsz)
xy_dec = decoder.h.prefill(session.xy_pos, session.kv_cache, session.attn_mask)
xy_dec = xy_dec[None, batch_idx, session.input_pos - 1]
else:
if request.use_cuda_graph and session.graph is None and torch.cuda.is_available():
self.graphcache.assign_graph(session)
with torch_profiler.record("AR"):
if session.graph:
assert session.stream
session.stream.wait_stream(torch.cuda.default_stream())
with torch.cuda.stream(session.stream):
session.xy_pos_.copy_(session.xy_pos)
session.graph.replay()
xy_dec = session.xy_dec_.clone()
else:
args, kwds = decoder.pre_forward(session)
xy_dec = decoder.h(
session.input_pos,
session.xy_pos,
session.kv_cache,
*args,
**kwds,
)
with torch.cuda.stream(session.stream) if session.stream is not None else contextlib.nullcontext():
decoder.post_forward(idx, session)
logits = decoder.ar_predict_layer(xy_dec[:, -1])
if idx == 0:
logits[:, -1] = float("-inf")
with torch_profiler.record("Sampling"):
samples = session.sample(
logits=logits,
previous_tokens=session.y[:, : session.y_len + idx],
top_k=request.top_k,
top_p=request.top_p,
repetition_penalty=request.repetition_penalty,
temperature=request.temperature,
)
session.y[batch_idx, session.y_len + idx] = samples
session.input_pos.add_(1)
with torch_profiler.record("EOS"):
argmax_token = torch.argmax(logits, dim=-1)
sample_token = samples.squeeze(1)
EOS_mask = (argmax_token == decoder.EOS) | (sample_token == decoder.EOS)
newly_done_mask = EOS_mask & (~session.completed)
newly_done_indices = newly_done_mask.nonzero()
if newly_done_indices.numel() > 0:
for i in newly_done_indices:
session.y_results[i] = session.y[i, session.y_len : session.y_len + idx]
session.completed[newly_done_indices] = True
if torch.all(session.completed).item():
if session.y[:, session.y_len :].sum() == 0:
session.y_results = [torch.tensor(0) for _ in range(session.bsz)]
logger.error("Bad Zero Prediction")
else:
logger.info(
f"T2S Decoding EOS {session.prefill_len.tolist().__str__().strip('[]')} -> {[i.size(-1) for i in session.y_results].__str__().strip('[]')}"
)
logger.info(f"Infer Speed: {(idx - 1) / (time.perf_counter() - t1):.2f} token/s")
infer_time = time.perf_counter() - t1
infer_speed = (idx - 1) / infer_time
break
if (request.early_stop_num != -1 and idx >= request.early_stop_num) or idx == max_token - 1:
for i in range(session.bsz):
if not session.completed[i].item():
session.y_results[i] = session.y[i, session.y_len : session.y_len + 1499]
session.completed[i] = True
logger.error("Bad Full Prediction")
break
with torch_profiler.record("NextPos"):
y_emb = decoder.ar_audio_embedding(samples)
session.xy_pos = decoder.ar_audio_position(session.input_pos - session.x_lens, y_emb)
if idx == 1:
torch_profiler.start()
t1 = time.perf_counter()
if idx == 51:
torch_profiler.end()
if idx % 100 == 0:
match session.device.type:
case "cuda":
torch.cuda.empty_cache()
case "mps":
torch.mps.empty_cache()
case "xpu":
torch.xpu.empty_cache()
case "mtia":
torch.mtia.empty_cache()
match session.device.type:
case "cuda":
if session.stream is not None:
torch.cuda.current_stream().wait_stream(session.stream)
torch.cuda.empty_cache()
case "mps":
torch.mps.empty_cache()
case "xpu":
torch.xpu.empty_cache()
case "mtia":
torch.mtia.empty_cache()
case "cpu":
gc.collect()
torch_profiler.end()
if request.use_cuda_graph and torch.cuda.is_available():
self.graphcache.release_graph(session)
return session.y_results[: request.valid_length], infer_speed, infer_time
def generate(self, request: T2SRequest):
try:
result, infer_speed, infer_time = self._handle_request(request)
t2s_result = T2SResult(result=result, infer_speed=(infer_speed, infer_time), status="Success")
except Exception as e:
t2s_result = T2SResult(status="Error", exception=e, traceback=traceback.format_exc())
return t2s_result
@staticmethod
def load_decoder(weights_path: os.PathLike, max_batch_size: int = 1, backend: str = "Flash Attn CUDAGraph"):
logger.info(f"Loading Text2Semantic Weights from {weights_path} with {backend} Backend")
module_path = f".backends.{backend.lower().replace('-', '_')}"
decoder_cls_name = "T2SDecoder"
decoder_mod = import_module(module_path, package=__package__)
decoder_cls: type[T2SDecoderABC] = getattr(decoder_mod, decoder_cls_name)
dict_s1 = torch.load(weights_path, map_location="cpu", weights_only=False, mmap=True)
config = dict_s1["config"]
decoder: T2SDecoderABC = decoder_cls(config, max_batch_size=max_batch_size)
state_dict = dict_s1["weight"]
decoder.load_state_dict(state_dict)
return decoder.eval()
def init_cache(self):
assert self.decoder_model
module_name = self.decoder_model.__class__.__module__
module = sys.modules.get(module_name)
assert module
target_class: type[CUDAGraphCacheABC] = getattr(module, "CUDAGraphCache")
return target_class(self.decoder_model)

View File

@ -0,0 +1,670 @@
"""
Modified From https://github.com/XXXXRT666/GPT-SoVITS
"""
from __future__ import annotations
import math
import os
import random
from abc import ABC, abstractmethod
from contextlib import nullcontext
from typing import MutableSequence
import torch
import torch._inductor.config
import torch.nn.functional as F
from torch.cuda.graphs import CUDAGraph
from torch.profiler import ProfilerAction, tensorboard_trace_handler
from . import nn
from .structs import KVCacheProtocol, T2SDecoderProtocol, T2SSession
Tensor = torch.Tensor
class TokenEmbedding(nn.Module):
def __init__(
self,
embedding_dim: int,
vocab_size: int,
):
super().__init__()
self.vocab_size = vocab_size
self.embedding_dim = embedding_dim
self.word_embeddings = nn.Embedding(self.vocab_size, self.embedding_dim)
@property
def weight(self) -> Tensor:
return self.word_embeddings.weight
def embedding(self, index: int) -> Tensor:
return self.word_embeddings.weight[index : index + 1]
def __call__(self, x: Tensor):
x = self.word_embeddings(x)
return x
class SinePositionalEmbedding(nn.Module):
def __init__(
self,
embedding_dim: int,
scale: bool = False,
alpha: bool = False,
max_batch_size: int = 10,
max_seq_len: int = 1800,
):
super().__init__()
self.embedding_dim = embedding_dim
self.x_scale = math.sqrt(embedding_dim) if scale else 1.0
self.alpha = nn.Parameter(torch.ones(1), requires_grad=alpha)
self.max_batch_size = max_batch_size
self.max_seq_len = max_seq_len
self.reverse = False
self.register_buffer("pe", torch.zeros(max_batch_size, max_seq_len, embedding_dim), persistent=False)
self.pe: torch.Tensor
self.compute_pe()
def compute_pe(self):
"""Reset the positional encodings."""
if self.reverse:
position = torch.arange(self.max_seq_len - 1, -1, -1.0, dtype=torch.float32).unsqueeze(1)
else:
position = torch.arange(self.max_seq_len, dtype=torch.float32).unsqueeze(1)
div_term = torch.exp(
torch.arange(0, self.embedding_dim, 2, dtype=torch.float32) * -(math.log(10000.0) / self.embedding_dim)
)
pe = self.pe
pe[:, :, 0::2] = torch.sin(position * div_term)
pe[:, :, 1::2] = torch.cos(position * div_term)
def __call__(self, input_pos: Tensor, x: Tensor) -> Tensor:
"""
Args:
input_pos (Tensor): [batch_size, ]
x (Tensor): [batch_size, 1, embed_dim]
Returns:
embedded_x (Tensor): [batch_size, 1, embed_dim]
"""
batch_size = x.shape[0]
pe_values = self.pe[torch.arange(batch_size), input_pos - 1] # (batch_size, embed_dim)
return x * self.x_scale + self.alpha * pe_values.unsqueeze(1) # (batch_size, 1, embed_dim)
def prefill(self, x: Tensor) -> Tensor:
"""
Args:
x (Tensor): [batch_size, seq_len, embed_dim]
Returns:
embedded_x (Tensor): [batch_size, seq_len, embed_dim]
"""
pe_values = self.pe[:, : x.shape[-2]]
return x * self.x_scale + self.alpha.item() * pe_values
class KVCacheABC(nn.Module, ABC, KVCacheProtocol):
def __init__(self, batch_size: int, max_seq_length: int, n_heads: int, head_dim: int) -> None:
super().__init__()
self.n_head = n_heads
self.head_dim = head_dim
self.batch_size = batch_size
self.max_seq_length = max_seq_length
self.k_cache: Tensor
self.v_cache: Tensor
def empty(self):
self.k_cache.zero_()
self.v_cache.zero_()
@abstractmethod
def update(self, input_pos: Tensor, k_val: Tensor, v_val: Tensor, *args, **kwds) -> tuple[Tensor, Tensor]: ...
@abstractmethod
def prefill_kv(self, k_val: Tensor, v_val: Tensor) -> None: ...
def sync_cache(self, kv_cache: KVCacheProtocol):
self.k_cache.copy_(kv_cache.k_cache)
self.v_cache.copy_(kv_cache.v_cache)
class KVCacheNHD(KVCacheABC):
def __init__(self, batch_size, max_seq_length, n_heads, head_dim):
super().__init__(batch_size, max_seq_length, n_heads, head_dim)
assert batch_size > 0
cache_shape = (batch_size, max_seq_length, n_heads, head_dim)
self.register_buffer("k_cache", torch.zeros(size=cache_shape), persistent=False)
self.register_buffer("v_cache", torch.zeros(size=cache_shape), persistent=False)
def update(self, input_pos: Tensor, k_val: Tensor, v_val: Tensor):
# input_pos: [B, ], k_val: [B, 1, H, D]
index = (
(input_pos - 1)
.unsqueeze(-1)
.unsqueeze(-1)
.unsqueeze(-1)
.expand(
-1,
-1,
self.n_head,
self.head_dim,
)
.to(torch.int64)
) # (bs, 1, num_head, head_dim)
k_out = self.k_cache
v_out = self.v_cache
k_out.scatter_(1, index, k_val)
v_out.scatter_(1, index, v_val)
return k_out, v_out
def empty(self):
self.k_cache.zero_()
self.v_cache.zero_()
def prefill_kv(self, k_val: Tensor, v_val: Tensor):
# input_pos: int, k_val: [B, S, H, D]
self.k_cache[:, : k_val.shape[1]] = k_val
self.v_cache[:, : v_val.shape[1]] = v_val
class KVCacheHND(KVCacheABC):
def __init__(self, batch_size, max_seq_length, n_heads, head_dim):
super().__init__(batch_size, max_seq_length, n_heads, head_dim)
cache_shape = (batch_size, n_heads, max_seq_length, head_dim)
self.register_buffer("k_cache", torch.zeros(size=cache_shape), persistent=False)
self.register_buffer("v_cache", torch.zeros(size=cache_shape), persistent=False)
def update(self, input_pos: Tensor, k_val: Tensor, v_val: Tensor):
# input_pos: [B, ], k_val: [B, H, 1, D]
index = (
(input_pos - 1)
.unsqueeze(-1)
.unsqueeze(-1)
.unsqueeze(-1)
.expand(
-1,
self.n_head,
-1,
self.head_dim,
)
.to(torch.int64)
) # (bs, num_head, 1, head_dim)
k_out = self.k_cache
v_out = self.v_cache
k_out.scatter_(2, index, k_val)
v_out.scatter_(2, index, v_val)
return k_out, v_out
def empty(self):
self.k_cache.zero_()
self.v_cache.zero_()
def prefill_kv(self, k_val: Tensor, v_val: Tensor):
# input_pos: int, k_val: [B, S, H, D]
self.k_cache[..., : k_val.shape[1], :] = k_val.transpose(1, 2)
self.v_cache[..., : v_val.shape[1], :] = v_val.transpose(1, 2)
class KVCacheHNDVarlen(KVCacheABC):
def __init__(self, batch_size, max_seq_length, n_heads, head_dim):
super().__init__(batch_size, max_seq_length, n_heads, head_dim)
cache_shape = (batch_size, n_heads, max_seq_length, head_dim)
self.cache_idx: Tensor
self.register_buffer("cache_idx", torch.arange(batch_size), persistent=False)
self.register_buffer("k_cache", torch.zeros(size=cache_shape), persistent=False)
self.register_buffer("v_cache", torch.zeros(size=cache_shape), persistent=False)
def update(self, input_pos: Tensor, k_val: Tensor, v_val: Tensor):
# input_pos: [B, ], k_val: [B, H, 1, D]
k_out = self.k_cache
v_out = self.v_cache
ip0 = input_pos - 1
k_out[self.cache_idx, :, ip0, None] = k_val
v_out[self.cache_idx, :, ip0, None] = v_val
return k_out, v_out
def empty(self):
self.k_cache.zero_()
self.v_cache.zero_()
def prefill_kv(self, k_val: Tensor, v_val: Tensor):
# input_pos: int, k_val: [B, S, H, D]
self.k_cache[..., : k_val.shape[1], :] = k_val.transpose(1, 2)
self.v_cache[..., : v_val.shape[1], :] = v_val.transpose(1, 2)
class AttentionABC(nn.Module, ABC):
def __init__(self, n_head: int, hidden_dim: int, max_seq_length: int):
super().__init__()
self.n_head = n_head
self.hidden_dim = hidden_dim
assert hidden_dim % n_head == 0
self.head_dim = hidden_dim // n_head
self.max_seq_length = max_seq_length
# key, query, value projections for all heads, but in a batch
self.in_proj: nn.Linear
self.out_proj: nn.Linear
self._register_load_state_dict_pre_hook(self.load_hook)
def load_hook(self, state_dict: dict[str, Tensor], prefix, *args):
keys_to_modify = [key for key in state_dict if "in_proj_" in key]
for key in keys_to_modify:
new_key = key.replace("in_proj_", "in_proj.") # in_proj_ -> in_proj.
state_dict[new_key] = state_dict.pop(key)
@abstractmethod
def __call__(self, x: Tensor, input_pos: Tensor, kv_cache: KVCacheProtocol, *args, **kwds) -> Tensor: ...
def prefill(self, x: Tensor, kv_cache: KVCacheProtocol, attn_mask: Tensor) -> Tensor:
bsz, seqlen, _ = x.shape
q, k, v = self.in_proj(x.unsqueeze(0)).chunk(3, dim=-1)
q, k, v = map(lambda x: x.contiguous().view(bsz, seqlen, self.n_head, self.head_dim), (q, k, v))
kv_cache.prefill_kv(k, v)
q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v))
attn = F.scaled_dot_product_attention(q, k, v, attn_mask)
attn = attn.transpose(1, 2).contiguous().view(1, -1, self.hidden_dim)
output = self.out_proj(attn)
return output
class FeedForward(nn.Module):
def __init__(self, dim: int, hidden_dim: int) -> None:
super().__init__()
self.linear1 = nn.Linear(dim, hidden_dim, bias=True)
self.linear2 = nn.Linear(hidden_dim, dim, bias=True)
def __call__(self, x: Tensor):
return self.linear2(F.relu(self.linear1(x), inplace=True))
class TransformerBlockABC(nn.Module, ABC):
def __init__(self, n_head: int, ffn_dim: int, hidden_dim: int, max_seq_length: int) -> None:
super().__init__()
self.hidden_dim = hidden_dim
self.max_seq_length = max_seq_length
self.attention: AttentionABC
self.feed_forward: FeedForward
self.attention_norm: nn.LayerNorm
self.ffn_norm: nn.LayerNorm
self._register_load_state_dict_pre_hook(self.load_hook)
def load_hook(self, state_dict: dict[str, Tensor], prefix, *args):
for key in list(state_dict.keys()):
new_key = (
key.replace("self_attn", "attention")
.replace("linear", "feed_forward.linear")
.replace("norm1", "attention_norm")
.replace("norm2", "ffn_norm")
)
state_dict[new_key] = state_dict.pop(key)
def __call__(self, x: Tensor, input_pos: Tensor, kv_cache: KVCacheProtocol, *args, **kwds):
h = self.attention_norm(
x
+ self.attention(
x,
input_pos,
kv_cache,
*args,
**kwds,
)
)
out = self.ffn_norm(h + self.feed_forward(h))
return out
def prefill(
self,
x: Tensor,
kv_cache: KVCacheProtocol,
attn_mask: Tensor,
) -> Tensor:
h = self.attention_norm(
x
+ self.attention.prefill(
x,
kv_cache,
attn_mask,
)
)
out = self.ffn_norm(h + self.feed_forward(h))
return out
class TransformerDecoderABC(nn.Module, ABC):
def __init__(
self,
hidden_dim: int,
n_layer: int,
n_head: int,
ffn_dim: int,
vocab_size: int,
max_seq_length: int,
max_batch_size: int,
) -> None:
super().__init__()
self.hidden_dim = hidden_dim
self.n_head = n_head
assert hidden_dim % n_head == 0
self.head_dim = hidden_dim // n_head
self.vocab_size = vocab_size
self.n_layer = n_layer
self.layers: MutableSequence[TransformerBlockABC]
self.max_seq_length = max_seq_length
self.max_batch_size = max_batch_size
def __call__(self, input_pos: Tensor, x: Tensor, kv_caches: MutableSequence[KVCacheProtocol], *args, **kwds):
for layer, kv_cache in zip(self.layers, kv_caches):
x = layer(x, input_pos, kv_cache, *args, **kwds)
return x
def prefill(self, x: Tensor, kv_caches: MutableSequence[KVCacheProtocol], attn_mask: Tensor):
for layer, kv_cache in zip(self.layers, kv_caches):
x = layer.prefill(x, kv_cache, attn_mask)
return x
class T2SDecoderABC(nn.Module, ABC, T2SDecoderProtocol):
def __init__(
self,
config: dict,
max_seq_length: int = 1800,
max_batch_size: int = 10,
) -> None:
super().__init__()
hidden_dim: int = config["model"]["hidden_dim"]
embedding_dim: int = config["model"]["embedding_dim"]
n_head: int = config["model"]["head"]
n_layer: int = config["model"]["n_layer"]
vocab_size: int = config["model"]["vocab_size"]
phoneme_vocab_size: int = config["model"]["phoneme_vocab_size"]
EOS: int = config["model"]["EOS"]
ffn_dim: int = hidden_dim * 4
self.n_layer = int(n_layer)
self.hidden_dim = int(hidden_dim)
self.n_head = int(n_head)
assert hidden_dim % n_head == 0
self.head_dim = int(hidden_dim // n_head)
self.embedding_dim = int(embedding_dim)
self.ffn_dim = int(ffn_dim)
self.vocab_size = int(vocab_size)
self.phoneme_vocab_size = int(phoneme_vocab_size)
self.max_seq_length = max_seq_length
self.max_batch_size = max_batch_size
self.EOS = EOS
assert self.EOS == self.vocab_size - 1
self.bert_proj: nn.Linear
self.ar_predict_layer: nn.Linear
self.h: TransformerDecoderABC
self.kv_class: type[KVCacheABC]
self.GraphCache: CUDAGraphCacheABC | None
self.ar_text_embedding = TokenEmbedding(self.embedding_dim, self.phoneme_vocab_size)
self.ar_text_position = SinePositionalEmbedding(
self.embedding_dim,
scale=False,
alpha=True,
max_batch_size=max_batch_size,
max_seq_len=max_seq_length,
)
self.ar_audio_embedding = TokenEmbedding(self.embedding_dim, self.vocab_size)
self.ar_audio_position = SinePositionalEmbedding(
self.embedding_dim,
scale=False,
alpha=True,
max_batch_size=max_batch_size,
max_seq_len=max_seq_length,
)
self._register_load_state_dict_pre_hook(self.load_hook)
def load_hook(self, state_dict: dict[str, Tensor], prefix, *args):
model_keys = [key for key in state_dict if key.startswith("model.")]
for key in model_keys:
new_key = key[len("model.") :]
state_dict[new_key] = state_dict.pop(key)
def init_cache(self, bsz: int = 0) -> MutableSequence[KVCacheProtocol]:
bsz = bsz or self.h.max_batch_size
assert bsz <= self.h.max_batch_size
seq_lens = self.h.max_seq_length
dtype = self.bert_proj.bias.dtype
kvclass = self.kv_class
return nn.ModuleList(
[kvclass(bsz, seq_lens, self.n_head, self.head_dim) for _ in range(self.n_layer)],
).to(self.device, dtype) # type: ignore
def embed(
self,
x: list[torch.Tensor],
y: torch.Tensor,
bert_features: list[torch.Tensor],
):
x_len: list[int] = [i.shape[0] for i in x]
x_len_max = max(x_len)
xy_pos = torch.zeros((len(x), x_len_max + y.shape[1], self.embedding_dim)).to(bert_features[0].dtype)
bert_features = list(map(lambda x: x.transpose(0, 1), bert_features))
y_len = y.shape[1]
y_emb = self.ar_audio_embedding(y)
y_pos = self.ar_audio_position.prefill(y_emb)
for bs, (x_, len_, bert_feature) in enumerate(zip(x, x_len, bert_features)):
x_emb = self.ar_text_embedding(x_)
bert = self.bert_proj(bert_feature)
x_emb = x_emb + bert
x_pos = self.ar_text_position.prefill(x_emb.unsqueeze(0))
xy_pos[[bs], :len_] = x_pos
xy_pos[[bs], len_ : len_ + y_len] = y_pos
return xy_pos
def compile(self, *args, **kwds):
# Experimental features to reduce compilation times, will be on by default in future
torch._inductor.config.triton.cudagraph_skip_dynamic_graphs = True
torch._inductor.config.coordinate_descent_tuning = True
torch._inductor.config.triton.unique_kernel_names = True
torch._inductor.config.fx_graph_cache = True
torch._inductor.config.triton.cudagraph_trees = True
torch._inductor.config.triton.cudagraph_support_input_mutation = True
self.h.compile(fullgraph=True, mode="reduce-overhead")
def capture(
self, input_pos: Tensor, x: Tensor, x_dec: Tensor, kv_caches: MutableSequence[KVCacheProtocol], *args, **kwds
) -> CUDAGraph:
assert torch.cuda.is_available()
s = torch.cuda.Stream()
s.wait_stream(torch.cuda.current_stream())
graph = torch.cuda.CUDAGraph()
with torch.cuda.stream(s): # type: ignore
for _ in range(5):
self.h(input_pos, x, kv_caches, *args, **kwds)
torch.cuda.current_stream().wait_stream(s)
with torch.cuda.graph(graph):
x_dec.copy_(self.h(input_pos, x, kv_caches, *args, **kwds))
torch.cuda.synchronize()
return graph
@abstractmethod
def pre_forward(self, session: T2SSession) -> tuple[list[Tensor], dict[str, Tensor]]:
return list(), dict()
@abstractmethod
def post_forward(self, idx: int, session: T2SSession) -> None:
return
class CUDAGraphCacheABC(ABC):
def __init__(
self,
decoder: T2SDecoderABC,
enabled: bool = False,
) -> None:
if torch.cuda.is_available() and enabled:
self.device: torch.device = decoder.device
self.dtype = decoder.bert_proj.bias.dtype
self.assigned: bool = False
self.decoder: T2SDecoderABC = decoder
self.kv_cache: MutableSequence[KVCacheProtocol] = decoder.init_cache(decoder.max_batch_size)
self.xy_pos = torch.rand(size=(decoder.max_batch_size, 1, decoder.embedding_dim), device=self.device).to(
self.dtype
)
self.xy_dec = self.xy_pos.clone()
self.input_pos = torch.tensor([10] * decoder.max_batch_size, device=self.device).int()
self.graph: torch.cuda.CUDAGraph | None = None
self.stream: torch.cuda.Stream | None
self.id: int = random.randint(1, 2**32 - 1)
def assign_graph(self, session: T2SSession):
if self.graph is None:
args, kwds = self.decoder.pre_forward(session)
graph = self.decoder.capture(self.input_pos, self.xy_pos, self.xy_dec, self.kv_cache, *args, **kwds)
self.graph = graph
self.stream = torch.cuda.Stream() # type: ignore
if self.assigned is False:
self.get_cache_graph(session)
session.id = self.id
self.assigned = True
else:
self.capture_new_graph(session)
@abstractmethod
def release_graph(self, session: T2SSession): ...
@abstractmethod
def get_cache_graph(self, session: T2SSession):
pass
@abstractmethod
def capture_new_graph(self, session: T2SSession):
pass
class TorchProfiler:
def __init__(self, debug: bool, log_dir: str = "./profiler") -> None:
self.debug = debug
self.log_dir = log_dir
self.__profiler: torch.profiler.profile
if self.debug and not os.path.exists(self.log_dir):
os.makedirs(self.log_dir)
self.tensorboard_handler = tensorboard_trace_handler(self.log_dir)
def profiler_callback(self, prof: torch.profiler.profile):
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=30))
print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=30))
self.tensorboard_handler(prof)
@staticmethod
def three_step_schedule(step: int) -> ProfilerAction:
if step == 0:
return ProfilerAction.NONE
elif step == 1:
return ProfilerAction.RECORD
elif step == 2:
return ProfilerAction.RECORD_AND_SAVE
else:
return ProfilerAction.NONE
def start(self):
if not self.debug:
return
assert self.__profiler is not None
self.__profiler.step()
def end(self):
if not self.debug:
return
assert self.__profiler is not None
self.__profiler.step()
def profiler(self):
if self.debug:
activities_list = [torch.profiler.ProfilerActivity.CPU]
if torch.cuda.is_available():
activities_list.append(torch.profiler.ProfilerActivity.CUDA)
self.__profiler = torch.profiler.profile(
activities=activities_list,
record_shapes=True,
with_stack=True,
with_modules=True,
profile_memory=True,
schedule=self.three_step_schedule,
on_trace_ready=self.profiler_callback,
)
return self.__profiler
else:
return nullcontext()
def record(self, func_name: str):
if self.debug:
return torch.profiler.record_function(func_name)
else:
return nullcontext()

View File

@ -0,0 +1,12 @@
from . import MLX, PyTorch
from .logger import logger, tb
from .PyTorch import T2SEngineTorch, T2SRequest, T2SResult
backends = PyTorch.backends + MLX.backends
backends = [
b.replace("_", "-").title().replace("Mlx", "MLX").replace("Mps", "MPS").replace("Cuda", "CUDA") for b in backends
]
__all__ = ["T2SEngineTorch", "T2SRequest", "T2SResult", "backends", "MLX", "PyTorch", "logger", "tb"]

View File

@ -0,0 +1,36 @@
import sys
from loguru import logger
from rich.console import Console
from rich.traceback import Traceback, install
install()
def rich_format(record):
level = record["level"].name
color = {
"DEBUG": "green",
"INFO": "cyan",
"WARNING": "yellow",
"ERROR": "red",
"CRITICAL": "magenta",
}.get(level, "black")
return f"[bold {color}][{level}][/bold {color}] {record['message']}"
def tb(show_locals: bool = True):
exc_type, exc_value, exc_tb = sys.exc_info()
assert exc_type
assert exc_value
tb = Traceback.from_exception(exc_type, exc_value, exc_tb, show_locals=show_locals)
return tb
console = Console()
logger.remove()
logger.add(console.print, format=rich_format)
__all__ = ["logger", "console", "tb"]

View File

@ -1,266 +0,0 @@
## BigVGAN: A Universal Neural Vocoder with Large-Scale Training
#### Sang-gil Lee, Wei Ping, Boris Ginsburg, Bryan Catanzaro, Sungroh Yoon
[[Paper]](https://arxiv.org/abs/2206.04658) - [[Code]](https://github.com/NVIDIA/BigVGAN) - [[Showcase]](https://bigvgan-demo.github.io/) - [[Project Page]](https://research.nvidia.com/labs/adlr/projects/bigvgan/) - [[Weights]](https://huggingface.co/collections/nvidia/bigvgan-66959df3d97fd7d98d97dc9a) - [[Demo]](https://huggingface.co/spaces/nvidia/BigVGAN)
[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/bigvgan-a-universal-neural-vocoder-with-large/speech-synthesis-on-libritts)](https://paperswithcode.com/sota/speech-synthesis-on-libritts?p=bigvgan-a-universal-neural-vocoder-with-large)
<center><img src="https://user-images.githubusercontent.com/15963413/218609148-881e39df-33af-4af9-ab95-1427c4ebf062.png" width="800"></center>
## News
- **Sep 2024 (v2.4):**
- We have updated the pretrained checkpoints trained for 5M steps. This is final release of the BigVGAN-v2 checkpoints.
- **Jul 2024 (v2.3):**
- General refactor and code improvements for improved readability.
- Fully fused CUDA kernel of anti-alised activation (upsampling + activation + downsampling) with inference speed benchmark.
- **Jul 2024 (v2.2):** The repository now includes an interactive local demo using gradio.
- **Jul 2024 (v2.1):** BigVGAN is now integrated with 🤗 Hugging Face Hub with easy access to inference using pretrained checkpoints. We also provide an interactive demo on Hugging Face Spaces.
- **Jul 2024 (v2):** We release BigVGAN-v2 along with pretrained checkpoints. Below are the highlights:
- Custom CUDA kernel for inference: we provide a fused upsampling + activation kernel written in CUDA for accelerated inference speed. Our test shows 1.5 - 3x faster speed on a single A100 GPU.
- Improved discriminator and loss: BigVGAN-v2 is trained using a [multi-scale sub-band CQT discriminator](https://arxiv.org/abs/2311.14957) and a [multi-scale mel spectrogram loss](https://arxiv.org/abs/2306.06546).
- Larger training data: BigVGAN-v2 is trained using datasets containing diverse audio types, including speech in multiple languages, environmental sounds, and instruments.
- We provide pretrained checkpoints of BigVGAN-v2 using diverse audio configurations, supporting up to 44 kHz sampling rate and 512x upsampling ratio.
## Installation
The codebase has been tested on Python `3.10` and PyTorch `2.3.1` conda packages with either `pytorch-cuda=12.1` or `pytorch-cuda=11.8`. Below is an example command to create the conda environment:
```shell
conda create -n bigvgan python=3.10 pytorch torchvision torchaudio pytorch-cuda=12.1 -c pytorch -c nvidia
conda activate bigvgan
```
Clone the repository and install dependencies:
```shell
git clone https://github.com/NVIDIA/BigVGAN
cd BigVGAN
pip install -r requirements.txt
```
## Inference Quickstart using 🤗 Hugging Face Hub
Below example describes how you can use BigVGAN: load the pretrained BigVGAN generator from Hugging Face Hub, compute mel spectrogram from input waveform, and generate synthesized waveform using the mel spectrogram as the model's input.
```python
device = 'cuda'
import torch
import bigvgan
import librosa
from meldataset import get_mel_spectrogram
# instantiate the model. You can optionally set use_cuda_kernel=True for faster inference.
model = bigvgan.BigVGAN.from_pretrained('nvidia/bigvgan_v2_24khz_100band_256x', use_cuda_kernel=False)
# remove weight norm in the model and set to eval mode
model.remove_weight_norm()
model = model.eval().to(device)
# load wav file and compute mel spectrogram
wav_path = '/path/to/your/audio.wav'
wav, sr = librosa.load(wav_path, sr=model.h.sampling_rate, mono=True) # wav is np.ndarray with shape [T_time] and values in [-1, 1]
wav = torch.FloatTensor(wav).unsqueeze(0) # wav is FloatTensor with shape [B(1), T_time]
# compute mel spectrogram from the ground truth audio
mel = get_mel_spectrogram(wav, model.h).to(device) # mel is FloatTensor with shape [B(1), C_mel, T_frame]
# generate waveform from mel
with torch.inference_mode():
wav_gen = model(mel) # wav_gen is FloatTensor with shape [B(1), 1, T_time] and values in [-1, 1]
wav_gen_float = wav_gen.squeeze(0).cpu() # wav_gen is FloatTensor with shape [1, T_time]
# you can convert the generated waveform to 16 bit linear PCM
wav_gen_int16 = (wav_gen_float * 32767.0).numpy().astype('int16') # wav_gen is now np.ndarray with shape [1, T_time] and int16 dtype
```
## Local gradio demo <a href='https://github.com/gradio-app/gradio'><img src='https://img.shields.io/github/stars/gradio-app/gradio'></a>
You can run a local gradio demo using below command:
```python
pip install -r demo/requirements.txt
python demo/app.py
```
## Training
Create symbolic link to the root of the dataset. The codebase uses filelist with the relative path from the dataset. Below are the example commands for LibriTTS dataset:
```shell
cd filelists/LibriTTS && \
ln -s /path/to/your/LibriTTS/train-clean-100 train-clean-100 && \
ln -s /path/to/your/LibriTTS/train-clean-360 train-clean-360 && \
ln -s /path/to/your/LibriTTS/train-other-500 train-other-500 && \
ln -s /path/to/your/LibriTTS/dev-clean dev-clean && \
ln -s /path/to/your/LibriTTS/dev-other dev-other && \
ln -s /path/to/your/LibriTTS/test-clean test-clean && \
ln -s /path/to/your/LibriTTS/test-other test-other && \
cd ../..
```
Train BigVGAN model. Below is an example command for training BigVGAN-v2 using LibriTTS dataset at 24kHz with a full 100-band mel spectrogram as input:
```shell
python train.py \
--config configs/bigvgan_v2_24khz_100band_256x.json \
--input_wavs_dir filelists/LibriTTS \
--input_training_file filelists/LibriTTS/train-full.txt \
--input_validation_file filelists/LibriTTS/val-full.txt \
--list_input_unseen_wavs_dir filelists/LibriTTS filelists/LibriTTS \
--list_input_unseen_validation_file filelists/LibriTTS/dev-clean.txt filelists/LibriTTS/dev-other.txt \
--checkpoint_path exp/bigvgan_v2_24khz_100band_256x
```
## Synthesis
Synthesize from BigVGAN model. Below is an example command for generating audio from the model.
It computes mel spectrograms using wav files from `--input_wavs_dir` and saves the generated audio to `--output_dir`.
```shell
python inference.py \
--checkpoint_file /path/to/your/bigvgan_v2_24khz_100band_256x/bigvgan_generator.pt \
--input_wavs_dir /path/to/your/input_wav \
--output_dir /path/to/your/output_wav
```
`inference_e2e.py` supports synthesis directly from the mel spectrogram saved in `.npy` format, with shapes `[1, channel, frame]` or `[channel, frame]`.
It loads mel spectrograms from `--input_mels_dir` and saves the generated audio to `--output_dir`.
Make sure that the STFT hyperparameters for mel spectrogram are the same as the model, which are defined in `config.json` of the corresponding model.
```shell
python inference_e2e.py \
--checkpoint_file /path/to/your/bigvgan_v2_24khz_100band_256x/bigvgan_generator.pt \
--input_mels_dir /path/to/your/input_mel \
--output_dir /path/to/your/output_wav
```
## Using Custom CUDA Kernel for Synthesis
You can apply the fast CUDA inference kernel by using a parameter `use_cuda_kernel` when instantiating BigVGAN:
```python
generator = BigVGAN(h, use_cuda_kernel=True)
```
You can also pass `--use_cuda_kernel` to `inference.py` and `inference_e2e.py` to enable this feature.
When applied for the first time, it builds the kernel using `nvcc` and `ninja`. If the build succeeds, the kernel is saved to `alias_free_activation/cuda/build` and the model automatically loads the kernel. The codebase has been tested using CUDA `12.1`.
Please make sure that both are installed in your system and `nvcc` installed in your system matches the version your PyTorch build is using.
We recommend running `test_cuda_vs_torch_model.py` first to build and check the correctness of the CUDA kernel. See below example command and its output, where it returns `[Success] test CUDA fused vs. plain torch BigVGAN inference`:
```python
python tests/test_cuda_vs_torch_model.py \
--checkpoint_file /path/to/your/bigvgan_generator.pt
```
```shell
loading plain Pytorch BigVGAN
...
loading CUDA kernel BigVGAN with auto-build
Detected CUDA files, patching ldflags
Emitting ninja build file /path/to/your/BigVGAN/alias_free_activation/cuda/build/build.ninja..
Building extension module anti_alias_activation_cuda...
...
Loading extension module anti_alias_activation_cuda...
...
Loading '/path/to/your/bigvgan_generator.pt'
...
[Success] test CUDA fused vs. plain torch BigVGAN inference
> mean_difference=0.0007238413265440613
...
```
If you see `[Fail] test CUDA fused vs. plain torch BigVGAN inference`, it means that the CUDA kernel inference is incorrect. Please check if `nvcc` installed in your system is compatible with your PyTorch version.
## Pretrained Models
We provide the [pretrained models on Hugging Face Collections](https://huggingface.co/collections/nvidia/bigvgan-66959df3d97fd7d98d97dc9a).
One can download the checkpoints of the generator weight (named `bigvgan_generator.pt`) and its discriminator/optimizer states (named `bigvgan_discriminator_optimizer.pt`) within the listed model repositories.
| Model Name | Sampling Rate | Mel band | fmax | Upsampling Ratio | Params | Dataset | Steps | Fine-Tuned |
|:--------------------------------------------------------------------------------------------------------:|:-------------:|:--------:|:-----:|:----------------:|:------:|:--------------------------:|:-----:|:----------:|
| [bigvgan_v2_44khz_128band_512x](https://huggingface.co/nvidia/bigvgan_v2_44khz_128band_512x) | 44 kHz | 128 | 22050 | 512 | 122M | Large-scale Compilation | 5M | No |
| [bigvgan_v2_44khz_128band_256x](https://huggingface.co/nvidia/bigvgan_v2_44khz_128band_256x) | 44 kHz | 128 | 22050 | 256 | 112M | Large-scale Compilation | 5M | No |
| [bigvgan_v2_24khz_100band_256x](https://huggingface.co/nvidia/bigvgan_v2_24khz_100band_256x) | 24 kHz | 100 | 12000 | 256 | 112M | Large-scale Compilation | 5M | No |
| [bigvgan_v2_22khz_80band_256x](https://huggingface.co/nvidia/bigvgan_v2_22khz_80band_256x) | 22 kHz | 80 | 11025 | 256 | 112M | Large-scale Compilation | 5M | No |
| [bigvgan_v2_22khz_80band_fmax8k_256x](https://huggingface.co/nvidia/bigvgan_v2_22khz_80band_fmax8k_256x) | 22 kHz | 80 | 8000 | 256 | 112M | Large-scale Compilation | 5M | No |
| [bigvgan_24khz_100band](https://huggingface.co/nvidia/bigvgan_24khz_100band) | 24 kHz | 100 | 12000 | 256 | 112M | LibriTTS | 5M | No |
| [bigvgan_base_24khz_100band](https://huggingface.co/nvidia/bigvgan_base_24khz_100band) | 24 kHz | 100 | 12000 | 256 | 14M | LibriTTS | 5M | No |
| [bigvgan_22khz_80band](https://huggingface.co/nvidia/bigvgan_22khz_80band) | 22 kHz | 80 | 8000 | 256 | 112M | LibriTTS + VCTK + LJSpeech | 5M | No |
| [bigvgan_base_22khz_80band](https://huggingface.co/nvidia/bigvgan_base_22khz_80band) | 22 kHz | 80 | 8000 | 256 | 14M | LibriTTS + VCTK + LJSpeech | 5M | No |
The paper results are based on the original 24kHz BigVGAN models (`bigvgan_24khz_100band` and `bigvgan_base_24khz_100band`) trained on LibriTTS dataset.
We also provide 22kHz BigVGAN models with band-limited setup (i.e., fmax=8000) for TTS applications.
Note that the checkpoints use `snakebeta` activation with log scale parameterization, which have the best overall quality.
You can fine-tune the models by:
1. downloading the checkpoints (both the generator weight and its discriminator/optimizer states)
2. resuming training using your audio dataset by specifying `--checkpoint_path` that includes the checkpoints when launching `train.py`
## Training Details of BigVGAN-v2
Comapred to the original BigVGAN, the pretrained checkpoints of BigVGAN-v2 used `batch_size=32` with a longer `segment_size=65536` and are trained using 8 A100 GPUs.
Note that the BigVGAN-v2 `json` config files in `./configs` use `batch_size=4` as default to fit in a single A100 GPU for training. You can fine-tune the models adjusting `batch_size` depending on your GPUs.
When training BigVGAN-v2 from scratch with small batch size, it can potentially encounter the early divergence problem mentioned in the paper. In such case, we recommend lowering the `clip_grad_norm` value (e.g. `100`) for the early training iterations (e.g. 20000 steps) and increase the value to the default `500`.
## Evaluation Results of BigVGAN-v2
Below are the objective results of the 24kHz model (`bigvgan_v2_24khz_100band_256x`) obtained from the LibriTTS `dev` sets. BigVGAN-v2 shows noticeable improvements of the metrics. The model also exhibits reduced perceptual artifacts, especially for non-speech audio.
| Model | Dataset | Steps | PESQ(↑) | M-STFT(↓) | MCD(↓) | Periodicity(↓) | V/UV F1(↑) |
|:----------:|:-----------------------:|:-----:|:---------:|:----------:|:----------:|:--------------:|:----------:|
| BigVGAN | LibriTTS | 1M | 4.027 | 0.7997 | 0.3745 | 0.1018 | 0.9598 |
| BigVGAN | LibriTTS | 5M | 4.256 | 0.7409 | 0.2988 | 0.0809 | 0.9698 |
| BigVGAN-v2 | Large-scale Compilation | 3M | 4.359 | 0.7134 | 0.3060 | 0.0621 | 0.9777 |
| BigVGAN-v2 | Large-scale Compilation | 5M | **4.362** | **0.7026** | **0.2903** | **0.0593** | **0.9793** |
## Speed Benchmark
Below are the speed and VRAM usage benchmark results of BigVGAN from `tests/test_cuda_vs_torch_model.py`, using `bigvgan_v2_24khz_100band_256x` as a reference model.
| GPU | num_mel_frame | use_cuda_kernel | Speed (kHz) | Real-time Factor | VRAM (GB) |
|:--------------------------:|:-------------:|:---------------:|:-----------:|:----------------:|:---------:|
| NVIDIA A100 | 256 | False | 1672.1 | 69.7x | 1.3 |
| | | True | 3916.5 | 163.2x | 1.3 |
| | 2048 | False | 1899.6 | 79.2x | 1.7 |
| | | True | 5330.1 | 222.1x | 1.7 |
| | 16384 | False | 1973.8 | 82.2x | 5.0 |
| | | True | 5761.7 | 240.1x | 4.4 |
| NVIDIA GeForce RTX 3080 | 256 | False | 841.1 | 35.0x | 1.3 |
| | | True | 1598.1 | 66.6x | 1.3 |
| | 2048 | False | 929.9 | 38.7x | 1.7 |
| | | True | 1971.3 | 82.1x | 1.6 |
| | 16384 | False | 943.4 | 39.3x | 5.0 |
| | | True | 2026.5 | 84.4x | 3.9 |
| NVIDIA GeForce RTX 2080 Ti | 256 | False | 515.6 | 21.5x | 1.3 |
| | | True | 811.3 | 33.8x | 1.3 |
| | 2048 | False | 576.5 | 24.0x | 1.7 |
| | | True | 1023.0 | 42.6x | 1.5 |
| | 16384 | False | 589.4 | 24.6x | 5.0 |
| | | True | 1068.1 | 44.5x | 3.2 |
## Acknowledgements
We thank Vijay Anand Korthikanti and Kevin J. Shih for their generous support in implementing the CUDA kernel for inference.
## References
- [HiFi-GAN](https://github.com/jik876/hifi-gan) (for generator and multi-period discriminator)
- [Snake](https://github.com/EdwardDixon/snake) (for periodic activation)
- [Alias-free-torch](https://github.com/junjun3518/alias-free-torch) (for anti-aliasing)
- [Julius](https://github.com/adefossez/julius) (for low-pass filter)
- [UnivNet](https://github.com/mindslab-ai/univnet) (for multi-resolution discriminator)
- [descript-audio-codec](https://github.com/descriptinc/descript-audio-codec) and [vocos](https://github.com/gemelo-ai/vocos) (for multi-band multi-scale STFT discriminator and multi-scale mel spectrogram loss)
- [Amphion](https://github.com/open-mmlab/Amphion) (for multi-scale sub-band CQT discriminator)

View File

@ -1,122 +0,0 @@
# Implementation adapted from https://github.com/EdwardDixon/snake under the MIT license.
# LICENSE is in incl_licenses directory.
import torch
from torch import nn, sin, pow
from torch.nn import Parameter
class Snake(nn.Module):
"""
Implementation of a sine-based periodic activation function
Shape:
- Input: (B, C, T)
- Output: (B, C, T), same shape as the input
Parameters:
- alpha - trainable parameter
References:
- This activation function is from this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
https://arxiv.org/abs/2006.08195
Examples:
>>> a1 = snake(256)
>>> x = torch.randn(256)
>>> x = a1(x)
"""
def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False):
"""
Initialization.
INPUT:
- in_features: shape of the input
- alpha: trainable parameter
alpha is initialized to 1 by default, higher values = higher-frequency.
alpha will be trained along with the rest of your model.
"""
super(Snake, self).__init__()
self.in_features = in_features
# Initialize alpha
self.alpha_logscale = alpha_logscale
if self.alpha_logscale: # Log scale alphas initialized to zeros
self.alpha = Parameter(torch.zeros(in_features) * alpha)
else: # Linear scale alphas initialized to ones
self.alpha = Parameter(torch.ones(in_features) * alpha)
self.alpha.requires_grad = alpha_trainable
self.no_div_by_zero = 0.000000001
def forward(self, x):
"""
Forward pass of the function.
Applies the function to the input elementwise.
Snake = x + 1/a * sin^2 (xa)
"""
alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # Line up with x to [B, C, T]
if self.alpha_logscale:
alpha = torch.exp(alpha)
x = x + (1.0 / (alpha + self.no_div_by_zero)) * pow(sin(x * alpha), 2)
return x
class SnakeBeta(nn.Module):
"""
A modified Snake function which uses separate parameters for the magnitude of the periodic components
Shape:
- Input: (B, C, T)
- Output: (B, C, T), same shape as the input
Parameters:
- alpha - trainable parameter that controls frequency
- beta - trainable parameter that controls magnitude
References:
- This activation function is a modified version based on this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
https://arxiv.org/abs/2006.08195
Examples:
>>> a1 = snakebeta(256)
>>> x = torch.randn(256)
>>> x = a1(x)
"""
def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False):
"""
Initialization.
INPUT:
- in_features: shape of the input
- alpha - trainable parameter that controls frequency
- beta - trainable parameter that controls magnitude
alpha is initialized to 1 by default, higher values = higher-frequency.
beta is initialized to 1 by default, higher values = higher-magnitude.
alpha will be trained along with the rest of your model.
"""
super(SnakeBeta, self).__init__()
self.in_features = in_features
# Initialize alpha
self.alpha_logscale = alpha_logscale
if self.alpha_logscale: # Log scale alphas initialized to zeros
self.alpha = Parameter(torch.zeros(in_features) * alpha)
self.beta = Parameter(torch.zeros(in_features) * alpha)
else: # Linear scale alphas initialized to ones
self.alpha = Parameter(torch.ones(in_features) * alpha)
self.beta = Parameter(torch.ones(in_features) * alpha)
self.alpha.requires_grad = alpha_trainable
self.beta.requires_grad = alpha_trainable
self.no_div_by_zero = 0.000000001
def forward(self, x):
"""
Forward pass of the function.
Applies the function to the input elementwise.
SnakeBeta = x + 1/b * sin^2 (xa)
"""
alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # Line up with x to [B, C, T]
beta = self.beta.unsqueeze(0).unsqueeze(-1)
if self.alpha_logscale:
alpha = torch.exp(alpha)
beta = torch.exp(beta)
x = x + (1.0 / (beta + self.no_div_by_zero)) * pow(sin(x * alpha), 2)
return x

View File

@ -1,45 +0,0 @@
{
"resblock": "1",
"num_gpus": 0,
"batch_size": 32,
"learning_rate": 0.0001,
"adam_b1": 0.8,
"adam_b2": 0.99,
"lr_decay": 0.9999996,
"seed": 1234,
"upsample_rates": [4,4,2,2,2,2],
"upsample_kernel_sizes": [8,8,4,4,4,4],
"upsample_initial_channel": 1536,
"resblock_kernel_sizes": [3,7,11],
"resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]],
"activation": "snakebeta",
"snake_logscale": true,
"resolutions": [[1024, 120, 600], [2048, 240, 1200], [512, 50, 240]],
"mpd_reshapes": [2, 3, 5, 7, 11],
"use_spectral_norm": false,
"discriminator_channel_mult": 1,
"segment_size": 8192,
"num_mels": 80,
"num_freq": 1025,
"n_fft": 1024,
"hop_size": 256,
"win_size": 1024,
"sampling_rate": 22050,
"fmin": 0,
"fmax": 8000,
"fmax_for_loss": null,
"num_workers": 4,
"dist_config": {
"dist_backend": "nccl",
"dist_url": "tcp://localhost:54321",
"world_size": 1
}
}

View File

@ -1,45 +0,0 @@
{
"resblock": "1",
"num_gpus": 0,
"batch_size": 32,
"learning_rate": 0.0001,
"adam_b1": 0.8,
"adam_b2": 0.99,
"lr_decay": 0.9999996,
"seed": 1234,
"upsample_rates": [4,4,2,2,2,2],
"upsample_kernel_sizes": [8,8,4,4,4,4],
"upsample_initial_channel": 1536,
"resblock_kernel_sizes": [3,7,11],
"resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]],
"activation": "snakebeta",
"snake_logscale": true,
"resolutions": [[1024, 120, 600], [2048, 240, 1200], [512, 50, 240]],
"mpd_reshapes": [2, 3, 5, 7, 11],
"use_spectral_norm": false,
"discriminator_channel_mult": 1,
"segment_size": 8192,
"num_mels": 100,
"num_freq": 1025,
"n_fft": 1024,
"hop_size": 256,
"win_size": 1024,
"sampling_rate": 24000,
"fmin": 0,
"fmax": 12000,
"fmax_for_loss": null,
"num_workers": 4,
"dist_config": {
"dist_backend": "nccl",
"dist_url": "tcp://localhost:54321",
"world_size": 1
}
}

View File

@ -1,45 +0,0 @@
{
"resblock": "1",
"num_gpus": 0,
"batch_size": 32,
"learning_rate": 0.0001,
"adam_b1": 0.8,
"adam_b2": 0.99,
"lr_decay": 0.9999996,
"seed": 1234,
"upsample_rates": [8,8,2,2],
"upsample_kernel_sizes": [16,16,4,4],
"upsample_initial_channel": 512,
"resblock_kernel_sizes": [3,7,11],
"resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]],
"activation": "snakebeta",
"snake_logscale": true,
"resolutions": [[1024, 120, 600], [2048, 240, 1200], [512, 50, 240]],
"mpd_reshapes": [2, 3, 5, 7, 11],
"use_spectral_norm": false,
"discriminator_channel_mult": 1,
"segment_size": 8192,
"num_mels": 80,
"num_freq": 1025,
"n_fft": 1024,
"hop_size": 256,
"win_size": 1024,
"sampling_rate": 22050,
"fmin": 0,
"fmax": 8000,
"fmax_for_loss": null,
"num_workers": 4,
"dist_config": {
"dist_backend": "nccl",
"dist_url": "tcp://localhost:54321",
"world_size": 1
}
}

View File

@ -1,45 +0,0 @@
{
"resblock": "1",
"num_gpus": 0,
"batch_size": 32,
"learning_rate": 0.0001,
"adam_b1": 0.8,
"adam_b2": 0.99,
"lr_decay": 0.9999996,
"seed": 1234,
"upsample_rates": [8,8,2,2],
"upsample_kernel_sizes": [16,16,4,4],
"upsample_initial_channel": 512,
"resblock_kernel_sizes": [3,7,11],
"resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]],
"activation": "snakebeta",
"snake_logscale": true,
"resolutions": [[1024, 120, 600], [2048, 240, 1200], [512, 50, 240]],
"mpd_reshapes": [2, 3, 5, 7, 11],
"use_spectral_norm": false,
"discriminator_channel_mult": 1,
"segment_size": 8192,
"num_mels": 100,
"num_freq": 1025,
"n_fft": 1024,
"hop_size": 256,
"win_size": 1024,
"sampling_rate": 24000,
"fmin": 0,
"fmax": 12000,
"fmax_for_loss": null,
"num_workers": 4,
"dist_config": {
"dist_backend": "nccl",
"dist_url": "tcp://localhost:54321",
"world_size": 1
}
}

View File

@ -1,61 +0,0 @@
{
"resblock": "1",
"num_gpus": 0,
"batch_size": 4,
"learning_rate": 0.0001,
"adam_b1": 0.8,
"adam_b2": 0.99,
"lr_decay": 0.9999996,
"seed": 1234,
"upsample_rates": [4,4,2,2,2,2],
"upsample_kernel_sizes": [8,8,4,4,4,4],
"upsample_initial_channel": 1536,
"resblock_kernel_sizes": [3,7,11],
"resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]],
"use_tanh_at_final": false,
"use_bias_at_final": false,
"activation": "snakebeta",
"snake_logscale": true,
"use_cqtd_instead_of_mrd": true,
"cqtd_filters": 128,
"cqtd_max_filters": 1024,
"cqtd_filters_scale": 1,
"cqtd_dilations": [1, 2, 4],
"cqtd_hop_lengths": [512, 256, 256],
"cqtd_n_octaves": [9, 9, 9],
"cqtd_bins_per_octaves": [24, 36, 48],
"mpd_reshapes": [2, 3, 5, 7, 11],
"use_spectral_norm": false,
"discriminator_channel_mult": 1,
"use_multiscale_melloss": true,
"lambda_melloss": 15,
"clip_grad_norm": 500,
"segment_size": 65536,
"num_mels": 80,
"num_freq": 1025,
"n_fft": 1024,
"hop_size": 256,
"win_size": 1024,
"sampling_rate": 22050,
"fmin": 0,
"fmax": null,
"fmax_for_loss": null,
"num_workers": 4,
"dist_config": {
"dist_backend": "nccl",
"dist_url": "tcp://localhost:54321",
"world_size": 1
}
}

View File

@ -1,61 +0,0 @@
{
"resblock": "1",
"num_gpus": 0,
"batch_size": 4,
"learning_rate": 0.0001,
"adam_b1": 0.8,
"adam_b2": 0.99,
"lr_decay": 0.9999996,
"seed": 1234,
"upsample_rates": [4,4,2,2,2,2],
"upsample_kernel_sizes": [8,8,4,4,4,4],
"upsample_initial_channel": 1536,
"resblock_kernel_sizes": [3,7,11],
"resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]],
"use_tanh_at_final": false,
"use_bias_at_final": false,
"activation": "snakebeta",
"snake_logscale": true,
"use_cqtd_instead_of_mrd": true,
"cqtd_filters": 128,
"cqtd_max_filters": 1024,
"cqtd_filters_scale": 1,
"cqtd_dilations": [1, 2, 4],
"cqtd_hop_lengths": [512, 256, 256],
"cqtd_n_octaves": [9, 9, 9],
"cqtd_bins_per_octaves": [24, 36, 48],
"mpd_reshapes": [2, 3, 5, 7, 11],
"use_spectral_norm": false,
"discriminator_channel_mult": 1,
"use_multiscale_melloss": true,
"lambda_melloss": 15,
"clip_grad_norm": 500,
"segment_size": 65536,
"num_mels": 80,
"num_freq": 1025,
"n_fft": 1024,
"hop_size": 256,
"win_size": 1024,
"sampling_rate": 22050,
"fmin": 0,
"fmax": 8000,
"fmax_for_loss": null,
"num_workers": 4,
"dist_config": {
"dist_backend": "nccl",
"dist_url": "tcp://localhost:54321",
"world_size": 1
}
}

View File

@ -1,61 +0,0 @@
{
"resblock": "1",
"num_gpus": 0,
"batch_size": 4,
"learning_rate": 0.0001,
"adam_b1": 0.8,
"adam_b2": 0.99,
"lr_decay": 0.9999996,
"seed": 1234,
"upsample_rates": [4,4,2,2,2,2],
"upsample_kernel_sizes": [8,8,4,4,4,4],
"upsample_initial_channel": 1536,
"resblock_kernel_sizes": [3,7,11],
"resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]],
"use_tanh_at_final": false,
"use_bias_at_final": false,
"activation": "snakebeta",
"snake_logscale": true,
"use_cqtd_instead_of_mrd": true,
"cqtd_filters": 128,
"cqtd_max_filters": 1024,
"cqtd_filters_scale": 1,
"cqtd_dilations": [1, 2, 4],
"cqtd_hop_lengths": [512, 256, 256],
"cqtd_n_octaves": [9, 9, 9],
"cqtd_bins_per_octaves": [24, 36, 48],
"mpd_reshapes": [2, 3, 5, 7, 11],
"use_spectral_norm": false,
"discriminator_channel_mult": 1,
"use_multiscale_melloss": true,
"lambda_melloss": 15,
"clip_grad_norm": 500,
"segment_size": 65536,
"num_mels": 128,
"num_freq": 1025,
"n_fft": 1024,
"hop_size": 256,
"win_size": 1024,
"sampling_rate": 44100,
"fmin": 0,
"fmax": null,
"fmax_for_loss": null,
"num_workers": 4,
"dist_config": {
"dist_backend": "nccl",
"dist_url": "tcp://localhost:54321",
"world_size": 1
}
}

View File

@ -1,61 +0,0 @@
{
"resblock": "1",
"num_gpus": 0,
"batch_size": 4,
"learning_rate": 0.0001,
"adam_b1": 0.8,
"adam_b2": 0.99,
"lr_decay": 0.9999996,
"seed": 1234,
"upsample_rates": [8,4,2,2,2,2],
"upsample_kernel_sizes": [16,8,4,4,4,4],
"upsample_initial_channel": 1536,
"resblock_kernel_sizes": [3,7,11],
"resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]],
"use_tanh_at_final": false,
"use_bias_at_final": false,
"activation": "snakebeta",
"snake_logscale": true,
"use_cqtd_instead_of_mrd": true,
"cqtd_filters": 128,
"cqtd_max_filters": 1024,
"cqtd_filters_scale": 1,
"cqtd_dilations": [1, 2, 4],
"cqtd_hop_lengths": [512, 256, 256],
"cqtd_n_octaves": [9, 9, 9],
"cqtd_bins_per_octaves": [24, 36, 48],
"mpd_reshapes": [2, 3, 5, 7, 11],
"use_spectral_norm": false,
"discriminator_channel_mult": 1,
"use_multiscale_melloss": true,
"lambda_melloss": 15,
"clip_grad_norm": 500,
"segment_size": 65536,
"num_mels": 128,
"num_freq": 2049,
"n_fft": 2048,
"hop_size": 512,
"win_size": 2048,
"sampling_rate": 44100,
"fmin": 0,
"fmax": null,
"fmax_for_loss": null,
"num_workers": 4,
"dist_config": {
"dist_backend": "nccl",
"dist_url": "tcp://localhost:54321",
"world_size": 1
}
}

View File

@ -1,625 +0,0 @@
# Copyright (c) 2024 NVIDIA CORPORATION.
# Licensed under the MIT license.
# Adapted from https://github.com/jik876/hifi-gan under the MIT license.
# LICENSE is in incl_licenses directory.
import torch
import torch.nn.functional as F
import torch.nn as nn
from torch.nn import Conv2d
from torch.nn.utils import weight_norm, spectral_norm
from torchaudio.transforms import Spectrogram, Resample
from env import AttrDict
from utils import get_padding
import typing
from typing import List, Tuple
class DiscriminatorP(torch.nn.Module):
def __init__(
self,
h: AttrDict,
period: List[int],
kernel_size: int = 5,
stride: int = 3,
use_spectral_norm: bool = False,
):
super().__init__()
self.period = period
self.d_mult = h.discriminator_channel_mult
norm_f = weight_norm if not use_spectral_norm else spectral_norm
self.convs = nn.ModuleList(
[
norm_f(
Conv2d(
1,
int(32 * self.d_mult),
(kernel_size, 1),
(stride, 1),
padding=(get_padding(5, 1), 0),
)
),
norm_f(
Conv2d(
int(32 * self.d_mult),
int(128 * self.d_mult),
(kernel_size, 1),
(stride, 1),
padding=(get_padding(5, 1), 0),
)
),
norm_f(
Conv2d(
int(128 * self.d_mult),
int(512 * self.d_mult),
(kernel_size, 1),
(stride, 1),
padding=(get_padding(5, 1), 0),
)
),
norm_f(
Conv2d(
int(512 * self.d_mult),
int(1024 * self.d_mult),
(kernel_size, 1),
(stride, 1),
padding=(get_padding(5, 1), 0),
)
),
norm_f(
Conv2d(
int(1024 * self.d_mult),
int(1024 * self.d_mult),
(kernel_size, 1),
1,
padding=(2, 0),
)
),
]
)
self.conv_post = norm_f(Conv2d(int(1024 * self.d_mult), 1, (3, 1), 1, padding=(1, 0)))
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, List[torch.Tensor]]:
fmap = []
# 1d to 2d
b, c, t = x.shape
if t % self.period != 0: # pad first
n_pad = self.period - (t % self.period)
x = F.pad(x, (0, n_pad), "reflect")
t = t + n_pad
x = x.view(b, c, t // self.period, self.period)
for l in self.convs:
x = l(x)
x = F.leaky_relu(x, 0.1)
fmap.append(x)
x = self.conv_post(x)
fmap.append(x)
x = torch.flatten(x, 1, -1)
return x, fmap
class MultiPeriodDiscriminator(torch.nn.Module):
def __init__(self, h: AttrDict):
super().__init__()
self.mpd_reshapes = h.mpd_reshapes
print(f"mpd_reshapes: {self.mpd_reshapes}")
self.discriminators = nn.ModuleList(
[DiscriminatorP(h, rs, use_spectral_norm=h.use_spectral_norm) for rs in self.mpd_reshapes]
)
def forward(
self, y: torch.Tensor, y_hat: torch.Tensor
) -> Tuple[
List[torch.Tensor],
List[torch.Tensor],
List[List[torch.Tensor]],
List[List[torch.Tensor]],
]:
y_d_rs = []
y_d_gs = []
fmap_rs = []
fmap_gs = []
for i, d in enumerate(self.discriminators):
y_d_r, fmap_r = d(y)
y_d_g, fmap_g = d(y_hat)
y_d_rs.append(y_d_r)
fmap_rs.append(fmap_r)
y_d_gs.append(y_d_g)
fmap_gs.append(fmap_g)
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
class DiscriminatorR(nn.Module):
def __init__(self, cfg: AttrDict, resolution: List[List[int]]):
super().__init__()
self.resolution = resolution
assert len(self.resolution) == 3, f"MRD layer requires list with len=3, got {self.resolution}"
self.lrelu_slope = 0.1
norm_f = weight_norm if cfg.use_spectral_norm == False else spectral_norm
if hasattr(cfg, "mrd_use_spectral_norm"):
print(f"[INFO] overriding MRD use_spectral_norm as {cfg.mrd_use_spectral_norm}")
norm_f = weight_norm if cfg.mrd_use_spectral_norm == False else spectral_norm
self.d_mult = cfg.discriminator_channel_mult
if hasattr(cfg, "mrd_channel_mult"):
print(f"[INFO] overriding mrd channel multiplier as {cfg.mrd_channel_mult}")
self.d_mult = cfg.mrd_channel_mult
self.convs = nn.ModuleList(
[
norm_f(nn.Conv2d(1, int(32 * self.d_mult), (3, 9), padding=(1, 4))),
norm_f(
nn.Conv2d(
int(32 * self.d_mult),
int(32 * self.d_mult),
(3, 9),
stride=(1, 2),
padding=(1, 4),
)
),
norm_f(
nn.Conv2d(
int(32 * self.d_mult),
int(32 * self.d_mult),
(3, 9),
stride=(1, 2),
padding=(1, 4),
)
),
norm_f(
nn.Conv2d(
int(32 * self.d_mult),
int(32 * self.d_mult),
(3, 9),
stride=(1, 2),
padding=(1, 4),
)
),
norm_f(
nn.Conv2d(
int(32 * self.d_mult),
int(32 * self.d_mult),
(3, 3),
padding=(1, 1),
)
),
]
)
self.conv_post = norm_f(nn.Conv2d(int(32 * self.d_mult), 1, (3, 3), padding=(1, 1)))
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, List[torch.Tensor]]:
fmap = []
x = self.spectrogram(x)
x = x.unsqueeze(1)
for l in self.convs:
x = l(x)
x = F.leaky_relu(x, self.lrelu_slope)
fmap.append(x)
x = self.conv_post(x)
fmap.append(x)
x = torch.flatten(x, 1, -1)
return x, fmap
def spectrogram(self, x: torch.Tensor) -> torch.Tensor:
n_fft, hop_length, win_length = self.resolution
x = F.pad(
x,
(int((n_fft - hop_length) / 2), int((n_fft - hop_length) / 2)),
mode="reflect",
)
x = x.squeeze(1)
x = torch.stft(
x,
n_fft=n_fft,
hop_length=hop_length,
win_length=win_length,
center=False,
return_complex=True,
)
x = torch.view_as_real(x) # [B, F, TT, 2]
mag = torch.norm(x, p=2, dim=-1) # [B, F, TT]
return mag
class MultiResolutionDiscriminator(nn.Module):
def __init__(self, cfg, debug=False):
super().__init__()
self.resolutions = cfg.resolutions
assert len(self.resolutions) == 3, (
f"MRD requires list of list with len=3, each element having a list with len=3. Got {self.resolutions}"
)
self.discriminators = nn.ModuleList([DiscriminatorR(cfg, resolution) for resolution in self.resolutions])
def forward(
self, y: torch.Tensor, y_hat: torch.Tensor
) -> Tuple[
List[torch.Tensor],
List[torch.Tensor],
List[List[torch.Tensor]],
List[List[torch.Tensor]],
]:
y_d_rs = []
y_d_gs = []
fmap_rs = []
fmap_gs = []
for i, d in enumerate(self.discriminators):
y_d_r, fmap_r = d(x=y)
y_d_g, fmap_g = d(x=y_hat)
y_d_rs.append(y_d_r)
fmap_rs.append(fmap_r)
y_d_gs.append(y_d_g)
fmap_gs.append(fmap_g)
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
# Method based on descript-audio-codec: https://github.com/descriptinc/descript-audio-codec
# Modified code adapted from https://github.com/gemelo-ai/vocos under the MIT license.
# LICENSE is in incl_licenses directory.
class DiscriminatorB(nn.Module):
def __init__(
self,
window_length: int,
channels: int = 32,
hop_factor: float = 0.25,
bands: Tuple[Tuple[float, float], ...] = (
(0.0, 0.1),
(0.1, 0.25),
(0.25, 0.5),
(0.5, 0.75),
(0.75, 1.0),
),
):
super().__init__()
self.window_length = window_length
self.hop_factor = hop_factor
self.spec_fn = Spectrogram(
n_fft=window_length,
hop_length=int(window_length * hop_factor),
win_length=window_length,
power=None,
)
n_fft = window_length // 2 + 1
bands = [(int(b[0] * n_fft), int(b[1] * n_fft)) for b in bands]
self.bands = bands
convs = lambda: nn.ModuleList(
[
weight_norm(nn.Conv2d(2, channels, (3, 9), (1, 1), padding=(1, 4))),
weight_norm(nn.Conv2d(channels, channels, (3, 9), (1, 2), padding=(1, 4))),
weight_norm(nn.Conv2d(channels, channels, (3, 9), (1, 2), padding=(1, 4))),
weight_norm(nn.Conv2d(channels, channels, (3, 9), (1, 2), padding=(1, 4))),
weight_norm(nn.Conv2d(channels, channels, (3, 3), (1, 1), padding=(1, 1))),
]
)
self.band_convs = nn.ModuleList([convs() for _ in range(len(self.bands))])
self.conv_post = weight_norm(nn.Conv2d(channels, 1, (3, 3), (1, 1), padding=(1, 1)))
def spectrogram(self, x: torch.Tensor) -> List[torch.Tensor]:
# Remove DC offset
x = x - x.mean(dim=-1, keepdims=True)
# Peak normalize the volume of input audio
x = 0.8 * x / (x.abs().max(dim=-1, keepdim=True)[0] + 1e-9)
x = self.spec_fn(x)
x = torch.view_as_real(x)
x = x.permute(0, 3, 2, 1) # [B, F, T, C] -> [B, C, T, F]
# Split into bands
x_bands = [x[..., b[0] : b[1]] for b in self.bands]
return x_bands
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, List[torch.Tensor]]:
x_bands = self.spectrogram(x.squeeze(1))
fmap = []
x = []
for band, stack in zip(x_bands, self.band_convs):
for i, layer in enumerate(stack):
band = layer(band)
band = torch.nn.functional.leaky_relu(band, 0.1)
if i > 0:
fmap.append(band)
x.append(band)
x = torch.cat(x, dim=-1)
x = self.conv_post(x)
fmap.append(x)
return x, fmap
# Method based on descript-audio-codec: https://github.com/descriptinc/descript-audio-codec
# Modified code adapted from https://github.com/gemelo-ai/vocos under the MIT license.
# LICENSE is in incl_licenses directory.
class MultiBandDiscriminator(nn.Module):
def __init__(
self,
h,
):
"""
Multi-band multi-scale STFT discriminator, with the architecture based on https://github.com/descriptinc/descript-audio-codec.
and the modified code adapted from https://github.com/gemelo-ai/vocos.
"""
super().__init__()
# fft_sizes (list[int]): Tuple of window lengths for FFT. Defaults to [2048, 1024, 512] if not set in h.
self.fft_sizes = h.get("mbd_fft_sizes", [2048, 1024, 512])
self.discriminators = nn.ModuleList([DiscriminatorB(window_length=w) for w in self.fft_sizes])
def forward(
self, y: torch.Tensor, y_hat: torch.Tensor
) -> Tuple[
List[torch.Tensor],
List[torch.Tensor],
List[List[torch.Tensor]],
List[List[torch.Tensor]],
]:
y_d_rs = []
y_d_gs = []
fmap_rs = []
fmap_gs = []
for d in self.discriminators:
y_d_r, fmap_r = d(x=y)
y_d_g, fmap_g = d(x=y_hat)
y_d_rs.append(y_d_r)
fmap_rs.append(fmap_r)
y_d_gs.append(y_d_g)
fmap_gs.append(fmap_g)
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
# Adapted from https://github.com/open-mmlab/Amphion/blob/main/models/vocoders/gan/discriminator/mssbcqtd.py under the MIT license.
# LICENSE is in incl_licenses directory.
class DiscriminatorCQT(nn.Module):
def __init__(self, cfg: AttrDict, hop_length: int, n_octaves: int, bins_per_octave: int):
super().__init__()
self.cfg = cfg
self.filters = cfg["cqtd_filters"]
self.max_filters = cfg["cqtd_max_filters"]
self.filters_scale = cfg["cqtd_filters_scale"]
self.kernel_size = (3, 9)
self.dilations = cfg["cqtd_dilations"]
self.stride = (1, 2)
self.in_channels = cfg["cqtd_in_channels"]
self.out_channels = cfg["cqtd_out_channels"]
self.fs = cfg["sampling_rate"]
self.hop_length = hop_length
self.n_octaves = n_octaves
self.bins_per_octave = bins_per_octave
# Lazy-load
from nnAudio import features
self.cqt_transform = features.cqt.CQT2010v2(
sr=self.fs * 2,
hop_length=self.hop_length,
n_bins=self.bins_per_octave * self.n_octaves,
bins_per_octave=self.bins_per_octave,
output_format="Complex",
pad_mode="constant",
)
self.conv_pres = nn.ModuleList()
for _ in range(self.n_octaves):
self.conv_pres.append(
nn.Conv2d(
self.in_channels * 2,
self.in_channels * 2,
kernel_size=self.kernel_size,
padding=self.get_2d_padding(self.kernel_size),
)
)
self.convs = nn.ModuleList()
self.convs.append(
nn.Conv2d(
self.in_channels * 2,
self.filters,
kernel_size=self.kernel_size,
padding=self.get_2d_padding(self.kernel_size),
)
)
in_chs = min(self.filters_scale * self.filters, self.max_filters)
for i, dilation in enumerate(self.dilations):
out_chs = min((self.filters_scale ** (i + 1)) * self.filters, self.max_filters)
self.convs.append(
weight_norm(
nn.Conv2d(
in_chs,
out_chs,
kernel_size=self.kernel_size,
stride=self.stride,
dilation=(dilation, 1),
padding=self.get_2d_padding(self.kernel_size, (dilation, 1)),
)
)
)
in_chs = out_chs
out_chs = min(
(self.filters_scale ** (len(self.dilations) + 1)) * self.filters,
self.max_filters,
)
self.convs.append(
weight_norm(
nn.Conv2d(
in_chs,
out_chs,
kernel_size=(self.kernel_size[0], self.kernel_size[0]),
padding=self.get_2d_padding((self.kernel_size[0], self.kernel_size[0])),
)
)
)
self.conv_post = weight_norm(
nn.Conv2d(
out_chs,
self.out_channels,
kernel_size=(self.kernel_size[0], self.kernel_size[0]),
padding=self.get_2d_padding((self.kernel_size[0], self.kernel_size[0])),
)
)
self.activation = torch.nn.LeakyReLU(negative_slope=0.1)
self.resample = Resample(orig_freq=self.fs, new_freq=self.fs * 2)
self.cqtd_normalize_volume = self.cfg.get("cqtd_normalize_volume", False)
if self.cqtd_normalize_volume:
print(
"[INFO] cqtd_normalize_volume set to True. Will apply DC offset removal & peak volume normalization in CQTD!"
)
def get_2d_padding(
self,
kernel_size: typing.Tuple[int, int],
dilation: typing.Tuple[int, int] = (1, 1),
):
return (
((kernel_size[0] - 1) * dilation[0]) // 2,
((kernel_size[1] - 1) * dilation[1]) // 2,
)
def forward(self, x: torch.tensor) -> Tuple[torch.Tensor, List[torch.Tensor]]:
fmap = []
if self.cqtd_normalize_volume:
# Remove DC offset
x = x - x.mean(dim=-1, keepdims=True)
# Peak normalize the volume of input audio
x = 0.8 * x / (x.abs().max(dim=-1, keepdim=True)[0] + 1e-9)
x = self.resample(x)
z = self.cqt_transform(x)
z_amplitude = z[:, :, :, 0].unsqueeze(1)
z_phase = z[:, :, :, 1].unsqueeze(1)
z = torch.cat([z_amplitude, z_phase], dim=1)
z = torch.permute(z, (0, 1, 3, 2)) # [B, C, W, T] -> [B, C, T, W]
latent_z = []
for i in range(self.n_octaves):
latent_z.append(
self.conv_pres[i](
z[
:,
:,
:,
i * self.bins_per_octave : (i + 1) * self.bins_per_octave,
]
)
)
latent_z = torch.cat(latent_z, dim=-1)
for i, l in enumerate(self.convs):
latent_z = l(latent_z)
latent_z = self.activation(latent_z)
fmap.append(latent_z)
latent_z = self.conv_post(latent_z)
return latent_z, fmap
class MultiScaleSubbandCQTDiscriminator(nn.Module):
def __init__(self, cfg: AttrDict):
super().__init__()
self.cfg = cfg
# Using get with defaults
self.cfg["cqtd_filters"] = self.cfg.get("cqtd_filters", 32)
self.cfg["cqtd_max_filters"] = self.cfg.get("cqtd_max_filters", 1024)
self.cfg["cqtd_filters_scale"] = self.cfg.get("cqtd_filters_scale", 1)
self.cfg["cqtd_dilations"] = self.cfg.get("cqtd_dilations", [1, 2, 4])
self.cfg["cqtd_in_channels"] = self.cfg.get("cqtd_in_channels", 1)
self.cfg["cqtd_out_channels"] = self.cfg.get("cqtd_out_channels", 1)
# Multi-scale params to loop over
self.cfg["cqtd_hop_lengths"] = self.cfg.get("cqtd_hop_lengths", [512, 256, 256])
self.cfg["cqtd_n_octaves"] = self.cfg.get("cqtd_n_octaves", [9, 9, 9])
self.cfg["cqtd_bins_per_octaves"] = self.cfg.get("cqtd_bins_per_octaves", [24, 36, 48])
self.discriminators = nn.ModuleList(
[
DiscriminatorCQT(
self.cfg,
hop_length=self.cfg["cqtd_hop_lengths"][i],
n_octaves=self.cfg["cqtd_n_octaves"][i],
bins_per_octave=self.cfg["cqtd_bins_per_octaves"][i],
)
for i in range(len(self.cfg["cqtd_hop_lengths"]))
]
)
def forward(
self, y: torch.Tensor, y_hat: torch.Tensor
) -> Tuple[
List[torch.Tensor],
List[torch.Tensor],
List[List[torch.Tensor]],
List[List[torch.Tensor]],
]:
y_d_rs = []
y_d_gs = []
fmap_rs = []
fmap_gs = []
for disc in self.discriminators:
y_d_r, fmap_r = disc(y)
y_d_g, fmap_g = disc(y_hat)
y_d_rs.append(y_d_r)
fmap_rs.append(fmap_r)
y_d_gs.append(y_d_g)
fmap_gs.append(fmap_g)
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
class CombinedDiscriminator(nn.Module):
"""
Wrapper of chaining multiple discrimiantor architectures.
Example: combine mbd and cqtd as a single class
"""
def __init__(self, list_discriminator: List[nn.Module]):
super().__init__()
self.discrimiantor = nn.ModuleList(list_discriminator)
def forward(
self, y: torch.Tensor, y_hat: torch.Tensor
) -> Tuple[
List[torch.Tensor],
List[torch.Tensor],
List[List[torch.Tensor]],
List[List[torch.Tensor]],
]:
y_d_rs = []
y_d_gs = []
fmap_rs = []
fmap_gs = []
for disc in self.discrimiantor:
y_d_r, y_d_g, fmap_r, fmap_g = disc(y, y_hat)
y_d_rs.extend(y_d_r)
fmap_rs.extend(fmap_r)
y_d_gs.extend(y_d_g)
fmap_gs.extend(fmap_g)
return y_d_rs, y_d_gs, fmap_rs, fmap_gs

View File

@ -1,85 +0,0 @@
# Adapted from https://github.com/jik876/hifi-gan under the MIT license.
# LICENSE is in incl_licenses directory.
from __future__ import absolute_import, division, print_function, unicode_literals
import os
import argparse
import json
import torch
import librosa
from utils import load_checkpoint
from meldataset import get_mel_spectrogram
from scipy.io.wavfile import write
from env import AttrDict
from meldataset import MAX_WAV_VALUE
from bigvgan import BigVGAN as Generator
h = None
device = None
torch.backends.cudnn.benchmark = False
def inference(a, h):
generator = Generator(h, use_cuda_kernel=a.use_cuda_kernel).to(device)
state_dict_g = load_checkpoint(a.checkpoint_file, device)
generator.load_state_dict(state_dict_g["generator"])
filelist = os.listdir(a.input_wavs_dir)
os.makedirs(a.output_dir, exist_ok=True)
generator.eval()
generator.remove_weight_norm()
with torch.no_grad():
for i, filname in enumerate(filelist):
# Load the ground truth audio and resample if necessary
wav, sr = librosa.load(os.path.join(a.input_wavs_dir, filname), sr=h.sampling_rate, mono=True)
wav = torch.FloatTensor(wav).to(device)
# Compute mel spectrogram from the ground truth audio
x = get_mel_spectrogram(wav.unsqueeze(0), generator.h)
y_g_hat = generator(x)
audio = y_g_hat.squeeze()
audio = audio * MAX_WAV_VALUE
audio = audio.cpu().numpy().astype("int16")
output_file = os.path.join(a.output_dir, os.path.splitext(filname)[0] + "_generated.wav")
write(output_file, h.sampling_rate, audio)
print(output_file)
def main():
print("Initializing Inference Process..")
parser = argparse.ArgumentParser()
parser.add_argument("--input_wavs_dir", default="test_files")
parser.add_argument("--output_dir", default="generated_files")
parser.add_argument("--checkpoint_file", required=True)
parser.add_argument("--use_cuda_kernel", action="store_true", default=False)
a = parser.parse_args()
config_file = os.path.join(os.path.split(a.checkpoint_file)[0], "config.json")
with open(config_file) as f:
data = f.read()
global h
json_config = json.loads(data)
h = AttrDict(json_config)
torch.manual_seed(h.seed)
global device
if torch.cuda.is_available():
torch.cuda.manual_seed(h.seed)
device = torch.device("cuda")
else:
device = torch.device("cpu")
inference(a, h)
if __name__ == "__main__":
main()

View File

@ -1,100 +0,0 @@
# Adapted from https://github.com/jik876/hifi-gan under the MIT license.
# LICENSE is in incl_licenses directory.
from __future__ import absolute_import, division, print_function, unicode_literals
import glob
import os
import numpy as np
import argparse
import json
import torch
from scipy.io.wavfile import write
from env import AttrDict
from meldataset import MAX_WAV_VALUE
from bigvgan import BigVGAN as Generator
h = None
device = None
torch.backends.cudnn.benchmark = False
def load_checkpoint(filepath, device):
assert os.path.isfile(filepath)
print(f"Loading '{filepath}'")
checkpoint_dict = torch.load(filepath, map_location=device)
print("Complete.")
return checkpoint_dict
def scan_checkpoint(cp_dir, prefix):
pattern = os.path.join(cp_dir, prefix + "*")
cp_list = glob.glob(pattern)
if len(cp_list) == 0:
return ""
return sorted(cp_list)[-1]
def inference(a, h):
generator = Generator(h, use_cuda_kernel=a.use_cuda_kernel).to(device)
state_dict_g = load_checkpoint(a.checkpoint_file, device)
generator.load_state_dict(state_dict_g["generator"])
filelist = os.listdir(a.input_mels_dir)
os.makedirs(a.output_dir, exist_ok=True)
generator.eval()
generator.remove_weight_norm()
with torch.no_grad():
for i, filname in enumerate(filelist):
# Load the mel spectrogram in .npy format
x = np.load(os.path.join(a.input_mels_dir, filname))
x = torch.FloatTensor(x).to(device)
if len(x.shape) == 2:
x = x.unsqueeze(0)
y_g_hat = generator(x)
audio = y_g_hat.squeeze()
audio = audio * MAX_WAV_VALUE
audio = audio.cpu().numpy().astype("int16")
output_file = os.path.join(a.output_dir, os.path.splitext(filname)[0] + "_generated_e2e.wav")
write(output_file, h.sampling_rate, audio)
print(output_file)
def main():
print("Initializing Inference Process..")
parser = argparse.ArgumentParser()
parser.add_argument("--input_mels_dir", default="test_mel_files")
parser.add_argument("--output_dir", default="generated_files_from_mel")
parser.add_argument("--checkpoint_file", required=True)
parser.add_argument("--use_cuda_kernel", action="store_true", default=False)
a = parser.parse_args()
config_file = os.path.join(os.path.split(a.checkpoint_file)[0], "config.json")
with open(config_file) as f:
data = f.read()
global h
json_config = json.loads(data)
h = AttrDict(json_config)
torch.manual_seed(h.seed)
global device
if torch.cuda.is_available():
torch.cuda.manual_seed(h.seed)
device = torch.device("cuda")
else:
device = torch.device("cpu")
inference(a, h)
if __name__ == "__main__":
main()

View File

@ -1,238 +0,0 @@
# Copyright (c) 2024 NVIDIA CORPORATION.
# Licensed under the MIT license.
# Adapted from https://github.com/jik876/hifi-gan under the MIT license.
# LICENSE is in incl_licenses directory.
import torch
import torch.nn as nn
from librosa.filters import mel as librosa_mel_fn
from scipy import signal
import typing
from typing import List, Tuple
from collections import namedtuple
import math
import functools
# Adapted from https://github.com/descriptinc/descript-audio-codec/blob/main/dac/nn/loss.py under the MIT license.
# LICENSE is in incl_licenses directory.
class MultiScaleMelSpectrogramLoss(nn.Module):
"""Compute distance between mel spectrograms. Can be used
in a multi-scale way.
Parameters
----------
n_mels : List[int]
Number of mels per STFT, by default [5, 10, 20, 40, 80, 160, 320],
window_lengths : List[int], optional
Length of each window of each STFT, by default [32, 64, 128, 256, 512, 1024, 2048]
loss_fn : typing.Callable, optional
How to compare each loss, by default nn.L1Loss()
clamp_eps : float, optional
Clamp on the log magnitude, below, by default 1e-5
mag_weight : float, optional
Weight of raw magnitude portion of loss, by default 0.0 (no ampliciation on mag part)
log_weight : float, optional
Weight of log magnitude portion of loss, by default 1.0
pow : float, optional
Power to raise magnitude to before taking log, by default 1.0
weight : float, optional
Weight of this loss, by default 1.0
match_stride : bool, optional
Whether to match the stride of convolutional layers, by default False
Implementation copied from: https://github.com/descriptinc/lyrebird-audiotools/blob/961786aa1a9d628cca0c0486e5885a457fe70c1a/audiotools/metrics/spectral.py
Additional code copied and modified from https://github.com/descriptinc/audiotools/blob/master/audiotools/core/audio_signal.py
"""
def __init__(
self,
sampling_rate: int,
n_mels: List[int] = [5, 10, 20, 40, 80, 160, 320],
window_lengths: List[int] = [32, 64, 128, 256, 512, 1024, 2048],
loss_fn: typing.Callable = nn.L1Loss(),
clamp_eps: float = 1e-5,
mag_weight: float = 0.0,
log_weight: float = 1.0,
pow: float = 1.0,
weight: float = 1.0,
match_stride: bool = False,
mel_fmin: List[float] = [0, 0, 0, 0, 0, 0, 0],
mel_fmax: List[float] = [None, None, None, None, None, None, None],
window_type: str = "hann",
):
super().__init__()
self.sampling_rate = sampling_rate
STFTParams = namedtuple(
"STFTParams",
["window_length", "hop_length", "window_type", "match_stride"],
)
self.stft_params = [
STFTParams(
window_length=w,
hop_length=w // 4,
match_stride=match_stride,
window_type=window_type,
)
for w in window_lengths
]
self.n_mels = n_mels
self.loss_fn = loss_fn
self.clamp_eps = clamp_eps
self.log_weight = log_weight
self.mag_weight = mag_weight
self.weight = weight
self.mel_fmin = mel_fmin
self.mel_fmax = mel_fmax
self.pow = pow
@staticmethod
@functools.lru_cache(None)
def get_window(
window_type,
window_length,
):
return signal.get_window(window_type, window_length)
@staticmethod
@functools.lru_cache(None)
def get_mel_filters(sr, n_fft, n_mels, fmin, fmax):
return librosa_mel_fn(sr=sr, n_fft=n_fft, n_mels=n_mels, fmin=fmin, fmax=fmax)
def mel_spectrogram(
self,
wav,
n_mels,
fmin,
fmax,
window_length,
hop_length,
match_stride,
window_type,
):
"""
Mirrors AudioSignal.mel_spectrogram used by BigVGAN-v2 training from:
https://github.com/descriptinc/audiotools/blob/master/audiotools/core/audio_signal.py
"""
B, C, T = wav.shape
if match_stride:
assert hop_length == window_length // 4, "For match_stride, hop must equal n_fft // 4"
right_pad = math.ceil(T / hop_length) * hop_length - T
pad = (window_length - hop_length) // 2
else:
right_pad = 0
pad = 0
wav = torch.nn.functional.pad(wav, (pad, pad + right_pad), mode="reflect")
window = self.get_window(window_type, window_length)
window = torch.from_numpy(window).to(wav.device).float()
stft = torch.stft(
wav.reshape(-1, T),
n_fft=window_length,
hop_length=hop_length,
window=window,
return_complex=True,
center=True,
)
_, nf, nt = stft.shape
stft = stft.reshape(B, C, nf, nt)
if match_stride:
"""
Drop first two and last two frames, which are added, because of padding. Now num_frames * hop_length = num_samples.
"""
stft = stft[..., 2:-2]
magnitude = torch.abs(stft)
nf = magnitude.shape[2]
mel_basis = self.get_mel_filters(self.sampling_rate, 2 * (nf - 1), n_mels, fmin, fmax)
mel_basis = torch.from_numpy(mel_basis).to(wav.device)
mel_spectrogram = magnitude.transpose(2, -1) @ mel_basis.T
mel_spectrogram = mel_spectrogram.transpose(-1, 2)
return mel_spectrogram
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
"""Computes mel loss between an estimate and a reference
signal.
Parameters
----------
x : torch.Tensor
Estimate signal
y : torch.Tensor
Reference signal
Returns
-------
torch.Tensor
Mel loss.
"""
loss = 0.0
for n_mels, fmin, fmax, s in zip(self.n_mels, self.mel_fmin, self.mel_fmax, self.stft_params):
kwargs = {
"n_mels": n_mels,
"fmin": fmin,
"fmax": fmax,
"window_length": s.window_length,
"hop_length": s.hop_length,
"match_stride": s.match_stride,
"window_type": s.window_type,
}
x_mels = self.mel_spectrogram(x, **kwargs)
y_mels = self.mel_spectrogram(y, **kwargs)
x_logmels = torch.log(x_mels.clamp(min=self.clamp_eps).pow(self.pow)) / torch.log(torch.tensor(10.0))
y_logmels = torch.log(y_mels.clamp(min=self.clamp_eps).pow(self.pow)) / torch.log(torch.tensor(10.0))
loss += self.log_weight * self.loss_fn(x_logmels, y_logmels)
loss += self.mag_weight * self.loss_fn(x_logmels, y_logmels)
return loss
# Loss functions
def feature_loss(fmap_r: List[List[torch.Tensor]], fmap_g: List[List[torch.Tensor]]) -> torch.Tensor:
loss = 0
for dr, dg in zip(fmap_r, fmap_g):
for rl, gl in zip(dr, dg):
loss += torch.mean(torch.abs(rl - gl))
return loss * 2 # This equates to lambda=2.0 for the feature matching loss
def discriminator_loss(
disc_real_outputs: List[torch.Tensor], disc_generated_outputs: List[torch.Tensor]
) -> Tuple[torch.Tensor, List[torch.Tensor], List[torch.Tensor]]:
loss = 0
r_losses = []
g_losses = []
for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
r_loss = torch.mean((1 - dr) ** 2)
g_loss = torch.mean(dg**2)
loss += r_loss + g_loss
r_losses.append(r_loss.item())
g_losses.append(g_loss.item())
return loss, r_losses, g_losses
def generator_loss(
disc_outputs: List[torch.Tensor],
) -> Tuple[torch.Tensor, List[torch.Tensor]]:
loss = 0
gen_losses = []
for dg in disc_outputs:
l = torch.mean((1 - dg) ** 2)
gen_losses.append(l)
loss += l
return loss, gen_losses

View File

@ -1,370 +0,0 @@
# Copyright (c) 2024 NVIDIA CORPORATION.
# Licensed under the MIT license.
# Adapted from https://github.com/jik876/hifi-gan under the MIT license.
# LICENSE is in incl_licenses directory.
import math
import os
import random
import torch
import torch.utils.data
import numpy as np
import librosa
from librosa.filters import mel as librosa_mel_fn
import pathlib
from tqdm import tqdm
from typing import List, Tuple, Optional
from .env import AttrDict
MAX_WAV_VALUE = 32767.0 # NOTE: 32768.0 -1 to prevent int16 overflow (results in popping sound in corner cases)
def dynamic_range_compression(x, C=1, clip_val=1e-5):
return np.log(np.clip(x, a_min=clip_val, a_max=None) * C)
def dynamic_range_decompression(x, C=1):
return np.exp(x) / C
def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
return torch.log(torch.clamp(x, min=clip_val) * C)
def dynamic_range_decompression_torch(x, C=1):
return torch.exp(x) / C
def spectral_normalize_torch(magnitudes):
return dynamic_range_compression_torch(magnitudes)
def spectral_de_normalize_torch(magnitudes):
return dynamic_range_decompression_torch(magnitudes)
mel_basis_cache = {}
hann_window_cache = {}
def mel_spectrogram(
y: torch.Tensor,
n_fft: int,
num_mels: int,
sampling_rate: int,
hop_size: int,
win_size: int,
fmin: int,
fmax: int = None,
center: bool = False,
) -> torch.Tensor:
"""
Calculate the mel spectrogram of an input signal.
This function uses slaney norm for the librosa mel filterbank (using librosa.filters.mel) and uses Hann window for STFT (using torch.stft).
Args:
y (torch.Tensor): Input signal.
n_fft (int): FFT size.
num_mels (int): Number of mel bins.
sampling_rate (int): Sampling rate of the input signal.
hop_size (int): Hop size for STFT.
win_size (int): Window size for STFT.
fmin (int): Minimum frequency for mel filterbank.
fmax (int): Maximum frequency for mel filterbank. If None, defaults to half the sampling rate (fmax = sr / 2.0) inside librosa_mel_fn
center (bool): Whether to pad the input to center the frames. Default is False.
Returns:
torch.Tensor: Mel spectrogram.
"""
if torch.min(y) < -1.0:
print(f"[WARNING] Min value of input waveform signal is {torch.min(y)}")
if torch.max(y) > 1.0:
print(f"[WARNING] Max value of input waveform signal is {torch.max(y)}")
device = y.device
key = f"{n_fft}_{num_mels}_{sampling_rate}_{hop_size}_{win_size}_{fmin}_{fmax}_{device}"
if key not in mel_basis_cache:
mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax)
mel_basis_cache[key] = torch.from_numpy(mel).float().to(device)
hann_window_cache[key] = torch.hann_window(win_size).to(device)
mel_basis = mel_basis_cache[key]
hann_window = hann_window_cache[key]
padding = (n_fft - hop_size) // 2
y = torch.nn.functional.pad(y.unsqueeze(1), (padding, padding), mode="reflect").squeeze(1)
spec = torch.stft(
y,
n_fft,
hop_length=hop_size,
win_length=win_size,
window=hann_window,
center=center,
pad_mode="reflect",
normalized=False,
onesided=True,
return_complex=True,
)
spec = torch.sqrt(torch.view_as_real(spec).pow(2).sum(-1) + 1e-9)
mel_spec = torch.matmul(mel_basis, spec)
mel_spec = spectral_normalize_torch(mel_spec)
return mel_spec
def get_mel_spectrogram(wav, h):
"""
Generate mel spectrogram from a waveform using given hyperparameters.
Args:
wav (torch.Tensor): Input waveform.
h: Hyperparameters object with attributes n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax.
Returns:
torch.Tensor: Mel spectrogram.
"""
return mel_spectrogram(
wav,
h.n_fft,
h.num_mels,
h.sampling_rate,
h.hop_size,
h.win_size,
h.fmin,
h.fmax,
)
def get_dataset_filelist(a):
training_files = []
validation_files = []
list_unseen_validation_files = []
with open(a.input_training_file, "r", encoding="utf-8") as fi:
training_files = [
os.path.join(a.input_wavs_dir, x.split("|")[0] + ".wav") for x in fi.read().split("\n") if len(x) > 0
]
print(f"first training file: {training_files[0]}")
with open(a.input_validation_file, "r", encoding="utf-8") as fi:
validation_files = [
os.path.join(a.input_wavs_dir, x.split("|")[0] + ".wav") for x in fi.read().split("\n") if len(x) > 0
]
print(f"first validation file: {validation_files[0]}")
for i in range(len(a.list_input_unseen_validation_file)):
with open(a.list_input_unseen_validation_file[i], "r", encoding="utf-8") as fi:
unseen_validation_files = [
os.path.join(a.list_input_unseen_wavs_dir[i], x.split("|")[0] + ".wav")
for x in fi.read().split("\n")
if len(x) > 0
]
print(f"first unseen {i}th validation fileset: {unseen_validation_files[0]}")
list_unseen_validation_files.append(unseen_validation_files)
return training_files, validation_files, list_unseen_validation_files
class MelDataset(torch.utils.data.Dataset):
def __init__(
self,
training_files: List[str],
hparams: AttrDict,
segment_size: int,
n_fft: int,
num_mels: int,
hop_size: int,
win_size: int,
sampling_rate: int,
fmin: int,
fmax: Optional[int],
split: bool = True,
shuffle: bool = True,
device: str = None,
fmax_loss: Optional[int] = None,
fine_tuning: bool = False,
base_mels_path: str = None,
is_seen: bool = True,
):
self.audio_files = training_files
random.seed(1234)
if shuffle:
random.shuffle(self.audio_files)
self.hparams = hparams
self.is_seen = is_seen
if self.is_seen:
self.name = pathlib.Path(self.audio_files[0]).parts[0]
else:
self.name = "-".join(pathlib.Path(self.audio_files[0]).parts[:2]).strip("/")
self.segment_size = segment_size
self.sampling_rate = sampling_rate
self.split = split
self.n_fft = n_fft
self.num_mels = num_mels
self.hop_size = hop_size
self.win_size = win_size
self.fmin = fmin
self.fmax = fmax
self.fmax_loss = fmax_loss
self.device = device
self.fine_tuning = fine_tuning
self.base_mels_path = base_mels_path
print("[INFO] checking dataset integrity...")
for i in tqdm(range(len(self.audio_files))):
assert os.path.exists(self.audio_files[i]), f"{self.audio_files[i]} not found"
def __getitem__(self, index: int) -> Tuple[torch.Tensor, torch.Tensor, str, torch.Tensor]:
try:
filename = self.audio_files[index]
# Use librosa.load that ensures loading waveform into mono with [-1, 1] float values
# Audio is ndarray with shape [T_time]. Disable auto-resampling here to minimize overhead
# The on-the-fly resampling during training will be done only for the obtained random chunk
audio, source_sampling_rate = librosa.load(filename, sr=None, mono=True)
# Main logic that uses <mel, audio> pair for training BigVGAN
if not self.fine_tuning:
if self.split: # Training step
# Obtain randomized audio chunk
if source_sampling_rate != self.sampling_rate:
# Adjust segment size to crop if the source sr is different
target_segment_size = math.ceil(self.segment_size * (source_sampling_rate / self.sampling_rate))
else:
target_segment_size = self.segment_size
# Compute upper bound index for the random chunk
random_chunk_upper_bound = max(0, audio.shape[0] - target_segment_size)
# Crop or pad audio to obtain random chunk with target_segment_size
if audio.shape[0] >= target_segment_size:
audio_start = random.randint(0, random_chunk_upper_bound)
audio = audio[audio_start : audio_start + target_segment_size]
else:
audio = np.pad(
audio,
(0, target_segment_size - audio.shape[0]),
mode="constant",
)
# Resample audio chunk to self.sampling rate
if source_sampling_rate != self.sampling_rate:
audio = librosa.resample(
audio,
orig_sr=source_sampling_rate,
target_sr=self.sampling_rate,
)
if audio.shape[0] > self.segment_size:
# trim last elements to match self.segment_size (e.g., 16385 for 44khz downsampled to 24khz -> 16384)
audio = audio[: self.segment_size]
else: # Validation step
# Resample full audio clip to target sampling rate
if source_sampling_rate != self.sampling_rate:
audio = librosa.resample(
audio,
orig_sr=source_sampling_rate,
target_sr=self.sampling_rate,
)
# Trim last elements to match audio length to self.hop_size * n for evaluation
if (audio.shape[0] % self.hop_size) != 0:
audio = audio[: -(audio.shape[0] % self.hop_size)]
# BigVGAN is trained using volume-normalized waveform
audio = librosa.util.normalize(audio) * 0.95
# Cast ndarray to torch tensor
audio = torch.FloatTensor(audio)
audio = audio.unsqueeze(0) # [B(1), self.segment_size]
# Compute mel spectrogram corresponding to audio
mel = mel_spectrogram(
audio,
self.n_fft,
self.num_mels,
self.sampling_rate,
self.hop_size,
self.win_size,
self.fmin,
self.fmax,
center=False,
) # [B(1), self.num_mels, self.segment_size // self.hop_size]
# Fine-tuning logic that uses pre-computed mel. Example: Using TTS model-generated mel as input
else:
# For fine-tuning, assert that the waveform is in the defined sampling_rate
# Fine-tuning won't support on-the-fly resampling to be fool-proof (the dataset should have been prepared properly)
assert source_sampling_rate == self.sampling_rate, (
f"For fine_tuning, waveform must be in the spcified sampling rate {self.sampling_rate}, got {source_sampling_rate}"
)
# Cast ndarray to torch tensor
audio = torch.FloatTensor(audio)
audio = audio.unsqueeze(0) # [B(1), T_time]
# Load pre-computed mel from disk
mel = np.load(
os.path.join(
self.base_mels_path,
os.path.splitext(os.path.split(filename)[-1])[0] + ".npy",
)
)
mel = torch.from_numpy(mel)
if len(mel.shape) < 3:
mel = mel.unsqueeze(0) # ensure [B, C, T]
if self.split:
frames_per_seg = math.ceil(self.segment_size / self.hop_size)
if audio.size(1) >= self.segment_size:
mel_start = random.randint(0, mel.size(2) - frames_per_seg - 1)
mel = mel[:, :, mel_start : mel_start + frames_per_seg]
audio = audio[
:,
mel_start * self.hop_size : (mel_start + frames_per_seg) * self.hop_size,
]
# Pad pre-computed mel and audio to match length to ensuring fine-tuning without error.
# NOTE: this may introduce a single-frame misalignment of the <pre-computed mel, audio>
# To remove possible misalignment, it is recommended to prepare the <pre-computed mel, audio> pair where the audio length is the integer multiple of self.hop_size
mel = torch.nn.functional.pad(mel, (0, frames_per_seg - mel.size(2)), "constant")
audio = torch.nn.functional.pad(audio, (0, self.segment_size - audio.size(1)), "constant")
# Compute mel_loss used by spectral regression objective. Uses self.fmax_loss instead (usually None)
mel_loss = mel_spectrogram(
audio,
self.n_fft,
self.num_mels,
self.sampling_rate,
self.hop_size,
self.win_size,
self.fmin,
self.fmax_loss,
center=False,
) # [B(1), self.num_mels, self.segment_size // self.hop_size]
# Shape sanity checks
assert (
audio.shape[1] == mel.shape[2] * self.hop_size and audio.shape[1] == mel_loss.shape[2] * self.hop_size
), (
f"Audio length must be mel frame length * hop_size. Got audio shape {audio.shape} mel shape {mel.shape} mel_loss shape {mel_loss.shape}"
)
return (mel.squeeze(), audio.squeeze(0), filename, mel_loss.squeeze())
# If it encounters error during loading the data, skip this sample and load random other sample to the batch
except Exception as e:
if self.fine_tuning:
raise e # Terminate training if it is fine-tuning. The dataset should have been prepared properly.
else:
print(f"[WARNING] Failed to load waveform, skipping! filename: {filename} Error: {e}")
return self[random.randrange(len(self))]
def __len__(self):
return len(self.audio_files)

View File

@ -1 +0,0 @@

View File

@ -1,4 +0,0 @@
| Field | Response |
| :--------------------------------------------------------------------------------------------------------- | :--------------------------------------------------- |
| Participation considerations from adversely impacted groups protected classes in model design and testing: | None |
| Measures taken to mitigate against unwanted bias: | No measures taken to mitigate against unwanted bias. |

View File

@ -1,13 +0,0 @@
| Field | Response |
| :---------------------------------------------------------------------------------------------------- | :--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| Intended Application & Domain: | Generating waveform from mel spectrogram. |
| Model Type: | Convolutional Neural Network (CNN) |
| Intended Users: | This model is intended for developers to synthesize and generate waveforms from the AI-generated mel spectrograms. |
| Output: | Audio Waveform |
| Describe how the model works: | Model generates audio waveform corresponding to the input mel spectrogram. |
| Name the adversely impacted groups this has been tested to deliver comparable outcomes regardless of: | Not Applicable |
| Technical Limitations: | This may not perform well on synthetically-generated mel spectrograms that deviate significantly from the profile of mel spectrograms on which this was trained. |
| Verified to have met prescribed NVIDIA quality standards: | Yes |
| Performance Metrics: | Perceptual Evaluation of Speech Quality (PESQ), Virtual Speech Quality Objective Listener (VISQOL), Multi-resolution STFT (MRSTFT), Mel cepstral distortion (MCD), Periodicity RMSE, Voice/Unvoiced F1 Score (V/UV F1) |
| Potential Known Risks: | This model may generate low-quality or distorted soundwaves. |
| Licensing: | https://github.com/NVIDIA/BigVGAN/blob/main/LICENSE |

View File

@ -1,126 +0,0 @@
# Model Overview
## Description:
BigVGAN is a generative AI model specialized in synthesizing audio waveforms using Mel spectrogram as inputs.
<center><img src="https://user-images.githubusercontent.com/15963413/218609148-881e39df-33af-4af9-ab95-1427c4ebf062.png" width="800"></center>
BigVGAN is a fully convolutional architecture with several upsampling blocks using transposed convolution followed by multiple residual dilated convolution layers.
BigVGAN consists of a novel module, called anti-aliased multi-periodicity composition (AMP), which is specifically designed for generating waveforms. AMP is specialized in synthesizing high-frequency and periodic soundwaves drawing inspiration from audio signal processing principles.
It applies a periodic activation function, called Snake, which provides an inductive bias to the architecture in generating periodic soundwaves. It also applies anti-aliasing filters to reduce undesired artifacts in the generated waveforms. <br>
This model is ready for commercial use.<br>
## References(s):
- [BigVGAN: A Universal Neural Vocoder with Large-Scale Training](https://arxiv.org/abs/2206.04658) <br>
- [Project Page](https://research.nvidia.com/labs/adlr/projects/bigvgan/) <br>
- [Audio Demo](https://bigvgan-demo.github.io/) <br>
## Model Architecture:
**Architecture Type:** Convolution Neural Network (CNN) <br>
**Network Architecture:** You can see the details of this model on this link: https://github.com/NVIDIA/BigVGAN and the related paper can be found here: https://arxiv.org/abs/2206.04658<br>
**Model Version:** 2.0 <br>
## Input:
**Input Type:** Audio <br>
**Input Format:** Mel Spectrogram <br>
**Input Parameters:** None <br>
**Other Properties Related to Input:** The input mel spectrogram has shape `[batch, channels, frames]`, where `channels` refers to the number of mel bands defined by the model and `frames` refers to the temporal length. The model supports arbitrary long `frames` that fits into the GPU memory.
## Output:
**Input Type:** Audio <br>
**Output Format:** Audio Waveform <br>
**Output Parameters:** None <br>
**Other Properties Related to Output:** The output audio waveform has shape `[batch, 1, time]`, where `1` refers to the mono audio channels and `time` refers to the temporal length. `time` is defined as a fixed integer multiple of input `frames`, which is an upsampling ratio of the model (`time = upsampling ratio * frames`). The output audio waveform consitutes float values with a range of `[-1, 1]`.
## Software Integration:
**Runtime Engine(s):** PyTorch
**Supported Hardware Microarchitecture Compatibility:** NVIDIA Ampere, NVIDIA Hopper, NVIDIA Lovelace, NVIDIA Turing, NVIDIA Volta <br>
## Preferred/Supported Operating System(s):
Linux
## Model Version(s):
v2.0
## Training, Testing, and Evaluation Datasets:
### Training Dataset:
The dataset contains diverse audio types, including speech in multiple languages, environmental sounds, and instruments.
**Links:**
- [AAM: Artificial Audio Multitracks Dataset](https://zenodo.org/records/5794629)
- [AudioCaps](https://audiocaps.github.io/)
- [AudioSet](https://research.google.com/audioset/index.html)
- [common-accent](https://huggingface.co/datasets/DTU54DL/common-accent)
- [Crowd Sourced Emotional Multimodal Actors Dataset (CREMA-D)](https://ieeexplore.ieee.org/document/6849440)
- [DCASE2017 Challenge, Task 4: Large-scale weakly supervised sound event detection for smart cars](https://dcase.community/challenge2017/task-large-scale-sound-event-detection)
- [FSDnoisy18k](https://zenodo.org/records/2529934)
- [Free Universal Sound Separation Dataset](https://zenodo.org/records/3694384)
- [Greatest Hits dataset](https://andrewowens.com/vis/)
- [GTZAN](https://ieeexplore.ieee.org/document/1021072)
- [JL corpus](https://www.kaggle.com/datasets/tli725/jl-corpus)
- [Medley-solos-DB: a cross-collection dataset for musical instrument recognition](https://zenodo.org/records/3464194)
- [MUSAN: A Music, Speech, and Noise Corpus](https://www.openslr.org/17/)
- [MusicBench](https://huggingface.co/datasets/amaai-lab/MusicBench)
- [MusicCaps](https://www.kaggle.com/datasets/googleai/musiccaps)
- [MusicNet](https://www.kaggle.com/datasets/imsparsh/musicnet-dataset)
- [NSynth](https://magenta.tensorflow.org/datasets/nsynth)
- [OnAir-Music-Dataset](https://github.com/sevagh/OnAir-Music-Dataset)
- [Audio Piano Triads Dataset](https://zenodo.org/records/4740877)
- [Pitch Audio Dataset (Surge synthesizer)](https://zenodo.org/records/4677097)
- [SONYC Urban Sound Tagging (SONYC-UST): a multilabel dataset from an urban acoustic sensor network](https://zenodo.org/records/3966543)
- [VocalSound: A Dataset for Improving Human Vocal Sounds Recognition](https://arxiv.org/abs/2205.03433)
- [WavText5K](https://github.com/microsoft/WavText5K)
- [CSS10: A Collection of Single Speaker Speech Datasets for 10 Languages](https://github.com/Kyubyong/css10)
- [Hi-Fi Multi-Speaker English TTS Dataset (Hi-Fi TTS)](https://www.openslr.org/109/)
- [IIIT-H Indic Speech Databases](http://festvox.org/databases/iiit_voices/)
- [Libri-Light: A Benchmark for ASR with Limited or No Supervision](https://arxiv.org/abs/1912.07875)
- [LibriTTS: A Corpus Derived from LibriSpeech for Text-to-Speech](https://www.openslr.org/60)
- [LibriTTS-R: A Restored Multi-Speaker Text-to-Speech Corpus](https://www.openslr.org/141/)
- [The SIWIS French Speech Synthesis Database](https://datashare.ed.ac.uk/handle/10283/2353)
- [Crowdsourced high-quality Colombian Spanish speech data set](https://openslr.org/72/)
- [TTS-Portuguese Corpus](https://github.com/Edresson/TTS-Portuguese-Corpus)
- [CSTR VCTK Corpus: English Multi-speaker Corpus for CSTR Voice Cloning Toolkit](https://datashare.ed.ac.uk/handle/10283/3443)
\*\* Data Collection Method by dataset <br>
- Human <br>
\*\* Labeling Method by dataset (for those with labels) <br>
- Hybrid: Automated, Human, Unknown <br>
### Evaluating Dataset:
Properties: The audio generation quality of BigVGAN is evaluated using `dev` splits of the [LibriTTS dataset](https://www.openslr.org/60/) and [Hi-Fi TTS dataset](https://www.openslr.org/109/). The datasets include speech in English language with equal balance of genders.
\*\* Data Collection Method by dataset <br>
- Human <br>
\*\* Labeling Method by dataset <br>
- Automated <br>
## Inference:
**Engine:** PyTorch <br>
**Test Hardware:** NVIDIA A100 GPU <br>
## Ethical Considerations:
NVIDIA believes Trustworthy AI is a shared responsibility and we have established policies and practices to enable development for a wide array of AI applications. When downloaded or used in accordance with our terms of service, developers should work with their internal model team to ensure this model meets requirements for the relevant industry and use case and addresses unforeseen product misuse. For more detailed information on ethical considerations for this model, please see the Model Card++ Explainability, Bias, Safety & Security, and Privacy Subcards. Please report security vulnerabilities or NVIDIA AI Concerns [here](https://www.nvidia.com/en-us/support/submit-security-vulnerability/).

View File

@ -1,14 +0,0 @@
| Field | Response |
| :------------------------------------------------------------------------------------------------------------------------------------- | :--------------------------------------------- |
| Generatable or reverse engineerable personal information? | None |
| Protected class data used to create this model? | None |
| Was consent obtained for any personal data used? | Not Applicable (No Personal Data) |
| How often is dataset reviewed? | Before Release |
| Is a mechanism in place to honor data subject right of access or deletion of personal data? | Not Applicable |
| If personal collected for the development of the model, was it collected directly by NVIDIA? | Not Applicable |
| If personal collected for the development of the model by NVIDIA, do you maintain or have access to disclosures made to data subjects? | Not Applicable |
| If personal collected for the development of this AI model, was it minimized to only what was required? | Not Applicable |
| Is data in dataset traceable? | Yes |
| Is there provenance for all datasets used in training? | Yes |
| Does data labeling (annotation, metadata) comply with privacy laws? | Yes |
| Is data compliant with data subject requests for data correction or removal, if such a request was made? | No, not possible with externally-sourced data. |

View File

@ -1,6 +0,0 @@
| Field | Response |
| :---------------------------------------------- | :---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| Model Application(s): | Synethic Audio Generation |
| Describe the life critical impact (if present). | Not Applicable |
| Use Case Restrictions: | None |
| Model and dataset restrictions: | The Principle of least privilege (PoLP) is applied limiting access for dataset generation and model development. Restrictions enforce dataset access during training, and dataset license constraints adhered to. |

View File

@ -1,13 +0,0 @@
torch
numpy
librosa>=0.8.1
scipy
tensorboard
soundfile
matplotlib
pesq
auraloss
tqdm
nnAudio
ninja
huggingface_hub>=0.23.4

View File

@ -1,62 +0,0 @@
# Copyright (c) 2024 NVIDIA CORPORATION.
# Licensed under the MIT license.
import os
import sys
# to import modules from parent_dir
parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
sys.path.append(parent_dir)
import torch
from alias_free_activation.cuda import activation1d
from activations import Snake
def test_load_fused_kernels():
try:
print("[Success] load_fused_kernels")
except ImportError as e:
print("[Fail] load_fused_kernels")
raise e
def test_anti_alias_activation():
data = torch.rand((10, 10, 200), device="cuda")
# Check activations.Snake cuda vs. torch
fused_anti_alias_activation = activation1d.Activation1d(activation=Snake(10), fused=True).cuda()
fused_activation_output = fused_anti_alias_activation(data)
torch_anti_alias_activation = activation1d.Activation1d(activation=Snake(10), fused=False).cuda()
torch_activation_output = torch_anti_alias_activation(data)
test_result = (fused_activation_output - torch_activation_output).abs()
while test_result.dim() != 1:
test_result = test_result.mean(dim=-1)
diff = test_result.mean(dim=-1)
if diff <= 1e-3:
print(
f"\n[Success] test_fused_anti_alias_activation"
f"\n > mean_difference={diff}"
f"\n > fused_values={fused_activation_output[-1][-1][:].tolist()}"
f"\n > torch_values={torch_activation_output[-1][-1][:].tolist()}"
)
else:
print(
f"\n[Fail] test_fused_anti_alias_activation"
f"\n > mean_difference={diff}, "
f"\n > fused_values={fused_activation_output[-1][-1][:].tolist()}, "
f"\n > torch_values={torch_activation_output[-1][-1][:].tolist()}"
)
if __name__ == "__main__":
from alias_free_activation.cuda import load
load.load()
test_load_fused_kernels()
test_anti_alias_activation()

View File

@ -1,62 +0,0 @@
# Copyright (c) 2024 NVIDIA CORPORATION.
# Licensed under the MIT license.
import os
import sys
# to import modules from parent_dir
parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
sys.path.append(parent_dir)
import torch
from alias_free_activation.cuda import activation1d
from activations import SnakeBeta
def test_load_fused_kernels():
try:
print("[Success] load_fused_kernels")
except ImportError as e:
print("[Fail] load_fused_kernels")
raise e
def test_anti_alias_activation():
data = torch.rand((10, 10, 200), device="cuda")
# Check activations, Snake CUDA vs. Torch
fused_anti_alias_activation = activation1d.Activation1d(activation=SnakeBeta(10), fused=True).cuda()
fused_activation_output = fused_anti_alias_activation(data)
torch_anti_alias_activation = activation1d.Activation1d(activation=SnakeBeta(10), fused=False).cuda()
torch_activation_output = torch_anti_alias_activation(data)
test_result = (fused_activation_output - torch_activation_output).abs()
while test_result.dim() != 1:
test_result = test_result.mean(dim=-1)
diff = test_result.mean(dim=-1)
if diff <= 1e-3:
print(
f"\n[Success] test_fused_anti_alias_activation"
f"\n > mean_difference={diff}"
f"\n > fused_values={fused_activation_output[-1][-1][:].tolist()}"
f"\n > torch_values={torch_activation_output[-1][-1][:].tolist()}"
)
else:
print(
f"\n[Fail] test_fused_anti_alias_activation"
f"\n > mean_difference={diff}, "
f"\n > fused_values={fused_activation_output[-1][-1][:].tolist()}, "
f"\n > torch_values={torch_activation_output[-1][-1][:].tolist()}"
)
if __name__ == "__main__":
from alias_free_activation.cuda import load
load.load()
test_load_fused_kernels()
test_anti_alias_activation()

View File

@ -1,215 +0,0 @@
# Copyright (c) 2024 NVIDIA CORPORATION.
# Licensed under the MIT license.
import os
import sys
# to import modules from parent_dir
parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
sys.path.append(parent_dir)
import torch
import json
from env import AttrDict
from bigvgan import BigVGAN
from time import time
from tqdm import tqdm
from meldataset import mel_spectrogram, MAX_WAV_VALUE
from scipy.io.wavfile import write
import numpy as np
import argparse
torch.backends.cudnn.benchmark = True
# For easier debugging
torch.set_printoptions(linewidth=200, threshold=10_000)
def generate_soundwave(duration=5.0, sr=24000):
t = np.linspace(0, duration, int(sr * duration), False, dtype=np.float32)
modulation = np.sin(2 * np.pi * t / duration)
min_freq = 220
max_freq = 1760
frequencies = min_freq + (max_freq - min_freq) * (modulation + 1) / 2
soundwave = np.sin(2 * np.pi * frequencies * t)
soundwave = soundwave / np.max(np.abs(soundwave)) * 0.95
return soundwave, sr
def get_mel(x, h):
return mel_spectrogram(x, h.n_fft, h.num_mels, h.sampling_rate, h.hop_size, h.win_size, h.fmin, h.fmax)
def load_checkpoint(filepath, device):
assert os.path.isfile(filepath)
print(f"Loading '{filepath}'")
checkpoint_dict = torch.load(filepath, map_location=device)
print("Complete.")
return checkpoint_dict
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Test script to check CUDA kernel correctness.")
parser.add_argument(
"--checkpoint_file",
type=str,
required=True,
help="Path to the checkpoint file. Assumes config.json exists in the directory.",
)
args = parser.parse_args()
config_file = os.path.join(os.path.split(args.checkpoint_file)[0], "config.json")
with open(config_file) as f:
config = f.read()
json_config = json.loads(config)
h = AttrDict({**json_config})
print("loading plain Pytorch BigVGAN")
generator_original = BigVGAN(h).to("cuda")
print("loading CUDA kernel BigVGAN with auto-build")
generator_cuda_kernel = BigVGAN(h, use_cuda_kernel=True).to("cuda")
state_dict_g = load_checkpoint(args.checkpoint_file, "cuda")
generator_original.load_state_dict(state_dict_g["generator"])
generator_cuda_kernel.load_state_dict(state_dict_g["generator"])
generator_original.remove_weight_norm()
generator_original.eval()
generator_cuda_kernel.remove_weight_norm()
generator_cuda_kernel.eval()
# define number of samples and length of mel frame to benchmark
num_sample = 10
num_mel_frame = 16384
# CUDA kernel correctness check
diff = 0.0
for i in tqdm(range(num_sample)):
# Random mel
data = torch.rand((1, h.num_mels, num_mel_frame), device="cuda")
with torch.inference_mode():
audio_original = generator_original(data)
with torch.inference_mode():
audio_cuda_kernel = generator_cuda_kernel(data)
# Both outputs should be (almost) the same
test_result = (audio_original - audio_cuda_kernel).abs()
diff += test_result.mean(dim=-1).item()
diff /= num_sample
if diff <= 2e-3: # We can expect a small difference (~1e-3) which does not affect perceptual quality
print(
f"\n[Success] test CUDA fused vs. plain torch BigVGAN inference"
f"\n > mean_difference={diff}"
f"\n > fused_values={audio_cuda_kernel[-1][-1][-30:].tolist()}"
f"\n > torch_values={audio_original[-1][-1][-30:].tolist()}"
)
else:
print(
f"\n[Fail] test CUDA fused vs. plain torch BigVGAN inference"
f"\n > mean_difference={diff}"
f"\n > fused_values={audio_cuda_kernel[-1][-1][-30:].tolist()}, "
f"\n > torch_values={audio_original[-1][-1][-30:].tolist()}"
)
del data, audio_original, audio_cuda_kernel
# Variables for tracking total time and VRAM usage
toc_total_original = 0
toc_total_cuda_kernel = 0
vram_used_original_total = 0
vram_used_cuda_kernel_total = 0
audio_length_total = 0
# Measure Original inference in isolation
for i in tqdm(range(num_sample)):
torch.cuda.reset_peak_memory_stats(device="cuda")
data = torch.rand((1, h.num_mels, num_mel_frame), device="cuda")
torch.cuda.synchronize()
tic = time()
with torch.inference_mode():
audio_original = generator_original(data)
torch.cuda.synchronize()
toc = time() - tic
toc_total_original += toc
vram_used_original_total += torch.cuda.max_memory_allocated(device="cuda")
del data, audio_original
torch.cuda.empty_cache()
# Measure CUDA kernel inference in isolation
for i in tqdm(range(num_sample)):
torch.cuda.reset_peak_memory_stats(device="cuda")
data = torch.rand((1, h.num_mels, num_mel_frame), device="cuda")
torch.cuda.synchronize()
tic = time()
with torch.inference_mode():
audio_cuda_kernel = generator_cuda_kernel(data)
torch.cuda.synchronize()
toc = time() - tic
toc_total_cuda_kernel += toc
audio_length_total += audio_cuda_kernel.shape[-1]
vram_used_cuda_kernel_total += torch.cuda.max_memory_allocated(device="cuda")
del data, audio_cuda_kernel
torch.cuda.empty_cache()
# Calculate metrics
audio_second = audio_length_total / h.sampling_rate
khz_original = audio_length_total / toc_total_original / 1000
khz_cuda_kernel = audio_length_total / toc_total_cuda_kernel / 1000
vram_used_original_gb = vram_used_original_total / num_sample / (1024**3)
vram_used_cuda_kernel_gb = vram_used_cuda_kernel_total / num_sample / (1024**3)
# Print results
print(
f"Original BigVGAN: took {toc_total_original:.2f} seconds to generate {audio_second:.2f} seconds of audio, {khz_original:.1f}kHz, {audio_second / toc_total_original:.1f} faster than realtime, VRAM used {vram_used_original_gb:.1f} GB"
)
print(
f"CUDA kernel BigVGAN: took {toc_total_cuda_kernel:.2f} seconds to generate {audio_second:.2f} seconds of audio, {khz_cuda_kernel:.1f}kHz, {audio_second / toc_total_cuda_kernel:.1f} faster than realtime, VRAM used {vram_used_cuda_kernel_gb:.1f} GB"
)
print(f"speedup of CUDA kernel: {khz_cuda_kernel / khz_original}")
print(f"VRAM saving of CUDA kernel: {vram_used_original_gb / vram_used_cuda_kernel_gb}")
# Use artificial sine waves for inference test
audio_real, sr = generate_soundwave(duration=5.0, sr=h.sampling_rate)
audio_real = torch.tensor(audio_real).to("cuda")
# Compute mel spectrogram from the ground truth audio
x = get_mel(audio_real.unsqueeze(0), h)
with torch.inference_mode():
y_g_hat_original = generator_original(x)
y_g_hat_cuda_kernel = generator_cuda_kernel(x)
audio_real = audio_real.squeeze()
audio_real = audio_real * MAX_WAV_VALUE
audio_real = audio_real.cpu().numpy().astype("int16")
audio_original = y_g_hat_original.squeeze()
audio_original = audio_original * MAX_WAV_VALUE
audio_original = audio_original.cpu().numpy().astype("int16")
audio_cuda_kernel = y_g_hat_cuda_kernel.squeeze()
audio_cuda_kernel = audio_cuda_kernel * MAX_WAV_VALUE
audio_cuda_kernel = audio_cuda_kernel.cpu().numpy().astype("int16")
os.makedirs("tmp", exist_ok=True)
output_file_real = os.path.join("tmp", "audio_real.wav")
output_file_original = os.path.join("tmp", "audio_generated_original.wav")
output_file_cuda_kernel = os.path.join("tmp", "audio_generated_cuda_kernel.wav")
write(output_file_real, h.sampling_rate, audio_real)
write(output_file_original, h.sampling_rate, audio_original)
write(output_file_cuda_kernel, h.sampling_rate, audio_cuda_kernel)
print("Example generated audios of original vs. fused CUDA kernel written to tmp!")
print("Done")

View File

@ -1,716 +0,0 @@
# Copyright (c) 2024 NVIDIA CORPORATION.
# Licensed under the MIT license.
# Adapted from https://github.com/jik876/hifi-gan under the MIT license.
# LICENSE is in incl_licenses directory.
import warnings
warnings.simplefilter(action="ignore", category=FutureWarning)
import itertools
import os
import time
import argparse
import json
import torch
import torch.nn.functional as F
from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import DistributedSampler, DataLoader
import torch.multiprocessing as mp
from torch.distributed import init_process_group
from torch.nn.parallel import DistributedDataParallel
from env import AttrDict, build_env
from meldataset import MelDataset, mel_spectrogram, get_dataset_filelist, MAX_WAV_VALUE
from bigvgan import BigVGAN
from discriminators import (
MultiPeriodDiscriminator,
MultiResolutionDiscriminator,
MultiBandDiscriminator,
MultiScaleSubbandCQTDiscriminator,
)
from loss import (
feature_loss,
generator_loss,
discriminator_loss,
MultiScaleMelSpectrogramLoss,
)
from utils import (
plot_spectrogram,
plot_spectrogram_clipped,
scan_checkpoint,
load_checkpoint,
save_checkpoint,
save_audio,
)
import torchaudio as ta
from pesq import pesq
from tqdm import tqdm
import auraloss
torch.backends.cudnn.benchmark = False
def train(rank, a, h):
if h.num_gpus > 1:
# initialize distributed
init_process_group(
backend=h.dist_config["dist_backend"],
init_method=h.dist_config["dist_url"],
world_size=h.dist_config["world_size"] * h.num_gpus,
rank=rank,
)
# Set seed and device
torch.cuda.manual_seed(h.seed)
torch.cuda.set_device(rank)
device = torch.device(f"cuda:{rank:d}")
# Define BigVGAN generator
generator = BigVGAN(h).to(device)
# Define discriminators. MPD is used by default
mpd = MultiPeriodDiscriminator(h).to(device)
# Define additional discriminators. BigVGAN-v1 uses UnivNet's MRD as default
# New in BigVGAN-v2: option to switch to new discriminators: MultiBandDiscriminator / MultiScaleSubbandCQTDiscriminator
if h.get("use_mbd_instead_of_mrd", False): # Switch to MBD
print("[INFO] using MultiBandDiscriminator of BigVGAN-v2 instead of MultiResolutionDiscriminator")
# Variable name is kept as "mrd" for backward compatibility & minimal code change
mrd = MultiBandDiscriminator(h).to(device)
elif h.get("use_cqtd_instead_of_mrd", False): # Switch to CQTD
print("[INFO] using MultiScaleSubbandCQTDiscriminator of BigVGAN-v2 instead of MultiResolutionDiscriminator")
mrd = MultiScaleSubbandCQTDiscriminator(h).to(device)
else: # Fallback to original MRD in BigVGAN-v1
mrd = MultiResolutionDiscriminator(h).to(device)
# New in BigVGAN-v2: option to switch to multi-scale L1 mel loss
if h.get("use_multiscale_melloss", False):
print("[INFO] using multi-scale Mel l1 loss of BigVGAN-v2 instead of the original single-scale loss")
fn_mel_loss_multiscale = MultiScaleMelSpectrogramLoss(
sampling_rate=h.sampling_rate
) # NOTE: accepts waveform as input
else:
fn_mel_loss_singlescale = F.l1_loss
# Print the model & number of parameters, and create or scan the latest checkpoint from checkpoints directory
if rank == 0:
print(generator)
print(mpd)
print(mrd)
print(f"Generator params: {sum(p.numel() for p in generator.parameters())}")
print(f"Discriminator mpd params: {sum(p.numel() for p in mpd.parameters())}")
print(f"Discriminator mrd params: {sum(p.numel() for p in mrd.parameters())}")
os.makedirs(a.checkpoint_path, exist_ok=True)
print(f"Checkpoints directory: {a.checkpoint_path}")
if os.path.isdir(a.checkpoint_path):
# New in v2.1: If the step prefix pattern-based checkpoints are not found, also check for renamed files in Hugging Face Hub to resume training
cp_g = scan_checkpoint(a.checkpoint_path, prefix="g_", renamed_file="bigvgan_generator.pt")
cp_do = scan_checkpoint(
a.checkpoint_path,
prefix="do_",
renamed_file="bigvgan_discriminator_optimizer.pt",
)
# Load the latest checkpoint if exists
steps = 0
if cp_g is None or cp_do is None:
state_dict_do = None
last_epoch = -1
else:
state_dict_g = load_checkpoint(cp_g, device)
state_dict_do = load_checkpoint(cp_do, device)
generator.load_state_dict(state_dict_g["generator"])
mpd.load_state_dict(state_dict_do["mpd"])
mrd.load_state_dict(state_dict_do["mrd"])
steps = state_dict_do["steps"] + 1
last_epoch = state_dict_do["epoch"]
# Initialize DDP, optimizers, and schedulers
if h.num_gpus > 1:
generator = DistributedDataParallel(generator, device_ids=[rank]).to(device)
mpd = DistributedDataParallel(mpd, device_ids=[rank]).to(device)
mrd = DistributedDataParallel(mrd, device_ids=[rank]).to(device)
optim_g = torch.optim.AdamW(generator.parameters(), h.learning_rate, betas=[h.adam_b1, h.adam_b2])
optim_d = torch.optim.AdamW(
itertools.chain(mrd.parameters(), mpd.parameters()),
h.learning_rate,
betas=[h.adam_b1, h.adam_b2],
)
if state_dict_do is not None:
optim_g.load_state_dict(state_dict_do["optim_g"])
optim_d.load_state_dict(state_dict_do["optim_d"])
scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optim_g, gamma=h.lr_decay, last_epoch=last_epoch)
scheduler_d = torch.optim.lr_scheduler.ExponentialLR(optim_d, gamma=h.lr_decay, last_epoch=last_epoch)
# Define training and validation datasets
"""
unseen_validation_filelist will contain sample filepaths outside the seen training & validation dataset
Example: trained on LibriTTS, validate on VCTK
"""
training_filelist, validation_filelist, list_unseen_validation_filelist = get_dataset_filelist(a)
trainset = MelDataset(
training_filelist,
h,
h.segment_size,
h.n_fft,
h.num_mels,
h.hop_size,
h.win_size,
h.sampling_rate,
h.fmin,
h.fmax,
shuffle=False if h.num_gpus > 1 else True,
fmax_loss=h.fmax_for_loss,
device=device,
fine_tuning=a.fine_tuning,
base_mels_path=a.input_mels_dir,
is_seen=True,
)
train_sampler = DistributedSampler(trainset) if h.num_gpus > 1 else None
train_loader = DataLoader(
trainset,
num_workers=h.num_workers,
shuffle=False,
sampler=train_sampler,
batch_size=h.batch_size,
pin_memory=True,
drop_last=True,
)
if rank == 0:
validset = MelDataset(
validation_filelist,
h,
h.segment_size,
h.n_fft,
h.num_mels,
h.hop_size,
h.win_size,
h.sampling_rate,
h.fmin,
h.fmax,
False,
False,
fmax_loss=h.fmax_for_loss,
device=device,
fine_tuning=a.fine_tuning,
base_mels_path=a.input_mels_dir,
is_seen=True,
)
validation_loader = DataLoader(
validset,
num_workers=1,
shuffle=False,
sampler=None,
batch_size=1,
pin_memory=True,
drop_last=True,
)
list_unseen_validset = []
list_unseen_validation_loader = []
for i in range(len(list_unseen_validation_filelist)):
unseen_validset = MelDataset(
list_unseen_validation_filelist[i],
h,
h.segment_size,
h.n_fft,
h.num_mels,
h.hop_size,
h.win_size,
h.sampling_rate,
h.fmin,
h.fmax,
False,
False,
fmax_loss=h.fmax_for_loss,
device=device,
fine_tuning=a.fine_tuning,
base_mels_path=a.input_mels_dir,
is_seen=False,
)
unseen_validation_loader = DataLoader(
unseen_validset,
num_workers=1,
shuffle=False,
sampler=None,
batch_size=1,
pin_memory=True,
drop_last=True,
)
list_unseen_validset.append(unseen_validset)
list_unseen_validation_loader.append(unseen_validation_loader)
# Tensorboard logger
sw = SummaryWriter(os.path.join(a.checkpoint_path, "logs"))
if a.save_audio: # Also save audio to disk if --save_audio is set to True
os.makedirs(os.path.join(a.checkpoint_path, "samples"), exist_ok=True)
"""
Validation loop, "mode" parameter is automatically defined as (seen or unseen)_(name of the dataset).
If the name of the dataset contains "nonspeech", it skips PESQ calculation to prevent errors
"""
def validate(rank, a, h, loader, mode="seen"):
assert rank == 0, "validate should only run on rank=0"
generator.eval()
torch.cuda.empty_cache()
val_err_tot = 0
val_pesq_tot = 0
val_mrstft_tot = 0
# Modules for evaluation metrics
pesq_resampler = ta.transforms.Resample(h.sampling_rate, 16000).cuda()
loss_mrstft = auraloss.freq.MultiResolutionSTFTLoss(device="cuda")
if a.save_audio: # Also save audio to disk if --save_audio is set to True
os.makedirs(
os.path.join(a.checkpoint_path, "samples", f"gt_{mode}"),
exist_ok=True,
)
os.makedirs(
os.path.join(a.checkpoint_path, "samples", f"{mode}_{steps:08d}"),
exist_ok=True,
)
with torch.no_grad():
print(f"step {steps} {mode} speaker validation...")
# Loop over validation set and compute metrics
for j, batch in enumerate(tqdm(loader)):
x, y, _, y_mel = batch
y = y.to(device)
if hasattr(generator, "module"):
y_g_hat = generator.module(x.to(device))
else:
y_g_hat = generator(x.to(device))
y_mel = y_mel.to(device, non_blocking=True)
y_g_hat_mel = mel_spectrogram(
y_g_hat.squeeze(1),
h.n_fft,
h.num_mels,
h.sampling_rate,
h.hop_size,
h.win_size,
h.fmin,
h.fmax_for_loss,
)
min_t = min(y_mel.size(-1), y_g_hat_mel.size(-1))
val_err_tot += F.l1_loss(y_mel[..., :min_t], y_g_hat_mel[..., :min_t]).item()
# PESQ calculation. only evaluate PESQ if it's speech signal (nonspeech PESQ will error out)
if "nonspeech" not in mode: # Skips if the name of dataset (in mode string) contains "nonspeech"
# Resample to 16000 for pesq
y_16k = pesq_resampler(y)
y_g_hat_16k = pesq_resampler(y_g_hat.squeeze(1))
y_int_16k = (y_16k[0] * MAX_WAV_VALUE).short().cpu().numpy()
y_g_hat_int_16k = (y_g_hat_16k[0] * MAX_WAV_VALUE).short().cpu().numpy()
val_pesq_tot += pesq(16000, y_int_16k, y_g_hat_int_16k, "wb")
# MRSTFT calculation
min_t = min(y.size(-1), y_g_hat.size(-1))
val_mrstft_tot += loss_mrstft(y_g_hat[..., :min_t], y[..., :min_t]).item()
# Log audio and figures to Tensorboard
if j % a.eval_subsample == 0: # Subsample every nth from validation set
if steps >= 0:
sw.add_audio(f"gt_{mode}/y_{j}", y[0], steps, h.sampling_rate)
if a.save_audio: # Also save audio to disk if --save_audio is set to True
save_audio(
y[0],
os.path.join(
a.checkpoint_path,
"samples",
f"gt_{mode}",
f"{j:04d}.wav",
),
h.sampling_rate,
)
sw.add_figure(
f"gt_{mode}/y_spec_{j}",
plot_spectrogram(x[0]),
steps,
)
sw.add_audio(
f"generated_{mode}/y_hat_{j}",
y_g_hat[0],
steps,
h.sampling_rate,
)
if a.save_audio: # Also save audio to disk if --save_audio is set to True
save_audio(
y_g_hat[0, 0],
os.path.join(
a.checkpoint_path,
"samples",
f"{mode}_{steps:08d}",
f"{j:04d}.wav",
),
h.sampling_rate,
)
# Spectrogram of synthesized audio
y_hat_spec = mel_spectrogram(
y_g_hat.squeeze(1),
h.n_fft,
h.num_mels,
h.sampling_rate,
h.hop_size,
h.win_size,
h.fmin,
h.fmax,
)
sw.add_figure(
f"generated_{mode}/y_hat_spec_{j}",
plot_spectrogram(y_hat_spec.squeeze(0).cpu().numpy()),
steps,
)
"""
Visualization of spectrogram difference between GT and synthesized audio, difference higher than 1 is clipped for better visualization.
"""
spec_delta = torch.clamp(
torch.abs(x[0] - y_hat_spec.squeeze(0).cpu()),
min=1e-6,
max=1.0,
)
sw.add_figure(
f"delta_dclip1_{mode}/spec_{j}",
plot_spectrogram_clipped(spec_delta.numpy(), clip_max=1.0),
steps,
)
val_err = val_err_tot / (j + 1)
val_pesq = val_pesq_tot / (j + 1)
val_mrstft = val_mrstft_tot / (j + 1)
# Log evaluation metrics to Tensorboard
sw.add_scalar(f"validation_{mode}/mel_spec_error", val_err, steps)
sw.add_scalar(f"validation_{mode}/pesq", val_pesq, steps)
sw.add_scalar(f"validation_{mode}/mrstft", val_mrstft, steps)
generator.train()
# If the checkpoint is loaded, start with validation loop
if steps != 0 and rank == 0 and not a.debug:
if not a.skip_seen:
validate(
rank,
a,
h,
validation_loader,
mode=f"seen_{train_loader.dataset.name}",
)
for i in range(len(list_unseen_validation_loader)):
validate(
rank,
a,
h,
list_unseen_validation_loader[i],
mode=f"unseen_{list_unseen_validation_loader[i].dataset.name}",
)
# Exit the script if --evaluate is set to True
if a.evaluate:
exit()
# Main training loop
generator.train()
mpd.train()
mrd.train()
for epoch in range(max(0, last_epoch), a.training_epochs):
if rank == 0:
start = time.time()
print(f"Epoch: {epoch + 1}")
if h.num_gpus > 1:
train_sampler.set_epoch(epoch)
for i, batch in enumerate(train_loader):
if rank == 0:
start_b = time.time()
x, y, _, y_mel = batch
x = x.to(device, non_blocking=True)
y = y.to(device, non_blocking=True)
y_mel = y_mel.to(device, non_blocking=True)
y = y.unsqueeze(1)
y_g_hat = generator(x)
y_g_hat_mel = mel_spectrogram(
y_g_hat.squeeze(1),
h.n_fft,
h.num_mels,
h.sampling_rate,
h.hop_size,
h.win_size,
h.fmin,
h.fmax_for_loss,
)
optim_d.zero_grad()
# MPD
y_df_hat_r, y_df_hat_g, _, _ = mpd(y, y_g_hat.detach())
loss_disc_f, losses_disc_f_r, losses_disc_f_g = discriminator_loss(y_df_hat_r, y_df_hat_g)
# MRD
y_ds_hat_r, y_ds_hat_g, _, _ = mrd(y, y_g_hat.detach())
loss_disc_s, losses_disc_s_r, losses_disc_s_g = discriminator_loss(y_ds_hat_r, y_ds_hat_g)
loss_disc_all = loss_disc_s + loss_disc_f
# Set clip_grad_norm value
clip_grad_norm = h.get("clip_grad_norm", 1000.0) # Default to 1000
# Whether to freeze D for initial training steps
if steps >= a.freeze_step:
loss_disc_all.backward()
grad_norm_mpd = torch.nn.utils.clip_grad_norm_(mpd.parameters(), clip_grad_norm)
grad_norm_mrd = torch.nn.utils.clip_grad_norm_(mrd.parameters(), clip_grad_norm)
optim_d.step()
else:
print(f"[WARNING] skipping D training for the first {a.freeze_step} steps")
grad_norm_mpd = 0.0
grad_norm_mrd = 0.0
# Generator
optim_g.zero_grad()
# L1 Mel-Spectrogram Loss
lambda_melloss = h.get("lambda_melloss", 45.0) # Defaults to 45 in BigVGAN-v1 if not set
if h.get("use_multiscale_melloss", False): # uses wav <y, y_g_hat> for loss
loss_mel = fn_mel_loss_multiscale(y, y_g_hat) * lambda_melloss
else: # Uses mel <y_mel, y_g_hat_mel> for loss
loss_mel = fn_mel_loss_singlescale(y_mel, y_g_hat_mel) * lambda_melloss
# MPD loss
y_df_hat_r, y_df_hat_g, fmap_f_r, fmap_f_g = mpd(y, y_g_hat)
loss_fm_f = feature_loss(fmap_f_r, fmap_f_g)
loss_gen_f, losses_gen_f = generator_loss(y_df_hat_g)
# MRD loss
y_ds_hat_r, y_ds_hat_g, fmap_s_r, fmap_s_g = mrd(y, y_g_hat)
loss_fm_s = feature_loss(fmap_s_r, fmap_s_g)
loss_gen_s, losses_gen_s = generator_loss(y_ds_hat_g)
if steps >= a.freeze_step:
loss_gen_all = loss_gen_s + loss_gen_f + loss_fm_s + loss_fm_f + loss_mel
else:
print(f"[WARNING] using regression loss only for G for the first {a.freeze_step} steps")
loss_gen_all = loss_mel
loss_gen_all.backward()
grad_norm_g = torch.nn.utils.clip_grad_norm_(generator.parameters(), clip_grad_norm)
optim_g.step()
if rank == 0:
# STDOUT logging
if steps % a.stdout_interval == 0:
mel_error = loss_mel.item() / lambda_melloss # Log training mel regression loss to stdout
print(
f"Steps: {steps:d}, "
f"Gen Loss Total: {loss_gen_all:4.3f}, "
f"Mel Error: {mel_error:4.3f}, "
f"s/b: {time.time() - start_b:4.3f} "
f"lr: {optim_g.param_groups[0]['lr']:4.7f} "
f"grad_norm_g: {grad_norm_g:4.3f}"
)
# Checkpointing
if steps % a.checkpoint_interval == 0 and steps != 0:
checkpoint_path = f"{a.checkpoint_path}/g_{steps:08d}"
save_checkpoint(
checkpoint_path,
{"generator": (generator.module if h.num_gpus > 1 else generator).state_dict()},
)
checkpoint_path = f"{a.checkpoint_path}/do_{steps:08d}"
save_checkpoint(
checkpoint_path,
{
"mpd": (mpd.module if h.num_gpus > 1 else mpd).state_dict(),
"mrd": (mrd.module if h.num_gpus > 1 else mrd).state_dict(),
"optim_g": optim_g.state_dict(),
"optim_d": optim_d.state_dict(),
"steps": steps,
"epoch": epoch,
},
)
# Tensorboard summary logging
if steps % a.summary_interval == 0:
mel_error = loss_mel.item() / lambda_melloss # Log training mel regression loss to tensorboard
sw.add_scalar("training/gen_loss_total", loss_gen_all.item(), steps)
sw.add_scalar("training/mel_spec_error", mel_error, steps)
sw.add_scalar("training/fm_loss_mpd", loss_fm_f.item(), steps)
sw.add_scalar("training/gen_loss_mpd", loss_gen_f.item(), steps)
sw.add_scalar("training/disc_loss_mpd", loss_disc_f.item(), steps)
sw.add_scalar("training/grad_norm_mpd", grad_norm_mpd, steps)
sw.add_scalar("training/fm_loss_mrd", loss_fm_s.item(), steps)
sw.add_scalar("training/gen_loss_mrd", loss_gen_s.item(), steps)
sw.add_scalar("training/disc_loss_mrd", loss_disc_s.item(), steps)
sw.add_scalar("training/grad_norm_mrd", grad_norm_mrd, steps)
sw.add_scalar("training/grad_norm_g", grad_norm_g, steps)
sw.add_scalar("training/learning_rate_d", scheduler_d.get_last_lr()[0], steps)
sw.add_scalar("training/learning_rate_g", scheduler_g.get_last_lr()[0], steps)
sw.add_scalar("training/epoch", epoch + 1, steps)
# Validation
if steps % a.validation_interval == 0:
# Plot training input x so far used
for i_x in range(x.shape[0]):
sw.add_figure(
f"training_input/x_{i_x}",
plot_spectrogram(x[i_x].cpu()),
steps,
)
sw.add_audio(
f"training_input/y_{i_x}",
y[i_x][0],
steps,
h.sampling_rate,
)
# Seen and unseen speakers validation loops
if not a.debug and steps != 0:
validate(
rank,
a,
h,
validation_loader,
mode=f"seen_{train_loader.dataset.name}",
)
for i in range(len(list_unseen_validation_loader)):
validate(
rank,
a,
h,
list_unseen_validation_loader[i],
mode=f"unseen_{list_unseen_validation_loader[i].dataset.name}",
)
steps += 1
# BigVGAN-v2 learning rate scheduler is changed from epoch-level to step-level
scheduler_g.step()
scheduler_d.step()
if rank == 0:
print(f"Time taken for epoch {epoch + 1} is {int(time.time() - start)} sec\n")
def main():
print("Initializing Training Process..")
parser = argparse.ArgumentParser()
parser.add_argument("--group_name", default=None)
parser.add_argument("--input_wavs_dir", default="LibriTTS")
parser.add_argument("--input_mels_dir", default="ft_dataset")
parser.add_argument("--input_training_file", default="tests/LibriTTS/train-full.txt")
parser.add_argument("--input_validation_file", default="tests/LibriTTS/val-full.txt")
parser.add_argument(
"--list_input_unseen_wavs_dir",
nargs="+",
default=["tests/LibriTTS", "tests/LibriTTS"],
)
parser.add_argument(
"--list_input_unseen_validation_file",
nargs="+",
default=["tests/LibriTTS/dev-clean.txt", "tests/LibriTTS/dev-other.txt"],
)
parser.add_argument("--checkpoint_path", default="exp/bigvgan")
parser.add_argument("--config", default="")
parser.add_argument("--training_epochs", default=100000, type=int)
parser.add_argument("--stdout_interval", default=5, type=int)
parser.add_argument("--checkpoint_interval", default=50000, type=int)
parser.add_argument("--summary_interval", default=100, type=int)
parser.add_argument("--validation_interval", default=50000, type=int)
parser.add_argument(
"--freeze_step",
default=0,
type=int,
help="freeze D for the first specified steps. G only uses regression loss for these steps.",
)
parser.add_argument("--fine_tuning", default=False, type=bool)
parser.add_argument(
"--debug",
default=False,
type=bool,
help="debug mode. skips validation loop throughout training",
)
parser.add_argument(
"--evaluate",
default=False,
type=bool,
help="only run evaluation from checkpoint and exit",
)
parser.add_argument(
"--eval_subsample",
default=5,
type=int,
help="subsampling during evaluation loop",
)
parser.add_argument(
"--skip_seen",
default=False,
type=bool,
help="skip seen dataset. useful for test set inference",
)
parser.add_argument(
"--save_audio",
default=False,
type=bool,
help="save audio of test set inference to disk",
)
a = parser.parse_args()
with open(a.config) as f:
data = f.read()
json_config = json.loads(data)
h = AttrDict(json_config)
build_env(a.config, "config.json", a.checkpoint_path)
torch.manual_seed(h.seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(h.seed)
h.num_gpus = torch.cuda.device_count()
h.batch_size = int(h.batch_size / h.num_gpus)
print(f"Batch size per GPU: {h.batch_size}")
else:
pass
if h.num_gpus > 1:
mp.spawn(
train,
nprocs=h.num_gpus,
args=(
a,
h,
),
)
else:
train(0, a, h)
if __name__ == "__main__":
main()

View File

@ -21,20 +21,20 @@ import numpy as np
import torch
import torch.nn.functional as F
import yaml
from AR.models.t2s_lightning_module import Text2SemanticLightningModule
from BigVGAN.bigvgan import BigVGAN
from feature_extractor.cnhubert import CNHubert
from module.mel_processing import mel_spectrogram_torch, spectrogram_torch
from module.models import SynthesizerTrn, SynthesizerTrnV3, Generator
from peft import LoraConfig, get_peft_model
from process_ckpt import get_sovits_version_from_path_fast, load_sovits_new
from sv import SV
from transformers import AutoModelForMaskedLM, AutoTokenizer
from GPT_SoVITS.AR.models.t2s_lightning_module import Text2SemanticLightningModule
from GPT_SoVITS.module.mel_processing import mel_spectrogram_torch, spectrogram_torch
from GPT_SoVITS.module.models import Generator, SynthesizerTrn, SynthesizerTrnV3
from GPT_SoVITS.process_ckpt import get_sovits_version_from_path_fast, load_sovits_new
from tools.audio_sr import AP_BWE
from tools.i18n.i18n import I18nAuto, scan_language_list
from TTS_infer_pack.text_segmentation_method import splits
from TTS_infer_pack.TextPreprocessor import TextPreprocessor
from sv import SV
resample_transform_dict = {}
@ -64,33 +64,32 @@ def denorm_spec(x):
return (x + 1) / 2 * (spec_max - spec_min) + spec_min
mel_fn = lambda x: mel_spectrogram_torch(
x,
**{
"n_fft": 1024,
"win_size": 1024,
"hop_size": 256,
"num_mels": 100,
"sampling_rate": 24000,
"fmin": 0,
"fmax": None,
"center": False,
},
)
def mel_fn(x):
return mel_spectrogram_torch(
y=x,
n_fft=1024,
num_mels=100,
sampling_rate=24000,
hop_size=256,
win_size=1024,
fmin=0,
fmax=None,
center=False,
)
mel_fn_v4 = lambda x: mel_spectrogram_torch(
x,
**{
"n_fft": 1280,
"win_size": 1280,
"hop_size": 320,
"num_mels": 100,
"sampling_rate": 32000,
"fmin": 0,
"fmax": None,
"center": False,
},
)
def mel_fn_v4(x):
return mel_spectrogram_torch(
y=x,
n_fft=1280,
num_mels=100,
sampling_rate=32000,
hop_size=320,
win_size=1280,
fmin=0,
fmax=None,
center=False,
)
def speed_change(input_audio: np.ndarray, speed: float, sr: int):
@ -488,7 +487,7 @@ class TTS:
self.init_sv_model()
path_sovits = self.configs.default_configs[model_version]["vits_weights_path"]
if if_lora_v3 == True and os.path.exists(path_sovits) == False:
if if_lora_v3 is True and os.path.exists(path_sovits) is False:
info = path_sovits + i18n("SoVITS %s 底模缺失,无法加载相应 LoRA 权重" % model_version)
raise FileExistsError(info)
@ -549,7 +548,7 @@ class TTS:
self.is_v2pro = model_version in {"v2Pro", "v2ProPlus"}
if if_lora_v3 == False:
if if_lora_v3 is False:
print(
f"Loading VITS weights from {weights_path}. {vits_model.load_state_dict(dict_s2['weight'], strict=False)}"
)
@ -580,8 +579,6 @@ class TTS:
self.configs.save_configs()
def init_t2s_weights(self, weights_path: str):
print(f"Loading Text2Semantic weights from {weights_path}")
self.configs.t2s_weights_path = weights_path
@ -654,7 +651,7 @@ class TTS:
self.vocoder_configs["overlapped_len"] = 12
self.vocoder = self.vocoder.eval()
if self.configs.is_half == True:
if self.configs.is_half is True:
self.vocoder = self.vocoder.half().to(self.configs.device)
else:
self.vocoder = self.vocoder.to(self.configs.device)
@ -784,7 +781,7 @@ class TTS:
)
if self.configs.is_half:
spec = spec.half()
if self.is_v2pro == True:
if self.is_v2pro is True:
audio = resample(audio, self.configs.sampling_rate, 16000, self.configs.device)
if self.configs.is_half:
audio = audio.half()

View File

@ -9,11 +9,12 @@ ERes2Net-huge is an upgraded version of ERes2Net that uses a larger number of pa
recognition performance. Parameters expansion, baseWidth, and scale can be modified to obtain optimal performance.
"""
import torch
import math
import pooling_layers as pooling_layers
import torch
import torch.nn as nn
import torch.nn.functional as F
import pooling_layers as pooling_layers
from fusion import AFF
@ -252,7 +253,7 @@ class ERes2Net(nn.Module):
out4 = self.layer4(out3)
fuse_out123_downsample = self.layer3_downsample(fuse_out123)
fuse_out1234 = self.fuse_mode1234(out4, fuse_out123_downsample).flatten(start_dim=1, end_dim=2) # bs,20480,T
if if_mean == False:
if if_mean is False:
mean = fuse_out1234[0].transpose(1, 0) # (T,20480),bs=T
else:
mean = fuse_out1234.mean(2) # bs,20480

View File

@ -663,7 +663,7 @@ def fbank(
# avoid log of zero (which should be prevented anyway by dithering)
mel_energies = torch.max(mel_energies, _get_epsilon(device, dtype)).log()
# if use_energy then add it as the last column for htk_compat == true else first column
# if use_energy then add it as the last column for htk_compat is True else first column
if use_energy:
signal_log_energy = signal_log_energy.unsqueeze(1) # size (m, 1)
# returns size (m, num_mel_bins + 1)
@ -826,7 +826,7 @@ def mfcc(
lifter_coeffs = _get_lifter_coeffs(num_ceps, cepstral_lifter).unsqueeze(0)
feature *= lifter_coeffs.to(device=device, dtype=dtype)
# if use_energy then replace the last column for htk_compat == true else first column
# if use_energy then replace the last column for htk_compat is True else first column
if use_energy:
feature[:, 0] = signal_log_energy

View File

@ -1,28 +1,24 @@
# modified from https://github.com/yangdongchao/SoundStorm/blob/master/soundstorm/s1/AR/models/t2s_model.py
# reference: https://github.com/lifeiteng/vall-e
import argparse
import os
from io import BytesIO
from typing import Optional
from my_utils import load_audio
import kaldi as Kaldi
import soundfile
import torch
import torchaudio
from feature_extractor import cnhubert
from inference_webui import get_phones_and_bert
from my_utils import load_audio
from sv import SV
from torch import IntTensor, LongTensor, Tensor, nn
from torch.nn import functional as F
from transformers import AutoModelForMaskedLM, AutoTokenizer
from feature_extractor import cnhubert
from AR.models.t2s_lightning_module import Text2SemanticLightningModule
from module.models_onnx import SynthesizerTrn
from inference_webui import get_phones_and_bert
from sv import SV
import kaldi as Kaldi
import os
import soundfile
from GPT_SoVITS.AR.models.t2s_lightning_module import Text2SemanticLightningModule
from GPT_SoVITS.module.models_onnx import SynthesizerTrn
default_config = {
"embedding_dim": 512,
@ -477,7 +473,7 @@ class T2SModel(nn.Module):
# avoid dtype inconsistency when exporting
bert = bert.to(dtype=self.bert_proj.weight.dtype)
x = x + self.bert_proj(bert.transpose(1, 2))
x: torch.Tensor = self.ar_text_position(x)
@ -737,7 +733,7 @@ def export_prov2(
device="cpu",
is_half=True,
):
if sv_cn_model == None:
if sv_cn_model is None:
init_sv_cn(device, is_half)
if not os.path.exists(output_path):
@ -1041,9 +1037,10 @@ def test():
soundfile.write("out.wav", audio.detach().cpu().numpy(), 32000)
import text
import json
import text
def export_symbel(version="v2"):
if version == "v1":

View File

@ -1,4 +1,13 @@
import logging
import os
import librosa
import numpy as np
import soundfile
import torch
import torch._dynamo.config
import torchaudio
import uvicorn
from export_torch_script import (
T2SModel,
get_raw_t2s_model,
@ -6,22 +15,12 @@ from export_torch_script import (
spectrogram_torch,
)
from f5_tts.model.backbones.dit import DiT
from inference_webui import get_phones_and_bert
import librosa
from module import commons
from module.mel_processing import mel_spectrogram_torch
from module.models_onnx import CFM, Generator, SynthesizerTrnV3
import numpy as np
import torch._dynamo.config
import torchaudio
import logging
import uvicorn
import torch
import soundfile
from inference_webui import get_phones_and_bert, get_spepc, norm_spec, resample, ssl_model
from librosa.filters import mel as librosa_mel_fn
from inference_webui import get_spepc, norm_spec, resample, ssl_model
from GPT_SoVITS.module import commons
from GPT_SoVITS.module.mel_processing import mel_spectrogram_torch
from GPT_SoVITS.module.models_onnx import CFM, Generator, SynthesizerTrnV3
logging.config.dictConfig(uvicorn.config.LOGGING_CONFIG)
logger = logging.getLogger("uvicorn")
@ -176,32 +175,33 @@ class ExportCFM(torch.nn.Module):
return cfm_res, fea_ref, mel2
mel_fn = lambda x: mel_spectrogram_torch(
x,
**{
"n_fft": 1024,
"win_size": 1024,
"hop_size": 256,
"num_mels": 100,
"sampling_rate": 24000,
"fmin": 0,
"fmax": None,
"center": False,
},
)
mel_fn_v4 = lambda x: mel_spectrogram_torch(
x,
**{
"n_fft": 1280,
"win_size": 1280,
"hop_size": 320,
"num_mels": 100,
"sampling_rate": 32000,
"fmin": 0,
"fmax": None,
"center": False,
},
)
def mel_fn(x):
return mel_spectrogram_torch(
y=x,
n_fft=1024,
num_mels=100,
sampling_rate=24000,
hop_size=256,
win_size=1024,
fmin=0,
fmax=None,
center=False,
)
def mel_fn_v4(x):
return mel_spectrogram_torch(
y=x,
n_fft=1280,
num_mels=100,
sampling_rate=32000,
hop_size=320,
win_size=1280,
fmin=0,
fmax=None,
center=False,
)
spec_min = -12
spec_max = 2
@ -511,7 +511,7 @@ def init_bigvgan():
# remove weight norm in the model and set to eval mode
bigvgan_model.remove_weight_norm()
bigvgan_model = bigvgan_model.eval()
if is_half == True:
if is_half is True:
bigvgan_model = bigvgan_model.half().to(device)
else:
bigvgan_model = bigvgan_model.to(device)
@ -536,7 +536,7 @@ def init_hifigan():
"%s/GPT_SoVITS/pretrained_models/gsv-v4-pretrained/vocoder.pth" % (now_dir,), map_location="cpu"
)
print("loading vocoder", hifigan_model.load_state_dict(state_dict_g))
if is_half == True:
if is_half is True:
hifigan_model = hifigan_model.half().to(device)
else:
hifigan_model = hifigan_model.to(device)
@ -578,7 +578,7 @@ class DictToAttrRecursive(dict):
raise AttributeError(f"Attribute {item} not found")
from process_ckpt import get_sovits_version_from_path_fast, load_sovits_new
from GPT_SoVITS.process_ckpt import get_sovits_version_from_path_fast, load_sovits_new
v3v4set = {"v3", "v4"}
@ -588,7 +588,7 @@ def get_sovits_weights(sovits_path):
is_exist_s2gv3 = os.path.exists(path_sovits_v3)
version, model_version, if_lora_v3 = get_sovits_version_from_path_fast(sovits_path)
if if_lora_v3 == True and is_exist_s2gv3 == False:
if if_lora_v3 is True and is_exist_s2gv3 is False:
logger.info("SoVITS V3 底模缺失,无法加载相应 LoRA 权重")
dict_s2 = load_sovits_new(sovits_path)
@ -617,7 +617,7 @@ def get_sovits_weights(sovits_path):
model_version = hps.model.version
logger.info(f"模型版本: {model_version}")
if is_half == True:
if is_half is True:
vq_model = vq_model.half().to(device)
else:
vq_model = vq_model.to(device)
@ -729,11 +729,11 @@ def export_1(ref_wav_path, ref_wav_text, version="v3"):
# ref_wav_path = "onnx/ad/ref.wav"
speed = 1.0
sample_steps = 8
dtype = torch.float16 if is_half == True else torch.float32
dtype = torch.float16 if is_half is True else torch.float32
refer = get_spepc(hps, ref_wav_path).to(device).to(dtype)
zero_wav = np.zeros(
int(hps.data.sampling_rate * 0.3),
dtype=np.float16 if is_half == True else np.float32,
dtype=np.float16 if is_half is True else np.float32,
)
with torch.no_grad():
@ -741,7 +741,7 @@ def export_1(ref_wav_path, ref_wav_text, version="v3"):
wav16k = torch.from_numpy(wav16k)
zero_wav_torch = torch.from_numpy(zero_wav)
if is_half == True:
if is_half is True:
wav16k = wav16k.half().to(device)
zero_wav_torch = zero_wav_torch.half().to(device)
else:
@ -940,11 +940,11 @@ def test_export(
speed = 1.0
sample_steps = 8
dtype = torch.float16 if is_half == True else torch.float32
dtype = torch.float16 if is_half is True else torch.float32
zero_wav = np.zeros(
int(16000 * 0.3),
dtype=np.float16 if is_half == True else np.float32,
dtype=np.float16 if is_half is True else np.float32,
)
with torch.no_grad():
@ -952,7 +952,7 @@ def test_export(
wav16k = torch.from_numpy(wav16k)
zero_wav_torch = torch.from_numpy(zero_wav)
if is_half == True:
if is_half is True:
wav16k = wav16k.half().to(device)
zero_wav_torch = zero_wav_torch.half().to(device)
else:
@ -1058,11 +1058,11 @@ def test_export(
speed = 1.0
sample_steps = torch.LongTensor([16])
dtype = torch.float16 if is_half == True else torch.float32
dtype = torch.float16 if is_half is True else torch.float32
zero_wav = np.zeros(
int(out_sr * 0.3),
dtype=np.float16 if is_half == True else np.float32,
dtype=np.float16 if is_half is True else np.float32,
)
with torch.no_grad():
@ -1070,7 +1070,7 @@ def test_export(
wav16k = torch.from_numpy(wav16k)
zero_wav_torch = torch.from_numpy(zero_wav)
if is_half == True:
if is_half is True:
wav16k = wav16k.half().to(device)
zero_wav_torch = zero_wav_torch.half().to(device)
else:

View File

@ -12,20 +12,18 @@ from __future__ import annotations
import torch
from torch import nn
from torch.utils.checkpoint import checkpoint
from x_transformers.x_transformers import RotaryEmbedding
from GPT_SoVITS.f5_tts.model.modules import (
TimestepEmbedding,
AdaLayerNormZero_Final,
ConvNeXtV2Block,
ConvPositionEmbedding,
DiTBlock,
AdaLayerNormZero_Final,
precompute_freqs_cis,
TimestepEmbedding,
get_pos_embed_indices,
precompute_freqs_cis,
)
from module.commons import sequence_mask
from GPT_SoVITS.module.commons import sequence_mask
class TextEmbedding(nn.Module):

View File

@ -1,5 +1,6 @@
import torch
import os
import torch
from transformers import logging as tf_logging
tf_logging.set_verbosity_error()
@ -8,13 +9,13 @@ import logging
logging.getLogger("numba").setLevel(logging.WARNING)
import torch.nn as nn
from transformers import (
Wav2Vec2FeatureExtractor,
HubertModel,
Wav2Vec2FeatureExtractor,
)
import utils
import torch.nn as nn
import GPT_SoVITS.utils as utils
cnhubert_base_path = None

View File

@ -1,86 +0,0 @@
import argparse
import os
import soundfile as sf
from tools.i18n.i18n import I18nAuto
from GPT_SoVITS.inference_webui import change_gpt_weights, change_sovits_weights, get_tts_wav
i18n = I18nAuto()
def synthesize(
GPT_model_path,
SoVITS_model_path,
ref_audio_path,
ref_text_path,
ref_language,
target_text_path,
target_language,
output_path,
):
# Read reference text
with open(ref_text_path, "r", encoding="utf-8") as file:
ref_text = file.read()
# Read target text
with open(target_text_path, "r", encoding="utf-8") as file:
target_text = file.read()
# Change model weights
change_gpt_weights(gpt_path=GPT_model_path)
change_sovits_weights(sovits_path=SoVITS_model_path)
# Synthesize audio
synthesis_result = get_tts_wav(
ref_wav_path=ref_audio_path,
prompt_text=ref_text,
prompt_language=i18n(ref_language),
text=target_text,
text_language=i18n(target_language),
top_p=1,
temperature=1,
)
result_list = list(synthesis_result)
if result_list:
last_sampling_rate, last_audio_data = result_list[-1]
output_wav_path = os.path.join(output_path, "output.wav")
sf.write(output_wav_path, last_audio_data, last_sampling_rate)
print(f"Audio saved to {output_wav_path}")
def main():
parser = argparse.ArgumentParser(description="GPT-SoVITS Command Line Tool")
parser.add_argument("--gpt_model", required=True, help="Path to the GPT model file")
parser.add_argument("--sovits_model", required=True, help="Path to the SoVITS model file")
parser.add_argument("--ref_audio", required=True, help="Path to the reference audio file")
parser.add_argument("--ref_text", required=True, help="Path to the reference text file")
parser.add_argument(
"--ref_language", required=True, choices=["中文", "英文", "日文"], help="Language of the reference audio"
)
parser.add_argument("--target_text", required=True, help="Path to the target text file")
parser.add_argument(
"--target_language",
required=True,
choices=["中文", "英文", "日文", "中英混合", "日英混合", "多语种混合"],
help="Language of the target text",
)
parser.add_argument("--output_path", required=True, help="Path to the output directory")
args = parser.parse_args()
synthesize(
args.gpt_model,
args.sovits_model,
args.ref_audio,
args.ref_text,
args.ref_language,
args.target_text,
args.target_language,
args.output_path,
)
if __name__ == "__main__":
main()

View File

@ -1,316 +0,0 @@
import os
import sys
from PyQt5.QtCore import QEvent
from PyQt5.QtWidgets import QApplication, QMainWindow, QLabel, QLineEdit, QPushButton, QTextEdit
from PyQt5.QtWidgets import QGridLayout, QVBoxLayout, QWidget, QFileDialog, QStatusBar, QComboBox
import soundfile as sf
from tools.i18n.i18n import I18nAuto
i18n = I18nAuto()
from inference_webui import gpt_path, sovits_path, change_gpt_weights, change_sovits_weights, get_tts_wav
class GPTSoVITSGUI(QMainWindow):
GPT_Path = gpt_path
SoVITS_Path = sovits_path
def __init__(self):
super().__init__()
self.setWindowTitle("GPT-SoVITS GUI")
self.setGeometry(800, 450, 950, 850)
self.setStyleSheet("""
QWidget {
background-color: #a3d3b1;
}
QTabWidget::pane {
background-color: #a3d3b1;
}
QTabWidget::tab-bar {
alignment: left;
}
QTabBar::tab {
background: #8da4bf;
color: #ffffff;
padding: 8px;
}
QTabBar::tab:selected {
background: #2a3f54;
}
QLabel {
color: #000000;
}
QPushButton {
background-color: #4CAF50;
color: white;
padding: 8px;
border: 1px solid #4CAF50;
border-radius: 4px;
}
QPushButton:hover {
background-color: #45a049;
border: 1px solid #45a049;
box-shadow: 2px 2px 2px rgba(0, 0, 0, 0.1);
}
""")
license_text = (
"本软件以MIT协议开源, 作者不对软件具备任何控制力, 使用软件者、传播软件导出的声音者自负全责. "
"如不认可该条款, 则不能使用或引用软件包内任何代码和文件. 详见根目录LICENSE."
)
license_label = QLabel(license_text)
license_label.setWordWrap(True)
self.GPT_model_label = QLabel("选择GPT模型:")
self.GPT_model_input = QLineEdit()
self.GPT_model_input.setPlaceholderText("拖拽或选择文件")
self.GPT_model_input.setText(self.GPT_Path)
self.GPT_model_input.setReadOnly(True)
self.GPT_model_button = QPushButton("选择GPT模型文件")
self.GPT_model_button.clicked.connect(self.select_GPT_model)
self.SoVITS_model_label = QLabel("选择SoVITS模型:")
self.SoVITS_model_input = QLineEdit()
self.SoVITS_model_input.setPlaceholderText("拖拽或选择文件")
self.SoVITS_model_input.setText(self.SoVITS_Path)
self.SoVITS_model_input.setReadOnly(True)
self.SoVITS_model_button = QPushButton("选择SoVITS模型文件")
self.SoVITS_model_button.clicked.connect(self.select_SoVITS_model)
self.ref_audio_label = QLabel("上传参考音频:")
self.ref_audio_input = QLineEdit()
self.ref_audio_input.setPlaceholderText("拖拽或选择文件")
self.ref_audio_input.setReadOnly(True)
self.ref_audio_button = QPushButton("选择音频文件")
self.ref_audio_button.clicked.connect(self.select_ref_audio)
self.ref_text_label = QLabel("参考音频文本:")
self.ref_text_input = QLineEdit()
self.ref_text_input.setPlaceholderText("直接输入文字或上传文本")
self.ref_text_button = QPushButton("上传文本")
self.ref_text_button.clicked.connect(self.upload_ref_text)
self.ref_language_label = QLabel("参考音频语言:")
self.ref_language_combobox = QComboBox()
self.ref_language_combobox.addItems(["中文", "英文", "日文", "中英混合", "日英混合", "多语种混合"])
self.ref_language_combobox.setCurrentText("多语种混合")
self.target_text_label = QLabel("合成目标文本:")
self.target_text_input = QLineEdit()
self.target_text_input.setPlaceholderText("直接输入文字或上传文本")
self.target_text_button = QPushButton("上传文本")
self.target_text_button.clicked.connect(self.upload_target_text)
self.target_language_label = QLabel("合成音频语言:")
self.target_language_combobox = QComboBox()
self.target_language_combobox.addItems(["中文", "英文", "日文", "中英混合", "日英混合", "多语种混合"])
self.target_language_combobox.setCurrentText("多语种混合")
self.output_label = QLabel("输出音频路径:")
self.output_input = QLineEdit()
self.output_input.setPlaceholderText("拖拽或选择文件")
self.output_input.setReadOnly(True)
self.output_button = QPushButton("选择文件夹")
self.output_button.clicked.connect(self.select_output_path)
self.output_text = QTextEdit()
self.output_text.setReadOnly(True)
self.add_drag_drop_events(
[
self.GPT_model_input,
self.SoVITS_model_input,
self.ref_audio_input,
self.ref_text_input,
self.target_text_input,
self.output_input,
]
)
self.synthesize_button = QPushButton("合成")
self.synthesize_button.clicked.connect(self.synthesize)
self.clear_output_button = QPushButton("清空输出")
self.clear_output_button.clicked.connect(self.clear_output)
self.status_bar = QStatusBar()
main_layout = QVBoxLayout()
input_layout = QGridLayout(self)
input_layout.setSpacing(10)
input_layout.addWidget(license_label, 0, 0, 1, 3)
input_layout.addWidget(self.GPT_model_label, 1, 0)
input_layout.addWidget(self.GPT_model_input, 2, 0, 1, 2)
input_layout.addWidget(self.GPT_model_button, 2, 2)
input_layout.addWidget(self.SoVITS_model_label, 3, 0)
input_layout.addWidget(self.SoVITS_model_input, 4, 0, 1, 2)
input_layout.addWidget(self.SoVITS_model_button, 4, 2)
input_layout.addWidget(self.ref_audio_label, 5, 0)
input_layout.addWidget(self.ref_audio_input, 6, 0, 1, 2)
input_layout.addWidget(self.ref_audio_button, 6, 2)
input_layout.addWidget(self.ref_language_label, 7, 0)
input_layout.addWidget(self.ref_language_combobox, 8, 0, 1, 1)
input_layout.addWidget(self.ref_text_label, 9, 0)
input_layout.addWidget(self.ref_text_input, 10, 0, 1, 2)
input_layout.addWidget(self.ref_text_button, 10, 2)
input_layout.addWidget(self.target_language_label, 11, 0)
input_layout.addWidget(self.target_language_combobox, 12, 0, 1, 1)
input_layout.addWidget(self.target_text_label, 13, 0)
input_layout.addWidget(self.target_text_input, 14, 0, 1, 2)
input_layout.addWidget(self.target_text_button, 14, 2)
input_layout.addWidget(self.output_label, 15, 0)
input_layout.addWidget(self.output_input, 16, 0, 1, 2)
input_layout.addWidget(self.output_button, 16, 2)
main_layout.addLayout(input_layout)
output_layout = QVBoxLayout()
output_layout.addWidget(self.output_text)
main_layout.addLayout(output_layout)
main_layout.addWidget(self.synthesize_button)
main_layout.addWidget(self.clear_output_button)
main_layout.addWidget(self.status_bar)
self.central_widget = QWidget()
self.central_widget.setLayout(main_layout)
self.setCentralWidget(self.central_widget)
def dragEnterEvent(self, event):
if event.mimeData().hasUrls():
event.acceptProposedAction()
def dropEvent(self, event):
if event.mimeData().hasUrls():
file_paths = [url.toLocalFile() for url in event.mimeData().urls()]
if len(file_paths) == 1:
self.update_ref_audio(file_paths[0])
else:
self.update_ref_audio(", ".join(file_paths))
def add_drag_drop_events(self, widgets):
for widget in widgets:
widget.setAcceptDrops(True)
widget.installEventFilter(self)
def eventFilter(self, obj, event):
if event.type() in (QEvent.DragEnter, QEvent.Drop):
mime_data = event.mimeData()
if mime_data.hasUrls():
event.acceptProposedAction()
return super().eventFilter(obj, event)
def select_GPT_model(self):
file_path, _ = QFileDialog.getOpenFileName(self, "选择GPT模型文件", "", "GPT Files (*.ckpt)")
if file_path:
self.GPT_model_input.setText(file_path)
def select_SoVITS_model(self):
file_path, _ = QFileDialog.getOpenFileName(self, "选择SoVITS模型文件", "", "SoVITS Files (*.pth)")
if file_path:
self.SoVITS_model_input.setText(file_path)
def select_ref_audio(self):
file_path, _ = QFileDialog.getOpenFileName(self, "选择参考音频文件", "", "Audio Files (*.wav *.mp3)")
if file_path:
self.update_ref_audio(file_path)
def upload_ref_text(self):
file_path, _ = QFileDialog.getOpenFileName(self, "选择文本文件", "", "Text Files (*.txt)")
if file_path:
with open(file_path, "r", encoding="utf-8") as file:
content = file.read()
self.ref_text_input.setText(content)
def upload_target_text(self):
file_path, _ = QFileDialog.getOpenFileName(self, "选择文本文件", "", "Text Files (*.txt)")
if file_path:
with open(file_path, "r", encoding="utf-8") as file:
content = file.read()
self.target_text_input.setText(content)
def select_output_path(self):
options = QFileDialog.Options()
options |= QFileDialog.DontUseNativeDialog
options |= QFileDialog.ShowDirsOnly
folder_dialog = QFileDialog()
folder_dialog.setOptions(options)
folder_dialog.setFileMode(QFileDialog.Directory)
if folder_dialog.exec_():
folder_path = folder_dialog.selectedFiles()[0]
self.output_input.setText(folder_path)
def update_ref_audio(self, file_path):
self.ref_audio_input.setText(file_path)
def clear_output(self):
self.output_text.clear()
def synthesize(self):
GPT_model_path = self.GPT_model_input.text()
SoVITS_model_path = self.SoVITS_model_input.text()
ref_audio_path = self.ref_audio_input.text()
language_combobox = self.ref_language_combobox.currentText()
language_combobox = i18n(language_combobox)
ref_text = self.ref_text_input.text()
target_language_combobox = self.target_language_combobox.currentText()
target_language_combobox = i18n(target_language_combobox)
target_text = self.target_text_input.text()
output_path = self.output_input.text()
if GPT_model_path != self.GPT_Path:
change_gpt_weights(gpt_path=GPT_model_path)
self.GPT_Path = GPT_model_path
if SoVITS_model_path != self.SoVITS_Path:
change_sovits_weights(sovits_path=SoVITS_model_path)
self.SoVITS_Path = SoVITS_model_path
synthesis_result = get_tts_wav(
ref_wav_path=ref_audio_path,
prompt_text=ref_text,
prompt_language=language_combobox,
text=target_text,
text_language=target_language_combobox,
)
result_list = list(synthesis_result)
if result_list:
last_sampling_rate, last_audio_data = result_list[-1]
output_wav_path = os.path.join(output_path, "output.wav")
sf.write(output_wav_path, last_audio_data, last_sampling_rate)
result = "Audio saved to " + output_wav_path
self.status_bar.showMessage("合成完成!输出路径:" + output_wav_path, 5000)
self.output_text.append("处理结果:\n" + result)
if __name__ == "__main__":
app = QApplication(sys.argv)
mainWin = GPTSoVITSGUI()
mainWin.show()
sys.exit(app.exec_())

File diff suppressed because it is too large Load Diff

View File

@ -6,19 +6,24 @@
全部按英文识别
全部按日文识别
"""
import psutil
import os
import psutil
def set_high_priority():
"""把当前 Python 进程设为 HIGH_PRIORITY_CLASS"""
if os.name != "nt":
return # 仅 Windows 有效
return # 仅 Windows 有效
p = psutil.Process(os.getpid())
try:
p.nice(psutil.HIGH_PRIORITY_CLASS)
print("已将进程优先级设为 High")
except psutil.AccessDenied:
print("权限不足,无法修改优先级(请用管理员运行)")
set_high_priority()
import json
import logging
@ -225,7 +230,7 @@ with open("./weight.json", "r", encoding="utf-8") as file:
if isinstance(sovits_path, list):
sovits_path = sovits_path[0]
from process_ckpt import get_sovits_version_from_path_fast
from GPT_SoVITS.process_ckpt import get_sovits_version_from_path_fast
v3v4set = {"v3", "v4"}
@ -238,7 +243,7 @@ def change_sovits_weights(sovits_path, prompt_language=None, text_language=None)
# print(sovits_path,version, model_version, if_lora_v3)
is_exist = is_exist_s2gv3 if model_version == "v3" else is_exist_s2gv4
path_sovits = path_sovits_v3 if model_version == "v3" else path_sovits_v4
if if_lora_v3 == True and is_exist == False:
if if_lora_v3 is True and is_exist is False:
info = path_sovits + "SoVITS %s" % model_version + i18n("底模缺失,无法加载相应 LoRA 权重")
gr.Warning(info)
raise FileExistsError(info)
@ -315,7 +320,7 @@ with gr.Blocks(title="GPT-SoVITS WebUI", analytics_enabled=False, js=js, css=css
with gr.Column():
# with gr.Group():
gr.Markdown(value=i18n("模型切换"))
with gr.Row():
with gr.Row(equal_height=True):
GPT_dropdown = gr.Dropdown(
label=i18n("GPT模型列表"),
choices=sorted(GPT_names, key=custom_sort_key),
@ -331,18 +336,22 @@ with gr.Blocks(title="GPT-SoVITS WebUI", analytics_enabled=False, js=js, css=css
refresh_button = gr.Button(i18n("刷新模型路径"), variant="primary")
refresh_button.click(fn=change_choices, inputs=[], outputs=[SoVITS_dropdown, GPT_dropdown])
with gr.Row():
with gr.Row(equal_height=True):
with gr.Column():
gr.Markdown(value=i18n("*请上传并填写参考信息"))
with gr.Row():
inp_ref = gr.Audio(label=i18n("主参考音频(请上传3~10秒内参考音频超过会报错)"), type="filepath")
with gr.Row(equal_height=True):
inp_ref = gr.Audio(
label=i18n("主参考音频(请上传3~10秒内参考音频超过会报错)"),
type="filepath",
waveform_options={"show_recording_waveform": False},
)
inp_refs = gr.File(
label=i18n("辅参考音频(可选多个,或不选)"),
file_count="multiple",
visible=True if model_version != "v3" else False,
)
prompt_text = gr.Textbox(label=i18n("主参考音频的文本"), value="", lines=2)
with gr.Row():
with gr.Row(equal_height=True):
prompt_language = gr.Dropdown(
label=i18n("主参考音频的语种"), choices=list(dict_language.keys()), value=i18n("中文")
)
@ -368,26 +377,26 @@ with gr.Blocks(title="GPT-SoVITS WebUI", analytics_enabled=False, js=js, css=css
with gr.Group():
gr.Markdown(value=i18n("推理设置"))
with gr.Row():
with gr.Row(equal_height=True):
with gr.Column():
with gr.Row():
with gr.Row(equal_height=True):
batch_size = gr.Slider(
minimum=1, maximum=200, step=1, label=i18n("batch_size"), value=20, interactive=True
)
sample_steps = gr.Radio(
label=i18n("采样步数(仅对V3/4生效)"), value=32, choices=[4, 8, 16, 32, 64, 128], visible=True
)
with gr.Row():
with gr.Row(equal_height=True):
fragment_interval = gr.Slider(
minimum=0.01, maximum=1, step=0.01, label=i18n("分段间隔(秒)"), value=0.3, interactive=True
)
speed_factor = gr.Slider(
minimum=0.6, maximum=1.65, step=0.05, label="语速", value=1.0, interactive=True
)
with gr.Row():
with gr.Row(equal_height=True):
top_k = gr.Slider(minimum=1, maximum=100, step=1, label=i18n("top_k"), value=5, interactive=True)
top_p = gr.Slider(minimum=0, maximum=1, step=0.05, label=i18n("top_p"), value=1, interactive=True)
with gr.Row():
with gr.Row(equal_height=True):
temperature = gr.Slider(
minimum=0, maximum=1, step=0.05, label=i18n("temperature"), value=1, interactive=True
)
@ -396,7 +405,7 @@ with gr.Blocks(title="GPT-SoVITS WebUI", analytics_enabled=False, js=js, css=css
)
with gr.Column():
with gr.Row():
with gr.Row(equal_height=True):
how_to_cut = gr.Dropdown(
label=i18n("怎么切"),
choices=[
@ -415,7 +424,7 @@ with gr.Blocks(title="GPT-SoVITS WebUI", analytics_enabled=False, js=js, css=css
label=i18n("音频超采样(仅对V3生效))"), value=False, interactive=True, show_label=True
)
with gr.Row():
with gr.Row(equal_height=True):
parallel_infer = gr.Checkbox(label=i18n("并行推理"), value=True, interactive=True, show_label=True)
split_bucket = gr.Checkbox(
label=i18n("数据分桶(并行推理时会降低一点计算量)"),
@ -424,12 +433,15 @@ with gr.Blocks(title="GPT-SoVITS WebUI", analytics_enabled=False, js=js, css=css
show_label=True,
)
with gr.Row():
with gr.Row(equal_height=True):
seed = gr.Number(label=i18n("随机种子"), value=-1)
keep_random = gr.Checkbox(label=i18n("保持随机"), value=True, interactive=True, show_label=True)
output = gr.Audio(label=i18n("输出的语音"))
with gr.Row():
output = gr.Audio(
label=i18n("输出的语音"),
waveform_options={"show_recording_waveform": False},
)
with gr.Row(equal_height=True):
inference_button = gr.Button(i18n("合成语音"), variant="primary")
stop_infer = gr.Button(i18n("终止合成"), variant="primary")
@ -485,7 +497,7 @@ with gr.Blocks(title="GPT-SoVITS WebUI", analytics_enabled=False, js=js, css=css
"文本切分工具。太长的文本合成出来效果不一定好,所以太长建议先切。合成会根据文本的换行分开合成再拼起来。"
)
)
with gr.Row():
with gr.Row(equal_height=True):
text_inp = gr.Textbox(label=i18n("需要合成的切分前文本"), value="", lines=4)
with gr.Column():
_how_to_cut = gr.Radio(

View File

@ -1,10 +1,11 @@
import math
import torch
from torch import nn
from torch.nn import functional as F
from module import commons
from module.modules import LayerNorm
from GPT_SoVITS.module import commons
from GPT_SoVITS.module.modules import LayerNorm
class Encoder(nn.Module):

View File

@ -1,11 +1,11 @@
import math
from typing import Optional
import torch
from torch import nn
from torch.nn import functional as F
from module import commons
from typing import Optional
from GPT_SoVITS.module import commons
class LayerNorm(nn.Module):

View File

@ -1,13 +1,14 @@
import os
import random
import traceback
import torch
import torch.nn.functional as F
import torch.utils.data
from text import cleaned_text_to_sequence
from tqdm import tqdm
from module.mel_processing import spectrogram_torch, spec_to_mel_torch
from text import cleaned_text_to_sequence
import torch.nn.functional as F
from GPT_SoVITS.module.mel_processing import spec_to_mel_torch, spectrogram_torch
from tools.my_utils import load_audio
version = os.environ.get("version", None)

View File

@ -4,7 +4,7 @@ import torch
def feature_loss(fmap_r, fmap_g):
loss = 0
loss = torch.tensor(0)
for dr, dg in zip(fmap_r, fmap_g):
for rl, gl in zip(dr, dg):
rl = rl.float().detach()
@ -15,7 +15,7 @@ def feature_loss(fmap_r, fmap_g):
def discriminator_loss(disc_real_outputs, disc_generated_outputs):
loss = 0
loss = torch.tensor(0)
r_losses = []
g_losses = []
for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
@ -31,7 +31,7 @@ def discriminator_loss(disc_real_outputs, disc_generated_outputs):
def generator_loss(disc_outputs):
loss = 0
loss = torch.tensor(0)
gen_losses = []
for dg in disc_outputs:
dg = dg.float()

View File

@ -1,28 +1,27 @@
import warnings
warnings.filterwarnings("ignore")
import contextlib
import math
import random
import warnings
import torch
from torch import nn
from torch.nn import functional as F
from module import commons
from module import modules
from module import attentions
from f5_tts.model import DiT
from torch.nn import Conv1d, ConvTranspose1d, Conv2d
from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
from module.commons import init_weights, get_padding
from module.mrte_model import MRTE
from module.quantize import ResidualVectorQuantizer
# from text import symbols
from text import symbols as symbols_v1
from text import symbols2 as symbols_v2
from torch.cuda.amp import autocast
import contextlib
import random
from torch.nn import Conv1d, Conv2d, ConvTranspose1d
from torch.nn import functional as F
from torch.nn.utils import remove_weight_norm, spectral_norm, weight_norm
from GPT_SoVITS.f5_tts.model import DiT
from GPT_SoVITS.module import attentions, commons, modules
from GPT_SoVITS.module.commons import get_padding, init_weights
from GPT_SoVITS.module.mrte_model import MRTE
from GPT_SoVITS.module.quantize import ResidualVectorQuantizer
from GPT_SoVITS.text import symbols as symbols_v1
from GPT_SoVITS.text import symbols2 as symbols_v2
from GPT_SoVITS.utils import HParams
warnings.filterwarnings("ignore")
torch.serialization.add_safe_globals([(HParams, "utils.HParams")])
class StochasticDurationPredictor(nn.Module):
@ -230,25 +229,6 @@ class TextEncoder(nn.Module):
m, logs = torch.split(stats, self.out_channels, dim=1)
return y, m, logs, y_mask
def extract_latent(self, x):
x = self.ssl_proj(x)
quantized, codes, commit_loss, quantized_list = self.quantizer(x)
return codes.transpose(0, 1)
def decode_latent(self, codes, y_mask, refer, refer_mask, ge):
quantized = self.quantizer.decode(codes)
y = self.vq_proj(quantized) * y_mask
y = self.encoder_ssl(y * y_mask, y_mask)
y = self.mrte(y, y_mask, refer, refer_mask, ge)
y = self.encoder2(y * y_mask, y_mask)
stats = self.proj(y) * y_mask
m, logs = torch.split(stats, self.out_channels, dim=1)
return y, m, logs, y_mask, quantized
class ResidualCouplingBlock(nn.Module):
def __init__(
@ -483,7 +463,7 @@ class DiscriminatorP(torch.nn.Module):
super(DiscriminatorP, self).__init__()
self.period = period
self.use_spectral_norm = use_spectral_norm
norm_f = weight_norm if use_spectral_norm == False else spectral_norm
norm_f = weight_norm if use_spectral_norm is False else spectral_norm
self.convs = nn.ModuleList(
[
norm_f(
@ -560,7 +540,7 @@ class DiscriminatorP(torch.nn.Module):
class DiscriminatorS(torch.nn.Module):
def __init__(self, use_spectral_norm=False):
super(DiscriminatorS, self).__init__()
norm_f = weight_norm if use_spectral_norm == False else spectral_norm
norm_f = weight_norm if use_spectral_norm is False else spectral_norm
self.convs = nn.ModuleList(
[
norm_f(Conv1d(1, 16, 15, 1, padding=7)),
@ -1206,7 +1186,7 @@ class SynthesizerTrnV3(nn.Module):
100,
DiT(**dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=inter_channels2, conv_layers=4)),
) # text_dim is condition feature dim
if self.freeze_quantizer == True:
if self.freeze_quantizer is True:
set_no_grad(self.ssl_proj)
set_no_grad(self.quantizer)
set_no_grad(self.enc_p)
@ -1245,7 +1225,7 @@ class SynthesizerTrnV3(nn.Module):
def decode_encp(self, codes, text, refer, ge=None, speed=1):
# print(2333333,refer.shape)
# ge=None
if ge == None:
if ge is None:
refer_lengths = torch.LongTensor([refer.size(2)]).to(refer.device)
refer_mask = torch.unsqueeze(commons.sequence_mask(refer_lengths, refer.size(2)), 1).to(refer.dtype)
ge = self.ref_enc(refer[:, :704] * refer_mask, refer_mask)
@ -1409,7 +1389,7 @@ class SynthesizerTrnV3b(nn.Module):
def decode_encp(self, codes, text, refer, ge=None):
# print(2333333,refer.shape)
# ge=None
if ge == None:
if ge is None:
refer_lengths = torch.LongTensor([refer.size(2)]).to(refer.device)
refer_mask = torch.unsqueeze(commons.sequence_mask(refer_lengths, refer.size(2)), 1).to(refer.dtype)
ge = self.ref_enc(refer[:, :704] * refer_mask, refer_mask)

View File

@ -1,23 +1,21 @@
import math
from typing import Optional
import torch
from torch import nn
from torch.nn import functional as F
from module import commons
from module import modules
from module import attentions_onnx as attentions
from f5_tts.model import DiT
from torch.nn import Conv1d, ConvTranspose1d, Conv2d
from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
from module.commons import init_weights, get_padding
from module.quantize import ResidualVectorQuantizer
# from text import symbols
from text import symbols as symbols_v1
from text import symbols2 as symbols_v2
from torch import nn
from torch.nn import Conv1d, Conv2d, ConvTranspose1d
from torch.nn import functional as F
from torch.nn.utils import remove_weight_norm, spectral_norm, weight_norm
from GPT_SoVITS.module import attentions_onnx as attentions
from GPT_SoVITS.module import commons, modules
from GPT_SoVITS.module.commons import get_padding, init_weights
from GPT_SoVITS.module.quantize import ResidualVectorQuantizer
class StochasticDurationPredictor(nn.Module):
@ -459,7 +457,7 @@ class DiscriminatorP(torch.nn.Module):
super(DiscriminatorP, self).__init__()
self.period = period
self.use_spectral_norm = use_spectral_norm
norm_f = weight_norm if use_spectral_norm == False else spectral_norm
norm_f = weight_norm if use_spectral_norm is False else spectral_norm
self.convs = nn.ModuleList(
[
norm_f(
@ -536,7 +534,7 @@ class DiscriminatorP(torch.nn.Module):
class DiscriminatorS(torch.nn.Module):
def __init__(self, use_spectral_norm=False):
super(DiscriminatorS, self).__init__()
norm_f = weight_norm if use_spectral_norm == False else spectral_norm
norm_f = weight_norm if use_spectral_norm is False else spectral_norm
self.convs = nn.ModuleList(
[
norm_f(Conv1d(1, 16, 15, 1, padding=7)),
@ -1057,7 +1055,7 @@ class SynthesizerTrnV3(nn.Module):
100,
DiT(**dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=inter_channels2, conv_layers=4)),
) # text_dim is condition feature dim
if freeze_quantizer == True:
if freeze_quantizer is True:
set_no_grad(self.ssl_proj)
set_no_grad(self.quantizer)
set_no_grad(self.enc_p)

View File

@ -2,17 +2,15 @@ import math
import numpy as np
import torch
from torch import nn
from torch.nn import functional as F
from torch.nn import Conv1d
from torch.nn.utils import weight_norm, remove_weight_norm
from module import commons
from module.commons import init_weights, get_padding
from module.transforms import piecewise_rational_quadratic_transform
import torch.distributions as D
from torch import nn
from torch.nn import Conv1d
from torch.nn import functional as F
from torch.nn.utils import remove_weight_norm, weight_norm
from GPT_SoVITS.module import commons
from GPT_SoVITS.module.commons import get_padding, init_weights
from GPT_SoVITS.module.transforms import piecewise_rational_quadratic_transform
LRELU_SLOPE = 0.1

View File

@ -3,7 +3,8 @@
import torch
from torch import nn
from torch.nn.utils import remove_weight_norm, weight_norm
from module.attentions import MultiHeadAttention
from GPT_SoVITS.module.attentions import MultiHeadAttention
class MRTE(nn.Module):
@ -23,7 +24,7 @@ class MRTE(nn.Module):
self.c_post = nn.Conv1d(hidden_size, out_channels, 1)
def forward(self, ssl_enc, ssl_mask, text, text_mask, ge, test=None):
if ge == None:
if ge is None:
ge = 0
attn_mask = text_mask.unsqueeze(2) * ssl_mask.unsqueeze(-1)

View File

@ -6,13 +6,13 @@
"""Residual vector quantizer implementation."""
from dataclasses import dataclass, field
import typing as tp
from dataclasses import dataclass, field
import torch
from torch import nn
from module.core_vq import ResidualVectorQuantization
from GPT_SoVITS.module.core_vq import ResidualVectorQuantization
@dataclass

View File

@ -1,10 +1,11 @@
import torch
import torchaudio
from AR.models.t2s_lightning_module_onnx import Text2SemanticLightningModule
from feature_extractor import cnhubert
from module.models_onnx import SynthesizerTrn, symbols_v1, symbols_v2
from torch import nn
from GPT_SoVITS.AR.models.t2s_lightning_module_onnx import Text2SemanticLightningModule
from GPT_SoVITS.module.models_onnx import SynthesizerTrn, symbols_v1, symbols_v2
cnhubert_base_path = "GPT_SoVITS/pretrained_models/chinese-hubert-base"
cnhubert.cnhubert_base_path = cnhubert_base_path
ssl_model = cnhubert.get_model()

View File

@ -1,113 +1,24 @@
# -*- coding: utf-8 -*-
import argparse
import os
inp_text = os.environ.get("inp_text")
inp_wav_dir = os.environ.get("inp_wav_dir")
exp_name = os.environ.get("exp_name")
i_part = os.environ.get("i_part")
all_parts = os.environ.get("all_parts")
if "_CUDA_VISIBLE_DEVICES" in os.environ:
os.environ["CUDA_VISIBLE_DEVICES"] = os.environ["_CUDA_VISIBLE_DEVICES"]
opt_dir = os.environ.get("opt_dir")
bert_pretrained_dir = os.environ.get("bert_pretrained_dir")
import torch
is_half = eval(os.environ.get("is_half", "True")) and torch.cuda.is_available()
version = os.environ.get("version", None)
import traceback
import os.path
from text.cleaner import clean_text
import traceback
from multiprocessing import Process, Queue, set_start_method
import torch
from rich.progress import track
from transformers import AutoModelForMaskedLM, AutoTokenizer
from GPT_SoVITS.Accelerate import logger, tb
from GPT_SoVITS.text.cleaner import clean_text
from tools.my_utils import clean_path
# inp_text=sys.argv[1]
# inp_wav_dir=sys.argv[2]
# exp_name=sys.argv[3]
# i_part=sys.argv[4]
# all_parts=sys.argv[5]
# os.environ["CUDA_VISIBLE_DEVICES"]=sys.argv[6]#i_gpu
# opt_dir="/data/docker/liujing04/gpt-vits/fine_tune_dataset/%s"%exp_name
# bert_pretrained_dir="/data/docker/liujing04/bert-vits2/Bert-VITS2-master20231106/bert/chinese-roberta-wwm-ext-large"
torch.set_grad_enabled(False)
from time import time as ttime
import shutil
set_start_method("spawn", force=True)
def my_save(fea, path): #####fix issue: torch.save doesn't support chinese path
dir = os.path.dirname(path)
name = os.path.basename(path)
# tmp_path="%s/%s%s.pth"%(dir,ttime(),i_part)
tmp_path = "%s%s.pth" % (ttime(), i_part)
torch.save(fea, tmp_path)
shutil.move(tmp_path, "%s/%s" % (dir, name))
txt_path = "%s/2-name2text-%s.txt" % (opt_dir, i_part)
if os.path.exists(txt_path) == False:
bert_dir = "%s/3-bert" % (opt_dir)
os.makedirs(opt_dir, exist_ok=True)
os.makedirs(bert_dir, exist_ok=True)
if torch.cuda.is_available():
device = "cuda:0"
# elif torch.backends.mps.is_available():
# device = "mps"
else:
device = "cpu"
if os.path.exists(bert_pretrained_dir):
...
else:
raise FileNotFoundError(bert_pretrained_dir)
tokenizer = AutoTokenizer.from_pretrained(bert_pretrained_dir)
bert_model = AutoModelForMaskedLM.from_pretrained(bert_pretrained_dir)
if is_half == True:
bert_model = bert_model.half().to(device)
else:
bert_model = bert_model.to(device)
def get_bert_feature(text, word2ph):
with torch.no_grad():
inputs = tokenizer(text, return_tensors="pt")
for i in inputs:
inputs[i] = inputs[i].to(device)
res = bert_model(**inputs, output_hidden_states=True)
res = torch.cat(res["hidden_states"][-3:-2], -1)[0].cpu()[1:-1]
assert len(word2ph) == len(text)
phone_level_feature = []
for i in range(len(word2ph)):
repeat_feature = res[i].repeat(word2ph[i], 1)
phone_level_feature.append(repeat_feature)
phone_level_feature = torch.cat(phone_level_feature, dim=0)
return phone_level_feature.T
def process(data, res):
for name, text, lan in data:
try:
name = clean_path(name)
name = os.path.basename(name)
print(name)
phones, word2ph, norm_text = clean_text(text.replace("%", "-").replace("", ","), lan, version)
path_bert = "%s/%s.pt" % (bert_dir, name)
if os.path.exists(path_bert) == False and lan == "zh":
bert_feature = get_bert_feature(norm_text, word2ph)
assert bert_feature.shape[-1] == len(phones)
# torch.save(bert_feature, path_bert)
my_save(bert_feature, path_bert)
phones = " ".join(phones)
# res.append([name,phones])
res.append([name, phones, word2ph, norm_text])
except:
print(name, text, traceback.format_exc())
todo = []
res = []
with open(inp_text, "r", encoding="utf8") as f:
lines = f.read().strip("\n").split("\n")
language_v1_to_language_v2 = {
def lang_map(lang: str) -> str:
m = {
"ZH": "zh",
"zh": "zh",
"JP": "ja",
@ -124,20 +35,179 @@ if os.path.exists(txt_path) == False:
"YUE": "yue",
"Yue": "yue",
}
for line in lines[int(i_part) :: int(all_parts)]:
try:
wav_name, spk_name, language, text = line.split("|")
# todo.append([name,text,"zh"])
if language in language_v1_to_language_v2.keys():
todo.append([wav_name, text, language_v1_to_language_v2.get(language, language)])
else:
print(f"\033[33m[Waring] The {language = } of {wav_name} is not supported for training.\033[0m")
except:
print(line, traceback.format_exc())
return m.get(lang, "")
process(todo, res)
opt = []
for name, phones, word2ph, norm_text in res:
opt.append("%s\t%s\t%s\t%s" % (name, phones, word2ph, norm_text))
with open(txt_path, "w", encoding="utf8") as f:
f.write("\n".join(opt) + "\n")
def parse_inp_text_line(line: str) -> tuple[str, str, str]:
wav_name, _, language, text = line.split("|", 3)
return wav_name, language, text
def build_device_strings(device_type: str, device_ids: list[int], procs_per_device: int) -> list[str]:
devices = []
for device_id in device_ids:
dstr = f"{device_type}:{device_id}"
devices.extend([dstr] * procs_per_device)
return devices
def worker_run(
wid: int,
device_str: str,
tasks_q: Queue[tuple[int, str, str, str]],
results_q: Queue[tuple[int, tuple[str, str, list[int] | None, str]]],
bert_pretrained_dir: str,
opt_dir: str,
fp16: bool,
version: str,
):
device = torch.device(device_str)
if device.type == "cuda":
assert torch.cuda.is_available()
torch.cuda.set_device(device.index)
elif device.type == "mps":
assert torch.backends.mps.is_available()
bert_dir = os.path.join(opt_dir, "3-bert")
os.makedirs(bert_dir, exist_ok=True)
if not os.path.exists(bert_pretrained_dir):
raise FileNotFoundError(bert_pretrained_dir)
tokenizer = AutoTokenizer.from_pretrained(bert_pretrained_dir)
bert_model = AutoModelForMaskedLM.from_pretrained(bert_pretrained_dir)
if fp16:
bert_model = bert_model.half().to(device)
else:
bert_model = bert_model.to(device)
def get_bert_feature(text: str, word2ph: list[int]) -> torch.Tensor:
inputs = tokenizer(text, return_tensors="pt")
for k in inputs:
inputs[k] = inputs[k].to(device)
out = bert_model(**inputs, output_hidden_states=True)
layer = out.hidden_states[-3][0].cpu()[1:-1] # [seq-2, hid]
assert len(word2ph) == len(text)
phone_level_feature = []
for i in range(len(word2ph)):
phone_level_feature.append(layer[i].repeat(word2ph[i], 1))
feats = torch.cat(phone_level_feature, dim=0) # [phones, hid]
return feats.T # [hid, phones]
while True:
item = tasks_q.get()
if item is None:
break
idx, wav_name, language, text = item
try:
name = clean_path(os.path.basename(wav_name))
mapped_lang = lang_map(language)
if not mapped_lang:
logger.warning(f"[W{wid}] Unsupported language: {language} of {wav_name}")
results_q.put((idx, ("", "", [], "")))
continue
phones, word2ph, norm_text = clean_text(
text.replace("%", "-").replace("", ","),
mapped_lang,
version,
)
if mapped_lang == "zh":
path_bert = os.path.join(bert_dir, f"{name}.pt")
if not os.path.exists(path_bert):
assert word2ph
bert_feature = get_bert_feature(norm_text, word2ph)
assert bert_feature.shape[-1] == len(phones)
torch.save(bert_feature, path_bert)
phones_str = " ".join(phones)
results_q.put((idx, (name, phones_str, word2ph, norm_text)))
except Exception:
logger.error(f"[W{wid}] Failed: {wav_name} | {text}\n{tb()}")
results_q.put((idx, ("", "", [], "")))
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--inp", type=str, required=True, help="list Filewav|spk|lang|text")
parser.add_argument("--opt", type=str, required=True)
parser.add_argument("--bert", type=str, required=True)
parser.add_argument("--version", type=str, default=None)
parser.add_argument("--device", type=str, default="cpu", choices=["cpu", "cuda", "mps"])
parser.add_argument("--device-id", type=str, default="0", help="CUDA_VISIBLE_DEVICE")
parser.add_argument("--nproc", type=int, default=1)
parser.add_argument("--fp16", action="store_true")
args = parser.parse_args()
device_ids = [int(x) for x in args.devices.split(",") if x.strip() != ""]
if args.device in {"cpu", "mps"} and device_ids != [0]:
raise ValueError(f"Invalid Device ID {device_ids}")
if args.nproc < 1:
raise ValueError(f"Invalid Num Process {args.nproc}")
os.makedirs(args.opt, exist_ok=True)
merged_path = os.path.join(args.opt, "2-name2text.txt")
with open(args.inp, "r", encoding="utf8") as f:
lines = [ln for ln in f.read().splitlines() if ln.strip()]
tasks_all: list[tuple[int, str, str, str]] = []
for idx, line in enumerate(lines):
try:
wav_name, language, text = parse_inp_text_line(line)
tasks_all.append((idx, wav_name, language, text))
except Exception:
logger.error(f"Skip line {idx}: {line}\n{traceback.format_exc()}")
n_tasks = len(tasks_all)
if n_tasks == 0:
logger.warning("Empty list")
with open(merged_path, "w", encoding="utf8") as fout:
pass
return
device_strs = build_device_strings(args.device, device_ids, args.nproc)
total_workers = len(device_strs)
tasks_q: Queue[tuple[int, str, str, str] | None] = Queue(maxsize=total_workers * 2)
results_q: Queue = Queue()
for task in tasks_all:
tasks_q.put(task)
for _ in range(total_workers):
tasks_q.put(None)
procs: list[Process] = []
for wid, dstr in enumerate(device_strs):
p = Process(
target=worker_run,
args=(wid, dstr, tasks_q, results_q, args.bert, args.opt, bool(args.fp16), args.version),
daemon=False,
)
p.start()
procs.append(p)
ordered: list[tuple[str, str, list[int], str]] = [("", "", [], "")] * n_tasks
for _ in track(range(n_tasks)):
idx, tup = results_q.get() # (idx, (name, phones_str, word2ph, norm_text))
ordered[idx] = tup
for p in procs:
p.join()
with open(merged_path, "w", encoding="utf8") as fout:
for name, phones_str, word2ph, norm_text in ordered:
if name == "":
pass
else:
fout.write(f"{name}\t{phones_str}\t{word2ph}\t{norm_text}\n")
logger.info(f"Done: {merged_path}")
if __name__ == "__main__":
main()

View File

@ -0,0 +1,135 @@
# -*- coding: utf-8 -*-
import os
import sys
inp_text = os.environ.get("inp_text")
inp_wav_dir = os.environ.get("inp_wav_dir")
exp_name = os.environ.get("exp_name")
i_part = os.environ.get("i_part")
all_parts = os.environ.get("all_parts")
if "_CUDA_VISIBLE_DEVICES" in os.environ:
os.environ["CUDA_VISIBLE_DEVICES"] = os.environ["_CUDA_VISIBLE_DEVICES"]
from feature_extractor import cnhubert
opt_dir = os.environ.get("opt_dir")
cnhubert.cnhubert_base_path = os.environ.get("cnhubert_base_dir")
import torch
is_half = eval(os.environ.get("is_half", "True")) and torch.cuda.is_available()
import traceback
import librosa
import numpy as np
from scipy.io import wavfile
now_dir = os.getcwd()
sys.path.append(now_dir)
import shutil
# from config import cnhubert_base_path
# cnhubert.cnhubert_base_path=cnhubert_base_path
# inp_text=sys.argv[1]
# inp_wav_dir=sys.argv[2]
# exp_name=sys.argv[3]
# i_part=sys.argv[4]
# all_parts=sys.argv[5]
# os.environ["CUDA_VISIBLE_DEVICES"]=sys.argv[6]
# cnhubert.cnhubert_base_path=sys.argv[7]
# opt_dir="/data/docker/liujing04/gpt-vits/fine_tune_dataset/%s"%exp_name
from time import time as ttime
from tools.my_utils import clean_path, load_audio
def my_save(fea, path): #####fix issue: torch.save doesn't support chinese path
dir = os.path.dirname(path)
name = os.path.basename(path)
# tmp_path="%s/%s%s.pth"%(dir,ttime(),i_part)
tmp_path = "%s%s.pth" % (ttime(), i_part)
torch.save(fea, tmp_path)
shutil.move(tmp_path, "%s/%s" % (dir, name))
hubert_dir = "%s/4-cnhubert" % (opt_dir)
wav32dir = "%s/5-wav32k" % (opt_dir)
os.makedirs(opt_dir, exist_ok=True)
os.makedirs(hubert_dir, exist_ok=True)
os.makedirs(wav32dir, exist_ok=True)
maxx = 0.95
alpha = 0.5
if torch.cuda.is_available():
device = "cuda:0"
# elif torch.backends.mps.is_available():
# device = "mps"
else:
device = "cpu"
model = cnhubert.get_model()
# is_half=False
if is_half is True:
model = model.half().to(device)
else:
model = model.to(device)
nan_fails = []
def name2go(wav_name, wav_path):
hubert_path = "%s/%s.pt" % (hubert_dir, wav_name)
if os.path.exists(hubert_path):
return
tmp_audio = load_audio(wav_path, 32000)
tmp_max = np.abs(tmp_audio).max()
if tmp_max > 2.2:
print("%s-filtered,%s" % (wav_name, tmp_max))
return
tmp_audio32 = (tmp_audio / tmp_max * (maxx * alpha * 32768)) + ((1 - alpha) * 32768) * tmp_audio
tmp_audio32b = (tmp_audio / tmp_max * (maxx * alpha * 1145.14)) + ((1 - alpha) * 1145.14) * tmp_audio
tmp_audio = librosa.resample(tmp_audio32b, orig_sr=32000, target_sr=16000) # 不是重采样问题
tensor_wav16 = torch.from_numpy(tmp_audio)
if is_half is True:
tensor_wav16 = tensor_wav16.half().to(device)
else:
tensor_wav16 = tensor_wav16.to(device)
ssl = model.model(tensor_wav16.unsqueeze(0))["last_hidden_state"].transpose(1, 2).cpu() # torch.Size([1, 768, 215])
if np.isnan(ssl.detach().numpy()).sum() != 0:
nan_fails.append((wav_name, wav_path))
print("nan filtered:%s" % wav_name)
return
wavfile.write(
"%s/%s" % (wav32dir, wav_name),
32000,
tmp_audio32.astype("int16"),
)
my_save(ssl, hubert_path)
with open(inp_text, "r", encoding="utf8") as f:
lines = f.read().strip("\n").split("\n")
for line in lines[int(i_part) :: int(all_parts)]:
try:
# wav_name,text=line.split("\t")
wav_name, spk_name, language, text = line.split("|")
wav_name = clean_path(wav_name)
if inp_wav_dir != "" and inp_wav_dir != None:
wav_name = os.path.basename(wav_name)
wav_path = "%s/%s" % (inp_wav_dir, wav_name)
else:
wav_path = wav_name
wav_name = os.path.basename(wav_name)
name2go(wav_name, wav_path)
except:
print(line, traceback.format_exc())
if len(nan_fails) > 0 and is_half is True:
is_half = False
model = model.float()
for wav in nan_fails:
try:
name2go(wav[0], wav[1])
except:
print(wav_name, traceback.format_exc())

View File

@ -1,7 +1,7 @@
# -*- coding: utf-8 -*-
import sys
import os
import sys
inp_text = os.environ.get("inp_text")
inp_wav_dir = os.environ.get("inp_wav_dir")
@ -19,13 +19,14 @@ import torch
is_half = eval(os.environ.get("is_half", "True")) and torch.cuda.is_available()
import traceback
import librosa
import numpy as np
from scipy.io import wavfile
import librosa
now_dir = os.getcwd()
sys.path.append(now_dir)
from tools.my_utils import load_audio, clean_path
import shutil
# from config import cnhubert_base_path
# cnhubert.cnhubert_base_path=cnhubert_base_path
@ -37,9 +38,9 @@ from tools.my_utils import load_audio, clean_path
# os.environ["CUDA_VISIBLE_DEVICES"]=sys.argv[6]
# cnhubert.cnhubert_base_path=sys.argv[7]
# opt_dir="/data/docker/liujing04/gpt-vits/fine_tune_dataset/%s"%exp_name
from time import time as ttime
import shutil
from tools.my_utils import clean_path, load_audio
def my_save(fea, path): #####fix issue: torch.save doesn't support chinese path
@ -67,7 +68,7 @@ else:
device = "cpu"
model = cnhubert.get_model()
# is_half=False
if is_half == True:
if is_half is True:
model = model.half().to(device)
else:
model = model.to(device)
@ -88,7 +89,7 @@ def name2go(wav_name, wav_path):
tmp_audio32b = (tmp_audio / tmp_max * (maxx * alpha * 1145.14)) + ((1 - alpha) * 1145.14) * tmp_audio
tmp_audio = librosa.resample(tmp_audio32b, orig_sr=32000, target_sr=16000) # 不是重采样问题
tensor_wav16 = torch.from_numpy(tmp_audio)
if is_half == True:
if is_half is True:
tensor_wav16 = tensor_wav16.half().to(device)
else:
tensor_wav16 = tensor_wav16.to(device)
@ -124,7 +125,7 @@ for line in lines[int(i_part) :: int(all_parts)]:
except:
print(line, traceback.format_exc())
if len(nan_fails) > 0 and is_half == True:
if len(nan_fails) > 0 and is_half is True:
is_half = False
model = model.float()
for wav in nan_fails:

View File

@ -1,7 +1,7 @@
# -*- coding: utf-8 -*-
import sys
import os
import sys
inp_text = os.environ.get("inp_text")
inp_wav_dir = os.environ.get("inp_wav_dir")
@ -18,16 +18,19 @@ import torch
is_half = eval(os.environ.get("is_half", "True")) and torch.cuda.is_available()
import traceback
import torchaudio
now_dir = os.getcwd()
sys.path.append(now_dir)
sys.path.append(f"{now_dir}/GPT_SoVITS/eres2net")
from tools.my_utils import clean_path
from time import time as ttime
import shutil
from ERes2NetV2 import ERes2NetV2
from time import time as ttime
import kaldi as Kaldi
from ERes2NetV2 import ERes2NetV2
from tools.my_utils import clean_path
def my_save(fea, path): #####fix issue: torch.save doesn't support chinese path
@ -63,7 +66,7 @@ class SV:
embedding_model.eval()
self.embedding_model = embedding_model
self.res = torchaudio.transforms.Resample(32000, 16000).to(device)
if is_half == False:
if is_half is False:
self.embedding_model = self.embedding_model.to(device)
else:
self.embedding_model = self.embedding_model.half().to(device)
@ -72,7 +75,7 @@ class SV:
def compute_embedding3(self, wav): # (1,x)#-1~1
with torch.no_grad():
wav = self.res(wav)
if self.is_half == True:
if self.is_half is True:
wav = wav.half()
feat = torch.stack(
[Kaldi.fbank(wav0.unsqueeze(0), num_mel_bins=80, sample_frequency=16000, dither=0) for wav0 in wav]

View File

@ -29,18 +29,19 @@ else:
import torch
is_half = eval(os.environ.get("is_half", "True")) and torch.cuda.is_available()
import traceback
import sys
import traceback
now_dir = os.getcwd()
sys.path.append(now_dir)
import logging
import utils
import GPT_SoVITS.utils as utils
if version != "v3":
from module.models import SynthesizerTrn
from GPT_SoVITS.module.models import SynthesizerTrn
else:
from module.models import SynthesizerTrnV3 as SynthesizerTrn
from GPT_SoVITS.module.models import SynthesizerTrnV3 as SynthesizerTrn
from tools.my_utils import clean_path
logging.getLogger("numba").setLevel(logging.WARNING)
@ -56,7 +57,7 @@ logging.getLogger("numba").setLevel(logging.WARNING)
hubert_dir = "%s/4-cnhubert" % (opt_dir)
semantic_path = "%s/6-name2semantic-%s.tsv" % (opt_dir, i_part)
if os.path.exists(semantic_path) == False:
if os.path.exists(semantic_path) is False:
os.makedirs(opt_dir, exist_ok=True)
if torch.cuda.is_available():
@ -73,7 +74,7 @@ if os.path.exists(semantic_path) == False:
version=version,
**hps.model,
)
if is_half == True:
if is_half is True:
vq_model = vq_model.half().to(device)
else:
vq_model = vq_model.to(device)
@ -88,10 +89,10 @@ if os.path.exists(semantic_path) == False:
def name2go(wav_name, lines):
hubert_path = "%s/%s.pt" % (hubert_dir, wav_name)
if os.path.exists(hubert_path) == False:
if os.path.exists(hubert_path) is False:
return
ssl_content = torch.load(hubert_path, map_location="cpu")
if is_half == True:
if is_half is True:
ssl_content = ssl_content.half().to(device)
else:
ssl_content = ssl_content.to(device)

View File

@ -1,15 +1,17 @@
import os
import shutil
import traceback
from collections import OrderedDict
from time import time as ttime
import shutil
import os
import torch
from tools.i18n.i18n import I18nAuto
i18n = I18nAuto()
def my_save(fea, path): #####fix issue: torch.save doesn't support chinese path
def my_save(fea, path): # fix issue: torch.save doesn't support chinese path
dir = os.path.dirname(path)
name = os.path.basename(path)
tmp_path = "%s.pth" % (ttime())
@ -17,27 +19,6 @@ def my_save(fea, path): #####fix issue: torch.save doesn't support chinese path
shutil.move(tmp_path, "%s/%s" % (dir, name))
from io import BytesIO
model_version2byte = {
"v3": b"03",
"v4": b"04",
"v2Pro": b"05",
"v2ProPlus": b"06",
}
def my_save2(fea, path, model_version):
bio = BytesIO()
torch.save(fea, bio)
bio.seek(0)
data = bio.getvalue()
byte = model_version2byte[model_version]
data = byte + data[2:]
with open(path, "wb") as f:
f.write(data)
def savee(ckpt, name, epoch, steps, hps, model_version=None, lora_rank=None):
try:
opt = OrderedDict()
@ -50,89 +31,33 @@ def savee(ckpt, name, epoch, steps, hps, model_version=None, lora_rank=None):
opt["info"] = "%sepoch_%siteration" % (epoch, steps)
if lora_rank:
opt["lora_rank"] = lora_rank
my_save2(opt, "%s/%s.pth" % (hps.save_weight_dir, name), model_version)
elif model_version != None and "Pro" in model_version:
my_save2(opt, "%s/%s.pth" % (hps.save_weight_dir, name), model_version)
else:
my_save(opt, "%s/%s.pth" % (hps.save_weight_dir, name))
my_save(opt, f"{hps.save_weight_dir}/{name}.pth")
return "Success."
except:
except Exception:
return traceback.format_exc()
"""
00:v1
01:v2
02:v3
03:v3lora
04:v4lora
05:v2Pro
06:v2ProPlus
"""
head2version = {
b"00": ["v1", "v1", False],
b"01": ["v2", "v2", False],
b"02": ["v2", "v3", False],
b"03": ["v2", "v3", True],
b"04": ["v2", "v4", True],
b"05": ["v2", "v2Pro", False],
b"06": ["v2", "v2ProPlus", False],
}
hash_pretrained_dict = {
"dc3c97e17592963677a4a1681f30c653": ["v2", "v2", False], # s2G488k.pth#sovits_v1_pretrained
"43797be674a37c1c83ee81081941ed0f": ["v2", "v3", False], # s2Gv3.pth#sovits_v3_pretrained
"6642b37f3dbb1f76882b69937c95a5f3": ["v2", "v2", False], # s2G2333K.pth#sovits_v2_pretrained
"4f26b9476d0c5033e04162c486074374": ["v2", "v4", False], # s2Gv4.pth#sovits_v4_pretrained
"c7e9fce2223f3db685cdfa1e6368728a": ["v2", "v2Pro", False], # s2Gv2Pro.pth#sovits_v2Pro_pretrained
"66b313e39455b57ab1b0bc0b239c9d0a": ["v2", "v2ProPlus", False], # s2Gv2ProPlus.pth#sovits_v2ProPlus_pretrained
}
import hashlib
def inspect_version(f: str):
dict_s2 = torch.load(f, map_location="cpu", mmap=True)
hps = dict_s2["config"]
version = None
if "version" in hps:
version = hps.version
is_lora = "lora_rank" in dict_s2.keys()
def get_hash_from_file(sovits_path):
with open(sovits_path, "rb") as f:
data = f.read(8192)
hash_md5 = hashlib.md5()
hash_md5.update(data)
return hash_md5.hexdigest()
def get_sovits_version_from_path_fast(sovits_path):
###1-if it is pretrained sovits models, by hash
hash = get_hash_from_file(sovits_path)
if hash in hash_pretrained_dict:
return hash_pretrained_dict[hash]
###2-new weights, by head
with open(sovits_path, "rb") as f:
version = f.read(2)
if version != b"PK":
return head2version[version]
###3-old weights, by file size
if_lora_v3 = False
size = os.path.getsize(sovits_path)
"""
v1weights:about 82942KB
half thr:82978KB
v2weights:about 83014KB
v3weights:about 750MB
"""
if size < 82978 * 1024:
model_version = version = "v1"
elif size < 700 * 1024 * 1024:
model_version = version = "v2"
if version is not None:
lang_version = "v2"
model_version = version
else:
version = "v2"
model_version = "v3"
return version, model_version, if_lora_v3
if "dec.conv_pre.weight" in dict_s2["weight"].keys():
if dict_s2["weight"]["enc_p.text_embedding.weight"].shape[0] == 322:
lang_version = model_version = "v1"
else:
lang_version = model_version = "v2"
else:
lang_version = "v2"
model_version = "v3"
if dict_s2["info"] == "pretrained_s2G_v4":
model_version = "v4"
def load_sovits_new(sovits_path):
f = open(sovits_path, "rb")
meta = f.read(2)
if meta != b"PK":
data = b"PK" + f.read()
bio = BytesIO()
bio.write(data)
bio.seek(0)
return torch.load(bio, map_location="cpu", weights_only=False)
return torch.load(sovits_path, map_location="cpu", weights_only=False)
return model_version, lang_version, is_lora, hps, dict_s2

View File

@ -1,32 +1,36 @@
# modified from https://github.com/feng-yufei/shared_debugging_code/blob/main/train_t2s.py
import os
if "_CUDA_VISIBLE_DEVICES" in os.environ:
os.environ["CUDA_VISIBLE_DEVICES"] = os.environ["_CUDA_VISIBLE_DEVICES"]
import argparse
import logging
import os
import platform
from collections import OrderedDict
from pathlib import Path
from typing import Any
import torch
from AR.data.data_module import Text2SemanticDataModule
from AR.models.t2s_lightning_module import Text2SemanticLightningModule
from AR.utils.io import load_yaml_config
from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger # WandbLogger
from pytorch_lightning.strategies import DDPStrategy
from pytorch_lightning.strategies import DDPStrategy, SingleDeviceStrategy
from pytorch_lightning.strategies.strategy import Strategy
from GPT_SoVITS.AR.data.data_module import Text2SemanticDataModule
from GPT_SoVITS.AR.models.t2s_lightning_module import Text2SemanticLightningModule
from GPT_SoVITS.AR.utils import get_newest_ckpt
from GPT_SoVITS.AR.utils.io import load_yaml_config
from GPT_SoVITS.process_ckpt import my_save
logging.getLogger("numba").setLevel(logging.WARNING)
logging.getLogger("matplotlib").setLevel(logging.WARNING)
torch.set_float32_matmul_precision("high")
from collections import OrderedDict
from AR.utils import get_newest_ckpt
from process_ckpt import my_save
class my_model_ckpt(ModelCheckpoint):
os.environ["MASTER_ADDR"] = "localhost"
if platform.system() == "Windows":
os.environ["USE_LIBUV"] = "0"
class ARModelCheckpoint(ModelCheckpoint):
def __init__(
self,
config,
@ -44,40 +48,31 @@ class my_model_ckpt(ModelCheckpoint):
self.config = config
def on_train_epoch_end(self, trainer, pl_module):
# if not self._should_skip_saving_checkpoint(trainer) and self._should_save_on_train_epoch_end(trainer):
if self._should_save_on_train_epoch_end(trainer):
monitor_candidates = self._monitor_candidates(trainer)
if self._every_n_epochs >= 1 and (trainer.current_epoch + 1) % self._every_n_epochs == 0:
if (
self.if_save_latest == True
): ####如果设置只保存最后一个ckpt在保存下一个ckpt后要清理掉之前的所有ckpt
to_clean = list(os.listdir(self.dirpath))
self._save_topk_checkpoint(trainer, monitor_candidates)
if self.if_save_latest == True:
if self.if_save_latest is True: # 如果设置只保存最后一个ckpt在保存下一个ckpt后要清理掉之前的所有ckpt
to_clean = list(os.listdir(self.dirpath))
for name in to_clean:
try:
os.remove("%s/%s" % (self.dirpath, name))
except:
os.remove(f"{self.dirpath}/{name}")
except Exception as _:
pass
if self.if_save_every_weights == True:
to_save_od = OrderedDict()
if self.if_save_every_weights is True:
to_save_od: OrderedDict[str, Any] = OrderedDict()
to_save_od["weight"] = OrderedDict()
dictt = trainer.strategy._lightning_module.state_dict()
for key in dictt:
to_save_od["weight"][key] = dictt[key].half()
to_save_od["config"] = self.config
to_save_od["info"] = "GPT-e%s" % (trainer.current_epoch + 1)
to_save_od["info"] = f"GPT-e{trainer.current_epoch + 1}"
# torch.save(
# print(os.environ)
if os.environ.get("LOCAL_RANK", "0") == "0":
my_save(
to_save_od,
"%s/%s-e%s.ckpt"
% (
self.half_weights_save_dir,
self.exp_name,
trainer.current_epoch + 1,
),
f"{self.half_weights_save_dir}/{self.exp_name}-e{trainer.current_epoch + 1}.ckpt",
)
self._save_last_checkpoint(trainer, monitor_candidates)
@ -91,8 +86,17 @@ def main(args):
ckpt_dir = output_dir / "ckpt"
ckpt_dir.mkdir(parents=True, exist_ok=True)
if torch.cuda.is_available():
if torch.cuda.device_count() > 1:
strategy: Strategy = DDPStrategy(process_group_backend="nccl" if platform.system() != "Windows" else "gloo")
else:
strategy = SingleDeviceStrategy("cuda")
else:
strategy = SingleDeviceStrategy("cpu")
seed_everything(config["train"]["seed"], workers=True)
ckpt_callback: ModelCheckpoint = my_model_ckpt(
ckpt_callback: ModelCheckpoint = ARModelCheckpoint(
config=config,
if_save_latest=config["train"]["if_save_latest"],
if_save_every_weights=config["train"]["if_save_every_weights"],
@ -106,20 +110,15 @@ def main(args):
dirpath=ckpt_dir,
)
logger = TensorBoardLogger(name=output_dir.stem, save_dir=output_dir)
os.environ["MASTER_ADDR"] = "localhost"
os.environ["USE_LIBUV"] = "0"
trainer: Trainer = Trainer(
max_epochs=config["train"]["epochs"],
accelerator="gpu" if torch.cuda.is_available() else "cpu",
# val_check_interval=9999999999999999999999,###不要验证
# check_val_every_n_epoch=None,
limit_val_batches=0,
devices=-1 if torch.cuda.is_available() else 1,
benchmark=False,
fast_dev_run=False,
strategy=DDPStrategy(process_group_backend="nccl" if platform.system() != "Windows" else "gloo")
if torch.cuda.is_available()
else "auto",
strategy=strategy,
precision=config["train"]["precision"],
logger=logger,
num_sanity_val_steps=0,
@ -133,8 +132,6 @@ def main(args):
config,
train_semantic_path=config["train_semantic_path"],
train_phoneme_path=config["train_phoneme_path"],
# dev_semantic_path=args.dev_semantic_path,
# dev_phoneme_path=args.dev_phoneme_path
)
try:

View File

@ -1,53 +1,56 @@
import warnings
warnings.filterwarnings("ignore")
import os
import utils
hps = utils.get_hparams(stage=2)
os.environ["CUDA_VISIBLE_DEVICES"] = hps.train.gpu_numbers.replace("-", ",")
import logging
import os
import platform
import warnings
from random import randint
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.cuda.amp import GradScaler, autocast
import torch.multiprocessing.spawn as mp
from torch.amp.autocast_mode import autocast
from torch.amp.grad_scaler import GradScaler
from torch.nn import functional as F
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
logging.getLogger("matplotlib").setLevel(logging.INFO)
logging.getLogger("h5py").setLevel(logging.INFO)
logging.getLogger("numba").setLevel(logging.INFO)
from random import randint
from module import commons
from module.data_utils import (
import GPT_SoVITS.utils as utils
from GPT_SoVITS.module import commons
from GPT_SoVITS.module.data_utils import (
DistributedBucketSampler,
TextAudioSpeakerCollate,
TextAudioSpeakerLoader,
)
from module.losses import discriminator_loss, feature_loss, generator_loss, kl_loss
from module.mel_processing import mel_spectrogram_torch, spec_to_mel_torch
from module.models import (
from GPT_SoVITS.module.losses import discriminator_loss, feature_loss, generator_loss, kl_loss
from GPT_SoVITS.module.mel_processing import mel_spectrogram_torch, spec_to_mel_torch
from GPT_SoVITS.module.models import (
MultiPeriodDiscriminator,
SynthesizerTrn,
)
from process_ckpt import savee
from GPT_SoVITS.process_ckpt import savee
logging.getLogger("matplotlib").setLevel(logging.INFO)
logging.getLogger("h5py").setLevel(logging.INFO)
logging.getLogger("numba").setLevel(logging.INFO)
hps = utils.get_hparams(stage=2)
warnings.filterwarnings("ignore")
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = False
###反正A100fp32更快那试试tf32吧
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cuda.matmul.allow_tf32 = True ###反正A100fp32更快那试试tf32吧
torch.backends.cudnn.allow_tf32 = True
torch.set_float32_matmul_precision("medium") # 最低精度但最快(也就快一丁点),对于结果造成不了影响
# from config import pretrained_s2G,pretrained_s2D
global_step = 0
device = "cpu" # cuda以外的设备等mps优化后加入
if torch.cuda.is_available():
device_str = "cuda"
else:
device_str = "cpu"
multigpu = torch.cuda.device_count() > 1 if torch.cuda.is_available() else False
def main():
@ -70,19 +73,23 @@ def main():
def run(rank, n_gpus, hps):
global global_step
device = torch.device("{device_str}:{rank}")
if rank == 0:
logger = utils.get_logger(hps.data.exp_dir)
logger.info(hps)
# utils.check_git_hash(hps.s2_ckpt_dir)
writer = SummaryWriter(log_dir=hps.s2_ckpt_dir)
writer_eval = SummaryWriter(log_dir=os.path.join(hps.s2_ckpt_dir, "eval"))
else:
logger = writer = writer_eval = None
if multigpu:
dist.init_process_group(
backend="gloo" if os.name == "nt" or not torch.cuda.is_available() else "nccl",
init_method="env://?use_libuv=False" if platform.system() == "Windows" else "env://",
world_size=n_gpus,
rank=rank,
)
dist.init_process_group(
backend="gloo" if os.name == "nt" or not torch.cuda.is_available() else "nccl",
init_method="env://?use_libuv=False",
world_size=n_gpus,
rank=rank,
)
torch.manual_seed(hps.train.seed)
if torch.cuda.is_available():
torch.cuda.set_device(rank)
@ -126,33 +133,19 @@ def run(rank, n_gpus, hps):
persistent_workers=True,
prefetch_factor=4,
)
# if rank == 0:
# eval_dataset = TextAudioSpeakerLoader(hps.data.validation_files, hps.data, val=True)
# eval_loader = DataLoader(eval_dataset, num_workers=0, shuffle=False,
# batch_size=1, pin_memory=True,
# drop_last=False, collate_fn=collate_fn)
net_g = (
SynthesizerTrn(
hps.data.filter_length // 2 + 1,
hps.train.segment_size // hps.data.hop_length,
n_speakers=hps.data.n_speakers,
**hps.model,
).cuda(rank)
if torch.cuda.is_available()
else SynthesizerTrn(
hps.data.filter_length // 2 + 1,
hps.train.segment_size // hps.data.hop_length,
n_speakers=hps.data.n_speakers,
**hps.model,
).to(device)
)
net_g = SynthesizerTrn(
hps.data.filter_length // 2 + 1,
hps.train.segment_size // hps.data.hop_length,
n_speakers=hps.data.n_speakers,
**hps.model,
).to(device)
net_d = MultiPeriodDiscriminator(
hps.model.use_spectral_norm,
version=hps.model.version,
).to(device)
net_d = (
MultiPeriodDiscriminator(hps.model.use_spectral_norm, version=hps.model.version).cuda(rank)
if torch.cuda.is_available()
else MultiPeriodDiscriminator(hps.model.use_spectral_norm, version=hps.model.version).to(device)
)
for name, param in net_g.named_parameters():
if not param.requires_grad:
print(name, "not requires_grad")
@ -165,10 +158,6 @@ def run(rank, n_gpus, hps):
net_g.parameters(),
)
# te_p=net_g.enc_p.text_embedding.parameters()
# et_p=net_g.enc_p.encoder_text.parameters()
# mrte_p=net_g.enc_p.mrte.parameters()
optim_g = torch.optim.AdamW(
# filter(lambda p: p.requires_grad, net_g.parameters()),###默认所有层lr一致
[
@ -196,7 +185,7 @@ def run(rank, n_gpus, hps):
betas=hps.train.betas,
eps=hps.train.eps,
)
if torch.cuda.is_available():
if multigpu:
net_g = DDP(net_g, device_ids=[rank], find_unused_parameters=True)
net_d = DDP(net_d, device_ids=[rank], find_unused_parameters=True)
else:
@ -204,49 +193,49 @@ def run(rank, n_gpus, hps):
net_d = net_d.to(device)
try: # 如果能加载自动resume
_, _, _, epoch_str = utils.load_checkpoint(
utils.latest_checkpoint_path("%s/logs_s2_%s" % (hps.data.exp_dir, hps.model.version), "D_*.pth"),
epoch_str = utils.load_checkpoint(
utils.latest_checkpoint_path(f"{(hps.data.exp_dir,)}/logs_s2_{hps.model.version}", "D_*.pth"),
net_d,
optim_d,
) # D多半加载没事
)[-1] # D多半加载没事
if rank == 0:
logger.info("loaded D")
# _, _, _, epoch_str = utils.load_checkpoint(utils.latest_checkpoint_path(hps.model_dir, "G_*.pth"), net_g, optim_g,load_opt=0)
_, _, _, epoch_str = utils.load_checkpoint(
utils.latest_checkpoint_path("%s/logs_s2_%s" % (hps.data.exp_dir, hps.model.version), "G_*.pth"),
epoch_str = utils.load_checkpoint(
utils.latest_checkpoint_path(f"{hps.data.exp_dir}/logs_s2_{hps.model.version}", "G_*.pth"),
net_g,
optim_g,
)
)[-1]
epoch_str += 1
global_step = (epoch_str - 1) * len(train_loader)
# epoch_str = 1
# global_step = 0
except: # 如果首次不能加载加载pretrain
except Exception: # 如果首次不能加载加载pretrain
# traceback.print_exc()
epoch_str = 1
global_step = 0
if (
hps.train.pretrained_s2G != ""
and hps.train.pretrained_s2G != None
and hps.train.pretrained_s2G is not None
and os.path.exists(hps.train.pretrained_s2G)
):
if rank == 0:
logger.info("loaded pretrained %s" % hps.train.pretrained_s2G)
logger.info(f"loaded pretrained {hps.train.pretrained_s2G}")
print(
"loaded pretrained %s" % hps.train.pretrained_s2G,
net_g.module.load_state_dict(
torch.load(hps.train.pretrained_s2G, map_location="cpu", weights_only=False)["weight"],
torch.load(hps.train.pretrained_s2G, map_location="cpu")["weight"],
strict=False,
)
if torch.cuda.is_available()
if multigpu
else net_g.load_state_dict(
torch.load(hps.train.pretrained_s2G, map_location="cpu", weights_only=False)["weight"],
torch.load(hps.train.pretrained_s2G, map_location="cpu")["weight"],
strict=False,
),
) ##测试不加载优化器
if (
hps.train.pretrained_s2D != ""
and hps.train.pretrained_s2D != None
and hps.train.pretrained_s2D is not None
and os.path.exists(hps.train.pretrained_s2D)
):
if rank == 0:
@ -254,11 +243,11 @@ def run(rank, n_gpus, hps):
print(
"loaded pretrained %s" % hps.train.pretrained_s2D,
net_d.module.load_state_dict(
torch.load(hps.train.pretrained_s2D, map_location="cpu", weights_only=False)["weight"], strict=False
torch.load(hps.train.pretrained_s2D, map_location="cpu")["weight"], strict=False
)
if torch.cuda.is_available()
if multigpu
else net_d.load_state_dict(
torch.load(hps.train.pretrained_s2D, map_location="cpu", weights_only=False)["weight"],
torch.load(hps.train.pretrained_s2D, map_location="cpu")["weight"],
),
)
@ -279,13 +268,13 @@ def run(rank, n_gpus, hps):
scheduler_g.step()
scheduler_d.step()
scaler = GradScaler(enabled=hps.train.fp16_run)
scaler = GradScaler(device=device.type, enabled=hps.train.fp16_run)
print("start training from epoch %s" % epoch_str)
for epoch in range(epoch_str, hps.train.epochs + 1):
if rank == 0:
train_and_evaluate(
rank,
device,
epoch,
hps,
[net_g, net_d],
@ -299,7 +288,7 @@ def run(rank, n_gpus, hps):
)
else:
train_and_evaluate(
rank,
device,
epoch,
hps,
[net_g, net_d],
@ -315,7 +304,7 @@ def run(rank, n_gpus, hps):
print("training done")
def train_and_evaluate(rank, epoch, hps, nets, optims, schedulers, scaler, loaders, logger, writers):
def train_and_evaluate(device: torch.device, epoch, hps, nets, optims, schedulers, scaler, loaders, logger, writers):
net_g, net_d = nets
optim_g, optim_d = optims
# scheduler_g, scheduler_d = schedulers
@ -331,54 +320,19 @@ def train_and_evaluate(rank, epoch, hps, nets, optims, schedulers, scaler, loade
for batch_idx, data in enumerate(tqdm(train_loader)):
if hps.model.version in {"v2Pro", "v2ProPlus"}:
ssl, ssl_lengths, spec, spec_lengths, y, y_lengths, text, text_lengths, sv_emb = data
ssl, ssl_lengths, spec, spec_lengths, y, y_lengths, text, text_lengths, sv_emb = map(
lambda x: x.to(device, non_blocking=True),
(ssl, ssl_lengths, spec, spec_lengths, y, y_lengths, text, text_lengths, sv_emb),
)
else:
ssl, ssl_lengths, spec, spec_lengths, y, y_lengths, text, text_lengths = data
if torch.cuda.is_available():
spec, spec_lengths = (
spec.cuda(
rank,
non_blocking=True,
),
spec_lengths.cuda(
rank,
non_blocking=True,
),
ssl, ssl_lengths, spec, spec_lengths, y, y_lengths, text, text_lengths = map(
lambda x: x.to(device, non_blocking=True),
(ssl, ssl_lengths, spec, spec_lengths, y, y_lengths, text, text_lengths),
)
y, y_lengths = (
y.cuda(
rank,
non_blocking=True,
),
y_lengths.cuda(
rank,
non_blocking=True,
),
)
ssl = ssl.cuda(rank, non_blocking=True)
ssl.requires_grad = False
# ssl_lengths = ssl_lengths.cuda(rank, non_blocking=True)
text, text_lengths = (
text.cuda(
rank,
non_blocking=True,
),
text_lengths.cuda(
rank,
non_blocking=True,
),
)
if hps.model.version in {"v2Pro", "v2ProPlus"}:
sv_emb = sv_emb.cuda(rank, non_blocking=True)
else:
spec, spec_lengths = spec.to(device), spec_lengths.to(device)
y, y_lengths = y.to(device), y_lengths.to(device)
ssl = ssl.to(device)
ssl.requires_grad = False
# ssl_lengths = ssl_lengths.cuda(rank, non_blocking=True)
text, text_lengths = text.to(device), text_lengths.to(device)
if hps.model.version in {"v2Pro", "v2ProPlus"}:
sv_emb = sv_emb.to(device)
with autocast(enabled=hps.train.fp16_run):
ssl.requires_grad = False
with autocast(device_type=device.type, dtype=torch.float16, enabled=hps.train.fp16_run):
if hps.model.version in {"v2Pro", "v2ProPlus"}:
(y_hat, kl_ssl, ids_slice, x_mask, z_mask, (z, z_p, m_p, logs_p, m_q, logs_q), stats_ssl) = net_g(
ssl, spec, spec_lengths, text, text_lengths, sv_emb
@ -418,22 +372,23 @@ def train_and_evaluate(rank, epoch, hps, nets, optims, schedulers, scaler, loade
# Discriminator
y_d_hat_r, y_d_hat_g, _, _ = net_d(y, y_hat.detach())
with autocast(enabled=False):
with autocast(device_type=device.type, enabled=False):
loss_disc, losses_disc_r, losses_disc_g = discriminator_loss(
y_d_hat_r,
y_d_hat_g,
)
loss_disc_all = loss_disc
optim_d.zero_grad()
scaler.scale(loss_disc_all).backward()
scaler.unscale_(optim_d)
grad_norm_d = commons.clip_grad_value_(net_d.parameters(), None)
scaler.step(optim_d)
with autocast(enabled=hps.train.fp16_run):
with autocast(device_type=device.type, dtype=torch.float16, enabled=hps.train.fp16_run):
# Generator
y_d_hat_r, y_d_hat_g, fmap_r, fmap_g = net_d(y, y_hat)
with autocast(enabled=False):
with autocast(device_type=device.type, enabled=False):
loss_mel = F.l1_loss(y_mel, y_hat_mel) * hps.train.c_mel
loss_kl = kl_loss(z_p, logs_q, m_p, logs_p, z_mask) * hps.train.c_kl
@ -448,7 +403,7 @@ def train_and_evaluate(rank, epoch, hps, nets, optims, schedulers, scaler, loade
scaler.step(optim_g)
scaler.update()
if rank == 0:
if device.index == 0:
if global_step % hps.train.log_interval == 0:
lr = optim_g.param_groups[0]["lr"]
losses = [loss_disc, loss_gen, loss_fm, loss_mel, kl_ssl, loss_kl]
@ -480,7 +435,7 @@ def train_and_evaluate(rank, epoch, hps, nets, optims, schedulers, scaler, loade
# scalar_dict.update({"loss/d_r/{}".format(i): v for i, v in enumerate(losses_disc_r)})
# scalar_dict.update({"loss/d_g/{}".format(i): v for i, v in enumerate(losses_disc_g)})
image_dict = None
try: ###Some people installed the wrong version of matplotlib.
try: # Some people installed the wrong version of matplotlib.
image_dict = {
"slice/mel_org": utils.plot_spectrogram_to_numpy(
y_mel[0].data.cpu().numpy(),
@ -495,7 +450,7 @@ def train_and_evaluate(rank, epoch, hps, nets, optims, schedulers, scaler, loade
stats_ssl[0].data.cpu().numpy(),
),
}
except:
except Exception as _:
pass
if image_dict:
utils.summarize(
@ -511,7 +466,7 @@ def train_and_evaluate(rank, epoch, hps, nets, optims, schedulers, scaler, loade
scalars=scalar_dict,
)
global_step += 1
if epoch % hps.train.save_every_epoch == 0 and rank == 0:
if epoch % hps.train.save_every_epoch == 0 and device.index == 0:
if hps.train.if_save_latest == 0:
utils.save_checkpoint(
net_g,
@ -519,8 +474,8 @@ def train_and_evaluate(rank, epoch, hps, nets, optims, schedulers, scaler, loade
hps.train.learning_rate,
epoch,
os.path.join(
"%s/logs_s2_%s" % (hps.data.exp_dir, hps.model.version),
"G_{}.pth".format(global_step),
f"{hps.data.exp_dir}/logs_s2_{hps.model.version}",
f"G_{global_step}.pth",
),
)
utils.save_checkpoint(
@ -529,8 +484,8 @@ def train_and_evaluate(rank, epoch, hps, nets, optims, schedulers, scaler, loade
hps.train.learning_rate,
epoch,
os.path.join(
"%s/logs_s2_%s" % (hps.data.exp_dir, hps.model.version),
"D_{}.pth".format(global_step),
"{hps.data.exp_dir}/logs_s2_{hps.model.version}",
"D_{global_step}.pth",
),
)
else:
@ -540,8 +495,8 @@ def train_and_evaluate(rank, epoch, hps, nets, optims, schedulers, scaler, loade
hps.train.learning_rate,
epoch,
os.path.join(
"%s/logs_s2_%s" % (hps.data.exp_dir, hps.model.version),
"G_{}.pth".format(233333333333),
f"{hps.data.exp_dir}/logs_s2_{hps.model.version}",
"G_233333333333.pth",
),
)
utils.save_checkpoint(
@ -550,11 +505,11 @@ def train_and_evaluate(rank, epoch, hps, nets, optims, schedulers, scaler, loade
hps.train.learning_rate,
epoch,
os.path.join(
"%s/logs_s2_%s" % (hps.data.exp_dir, hps.model.version),
"D_{}.pth".format(233333333333),
"{hps.data.exp_dir}/logs_s2_{hps.model.version}",
"D_233333333333.pth",
),
)
if rank == 0 and hps.train.if_save_every_weights == True:
if device.index == 0 and hps.train.if_save_every_weights is True:
if hasattr(net_g, "module"):
ckpt = net_g.module.state_dict()
else:
@ -566,7 +521,7 @@ def train_and_evaluate(rank, epoch, hps, nets, optims, schedulers, scaler, loade
epoch,
savee(
ckpt,
hps.name + "_e%s_s%s" % (epoch, global_step),
hps.name + f"_e{epoch}_s{global_step}",
epoch,
global_step,
hps,
@ -575,11 +530,11 @@ def train_and_evaluate(rank, epoch, hps, nets, optims, schedulers, scaler, loade
)
)
if rank == 0:
logger.info("====> Epoch: {}".format(epoch))
if device.index == 0:
logger.info(f"====> Epoch: {epoch}")
def evaluate(hps, generator, eval_loader, writer_eval):
def evaluate(hps, generator, eval_loader, writer_eval, device):
generator.eval()
image_dict = {}
audio_dict = {}
@ -595,17 +550,10 @@ def evaluate(hps, generator, eval_loader, writer_eval):
text,
text_lengths,
) in enumerate(eval_loader):
print(111)
if torch.cuda.is_available():
spec, spec_lengths = spec.cuda(), spec_lengths.cuda()
y, y_lengths = y.cuda(), y_lengths.cuda()
ssl = ssl.cuda()
text, text_lengths = text.cuda(), text_lengths.cuda()
else:
spec, spec_lengths = spec.to(device), spec_lengths.to(device)
y, y_lengths = y.to(device), y_lengths.to(device)
ssl = ssl.to(device)
text, text_lengths = text.to(device), text_lengths.to(device)
spec, spec_lengths = spec.to(device), spec_lengths.to(device)
y, y_lengths = y.to(device), y_lengths.to(device)
ssl = ssl.to(device)
text, text_lengths = text.to(device), text_lengths.to(device)
for test in [0, 1]:
y_hat, mask, *_ = (
generator.module.infer(
@ -665,11 +613,6 @@ def evaluate(hps, generator, eval_loader, writer_eval):
)
audio_dict.update({f"gt/audio_{batch_idx}": y[0, :, : y_lengths[0]]})
# y_hat, mask, *_ = generator.module.infer(ssl, spec_lengths, speakers, y=None)
# audio_dict.update({
# f"gen/audio_{batch_idx}_style_pred": y_hat[0, :, :]
# })
utils.summarize(
writer=writer_eval,
global_step=global_step,

View File

@ -1,53 +1,51 @@
import warnings
warnings.filterwarnings("ignore")
import os
import utils
hps = utils.get_hparams(stage=2)
os.environ["CUDA_VISIBLE_DEVICES"] = hps.train.gpu_numbers.replace("-", ",")
import logging
import os
import platform
import warnings
from random import randint
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.cuda.amp import GradScaler, autocast
import torch.multiprocessing.spawn as mp
from torch.amp.autocast_mode import autocast
from torch.amp.grad_scaler import GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
import GPT_SoVITS.utils as utils
from GPT_SoVITS.module import commons
from GPT_SoVITS.module.data_utils import (
DistributedBucketSampler,
TextAudioSpeakerCollateV3,
TextAudioSpeakerLoaderV3,
)
from GPT_SoVITS.module.models import SynthesizerTrnV3
from GPT_SoVITS.process_ckpt import savee
hps = utils.get_hparams(stage=2)
warnings.filterwarnings("ignore")
logging.getLogger("matplotlib").setLevel(logging.INFO)
logging.getLogger("h5py").setLevel(logging.INFO)
logging.getLogger("numba").setLevel(logging.INFO)
from random import randint
from module import commons
from module.data_utils import (
DistributedBucketSampler,
)
from module.data_utils import (
TextAudioSpeakerCollateV3 as TextAudioSpeakerCollate,
)
from module.data_utils import (
TextAudioSpeakerLoaderV3 as TextAudioSpeakerLoader,
)
from module.models import (
SynthesizerTrnV3 as SynthesizerTrn,
)
from process_ckpt import savee
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = False
###反正A100fp32更快那试试tf32吧
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cuda.matmul.allow_tf32 = True ###反正A100fp32更快那试试tf32吧
torch.backends.cudnn.allow_tf32 = True
torch.set_float32_matmul_precision("medium") # 最低精度但最快(也就快一丁点),对于结果造成不了影响
# from config import pretrained_s2G,pretrained_s2D
global_step = 0
device = "cpu" # cuda以外的设备等mps优化后加入
if torch.cuda.is_available():
device_str = "cuda"
else:
device_str = "cpu"
multigpu = torch.cuda.device_count() > 1 if torch.cuda.is_available() else False
def main():
@ -70,24 +68,29 @@ def main():
def run(rank, n_gpus, hps):
global global_step
device = torch.device("{device_str}:{rank}")
if rank == 0:
logger = utils.get_logger(hps.data.exp_dir)
logger.info(hps)
# utils.check_git_hash(hps.s2_ckpt_dir)
writer = SummaryWriter(log_dir=hps.s2_ckpt_dir)
writer_eval = SummaryWriter(log_dir=os.path.join(hps.s2_ckpt_dir, "eval"))
else:
logger = writer = writer_eval = None
if multigpu:
dist.init_process_group(
backend="gloo" if os.name == "nt" or not torch.cuda.is_available() else "nccl",
init_method="env://?use_libuv=False" if platform.system() == "Windows" else "env://",
world_size=n_gpus,
rank=rank,
)
dist.init_process_group(
backend="gloo" if os.name == "nt" or not torch.cuda.is_available() else "nccl",
init_method="env://?use_libuv=False",
world_size=n_gpus,
rank=rank,
)
torch.manual_seed(hps.train.seed)
if torch.cuda.is_available():
torch.cuda.set_device(rank)
train_dataset = TextAudioSpeakerLoader(hps.data) ########
train_dataset = TextAudioSpeakerLoaderV3(hps.data)
train_sampler = DistributedBucketSampler(
train_dataset,
hps.train.batch_size,
@ -101,21 +104,12 @@ def run(rank, n_gpus, hps):
800,
900,
1000,
# 1100,
# 1200,
# 1300,
# 1400,
# 1500,
# 1600,
# 1700,
# 1800,
# 1900,
],
num_replicas=n_gpus,
rank=rank,
shuffle=True,
)
collate_fn = TextAudioSpeakerCollate()
collate_fn = TextAudioSpeakerCollateV3()
train_loader = DataLoader(
train_dataset,
num_workers=6,
@ -127,31 +121,17 @@ def run(rank, n_gpus, hps):
prefetch_factor=4,
)
# if rank == 0:
# eval_dataset = TextAudioSpeakerLoader(hps.data.validation_files, hps.data, val=True)
# eval_dataset = TextAudioSpeakerLoaderV3(hps.data.validation_files, hps.data, val=True)
# eval_loader = DataLoader(eval_dataset, num_workers=0, shuffle=False,
# batch_size=1, pin_memory=True,
# drop_last=False, collate_fn=collate_fn)
net_g = (
SynthesizerTrn(
hps.data.filter_length // 2 + 1,
hps.train.segment_size // hps.data.hop_length,
n_speakers=hps.data.n_speakers,
**hps.model,
).cuda(rank)
if torch.cuda.is_available()
else SynthesizerTrn(
hps.data.filter_length // 2 + 1,
hps.train.segment_size // hps.data.hop_length,
n_speakers=hps.data.n_speakers,
**hps.model,
).to(device)
)
# net_d = MultiPeriodDiscriminator(hps.model.use_spectral_norm).cuda(rank) if torch.cuda.is_available() else MultiPeriodDiscriminator(hps.model.use_spectral_norm).to(device)
# for name, param in net_g.named_parameters():
# if not param.requires_grad:
# print(name, "not requires_grad")
net_g = SynthesizerTrnV3(
hps.data.filter_length // 2 + 1,
hps.train.segment_size // hps.data.hop_length,
n_speakers=hps.data.n_speakers,
**hps.model,
).to(device)
optim_g = torch.optim.AdamW(
filter(lambda p: p.requires_grad, net_g.parameters()), ###默认所有层lr一致
@ -159,44 +139,26 @@ def run(rank, n_gpus, hps):
betas=hps.train.betas,
eps=hps.train.eps,
)
# optim_d = torch.optim.AdamW(
# net_d.parameters(),
# hps.train.learning_rate,
# betas=hps.train.betas,
# eps=hps.train.eps,
# )
if torch.cuda.is_available():
if multigpu:
net_g = DDP(net_g, device_ids=[rank], find_unused_parameters=True)
# net_d = DDP(net_d, device_ids=[rank], find_unused_parameters=True)
else:
net_g = net_g.to(device)
# net_d = net_d.to(device)
try: # 如果能加载自动resume
# _, _, _, epoch_str = utils.load_checkpoint(
# utils.latest_checkpoint_path("%s/logs_s2_%s" % (hps.data.exp_dir,hps.model.version), "D_*.pth"),
# net_d,
# optim_d,
# ) # D多半加载没事
# if rank == 0:
# logger.info("loaded D")
# _, _, _, epoch_str = utils.load_checkpoint(utils.latest_checkpoint_path(hps.model_dir, "G_*.pth"), net_g, optim_g,load_opt=0)
_, _, _, epoch_str = utils.load_checkpoint(
utils.latest_checkpoint_path("%s/logs_s2_%s" % (hps.data.exp_dir, hps.model.version), "G_*.pth"),
utils.latest_checkpoint_path(f"{hps.data.exp_dir}/logs_s2_{hps.model.version}", "G_*.pth"),
net_g,
optim_g,
)
epoch_str += 1
global_step = (epoch_str - 1) * len(train_loader)
# epoch_str = 1
# global_step = 0
except: # 如果首次不能加载加载pretrain
# traceback.print_exc()
except Exception: # 如果首次不能加载加载pretrain
epoch_str = 1
global_step = 0
if (
hps.train.pretrained_s2G != ""
and hps.train.pretrained_s2G != None
and hps.train.pretrained_s2G is not None
and os.path.exists(hps.train.pretrained_s2G)
):
if rank == 0:
@ -207,42 +169,26 @@ def run(rank, n_gpus, hps):
torch.load(hps.train.pretrained_s2G, map_location="cpu", weights_only=False)["weight"],
strict=False,
)
if torch.cuda.is_available()
if multigpu
else net_g.load_state_dict(
torch.load(hps.train.pretrained_s2G, map_location="cpu", weights_only=False)["weight"],
strict=False,
),
) ##测试不加载优化器
# if hps.train.pretrained_s2D != ""and hps.train.pretrained_s2D != None and os.path.exists(hps.train.pretrained_s2D):
# if rank == 0:
# logger.info("loaded pretrained %s" % hps.train.pretrained_s2D)
# print(
# net_d.module.load_state_dict(
# torch.load(hps.train.pretrained_s2D, map_location="cpu")["weight"]
# ) if torch.cuda.is_available() else net_d.load_state_dict(
# torch.load(hps.train.pretrained_s2D, map_location="cpu")["weight"]
# )
# )
# scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optim_g, gamma=hps.train.lr_decay, last_epoch=epoch_str - 2)
# scheduler_d = torch.optim.lr_scheduler.ExponentialLR(optim_d, gamma=hps.train.lr_decay, last_epoch=epoch_str - 2)
)
scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optim_g, gamma=hps.train.lr_decay, last_epoch=-1)
# scheduler_d = torch.optim.lr_scheduler.ExponentialLR(
# optim_d, gamma=hps.train.lr_decay, last_epoch=-1
# )
for _ in range(epoch_str):
scheduler_g.step()
# scheduler_d.step()
scaler = GradScaler(enabled=hps.train.fp16_run)
scaler = GradScaler(device=device.type, enabled=hps.train.fp16_run)
net_d = optim_d = scheduler_d = None
print("start training from epoch %s" % epoch_str)
print(f"start training from epoch {epoch_str}")
for epoch in range(epoch_str, hps.train.epochs + 1):
if rank == 0:
train_and_evaluate(
rank,
device,
epoch,
hps,
[net_g, net_d],
@ -256,7 +202,7 @@ def run(rank, n_gpus, hps):
)
else:
train_and_evaluate(
rank,
device,
epoch,
hps,
[net_g, net_d],
@ -273,7 +219,7 @@ def run(rank, n_gpus, hps):
def train_and_evaluate(
rank,
device: torch.device,
epoch,
hps,
nets,
@ -309,40 +255,14 @@ def train_and_evaluate(
for batch_idx, (ssl, spec, mel, ssl_lengths, spec_lengths, text, text_lengths, mel_lengths) in enumerate(
tqdm(train_loader)
):
if torch.cuda.is_available():
spec, spec_lengths = (
spec.cuda(
rank,
non_blocking=True,
),
spec_lengths.cuda(
rank,
non_blocking=True,
),
)
mel, mel_lengths = mel.cuda(rank, non_blocking=True), mel_lengths.cuda(rank, non_blocking=True)
ssl = ssl.cuda(rank, non_blocking=True)
ssl.requires_grad = False
# ssl_lengths = ssl_lengths.cuda(rank, non_blocking=True)
text, text_lengths = (
text.cuda(
rank,
non_blocking=True,
),
text_lengths.cuda(
rank,
non_blocking=True,
),
)
else:
spec, spec_lengths = spec.to(device), spec_lengths.to(device)
mel, mel_lengths = mel.to(device), mel_lengths.to(device)
ssl = ssl.to(device)
ssl.requires_grad = False
# ssl_lengths = ssl_lengths.cuda(rank, non_blocking=True)
text, text_lengths = text.to(device), text_lengths.to(device)
spec, spec_lengths = spec.to(device, non_blocking=True), spec_lengths.to(device, non_blocking=True)
mel, mel_lengths = mel.to(device, non_blocking=True), mel_lengths.to(device, non_blocking=True)
ssl = ssl.to(device, non_blocking=True)
ssl.requires_grad = False
# ssl_lengths = ssl_lengths.cuda(rank, non_blocking=True)
text, text_lengths = text.to(device, non_blocking=True), text_lengths.to(device, non_blocking=True)
with autocast(enabled=hps.train.fp16_run):
with autocast(device_type=device.type, dtype=torch.float16, enabled=hps.train.fp16_run):
cfm_loss = net_g(
ssl,
spec,
@ -362,7 +282,7 @@ def train_and_evaluate(
scaler.step(optim_g)
scaler.update()
if rank == 0:
if device.index == 0:
if global_step % hps.train.log_interval == 0:
lr = optim_g.param_groups[0]["lr"]
# losses = [commit_loss,cfm_loss,mel_loss,loss_disc, loss_gen, loss_fm, loss_mel, loss_kl]
@ -376,12 +296,6 @@ def train_and_evaluate(
logger.info([x.item() for x in losses] + [global_step, lr])
scalar_dict = {"loss/g/total": loss_gen_all, "learning_rate": lr, "grad_norm_g": grad_norm_g}
# image_dict = {
# "slice/mel_org": utils.plot_spectrogram_to_numpy(y_mel[0].data.cpu().numpy()),
# "slice/mel_gen": utils.plot_spectrogram_to_numpy(y_hat_mel[0].data.cpu().numpy()),
# "all/mel": utils.plot_spectrogram_to_numpy(mel[0].data.cpu().numpy()),
# "all/stats_ssl": utils.plot_spectrogram_to_numpy(stats_ssl[0].data.cpu().numpy()),
# }
utils.summarize(
writer=writer,
global_step=global_step,
@ -389,16 +303,8 @@ def train_and_evaluate(
scalars=scalar_dict,
)
# if global_step % hps.train.eval_interval == 0:
# # evaluate(hps, net_g, eval_loader, writer_eval)
# utils.save_checkpoint(net_g, optim_g, hps.train.learning_rate, epoch,os.path.join(hps.s2_ckpt_dir, "G_{}.pth".format(global_step)),scaler)
# # utils.save_checkpoint(net_d, optim_d, hps.train.learning_rate, epoch,os.path.join(hps.s2_ckpt_dir, "D_{}.pth".format(global_step)),scaler)
# # keep_ckpts = getattr(hps.train, 'keep_ckpts', 3)
# # if keep_ckpts > 0:
# # utils.clean_checkpoints(path_to_models=hps.s2_ckpt_dir, n_ckpts_to_keep=keep_ckpts, sort_by_time=True)
global_step += 1
if epoch % hps.train.save_every_epoch == 0 and rank == 0:
if epoch % hps.train.save_every_epoch == 0 and device.index == 0:
if hps.train.if_save_latest == 0:
utils.save_checkpoint(
net_g,
@ -406,19 +312,11 @@ def train_and_evaluate(
hps.train.learning_rate,
epoch,
os.path.join(
"%s/logs_s2_%s" % (hps.data.exp_dir, hps.model.version),
"G_{}.pth".format(global_step),
f"{hps.data.exp_dir}/logs_s2_{hps.model.version}",
f"G_{global_step}.pth",
),
)
# utils.save_checkpoint(
# net_d,
# optim_d,
# hps.train.learning_rate,
# epoch,
# os.path.join(
# "%s/logs_s2_%s" % (hps.data.exp_dir,hps.model.version), "D_{}.pth".format(global_step)
# ),
# )
else:
utils.save_checkpoint(
net_g,
@ -426,41 +324,27 @@ def train_and_evaluate(
hps.train.learning_rate,
epoch,
os.path.join(
"%s/logs_s2_%s" % (hps.data.exp_dir, hps.model.version),
"G_{}.pth".format(233333333333),
f"{hps.data.exp_dir}/logs_s2_{hps.model.version}",
"G_233333333333.pth",
),
)
# utils.save_checkpoint(
# net_d,
# optim_d,
# hps.train.learning_rate,
# epoch,
# os.path.join(
# "%s/logs_s2_%s" % (hps.data.exp_dir,hps.model.version), "D_{}.pth".format(233333333333)
# ),
# )
if rank == 0 and hps.train.if_save_every_weights == True:
if device.index == 0 and hps.train.if_save_every_weights is True:
if hasattr(net_g, "module"):
ckpt = net_g.module.state_dict()
else:
ckpt = net_g.state_dict()
logger.info(
"saving ckpt %s_e%s:%s"
% (
hps.name,
epoch,
savee(
ckpt,
hps.name + "_e%s_s%s" % (epoch, global_step),
epoch,
global_step,
hps,
),
)
save_info = savee(
ckpt,
hps.name + f"_e{epoch}_s{global_step}",
epoch,
global_step,
hps,
)
logger.info(f"saving ckpt {hps.name}_e{epoch}:{save_info}")
if rank == 0:
logger.info("====> Epoch: {}".format(epoch))
if device.index == 0:
logger.info(f"====> Epoch: {epoch}")
if __name__ == "__main__":

View File

@ -3,10 +3,9 @@ import warnings
warnings.filterwarnings("ignore")
import os
import utils
import GPT_SoVITS.utils as utils
hps = utils.get_hparams(stage=2)
os.environ["CUDA_VISIBLE_DEVICES"] = hps.train.gpu_numbers.replace("-", ",")
import logging
import torch
@ -24,19 +23,20 @@ logging.getLogger("numba").setLevel(logging.INFO)
from collections import OrderedDict as od
from random import randint
from module import commons
from module.data_utils import (
from peft import LoraConfig, get_peft_model
from GPT_SoVITS.module import commons
from GPT_SoVITS.module.data_utils import (
DistributedBucketSampler,
TextAudioSpeakerCollateV3,
TextAudioSpeakerLoaderV3,
TextAudioSpeakerCollateV4,
TextAudioSpeakerLoaderV3,
TextAudioSpeakerLoaderV4,
)
from module.models import (
from GPT_SoVITS.module.models import (
SynthesizerTrnV3 as SynthesizerTrn,
)
from peft import LoraConfig, get_peft_model
from process_ckpt import savee
from GPT_SoVITS.process_ckpt import savee
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = False
@ -343,7 +343,7 @@ def train_and_evaluate(rank, epoch, hps, nets, optims, schedulers, scaler, loade
epoch,
os.path.join(save_root, "G_{}.pth".format(233333333333)),
)
if rank == 0 and hps.train.if_save_every_weights == True:
if rank == 0 and hps.train.if_save_every_weights is True:
if hasattr(net_g, "module"):
ckpt = net_g.module.state_dict()
else:

View File

@ -1,21 +1,22 @@
import sys
import os
import torch
import sys
import torch
import torchaudio
from GPT_SoVITS.eres2net.ERes2NetV2 import ERes2NetV2
sys.path.append(f"{os.getcwd()}/GPT_SoVITS/eres2net")
sv_path = "GPT_SoVITS/pretrained_models/sv/pretrained_eres2netv2w24s4ep4.ckpt"
from ERes2NetV2 import ERes2NetV2
import kaldi as Kaldi
class SV:
def __init__(self, device, is_half):
pretrained_state = torch.load(sv_path, map_location="cpu", weights_only=False)
pretrained_state = torch.load(sv_path, map_location="cpu")
embedding_model = ERes2NetV2(baseWidth=24, scale=4, expansion=4)
embedding_model.load_state_dict(pretrained_state)
embedding_model.eval()
self.embedding_model = embedding_model
if is_half == False:
if is_half is False:
self.embedding_model = self.embedding_model.to(device)
else:
self.embedding_model = self.embedding_model.half().to(device)
@ -23,10 +24,15 @@ class SV:
def compute_embedding3(self, wav):
with torch.no_grad():
if self.is_half == True:
if self.is_half is True:
wav = wav.half()
feat = torch.stack(
[Kaldi.fbank(wav0.unsqueeze(0), num_mel_bins=80, sample_frequency=16000, dither=0) for wav0 in wav]
[
torchaudio.compliance.kaldi.fbank(
wav0.unsqueeze(0), num_mel_bins=80, sample_frequency=16000, dither=0
)
for wav0 in wav
]
)
sv_emb = self.embedding_model.forward3(feat)
return sv_emb

View File

@ -1,40 +1,41 @@
import logging
import re
from pathlib import Path
# jieba静音
import fast_langdetect
import jieba
from split_lang import LangSplitter
jieba.setLogLevel(logging.CRITICAL)
# 更改fast_langdetect大模型位置
from pathlib import Path
import fast_langdetect
fast_langdetect.infer._default_detector = fast_langdetect.infer.LangDetector(fast_langdetect.infer.LangDetectConfig(cache_dir=Path(__file__).parent.parent.parent / "pretrained_models" / "fast_langdetect"))
from split_lang import LangSplitter
fast_langdetect.infer._default_detector = fast_langdetect.infer.LangDetector(
fast_langdetect.infer.LangDetectConfig(
cache_dir=str(Path(__file__).parent.parent.parent / "pretrained_models" / "fast_langdetect")
)
)
def full_en(text):
pattern = r'^(?=.*[A-Za-z])[A-Za-z0-9\s\u0020-\u007E\u2000-\u206F\u3000-\u303F\uFF00-\uFFEF]+$'
pattern = r"^(?=.*[A-Za-z])[A-Za-z0-9\s\u0020-\u007E\u2000-\u206F\u3000-\u303F\uFF00-\uFFEF]+$"
return bool(re.match(pattern, text))
def full_cjk(text):
# 来自wiki
cjk_ranges = [
(0x4E00, 0x9FFF), # CJK Unified Ideographs
(0x3400, 0x4DB5), # CJK Extension A
(0x20000, 0x2A6DD), # CJK Extension B
(0x2A700, 0x2B73F), # CJK Extension C
(0x2B740, 0x2B81F), # CJK Extension D
(0x2B820, 0x2CEAF), # CJK Extension E
(0x2CEB0, 0x2EBEF), # CJK Extension F
(0x30000, 0x3134A), # CJK Extension G
(0x31350, 0x323AF), # CJK Extension H
(0x2EBF0, 0x2EE5D), # CJK Extension H
(0x4E00, 0x9FFF), # CJK Unified Ideographs
(0x3400, 0x4DB5), # CJK Extension A
(0x20000, 0x2A6DD), # CJK Extension B
(0x2A700, 0x2B73F), # CJK Extension C
(0x2B740, 0x2B81F), # CJK Extension D
(0x2B820, 0x2CEAF), # CJK Extension E
(0x2CEB0, 0x2EBEF), # CJK Extension F
(0x30000, 0x3134A), # CJK Extension G
(0x31350, 0x323AF), # CJK Extension H
(0x2EBF0, 0x2EE5D), # CJK Extension H
]
pattern = r'[0-9、-〜。!?.!?… /]+$'
pattern = r"[0-9、-〜。!?.!?… /]+$"
cjk_text = ""
for char in text:
@ -45,7 +46,7 @@ def full_cjk(text):
return cjk_text
def split_jako(tag_lang,item):
def split_jako(tag_lang, item):
if tag_lang == "ja":
pattern = r"([\u3041-\u3096\u3099\u309A\u30A1-\u30FA\u30FC]+(?:[0-9、-〜。!?.!?… ]+[\u3041-\u3096\u3099\u309A\u30A1-\u30FA\u30FC]*)*)"
else:
@ -53,41 +54,42 @@ def split_jako(tag_lang,item):
lang_list: list[dict] = []
tag = 0
for match in re.finditer(pattern, item['text']):
for match in re.finditer(pattern, item["text"]):
if match.start() > tag:
lang_list.append({'lang':item['lang'],'text':item['text'][tag:match.start()]})
lang_list.append({"lang": item["lang"], "text": item["text"][tag : match.start()]})
tag = match.end()
lang_list.append({'lang':tag_lang,'text':item['text'][match.start():match.end()]})
lang_list.append({"lang": tag_lang, "text": item["text"][match.start() : match.end()]})
if tag < len(item['text']):
lang_list.append({'lang':item['lang'],'text':item['text'][tag:len(item['text'])]})
if tag < len(item["text"]):
lang_list.append({"lang": item["lang"], "text": item["text"][tag : len(item["text"])]})
return lang_list
def merge_lang(lang_list, item):
if lang_list and item['lang'] == lang_list[-1]['lang']:
lang_list[-1]['text'] += item['text']
if lang_list and item["lang"] == lang_list[-1]["lang"]:
lang_list[-1]["text"] += item["text"]
else:
lang_list.append(item)
return lang_list
class LangSegmenter():
class LangSegmenter:
# 默认过滤器, 基于gsv目前四种语言
DEFAULT_LANG_MAP = {
"zh": "zh",
"yue": "zh", # 粤语
"wuu": "zh", # 吴语
"zh-cn": "zh",
"zh-tw": "x", # 繁体设置为x
"zh-tw": "x", # 繁体设置为x
"ko": "ko",
"ja": "ja",
"en": "en",
}
def getTexts(text,default_lang = ""):
@staticmethod
def getTexts(text, default_lang=""):
lang_splitter = LangSplitter(lang_map=LangSegmenter.DEFAULT_LANG_MAP)
lang_splitter.merge_across_digit = False
substr = lang_splitter.split_by_lang(text=text)
@ -97,31 +99,31 @@ class LangSegmenter():
have_num = False
for _, item in enumerate(substr):
dict_item = {'lang':item.lang,'text':item.text}
dict_item = {"lang": item.lang, "text": item.text}
if dict_item['lang'] == 'digit':
if dict_item["lang"] == "digit":
if default_lang != "":
dict_item['lang'] = default_lang
dict_item["lang"] = default_lang
else:
have_num = True
lang_list = merge_lang(lang_list,dict_item)
lang_list = merge_lang(lang_list, dict_item)
continue
# 处理短英文被识别为其他语言的问题
if full_en(dict_item['text']):
dict_item['lang'] = 'en'
lang_list = merge_lang(lang_list,dict_item)
if full_en(dict_item["text"]):
dict_item["lang"] = "en"
lang_list = merge_lang(lang_list, dict_item)
continue
if default_lang != "":
dict_item['lang'] = default_lang
lang_list = merge_lang(lang_list,dict_item)
dict_item["lang"] = default_lang
lang_list = merge_lang(lang_list, dict_item)
continue
else:
# 处理非日语夹日文的问题(不包含CJK)
ja_list: list[dict] = []
if dict_item['lang'] != 'ja':
ja_list = split_jako('ja',dict_item)
if dict_item["lang"] != "ja":
ja_list = split_jako("ja", dict_item)
if not ja_list:
ja_list.append(dict_item)
@ -130,8 +132,8 @@ class LangSegmenter():
ko_list: list[dict] = []
temp_list: list[dict] = []
for _, ko_item in enumerate(ja_list):
if ko_item["lang"] != 'ko':
ko_list = split_jako('ko',ko_item)
if ko_item["lang"] != "ko":
ko_list = split_jako("ko", ko_item)
if ko_list:
temp_list.extend(ko_list)
@ -141,77 +143,76 @@ class LangSegmenter():
# 未存在非日韩文夹日韩文
if len(temp_list) == 1:
# 未知语言检查是否为CJK
if dict_item['lang'] == 'x':
cjk_text = full_cjk(dict_item['text'])
if dict_item["lang"] == "x":
cjk_text = full_cjk(dict_item["text"])
if cjk_text:
dict_item = {'lang':'zh','text':cjk_text}
lang_list = merge_lang(lang_list,dict_item)
dict_item = {"lang": "zh", "text": cjk_text}
lang_list = merge_lang(lang_list, dict_item)
else:
lang_list = merge_lang(lang_list,dict_item)
lang_list = merge_lang(lang_list, dict_item)
continue
else:
lang_list = merge_lang(lang_list,dict_item)
lang_list = merge_lang(lang_list, dict_item)
continue
# 存在非日韩文夹日韩文
for _, temp_item in enumerate(temp_list):
# 未知语言检查是否为CJK
if temp_item['lang'] == 'x':
cjk_text = full_cjk(temp_item['text'])
if temp_item["lang"] == "x":
cjk_text = full_cjk(temp_item["text"])
if cjk_text:
lang_list = merge_lang(lang_list,{'lang':'zh','text':cjk_text})
lang_list = merge_lang(lang_list, {"lang": "zh", "text": cjk_text})
else:
lang_list = merge_lang(lang_list,temp_item)
lang_list = merge_lang(lang_list, temp_item)
else:
lang_list = merge_lang(lang_list,temp_item)
lang_list = merge_lang(lang_list, temp_item)
# 有数字
if have_num:
temp_list = lang_list
lang_list = []
for i, temp_item in enumerate(temp_list):
if temp_item['lang'] == 'digit':
if temp_item["lang"] == "digit":
if default_lang:
temp_item['lang'] = default_lang
temp_item["lang"] = default_lang
elif lang_list and i == len(temp_list) - 1:
temp_item['lang'] = lang_list[-1]['lang']
temp_item["lang"] = lang_list[-1]["lang"]
elif not lang_list and i < len(temp_list) - 1:
temp_item['lang'] = temp_list[1]['lang']
temp_item["lang"] = temp_list[1]["lang"]
elif lang_list and i < len(temp_list) - 1:
if lang_list[-1]['lang'] == temp_list[i + 1]['lang']:
temp_item['lang'] = lang_list[-1]['lang']
elif lang_list[-1]['text'][-1] in [",",".","!","?","","","",""]:
temp_item['lang'] = temp_list[i + 1]['lang']
elif temp_list[i + 1]['text'][0] in [",",".","!","?","","","",""]:
temp_item['lang'] = lang_list[-1]['lang']
elif temp_item['text'][-1] in ["","."]:
temp_item['lang'] = lang_list[-1]['lang']
elif len(lang_list[-1]['text']) >= len(temp_list[i + 1]['text']):
temp_item['lang'] = lang_list[-1]['lang']
if lang_list[-1]["lang"] == temp_list[i + 1]["lang"]:
temp_item["lang"] = lang_list[-1]["lang"]
elif lang_list[-1]["text"][-1] in [",", ".", "!", "?", "", "", "", ""]:
temp_item["lang"] = temp_list[i + 1]["lang"]
elif temp_list[i + 1]["text"][0] in [",", ".", "!", "?", "", "", "", ""]:
temp_item["lang"] = lang_list[-1]["lang"]
elif temp_item["text"][-1] in ["", "."]:
temp_item["lang"] = lang_list[-1]["lang"]
elif len(lang_list[-1]["text"]) >= len(temp_list[i + 1]["text"]):
temp_item["lang"] = lang_list[-1]["lang"]
else:
temp_item['lang'] = temp_list[i + 1]['lang']
temp_item["lang"] = temp_list[i + 1]["lang"]
else:
temp_item['lang'] = 'zh'
lang_list = merge_lang(lang_list,temp_item)
temp_item["lang"] = "zh"
lang_list = merge_lang(lang_list, temp_item)
# 筛X
temp_list = lang_list
lang_list = []
for _, temp_item in enumerate(temp_list):
if temp_item['lang'] == 'x':
if temp_item["lang"] == "x":
if lang_list:
temp_item['lang'] = lang_list[-1]['lang']
temp_item["lang"] = lang_list[-1]["lang"]
elif len(temp_list) > 1:
temp_item['lang'] = temp_list[1]['lang']
temp_item["lang"] = temp_list[1]["lang"]
else:
temp_item['lang'] = 'zh'
temp_item["lang"] = "zh"
lang_list = merge_lang(lang_list,temp_item)
lang_list = merge_lang(lang_list, temp_item)
return lang_list
if __name__ == "__main__":
text = "MyGO?,你也喜欢まいご吗?"
@ -221,5 +222,5 @@ if __name__ == "__main__":
print(LangSegmenter.getTexts(text))
text = "当时ThinkPad T60刚刚发布一同推出的还有一款名为Advanced Dock的扩展坞配件。这款扩展坞通过连接T60底部的插槽扩展出包括PCIe在内的一大堆接口并且自带电源让T60可以安装桌面显卡来提升性能。"
print(LangSegmenter.getTexts(text,"zh"))
print(LangSegmenter.getTexts(text))
print(LangSegmenter.getTexts(text, "zh"))
print(LangSegmenter.getTexts(text))

View File

@ -1,12 +1,13 @@
from text import cleaned_text_to_sequence
import os
from text import cleaned_text_to_sequence
# if os.environ.get("version","v1")=="v1":
# from text import chinese
# from text.symbols import symbols
# else:
# from text import chinese2 as chinese
# from text.symbols2 import symbols
from text import symbols as symbols_v1
from text import symbols2 as symbols_v2
@ -18,7 +19,7 @@ special = [
]
def clean_text(text, language, version=None):
def clean_text(text, language, version=None) -> tuple[list[str], list[int] | None, str]:
if version is None:
version = os.environ.get("version", "v2")
if version == "v1":

View File

@ -3,9 +3,10 @@
from __future__ import print_function
import re
import inflect
import unicodedata
import inflect
# 后缀计量单位替换表
measurement_map = {
"m": ["meter", "meters"],
@ -109,7 +110,7 @@ def _expand_measurement(m):
num = int(m.group(1).replace(sign, "").replace(".", ""))
decimal_part = m.group(2)
# 上面判断的漏洞,比如 0.1 的情况,在这里排除了
if decimal_part == None and num == 1:
if decimal_part is None and num == 1:
ptr = 0
return m.group(1).replace(sign, " " + measurement_map[sign][ptr])

View File

@ -3,9 +3,11 @@ import glob
import json
import logging
import os
import shutil
import subprocess
import sys
import traceback
from time import time as ttime
import librosa
import numpy as np
@ -42,9 +44,9 @@ def load_checkpoint(checkpoint_path, model, optimizer=None, skip_optimizer=False
saved_state_dict[k].shape,
v.shape,
)
except:
except AssertionError:
traceback.print_exc()
print("error, %s is not in the checkpoint" % k) # shape不对也会比如text_embedding当cleaner修改时
print(f"error, {k} is not in the checkpoint") # shape不对也会比如text_embedding当cleaner修改时
new_state_dict[k] = v
if hasattr(model, "module"):
model.module.load_state_dict(new_state_dict)
@ -60,14 +62,10 @@ def load_checkpoint(checkpoint_path, model, optimizer=None, skip_optimizer=False
return model, optimizer, learning_rate, iteration
import shutil
from time import time as ttime
def my_save(fea, path): #####fix issue: torch.save doesn't support chinese path
dir = os.path.dirname(path)
name = os.path.basename(path)
tmp_path = "%s.pth" % (ttime())
tmp_path = f"{ttime()}.pth"
torch.save(fea, tmp_path)
shutil.move(tmp_path, "%s/%s" % (dir, name))
@ -136,8 +134,7 @@ def plot_spectrogram_to_numpy(spectrogram):
plt.tight_layout()
fig.canvas.draw()
data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="")
data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
data = np.asarray(fig.canvas.renderer.buffer_rgba(), dtype=np.uint8)[:, :, :3]
plt.close()
return data
@ -169,8 +166,7 @@ def plot_alignment_to_numpy(alignment, info=None):
plt.tight_layout()
fig.canvas.draw()
data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="")
data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
data = np.asarray(fig.canvas.renderer.buffer_rgba(), dtype=np.uint8)[:, :, :3]
plt.close()
return data
@ -245,18 +241,31 @@ def clean_checkpoints(path_to_models="logs/44k/", n_ckpts_to_keep=2, sort_by_tim
import re
ckpts_files = [f for f in os.listdir(path_to_models) if os.path.isfile(os.path.join(path_to_models, f))]
name_key = lambda _f: int(re.compile("._(\d+)\.pth").match(_f).group(1))
time_key = lambda _f: os.path.getmtime(os.path.join(path_to_models, _f))
def name_key(_f):
return int(re.compile("._(\d+)\.pth").match(_f).group(1))
def time_key(_f):
return os.path.getmtime(os.path.join(path_to_models, _f))
sort_key = time_key if sort_by_time else name_key
x_sorted = lambda _x: sorted(
[f for f in ckpts_files if f.startswith(_x) and not f.endswith("_0.pth")],
key=sort_key,
)
def x_sorted(_x):
return sorted(
[f for f in ckpts_files if f.startswith(_x) and not f.endswith("_0.pth")],
key=sort_key,
)
to_del = [
os.path.join(path_to_models, fn) for fn in (x_sorted("G")[:-n_ckpts_to_keep] + x_sorted("D")[:-n_ckpts_to_keep])
]
del_info = lambda fn: logger.info(f".. Free up space by deleting ckpt {fn}")
del_routine = lambda x: [os.remove(x), del_info(x)]
def del_info(fn):
return logger.info(f".. Free up space by deleting ckpt {fn}")
def del_routine(x):
return [os.remove(x), del_info(x)]
rs = [del_routine(fn) for fn in to_del]
@ -324,7 +333,7 @@ def get_logger(model_dir, filename="train.log"):
class HParams:
def __init__(self, **kwargs):
for k, v in kwargs.items():
if type(v) == dict:
if isinstance(v, dict):
v = HParams(**v)
self[k] = v
@ -352,10 +361,13 @@ class HParams:
def __repr__(self):
return self.__dict__.__repr__()
def __getstate__(self):
def to_dict(obj):
if isinstance(obj, HParams):
return {k: to_dict(v) for k, v in obj.items()}
elif isinstance(obj, dict):
return {k: to_dict(v) for k, v in obj.items()}
else:
return obj
if __name__ == "__main__":
print(
load_wav_to_torch(
"/home/fish/wenetspeech/dataset_vq/Y0000022499_wHFSeHEx9CM/S00261.flac",
)
)
return to_dict(self)

Some files were not shown because too many files have changed in this diff Show More