mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2025-09-29 08:49:59 +08:00
Merge 26d5eaf1b4903199f5206eb990f1ee4ffcdcf901 into fdf794e31d1fd6f91c5cb4fbb0396094491a31ac
This commit is contained in:
commit
dfb741fce2
20
.github/build_windows_packages.ps1
vendored
20
.github/build_windows_packages.ps1
vendored
@ -31,8 +31,8 @@ $UVR5_URL = "$baseHF/uvr5_weights.zip"
|
||||
$NLTK_URL = "$baseHF/nltk_data.zip"
|
||||
$JTALK_URL = "$baseHF/open_jtalk_dic_utf_8-1.11.tar.gz"
|
||||
|
||||
$PYTHON_VERSION = "3.11.12"
|
||||
$PY_RELEASE_VERSION = "20250409"
|
||||
$PYTHON_VERSION = "3.10.18"
|
||||
$PY_RELEASE_VERSION = "20250902"
|
||||
|
||||
Write-Host "[INFO] Cleaning .git..."
|
||||
Remove-Item "$srcDir\.git" -Recurse -Force -ErrorAction SilentlyContinue
|
||||
@ -115,12 +115,17 @@ Remove-Item $ffDir.FullName -Recurse -Force
|
||||
Write-Host "[INFO] Installing PyTorch..."
|
||||
& ".\runtime\python.exe" -m ensurepip
|
||||
& ".\runtime\python.exe" -m pip install --upgrade pip --no-warn-script-location
|
||||
|
||||
switch ($cuda) {
|
||||
"cu124" {
|
||||
& ".\runtime\python.exe" -m pip install torch==2.6 torchaudio --index-url https://download.pytorch.org/whl/cu124 --no-warn-script-location
|
||||
"cu126" {
|
||||
& ".\runtime\python.exe" -m pip install psutil ninja packaging wheel "setuptools>=42" --no-warn-script-location
|
||||
& ".\runtime\python.exe" -m pip install torch --index-url https://download.pytorch.org/whl/cu126 --no-warn-script-location
|
||||
& ".\runtime\python.exe" -m pip install flash-attn -i https://xxxxrt666.github.io/PIP-Index/ --no-build-isolation
|
||||
}
|
||||
"cu128" {
|
||||
& ".\runtime\python.exe" -m pip install torch torchaudio --index-url https://download.pytorch.org/whl/cu128 --no-warn-script-location
|
||||
& ".\runtime\python.exe" -m pip install psutil ninja packaging wheel "setuptools>=42" --no-warn-script-location
|
||||
& ".\runtime\python.exe" -m pip install torch --index-url https://download.pytorch.org/whl/cu128 --no-warn-script-location
|
||||
& ".\runtime\python.exe" -m pip install flash-attn -i https://xxxxrt666.github.io/PIP-Index/ --no-build-isolation
|
||||
}
|
||||
default {
|
||||
Write-Error "Unsupported CUDA version: $cuda"
|
||||
@ -129,6 +134,7 @@ switch ($cuda) {
|
||||
}
|
||||
|
||||
Write-Host "[INFO] Installing dependencies..."
|
||||
& ".\runtime\python.exe" -m pip install --pre torchcodec --index-url https://download.pytorch.org/whl/nightly/cpu
|
||||
& ".\runtime\python.exe" -m pip install -r extra-req.txt --no-deps --no-warn-script-location
|
||||
& ".\runtime\python.exe" -m pip install -r requirements.txt --no-warn-script-location
|
||||
|
||||
@ -162,7 +168,7 @@ Copy-Item -Path $curr -Destination $pkgName -Recurse
|
||||
$7zPath = "$pkgName.7z"
|
||||
$start = Get-Date
|
||||
Write-Host "Compress Starting at $start"
|
||||
& "C:\Program Files\7-Zip\7z.exe" a -t7z "$7zPath" "$pkgName" -m0=lzma2 -mx=9 -md=1g -ms=1g -mmc=500 -mfb=273 -mlc=0 -mlp=4 -mpb=4 -mc=8g -mmt=on -bsp1
|
||||
& "C:\Program Files\7-Zip\7z.exe" a -t7z "$7zPath" "$pkgName" -m0=lzma2 -mx=9 -mmt=on -bsp1
|
||||
$end = Get-Date
|
||||
Write-Host "Elapsed time: $($end - $start)"
|
||||
Get-ChildItem .
|
||||
@ -189,6 +195,6 @@ if (-not $hfUser -or -not $hfToken) {
|
||||
exit 1
|
||||
}
|
||||
$env:HF_HUB_ENABLE_HF_TRANSFER = "1"
|
||||
huggingface-cli upload "$hfUser/GPT-SoVITS-Packages" "$7zPath" "$7zPath" --repo-type model --token $hfToken
|
||||
hf upload "$hfUser/GPT-SoVITS-Packages" "$7zPath" "$7zPath" --repo-type model --token $hfToken
|
||||
|
||||
Write-Host "[SUCCESS] Uploaded: $7zPath to HuggingFace"
|
||||
|
10
.github/workflows/build_windows_packages.yaml
vendored
10
.github/workflows/build_windows_packages.yaml
vendored
@ -17,7 +17,7 @@ jobs:
|
||||
runs-on: windows-latest
|
||||
strategy:
|
||||
matrix:
|
||||
torch_cuda: [cu124, cu128]
|
||||
torch_cuda: [cu126, cu128]
|
||||
env:
|
||||
TORCH_CUDA: ${{ matrix.torch_cuda }}
|
||||
MODELSCOPE_USERNAME: ${{ secrets.MODELSCOPE_USERNAME }}
|
||||
@ -31,6 +31,14 @@ jobs:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Install Windows CUDA 12.9
|
||||
uses: Jimver/cuda-toolkit@v0.2.24
|
||||
id: cuda-toolkit-win-129
|
||||
with:
|
||||
cuda: 12.9.0
|
||||
method: "network"
|
||||
sub-packages: '["nvcc", "cudart", "visual_studio_integration"]'
|
||||
|
||||
- name: Run Build and Upload Script
|
||||
shell: pwsh
|
||||
run: |
|
||||
|
5
.gitignore
vendored
5
.gitignore
vendored
@ -16,8 +16,9 @@ ffprobe*
|
||||
cfg.json
|
||||
speakers.json
|
||||
ref_audios
|
||||
tools/AP_BWE_main/24kto48k/*
|
||||
!tools/AP_BWE_main/24kto48k/readme.txt
|
||||
tools/AP_BWE/24kto48k/*
|
||||
!tools/AP_BWE/24kto48k/readme.txt
|
||||
onnx_export
|
||||
|
||||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
|
@ -23,8 +23,10 @@ fi
|
||||
|
||||
if [ "$TARGETPLATFORM" = "linux/amd64" ]; then
|
||||
"${WGET_CMD[@]}" -O miniconda.sh https://repo.anaconda.com/miniconda/Miniconda3-py311_25.3.1-1-Linux-x86_64.sh
|
||||
SYSROOT_PKG="sysroot_linux-64>=2.28"
|
||||
elif [ "$TARGETPLATFORM" = "linux/arm64" ]; then
|
||||
"${WGET_CMD[@]}" -O miniconda.sh https://repo.anaconda.com/miniconda/Miniconda3-py311_25.3.1-1-Linux-aarch64.sh
|
||||
SYSROOT_PKG="sysroot_linux-aarch64>=2.28"
|
||||
else
|
||||
exit 1
|
||||
fi
|
||||
@ -45,20 +47,36 @@ rm miniconda.sh
|
||||
|
||||
source "$HOME/miniconda3/etc/profile.d/conda.sh"
|
||||
|
||||
"$HOME/miniconda3/bin/conda" init bash
|
||||
|
||||
source "$HOME/.bashrc"
|
||||
|
||||
"$HOME/miniconda3/bin/conda" config --add channels conda-forge
|
||||
|
||||
"$HOME/miniconda3/bin/conda" update -q --all -y 1>/dev/null
|
||||
|
||||
"$HOME/miniconda3/bin/conda" install python=3.11 -q -y
|
||||
|
||||
"$HOME/miniconda3/bin/conda" install gcc=14 gxx ffmpeg cmake make unzip -q -y
|
||||
"$HOME/miniconda3/bin/conda" install gcc=11 gxx ffmpeg cmake make unzip $SYSROOT_PKG "libstdcxx-ng>=11" -q -y
|
||||
|
||||
if [ "$CUDA_VERSION" = "12.8" ]; then
|
||||
"$HOME/miniconda3/bin/pip" install torch torchaudio --no-cache-dir --index-url https://download.pytorch.org/whl/cu128
|
||||
"$HOME/miniconda3/bin/conda" install cuda-nvcc=12.8 -c nvidia
|
||||
elif [ "$CUDA_VERSION" = "12.6" ]; then
|
||||
"$HOME/miniconda3/bin/pip" install torch==2.6 torchaudio --no-cache-dir --index-url https://download.pytorch.org/whl/cu126
|
||||
"$HOME/miniconda3/bin/pip" install torch torchaudio --no-cache-dir --index-url https://download.pytorch.org/whl/cu126
|
||||
"$HOME/miniconda3/bin/conda" install cuda-nvcc=12.6 -c nvidia
|
||||
fi
|
||||
|
||||
CUDA_PATH=$(echo "$HOME/miniconda3/targets/"*-linux | awk '{print $1}')
|
||||
|
||||
export CUDA_HOME=$CUDA_PATH
|
||||
export PATH="$HOME/miniconda3/bin:$PATH"
|
||||
export PATH="$CUDA_HOME/bin:$PATH"
|
||||
export PATH="$CUDA_HOME/nvvm/bin:$PATH"
|
||||
|
||||
"$HOME/miniconda3/bin/pip" install psutil ninja packaging wheel "setuptools>=42"
|
||||
"$HOME/miniconda3/bin/pip" install flash-attn -i https://xxxxrt666.github.io/PIP-Index/ --no-build-isolation
|
||||
|
||||
"$HOME/miniconda3/bin/pip" cache purge
|
||||
|
||||
rm $LOG_PATH
|
||||
|
@ -39,12 +39,12 @@ class DistributedBucketSampler(Sampler[T_co]):
|
||||
if num_replicas is None:
|
||||
if not dist.is_available():
|
||||
raise RuntimeError("Requires distributed package to be available")
|
||||
num_replicas = dist.get_world_size() if torch.cuda.is_available() else 1
|
||||
num_replicas = dist.get_world_size() if torch.cuda.device_count() > 1 else 1
|
||||
if rank is None:
|
||||
if not dist.is_available():
|
||||
raise RuntimeError("Requires distributed package to be available")
|
||||
rank = dist.get_rank() if torch.cuda.is_available() else 0
|
||||
if torch.cuda.is_available():
|
||||
rank = dist.get_rank() if torch.cuda.device_count() > 1 else 0
|
||||
if torch.cuda.device_count() > 1:
|
||||
torch.cuda.set_device(rank)
|
||||
if rank >= num_replicas or rank < 0:
|
||||
raise ValueError("Invalid rank {}, rank should be in the interval [0, {}]".format(rank, num_replicas - 1))
|
||||
|
@ -3,8 +3,8 @@
|
||||
from pytorch_lightning import LightningDataModule
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from AR.data.bucket_sampler import DistributedBucketSampler
|
||||
from AR.data.dataset import Text2SemanticDataset
|
||||
from GPT_SoVITS.AR.data.bucket_sampler import DistributedBucketSampler
|
||||
from GPT_SoVITS.AR.data.dataset import Text2SemanticDataset
|
||||
|
||||
|
||||
class Text2SemanticDataModule(LightningDataModule):
|
||||
|
@ -11,9 +11,9 @@ import pandas as pd
|
||||
import torch
|
||||
from torch.utils.data import DataLoader, Dataset
|
||||
|
||||
version = os.environ.get("version", None)
|
||||
from GPT_SoVITS.text import cleaned_text_to_sequence
|
||||
|
||||
from text import cleaned_text_to_sequence
|
||||
version = os.environ.get("version", None)
|
||||
|
||||
# from config import exp_dir
|
||||
|
||||
@ -220,7 +220,7 @@ class Text2SemanticDataset(Dataset):
|
||||
|
||||
flag = 0
|
||||
path_bert = "%s/%s.pt" % (self.path3, item_name)
|
||||
if os.path.exists(path_bert) == True:
|
||||
if os.path.exists(path_bert) is True:
|
||||
bert_feature = torch.load(path_bert, map_location="cpu")
|
||||
else:
|
||||
flag = 1
|
||||
|
@ -1,18 +1,14 @@
|
||||
# modified from https://github.com/yangdongchao/SoundStorm/blob/master/soundstorm/s1/AR/models/t2s_lightning_module.py
|
||||
# reference: https://github.com/lifeiteng/vall-e
|
||||
import os
|
||||
import sys
|
||||
|
||||
now_dir = os.getcwd()
|
||||
sys.path.append(now_dir)
|
||||
from typing import Dict
|
||||
|
||||
import torch
|
||||
from pytorch_lightning import LightningModule
|
||||
|
||||
from AR.models.t2s_model import Text2SemanticDecoder
|
||||
from AR.modules.lr_schedulers import WarmupCosineLRSchedule
|
||||
from AR.modules.optim import ScaledAdam
|
||||
from ..modules.lr_schedulers import WarmupCosineLRSchedule
|
||||
from ..modules.optim import ScaledAdam
|
||||
from .t2s_model import Text2SemanticDecoder
|
||||
|
||||
|
||||
class Text2SemanticLightningModule(LightningModule):
|
||||
@ -42,7 +38,7 @@ class Text2SemanticLightningModule(LightningModule):
|
||||
def training_step(self, batch: Dict, batch_idx: int):
|
||||
opt = self.optimizers()
|
||||
scheduler = self.lr_schedulers()
|
||||
forward = self.model.forward if self.config["train"].get("if_dpo", False) == True else self.model.forward_old
|
||||
forward = self.model.forward if self.config["train"].get("if_dpo", False) is True else self.model.forward_old
|
||||
loss, acc = forward(
|
||||
batch["phoneme_ids"],
|
||||
batch["phoneme_ids_len"],
|
||||
|
@ -1,18 +1,10 @@
|
||||
# modified from https://github.com/yangdongchao/SoundStorm/blob/master/soundstorm/s1/AR/models/t2s_lightning_module.py
|
||||
# reference: https://github.com/lifeiteng/vall-e
|
||||
import os
|
||||
import sys
|
||||
|
||||
now_dir = os.getcwd()
|
||||
sys.path.append(now_dir)
|
||||
from typing import Dict
|
||||
|
||||
import torch
|
||||
from pytorch_lightning import LightningModule
|
||||
|
||||
from AR.models.t2s_model_onnx import Text2SemanticDecoder
|
||||
from AR.modules.lr_schedulers import WarmupCosineLRSchedule
|
||||
from AR.modules.optim import ScaledAdam
|
||||
from .t2s_model_onnx import Text2SemanticDecoder
|
||||
|
||||
|
||||
class Text2SemanticLightningModule(LightningModule):
|
||||
@ -21,90 +13,3 @@ class Text2SemanticLightningModule(LightningModule):
|
||||
self.config = config
|
||||
self.top_k = 3
|
||||
self.model = Text2SemanticDecoder(config=config, top_k=self.top_k)
|
||||
pretrained_s1 = config.get("pretrained_s1")
|
||||
if pretrained_s1 and is_train:
|
||||
# print(self.load_state_dict(torch.load(pretrained_s1,map_location="cpu")["state_dict"]))
|
||||
print(
|
||||
self.load_state_dict(
|
||||
torch.load(
|
||||
pretrained_s1,
|
||||
map_location="cpu",
|
||||
)["weight"],
|
||||
),
|
||||
)
|
||||
if is_train:
|
||||
self.automatic_optimization = False
|
||||
self.save_hyperparameters()
|
||||
self.eval_dir = output_dir / "eval"
|
||||
self.eval_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
def training_step(self, batch: Dict, batch_idx: int):
|
||||
opt = self.optimizers()
|
||||
scheduler = self.lr_schedulers()
|
||||
loss, acc = self.model.forward(
|
||||
batch["phoneme_ids"],
|
||||
batch["phoneme_ids_len"],
|
||||
batch["semantic_ids"],
|
||||
batch["semantic_ids_len"],
|
||||
batch["bert_feature"],
|
||||
)
|
||||
self.manual_backward(loss)
|
||||
if batch_idx > 0 and batch_idx % 4 == 0:
|
||||
opt.step()
|
||||
opt.zero_grad()
|
||||
scheduler.step()
|
||||
|
||||
self.log(
|
||||
"total_loss",
|
||||
loss,
|
||||
on_step=True,
|
||||
on_epoch=True,
|
||||
prog_bar=True,
|
||||
sync_dist=True,
|
||||
)
|
||||
self.log(
|
||||
"lr",
|
||||
scheduler.get_last_lr()[0],
|
||||
on_epoch=True,
|
||||
prog_bar=True,
|
||||
sync_dist=True,
|
||||
)
|
||||
self.log(
|
||||
f"top_{self.top_k}_acc",
|
||||
acc,
|
||||
on_step=True,
|
||||
on_epoch=True,
|
||||
prog_bar=True,
|
||||
sync_dist=True,
|
||||
)
|
||||
|
||||
def validation_step(self, batch: Dict, batch_idx: int):
|
||||
return
|
||||
|
||||
def configure_optimizers(self):
|
||||
model_parameters = self.model.parameters()
|
||||
parameters_names = []
|
||||
parameters_names.append([name_param_pair[0] for name_param_pair in self.model.named_parameters()])
|
||||
lm_opt = ScaledAdam(
|
||||
model_parameters,
|
||||
lr=0.01,
|
||||
betas=(0.9, 0.95),
|
||||
clipping_scale=2.0,
|
||||
parameters_names=parameters_names,
|
||||
show_dominant_parameters=False,
|
||||
clipping_update_period=1000,
|
||||
)
|
||||
|
||||
return {
|
||||
"optimizer": lm_opt,
|
||||
"lr_scheduler": {
|
||||
"scheduler": WarmupCosineLRSchedule(
|
||||
lm_opt,
|
||||
init_lr=self.config["optimizer"]["lr_init"],
|
||||
peak_lr=self.config["optimizer"]["lr"],
|
||||
end_lr=self.config["optimizer"]["lr_end"],
|
||||
warmup_steps=self.config["optimizer"]["warmup_steps"],
|
||||
total_steps=self.config["optimizer"]["decay_steps"],
|
||||
)
|
||||
},
|
||||
}
|
||||
|
@ -9,7 +9,7 @@ from torch.nn import functional as F
|
||||
from torchmetrics.classification import MulticlassAccuracy
|
||||
from tqdm import tqdm
|
||||
|
||||
from AR.models.utils import (
|
||||
from GPT_SoVITS.AR.models.utils import (
|
||||
dpo_loss,
|
||||
get_batch_logps,
|
||||
make_pad_mask,
|
||||
@ -18,8 +18,8 @@ from AR.models.utils import (
|
||||
sample,
|
||||
topk_sampling,
|
||||
)
|
||||
from AR.modules.embedding import SinePositionalEmbedding, TokenEmbedding
|
||||
from AR.modules.transformer import LayerNorm, TransformerEncoder, TransformerEncoderLayer
|
||||
from GPT_SoVITS.AR.modules.embedding import SinePositionalEmbedding, TokenEmbedding
|
||||
from GPT_SoVITS.AR.modules.transformer import LayerNorm, TransformerEncoder, TransformerEncoderLayer
|
||||
|
||||
default_config = {
|
||||
"embedding_dim": 512,
|
||||
@ -420,7 +420,7 @@ class Text2SemanticDecoder(nn.Module):
|
||||
mask=xy_attn_mask,
|
||||
)
|
||||
x_len = x_lens.max()
|
||||
logits = self.ar_predict_layer(xy_dec[:, x_len-1:])
|
||||
logits = self.ar_predict_layer(xy_dec[:, x_len - 1 :])
|
||||
|
||||
###### DPO #############
|
||||
reject_xy_pos, reject_xy_attn_mask, reject_targets = self.make_input_data(
|
||||
@ -432,7 +432,7 @@ class Text2SemanticDecoder(nn.Module):
|
||||
mask=reject_xy_attn_mask,
|
||||
)
|
||||
x_len = x_lens.max()
|
||||
reject_logits = self.ar_predict_layer(reject_xy_dec[:, x_len-1:])
|
||||
reject_logits = self.ar_predict_layer(reject_xy_dec[:, x_len - 1 :])
|
||||
|
||||
# loss
|
||||
# from feiteng: 每次 duration 越多, 梯度更新也应该更多, 所以用 sum
|
||||
@ -502,7 +502,7 @@ class Text2SemanticDecoder(nn.Module):
|
||||
(xy_pos, None),
|
||||
mask=xy_attn_mask,
|
||||
)
|
||||
logits = self.ar_predict_layer(xy_dec[:, x_len-1:]).permute(0, 2, 1)
|
||||
logits = self.ar_predict_layer(xy_dec[:, x_len - 1 :]).permute(0, 2, 1)
|
||||
# loss
|
||||
# from feiteng: 每次 duration 越多, 梯度更新也应该更多, 所以用 sum
|
||||
loss = F.cross_entropy(logits, targets, reduction="sum")
|
||||
@ -724,8 +724,8 @@ class Text2SemanticDecoder(nn.Module):
|
||||
l1 = samples[:, 0] == self.EOS
|
||||
l2 = tokens == self.EOS
|
||||
l = l1.logical_or(l2)
|
||||
removed_idx_of_batch_for_y = torch.where(l == True)[0].tolist()
|
||||
reserved_idx_of_batch_for_y = torch.where(l == False)[0]
|
||||
removed_idx_of_batch_for_y = torch.where(l is True)[0].tolist()
|
||||
reserved_idx_of_batch_for_y = torch.where(l is False)[0]
|
||||
# batch_indexs = torch.tensor(batch_idx_map, device=y.device)[removed_idx_of_batch_for_y]
|
||||
for i in removed_idx_of_batch_for_y:
|
||||
batch_index = batch_idx_map[i]
|
||||
|
@ -5,8 +5,8 @@ from torch import nn
|
||||
from torch.nn import functional as F
|
||||
from torchmetrics.classification import MulticlassAccuracy
|
||||
|
||||
from AR.modules.embedding_onnx import SinePositionalEmbedding, TokenEmbedding
|
||||
from AR.modules.transformer_onnx import LayerNorm, TransformerEncoder, TransformerEncoderLayer
|
||||
from GPT_SoVITS.AR.modules.embedding_onnx import SinePositionalEmbedding, TokenEmbedding
|
||||
from GPT_SoVITS.AR.modules.transformer_onnx import LayerNorm, TransformerEncoder, TransformerEncoderLayer
|
||||
|
||||
default_config = {
|
||||
"embedding_dim": 512,
|
||||
|
@ -9,7 +9,7 @@ from torch.nn.init import constant_, xavier_normal_, xavier_uniform_
|
||||
from torch.nn.modules.linear import NonDynamicallyQuantizableLinear
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from AR.modules.patched_mha_with_cache import multi_head_attention_forward_patched
|
||||
from .patched_mha_with_cache import multi_head_attention_forward_patched
|
||||
|
||||
F.multi_head_attention_forward = multi_head_attention_forward_patched
|
||||
|
||||
@ -86,8 +86,8 @@ class MultiheadAttention(Module):
|
||||
kdim=None,
|
||||
vdim=None,
|
||||
batch_first=False,
|
||||
linear1_cls=Linear,
|
||||
linear2_cls=Linear,
|
||||
linear1_cls: type[Module] = Linear,
|
||||
linear2_cls: type[Module] = Linear,
|
||||
device=None,
|
||||
dtype=None,
|
||||
) -> None:
|
||||
@ -383,7 +383,7 @@ class MultiheadAttention(Module):
|
||||
k_proj_weight=self.k_proj_weight,
|
||||
v_proj_weight=self.v_proj_weight,
|
||||
average_attn_weights=average_attn_weights,
|
||||
cache=cache,
|
||||
cache=cache, # type: ignore
|
||||
)
|
||||
else:
|
||||
attn_output, attn_output_weights = F.multi_head_attention_forward(
|
||||
@ -405,7 +405,7 @@ class MultiheadAttention(Module):
|
||||
need_weights=need_weights,
|
||||
attn_mask=attn_mask,
|
||||
average_attn_weights=average_attn_weights,
|
||||
cache=cache,
|
||||
cache=cache, # type: ignore
|
||||
)
|
||||
if self.batch_first and is_batched:
|
||||
return attn_output.transpose(1, 0), attn_output_weights
|
||||
|
@ -1,5 +1,5 @@
|
||||
# modified from https://github.com/lifeiteng/vall-e/blob/main/valle/modules/activation.py
|
||||
from typing import Optional, Tuple
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
@ -8,7 +8,7 @@ from torch.nn.init import constant_, xavier_normal_, xavier_uniform_
|
||||
from torch.nn.modules.linear import NonDynamicallyQuantizableLinear
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from AR.modules.patched_mha_with_cache_onnx import multi_head_attention_forward_patched
|
||||
from .patched_mha_with_cache_onnx import multi_head_attention_forward_patched
|
||||
|
||||
|
||||
class MultiheadAttention(Module):
|
||||
@ -161,7 +161,7 @@ class MultiheadAttention(Module):
|
||||
attn_mask: Optional[Tensor] = None,
|
||||
average_attn_weights: bool = True,
|
||||
cache=None,
|
||||
) -> Tuple[Tensor, Optional[Tensor]]:
|
||||
) -> Tensor:
|
||||
any_nested = query.is_nested or key.is_nested or value.is_nested
|
||||
query = key = value = query.transpose(1, 0)
|
||||
attn_output = multi_head_attention_forward_patched(
|
||||
|
@ -1,6 +1,7 @@
|
||||
# modified from https://github.com/yangdongchao/SoundStorm/blob/master/soundstorm/s1/AR/modules/lr_schedulers.py
|
||||
# reference: https://github.com/lifeiteng/vall-e
|
||||
import math
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from matplotlib import pyplot as plt
|
||||
@ -38,10 +39,9 @@ class WarmupCosineLRSchedule(torch.optim.lr_scheduler._LRScheduler):
|
||||
def set_lr(self, lr):
|
||||
self._last_lr = [g["lr"] for g in self.optimizer.param_groups]
|
||||
for g in self.optimizer.param_groups:
|
||||
# g['lr'] = lr
|
||||
g["lr"] = self.end_lr ###锁定用线性
|
||||
g["lr"] = self.end_lr
|
||||
|
||||
def step(self):
|
||||
def step(self, epoch: Optional[int] = None):
|
||||
if self._current_step < self.warmup_steps:
|
||||
lr = self.init_lr + self._warmup_rate * self._current_step
|
||||
|
||||
@ -55,11 +55,10 @@ class WarmupCosineLRSchedule(torch.optim.lr_scheduler._LRScheduler):
|
||||
coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio))
|
||||
lr = self.end_lr + coeff * (self.peak_lr - self.end_lr)
|
||||
|
||||
self.lr = lr = self.end_lr = 0.002 ###锁定用线性###不听话,直接锁定!
|
||||
self.lr = lr = self.end_lr = 0.002
|
||||
self.set_lr(lr)
|
||||
self.lr = lr
|
||||
self._current_step += 1
|
||||
return self.lr
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -13,6 +13,7 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import contextlib
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
|
@ -1,43 +1,46 @@
|
||||
from torch.nn.functional import *
|
||||
from torch.nn.functional import (
|
||||
_mha_shape_check,
|
||||
_canonical_mask,
|
||||
_none_or_dtype,
|
||||
_in_projection_packed,
|
||||
)
|
||||
import math
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
# Tensor = torch.Tensor
|
||||
# from typing import Callable, List, Optional, Tuple, Union
|
||||
from torch.nn.functional import * # noqa: F403
|
||||
from torch.nn.functional import (
|
||||
_canonical_mask,
|
||||
_in_projection_packed, # type: ignore
|
||||
_mha_shape_check, # type: ignore
|
||||
_none_or_dtype,
|
||||
)
|
||||
|
||||
Tensor = torch.Tensor
|
||||
|
||||
|
||||
def multi_head_attention_forward_patched(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
embed_dim_to_check,
|
||||
num_heads,
|
||||
embed_dim_to_check: int,
|
||||
num_heads: int,
|
||||
in_proj_weight,
|
||||
in_proj_bias,
|
||||
bias_k,
|
||||
bias_v,
|
||||
add_zero_attn,
|
||||
in_proj_bias: Optional[Tensor],
|
||||
bias_k: Optional[Tensor],
|
||||
bias_v: Optional[Tensor],
|
||||
add_zero_attn: bool,
|
||||
dropout_p: float,
|
||||
out_proj_weight,
|
||||
out_proj_bias,
|
||||
training=True,
|
||||
key_padding_mask=None,
|
||||
need_weights=True,
|
||||
attn_mask=None,
|
||||
use_separate_proj_weight=False,
|
||||
q_proj_weight=None,
|
||||
k_proj_weight=None,
|
||||
v_proj_weight=None,
|
||||
static_k=None,
|
||||
static_v=None,
|
||||
average_attn_weights=True,
|
||||
is_causal=False,
|
||||
cache=None,
|
||||
):
|
||||
out_proj_weight: Tensor,
|
||||
out_proj_bias: Optional[Tensor],
|
||||
training: bool = True,
|
||||
key_padding_mask: Optional[Tensor] = None,
|
||||
need_weights: bool = True,
|
||||
attn_mask: Optional[Tensor] = None,
|
||||
use_separate_proj_weight: bool = False,
|
||||
q_proj_weight: Optional[Tensor] = None,
|
||||
k_proj_weight: Optional[Tensor] = None,
|
||||
v_proj_weight: Optional[Tensor] = None,
|
||||
static_k: Optional[Tensor] = None,
|
||||
static_v: Optional[Tensor] = None,
|
||||
average_attn_weights: bool = True,
|
||||
is_causal: bool = False,
|
||||
cache: dict | None = None,
|
||||
) -> Tuple[Tensor, Tensor | None]:
|
||||
r"""
|
||||
Args:
|
||||
query, key, value: map a query and a set of key-value pairs to an output.
|
||||
@ -250,27 +253,18 @@ def multi_head_attention_forward_patched(
|
||||
b_k,
|
||||
b_v,
|
||||
)
|
||||
if cache != None:
|
||||
if cache is not None:
|
||||
if cache["first_infer"] == 1:
|
||||
cache["k"][cache["stage"]] = k
|
||||
# print(0,cache["k"].shape)
|
||||
cache["v"][cache["stage"]] = v
|
||||
else: ###12个layer每个都要留自己的cache_kv
|
||||
# print(1,cache["k"].shape)
|
||||
cache["k"][cache["stage"]] = torch.cat(
|
||||
[cache["k"][cache["stage"]], k], 0
|
||||
) ##本来时序是1,但是proj的时候可能transpose了所以时序到0维了
|
||||
else:
|
||||
cache["k"][cache["stage"]] = torch.cat([cache["k"][cache["stage"]], k], 0)
|
||||
cache["v"][cache["stage"]] = torch.cat([cache["v"][cache["stage"]], v], 0)
|
||||
# print(2, cache["k"].shape)
|
||||
src_len = cache["k"][cache["stage"]].shape[0]
|
||||
k = cache["k"][cache["stage"]]
|
||||
v = cache["v"][cache["stage"]]
|
||||
# if attn_mask is not None:
|
||||
# attn_mask=attn_mask[-1:,]
|
||||
# print(attn_mask.shape,attn_mask)
|
||||
cache["stage"] = (cache["stage"] + 1) % cache["all_stage"]
|
||||
# print(2333,cache)
|
||||
# prep attention mask
|
||||
|
||||
attn_mask = _canonical_mask(
|
||||
mask=attn_mask,
|
||||
|
@ -1,7 +1,9 @@
|
||||
from torch.nn.functional import *
|
||||
from torch.nn.functional import (
|
||||
_canonical_mask,
|
||||
)
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from torch.nn.functional import _canonical_mask, linear, scaled_dot_product_attention
|
||||
|
||||
Tensor = torch.Tensor
|
||||
|
||||
|
||||
def multi_head_attention_forward_patched(
|
||||
@ -30,8 +32,8 @@ def multi_head_attention_forward_patched(
|
||||
static_v: Optional[Tensor] = None,
|
||||
average_attn_weights: bool = True,
|
||||
is_causal: bool = False,
|
||||
cache=None,
|
||||
) -> Tuple[Tensor, Optional[Tensor]]:
|
||||
cache: dict | None = None,
|
||||
) -> Tensor:
|
||||
# set up shape vars
|
||||
_, _, embed_dim = query.shape
|
||||
attn_mask = _canonical_mask(
|
||||
@ -48,6 +50,7 @@ def multi_head_attention_forward_patched(
|
||||
proj_qkv = proj_qkv.unflatten(-1, (3, query.size(-1))).unsqueeze(0).transpose(0, -2).squeeze(-2).contiguous()
|
||||
q, k, v = proj_qkv[0], proj_qkv[1], proj_qkv[2]
|
||||
|
||||
assert cache
|
||||
if cache["first_infer"] == 1:
|
||||
cache["k"][cache["stage"]] = k
|
||||
cache["v"][cache["stage"]] = v
|
||||
@ -66,6 +69,7 @@ def multi_head_attention_forward_patched(
|
||||
target_type=q.dtype,
|
||||
check_other=False,
|
||||
)
|
||||
assert attn_mask
|
||||
attn_mask = attn_mask.unsqueeze(0)
|
||||
|
||||
q = q.view(-1, num_heads, head_dim).transpose(0, 1)
|
||||
|
@ -14,8 +14,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import random
|
||||
from typing import Optional
|
||||
from typing import Tuple
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@ -41,7 +40,6 @@ class DoubleSwishFunction(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, x: Tensor) -> Tensor:
|
||||
requires_grad = x.requires_grad
|
||||
x_dtype = x.dtype
|
||||
if x.dtype == torch.float16:
|
||||
x = x.to(torch.float32)
|
||||
|
||||
|
@ -2,20 +2,15 @@
|
||||
import copy
|
||||
import numbers
|
||||
from functools import partial
|
||||
from typing import Any
|
||||
from typing import Callable
|
||||
from typing import List
|
||||
from typing import Optional
|
||||
from typing import Tuple
|
||||
from typing import Union
|
||||
from typing import Any, Callable, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from AR.modules.activation import MultiheadAttention
|
||||
from AR.modules.scaling import BalancedDoubleSwish
|
||||
from torch import nn
|
||||
from torch import Tensor
|
||||
from torch import Tensor, nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
from .activation import MultiheadAttention
|
||||
from .scaling import BalancedDoubleSwish
|
||||
|
||||
_shape_t = Union[int, List[int], torch.Size]
|
||||
|
||||
|
||||
@ -55,7 +50,7 @@ class LayerNorm(nn.Module):
|
||||
nn.init.ones_(self.weight)
|
||||
nn.init.zeros_(self.bias)
|
||||
|
||||
def forward(self, input: Tensor, embedding: Any = None) -> Tensor:
|
||||
def forward(self, input: Tensor, embedding: Any = None) -> tuple[Tensor, Any] | Tensor:
|
||||
if isinstance(input, tuple):
|
||||
input, embedding = input
|
||||
return (
|
||||
@ -128,7 +123,7 @@ class TransformerEncoder(nn.Module):
|
||||
src_key_padding_mask: Optional[Tensor] = None,
|
||||
return_layer_states: bool = False,
|
||||
cache=None,
|
||||
) -> Tensor:
|
||||
) -> Tensor | tuple[list[Tensor], Tensor]:
|
||||
r"""Pass the input through the encoder layers in turn.
|
||||
|
||||
Args:
|
||||
@ -186,11 +181,11 @@ class TransformerEncoderLayer(nn.Module):
|
||||
norm_first: bool = False,
|
||||
device=None,
|
||||
dtype=None,
|
||||
linear1_self_attention_cls: nn.Module = nn.Linear,
|
||||
linear2_self_attention_cls: nn.Module = nn.Linear,
|
||||
linear1_feedforward_cls: nn.Module = nn.Linear,
|
||||
linear2_feedforward_cls: nn.Module = nn.Linear,
|
||||
layer_norm_cls: nn.Module = LayerNorm,
|
||||
linear1_self_attention_cls: type[nn.Module] = nn.Linear,
|
||||
linear2_self_attention_cls: type[nn.Module] = nn.Linear,
|
||||
linear1_feedforward_cls: type[nn.Module] = nn.Linear,
|
||||
linear2_feedforward_cls: type[nn.Module] = nn.Linear,
|
||||
layer_norm_cls: type[nn.Module] = LayerNorm,
|
||||
layer_norm_eps: float = 1e-5,
|
||||
adaptive_layer_norm=False,
|
||||
) -> None:
|
||||
@ -260,7 +255,7 @@ class TransformerEncoderLayer(nn.Module):
|
||||
src_mask: Optional[Tensor] = None,
|
||||
src_key_padding_mask: Optional[Tensor] = None,
|
||||
cache=None,
|
||||
) -> Tensor:
|
||||
) -> Tensor | tuple[Tensor, Any]:
|
||||
r"""Pass the input through the encoder layer.
|
||||
|
||||
Args:
|
||||
|
@ -2,20 +2,15 @@
|
||||
import copy
|
||||
import numbers
|
||||
from functools import partial
|
||||
from typing import Any
|
||||
from typing import Callable
|
||||
from typing import List
|
||||
from typing import Optional
|
||||
from typing import Tuple
|
||||
from typing import Union
|
||||
from typing import Any, Callable, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from AR.modules.activation_onnx import MultiheadAttention
|
||||
from AR.modules.scaling import BalancedDoubleSwish
|
||||
from torch import nn
|
||||
from torch import Tensor
|
||||
from torch import Tensor, nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
from GPT_SoVITS.AR.modules.activation_onnx import MultiheadAttention
|
||||
from GPT_SoVITS.AR.modules.scaling import BalancedDoubleSwish
|
||||
|
||||
_shape_t = Union[int, List[int], torch.Size]
|
||||
|
||||
|
||||
|
@ -1,72 +0,0 @@
|
||||
# modified from https://github.com/yangdongchao/SoundStorm/blob/master/soundstorm/s1/AR/text_processing/phonemizer.py
|
||||
# reference: https://github.com/lifeiteng/vall-e
|
||||
import itertools
|
||||
import re
|
||||
from typing import Dict
|
||||
from typing import List
|
||||
|
||||
import regex
|
||||
from gruut import sentences
|
||||
from gruut.const import Sentence
|
||||
from gruut.const import Word
|
||||
from AR.text_processing.symbols import SYMBOL_TO_ID
|
||||
|
||||
|
||||
class GruutPhonemizer:
|
||||
def __init__(self, language: str):
|
||||
self._phonemizer = sentences
|
||||
self.lang = language
|
||||
self.symbol_to_id = SYMBOL_TO_ID
|
||||
self._special_cases_dict: Dict[str] = {
|
||||
r"\.\.\.": "... ",
|
||||
";": "; ",
|
||||
":": ": ",
|
||||
",": ", ",
|
||||
r"\.": ". ",
|
||||
"!": "! ",
|
||||
r"\?": "? ",
|
||||
"—": "—",
|
||||
"…": "… ",
|
||||
"«": "«",
|
||||
"»": "»",
|
||||
}
|
||||
self._punctuation_regexp: str = rf"([{''.join(self._special_cases_dict.keys())}])"
|
||||
|
||||
def _normalize_punctuation(self, text: str) -> str:
|
||||
text = regex.sub(rf"\pZ+{self._punctuation_regexp}", r"\1", text)
|
||||
text = regex.sub(rf"{self._punctuation_regexp}(\pL)", r"\1 \2", text)
|
||||
text = regex.sub(r"\pZ+", r" ", text)
|
||||
return text.strip()
|
||||
|
||||
def _convert_punctuation(self, word: Word) -> str:
|
||||
if not word.phonemes:
|
||||
return ""
|
||||
if word.phonemes[0] in ["‖", "|"]:
|
||||
return word.text.strip()
|
||||
|
||||
phonemes = "".join(word.phonemes)
|
||||
# remove modifier characters ˈˌː with regex
|
||||
phonemes = re.sub(r"[ˈˌː͡]", "", phonemes)
|
||||
return phonemes.strip()
|
||||
|
||||
def phonemize(self, text: str, espeak: bool = False) -> str:
|
||||
text_to_phonemize: str = self._normalize_punctuation(text)
|
||||
sents: List[Sentence] = [sent for sent in self._phonemizer(text_to_phonemize, lang="en-us", espeak=espeak)]
|
||||
words: List[str] = [self._convert_punctuation(word) for word in itertools.chain(*sents)]
|
||||
return " ".join(words)
|
||||
|
||||
def transform(self, phonemes):
|
||||
# convert phonemes to ids
|
||||
# dictionary is in symbols.py
|
||||
return [self.symbol_to_id[p] for p in phonemes if p in self.symbol_to_id.keys()]
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
phonemizer = GruutPhonemizer("en-us")
|
||||
# text -> IPA
|
||||
phonemes = phonemizer.phonemize("Hello, wor-ld ?")
|
||||
print("phonemes:", phonemes)
|
||||
print("len(phonemes):", len(phonemes))
|
||||
phoneme_ids = phonemizer.transform(phonemes)
|
||||
print("phoneme_ids:", phoneme_ids)
|
||||
print("len(phoneme_ids):", len(phoneme_ids))
|
@ -1,12 +0,0 @@
|
||||
# modified from https://github.com/yangdongchao/SoundStorm/blob/master/soundstorm/s1/AR/text_processing/symbols.py
|
||||
# reference: https://github.com/lifeiteng/vall-e
|
||||
PAD = "_"
|
||||
PUNCTUATION = ';:,.!?¡¿—…"«»“” '
|
||||
LETTERS = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"
|
||||
IPA_LETTERS = (
|
||||
"ɑɐɒæɓʙβɔɕçɗɖðʤəɘɚɛɜɝɞɟʄɡɠɢʛɦɧħɥʜɨɪʝɭɬɫɮʟɱɯɰŋɳɲɴøɵɸθœɶʘɹɺɾɻʀʁɽʂʃʈʧʉʊʋⱱʌɣɤʍχʎʏʑʐʒʔʡʕʢǀǁǂǃˈˌːˑʼʴʰʱʲʷˠˤ˞↓↑→↗↘'̩'ᵻ"
|
||||
)
|
||||
SYMBOLS = [PAD] + list(PUNCTUATION) + list(LETTERS) + list(IPA_LETTERS)
|
||||
SPACE_ID = SYMBOLS.index(" ")
|
||||
SYMBOL_TO_ID = {s: i for i, s in enumerate(SYMBOLS)}
|
||||
ID_TO_SYMBOL = {i: s for i, s in enumerate(SYMBOLS)}
|
12
GPT_SoVITS/Accelerate/MLX/__init__.py
Normal file
12
GPT_SoVITS/Accelerate/MLX/__init__.py
Normal file
@ -0,0 +1,12 @@
|
||||
import importlib.util
|
||||
import platform
|
||||
|
||||
if importlib.util.find_spec("mlx") is not None and platform.system() == "Darwin":
|
||||
from .sample_funcs_mlx import sample_naive as sample_naive_mlx
|
||||
from .t2s_engine_mlx import T2SEngine as T2SEngineMLX
|
||||
|
||||
backends = ["mlx_static", "mlx_quantized_mxfp4", "mlx_quantized_affine", "mlx_varlen"]
|
||||
else:
|
||||
backends = []
|
||||
|
||||
__all__ = ["T2SEngineMLX", "sample_naive_mlx", "backends"]
|
181
GPT_SoVITS/Accelerate/MLX/backends/mlx_quantized.py
Normal file
181
GPT_SoVITS/Accelerate/MLX/backends/mlx_quantized.py
Normal file
@ -0,0 +1,181 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import cast
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from ..structs_mlx import KVCacheQ
|
||||
from ..t2s_model_abc import (
|
||||
AttentionABC,
|
||||
KVCache,
|
||||
KVCacheHND,
|
||||
T2SDecoderABC,
|
||||
TransformerBlockABC,
|
||||
TransformerDecoderABC,
|
||||
)
|
||||
|
||||
Array = mx.array
|
||||
|
||||
|
||||
class Attention(AttentionABC):
|
||||
def __init__(self, n_head: int, hidden_dim: int, max_seq_length: int):
|
||||
super().__init__(n_head, hidden_dim, max_seq_length)
|
||||
self.kc_class = KVCacheHND
|
||||
|
||||
@staticmethod
|
||||
def quantized_scaled_dot_product_attention(
|
||||
queries: Array,
|
||||
q_keys: tuple[Array, Array, Array],
|
||||
q_values: tuple[Array, Array, Array],
|
||||
scale: float,
|
||||
mask: Array,
|
||||
group_size: int = 32,
|
||||
bits: int = 8,
|
||||
) -> Array:
|
||||
queries *= scale
|
||||
|
||||
scores = mx.quantized_matmul(queries, *q_keys, transpose=True, group_size=group_size, bits=bits)
|
||||
scores = mx.where(mask, scores, -mx.inf)
|
||||
scores = mx.softmax(scores, axis=-1, precise=True) # type: ignore
|
||||
out = mx.quantized_matmul(scores, *q_values, transpose=False, group_size=group_size, bits=bits)
|
||||
|
||||
return out
|
||||
|
||||
def __call__(self, x: Array, input_pos: Array, kv_cache: KVCache | KVCacheQ, cache_idx: Array, attn_mask: Array):
|
||||
bsz, seqlen, _ = cast(tuple[int, ...], x.shape)
|
||||
|
||||
q, k, v = self.in_proj(x).split(3, axis=-1)
|
||||
|
||||
q, k, v = map(lambda x: x.reshape(bsz, seqlen, self.n_head, self.head_dim), (q, k, v))
|
||||
|
||||
q, k, v = map(lambda x: x.swapaxes(1, 2), (q, k, v))
|
||||
|
||||
kv_cache = self.kc_class.update_cache(input_pos, k, v, kv_cache, cache_idx)
|
||||
assert len(kv_cache) == 2
|
||||
|
||||
max_idx = int(input_pos.max())
|
||||
|
||||
q, k, v = map(lambda x: x[..., :max_idx, :], (q, *kv_cache))
|
||||
|
||||
mask = attn_mask[..., :max_idx]
|
||||
|
||||
attn = mx.fast.scaled_dot_product_attention(q, k, v, scale=self.scale, mask=mask)
|
||||
|
||||
attn = attn.swapaxes(1, 2).reshape(bsz, seqlen, self.hidden_dim)
|
||||
|
||||
attn = self.out_proj(attn)
|
||||
|
||||
return attn
|
||||
|
||||
# def __call__(self, x: Array, input_pos: Array, kv_cache: KVCache | KVCacheQ, cache_idx: Array, attn_mask: Array):
|
||||
# bsz, seqlen, _ = cast(tuple[int, ...], x.shape)
|
||||
|
||||
# q, k, v = self.in_proj(x).split(3, axis=-1)
|
||||
|
||||
# q, k, v = map(lambda x: x.reshape(bsz, seqlen, self.n_head, self.head_dim), (q, k, v))
|
||||
|
||||
# q, k, v = map(lambda x: x.swapaxes(1, 2), (q, k, v))
|
||||
|
||||
# kv_cache = self.kc_class.update_cache(input_pos, k, v, kv_cache, cache_idx)
|
||||
|
||||
# assert len(kv_cache) == 3
|
||||
# (k_q, k_s, k_b), (v_q, v_s, v_b), (group_size, bits) = kv_cache
|
||||
|
||||
# k_q, k_s, k_b, v_q, v_s, v_b = map(lambda x: x[..., : int(input_pos.max()), :], (k_q, k_s, k_b, v_q, v_s, v_b))
|
||||
|
||||
# mask = attn_mask[..., : int(input_pos.max())]
|
||||
|
||||
# attn = Attention.quantized_scaled_dot_product_attention(
|
||||
# q,
|
||||
# (k_q, k_s, k_b),
|
||||
# (v_q, v_s, v_b),
|
||||
# self.scale,
|
||||
# mask,
|
||||
# group_size,
|
||||
# bits,
|
||||
# )
|
||||
|
||||
# attn = attn.swapaxes(1, 2).reshape(bsz, seqlen, self.hidden_dim)
|
||||
|
||||
# output = self.out_proj(attn)
|
||||
|
||||
# return output
|
||||
|
||||
|
||||
class TransformerBlock(TransformerBlockABC):
|
||||
def __init__(self, n_head: int, ffn_dim: int, hidden_dim: int, max_seq_length: int, *args, **kwds) -> None:
|
||||
super().__init__(n_head, ffn_dim, hidden_dim, max_seq_length, *args, **kwds)
|
||||
|
||||
self.attention = Attention(n_head, hidden_dim, max_seq_length, *args, **kwds)
|
||||
|
||||
|
||||
class TransformerDecoder(TransformerDecoderABC):
|
||||
def __init__(
|
||||
self,
|
||||
hidden_dim: int,
|
||||
n_layer: int,
|
||||
n_head: int,
|
||||
ffn_dim: int,
|
||||
vocab_size: int,
|
||||
max_seq_length: int,
|
||||
max_batch_size: int,
|
||||
*args,
|
||||
**kwds,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
hidden_dim,
|
||||
n_layer,
|
||||
n_head,
|
||||
ffn_dim,
|
||||
vocab_size,
|
||||
max_seq_length,
|
||||
max_batch_size,
|
||||
*args,
|
||||
**kwds,
|
||||
)
|
||||
|
||||
self.layers = [
|
||||
TransformerBlock(
|
||||
n_head,
|
||||
ffn_dim,
|
||||
hidden_dim,
|
||||
max_seq_length,
|
||||
*args,
|
||||
**kwds,
|
||||
)
|
||||
for _ in range(n_layer)
|
||||
]
|
||||
|
||||
|
||||
class T2SDecoder(T2SDecoderABC):
|
||||
def __init__(
|
||||
self,
|
||||
config: dict,
|
||||
max_seq_length: int = 2000,
|
||||
max_batch_size: int = 10,
|
||||
) -> None:
|
||||
super().__init__(config, max_seq_length, max_batch_size)
|
||||
|
||||
self.h = TransformerDecoder(
|
||||
self.hidden_dim, self.n_layer, self.n_head, self.ffn_dim, self.vocab_size, max_seq_length, max_batch_size
|
||||
)
|
||||
|
||||
self.kv_class = KVCacheHND
|
||||
self.group_size = 32
|
||||
self.bits = 8
|
||||
self.mode = "affine"
|
||||
|
||||
def set_mode(self, mode: str):
|
||||
assert mode in ["affine", "mxfp4"]
|
||||
self.mode = mode
|
||||
if self.mode == "mxfp4":
|
||||
self.bits = 4
|
||||
else:
|
||||
self.bits = 8
|
||||
|
||||
def quantized(self):
|
||||
nn.quantize(self, self.group_size, self.bits, mode=self.mode)
|
||||
# for layer in self.h.layers:
|
||||
# nn.quantize(layer.feed_forward, self.group_size, self.bits)
|
||||
# nn.quantize(layer.attention, self.group_size, self.bits)
|
99
GPT_SoVITS/Accelerate/MLX/backends/mlx_static.py
Normal file
99
GPT_SoVITS/Accelerate/MLX/backends/mlx_static.py
Normal file
@ -0,0 +1,99 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import cast
|
||||
|
||||
import mlx.core as mx
|
||||
|
||||
from ..structs_mlx import KVCache, KVCacheQ
|
||||
from ..t2s_model_abc import (
|
||||
AttentionABC,
|
||||
KVCacheHND,
|
||||
T2SDecoderABC,
|
||||
TransformerBlockABC,
|
||||
TransformerDecoderABC,
|
||||
)
|
||||
|
||||
Array = mx.array
|
||||
|
||||
|
||||
class Attention(AttentionABC):
|
||||
def __init__(self, n_head: int, hidden_dim: int, max_seq_length: int):
|
||||
super().__init__(n_head, hidden_dim, max_seq_length)
|
||||
self.kc_class = KVCacheHND
|
||||
|
||||
def __call__(self, x: Array, input_pos: Array, kv_cache: KVCache | KVCacheQ, cache_idx: Array, attn_mask: Array):
|
||||
bsz, seqlen, _ = cast(tuple[int, ...], x.shape)
|
||||
|
||||
q, k, v = self.in_proj(x).split(3, axis=-1)
|
||||
|
||||
q, k, v = map(lambda x: x.reshape(bsz, seqlen, self.n_head, self.head_dim), (q, k, v))
|
||||
|
||||
q, k, v = map(lambda x: x.swapaxes(1, 2), (q, k, v))
|
||||
|
||||
kv_cache = self.kc_class.update_cache(input_pos, k, v, kv_cache, cache_idx)
|
||||
assert len(kv_cache) == 2
|
||||
|
||||
k, v = kv_cache
|
||||
|
||||
attn = mx.fast.scaled_dot_product_attention(q, k, v, scale=self.scale, mask=attn_mask)
|
||||
|
||||
attn = attn.swapaxes(1, 2).reshape(bsz, seqlen, self.hidden_dim)
|
||||
|
||||
attn = self.out_proj(attn)
|
||||
|
||||
return attn
|
||||
|
||||
|
||||
class TransformerBlock(TransformerBlockABC):
|
||||
def __init__(self, n_head: int, ffn_dim: int, hidden_dim: int, max_seq_length: int) -> None:
|
||||
super().__init__(n_head, ffn_dim, hidden_dim, max_seq_length)
|
||||
|
||||
self.attention = Attention(n_head, hidden_dim, max_seq_length)
|
||||
|
||||
|
||||
class TransformerDecoder(TransformerDecoderABC):
|
||||
def __init__(
|
||||
self,
|
||||
hidden_dim: int,
|
||||
n_layer: int,
|
||||
n_head: int,
|
||||
ffn_dim: int,
|
||||
vocab_size: int,
|
||||
max_seq_length: int,
|
||||
max_batch_size: int,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
hidden_dim,
|
||||
n_layer,
|
||||
n_head,
|
||||
ffn_dim,
|
||||
vocab_size,
|
||||
max_seq_length,
|
||||
max_batch_size,
|
||||
)
|
||||
|
||||
self.layers = [
|
||||
TransformerBlock(
|
||||
n_head,
|
||||
ffn_dim,
|
||||
hidden_dim,
|
||||
max_seq_length,
|
||||
)
|
||||
for _ in range(n_layer)
|
||||
]
|
||||
|
||||
|
||||
class T2SDecoder(T2SDecoderABC):
|
||||
def __init__(
|
||||
self,
|
||||
config: dict,
|
||||
max_seq_length: int = 2000,
|
||||
max_batch_size: int = 10,
|
||||
) -> None:
|
||||
super().__init__(config, max_seq_length, max_batch_size)
|
||||
|
||||
self.h = TransformerDecoder(
|
||||
self.hidden_dim, self.n_layer, self.n_head, self.ffn_dim, self.vocab_size, max_seq_length, max_batch_size
|
||||
)
|
||||
|
||||
self.kv_class = KVCacheHND
|
103
GPT_SoVITS/Accelerate/MLX/backends/mlx_varlen.py
Normal file
103
GPT_SoVITS/Accelerate/MLX/backends/mlx_varlen.py
Normal file
@ -0,0 +1,103 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import cast
|
||||
|
||||
import mlx.core as mx
|
||||
|
||||
from ..structs_mlx import KVCache, KVCacheQ
|
||||
from ..t2s_model_abc import (
|
||||
AttentionABC,
|
||||
KVCacheHND,
|
||||
T2SDecoderABC,
|
||||
TransformerBlockABC,
|
||||
TransformerDecoderABC,
|
||||
)
|
||||
|
||||
Array = mx.array
|
||||
|
||||
|
||||
class Attention(AttentionABC):
|
||||
def __init__(self, n_head: int, hidden_dim: int, max_seq_length: int):
|
||||
super().__init__(n_head, hidden_dim, max_seq_length)
|
||||
self.kc_class = KVCacheHND
|
||||
|
||||
def __call__(self, x: Array, input_pos: Array, kv_cache: KVCache | KVCacheQ, cache_idx: Array, attn_mask: Array):
|
||||
bsz, seqlen, _ = cast(tuple[int, ...], x.shape)
|
||||
|
||||
q, k, v = self.in_proj(x).split(3, axis=-1)
|
||||
|
||||
q, k, v = map(lambda x: x.reshape(bsz, seqlen, self.n_head, self.head_dim), (q, k, v))
|
||||
|
||||
q, k, v = map(lambda x: x.swapaxes(1, 2), (q, k, v))
|
||||
|
||||
kv_cache = self.kc_class.update_cache(input_pos, k, v, kv_cache, cache_idx)
|
||||
assert len(kv_cache) == 2
|
||||
|
||||
max_idx = int(input_pos.max())
|
||||
|
||||
q, k, v = map(lambda x: x[..., :max_idx, :], (q, *kv_cache))
|
||||
|
||||
mask = attn_mask[..., :max_idx]
|
||||
|
||||
attn = mx.fast.scaled_dot_product_attention(q, k, v, scale=self.scale, mask=mask)
|
||||
|
||||
attn = attn.swapaxes(1, 2).reshape(bsz, seqlen, self.hidden_dim)
|
||||
|
||||
attn = self.out_proj(attn)
|
||||
|
||||
return attn
|
||||
|
||||
|
||||
class TransformerBlock(TransformerBlockABC):
|
||||
def __init__(self, n_head: int, ffn_dim: int, hidden_dim: int, max_seq_length: int) -> None:
|
||||
super().__init__(n_head, ffn_dim, hidden_dim, max_seq_length)
|
||||
|
||||
self.attention = Attention(n_head, hidden_dim, max_seq_length)
|
||||
|
||||
|
||||
class TransformerDecoder(TransformerDecoderABC):
|
||||
def __init__(
|
||||
self,
|
||||
hidden_dim: int,
|
||||
n_layer: int,
|
||||
n_head: int,
|
||||
ffn_dim: int,
|
||||
vocab_size: int,
|
||||
max_seq_length: int,
|
||||
max_batch_size: int,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
hidden_dim,
|
||||
n_layer,
|
||||
n_head,
|
||||
ffn_dim,
|
||||
vocab_size,
|
||||
max_seq_length,
|
||||
max_batch_size,
|
||||
)
|
||||
|
||||
self.layers = [
|
||||
TransformerBlock(
|
||||
n_head,
|
||||
ffn_dim,
|
||||
hidden_dim,
|
||||
max_seq_length,
|
||||
)
|
||||
for _ in range(n_layer)
|
||||
]
|
||||
|
||||
|
||||
class T2SDecoder(T2SDecoderABC):
|
||||
def __init__(
|
||||
self,
|
||||
config: dict,
|
||||
max_seq_length: int = 2000,
|
||||
max_batch_size: int = 10,
|
||||
) -> None:
|
||||
super().__init__(config, max_seq_length, max_batch_size)
|
||||
|
||||
self.h = TransformerDecoder(
|
||||
self.hidden_dim, self.n_layer, self.n_head, self.ffn_dim, self.vocab_size, max_seq_length, max_batch_size
|
||||
)
|
||||
|
||||
self.kv_class = KVCacheHND
|
65
GPT_SoVITS/Accelerate/MLX/sample_funcs_mlx.py
Normal file
65
GPT_SoVITS/Accelerate/MLX/sample_funcs_mlx.py
Normal file
@ -0,0 +1,65 @@
|
||||
from typing import Protocol, cast
|
||||
|
||||
import mlx.core as mx
|
||||
|
||||
Array = mx.array
|
||||
|
||||
|
||||
class SampleProtocolMLX(Protocol):
|
||||
@staticmethod
|
||||
def __call__(
|
||||
logits: Array,
|
||||
previous_tokens: Array,
|
||||
temperature: float,
|
||||
top_k: int,
|
||||
top_p: float,
|
||||
repetition_penalty: float,
|
||||
) -> Array: ...
|
||||
|
||||
|
||||
class sample_naive(SampleProtocolMLX):
|
||||
# @partial(mx.compile)
|
||||
@staticmethod
|
||||
def __call__(
|
||||
logits,
|
||||
previous_tokens,
|
||||
temperature,
|
||||
top_k,
|
||||
top_p,
|
||||
repetition_penalty,
|
||||
):
|
||||
if temperature <= 1e-5:
|
||||
probs = mx.softmax(logits, axis=-1)
|
||||
return mx.argmax(probs, axis=-1, keepdims=True).astype(mx.int32)
|
||||
|
||||
if repetition_penalty != 1.0:
|
||||
batch_idx = mx.arange(cast(tuple[int, ...], previous_tokens.shape)[0])
|
||||
previous_tokens = previous_tokens.astype(mx.int64)
|
||||
selected_logists = logits[batch_idx, previous_tokens]
|
||||
selected_logists = mx.where(
|
||||
selected_logists < 0, selected_logists * repetition_penalty, selected_logists / repetition_penalty
|
||||
)
|
||||
logits[batch_idx, previous_tokens] = selected_logists
|
||||
|
||||
if top_p < 1.0:
|
||||
sorted_indices = mx.argsort(-logits, axis=-1)
|
||||
sorted_logits = mx.take_along_axis(logits, sorted_indices, axis=-1)
|
||||
cum_probs = mx.cumsum(mx.softmax(sorted_logits, axis=-1), axis=-1)
|
||||
sorted_indices_to_remove = cum_probs > top_p
|
||||
sorted_indices_to_remove[:, -1] = False
|
||||
indices_to_remove = mx.zeros_like(logits).astype(mx.bool_)
|
||||
batch_indices = mx.arange(cast(tuple[int, ...], logits.shape)[0])[:, None]
|
||||
indices_to_remove[batch_indices, sorted_indices] = sorted_indices_to_remove
|
||||
logits = mx.where(indices_to_remove, -mx.inf, logits)
|
||||
|
||||
if temperature < 1.0:
|
||||
logits = logits / temperature
|
||||
|
||||
v = mx.topk(logits, top_k)
|
||||
pivot = mx.expand_dims(v[:, 0], -1)
|
||||
logits = mx.where(logits < pivot, -mx.inf, logits)
|
||||
|
||||
gumbel_noise = mx.random.gumbel(shape=cast(tuple[int, ...], logits.shape), dtype=logits.dtype)
|
||||
idx_next = mx.argmax(logits + gumbel_noise, axis=-1, keepdims=True).astype(mx.int32)
|
||||
|
||||
return idx_next
|
152
GPT_SoVITS/Accelerate/MLX/structs_mlx.py
Normal file
152
GPT_SoVITS/Accelerate/MLX/structs_mlx.py
Normal file
@ -0,0 +1,152 @@
|
||||
"""
|
||||
Modified From https://github.com/XXXXRT666/GPT-SoVITS
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import List, MutableSequence, Protocol, TypeAlias, cast
|
||||
|
||||
import mlx.core as mx
|
||||
import torch
|
||||
|
||||
from ..PyTorch.structs import T2SRequest
|
||||
from .sample_funcs_mlx import SampleProtocolMLX, sample_naive
|
||||
|
||||
Tensor = torch.Tensor
|
||||
Array = mx.array
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class T2SRequestMLX:
|
||||
x: List[Array]
|
||||
x_lens: Array
|
||||
prompts: Array
|
||||
bert_feature: List[Array]
|
||||
valid_length: int
|
||||
top_k: int = 5
|
||||
top_p: float = 1
|
||||
early_stop_num: int = -1
|
||||
temperature: float = 1.0
|
||||
repetition_penalty: float = 1.35
|
||||
|
||||
@classmethod
|
||||
def from_torch(cls, request: T2SRequest) -> T2SRequestMLX:
|
||||
x = list(map(lambda tensor: mx.array(tensor.cpu()), request.x))
|
||||
x_lens = mx.array(request.x_lens.cpu())
|
||||
prompts = mx.array(request.prompts.cpu())
|
||||
bert_feature = list(map(lambda tensor: mx.array(tensor.cpu()), request.bert_feature))
|
||||
|
||||
return cls(
|
||||
x,
|
||||
x_lens,
|
||||
prompts,
|
||||
bert_feature,
|
||||
request.valid_length,
|
||||
request.top_k,
|
||||
request.top_p,
|
||||
request.early_stop_num,
|
||||
request.temperature,
|
||||
request.repetition_penalty,
|
||||
)
|
||||
|
||||
|
||||
KVCache: TypeAlias = tuple[Array, Array]
|
||||
KVCacheQ: TypeAlias = tuple[tuple[Array, Array, Array], tuple[Array, Array, Array], tuple[int, int]]
|
||||
|
||||
|
||||
class KVCacheProtocol(Protocol):
|
||||
@staticmethod
|
||||
def empty(kv_cache: KVCache | KVCacheQ) -> None: ...
|
||||
|
||||
@staticmethod
|
||||
def update_cache(
|
||||
input_pos: Array, k_val: Array, v_val: Array, kv_cache: KVCache | KVCacheQ, cache_idx: Array
|
||||
) -> KVCache | KVCacheQ: ...
|
||||
|
||||
@staticmethod
|
||||
def prefill_kv(k_val: Array, v_val: Array, kv_cache: KVCache | KVCacheQ) -> None: ...
|
||||
|
||||
@staticmethod
|
||||
def init_cache(
|
||||
batch_size: int, max_seq_length: int, n_heads: int, head_dim: int, dtype: mx.Dtype, *args, **kwds
|
||||
) -> KVCache | KVCacheQ: ...
|
||||
|
||||
|
||||
class T2SDecoderProtocol(Protocol):
|
||||
max_seq_length: int
|
||||
EOS: int
|
||||
n_head: int
|
||||
|
||||
def embed(self, x: list[Array], y: Array, bert_features: list[Array]) -> Array: ...
|
||||
|
||||
|
||||
class T2SSessionMLX:
|
||||
def __init__(
|
||||
self,
|
||||
decoder: T2SDecoderProtocol,
|
||||
request_torch: T2SRequest,
|
||||
sample_func: type[SampleProtocolMLX] = sample_naive,
|
||||
device: mx.Device = mx.Device(mx.cpu),
|
||||
dtype: mx.Dtype = mx.float32,
|
||||
):
|
||||
with mx.stream(device):
|
||||
request = T2SRequestMLX.from_torch(request_torch)
|
||||
|
||||
self.decoder = decoder
|
||||
self.request = request
|
||||
self.device = device
|
||||
self.dtype = dtype
|
||||
|
||||
bsz = len(request.x)
|
||||
y_len: int = cast(tuple[int, ...], request.prompts.shape)[-1]
|
||||
self.bsz = bsz
|
||||
self.y_len = y_len
|
||||
|
||||
# Cache
|
||||
self.kv_cache: MutableSequence[KVCache | KVCacheQ]
|
||||
self.sample = sample_func()
|
||||
|
||||
# Forward args
|
||||
self.x = [i.astype(mx.int32) for i in request.x]
|
||||
self.x_lens = request.x_lens.astype(mx.int32)
|
||||
self.y = mx.zeros((bsz, decoder.max_seq_length)).astype(mx.int32)
|
||||
self.y[:, : cast(tuple[int, ...], request.prompts.shape)[-1]] = request.prompts.astype(mx.int32)
|
||||
self.bert_feature = [i.astype(dtype) for i in request.bert_feature]
|
||||
|
||||
self.prefill_len = self.x_lens + cast(tuple[int, ...], request.prompts.shape)[1]
|
||||
|
||||
self.input_pos = mx.zeros_like(self.prefill_len)
|
||||
self.input_pos += self.prefill_len
|
||||
|
||||
# EOS
|
||||
self.completed = mx.array([False] * len(self.x)).astype(mx.bool_)
|
||||
self.y_results: List[Array] = [None] * len(self.x) # type: ignore
|
||||
|
||||
self.xy_pos = decoder.embed(self.x, request.prompts, self.bert_feature)
|
||||
|
||||
max_len = int(self.prefill_len.max(-1))
|
||||
attn_mask = mx.zeros(shape=(bsz, max_len, max_len), dtype=mx.bool_)
|
||||
|
||||
for bs in range(bsz):
|
||||
pos = int(self.x_lens[bs])
|
||||
seq_len = pos + y_len
|
||||
|
||||
attn_mask[bs, :seq_len, :pos] = True
|
||||
|
||||
ar_mask = ~mx.triu(
|
||||
x=mx.ones(
|
||||
shape=(
|
||||
y_len,
|
||||
y_len,
|
||||
),
|
||||
dtype=mx.bool_,
|
||||
),
|
||||
k=1,
|
||||
)
|
||||
attn_mask[bs, pos:seq_len, pos:seq_len] = ar_mask
|
||||
|
||||
attn_mask = mx.repeat(mx.expand_dims(attn_mask, 1), decoder.n_head, 1)
|
||||
self.attn_mask = attn_mask
|
||||
|
||||
mx.eval(self.attn_mask)
|
238
GPT_SoVITS/Accelerate/MLX/t2s_engine_mlx.py
Normal file
238
GPT_SoVITS/Accelerate/MLX/t2s_engine_mlx.py
Normal file
@ -0,0 +1,238 @@
|
||||
import gc
|
||||
import os
|
||||
import time
|
||||
import traceback
|
||||
from typing import cast
|
||||
|
||||
import mlx.core as mx
|
||||
import torch
|
||||
from rich.progress import BarColumn, Progress, TextColumn
|
||||
|
||||
from ..logger import SpeedColumnToken, console, logger
|
||||
from ..PyTorch.structs import T2SEngineProtocol, T2SRequest, T2SResult
|
||||
from .backends import mlx_quantized, mlx_static, mlx_varlen
|
||||
from .structs_mlx import T2SSessionMLX
|
||||
from .t2s_model_abc import T2SDecoderABC
|
||||
|
||||
Array = mx.array
|
||||
Tensor = torch.Tensor
|
||||
|
||||
|
||||
class T2SEngine(T2SEngineProtocol):
|
||||
def __init__(
|
||||
self,
|
||||
decoder_model: T2SDecoderABC,
|
||||
device: mx.Device | str = mx.Device(mx.cpu),
|
||||
dtype: torch.dtype | mx.Dtype = torch.float32,
|
||||
) -> None:
|
||||
if isinstance(device, str):
|
||||
match device:
|
||||
case "mx.cpu":
|
||||
device = mx.Device(mx.cpu)
|
||||
case "mx.gpu":
|
||||
device = mx.Device(mx.gpu)
|
||||
|
||||
match dtype:
|
||||
case torch.float32:
|
||||
dtype = mx.float32
|
||||
case torch.float16:
|
||||
dtype = mx.float16
|
||||
case torch.bfloat16:
|
||||
dtype = mx.bfloat16
|
||||
|
||||
device = cast(mx.Device, device)
|
||||
dtype = cast(mx.Dtype, dtype)
|
||||
|
||||
assert device.type.value in {0, 1}
|
||||
assert dtype in {mx.float16, mx.bfloat16, mx.float32}
|
||||
|
||||
self.device = device
|
||||
self.dtype = dtype
|
||||
|
||||
mx.set_default_device(device)
|
||||
decoder_model.set_dtype(self.dtype)
|
||||
|
||||
self.decoder_model: T2SDecoderABC = decoder_model
|
||||
self.decoder_model.compile()
|
||||
|
||||
def _handle_request(self, request: T2SRequest):
|
||||
decoder = self.decoder_model
|
||||
session = T2SSessionMLX(decoder, request, device=self.device, dtype=self.dtype)
|
||||
batch_idx = mx.arange(session.bsz)
|
||||
|
||||
t1 = 0.0
|
||||
infer_speed = 0.0
|
||||
infer_time = 0.0
|
||||
|
||||
with (
|
||||
mx.stream(session.device),
|
||||
Progress(
|
||||
TextColumn("[cyan]{task.description}"),
|
||||
BarColumn(),
|
||||
TextColumn("{task.completed}/{task.total}"),
|
||||
SpeedColumnToken(show_speed=True),
|
||||
console=console,
|
||||
transient=True,
|
||||
) as progress,
|
||||
):
|
||||
max_token = min(2000 - int(session.input_pos.max()), 1500)
|
||||
|
||||
task = progress.add_task("T2S Decoding", total=max_token)
|
||||
for idx in range(1500):
|
||||
progress.update(task, advance=1)
|
||||
if idx == 0:
|
||||
session.kv_cache = decoder.init_cache(session.bsz)
|
||||
xy_dec = decoder.h.prefill(
|
||||
session.xy_pos,
|
||||
session.attn_mask,
|
||||
session.kv_cache,
|
||||
) # bs, seq_len, embed_dim
|
||||
xy_dec = xy_dec[None, batch_idx, session.input_pos - 1]
|
||||
else:
|
||||
args, kwds = decoder.pre_forward(session)
|
||||
xy_dec = decoder.h(
|
||||
session.input_pos,
|
||||
session.xy_pos,
|
||||
session.kv_cache,
|
||||
batch_idx,
|
||||
*args,
|
||||
**kwds,
|
||||
)
|
||||
|
||||
decoder.post_forward(idx, session)
|
||||
logits = decoder.ar_predict_layer(xy_dec[:, -1])
|
||||
session.input_pos += 1
|
||||
|
||||
if idx == 0:
|
||||
logits[:, -1] = -mx.inf
|
||||
|
||||
samples = session.sample(
|
||||
logits=logits,
|
||||
previous_tokens=session.y[:, : session.y_len + idx],
|
||||
top_k=request.top_k,
|
||||
top_p=request.top_p,
|
||||
repetition_penalty=request.repetition_penalty,
|
||||
temperature=request.temperature,
|
||||
)
|
||||
|
||||
session.y[batch_idx, session.y_len + idx] = samples
|
||||
|
||||
argmax_token = mx.argmax(logits, axis=-1)
|
||||
sample_token = samples.squeeze(1)
|
||||
EOS_mask = (cast(Array, argmax_token == decoder.EOS)) | (sample_token == decoder.EOS)
|
||||
|
||||
newly_done_mask = EOS_mask & (~session.completed)
|
||||
newly_done_indices = mx.where(newly_done_mask, batch_idx, -1)
|
||||
pos = mx.where(newly_done_indices != -1, batch_idx, session.bsz)
|
||||
pos_sorted = mx.sort(pos, axis=0)
|
||||
valid_count = session.bsz - mx.sum(cast(Array, pos_sorted == session.bsz))
|
||||
pos_final = pos_sorted[: int(valid_count)]
|
||||
newly_done_indices = mx.expand_dims(newly_done_indices[pos_final], 0)
|
||||
|
||||
if newly_done_indices.size > 0:
|
||||
for i in newly_done_indices:
|
||||
session.y_results[int(i)] = session.y[i, session.y_len : session.y_len + idx]
|
||||
session.completed[newly_done_indices] = True
|
||||
|
||||
if mx.all(session.completed).item():
|
||||
if session.y[:, session.y_len :].sum() == 0:
|
||||
session.y_results = [mx.array([0]) for _ in range(session.bsz)]
|
||||
logger.error("Bad Zero Prediction")
|
||||
else:
|
||||
logger.info(
|
||||
f"T2S Decoding EOS {session.prefill_len.tolist().__str__().strip('[]')} -> {[cast(tuple[int, ...], i.shape)[-1] for i in session.y_results].__str__().strip('[]')}"
|
||||
)
|
||||
logger.info(f"Infer Speed: {(idx - 1) / (time.perf_counter() - t1):.2f} token/s")
|
||||
infer_time = time.perf_counter() - t1
|
||||
infer_speed = (idx - 1) / infer_time
|
||||
break
|
||||
|
||||
if (request.early_stop_num != -1 and idx >= request.early_stop_num) or idx == max_token - 1:
|
||||
for j in range(session.bsz):
|
||||
if not session.completed[j].item():
|
||||
session.y_results[j] = session.y[[j], session.y_len : session.y_len + 1499]
|
||||
session.completed[j] = True
|
||||
logger.error("Bad Full Prediction")
|
||||
logger.info(f"Infer Speed: {(idx - 1) / (time.perf_counter() - t1):.2f} token/s")
|
||||
infer_time = time.perf_counter() - t1
|
||||
infer_speed = (idx - 1) / infer_time
|
||||
break
|
||||
|
||||
y_emb = decoder.ar_audio_embedding(samples)
|
||||
session.xy_pos = decoder.ar_audio_position(session.input_pos - session.x_lens, y_emb)
|
||||
mx.eval(session.xy_pos, session.y)
|
||||
|
||||
if idx == 1:
|
||||
t1 = time.perf_counter()
|
||||
|
||||
if idx % 100 == 0:
|
||||
mx.clear_cache()
|
||||
|
||||
match session.device:
|
||||
case mx.gpu:
|
||||
mx.clear_cache()
|
||||
case mx.cpu:
|
||||
gc.collect()
|
||||
|
||||
result_mlx = session.y_results[: request.valid_length]
|
||||
mx.eval(result_mlx)
|
||||
result = [torch.tensor(k) for k in result_mlx]
|
||||
return result, infer_speed, infer_time
|
||||
|
||||
def generate(self, request: T2SRequest):
|
||||
try:
|
||||
result, infer_speed, infer_time = self._handle_request(request)
|
||||
t2s_result = T2SResult(result=result, infer_speed=(infer_speed, infer_time), status="Success")
|
||||
except Exception as e:
|
||||
t2s_result = T2SResult(status="Error", exception=e, traceback=traceback.format_exc())
|
||||
return t2s_result
|
||||
|
||||
@staticmethod
|
||||
def replace_key(state_dict: dict[str, Tensor]):
|
||||
state_dict_mlx: list[tuple[str, Array]] = []
|
||||
for key, value in state_dict.items():
|
||||
key = (
|
||||
key.replace("model.", "")
|
||||
.replace("in_proj_", "in_proj.")
|
||||
.replace("self_attn", "attention")
|
||||
.replace("linear", "feed_forward.linear")
|
||||
.replace("norm1", "attention_norm")
|
||||
.replace("norm2", "ffn_norm")
|
||||
)
|
||||
value_mlx = mx.array(value)
|
||||
state_dict_mlx.append((key, value_mlx))
|
||||
return state_dict_mlx
|
||||
|
||||
@staticmethod
|
||||
def load_decoder(weights_path: os.PathLike, max_batch_size: int = 1, backend: str = "MLX-Varlen"):
|
||||
logger.info(f"Loading Text2Semantic Weights from {weights_path} with {backend} Backend")
|
||||
dict_s1 = torch.load(weights_path, map_location="cpu", weights_only=False, mmap=True)
|
||||
config = dict_s1["config"]
|
||||
match backend:
|
||||
case "MLX-Varlen":
|
||||
decoder_cls: type[T2SDecoderABC] = mlx_varlen.T2SDecoder
|
||||
case "MLX-Static":
|
||||
decoder_cls = mlx_static.T2SDecoder
|
||||
case "MLX-Quantized-Affine" | "MLX-Quantized-MXFP4":
|
||||
decoder_cls = mlx_quantized.T2SDecoder
|
||||
case _:
|
||||
raise RuntimeError(f"Backend {backend} Not Found")
|
||||
|
||||
decoder: T2SDecoderABC = decoder_cls(config, max_batch_size=max_batch_size)
|
||||
state_dict = dict_s1["weight"]
|
||||
state_dict_mlx = T2SEngine.replace_key(state_dict)
|
||||
decoder.load_weights(state_dict_mlx)
|
||||
decoder.eval()
|
||||
mx.eval(decoder)
|
||||
|
||||
if "Quantized" in backend and isinstance(decoder, mlx_quantized.T2SDecoder):
|
||||
if backend == "MLX-Quantized-Affine":
|
||||
decoder.set_mode("affine")
|
||||
elif backend == "MLX-Quantized-MXFP4":
|
||||
decoder.set_mode("mxfp4")
|
||||
else:
|
||||
raise RuntimeError(f"Quantized Backend {backend} Not Supported")
|
||||
decoder.quantized()
|
||||
mx.eval(decoder)
|
||||
|
||||
return decoder
|
530
GPT_SoVITS/Accelerate/MLX/t2s_model_abc.py
Normal file
530
GPT_SoVITS/Accelerate/MLX/t2s_model_abc.py
Normal file
@ -0,0 +1,530 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import MutableSequence, cast
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from .structs_mlx import KVCache, KVCacheProtocol, KVCacheQ, T2SDecoderProtocol, T2SSessionMLX
|
||||
|
||||
Array = mx.array
|
||||
|
||||
|
||||
class TokenEmbedding(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
embedding_dim: int,
|
||||
vocab_size: int,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.vocab_size = vocab_size
|
||||
self.embedding_dim = embedding_dim
|
||||
|
||||
self.word_embeddings = nn.Embedding(self.vocab_size, self.embedding_dim)
|
||||
|
||||
@property
|
||||
def weight(self):
|
||||
return self.word_embeddings.weight
|
||||
|
||||
def embedding(self, index: int):
|
||||
return self.word_embeddings.weight[index : index + 1]
|
||||
|
||||
def __call__(self, x: Array):
|
||||
x = self.word_embeddings(x)
|
||||
return x
|
||||
|
||||
|
||||
class SinePositionalEmbedding(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
embedding_dim: int,
|
||||
scale: bool = False,
|
||||
max_batch_size: int = 10,
|
||||
max_seq_len: int = 2000,
|
||||
):
|
||||
super().__init__()
|
||||
self.embedding_dim = embedding_dim
|
||||
self.x_scale = math.sqrt(embedding_dim) if scale else 1.0
|
||||
self.alpha = mx.ones(1)
|
||||
self.max_batch_size = max_batch_size
|
||||
self.max_seq_len = max_seq_len
|
||||
|
||||
self.reverse = False
|
||||
self._pe = mx.zeros((max_batch_size, max_seq_len, embedding_dim))
|
||||
self.compute_pe()
|
||||
|
||||
def compute_pe(self):
|
||||
"""Reset the positional encodings."""
|
||||
|
||||
if self.reverse:
|
||||
position = mx.expand_dims(mx.arange(self.max_seq_len - 1, -1, -1.0), axis=1)
|
||||
else:
|
||||
position = mx.expand_dims(mx.arange(self.max_seq_len), axis=1)
|
||||
div_term = mx.exp(
|
||||
mx.arange(
|
||||
0,
|
||||
self.embedding_dim,
|
||||
2,
|
||||
)
|
||||
* -(math.log(10000.0) / self.embedding_dim)
|
||||
)
|
||||
pe = self._pe
|
||||
pe[:, :, 0::2] = mx.sin(position * div_term)
|
||||
pe[:, :, 1::2] = mx.cos(position * div_term)
|
||||
|
||||
def __call__(self, input_pos: Array, x: Array):
|
||||
"""
|
||||
Args:
|
||||
input_pos (Array): [batch_size, ]
|
||||
x (Array): [batch_size, 1, embed_dim]
|
||||
|
||||
Returns:
|
||||
embedded_x (Array): [batch_size, 1, embed_dim]
|
||||
"""
|
||||
|
||||
batch_size = cast(tuple[int, ...], x.shape)[0]
|
||||
pe_values = self._pe[mx.arange(batch_size), input_pos - 1] # (batch_size, embed_dim)
|
||||
|
||||
return x * self.x_scale + self.alpha * mx.expand_dims(pe_values, 1) # (batch_size, 1, embed_dim)
|
||||
|
||||
def prefill(self, x: Array):
|
||||
"""
|
||||
Args:
|
||||
x (Array): [batch_size, seq_len, embed_dim]
|
||||
|
||||
Returns:
|
||||
embedded_x (Array): [batch_size, seq_len, embed_dim]
|
||||
"""
|
||||
pe_values = self._pe[:, : cast(tuple[int, ...], x.shape)[-2]]
|
||||
return x * self.x_scale + self.alpha * pe_values
|
||||
|
||||
|
||||
class KVCacheHND(KVCacheProtocol):
|
||||
@staticmethod
|
||||
def empty(kv_cache):
|
||||
assert len(kv_cache) == 2
|
||||
k_cache, v_cache = kv_cache
|
||||
|
||||
k_cache[:] = 0
|
||||
v_cache[:] = 0
|
||||
|
||||
@staticmethod
|
||||
def update_cache(input_pos, k_val, v_val, kv_cache, cache_idx):
|
||||
# input_pos: [B, ], k_val: [B, H, 1, D]
|
||||
assert len(kv_cache) == 2
|
||||
k_out, v_out = kv_cache
|
||||
ip0 = input_pos - 1
|
||||
|
||||
k_out[cache_idx, :, ip0, None] = k_val
|
||||
v_out[cache_idx, :, ip0, None] = v_val
|
||||
|
||||
return k_out, v_out
|
||||
|
||||
@staticmethod
|
||||
def prefill_kv(k_val, v_val, kv_cache):
|
||||
# k_val: [B, S, H, D]
|
||||
assert len(kv_cache) == 2
|
||||
k_cache, v_cache = kv_cache
|
||||
|
||||
k_cache[..., : cast(tuple[int, ...], k_val.shape)[1], :] = k_val.swapaxes(1, 2)
|
||||
v_cache[..., : cast(tuple[int, ...], v_val.shape)[1], :] = v_val.swapaxes(1, 2)
|
||||
|
||||
@staticmethod
|
||||
def init_cache(batch_size: int, max_seq_length: int, n_heads: int, head_dim: int, dtype: mx.Dtype) -> KVCache:
|
||||
cache_shape = (batch_size, n_heads, max_seq_length, head_dim)
|
||||
|
||||
return (mx.zeros(cache_shape, dtype=dtype), mx.zeros(cache_shape, dtype=dtype))
|
||||
|
||||
|
||||
class KVCacheHNDQuantized(KVCacheProtocol):
|
||||
@staticmethod
|
||||
def _el_per_int(bits: int) -> int:
|
||||
return 32 // bits
|
||||
|
||||
@staticmethod
|
||||
def _packed_dim(head_dim: int, bits: int = 8) -> int:
|
||||
el_per_int = KVCacheHNDQuantized._el_per_int(bits)
|
||||
if head_dim % el_per_int != 0:
|
||||
raise ValueError(f"{head_dim=} is not divisible by {el_per_int=} ({bits=})")
|
||||
return head_dim // el_per_int
|
||||
|
||||
@staticmethod
|
||||
def _group_count(head_dim: int, group_size: int = 32) -> int:
|
||||
assert group_size in {32, 64, 128}
|
||||
if head_dim % group_size != 0:
|
||||
raise ValueError(f"{head_dim} is not divisible by {group_size=}")
|
||||
return head_dim // group_size
|
||||
|
||||
@staticmethod
|
||||
def empty(kv_cache) -> None:
|
||||
assert len(kv_cache) == 3
|
||||
(k_q, k_s, k_b), (v_q, v_s, v_b), (_, __) = kv_cache
|
||||
|
||||
k_q[:] = 0
|
||||
k_s[:] = 0
|
||||
k_b[:] = 0
|
||||
v_q[:] = 0
|
||||
v_s[:] = 0
|
||||
v_b[:] = 0
|
||||
|
||||
@staticmethod
|
||||
def update_cache(
|
||||
input_pos,
|
||||
k_val,
|
||||
v_val,
|
||||
kv_cache,
|
||||
cache_idx,
|
||||
):
|
||||
# input_pos: [B, ], k_val: [B, H, 1, D]
|
||||
|
||||
assert len(kv_cache) == 3
|
||||
(k_q_out, k_s_out, k_b_out), (v_q_out, v_s_out, v_b_out), (group_size, bits) = kv_cache
|
||||
|
||||
k_q, k_s, k_b = mx.quantize(k_val, group_size=group_size, bits=bits)
|
||||
v_q, v_s, v_b = mx.quantize(v_val, group_size=group_size, bits=bits)
|
||||
|
||||
ip0 = input_pos - 1
|
||||
|
||||
k_q_out[cache_idx, :, ip0, None] = k_q
|
||||
k_s_out[cache_idx, :, ip0, None] = k_s
|
||||
k_b_out[cache_idx, :, ip0, None] = k_b
|
||||
|
||||
v_q_out[cache_idx, :, ip0, None] = v_q
|
||||
v_s_out[cache_idx, :, ip0, None] = v_s
|
||||
v_b_out[cache_idx, :, ip0, None] = v_b
|
||||
|
||||
return (k_q_out, k_s_out, k_b_out), (v_q_out, v_s_out, v_b_out), (group_size, bits)
|
||||
|
||||
@staticmethod
|
||||
def prefill_kv(
|
||||
k_val,
|
||||
v_val,
|
||||
kv_cache,
|
||||
) -> None:
|
||||
assert len(kv_cache) == 3
|
||||
(k_q_out, k_s_out, k_b_out), (v_q_out, v_s_out, v_b_out), (group_size, bits) = kv_cache
|
||||
|
||||
S = cast(tuple[int, ...], k_val.shape)[1]
|
||||
|
||||
k_sw = k_val.swapaxes(1, 2)
|
||||
v_sw = v_val.swapaxes(1, 2)
|
||||
|
||||
k_q, k_s, k_b = mx.quantize(k_sw, group_size=group_size, bits=bits)
|
||||
v_q, v_s, v_b = mx.quantize(v_sw, group_size=group_size, bits=bits)
|
||||
|
||||
k_q_out[..., :S, :] = k_q
|
||||
k_s_out[..., :S, :] = k_s
|
||||
k_b_out[..., :S, :] = k_b
|
||||
|
||||
v_q_out[..., :S, :] = v_q
|
||||
v_s_out[..., :S, :] = v_s
|
||||
v_b_out[..., :S, :] = v_b
|
||||
|
||||
@staticmethod
|
||||
def init_cache(
|
||||
batch_size: int,
|
||||
max_seq_length: int,
|
||||
n_heads: int,
|
||||
head_dim: int,
|
||||
dtype: mx.Dtype,
|
||||
*,
|
||||
group_size: int = 32,
|
||||
bits: int = 8,
|
||||
) -> KVCacheQ:
|
||||
packed_dim = KVCacheHNDQuantized._packed_dim(head_dim, bits=bits)
|
||||
group_cnt = KVCacheHNDQuantized._group_count(head_dim, group_size=group_size)
|
||||
|
||||
packed_shape = (batch_size, n_heads, max_seq_length, packed_dim)
|
||||
group_shape = (batch_size, n_heads, max_seq_length, group_cnt)
|
||||
|
||||
k_q = mx.zeros(packed_shape, dtype=mx.uint32)
|
||||
k_s = mx.zeros(group_shape, dtype=dtype)
|
||||
k_b = mx.zeros(group_shape, dtype=dtype)
|
||||
|
||||
v_q = mx.zeros(packed_shape, dtype=mx.uint32)
|
||||
v_s = mx.zeros(group_shape, dtype=dtype)
|
||||
v_b = mx.zeros(group_shape, dtype=dtype)
|
||||
|
||||
return (k_q, k_s, k_b), (v_q, v_s, v_b), (group_size, bits)
|
||||
|
||||
|
||||
class AttentionABC(ABC, nn.Module):
|
||||
def __init__(self, n_head: int, hidden_dim: int, max_seq_length: int, *args, **kwds):
|
||||
super().__init__()
|
||||
|
||||
self.n_head = n_head
|
||||
self.hidden_dim = hidden_dim
|
||||
assert hidden_dim % n_head == 0
|
||||
self.head_dim = hidden_dim // n_head
|
||||
|
||||
self.max_seq_length = max_seq_length
|
||||
|
||||
# key, query, value projections for all heads, but in a batch
|
||||
self.in_proj = nn.Linear(hidden_dim, hidden_dim * 3, bias=True)
|
||||
self.out_proj = nn.Linear(hidden_dim, hidden_dim, bias=True)
|
||||
|
||||
self.scale = 1 / math.sqrt(self.head_dim)
|
||||
|
||||
self.kc_class: KVCacheProtocol
|
||||
|
||||
@abstractmethod
|
||||
def __call__(
|
||||
self, x: Array, input_pos: Array, kv_cache: KVCache | KVCacheQ, cache_idx: Array, attn_mask: Array
|
||||
) -> Array: ...
|
||||
|
||||
def prefill(self, x: Array, kv_cache: KVCache | KVCacheQ, attn_mask: Array):
|
||||
bsz, seqlen, _ = cast(tuple[int, ...], x.shape)
|
||||
|
||||
q, k, v = self.in_proj(x).split(3, axis=-1)
|
||||
|
||||
q, k, v = map(lambda x: x.reshape(bsz, seqlen, self.n_head, self.head_dim), (q, k, v))
|
||||
|
||||
self.kc_class.prefill_kv(k, v, kv_cache)
|
||||
|
||||
q, k, v = map(lambda x: x.swapaxes(1, 2), (q, k, v))
|
||||
|
||||
attn = mx.fast.scaled_dot_product_attention(q, k, v, mask=attn_mask, scale=self.scale)
|
||||
|
||||
attn = mx.nan_to_num(attn)
|
||||
|
||||
attn = attn.swapaxes(1, 2).reshape(1, -1, self.hidden_dim)
|
||||
|
||||
output = self.out_proj(attn)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(self, dim: int, hidden_dim: int) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.linear1 = nn.Linear(dim, hidden_dim, bias=True)
|
||||
self.linear2 = nn.Linear(hidden_dim, dim, bias=True)
|
||||
|
||||
def __call__(self, x: Array):
|
||||
return self.linear2(nn.relu(self.linear1(x)))
|
||||
|
||||
|
||||
class TransformerBlockABC(nn.Module):
|
||||
def __init__(self, n_head: int, ffn_dim: int, hidden_dim: int, max_seq_length: int, *args, **kwds) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.hidden_dim = hidden_dim
|
||||
self.max_seq_length = max_seq_length
|
||||
|
||||
self.attention: AttentionABC
|
||||
|
||||
self.feed_forward = FeedForward(hidden_dim, ffn_dim)
|
||||
self.attention_norm = nn.LayerNorm(self.hidden_dim)
|
||||
self.ffn_norm = nn.LayerNorm(self.hidden_dim)
|
||||
|
||||
def __call__(self, x: Array, input_pos: Array, kv_cache: KVCache | KVCacheQ, cache_idx: Array, attn_mask: Array):
|
||||
h = self.attention_norm(
|
||||
x
|
||||
+ self.attention(
|
||||
x,
|
||||
input_pos,
|
||||
kv_cache,
|
||||
cache_idx,
|
||||
attn_mask,
|
||||
)
|
||||
)
|
||||
out = self.ffn_norm(h + self.feed_forward(h))
|
||||
return out
|
||||
|
||||
def prefill(self, x: Array, attn_mask: Array, kv_cache: KVCache | KVCacheQ):
|
||||
h = self.attention_norm(
|
||||
x
|
||||
+ self.attention.prefill(
|
||||
x,
|
||||
kv_cache,
|
||||
attn_mask,
|
||||
)
|
||||
)
|
||||
out = self.ffn_norm(h + self.feed_forward(h))
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class TransformerDecoderABC(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
hidden_dim: int,
|
||||
n_layer: int,
|
||||
n_head: int,
|
||||
ffn_dim: int,
|
||||
vocab_size: int,
|
||||
max_seq_length: int,
|
||||
max_batch_size: int,
|
||||
*args,
|
||||
**kwds,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.hidden_dim = hidden_dim
|
||||
self.n_head = n_head
|
||||
assert hidden_dim % n_head == 0
|
||||
|
||||
self.head_dim = hidden_dim // n_head
|
||||
self.vocab_size = vocab_size
|
||||
|
||||
self.n_layer = n_layer
|
||||
|
||||
self.layers: MutableSequence[TransformerBlockABC]
|
||||
|
||||
self.max_seq_length = max_seq_length
|
||||
self.max_batch_size = max_batch_size
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
input_pos: Array,
|
||||
x: Array,
|
||||
kv_caches: MutableSequence[KVCache | KVCacheQ],
|
||||
cache_idx: Array,
|
||||
*args,
|
||||
**kwds,
|
||||
):
|
||||
for layer, kv_cache in zip(self.layers, kv_caches):
|
||||
x = layer(
|
||||
x,
|
||||
input_pos,
|
||||
kv_cache,
|
||||
cache_idx,
|
||||
*args,
|
||||
**kwds,
|
||||
)
|
||||
|
||||
return x
|
||||
|
||||
def prefill(self, x: Array, mask: Array, kv_caches: MutableSequence[KVCache | KVCacheQ]):
|
||||
for layer, kv_cache in zip(self.layers, kv_caches):
|
||||
x = layer.prefill(
|
||||
x,
|
||||
mask,
|
||||
kv_cache,
|
||||
)
|
||||
return x
|
||||
|
||||
|
||||
class T2SDecoderABC(nn.Module, T2SDecoderProtocol):
|
||||
def __init__(
|
||||
self,
|
||||
config: dict,
|
||||
max_seq_length: int = 2000,
|
||||
max_batch_size: int = 10,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
hidden_dim: int = config["model"]["hidden_dim"]
|
||||
embedding_dim: int = config["model"]["embedding_dim"]
|
||||
n_head: int = config["model"]["head"]
|
||||
n_layer: int = config["model"]["n_layer"]
|
||||
vocab_size: int = config["model"]["vocab_size"]
|
||||
phoneme_vocab_size: int = config["model"]["phoneme_vocab_size"]
|
||||
EOS: int = config["model"]["EOS"]
|
||||
ffn_dim: int = hidden_dim * 4
|
||||
|
||||
self.n_layer = int(n_layer)
|
||||
self.hidden_dim = int(hidden_dim)
|
||||
self.n_head = int(n_head)
|
||||
assert hidden_dim % n_head == 0
|
||||
|
||||
self.head_dim = int(hidden_dim // n_head)
|
||||
self.embedding_dim = int(embedding_dim)
|
||||
self.ffn_dim = int(ffn_dim)
|
||||
self.vocab_size = int(vocab_size)
|
||||
self.phoneme_vocab_size = int(phoneme_vocab_size)
|
||||
self.max_seq_length = max_seq_length
|
||||
self.max_batch_size = max_batch_size
|
||||
self.EOS = EOS
|
||||
assert self.EOS == self.vocab_size - 1
|
||||
|
||||
self.bert_proj = nn.Linear(1024, self.embedding_dim)
|
||||
self.ar_predict_layer = nn.Linear(self.hidden_dim, self.vocab_size, bias=False)
|
||||
self.h: TransformerDecoderABC
|
||||
|
||||
self.ar_text_embedding = TokenEmbedding(self.embedding_dim, self.phoneme_vocab_size)
|
||||
self.ar_text_position = SinePositionalEmbedding(
|
||||
self.embedding_dim,
|
||||
scale=False,
|
||||
max_batch_size=max_batch_size,
|
||||
max_seq_len=max_seq_length,
|
||||
)
|
||||
self.ar_audio_embedding = TokenEmbedding(self.embedding_dim, self.vocab_size)
|
||||
self.ar_audio_position = SinePositionalEmbedding(
|
||||
self.embedding_dim,
|
||||
scale=False,
|
||||
max_batch_size=max_batch_size,
|
||||
max_seq_len=max_seq_length,
|
||||
)
|
||||
|
||||
self.kv_class: KVCacheProtocol
|
||||
|
||||
def init_cache(self, bsz: int = 0, *args, **kwds) -> MutableSequence[KVCache | KVCacheQ]:
|
||||
bsz = bsz or self.h.max_batch_size
|
||||
assert bsz <= self.h.max_batch_size
|
||||
seq_lens = self.h.max_seq_length
|
||||
dtype = self.bert_proj.bias.dtype
|
||||
cache: MutableSequence[KVCache | KVCacheQ] = [
|
||||
self.kv_class.init_cache(bsz, seq_lens, self.n_head, self.head_dim, dtype, *args, **kwds)
|
||||
for _ in range(self.n_layer)
|
||||
]
|
||||
mx.eval(cache)
|
||||
return cache
|
||||
|
||||
def embed(
|
||||
self,
|
||||
x: list[Array],
|
||||
y: Array,
|
||||
bert_features: list[Array],
|
||||
):
|
||||
x_len: list[int] = [cast(tuple[int, ...], i.shape)[0] for i in x]
|
||||
x_len_max = max(x_len)
|
||||
xy_pos = mx.zeros((len(x), x_len_max + cast(tuple[int, ...], y.shape)[1], self.embedding_dim)).astype(
|
||||
bert_features[0].dtype
|
||||
)
|
||||
|
||||
bert_features = list(map(lambda x: x.swapaxes(0, 1), bert_features))
|
||||
|
||||
y_len = cast(tuple[int, ...], y.shape)[1]
|
||||
y_emb = self.ar_audio_embedding(y)
|
||||
y_pos = self.ar_audio_position.prefill(y_emb)
|
||||
|
||||
for bs, (x_, len_, bert_feature) in enumerate(zip(x, x_len, bert_features)):
|
||||
x_emb = self.ar_text_embedding(x_)
|
||||
bert = self.bert_proj(bert_feature)
|
||||
x_emb = x_emb + bert
|
||||
x_pos = self.ar_text_position.prefill(mx.expand_dims(x_emb, 0))
|
||||
xy_pos[[bs], :len_] = x_pos
|
||||
xy_pos[[bs], len_ : len_ + y_len] = y_pos
|
||||
|
||||
mx.eval(xy_pos)
|
||||
return xy_pos
|
||||
|
||||
def compile(self):
|
||||
setattr(self.h, "__call__", mx.compile(self.h.__call__))
|
||||
# setattr(self.h, "prefill", mx.compile(self.h.prefill, shapeless=True))
|
||||
|
||||
def pre_forward(self, session: T2SSessionMLX):
|
||||
attn_mask = session.attn_mask
|
||||
return list(), dict(attn_mask=attn_mask)
|
||||
|
||||
def post_forward(self, idx: int, session: T2SSessionMLX) -> None:
|
||||
if idx == 0:
|
||||
prefill_len = session.prefill_len
|
||||
bsz = session.bsz
|
||||
|
||||
range_tensor = mx.arange(self.max_seq_length).reshape(1, 1, 1, self.max_seq_length)
|
||||
prefill_len_expanded = prefill_len.reshape(bsz, 1, 1, 1)
|
||||
attn_mask = range_tensor < prefill_len_expanded
|
||||
attn_mask = mx.repeat(attn_mask, self.n_head, 1)
|
||||
|
||||
session.attn_mask = attn_mask
|
||||
|
||||
attn_mask = session.attn_mask
|
||||
input_pos = session.input_pos
|
||||
attn_mask[mx.arange(session.bsz), :, :, input_pos] = True
|
||||
mx.eval(attn_mask)
|
30
GPT_SoVITS/Accelerate/PyTorch/__init__.py
Normal file
30
GPT_SoVITS/Accelerate/PyTorch/__init__.py
Normal file
@ -0,0 +1,30 @@
|
||||
import importlib.util
|
||||
|
||||
import torch
|
||||
|
||||
from .sample_funcs import sample_naive
|
||||
from .structs import T2SRequest, T2SResult
|
||||
from .t2s_engine import T2SEngine as T2SEngineTorch
|
||||
|
||||
torch.set_grad_enabled(False)
|
||||
|
||||
backends = ["torch_varlen"]
|
||||
if torch.cuda.is_available():
|
||||
backends.append("torch_static_cuda_graph")
|
||||
# if importlib.util.find_spec("sageattention") is not None:
|
||||
# for i in range(torch.cuda.device_count()):
|
||||
# major, minor = torch.cuda.get_device_capability(i)
|
||||
# sm_version = major + minor / 10.0
|
||||
# if sm_version >= 7.0:
|
||||
# backends.append("sage_attn_varlen_cuda_graph")
|
||||
if importlib.util.find_spec("flash_attn") is not None:
|
||||
for i in range(torch.cuda.device_count()):
|
||||
major, minor = torch.cuda.get_device_capability(i)
|
||||
sm_version = major + minor / 10.0
|
||||
if sm_version >= 7.5:
|
||||
backends.append("flash_attn_varlen_cuda_graph")
|
||||
# if torch.mps.is_available():
|
||||
# backends.append("mps_flash_attn_varlen")
|
||||
|
||||
|
||||
__all__ = ["T2SEngineTorch", "T2SRequest", "sample_naive", "T2SResult", "backends"]
|
@ -0,0 +1,158 @@
|
||||
"""
|
||||
Modified From https://github.com/XXXXRT666/GPT-SoVITS
|
||||
"""
|
||||
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
import kernels
|
||||
import torch
|
||||
|
||||
from .. import nn
|
||||
from ..structs import T2SSession
|
||||
from ..t2s_model_abc import (
|
||||
AttentionABC,
|
||||
CUDAGraphCacheABC,
|
||||
FeedForward,
|
||||
KVCacheNHD,
|
||||
KVCacheProtocol,
|
||||
T2SDecoderABC,
|
||||
TransformerBlockABC,
|
||||
TransformerDecoderABC,
|
||||
)
|
||||
|
||||
flash_attn_kernel = None
|
||||
try:
|
||||
import flash_attn_interface as flash_attn # type: ignore
|
||||
|
||||
flash_attn_kernel = flash_attn.flash_attn_with_kvcache
|
||||
except ModuleNotFoundError:
|
||||
try:
|
||||
import flash_attn # type: ignore
|
||||
|
||||
flash_attn_kernel = flash_attn.flash_attn_with_kvcache
|
||||
|
||||
except ModuleNotFoundError:
|
||||
pass
|
||||
|
||||
if flash_attn_kernel is None:
|
||||
flash_attn_kernel = kernels.get_kernel("kernels-community/flash-attn").flash_attn_with_kvcache
|
||||
|
||||
|
||||
Tensor = torch.Tensor
|
||||
|
||||
|
||||
class Attention(AttentionABC):
|
||||
def __init__(self, n_head, hidden_dim, max_seq_length):
|
||||
super().__init__(n_head, hidden_dim, max_seq_length)
|
||||
|
||||
self.in_proj = nn.Linear(hidden_dim, hidden_dim * 3, bias=True)
|
||||
self.out_proj = nn.Linear(hidden_dim, hidden_dim, bias=True)
|
||||
|
||||
def __call__(self, x: Tensor, input_pos: Tensor, kv_cache: KVCacheProtocol, *args, **kwds) -> Tensor:
|
||||
bsz, seqlen, _ = x.shape
|
||||
|
||||
q, k, v = self.in_proj(x).chunk(3, dim=-1)
|
||||
|
||||
q = q.view(bsz, seqlen, self.n_head, self.head_dim)
|
||||
k = k.view(bsz, seqlen, self.n_head, self.head_dim)
|
||||
v = v.view(bsz, seqlen, self.n_head, self.head_dim)
|
||||
|
||||
attn: Tensor = flash_attn.flash_attn_with_kvcache( # type: ignore
|
||||
q, kv_cache.k_cache, kv_cache.v_cache, k, v, cache_seqlens=input_pos - 1
|
||||
)
|
||||
|
||||
attn = attn.view(bsz, seqlen, self.hidden_dim)
|
||||
|
||||
attn = self.out_proj(attn)
|
||||
|
||||
return attn
|
||||
|
||||
|
||||
class TransformerBlock(TransformerBlockABC):
|
||||
def __init__(self, n_head, ffn_dim, hidden_dim, max_seq_length) -> None:
|
||||
super().__init__(n_head, ffn_dim, hidden_dim, max_seq_length)
|
||||
|
||||
self.attention = Attention(n_head, hidden_dim, max_seq_length)
|
||||
self.feed_forward = FeedForward(hidden_dim, ffn_dim)
|
||||
self.attention_norm = nn.LayerNorm([self.hidden_dim])
|
||||
self.ffn_norm = nn.LayerNorm([self.hidden_dim])
|
||||
|
||||
|
||||
class TransformerDecoder(TransformerDecoderABC):
|
||||
def __init__(
|
||||
self,
|
||||
hidden_dim,
|
||||
n_layer,
|
||||
n_head,
|
||||
ffn_dim,
|
||||
vocab_size,
|
||||
max_seq_length,
|
||||
max_batch_size,
|
||||
) -> None:
|
||||
super().__init__(hidden_dim, n_layer, n_head, ffn_dim, vocab_size, max_seq_length, max_batch_size)
|
||||
|
||||
self.layers = nn.ModuleList( # type: ignore
|
||||
TransformerBlock(n_head, ffn_dim, hidden_dim, max_seq_length) for _ in range(n_layer)
|
||||
)
|
||||
|
||||
|
||||
class T2SDecoder(T2SDecoderABC):
|
||||
def __init__(
|
||||
self,
|
||||
config,
|
||||
max_seq_length=2000,
|
||||
max_batch_size=10,
|
||||
) -> None:
|
||||
assert torch.cuda.is_available()
|
||||
super().__init__(config, max_seq_length, max_batch_size)
|
||||
|
||||
self.bert_proj = nn.Linear(1024, self.embedding_dim)
|
||||
self.ar_predict_layer = nn.Linear(self.hidden_dim, self.vocab_size, bias=False)
|
||||
self.h: TransformerDecoderABC = TransformerDecoder(
|
||||
self.hidden_dim, self.n_layer, self.n_head, self.ffn_dim, self.vocab_size, max_seq_length, max_batch_size
|
||||
)
|
||||
|
||||
self.kv_class = KVCacheNHD
|
||||
|
||||
def post_forward(self, idx: int, session: T2SSession) -> None:
|
||||
return super().post_forward(idx, session)
|
||||
|
||||
def pre_forward(self, session: T2SSession) -> Tuple[List, Dict]:
|
||||
return super().pre_forward(session)
|
||||
|
||||
|
||||
class CUDAGraphCache(CUDAGraphCacheABC):
|
||||
def __init__(
|
||||
self,
|
||||
decoder: T2SDecoder,
|
||||
) -> None:
|
||||
self.is_applicable = True
|
||||
super().__init__(decoder)
|
||||
|
||||
def release_graph(self, session: T2SSession):
|
||||
if session.id == self.id:
|
||||
self.assigned = False
|
||||
else:
|
||||
del session.graph, session.xy_pos_, session.xy_dec_, session.input_pos, session.kv_cache
|
||||
|
||||
def get_cache_graph(self, session: T2SSession):
|
||||
assert self.graph
|
||||
session.graph = self.graph
|
||||
session.stream = self.stream
|
||||
|
||||
session.xy_pos_ = self.xy_pos
|
||||
session.xy_dec_ = self.xy_dec
|
||||
session.input_pos = self.input_pos.copy_(session.input_pos)
|
||||
|
||||
for cache, cache_ in zip(self.kv_cache, session.kv_cache):
|
||||
cache.sync_cache(cache_)
|
||||
|
||||
def capture_new_graph(self, session: T2SSession):
|
||||
session.xy_pos_ = self.xy_pos.clone()
|
||||
session.xy_dec_ = self.xy_dec.clone()
|
||||
session.input_pos = self.input_pos.clone().copy_(session.input_pos)
|
||||
|
||||
args, kwds = self.decoder.pre_forward(session)
|
||||
graph = self.decoder.capture(self.input_pos, self.xy_pos, self.xy_dec, self.kv_cache, *args, **kwds)
|
||||
session.graph = graph
|
||||
session.stream = torch.cuda.Stream() # type: ignore
|
166
GPT_SoVITS/Accelerate/PyTorch/backends/mps_flash_attn_varlen.py
Normal file
166
GPT_SoVITS/Accelerate/PyTorch/backends/mps_flash_attn_varlen.py
Normal file
@ -0,0 +1,166 @@
|
||||
import torch
|
||||
from torch.nn import functional as F
|
||||
|
||||
from .. import nn
|
||||
from ..structs import KVCacheProtocol, T2SSession
|
||||
from ..t2s_model_abc import (
|
||||
AttentionABC,
|
||||
CUDAGraphCacheABC,
|
||||
FeedForward,
|
||||
KVCacheHND,
|
||||
T2SDecoderABC,
|
||||
TransformerBlockABC,
|
||||
TransformerDecoderABC,
|
||||
)
|
||||
|
||||
Tensor = torch.Tensor
|
||||
|
||||
|
||||
class Attention(AttentionABC):
|
||||
def __init__(self, n_head, hidden_dim, max_seq_length):
|
||||
super().__init__(n_head, hidden_dim, max_seq_length)
|
||||
|
||||
# key, query, value projections for all heads, but in a batch
|
||||
self.in_proj = nn.Linear(hidden_dim, hidden_dim * 3, bias=True)
|
||||
self.out_proj = nn.Linear(hidden_dim, hidden_dim, bias=True)
|
||||
|
||||
def __call__(self, x: Tensor, input_pos: Tensor, kv_cache: KVCacheProtocol, attn_mask: Tensor):
|
||||
bsz, seqlen, _ = x.shape
|
||||
|
||||
q, k, v = self.in_proj(x).chunk(3, dim=-1)
|
||||
|
||||
q = q.view(bsz, seqlen, self.n_head, self.head_dim)
|
||||
k = k.view(bsz, seqlen, self.n_head, self.head_dim)
|
||||
v = v.view(bsz, seqlen, self.n_head, self.head_dim)
|
||||
|
||||
q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v))
|
||||
|
||||
k, v = kv_cache.update(input_pos, k, v)
|
||||
|
||||
attn = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
|
||||
|
||||
attn = attn.transpose(1, 2).contiguous().view(bsz, seqlen, self.hidden_dim)
|
||||
|
||||
attn = self.out_proj(attn)
|
||||
|
||||
return attn
|
||||
|
||||
|
||||
class TransformerBlock(TransformerBlockABC):
|
||||
def __init__(self, n_head: int, ffn_dim: int, hidden_dim: int, max_seq_length: int) -> None:
|
||||
super().__init__(n_head, ffn_dim, hidden_dim, max_seq_length)
|
||||
|
||||
self.attention = Attention(n_head, hidden_dim, max_seq_length)
|
||||
self.feed_forward = FeedForward(hidden_dim, ffn_dim)
|
||||
self.attention_norm = nn.LayerNorm([self.hidden_dim])
|
||||
self.ffn_norm = nn.LayerNorm([self.hidden_dim])
|
||||
|
||||
|
||||
class TransformerDecoder(TransformerDecoderABC):
|
||||
def __init__(
|
||||
self,
|
||||
hidden_dim,
|
||||
n_layer,
|
||||
n_head,
|
||||
ffn_dim,
|
||||
vocab_size,
|
||||
max_seq_length,
|
||||
max_batch_size,
|
||||
) -> None:
|
||||
super().__init__(hidden_dim, n_layer, n_head, ffn_dim, vocab_size, max_seq_length, max_batch_size)
|
||||
|
||||
self.layers = nn.ModuleList( # type: ignore
|
||||
TransformerBlock(n_head, ffn_dim, hidden_dim, max_seq_length) for _ in range(n_layer)
|
||||
)
|
||||
|
||||
|
||||
class T2SDecoder(T2SDecoderABC):
|
||||
def __init__(
|
||||
self,
|
||||
config,
|
||||
max_seq_length=2000,
|
||||
max_batch_size=10,
|
||||
) -> None:
|
||||
super().__init__(config, max_seq_length, max_batch_size)
|
||||
|
||||
self.bert_proj = nn.Linear(1024, self.embedding_dim)
|
||||
self.ar_predict_layer = nn.Linear(self.hidden_dim, self.vocab_size, bias=False)
|
||||
self.h: TransformerDecoderABC = TransformerDecoder(
|
||||
self.hidden_dim, self.n_layer, self.n_head, self.ffn_dim, self.vocab_size, max_seq_length, max_batch_size
|
||||
)
|
||||
|
||||
self.kv_class = KVCacheHND
|
||||
|
||||
def pre_forward(self, session: T2SSession):
|
||||
attn_mask = session.attn_mask
|
||||
return list(), dict(attn_mask=attn_mask)
|
||||
|
||||
def post_forward(self, idx: int, session: T2SSession) -> None:
|
||||
if idx == 0:
|
||||
prefill_len = session.prefill_len
|
||||
bsz = session.bsz
|
||||
|
||||
range_tensor = torch.arange(self.max_seq_length).view(1, 1, 1, self.max_seq_length)
|
||||
prefill_len_expanded = prefill_len.view(bsz, 1, 1, 1)
|
||||
attn_mask = range_tensor < prefill_len_expanded
|
||||
attn_mask = attn_mask.expand(-1, self.n_head, -1, -1)
|
||||
|
||||
session.attn_mask = attn_mask
|
||||
|
||||
attn_mask = session.attn_mask
|
||||
input_pos = session.input_pos
|
||||
attn_mask[torch.arange(session.bsz), :, :, input_pos] = True
|
||||
|
||||
|
||||
class CUDAGraphCache(CUDAGraphCacheABC):
|
||||
def __init__(
|
||||
self,
|
||||
decoder,
|
||||
) -> None:
|
||||
self.is_applicable = False
|
||||
super().__init__(decoder)
|
||||
if torch.cuda.is_available():
|
||||
self.attn_mask = (
|
||||
torch.randint(0, 2, (decoder.max_batch_size, decoder.n_head, 1, decoder.max_seq_length))
|
||||
.bool()
|
||||
.to(self.device, self.dtype)
|
||||
)
|
||||
|
||||
def release_graph(self, session: T2SSession):
|
||||
if session.id == self.id:
|
||||
self.assigned = False
|
||||
else:
|
||||
del (
|
||||
session.graph,
|
||||
session.xy_pos_,
|
||||
session.xy_dec_,
|
||||
session.input_pos,
|
||||
session.kv_cache,
|
||||
session.attn_mask,
|
||||
)
|
||||
|
||||
def get_cache_graph(self, session: T2SSession):
|
||||
assert self.graph
|
||||
session.graph = self.graph
|
||||
session.stream = self.stream
|
||||
|
||||
session.xy_pos_ = self.xy_pos
|
||||
session.xy_dec_ = self.xy_dec
|
||||
session.input_pos = self.input_pos.copy_(session.input_pos)
|
||||
|
||||
session.attn_mask = self.attn_mask
|
||||
|
||||
for cache, cache_ in zip(self.kv_cache, session.kv_cache):
|
||||
cache.sync_cache(cache_)
|
||||
|
||||
def capture_new_graph(self, session: T2SSession):
|
||||
session.xy_pos_ = self.xy_pos.clone()
|
||||
session.xy_dec_ = self.xy_dec.clone()
|
||||
session.input_pos = self.input_pos.clone().copy_(session.input_pos)
|
||||
|
||||
session.attn_mask = self.attn_mask.clone().copy_(session.attn_mask)
|
||||
|
||||
args, kwds = self.decoder.pre_forward(session)
|
||||
graph = self.decoder.capture(self.input_pos, self.xy_pos, self.xy_dec, self.kv_cache, *args, **kwds)
|
||||
session.graph = graph
|
||||
session.stream = torch.cuda.Stream() # type: ignore
|
@ -0,0 +1,175 @@
|
||||
import sageattention # type: ignore
|
||||
import torch
|
||||
|
||||
from .. import nn
|
||||
from ..structs import T2SSession
|
||||
from ..t2s_model_abc import (
|
||||
AttentionABC,
|
||||
CUDAGraphCacheABC,
|
||||
FeedForward,
|
||||
KVCacheHND,
|
||||
KVCacheProtocol,
|
||||
T2SDecoderABC,
|
||||
TransformerBlockABC,
|
||||
TransformerDecoderABC,
|
||||
)
|
||||
|
||||
Tensor = torch.Tensor
|
||||
|
||||
|
||||
class Attention(AttentionABC):
|
||||
def __init__(self, n_head, hidden_dim, max_seq_length):
|
||||
super().__init__(n_head, hidden_dim, max_seq_length)
|
||||
|
||||
# key, query, value projections for all heads, but in a batch
|
||||
self.in_proj = nn.Linear(hidden_dim, hidden_dim * 3, bias=True)
|
||||
self.out_proj = nn.Linear(hidden_dim, hidden_dim, bias=True)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: Tensor,
|
||||
input_pos: Tensor,
|
||||
kv_cache: KVCacheProtocol,
|
||||
cu_seqlens_q: Tensor,
|
||||
cu_seqlens_kv: Tensor,
|
||||
) -> Tensor:
|
||||
bsz, seqlen, _ = x.shape
|
||||
|
||||
q, k, v = self.in_proj(x).chunk(3, dim=-1)
|
||||
|
||||
q = q.view(bsz, seqlen, self.n_head, self.head_dim)
|
||||
k = k.view(bsz, seqlen, self.n_head, self.head_dim)
|
||||
v = v.view(bsz, seqlen, self.n_head, self.head_dim)
|
||||
|
||||
q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v))
|
||||
|
||||
k, v = kv_cache.update(input_pos, k, v)
|
||||
|
||||
attn: Tensor = sageattention.sageattn_varlen(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
cu_seqlens_q=cu_seqlens_q,
|
||||
cu_seqlens_kv=cu_seqlens_kv,
|
||||
max_seqlen_q=1,
|
||||
max_seqlen_k=self.max_seq_length,
|
||||
)
|
||||
|
||||
attn = attn.transpose(1, 2).contiguous().view(bsz, seqlen, self.hidden_dim)
|
||||
|
||||
attn = self.out_proj(attn)
|
||||
|
||||
return attn
|
||||
|
||||
|
||||
class TransformerBlock(TransformerBlockABC):
|
||||
def __init__(self, n_head, ffn_dim, hidden_dim, max_seq_length) -> None:
|
||||
super().__init__(n_head, ffn_dim, hidden_dim, max_seq_length)
|
||||
|
||||
self.attention = Attention(n_head, hidden_dim, max_seq_length)
|
||||
self.feed_forward = FeedForward(hidden_dim, ffn_dim)
|
||||
self.attention_norm = nn.LayerNorm([self.hidden_dim])
|
||||
self.ffn_norm = nn.LayerNorm([self.hidden_dim])
|
||||
|
||||
|
||||
class TransformerDecoder(TransformerDecoderABC):
|
||||
def __init__(
|
||||
self,
|
||||
hidden_dim,
|
||||
n_layer,
|
||||
n_head,
|
||||
ffn_dim,
|
||||
vocab_size,
|
||||
max_seq_length,
|
||||
max_batch_size,
|
||||
) -> None:
|
||||
super().__init__(hidden_dim, n_layer, n_head, ffn_dim, vocab_size, max_seq_length, max_batch_size)
|
||||
|
||||
self.layers = nn.ModuleList( # type: ignore
|
||||
TransformerBlock(n_head, ffn_dim, hidden_dim, max_seq_length) for _ in range(n_layer)
|
||||
)
|
||||
|
||||
|
||||
class T2SDecoder(T2SDecoderABC):
|
||||
def __init__(
|
||||
self,
|
||||
config,
|
||||
max_seq_length=2000,
|
||||
max_batch_size=10,
|
||||
) -> None:
|
||||
super().__init__(config, max_seq_length, max_batch_size)
|
||||
|
||||
self.bert_proj = nn.Linear(1024, self.embedding_dim)
|
||||
self.ar_predict_layer = nn.Linear(self.hidden_dim, self.vocab_size, bias=False)
|
||||
self.h: TransformerDecoderABC = TransformerDecoder(
|
||||
self.hidden_dim, self.n_layer, self.n_head, self.ffn_dim, self.vocab_size, max_seq_length, max_batch_size
|
||||
)
|
||||
|
||||
self.kv_class = KVCacheHND
|
||||
|
||||
def pre_forward(self, session: T2SSession) -> tuple[list[Tensor], dict[str, Tensor]]:
|
||||
return list(), dict(cu_seqlens_q=session.cu_seqlens_q, cu_seqlens_kv=session.cu_seqlens_kv)
|
||||
|
||||
def post_forward(self, idx: int, session: T2SSession):
|
||||
if idx == 0:
|
||||
session.cu_seqlens_q = torch.arange(0, session.bsz + 1, dtype=torch.int32)
|
||||
session.cu_seqlens_kv = torch.cat([torch.tensor(0, dtype=torch.int32), session.input_pos])
|
||||
else:
|
||||
cu_seqlens_q = session.cu_seqlens_q
|
||||
cu_seqlens_kv = session.cu_seqlens_kv
|
||||
cu_seqlens_kv.add_(cu_seqlens_q)
|
||||
|
||||
|
||||
class CUDAGraphCache(CUDAGraphCacheABC):
|
||||
def __init__(
|
||||
self,
|
||||
decoder: T2SDecoder,
|
||||
) -> None:
|
||||
self.is_applicable = False
|
||||
super().__init__(decoder)
|
||||
|
||||
if torch.cuda.is_available():
|
||||
self.cu_seqlens_q = torch.arange(0, decoder.max_batch_size + 1, dtype=torch.int32).to(self.device)
|
||||
self.cu_seqlens_kv = torch.cat([torch.tensor(0, dtype=torch.int32), self.input_pos]).to(self.device)
|
||||
|
||||
def release_graph(self, session: T2SSession):
|
||||
if session.id == self.id:
|
||||
self.assigned = False
|
||||
else:
|
||||
del (
|
||||
session.graph,
|
||||
session.xy_pos_,
|
||||
session.xy_dec_,
|
||||
session.input_pos,
|
||||
session.kv_cache,
|
||||
session.cu_seqlens_q,
|
||||
session.cu_seqlens_kv,
|
||||
)
|
||||
|
||||
def get_cache_graph(self, session: T2SSession):
|
||||
assert self.graph
|
||||
session.graph = self.graph
|
||||
session.stream = self.stream
|
||||
|
||||
session.xy_pos_ = self.xy_pos
|
||||
session.xy_dec_ = self.xy_dec
|
||||
session.input_pos = self.input_pos.copy_(session.input_pos)
|
||||
|
||||
session.cu_seqlens_q = self.cu_seqlens_q
|
||||
session.cu_seqlens_kv = self.cu_seqlens_kv
|
||||
|
||||
for cache, cache_ in zip(self.kv_cache, session.kv_cache):
|
||||
cache.sync_cache(cache_)
|
||||
|
||||
def capture_new_graph(self, session: T2SSession):
|
||||
session.xy_pos_ = self.xy_pos.clone()
|
||||
session.xy_dec_ = self.xy_dec.clone()
|
||||
session.input_pos = self.input_pos.clone().copy_(session.input_pos)
|
||||
|
||||
session.cu_seqlens_q = self.cu_seqlens_q.clone().copy_(session.cu_seqlens_q)
|
||||
session.cu_seqlens_kv = self.cu_seqlens_kv.clone().copy_(session.cu_seqlens_kv)
|
||||
|
||||
args, kwds = self.decoder.pre_forward(session)
|
||||
graph = self.decoder.capture(self.input_pos, self.xy_pos, self.xy_dec, self.kv_cache, *args, **kwds)
|
||||
session.graph = graph
|
||||
session.stream = torch.cuda.Stream() # type: ignore
|
@ -0,0 +1,166 @@
|
||||
import torch
|
||||
from torch.nn import functional as F
|
||||
|
||||
from .. import nn
|
||||
from ..structs import KVCacheProtocol, T2SSession
|
||||
from ..t2s_model_abc import (
|
||||
AttentionABC,
|
||||
CUDAGraphCacheABC,
|
||||
FeedForward,
|
||||
KVCacheHND,
|
||||
T2SDecoderABC,
|
||||
TransformerBlockABC,
|
||||
TransformerDecoderABC,
|
||||
)
|
||||
|
||||
Tensor = torch.Tensor
|
||||
|
||||
|
||||
class Attention(AttentionABC):
|
||||
def __init__(self, n_head, hidden_dim, max_seq_length):
|
||||
super().__init__(n_head, hidden_dim, max_seq_length)
|
||||
|
||||
# key, query, value projections for all heads, but in a batch
|
||||
self.in_proj = nn.Linear(hidden_dim, hidden_dim * 3, bias=True)
|
||||
self.out_proj = nn.Linear(hidden_dim, hidden_dim, bias=True)
|
||||
|
||||
def __call__(self, x: Tensor, input_pos: Tensor, kv_cache: KVCacheProtocol, attn_mask: Tensor):
|
||||
bsz, seqlen, _ = x.shape
|
||||
|
||||
q, k, v = self.in_proj(x).chunk(3, dim=-1)
|
||||
|
||||
q = q.view(bsz, seqlen, self.n_head, self.head_dim)
|
||||
k = k.view(bsz, seqlen, self.n_head, self.head_dim)
|
||||
v = v.view(bsz, seqlen, self.n_head, self.head_dim)
|
||||
|
||||
q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v))
|
||||
|
||||
k, v = kv_cache.update(input_pos, k, v)
|
||||
|
||||
attn = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
|
||||
|
||||
attn = attn.transpose(1, 2).contiguous().view(bsz, seqlen, self.hidden_dim)
|
||||
|
||||
attn = self.out_proj(attn)
|
||||
|
||||
return attn
|
||||
|
||||
|
||||
class TransformerBlock(TransformerBlockABC):
|
||||
def __init__(self, n_head: int, ffn_dim: int, hidden_dim: int, max_seq_length: int) -> None:
|
||||
super().__init__(n_head, ffn_dim, hidden_dim, max_seq_length)
|
||||
|
||||
self.attention = Attention(n_head, hidden_dim, max_seq_length)
|
||||
self.feed_forward = FeedForward(hidden_dim, ffn_dim)
|
||||
self.attention_norm = nn.LayerNorm([self.hidden_dim])
|
||||
self.ffn_norm = nn.LayerNorm([self.hidden_dim])
|
||||
|
||||
|
||||
class TransformerDecoder(TransformerDecoderABC):
|
||||
def __init__(
|
||||
self,
|
||||
hidden_dim,
|
||||
n_layer,
|
||||
n_head,
|
||||
ffn_dim,
|
||||
vocab_size,
|
||||
max_seq_length,
|
||||
max_batch_size,
|
||||
) -> None:
|
||||
super().__init__(hidden_dim, n_layer, n_head, ffn_dim, vocab_size, max_seq_length, max_batch_size)
|
||||
|
||||
self.layers = nn.ModuleList( # type: ignore
|
||||
TransformerBlock(n_head, ffn_dim, hidden_dim, max_seq_length) for _ in range(n_layer)
|
||||
)
|
||||
|
||||
|
||||
class T2SDecoder(T2SDecoderABC):
|
||||
def __init__(
|
||||
self,
|
||||
config,
|
||||
max_seq_length=2000,
|
||||
max_batch_size=10,
|
||||
) -> None:
|
||||
super().__init__(config, max_seq_length, max_batch_size)
|
||||
|
||||
self.bert_proj = nn.Linear(1024, self.embedding_dim)
|
||||
self.ar_predict_layer = nn.Linear(self.hidden_dim, self.vocab_size, bias=False)
|
||||
self.h: TransformerDecoderABC = TransformerDecoder(
|
||||
self.hidden_dim, self.n_layer, self.n_head, self.ffn_dim, self.vocab_size, max_seq_length, max_batch_size
|
||||
)
|
||||
|
||||
self.kv_class = KVCacheHND
|
||||
|
||||
def pre_forward(self, session: T2SSession):
|
||||
attn_mask = session.attn_mask
|
||||
return list(), dict(attn_mask=attn_mask)
|
||||
|
||||
def post_forward(self, idx: int, session: T2SSession) -> None:
|
||||
if idx == 0:
|
||||
prefill_len = session.prefill_len
|
||||
bsz = session.bsz
|
||||
|
||||
range_tensor = torch.arange(self.max_seq_length).view(1, 1, 1, self.max_seq_length)
|
||||
prefill_len_expanded = prefill_len.view(bsz, 1, 1, 1)
|
||||
attn_mask = range_tensor < prefill_len_expanded
|
||||
attn_mask = attn_mask.expand(-1, self.n_head, -1, -1)
|
||||
|
||||
session.attn_mask = attn_mask
|
||||
|
||||
attn_mask = session.attn_mask
|
||||
input_pos = session.input_pos
|
||||
attn_mask[torch.arange(session.bsz), :, :, input_pos] = True
|
||||
|
||||
|
||||
class CUDAGraphCache(CUDAGraphCacheABC):
|
||||
def __init__(
|
||||
self,
|
||||
decoder,
|
||||
) -> None:
|
||||
self.is_applicable = True
|
||||
super().__init__(decoder)
|
||||
if torch.cuda.is_available():
|
||||
self.attn_mask = (
|
||||
torch.randint(0, 2, (decoder.max_batch_size, decoder.n_head, 1, decoder.max_seq_length))
|
||||
.bool()
|
||||
.to(self.device, self.dtype)
|
||||
)
|
||||
|
||||
def release_graph(self, session: T2SSession):
|
||||
if session.id == self.id:
|
||||
self.assigned = False
|
||||
else:
|
||||
del (
|
||||
session.graph,
|
||||
session.xy_pos_,
|
||||
session.xy_dec_,
|
||||
session.input_pos,
|
||||
session.kv_cache,
|
||||
session.attn_mask,
|
||||
)
|
||||
|
||||
def get_cache_graph(self, session: T2SSession):
|
||||
assert self.graph
|
||||
session.graph = self.graph
|
||||
session.stream = self.stream
|
||||
|
||||
session.xy_pos_ = self.xy_pos
|
||||
session.xy_dec_ = self.xy_dec
|
||||
session.input_pos = self.input_pos.copy_(session.input_pos)
|
||||
|
||||
session.attn_mask = self.attn_mask
|
||||
|
||||
for cache, cache_ in zip(self.kv_cache, session.kv_cache):
|
||||
cache.sync_cache(cache_)
|
||||
|
||||
def capture_new_graph(self, session: T2SSession):
|
||||
session.xy_pos_ = self.xy_pos.clone()
|
||||
session.xy_dec_ = self.xy_dec.clone()
|
||||
session.input_pos = self.input_pos.clone().copy_(session.input_pos)
|
||||
|
||||
session.attn_mask = self.attn_mask.clone().copy_(session.attn_mask)
|
||||
|
||||
args, kwds = self.decoder.pre_forward(session)
|
||||
graph = self.decoder.capture(self.input_pos, self.xy_pos, self.xy_dec, self.kv_cache, *args, **kwds)
|
||||
session.graph = graph
|
||||
session.stream = torch.cuda.Stream() # type: ignore
|
145
GPT_SoVITS/Accelerate/PyTorch/backends/torch_varlen.py
Normal file
145
GPT_SoVITS/Accelerate/PyTorch/backends/torch_varlen.py
Normal file
@ -0,0 +1,145 @@
|
||||
from typing import NoReturn
|
||||
|
||||
import torch
|
||||
from torch.nn import functional as F
|
||||
|
||||
from .. import nn
|
||||
from ..structs import KVCacheProtocol, T2SSession
|
||||
from ..t2s_model_abc import (
|
||||
AttentionABC,
|
||||
CUDAGraphCacheABC,
|
||||
FeedForward,
|
||||
KVCacheHNDVarlen,
|
||||
T2SDecoderABC,
|
||||
TransformerBlockABC,
|
||||
TransformerDecoderABC,
|
||||
)
|
||||
|
||||
Tensor = torch.Tensor
|
||||
|
||||
|
||||
class Attention(AttentionABC):
|
||||
def __init__(self, n_head, hidden_dim, max_seq_length):
|
||||
super().__init__(n_head, hidden_dim, max_seq_length)
|
||||
|
||||
# key, query, value projections for all heads, but in a batch
|
||||
self.in_proj = nn.Linear(hidden_dim, hidden_dim * 3, bias=True)
|
||||
self.out_proj = nn.Linear(hidden_dim, hidden_dim, bias=True)
|
||||
|
||||
def __call__(self, x: Tensor, input_pos: Tensor, kv_cache: KVCacheProtocol, attn_mask: Tensor):
|
||||
bsz, seqlen, _ = x.shape
|
||||
|
||||
q, k, v = self.in_proj(x).chunk(3, dim=-1)
|
||||
|
||||
q = q.view(bsz, seqlen, self.n_head, self.head_dim)
|
||||
k = k.view(bsz, seqlen, self.n_head, self.head_dim)
|
||||
v = v.view(bsz, seqlen, self.n_head, self.head_dim)
|
||||
|
||||
q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v))
|
||||
|
||||
k, v = kv_cache.update(input_pos, k, v)
|
||||
|
||||
max_idx = input_pos.max()
|
||||
|
||||
q, k, v = map(lambda x: x[..., :max_idx, :], (q, k, v))
|
||||
|
||||
mask = attn_mask[..., :max_idx]
|
||||
|
||||
attn = F.scaled_dot_product_attention(q, k, v, mask)
|
||||
|
||||
attn = attn.transpose(1, 2).contiguous().view(bsz, seqlen, self.hidden_dim)
|
||||
|
||||
attn = self.out_proj(attn)
|
||||
|
||||
return attn
|
||||
|
||||
|
||||
class TransformerBlock(TransformerBlockABC):
|
||||
def __init__(self, n_head: int, ffn_dim: int, hidden_dim: int, max_seq_length: int) -> None:
|
||||
super().__init__(n_head, ffn_dim, hidden_dim, max_seq_length)
|
||||
|
||||
self.attention = Attention(n_head, hidden_dim, max_seq_length)
|
||||
self.feed_forward = FeedForward(hidden_dim, ffn_dim)
|
||||
self.attention_norm = nn.LayerNorm([self.hidden_dim])
|
||||
self.ffn_norm = nn.LayerNorm([self.hidden_dim])
|
||||
|
||||
|
||||
class TransformerDecoder(TransformerDecoderABC):
|
||||
def __init__(
|
||||
self,
|
||||
hidden_dim,
|
||||
n_layer,
|
||||
n_head,
|
||||
ffn_dim,
|
||||
vocab_size,
|
||||
max_seq_length,
|
||||
max_batch_size,
|
||||
) -> None:
|
||||
super().__init__(hidden_dim, n_layer, n_head, ffn_dim, vocab_size, max_seq_length, max_batch_size)
|
||||
|
||||
self.layers = nn.ModuleList( # type: ignore
|
||||
TransformerBlock(n_head, ffn_dim, hidden_dim, max_seq_length) for _ in range(n_layer)
|
||||
)
|
||||
|
||||
|
||||
class T2SDecoder(T2SDecoderABC):
|
||||
def __init__(
|
||||
self,
|
||||
config,
|
||||
max_seq_length=2000,
|
||||
max_batch_size=10,
|
||||
) -> None:
|
||||
super().__init__(config, max_seq_length, max_batch_size)
|
||||
|
||||
self.bert_proj = nn.Linear(1024, self.embedding_dim)
|
||||
self.ar_predict_layer = nn.Linear(self.hidden_dim, self.vocab_size, bias=False)
|
||||
self.h: TransformerDecoderABC = TransformerDecoder(
|
||||
self.hidden_dim, self.n_layer, self.n_head, self.ffn_dim, self.vocab_size, max_seq_length, max_batch_size
|
||||
)
|
||||
|
||||
self.kv_class = KVCacheHNDVarlen
|
||||
|
||||
def capture(
|
||||
self,
|
||||
*args,
|
||||
**kwds,
|
||||
) -> NoReturn:
|
||||
raise NotImplementedError("Cuda Graph Is Not Supported For Varlen Model")
|
||||
|
||||
def pre_forward(self, session: T2SSession):
|
||||
attn_mask = session.attn_mask
|
||||
return list(), dict(attn_mask=attn_mask)
|
||||
|
||||
def post_forward(self, idx: int, session: T2SSession) -> None:
|
||||
if idx == 0:
|
||||
prefill_len = session.prefill_len
|
||||
bsz = session.bsz
|
||||
|
||||
range_tensor = torch.arange(self.max_seq_length).view(1, 1, 1, self.max_seq_length)
|
||||
prefill_len_expanded = prefill_len.view(bsz, 1, 1, 1)
|
||||
attn_mask = range_tensor < prefill_len_expanded
|
||||
attn_mask = attn_mask.expand(-1, self.n_head, -1, -1)
|
||||
|
||||
session.attn_mask = attn_mask
|
||||
|
||||
attn_mask = session.attn_mask
|
||||
input_pos = session.input_pos
|
||||
attn_mask[torch.arange(session.bsz), :, :, input_pos] = True
|
||||
|
||||
|
||||
class CUDAGraphCache(CUDAGraphCacheABC):
|
||||
def __init__(
|
||||
self,
|
||||
decoder,
|
||||
) -> None:
|
||||
self.is_applicable = False
|
||||
super().__init__(decoder)
|
||||
|
||||
def release_graph(self, session: T2SSession):
|
||||
raise NotImplementedError("Cuda Graph Is Not Supported For Varlen Model")
|
||||
|
||||
def get_cache_graph(self, session: T2SSession):
|
||||
raise NotImplementedError("Cuda Graph Is Not Supported For Varlen Model")
|
||||
|
||||
def capture_new_graph(self, session: T2SSession):
|
||||
raise NotImplementedError("Cuda Graph Is Not Supported For Varlen Model")
|
69
GPT_SoVITS/Accelerate/PyTorch/nn.py
Normal file
69
GPT_SoVITS/Accelerate/PyTorch/nn.py
Normal file
@ -0,0 +1,69 @@
|
||||
"""
|
||||
Enhanced Type Hint nn.Module
|
||||
Modified From https://github.com/labmlai/labml/blob/master/helpers/labml_helpers/module.py
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
import torch.nn
|
||||
from torch.nn import (
|
||||
functional as functional,
|
||||
)
|
||||
from torch.nn import (
|
||||
utils as utils,
|
||||
)
|
||||
from torch.nn.modules import * # type: ignore # noqa: F403
|
||||
from torch.nn.parameter import (
|
||||
Parameter as Parameter,
|
||||
)
|
||||
|
||||
Tensor = torch.Tensor
|
||||
|
||||
|
||||
class Module(torch.nn.Module):
|
||||
r"""
|
||||
Wraps ``torch.nn.Module`` to overload ``__call__`` instead of
|
||||
``forward`` for better type checking.
|
||||
|
||||
`PyTorch Github issue for clarification <https://github.com/pytorch/pytorch/issues/44605>`_
|
||||
"""
|
||||
|
||||
def _forward_unimplemented(self, *input: Any) -> None:
|
||||
# To stop PyTorch from giving abstract methods warning
|
||||
pass
|
||||
|
||||
def __init_subclass__(cls, **kwargs):
|
||||
if cls.__dict__.get("__call__", None) is None:
|
||||
return
|
||||
|
||||
setattr(cls, "forward", cls.__dict__["__call__"])
|
||||
delattr(cls, "__call__")
|
||||
|
||||
@property
|
||||
def device(self) -> torch.device:
|
||||
params = self.parameters()
|
||||
try:
|
||||
sample_param = next(params)
|
||||
return sample_param.device
|
||||
except StopIteration:
|
||||
raise RuntimeError(f"Unable to determine device of {self.__class__.__name__}") from None
|
||||
|
||||
|
||||
class Linear(torch.nn.Linear):
|
||||
def __call__(self, input: Tensor) -> Tensor:
|
||||
return super().__call__(input)
|
||||
|
||||
|
||||
class Dropout(torch.nn.Dropout):
|
||||
def __call__(self, input: Tensor) -> Tensor:
|
||||
return super().__call__(input)
|
||||
|
||||
|
||||
class Embedding(torch.nn.Embedding):
|
||||
def __call__(self, input: Tensor) -> Tensor:
|
||||
return super().__call__(input)
|
||||
|
||||
|
||||
class LayerNorm(torch.nn.LayerNorm):
|
||||
def __call__(self, input: Tensor) -> Tensor:
|
||||
return super().__call__(input)
|
67
GPT_SoVITS/Accelerate/PyTorch/sample_funcs.py
Normal file
67
GPT_SoVITS/Accelerate/PyTorch/sample_funcs.py
Normal file
@ -0,0 +1,67 @@
|
||||
from typing import Protocol
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
Tensor = torch.Tensor
|
||||
|
||||
|
||||
class SampleProtocol(Protocol):
|
||||
@staticmethod
|
||||
def __call__(
|
||||
logits: Tensor,
|
||||
previous_tokens: Tensor,
|
||||
temperature: float,
|
||||
top_k: int,
|
||||
top_p: float,
|
||||
repetition_penalty: float,
|
||||
) -> Tensor: ...
|
||||
|
||||
|
||||
class sample_naive(SampleProtocol):
|
||||
@staticmethod
|
||||
def __call__(
|
||||
logits: Tensor,
|
||||
previous_tokens: Tensor,
|
||||
temperature: float,
|
||||
top_k: int,
|
||||
top_p: float,
|
||||
repetition_penalty: float,
|
||||
):
|
||||
if temperature <= 1e-5:
|
||||
probs = F.softmax(logits, dim=-1)
|
||||
return torch.argmax(probs, dim=-1, keepdim=True).to(dtype=torch.int32)
|
||||
|
||||
if repetition_penalty != 1.0:
|
||||
previous_tokens = previous_tokens.long()
|
||||
score = torch.gather(logits, dim=1, index=previous_tokens)
|
||||
score = torch.where(
|
||||
score < 0,
|
||||
score * repetition_penalty,
|
||||
score / repetition_penalty,
|
||||
)
|
||||
logits.scatter_(dim=1, index=previous_tokens, src=score)
|
||||
|
||||
if top_p < 1.0:
|
||||
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
|
||||
cum_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
|
||||
cum_probs[cum_probs > 1] = 1
|
||||
sorted_indices_to_remove = cum_probs > top_p
|
||||
sorted_indices_to_remove[:, 0] = False # keep at least one option
|
||||
indices_to_remove = sorted_indices_to_remove.scatter(
|
||||
dim=1, index=sorted_indices, src=sorted_indices_to_remove
|
||||
)
|
||||
logits = logits.masked_fill(indices_to_remove, -float("Inf"))
|
||||
|
||||
if temperature < 1.0:
|
||||
logits /= temperature
|
||||
|
||||
v, _ = torch.topk(logits, top_k)
|
||||
pivot = v[:, -1].unsqueeze(-1)
|
||||
logits = torch.where(logits < pivot, -float("Inf"), logits)
|
||||
|
||||
probs = F.softmax(logits, dim=-1)
|
||||
q = -torch.log(torch.rand_like(probs))
|
||||
idx_next = torch.argmax(probs / q, dim=-1, keepdim=True).to(dtype=torch.int32)
|
||||
|
||||
return idx_next
|
151
GPT_SoVITS/Accelerate/PyTorch/structs.py
Normal file
151
GPT_SoVITS/Accelerate/PyTorch/structs.py
Normal file
@ -0,0 +1,151 @@
|
||||
"""
|
||||
Modified From https://github.com/XXXXRT666/GPT-SoVITS
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Literal, MutableSequence, Optional, Protocol
|
||||
|
||||
import torch
|
||||
|
||||
from .sample_funcs import SampleProtocol, sample_naive
|
||||
|
||||
Tensor = torch.Tensor
|
||||
|
||||
|
||||
@dataclass
|
||||
class T2SResult:
|
||||
result: list[Tensor] | None = None
|
||||
infer_speed: tuple[float, float] = (0.0, 0.0)
|
||||
status: Literal["Success", "Error"] = "Success"
|
||||
exception: Optional[Exception] = None
|
||||
traceback: Optional[str] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class T2SRequest:
|
||||
x: list[torch.Tensor]
|
||||
x_lens: Tensor
|
||||
prompts: torch.Tensor
|
||||
bert_feature: list[Tensor]
|
||||
valid_length: int
|
||||
top_k: int = 5
|
||||
top_p: float = 1
|
||||
early_stop_num: int = -1
|
||||
temperature: float = 1.0
|
||||
repetition_penalty: float = 1.35
|
||||
use_cuda_graph: bool = False
|
||||
debug: bool = False
|
||||
|
||||
|
||||
class KVCacheProtocol(Protocol):
|
||||
k_cache: Tensor
|
||||
v_cache: Tensor
|
||||
|
||||
def __init__(self, batch_size: int, max_seq_length: int, n_heads: int, head_dim: int) -> None: ...
|
||||
|
||||
def empty(self) -> None: ...
|
||||
|
||||
def update(self, input_pos: Tensor, k_val: Tensor, v_val: Tensor, *args, **kwds) -> tuple[Tensor, Tensor]: ...
|
||||
|
||||
def prefill_kv(self, k_val: Tensor, v_val: Tensor) -> None: ...
|
||||
|
||||
def sync_cache(self, kv_cache: KVCacheProtocol) -> None: ...
|
||||
|
||||
|
||||
class T2SDecoderProtocol(Protocol):
|
||||
max_seq_length: int
|
||||
EOS: int
|
||||
n_head: int
|
||||
|
||||
@property
|
||||
def device(self) -> torch.device: ...
|
||||
|
||||
def embed(self, x: list[Tensor], y: Tensor, bert_features: list[Tensor]) -> Tensor: ...
|
||||
|
||||
|
||||
class T2SEngineProtocol(Protocol):
|
||||
def _handle_request(self, request: T2SRequest) -> tuple[list[Tensor], float, float]: ...
|
||||
|
||||
def generate(self, request: T2SRequest) -> T2SResult: ...
|
||||
|
||||
|
||||
class T2SSession:
|
||||
def __init__(
|
||||
self,
|
||||
decoder: T2SDecoderProtocol,
|
||||
request: T2SRequest,
|
||||
sapmle_func: type[SampleProtocol] = sample_naive,
|
||||
device: torch.device = torch.device("cpu"),
|
||||
dtype: torch.dtype = torch.float32,
|
||||
):
|
||||
with device:
|
||||
self.decoder = decoder
|
||||
self.request = request
|
||||
self.device = device
|
||||
self.dtype = dtype
|
||||
|
||||
bsz = len(request.x)
|
||||
y_len = request.prompts.size(-1)
|
||||
self.bsz = bsz
|
||||
self.y_len = y_len
|
||||
request.prompts = request.prompts.to(device, torch.int32)
|
||||
|
||||
# Cache
|
||||
self.kv_cache: MutableSequence[KVCacheProtocol]
|
||||
self.sample = sapmle_func()
|
||||
|
||||
# Forward args
|
||||
self.x = [i.to(device) for i in request.x]
|
||||
self.x_lens = request.x_lens.to(torch.int32)
|
||||
self.y = torch.zeros((bsz, decoder.max_seq_length)).to(torch.int32)
|
||||
self.y[:, : request.prompts.shape[-1]] = request.prompts
|
||||
self.bert_feature = [i.to(device, dtype) for i in request.bert_feature]
|
||||
|
||||
self.prefill_len = self.x_lens + request.prompts.size(1)
|
||||
|
||||
self.input_pos = torch.zeros_like(self.prefill_len)
|
||||
self.input_pos.add_(self.prefill_len)
|
||||
|
||||
# CUDA Graph
|
||||
self.stream: Optional[torch.cuda.Stream] = None
|
||||
self.graph: Optional[torch.cuda.CUDAGraph] = None
|
||||
self.xy_pos_: Tensor
|
||||
self.xy_dec_: Tensor
|
||||
|
||||
# EOS
|
||||
self.completed = torch.Tensor([False] * len(self.x)).bool().to(device)
|
||||
self.y_results: list[Tensor] = [None] * len(self.x) # type: ignore
|
||||
|
||||
self.xy_pos = decoder.embed(self.x, request.prompts, self.bert_feature)
|
||||
|
||||
max_len = int(self.prefill_len.max().item())
|
||||
attn_mask = torch.zeros(size=(bsz, max_len, max_len), dtype=torch.bool)
|
||||
|
||||
for bs in range(bsz):
|
||||
pos = int(self.x_lens[bs])
|
||||
seq_len = pos + y_len
|
||||
|
||||
attn_mask[bs, :seq_len, :pos] = True
|
||||
|
||||
ar_mask = ~torch.triu(
|
||||
input=torch.ones(
|
||||
size=(
|
||||
y_len,
|
||||
y_len,
|
||||
),
|
||||
dtype=torch.bool,
|
||||
),
|
||||
diagonal=1,
|
||||
)
|
||||
attn_mask[bs, pos:seq_len, pos:seq_len] = ar_mask
|
||||
|
||||
self.attn_mask = attn_mask
|
||||
self.attn_mask = attn_mask.unsqueeze(0).expand(-1, decoder.n_head, -1, -1)
|
||||
|
||||
self.id: int = -1
|
||||
|
||||
# Sage Attn & Transformer Engine Impl
|
||||
self.cu_seqlens_q: Tensor
|
||||
self.cu_seqlens_kv: Tensor
|
223
GPT_SoVITS/Accelerate/PyTorch/t2s_engine.py
Normal file
223
GPT_SoVITS/Accelerate/PyTorch/t2s_engine.py
Normal file
@ -0,0 +1,223 @@
|
||||
import contextlib
|
||||
import gc
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import traceback
|
||||
from importlib import import_module
|
||||
|
||||
import torch
|
||||
from rich.progress import BarColumn, Progress, TextColumn
|
||||
|
||||
from ..logger import SpeedColumnToken, console, logger
|
||||
from .structs import T2SEngineProtocol, T2SRequest, T2SResult, T2SSession
|
||||
from .t2s_model_abc import (
|
||||
CUDAGraphCacheABC,
|
||||
T2SDecoderABC,
|
||||
TorchProfiler,
|
||||
)
|
||||
|
||||
|
||||
class T2SEngine(T2SEngineProtocol):
|
||||
def __init__(
|
||||
self,
|
||||
decoder_model: T2SDecoderABC,
|
||||
device: torch.device = torch.device("cpu"),
|
||||
dtype: torch.dtype = torch.float32,
|
||||
) -> None:
|
||||
assert device.type in {"cpu", "cuda", "mps", "xpu", "mtia"}
|
||||
assert dtype in {torch.float16, torch.bfloat16, torch.float32}
|
||||
|
||||
self.device = device
|
||||
self.dtype = dtype
|
||||
|
||||
self.decoder_model: T2SDecoderABC = decoder_model.to(self.device, self.dtype)
|
||||
|
||||
self.graphcache: CUDAGraphCacheABC = self.init_cache()
|
||||
|
||||
def _handle_request(self, request: T2SRequest):
|
||||
with self.device:
|
||||
decoder = self.decoder_model
|
||||
session = T2SSession(decoder, request, device=self.device, dtype=self.dtype)
|
||||
batch_idx = torch.arange(session.bsz)
|
||||
|
||||
t1 = 0.0
|
||||
infer_speed = 0.0
|
||||
infer_time = 0.0
|
||||
|
||||
torch_profiler = TorchProfiler(request.debug)
|
||||
with (
|
||||
torch_profiler.profiler(),
|
||||
Progress(
|
||||
TextColumn("[cyan]{task.description}"),
|
||||
BarColumn(),
|
||||
TextColumn("{task.completed}/{task.total} tokens"),
|
||||
SpeedColumnToken(show_speed=True),
|
||||
console=console,
|
||||
transient=True,
|
||||
) as progress,
|
||||
):
|
||||
max_token = int(min(2000 - session.input_pos.max(), 1500))
|
||||
task = progress.add_task("T2S Decoding", total=max_token)
|
||||
|
||||
for idx in range(max_token):
|
||||
progress.update(task, advance=1)
|
||||
if idx == 0:
|
||||
session.kv_cache = decoder.init_cache(session.bsz)
|
||||
xy_dec = decoder.h.prefill(session.xy_pos, session.kv_cache, session.attn_mask)
|
||||
xy_dec = xy_dec[None, batch_idx, session.input_pos - 1]
|
||||
else:
|
||||
if (
|
||||
request.use_cuda_graph
|
||||
and session.graph is None
|
||||
and self.graphcache.is_applicable
|
||||
and torch.cuda.is_available()
|
||||
):
|
||||
self.graphcache.assign_graph(session)
|
||||
|
||||
with torch_profiler.record("AR"):
|
||||
if session.graph:
|
||||
assert session.stream
|
||||
session.stream.wait_stream(torch.cuda.default_stream())
|
||||
with torch.cuda.stream(session.stream):
|
||||
session.xy_pos_.copy_(session.xy_pos)
|
||||
session.graph.replay()
|
||||
xy_dec = session.xy_dec_.clone()
|
||||
else:
|
||||
args, kwds = decoder.pre_forward(session)
|
||||
xy_dec = decoder.h(
|
||||
session.input_pos,
|
||||
session.xy_pos,
|
||||
session.kv_cache,
|
||||
*args,
|
||||
**kwds,
|
||||
)
|
||||
|
||||
with torch.cuda.stream(session.stream) if session.stream is not None else contextlib.nullcontext():
|
||||
decoder.post_forward(idx, session)
|
||||
logits = decoder.ar_predict_layer(xy_dec[:, -1])
|
||||
|
||||
if idx == 0:
|
||||
logits[:, -1] = float("-inf")
|
||||
|
||||
with torch_profiler.record("Sampling"):
|
||||
samples = session.sample(
|
||||
logits=logits,
|
||||
previous_tokens=session.y[:, : session.y_len + idx],
|
||||
top_k=request.top_k,
|
||||
top_p=request.top_p,
|
||||
repetition_penalty=request.repetition_penalty,
|
||||
temperature=request.temperature,
|
||||
)
|
||||
session.y[batch_idx, session.y_len + idx] = samples
|
||||
session.input_pos.add_(1)
|
||||
|
||||
with torch_profiler.record("EOS"):
|
||||
argmax_token = torch.argmax(logits, dim=-1)
|
||||
sample_token = samples.squeeze(1)
|
||||
EOS_mask = (argmax_token == decoder.EOS) | (sample_token == decoder.EOS)
|
||||
|
||||
newly_done_mask = EOS_mask & (~session.completed)
|
||||
newly_done_indices = newly_done_mask.nonzero()
|
||||
|
||||
if newly_done_indices.numel() > 0:
|
||||
for i in newly_done_indices:
|
||||
session.y_results[i] = session.y[i, session.y_len : session.y_len + idx]
|
||||
session.completed[newly_done_indices] = True
|
||||
|
||||
if torch.all(session.completed).item():
|
||||
if session.y[:, session.y_len :].sum() == 0:
|
||||
session.y_results = [torch.tensor(0) for _ in range(session.bsz)]
|
||||
logger.error("Bad Zero Prediction")
|
||||
else:
|
||||
logger.info(
|
||||
f"T2S Decoding EOS {session.prefill_len.tolist().__str__().strip('[]')} -> {[i.size(-1) for i in session.y_results].__str__().strip('[]')}"
|
||||
)
|
||||
logger.info(f"Infer Speed: {(idx - 1) / (time.perf_counter() - t1):.2f} token/s")
|
||||
infer_time = time.perf_counter() - t1
|
||||
infer_speed = (idx - 1) / infer_time
|
||||
break
|
||||
|
||||
if (request.early_stop_num != -1 and idx >= request.early_stop_num) or idx == max_token - 1:
|
||||
for i in range(session.bsz):
|
||||
if not session.completed[i].item():
|
||||
session.y_results[i] = session.y[i, session.y_len : session.y_len + 1499]
|
||||
session.completed[i] = True
|
||||
logger.error("Bad Full Prediction")
|
||||
break
|
||||
|
||||
with torch_profiler.record("NextPos"):
|
||||
y_emb = decoder.ar_audio_embedding(samples)
|
||||
session.xy_pos = decoder.ar_audio_position(session.input_pos - session.x_lens, y_emb)
|
||||
|
||||
if idx == 1:
|
||||
torch_profiler.start()
|
||||
t1 = time.perf_counter()
|
||||
|
||||
if idx == 51:
|
||||
torch_profiler.end()
|
||||
|
||||
if idx % 100 == 0:
|
||||
match session.device.type:
|
||||
case "cuda":
|
||||
torch.cuda.empty_cache()
|
||||
case "mps":
|
||||
torch.mps.empty_cache()
|
||||
case "xpu":
|
||||
torch.xpu.empty_cache()
|
||||
case "mtia":
|
||||
torch.mtia.empty_cache()
|
||||
|
||||
match session.device.type:
|
||||
case "cuda":
|
||||
if session.stream is not None:
|
||||
torch.cuda.current_stream().wait_stream(session.stream)
|
||||
torch.cuda.empty_cache()
|
||||
case "mps":
|
||||
torch.mps.empty_cache()
|
||||
case "xpu":
|
||||
torch.xpu.empty_cache()
|
||||
case "mtia":
|
||||
torch.mtia.empty_cache()
|
||||
case "cpu":
|
||||
gc.collect()
|
||||
|
||||
torch_profiler.end()
|
||||
if request.use_cuda_graph and torch.cuda.is_available():
|
||||
self.graphcache.release_graph(session)
|
||||
|
||||
return session.y_results[: request.valid_length], infer_speed, infer_time
|
||||
|
||||
def generate(self, request: T2SRequest):
|
||||
try:
|
||||
result, infer_speed, infer_time = self._handle_request(request)
|
||||
t2s_result = T2SResult(result=result, infer_speed=(infer_speed, infer_time), status="Success")
|
||||
except Exception as e:
|
||||
t2s_result = T2SResult(status="Error", exception=e, traceback=traceback.format_exc())
|
||||
return t2s_result
|
||||
|
||||
@staticmethod
|
||||
def load_decoder(weights_path: os.PathLike, max_batch_size: int = 1, backend: str = "Flash-Attn-Varlen-CUDAGraph"):
|
||||
logger.info(f"Loading Text2Semantic Weights from {weights_path} with {backend} Backend")
|
||||
module_path = f".backends.{backend.lower().replace('-', '_').replace('cudagraph', 'cuda_graph')}"
|
||||
decoder_cls_name = "T2SDecoder"
|
||||
decoder_mod = import_module(module_path, package=__package__)
|
||||
decoder_cls: type[T2SDecoderABC] = getattr(decoder_mod, decoder_cls_name)
|
||||
dict_s1 = torch.load(weights_path, map_location="cpu", weights_only=False, mmap=True)
|
||||
config = dict_s1["config"]
|
||||
decoder: T2SDecoderABC = decoder_cls(config, max_batch_size=max_batch_size)
|
||||
state_dict = dict_s1["weight"]
|
||||
decoder.load_state_dict(state_dict)
|
||||
|
||||
return decoder.eval()
|
||||
|
||||
def init_cache(self):
|
||||
assert self.decoder_model
|
||||
|
||||
module_name = self.decoder_model.__class__.__module__
|
||||
module = sys.modules.get(module_name)
|
||||
assert module
|
||||
|
||||
target_class: type[CUDAGraphCacheABC] = getattr(module, "CUDAGraphCache")
|
||||
|
||||
return target_class(self.decoder_model)
|
672
GPT_SoVITS/Accelerate/PyTorch/t2s_model_abc.py
Normal file
672
GPT_SoVITS/Accelerate/PyTorch/t2s_model_abc.py
Normal file
@ -0,0 +1,672 @@
|
||||
"""
|
||||
Modified From https://github.com/XXXXRT666/GPT-SoVITS
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
import os
|
||||
import random
|
||||
from abc import ABC, abstractmethod
|
||||
from contextlib import nullcontext
|
||||
from typing import MutableSequence
|
||||
|
||||
import torch
|
||||
import torch._inductor.config
|
||||
import torch.nn.functional as F
|
||||
from torch.cuda.graphs import CUDAGraph
|
||||
from torch.profiler import ProfilerAction, tensorboard_trace_handler
|
||||
|
||||
from . import nn
|
||||
from .structs import KVCacheProtocol, T2SDecoderProtocol, T2SSession
|
||||
|
||||
Tensor = torch.Tensor
|
||||
|
||||
|
||||
class TokenEmbedding(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
embedding_dim: int,
|
||||
vocab_size: int,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.vocab_size = vocab_size
|
||||
self.embedding_dim = embedding_dim
|
||||
|
||||
self.word_embeddings = nn.Embedding(self.vocab_size, self.embedding_dim)
|
||||
|
||||
@property
|
||||
def weight(self) -> Tensor:
|
||||
return self.word_embeddings.weight
|
||||
|
||||
def embedding(self, index: int) -> Tensor:
|
||||
return self.word_embeddings.weight[index : index + 1]
|
||||
|
||||
def __call__(self, x: Tensor):
|
||||
x = self.word_embeddings(x)
|
||||
return x
|
||||
|
||||
|
||||
class SinePositionalEmbedding(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
embedding_dim: int,
|
||||
scale: bool = False,
|
||||
alpha: bool = False,
|
||||
max_batch_size: int = 10,
|
||||
max_seq_len: int = 2000,
|
||||
):
|
||||
super().__init__()
|
||||
self.embedding_dim = embedding_dim
|
||||
self.x_scale = math.sqrt(embedding_dim) if scale else 1.0
|
||||
self.alpha = nn.Parameter(torch.ones(1), requires_grad=alpha)
|
||||
self.max_batch_size = max_batch_size
|
||||
self.max_seq_len = max_seq_len
|
||||
|
||||
self.reverse = False
|
||||
self.register_buffer("pe", torch.zeros(max_batch_size, max_seq_len, embedding_dim), persistent=False)
|
||||
self.pe: torch.Tensor
|
||||
self.compute_pe()
|
||||
|
||||
def compute_pe(self):
|
||||
"""Reset the positional encodings."""
|
||||
if self.reverse:
|
||||
position = torch.arange(self.max_seq_len - 1, -1, -1.0, dtype=torch.float32).unsqueeze(1)
|
||||
else:
|
||||
position = torch.arange(self.max_seq_len, dtype=torch.float32).unsqueeze(1)
|
||||
div_term = torch.exp(
|
||||
torch.arange(0, self.embedding_dim, 2, dtype=torch.float32) * -(math.log(10000.0) / self.embedding_dim)
|
||||
)
|
||||
pe = self.pe
|
||||
pe[:, :, 0::2] = torch.sin(position * div_term)
|
||||
pe[:, :, 1::2] = torch.cos(position * div_term)
|
||||
|
||||
def __call__(self, input_pos: Tensor, x: Tensor) -> Tensor:
|
||||
"""
|
||||
Args:
|
||||
input_pos (Tensor): [batch_size, ]
|
||||
x (Tensor): [batch_size, 1, embed_dim]
|
||||
|
||||
Returns:
|
||||
embedded_x (Tensor): [batch_size, 1, embed_dim]
|
||||
"""
|
||||
|
||||
batch_size = x.shape[0]
|
||||
pe_values = self.pe[torch.arange(batch_size), input_pos - 1] # (batch_size, embed_dim)
|
||||
|
||||
return x * self.x_scale + self.alpha * pe_values.unsqueeze(1) # (batch_size, 1, embed_dim)
|
||||
|
||||
def prefill(self, x: Tensor) -> Tensor:
|
||||
"""
|
||||
Args:
|
||||
x (Tensor): [batch_size, seq_len, embed_dim]
|
||||
|
||||
Returns:
|
||||
embedded_x (Tensor): [batch_size, seq_len, embed_dim]
|
||||
"""
|
||||
|
||||
batch_size = x.shape[0]
|
||||
pe_values = self.pe[:batch_size, : x.shape[-2]]
|
||||
return x * self.x_scale + self.alpha * pe_values
|
||||
|
||||
|
||||
class KVCacheABC(nn.Module, ABC, KVCacheProtocol):
|
||||
def __init__(self, batch_size: int, max_seq_length: int, n_heads: int, head_dim: int) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.n_head = n_heads
|
||||
self.head_dim = head_dim
|
||||
self.batch_size = batch_size
|
||||
self.max_seq_length = max_seq_length
|
||||
|
||||
self.k_cache: Tensor
|
||||
self.v_cache: Tensor
|
||||
|
||||
def empty(self):
|
||||
self.k_cache.zero_()
|
||||
self.v_cache.zero_()
|
||||
|
||||
@abstractmethod
|
||||
def update(self, input_pos: Tensor, k_val: Tensor, v_val: Tensor, *args, **kwds) -> tuple[Tensor, Tensor]: ...
|
||||
|
||||
@abstractmethod
|
||||
def prefill_kv(self, k_val: Tensor, v_val: Tensor) -> None: ...
|
||||
|
||||
def sync_cache(self, kv_cache: KVCacheProtocol):
|
||||
self.k_cache.copy_(kv_cache.k_cache)
|
||||
self.v_cache.copy_(kv_cache.v_cache)
|
||||
|
||||
|
||||
class KVCacheNHD(KVCacheABC):
|
||||
def __init__(self, batch_size, max_seq_length, n_heads, head_dim):
|
||||
super().__init__(batch_size, max_seq_length, n_heads, head_dim)
|
||||
|
||||
assert batch_size > 0
|
||||
cache_shape = (batch_size, max_seq_length, n_heads, head_dim)
|
||||
|
||||
self.register_buffer("k_cache", torch.zeros(size=cache_shape), persistent=False)
|
||||
self.register_buffer("v_cache", torch.zeros(size=cache_shape), persistent=False)
|
||||
|
||||
def update(self, input_pos: Tensor, k_val: Tensor, v_val: Tensor):
|
||||
# input_pos: [B, ], k_val: [B, 1, H, D]
|
||||
|
||||
index = (
|
||||
(input_pos - 1)
|
||||
.unsqueeze(-1)
|
||||
.unsqueeze(-1)
|
||||
.unsqueeze(-1)
|
||||
.expand(
|
||||
-1,
|
||||
-1,
|
||||
self.n_head,
|
||||
self.head_dim,
|
||||
)
|
||||
.to(torch.int64)
|
||||
) # (bs, 1, num_head, head_dim)
|
||||
|
||||
k_out = self.k_cache
|
||||
v_out = self.v_cache
|
||||
k_out.scatter_(1, index, k_val)
|
||||
v_out.scatter_(1, index, v_val)
|
||||
|
||||
return k_out, v_out
|
||||
|
||||
def empty(self):
|
||||
self.k_cache.zero_()
|
||||
self.v_cache.zero_()
|
||||
|
||||
def prefill_kv(self, k_val: Tensor, v_val: Tensor):
|
||||
# input_pos: int, k_val: [B, S, H, D]
|
||||
|
||||
self.k_cache[:, : k_val.shape[1]] = k_val
|
||||
self.v_cache[:, : v_val.shape[1]] = v_val
|
||||
|
||||
|
||||
class KVCacheHND(KVCacheABC):
|
||||
def __init__(self, batch_size, max_seq_length, n_heads, head_dim):
|
||||
super().__init__(batch_size, max_seq_length, n_heads, head_dim)
|
||||
|
||||
cache_shape = (batch_size, n_heads, max_seq_length, head_dim)
|
||||
|
||||
self.register_buffer("k_cache", torch.zeros(size=cache_shape), persistent=False)
|
||||
self.register_buffer("v_cache", torch.zeros(size=cache_shape), persistent=False)
|
||||
|
||||
def update(self, input_pos: Tensor, k_val: Tensor, v_val: Tensor):
|
||||
# input_pos: [B, ], k_val: [B, H, 1, D]
|
||||
|
||||
index = (
|
||||
(input_pos - 1)
|
||||
.unsqueeze(-1)
|
||||
.unsqueeze(-1)
|
||||
.unsqueeze(-1)
|
||||
.expand(
|
||||
-1,
|
||||
self.n_head,
|
||||
-1,
|
||||
self.head_dim,
|
||||
)
|
||||
.to(torch.int64)
|
||||
) # (bs, num_head, 1, head_dim)
|
||||
|
||||
k_out = self.k_cache
|
||||
v_out = self.v_cache
|
||||
k_out.scatter_(2, index, k_val)
|
||||
v_out.scatter_(2, index, v_val)
|
||||
|
||||
return k_out, v_out
|
||||
|
||||
def empty(self):
|
||||
self.k_cache.zero_()
|
||||
self.v_cache.zero_()
|
||||
|
||||
def prefill_kv(self, k_val: Tensor, v_val: Tensor):
|
||||
# input_pos: int, k_val: [B, S, H, D]
|
||||
|
||||
self.k_cache[..., : k_val.shape[1], :] = k_val.transpose(1, 2)
|
||||
self.v_cache[..., : v_val.shape[1], :] = v_val.transpose(1, 2)
|
||||
|
||||
|
||||
class KVCacheHNDVarlen(KVCacheABC):
|
||||
def __init__(self, batch_size, max_seq_length, n_heads, head_dim):
|
||||
super().__init__(batch_size, max_seq_length, n_heads, head_dim)
|
||||
|
||||
cache_shape = (batch_size, n_heads, max_seq_length, head_dim)
|
||||
self.cache_idx: Tensor
|
||||
|
||||
self.register_buffer("cache_idx", torch.arange(batch_size), persistent=False)
|
||||
self.register_buffer("k_cache", torch.zeros(size=cache_shape), persistent=False)
|
||||
self.register_buffer("v_cache", torch.zeros(size=cache_shape), persistent=False)
|
||||
|
||||
def update(self, input_pos: Tensor, k_val: Tensor, v_val: Tensor):
|
||||
# input_pos: [B, ], k_val: [B, H, 1, D]
|
||||
|
||||
k_out = self.k_cache
|
||||
v_out = self.v_cache
|
||||
|
||||
ip0 = input_pos - 1
|
||||
|
||||
k_out[self.cache_idx, :, ip0, None] = k_val
|
||||
v_out[self.cache_idx, :, ip0, None] = v_val
|
||||
|
||||
return k_out, v_out
|
||||
|
||||
def empty(self):
|
||||
self.k_cache.zero_()
|
||||
self.v_cache.zero_()
|
||||
|
||||
def prefill_kv(self, k_val: Tensor, v_val: Tensor):
|
||||
# input_pos: int, k_val: [B, S, H, D]
|
||||
|
||||
self.k_cache[..., : k_val.shape[1], :] = k_val.transpose(1, 2)
|
||||
self.v_cache[..., : v_val.shape[1], :] = v_val.transpose(1, 2)
|
||||
|
||||
|
||||
class AttentionABC(nn.Module, ABC):
|
||||
def __init__(self, n_head: int, hidden_dim: int, max_seq_length: int):
|
||||
super().__init__()
|
||||
|
||||
self.n_head = n_head
|
||||
self.hidden_dim = hidden_dim
|
||||
assert hidden_dim % n_head == 0
|
||||
self.head_dim = hidden_dim // n_head
|
||||
|
||||
self.max_seq_length = max_seq_length
|
||||
|
||||
# key, query, value projections for all heads, but in a batch
|
||||
self.in_proj: nn.Linear
|
||||
self.out_proj: nn.Linear
|
||||
|
||||
self._register_load_state_dict_pre_hook(self.load_hook)
|
||||
|
||||
def load_hook(self, state_dict: dict[str, Tensor], prefix, *args):
|
||||
keys_to_modify = [key for key in state_dict if "in_proj_" in key]
|
||||
for key in keys_to_modify:
|
||||
new_key = key.replace("in_proj_", "in_proj.") # in_proj_ -> in_proj.
|
||||
state_dict[new_key] = state_dict.pop(key)
|
||||
|
||||
@abstractmethod
|
||||
def __call__(self, x: Tensor, input_pos: Tensor, kv_cache: KVCacheProtocol, *args, **kwds) -> Tensor: ...
|
||||
|
||||
def prefill(self, x: Tensor, kv_cache: KVCacheProtocol, attn_mask: Tensor) -> Tensor:
|
||||
bsz, seqlen, _ = x.shape
|
||||
|
||||
q, k, v = self.in_proj(x).chunk(3, dim=-1)
|
||||
|
||||
q, k, v = map(lambda x: x.contiguous().view(bsz, seqlen, self.n_head, self.head_dim), (q, k, v))
|
||||
|
||||
kv_cache.prefill_kv(k, v)
|
||||
|
||||
q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v))
|
||||
|
||||
attn = F.scaled_dot_product_attention(q, k, v, attn_mask)
|
||||
|
||||
attn = attn.transpose(1, 2).contiguous().view(1, -1, self.hidden_dim)
|
||||
|
||||
output = self.out_proj(attn)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(self, dim: int, hidden_dim: int) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.linear1 = nn.Linear(dim, hidden_dim, bias=True)
|
||||
self.linear2 = nn.Linear(hidden_dim, dim, bias=True)
|
||||
|
||||
def __call__(self, x: Tensor):
|
||||
return self.linear2(F.relu(self.linear1(x), inplace=True))
|
||||
|
||||
|
||||
class TransformerBlockABC(nn.Module, ABC):
|
||||
def __init__(self, n_head: int, ffn_dim: int, hidden_dim: int, max_seq_length: int) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.hidden_dim = hidden_dim
|
||||
self.max_seq_length = max_seq_length
|
||||
|
||||
self.attention: AttentionABC
|
||||
self.feed_forward: FeedForward
|
||||
self.attention_norm: nn.LayerNorm
|
||||
self.ffn_norm: nn.LayerNorm
|
||||
|
||||
self._register_load_state_dict_pre_hook(self.load_hook)
|
||||
|
||||
def load_hook(self, state_dict: dict[str, Tensor], prefix, *args):
|
||||
for key in list(state_dict.keys()):
|
||||
new_key = (
|
||||
key.replace("self_attn", "attention")
|
||||
.replace("linear", "feed_forward.linear")
|
||||
.replace("norm1", "attention_norm")
|
||||
.replace("norm2", "ffn_norm")
|
||||
)
|
||||
state_dict[new_key] = state_dict.pop(key)
|
||||
|
||||
def __call__(self, x: Tensor, input_pos: Tensor, kv_cache: KVCacheProtocol, *args, **kwds):
|
||||
h = self.attention_norm(
|
||||
x
|
||||
+ self.attention(
|
||||
x,
|
||||
input_pos,
|
||||
kv_cache,
|
||||
*args,
|
||||
**kwds,
|
||||
)
|
||||
)
|
||||
out = self.ffn_norm(h + self.feed_forward(h))
|
||||
return out
|
||||
|
||||
def prefill(
|
||||
self,
|
||||
x: Tensor,
|
||||
kv_cache: KVCacheProtocol,
|
||||
attn_mask: Tensor,
|
||||
) -> Tensor:
|
||||
h = self.attention_norm(
|
||||
x
|
||||
+ self.attention.prefill(
|
||||
x,
|
||||
kv_cache,
|
||||
attn_mask,
|
||||
)
|
||||
)
|
||||
out = self.ffn_norm(h + self.feed_forward(h))
|
||||
return out
|
||||
|
||||
|
||||
class TransformerDecoderABC(nn.Module, ABC):
|
||||
def __init__(
|
||||
self,
|
||||
hidden_dim: int,
|
||||
n_layer: int,
|
||||
n_head: int,
|
||||
ffn_dim: int,
|
||||
vocab_size: int,
|
||||
max_seq_length: int,
|
||||
max_batch_size: int,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.hidden_dim = hidden_dim
|
||||
self.n_head = n_head
|
||||
assert hidden_dim % n_head == 0
|
||||
|
||||
self.head_dim = hidden_dim // n_head
|
||||
self.vocab_size = vocab_size
|
||||
|
||||
self.n_layer = n_layer
|
||||
|
||||
self.layers: MutableSequence[TransformerBlockABC]
|
||||
|
||||
self.max_seq_length = max_seq_length
|
||||
self.max_batch_size = max_batch_size
|
||||
|
||||
def __call__(self, input_pos: Tensor, x: Tensor, kv_caches: MutableSequence[KVCacheProtocol], *args, **kwds):
|
||||
for layer, kv_cache in zip(self.layers, kv_caches):
|
||||
x = layer(x, input_pos, kv_cache, *args, **kwds)
|
||||
return x
|
||||
|
||||
def prefill(self, x: Tensor, kv_caches: MutableSequence[KVCacheProtocol], attn_mask: Tensor):
|
||||
for layer, kv_cache in zip(self.layers, kv_caches):
|
||||
x = layer.prefill(x, kv_cache, attn_mask)
|
||||
return x
|
||||
|
||||
|
||||
class T2SDecoderABC(nn.Module, ABC, T2SDecoderProtocol):
|
||||
def __init__(
|
||||
self,
|
||||
config: dict,
|
||||
max_seq_length: int = 2000,
|
||||
max_batch_size: int = 10,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
hidden_dim: int = config["model"]["hidden_dim"]
|
||||
embedding_dim: int = config["model"]["embedding_dim"]
|
||||
n_head: int = config["model"]["head"]
|
||||
n_layer: int = config["model"]["n_layer"]
|
||||
vocab_size: int = config["model"]["vocab_size"]
|
||||
phoneme_vocab_size: int = config["model"]["phoneme_vocab_size"]
|
||||
EOS: int = config["model"]["EOS"]
|
||||
ffn_dim: int = hidden_dim * 4
|
||||
|
||||
self.n_layer = int(n_layer)
|
||||
self.hidden_dim = int(hidden_dim)
|
||||
self.n_head = int(n_head)
|
||||
assert hidden_dim % n_head == 0
|
||||
|
||||
self.head_dim = int(hidden_dim // n_head)
|
||||
self.embedding_dim = int(embedding_dim)
|
||||
self.ffn_dim = int(ffn_dim)
|
||||
self.vocab_size = int(vocab_size)
|
||||
self.phoneme_vocab_size = int(phoneme_vocab_size)
|
||||
self.max_seq_length = max_seq_length
|
||||
self.max_batch_size = max_batch_size
|
||||
self.EOS = EOS
|
||||
assert self.EOS == self.vocab_size - 1
|
||||
|
||||
self.bert_proj: nn.Linear
|
||||
self.ar_predict_layer: nn.Linear
|
||||
self.h: TransformerDecoderABC
|
||||
|
||||
self.kv_class: type[KVCacheABC]
|
||||
|
||||
self.GraphCache: CUDAGraphCacheABC | None
|
||||
|
||||
self.ar_text_embedding = TokenEmbedding(self.embedding_dim, self.phoneme_vocab_size)
|
||||
self.ar_text_position = SinePositionalEmbedding(
|
||||
self.embedding_dim,
|
||||
scale=False,
|
||||
alpha=True,
|
||||
max_batch_size=max_batch_size,
|
||||
max_seq_len=max_seq_length,
|
||||
)
|
||||
self.ar_audio_embedding = TokenEmbedding(self.embedding_dim, self.vocab_size)
|
||||
self.ar_audio_position = SinePositionalEmbedding(
|
||||
self.embedding_dim,
|
||||
scale=False,
|
||||
alpha=True,
|
||||
max_batch_size=max_batch_size,
|
||||
max_seq_len=max_seq_length,
|
||||
)
|
||||
|
||||
self._register_load_state_dict_pre_hook(self.load_hook)
|
||||
|
||||
def load_hook(self, state_dict: dict[str, Tensor], prefix, *args):
|
||||
model_keys = [key for key in state_dict if key.startswith("model.")]
|
||||
for key in model_keys:
|
||||
new_key = key[len("model.") :]
|
||||
state_dict[new_key] = state_dict.pop(key)
|
||||
|
||||
def init_cache(self, bsz: int = 0) -> MutableSequence[KVCacheProtocol]:
|
||||
bsz = bsz or self.h.max_batch_size
|
||||
assert bsz <= self.h.max_batch_size
|
||||
seq_lens = self.h.max_seq_length
|
||||
dtype = self.bert_proj.bias.dtype
|
||||
kvclass = self.kv_class
|
||||
|
||||
return nn.ModuleList(
|
||||
[kvclass(bsz, seq_lens, self.n_head, self.head_dim) for _ in range(self.n_layer)],
|
||||
).to(self.device, dtype) # type: ignore
|
||||
|
||||
def embed(
|
||||
self,
|
||||
x: list[torch.Tensor],
|
||||
y: torch.Tensor,
|
||||
bert_features: list[torch.Tensor],
|
||||
):
|
||||
x_len: list[int] = [i.shape[0] for i in x]
|
||||
x_len_max = max(x_len)
|
||||
xy_pos = torch.zeros((len(x), x_len_max + y.shape[1], self.embedding_dim)).to(bert_features[0].dtype)
|
||||
|
||||
bert_features = list(map(lambda x: x.transpose(0, 1), bert_features))
|
||||
|
||||
y_len = y.shape[1]
|
||||
y_emb = self.ar_audio_embedding(y)
|
||||
y_pos = self.ar_audio_position.prefill(y_emb)
|
||||
|
||||
for bs, (x_, len_, bert_feature) in enumerate(zip(x, x_len, bert_features)):
|
||||
x_emb = self.ar_text_embedding(x_)
|
||||
bert = self.bert_proj(bert_feature)
|
||||
x_emb = x_emb + bert
|
||||
x_pos = self.ar_text_position.prefill(x_emb.unsqueeze(0))
|
||||
xy_pos[[bs], :len_] = x_pos
|
||||
xy_pos[[bs], len_ : len_ + y_len] = y_pos
|
||||
|
||||
return xy_pos
|
||||
|
||||
def compile(self, *args, **kwds):
|
||||
# Experimental features to reduce compilation times, will be on by default in future
|
||||
torch._inductor.config.triton.cudagraph_skip_dynamic_graphs = True
|
||||
torch._inductor.config.coordinate_descent_tuning = True
|
||||
torch._inductor.config.triton.unique_kernel_names = True
|
||||
torch._inductor.config.fx_graph_cache = True
|
||||
torch._inductor.config.triton.cudagraph_trees = True
|
||||
torch._inductor.config.triton.cudagraph_support_input_mutation = True
|
||||
self.h.compile(fullgraph=True, mode="reduce-overhead")
|
||||
|
||||
def capture(
|
||||
self, input_pos: Tensor, x: Tensor, x_dec: Tensor, kv_caches: MutableSequence[KVCacheProtocol], *args, **kwds
|
||||
) -> CUDAGraph:
|
||||
assert torch.cuda.is_available()
|
||||
s = torch.cuda.Stream()
|
||||
s.wait_stream(torch.cuda.current_stream())
|
||||
|
||||
graph = torch.cuda.CUDAGraph()
|
||||
|
||||
with torch.cuda.stream(s):
|
||||
for _ in range(5):
|
||||
self.h(input_pos, x, kv_caches, *args, **kwds)
|
||||
torch.cuda.current_stream().wait_stream(s)
|
||||
|
||||
with torch.cuda.graph(graph):
|
||||
x_dec.copy_(self.h(input_pos, x, kv_caches, *args, **kwds))
|
||||
torch.cuda.synchronize()
|
||||
|
||||
return graph
|
||||
|
||||
@abstractmethod
|
||||
def pre_forward(self, session: T2SSession) -> tuple[list[Tensor], dict[str, Tensor]]:
|
||||
return list(), dict()
|
||||
|
||||
@abstractmethod
|
||||
def post_forward(self, idx: int, session: T2SSession) -> None:
|
||||
return
|
||||
|
||||
|
||||
class CUDAGraphCacheABC(ABC):
|
||||
def __init__(
|
||||
self,
|
||||
decoder: T2SDecoderABC,
|
||||
) -> None:
|
||||
self.is_applicable: bool
|
||||
|
||||
if torch.cuda.is_available() and self.is_applicable:
|
||||
self.device: torch.device = decoder.device
|
||||
self.dtype = decoder.bert_proj.bias.dtype
|
||||
|
||||
self.assigned: bool = False
|
||||
|
||||
self.decoder: T2SDecoderABC = decoder
|
||||
self.kv_cache: MutableSequence[KVCacheProtocol] = decoder.init_cache(decoder.max_batch_size)
|
||||
self.xy_pos = torch.rand(size=(decoder.max_batch_size, 1, decoder.embedding_dim), device=self.device).to(
|
||||
self.dtype
|
||||
)
|
||||
self.xy_dec = self.xy_pos.clone()
|
||||
|
||||
self.input_pos = torch.tensor([10] * decoder.max_batch_size, device=self.device).int()
|
||||
self.graph: torch.cuda.CUDAGraph | None = None
|
||||
self.stream: torch.cuda.Stream | None
|
||||
|
||||
self.id: int = random.randint(1, 2**32 - 1)
|
||||
|
||||
def assign_graph(self, session: T2SSession):
|
||||
if self.graph is None:
|
||||
args, kwds = self.decoder.pre_forward(session)
|
||||
graph = self.decoder.capture(self.input_pos, self.xy_pos, self.xy_dec, self.kv_cache, *args, **kwds)
|
||||
self.graph = graph
|
||||
self.stream = torch.cuda.Stream()
|
||||
|
||||
if self.assigned is False:
|
||||
self.get_cache_graph(session)
|
||||
session.id = self.id
|
||||
self.assigned = True
|
||||
else:
|
||||
self.capture_new_graph(session)
|
||||
|
||||
@abstractmethod
|
||||
def release_graph(self, session: T2SSession): ...
|
||||
|
||||
@abstractmethod
|
||||
def get_cache_graph(self, session: T2SSession):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def capture_new_graph(self, session: T2SSession):
|
||||
pass
|
||||
|
||||
|
||||
class TorchProfiler:
|
||||
def __init__(self, debug: bool, log_dir: str = "./profiler") -> None:
|
||||
self.debug = debug
|
||||
self.log_dir = log_dir
|
||||
self.__profiler: torch.profiler.profile
|
||||
|
||||
if self.debug and not os.path.exists(self.log_dir):
|
||||
os.makedirs(self.log_dir)
|
||||
|
||||
self.tensorboard_handler = tensorboard_trace_handler(self.log_dir)
|
||||
|
||||
def profiler_callback(self, prof: torch.profiler.profile):
|
||||
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=30))
|
||||
print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=30))
|
||||
self.tensorboard_handler(prof)
|
||||
|
||||
@staticmethod
|
||||
def three_step_schedule(step: int) -> ProfilerAction:
|
||||
if step == 0:
|
||||
return ProfilerAction.NONE
|
||||
elif step == 1:
|
||||
return ProfilerAction.RECORD
|
||||
elif step == 2:
|
||||
return ProfilerAction.RECORD_AND_SAVE
|
||||
else:
|
||||
return ProfilerAction.NONE
|
||||
|
||||
def start(self):
|
||||
if not self.debug:
|
||||
return
|
||||
assert self.__profiler is not None
|
||||
self.__profiler.step()
|
||||
|
||||
def end(self):
|
||||
if not self.debug:
|
||||
return
|
||||
assert self.__profiler is not None
|
||||
self.__profiler.step()
|
||||
|
||||
def profiler(self):
|
||||
if self.debug:
|
||||
activities_list = [torch.profiler.ProfilerActivity.CPU]
|
||||
if torch.cuda.is_available():
|
||||
activities_list.append(torch.profiler.ProfilerActivity.CUDA)
|
||||
|
||||
self.__profiler = torch.profiler.profile(
|
||||
activities=activities_list,
|
||||
record_shapes=True,
|
||||
with_stack=True,
|
||||
with_modules=True,
|
||||
profile_memory=True,
|
||||
schedule=self.three_step_schedule,
|
||||
on_trace_ready=self.profiler_callback,
|
||||
)
|
||||
return self.__profiler
|
||||
else:
|
||||
return nullcontext()
|
||||
|
||||
def record(self, func_name: str):
|
||||
if self.debug:
|
||||
return torch.profiler.record_function(func_name)
|
||||
else:
|
||||
return nullcontext()
|
30
GPT_SoVITS/Accelerate/__init__.py
Normal file
30
GPT_SoVITS/Accelerate/__init__.py
Normal file
@ -0,0 +1,30 @@
|
||||
from . import MLX, PyTorch
|
||||
from .logger import console, logger, tb
|
||||
from .PyTorch import T2SEngineTorch, T2SRequest, T2SResult
|
||||
from .PyTorch.structs import T2SEngineProtocol
|
||||
|
||||
backends = PyTorch.backends + MLX.backends
|
||||
|
||||
backends = [
|
||||
b.replace("_", "-")
|
||||
.title()
|
||||
.replace("Mlx", "MLX")
|
||||
.replace("Mps", "MPS")
|
||||
.replace("Cuda", "CUDA")
|
||||
.replace("Mxfp4", "MXFP4")
|
||||
for b in backends
|
||||
]
|
||||
|
||||
|
||||
__all__ = [
|
||||
"T2SEngineTorch",
|
||||
"T2SRequest",
|
||||
"T2SResult",
|
||||
"backends",
|
||||
"MLX",
|
||||
"PyTorch",
|
||||
"logger",
|
||||
"console",
|
||||
"tb",
|
||||
"T2SEngineProtocol",
|
||||
]
|
203
GPT_SoVITS/Accelerate/logger.py
Normal file
203
GPT_SoVITS/Accelerate/logger.py
Normal file
@ -0,0 +1,203 @@
|
||||
import sys
|
||||
from typing import Optional
|
||||
|
||||
from loguru import logger
|
||||
from rich.console import Console, JustifyMethod
|
||||
from rich.highlighter import Highlighter
|
||||
from rich.logging import RichHandler
|
||||
from rich.progress import Task, TextColumn
|
||||
from rich.style import StyleType
|
||||
from rich.table import Column
|
||||
from rich.text import Text
|
||||
from rich.traceback import Traceback, install
|
||||
|
||||
console = Console(stderr=False)
|
||||
install(console=console)
|
||||
|
||||
|
||||
def loguru_format(record):
|
||||
level = record["level"].name
|
||||
color = {
|
||||
"DEBUG": "green",
|
||||
"INFO": "blue",
|
||||
"WARNING": "yellow",
|
||||
"ERROR": "red",
|
||||
"CRITICAL": "bright_red",
|
||||
}.get(level, "white")
|
||||
|
||||
return f"[bold {color}][{level}][/bold {color}] " + "{message}"
|
||||
|
||||
|
||||
handler_with_locals = RichHandler(
|
||||
console=console,
|
||||
show_time=False,
|
||||
show_path=False,
|
||||
rich_tracebacks=True,
|
||||
tracebacks_show_locals=True,
|
||||
show_level=False,
|
||||
markup=True,
|
||||
)
|
||||
handler_without_locals = RichHandler(
|
||||
console=console,
|
||||
show_time=False,
|
||||
show_path=False,
|
||||
rich_tracebacks=True,
|
||||
tracebacks_show_locals=False,
|
||||
show_level=False,
|
||||
markup=True,
|
||||
)
|
||||
|
||||
|
||||
def local_filter(r):
|
||||
return r["extra"].get("show_locals", True)
|
||||
|
||||
|
||||
logger.remove()
|
||||
logger.add(handler_with_locals, format=loguru_format, filter=local_filter)
|
||||
logger.add(handler_without_locals, format=loguru_format, filter=lambda x: not local_filter(x))
|
||||
|
||||
|
||||
class SpeedColumnToken(TextColumn):
|
||||
"""Show task progress as a percentage.
|
||||
|
||||
Args:
|
||||
text_format (str, optional): Format for percentage display. Defaults to "[progress.percentage]{task.percentage:>3.0f}%".
|
||||
text_format_no_percentage (str, optional): Format if percentage is unknown. Defaults to "".
|
||||
style (StyleType, optional): Style of output. Defaults to "none".
|
||||
justify (JustifyMethod, optional): Text justification. Defaults to "left".
|
||||
markup (bool, optional): Enable markup. Defaults to True.
|
||||
highlighter (Optional[Highlighter], optional): Highlighter to apply to output. Defaults to None.
|
||||
table_column (Optional[Column], optional): Table Column to use. Defaults to None.
|
||||
show_speed (bool, optional): Show speed if total is unknown. Defaults to False.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
text_format: str = "[progress.percentage]{task.percentage:>3.0f}%",
|
||||
text_format_no_percentage: str = "",
|
||||
style: StyleType = "none",
|
||||
justify: JustifyMethod = "left",
|
||||
markup: bool = True,
|
||||
highlighter: Optional[Highlighter] = None,
|
||||
table_column: Optional[Column] = None,
|
||||
show_speed: bool = True,
|
||||
) -> None:
|
||||
self.text_format_no_percentage = text_format_no_percentage
|
||||
self.show_speed = show_speed
|
||||
super().__init__(
|
||||
text_format=text_format,
|
||||
style=style,
|
||||
justify=justify,
|
||||
markup=markup,
|
||||
highlighter=highlighter,
|
||||
table_column=table_column,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def render_speed(cls, speed: Optional[float]) -> Text:
|
||||
"""Render the speed in iterations per second.
|
||||
|
||||
Args:
|
||||
task (Task): A Task object.
|
||||
|
||||
Returns:
|
||||
Text: Text object containing the task speed.
|
||||
"""
|
||||
if speed is None:
|
||||
return Text("", style="progress.percentage")
|
||||
return Text(f"{speed:.1f} token/s", style="progress.percentage")
|
||||
|
||||
def render(self, task: Task) -> Text:
|
||||
if self.show_speed:
|
||||
return self.render_speed(task.finished_speed or task.speed)
|
||||
text_format = self.text_format_no_percentage if task.total is None else self.text_format
|
||||
_text = text_format.format(task=task)
|
||||
if self.markup:
|
||||
text = Text.from_markup(_text, style=self.style, justify=self.justify)
|
||||
else:
|
||||
text = Text(_text, style=self.style, justify=self.justify)
|
||||
if self.highlighter:
|
||||
self.highlighter.highlight(text)
|
||||
return text
|
||||
|
||||
|
||||
class SpeedColumnIteration(TextColumn):
|
||||
"""Show task progress as a percentage.
|
||||
|
||||
Args:
|
||||
text_format (str, optional): Format for percentage display. Defaults to "[progress.percentage]{task.percentage:>3.0f}%".
|
||||
text_format_no_percentage (str, optional): Format if percentage is unknown. Defaults to "".
|
||||
style (StyleType, optional): Style of output. Defaults to "none".
|
||||
justify (JustifyMethod, optional): Text justification. Defaults to "left".
|
||||
markup (bool, optional): Enable markup. Defaults to True.
|
||||
highlighter (Optional[Highlighter], optional): Highlighter to apply to output. Defaults to None.
|
||||
table_column (Optional[Column], optional): Table Column to use. Defaults to None.
|
||||
show_speed (bool, optional): Show speed if total is unknown. Defaults to False.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
text_format: str = "[progress.percentage]{task.percentage:>3.0f}%",
|
||||
text_format_no_percentage: str = "",
|
||||
style: StyleType = "none",
|
||||
justify: JustifyMethod = "left",
|
||||
markup: bool = True,
|
||||
highlighter: Optional[Highlighter] = None,
|
||||
table_column: Optional[Column] = None,
|
||||
show_speed: bool = True,
|
||||
) -> None:
|
||||
self.text_format_no_percentage = text_format_no_percentage
|
||||
self.show_speed = show_speed
|
||||
super().__init__(
|
||||
text_format=text_format,
|
||||
style=style,
|
||||
justify=justify,
|
||||
markup=markup,
|
||||
highlighter=highlighter,
|
||||
table_column=table_column,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def render_speed(cls, speed: Optional[float]) -> Text:
|
||||
"""Render the speed in iterations per second.
|
||||
|
||||
Args:
|
||||
task (Task): A Task object.
|
||||
|
||||
Returns:
|
||||
Text: Text object containing the task speed.
|
||||
"""
|
||||
if speed is None:
|
||||
return Text("", style="progress.percentage")
|
||||
return Text(f"{speed:.1f} it/s", style="progress.percentage")
|
||||
|
||||
def render(self, task: Task) -> Text:
|
||||
if self.show_speed:
|
||||
return self.render_speed(task.finished_speed or task.speed)
|
||||
text_format = self.text_format_no_percentage if task.total is None else self.text_format
|
||||
_text = text_format.format(task=task)
|
||||
if self.markup:
|
||||
text = Text.from_markup(_text, style=self.style, justify=self.justify)
|
||||
else:
|
||||
text = Text(_text, style=self.style, justify=self.justify)
|
||||
if self.highlighter:
|
||||
self.highlighter.highlight(text)
|
||||
return text
|
||||
|
||||
|
||||
def tb(show_locals: bool = True):
|
||||
exc_type, exc_value, exc_tb = sys.exc_info()
|
||||
assert exc_type
|
||||
assert exc_value
|
||||
tb = Traceback.from_exception(exc_type, exc_value, exc_tb, show_locals=show_locals)
|
||||
|
||||
return tb
|
||||
|
||||
|
||||
__all__ = ["logger", "console", "tb", "SpeedColumnToken", "SpeedColumnIteration"]
|
||||
|
||||
if __name__ == "__main__":
|
||||
try:
|
||||
raise RuntimeError()
|
||||
except Exception:
|
||||
logger.bind(show_locals=False).exception("TEST")
|
@ -1,266 +0,0 @@
|
||||
## BigVGAN: A Universal Neural Vocoder with Large-Scale Training
|
||||
|
||||
#### Sang-gil Lee, Wei Ping, Boris Ginsburg, Bryan Catanzaro, Sungroh Yoon
|
||||
|
||||
[[Paper]](https://arxiv.org/abs/2206.04658) - [[Code]](https://github.com/NVIDIA/BigVGAN) - [[Showcase]](https://bigvgan-demo.github.io/) - [[Project Page]](https://research.nvidia.com/labs/adlr/projects/bigvgan/) - [[Weights]](https://huggingface.co/collections/nvidia/bigvgan-66959df3d97fd7d98d97dc9a) - [[Demo]](https://huggingface.co/spaces/nvidia/BigVGAN)
|
||||
|
||||
[](https://paperswithcode.com/sota/speech-synthesis-on-libritts?p=bigvgan-a-universal-neural-vocoder-with-large)
|
||||
|
||||
<center><img src="https://user-images.githubusercontent.com/15963413/218609148-881e39df-33af-4af9-ab95-1427c4ebf062.png" width="800"></center>
|
||||
|
||||
## News
|
||||
- **Sep 2024 (v2.4):**
|
||||
- We have updated the pretrained checkpoints trained for 5M steps. This is final release of the BigVGAN-v2 checkpoints.
|
||||
|
||||
- **Jul 2024 (v2.3):**
|
||||
- General refactor and code improvements for improved readability.
|
||||
- Fully fused CUDA kernel of anti-alised activation (upsampling + activation + downsampling) with inference speed benchmark.
|
||||
|
||||
- **Jul 2024 (v2.2):** The repository now includes an interactive local demo using gradio.
|
||||
|
||||
- **Jul 2024 (v2.1):** BigVGAN is now integrated with 🤗 Hugging Face Hub with easy access to inference using pretrained checkpoints. We also provide an interactive demo on Hugging Face Spaces.
|
||||
|
||||
- **Jul 2024 (v2):** We release BigVGAN-v2 along with pretrained checkpoints. Below are the highlights:
|
||||
- Custom CUDA kernel for inference: we provide a fused upsampling + activation kernel written in CUDA for accelerated inference speed. Our test shows 1.5 - 3x faster speed on a single A100 GPU.
|
||||
- Improved discriminator and loss: BigVGAN-v2 is trained using a [multi-scale sub-band CQT discriminator](https://arxiv.org/abs/2311.14957) and a [multi-scale mel spectrogram loss](https://arxiv.org/abs/2306.06546).
|
||||
- Larger training data: BigVGAN-v2 is trained using datasets containing diverse audio types, including speech in multiple languages, environmental sounds, and instruments.
|
||||
- We provide pretrained checkpoints of BigVGAN-v2 using diverse audio configurations, supporting up to 44 kHz sampling rate and 512x upsampling ratio.
|
||||
|
||||
## Installation
|
||||
|
||||
The codebase has been tested on Python `3.10` and PyTorch `2.3.1` conda packages with either `pytorch-cuda=12.1` or `pytorch-cuda=11.8`. Below is an example command to create the conda environment:
|
||||
|
||||
```shell
|
||||
conda create -n bigvgan python=3.10 pytorch torchvision torchaudio pytorch-cuda=12.1 -c pytorch -c nvidia
|
||||
conda activate bigvgan
|
||||
```
|
||||
|
||||
Clone the repository and install dependencies:
|
||||
|
||||
```shell
|
||||
git clone https://github.com/NVIDIA/BigVGAN
|
||||
cd BigVGAN
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
## Inference Quickstart using 🤗 Hugging Face Hub
|
||||
|
||||
Below example describes how you can use BigVGAN: load the pretrained BigVGAN generator from Hugging Face Hub, compute mel spectrogram from input waveform, and generate synthesized waveform using the mel spectrogram as the model's input.
|
||||
|
||||
```python
|
||||
device = 'cuda'
|
||||
|
||||
import torch
|
||||
import bigvgan
|
||||
import librosa
|
||||
from meldataset import get_mel_spectrogram
|
||||
|
||||
# instantiate the model. You can optionally set use_cuda_kernel=True for faster inference.
|
||||
model = bigvgan.BigVGAN.from_pretrained('nvidia/bigvgan_v2_24khz_100band_256x', use_cuda_kernel=False)
|
||||
|
||||
# remove weight norm in the model and set to eval mode
|
||||
model.remove_weight_norm()
|
||||
model = model.eval().to(device)
|
||||
|
||||
# load wav file and compute mel spectrogram
|
||||
wav_path = '/path/to/your/audio.wav'
|
||||
wav, sr = librosa.load(wav_path, sr=model.h.sampling_rate, mono=True) # wav is np.ndarray with shape [T_time] and values in [-1, 1]
|
||||
wav = torch.FloatTensor(wav).unsqueeze(0) # wav is FloatTensor with shape [B(1), T_time]
|
||||
|
||||
# compute mel spectrogram from the ground truth audio
|
||||
mel = get_mel_spectrogram(wav, model.h).to(device) # mel is FloatTensor with shape [B(1), C_mel, T_frame]
|
||||
|
||||
# generate waveform from mel
|
||||
with torch.inference_mode():
|
||||
wav_gen = model(mel) # wav_gen is FloatTensor with shape [B(1), 1, T_time] and values in [-1, 1]
|
||||
wav_gen_float = wav_gen.squeeze(0).cpu() # wav_gen is FloatTensor with shape [1, T_time]
|
||||
|
||||
# you can convert the generated waveform to 16 bit linear PCM
|
||||
wav_gen_int16 = (wav_gen_float * 32767.0).numpy().astype('int16') # wav_gen is now np.ndarray with shape [1, T_time] and int16 dtype
|
||||
```
|
||||
|
||||
## Local gradio demo <a href='https://github.com/gradio-app/gradio'><img src='https://img.shields.io/github/stars/gradio-app/gradio'></a>
|
||||
|
||||
You can run a local gradio demo using below command:
|
||||
|
||||
```python
|
||||
pip install -r demo/requirements.txt
|
||||
python demo/app.py
|
||||
```
|
||||
|
||||
## Training
|
||||
|
||||
Create symbolic link to the root of the dataset. The codebase uses filelist with the relative path from the dataset. Below are the example commands for LibriTTS dataset:
|
||||
|
||||
```shell
|
||||
cd filelists/LibriTTS && \
|
||||
ln -s /path/to/your/LibriTTS/train-clean-100 train-clean-100 && \
|
||||
ln -s /path/to/your/LibriTTS/train-clean-360 train-clean-360 && \
|
||||
ln -s /path/to/your/LibriTTS/train-other-500 train-other-500 && \
|
||||
ln -s /path/to/your/LibriTTS/dev-clean dev-clean && \
|
||||
ln -s /path/to/your/LibriTTS/dev-other dev-other && \
|
||||
ln -s /path/to/your/LibriTTS/test-clean test-clean && \
|
||||
ln -s /path/to/your/LibriTTS/test-other test-other && \
|
||||
cd ../..
|
||||
```
|
||||
|
||||
Train BigVGAN model. Below is an example command for training BigVGAN-v2 using LibriTTS dataset at 24kHz with a full 100-band mel spectrogram as input:
|
||||
|
||||
```shell
|
||||
python train.py \
|
||||
--config configs/bigvgan_v2_24khz_100band_256x.json \
|
||||
--input_wavs_dir filelists/LibriTTS \
|
||||
--input_training_file filelists/LibriTTS/train-full.txt \
|
||||
--input_validation_file filelists/LibriTTS/val-full.txt \
|
||||
--list_input_unseen_wavs_dir filelists/LibriTTS filelists/LibriTTS \
|
||||
--list_input_unseen_validation_file filelists/LibriTTS/dev-clean.txt filelists/LibriTTS/dev-other.txt \
|
||||
--checkpoint_path exp/bigvgan_v2_24khz_100band_256x
|
||||
```
|
||||
|
||||
## Synthesis
|
||||
|
||||
Synthesize from BigVGAN model. Below is an example command for generating audio from the model.
|
||||
It computes mel spectrograms using wav files from `--input_wavs_dir` and saves the generated audio to `--output_dir`.
|
||||
|
||||
```shell
|
||||
python inference.py \
|
||||
--checkpoint_file /path/to/your/bigvgan_v2_24khz_100band_256x/bigvgan_generator.pt \
|
||||
--input_wavs_dir /path/to/your/input_wav \
|
||||
--output_dir /path/to/your/output_wav
|
||||
```
|
||||
|
||||
`inference_e2e.py` supports synthesis directly from the mel spectrogram saved in `.npy` format, with shapes `[1, channel, frame]` or `[channel, frame]`.
|
||||
It loads mel spectrograms from `--input_mels_dir` and saves the generated audio to `--output_dir`.
|
||||
|
||||
Make sure that the STFT hyperparameters for mel spectrogram are the same as the model, which are defined in `config.json` of the corresponding model.
|
||||
|
||||
```shell
|
||||
python inference_e2e.py \
|
||||
--checkpoint_file /path/to/your/bigvgan_v2_24khz_100band_256x/bigvgan_generator.pt \
|
||||
--input_mels_dir /path/to/your/input_mel \
|
||||
--output_dir /path/to/your/output_wav
|
||||
```
|
||||
|
||||
## Using Custom CUDA Kernel for Synthesis
|
||||
|
||||
You can apply the fast CUDA inference kernel by using a parameter `use_cuda_kernel` when instantiating BigVGAN:
|
||||
|
||||
```python
|
||||
generator = BigVGAN(h, use_cuda_kernel=True)
|
||||
```
|
||||
|
||||
You can also pass `--use_cuda_kernel` to `inference.py` and `inference_e2e.py` to enable this feature.
|
||||
|
||||
When applied for the first time, it builds the kernel using `nvcc` and `ninja`. If the build succeeds, the kernel is saved to `alias_free_activation/cuda/build` and the model automatically loads the kernel. The codebase has been tested using CUDA `12.1`.
|
||||
|
||||
Please make sure that both are installed in your system and `nvcc` installed in your system matches the version your PyTorch build is using.
|
||||
|
||||
We recommend running `test_cuda_vs_torch_model.py` first to build and check the correctness of the CUDA kernel. See below example command and its output, where it returns `[Success] test CUDA fused vs. plain torch BigVGAN inference`:
|
||||
|
||||
```python
|
||||
python tests/test_cuda_vs_torch_model.py \
|
||||
--checkpoint_file /path/to/your/bigvgan_generator.pt
|
||||
```
|
||||
|
||||
```shell
|
||||
loading plain Pytorch BigVGAN
|
||||
...
|
||||
loading CUDA kernel BigVGAN with auto-build
|
||||
Detected CUDA files, patching ldflags
|
||||
Emitting ninja build file /path/to/your/BigVGAN/alias_free_activation/cuda/build/build.ninja..
|
||||
Building extension module anti_alias_activation_cuda...
|
||||
...
|
||||
Loading extension module anti_alias_activation_cuda...
|
||||
...
|
||||
Loading '/path/to/your/bigvgan_generator.pt'
|
||||
...
|
||||
[Success] test CUDA fused vs. plain torch BigVGAN inference
|
||||
> mean_difference=0.0007238413265440613
|
||||
...
|
||||
```
|
||||
|
||||
If you see `[Fail] test CUDA fused vs. plain torch BigVGAN inference`, it means that the CUDA kernel inference is incorrect. Please check if `nvcc` installed in your system is compatible with your PyTorch version.
|
||||
|
||||
## Pretrained Models
|
||||
|
||||
We provide the [pretrained models on Hugging Face Collections](https://huggingface.co/collections/nvidia/bigvgan-66959df3d97fd7d98d97dc9a).
|
||||
One can download the checkpoints of the generator weight (named `bigvgan_generator.pt`) and its discriminator/optimizer states (named `bigvgan_discriminator_optimizer.pt`) within the listed model repositories.
|
||||
|
||||
| Model Name | Sampling Rate | Mel band | fmax | Upsampling Ratio | Params | Dataset | Steps | Fine-Tuned |
|
||||
|:--------------------------------------------------------------------------------------------------------:|:-------------:|:--------:|:-----:|:----------------:|:------:|:--------------------------:|:-----:|:----------:|
|
||||
| [bigvgan_v2_44khz_128band_512x](https://huggingface.co/nvidia/bigvgan_v2_44khz_128band_512x) | 44 kHz | 128 | 22050 | 512 | 122M | Large-scale Compilation | 5M | No |
|
||||
| [bigvgan_v2_44khz_128band_256x](https://huggingface.co/nvidia/bigvgan_v2_44khz_128band_256x) | 44 kHz | 128 | 22050 | 256 | 112M | Large-scale Compilation | 5M | No |
|
||||
| [bigvgan_v2_24khz_100band_256x](https://huggingface.co/nvidia/bigvgan_v2_24khz_100band_256x) | 24 kHz | 100 | 12000 | 256 | 112M | Large-scale Compilation | 5M | No |
|
||||
| [bigvgan_v2_22khz_80band_256x](https://huggingface.co/nvidia/bigvgan_v2_22khz_80band_256x) | 22 kHz | 80 | 11025 | 256 | 112M | Large-scale Compilation | 5M | No |
|
||||
| [bigvgan_v2_22khz_80band_fmax8k_256x](https://huggingface.co/nvidia/bigvgan_v2_22khz_80band_fmax8k_256x) | 22 kHz | 80 | 8000 | 256 | 112M | Large-scale Compilation | 5M | No |
|
||||
| [bigvgan_24khz_100band](https://huggingface.co/nvidia/bigvgan_24khz_100band) | 24 kHz | 100 | 12000 | 256 | 112M | LibriTTS | 5M | No |
|
||||
| [bigvgan_base_24khz_100band](https://huggingface.co/nvidia/bigvgan_base_24khz_100band) | 24 kHz | 100 | 12000 | 256 | 14M | LibriTTS | 5M | No |
|
||||
| [bigvgan_22khz_80band](https://huggingface.co/nvidia/bigvgan_22khz_80band) | 22 kHz | 80 | 8000 | 256 | 112M | LibriTTS + VCTK + LJSpeech | 5M | No |
|
||||
| [bigvgan_base_22khz_80band](https://huggingface.co/nvidia/bigvgan_base_22khz_80band) | 22 kHz | 80 | 8000 | 256 | 14M | LibriTTS + VCTK + LJSpeech | 5M | No |
|
||||
|
||||
The paper results are based on the original 24kHz BigVGAN models (`bigvgan_24khz_100band` and `bigvgan_base_24khz_100band`) trained on LibriTTS dataset.
|
||||
We also provide 22kHz BigVGAN models with band-limited setup (i.e., fmax=8000) for TTS applications.
|
||||
Note that the checkpoints use `snakebeta` activation with log scale parameterization, which have the best overall quality.
|
||||
|
||||
You can fine-tune the models by:
|
||||
|
||||
1. downloading the checkpoints (both the generator weight and its discriminator/optimizer states)
|
||||
2. resuming training using your audio dataset by specifying `--checkpoint_path` that includes the checkpoints when launching `train.py`
|
||||
|
||||
## Training Details of BigVGAN-v2
|
||||
|
||||
Comapred to the original BigVGAN, the pretrained checkpoints of BigVGAN-v2 used `batch_size=32` with a longer `segment_size=65536` and are trained using 8 A100 GPUs.
|
||||
|
||||
Note that the BigVGAN-v2 `json` config files in `./configs` use `batch_size=4` as default to fit in a single A100 GPU for training. You can fine-tune the models adjusting `batch_size` depending on your GPUs.
|
||||
|
||||
When training BigVGAN-v2 from scratch with small batch size, it can potentially encounter the early divergence problem mentioned in the paper. In such case, we recommend lowering the `clip_grad_norm` value (e.g. `100`) for the early training iterations (e.g. 20000 steps) and increase the value to the default `500`.
|
||||
|
||||
## Evaluation Results of BigVGAN-v2
|
||||
|
||||
Below are the objective results of the 24kHz model (`bigvgan_v2_24khz_100band_256x`) obtained from the LibriTTS `dev` sets. BigVGAN-v2 shows noticeable improvements of the metrics. The model also exhibits reduced perceptual artifacts, especially for non-speech audio.
|
||||
|
||||
| Model | Dataset | Steps | PESQ(↑) | M-STFT(↓) | MCD(↓) | Periodicity(↓) | V/UV F1(↑) |
|
||||
|:----------:|:-----------------------:|:-----:|:---------:|:----------:|:----------:|:--------------:|:----------:|
|
||||
| BigVGAN | LibriTTS | 1M | 4.027 | 0.7997 | 0.3745 | 0.1018 | 0.9598 |
|
||||
| BigVGAN | LibriTTS | 5M | 4.256 | 0.7409 | 0.2988 | 0.0809 | 0.9698 |
|
||||
| BigVGAN-v2 | Large-scale Compilation | 3M | 4.359 | 0.7134 | 0.3060 | 0.0621 | 0.9777 |
|
||||
| BigVGAN-v2 | Large-scale Compilation | 5M | **4.362** | **0.7026** | **0.2903** | **0.0593** | **0.9793** |
|
||||
|
||||
## Speed Benchmark
|
||||
|
||||
Below are the speed and VRAM usage benchmark results of BigVGAN from `tests/test_cuda_vs_torch_model.py`, using `bigvgan_v2_24khz_100band_256x` as a reference model.
|
||||
|
||||
| GPU | num_mel_frame | use_cuda_kernel | Speed (kHz) | Real-time Factor | VRAM (GB) |
|
||||
|:--------------------------:|:-------------:|:---------------:|:-----------:|:----------------:|:---------:|
|
||||
| NVIDIA A100 | 256 | False | 1672.1 | 69.7x | 1.3 |
|
||||
| | | True | 3916.5 | 163.2x | 1.3 |
|
||||
| | 2048 | False | 1899.6 | 79.2x | 1.7 |
|
||||
| | | True | 5330.1 | 222.1x | 1.7 |
|
||||
| | 16384 | False | 1973.8 | 82.2x | 5.0 |
|
||||
| | | True | 5761.7 | 240.1x | 4.4 |
|
||||
| NVIDIA GeForce RTX 3080 | 256 | False | 841.1 | 35.0x | 1.3 |
|
||||
| | | True | 1598.1 | 66.6x | 1.3 |
|
||||
| | 2048 | False | 929.9 | 38.7x | 1.7 |
|
||||
| | | True | 1971.3 | 82.1x | 1.6 |
|
||||
| | 16384 | False | 943.4 | 39.3x | 5.0 |
|
||||
| | | True | 2026.5 | 84.4x | 3.9 |
|
||||
| NVIDIA GeForce RTX 2080 Ti | 256 | False | 515.6 | 21.5x | 1.3 |
|
||||
| | | True | 811.3 | 33.8x | 1.3 |
|
||||
| | 2048 | False | 576.5 | 24.0x | 1.7 |
|
||||
| | | True | 1023.0 | 42.6x | 1.5 |
|
||||
| | 16384 | False | 589.4 | 24.6x | 5.0 |
|
||||
| | | True | 1068.1 | 44.5x | 3.2 |
|
||||
|
||||
## Acknowledgements
|
||||
|
||||
We thank Vijay Anand Korthikanti and Kevin J. Shih for their generous support in implementing the CUDA kernel for inference.
|
||||
|
||||
## References
|
||||
|
||||
- [HiFi-GAN](https://github.com/jik876/hifi-gan) (for generator and multi-period discriminator)
|
||||
- [Snake](https://github.com/EdwardDixon/snake) (for periodic activation)
|
||||
- [Alias-free-torch](https://github.com/junjun3518/alias-free-torch) (for anti-aliasing)
|
||||
- [Julius](https://github.com/adefossez/julius) (for low-pass filter)
|
||||
- [UnivNet](https://github.com/mindslab-ai/univnet) (for multi-resolution discriminator)
|
||||
- [descript-audio-codec](https://github.com/descriptinc/descript-audio-codec) and [vocos](https://github.com/gemelo-ai/vocos) (for multi-band multi-scale STFT discriminator and multi-scale mel spectrogram loss)
|
||||
- [Amphion](https://github.com/open-mmlab/Amphion) (for multi-scale sub-band CQT discriminator)
|
@ -2,7 +2,7 @@
|
||||
# LICENSE is in incl_licenses directory.
|
||||
|
||||
import torch
|
||||
from torch import nn, sin, pow
|
||||
from torch import nn, pow, sin
|
||||
from torch.nn import Parameter
|
||||
|
||||
|
||||
|
@ -4,22 +4,22 @@
|
||||
# Adapted from https://github.com/jik876/hifi-gan under the MIT license.
|
||||
# LICENSE is in incl_licenses directory.
|
||||
|
||||
import os
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Optional, Union, Dict
|
||||
from typing import Dict, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from huggingface_hub import PyTorchModelHubMixin, hf_hub_download
|
||||
from torch.nn import Conv1d, ConvTranspose1d
|
||||
from torch.nn.utils import weight_norm, remove_weight_norm
|
||||
from torch.nn.utils import remove_weight_norm
|
||||
from torch.nn.utils import weight_norm
|
||||
|
||||
from . import activations
|
||||
from .utils0 import init_weights, get_padding
|
||||
from .alias_free_activation.torch.act import Activation1d as TorchActivation1d
|
||||
from .env import AttrDict
|
||||
|
||||
from huggingface_hub import PyTorchModelHubMixin, hf_hub_download
|
||||
from .utils0 import get_padding, init_weights
|
||||
|
||||
|
||||
def load_hparams_from_json(path) -> AttrDict:
|
||||
|
@ -1,45 +0,0 @@
|
||||
{
|
||||
"resblock": "1",
|
||||
"num_gpus": 0,
|
||||
"batch_size": 32,
|
||||
"learning_rate": 0.0001,
|
||||
"adam_b1": 0.8,
|
||||
"adam_b2": 0.99,
|
||||
"lr_decay": 0.9999996,
|
||||
"seed": 1234,
|
||||
|
||||
"upsample_rates": [4,4,2,2,2,2],
|
||||
"upsample_kernel_sizes": [8,8,4,4,4,4],
|
||||
"upsample_initial_channel": 1536,
|
||||
"resblock_kernel_sizes": [3,7,11],
|
||||
"resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]],
|
||||
|
||||
"activation": "snakebeta",
|
||||
"snake_logscale": true,
|
||||
|
||||
"resolutions": [[1024, 120, 600], [2048, 240, 1200], [512, 50, 240]],
|
||||
"mpd_reshapes": [2, 3, 5, 7, 11],
|
||||
"use_spectral_norm": false,
|
||||
"discriminator_channel_mult": 1,
|
||||
|
||||
"segment_size": 8192,
|
||||
"num_mels": 80,
|
||||
"num_freq": 1025,
|
||||
"n_fft": 1024,
|
||||
"hop_size": 256,
|
||||
"win_size": 1024,
|
||||
|
||||
"sampling_rate": 22050,
|
||||
|
||||
"fmin": 0,
|
||||
"fmax": 8000,
|
||||
"fmax_for_loss": null,
|
||||
|
||||
"num_workers": 4,
|
||||
|
||||
"dist_config": {
|
||||
"dist_backend": "nccl",
|
||||
"dist_url": "tcp://localhost:54321",
|
||||
"world_size": 1
|
||||
}
|
||||
}
|
@ -1,45 +0,0 @@
|
||||
{
|
||||
"resblock": "1",
|
||||
"num_gpus": 0,
|
||||
"batch_size": 32,
|
||||
"learning_rate": 0.0001,
|
||||
"adam_b1": 0.8,
|
||||
"adam_b2": 0.99,
|
||||
"lr_decay": 0.9999996,
|
||||
"seed": 1234,
|
||||
|
||||
"upsample_rates": [4,4,2,2,2,2],
|
||||
"upsample_kernel_sizes": [8,8,4,4,4,4],
|
||||
"upsample_initial_channel": 1536,
|
||||
"resblock_kernel_sizes": [3,7,11],
|
||||
"resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]],
|
||||
|
||||
"activation": "snakebeta",
|
||||
"snake_logscale": true,
|
||||
|
||||
"resolutions": [[1024, 120, 600], [2048, 240, 1200], [512, 50, 240]],
|
||||
"mpd_reshapes": [2, 3, 5, 7, 11],
|
||||
"use_spectral_norm": false,
|
||||
"discriminator_channel_mult": 1,
|
||||
|
||||
"segment_size": 8192,
|
||||
"num_mels": 100,
|
||||
"num_freq": 1025,
|
||||
"n_fft": 1024,
|
||||
"hop_size": 256,
|
||||
"win_size": 1024,
|
||||
|
||||
"sampling_rate": 24000,
|
||||
|
||||
"fmin": 0,
|
||||
"fmax": 12000,
|
||||
"fmax_for_loss": null,
|
||||
|
||||
"num_workers": 4,
|
||||
|
||||
"dist_config": {
|
||||
"dist_backend": "nccl",
|
||||
"dist_url": "tcp://localhost:54321",
|
||||
"world_size": 1
|
||||
}
|
||||
}
|
@ -1,45 +0,0 @@
|
||||
{
|
||||
"resblock": "1",
|
||||
"num_gpus": 0,
|
||||
"batch_size": 32,
|
||||
"learning_rate": 0.0001,
|
||||
"adam_b1": 0.8,
|
||||
"adam_b2": 0.99,
|
||||
"lr_decay": 0.9999996,
|
||||
"seed": 1234,
|
||||
|
||||
"upsample_rates": [8,8,2,2],
|
||||
"upsample_kernel_sizes": [16,16,4,4],
|
||||
"upsample_initial_channel": 512,
|
||||
"resblock_kernel_sizes": [3,7,11],
|
||||
"resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]],
|
||||
|
||||
"activation": "snakebeta",
|
||||
"snake_logscale": true,
|
||||
|
||||
"resolutions": [[1024, 120, 600], [2048, 240, 1200], [512, 50, 240]],
|
||||
"mpd_reshapes": [2, 3, 5, 7, 11],
|
||||
"use_spectral_norm": false,
|
||||
"discriminator_channel_mult": 1,
|
||||
|
||||
"segment_size": 8192,
|
||||
"num_mels": 80,
|
||||
"num_freq": 1025,
|
||||
"n_fft": 1024,
|
||||
"hop_size": 256,
|
||||
"win_size": 1024,
|
||||
|
||||
"sampling_rate": 22050,
|
||||
|
||||
"fmin": 0,
|
||||
"fmax": 8000,
|
||||
"fmax_for_loss": null,
|
||||
|
||||
"num_workers": 4,
|
||||
|
||||
"dist_config": {
|
||||
"dist_backend": "nccl",
|
||||
"dist_url": "tcp://localhost:54321",
|
||||
"world_size": 1
|
||||
}
|
||||
}
|
@ -1,45 +0,0 @@
|
||||
{
|
||||
"resblock": "1",
|
||||
"num_gpus": 0,
|
||||
"batch_size": 32,
|
||||
"learning_rate": 0.0001,
|
||||
"adam_b1": 0.8,
|
||||
"adam_b2": 0.99,
|
||||
"lr_decay": 0.9999996,
|
||||
"seed": 1234,
|
||||
|
||||
"upsample_rates": [8,8,2,2],
|
||||
"upsample_kernel_sizes": [16,16,4,4],
|
||||
"upsample_initial_channel": 512,
|
||||
"resblock_kernel_sizes": [3,7,11],
|
||||
"resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]],
|
||||
|
||||
"activation": "snakebeta",
|
||||
"snake_logscale": true,
|
||||
|
||||
"resolutions": [[1024, 120, 600], [2048, 240, 1200], [512, 50, 240]],
|
||||
"mpd_reshapes": [2, 3, 5, 7, 11],
|
||||
"use_spectral_norm": false,
|
||||
"discriminator_channel_mult": 1,
|
||||
|
||||
"segment_size": 8192,
|
||||
"num_mels": 100,
|
||||
"num_freq": 1025,
|
||||
"n_fft": 1024,
|
||||
"hop_size": 256,
|
||||
"win_size": 1024,
|
||||
|
||||
"sampling_rate": 24000,
|
||||
|
||||
"fmin": 0,
|
||||
"fmax": 12000,
|
||||
"fmax_for_loss": null,
|
||||
|
||||
"num_workers": 4,
|
||||
|
||||
"dist_config": {
|
||||
"dist_backend": "nccl",
|
||||
"dist_url": "tcp://localhost:54321",
|
||||
"world_size": 1
|
||||
}
|
||||
}
|
@ -1,61 +0,0 @@
|
||||
{
|
||||
"resblock": "1",
|
||||
"num_gpus": 0,
|
||||
"batch_size": 4,
|
||||
"learning_rate": 0.0001,
|
||||
"adam_b1": 0.8,
|
||||
"adam_b2": 0.99,
|
||||
"lr_decay": 0.9999996,
|
||||
"seed": 1234,
|
||||
|
||||
"upsample_rates": [4,4,2,2,2,2],
|
||||
"upsample_kernel_sizes": [8,8,4,4,4,4],
|
||||
"upsample_initial_channel": 1536,
|
||||
"resblock_kernel_sizes": [3,7,11],
|
||||
"resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]],
|
||||
|
||||
"use_tanh_at_final": false,
|
||||
"use_bias_at_final": false,
|
||||
|
||||
"activation": "snakebeta",
|
||||
"snake_logscale": true,
|
||||
|
||||
"use_cqtd_instead_of_mrd": true,
|
||||
"cqtd_filters": 128,
|
||||
"cqtd_max_filters": 1024,
|
||||
"cqtd_filters_scale": 1,
|
||||
"cqtd_dilations": [1, 2, 4],
|
||||
"cqtd_hop_lengths": [512, 256, 256],
|
||||
"cqtd_n_octaves": [9, 9, 9],
|
||||
"cqtd_bins_per_octaves": [24, 36, 48],
|
||||
|
||||
"mpd_reshapes": [2, 3, 5, 7, 11],
|
||||
"use_spectral_norm": false,
|
||||
"discriminator_channel_mult": 1,
|
||||
|
||||
"use_multiscale_melloss": true,
|
||||
"lambda_melloss": 15,
|
||||
|
||||
"clip_grad_norm": 500,
|
||||
|
||||
"segment_size": 65536,
|
||||
"num_mels": 80,
|
||||
"num_freq": 1025,
|
||||
"n_fft": 1024,
|
||||
"hop_size": 256,
|
||||
"win_size": 1024,
|
||||
|
||||
"sampling_rate": 22050,
|
||||
|
||||
"fmin": 0,
|
||||
"fmax": null,
|
||||
"fmax_for_loss": null,
|
||||
|
||||
"num_workers": 4,
|
||||
|
||||
"dist_config": {
|
||||
"dist_backend": "nccl",
|
||||
"dist_url": "tcp://localhost:54321",
|
||||
"world_size": 1
|
||||
}
|
||||
}
|
@ -1,61 +0,0 @@
|
||||
{
|
||||
"resblock": "1",
|
||||
"num_gpus": 0,
|
||||
"batch_size": 4,
|
||||
"learning_rate": 0.0001,
|
||||
"adam_b1": 0.8,
|
||||
"adam_b2": 0.99,
|
||||
"lr_decay": 0.9999996,
|
||||
"seed": 1234,
|
||||
|
||||
"upsample_rates": [4,4,2,2,2,2],
|
||||
"upsample_kernel_sizes": [8,8,4,4,4,4],
|
||||
"upsample_initial_channel": 1536,
|
||||
"resblock_kernel_sizes": [3,7,11],
|
||||
"resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]],
|
||||
|
||||
"use_tanh_at_final": false,
|
||||
"use_bias_at_final": false,
|
||||
|
||||
"activation": "snakebeta",
|
||||
"snake_logscale": true,
|
||||
|
||||
"use_cqtd_instead_of_mrd": true,
|
||||
"cqtd_filters": 128,
|
||||
"cqtd_max_filters": 1024,
|
||||
"cqtd_filters_scale": 1,
|
||||
"cqtd_dilations": [1, 2, 4],
|
||||
"cqtd_hop_lengths": [512, 256, 256],
|
||||
"cqtd_n_octaves": [9, 9, 9],
|
||||
"cqtd_bins_per_octaves": [24, 36, 48],
|
||||
|
||||
"mpd_reshapes": [2, 3, 5, 7, 11],
|
||||
"use_spectral_norm": false,
|
||||
"discriminator_channel_mult": 1,
|
||||
|
||||
"use_multiscale_melloss": true,
|
||||
"lambda_melloss": 15,
|
||||
|
||||
"clip_grad_norm": 500,
|
||||
|
||||
"segment_size": 65536,
|
||||
"num_mels": 80,
|
||||
"num_freq": 1025,
|
||||
"n_fft": 1024,
|
||||
"hop_size": 256,
|
||||
"win_size": 1024,
|
||||
|
||||
"sampling_rate": 22050,
|
||||
|
||||
"fmin": 0,
|
||||
"fmax": 8000,
|
||||
"fmax_for_loss": null,
|
||||
|
||||
"num_workers": 4,
|
||||
|
||||
"dist_config": {
|
||||
"dist_backend": "nccl",
|
||||
"dist_url": "tcp://localhost:54321",
|
||||
"world_size": 1
|
||||
}
|
||||
}
|
@ -1,61 +0,0 @@
|
||||
{
|
||||
"resblock": "1",
|
||||
"num_gpus": 0,
|
||||
"batch_size": 4,
|
||||
"learning_rate": 0.0001,
|
||||
"adam_b1": 0.8,
|
||||
"adam_b2": 0.99,
|
||||
"lr_decay": 0.9999996,
|
||||
"seed": 1234,
|
||||
|
||||
"upsample_rates": [4,4,2,2,2,2],
|
||||
"upsample_kernel_sizes": [8,8,4,4,4,4],
|
||||
"upsample_initial_channel": 1536,
|
||||
"resblock_kernel_sizes": [3,7,11],
|
||||
"resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]],
|
||||
|
||||
"use_tanh_at_final": false,
|
||||
"use_bias_at_final": false,
|
||||
|
||||
"activation": "snakebeta",
|
||||
"snake_logscale": true,
|
||||
|
||||
"use_cqtd_instead_of_mrd": true,
|
||||
"cqtd_filters": 128,
|
||||
"cqtd_max_filters": 1024,
|
||||
"cqtd_filters_scale": 1,
|
||||
"cqtd_dilations": [1, 2, 4],
|
||||
"cqtd_hop_lengths": [512, 256, 256],
|
||||
"cqtd_n_octaves": [9, 9, 9],
|
||||
"cqtd_bins_per_octaves": [24, 36, 48],
|
||||
|
||||
"mpd_reshapes": [2, 3, 5, 7, 11],
|
||||
"use_spectral_norm": false,
|
||||
"discriminator_channel_mult": 1,
|
||||
|
||||
"use_multiscale_melloss": true,
|
||||
"lambda_melloss": 15,
|
||||
|
||||
"clip_grad_norm": 500,
|
||||
|
||||
"segment_size": 65536,
|
||||
"num_mels": 128,
|
||||
"num_freq": 1025,
|
||||
"n_fft": 1024,
|
||||
"hop_size": 256,
|
||||
"win_size": 1024,
|
||||
|
||||
"sampling_rate": 44100,
|
||||
|
||||
"fmin": 0,
|
||||
"fmax": null,
|
||||
"fmax_for_loss": null,
|
||||
|
||||
"num_workers": 4,
|
||||
|
||||
"dist_config": {
|
||||
"dist_backend": "nccl",
|
||||
"dist_url": "tcp://localhost:54321",
|
||||
"world_size": 1
|
||||
}
|
||||
}
|
@ -1,61 +0,0 @@
|
||||
{
|
||||
"resblock": "1",
|
||||
"num_gpus": 0,
|
||||
"batch_size": 4,
|
||||
"learning_rate": 0.0001,
|
||||
"adam_b1": 0.8,
|
||||
"adam_b2": 0.99,
|
||||
"lr_decay": 0.9999996,
|
||||
"seed": 1234,
|
||||
|
||||
"upsample_rates": [8,4,2,2,2,2],
|
||||
"upsample_kernel_sizes": [16,8,4,4,4,4],
|
||||
"upsample_initial_channel": 1536,
|
||||
"resblock_kernel_sizes": [3,7,11],
|
||||
"resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]],
|
||||
|
||||
"use_tanh_at_final": false,
|
||||
"use_bias_at_final": false,
|
||||
|
||||
"activation": "snakebeta",
|
||||
"snake_logscale": true,
|
||||
|
||||
"use_cqtd_instead_of_mrd": true,
|
||||
"cqtd_filters": 128,
|
||||
"cqtd_max_filters": 1024,
|
||||
"cqtd_filters_scale": 1,
|
||||
"cqtd_dilations": [1, 2, 4],
|
||||
"cqtd_hop_lengths": [512, 256, 256],
|
||||
"cqtd_n_octaves": [9, 9, 9],
|
||||
"cqtd_bins_per_octaves": [24, 36, 48],
|
||||
|
||||
"mpd_reshapes": [2, 3, 5, 7, 11],
|
||||
"use_spectral_norm": false,
|
||||
"discriminator_channel_mult": 1,
|
||||
|
||||
"use_multiscale_melloss": true,
|
||||
"lambda_melloss": 15,
|
||||
|
||||
"clip_grad_norm": 500,
|
||||
|
||||
"segment_size": 65536,
|
||||
"num_mels": 128,
|
||||
"num_freq": 2049,
|
||||
"n_fft": 2048,
|
||||
"hop_size": 512,
|
||||
"win_size": 2048,
|
||||
|
||||
"sampling_rate": 44100,
|
||||
|
||||
"fmin": 0,
|
||||
"fmax": null,
|
||||
"fmax_for_loss": null,
|
||||
|
||||
"num_workers": 4,
|
||||
|
||||
"dist_config": {
|
||||
"dist_backend": "nccl",
|
||||
"dist_url": "tcp://localhost:54321",
|
||||
"world_size": 1
|
||||
}
|
||||
}
|
@ -1,625 +0,0 @@
|
||||
# Copyright (c) 2024 NVIDIA CORPORATION.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
# Adapted from https://github.com/jik876/hifi-gan under the MIT license.
|
||||
# LICENSE is in incl_licenses directory.
|
||||
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torch.nn as nn
|
||||
from torch.nn import Conv2d
|
||||
from torch.nn.utils import weight_norm, spectral_norm
|
||||
from torchaudio.transforms import Spectrogram, Resample
|
||||
|
||||
from env import AttrDict
|
||||
from utils import get_padding
|
||||
import typing
|
||||
from typing import List, Tuple
|
||||
|
||||
|
||||
class DiscriminatorP(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
h: AttrDict,
|
||||
period: List[int],
|
||||
kernel_size: int = 5,
|
||||
stride: int = 3,
|
||||
use_spectral_norm: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.period = period
|
||||
self.d_mult = h.discriminator_channel_mult
|
||||
norm_f = weight_norm if not use_spectral_norm else spectral_norm
|
||||
|
||||
self.convs = nn.ModuleList(
|
||||
[
|
||||
norm_f(
|
||||
Conv2d(
|
||||
1,
|
||||
int(32 * self.d_mult),
|
||||
(kernel_size, 1),
|
||||
(stride, 1),
|
||||
padding=(get_padding(5, 1), 0),
|
||||
)
|
||||
),
|
||||
norm_f(
|
||||
Conv2d(
|
||||
int(32 * self.d_mult),
|
||||
int(128 * self.d_mult),
|
||||
(kernel_size, 1),
|
||||
(stride, 1),
|
||||
padding=(get_padding(5, 1), 0),
|
||||
)
|
||||
),
|
||||
norm_f(
|
||||
Conv2d(
|
||||
int(128 * self.d_mult),
|
||||
int(512 * self.d_mult),
|
||||
(kernel_size, 1),
|
||||
(stride, 1),
|
||||
padding=(get_padding(5, 1), 0),
|
||||
)
|
||||
),
|
||||
norm_f(
|
||||
Conv2d(
|
||||
int(512 * self.d_mult),
|
||||
int(1024 * self.d_mult),
|
||||
(kernel_size, 1),
|
||||
(stride, 1),
|
||||
padding=(get_padding(5, 1), 0),
|
||||
)
|
||||
),
|
||||
norm_f(
|
||||
Conv2d(
|
||||
int(1024 * self.d_mult),
|
||||
int(1024 * self.d_mult),
|
||||
(kernel_size, 1),
|
||||
1,
|
||||
padding=(2, 0),
|
||||
)
|
||||
),
|
||||
]
|
||||
)
|
||||
self.conv_post = norm_f(Conv2d(int(1024 * self.d_mult), 1, (3, 1), 1, padding=(1, 0)))
|
||||
|
||||
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, List[torch.Tensor]]:
|
||||
fmap = []
|
||||
|
||||
# 1d to 2d
|
||||
b, c, t = x.shape
|
||||
if t % self.period != 0: # pad first
|
||||
n_pad = self.period - (t % self.period)
|
||||
x = F.pad(x, (0, n_pad), "reflect")
|
||||
t = t + n_pad
|
||||
x = x.view(b, c, t // self.period, self.period)
|
||||
|
||||
for l in self.convs:
|
||||
x = l(x)
|
||||
x = F.leaky_relu(x, 0.1)
|
||||
fmap.append(x)
|
||||
x = self.conv_post(x)
|
||||
fmap.append(x)
|
||||
x = torch.flatten(x, 1, -1)
|
||||
|
||||
return x, fmap
|
||||
|
||||
|
||||
class MultiPeriodDiscriminator(torch.nn.Module):
|
||||
def __init__(self, h: AttrDict):
|
||||
super().__init__()
|
||||
self.mpd_reshapes = h.mpd_reshapes
|
||||
print(f"mpd_reshapes: {self.mpd_reshapes}")
|
||||
self.discriminators = nn.ModuleList(
|
||||
[DiscriminatorP(h, rs, use_spectral_norm=h.use_spectral_norm) for rs in self.mpd_reshapes]
|
||||
)
|
||||
|
||||
def forward(
|
||||
self, y: torch.Tensor, y_hat: torch.Tensor
|
||||
) -> Tuple[
|
||||
List[torch.Tensor],
|
||||
List[torch.Tensor],
|
||||
List[List[torch.Tensor]],
|
||||
List[List[torch.Tensor]],
|
||||
]:
|
||||
y_d_rs = []
|
||||
y_d_gs = []
|
||||
fmap_rs = []
|
||||
fmap_gs = []
|
||||
for i, d in enumerate(self.discriminators):
|
||||
y_d_r, fmap_r = d(y)
|
||||
y_d_g, fmap_g = d(y_hat)
|
||||
y_d_rs.append(y_d_r)
|
||||
fmap_rs.append(fmap_r)
|
||||
y_d_gs.append(y_d_g)
|
||||
fmap_gs.append(fmap_g)
|
||||
|
||||
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
|
||||
|
||||
|
||||
class DiscriminatorR(nn.Module):
|
||||
def __init__(self, cfg: AttrDict, resolution: List[List[int]]):
|
||||
super().__init__()
|
||||
|
||||
self.resolution = resolution
|
||||
assert len(self.resolution) == 3, f"MRD layer requires list with len=3, got {self.resolution}"
|
||||
self.lrelu_slope = 0.1
|
||||
|
||||
norm_f = weight_norm if cfg.use_spectral_norm == False else spectral_norm
|
||||
if hasattr(cfg, "mrd_use_spectral_norm"):
|
||||
print(f"[INFO] overriding MRD use_spectral_norm as {cfg.mrd_use_spectral_norm}")
|
||||
norm_f = weight_norm if cfg.mrd_use_spectral_norm == False else spectral_norm
|
||||
self.d_mult = cfg.discriminator_channel_mult
|
||||
if hasattr(cfg, "mrd_channel_mult"):
|
||||
print(f"[INFO] overriding mrd channel multiplier as {cfg.mrd_channel_mult}")
|
||||
self.d_mult = cfg.mrd_channel_mult
|
||||
|
||||
self.convs = nn.ModuleList(
|
||||
[
|
||||
norm_f(nn.Conv2d(1, int(32 * self.d_mult), (3, 9), padding=(1, 4))),
|
||||
norm_f(
|
||||
nn.Conv2d(
|
||||
int(32 * self.d_mult),
|
||||
int(32 * self.d_mult),
|
||||
(3, 9),
|
||||
stride=(1, 2),
|
||||
padding=(1, 4),
|
||||
)
|
||||
),
|
||||
norm_f(
|
||||
nn.Conv2d(
|
||||
int(32 * self.d_mult),
|
||||
int(32 * self.d_mult),
|
||||
(3, 9),
|
||||
stride=(1, 2),
|
||||
padding=(1, 4),
|
||||
)
|
||||
),
|
||||
norm_f(
|
||||
nn.Conv2d(
|
||||
int(32 * self.d_mult),
|
||||
int(32 * self.d_mult),
|
||||
(3, 9),
|
||||
stride=(1, 2),
|
||||
padding=(1, 4),
|
||||
)
|
||||
),
|
||||
norm_f(
|
||||
nn.Conv2d(
|
||||
int(32 * self.d_mult),
|
||||
int(32 * self.d_mult),
|
||||
(3, 3),
|
||||
padding=(1, 1),
|
||||
)
|
||||
),
|
||||
]
|
||||
)
|
||||
self.conv_post = norm_f(nn.Conv2d(int(32 * self.d_mult), 1, (3, 3), padding=(1, 1)))
|
||||
|
||||
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, List[torch.Tensor]]:
|
||||
fmap = []
|
||||
|
||||
x = self.spectrogram(x)
|
||||
x = x.unsqueeze(1)
|
||||
for l in self.convs:
|
||||
x = l(x)
|
||||
x = F.leaky_relu(x, self.lrelu_slope)
|
||||
fmap.append(x)
|
||||
x = self.conv_post(x)
|
||||
fmap.append(x)
|
||||
x = torch.flatten(x, 1, -1)
|
||||
|
||||
return x, fmap
|
||||
|
||||
def spectrogram(self, x: torch.Tensor) -> torch.Tensor:
|
||||
n_fft, hop_length, win_length = self.resolution
|
||||
x = F.pad(
|
||||
x,
|
||||
(int((n_fft - hop_length) / 2), int((n_fft - hop_length) / 2)),
|
||||
mode="reflect",
|
||||
)
|
||||
x = x.squeeze(1)
|
||||
x = torch.stft(
|
||||
x,
|
||||
n_fft=n_fft,
|
||||
hop_length=hop_length,
|
||||
win_length=win_length,
|
||||
center=False,
|
||||
return_complex=True,
|
||||
)
|
||||
x = torch.view_as_real(x) # [B, F, TT, 2]
|
||||
mag = torch.norm(x, p=2, dim=-1) # [B, F, TT]
|
||||
|
||||
return mag
|
||||
|
||||
|
||||
class MultiResolutionDiscriminator(nn.Module):
|
||||
def __init__(self, cfg, debug=False):
|
||||
super().__init__()
|
||||
self.resolutions = cfg.resolutions
|
||||
assert len(self.resolutions) == 3, (
|
||||
f"MRD requires list of list with len=3, each element having a list with len=3. Got {self.resolutions}"
|
||||
)
|
||||
self.discriminators = nn.ModuleList([DiscriminatorR(cfg, resolution) for resolution in self.resolutions])
|
||||
|
||||
def forward(
|
||||
self, y: torch.Tensor, y_hat: torch.Tensor
|
||||
) -> Tuple[
|
||||
List[torch.Tensor],
|
||||
List[torch.Tensor],
|
||||
List[List[torch.Tensor]],
|
||||
List[List[torch.Tensor]],
|
||||
]:
|
||||
y_d_rs = []
|
||||
y_d_gs = []
|
||||
fmap_rs = []
|
||||
fmap_gs = []
|
||||
|
||||
for i, d in enumerate(self.discriminators):
|
||||
y_d_r, fmap_r = d(x=y)
|
||||
y_d_g, fmap_g = d(x=y_hat)
|
||||
y_d_rs.append(y_d_r)
|
||||
fmap_rs.append(fmap_r)
|
||||
y_d_gs.append(y_d_g)
|
||||
fmap_gs.append(fmap_g)
|
||||
|
||||
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
|
||||
|
||||
|
||||
# Method based on descript-audio-codec: https://github.com/descriptinc/descript-audio-codec
|
||||
# Modified code adapted from https://github.com/gemelo-ai/vocos under the MIT license.
|
||||
# LICENSE is in incl_licenses directory.
|
||||
class DiscriminatorB(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
window_length: int,
|
||||
channels: int = 32,
|
||||
hop_factor: float = 0.25,
|
||||
bands: Tuple[Tuple[float, float], ...] = (
|
||||
(0.0, 0.1),
|
||||
(0.1, 0.25),
|
||||
(0.25, 0.5),
|
||||
(0.5, 0.75),
|
||||
(0.75, 1.0),
|
||||
),
|
||||
):
|
||||
super().__init__()
|
||||
self.window_length = window_length
|
||||
self.hop_factor = hop_factor
|
||||
self.spec_fn = Spectrogram(
|
||||
n_fft=window_length,
|
||||
hop_length=int(window_length * hop_factor),
|
||||
win_length=window_length,
|
||||
power=None,
|
||||
)
|
||||
n_fft = window_length // 2 + 1
|
||||
bands = [(int(b[0] * n_fft), int(b[1] * n_fft)) for b in bands]
|
||||
self.bands = bands
|
||||
convs = lambda: nn.ModuleList(
|
||||
[
|
||||
weight_norm(nn.Conv2d(2, channels, (3, 9), (1, 1), padding=(1, 4))),
|
||||
weight_norm(nn.Conv2d(channels, channels, (3, 9), (1, 2), padding=(1, 4))),
|
||||
weight_norm(nn.Conv2d(channels, channels, (3, 9), (1, 2), padding=(1, 4))),
|
||||
weight_norm(nn.Conv2d(channels, channels, (3, 9), (1, 2), padding=(1, 4))),
|
||||
weight_norm(nn.Conv2d(channels, channels, (3, 3), (1, 1), padding=(1, 1))),
|
||||
]
|
||||
)
|
||||
self.band_convs = nn.ModuleList([convs() for _ in range(len(self.bands))])
|
||||
|
||||
self.conv_post = weight_norm(nn.Conv2d(channels, 1, (3, 3), (1, 1), padding=(1, 1)))
|
||||
|
||||
def spectrogram(self, x: torch.Tensor) -> List[torch.Tensor]:
|
||||
# Remove DC offset
|
||||
x = x - x.mean(dim=-1, keepdims=True)
|
||||
# Peak normalize the volume of input audio
|
||||
x = 0.8 * x / (x.abs().max(dim=-1, keepdim=True)[0] + 1e-9)
|
||||
x = self.spec_fn(x)
|
||||
x = torch.view_as_real(x)
|
||||
x = x.permute(0, 3, 2, 1) # [B, F, T, C] -> [B, C, T, F]
|
||||
# Split into bands
|
||||
x_bands = [x[..., b[0] : b[1]] for b in self.bands]
|
||||
return x_bands
|
||||
|
||||
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, List[torch.Tensor]]:
|
||||
x_bands = self.spectrogram(x.squeeze(1))
|
||||
fmap = []
|
||||
x = []
|
||||
|
||||
for band, stack in zip(x_bands, self.band_convs):
|
||||
for i, layer in enumerate(stack):
|
||||
band = layer(band)
|
||||
band = torch.nn.functional.leaky_relu(band, 0.1)
|
||||
if i > 0:
|
||||
fmap.append(band)
|
||||
x.append(band)
|
||||
|
||||
x = torch.cat(x, dim=-1)
|
||||
x = self.conv_post(x)
|
||||
fmap.append(x)
|
||||
|
||||
return x, fmap
|
||||
|
||||
|
||||
# Method based on descript-audio-codec: https://github.com/descriptinc/descript-audio-codec
|
||||
# Modified code adapted from https://github.com/gemelo-ai/vocos under the MIT license.
|
||||
# LICENSE is in incl_licenses directory.
|
||||
class MultiBandDiscriminator(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
h,
|
||||
):
|
||||
"""
|
||||
Multi-band multi-scale STFT discriminator, with the architecture based on https://github.com/descriptinc/descript-audio-codec.
|
||||
and the modified code adapted from https://github.com/gemelo-ai/vocos.
|
||||
"""
|
||||
super().__init__()
|
||||
# fft_sizes (list[int]): Tuple of window lengths for FFT. Defaults to [2048, 1024, 512] if not set in h.
|
||||
self.fft_sizes = h.get("mbd_fft_sizes", [2048, 1024, 512])
|
||||
self.discriminators = nn.ModuleList([DiscriminatorB(window_length=w) for w in self.fft_sizes])
|
||||
|
||||
def forward(
|
||||
self, y: torch.Tensor, y_hat: torch.Tensor
|
||||
) -> Tuple[
|
||||
List[torch.Tensor],
|
||||
List[torch.Tensor],
|
||||
List[List[torch.Tensor]],
|
||||
List[List[torch.Tensor]],
|
||||
]:
|
||||
y_d_rs = []
|
||||
y_d_gs = []
|
||||
fmap_rs = []
|
||||
fmap_gs = []
|
||||
|
||||
for d in self.discriminators:
|
||||
y_d_r, fmap_r = d(x=y)
|
||||
y_d_g, fmap_g = d(x=y_hat)
|
||||
y_d_rs.append(y_d_r)
|
||||
fmap_rs.append(fmap_r)
|
||||
y_d_gs.append(y_d_g)
|
||||
fmap_gs.append(fmap_g)
|
||||
|
||||
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
|
||||
|
||||
|
||||
# Adapted from https://github.com/open-mmlab/Amphion/blob/main/models/vocoders/gan/discriminator/mssbcqtd.py under the MIT license.
|
||||
# LICENSE is in incl_licenses directory.
|
||||
class DiscriminatorCQT(nn.Module):
|
||||
def __init__(self, cfg: AttrDict, hop_length: int, n_octaves: int, bins_per_octave: int):
|
||||
super().__init__()
|
||||
self.cfg = cfg
|
||||
|
||||
self.filters = cfg["cqtd_filters"]
|
||||
self.max_filters = cfg["cqtd_max_filters"]
|
||||
self.filters_scale = cfg["cqtd_filters_scale"]
|
||||
self.kernel_size = (3, 9)
|
||||
self.dilations = cfg["cqtd_dilations"]
|
||||
self.stride = (1, 2)
|
||||
|
||||
self.in_channels = cfg["cqtd_in_channels"]
|
||||
self.out_channels = cfg["cqtd_out_channels"]
|
||||
self.fs = cfg["sampling_rate"]
|
||||
self.hop_length = hop_length
|
||||
self.n_octaves = n_octaves
|
||||
self.bins_per_octave = bins_per_octave
|
||||
|
||||
# Lazy-load
|
||||
from nnAudio import features
|
||||
|
||||
self.cqt_transform = features.cqt.CQT2010v2(
|
||||
sr=self.fs * 2,
|
||||
hop_length=self.hop_length,
|
||||
n_bins=self.bins_per_octave * self.n_octaves,
|
||||
bins_per_octave=self.bins_per_octave,
|
||||
output_format="Complex",
|
||||
pad_mode="constant",
|
||||
)
|
||||
|
||||
self.conv_pres = nn.ModuleList()
|
||||
for _ in range(self.n_octaves):
|
||||
self.conv_pres.append(
|
||||
nn.Conv2d(
|
||||
self.in_channels * 2,
|
||||
self.in_channels * 2,
|
||||
kernel_size=self.kernel_size,
|
||||
padding=self.get_2d_padding(self.kernel_size),
|
||||
)
|
||||
)
|
||||
|
||||
self.convs = nn.ModuleList()
|
||||
|
||||
self.convs.append(
|
||||
nn.Conv2d(
|
||||
self.in_channels * 2,
|
||||
self.filters,
|
||||
kernel_size=self.kernel_size,
|
||||
padding=self.get_2d_padding(self.kernel_size),
|
||||
)
|
||||
)
|
||||
|
||||
in_chs = min(self.filters_scale * self.filters, self.max_filters)
|
||||
for i, dilation in enumerate(self.dilations):
|
||||
out_chs = min((self.filters_scale ** (i + 1)) * self.filters, self.max_filters)
|
||||
self.convs.append(
|
||||
weight_norm(
|
||||
nn.Conv2d(
|
||||
in_chs,
|
||||
out_chs,
|
||||
kernel_size=self.kernel_size,
|
||||
stride=self.stride,
|
||||
dilation=(dilation, 1),
|
||||
padding=self.get_2d_padding(self.kernel_size, (dilation, 1)),
|
||||
)
|
||||
)
|
||||
)
|
||||
in_chs = out_chs
|
||||
out_chs = min(
|
||||
(self.filters_scale ** (len(self.dilations) + 1)) * self.filters,
|
||||
self.max_filters,
|
||||
)
|
||||
self.convs.append(
|
||||
weight_norm(
|
||||
nn.Conv2d(
|
||||
in_chs,
|
||||
out_chs,
|
||||
kernel_size=(self.kernel_size[0], self.kernel_size[0]),
|
||||
padding=self.get_2d_padding((self.kernel_size[0], self.kernel_size[0])),
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
self.conv_post = weight_norm(
|
||||
nn.Conv2d(
|
||||
out_chs,
|
||||
self.out_channels,
|
||||
kernel_size=(self.kernel_size[0], self.kernel_size[0]),
|
||||
padding=self.get_2d_padding((self.kernel_size[0], self.kernel_size[0])),
|
||||
)
|
||||
)
|
||||
|
||||
self.activation = torch.nn.LeakyReLU(negative_slope=0.1)
|
||||
self.resample = Resample(orig_freq=self.fs, new_freq=self.fs * 2)
|
||||
|
||||
self.cqtd_normalize_volume = self.cfg.get("cqtd_normalize_volume", False)
|
||||
if self.cqtd_normalize_volume:
|
||||
print(
|
||||
"[INFO] cqtd_normalize_volume set to True. Will apply DC offset removal & peak volume normalization in CQTD!"
|
||||
)
|
||||
|
||||
def get_2d_padding(
|
||||
self,
|
||||
kernel_size: typing.Tuple[int, int],
|
||||
dilation: typing.Tuple[int, int] = (1, 1),
|
||||
):
|
||||
return (
|
||||
((kernel_size[0] - 1) * dilation[0]) // 2,
|
||||
((kernel_size[1] - 1) * dilation[1]) // 2,
|
||||
)
|
||||
|
||||
def forward(self, x: torch.tensor) -> Tuple[torch.Tensor, List[torch.Tensor]]:
|
||||
fmap = []
|
||||
|
||||
if self.cqtd_normalize_volume:
|
||||
# Remove DC offset
|
||||
x = x - x.mean(dim=-1, keepdims=True)
|
||||
# Peak normalize the volume of input audio
|
||||
x = 0.8 * x / (x.abs().max(dim=-1, keepdim=True)[0] + 1e-9)
|
||||
|
||||
x = self.resample(x)
|
||||
|
||||
z = self.cqt_transform(x)
|
||||
|
||||
z_amplitude = z[:, :, :, 0].unsqueeze(1)
|
||||
z_phase = z[:, :, :, 1].unsqueeze(1)
|
||||
|
||||
z = torch.cat([z_amplitude, z_phase], dim=1)
|
||||
z = torch.permute(z, (0, 1, 3, 2)) # [B, C, W, T] -> [B, C, T, W]
|
||||
|
||||
latent_z = []
|
||||
for i in range(self.n_octaves):
|
||||
latent_z.append(
|
||||
self.conv_pres[i](
|
||||
z[
|
||||
:,
|
||||
:,
|
||||
:,
|
||||
i * self.bins_per_octave : (i + 1) * self.bins_per_octave,
|
||||
]
|
||||
)
|
||||
)
|
||||
latent_z = torch.cat(latent_z, dim=-1)
|
||||
|
||||
for i, l in enumerate(self.convs):
|
||||
latent_z = l(latent_z)
|
||||
|
||||
latent_z = self.activation(latent_z)
|
||||
fmap.append(latent_z)
|
||||
|
||||
latent_z = self.conv_post(latent_z)
|
||||
|
||||
return latent_z, fmap
|
||||
|
||||
|
||||
class MultiScaleSubbandCQTDiscriminator(nn.Module):
|
||||
def __init__(self, cfg: AttrDict):
|
||||
super().__init__()
|
||||
|
||||
self.cfg = cfg
|
||||
# Using get with defaults
|
||||
self.cfg["cqtd_filters"] = self.cfg.get("cqtd_filters", 32)
|
||||
self.cfg["cqtd_max_filters"] = self.cfg.get("cqtd_max_filters", 1024)
|
||||
self.cfg["cqtd_filters_scale"] = self.cfg.get("cqtd_filters_scale", 1)
|
||||
self.cfg["cqtd_dilations"] = self.cfg.get("cqtd_dilations", [1, 2, 4])
|
||||
self.cfg["cqtd_in_channels"] = self.cfg.get("cqtd_in_channels", 1)
|
||||
self.cfg["cqtd_out_channels"] = self.cfg.get("cqtd_out_channels", 1)
|
||||
# Multi-scale params to loop over
|
||||
self.cfg["cqtd_hop_lengths"] = self.cfg.get("cqtd_hop_lengths", [512, 256, 256])
|
||||
self.cfg["cqtd_n_octaves"] = self.cfg.get("cqtd_n_octaves", [9, 9, 9])
|
||||
self.cfg["cqtd_bins_per_octaves"] = self.cfg.get("cqtd_bins_per_octaves", [24, 36, 48])
|
||||
|
||||
self.discriminators = nn.ModuleList(
|
||||
[
|
||||
DiscriminatorCQT(
|
||||
self.cfg,
|
||||
hop_length=self.cfg["cqtd_hop_lengths"][i],
|
||||
n_octaves=self.cfg["cqtd_n_octaves"][i],
|
||||
bins_per_octave=self.cfg["cqtd_bins_per_octaves"][i],
|
||||
)
|
||||
for i in range(len(self.cfg["cqtd_hop_lengths"]))
|
||||
]
|
||||
)
|
||||
|
||||
def forward(
|
||||
self, y: torch.Tensor, y_hat: torch.Tensor
|
||||
) -> Tuple[
|
||||
List[torch.Tensor],
|
||||
List[torch.Tensor],
|
||||
List[List[torch.Tensor]],
|
||||
List[List[torch.Tensor]],
|
||||
]:
|
||||
y_d_rs = []
|
||||
y_d_gs = []
|
||||
fmap_rs = []
|
||||
fmap_gs = []
|
||||
|
||||
for disc in self.discriminators:
|
||||
y_d_r, fmap_r = disc(y)
|
||||
y_d_g, fmap_g = disc(y_hat)
|
||||
y_d_rs.append(y_d_r)
|
||||
fmap_rs.append(fmap_r)
|
||||
y_d_gs.append(y_d_g)
|
||||
fmap_gs.append(fmap_g)
|
||||
|
||||
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
|
||||
|
||||
|
||||
class CombinedDiscriminator(nn.Module):
|
||||
"""
|
||||
Wrapper of chaining multiple discrimiantor architectures.
|
||||
Example: combine mbd and cqtd as a single class
|
||||
"""
|
||||
|
||||
def __init__(self, list_discriminator: List[nn.Module]):
|
||||
super().__init__()
|
||||
self.discrimiantor = nn.ModuleList(list_discriminator)
|
||||
|
||||
def forward(
|
||||
self, y: torch.Tensor, y_hat: torch.Tensor
|
||||
) -> Tuple[
|
||||
List[torch.Tensor],
|
||||
List[torch.Tensor],
|
||||
List[List[torch.Tensor]],
|
||||
List[List[torch.Tensor]],
|
||||
]:
|
||||
y_d_rs = []
|
||||
y_d_gs = []
|
||||
fmap_rs = []
|
||||
fmap_gs = []
|
||||
|
||||
for disc in self.discrimiantor:
|
||||
y_d_r, y_d_g, fmap_r, fmap_g = disc(y, y_hat)
|
||||
y_d_rs.extend(y_d_r)
|
||||
fmap_rs.extend(fmap_r)
|
||||
y_d_gs.extend(y_d_g)
|
||||
fmap_gs.extend(fmap_g)
|
||||
|
||||
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
|
@ -1,85 +0,0 @@
|
||||
# Adapted from https://github.com/jik876/hifi-gan under the MIT license.
|
||||
# LICENSE is in incl_licenses directory.
|
||||
|
||||
from __future__ import absolute_import, division, print_function, unicode_literals
|
||||
|
||||
import os
|
||||
import argparse
|
||||
import json
|
||||
import torch
|
||||
import librosa
|
||||
from utils import load_checkpoint
|
||||
from meldataset import get_mel_spectrogram
|
||||
from scipy.io.wavfile import write
|
||||
from env import AttrDict
|
||||
from meldataset import MAX_WAV_VALUE
|
||||
from bigvgan import BigVGAN as Generator
|
||||
|
||||
h = None
|
||||
device = None
|
||||
torch.backends.cudnn.benchmark = False
|
||||
|
||||
|
||||
def inference(a, h):
|
||||
generator = Generator(h, use_cuda_kernel=a.use_cuda_kernel).to(device)
|
||||
|
||||
state_dict_g = load_checkpoint(a.checkpoint_file, device)
|
||||
generator.load_state_dict(state_dict_g["generator"])
|
||||
|
||||
filelist = os.listdir(a.input_wavs_dir)
|
||||
|
||||
os.makedirs(a.output_dir, exist_ok=True)
|
||||
|
||||
generator.eval()
|
||||
generator.remove_weight_norm()
|
||||
with torch.no_grad():
|
||||
for i, filname in enumerate(filelist):
|
||||
# Load the ground truth audio and resample if necessary
|
||||
wav, sr = librosa.load(os.path.join(a.input_wavs_dir, filname), sr=h.sampling_rate, mono=True)
|
||||
wav = torch.FloatTensor(wav).to(device)
|
||||
# Compute mel spectrogram from the ground truth audio
|
||||
x = get_mel_spectrogram(wav.unsqueeze(0), generator.h)
|
||||
|
||||
y_g_hat = generator(x)
|
||||
|
||||
audio = y_g_hat.squeeze()
|
||||
audio = audio * MAX_WAV_VALUE
|
||||
audio = audio.cpu().numpy().astype("int16")
|
||||
|
||||
output_file = os.path.join(a.output_dir, os.path.splitext(filname)[0] + "_generated.wav")
|
||||
write(output_file, h.sampling_rate, audio)
|
||||
print(output_file)
|
||||
|
||||
|
||||
def main():
|
||||
print("Initializing Inference Process..")
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--input_wavs_dir", default="test_files")
|
||||
parser.add_argument("--output_dir", default="generated_files")
|
||||
parser.add_argument("--checkpoint_file", required=True)
|
||||
parser.add_argument("--use_cuda_kernel", action="store_true", default=False)
|
||||
|
||||
a = parser.parse_args()
|
||||
|
||||
config_file = os.path.join(os.path.split(a.checkpoint_file)[0], "config.json")
|
||||
with open(config_file) as f:
|
||||
data = f.read()
|
||||
|
||||
global h
|
||||
json_config = json.loads(data)
|
||||
h = AttrDict(json_config)
|
||||
|
||||
torch.manual_seed(h.seed)
|
||||
global device
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed(h.seed)
|
||||
device = torch.device("cuda")
|
||||
else:
|
||||
device = torch.device("cpu")
|
||||
|
||||
inference(a, h)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -1,100 +0,0 @@
|
||||
# Adapted from https://github.com/jik876/hifi-gan under the MIT license.
|
||||
# LICENSE is in incl_licenses directory.
|
||||
|
||||
from __future__ import absolute_import, division, print_function, unicode_literals
|
||||
|
||||
import glob
|
||||
import os
|
||||
import numpy as np
|
||||
import argparse
|
||||
import json
|
||||
import torch
|
||||
from scipy.io.wavfile import write
|
||||
from env import AttrDict
|
||||
from meldataset import MAX_WAV_VALUE
|
||||
from bigvgan import BigVGAN as Generator
|
||||
|
||||
h = None
|
||||
device = None
|
||||
torch.backends.cudnn.benchmark = False
|
||||
|
||||
|
||||
def load_checkpoint(filepath, device):
|
||||
assert os.path.isfile(filepath)
|
||||
print(f"Loading '{filepath}'")
|
||||
checkpoint_dict = torch.load(filepath, map_location=device)
|
||||
print("Complete.")
|
||||
return checkpoint_dict
|
||||
|
||||
|
||||
def scan_checkpoint(cp_dir, prefix):
|
||||
pattern = os.path.join(cp_dir, prefix + "*")
|
||||
cp_list = glob.glob(pattern)
|
||||
if len(cp_list) == 0:
|
||||
return ""
|
||||
return sorted(cp_list)[-1]
|
||||
|
||||
|
||||
def inference(a, h):
|
||||
generator = Generator(h, use_cuda_kernel=a.use_cuda_kernel).to(device)
|
||||
|
||||
state_dict_g = load_checkpoint(a.checkpoint_file, device)
|
||||
generator.load_state_dict(state_dict_g["generator"])
|
||||
|
||||
filelist = os.listdir(a.input_mels_dir)
|
||||
|
||||
os.makedirs(a.output_dir, exist_ok=True)
|
||||
|
||||
generator.eval()
|
||||
generator.remove_weight_norm()
|
||||
with torch.no_grad():
|
||||
for i, filname in enumerate(filelist):
|
||||
# Load the mel spectrogram in .npy format
|
||||
x = np.load(os.path.join(a.input_mels_dir, filname))
|
||||
x = torch.FloatTensor(x).to(device)
|
||||
if len(x.shape) == 2:
|
||||
x = x.unsqueeze(0)
|
||||
|
||||
y_g_hat = generator(x)
|
||||
|
||||
audio = y_g_hat.squeeze()
|
||||
audio = audio * MAX_WAV_VALUE
|
||||
audio = audio.cpu().numpy().astype("int16")
|
||||
|
||||
output_file = os.path.join(a.output_dir, os.path.splitext(filname)[0] + "_generated_e2e.wav")
|
||||
write(output_file, h.sampling_rate, audio)
|
||||
print(output_file)
|
||||
|
||||
|
||||
def main():
|
||||
print("Initializing Inference Process..")
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--input_mels_dir", default="test_mel_files")
|
||||
parser.add_argument("--output_dir", default="generated_files_from_mel")
|
||||
parser.add_argument("--checkpoint_file", required=True)
|
||||
parser.add_argument("--use_cuda_kernel", action="store_true", default=False)
|
||||
|
||||
a = parser.parse_args()
|
||||
|
||||
config_file = os.path.join(os.path.split(a.checkpoint_file)[0], "config.json")
|
||||
with open(config_file) as f:
|
||||
data = f.read()
|
||||
|
||||
global h
|
||||
json_config = json.loads(data)
|
||||
h = AttrDict(json_config)
|
||||
|
||||
torch.manual_seed(h.seed)
|
||||
global device
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed(h.seed)
|
||||
device = torch.device("cuda")
|
||||
else:
|
||||
device = torch.device("cpu")
|
||||
|
||||
inference(a, h)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -1,238 +0,0 @@
|
||||
# Copyright (c) 2024 NVIDIA CORPORATION.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
# Adapted from https://github.com/jik876/hifi-gan under the MIT license.
|
||||
# LICENSE is in incl_licenses directory.
|
||||
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from librosa.filters import mel as librosa_mel_fn
|
||||
from scipy import signal
|
||||
|
||||
import typing
|
||||
from typing import List, Tuple
|
||||
from collections import namedtuple
|
||||
import math
|
||||
import functools
|
||||
|
||||
|
||||
# Adapted from https://github.com/descriptinc/descript-audio-codec/blob/main/dac/nn/loss.py under the MIT license.
|
||||
# LICENSE is in incl_licenses directory.
|
||||
class MultiScaleMelSpectrogramLoss(nn.Module):
|
||||
"""Compute distance between mel spectrograms. Can be used
|
||||
in a multi-scale way.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
n_mels : List[int]
|
||||
Number of mels per STFT, by default [5, 10, 20, 40, 80, 160, 320],
|
||||
window_lengths : List[int], optional
|
||||
Length of each window of each STFT, by default [32, 64, 128, 256, 512, 1024, 2048]
|
||||
loss_fn : typing.Callable, optional
|
||||
How to compare each loss, by default nn.L1Loss()
|
||||
clamp_eps : float, optional
|
||||
Clamp on the log magnitude, below, by default 1e-5
|
||||
mag_weight : float, optional
|
||||
Weight of raw magnitude portion of loss, by default 0.0 (no ampliciation on mag part)
|
||||
log_weight : float, optional
|
||||
Weight of log magnitude portion of loss, by default 1.0
|
||||
pow : float, optional
|
||||
Power to raise magnitude to before taking log, by default 1.0
|
||||
weight : float, optional
|
||||
Weight of this loss, by default 1.0
|
||||
match_stride : bool, optional
|
||||
Whether to match the stride of convolutional layers, by default False
|
||||
|
||||
Implementation copied from: https://github.com/descriptinc/lyrebird-audiotools/blob/961786aa1a9d628cca0c0486e5885a457fe70c1a/audiotools/metrics/spectral.py
|
||||
Additional code copied and modified from https://github.com/descriptinc/audiotools/blob/master/audiotools/core/audio_signal.py
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
sampling_rate: int,
|
||||
n_mels: List[int] = [5, 10, 20, 40, 80, 160, 320],
|
||||
window_lengths: List[int] = [32, 64, 128, 256, 512, 1024, 2048],
|
||||
loss_fn: typing.Callable = nn.L1Loss(),
|
||||
clamp_eps: float = 1e-5,
|
||||
mag_weight: float = 0.0,
|
||||
log_weight: float = 1.0,
|
||||
pow: float = 1.0,
|
||||
weight: float = 1.0,
|
||||
match_stride: bool = False,
|
||||
mel_fmin: List[float] = [0, 0, 0, 0, 0, 0, 0],
|
||||
mel_fmax: List[float] = [None, None, None, None, None, None, None],
|
||||
window_type: str = "hann",
|
||||
):
|
||||
super().__init__()
|
||||
self.sampling_rate = sampling_rate
|
||||
|
||||
STFTParams = namedtuple(
|
||||
"STFTParams",
|
||||
["window_length", "hop_length", "window_type", "match_stride"],
|
||||
)
|
||||
|
||||
self.stft_params = [
|
||||
STFTParams(
|
||||
window_length=w,
|
||||
hop_length=w // 4,
|
||||
match_stride=match_stride,
|
||||
window_type=window_type,
|
||||
)
|
||||
for w in window_lengths
|
||||
]
|
||||
self.n_mels = n_mels
|
||||
self.loss_fn = loss_fn
|
||||
self.clamp_eps = clamp_eps
|
||||
self.log_weight = log_weight
|
||||
self.mag_weight = mag_weight
|
||||
self.weight = weight
|
||||
self.mel_fmin = mel_fmin
|
||||
self.mel_fmax = mel_fmax
|
||||
self.pow = pow
|
||||
|
||||
@staticmethod
|
||||
@functools.lru_cache(None)
|
||||
def get_window(
|
||||
window_type,
|
||||
window_length,
|
||||
):
|
||||
return signal.get_window(window_type, window_length)
|
||||
|
||||
@staticmethod
|
||||
@functools.lru_cache(None)
|
||||
def get_mel_filters(sr, n_fft, n_mels, fmin, fmax):
|
||||
return librosa_mel_fn(sr=sr, n_fft=n_fft, n_mels=n_mels, fmin=fmin, fmax=fmax)
|
||||
|
||||
def mel_spectrogram(
|
||||
self,
|
||||
wav,
|
||||
n_mels,
|
||||
fmin,
|
||||
fmax,
|
||||
window_length,
|
||||
hop_length,
|
||||
match_stride,
|
||||
window_type,
|
||||
):
|
||||
"""
|
||||
Mirrors AudioSignal.mel_spectrogram used by BigVGAN-v2 training from:
|
||||
https://github.com/descriptinc/audiotools/blob/master/audiotools/core/audio_signal.py
|
||||
"""
|
||||
B, C, T = wav.shape
|
||||
|
||||
if match_stride:
|
||||
assert hop_length == window_length // 4, "For match_stride, hop must equal n_fft // 4"
|
||||
right_pad = math.ceil(T / hop_length) * hop_length - T
|
||||
pad = (window_length - hop_length) // 2
|
||||
else:
|
||||
right_pad = 0
|
||||
pad = 0
|
||||
|
||||
wav = torch.nn.functional.pad(wav, (pad, pad + right_pad), mode="reflect")
|
||||
|
||||
window = self.get_window(window_type, window_length)
|
||||
window = torch.from_numpy(window).to(wav.device).float()
|
||||
|
||||
stft = torch.stft(
|
||||
wav.reshape(-1, T),
|
||||
n_fft=window_length,
|
||||
hop_length=hop_length,
|
||||
window=window,
|
||||
return_complex=True,
|
||||
center=True,
|
||||
)
|
||||
_, nf, nt = stft.shape
|
||||
stft = stft.reshape(B, C, nf, nt)
|
||||
if match_stride:
|
||||
"""
|
||||
Drop first two and last two frames, which are added, because of padding. Now num_frames * hop_length = num_samples.
|
||||
"""
|
||||
stft = stft[..., 2:-2]
|
||||
magnitude = torch.abs(stft)
|
||||
|
||||
nf = magnitude.shape[2]
|
||||
mel_basis = self.get_mel_filters(self.sampling_rate, 2 * (nf - 1), n_mels, fmin, fmax)
|
||||
mel_basis = torch.from_numpy(mel_basis).to(wav.device)
|
||||
mel_spectrogram = magnitude.transpose(2, -1) @ mel_basis.T
|
||||
mel_spectrogram = mel_spectrogram.transpose(-1, 2)
|
||||
|
||||
return mel_spectrogram
|
||||
|
||||
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
||||
"""Computes mel loss between an estimate and a reference
|
||||
signal.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
x : torch.Tensor
|
||||
Estimate signal
|
||||
y : torch.Tensor
|
||||
Reference signal
|
||||
|
||||
Returns
|
||||
-------
|
||||
torch.Tensor
|
||||
Mel loss.
|
||||
"""
|
||||
|
||||
loss = 0.0
|
||||
for n_mels, fmin, fmax, s in zip(self.n_mels, self.mel_fmin, self.mel_fmax, self.stft_params):
|
||||
kwargs = {
|
||||
"n_mels": n_mels,
|
||||
"fmin": fmin,
|
||||
"fmax": fmax,
|
||||
"window_length": s.window_length,
|
||||
"hop_length": s.hop_length,
|
||||
"match_stride": s.match_stride,
|
||||
"window_type": s.window_type,
|
||||
}
|
||||
|
||||
x_mels = self.mel_spectrogram(x, **kwargs)
|
||||
y_mels = self.mel_spectrogram(y, **kwargs)
|
||||
x_logmels = torch.log(x_mels.clamp(min=self.clamp_eps).pow(self.pow)) / torch.log(torch.tensor(10.0))
|
||||
y_logmels = torch.log(y_mels.clamp(min=self.clamp_eps).pow(self.pow)) / torch.log(torch.tensor(10.0))
|
||||
|
||||
loss += self.log_weight * self.loss_fn(x_logmels, y_logmels)
|
||||
loss += self.mag_weight * self.loss_fn(x_logmels, y_logmels)
|
||||
|
||||
return loss
|
||||
|
||||
|
||||
# Loss functions
|
||||
def feature_loss(fmap_r: List[List[torch.Tensor]], fmap_g: List[List[torch.Tensor]]) -> torch.Tensor:
|
||||
loss = 0
|
||||
for dr, dg in zip(fmap_r, fmap_g):
|
||||
for rl, gl in zip(dr, dg):
|
||||
loss += torch.mean(torch.abs(rl - gl))
|
||||
|
||||
return loss * 2 # This equates to lambda=2.0 for the feature matching loss
|
||||
|
||||
|
||||
def discriminator_loss(
|
||||
disc_real_outputs: List[torch.Tensor], disc_generated_outputs: List[torch.Tensor]
|
||||
) -> Tuple[torch.Tensor, List[torch.Tensor], List[torch.Tensor]]:
|
||||
loss = 0
|
||||
r_losses = []
|
||||
g_losses = []
|
||||
for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
|
||||
r_loss = torch.mean((1 - dr) ** 2)
|
||||
g_loss = torch.mean(dg**2)
|
||||
loss += r_loss + g_loss
|
||||
r_losses.append(r_loss.item())
|
||||
g_losses.append(g_loss.item())
|
||||
|
||||
return loss, r_losses, g_losses
|
||||
|
||||
|
||||
def generator_loss(
|
||||
disc_outputs: List[torch.Tensor],
|
||||
) -> Tuple[torch.Tensor, List[torch.Tensor]]:
|
||||
loss = 0
|
||||
gen_losses = []
|
||||
for dg in disc_outputs:
|
||||
l = torch.mean((1 - dg) ** 2)
|
||||
gen_losses.append(l)
|
||||
loss += l
|
||||
|
||||
return loss, gen_losses
|
@ -1,370 +0,0 @@
|
||||
# Copyright (c) 2024 NVIDIA CORPORATION.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
# Adapted from https://github.com/jik876/hifi-gan under the MIT license.
|
||||
# LICENSE is in incl_licenses directory.
|
||||
|
||||
import math
|
||||
import os
|
||||
import random
|
||||
import torch
|
||||
import torch.utils.data
|
||||
import numpy as np
|
||||
import librosa
|
||||
from librosa.filters import mel as librosa_mel_fn
|
||||
import pathlib
|
||||
from tqdm import tqdm
|
||||
from typing import List, Tuple, Optional
|
||||
from .env import AttrDict
|
||||
|
||||
MAX_WAV_VALUE = 32767.0 # NOTE: 32768.0 -1 to prevent int16 overflow (results in popping sound in corner cases)
|
||||
|
||||
|
||||
def dynamic_range_compression(x, C=1, clip_val=1e-5):
|
||||
return np.log(np.clip(x, a_min=clip_val, a_max=None) * C)
|
||||
|
||||
|
||||
def dynamic_range_decompression(x, C=1):
|
||||
return np.exp(x) / C
|
||||
|
||||
|
||||
def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
|
||||
return torch.log(torch.clamp(x, min=clip_val) * C)
|
||||
|
||||
|
||||
def dynamic_range_decompression_torch(x, C=1):
|
||||
return torch.exp(x) / C
|
||||
|
||||
|
||||
def spectral_normalize_torch(magnitudes):
|
||||
return dynamic_range_compression_torch(magnitudes)
|
||||
|
||||
|
||||
def spectral_de_normalize_torch(magnitudes):
|
||||
return dynamic_range_decompression_torch(magnitudes)
|
||||
|
||||
|
||||
mel_basis_cache = {}
|
||||
hann_window_cache = {}
|
||||
|
||||
|
||||
def mel_spectrogram(
|
||||
y: torch.Tensor,
|
||||
n_fft: int,
|
||||
num_mels: int,
|
||||
sampling_rate: int,
|
||||
hop_size: int,
|
||||
win_size: int,
|
||||
fmin: int,
|
||||
fmax: int = None,
|
||||
center: bool = False,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Calculate the mel spectrogram of an input signal.
|
||||
This function uses slaney norm for the librosa mel filterbank (using librosa.filters.mel) and uses Hann window for STFT (using torch.stft).
|
||||
|
||||
Args:
|
||||
y (torch.Tensor): Input signal.
|
||||
n_fft (int): FFT size.
|
||||
num_mels (int): Number of mel bins.
|
||||
sampling_rate (int): Sampling rate of the input signal.
|
||||
hop_size (int): Hop size for STFT.
|
||||
win_size (int): Window size for STFT.
|
||||
fmin (int): Minimum frequency for mel filterbank.
|
||||
fmax (int): Maximum frequency for mel filterbank. If None, defaults to half the sampling rate (fmax = sr / 2.0) inside librosa_mel_fn
|
||||
center (bool): Whether to pad the input to center the frames. Default is False.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Mel spectrogram.
|
||||
"""
|
||||
if torch.min(y) < -1.0:
|
||||
print(f"[WARNING] Min value of input waveform signal is {torch.min(y)}")
|
||||
if torch.max(y) > 1.0:
|
||||
print(f"[WARNING] Max value of input waveform signal is {torch.max(y)}")
|
||||
|
||||
device = y.device
|
||||
key = f"{n_fft}_{num_mels}_{sampling_rate}_{hop_size}_{win_size}_{fmin}_{fmax}_{device}"
|
||||
|
||||
if key not in mel_basis_cache:
|
||||
mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax)
|
||||
mel_basis_cache[key] = torch.from_numpy(mel).float().to(device)
|
||||
hann_window_cache[key] = torch.hann_window(win_size).to(device)
|
||||
|
||||
mel_basis = mel_basis_cache[key]
|
||||
hann_window = hann_window_cache[key]
|
||||
|
||||
padding = (n_fft - hop_size) // 2
|
||||
y = torch.nn.functional.pad(y.unsqueeze(1), (padding, padding), mode="reflect").squeeze(1)
|
||||
|
||||
spec = torch.stft(
|
||||
y,
|
||||
n_fft,
|
||||
hop_length=hop_size,
|
||||
win_length=win_size,
|
||||
window=hann_window,
|
||||
center=center,
|
||||
pad_mode="reflect",
|
||||
normalized=False,
|
||||
onesided=True,
|
||||
return_complex=True,
|
||||
)
|
||||
spec = torch.sqrt(torch.view_as_real(spec).pow(2).sum(-1) + 1e-9)
|
||||
|
||||
mel_spec = torch.matmul(mel_basis, spec)
|
||||
mel_spec = spectral_normalize_torch(mel_spec)
|
||||
|
||||
return mel_spec
|
||||
|
||||
|
||||
def get_mel_spectrogram(wav, h):
|
||||
"""
|
||||
Generate mel spectrogram from a waveform using given hyperparameters.
|
||||
|
||||
Args:
|
||||
wav (torch.Tensor): Input waveform.
|
||||
h: Hyperparameters object with attributes n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Mel spectrogram.
|
||||
"""
|
||||
return mel_spectrogram(
|
||||
wav,
|
||||
h.n_fft,
|
||||
h.num_mels,
|
||||
h.sampling_rate,
|
||||
h.hop_size,
|
||||
h.win_size,
|
||||
h.fmin,
|
||||
h.fmax,
|
||||
)
|
||||
|
||||
|
||||
def get_dataset_filelist(a):
|
||||
training_files = []
|
||||
validation_files = []
|
||||
list_unseen_validation_files = []
|
||||
|
||||
with open(a.input_training_file, "r", encoding="utf-8") as fi:
|
||||
training_files = [
|
||||
os.path.join(a.input_wavs_dir, x.split("|")[0] + ".wav") for x in fi.read().split("\n") if len(x) > 0
|
||||
]
|
||||
print(f"first training file: {training_files[0]}")
|
||||
|
||||
with open(a.input_validation_file, "r", encoding="utf-8") as fi:
|
||||
validation_files = [
|
||||
os.path.join(a.input_wavs_dir, x.split("|")[0] + ".wav") for x in fi.read().split("\n") if len(x) > 0
|
||||
]
|
||||
print(f"first validation file: {validation_files[0]}")
|
||||
|
||||
for i in range(len(a.list_input_unseen_validation_file)):
|
||||
with open(a.list_input_unseen_validation_file[i], "r", encoding="utf-8") as fi:
|
||||
unseen_validation_files = [
|
||||
os.path.join(a.list_input_unseen_wavs_dir[i], x.split("|")[0] + ".wav")
|
||||
for x in fi.read().split("\n")
|
||||
if len(x) > 0
|
||||
]
|
||||
print(f"first unseen {i}th validation fileset: {unseen_validation_files[0]}")
|
||||
list_unseen_validation_files.append(unseen_validation_files)
|
||||
|
||||
return training_files, validation_files, list_unseen_validation_files
|
||||
|
||||
|
||||
class MelDataset(torch.utils.data.Dataset):
|
||||
def __init__(
|
||||
self,
|
||||
training_files: List[str],
|
||||
hparams: AttrDict,
|
||||
segment_size: int,
|
||||
n_fft: int,
|
||||
num_mels: int,
|
||||
hop_size: int,
|
||||
win_size: int,
|
||||
sampling_rate: int,
|
||||
fmin: int,
|
||||
fmax: Optional[int],
|
||||
split: bool = True,
|
||||
shuffle: bool = True,
|
||||
device: str = None,
|
||||
fmax_loss: Optional[int] = None,
|
||||
fine_tuning: bool = False,
|
||||
base_mels_path: str = None,
|
||||
is_seen: bool = True,
|
||||
):
|
||||
self.audio_files = training_files
|
||||
random.seed(1234)
|
||||
if shuffle:
|
||||
random.shuffle(self.audio_files)
|
||||
self.hparams = hparams
|
||||
self.is_seen = is_seen
|
||||
if self.is_seen:
|
||||
self.name = pathlib.Path(self.audio_files[0]).parts[0]
|
||||
else:
|
||||
self.name = "-".join(pathlib.Path(self.audio_files[0]).parts[:2]).strip("/")
|
||||
|
||||
self.segment_size = segment_size
|
||||
self.sampling_rate = sampling_rate
|
||||
self.split = split
|
||||
self.n_fft = n_fft
|
||||
self.num_mels = num_mels
|
||||
self.hop_size = hop_size
|
||||
self.win_size = win_size
|
||||
self.fmin = fmin
|
||||
self.fmax = fmax
|
||||
self.fmax_loss = fmax_loss
|
||||
self.device = device
|
||||
self.fine_tuning = fine_tuning
|
||||
self.base_mels_path = base_mels_path
|
||||
|
||||
print("[INFO] checking dataset integrity...")
|
||||
for i in tqdm(range(len(self.audio_files))):
|
||||
assert os.path.exists(self.audio_files[i]), f"{self.audio_files[i]} not found"
|
||||
|
||||
def __getitem__(self, index: int) -> Tuple[torch.Tensor, torch.Tensor, str, torch.Tensor]:
|
||||
try:
|
||||
filename = self.audio_files[index]
|
||||
|
||||
# Use librosa.load that ensures loading waveform into mono with [-1, 1] float values
|
||||
# Audio is ndarray with shape [T_time]. Disable auto-resampling here to minimize overhead
|
||||
# The on-the-fly resampling during training will be done only for the obtained random chunk
|
||||
audio, source_sampling_rate = librosa.load(filename, sr=None, mono=True)
|
||||
|
||||
# Main logic that uses <mel, audio> pair for training BigVGAN
|
||||
if not self.fine_tuning:
|
||||
if self.split: # Training step
|
||||
# Obtain randomized audio chunk
|
||||
if source_sampling_rate != self.sampling_rate:
|
||||
# Adjust segment size to crop if the source sr is different
|
||||
target_segment_size = math.ceil(self.segment_size * (source_sampling_rate / self.sampling_rate))
|
||||
else:
|
||||
target_segment_size = self.segment_size
|
||||
|
||||
# Compute upper bound index for the random chunk
|
||||
random_chunk_upper_bound = max(0, audio.shape[0] - target_segment_size)
|
||||
|
||||
# Crop or pad audio to obtain random chunk with target_segment_size
|
||||
if audio.shape[0] >= target_segment_size:
|
||||
audio_start = random.randint(0, random_chunk_upper_bound)
|
||||
audio = audio[audio_start : audio_start + target_segment_size]
|
||||
else:
|
||||
audio = np.pad(
|
||||
audio,
|
||||
(0, target_segment_size - audio.shape[0]),
|
||||
mode="constant",
|
||||
)
|
||||
|
||||
# Resample audio chunk to self.sampling rate
|
||||
if source_sampling_rate != self.sampling_rate:
|
||||
audio = librosa.resample(
|
||||
audio,
|
||||
orig_sr=source_sampling_rate,
|
||||
target_sr=self.sampling_rate,
|
||||
)
|
||||
if audio.shape[0] > self.segment_size:
|
||||
# trim last elements to match self.segment_size (e.g., 16385 for 44khz downsampled to 24khz -> 16384)
|
||||
audio = audio[: self.segment_size]
|
||||
|
||||
else: # Validation step
|
||||
# Resample full audio clip to target sampling rate
|
||||
if source_sampling_rate != self.sampling_rate:
|
||||
audio = librosa.resample(
|
||||
audio,
|
||||
orig_sr=source_sampling_rate,
|
||||
target_sr=self.sampling_rate,
|
||||
)
|
||||
# Trim last elements to match audio length to self.hop_size * n for evaluation
|
||||
if (audio.shape[0] % self.hop_size) != 0:
|
||||
audio = audio[: -(audio.shape[0] % self.hop_size)]
|
||||
|
||||
# BigVGAN is trained using volume-normalized waveform
|
||||
audio = librosa.util.normalize(audio) * 0.95
|
||||
|
||||
# Cast ndarray to torch tensor
|
||||
audio = torch.FloatTensor(audio)
|
||||
audio = audio.unsqueeze(0) # [B(1), self.segment_size]
|
||||
|
||||
# Compute mel spectrogram corresponding to audio
|
||||
mel = mel_spectrogram(
|
||||
audio,
|
||||
self.n_fft,
|
||||
self.num_mels,
|
||||
self.sampling_rate,
|
||||
self.hop_size,
|
||||
self.win_size,
|
||||
self.fmin,
|
||||
self.fmax,
|
||||
center=False,
|
||||
) # [B(1), self.num_mels, self.segment_size // self.hop_size]
|
||||
|
||||
# Fine-tuning logic that uses pre-computed mel. Example: Using TTS model-generated mel as input
|
||||
else:
|
||||
# For fine-tuning, assert that the waveform is in the defined sampling_rate
|
||||
# Fine-tuning won't support on-the-fly resampling to be fool-proof (the dataset should have been prepared properly)
|
||||
assert source_sampling_rate == self.sampling_rate, (
|
||||
f"For fine_tuning, waveform must be in the spcified sampling rate {self.sampling_rate}, got {source_sampling_rate}"
|
||||
)
|
||||
|
||||
# Cast ndarray to torch tensor
|
||||
audio = torch.FloatTensor(audio)
|
||||
audio = audio.unsqueeze(0) # [B(1), T_time]
|
||||
|
||||
# Load pre-computed mel from disk
|
||||
mel = np.load(
|
||||
os.path.join(
|
||||
self.base_mels_path,
|
||||
os.path.splitext(os.path.split(filename)[-1])[0] + ".npy",
|
||||
)
|
||||
)
|
||||
mel = torch.from_numpy(mel)
|
||||
|
||||
if len(mel.shape) < 3:
|
||||
mel = mel.unsqueeze(0) # ensure [B, C, T]
|
||||
|
||||
if self.split:
|
||||
frames_per_seg = math.ceil(self.segment_size / self.hop_size)
|
||||
|
||||
if audio.size(1) >= self.segment_size:
|
||||
mel_start = random.randint(0, mel.size(2) - frames_per_seg - 1)
|
||||
mel = mel[:, :, mel_start : mel_start + frames_per_seg]
|
||||
audio = audio[
|
||||
:,
|
||||
mel_start * self.hop_size : (mel_start + frames_per_seg) * self.hop_size,
|
||||
]
|
||||
|
||||
# Pad pre-computed mel and audio to match length to ensuring fine-tuning without error.
|
||||
# NOTE: this may introduce a single-frame misalignment of the <pre-computed mel, audio>
|
||||
# To remove possible misalignment, it is recommended to prepare the <pre-computed mel, audio> pair where the audio length is the integer multiple of self.hop_size
|
||||
mel = torch.nn.functional.pad(mel, (0, frames_per_seg - mel.size(2)), "constant")
|
||||
audio = torch.nn.functional.pad(audio, (0, self.segment_size - audio.size(1)), "constant")
|
||||
|
||||
# Compute mel_loss used by spectral regression objective. Uses self.fmax_loss instead (usually None)
|
||||
mel_loss = mel_spectrogram(
|
||||
audio,
|
||||
self.n_fft,
|
||||
self.num_mels,
|
||||
self.sampling_rate,
|
||||
self.hop_size,
|
||||
self.win_size,
|
||||
self.fmin,
|
||||
self.fmax_loss,
|
||||
center=False,
|
||||
) # [B(1), self.num_mels, self.segment_size // self.hop_size]
|
||||
|
||||
# Shape sanity checks
|
||||
assert (
|
||||
audio.shape[1] == mel.shape[2] * self.hop_size and audio.shape[1] == mel_loss.shape[2] * self.hop_size
|
||||
), (
|
||||
f"Audio length must be mel frame length * hop_size. Got audio shape {audio.shape} mel shape {mel.shape} mel_loss shape {mel_loss.shape}"
|
||||
)
|
||||
|
||||
return (mel.squeeze(), audio.squeeze(0), filename, mel_loss.squeeze())
|
||||
|
||||
# If it encounters error during loading the data, skip this sample and load random other sample to the batch
|
||||
except Exception as e:
|
||||
if self.fine_tuning:
|
||||
raise e # Terminate training if it is fine-tuning. The dataset should have been prepared properly.
|
||||
else:
|
||||
print(f"[WARNING] Failed to load waveform, skipping! filename: {filename} Error: {e}")
|
||||
return self[random.randrange(len(self))]
|
||||
|
||||
def __len__(self):
|
||||
return len(self.audio_files)
|
@ -1 +0,0 @@
|
||||
|
@ -1,4 +0,0 @@
|
||||
| Field | Response |
|
||||
| :--------------------------------------------------------------------------------------------------------- | :--------------------------------------------------- |
|
||||
| Participation considerations from adversely impacted groups protected classes in model design and testing: | None |
|
||||
| Measures taken to mitigate against unwanted bias: | No measures taken to mitigate against unwanted bias. |
|
@ -1,13 +0,0 @@
|
||||
| Field | Response |
|
||||
| :---------------------------------------------------------------------------------------------------- | :--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| Intended Application & Domain: | Generating waveform from mel spectrogram. |
|
||||
| Model Type: | Convolutional Neural Network (CNN) |
|
||||
| Intended Users: | This model is intended for developers to synthesize and generate waveforms from the AI-generated mel spectrograms. |
|
||||
| Output: | Audio Waveform |
|
||||
| Describe how the model works: | Model generates audio waveform corresponding to the input mel spectrogram. |
|
||||
| Name the adversely impacted groups this has been tested to deliver comparable outcomes regardless of: | Not Applicable |
|
||||
| Technical Limitations: | This may not perform well on synthetically-generated mel spectrograms that deviate significantly from the profile of mel spectrograms on which this was trained. |
|
||||
| Verified to have met prescribed NVIDIA quality standards: | Yes |
|
||||
| Performance Metrics: | Perceptual Evaluation of Speech Quality (PESQ), Virtual Speech Quality Objective Listener (VISQOL), Multi-resolution STFT (MRSTFT), Mel cepstral distortion (MCD), Periodicity RMSE, Voice/Unvoiced F1 Score (V/UV F1) |
|
||||
| Potential Known Risks: | This model may generate low-quality or distorted soundwaves. |
|
||||
| Licensing: | https://github.com/NVIDIA/BigVGAN/blob/main/LICENSE |
|
@ -1,126 +0,0 @@
|
||||
# Model Overview
|
||||
|
||||
## Description:
|
||||
|
||||
BigVGAN is a generative AI model specialized in synthesizing audio waveforms using Mel spectrogram as inputs.
|
||||
|
||||
<center><img src="https://user-images.githubusercontent.com/15963413/218609148-881e39df-33af-4af9-ab95-1427c4ebf062.png" width="800"></center>
|
||||
|
||||
BigVGAN is a fully convolutional architecture with several upsampling blocks using transposed convolution followed by multiple residual dilated convolution layers.
|
||||
|
||||
BigVGAN consists of a novel module, called anti-aliased multi-periodicity composition (AMP), which is specifically designed for generating waveforms. AMP is specialized in synthesizing high-frequency and periodic soundwaves drawing inspiration from audio signal processing principles.
|
||||
|
||||
It applies a periodic activation function, called Snake, which provides an inductive bias to the architecture in generating periodic soundwaves. It also applies anti-aliasing filters to reduce undesired artifacts in the generated waveforms. <br>
|
||||
|
||||
This model is ready for commercial use.<br>
|
||||
|
||||
## References(s):
|
||||
|
||||
- [BigVGAN: A Universal Neural Vocoder with Large-Scale Training](https://arxiv.org/abs/2206.04658) <br>
|
||||
- [Project Page](https://research.nvidia.com/labs/adlr/projects/bigvgan/) <br>
|
||||
- [Audio Demo](https://bigvgan-demo.github.io/) <br>
|
||||
|
||||
## Model Architecture:
|
||||
|
||||
**Architecture Type:** Convolution Neural Network (CNN) <br>
|
||||
**Network Architecture:** You can see the details of this model on this link: https://github.com/NVIDIA/BigVGAN and the related paper can be found here: https://arxiv.org/abs/2206.04658<br>
|
||||
**Model Version:** 2.0 <br>
|
||||
|
||||
## Input:
|
||||
|
||||
**Input Type:** Audio <br>
|
||||
**Input Format:** Mel Spectrogram <br>
|
||||
**Input Parameters:** None <br>
|
||||
**Other Properties Related to Input:** The input mel spectrogram has shape `[batch, channels, frames]`, where `channels` refers to the number of mel bands defined by the model and `frames` refers to the temporal length. The model supports arbitrary long `frames` that fits into the GPU memory.
|
||||
|
||||
## Output:
|
||||
|
||||
**Input Type:** Audio <br>
|
||||
**Output Format:** Audio Waveform <br>
|
||||
**Output Parameters:** None <br>
|
||||
**Other Properties Related to Output:** The output audio waveform has shape `[batch, 1, time]`, where `1` refers to the mono audio channels and `time` refers to the temporal length. `time` is defined as a fixed integer multiple of input `frames`, which is an upsampling ratio of the model (`time = upsampling ratio * frames`). The output audio waveform consitutes float values with a range of `[-1, 1]`.
|
||||
|
||||
## Software Integration:
|
||||
|
||||
**Runtime Engine(s):** PyTorch
|
||||
|
||||
**Supported Hardware Microarchitecture Compatibility:** NVIDIA Ampere, NVIDIA Hopper, NVIDIA Lovelace, NVIDIA Turing, NVIDIA Volta <br>
|
||||
|
||||
## Preferred/Supported Operating System(s):
|
||||
|
||||
Linux
|
||||
|
||||
## Model Version(s):
|
||||
|
||||
v2.0
|
||||
|
||||
## Training, Testing, and Evaluation Datasets:
|
||||
|
||||
### Training Dataset:
|
||||
|
||||
The dataset contains diverse audio types, including speech in multiple languages, environmental sounds, and instruments.
|
||||
|
||||
**Links:**
|
||||
|
||||
- [AAM: Artificial Audio Multitracks Dataset](https://zenodo.org/records/5794629)
|
||||
- [AudioCaps](https://audiocaps.github.io/)
|
||||
- [AudioSet](https://research.google.com/audioset/index.html)
|
||||
- [common-accent](https://huggingface.co/datasets/DTU54DL/common-accent)
|
||||
- [Crowd Sourced Emotional Multimodal Actors Dataset (CREMA-D)](https://ieeexplore.ieee.org/document/6849440)
|
||||
- [DCASE2017 Challenge, Task 4: Large-scale weakly supervised sound event detection for smart cars](https://dcase.community/challenge2017/task-large-scale-sound-event-detection)
|
||||
- [FSDnoisy18k](https://zenodo.org/records/2529934)
|
||||
- [Free Universal Sound Separation Dataset](https://zenodo.org/records/3694384)
|
||||
- [Greatest Hits dataset](https://andrewowens.com/vis/)
|
||||
- [GTZAN](https://ieeexplore.ieee.org/document/1021072)
|
||||
- [JL corpus](https://www.kaggle.com/datasets/tli725/jl-corpus)
|
||||
- [Medley-solos-DB: a cross-collection dataset for musical instrument recognition](https://zenodo.org/records/3464194)
|
||||
- [MUSAN: A Music, Speech, and Noise Corpus](https://www.openslr.org/17/)
|
||||
- [MusicBench](https://huggingface.co/datasets/amaai-lab/MusicBench)
|
||||
- [MusicCaps](https://www.kaggle.com/datasets/googleai/musiccaps)
|
||||
- [MusicNet](https://www.kaggle.com/datasets/imsparsh/musicnet-dataset)
|
||||
- [NSynth](https://magenta.tensorflow.org/datasets/nsynth)
|
||||
- [OnAir-Music-Dataset](https://github.com/sevagh/OnAir-Music-Dataset)
|
||||
- [Audio Piano Triads Dataset](https://zenodo.org/records/4740877)
|
||||
- [Pitch Audio Dataset (Surge synthesizer)](https://zenodo.org/records/4677097)
|
||||
- [SONYC Urban Sound Tagging (SONYC-UST): a multilabel dataset from an urban acoustic sensor network](https://zenodo.org/records/3966543)
|
||||
- [VocalSound: A Dataset for Improving Human Vocal Sounds Recognition](https://arxiv.org/abs/2205.03433)
|
||||
- [WavText5K](https://github.com/microsoft/WavText5K)
|
||||
- [CSS10: A Collection of Single Speaker Speech Datasets for 10 Languages](https://github.com/Kyubyong/css10)
|
||||
- [Hi-Fi Multi-Speaker English TTS Dataset (Hi-Fi TTS)](https://www.openslr.org/109/)
|
||||
- [IIIT-H Indic Speech Databases](http://festvox.org/databases/iiit_voices/)
|
||||
- [Libri-Light: A Benchmark for ASR with Limited or No Supervision](https://arxiv.org/abs/1912.07875)
|
||||
- [LibriTTS: A Corpus Derived from LibriSpeech for Text-to-Speech](https://www.openslr.org/60)
|
||||
- [LibriTTS-R: A Restored Multi-Speaker Text-to-Speech Corpus](https://www.openslr.org/141/)
|
||||
- [The SIWIS French Speech Synthesis Database](https://datashare.ed.ac.uk/handle/10283/2353)
|
||||
- [Crowdsourced high-quality Colombian Spanish speech data set](https://openslr.org/72/)
|
||||
- [TTS-Portuguese Corpus](https://github.com/Edresson/TTS-Portuguese-Corpus)
|
||||
- [CSTR VCTK Corpus: English Multi-speaker Corpus for CSTR Voice Cloning Toolkit](https://datashare.ed.ac.uk/handle/10283/3443)
|
||||
|
||||
\*\* Data Collection Method by dataset <br>
|
||||
|
||||
- Human <br>
|
||||
|
||||
\*\* Labeling Method by dataset (for those with labels) <br>
|
||||
|
||||
- Hybrid: Automated, Human, Unknown <br>
|
||||
|
||||
### Evaluating Dataset:
|
||||
|
||||
Properties: The audio generation quality of BigVGAN is evaluated using `dev` splits of the [LibriTTS dataset](https://www.openslr.org/60/) and [Hi-Fi TTS dataset](https://www.openslr.org/109/). The datasets include speech in English language with equal balance of genders.
|
||||
|
||||
\*\* Data Collection Method by dataset <br>
|
||||
|
||||
- Human <br>
|
||||
|
||||
\*\* Labeling Method by dataset <br>
|
||||
|
||||
- Automated <br>
|
||||
|
||||
## Inference:
|
||||
|
||||
**Engine:** PyTorch <br>
|
||||
**Test Hardware:** NVIDIA A100 GPU <br>
|
||||
|
||||
## Ethical Considerations:
|
||||
|
||||
NVIDIA believes Trustworthy AI is a shared responsibility and we have established policies and practices to enable development for a wide array of AI applications. When downloaded or used in accordance with our terms of service, developers should work with their internal model team to ensure this model meets requirements for the relevant industry and use case and addresses unforeseen product misuse. For more detailed information on ethical considerations for this model, please see the Model Card++ Explainability, Bias, Safety & Security, and Privacy Subcards. Please report security vulnerabilities or NVIDIA AI Concerns [here](https://www.nvidia.com/en-us/support/submit-security-vulnerability/).
|
@ -1,14 +0,0 @@
|
||||
| Field | Response |
|
||||
| :------------------------------------------------------------------------------------------------------------------------------------- | :--------------------------------------------- |
|
||||
| Generatable or reverse engineerable personal information? | None |
|
||||
| Protected class data used to create this model? | None |
|
||||
| Was consent obtained for any personal data used? | Not Applicable (No Personal Data) |
|
||||
| How often is dataset reviewed? | Before Release |
|
||||
| Is a mechanism in place to honor data subject right of access or deletion of personal data? | Not Applicable |
|
||||
| If personal collected for the development of the model, was it collected directly by NVIDIA? | Not Applicable |
|
||||
| If personal collected for the development of the model by NVIDIA, do you maintain or have access to disclosures made to data subjects? | Not Applicable |
|
||||
| If personal collected for the development of this AI model, was it minimized to only what was required? | Not Applicable |
|
||||
| Is data in dataset traceable? | Yes |
|
||||
| Is there provenance for all datasets used in training? | Yes |
|
||||
| Does data labeling (annotation, metadata) comply with privacy laws? | Yes |
|
||||
| Is data compliant with data subject requests for data correction or removal, if such a request was made? | No, not possible with externally-sourced data. |
|
@ -1,6 +0,0 @@
|
||||
| Field | Response |
|
||||
| :---------------------------------------------- | :---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| Model Application(s): | Synethic Audio Generation |
|
||||
| Describe the life critical impact (if present). | Not Applicable |
|
||||
| Use Case Restrictions: | None |
|
||||
| Model and dataset restrictions: | The Principle of least privilege (PoLP) is applied limiting access for dataset generation and model development. Restrictions enforce dataset access during training, and dataset license constraints adhered to. |
|
@ -1,13 +0,0 @@
|
||||
torch
|
||||
numpy
|
||||
librosa>=0.8.1
|
||||
scipy
|
||||
tensorboard
|
||||
soundfile
|
||||
matplotlib
|
||||
pesq
|
||||
auraloss
|
||||
tqdm
|
||||
nnAudio
|
||||
ninja
|
||||
huggingface_hub>=0.23.4
|
@ -1,62 +0,0 @@
|
||||
# Copyright (c) 2024 NVIDIA CORPORATION.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import os
|
||||
import sys
|
||||
|
||||
# to import modules from parent_dir
|
||||
parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
||||
sys.path.append(parent_dir)
|
||||
|
||||
import torch
|
||||
from alias_free_activation.cuda import activation1d
|
||||
from activations import Snake
|
||||
|
||||
|
||||
def test_load_fused_kernels():
|
||||
try:
|
||||
print("[Success] load_fused_kernels")
|
||||
except ImportError as e:
|
||||
print("[Fail] load_fused_kernels")
|
||||
raise e
|
||||
|
||||
|
||||
def test_anti_alias_activation():
|
||||
data = torch.rand((10, 10, 200), device="cuda")
|
||||
|
||||
# Check activations.Snake cuda vs. torch
|
||||
fused_anti_alias_activation = activation1d.Activation1d(activation=Snake(10), fused=True).cuda()
|
||||
fused_activation_output = fused_anti_alias_activation(data)
|
||||
|
||||
torch_anti_alias_activation = activation1d.Activation1d(activation=Snake(10), fused=False).cuda()
|
||||
torch_activation_output = torch_anti_alias_activation(data)
|
||||
|
||||
test_result = (fused_activation_output - torch_activation_output).abs()
|
||||
|
||||
while test_result.dim() != 1:
|
||||
test_result = test_result.mean(dim=-1)
|
||||
|
||||
diff = test_result.mean(dim=-1)
|
||||
|
||||
if diff <= 1e-3:
|
||||
print(
|
||||
f"\n[Success] test_fused_anti_alias_activation"
|
||||
f"\n > mean_difference={diff}"
|
||||
f"\n > fused_values={fused_activation_output[-1][-1][:].tolist()}"
|
||||
f"\n > torch_values={torch_activation_output[-1][-1][:].tolist()}"
|
||||
)
|
||||
else:
|
||||
print(
|
||||
f"\n[Fail] test_fused_anti_alias_activation"
|
||||
f"\n > mean_difference={diff}, "
|
||||
f"\n > fused_values={fused_activation_output[-1][-1][:].tolist()}, "
|
||||
f"\n > torch_values={torch_activation_output[-1][-1][:].tolist()}"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from alias_free_activation.cuda import load
|
||||
|
||||
load.load()
|
||||
test_load_fused_kernels()
|
||||
test_anti_alias_activation()
|
@ -1,62 +0,0 @@
|
||||
# Copyright (c) 2024 NVIDIA CORPORATION.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import os
|
||||
import sys
|
||||
|
||||
# to import modules from parent_dir
|
||||
parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
||||
sys.path.append(parent_dir)
|
||||
|
||||
import torch
|
||||
from alias_free_activation.cuda import activation1d
|
||||
from activations import SnakeBeta
|
||||
|
||||
|
||||
def test_load_fused_kernels():
|
||||
try:
|
||||
print("[Success] load_fused_kernels")
|
||||
except ImportError as e:
|
||||
print("[Fail] load_fused_kernels")
|
||||
raise e
|
||||
|
||||
|
||||
def test_anti_alias_activation():
|
||||
data = torch.rand((10, 10, 200), device="cuda")
|
||||
|
||||
# Check activations, Snake CUDA vs. Torch
|
||||
fused_anti_alias_activation = activation1d.Activation1d(activation=SnakeBeta(10), fused=True).cuda()
|
||||
fused_activation_output = fused_anti_alias_activation(data)
|
||||
|
||||
torch_anti_alias_activation = activation1d.Activation1d(activation=SnakeBeta(10), fused=False).cuda()
|
||||
torch_activation_output = torch_anti_alias_activation(data)
|
||||
|
||||
test_result = (fused_activation_output - torch_activation_output).abs()
|
||||
|
||||
while test_result.dim() != 1:
|
||||
test_result = test_result.mean(dim=-1)
|
||||
|
||||
diff = test_result.mean(dim=-1)
|
||||
|
||||
if diff <= 1e-3:
|
||||
print(
|
||||
f"\n[Success] test_fused_anti_alias_activation"
|
||||
f"\n > mean_difference={diff}"
|
||||
f"\n > fused_values={fused_activation_output[-1][-1][:].tolist()}"
|
||||
f"\n > torch_values={torch_activation_output[-1][-1][:].tolist()}"
|
||||
)
|
||||
else:
|
||||
print(
|
||||
f"\n[Fail] test_fused_anti_alias_activation"
|
||||
f"\n > mean_difference={diff}, "
|
||||
f"\n > fused_values={fused_activation_output[-1][-1][:].tolist()}, "
|
||||
f"\n > torch_values={torch_activation_output[-1][-1][:].tolist()}"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from alias_free_activation.cuda import load
|
||||
|
||||
load.load()
|
||||
test_load_fused_kernels()
|
||||
test_anti_alias_activation()
|
@ -1,215 +0,0 @@
|
||||
# Copyright (c) 2024 NVIDIA CORPORATION.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import os
|
||||
import sys
|
||||
|
||||
# to import modules from parent_dir
|
||||
parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
||||
sys.path.append(parent_dir)
|
||||
|
||||
import torch
|
||||
import json
|
||||
from env import AttrDict
|
||||
from bigvgan import BigVGAN
|
||||
from time import time
|
||||
from tqdm import tqdm
|
||||
from meldataset import mel_spectrogram, MAX_WAV_VALUE
|
||||
from scipy.io.wavfile import write
|
||||
import numpy as np
|
||||
|
||||
import argparse
|
||||
|
||||
torch.backends.cudnn.benchmark = True
|
||||
|
||||
# For easier debugging
|
||||
torch.set_printoptions(linewidth=200, threshold=10_000)
|
||||
|
||||
|
||||
def generate_soundwave(duration=5.0, sr=24000):
|
||||
t = np.linspace(0, duration, int(sr * duration), False, dtype=np.float32)
|
||||
|
||||
modulation = np.sin(2 * np.pi * t / duration)
|
||||
|
||||
min_freq = 220
|
||||
max_freq = 1760
|
||||
frequencies = min_freq + (max_freq - min_freq) * (modulation + 1) / 2
|
||||
soundwave = np.sin(2 * np.pi * frequencies * t)
|
||||
|
||||
soundwave = soundwave / np.max(np.abs(soundwave)) * 0.95
|
||||
|
||||
return soundwave, sr
|
||||
|
||||
|
||||
def get_mel(x, h):
|
||||
return mel_spectrogram(x, h.n_fft, h.num_mels, h.sampling_rate, h.hop_size, h.win_size, h.fmin, h.fmax)
|
||||
|
||||
|
||||
def load_checkpoint(filepath, device):
|
||||
assert os.path.isfile(filepath)
|
||||
print(f"Loading '{filepath}'")
|
||||
checkpoint_dict = torch.load(filepath, map_location=device)
|
||||
print("Complete.")
|
||||
return checkpoint_dict
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Test script to check CUDA kernel correctness.")
|
||||
parser.add_argument(
|
||||
"--checkpoint_file",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to the checkpoint file. Assumes config.json exists in the directory.",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
config_file = os.path.join(os.path.split(args.checkpoint_file)[0], "config.json")
|
||||
with open(config_file) as f:
|
||||
config = f.read()
|
||||
json_config = json.loads(config)
|
||||
h = AttrDict({**json_config})
|
||||
|
||||
print("loading plain Pytorch BigVGAN")
|
||||
generator_original = BigVGAN(h).to("cuda")
|
||||
print("loading CUDA kernel BigVGAN with auto-build")
|
||||
generator_cuda_kernel = BigVGAN(h, use_cuda_kernel=True).to("cuda")
|
||||
|
||||
state_dict_g = load_checkpoint(args.checkpoint_file, "cuda")
|
||||
generator_original.load_state_dict(state_dict_g["generator"])
|
||||
generator_cuda_kernel.load_state_dict(state_dict_g["generator"])
|
||||
|
||||
generator_original.remove_weight_norm()
|
||||
generator_original.eval()
|
||||
generator_cuda_kernel.remove_weight_norm()
|
||||
generator_cuda_kernel.eval()
|
||||
|
||||
# define number of samples and length of mel frame to benchmark
|
||||
num_sample = 10
|
||||
num_mel_frame = 16384
|
||||
|
||||
# CUDA kernel correctness check
|
||||
diff = 0.0
|
||||
for i in tqdm(range(num_sample)):
|
||||
# Random mel
|
||||
data = torch.rand((1, h.num_mels, num_mel_frame), device="cuda")
|
||||
|
||||
with torch.inference_mode():
|
||||
audio_original = generator_original(data)
|
||||
|
||||
with torch.inference_mode():
|
||||
audio_cuda_kernel = generator_cuda_kernel(data)
|
||||
|
||||
# Both outputs should be (almost) the same
|
||||
test_result = (audio_original - audio_cuda_kernel).abs()
|
||||
diff += test_result.mean(dim=-1).item()
|
||||
|
||||
diff /= num_sample
|
||||
if diff <= 2e-3: # We can expect a small difference (~1e-3) which does not affect perceptual quality
|
||||
print(
|
||||
f"\n[Success] test CUDA fused vs. plain torch BigVGAN inference"
|
||||
f"\n > mean_difference={diff}"
|
||||
f"\n > fused_values={audio_cuda_kernel[-1][-1][-30:].tolist()}"
|
||||
f"\n > torch_values={audio_original[-1][-1][-30:].tolist()}"
|
||||
)
|
||||
else:
|
||||
print(
|
||||
f"\n[Fail] test CUDA fused vs. plain torch BigVGAN inference"
|
||||
f"\n > mean_difference={diff}"
|
||||
f"\n > fused_values={audio_cuda_kernel[-1][-1][-30:].tolist()}, "
|
||||
f"\n > torch_values={audio_original[-1][-1][-30:].tolist()}"
|
||||
)
|
||||
|
||||
del data, audio_original, audio_cuda_kernel
|
||||
|
||||
# Variables for tracking total time and VRAM usage
|
||||
toc_total_original = 0
|
||||
toc_total_cuda_kernel = 0
|
||||
vram_used_original_total = 0
|
||||
vram_used_cuda_kernel_total = 0
|
||||
audio_length_total = 0
|
||||
|
||||
# Measure Original inference in isolation
|
||||
for i in tqdm(range(num_sample)):
|
||||
torch.cuda.reset_peak_memory_stats(device="cuda")
|
||||
data = torch.rand((1, h.num_mels, num_mel_frame), device="cuda")
|
||||
torch.cuda.synchronize()
|
||||
tic = time()
|
||||
with torch.inference_mode():
|
||||
audio_original = generator_original(data)
|
||||
torch.cuda.synchronize()
|
||||
toc = time() - tic
|
||||
toc_total_original += toc
|
||||
|
||||
vram_used_original_total += torch.cuda.max_memory_allocated(device="cuda")
|
||||
|
||||
del data, audio_original
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# Measure CUDA kernel inference in isolation
|
||||
for i in tqdm(range(num_sample)):
|
||||
torch.cuda.reset_peak_memory_stats(device="cuda")
|
||||
data = torch.rand((1, h.num_mels, num_mel_frame), device="cuda")
|
||||
torch.cuda.synchronize()
|
||||
tic = time()
|
||||
with torch.inference_mode():
|
||||
audio_cuda_kernel = generator_cuda_kernel(data)
|
||||
torch.cuda.synchronize()
|
||||
toc = time() - tic
|
||||
toc_total_cuda_kernel += toc
|
||||
|
||||
audio_length_total += audio_cuda_kernel.shape[-1]
|
||||
|
||||
vram_used_cuda_kernel_total += torch.cuda.max_memory_allocated(device="cuda")
|
||||
|
||||
del data, audio_cuda_kernel
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# Calculate metrics
|
||||
audio_second = audio_length_total / h.sampling_rate
|
||||
khz_original = audio_length_total / toc_total_original / 1000
|
||||
khz_cuda_kernel = audio_length_total / toc_total_cuda_kernel / 1000
|
||||
vram_used_original_gb = vram_used_original_total / num_sample / (1024**3)
|
||||
vram_used_cuda_kernel_gb = vram_used_cuda_kernel_total / num_sample / (1024**3)
|
||||
|
||||
# Print results
|
||||
print(
|
||||
f"Original BigVGAN: took {toc_total_original:.2f} seconds to generate {audio_second:.2f} seconds of audio, {khz_original:.1f}kHz, {audio_second / toc_total_original:.1f} faster than realtime, VRAM used {vram_used_original_gb:.1f} GB"
|
||||
)
|
||||
print(
|
||||
f"CUDA kernel BigVGAN: took {toc_total_cuda_kernel:.2f} seconds to generate {audio_second:.2f} seconds of audio, {khz_cuda_kernel:.1f}kHz, {audio_second / toc_total_cuda_kernel:.1f} faster than realtime, VRAM used {vram_used_cuda_kernel_gb:.1f} GB"
|
||||
)
|
||||
print(f"speedup of CUDA kernel: {khz_cuda_kernel / khz_original}")
|
||||
print(f"VRAM saving of CUDA kernel: {vram_used_original_gb / vram_used_cuda_kernel_gb}")
|
||||
|
||||
# Use artificial sine waves for inference test
|
||||
audio_real, sr = generate_soundwave(duration=5.0, sr=h.sampling_rate)
|
||||
audio_real = torch.tensor(audio_real).to("cuda")
|
||||
# Compute mel spectrogram from the ground truth audio
|
||||
x = get_mel(audio_real.unsqueeze(0), h)
|
||||
|
||||
with torch.inference_mode():
|
||||
y_g_hat_original = generator_original(x)
|
||||
y_g_hat_cuda_kernel = generator_cuda_kernel(x)
|
||||
|
||||
audio_real = audio_real.squeeze()
|
||||
audio_real = audio_real * MAX_WAV_VALUE
|
||||
audio_real = audio_real.cpu().numpy().astype("int16")
|
||||
|
||||
audio_original = y_g_hat_original.squeeze()
|
||||
audio_original = audio_original * MAX_WAV_VALUE
|
||||
audio_original = audio_original.cpu().numpy().astype("int16")
|
||||
|
||||
audio_cuda_kernel = y_g_hat_cuda_kernel.squeeze()
|
||||
audio_cuda_kernel = audio_cuda_kernel * MAX_WAV_VALUE
|
||||
audio_cuda_kernel = audio_cuda_kernel.cpu().numpy().astype("int16")
|
||||
|
||||
os.makedirs("tmp", exist_ok=True)
|
||||
output_file_real = os.path.join("tmp", "audio_real.wav")
|
||||
output_file_original = os.path.join("tmp", "audio_generated_original.wav")
|
||||
output_file_cuda_kernel = os.path.join("tmp", "audio_generated_cuda_kernel.wav")
|
||||
write(output_file_real, h.sampling_rate, audio_real)
|
||||
write(output_file_original, h.sampling_rate, audio_original)
|
||||
write(output_file_cuda_kernel, h.sampling_rate, audio_cuda_kernel)
|
||||
print("Example generated audios of original vs. fused CUDA kernel written to tmp!")
|
||||
print("Done")
|
@ -1,716 +0,0 @@
|
||||
# Copyright (c) 2024 NVIDIA CORPORATION.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
# Adapted from https://github.com/jik876/hifi-gan under the MIT license.
|
||||
# LICENSE is in incl_licenses directory.
|
||||
|
||||
|
||||
import warnings
|
||||
|
||||
warnings.simplefilter(action="ignore", category=FutureWarning)
|
||||
import itertools
|
||||
import os
|
||||
import time
|
||||
import argparse
|
||||
import json
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
from torch.utils.data import DistributedSampler, DataLoader
|
||||
import torch.multiprocessing as mp
|
||||
from torch.distributed import init_process_group
|
||||
from torch.nn.parallel import DistributedDataParallel
|
||||
from env import AttrDict, build_env
|
||||
from meldataset import MelDataset, mel_spectrogram, get_dataset_filelist, MAX_WAV_VALUE
|
||||
|
||||
from bigvgan import BigVGAN
|
||||
from discriminators import (
|
||||
MultiPeriodDiscriminator,
|
||||
MultiResolutionDiscriminator,
|
||||
MultiBandDiscriminator,
|
||||
MultiScaleSubbandCQTDiscriminator,
|
||||
)
|
||||
from loss import (
|
||||
feature_loss,
|
||||
generator_loss,
|
||||
discriminator_loss,
|
||||
MultiScaleMelSpectrogramLoss,
|
||||
)
|
||||
|
||||
from utils import (
|
||||
plot_spectrogram,
|
||||
plot_spectrogram_clipped,
|
||||
scan_checkpoint,
|
||||
load_checkpoint,
|
||||
save_checkpoint,
|
||||
save_audio,
|
||||
)
|
||||
import torchaudio as ta
|
||||
from pesq import pesq
|
||||
from tqdm import tqdm
|
||||
import auraloss
|
||||
|
||||
torch.backends.cudnn.benchmark = False
|
||||
|
||||
|
||||
def train(rank, a, h):
|
||||
if h.num_gpus > 1:
|
||||
# initialize distributed
|
||||
init_process_group(
|
||||
backend=h.dist_config["dist_backend"],
|
||||
init_method=h.dist_config["dist_url"],
|
||||
world_size=h.dist_config["world_size"] * h.num_gpus,
|
||||
rank=rank,
|
||||
)
|
||||
|
||||
# Set seed and device
|
||||
torch.cuda.manual_seed(h.seed)
|
||||
torch.cuda.set_device(rank)
|
||||
device = torch.device(f"cuda:{rank:d}")
|
||||
|
||||
# Define BigVGAN generator
|
||||
generator = BigVGAN(h).to(device)
|
||||
|
||||
# Define discriminators. MPD is used by default
|
||||
mpd = MultiPeriodDiscriminator(h).to(device)
|
||||
|
||||
# Define additional discriminators. BigVGAN-v1 uses UnivNet's MRD as default
|
||||
# New in BigVGAN-v2: option to switch to new discriminators: MultiBandDiscriminator / MultiScaleSubbandCQTDiscriminator
|
||||
if h.get("use_mbd_instead_of_mrd", False): # Switch to MBD
|
||||
print("[INFO] using MultiBandDiscriminator of BigVGAN-v2 instead of MultiResolutionDiscriminator")
|
||||
# Variable name is kept as "mrd" for backward compatibility & minimal code change
|
||||
mrd = MultiBandDiscriminator(h).to(device)
|
||||
elif h.get("use_cqtd_instead_of_mrd", False): # Switch to CQTD
|
||||
print("[INFO] using MultiScaleSubbandCQTDiscriminator of BigVGAN-v2 instead of MultiResolutionDiscriminator")
|
||||
mrd = MultiScaleSubbandCQTDiscriminator(h).to(device)
|
||||
else: # Fallback to original MRD in BigVGAN-v1
|
||||
mrd = MultiResolutionDiscriminator(h).to(device)
|
||||
|
||||
# New in BigVGAN-v2: option to switch to multi-scale L1 mel loss
|
||||
if h.get("use_multiscale_melloss", False):
|
||||
print("[INFO] using multi-scale Mel l1 loss of BigVGAN-v2 instead of the original single-scale loss")
|
||||
fn_mel_loss_multiscale = MultiScaleMelSpectrogramLoss(
|
||||
sampling_rate=h.sampling_rate
|
||||
) # NOTE: accepts waveform as input
|
||||
else:
|
||||
fn_mel_loss_singlescale = F.l1_loss
|
||||
|
||||
# Print the model & number of parameters, and create or scan the latest checkpoint from checkpoints directory
|
||||
if rank == 0:
|
||||
print(generator)
|
||||
print(mpd)
|
||||
print(mrd)
|
||||
print(f"Generator params: {sum(p.numel() for p in generator.parameters())}")
|
||||
print(f"Discriminator mpd params: {sum(p.numel() for p in mpd.parameters())}")
|
||||
print(f"Discriminator mrd params: {sum(p.numel() for p in mrd.parameters())}")
|
||||
os.makedirs(a.checkpoint_path, exist_ok=True)
|
||||
print(f"Checkpoints directory: {a.checkpoint_path}")
|
||||
|
||||
if os.path.isdir(a.checkpoint_path):
|
||||
# New in v2.1: If the step prefix pattern-based checkpoints are not found, also check for renamed files in Hugging Face Hub to resume training
|
||||
cp_g = scan_checkpoint(a.checkpoint_path, prefix="g_", renamed_file="bigvgan_generator.pt")
|
||||
cp_do = scan_checkpoint(
|
||||
a.checkpoint_path,
|
||||
prefix="do_",
|
||||
renamed_file="bigvgan_discriminator_optimizer.pt",
|
||||
)
|
||||
|
||||
# Load the latest checkpoint if exists
|
||||
steps = 0
|
||||
if cp_g is None or cp_do is None:
|
||||
state_dict_do = None
|
||||
last_epoch = -1
|
||||
else:
|
||||
state_dict_g = load_checkpoint(cp_g, device)
|
||||
state_dict_do = load_checkpoint(cp_do, device)
|
||||
generator.load_state_dict(state_dict_g["generator"])
|
||||
mpd.load_state_dict(state_dict_do["mpd"])
|
||||
mrd.load_state_dict(state_dict_do["mrd"])
|
||||
steps = state_dict_do["steps"] + 1
|
||||
last_epoch = state_dict_do["epoch"]
|
||||
|
||||
# Initialize DDP, optimizers, and schedulers
|
||||
if h.num_gpus > 1:
|
||||
generator = DistributedDataParallel(generator, device_ids=[rank]).to(device)
|
||||
mpd = DistributedDataParallel(mpd, device_ids=[rank]).to(device)
|
||||
mrd = DistributedDataParallel(mrd, device_ids=[rank]).to(device)
|
||||
|
||||
optim_g = torch.optim.AdamW(generator.parameters(), h.learning_rate, betas=[h.adam_b1, h.adam_b2])
|
||||
optim_d = torch.optim.AdamW(
|
||||
itertools.chain(mrd.parameters(), mpd.parameters()),
|
||||
h.learning_rate,
|
||||
betas=[h.adam_b1, h.adam_b2],
|
||||
)
|
||||
|
||||
if state_dict_do is not None:
|
||||
optim_g.load_state_dict(state_dict_do["optim_g"])
|
||||
optim_d.load_state_dict(state_dict_do["optim_d"])
|
||||
|
||||
scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optim_g, gamma=h.lr_decay, last_epoch=last_epoch)
|
||||
scheduler_d = torch.optim.lr_scheduler.ExponentialLR(optim_d, gamma=h.lr_decay, last_epoch=last_epoch)
|
||||
|
||||
# Define training and validation datasets
|
||||
|
||||
"""
|
||||
unseen_validation_filelist will contain sample filepaths outside the seen training & validation dataset
|
||||
Example: trained on LibriTTS, validate on VCTK
|
||||
"""
|
||||
training_filelist, validation_filelist, list_unseen_validation_filelist = get_dataset_filelist(a)
|
||||
|
||||
trainset = MelDataset(
|
||||
training_filelist,
|
||||
h,
|
||||
h.segment_size,
|
||||
h.n_fft,
|
||||
h.num_mels,
|
||||
h.hop_size,
|
||||
h.win_size,
|
||||
h.sampling_rate,
|
||||
h.fmin,
|
||||
h.fmax,
|
||||
shuffle=False if h.num_gpus > 1 else True,
|
||||
fmax_loss=h.fmax_for_loss,
|
||||
device=device,
|
||||
fine_tuning=a.fine_tuning,
|
||||
base_mels_path=a.input_mels_dir,
|
||||
is_seen=True,
|
||||
)
|
||||
|
||||
train_sampler = DistributedSampler(trainset) if h.num_gpus > 1 else None
|
||||
|
||||
train_loader = DataLoader(
|
||||
trainset,
|
||||
num_workers=h.num_workers,
|
||||
shuffle=False,
|
||||
sampler=train_sampler,
|
||||
batch_size=h.batch_size,
|
||||
pin_memory=True,
|
||||
drop_last=True,
|
||||
)
|
||||
|
||||
if rank == 0:
|
||||
validset = MelDataset(
|
||||
validation_filelist,
|
||||
h,
|
||||
h.segment_size,
|
||||
h.n_fft,
|
||||
h.num_mels,
|
||||
h.hop_size,
|
||||
h.win_size,
|
||||
h.sampling_rate,
|
||||
h.fmin,
|
||||
h.fmax,
|
||||
False,
|
||||
False,
|
||||
fmax_loss=h.fmax_for_loss,
|
||||
device=device,
|
||||
fine_tuning=a.fine_tuning,
|
||||
base_mels_path=a.input_mels_dir,
|
||||
is_seen=True,
|
||||
)
|
||||
validation_loader = DataLoader(
|
||||
validset,
|
||||
num_workers=1,
|
||||
shuffle=False,
|
||||
sampler=None,
|
||||
batch_size=1,
|
||||
pin_memory=True,
|
||||
drop_last=True,
|
||||
)
|
||||
|
||||
list_unseen_validset = []
|
||||
list_unseen_validation_loader = []
|
||||
for i in range(len(list_unseen_validation_filelist)):
|
||||
unseen_validset = MelDataset(
|
||||
list_unseen_validation_filelist[i],
|
||||
h,
|
||||
h.segment_size,
|
||||
h.n_fft,
|
||||
h.num_mels,
|
||||
h.hop_size,
|
||||
h.win_size,
|
||||
h.sampling_rate,
|
||||
h.fmin,
|
||||
h.fmax,
|
||||
False,
|
||||
False,
|
||||
fmax_loss=h.fmax_for_loss,
|
||||
device=device,
|
||||
fine_tuning=a.fine_tuning,
|
||||
base_mels_path=a.input_mels_dir,
|
||||
is_seen=False,
|
||||
)
|
||||
unseen_validation_loader = DataLoader(
|
||||
unseen_validset,
|
||||
num_workers=1,
|
||||
shuffle=False,
|
||||
sampler=None,
|
||||
batch_size=1,
|
||||
pin_memory=True,
|
||||
drop_last=True,
|
||||
)
|
||||
list_unseen_validset.append(unseen_validset)
|
||||
list_unseen_validation_loader.append(unseen_validation_loader)
|
||||
|
||||
# Tensorboard logger
|
||||
sw = SummaryWriter(os.path.join(a.checkpoint_path, "logs"))
|
||||
if a.save_audio: # Also save audio to disk if --save_audio is set to True
|
||||
os.makedirs(os.path.join(a.checkpoint_path, "samples"), exist_ok=True)
|
||||
|
||||
"""
|
||||
Validation loop, "mode" parameter is automatically defined as (seen or unseen)_(name of the dataset).
|
||||
If the name of the dataset contains "nonspeech", it skips PESQ calculation to prevent errors
|
||||
"""
|
||||
|
||||
def validate(rank, a, h, loader, mode="seen"):
|
||||
assert rank == 0, "validate should only run on rank=0"
|
||||
generator.eval()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
val_err_tot = 0
|
||||
val_pesq_tot = 0
|
||||
val_mrstft_tot = 0
|
||||
|
||||
# Modules for evaluation metrics
|
||||
pesq_resampler = ta.transforms.Resample(h.sampling_rate, 16000).cuda()
|
||||
loss_mrstft = auraloss.freq.MultiResolutionSTFTLoss(device="cuda")
|
||||
|
||||
if a.save_audio: # Also save audio to disk if --save_audio is set to True
|
||||
os.makedirs(
|
||||
os.path.join(a.checkpoint_path, "samples", f"gt_{mode}"),
|
||||
exist_ok=True,
|
||||
)
|
||||
os.makedirs(
|
||||
os.path.join(a.checkpoint_path, "samples", f"{mode}_{steps:08d}"),
|
||||
exist_ok=True,
|
||||
)
|
||||
|
||||
with torch.no_grad():
|
||||
print(f"step {steps} {mode} speaker validation...")
|
||||
|
||||
# Loop over validation set and compute metrics
|
||||
for j, batch in enumerate(tqdm(loader)):
|
||||
x, y, _, y_mel = batch
|
||||
y = y.to(device)
|
||||
if hasattr(generator, "module"):
|
||||
y_g_hat = generator.module(x.to(device))
|
||||
else:
|
||||
y_g_hat = generator(x.to(device))
|
||||
y_mel = y_mel.to(device, non_blocking=True)
|
||||
y_g_hat_mel = mel_spectrogram(
|
||||
y_g_hat.squeeze(1),
|
||||
h.n_fft,
|
||||
h.num_mels,
|
||||
h.sampling_rate,
|
||||
h.hop_size,
|
||||
h.win_size,
|
||||
h.fmin,
|
||||
h.fmax_for_loss,
|
||||
)
|
||||
min_t = min(y_mel.size(-1), y_g_hat_mel.size(-1))
|
||||
val_err_tot += F.l1_loss(y_mel[..., :min_t], y_g_hat_mel[..., :min_t]).item()
|
||||
|
||||
# PESQ calculation. only evaluate PESQ if it's speech signal (nonspeech PESQ will error out)
|
||||
if "nonspeech" not in mode: # Skips if the name of dataset (in mode string) contains "nonspeech"
|
||||
# Resample to 16000 for pesq
|
||||
y_16k = pesq_resampler(y)
|
||||
y_g_hat_16k = pesq_resampler(y_g_hat.squeeze(1))
|
||||
y_int_16k = (y_16k[0] * MAX_WAV_VALUE).short().cpu().numpy()
|
||||
y_g_hat_int_16k = (y_g_hat_16k[0] * MAX_WAV_VALUE).short().cpu().numpy()
|
||||
val_pesq_tot += pesq(16000, y_int_16k, y_g_hat_int_16k, "wb")
|
||||
|
||||
# MRSTFT calculation
|
||||
min_t = min(y.size(-1), y_g_hat.size(-1))
|
||||
val_mrstft_tot += loss_mrstft(y_g_hat[..., :min_t], y[..., :min_t]).item()
|
||||
|
||||
# Log audio and figures to Tensorboard
|
||||
if j % a.eval_subsample == 0: # Subsample every nth from validation set
|
||||
if steps >= 0:
|
||||
sw.add_audio(f"gt_{mode}/y_{j}", y[0], steps, h.sampling_rate)
|
||||
if a.save_audio: # Also save audio to disk if --save_audio is set to True
|
||||
save_audio(
|
||||
y[0],
|
||||
os.path.join(
|
||||
a.checkpoint_path,
|
||||
"samples",
|
||||
f"gt_{mode}",
|
||||
f"{j:04d}.wav",
|
||||
),
|
||||
h.sampling_rate,
|
||||
)
|
||||
sw.add_figure(
|
||||
f"gt_{mode}/y_spec_{j}",
|
||||
plot_spectrogram(x[0]),
|
||||
steps,
|
||||
)
|
||||
|
||||
sw.add_audio(
|
||||
f"generated_{mode}/y_hat_{j}",
|
||||
y_g_hat[0],
|
||||
steps,
|
||||
h.sampling_rate,
|
||||
)
|
||||
if a.save_audio: # Also save audio to disk if --save_audio is set to True
|
||||
save_audio(
|
||||
y_g_hat[0, 0],
|
||||
os.path.join(
|
||||
a.checkpoint_path,
|
||||
"samples",
|
||||
f"{mode}_{steps:08d}",
|
||||
f"{j:04d}.wav",
|
||||
),
|
||||
h.sampling_rate,
|
||||
)
|
||||
# Spectrogram of synthesized audio
|
||||
y_hat_spec = mel_spectrogram(
|
||||
y_g_hat.squeeze(1),
|
||||
h.n_fft,
|
||||
h.num_mels,
|
||||
h.sampling_rate,
|
||||
h.hop_size,
|
||||
h.win_size,
|
||||
h.fmin,
|
||||
h.fmax,
|
||||
)
|
||||
sw.add_figure(
|
||||
f"generated_{mode}/y_hat_spec_{j}",
|
||||
plot_spectrogram(y_hat_spec.squeeze(0).cpu().numpy()),
|
||||
steps,
|
||||
)
|
||||
|
||||
"""
|
||||
Visualization of spectrogram difference between GT and synthesized audio, difference higher than 1 is clipped for better visualization.
|
||||
"""
|
||||
spec_delta = torch.clamp(
|
||||
torch.abs(x[0] - y_hat_spec.squeeze(0).cpu()),
|
||||
min=1e-6,
|
||||
max=1.0,
|
||||
)
|
||||
sw.add_figure(
|
||||
f"delta_dclip1_{mode}/spec_{j}",
|
||||
plot_spectrogram_clipped(spec_delta.numpy(), clip_max=1.0),
|
||||
steps,
|
||||
)
|
||||
|
||||
val_err = val_err_tot / (j + 1)
|
||||
val_pesq = val_pesq_tot / (j + 1)
|
||||
val_mrstft = val_mrstft_tot / (j + 1)
|
||||
# Log evaluation metrics to Tensorboard
|
||||
sw.add_scalar(f"validation_{mode}/mel_spec_error", val_err, steps)
|
||||
sw.add_scalar(f"validation_{mode}/pesq", val_pesq, steps)
|
||||
sw.add_scalar(f"validation_{mode}/mrstft", val_mrstft, steps)
|
||||
|
||||
generator.train()
|
||||
|
||||
# If the checkpoint is loaded, start with validation loop
|
||||
if steps != 0 and rank == 0 and not a.debug:
|
||||
if not a.skip_seen:
|
||||
validate(
|
||||
rank,
|
||||
a,
|
||||
h,
|
||||
validation_loader,
|
||||
mode=f"seen_{train_loader.dataset.name}",
|
||||
)
|
||||
for i in range(len(list_unseen_validation_loader)):
|
||||
validate(
|
||||
rank,
|
||||
a,
|
||||
h,
|
||||
list_unseen_validation_loader[i],
|
||||
mode=f"unseen_{list_unseen_validation_loader[i].dataset.name}",
|
||||
)
|
||||
# Exit the script if --evaluate is set to True
|
||||
if a.evaluate:
|
||||
exit()
|
||||
|
||||
# Main training loop
|
||||
generator.train()
|
||||
mpd.train()
|
||||
mrd.train()
|
||||
for epoch in range(max(0, last_epoch), a.training_epochs):
|
||||
if rank == 0:
|
||||
start = time.time()
|
||||
print(f"Epoch: {epoch + 1}")
|
||||
|
||||
if h.num_gpus > 1:
|
||||
train_sampler.set_epoch(epoch)
|
||||
|
||||
for i, batch in enumerate(train_loader):
|
||||
if rank == 0:
|
||||
start_b = time.time()
|
||||
x, y, _, y_mel = batch
|
||||
|
||||
x = x.to(device, non_blocking=True)
|
||||
y = y.to(device, non_blocking=True)
|
||||
y_mel = y_mel.to(device, non_blocking=True)
|
||||
y = y.unsqueeze(1)
|
||||
|
||||
y_g_hat = generator(x)
|
||||
y_g_hat_mel = mel_spectrogram(
|
||||
y_g_hat.squeeze(1),
|
||||
h.n_fft,
|
||||
h.num_mels,
|
||||
h.sampling_rate,
|
||||
h.hop_size,
|
||||
h.win_size,
|
||||
h.fmin,
|
||||
h.fmax_for_loss,
|
||||
)
|
||||
|
||||
optim_d.zero_grad()
|
||||
|
||||
# MPD
|
||||
y_df_hat_r, y_df_hat_g, _, _ = mpd(y, y_g_hat.detach())
|
||||
loss_disc_f, losses_disc_f_r, losses_disc_f_g = discriminator_loss(y_df_hat_r, y_df_hat_g)
|
||||
|
||||
# MRD
|
||||
y_ds_hat_r, y_ds_hat_g, _, _ = mrd(y, y_g_hat.detach())
|
||||
loss_disc_s, losses_disc_s_r, losses_disc_s_g = discriminator_loss(y_ds_hat_r, y_ds_hat_g)
|
||||
|
||||
loss_disc_all = loss_disc_s + loss_disc_f
|
||||
|
||||
# Set clip_grad_norm value
|
||||
clip_grad_norm = h.get("clip_grad_norm", 1000.0) # Default to 1000
|
||||
|
||||
# Whether to freeze D for initial training steps
|
||||
if steps >= a.freeze_step:
|
||||
loss_disc_all.backward()
|
||||
grad_norm_mpd = torch.nn.utils.clip_grad_norm_(mpd.parameters(), clip_grad_norm)
|
||||
grad_norm_mrd = torch.nn.utils.clip_grad_norm_(mrd.parameters(), clip_grad_norm)
|
||||
optim_d.step()
|
||||
else:
|
||||
print(f"[WARNING] skipping D training for the first {a.freeze_step} steps")
|
||||
grad_norm_mpd = 0.0
|
||||
grad_norm_mrd = 0.0
|
||||
|
||||
# Generator
|
||||
optim_g.zero_grad()
|
||||
|
||||
# L1 Mel-Spectrogram Loss
|
||||
lambda_melloss = h.get("lambda_melloss", 45.0) # Defaults to 45 in BigVGAN-v1 if not set
|
||||
if h.get("use_multiscale_melloss", False): # uses wav <y, y_g_hat> for loss
|
||||
loss_mel = fn_mel_loss_multiscale(y, y_g_hat) * lambda_melloss
|
||||
else: # Uses mel <y_mel, y_g_hat_mel> for loss
|
||||
loss_mel = fn_mel_loss_singlescale(y_mel, y_g_hat_mel) * lambda_melloss
|
||||
|
||||
# MPD loss
|
||||
y_df_hat_r, y_df_hat_g, fmap_f_r, fmap_f_g = mpd(y, y_g_hat)
|
||||
loss_fm_f = feature_loss(fmap_f_r, fmap_f_g)
|
||||
loss_gen_f, losses_gen_f = generator_loss(y_df_hat_g)
|
||||
|
||||
# MRD loss
|
||||
y_ds_hat_r, y_ds_hat_g, fmap_s_r, fmap_s_g = mrd(y, y_g_hat)
|
||||
loss_fm_s = feature_loss(fmap_s_r, fmap_s_g)
|
||||
loss_gen_s, losses_gen_s = generator_loss(y_ds_hat_g)
|
||||
|
||||
if steps >= a.freeze_step:
|
||||
loss_gen_all = loss_gen_s + loss_gen_f + loss_fm_s + loss_fm_f + loss_mel
|
||||
else:
|
||||
print(f"[WARNING] using regression loss only for G for the first {a.freeze_step} steps")
|
||||
loss_gen_all = loss_mel
|
||||
|
||||
loss_gen_all.backward()
|
||||
grad_norm_g = torch.nn.utils.clip_grad_norm_(generator.parameters(), clip_grad_norm)
|
||||
optim_g.step()
|
||||
|
||||
if rank == 0:
|
||||
# STDOUT logging
|
||||
if steps % a.stdout_interval == 0:
|
||||
mel_error = loss_mel.item() / lambda_melloss # Log training mel regression loss to stdout
|
||||
print(
|
||||
f"Steps: {steps:d}, "
|
||||
f"Gen Loss Total: {loss_gen_all:4.3f}, "
|
||||
f"Mel Error: {mel_error:4.3f}, "
|
||||
f"s/b: {time.time() - start_b:4.3f} "
|
||||
f"lr: {optim_g.param_groups[0]['lr']:4.7f} "
|
||||
f"grad_norm_g: {grad_norm_g:4.3f}"
|
||||
)
|
||||
|
||||
# Checkpointing
|
||||
if steps % a.checkpoint_interval == 0 and steps != 0:
|
||||
checkpoint_path = f"{a.checkpoint_path}/g_{steps:08d}"
|
||||
save_checkpoint(
|
||||
checkpoint_path,
|
||||
{"generator": (generator.module if h.num_gpus > 1 else generator).state_dict()},
|
||||
)
|
||||
checkpoint_path = f"{a.checkpoint_path}/do_{steps:08d}"
|
||||
save_checkpoint(
|
||||
checkpoint_path,
|
||||
{
|
||||
"mpd": (mpd.module if h.num_gpus > 1 else mpd).state_dict(),
|
||||
"mrd": (mrd.module if h.num_gpus > 1 else mrd).state_dict(),
|
||||
"optim_g": optim_g.state_dict(),
|
||||
"optim_d": optim_d.state_dict(),
|
||||
"steps": steps,
|
||||
"epoch": epoch,
|
||||
},
|
||||
)
|
||||
|
||||
# Tensorboard summary logging
|
||||
if steps % a.summary_interval == 0:
|
||||
mel_error = loss_mel.item() / lambda_melloss # Log training mel regression loss to tensorboard
|
||||
sw.add_scalar("training/gen_loss_total", loss_gen_all.item(), steps)
|
||||
sw.add_scalar("training/mel_spec_error", mel_error, steps)
|
||||
sw.add_scalar("training/fm_loss_mpd", loss_fm_f.item(), steps)
|
||||
sw.add_scalar("training/gen_loss_mpd", loss_gen_f.item(), steps)
|
||||
sw.add_scalar("training/disc_loss_mpd", loss_disc_f.item(), steps)
|
||||
sw.add_scalar("training/grad_norm_mpd", grad_norm_mpd, steps)
|
||||
sw.add_scalar("training/fm_loss_mrd", loss_fm_s.item(), steps)
|
||||
sw.add_scalar("training/gen_loss_mrd", loss_gen_s.item(), steps)
|
||||
sw.add_scalar("training/disc_loss_mrd", loss_disc_s.item(), steps)
|
||||
sw.add_scalar("training/grad_norm_mrd", grad_norm_mrd, steps)
|
||||
sw.add_scalar("training/grad_norm_g", grad_norm_g, steps)
|
||||
sw.add_scalar("training/learning_rate_d", scheduler_d.get_last_lr()[0], steps)
|
||||
sw.add_scalar("training/learning_rate_g", scheduler_g.get_last_lr()[0], steps)
|
||||
sw.add_scalar("training/epoch", epoch + 1, steps)
|
||||
|
||||
# Validation
|
||||
if steps % a.validation_interval == 0:
|
||||
# Plot training input x so far used
|
||||
for i_x in range(x.shape[0]):
|
||||
sw.add_figure(
|
||||
f"training_input/x_{i_x}",
|
||||
plot_spectrogram(x[i_x].cpu()),
|
||||
steps,
|
||||
)
|
||||
sw.add_audio(
|
||||
f"training_input/y_{i_x}",
|
||||
y[i_x][0],
|
||||
steps,
|
||||
h.sampling_rate,
|
||||
)
|
||||
|
||||
# Seen and unseen speakers validation loops
|
||||
if not a.debug and steps != 0:
|
||||
validate(
|
||||
rank,
|
||||
a,
|
||||
h,
|
||||
validation_loader,
|
||||
mode=f"seen_{train_loader.dataset.name}",
|
||||
)
|
||||
for i in range(len(list_unseen_validation_loader)):
|
||||
validate(
|
||||
rank,
|
||||
a,
|
||||
h,
|
||||
list_unseen_validation_loader[i],
|
||||
mode=f"unseen_{list_unseen_validation_loader[i].dataset.name}",
|
||||
)
|
||||
steps += 1
|
||||
|
||||
# BigVGAN-v2 learning rate scheduler is changed from epoch-level to step-level
|
||||
scheduler_g.step()
|
||||
scheduler_d.step()
|
||||
|
||||
if rank == 0:
|
||||
print(f"Time taken for epoch {epoch + 1} is {int(time.time() - start)} sec\n")
|
||||
|
||||
|
||||
def main():
|
||||
print("Initializing Training Process..")
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument("--group_name", default=None)
|
||||
|
||||
parser.add_argument("--input_wavs_dir", default="LibriTTS")
|
||||
parser.add_argument("--input_mels_dir", default="ft_dataset")
|
||||
parser.add_argument("--input_training_file", default="tests/LibriTTS/train-full.txt")
|
||||
parser.add_argument("--input_validation_file", default="tests/LibriTTS/val-full.txt")
|
||||
|
||||
parser.add_argument(
|
||||
"--list_input_unseen_wavs_dir",
|
||||
nargs="+",
|
||||
default=["tests/LibriTTS", "tests/LibriTTS"],
|
||||
)
|
||||
parser.add_argument(
|
||||
"--list_input_unseen_validation_file",
|
||||
nargs="+",
|
||||
default=["tests/LibriTTS/dev-clean.txt", "tests/LibriTTS/dev-other.txt"],
|
||||
)
|
||||
|
||||
parser.add_argument("--checkpoint_path", default="exp/bigvgan")
|
||||
parser.add_argument("--config", default="")
|
||||
|
||||
parser.add_argument("--training_epochs", default=100000, type=int)
|
||||
parser.add_argument("--stdout_interval", default=5, type=int)
|
||||
parser.add_argument("--checkpoint_interval", default=50000, type=int)
|
||||
parser.add_argument("--summary_interval", default=100, type=int)
|
||||
parser.add_argument("--validation_interval", default=50000, type=int)
|
||||
|
||||
parser.add_argument(
|
||||
"--freeze_step",
|
||||
default=0,
|
||||
type=int,
|
||||
help="freeze D for the first specified steps. G only uses regression loss for these steps.",
|
||||
)
|
||||
|
||||
parser.add_argument("--fine_tuning", default=False, type=bool)
|
||||
|
||||
parser.add_argument(
|
||||
"--debug",
|
||||
default=False,
|
||||
type=bool,
|
||||
help="debug mode. skips validation loop throughout training",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--evaluate",
|
||||
default=False,
|
||||
type=bool,
|
||||
help="only run evaluation from checkpoint and exit",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--eval_subsample",
|
||||
default=5,
|
||||
type=int,
|
||||
help="subsampling during evaluation loop",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--skip_seen",
|
||||
default=False,
|
||||
type=bool,
|
||||
help="skip seen dataset. useful for test set inference",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--save_audio",
|
||||
default=False,
|
||||
type=bool,
|
||||
help="save audio of test set inference to disk",
|
||||
)
|
||||
|
||||
a = parser.parse_args()
|
||||
|
||||
with open(a.config) as f:
|
||||
data = f.read()
|
||||
|
||||
json_config = json.loads(data)
|
||||
h = AttrDict(json_config)
|
||||
|
||||
build_env(a.config, "config.json", a.checkpoint_path)
|
||||
|
||||
torch.manual_seed(h.seed)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed(h.seed)
|
||||
h.num_gpus = torch.cuda.device_count()
|
||||
h.batch_size = int(h.batch_size / h.num_gpus)
|
||||
print(f"Batch size per GPU: {h.batch_size}")
|
||||
else:
|
||||
pass
|
||||
|
||||
if h.num_gpus > 1:
|
||||
mp.spawn(
|
||||
train,
|
||||
nprocs=h.num_gpus,
|
||||
args=(
|
||||
a,
|
||||
h,
|
||||
),
|
||||
)
|
||||
else:
|
||||
train(0, a, h)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -3,44 +3,10 @@
|
||||
|
||||
import glob
|
||||
import os
|
||||
import matplotlib
|
||||
|
||||
import torch
|
||||
from torch.nn.utils import weight_norm
|
||||
|
||||
matplotlib.use("Agg")
|
||||
import matplotlib.pylab as plt
|
||||
from .meldataset import MAX_WAV_VALUE
|
||||
from scipy.io.wavfile import write
|
||||
|
||||
|
||||
def plot_spectrogram(spectrogram):
|
||||
fig, ax = plt.subplots(figsize=(10, 2))
|
||||
im = ax.imshow(spectrogram, aspect="auto", origin="lower", interpolation="none")
|
||||
plt.colorbar(im, ax=ax)
|
||||
|
||||
fig.canvas.draw()
|
||||
plt.close()
|
||||
|
||||
return fig
|
||||
|
||||
|
||||
def plot_spectrogram_clipped(spectrogram, clip_max=2.0):
|
||||
fig, ax = plt.subplots(figsize=(10, 2))
|
||||
im = ax.imshow(
|
||||
spectrogram,
|
||||
aspect="auto",
|
||||
origin="lower",
|
||||
interpolation="none",
|
||||
vmin=1e-6,
|
||||
vmax=clip_max,
|
||||
)
|
||||
plt.colorbar(im, ax=ax)
|
||||
|
||||
fig.canvas.draw()
|
||||
plt.close()
|
||||
|
||||
return fig
|
||||
|
||||
|
||||
def init_weights(m, mean=0.0, std=0.01):
|
||||
classname = m.__class__.__name__
|
||||
@ -90,10 +56,3 @@ def scan_checkpoint(cp_dir, prefix, renamed_file=None):
|
||||
return renamed_path
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def save_audio(audio, path, sr):
|
||||
# wav: torch with 1d shape
|
||||
audio = audio * MAX_WAV_VALUE
|
||||
audio = audio.cpu().numpy().astype("int16")
|
||||
write(path, sr, audio)
|
||||
|
@ -2,17 +2,9 @@ import gc
|
||||
import math
|
||||
import os
|
||||
import random
|
||||
import sys
|
||||
import time
|
||||
import traceback
|
||||
from copy import deepcopy
|
||||
|
||||
import torchaudio
|
||||
from tqdm import tqdm
|
||||
|
||||
now_dir = os.getcwd()
|
||||
sys.path.append(now_dir)
|
||||
import os
|
||||
from typing import List, Tuple, Union
|
||||
|
||||
import ffmpeg
|
||||
@ -20,21 +12,26 @@ import librosa
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torchaudio
|
||||
import yaml
|
||||
from AR.models.t2s_lightning_module import Text2SemanticLightningModule
|
||||
from BigVGAN.bigvgan import BigVGAN
|
||||
from feature_extractor.cnhubert import CNHubert
|
||||
from module.mel_processing import mel_spectrogram_torch, spectrogram_torch
|
||||
from module.models import SynthesizerTrn, SynthesizerTrnV3, Generator
|
||||
from peft import LoraConfig, get_peft_model
|
||||
from process_ckpt import get_sovits_version_from_path_fast, load_sovits_new
|
||||
from tqdm import tqdm
|
||||
from transformers import AutoModelForMaskedLM, AutoTokenizer
|
||||
|
||||
from GPT_SoVITS.AR.models.t2s_lightning_module import Text2SemanticLightningModule
|
||||
from GPT_SoVITS.BigVGAN.bigvgan import BigVGAN
|
||||
from GPT_SoVITS.feature_extractor.cnhubert import CNHubert
|
||||
from GPT_SoVITS.module.mel_processing import mel_spectrogram_torch, spectrogram_torch
|
||||
from GPT_SoVITS.module.models import Generator, SynthesizerTrn, SynthesizerTrnV3
|
||||
from GPT_SoVITS.process_ckpt import get_sovits_version_from_path_fast, load_sovits_new
|
||||
from GPT_SoVITS.sv import SV
|
||||
from GPT_SoVITS.TTS_infer_pack.text_segmentation_method import splits
|
||||
from GPT_SoVITS.TTS_infer_pack.TextPreprocessor import TextPreprocessor
|
||||
from tools.audio_sr import AP_BWE
|
||||
from tools.i18n.i18n import I18nAuto, scan_language_list
|
||||
from TTS_infer_pack.text_segmentation_method import splits
|
||||
from TTS_infer_pack.TextPreprocessor import TextPreprocessor
|
||||
from sv import SV
|
||||
from tools.i18n.i18n import I18nAuto
|
||||
from tools.my_utils import DictToAttrRecursive
|
||||
|
||||
now_dir = os.getcwd()
|
||||
|
||||
resample_transform_dict = {}
|
||||
|
||||
@ -48,7 +45,6 @@ def resample(audio_tensor, sr0, sr1, device):
|
||||
|
||||
|
||||
language = os.environ.get("language", "Auto")
|
||||
language = sys.argv[-1] if sys.argv[-1] in scan_language_list() else language
|
||||
i18n = I18nAuto(language=language)
|
||||
|
||||
|
||||
@ -64,33 +60,32 @@ def denorm_spec(x):
|
||||
return (x + 1) / 2 * (spec_max - spec_min) + spec_min
|
||||
|
||||
|
||||
mel_fn = lambda x: mel_spectrogram_torch(
|
||||
x,
|
||||
**{
|
||||
"n_fft": 1024,
|
||||
"win_size": 1024,
|
||||
"hop_size": 256,
|
||||
"num_mels": 100,
|
||||
"sampling_rate": 24000,
|
||||
"fmin": 0,
|
||||
"fmax": None,
|
||||
"center": False,
|
||||
},
|
||||
)
|
||||
def mel_fn(x):
|
||||
return mel_spectrogram_torch(
|
||||
y=x,
|
||||
n_fft=1024,
|
||||
num_mels=100,
|
||||
sampling_rate=24000,
|
||||
hop_size=256,
|
||||
win_size=1024,
|
||||
fmin=0,
|
||||
fmax=None,
|
||||
center=False,
|
||||
)
|
||||
|
||||
mel_fn_v4 = lambda x: mel_spectrogram_torch(
|
||||
x,
|
||||
**{
|
||||
"n_fft": 1280,
|
||||
"win_size": 1280,
|
||||
"hop_size": 320,
|
||||
"num_mels": 100,
|
||||
"sampling_rate": 32000,
|
||||
"fmin": 0,
|
||||
"fmax": None,
|
||||
"center": False,
|
||||
},
|
||||
)
|
||||
|
||||
def mel_fn_v4(x):
|
||||
return mel_spectrogram_torch(
|
||||
y=x,
|
||||
n_fft=1280,
|
||||
num_mels=100,
|
||||
sampling_rate=32000,
|
||||
hop_size=320,
|
||||
win_size=1280,
|
||||
fmin=0,
|
||||
fmax=None,
|
||||
center=False,
|
||||
)
|
||||
|
||||
|
||||
def speed_change(input_audio: np.ndarray, speed: float, sr: int):
|
||||
@ -114,34 +109,6 @@ def speed_change(input_audio: np.ndarray, speed: float, sr: int):
|
||||
return processed_audio
|
||||
|
||||
|
||||
class DictToAttrRecursive(dict):
|
||||
def __init__(self, input_dict):
|
||||
super().__init__(input_dict)
|
||||
for key, value in input_dict.items():
|
||||
if isinstance(value, dict):
|
||||
value = DictToAttrRecursive(value)
|
||||
self[key] = value
|
||||
setattr(self, key, value)
|
||||
|
||||
def __getattr__(self, item):
|
||||
try:
|
||||
return self[item]
|
||||
except KeyError:
|
||||
raise AttributeError(f"Attribute {item} not found")
|
||||
|
||||
def __setattr__(self, key, value):
|
||||
if isinstance(value, dict):
|
||||
value = DictToAttrRecursive(value)
|
||||
super(DictToAttrRecursive, self).__setitem__(key, value)
|
||||
super().__setattr__(key, value)
|
||||
|
||||
def __delattr__(self, item):
|
||||
try:
|
||||
del self[item]
|
||||
except KeyError:
|
||||
raise AttributeError(f"Attribute {item} not found")
|
||||
|
||||
|
||||
class NO_PROMPT_ERROR(Exception):
|
||||
pass
|
||||
|
||||
@ -316,7 +283,7 @@ class TTS_Config:
|
||||
|
||||
self.is_half = self.configs.get("is_half", False)
|
||||
if str(self.device) == "cpu" and self.is_half:
|
||||
print(f"Warning: Half precision is not supported on CPU, set is_half to False.")
|
||||
print("Warning: Half precision is not supported on CPU, set is_half to False.")
|
||||
self.is_half = False
|
||||
|
||||
version = self.configs.get("version", None)
|
||||
@ -488,7 +455,7 @@ class TTS:
|
||||
self.init_sv_model()
|
||||
path_sovits = self.configs.default_configs[model_version]["vits_weights_path"]
|
||||
|
||||
if if_lora_v3 == True and os.path.exists(path_sovits) == False:
|
||||
if if_lora_v3 is True and os.path.exists(path_sovits) is False:
|
||||
info = path_sovits + i18n("SoVITS %s 底模缺失,无法加载相应 LoRA 权重" % model_version)
|
||||
raise FileExistsError(info)
|
||||
|
||||
@ -549,7 +516,7 @@ class TTS:
|
||||
|
||||
self.is_v2pro = model_version in {"v2Pro", "v2ProPlus"}
|
||||
|
||||
if if_lora_v3 == False:
|
||||
if if_lora_v3 is False:
|
||||
print(
|
||||
f"Loading VITS weights from {weights_path}. {vits_model.load_state_dict(dict_s2['weight'], strict=False)}"
|
||||
)
|
||||
@ -580,8 +547,6 @@ class TTS:
|
||||
|
||||
self.configs.save_configs()
|
||||
|
||||
|
||||
|
||||
def init_t2s_weights(self, weights_path: str):
|
||||
print(f"Loading Text2Semantic weights from {weights_path}")
|
||||
self.configs.t2s_weights_path = weights_path
|
||||
@ -654,7 +619,7 @@ class TTS:
|
||||
self.vocoder_configs["overlapped_len"] = 12
|
||||
|
||||
self.vocoder = self.vocoder.eval()
|
||||
if self.configs.is_half == True:
|
||||
if self.configs.is_half is True:
|
||||
self.vocoder = self.vocoder.half().to(self.configs.device)
|
||||
else:
|
||||
self.vocoder = self.vocoder.to(self.configs.device)
|
||||
@ -756,19 +721,18 @@ class TTS:
|
||||
self.prompt_cache["refer_spec"][0] = spec_audio
|
||||
|
||||
def _get_ref_spec(self, ref_audio_path):
|
||||
raw_audio, raw_sr = torchaudio.load(ref_audio_path)
|
||||
raw_audio = raw_audio.to(self.configs.device).float()
|
||||
raw_audio, raw_sr = torchaudio.load_with_torchcodec(ref_audio_path)
|
||||
self.prompt_cache["raw_audio"] = raw_audio
|
||||
self.prompt_cache["raw_sr"] = raw_sr
|
||||
|
||||
if raw_sr != self.configs.sampling_rate:
|
||||
audio = raw_audio.to(self.configs.device)
|
||||
if audio.shape[0] == 2:
|
||||
if audio.shape[0] > 1:
|
||||
audio = audio.mean(0).unsqueeze(0)
|
||||
audio = resample(audio, raw_sr, self.configs.sampling_rate, self.configs.device)
|
||||
else:
|
||||
audio = raw_audio.to(self.configs.device)
|
||||
if audio.shape[0] == 2:
|
||||
if audio.shape[0] > 1:
|
||||
audio = audio.mean(0).unsqueeze(0)
|
||||
|
||||
maxx = audio.abs().max()
|
||||
@ -784,7 +748,7 @@ class TTS:
|
||||
)
|
||||
if self.configs.is_half:
|
||||
spec = spec.half()
|
||||
if self.is_v2pro == True:
|
||||
if self.is_v2pro is True:
|
||||
audio = resample(audio, self.configs.sampling_rate, 16000, self.configs.device)
|
||||
if self.configs.is_half:
|
||||
audio = audio.half()
|
||||
@ -1235,7 +1199,7 @@ class TTS:
|
||||
spec = spec.to(dtype=self.precision, device=self.configs.device)
|
||||
refer_audio_spec.append(spec)
|
||||
if self.is_v2pro:
|
||||
sv_emb.append(self.sv_model.compute_embedding3(audio_tensor))
|
||||
sv_emb.append(self.sv_model.compute_embedding(audio_tensor))
|
||||
|
||||
batch_audio_fragment = []
|
||||
|
||||
|
@ -1,22 +1,18 @@
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
import threading
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
now_dir = os.getcwd()
|
||||
sys.path.append(now_dir)
|
||||
|
||||
import re
|
||||
import torch
|
||||
from text.LangSegmenter import LangSegmenter
|
||||
from text import chinese
|
||||
from typing import Dict, List, Tuple
|
||||
from text.cleaner import clean_text
|
||||
from text import cleaned_text_to_sequence
|
||||
from transformers import AutoModelForMaskedLM, AutoTokenizer
|
||||
from TTS_infer_pack.text_segmentation_method import split_big_text, splits, get_method as get_seg_method
|
||||
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
from transformers import AutoModelForMaskedLM, AutoTokenizer
|
||||
|
||||
from GPT_SoVITS.text import cleaned_text_to_sequence
|
||||
from GPT_SoVITS.text.cleaner import clean_text
|
||||
from GPT_SoVITS.text.LangSegmenter import LangSegmenter
|
||||
from GPT_SoVITS.TTS_infer_pack.text_segmentation_method import get_method as get_seg_method
|
||||
from GPT_SoVITS.TTS_infer_pack.text_segmentation_method import split_big_text, splits
|
||||
from tools.i18n.i18n import I18nAuto, scan_language_list
|
||||
|
||||
language = os.environ.get("language", "Auto")
|
||||
@ -121,25 +117,25 @@ class TextPreprocessor:
|
||||
|
||||
def get_phones_and_bert(self, text: str, language: str, version: str, final: bool = False):
|
||||
with self.bert_lock:
|
||||
text = re.sub(r' {2,}', ' ', text)
|
||||
text = re.sub(r" {2,}", " ", text)
|
||||
textlist = []
|
||||
langlist = []
|
||||
if language == "all_zh":
|
||||
for tmp in LangSegmenter.getTexts(text,"zh"):
|
||||
for tmp in LangSegmenter.getTexts(text, "zh"):
|
||||
langlist.append(tmp["lang"])
|
||||
textlist.append(tmp["text"])
|
||||
elif language == "all_yue":
|
||||
for tmp in LangSegmenter.getTexts(text,"zh"):
|
||||
for tmp in LangSegmenter.getTexts(text, "zh"):
|
||||
if tmp["lang"] == "zh":
|
||||
tmp["lang"] = "yue"
|
||||
langlist.append(tmp["lang"])
|
||||
textlist.append(tmp["text"])
|
||||
elif language == "all_ja":
|
||||
for tmp in LangSegmenter.getTexts(text,"ja"):
|
||||
for tmp in LangSegmenter.getTexts(text, "ja"):
|
||||
langlist.append(tmp["lang"])
|
||||
textlist.append(tmp["text"])
|
||||
elif language == "all_ko":
|
||||
for tmp in LangSegmenter.getTexts(text,"ko"):
|
||||
for tmp in LangSegmenter.getTexts(text, "ko"):
|
||||
langlist.append(tmp["lang"])
|
||||
textlist.append(tmp["text"])
|
||||
elif language == "en":
|
||||
@ -158,7 +154,9 @@ class TextPreprocessor:
|
||||
else:
|
||||
for tmp in LangSegmenter.getTexts(text):
|
||||
if langlist:
|
||||
if (tmp["lang"] == "en" and langlist[-1] == "en") or (tmp["lang"] != "en" and langlist[-1] != "en"):
|
||||
if (tmp["lang"] == "en" and langlist[-1] == "en") or (
|
||||
tmp["lang"] != "en" and langlist[-1] != "en"
|
||||
):
|
||||
textlist[-1] += tmp["text"]
|
||||
continue
|
||||
if tmp["lang"] == "en":
|
||||
@ -189,12 +187,11 @@ class TextPreprocessor:
|
||||
return phones, bert, norm_text
|
||||
|
||||
def get_bert_feature(self, text: str, word2ph: list) -> torch.Tensor:
|
||||
with torch.no_grad():
|
||||
inputs = self.tokenizer(text, return_tensors="pt")
|
||||
for i in inputs:
|
||||
inputs[i] = inputs[i].to(self.device)
|
||||
res = self.bert_model(**inputs, output_hidden_states=True)
|
||||
res = torch.cat(res["hidden_states"][-3:-2], -1)[0].cpu()[1:-1]
|
||||
inputs = self.tokenizer(text, return_tensors="pt")
|
||||
for i in inputs:
|
||||
inputs[i] = inputs[i].to(self.device)
|
||||
res = self.bert_model(**inputs, output_hidden_states=True)
|
||||
res = torch.cat(res["hidden_states"][-3:-2], -1)[0].cpu()[1:-1]
|
||||
assert len(word2ph) == len(text)
|
||||
phone_level_feature = []
|
||||
for i in range(len(word2ph)):
|
||||
@ -209,6 +206,7 @@ class TextPreprocessor:
|
||||
phones = cleaned_text_to_sequence(phones, version)
|
||||
return phones, word2ph, norm_text
|
||||
|
||||
@torch.no_grad()
|
||||
def get_bert_inf(self, phones: list, word2ph: list, norm_text: str, language: str):
|
||||
language = language.replace("all_", "")
|
||||
if language == "zh":
|
||||
@ -236,4 +234,4 @@ class TextPreprocessor:
|
||||
punctuations = "".join(re.escape(p) for p in punctuation)
|
||||
pattern = f"([{punctuations}])([{punctuations}])+"
|
||||
result = re.sub(pattern, r"\1", text)
|
||||
return result
|
||||
return result
|
||||
|
@ -1 +1,3 @@
|
||||
from . import TTS, text_segmentation_method
|
||||
|
||||
__all__ = ["TTS", "text_segmentation_method"]
|
||||
|
@ -100,7 +100,7 @@ def cut0(inp):
|
||||
def cut1(inp):
|
||||
inp = inp.strip("\n")
|
||||
inps = split(inp)
|
||||
split_idx = list(range(0, len(inps), 4))
|
||||
split_idx = list(range(0, len(inps) + 1, 4))
|
||||
split_idx[-1] = None
|
||||
if len(split_idx) > 1:
|
||||
opts = []
|
||||
|
@ -1,13 +0,0 @@
|
||||
import os
|
||||
import sys
|
||||
|
||||
now_dir = os.getcwd()
|
||||
sys.path.insert(0, now_dir)
|
||||
from text.g2pw import G2PWPinyin
|
||||
|
||||
g2pw = G2PWPinyin(
|
||||
model_dir="GPT_SoVITS/text/G2PWModel",
|
||||
model_source="GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large",
|
||||
v_to_u=False,
|
||||
neutral_tone_with_five=True,
|
||||
)
|
@ -1,264 +0,0 @@
|
||||
# Copyright 3D-Speaker (https://github.com/alibaba-damo-academy/3D-Speaker). All Rights Reserved.
|
||||
# Licensed under the Apache License, Version 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
||||
|
||||
"""
|
||||
Res2Net implementation is adapted from https://github.com/wenet-e2e/wespeaker.
|
||||
ERes2Net incorporates both local and global feature fusion techniques to improve the performance.
|
||||
The local feature fusion (LFF) fuses the features within one single residual block to extract the local signal.
|
||||
The global feature fusion (GFF) takes acoustic features of different scales as input to aggregate global signal.
|
||||
"""
|
||||
|
||||
import torch
|
||||
import math
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import pooling_layers as pooling_layers
|
||||
from fusion import AFF
|
||||
|
||||
|
||||
class ReLU(nn.Hardtanh):
|
||||
def __init__(self, inplace=False):
|
||||
super(ReLU, self).__init__(0, 20, inplace)
|
||||
|
||||
def __repr__(self):
|
||||
inplace_str = "inplace" if self.inplace else ""
|
||||
return self.__class__.__name__ + " (" + inplace_str + ")"
|
||||
|
||||
|
||||
class BasicBlockERes2Net(nn.Module):
|
||||
expansion = 2
|
||||
|
||||
def __init__(self, in_planes, planes, stride=1, baseWidth=32, scale=2):
|
||||
super(BasicBlockERes2Net, self).__init__()
|
||||
width = int(math.floor(planes * (baseWidth / 64.0)))
|
||||
self.conv1 = nn.Conv2d(in_planes, width * scale, kernel_size=1, stride=stride, bias=False)
|
||||
self.bn1 = nn.BatchNorm2d(width * scale)
|
||||
self.nums = scale
|
||||
|
||||
convs = []
|
||||
bns = []
|
||||
for i in range(self.nums):
|
||||
convs.append(nn.Conv2d(width, width, kernel_size=3, padding=1, bias=False))
|
||||
bns.append(nn.BatchNorm2d(width))
|
||||
self.convs = nn.ModuleList(convs)
|
||||
self.bns = nn.ModuleList(bns)
|
||||
self.relu = ReLU(inplace=True)
|
||||
|
||||
self.conv3 = nn.Conv2d(width * scale, planes * self.expansion, kernel_size=1, bias=False)
|
||||
self.bn3 = nn.BatchNorm2d(planes * self.expansion)
|
||||
self.shortcut = nn.Sequential()
|
||||
if stride != 1 or in_planes != self.expansion * planes:
|
||||
self.shortcut = nn.Sequential(
|
||||
nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
|
||||
nn.BatchNorm2d(self.expansion * planes),
|
||||
)
|
||||
self.stride = stride
|
||||
self.width = width
|
||||
self.scale = scale
|
||||
|
||||
def forward(self, x):
|
||||
residual = x
|
||||
|
||||
out = self.conv1(x)
|
||||
out = self.bn1(out)
|
||||
out = self.relu(out)
|
||||
spx = torch.split(out, self.width, 1)
|
||||
for i in range(self.nums):
|
||||
if i == 0:
|
||||
sp = spx[i]
|
||||
else:
|
||||
sp = sp + spx[i]
|
||||
sp = self.convs[i](sp)
|
||||
sp = self.relu(self.bns[i](sp))
|
||||
if i == 0:
|
||||
out = sp
|
||||
else:
|
||||
out = torch.cat((out, sp), 1)
|
||||
|
||||
out = self.conv3(out)
|
||||
out = self.bn3(out)
|
||||
|
||||
residual = self.shortcut(x)
|
||||
out += residual
|
||||
out = self.relu(out)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class BasicBlockERes2Net_diff_AFF(nn.Module):
|
||||
expansion = 2
|
||||
|
||||
def __init__(self, in_planes, planes, stride=1, baseWidth=32, scale=2):
|
||||
super(BasicBlockERes2Net_diff_AFF, self).__init__()
|
||||
width = int(math.floor(planes * (baseWidth / 64.0)))
|
||||
self.conv1 = nn.Conv2d(in_planes, width * scale, kernel_size=1, stride=stride, bias=False)
|
||||
self.bn1 = nn.BatchNorm2d(width * scale)
|
||||
self.nums = scale
|
||||
|
||||
convs = []
|
||||
fuse_models = []
|
||||
bns = []
|
||||
for i in range(self.nums):
|
||||
convs.append(nn.Conv2d(width, width, kernel_size=3, padding=1, bias=False))
|
||||
bns.append(nn.BatchNorm2d(width))
|
||||
for j in range(self.nums - 1):
|
||||
fuse_models.append(AFF(channels=width))
|
||||
|
||||
self.convs = nn.ModuleList(convs)
|
||||
self.bns = nn.ModuleList(bns)
|
||||
self.fuse_models = nn.ModuleList(fuse_models)
|
||||
self.relu = ReLU(inplace=True)
|
||||
|
||||
self.conv3 = nn.Conv2d(width * scale, planes * self.expansion, kernel_size=1, bias=False)
|
||||
self.bn3 = nn.BatchNorm2d(planes * self.expansion)
|
||||
self.shortcut = nn.Sequential()
|
||||
if stride != 1 or in_planes != self.expansion * planes:
|
||||
self.shortcut = nn.Sequential(
|
||||
nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
|
||||
nn.BatchNorm2d(self.expansion * planes),
|
||||
)
|
||||
self.stride = stride
|
||||
self.width = width
|
||||
self.scale = scale
|
||||
|
||||
def forward(self, x):
|
||||
residual = x
|
||||
|
||||
out = self.conv1(x)
|
||||
out = self.bn1(out)
|
||||
out = self.relu(out)
|
||||
spx = torch.split(out, self.width, 1)
|
||||
for i in range(self.nums):
|
||||
if i == 0:
|
||||
sp = spx[i]
|
||||
else:
|
||||
sp = self.fuse_models[i - 1](sp, spx[i])
|
||||
|
||||
sp = self.convs[i](sp)
|
||||
sp = self.relu(self.bns[i](sp))
|
||||
if i == 0:
|
||||
out = sp
|
||||
else:
|
||||
out = torch.cat((out, sp), 1)
|
||||
|
||||
out = self.conv3(out)
|
||||
out = self.bn3(out)
|
||||
|
||||
residual = self.shortcut(x)
|
||||
out += residual
|
||||
out = self.relu(out)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class ERes2Net(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
block=BasicBlockERes2Net,
|
||||
block_fuse=BasicBlockERes2Net_diff_AFF,
|
||||
num_blocks=[3, 4, 6, 3],
|
||||
m_channels=32,
|
||||
feat_dim=80,
|
||||
embedding_size=192,
|
||||
pooling_func="TSTP",
|
||||
two_emb_layer=False,
|
||||
):
|
||||
super(ERes2Net, self).__init__()
|
||||
self.in_planes = m_channels
|
||||
self.feat_dim = feat_dim
|
||||
self.embedding_size = embedding_size
|
||||
self.stats_dim = int(feat_dim / 8) * m_channels * 8
|
||||
self.two_emb_layer = two_emb_layer
|
||||
|
||||
self.conv1 = nn.Conv2d(1, m_channels, kernel_size=3, stride=1, padding=1, bias=False)
|
||||
self.bn1 = nn.BatchNorm2d(m_channels)
|
||||
self.layer1 = self._make_layer(block, m_channels, num_blocks[0], stride=1)
|
||||
self.layer2 = self._make_layer(block, m_channels * 2, num_blocks[1], stride=2)
|
||||
self.layer3 = self._make_layer(block_fuse, m_channels * 4, num_blocks[2], stride=2)
|
||||
self.layer4 = self._make_layer(block_fuse, m_channels * 8, num_blocks[3], stride=2)
|
||||
|
||||
# Downsampling module for each layer
|
||||
self.layer1_downsample = nn.Conv2d(
|
||||
m_channels * 2, m_channels * 4, kernel_size=3, stride=2, padding=1, bias=False
|
||||
)
|
||||
self.layer2_downsample = nn.Conv2d(
|
||||
m_channels * 4, m_channels * 8, kernel_size=3, padding=1, stride=2, bias=False
|
||||
)
|
||||
self.layer3_downsample = nn.Conv2d(
|
||||
m_channels * 8, m_channels * 16, kernel_size=3, padding=1, stride=2, bias=False
|
||||
)
|
||||
|
||||
# Bottom-up fusion module
|
||||
self.fuse_mode12 = AFF(channels=m_channels * 4)
|
||||
self.fuse_mode123 = AFF(channels=m_channels * 8)
|
||||
self.fuse_mode1234 = AFF(channels=m_channels * 16)
|
||||
|
||||
self.n_stats = 1 if pooling_func == "TAP" or pooling_func == "TSDP" else 2
|
||||
self.pool = getattr(pooling_layers, pooling_func)(in_dim=self.stats_dim * block.expansion)
|
||||
self.seg_1 = nn.Linear(self.stats_dim * block.expansion * self.n_stats, embedding_size)
|
||||
if self.two_emb_layer:
|
||||
self.seg_bn_1 = nn.BatchNorm1d(embedding_size, affine=False)
|
||||
self.seg_2 = nn.Linear(embedding_size, embedding_size)
|
||||
else:
|
||||
self.seg_bn_1 = nn.Identity()
|
||||
self.seg_2 = nn.Identity()
|
||||
|
||||
def _make_layer(self, block, planes, num_blocks, stride):
|
||||
strides = [stride] + [1] * (num_blocks - 1)
|
||||
layers = []
|
||||
for stride in strides:
|
||||
layers.append(block(self.in_planes, planes, stride))
|
||||
self.in_planes = planes * block.expansion
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x):
|
||||
x = x.permute(0, 2, 1) # (B,T,F) => (B,F,T)
|
||||
x = x.unsqueeze_(1)
|
||||
out = F.relu(self.bn1(self.conv1(x)))
|
||||
out1 = self.layer1(out)
|
||||
out2 = self.layer2(out1)
|
||||
out1_downsample = self.layer1_downsample(out1)
|
||||
fuse_out12 = self.fuse_mode12(out2, out1_downsample)
|
||||
out3 = self.layer3(out2)
|
||||
fuse_out12_downsample = self.layer2_downsample(fuse_out12)
|
||||
fuse_out123 = self.fuse_mode123(out3, fuse_out12_downsample)
|
||||
out4 = self.layer4(out3)
|
||||
fuse_out123_downsample = self.layer3_downsample(fuse_out123)
|
||||
fuse_out1234 = self.fuse_mode1234(out4, fuse_out123_downsample)
|
||||
stats = self.pool(fuse_out1234)
|
||||
|
||||
embed_a = self.seg_1(stats)
|
||||
if self.two_emb_layer:
|
||||
out = F.relu(embed_a)
|
||||
out = self.seg_bn_1(out)
|
||||
embed_b = self.seg_2(out)
|
||||
return embed_b
|
||||
else:
|
||||
return embed_a
|
||||
|
||||
def forward3(self, x):
|
||||
x = x.permute(0, 2, 1) # (B,T,F) => (B,F,T)
|
||||
x = x.unsqueeze_(1)
|
||||
out = F.relu(self.bn1(self.conv1(x)))
|
||||
out1 = self.layer1(out)
|
||||
out2 = self.layer2(out1)
|
||||
out1_downsample = self.layer1_downsample(out1)
|
||||
fuse_out12 = self.fuse_mode12(out2, out1_downsample)
|
||||
out3 = self.layer3(out2)
|
||||
fuse_out12_downsample = self.layer2_downsample(fuse_out12)
|
||||
fuse_out123 = self.fuse_mode123(out3, fuse_out12_downsample)
|
||||
out4 = self.layer4(out3)
|
||||
fuse_out123_downsample = self.layer3_downsample(fuse_out123)
|
||||
fuse_out1234 = self.fuse_mode1234(out4, fuse_out123_downsample).flatten(start_dim=1, end_dim=2).mean(-1)
|
||||
return fuse_out1234
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
x = torch.zeros(10, 300, 80)
|
||||
model = ERes2Net(feat_dim=80, embedding_size=192, pooling_func="TSTP")
|
||||
model.eval()
|
||||
out = model(x)
|
||||
print(out.shape) # torch.Size([10, 192])
|
||||
|
||||
num_params = sum(param.numel() for param in model.parameters())
|
||||
print("{} M".format(num_params / 1e6)) # 6.61M
|
@ -8,12 +8,14 @@ To alleviate this problem, we propose an improved ERes2NetV2 by pruning redundan
|
||||
both the model parameters and its computational cost.
|
||||
"""
|
||||
|
||||
import torch
|
||||
import math
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import pooling_layers as pooling_layers
|
||||
from fusion import AFF
|
||||
|
||||
from . import pooling_layers as pooling_layers
|
||||
from .fusion import AFF
|
||||
|
||||
|
||||
class ReLU(nn.Hardtanh):
|
||||
@ -247,26 +249,4 @@ class ERes2NetV2(nn.Module):
|
||||
out4 = self.layer4(out3)
|
||||
out3_ds = self.layer3_ds(out3)
|
||||
fuse_out34 = self.fuse34(out4, out3_ds)
|
||||
# print(111111111,fuse_out34.shape)#111111111 torch.Size([16, 2048, 10, 72])
|
||||
return fuse_out34.flatten(start_dim=1, end_dim=2).mean(-1)
|
||||
# stats = self.pool(fuse_out34)
|
||||
#
|
||||
# embed_a = self.seg_1(stats)
|
||||
# if self.two_emb_layer:
|
||||
# out = F.relu(embed_a)
|
||||
# out = self.seg_bn_1(out)
|
||||
# embed_b = self.seg_2(out)
|
||||
# return embed_b
|
||||
# else:
|
||||
# return embed_a
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
x = torch.randn(1, 300, 80)
|
||||
model = ERes2NetV2(feat_dim=80, embedding_size=192, m_channels=64, baseWidth=26, scale=2, expansion=2)
|
||||
model.eval()
|
||||
y = model(x)
|
||||
print(y.size())
|
||||
macs, num_params = profile(model, inputs=(x,))
|
||||
print("Params: {} M".format(num_params / 1e6)) # 17.86 M
|
||||
print("MACs: {} G".format(macs / 1e9)) # 12.69 G
|
||||
|
@ -1,289 +0,0 @@
|
||||
# Copyright 3D-Speaker (https://github.com/alibaba-damo-academy/3D-Speaker). All Rights Reserved.
|
||||
# Licensed under the Apache License, Version 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
||||
|
||||
"""Res2Net implementation is adapted from https://github.com/wenet-e2e/wespeaker.
|
||||
ERes2Net incorporates both local and global feature fusion techniques to improve the performance.
|
||||
The local feature fusion (LFF) fuses the features within one single residual block to extract the local signal.
|
||||
The global feature fusion (GFF) takes acoustic features of different scales as input to aggregate global signal.
|
||||
ERes2Net-huge is an upgraded version of ERes2Net that uses a larger number of parameters to achieve better
|
||||
recognition performance. Parameters expansion, baseWidth, and scale can be modified to obtain optimal performance.
|
||||
"""
|
||||
|
||||
import torch
|
||||
import math
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import pooling_layers as pooling_layers
|
||||
from fusion import AFF
|
||||
|
||||
|
||||
class ReLU(nn.Hardtanh):
|
||||
def __init__(self, inplace=False):
|
||||
super(ReLU, self).__init__(0, 20, inplace)
|
||||
|
||||
def __repr__(self):
|
||||
inplace_str = "inplace" if self.inplace else ""
|
||||
return self.__class__.__name__ + " (" + inplace_str + ")"
|
||||
|
||||
|
||||
class BasicBlockERes2Net(nn.Module):
|
||||
expansion = 4
|
||||
|
||||
def __init__(self, in_planes, planes, stride=1, baseWidth=24, scale=3):
|
||||
super(BasicBlockERes2Net, self).__init__()
|
||||
width = int(math.floor(planes * (baseWidth / 64.0)))
|
||||
self.conv1 = nn.Conv2d(in_planes, width * scale, kernel_size=1, stride=stride, bias=False)
|
||||
self.bn1 = nn.BatchNorm2d(width * scale)
|
||||
self.nums = scale
|
||||
|
||||
convs = []
|
||||
bns = []
|
||||
for i in range(self.nums):
|
||||
convs.append(nn.Conv2d(width, width, kernel_size=3, padding=1, bias=False))
|
||||
bns.append(nn.BatchNorm2d(width))
|
||||
self.convs = nn.ModuleList(convs)
|
||||
self.bns = nn.ModuleList(bns)
|
||||
self.relu = ReLU(inplace=True)
|
||||
|
||||
self.conv3 = nn.Conv2d(width * scale, planes * self.expansion, kernel_size=1, bias=False)
|
||||
self.bn3 = nn.BatchNorm2d(planes * self.expansion)
|
||||
self.shortcut = nn.Sequential()
|
||||
if stride != 1 or in_planes != self.expansion * planes:
|
||||
self.shortcut = nn.Sequential(
|
||||
nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
|
||||
nn.BatchNorm2d(self.expansion * planes),
|
||||
)
|
||||
self.stride = stride
|
||||
self.width = width
|
||||
self.scale = scale
|
||||
|
||||
def forward(self, x):
|
||||
residual = x
|
||||
|
||||
out = self.conv1(x)
|
||||
out = self.bn1(out)
|
||||
out = self.relu(out)
|
||||
spx = torch.split(out, self.width, 1)
|
||||
for i in range(self.nums):
|
||||
if i == 0:
|
||||
sp = spx[i]
|
||||
else:
|
||||
sp = sp + spx[i]
|
||||
sp = self.convs[i](sp)
|
||||
sp = self.relu(self.bns[i](sp))
|
||||
if i == 0:
|
||||
out = sp
|
||||
else:
|
||||
out = torch.cat((out, sp), 1)
|
||||
|
||||
out = self.conv3(out)
|
||||
out = self.bn3(out)
|
||||
|
||||
residual = self.shortcut(x)
|
||||
out += residual
|
||||
out = self.relu(out)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class BasicBlockERes2Net_diff_AFF(nn.Module):
|
||||
expansion = 4
|
||||
|
||||
def __init__(self, in_planes, planes, stride=1, baseWidth=24, scale=3):
|
||||
super(BasicBlockERes2Net_diff_AFF, self).__init__()
|
||||
width = int(math.floor(planes * (baseWidth / 64.0)))
|
||||
self.conv1 = nn.Conv2d(in_planes, width * scale, kernel_size=1, stride=stride, bias=False)
|
||||
self.bn1 = nn.BatchNorm2d(width * scale)
|
||||
self.nums = scale
|
||||
|
||||
convs = []
|
||||
fuse_models = []
|
||||
bns = []
|
||||
for i in range(self.nums):
|
||||
convs.append(nn.Conv2d(width, width, kernel_size=3, padding=1, bias=False))
|
||||
bns.append(nn.BatchNorm2d(width))
|
||||
for j in range(self.nums - 1):
|
||||
fuse_models.append(AFF(channels=width))
|
||||
|
||||
self.convs = nn.ModuleList(convs)
|
||||
self.bns = nn.ModuleList(bns)
|
||||
self.fuse_models = nn.ModuleList(fuse_models)
|
||||
self.relu = ReLU(inplace=True)
|
||||
|
||||
self.conv3 = nn.Conv2d(width * scale, planes * self.expansion, kernel_size=1, bias=False)
|
||||
self.bn3 = nn.BatchNorm2d(planes * self.expansion)
|
||||
self.shortcut = nn.Sequential()
|
||||
if stride != 1 or in_planes != self.expansion * planes:
|
||||
self.shortcut = nn.Sequential(
|
||||
nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
|
||||
nn.BatchNorm2d(self.expansion * planes),
|
||||
)
|
||||
self.stride = stride
|
||||
self.width = width
|
||||
self.scale = scale
|
||||
|
||||
def forward(self, x):
|
||||
residual = x
|
||||
|
||||
out = self.conv1(x)
|
||||
out = self.bn1(out)
|
||||
out = self.relu(out)
|
||||
spx = torch.split(out, self.width, 1)
|
||||
for i in range(self.nums):
|
||||
if i == 0:
|
||||
sp = spx[i]
|
||||
else:
|
||||
sp = self.fuse_models[i - 1](sp, spx[i])
|
||||
|
||||
sp = self.convs[i](sp)
|
||||
sp = self.relu(self.bns[i](sp))
|
||||
if i == 0:
|
||||
out = sp
|
||||
else:
|
||||
out = torch.cat((out, sp), 1)
|
||||
|
||||
out = self.conv3(out)
|
||||
out = self.bn3(out)
|
||||
|
||||
residual = self.shortcut(x)
|
||||
out += residual
|
||||
out = self.relu(out)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class ERes2Net(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
block=BasicBlockERes2Net,
|
||||
block_fuse=BasicBlockERes2Net_diff_AFF,
|
||||
num_blocks=[3, 4, 6, 3],
|
||||
m_channels=64,
|
||||
feat_dim=80,
|
||||
embedding_size=192,
|
||||
pooling_func="TSTP",
|
||||
two_emb_layer=False,
|
||||
):
|
||||
super(ERes2Net, self).__init__()
|
||||
self.in_planes = m_channels
|
||||
self.feat_dim = feat_dim
|
||||
self.embedding_size = embedding_size
|
||||
self.stats_dim = int(feat_dim / 8) * m_channels * 8
|
||||
self.two_emb_layer = two_emb_layer
|
||||
|
||||
self.conv1 = nn.Conv2d(1, m_channels, kernel_size=3, stride=1, padding=1, bias=False)
|
||||
self.bn1 = nn.BatchNorm2d(m_channels)
|
||||
|
||||
self.layer1 = self._make_layer(block, m_channels, num_blocks[0], stride=1)
|
||||
self.layer2 = self._make_layer(block, m_channels * 2, num_blocks[1], stride=2)
|
||||
self.layer3 = self._make_layer(block_fuse, m_channels * 4, num_blocks[2], stride=2)
|
||||
self.layer4 = self._make_layer(block_fuse, m_channels * 8, num_blocks[3], stride=2)
|
||||
|
||||
self.layer1_downsample = nn.Conv2d(
|
||||
m_channels * 4, m_channels * 8, kernel_size=3, padding=1, stride=2, bias=False
|
||||
)
|
||||
self.layer2_downsample = nn.Conv2d(
|
||||
m_channels * 8, m_channels * 16, kernel_size=3, padding=1, stride=2, bias=False
|
||||
)
|
||||
self.layer3_downsample = nn.Conv2d(
|
||||
m_channels * 16, m_channels * 32, kernel_size=3, padding=1, stride=2, bias=False
|
||||
)
|
||||
|
||||
self.fuse_mode12 = AFF(channels=m_channels * 8)
|
||||
self.fuse_mode123 = AFF(channels=m_channels * 16)
|
||||
self.fuse_mode1234 = AFF(channels=m_channels * 32)
|
||||
|
||||
self.n_stats = 1 if pooling_func == "TAP" or pooling_func == "TSDP" else 2
|
||||
self.pool = getattr(pooling_layers, pooling_func)(in_dim=self.stats_dim * block.expansion)
|
||||
self.seg_1 = nn.Linear(self.stats_dim * block.expansion * self.n_stats, embedding_size)
|
||||
if self.two_emb_layer:
|
||||
self.seg_bn_1 = nn.BatchNorm1d(embedding_size, affine=False)
|
||||
self.seg_2 = nn.Linear(embedding_size, embedding_size)
|
||||
else:
|
||||
self.seg_bn_1 = nn.Identity()
|
||||
self.seg_2 = nn.Identity()
|
||||
|
||||
def _make_layer(self, block, planes, num_blocks, stride):
|
||||
strides = [stride] + [1] * (num_blocks - 1)
|
||||
layers = []
|
||||
for stride in strides:
|
||||
layers.append(block(self.in_planes, planes, stride))
|
||||
self.in_planes = planes * block.expansion
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x):
|
||||
x = x.permute(0, 2, 1) # (B,T,F) => (B,F,T)
|
||||
|
||||
x = x.unsqueeze_(1)
|
||||
out = F.relu(self.bn1(self.conv1(x)))
|
||||
out1 = self.layer1(out)
|
||||
out2 = self.layer2(out1)
|
||||
out1_downsample = self.layer1_downsample(out1)
|
||||
fuse_out12 = self.fuse_mode12(out2, out1_downsample)
|
||||
out3 = self.layer3(out2)
|
||||
fuse_out12_downsample = self.layer2_downsample(fuse_out12)
|
||||
fuse_out123 = self.fuse_mode123(out3, fuse_out12_downsample)
|
||||
out4 = self.layer4(out3)
|
||||
fuse_out123_downsample = self.layer3_downsample(fuse_out123)
|
||||
fuse_out1234 = self.fuse_mode1234(out4, fuse_out123_downsample)
|
||||
stats = self.pool(fuse_out1234)
|
||||
|
||||
embed_a = self.seg_1(stats)
|
||||
if self.two_emb_layer:
|
||||
out = F.relu(embed_a)
|
||||
out = self.seg_bn_1(out)
|
||||
embed_b = self.seg_2(out)
|
||||
return embed_b
|
||||
else:
|
||||
return embed_a
|
||||
|
||||
def forward2(self, x, if_mean):
|
||||
x = x.permute(0, 2, 1) # (B,T,F) => (B,F,T)
|
||||
|
||||
x = x.unsqueeze_(1)
|
||||
out = F.relu(self.bn1(self.conv1(x)))
|
||||
out1 = self.layer1(out)
|
||||
out2 = self.layer2(out1)
|
||||
out1_downsample = self.layer1_downsample(out1)
|
||||
fuse_out12 = self.fuse_mode12(out2, out1_downsample)
|
||||
out3 = self.layer3(out2)
|
||||
fuse_out12_downsample = self.layer2_downsample(fuse_out12)
|
||||
fuse_out123 = self.fuse_mode123(out3, fuse_out12_downsample)
|
||||
out4 = self.layer4(out3)
|
||||
fuse_out123_downsample = self.layer3_downsample(fuse_out123)
|
||||
fuse_out1234 = self.fuse_mode1234(out4, fuse_out123_downsample).flatten(start_dim=1, end_dim=2) # bs,20480,T
|
||||
if if_mean == False:
|
||||
mean = fuse_out1234[0].transpose(1, 0) # (T,20480),bs=T
|
||||
else:
|
||||
mean = fuse_out1234.mean(2) # bs,20480
|
||||
mean_std = torch.cat([mean, torch.zeros_like(mean)], 1)
|
||||
return self.seg_1(mean_std) # (T,192)
|
||||
|
||||
# stats = self.pool(fuse_out1234)
|
||||
# if self.two_emb_layer:
|
||||
# out = F.relu(embed_a)
|
||||
# out = self.seg_bn_1(out)
|
||||
# embed_b = self.seg_2(out)
|
||||
# return embed_b
|
||||
# else:
|
||||
# return embed_a
|
||||
|
||||
def forward3(self, x):
|
||||
x = x.permute(0, 2, 1) # (B,T,F) => (B,F,T)
|
||||
|
||||
x = x.unsqueeze_(1)
|
||||
out = F.relu(self.bn1(self.conv1(x)))
|
||||
out1 = self.layer1(out)
|
||||
out2 = self.layer2(out1)
|
||||
out1_downsample = self.layer1_downsample(out1)
|
||||
fuse_out12 = self.fuse_mode12(out2, out1_downsample)
|
||||
out3 = self.layer3(out2)
|
||||
fuse_out12_downsample = self.layer2_downsample(fuse_out12)
|
||||
fuse_out123 = self.fuse_mode123(out3, fuse_out12_downsample)
|
||||
out4 = self.layer4(out3)
|
||||
fuse_out123_downsample = self.layer3_downsample(fuse_out123)
|
||||
fuse_out1234 = self.fuse_mode1234(out4, fuse_out123_downsample).flatten(start_dim=1, end_dim=2).mean(-1)
|
||||
return fuse_out1234
|
||||
# print(fuse_out1234.shape)
|
||||
# print(fuse_out1234.flatten(start_dim=1,end_dim=2).shape)
|
||||
# pdb.set_trace()
|
@ -1,28 +1,26 @@
|
||||
# modified from https://github.com/yangdongchao/SoundStorm/blob/master/soundstorm/s1/AR/models/t2s_model.py
|
||||
# reference: https://github.com/lifeiteng/vall-e
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
from io import BytesIO
|
||||
from typing import Optional
|
||||
from my_utils import load_audio
|
||||
|
||||
import soundfile
|
||||
import torch
|
||||
import torchaudio
|
||||
|
||||
from torch import IntTensor, LongTensor, Tensor, nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
from torchaudio.compliance.kaldi import fbank
|
||||
from transformers import AutoModelForMaskedLM, AutoTokenizer
|
||||
from feature_extractor import cnhubert
|
||||
|
||||
from AR.models.t2s_lightning_module import Text2SemanticLightningModule
|
||||
from module.models_onnx import SynthesizerTrn
|
||||
|
||||
from inference_webui import get_phones_and_bert
|
||||
|
||||
from sv import SV
|
||||
import kaldi as Kaldi
|
||||
|
||||
import os
|
||||
import soundfile
|
||||
import GPT_SoVITS.text as text
|
||||
from GPT_SoVITS.AR.models.t2s_lightning_module import Text2SemanticLightningModule
|
||||
from GPT_SoVITS.feature_extractor import cnhubert
|
||||
from GPT_SoVITS.inference_webui import get_phones_and_bert
|
||||
from GPT_SoVITS.module.models_onnx import SynthesizerTrn
|
||||
from GPT_SoVITS.sv import SV
|
||||
from tools.my_utils import load_audio
|
||||
|
||||
default_config = {
|
||||
"embedding_dim": 512,
|
||||
@ -477,7 +475,7 @@ class T2SModel(nn.Module):
|
||||
|
||||
# avoid dtype inconsistency when exporting
|
||||
bert = bert.to(dtype=self.bert_proj.weight.dtype)
|
||||
|
||||
|
||||
x = x + self.bert_proj(bert.transpose(1, 2))
|
||||
x: torch.Tensor = self.ar_text_position(x)
|
||||
|
||||
@ -737,7 +735,7 @@ def export_prov2(
|
||||
device="cpu",
|
||||
is_half=True,
|
||||
):
|
||||
if sv_cn_model == None:
|
||||
if sv_cn_model is None:
|
||||
init_sv_cn(device, is_half)
|
||||
|
||||
if not os.path.exists(output_path):
|
||||
@ -898,7 +896,7 @@ class ExportERes2NetV2(nn.Module):
|
||||
def forward(self, audio_16k):
|
||||
# 这个 fbank 函数有一个 cache, 不过不要紧,它跟 audio_16k 的长度无关
|
||||
# 只跟 device 和 dtype 有关
|
||||
x = Kaldi.fbank(audio_16k, num_mel_bins=80, sample_frequency=16000, dither=0)
|
||||
x = fbank(audio_16k, num_mel_bins=80, sample_frequency=16000, dither=0)
|
||||
x = torch.stack([x])
|
||||
|
||||
x = x.permute(0, 2, 1) # (B,T,F) => (B,F,T)
|
||||
@ -1041,10 +1039,6 @@ def test():
|
||||
soundfile.write("out.wav", audio.detach().cpu().numpy(), 32000)
|
||||
|
||||
|
||||
import text
|
||||
import json
|
||||
|
||||
|
||||
def export_symbel(version="v2"):
|
||||
if version == "v1":
|
||||
symbols = text._symbol_to_id_v1
|
||||
|
@ -1,27 +1,25 @@
|
||||
import logging
|
||||
import os
|
||||
from export_torch_script import (
|
||||
|
||||
import librosa
|
||||
import numpy as np
|
||||
import soundfile
|
||||
import torch
|
||||
import uvicorn
|
||||
from librosa.filters import mel as librosa_mel_fn
|
||||
|
||||
from GPT_SoVITS.export_torch_script import (
|
||||
T2SModel,
|
||||
get_raw_t2s_model,
|
||||
resamplex,
|
||||
spectrogram_torch,
|
||||
)
|
||||
from f5_tts.model.backbones.dit import DiT
|
||||
from inference_webui import get_phones_and_bert
|
||||
import librosa
|
||||
from module import commons
|
||||
from module.mel_processing import mel_spectrogram_torch
|
||||
from module.models_onnx import CFM, Generator, SynthesizerTrnV3
|
||||
import numpy as np
|
||||
import torch._dynamo.config
|
||||
import torchaudio
|
||||
import logging
|
||||
import uvicorn
|
||||
import torch
|
||||
import soundfile
|
||||
from librosa.filters import mel as librosa_mel_fn
|
||||
|
||||
|
||||
from inference_webui import get_spepc, norm_spec, resample, ssl_model
|
||||
from GPT_SoVITS.f5_tts.model.backbones.dit import DiT
|
||||
from GPT_SoVITS.inference_webui import get_phones_and_bert, get_spepc, norm_spec, ssl_model
|
||||
from GPT_SoVITS.module import commons
|
||||
from GPT_SoVITS.module.mel_processing import mel_spectrogram_torch
|
||||
from GPT_SoVITS.module.models_onnx import CFM, Generator, SynthesizerTrnV3
|
||||
from GPT_SoVITS.process_ckpt import inspect_version
|
||||
|
||||
logging.config.dictConfig(uvicorn.config.LOGGING_CONFIG)
|
||||
logger = logging.getLogger("uvicorn")
|
||||
@ -176,32 +174,33 @@ class ExportCFM(torch.nn.Module):
|
||||
return cfm_res, fea_ref, mel2
|
||||
|
||||
|
||||
mel_fn = lambda x: mel_spectrogram_torch(
|
||||
x,
|
||||
**{
|
||||
"n_fft": 1024,
|
||||
"win_size": 1024,
|
||||
"hop_size": 256,
|
||||
"num_mels": 100,
|
||||
"sampling_rate": 24000,
|
||||
"fmin": 0,
|
||||
"fmax": None,
|
||||
"center": False,
|
||||
},
|
||||
)
|
||||
mel_fn_v4 = lambda x: mel_spectrogram_torch(
|
||||
x,
|
||||
**{
|
||||
"n_fft": 1280,
|
||||
"win_size": 1280,
|
||||
"hop_size": 320,
|
||||
"num_mels": 100,
|
||||
"sampling_rate": 32000,
|
||||
"fmin": 0,
|
||||
"fmax": None,
|
||||
"center": False,
|
||||
},
|
||||
)
|
||||
def mel_fn(x):
|
||||
return mel_spectrogram_torch(
|
||||
y=x,
|
||||
n_fft=1024,
|
||||
num_mels=100,
|
||||
sampling_rate=24000,
|
||||
hop_size=256,
|
||||
win_size=1024,
|
||||
fmin=0,
|
||||
fmax=None,
|
||||
center=False,
|
||||
)
|
||||
|
||||
|
||||
def mel_fn_v4(x):
|
||||
return mel_spectrogram_torch(
|
||||
y=x,
|
||||
n_fft=1280,
|
||||
num_mels=100,
|
||||
sampling_rate=32000,
|
||||
hop_size=320,
|
||||
win_size=1280,
|
||||
fmin=0,
|
||||
fmax=None,
|
||||
center=False,
|
||||
)
|
||||
|
||||
|
||||
spec_min = -12
|
||||
spec_max = 2
|
||||
@ -511,7 +510,7 @@ def init_bigvgan():
|
||||
# remove weight norm in the model and set to eval mode
|
||||
bigvgan_model.remove_weight_norm()
|
||||
bigvgan_model = bigvgan_model.eval()
|
||||
if is_half == True:
|
||||
if is_half is True:
|
||||
bigvgan_model = bigvgan_model.half().to(device)
|
||||
else:
|
||||
bigvgan_model = bigvgan_model.to(device)
|
||||
@ -536,7 +535,7 @@ def init_hifigan():
|
||||
"%s/GPT_SoVITS/pretrained_models/gsv-v4-pretrained/vocoder.pth" % (now_dir,), map_location="cpu"
|
||||
)
|
||||
print("loading vocoder", hifigan_model.load_state_dict(state_dict_g))
|
||||
if is_half == True:
|
||||
if is_half is True:
|
||||
hifigan_model = hifigan_model.half().to(device)
|
||||
else:
|
||||
hifigan_model = hifigan_model.to(device)
|
||||
@ -578,8 +577,6 @@ class DictToAttrRecursive(dict):
|
||||
raise AttributeError(f"Attribute {item} not found")
|
||||
|
||||
|
||||
from process_ckpt import get_sovits_version_from_path_fast, load_sovits_new
|
||||
|
||||
v3v4set = {"v3", "v4"}
|
||||
|
||||
|
||||
@ -587,12 +584,10 @@ def get_sovits_weights(sovits_path):
|
||||
path_sovits_v3 = "GPT_SoVITS/pretrained_models/s2Gv3.pth"
|
||||
is_exist_s2gv3 = os.path.exists(path_sovits_v3)
|
||||
|
||||
version, model_version, if_lora_v3 = get_sovits_version_from_path_fast(sovits_path)
|
||||
if if_lora_v3 == True and is_exist_s2gv3 == False:
|
||||
model_version, version, if_lora_v3, hps, dict_s2 = inspect_version(sovits_path)
|
||||
if if_lora_v3 is True and is_exist_s2gv3 is False:
|
||||
logger.info("SoVITS V3 底模缺失,无法加载相应 LoRA 权重")
|
||||
|
||||
dict_s2 = load_sovits_new(sovits_path)
|
||||
hps = dict_s2["config"]
|
||||
hps = DictToAttrRecursive(hps)
|
||||
hps.model.semantic_frame_rate = "25hz"
|
||||
if "enc_p.text_embedding.weight" not in dict_s2["weight"]:
|
||||
@ -617,7 +612,7 @@ def get_sovits_weights(sovits_path):
|
||||
model_version = hps.model.version
|
||||
logger.info(f"模型版本: {model_version}")
|
||||
|
||||
if is_half == True:
|
||||
if is_half is True:
|
||||
vq_model = vq_model.half().to(device)
|
||||
else:
|
||||
vq_model = vq_model.to(device)
|
||||
@ -729,11 +724,11 @@ def export_1(ref_wav_path, ref_wav_text, version="v3"):
|
||||
# ref_wav_path = "onnx/ad/ref.wav"
|
||||
speed = 1.0
|
||||
sample_steps = 8
|
||||
dtype = torch.float16 if is_half == True else torch.float32
|
||||
dtype = torch.float16 if is_half is True else torch.float32
|
||||
refer = get_spepc(hps, ref_wav_path).to(device).to(dtype)
|
||||
zero_wav = np.zeros(
|
||||
int(hps.data.sampling_rate * 0.3),
|
||||
dtype=np.float16 if is_half == True else np.float32,
|
||||
dtype=np.float16 if is_half is True else np.float32,
|
||||
)
|
||||
|
||||
with torch.no_grad():
|
||||
@ -741,7 +736,7 @@ def export_1(ref_wav_path, ref_wav_text, version="v3"):
|
||||
wav16k = torch.from_numpy(wav16k)
|
||||
zero_wav_torch = torch.from_numpy(zero_wav)
|
||||
|
||||
if is_half == True:
|
||||
if is_half is True:
|
||||
wav16k = wav16k.half().to(device)
|
||||
zero_wav_torch = zero_wav_torch.half().to(device)
|
||||
else:
|
||||
@ -828,13 +823,11 @@ def export_1(ref_wav_path, ref_wav_text, version="v3"):
|
||||
gpt_sovits_half = ExportGPTSovitsV4Half(sovits.hps, script_t2s, trace_vq_model)
|
||||
torch.jit.script(gpt_sovits_half).save("onnx/ad/gpt_sovits_v4_half.pt")
|
||||
|
||||
ref_audio, sr = torchaudio.load(ref_wav_path)
|
||||
tgt_sr = 24000 if version == "v3" else 32000
|
||||
ref_audio = torch.from_numpy(librosa.load(ref_wav_path, sr=tgt_sr)[0]).unsqueeze(0)
|
||||
ref_audio = ref_audio.to(device).float()
|
||||
if ref_audio.shape[0] == 2:
|
||||
ref_audio = ref_audio.mean(0).unsqueeze(0)
|
||||
tgt_sr = 24000 if version == "v3" else 32000
|
||||
if sr != tgt_sr:
|
||||
ref_audio = resample(ref_audio, sr, tgt_sr)
|
||||
# mel2 = mel_fn(ref_audio)
|
||||
mel2 = mel_fn(ref_audio) if version == "v3" else mel_fn_v4(ref_audio)
|
||||
mel2 = norm_spec(mel2)
|
||||
@ -940,11 +933,11 @@ def test_export(
|
||||
speed = 1.0
|
||||
sample_steps = 8
|
||||
|
||||
dtype = torch.float16 if is_half == True else torch.float32
|
||||
dtype = torch.float16 if is_half is True else torch.float32
|
||||
|
||||
zero_wav = np.zeros(
|
||||
int(16000 * 0.3),
|
||||
dtype=np.float16 if is_half == True else np.float32,
|
||||
dtype=np.float16 if is_half is True else np.float32,
|
||||
)
|
||||
|
||||
with torch.no_grad():
|
||||
@ -952,7 +945,7 @@ def test_export(
|
||||
wav16k = torch.from_numpy(wav16k)
|
||||
zero_wav_torch = torch.from_numpy(zero_wav)
|
||||
|
||||
if is_half == True:
|
||||
if is_half is True:
|
||||
wav16k = wav16k.half().to(device)
|
||||
zero_wav_torch = zero_wav_torch.half().to(device)
|
||||
else:
|
||||
@ -1058,11 +1051,11 @@ def test_export(
|
||||
speed = 1.0
|
||||
sample_steps = torch.LongTensor([16])
|
||||
|
||||
dtype = torch.float16 if is_half == True else torch.float32
|
||||
dtype = torch.float16 if is_half is True else torch.float32
|
||||
|
||||
zero_wav = np.zeros(
|
||||
int(out_sr * 0.3),
|
||||
dtype=np.float16 if is_half == True else np.float32,
|
||||
dtype=np.float16 if is_half is True else np.float32,
|
||||
)
|
||||
|
||||
with torch.no_grad():
|
||||
@ -1070,7 +1063,7 @@ def test_export(
|
||||
wav16k = torch.from_numpy(wav16k)
|
||||
zero_wav_torch = torch.from_numpy(zero_wav)
|
||||
|
||||
if is_half == True:
|
||||
if is_half is True:
|
||||
wav16k = wav16k.half().to(device)
|
||||
zero_wav_torch = zero_wav_torch.half().to(device)
|
||||
else:
|
||||
|
@ -1,13 +1,3 @@
|
||||
# from f5_tts.model.cfm import CFM
|
||||
#
|
||||
# from f5_tts.model.backbones.unett import UNetT
|
||||
from GPT_SoVITS.f5_tts.model.backbones.dit import DiT
|
||||
# from f5_tts.model.backbones.dit import DiTNoCond
|
||||
# from f5_tts.model.backbones.dit import DiTNoCondNoT
|
||||
# from f5_tts.model.backbones.mmdit import MMDiT
|
||||
from .backbones.dit import DiT
|
||||
|
||||
# from f5_tts.model.trainer import Trainer
|
||||
|
||||
|
||||
# __all__ = ["CFM", "UNetT", "DiT", "MMDiT", "Trainer"]
|
||||
# __all__ = ["CFM", "UNetT", "DiTNoCond","DiT", "MMDiT"]
|
||||
__all__ = ["DiT"]
|
||||
|
@ -12,21 +12,20 @@ from __future__ import annotations
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.utils.checkpoint import checkpoint
|
||||
|
||||
from x_transformers.x_transformers import RotaryEmbedding
|
||||
|
||||
from GPT_SoVITS.f5_tts.model.modules import (
|
||||
TimestepEmbedding,
|
||||
from GPT_SoVITS.module.commons import sequence_mask
|
||||
|
||||
from ..modules import (
|
||||
AdaLayerNormZero_Final,
|
||||
ConvNeXtV2Block,
|
||||
ConvPositionEmbedding,
|
||||
DiTBlock,
|
||||
AdaLayerNormZero_Final,
|
||||
precompute_freqs_cis,
|
||||
TimestepEmbedding,
|
||||
get_pos_embed_indices,
|
||||
precompute_freqs_cis,
|
||||
)
|
||||
|
||||
from module.commons import sequence_mask
|
||||
|
||||
|
||||
class TextEmbedding(nn.Module):
|
||||
def __init__(self, text_dim, conv_layers=0, conv_mult=2):
|
||||
|
@ -11,19 +11,17 @@ from __future__ import annotations
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from x_transformers.x_transformers import RotaryEmbedding
|
||||
|
||||
from f5_tts.model.modules import (
|
||||
TimestepEmbedding,
|
||||
from ..modules import (
|
||||
AdaLayerNormZero_Final,
|
||||
ConvPositionEmbedding,
|
||||
MMDiTBlock,
|
||||
AdaLayerNormZero_Final,
|
||||
precompute_freqs_cis,
|
||||
TimestepEmbedding,
|
||||
get_pos_embed_indices,
|
||||
precompute_freqs_cis,
|
||||
)
|
||||
|
||||
|
||||
# text embedding
|
||||
|
||||
|
||||
|
@ -8,27 +8,26 @@ d - dimension
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Literal
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from torch import nn
|
||||
from x_transformers import RMSNorm
|
||||
from x_transformers.x_transformers import RotaryEmbedding
|
||||
|
||||
from f5_tts.model.modules import (
|
||||
TimestepEmbedding,
|
||||
ConvNeXtV2Block,
|
||||
ConvPositionEmbedding,
|
||||
from ..modules import (
|
||||
Attention,
|
||||
AttnProcessor,
|
||||
ConvNeXtV2Block,
|
||||
ConvPositionEmbedding,
|
||||
FeedForward,
|
||||
precompute_freqs_cis,
|
||||
TimestepEmbedding,
|
||||
get_pos_embed_indices,
|
||||
precompute_freqs_cis,
|
||||
)
|
||||
|
||||
|
||||
# Text embedding
|
||||
|
||||
|
||||
|
@ -19,7 +19,6 @@ from librosa.filters import mel as librosa_mel_fn
|
||||
from torch import nn
|
||||
from x_transformers.x_transformers import apply_rotary_pos_emb
|
||||
|
||||
|
||||
# raw wav to mel spec
|
||||
|
||||
|
||||
|
@ -1,28 +1,25 @@
|
||||
import torch
|
||||
import logging
|
||||
import os
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from transformers import (
|
||||
HubertModel,
|
||||
Wav2Vec2FeatureExtractor,
|
||||
)
|
||||
from transformers import logging as tf_logging
|
||||
|
||||
tf_logging.set_verbosity_error()
|
||||
|
||||
import logging
|
||||
|
||||
logging.getLogger("numba").setLevel(logging.WARNING)
|
||||
|
||||
from transformers import (
|
||||
Wav2Vec2FeatureExtractor,
|
||||
HubertModel,
|
||||
)
|
||||
|
||||
import utils
|
||||
import torch.nn as nn
|
||||
|
||||
cnhubert_base_path = None
|
||||
cnhubert_base_path = "GPT_SoVITS/pretrained_models/chinese-hubert-base"
|
||||
|
||||
|
||||
class CNHubert(nn.Module):
|
||||
def __init__(self, base_path: str = None):
|
||||
def __init__(self, base_path: str = ""):
|
||||
super().__init__()
|
||||
if base_path is None:
|
||||
if not base_path:
|
||||
base_path = cnhubert_base_path
|
||||
if os.path.exists(base_path):
|
||||
...
|
||||
@ -37,70 +34,13 @@ class CNHubert(nn.Module):
|
||||
return feats
|
||||
|
||||
|
||||
# class CNHubertLarge(nn.Module):
|
||||
# def __init__(self):
|
||||
# super().__init__()
|
||||
# self.model = HubertModel.from_pretrained("/data/docker/liujing04/gpt-vits/chinese-hubert-large")
|
||||
# self.feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained("/data/docker/liujing04/gpt-vits/chinese-hubert-large")
|
||||
# def forward(self, x):
|
||||
# input_values = self.feature_extractor(x, return_tensors="pt", sampling_rate=16000).input_values.to(x.device)
|
||||
# feats = self.model(input_values)["last_hidden_state"]
|
||||
# return feats
|
||||
#
|
||||
# class CVec(nn.Module):
|
||||
# def __init__(self):
|
||||
# super().__init__()
|
||||
# self.model = HubertModel.from_pretrained("/data/docker/liujing04/vc-webui-big/hubert_base")
|
||||
# self.feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained("/data/docker/liujing04/vc-webui-big/hubert_base")
|
||||
# def forward(self, x):
|
||||
# input_values = self.feature_extractor(x, return_tensors="pt", sampling_rate=16000).input_values.to(x.device)
|
||||
# feats = self.model(input_values)["last_hidden_state"]
|
||||
# return feats
|
||||
#
|
||||
# class cnw2v2base(nn.Module):
|
||||
# def __init__(self):
|
||||
# super().__init__()
|
||||
# self.model = Wav2Vec2Model.from_pretrained("/data/docker/liujing04/gpt-vits/chinese-wav2vec2-base")
|
||||
# self.feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained("/data/docker/liujing04/gpt-vits/chinese-wav2vec2-base")
|
||||
# def forward(self, x):
|
||||
# input_values = self.feature_extractor(x, return_tensors="pt", sampling_rate=16000).input_values.to(x.device)
|
||||
# feats = self.model(input_values)["last_hidden_state"]
|
||||
# return feats
|
||||
|
||||
|
||||
def get_model():
|
||||
model = CNHubert()
|
||||
model.eval()
|
||||
return model
|
||||
|
||||
|
||||
# def get_large_model():
|
||||
# model = CNHubertLarge()
|
||||
# model.eval()
|
||||
# return model
|
||||
#
|
||||
# def get_model_cvec():
|
||||
# model = CVec()
|
||||
# model.eval()
|
||||
# return model
|
||||
#
|
||||
# def get_model_cnw2v2base():
|
||||
# model = cnw2v2base()
|
||||
# model.eval()
|
||||
# return model
|
||||
|
||||
|
||||
def get_content(hmodel, wav_16k_tensor):
|
||||
with torch.no_grad():
|
||||
feats = hmodel(wav_16k_tensor)
|
||||
return feats.transpose(1, 2)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
model = get_model()
|
||||
src_path = "/Users/Shared/原音频2.wav"
|
||||
wav_16k_tensor = utils.load_wav_to_torch_and_resample(src_path, 16000)
|
||||
model = model
|
||||
wav_16k_tensor = wav_16k_tensor
|
||||
feats = get_content(model, wav_16k_tensor)
|
||||
print(feats.shape)
|
||||
|
@ -1,17 +1,16 @@
|
||||
import torch
|
||||
import whisper
|
||||
from whisper import log_mel_spectrogram, pad_or_trim
|
||||
|
||||
|
||||
def get_model():
|
||||
import whisper
|
||||
|
||||
model = whisper.load_model("small", device="cpu")
|
||||
|
||||
return model.encoder
|
||||
|
||||
|
||||
def get_content(model=None, wav_16k_tensor=None):
|
||||
from whisper import log_mel_spectrogram, pad_or_trim
|
||||
|
||||
def get_content(model: whisper.Whisper, wav_16k_tensor: torch.Tensor):
|
||||
assert model
|
||||
dev = next(model.parameters()).device
|
||||
mel = log_mel_spectrogram(wav_16k_tensor).to(dev)[:, :3000]
|
||||
# if torch.cuda.is_available():
|
||||
@ -19,5 +18,5 @@ def get_content(model=None, wav_16k_tensor=None):
|
||||
feature_len = mel.shape[-1] // 2
|
||||
assert mel.shape[-1] < 3000, "输入音频过长,只允许输入30以内音频"
|
||||
with torch.no_grad():
|
||||
feature = model(pad_or_trim(mel, 3000).unsqueeze(0))[:1, :feature_len, :].transpose(1, 2)
|
||||
feature = model(pad_or_trim(mel, 3000).unsqueeze(0))[:1, :feature_len, :].transpose(1, 2) # type: ignore
|
||||
return feature
|
||||
|
@ -1,86 +0,0 @@
|
||||
import argparse
|
||||
import os
|
||||
import soundfile as sf
|
||||
|
||||
from tools.i18n.i18n import I18nAuto
|
||||
from GPT_SoVITS.inference_webui import change_gpt_weights, change_sovits_weights, get_tts_wav
|
||||
|
||||
i18n = I18nAuto()
|
||||
|
||||
|
||||
def synthesize(
|
||||
GPT_model_path,
|
||||
SoVITS_model_path,
|
||||
ref_audio_path,
|
||||
ref_text_path,
|
||||
ref_language,
|
||||
target_text_path,
|
||||
target_language,
|
||||
output_path,
|
||||
):
|
||||
# Read reference text
|
||||
with open(ref_text_path, "r", encoding="utf-8") as file:
|
||||
ref_text = file.read()
|
||||
|
||||
# Read target text
|
||||
with open(target_text_path, "r", encoding="utf-8") as file:
|
||||
target_text = file.read()
|
||||
|
||||
# Change model weights
|
||||
change_gpt_weights(gpt_path=GPT_model_path)
|
||||
change_sovits_weights(sovits_path=SoVITS_model_path)
|
||||
|
||||
# Synthesize audio
|
||||
synthesis_result = get_tts_wav(
|
||||
ref_wav_path=ref_audio_path,
|
||||
prompt_text=ref_text,
|
||||
prompt_language=i18n(ref_language),
|
||||
text=target_text,
|
||||
text_language=i18n(target_language),
|
||||
top_p=1,
|
||||
temperature=1,
|
||||
)
|
||||
|
||||
result_list = list(synthesis_result)
|
||||
|
||||
if result_list:
|
||||
last_sampling_rate, last_audio_data = result_list[-1]
|
||||
output_wav_path = os.path.join(output_path, "output.wav")
|
||||
sf.write(output_wav_path, last_audio_data, last_sampling_rate)
|
||||
print(f"Audio saved to {output_wav_path}")
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="GPT-SoVITS Command Line Tool")
|
||||
parser.add_argument("--gpt_model", required=True, help="Path to the GPT model file")
|
||||
parser.add_argument("--sovits_model", required=True, help="Path to the SoVITS model file")
|
||||
parser.add_argument("--ref_audio", required=True, help="Path to the reference audio file")
|
||||
parser.add_argument("--ref_text", required=True, help="Path to the reference text file")
|
||||
parser.add_argument(
|
||||
"--ref_language", required=True, choices=["中文", "英文", "日文"], help="Language of the reference audio"
|
||||
)
|
||||
parser.add_argument("--target_text", required=True, help="Path to the target text file")
|
||||
parser.add_argument(
|
||||
"--target_language",
|
||||
required=True,
|
||||
choices=["中文", "英文", "日文", "中英混合", "日英混合", "多语种混合"],
|
||||
help="Language of the target text",
|
||||
)
|
||||
parser.add_argument("--output_path", required=True, help="Path to the output directory")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
synthesize(
|
||||
args.gpt_model,
|
||||
args.sovits_model,
|
||||
args.ref_audio,
|
||||
args.ref_text,
|
||||
args.ref_language,
|
||||
args.target_text,
|
||||
args.target_language,
|
||||
args.output_path,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -1,316 +0,0 @@
|
||||
import os
|
||||
import sys
|
||||
from PyQt5.QtCore import QEvent
|
||||
from PyQt5.QtWidgets import QApplication, QMainWindow, QLabel, QLineEdit, QPushButton, QTextEdit
|
||||
from PyQt5.QtWidgets import QGridLayout, QVBoxLayout, QWidget, QFileDialog, QStatusBar, QComboBox
|
||||
import soundfile as sf
|
||||
|
||||
from tools.i18n.i18n import I18nAuto
|
||||
|
||||
i18n = I18nAuto()
|
||||
|
||||
from inference_webui import gpt_path, sovits_path, change_gpt_weights, change_sovits_weights, get_tts_wav
|
||||
|
||||
|
||||
class GPTSoVITSGUI(QMainWindow):
|
||||
GPT_Path = gpt_path
|
||||
SoVITS_Path = sovits_path
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
self.setWindowTitle("GPT-SoVITS GUI")
|
||||
self.setGeometry(800, 450, 950, 850)
|
||||
|
||||
self.setStyleSheet("""
|
||||
QWidget {
|
||||
background-color: #a3d3b1;
|
||||
}
|
||||
|
||||
QTabWidget::pane {
|
||||
background-color: #a3d3b1;
|
||||
}
|
||||
|
||||
QTabWidget::tab-bar {
|
||||
alignment: left;
|
||||
}
|
||||
|
||||
QTabBar::tab {
|
||||
background: #8da4bf;
|
||||
color: #ffffff;
|
||||
padding: 8px;
|
||||
}
|
||||
|
||||
QTabBar::tab:selected {
|
||||
background: #2a3f54;
|
||||
}
|
||||
|
||||
QLabel {
|
||||
color: #000000;
|
||||
}
|
||||
|
||||
QPushButton {
|
||||
background-color: #4CAF50;
|
||||
color: white;
|
||||
padding: 8px;
|
||||
border: 1px solid #4CAF50;
|
||||
border-radius: 4px;
|
||||
}
|
||||
|
||||
QPushButton:hover {
|
||||
background-color: #45a049;
|
||||
border: 1px solid #45a049;
|
||||
box-shadow: 2px 2px 2px rgba(0, 0, 0, 0.1);
|
||||
}
|
||||
""")
|
||||
|
||||
license_text = (
|
||||
"本软件以MIT协议开源, 作者不对软件具备任何控制力, 使用软件者、传播软件导出的声音者自负全责. "
|
||||
"如不认可该条款, 则不能使用或引用软件包内任何代码和文件. 详见根目录LICENSE."
|
||||
)
|
||||
license_label = QLabel(license_text)
|
||||
license_label.setWordWrap(True)
|
||||
|
||||
self.GPT_model_label = QLabel("选择GPT模型:")
|
||||
self.GPT_model_input = QLineEdit()
|
||||
self.GPT_model_input.setPlaceholderText("拖拽或选择文件")
|
||||
self.GPT_model_input.setText(self.GPT_Path)
|
||||
self.GPT_model_input.setReadOnly(True)
|
||||
self.GPT_model_button = QPushButton("选择GPT模型文件")
|
||||
self.GPT_model_button.clicked.connect(self.select_GPT_model)
|
||||
|
||||
self.SoVITS_model_label = QLabel("选择SoVITS模型:")
|
||||
self.SoVITS_model_input = QLineEdit()
|
||||
self.SoVITS_model_input.setPlaceholderText("拖拽或选择文件")
|
||||
self.SoVITS_model_input.setText(self.SoVITS_Path)
|
||||
self.SoVITS_model_input.setReadOnly(True)
|
||||
self.SoVITS_model_button = QPushButton("选择SoVITS模型文件")
|
||||
self.SoVITS_model_button.clicked.connect(self.select_SoVITS_model)
|
||||
|
||||
self.ref_audio_label = QLabel("上传参考音频:")
|
||||
self.ref_audio_input = QLineEdit()
|
||||
self.ref_audio_input.setPlaceholderText("拖拽或选择文件")
|
||||
self.ref_audio_input.setReadOnly(True)
|
||||
self.ref_audio_button = QPushButton("选择音频文件")
|
||||
self.ref_audio_button.clicked.connect(self.select_ref_audio)
|
||||
|
||||
self.ref_text_label = QLabel("参考音频文本:")
|
||||
self.ref_text_input = QLineEdit()
|
||||
self.ref_text_input.setPlaceholderText("直接输入文字或上传文本")
|
||||
self.ref_text_button = QPushButton("上传文本")
|
||||
self.ref_text_button.clicked.connect(self.upload_ref_text)
|
||||
|
||||
self.ref_language_label = QLabel("参考音频语言:")
|
||||
self.ref_language_combobox = QComboBox()
|
||||
self.ref_language_combobox.addItems(["中文", "英文", "日文", "中英混合", "日英混合", "多语种混合"])
|
||||
self.ref_language_combobox.setCurrentText("多语种混合")
|
||||
|
||||
self.target_text_label = QLabel("合成目标文本:")
|
||||
self.target_text_input = QLineEdit()
|
||||
self.target_text_input.setPlaceholderText("直接输入文字或上传文本")
|
||||
self.target_text_button = QPushButton("上传文本")
|
||||
self.target_text_button.clicked.connect(self.upload_target_text)
|
||||
|
||||
self.target_language_label = QLabel("合成音频语言:")
|
||||
self.target_language_combobox = QComboBox()
|
||||
self.target_language_combobox.addItems(["中文", "英文", "日文", "中英混合", "日英混合", "多语种混合"])
|
||||
self.target_language_combobox.setCurrentText("多语种混合")
|
||||
|
||||
self.output_label = QLabel("输出音频路径:")
|
||||
self.output_input = QLineEdit()
|
||||
self.output_input.setPlaceholderText("拖拽或选择文件")
|
||||
self.output_input.setReadOnly(True)
|
||||
self.output_button = QPushButton("选择文件夹")
|
||||
self.output_button.clicked.connect(self.select_output_path)
|
||||
|
||||
self.output_text = QTextEdit()
|
||||
self.output_text.setReadOnly(True)
|
||||
|
||||
self.add_drag_drop_events(
|
||||
[
|
||||
self.GPT_model_input,
|
||||
self.SoVITS_model_input,
|
||||
self.ref_audio_input,
|
||||
self.ref_text_input,
|
||||
self.target_text_input,
|
||||
self.output_input,
|
||||
]
|
||||
)
|
||||
|
||||
self.synthesize_button = QPushButton("合成")
|
||||
self.synthesize_button.clicked.connect(self.synthesize)
|
||||
|
||||
self.clear_output_button = QPushButton("清空输出")
|
||||
self.clear_output_button.clicked.connect(self.clear_output)
|
||||
|
||||
self.status_bar = QStatusBar()
|
||||
|
||||
main_layout = QVBoxLayout()
|
||||
|
||||
input_layout = QGridLayout(self)
|
||||
input_layout.setSpacing(10)
|
||||
|
||||
input_layout.addWidget(license_label, 0, 0, 1, 3)
|
||||
|
||||
input_layout.addWidget(self.GPT_model_label, 1, 0)
|
||||
input_layout.addWidget(self.GPT_model_input, 2, 0, 1, 2)
|
||||
input_layout.addWidget(self.GPT_model_button, 2, 2)
|
||||
|
||||
input_layout.addWidget(self.SoVITS_model_label, 3, 0)
|
||||
input_layout.addWidget(self.SoVITS_model_input, 4, 0, 1, 2)
|
||||
input_layout.addWidget(self.SoVITS_model_button, 4, 2)
|
||||
|
||||
input_layout.addWidget(self.ref_audio_label, 5, 0)
|
||||
input_layout.addWidget(self.ref_audio_input, 6, 0, 1, 2)
|
||||
input_layout.addWidget(self.ref_audio_button, 6, 2)
|
||||
|
||||
input_layout.addWidget(self.ref_language_label, 7, 0)
|
||||
input_layout.addWidget(self.ref_language_combobox, 8, 0, 1, 1)
|
||||
input_layout.addWidget(self.ref_text_label, 9, 0)
|
||||
input_layout.addWidget(self.ref_text_input, 10, 0, 1, 2)
|
||||
input_layout.addWidget(self.ref_text_button, 10, 2)
|
||||
|
||||
input_layout.addWidget(self.target_language_label, 11, 0)
|
||||
input_layout.addWidget(self.target_language_combobox, 12, 0, 1, 1)
|
||||
input_layout.addWidget(self.target_text_label, 13, 0)
|
||||
input_layout.addWidget(self.target_text_input, 14, 0, 1, 2)
|
||||
input_layout.addWidget(self.target_text_button, 14, 2)
|
||||
|
||||
input_layout.addWidget(self.output_label, 15, 0)
|
||||
input_layout.addWidget(self.output_input, 16, 0, 1, 2)
|
||||
input_layout.addWidget(self.output_button, 16, 2)
|
||||
|
||||
main_layout.addLayout(input_layout)
|
||||
|
||||
output_layout = QVBoxLayout()
|
||||
output_layout.addWidget(self.output_text)
|
||||
main_layout.addLayout(output_layout)
|
||||
|
||||
main_layout.addWidget(self.synthesize_button)
|
||||
|
||||
main_layout.addWidget(self.clear_output_button)
|
||||
|
||||
main_layout.addWidget(self.status_bar)
|
||||
|
||||
self.central_widget = QWidget()
|
||||
self.central_widget.setLayout(main_layout)
|
||||
self.setCentralWidget(self.central_widget)
|
||||
|
||||
def dragEnterEvent(self, event):
|
||||
if event.mimeData().hasUrls():
|
||||
event.acceptProposedAction()
|
||||
|
||||
def dropEvent(self, event):
|
||||
if event.mimeData().hasUrls():
|
||||
file_paths = [url.toLocalFile() for url in event.mimeData().urls()]
|
||||
if len(file_paths) == 1:
|
||||
self.update_ref_audio(file_paths[0])
|
||||
else:
|
||||
self.update_ref_audio(", ".join(file_paths))
|
||||
|
||||
def add_drag_drop_events(self, widgets):
|
||||
for widget in widgets:
|
||||
widget.setAcceptDrops(True)
|
||||
widget.installEventFilter(self)
|
||||
|
||||
def eventFilter(self, obj, event):
|
||||
if event.type() in (QEvent.DragEnter, QEvent.Drop):
|
||||
mime_data = event.mimeData()
|
||||
if mime_data.hasUrls():
|
||||
event.acceptProposedAction()
|
||||
|
||||
return super().eventFilter(obj, event)
|
||||
|
||||
def select_GPT_model(self):
|
||||
file_path, _ = QFileDialog.getOpenFileName(self, "选择GPT模型文件", "", "GPT Files (*.ckpt)")
|
||||
if file_path:
|
||||
self.GPT_model_input.setText(file_path)
|
||||
|
||||
def select_SoVITS_model(self):
|
||||
file_path, _ = QFileDialog.getOpenFileName(self, "选择SoVITS模型文件", "", "SoVITS Files (*.pth)")
|
||||
if file_path:
|
||||
self.SoVITS_model_input.setText(file_path)
|
||||
|
||||
def select_ref_audio(self):
|
||||
file_path, _ = QFileDialog.getOpenFileName(self, "选择参考音频文件", "", "Audio Files (*.wav *.mp3)")
|
||||
if file_path:
|
||||
self.update_ref_audio(file_path)
|
||||
|
||||
def upload_ref_text(self):
|
||||
file_path, _ = QFileDialog.getOpenFileName(self, "选择文本文件", "", "Text Files (*.txt)")
|
||||
if file_path:
|
||||
with open(file_path, "r", encoding="utf-8") as file:
|
||||
content = file.read()
|
||||
self.ref_text_input.setText(content)
|
||||
|
||||
def upload_target_text(self):
|
||||
file_path, _ = QFileDialog.getOpenFileName(self, "选择文本文件", "", "Text Files (*.txt)")
|
||||
if file_path:
|
||||
with open(file_path, "r", encoding="utf-8") as file:
|
||||
content = file.read()
|
||||
self.target_text_input.setText(content)
|
||||
|
||||
def select_output_path(self):
|
||||
options = QFileDialog.Options()
|
||||
options |= QFileDialog.DontUseNativeDialog
|
||||
options |= QFileDialog.ShowDirsOnly
|
||||
|
||||
folder_dialog = QFileDialog()
|
||||
folder_dialog.setOptions(options)
|
||||
folder_dialog.setFileMode(QFileDialog.Directory)
|
||||
|
||||
if folder_dialog.exec_():
|
||||
folder_path = folder_dialog.selectedFiles()[0]
|
||||
self.output_input.setText(folder_path)
|
||||
|
||||
def update_ref_audio(self, file_path):
|
||||
self.ref_audio_input.setText(file_path)
|
||||
|
||||
def clear_output(self):
|
||||
self.output_text.clear()
|
||||
|
||||
def synthesize(self):
|
||||
GPT_model_path = self.GPT_model_input.text()
|
||||
SoVITS_model_path = self.SoVITS_model_input.text()
|
||||
ref_audio_path = self.ref_audio_input.text()
|
||||
language_combobox = self.ref_language_combobox.currentText()
|
||||
language_combobox = i18n(language_combobox)
|
||||
ref_text = self.ref_text_input.text()
|
||||
target_language_combobox = self.target_language_combobox.currentText()
|
||||
target_language_combobox = i18n(target_language_combobox)
|
||||
target_text = self.target_text_input.text()
|
||||
output_path = self.output_input.text()
|
||||
|
||||
if GPT_model_path != self.GPT_Path:
|
||||
change_gpt_weights(gpt_path=GPT_model_path)
|
||||
self.GPT_Path = GPT_model_path
|
||||
if SoVITS_model_path != self.SoVITS_Path:
|
||||
change_sovits_weights(sovits_path=SoVITS_model_path)
|
||||
self.SoVITS_Path = SoVITS_model_path
|
||||
|
||||
synthesis_result = get_tts_wav(
|
||||
ref_wav_path=ref_audio_path,
|
||||
prompt_text=ref_text,
|
||||
prompt_language=language_combobox,
|
||||
text=target_text,
|
||||
text_language=target_language_combobox,
|
||||
)
|
||||
|
||||
result_list = list(synthesis_result)
|
||||
|
||||
if result_list:
|
||||
last_sampling_rate, last_audio_data = result_list[-1]
|
||||
output_wav_path = os.path.join(output_path, "output.wav")
|
||||
sf.write(output_wav_path, last_audio_data, last_sampling_rate)
|
||||
|
||||
result = "Audio saved to " + output_wav_path
|
||||
|
||||
self.status_bar.showMessage("合成完成!输出路径:" + output_wav_path, 5000)
|
||||
self.output_text.append("处理结果:\n" + result)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
app = QApplication(sys.argv)
|
||||
mainWin = GPTSoVITSGUI()
|
||||
mainWin.show()
|
||||
sys.exit(app.exec_())
|
File diff suppressed because it is too large
Load Diff
@ -6,32 +6,26 @@
|
||||
全部按英文识别
|
||||
全部按日文识别
|
||||
"""
|
||||
import psutil
|
||||
import os
|
||||
|
||||
def set_high_priority():
|
||||
"""把当前 Python 进程设为 HIGH_PRIORITY_CLASS"""
|
||||
if os.name != "nt":
|
||||
return # 仅 Windows 有效
|
||||
p = psutil.Process(os.getpid())
|
||||
try:
|
||||
p.nice(psutil.HIGH_PRIORITY_CLASS)
|
||||
print("已将进程优先级设为 High")
|
||||
except psutil.AccessDenied:
|
||||
print("权限不足,无法修改优先级(请用管理员运行)")
|
||||
set_high_priority()
|
||||
import json
|
||||
import argparse
|
||||
import contextlib
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
import re
|
||||
import sys
|
||||
from functools import partial
|
||||
|
||||
import gradio as gr
|
||||
import psutil
|
||||
import torch
|
||||
|
||||
now_dir = os.getcwd()
|
||||
sys.path.append(now_dir)
|
||||
sys.path.append("%s/GPT_SoVITS" % (now_dir))
|
||||
from config import change_choices, custom_sort_key, get_dtype, get_weights_names, pretrained_sovits_name
|
||||
from config import infer_device as default_device
|
||||
from GPT_SoVITS.process_ckpt import inspect_version
|
||||
from GPT_SoVITS.TTS_infer_pack.text_segmentation_method import get_method
|
||||
from GPT_SoVITS.TTS_infer_pack.TTS import NO_PROMPT_ERROR, TTS, TTS_Config
|
||||
from tools.assets import css, js, top_html
|
||||
from tools.i18n.i18n import I18nAuto, scan_language_list
|
||||
|
||||
logging.getLogger("markdown_it").setLevel(logging.ERROR)
|
||||
logging.getLogger("urllib3").setLevel(logging.ERROR)
|
||||
@ -41,44 +35,128 @@ logging.getLogger("asyncio").setLevel(logging.ERROR)
|
||||
logging.getLogger("charset_normalizer").setLevel(logging.ERROR)
|
||||
logging.getLogger("torchaudio._extension").setLevel(logging.ERROR)
|
||||
|
||||
|
||||
infer_ttswebui = os.environ.get("infer_ttswebui", 9872)
|
||||
infer_ttswebui = int(infer_ttswebui)
|
||||
is_share = os.environ.get("is_share", "False")
|
||||
is_share = eval(is_share)
|
||||
if "_CUDA_VISIBLE_DEVICES" in os.environ:
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = os.environ["_CUDA_VISIBLE_DEVICES"]
|
||||
|
||||
is_half = eval(os.environ.get("is_half", "True")) and torch.cuda.is_available()
|
||||
gpt_path = os.environ.get("gpt_path", None)
|
||||
sovits_path = os.environ.get("sovits_path", None)
|
||||
cnhubert_base_path = os.environ.get("cnhubert_base_path", None)
|
||||
bert_path = os.environ.get("bert_path", None)
|
||||
version = model_version = os.environ.get("version", "v2")
|
||||
|
||||
import gradio as gr
|
||||
from TTS_infer_pack.text_segmentation_method import get_method
|
||||
from TTS_infer_pack.TTS import NO_PROMPT_ERROR, TTS, TTS_Config
|
||||
|
||||
from tools.assets import css, js, top_html
|
||||
from tools.i18n.i18n import I18nAuto, scan_language_list
|
||||
|
||||
language = os.environ.get("language", "Auto")
|
||||
language = sys.argv[-1] if sys.argv[-1] in scan_language_list() else language
|
||||
i18n = I18nAuto(language=language)
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
|
||||
|
||||
# os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1' # 确保直接启动推理UI时也能够设置。
|
||||
def set_high_priority():
|
||||
if os.name != "nt":
|
||||
return
|
||||
p = psutil.Process(os.getpid())
|
||||
with contextlib.suppress(psutil.AccessDenied):
|
||||
p.nice(psutil.HIGH_PRIORITY_CLASS)
|
||||
print("已将进程优先级设为 High")
|
||||
|
||||
if torch.cuda.is_available():
|
||||
device = "cuda"
|
||||
# elif torch.backends.mps.is_available():
|
||||
# device = "mps"
|
||||
else:
|
||||
device = "cpu"
|
||||
|
||||
# is_half = False
|
||||
# device = "cpu"
|
||||
set_high_priority()
|
||||
|
||||
|
||||
_LANG_RE = re.compile(r"^[a-z]{2}[_-][A-Z]{2}$")
|
||||
|
||||
|
||||
def lang_type(text: str) -> str:
|
||||
if text == "Auto":
|
||||
return text
|
||||
if not _LANG_RE.match(text):
|
||||
raise argparse.ArgumentTypeError(f"Unspported Format: {text}, Expected ll_CC/ll-CC")
|
||||
ll, cc = re.split(r"[_-]", text)
|
||||
language = f"{ll}_{cc}"
|
||||
if language in scan_language_list():
|
||||
return language
|
||||
else:
|
||||
return "Auto"
|
||||
|
||||
|
||||
def build_parser() -> argparse.ArgumentParser:
|
||||
p = argparse.ArgumentParser(
|
||||
prog="inference_webui",
|
||||
description="python -s inference_webui.py zh_CN -i naive",
|
||||
)
|
||||
p.add_argument(
|
||||
"language",
|
||||
nargs="?",
|
||||
default="Auto",
|
||||
type=lang_type,
|
||||
help="Language Code, Such as zh_CN, en-US",
|
||||
)
|
||||
p.add_argument(
|
||||
"--device",
|
||||
"-d",
|
||||
default=str(default_device),
|
||||
help="Inference Device",
|
||||
required=False,
|
||||
)
|
||||
p.add_argument(
|
||||
"--port",
|
||||
"-p",
|
||||
default=9872,
|
||||
type=int,
|
||||
help="WebUI Binding Port",
|
||||
required=False,
|
||||
)
|
||||
p.add_argument(
|
||||
"--share",
|
||||
"-s",
|
||||
default=False,
|
||||
action="store_true",
|
||||
help="Gradio Share Link",
|
||||
required=False,
|
||||
)
|
||||
p.add_argument(
|
||||
"--cnhubert",
|
||||
default="GPT_SoVITS/pretrained_models/chinese-hubert-base",
|
||||
help="CNHuBERT Pretrain",
|
||||
required=False,
|
||||
)
|
||||
p.add_argument(
|
||||
"--bert",
|
||||
default="GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large",
|
||||
help="BERT Pretrain",
|
||||
required=False,
|
||||
)
|
||||
p.add_argument(
|
||||
"--gpt",
|
||||
default="",
|
||||
help="GPT Model",
|
||||
required=False,
|
||||
)
|
||||
p.add_argument(
|
||||
"--sovits",
|
||||
default="",
|
||||
help="SoVITS Model",
|
||||
required=False,
|
||||
)
|
||||
|
||||
return p
|
||||
|
||||
|
||||
args = build_parser().parse_args()
|
||||
|
||||
|
||||
infer_ttswebui = int(args.port)
|
||||
is_share = args.share
|
||||
|
||||
infer_device = torch.device(args.device)
|
||||
device = infer_device
|
||||
|
||||
if torch.mps.is_available():
|
||||
device = torch.device("cpu")
|
||||
|
||||
dtype = get_dtype(device.index)
|
||||
is_half = dtype == torch.float16
|
||||
|
||||
i18n = I18nAuto(language=args.language)
|
||||
change_choices_gradio = partial(change_choices, i18n=i18n)
|
||||
|
||||
SoVITS_names, GPT_names = get_weights_names(i18n)
|
||||
|
||||
gpt_path = str(args.gpt) or GPT_names[0][-1]
|
||||
sovits_path = str(args.sovits) or SoVITS_names[0][-1]
|
||||
|
||||
cnhubert_base_path = str(args.cuhubert)
|
||||
bert_path = str(args.bert)
|
||||
|
||||
version = model_version = "v2"
|
||||
|
||||
|
||||
dict_language_v1 = {
|
||||
i18n("中文"): "all_zh", # 全部按中文识别
|
||||
@ -112,10 +190,6 @@ cut_method = {
|
||||
i18n("按标点符号切"): "cut5",
|
||||
}
|
||||
|
||||
from config import change_choices, get_weights_names, name2gpt_path, name2sovits_path
|
||||
|
||||
SoVITS_names, GPT_names = get_weights_names()
|
||||
from config import pretrained_sovits_name
|
||||
|
||||
path_sovits_v3 = pretrained_sovits_name["v3"]
|
||||
path_sovits_v4 = pretrained_sovits_name["v4"]
|
||||
@ -128,12 +202,8 @@ tts_config.is_half = is_half
|
||||
# tts_config.version = version
|
||||
tts_config.update_version(version)
|
||||
if gpt_path is not None:
|
||||
if "!" in gpt_path or "!" in gpt_path:
|
||||
gpt_path = name2gpt_path[gpt_path]
|
||||
tts_config.t2s_weights_path = gpt_path
|
||||
if sovits_path is not None:
|
||||
if "!" in sovits_path or "!" in sovits_path:
|
||||
sovits_path = name2sovits_path[sovits_path]
|
||||
tts_config.vits_weights_path = sovits_path
|
||||
if cnhubert_base_path is not None:
|
||||
tts_config.cnhuhbert_base_path = cnhubert_base_path
|
||||
@ -201,62 +271,31 @@ def inference(
|
||||
gr.Warning(i18n("V3不支持无参考文本模式,请填写参考文本!"))
|
||||
|
||||
|
||||
def custom_sort_key(s):
|
||||
# 使用正则表达式提取字符串中的数字部分和非数字部分
|
||||
parts = re.split("(\d+)", s)
|
||||
# 将数字部分转换为整数,非数字部分保持不变
|
||||
parts = [int(part) if part.isdigit() else part for part in parts]
|
||||
return parts
|
||||
|
||||
|
||||
if os.path.exists("./weight.json"):
|
||||
pass
|
||||
else:
|
||||
with open("./weight.json", "w", encoding="utf-8") as file:
|
||||
json.dump({"GPT": {}, "SoVITS": {}}, file)
|
||||
|
||||
with open("./weight.json", "r", encoding="utf-8") as file:
|
||||
weight_data = file.read()
|
||||
weight_data = json.loads(weight_data)
|
||||
gpt_path = os.environ.get("gpt_path", weight_data.get("GPT", {}).get(version, GPT_names[-1]))
|
||||
sovits_path = os.environ.get("sovits_path", weight_data.get("SoVITS", {}).get(version, SoVITS_names[0]))
|
||||
if isinstance(gpt_path, list):
|
||||
gpt_path = gpt_path[0]
|
||||
if isinstance(sovits_path, list):
|
||||
sovits_path = sovits_path[0]
|
||||
|
||||
from process_ckpt import get_sovits_version_from_path_fast
|
||||
|
||||
v3v4set = {"v3", "v4"}
|
||||
|
||||
|
||||
def change_sovits_weights(sovits_path, prompt_language=None, text_language=None):
|
||||
if "!" in sovits_path or "!" in sovits_path:
|
||||
sovits_path = name2sovits_path[sovits_path]
|
||||
global version, model_version, dict_language, if_lora_v3
|
||||
version, model_version, if_lora_v3 = get_sovits_version_from_path_fast(sovits_path)
|
||||
# print(sovits_path,version, model_version, if_lora_v3)
|
||||
global version, model_version, dict_language, is_lora
|
||||
model_version, version, is_lora, _, __ = inspect_version(sovits_path)
|
||||
# print(sovits_path,version, model_version, is_lora)
|
||||
is_exist = is_exist_s2gv3 if model_version == "v3" else is_exist_s2gv4
|
||||
path_sovits = path_sovits_v3 if model_version == "v3" else path_sovits_v4
|
||||
if if_lora_v3 == True and is_exist == False:
|
||||
info = path_sovits + "SoVITS %s" % model_version + i18n("底模缺失,无法加载相应 LoRA 权重")
|
||||
if is_lora is True and is_exist is False:
|
||||
info = path_sovits + f"SoVITS {model_version}" + i18n("底模缺失,无法加载相应 LoRA 权重")
|
||||
gr.Warning(info)
|
||||
raise FileExistsError(info)
|
||||
dict_language = dict_language_v1 if version == "v1" else dict_language_v2
|
||||
if prompt_language is not None and text_language is not None:
|
||||
if prompt_language in list(dict_language.keys()):
|
||||
prompt_text_update, prompt_language_update = (
|
||||
{"__type__": "update"},
|
||||
{"__type__": "update", "value": prompt_language},
|
||||
)
|
||||
prompt_text_update, prompt_language_update = gr.skip(), gr.skip()
|
||||
else:
|
||||
prompt_text_update = {"__type__": "update", "value": ""}
|
||||
prompt_language_update = {"__type__": "update", "value": i18n("中文")}
|
||||
prompt_text_update = gr.update(value="")
|
||||
prompt_language_update = gr.update(value=i18n("中文"))
|
||||
if text_language in list(dict_language.keys()):
|
||||
text_update, text_language_update = {"__type__": "update"}, {"__type__": "update", "value": text_language}
|
||||
text_update, text_language_update = gr.skip(), gr.skip()
|
||||
else:
|
||||
text_update = {"__type__": "update", "value": ""}
|
||||
text_language_update = {"__type__": "update", "value": i18n("中文")}
|
||||
text_update = gr.update(value="")
|
||||
text_language_update = gr.update(value=i18n("中文"))
|
||||
if model_version in v3v4set:
|
||||
visible_sample_steps = True
|
||||
visible_inp_refs = False
|
||||
@ -264,42 +303,42 @@ def change_sovits_weights(sovits_path, prompt_language=None, text_language=None)
|
||||
visible_sample_steps = False
|
||||
visible_inp_refs = True
|
||||
yield (
|
||||
{"__type__": "update", "choices": list(dict_language.keys())},
|
||||
{"__type__": "update", "choices": list(dict_language.keys())},
|
||||
gr.update(choices=list(dict_language.keys())),
|
||||
gr.update(choices=list(dict_language.keys())),
|
||||
prompt_text_update,
|
||||
prompt_language_update,
|
||||
text_update,
|
||||
text_language_update,
|
||||
{"__type__": "update", "interactive": visible_sample_steps, "value": 32},
|
||||
{"__type__": "update", "visible": visible_inp_refs},
|
||||
{"__type__": "update", "interactive": True if model_version not in v3v4set else False},
|
||||
{"__type__": "update", "value": i18n("模型加载中,请等待"), "interactive": False},
|
||||
gr.update(
|
||||
visible=visible_sample_steps,
|
||||
value=32 if model_version == "v3" else 8,
|
||||
choices=[4, 8, 16, 32, 64, 128] if model_version == "v3" else [4, 8, 16, 32],
|
||||
),
|
||||
gr.update(visible=visible_inp_refs),
|
||||
gr.update(interactive=True if model_version not in v3v4set else False),
|
||||
gr.update(value=i18n("模型加载中,请等待"), interactive=False),
|
||||
)
|
||||
|
||||
tts_pipeline.init_vits_weights(sovits_path)
|
||||
yield (
|
||||
{"__type__": "update", "choices": list(dict_language.keys())},
|
||||
{"__type__": "update", "choices": list(dict_language.keys())},
|
||||
gr.update(choices=list(dict_language.keys())),
|
||||
gr.update(choices=list(dict_language.keys())),
|
||||
prompt_text_update,
|
||||
prompt_language_update,
|
||||
text_update,
|
||||
text_language_update,
|
||||
{"__type__": "update", "interactive": visible_sample_steps, "value": 32},
|
||||
{"__type__": "update", "visible": visible_inp_refs},
|
||||
{"__type__": "update", "interactive": True if model_version not in v3v4set else False},
|
||||
{"__type__": "update", "value": i18n("合成语音"), "interactive": True},
|
||||
gr.update(
|
||||
visible=visible_sample_steps,
|
||||
value=32 if model_version == "v3" else 8,
|
||||
choices=[4, 8, 16, 32, 64, 128] if model_version == "v3" else [4, 8, 16, 32],
|
||||
),
|
||||
gr.update(visible=visible_inp_refs),
|
||||
gr.update(interactive=True if model_version not in v3v4set else False),
|
||||
gr.update(value=i18n("合成语音"), interactive=True),
|
||||
)
|
||||
with open("./weight.json") as f:
|
||||
data = f.read()
|
||||
data = json.loads(data)
|
||||
data["SoVITS"][version] = sovits_path
|
||||
with open("./weight.json", "w") as f:
|
||||
f.write(json.dumps(data))
|
||||
|
||||
|
||||
def change_gpt_weights(gpt_path):
|
||||
if "!" in gpt_path or "!" in gpt_path:
|
||||
gpt_path = name2gpt_path[gpt_path]
|
||||
tts_pipeline.init_t2s_weights(gpt_path)
|
||||
|
||||
|
||||
@ -315,7 +354,7 @@ with gr.Blocks(title="GPT-SoVITS WebUI", analytics_enabled=False, js=js, css=css
|
||||
with gr.Column():
|
||||
# with gr.Group():
|
||||
gr.Markdown(value=i18n("模型切换"))
|
||||
with gr.Row():
|
||||
with gr.Row(equal_height=True):
|
||||
GPT_dropdown = gr.Dropdown(
|
||||
label=i18n("GPT模型列表"),
|
||||
choices=sorted(GPT_names, key=custom_sort_key),
|
||||
@ -329,20 +368,24 @@ with gr.Blocks(title="GPT-SoVITS WebUI", analytics_enabled=False, js=js, css=css
|
||||
interactive=True,
|
||||
)
|
||||
refresh_button = gr.Button(i18n("刷新模型路径"), variant="primary")
|
||||
refresh_button.click(fn=change_choices, inputs=[], outputs=[SoVITS_dropdown, GPT_dropdown])
|
||||
refresh_button.click(fn=change_choices_gradio, inputs=[], outputs=[SoVITS_dropdown, GPT_dropdown])
|
||||
|
||||
with gr.Row():
|
||||
with gr.Row(equal_height=True):
|
||||
with gr.Column():
|
||||
gr.Markdown(value=i18n("*请上传并填写参考信息"))
|
||||
with gr.Row():
|
||||
inp_ref = gr.Audio(label=i18n("主参考音频(请上传3~10秒内参考音频,超过会报错!)"), type="filepath")
|
||||
with gr.Row(equal_height=True):
|
||||
inp_ref = gr.Audio(
|
||||
label=i18n("主参考音频(请上传3~10秒内参考音频,超过会报错!)"),
|
||||
type="filepath",
|
||||
waveform_options={"show_recording_waveform": False},
|
||||
)
|
||||
inp_refs = gr.File(
|
||||
label=i18n("辅参考音频(可选多个,或不选)"),
|
||||
file_count="multiple",
|
||||
visible=True if model_version != "v3" else False,
|
||||
)
|
||||
prompt_text = gr.Textbox(label=i18n("主参考音频的文本"), value="", lines=2)
|
||||
with gr.Row():
|
||||
with gr.Row(equal_height=True):
|
||||
prompt_language = gr.Dropdown(
|
||||
label=i18n("主参考音频的语种"), choices=list(dict_language.keys()), value=i18n("中文")
|
||||
)
|
||||
@ -368,26 +411,26 @@ with gr.Blocks(title="GPT-SoVITS WebUI", analytics_enabled=False, js=js, css=css
|
||||
|
||||
with gr.Group():
|
||||
gr.Markdown(value=i18n("推理设置"))
|
||||
with gr.Row():
|
||||
with gr.Row(equal_height=True):
|
||||
with gr.Column():
|
||||
with gr.Row():
|
||||
with gr.Row(equal_height=True):
|
||||
batch_size = gr.Slider(
|
||||
minimum=1, maximum=200, step=1, label=i18n("batch_size"), value=20, interactive=True
|
||||
)
|
||||
sample_steps = gr.Radio(
|
||||
label=i18n("采样步数(仅对V3/4生效)"), value=32, choices=[4, 8, 16, 32, 64, 128], visible=True
|
||||
)
|
||||
with gr.Row():
|
||||
with gr.Row(equal_height=True):
|
||||
fragment_interval = gr.Slider(
|
||||
minimum=0.01, maximum=1, step=0.01, label=i18n("分段间隔(秒)"), value=0.3, interactive=True
|
||||
)
|
||||
speed_factor = gr.Slider(
|
||||
minimum=0.6, maximum=1.65, step=0.05, label="语速", value=1.0, interactive=True
|
||||
)
|
||||
with gr.Row():
|
||||
with gr.Row(equal_height=True):
|
||||
top_k = gr.Slider(minimum=1, maximum=100, step=1, label=i18n("top_k"), value=5, interactive=True)
|
||||
top_p = gr.Slider(minimum=0, maximum=1, step=0.05, label=i18n("top_p"), value=1, interactive=True)
|
||||
with gr.Row():
|
||||
with gr.Row(equal_height=True):
|
||||
temperature = gr.Slider(
|
||||
minimum=0, maximum=1, step=0.05, label=i18n("temperature"), value=1, interactive=True
|
||||
)
|
||||
@ -396,7 +439,7 @@ with gr.Blocks(title="GPT-SoVITS WebUI", analytics_enabled=False, js=js, css=css
|
||||
)
|
||||
|
||||
with gr.Column():
|
||||
with gr.Row():
|
||||
with gr.Row(equal_height=True):
|
||||
how_to_cut = gr.Dropdown(
|
||||
label=i18n("怎么切"),
|
||||
choices=[
|
||||
@ -415,7 +458,7 @@ with gr.Blocks(title="GPT-SoVITS WebUI", analytics_enabled=False, js=js, css=css
|
||||
label=i18n("音频超采样(仅对V3生效))"), value=False, interactive=True, show_label=True
|
||||
)
|
||||
|
||||
with gr.Row():
|
||||
with gr.Row(equal_height=True):
|
||||
parallel_infer = gr.Checkbox(label=i18n("并行推理"), value=True, interactive=True, show_label=True)
|
||||
split_bucket = gr.Checkbox(
|
||||
label=i18n("数据分桶(并行推理时会降低一点计算量)"),
|
||||
@ -424,12 +467,15 @@ with gr.Blocks(title="GPT-SoVITS WebUI", analytics_enabled=False, js=js, css=css
|
||||
show_label=True,
|
||||
)
|
||||
|
||||
with gr.Row():
|
||||
with gr.Row(equal_height=True):
|
||||
seed = gr.Number(label=i18n("随机种子"), value=-1)
|
||||
keep_random = gr.Checkbox(label=i18n("保持随机"), value=True, interactive=True, show_label=True)
|
||||
|
||||
output = gr.Audio(label=i18n("输出的语音"))
|
||||
with gr.Row():
|
||||
output = gr.Audio(
|
||||
label=i18n("输出的语音"),
|
||||
waveform_options={"show_recording_waveform": False},
|
||||
)
|
||||
with gr.Row(equal_height=True):
|
||||
inference_button = gr.Button(i18n("合成语音"), variant="primary")
|
||||
stop_infer = gr.Button(i18n("终止合成"), variant="primary")
|
||||
|
||||
@ -485,7 +531,7 @@ with gr.Blocks(title="GPT-SoVITS WebUI", analytics_enabled=False, js=js, css=css
|
||||
"文本切分工具。太长的文本合成出来效果不一定好,所以太长建议先切。合成会根据文本的换行分开合成再拼起来。"
|
||||
)
|
||||
)
|
||||
with gr.Row():
|
||||
with gr.Row(equal_height=True):
|
||||
text_inp = gr.Textbox(label=i18n("需要合成的切分前文本"), value="", lines=4)
|
||||
with gr.Column():
|
||||
_how_to_cut = gr.Radio(
|
||||
|
@ -1,10 +1,13 @@
|
||||
import math
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
from module import commons
|
||||
from module.modules import LayerNorm
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn import functional as F
|
||||
from torch.nn.utils import remove_weight_norm
|
||||
from torch.nn.utils.parametrizations import weight_norm
|
||||
|
||||
from . import commons
|
||||
from .modules import LayerNorm
|
||||
|
||||
|
||||
class Encoder(nn.Module):
|
||||
@ -392,10 +395,6 @@ class FFN(nn.Module):
|
||||
return x
|
||||
|
||||
|
||||
import torch.nn as nn
|
||||
from torch.nn.utils import remove_weight_norm, weight_norm
|
||||
|
||||
|
||||
class Depthwise_Separable_Conv1D(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -1,11 +1,11 @@
|
||||
import math
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
from module import commons
|
||||
|
||||
from typing import Optional
|
||||
from . import commons
|
||||
|
||||
|
||||
class LayerNorm(nn.Module):
|
||||
|
@ -1,15 +1,17 @@
|
||||
import os
|
||||
import random
|
||||
import traceback
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torch.utils.data
|
||||
from tqdm import tqdm
|
||||
|
||||
from module.mel_processing import spectrogram_torch, spec_to_mel_torch
|
||||
from text import cleaned_text_to_sequence
|
||||
import torch.nn.functional as F
|
||||
from GPT_SoVITS.text import cleaned_text_to_sequence
|
||||
from tools.my_utils import load_audio
|
||||
|
||||
from .mel_processing import spec_to_mel_torch, spectrogram_torch
|
||||
|
||||
version = os.environ.get("version", None)
|
||||
|
||||
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
x
Reference in New Issue
Block a user