mirror of
https://github.com/RVC-Boss/GPT-SoVITS.git
synced 2025-04-23 18:02:58 +08:00
Compare commits
174 Commits
20240821v2
...
main
Author | SHA1 | Date | |
---|---|---|---|
|
9202c74761 | ||
|
6a1ece8992 | ||
|
e31d67eeff | ||
|
fbdab94e17 | ||
|
a19f49604f | ||
|
590c83d766 | ||
|
7405427a0a | ||
|
e0f2818df7 | ||
|
bc2fe5ec86 | ||
|
839ff9ce5b | ||
|
8b394a15bc | ||
|
ec7ec370ef | ||
|
9d481da610 | ||
|
50e9ba0218 | ||
|
c6cb6b45f3 | ||
|
e0c452f007 | ||
|
b43ae64a1e | ||
|
c0b46314ca | ||
|
53cac93589 | ||
|
9da7e17efe | ||
|
b0de354c63 | ||
|
41090e5a7c | ||
|
605b380114 | ||
|
9f8d455130 | ||
|
7abae557fb | ||
|
6a60e5edb1 | ||
|
28bdff356f | ||
|
03b662a769 | ||
|
6c468583c5 | ||
|
ee4a466f79 | ||
|
b65ea9181e | ||
|
c0ce55a132 | ||
|
13573a1b06 | ||
|
fef65d40fe | ||
|
b0e465eb72 | ||
|
f1332ff53a | ||
|
b88bd391fc | ||
|
4635cb4293 | ||
|
86e6dea694 | ||
|
d7c24e9ac9 | ||
|
6c1c1bb72a | ||
|
265586990c | ||
|
7394dc7b0c | ||
|
165882d64f | ||
|
271db6a4de | ||
|
053a356ffe | ||
|
fe2f04bdb8 | ||
|
6dd2f72090 | ||
|
959a2ddbeb | ||
|
bb8a8efeca | ||
|
df33574a26 | ||
|
4fd57b0ea7 | ||
|
118c2bed68 | ||
|
a32a2b8934 | ||
|
c38b169019 | ||
|
a69be1eae7 | ||
|
9ef2c00bbc | ||
|
b446f7954b | ||
|
ff299d17d3 | ||
|
e9af7921fa | ||
|
3356fc9e09 | ||
|
531a38f119 | ||
|
780524e0cc | ||
|
e35ade8f60 | ||
|
fc1400bdba | ||
|
2cd843dcbc | ||
|
94ffcbe616 | ||
|
a0ff6fa7a2 | ||
|
0b20a949ed | ||
|
a68e3c4354 | ||
|
ffcba8e553 | ||
|
250b1c73cb | ||
|
060a0d91dc | ||
|
af80e8f113 | ||
|
849bd36684 | ||
|
8e95391caa | ||
|
6c43e2f052 | ||
|
49dd04d051 | ||
|
ffbf205ae4 | ||
|
00bdc01113 | ||
|
9208fb8157 | ||
|
b7a6e43a4c | ||
|
b88d78d64b | ||
|
c242346280 | ||
|
7c56946d95 | ||
|
270a41eb89 | ||
|
92961c3f68 | ||
|
25a6829ddc | ||
|
00cbadcaad | ||
|
db811b7cc8 | ||
|
778fbc2880 | ||
|
237526a7e6 | ||
|
8158a97909 | ||
|
c4606c1cc1 | ||
|
e061e9d38e | ||
|
fbb9f21e53 | ||
|
aa07216bba | ||
|
514fb692db | ||
|
e6a32e15b0 | ||
|
e937b625e4 | ||
|
56509a17c9 | ||
|
11a0d9a265 | ||
|
137cb26d58 | ||
|
3737496389 | ||
|
a2bb1dab91 | ||
|
c17dd642c7 | ||
|
c70daefea2 | ||
|
087fd24579 | ||
|
f454834cbb | ||
|
c09af5ab7d | ||
|
1c5edd2949 | ||
|
1867780d56 | ||
|
72d839e40a | ||
|
16941a7c14 | ||
|
87a3b908ee | ||
|
d8fc921771 | ||
|
cfb5bed554 | ||
|
753472c9bd | ||
|
1b25e1097a | ||
|
c2b3298bed | ||
|
86acb7a89d | ||
|
8e9e3c07d5 | ||
|
598cddabaf | ||
|
282ae1d9b2 | ||
|
ca3cc4997a | ||
|
15cbd1b673 | ||
|
6e2b49186c | ||
|
347becb4e4 | ||
|
f29921d85d | ||
|
38e93fd047 | ||
|
8109e9bbfd | ||
|
53aa20350f | ||
|
504b72a489 | ||
|
05645a6593 | ||
|
0db482e87d | ||
|
206325a402 | ||
|
6b12b4b10b | ||
|
17d9be2a70 | ||
|
f7b617d4f7 | ||
|
89f1194c70 | ||
|
25d1aaf202 | ||
|
76c8150fea | ||
|
96a7dca4b8 | ||
|
bfe241143b | ||
|
c8f7604ba7 | ||
|
190dd6198c | ||
|
21eaf84f23 | ||
|
00ab793ff2 | ||
|
4d5e9d27a9 | ||
|
43eabf21da | ||
|
25cb1bf400 | ||
|
fa42d26d0e | ||
|
ed207c4b87 | ||
|
b7a904a671 | ||
|
a1fe2267af | ||
|
793b0043de | ||
|
a70e1ad30c | ||
|
6d82050b9c | ||
|
98cc47699c | ||
|
5d126f98b2 | ||
|
eee607b71d | ||
|
a95b2b85f7 | ||
|
38cd881578 | ||
|
5efb960898 | ||
|
78c68d46cb | ||
|
192ea6f6c9 | ||
|
0c000191b3 | ||
|
570da092c9 | ||
|
40cd22e69d | ||
|
3488cffd68 | ||
|
d67bbd2166 | ||
|
f35f6e9b5e | ||
|
7dac47ca95 | ||
|
2a9512a63e |
184
.gitignore
vendored
184
.gitignore
vendored
@ -12,7 +12,189 @@ GPT_weights
|
|||||||
SoVITS_weights
|
SoVITS_weights
|
||||||
GPT_weights_v2
|
GPT_weights_v2
|
||||||
SoVITS_weights_v2
|
SoVITS_weights_v2
|
||||||
|
GPT_weights_v3
|
||||||
|
SoVITS_weights_v3
|
||||||
TEMP
|
TEMP
|
||||||
weight.json
|
weight.json
|
||||||
ffmpeg*
|
ffmpeg*
|
||||||
ffprobe*
|
ffprobe*
|
||||||
|
cfg.json
|
||||||
|
speakers.json
|
||||||
|
ref_audios
|
||||||
|
tools/AP_BWE_main/24kto48k/*
|
||||||
|
!tools/AP_BWE_main/24kto48k/readme.txt
|
||||||
|
|
||||||
|
# Byte-compiled / optimized / DLL files
|
||||||
|
__pycache__/
|
||||||
|
*.py[cod]
|
||||||
|
*$py.class
|
||||||
|
|
||||||
|
# C extensions
|
||||||
|
*.so
|
||||||
|
|
||||||
|
# Distribution / packaging
|
||||||
|
.Python
|
||||||
|
build/
|
||||||
|
develop-eggs/
|
||||||
|
dist/
|
||||||
|
downloads/
|
||||||
|
eggs/
|
||||||
|
.eggs/
|
||||||
|
lib/
|
||||||
|
lib64/
|
||||||
|
parts/
|
||||||
|
sdist/
|
||||||
|
var/
|
||||||
|
wheels/
|
||||||
|
share/python-wheels/
|
||||||
|
*.egg-info/
|
||||||
|
.installed.cfg
|
||||||
|
*.egg
|
||||||
|
MANIFEST
|
||||||
|
|
||||||
|
# PyInstaller
|
||||||
|
# Usually these files are written by a python script from a template
|
||||||
|
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||||
|
*.manifest
|
||||||
|
*.spec
|
||||||
|
|
||||||
|
# Installer logs
|
||||||
|
pip-log.txt
|
||||||
|
pip-delete-this-directory.txt
|
||||||
|
|
||||||
|
# Unit test / coverage reports
|
||||||
|
htmlcov/
|
||||||
|
.tox/
|
||||||
|
.nox/
|
||||||
|
.coverage
|
||||||
|
.coverage.*
|
||||||
|
.cache
|
||||||
|
nosetests.xml
|
||||||
|
coverage.xml
|
||||||
|
*.cover
|
||||||
|
*.py,cover
|
||||||
|
.hypothesis/
|
||||||
|
.pytest_cache/
|
||||||
|
cover/
|
||||||
|
|
||||||
|
# Translations
|
||||||
|
*.mo
|
||||||
|
*.pot
|
||||||
|
|
||||||
|
# Django stuff:
|
||||||
|
*.log
|
||||||
|
local_settings.py
|
||||||
|
db.sqlite3
|
||||||
|
db.sqlite3-journal
|
||||||
|
|
||||||
|
# Flask stuff:
|
||||||
|
instance/
|
||||||
|
.webassets-cache
|
||||||
|
|
||||||
|
# Scrapy stuff:
|
||||||
|
.scrapy
|
||||||
|
|
||||||
|
# Sphinx documentation
|
||||||
|
docs/_build/
|
||||||
|
|
||||||
|
# PyBuilder
|
||||||
|
.pybuilder/
|
||||||
|
target/
|
||||||
|
|
||||||
|
# Jupyter Notebook
|
||||||
|
.ipynb_checkpoints
|
||||||
|
|
||||||
|
# IPython
|
||||||
|
profile_default/
|
||||||
|
ipython_config.py
|
||||||
|
|
||||||
|
# pyenv
|
||||||
|
# For a library or package, you might want to ignore these files since the code is
|
||||||
|
# intended to run in multiple environments; otherwise, check them in:
|
||||||
|
# .python-version
|
||||||
|
|
||||||
|
# pipenv
|
||||||
|
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
||||||
|
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
||||||
|
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
||||||
|
# install all needed dependencies.
|
||||||
|
#Pipfile.lock
|
||||||
|
|
||||||
|
# UV
|
||||||
|
# Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
|
||||||
|
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
||||||
|
# commonly ignored for libraries.
|
||||||
|
#uv.lock
|
||||||
|
|
||||||
|
# poetry
|
||||||
|
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
||||||
|
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
||||||
|
# commonly ignored for libraries.
|
||||||
|
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
||||||
|
#poetry.lock
|
||||||
|
|
||||||
|
# pdm
|
||||||
|
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
||||||
|
#pdm.lock
|
||||||
|
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
||||||
|
# in version control.
|
||||||
|
# https://pdm.fming.dev/latest/usage/project/#working-with-version-control
|
||||||
|
.pdm.toml
|
||||||
|
.pdm-python
|
||||||
|
.pdm-build/
|
||||||
|
|
||||||
|
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
||||||
|
__pypackages__/
|
||||||
|
|
||||||
|
# Celery stuff
|
||||||
|
celerybeat-schedule
|
||||||
|
celerybeat.pid
|
||||||
|
|
||||||
|
# SageMath parsed files
|
||||||
|
*.sage.py
|
||||||
|
|
||||||
|
# Environments
|
||||||
|
.env
|
||||||
|
.venv
|
||||||
|
env/
|
||||||
|
venv/
|
||||||
|
ENV/
|
||||||
|
env.bak/
|
||||||
|
venv.bak/
|
||||||
|
|
||||||
|
# Spyder project settings
|
||||||
|
.spyderproject
|
||||||
|
.spyproject
|
||||||
|
|
||||||
|
# Rope project settings
|
||||||
|
.ropeproject
|
||||||
|
|
||||||
|
# mkdocs documentation
|
||||||
|
/site
|
||||||
|
|
||||||
|
# mypy
|
||||||
|
.mypy_cache/
|
||||||
|
.dmypy.json
|
||||||
|
dmypy.json
|
||||||
|
|
||||||
|
# Pyre type checker
|
||||||
|
.pyre/
|
||||||
|
|
||||||
|
# pytype static type analyzer
|
||||||
|
.pytype/
|
||||||
|
|
||||||
|
# Cython debug symbols
|
||||||
|
cython_debug/
|
||||||
|
|
||||||
|
# PyCharm
|
||||||
|
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
||||||
|
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
||||||
|
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
||||||
|
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
||||||
|
#.idea/
|
||||||
|
|
||||||
|
# Ruff stuff:
|
||||||
|
.ruff_cache/
|
||||||
|
|
||||||
|
# PyPI configuration file
|
||||||
|
.pypirc
|
||||||
|
@ -1,5 +1,8 @@
|
|||||||
# Download moda ASR related models
|
# Download moda ASR related models
|
||||||
from modelscope import snapshot_download
|
from modelscope import snapshot_download
|
||||||
model_dir = snapshot_download('damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch',revision="v2.0.4")
|
|
||||||
model_dir = snapshot_download('damo/speech_fsmn_vad_zh-cn-16k-common-pytorch',revision="v2.0.4")
|
model_dir = snapshot_download(
|
||||||
model_dir = snapshot_download('damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch',revision="v2.0.4")
|
"damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch", revision="v2.0.4"
|
||||||
|
)
|
||||||
|
model_dir = snapshot_download("damo/speech_fsmn_vad_zh-cn-16k-common-pytorch", revision="v2.0.4")
|
||||||
|
model_dir = snapshot_download("damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch", revision="v2.0.4")
|
||||||
|
@ -4,14 +4,11 @@ import itertools
|
|||||||
import math
|
import math
|
||||||
import random
|
import random
|
||||||
from random import shuffle
|
from random import shuffle
|
||||||
from typing import Iterator
|
from typing import Iterator, Optional, TypeVar
|
||||||
from typing import Optional
|
|
||||||
from typing import TypeVar
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
from torch.utils.data import Dataset
|
from torch.utils.data import Dataset, Sampler
|
||||||
from torch.utils.data import Sampler
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"DistributedBucketSampler",
|
"DistributedBucketSampler",
|
||||||
@ -50,10 +47,7 @@ class DistributedBucketSampler(Sampler[T_co]):
|
|||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
torch.cuda.set_device(rank)
|
torch.cuda.set_device(rank)
|
||||||
if rank >= num_replicas or rank < 0:
|
if rank >= num_replicas or rank < 0:
|
||||||
raise ValueError(
|
raise ValueError("Invalid rank {}, rank should be in the interval [0, {}]".format(rank, num_replicas - 1))
|
||||||
"Invalid rank {}, rank should be in the interval"
|
|
||||||
" [0, {}]".format(rank, num_replicas - 1)
|
|
||||||
)
|
|
||||||
self.dataset = dataset
|
self.dataset = dataset
|
||||||
self.num_replicas = num_replicas
|
self.num_replicas = num_replicas
|
||||||
self.rank = rank
|
self.rank = rank
|
||||||
@ -61,19 +55,16 @@ class DistributedBucketSampler(Sampler[T_co]):
|
|||||||
self.drop_last = drop_last
|
self.drop_last = drop_last
|
||||||
# If the dataset length is evenly divisible by # of replicas, then there
|
# If the dataset length is evenly divisible by # of replicas, then there
|
||||||
# is no need to drop any data, since the dataset will be split equally.
|
# is no need to drop any data, since the dataset will be split equally.
|
||||||
if (
|
if self.drop_last and len(self.dataset) % self.num_replicas != 0: # type: ignore[arg-type]
|
||||||
self.drop_last and len(self.dataset) % self.num_replicas != 0
|
|
||||||
): # type: ignore[arg-type]
|
|
||||||
# Split to nearest available length that is evenly divisible.
|
# Split to nearest available length that is evenly divisible.
|
||||||
# This is to ensure each rank receives the same amount of data when
|
# This is to ensure each rank receives the same amount of data when
|
||||||
# using this Sampler.
|
# using this Sampler.
|
||||||
self.num_samples = math.ceil(
|
self.num_samples = math.ceil(
|
||||||
(len(self.dataset) - self.num_replicas)
|
(len(self.dataset) - self.num_replicas) / self.num_replicas, # type: ignore[arg-type]
|
||||||
/ self.num_replicas # type: ignore[arg-type]
|
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self.num_samples = math.ceil(
|
self.num_samples = math.ceil(
|
||||||
len(self.dataset) / self.num_replicas
|
len(self.dataset) / self.num_replicas,
|
||||||
) # type: ignore[arg-type]
|
) # type: ignore[arg-type]
|
||||||
self.total_size = self.num_samples * self.num_replicas
|
self.total_size = self.num_samples * self.num_replicas
|
||||||
self.shuffle = shuffle
|
self.shuffle = shuffle
|
||||||
@ -118,10 +109,7 @@ class DistributedBucketSampler(Sampler[T_co]):
|
|||||||
grouped_batch_size = self.batch_size * self.num_replicas
|
grouped_batch_size = self.batch_size * self.num_replicas
|
||||||
shuffled_bucket = list(itertools.chain(*shuffled_bucket))
|
shuffled_bucket = list(itertools.chain(*shuffled_bucket))
|
||||||
n_batch = int(math.ceil(len(shuffled_bucket) / grouped_batch_size))
|
n_batch = int(math.ceil(len(shuffled_bucket) / grouped_batch_size))
|
||||||
batches = [
|
batches = [shuffled_bucket[b * grouped_batch_size : (b + 1) * grouped_batch_size] for b in range(n_batch)]
|
||||||
shuffled_bucket[b * grouped_batch_size : (b + 1) * grouped_batch_size]
|
|
||||||
for b in range(n_batch)
|
|
||||||
]
|
|
||||||
shuffle(batches)
|
shuffle(batches)
|
||||||
indices = list(itertools.chain(*batches))
|
indices = list(itertools.chain(*batches))
|
||||||
else:
|
else:
|
||||||
@ -134,9 +122,7 @@ class DistributedBucketSampler(Sampler[T_co]):
|
|||||||
if padding_size <= len(indices):
|
if padding_size <= len(indices):
|
||||||
indices += indices[:padding_size]
|
indices += indices[:padding_size]
|
||||||
else:
|
else:
|
||||||
indices += (indices * math.ceil(padding_size / len(indices)))[
|
indices += (indices * math.ceil(padding_size / len(indices)))[:padding_size]
|
||||||
:padding_size
|
|
||||||
]
|
|
||||||
else:
|
else:
|
||||||
# remove tail of data to make it evenly divisible.
|
# remove tail of data to make it evenly divisible.
|
||||||
indices = indices[: self.total_size]
|
indices = indices[: self.total_size]
|
||||||
|
@ -1,9 +1,10 @@
|
|||||||
# modified from https://github.com/yangdongchao/SoundStorm/blob/master/soundstorm/s1/AR/data/data_module.py
|
# modified from https://github.com/yangdongchao/SoundStorm/blob/master/soundstorm/s1/AR/data/data_module.py
|
||||||
# reference: https://github.com/lifeiteng/vall-e
|
# reference: https://github.com/lifeiteng/vall-e
|
||||||
from pytorch_lightning import LightningDataModule
|
from pytorch_lightning import LightningDataModule
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
from AR.data.bucket_sampler import DistributedBucketSampler
|
from AR.data.bucket_sampler import DistributedBucketSampler
|
||||||
from AR.data.dataset import Text2SemanticDataset
|
from AR.data.dataset import Text2SemanticDataset
|
||||||
from torch.utils.data import DataLoader
|
|
||||||
|
|
||||||
|
|
||||||
class Text2SemanticDataModule(LightningDataModule):
|
class Text2SemanticDataModule(LightningDataModule):
|
||||||
@ -42,8 +43,12 @@ class Text2SemanticDataModule(LightningDataModule):
|
|||||||
# pad_val=self.config['data']['pad_val'])
|
# pad_val=self.config['data']['pad_val'])
|
||||||
|
|
||||||
def train_dataloader(self):
|
def train_dataloader(self):
|
||||||
batch_size=self.config["train"]["batch_size"]//2 if self.config["train"].get("if_dpo",False)==True else self.config["train"]["batch_size"]
|
batch_size = (
|
||||||
batch_size = max(min(batch_size,len(self._train_dataset)//4),1)#防止不保存
|
self.config["train"]["batch_size"] // 2
|
||||||
|
if self.config["train"].get("if_dpo", False) is True
|
||||||
|
else self.config["train"]["batch_size"]
|
||||||
|
)
|
||||||
|
batch_size = max(min(batch_size, len(self._train_dataset) // 4), 1) # 防止不保存
|
||||||
sampler = DistributedBucketSampler(self._train_dataset, batch_size=batch_size)
|
sampler = DistributedBucketSampler(self._train_dataset, batch_size=batch_size)
|
||||||
return DataLoader(
|
return DataLoader(
|
||||||
self._train_dataset,
|
self._train_dataset,
|
||||||
|
@ -1,21 +1,17 @@
|
|||||||
# modified from https://github.com/yangdongchao/SoundStorm/blob/master/soundstorm/s1/AR/data/dataset.py
|
# modified from https://github.com/yangdongchao/SoundStorm/blob/master/soundstorm/s1/AR/data/dataset.py
|
||||||
# reference: https://github.com/lifeiteng/vall-e
|
# reference: https://github.com/lifeiteng/vall-e
|
||||||
import pdb
|
|
||||||
import sys
|
|
||||||
|
|
||||||
# sys.path.append("/data/docker/liujing04/gpt-vits/mq-vits-s1bert_no_bert")
|
# sys.path.append("/data/docker/liujing04/gpt-vits/mq-vits-s1bert_no_bert")
|
||||||
import traceback, os
|
import os
|
||||||
from typing import Dict
|
import traceback
|
||||||
from typing import List
|
from typing import Dict, List
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
import torch, json
|
import torch
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader, Dataset
|
||||||
from torch.utils.data import Dataset
|
|
||||||
from transformers import AutoTokenizer
|
|
||||||
|
|
||||||
version = os.environ.get('version',None)
|
version = os.environ.get("version", None)
|
||||||
|
|
||||||
from text import cleaned_text_to_sequence
|
from text import cleaned_text_to_sequence
|
||||||
|
|
||||||
@ -34,9 +30,7 @@ def batch_sequences(sequences: List[np.array], axis: int = 0, pad_value: int = 0
|
|||||||
|
|
||||||
padded_sequences = []
|
padded_sequences = []
|
||||||
for seq, length in zip(sequences, seq_lengths):
|
for seq, length in zip(sequences, seq_lengths):
|
||||||
padding = (
|
padding = [(0, 0)] * axis + [(0, max_length - length)] + [(0, 0)] * (ndim - axis - 1)
|
||||||
[(0, 0)] * axis + [(0, max_length - length)] + [(0, 0)] * (ndim - axis - 1)
|
|
||||||
)
|
|
||||||
padded_seq = np.pad(seq, padding, mode="constant", constant_values=pad_value)
|
padded_seq = np.pad(seq, padding, mode="constant", constant_values=pad_value)
|
||||||
padded_sequences.append(padded_seq)
|
padded_sequences.append(padded_seq)
|
||||||
batch = np.stack(padded_sequences)
|
batch = np.stack(padded_sequences)
|
||||||
@ -61,12 +55,16 @@ class Text2SemanticDataset(Dataset):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.semantic_data = pd.read_csv(
|
self.semantic_data = pd.read_csv(
|
||||||
semantic_path, delimiter="\t", encoding="utf-8"
|
semantic_path,
|
||||||
|
delimiter="\t",
|
||||||
|
encoding="utf-8",
|
||||||
)
|
)
|
||||||
# get dict
|
# get dict
|
||||||
self.path2 = phoneme_path # "%s/2-name2text.txt"%exp_dir#phoneme_path
|
self.path2 = phoneme_path # "%s/2-name2text.txt"%exp_dir#phoneme_path
|
||||||
self.path3 = "%s/3-bert" % (
|
self.path3 = "%s/3-bert" % (
|
||||||
os.path.dirname(phoneme_path)
|
os.path.dirname(
|
||||||
|
phoneme_path,
|
||||||
|
)
|
||||||
) # "%s/3-bert"%exp_dir#bert_dir
|
) # "%s/3-bert"%exp_dir#bert_dir
|
||||||
self.path6 = semantic_path # "%s/6-name2semantic.tsv"%exp_dir#semantic_path
|
self.path6 = semantic_path # "%s/6-name2semantic.tsv"%exp_dir#semantic_path
|
||||||
assert os.path.exists(self.path2)
|
assert os.path.exists(self.path2)
|
||||||
@ -127,7 +125,7 @@ class Text2SemanticDataset(Dataset):
|
|||||||
for i in range(semantic_data_len):
|
for i in range(semantic_data_len):
|
||||||
# 先依次遍历
|
# 先依次遍历
|
||||||
# get str
|
# get str
|
||||||
item_name = self.semantic_data.iloc[i,0]
|
item_name = self.semantic_data.iloc[i, 0]
|
||||||
# print(self.phoneme_data)
|
# print(self.phoneme_data)
|
||||||
try:
|
try:
|
||||||
phoneme, word2ph, text = self.phoneme_data[item_name]
|
phoneme, word2ph, text = self.phoneme_data[item_name]
|
||||||
@ -137,7 +135,7 @@ class Text2SemanticDataset(Dataset):
|
|||||||
num_not_in += 1
|
num_not_in += 1
|
||||||
continue
|
continue
|
||||||
|
|
||||||
semantic_str = self.semantic_data.iloc[i,1]
|
semantic_str = self.semantic_data.iloc[i, 1]
|
||||||
# get token list
|
# get token list
|
||||||
semantic_ids = [int(idx) for idx in semantic_str.split(" ")]
|
semantic_ids = [int(idx) for idx in semantic_str.split(" ")]
|
||||||
# (T), 是否需要变成 (1, T) -> 不需要,因为需要求 len
|
# (T), 是否需要变成 (1, T) -> 不需要,因为需要求 len
|
||||||
@ -158,9 +156,7 @@ class Text2SemanticDataset(Dataset):
|
|||||||
num_not_in += 1
|
num_not_in += 1
|
||||||
continue
|
continue
|
||||||
# if len(phoneme_ids) >400:###########2:改为恒定限制为semantic/2.5就行
|
# if len(phoneme_ids) >400:###########2:改为恒定限制为semantic/2.5就行
|
||||||
if (
|
if len(phoneme_ids) > self.max_sec * self.hz / 2.5: ###########2:改为恒定限制为semantic/2.5就行
|
||||||
len(phoneme_ids) > self.max_sec * self.hz / 2.5
|
|
||||||
): ###########2:改为恒定限制为semantic/2.5就行
|
|
||||||
num_deleted_ps += 1
|
num_deleted_ps += 1
|
||||||
continue
|
continue
|
||||||
# if len(semantic_ids) > 1000:###########3
|
# if len(semantic_ids) > 1000:###########3
|
||||||
@ -169,9 +165,7 @@ class Text2SemanticDataset(Dataset):
|
|||||||
|
|
||||||
ps_ratio = len(phoneme_ids) / (len(semantic_ids) / self.hz)
|
ps_ratio = len(phoneme_ids) / (len(semantic_ids) / self.hz)
|
||||||
|
|
||||||
if (
|
if ps_ratio > self.max_ps_ratio or ps_ratio < self.min_ps_ratio: ##########4#3~25#每秒多少个phone
|
||||||
ps_ratio > self.max_ps_ratio or ps_ratio < self.min_ps_ratio
|
|
||||||
): ##########4#3~25#每秒多少个phone
|
|
||||||
num_deleted_ps += 1
|
num_deleted_ps += 1
|
||||||
# print(item_name)
|
# print(item_name)
|
||||||
continue
|
continue
|
||||||
@ -194,12 +188,12 @@ class Text2SemanticDataset(Dataset):
|
|||||||
print(f"there are {num_not_in} semantic datas not in phoneme datas")
|
print(f"there are {num_not_in} semantic datas not in phoneme datas")
|
||||||
if num_deleted_bigger > 0:
|
if num_deleted_bigger > 0:
|
||||||
print(
|
print(
|
||||||
f"deleted {num_deleted_bigger} audios who's duration are bigger than {self.max_sec} seconds"
|
f"deleted {num_deleted_bigger} audios who's duration are bigger than {self.max_sec} seconds",
|
||||||
)
|
)
|
||||||
if num_deleted_ps > 0:
|
if num_deleted_ps > 0:
|
||||||
# 4702 for LibriTTS, LirbriTTS 是标注数据, 是否需要筛?=> 需要,有值为 100 的极端值
|
# 4702 for LibriTTS, LirbriTTS 是标注数据, 是否需要筛?=> 需要,有值为 100 的极端值
|
||||||
print(
|
print(
|
||||||
f"deleted {num_deleted_ps} audios who's phoneme/sec are bigger than {self.max_ps_ratio} or smaller than {self.min_ps_ratio}"
|
f"deleted {num_deleted_ps} audios who's phoneme/sec are bigger than {self.max_ps_ratio} or smaller than {self.min_ps_ratio}",
|
||||||
)
|
)
|
||||||
"""
|
"""
|
||||||
there are 31 semantic datas not in phoneme datas
|
there are 31 semantic datas not in phoneme datas
|
||||||
@ -306,7 +300,10 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
batch_size = 12
|
batch_size = 12
|
||||||
dataloader = DataLoader(
|
dataloader = DataLoader(
|
||||||
dataset, batch_size=batch_size, collate_fn=dataset.collate, shuffle=False
|
dataset,
|
||||||
|
batch_size=batch_size,
|
||||||
|
collate_fn=dataset.collate,
|
||||||
|
shuffle=False,
|
||||||
)
|
)
|
||||||
for i, batch in enumerate(dataloader):
|
for i, batch in enumerate(dataloader):
|
||||||
if i % 1000 == 0:
|
if i % 1000 == 0:
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
# modified from https://github.com/yangdongchao/SoundStorm/blob/master/soundstorm/s1/AR/models/t2s_lightning_module.py
|
# modified from https://github.com/yangdongchao/SoundStorm/blob/master/soundstorm/s1/AR/models/t2s_lightning_module.py
|
||||||
# reference: https://github.com/lifeiteng/vall-e
|
# reference: https://github.com/lifeiteng/vall-e
|
||||||
import os, sys
|
import os
|
||||||
|
import sys
|
||||||
|
|
||||||
now_dir = os.getcwd()
|
now_dir = os.getcwd()
|
||||||
sys.path.append(now_dir)
|
sys.path.append(now_dir)
|
||||||
@ -8,10 +9,12 @@ from typing import Dict
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
from pytorch_lightning import LightningModule
|
from pytorch_lightning import LightningModule
|
||||||
|
|
||||||
from AR.models.t2s_model import Text2SemanticDecoder
|
from AR.models.t2s_model import Text2SemanticDecoder
|
||||||
from AR.modules.lr_schedulers import WarmupCosineLRSchedule
|
from AR.modules.lr_schedulers import WarmupCosineLRSchedule
|
||||||
from AR.modules.optim import ScaledAdam
|
from AR.modules.optim import ScaledAdam
|
||||||
|
|
||||||
|
|
||||||
class Text2SemanticLightningModule(LightningModule):
|
class Text2SemanticLightningModule(LightningModule):
|
||||||
def __init__(self, config, output_dir, is_train=True):
|
def __init__(self, config, output_dir, is_train=True):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -23,7 +26,10 @@ class Text2SemanticLightningModule(LightningModule):
|
|||||||
# print(self.load_state_dict(torch.load(pretrained_s1,map_location="cpu")["state_dict"]))
|
# print(self.load_state_dict(torch.load(pretrained_s1,map_location="cpu")["state_dict"]))
|
||||||
print(
|
print(
|
||||||
self.load_state_dict(
|
self.load_state_dict(
|
||||||
torch.load(pretrained_s1, map_location="cpu")["weight"]
|
torch.load(
|
||||||
|
pretrained_s1,
|
||||||
|
map_location="cpu",
|
||||||
|
)["weight"],
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
if is_train:
|
if is_train:
|
||||||
@ -35,7 +41,7 @@ class Text2SemanticLightningModule(LightningModule):
|
|||||||
def training_step(self, batch: Dict, batch_idx: int):
|
def training_step(self, batch: Dict, batch_idx: int):
|
||||||
opt = self.optimizers()
|
opt = self.optimizers()
|
||||||
scheduler = self.lr_schedulers()
|
scheduler = self.lr_schedulers()
|
||||||
forward=self.model.forward if self.config["train"].get("if_dpo",False)==True else self.model.forward_old
|
forward = self.model.forward if self.config["train"].get("if_dpo", False) == True else self.model.forward_old
|
||||||
loss, acc = forward(
|
loss, acc = forward(
|
||||||
batch["phoneme_ids"],
|
batch["phoneme_ids"],
|
||||||
batch["phoneme_ids_len"],
|
batch["phoneme_ids_len"],
|
||||||
@ -113,9 +119,7 @@ class Text2SemanticLightningModule(LightningModule):
|
|||||||
def configure_optimizers(self):
|
def configure_optimizers(self):
|
||||||
model_parameters = self.model.parameters()
|
model_parameters = self.model.parameters()
|
||||||
parameters_names = []
|
parameters_names = []
|
||||||
parameters_names.append(
|
parameters_names.append([name_param_pair[0] for name_param_pair in self.model.named_parameters()])
|
||||||
[name_param_pair[0] for name_param_pair in self.model.named_parameters()]
|
|
||||||
)
|
|
||||||
lm_opt = ScaledAdam(
|
lm_opt = ScaledAdam(
|
||||||
model_parameters,
|
model_parameters,
|
||||||
lr=0.01,
|
lr=0.01,
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
# modified from https://github.com/yangdongchao/SoundStorm/blob/master/soundstorm/s1/AR/models/t2s_lightning_module.py
|
# modified from https://github.com/yangdongchao/SoundStorm/blob/master/soundstorm/s1/AR/models/t2s_lightning_module.py
|
||||||
# reference: https://github.com/lifeiteng/vall-e
|
# reference: https://github.com/lifeiteng/vall-e
|
||||||
import os, sys
|
import os
|
||||||
|
import sys
|
||||||
|
|
||||||
now_dir = os.getcwd()
|
now_dir = os.getcwd()
|
||||||
sys.path.append(now_dir)
|
sys.path.append(now_dir)
|
||||||
@ -8,6 +9,7 @@ from typing import Dict
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
from pytorch_lightning import LightningModule
|
from pytorch_lightning import LightningModule
|
||||||
|
|
||||||
from AR.models.t2s_model_onnx import Text2SemanticDecoder
|
from AR.models.t2s_model_onnx import Text2SemanticDecoder
|
||||||
from AR.modules.lr_schedulers import WarmupCosineLRSchedule
|
from AR.modules.lr_schedulers import WarmupCosineLRSchedule
|
||||||
from AR.modules.optim import ScaledAdam
|
from AR.modules.optim import ScaledAdam
|
||||||
@ -24,8 +26,11 @@ class Text2SemanticLightningModule(LightningModule):
|
|||||||
# print(self.load_state_dict(torch.load(pretrained_s1,map_location="cpu")["state_dict"]))
|
# print(self.load_state_dict(torch.load(pretrained_s1,map_location="cpu")["state_dict"]))
|
||||||
print(
|
print(
|
||||||
self.load_state_dict(
|
self.load_state_dict(
|
||||||
torch.load(pretrained_s1, map_location="cpu")["weight"]
|
torch.load(
|
||||||
)
|
pretrained_s1,
|
||||||
|
map_location="cpu",
|
||||||
|
)["weight"],
|
||||||
|
),
|
||||||
)
|
)
|
||||||
if is_train:
|
if is_train:
|
||||||
self.automatic_optimization = False
|
self.automatic_optimization = False
|
||||||
@ -79,9 +84,7 @@ class Text2SemanticLightningModule(LightningModule):
|
|||||||
def configure_optimizers(self):
|
def configure_optimizers(self):
|
||||||
model_parameters = self.model.parameters()
|
model_parameters = self.model.parameters()
|
||||||
parameters_names = []
|
parameters_names = []
|
||||||
parameters_names.append(
|
parameters_names.append([name_param_pair[0] for name_param_pair in self.model.named_parameters()])
|
||||||
[name_param_pair[0] for name_param_pair in self.model.named_parameters()]
|
|
||||||
)
|
|
||||||
lm_opt = ScaledAdam(
|
lm_opt = ScaledAdam(
|
||||||
model_parameters,
|
model_parameters,
|
||||||
lr=0.01,
|
lr=0.01,
|
||||||
|
@ -2,27 +2,24 @@
|
|||||||
# reference: https://github.com/lifeiteng/vall-e
|
# reference: https://github.com/lifeiteng/vall-e
|
||||||
import math
|
import math
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
import torch
|
|
||||||
from tqdm import tqdm
|
|
||||||
|
|
||||||
from AR.models.utils import make_pad_mask
|
import torch
|
||||||
from AR.models.utils import (
|
|
||||||
topk_sampling,
|
|
||||||
sample,
|
|
||||||
logits_to_probs,
|
|
||||||
multinomial_sample_one_no_sync,
|
|
||||||
dpo_loss,
|
|
||||||
make_reject_y,
|
|
||||||
get_batch_logps
|
|
||||||
)
|
|
||||||
from AR.modules.embedding import SinePositionalEmbedding
|
|
||||||
from AR.modules.embedding import TokenEmbedding
|
|
||||||
from AR.modules.transformer import LayerNorm
|
|
||||||
from AR.modules.transformer import TransformerEncoder
|
|
||||||
from AR.modules.transformer import TransformerEncoderLayer
|
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.nn import functional as F
|
from torch.nn import functional as F
|
||||||
from torchmetrics.classification import MulticlassAccuracy
|
from torchmetrics.classification import MulticlassAccuracy
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
from AR.models.utils import (
|
||||||
|
dpo_loss,
|
||||||
|
get_batch_logps,
|
||||||
|
make_pad_mask,
|
||||||
|
make_pad_mask_left,
|
||||||
|
make_reject_y,
|
||||||
|
sample,
|
||||||
|
topk_sampling,
|
||||||
|
)
|
||||||
|
from AR.modules.embedding import SinePositionalEmbedding, TokenEmbedding
|
||||||
|
from AR.modules.transformer import LayerNorm, TransformerEncoder, TransformerEncoderLayer
|
||||||
|
|
||||||
default_config = {
|
default_config = {
|
||||||
"embedding_dim": 512,
|
"embedding_dim": 512,
|
||||||
@ -36,10 +33,17 @@ default_config = {
|
|||||||
"EOS": 1024,
|
"EOS": 1024,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
# @torch.jit.script ## 使用的话首次推理会非常慢,而且推理速度不稳定
|
# @torch.jit.script ## 使用的话首次推理会非常慢,而且推理速度不稳定
|
||||||
# Efficient implementation equivalent to the following:
|
# Efficient implementation equivalent to the following:
|
||||||
def scaled_dot_product_attention(query:torch.Tensor, key:torch.Tensor, value:torch.Tensor, attn_mask:Optional[torch.Tensor]=None, scale:Optional[torch.Tensor]=None) -> torch.Tensor:
|
def scaled_dot_product_attention(
|
||||||
B, H, L, S =query.size(0), query.size(1), query.size(-2), key.size(-2)
|
query: torch.Tensor,
|
||||||
|
key: torch.Tensor,
|
||||||
|
value: torch.Tensor,
|
||||||
|
attn_mask: Optional[torch.Tensor] = None,
|
||||||
|
scale: Optional[torch.Tensor] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
B, H, L, S = query.size(0), query.size(1), query.size(-2), key.size(-2)
|
||||||
if scale is None:
|
if scale is None:
|
||||||
scale_factor = torch.tensor(1 / math.sqrt(query.size(-1)))
|
scale_factor = torch.tensor(1 / math.sqrt(query.size(-1)))
|
||||||
else:
|
else:
|
||||||
@ -59,12 +63,13 @@ def scaled_dot_product_attention(query:torch.Tensor, key:torch.Tensor, value:tor
|
|||||||
if attn_mask.dtype == torch.bool:
|
if attn_mask.dtype == torch.bool:
|
||||||
attn_weight.masked_fill_(attn_mask, 0)
|
attn_weight.masked_fill_(attn_mask, 0)
|
||||||
else:
|
else:
|
||||||
attn_mask[attn_mask!=float("-inf")] =0
|
attn_mask[attn_mask != float("-inf")] = 0
|
||||||
attn_mask[attn_mask==float("-inf")] =1
|
attn_mask[attn_mask == float("-inf")] = 1
|
||||||
attn_weight.masked_fill_(attn_mask, 0)
|
attn_weight.masked_fill_(attn_mask, 0)
|
||||||
|
|
||||||
return attn_weight @ value
|
return attn_weight @ value
|
||||||
|
|
||||||
|
|
||||||
@torch.jit.script
|
@torch.jit.script
|
||||||
class T2SMLP:
|
class T2SMLP:
|
||||||
def __init__(self, w1, b1, w2, b2):
|
def __init__(self, w1, b1, w2, b2):
|
||||||
@ -82,20 +87,20 @@ class T2SMLP:
|
|||||||
@torch.jit.script
|
@torch.jit.script
|
||||||
class T2SBlock:
|
class T2SBlock:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
num_heads,
|
num_heads,
|
||||||
hidden_dim: int,
|
hidden_dim: int,
|
||||||
mlp: T2SMLP,
|
mlp: T2SMLP,
|
||||||
qkv_w,
|
qkv_w,
|
||||||
qkv_b,
|
qkv_b,
|
||||||
out_w,
|
out_w,
|
||||||
out_b,
|
out_b,
|
||||||
norm_w1,
|
norm_w1,
|
||||||
norm_b1,
|
norm_b1,
|
||||||
norm_eps1,
|
norm_eps1,
|
||||||
norm_w2,
|
norm_w2,
|
||||||
norm_b2,
|
norm_b2,
|
||||||
norm_eps2,
|
norm_eps2,
|
||||||
):
|
):
|
||||||
self.num_heads = num_heads
|
self.num_heads = num_heads
|
||||||
self.mlp = mlp
|
self.mlp = mlp
|
||||||
@ -114,24 +119,32 @@ class T2SBlock:
|
|||||||
self.false = torch.tensor(False, dtype=torch.bool)
|
self.false = torch.tensor(False, dtype=torch.bool)
|
||||||
|
|
||||||
@torch.jit.ignore
|
@torch.jit.ignore
|
||||||
def to_mask(self, x:torch.Tensor, padding_mask:Optional[torch.Tensor]):
|
def to_mask(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
padding_mask: Optional[torch.Tensor],
|
||||||
|
):
|
||||||
if padding_mask is None:
|
if padding_mask is None:
|
||||||
return x
|
return x
|
||||||
|
|
||||||
if padding_mask.dtype == torch.bool:
|
if padding_mask.dtype == torch.bool:
|
||||||
return x.masked_fill(padding_mask, 0)
|
return x.masked_fill(padding_mask, 0)
|
||||||
else:
|
else:
|
||||||
return x * padding_mask
|
return x * padding_mask
|
||||||
|
|
||||||
def process_prompt(self, x:torch.Tensor, attn_mask : torch.Tensor, padding_mask:Optional[torch.Tensor]=None, torch_sdpa:bool=True):
|
|
||||||
|
|
||||||
|
def process_prompt(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
attn_mask: torch.Tensor,
|
||||||
|
padding_mask: Optional[torch.Tensor] = None,
|
||||||
|
torch_sdpa: bool = True,
|
||||||
|
):
|
||||||
q, k, v = F.linear(self.to_mask(x, padding_mask), self.qkv_w, self.qkv_b).chunk(3, dim=-1)
|
q, k, v = F.linear(self.to_mask(x, padding_mask), self.qkv_w, self.qkv_b).chunk(3, dim=-1)
|
||||||
|
|
||||||
batch_size = q.shape[0]
|
batch_size = q.shape[0]
|
||||||
q_len = q.shape[1]
|
q_len = q.shape[1]
|
||||||
kv_len = k.shape[1]
|
kv_len = k.shape[1]
|
||||||
|
|
||||||
q = self.to_mask(q, padding_mask)
|
q = self.to_mask(q, padding_mask)
|
||||||
k_cache = self.to_mask(k, padding_mask)
|
k_cache = self.to_mask(k, padding_mask)
|
||||||
v_cache = self.to_mask(v, padding_mask)
|
v_cache = self.to_mask(v, padding_mask)
|
||||||
@ -145,53 +158,34 @@ class T2SBlock:
|
|||||||
else:
|
else:
|
||||||
attn = scaled_dot_product_attention(q, k, v, attn_mask)
|
attn = scaled_dot_product_attention(q, k, v, attn_mask)
|
||||||
|
|
||||||
attn = attn.permute(2, 0, 1, 3).reshape(batch_size*q_len, self.hidden_dim)
|
attn = attn.transpose(1, 2).reshape(batch_size, q_len, -1)
|
||||||
attn = attn.view(q_len, batch_size, self.hidden_dim).transpose(1, 0)
|
|
||||||
attn = F.linear(self.to_mask(attn, padding_mask), self.out_w, self.out_b)
|
attn = F.linear(self.to_mask(attn, padding_mask), self.out_w, self.out_b)
|
||||||
|
|
||||||
if padding_mask is not None:
|
x = x + attn
|
||||||
for i in range(batch_size):
|
x = F.layer_norm(x, [self.hidden_dim], self.norm_w1, self.norm_b1, self.norm_eps1)
|
||||||
# mask = padding_mask[i,:,0]
|
x = x + self.mlp.forward(x)
|
||||||
if self.false.device!= padding_mask.device:
|
x = F.layer_norm(
|
||||||
self.false = self.false.to(padding_mask.device)
|
x,
|
||||||
idx = torch.where(padding_mask[i,:,0]==self.false)[0]
|
[self.hidden_dim],
|
||||||
x_item = x[i,idx,:].unsqueeze(0)
|
self.norm_w2,
|
||||||
attn_item = attn[i,idx,:].unsqueeze(0)
|
self.norm_b2,
|
||||||
x_item = x_item + attn_item
|
self.norm_eps2,
|
||||||
x_item = F.layer_norm(
|
)
|
||||||
x_item, [self.hidden_dim], self.norm_w1, self.norm_b1, self.norm_eps1
|
|
||||||
)
|
|
||||||
x_item = x_item + self.mlp.forward(x_item)
|
|
||||||
x_item = F.layer_norm(
|
|
||||||
x_item,
|
|
||||||
[self.hidden_dim],
|
|
||||||
self.norm_w2,
|
|
||||||
self.norm_b2,
|
|
||||||
self.norm_eps2,
|
|
||||||
)
|
|
||||||
x[i,idx,:] = x_item.squeeze(0)
|
|
||||||
x = self.to_mask(x, padding_mask)
|
|
||||||
else:
|
|
||||||
x = x + attn
|
|
||||||
x = F.layer_norm(
|
|
||||||
x, [self.hidden_dim], self.norm_w1, self.norm_b1, self.norm_eps1
|
|
||||||
)
|
|
||||||
x = x + self.mlp.forward(x)
|
|
||||||
x = F.layer_norm(
|
|
||||||
x,
|
|
||||||
[self.hidden_dim],
|
|
||||||
self.norm_w2,
|
|
||||||
self.norm_b2,
|
|
||||||
self.norm_eps2,
|
|
||||||
)
|
|
||||||
return x, k_cache, v_cache
|
return x, k_cache, v_cache
|
||||||
|
|
||||||
def decode_next_token(self, x:torch.Tensor, k_cache:torch.Tensor, v_cache:torch.Tensor, attn_mask:Optional[torch.Tensor]=None, torch_sdpa:bool=True):
|
def decode_next_token(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
k_cache: torch.Tensor,
|
||||||
|
v_cache: torch.Tensor,
|
||||||
|
attn_mask: torch.Tensor = None,
|
||||||
|
torch_sdpa: bool = True,
|
||||||
|
):
|
||||||
q, k, v = F.linear(x, self.qkv_w, self.qkv_b).chunk(3, dim=-1)
|
q, k, v = F.linear(x, self.qkv_w, self.qkv_b).chunk(3, dim=-1)
|
||||||
|
|
||||||
k_cache = torch.cat([k_cache, k], dim=1)
|
k_cache = torch.cat([k_cache, k], dim=1)
|
||||||
v_cache = torch.cat([v_cache, v], dim=1)
|
v_cache = torch.cat([v_cache, v], dim=1)
|
||||||
|
|
||||||
batch_size = q.shape[0]
|
batch_size = q.shape[0]
|
||||||
q_len = q.shape[1]
|
q_len = q.shape[1]
|
||||||
kv_len = k_cache.shape[1]
|
kv_len = k_cache.shape[1]
|
||||||
@ -200,19 +194,21 @@ class T2SBlock:
|
|||||||
k = k_cache.view(batch_size, kv_len, self.num_heads, -1).transpose(1, 2)
|
k = k_cache.view(batch_size, kv_len, self.num_heads, -1).transpose(1, 2)
|
||||||
v = v_cache.view(batch_size, kv_len, self.num_heads, -1).transpose(1, 2)
|
v = v_cache.view(batch_size, kv_len, self.num_heads, -1).transpose(1, 2)
|
||||||
|
|
||||||
|
|
||||||
if torch_sdpa:
|
if torch_sdpa:
|
||||||
attn = F.scaled_dot_product_attention(q, k, v)
|
attn = F.scaled_dot_product_attention(q, k, v, (~attn_mask) if attn_mask is not None else None)
|
||||||
else:
|
else:
|
||||||
attn = scaled_dot_product_attention(q, k, v, attn_mask)
|
attn = scaled_dot_product_attention(q, k, v, attn_mask)
|
||||||
|
|
||||||
attn = attn.permute(2, 0, 1, 3).reshape(batch_size*q_len, self.hidden_dim)
|
attn = attn.transpose(1, 2).reshape(batch_size, q_len, -1)
|
||||||
attn = attn.view(q_len, batch_size, self.hidden_dim).transpose(1, 0)
|
|
||||||
attn = F.linear(attn, self.out_w, self.out_b)
|
attn = F.linear(attn, self.out_w, self.out_b)
|
||||||
|
|
||||||
x = x + attn
|
x = x + attn
|
||||||
x = F.layer_norm(
|
x = F.layer_norm(
|
||||||
x, [self.hidden_dim], self.norm_w1, self.norm_b1, self.norm_eps1
|
x,
|
||||||
|
[self.hidden_dim],
|
||||||
|
self.norm_w1,
|
||||||
|
self.norm_b1,
|
||||||
|
self.norm_eps1,
|
||||||
)
|
)
|
||||||
x = x + self.mlp.forward(x)
|
x = x + self.mlp.forward(x)
|
||||||
x = F.layer_norm(
|
x = F.layer_norm(
|
||||||
@ -227,17 +223,19 @@ class T2SBlock:
|
|||||||
|
|
||||||
@torch.jit.script
|
@torch.jit.script
|
||||||
class T2STransformer:
|
class T2STransformer:
|
||||||
def __init__(self, num_blocks : int, blocks: List[T2SBlock]):
|
def __init__(self, num_blocks: int, blocks: List[T2SBlock]):
|
||||||
self.num_blocks : int = num_blocks
|
self.num_blocks: int = num_blocks
|
||||||
self.blocks = blocks
|
self.blocks = blocks
|
||||||
|
|
||||||
def process_prompt(
|
def process_prompt(
|
||||||
self, x:torch.Tensor, attn_mask : torch.Tensor,
|
self,
|
||||||
padding_mask : Optional[torch.Tensor]=None,
|
x: torch.Tensor,
|
||||||
torch_sdpa:bool=True
|
attn_mask: torch.Tensor,
|
||||||
):
|
padding_mask: Optional[torch.Tensor] = None,
|
||||||
k_cache : List[torch.Tensor] = []
|
torch_sdpa: bool = True,
|
||||||
v_cache : List[torch.Tensor] = []
|
):
|
||||||
|
k_cache: List[torch.Tensor] = []
|
||||||
|
v_cache: List[torch.Tensor] = []
|
||||||
for i in range(self.num_blocks):
|
for i in range(self.num_blocks):
|
||||||
x, k_cache_, v_cache_ = self.blocks[i].process_prompt(x, attn_mask, padding_mask, torch_sdpa)
|
x, k_cache_, v_cache_ = self.blocks[i].process_prompt(x, attn_mask, padding_mask, torch_sdpa)
|
||||||
k_cache.append(k_cache_)
|
k_cache.append(k_cache_)
|
||||||
@ -245,14 +243,17 @@ class T2STransformer:
|
|||||||
return x, k_cache, v_cache
|
return x, k_cache, v_cache
|
||||||
|
|
||||||
def decode_next_token(
|
def decode_next_token(
|
||||||
self, x:torch.Tensor,
|
self,
|
||||||
k_cache: List[torch.Tensor],
|
x: torch.Tensor,
|
||||||
v_cache: List[torch.Tensor],
|
k_cache: List[torch.Tensor],
|
||||||
attn_mask : Optional[torch.Tensor]=None,
|
v_cache: List[torch.Tensor],
|
||||||
torch_sdpa:bool=True
|
attn_mask: torch.Tensor = None,
|
||||||
|
torch_sdpa: bool = True,
|
||||||
):
|
):
|
||||||
for i in range(self.num_blocks):
|
for i in range(self.num_blocks):
|
||||||
x, k_cache[i], v_cache[i] = self.blocks[i].decode_next_token(x, k_cache[i], v_cache[i], attn_mask, torch_sdpa)
|
x, k_cache[i], v_cache[i] = self.blocks[i].decode_next_token(
|
||||||
|
x, k_cache[i], v_cache[i], attn_mask, torch_sdpa
|
||||||
|
)
|
||||||
return x, k_cache, v_cache
|
return x, k_cache, v_cache
|
||||||
|
|
||||||
|
|
||||||
@ -274,16 +275,26 @@ class Text2SemanticDecoder(nn.Module):
|
|||||||
# assert self.EOS == 1024
|
# assert self.EOS == 1024
|
||||||
self.bert_proj = nn.Linear(1024, self.embedding_dim)
|
self.bert_proj = nn.Linear(1024, self.embedding_dim)
|
||||||
self.ar_text_embedding = TokenEmbedding(
|
self.ar_text_embedding = TokenEmbedding(
|
||||||
self.embedding_dim, self.phoneme_vocab_size, self.p_dropout
|
self.embedding_dim,
|
||||||
|
self.phoneme_vocab_size,
|
||||||
|
self.p_dropout,
|
||||||
)
|
)
|
||||||
self.ar_text_position = SinePositionalEmbedding(
|
self.ar_text_position = SinePositionalEmbedding(
|
||||||
self.embedding_dim, dropout=0.1, scale=False, alpha=True
|
self.embedding_dim,
|
||||||
|
dropout=0.1,
|
||||||
|
scale=False,
|
||||||
|
alpha=True,
|
||||||
)
|
)
|
||||||
self.ar_audio_embedding = TokenEmbedding(
|
self.ar_audio_embedding = TokenEmbedding(
|
||||||
self.embedding_dim, self.vocab_size, self.p_dropout
|
self.embedding_dim,
|
||||||
|
self.vocab_size,
|
||||||
|
self.p_dropout,
|
||||||
)
|
)
|
||||||
self.ar_audio_position = SinePositionalEmbedding(
|
self.ar_audio_position = SinePositionalEmbedding(
|
||||||
self.embedding_dim, dropout=0.1, scale=False, alpha=True
|
self.embedding_dim,
|
||||||
|
dropout=0.1,
|
||||||
|
scale=False,
|
||||||
|
alpha=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.h = TransformerEncoder(
|
self.h = TransformerEncoder(
|
||||||
@ -318,7 +329,7 @@ class Text2SemanticDecoder(nn.Module):
|
|||||||
layer.linear1.weight,
|
layer.linear1.weight,
|
||||||
layer.linear1.bias,
|
layer.linear1.bias,
|
||||||
layer.linear2.weight,
|
layer.linear2.weight,
|
||||||
layer.linear2.bias
|
layer.linear2.bias,
|
||||||
)
|
)
|
||||||
|
|
||||||
block = T2SBlock(
|
block = T2SBlock(
|
||||||
@ -334,11 +345,11 @@ class Text2SemanticDecoder(nn.Module):
|
|||||||
layer.norm1.eps,
|
layer.norm1.eps,
|
||||||
layer.norm2.weight,
|
layer.norm2.weight,
|
||||||
layer.norm2.bias,
|
layer.norm2.bias,
|
||||||
layer.norm2.eps
|
layer.norm2.eps,
|
||||||
)
|
)
|
||||||
|
|
||||||
blocks.append(block)
|
blocks.append(block)
|
||||||
|
|
||||||
self.t2s_transformer = T2STransformer(self.num_layers, blocks)
|
self.t2s_transformer = T2STransformer(self.num_layers, blocks)
|
||||||
|
|
||||||
def make_input_data(self, x, x_lens, y, y_lens, bert_feature):
|
def make_input_data(self, x, x_lens, y, y_lens, bert_feature):
|
||||||
@ -412,7 +423,9 @@ class Text2SemanticDecoder(nn.Module):
|
|||||||
logits = self.ar_predict_layer(xy_dec[:, x_len:])
|
logits = self.ar_predict_layer(xy_dec[:, x_len:])
|
||||||
|
|
||||||
###### DPO #############
|
###### DPO #############
|
||||||
reject_xy_pos, reject_xy_attn_mask, reject_targets = self.make_input_data(x, x_lens, reject_y, reject_y_lens, bert_feature)
|
reject_xy_pos, reject_xy_attn_mask, reject_targets = self.make_input_data(
|
||||||
|
x, x_lens, reject_y, reject_y_lens, bert_feature
|
||||||
|
)
|
||||||
|
|
||||||
reject_xy_dec, _ = self.h(
|
reject_xy_dec, _ = self.h(
|
||||||
(reject_xy_pos, None),
|
(reject_xy_pos, None),
|
||||||
@ -429,7 +442,7 @@ class Text2SemanticDecoder(nn.Module):
|
|||||||
|
|
||||||
A_logits, R_logits = get_batch_logps(logits, reject_logits, targets, reject_targets)
|
A_logits, R_logits = get_batch_logps(logits, reject_logits, targets, reject_targets)
|
||||||
loss_2, _, _ = dpo_loss(A_logits, R_logits, 0, 0, 0.2, reference_free=True)
|
loss_2, _, _ = dpo_loss(A_logits, R_logits, 0, 0, 0.2, reference_free=True)
|
||||||
|
|
||||||
loss = loss_1 + loss_2
|
loss = loss_1 + loss_2
|
||||||
|
|
||||||
return loss, acc
|
return loss, acc
|
||||||
@ -498,14 +511,14 @@ class Text2SemanticDecoder(nn.Module):
|
|||||||
|
|
||||||
# 需要看下这个函数和 forward 的区别以及没有 semantic 的时候 prompts 输入什么
|
# 需要看下这个函数和 forward 的区别以及没有 semantic 的时候 prompts 输入什么
|
||||||
def infer(
|
def infer(
|
||||||
self,
|
self,
|
||||||
x,
|
x,
|
||||||
x_lens,
|
x_lens,
|
||||||
prompts,
|
prompts,
|
||||||
bert_feature,
|
bert_feature,
|
||||||
top_k: int = -100,
|
top_k: int = -100,
|
||||||
early_stop_num: int = -1,
|
early_stop_num: int = -1,
|
||||||
temperature: float = 1.0,
|
temperature: float = 1.0,
|
||||||
):
|
):
|
||||||
x = self.ar_text_embedding(x)
|
x = self.ar_text_embedding(x)
|
||||||
x = x + self.bert_proj(bert_feature.transpose(1, 2))
|
x = x + self.bert_proj(bert_feature.transpose(1, 2))
|
||||||
@ -533,18 +546,14 @@ class Text2SemanticDecoder(nn.Module):
|
|||||||
(x_len, 0),
|
(x_len, 0),
|
||||||
value=False,
|
value=False,
|
||||||
)
|
)
|
||||||
xy_attn_mask = torch.concat([x_attn_mask_pad, y_attn_mask], dim=0).to(
|
xy_attn_mask = torch.concat([x_attn_mask_pad, y_attn_mask], dim=0).to(y.device)
|
||||||
y.device
|
|
||||||
)
|
|
||||||
|
|
||||||
xy_dec, _ = self.h(
|
xy_dec, _ = self.h(
|
||||||
(xy_pos, None),
|
(xy_pos, None),
|
||||||
mask=xy_attn_mask,
|
mask=xy_attn_mask,
|
||||||
)
|
)
|
||||||
logits = self.ar_predict_layer(xy_dec[:, -1])
|
logits = self.ar_predict_layer(xy_dec[:, -1])
|
||||||
samples = topk_sampling(
|
samples = topk_sampling(logits, top_k=top_k, top_p=1.0, temperature=temperature)
|
||||||
logits, top_k=top_k, top_p=1.0, temperature=temperature
|
|
||||||
)
|
|
||||||
|
|
||||||
if early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num:
|
if early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num:
|
||||||
print("use early stop num:", early_stop_num)
|
print("use early stop num:", early_stop_num)
|
||||||
@ -567,18 +576,16 @@ class Text2SemanticDecoder(nn.Module):
|
|||||||
return y
|
return y
|
||||||
|
|
||||||
def pad_y_eos(self, y, y_mask_int, eos_id):
|
def pad_y_eos(self, y, y_mask_int, eos_id):
|
||||||
targets = F.pad(y, (0, 1), value=0) + eos_id * F.pad(
|
targets = F.pad(y, (0, 1), value=0) + eos_id * F.pad(y_mask_int, (0, 1), value=1)
|
||||||
y_mask_int, (0, 1), value=1
|
|
||||||
)
|
|
||||||
# 错位
|
# 错位
|
||||||
return targets[:, :-1], targets[:, 1:]
|
return targets[:, :-1], targets[:, 1:]
|
||||||
|
|
||||||
def infer_panel_batch_infer(
|
def infer_panel_batch_infer(
|
||||||
self,
|
self,
|
||||||
x:List[torch.LongTensor], #####全部文本token
|
x: List[torch.LongTensor], #####全部文本token
|
||||||
x_lens:torch.LongTensor,
|
x_lens: torch.LongTensor,
|
||||||
prompts:torch.LongTensor, ####参考音频token
|
prompts: torch.LongTensor, ####参考音频token
|
||||||
bert_feature:List[torch.LongTensor],
|
bert_feature: List[torch.LongTensor],
|
||||||
top_k: int = -100,
|
top_k: int = -100,
|
||||||
top_p: int = 100,
|
top_p: int = 100,
|
||||||
early_stop_num: int = -1,
|
early_stop_num: int = -1,
|
||||||
@ -588,149 +595,168 @@ class Text2SemanticDecoder(nn.Module):
|
|||||||
):
|
):
|
||||||
if prompts is None:
|
if prompts is None:
|
||||||
print("Warning: Prompt free is not supported batch_infer! switch to naive_infer")
|
print("Warning: Prompt free is not supported batch_infer! switch to naive_infer")
|
||||||
return self.infer_panel_naive_batched(x, x_lens, prompts, bert_feature, top_k=top_k, top_p=top_p, early_stop_num=early_stop_num, temperature=temperature, **kwargs)
|
return self.infer_panel_naive_batched(
|
||||||
|
x,
|
||||||
|
x_lens,
|
||||||
|
prompts,
|
||||||
|
bert_feature,
|
||||||
|
top_k=top_k,
|
||||||
|
top_p=top_p,
|
||||||
|
early_stop_num=early_stop_num,
|
||||||
|
temperature=temperature,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
max_len = kwargs.get("max_len", x_lens.max())
|
||||||
max_len = kwargs.get("max_len",x_lens.max())
|
|
||||||
x_list = []
|
x_list = []
|
||||||
for x_item, bert_item in zip(x, bert_feature):
|
for x_item, bert_item in zip(x, bert_feature):
|
||||||
# max_len = max(max_len, x_item.shape[0], bert_item.shape[1])
|
# max_len = max(max_len, x_item.shape[0], bert_item.shape[1])
|
||||||
x_item = self.ar_text_embedding(x_item.unsqueeze(0))
|
x_item = self.ar_text_embedding(x_item.unsqueeze(0))
|
||||||
x_item = x_item + self.bert_proj(bert_item.transpose(0, 1).unsqueeze(0))
|
x_item = x_item + self.bert_proj(bert_item.transpose(0, 1).unsqueeze(0))
|
||||||
x_item = self.ar_text_position(x_item).squeeze(0)
|
x_item = self.ar_text_position(x_item).squeeze(0)
|
||||||
x_item = F.pad(x_item,(0,0,0,max_len-x_item.shape[0]),value=0) if x_item.shape[0]<max_len else x_item
|
# x_item = F.pad(x_item,(0,0,0,max_len-x_item.shape[0]),value=0) if x_item.shape[0]<max_len else x_item ### padding right
|
||||||
|
x_item = (
|
||||||
|
F.pad(x_item, (0, 0, max_len - x_item.shape[0], 0), value=0) if x_item.shape[0] < max_len else x_item
|
||||||
|
) ### padding left
|
||||||
x_list.append(x_item)
|
x_list.append(x_item)
|
||||||
x = torch.stack(x_list, dim=0)
|
x: torch.Tensor = torch.stack(x_list, dim=0)
|
||||||
|
|
||||||
|
|
||||||
# AR Decoder
|
# AR Decoder
|
||||||
y = prompts
|
y = prompts
|
||||||
|
|
||||||
x_len = x.shape[1]
|
x_len = x.shape[1]
|
||||||
x_attn_mask = torch.zeros((x_len, x_len), dtype=torch.bool)
|
|
||||||
stop = False
|
stop = False
|
||||||
|
|
||||||
k_cache = None
|
k_cache = None
|
||||||
v_cache = None
|
v_cache = None
|
||||||
################### first step ##########################
|
################### first step ##########################
|
||||||
if y is not None:
|
assert y is not None, "Error: Prompt free is not supported batch_infer!"
|
||||||
y_emb = self.ar_audio_embedding(y)
|
ref_free = False
|
||||||
y_len = y_emb.shape[1]
|
|
||||||
prefix_len = y.shape[1]
|
|
||||||
y_lens = torch.LongTensor([y_emb.shape[1]]*y_emb.shape[0]).to(x.device)
|
|
||||||
y_pos = self.ar_audio_position(y_emb)
|
|
||||||
xy_pos = torch.concat([x, y_pos], dim=1)
|
|
||||||
ref_free = False
|
|
||||||
else:
|
|
||||||
y_emb = None
|
|
||||||
y_len = 0
|
|
||||||
prefix_len = 0
|
|
||||||
y_lens = torch.LongTensor([y_len]*x.shape[0]).to(x.device)
|
|
||||||
y_pos = None
|
|
||||||
xy_pos = x
|
|
||||||
y = torch.zeros(x.shape[0], 0, dtype=torch.int, device=x.device)
|
|
||||||
ref_free = True
|
|
||||||
|
|
||||||
|
y_emb = self.ar_audio_embedding(y)
|
||||||
|
y_len = y_emb.shape[1]
|
||||||
|
prefix_len = y.shape[1]
|
||||||
|
y_lens = torch.LongTensor([y_emb.shape[1]] * y_emb.shape[0]).to(x.device)
|
||||||
|
y_pos = self.ar_audio_position(y_emb)
|
||||||
|
xy_pos = torch.concat([x, y_pos], dim=1)
|
||||||
|
|
||||||
##### create mask #####
|
##### create mask #####
|
||||||
bsz = x.shape[0]
|
bsz = x.shape[0]
|
||||||
src_len = x_len + y_len
|
src_len = x_len + y_len
|
||||||
y_paddind_mask = make_pad_mask(y_lens, y_len)
|
y_paddind_mask = make_pad_mask_left(y_lens, y_len)
|
||||||
x_paddind_mask = make_pad_mask(x_lens, max_len)
|
x_paddind_mask = make_pad_mask_left(x_lens, max_len)
|
||||||
|
|
||||||
# (bsz, x_len + y_len)
|
# (bsz, x_len + y_len)
|
||||||
xy_padding_mask = torch.concat([x_paddind_mask, y_paddind_mask], dim=1)
|
padding_mask = torch.concat([x_paddind_mask, y_paddind_mask], dim=1)
|
||||||
|
|
||||||
x_mask = F.pad(
|
x_mask = F.pad(
|
||||||
x_attn_mask,
|
torch.zeros(x_len, x_len, dtype=torch.bool, device=x.device),
|
||||||
(0, y_len), ###xx的纯0扩展到xx纯0+xy纯1,(x,x+y)
|
(0, y_len),
|
||||||
value=True,
|
value=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
y_mask = F.pad( ###yy的右上1扩展到左边xy的0,(y,x+y)
|
y_mask = F.pad( ###yy的右上1扩展到左边xy的0,(y,x+y)
|
||||||
torch.triu(torch.ones(y_len, y_len, dtype=torch.bool), diagonal=1),
|
torch.triu(torch.ones(y_len, y_len, dtype=torch.bool, device=x.device), diagonal=1),
|
||||||
(x_len, 0),
|
(x_len, 0),
|
||||||
value=False,
|
value=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
xy_mask = torch.concat([x_mask, y_mask], dim=0).view(1 , src_len, src_len).repeat(bsz, 1, 1).to(x.device)
|
causal_mask = torch.concat([x_mask, y_mask], dim=0).view(1, src_len, src_len).repeat(bsz, 1, 1).to(x.device)
|
||||||
_xy_padding_mask = xy_padding_mask.view(bsz, 1, src_len).repeat(1, src_len, 1)
|
# padding_mask = padding_mask.unsqueeze(1) * padding_mask.unsqueeze(2) ### [b, x+y, x+y]
|
||||||
|
### 上面是错误的,会导致padding的token被"看见"
|
||||||
for i in range(bsz):
|
|
||||||
l = x_lens[i]
|
# 正确的padding_mask应该是:
|
||||||
_xy_padding_mask[i,l:max_len,:]=True
|
# | pad_len | x_len | y_len |
|
||||||
|
# [[PAD, PAD, PAD, 1, 2, 3, 4, 5, 6],
|
||||||
xy_attn_mask = xy_mask.logical_or(_xy_padding_mask)
|
# [PAD, PAD, PAD, 1, 2, 3, 4, 5, 6],
|
||||||
xy_attn_mask = xy_attn_mask.unsqueeze(1).expand(-1, self.num_head, -1, -1)
|
# [PAD, PAD, PAD, 1, 2, 3, 4, 5, 6], 前3行按理说也应该被mask掉,但是为了防止计算attention时不出现nan,还是保留了,不影响结果
|
||||||
xy_attn_mask = xy_attn_mask.bool()
|
# [PAD, PAD, PAD, 1, 2, 3, 4, 5, 6],
|
||||||
xy_padding_mask = xy_padding_mask.view(bsz, src_len, 1).expand(-1, -1, self.model_dim)
|
# [PAD, PAD, PAD, 1, 2, 3, 4, 5, 6],
|
||||||
|
# [PAD, PAD, PAD, 1, 2, 3, 4, 5, 6],
|
||||||
|
# [PAD, PAD, PAD, 1, 2, 3, 4, 5, 6],
|
||||||
|
# [PAD, PAD, PAD, 1, 2, 3, 4, 5, 6],
|
||||||
|
# [PAD, PAD, PAD, 1, 2, 3, 4, 5, 6]]
|
||||||
|
|
||||||
|
padding_mask = padding_mask.view(bsz, 1, src_len).repeat(1, src_len, 1)
|
||||||
|
|
||||||
|
attn_mask: torch.Tensor = causal_mask.logical_or(padding_mask)
|
||||||
|
attn_mask = attn_mask.unsqueeze(1).expand(-1, self.num_head, -1, -1).bool()
|
||||||
|
|
||||||
|
# 正确的attn_mask应该是这样的:
|
||||||
|
# | pad_len | x_len | y_len |
|
||||||
|
# [[PAD, PAD, PAD, 1, 2, 3, EOS, EOS, EOS],
|
||||||
|
# [PAD, PAD, PAD, 1, 2, 3, EOS, EOS, EOS],
|
||||||
|
# [PAD, PAD, PAD, 1, 2, 3, EOS, EOS, EOS], 前3行按理说也应该被mask掉,但是为了防止计算attention时不出现nan,还是保留了,不影响结果
|
||||||
|
# [PAD, PAD, PAD, 1, 2, 3, EOS, EOS, EOS],
|
||||||
|
# [PAD, PAD, PAD, 1, 2, 3, EOS, EOS, EOS],
|
||||||
|
# [PAD, PAD, PAD, 1, 2, 3, EOS, EOS, EOS],
|
||||||
|
# [PAD, PAD, PAD, 1, 2, 3, 4, EOS, EOS],
|
||||||
|
# [PAD, PAD, PAD, 1, 2, 3, 4, 5, EOS],
|
||||||
|
# [PAD, PAD, PAD, 1, 2, 3, 4, 5, 6]]
|
||||||
|
|
||||||
###### decode #####
|
###### decode #####
|
||||||
y_list = [None]*y.shape[0]
|
y_list = [None] * y.shape[0]
|
||||||
batch_idx_map = list(range(y.shape[0]))
|
batch_idx_map = list(range(y.shape[0]))
|
||||||
idx_list = [None]*y.shape[0]
|
idx_list = [None] * y.shape[0]
|
||||||
for idx in tqdm(range(1500)):
|
for idx in tqdm(range(1500)):
|
||||||
if idx == 0:
|
if idx == 0:
|
||||||
xy_dec, k_cache, v_cache = self.t2s_transformer.process_prompt(xy_pos, xy_attn_mask, xy_padding_mask, False)
|
xy_dec, k_cache, v_cache = self.t2s_transformer.process_prompt(xy_pos, attn_mask, None)
|
||||||
else:
|
else:
|
||||||
xy_dec, k_cache, v_cache = self.t2s_transformer.decode_next_token(xy_pos, k_cache, v_cache, xy_attn_mask, False)
|
xy_dec, k_cache, v_cache = self.t2s_transformer.decode_next_token(xy_pos, k_cache, v_cache, attn_mask)
|
||||||
logits = self.ar_predict_layer(
|
logits = self.ar_predict_layer(xy_dec[:, -1])
|
||||||
xy_dec[:, -1]
|
|
||||||
)
|
|
||||||
|
|
||||||
if idx == 0:
|
if idx == 0:
|
||||||
xy_attn_mask = F.pad(xy_attn_mask[:,:,-1].unsqueeze(-2),(0,1),value=False)
|
attn_mask = F.pad(attn_mask[:, :, -1].unsqueeze(-2), (0, 1), value=False)
|
||||||
logits = logits[:, :-1]
|
logits = logits[:, :-1]
|
||||||
else:
|
else:
|
||||||
xy_attn_mask = F.pad(xy_attn_mask,(0,1),value=False)
|
attn_mask = F.pad(attn_mask, (0, 1), value=False)
|
||||||
|
|
||||||
samples = sample(
|
samples = sample(
|
||||||
logits, y, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty, temperature=temperature
|
logits, y, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty, temperature=temperature
|
||||||
)[0]
|
)[0]
|
||||||
|
|
||||||
y = torch.concat([y, samples], dim=1)
|
y = torch.concat([y, samples], dim=1)
|
||||||
|
|
||||||
####### 移除batch中已经生成完毕的序列,进一步优化计算量
|
####### 移除batch中已经生成完毕的序列,进一步优化计算量
|
||||||
tokens = torch.argmax(logits, dim=-1)
|
tokens = torch.argmax(logits, dim=-1)
|
||||||
reserved_idx_of_batch_for_y = None
|
reserved_idx_of_batch_for_y = None
|
||||||
if (self.EOS in samples[:, 0]) or \
|
if (self.EOS in samples[:, 0]) or (self.EOS in tokens): ###如果生成到EOS,则停止
|
||||||
(self.EOS in tokens): ###如果生成到EOS,则停止
|
l1 = samples[:, 0] == self.EOS
|
||||||
l1 = samples[:, 0]==self.EOS
|
l2 = tokens == self.EOS
|
||||||
l2 = tokens==self.EOS
|
l = l1.logical_or(l2)
|
||||||
l = l1.logical_or(l2)
|
removed_idx_of_batch_for_y = torch.where(l == True)[0].tolist()
|
||||||
removed_idx_of_batch_for_y = torch.where(l==True)[0].tolist()
|
reserved_idx_of_batch_for_y = torch.where(l == False)[0]
|
||||||
reserved_idx_of_batch_for_y = torch.where(l==False)[0]
|
# batch_indexs = torch.tensor(batch_idx_map, device=y.device)[removed_idx_of_batch_for_y]
|
||||||
# batch_indexs = torch.tensor(batch_idx_map, device=y.device)[removed_idx_of_batch_for_y]
|
for i in removed_idx_of_batch_for_y:
|
||||||
for i in removed_idx_of_batch_for_y:
|
batch_index = batch_idx_map[i]
|
||||||
batch_index = batch_idx_map[i]
|
idx_list[batch_index] = idx
|
||||||
idx_list[batch_index] = idx - 1
|
y_list[batch_index] = y[i, :-1]
|
||||||
y_list[batch_index] = y[i, :-1]
|
|
||||||
|
batch_idx_map = [batch_idx_map[i] for i in reserved_idx_of_batch_for_y.tolist()]
|
||||||
batch_idx_map = [batch_idx_map[i] for i in reserved_idx_of_batch_for_y.tolist()]
|
|
||||||
|
# 只保留batch中未生成完毕的序列
|
||||||
# 只保留batch中未生成完毕的序列
|
|
||||||
if reserved_idx_of_batch_for_y is not None:
|
if reserved_idx_of_batch_for_y is not None:
|
||||||
# index = torch.LongTensor(batch_idx_map).to(y.device)
|
# index = torch.LongTensor(batch_idx_map).to(y.device)
|
||||||
y = torch.index_select(y, dim=0, index=reserved_idx_of_batch_for_y)
|
y = torch.index_select(y, dim=0, index=reserved_idx_of_batch_for_y)
|
||||||
xy_attn_mask = torch.index_select(xy_attn_mask, dim=0, index=reserved_idx_of_batch_for_y)
|
attn_mask = torch.index_select(attn_mask, dim=0, index=reserved_idx_of_batch_for_y)
|
||||||
if k_cache is not None :
|
if k_cache is not None:
|
||||||
for i in range(len(k_cache)):
|
for i in range(len(k_cache)):
|
||||||
k_cache[i] = torch.index_select(k_cache[i], dim=0, index=reserved_idx_of_batch_for_y)
|
k_cache[i] = torch.index_select(k_cache[i], dim=0, index=reserved_idx_of_batch_for_y)
|
||||||
v_cache[i] = torch.index_select(v_cache[i], dim=0, index=reserved_idx_of_batch_for_y)
|
v_cache[i] = torch.index_select(v_cache[i], dim=0, index=reserved_idx_of_batch_for_y)
|
||||||
|
|
||||||
|
if (early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num) or idx == 1499:
|
||||||
if (early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num) or idx==1499:
|
|
||||||
print("use early stop num:", early_stop_num)
|
print("use early stop num:", early_stop_num)
|
||||||
stop = True
|
stop = True
|
||||||
for i, batch_index in enumerate(batch_idx_map):
|
for i, batch_index in enumerate(batch_idx_map):
|
||||||
batch_index = batch_idx_map[i]
|
batch_index = batch_idx_map[i]
|
||||||
idx_list[batch_index] = idx
|
idx_list[batch_index] = idx
|
||||||
y_list[batch_index] = y[i, :-1]
|
y_list[batch_index] = y[i, :-1]
|
||||||
|
|
||||||
if not (None in idx_list):
|
if None not in idx_list:
|
||||||
stop = True
|
stop = True
|
||||||
|
|
||||||
if stop:
|
if stop:
|
||||||
if y.shape[1]==0:
|
if y.shape[1] == 0:
|
||||||
y = torch.concat([y, torch.zeros_like(samples)], dim=1)
|
y = torch.concat([y, torch.zeros_like(samples)], dim=1)
|
||||||
print("bad zero prediction")
|
print("bad zero prediction")
|
||||||
print(f"T2S Decoding EOS [{prefix_len} -> {y.shape[1]}]")
|
print(f"T2S Decoding EOS [{prefix_len} -> {y.shape[1]}]")
|
||||||
@ -738,60 +764,65 @@ class Text2SemanticDecoder(nn.Module):
|
|||||||
|
|
||||||
####################### update next step ###################################
|
####################### update next step ###################################
|
||||||
y_emb = self.ar_audio_embedding(y[:, -1:])
|
y_emb = self.ar_audio_embedding(y[:, -1:])
|
||||||
xy_pos = y_emb * self.ar_audio_position.x_scale + self.ar_audio_position.alpha * self.ar_audio_position.pe[:, y_len + idx].to( dtype= y_emb.dtype,device=y_emb.device)
|
xy_pos = y_emb * self.ar_audio_position.x_scale + self.ar_audio_position.alpha * self.ar_audio_position.pe[
|
||||||
|
:, y_len + idx
|
||||||
|
].to(dtype=y_emb.dtype, device=y_emb.device)
|
||||||
|
|
||||||
if (None in idx_list):
|
if None in idx_list:
|
||||||
for i in range(x.shape[0]):
|
for i in range(x.shape[0]):
|
||||||
if idx_list[i] is None:
|
if idx_list[i] is None:
|
||||||
idx_list[i] = 1500-1 ###如果没有生成到EOS,就用最大长度代替
|
idx_list[i] = 1500 - 1 ###如果没有生成到EOS,就用最大长度代替
|
||||||
|
|
||||||
if ref_free:
|
if ref_free:
|
||||||
return y_list, [0]*x.shape[0]
|
return y_list, [0] * x.shape[0]
|
||||||
# print(idx_list)
|
# print(idx_list)
|
||||||
return y_list, idx_list
|
return y_list, idx_list
|
||||||
|
|
||||||
def infer_panel_naive_batched(self,
|
def infer_panel_naive_batched(
|
||||||
x:List[torch.LongTensor], #####全部文本token
|
self,
|
||||||
x_lens:torch.LongTensor,
|
x: List[torch.LongTensor], #####全部文本token
|
||||||
prompts:torch.LongTensor, ####参考音频token
|
x_lens: torch.LongTensor,
|
||||||
bert_feature:List[torch.LongTensor],
|
prompts: torch.LongTensor, ####参考音频token
|
||||||
|
bert_feature: List[torch.LongTensor],
|
||||||
top_k: int = -100,
|
top_k: int = -100,
|
||||||
top_p: int = 100,
|
top_p: int = 100,
|
||||||
early_stop_num: int = -1,
|
early_stop_num: int = -1,
|
||||||
temperature: float = 1.0,
|
temperature: float = 1.0,
|
||||||
repetition_penalty: float = 1.35,
|
repetition_penalty: float = 1.35,
|
||||||
**kwargs
|
**kwargs,
|
||||||
):
|
):
|
||||||
y_list = []
|
y_list = []
|
||||||
idx_list = []
|
idx_list = []
|
||||||
for i in range(len(x)):
|
for i in range(len(x)):
|
||||||
y, idx = self.infer_panel_naive(x[i].unsqueeze(0),
|
y, idx = self.infer_panel_naive(
|
||||||
x_lens[i],
|
x[i].unsqueeze(0),
|
||||||
prompts[i].unsqueeze(0) if prompts is not None else None,
|
x_lens[i],
|
||||||
bert_feature[i].unsqueeze(0),
|
prompts[i].unsqueeze(0) if prompts is not None else None,
|
||||||
top_k,
|
bert_feature[i].unsqueeze(0),
|
||||||
top_p,
|
top_k,
|
||||||
early_stop_num,
|
top_p,
|
||||||
temperature,
|
early_stop_num,
|
||||||
repetition_penalty,
|
temperature,
|
||||||
**kwargs)
|
repetition_penalty,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
y_list.append(y[0])
|
y_list.append(y[0])
|
||||||
idx_list.append(idx)
|
idx_list.append(idx)
|
||||||
|
|
||||||
return y_list, idx_list
|
return y_list, idx_list
|
||||||
|
|
||||||
def infer_panel_naive(
|
def infer_panel_naive(
|
||||||
self,
|
self,
|
||||||
x:torch.LongTensor, #####全部文本token
|
x: torch.LongTensor, #####全部文本token
|
||||||
x_lens:torch.LongTensor,
|
x_lens: torch.LongTensor,
|
||||||
prompts:torch.LongTensor, ####参考音频token
|
prompts: torch.LongTensor, ####参考音频token
|
||||||
bert_feature:torch.LongTensor,
|
bert_feature: torch.LongTensor,
|
||||||
top_k: int = -100,
|
top_k: int = -100,
|
||||||
top_p: int = 100,
|
top_p: int = 100,
|
||||||
early_stop_num: int = -1,
|
early_stop_num: int = -1,
|
||||||
temperature: float = 1.0,
|
temperature: float = 1.0,
|
||||||
repetition_penalty: float = 1.35,
|
repetition_penalty: float = 1.35,
|
||||||
**kwargs
|
**kwargs,
|
||||||
):
|
):
|
||||||
x = self.ar_text_embedding(x)
|
x = self.ar_text_embedding(x)
|
||||||
x = x + self.bert_proj(bert_feature.transpose(1, 2))
|
x = x + self.bert_proj(bert_feature.transpose(1, 2))
|
||||||
@ -836,11 +867,13 @@ class Text2SemanticDecoder(nn.Module):
|
|||||||
(x_len, 0),
|
(x_len, 0),
|
||||||
value=False,
|
value=False,
|
||||||
)
|
)
|
||||||
xy_attn_mask = torch.concat([x_attn_mask_pad, y_attn_mask], dim=0)\
|
xy_attn_mask = (
|
||||||
.unsqueeze(0)\
|
torch.concat([x_attn_mask_pad, y_attn_mask], dim=0)
|
||||||
.expand(bsz*self.num_head, -1, -1)\
|
.unsqueeze(0)
|
||||||
.view(bsz, self.num_head, src_len, src_len)\
|
.expand(bsz * self.num_head, -1, -1)
|
||||||
.to(device=x.device, dtype=torch.bool)
|
.view(bsz, self.num_head, src_len, src_len)
|
||||||
|
.to(device=x.device, dtype=torch.bool)
|
||||||
|
)
|
||||||
|
|
||||||
for idx in tqdm(range(1500)):
|
for idx in tqdm(range(1500)):
|
||||||
if xy_attn_mask is not None:
|
if xy_attn_mask is not None:
|
||||||
@ -848,12 +881,11 @@ class Text2SemanticDecoder(nn.Module):
|
|||||||
else:
|
else:
|
||||||
xy_dec, k_cache, v_cache = self.t2s_transformer.decode_next_token(xy_pos, k_cache, v_cache)
|
xy_dec, k_cache, v_cache = self.t2s_transformer.decode_next_token(xy_pos, k_cache, v_cache)
|
||||||
|
|
||||||
logits = self.ar_predict_layer(
|
logits = self.ar_predict_layer(xy_dec[:, -1])
|
||||||
xy_dec[:, -1]
|
|
||||||
)
|
|
||||||
|
|
||||||
if idx == 0:
|
if idx == 0:
|
||||||
xy_attn_mask = None
|
xy_attn_mask = None
|
||||||
|
if idx < 11: ###至少预测出10个token不然不给停止(0.4s)
|
||||||
logits = logits[:, :-1]
|
logits = logits[:, :-1]
|
||||||
|
|
||||||
samples = sample(
|
samples = sample(
|
||||||
@ -877,24 +909,27 @@ class Text2SemanticDecoder(nn.Module):
|
|||||||
|
|
||||||
####################### update next step ###################################
|
####################### update next step ###################################
|
||||||
y_emb = self.ar_audio_embedding(y[:, -1:])
|
y_emb = self.ar_audio_embedding(y[:, -1:])
|
||||||
xy_pos = y_emb * self.ar_audio_position.x_scale + self.ar_audio_position.alpha * self.ar_audio_position.pe[:, y_len + idx].to(dtype=y_emb.dtype,device=y_emb.device)
|
xy_pos = y_emb * self.ar_audio_position.x_scale + self.ar_audio_position.alpha * self.ar_audio_position.pe[
|
||||||
|
:, y_len + idx
|
||||||
|
].to(dtype=y_emb.dtype, device=y_emb.device)
|
||||||
|
|
||||||
if ref_free:
|
if ref_free:
|
||||||
return y[:, :-1], 0
|
return y[:, :-1], 0
|
||||||
return y[:, :-1], idx - 1
|
return y[:, :-1], idx
|
||||||
|
|
||||||
|
|
||||||
def infer_panel(
|
def infer_panel(
|
||||||
self,
|
self,
|
||||||
x:torch.LongTensor, #####全部文本token
|
x: torch.LongTensor, #####全部文本token
|
||||||
x_lens:torch.LongTensor,
|
x_lens: torch.LongTensor,
|
||||||
prompts:torch.LongTensor, ####参考音频token
|
prompts: torch.LongTensor, ####参考音频token
|
||||||
bert_feature:torch.LongTensor,
|
bert_feature: torch.LongTensor,
|
||||||
top_k: int = -100,
|
top_k: int = -100,
|
||||||
top_p: int = 100,
|
top_p: int = 100,
|
||||||
early_stop_num: int = -1,
|
early_stop_num: int = -1,
|
||||||
temperature: float = 1.0,
|
temperature: float = 1.0,
|
||||||
repetition_penalty: float = 1.35,
|
repetition_penalty: float = 1.35,
|
||||||
**kwargs
|
**kwargs,
|
||||||
):
|
):
|
||||||
return self.infer_panel_naive(x, x_lens, prompts, bert_feature, top_k, top_p, early_stop_num, temperature, repetition_penalty, **kwargs)
|
return self.infer_panel_naive(
|
||||||
|
x, x_lens, prompts, bert_feature, top_k, top_p, early_stop_num, temperature, repetition_penalty, **kwargs
|
||||||
|
)
|
||||||
|
@ -1,17 +1,13 @@
|
|||||||
# modified from https://github.com/yangdongchao/SoundStorm/blob/master/soundstorm/s1/AR/models/t2s_model.py
|
# modified from https://github.com/yangdongchao/SoundStorm/blob/master/soundstorm/s1/AR/models/t2s_model.py
|
||||||
# reference: https://github.com/lifeiteng/vall-e
|
# reference: https://github.com/lifeiteng/vall-e
|
||||||
import torch
|
import torch
|
||||||
from tqdm import tqdm
|
|
||||||
|
|
||||||
from AR.modules.embedding_onnx import SinePositionalEmbedding
|
|
||||||
from AR.modules.embedding_onnx import TokenEmbedding
|
|
||||||
from AR.modules.transformer_onnx import LayerNorm
|
|
||||||
from AR.modules.transformer_onnx import TransformerEncoder
|
|
||||||
from AR.modules.transformer_onnx import TransformerEncoderLayer
|
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.nn import functional as F
|
from torch.nn import functional as F
|
||||||
from torchmetrics.classification import MulticlassAccuracy
|
from torchmetrics.classification import MulticlassAccuracy
|
||||||
|
|
||||||
|
from AR.modules.embedding_onnx import SinePositionalEmbedding, TokenEmbedding
|
||||||
|
from AR.modules.transformer_onnx import LayerNorm, TransformerEncoder, TransformerEncoderLayer
|
||||||
|
|
||||||
default_config = {
|
default_config = {
|
||||||
"embedding_dim": 512,
|
"embedding_dim": 512,
|
||||||
"hidden_dim": 512,
|
"hidden_dim": 512,
|
||||||
@ -26,12 +22,13 @@ default_config = {
|
|||||||
|
|
||||||
inf_tensor_value = torch.FloatTensor([-float("Inf")]).float()
|
inf_tensor_value = torch.FloatTensor([-float("Inf")]).float()
|
||||||
|
|
||||||
|
|
||||||
def logits_to_probs(
|
def logits_to_probs(
|
||||||
logits,
|
logits,
|
||||||
previous_tokens = None,
|
previous_tokens=None,
|
||||||
temperature: float = 1.0,
|
temperature: float = 1.0,
|
||||||
top_k = None,
|
top_k=None,
|
||||||
top_p = None,
|
top_p=None,
|
||||||
repetition_penalty: float = 1.0,
|
repetition_penalty: float = 1.0,
|
||||||
):
|
):
|
||||||
previous_tokens = previous_tokens.squeeze()
|
previous_tokens = previous_tokens.squeeze()
|
||||||
@ -39,19 +36,27 @@ def logits_to_probs(
|
|||||||
previous_tokens = previous_tokens.long()
|
previous_tokens = previous_tokens.long()
|
||||||
score = torch.gather(logits, dim=0, index=previous_tokens)
|
score = torch.gather(logits, dim=0, index=previous_tokens)
|
||||||
score = torch.where(
|
score = torch.where(
|
||||||
score < 0, score * repetition_penalty, score / repetition_penalty
|
score < 0,
|
||||||
|
score * repetition_penalty,
|
||||||
|
score / repetition_penalty,
|
||||||
)
|
)
|
||||||
logits.scatter_(dim=0, index=previous_tokens, src=score)
|
logits.scatter_(dim=0, index=previous_tokens, src=score)
|
||||||
|
|
||||||
if top_p is not None and top_p < 1.0:
|
if top_p is not None and top_p < 1.0:
|
||||||
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
|
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
|
||||||
cum_probs = torch.cumsum(
|
cum_probs = torch.cumsum(
|
||||||
torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1
|
torch.nn.functional.softmax(
|
||||||
|
sorted_logits,
|
||||||
|
dim=-1,
|
||||||
|
),
|
||||||
|
dim=-1,
|
||||||
)
|
)
|
||||||
sorted_indices_to_remove = cum_probs > top_p
|
sorted_indices_to_remove = cum_probs > top_p
|
||||||
sorted_indices_to_remove[0] = False # keep at least one option
|
sorted_indices_to_remove[0] = False # keep at least one option
|
||||||
indices_to_remove = sorted_indices_to_remove.scatter(
|
indices_to_remove = sorted_indices_to_remove.scatter(
|
||||||
dim=0, index=sorted_indices, src=sorted_indices_to_remove
|
dim=0,
|
||||||
|
index=sorted_indices,
|
||||||
|
src=sorted_indices_to_remove,
|
||||||
)
|
)
|
||||||
logits = logits.masked_fill(indices_to_remove, -float("Inf"))
|
logits = logits.masked_fill(indices_to_remove, -float("Inf"))
|
||||||
|
|
||||||
@ -67,7 +72,7 @@ def logits_to_probs(
|
|||||||
|
|
||||||
|
|
||||||
def multinomial_sample_one_no_sync(
|
def multinomial_sample_one_no_sync(
|
||||||
probs_sort
|
probs_sort,
|
||||||
): # Does multinomial sampling without a cuda synchronization
|
): # Does multinomial sampling without a cuda synchronization
|
||||||
q = torch.randn_like(probs_sort)
|
q = torch.randn_like(probs_sort)
|
||||||
return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int)
|
return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int)
|
||||||
@ -79,7 +84,9 @@ def sample(
|
|||||||
**sampling_kwargs,
|
**sampling_kwargs,
|
||||||
):
|
):
|
||||||
probs = logits_to_probs(
|
probs = logits_to_probs(
|
||||||
logits=logits, previous_tokens=previous_tokens, **sampling_kwargs
|
logits=logits,
|
||||||
|
previous_tokens=previous_tokens,
|
||||||
|
**sampling_kwargs,
|
||||||
)
|
)
|
||||||
idx_next = multinomial_sample_one_no_sync(probs)
|
idx_next = multinomial_sample_one_no_sync(probs)
|
||||||
return idx_next, probs
|
return idx_next, probs
|
||||||
@ -91,7 +98,7 @@ class OnnxEncoder(nn.Module):
|
|||||||
self.ar_text_embedding = ar_text_embedding
|
self.ar_text_embedding = ar_text_embedding
|
||||||
self.bert_proj = bert_proj
|
self.bert_proj = bert_proj
|
||||||
self.ar_text_position = ar_text_position
|
self.ar_text_position = ar_text_position
|
||||||
|
|
||||||
def forward(self, x, bert_feature):
|
def forward(self, x, bert_feature):
|
||||||
x = self.ar_text_embedding(x)
|
x = self.ar_text_embedding(x)
|
||||||
x = x + self.bert_proj(bert_feature.transpose(1, 2))
|
x = x + self.bert_proj(bert_feature.transpose(1, 2))
|
||||||
@ -99,8 +106,18 @@ class OnnxEncoder(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class T2SFirstStageDecoder(nn.Module):
|
class T2SFirstStageDecoder(nn.Module):
|
||||||
def __init__(self, ar_audio_embedding, ar_audio_position, h, ar_predict_layer, loss_fct, ar_accuracy_metric,
|
def __init__(
|
||||||
top_k, early_stop_num, num_layers):
|
self,
|
||||||
|
ar_audio_embedding,
|
||||||
|
ar_audio_position,
|
||||||
|
h,
|
||||||
|
ar_predict_layer,
|
||||||
|
loss_fct,
|
||||||
|
ar_accuracy_metric,
|
||||||
|
top_k,
|
||||||
|
early_stop_num,
|
||||||
|
num_layers,
|
||||||
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.ar_audio_embedding = ar_audio_embedding
|
self.ar_audio_embedding = ar_audio_embedding
|
||||||
self.ar_audio_position = ar_audio_position
|
self.ar_audio_position = ar_audio_position
|
||||||
@ -111,11 +128,11 @@ class T2SFirstStageDecoder(nn.Module):
|
|||||||
self.top_k = top_k
|
self.top_k = top_k
|
||||||
self.early_stop_num = early_stop_num
|
self.early_stop_num = early_stop_num
|
||||||
self.num_layers = num_layers
|
self.num_layers = num_layers
|
||||||
|
|
||||||
def forward(self, x, prompt):
|
def forward(self, x, prompt):
|
||||||
y = prompt
|
y = prompt
|
||||||
x_example = x[:,:,0] * 0.0
|
x_example = x[:, :, 0] * 0.0
|
||||||
#N, 1, 512
|
# N, 1, 512
|
||||||
cache = {
|
cache = {
|
||||||
"all_stage": self.num_layers,
|
"all_stage": self.num_layers,
|
||||||
"k": None,
|
"k": None,
|
||||||
@ -132,11 +149,15 @@ class T2SFirstStageDecoder(nn.Module):
|
|||||||
|
|
||||||
xy_pos = torch.concat([x, y_pos], dim=1)
|
xy_pos = torch.concat([x, y_pos], dim=1)
|
||||||
|
|
||||||
y_example = y_pos[:,:,0] * 0.0
|
y_example = y_pos[:, :, 0] * 0.0
|
||||||
x_attn_mask = torch.matmul(x_example.transpose(0, 1) , x_example).bool()
|
x_attn_mask = torch.matmul(x_example.transpose(0, 1), x_example).bool()
|
||||||
y_attn_mask = torch.ones_like(torch.matmul(y_example.transpose(0, 1), y_example), dtype=torch.int64)
|
y_attn_mask = torch.ones_like(torch.matmul(y_example.transpose(0, 1), y_example), dtype=torch.int64)
|
||||||
y_attn_mask = torch.cumsum(y_attn_mask, dim=1) - torch.cumsum(
|
y_attn_mask = torch.cumsum(y_attn_mask, dim=1) - torch.cumsum(
|
||||||
torch.ones_like(y_example.transpose(0, 1), dtype=torch.int64), dim=0
|
torch.ones_like(
|
||||||
|
y_example.transpose(0, 1),
|
||||||
|
dtype=torch.int64,
|
||||||
|
),
|
||||||
|
dim=0,
|
||||||
)
|
)
|
||||||
y_attn_mask = y_attn_mask > 0
|
y_attn_mask = y_attn_mask > 0
|
||||||
|
|
||||||
@ -145,10 +166,16 @@ class T2SFirstStageDecoder(nn.Module):
|
|||||||
x_attn_mask_pad = torch.cat([x_attn_mask, torch.ones_like(x_y_pad)], dim=1)
|
x_attn_mask_pad = torch.cat([x_attn_mask, torch.ones_like(x_y_pad)], dim=1)
|
||||||
y_attn_mask = torch.cat([y_x_pad, y_attn_mask], dim=1)
|
y_attn_mask = torch.cat([y_x_pad, y_attn_mask], dim=1)
|
||||||
xy_attn_mask = torch.concat([x_attn_mask_pad, y_attn_mask], dim=0)
|
xy_attn_mask = torch.concat([x_attn_mask_pad, y_attn_mask], dim=0)
|
||||||
cache["k"] = torch.matmul(x_attn_mask_pad[0].float().unsqueeze(-1), torch.zeros((1, 512)))\
|
cache["k"] = (
|
||||||
.unsqueeze(1).repeat(self.num_layers, 1, 1, 1)
|
torch.matmul(x_attn_mask_pad[0].float().unsqueeze(-1), torch.zeros((1, 512)))
|
||||||
cache["v"] = torch.matmul(x_attn_mask_pad[0].float().unsqueeze(-1), torch.zeros((1, 512)))\
|
.unsqueeze(1)
|
||||||
.unsqueeze(1).repeat(self.num_layers, 1, 1, 1)
|
.repeat(self.num_layers, 1, 1, 1)
|
||||||
|
)
|
||||||
|
cache["v"] = (
|
||||||
|
torch.matmul(x_attn_mask_pad[0].float().unsqueeze(-1), torch.zeros((1, 512)))
|
||||||
|
.unsqueeze(1)
|
||||||
|
.repeat(self.num_layers, 1, 1, 1)
|
||||||
|
)
|
||||||
|
|
||||||
xy_dec = self.h(xy_pos, mask=xy_attn_mask, cache=cache)
|
xy_dec = self.h(xy_pos, mask=xy_attn_mask, cache=cache)
|
||||||
logits = self.ar_predict_layer(xy_dec[:, -1])
|
logits = self.ar_predict_layer(xy_dec[:, -1])
|
||||||
@ -160,8 +187,18 @@ class T2SFirstStageDecoder(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class T2SStageDecoder(nn.Module):
|
class T2SStageDecoder(nn.Module):
|
||||||
def __init__(self, ar_audio_embedding, ar_audio_position, h, ar_predict_layer, loss_fct, ar_accuracy_metric,
|
def __init__(
|
||||||
top_k, early_stop_num, num_layers):
|
self,
|
||||||
|
ar_audio_embedding,
|
||||||
|
ar_audio_position,
|
||||||
|
h,
|
||||||
|
ar_predict_layer,
|
||||||
|
loss_fct,
|
||||||
|
ar_accuracy_metric,
|
||||||
|
top_k,
|
||||||
|
early_stop_num,
|
||||||
|
num_layers,
|
||||||
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.ar_audio_embedding = ar_audio_embedding
|
self.ar_audio_embedding = ar_audio_embedding
|
||||||
self.ar_audio_position = ar_audio_position
|
self.ar_audio_position = ar_audio_position
|
||||||
@ -184,14 +221,18 @@ class T2SStageDecoder(nn.Module):
|
|||||||
}
|
}
|
||||||
|
|
||||||
y_emb = torch.cat(
|
y_emb = torch.cat(
|
||||||
[cache["y_emb"], self.ar_audio_embedding(y[:, -1:])], 1
|
[
|
||||||
|
cache["y_emb"],
|
||||||
|
self.ar_audio_embedding(y[:, -1:]),
|
||||||
|
],
|
||||||
|
1,
|
||||||
)
|
)
|
||||||
cache["y_emb"] = y_emb
|
cache["y_emb"] = y_emb
|
||||||
y_pos = self.ar_audio_position(y_emb)
|
y_pos = self.ar_audio_position(y_emb)
|
||||||
|
|
||||||
xy_pos = y_pos[:, -1:]
|
xy_pos = y_pos[:, -1:]
|
||||||
|
|
||||||
y_example = y_pos[:,:,0] * 0.0
|
y_example = y_pos[:, :, 0] * 0.0
|
||||||
|
|
||||||
xy_attn_mask = torch.cat([x_example, y_example], dim=1)
|
xy_attn_mask = torch.cat([x_example, y_example], dim=1)
|
||||||
xy_attn_mask = torch.zeros_like(xy_attn_mask, dtype=torch.bool)
|
xy_attn_mask = torch.zeros_like(xy_attn_mask, dtype=torch.bool)
|
||||||
@ -250,12 +291,28 @@ class Text2SemanticDecoder(nn.Module):
|
|||||||
|
|
||||||
def init_onnx(self):
|
def init_onnx(self):
|
||||||
self.onnx_encoder = OnnxEncoder(self.ar_text_embedding, self.bert_proj, self.ar_text_position)
|
self.onnx_encoder = OnnxEncoder(self.ar_text_embedding, self.bert_proj, self.ar_text_position)
|
||||||
self.first_stage_decoder = T2SFirstStageDecoder(self.ar_audio_embedding, self.ar_audio_position, self.h,
|
self.first_stage_decoder = T2SFirstStageDecoder(
|
||||||
self.ar_predict_layer, self.loss_fct, self.ar_accuracy_metric, self.top_k, self.early_stop_num,
|
self.ar_audio_embedding,
|
||||||
self.num_layers)
|
self.ar_audio_position,
|
||||||
self.stage_decoder = T2SStageDecoder(self.ar_audio_embedding, self.ar_audio_position, self.h,
|
self.h,
|
||||||
self.ar_predict_layer, self.loss_fct, self.ar_accuracy_metric, self.top_k, self.early_stop_num,
|
self.ar_predict_layer,
|
||||||
self.num_layers)
|
self.loss_fct,
|
||||||
|
self.ar_accuracy_metric,
|
||||||
|
self.top_k,
|
||||||
|
self.early_stop_num,
|
||||||
|
self.num_layers,
|
||||||
|
)
|
||||||
|
self.stage_decoder = T2SStageDecoder(
|
||||||
|
self.ar_audio_embedding,
|
||||||
|
self.ar_audio_position,
|
||||||
|
self.h,
|
||||||
|
self.ar_predict_layer,
|
||||||
|
self.loss_fct,
|
||||||
|
self.ar_accuracy_metric,
|
||||||
|
self.top_k,
|
||||||
|
self.early_stop_num,
|
||||||
|
self.num_layers,
|
||||||
|
)
|
||||||
|
|
||||||
def forward(self, x, prompts, bert_feature):
|
def forward(self, x, prompts, bert_feature):
|
||||||
early_stop_num = self.early_stop_num
|
early_stop_num = self.early_stop_num
|
||||||
@ -286,7 +343,7 @@ class Text2SemanticDecoder(nn.Module):
|
|||||||
y = prompts
|
y = prompts
|
||||||
prefix_len = y.shape[1]
|
prefix_len = y.shape[1]
|
||||||
x_len = x.shape[1]
|
x_len = x.shape[1]
|
||||||
x_example = x[:,:,0] * 0.0
|
x_example = x[:, :, 0] * 0.0
|
||||||
x_attn_mask = torch.matmul(x_example.transpose(0, 1), x_example)
|
x_attn_mask = torch.matmul(x_example.transpose(0, 1), x_example)
|
||||||
x_attn_mask = torch.zeros_like(x_attn_mask, dtype=torch.bool)
|
x_attn_mask = torch.zeros_like(x_attn_mask, dtype=torch.bool)
|
||||||
|
|
||||||
@ -303,9 +360,7 @@ class Text2SemanticDecoder(nn.Module):
|
|||||||
if cache["first_infer"] == 1:
|
if cache["first_infer"] == 1:
|
||||||
y_emb = self.ar_audio_embedding(y)
|
y_emb = self.ar_audio_embedding(y)
|
||||||
else:
|
else:
|
||||||
y_emb = torch.cat(
|
y_emb = torch.cat([cache["y_emb"], self.ar_audio_embedding(y[:, -1:])], 1)
|
||||||
[cache["y_emb"], self.ar_audio_embedding(y[:, -1:])], 1
|
|
||||||
)
|
|
||||||
cache["y_emb"] = y_emb
|
cache["y_emb"] = y_emb
|
||||||
y_pos = self.ar_audio_position(y_emb)
|
y_pos = self.ar_audio_position(y_emb)
|
||||||
if cache["first_infer"] == 1:
|
if cache["first_infer"] == 1:
|
||||||
@ -317,7 +372,8 @@ class Text2SemanticDecoder(nn.Module):
|
|||||||
x_attn_mask_pad = F.pad(x_attn_mask, (0, y_len), value=True)
|
x_attn_mask_pad = F.pad(x_attn_mask, (0, y_len), value=True)
|
||||||
y_attn_mask = F.pad(
|
y_attn_mask = F.pad(
|
||||||
torch.triu(torch.ones(y_len, y_len, dtype=torch.bool), diagonal=1),
|
torch.triu(torch.ones(y_len, y_len, dtype=torch.bool), diagonal=1),
|
||||||
(x_len, 0), value=False
|
(x_len, 0),
|
||||||
|
value=False,
|
||||||
)
|
)
|
||||||
xy_attn_mask = torch.concat([x_attn_mask_pad, y_attn_mask], dim=0)
|
xy_attn_mask = torch.concat([x_attn_mask_pad, y_attn_mask], dim=0)
|
||||||
else:
|
else:
|
||||||
@ -335,4 +391,4 @@ class Text2SemanticDecoder(nn.Module):
|
|||||||
break
|
break
|
||||||
y = torch.concat([y, samples], dim=1)
|
y = torch.concat([y, samples], dim=1)
|
||||||
cache["first_infer"] = 0
|
cache["first_infer"] = 0
|
||||||
return y, idx
|
return y, idx
|
||||||
|
@ -1,8 +1,10 @@
|
|||||||
# modified from https://github.com/yangdongchao/SoundStorm/blob/master/soundstorm/s1/AR/models/utils.py
|
# modified from https://github.com/yangdongchao/SoundStorm/blob/master/soundstorm/s1/AR/models/utils.py
|
||||||
# reference: https://github.com/lifeiteng/vall-e
|
# reference: https://github.com/lifeiteng/vall-e
|
||||||
|
from typing import Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from typing import Tuple
|
|
||||||
|
|
||||||
def sequence_mask(length, max_length=None):
|
def sequence_mask(length, max_length=None):
|
||||||
if max_length is None:
|
if max_length is None:
|
||||||
@ -39,9 +41,46 @@ def make_pad_mask(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor:
|
|||||||
return expaned_lengths >= lengths.unsqueeze(-1)
|
return expaned_lengths >= lengths.unsqueeze(-1)
|
||||||
|
|
||||||
|
|
||||||
|
def make_pad_mask_left(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
lengths:
|
||||||
|
A 1-D tensor containing sentence lengths.
|
||||||
|
max_len:
|
||||||
|
The length of masks.
|
||||||
|
Returns:
|
||||||
|
Return a 2-D bool tensor, where masked positions
|
||||||
|
are filled with `True` and non-masked positions are
|
||||||
|
filled with `False`.
|
||||||
|
|
||||||
|
#>>> lengths = torch.tensor([1, 3, 2, 5])
|
||||||
|
#>>> make_pad_mask(lengths)
|
||||||
|
tensor(
|
||||||
|
[
|
||||||
|
[True, True, False],
|
||||||
|
[True, False, False],
|
||||||
|
[True, True, False],
|
||||||
|
...
|
||||||
|
]
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
assert lengths.ndim == 1, lengths.ndim
|
||||||
|
max_len = max(max_len, lengths.max())
|
||||||
|
n = lengths.size(0)
|
||||||
|
seq_range = torch.arange(0, max_len, device=lengths.device)
|
||||||
|
expaned_lengths = seq_range.unsqueeze(0).repeat(n, 1)
|
||||||
|
expaned_lengths -= (max_len - lengths).unsqueeze(-1)
|
||||||
|
|
||||||
|
return expaned_lengths < 0
|
||||||
|
|
||||||
|
|
||||||
# https://github.com/microsoft/unilm/blob/master/xtune/src/transformers/modeling_utils.py
|
# https://github.com/microsoft/unilm/blob/master/xtune/src/transformers/modeling_utils.py
|
||||||
def top_k_top_p_filtering(
|
def top_k_top_p_filtering(
|
||||||
logits, top_k=0, top_p=1.0, filter_value=-float("Inf"), min_tokens_to_keep=1
|
logits,
|
||||||
|
top_k=0,
|
||||||
|
top_p=1.0,
|
||||||
|
filter_value=-float("Inf"),
|
||||||
|
min_tokens_to_keep=1,
|
||||||
):
|
):
|
||||||
"""Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
|
"""Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
|
||||||
Args:
|
Args:
|
||||||
@ -72,9 +111,7 @@ def top_k_top_p_filtering(
|
|||||||
sorted_indices_to_remove[..., 0] = 0
|
sorted_indices_to_remove[..., 0] = 0
|
||||||
|
|
||||||
# scatter sorted tensors to original indexing
|
# scatter sorted tensors to original indexing
|
||||||
indices_to_remove = sorted_indices_to_remove.scatter(
|
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
|
||||||
1, sorted_indices, sorted_indices_to_remove
|
|
||||||
)
|
|
||||||
logits[indices_to_remove] = filter_value
|
logits[indices_to_remove] = filter_value
|
||||||
return logits
|
return logits
|
||||||
|
|
||||||
@ -97,7 +134,7 @@ def topk_sampling(logits, top_k=10, top_p=1.0, temperature=1.0):
|
|||||||
return token
|
return token
|
||||||
|
|
||||||
|
|
||||||
from typing import Optional, Tuple
|
from typing import Optional
|
||||||
|
|
||||||
|
|
||||||
def multinomial_sample_one_no_sync(
|
def multinomial_sample_one_no_sync(
|
||||||
@ -123,19 +160,21 @@ def logits_to_probs(
|
|||||||
previous_tokens = previous_tokens.long()
|
previous_tokens = previous_tokens.long()
|
||||||
score = torch.gather(logits, dim=1, index=previous_tokens)
|
score = torch.gather(logits, dim=1, index=previous_tokens)
|
||||||
score = torch.where(
|
score = torch.where(
|
||||||
score < 0, score * repetition_penalty, score / repetition_penalty
|
score < 0,
|
||||||
|
score * repetition_penalty,
|
||||||
|
score / repetition_penalty,
|
||||||
)
|
)
|
||||||
logits.scatter_(dim=1, index=previous_tokens, src=score)
|
logits.scatter_(dim=1, index=previous_tokens, src=score)
|
||||||
|
|
||||||
if top_p is not None and top_p < 1.0:
|
if top_p is not None and top_p < 1.0:
|
||||||
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
|
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
|
||||||
cum_probs = torch.cumsum(
|
cum_probs = torch.cumsum(torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1)
|
||||||
torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1
|
|
||||||
)
|
|
||||||
sorted_indices_to_remove = cum_probs > top_p
|
sorted_indices_to_remove = cum_probs > top_p
|
||||||
sorted_indices_to_remove[:, 0] = False # keep at least one option
|
sorted_indices_to_remove[:, 0] = False # keep at least one option
|
||||||
indices_to_remove = sorted_indices_to_remove.scatter(
|
indices_to_remove = sorted_indices_to_remove.scatter(
|
||||||
dim=1, index=sorted_indices, src=sorted_indices_to_remove
|
dim=1,
|
||||||
|
index=sorted_indices,
|
||||||
|
src=sorted_indices_to_remove,
|
||||||
)
|
)
|
||||||
logits = logits.masked_fill(indices_to_remove, -float("Inf"))
|
logits = logits.masked_fill(indices_to_remove, -float("Inf"))
|
||||||
|
|
||||||
@ -143,7 +182,7 @@ def logits_to_probs(
|
|||||||
|
|
||||||
if top_k is not None:
|
if top_k is not None:
|
||||||
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
|
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
|
||||||
pivot = v[: , -1].unsqueeze(-1)
|
pivot = v[:, -1].unsqueeze(-1)
|
||||||
logits = torch.where(logits < pivot, -float("Inf"), logits)
|
logits = torch.where(logits < pivot, -float("Inf"), logits)
|
||||||
|
|
||||||
probs = torch.nn.functional.softmax(logits, dim=-1)
|
probs = torch.nn.functional.softmax(logits, dim=-1)
|
||||||
@ -155,18 +194,19 @@ def sample(
|
|||||||
previous_tokens: Optional[torch.Tensor] = None,
|
previous_tokens: Optional[torch.Tensor] = None,
|
||||||
**sampling_kwargs,
|
**sampling_kwargs,
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
probs = logits_to_probs(
|
probs = logits_to_probs(logits=logits, previous_tokens=previous_tokens, **sampling_kwargs)
|
||||||
logits=logits, previous_tokens=previous_tokens, **sampling_kwargs
|
|
||||||
)
|
|
||||||
idx_next = multinomial_sample_one_no_sync(probs)
|
idx_next = multinomial_sample_one_no_sync(probs)
|
||||||
return idx_next, probs
|
return idx_next, probs
|
||||||
|
|
||||||
def dpo_loss(policy_chosen_logps: torch.FloatTensor,
|
|
||||||
policy_rejected_logps: torch.FloatTensor,
|
def dpo_loss(
|
||||||
reference_chosen_logps: torch.FloatTensor,
|
policy_chosen_logps: torch.FloatTensor,
|
||||||
reference_rejected_logps: torch.FloatTensor,
|
policy_rejected_logps: torch.FloatTensor,
|
||||||
beta: float,
|
reference_chosen_logps: torch.FloatTensor,
|
||||||
reference_free: bool = False) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
|
reference_rejected_logps: torch.FloatTensor,
|
||||||
|
beta: float,
|
||||||
|
reference_free: bool = False,
|
||||||
|
) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
|
||||||
pi_logratios = policy_chosen_logps - policy_rejected_logps
|
pi_logratios = policy_chosen_logps - policy_rejected_logps
|
||||||
ref_logratios = reference_chosen_logps - reference_rejected_logps
|
ref_logratios = reference_chosen_logps - reference_rejected_logps
|
||||||
|
|
||||||
@ -181,40 +221,53 @@ def dpo_loss(policy_chosen_logps: torch.FloatTensor,
|
|||||||
|
|
||||||
return losses.mean(), chosen_rewards, rejected_rewards
|
return losses.mean(), chosen_rewards, rejected_rewards
|
||||||
|
|
||||||
def get_batch_logps(logits_target: torch.FloatTensor, logits_reject: torch.FloatTensor, labels_target: torch.LongTensor, labels_reject: torch.LongTensor, average_log_prob: bool = False) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
|
|
||||||
|
|
||||||
|
def get_batch_logps(
|
||||||
|
logits_target: torch.FloatTensor,
|
||||||
|
logits_reject: torch.FloatTensor,
|
||||||
|
labels_target: torch.LongTensor,
|
||||||
|
labels_reject: torch.LongTensor,
|
||||||
|
average_log_prob: bool = False,
|
||||||
|
) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
|
||||||
# dummy token; we'll ignore the losses on these tokens later
|
# dummy token; we'll ignore the losses on these tokens later
|
||||||
|
|
||||||
per_token_logps_target = torch.gather(logits_target.log_softmax(-1), dim=2, index=labels_target.unsqueeze(2)).squeeze(2)
|
per_token_logps_target = torch.gather(
|
||||||
per_token_logps_reject = torch.gather(logits_reject.log_softmax(-1), dim=2, index=labels_reject.unsqueeze(2)).squeeze(2)
|
logits_target.log_softmax(-1), dim=2, index=labels_target.unsqueeze(2)
|
||||||
|
).squeeze(2)
|
||||||
|
per_token_logps_reject = torch.gather(
|
||||||
|
logits_reject.log_softmax(-1), dim=2, index=labels_reject.unsqueeze(2)
|
||||||
|
).squeeze(2)
|
||||||
|
|
||||||
return per_token_logps_target.sum(-1), per_token_logps_reject.sum(-1)
|
return per_token_logps_target.sum(-1), per_token_logps_reject.sum(-1)
|
||||||
|
|
||||||
|
|
||||||
def make_reject_y(y_o, y_lens):
|
def make_reject_y(y_o, y_lens):
|
||||||
def repeat_P(y):
|
def repeat_P(y):
|
||||||
range_idx, _ = torch.randint(0, len(y), size=(2,)).sort()
|
range_idx, _ = torch.randint(0, len(y), size=(2,)).sort()
|
||||||
pre = y[:range_idx[0]]
|
pre = y[: range_idx[0]]
|
||||||
shf = y[range_idx[1]:]
|
shf = y[range_idx[1] :]
|
||||||
range_text = y[range_idx[0]:range_idx[1]]
|
range_text = y[range_idx[0] : range_idx[1]]
|
||||||
new_y = torch.cat([pre, range_text, range_text, shf])
|
new_y = torch.cat([pre, range_text, range_text, shf])
|
||||||
return new_y
|
return new_y
|
||||||
|
|
||||||
def lost_P(y):
|
def lost_P(y):
|
||||||
range_idx, _ = torch.randint(0, len(y), size=(2,)).sort()
|
range_idx, _ = torch.randint(0, len(y), size=(2,)).sort()
|
||||||
pre = y[:range_idx[0]]
|
pre = y[: range_idx[0]]
|
||||||
shf = y[range_idx[1]:]
|
shf = y[range_idx[1] :]
|
||||||
range_text = y[range_idx[0]:range_idx[1]]
|
range_text = y[range_idx[0] : range_idx[1]]
|
||||||
new_y = torch.cat([pre, shf])
|
new_y = torch.cat([pre, shf])
|
||||||
return new_y
|
return new_y
|
||||||
|
|
||||||
bs = len(y_lens)
|
bs = len(y_lens)
|
||||||
reject_y = []
|
reject_y = []
|
||||||
reject_y_lens = []
|
reject_y_lens = []
|
||||||
for b in range(bs):
|
for b in range(bs):
|
||||||
process_item_idx = torch.randint(0, 1, size=(1, ))[0]
|
process_item_idx = torch.randint(0, 1, size=(1,))[0]
|
||||||
if process_item_idx == 0:
|
if process_item_idx == 0:
|
||||||
new_y = repeat_P(y_o[b])
|
new_y = repeat_P(y_o[b])
|
||||||
reject_y.append(new_y)
|
reject_y.append(new_y)
|
||||||
reject_y_lens.append(len(new_y))
|
reject_y_lens.append(len(new_y))
|
||||||
elif process_item_idx==1:
|
elif process_item_idx == 1:
|
||||||
new_y = lost_P(y_o[b])
|
new_y = lost_P(y_o[b])
|
||||||
reject_y.append(new_y)
|
reject_y.append(new_y)
|
||||||
reject_y_lens.append(len(new_y))
|
reject_y_lens.append(len(new_y))
|
||||||
@ -223,7 +276,7 @@ def make_reject_y(y_o, y_lens):
|
|||||||
pad_length = max_length - reject_y_lens[b]
|
pad_length = max_length - reject_y_lens[b]
|
||||||
reject_y[b] = torch.cat([reject_y[b], torch.zeros(pad_length, dtype=y_o.dtype, device=y_o.device)], dim=0)
|
reject_y[b] = torch.cat([reject_y[b], torch.zeros(pad_length, dtype=y_o.dtype, device=y_o.device)], dim=0)
|
||||||
|
|
||||||
reject_y = torch.stack(reject_y, dim = 0)
|
reject_y = torch.stack(reject_y, dim=0)
|
||||||
reject_y_lens = torch.tensor(reject_y_lens, device=y_lens.device)
|
reject_y_lens = torch.tensor(reject_y_lens, device=y_lens.device)
|
||||||
|
|
||||||
return reject_y, reject_y_lens
|
return reject_y, reject_y_lens
|
||||||
|
@ -1,17 +1,14 @@
|
|||||||
# modified from https://github.com/lifeiteng/vall-e/blob/main/valle/modules/activation.py
|
# modified from https://github.com/lifeiteng/vall-e/blob/main/valle/modules/activation.py
|
||||||
from typing import Optional
|
from typing import Optional, Tuple
|
||||||
from typing import Tuple
|
|
||||||
import torch
|
import torch
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from torch.nn import Linear
|
from torch.nn import Linear, Module
|
||||||
from torch.nn import Module
|
from torch.nn import functional as F
|
||||||
from torch.nn.init import constant_
|
from torch.nn.init import constant_, xavier_normal_, xavier_uniform_
|
||||||
from torch.nn.init import xavier_normal_
|
|
||||||
from torch.nn.init import xavier_uniform_
|
|
||||||
from torch.nn.modules.linear import NonDynamicallyQuantizableLinear
|
from torch.nn.modules.linear import NonDynamicallyQuantizableLinear
|
||||||
from torch.nn.parameter import Parameter
|
from torch.nn.parameter import Parameter
|
||||||
|
|
||||||
from torch.nn import functional as F
|
|
||||||
from AR.modules.patched_mha_with_cache import multi_head_attention_forward_patched
|
from AR.modules.patched_mha_with_cache import multi_head_attention_forward_patched
|
||||||
|
|
||||||
F.multi_head_attention_forward = multi_head_attention_forward_patched
|
F.multi_head_attention_forward = multi_head_attention_forward_patched
|
||||||
@ -73,6 +70,7 @@ class MultiheadAttention(Module):
|
|||||||
>>> attn_output, attn_output_weights = multihead_attn(query, key, value)
|
>>> attn_output, attn_output_weights = multihead_attn(query, key, value)
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
__constants__ = ["batch_first"]
|
__constants__ = ["batch_first"]
|
||||||
bias_k: Optional[torch.Tensor]
|
bias_k: Optional[torch.Tensor]
|
||||||
bias_v: Optional[torch.Tensor]
|
bias_v: Optional[torch.Tensor]
|
||||||
@ -104,9 +102,7 @@ class MultiheadAttention(Module):
|
|||||||
self.dropout = dropout
|
self.dropout = dropout
|
||||||
self.batch_first = batch_first
|
self.batch_first = batch_first
|
||||||
self.head_dim = embed_dim // num_heads
|
self.head_dim = embed_dim // num_heads
|
||||||
assert (
|
assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
|
||||||
self.head_dim * num_heads == self.embed_dim
|
|
||||||
), "embed_dim must be divisible by num_heads"
|
|
||||||
|
|
||||||
if add_bias_kv:
|
if add_bias_kv:
|
||||||
self.bias_k = Parameter(torch.empty((1, 1, embed_dim), **factory_kwargs))
|
self.bias_k = Parameter(torch.empty((1, 1, embed_dim), **factory_kwargs))
|
||||||
@ -117,31 +113,32 @@ class MultiheadAttention(Module):
|
|||||||
if linear1_cls == Linear:
|
if linear1_cls == Linear:
|
||||||
if not self._qkv_same_embed_dim:
|
if not self._qkv_same_embed_dim:
|
||||||
self.q_proj_weight = Parameter(
|
self.q_proj_weight = Parameter(
|
||||||
torch.empty((embed_dim, embed_dim), **factory_kwargs)
|
torch.empty((embed_dim, embed_dim), **factory_kwargs),
|
||||||
)
|
)
|
||||||
self.k_proj_weight = Parameter(
|
self.k_proj_weight = Parameter(
|
||||||
torch.empty((embed_dim, self.kdim), **factory_kwargs)
|
torch.empty((embed_dim, self.kdim), **factory_kwargs),
|
||||||
)
|
)
|
||||||
self.v_proj_weight = Parameter(
|
self.v_proj_weight = Parameter(
|
||||||
torch.empty((embed_dim, self.vdim), **factory_kwargs)
|
torch.empty((embed_dim, self.vdim), **factory_kwargs),
|
||||||
)
|
)
|
||||||
self.register_parameter("in_proj_weight", None)
|
self.register_parameter("in_proj_weight", None)
|
||||||
else:
|
else:
|
||||||
self.in_proj_weight = Parameter(
|
self.in_proj_weight = Parameter(
|
||||||
torch.empty((3 * embed_dim, embed_dim), **factory_kwargs)
|
torch.empty((3 * embed_dim, embed_dim), **factory_kwargs),
|
||||||
)
|
)
|
||||||
self.register_parameter("q_proj_weight", None)
|
self.register_parameter("q_proj_weight", None)
|
||||||
self.register_parameter("k_proj_weight", None)
|
self.register_parameter("k_proj_weight", None)
|
||||||
self.register_parameter("v_proj_weight", None)
|
self.register_parameter("v_proj_weight", None)
|
||||||
|
|
||||||
if bias:
|
if bias:
|
||||||
self.in_proj_bias = Parameter(
|
self.in_proj_bias = Parameter(torch.empty(3 * embed_dim, **factory_kwargs))
|
||||||
torch.empty(3 * embed_dim, **factory_kwargs)
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
self.register_parameter("in_proj_bias", None)
|
self.register_parameter("in_proj_bias", None)
|
||||||
self.out_proj = NonDynamicallyQuantizableLinear(
|
self.out_proj = NonDynamicallyQuantizableLinear(
|
||||||
embed_dim, embed_dim, bias=bias, **factory_kwargs
|
embed_dim,
|
||||||
|
embed_dim,
|
||||||
|
bias=bias,
|
||||||
|
**factory_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
self._reset_parameters()
|
self._reset_parameters()
|
||||||
@ -150,7 +147,10 @@ class MultiheadAttention(Module):
|
|||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
else:
|
else:
|
||||||
self.in_proj_linear = linear1_cls(
|
self.in_proj_linear = linear1_cls(
|
||||||
embed_dim, 3 * embed_dim, bias=bias, **factory_kwargs
|
embed_dim,
|
||||||
|
3 * embed_dim,
|
||||||
|
bias=bias,
|
||||||
|
**factory_kwargs,
|
||||||
)
|
)
|
||||||
self.in_proj_weight = self.in_proj_linear.weight
|
self.in_proj_weight = self.in_proj_linear.weight
|
||||||
|
|
||||||
@ -164,7 +164,10 @@ class MultiheadAttention(Module):
|
|||||||
self.register_parameter("in_proj_bias", None)
|
self.register_parameter("in_proj_bias", None)
|
||||||
|
|
||||||
self.out_proj = linear2_cls(
|
self.out_proj = linear2_cls(
|
||||||
embed_dim, embed_dim, bias=bias, **factory_kwargs
|
embed_dim,
|
||||||
|
embed_dim,
|
||||||
|
bias=bias,
|
||||||
|
**factory_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.bias_k is not None:
|
if self.bias_k is not None:
|
||||||
@ -261,28 +264,26 @@ class MultiheadAttention(Module):
|
|||||||
if key_padding_mask is not None:
|
if key_padding_mask is not None:
|
||||||
_kpm_dtype = key_padding_mask.dtype
|
_kpm_dtype = key_padding_mask.dtype
|
||||||
if _kpm_dtype != torch.bool and not torch.is_floating_point(
|
if _kpm_dtype != torch.bool and not torch.is_floating_point(
|
||||||
key_padding_mask
|
key_padding_mask,
|
||||||
):
|
):
|
||||||
raise AssertionError(
|
raise AssertionError("only bool and floating types of key_padding_mask are supported")
|
||||||
"only bool and floating types of key_padding_mask are supported"
|
|
||||||
)
|
|
||||||
why_not_fast_path = ""
|
why_not_fast_path = ""
|
||||||
if not is_batched:
|
if not is_batched:
|
||||||
why_not_fast_path = (
|
why_not_fast_path = f"input not batched; expected query.dim() of 3 but got {query.dim()}"
|
||||||
f"input not batched; expected query.dim() of 3 but got {query.dim()}"
|
|
||||||
)
|
|
||||||
elif query is not key or key is not value:
|
elif query is not key or key is not value:
|
||||||
# When lifting this restriction, don't forget to either
|
# When lifting this restriction, don't forget to either
|
||||||
# enforce that the dtypes all match or test cases where
|
# enforce that the dtypes all match or test cases where
|
||||||
# they don't!
|
# they don't!
|
||||||
why_not_fast_path = "non-self attention was used (query, key, and value are not the same Tensor)"
|
why_not_fast_path = "non-self attention was used (query, key, and value are not the same Tensor)"
|
||||||
elif self.in_proj_bias is not None and query.dtype != self.in_proj_bias.dtype:
|
elif self.in_proj_bias is not None and query.dtype != self.in_proj_bias.dtype:
|
||||||
why_not_fast_path = f"dtypes of query ({query.dtype}) and self.in_proj_bias ({self.in_proj_bias.dtype}) don't match"
|
why_not_fast_path = (
|
||||||
elif (
|
f"dtypes of query ({query.dtype}) and self.in_proj_bias ({self.in_proj_bias.dtype}) don't match"
|
||||||
self.in_proj_weight is not None and query.dtype != self.in_proj_weight.dtype
|
)
|
||||||
):
|
elif self.in_proj_weight is not None and query.dtype != self.in_proj_weight.dtype:
|
||||||
# this case will fail anyway, but at least they'll get a useful error message.
|
# this case will fail anyway, but at least they'll get a useful error message.
|
||||||
why_not_fast_path = f"dtypes of query ({query.dtype}) and self.in_proj_weight ({self.in_proj_weight.dtype}) don't match"
|
why_not_fast_path = (
|
||||||
|
f"dtypes of query ({query.dtype}) and self.in_proj_weight ({self.in_proj_weight.dtype}) don't match"
|
||||||
|
)
|
||||||
elif self.training:
|
elif self.training:
|
||||||
why_not_fast_path = "training is enabled"
|
why_not_fast_path = "training is enabled"
|
||||||
elif not self.batch_first:
|
elif not self.batch_first:
|
||||||
@ -300,9 +301,7 @@ class MultiheadAttention(Module):
|
|||||||
elif attn_mask is not None:
|
elif attn_mask is not None:
|
||||||
why_not_fast_path = "attn_mask was not None"
|
why_not_fast_path = "attn_mask was not None"
|
||||||
elif query.is_nested and key_padding_mask is not None:
|
elif query.is_nested and key_padding_mask is not None:
|
||||||
why_not_fast_path = (
|
why_not_fast_path = "key_padding_mask is not supported with NestedTensor input"
|
||||||
"key_padding_mask is not supported with NestedTensor input"
|
|
||||||
)
|
|
||||||
elif self.num_heads % 2 == 1:
|
elif self.num_heads % 2 == 1:
|
||||||
why_not_fast_path = "num_heads is odd"
|
why_not_fast_path = "num_heads is odd"
|
||||||
elif torch.is_autocast_enabled():
|
elif torch.is_autocast_enabled():
|
||||||
@ -322,20 +321,10 @@ class MultiheadAttention(Module):
|
|||||||
# generator expressions.
|
# generator expressions.
|
||||||
if torch.overrides.has_torch_function(tensor_args):
|
if torch.overrides.has_torch_function(tensor_args):
|
||||||
why_not_fast_path = "some Tensor argument has_torch_function"
|
why_not_fast_path = "some Tensor argument has_torch_function"
|
||||||
elif not all(
|
elif not all([(x is None or x.is_cuda or "cpu" in str(x.device)) for x in tensor_args]):
|
||||||
[
|
|
||||||
(x is None or x.is_cuda or "cpu" in str(x.device))
|
|
||||||
for x in tensor_args
|
|
||||||
]
|
|
||||||
):
|
|
||||||
why_not_fast_path = "some Tensor argument is neither CUDA nor CPU"
|
why_not_fast_path = "some Tensor argument is neither CUDA nor CPU"
|
||||||
elif torch.is_grad_enabled() and any(
|
elif torch.is_grad_enabled() and any([x is not None and x.requires_grad for x in tensor_args]):
|
||||||
[x is not None and x.requires_grad for x in tensor_args]
|
why_not_fast_path = "grad is enabled and at least one of query or the input/output projection weights or biases requires_grad"
|
||||||
):
|
|
||||||
why_not_fast_path = (
|
|
||||||
"grad is enabled and at least one of query or the "
|
|
||||||
"input/output projection weights or biases requires_grad"
|
|
||||||
)
|
|
||||||
if not why_not_fast_path:
|
if not why_not_fast_path:
|
||||||
return torch._native_multi_head_attention(
|
return torch._native_multi_head_attention(
|
||||||
query,
|
query,
|
||||||
@ -350,11 +339,7 @@ class MultiheadAttention(Module):
|
|||||||
key_padding_mask if key_padding_mask is not None else attn_mask,
|
key_padding_mask if key_padding_mask is not None else attn_mask,
|
||||||
need_weights,
|
need_weights,
|
||||||
average_attn_weights,
|
average_attn_weights,
|
||||||
1
|
1 if key_padding_mask is not None else 0 if attn_mask is not None else None,
|
||||||
if key_padding_mask is not None
|
|
||||||
else 0
|
|
||||||
if attn_mask is not None
|
|
||||||
else None,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
any_nested = query.is_nested or key.is_nested or value.is_nested
|
any_nested = query.is_nested or key.is_nested or value.is_nested
|
||||||
|
@ -1,17 +1,13 @@
|
|||||||
# modified from https://github.com/lifeiteng/vall-e/blob/main/valle/modules/activation.py
|
# modified from https://github.com/lifeiteng/vall-e/blob/main/valle/modules/activation.py
|
||||||
from typing import Optional
|
from typing import Optional, Tuple
|
||||||
from typing import Tuple
|
|
||||||
import torch
|
import torch
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from torch.nn import Linear
|
from torch.nn import Linear, Module
|
||||||
from torch.nn import Module
|
from torch.nn.init import constant_, xavier_normal_, xavier_uniform_
|
||||||
from torch.nn.init import constant_
|
|
||||||
from torch.nn.init import xavier_normal_
|
|
||||||
from torch.nn.init import xavier_uniform_
|
|
||||||
from torch.nn.modules.linear import NonDynamicallyQuantizableLinear
|
from torch.nn.modules.linear import NonDynamicallyQuantizableLinear
|
||||||
from torch.nn.parameter import Parameter
|
from torch.nn.parameter import Parameter
|
||||||
|
|
||||||
from torch.nn import functional as F
|
|
||||||
from AR.modules.patched_mha_with_cache_onnx import multi_head_attention_forward_patched
|
from AR.modules.patched_mha_with_cache_onnx import multi_head_attention_forward_patched
|
||||||
|
|
||||||
|
|
||||||
@ -47,9 +43,7 @@ class MultiheadAttention(Module):
|
|||||||
self.dropout = dropout
|
self.dropout = dropout
|
||||||
self.batch_first = batch_first
|
self.batch_first = batch_first
|
||||||
self.head_dim = embed_dim // num_heads
|
self.head_dim = embed_dim // num_heads
|
||||||
assert (
|
assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
|
||||||
self.head_dim * num_heads == self.embed_dim
|
|
||||||
), "embed_dim must be divisible by num_heads"
|
|
||||||
|
|
||||||
if add_bias_kv:
|
if add_bias_kv:
|
||||||
self.bias_k = Parameter(torch.empty((1, 1, embed_dim), **factory_kwargs))
|
self.bias_k = Parameter(torch.empty((1, 1, embed_dim), **factory_kwargs))
|
||||||
@ -60,18 +54,30 @@ class MultiheadAttention(Module):
|
|||||||
if linear1_cls == Linear:
|
if linear1_cls == Linear:
|
||||||
if not self._qkv_same_embed_dim:
|
if not self._qkv_same_embed_dim:
|
||||||
self.q_proj_weight = Parameter(
|
self.q_proj_weight = Parameter(
|
||||||
torch.empty((embed_dim, embed_dim), **factory_kwargs)
|
torch.empty(
|
||||||
|
(embed_dim, embed_dim),
|
||||||
|
**factory_kwargs,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
self.k_proj_weight = Parameter(
|
self.k_proj_weight = Parameter(
|
||||||
torch.empty((embed_dim, self.kdim), **factory_kwargs)
|
torch.empty(
|
||||||
|
(embed_dim, self.kdim),
|
||||||
|
**factory_kwargs,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
self.v_proj_weight = Parameter(
|
self.v_proj_weight = Parameter(
|
||||||
torch.empty((embed_dim, self.vdim), **factory_kwargs)
|
torch.empty(
|
||||||
|
(embed_dim, self.vdim),
|
||||||
|
**factory_kwargs,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
self.register_parameter("in_proj_weight", None)
|
self.register_parameter("in_proj_weight", None)
|
||||||
else:
|
else:
|
||||||
self.in_proj_weight = Parameter(
|
self.in_proj_weight = Parameter(
|
||||||
torch.empty((3 * embed_dim, embed_dim), **factory_kwargs)
|
torch.empty(
|
||||||
|
(3 * embed_dim, embed_dim),
|
||||||
|
**factory_kwargs,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
self.register_parameter("q_proj_weight", None)
|
self.register_parameter("q_proj_weight", None)
|
||||||
self.register_parameter("k_proj_weight", None)
|
self.register_parameter("k_proj_weight", None)
|
||||||
@ -79,13 +85,11 @@ class MultiheadAttention(Module):
|
|||||||
|
|
||||||
if bias:
|
if bias:
|
||||||
self.in_proj_bias = Parameter(
|
self.in_proj_bias = Parameter(
|
||||||
torch.empty(3 * embed_dim, **factory_kwargs)
|
torch.empty(3 * embed_dim, **factory_kwargs),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self.register_parameter("in_proj_bias", None)
|
self.register_parameter("in_proj_bias", None)
|
||||||
self.out_proj = NonDynamicallyQuantizableLinear(
|
self.out_proj = NonDynamicallyQuantizableLinear(embed_dim, embed_dim, bias=bias, **factory_kwargs)
|
||||||
embed_dim, embed_dim, bias=bias, **factory_kwargs
|
|
||||||
)
|
|
||||||
|
|
||||||
self._reset_parameters()
|
self._reset_parameters()
|
||||||
else:
|
else:
|
||||||
@ -93,7 +97,10 @@ class MultiheadAttention(Module):
|
|||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
else:
|
else:
|
||||||
self.in_proj_linear = linear1_cls(
|
self.in_proj_linear = linear1_cls(
|
||||||
embed_dim, 3 * embed_dim, bias=bias, **factory_kwargs
|
embed_dim,
|
||||||
|
3 * embed_dim,
|
||||||
|
bias=bias,
|
||||||
|
**factory_kwargs,
|
||||||
)
|
)
|
||||||
self.in_proj_weight = self.in_proj_linear.weight
|
self.in_proj_weight = self.in_proj_linear.weight
|
||||||
|
|
||||||
@ -107,7 +114,10 @@ class MultiheadAttention(Module):
|
|||||||
self.register_parameter("in_proj_bias", None)
|
self.register_parameter("in_proj_bias", None)
|
||||||
|
|
||||||
self.out_proj = linear2_cls(
|
self.out_proj = linear2_cls(
|
||||||
embed_dim, embed_dim, bias=bias, **factory_kwargs
|
embed_dim,
|
||||||
|
embed_dim,
|
||||||
|
bias=bias,
|
||||||
|
**factory_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.bias_k is not None:
|
if self.bias_k is not None:
|
||||||
|
@ -60,14 +60,11 @@ class SinePositionalEmbedding(nn.Module):
|
|||||||
return
|
return
|
||||||
pe = torch.zeros(x.size(1), self.embedding_dim)
|
pe = torch.zeros(x.size(1), self.embedding_dim)
|
||||||
if self.reverse:
|
if self.reverse:
|
||||||
position = torch.arange(
|
position = torch.arange(x.size(1) - 1, -1, -1.0, dtype=torch.float32).unsqueeze(1)
|
||||||
x.size(1) - 1, -1, -1.0, dtype=torch.float32
|
|
||||||
).unsqueeze(1)
|
|
||||||
else:
|
else:
|
||||||
position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1)
|
position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1)
|
||||||
div_term = torch.exp(
|
div_term = torch.exp(
|
||||||
torch.arange(0, self.embedding_dim, 2, dtype=torch.float32)
|
torch.arange(0, self.embedding_dim, 2, dtype=torch.float32) * -(math.log(10000.0) / self.embedding_dim)
|
||||||
* -(math.log(10000.0) / self.embedding_dim)
|
|
||||||
)
|
)
|
||||||
pe[:, 0::2] = torch.sin(position * div_term)
|
pe[:, 0::2] = torch.sin(position * div_term)
|
||||||
pe[:, 1::2] = torch.cos(position * div_term)
|
pe[:, 1::2] = torch.cos(position * div_term)
|
||||||
|
@ -50,7 +50,7 @@ class SinePositionalEmbedding(nn.Module):
|
|||||||
self.div_term = torch.exp(torch.arange(0, self.embedding_dim, 2) * -(math.log(10000.0) / self.embedding_dim))
|
self.div_term = torch.exp(torch.arange(0, self.embedding_dim, 2) * -(math.log(10000.0) / self.embedding_dim))
|
||||||
|
|
||||||
def extend_pe(self, x):
|
def extend_pe(self, x):
|
||||||
position = torch.cumsum(torch.ones_like(x[:,:,0]), dim=1).transpose(0, 1)
|
position = torch.cumsum(torch.ones_like(x[:, :, 0]), dim=1).transpose(0, 1)
|
||||||
scpe = (position * self.div_term).unsqueeze(0)
|
scpe = (position * self.div_term).unsqueeze(0)
|
||||||
pe = torch.cat([torch.sin(scpe), torch.cos(scpe)]).permute(1, 2, 0)
|
pe = torch.cat([torch.sin(scpe), torch.cos(scpe)]).permute(1, 2, 0)
|
||||||
pe = pe.contiguous().view(1, -1, self.embedding_dim)
|
pe = pe.contiguous().view(1, -1, self.embedding_dim)
|
||||||
|
@ -49,13 +49,9 @@ class WarmupCosineLRSchedule(torch.optim.lr_scheduler._LRScheduler):
|
|||||||
lr = self.end_lr
|
lr = self.end_lr
|
||||||
|
|
||||||
else:
|
else:
|
||||||
decay_ratio = (self._current_step - self.warmup_steps) / (
|
decay_ratio = (self._current_step - self.warmup_steps) / (self.total_steps - self.warmup_steps)
|
||||||
self.total_steps - self.warmup_steps
|
|
||||||
)
|
|
||||||
if decay_ratio < 0.0 or decay_ratio > 1.0:
|
if decay_ratio < 0.0 or decay_ratio > 1.0:
|
||||||
raise RuntimeError(
|
raise RuntimeError("Decay ratio must be in [0.0, 1.0]. Fix LR scheduler settings.")
|
||||||
"Decay ratio must be in [0.0, 1.0]. Fix LR scheduler settings."
|
|
||||||
)
|
|
||||||
coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio))
|
coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio))
|
||||||
lr = self.end_lr + coeff * (self.peak_lr - self.end_lr)
|
lr = self.end_lr + coeff * (self.peak_lr - self.end_lr)
|
||||||
|
|
||||||
@ -70,7 +66,13 @@ if __name__ == "__main__":
|
|||||||
m = nn.Linear(10, 10)
|
m = nn.Linear(10, 10)
|
||||||
opt = Adam(m.parameters(), lr=1e-4)
|
opt = Adam(m.parameters(), lr=1e-4)
|
||||||
s = WarmupCosineLRSchedule(
|
s = WarmupCosineLRSchedule(
|
||||||
opt, 1e-6, 2e-4, 1e-6, warmup_steps=2000, total_steps=20000, current_step=0
|
opt,
|
||||||
|
1e-6,
|
||||||
|
2e-4,
|
||||||
|
1e-6,
|
||||||
|
warmup_steps=2000,
|
||||||
|
total_steps=20000,
|
||||||
|
current_step=0,
|
||||||
)
|
)
|
||||||
lrs = []
|
lrs = []
|
||||||
for i in range(25000):
|
for i in range(25000):
|
||||||
|
@ -16,8 +16,7 @@
|
|||||||
import contextlib
|
import contextlib
|
||||||
import logging
|
import logging
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from typing import List
|
from typing import List, Tuple
|
||||||
from typing import Tuple
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
@ -71,12 +70,8 @@ class BatchedOptimizer(Optimizer):
|
|||||||
group_params_names: name for each parameter in group,
|
group_params_names: name for each parameter in group,
|
||||||
which is List[str].
|
which is List[str].
|
||||||
"""
|
"""
|
||||||
batches = defaultdict(
|
batches = defaultdict(list) # `batches` maps from tuple (dtype_as_str,*shape) to list of nn.Parameter
|
||||||
list
|
batches_names = defaultdict(list) # `batches` maps from tuple (dtype_as_str,*shape) to list of str
|
||||||
) # `batches` maps from tuple (dtype_as_str,*shape) to list of nn.Parameter
|
|
||||||
batches_names = defaultdict(
|
|
||||||
list
|
|
||||||
) # `batches` maps from tuple (dtype_as_str,*shape) to list of str
|
|
||||||
|
|
||||||
assert len(param_group) == len(group_params_names)
|
assert len(param_group) == len(group_params_names)
|
||||||
for p, named_p in zip(param_group, group_params_names):
|
for p, named_p in zip(param_group, group_params_names):
|
||||||
@ -85,11 +80,8 @@ class BatchedOptimizer(Optimizer):
|
|||||||
batches_names[key].append(named_p)
|
batches_names[key].append(named_p)
|
||||||
|
|
||||||
batches_names_keys = list(batches_names.keys())
|
batches_names_keys = list(batches_names.keys())
|
||||||
sorted_idx = sorted(
|
sorted_idx = sorted(range(len(batches_names)), key=lambda i: batches_names_keys[i])
|
||||||
range(len(batches_names)), key=lambda i: batches_names_keys[i])
|
batches_names = [batches_names[batches_names_keys[idx]] for idx in sorted_idx]
|
||||||
batches_names = [
|
|
||||||
batches_names[batches_names_keys[idx]] for idx in sorted_idx
|
|
||||||
]
|
|
||||||
batches = [batches[batches_names_keys[idx]] for idx in sorted_idx]
|
batches = [batches[batches_names_keys[idx]] for idx in sorted_idx]
|
||||||
|
|
||||||
stacked_params_dict = dict()
|
stacked_params_dict = dict()
|
||||||
@ -106,16 +98,14 @@ class BatchedOptimizer(Optimizer):
|
|||||||
# group. class Optimizer will take care of saving/loading state.
|
# group. class Optimizer will take care of saving/loading state.
|
||||||
state = self.state[p]
|
state = self.state[p]
|
||||||
p_stacked = torch.stack(batch)
|
p_stacked = torch.stack(batch)
|
||||||
grad = torch.stack([
|
grad = torch.stack([torch.zeros_like(p) if p.grad is None else p.grad for p in batch])
|
||||||
torch.zeros_like(p) if p.grad is None else p.grad for p in batch
|
|
||||||
])
|
|
||||||
p_stacked.grad = grad
|
p_stacked.grad = grad
|
||||||
stacked_params_dict[key] = p_stacked
|
stacked_params_dict[key] = p_stacked
|
||||||
tuples.append((p_stacked, state, batch_names))
|
tuples.append((p_stacked, state, batch_names))
|
||||||
|
|
||||||
yield tuples # <-- calling code will do the actual optimization here!
|
yield tuples # <-- calling code will do the actual optimization here!
|
||||||
|
|
||||||
for ((stacked_params, _state, _names), batch) in zip(tuples, batches):
|
for (stacked_params, _state, _names), batch in zip(tuples, batches):
|
||||||
for i, p in enumerate(batch): # batch is list of Parameter
|
for i, p in enumerate(batch): # batch is list of Parameter
|
||||||
p.copy_(stacked_params[i])
|
p.copy_(stacked_params[i])
|
||||||
|
|
||||||
@ -164,25 +154,24 @@ class ScaledAdam(BatchedOptimizer):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
params,
|
params,
|
||||||
lr=3e-02,
|
lr=3e-02,
|
||||||
clipping_scale=None,
|
clipping_scale=None,
|
||||||
betas=(0.9, 0.98),
|
betas=(0.9, 0.98),
|
||||||
scalar_lr_scale=0.1,
|
scalar_lr_scale=0.1,
|
||||||
eps=1.0e-08,
|
eps=1.0e-08,
|
||||||
param_min_rms=1.0e-05,
|
param_min_rms=1.0e-05,
|
||||||
param_max_rms=3.0,
|
param_max_rms=3.0,
|
||||||
scalar_max=10.0,
|
scalar_max=10.0,
|
||||||
size_update_period=4,
|
size_update_period=4,
|
||||||
clipping_update_period=100,
|
clipping_update_period=100,
|
||||||
parameters_names=None,
|
parameters_names=None,
|
||||||
show_dominant_parameters=True, ):
|
show_dominant_parameters=True,
|
||||||
|
):
|
||||||
assert parameters_names is not None, (
|
assert parameters_names is not None, (
|
||||||
"Please prepare parameters_names,"
|
"Please prepare parameters_names,which is a List[List[str]]. Each List[str] is for a groupand each str is for a parameter"
|
||||||
"which is a List[List[str]]. Each List[str] is for a group"
|
)
|
||||||
"and each str is for a parameter")
|
|
||||||
defaults = dict(
|
defaults = dict(
|
||||||
lr=lr,
|
lr=lr,
|
||||||
clipping_scale=clipping_scale,
|
clipping_scale=clipping_scale,
|
||||||
@ -193,7 +182,8 @@ class ScaledAdam(BatchedOptimizer):
|
|||||||
param_max_rms=param_max_rms,
|
param_max_rms=param_max_rms,
|
||||||
scalar_max=scalar_max,
|
scalar_max=scalar_max,
|
||||||
size_update_period=size_update_period,
|
size_update_period=size_update_period,
|
||||||
clipping_update_period=clipping_update_period, )
|
clipping_update_period=clipping_update_period,
|
||||||
|
)
|
||||||
|
|
||||||
super(ScaledAdam, self).__init__(params, defaults)
|
super(ScaledAdam, self).__init__(params, defaults)
|
||||||
assert len(self.param_groups) == len(parameters_names)
|
assert len(self.param_groups) == len(parameters_names)
|
||||||
@ -218,18 +208,13 @@ class ScaledAdam(BatchedOptimizer):
|
|||||||
|
|
||||||
batch = True
|
batch = True
|
||||||
|
|
||||||
for group, group_params_names in zip(self.param_groups,
|
for group, group_params_names in zip(self.param_groups, self.parameters_names):
|
||||||
self.parameters_names):
|
with self.batched_params(group["params"], group_params_names) as batches:
|
||||||
|
|
||||||
with self.batched_params(group["params"],
|
|
||||||
group_params_names) as batches:
|
|
||||||
|
|
||||||
# batches is list of pairs (stacked_param, state). stacked_param is like
|
# batches is list of pairs (stacked_param, state). stacked_param is like
|
||||||
# a regular parameter, and will have a .grad, but the 1st dim corresponds to
|
# a regular parameter, and will have a .grad, but the 1st dim corresponds to
|
||||||
# a stacking dim, it is not a real dim.
|
# a stacking dim, it is not a real dim.
|
||||||
|
|
||||||
if (len(batches[0][1]) ==
|
if len(batches[0][1]) == 0: # if len(first state) == 0: not yet initialized
|
||||||
0): # if len(first state) == 0: not yet initialized
|
|
||||||
clipping_scale = 1
|
clipping_scale = 1
|
||||||
else:
|
else:
|
||||||
clipping_scale = self._get_clipping_scale(group, batches)
|
clipping_scale = self._get_clipping_scale(group, batches)
|
||||||
@ -239,9 +224,7 @@ class ScaledAdam(BatchedOptimizer):
|
|||||||
# grad is not going to be None, we handled that when creating the batches.
|
# grad is not going to be None, we handled that when creating the batches.
|
||||||
grad = p.grad
|
grad = p.grad
|
||||||
if grad.is_sparse:
|
if grad.is_sparse:
|
||||||
raise RuntimeError(
|
raise RuntimeError("ScaledAdam optimizer does not support sparse gradients")
|
||||||
"ScaledAdam optimizer does not support sparse gradients"
|
|
||||||
)
|
|
||||||
# State initialization
|
# State initialization
|
||||||
if len(state) == 0:
|
if len(state) == 0:
|
||||||
self._init_state(group, p, state)
|
self._init_state(group, p, state)
|
||||||
@ -274,8 +257,7 @@ class ScaledAdam(BatchedOptimizer):
|
|||||||
# parameter-change "delta", which combines all forms of
|
# parameter-change "delta", which combines all forms of
|
||||||
# update. this is equivalent to how it's done in Adam,
|
# update. this is equivalent to how it's done in Adam,
|
||||||
# except for the first few steps.
|
# except for the first few steps.
|
||||||
state["delta"] = torch.zeros_like(
|
state["delta"] = torch.zeros_like(p, memory_format=torch.preserve_format)
|
||||||
p, memory_format=torch.preserve_format)
|
|
||||||
|
|
||||||
batch_size = p.shape[0]
|
batch_size = p.shape[0]
|
||||||
numel = p.numel() // batch_size
|
numel = p.numel() // batch_size
|
||||||
@ -285,22 +267,16 @@ class ScaledAdam(BatchedOptimizer):
|
|||||||
# "param_rms" just periodically records the scalar root-mean-square value of
|
# "param_rms" just periodically records the scalar root-mean-square value of
|
||||||
# the parameter tensor.
|
# the parameter tensor.
|
||||||
# it has a shape like (batch_size, 1, 1, 1, 1)
|
# it has a shape like (batch_size, 1, 1, 1, 1)
|
||||||
param_rms = (
|
param_rms = (p**2).mean(dim=list(range(1, p.ndim)), keepdim=True).sqrt()
|
||||||
(p**2).mean(dim=list(range(1, p.ndim)), keepdim=True).sqrt())
|
|
||||||
state["param_rms"] = param_rms
|
state["param_rms"] = param_rms
|
||||||
|
|
||||||
state["scale_exp_avg_sq"] = torch.zeros_like(param_rms)
|
state["scale_exp_avg_sq"] = torch.zeros_like(param_rms)
|
||||||
state["scale_grads"] = torch.zeros(size_update_period,
|
state["scale_grads"] = torch.zeros(size_update_period, *param_rms.shape, **kwargs)
|
||||||
*param_rms.shape, **kwargs)
|
|
||||||
|
|
||||||
# exp_avg_sq is the weighted sum of scaled gradients. as in Adam.
|
# exp_avg_sq is the weighted sum of scaled gradients. as in Adam.
|
||||||
state["exp_avg_sq"] = torch.zeros_like(
|
state["exp_avg_sq"] = torch.zeros_like(p, memory_format=torch.preserve_format)
|
||||||
p, memory_format=torch.preserve_format)
|
|
||||||
|
|
||||||
def _get_clipping_scale(self,
|
def _get_clipping_scale(self, group: dict, tuples: List[Tuple[Tensor, dict, List[str]]]) -> float:
|
||||||
group: dict,
|
|
||||||
tuples: List[Tuple[Tensor, dict, List[str]]]
|
|
||||||
) -> float:
|
|
||||||
"""
|
"""
|
||||||
Returns a scalar factor <= 1.0 that dictates gradient clipping, i.e. we will scale the gradients
|
Returns a scalar factor <= 1.0 that dictates gradient clipping, i.e. we will scale the gradients
|
||||||
by this amount before applying the rest of the update.
|
by this amount before applying the rest of the update.
|
||||||
@ -325,20 +301,18 @@ class ScaledAdam(BatchedOptimizer):
|
|||||||
clipping_update_period = group["clipping_update_period"]
|
clipping_update_period = group["clipping_update_period"]
|
||||||
|
|
||||||
tot_sumsq = torch.tensor(0.0, device=first_p.device)
|
tot_sumsq = torch.tensor(0.0, device=first_p.device)
|
||||||
for (p, state, param_names) in tuples:
|
for p, state, param_names in tuples:
|
||||||
grad = p.grad
|
grad = p.grad
|
||||||
if grad.is_sparse:
|
if grad.is_sparse:
|
||||||
raise RuntimeError(
|
raise RuntimeError("ScaledAdam optimizer does not support sparse gradients")
|
||||||
"ScaledAdam optimizer does not support sparse gradients")
|
|
||||||
if p.numel() == p.shape[0]: # a batch of scalars
|
if p.numel() == p.shape[0]: # a batch of scalars
|
||||||
tot_sumsq += (grad**2).sum() # sum() to change shape [1] to []
|
tot_sumsq += (grad**2).sum() # sum() to change shape [1] to []
|
||||||
else:
|
else:
|
||||||
tot_sumsq += ((grad * state["param_rms"])**2).sum()
|
tot_sumsq += ((grad * state["param_rms"]) ** 2).sum()
|
||||||
|
|
||||||
tot_norm = tot_sumsq.sqrt()
|
tot_norm = tot_sumsq.sqrt()
|
||||||
if "model_norms" not in first_state:
|
if "model_norms" not in first_state:
|
||||||
first_state["model_norms"] = torch.zeros(
|
first_state["model_norms"] = torch.zeros(clipping_update_period, device=p.device)
|
||||||
clipping_update_period, device=p.device)
|
|
||||||
first_state["model_norms"][step % clipping_update_period] = tot_norm
|
first_state["model_norms"][step % clipping_update_period] = tot_norm
|
||||||
|
|
||||||
if step % clipping_update_period == 0:
|
if step % clipping_update_period == 0:
|
||||||
@ -350,20 +324,20 @@ class ScaledAdam(BatchedOptimizer):
|
|||||||
for n in range(0, 5):
|
for n in range(0, 5):
|
||||||
index = min(
|
index = min(
|
||||||
clipping_update_period - 1,
|
clipping_update_period - 1,
|
||||||
(clipping_update_period // 4) * n, )
|
(clipping_update_period // 4) * n,
|
||||||
|
)
|
||||||
quartiles.append(sorted_norms[index].item())
|
quartiles.append(sorted_norms[index].item())
|
||||||
|
|
||||||
median = quartiles[2]
|
median = quartiles[2]
|
||||||
threshold = clipping_scale * median
|
threshold = clipping_scale * median
|
||||||
first_state["model_norm_threshold"] = threshold
|
first_state["model_norm_threshold"] = threshold
|
||||||
percent_clipped = (first_state["num_clipped"] * 100.0 /
|
percent_clipped = (
|
||||||
clipping_update_period
|
first_state["num_clipped"] * 100.0 / clipping_update_period if "num_clipped" in first_state else 0.0
|
||||||
if "num_clipped" in first_state else 0.0)
|
)
|
||||||
first_state["num_clipped"] = 0
|
first_state["num_clipped"] = 0
|
||||||
quartiles = " ".join(["%.3e" % x for x in quartiles])
|
quartiles = " ".join(["%.3e" % x for x in quartiles])
|
||||||
logging.info(
|
logging.info(
|
||||||
f"Clipping_scale={clipping_scale}, grad-norm quartiles {quartiles}, "
|
f"Clipping_scale={clipping_scale}, grad-norm quartiles {quartiles}, threshold={threshold:.3e}, percent-clipped={percent_clipped:.1f}"
|
||||||
f"threshold={threshold:.3e}, percent-clipped={percent_clipped:.1f}"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if step < clipping_update_period:
|
if step < clipping_update_period:
|
||||||
@ -373,25 +347,20 @@ class ScaledAdam(BatchedOptimizer):
|
|||||||
model_norm_threshold = first_state["model_norm_threshold"]
|
model_norm_threshold = first_state["model_norm_threshold"]
|
||||||
except KeyError:
|
except KeyError:
|
||||||
logging.info(
|
logging.info(
|
||||||
"Warning: model_norm_threshold not in state: possibly "
|
"Warning: model_norm_threshold not in state: possibly you changed config when restarting, adding clipping_scale option?"
|
||||||
"you changed config when restarting, adding clipping_scale option?"
|
|
||||||
)
|
)
|
||||||
return 1.0
|
return 1.0
|
||||||
ans = min(1.0, (model_norm_threshold / (tot_norm + 1.0e-20)).item())
|
ans = min(1.0, (model_norm_threshold / (tot_norm + 1.0e-20)).item())
|
||||||
if ans < 1.0:
|
if ans < 1.0:
|
||||||
first_state["num_clipped"] += 1
|
first_state["num_clipped"] += 1
|
||||||
if ans < 0.1:
|
if ans < 0.1:
|
||||||
logging.warn(
|
logging.warn(f"Scaling gradients by {ans}, model_norm_threshold={model_norm_threshold}")
|
||||||
f"Scaling gradients by {ans}, model_norm_threshold={model_norm_threshold}"
|
|
||||||
)
|
|
||||||
if self.show_dominant_parameters:
|
if self.show_dominant_parameters:
|
||||||
assert p.shape[0] == len(param_names)
|
assert p.shape[0] == len(param_names)
|
||||||
self._show_gradient_dominating_parameter(tuples, tot_sumsq)
|
self._show_gradient_dominating_parameter(tuples, tot_sumsq)
|
||||||
return ans
|
return ans
|
||||||
|
|
||||||
def _show_gradient_dominating_parameter(
|
def _show_gradient_dominating_parameter(self, tuples: List[Tuple[Tensor, dict, List[str]]], tot_sumsq: Tensor):
|
||||||
self, tuples: List[Tuple[Tensor, dict, List[str]]],
|
|
||||||
tot_sumsq: Tensor):
|
|
||||||
"""
|
"""
|
||||||
Show information of parameter wihch dominanting tot_sumsq.
|
Show information of parameter wihch dominanting tot_sumsq.
|
||||||
|
|
||||||
@ -406,7 +375,7 @@ class ScaledAdam(BatchedOptimizer):
|
|||||||
from tuples, we still pass it to save some time.
|
from tuples, we still pass it to save some time.
|
||||||
"""
|
"""
|
||||||
all_sumsq_orig = {}
|
all_sumsq_orig = {}
|
||||||
for (p, state, batch_param_names) in tuples:
|
for p, state, batch_param_names in tuples:
|
||||||
# p is a stacked batch parameters.
|
# p is a stacked batch parameters.
|
||||||
batch_grad = p.grad
|
batch_grad = p.grad
|
||||||
if p.numel() == p.shape[0]: # a batch of scalars
|
if p.numel() == p.shape[0]: # a batch of scalars
|
||||||
@ -415,41 +384,46 @@ class ScaledAdam(BatchedOptimizer):
|
|||||||
batch_rms_orig = torch.ones(p.shape[0])
|
batch_rms_orig = torch.ones(p.shape[0])
|
||||||
else:
|
else:
|
||||||
batch_rms_orig = state["param_rms"]
|
batch_rms_orig = state["param_rms"]
|
||||||
batch_sumsq_orig = ((batch_grad * batch_rms_orig)**2).sum(
|
batch_sumsq_orig = ((batch_grad * batch_rms_orig) ** 2).sum(dim=list(range(1, batch_grad.ndim)))
|
||||||
dim=list(range(1, batch_grad.ndim)))
|
|
||||||
|
|
||||||
for name, sumsq_orig, rms, grad in zip(batch_param_names,
|
|
||||||
batch_sumsq_orig,
|
|
||||||
batch_rms_orig, batch_grad):
|
|
||||||
|
|
||||||
|
for name, sumsq_orig, rms, grad in zip(
|
||||||
|
batch_param_names,
|
||||||
|
batch_sumsq_orig,
|
||||||
|
batch_rms_orig,
|
||||||
|
batch_grad,
|
||||||
|
):
|
||||||
proportion_orig = sumsq_orig / tot_sumsq
|
proportion_orig = sumsq_orig / tot_sumsq
|
||||||
all_sumsq_orig[name] = (proportion_orig, sumsq_orig, rms, grad)
|
all_sumsq_orig[name] = (proportion_orig, sumsq_orig, rms, grad)
|
||||||
|
|
||||||
assert torch.isclose(
|
assert torch.isclose(
|
||||||
sum([value[0] for value in all_sumsq_orig.values()]).cpu(),
|
sum([value[0] for value in all_sumsq_orig.values()]).cpu(),
|
||||||
torch.tensor(1.0), )
|
torch.tensor(1.0),
|
||||||
|
)
|
||||||
sorted_by_proportion = {
|
sorted_by_proportion = {
|
||||||
k: v
|
k: v
|
||||||
for k, v in sorted(
|
for k, v in sorted(
|
||||||
all_sumsq_orig.items(),
|
all_sumsq_orig.items(),
|
||||||
key=lambda item: item[1][0],
|
key=lambda item: item[1][0],
|
||||||
reverse=True, )
|
reverse=True,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
dominant_param_name = next(iter(sorted_by_proportion))
|
dominant_param_name = next(iter(sorted_by_proportion))
|
||||||
(dominant_proportion, dominant_sumsq, dominant_rms,
|
(
|
||||||
dominant_grad, ) = sorted_by_proportion[dominant_param_name]
|
dominant_proportion,
|
||||||
logging.info(f"Parameter Dominanting tot_sumsq {dominant_param_name}"
|
dominant_sumsq,
|
||||||
f" with proportion {dominant_proportion:.2f},"
|
dominant_rms,
|
||||||
f" where dominant_sumsq=(grad_sumsq*orig_rms_sq)"
|
dominant_grad,
|
||||||
f"={dominant_sumsq:.3e},"
|
) = sorted_by_proportion[dominant_param_name]
|
||||||
f" grad_sumsq = {(dominant_grad**2).sum():.3e},"
|
logging.info(
|
||||||
f" orig_rms_sq={(dominant_rms**2).item():.3e}")
|
f"Parameter Dominanting tot_sumsq {dominant_param_name}"
|
||||||
|
f" with proportion {dominant_proportion:.2f},"
|
||||||
|
f" where dominant_sumsq=(grad_sumsq*orig_rms_sq)"
|
||||||
|
f"={dominant_sumsq:.3e},"
|
||||||
|
f" grad_sumsq = {(dominant_grad**2).sum():.3e},"
|
||||||
|
f" orig_rms_sq={(dominant_rms**2).item():.3e}"
|
||||||
|
)
|
||||||
|
|
||||||
def _step_one_batch(self,
|
def _step_one_batch(self, group: dict, p: Tensor, state: dict, clipping_scale: float):
|
||||||
group: dict,
|
|
||||||
p: Tensor,
|
|
||||||
state: dict,
|
|
||||||
clipping_scale: float):
|
|
||||||
"""
|
"""
|
||||||
Do the step for one parameter, which is actually going to be a batch of
|
Do the step for one parameter, which is actually going to be a batch of
|
||||||
`real` parameters, with dim 0 as the batch dim.
|
`real` parameters, with dim 0 as the batch dim.
|
||||||
@ -475,13 +449,10 @@ class ScaledAdam(BatchedOptimizer):
|
|||||||
if numel > 1:
|
if numel > 1:
|
||||||
# Update the size/scale of p, and set param_rms
|
# Update the size/scale of p, and set param_rms
|
||||||
scale_grads = state["scale_grads"]
|
scale_grads = state["scale_grads"]
|
||||||
scale_grads[step % size_update_period] = (p * grad).sum(
|
scale_grads[step % size_update_period] = (p * grad).sum(dim=list(range(1, p.ndim)), keepdim=True)
|
||||||
dim=list(range(1, p.ndim)), keepdim=True)
|
|
||||||
if step % size_update_period == size_update_period - 1:
|
if step % size_update_period == size_update_period - 1:
|
||||||
param_rms = state["param_rms"] # shape: (batch_size, 1, 1, ..)
|
param_rms = state["param_rms"] # shape: (batch_size, 1, 1, ..)
|
||||||
param_rms.copy_((p**2)
|
param_rms.copy_((p**2).mean(dim=list(range(1, p.ndim)), keepdim=True).sqrt())
|
||||||
.mean(dim=list(range(1, p.ndim)), keepdim=True)
|
|
||||||
.sqrt())
|
|
||||||
if step > 0:
|
if step > 0:
|
||||||
# self._size_update() learns the overall scale on the
|
# self._size_update() learns the overall scale on the
|
||||||
# parameter, by shrinking or expanding it.
|
# parameter, by shrinking or expanding it.
|
||||||
@ -496,11 +467,13 @@ class ScaledAdam(BatchedOptimizer):
|
|||||||
|
|
||||||
state["step"] = step + 1
|
state["step"] = step + 1
|
||||||
|
|
||||||
def _size_update(self,
|
def _size_update(
|
||||||
group: dict,
|
self,
|
||||||
scale_grads: Tensor,
|
group: dict,
|
||||||
p: Tensor,
|
scale_grads: Tensor,
|
||||||
state: dict) -> None:
|
p: Tensor,
|
||||||
|
state: dict,
|
||||||
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Called only where p.numel() > 1, this updates the scale of the parameter.
|
Called only where p.numel() > 1, this updates the scale of the parameter.
|
||||||
If we imagine: p = underlying_param * scale.exp(), and we are doing
|
If we imagine: p = underlying_param * scale.exp(), and we are doing
|
||||||
@ -529,11 +502,11 @@ class ScaledAdam(BatchedOptimizer):
|
|||||||
# faster decay at this level.
|
# faster decay at this level.
|
||||||
beta2_corr = beta2**size_update_period
|
beta2_corr = beta2**size_update_period
|
||||||
|
|
||||||
scale_exp_avg_sq = state[
|
scale_exp_avg_sq = state["scale_exp_avg_sq"] # shape: (batch_size, 1, 1, ..)
|
||||||
"scale_exp_avg_sq"] # shape: (batch_size, 1, 1, ..)
|
|
||||||
scale_exp_avg_sq.mul_(beta2_corr).add_(
|
scale_exp_avg_sq.mul_(beta2_corr).add_(
|
||||||
(scale_grads**2).mean(dim=0), # mean over dim `size_update_period`
|
(scale_grads**2).mean(dim=0), # mean over dim `size_update_period`
|
||||||
alpha=1 - beta2_corr, ) # shape is (batch_size, 1, 1, ...)
|
alpha=1 - beta2_corr,
|
||||||
|
) # shape is (batch_size, 1, 1, ...)
|
||||||
|
|
||||||
# The 1st time we reach here is when size_step == 1.
|
# The 1st time we reach here is when size_step == 1.
|
||||||
size_step = (step + 1) // size_update_period
|
size_step = (step + 1) // size_update_period
|
||||||
@ -543,8 +516,7 @@ class ScaledAdam(BatchedOptimizer):
|
|||||||
|
|
||||||
denom = scale_exp_avg_sq.sqrt() + eps
|
denom = scale_exp_avg_sq.sqrt() + eps
|
||||||
|
|
||||||
scale_step = (-size_lr * (bias_correction2**0.5) *
|
scale_step = -size_lr * (bias_correction2**0.5) * scale_grads.sum(dim=0) / denom
|
||||||
scale_grads.sum(dim=0) / denom)
|
|
||||||
|
|
||||||
is_too_small = param_rms < param_min_rms
|
is_too_small = param_rms < param_min_rms
|
||||||
is_too_large = param_rms > param_max_rms
|
is_too_large = param_rms > param_max_rms
|
||||||
@ -580,9 +552,8 @@ class ScaledAdam(BatchedOptimizer):
|
|||||||
exp_avg_sq = state["exp_avg_sq"]
|
exp_avg_sq = state["exp_avg_sq"]
|
||||||
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=(1 - beta2))
|
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=(1 - beta2))
|
||||||
|
|
||||||
this_step = state["step"] - (state["zero_step"]
|
this_step = state["step"] - (state["zero_step"] if "zero_step" in state else 0)
|
||||||
if "zero_step" in state else 0)
|
bias_correction2 = 1 - beta2 ** (this_step + 1)
|
||||||
bias_correction2 = 1 - beta2**(this_step + 1)
|
|
||||||
if bias_correction2 < 0.99:
|
if bias_correction2 < 0.99:
|
||||||
# note: not in-place.
|
# note: not in-place.
|
||||||
exp_avg_sq = exp_avg_sq * (1.0 / bias_correction2)
|
exp_avg_sq = exp_avg_sq * (1.0 / bias_correction2)
|
||||||
@ -613,7 +584,7 @@ class ScaledAdam(BatchedOptimizer):
|
|||||||
|
|
||||||
# bias_correction2 is like in Adam. Don't bother with bias_correction1;
|
# bias_correction2 is like in Adam. Don't bother with bias_correction1;
|
||||||
# slower update at the start will help stability anyway.
|
# slower update at the start will help stability anyway.
|
||||||
bias_correction2 = 1 - beta2**(state["step"] + 1)
|
bias_correction2 = 1 - beta2 ** (state["step"] + 1)
|
||||||
denom = (exp_avg_sq / bias_correction2).sqrt() + eps
|
denom = (exp_avg_sq / bias_correction2).sqrt() + eps
|
||||||
|
|
||||||
delta = state["delta"]
|
delta = state["delta"]
|
||||||
|
@ -5,40 +5,39 @@ from torch.nn.functional import (
|
|||||||
_none_or_dtype,
|
_none_or_dtype,
|
||||||
_in_projection_packed,
|
_in_projection_packed,
|
||||||
)
|
)
|
||||||
from torch.nn import functional as F
|
|
||||||
import torch
|
import torch
|
||||||
# Tensor = torch.Tensor
|
# Tensor = torch.Tensor
|
||||||
# from typing import Callable, List, Optional, Tuple, Union
|
# from typing import Callable, List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
|
||||||
def multi_head_attention_forward_patched(
|
def multi_head_attention_forward_patched(
|
||||||
query: Tensor,
|
query,
|
||||||
key: Tensor,
|
key,
|
||||||
value: Tensor,
|
value,
|
||||||
embed_dim_to_check: int,
|
embed_dim_to_check,
|
||||||
num_heads: int,
|
num_heads,
|
||||||
in_proj_weight: Optional[Tensor],
|
in_proj_weight,
|
||||||
in_proj_bias: Optional[Tensor],
|
in_proj_bias,
|
||||||
bias_k: Optional[Tensor],
|
bias_k,
|
||||||
bias_v: Optional[Tensor],
|
bias_v,
|
||||||
add_zero_attn: bool,
|
add_zero_attn,
|
||||||
dropout_p: float,
|
dropout_p: float,
|
||||||
out_proj_weight: Tensor,
|
out_proj_weight,
|
||||||
out_proj_bias: Optional[Tensor],
|
out_proj_bias,
|
||||||
training: bool = True,
|
training=True,
|
||||||
key_padding_mask: Optional[Tensor] = None,
|
key_padding_mask=None,
|
||||||
need_weights: bool = True,
|
need_weights=True,
|
||||||
attn_mask: Optional[Tensor] = None,
|
attn_mask=None,
|
||||||
use_separate_proj_weight: bool = False,
|
use_separate_proj_weight=False,
|
||||||
q_proj_weight: Optional[Tensor] = None,
|
q_proj_weight=None,
|
||||||
k_proj_weight: Optional[Tensor] = None,
|
k_proj_weight=None,
|
||||||
v_proj_weight: Optional[Tensor] = None,
|
v_proj_weight=None,
|
||||||
static_k: Optional[Tensor] = None,
|
static_k=None,
|
||||||
static_v: Optional[Tensor] = None,
|
static_v=None,
|
||||||
average_attn_weights: bool = True,
|
average_attn_weights=True,
|
||||||
is_causal: bool = False,
|
is_causal=False,
|
||||||
cache=None,
|
cache=None,
|
||||||
) -> Tuple[Tensor, Optional[Tensor]]:
|
):
|
||||||
r"""
|
r"""
|
||||||
Args:
|
Args:
|
||||||
query, key, value: map a query and a set of key-value pairs to an output.
|
query, key, value: map a query and a set of key-value pairs to an output.
|
||||||
@ -156,9 +155,7 @@ def multi_head_attention_forward_patched(
|
|||||||
cache=cache,
|
cache=cache,
|
||||||
)
|
)
|
||||||
|
|
||||||
is_batched = _mha_shape_check(
|
is_batched = _mha_shape_check(query, key, value, key_padding_mask, attn_mask, num_heads)
|
||||||
query, key, value, key_padding_mask, attn_mask, num_heads
|
|
||||||
)
|
|
||||||
|
|
||||||
# For unbatched input, we unsqueeze at the expected batch-dim to pretend that the input
|
# For unbatched input, we unsqueeze at the expected batch-dim to pretend that the input
|
||||||
# is batched, run the computation and before returning squeeze the
|
# is batched, run the computation and before returning squeeze the
|
||||||
@ -211,45 +208,33 @@ def multi_head_attention_forward_patched(
|
|||||||
# longer causal.
|
# longer causal.
|
||||||
is_causal = False
|
is_causal = False
|
||||||
|
|
||||||
assert (
|
assert embed_dim == embed_dim_to_check, (
|
||||||
embed_dim == embed_dim_to_check
|
f"was expecting embedding dimension of {embed_dim_to_check}, but got {embed_dim}"
|
||||||
), f"was expecting embedding dimension of {embed_dim_to_check}, but got {embed_dim}"
|
)
|
||||||
if isinstance(embed_dim, torch.Tensor):
|
if isinstance(embed_dim, torch.Tensor):
|
||||||
# embed_dim can be a tensor when JIT tracing
|
# embed_dim can be a tensor when JIT tracing
|
||||||
head_dim = embed_dim.div(num_heads, rounding_mode="trunc")
|
head_dim = embed_dim.div(num_heads, rounding_mode="trunc")
|
||||||
else:
|
else:
|
||||||
head_dim = embed_dim // num_heads
|
head_dim = embed_dim // num_heads
|
||||||
assert (
|
assert head_dim * num_heads == embed_dim, f"embed_dim {embed_dim} not divisible by num_heads {num_heads}"
|
||||||
head_dim * num_heads == embed_dim
|
|
||||||
), f"embed_dim {embed_dim} not divisible by num_heads {num_heads}"
|
|
||||||
if use_separate_proj_weight:
|
if use_separate_proj_weight:
|
||||||
# allow MHA to have different embedding dimensions when separate projection weights are used
|
# allow MHA to have different embedding dimensions when separate projection weights are used
|
||||||
assert (
|
assert key.shape[:2] == value.shape[:2], (
|
||||||
key.shape[:2] == value.shape[:2]
|
f"key's sequence and batch dims {key.shape[:2]} do not match value's {value.shape[:2]}"
|
||||||
), f"key's sequence and batch dims {key.shape[:2]} do not match value's {value.shape[:2]}"
|
)
|
||||||
else:
|
else:
|
||||||
assert (
|
assert key.shape == value.shape, f"key shape {key.shape} does not match value shape {value.shape}"
|
||||||
key.shape == value.shape
|
|
||||||
), f"key shape {key.shape} does not match value shape {value.shape}"
|
|
||||||
|
|
||||||
#
|
#
|
||||||
# compute in-projection
|
# compute in-projection
|
||||||
#
|
#
|
||||||
if not use_separate_proj_weight:
|
if not use_separate_proj_weight:
|
||||||
assert (
|
assert in_proj_weight is not None, "use_separate_proj_weight is False but in_proj_weight is None"
|
||||||
in_proj_weight is not None
|
|
||||||
), "use_separate_proj_weight is False but in_proj_weight is None"
|
|
||||||
q, k, v = _in_projection_packed(query, key, value, in_proj_weight, in_proj_bias)
|
q, k, v = _in_projection_packed(query, key, value, in_proj_weight, in_proj_bias)
|
||||||
else:
|
else:
|
||||||
assert (
|
assert q_proj_weight is not None, "use_separate_proj_weight is True but q_proj_weight is None"
|
||||||
q_proj_weight is not None
|
assert k_proj_weight is not None, "use_separate_proj_weight is True but k_proj_weight is None"
|
||||||
), "use_separate_proj_weight is True but q_proj_weight is None"
|
assert v_proj_weight is not None, "use_separate_proj_weight is True but v_proj_weight is None"
|
||||||
assert (
|
|
||||||
k_proj_weight is not None
|
|
||||||
), "use_separate_proj_weight is True but k_proj_weight is None"
|
|
||||||
assert (
|
|
||||||
v_proj_weight is not None
|
|
||||||
), "use_separate_proj_weight is True but v_proj_weight is None"
|
|
||||||
if in_proj_bias is None:
|
if in_proj_bias is None:
|
||||||
b_q = b_k = b_v = None
|
b_q = b_k = b_v = None
|
||||||
else:
|
else:
|
||||||
@ -312,9 +297,7 @@ def multi_head_attention_forward_patched(
|
|||||||
f"The shape of the 3D attn_mask is {attn_mask.shape}, but should be {correct_3d_size}."
|
f"The shape of the 3D attn_mask is {attn_mask.shape}, but should be {correct_3d_size}."
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise RuntimeError(
|
raise RuntimeError(f"attn_mask's dimension {attn_mask.dim()} is not supported")
|
||||||
f"attn_mask's dimension {attn_mask.dim()} is not supported"
|
|
||||||
)
|
|
||||||
|
|
||||||
# add bias along batch dimension (currently second)
|
# add bias along batch dimension (currently second)
|
||||||
if bias_k is not None and bias_v is not None:
|
if bias_k is not None and bias_v is not None:
|
||||||
@ -338,34 +321,26 @@ def multi_head_attention_forward_patched(
|
|||||||
k = k.view(k.shape[0], bsz * num_heads, head_dim).transpose(0, 1)
|
k = k.view(k.shape[0], bsz * num_heads, head_dim).transpose(0, 1)
|
||||||
else:
|
else:
|
||||||
# TODO finish disentangling control flow so we don't do in-projections when statics are passed
|
# TODO finish disentangling control flow so we don't do in-projections when statics are passed
|
||||||
assert (
|
assert static_k.size(0) == bsz * num_heads, (
|
||||||
static_k.size(0) == bsz * num_heads
|
f"expecting static_k.size(0) of {bsz * num_heads}, but got {static_k.size(0)}"
|
||||||
), f"expecting static_k.size(0) of {bsz * num_heads}, but got {static_k.size(0)}"
|
)
|
||||||
assert (
|
assert static_k.size(2) == head_dim, f"expecting static_k.size(2) of {head_dim}, but got {static_k.size(2)}"
|
||||||
static_k.size(2) == head_dim
|
|
||||||
), f"expecting static_k.size(2) of {head_dim}, but got {static_k.size(2)}"
|
|
||||||
k = static_k
|
k = static_k
|
||||||
if static_v is None:
|
if static_v is None:
|
||||||
v = v.view(v.shape[0], bsz * num_heads, head_dim).transpose(0, 1)
|
v = v.view(v.shape[0], bsz * num_heads, head_dim).transpose(0, 1)
|
||||||
else:
|
else:
|
||||||
# TODO finish disentangling control flow so we don't do in-projections when statics are passed
|
# TODO finish disentangling control flow so we don't do in-projections when statics are passed
|
||||||
assert (
|
assert static_v.size(0) == bsz * num_heads, (
|
||||||
static_v.size(0) == bsz * num_heads
|
f"expecting static_v.size(0) of {bsz * num_heads}, but got {static_v.size(0)}"
|
||||||
), f"expecting static_v.size(0) of {bsz * num_heads}, but got {static_v.size(0)}"
|
)
|
||||||
assert (
|
assert static_v.size(2) == head_dim, f"expecting static_v.size(2) of {head_dim}, but got {static_v.size(2)}"
|
||||||
static_v.size(2) == head_dim
|
|
||||||
), f"expecting static_v.size(2) of {head_dim}, but got {static_v.size(2)}"
|
|
||||||
v = static_v
|
v = static_v
|
||||||
|
|
||||||
# add zero attention along batch dimension (now first)
|
# add zero attention along batch dimension (now first)
|
||||||
if add_zero_attn:
|
if add_zero_attn:
|
||||||
zero_attn_shape = (bsz * num_heads, 1, head_dim)
|
zero_attn_shape = (bsz * num_heads, 1, head_dim)
|
||||||
k = torch.cat(
|
k = torch.cat([k, torch.zeros(zero_attn_shape, dtype=k.dtype, device=k.device)], dim=1)
|
||||||
[k, torch.zeros(zero_attn_shape, dtype=k.dtype, device=k.device)], dim=1
|
v = torch.cat([v, torch.zeros(zero_attn_shape, dtype=v.dtype, device=v.device)], dim=1)
|
||||||
)
|
|
||||||
v = torch.cat(
|
|
||||||
[v, torch.zeros(zero_attn_shape, dtype=v.dtype, device=v.device)], dim=1
|
|
||||||
)
|
|
||||||
if attn_mask is not None:
|
if attn_mask is not None:
|
||||||
attn_mask = pad(attn_mask, (0, 1))
|
attn_mask = pad(attn_mask, (0, 1))
|
||||||
if key_padding_mask is not None:
|
if key_padding_mask is not None:
|
||||||
@ -381,9 +356,7 @@ def multi_head_attention_forward_patched(
|
|||||||
src_len,
|
src_len,
|
||||||
), f"expecting key_padding_mask shape of {(bsz, src_len)}, but got {key_padding_mask.shape}"
|
), f"expecting key_padding_mask shape of {(bsz, src_len)}, but got {key_padding_mask.shape}"
|
||||||
key_padding_mask = (
|
key_padding_mask = (
|
||||||
key_padding_mask.view(bsz, 1, 1, src_len)
|
key_padding_mask.view(bsz, 1, 1, src_len).expand(-1, num_heads, -1, -1).reshape(bsz * num_heads, 1, src_len)
|
||||||
.expand(-1, num_heads, -1, -1)
|
|
||||||
.reshape(bsz * num_heads, 1, src_len)
|
|
||||||
)
|
)
|
||||||
if attn_mask is None:
|
if attn_mask is None:
|
||||||
attn_mask = key_padding_mask
|
attn_mask = key_padding_mask
|
||||||
@ -402,14 +375,10 @@ def multi_head_attention_forward_patched(
|
|||||||
B, Nt, E = q.shape
|
B, Nt, E = q.shape
|
||||||
q_scaled = q / math.sqrt(E)
|
q_scaled = q / math.sqrt(E)
|
||||||
|
|
||||||
assert not (
|
assert not (is_causal and attn_mask is None), "FIXME: is_causal not implemented for need_weights"
|
||||||
is_causal and attn_mask is None
|
|
||||||
), "FIXME: is_causal not implemented for need_weights"
|
|
||||||
|
|
||||||
if attn_mask is not None:
|
if attn_mask is not None:
|
||||||
attn_output_weights = torch.baddbmm(
|
attn_output_weights = torch.baddbmm(attn_mask, q_scaled, k.transpose(-2, -1))
|
||||||
attn_mask, q_scaled, k.transpose(-2, -1)
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
attn_output_weights = torch.bmm(q_scaled, k.transpose(-2, -1))
|
attn_output_weights = torch.bmm(q_scaled, k.transpose(-2, -1))
|
||||||
attn_output_weights = softmax(attn_output_weights, dim=-1)
|
attn_output_weights = softmax(attn_output_weights, dim=-1)
|
||||||
@ -418,9 +387,7 @@ def multi_head_attention_forward_patched(
|
|||||||
|
|
||||||
attn_output = torch.bmm(attn_output_weights, v)
|
attn_output = torch.bmm(attn_output_weights, v)
|
||||||
|
|
||||||
attn_output = (
|
attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len * bsz, embed_dim)
|
||||||
attn_output.transpose(0, 1).contiguous().view(tgt_len * bsz, embed_dim)
|
|
||||||
)
|
|
||||||
attn_output = linear(attn_output, out_proj_weight, out_proj_bias)
|
attn_output = linear(attn_output, out_proj_weight, out_proj_bias)
|
||||||
attn_output = attn_output.view(tgt_len, bsz, attn_output.size(1))
|
attn_output = attn_output.view(tgt_len, bsz, attn_output.size(1))
|
||||||
|
|
||||||
@ -449,13 +416,9 @@ def multi_head_attention_forward_patched(
|
|||||||
v = v.view(bsz, num_heads, src_len, head_dim)
|
v = v.view(bsz, num_heads, src_len, head_dim)
|
||||||
|
|
||||||
# with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=True, enable_mem_efficient=True):
|
# with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=True, enable_mem_efficient=True):
|
||||||
attn_output = scaled_dot_product_attention(
|
attn_output = scaled_dot_product_attention(q, k, v, attn_mask, dropout_p, is_causal)
|
||||||
q, k, v, attn_mask, dropout_p, is_causal
|
|
||||||
)
|
|
||||||
|
|
||||||
attn_output = (
|
attn_output = attn_output.permute(2, 0, 1, 3).contiguous().view(bsz * tgt_len, embed_dim)
|
||||||
attn_output.permute(2, 0, 1, 3).contiguous().view(bsz * tgt_len, embed_dim)
|
|
||||||
)
|
|
||||||
|
|
||||||
attn_output = linear(attn_output, out_proj_weight, out_proj_bias)
|
attn_output = linear(attn_output, out_proj_weight, out_proj_bias)
|
||||||
attn_output = attn_output.view(tgt_len, bsz, attn_output.size(1))
|
attn_output = attn_output.view(tgt_len, bsz, attn_output.size(1))
|
||||||
|
@ -1,11 +1,9 @@
|
|||||||
from torch.nn.functional import *
|
from torch.nn.functional import *
|
||||||
from torch.nn.functional import (
|
from torch.nn.functional import (
|
||||||
_mha_shape_check,
|
|
||||||
_canonical_mask,
|
_canonical_mask,
|
||||||
_none_or_dtype,
|
|
||||||
_in_projection_packed,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def multi_head_attention_forward_patched(
|
def multi_head_attention_forward_patched(
|
||||||
query,
|
query,
|
||||||
key,
|
key,
|
||||||
@ -34,7 +32,6 @@ def multi_head_attention_forward_patched(
|
|||||||
is_causal: bool = False,
|
is_causal: bool = False,
|
||||||
cache=None,
|
cache=None,
|
||||||
) -> Tuple[Tensor, Optional[Tensor]]:
|
) -> Tuple[Tensor, Optional[Tensor]]:
|
||||||
|
|
||||||
# set up shape vars
|
# set up shape vars
|
||||||
_, _, embed_dim = query.shape
|
_, _, embed_dim = query.shape
|
||||||
attn_mask = _canonical_mask(
|
attn_mask = _canonical_mask(
|
||||||
@ -80,12 +77,8 @@ def multi_head_attention_forward_patched(
|
|||||||
q = q.view(num_heads, -1, head_dim).unsqueeze(0)
|
q = q.view(num_heads, -1, head_dim).unsqueeze(0)
|
||||||
k = k.view(num_heads, -1, head_dim).unsqueeze(0)
|
k = k.view(num_heads, -1, head_dim).unsqueeze(0)
|
||||||
v = v.view(num_heads, -1, head_dim).unsqueeze(0)
|
v = v.view(num_heads, -1, head_dim).unsqueeze(0)
|
||||||
attn_output = scaled_dot_product_attention(
|
attn_output = scaled_dot_product_attention(q, k, v, attn_mask, dropout_p, is_causal)
|
||||||
q, k, v, attn_mask, dropout_p, is_causal
|
attn_output = attn_output.permute(2, 0, 1, 3).contiguous().view(-1, embed_dim)
|
||||||
)
|
|
||||||
attn_output = (
|
|
||||||
attn_output.permute(2, 0, 1, 3).contiguous().view(-1, embed_dim)
|
|
||||||
)
|
|
||||||
attn_output = linear(attn_output, out_proj_weight, out_proj_bias)
|
attn_output = linear(attn_output, out_proj_weight, out_proj_bias)
|
||||||
attn_output = attn_output.view(-1, 1, attn_output.size(1))
|
attn_output = attn_output.view(-1, 1, attn_output.size(1))
|
||||||
|
|
||||||
|
@ -13,12 +13,9 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import logging
|
|
||||||
import math
|
|
||||||
import random
|
import random
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
from typing import Tuple
|
from typing import Tuple
|
||||||
from typing import Union
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
@ -61,9 +58,7 @@ class DoubleSwishFunction(torch.autograd.Function):
|
|||||||
# floors), should be expectation-preserving.
|
# floors), should be expectation-preserving.
|
||||||
floor = -0.043637
|
floor = -0.043637
|
||||||
ceil = 1.2
|
ceil = 1.2
|
||||||
d_scaled = (deriv - floor) * (255.0 / (ceil - floor)) + torch.rand_like(
|
d_scaled = (deriv - floor) * (255.0 / (ceil - floor)) + torch.rand_like(deriv)
|
||||||
deriv
|
|
||||||
)
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# for self-testing only.
|
# for self-testing only.
|
||||||
assert d_scaled.min() >= 0.0
|
assert d_scaled.min() >= 0.0
|
||||||
@ -153,13 +148,9 @@ def _compute_scale_factor(
|
|||||||
else:
|
else:
|
||||||
# below_threshold is 0 if x_abs_mean > min_abs, can be at most max_factor if
|
# below_threshold is 0 if x_abs_mean > min_abs, can be at most max_factor if
|
||||||
# x_abs)_mean , min_abs.
|
# x_abs)_mean , min_abs.
|
||||||
below_threshold = ((min_abs - x_abs_mean) * (gain_factor / min_abs)).clamp(
|
below_threshold = ((min_abs - x_abs_mean) * (gain_factor / min_abs)).clamp(min=0, max=max_factor)
|
||||||
min=0, max=max_factor
|
|
||||||
)
|
|
||||||
|
|
||||||
above_threshold = ((x_abs_mean - max_abs) * (gain_factor / max_abs)).clamp(
|
above_threshold = ((x_abs_mean - max_abs) * (gain_factor / max_abs)).clamp(min=0, max=max_factor)
|
||||||
min=0, max=max_factor
|
|
||||||
)
|
|
||||||
|
|
||||||
return below_threshold - above_threshold
|
return below_threshold - above_threshold
|
||||||
|
|
||||||
@ -181,18 +172,16 @@ def _compute_sign_factor(
|
|||||||
else:
|
else:
|
||||||
# 0 if proportion_positive >= min_positive, else can be
|
# 0 if proportion_positive >= min_positive, else can be
|
||||||
# as large as max_factor.
|
# as large as max_factor.
|
||||||
factor1 = (
|
factor1 = ((min_positive - proportion_positive) * (gain_factor / min_positive)).clamp_(min=0, max=max_factor)
|
||||||
(min_positive - proportion_positive) * (gain_factor / min_positive)
|
|
||||||
).clamp_(min=0, max=max_factor)
|
|
||||||
|
|
||||||
if max_positive == 1.0:
|
if max_positive == 1.0:
|
||||||
factor2 = 0.0
|
factor2 = 0.0
|
||||||
else:
|
else:
|
||||||
# 0 if self.proportion_positive <= max_positive, else can be
|
# 0 if self.proportion_positive <= max_positive, else can be
|
||||||
# as large as -max_factor.
|
# as large as -max_factor.
|
||||||
factor2 = (
|
factor2 = ((proportion_positive - max_positive) * (gain_factor / (1.0 - max_positive))).clamp_(
|
||||||
(proportion_positive - max_positive) * (gain_factor / (1.0 - max_positive))
|
min=0, max=max_factor
|
||||||
).clamp_(min=0, max=max_factor)
|
)
|
||||||
sign_factor = factor1 - factor2
|
sign_factor = factor1 - factor2
|
||||||
# require min_positive != 0 or max_positive != 1:
|
# require min_positive != 0 or max_positive != 1:
|
||||||
assert not isinstance(sign_factor, float)
|
assert not isinstance(sign_factor, float)
|
||||||
@ -320,15 +309,11 @@ class ActivationBalancer(torch.nn.Module):
|
|||||||
return _no_op(x)
|
return _no_op(x)
|
||||||
|
|
||||||
|
|
||||||
def BalancedDoubleSwish(
|
def BalancedDoubleSwish(d_model, channel_dim=-1, max_abs=10.0, min_prob=0.25) -> nn.Sequential:
|
||||||
d_model, channel_dim=-1, max_abs=10.0, min_prob=0.25
|
|
||||||
) -> nn.Sequential:
|
|
||||||
"""
|
"""
|
||||||
ActivationBalancer -> DoubleSwish
|
ActivationBalancer -> DoubleSwish
|
||||||
"""
|
"""
|
||||||
balancer = ActivationBalancer(
|
balancer = ActivationBalancer(d_model, channel_dim=channel_dim, max_abs=max_abs, min_prob=min_prob)
|
||||||
d_model, channel_dim=channel_dim, max_abs=max_abs, min_prob=min_prob
|
|
||||||
)
|
|
||||||
return nn.Sequential(
|
return nn.Sequential(
|
||||||
balancer,
|
balancer,
|
||||||
DoubleSwish(),
|
DoubleSwish(),
|
||||||
|
@ -42,12 +42,8 @@ class LayerNorm(nn.Module):
|
|||||||
self.eps = eps
|
self.eps = eps
|
||||||
self.elementwise_affine = elementwise_affine
|
self.elementwise_affine = elementwise_affine
|
||||||
if self.elementwise_affine:
|
if self.elementwise_affine:
|
||||||
self.weight = nn.Parameter(
|
self.weight = nn.Parameter(torch.empty(self.normalized_shape, **factory_kwargs))
|
||||||
torch.empty(self.normalized_shape, **factory_kwargs)
|
self.bias = nn.Parameter(torch.empty(self.normalized_shape, **factory_kwargs))
|
||||||
)
|
|
||||||
self.bias = nn.Parameter(
|
|
||||||
torch.empty(self.normalized_shape, **factory_kwargs)
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
self.register_parameter("weight", None)
|
self.register_parameter("weight", None)
|
||||||
self.register_parameter("bias", None)
|
self.register_parameter("bias", None)
|
||||||
@ -74,15 +70,10 @@ class LayerNorm(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
assert embedding is None
|
assert embedding is None
|
||||||
return F.layer_norm(
|
return F.layer_norm(input, self.normalized_shape, self.weight, self.bias, self.eps)
|
||||||
input, self.normalized_shape, self.weight, self.bias, self.eps
|
|
||||||
)
|
|
||||||
|
|
||||||
def extra_repr(self) -> str:
|
def extra_repr(self) -> str:
|
||||||
return (
|
return "{normalized_shape}, eps={eps}, elementwise_affine={elementwise_affine}".format(**self.__dict__)
|
||||||
"{normalized_shape}, eps={eps}, "
|
|
||||||
"elementwise_affine={elementwise_affine}".format(**self.__dict__)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class IdentityNorm(nn.Module):
|
class IdentityNorm(nn.Module):
|
||||||
@ -121,6 +112,7 @@ class TransformerEncoder(nn.Module):
|
|||||||
>>> src = torch.rand(10, 32, 512)
|
>>> src = torch.rand(10, 32, 512)
|
||||||
>>> out = transformer_encoder(src)
|
>>> out = transformer_encoder(src)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
__constants__ = ["norm"]
|
__constants__ = ["norm"]
|
||||||
|
|
||||||
def __init__(self, encoder_layer, num_layers, norm=None):
|
def __init__(self, encoder_layer, num_layers, norm=None):
|
||||||
@ -218,13 +210,9 @@ class TransformerEncoderLayer(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Implementation of Feedforward model
|
# Implementation of Feedforward model
|
||||||
self.linear1 = linear1_feedforward_cls(
|
self.linear1 = linear1_feedforward_cls(d_model, dim_feedforward, **factory_kwargs)
|
||||||
d_model, dim_feedforward, **factory_kwargs
|
|
||||||
)
|
|
||||||
self.dropout = nn.Dropout(dropout)
|
self.dropout = nn.Dropout(dropout)
|
||||||
self.linear2 = linear2_feedforward_cls(
|
self.linear2 = linear2_feedforward_cls(dim_feedforward, d_model, **factory_kwargs)
|
||||||
dim_feedforward, d_model, **factory_kwargs
|
|
||||||
)
|
|
||||||
|
|
||||||
self.norm_first = norm_first
|
self.norm_first = norm_first
|
||||||
self.dropout1 = nn.Dropout(dropout)
|
self.dropout1 = nn.Dropout(dropout)
|
||||||
@ -291,12 +279,8 @@ class TransformerEncoderLayer(nn.Module):
|
|||||||
|
|
||||||
if src_key_padding_mask is not None:
|
if src_key_padding_mask is not None:
|
||||||
_skpm_dtype = src_key_padding_mask.dtype
|
_skpm_dtype = src_key_padding_mask.dtype
|
||||||
if _skpm_dtype != torch.bool and not torch.is_floating_point(
|
if _skpm_dtype != torch.bool and not torch.is_floating_point(src_key_padding_mask):
|
||||||
src_key_padding_mask
|
raise AssertionError("only bool and floating types of key_padding_mask are supported")
|
||||||
):
|
|
||||||
raise AssertionError(
|
|
||||||
"only bool and floating types of key_padding_mask are supported"
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.norm_first:
|
if self.norm_first:
|
||||||
x = x + self._sa_block(
|
x = x + self._sa_block(
|
||||||
|
@ -42,12 +42,8 @@ class LayerNorm(nn.Module):
|
|||||||
self.eps = eps
|
self.eps = eps
|
||||||
self.elementwise_affine = elementwise_affine
|
self.elementwise_affine = elementwise_affine
|
||||||
if self.elementwise_affine:
|
if self.elementwise_affine:
|
||||||
self.weight = nn.Parameter(
|
self.weight = nn.Parameter(torch.empty(self.normalized_shape, **factory_kwargs))
|
||||||
torch.empty(self.normalized_shape, **factory_kwargs)
|
self.bias = nn.Parameter(torch.empty(self.normalized_shape, **factory_kwargs))
|
||||||
)
|
|
||||||
self.bias = nn.Parameter(
|
|
||||||
torch.empty(self.normalized_shape, **factory_kwargs)
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
self.register_parameter("weight", None)
|
self.register_parameter("weight", None)
|
||||||
self.register_parameter("bias", None)
|
self.register_parameter("bias", None)
|
||||||
@ -74,15 +70,10 @@ class LayerNorm(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
assert embedding is None
|
assert embedding is None
|
||||||
return F.layer_norm(
|
return F.layer_norm(input, self.normalized_shape, self.weight, self.bias, self.eps)
|
||||||
input, self.normalized_shape, self.weight, self.bias, self.eps
|
|
||||||
)
|
|
||||||
|
|
||||||
def extra_repr(self) -> str:
|
def extra_repr(self) -> str:
|
||||||
return (
|
return "{normalized_shape}, eps={eps}, elementwise_affine={elementwise_affine}".format(**self.__dict__)
|
||||||
"{normalized_shape}, eps={eps}, "
|
|
||||||
"elementwise_affine={elementwise_affine}".format(**self.__dict__)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class IdentityNorm(nn.Module):
|
class IdentityNorm(nn.Module):
|
||||||
@ -121,6 +112,7 @@ class TransformerEncoder(nn.Module):
|
|||||||
>>> src = torch.rand(10, 32, 512)
|
>>> src = torch.rand(10, 32, 512)
|
||||||
>>> out = transformer_encoder(src)
|
>>> out = transformer_encoder(src)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
__constants__ = ["norm"]
|
__constants__ = ["norm"]
|
||||||
|
|
||||||
def __init__(self, encoder_layer, num_layers, norm=None):
|
def __init__(self, encoder_layer, num_layers, norm=None):
|
||||||
@ -154,6 +146,7 @@ class TransformerEncoder(nn.Module):
|
|||||||
|
|
||||||
class TransformerEncoderLayer(nn.Module):
|
class TransformerEncoderLayer(nn.Module):
|
||||||
__constants__ = ["batch_first", "norm_first"]
|
__constants__ = ["batch_first", "norm_first"]
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
d_model: int,
|
d_model: int,
|
||||||
@ -184,13 +177,9 @@ class TransformerEncoderLayer(nn.Module):
|
|||||||
linear2_cls=linear2_self_attention_cls,
|
linear2_cls=linear2_self_attention_cls,
|
||||||
**factory_kwargs,
|
**factory_kwargs,
|
||||||
)
|
)
|
||||||
self.linear1 = linear1_feedforward_cls(
|
self.linear1 = linear1_feedforward_cls(d_model, dim_feedforward, **factory_kwargs)
|
||||||
d_model, dim_feedforward, **factory_kwargs
|
|
||||||
)
|
|
||||||
self.dropout = nn.Dropout(dropout)
|
self.dropout = nn.Dropout(dropout)
|
||||||
self.linear2 = linear2_feedforward_cls(
|
self.linear2 = linear2_feedforward_cls(dim_feedforward, d_model, **factory_kwargs)
|
||||||
dim_feedforward, d_model, **factory_kwargs
|
|
||||||
)
|
|
||||||
self.norm_first = norm_first
|
self.norm_first = norm_first
|
||||||
self.dropout1 = nn.Dropout(dropout)
|
self.dropout1 = nn.Dropout(dropout)
|
||||||
self.dropout2 = nn.Dropout(dropout)
|
self.dropout2 = nn.Dropout(dropout)
|
||||||
|
@ -30,9 +30,7 @@ class GruutPhonemizer:
|
|||||||
"«": "«",
|
"«": "«",
|
||||||
"»": "»",
|
"»": "»",
|
||||||
}
|
}
|
||||||
self._punctuation_regexp: str = (
|
self._punctuation_regexp: str = rf"([{''.join(self._special_cases_dict.keys())}])"
|
||||||
rf"([{''.join(self._special_cases_dict.keys())}])"
|
|
||||||
)
|
|
||||||
|
|
||||||
def _normalize_punctuation(self, text: str) -> str:
|
def _normalize_punctuation(self, text: str) -> str:
|
||||||
text = regex.sub(rf"\pZ+{self._punctuation_regexp}", r"\1", text)
|
text = regex.sub(rf"\pZ+{self._punctuation_regexp}", r"\1", text)
|
||||||
@ -53,13 +51,8 @@ class GruutPhonemizer:
|
|||||||
|
|
||||||
def phonemize(self, text: str, espeak: bool = False) -> str:
|
def phonemize(self, text: str, espeak: bool = False) -> str:
|
||||||
text_to_phonemize: str = self._normalize_punctuation(text)
|
text_to_phonemize: str = self._normalize_punctuation(text)
|
||||||
sents: List[Sentence] = [
|
sents: List[Sentence] = [sent for sent in self._phonemizer(text_to_phonemize, lang="en-us", espeak=espeak)]
|
||||||
sent
|
words: List[str] = [self._convert_punctuation(word) for word in itertools.chain(*sents)]
|
||||||
for sent in self._phonemizer(text_to_phonemize, lang="en-us", espeak=espeak)
|
|
||||||
]
|
|
||||||
words: List[str] = [
|
|
||||||
self._convert_punctuation(word) for word in itertools.chain(*sents)
|
|
||||||
]
|
|
||||||
return " ".join(words)
|
return " ".join(words)
|
||||||
|
|
||||||
def transform(self, phonemes):
|
def transform(self, phonemes):
|
||||||
|
@ -3,7 +3,9 @@
|
|||||||
PAD = "_"
|
PAD = "_"
|
||||||
PUNCTUATION = ';:,.!?¡¿—…"«»“” '
|
PUNCTUATION = ';:,.!?¡¿—…"«»“” '
|
||||||
LETTERS = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"
|
LETTERS = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"
|
||||||
IPA_LETTERS = "ɑɐɒæɓʙβɔɕçɗɖðʤəɘɚɛɜɝɞɟʄɡɠɢʛɦɧħɥʜɨɪʝɭɬɫɮʟɱɯɰŋɳɲɴøɵɸθœɶʘɹɺɾɻʀʁɽʂʃʈʧʉʊʋⱱʌɣɤʍχʎʏʑʐʒʔʡʕʢǀǁǂǃˈˌːˑʼʴʰʱʲʷˠˤ˞↓↑→↗↘'̩'ᵻ"
|
IPA_LETTERS = (
|
||||||
|
"ɑɐɒæɓʙβɔɕçɗɖðʤəɘɚɛɜɝɞɟʄɡɠɢʛɦɧħɥʜɨɪʝɭɬɫɮʟɱɯɰŋɳɲɴøɵɸθœɶʘɹɺɾɻʀʁɽʂʃʈʧʉʊʋⱱʌɣɤʍχʎʏʑʐʒʔʡʕʢǀǁǂǃˈˌːˑʼʴʰʱʲʷˠˤ˞↓↑→↗↘'̩'ᵻ"
|
||||||
|
)
|
||||||
SYMBOLS = [PAD] + list(PUNCTUATION) + list(LETTERS) + list(IPA_LETTERS)
|
SYMBOLS = [PAD] + list(PUNCTUATION) + list(LETTERS) + list(IPA_LETTERS)
|
||||||
SPACE_ID = SYMBOLS.index(" ")
|
SPACE_ID = SYMBOLS.index(" ")
|
||||||
SYMBOL_TO_ID = {s: i for i, s in enumerate(SYMBOLS)}
|
SYMBOL_TO_ID = {s: i for i, s in enumerate(SYMBOLS)}
|
||||||
|
@ -2,12 +2,12 @@ import re
|
|||||||
|
|
||||||
|
|
||||||
def str2bool(str):
|
def str2bool(str):
|
||||||
return True if str.lower() == 'true' else False
|
return True if str.lower() == "true" else False
|
||||||
|
|
||||||
|
|
||||||
def get_newest_ckpt(string_list):
|
def get_newest_ckpt(string_list):
|
||||||
# 定义一个正则表达式模式,用于匹配字符串中的数字
|
# 定义一个正则表达式模式,用于匹配字符串中的数字
|
||||||
pattern = r'epoch=(\d+)-step=(\d+)\.ckpt'
|
pattern = r"epoch=(\d+)-step=(\d+)\.ckpt"
|
||||||
|
|
||||||
# 使用正则表达式提取每个字符串中的数字信息,并创建一个包含元组的列表
|
# 使用正则表达式提取每个字符串中的数字信息,并创建一个包含元组的列表
|
||||||
extracted_info = []
|
extracted_info = []
|
||||||
@ -18,8 +18,7 @@ def get_newest_ckpt(string_list):
|
|||||||
step = int(match.group(2))
|
step = int(match.group(2))
|
||||||
extracted_info.append((epoch, step, string))
|
extracted_info.append((epoch, step, string))
|
||||||
# 按照 epoch 后面的数字和 step 后面的数字进行排序
|
# 按照 epoch 后面的数字和 step 后面的数字进行排序
|
||||||
sorted_info = sorted(
|
sorted_info = sorted(extracted_info, key=lambda x: (x[0], x[1]), reverse=True)
|
||||||
extracted_info, key=lambda x: (x[0], x[1]), reverse=True)
|
|
||||||
# 获取最新的 ckpt 文件名
|
# 获取最新的 ckpt 文件名
|
||||||
newest_ckpt = sorted_info[0][2]
|
newest_ckpt = sorted_info[0][2]
|
||||||
return newest_ckpt
|
return newest_ckpt
|
||||||
@ -28,9 +27,9 @@ def get_newest_ckpt(string_list):
|
|||||||
# 文本存在且不为空时 return True
|
# 文本存在且不为空时 return True
|
||||||
def check_txt_file(file_path):
|
def check_txt_file(file_path):
|
||||||
try:
|
try:
|
||||||
with open(file_path, 'r') as file:
|
with open(file_path, "r") as file:
|
||||||
text = file.readline().strip()
|
text = file.readline().strip()
|
||||||
assert text.strip() != ''
|
assert text.strip() != ""
|
||||||
return text
|
return text
|
||||||
except Exception:
|
except Exception:
|
||||||
return False
|
return False
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
"""Initialize modules for espnet2 neural networks."""
|
"""Initialize modules for espnet2 neural networks."""
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from typeguard import check_argument_types
|
from typeguard import check_argument_types
|
||||||
|
|
||||||
|
@ -18,14 +18,10 @@ def save_config_to_yaml(config, path):
|
|||||||
|
|
||||||
|
|
||||||
def write_args(args, path):
|
def write_args(args, path):
|
||||||
args_dict = dict(
|
args_dict = dict((name, getattr(args, name)) for name in dir(args) if not name.startswith("_"))
|
||||||
(name, getattr(args, name)) for name in dir(args) if not name.startswith("_")
|
|
||||||
)
|
|
||||||
with open(path, "a") as args_file:
|
with open(path, "a") as args_file:
|
||||||
args_file.write("==> torch version: {}\n".format(torch.__version__))
|
args_file.write("==> torch version: {}\n".format(torch.__version__))
|
||||||
args_file.write(
|
args_file.write("==> cudnn version: {}\n".format(torch.backends.cudnn.version()))
|
||||||
"==> cudnn version: {}\n".format(torch.backends.cudnn.version())
|
|
||||||
)
|
|
||||||
args_file.write("==> Cmd:\n")
|
args_file.write("==> Cmd:\n")
|
||||||
args_file.write(str(sys.argv))
|
args_file.write(str(sys.argv))
|
||||||
args_file.write("\n==> args:\n")
|
args_file.write("\n==> args:\n")
|
||||||
|
21
GPT_SoVITS/BigVGAN/LICENSE
Normal file
21
GPT_SoVITS/BigVGAN/LICENSE
Normal file
@ -0,0 +1,21 @@
|
|||||||
|
MIT License
|
||||||
|
|
||||||
|
Copyright (c) 2024 NVIDIA CORPORATION.
|
||||||
|
|
||||||
|
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||||
|
of this software and associated documentation files (the "Software"), to deal
|
||||||
|
in the Software without restriction, including without limitation the rights
|
||||||
|
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||||
|
copies of the Software, and to permit persons to whom the Software is
|
||||||
|
furnished to do so, subject to the following conditions:
|
||||||
|
|
||||||
|
The above copyright notice and this permission notice shall be included in all
|
||||||
|
copies or substantial portions of the Software.
|
||||||
|
|
||||||
|
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||||
|
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||||
|
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||||
|
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||||
|
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||||
|
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||||
|
SOFTWARE.
|
266
GPT_SoVITS/BigVGAN/README.md
Normal file
266
GPT_SoVITS/BigVGAN/README.md
Normal file
@ -0,0 +1,266 @@
|
|||||||
|
## BigVGAN: A Universal Neural Vocoder with Large-Scale Training
|
||||||
|
|
||||||
|
#### Sang-gil Lee, Wei Ping, Boris Ginsburg, Bryan Catanzaro, Sungroh Yoon
|
||||||
|
|
||||||
|
[[Paper]](https://arxiv.org/abs/2206.04658) - [[Code]](https://github.com/NVIDIA/BigVGAN) - [[Showcase]](https://bigvgan-demo.github.io/) - [[Project Page]](https://research.nvidia.com/labs/adlr/projects/bigvgan/) - [[Weights]](https://huggingface.co/collections/nvidia/bigvgan-66959df3d97fd7d98d97dc9a) - [[Demo]](https://huggingface.co/spaces/nvidia/BigVGAN)
|
||||||
|
|
||||||
|
[](https://paperswithcode.com/sota/speech-synthesis-on-libritts?p=bigvgan-a-universal-neural-vocoder-with-large)
|
||||||
|
|
||||||
|
<center><img src="https://user-images.githubusercontent.com/15963413/218609148-881e39df-33af-4af9-ab95-1427c4ebf062.png" width="800"></center>
|
||||||
|
|
||||||
|
## News
|
||||||
|
- **Sep 2024 (v2.4):**
|
||||||
|
- We have updated the pretrained checkpoints trained for 5M steps. This is final release of the BigVGAN-v2 checkpoints.
|
||||||
|
|
||||||
|
- **Jul 2024 (v2.3):**
|
||||||
|
- General refactor and code improvements for improved readability.
|
||||||
|
- Fully fused CUDA kernel of anti-alised activation (upsampling + activation + downsampling) with inference speed benchmark.
|
||||||
|
|
||||||
|
- **Jul 2024 (v2.2):** The repository now includes an interactive local demo using gradio.
|
||||||
|
|
||||||
|
- **Jul 2024 (v2.1):** BigVGAN is now integrated with 🤗 Hugging Face Hub with easy access to inference using pretrained checkpoints. We also provide an interactive demo on Hugging Face Spaces.
|
||||||
|
|
||||||
|
- **Jul 2024 (v2):** We release BigVGAN-v2 along with pretrained checkpoints. Below are the highlights:
|
||||||
|
- Custom CUDA kernel for inference: we provide a fused upsampling + activation kernel written in CUDA for accelerated inference speed. Our test shows 1.5 - 3x faster speed on a single A100 GPU.
|
||||||
|
- Improved discriminator and loss: BigVGAN-v2 is trained using a [multi-scale sub-band CQT discriminator](https://arxiv.org/abs/2311.14957) and a [multi-scale mel spectrogram loss](https://arxiv.org/abs/2306.06546).
|
||||||
|
- Larger training data: BigVGAN-v2 is trained using datasets containing diverse audio types, including speech in multiple languages, environmental sounds, and instruments.
|
||||||
|
- We provide pretrained checkpoints of BigVGAN-v2 using diverse audio configurations, supporting up to 44 kHz sampling rate and 512x upsampling ratio.
|
||||||
|
|
||||||
|
## Installation
|
||||||
|
|
||||||
|
The codebase has been tested on Python `3.10` and PyTorch `2.3.1` conda packages with either `pytorch-cuda=12.1` or `pytorch-cuda=11.8`. Below is an example command to create the conda environment:
|
||||||
|
|
||||||
|
```shell
|
||||||
|
conda create -n bigvgan python=3.10 pytorch torchvision torchaudio pytorch-cuda=12.1 -c pytorch -c nvidia
|
||||||
|
conda activate bigvgan
|
||||||
|
```
|
||||||
|
|
||||||
|
Clone the repository and install dependencies:
|
||||||
|
|
||||||
|
```shell
|
||||||
|
git clone https://github.com/NVIDIA/BigVGAN
|
||||||
|
cd BigVGAN
|
||||||
|
pip install -r requirements.txt
|
||||||
|
```
|
||||||
|
|
||||||
|
## Inference Quickstart using 🤗 Hugging Face Hub
|
||||||
|
|
||||||
|
Below example describes how you can use BigVGAN: load the pretrained BigVGAN generator from Hugging Face Hub, compute mel spectrogram from input waveform, and generate synthesized waveform using the mel spectrogram as the model's input.
|
||||||
|
|
||||||
|
```python
|
||||||
|
device = 'cuda'
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import bigvgan
|
||||||
|
import librosa
|
||||||
|
from meldataset import get_mel_spectrogram
|
||||||
|
|
||||||
|
# instantiate the model. You can optionally set use_cuda_kernel=True for faster inference.
|
||||||
|
model = bigvgan.BigVGAN.from_pretrained('nvidia/bigvgan_v2_24khz_100band_256x', use_cuda_kernel=False)
|
||||||
|
|
||||||
|
# remove weight norm in the model and set to eval mode
|
||||||
|
model.remove_weight_norm()
|
||||||
|
model = model.eval().to(device)
|
||||||
|
|
||||||
|
# load wav file and compute mel spectrogram
|
||||||
|
wav_path = '/path/to/your/audio.wav'
|
||||||
|
wav, sr = librosa.load(wav_path, sr=model.h.sampling_rate, mono=True) # wav is np.ndarray with shape [T_time] and values in [-1, 1]
|
||||||
|
wav = torch.FloatTensor(wav).unsqueeze(0) # wav is FloatTensor with shape [B(1), T_time]
|
||||||
|
|
||||||
|
# compute mel spectrogram from the ground truth audio
|
||||||
|
mel = get_mel_spectrogram(wav, model.h).to(device) # mel is FloatTensor with shape [B(1), C_mel, T_frame]
|
||||||
|
|
||||||
|
# generate waveform from mel
|
||||||
|
with torch.inference_mode():
|
||||||
|
wav_gen = model(mel) # wav_gen is FloatTensor with shape [B(1), 1, T_time] and values in [-1, 1]
|
||||||
|
wav_gen_float = wav_gen.squeeze(0).cpu() # wav_gen is FloatTensor with shape [1, T_time]
|
||||||
|
|
||||||
|
# you can convert the generated waveform to 16 bit linear PCM
|
||||||
|
wav_gen_int16 = (wav_gen_float * 32767.0).numpy().astype('int16') # wav_gen is now np.ndarray with shape [1, T_time] and int16 dtype
|
||||||
|
```
|
||||||
|
|
||||||
|
## Local gradio demo <a href='https://github.com/gradio-app/gradio'><img src='https://img.shields.io/github/stars/gradio-app/gradio'></a>
|
||||||
|
|
||||||
|
You can run a local gradio demo using below command:
|
||||||
|
|
||||||
|
```python
|
||||||
|
pip install -r demo/requirements.txt
|
||||||
|
python demo/app.py
|
||||||
|
```
|
||||||
|
|
||||||
|
## Training
|
||||||
|
|
||||||
|
Create symbolic link to the root of the dataset. The codebase uses filelist with the relative path from the dataset. Below are the example commands for LibriTTS dataset:
|
||||||
|
|
||||||
|
```shell
|
||||||
|
cd filelists/LibriTTS && \
|
||||||
|
ln -s /path/to/your/LibriTTS/train-clean-100 train-clean-100 && \
|
||||||
|
ln -s /path/to/your/LibriTTS/train-clean-360 train-clean-360 && \
|
||||||
|
ln -s /path/to/your/LibriTTS/train-other-500 train-other-500 && \
|
||||||
|
ln -s /path/to/your/LibriTTS/dev-clean dev-clean && \
|
||||||
|
ln -s /path/to/your/LibriTTS/dev-other dev-other && \
|
||||||
|
ln -s /path/to/your/LibriTTS/test-clean test-clean && \
|
||||||
|
ln -s /path/to/your/LibriTTS/test-other test-other && \
|
||||||
|
cd ../..
|
||||||
|
```
|
||||||
|
|
||||||
|
Train BigVGAN model. Below is an example command for training BigVGAN-v2 using LibriTTS dataset at 24kHz with a full 100-band mel spectrogram as input:
|
||||||
|
|
||||||
|
```shell
|
||||||
|
python train.py \
|
||||||
|
--config configs/bigvgan_v2_24khz_100band_256x.json \
|
||||||
|
--input_wavs_dir filelists/LibriTTS \
|
||||||
|
--input_training_file filelists/LibriTTS/train-full.txt \
|
||||||
|
--input_validation_file filelists/LibriTTS/val-full.txt \
|
||||||
|
--list_input_unseen_wavs_dir filelists/LibriTTS filelists/LibriTTS \
|
||||||
|
--list_input_unseen_validation_file filelists/LibriTTS/dev-clean.txt filelists/LibriTTS/dev-other.txt \
|
||||||
|
--checkpoint_path exp/bigvgan_v2_24khz_100band_256x
|
||||||
|
```
|
||||||
|
|
||||||
|
## Synthesis
|
||||||
|
|
||||||
|
Synthesize from BigVGAN model. Below is an example command for generating audio from the model.
|
||||||
|
It computes mel spectrograms using wav files from `--input_wavs_dir` and saves the generated audio to `--output_dir`.
|
||||||
|
|
||||||
|
```shell
|
||||||
|
python inference.py \
|
||||||
|
--checkpoint_file /path/to/your/bigvgan_v2_24khz_100band_256x/bigvgan_generator.pt \
|
||||||
|
--input_wavs_dir /path/to/your/input_wav \
|
||||||
|
--output_dir /path/to/your/output_wav
|
||||||
|
```
|
||||||
|
|
||||||
|
`inference_e2e.py` supports synthesis directly from the mel spectrogram saved in `.npy` format, with shapes `[1, channel, frame]` or `[channel, frame]`.
|
||||||
|
It loads mel spectrograms from `--input_mels_dir` and saves the generated audio to `--output_dir`.
|
||||||
|
|
||||||
|
Make sure that the STFT hyperparameters for mel spectrogram are the same as the model, which are defined in `config.json` of the corresponding model.
|
||||||
|
|
||||||
|
```shell
|
||||||
|
python inference_e2e.py \
|
||||||
|
--checkpoint_file /path/to/your/bigvgan_v2_24khz_100band_256x/bigvgan_generator.pt \
|
||||||
|
--input_mels_dir /path/to/your/input_mel \
|
||||||
|
--output_dir /path/to/your/output_wav
|
||||||
|
```
|
||||||
|
|
||||||
|
## Using Custom CUDA Kernel for Synthesis
|
||||||
|
|
||||||
|
You can apply the fast CUDA inference kernel by using a parameter `use_cuda_kernel` when instantiating BigVGAN:
|
||||||
|
|
||||||
|
```python
|
||||||
|
generator = BigVGAN(h, use_cuda_kernel=True)
|
||||||
|
```
|
||||||
|
|
||||||
|
You can also pass `--use_cuda_kernel` to `inference.py` and `inference_e2e.py` to enable this feature.
|
||||||
|
|
||||||
|
When applied for the first time, it builds the kernel using `nvcc` and `ninja`. If the build succeeds, the kernel is saved to `alias_free_activation/cuda/build` and the model automatically loads the kernel. The codebase has been tested using CUDA `12.1`.
|
||||||
|
|
||||||
|
Please make sure that both are installed in your system and `nvcc` installed in your system matches the version your PyTorch build is using.
|
||||||
|
|
||||||
|
We recommend running `test_cuda_vs_torch_model.py` first to build and check the correctness of the CUDA kernel. See below example command and its output, where it returns `[Success] test CUDA fused vs. plain torch BigVGAN inference`:
|
||||||
|
|
||||||
|
```python
|
||||||
|
python tests/test_cuda_vs_torch_model.py \
|
||||||
|
--checkpoint_file /path/to/your/bigvgan_generator.pt
|
||||||
|
```
|
||||||
|
|
||||||
|
```shell
|
||||||
|
loading plain Pytorch BigVGAN
|
||||||
|
...
|
||||||
|
loading CUDA kernel BigVGAN with auto-build
|
||||||
|
Detected CUDA files, patching ldflags
|
||||||
|
Emitting ninja build file /path/to/your/BigVGAN/alias_free_activation/cuda/build/build.ninja..
|
||||||
|
Building extension module anti_alias_activation_cuda...
|
||||||
|
...
|
||||||
|
Loading extension module anti_alias_activation_cuda...
|
||||||
|
...
|
||||||
|
Loading '/path/to/your/bigvgan_generator.pt'
|
||||||
|
...
|
||||||
|
[Success] test CUDA fused vs. plain torch BigVGAN inference
|
||||||
|
> mean_difference=0.0007238413265440613
|
||||||
|
...
|
||||||
|
```
|
||||||
|
|
||||||
|
If you see `[Fail] test CUDA fused vs. plain torch BigVGAN inference`, it means that the CUDA kernel inference is incorrect. Please check if `nvcc` installed in your system is compatible with your PyTorch version.
|
||||||
|
|
||||||
|
## Pretrained Models
|
||||||
|
|
||||||
|
We provide the [pretrained models on Hugging Face Collections](https://huggingface.co/collections/nvidia/bigvgan-66959df3d97fd7d98d97dc9a).
|
||||||
|
One can download the checkpoints of the generator weight (named `bigvgan_generator.pt`) and its discriminator/optimizer states (named `bigvgan_discriminator_optimizer.pt`) within the listed model repositories.
|
||||||
|
|
||||||
|
| Model Name | Sampling Rate | Mel band | fmax | Upsampling Ratio | Params | Dataset | Steps | Fine-Tuned |
|
||||||
|
|:--------------------------------------------------------------------------------------------------------:|:-------------:|:--------:|:-----:|:----------------:|:------:|:--------------------------:|:-----:|:----------:|
|
||||||
|
| [bigvgan_v2_44khz_128band_512x](https://huggingface.co/nvidia/bigvgan_v2_44khz_128band_512x) | 44 kHz | 128 | 22050 | 512 | 122M | Large-scale Compilation | 5M | No |
|
||||||
|
| [bigvgan_v2_44khz_128band_256x](https://huggingface.co/nvidia/bigvgan_v2_44khz_128band_256x) | 44 kHz | 128 | 22050 | 256 | 112M | Large-scale Compilation | 5M | No |
|
||||||
|
| [bigvgan_v2_24khz_100band_256x](https://huggingface.co/nvidia/bigvgan_v2_24khz_100band_256x) | 24 kHz | 100 | 12000 | 256 | 112M | Large-scale Compilation | 5M | No |
|
||||||
|
| [bigvgan_v2_22khz_80band_256x](https://huggingface.co/nvidia/bigvgan_v2_22khz_80band_256x) | 22 kHz | 80 | 11025 | 256 | 112M | Large-scale Compilation | 5M | No |
|
||||||
|
| [bigvgan_v2_22khz_80band_fmax8k_256x](https://huggingface.co/nvidia/bigvgan_v2_22khz_80band_fmax8k_256x) | 22 kHz | 80 | 8000 | 256 | 112M | Large-scale Compilation | 5M | No |
|
||||||
|
| [bigvgan_24khz_100band](https://huggingface.co/nvidia/bigvgan_24khz_100band) | 24 kHz | 100 | 12000 | 256 | 112M | LibriTTS | 5M | No |
|
||||||
|
| [bigvgan_base_24khz_100band](https://huggingface.co/nvidia/bigvgan_base_24khz_100band) | 24 kHz | 100 | 12000 | 256 | 14M | LibriTTS | 5M | No |
|
||||||
|
| [bigvgan_22khz_80band](https://huggingface.co/nvidia/bigvgan_22khz_80band) | 22 kHz | 80 | 8000 | 256 | 112M | LibriTTS + VCTK + LJSpeech | 5M | No |
|
||||||
|
| [bigvgan_base_22khz_80band](https://huggingface.co/nvidia/bigvgan_base_22khz_80band) | 22 kHz | 80 | 8000 | 256 | 14M | LibriTTS + VCTK + LJSpeech | 5M | No |
|
||||||
|
|
||||||
|
The paper results are based on the original 24kHz BigVGAN models (`bigvgan_24khz_100band` and `bigvgan_base_24khz_100band`) trained on LibriTTS dataset.
|
||||||
|
We also provide 22kHz BigVGAN models with band-limited setup (i.e., fmax=8000) for TTS applications.
|
||||||
|
Note that the checkpoints use `snakebeta` activation with log scale parameterization, which have the best overall quality.
|
||||||
|
|
||||||
|
You can fine-tune the models by:
|
||||||
|
|
||||||
|
1. downloading the checkpoints (both the generator weight and its discriminator/optimizer states)
|
||||||
|
2. resuming training using your audio dataset by specifying `--checkpoint_path` that includes the checkpoints when launching `train.py`
|
||||||
|
|
||||||
|
## Training Details of BigVGAN-v2
|
||||||
|
|
||||||
|
Comapred to the original BigVGAN, the pretrained checkpoints of BigVGAN-v2 used `batch_size=32` with a longer `segment_size=65536` and are trained using 8 A100 GPUs.
|
||||||
|
|
||||||
|
Note that the BigVGAN-v2 `json` config files in `./configs` use `batch_size=4` as default to fit in a single A100 GPU for training. You can fine-tune the models adjusting `batch_size` depending on your GPUs.
|
||||||
|
|
||||||
|
When training BigVGAN-v2 from scratch with small batch size, it can potentially encounter the early divergence problem mentioned in the paper. In such case, we recommend lowering the `clip_grad_norm` value (e.g. `100`) for the early training iterations (e.g. 20000 steps) and increase the value to the default `500`.
|
||||||
|
|
||||||
|
## Evaluation Results of BigVGAN-v2
|
||||||
|
|
||||||
|
Below are the objective results of the 24kHz model (`bigvgan_v2_24khz_100band_256x`) obtained from the LibriTTS `dev` sets. BigVGAN-v2 shows noticeable improvements of the metrics. The model also exhibits reduced perceptual artifacts, especially for non-speech audio.
|
||||||
|
|
||||||
|
| Model | Dataset | Steps | PESQ(↑) | M-STFT(↓) | MCD(↓) | Periodicity(↓) | V/UV F1(↑) |
|
||||||
|
|:----------:|:-----------------------:|:-----:|:---------:|:----------:|:----------:|:--------------:|:----------:|
|
||||||
|
| BigVGAN | LibriTTS | 1M | 4.027 | 0.7997 | 0.3745 | 0.1018 | 0.9598 |
|
||||||
|
| BigVGAN | LibriTTS | 5M | 4.256 | 0.7409 | 0.2988 | 0.0809 | 0.9698 |
|
||||||
|
| BigVGAN-v2 | Large-scale Compilation | 3M | 4.359 | 0.7134 | 0.3060 | 0.0621 | 0.9777 |
|
||||||
|
| BigVGAN-v2 | Large-scale Compilation | 5M | **4.362** | **0.7026** | **0.2903** | **0.0593** | **0.9793** |
|
||||||
|
|
||||||
|
## Speed Benchmark
|
||||||
|
|
||||||
|
Below are the speed and VRAM usage benchmark results of BigVGAN from `tests/test_cuda_vs_torch_model.py`, using `bigvgan_v2_24khz_100band_256x` as a reference model.
|
||||||
|
|
||||||
|
| GPU | num_mel_frame | use_cuda_kernel | Speed (kHz) | Real-time Factor | VRAM (GB) |
|
||||||
|
|:--------------------------:|:-------------:|:---------------:|:-----------:|:----------------:|:---------:|
|
||||||
|
| NVIDIA A100 | 256 | False | 1672.1 | 69.7x | 1.3 |
|
||||||
|
| | | True | 3916.5 | 163.2x | 1.3 |
|
||||||
|
| | 2048 | False | 1899.6 | 79.2x | 1.7 |
|
||||||
|
| | | True | 5330.1 | 222.1x | 1.7 |
|
||||||
|
| | 16384 | False | 1973.8 | 82.2x | 5.0 |
|
||||||
|
| | | True | 5761.7 | 240.1x | 4.4 |
|
||||||
|
| NVIDIA GeForce RTX 3080 | 256 | False | 841.1 | 35.0x | 1.3 |
|
||||||
|
| | | True | 1598.1 | 66.6x | 1.3 |
|
||||||
|
| | 2048 | False | 929.9 | 38.7x | 1.7 |
|
||||||
|
| | | True | 1971.3 | 82.1x | 1.6 |
|
||||||
|
| | 16384 | False | 943.4 | 39.3x | 5.0 |
|
||||||
|
| | | True | 2026.5 | 84.4x | 3.9 |
|
||||||
|
| NVIDIA GeForce RTX 2080 Ti | 256 | False | 515.6 | 21.5x | 1.3 |
|
||||||
|
| | | True | 811.3 | 33.8x | 1.3 |
|
||||||
|
| | 2048 | False | 576.5 | 24.0x | 1.7 |
|
||||||
|
| | | True | 1023.0 | 42.6x | 1.5 |
|
||||||
|
| | 16384 | False | 589.4 | 24.6x | 5.0 |
|
||||||
|
| | | True | 1068.1 | 44.5x | 3.2 |
|
||||||
|
|
||||||
|
## Acknowledgements
|
||||||
|
|
||||||
|
We thank Vijay Anand Korthikanti and Kevin J. Shih for their generous support in implementing the CUDA kernel for inference.
|
||||||
|
|
||||||
|
## References
|
||||||
|
|
||||||
|
- [HiFi-GAN](https://github.com/jik876/hifi-gan) (for generator and multi-period discriminator)
|
||||||
|
- [Snake](https://github.com/EdwardDixon/snake) (for periodic activation)
|
||||||
|
- [Alias-free-torch](https://github.com/junjun3518/alias-free-torch) (for anti-aliasing)
|
||||||
|
- [Julius](https://github.com/adefossez/julius) (for low-pass filter)
|
||||||
|
- [UnivNet](https://github.com/mindslab-ai/univnet) (for multi-resolution discriminator)
|
||||||
|
- [descript-audio-codec](https://github.com/descriptinc/descript-audio-codec) and [vocos](https://github.com/gemelo-ai/vocos) (for multi-band multi-scale STFT discriminator and multi-scale mel spectrogram loss)
|
||||||
|
- [Amphion](https://github.com/open-mmlab/Amphion) (for multi-scale sub-band CQT discriminator)
|
122
GPT_SoVITS/BigVGAN/activations.py
Normal file
122
GPT_SoVITS/BigVGAN/activations.py
Normal file
@ -0,0 +1,122 @@
|
|||||||
|
# Implementation adapted from https://github.com/EdwardDixon/snake under the MIT license.
|
||||||
|
# LICENSE is in incl_licenses directory.
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import nn, sin, pow
|
||||||
|
from torch.nn import Parameter
|
||||||
|
|
||||||
|
|
||||||
|
class Snake(nn.Module):
|
||||||
|
"""
|
||||||
|
Implementation of a sine-based periodic activation function
|
||||||
|
Shape:
|
||||||
|
- Input: (B, C, T)
|
||||||
|
- Output: (B, C, T), same shape as the input
|
||||||
|
Parameters:
|
||||||
|
- alpha - trainable parameter
|
||||||
|
References:
|
||||||
|
- This activation function is from this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
|
||||||
|
https://arxiv.org/abs/2006.08195
|
||||||
|
Examples:
|
||||||
|
>>> a1 = snake(256)
|
||||||
|
>>> x = torch.randn(256)
|
||||||
|
>>> x = a1(x)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False):
|
||||||
|
"""
|
||||||
|
Initialization.
|
||||||
|
INPUT:
|
||||||
|
- in_features: shape of the input
|
||||||
|
- alpha: trainable parameter
|
||||||
|
alpha is initialized to 1 by default, higher values = higher-frequency.
|
||||||
|
alpha will be trained along with the rest of your model.
|
||||||
|
"""
|
||||||
|
super(Snake, self).__init__()
|
||||||
|
self.in_features = in_features
|
||||||
|
|
||||||
|
# Initialize alpha
|
||||||
|
self.alpha_logscale = alpha_logscale
|
||||||
|
if self.alpha_logscale: # Log scale alphas initialized to zeros
|
||||||
|
self.alpha = Parameter(torch.zeros(in_features) * alpha)
|
||||||
|
else: # Linear scale alphas initialized to ones
|
||||||
|
self.alpha = Parameter(torch.ones(in_features) * alpha)
|
||||||
|
|
||||||
|
self.alpha.requires_grad = alpha_trainable
|
||||||
|
|
||||||
|
self.no_div_by_zero = 0.000000001
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
"""
|
||||||
|
Forward pass of the function.
|
||||||
|
Applies the function to the input elementwise.
|
||||||
|
Snake ∶= x + 1/a * sin^2 (xa)
|
||||||
|
"""
|
||||||
|
alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # Line up with x to [B, C, T]
|
||||||
|
if self.alpha_logscale:
|
||||||
|
alpha = torch.exp(alpha)
|
||||||
|
x = x + (1.0 / (alpha + self.no_div_by_zero)) * pow(sin(x * alpha), 2)
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class SnakeBeta(nn.Module):
|
||||||
|
"""
|
||||||
|
A modified Snake function which uses separate parameters for the magnitude of the periodic components
|
||||||
|
Shape:
|
||||||
|
- Input: (B, C, T)
|
||||||
|
- Output: (B, C, T), same shape as the input
|
||||||
|
Parameters:
|
||||||
|
- alpha - trainable parameter that controls frequency
|
||||||
|
- beta - trainable parameter that controls magnitude
|
||||||
|
References:
|
||||||
|
- This activation function is a modified version based on this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
|
||||||
|
https://arxiv.org/abs/2006.08195
|
||||||
|
Examples:
|
||||||
|
>>> a1 = snakebeta(256)
|
||||||
|
>>> x = torch.randn(256)
|
||||||
|
>>> x = a1(x)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False):
|
||||||
|
"""
|
||||||
|
Initialization.
|
||||||
|
INPUT:
|
||||||
|
- in_features: shape of the input
|
||||||
|
- alpha - trainable parameter that controls frequency
|
||||||
|
- beta - trainable parameter that controls magnitude
|
||||||
|
alpha is initialized to 1 by default, higher values = higher-frequency.
|
||||||
|
beta is initialized to 1 by default, higher values = higher-magnitude.
|
||||||
|
alpha will be trained along with the rest of your model.
|
||||||
|
"""
|
||||||
|
super(SnakeBeta, self).__init__()
|
||||||
|
self.in_features = in_features
|
||||||
|
|
||||||
|
# Initialize alpha
|
||||||
|
self.alpha_logscale = alpha_logscale
|
||||||
|
if self.alpha_logscale: # Log scale alphas initialized to zeros
|
||||||
|
self.alpha = Parameter(torch.zeros(in_features) * alpha)
|
||||||
|
self.beta = Parameter(torch.zeros(in_features) * alpha)
|
||||||
|
else: # Linear scale alphas initialized to ones
|
||||||
|
self.alpha = Parameter(torch.ones(in_features) * alpha)
|
||||||
|
self.beta = Parameter(torch.ones(in_features) * alpha)
|
||||||
|
|
||||||
|
self.alpha.requires_grad = alpha_trainable
|
||||||
|
self.beta.requires_grad = alpha_trainable
|
||||||
|
|
||||||
|
self.no_div_by_zero = 0.000000001
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
"""
|
||||||
|
Forward pass of the function.
|
||||||
|
Applies the function to the input elementwise.
|
||||||
|
SnakeBeta ∶= x + 1/b * sin^2 (xa)
|
||||||
|
"""
|
||||||
|
alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # Line up with x to [B, C, T]
|
||||||
|
beta = self.beta.unsqueeze(0).unsqueeze(-1)
|
||||||
|
if self.alpha_logscale:
|
||||||
|
alpha = torch.exp(alpha)
|
||||||
|
beta = torch.exp(beta)
|
||||||
|
x = x + (1.0 / (beta + self.no_div_by_zero)) * pow(sin(x * alpha), 2)
|
||||||
|
|
||||||
|
return x
|
@ -0,0 +1,69 @@
|
|||||||
|
# Copyright (c) 2024 NVIDIA CORPORATION.
|
||||||
|
# Licensed under the MIT license.
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from alias_free_activation.torch.resample import UpSample1d, DownSample1d
|
||||||
|
|
||||||
|
# load fused CUDA kernel: this enables importing anti_alias_activation_cuda
|
||||||
|
from alias_free_activation.cuda import load
|
||||||
|
|
||||||
|
anti_alias_activation_cuda = load.load()
|
||||||
|
|
||||||
|
|
||||||
|
class FusedAntiAliasActivation(torch.autograd.Function):
|
||||||
|
"""
|
||||||
|
Assumes filter size 12, replication padding on upsampling/downsampling, and logscale alpha/beta parameters as inputs.
|
||||||
|
The hyperparameters are hard-coded in the kernel to maximize speed.
|
||||||
|
NOTE: The fused kenrel is incorrect for Activation1d with different hyperparameters.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def forward(ctx, inputs, up_ftr, down_ftr, alpha, beta):
|
||||||
|
activation_results = anti_alias_activation_cuda.forward(inputs, up_ftr, down_ftr, alpha, beta)
|
||||||
|
|
||||||
|
return activation_results
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def backward(ctx, output_grads):
|
||||||
|
raise NotImplementedError
|
||||||
|
return output_grads, None, None
|
||||||
|
|
||||||
|
|
||||||
|
class Activation1d(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
activation,
|
||||||
|
up_ratio: int = 2,
|
||||||
|
down_ratio: int = 2,
|
||||||
|
up_kernel_size: int = 12,
|
||||||
|
down_kernel_size: int = 12,
|
||||||
|
fused: bool = True,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.up_ratio = up_ratio
|
||||||
|
self.down_ratio = down_ratio
|
||||||
|
self.act = activation
|
||||||
|
self.upsample = UpSample1d(up_ratio, up_kernel_size)
|
||||||
|
self.downsample = DownSample1d(down_ratio, down_kernel_size)
|
||||||
|
|
||||||
|
self.fused = fused # Whether to use fused CUDA kernel or not
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
if not self.fused:
|
||||||
|
x = self.upsample(x)
|
||||||
|
x = self.act(x)
|
||||||
|
x = self.downsample(x)
|
||||||
|
return x
|
||||||
|
else:
|
||||||
|
if self.act.__class__.__name__ == "Snake":
|
||||||
|
beta = self.act.alpha.data # Snake uses same params for alpha and beta
|
||||||
|
else:
|
||||||
|
beta = self.act.beta.data # Snakebeta uses different params for alpha and beta
|
||||||
|
alpha = self.act.alpha.data
|
||||||
|
if not self.act.alpha_logscale: # Exp baked into cuda kernel, cancel it out with a log
|
||||||
|
alpha = torch.log(alpha)
|
||||||
|
beta = torch.log(beta)
|
||||||
|
|
||||||
|
x = FusedAntiAliasActivation.apply(x, self.upsample.filter, self.downsample.lowpass.filter, alpha, beta)
|
||||||
|
return x
|
@ -0,0 +1,23 @@
|
|||||||
|
/* coding=utf-8
|
||||||
|
* Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include <torch/extension.h>
|
||||||
|
|
||||||
|
extern "C" torch::Tensor fwd_cuda(torch::Tensor const &input, torch::Tensor const &up_filter, torch::Tensor const &down_filter, torch::Tensor const &alpha, torch::Tensor const &beta);
|
||||||
|
|
||||||
|
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||||
|
m.def("forward", &fwd_cuda, "Anti-Alias Activation forward (CUDA)");
|
||||||
|
}
|
@ -0,0 +1,246 @@
|
|||||||
|
/* coding=utf-8
|
||||||
|
* Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include <ATen/ATen.h>
|
||||||
|
#include <cuda.h>
|
||||||
|
#include <cuda_runtime.h>
|
||||||
|
#include <cuda_fp16.h>
|
||||||
|
#include <cuda_profiler_api.h>
|
||||||
|
#include <ATen/cuda/CUDAContext.h>
|
||||||
|
#include <torch/extension.h>
|
||||||
|
#include "type_shim.h"
|
||||||
|
#include <assert.h>
|
||||||
|
#include <cfloat>
|
||||||
|
#include <limits>
|
||||||
|
#include <stdint.h>
|
||||||
|
#include <c10/macros/Macros.h>
|
||||||
|
|
||||||
|
namespace
|
||||||
|
{
|
||||||
|
// Hard-coded hyperparameters
|
||||||
|
// WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and
|
||||||
|
constexpr int ELEMENTS_PER_LDG_STG = 1; //(WARP_ITERATIONS < 4) ? 1 : 4;
|
||||||
|
constexpr int BUFFER_SIZE = 32;
|
||||||
|
constexpr int FILTER_SIZE = 12;
|
||||||
|
constexpr int HALF_FILTER_SIZE = 6;
|
||||||
|
constexpr int UPSAMPLE_REPLICATION_PAD = 5; // 5 on each side, matching torch impl
|
||||||
|
constexpr int DOWNSAMPLE_REPLICATION_PAD_LEFT = 5; // matching torch impl
|
||||||
|
constexpr int DOWNSAMPLE_REPLICATION_PAD_RIGHT = 6; // matching torch impl
|
||||||
|
|
||||||
|
template <typename input_t, typename output_t, typename acc_t>
|
||||||
|
__global__ void anti_alias_activation_forward(
|
||||||
|
output_t *dst,
|
||||||
|
const input_t *src,
|
||||||
|
const input_t *up_ftr,
|
||||||
|
const input_t *down_ftr,
|
||||||
|
const input_t *alpha,
|
||||||
|
const input_t *beta,
|
||||||
|
int batch_size,
|
||||||
|
int channels,
|
||||||
|
int seq_len)
|
||||||
|
{
|
||||||
|
// Up and downsample filters
|
||||||
|
input_t up_filter[FILTER_SIZE];
|
||||||
|
input_t down_filter[FILTER_SIZE];
|
||||||
|
|
||||||
|
// Load data from global memory including extra indices reserved for replication paddings
|
||||||
|
input_t elements[2 * FILTER_SIZE + 2 * BUFFER_SIZE + 2 * UPSAMPLE_REPLICATION_PAD] = {0};
|
||||||
|
input_t intermediates[2 * FILTER_SIZE + 2 * BUFFER_SIZE + DOWNSAMPLE_REPLICATION_PAD_LEFT + DOWNSAMPLE_REPLICATION_PAD_RIGHT] = {0};
|
||||||
|
|
||||||
|
// Output stores downsampled output before writing to dst
|
||||||
|
output_t output[BUFFER_SIZE];
|
||||||
|
|
||||||
|
// blockDim/threadIdx = (128, 1, 1)
|
||||||
|
// gridDim/blockIdx = (seq_blocks, channels, batches)
|
||||||
|
int block_offset = (blockIdx.x * 128 * BUFFER_SIZE + seq_len * (blockIdx.y + gridDim.y * blockIdx.z));
|
||||||
|
int local_offset = threadIdx.x * BUFFER_SIZE;
|
||||||
|
int seq_offset = blockIdx.x * 128 * BUFFER_SIZE + local_offset;
|
||||||
|
|
||||||
|
// intermediate have double the seq_len
|
||||||
|
int intermediate_local_offset = threadIdx.x * BUFFER_SIZE * 2;
|
||||||
|
int intermediate_seq_offset = blockIdx.x * 128 * BUFFER_SIZE * 2 + intermediate_local_offset;
|
||||||
|
|
||||||
|
// Get values needed for replication padding before moving pointer
|
||||||
|
const input_t *right_most_pntr = src + (seq_len * (blockIdx.y + gridDim.y * blockIdx.z));
|
||||||
|
input_t seq_left_most_value = right_most_pntr[0];
|
||||||
|
input_t seq_right_most_value = right_most_pntr[seq_len - 1];
|
||||||
|
|
||||||
|
// Move src and dst pointers
|
||||||
|
src += block_offset + local_offset;
|
||||||
|
dst += block_offset + local_offset;
|
||||||
|
|
||||||
|
// Alpha and beta values for snake activatons. Applies exp by default
|
||||||
|
alpha = alpha + blockIdx.y;
|
||||||
|
input_t alpha_val = expf(alpha[0]);
|
||||||
|
beta = beta + blockIdx.y;
|
||||||
|
input_t beta_val = expf(beta[0]);
|
||||||
|
|
||||||
|
#pragma unroll
|
||||||
|
for (int it = 0; it < FILTER_SIZE; it += 1)
|
||||||
|
{
|
||||||
|
up_filter[it] = up_ftr[it];
|
||||||
|
down_filter[it] = down_ftr[it];
|
||||||
|
}
|
||||||
|
|
||||||
|
// Apply replication padding for upsampling, matching torch impl
|
||||||
|
#pragma unroll
|
||||||
|
for (int it = -HALF_FILTER_SIZE; it < BUFFER_SIZE + HALF_FILTER_SIZE; it += 1)
|
||||||
|
{
|
||||||
|
int element_index = seq_offset + it; // index for element
|
||||||
|
if ((element_index < 0) && (element_index >= -UPSAMPLE_REPLICATION_PAD))
|
||||||
|
{
|
||||||
|
elements[2 * (HALF_FILTER_SIZE + it)] = 2 * seq_left_most_value;
|
||||||
|
}
|
||||||
|
if ((element_index >= seq_len) && (element_index < seq_len + UPSAMPLE_REPLICATION_PAD))
|
||||||
|
{
|
||||||
|
elements[2 * (HALF_FILTER_SIZE + it)] = 2 * seq_right_most_value;
|
||||||
|
}
|
||||||
|
if ((element_index >= 0) && (element_index < seq_len))
|
||||||
|
{
|
||||||
|
elements[2 * (HALF_FILTER_SIZE + it)] = 2 * src[it];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Apply upsampling strided convolution and write to intermediates. It reserves DOWNSAMPLE_REPLICATION_PAD_LEFT for replication padding of the downsampilng conv later
|
||||||
|
#pragma unroll
|
||||||
|
for (int it = 0; it < (2 * BUFFER_SIZE + 2 * FILTER_SIZE); it += 1)
|
||||||
|
{
|
||||||
|
input_t acc = 0.0;
|
||||||
|
int element_index = intermediate_seq_offset + it; // index for intermediate
|
||||||
|
#pragma unroll
|
||||||
|
for (int f_idx = 0; f_idx < FILTER_SIZE; f_idx += 1)
|
||||||
|
{
|
||||||
|
if ((element_index + f_idx) >= 0)
|
||||||
|
{
|
||||||
|
acc += up_filter[f_idx] * elements[it + f_idx];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
intermediates[it + DOWNSAMPLE_REPLICATION_PAD_LEFT] = acc;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Apply activation function. It reserves DOWNSAMPLE_REPLICATION_PAD_LEFT and DOWNSAMPLE_REPLICATION_PAD_RIGHT for replication padding of the downsampilng conv later
|
||||||
|
double no_div_by_zero = 0.000000001;
|
||||||
|
#pragma unroll
|
||||||
|
for (int it = 0; it < 2 * BUFFER_SIZE + 2 * FILTER_SIZE; it += 1)
|
||||||
|
{
|
||||||
|
intermediates[it + DOWNSAMPLE_REPLICATION_PAD_LEFT] += (1.0 / (beta_val + no_div_by_zero)) * sinf(intermediates[it + DOWNSAMPLE_REPLICATION_PAD_LEFT] * alpha_val) * sinf(intermediates[it + DOWNSAMPLE_REPLICATION_PAD_LEFT] * alpha_val);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Apply replication padding before downsampling conv from intermediates
|
||||||
|
#pragma unroll
|
||||||
|
for (int it = 0; it < DOWNSAMPLE_REPLICATION_PAD_LEFT; it += 1)
|
||||||
|
{
|
||||||
|
intermediates[it] = intermediates[DOWNSAMPLE_REPLICATION_PAD_LEFT];
|
||||||
|
}
|
||||||
|
#pragma unroll
|
||||||
|
for (int it = DOWNSAMPLE_REPLICATION_PAD_LEFT + 2 * BUFFER_SIZE + 2 * FILTER_SIZE; it < DOWNSAMPLE_REPLICATION_PAD_LEFT + 2 * BUFFER_SIZE + 2 * FILTER_SIZE + DOWNSAMPLE_REPLICATION_PAD_RIGHT; it += 1)
|
||||||
|
{
|
||||||
|
intermediates[it] = intermediates[DOWNSAMPLE_REPLICATION_PAD_LEFT + 2 * BUFFER_SIZE + 2 * FILTER_SIZE - 1];
|
||||||
|
}
|
||||||
|
|
||||||
|
// Apply downsample strided convolution (assuming stride=2) from intermediates
|
||||||
|
#pragma unroll
|
||||||
|
for (int it = 0; it < BUFFER_SIZE; it += 1)
|
||||||
|
{
|
||||||
|
input_t acc = 0.0;
|
||||||
|
#pragma unroll
|
||||||
|
for (int f_idx = 0; f_idx < FILTER_SIZE; f_idx += 1)
|
||||||
|
{
|
||||||
|
// Add constant DOWNSAMPLE_REPLICATION_PAD_RIGHT to match torch implementation
|
||||||
|
acc += down_filter[f_idx] * intermediates[it * 2 + f_idx + DOWNSAMPLE_REPLICATION_PAD_RIGHT];
|
||||||
|
}
|
||||||
|
output[it] = acc;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write output to dst
|
||||||
|
#pragma unroll
|
||||||
|
for (int it = 0; it < BUFFER_SIZE; it += ELEMENTS_PER_LDG_STG)
|
||||||
|
{
|
||||||
|
int element_index = seq_offset + it;
|
||||||
|
if (element_index < seq_len)
|
||||||
|
{
|
||||||
|
dst[it] = output[it];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename input_t, typename output_t, typename acc_t>
|
||||||
|
void dispatch_anti_alias_activation_forward(
|
||||||
|
output_t *dst,
|
||||||
|
const input_t *src,
|
||||||
|
const input_t *up_ftr,
|
||||||
|
const input_t *down_ftr,
|
||||||
|
const input_t *alpha,
|
||||||
|
const input_t *beta,
|
||||||
|
int batch_size,
|
||||||
|
int channels,
|
||||||
|
int seq_len)
|
||||||
|
{
|
||||||
|
if (seq_len == 0)
|
||||||
|
{
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
// Use 128 threads per block to maximimize gpu utilization
|
||||||
|
constexpr int threads_per_block = 128;
|
||||||
|
constexpr int seq_len_per_block = 4096;
|
||||||
|
int blocks_per_seq_len = (seq_len + seq_len_per_block - 1) / seq_len_per_block;
|
||||||
|
dim3 blocks(blocks_per_seq_len, channels, batch_size);
|
||||||
|
dim3 threads(threads_per_block, 1, 1);
|
||||||
|
|
||||||
|
anti_alias_activation_forward<input_t, output_t, acc_t>
|
||||||
|
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, up_ftr, down_ftr, alpha, beta, batch_size, channels, seq_len);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
extern "C" torch::Tensor fwd_cuda(torch::Tensor const &input, torch::Tensor const &up_filter, torch::Tensor const &down_filter, torch::Tensor const &alpha, torch::Tensor const &beta)
|
||||||
|
{
|
||||||
|
// Input is a 3d tensor with dimensions [batches, channels, seq_len]
|
||||||
|
const int batches = input.size(0);
|
||||||
|
const int channels = input.size(1);
|
||||||
|
const int seq_len = input.size(2);
|
||||||
|
|
||||||
|
// Output
|
||||||
|
auto act_options = input.options().requires_grad(false);
|
||||||
|
|
||||||
|
torch::Tensor anti_alias_activation_results =
|
||||||
|
torch::empty({batches, channels, seq_len}, act_options);
|
||||||
|
|
||||||
|
void *input_ptr = static_cast<void *>(input.data_ptr());
|
||||||
|
void *up_filter_ptr = static_cast<void *>(up_filter.data_ptr());
|
||||||
|
void *down_filter_ptr = static_cast<void *>(down_filter.data_ptr());
|
||||||
|
void *alpha_ptr = static_cast<void *>(alpha.data_ptr());
|
||||||
|
void *beta_ptr = static_cast<void *>(beta.data_ptr());
|
||||||
|
void *anti_alias_activation_results_ptr = static_cast<void *>(anti_alias_activation_results.data_ptr());
|
||||||
|
|
||||||
|
DISPATCH_FLOAT_HALF_AND_BFLOAT(
|
||||||
|
input.scalar_type(),
|
||||||
|
"dispatch anti alias activation_forward",
|
||||||
|
dispatch_anti_alias_activation_forward<scalar_t, scalar_t, float>(
|
||||||
|
reinterpret_cast<scalar_t *>(anti_alias_activation_results_ptr),
|
||||||
|
reinterpret_cast<const scalar_t *>(input_ptr),
|
||||||
|
reinterpret_cast<const scalar_t *>(up_filter_ptr),
|
||||||
|
reinterpret_cast<const scalar_t *>(down_filter_ptr),
|
||||||
|
reinterpret_cast<const scalar_t *>(alpha_ptr),
|
||||||
|
reinterpret_cast<const scalar_t *>(beta_ptr),
|
||||||
|
batches,
|
||||||
|
channels,
|
||||||
|
seq_len););
|
||||||
|
return anti_alias_activation_results;
|
||||||
|
}
|
1
GPT_SoVITS/BigVGAN/alias_free_activation/cuda/build/_
Normal file
1
GPT_SoVITS/BigVGAN/alias_free_activation/cuda/build/_
Normal file
@ -0,0 +1 @@
|
|||||||
|
|
29
GPT_SoVITS/BigVGAN/alias_free_activation/cuda/compat.h
Normal file
29
GPT_SoVITS/BigVGAN/alias_free_activation/cuda/compat.h
Normal file
@ -0,0 +1,29 @@
|
|||||||
|
/* coding=utf-8
|
||||||
|
* Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
/*This code is copied fron NVIDIA apex:
|
||||||
|
* https://github.com/NVIDIA/apex
|
||||||
|
* with minor changes. */
|
||||||
|
|
||||||
|
#ifndef TORCH_CHECK
|
||||||
|
#define TORCH_CHECK AT_CHECK
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#ifdef VERSION_GE_1_3
|
||||||
|
#define DATA_PTR data_ptr
|
||||||
|
#else
|
||||||
|
#define DATA_PTR data
|
||||||
|
#endif
|
82
GPT_SoVITS/BigVGAN/alias_free_activation/cuda/load.py
Normal file
82
GPT_SoVITS/BigVGAN/alias_free_activation/cuda/load.py
Normal file
@ -0,0 +1,82 @@
|
|||||||
|
# Copyright (c) 2024 NVIDIA CORPORATION.
|
||||||
|
# Licensed under the MIT license.
|
||||||
|
|
||||||
|
import os
|
||||||
|
import pathlib
|
||||||
|
import subprocess
|
||||||
|
|
||||||
|
from torch.utils import cpp_extension
|
||||||
|
|
||||||
|
"""
|
||||||
|
Setting this param to a list has a problem of generating different compilation commands (with diferent order of architectures) and leading to recompilation of fused kernels.
|
||||||
|
Set it to empty stringo avoid recompilation and assign arch flags explicity in extra_cuda_cflags below
|
||||||
|
"""
|
||||||
|
os.environ["TORCH_CUDA_ARCH_LIST"] = ""
|
||||||
|
|
||||||
|
|
||||||
|
def load():
|
||||||
|
# Check if cuda 11 is installed for compute capability 8.0
|
||||||
|
cc_flag = []
|
||||||
|
_, bare_metal_major, _ = _get_cuda_bare_metal_version(cpp_extension.CUDA_HOME)
|
||||||
|
if int(bare_metal_major) >= 11:
|
||||||
|
cc_flag.append("-gencode")
|
||||||
|
cc_flag.append("arch=compute_80,code=sm_80")
|
||||||
|
|
||||||
|
# Build path
|
||||||
|
srcpath = pathlib.Path(__file__).parent.absolute()
|
||||||
|
buildpath = srcpath / "build"
|
||||||
|
_create_build_dir(buildpath)
|
||||||
|
|
||||||
|
# Helper function to build the kernels.
|
||||||
|
def _cpp_extention_load_helper(name, sources, extra_cuda_flags):
|
||||||
|
return cpp_extension.load(
|
||||||
|
name=name,
|
||||||
|
sources=sources,
|
||||||
|
build_directory=buildpath,
|
||||||
|
extra_cflags=[
|
||||||
|
"-O3",
|
||||||
|
],
|
||||||
|
extra_cuda_cflags=[
|
||||||
|
"-O3",
|
||||||
|
"-gencode",
|
||||||
|
"arch=compute_70,code=sm_70",
|
||||||
|
"--use_fast_math",
|
||||||
|
]
|
||||||
|
+ extra_cuda_flags
|
||||||
|
+ cc_flag,
|
||||||
|
verbose=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
extra_cuda_flags = [
|
||||||
|
"-U__CUDA_NO_HALF_OPERATORS__",
|
||||||
|
"-U__CUDA_NO_HALF_CONVERSIONS__",
|
||||||
|
"--expt-relaxed-constexpr",
|
||||||
|
"--expt-extended-lambda",
|
||||||
|
]
|
||||||
|
|
||||||
|
sources = [
|
||||||
|
srcpath / "anti_alias_activation.cpp",
|
||||||
|
srcpath / "anti_alias_activation_cuda.cu",
|
||||||
|
]
|
||||||
|
anti_alias_activation_cuda = _cpp_extention_load_helper("anti_alias_activation_cuda", sources, extra_cuda_flags)
|
||||||
|
|
||||||
|
return anti_alias_activation_cuda
|
||||||
|
|
||||||
|
|
||||||
|
def _get_cuda_bare_metal_version(cuda_dir):
|
||||||
|
raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True)
|
||||||
|
output = raw_output.split()
|
||||||
|
release_idx = output.index("release") + 1
|
||||||
|
release = output[release_idx].split(".")
|
||||||
|
bare_metal_major = release[0]
|
||||||
|
bare_metal_minor = release[1][0]
|
||||||
|
|
||||||
|
return raw_output, bare_metal_major, bare_metal_minor
|
||||||
|
|
||||||
|
|
||||||
|
def _create_build_dir(buildpath):
|
||||||
|
try:
|
||||||
|
os.mkdir(buildpath)
|
||||||
|
except OSError:
|
||||||
|
if not os.path.isdir(buildpath):
|
||||||
|
print(f"Creation of the build directory {buildpath} failed")
|
92
GPT_SoVITS/BigVGAN/alias_free_activation/cuda/type_shim.h
Normal file
92
GPT_SoVITS/BigVGAN/alias_free_activation/cuda/type_shim.h
Normal file
@ -0,0 +1,92 @@
|
|||||||
|
/* coding=utf-8
|
||||||
|
* Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include <ATen/ATen.h>
|
||||||
|
#include "compat.h"
|
||||||
|
|
||||||
|
#define DISPATCH_FLOAT_HALF_AND_BFLOAT(TYPE, NAME, ...) \
|
||||||
|
switch (TYPE) \
|
||||||
|
{ \
|
||||||
|
case at::ScalarType::Float: \
|
||||||
|
{ \
|
||||||
|
using scalar_t = float; \
|
||||||
|
__VA_ARGS__; \
|
||||||
|
break; \
|
||||||
|
} \
|
||||||
|
case at::ScalarType::Half: \
|
||||||
|
{ \
|
||||||
|
using scalar_t = at::Half; \
|
||||||
|
__VA_ARGS__; \
|
||||||
|
break; \
|
||||||
|
} \
|
||||||
|
case at::ScalarType::BFloat16: \
|
||||||
|
{ \
|
||||||
|
using scalar_t = at::BFloat16; \
|
||||||
|
__VA_ARGS__; \
|
||||||
|
break; \
|
||||||
|
} \
|
||||||
|
default: \
|
||||||
|
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
|
||||||
|
}
|
||||||
|
|
||||||
|
#define DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES(TYPEIN, TYPEOUT, NAME, ...) \
|
||||||
|
switch (TYPEIN) \
|
||||||
|
{ \
|
||||||
|
case at::ScalarType::Float: \
|
||||||
|
{ \
|
||||||
|
using scalar_t_in = float; \
|
||||||
|
switch (TYPEOUT) \
|
||||||
|
{ \
|
||||||
|
case at::ScalarType::Float: \
|
||||||
|
{ \
|
||||||
|
using scalar_t_out = float; \
|
||||||
|
__VA_ARGS__; \
|
||||||
|
break; \
|
||||||
|
} \
|
||||||
|
case at::ScalarType::Half: \
|
||||||
|
{ \
|
||||||
|
using scalar_t_out = at::Half; \
|
||||||
|
__VA_ARGS__; \
|
||||||
|
break; \
|
||||||
|
} \
|
||||||
|
case at::ScalarType::BFloat16: \
|
||||||
|
{ \
|
||||||
|
using scalar_t_out = at::BFloat16; \
|
||||||
|
__VA_ARGS__; \
|
||||||
|
break; \
|
||||||
|
} \
|
||||||
|
default: \
|
||||||
|
AT_ERROR(#NAME, " not implemented for '", toString(TYPEOUT), "'"); \
|
||||||
|
} \
|
||||||
|
break; \
|
||||||
|
} \
|
||||||
|
case at::ScalarType::Half: \
|
||||||
|
{ \
|
||||||
|
using scalar_t_in = at::Half; \
|
||||||
|
using scalar_t_out = at::Half; \
|
||||||
|
__VA_ARGS__; \
|
||||||
|
break; \
|
||||||
|
} \
|
||||||
|
case at::ScalarType::BFloat16: \
|
||||||
|
{ \
|
||||||
|
using scalar_t_in = at::BFloat16; \
|
||||||
|
using scalar_t_out = at::BFloat16; \
|
||||||
|
__VA_ARGS__; \
|
||||||
|
break; \
|
||||||
|
} \
|
||||||
|
default: \
|
||||||
|
AT_ERROR(#NAME, " not implemented for '", toString(TYPEIN), "'"); \
|
||||||
|
}
|
@ -0,0 +1,6 @@
|
|||||||
|
# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
|
||||||
|
# LICENSE is in incl_licenses directory.
|
||||||
|
|
||||||
|
from .filter import *
|
||||||
|
from .resample import *
|
||||||
|
from .act import *
|
30
GPT_SoVITS/BigVGAN/alias_free_activation/torch/act.py
Normal file
30
GPT_SoVITS/BigVGAN/alias_free_activation/torch/act.py
Normal file
@ -0,0 +1,30 @@
|
|||||||
|
# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
|
||||||
|
# LICENSE is in incl_licenses directory.
|
||||||
|
|
||||||
|
import torch.nn as nn
|
||||||
|
from .resample import UpSample1d, DownSample1d
|
||||||
|
|
||||||
|
|
||||||
|
class Activation1d(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
activation,
|
||||||
|
up_ratio: int = 2,
|
||||||
|
down_ratio: int = 2,
|
||||||
|
up_kernel_size: int = 12,
|
||||||
|
down_kernel_size: int = 12,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.up_ratio = up_ratio
|
||||||
|
self.down_ratio = down_ratio
|
||||||
|
self.act = activation
|
||||||
|
self.upsample = UpSample1d(up_ratio, up_kernel_size)
|
||||||
|
self.downsample = DownSample1d(down_ratio, down_kernel_size)
|
||||||
|
|
||||||
|
# x: [B,C,T]
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.upsample(x)
|
||||||
|
x = self.act(x)
|
||||||
|
x = self.downsample(x)
|
||||||
|
|
||||||
|
return x
|
99
GPT_SoVITS/BigVGAN/alias_free_activation/torch/filter.py
Normal file
99
GPT_SoVITS/BigVGAN/alias_free_activation/torch/filter.py
Normal file
@ -0,0 +1,99 @@
|
|||||||
|
# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
|
||||||
|
# LICENSE is in incl_licenses directory.
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import math
|
||||||
|
|
||||||
|
if "sinc" in dir(torch):
|
||||||
|
sinc = torch.sinc
|
||||||
|
else:
|
||||||
|
# This code is adopted from adefossez's julius.core.sinc under the MIT License
|
||||||
|
# https://adefossez.github.io/julius/julius/core.html
|
||||||
|
# LICENSE is in incl_licenses directory.
|
||||||
|
def sinc(x: torch.Tensor):
|
||||||
|
"""
|
||||||
|
Implementation of sinc, i.e. sin(pi * x) / (pi * x)
|
||||||
|
__Warning__: Different to julius.sinc, the input is multiplied by `pi`!
|
||||||
|
"""
|
||||||
|
return torch.where(
|
||||||
|
x == 0,
|
||||||
|
torch.tensor(1.0, device=x.device, dtype=x.dtype),
|
||||||
|
torch.sin(math.pi * x) / math.pi / x,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# This code is adopted from adefossez's julius.lowpass.LowPassFilters under the MIT License
|
||||||
|
# https://adefossez.github.io/julius/julius/lowpass.html
|
||||||
|
# LICENSE is in incl_licenses directory.
|
||||||
|
def kaiser_sinc_filter1d(cutoff, half_width, kernel_size): # return filter [1,1,kernel_size]
|
||||||
|
even = kernel_size % 2 == 0
|
||||||
|
half_size = kernel_size // 2
|
||||||
|
|
||||||
|
# For kaiser window
|
||||||
|
delta_f = 4 * half_width
|
||||||
|
A = 2.285 * (half_size - 1) * math.pi * delta_f + 7.95
|
||||||
|
if A > 50.0:
|
||||||
|
beta = 0.1102 * (A - 8.7)
|
||||||
|
elif A >= 21.0:
|
||||||
|
beta = 0.5842 * (A - 21) ** 0.4 + 0.07886 * (A - 21.0)
|
||||||
|
else:
|
||||||
|
beta = 0.0
|
||||||
|
window = torch.kaiser_window(kernel_size, beta=beta, periodic=False)
|
||||||
|
|
||||||
|
# ratio = 0.5/cutoff -> 2 * cutoff = 1 / ratio
|
||||||
|
if even:
|
||||||
|
time = torch.arange(-half_size, half_size) + 0.5
|
||||||
|
else:
|
||||||
|
time = torch.arange(kernel_size) - half_size
|
||||||
|
if cutoff == 0:
|
||||||
|
filter_ = torch.zeros_like(time)
|
||||||
|
else:
|
||||||
|
filter_ = 2 * cutoff * window * sinc(2 * cutoff * time)
|
||||||
|
"""
|
||||||
|
Normalize filter to have sum = 1, otherwise we will have a small leakage of the constant component in the input signal.
|
||||||
|
"""
|
||||||
|
filter_ /= filter_.sum()
|
||||||
|
filter = filter_.view(1, 1, kernel_size)
|
||||||
|
|
||||||
|
return filter
|
||||||
|
|
||||||
|
|
||||||
|
class LowPassFilter1d(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
cutoff=0.5,
|
||||||
|
half_width=0.6,
|
||||||
|
stride: int = 1,
|
||||||
|
padding: bool = True,
|
||||||
|
padding_mode: str = "replicate",
|
||||||
|
kernel_size: int = 12,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
kernel_size should be even number for stylegan3 setup, in this implementation, odd number is also possible.
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
if cutoff < -0.0:
|
||||||
|
raise ValueError("Minimum cutoff must be larger than zero.")
|
||||||
|
if cutoff > 0.5:
|
||||||
|
raise ValueError("A cutoff above 0.5 does not make sense.")
|
||||||
|
self.kernel_size = kernel_size
|
||||||
|
self.even = kernel_size % 2 == 0
|
||||||
|
self.pad_left = kernel_size // 2 - int(self.even)
|
||||||
|
self.pad_right = kernel_size // 2
|
||||||
|
self.stride = stride
|
||||||
|
self.padding = padding
|
||||||
|
self.padding_mode = padding_mode
|
||||||
|
filter = kaiser_sinc_filter1d(cutoff, half_width, kernel_size)
|
||||||
|
self.register_buffer("filter", filter)
|
||||||
|
|
||||||
|
# Input [B, C, T]
|
||||||
|
def forward(self, x):
|
||||||
|
_, C, _ = x.shape
|
||||||
|
|
||||||
|
if self.padding:
|
||||||
|
x = F.pad(x, (self.pad_left, self.pad_right), mode=self.padding_mode)
|
||||||
|
out = F.conv1d(x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C)
|
||||||
|
|
||||||
|
return out
|
48
GPT_SoVITS/BigVGAN/alias_free_activation/torch/resample.py
Normal file
48
GPT_SoVITS/BigVGAN/alias_free_activation/torch/resample.py
Normal file
@ -0,0 +1,48 @@
|
|||||||
|
# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
|
||||||
|
# LICENSE is in incl_licenses directory.
|
||||||
|
|
||||||
|
import torch.nn as nn
|
||||||
|
from torch.nn import functional as F
|
||||||
|
from .filter import LowPassFilter1d
|
||||||
|
from .filter import kaiser_sinc_filter1d
|
||||||
|
|
||||||
|
|
||||||
|
class UpSample1d(nn.Module):
|
||||||
|
def __init__(self, ratio=2, kernel_size=None):
|
||||||
|
super().__init__()
|
||||||
|
self.ratio = ratio
|
||||||
|
self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
|
||||||
|
self.stride = ratio
|
||||||
|
self.pad = self.kernel_size // ratio - 1
|
||||||
|
self.pad_left = self.pad * self.stride + (self.kernel_size - self.stride) // 2
|
||||||
|
self.pad_right = self.pad * self.stride + (self.kernel_size - self.stride + 1) // 2
|
||||||
|
filter = kaiser_sinc_filter1d(cutoff=0.5 / ratio, half_width=0.6 / ratio, kernel_size=self.kernel_size)
|
||||||
|
self.register_buffer("filter", filter)
|
||||||
|
|
||||||
|
# x: [B, C, T]
|
||||||
|
def forward(self, x):
|
||||||
|
_, C, _ = x.shape
|
||||||
|
|
||||||
|
x = F.pad(x, (self.pad, self.pad), mode="replicate")
|
||||||
|
x = self.ratio * F.conv_transpose1d(x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C)
|
||||||
|
x = x[..., self.pad_left : -self.pad_right]
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class DownSample1d(nn.Module):
|
||||||
|
def __init__(self, ratio=2, kernel_size=None):
|
||||||
|
super().__init__()
|
||||||
|
self.ratio = ratio
|
||||||
|
self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
|
||||||
|
self.lowpass = LowPassFilter1d(
|
||||||
|
cutoff=0.5 / ratio,
|
||||||
|
half_width=0.6 / ratio,
|
||||||
|
stride=ratio,
|
||||||
|
kernel_size=self.kernel_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
xx = self.lowpass(x)
|
||||||
|
|
||||||
|
return xx
|
461
GPT_SoVITS/BigVGAN/bigvgan.py
Normal file
461
GPT_SoVITS/BigVGAN/bigvgan.py
Normal file
@ -0,0 +1,461 @@
|
|||||||
|
# Copyright (c) 2024 NVIDIA CORPORATION.
|
||||||
|
# Licensed under the MIT license.
|
||||||
|
|
||||||
|
# Adapted from https://github.com/jik876/hifi-gan under the MIT license.
|
||||||
|
# LICENSE is in incl_licenses directory.
|
||||||
|
|
||||||
|
import os
|
||||||
|
import json
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Optional, Union, Dict
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from torch.nn import Conv1d, ConvTranspose1d
|
||||||
|
from torch.nn.utils import weight_norm, remove_weight_norm
|
||||||
|
|
||||||
|
from . import activations
|
||||||
|
from .utils0 import init_weights, get_padding
|
||||||
|
from .alias_free_activation.torch.act import Activation1d as TorchActivation1d
|
||||||
|
from .env import AttrDict
|
||||||
|
|
||||||
|
from huggingface_hub import PyTorchModelHubMixin, hf_hub_download
|
||||||
|
|
||||||
|
|
||||||
|
def load_hparams_from_json(path) -> AttrDict:
|
||||||
|
with open(path) as f:
|
||||||
|
data = f.read()
|
||||||
|
return AttrDict(json.loads(data))
|
||||||
|
|
||||||
|
|
||||||
|
class AMPBlock1(torch.nn.Module):
|
||||||
|
"""
|
||||||
|
AMPBlock applies Snake / SnakeBeta activation functions with trainable parameters that control periodicity, defined for each layer.
|
||||||
|
AMPBlock1 has additional self.convs2 that contains additional Conv1d layers with a fixed dilation=1 followed by each layer in self.convs1
|
||||||
|
|
||||||
|
Args:
|
||||||
|
h (AttrDict): Hyperparameters.
|
||||||
|
channels (int): Number of convolution channels.
|
||||||
|
kernel_size (int): Size of the convolution kernel. Default is 3.
|
||||||
|
dilation (tuple): Dilation rates for the convolutions. Each dilation layer has two convolutions. Default is (1, 3, 5).
|
||||||
|
activation (str): Activation function type. Should be either 'snake' or 'snakebeta'. Default is None.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
h: AttrDict,
|
||||||
|
channels: int,
|
||||||
|
kernel_size: int = 3,
|
||||||
|
dilation: tuple = (1, 3, 5),
|
||||||
|
activation: str = None,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.h = h
|
||||||
|
|
||||||
|
self.convs1 = nn.ModuleList(
|
||||||
|
[
|
||||||
|
weight_norm(
|
||||||
|
Conv1d(
|
||||||
|
channels,
|
||||||
|
channels,
|
||||||
|
kernel_size,
|
||||||
|
stride=1,
|
||||||
|
dilation=d,
|
||||||
|
padding=get_padding(kernel_size, d),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
for d in dilation
|
||||||
|
]
|
||||||
|
)
|
||||||
|
self.convs1.apply(init_weights)
|
||||||
|
|
||||||
|
self.convs2 = nn.ModuleList(
|
||||||
|
[
|
||||||
|
weight_norm(
|
||||||
|
Conv1d(
|
||||||
|
channels,
|
||||||
|
channels,
|
||||||
|
kernel_size,
|
||||||
|
stride=1,
|
||||||
|
dilation=1,
|
||||||
|
padding=get_padding(kernel_size, 1),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
for _ in range(len(dilation))
|
||||||
|
]
|
||||||
|
)
|
||||||
|
self.convs2.apply(init_weights)
|
||||||
|
|
||||||
|
self.num_layers = len(self.convs1) + len(self.convs2) # Total number of conv layers
|
||||||
|
|
||||||
|
# Select which Activation1d, lazy-load cuda version to ensure backward compatibility
|
||||||
|
if self.h.get("use_cuda_kernel", False):
|
||||||
|
from .alias_free_activation.cuda.activation1d import (
|
||||||
|
Activation1d as CudaActivation1d,
|
||||||
|
)
|
||||||
|
|
||||||
|
Activation1d = CudaActivation1d
|
||||||
|
else:
|
||||||
|
Activation1d = TorchActivation1d
|
||||||
|
|
||||||
|
# Activation functions
|
||||||
|
if activation == "snake":
|
||||||
|
self.activations = nn.ModuleList(
|
||||||
|
[
|
||||||
|
Activation1d(activation=activations.Snake(channels, alpha_logscale=h.snake_logscale))
|
||||||
|
for _ in range(self.num_layers)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
elif activation == "snakebeta":
|
||||||
|
self.activations = nn.ModuleList(
|
||||||
|
[
|
||||||
|
Activation1d(activation=activations.SnakeBeta(channels, alpha_logscale=h.snake_logscale))
|
||||||
|
for _ in range(self.num_layers)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"activation incorrectly specified. check the config file and look for 'activation'."
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
acts1, acts2 = self.activations[::2], self.activations[1::2]
|
||||||
|
for c1, c2, a1, a2 in zip(self.convs1, self.convs2, acts1, acts2):
|
||||||
|
xt = a1(x)
|
||||||
|
xt = c1(xt)
|
||||||
|
xt = a2(xt)
|
||||||
|
xt = c2(xt)
|
||||||
|
x = xt + x
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
def remove_weight_norm(self):
|
||||||
|
for l in self.convs1:
|
||||||
|
remove_weight_norm(l)
|
||||||
|
for l in self.convs2:
|
||||||
|
remove_weight_norm(l)
|
||||||
|
|
||||||
|
|
||||||
|
class AMPBlock2(torch.nn.Module):
|
||||||
|
"""
|
||||||
|
AMPBlock applies Snake / SnakeBeta activation functions with trainable parameters that control periodicity, defined for each layer.
|
||||||
|
Unlike AMPBlock1, AMPBlock2 does not contain extra Conv1d layers with fixed dilation=1
|
||||||
|
|
||||||
|
Args:
|
||||||
|
h (AttrDict): Hyperparameters.
|
||||||
|
channels (int): Number of convolution channels.
|
||||||
|
kernel_size (int): Size of the convolution kernel. Default is 3.
|
||||||
|
dilation (tuple): Dilation rates for the convolutions. Each dilation layer has two convolutions. Default is (1, 3, 5).
|
||||||
|
activation (str): Activation function type. Should be either 'snake' or 'snakebeta'. Default is None.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
h: AttrDict,
|
||||||
|
channels: int,
|
||||||
|
kernel_size: int = 3,
|
||||||
|
dilation: tuple = (1, 3, 5),
|
||||||
|
activation: str = None,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.h = h
|
||||||
|
|
||||||
|
self.convs = nn.ModuleList(
|
||||||
|
[
|
||||||
|
weight_norm(
|
||||||
|
Conv1d(
|
||||||
|
channels,
|
||||||
|
channels,
|
||||||
|
kernel_size,
|
||||||
|
stride=1,
|
||||||
|
dilation=d,
|
||||||
|
padding=get_padding(kernel_size, d),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
for d in dilation
|
||||||
|
]
|
||||||
|
)
|
||||||
|
self.convs.apply(init_weights)
|
||||||
|
|
||||||
|
self.num_layers = len(self.convs) # Total number of conv layers
|
||||||
|
|
||||||
|
# Select which Activation1d, lazy-load cuda version to ensure backward compatibility
|
||||||
|
if self.h.get("use_cuda_kernel", False):
|
||||||
|
from .alias_free_activation.cuda.activation1d import (
|
||||||
|
Activation1d as CudaActivation1d,
|
||||||
|
)
|
||||||
|
|
||||||
|
Activation1d = CudaActivation1d
|
||||||
|
else:
|
||||||
|
Activation1d = TorchActivation1d
|
||||||
|
|
||||||
|
# Activation functions
|
||||||
|
if activation == "snake":
|
||||||
|
self.activations = nn.ModuleList(
|
||||||
|
[
|
||||||
|
Activation1d(activation=activations.Snake(channels, alpha_logscale=h.snake_logscale))
|
||||||
|
for _ in range(self.num_layers)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
elif activation == "snakebeta":
|
||||||
|
self.activations = nn.ModuleList(
|
||||||
|
[
|
||||||
|
Activation1d(activation=activations.SnakeBeta(channels, alpha_logscale=h.snake_logscale))
|
||||||
|
for _ in range(self.num_layers)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"activation incorrectly specified. check the config file and look for 'activation'."
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
for c, a in zip(self.convs, self.activations):
|
||||||
|
xt = a(x)
|
||||||
|
xt = c(xt)
|
||||||
|
x = xt + x
|
||||||
|
return x
|
||||||
|
|
||||||
|
def remove_weight_norm(self):
|
||||||
|
for l in self.convs:
|
||||||
|
remove_weight_norm(l)
|
||||||
|
|
||||||
|
|
||||||
|
class BigVGAN(
|
||||||
|
torch.nn.Module,
|
||||||
|
PyTorchModelHubMixin,
|
||||||
|
# library_name="bigvgan",
|
||||||
|
# repo_url="https://github.com/NVIDIA/BigVGAN",
|
||||||
|
# docs_url="https://github.com/NVIDIA/BigVGAN/blob/main/README.md",
|
||||||
|
# pipeline_tag="audio-to-audio",
|
||||||
|
# license="mit",
|
||||||
|
# tags=["neural-vocoder", "audio-generation", "arxiv:2206.04658"],
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
BigVGAN is a neural vocoder model that applies anti-aliased periodic activation for residual blocks (resblocks).
|
||||||
|
New in BigVGAN-v2: it can optionally use optimized CUDA kernels for AMP (anti-aliased multi-periodicity) blocks.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
h (AttrDict): Hyperparameters.
|
||||||
|
use_cuda_kernel (bool): If set to True, loads optimized CUDA kernels for AMP. This should be used for inference only, as training is not supported with CUDA kernels.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
- The `use_cuda_kernel` parameter should be used for inference only, as training with CUDA kernels is not supported.
|
||||||
|
- Ensure that the activation function is correctly specified in the hyperparameters (h.activation).
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, h: AttrDict, use_cuda_kernel: bool = False):
|
||||||
|
super().__init__()
|
||||||
|
self.h = h
|
||||||
|
self.h["use_cuda_kernel"] = use_cuda_kernel
|
||||||
|
|
||||||
|
# Select which Activation1d, lazy-load cuda version to ensure backward compatibility
|
||||||
|
if self.h.get("use_cuda_kernel", False):
|
||||||
|
from .alias_free_activation.cuda.activation1d import (
|
||||||
|
Activation1d as CudaActivation1d,
|
||||||
|
)
|
||||||
|
|
||||||
|
Activation1d = CudaActivation1d
|
||||||
|
else:
|
||||||
|
Activation1d = TorchActivation1d
|
||||||
|
|
||||||
|
self.num_kernels = len(h.resblock_kernel_sizes)
|
||||||
|
self.num_upsamples = len(h.upsample_rates)
|
||||||
|
|
||||||
|
# Pre-conv
|
||||||
|
self.conv_pre = weight_norm(Conv1d(h.num_mels, h.upsample_initial_channel, 7, 1, padding=3))
|
||||||
|
|
||||||
|
# Define which AMPBlock to use. BigVGAN uses AMPBlock1 as default
|
||||||
|
if h.resblock == "1":
|
||||||
|
resblock_class = AMPBlock1
|
||||||
|
elif h.resblock == "2":
|
||||||
|
resblock_class = AMPBlock2
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Incorrect resblock class specified in hyperparameters. Got {h.resblock}")
|
||||||
|
|
||||||
|
# Transposed conv-based upsamplers. does not apply anti-aliasing
|
||||||
|
self.ups = nn.ModuleList()
|
||||||
|
for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)):
|
||||||
|
self.ups.append(
|
||||||
|
nn.ModuleList(
|
||||||
|
[
|
||||||
|
weight_norm(
|
||||||
|
ConvTranspose1d(
|
||||||
|
h.upsample_initial_channel // (2**i),
|
||||||
|
h.upsample_initial_channel // (2 ** (i + 1)),
|
||||||
|
k,
|
||||||
|
u,
|
||||||
|
padding=(k - u) // 2,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Residual blocks using anti-aliased multi-periodicity composition modules (AMP)
|
||||||
|
self.resblocks = nn.ModuleList()
|
||||||
|
for i in range(len(self.ups)):
|
||||||
|
ch = h.upsample_initial_channel // (2 ** (i + 1))
|
||||||
|
for j, (k, d) in enumerate(zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)):
|
||||||
|
self.resblocks.append(resblock_class(h, ch, k, d, activation=h.activation))
|
||||||
|
|
||||||
|
# Post-conv
|
||||||
|
activation_post = (
|
||||||
|
activations.Snake(ch, alpha_logscale=h.snake_logscale)
|
||||||
|
if h.activation == "snake"
|
||||||
|
else (activations.SnakeBeta(ch, alpha_logscale=h.snake_logscale) if h.activation == "snakebeta" else None)
|
||||||
|
)
|
||||||
|
if activation_post is None:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"activation incorrectly specified. check the config file and look for 'activation'."
|
||||||
|
)
|
||||||
|
|
||||||
|
self.activation_post = Activation1d(activation=activation_post)
|
||||||
|
|
||||||
|
# Whether to use bias for the final conv_post. Default to True for backward compatibility
|
||||||
|
self.use_bias_at_final = h.get("use_bias_at_final", True)
|
||||||
|
self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3, bias=self.use_bias_at_final))
|
||||||
|
|
||||||
|
# Weight initialization
|
||||||
|
for i in range(len(self.ups)):
|
||||||
|
self.ups[i].apply(init_weights)
|
||||||
|
self.conv_post.apply(init_weights)
|
||||||
|
|
||||||
|
# Final tanh activation. Defaults to True for backward compatibility
|
||||||
|
self.use_tanh_at_final = h.get("use_tanh_at_final", True)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
# Pre-conv
|
||||||
|
x = self.conv_pre(x)
|
||||||
|
|
||||||
|
for i in range(self.num_upsamples):
|
||||||
|
# Upsampling
|
||||||
|
for i_up in range(len(self.ups[i])):
|
||||||
|
x = self.ups[i][i_up](x)
|
||||||
|
# AMP blocks
|
||||||
|
xs = None
|
||||||
|
for j in range(self.num_kernels):
|
||||||
|
if xs is None:
|
||||||
|
xs = self.resblocks[i * self.num_kernels + j](x)
|
||||||
|
else:
|
||||||
|
xs += self.resblocks[i * self.num_kernels + j](x)
|
||||||
|
x = xs / self.num_kernels
|
||||||
|
|
||||||
|
# Post-conv
|
||||||
|
x = self.activation_post(x)
|
||||||
|
x = self.conv_post(x)
|
||||||
|
# Final tanh activation
|
||||||
|
if self.use_tanh_at_final:
|
||||||
|
x = torch.tanh(x)
|
||||||
|
else:
|
||||||
|
x = torch.clamp(x, min=-1.0, max=1.0) # Bound the output to [-1, 1]
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
def remove_weight_norm(self):
|
||||||
|
try:
|
||||||
|
# print("Removing weight norm...")
|
||||||
|
for l in self.ups:
|
||||||
|
for l_i in l:
|
||||||
|
remove_weight_norm(l_i)
|
||||||
|
for l in self.resblocks:
|
||||||
|
l.remove_weight_norm()
|
||||||
|
remove_weight_norm(self.conv_pre)
|
||||||
|
remove_weight_norm(self.conv_post)
|
||||||
|
except ValueError:
|
||||||
|
print("[INFO] Model already removed weight norm. Skipping!")
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Additional methods for huggingface_hub support
|
||||||
|
def _save_pretrained(self, save_directory: Path) -> None:
|
||||||
|
"""Save weights and config.json from a Pytorch model to a local directory."""
|
||||||
|
|
||||||
|
model_path = save_directory / "bigvgan_generator.pt"
|
||||||
|
torch.save({"generator": self.state_dict()}, model_path)
|
||||||
|
|
||||||
|
config_path = save_directory / "config.json"
|
||||||
|
with open(config_path, "w") as config_file:
|
||||||
|
json.dump(self.h, config_file, indent=4)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _from_pretrained(
|
||||||
|
cls,
|
||||||
|
*,
|
||||||
|
model_id: str,
|
||||||
|
revision: str,
|
||||||
|
cache_dir: str,
|
||||||
|
force_download: bool,
|
||||||
|
proxies: Optional[Dict],
|
||||||
|
resume_download: bool,
|
||||||
|
local_files_only: bool,
|
||||||
|
token: Union[str, bool, None],
|
||||||
|
map_location: str = "cpu", # Additional argument
|
||||||
|
strict: bool = False, # Additional argument
|
||||||
|
use_cuda_kernel: bool = False,
|
||||||
|
**model_kwargs,
|
||||||
|
):
|
||||||
|
"""Load Pytorch pretrained weights and return the loaded model."""
|
||||||
|
|
||||||
|
# Download and load hyperparameters (h) used by BigVGAN
|
||||||
|
if os.path.isdir(model_id):
|
||||||
|
# print("Loading config.json from local directory")
|
||||||
|
config_file = os.path.join(model_id, "config.json")
|
||||||
|
else:
|
||||||
|
config_file = hf_hub_download(
|
||||||
|
repo_id=model_id,
|
||||||
|
filename="config.json",
|
||||||
|
revision=revision,
|
||||||
|
cache_dir=cache_dir,
|
||||||
|
force_download=force_download,
|
||||||
|
proxies=proxies,
|
||||||
|
resume_download=resume_download,
|
||||||
|
token=token,
|
||||||
|
local_files_only=local_files_only,
|
||||||
|
)
|
||||||
|
h = load_hparams_from_json(config_file)
|
||||||
|
|
||||||
|
# instantiate BigVGAN using h
|
||||||
|
if use_cuda_kernel:
|
||||||
|
print(
|
||||||
|
"[WARNING] You have specified use_cuda_kernel=True during BigVGAN.from_pretrained(). Only inference is supported (training is not implemented)!"
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
"[WARNING] You need nvcc and ninja installed in your system that matches your PyTorch build is using to build the kernel. If not, the model will fail to initialize or generate incorrect waveform!"
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
"[WARNING] For detail, see the official GitHub repository: https://github.com/NVIDIA/BigVGAN?tab=readme-ov-file#using-custom-cuda-kernel-for-synthesis"
|
||||||
|
)
|
||||||
|
model = cls(h, use_cuda_kernel=use_cuda_kernel)
|
||||||
|
|
||||||
|
# Download and load pretrained generator weight
|
||||||
|
if os.path.isdir(model_id):
|
||||||
|
# print("Loading weights from local directory")
|
||||||
|
model_file = os.path.join(model_id, "bigvgan_generator.pt")
|
||||||
|
else:
|
||||||
|
# print(f"Loading weights from {model_id}")
|
||||||
|
model_file = hf_hub_download(
|
||||||
|
repo_id=model_id,
|
||||||
|
filename="bigvgan_generator.pt",
|
||||||
|
revision=revision,
|
||||||
|
cache_dir=cache_dir,
|
||||||
|
force_download=force_download,
|
||||||
|
proxies=proxies,
|
||||||
|
resume_download=resume_download,
|
||||||
|
token=token,
|
||||||
|
local_files_only=local_files_only,
|
||||||
|
)
|
||||||
|
|
||||||
|
checkpoint_dict = torch.load(model_file, map_location=map_location)
|
||||||
|
|
||||||
|
try:
|
||||||
|
model.load_state_dict(checkpoint_dict["generator"])
|
||||||
|
except RuntimeError:
|
||||||
|
print(
|
||||||
|
"[INFO] the pretrained checkpoint does not contain weight norm. Loading the checkpoint after removing weight norm!"
|
||||||
|
)
|
||||||
|
model.remove_weight_norm()
|
||||||
|
model.load_state_dict(checkpoint_dict["generator"])
|
||||||
|
|
||||||
|
return model
|
45
GPT_SoVITS/BigVGAN/configs/bigvgan_22khz_80band.json
Normal file
45
GPT_SoVITS/BigVGAN/configs/bigvgan_22khz_80band.json
Normal file
@ -0,0 +1,45 @@
|
|||||||
|
{
|
||||||
|
"resblock": "1",
|
||||||
|
"num_gpus": 0,
|
||||||
|
"batch_size": 32,
|
||||||
|
"learning_rate": 0.0001,
|
||||||
|
"adam_b1": 0.8,
|
||||||
|
"adam_b2": 0.99,
|
||||||
|
"lr_decay": 0.9999996,
|
||||||
|
"seed": 1234,
|
||||||
|
|
||||||
|
"upsample_rates": [4,4,2,2,2,2],
|
||||||
|
"upsample_kernel_sizes": [8,8,4,4,4,4],
|
||||||
|
"upsample_initial_channel": 1536,
|
||||||
|
"resblock_kernel_sizes": [3,7,11],
|
||||||
|
"resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]],
|
||||||
|
|
||||||
|
"activation": "snakebeta",
|
||||||
|
"snake_logscale": true,
|
||||||
|
|
||||||
|
"resolutions": [[1024, 120, 600], [2048, 240, 1200], [512, 50, 240]],
|
||||||
|
"mpd_reshapes": [2, 3, 5, 7, 11],
|
||||||
|
"use_spectral_norm": false,
|
||||||
|
"discriminator_channel_mult": 1,
|
||||||
|
|
||||||
|
"segment_size": 8192,
|
||||||
|
"num_mels": 80,
|
||||||
|
"num_freq": 1025,
|
||||||
|
"n_fft": 1024,
|
||||||
|
"hop_size": 256,
|
||||||
|
"win_size": 1024,
|
||||||
|
|
||||||
|
"sampling_rate": 22050,
|
||||||
|
|
||||||
|
"fmin": 0,
|
||||||
|
"fmax": 8000,
|
||||||
|
"fmax_for_loss": null,
|
||||||
|
|
||||||
|
"num_workers": 4,
|
||||||
|
|
||||||
|
"dist_config": {
|
||||||
|
"dist_backend": "nccl",
|
||||||
|
"dist_url": "tcp://localhost:54321",
|
||||||
|
"world_size": 1
|
||||||
|
}
|
||||||
|
}
|
45
GPT_SoVITS/BigVGAN/configs/bigvgan_24khz_100band.json
Normal file
45
GPT_SoVITS/BigVGAN/configs/bigvgan_24khz_100band.json
Normal file
@ -0,0 +1,45 @@
|
|||||||
|
{
|
||||||
|
"resblock": "1",
|
||||||
|
"num_gpus": 0,
|
||||||
|
"batch_size": 32,
|
||||||
|
"learning_rate": 0.0001,
|
||||||
|
"adam_b1": 0.8,
|
||||||
|
"adam_b2": 0.99,
|
||||||
|
"lr_decay": 0.9999996,
|
||||||
|
"seed": 1234,
|
||||||
|
|
||||||
|
"upsample_rates": [4,4,2,2,2,2],
|
||||||
|
"upsample_kernel_sizes": [8,8,4,4,4,4],
|
||||||
|
"upsample_initial_channel": 1536,
|
||||||
|
"resblock_kernel_sizes": [3,7,11],
|
||||||
|
"resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]],
|
||||||
|
|
||||||
|
"activation": "snakebeta",
|
||||||
|
"snake_logscale": true,
|
||||||
|
|
||||||
|
"resolutions": [[1024, 120, 600], [2048, 240, 1200], [512, 50, 240]],
|
||||||
|
"mpd_reshapes": [2, 3, 5, 7, 11],
|
||||||
|
"use_spectral_norm": false,
|
||||||
|
"discriminator_channel_mult": 1,
|
||||||
|
|
||||||
|
"segment_size": 8192,
|
||||||
|
"num_mels": 100,
|
||||||
|
"num_freq": 1025,
|
||||||
|
"n_fft": 1024,
|
||||||
|
"hop_size": 256,
|
||||||
|
"win_size": 1024,
|
||||||
|
|
||||||
|
"sampling_rate": 24000,
|
||||||
|
|
||||||
|
"fmin": 0,
|
||||||
|
"fmax": 12000,
|
||||||
|
"fmax_for_loss": null,
|
||||||
|
|
||||||
|
"num_workers": 4,
|
||||||
|
|
||||||
|
"dist_config": {
|
||||||
|
"dist_backend": "nccl",
|
||||||
|
"dist_url": "tcp://localhost:54321",
|
||||||
|
"world_size": 1
|
||||||
|
}
|
||||||
|
}
|
45
GPT_SoVITS/BigVGAN/configs/bigvgan_base_22khz_80band.json
Normal file
45
GPT_SoVITS/BigVGAN/configs/bigvgan_base_22khz_80band.json
Normal file
@ -0,0 +1,45 @@
|
|||||||
|
{
|
||||||
|
"resblock": "1",
|
||||||
|
"num_gpus": 0,
|
||||||
|
"batch_size": 32,
|
||||||
|
"learning_rate": 0.0001,
|
||||||
|
"adam_b1": 0.8,
|
||||||
|
"adam_b2": 0.99,
|
||||||
|
"lr_decay": 0.9999996,
|
||||||
|
"seed": 1234,
|
||||||
|
|
||||||
|
"upsample_rates": [8,8,2,2],
|
||||||
|
"upsample_kernel_sizes": [16,16,4,4],
|
||||||
|
"upsample_initial_channel": 512,
|
||||||
|
"resblock_kernel_sizes": [3,7,11],
|
||||||
|
"resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]],
|
||||||
|
|
||||||
|
"activation": "snakebeta",
|
||||||
|
"snake_logscale": true,
|
||||||
|
|
||||||
|
"resolutions": [[1024, 120, 600], [2048, 240, 1200], [512, 50, 240]],
|
||||||
|
"mpd_reshapes": [2, 3, 5, 7, 11],
|
||||||
|
"use_spectral_norm": false,
|
||||||
|
"discriminator_channel_mult": 1,
|
||||||
|
|
||||||
|
"segment_size": 8192,
|
||||||
|
"num_mels": 80,
|
||||||
|
"num_freq": 1025,
|
||||||
|
"n_fft": 1024,
|
||||||
|
"hop_size": 256,
|
||||||
|
"win_size": 1024,
|
||||||
|
|
||||||
|
"sampling_rate": 22050,
|
||||||
|
|
||||||
|
"fmin": 0,
|
||||||
|
"fmax": 8000,
|
||||||
|
"fmax_for_loss": null,
|
||||||
|
|
||||||
|
"num_workers": 4,
|
||||||
|
|
||||||
|
"dist_config": {
|
||||||
|
"dist_backend": "nccl",
|
||||||
|
"dist_url": "tcp://localhost:54321",
|
||||||
|
"world_size": 1
|
||||||
|
}
|
||||||
|
}
|
45
GPT_SoVITS/BigVGAN/configs/bigvgan_base_24khz_100band.json
Normal file
45
GPT_SoVITS/BigVGAN/configs/bigvgan_base_24khz_100band.json
Normal file
@ -0,0 +1,45 @@
|
|||||||
|
{
|
||||||
|
"resblock": "1",
|
||||||
|
"num_gpus": 0,
|
||||||
|
"batch_size": 32,
|
||||||
|
"learning_rate": 0.0001,
|
||||||
|
"adam_b1": 0.8,
|
||||||
|
"adam_b2": 0.99,
|
||||||
|
"lr_decay": 0.9999996,
|
||||||
|
"seed": 1234,
|
||||||
|
|
||||||
|
"upsample_rates": [8,8,2,2],
|
||||||
|
"upsample_kernel_sizes": [16,16,4,4],
|
||||||
|
"upsample_initial_channel": 512,
|
||||||
|
"resblock_kernel_sizes": [3,7,11],
|
||||||
|
"resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]],
|
||||||
|
|
||||||
|
"activation": "snakebeta",
|
||||||
|
"snake_logscale": true,
|
||||||
|
|
||||||
|
"resolutions": [[1024, 120, 600], [2048, 240, 1200], [512, 50, 240]],
|
||||||
|
"mpd_reshapes": [2, 3, 5, 7, 11],
|
||||||
|
"use_spectral_norm": false,
|
||||||
|
"discriminator_channel_mult": 1,
|
||||||
|
|
||||||
|
"segment_size": 8192,
|
||||||
|
"num_mels": 100,
|
||||||
|
"num_freq": 1025,
|
||||||
|
"n_fft": 1024,
|
||||||
|
"hop_size": 256,
|
||||||
|
"win_size": 1024,
|
||||||
|
|
||||||
|
"sampling_rate": 24000,
|
||||||
|
|
||||||
|
"fmin": 0,
|
||||||
|
"fmax": 12000,
|
||||||
|
"fmax_for_loss": null,
|
||||||
|
|
||||||
|
"num_workers": 4,
|
||||||
|
|
||||||
|
"dist_config": {
|
||||||
|
"dist_backend": "nccl",
|
||||||
|
"dist_url": "tcp://localhost:54321",
|
||||||
|
"world_size": 1
|
||||||
|
}
|
||||||
|
}
|
61
GPT_SoVITS/BigVGAN/configs/bigvgan_v2_22khz_80band_256x.json
Normal file
61
GPT_SoVITS/BigVGAN/configs/bigvgan_v2_22khz_80band_256x.json
Normal file
@ -0,0 +1,61 @@
|
|||||||
|
{
|
||||||
|
"resblock": "1",
|
||||||
|
"num_gpus": 0,
|
||||||
|
"batch_size": 4,
|
||||||
|
"learning_rate": 0.0001,
|
||||||
|
"adam_b1": 0.8,
|
||||||
|
"adam_b2": 0.99,
|
||||||
|
"lr_decay": 0.9999996,
|
||||||
|
"seed": 1234,
|
||||||
|
|
||||||
|
"upsample_rates": [4,4,2,2,2,2],
|
||||||
|
"upsample_kernel_sizes": [8,8,4,4,4,4],
|
||||||
|
"upsample_initial_channel": 1536,
|
||||||
|
"resblock_kernel_sizes": [3,7,11],
|
||||||
|
"resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]],
|
||||||
|
|
||||||
|
"use_tanh_at_final": false,
|
||||||
|
"use_bias_at_final": false,
|
||||||
|
|
||||||
|
"activation": "snakebeta",
|
||||||
|
"snake_logscale": true,
|
||||||
|
|
||||||
|
"use_cqtd_instead_of_mrd": true,
|
||||||
|
"cqtd_filters": 128,
|
||||||
|
"cqtd_max_filters": 1024,
|
||||||
|
"cqtd_filters_scale": 1,
|
||||||
|
"cqtd_dilations": [1, 2, 4],
|
||||||
|
"cqtd_hop_lengths": [512, 256, 256],
|
||||||
|
"cqtd_n_octaves": [9, 9, 9],
|
||||||
|
"cqtd_bins_per_octaves": [24, 36, 48],
|
||||||
|
|
||||||
|
"mpd_reshapes": [2, 3, 5, 7, 11],
|
||||||
|
"use_spectral_norm": false,
|
||||||
|
"discriminator_channel_mult": 1,
|
||||||
|
|
||||||
|
"use_multiscale_melloss": true,
|
||||||
|
"lambda_melloss": 15,
|
||||||
|
|
||||||
|
"clip_grad_norm": 500,
|
||||||
|
|
||||||
|
"segment_size": 65536,
|
||||||
|
"num_mels": 80,
|
||||||
|
"num_freq": 1025,
|
||||||
|
"n_fft": 1024,
|
||||||
|
"hop_size": 256,
|
||||||
|
"win_size": 1024,
|
||||||
|
|
||||||
|
"sampling_rate": 22050,
|
||||||
|
|
||||||
|
"fmin": 0,
|
||||||
|
"fmax": null,
|
||||||
|
"fmax_for_loss": null,
|
||||||
|
|
||||||
|
"num_workers": 4,
|
||||||
|
|
||||||
|
"dist_config": {
|
||||||
|
"dist_backend": "nccl",
|
||||||
|
"dist_url": "tcp://localhost:54321",
|
||||||
|
"world_size": 1
|
||||||
|
}
|
||||||
|
}
|
@ -0,0 +1,61 @@
|
|||||||
|
{
|
||||||
|
"resblock": "1",
|
||||||
|
"num_gpus": 0,
|
||||||
|
"batch_size": 4,
|
||||||
|
"learning_rate": 0.0001,
|
||||||
|
"adam_b1": 0.8,
|
||||||
|
"adam_b2": 0.99,
|
||||||
|
"lr_decay": 0.9999996,
|
||||||
|
"seed": 1234,
|
||||||
|
|
||||||
|
"upsample_rates": [4,4,2,2,2,2],
|
||||||
|
"upsample_kernel_sizes": [8,8,4,4,4,4],
|
||||||
|
"upsample_initial_channel": 1536,
|
||||||
|
"resblock_kernel_sizes": [3,7,11],
|
||||||
|
"resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]],
|
||||||
|
|
||||||
|
"use_tanh_at_final": false,
|
||||||
|
"use_bias_at_final": false,
|
||||||
|
|
||||||
|
"activation": "snakebeta",
|
||||||
|
"snake_logscale": true,
|
||||||
|
|
||||||
|
"use_cqtd_instead_of_mrd": true,
|
||||||
|
"cqtd_filters": 128,
|
||||||
|
"cqtd_max_filters": 1024,
|
||||||
|
"cqtd_filters_scale": 1,
|
||||||
|
"cqtd_dilations": [1, 2, 4],
|
||||||
|
"cqtd_hop_lengths": [512, 256, 256],
|
||||||
|
"cqtd_n_octaves": [9, 9, 9],
|
||||||
|
"cqtd_bins_per_octaves": [24, 36, 48],
|
||||||
|
|
||||||
|
"mpd_reshapes": [2, 3, 5, 7, 11],
|
||||||
|
"use_spectral_norm": false,
|
||||||
|
"discriminator_channel_mult": 1,
|
||||||
|
|
||||||
|
"use_multiscale_melloss": true,
|
||||||
|
"lambda_melloss": 15,
|
||||||
|
|
||||||
|
"clip_grad_norm": 500,
|
||||||
|
|
||||||
|
"segment_size": 65536,
|
||||||
|
"num_mels": 80,
|
||||||
|
"num_freq": 1025,
|
||||||
|
"n_fft": 1024,
|
||||||
|
"hop_size": 256,
|
||||||
|
"win_size": 1024,
|
||||||
|
|
||||||
|
"sampling_rate": 22050,
|
||||||
|
|
||||||
|
"fmin": 0,
|
||||||
|
"fmax": 8000,
|
||||||
|
"fmax_for_loss": null,
|
||||||
|
|
||||||
|
"num_workers": 4,
|
||||||
|
|
||||||
|
"dist_config": {
|
||||||
|
"dist_backend": "nccl",
|
||||||
|
"dist_url": "tcp://localhost:54321",
|
||||||
|
"world_size": 1
|
||||||
|
}
|
||||||
|
}
|
@ -0,0 +1,61 @@
|
|||||||
|
{
|
||||||
|
"resblock": "1",
|
||||||
|
"num_gpus": 0,
|
||||||
|
"batch_size": 4,
|
||||||
|
"learning_rate": 0.0001,
|
||||||
|
"adam_b1": 0.8,
|
||||||
|
"adam_b2": 0.99,
|
||||||
|
"lr_decay": 0.9999996,
|
||||||
|
"seed": 1234,
|
||||||
|
|
||||||
|
"upsample_rates": [4,4,2,2,2,2],
|
||||||
|
"upsample_kernel_sizes": [8,8,4,4,4,4],
|
||||||
|
"upsample_initial_channel": 1536,
|
||||||
|
"resblock_kernel_sizes": [3,7,11],
|
||||||
|
"resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]],
|
||||||
|
|
||||||
|
"use_tanh_at_final": false,
|
||||||
|
"use_bias_at_final": false,
|
||||||
|
|
||||||
|
"activation": "snakebeta",
|
||||||
|
"snake_logscale": true,
|
||||||
|
|
||||||
|
"use_cqtd_instead_of_mrd": true,
|
||||||
|
"cqtd_filters": 128,
|
||||||
|
"cqtd_max_filters": 1024,
|
||||||
|
"cqtd_filters_scale": 1,
|
||||||
|
"cqtd_dilations": [1, 2, 4],
|
||||||
|
"cqtd_hop_lengths": [512, 256, 256],
|
||||||
|
"cqtd_n_octaves": [9, 9, 9],
|
||||||
|
"cqtd_bins_per_octaves": [24, 36, 48],
|
||||||
|
|
||||||
|
"mpd_reshapes": [2, 3, 5, 7, 11],
|
||||||
|
"use_spectral_norm": false,
|
||||||
|
"discriminator_channel_mult": 1,
|
||||||
|
|
||||||
|
"use_multiscale_melloss": true,
|
||||||
|
"lambda_melloss": 15,
|
||||||
|
|
||||||
|
"clip_grad_norm": 500,
|
||||||
|
|
||||||
|
"segment_size": 65536,
|
||||||
|
"num_mels": 100,
|
||||||
|
"num_freq": 1025,
|
||||||
|
"n_fft": 1024,
|
||||||
|
"hop_size": 256,
|
||||||
|
"win_size": 1024,
|
||||||
|
|
||||||
|
"sampling_rate": 24000,
|
||||||
|
|
||||||
|
"fmin": 0,
|
||||||
|
"fmax": null,
|
||||||
|
"fmax_for_loss": null,
|
||||||
|
|
||||||
|
"num_workers": 4,
|
||||||
|
|
||||||
|
"dist_config": {
|
||||||
|
"dist_backend": "nccl",
|
||||||
|
"dist_url": "tcp://localhost:54321",
|
||||||
|
"world_size": 1
|
||||||
|
}
|
||||||
|
}
|
@ -0,0 +1,61 @@
|
|||||||
|
{
|
||||||
|
"resblock": "1",
|
||||||
|
"num_gpus": 0,
|
||||||
|
"batch_size": 4,
|
||||||
|
"learning_rate": 0.0001,
|
||||||
|
"adam_b1": 0.8,
|
||||||
|
"adam_b2": 0.99,
|
||||||
|
"lr_decay": 0.9999996,
|
||||||
|
"seed": 1234,
|
||||||
|
|
||||||
|
"upsample_rates": [4,4,2,2,2,2],
|
||||||
|
"upsample_kernel_sizes": [8,8,4,4,4,4],
|
||||||
|
"upsample_initial_channel": 1536,
|
||||||
|
"resblock_kernel_sizes": [3,7,11],
|
||||||
|
"resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]],
|
||||||
|
|
||||||
|
"use_tanh_at_final": false,
|
||||||
|
"use_bias_at_final": false,
|
||||||
|
|
||||||
|
"activation": "snakebeta",
|
||||||
|
"snake_logscale": true,
|
||||||
|
|
||||||
|
"use_cqtd_instead_of_mrd": true,
|
||||||
|
"cqtd_filters": 128,
|
||||||
|
"cqtd_max_filters": 1024,
|
||||||
|
"cqtd_filters_scale": 1,
|
||||||
|
"cqtd_dilations": [1, 2, 4],
|
||||||
|
"cqtd_hop_lengths": [512, 256, 256],
|
||||||
|
"cqtd_n_octaves": [9, 9, 9],
|
||||||
|
"cqtd_bins_per_octaves": [24, 36, 48],
|
||||||
|
|
||||||
|
"mpd_reshapes": [2, 3, 5, 7, 11],
|
||||||
|
"use_spectral_norm": false,
|
||||||
|
"discriminator_channel_mult": 1,
|
||||||
|
|
||||||
|
"use_multiscale_melloss": true,
|
||||||
|
"lambda_melloss": 15,
|
||||||
|
|
||||||
|
"clip_grad_norm": 500,
|
||||||
|
|
||||||
|
"segment_size": 65536,
|
||||||
|
"num_mels": 128,
|
||||||
|
"num_freq": 1025,
|
||||||
|
"n_fft": 1024,
|
||||||
|
"hop_size": 256,
|
||||||
|
"win_size": 1024,
|
||||||
|
|
||||||
|
"sampling_rate": 44100,
|
||||||
|
|
||||||
|
"fmin": 0,
|
||||||
|
"fmax": null,
|
||||||
|
"fmax_for_loss": null,
|
||||||
|
|
||||||
|
"num_workers": 4,
|
||||||
|
|
||||||
|
"dist_config": {
|
||||||
|
"dist_backend": "nccl",
|
||||||
|
"dist_url": "tcp://localhost:54321",
|
||||||
|
"world_size": 1
|
||||||
|
}
|
||||||
|
}
|
@ -0,0 +1,61 @@
|
|||||||
|
{
|
||||||
|
"resblock": "1",
|
||||||
|
"num_gpus": 0,
|
||||||
|
"batch_size": 4,
|
||||||
|
"learning_rate": 0.0001,
|
||||||
|
"adam_b1": 0.8,
|
||||||
|
"adam_b2": 0.99,
|
||||||
|
"lr_decay": 0.9999996,
|
||||||
|
"seed": 1234,
|
||||||
|
|
||||||
|
"upsample_rates": [8,4,2,2,2,2],
|
||||||
|
"upsample_kernel_sizes": [16,8,4,4,4,4],
|
||||||
|
"upsample_initial_channel": 1536,
|
||||||
|
"resblock_kernel_sizes": [3,7,11],
|
||||||
|
"resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]],
|
||||||
|
|
||||||
|
"use_tanh_at_final": false,
|
||||||
|
"use_bias_at_final": false,
|
||||||
|
|
||||||
|
"activation": "snakebeta",
|
||||||
|
"snake_logscale": true,
|
||||||
|
|
||||||
|
"use_cqtd_instead_of_mrd": true,
|
||||||
|
"cqtd_filters": 128,
|
||||||
|
"cqtd_max_filters": 1024,
|
||||||
|
"cqtd_filters_scale": 1,
|
||||||
|
"cqtd_dilations": [1, 2, 4],
|
||||||
|
"cqtd_hop_lengths": [512, 256, 256],
|
||||||
|
"cqtd_n_octaves": [9, 9, 9],
|
||||||
|
"cqtd_bins_per_octaves": [24, 36, 48],
|
||||||
|
|
||||||
|
"mpd_reshapes": [2, 3, 5, 7, 11],
|
||||||
|
"use_spectral_norm": false,
|
||||||
|
"discriminator_channel_mult": 1,
|
||||||
|
|
||||||
|
"use_multiscale_melloss": true,
|
||||||
|
"lambda_melloss": 15,
|
||||||
|
|
||||||
|
"clip_grad_norm": 500,
|
||||||
|
|
||||||
|
"segment_size": 65536,
|
||||||
|
"num_mels": 128,
|
||||||
|
"num_freq": 2049,
|
||||||
|
"n_fft": 2048,
|
||||||
|
"hop_size": 512,
|
||||||
|
"win_size": 2048,
|
||||||
|
|
||||||
|
"sampling_rate": 44100,
|
||||||
|
|
||||||
|
"fmin": 0,
|
||||||
|
"fmax": null,
|
||||||
|
"fmax_for_loss": null,
|
||||||
|
|
||||||
|
"num_workers": 4,
|
||||||
|
|
||||||
|
"dist_config": {
|
||||||
|
"dist_backend": "nccl",
|
||||||
|
"dist_url": "tcp://localhost:54321",
|
||||||
|
"world_size": 1
|
||||||
|
}
|
||||||
|
}
|
625
GPT_SoVITS/BigVGAN/discriminators.py
Normal file
625
GPT_SoVITS/BigVGAN/discriminators.py
Normal file
@ -0,0 +1,625 @@
|
|||||||
|
# Copyright (c) 2024 NVIDIA CORPORATION.
|
||||||
|
# Licensed under the MIT license.
|
||||||
|
|
||||||
|
# Adapted from https://github.com/jik876/hifi-gan under the MIT license.
|
||||||
|
# LICENSE is in incl_licenses directory.
|
||||||
|
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import torch.nn as nn
|
||||||
|
from torch.nn import Conv2d
|
||||||
|
from torch.nn.utils import weight_norm, spectral_norm
|
||||||
|
from torchaudio.transforms import Spectrogram, Resample
|
||||||
|
|
||||||
|
from env import AttrDict
|
||||||
|
from utils import get_padding
|
||||||
|
import typing
|
||||||
|
from typing import List, Tuple
|
||||||
|
|
||||||
|
|
||||||
|
class DiscriminatorP(torch.nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
h: AttrDict,
|
||||||
|
period: List[int],
|
||||||
|
kernel_size: int = 5,
|
||||||
|
stride: int = 3,
|
||||||
|
use_spectral_norm: bool = False,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.period = period
|
||||||
|
self.d_mult = h.discriminator_channel_mult
|
||||||
|
norm_f = weight_norm if not use_spectral_norm else spectral_norm
|
||||||
|
|
||||||
|
self.convs = nn.ModuleList(
|
||||||
|
[
|
||||||
|
norm_f(
|
||||||
|
Conv2d(
|
||||||
|
1,
|
||||||
|
int(32 * self.d_mult),
|
||||||
|
(kernel_size, 1),
|
||||||
|
(stride, 1),
|
||||||
|
padding=(get_padding(5, 1), 0),
|
||||||
|
)
|
||||||
|
),
|
||||||
|
norm_f(
|
||||||
|
Conv2d(
|
||||||
|
int(32 * self.d_mult),
|
||||||
|
int(128 * self.d_mult),
|
||||||
|
(kernel_size, 1),
|
||||||
|
(stride, 1),
|
||||||
|
padding=(get_padding(5, 1), 0),
|
||||||
|
)
|
||||||
|
),
|
||||||
|
norm_f(
|
||||||
|
Conv2d(
|
||||||
|
int(128 * self.d_mult),
|
||||||
|
int(512 * self.d_mult),
|
||||||
|
(kernel_size, 1),
|
||||||
|
(stride, 1),
|
||||||
|
padding=(get_padding(5, 1), 0),
|
||||||
|
)
|
||||||
|
),
|
||||||
|
norm_f(
|
||||||
|
Conv2d(
|
||||||
|
int(512 * self.d_mult),
|
||||||
|
int(1024 * self.d_mult),
|
||||||
|
(kernel_size, 1),
|
||||||
|
(stride, 1),
|
||||||
|
padding=(get_padding(5, 1), 0),
|
||||||
|
)
|
||||||
|
),
|
||||||
|
norm_f(
|
||||||
|
Conv2d(
|
||||||
|
int(1024 * self.d_mult),
|
||||||
|
int(1024 * self.d_mult),
|
||||||
|
(kernel_size, 1),
|
||||||
|
1,
|
||||||
|
padding=(2, 0),
|
||||||
|
)
|
||||||
|
),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
self.conv_post = norm_f(Conv2d(int(1024 * self.d_mult), 1, (3, 1), 1, padding=(1, 0)))
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, List[torch.Tensor]]:
|
||||||
|
fmap = []
|
||||||
|
|
||||||
|
# 1d to 2d
|
||||||
|
b, c, t = x.shape
|
||||||
|
if t % self.period != 0: # pad first
|
||||||
|
n_pad = self.period - (t % self.period)
|
||||||
|
x = F.pad(x, (0, n_pad), "reflect")
|
||||||
|
t = t + n_pad
|
||||||
|
x = x.view(b, c, t // self.period, self.period)
|
||||||
|
|
||||||
|
for l in self.convs:
|
||||||
|
x = l(x)
|
||||||
|
x = F.leaky_relu(x, 0.1)
|
||||||
|
fmap.append(x)
|
||||||
|
x = self.conv_post(x)
|
||||||
|
fmap.append(x)
|
||||||
|
x = torch.flatten(x, 1, -1)
|
||||||
|
|
||||||
|
return x, fmap
|
||||||
|
|
||||||
|
|
||||||
|
class MultiPeriodDiscriminator(torch.nn.Module):
|
||||||
|
def __init__(self, h: AttrDict):
|
||||||
|
super().__init__()
|
||||||
|
self.mpd_reshapes = h.mpd_reshapes
|
||||||
|
print(f"mpd_reshapes: {self.mpd_reshapes}")
|
||||||
|
self.discriminators = nn.ModuleList(
|
||||||
|
[DiscriminatorP(h, rs, use_spectral_norm=h.use_spectral_norm) for rs in self.mpd_reshapes]
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self, y: torch.Tensor, y_hat: torch.Tensor
|
||||||
|
) -> Tuple[
|
||||||
|
List[torch.Tensor],
|
||||||
|
List[torch.Tensor],
|
||||||
|
List[List[torch.Tensor]],
|
||||||
|
List[List[torch.Tensor]],
|
||||||
|
]:
|
||||||
|
y_d_rs = []
|
||||||
|
y_d_gs = []
|
||||||
|
fmap_rs = []
|
||||||
|
fmap_gs = []
|
||||||
|
for i, d in enumerate(self.discriminators):
|
||||||
|
y_d_r, fmap_r = d(y)
|
||||||
|
y_d_g, fmap_g = d(y_hat)
|
||||||
|
y_d_rs.append(y_d_r)
|
||||||
|
fmap_rs.append(fmap_r)
|
||||||
|
y_d_gs.append(y_d_g)
|
||||||
|
fmap_gs.append(fmap_g)
|
||||||
|
|
||||||
|
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
|
||||||
|
|
||||||
|
|
||||||
|
class DiscriminatorR(nn.Module):
|
||||||
|
def __init__(self, cfg: AttrDict, resolution: List[List[int]]):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.resolution = resolution
|
||||||
|
assert len(self.resolution) == 3, f"MRD layer requires list with len=3, got {self.resolution}"
|
||||||
|
self.lrelu_slope = 0.1
|
||||||
|
|
||||||
|
norm_f = weight_norm if cfg.use_spectral_norm == False else spectral_norm
|
||||||
|
if hasattr(cfg, "mrd_use_spectral_norm"):
|
||||||
|
print(f"[INFO] overriding MRD use_spectral_norm as {cfg.mrd_use_spectral_norm}")
|
||||||
|
norm_f = weight_norm if cfg.mrd_use_spectral_norm == False else spectral_norm
|
||||||
|
self.d_mult = cfg.discriminator_channel_mult
|
||||||
|
if hasattr(cfg, "mrd_channel_mult"):
|
||||||
|
print(f"[INFO] overriding mrd channel multiplier as {cfg.mrd_channel_mult}")
|
||||||
|
self.d_mult = cfg.mrd_channel_mult
|
||||||
|
|
||||||
|
self.convs = nn.ModuleList(
|
||||||
|
[
|
||||||
|
norm_f(nn.Conv2d(1, int(32 * self.d_mult), (3, 9), padding=(1, 4))),
|
||||||
|
norm_f(
|
||||||
|
nn.Conv2d(
|
||||||
|
int(32 * self.d_mult),
|
||||||
|
int(32 * self.d_mult),
|
||||||
|
(3, 9),
|
||||||
|
stride=(1, 2),
|
||||||
|
padding=(1, 4),
|
||||||
|
)
|
||||||
|
),
|
||||||
|
norm_f(
|
||||||
|
nn.Conv2d(
|
||||||
|
int(32 * self.d_mult),
|
||||||
|
int(32 * self.d_mult),
|
||||||
|
(3, 9),
|
||||||
|
stride=(1, 2),
|
||||||
|
padding=(1, 4),
|
||||||
|
)
|
||||||
|
),
|
||||||
|
norm_f(
|
||||||
|
nn.Conv2d(
|
||||||
|
int(32 * self.d_mult),
|
||||||
|
int(32 * self.d_mult),
|
||||||
|
(3, 9),
|
||||||
|
stride=(1, 2),
|
||||||
|
padding=(1, 4),
|
||||||
|
)
|
||||||
|
),
|
||||||
|
norm_f(
|
||||||
|
nn.Conv2d(
|
||||||
|
int(32 * self.d_mult),
|
||||||
|
int(32 * self.d_mult),
|
||||||
|
(3, 3),
|
||||||
|
padding=(1, 1),
|
||||||
|
)
|
||||||
|
),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
self.conv_post = norm_f(nn.Conv2d(int(32 * self.d_mult), 1, (3, 3), padding=(1, 1)))
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, List[torch.Tensor]]:
|
||||||
|
fmap = []
|
||||||
|
|
||||||
|
x = self.spectrogram(x)
|
||||||
|
x = x.unsqueeze(1)
|
||||||
|
for l in self.convs:
|
||||||
|
x = l(x)
|
||||||
|
x = F.leaky_relu(x, self.lrelu_slope)
|
||||||
|
fmap.append(x)
|
||||||
|
x = self.conv_post(x)
|
||||||
|
fmap.append(x)
|
||||||
|
x = torch.flatten(x, 1, -1)
|
||||||
|
|
||||||
|
return x, fmap
|
||||||
|
|
||||||
|
def spectrogram(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
n_fft, hop_length, win_length = self.resolution
|
||||||
|
x = F.pad(
|
||||||
|
x,
|
||||||
|
(int((n_fft - hop_length) / 2), int((n_fft - hop_length) / 2)),
|
||||||
|
mode="reflect",
|
||||||
|
)
|
||||||
|
x = x.squeeze(1)
|
||||||
|
x = torch.stft(
|
||||||
|
x,
|
||||||
|
n_fft=n_fft,
|
||||||
|
hop_length=hop_length,
|
||||||
|
win_length=win_length,
|
||||||
|
center=False,
|
||||||
|
return_complex=True,
|
||||||
|
)
|
||||||
|
x = torch.view_as_real(x) # [B, F, TT, 2]
|
||||||
|
mag = torch.norm(x, p=2, dim=-1) # [B, F, TT]
|
||||||
|
|
||||||
|
return mag
|
||||||
|
|
||||||
|
|
||||||
|
class MultiResolutionDiscriminator(nn.Module):
|
||||||
|
def __init__(self, cfg, debug=False):
|
||||||
|
super().__init__()
|
||||||
|
self.resolutions = cfg.resolutions
|
||||||
|
assert len(self.resolutions) == 3, (
|
||||||
|
f"MRD requires list of list with len=3, each element having a list with len=3. Got {self.resolutions}"
|
||||||
|
)
|
||||||
|
self.discriminators = nn.ModuleList([DiscriminatorR(cfg, resolution) for resolution in self.resolutions])
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self, y: torch.Tensor, y_hat: torch.Tensor
|
||||||
|
) -> Tuple[
|
||||||
|
List[torch.Tensor],
|
||||||
|
List[torch.Tensor],
|
||||||
|
List[List[torch.Tensor]],
|
||||||
|
List[List[torch.Tensor]],
|
||||||
|
]:
|
||||||
|
y_d_rs = []
|
||||||
|
y_d_gs = []
|
||||||
|
fmap_rs = []
|
||||||
|
fmap_gs = []
|
||||||
|
|
||||||
|
for i, d in enumerate(self.discriminators):
|
||||||
|
y_d_r, fmap_r = d(x=y)
|
||||||
|
y_d_g, fmap_g = d(x=y_hat)
|
||||||
|
y_d_rs.append(y_d_r)
|
||||||
|
fmap_rs.append(fmap_r)
|
||||||
|
y_d_gs.append(y_d_g)
|
||||||
|
fmap_gs.append(fmap_g)
|
||||||
|
|
||||||
|
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
|
||||||
|
|
||||||
|
|
||||||
|
# Method based on descript-audio-codec: https://github.com/descriptinc/descript-audio-codec
|
||||||
|
# Modified code adapted from https://github.com/gemelo-ai/vocos under the MIT license.
|
||||||
|
# LICENSE is in incl_licenses directory.
|
||||||
|
class DiscriminatorB(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
window_length: int,
|
||||||
|
channels: int = 32,
|
||||||
|
hop_factor: float = 0.25,
|
||||||
|
bands: Tuple[Tuple[float, float], ...] = (
|
||||||
|
(0.0, 0.1),
|
||||||
|
(0.1, 0.25),
|
||||||
|
(0.25, 0.5),
|
||||||
|
(0.5, 0.75),
|
||||||
|
(0.75, 1.0),
|
||||||
|
),
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.window_length = window_length
|
||||||
|
self.hop_factor = hop_factor
|
||||||
|
self.spec_fn = Spectrogram(
|
||||||
|
n_fft=window_length,
|
||||||
|
hop_length=int(window_length * hop_factor),
|
||||||
|
win_length=window_length,
|
||||||
|
power=None,
|
||||||
|
)
|
||||||
|
n_fft = window_length // 2 + 1
|
||||||
|
bands = [(int(b[0] * n_fft), int(b[1] * n_fft)) for b in bands]
|
||||||
|
self.bands = bands
|
||||||
|
convs = lambda: nn.ModuleList(
|
||||||
|
[
|
||||||
|
weight_norm(nn.Conv2d(2, channels, (3, 9), (1, 1), padding=(1, 4))),
|
||||||
|
weight_norm(nn.Conv2d(channels, channels, (3, 9), (1, 2), padding=(1, 4))),
|
||||||
|
weight_norm(nn.Conv2d(channels, channels, (3, 9), (1, 2), padding=(1, 4))),
|
||||||
|
weight_norm(nn.Conv2d(channels, channels, (3, 9), (1, 2), padding=(1, 4))),
|
||||||
|
weight_norm(nn.Conv2d(channels, channels, (3, 3), (1, 1), padding=(1, 1))),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
self.band_convs = nn.ModuleList([convs() for _ in range(len(self.bands))])
|
||||||
|
|
||||||
|
self.conv_post = weight_norm(nn.Conv2d(channels, 1, (3, 3), (1, 1), padding=(1, 1)))
|
||||||
|
|
||||||
|
def spectrogram(self, x: torch.Tensor) -> List[torch.Tensor]:
|
||||||
|
# Remove DC offset
|
||||||
|
x = x - x.mean(dim=-1, keepdims=True)
|
||||||
|
# Peak normalize the volume of input audio
|
||||||
|
x = 0.8 * x / (x.abs().max(dim=-1, keepdim=True)[0] + 1e-9)
|
||||||
|
x = self.spec_fn(x)
|
||||||
|
x = torch.view_as_real(x)
|
||||||
|
x = x.permute(0, 3, 2, 1) # [B, F, T, C] -> [B, C, T, F]
|
||||||
|
# Split into bands
|
||||||
|
x_bands = [x[..., b[0] : b[1]] for b in self.bands]
|
||||||
|
return x_bands
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, List[torch.Tensor]]:
|
||||||
|
x_bands = self.spectrogram(x.squeeze(1))
|
||||||
|
fmap = []
|
||||||
|
x = []
|
||||||
|
|
||||||
|
for band, stack in zip(x_bands, self.band_convs):
|
||||||
|
for i, layer in enumerate(stack):
|
||||||
|
band = layer(band)
|
||||||
|
band = torch.nn.functional.leaky_relu(band, 0.1)
|
||||||
|
if i > 0:
|
||||||
|
fmap.append(band)
|
||||||
|
x.append(band)
|
||||||
|
|
||||||
|
x = torch.cat(x, dim=-1)
|
||||||
|
x = self.conv_post(x)
|
||||||
|
fmap.append(x)
|
||||||
|
|
||||||
|
return x, fmap
|
||||||
|
|
||||||
|
|
||||||
|
# Method based on descript-audio-codec: https://github.com/descriptinc/descript-audio-codec
|
||||||
|
# Modified code adapted from https://github.com/gemelo-ai/vocos under the MIT license.
|
||||||
|
# LICENSE is in incl_licenses directory.
|
||||||
|
class MultiBandDiscriminator(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
h,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Multi-band multi-scale STFT discriminator, with the architecture based on https://github.com/descriptinc/descript-audio-codec.
|
||||||
|
and the modified code adapted from https://github.com/gemelo-ai/vocos.
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
# fft_sizes (list[int]): Tuple of window lengths for FFT. Defaults to [2048, 1024, 512] if not set in h.
|
||||||
|
self.fft_sizes = h.get("mbd_fft_sizes", [2048, 1024, 512])
|
||||||
|
self.discriminators = nn.ModuleList([DiscriminatorB(window_length=w) for w in self.fft_sizes])
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self, y: torch.Tensor, y_hat: torch.Tensor
|
||||||
|
) -> Tuple[
|
||||||
|
List[torch.Tensor],
|
||||||
|
List[torch.Tensor],
|
||||||
|
List[List[torch.Tensor]],
|
||||||
|
List[List[torch.Tensor]],
|
||||||
|
]:
|
||||||
|
y_d_rs = []
|
||||||
|
y_d_gs = []
|
||||||
|
fmap_rs = []
|
||||||
|
fmap_gs = []
|
||||||
|
|
||||||
|
for d in self.discriminators:
|
||||||
|
y_d_r, fmap_r = d(x=y)
|
||||||
|
y_d_g, fmap_g = d(x=y_hat)
|
||||||
|
y_d_rs.append(y_d_r)
|
||||||
|
fmap_rs.append(fmap_r)
|
||||||
|
y_d_gs.append(y_d_g)
|
||||||
|
fmap_gs.append(fmap_g)
|
||||||
|
|
||||||
|
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
|
||||||
|
|
||||||
|
|
||||||
|
# Adapted from https://github.com/open-mmlab/Amphion/blob/main/models/vocoders/gan/discriminator/mssbcqtd.py under the MIT license.
|
||||||
|
# LICENSE is in incl_licenses directory.
|
||||||
|
class DiscriminatorCQT(nn.Module):
|
||||||
|
def __init__(self, cfg: AttrDict, hop_length: int, n_octaves: int, bins_per_octave: int):
|
||||||
|
super().__init__()
|
||||||
|
self.cfg = cfg
|
||||||
|
|
||||||
|
self.filters = cfg["cqtd_filters"]
|
||||||
|
self.max_filters = cfg["cqtd_max_filters"]
|
||||||
|
self.filters_scale = cfg["cqtd_filters_scale"]
|
||||||
|
self.kernel_size = (3, 9)
|
||||||
|
self.dilations = cfg["cqtd_dilations"]
|
||||||
|
self.stride = (1, 2)
|
||||||
|
|
||||||
|
self.in_channels = cfg["cqtd_in_channels"]
|
||||||
|
self.out_channels = cfg["cqtd_out_channels"]
|
||||||
|
self.fs = cfg["sampling_rate"]
|
||||||
|
self.hop_length = hop_length
|
||||||
|
self.n_octaves = n_octaves
|
||||||
|
self.bins_per_octave = bins_per_octave
|
||||||
|
|
||||||
|
# Lazy-load
|
||||||
|
from nnAudio import features
|
||||||
|
|
||||||
|
self.cqt_transform = features.cqt.CQT2010v2(
|
||||||
|
sr=self.fs * 2,
|
||||||
|
hop_length=self.hop_length,
|
||||||
|
n_bins=self.bins_per_octave * self.n_octaves,
|
||||||
|
bins_per_octave=self.bins_per_octave,
|
||||||
|
output_format="Complex",
|
||||||
|
pad_mode="constant",
|
||||||
|
)
|
||||||
|
|
||||||
|
self.conv_pres = nn.ModuleList()
|
||||||
|
for _ in range(self.n_octaves):
|
||||||
|
self.conv_pres.append(
|
||||||
|
nn.Conv2d(
|
||||||
|
self.in_channels * 2,
|
||||||
|
self.in_channels * 2,
|
||||||
|
kernel_size=self.kernel_size,
|
||||||
|
padding=self.get_2d_padding(self.kernel_size),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.convs = nn.ModuleList()
|
||||||
|
|
||||||
|
self.convs.append(
|
||||||
|
nn.Conv2d(
|
||||||
|
self.in_channels * 2,
|
||||||
|
self.filters,
|
||||||
|
kernel_size=self.kernel_size,
|
||||||
|
padding=self.get_2d_padding(self.kernel_size),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
in_chs = min(self.filters_scale * self.filters, self.max_filters)
|
||||||
|
for i, dilation in enumerate(self.dilations):
|
||||||
|
out_chs = min((self.filters_scale ** (i + 1)) * self.filters, self.max_filters)
|
||||||
|
self.convs.append(
|
||||||
|
weight_norm(
|
||||||
|
nn.Conv2d(
|
||||||
|
in_chs,
|
||||||
|
out_chs,
|
||||||
|
kernel_size=self.kernel_size,
|
||||||
|
stride=self.stride,
|
||||||
|
dilation=(dilation, 1),
|
||||||
|
padding=self.get_2d_padding(self.kernel_size, (dilation, 1)),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
in_chs = out_chs
|
||||||
|
out_chs = min(
|
||||||
|
(self.filters_scale ** (len(self.dilations) + 1)) * self.filters,
|
||||||
|
self.max_filters,
|
||||||
|
)
|
||||||
|
self.convs.append(
|
||||||
|
weight_norm(
|
||||||
|
nn.Conv2d(
|
||||||
|
in_chs,
|
||||||
|
out_chs,
|
||||||
|
kernel_size=(self.kernel_size[0], self.kernel_size[0]),
|
||||||
|
padding=self.get_2d_padding((self.kernel_size[0], self.kernel_size[0])),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.conv_post = weight_norm(
|
||||||
|
nn.Conv2d(
|
||||||
|
out_chs,
|
||||||
|
self.out_channels,
|
||||||
|
kernel_size=(self.kernel_size[0], self.kernel_size[0]),
|
||||||
|
padding=self.get_2d_padding((self.kernel_size[0], self.kernel_size[0])),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.activation = torch.nn.LeakyReLU(negative_slope=0.1)
|
||||||
|
self.resample = Resample(orig_freq=self.fs, new_freq=self.fs * 2)
|
||||||
|
|
||||||
|
self.cqtd_normalize_volume = self.cfg.get("cqtd_normalize_volume", False)
|
||||||
|
if self.cqtd_normalize_volume:
|
||||||
|
print(
|
||||||
|
"[INFO] cqtd_normalize_volume set to True. Will apply DC offset removal & peak volume normalization in CQTD!"
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_2d_padding(
|
||||||
|
self,
|
||||||
|
kernel_size: typing.Tuple[int, int],
|
||||||
|
dilation: typing.Tuple[int, int] = (1, 1),
|
||||||
|
):
|
||||||
|
return (
|
||||||
|
((kernel_size[0] - 1) * dilation[0]) // 2,
|
||||||
|
((kernel_size[1] - 1) * dilation[1]) // 2,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x: torch.tensor) -> Tuple[torch.Tensor, List[torch.Tensor]]:
|
||||||
|
fmap = []
|
||||||
|
|
||||||
|
if self.cqtd_normalize_volume:
|
||||||
|
# Remove DC offset
|
||||||
|
x = x - x.mean(dim=-1, keepdims=True)
|
||||||
|
# Peak normalize the volume of input audio
|
||||||
|
x = 0.8 * x / (x.abs().max(dim=-1, keepdim=True)[0] + 1e-9)
|
||||||
|
|
||||||
|
x = self.resample(x)
|
||||||
|
|
||||||
|
z = self.cqt_transform(x)
|
||||||
|
|
||||||
|
z_amplitude = z[:, :, :, 0].unsqueeze(1)
|
||||||
|
z_phase = z[:, :, :, 1].unsqueeze(1)
|
||||||
|
|
||||||
|
z = torch.cat([z_amplitude, z_phase], dim=1)
|
||||||
|
z = torch.permute(z, (0, 1, 3, 2)) # [B, C, W, T] -> [B, C, T, W]
|
||||||
|
|
||||||
|
latent_z = []
|
||||||
|
for i in range(self.n_octaves):
|
||||||
|
latent_z.append(
|
||||||
|
self.conv_pres[i](
|
||||||
|
z[
|
||||||
|
:,
|
||||||
|
:,
|
||||||
|
:,
|
||||||
|
i * self.bins_per_octave : (i + 1) * self.bins_per_octave,
|
||||||
|
]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
latent_z = torch.cat(latent_z, dim=-1)
|
||||||
|
|
||||||
|
for i, l in enumerate(self.convs):
|
||||||
|
latent_z = l(latent_z)
|
||||||
|
|
||||||
|
latent_z = self.activation(latent_z)
|
||||||
|
fmap.append(latent_z)
|
||||||
|
|
||||||
|
latent_z = self.conv_post(latent_z)
|
||||||
|
|
||||||
|
return latent_z, fmap
|
||||||
|
|
||||||
|
|
||||||
|
class MultiScaleSubbandCQTDiscriminator(nn.Module):
|
||||||
|
def __init__(self, cfg: AttrDict):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.cfg = cfg
|
||||||
|
# Using get with defaults
|
||||||
|
self.cfg["cqtd_filters"] = self.cfg.get("cqtd_filters", 32)
|
||||||
|
self.cfg["cqtd_max_filters"] = self.cfg.get("cqtd_max_filters", 1024)
|
||||||
|
self.cfg["cqtd_filters_scale"] = self.cfg.get("cqtd_filters_scale", 1)
|
||||||
|
self.cfg["cqtd_dilations"] = self.cfg.get("cqtd_dilations", [1, 2, 4])
|
||||||
|
self.cfg["cqtd_in_channels"] = self.cfg.get("cqtd_in_channels", 1)
|
||||||
|
self.cfg["cqtd_out_channels"] = self.cfg.get("cqtd_out_channels", 1)
|
||||||
|
# Multi-scale params to loop over
|
||||||
|
self.cfg["cqtd_hop_lengths"] = self.cfg.get("cqtd_hop_lengths", [512, 256, 256])
|
||||||
|
self.cfg["cqtd_n_octaves"] = self.cfg.get("cqtd_n_octaves", [9, 9, 9])
|
||||||
|
self.cfg["cqtd_bins_per_octaves"] = self.cfg.get("cqtd_bins_per_octaves", [24, 36, 48])
|
||||||
|
|
||||||
|
self.discriminators = nn.ModuleList(
|
||||||
|
[
|
||||||
|
DiscriminatorCQT(
|
||||||
|
self.cfg,
|
||||||
|
hop_length=self.cfg["cqtd_hop_lengths"][i],
|
||||||
|
n_octaves=self.cfg["cqtd_n_octaves"][i],
|
||||||
|
bins_per_octave=self.cfg["cqtd_bins_per_octaves"][i],
|
||||||
|
)
|
||||||
|
for i in range(len(self.cfg["cqtd_hop_lengths"]))
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self, y: torch.Tensor, y_hat: torch.Tensor
|
||||||
|
) -> Tuple[
|
||||||
|
List[torch.Tensor],
|
||||||
|
List[torch.Tensor],
|
||||||
|
List[List[torch.Tensor]],
|
||||||
|
List[List[torch.Tensor]],
|
||||||
|
]:
|
||||||
|
y_d_rs = []
|
||||||
|
y_d_gs = []
|
||||||
|
fmap_rs = []
|
||||||
|
fmap_gs = []
|
||||||
|
|
||||||
|
for disc in self.discriminators:
|
||||||
|
y_d_r, fmap_r = disc(y)
|
||||||
|
y_d_g, fmap_g = disc(y_hat)
|
||||||
|
y_d_rs.append(y_d_r)
|
||||||
|
fmap_rs.append(fmap_r)
|
||||||
|
y_d_gs.append(y_d_g)
|
||||||
|
fmap_gs.append(fmap_g)
|
||||||
|
|
||||||
|
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
|
||||||
|
|
||||||
|
|
||||||
|
class CombinedDiscriminator(nn.Module):
|
||||||
|
"""
|
||||||
|
Wrapper of chaining multiple discrimiantor architectures.
|
||||||
|
Example: combine mbd and cqtd as a single class
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, list_discriminator: List[nn.Module]):
|
||||||
|
super().__init__()
|
||||||
|
self.discrimiantor = nn.ModuleList(list_discriminator)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self, y: torch.Tensor, y_hat: torch.Tensor
|
||||||
|
) -> Tuple[
|
||||||
|
List[torch.Tensor],
|
||||||
|
List[torch.Tensor],
|
||||||
|
List[List[torch.Tensor]],
|
||||||
|
List[List[torch.Tensor]],
|
||||||
|
]:
|
||||||
|
y_d_rs = []
|
||||||
|
y_d_gs = []
|
||||||
|
fmap_rs = []
|
||||||
|
fmap_gs = []
|
||||||
|
|
||||||
|
for disc in self.discrimiantor:
|
||||||
|
y_d_r, y_d_g, fmap_r, fmap_g = disc(y, y_hat)
|
||||||
|
y_d_rs.extend(y_d_r)
|
||||||
|
fmap_rs.extend(fmap_r)
|
||||||
|
y_d_gs.extend(y_d_g)
|
||||||
|
fmap_gs.extend(fmap_g)
|
||||||
|
|
||||||
|
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
|
18
GPT_SoVITS/BigVGAN/env.py
Normal file
18
GPT_SoVITS/BigVGAN/env.py
Normal file
@ -0,0 +1,18 @@
|
|||||||
|
# Adapted from https://github.com/jik876/hifi-gan under the MIT license.
|
||||||
|
# LICENSE is in incl_licenses directory.
|
||||||
|
|
||||||
|
import os
|
||||||
|
import shutil
|
||||||
|
|
||||||
|
|
||||||
|
class AttrDict(dict):
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super(AttrDict, self).__init__(*args, **kwargs)
|
||||||
|
self.__dict__ = self
|
||||||
|
|
||||||
|
|
||||||
|
def build_env(config, config_name, path):
|
||||||
|
t_path = os.path.join(path, config_name)
|
||||||
|
if config != t_path:
|
||||||
|
os.makedirs(path, exist_ok=True)
|
||||||
|
shutil.copyfile(config, os.path.join(path, config_name))
|
21
GPT_SoVITS/BigVGAN/incl_licenses/LICENSE_1
Normal file
21
GPT_SoVITS/BigVGAN/incl_licenses/LICENSE_1
Normal file
@ -0,0 +1,21 @@
|
|||||||
|
MIT License
|
||||||
|
|
||||||
|
Copyright (c) 2020 Jungil Kong
|
||||||
|
|
||||||
|
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||||
|
of this software and associated documentation files (the "Software"), to deal
|
||||||
|
in the Software without restriction, including without limitation the rights
|
||||||
|
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||||
|
copies of the Software, and to permit persons to whom the Software is
|
||||||
|
furnished to do so, subject to the following conditions:
|
||||||
|
|
||||||
|
The above copyright notice and this permission notice shall be included in all
|
||||||
|
copies or substantial portions of the Software.
|
||||||
|
|
||||||
|
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||||
|
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||||
|
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||||
|
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||||
|
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||||
|
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||||
|
SOFTWARE.
|
21
GPT_SoVITS/BigVGAN/incl_licenses/LICENSE_2
Normal file
21
GPT_SoVITS/BigVGAN/incl_licenses/LICENSE_2
Normal file
@ -0,0 +1,21 @@
|
|||||||
|
MIT License
|
||||||
|
|
||||||
|
Copyright (c) 2020 Edward Dixon
|
||||||
|
|
||||||
|
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||||
|
of this software and associated documentation files (the "Software"), to deal
|
||||||
|
in the Software without restriction, including without limitation the rights
|
||||||
|
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||||
|
copies of the Software, and to permit persons to whom the Software is
|
||||||
|
furnished to do so, subject to the following conditions:
|
||||||
|
|
||||||
|
The above copyright notice and this permission notice shall be included in all
|
||||||
|
copies or substantial portions of the Software.
|
||||||
|
|
||||||
|
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||||
|
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||||
|
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||||
|
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||||
|
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||||
|
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||||
|
SOFTWARE.
|
201
GPT_SoVITS/BigVGAN/incl_licenses/LICENSE_3
Normal file
201
GPT_SoVITS/BigVGAN/incl_licenses/LICENSE_3
Normal file
@ -0,0 +1,201 @@
|
|||||||
|
Apache License
|
||||||
|
Version 2.0, January 2004
|
||||||
|
http://www.apache.org/licenses/
|
||||||
|
|
||||||
|
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||||
|
|
||||||
|
1. Definitions.
|
||||||
|
|
||||||
|
"License" shall mean the terms and conditions for use, reproduction,
|
||||||
|
and distribution as defined by Sections 1 through 9 of this document.
|
||||||
|
|
||||||
|
"Licensor" shall mean the copyright owner or entity authorized by
|
||||||
|
the copyright owner that is granting the License.
|
||||||
|
|
||||||
|
"Legal Entity" shall mean the union of the acting entity and all
|
||||||
|
other entities that control, are controlled by, or are under common
|
||||||
|
control with that entity. For the purposes of this definition,
|
||||||
|
"control" means (i) the power, direct or indirect, to cause the
|
||||||
|
direction or management of such entity, whether by contract or
|
||||||
|
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
||||||
|
outstanding shares, or (iii) beneficial ownership of such entity.
|
||||||
|
|
||||||
|
"You" (or "Your") shall mean an individual or Legal Entity
|
||||||
|
exercising permissions granted by this License.
|
||||||
|
|
||||||
|
"Source" form shall mean the preferred form for making modifications,
|
||||||
|
including but not limited to software source code, documentation
|
||||||
|
source, and configuration files.
|
||||||
|
|
||||||
|
"Object" form shall mean any form resulting from mechanical
|
||||||
|
transformation or translation of a Source form, including but
|
||||||
|
not limited to compiled object code, generated documentation,
|
||||||
|
and conversions to other media types.
|
||||||
|
|
||||||
|
"Work" shall mean the work of authorship, whether in Source or
|
||||||
|
Object form, made available under the License, as indicated by a
|
||||||
|
copyright notice that is included in or attached to the work
|
||||||
|
(an example is provided in the Appendix below).
|
||||||
|
|
||||||
|
"Derivative Works" shall mean any work, whether in Source or Object
|
||||||
|
form, that is based on (or derived from) the Work and for which the
|
||||||
|
editorial revisions, annotations, elaborations, or other modifications
|
||||||
|
represent, as a whole, an original work of authorship. For the purposes
|
||||||
|
of this License, Derivative Works shall not include works that remain
|
||||||
|
separable from, or merely link (or bind by name) to the interfaces of,
|
||||||
|
the Work and Derivative Works thereof.
|
||||||
|
|
||||||
|
"Contribution" shall mean any work of authorship, including
|
||||||
|
the original version of the Work and any modifications or additions
|
||||||
|
to that Work or Derivative Works thereof, that is intentionally
|
||||||
|
submitted to Licensor for inclusion in the Work by the copyright owner
|
||||||
|
or by an individual or Legal Entity authorized to submit on behalf of
|
||||||
|
the copyright owner. For the purposes of this definition, "submitted"
|
||||||
|
means any form of electronic, verbal, or written communication sent
|
||||||
|
to the Licensor or its representatives, including but not limited to
|
||||||
|
communication on electronic mailing lists, source code control systems,
|
||||||
|
and issue tracking systems that are managed by, or on behalf of, the
|
||||||
|
Licensor for the purpose of discussing and improving the Work, but
|
||||||
|
excluding communication that is conspicuously marked or otherwise
|
||||||
|
designated in writing by the copyright owner as "Not a Contribution."
|
||||||
|
|
||||||
|
"Contributor" shall mean Licensor and any individual or Legal Entity
|
||||||
|
on behalf of whom a Contribution has been received by Licensor and
|
||||||
|
subsequently incorporated within the Work.
|
||||||
|
|
||||||
|
2. Grant of Copyright License. Subject to the terms and conditions of
|
||||||
|
this License, each Contributor hereby grants to You a perpetual,
|
||||||
|
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||||
|
copyright license to reproduce, prepare Derivative Works of,
|
||||||
|
publicly display, publicly perform, sublicense, and distribute the
|
||||||
|
Work and such Derivative Works in Source or Object form.
|
||||||
|
|
||||||
|
3. Grant of Patent License. Subject to the terms and conditions of
|
||||||
|
this License, each Contributor hereby grants to You a perpetual,
|
||||||
|
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||||
|
(except as stated in this section) patent license to make, have made,
|
||||||
|
use, offer to sell, sell, import, and otherwise transfer the Work,
|
||||||
|
where such license applies only to those patent claims licensable
|
||||||
|
by such Contributor that are necessarily infringed by their
|
||||||
|
Contribution(s) alone or by combination of their Contribution(s)
|
||||||
|
with the Work to which such Contribution(s) was submitted. If You
|
||||||
|
institute patent litigation against any entity (including a
|
||||||
|
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
||||||
|
or a Contribution incorporated within the Work constitutes direct
|
||||||
|
or contributory patent infringement, then any patent licenses
|
||||||
|
granted to You under this License for that Work shall terminate
|
||||||
|
as of the date such litigation is filed.
|
||||||
|
|
||||||
|
4. Redistribution. You may reproduce and distribute copies of the
|
||||||
|
Work or Derivative Works thereof in any medium, with or without
|
||||||
|
modifications, and in Source or Object form, provided that You
|
||||||
|
meet the following conditions:
|
||||||
|
|
||||||
|
(a) You must give any other recipients of the Work or
|
||||||
|
Derivative Works a copy of this License; and
|
||||||
|
|
||||||
|
(b) You must cause any modified files to carry prominent notices
|
||||||
|
stating that You changed the files; and
|
||||||
|
|
||||||
|
(c) You must retain, in the Source form of any Derivative Works
|
||||||
|
that You distribute, all copyright, patent, trademark, and
|
||||||
|
attribution notices from the Source form of the Work,
|
||||||
|
excluding those notices that do not pertain to any part of
|
||||||
|
the Derivative Works; and
|
||||||
|
|
||||||
|
(d) If the Work includes a "NOTICE" text file as part of its
|
||||||
|
distribution, then any Derivative Works that You distribute must
|
||||||
|
include a readable copy of the attribution notices contained
|
||||||
|
within such NOTICE file, excluding those notices that do not
|
||||||
|
pertain to any part of the Derivative Works, in at least one
|
||||||
|
of the following places: within a NOTICE text file distributed
|
||||||
|
as part of the Derivative Works; within the Source form or
|
||||||
|
documentation, if provided along with the Derivative Works; or,
|
||||||
|
within a display generated by the Derivative Works, if and
|
||||||
|
wherever such third-party notices normally appear. The contents
|
||||||
|
of the NOTICE file are for informational purposes only and
|
||||||
|
do not modify the License. You may add Your own attribution
|
||||||
|
notices within Derivative Works that You distribute, alongside
|
||||||
|
or as an addendum to the NOTICE text from the Work, provided
|
||||||
|
that such additional attribution notices cannot be construed
|
||||||
|
as modifying the License.
|
||||||
|
|
||||||
|
You may add Your own copyright statement to Your modifications and
|
||||||
|
may provide additional or different license terms and conditions
|
||||||
|
for use, reproduction, or distribution of Your modifications, or
|
||||||
|
for any such Derivative Works as a whole, provided Your use,
|
||||||
|
reproduction, and distribution of the Work otherwise complies with
|
||||||
|
the conditions stated in this License.
|
||||||
|
|
||||||
|
5. Submission of Contributions. Unless You explicitly state otherwise,
|
||||||
|
any Contribution intentionally submitted for inclusion in the Work
|
||||||
|
by You to the Licensor shall be under the terms and conditions of
|
||||||
|
this License, without any additional terms or conditions.
|
||||||
|
Notwithstanding the above, nothing herein shall supersede or modify
|
||||||
|
the terms of any separate license agreement you may have executed
|
||||||
|
with Licensor regarding such Contributions.
|
||||||
|
|
||||||
|
6. Trademarks. This License does not grant permission to use the trade
|
||||||
|
names, trademarks, service marks, or product names of the Licensor,
|
||||||
|
except as required for reasonable and customary use in describing the
|
||||||
|
origin of the Work and reproducing the content of the NOTICE file.
|
||||||
|
|
||||||
|
7. Disclaimer of Warranty. Unless required by applicable law or
|
||||||
|
agreed to in writing, Licensor provides the Work (and each
|
||||||
|
Contributor provides its Contributions) on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||||
|
implied, including, without limitation, any warranties or conditions
|
||||||
|
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
||||||
|
PARTICULAR PURPOSE. You are solely responsible for determining the
|
||||||
|
appropriateness of using or redistributing the Work and assume any
|
||||||
|
risks associated with Your exercise of permissions under this License.
|
||||||
|
|
||||||
|
8. Limitation of Liability. In no event and under no legal theory,
|
||||||
|
whether in tort (including negligence), contract, or otherwise,
|
||||||
|
unless required by applicable law (such as deliberate and grossly
|
||||||
|
negligent acts) or agreed to in writing, shall any Contributor be
|
||||||
|
liable to You for damages, including any direct, indirect, special,
|
||||||
|
incidental, or consequential damages of any character arising as a
|
||||||
|
result of this License or out of the use or inability to use the
|
||||||
|
Work (including but not limited to damages for loss of goodwill,
|
||||||
|
work stoppage, computer failure or malfunction, or any and all
|
||||||
|
other commercial damages or losses), even if such Contributor
|
||||||
|
has been advised of the possibility of such damages.
|
||||||
|
|
||||||
|
9. Accepting Warranty or Additional Liability. While redistributing
|
||||||
|
the Work or Derivative Works thereof, You may choose to offer,
|
||||||
|
and charge a fee for, acceptance of support, warranty, indemnity,
|
||||||
|
or other liability obligations and/or rights consistent with this
|
||||||
|
License. However, in accepting such obligations, You may act only
|
||||||
|
on Your own behalf and on Your sole responsibility, not on behalf
|
||||||
|
of any other Contributor, and only if You agree to indemnify,
|
||||||
|
defend, and hold each Contributor harmless for any liability
|
||||||
|
incurred by, or claims asserted against, such Contributor by reason
|
||||||
|
of your accepting any such warranty or additional liability.
|
||||||
|
|
||||||
|
END OF TERMS AND CONDITIONS
|
||||||
|
|
||||||
|
APPENDIX: How to apply the Apache License to your work.
|
||||||
|
|
||||||
|
To apply the Apache License to your work, attach the following
|
||||||
|
boilerplate notice, with the fields enclosed by brackets "[]"
|
||||||
|
replaced with your own identifying information. (Don't include
|
||||||
|
the brackets!) The text should be enclosed in the appropriate
|
||||||
|
comment syntax for the file format. We also recommend that a
|
||||||
|
file or class name and description of purpose be included on the
|
||||||
|
same "printed page" as the copyright notice for easier
|
||||||
|
identification within third-party archives.
|
||||||
|
|
||||||
|
Copyright [yyyy] [name of copyright owner]
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
29
GPT_SoVITS/BigVGAN/incl_licenses/LICENSE_4
Normal file
29
GPT_SoVITS/BigVGAN/incl_licenses/LICENSE_4
Normal file
@ -0,0 +1,29 @@
|
|||||||
|
BSD 3-Clause License
|
||||||
|
|
||||||
|
Copyright (c) 2019, Seungwon Park 박승원
|
||||||
|
All rights reserved.
|
||||||
|
|
||||||
|
Redistribution and use in source and binary forms, with or without
|
||||||
|
modification, are permitted provided that the following conditions are met:
|
||||||
|
|
||||||
|
1. Redistributions of source code must retain the above copyright notice, this
|
||||||
|
list of conditions and the following disclaimer.
|
||||||
|
|
||||||
|
2. Redistributions in binary form must reproduce the above copyright notice,
|
||||||
|
this list of conditions and the following disclaimer in the documentation
|
||||||
|
and/or other materials provided with the distribution.
|
||||||
|
|
||||||
|
3. Neither the name of the copyright holder nor the names of its
|
||||||
|
contributors may be used to endorse or promote products derived from
|
||||||
|
this software without specific prior written permission.
|
||||||
|
|
||||||
|
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||||
|
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||||
|
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||||
|
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||||
|
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||||
|
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||||
|
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||||
|
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||||
|
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||||
|
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
16
GPT_SoVITS/BigVGAN/incl_licenses/LICENSE_5
Normal file
16
GPT_SoVITS/BigVGAN/incl_licenses/LICENSE_5
Normal file
@ -0,0 +1,16 @@
|
|||||||
|
Copyright 2020 Alexandre Défossez
|
||||||
|
|
||||||
|
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and
|
||||||
|
associated documentation files (the "Software"), to deal in the Software without restriction,
|
||||||
|
including without limitation the rights to use, copy, modify, merge, publish, distribute,
|
||||||
|
sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is
|
||||||
|
furnished to do so, subject to the following conditions:
|
||||||
|
|
||||||
|
The above copyright notice and this permission notice shall be included in all copies or
|
||||||
|
substantial portions of the Software.
|
||||||
|
|
||||||
|
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT
|
||||||
|
NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
|
||||||
|
NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
|
||||||
|
DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||||
|
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
21
GPT_SoVITS/BigVGAN/incl_licenses/LICENSE_6
Normal file
21
GPT_SoVITS/BigVGAN/incl_licenses/LICENSE_6
Normal file
@ -0,0 +1,21 @@
|
|||||||
|
MIT License
|
||||||
|
|
||||||
|
Copyright (c) 2023-present, Descript
|
||||||
|
|
||||||
|
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||||
|
of this software and associated documentation files (the "Software"), to deal
|
||||||
|
in the Software without restriction, including without limitation the rights
|
||||||
|
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||||
|
copies of the Software, and to permit persons to whom the Software is
|
||||||
|
furnished to do so, subject to the following conditions:
|
||||||
|
|
||||||
|
The above copyright notice and this permission notice shall be included in all
|
||||||
|
copies or substantial portions of the Software.
|
||||||
|
|
||||||
|
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||||
|
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||||
|
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||||
|
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||||
|
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||||
|
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||||
|
SOFTWARE.
|
21
GPT_SoVITS/BigVGAN/incl_licenses/LICENSE_7
Normal file
21
GPT_SoVITS/BigVGAN/incl_licenses/LICENSE_7
Normal file
@ -0,0 +1,21 @@
|
|||||||
|
MIT License
|
||||||
|
|
||||||
|
Copyright (c) 2023 Charactr Inc.
|
||||||
|
|
||||||
|
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||||
|
of this software and associated documentation files (the "Software"), to deal
|
||||||
|
in the Software without restriction, including without limitation the rights
|
||||||
|
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||||
|
copies of the Software, and to permit persons to whom the Software is
|
||||||
|
furnished to do so, subject to the following conditions:
|
||||||
|
|
||||||
|
The above copyright notice and this permission notice shall be included in all
|
||||||
|
copies or substantial portions of the Software.
|
||||||
|
|
||||||
|
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||||
|
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||||
|
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||||
|
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||||
|
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||||
|
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||||
|
SOFTWARE.
|
21
GPT_SoVITS/BigVGAN/incl_licenses/LICENSE_8
Normal file
21
GPT_SoVITS/BigVGAN/incl_licenses/LICENSE_8
Normal file
@ -0,0 +1,21 @@
|
|||||||
|
MIT License
|
||||||
|
|
||||||
|
Copyright (c) 2023 Amphion
|
||||||
|
|
||||||
|
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||||
|
of this software and associated documentation files (the "Software"), to deal
|
||||||
|
in the Software without restriction, including without limitation the rights
|
||||||
|
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||||
|
copies of the Software, and to permit persons to whom the Software is
|
||||||
|
furnished to do so, subject to the following conditions:
|
||||||
|
|
||||||
|
The above copyright notice and this permission notice shall be included in all
|
||||||
|
copies or substantial portions of the Software.
|
||||||
|
|
||||||
|
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||||
|
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||||
|
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||||
|
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||||
|
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||||
|
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||||
|
SOFTWARE.
|
85
GPT_SoVITS/BigVGAN/inference.py
Normal file
85
GPT_SoVITS/BigVGAN/inference.py
Normal file
@ -0,0 +1,85 @@
|
|||||||
|
# Adapted from https://github.com/jik876/hifi-gan under the MIT license.
|
||||||
|
# LICENSE is in incl_licenses directory.
|
||||||
|
|
||||||
|
from __future__ import absolute_import, division, print_function, unicode_literals
|
||||||
|
|
||||||
|
import os
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
import torch
|
||||||
|
import librosa
|
||||||
|
from utils import load_checkpoint
|
||||||
|
from meldataset import get_mel_spectrogram
|
||||||
|
from scipy.io.wavfile import write
|
||||||
|
from env import AttrDict
|
||||||
|
from meldataset import MAX_WAV_VALUE
|
||||||
|
from bigvgan import BigVGAN as Generator
|
||||||
|
|
||||||
|
h = None
|
||||||
|
device = None
|
||||||
|
torch.backends.cudnn.benchmark = False
|
||||||
|
|
||||||
|
|
||||||
|
def inference(a, h):
|
||||||
|
generator = Generator(h, use_cuda_kernel=a.use_cuda_kernel).to(device)
|
||||||
|
|
||||||
|
state_dict_g = load_checkpoint(a.checkpoint_file, device)
|
||||||
|
generator.load_state_dict(state_dict_g["generator"])
|
||||||
|
|
||||||
|
filelist = os.listdir(a.input_wavs_dir)
|
||||||
|
|
||||||
|
os.makedirs(a.output_dir, exist_ok=True)
|
||||||
|
|
||||||
|
generator.eval()
|
||||||
|
generator.remove_weight_norm()
|
||||||
|
with torch.no_grad():
|
||||||
|
for i, filname in enumerate(filelist):
|
||||||
|
# Load the ground truth audio and resample if necessary
|
||||||
|
wav, sr = librosa.load(os.path.join(a.input_wavs_dir, filname), sr=h.sampling_rate, mono=True)
|
||||||
|
wav = torch.FloatTensor(wav).to(device)
|
||||||
|
# Compute mel spectrogram from the ground truth audio
|
||||||
|
x = get_mel_spectrogram(wav.unsqueeze(0), generator.h)
|
||||||
|
|
||||||
|
y_g_hat = generator(x)
|
||||||
|
|
||||||
|
audio = y_g_hat.squeeze()
|
||||||
|
audio = audio * MAX_WAV_VALUE
|
||||||
|
audio = audio.cpu().numpy().astype("int16")
|
||||||
|
|
||||||
|
output_file = os.path.join(a.output_dir, os.path.splitext(filname)[0] + "_generated.wav")
|
||||||
|
write(output_file, h.sampling_rate, audio)
|
||||||
|
print(output_file)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
print("Initializing Inference Process..")
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("--input_wavs_dir", default="test_files")
|
||||||
|
parser.add_argument("--output_dir", default="generated_files")
|
||||||
|
parser.add_argument("--checkpoint_file", required=True)
|
||||||
|
parser.add_argument("--use_cuda_kernel", action="store_true", default=False)
|
||||||
|
|
||||||
|
a = parser.parse_args()
|
||||||
|
|
||||||
|
config_file = os.path.join(os.path.split(a.checkpoint_file)[0], "config.json")
|
||||||
|
with open(config_file) as f:
|
||||||
|
data = f.read()
|
||||||
|
|
||||||
|
global h
|
||||||
|
json_config = json.loads(data)
|
||||||
|
h = AttrDict(json_config)
|
||||||
|
|
||||||
|
torch.manual_seed(h.seed)
|
||||||
|
global device
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
torch.cuda.manual_seed(h.seed)
|
||||||
|
device = torch.device("cuda")
|
||||||
|
else:
|
||||||
|
device = torch.device("cpu")
|
||||||
|
|
||||||
|
inference(a, h)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
100
GPT_SoVITS/BigVGAN/inference_e2e.py
Normal file
100
GPT_SoVITS/BigVGAN/inference_e2e.py
Normal file
@ -0,0 +1,100 @@
|
|||||||
|
# Adapted from https://github.com/jik876/hifi-gan under the MIT license.
|
||||||
|
# LICENSE is in incl_licenses directory.
|
||||||
|
|
||||||
|
from __future__ import absolute_import, division, print_function, unicode_literals
|
||||||
|
|
||||||
|
import glob
|
||||||
|
import os
|
||||||
|
import numpy as np
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
import torch
|
||||||
|
from scipy.io.wavfile import write
|
||||||
|
from env import AttrDict
|
||||||
|
from meldataset import MAX_WAV_VALUE
|
||||||
|
from bigvgan import BigVGAN as Generator
|
||||||
|
|
||||||
|
h = None
|
||||||
|
device = None
|
||||||
|
torch.backends.cudnn.benchmark = False
|
||||||
|
|
||||||
|
|
||||||
|
def load_checkpoint(filepath, device):
|
||||||
|
assert os.path.isfile(filepath)
|
||||||
|
print(f"Loading '{filepath}'")
|
||||||
|
checkpoint_dict = torch.load(filepath, map_location=device)
|
||||||
|
print("Complete.")
|
||||||
|
return checkpoint_dict
|
||||||
|
|
||||||
|
|
||||||
|
def scan_checkpoint(cp_dir, prefix):
|
||||||
|
pattern = os.path.join(cp_dir, prefix + "*")
|
||||||
|
cp_list = glob.glob(pattern)
|
||||||
|
if len(cp_list) == 0:
|
||||||
|
return ""
|
||||||
|
return sorted(cp_list)[-1]
|
||||||
|
|
||||||
|
|
||||||
|
def inference(a, h):
|
||||||
|
generator = Generator(h, use_cuda_kernel=a.use_cuda_kernel).to(device)
|
||||||
|
|
||||||
|
state_dict_g = load_checkpoint(a.checkpoint_file, device)
|
||||||
|
generator.load_state_dict(state_dict_g["generator"])
|
||||||
|
|
||||||
|
filelist = os.listdir(a.input_mels_dir)
|
||||||
|
|
||||||
|
os.makedirs(a.output_dir, exist_ok=True)
|
||||||
|
|
||||||
|
generator.eval()
|
||||||
|
generator.remove_weight_norm()
|
||||||
|
with torch.no_grad():
|
||||||
|
for i, filname in enumerate(filelist):
|
||||||
|
# Load the mel spectrogram in .npy format
|
||||||
|
x = np.load(os.path.join(a.input_mels_dir, filname))
|
||||||
|
x = torch.FloatTensor(x).to(device)
|
||||||
|
if len(x.shape) == 2:
|
||||||
|
x = x.unsqueeze(0)
|
||||||
|
|
||||||
|
y_g_hat = generator(x)
|
||||||
|
|
||||||
|
audio = y_g_hat.squeeze()
|
||||||
|
audio = audio * MAX_WAV_VALUE
|
||||||
|
audio = audio.cpu().numpy().astype("int16")
|
||||||
|
|
||||||
|
output_file = os.path.join(a.output_dir, os.path.splitext(filname)[0] + "_generated_e2e.wav")
|
||||||
|
write(output_file, h.sampling_rate, audio)
|
||||||
|
print(output_file)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
print("Initializing Inference Process..")
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("--input_mels_dir", default="test_mel_files")
|
||||||
|
parser.add_argument("--output_dir", default="generated_files_from_mel")
|
||||||
|
parser.add_argument("--checkpoint_file", required=True)
|
||||||
|
parser.add_argument("--use_cuda_kernel", action="store_true", default=False)
|
||||||
|
|
||||||
|
a = parser.parse_args()
|
||||||
|
|
||||||
|
config_file = os.path.join(os.path.split(a.checkpoint_file)[0], "config.json")
|
||||||
|
with open(config_file) as f:
|
||||||
|
data = f.read()
|
||||||
|
|
||||||
|
global h
|
||||||
|
json_config = json.loads(data)
|
||||||
|
h = AttrDict(json_config)
|
||||||
|
|
||||||
|
torch.manual_seed(h.seed)
|
||||||
|
global device
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
torch.cuda.manual_seed(h.seed)
|
||||||
|
device = torch.device("cuda")
|
||||||
|
else:
|
||||||
|
device = torch.device("cpu")
|
||||||
|
|
||||||
|
inference(a, h)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
238
GPT_SoVITS/BigVGAN/loss.py
Normal file
238
GPT_SoVITS/BigVGAN/loss.py
Normal file
@ -0,0 +1,238 @@
|
|||||||
|
# Copyright (c) 2024 NVIDIA CORPORATION.
|
||||||
|
# Licensed under the MIT license.
|
||||||
|
|
||||||
|
# Adapted from https://github.com/jik876/hifi-gan under the MIT license.
|
||||||
|
# LICENSE is in incl_licenses directory.
|
||||||
|
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from librosa.filters import mel as librosa_mel_fn
|
||||||
|
from scipy import signal
|
||||||
|
|
||||||
|
import typing
|
||||||
|
from typing import List, Tuple
|
||||||
|
from collections import namedtuple
|
||||||
|
import math
|
||||||
|
import functools
|
||||||
|
|
||||||
|
|
||||||
|
# Adapted from https://github.com/descriptinc/descript-audio-codec/blob/main/dac/nn/loss.py under the MIT license.
|
||||||
|
# LICENSE is in incl_licenses directory.
|
||||||
|
class MultiScaleMelSpectrogramLoss(nn.Module):
|
||||||
|
"""Compute distance between mel spectrograms. Can be used
|
||||||
|
in a multi-scale way.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
n_mels : List[int]
|
||||||
|
Number of mels per STFT, by default [5, 10, 20, 40, 80, 160, 320],
|
||||||
|
window_lengths : List[int], optional
|
||||||
|
Length of each window of each STFT, by default [32, 64, 128, 256, 512, 1024, 2048]
|
||||||
|
loss_fn : typing.Callable, optional
|
||||||
|
How to compare each loss, by default nn.L1Loss()
|
||||||
|
clamp_eps : float, optional
|
||||||
|
Clamp on the log magnitude, below, by default 1e-5
|
||||||
|
mag_weight : float, optional
|
||||||
|
Weight of raw magnitude portion of loss, by default 0.0 (no ampliciation on mag part)
|
||||||
|
log_weight : float, optional
|
||||||
|
Weight of log magnitude portion of loss, by default 1.0
|
||||||
|
pow : float, optional
|
||||||
|
Power to raise magnitude to before taking log, by default 1.0
|
||||||
|
weight : float, optional
|
||||||
|
Weight of this loss, by default 1.0
|
||||||
|
match_stride : bool, optional
|
||||||
|
Whether to match the stride of convolutional layers, by default False
|
||||||
|
|
||||||
|
Implementation copied from: https://github.com/descriptinc/lyrebird-audiotools/blob/961786aa1a9d628cca0c0486e5885a457fe70c1a/audiotools/metrics/spectral.py
|
||||||
|
Additional code copied and modified from https://github.com/descriptinc/audiotools/blob/master/audiotools/core/audio_signal.py
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
sampling_rate: int,
|
||||||
|
n_mels: List[int] = [5, 10, 20, 40, 80, 160, 320],
|
||||||
|
window_lengths: List[int] = [32, 64, 128, 256, 512, 1024, 2048],
|
||||||
|
loss_fn: typing.Callable = nn.L1Loss(),
|
||||||
|
clamp_eps: float = 1e-5,
|
||||||
|
mag_weight: float = 0.0,
|
||||||
|
log_weight: float = 1.0,
|
||||||
|
pow: float = 1.0,
|
||||||
|
weight: float = 1.0,
|
||||||
|
match_stride: bool = False,
|
||||||
|
mel_fmin: List[float] = [0, 0, 0, 0, 0, 0, 0],
|
||||||
|
mel_fmax: List[float] = [None, None, None, None, None, None, None],
|
||||||
|
window_type: str = "hann",
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.sampling_rate = sampling_rate
|
||||||
|
|
||||||
|
STFTParams = namedtuple(
|
||||||
|
"STFTParams",
|
||||||
|
["window_length", "hop_length", "window_type", "match_stride"],
|
||||||
|
)
|
||||||
|
|
||||||
|
self.stft_params = [
|
||||||
|
STFTParams(
|
||||||
|
window_length=w,
|
||||||
|
hop_length=w // 4,
|
||||||
|
match_stride=match_stride,
|
||||||
|
window_type=window_type,
|
||||||
|
)
|
||||||
|
for w in window_lengths
|
||||||
|
]
|
||||||
|
self.n_mels = n_mels
|
||||||
|
self.loss_fn = loss_fn
|
||||||
|
self.clamp_eps = clamp_eps
|
||||||
|
self.log_weight = log_weight
|
||||||
|
self.mag_weight = mag_weight
|
||||||
|
self.weight = weight
|
||||||
|
self.mel_fmin = mel_fmin
|
||||||
|
self.mel_fmax = mel_fmax
|
||||||
|
self.pow = pow
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
@functools.lru_cache(None)
|
||||||
|
def get_window(
|
||||||
|
window_type,
|
||||||
|
window_length,
|
||||||
|
):
|
||||||
|
return signal.get_window(window_type, window_length)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
@functools.lru_cache(None)
|
||||||
|
def get_mel_filters(sr, n_fft, n_mels, fmin, fmax):
|
||||||
|
return librosa_mel_fn(sr=sr, n_fft=n_fft, n_mels=n_mels, fmin=fmin, fmax=fmax)
|
||||||
|
|
||||||
|
def mel_spectrogram(
|
||||||
|
self,
|
||||||
|
wav,
|
||||||
|
n_mels,
|
||||||
|
fmin,
|
||||||
|
fmax,
|
||||||
|
window_length,
|
||||||
|
hop_length,
|
||||||
|
match_stride,
|
||||||
|
window_type,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Mirrors AudioSignal.mel_spectrogram used by BigVGAN-v2 training from:
|
||||||
|
https://github.com/descriptinc/audiotools/blob/master/audiotools/core/audio_signal.py
|
||||||
|
"""
|
||||||
|
B, C, T = wav.shape
|
||||||
|
|
||||||
|
if match_stride:
|
||||||
|
assert hop_length == window_length // 4, "For match_stride, hop must equal n_fft // 4"
|
||||||
|
right_pad = math.ceil(T / hop_length) * hop_length - T
|
||||||
|
pad = (window_length - hop_length) // 2
|
||||||
|
else:
|
||||||
|
right_pad = 0
|
||||||
|
pad = 0
|
||||||
|
|
||||||
|
wav = torch.nn.functional.pad(wav, (pad, pad + right_pad), mode="reflect")
|
||||||
|
|
||||||
|
window = self.get_window(window_type, window_length)
|
||||||
|
window = torch.from_numpy(window).to(wav.device).float()
|
||||||
|
|
||||||
|
stft = torch.stft(
|
||||||
|
wav.reshape(-1, T),
|
||||||
|
n_fft=window_length,
|
||||||
|
hop_length=hop_length,
|
||||||
|
window=window,
|
||||||
|
return_complex=True,
|
||||||
|
center=True,
|
||||||
|
)
|
||||||
|
_, nf, nt = stft.shape
|
||||||
|
stft = stft.reshape(B, C, nf, nt)
|
||||||
|
if match_stride:
|
||||||
|
"""
|
||||||
|
Drop first two and last two frames, which are added, because of padding. Now num_frames * hop_length = num_samples.
|
||||||
|
"""
|
||||||
|
stft = stft[..., 2:-2]
|
||||||
|
magnitude = torch.abs(stft)
|
||||||
|
|
||||||
|
nf = magnitude.shape[2]
|
||||||
|
mel_basis = self.get_mel_filters(self.sampling_rate, 2 * (nf - 1), n_mels, fmin, fmax)
|
||||||
|
mel_basis = torch.from_numpy(mel_basis).to(wav.device)
|
||||||
|
mel_spectrogram = magnitude.transpose(2, -1) @ mel_basis.T
|
||||||
|
mel_spectrogram = mel_spectrogram.transpose(-1, 2)
|
||||||
|
|
||||||
|
return mel_spectrogram
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""Computes mel loss between an estimate and a reference
|
||||||
|
signal.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
x : torch.Tensor
|
||||||
|
Estimate signal
|
||||||
|
y : torch.Tensor
|
||||||
|
Reference signal
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
torch.Tensor
|
||||||
|
Mel loss.
|
||||||
|
"""
|
||||||
|
|
||||||
|
loss = 0.0
|
||||||
|
for n_mels, fmin, fmax, s in zip(self.n_mels, self.mel_fmin, self.mel_fmax, self.stft_params):
|
||||||
|
kwargs = {
|
||||||
|
"n_mels": n_mels,
|
||||||
|
"fmin": fmin,
|
||||||
|
"fmax": fmax,
|
||||||
|
"window_length": s.window_length,
|
||||||
|
"hop_length": s.hop_length,
|
||||||
|
"match_stride": s.match_stride,
|
||||||
|
"window_type": s.window_type,
|
||||||
|
}
|
||||||
|
|
||||||
|
x_mels = self.mel_spectrogram(x, **kwargs)
|
||||||
|
y_mels = self.mel_spectrogram(y, **kwargs)
|
||||||
|
x_logmels = torch.log(x_mels.clamp(min=self.clamp_eps).pow(self.pow)) / torch.log(torch.tensor(10.0))
|
||||||
|
y_logmels = torch.log(y_mels.clamp(min=self.clamp_eps).pow(self.pow)) / torch.log(torch.tensor(10.0))
|
||||||
|
|
||||||
|
loss += self.log_weight * self.loss_fn(x_logmels, y_logmels)
|
||||||
|
loss += self.mag_weight * self.loss_fn(x_logmels, y_logmels)
|
||||||
|
|
||||||
|
return loss
|
||||||
|
|
||||||
|
|
||||||
|
# Loss functions
|
||||||
|
def feature_loss(fmap_r: List[List[torch.Tensor]], fmap_g: List[List[torch.Tensor]]) -> torch.Tensor:
|
||||||
|
loss = 0
|
||||||
|
for dr, dg in zip(fmap_r, fmap_g):
|
||||||
|
for rl, gl in zip(dr, dg):
|
||||||
|
loss += torch.mean(torch.abs(rl - gl))
|
||||||
|
|
||||||
|
return loss * 2 # This equates to lambda=2.0 for the feature matching loss
|
||||||
|
|
||||||
|
|
||||||
|
def discriminator_loss(
|
||||||
|
disc_real_outputs: List[torch.Tensor], disc_generated_outputs: List[torch.Tensor]
|
||||||
|
) -> Tuple[torch.Tensor, List[torch.Tensor], List[torch.Tensor]]:
|
||||||
|
loss = 0
|
||||||
|
r_losses = []
|
||||||
|
g_losses = []
|
||||||
|
for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
|
||||||
|
r_loss = torch.mean((1 - dr) ** 2)
|
||||||
|
g_loss = torch.mean(dg**2)
|
||||||
|
loss += r_loss + g_loss
|
||||||
|
r_losses.append(r_loss.item())
|
||||||
|
g_losses.append(g_loss.item())
|
||||||
|
|
||||||
|
return loss, r_losses, g_losses
|
||||||
|
|
||||||
|
|
||||||
|
def generator_loss(
|
||||||
|
disc_outputs: List[torch.Tensor],
|
||||||
|
) -> Tuple[torch.Tensor, List[torch.Tensor]]:
|
||||||
|
loss = 0
|
||||||
|
gen_losses = []
|
||||||
|
for dg in disc_outputs:
|
||||||
|
l = torch.mean((1 - dg) ** 2)
|
||||||
|
gen_losses.append(l)
|
||||||
|
loss += l
|
||||||
|
|
||||||
|
return loss, gen_losses
|
370
GPT_SoVITS/BigVGAN/meldataset.py
Normal file
370
GPT_SoVITS/BigVGAN/meldataset.py
Normal file
@ -0,0 +1,370 @@
|
|||||||
|
# Copyright (c) 2024 NVIDIA CORPORATION.
|
||||||
|
# Licensed under the MIT license.
|
||||||
|
|
||||||
|
# Adapted from https://github.com/jik876/hifi-gan under the MIT license.
|
||||||
|
# LICENSE is in incl_licenses directory.
|
||||||
|
|
||||||
|
import math
|
||||||
|
import os
|
||||||
|
import random
|
||||||
|
import torch
|
||||||
|
import torch.utils.data
|
||||||
|
import numpy as np
|
||||||
|
import librosa
|
||||||
|
from librosa.filters import mel as librosa_mel_fn
|
||||||
|
import pathlib
|
||||||
|
from tqdm import tqdm
|
||||||
|
from typing import List, Tuple, Optional
|
||||||
|
from .env import AttrDict
|
||||||
|
|
||||||
|
MAX_WAV_VALUE = 32767.0 # NOTE: 32768.0 -1 to prevent int16 overflow (results in popping sound in corner cases)
|
||||||
|
|
||||||
|
|
||||||
|
def dynamic_range_compression(x, C=1, clip_val=1e-5):
|
||||||
|
return np.log(np.clip(x, a_min=clip_val, a_max=None) * C)
|
||||||
|
|
||||||
|
|
||||||
|
def dynamic_range_decompression(x, C=1):
|
||||||
|
return np.exp(x) / C
|
||||||
|
|
||||||
|
|
||||||
|
def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
|
||||||
|
return torch.log(torch.clamp(x, min=clip_val) * C)
|
||||||
|
|
||||||
|
|
||||||
|
def dynamic_range_decompression_torch(x, C=1):
|
||||||
|
return torch.exp(x) / C
|
||||||
|
|
||||||
|
|
||||||
|
def spectral_normalize_torch(magnitudes):
|
||||||
|
return dynamic_range_compression_torch(magnitudes)
|
||||||
|
|
||||||
|
|
||||||
|
def spectral_de_normalize_torch(magnitudes):
|
||||||
|
return dynamic_range_decompression_torch(magnitudes)
|
||||||
|
|
||||||
|
|
||||||
|
mel_basis_cache = {}
|
||||||
|
hann_window_cache = {}
|
||||||
|
|
||||||
|
|
||||||
|
def mel_spectrogram(
|
||||||
|
y: torch.Tensor,
|
||||||
|
n_fft: int,
|
||||||
|
num_mels: int,
|
||||||
|
sampling_rate: int,
|
||||||
|
hop_size: int,
|
||||||
|
win_size: int,
|
||||||
|
fmin: int,
|
||||||
|
fmax: int = None,
|
||||||
|
center: bool = False,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Calculate the mel spectrogram of an input signal.
|
||||||
|
This function uses slaney norm for the librosa mel filterbank (using librosa.filters.mel) and uses Hann window for STFT (using torch.stft).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
y (torch.Tensor): Input signal.
|
||||||
|
n_fft (int): FFT size.
|
||||||
|
num_mels (int): Number of mel bins.
|
||||||
|
sampling_rate (int): Sampling rate of the input signal.
|
||||||
|
hop_size (int): Hop size for STFT.
|
||||||
|
win_size (int): Window size for STFT.
|
||||||
|
fmin (int): Minimum frequency for mel filterbank.
|
||||||
|
fmax (int): Maximum frequency for mel filterbank. If None, defaults to half the sampling rate (fmax = sr / 2.0) inside librosa_mel_fn
|
||||||
|
center (bool): Whether to pad the input to center the frames. Default is False.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: Mel spectrogram.
|
||||||
|
"""
|
||||||
|
if torch.min(y) < -1.0:
|
||||||
|
print(f"[WARNING] Min value of input waveform signal is {torch.min(y)}")
|
||||||
|
if torch.max(y) > 1.0:
|
||||||
|
print(f"[WARNING] Max value of input waveform signal is {torch.max(y)}")
|
||||||
|
|
||||||
|
device = y.device
|
||||||
|
key = f"{n_fft}_{num_mels}_{sampling_rate}_{hop_size}_{win_size}_{fmin}_{fmax}_{device}"
|
||||||
|
|
||||||
|
if key not in mel_basis_cache:
|
||||||
|
mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax)
|
||||||
|
mel_basis_cache[key] = torch.from_numpy(mel).float().to(device)
|
||||||
|
hann_window_cache[key] = torch.hann_window(win_size).to(device)
|
||||||
|
|
||||||
|
mel_basis = mel_basis_cache[key]
|
||||||
|
hann_window = hann_window_cache[key]
|
||||||
|
|
||||||
|
padding = (n_fft - hop_size) // 2
|
||||||
|
y = torch.nn.functional.pad(y.unsqueeze(1), (padding, padding), mode="reflect").squeeze(1)
|
||||||
|
|
||||||
|
spec = torch.stft(
|
||||||
|
y,
|
||||||
|
n_fft,
|
||||||
|
hop_length=hop_size,
|
||||||
|
win_length=win_size,
|
||||||
|
window=hann_window,
|
||||||
|
center=center,
|
||||||
|
pad_mode="reflect",
|
||||||
|
normalized=False,
|
||||||
|
onesided=True,
|
||||||
|
return_complex=True,
|
||||||
|
)
|
||||||
|
spec = torch.sqrt(torch.view_as_real(spec).pow(2).sum(-1) + 1e-9)
|
||||||
|
|
||||||
|
mel_spec = torch.matmul(mel_basis, spec)
|
||||||
|
mel_spec = spectral_normalize_torch(mel_spec)
|
||||||
|
|
||||||
|
return mel_spec
|
||||||
|
|
||||||
|
|
||||||
|
def get_mel_spectrogram(wav, h):
|
||||||
|
"""
|
||||||
|
Generate mel spectrogram from a waveform using given hyperparameters.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
wav (torch.Tensor): Input waveform.
|
||||||
|
h: Hyperparameters object with attributes n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: Mel spectrogram.
|
||||||
|
"""
|
||||||
|
return mel_spectrogram(
|
||||||
|
wav,
|
||||||
|
h.n_fft,
|
||||||
|
h.num_mels,
|
||||||
|
h.sampling_rate,
|
||||||
|
h.hop_size,
|
||||||
|
h.win_size,
|
||||||
|
h.fmin,
|
||||||
|
h.fmax,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_dataset_filelist(a):
|
||||||
|
training_files = []
|
||||||
|
validation_files = []
|
||||||
|
list_unseen_validation_files = []
|
||||||
|
|
||||||
|
with open(a.input_training_file, "r", encoding="utf-8") as fi:
|
||||||
|
training_files = [
|
||||||
|
os.path.join(a.input_wavs_dir, x.split("|")[0] + ".wav") for x in fi.read().split("\n") if len(x) > 0
|
||||||
|
]
|
||||||
|
print(f"first training file: {training_files[0]}")
|
||||||
|
|
||||||
|
with open(a.input_validation_file, "r", encoding="utf-8") as fi:
|
||||||
|
validation_files = [
|
||||||
|
os.path.join(a.input_wavs_dir, x.split("|")[0] + ".wav") for x in fi.read().split("\n") if len(x) > 0
|
||||||
|
]
|
||||||
|
print(f"first validation file: {validation_files[0]}")
|
||||||
|
|
||||||
|
for i in range(len(a.list_input_unseen_validation_file)):
|
||||||
|
with open(a.list_input_unseen_validation_file[i], "r", encoding="utf-8") as fi:
|
||||||
|
unseen_validation_files = [
|
||||||
|
os.path.join(a.list_input_unseen_wavs_dir[i], x.split("|")[0] + ".wav")
|
||||||
|
for x in fi.read().split("\n")
|
||||||
|
if len(x) > 0
|
||||||
|
]
|
||||||
|
print(f"first unseen {i}th validation fileset: {unseen_validation_files[0]}")
|
||||||
|
list_unseen_validation_files.append(unseen_validation_files)
|
||||||
|
|
||||||
|
return training_files, validation_files, list_unseen_validation_files
|
||||||
|
|
||||||
|
|
||||||
|
class MelDataset(torch.utils.data.Dataset):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
training_files: List[str],
|
||||||
|
hparams: AttrDict,
|
||||||
|
segment_size: int,
|
||||||
|
n_fft: int,
|
||||||
|
num_mels: int,
|
||||||
|
hop_size: int,
|
||||||
|
win_size: int,
|
||||||
|
sampling_rate: int,
|
||||||
|
fmin: int,
|
||||||
|
fmax: Optional[int],
|
||||||
|
split: bool = True,
|
||||||
|
shuffle: bool = True,
|
||||||
|
device: str = None,
|
||||||
|
fmax_loss: Optional[int] = None,
|
||||||
|
fine_tuning: bool = False,
|
||||||
|
base_mels_path: str = None,
|
||||||
|
is_seen: bool = True,
|
||||||
|
):
|
||||||
|
self.audio_files = training_files
|
||||||
|
random.seed(1234)
|
||||||
|
if shuffle:
|
||||||
|
random.shuffle(self.audio_files)
|
||||||
|
self.hparams = hparams
|
||||||
|
self.is_seen = is_seen
|
||||||
|
if self.is_seen:
|
||||||
|
self.name = pathlib.Path(self.audio_files[0]).parts[0]
|
||||||
|
else:
|
||||||
|
self.name = "-".join(pathlib.Path(self.audio_files[0]).parts[:2]).strip("/")
|
||||||
|
|
||||||
|
self.segment_size = segment_size
|
||||||
|
self.sampling_rate = sampling_rate
|
||||||
|
self.split = split
|
||||||
|
self.n_fft = n_fft
|
||||||
|
self.num_mels = num_mels
|
||||||
|
self.hop_size = hop_size
|
||||||
|
self.win_size = win_size
|
||||||
|
self.fmin = fmin
|
||||||
|
self.fmax = fmax
|
||||||
|
self.fmax_loss = fmax_loss
|
||||||
|
self.device = device
|
||||||
|
self.fine_tuning = fine_tuning
|
||||||
|
self.base_mels_path = base_mels_path
|
||||||
|
|
||||||
|
print("[INFO] checking dataset integrity...")
|
||||||
|
for i in tqdm(range(len(self.audio_files))):
|
||||||
|
assert os.path.exists(self.audio_files[i]), f"{self.audio_files[i]} not found"
|
||||||
|
|
||||||
|
def __getitem__(self, index: int) -> Tuple[torch.Tensor, torch.Tensor, str, torch.Tensor]:
|
||||||
|
try:
|
||||||
|
filename = self.audio_files[index]
|
||||||
|
|
||||||
|
# Use librosa.load that ensures loading waveform into mono with [-1, 1] float values
|
||||||
|
# Audio is ndarray with shape [T_time]. Disable auto-resampling here to minimize overhead
|
||||||
|
# The on-the-fly resampling during training will be done only for the obtained random chunk
|
||||||
|
audio, source_sampling_rate = librosa.load(filename, sr=None, mono=True)
|
||||||
|
|
||||||
|
# Main logic that uses <mel, audio> pair for training BigVGAN
|
||||||
|
if not self.fine_tuning:
|
||||||
|
if self.split: # Training step
|
||||||
|
# Obtain randomized audio chunk
|
||||||
|
if source_sampling_rate != self.sampling_rate:
|
||||||
|
# Adjust segment size to crop if the source sr is different
|
||||||
|
target_segment_size = math.ceil(self.segment_size * (source_sampling_rate / self.sampling_rate))
|
||||||
|
else:
|
||||||
|
target_segment_size = self.segment_size
|
||||||
|
|
||||||
|
# Compute upper bound index for the random chunk
|
||||||
|
random_chunk_upper_bound = max(0, audio.shape[0] - target_segment_size)
|
||||||
|
|
||||||
|
# Crop or pad audio to obtain random chunk with target_segment_size
|
||||||
|
if audio.shape[0] >= target_segment_size:
|
||||||
|
audio_start = random.randint(0, random_chunk_upper_bound)
|
||||||
|
audio = audio[audio_start : audio_start + target_segment_size]
|
||||||
|
else:
|
||||||
|
audio = np.pad(
|
||||||
|
audio,
|
||||||
|
(0, target_segment_size - audio.shape[0]),
|
||||||
|
mode="constant",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Resample audio chunk to self.sampling rate
|
||||||
|
if source_sampling_rate != self.sampling_rate:
|
||||||
|
audio = librosa.resample(
|
||||||
|
audio,
|
||||||
|
orig_sr=source_sampling_rate,
|
||||||
|
target_sr=self.sampling_rate,
|
||||||
|
)
|
||||||
|
if audio.shape[0] > self.segment_size:
|
||||||
|
# trim last elements to match self.segment_size (e.g., 16385 for 44khz downsampled to 24khz -> 16384)
|
||||||
|
audio = audio[: self.segment_size]
|
||||||
|
|
||||||
|
else: # Validation step
|
||||||
|
# Resample full audio clip to target sampling rate
|
||||||
|
if source_sampling_rate != self.sampling_rate:
|
||||||
|
audio = librosa.resample(
|
||||||
|
audio,
|
||||||
|
orig_sr=source_sampling_rate,
|
||||||
|
target_sr=self.sampling_rate,
|
||||||
|
)
|
||||||
|
# Trim last elements to match audio length to self.hop_size * n for evaluation
|
||||||
|
if (audio.shape[0] % self.hop_size) != 0:
|
||||||
|
audio = audio[: -(audio.shape[0] % self.hop_size)]
|
||||||
|
|
||||||
|
# BigVGAN is trained using volume-normalized waveform
|
||||||
|
audio = librosa.util.normalize(audio) * 0.95
|
||||||
|
|
||||||
|
# Cast ndarray to torch tensor
|
||||||
|
audio = torch.FloatTensor(audio)
|
||||||
|
audio = audio.unsqueeze(0) # [B(1), self.segment_size]
|
||||||
|
|
||||||
|
# Compute mel spectrogram corresponding to audio
|
||||||
|
mel = mel_spectrogram(
|
||||||
|
audio,
|
||||||
|
self.n_fft,
|
||||||
|
self.num_mels,
|
||||||
|
self.sampling_rate,
|
||||||
|
self.hop_size,
|
||||||
|
self.win_size,
|
||||||
|
self.fmin,
|
||||||
|
self.fmax,
|
||||||
|
center=False,
|
||||||
|
) # [B(1), self.num_mels, self.segment_size // self.hop_size]
|
||||||
|
|
||||||
|
# Fine-tuning logic that uses pre-computed mel. Example: Using TTS model-generated mel as input
|
||||||
|
else:
|
||||||
|
# For fine-tuning, assert that the waveform is in the defined sampling_rate
|
||||||
|
# Fine-tuning won't support on-the-fly resampling to be fool-proof (the dataset should have been prepared properly)
|
||||||
|
assert source_sampling_rate == self.sampling_rate, (
|
||||||
|
f"For fine_tuning, waveform must be in the spcified sampling rate {self.sampling_rate}, got {source_sampling_rate}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Cast ndarray to torch tensor
|
||||||
|
audio = torch.FloatTensor(audio)
|
||||||
|
audio = audio.unsqueeze(0) # [B(1), T_time]
|
||||||
|
|
||||||
|
# Load pre-computed mel from disk
|
||||||
|
mel = np.load(
|
||||||
|
os.path.join(
|
||||||
|
self.base_mels_path,
|
||||||
|
os.path.splitext(os.path.split(filename)[-1])[0] + ".npy",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
mel = torch.from_numpy(mel)
|
||||||
|
|
||||||
|
if len(mel.shape) < 3:
|
||||||
|
mel = mel.unsqueeze(0) # ensure [B, C, T]
|
||||||
|
|
||||||
|
if self.split:
|
||||||
|
frames_per_seg = math.ceil(self.segment_size / self.hop_size)
|
||||||
|
|
||||||
|
if audio.size(1) >= self.segment_size:
|
||||||
|
mel_start = random.randint(0, mel.size(2) - frames_per_seg - 1)
|
||||||
|
mel = mel[:, :, mel_start : mel_start + frames_per_seg]
|
||||||
|
audio = audio[
|
||||||
|
:,
|
||||||
|
mel_start * self.hop_size : (mel_start + frames_per_seg) * self.hop_size,
|
||||||
|
]
|
||||||
|
|
||||||
|
# Pad pre-computed mel and audio to match length to ensuring fine-tuning without error.
|
||||||
|
# NOTE: this may introduce a single-frame misalignment of the <pre-computed mel, audio>
|
||||||
|
# To remove possible misalignment, it is recommended to prepare the <pre-computed mel, audio> pair where the audio length is the integer multiple of self.hop_size
|
||||||
|
mel = torch.nn.functional.pad(mel, (0, frames_per_seg - mel.size(2)), "constant")
|
||||||
|
audio = torch.nn.functional.pad(audio, (0, self.segment_size - audio.size(1)), "constant")
|
||||||
|
|
||||||
|
# Compute mel_loss used by spectral regression objective. Uses self.fmax_loss instead (usually None)
|
||||||
|
mel_loss = mel_spectrogram(
|
||||||
|
audio,
|
||||||
|
self.n_fft,
|
||||||
|
self.num_mels,
|
||||||
|
self.sampling_rate,
|
||||||
|
self.hop_size,
|
||||||
|
self.win_size,
|
||||||
|
self.fmin,
|
||||||
|
self.fmax_loss,
|
||||||
|
center=False,
|
||||||
|
) # [B(1), self.num_mels, self.segment_size // self.hop_size]
|
||||||
|
|
||||||
|
# Shape sanity checks
|
||||||
|
assert (
|
||||||
|
audio.shape[1] == mel.shape[2] * self.hop_size and audio.shape[1] == mel_loss.shape[2] * self.hop_size
|
||||||
|
), (
|
||||||
|
f"Audio length must be mel frame length * hop_size. Got audio shape {audio.shape} mel shape {mel.shape} mel_loss shape {mel_loss.shape}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return (mel.squeeze(), audio.squeeze(0), filename, mel_loss.squeeze())
|
||||||
|
|
||||||
|
# If it encounters error during loading the data, skip this sample and load random other sample to the batch
|
||||||
|
except Exception as e:
|
||||||
|
if self.fine_tuning:
|
||||||
|
raise e # Terminate training if it is fine-tuning. The dataset should have been prepared properly.
|
||||||
|
else:
|
||||||
|
print(f"[WARNING] Failed to load waveform, skipping! filename: {filename} Error: {e}")
|
||||||
|
return self[random.randrange(len(self))]
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.audio_files)
|
1
GPT_SoVITS/BigVGAN/nv-modelcard++/.gitkeep
Normal file
1
GPT_SoVITS/BigVGAN/nv-modelcard++/.gitkeep
Normal file
@ -0,0 +1 @@
|
|||||||
|
|
4
GPT_SoVITS/BigVGAN/nv-modelcard++/bias.md
Normal file
4
GPT_SoVITS/BigVGAN/nv-modelcard++/bias.md
Normal file
@ -0,0 +1,4 @@
|
|||||||
|
| Field | Response |
|
||||||
|
| :--------------------------------------------------------------------------------------------------------- | :--------------------------------------------------- |
|
||||||
|
| Participation considerations from adversely impacted groups protected classes in model design and testing: | None |
|
||||||
|
| Measures taken to mitigate against unwanted bias: | No measures taken to mitigate against unwanted bias. |
|
13
GPT_SoVITS/BigVGAN/nv-modelcard++/explainability.md
Normal file
13
GPT_SoVITS/BigVGAN/nv-modelcard++/explainability.md
Normal file
@ -0,0 +1,13 @@
|
|||||||
|
| Field | Response |
|
||||||
|
| :---------------------------------------------------------------------------------------------------- | :--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||||
|
| Intended Application & Domain: | Generating waveform from mel spectrogram. |
|
||||||
|
| Model Type: | Convolutional Neural Network (CNN) |
|
||||||
|
| Intended Users: | This model is intended for developers to synthesize and generate waveforms from the AI-generated mel spectrograms. |
|
||||||
|
| Output: | Audio Waveform |
|
||||||
|
| Describe how the model works: | Model generates audio waveform corresponding to the input mel spectrogram. |
|
||||||
|
| Name the adversely impacted groups this has been tested to deliver comparable outcomes regardless of: | Not Applicable |
|
||||||
|
| Technical Limitations: | This may not perform well on synthetically-generated mel spectrograms that deviate significantly from the profile of mel spectrograms on which this was trained. |
|
||||||
|
| Verified to have met prescribed NVIDIA quality standards: | Yes |
|
||||||
|
| Performance Metrics: | Perceptual Evaluation of Speech Quality (PESQ), Virtual Speech Quality Objective Listener (VISQOL), Multi-resolution STFT (MRSTFT), Mel cepstral distortion (MCD), Periodicity RMSE, Voice/Unvoiced F1 Score (V/UV F1) |
|
||||||
|
| Potential Known Risks: | This model may generate low-quality or distorted soundwaves. |
|
||||||
|
| Licensing: | https://github.com/NVIDIA/BigVGAN/blob/main/LICENSE |
|
126
GPT_SoVITS/BigVGAN/nv-modelcard++/overview.md
Normal file
126
GPT_SoVITS/BigVGAN/nv-modelcard++/overview.md
Normal file
@ -0,0 +1,126 @@
|
|||||||
|
# Model Overview
|
||||||
|
|
||||||
|
## Description:
|
||||||
|
|
||||||
|
BigVGAN is a generative AI model specialized in synthesizing audio waveforms using Mel spectrogram as inputs.
|
||||||
|
|
||||||
|
<center><img src="https://user-images.githubusercontent.com/15963413/218609148-881e39df-33af-4af9-ab95-1427c4ebf062.png" width="800"></center>
|
||||||
|
|
||||||
|
BigVGAN is a fully convolutional architecture with several upsampling blocks using transposed convolution followed by multiple residual dilated convolution layers.
|
||||||
|
|
||||||
|
BigVGAN consists of a novel module, called anti-aliased multi-periodicity composition (AMP), which is specifically designed for generating waveforms. AMP is specialized in synthesizing high-frequency and periodic soundwaves drawing inspiration from audio signal processing principles.
|
||||||
|
|
||||||
|
It applies a periodic activation function, called Snake, which provides an inductive bias to the architecture in generating periodic soundwaves. It also applies anti-aliasing filters to reduce undesired artifacts in the generated waveforms. <br>
|
||||||
|
|
||||||
|
This model is ready for commercial use.<br>
|
||||||
|
|
||||||
|
## References(s):
|
||||||
|
|
||||||
|
- [BigVGAN: A Universal Neural Vocoder with Large-Scale Training](https://arxiv.org/abs/2206.04658) <br>
|
||||||
|
- [Project Page](https://research.nvidia.com/labs/adlr/projects/bigvgan/) <br>
|
||||||
|
- [Audio Demo](https://bigvgan-demo.github.io/) <br>
|
||||||
|
|
||||||
|
## Model Architecture:
|
||||||
|
|
||||||
|
**Architecture Type:** Convolution Neural Network (CNN) <br>
|
||||||
|
**Network Architecture:** You can see the details of this model on this link: https://github.com/NVIDIA/BigVGAN and the related paper can be found here: https://arxiv.org/abs/2206.04658<br>
|
||||||
|
**Model Version:** 2.0 <br>
|
||||||
|
|
||||||
|
## Input:
|
||||||
|
|
||||||
|
**Input Type:** Audio <br>
|
||||||
|
**Input Format:** Mel Spectrogram <br>
|
||||||
|
**Input Parameters:** None <br>
|
||||||
|
**Other Properties Related to Input:** The input mel spectrogram has shape `[batch, channels, frames]`, where `channels` refers to the number of mel bands defined by the model and `frames` refers to the temporal length. The model supports arbitrary long `frames` that fits into the GPU memory.
|
||||||
|
|
||||||
|
## Output:
|
||||||
|
|
||||||
|
**Input Type:** Audio <br>
|
||||||
|
**Output Format:** Audio Waveform <br>
|
||||||
|
**Output Parameters:** None <br>
|
||||||
|
**Other Properties Related to Output:** The output audio waveform has shape `[batch, 1, time]`, where `1` refers to the mono audio channels and `time` refers to the temporal length. `time` is defined as a fixed integer multiple of input `frames`, which is an upsampling ratio of the model (`time = upsampling ratio * frames`). The output audio waveform consitutes float values with a range of `[-1, 1]`.
|
||||||
|
|
||||||
|
## Software Integration:
|
||||||
|
|
||||||
|
**Runtime Engine(s):** PyTorch
|
||||||
|
|
||||||
|
**Supported Hardware Microarchitecture Compatibility:** NVIDIA Ampere, NVIDIA Hopper, NVIDIA Lovelace, NVIDIA Turing, NVIDIA Volta <br>
|
||||||
|
|
||||||
|
## Preferred/Supported Operating System(s):
|
||||||
|
|
||||||
|
Linux
|
||||||
|
|
||||||
|
## Model Version(s):
|
||||||
|
|
||||||
|
v2.0
|
||||||
|
|
||||||
|
## Training, Testing, and Evaluation Datasets:
|
||||||
|
|
||||||
|
### Training Dataset:
|
||||||
|
|
||||||
|
The dataset contains diverse audio types, including speech in multiple languages, environmental sounds, and instruments.
|
||||||
|
|
||||||
|
**Links:**
|
||||||
|
|
||||||
|
- [AAM: Artificial Audio Multitracks Dataset](https://zenodo.org/records/5794629)
|
||||||
|
- [AudioCaps](https://audiocaps.github.io/)
|
||||||
|
- [AudioSet](https://research.google.com/audioset/index.html)
|
||||||
|
- [common-accent](https://huggingface.co/datasets/DTU54DL/common-accent)
|
||||||
|
- [Crowd Sourced Emotional Multimodal Actors Dataset (CREMA-D)](https://ieeexplore.ieee.org/document/6849440)
|
||||||
|
- [DCASE2017 Challenge, Task 4: Large-scale weakly supervised sound event detection for smart cars](https://dcase.community/challenge2017/task-large-scale-sound-event-detection)
|
||||||
|
- [FSDnoisy18k](https://zenodo.org/records/2529934)
|
||||||
|
- [Free Universal Sound Separation Dataset](https://zenodo.org/records/3694384)
|
||||||
|
- [Greatest Hits dataset](https://andrewowens.com/vis/)
|
||||||
|
- [GTZAN](https://ieeexplore.ieee.org/document/1021072)
|
||||||
|
- [JL corpus](https://www.kaggle.com/datasets/tli725/jl-corpus)
|
||||||
|
- [Medley-solos-DB: a cross-collection dataset for musical instrument recognition](https://zenodo.org/records/3464194)
|
||||||
|
- [MUSAN: A Music, Speech, and Noise Corpus](https://www.openslr.org/17/)
|
||||||
|
- [MusicBench](https://huggingface.co/datasets/amaai-lab/MusicBench)
|
||||||
|
- [MusicCaps](https://www.kaggle.com/datasets/googleai/musiccaps)
|
||||||
|
- [MusicNet](https://www.kaggle.com/datasets/imsparsh/musicnet-dataset)
|
||||||
|
- [NSynth](https://magenta.tensorflow.org/datasets/nsynth)
|
||||||
|
- [OnAir-Music-Dataset](https://github.com/sevagh/OnAir-Music-Dataset)
|
||||||
|
- [Audio Piano Triads Dataset](https://zenodo.org/records/4740877)
|
||||||
|
- [Pitch Audio Dataset (Surge synthesizer)](https://zenodo.org/records/4677097)
|
||||||
|
- [SONYC Urban Sound Tagging (SONYC-UST): a multilabel dataset from an urban acoustic sensor network](https://zenodo.org/records/3966543)
|
||||||
|
- [VocalSound: A Dataset for Improving Human Vocal Sounds Recognition](https://arxiv.org/abs/2205.03433)
|
||||||
|
- [WavText5K](https://github.com/microsoft/WavText5K)
|
||||||
|
- [CSS10: A Collection of Single Speaker Speech Datasets for 10 Languages](https://github.com/Kyubyong/css10)
|
||||||
|
- [Hi-Fi Multi-Speaker English TTS Dataset (Hi-Fi TTS)](https://www.openslr.org/109/)
|
||||||
|
- [IIIT-H Indic Speech Databases](http://festvox.org/databases/iiit_voices/)
|
||||||
|
- [Libri-Light: A Benchmark for ASR with Limited or No Supervision](https://arxiv.org/abs/1912.07875)
|
||||||
|
- [LibriTTS: A Corpus Derived from LibriSpeech for Text-to-Speech](https://www.openslr.org/60)
|
||||||
|
- [LibriTTS-R: A Restored Multi-Speaker Text-to-Speech Corpus](https://www.openslr.org/141/)
|
||||||
|
- [The SIWIS French Speech Synthesis Database](https://datashare.ed.ac.uk/handle/10283/2353)
|
||||||
|
- [Crowdsourced high-quality Colombian Spanish speech data set](https://openslr.org/72/)
|
||||||
|
- [TTS-Portuguese Corpus](https://github.com/Edresson/TTS-Portuguese-Corpus)
|
||||||
|
- [CSTR VCTK Corpus: English Multi-speaker Corpus for CSTR Voice Cloning Toolkit](https://datashare.ed.ac.uk/handle/10283/3443)
|
||||||
|
|
||||||
|
\*\* Data Collection Method by dataset <br>
|
||||||
|
|
||||||
|
- Human <br>
|
||||||
|
|
||||||
|
\*\* Labeling Method by dataset (for those with labels) <br>
|
||||||
|
|
||||||
|
- Hybrid: Automated, Human, Unknown <br>
|
||||||
|
|
||||||
|
### Evaluating Dataset:
|
||||||
|
|
||||||
|
Properties: The audio generation quality of BigVGAN is evaluated using `dev` splits of the [LibriTTS dataset](https://www.openslr.org/60/) and [Hi-Fi TTS dataset](https://www.openslr.org/109/). The datasets include speech in English language with equal balance of genders.
|
||||||
|
|
||||||
|
\*\* Data Collection Method by dataset <br>
|
||||||
|
|
||||||
|
- Human <br>
|
||||||
|
|
||||||
|
\*\* Labeling Method by dataset <br>
|
||||||
|
|
||||||
|
- Automated <br>
|
||||||
|
|
||||||
|
## Inference:
|
||||||
|
|
||||||
|
**Engine:** PyTorch <br>
|
||||||
|
**Test Hardware:** NVIDIA A100 GPU <br>
|
||||||
|
|
||||||
|
## Ethical Considerations:
|
||||||
|
|
||||||
|
NVIDIA believes Trustworthy AI is a shared responsibility and we have established policies and practices to enable development for a wide array of AI applications. When downloaded or used in accordance with our terms of service, developers should work with their internal model team to ensure this model meets requirements for the relevant industry and use case and addresses unforeseen product misuse. For more detailed information on ethical considerations for this model, please see the Model Card++ Explainability, Bias, Safety & Security, and Privacy Subcards. Please report security vulnerabilities or NVIDIA AI Concerns [here](https://www.nvidia.com/en-us/support/submit-security-vulnerability/).
|
14
GPT_SoVITS/BigVGAN/nv-modelcard++/privacy.md
Normal file
14
GPT_SoVITS/BigVGAN/nv-modelcard++/privacy.md
Normal file
@ -0,0 +1,14 @@
|
|||||||
|
| Field | Response |
|
||||||
|
| :------------------------------------------------------------------------------------------------------------------------------------- | :--------------------------------------------- |
|
||||||
|
| Generatable or reverse engineerable personal information? | None |
|
||||||
|
| Protected class data used to create this model? | None |
|
||||||
|
| Was consent obtained for any personal data used? | Not Applicable (No Personal Data) |
|
||||||
|
| How often is dataset reviewed? | Before Release |
|
||||||
|
| Is a mechanism in place to honor data subject right of access or deletion of personal data? | Not Applicable |
|
||||||
|
| If personal collected for the development of the model, was it collected directly by NVIDIA? | Not Applicable |
|
||||||
|
| If personal collected for the development of the model by NVIDIA, do you maintain or have access to disclosures made to data subjects? | Not Applicable |
|
||||||
|
| If personal collected for the development of this AI model, was it minimized to only what was required? | Not Applicable |
|
||||||
|
| Is data in dataset traceable? | Yes |
|
||||||
|
| Is there provenance for all datasets used in training? | Yes |
|
||||||
|
| Does data labeling (annotation, metadata) comply with privacy laws? | Yes |
|
||||||
|
| Is data compliant with data subject requests for data correction or removal, if such a request was made? | No, not possible with externally-sourced data. |
|
6
GPT_SoVITS/BigVGAN/nv-modelcard++/safety.md
Normal file
6
GPT_SoVITS/BigVGAN/nv-modelcard++/safety.md
Normal file
@ -0,0 +1,6 @@
|
|||||||
|
| Field | Response |
|
||||||
|
| :---------------------------------------------- | :---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||||
|
| Model Application(s): | Synethic Audio Generation |
|
||||||
|
| Describe the life critical impact (if present). | Not Applicable |
|
||||||
|
| Use Case Restrictions: | None |
|
||||||
|
| Model and dataset restrictions: | The Principle of least privilege (PoLP) is applied limiting access for dataset generation and model development. Restrictions enforce dataset access during training, and dataset license constraints adhered to. |
|
13
GPT_SoVITS/BigVGAN/requirements.txt
Normal file
13
GPT_SoVITS/BigVGAN/requirements.txt
Normal file
@ -0,0 +1,13 @@
|
|||||||
|
torch
|
||||||
|
numpy
|
||||||
|
librosa>=0.8.1
|
||||||
|
scipy
|
||||||
|
tensorboard
|
||||||
|
soundfile
|
||||||
|
matplotlib
|
||||||
|
pesq
|
||||||
|
auraloss
|
||||||
|
tqdm
|
||||||
|
nnAudio
|
||||||
|
ninja
|
||||||
|
huggingface_hub>=0.23.4
|
62
GPT_SoVITS/BigVGAN/tests/test_activation.py
Normal file
62
GPT_SoVITS/BigVGAN/tests/test_activation.py
Normal file
@ -0,0 +1,62 @@
|
|||||||
|
# Copyright (c) 2024 NVIDIA CORPORATION.
|
||||||
|
# Licensed under the MIT license.
|
||||||
|
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
|
||||||
|
# to import modules from parent_dir
|
||||||
|
parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
||||||
|
sys.path.append(parent_dir)
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from alias_free_activation.cuda import activation1d
|
||||||
|
from activations import Snake
|
||||||
|
|
||||||
|
|
||||||
|
def test_load_fused_kernels():
|
||||||
|
try:
|
||||||
|
print("[Success] load_fused_kernels")
|
||||||
|
except ImportError as e:
|
||||||
|
print("[Fail] load_fused_kernels")
|
||||||
|
raise e
|
||||||
|
|
||||||
|
|
||||||
|
def test_anti_alias_activation():
|
||||||
|
data = torch.rand((10, 10, 200), device="cuda")
|
||||||
|
|
||||||
|
# Check activations.Snake cuda vs. torch
|
||||||
|
fused_anti_alias_activation = activation1d.Activation1d(activation=Snake(10), fused=True).cuda()
|
||||||
|
fused_activation_output = fused_anti_alias_activation(data)
|
||||||
|
|
||||||
|
torch_anti_alias_activation = activation1d.Activation1d(activation=Snake(10), fused=False).cuda()
|
||||||
|
torch_activation_output = torch_anti_alias_activation(data)
|
||||||
|
|
||||||
|
test_result = (fused_activation_output - torch_activation_output).abs()
|
||||||
|
|
||||||
|
while test_result.dim() != 1:
|
||||||
|
test_result = test_result.mean(dim=-1)
|
||||||
|
|
||||||
|
diff = test_result.mean(dim=-1)
|
||||||
|
|
||||||
|
if diff <= 1e-3:
|
||||||
|
print(
|
||||||
|
f"\n[Success] test_fused_anti_alias_activation"
|
||||||
|
f"\n > mean_difference={diff}"
|
||||||
|
f"\n > fused_values={fused_activation_output[-1][-1][:].tolist()}"
|
||||||
|
f"\n > torch_values={torch_activation_output[-1][-1][:].tolist()}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
print(
|
||||||
|
f"\n[Fail] test_fused_anti_alias_activation"
|
||||||
|
f"\n > mean_difference={diff}, "
|
||||||
|
f"\n > fused_values={fused_activation_output[-1][-1][:].tolist()}, "
|
||||||
|
f"\n > torch_values={torch_activation_output[-1][-1][:].tolist()}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
from alias_free_activation.cuda import load
|
||||||
|
|
||||||
|
load.load()
|
||||||
|
test_load_fused_kernels()
|
||||||
|
test_anti_alias_activation()
|
62
GPT_SoVITS/BigVGAN/tests/test_activation_snake_beta.py
Normal file
62
GPT_SoVITS/BigVGAN/tests/test_activation_snake_beta.py
Normal file
@ -0,0 +1,62 @@
|
|||||||
|
# Copyright (c) 2024 NVIDIA CORPORATION.
|
||||||
|
# Licensed under the MIT license.
|
||||||
|
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
|
||||||
|
# to import modules from parent_dir
|
||||||
|
parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
||||||
|
sys.path.append(parent_dir)
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from alias_free_activation.cuda import activation1d
|
||||||
|
from activations import SnakeBeta
|
||||||
|
|
||||||
|
|
||||||
|
def test_load_fused_kernels():
|
||||||
|
try:
|
||||||
|
print("[Success] load_fused_kernels")
|
||||||
|
except ImportError as e:
|
||||||
|
print("[Fail] load_fused_kernels")
|
||||||
|
raise e
|
||||||
|
|
||||||
|
|
||||||
|
def test_anti_alias_activation():
|
||||||
|
data = torch.rand((10, 10, 200), device="cuda")
|
||||||
|
|
||||||
|
# Check activations, Snake CUDA vs. Torch
|
||||||
|
fused_anti_alias_activation = activation1d.Activation1d(activation=SnakeBeta(10), fused=True).cuda()
|
||||||
|
fused_activation_output = fused_anti_alias_activation(data)
|
||||||
|
|
||||||
|
torch_anti_alias_activation = activation1d.Activation1d(activation=SnakeBeta(10), fused=False).cuda()
|
||||||
|
torch_activation_output = torch_anti_alias_activation(data)
|
||||||
|
|
||||||
|
test_result = (fused_activation_output - torch_activation_output).abs()
|
||||||
|
|
||||||
|
while test_result.dim() != 1:
|
||||||
|
test_result = test_result.mean(dim=-1)
|
||||||
|
|
||||||
|
diff = test_result.mean(dim=-1)
|
||||||
|
|
||||||
|
if diff <= 1e-3:
|
||||||
|
print(
|
||||||
|
f"\n[Success] test_fused_anti_alias_activation"
|
||||||
|
f"\n > mean_difference={diff}"
|
||||||
|
f"\n > fused_values={fused_activation_output[-1][-1][:].tolist()}"
|
||||||
|
f"\n > torch_values={torch_activation_output[-1][-1][:].tolist()}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
print(
|
||||||
|
f"\n[Fail] test_fused_anti_alias_activation"
|
||||||
|
f"\n > mean_difference={diff}, "
|
||||||
|
f"\n > fused_values={fused_activation_output[-1][-1][:].tolist()}, "
|
||||||
|
f"\n > torch_values={torch_activation_output[-1][-1][:].tolist()}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
from alias_free_activation.cuda import load
|
||||||
|
|
||||||
|
load.load()
|
||||||
|
test_load_fused_kernels()
|
||||||
|
test_anti_alias_activation()
|
215
GPT_SoVITS/BigVGAN/tests/test_cuda_vs_torch_model.py
Normal file
215
GPT_SoVITS/BigVGAN/tests/test_cuda_vs_torch_model.py
Normal file
@ -0,0 +1,215 @@
|
|||||||
|
# Copyright (c) 2024 NVIDIA CORPORATION.
|
||||||
|
# Licensed under the MIT license.
|
||||||
|
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
|
||||||
|
# to import modules from parent_dir
|
||||||
|
parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
||||||
|
sys.path.append(parent_dir)
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import json
|
||||||
|
from env import AttrDict
|
||||||
|
from bigvgan import BigVGAN
|
||||||
|
from time import time
|
||||||
|
from tqdm import tqdm
|
||||||
|
from meldataset import mel_spectrogram, MAX_WAV_VALUE
|
||||||
|
from scipy.io.wavfile import write
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
torch.backends.cudnn.benchmark = True
|
||||||
|
|
||||||
|
# For easier debugging
|
||||||
|
torch.set_printoptions(linewidth=200, threshold=10_000)
|
||||||
|
|
||||||
|
|
||||||
|
def generate_soundwave(duration=5.0, sr=24000):
|
||||||
|
t = np.linspace(0, duration, int(sr * duration), False, dtype=np.float32)
|
||||||
|
|
||||||
|
modulation = np.sin(2 * np.pi * t / duration)
|
||||||
|
|
||||||
|
min_freq = 220
|
||||||
|
max_freq = 1760
|
||||||
|
frequencies = min_freq + (max_freq - min_freq) * (modulation + 1) / 2
|
||||||
|
soundwave = np.sin(2 * np.pi * frequencies * t)
|
||||||
|
|
||||||
|
soundwave = soundwave / np.max(np.abs(soundwave)) * 0.95
|
||||||
|
|
||||||
|
return soundwave, sr
|
||||||
|
|
||||||
|
|
||||||
|
def get_mel(x, h):
|
||||||
|
return mel_spectrogram(x, h.n_fft, h.num_mels, h.sampling_rate, h.hop_size, h.win_size, h.fmin, h.fmax)
|
||||||
|
|
||||||
|
|
||||||
|
def load_checkpoint(filepath, device):
|
||||||
|
assert os.path.isfile(filepath)
|
||||||
|
print(f"Loading '{filepath}'")
|
||||||
|
checkpoint_dict = torch.load(filepath, map_location=device)
|
||||||
|
print("Complete.")
|
||||||
|
return checkpoint_dict
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser(description="Test script to check CUDA kernel correctness.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--checkpoint_file",
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="Path to the checkpoint file. Assumes config.json exists in the directory.",
|
||||||
|
)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
config_file = os.path.join(os.path.split(args.checkpoint_file)[0], "config.json")
|
||||||
|
with open(config_file) as f:
|
||||||
|
config = f.read()
|
||||||
|
json_config = json.loads(config)
|
||||||
|
h = AttrDict({**json_config})
|
||||||
|
|
||||||
|
print("loading plain Pytorch BigVGAN")
|
||||||
|
generator_original = BigVGAN(h).to("cuda")
|
||||||
|
print("loading CUDA kernel BigVGAN with auto-build")
|
||||||
|
generator_cuda_kernel = BigVGAN(h, use_cuda_kernel=True).to("cuda")
|
||||||
|
|
||||||
|
state_dict_g = load_checkpoint(args.checkpoint_file, "cuda")
|
||||||
|
generator_original.load_state_dict(state_dict_g["generator"])
|
||||||
|
generator_cuda_kernel.load_state_dict(state_dict_g["generator"])
|
||||||
|
|
||||||
|
generator_original.remove_weight_norm()
|
||||||
|
generator_original.eval()
|
||||||
|
generator_cuda_kernel.remove_weight_norm()
|
||||||
|
generator_cuda_kernel.eval()
|
||||||
|
|
||||||
|
# define number of samples and length of mel frame to benchmark
|
||||||
|
num_sample = 10
|
||||||
|
num_mel_frame = 16384
|
||||||
|
|
||||||
|
# CUDA kernel correctness check
|
||||||
|
diff = 0.0
|
||||||
|
for i in tqdm(range(num_sample)):
|
||||||
|
# Random mel
|
||||||
|
data = torch.rand((1, h.num_mels, num_mel_frame), device="cuda")
|
||||||
|
|
||||||
|
with torch.inference_mode():
|
||||||
|
audio_original = generator_original(data)
|
||||||
|
|
||||||
|
with torch.inference_mode():
|
||||||
|
audio_cuda_kernel = generator_cuda_kernel(data)
|
||||||
|
|
||||||
|
# Both outputs should be (almost) the same
|
||||||
|
test_result = (audio_original - audio_cuda_kernel).abs()
|
||||||
|
diff += test_result.mean(dim=-1).item()
|
||||||
|
|
||||||
|
diff /= num_sample
|
||||||
|
if diff <= 2e-3: # We can expect a small difference (~1e-3) which does not affect perceptual quality
|
||||||
|
print(
|
||||||
|
f"\n[Success] test CUDA fused vs. plain torch BigVGAN inference"
|
||||||
|
f"\n > mean_difference={diff}"
|
||||||
|
f"\n > fused_values={audio_cuda_kernel[-1][-1][-30:].tolist()}"
|
||||||
|
f"\n > torch_values={audio_original[-1][-1][-30:].tolist()}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
print(
|
||||||
|
f"\n[Fail] test CUDA fused vs. plain torch BigVGAN inference"
|
||||||
|
f"\n > mean_difference={diff}"
|
||||||
|
f"\n > fused_values={audio_cuda_kernel[-1][-1][-30:].tolist()}, "
|
||||||
|
f"\n > torch_values={audio_original[-1][-1][-30:].tolist()}"
|
||||||
|
)
|
||||||
|
|
||||||
|
del data, audio_original, audio_cuda_kernel
|
||||||
|
|
||||||
|
# Variables for tracking total time and VRAM usage
|
||||||
|
toc_total_original = 0
|
||||||
|
toc_total_cuda_kernel = 0
|
||||||
|
vram_used_original_total = 0
|
||||||
|
vram_used_cuda_kernel_total = 0
|
||||||
|
audio_length_total = 0
|
||||||
|
|
||||||
|
# Measure Original inference in isolation
|
||||||
|
for i in tqdm(range(num_sample)):
|
||||||
|
torch.cuda.reset_peak_memory_stats(device="cuda")
|
||||||
|
data = torch.rand((1, h.num_mels, num_mel_frame), device="cuda")
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
tic = time()
|
||||||
|
with torch.inference_mode():
|
||||||
|
audio_original = generator_original(data)
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
toc = time() - tic
|
||||||
|
toc_total_original += toc
|
||||||
|
|
||||||
|
vram_used_original_total += torch.cuda.max_memory_allocated(device="cuda")
|
||||||
|
|
||||||
|
del data, audio_original
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
# Measure CUDA kernel inference in isolation
|
||||||
|
for i in tqdm(range(num_sample)):
|
||||||
|
torch.cuda.reset_peak_memory_stats(device="cuda")
|
||||||
|
data = torch.rand((1, h.num_mels, num_mel_frame), device="cuda")
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
tic = time()
|
||||||
|
with torch.inference_mode():
|
||||||
|
audio_cuda_kernel = generator_cuda_kernel(data)
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
toc = time() - tic
|
||||||
|
toc_total_cuda_kernel += toc
|
||||||
|
|
||||||
|
audio_length_total += audio_cuda_kernel.shape[-1]
|
||||||
|
|
||||||
|
vram_used_cuda_kernel_total += torch.cuda.max_memory_allocated(device="cuda")
|
||||||
|
|
||||||
|
del data, audio_cuda_kernel
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
# Calculate metrics
|
||||||
|
audio_second = audio_length_total / h.sampling_rate
|
||||||
|
khz_original = audio_length_total / toc_total_original / 1000
|
||||||
|
khz_cuda_kernel = audio_length_total / toc_total_cuda_kernel / 1000
|
||||||
|
vram_used_original_gb = vram_used_original_total / num_sample / (1024**3)
|
||||||
|
vram_used_cuda_kernel_gb = vram_used_cuda_kernel_total / num_sample / (1024**3)
|
||||||
|
|
||||||
|
# Print results
|
||||||
|
print(
|
||||||
|
f"Original BigVGAN: took {toc_total_original:.2f} seconds to generate {audio_second:.2f} seconds of audio, {khz_original:.1f}kHz, {audio_second / toc_total_original:.1f} faster than realtime, VRAM used {vram_used_original_gb:.1f} GB"
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
f"CUDA kernel BigVGAN: took {toc_total_cuda_kernel:.2f} seconds to generate {audio_second:.2f} seconds of audio, {khz_cuda_kernel:.1f}kHz, {audio_second / toc_total_cuda_kernel:.1f} faster than realtime, VRAM used {vram_used_cuda_kernel_gb:.1f} GB"
|
||||||
|
)
|
||||||
|
print(f"speedup of CUDA kernel: {khz_cuda_kernel / khz_original}")
|
||||||
|
print(f"VRAM saving of CUDA kernel: {vram_used_original_gb / vram_used_cuda_kernel_gb}")
|
||||||
|
|
||||||
|
# Use artificial sine waves for inference test
|
||||||
|
audio_real, sr = generate_soundwave(duration=5.0, sr=h.sampling_rate)
|
||||||
|
audio_real = torch.tensor(audio_real).to("cuda")
|
||||||
|
# Compute mel spectrogram from the ground truth audio
|
||||||
|
x = get_mel(audio_real.unsqueeze(0), h)
|
||||||
|
|
||||||
|
with torch.inference_mode():
|
||||||
|
y_g_hat_original = generator_original(x)
|
||||||
|
y_g_hat_cuda_kernel = generator_cuda_kernel(x)
|
||||||
|
|
||||||
|
audio_real = audio_real.squeeze()
|
||||||
|
audio_real = audio_real * MAX_WAV_VALUE
|
||||||
|
audio_real = audio_real.cpu().numpy().astype("int16")
|
||||||
|
|
||||||
|
audio_original = y_g_hat_original.squeeze()
|
||||||
|
audio_original = audio_original * MAX_WAV_VALUE
|
||||||
|
audio_original = audio_original.cpu().numpy().astype("int16")
|
||||||
|
|
||||||
|
audio_cuda_kernel = y_g_hat_cuda_kernel.squeeze()
|
||||||
|
audio_cuda_kernel = audio_cuda_kernel * MAX_WAV_VALUE
|
||||||
|
audio_cuda_kernel = audio_cuda_kernel.cpu().numpy().astype("int16")
|
||||||
|
|
||||||
|
os.makedirs("tmp", exist_ok=True)
|
||||||
|
output_file_real = os.path.join("tmp", "audio_real.wav")
|
||||||
|
output_file_original = os.path.join("tmp", "audio_generated_original.wav")
|
||||||
|
output_file_cuda_kernel = os.path.join("tmp", "audio_generated_cuda_kernel.wav")
|
||||||
|
write(output_file_real, h.sampling_rate, audio_real)
|
||||||
|
write(output_file_original, h.sampling_rate, audio_original)
|
||||||
|
write(output_file_cuda_kernel, h.sampling_rate, audio_cuda_kernel)
|
||||||
|
print("Example generated audios of original vs. fused CUDA kernel written to tmp!")
|
||||||
|
print("Done")
|
716
GPT_SoVITS/BigVGAN/train.py
Normal file
716
GPT_SoVITS/BigVGAN/train.py
Normal file
@ -0,0 +1,716 @@
|
|||||||
|
# Copyright (c) 2024 NVIDIA CORPORATION.
|
||||||
|
# Licensed under the MIT license.
|
||||||
|
|
||||||
|
# Adapted from https://github.com/jik876/hifi-gan under the MIT license.
|
||||||
|
# LICENSE is in incl_licenses directory.
|
||||||
|
|
||||||
|
|
||||||
|
import warnings
|
||||||
|
|
||||||
|
warnings.simplefilter(action="ignore", category=FutureWarning)
|
||||||
|
import itertools
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
from torch.utils.data import DistributedSampler, DataLoader
|
||||||
|
import torch.multiprocessing as mp
|
||||||
|
from torch.distributed import init_process_group
|
||||||
|
from torch.nn.parallel import DistributedDataParallel
|
||||||
|
from env import AttrDict, build_env
|
||||||
|
from meldataset import MelDataset, mel_spectrogram, get_dataset_filelist, MAX_WAV_VALUE
|
||||||
|
|
||||||
|
from bigvgan import BigVGAN
|
||||||
|
from discriminators import (
|
||||||
|
MultiPeriodDiscriminator,
|
||||||
|
MultiResolutionDiscriminator,
|
||||||
|
MultiBandDiscriminator,
|
||||||
|
MultiScaleSubbandCQTDiscriminator,
|
||||||
|
)
|
||||||
|
from loss import (
|
||||||
|
feature_loss,
|
||||||
|
generator_loss,
|
||||||
|
discriminator_loss,
|
||||||
|
MultiScaleMelSpectrogramLoss,
|
||||||
|
)
|
||||||
|
|
||||||
|
from utils import (
|
||||||
|
plot_spectrogram,
|
||||||
|
plot_spectrogram_clipped,
|
||||||
|
scan_checkpoint,
|
||||||
|
load_checkpoint,
|
||||||
|
save_checkpoint,
|
||||||
|
save_audio,
|
||||||
|
)
|
||||||
|
import torchaudio as ta
|
||||||
|
from pesq import pesq
|
||||||
|
from tqdm import tqdm
|
||||||
|
import auraloss
|
||||||
|
|
||||||
|
torch.backends.cudnn.benchmark = False
|
||||||
|
|
||||||
|
|
||||||
|
def train(rank, a, h):
|
||||||
|
if h.num_gpus > 1:
|
||||||
|
# initialize distributed
|
||||||
|
init_process_group(
|
||||||
|
backend=h.dist_config["dist_backend"],
|
||||||
|
init_method=h.dist_config["dist_url"],
|
||||||
|
world_size=h.dist_config["world_size"] * h.num_gpus,
|
||||||
|
rank=rank,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Set seed and device
|
||||||
|
torch.cuda.manual_seed(h.seed)
|
||||||
|
torch.cuda.set_device(rank)
|
||||||
|
device = torch.device(f"cuda:{rank:d}")
|
||||||
|
|
||||||
|
# Define BigVGAN generator
|
||||||
|
generator = BigVGAN(h).to(device)
|
||||||
|
|
||||||
|
# Define discriminators. MPD is used by default
|
||||||
|
mpd = MultiPeriodDiscriminator(h).to(device)
|
||||||
|
|
||||||
|
# Define additional discriminators. BigVGAN-v1 uses UnivNet's MRD as default
|
||||||
|
# New in BigVGAN-v2: option to switch to new discriminators: MultiBandDiscriminator / MultiScaleSubbandCQTDiscriminator
|
||||||
|
if h.get("use_mbd_instead_of_mrd", False): # Switch to MBD
|
||||||
|
print("[INFO] using MultiBandDiscriminator of BigVGAN-v2 instead of MultiResolutionDiscriminator")
|
||||||
|
# Variable name is kept as "mrd" for backward compatibility & minimal code change
|
||||||
|
mrd = MultiBandDiscriminator(h).to(device)
|
||||||
|
elif h.get("use_cqtd_instead_of_mrd", False): # Switch to CQTD
|
||||||
|
print("[INFO] using MultiScaleSubbandCQTDiscriminator of BigVGAN-v2 instead of MultiResolutionDiscriminator")
|
||||||
|
mrd = MultiScaleSubbandCQTDiscriminator(h).to(device)
|
||||||
|
else: # Fallback to original MRD in BigVGAN-v1
|
||||||
|
mrd = MultiResolutionDiscriminator(h).to(device)
|
||||||
|
|
||||||
|
# New in BigVGAN-v2: option to switch to multi-scale L1 mel loss
|
||||||
|
if h.get("use_multiscale_melloss", False):
|
||||||
|
print("[INFO] using multi-scale Mel l1 loss of BigVGAN-v2 instead of the original single-scale loss")
|
||||||
|
fn_mel_loss_multiscale = MultiScaleMelSpectrogramLoss(
|
||||||
|
sampling_rate=h.sampling_rate
|
||||||
|
) # NOTE: accepts waveform as input
|
||||||
|
else:
|
||||||
|
fn_mel_loss_singlescale = F.l1_loss
|
||||||
|
|
||||||
|
# Print the model & number of parameters, and create or scan the latest checkpoint from checkpoints directory
|
||||||
|
if rank == 0:
|
||||||
|
print(generator)
|
||||||
|
print(mpd)
|
||||||
|
print(mrd)
|
||||||
|
print(f"Generator params: {sum(p.numel() for p in generator.parameters())}")
|
||||||
|
print(f"Discriminator mpd params: {sum(p.numel() for p in mpd.parameters())}")
|
||||||
|
print(f"Discriminator mrd params: {sum(p.numel() for p in mrd.parameters())}")
|
||||||
|
os.makedirs(a.checkpoint_path, exist_ok=True)
|
||||||
|
print(f"Checkpoints directory: {a.checkpoint_path}")
|
||||||
|
|
||||||
|
if os.path.isdir(a.checkpoint_path):
|
||||||
|
# New in v2.1: If the step prefix pattern-based checkpoints are not found, also check for renamed files in Hugging Face Hub to resume training
|
||||||
|
cp_g = scan_checkpoint(a.checkpoint_path, prefix="g_", renamed_file="bigvgan_generator.pt")
|
||||||
|
cp_do = scan_checkpoint(
|
||||||
|
a.checkpoint_path,
|
||||||
|
prefix="do_",
|
||||||
|
renamed_file="bigvgan_discriminator_optimizer.pt",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Load the latest checkpoint if exists
|
||||||
|
steps = 0
|
||||||
|
if cp_g is None or cp_do is None:
|
||||||
|
state_dict_do = None
|
||||||
|
last_epoch = -1
|
||||||
|
else:
|
||||||
|
state_dict_g = load_checkpoint(cp_g, device)
|
||||||
|
state_dict_do = load_checkpoint(cp_do, device)
|
||||||
|
generator.load_state_dict(state_dict_g["generator"])
|
||||||
|
mpd.load_state_dict(state_dict_do["mpd"])
|
||||||
|
mrd.load_state_dict(state_dict_do["mrd"])
|
||||||
|
steps = state_dict_do["steps"] + 1
|
||||||
|
last_epoch = state_dict_do["epoch"]
|
||||||
|
|
||||||
|
# Initialize DDP, optimizers, and schedulers
|
||||||
|
if h.num_gpus > 1:
|
||||||
|
generator = DistributedDataParallel(generator, device_ids=[rank]).to(device)
|
||||||
|
mpd = DistributedDataParallel(mpd, device_ids=[rank]).to(device)
|
||||||
|
mrd = DistributedDataParallel(mrd, device_ids=[rank]).to(device)
|
||||||
|
|
||||||
|
optim_g = torch.optim.AdamW(generator.parameters(), h.learning_rate, betas=[h.adam_b1, h.adam_b2])
|
||||||
|
optim_d = torch.optim.AdamW(
|
||||||
|
itertools.chain(mrd.parameters(), mpd.parameters()),
|
||||||
|
h.learning_rate,
|
||||||
|
betas=[h.adam_b1, h.adam_b2],
|
||||||
|
)
|
||||||
|
|
||||||
|
if state_dict_do is not None:
|
||||||
|
optim_g.load_state_dict(state_dict_do["optim_g"])
|
||||||
|
optim_d.load_state_dict(state_dict_do["optim_d"])
|
||||||
|
|
||||||
|
scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optim_g, gamma=h.lr_decay, last_epoch=last_epoch)
|
||||||
|
scheduler_d = torch.optim.lr_scheduler.ExponentialLR(optim_d, gamma=h.lr_decay, last_epoch=last_epoch)
|
||||||
|
|
||||||
|
# Define training and validation datasets
|
||||||
|
|
||||||
|
"""
|
||||||
|
unseen_validation_filelist will contain sample filepaths outside the seen training & validation dataset
|
||||||
|
Example: trained on LibriTTS, validate on VCTK
|
||||||
|
"""
|
||||||
|
training_filelist, validation_filelist, list_unseen_validation_filelist = get_dataset_filelist(a)
|
||||||
|
|
||||||
|
trainset = MelDataset(
|
||||||
|
training_filelist,
|
||||||
|
h,
|
||||||
|
h.segment_size,
|
||||||
|
h.n_fft,
|
||||||
|
h.num_mels,
|
||||||
|
h.hop_size,
|
||||||
|
h.win_size,
|
||||||
|
h.sampling_rate,
|
||||||
|
h.fmin,
|
||||||
|
h.fmax,
|
||||||
|
shuffle=False if h.num_gpus > 1 else True,
|
||||||
|
fmax_loss=h.fmax_for_loss,
|
||||||
|
device=device,
|
||||||
|
fine_tuning=a.fine_tuning,
|
||||||
|
base_mels_path=a.input_mels_dir,
|
||||||
|
is_seen=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
train_sampler = DistributedSampler(trainset) if h.num_gpus > 1 else None
|
||||||
|
|
||||||
|
train_loader = DataLoader(
|
||||||
|
trainset,
|
||||||
|
num_workers=h.num_workers,
|
||||||
|
shuffle=False,
|
||||||
|
sampler=train_sampler,
|
||||||
|
batch_size=h.batch_size,
|
||||||
|
pin_memory=True,
|
||||||
|
drop_last=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
if rank == 0:
|
||||||
|
validset = MelDataset(
|
||||||
|
validation_filelist,
|
||||||
|
h,
|
||||||
|
h.segment_size,
|
||||||
|
h.n_fft,
|
||||||
|
h.num_mels,
|
||||||
|
h.hop_size,
|
||||||
|
h.win_size,
|
||||||
|
h.sampling_rate,
|
||||||
|
h.fmin,
|
||||||
|
h.fmax,
|
||||||
|
False,
|
||||||
|
False,
|
||||||
|
fmax_loss=h.fmax_for_loss,
|
||||||
|
device=device,
|
||||||
|
fine_tuning=a.fine_tuning,
|
||||||
|
base_mels_path=a.input_mels_dir,
|
||||||
|
is_seen=True,
|
||||||
|
)
|
||||||
|
validation_loader = DataLoader(
|
||||||
|
validset,
|
||||||
|
num_workers=1,
|
||||||
|
shuffle=False,
|
||||||
|
sampler=None,
|
||||||
|
batch_size=1,
|
||||||
|
pin_memory=True,
|
||||||
|
drop_last=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
list_unseen_validset = []
|
||||||
|
list_unseen_validation_loader = []
|
||||||
|
for i in range(len(list_unseen_validation_filelist)):
|
||||||
|
unseen_validset = MelDataset(
|
||||||
|
list_unseen_validation_filelist[i],
|
||||||
|
h,
|
||||||
|
h.segment_size,
|
||||||
|
h.n_fft,
|
||||||
|
h.num_mels,
|
||||||
|
h.hop_size,
|
||||||
|
h.win_size,
|
||||||
|
h.sampling_rate,
|
||||||
|
h.fmin,
|
||||||
|
h.fmax,
|
||||||
|
False,
|
||||||
|
False,
|
||||||
|
fmax_loss=h.fmax_for_loss,
|
||||||
|
device=device,
|
||||||
|
fine_tuning=a.fine_tuning,
|
||||||
|
base_mels_path=a.input_mels_dir,
|
||||||
|
is_seen=False,
|
||||||
|
)
|
||||||
|
unseen_validation_loader = DataLoader(
|
||||||
|
unseen_validset,
|
||||||
|
num_workers=1,
|
||||||
|
shuffle=False,
|
||||||
|
sampler=None,
|
||||||
|
batch_size=1,
|
||||||
|
pin_memory=True,
|
||||||
|
drop_last=True,
|
||||||
|
)
|
||||||
|
list_unseen_validset.append(unseen_validset)
|
||||||
|
list_unseen_validation_loader.append(unseen_validation_loader)
|
||||||
|
|
||||||
|
# Tensorboard logger
|
||||||
|
sw = SummaryWriter(os.path.join(a.checkpoint_path, "logs"))
|
||||||
|
if a.save_audio: # Also save audio to disk if --save_audio is set to True
|
||||||
|
os.makedirs(os.path.join(a.checkpoint_path, "samples"), exist_ok=True)
|
||||||
|
|
||||||
|
"""
|
||||||
|
Validation loop, "mode" parameter is automatically defined as (seen or unseen)_(name of the dataset).
|
||||||
|
If the name of the dataset contains "nonspeech", it skips PESQ calculation to prevent errors
|
||||||
|
"""
|
||||||
|
|
||||||
|
def validate(rank, a, h, loader, mode="seen"):
|
||||||
|
assert rank == 0, "validate should only run on rank=0"
|
||||||
|
generator.eval()
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
val_err_tot = 0
|
||||||
|
val_pesq_tot = 0
|
||||||
|
val_mrstft_tot = 0
|
||||||
|
|
||||||
|
# Modules for evaluation metrics
|
||||||
|
pesq_resampler = ta.transforms.Resample(h.sampling_rate, 16000).cuda()
|
||||||
|
loss_mrstft = auraloss.freq.MultiResolutionSTFTLoss(device="cuda")
|
||||||
|
|
||||||
|
if a.save_audio: # Also save audio to disk if --save_audio is set to True
|
||||||
|
os.makedirs(
|
||||||
|
os.path.join(a.checkpoint_path, "samples", f"gt_{mode}"),
|
||||||
|
exist_ok=True,
|
||||||
|
)
|
||||||
|
os.makedirs(
|
||||||
|
os.path.join(a.checkpoint_path, "samples", f"{mode}_{steps:08d}"),
|
||||||
|
exist_ok=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
print(f"step {steps} {mode} speaker validation...")
|
||||||
|
|
||||||
|
# Loop over validation set and compute metrics
|
||||||
|
for j, batch in enumerate(tqdm(loader)):
|
||||||
|
x, y, _, y_mel = batch
|
||||||
|
y = y.to(device)
|
||||||
|
if hasattr(generator, "module"):
|
||||||
|
y_g_hat = generator.module(x.to(device))
|
||||||
|
else:
|
||||||
|
y_g_hat = generator(x.to(device))
|
||||||
|
y_mel = y_mel.to(device, non_blocking=True)
|
||||||
|
y_g_hat_mel = mel_spectrogram(
|
||||||
|
y_g_hat.squeeze(1),
|
||||||
|
h.n_fft,
|
||||||
|
h.num_mels,
|
||||||
|
h.sampling_rate,
|
||||||
|
h.hop_size,
|
||||||
|
h.win_size,
|
||||||
|
h.fmin,
|
||||||
|
h.fmax_for_loss,
|
||||||
|
)
|
||||||
|
min_t = min(y_mel.size(-1), y_g_hat_mel.size(-1))
|
||||||
|
val_err_tot += F.l1_loss(y_mel[..., :min_t], y_g_hat_mel[..., :min_t]).item()
|
||||||
|
|
||||||
|
# PESQ calculation. only evaluate PESQ if it's speech signal (nonspeech PESQ will error out)
|
||||||
|
if "nonspeech" not in mode: # Skips if the name of dataset (in mode string) contains "nonspeech"
|
||||||
|
# Resample to 16000 for pesq
|
||||||
|
y_16k = pesq_resampler(y)
|
||||||
|
y_g_hat_16k = pesq_resampler(y_g_hat.squeeze(1))
|
||||||
|
y_int_16k = (y_16k[0] * MAX_WAV_VALUE).short().cpu().numpy()
|
||||||
|
y_g_hat_int_16k = (y_g_hat_16k[0] * MAX_WAV_VALUE).short().cpu().numpy()
|
||||||
|
val_pesq_tot += pesq(16000, y_int_16k, y_g_hat_int_16k, "wb")
|
||||||
|
|
||||||
|
# MRSTFT calculation
|
||||||
|
min_t = min(y.size(-1), y_g_hat.size(-1))
|
||||||
|
val_mrstft_tot += loss_mrstft(y_g_hat[..., :min_t], y[..., :min_t]).item()
|
||||||
|
|
||||||
|
# Log audio and figures to Tensorboard
|
||||||
|
if j % a.eval_subsample == 0: # Subsample every nth from validation set
|
||||||
|
if steps >= 0:
|
||||||
|
sw.add_audio(f"gt_{mode}/y_{j}", y[0], steps, h.sampling_rate)
|
||||||
|
if a.save_audio: # Also save audio to disk if --save_audio is set to True
|
||||||
|
save_audio(
|
||||||
|
y[0],
|
||||||
|
os.path.join(
|
||||||
|
a.checkpoint_path,
|
||||||
|
"samples",
|
||||||
|
f"gt_{mode}",
|
||||||
|
f"{j:04d}.wav",
|
||||||
|
),
|
||||||
|
h.sampling_rate,
|
||||||
|
)
|
||||||
|
sw.add_figure(
|
||||||
|
f"gt_{mode}/y_spec_{j}",
|
||||||
|
plot_spectrogram(x[0]),
|
||||||
|
steps,
|
||||||
|
)
|
||||||
|
|
||||||
|
sw.add_audio(
|
||||||
|
f"generated_{mode}/y_hat_{j}",
|
||||||
|
y_g_hat[0],
|
||||||
|
steps,
|
||||||
|
h.sampling_rate,
|
||||||
|
)
|
||||||
|
if a.save_audio: # Also save audio to disk if --save_audio is set to True
|
||||||
|
save_audio(
|
||||||
|
y_g_hat[0, 0],
|
||||||
|
os.path.join(
|
||||||
|
a.checkpoint_path,
|
||||||
|
"samples",
|
||||||
|
f"{mode}_{steps:08d}",
|
||||||
|
f"{j:04d}.wav",
|
||||||
|
),
|
||||||
|
h.sampling_rate,
|
||||||
|
)
|
||||||
|
# Spectrogram of synthesized audio
|
||||||
|
y_hat_spec = mel_spectrogram(
|
||||||
|
y_g_hat.squeeze(1),
|
||||||
|
h.n_fft,
|
||||||
|
h.num_mels,
|
||||||
|
h.sampling_rate,
|
||||||
|
h.hop_size,
|
||||||
|
h.win_size,
|
||||||
|
h.fmin,
|
||||||
|
h.fmax,
|
||||||
|
)
|
||||||
|
sw.add_figure(
|
||||||
|
f"generated_{mode}/y_hat_spec_{j}",
|
||||||
|
plot_spectrogram(y_hat_spec.squeeze(0).cpu().numpy()),
|
||||||
|
steps,
|
||||||
|
)
|
||||||
|
|
||||||
|
"""
|
||||||
|
Visualization of spectrogram difference between GT and synthesized audio, difference higher than 1 is clipped for better visualization.
|
||||||
|
"""
|
||||||
|
spec_delta = torch.clamp(
|
||||||
|
torch.abs(x[0] - y_hat_spec.squeeze(0).cpu()),
|
||||||
|
min=1e-6,
|
||||||
|
max=1.0,
|
||||||
|
)
|
||||||
|
sw.add_figure(
|
||||||
|
f"delta_dclip1_{mode}/spec_{j}",
|
||||||
|
plot_spectrogram_clipped(spec_delta.numpy(), clip_max=1.0),
|
||||||
|
steps,
|
||||||
|
)
|
||||||
|
|
||||||
|
val_err = val_err_tot / (j + 1)
|
||||||
|
val_pesq = val_pesq_tot / (j + 1)
|
||||||
|
val_mrstft = val_mrstft_tot / (j + 1)
|
||||||
|
# Log evaluation metrics to Tensorboard
|
||||||
|
sw.add_scalar(f"validation_{mode}/mel_spec_error", val_err, steps)
|
||||||
|
sw.add_scalar(f"validation_{mode}/pesq", val_pesq, steps)
|
||||||
|
sw.add_scalar(f"validation_{mode}/mrstft", val_mrstft, steps)
|
||||||
|
|
||||||
|
generator.train()
|
||||||
|
|
||||||
|
# If the checkpoint is loaded, start with validation loop
|
||||||
|
if steps != 0 and rank == 0 and not a.debug:
|
||||||
|
if not a.skip_seen:
|
||||||
|
validate(
|
||||||
|
rank,
|
||||||
|
a,
|
||||||
|
h,
|
||||||
|
validation_loader,
|
||||||
|
mode=f"seen_{train_loader.dataset.name}",
|
||||||
|
)
|
||||||
|
for i in range(len(list_unseen_validation_loader)):
|
||||||
|
validate(
|
||||||
|
rank,
|
||||||
|
a,
|
||||||
|
h,
|
||||||
|
list_unseen_validation_loader[i],
|
||||||
|
mode=f"unseen_{list_unseen_validation_loader[i].dataset.name}",
|
||||||
|
)
|
||||||
|
# Exit the script if --evaluate is set to True
|
||||||
|
if a.evaluate:
|
||||||
|
exit()
|
||||||
|
|
||||||
|
# Main training loop
|
||||||
|
generator.train()
|
||||||
|
mpd.train()
|
||||||
|
mrd.train()
|
||||||
|
for epoch in range(max(0, last_epoch), a.training_epochs):
|
||||||
|
if rank == 0:
|
||||||
|
start = time.time()
|
||||||
|
print(f"Epoch: {epoch + 1}")
|
||||||
|
|
||||||
|
if h.num_gpus > 1:
|
||||||
|
train_sampler.set_epoch(epoch)
|
||||||
|
|
||||||
|
for i, batch in enumerate(train_loader):
|
||||||
|
if rank == 0:
|
||||||
|
start_b = time.time()
|
||||||
|
x, y, _, y_mel = batch
|
||||||
|
|
||||||
|
x = x.to(device, non_blocking=True)
|
||||||
|
y = y.to(device, non_blocking=True)
|
||||||
|
y_mel = y_mel.to(device, non_blocking=True)
|
||||||
|
y = y.unsqueeze(1)
|
||||||
|
|
||||||
|
y_g_hat = generator(x)
|
||||||
|
y_g_hat_mel = mel_spectrogram(
|
||||||
|
y_g_hat.squeeze(1),
|
||||||
|
h.n_fft,
|
||||||
|
h.num_mels,
|
||||||
|
h.sampling_rate,
|
||||||
|
h.hop_size,
|
||||||
|
h.win_size,
|
||||||
|
h.fmin,
|
||||||
|
h.fmax_for_loss,
|
||||||
|
)
|
||||||
|
|
||||||
|
optim_d.zero_grad()
|
||||||
|
|
||||||
|
# MPD
|
||||||
|
y_df_hat_r, y_df_hat_g, _, _ = mpd(y, y_g_hat.detach())
|
||||||
|
loss_disc_f, losses_disc_f_r, losses_disc_f_g = discriminator_loss(y_df_hat_r, y_df_hat_g)
|
||||||
|
|
||||||
|
# MRD
|
||||||
|
y_ds_hat_r, y_ds_hat_g, _, _ = mrd(y, y_g_hat.detach())
|
||||||
|
loss_disc_s, losses_disc_s_r, losses_disc_s_g = discriminator_loss(y_ds_hat_r, y_ds_hat_g)
|
||||||
|
|
||||||
|
loss_disc_all = loss_disc_s + loss_disc_f
|
||||||
|
|
||||||
|
# Set clip_grad_norm value
|
||||||
|
clip_grad_norm = h.get("clip_grad_norm", 1000.0) # Default to 1000
|
||||||
|
|
||||||
|
# Whether to freeze D for initial training steps
|
||||||
|
if steps >= a.freeze_step:
|
||||||
|
loss_disc_all.backward()
|
||||||
|
grad_norm_mpd = torch.nn.utils.clip_grad_norm_(mpd.parameters(), clip_grad_norm)
|
||||||
|
grad_norm_mrd = torch.nn.utils.clip_grad_norm_(mrd.parameters(), clip_grad_norm)
|
||||||
|
optim_d.step()
|
||||||
|
else:
|
||||||
|
print(f"[WARNING] skipping D training for the first {a.freeze_step} steps")
|
||||||
|
grad_norm_mpd = 0.0
|
||||||
|
grad_norm_mrd = 0.0
|
||||||
|
|
||||||
|
# Generator
|
||||||
|
optim_g.zero_grad()
|
||||||
|
|
||||||
|
# L1 Mel-Spectrogram Loss
|
||||||
|
lambda_melloss = h.get("lambda_melloss", 45.0) # Defaults to 45 in BigVGAN-v1 if not set
|
||||||
|
if h.get("use_multiscale_melloss", False): # uses wav <y, y_g_hat> for loss
|
||||||
|
loss_mel = fn_mel_loss_multiscale(y, y_g_hat) * lambda_melloss
|
||||||
|
else: # Uses mel <y_mel, y_g_hat_mel> for loss
|
||||||
|
loss_mel = fn_mel_loss_singlescale(y_mel, y_g_hat_mel) * lambda_melloss
|
||||||
|
|
||||||
|
# MPD loss
|
||||||
|
y_df_hat_r, y_df_hat_g, fmap_f_r, fmap_f_g = mpd(y, y_g_hat)
|
||||||
|
loss_fm_f = feature_loss(fmap_f_r, fmap_f_g)
|
||||||
|
loss_gen_f, losses_gen_f = generator_loss(y_df_hat_g)
|
||||||
|
|
||||||
|
# MRD loss
|
||||||
|
y_ds_hat_r, y_ds_hat_g, fmap_s_r, fmap_s_g = mrd(y, y_g_hat)
|
||||||
|
loss_fm_s = feature_loss(fmap_s_r, fmap_s_g)
|
||||||
|
loss_gen_s, losses_gen_s = generator_loss(y_ds_hat_g)
|
||||||
|
|
||||||
|
if steps >= a.freeze_step:
|
||||||
|
loss_gen_all = loss_gen_s + loss_gen_f + loss_fm_s + loss_fm_f + loss_mel
|
||||||
|
else:
|
||||||
|
print(f"[WARNING] using regression loss only for G for the first {a.freeze_step} steps")
|
||||||
|
loss_gen_all = loss_mel
|
||||||
|
|
||||||
|
loss_gen_all.backward()
|
||||||
|
grad_norm_g = torch.nn.utils.clip_grad_norm_(generator.parameters(), clip_grad_norm)
|
||||||
|
optim_g.step()
|
||||||
|
|
||||||
|
if rank == 0:
|
||||||
|
# STDOUT logging
|
||||||
|
if steps % a.stdout_interval == 0:
|
||||||
|
mel_error = loss_mel.item() / lambda_melloss # Log training mel regression loss to stdout
|
||||||
|
print(
|
||||||
|
f"Steps: {steps:d}, "
|
||||||
|
f"Gen Loss Total: {loss_gen_all:4.3f}, "
|
||||||
|
f"Mel Error: {mel_error:4.3f}, "
|
||||||
|
f"s/b: {time.time() - start_b:4.3f} "
|
||||||
|
f"lr: {optim_g.param_groups[0]['lr']:4.7f} "
|
||||||
|
f"grad_norm_g: {grad_norm_g:4.3f}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Checkpointing
|
||||||
|
if steps % a.checkpoint_interval == 0 and steps != 0:
|
||||||
|
checkpoint_path = f"{a.checkpoint_path}/g_{steps:08d}"
|
||||||
|
save_checkpoint(
|
||||||
|
checkpoint_path,
|
||||||
|
{"generator": (generator.module if h.num_gpus > 1 else generator).state_dict()},
|
||||||
|
)
|
||||||
|
checkpoint_path = f"{a.checkpoint_path}/do_{steps:08d}"
|
||||||
|
save_checkpoint(
|
||||||
|
checkpoint_path,
|
||||||
|
{
|
||||||
|
"mpd": (mpd.module if h.num_gpus > 1 else mpd).state_dict(),
|
||||||
|
"mrd": (mrd.module if h.num_gpus > 1 else mrd).state_dict(),
|
||||||
|
"optim_g": optim_g.state_dict(),
|
||||||
|
"optim_d": optim_d.state_dict(),
|
||||||
|
"steps": steps,
|
||||||
|
"epoch": epoch,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Tensorboard summary logging
|
||||||
|
if steps % a.summary_interval == 0:
|
||||||
|
mel_error = loss_mel.item() / lambda_melloss # Log training mel regression loss to tensorboard
|
||||||
|
sw.add_scalar("training/gen_loss_total", loss_gen_all.item(), steps)
|
||||||
|
sw.add_scalar("training/mel_spec_error", mel_error, steps)
|
||||||
|
sw.add_scalar("training/fm_loss_mpd", loss_fm_f.item(), steps)
|
||||||
|
sw.add_scalar("training/gen_loss_mpd", loss_gen_f.item(), steps)
|
||||||
|
sw.add_scalar("training/disc_loss_mpd", loss_disc_f.item(), steps)
|
||||||
|
sw.add_scalar("training/grad_norm_mpd", grad_norm_mpd, steps)
|
||||||
|
sw.add_scalar("training/fm_loss_mrd", loss_fm_s.item(), steps)
|
||||||
|
sw.add_scalar("training/gen_loss_mrd", loss_gen_s.item(), steps)
|
||||||
|
sw.add_scalar("training/disc_loss_mrd", loss_disc_s.item(), steps)
|
||||||
|
sw.add_scalar("training/grad_norm_mrd", grad_norm_mrd, steps)
|
||||||
|
sw.add_scalar("training/grad_norm_g", grad_norm_g, steps)
|
||||||
|
sw.add_scalar("training/learning_rate_d", scheduler_d.get_last_lr()[0], steps)
|
||||||
|
sw.add_scalar("training/learning_rate_g", scheduler_g.get_last_lr()[0], steps)
|
||||||
|
sw.add_scalar("training/epoch", epoch + 1, steps)
|
||||||
|
|
||||||
|
# Validation
|
||||||
|
if steps % a.validation_interval == 0:
|
||||||
|
# Plot training input x so far used
|
||||||
|
for i_x in range(x.shape[0]):
|
||||||
|
sw.add_figure(
|
||||||
|
f"training_input/x_{i_x}",
|
||||||
|
plot_spectrogram(x[i_x].cpu()),
|
||||||
|
steps,
|
||||||
|
)
|
||||||
|
sw.add_audio(
|
||||||
|
f"training_input/y_{i_x}",
|
||||||
|
y[i_x][0],
|
||||||
|
steps,
|
||||||
|
h.sampling_rate,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Seen and unseen speakers validation loops
|
||||||
|
if not a.debug and steps != 0:
|
||||||
|
validate(
|
||||||
|
rank,
|
||||||
|
a,
|
||||||
|
h,
|
||||||
|
validation_loader,
|
||||||
|
mode=f"seen_{train_loader.dataset.name}",
|
||||||
|
)
|
||||||
|
for i in range(len(list_unseen_validation_loader)):
|
||||||
|
validate(
|
||||||
|
rank,
|
||||||
|
a,
|
||||||
|
h,
|
||||||
|
list_unseen_validation_loader[i],
|
||||||
|
mode=f"unseen_{list_unseen_validation_loader[i].dataset.name}",
|
||||||
|
)
|
||||||
|
steps += 1
|
||||||
|
|
||||||
|
# BigVGAN-v2 learning rate scheduler is changed from epoch-level to step-level
|
||||||
|
scheduler_g.step()
|
||||||
|
scheduler_d.step()
|
||||||
|
|
||||||
|
if rank == 0:
|
||||||
|
print(f"Time taken for epoch {epoch + 1} is {int(time.time() - start)} sec\n")
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
print("Initializing Training Process..")
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
|
||||||
|
parser.add_argument("--group_name", default=None)
|
||||||
|
|
||||||
|
parser.add_argument("--input_wavs_dir", default="LibriTTS")
|
||||||
|
parser.add_argument("--input_mels_dir", default="ft_dataset")
|
||||||
|
parser.add_argument("--input_training_file", default="tests/LibriTTS/train-full.txt")
|
||||||
|
parser.add_argument("--input_validation_file", default="tests/LibriTTS/val-full.txt")
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--list_input_unseen_wavs_dir",
|
||||||
|
nargs="+",
|
||||||
|
default=["tests/LibriTTS", "tests/LibriTTS"],
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--list_input_unseen_validation_file",
|
||||||
|
nargs="+",
|
||||||
|
default=["tests/LibriTTS/dev-clean.txt", "tests/LibriTTS/dev-other.txt"],
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument("--checkpoint_path", default="exp/bigvgan")
|
||||||
|
parser.add_argument("--config", default="")
|
||||||
|
|
||||||
|
parser.add_argument("--training_epochs", default=100000, type=int)
|
||||||
|
parser.add_argument("--stdout_interval", default=5, type=int)
|
||||||
|
parser.add_argument("--checkpoint_interval", default=50000, type=int)
|
||||||
|
parser.add_argument("--summary_interval", default=100, type=int)
|
||||||
|
parser.add_argument("--validation_interval", default=50000, type=int)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--freeze_step",
|
||||||
|
default=0,
|
||||||
|
type=int,
|
||||||
|
help="freeze D for the first specified steps. G only uses regression loss for these steps.",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument("--fine_tuning", default=False, type=bool)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--debug",
|
||||||
|
default=False,
|
||||||
|
type=bool,
|
||||||
|
help="debug mode. skips validation loop throughout training",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--evaluate",
|
||||||
|
default=False,
|
||||||
|
type=bool,
|
||||||
|
help="only run evaluation from checkpoint and exit",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--eval_subsample",
|
||||||
|
default=5,
|
||||||
|
type=int,
|
||||||
|
help="subsampling during evaluation loop",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--skip_seen",
|
||||||
|
default=False,
|
||||||
|
type=bool,
|
||||||
|
help="skip seen dataset. useful for test set inference",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--save_audio",
|
||||||
|
default=False,
|
||||||
|
type=bool,
|
||||||
|
help="save audio of test set inference to disk",
|
||||||
|
)
|
||||||
|
|
||||||
|
a = parser.parse_args()
|
||||||
|
|
||||||
|
with open(a.config) as f:
|
||||||
|
data = f.read()
|
||||||
|
|
||||||
|
json_config = json.loads(data)
|
||||||
|
h = AttrDict(json_config)
|
||||||
|
|
||||||
|
build_env(a.config, "config.json", a.checkpoint_path)
|
||||||
|
|
||||||
|
torch.manual_seed(h.seed)
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
torch.cuda.manual_seed(h.seed)
|
||||||
|
h.num_gpus = torch.cuda.device_count()
|
||||||
|
h.batch_size = int(h.batch_size / h.num_gpus)
|
||||||
|
print(f"Batch size per GPU: {h.batch_size}")
|
||||||
|
else:
|
||||||
|
pass
|
||||||
|
|
||||||
|
if h.num_gpus > 1:
|
||||||
|
mp.spawn(
|
||||||
|
train,
|
||||||
|
nprocs=h.num_gpus,
|
||||||
|
args=(
|
||||||
|
a,
|
||||||
|
h,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
train(0, a, h)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
99
GPT_SoVITS/BigVGAN/utils0.py
Normal file
99
GPT_SoVITS/BigVGAN/utils0.py
Normal file
@ -0,0 +1,99 @@
|
|||||||
|
# Adapted from https://github.com/jik876/hifi-gan under the MIT license.
|
||||||
|
# LICENSE is in incl_licenses directory.
|
||||||
|
|
||||||
|
import glob
|
||||||
|
import os
|
||||||
|
import matplotlib
|
||||||
|
import torch
|
||||||
|
from torch.nn.utils import weight_norm
|
||||||
|
|
||||||
|
matplotlib.use("Agg")
|
||||||
|
import matplotlib.pylab as plt
|
||||||
|
from .meldataset import MAX_WAV_VALUE
|
||||||
|
from scipy.io.wavfile import write
|
||||||
|
|
||||||
|
|
||||||
|
def plot_spectrogram(spectrogram):
|
||||||
|
fig, ax = plt.subplots(figsize=(10, 2))
|
||||||
|
im = ax.imshow(spectrogram, aspect="auto", origin="lower", interpolation="none")
|
||||||
|
plt.colorbar(im, ax=ax)
|
||||||
|
|
||||||
|
fig.canvas.draw()
|
||||||
|
plt.close()
|
||||||
|
|
||||||
|
return fig
|
||||||
|
|
||||||
|
|
||||||
|
def plot_spectrogram_clipped(spectrogram, clip_max=2.0):
|
||||||
|
fig, ax = plt.subplots(figsize=(10, 2))
|
||||||
|
im = ax.imshow(
|
||||||
|
spectrogram,
|
||||||
|
aspect="auto",
|
||||||
|
origin="lower",
|
||||||
|
interpolation="none",
|
||||||
|
vmin=1e-6,
|
||||||
|
vmax=clip_max,
|
||||||
|
)
|
||||||
|
plt.colorbar(im, ax=ax)
|
||||||
|
|
||||||
|
fig.canvas.draw()
|
||||||
|
plt.close()
|
||||||
|
|
||||||
|
return fig
|
||||||
|
|
||||||
|
|
||||||
|
def init_weights(m, mean=0.0, std=0.01):
|
||||||
|
classname = m.__class__.__name__
|
||||||
|
if classname.find("Conv") != -1:
|
||||||
|
m.weight.data.normal_(mean, std)
|
||||||
|
|
||||||
|
|
||||||
|
def apply_weight_norm(m):
|
||||||
|
classname = m.__class__.__name__
|
||||||
|
if classname.find("Conv") != -1:
|
||||||
|
weight_norm(m)
|
||||||
|
|
||||||
|
|
||||||
|
def get_padding(kernel_size, dilation=1):
|
||||||
|
return int((kernel_size * dilation - dilation) / 2)
|
||||||
|
|
||||||
|
|
||||||
|
def load_checkpoint(filepath, device):
|
||||||
|
assert os.path.isfile(filepath)
|
||||||
|
print(f"Loading '{filepath}'")
|
||||||
|
checkpoint_dict = torch.load(filepath, map_location=device)
|
||||||
|
print("Complete.")
|
||||||
|
return checkpoint_dict
|
||||||
|
|
||||||
|
|
||||||
|
def save_checkpoint(filepath, obj):
|
||||||
|
print(f"Saving checkpoint to {filepath}")
|
||||||
|
torch.save(obj, filepath)
|
||||||
|
print("Complete.")
|
||||||
|
|
||||||
|
|
||||||
|
def scan_checkpoint(cp_dir, prefix, renamed_file=None):
|
||||||
|
# Fallback to original scanning logic first
|
||||||
|
pattern = os.path.join(cp_dir, prefix + "????????")
|
||||||
|
cp_list = glob.glob(pattern)
|
||||||
|
|
||||||
|
if len(cp_list) > 0:
|
||||||
|
last_checkpoint_path = sorted(cp_list)[-1]
|
||||||
|
print(f"[INFO] Resuming from checkpoint: '{last_checkpoint_path}'")
|
||||||
|
return last_checkpoint_path
|
||||||
|
|
||||||
|
# If no pattern-based checkpoints are found, check for renamed file
|
||||||
|
if renamed_file:
|
||||||
|
renamed_path = os.path.join(cp_dir, renamed_file)
|
||||||
|
if os.path.isfile(renamed_path):
|
||||||
|
print(f"[INFO] Resuming from renamed checkpoint: '{renamed_file}'")
|
||||||
|
return renamed_path
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def save_audio(audio, path, sr):
|
||||||
|
# wav: torch with 1d shape
|
||||||
|
audio = audio * MAX_WAV_VALUE
|
||||||
|
audio = audio.cpu().numpy().astype("int16")
|
||||||
|
write(path, sr, audio)
|
File diff suppressed because it is too large
Load Diff
@ -1,13 +1,15 @@
|
|||||||
|
import os
|
||||||
import os, sys
|
import sys
|
||||||
|
import threading
|
||||||
|
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
now_dir = os.getcwd()
|
now_dir = os.getcwd()
|
||||||
sys.path.append(now_dir)
|
sys.path.append(now_dir)
|
||||||
|
|
||||||
import re
|
import re
|
||||||
import torch
|
import torch
|
||||||
import LangSegment
|
from text.LangSegmenter import LangSegmenter
|
||||||
from text import chinese
|
from text import chinese
|
||||||
from typing import Dict, List, Tuple
|
from typing import Dict, List, Tuple
|
||||||
from text.cleaner import clean_text
|
from text.cleaner import clean_text
|
||||||
@ -17,17 +19,19 @@ from TTS_infer_pack.text_segmentation_method import split_big_text, splits, get_
|
|||||||
|
|
||||||
from tools.i18n.i18n import I18nAuto, scan_language_list
|
from tools.i18n.i18n import I18nAuto, scan_language_list
|
||||||
|
|
||||||
language=os.environ.get("language","Auto")
|
language = os.environ.get("language", "Auto")
|
||||||
language=sys.argv[-1] if sys.argv[-1] in scan_language_list() else language
|
language = sys.argv[-1] if sys.argv[-1] in scan_language_list() else language
|
||||||
i18n = I18nAuto(language=language)
|
i18n = I18nAuto(language=language)
|
||||||
punctuation = set(['!', '?', '…', ',', '.', '-'," "])
|
punctuation = set(["!", "?", "…", ",", ".", "-"])
|
||||||
|
|
||||||
def get_first(text:str) -> str:
|
|
||||||
|
def get_first(text: str) -> str:
|
||||||
pattern = "[" + "".join(re.escape(sep) for sep in splits) + "]"
|
pattern = "[" + "".join(re.escape(sep) for sep in splits) + "]"
|
||||||
text = re.split(pattern, text)[0].strip()
|
text = re.split(pattern, text)[0].strip()
|
||||||
return text
|
return text
|
||||||
|
|
||||||
def merge_short_text_in_array(texts:str, threshold:int) -> list:
|
|
||||||
|
def merge_short_text_in_array(texts: str, threshold: int) -> list:
|
||||||
if (len(texts)) < 2:
|
if (len(texts)) < 2:
|
||||||
return texts
|
return texts
|
||||||
result = []
|
result = []
|
||||||
@ -37,7 +41,7 @@ def merge_short_text_in_array(texts:str, threshold:int) -> list:
|
|||||||
if len(text) >= threshold:
|
if len(text) >= threshold:
|
||||||
result.append(text)
|
result.append(text)
|
||||||
text = ""
|
text = ""
|
||||||
if (len(text) > 0):
|
if len(text) > 0:
|
||||||
if len(result) == 0:
|
if len(result) == 0:
|
||||||
result.append(text)
|
result.append(text)
|
||||||
else:
|
else:
|
||||||
@ -45,27 +49,24 @@ def merge_short_text_in_array(texts:str, threshold:int) -> list:
|
|||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class TextPreprocessor:
|
class TextPreprocessor:
|
||||||
def __init__(self, bert_model:AutoModelForMaskedLM,
|
def __init__(self, bert_model: AutoModelForMaskedLM, tokenizer: AutoTokenizer, device: torch.device):
|
||||||
tokenizer:AutoTokenizer, device:torch.device):
|
|
||||||
self.bert_model = bert_model
|
self.bert_model = bert_model
|
||||||
self.tokenizer = tokenizer
|
self.tokenizer = tokenizer
|
||||||
self.device = device
|
self.device = device
|
||||||
|
self.bert_lock = threading.RLock()
|
||||||
def preprocess(self, text:str, lang:str, text_split_method:str, version:str="v1")->List[Dict]:
|
|
||||||
print(i18n("############ 切分文本 ############"))
|
def preprocess(self, text: str, lang: str, text_split_method: str, version: str = "v2") -> List[Dict]:
|
||||||
text = self.replace_consecutive_punctuation(text) # 变量命名应该是写错了
|
print(f"############ {i18n('切分文本')} ############")
|
||||||
|
text = self.replace_consecutive_punctuation(text)
|
||||||
texts = self.pre_seg_text(text, lang, text_split_method)
|
texts = self.pre_seg_text(text, lang, text_split_method)
|
||||||
result = []
|
result = []
|
||||||
print(i18n("############ 提取文本Bert特征 ############"))
|
print(f"############ {i18n('提取文本Bert特征')} ############")
|
||||||
for text in tqdm(texts):
|
for text in tqdm(texts):
|
||||||
phones, bert_features, norm_text = self.segment_and_extract_feature_for_text(text, lang, version)
|
phones, bert_features, norm_text = self.segment_and_extract_feature_for_text(text, lang, version)
|
||||||
if phones is None or norm_text=="":
|
if phones is None or norm_text == "":
|
||||||
continue
|
continue
|
||||||
res={
|
res = {
|
||||||
"phones": phones,
|
"phones": phones,
|
||||||
"bert_features": bert_features,
|
"bert_features": bert_features,
|
||||||
"norm_text": norm_text,
|
"norm_text": norm_text,
|
||||||
@ -73,18 +74,18 @@ class TextPreprocessor:
|
|||||||
result.append(res)
|
result.append(res)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def pre_seg_text(self, text:str, lang:str, text_split_method:str):
|
def pre_seg_text(self, text: str, lang: str, text_split_method: str):
|
||||||
text = text.strip("\n")
|
text = text.strip("\n")
|
||||||
if len(text) == 0:
|
if len(text) == 0:
|
||||||
return []
|
return []
|
||||||
if (text[0] not in splits and len(get_first(text)) < 4):
|
if text[0] not in splits and len(get_first(text)) < 4:
|
||||||
text = "。" + text if lang != "en" else "." + text
|
text = "。" + text if lang != "en" else "." + text
|
||||||
print(i18n("实际输入的目标文本:"))
|
print(i18n("实际输入的目标文本:"))
|
||||||
print(text)
|
print(text)
|
||||||
|
|
||||||
seg_method = get_seg_method(text_split_method)
|
seg_method = get_seg_method(text_split_method)
|
||||||
text = seg_method(text)
|
text = seg_method(text)
|
||||||
|
|
||||||
while "\n\n" in text:
|
while "\n\n" in text:
|
||||||
text = text.replace("\n\n", "\n")
|
text = text.replace("\n\n", "\n")
|
||||||
|
|
||||||
@ -93,103 +94,99 @@ class TextPreprocessor:
|
|||||||
_texts = merge_short_text_in_array(_texts, 5)
|
_texts = merge_short_text_in_array(_texts, 5)
|
||||||
texts = []
|
texts = []
|
||||||
|
|
||||||
|
|
||||||
for text in _texts:
|
for text in _texts:
|
||||||
# 解决输入目标文本的空行导致报错的问题
|
# 解决输入目标文本的空行导致报错的问题
|
||||||
if (len(text.strip()) == 0):
|
if len(text.strip()) == 0:
|
||||||
continue
|
continue
|
||||||
if not re.sub("\W+", "", text):
|
if not re.sub("\W+", "", text):
|
||||||
# 检测一下,如果是纯符号,就跳过。
|
# 检测一下,如果是纯符号,就跳过。
|
||||||
continue
|
continue
|
||||||
if (text[-1] not in splits): text += "。" if lang != "en" else "."
|
if text[-1] not in splits:
|
||||||
|
text += "。" if lang != "en" else "."
|
||||||
|
|
||||||
# 解决句子过长导致Bert报错的问题
|
# 解决句子过长导致Bert报错的问题
|
||||||
if (len(text) > 510):
|
if len(text) > 510:
|
||||||
texts.extend(split_big_text(text))
|
texts.extend(split_big_text(text))
|
||||||
else:
|
else:
|
||||||
texts.append(text)
|
texts.append(text)
|
||||||
|
|
||||||
print(i18n("实际输入的目标文本(切句后):"))
|
print(i18n("实际输入的目标文本(切句后):"))
|
||||||
print(texts)
|
print(texts)
|
||||||
return texts
|
return texts
|
||||||
|
|
||||||
def segment_and_extract_feature_for_text(self, text:str, language:str, version:str="v1")->Tuple[list, torch.Tensor, str]:
|
def segment_and_extract_feature_for_text(
|
||||||
|
self, text: str, language: str, version: str = "v1"
|
||||||
|
) -> Tuple[list, torch.Tensor, str]:
|
||||||
return self.get_phones_and_bert(text, language, version)
|
return self.get_phones_and_bert(text, language, version)
|
||||||
|
|
||||||
def get_phones_and_bert(self, text:str, language:str, version:str, final:bool=False):
|
def get_phones_and_bert(self, text: str, language: str, version: str, final: bool = False):
|
||||||
if language in {"en", "all_zh", "all_ja", "all_ko", "all_yue"}:
|
with self.bert_lock:
|
||||||
language = language.replace("all_","")
|
if language in {"en", "all_zh", "all_ja", "all_ko", "all_yue"}:
|
||||||
if language == "en":
|
# language = language.replace("all_","")
|
||||||
LangSegment.setfilters(["en"])
|
|
||||||
formattext = " ".join(tmp["text"] for tmp in LangSegment.getTexts(text))
|
|
||||||
else:
|
|
||||||
# 因无法区别中日韩文汉字,以用户输入为准
|
|
||||||
formattext = text
|
formattext = text
|
||||||
while " " in formattext:
|
while " " in formattext:
|
||||||
formattext = formattext.replace(" ", " ")
|
formattext = formattext.replace(" ", " ")
|
||||||
if language == "zh":
|
if language == "all_zh":
|
||||||
if re.search(r'[A-Za-z]', formattext):
|
if re.search(r"[A-Za-z]", formattext):
|
||||||
formattext = re.sub(r'[a-z]', lambda x: x.group(0).upper(), formattext)
|
formattext = re.sub(r"[a-z]", lambda x: x.group(0).upper(), formattext)
|
||||||
|
formattext = chinese.mix_text_normalize(formattext)
|
||||||
|
return self.get_phones_and_bert(formattext, "zh", version)
|
||||||
|
else:
|
||||||
|
phones, word2ph, norm_text = self.clean_text_inf(formattext, language, version)
|
||||||
|
bert = self.get_bert_feature(norm_text, word2ph).to(self.device)
|
||||||
|
elif language == "all_yue" and re.search(r"[A-Za-z]", formattext):
|
||||||
|
formattext = re.sub(r"[a-z]", lambda x: x.group(0).upper(), formattext)
|
||||||
formattext = chinese.mix_text_normalize(formattext)
|
formattext = chinese.mix_text_normalize(formattext)
|
||||||
return self.get_phones_and_bert(formattext,"zh",version)
|
return self.get_phones_and_bert(formattext, "yue", version)
|
||||||
else:
|
else:
|
||||||
phones, word2ph, norm_text = self.clean_text_inf(formattext, language, version)
|
phones, word2ph, norm_text = self.clean_text_inf(formattext, language, version)
|
||||||
bert = self.get_bert_feature(norm_text, word2ph).to(self.device)
|
bert = torch.zeros(
|
||||||
elif language == "yue" and re.search(r'[A-Za-z]', formattext):
|
(1024, len(phones)),
|
||||||
formattext = re.sub(r'[a-z]', lambda x: x.group(0).upper(), formattext)
|
dtype=torch.float32,
|
||||||
formattext = chinese.mix_text_normalize(formattext)
|
).to(self.device)
|
||||||
return self.get_phones_and_bert(formattext,"yue",version)
|
elif language in {"zh", "ja", "ko", "yue", "auto", "auto_yue"}:
|
||||||
else:
|
textlist = []
|
||||||
phones, word2ph, norm_text = self.clean_text_inf(formattext, language, version)
|
langlist = []
|
||||||
bert = torch.zeros(
|
if language == "auto":
|
||||||
(1024, len(phones)),
|
for tmp in LangSegmenter.getTexts(text):
|
||||||
dtype=torch.float32,
|
|
||||||
).to(self.device)
|
|
||||||
elif language in {"zh", "ja", "ko", "yue", "auto", "auto_yue"}:
|
|
||||||
textlist=[]
|
|
||||||
langlist=[]
|
|
||||||
LangSegment.setfilters(["zh","ja","en","ko"])
|
|
||||||
if language == "auto":
|
|
||||||
for tmp in LangSegment.getTexts(text):
|
|
||||||
langlist.append(tmp["lang"])
|
|
||||||
textlist.append(tmp["text"])
|
|
||||||
elif language == "auto_yue":
|
|
||||||
for tmp in LangSegment.getTexts(text):
|
|
||||||
if tmp["lang"] == "zh":
|
|
||||||
tmp["lang"] = "yue"
|
|
||||||
langlist.append(tmp["lang"])
|
|
||||||
textlist.append(tmp["text"])
|
|
||||||
else:
|
|
||||||
for tmp in LangSegment.getTexts(text):
|
|
||||||
if tmp["lang"] == "en":
|
|
||||||
langlist.append(tmp["lang"])
|
langlist.append(tmp["lang"])
|
||||||
else:
|
textlist.append(tmp["text"])
|
||||||
# 因无法区别中日韩文汉字,以用户输入为准
|
elif language == "auto_yue":
|
||||||
langlist.append(language)
|
for tmp in LangSegmenter.getTexts(text):
|
||||||
textlist.append(tmp["text"])
|
if tmp["lang"] == "zh":
|
||||||
# print(textlist)
|
tmp["lang"] = "yue"
|
||||||
# print(langlist)
|
langlist.append(tmp["lang"])
|
||||||
phones_list = []
|
textlist.append(tmp["text"])
|
||||||
bert_list = []
|
else:
|
||||||
norm_text_list = []
|
for tmp in LangSegmenter.getTexts(text):
|
||||||
for i in range(len(textlist)):
|
if tmp["lang"] == "en":
|
||||||
lang = langlist[i]
|
langlist.append(tmp["lang"])
|
||||||
phones, word2ph, norm_text = self.clean_text_inf(textlist[i], lang, version)
|
else:
|
||||||
bert = self.get_bert_inf(phones, word2ph, norm_text, lang)
|
# 因无法区别中日韩文汉字,以用户输入为准
|
||||||
phones_list.append(phones)
|
langlist.append(language)
|
||||||
norm_text_list.append(norm_text)
|
textlist.append(tmp["text"])
|
||||||
bert_list.append(bert)
|
# print(textlist)
|
||||||
bert = torch.cat(bert_list, dim=1)
|
# print(langlist)
|
||||||
phones = sum(phones_list, [])
|
phones_list = []
|
||||||
norm_text = ''.join(norm_text_list)
|
bert_list = []
|
||||||
|
norm_text_list = []
|
||||||
|
for i in range(len(textlist)):
|
||||||
|
lang = langlist[i]
|
||||||
|
phones, word2ph, norm_text = self.clean_text_inf(textlist[i], lang, version)
|
||||||
|
bert = self.get_bert_inf(phones, word2ph, norm_text, lang)
|
||||||
|
phones_list.append(phones)
|
||||||
|
norm_text_list.append(norm_text)
|
||||||
|
bert_list.append(bert)
|
||||||
|
bert = torch.cat(bert_list, dim=1)
|
||||||
|
phones = sum(phones_list, [])
|
||||||
|
norm_text = "".join(norm_text_list)
|
||||||
|
|
||||||
if not final and len(phones) < 6:
|
if not final and len(phones) < 6:
|
||||||
return self.get_phones_and_bert("." + text,language,version,final=True)
|
return self.get_phones_and_bert("." + text, language, version, final=True)
|
||||||
|
|
||||||
return phones, bert, norm_text
|
return phones, bert, norm_text
|
||||||
|
|
||||||
|
def get_bert_feature(self, text: str, word2ph: list) -> torch.Tensor:
|
||||||
def get_bert_feature(self, text:str, word2ph:list)->torch.Tensor:
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
inputs = self.tokenizer(text, return_tensors="pt")
|
inputs = self.tokenizer(text, return_tensors="pt")
|
||||||
for i in inputs:
|
for i in inputs:
|
||||||
@ -203,14 +200,15 @@ class TextPreprocessor:
|
|||||||
phone_level_feature.append(repeat_feature)
|
phone_level_feature.append(repeat_feature)
|
||||||
phone_level_feature = torch.cat(phone_level_feature, dim=0)
|
phone_level_feature = torch.cat(phone_level_feature, dim=0)
|
||||||
return phone_level_feature.T
|
return phone_level_feature.T
|
||||||
|
|
||||||
def clean_text_inf(self, text:str, language:str, version:str="v1"):
|
def clean_text_inf(self, text: str, language: str, version: str = "v2"):
|
||||||
|
language = language.replace("all_", "")
|
||||||
phones, word2ph, norm_text = clean_text(text, language, version)
|
phones, word2ph, norm_text = clean_text(text, language, version)
|
||||||
phones = cleaned_text_to_sequence(phones, version)
|
phones = cleaned_text_to_sequence(phones, version)
|
||||||
return phones, word2ph, norm_text
|
return phones, word2ph, norm_text
|
||||||
|
|
||||||
def get_bert_inf(self, phones:list, word2ph:list, norm_text:str, language:str):
|
def get_bert_inf(self, phones: list, word2ph: list, norm_text: str, language: str):
|
||||||
language=language.replace("all_","")
|
language = language.replace("all_", "")
|
||||||
if language == "zh":
|
if language == "zh":
|
||||||
feature = self.get_bert_feature(norm_text, word2ph).to(self.device)
|
feature = self.get_bert_feature(norm_text, word2ph).to(self.device)
|
||||||
else:
|
else:
|
||||||
@ -221,24 +219,19 @@ class TextPreprocessor:
|
|||||||
|
|
||||||
return feature
|
return feature
|
||||||
|
|
||||||
|
def filter_text(self, texts):
|
||||||
def filter_text(self,texts):
|
_text = []
|
||||||
_text=[]
|
if all(text in [None, " ", "\n", ""] for text in texts):
|
||||||
if all(text in [None, " ", "\n",""] for text in texts):
|
|
||||||
raise ValueError(i18n("请输入有效文本"))
|
raise ValueError(i18n("请输入有效文本"))
|
||||||
for text in texts:
|
for text in texts:
|
||||||
if text in [None, " ", ""]:
|
if text in [None, " ", ""]:
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
_text.append(text)
|
_text.append(text)
|
||||||
return _text
|
return _text
|
||||||
|
|
||||||
|
|
||||||
def replace_consecutive_punctuation(self,text):
|
def replace_consecutive_punctuation(self, text):
|
||||||
punctuations = ''.join(re.escape(p) for p in punctuation)
|
punctuations = "".join(re.escape(p) for p in punctuation)
|
||||||
pattern = f'([{punctuations}])([{punctuations}])+'
|
pattern = f"([{punctuations}])([{punctuations}])+"
|
||||||
result = re.sub(pattern, r'\1', text)
|
result = re.sub(pattern, r"\1", text)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@ -1 +1 @@
|
|||||||
from . import TTS, text_segmentation_method
|
from . import TTS, text_segmentation_method
|
||||||
|
@ -1,41 +1,57 @@
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
import re
|
import re
|
||||||
from typing import Callable
|
from typing import Callable
|
||||||
|
|
||||||
punctuation = set(['!', '?', '…', ',', '.', '-'," "])
|
punctuation = set(["!", "?", "…", ",", ".", "-", " "])
|
||||||
METHODS = dict()
|
METHODS = dict()
|
||||||
|
|
||||||
def get_method(name:str)->Callable:
|
|
||||||
|
def get_method(name: str) -> Callable:
|
||||||
method = METHODS.get(name, None)
|
method = METHODS.get(name, None)
|
||||||
if method is None:
|
if method is None:
|
||||||
raise ValueError(f"Method {name} not found")
|
raise ValueError(f"Method {name} not found")
|
||||||
return method
|
return method
|
||||||
|
|
||||||
def get_method_names()->list:
|
|
||||||
|
def get_method_names() -> list:
|
||||||
return list(METHODS.keys())
|
return list(METHODS.keys())
|
||||||
|
|
||||||
|
|
||||||
def register_method(name):
|
def register_method(name):
|
||||||
def decorator(func):
|
def decorator(func):
|
||||||
METHODS[name] = func
|
METHODS[name] = func
|
||||||
return func
|
return func
|
||||||
|
|
||||||
return decorator
|
return decorator
|
||||||
|
|
||||||
splits = {",", "。", "?", "!", ",", ".", "?", "!", "~", ":", ":", "—", "…", }
|
|
||||||
|
splits = {
|
||||||
|
",",
|
||||||
|
"。",
|
||||||
|
"?",
|
||||||
|
"!",
|
||||||
|
",",
|
||||||
|
".",
|
||||||
|
"?",
|
||||||
|
"!",
|
||||||
|
"~",
|
||||||
|
":",
|
||||||
|
":",
|
||||||
|
"—",
|
||||||
|
"…",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
def split_big_text(text, max_len=510):
|
def split_big_text(text, max_len=510):
|
||||||
# 定义全角和半角标点符号
|
# 定义全角和半角标点符号
|
||||||
punctuation = "".join(splits)
|
punctuation = "".join(splits)
|
||||||
|
|
||||||
# 切割文本
|
# 切割文本
|
||||||
segments = re.split('([' + punctuation + '])', text)
|
segments = re.split("([" + punctuation + "])", text)
|
||||||
|
|
||||||
# 初始化结果列表和当前片段
|
# 初始化结果列表和当前片段
|
||||||
result = []
|
result = []
|
||||||
current_segment = ''
|
current_segment = ""
|
||||||
|
|
||||||
for segment in segments:
|
for segment in segments:
|
||||||
# 如果当前片段加上新的片段长度超过max_len,就将当前片段加入结果列表,并重置当前片段
|
# 如果当前片段加上新的片段长度超过max_len,就将当前片段加入结果列表,并重置当前片段
|
||||||
if len(current_segment + segment) > max_len:
|
if len(current_segment + segment) > max_len:
|
||||||
@ -43,13 +59,12 @@ def split_big_text(text, max_len=510):
|
|||||||
current_segment = segment
|
current_segment = segment
|
||||||
else:
|
else:
|
||||||
current_segment += segment
|
current_segment += segment
|
||||||
|
|
||||||
# 将最后一个片段加入结果列表
|
# 将最后一个片段加入结果列表
|
||||||
if current_segment:
|
if current_segment:
|
||||||
result.append(current_segment)
|
result.append(current_segment)
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
def split(todo_text):
|
def split(todo_text):
|
||||||
@ -90,7 +105,7 @@ def cut1(inp):
|
|||||||
if len(split_idx) > 1:
|
if len(split_idx) > 1:
|
||||||
opts = []
|
opts = []
|
||||||
for idx in range(len(split_idx) - 1):
|
for idx in range(len(split_idx) - 1):
|
||||||
opts.append("".join(inps[split_idx[idx]: split_idx[idx + 1]]))
|
opts.append("".join(inps[split_idx[idx] : split_idx[idx + 1]]))
|
||||||
else:
|
else:
|
||||||
opts = [inp]
|
opts = [inp]
|
||||||
opts = [item for item in opts if not set(item).issubset(punctuation)]
|
opts = [item for item in opts if not set(item).issubset(punctuation)]
|
||||||
@ -123,6 +138,7 @@ def cut2(inp):
|
|||||||
opts = [item for item in opts if not set(item).issubset(punctuation)]
|
opts = [item for item in opts if not set(item).issubset(punctuation)]
|
||||||
return "\n".join(opts)
|
return "\n".join(opts)
|
||||||
|
|
||||||
|
|
||||||
# 按中文句号。切
|
# 按中文句号。切
|
||||||
@register_method("cut3")
|
@register_method("cut3")
|
||||||
def cut3(inp):
|
def cut3(inp):
|
||||||
@ -131,26 +147,28 @@ def cut3(inp):
|
|||||||
opts = [item for item in opts if not set(item).issubset(punctuation)]
|
opts = [item for item in opts if not set(item).issubset(punctuation)]
|
||||||
return "\n".join(opts)
|
return "\n".join(opts)
|
||||||
|
|
||||||
#按英文句号.切
|
|
||||||
|
# 按英文句号.切
|
||||||
@register_method("cut4")
|
@register_method("cut4")
|
||||||
def cut4(inp):
|
def cut4(inp):
|
||||||
inp = inp.strip("\n")
|
inp = inp.strip("\n")
|
||||||
opts = ["%s" % item for item in inp.strip(".").split(".")]
|
opts = re.split(r"(?<!\d)\.(?!\d)", inp.strip("."))
|
||||||
opts = [item for item in opts if not set(item).issubset(punctuation)]
|
opts = [item for item in opts if not set(item).issubset(punctuation)]
|
||||||
return "\n".join(opts)
|
return "\n".join(opts)
|
||||||
|
|
||||||
|
|
||||||
# 按标点符号切
|
# 按标点符号切
|
||||||
# contributed by https://github.com/AI-Hobbyist/GPT-SoVITS/blob/main/GPT_SoVITS/inference_webui.py
|
# contributed by https://github.com/AI-Hobbyist/GPT-SoVITS/blob/main/GPT_SoVITS/inference_webui.py
|
||||||
@register_method("cut5")
|
@register_method("cut5")
|
||||||
def cut5(inp):
|
def cut5(inp):
|
||||||
inp = inp.strip("\n")
|
inp = inp.strip("\n")
|
||||||
punds = {',', '.', ';', '?', '!', '、', ',', '。', '?', '!', ';', ':', '…'}
|
punds = {",", ".", ";", "?", "!", "、", ",", "。", "?", "!", ";", ":", "…"}
|
||||||
mergeitems = []
|
mergeitems = []
|
||||||
items = []
|
items = []
|
||||||
|
|
||||||
for i, char in enumerate(inp):
|
for i, char in enumerate(inp):
|
||||||
if char in punds:
|
if char in punds:
|
||||||
if char == '.' and i > 0 and i < len(inp) - 1 and inp[i - 1].isdigit() and inp[i + 1].isdigit():
|
if char == "." and i > 0 and i < len(inp) - 1 and inp[i - 1].isdigit() and inp[i + 1].isdigit():
|
||||||
items.append(char)
|
items.append(char)
|
||||||
else:
|
else:
|
||||||
items.append(char)
|
items.append(char)
|
||||||
@ -166,8 +184,6 @@ def cut5(inp):
|
|||||||
return "\n".join(opt)
|
return "\n".join(opt)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
if __name__ == '__main__':
|
|
||||||
method = get_method("cut5")
|
method = get_method("cut5")
|
||||||
print(method("你好,我是小明。你好,我是小红。你好,我是小刚。你好,我是小张。"))
|
print(method("你好,我是小明。你好,我是小红。你好,我是小刚。你好,我是小张。"))
|
||||||
|
|
||||||
|
1
GPT_SoVITS/configs/.gitignore
vendored
Normal file
1
GPT_SoVITS/configs/.gitignore
vendored
Normal file
@ -0,0 +1 @@
|
|||||||
|
*.yaml
|
@ -18,7 +18,8 @@
|
|||||||
"warmup_epochs": 0,
|
"warmup_epochs": 0,
|
||||||
"c_mel": 45,
|
"c_mel": 45,
|
||||||
"c_kl": 1.0,
|
"c_kl": 1.0,
|
||||||
"text_low_lr_rate": 0.4
|
"text_low_lr_rate": 0.4,
|
||||||
|
"grad_ckpt": false
|
||||||
},
|
},
|
||||||
"data": {
|
"data": {
|
||||||
"max_wav_value": 32768.0,
|
"max_wav_value": 32768.0,
|
||||||
|
@ -6,7 +6,7 @@ custom:
|
|||||||
t2s_weights_path: GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s1bert25hz-5kh-longer-epoch=12-step=369668.ckpt
|
t2s_weights_path: GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s1bert25hz-5kh-longer-epoch=12-step=369668.ckpt
|
||||||
version: v2
|
version: v2
|
||||||
vits_weights_path: GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s2G2333k.pth
|
vits_weights_path: GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s2G2333k.pth
|
||||||
default:
|
v1:
|
||||||
bert_base_path: GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large
|
bert_base_path: GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large
|
||||||
cnhuhbert_base_path: GPT_SoVITS/pretrained_models/chinese-hubert-base
|
cnhuhbert_base_path: GPT_SoVITS/pretrained_models/chinese-hubert-base
|
||||||
device: cpu
|
device: cpu
|
||||||
@ -14,7 +14,7 @@ default:
|
|||||||
t2s_weights_path: GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt
|
t2s_weights_path: GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt
|
||||||
version: v1
|
version: v1
|
||||||
vits_weights_path: GPT_SoVITS/pretrained_models/s2G488k.pth
|
vits_weights_path: GPT_SoVITS/pretrained_models/s2G488k.pth
|
||||||
default_v2:
|
v2:
|
||||||
bert_base_path: GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large
|
bert_base_path: GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large
|
||||||
cnhuhbert_base_path: GPT_SoVITS/pretrained_models/chinese-hubert-base
|
cnhuhbert_base_path: GPT_SoVITS/pretrained_models/chinese-hubert-base
|
||||||
device: cpu
|
device: cpu
|
||||||
@ -22,3 +22,19 @@ default_v2:
|
|||||||
t2s_weights_path: GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s1bert25hz-5kh-longer-epoch=12-step=369668.ckpt
|
t2s_weights_path: GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s1bert25hz-5kh-longer-epoch=12-step=369668.ckpt
|
||||||
version: v2
|
version: v2
|
||||||
vits_weights_path: GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s2G2333k.pth
|
vits_weights_path: GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s2G2333k.pth
|
||||||
|
v3:
|
||||||
|
bert_base_path: GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large
|
||||||
|
cnhuhbert_base_path: GPT_SoVITS/pretrained_models/chinese-hubert-base
|
||||||
|
device: cpu
|
||||||
|
is_half: false
|
||||||
|
t2s_weights_path: GPT_SoVITS/pretrained_models/s1v3.ckpt
|
||||||
|
version: v3
|
||||||
|
vits_weights_path: GPT_SoVITS/pretrained_models/s2Gv3.pth
|
||||||
|
v4:
|
||||||
|
bert_base_path: GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large
|
||||||
|
cnhuhbert_base_path: GPT_SoVITS/pretrained_models/chinese-hubert-base
|
||||||
|
device: cpu
|
||||||
|
is_half: false
|
||||||
|
t2s_weights_path: GPT_SoVITS/pretrained_models/s1v3.ckpt
|
||||||
|
version: v4
|
||||||
|
vits_weights_path: GPT_SoVITS/pretrained_models/gsv-v4-pretrained/s2Gv4.pth
|
||||||
|
@ -1,5 +1,13 @@
|
|||||||
import os, sys
|
import os
|
||||||
|
import sys
|
||||||
|
|
||||||
now_dir = os.getcwd()
|
now_dir = os.getcwd()
|
||||||
sys.path.insert(0, now_dir)
|
sys.path.insert(0, now_dir)
|
||||||
from text.g2pw import G2PWPinyin
|
from text.g2pw import G2PWPinyin
|
||||||
g2pw = G2PWPinyin(model_dir="GPT_SoVITS/text/G2PWModel",model_source="GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large",v_to_u=False, neutral_tone_with_five=True)
|
|
||||||
|
g2pw = G2PWPinyin(
|
||||||
|
model_dir="GPT_SoVITS/text/G2PWModel",
|
||||||
|
model_source="GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large",
|
||||||
|
v_to_u=False,
|
||||||
|
neutral_tone_with_five=True,
|
||||||
|
)
|
||||||
|
861
GPT_SoVITS/export_torch_script.py
Normal file
861
GPT_SoVITS/export_torch_script.py
Normal file
@ -0,0 +1,861 @@
|
|||||||
|
# modified from https://github.com/yangdongchao/SoundStorm/blob/master/soundstorm/s1/AR/models/t2s_model.py
|
||||||
|
# reference: https://github.com/lifeiteng/vall-e
|
||||||
|
import argparse
|
||||||
|
from typing import Optional
|
||||||
|
from my_utils import load_audio
|
||||||
|
import torch
|
||||||
|
import torchaudio
|
||||||
|
|
||||||
|
from torch import IntTensor, LongTensor, Tensor, nn
|
||||||
|
from torch.nn import functional as F
|
||||||
|
|
||||||
|
from transformers import AutoModelForMaskedLM, AutoTokenizer
|
||||||
|
from feature_extractor import cnhubert
|
||||||
|
|
||||||
|
from AR.models.t2s_lightning_module import Text2SemanticLightningModule
|
||||||
|
from module.models_onnx import SynthesizerTrn
|
||||||
|
|
||||||
|
from inference_webui import get_phones_and_bert
|
||||||
|
|
||||||
|
import os
|
||||||
|
import soundfile
|
||||||
|
|
||||||
|
default_config = {
|
||||||
|
"embedding_dim": 512,
|
||||||
|
"hidden_dim": 512,
|
||||||
|
"num_head": 8,
|
||||||
|
"num_layers": 12,
|
||||||
|
"num_codebook": 8,
|
||||||
|
"p_dropout": 0.0,
|
||||||
|
"vocab_size": 1024 + 1,
|
||||||
|
"phoneme_vocab_size": 512,
|
||||||
|
"EOS": 1024,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def get_raw_t2s_model(dict_s1) -> Text2SemanticLightningModule:
|
||||||
|
config = dict_s1["config"]
|
||||||
|
config["model"]["dropout"] = float(config["model"]["dropout"])
|
||||||
|
t2s_model = Text2SemanticLightningModule(config, "****", is_train=False)
|
||||||
|
t2s_model.load_state_dict(dict_s1["weight"])
|
||||||
|
t2s_model = t2s_model.eval()
|
||||||
|
return t2s_model
|
||||||
|
|
||||||
|
|
||||||
|
@torch.jit.script
|
||||||
|
def logits_to_probs(
|
||||||
|
logits,
|
||||||
|
previous_tokens: Optional[torch.Tensor] = None,
|
||||||
|
temperature: float = 1.0,
|
||||||
|
top_k: Optional[int] = None,
|
||||||
|
top_p: Optional[int] = None,
|
||||||
|
repetition_penalty: float = 1.0,
|
||||||
|
):
|
||||||
|
# if previous_tokens is not None:
|
||||||
|
# previous_tokens = previous_tokens.squeeze()
|
||||||
|
# print(logits.shape,previous_tokens.shape)
|
||||||
|
# pdb.set_trace()
|
||||||
|
if previous_tokens is not None and repetition_penalty != 1.0:
|
||||||
|
previous_tokens = previous_tokens.long()
|
||||||
|
score = torch.gather(logits, dim=1, index=previous_tokens)
|
||||||
|
score = torch.where(score < 0, score * repetition_penalty, score / repetition_penalty)
|
||||||
|
logits.scatter_(dim=1, index=previous_tokens, src=score)
|
||||||
|
|
||||||
|
if top_p is not None and top_p < 1.0:
|
||||||
|
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
|
||||||
|
cum_probs = torch.cumsum(torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1)
|
||||||
|
sorted_indices_to_remove = cum_probs > top_p
|
||||||
|
sorted_indices_to_remove[:, 0] = False # keep at least one option
|
||||||
|
indices_to_remove = sorted_indices_to_remove.scatter(dim=1, index=sorted_indices, src=sorted_indices_to_remove)
|
||||||
|
logits = logits.masked_fill(indices_to_remove, -float("Inf"))
|
||||||
|
|
||||||
|
logits = logits / max(temperature, 1e-5)
|
||||||
|
|
||||||
|
if top_k is not None:
|
||||||
|
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
|
||||||
|
pivot = v[:, -1].unsqueeze(-1)
|
||||||
|
logits = torch.where(logits < pivot, -float("Inf"), logits)
|
||||||
|
|
||||||
|
probs = torch.nn.functional.softmax(logits, dim=-1)
|
||||||
|
return probs
|
||||||
|
|
||||||
|
|
||||||
|
@torch.jit.script
|
||||||
|
def multinomial_sample_one_no_sync(probs_sort):
|
||||||
|
# Does multinomial sampling without a cuda synchronization
|
||||||
|
q = torch.randn_like(probs_sort)
|
||||||
|
return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int)
|
||||||
|
|
||||||
|
|
||||||
|
@torch.jit.script
|
||||||
|
def sample(
|
||||||
|
logits,
|
||||||
|
previous_tokens,
|
||||||
|
temperature: float = 1.0,
|
||||||
|
top_k: Optional[int] = None,
|
||||||
|
top_p: Optional[int] = None,
|
||||||
|
repetition_penalty: float = 1.0,
|
||||||
|
):
|
||||||
|
probs = logits_to_probs(
|
||||||
|
logits=logits,
|
||||||
|
previous_tokens=previous_tokens,
|
||||||
|
temperature=temperature,
|
||||||
|
top_k=top_k,
|
||||||
|
top_p=top_p,
|
||||||
|
repetition_penalty=repetition_penalty,
|
||||||
|
)
|
||||||
|
idx_next = multinomial_sample_one_no_sync(probs)
|
||||||
|
return idx_next, probs
|
||||||
|
|
||||||
|
|
||||||
|
@torch.jit.script
|
||||||
|
def spectrogram_torch(y: Tensor, n_fft: int, sampling_rate: int, hop_size: int, win_size: int, center: bool = False):
|
||||||
|
hann_window = torch.hann_window(win_size, device=y.device, dtype=y.dtype)
|
||||||
|
y = torch.nn.functional.pad(
|
||||||
|
y.unsqueeze(1),
|
||||||
|
(int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)),
|
||||||
|
mode="reflect",
|
||||||
|
)
|
||||||
|
y = y.squeeze(1)
|
||||||
|
spec = torch.stft(
|
||||||
|
y,
|
||||||
|
n_fft,
|
||||||
|
hop_length=hop_size,
|
||||||
|
win_length=win_size,
|
||||||
|
window=hann_window,
|
||||||
|
center=center,
|
||||||
|
pad_mode="reflect",
|
||||||
|
normalized=False,
|
||||||
|
onesided=True,
|
||||||
|
return_complex=False,
|
||||||
|
)
|
||||||
|
spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
|
||||||
|
return spec
|
||||||
|
|
||||||
|
|
||||||
|
class DictToAttrRecursive(dict):
|
||||||
|
def __init__(self, input_dict):
|
||||||
|
super().__init__(input_dict)
|
||||||
|
for key, value in input_dict.items():
|
||||||
|
if isinstance(value, dict):
|
||||||
|
value = DictToAttrRecursive(value)
|
||||||
|
self[key] = value
|
||||||
|
setattr(self, key, value)
|
||||||
|
|
||||||
|
def __getattr__(self, item):
|
||||||
|
try:
|
||||||
|
return self[item]
|
||||||
|
except KeyError:
|
||||||
|
raise AttributeError(f"Attribute {item} not found")
|
||||||
|
|
||||||
|
def __setattr__(self, key, value):
|
||||||
|
if isinstance(value, dict):
|
||||||
|
value = DictToAttrRecursive(value)
|
||||||
|
super(DictToAttrRecursive, self).__setitem__(key, value)
|
||||||
|
super().__setattr__(key, value)
|
||||||
|
|
||||||
|
def __delattr__(self, item):
|
||||||
|
try:
|
||||||
|
del self[item]
|
||||||
|
except KeyError:
|
||||||
|
raise AttributeError(f"Attribute {item} not found")
|
||||||
|
|
||||||
|
|
||||||
|
@torch.jit.script
|
||||||
|
class T2SMLP:
|
||||||
|
def __init__(self, w1, b1, w2, b2):
|
||||||
|
self.w1 = w1
|
||||||
|
self.b1 = b1
|
||||||
|
self.w2 = w2
|
||||||
|
self.b2 = b2
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = F.relu(F.linear(x, self.w1, self.b1))
|
||||||
|
x = F.linear(x, self.w2, self.b2)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
@torch.jit.script
|
||||||
|
class T2SBlock:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
num_heads: int,
|
||||||
|
hidden_dim: int,
|
||||||
|
mlp: T2SMLP,
|
||||||
|
qkv_w,
|
||||||
|
qkv_b,
|
||||||
|
out_w,
|
||||||
|
out_b,
|
||||||
|
norm_w1,
|
||||||
|
norm_b1,
|
||||||
|
norm_eps1: float,
|
||||||
|
norm_w2,
|
||||||
|
norm_b2,
|
||||||
|
norm_eps2: float,
|
||||||
|
):
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.mlp = mlp
|
||||||
|
self.hidden_dim: int = hidden_dim
|
||||||
|
self.qkv_w = qkv_w
|
||||||
|
self.qkv_b = qkv_b
|
||||||
|
self.out_w = out_w
|
||||||
|
self.out_b = out_b
|
||||||
|
self.norm_w1 = norm_w1
|
||||||
|
self.norm_b1 = norm_b1
|
||||||
|
self.norm_eps1 = norm_eps1
|
||||||
|
self.norm_w2 = norm_w2
|
||||||
|
self.norm_b2 = norm_b2
|
||||||
|
self.norm_eps2 = norm_eps2
|
||||||
|
|
||||||
|
self.false = torch.tensor(False, dtype=torch.bool)
|
||||||
|
|
||||||
|
@torch.jit.ignore
|
||||||
|
def to_mask(self, x: torch.Tensor, padding_mask: Optional[torch.Tensor]):
|
||||||
|
if padding_mask is None:
|
||||||
|
return x
|
||||||
|
|
||||||
|
if padding_mask.dtype == torch.bool:
|
||||||
|
return x.masked_fill(padding_mask, 0)
|
||||||
|
else:
|
||||||
|
return x * padding_mask
|
||||||
|
|
||||||
|
def process_prompt(self, x: torch.Tensor, attn_mask: torch.Tensor, padding_mask: Optional[torch.Tensor] = None):
|
||||||
|
q, k, v = F.linear(self.to_mask(x, padding_mask), self.qkv_w, self.qkv_b).chunk(3, dim=-1)
|
||||||
|
|
||||||
|
batch_size = q.shape[0]
|
||||||
|
q_len = q.shape[1]
|
||||||
|
kv_len = k.shape[1]
|
||||||
|
|
||||||
|
q = self.to_mask(q, padding_mask)
|
||||||
|
k_cache = self.to_mask(k, padding_mask)
|
||||||
|
v_cache = self.to_mask(v, padding_mask)
|
||||||
|
|
||||||
|
q = q.view(batch_size, q_len, self.num_heads, -1).transpose(1, 2)
|
||||||
|
k = k_cache.view(batch_size, kv_len, self.num_heads, -1).transpose(1, 2)
|
||||||
|
v = v_cache.view(batch_size, kv_len, self.num_heads, -1).transpose(1, 2)
|
||||||
|
|
||||||
|
attn = F.scaled_dot_product_attention(q, k, v, ~attn_mask)
|
||||||
|
|
||||||
|
attn = attn.permute(2, 0, 1, 3).reshape(batch_size * q_len, self.hidden_dim)
|
||||||
|
attn = attn.view(q_len, batch_size, self.hidden_dim).transpose(1, 0)
|
||||||
|
attn = F.linear(self.to_mask(attn, padding_mask), self.out_w, self.out_b)
|
||||||
|
|
||||||
|
if padding_mask is not None:
|
||||||
|
for i in range(batch_size):
|
||||||
|
# mask = padding_mask[i,:,0]
|
||||||
|
if self.false.device != padding_mask.device:
|
||||||
|
self.false = self.false.to(padding_mask.device)
|
||||||
|
idx = torch.where(padding_mask[i, :, 0] == self.false)[0]
|
||||||
|
x_item = x[i, idx, :].unsqueeze(0)
|
||||||
|
attn_item = attn[i, idx, :].unsqueeze(0)
|
||||||
|
x_item = x_item + attn_item
|
||||||
|
x_item = F.layer_norm(x_item, [self.hidden_dim], self.norm_w1, self.norm_b1, self.norm_eps1)
|
||||||
|
x_item = x_item + self.mlp.forward(x_item)
|
||||||
|
x_item = F.layer_norm(
|
||||||
|
x_item,
|
||||||
|
[self.hidden_dim],
|
||||||
|
self.norm_w2,
|
||||||
|
self.norm_b2,
|
||||||
|
self.norm_eps2,
|
||||||
|
)
|
||||||
|
x[i, idx, :] = x_item.squeeze(0)
|
||||||
|
x = self.to_mask(x, padding_mask)
|
||||||
|
else:
|
||||||
|
x = x + attn
|
||||||
|
x = F.layer_norm(x, [self.hidden_dim], self.norm_w1, self.norm_b1, self.norm_eps1)
|
||||||
|
x = x + self.mlp.forward(x)
|
||||||
|
x = F.layer_norm(
|
||||||
|
x,
|
||||||
|
[self.hidden_dim],
|
||||||
|
self.norm_w2,
|
||||||
|
self.norm_b2,
|
||||||
|
self.norm_eps2,
|
||||||
|
)
|
||||||
|
return x, k_cache, v_cache
|
||||||
|
|
||||||
|
def decode_next_token(self, x: torch.Tensor, k_cache: torch.Tensor, v_cache: torch.Tensor):
|
||||||
|
q, k, v = F.linear(x, self.qkv_w, self.qkv_b).chunk(3, dim=-1)
|
||||||
|
|
||||||
|
k_cache = torch.cat([k_cache, k], dim=1)
|
||||||
|
v_cache = torch.cat([v_cache, v], dim=1)
|
||||||
|
|
||||||
|
batch_size = q.shape[0]
|
||||||
|
q_len = q.shape[1]
|
||||||
|
kv_len = k_cache.shape[1]
|
||||||
|
|
||||||
|
q = q.view(batch_size, q_len, self.num_heads, -1).transpose(1, 2)
|
||||||
|
k = k_cache.view(batch_size, kv_len, self.num_heads, -1).transpose(1, 2)
|
||||||
|
v = v_cache.view(batch_size, kv_len, self.num_heads, -1).transpose(1, 2)
|
||||||
|
|
||||||
|
attn = F.scaled_dot_product_attention(q, k, v)
|
||||||
|
|
||||||
|
attn = attn.permute(2, 0, 1, 3).reshape(batch_size * q_len, self.hidden_dim)
|
||||||
|
attn = attn.view(q_len, batch_size, self.hidden_dim).transpose(1, 0)
|
||||||
|
attn = F.linear(attn, self.out_w, self.out_b)
|
||||||
|
|
||||||
|
x = x + attn
|
||||||
|
x = F.layer_norm(x, [self.hidden_dim], self.norm_w1, self.norm_b1, self.norm_eps1)
|
||||||
|
x = x + self.mlp.forward(x)
|
||||||
|
x = F.layer_norm(
|
||||||
|
x,
|
||||||
|
[self.hidden_dim],
|
||||||
|
self.norm_w2,
|
||||||
|
self.norm_b2,
|
||||||
|
self.norm_eps2,
|
||||||
|
)
|
||||||
|
return x, k_cache, v_cache
|
||||||
|
|
||||||
|
|
||||||
|
@torch.jit.script
|
||||||
|
class T2STransformer:
|
||||||
|
def __init__(self, num_blocks: int, blocks: list[T2SBlock]):
|
||||||
|
self.num_blocks: int = num_blocks
|
||||||
|
self.blocks = blocks
|
||||||
|
|
||||||
|
def process_prompt(self, x: torch.Tensor, attn_mask: torch.Tensor, padding_mask: Optional[torch.Tensor] = None):
|
||||||
|
k_cache: list[torch.Tensor] = []
|
||||||
|
v_cache: list[torch.Tensor] = []
|
||||||
|
for i in range(self.num_blocks):
|
||||||
|
x, k_cache_, v_cache_ = self.blocks[i].process_prompt(x, attn_mask, padding_mask)
|
||||||
|
k_cache.append(k_cache_)
|
||||||
|
v_cache.append(v_cache_)
|
||||||
|
return x, k_cache, v_cache
|
||||||
|
|
||||||
|
def decode_next_token(self, x: torch.Tensor, k_cache: list[torch.Tensor], v_cache: list[torch.Tensor]):
|
||||||
|
for i in range(self.num_blocks):
|
||||||
|
x, k_cache[i], v_cache[i] = self.blocks[i].decode_next_token(x, k_cache[i], v_cache[i])
|
||||||
|
return x, k_cache, v_cache
|
||||||
|
|
||||||
|
|
||||||
|
class VitsModel(nn.Module):
|
||||||
|
def __init__(self, vits_path):
|
||||||
|
super().__init__()
|
||||||
|
# dict_s2 = torch.load(vits_path,map_location="cpu")
|
||||||
|
dict_s2 = torch.load(vits_path)
|
||||||
|
self.hps = dict_s2["config"]
|
||||||
|
if dict_s2["weight"]["enc_p.text_embedding.weight"].shape[0] == 322:
|
||||||
|
self.hps["model"]["version"] = "v1"
|
||||||
|
else:
|
||||||
|
self.hps["model"]["version"] = "v2"
|
||||||
|
|
||||||
|
self.hps = DictToAttrRecursive(self.hps)
|
||||||
|
self.hps.model.semantic_frame_rate = "25hz"
|
||||||
|
self.vq_model = SynthesizerTrn(
|
||||||
|
self.hps.data.filter_length // 2 + 1,
|
||||||
|
self.hps.train.segment_size // self.hps.data.hop_length,
|
||||||
|
n_speakers=self.hps.data.n_speakers,
|
||||||
|
**self.hps.model,
|
||||||
|
)
|
||||||
|
self.vq_model.eval()
|
||||||
|
self.vq_model.load_state_dict(dict_s2["weight"], strict=False)
|
||||||
|
|
||||||
|
def forward(self, text_seq, pred_semantic, ref_audio, speed=1.0):
|
||||||
|
refer = spectrogram_torch(
|
||||||
|
ref_audio,
|
||||||
|
self.hps.data.filter_length,
|
||||||
|
self.hps.data.sampling_rate,
|
||||||
|
self.hps.data.hop_length,
|
||||||
|
self.hps.data.win_length,
|
||||||
|
center=False,
|
||||||
|
)
|
||||||
|
return self.vq_model(pred_semantic, text_seq, refer, speed)[0, 0]
|
||||||
|
|
||||||
|
|
||||||
|
class T2SModel(nn.Module):
|
||||||
|
def __init__(self, raw_t2s: Text2SemanticLightningModule):
|
||||||
|
super(T2SModel, self).__init__()
|
||||||
|
self.model_dim = raw_t2s.model.model_dim
|
||||||
|
self.embedding_dim = raw_t2s.model.embedding_dim
|
||||||
|
self.num_head = raw_t2s.model.num_head
|
||||||
|
self.num_layers = raw_t2s.model.num_layers
|
||||||
|
self.vocab_size = raw_t2s.model.vocab_size
|
||||||
|
self.phoneme_vocab_size = raw_t2s.model.phoneme_vocab_size
|
||||||
|
# self.p_dropout = float(raw_t2s.model.p_dropout)
|
||||||
|
self.EOS: int = int(raw_t2s.model.EOS)
|
||||||
|
self.norm_first = raw_t2s.model.norm_first
|
||||||
|
assert self.EOS == self.vocab_size - 1
|
||||||
|
self.hz = 50
|
||||||
|
|
||||||
|
self.bert_proj = raw_t2s.model.bert_proj
|
||||||
|
self.ar_text_embedding = raw_t2s.model.ar_text_embedding
|
||||||
|
self.ar_text_position = raw_t2s.model.ar_text_position
|
||||||
|
self.ar_audio_embedding = raw_t2s.model.ar_audio_embedding
|
||||||
|
self.ar_audio_position = raw_t2s.model.ar_audio_position
|
||||||
|
|
||||||
|
# self.t2s_transformer = T2STransformer(self.num_layers, blocks)
|
||||||
|
# self.t2s_transformer = raw_t2s.model.t2s_transformer
|
||||||
|
|
||||||
|
blocks = []
|
||||||
|
h = raw_t2s.model.h
|
||||||
|
|
||||||
|
for i in range(self.num_layers):
|
||||||
|
layer = h.layers[i]
|
||||||
|
t2smlp = T2SMLP(layer.linear1.weight, layer.linear1.bias, layer.linear2.weight, layer.linear2.bias)
|
||||||
|
|
||||||
|
block = T2SBlock(
|
||||||
|
self.num_head,
|
||||||
|
self.model_dim,
|
||||||
|
t2smlp,
|
||||||
|
layer.self_attn.in_proj_weight,
|
||||||
|
layer.self_attn.in_proj_bias,
|
||||||
|
layer.self_attn.out_proj.weight,
|
||||||
|
layer.self_attn.out_proj.bias,
|
||||||
|
layer.norm1.weight,
|
||||||
|
layer.norm1.bias,
|
||||||
|
layer.norm1.eps,
|
||||||
|
layer.norm2.weight,
|
||||||
|
layer.norm2.bias,
|
||||||
|
layer.norm2.eps,
|
||||||
|
)
|
||||||
|
|
||||||
|
blocks.append(block)
|
||||||
|
|
||||||
|
self.t2s_transformer = T2STransformer(self.num_layers, blocks)
|
||||||
|
|
||||||
|
# self.ar_predict_layer = nn.Linear(self.model_dim, self.vocab_size, bias=False)
|
||||||
|
self.ar_predict_layer = raw_t2s.model.ar_predict_layer
|
||||||
|
# self.loss_fct = nn.CrossEntropyLoss(reduction="sum")
|
||||||
|
self.max_sec = raw_t2s.config["data"]["max_sec"]
|
||||||
|
self.top_k = int(raw_t2s.config["inference"]["top_k"])
|
||||||
|
self.early_stop_num = torch.LongTensor([self.hz * self.max_sec])
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
prompts: LongTensor,
|
||||||
|
ref_seq: LongTensor,
|
||||||
|
text_seq: LongTensor,
|
||||||
|
ref_bert: torch.Tensor,
|
||||||
|
text_bert: torch.Tensor,
|
||||||
|
top_k: LongTensor,
|
||||||
|
):
|
||||||
|
bert = torch.cat([ref_bert.T, text_bert.T], 1)
|
||||||
|
all_phoneme_ids = torch.cat([ref_seq, text_seq], 1)
|
||||||
|
bert = bert.unsqueeze(0)
|
||||||
|
|
||||||
|
x = self.ar_text_embedding(all_phoneme_ids)
|
||||||
|
x = x + self.bert_proj(bert.transpose(1, 2))
|
||||||
|
x: torch.Tensor = self.ar_text_position(x)
|
||||||
|
|
||||||
|
early_stop_num = self.early_stop_num
|
||||||
|
|
||||||
|
# [1,N,512] [1,N]
|
||||||
|
# y, k, v, y_emb, x_example = self.first_stage_decoder(x, prompts)
|
||||||
|
y = prompts
|
||||||
|
# x_example = x[:,:,0] * 0.0
|
||||||
|
|
||||||
|
x_len = x.shape[1]
|
||||||
|
x_attn_mask = torch.zeros((x_len, x_len), dtype=torch.bool)
|
||||||
|
|
||||||
|
y_emb = self.ar_audio_embedding(y)
|
||||||
|
y_len = y_emb.shape[1]
|
||||||
|
prefix_len = y.shape[1]
|
||||||
|
y_pos = self.ar_audio_position(y_emb)
|
||||||
|
xy_pos = torch.concat([x, y_pos], dim=1)
|
||||||
|
|
||||||
|
bsz = x.shape[0]
|
||||||
|
src_len = x_len + y_len
|
||||||
|
x_attn_mask_pad = F.pad(
|
||||||
|
x_attn_mask,
|
||||||
|
(0, y_len), ###xx的纯0扩展到xx纯0+xy纯1,(x,x+y)
|
||||||
|
value=True,
|
||||||
|
)
|
||||||
|
y_attn_mask = F.pad( ###yy的右上1扩展到左边xy的0,(y,x+y)
|
||||||
|
torch.triu(torch.ones(y_len, y_len, dtype=torch.bool), diagonal=1),
|
||||||
|
(x_len, 0),
|
||||||
|
value=False,
|
||||||
|
)
|
||||||
|
xy_attn_mask = (
|
||||||
|
torch.concat([x_attn_mask_pad, y_attn_mask], dim=0)
|
||||||
|
.unsqueeze(0)
|
||||||
|
.expand(bsz * self.num_head, -1, -1)
|
||||||
|
.view(bsz, self.num_head, src_len, src_len)
|
||||||
|
.to(device=x.device, dtype=torch.bool)
|
||||||
|
)
|
||||||
|
|
||||||
|
idx = 0
|
||||||
|
top_k = int(top_k)
|
||||||
|
|
||||||
|
xy_dec, k_cache, v_cache = self.t2s_transformer.process_prompt(xy_pos, xy_attn_mask, None)
|
||||||
|
|
||||||
|
logits = self.ar_predict_layer(xy_dec[:, -1])
|
||||||
|
logits = logits[:, :-1]
|
||||||
|
samples = sample(logits, y, top_k=top_k, top_p=1, repetition_penalty=1.35, temperature=1.0)[0]
|
||||||
|
y = torch.concat([y, samples], dim=1)
|
||||||
|
y_emb = self.ar_audio_embedding(y[:, -1:])
|
||||||
|
xy_pos = y_emb * self.ar_audio_position.x_scale + self.ar_audio_position.alpha * self.ar_audio_position.pe[
|
||||||
|
:, y_len + idx
|
||||||
|
].to(dtype=y_emb.dtype, device=y_emb.device)
|
||||||
|
|
||||||
|
stop = False
|
||||||
|
# for idx in range(1, 50):
|
||||||
|
for idx in range(1, 1500):
|
||||||
|
# [1, N] [N_layer, N, 1, 512] [N_layer, N, 1, 512] [1, N, 512] [1] [1, N, 512] [1, N]
|
||||||
|
# y, k, v, y_emb, logits, samples = self.stage_decoder(y, k, v, y_emb, x_example)
|
||||||
|
xy_dec, k_cache, v_cache = self.t2s_transformer.decode_next_token(xy_pos, k_cache, v_cache)
|
||||||
|
logits = self.ar_predict_layer(xy_dec[:, -1])
|
||||||
|
|
||||||
|
if idx < 11: ###至少预测出10个token不然不给停止(0.4s)
|
||||||
|
logits = logits[:, :-1]
|
||||||
|
|
||||||
|
samples = sample(logits, y, top_k=top_k, top_p=1, repetition_penalty=1.35, temperature=1.0)[0]
|
||||||
|
|
||||||
|
y = torch.concat([y, samples], dim=1)
|
||||||
|
|
||||||
|
if early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num:
|
||||||
|
stop = True
|
||||||
|
if torch.argmax(logits, dim=-1)[0] == self.EOS or samples[0, 0] == self.EOS:
|
||||||
|
stop = True
|
||||||
|
if stop:
|
||||||
|
if y.shape[1] == 0:
|
||||||
|
y = torch.concat([y, torch.zeros_like(samples)], dim=1)
|
||||||
|
break
|
||||||
|
|
||||||
|
y_emb = self.ar_audio_embedding(y[:, -1:])
|
||||||
|
xy_pos = y_emb * self.ar_audio_position.x_scale + self.ar_audio_position.alpha * self.ar_audio_position.pe[
|
||||||
|
:, y_len + idx
|
||||||
|
].to(dtype=y_emb.dtype, device=y_emb.device)
|
||||||
|
|
||||||
|
y[0, -1] = 0
|
||||||
|
|
||||||
|
return y[:, -idx:].unsqueeze(0)
|
||||||
|
|
||||||
|
|
||||||
|
bert_path = os.environ.get("bert_path", "GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large")
|
||||||
|
cnhubert_base_path = "GPT_SoVITS/pretrained_models/chinese-hubert-base"
|
||||||
|
cnhubert.cnhubert_base_path = cnhubert_base_path
|
||||||
|
|
||||||
|
|
||||||
|
@torch.jit.script
|
||||||
|
def build_phone_level_feature(res: Tensor, word2ph: IntTensor):
|
||||||
|
phone_level_feature = []
|
||||||
|
for i in range(word2ph.shape[0]):
|
||||||
|
repeat_feature = res[i].repeat(word2ph[i].item(), 1)
|
||||||
|
phone_level_feature.append(repeat_feature)
|
||||||
|
phone_level_feature = torch.cat(phone_level_feature, dim=0)
|
||||||
|
# [sum(word2ph), 1024]
|
||||||
|
return phone_level_feature
|
||||||
|
|
||||||
|
|
||||||
|
class MyBertModel(torch.nn.Module):
|
||||||
|
def __init__(self, bert_model):
|
||||||
|
super(MyBertModel, self).__init__()
|
||||||
|
self.bert = bert_model
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self, input_ids: torch.Tensor, attention_mask: torch.Tensor, token_type_ids: torch.Tensor, word2ph: IntTensor
|
||||||
|
):
|
||||||
|
outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
|
||||||
|
# res = torch.cat(outputs["hidden_states"][-3:-2], -1)[0][1:-1]
|
||||||
|
res = torch.cat(outputs[1][-3:-2], -1)[0][1:-1]
|
||||||
|
return build_phone_level_feature(res, word2ph)
|
||||||
|
|
||||||
|
|
||||||
|
class SSLModel(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.ssl = cnhubert.get_model().model
|
||||||
|
|
||||||
|
def forward(self, ref_audio_16k) -> torch.Tensor:
|
||||||
|
ssl_content = self.ssl(ref_audio_16k)["last_hidden_state"].transpose(1, 2)
|
||||||
|
return ssl_content
|
||||||
|
|
||||||
|
|
||||||
|
class ExportSSLModel(torch.nn.Module):
|
||||||
|
def __init__(self, ssl: SSLModel):
|
||||||
|
super().__init__()
|
||||||
|
self.ssl = ssl
|
||||||
|
|
||||||
|
def forward(self, ref_audio: torch.Tensor):
|
||||||
|
return self.ssl(ref_audio)
|
||||||
|
|
||||||
|
@torch.jit.export
|
||||||
|
def resample(self, ref_audio: torch.Tensor, src_sr: int, dst_sr: int) -> torch.Tensor:
|
||||||
|
audio = resamplex(ref_audio, src_sr, dst_sr).float()
|
||||||
|
return audio
|
||||||
|
|
||||||
|
|
||||||
|
def export_bert(output_path):
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(bert_path)
|
||||||
|
|
||||||
|
text = "叹息声一声接着一声传出,木兰对着房门织布.听不见织布机织布的声音,只听见木兰在叹息.问木兰在想什么?问木兰在惦记什么?木兰答道,我也没有在想什么,也没有在惦记什么."
|
||||||
|
ref_bert_inputs = tokenizer(text, return_tensors="pt")
|
||||||
|
word2ph = []
|
||||||
|
for c in text:
|
||||||
|
if c in [",", "。", ":", "?", ",", ".", "?"]:
|
||||||
|
word2ph.append(1)
|
||||||
|
else:
|
||||||
|
word2ph.append(2)
|
||||||
|
ref_bert_inputs["word2ph"] = torch.Tensor(word2ph).int()
|
||||||
|
|
||||||
|
bert_model = AutoModelForMaskedLM.from_pretrained(bert_path, output_hidden_states=True, torchscript=True)
|
||||||
|
my_bert_model = MyBertModel(bert_model)
|
||||||
|
|
||||||
|
ref_bert_inputs = {
|
||||||
|
"input_ids": ref_bert_inputs["input_ids"],
|
||||||
|
"attention_mask": ref_bert_inputs["attention_mask"],
|
||||||
|
"token_type_ids": ref_bert_inputs["token_type_ids"],
|
||||||
|
"word2ph": ref_bert_inputs["word2ph"],
|
||||||
|
}
|
||||||
|
|
||||||
|
torch._dynamo.mark_dynamic(ref_bert_inputs["input_ids"], 1)
|
||||||
|
torch._dynamo.mark_dynamic(ref_bert_inputs["attention_mask"], 1)
|
||||||
|
torch._dynamo.mark_dynamic(ref_bert_inputs["token_type_ids"], 1)
|
||||||
|
torch._dynamo.mark_dynamic(ref_bert_inputs["word2ph"], 0)
|
||||||
|
|
||||||
|
my_bert_model = torch.jit.trace(my_bert_model, example_kwarg_inputs=ref_bert_inputs)
|
||||||
|
output_path = os.path.join(output_path, "bert_model.pt")
|
||||||
|
my_bert_model.save(output_path)
|
||||||
|
print("#### exported bert ####")
|
||||||
|
|
||||||
|
|
||||||
|
def export(gpt_path, vits_path, ref_audio_path, ref_text, output_path, export_bert_and_ssl=False, device="cpu"):
|
||||||
|
if not os.path.exists(output_path):
|
||||||
|
os.makedirs(output_path)
|
||||||
|
print(f"目录已创建: {output_path}")
|
||||||
|
else:
|
||||||
|
print(f"目录已存在: {output_path}")
|
||||||
|
|
||||||
|
ref_audio = torch.tensor([load_audio(ref_audio_path, 16000)]).float()
|
||||||
|
ssl = SSLModel()
|
||||||
|
if export_bert_and_ssl:
|
||||||
|
s = ExportSSLModel(torch.jit.trace(ssl, example_inputs=(ref_audio)))
|
||||||
|
ssl_path = os.path.join(output_path, "ssl_model.pt")
|
||||||
|
torch.jit.script(s).save(ssl_path)
|
||||||
|
print("#### exported ssl ####")
|
||||||
|
export_bert(output_path)
|
||||||
|
else:
|
||||||
|
s = ExportSSLModel(ssl)
|
||||||
|
|
||||||
|
print(f"device: {device}")
|
||||||
|
|
||||||
|
ref_seq_id, ref_bert_T, ref_norm_text = get_phones_and_bert(ref_text, "all_zh", "v2")
|
||||||
|
ref_seq = torch.LongTensor([ref_seq_id]).to(device)
|
||||||
|
ref_bert = ref_bert_T.T.to(ref_seq.device)
|
||||||
|
text_seq_id, text_bert_T, norm_text = get_phones_and_bert(
|
||||||
|
"这是一条测试语音,说什么无所谓,只是给它一个例子", "all_zh", "v2"
|
||||||
|
)
|
||||||
|
text_seq = torch.LongTensor([text_seq_id]).to(device)
|
||||||
|
text_bert = text_bert_T.T.to(text_seq.device)
|
||||||
|
|
||||||
|
ssl_content = ssl(ref_audio).to(device)
|
||||||
|
|
||||||
|
# vits_path = "SoVITS_weights_v2/xw_e8_s216.pth"
|
||||||
|
vits = VitsModel(vits_path).to(device)
|
||||||
|
vits.eval()
|
||||||
|
|
||||||
|
# gpt_path = "GPT_weights_v2/xw-e15.ckpt"
|
||||||
|
# dict_s1 = torch.load(gpt_path, map_location=device)
|
||||||
|
dict_s1 = torch.load(gpt_path)
|
||||||
|
raw_t2s = get_raw_t2s_model(dict_s1).to(device)
|
||||||
|
print("#### get_raw_t2s_model ####")
|
||||||
|
print(raw_t2s.config)
|
||||||
|
t2s_m = T2SModel(raw_t2s)
|
||||||
|
t2s_m.eval()
|
||||||
|
t2s = torch.jit.script(t2s_m).to(device)
|
||||||
|
print("#### script t2s_m ####")
|
||||||
|
|
||||||
|
print("vits.hps.data.sampling_rate:", vits.hps.data.sampling_rate)
|
||||||
|
gpt_sovits = GPT_SoVITS(t2s, vits).to(device)
|
||||||
|
gpt_sovits.eval()
|
||||||
|
|
||||||
|
ref_audio_sr = s.resample(ref_audio, 16000, 32000).to(device)
|
||||||
|
|
||||||
|
torch._dynamo.mark_dynamic(ssl_content, 2)
|
||||||
|
torch._dynamo.mark_dynamic(ref_audio_sr, 1)
|
||||||
|
torch._dynamo.mark_dynamic(ref_seq, 1)
|
||||||
|
torch._dynamo.mark_dynamic(text_seq, 1)
|
||||||
|
torch._dynamo.mark_dynamic(ref_bert, 0)
|
||||||
|
torch._dynamo.mark_dynamic(text_bert, 0)
|
||||||
|
|
||||||
|
top_k = torch.LongTensor([5]).to(device)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
gpt_sovits_export = torch.jit.trace(
|
||||||
|
gpt_sovits, example_inputs=(ssl_content, ref_audio_sr, ref_seq, text_seq, ref_bert, text_bert, top_k)
|
||||||
|
)
|
||||||
|
|
||||||
|
gpt_sovits_path = os.path.join(output_path, "gpt_sovits_model.pt")
|
||||||
|
gpt_sovits_export.save(gpt_sovits_path)
|
||||||
|
print("#### exported gpt_sovits ####")
|
||||||
|
|
||||||
|
|
||||||
|
@torch.jit.script
|
||||||
|
def parse_audio(ref_audio):
|
||||||
|
ref_audio_16k = torchaudio.functional.resample(ref_audio, 48000, 16000).float() # .to(ref_audio.device)
|
||||||
|
ref_audio_sr = torchaudio.functional.resample(ref_audio, 48000, 32000).float() # .to(ref_audio.device)
|
||||||
|
return ref_audio_16k, ref_audio_sr
|
||||||
|
|
||||||
|
|
||||||
|
@torch.jit.script
|
||||||
|
def resamplex(ref_audio: torch.Tensor, src_sr: int, dst_sr: int) -> torch.Tensor:
|
||||||
|
return torchaudio.functional.resample(ref_audio, src_sr, dst_sr).float()
|
||||||
|
|
||||||
|
|
||||||
|
class GPT_SoVITS(nn.Module):
|
||||||
|
def __init__(self, t2s: T2SModel, vits: VitsModel):
|
||||||
|
super().__init__()
|
||||||
|
self.t2s = t2s
|
||||||
|
self.vits = vits
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
ssl_content: torch.Tensor,
|
||||||
|
ref_audio_sr: torch.Tensor,
|
||||||
|
ref_seq: Tensor,
|
||||||
|
text_seq: Tensor,
|
||||||
|
ref_bert: Tensor,
|
||||||
|
text_bert: Tensor,
|
||||||
|
top_k: LongTensor,
|
||||||
|
speed=1.0,
|
||||||
|
):
|
||||||
|
codes = self.vits.vq_model.extract_latent(ssl_content)
|
||||||
|
prompt_semantic = codes[0, 0]
|
||||||
|
prompts = prompt_semantic.unsqueeze(0)
|
||||||
|
|
||||||
|
pred_semantic = self.t2s(prompts, ref_seq, text_seq, ref_bert, text_bert, top_k)
|
||||||
|
audio = self.vits(text_seq, pred_semantic, ref_audio_sr, speed)
|
||||||
|
return audio
|
||||||
|
|
||||||
|
|
||||||
|
def test():
|
||||||
|
parser = argparse.ArgumentParser(description="GPT-SoVITS Command Line Tool")
|
||||||
|
parser.add_argument("--gpt_model", required=True, help="Path to the GPT model file")
|
||||||
|
parser.add_argument("--sovits_model", required=True, help="Path to the SoVITS model file")
|
||||||
|
parser.add_argument("--ref_audio", required=True, help="Path to the reference audio file")
|
||||||
|
parser.add_argument("--ref_text", required=True, help="Path to the reference text file")
|
||||||
|
parser.add_argument("--output_path", required=True, help="Path to the output directory")
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
gpt_path = args.gpt_model
|
||||||
|
vits_path = args.sovits_model
|
||||||
|
ref_audio_path = args.ref_audio
|
||||||
|
ref_text = args.ref_text
|
||||||
|
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(bert_path)
|
||||||
|
# bert_model = AutoModelForMaskedLM.from_pretrained(bert_path,output_hidden_states=True,torchscript=True)
|
||||||
|
# bert = MyBertModel(bert_model)
|
||||||
|
my_bert = torch.jit.load("onnx/bert_model.pt", map_location="cuda")
|
||||||
|
|
||||||
|
# dict_s1 = torch.load(gpt_path, map_location="cuda")
|
||||||
|
# raw_t2s = get_raw_t2s_model(dict_s1)
|
||||||
|
# t2s = T2SModel(raw_t2s)
|
||||||
|
# t2s.eval()
|
||||||
|
# t2s = torch.jit.load("onnx/xw/t2s_model.pt",map_location='cuda')
|
||||||
|
|
||||||
|
# vits_path = "SoVITS_weights_v2/xw_e8_s216.pth"
|
||||||
|
# vits = VitsModel(vits_path)
|
||||||
|
# vits.eval()
|
||||||
|
|
||||||
|
# ssl = ExportSSLModel(SSLModel()).to('cuda')
|
||||||
|
# ssl.eval()
|
||||||
|
ssl = torch.jit.load("onnx/by/ssl_model.pt", map_location="cuda")
|
||||||
|
|
||||||
|
# gpt_sovits = GPT_SoVITS(t2s,vits)
|
||||||
|
gpt_sovits = torch.jit.load("onnx/by/gpt_sovits_model.pt", map_location="cuda")
|
||||||
|
|
||||||
|
ref_seq_id, ref_bert_T, ref_norm_text = get_phones_and_bert(ref_text, "all_zh", "v2")
|
||||||
|
ref_seq = torch.LongTensor([ref_seq_id])
|
||||||
|
ref_bert = ref_bert_T.T.to(ref_seq.device)
|
||||||
|
# text_seq_id,text_bert_T,norm_text = get_phones_and_bert("昨天晚上看见征兵文书,知道君主在大规模征兵,那么多卷征兵文册,每一卷上都有父亲的名字.","all_zh",'v2')
|
||||||
|
text = "昨天晚上看见征兵文书,知道君主在大规模征兵,那么多卷征兵文册,每一卷上都有父亲的名字."
|
||||||
|
|
||||||
|
text_seq_id, text_bert_T, norm_text = get_phones_and_bert(text, "all_zh", "v2")
|
||||||
|
|
||||||
|
test_bert = tokenizer(text, return_tensors="pt")
|
||||||
|
word2ph = []
|
||||||
|
for c in text:
|
||||||
|
if c in [",", "。", ":", "?", "?", ",", "."]:
|
||||||
|
word2ph.append(1)
|
||||||
|
else:
|
||||||
|
word2ph.append(2)
|
||||||
|
test_bert["word2ph"] = torch.Tensor(word2ph).int()
|
||||||
|
|
||||||
|
test_bert = my_bert(
|
||||||
|
test_bert["input_ids"].to("cuda"),
|
||||||
|
test_bert["attention_mask"].to("cuda"),
|
||||||
|
test_bert["token_type_ids"].to("cuda"),
|
||||||
|
test_bert["word2ph"].to("cuda"),
|
||||||
|
)
|
||||||
|
|
||||||
|
text_seq = torch.LongTensor([text_seq_id])
|
||||||
|
text_bert = text_bert_T.T.to(text_seq.device)
|
||||||
|
|
||||||
|
print("text_bert:", text_bert.shape, text_bert)
|
||||||
|
print("test_bert:", test_bert.shape, test_bert)
|
||||||
|
print(torch.allclose(text_bert.to("cuda"), test_bert))
|
||||||
|
|
||||||
|
print("text_seq:", text_seq.shape)
|
||||||
|
print("text_bert:", text_bert.shape, text_bert.type())
|
||||||
|
|
||||||
|
# [1,N]
|
||||||
|
ref_audio = torch.tensor([load_audio(ref_audio_path, 16000)]).float().to("cuda")
|
||||||
|
print("ref_audio:", ref_audio.shape)
|
||||||
|
|
||||||
|
ref_audio_sr = ssl.resample(ref_audio, 16000, 32000)
|
||||||
|
print("start ssl")
|
||||||
|
ssl_content = ssl(ref_audio)
|
||||||
|
|
||||||
|
print("start gpt_sovits:")
|
||||||
|
print("ssl_content:", ssl_content.shape)
|
||||||
|
print("ref_audio_sr:", ref_audio_sr.shape)
|
||||||
|
print("ref_seq:", ref_seq.shape)
|
||||||
|
ref_seq = ref_seq.to("cuda")
|
||||||
|
print("text_seq:", text_seq.shape)
|
||||||
|
text_seq = text_seq.to("cuda")
|
||||||
|
print("ref_bert:", ref_bert.shape)
|
||||||
|
ref_bert = ref_bert.to("cuda")
|
||||||
|
print("text_bert:", text_bert.shape)
|
||||||
|
text_bert = text_bert.to("cuda")
|
||||||
|
|
||||||
|
top_k = torch.LongTensor([5]).to("cuda")
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
audio = gpt_sovits(ssl_content, ref_audio_sr, ref_seq, text_seq, ref_bert, test_bert, top_k)
|
||||||
|
print("start write wav")
|
||||||
|
soundfile.write("out.wav", audio.detach().cpu().numpy(), 32000)
|
||||||
|
|
||||||
|
|
||||||
|
import text
|
||||||
|
import json
|
||||||
|
|
||||||
|
|
||||||
|
def export_symbel(version="v2"):
|
||||||
|
if version == "v1":
|
||||||
|
symbols = text._symbol_to_id_v1
|
||||||
|
with open("onnx/symbols_v1.json", "w") as file:
|
||||||
|
json.dump(symbols, file, indent=4)
|
||||||
|
else:
|
||||||
|
symbols = text._symbol_to_id_v2
|
||||||
|
with open("onnx/symbols_v2.json", "w") as file:
|
||||||
|
json.dump(symbols, file, indent=4)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(description="GPT-SoVITS Command Line Tool")
|
||||||
|
parser.add_argument("--gpt_model", required=True, help="Path to the GPT model file")
|
||||||
|
parser.add_argument("--sovits_model", required=True, help="Path to the SoVITS model file")
|
||||||
|
parser.add_argument("--ref_audio", required=True, help="Path to the reference audio file")
|
||||||
|
parser.add_argument("--ref_text", required=True, help="Path to the reference text file")
|
||||||
|
parser.add_argument("--output_path", required=True, help="Path to the output directory")
|
||||||
|
parser.add_argument("--export_common_model", action="store_true", help="Export Bert and SSL model")
|
||||||
|
parser.add_argument("--device", help="Device to use")
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
export(
|
||||||
|
gpt_path=args.gpt_model,
|
||||||
|
vits_path=args.sovits_model,
|
||||||
|
ref_audio_path=args.ref_audio,
|
||||||
|
ref_text=args.ref_text,
|
||||||
|
output_path=args.output_path,
|
||||||
|
device=args.device,
|
||||||
|
export_bert_and_ssl=args.export_common_model,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
import inference_webui
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
inference_webui.is_half = False
|
||||||
|
inference_webui.dtype = torch.float32
|
||||||
|
main()
|
||||||
|
# test()
|
1035
GPT_SoVITS/export_torch_script_v3.py
Normal file
1035
GPT_SoVITS/export_torch_script_v3.py
Normal file
File diff suppressed because it is too large
Load Diff
13
GPT_SoVITS/f5_tts/model/__init__.py
Normal file
13
GPT_SoVITS/f5_tts/model/__init__.py
Normal file
@ -0,0 +1,13 @@
|
|||||||
|
# from f5_tts.model.cfm import CFM
|
||||||
|
#
|
||||||
|
# from f5_tts.model.backbones.unett import UNetT
|
||||||
|
from GPT_SoVITS.f5_tts.model.backbones.dit import DiT
|
||||||
|
# from f5_tts.model.backbones.dit import DiTNoCond
|
||||||
|
# from f5_tts.model.backbones.dit import DiTNoCondNoT
|
||||||
|
# from f5_tts.model.backbones.mmdit import MMDiT
|
||||||
|
|
||||||
|
# from f5_tts.model.trainer import Trainer
|
||||||
|
|
||||||
|
|
||||||
|
# __all__ = ["CFM", "UNetT", "DiT", "MMDiT", "Trainer"]
|
||||||
|
# __all__ = ["CFM", "UNetT", "DiTNoCond","DiT", "MMDiT"]
|
20
GPT_SoVITS/f5_tts/model/backbones/README.md
Normal file
20
GPT_SoVITS/f5_tts/model/backbones/README.md
Normal file
@ -0,0 +1,20 @@
|
|||||||
|
## Backbones quick introduction
|
||||||
|
|
||||||
|
|
||||||
|
### unett.py
|
||||||
|
- flat unet transformer
|
||||||
|
- structure same as in e2-tts & voicebox paper except using rotary pos emb
|
||||||
|
- update: allow possible abs pos emb & convnextv2 blocks for embedded text before concat
|
||||||
|
|
||||||
|
### dit.py
|
||||||
|
- adaln-zero dit
|
||||||
|
- embedded timestep as condition
|
||||||
|
- concatted noised_input + masked_cond + embedded_text, linear proj in
|
||||||
|
- possible abs pos emb & convnextv2 blocks for embedded text before concat
|
||||||
|
- possible long skip connection (first layer to last layer)
|
||||||
|
|
||||||
|
### mmdit.py
|
||||||
|
- sd3 structure
|
||||||
|
- timestep as condition
|
||||||
|
- left stream: text embedded and applied a abs pos emb
|
||||||
|
- right stream: masked_cond & noised_input concatted and with same conv pos emb as unett
|
180
GPT_SoVITS/f5_tts/model/backbones/dit.py
Normal file
180
GPT_SoVITS/f5_tts/model/backbones/dit.py
Normal file
@ -0,0 +1,180 @@
|
|||||||
|
"""
|
||||||
|
ein notation:
|
||||||
|
b - batch
|
||||||
|
n - sequence
|
||||||
|
nt - text sequence
|
||||||
|
nw - raw wave length
|
||||||
|
d - dimension
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
from torch.utils.checkpoint import checkpoint
|
||||||
|
|
||||||
|
from x_transformers.x_transformers import RotaryEmbedding
|
||||||
|
|
||||||
|
from GPT_SoVITS.f5_tts.model.modules import (
|
||||||
|
TimestepEmbedding,
|
||||||
|
ConvNeXtV2Block,
|
||||||
|
ConvPositionEmbedding,
|
||||||
|
DiTBlock,
|
||||||
|
AdaLayerNormZero_Final,
|
||||||
|
precompute_freqs_cis,
|
||||||
|
get_pos_embed_indices,
|
||||||
|
)
|
||||||
|
|
||||||
|
from module.commons import sequence_mask
|
||||||
|
|
||||||
|
|
||||||
|
class TextEmbedding(nn.Module):
|
||||||
|
def __init__(self, text_dim, conv_layers=0, conv_mult=2):
|
||||||
|
super().__init__()
|
||||||
|
if conv_layers > 0:
|
||||||
|
self.extra_modeling = True
|
||||||
|
self.precompute_max_pos = 4096 # ~44s of 24khz audio
|
||||||
|
self.register_buffer("freqs_cis", precompute_freqs_cis(text_dim, self.precompute_max_pos), persistent=False)
|
||||||
|
self.text_blocks = nn.Sequential(
|
||||||
|
*[ConvNeXtV2Block(text_dim, text_dim * conv_mult) for _ in range(conv_layers)]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.extra_modeling = False
|
||||||
|
|
||||||
|
def forward(self, text: int["b nt"], seq_len, drop_text=False): # noqa: F722
|
||||||
|
batch, text_len = text.shape[0], text.shape[1]
|
||||||
|
|
||||||
|
if drop_text: # cfg for text
|
||||||
|
text = torch.zeros_like(text)
|
||||||
|
|
||||||
|
# possible extra modeling
|
||||||
|
if self.extra_modeling:
|
||||||
|
# sinus pos emb
|
||||||
|
batch_start = torch.zeros((batch,), dtype=torch.long)
|
||||||
|
pos_idx = get_pos_embed_indices(batch_start, seq_len, max_pos=self.precompute_max_pos)
|
||||||
|
text_pos_embed = self.freqs_cis[pos_idx]
|
||||||
|
|
||||||
|
# print(23333333,text.shape,text_pos_embed.shape)#torch.Size([7, 465, 256]) torch.Size([7, 465, 256])
|
||||||
|
|
||||||
|
text = text + text_pos_embed
|
||||||
|
|
||||||
|
# convnextv2 blocks
|
||||||
|
text = self.text_blocks(text)
|
||||||
|
|
||||||
|
return text
|
||||||
|
|
||||||
|
|
||||||
|
# noised input audio and context mixing embedding
|
||||||
|
|
||||||
|
|
||||||
|
class InputEmbedding(nn.Module):
|
||||||
|
def __init__(self, mel_dim, text_dim, out_dim):
|
||||||
|
super().__init__()
|
||||||
|
self.proj = nn.Linear(mel_dim * 2 + text_dim, out_dim)
|
||||||
|
self.conv_pos_embed = ConvPositionEmbedding(dim=out_dim)
|
||||||
|
|
||||||
|
def forward(self, x: float["b n d"], cond: float["b n d"], text_embed: float["b n d"], drop_audio_cond=False): # noqa: F722
|
||||||
|
if drop_audio_cond: # cfg for cond audio
|
||||||
|
cond = torch.zeros_like(cond)
|
||||||
|
|
||||||
|
x = self.proj(torch.cat((x, cond, text_embed), dim=-1))
|
||||||
|
x = self.conv_pos_embed(x) + x
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
# Transformer backbone using DiT blocks
|
||||||
|
|
||||||
|
|
||||||
|
class DiT(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
dim,
|
||||||
|
depth=8,
|
||||||
|
heads=8,
|
||||||
|
dim_head=64,
|
||||||
|
dropout=0.1,
|
||||||
|
ff_mult=4,
|
||||||
|
mel_dim=100,
|
||||||
|
text_dim=None,
|
||||||
|
conv_layers=0,
|
||||||
|
long_skip_connection=False,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.time_embed = TimestepEmbedding(dim)
|
||||||
|
self.d_embed = TimestepEmbedding(dim)
|
||||||
|
if text_dim is None:
|
||||||
|
text_dim = mel_dim
|
||||||
|
self.text_embed = TextEmbedding(text_dim, conv_layers=conv_layers)
|
||||||
|
self.input_embed = InputEmbedding(mel_dim, text_dim, dim)
|
||||||
|
|
||||||
|
self.rotary_embed = RotaryEmbedding(dim_head)
|
||||||
|
|
||||||
|
self.dim = dim
|
||||||
|
self.depth = depth
|
||||||
|
|
||||||
|
self.transformer_blocks = nn.ModuleList(
|
||||||
|
[DiTBlock(dim=dim, heads=heads, dim_head=dim_head, ff_mult=ff_mult, dropout=dropout) for _ in range(depth)]
|
||||||
|
)
|
||||||
|
self.long_skip_connection = nn.Linear(dim * 2, dim, bias=False) if long_skip_connection else None
|
||||||
|
|
||||||
|
self.norm_out = AdaLayerNormZero_Final(dim) # final modulation
|
||||||
|
self.proj_out = nn.Linear(dim, mel_dim)
|
||||||
|
|
||||||
|
def ckpt_wrapper(self, module):
|
||||||
|
# https://github.com/chuanyangjin/fast-DiT/blob/main/models.py
|
||||||
|
def ckpt_forward(*inputs):
|
||||||
|
outputs = module(*inputs)
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
return ckpt_forward
|
||||||
|
|
||||||
|
def forward( # x, prompt_x, x_lens, t, style,cond
|
||||||
|
self, # d is channel,n is T
|
||||||
|
x0: float["b n d"], # nosied input audio # noqa: F722
|
||||||
|
cond0: float["b n d"], # masked cond audio # noqa: F722
|
||||||
|
x_lens,
|
||||||
|
time: float["b"] | float[""], # time step # noqa: F821 F722
|
||||||
|
dt_base_bootstrap,
|
||||||
|
text0, # : int["b nt"] # noqa: F722#####condition feature
|
||||||
|
use_grad_ckpt=False, # bool
|
||||||
|
###no-use
|
||||||
|
drop_audio_cond=False, # cfg for cond audio
|
||||||
|
drop_text=False, # cfg for text
|
||||||
|
# mask: bool["b n"] | None = None, # noqa: F722
|
||||||
|
):
|
||||||
|
x = x0.transpose(2, 1)
|
||||||
|
cond = cond0.transpose(2, 1)
|
||||||
|
text = text0.transpose(2, 1)
|
||||||
|
mask = sequence_mask(x_lens, max_length=x.size(1)).to(x.device)
|
||||||
|
|
||||||
|
batch, seq_len = x.shape[0], x.shape[1]
|
||||||
|
if time.ndim == 0:
|
||||||
|
time = time.repeat(batch)
|
||||||
|
|
||||||
|
# t: conditioning time, c: context (text + masked cond audio), x: noised input audio
|
||||||
|
t = self.time_embed(time)
|
||||||
|
dt = self.d_embed(dt_base_bootstrap)
|
||||||
|
t += dt
|
||||||
|
text_embed = self.text_embed(text, seq_len, drop_text=drop_text) ###need to change
|
||||||
|
x = self.input_embed(x, cond, text_embed, drop_audio_cond=drop_audio_cond)
|
||||||
|
|
||||||
|
rope = self.rotary_embed.forward_from_seq_len(seq_len)
|
||||||
|
|
||||||
|
if self.long_skip_connection is not None:
|
||||||
|
residual = x
|
||||||
|
|
||||||
|
for block in self.transformer_blocks:
|
||||||
|
if use_grad_ckpt:
|
||||||
|
x = checkpoint(self.ckpt_wrapper(block), x, t, mask, rope, use_reentrant=False)
|
||||||
|
else:
|
||||||
|
x = block(x, t, mask=mask, rope=rope)
|
||||||
|
|
||||||
|
if self.long_skip_connection is not None:
|
||||||
|
x = self.long_skip_connection(torch.cat((x, residual), dim=-1))
|
||||||
|
|
||||||
|
x = self.norm_out(x, t)
|
||||||
|
output = self.proj_out(x)
|
||||||
|
|
||||||
|
return output
|
146
GPT_SoVITS/f5_tts/model/backbones/mmdit.py
Normal file
146
GPT_SoVITS/f5_tts/model/backbones/mmdit.py
Normal file
@ -0,0 +1,146 @@
|
|||||||
|
"""
|
||||||
|
ein notation:
|
||||||
|
b - batch
|
||||||
|
n - sequence
|
||||||
|
nt - text sequence
|
||||||
|
nw - raw wave length
|
||||||
|
d - dimension
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
from x_transformers.x_transformers import RotaryEmbedding
|
||||||
|
|
||||||
|
from f5_tts.model.modules import (
|
||||||
|
TimestepEmbedding,
|
||||||
|
ConvPositionEmbedding,
|
||||||
|
MMDiTBlock,
|
||||||
|
AdaLayerNormZero_Final,
|
||||||
|
precompute_freqs_cis,
|
||||||
|
get_pos_embed_indices,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# text embedding
|
||||||
|
|
||||||
|
|
||||||
|
class TextEmbedding(nn.Module):
|
||||||
|
def __init__(self, out_dim, text_num_embeds):
|
||||||
|
super().__init__()
|
||||||
|
self.text_embed = nn.Embedding(text_num_embeds + 1, out_dim) # will use 0 as filler token
|
||||||
|
|
||||||
|
self.precompute_max_pos = 1024
|
||||||
|
self.register_buffer("freqs_cis", precompute_freqs_cis(out_dim, self.precompute_max_pos), persistent=False)
|
||||||
|
|
||||||
|
def forward(self, text: int["b nt"], drop_text=False) -> int["b nt d"]: # noqa: F722
|
||||||
|
text = text + 1
|
||||||
|
if drop_text:
|
||||||
|
text = torch.zeros_like(text)
|
||||||
|
text = self.text_embed(text)
|
||||||
|
|
||||||
|
# sinus pos emb
|
||||||
|
batch_start = torch.zeros((text.shape[0],), dtype=torch.long)
|
||||||
|
batch_text_len = text.shape[1]
|
||||||
|
pos_idx = get_pos_embed_indices(batch_start, batch_text_len, max_pos=self.precompute_max_pos)
|
||||||
|
text_pos_embed = self.freqs_cis[pos_idx]
|
||||||
|
|
||||||
|
text = text + text_pos_embed
|
||||||
|
|
||||||
|
return text
|
||||||
|
|
||||||
|
|
||||||
|
# noised input & masked cond audio embedding
|
||||||
|
|
||||||
|
|
||||||
|
class AudioEmbedding(nn.Module):
|
||||||
|
def __init__(self, in_dim, out_dim):
|
||||||
|
super().__init__()
|
||||||
|
self.linear = nn.Linear(2 * in_dim, out_dim)
|
||||||
|
self.conv_pos_embed = ConvPositionEmbedding(out_dim)
|
||||||
|
|
||||||
|
def forward(self, x: float["b n d"], cond: float["b n d"], drop_audio_cond=False): # noqa: F722
|
||||||
|
if drop_audio_cond:
|
||||||
|
cond = torch.zeros_like(cond)
|
||||||
|
x = torch.cat((x, cond), dim=-1)
|
||||||
|
x = self.linear(x)
|
||||||
|
x = self.conv_pos_embed(x) + x
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
# Transformer backbone using MM-DiT blocks
|
||||||
|
|
||||||
|
|
||||||
|
class MMDiT(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
dim,
|
||||||
|
depth=8,
|
||||||
|
heads=8,
|
||||||
|
dim_head=64,
|
||||||
|
dropout=0.1,
|
||||||
|
ff_mult=4,
|
||||||
|
text_num_embeds=256,
|
||||||
|
mel_dim=100,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.time_embed = TimestepEmbedding(dim)
|
||||||
|
self.text_embed = TextEmbedding(dim, text_num_embeds)
|
||||||
|
self.audio_embed = AudioEmbedding(mel_dim, dim)
|
||||||
|
|
||||||
|
self.rotary_embed = RotaryEmbedding(dim_head)
|
||||||
|
|
||||||
|
self.dim = dim
|
||||||
|
self.depth = depth
|
||||||
|
|
||||||
|
self.transformer_blocks = nn.ModuleList(
|
||||||
|
[
|
||||||
|
MMDiTBlock(
|
||||||
|
dim=dim,
|
||||||
|
heads=heads,
|
||||||
|
dim_head=dim_head,
|
||||||
|
dropout=dropout,
|
||||||
|
ff_mult=ff_mult,
|
||||||
|
context_pre_only=i == depth - 1,
|
||||||
|
)
|
||||||
|
for i in range(depth)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
self.norm_out = AdaLayerNormZero_Final(dim) # final modulation
|
||||||
|
self.proj_out = nn.Linear(dim, mel_dim)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x: float["b n d"], # nosied input audio # noqa: F722
|
||||||
|
cond: float["b n d"], # masked cond audio # noqa: F722
|
||||||
|
text: int["b nt"], # text # noqa: F722
|
||||||
|
time: float["b"] | float[""], # time step # noqa: F821 F722
|
||||||
|
drop_audio_cond, # cfg for cond audio
|
||||||
|
drop_text, # cfg for text
|
||||||
|
mask: bool["b n"] | None = None, # noqa: F722
|
||||||
|
):
|
||||||
|
batch = x.shape[0]
|
||||||
|
if time.ndim == 0:
|
||||||
|
time = time.repeat(batch)
|
||||||
|
|
||||||
|
# t: conditioning (time), c: context (text + masked cond audio), x: noised input audio
|
||||||
|
t = self.time_embed(time)
|
||||||
|
c = self.text_embed(text, drop_text=drop_text)
|
||||||
|
x = self.audio_embed(x, cond, drop_audio_cond=drop_audio_cond)
|
||||||
|
|
||||||
|
seq_len = x.shape[1]
|
||||||
|
text_len = text.shape[1]
|
||||||
|
rope_audio = self.rotary_embed.forward_from_seq_len(seq_len)
|
||||||
|
rope_text = self.rotary_embed.forward_from_seq_len(text_len)
|
||||||
|
|
||||||
|
for block in self.transformer_blocks:
|
||||||
|
c, x = block(x, c, t, mask=mask, rope=rope_audio, c_rope=rope_text)
|
||||||
|
|
||||||
|
x = self.norm_out(x, t)
|
||||||
|
output = self.proj_out(x)
|
||||||
|
|
||||||
|
return output
|
219
GPT_SoVITS/f5_tts/model/backbones/unett.py
Normal file
219
GPT_SoVITS/f5_tts/model/backbones/unett.py
Normal file
@ -0,0 +1,219 @@
|
|||||||
|
"""
|
||||||
|
ein notation:
|
||||||
|
b - batch
|
||||||
|
n - sequence
|
||||||
|
nt - text sequence
|
||||||
|
nw - raw wave length
|
||||||
|
d - dimension
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
from typing import Literal
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
from x_transformers import RMSNorm
|
||||||
|
from x_transformers.x_transformers import RotaryEmbedding
|
||||||
|
|
||||||
|
from f5_tts.model.modules import (
|
||||||
|
TimestepEmbedding,
|
||||||
|
ConvNeXtV2Block,
|
||||||
|
ConvPositionEmbedding,
|
||||||
|
Attention,
|
||||||
|
AttnProcessor,
|
||||||
|
FeedForward,
|
||||||
|
precompute_freqs_cis,
|
||||||
|
get_pos_embed_indices,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Text embedding
|
||||||
|
|
||||||
|
|
||||||
|
class TextEmbedding(nn.Module):
|
||||||
|
def __init__(self, text_num_embeds, text_dim, conv_layers=0, conv_mult=2):
|
||||||
|
super().__init__()
|
||||||
|
self.text_embed = nn.Embedding(text_num_embeds + 1, text_dim) # use 0 as filler token
|
||||||
|
|
||||||
|
if conv_layers > 0:
|
||||||
|
self.extra_modeling = True
|
||||||
|
self.precompute_max_pos = 4096 # ~44s of 24khz audio
|
||||||
|
self.register_buffer("freqs_cis", precompute_freqs_cis(text_dim, self.precompute_max_pos), persistent=False)
|
||||||
|
self.text_blocks = nn.Sequential(
|
||||||
|
*[ConvNeXtV2Block(text_dim, text_dim * conv_mult) for _ in range(conv_layers)]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.extra_modeling = False
|
||||||
|
|
||||||
|
def forward(self, text: int["b nt"], seq_len, drop_text=False): # noqa: F722
|
||||||
|
text = text + 1 # use 0 as filler token. preprocess of batch pad -1, see list_str_to_idx()
|
||||||
|
text = text[:, :seq_len] # curtail if character tokens are more than the mel spec tokens
|
||||||
|
batch, text_len = text.shape[0], text.shape[1]
|
||||||
|
text = F.pad(text, (0, seq_len - text_len), value=0)
|
||||||
|
|
||||||
|
if drop_text: # cfg for text
|
||||||
|
text = torch.zeros_like(text)
|
||||||
|
|
||||||
|
text = self.text_embed(text) # b n -> b n d
|
||||||
|
|
||||||
|
# possible extra modeling
|
||||||
|
if self.extra_modeling:
|
||||||
|
# sinus pos emb
|
||||||
|
batch_start = torch.zeros((batch,), dtype=torch.long)
|
||||||
|
pos_idx = get_pos_embed_indices(batch_start, seq_len, max_pos=self.precompute_max_pos)
|
||||||
|
text_pos_embed = self.freqs_cis[pos_idx]
|
||||||
|
text = text + text_pos_embed
|
||||||
|
|
||||||
|
# convnextv2 blocks
|
||||||
|
text = self.text_blocks(text)
|
||||||
|
|
||||||
|
return text
|
||||||
|
|
||||||
|
|
||||||
|
# noised input audio and context mixing embedding
|
||||||
|
|
||||||
|
|
||||||
|
class InputEmbedding(nn.Module):
|
||||||
|
def __init__(self, mel_dim, text_dim, out_dim):
|
||||||
|
super().__init__()
|
||||||
|
self.proj = nn.Linear(mel_dim * 2 + text_dim, out_dim)
|
||||||
|
self.conv_pos_embed = ConvPositionEmbedding(dim=out_dim)
|
||||||
|
|
||||||
|
def forward(self, x: float["b n d"], cond: float["b n d"], text_embed: float["b n d"], drop_audio_cond=False): # noqa: F722
|
||||||
|
if drop_audio_cond: # cfg for cond audio
|
||||||
|
cond = torch.zeros_like(cond)
|
||||||
|
|
||||||
|
x = self.proj(torch.cat((x, cond, text_embed), dim=-1))
|
||||||
|
x = self.conv_pos_embed(x) + x
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
# Flat UNet Transformer backbone
|
||||||
|
|
||||||
|
|
||||||
|
class UNetT(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
dim,
|
||||||
|
depth=8,
|
||||||
|
heads=8,
|
||||||
|
dim_head=64,
|
||||||
|
dropout=0.1,
|
||||||
|
ff_mult=4,
|
||||||
|
mel_dim=100,
|
||||||
|
text_num_embeds=256,
|
||||||
|
text_dim=None,
|
||||||
|
conv_layers=0,
|
||||||
|
skip_connect_type: Literal["add", "concat", "none"] = "concat",
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
assert depth % 2 == 0, "UNet-Transformer's depth should be even."
|
||||||
|
|
||||||
|
self.time_embed = TimestepEmbedding(dim)
|
||||||
|
if text_dim is None:
|
||||||
|
text_dim = mel_dim
|
||||||
|
self.text_embed = TextEmbedding(text_num_embeds, text_dim, conv_layers=conv_layers)
|
||||||
|
self.input_embed = InputEmbedding(mel_dim, text_dim, dim)
|
||||||
|
|
||||||
|
self.rotary_embed = RotaryEmbedding(dim_head)
|
||||||
|
|
||||||
|
# transformer layers & skip connections
|
||||||
|
|
||||||
|
self.dim = dim
|
||||||
|
self.skip_connect_type = skip_connect_type
|
||||||
|
needs_skip_proj = skip_connect_type == "concat"
|
||||||
|
|
||||||
|
self.depth = depth
|
||||||
|
self.layers = nn.ModuleList([])
|
||||||
|
|
||||||
|
for idx in range(depth):
|
||||||
|
is_later_half = idx >= (depth // 2)
|
||||||
|
|
||||||
|
attn_norm = RMSNorm(dim)
|
||||||
|
attn = Attention(
|
||||||
|
processor=AttnProcessor(),
|
||||||
|
dim=dim,
|
||||||
|
heads=heads,
|
||||||
|
dim_head=dim_head,
|
||||||
|
dropout=dropout,
|
||||||
|
)
|
||||||
|
|
||||||
|
ff_norm = RMSNorm(dim)
|
||||||
|
ff = FeedForward(dim=dim, mult=ff_mult, dropout=dropout, approximate="tanh")
|
||||||
|
|
||||||
|
skip_proj = nn.Linear(dim * 2, dim, bias=False) if needs_skip_proj and is_later_half else None
|
||||||
|
|
||||||
|
self.layers.append(
|
||||||
|
nn.ModuleList(
|
||||||
|
[
|
||||||
|
skip_proj,
|
||||||
|
attn_norm,
|
||||||
|
attn,
|
||||||
|
ff_norm,
|
||||||
|
ff,
|
||||||
|
]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.norm_out = RMSNorm(dim)
|
||||||
|
self.proj_out = nn.Linear(dim, mel_dim)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x: float["b n d"], # nosied input audio # noqa: F722
|
||||||
|
cond: float["b n d"], # masked cond audio # noqa: F722
|
||||||
|
text: int["b nt"], # text # noqa: F722
|
||||||
|
time: float["b"] | float[""], # time step # noqa: F821 F722
|
||||||
|
drop_audio_cond, # cfg for cond audio
|
||||||
|
drop_text, # cfg for text
|
||||||
|
mask: bool["b n"] | None = None, # noqa: F722
|
||||||
|
):
|
||||||
|
batch, seq_len = x.shape[0], x.shape[1]
|
||||||
|
if time.ndim == 0:
|
||||||
|
time = time.repeat(batch)
|
||||||
|
|
||||||
|
# t: conditioning time, c: context (text + masked cond audio), x: noised input audio
|
||||||
|
t = self.time_embed(time)
|
||||||
|
text_embed = self.text_embed(text, seq_len, drop_text=drop_text)
|
||||||
|
x = self.input_embed(x, cond, text_embed, drop_audio_cond=drop_audio_cond)
|
||||||
|
|
||||||
|
# postfix time t to input x, [b n d] -> [b n+1 d]
|
||||||
|
x = torch.cat([t.unsqueeze(1), x], dim=1) # pack t to x
|
||||||
|
if mask is not None:
|
||||||
|
mask = F.pad(mask, (1, 0), value=1)
|
||||||
|
|
||||||
|
rope = self.rotary_embed.forward_from_seq_len(seq_len + 1)
|
||||||
|
|
||||||
|
# flat unet transformer
|
||||||
|
skip_connect_type = self.skip_connect_type
|
||||||
|
skips = []
|
||||||
|
for idx, (maybe_skip_proj, attn_norm, attn, ff_norm, ff) in enumerate(self.layers):
|
||||||
|
layer = idx + 1
|
||||||
|
|
||||||
|
# skip connection logic
|
||||||
|
is_first_half = layer <= (self.depth // 2)
|
||||||
|
is_later_half = not is_first_half
|
||||||
|
|
||||||
|
if is_first_half:
|
||||||
|
skips.append(x)
|
||||||
|
|
||||||
|
if is_later_half:
|
||||||
|
skip = skips.pop()
|
||||||
|
if skip_connect_type == "concat":
|
||||||
|
x = torch.cat((x, skip), dim=-1)
|
||||||
|
x = maybe_skip_proj(x)
|
||||||
|
elif skip_connect_type == "add":
|
||||||
|
x = x + skip
|
||||||
|
|
||||||
|
# attention and feedforward blocks
|
||||||
|
x = attn(attn_norm(x), rope=rope, mask=mask) + x
|
||||||
|
x = ff(ff_norm(x)) + x
|
||||||
|
|
||||||
|
assert len(skips) == 0
|
||||||
|
|
||||||
|
x = self.norm_out(x)[:, 1:, :] # unpack t from x
|
||||||
|
|
||||||
|
return self.proj_out(x)
|
666
GPT_SoVITS/f5_tts/model/modules.py
Normal file
666
GPT_SoVITS/f5_tts/model/modules.py
Normal file
@ -0,0 +1,666 @@
|
|||||||
|
"""
|
||||||
|
ein notation:
|
||||||
|
b - batch
|
||||||
|
n - sequence
|
||||||
|
nt - text sequence
|
||||||
|
nw - raw wave length
|
||||||
|
d - dimension
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import math
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import torchaudio
|
||||||
|
from librosa.filters import mel as librosa_mel_fn
|
||||||
|
from torch import nn
|
||||||
|
from x_transformers.x_transformers import apply_rotary_pos_emb
|
||||||
|
|
||||||
|
|
||||||
|
# raw wav to mel spec
|
||||||
|
|
||||||
|
|
||||||
|
mel_basis_cache = {}
|
||||||
|
hann_window_cache = {}
|
||||||
|
|
||||||
|
|
||||||
|
def get_bigvgan_mel_spectrogram(
|
||||||
|
waveform,
|
||||||
|
n_fft=1024,
|
||||||
|
n_mel_channels=100,
|
||||||
|
target_sample_rate=24000,
|
||||||
|
hop_length=256,
|
||||||
|
win_length=1024,
|
||||||
|
fmin=0,
|
||||||
|
fmax=None,
|
||||||
|
center=False,
|
||||||
|
): # Copy from https://github.com/NVIDIA/BigVGAN/tree/main
|
||||||
|
device = waveform.device
|
||||||
|
key = f"{n_fft}_{n_mel_channels}_{target_sample_rate}_{hop_length}_{win_length}_{fmin}_{fmax}_{device}"
|
||||||
|
|
||||||
|
if key not in mel_basis_cache:
|
||||||
|
mel = librosa_mel_fn(sr=target_sample_rate, n_fft=n_fft, n_mels=n_mel_channels, fmin=fmin, fmax=fmax)
|
||||||
|
mel_basis_cache[key] = torch.from_numpy(mel).float().to(device) # TODO: why they need .float()?
|
||||||
|
hann_window_cache[key] = torch.hann_window(win_length).to(device)
|
||||||
|
|
||||||
|
mel_basis = mel_basis_cache[key]
|
||||||
|
hann_window = hann_window_cache[key]
|
||||||
|
|
||||||
|
padding = (n_fft - hop_length) // 2
|
||||||
|
waveform = torch.nn.functional.pad(waveform.unsqueeze(1), (padding, padding), mode="reflect").squeeze(1)
|
||||||
|
|
||||||
|
spec = torch.stft(
|
||||||
|
waveform,
|
||||||
|
n_fft,
|
||||||
|
hop_length=hop_length,
|
||||||
|
win_length=win_length,
|
||||||
|
window=hann_window,
|
||||||
|
center=center,
|
||||||
|
pad_mode="reflect",
|
||||||
|
normalized=False,
|
||||||
|
onesided=True,
|
||||||
|
return_complex=True,
|
||||||
|
)
|
||||||
|
spec = torch.sqrt(torch.view_as_real(spec).pow(2).sum(-1) + 1e-9)
|
||||||
|
|
||||||
|
mel_spec = torch.matmul(mel_basis, spec)
|
||||||
|
mel_spec = torch.log(torch.clamp(mel_spec, min=1e-5))
|
||||||
|
|
||||||
|
return mel_spec
|
||||||
|
|
||||||
|
|
||||||
|
def get_vocos_mel_spectrogram(
|
||||||
|
waveform,
|
||||||
|
n_fft=1024,
|
||||||
|
n_mel_channels=100,
|
||||||
|
target_sample_rate=24000,
|
||||||
|
hop_length=256,
|
||||||
|
win_length=1024,
|
||||||
|
):
|
||||||
|
mel_stft = torchaudio.transforms.MelSpectrogram(
|
||||||
|
sample_rate=target_sample_rate,
|
||||||
|
n_fft=n_fft,
|
||||||
|
win_length=win_length,
|
||||||
|
hop_length=hop_length,
|
||||||
|
n_mels=n_mel_channels,
|
||||||
|
power=1,
|
||||||
|
center=True,
|
||||||
|
normalized=False,
|
||||||
|
norm=None,
|
||||||
|
).to(waveform.device)
|
||||||
|
if len(waveform.shape) == 3:
|
||||||
|
waveform = waveform.squeeze(1) # 'b 1 nw -> b nw'
|
||||||
|
|
||||||
|
assert len(waveform.shape) == 2
|
||||||
|
|
||||||
|
mel = mel_stft(waveform)
|
||||||
|
mel = mel.clamp(min=1e-5).log()
|
||||||
|
return mel
|
||||||
|
|
||||||
|
|
||||||
|
class MelSpec(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
n_fft=1024,
|
||||||
|
hop_length=256,
|
||||||
|
win_length=1024,
|
||||||
|
n_mel_channels=100,
|
||||||
|
target_sample_rate=24_000,
|
||||||
|
mel_spec_type="vocos",
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
assert mel_spec_type in ["vocos", "bigvgan"], print("We only support two extract mel backend: vocos or bigvgan")
|
||||||
|
|
||||||
|
self.n_fft = n_fft
|
||||||
|
self.hop_length = hop_length
|
||||||
|
self.win_length = win_length
|
||||||
|
self.n_mel_channels = n_mel_channels
|
||||||
|
self.target_sample_rate = target_sample_rate
|
||||||
|
|
||||||
|
if mel_spec_type == "vocos":
|
||||||
|
self.extractor = get_vocos_mel_spectrogram
|
||||||
|
elif mel_spec_type == "bigvgan":
|
||||||
|
self.extractor = get_bigvgan_mel_spectrogram
|
||||||
|
|
||||||
|
self.register_buffer("dummy", torch.tensor(0), persistent=False)
|
||||||
|
|
||||||
|
def forward(self, wav):
|
||||||
|
if self.dummy.device != wav.device:
|
||||||
|
self.to(wav.device)
|
||||||
|
|
||||||
|
mel = self.extractor(
|
||||||
|
waveform=wav,
|
||||||
|
n_fft=self.n_fft,
|
||||||
|
n_mel_channels=self.n_mel_channels,
|
||||||
|
target_sample_rate=self.target_sample_rate,
|
||||||
|
hop_length=self.hop_length,
|
||||||
|
win_length=self.win_length,
|
||||||
|
)
|
||||||
|
|
||||||
|
return mel
|
||||||
|
|
||||||
|
|
||||||
|
# sinusoidal position embedding
|
||||||
|
|
||||||
|
|
||||||
|
class SinusPositionEmbedding(nn.Module):
|
||||||
|
def __init__(self, dim):
|
||||||
|
super().__init__()
|
||||||
|
self.dim = dim
|
||||||
|
|
||||||
|
def forward(self, x, scale=1000):
|
||||||
|
device = x.device
|
||||||
|
half_dim = self.dim // 2
|
||||||
|
emb = math.log(10000) / (half_dim - 1)
|
||||||
|
emb = torch.exp(torch.arange(half_dim, device=device).float() * -emb)
|
||||||
|
emb = scale * x.unsqueeze(1) * emb.unsqueeze(0)
|
||||||
|
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
|
||||||
|
return emb
|
||||||
|
|
||||||
|
|
||||||
|
# convolutional position embedding
|
||||||
|
|
||||||
|
|
||||||
|
class ConvPositionEmbedding(nn.Module):
|
||||||
|
def __init__(self, dim, kernel_size=31, groups=16):
|
||||||
|
super().__init__()
|
||||||
|
assert kernel_size % 2 != 0
|
||||||
|
self.conv1d = nn.Sequential(
|
||||||
|
nn.Conv1d(dim, dim, kernel_size, groups=groups, padding=kernel_size // 2),
|
||||||
|
nn.Mish(),
|
||||||
|
nn.Conv1d(dim, dim, kernel_size, groups=groups, padding=kernel_size // 2),
|
||||||
|
nn.Mish(),
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x: float["b n d"], mask: bool["b n"] | None = None): # noqa: F722
|
||||||
|
if mask is not None:
|
||||||
|
mask = mask[..., None]
|
||||||
|
x = x.masked_fill(~mask, 0.0)
|
||||||
|
|
||||||
|
x = x.permute(0, 2, 1)
|
||||||
|
x = self.conv1d(x)
|
||||||
|
out = x.permute(0, 2, 1)
|
||||||
|
|
||||||
|
if mask is not None:
|
||||||
|
out = out.masked_fill(~mask, 0.0)
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
# rotary positional embedding related
|
||||||
|
|
||||||
|
|
||||||
|
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0, theta_rescale_factor=1.0):
|
||||||
|
# proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning
|
||||||
|
# has some connection to NTK literature
|
||||||
|
# https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/
|
||||||
|
# https://github.com/lucidrains/rotary-embedding-torch/blob/main/rotary_embedding_torch/rotary_embedding_torch.py
|
||||||
|
theta *= theta_rescale_factor ** (dim / (dim - 2))
|
||||||
|
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
|
||||||
|
t = torch.arange(end, device=freqs.device) # type: ignore
|
||||||
|
freqs = torch.outer(t, freqs).float() # type: ignore
|
||||||
|
freqs_cos = torch.cos(freqs) # real part
|
||||||
|
freqs_sin = torch.sin(freqs) # imaginary part
|
||||||
|
return torch.cat([freqs_cos, freqs_sin], dim=-1)
|
||||||
|
|
||||||
|
|
||||||
|
def get_pos_embed_indices(start, length, max_pos, scale=1.0):
|
||||||
|
# length = length if isinstance(length, int) else length.max()
|
||||||
|
scale = scale * torch.ones_like(start, dtype=torch.float32) # in case scale is a scalar
|
||||||
|
pos = (
|
||||||
|
start.unsqueeze(1)
|
||||||
|
+ (torch.arange(length, device=start.device, dtype=torch.float32).unsqueeze(0) * scale.unsqueeze(1)).long()
|
||||||
|
)
|
||||||
|
# avoid extra long error.
|
||||||
|
pos = torch.where(pos < max_pos, pos, max_pos - 1)
|
||||||
|
return pos
|
||||||
|
|
||||||
|
|
||||||
|
# Global Response Normalization layer (Instance Normalization ?)
|
||||||
|
|
||||||
|
|
||||||
|
class GRN(nn.Module):
|
||||||
|
def __init__(self, dim):
|
||||||
|
super().__init__()
|
||||||
|
self.gamma = nn.Parameter(torch.zeros(1, 1, dim))
|
||||||
|
self.beta = nn.Parameter(torch.zeros(1, 1, dim))
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
Gx = torch.norm(x, p=2, dim=1, keepdim=True)
|
||||||
|
Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6)
|
||||||
|
return self.gamma * (x * Nx) + self.beta + x
|
||||||
|
|
||||||
|
|
||||||
|
# ConvNeXt-V2 Block https://github.com/facebookresearch/ConvNeXt-V2/blob/main/models/convnextv2.py
|
||||||
|
# ref: https://github.com/bfs18/e2_tts/blob/main/rfwave/modules.py#L108
|
||||||
|
|
||||||
|
|
||||||
|
class ConvNeXtV2Block(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim: int,
|
||||||
|
intermediate_dim: int,
|
||||||
|
dilation: int = 1,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
padding = (dilation * (7 - 1)) // 2
|
||||||
|
self.dwconv = nn.Conv1d(
|
||||||
|
dim, dim, kernel_size=7, padding=padding, groups=dim, dilation=dilation
|
||||||
|
) # depthwise conv
|
||||||
|
self.norm = nn.LayerNorm(dim, eps=1e-6)
|
||||||
|
self.pwconv1 = nn.Linear(dim, intermediate_dim) # pointwise/1x1 convs, implemented with linear layers
|
||||||
|
self.act = nn.GELU()
|
||||||
|
self.grn = GRN(intermediate_dim)
|
||||||
|
self.pwconv2 = nn.Linear(intermediate_dim, dim)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
residual = x
|
||||||
|
x = x.transpose(1, 2) # b n d -> b d n
|
||||||
|
x = self.dwconv(x)
|
||||||
|
x = x.transpose(1, 2) # b d n -> b n d
|
||||||
|
x = self.norm(x)
|
||||||
|
x = self.pwconv1(x)
|
||||||
|
x = self.act(x)
|
||||||
|
x = self.grn(x)
|
||||||
|
x = self.pwconv2(x)
|
||||||
|
return residual + x
|
||||||
|
|
||||||
|
|
||||||
|
# AdaLayerNormZero
|
||||||
|
# return with modulated x for attn input, and params for later mlp modulation
|
||||||
|
|
||||||
|
|
||||||
|
class AdaLayerNormZero(nn.Module):
|
||||||
|
def __init__(self, dim):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.silu = nn.SiLU()
|
||||||
|
self.linear = nn.Linear(dim, dim * 6)
|
||||||
|
|
||||||
|
self.norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
||||||
|
|
||||||
|
def forward(self, x, emb=None):
|
||||||
|
emb = self.linear(self.silu(emb))
|
||||||
|
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = torch.chunk(emb, 6, dim=1)
|
||||||
|
|
||||||
|
x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None]
|
||||||
|
return x, gate_msa, shift_mlp, scale_mlp, gate_mlp
|
||||||
|
|
||||||
|
|
||||||
|
# AdaLayerNormZero for final layer
|
||||||
|
# return only with modulated x for attn input, cuz no more mlp modulation
|
||||||
|
|
||||||
|
|
||||||
|
class AdaLayerNormZero_Final(nn.Module):
|
||||||
|
def __init__(self, dim):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.silu = nn.SiLU()
|
||||||
|
self.linear = nn.Linear(dim, dim * 2)
|
||||||
|
|
||||||
|
self.norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
||||||
|
|
||||||
|
def forward(self, x, emb):
|
||||||
|
emb = self.linear(self.silu(emb))
|
||||||
|
scale, shift = torch.chunk(emb, 2, dim=1)
|
||||||
|
|
||||||
|
x = self.norm(x) * (1 + scale)[:, None, :] + shift[:, None, :]
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
# FeedForward
|
||||||
|
|
||||||
|
|
||||||
|
class FeedForward(nn.Module):
|
||||||
|
def __init__(self, dim, dim_out=None, mult=4, dropout=0.0, approximate: str = "none"):
|
||||||
|
super().__init__()
|
||||||
|
inner_dim = int(dim * mult)
|
||||||
|
dim_out = dim_out if dim_out is not None else dim
|
||||||
|
|
||||||
|
activation = nn.GELU(approximate=approximate)
|
||||||
|
project_in = nn.Sequential(nn.Linear(dim, inner_dim), activation)
|
||||||
|
self.ff = nn.Sequential(project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out))
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.ff(x)
|
||||||
|
|
||||||
|
|
||||||
|
# Attention with possible joint part
|
||||||
|
# modified from diffusers/src/diffusers/models/attention_processor.py
|
||||||
|
|
||||||
|
|
||||||
|
class Attention(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
processor: JointAttnProcessor | AttnProcessor,
|
||||||
|
dim: int,
|
||||||
|
heads: int = 8,
|
||||||
|
dim_head: int = 64,
|
||||||
|
dropout: float = 0.0,
|
||||||
|
context_dim: Optional[int] = None, # if not None -> joint attention
|
||||||
|
context_pre_only=None,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
if not hasattr(F, "scaled_dot_product_attention"):
|
||||||
|
raise ImportError("Attention equires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
|
||||||
|
|
||||||
|
self.processor = processor
|
||||||
|
|
||||||
|
self.dim = dim
|
||||||
|
self.heads = heads
|
||||||
|
self.inner_dim = dim_head * heads
|
||||||
|
self.dropout = dropout
|
||||||
|
|
||||||
|
self.context_dim = context_dim
|
||||||
|
self.context_pre_only = context_pre_only
|
||||||
|
|
||||||
|
self.to_q = nn.Linear(dim, self.inner_dim)
|
||||||
|
self.to_k = nn.Linear(dim, self.inner_dim)
|
||||||
|
self.to_v = nn.Linear(dim, self.inner_dim)
|
||||||
|
|
||||||
|
if self.context_dim is not None:
|
||||||
|
self.to_k_c = nn.Linear(context_dim, self.inner_dim)
|
||||||
|
self.to_v_c = nn.Linear(context_dim, self.inner_dim)
|
||||||
|
if self.context_pre_only is not None:
|
||||||
|
self.to_q_c = nn.Linear(context_dim, self.inner_dim)
|
||||||
|
|
||||||
|
self.to_out = nn.ModuleList([])
|
||||||
|
self.to_out.append(nn.Linear(self.inner_dim, dim))
|
||||||
|
self.to_out.append(nn.Dropout(dropout))
|
||||||
|
|
||||||
|
if self.context_pre_only is not None and not self.context_pre_only:
|
||||||
|
self.to_out_c = nn.Linear(self.inner_dim, dim)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x: float["b n d"], # noised input x # noqa: F722
|
||||||
|
c: float["b n d"] = None, # context c # noqa: F722
|
||||||
|
mask: bool["b n"] | None = None, # noqa: F722
|
||||||
|
rope=None, # rotary position embedding for x
|
||||||
|
c_rope=None, # rotary position embedding for c
|
||||||
|
) -> torch.Tensor:
|
||||||
|
if c is not None:
|
||||||
|
return self.processor(self, x, c=c, mask=mask, rope=rope, c_rope=c_rope)
|
||||||
|
else:
|
||||||
|
return self.processor(self, x, mask=mask, rope=rope)
|
||||||
|
|
||||||
|
|
||||||
|
# Attention processor
|
||||||
|
|
||||||
|
|
||||||
|
# from torch.nn.attention import SDPBackend
|
||||||
|
# torch.backends.cuda.enable_flash_sdp(True)
|
||||||
|
class AttnProcessor:
|
||||||
|
def __init__(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
attn: Attention,
|
||||||
|
x: float["b n d"], # noised input x # noqa: F722
|
||||||
|
mask: bool["b n"] | None = None, # noqa: F722
|
||||||
|
rope=None, # rotary position embedding
|
||||||
|
) -> torch.FloatTensor:
|
||||||
|
batch_size = x.shape[0]
|
||||||
|
|
||||||
|
# `sample` projections.
|
||||||
|
query = attn.to_q(x)
|
||||||
|
key = attn.to_k(x)
|
||||||
|
value = attn.to_v(x)
|
||||||
|
|
||||||
|
# apply rotary position embedding
|
||||||
|
if rope is not None:
|
||||||
|
freqs, xpos_scale = rope
|
||||||
|
q_xpos_scale, k_xpos_scale = (xpos_scale, xpos_scale**-1.0) if xpos_scale is not None else (1.0, 1.0)
|
||||||
|
|
||||||
|
query = apply_rotary_pos_emb(query, freqs, q_xpos_scale)
|
||||||
|
key = apply_rotary_pos_emb(key, freqs, k_xpos_scale)
|
||||||
|
|
||||||
|
# attention
|
||||||
|
inner_dim = key.shape[-1]
|
||||||
|
head_dim = inner_dim // attn.heads
|
||||||
|
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||||
|
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||||
|
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||||
|
|
||||||
|
# mask. e.g. inference got a batch with different target durations, mask out the padding
|
||||||
|
if mask is not None:
|
||||||
|
attn_mask = mask
|
||||||
|
attn_mask = attn_mask.unsqueeze(1).unsqueeze(1) # 'b n -> b 1 1 n'
|
||||||
|
# print(3433333333,attn_mask.shape)
|
||||||
|
attn_mask = attn_mask.expand(batch_size, attn.heads, query.shape[-2], key.shape[-2])
|
||||||
|
else:
|
||||||
|
attn_mask = None
|
||||||
|
# with torch.nn.attention.sdpa_kernel(backends=[SDPBackend.EFFICIENT_ATTENTION]):
|
||||||
|
# with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=True):
|
||||||
|
# with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=True, enable_mem_efficient=False):
|
||||||
|
# print(torch.backends.cuda.flash_sdp_enabled())
|
||||||
|
# print(torch.backends.cuda.mem_efficient_sdp_enabled())
|
||||||
|
# print(torch.backends.cuda.math_sdp_enabled())
|
||||||
|
x = F.scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=0.0, is_causal=False)
|
||||||
|
x = x.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
||||||
|
x = x.to(query.dtype)
|
||||||
|
|
||||||
|
# linear proj
|
||||||
|
x = attn.to_out[0](x)
|
||||||
|
# dropout
|
||||||
|
x = attn.to_out[1](x)
|
||||||
|
|
||||||
|
if mask is not None:
|
||||||
|
mask = mask.unsqueeze(-1)
|
||||||
|
x = x.masked_fill(~mask, 0.0)
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
# Joint Attention processor for MM-DiT
|
||||||
|
# modified from diffusers/src/diffusers/models/attention_processor.py
|
||||||
|
|
||||||
|
|
||||||
|
class JointAttnProcessor:
|
||||||
|
def __init__(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
attn: Attention,
|
||||||
|
x: float["b n d"], # noised input x # noqa: F722
|
||||||
|
c: float["b nt d"] = None, # context c, here text # noqa: F722
|
||||||
|
mask: bool["b n"] | None = None, # noqa: F722
|
||||||
|
rope=None, # rotary position embedding for x
|
||||||
|
c_rope=None, # rotary position embedding for c
|
||||||
|
) -> torch.FloatTensor:
|
||||||
|
residual = x
|
||||||
|
|
||||||
|
batch_size = c.shape[0]
|
||||||
|
|
||||||
|
# `sample` projections.
|
||||||
|
query = attn.to_q(x)
|
||||||
|
key = attn.to_k(x)
|
||||||
|
value = attn.to_v(x)
|
||||||
|
|
||||||
|
# `context` projections.
|
||||||
|
c_query = attn.to_q_c(c)
|
||||||
|
c_key = attn.to_k_c(c)
|
||||||
|
c_value = attn.to_v_c(c)
|
||||||
|
|
||||||
|
# apply rope for context and noised input independently
|
||||||
|
if rope is not None:
|
||||||
|
freqs, xpos_scale = rope
|
||||||
|
q_xpos_scale, k_xpos_scale = (xpos_scale, xpos_scale**-1.0) if xpos_scale is not None else (1.0, 1.0)
|
||||||
|
query = apply_rotary_pos_emb(query, freqs, q_xpos_scale)
|
||||||
|
key = apply_rotary_pos_emb(key, freqs, k_xpos_scale)
|
||||||
|
if c_rope is not None:
|
||||||
|
freqs, xpos_scale = c_rope
|
||||||
|
q_xpos_scale, k_xpos_scale = (xpos_scale, xpos_scale**-1.0) if xpos_scale is not None else (1.0, 1.0)
|
||||||
|
c_query = apply_rotary_pos_emb(c_query, freqs, q_xpos_scale)
|
||||||
|
c_key = apply_rotary_pos_emb(c_key, freqs, k_xpos_scale)
|
||||||
|
|
||||||
|
# attention
|
||||||
|
query = torch.cat([query, c_query], dim=1)
|
||||||
|
key = torch.cat([key, c_key], dim=1)
|
||||||
|
value = torch.cat([value, c_value], dim=1)
|
||||||
|
|
||||||
|
inner_dim = key.shape[-1]
|
||||||
|
head_dim = inner_dim // attn.heads
|
||||||
|
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||||
|
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||||
|
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||||
|
|
||||||
|
# mask. e.g. inference got a batch with different target durations, mask out the padding
|
||||||
|
if mask is not None:
|
||||||
|
attn_mask = F.pad(mask, (0, c.shape[1]), value=True) # no mask for c (text)
|
||||||
|
attn_mask = attn_mask.unsqueeze(1).unsqueeze(1) # 'b n -> b 1 1 n'
|
||||||
|
attn_mask = attn_mask.expand(batch_size, attn.heads, query.shape[-2], key.shape[-2])
|
||||||
|
else:
|
||||||
|
attn_mask = None
|
||||||
|
|
||||||
|
x = F.scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=0.0, is_causal=False)
|
||||||
|
x = x.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
||||||
|
x = x.to(query.dtype)
|
||||||
|
|
||||||
|
# Split the attention outputs.
|
||||||
|
x, c = (
|
||||||
|
x[:, : residual.shape[1]],
|
||||||
|
x[:, residual.shape[1] :],
|
||||||
|
)
|
||||||
|
|
||||||
|
# linear proj
|
||||||
|
x = attn.to_out[0](x)
|
||||||
|
# dropout
|
||||||
|
x = attn.to_out[1](x)
|
||||||
|
if not attn.context_pre_only:
|
||||||
|
c = attn.to_out_c(c)
|
||||||
|
|
||||||
|
if mask is not None:
|
||||||
|
mask = mask.unsqueeze(-1)
|
||||||
|
x = x.masked_fill(~mask, 0.0)
|
||||||
|
# c = c.masked_fill(~mask, 0.) # no mask for c (text)
|
||||||
|
|
||||||
|
return x, c
|
||||||
|
|
||||||
|
|
||||||
|
# DiT Block
|
||||||
|
|
||||||
|
|
||||||
|
class DiTBlock(nn.Module):
|
||||||
|
def __init__(self, dim, heads, dim_head, ff_mult=4, dropout=0.1):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.attn_norm = AdaLayerNormZero(dim)
|
||||||
|
self.attn = Attention(
|
||||||
|
processor=AttnProcessor(),
|
||||||
|
dim=dim,
|
||||||
|
heads=heads,
|
||||||
|
dim_head=dim_head,
|
||||||
|
dropout=dropout,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.ff_norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
||||||
|
self.ff = FeedForward(dim=dim, mult=ff_mult, dropout=dropout, approximate="tanh")
|
||||||
|
|
||||||
|
def forward(self, x, t, mask=None, rope=None): # x: noised input, t: time embedding
|
||||||
|
# pre-norm & modulation for attention input
|
||||||
|
norm, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.attn_norm(x, emb=t)
|
||||||
|
|
||||||
|
# attention
|
||||||
|
attn_output = self.attn(x=norm, mask=mask, rope=rope)
|
||||||
|
|
||||||
|
# process attention output for input x
|
||||||
|
x = x + gate_msa.unsqueeze(1) * attn_output
|
||||||
|
|
||||||
|
norm = self.ff_norm(x) * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
|
||||||
|
ff_output = self.ff(norm)
|
||||||
|
x = x + gate_mlp.unsqueeze(1) * ff_output
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
# MMDiT Block https://arxiv.org/abs/2403.03206
|
||||||
|
|
||||||
|
|
||||||
|
class MMDiTBlock(nn.Module):
|
||||||
|
r"""
|
||||||
|
modified from diffusers/src/diffusers/models/attention.py
|
||||||
|
|
||||||
|
notes.
|
||||||
|
_c: context related. text, cond, etc. (left part in sd3 fig2.b)
|
||||||
|
_x: noised input related. (right part)
|
||||||
|
context_pre_only: last layer only do prenorm + modulation cuz no more ffn
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, dim, heads, dim_head, ff_mult=4, dropout=0.1, context_pre_only=False):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.context_pre_only = context_pre_only
|
||||||
|
|
||||||
|
self.attn_norm_c = AdaLayerNormZero_Final(dim) if context_pre_only else AdaLayerNormZero(dim)
|
||||||
|
self.attn_norm_x = AdaLayerNormZero(dim)
|
||||||
|
self.attn = Attention(
|
||||||
|
processor=JointAttnProcessor(),
|
||||||
|
dim=dim,
|
||||||
|
heads=heads,
|
||||||
|
dim_head=dim_head,
|
||||||
|
dropout=dropout,
|
||||||
|
context_dim=dim,
|
||||||
|
context_pre_only=context_pre_only,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not context_pre_only:
|
||||||
|
self.ff_norm_c = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
||||||
|
self.ff_c = FeedForward(dim=dim, mult=ff_mult, dropout=dropout, approximate="tanh")
|
||||||
|
else:
|
||||||
|
self.ff_norm_c = None
|
||||||
|
self.ff_c = None
|
||||||
|
self.ff_norm_x = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
||||||
|
self.ff_x = FeedForward(dim=dim, mult=ff_mult, dropout=dropout, approximate="tanh")
|
||||||
|
|
||||||
|
def forward(self, x, c, t, mask=None, rope=None, c_rope=None): # x: noised input, c: context, t: time embedding
|
||||||
|
# pre-norm & modulation for attention input
|
||||||
|
if self.context_pre_only:
|
||||||
|
norm_c = self.attn_norm_c(c, t)
|
||||||
|
else:
|
||||||
|
norm_c, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.attn_norm_c(c, emb=t)
|
||||||
|
norm_x, x_gate_msa, x_shift_mlp, x_scale_mlp, x_gate_mlp = self.attn_norm_x(x, emb=t)
|
||||||
|
|
||||||
|
# attention
|
||||||
|
x_attn_output, c_attn_output = self.attn(x=norm_x, c=norm_c, mask=mask, rope=rope, c_rope=c_rope)
|
||||||
|
|
||||||
|
# process attention output for context c
|
||||||
|
if self.context_pre_only:
|
||||||
|
c = None
|
||||||
|
else: # if not last layer
|
||||||
|
c = c + c_gate_msa.unsqueeze(1) * c_attn_output
|
||||||
|
|
||||||
|
norm_c = self.ff_norm_c(c) * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
|
||||||
|
c_ff_output = self.ff_c(norm_c)
|
||||||
|
c = c + c_gate_mlp.unsqueeze(1) * c_ff_output
|
||||||
|
|
||||||
|
# process attention output for input x
|
||||||
|
x = x + x_gate_msa.unsqueeze(1) * x_attn_output
|
||||||
|
|
||||||
|
norm_x = self.ff_norm_x(x) * (1 + x_scale_mlp[:, None]) + x_shift_mlp[:, None]
|
||||||
|
x_ff_output = self.ff_x(norm_x)
|
||||||
|
x = x + x_gate_mlp.unsqueeze(1) * x_ff_output
|
||||||
|
|
||||||
|
return c, x
|
||||||
|
|
||||||
|
|
||||||
|
# time step conditioning embedding
|
||||||
|
|
||||||
|
|
||||||
|
class TimestepEmbedding(nn.Module):
|
||||||
|
def __init__(self, dim, freq_embed_dim=256):
|
||||||
|
super().__init__()
|
||||||
|
self.time_embed = SinusPositionEmbedding(freq_embed_dim)
|
||||||
|
self.time_mlp = nn.Sequential(nn.Linear(freq_embed_dim, dim), nn.SiLU(), nn.Linear(dim, dim))
|
||||||
|
|
||||||
|
def forward(self, timestep: float["b"]): # noqa: F821
|
||||||
|
time_hidden = self.time_embed(timestep)
|
||||||
|
time_hidden = time_hidden.to(timestep.dtype)
|
||||||
|
time = self.time_mlp(time_hidden) # b d
|
||||||
|
return time
|
@ -1,6 +1,3 @@
|
|||||||
from . import cnhubert, whisper_enc
|
from . import cnhubert, whisper_enc
|
||||||
|
|
||||||
content_module_map = {
|
content_module_map = {"cnhubert": cnhubert, "whisper": whisper_enc}
|
||||||
'cnhubert': cnhubert,
|
|
||||||
'whisper': whisper_enc
|
|
||||||
}
|
|
||||||
|
@ -1,14 +1,11 @@
|
|||||||
import time
|
|
||||||
|
|
||||||
import librosa
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
|
||||||
import soundfile as sf
|
|
||||||
import os
|
import os
|
||||||
from transformers import logging as tf_logging
|
from transformers import logging as tf_logging
|
||||||
|
|
||||||
tf_logging.set_verbosity_error()
|
tf_logging.set_verbosity_error()
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
logging.getLogger("numba").setLevel(logging.WARNING)
|
logging.getLogger("numba").setLevel(logging.WARNING)
|
||||||
|
|
||||||
from transformers import (
|
from transformers import (
|
||||||
@ -23,21 +20,19 @@ cnhubert_base_path = None
|
|||||||
|
|
||||||
|
|
||||||
class CNHubert(nn.Module):
|
class CNHubert(nn.Module):
|
||||||
def __init__(self, base_path:str=None):
|
def __init__(self, base_path: str = None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
if base_path is None:
|
if base_path is None:
|
||||||
base_path = cnhubert_base_path
|
base_path = cnhubert_base_path
|
||||||
if os.path.exists(base_path):...
|
if os.path.exists(base_path):
|
||||||
else:raise FileNotFoundError(base_path)
|
...
|
||||||
|
else:
|
||||||
|
raise FileNotFoundError(base_path)
|
||||||
self.model = HubertModel.from_pretrained(base_path, local_files_only=True)
|
self.model = HubertModel.from_pretrained(base_path, local_files_only=True)
|
||||||
self.feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(
|
self.feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(base_path, local_files_only=True)
|
||||||
base_path, local_files_only=True
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
input_values = self.feature_extractor(
|
input_values = self.feature_extractor(x, return_tensors="pt", sampling_rate=16000).input_values.to(x.device)
|
||||||
x, return_tensors="pt", sampling_rate=16000
|
|
||||||
).input_values.to(x.device)
|
|
||||||
feats = self.model(input_values)["last_hidden_state"]
|
feats = self.model(input_values)["last_hidden_state"]
|
||||||
return feats
|
return feats
|
||||||
|
|
||||||
|
@ -19,7 +19,5 @@ def get_content(model=None, wav_16k_tensor=None):
|
|||||||
feature_len = mel.shape[-1] // 2
|
feature_len = mel.shape[-1] // 2
|
||||||
assert mel.shape[-1] < 3000, "输入音频过长,只允许输入30以内音频"
|
assert mel.shape[-1] < 3000, "输入音频过长,只允许输入30以内音频"
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
feature = model(pad_or_trim(mel, 3000).unsqueeze(0))[
|
feature = model(pad_or_trim(mel, 3000).unsqueeze(0))[:1, :feature_len, :].transpose(1, 2)
|
||||||
:1, :feature_len, :
|
|
||||||
].transpose(1, 2)
|
|
||||||
return feature
|
return feature
|
||||||
|
@ -7,13 +7,23 @@ from GPT_SoVITS.inference_webui import change_gpt_weights, change_sovits_weights
|
|||||||
|
|
||||||
i18n = I18nAuto()
|
i18n = I18nAuto()
|
||||||
|
|
||||||
def synthesize(GPT_model_path, SoVITS_model_path, ref_audio_path, ref_text_path, ref_language, target_text_path, target_language, output_path):
|
|
||||||
|
def synthesize(
|
||||||
|
GPT_model_path,
|
||||||
|
SoVITS_model_path,
|
||||||
|
ref_audio_path,
|
||||||
|
ref_text_path,
|
||||||
|
ref_language,
|
||||||
|
target_text_path,
|
||||||
|
target_language,
|
||||||
|
output_path,
|
||||||
|
):
|
||||||
# Read reference text
|
# Read reference text
|
||||||
with open(ref_text_path, 'r', encoding='utf-8') as file:
|
with open(ref_text_path, "r", encoding="utf-8") as file:
|
||||||
ref_text = file.read()
|
ref_text = file.read()
|
||||||
|
|
||||||
# Read target text
|
# Read target text
|
||||||
with open(target_text_path, 'r', encoding='utf-8') as file:
|
with open(target_text_path, "r", encoding="utf-8") as file:
|
||||||
target_text = file.read()
|
target_text = file.read()
|
||||||
|
|
||||||
# Change model weights
|
# Change model weights
|
||||||
@ -21,12 +31,16 @@ def synthesize(GPT_model_path, SoVITS_model_path, ref_audio_path, ref_text_path,
|
|||||||
change_sovits_weights(sovits_path=SoVITS_model_path)
|
change_sovits_weights(sovits_path=SoVITS_model_path)
|
||||||
|
|
||||||
# Synthesize audio
|
# Synthesize audio
|
||||||
synthesis_result = get_tts_wav(ref_wav_path=ref_audio_path,
|
synthesis_result = get_tts_wav(
|
||||||
prompt_text=ref_text,
|
ref_wav_path=ref_audio_path,
|
||||||
prompt_language=i18n(ref_language),
|
prompt_text=ref_text,
|
||||||
text=target_text,
|
prompt_language=i18n(ref_language),
|
||||||
text_language=i18n(target_language), top_p=1, temperature=1)
|
text=target_text,
|
||||||
|
text_language=i18n(target_language),
|
||||||
|
top_p=1,
|
||||||
|
temperature=1,
|
||||||
|
)
|
||||||
|
|
||||||
result_list = list(synthesis_result)
|
result_list = list(synthesis_result)
|
||||||
|
|
||||||
if result_list:
|
if result_list:
|
||||||
@ -35,21 +49,38 @@ def synthesize(GPT_model_path, SoVITS_model_path, ref_audio_path, ref_text_path,
|
|||||||
sf.write(output_wav_path, last_audio_data, last_sampling_rate)
|
sf.write(output_wav_path, last_audio_data, last_sampling_rate)
|
||||||
print(f"Audio saved to {output_wav_path}")
|
print(f"Audio saved to {output_wav_path}")
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
parser = argparse.ArgumentParser(description="GPT-SoVITS Command Line Tool")
|
parser = argparse.ArgumentParser(description="GPT-SoVITS Command Line Tool")
|
||||||
parser.add_argument('--gpt_model', required=True, help="Path to the GPT model file")
|
parser.add_argument("--gpt_model", required=True, help="Path to the GPT model file")
|
||||||
parser.add_argument('--sovits_model', required=True, help="Path to the SoVITS model file")
|
parser.add_argument("--sovits_model", required=True, help="Path to the SoVITS model file")
|
||||||
parser.add_argument('--ref_audio', required=True, help="Path to the reference audio file")
|
parser.add_argument("--ref_audio", required=True, help="Path to the reference audio file")
|
||||||
parser.add_argument('--ref_text', required=True, help="Path to the reference text file")
|
parser.add_argument("--ref_text", required=True, help="Path to the reference text file")
|
||||||
parser.add_argument('--ref_language', required=True, choices=["中文", "英文", "日文"], help="Language of the reference audio")
|
parser.add_argument(
|
||||||
parser.add_argument('--target_text', required=True, help="Path to the target text file")
|
"--ref_language", required=True, choices=["中文", "英文", "日文"], help="Language of the reference audio"
|
||||||
parser.add_argument('--target_language', required=True, choices=["中文", "英文", "日文", "中英混合", "日英混合", "多语种混合"], help="Language of the target text")
|
)
|
||||||
parser.add_argument('--output_path', required=True, help="Path to the output directory")
|
parser.add_argument("--target_text", required=True, help="Path to the target text file")
|
||||||
|
parser.add_argument(
|
||||||
|
"--target_language",
|
||||||
|
required=True,
|
||||||
|
choices=["中文", "英文", "日文", "中英混合", "日英混合", "多语种混合"],
|
||||||
|
help="Language of the target text",
|
||||||
|
)
|
||||||
|
parser.add_argument("--output_path", required=True, help="Path to the output directory")
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
synthesize(args.gpt_model, args.sovits_model, args.ref_audio, args.ref_text, args.ref_language, args.target_text, args.target_language, args.output_path)
|
synthesize(
|
||||||
|
args.gpt_model,
|
||||||
|
args.sovits_model,
|
||||||
|
args.ref_audio,
|
||||||
|
args.ref_text,
|
||||||
|
args.ref_language,
|
||||||
|
args.target_text,
|
||||||
|
args.target_language,
|
||||||
|
args.output_path,
|
||||||
|
)
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
|
||||||
|
@ -6,6 +6,7 @@ from PyQt5.QtWidgets import QGridLayout, QVBoxLayout, QWidget, QFileDialog, QSta
|
|||||||
import soundfile as sf
|
import soundfile as sf
|
||||||
|
|
||||||
from tools.i18n.i18n import I18nAuto
|
from tools.i18n.i18n import I18nAuto
|
||||||
|
|
||||||
i18n = I18nAuto()
|
i18n = I18nAuto()
|
||||||
|
|
||||||
from inference_webui import gpt_path, sovits_path, change_gpt_weights, change_sovits_weights, get_tts_wav
|
from inference_webui import gpt_path, sovits_path, change_gpt_weights, change_sovits_weights, get_tts_wav
|
||||||
@ -18,7 +19,7 @@ class GPTSoVITSGUI(QMainWindow):
|
|||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.setWindowTitle('GPT-SoVITS GUI')
|
self.setWindowTitle("GPT-SoVITS GUI")
|
||||||
self.setGeometry(800, 450, 950, 850)
|
self.setGeometry(800, 450, 950, 850)
|
||||||
|
|
||||||
self.setStyleSheet("""
|
self.setStyleSheet("""
|
||||||
@ -61,11 +62,12 @@ class GPTSoVITSGUI(QMainWindow):
|
|||||||
border: 1px solid #45a049;
|
border: 1px solid #45a049;
|
||||||
box-shadow: 2px 2px 2px rgba(0, 0, 0, 0.1);
|
box-shadow: 2px 2px 2px rgba(0, 0, 0, 0.1);
|
||||||
}
|
}
|
||||||
""")
|
""")
|
||||||
|
|
||||||
license_text = (
|
license_text = (
|
||||||
"本软件以MIT协议开源, 作者不对软件具备任何控制力, 使用软件者、传播软件导出的声音者自负全责. "
|
"本软件以MIT协议开源, 作者不对软件具备任何控制力, 使用软件者、传播软件导出的声音者自负全责. "
|
||||||
"如不认可该条款, 则不能使用或引用软件包内任何代码和文件. 详见根目录LICENSE.")
|
"如不认可该条款, 则不能使用或引用软件包内任何代码和文件. 详见根目录LICENSE."
|
||||||
|
)
|
||||||
license_label = QLabel(license_text)
|
license_label = QLabel(license_text)
|
||||||
license_label.setWordWrap(True)
|
license_label.setWordWrap(True)
|
||||||
|
|
||||||
@ -124,14 +126,16 @@ class GPTSoVITSGUI(QMainWindow):
|
|||||||
self.output_text = QTextEdit()
|
self.output_text = QTextEdit()
|
||||||
self.output_text.setReadOnly(True)
|
self.output_text.setReadOnly(True)
|
||||||
|
|
||||||
self.add_drag_drop_events([
|
self.add_drag_drop_events(
|
||||||
self.GPT_model_input,
|
[
|
||||||
self.SoVITS_model_input,
|
self.GPT_model_input,
|
||||||
self.ref_audio_input,
|
self.SoVITS_model_input,
|
||||||
self.ref_text_input,
|
self.ref_audio_input,
|
||||||
self.target_text_input,
|
self.ref_text_input,
|
||||||
self.output_input,
|
self.target_text_input,
|
||||||
])
|
self.output_input,
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
self.synthesize_button = QPushButton("合成")
|
self.synthesize_button = QPushButton("合成")
|
||||||
self.synthesize_button.clicked.connect(self.synthesize)
|
self.synthesize_button.clicked.connect(self.synthesize)
|
||||||
@ -235,14 +239,14 @@ class GPTSoVITSGUI(QMainWindow):
|
|||||||
def upload_ref_text(self):
|
def upload_ref_text(self):
|
||||||
file_path, _ = QFileDialog.getOpenFileName(self, "选择文本文件", "", "Text Files (*.txt)")
|
file_path, _ = QFileDialog.getOpenFileName(self, "选择文本文件", "", "Text Files (*.txt)")
|
||||||
if file_path:
|
if file_path:
|
||||||
with open(file_path, 'r', encoding='utf-8') as file:
|
with open(file_path, "r", encoding="utf-8") as file:
|
||||||
content = file.read()
|
content = file.read()
|
||||||
self.ref_text_input.setText(content)
|
self.ref_text_input.setText(content)
|
||||||
|
|
||||||
def upload_target_text(self):
|
def upload_target_text(self):
|
||||||
file_path, _ = QFileDialog.getOpenFileName(self, "选择文本文件", "", "Text Files (*.txt)")
|
file_path, _ = QFileDialog.getOpenFileName(self, "选择文本文件", "", "Text Files (*.txt)")
|
||||||
if file_path:
|
if file_path:
|
||||||
with open(file_path, 'r', encoding='utf-8') as file:
|
with open(file_path, "r", encoding="utf-8") as file:
|
||||||
content = file.read()
|
content = file.read()
|
||||||
self.target_text_input.setText(content)
|
self.target_text_input.setText(content)
|
||||||
|
|
||||||
@ -284,17 +288,19 @@ class GPTSoVITSGUI(QMainWindow):
|
|||||||
change_sovits_weights(sovits_path=SoVITS_model_path)
|
change_sovits_weights(sovits_path=SoVITS_model_path)
|
||||||
self.SoVITS_Path = SoVITS_model_path
|
self.SoVITS_Path = SoVITS_model_path
|
||||||
|
|
||||||
synthesis_result = get_tts_wav(ref_wav_path=ref_audio_path,
|
synthesis_result = get_tts_wav(
|
||||||
prompt_text=ref_text,
|
ref_wav_path=ref_audio_path,
|
||||||
prompt_language=language_combobox,
|
prompt_text=ref_text,
|
||||||
text=target_text,
|
prompt_language=language_combobox,
|
||||||
text_language=target_language_combobox)
|
text=target_text,
|
||||||
|
text_language=target_language_combobox,
|
||||||
|
)
|
||||||
|
|
||||||
result_list = list(synthesis_result)
|
result_list = list(synthesis_result)
|
||||||
|
|
||||||
if result_list:
|
if result_list:
|
||||||
last_sampling_rate, last_audio_data = result_list[-1]
|
last_sampling_rate, last_audio_data = result_list[-1]
|
||||||
output_wav_path = os.path.join(output_path, "output.wav")
|
output_wav_path = os.path.join(output_path, "output.wav")
|
||||||
sf.write(output_wav_path, last_audio_data, last_sampling_rate)
|
sf.write(output_wav_path, last_audio_data, last_sampling_rate)
|
||||||
|
|
||||||
result = "Audio saved to " + output_wav_path
|
result = "Audio saved to " + output_wav_path
|
||||||
@ -303,8 +309,8 @@ class GPTSoVITSGUI(QMainWindow):
|
|||||||
self.output_text.append("处理结果:\n" + result)
|
self.output_text.append("处理结果:\n" + result)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == "__main__":
|
||||||
app = QApplication(sys.argv)
|
app = QApplication(sys.argv)
|
||||||
mainWin = GPTSoVITSGUI()
|
mainWin = GPTSoVITSGUI()
|
||||||
mainWin.show()
|
mainWin.show()
|
||||||
sys.exit(app.exec_())
|
sys.exit(app.exec_())
|
||||||
|
File diff suppressed because it is too large
Load Diff
@ -1,14 +1,19 @@
|
|||||||
'''
|
"""
|
||||||
按中英混合识别
|
按中英混合识别
|
||||||
按日英混合识别
|
按日英混合识别
|
||||||
多语种启动切分识别语种
|
多语种启动切分识别语种
|
||||||
全部按中文识别
|
全部按中文识别
|
||||||
全部按英文识别
|
全部按英文识别
|
||||||
全部按日文识别
|
全部按日文识别
|
||||||
'''
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
import random
|
import random
|
||||||
import os, re, logging
|
import re
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
now_dir = os.getcwd()
|
now_dir = os.getcwd()
|
||||||
sys.path.append(now_dir)
|
sys.path.append(now_dir)
|
||||||
sys.path.append("%s/GPT_SoVITS" % (now_dir))
|
sys.path.append("%s/GPT_SoVITS" % (now_dir))
|
||||||
@ -20,9 +25,15 @@ logging.getLogger("httpx").setLevel(logging.ERROR)
|
|||||||
logging.getLogger("asyncio").setLevel(logging.ERROR)
|
logging.getLogger("asyncio").setLevel(logging.ERROR)
|
||||||
logging.getLogger("charset_normalizer").setLevel(logging.ERROR)
|
logging.getLogger("charset_normalizer").setLevel(logging.ERROR)
|
||||||
logging.getLogger("torchaudio._extension").setLevel(logging.ERROR)
|
logging.getLogger("torchaudio._extension").setLevel(logging.ERROR)
|
||||||
import pdb
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
try:
|
||||||
|
import gradio.analytics as analytics
|
||||||
|
|
||||||
|
analytics.version_check = lambda: None
|
||||||
|
except:
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
infer_ttswebui = os.environ.get("infer_ttswebui", 9872)
|
infer_ttswebui = os.environ.get("infer_ttswebui", 9872)
|
||||||
infer_ttswebui = int(infer_ttswebui)
|
infer_ttswebui = int(infer_ttswebui)
|
||||||
@ -36,15 +47,16 @@ gpt_path = os.environ.get("gpt_path", None)
|
|||||||
sovits_path = os.environ.get("sovits_path", None)
|
sovits_path = os.environ.get("sovits_path", None)
|
||||||
cnhubert_base_path = os.environ.get("cnhubert_base_path", None)
|
cnhubert_base_path = os.environ.get("cnhubert_base_path", None)
|
||||||
bert_path = os.environ.get("bert_path", None)
|
bert_path = os.environ.get("bert_path", None)
|
||||||
version=os.environ.get("version","v2")
|
version = model_version = os.environ.get("version", "v2")
|
||||||
|
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
from TTS_infer_pack.TTS import TTS, TTS_Config
|
|
||||||
from TTS_infer_pack.text_segmentation_method import get_method
|
from TTS_infer_pack.text_segmentation_method import get_method
|
||||||
|
from TTS_infer_pack.TTS import NO_PROMPT_ERROR, TTS, TTS_Config
|
||||||
|
|
||||||
from tools.i18n.i18n import I18nAuto, scan_language_list
|
from tools.i18n.i18n import I18nAuto, scan_language_list
|
||||||
|
|
||||||
language=os.environ.get("language","Auto")
|
language = os.environ.get("language", "Auto")
|
||||||
language=sys.argv[-1] if sys.argv[-1] in scan_language_list() else language
|
language = sys.argv[-1] if sys.argv[-1] in scan_language_list() else language
|
||||||
i18n = I18nAuto(language=language)
|
i18n = I18nAuto(language=language)
|
||||||
|
|
||||||
|
|
||||||
@ -56,32 +68,35 @@ if torch.cuda.is_available():
|
|||||||
# device = "mps"
|
# device = "mps"
|
||||||
else:
|
else:
|
||||||
device = "cpu"
|
device = "cpu"
|
||||||
|
|
||||||
|
# is_half = False
|
||||||
|
# device = "cpu"
|
||||||
|
|
||||||
dict_language_v1 = {
|
dict_language_v1 = {
|
||||||
i18n("中文"): "all_zh",#全部按中文识别
|
i18n("中文"): "all_zh", # 全部按中文识别
|
||||||
i18n("英文"): "en",#全部按英文识别#######不变
|
i18n("英文"): "en", # 全部按英文识别#######不变
|
||||||
i18n("日文"): "all_ja",#全部按日文识别
|
i18n("日文"): "all_ja", # 全部按日文识别
|
||||||
i18n("中英混合"): "zh",#按中英混合识别####不变
|
i18n("中英混合"): "zh", # 按中英混合识别####不变
|
||||||
i18n("日英混合"): "ja",#按日英混合识别####不变
|
i18n("日英混合"): "ja", # 按日英混合识别####不变
|
||||||
i18n("多语种混合"): "auto",#多语种启动切分识别语种
|
i18n("多语种混合"): "auto", # 多语种启动切分识别语种
|
||||||
}
|
}
|
||||||
dict_language_v2 = {
|
dict_language_v2 = {
|
||||||
i18n("中文"): "all_zh",#全部按中文识别
|
i18n("中文"): "all_zh", # 全部按中文识别
|
||||||
i18n("英文"): "en",#全部按英文识别#######不变
|
i18n("英文"): "en", # 全部按英文识别#######不变
|
||||||
i18n("日文"): "all_ja",#全部按日文识别
|
i18n("日文"): "all_ja", # 全部按日文识别
|
||||||
i18n("粤语"): "all_yue",#全部按中文识别
|
i18n("粤语"): "all_yue", # 全部按中文识别
|
||||||
i18n("韩文"): "all_ko",#全部按韩文识别
|
i18n("韩文"): "all_ko", # 全部按韩文识别
|
||||||
i18n("中英混合"): "zh",#按中英混合识别####不变
|
i18n("中英混合"): "zh", # 按中英混合识别####不变
|
||||||
i18n("日英混合"): "ja",#按日英混合识别####不变
|
i18n("日英混合"): "ja", # 按日英混合识别####不变
|
||||||
i18n("粤英混合"): "yue",#按粤英混合识别####不变
|
i18n("粤英混合"): "yue", # 按粤英混合识别####不变
|
||||||
i18n("韩英混合"): "ko",#按韩英混合识别####不变
|
i18n("韩英混合"): "ko", # 按韩英混合识别####不变
|
||||||
i18n("多语种混合"): "auto",#多语种启动切分识别语种
|
i18n("多语种混合"): "auto", # 多语种启动切分识别语种
|
||||||
i18n("多语种混合(粤语)"): "auto_yue",#多语种启动切分识别语种
|
i18n("多语种混合(粤语)"): "auto_yue", # 多语种启动切分识别语种
|
||||||
}
|
}
|
||||||
dict_language = dict_language_v1 if version =='v1' else dict_language_v2
|
dict_language = dict_language_v1 if version == "v1" else dict_language_v2
|
||||||
|
|
||||||
cut_method = {
|
cut_method = {
|
||||||
i18n("不切"):"cut0",
|
i18n("不切"): "cut0",
|
||||||
i18n("凑四句一切"): "cut1",
|
i18n("凑四句一切"): "cut1",
|
||||||
i18n("凑50字一切"): "cut2",
|
i18n("凑50字一切"): "cut2",
|
||||||
i18n("按中文句号。切"): "cut3",
|
i18n("按中文句号。切"): "cut3",
|
||||||
@ -101,29 +116,40 @@ if cnhubert_base_path is not None:
|
|||||||
tts_config.cnhuhbert_base_path = cnhubert_base_path
|
tts_config.cnhuhbert_base_path = cnhubert_base_path
|
||||||
if bert_path is not None:
|
if bert_path is not None:
|
||||||
tts_config.bert_base_path = bert_path
|
tts_config.bert_base_path = bert_path
|
||||||
|
|
||||||
print(tts_config)
|
print(tts_config)
|
||||||
tts_pipeline = TTS(tts_config)
|
tts_pipeline = TTS(tts_config)
|
||||||
gpt_path = tts_config.t2s_weights_path
|
gpt_path = tts_config.t2s_weights_path
|
||||||
sovits_path = tts_config.vits_weights_path
|
sovits_path = tts_config.vits_weights_path
|
||||||
version = tts_config.version
|
version = tts_config.version
|
||||||
|
|
||||||
def inference(text, text_lang,
|
|
||||||
ref_audio_path,
|
|
||||||
aux_ref_audio_paths,
|
|
||||||
prompt_text,
|
|
||||||
prompt_lang, top_k,
|
|
||||||
top_p, temperature,
|
|
||||||
text_split_method, batch_size,
|
|
||||||
speed_factor, ref_text_free,
|
|
||||||
split_bucket,fragment_interval,
|
|
||||||
seed, keep_random, parallel_infer,
|
|
||||||
repetition_penalty
|
|
||||||
):
|
|
||||||
|
|
||||||
|
def inference(
|
||||||
|
text,
|
||||||
|
text_lang,
|
||||||
|
ref_audio_path,
|
||||||
|
aux_ref_audio_paths,
|
||||||
|
prompt_text,
|
||||||
|
prompt_lang,
|
||||||
|
top_k,
|
||||||
|
top_p,
|
||||||
|
temperature,
|
||||||
|
text_split_method,
|
||||||
|
batch_size,
|
||||||
|
speed_factor,
|
||||||
|
ref_text_free,
|
||||||
|
split_bucket,
|
||||||
|
fragment_interval,
|
||||||
|
seed,
|
||||||
|
keep_random,
|
||||||
|
parallel_infer,
|
||||||
|
repetition_penalty,
|
||||||
|
sample_steps,
|
||||||
|
super_sampling,
|
||||||
|
):
|
||||||
seed = -1 if keep_random else seed
|
seed = -1 if keep_random else seed
|
||||||
actual_seed = seed if seed not in [-1, "", None] else random.randrange(1 << 32)
|
actual_seed = seed if seed not in [-1, "", None] else random.randint(0, 2**32 - 1)
|
||||||
inputs={
|
inputs = {
|
||||||
"text": text,
|
"text": text,
|
||||||
"text_lang": dict_language[text_lang],
|
"text_lang": dict_language[text_lang],
|
||||||
"ref_audio_path": ref_audio_path,
|
"ref_audio_path": ref_audio_path,
|
||||||
@ -134,21 +160,27 @@ def inference(text, text_lang,
|
|||||||
"top_p": top_p,
|
"top_p": top_p,
|
||||||
"temperature": temperature,
|
"temperature": temperature,
|
||||||
"text_split_method": cut_method[text_split_method],
|
"text_split_method": cut_method[text_split_method],
|
||||||
"batch_size":int(batch_size),
|
"batch_size": int(batch_size),
|
||||||
"speed_factor":float(speed_factor),
|
"speed_factor": float(speed_factor),
|
||||||
"split_bucket":split_bucket,
|
"split_bucket": split_bucket,
|
||||||
"return_fragment":False,
|
"return_fragment": False,
|
||||||
"fragment_interval":fragment_interval,
|
"fragment_interval": fragment_interval,
|
||||||
"seed":actual_seed,
|
"seed": actual_seed,
|
||||||
"parallel_infer": parallel_infer,
|
"parallel_infer": parallel_infer,
|
||||||
"repetition_penalty": repetition_penalty,
|
"repetition_penalty": repetition_penalty,
|
||||||
|
"sample_steps": int(sample_steps),
|
||||||
|
"super_sampling": super_sampling,
|
||||||
}
|
}
|
||||||
for item in tts_pipeline.run(inputs):
|
try:
|
||||||
yield item, actual_seed
|
for item in tts_pipeline.run(inputs):
|
||||||
|
yield item, actual_seed
|
||||||
|
except NO_PROMPT_ERROR:
|
||||||
|
gr.Warning(i18n("V3不支持无参考文本模式,请填写参考文本!"))
|
||||||
|
|
||||||
|
|
||||||
def custom_sort_key(s):
|
def custom_sort_key(s):
|
||||||
# 使用正则表达式提取字符串中的数字部分和非数字部分
|
# 使用正则表达式提取字符串中的数字部分和非数字部分
|
||||||
parts = re.split('(\d+)', s)
|
parts = re.split("(\d+)", s)
|
||||||
# 将数字部分转换为整数,非数字部分保持不变
|
# 将数字部分转换为整数,非数字部分保持不变
|
||||||
parts = [int(part) if part.isdigit() else part for part in parts]
|
parts = [int(part) if part.isdigit() else part for part in parts]
|
||||||
return parts
|
return parts
|
||||||
@ -156,89 +188,200 @@ def custom_sort_key(s):
|
|||||||
|
|
||||||
def change_choices():
|
def change_choices():
|
||||||
SoVITS_names, GPT_names = get_weights_names(GPT_weight_root, SoVITS_weight_root)
|
SoVITS_names, GPT_names = get_weights_names(GPT_weight_root, SoVITS_weight_root)
|
||||||
return {"choices": sorted(SoVITS_names, key=custom_sort_key), "__type__": "update"}, {"choices": sorted(GPT_names, key=custom_sort_key), "__type__": "update"}
|
return {"choices": sorted(SoVITS_names, key=custom_sort_key), "__type__": "update"}, {
|
||||||
|
"choices": sorted(GPT_names, key=custom_sort_key),
|
||||||
|
"__type__": "update",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
pretrained_sovits_name=["GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s2G2333k.pth", "GPT_SoVITS/pretrained_models/s2G488k.pth"]
|
path_sovits_v3 = "GPT_SoVITS/pretrained_models/s2Gv3.pth"
|
||||||
pretrained_gpt_name=["GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s1bert25hz-5kh-longer-epoch=12-step=369668.ckpt", "GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt"]
|
path_sovits_v4 = "GPT_SoVITS/pretrained_models/gsv-v4-pretrained/s2Gv4.pth"
|
||||||
_ =[[],[]]
|
is_exist_s2gv3 = os.path.exists(path_sovits_v3)
|
||||||
for i in range(2):
|
is_exist_s2gv4 = os.path.exists(path_sovits_v4)
|
||||||
|
pretrained_sovits_name = [
|
||||||
|
"GPT_SoVITS/pretrained_models/s2G488k.pth",
|
||||||
|
"GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s2G2333k.pth",
|
||||||
|
"GPT_SoVITS/pretrained_models/s2Gv3.pth",
|
||||||
|
"GPT_SoVITS/pretrained_models/gsv-v4-pretrained/s2Gv4.pth",
|
||||||
|
]
|
||||||
|
pretrained_gpt_name = [
|
||||||
|
"GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt",
|
||||||
|
"GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s1bert25hz-5kh-longer-epoch=12-step=369668.ckpt",
|
||||||
|
"GPT_SoVITS/pretrained_models/s1v3.ckpt",
|
||||||
|
"GPT_SoVITS/pretrained_models/s1v3.ckpt",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
_ = [[], []]
|
||||||
|
for i in range(4):
|
||||||
if os.path.exists(pretrained_gpt_name[i]):
|
if os.path.exists(pretrained_gpt_name[i]):
|
||||||
_[0].append(pretrained_gpt_name[i])
|
_[0].append(pretrained_gpt_name[i])
|
||||||
if os.path.exists(pretrained_sovits_name[i]):
|
if os.path.exists(pretrained_sovits_name[i]):
|
||||||
_[-1].append(pretrained_sovits_name[i])
|
_[-1].append(pretrained_sovits_name[i])
|
||||||
pretrained_gpt_name,pretrained_sovits_name = _
|
pretrained_gpt_name, pretrained_sovits_name = _
|
||||||
|
|
||||||
|
if os.path.exists("./weight.json"):
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
with open("./weight.json", "w", encoding="utf-8") as file:
|
||||||
|
json.dump({"GPT": {}, "SoVITS": {}}, file)
|
||||||
|
|
||||||
|
with open("./weight.json", "r", encoding="utf-8") as file:
|
||||||
|
weight_data = file.read()
|
||||||
|
weight_data = json.loads(weight_data)
|
||||||
|
gpt_path = os.environ.get("gpt_path", weight_data.get("GPT", {}).get(version, pretrained_gpt_name))
|
||||||
|
sovits_path = os.environ.get("sovits_path", weight_data.get("SoVITS", {}).get(version, pretrained_sovits_name))
|
||||||
|
if isinstance(gpt_path, list):
|
||||||
|
gpt_path = gpt_path[0]
|
||||||
|
if isinstance(sovits_path, list):
|
||||||
|
sovits_path = sovits_path[0]
|
||||||
|
|
||||||
|
|
||||||
|
SoVITS_weight_root = ["SoVITS_weights", "SoVITS_weights_v2", "SoVITS_weights_v3", "SoVITS_weights_v4"]
|
||||||
|
GPT_weight_root = ["GPT_weights", "GPT_weights_v2", "GPT_weights_v3", "GPT_weights_v4"]
|
||||||
|
for path in SoVITS_weight_root + GPT_weight_root:
|
||||||
|
os.makedirs(path, exist_ok=True)
|
||||||
|
|
||||||
SoVITS_weight_root=["SoVITS_weights_v2","SoVITS_weights"]
|
|
||||||
GPT_weight_root=["GPT_weights_v2","GPT_weights"]
|
|
||||||
for path in SoVITS_weight_root+GPT_weight_root:
|
|
||||||
os.makedirs(path,exist_ok=True)
|
|
||||||
|
|
||||||
def get_weights_names(GPT_weight_root, SoVITS_weight_root):
|
def get_weights_names(GPT_weight_root, SoVITS_weight_root):
|
||||||
SoVITS_names = [i for i in pretrained_sovits_name]
|
SoVITS_names = [i for i in pretrained_sovits_name]
|
||||||
for path in SoVITS_weight_root:
|
for path in SoVITS_weight_root:
|
||||||
for name in os.listdir(path):
|
for name in os.listdir(path):
|
||||||
if name.endswith(".pth"): SoVITS_names.append("%s/%s" % (path, name))
|
if name.endswith(".pth"):
|
||||||
|
SoVITS_names.append("%s/%s" % (path, name))
|
||||||
GPT_names = [i for i in pretrained_gpt_name]
|
GPT_names = [i for i in pretrained_gpt_name]
|
||||||
for path in GPT_weight_root:
|
for path in GPT_weight_root:
|
||||||
for name in os.listdir(path):
|
for name in os.listdir(path):
|
||||||
if name.endswith(".ckpt"): GPT_names.append("%s/%s" % (path, name))
|
if name.endswith(".ckpt"):
|
||||||
|
GPT_names.append("%s/%s" % (path, name))
|
||||||
return SoVITS_names, GPT_names
|
return SoVITS_names, GPT_names
|
||||||
|
|
||||||
|
|
||||||
SoVITS_names, GPT_names = get_weights_names(GPT_weight_root, SoVITS_weight_root)
|
SoVITS_names, GPT_names = get_weights_names(GPT_weight_root, SoVITS_weight_root)
|
||||||
|
|
||||||
|
|
||||||
|
from process_ckpt import get_sovits_version_from_path_fast
|
||||||
|
|
||||||
def change_sovits_weights(sovits_path,prompt_language=None,text_language=None):
|
v3v4set={"v3","v4"}
|
||||||
tts_pipeline.init_vits_weights(sovits_path)
|
def change_sovits_weights(sovits_path, prompt_language=None, text_language=None):
|
||||||
global version, dict_language
|
global version, model_version, dict_language, if_lora_v3
|
||||||
dict_language = dict_language_v1 if tts_pipeline.configs.version =='v1' else dict_language_v2
|
version, model_version, if_lora_v3 = get_sovits_version_from_path_fast(sovits_path)
|
||||||
|
# print(sovits_path,version, model_version, if_lora_v3)
|
||||||
|
is_exist=is_exist_s2gv3 if model_version=="v3"else is_exist_s2gv4
|
||||||
|
path_sovits = path_sovits_v3 if model_version == "v3" else path_sovits_v4
|
||||||
|
if if_lora_v3 == True and is_exist == False:
|
||||||
|
info = path_sovits + i18n("SoVITS %s 底模缺失,无法加载相应 LoRA 权重"%model_version)
|
||||||
|
gr.Warning(info)
|
||||||
|
raise FileExistsError(info)
|
||||||
|
dict_language = dict_language_v1 if version == "v1" else dict_language_v2
|
||||||
if prompt_language is not None and text_language is not None:
|
if prompt_language is not None and text_language is not None:
|
||||||
if prompt_language in list(dict_language.keys()):
|
if prompt_language in list(dict_language.keys()):
|
||||||
prompt_text_update, prompt_language_update = {'__type__':'update'}, {'__type__':'update', 'value':prompt_language}
|
prompt_text_update, prompt_language_update = (
|
||||||
|
{"__type__": "update"},
|
||||||
|
{"__type__": "update", "value": prompt_language},
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
prompt_text_update = {'__type__':'update', 'value':''}
|
prompt_text_update = {"__type__": "update", "value": ""}
|
||||||
prompt_language_update = {'__type__':'update', 'value':i18n("中文")}
|
prompt_language_update = {"__type__": "update", "value": i18n("中文")}
|
||||||
if text_language in list(dict_language.keys()):
|
if text_language in list(dict_language.keys()):
|
||||||
text_update, text_language_update = {'__type__':'update'}, {'__type__':'update', 'value':text_language}
|
text_update, text_language_update = {"__type__": "update"}, {"__type__": "update", "value": text_language}
|
||||||
else:
|
else:
|
||||||
text_update = {'__type__':'update', 'value':''}
|
text_update = {"__type__": "update", "value": ""}
|
||||||
text_language_update = {'__type__':'update', 'value':i18n("中文")}
|
text_language_update = {"__type__": "update", "value": i18n("中文")}
|
||||||
return {'__type__':'update', 'choices':list(dict_language.keys())}, {'__type__':'update', 'choices':list(dict_language.keys())}, prompt_text_update, prompt_language_update, text_update, text_language_update
|
if model_version in v3v4set:
|
||||||
|
visible_sample_steps = True
|
||||||
|
visible_inp_refs = False
|
||||||
|
else:
|
||||||
|
visible_sample_steps = False
|
||||||
|
visible_inp_refs = True
|
||||||
|
yield (
|
||||||
|
{"__type__": "update", "choices": list(dict_language.keys())},
|
||||||
|
{"__type__": "update", "choices": list(dict_language.keys())},
|
||||||
|
prompt_text_update,
|
||||||
|
prompt_language_update,
|
||||||
|
text_update,
|
||||||
|
text_language_update,
|
||||||
|
{"__type__": "update", "interactive": visible_sample_steps, "value": 32},
|
||||||
|
{"__type__": "update", "visible": visible_inp_refs},
|
||||||
|
{"__type__": "update", "interactive": True if model_version not in v3v4set else False},
|
||||||
|
{"__type__": "update", "value": i18n("模型加载中,请等待"), "interactive": False},
|
||||||
|
)
|
||||||
|
|
||||||
|
tts_pipeline.init_vits_weights(sovits_path)
|
||||||
|
yield (
|
||||||
|
{"__type__": "update", "choices": list(dict_language.keys())},
|
||||||
|
{"__type__": "update", "choices": list(dict_language.keys())},
|
||||||
|
prompt_text_update,
|
||||||
|
prompt_language_update,
|
||||||
|
text_update,
|
||||||
|
text_language_update,
|
||||||
|
{"__type__": "update", "interactive": visible_sample_steps, "value": 32},
|
||||||
|
{"__type__": "update", "visible": visible_inp_refs},
|
||||||
|
{"__type__": "update", "interactive": True if model_version not in v3v4set else False},
|
||||||
|
{"__type__": "update", "value": i18n("合成语音"), "interactive": True},
|
||||||
|
)
|
||||||
|
with open("./weight.json") as f:
|
||||||
|
data = f.read()
|
||||||
|
data = json.loads(data)
|
||||||
|
data["SoVITS"][version] = sovits_path
|
||||||
|
with open("./weight.json", "w") as f:
|
||||||
|
f.write(json.dumps(data))
|
||||||
|
|
||||||
|
|
||||||
with gr.Blocks(title="GPT-SoVITS WebUI") as app:
|
with gr.Blocks(title="GPT-SoVITS WebUI") as app:
|
||||||
gr.Markdown(
|
gr.Markdown(
|
||||||
value=i18n("本软件以MIT协议开源, 作者不对软件具备任何控制力, 使用软件者、传播软件导出的声音者自负全责. <br>如不认可该条款, 则不能使用或引用软件包内任何代码和文件. 详见根目录<b>LICENSE</b>.")
|
value=i18n("本软件以MIT协议开源, 作者不对软件具备任何控制力, 使用软件者、传播软件导出的声音者自负全责.")
|
||||||
|
+ "<br>"
|
||||||
|
+ i18n("如不认可该条款, 则不能使用或引用软件包内任何代码和文件. 详见根目录LICENSE.")
|
||||||
)
|
)
|
||||||
|
|
||||||
with gr.Column():
|
with gr.Column():
|
||||||
# with gr.Group():
|
# with gr.Group():
|
||||||
gr.Markdown(value=i18n("模型切换"))
|
gr.Markdown(value=i18n("模型切换"))
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
GPT_dropdown = gr.Dropdown(label=i18n("GPT模型列表"), choices=sorted(GPT_names, key=custom_sort_key), value=gpt_path, interactive=True)
|
GPT_dropdown = gr.Dropdown(
|
||||||
SoVITS_dropdown = gr.Dropdown(label=i18n("SoVITS模型列表"), choices=sorted(SoVITS_names, key=custom_sort_key), value=sovits_path, interactive=True)
|
label=i18n("GPT模型列表"),
|
||||||
|
choices=sorted(GPT_names, key=custom_sort_key),
|
||||||
|
value=gpt_path,
|
||||||
|
interactive=True,
|
||||||
|
)
|
||||||
|
SoVITS_dropdown = gr.Dropdown(
|
||||||
|
label=i18n("SoVITS模型列表"),
|
||||||
|
choices=sorted(SoVITS_names, key=custom_sort_key),
|
||||||
|
value=sovits_path,
|
||||||
|
interactive=True,
|
||||||
|
)
|
||||||
refresh_button = gr.Button(i18n("刷新模型路径"), variant="primary")
|
refresh_button = gr.Button(i18n("刷新模型路径"), variant="primary")
|
||||||
refresh_button.click(fn=change_choices, inputs=[], outputs=[SoVITS_dropdown, GPT_dropdown])
|
refresh_button.click(fn=change_choices, inputs=[], outputs=[SoVITS_dropdown, GPT_dropdown])
|
||||||
|
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
with gr.Column():
|
with gr.Column():
|
||||||
gr.Markdown(value=i18n("*请上传并填写参考信息"))
|
gr.Markdown(value=i18n("*请上传并填写参考信息"))
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
inp_ref = gr.Audio(label=i18n("主参考音频(请上传3~10秒内参考音频,超过会报错!)"), type="filepath")
|
inp_ref = gr.Audio(label=i18n("主参考音频(请上传3~10秒内参考音频,超过会报错!)"), type="filepath")
|
||||||
inp_refs = gr.File(label=i18n("辅参考音频(可选多个,或不选)"),file_count="multiple")
|
inp_refs = gr.File(
|
||||||
|
label=i18n("辅参考音频(可选多个,或不选)"),
|
||||||
|
file_count="multiple",
|
||||||
|
visible=True if model_version != "v3" else False,
|
||||||
|
)
|
||||||
prompt_text = gr.Textbox(label=i18n("主参考音频的文本"), value="", lines=2)
|
prompt_text = gr.Textbox(label=i18n("主参考音频的文本"), value="", lines=2)
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
prompt_language = gr.Dropdown(
|
prompt_language = gr.Dropdown(
|
||||||
label=i18n("主参考音频的语种"), choices=list(dict_language.keys()), value=i18n("中文")
|
label=i18n("主参考音频的语种"), choices=list(dict_language.keys()), value=i18n("中文")
|
||||||
)
|
)
|
||||||
with gr.Column():
|
with gr.Column():
|
||||||
ref_text_free = gr.Checkbox(label=i18n("开启无参考文本模式。不填参考文本亦相当于开启。"), value=False, interactive=True, show_label=True)
|
ref_text_free = gr.Checkbox(
|
||||||
gr.Markdown(i18n("使用无参考文本模式时建议使用微调的GPT,听不清参考音频说的啥(不晓得写啥)可以开,开启后无视填写的参考文本。"))
|
label=i18n("开启无参考文本模式。不填参考文本亦相当于开启。"),
|
||||||
|
value=False,
|
||||||
|
interactive=True if model_version != "v3" else False,
|
||||||
|
show_label=True,
|
||||||
|
)
|
||||||
|
gr.Markdown(
|
||||||
|
i18n("使用无参考文本模式时建议使用微调的GPT")
|
||||||
|
+ "<br>"
|
||||||
|
+ i18n("听不清参考音频说的啥(不晓得写啥)可以开。开启后无视填写的参考文本。")
|
||||||
|
)
|
||||||
|
|
||||||
with gr.Column():
|
with gr.Column():
|
||||||
gr.Markdown(value=i18n("*请填写需要合成的目标文本和语种模式"))
|
gr.Markdown(value=i18n("*请填写需要合成的目标文本和语种模式"))
|
||||||
text = gr.Textbox(label=i18n("需要合成的文本"), value="", lines=20, max_lines=20)
|
text = gr.Textbox(label=i18n("需要合成的文本"), value="", lines=20, max_lines=20)
|
||||||
@ -246,86 +389,158 @@ with gr.Blocks(title="GPT-SoVITS WebUI") as app:
|
|||||||
label=i18n("需要合成的文本的语种"), choices=list(dict_language.keys()), value=i18n("中文")
|
label=i18n("需要合成的文本的语种"), choices=list(dict_language.keys()), value=i18n("中文")
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
with gr.Group():
|
with gr.Group():
|
||||||
gr.Markdown(value=i18n("推理设置"))
|
gr.Markdown(value=i18n("推理设置"))
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
|
|
||||||
with gr.Column():
|
with gr.Column():
|
||||||
batch_size = gr.Slider(minimum=1,maximum=200,step=1,label=i18n("batch_size"),value=20,interactive=True)
|
with gr.Row():
|
||||||
fragment_interval = gr.Slider(minimum=0.01,maximum=1,step=0.01,label=i18n("分段间隔(秒)"),value=0.3,interactive=True)
|
batch_size = gr.Slider(
|
||||||
speed_factor = gr.Slider(minimum=0.6,maximum=1.65,step=0.05,label="speed_factor",value=1.0,interactive=True)
|
minimum=1, maximum=200, step=1, label=i18n("batch_size"), value=20, interactive=True
|
||||||
top_k = gr.Slider(minimum=1,maximum=100,step=1,label=i18n("top_k"),value=5,interactive=True)
|
)
|
||||||
top_p = gr.Slider(minimum=0,maximum=1,step=0.05,label=i18n("top_p"),value=1,interactive=True)
|
sample_steps = gr.Radio(
|
||||||
temperature = gr.Slider(minimum=0,maximum=1,step=0.05,label=i18n("temperature"),value=1,interactive=True)
|
label=i18n("采样步数(仅对V3/4生效)"), value=32, choices=[4, 8, 16, 32, 64, 128], visible=True
|
||||||
repetition_penalty = gr.Slider(minimum=0,maximum=2,step=0.05,label=i18n("重复惩罚"),value=1.35,interactive=True)
|
)
|
||||||
|
with gr.Row():
|
||||||
|
fragment_interval = gr.Slider(
|
||||||
|
minimum=0.01, maximum=1, step=0.01, label=i18n("分段间隔(秒)"), value=0.3, interactive=True
|
||||||
|
)
|
||||||
|
speed_factor = gr.Slider(
|
||||||
|
minimum=0.6, maximum=1.65, step=0.05, label="语速", value=1.0, interactive=True
|
||||||
|
)
|
||||||
|
with gr.Row():
|
||||||
|
top_k = gr.Slider(minimum=1, maximum=100, step=1, label=i18n("top_k"), value=5, interactive=True)
|
||||||
|
top_p = gr.Slider(minimum=0, maximum=1, step=0.05, label=i18n("top_p"), value=1, interactive=True)
|
||||||
|
with gr.Row():
|
||||||
|
temperature = gr.Slider(
|
||||||
|
minimum=0, maximum=1, step=0.05, label=i18n("temperature"), value=1, interactive=True
|
||||||
|
)
|
||||||
|
repetition_penalty = gr.Slider(
|
||||||
|
minimum=0, maximum=2, step=0.05, label=i18n("重复惩罚"), value=1.35, interactive=True
|
||||||
|
)
|
||||||
|
|
||||||
with gr.Column():
|
with gr.Column():
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
how_to_cut = gr.Dropdown(
|
how_to_cut = gr.Dropdown(
|
||||||
label=i18n("怎么切"),
|
label=i18n("怎么切"),
|
||||||
choices=[i18n("不切"), i18n("凑四句一切"), i18n("凑50字一切"), i18n("按中文句号。切"), i18n("按英文句号.切"), i18n("按标点符号切"), ],
|
choices=[
|
||||||
value=i18n("凑四句一切"),
|
i18n("不切"),
|
||||||
interactive=True, scale=1
|
i18n("凑四句一切"),
|
||||||
)
|
i18n("凑50字一切"),
|
||||||
|
i18n("按中文句号。切"),
|
||||||
|
i18n("按英文句号.切"),
|
||||||
|
i18n("按标点符号切"),
|
||||||
|
],
|
||||||
|
value=i18n("凑四句一切"),
|
||||||
|
interactive=True,
|
||||||
|
scale=1,
|
||||||
|
)
|
||||||
|
super_sampling = gr.Checkbox(
|
||||||
|
label=i18n("音频超采样(仅对V3生效))"), value=False, interactive=True, show_label=True
|
||||||
|
)
|
||||||
|
|
||||||
|
with gr.Row():
|
||||||
parallel_infer = gr.Checkbox(label=i18n("并行推理"), value=True, interactive=True, show_label=True)
|
parallel_infer = gr.Checkbox(label=i18n("并行推理"), value=True, interactive=True, show_label=True)
|
||||||
split_bucket = gr.Checkbox(label=i18n("数据分桶(并行推理时会降低一点计算量)"), value=True, interactive=True, show_label=True)
|
split_bucket = gr.Checkbox(
|
||||||
|
label=i18n("数据分桶(并行推理时会降低一点计算量)"),
|
||||||
with gr.Row():
|
value=True,
|
||||||
seed = gr.Number(label=i18n("随机种子"),value=-1)
|
interactive=True,
|
||||||
|
show_label=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
with gr.Row():
|
||||||
|
seed = gr.Number(label=i18n("随机种子"), value=-1)
|
||||||
keep_random = gr.Checkbox(label=i18n("保持随机"), value=True, interactive=True, show_label=True)
|
keep_random = gr.Checkbox(label=i18n("保持随机"), value=True, interactive=True, show_label=True)
|
||||||
|
|
||||||
output = gr.Audio(label=i18n("输出的语音"))
|
output = gr.Audio(label=i18n("输出的语音"))
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
inference_button = gr.Button(i18n("合成语音"), variant="primary")
|
inference_button = gr.Button(i18n("合成语音"), variant="primary")
|
||||||
stop_infer = gr.Button(i18n("终止合成"), variant="primary")
|
stop_infer = gr.Button(i18n("终止合成"), variant="primary")
|
||||||
|
|
||||||
|
|
||||||
inference_button.click(
|
inference_button.click(
|
||||||
inference,
|
inference,
|
||||||
[
|
[
|
||||||
text,text_language, inp_ref, inp_refs,
|
text,
|
||||||
prompt_text, prompt_language,
|
text_language,
|
||||||
top_k, top_p, temperature,
|
inp_ref,
|
||||||
how_to_cut, batch_size,
|
inp_refs,
|
||||||
speed_factor, ref_text_free,
|
prompt_text,
|
||||||
split_bucket,fragment_interval,
|
prompt_language,
|
||||||
seed, keep_random, parallel_infer,
|
top_k,
|
||||||
repetition_penalty
|
top_p,
|
||||||
],
|
temperature,
|
||||||
|
how_to_cut,
|
||||||
|
batch_size,
|
||||||
|
speed_factor,
|
||||||
|
ref_text_free,
|
||||||
|
split_bucket,
|
||||||
|
fragment_interval,
|
||||||
|
seed,
|
||||||
|
keep_random,
|
||||||
|
parallel_infer,
|
||||||
|
repetition_penalty,
|
||||||
|
sample_steps,
|
||||||
|
super_sampling,
|
||||||
|
],
|
||||||
[output, seed],
|
[output, seed],
|
||||||
)
|
)
|
||||||
stop_infer.click(tts_pipeline.stop, [], [])
|
stop_infer.click(tts_pipeline.stop, [], [])
|
||||||
SoVITS_dropdown.change(change_sovits_weights, [SoVITS_dropdown,prompt_language,text_language], [prompt_language,text_language,prompt_text,prompt_language,text,text_language])
|
SoVITS_dropdown.change(
|
||||||
|
change_sovits_weights,
|
||||||
|
[SoVITS_dropdown, prompt_language, text_language],
|
||||||
|
[
|
||||||
|
prompt_language,
|
||||||
|
text_language,
|
||||||
|
prompt_text,
|
||||||
|
prompt_language,
|
||||||
|
text,
|
||||||
|
text_language,
|
||||||
|
sample_steps,
|
||||||
|
inp_refs,
|
||||||
|
ref_text_free,
|
||||||
|
inference_button,
|
||||||
|
],
|
||||||
|
) #
|
||||||
GPT_dropdown.change(tts_pipeline.init_t2s_weights, [GPT_dropdown], [])
|
GPT_dropdown.change(tts_pipeline.init_t2s_weights, [GPT_dropdown], [])
|
||||||
|
|
||||||
with gr.Group():
|
with gr.Group():
|
||||||
gr.Markdown(value=i18n("文本切分工具。太长的文本合成出来效果不一定好,所以太长建议先切。合成会根据文本的换行分开合成再拼起来。"))
|
gr.Markdown(
|
||||||
|
value=i18n(
|
||||||
|
"文本切分工具。太长的文本合成出来效果不一定好,所以太长建议先切。合成会根据文本的换行分开合成再拼起来。"
|
||||||
|
)
|
||||||
|
)
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
text_inp = gr.Textbox(label=i18n("需要合成的切分前文本"), value="", lines=4)
|
text_inp = gr.Textbox(label=i18n("需要合成的切分前文本"), value="", lines=4)
|
||||||
with gr.Column():
|
with gr.Column():
|
||||||
_how_to_cut = gr.Radio(
|
_how_to_cut = gr.Radio(
|
||||||
label=i18n("怎么切"),
|
label=i18n("怎么切"),
|
||||||
choices=[i18n("不切"), i18n("凑四句一切"), i18n("凑50字一切"), i18n("按中文句号。切"), i18n("按英文句号.切"), i18n("按标点符号切"), ],
|
choices=[
|
||||||
value=i18n("凑四句一切"),
|
i18n("不切"),
|
||||||
interactive=True,
|
i18n("凑四句一切"),
|
||||||
)
|
i18n("凑50字一切"),
|
||||||
cut_text= gr.Button(i18n("切分"), variant="primary")
|
i18n("按中文句号。切"),
|
||||||
|
i18n("按英文句号.切"),
|
||||||
|
i18n("按标点符号切"),
|
||||||
|
],
|
||||||
|
value=i18n("凑四句一切"),
|
||||||
|
interactive=True,
|
||||||
|
)
|
||||||
|
cut_text = gr.Button(i18n("切分"), variant="primary")
|
||||||
|
|
||||||
def to_cut(text_inp, how_to_cut):
|
def to_cut(text_inp, how_to_cut):
|
||||||
if len(text_inp.strip()) == 0 or text_inp==[]:
|
if len(text_inp.strip()) == 0 or text_inp == []:
|
||||||
return ""
|
return ""
|
||||||
method = get_method(cut_method[how_to_cut])
|
method = get_method(cut_method[how_to_cut])
|
||||||
return method(text_inp)
|
return method(text_inp)
|
||||||
|
|
||||||
text_opt = gr.Textbox(label=i18n("切分后文本"), value="", lines=4)
|
text_opt = gr.Textbox(label=i18n("切分后文本"), value="", lines=4)
|
||||||
cut_text.click(to_cut, [text_inp, _how_to_cut], [text_opt])
|
cut_text.click(to_cut, [text_inp, _how_to_cut], [text_opt])
|
||||||
gr.Markdown(value=i18n("后续将支持转音素、手工修改音素、语音合成分步执行。"))
|
gr.Markdown(value=i18n("后续将支持转音素、手工修改音素、语音合成分步执行。"))
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == "__main__":
|
||||||
app.queue().launch(#concurrency_count=511, max_size=1022
|
app.queue().launch( # concurrency_count=511, max_size=1022
|
||||||
server_name="0.0.0.0",
|
server_name="0.0.0.0",
|
||||||
inbrowser=True,
|
inbrowser=True,
|
||||||
share=is_share,
|
share=is_share,
|
||||||
server_port=infer_ttswebui,
|
server_port=infer_ttswebui,
|
||||||
quiet=True,
|
# quiet=True,
|
||||||
)
|
)
|
||||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
x
Reference in New Issue
Block a user